diff --git a/.gitattributes b/.gitattributes index 2c029d9fd2330e6b128281245ab0169ce2753f2a..e61d4ea8a37ef626f5cebdd06078ca77407bad02 100644 --- a/.gitattributes +++ b/.gitattributes @@ -115,3 +115,7 @@ phivenv/Lib/site-packages/torch/lib/fbgemm.lib filter=lfs diff=lfs merge=lfs -te phivenv/Lib/site-packages/torch/lib/fmt.lib filter=lfs diff=lfs merge=lfs -text phivenv/Lib/site-packages/torch/lib/libiomp5md.dll filter=lfs diff=lfs merge=lfs -text phivenv/Lib/site-packages/torch/lib/libittnotify.lib filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/torch/lib/libprotobuf-lite.lib filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/torch/lib/kineto.lib filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/torch/lib/libprotobuf.lib filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/torch/lib/libprotoc.lib filter=lfs diff=lfs merge=lfs -text diff --git a/phivenv/Lib/site-packages/torch-2.8.0.dist-info/INSTALLER b/phivenv/Lib/site-packages/torch-2.8.0.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/phivenv/Lib/site-packages/torch-2.8.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/phivenv/Lib/site-packages/torch-2.8.0.dist-info/LICENSE b/phivenv/Lib/site-packages/torch-2.8.0.dist-info/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f41351673a223f58f44d545eff818e88e9e930bf --- /dev/null +++ b/phivenv/Lib/site-packages/torch-2.8.0.dist-info/LICENSE @@ -0,0 +1,8884 @@ +From PyTorch: + +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +From Caffe2: + +Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +All contributions by Facebook: +Copyright (c) 2016 Facebook Inc. + +All contributions by Google: +Copyright (c) 2015 Google Inc. +All rights reserved. + +All contributions by Yangqing Jia: +Copyright (c) 2015 Yangqing Jia +All rights reserved. + +All contributions by Kakao Brain: +Copyright 2019-2020 Kakao Brain + +All contributions by Cruise LLC: +Copyright (c) 2022 Cruise LLC. +All rights reserved. + +All contributions by Tri Dao: +Copyright (c) 2024 Tri Dao. +All rights reserved. + +All contributions by Arm: +Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates + +All contributions from Caffe: +Copyright(c) 2013, 2014, 2015, the respective contributors +All rights reserved. + +All other contributions: +Copyright(c) 2015, 2016 the respective contributors +All rights reserved. + +Caffe2 uses a copyright model similar to Caffe: each contributor holds +copyright over their contributions to Caffe2. The project versioning records +all such contribution and copyright details. If a contributor wants to further +mark their specific copyright on a particular contribution, they should +indicate their copyright solely in the commit message of the change when it is +committed. + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + + +The PyTorch repository and source distributions bundle several libraries that are +compatibly licensed. We list these here. + +Name: third_party\FP16 +License: MIT +Files: third_party\FP16 + For details, see the files concatenated below: third_party\FP16\LICENSE + +Name: third_party\FXdiv +License: MIT +Files: third_party\FXdiv + For details, see the files concatenated below: third_party\FXdiv\LICENSE + +Name: third_party\NNPACK +License: BSD-2-Clause +Files: third_party\NNPACK + For details, see the files concatenated below: third_party\NNPACK\LICENSE + +Name: third_party\NVTX +License: Apache-2.0 with exception +Files: third_party\NVTX + For details, see the files concatenated below: third_party\NVTX\LICENSE.txt + +Name: third_party\NVTX\docs +License: Apache-2.0 with exception +Files: third_party\NVTX\docs + For details, see the files concatenated below: third_party\NVTX\docs\LICENSE.txt + +Name: third_party\NVTX\python +License: Apache-2.0 with exception +Files: third_party\NVTX\python + For details, see the files concatenated below: third_party\NVTX\python\LICENSE.txt + +Name: third_party\VulkanMemoryAllocator +License: MIT +Files: third_party\VulkanMemoryAllocator + For details, see the files concatenated below: third_party\VulkanMemoryAllocator\LICENSE.txt + +Name: third_party\XNNPACK +License: BSD-3-Clause +Files: third_party\XNNPACK + For details, see the files concatenated below: third_party\XNNPACK\LICENSE + +Name: third_party\benchmark +License: Apache-2.0 +Files: third_party\benchmark + For details, see the files concatenated below: third_party\benchmark\LICENSE + +Name: third_party\composable_kernel +License: MIT +Files: third_party\composable_kernel + For details, see the files concatenated below: third_party\composable_kernel\LICENSE + +Name: third_party\cpp-httplib +License: MIT +Files: third_party\cpp-httplib + For details, see the files concatenated below: third_party\cpp-httplib\LICENSE + +Name: third_party\cpuinfo +License: BSD-2-Clause +Files: third_party\cpuinfo + For details, see the files concatenated below: third_party\cpuinfo\LICENSE + +Name: third_party\cpuinfo\deps\clog +License: BSD-2-Clause +Files: third_party\cpuinfo\deps\clog + For details, see the files concatenated below: third_party\cpuinfo\deps\clog\LICENSE + +Name: third_party\cudnn_frontend +License: MIT +Files: third_party\cudnn_frontend + For details, see the files concatenated below: third_party\cudnn_frontend\LICENSE.txt + +Name: third_party\cutlass +License: BSD-3-Clause +Files: third_party\cutlass + For details, see the files concatenated below: third_party\cutlass\LICENSE.txt + +Name: third_party\cutlass\python +License: BSD-3-Clause +Files: third_party\cutlass\python + For details, see the files concatenated below: third_party\cutlass\python\LICENSE.txt + +Name: third_party\fbgemm +License: BSD-3-Clause +Files: third_party\fbgemm + For details, see the files concatenated below: third_party\fbgemm\LICENSE + +Name: third_party\fbgemm\external\composable_kernel +License: MIT +Files: third_party\fbgemm\external\composable_kernel + For details, see the files concatenated below: third_party\fbgemm\external\composable_kernel\LICENSE + +Name: third_party\fbgemm\external\cpuinfo +License: BSD-2-Clause +Files: third_party\fbgemm\external\cpuinfo + For details, see the files concatenated below: third_party\fbgemm\external\cpuinfo\LICENSE + +Name: third_party\fbgemm\external\cpuinfo\deps\clog +License: BSD-2-Clause +Files: third_party\fbgemm\external\cpuinfo\deps\clog + For details, see the files concatenated below: third_party\fbgemm\external\cpuinfo\deps\clog\LICENSE + +Name: third_party\fbgemm\external\cutlass +License: BSD-3-Clause +Files: third_party\fbgemm\external\cutlass + For details, see the files concatenated below: third_party\fbgemm\external\cutlass\LICENSE.txt + +Name: third_party\fbgemm\external\cutlass\python +License: BSD-3-Clause +Files: third_party\fbgemm\external\cutlass\python + For details, see the files concatenated below: third_party\fbgemm\external\cutlass\python\LICENSE.txt + +Name: third_party\fbgemm\external\googletest +License: BSD-3-Clause +Files: third_party\fbgemm\external\googletest + For details, see the files concatenated below: third_party\fbgemm\external\googletest\LICENSE + +Name: third_party\fbgemm\external\hipify_torch +License: MIT +Files: third_party\fbgemm\external\hipify_torch + For details, see the files concatenated below: third_party\fbgemm\external\hipify_torch\LICENSE.txt + +Name: third_party\fbgemm\fbgemm_gpu\src\quantize_ops\mx +License: MIT +Files: third_party\fbgemm\fbgemm_gpu\src\quantize_ops\mx + For details, see the files concatenated below: third_party\fbgemm\fbgemm_gpu\src\quantize_ops\mx\LICENSE + +Name: third_party\fbgemm\fbgemm_gpu\test\quantize\mx +License: MIT +Files: third_party\fbgemm\fbgemm_gpu\test\quantize\mx + For details, see the files concatenated below: third_party\fbgemm\fbgemm_gpu\test\quantize\mx\LICENSE + +Name: third_party\flash-attention +License: BSD-3-Clause +Files: third_party\flash-attention + For details, see the files concatenated below: third_party\flash-attention\LICENSE + +Name: third_party\flash-attention\csrc\composable_kernel +License: MIT +Files: third_party\flash-attention\csrc\composable_kernel + For details, see the files concatenated below: third_party\flash-attention\csrc\composable_kernel\LICENSE + +Name: third_party\flash-attention\csrc\cutlass +License: BSD-3-Clause +Files: third_party\flash-attention\csrc\cutlass + For details, see the files concatenated below: third_party\flash-attention\csrc\cutlass\LICENSE.txt + +Name: third_party\flash-attention\csrc\cutlass\python +License: BSD-3-Clause +Files: third_party\flash-attention\csrc\cutlass\python + For details, see the files concatenated below: third_party\flash-attention\csrc\cutlass\python\LICENSE.txt + +Name: third_party\flatbuffers +License: Apache-2.0 +Files: third_party\flatbuffers + For details, see the files concatenated below: third_party\flatbuffers\LICENSE + +Name: third_party\flatbuffers\dart +License: Apache-2.0 +Files: third_party\flatbuffers\dart + For details, see the files concatenated below: third_party\flatbuffers\dart\LICENSE + +Name: third_party\flatbuffers\swift +License: Apache-2.0 +Files: third_party\flatbuffers\swift + For details, see the files concatenated below: third_party\flatbuffers\swift\LICENSE + +Name: third_party\fmt +License: MIT with exception +Files: third_party\fmt + For details, see the files concatenated below: third_party\fmt\LICENSE + +Name: third_party\gemmlowp\gemmlowp +License: Apache-2.0 +Files: third_party\gemmlowp\gemmlowp + For details, see the files concatenated below: third_party\gemmlowp\gemmlowp\LICENSE + +Name: third_party\gloo +License: BSD-3-Clause +Files: third_party\gloo + For details, see the files concatenated below: third_party\gloo\LICENSE + +Name: third_party\googletest +License: BSD-3-Clause +Files: third_party\googletest + For details, see the files concatenated below: third_party\googletest\LICENSE + +Name: third_party\ideep +License: MIT +Files: third_party\ideep + For details, see the files concatenated below: third_party\ideep\LICENSE + +Name: third_party\ideep\mkl-dnn +License: Apache-2.0 +Files: third_party\ideep\mkl-dnn + For details, see the files concatenated below: third_party\ideep\mkl-dnn\LICENSE + +Name: third_party\ideep\mkl-dnn\tests\gtests\gtest +License: BSD-3-Clause +Files: third_party\ideep\mkl-dnn\tests\gtests\gtest + For details, see the files concatenated below: third_party\ideep\mkl-dnn\tests\gtests\gtest\LICENSE + +Name: third_party\kineto +License: BSD-3-Clause +Files: third_party\kineto + For details, see the files concatenated below: third_party\kineto\LICENSE + +Name: third_party\kineto\libkineto\third_party\dynolog +License: MIT +Files: third_party\kineto\libkineto\third_party\dynolog + For details, see the files concatenated below: third_party\kineto\libkineto\third_party\dynolog\LICENSE + +Name: third_party\kineto\libkineto\third_party\dynolog\third_party\DCGM +License: Apache-2.0 +Files: third_party\kineto\libkineto\third_party\dynolog\third_party\DCGM + For details, see the files concatenated below: third_party\kineto\libkineto\third_party\dynolog\third_party\DCGM\LICENSE + +Name: third_party\kineto\libkineto\third_party\dynolog\third_party\DCGM\testing\python3\libs_3rdparty\colorama +License: BSD-3-Clause +Files: third_party\kineto\libkineto\third_party\dynolog\third_party\DCGM\testing\python3\libs_3rdparty\colorama + For details, see the files concatenated below: third_party\kineto\libkineto\third_party\dynolog\third_party\DCGM\testing\python3\libs_3rdparty\colorama\LICENSE.txt + +Name: third_party\kineto\libkineto\third_party\dynolog\third_party\cpr +License: MIT +Files: third_party\kineto\libkineto\third_party\dynolog\third_party\cpr + For details, see the files concatenated below: third_party\kineto\libkineto\third_party\dynolog\third_party\cpr\LICENSE + +Name: third_party\kineto\libkineto\third_party\dynolog\third_party\cpr\test +License: MIT with exception +Files: third_party\kineto\libkineto\third_party\dynolog\third_party\cpr\test + For details, see the files concatenated below: third_party\kineto\libkineto\third_party\dynolog\third_party\cpr\test\LICENSE + +Name: third_party\kineto\libkineto\third_party\dynolog\third_party\fmt +License: MIT with exception +Files: third_party\kineto\libkineto\third_party\dynolog\third_party\fmt + For details, see the files concatenated below: third_party\kineto\libkineto\third_party\dynolog\third_party\fmt\LICENSE.rst + +Name: third_party\kineto\libkineto\third_party\dynolog\third_party\googletest +License: BSD-3-Clause +Files: third_party\kineto\libkineto\third_party\dynolog\third_party\googletest + For details, see the files concatenated below: third_party\kineto\libkineto\third_party\dynolog\third_party\googletest\LICENSE + +Name: third_party\kineto\libkineto\third_party\dynolog\third_party\json\test\thirdparty\doctest +License: MIT +Files: third_party\kineto\libkineto\third_party\dynolog\third_party\json\test\thirdparty\doctest + For details, see the files concatenated below: third_party\kineto\libkineto\third_party\dynolog\third_party\json\test\thirdparty\doctest\LICENSE.txt + +Name: third_party\kineto\libkineto\third_party\dynolog\third_party\json\third_party\cpplint +License: BSD-3-Clause +Files: third_party\kineto\libkineto\third_party\dynolog\third_party\json\third_party\cpplint + For details, see the files concatenated below: third_party\kineto\libkineto\third_party\dynolog\third_party\json\third_party\cpplint\LICENSE + +Name: third_party\kineto\libkineto\third_party\dynolog\third_party\pfs +License: Apache-2.0 +Files: third_party\kineto\libkineto\third_party\dynolog\third_party\pfs + For details, see the files concatenated below: third_party\kineto\libkineto\third_party\dynolog\third_party\pfs\LICENSE + +Name: third_party\kineto\libkineto\third_party\fmt +License: MIT with exception +Files: third_party\kineto\libkineto\third_party\fmt + For details, see the files concatenated below: third_party\kineto\libkineto\third_party\fmt\LICENSE + +Name: third_party\kineto\libkineto\third_party\googletest +License: BSD-3-Clause +Files: third_party\kineto\libkineto\third_party\googletest + For details, see the files concatenated below: third_party\kineto\libkineto\third_party\googletest\LICENSE + +Name: third_party\kineto\libkineto\third_party\googletest\googlemock +License: BSD-3-Clause +Files: third_party\kineto\libkineto\third_party\googletest\googlemock + For details, see the files concatenated below: third_party\kineto\libkineto\third_party\googletest\googlemock\LICENSE + +Name: third_party\kineto\libkineto\third_party\googletest\googlemock\scripts\generator +License: Apache-2.0 +Files: third_party\kineto\libkineto\third_party\googletest\googlemock\scripts\generator + For details, see the files concatenated below: third_party\kineto\libkineto\third_party\googletest\googlemock\scripts\generator\LICENSE + +Name: third_party\kineto\libkineto\third_party\googletest\googletest +License: BSD-3-Clause +Files: third_party\kineto\libkineto\third_party\googletest\googletest + For details, see the files concatenated below: third_party\kineto\libkineto\third_party\googletest\googletest\LICENSE + +Name: third_party\kineto\tb_plugin +License: BSD-3-Clause +Files: third_party\kineto\tb_plugin + For details, see the files concatenated below: third_party\kineto\tb_plugin\LICENSE + +Name: third_party\mimalloc +License: MIT +Files: third_party\mimalloc + For details, see the files concatenated below: third_party\mimalloc\LICENSE + +Name: third_party\miniz-3.0.2 +License: MIT +Files: third_party\miniz-3.0.2 + For details, see the files concatenated below: third_party\miniz-3.0.2\LICENSE + +Name: third_party\onnx +License: Apache-2.0 +Files: third_party\onnx + For details, see the files concatenated below: third_party\onnx\LICENSE + +Name: third_party\onnx\third_party\pybind11 +License: BSD-3-Clause +Files: third_party\onnx\third_party\pybind11 + For details, see the files concatenated below: third_party\onnx\third_party\pybind11\LICENSE + +Name: third_party\opentelemetry-cpp +License: Apache-2.0 +Files: third_party\opentelemetry-cpp + For details, see the files concatenated below: third_party\opentelemetry-cpp\LICENSE + +Name: third_party\opentelemetry-cpp\exporters\etw\include\opentelemetry\exporters\etw +License: MIT +Files: third_party\opentelemetry-cpp\exporters\etw\include\opentelemetry\exporters\etw + For details, see the files concatenated below: third_party\opentelemetry-cpp\exporters\etw\include\opentelemetry\exporters\etw\LICENSE + +Name: third_party\opentelemetry-cpp\third_party\benchmark +License: Apache-2.0 +Files: third_party\opentelemetry-cpp\third_party\benchmark + For details, see the files concatenated below: third_party\opentelemetry-cpp\third_party\benchmark\LICENSE + +Name: third_party\opentelemetry-cpp\third_party\googletest +License: BSD-3-Clause +Files: third_party\opentelemetry-cpp\third_party\googletest + For details, see the files concatenated below: third_party\opentelemetry-cpp\third_party\googletest\LICENSE + +Name: third_party\opentelemetry-cpp\third_party\ms-gsl +License: MIT +Files: third_party\opentelemetry-cpp\third_party\ms-gsl + For details, see the files concatenated below: third_party\opentelemetry-cpp\third_party\ms-gsl\LICENSE + +Name: third_party\opentelemetry-cpp\third_party\opentelemetry-proto +License: Apache-2.0 +Files: third_party\opentelemetry-cpp\third_party\opentelemetry-proto + For details, see the files concatenated below: third_party\opentelemetry-cpp\third_party\opentelemetry-proto\LICENSE + +Name: third_party\opentelemetry-cpp\third_party\opentracing-cpp +License: Apache-2.0 +Files: third_party\opentelemetry-cpp\third_party\opentracing-cpp + For details, see the files concatenated below: third_party\opentelemetry-cpp\third_party\opentracing-cpp\LICENSE + +Name: third_party\opentelemetry-cpp\third_party\opentracing-cpp\3rd_party\include\opentracing\catch2 +License: BSL-1.0 +Files: third_party\opentelemetry-cpp\third_party\opentracing-cpp\3rd_party\include\opentracing\catch2 + For details, see the files concatenated below: third_party\opentelemetry-cpp\third_party\opentracing-cpp\3rd_party\include\opentracing\catch2\LICENSE.txt + +Name: third_party\opentelemetry-cpp\third_party\opentracing-cpp\3rd_party\include\opentracing\expected +License: MIT +Files: third_party\opentelemetry-cpp\third_party\opentracing-cpp\3rd_party\include\opentracing\expected + For details, see the files concatenated below: third_party\opentelemetry-cpp\third_party\opentracing-cpp\3rd_party\include\opentracing\expected\LICENSE + +Name: third_party\opentelemetry-cpp\third_party\opentracing-cpp\3rd_party\include\opentracing\variant +License: BSD-3-Clause +Files: third_party\opentelemetry-cpp\third_party\opentracing-cpp\3rd_party\include\opentracing\variant + For details, see the files concatenated below: third_party\opentelemetry-cpp\third_party\opentracing-cpp\3rd_party\include\opentracing\variant\LICENSE + +Name: third_party\opentelemetry-cpp\third_party\prometheus-cpp +License: MIT +Files: third_party\opentelemetry-cpp\third_party\prometheus-cpp + For details, see the files concatenated below: third_party\opentelemetry-cpp\third_party\prometheus-cpp\LICENSE + +Name: third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\civetweb\examples\rest\cJSON +License: MIT +Files: third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\civetweb\examples\rest\cJSON + For details, see the files concatenated below: third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\civetweb\examples\rest\cJSON\LICENSE + +Name: third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\civetweb\src\third_party\duktape-1.5.2 +License: MIT +Files: third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\civetweb\src\third_party\duktape-1.5.2 + For details, see the files concatenated below: third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\civetweb\src\third_party\duktape-1.5.2\LICENSE.txt + +Name: third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\civetweb\src\third_party\duktape-1.8.0 +License: MIT +Files: third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\civetweb\src\third_party\duktape-1.8.0 + For details, see the files concatenated below: third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\civetweb\src\third_party\duktape-1.8.0\LICENSE.txt + +Name: third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\googletest +License: BSD-3-Clause +Files: third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\googletest + For details, see the files concatenated below: third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\googletest\LICENSE + +Name: third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\googletest\googlemock\scripts\generator +License: Apache-2.0 +Files: third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\googletest\googlemock\scripts\generator + For details, see the files concatenated below: third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\googletest\googlemock\scripts\generator\LICENSE + +Name: third_party\opentelemetry-cpp\tools\vcpkg +License: MIT +Files: third_party\opentelemetry-cpp\tools\vcpkg + For details, see the files concatenated below: third_party\opentelemetry-cpp\tools\vcpkg\LICENSE.txt + +Name: third_party\opentelemetry-cpp\tools\vcpkg\ports\boost-vcpkg-helpers +License: MIT +Files: third_party\opentelemetry-cpp\tools\vcpkg\ports\boost-vcpkg-helpers + For details, see the files concatenated below: third_party\opentelemetry-cpp\tools\vcpkg\ports\boost-vcpkg-helpers\LICENSE.txt + +Name: third_party\opentelemetry-cpp\tools\vcpkg\ports\ffnvcodec +License: MIT with exception +Files: third_party\opentelemetry-cpp\tools\vcpkg\ports\ffnvcodec + For details, see the files concatenated below: third_party\opentelemetry-cpp\tools\vcpkg\ports\ffnvcodec\LICENSE.txt + +Name: third_party\opentelemetry-cpp\tools\vcpkg\ports\gettimeofday +License: Apache-2.0 +Files: third_party\opentelemetry-cpp\tools\vcpkg\ports\gettimeofday + For details, see the files concatenated below: third_party\opentelemetry-cpp\tools\vcpkg\ports\gettimeofday\LICENSE + +Name: third_party\opentelemetry-cpp\tools\vcpkg\ports\hungarian +License: Permissive (free to use) +Files: third_party\opentelemetry-cpp\tools\vcpkg\ports\hungarian + For details, see the files concatenated below: third_party\opentelemetry-cpp\tools\vcpkg\ports\hungarian\LICENSE.txt + +Name: third_party\opentelemetry-cpp\tools\vcpkg\ports\irrlicht +License: MIT +Files: third_party\opentelemetry-cpp\tools\vcpkg\ports\irrlicht + For details, see the files concatenated below: third_party\opentelemetry-cpp\tools\vcpkg\ports\irrlicht\LICENSE.txt + +Name: third_party\opentelemetry-cpp\tools\vcpkg\ports\libstemmer +License: BSD-3-Clause +Files: third_party\opentelemetry-cpp\tools\vcpkg\ports\libstemmer + For details, see the files concatenated below: third_party\opentelemetry-cpp\tools\vcpkg\ports\libstemmer\LICENSE + +Name: third_party\opentelemetry-cpp\tools\vcpkg\ports\pdcurses +License: Public Domain for core +Files: third_party\opentelemetry-cpp\tools\vcpkg\ports\pdcurses + For details, see the files concatenated below: third_party\opentelemetry-cpp\tools\vcpkg\ports\pdcurses\LICENSE + +Name: third_party\opentelemetry-cpp\tools\vcpkg\ports\physac +License: MIT +Files: third_party\opentelemetry-cpp\tools\vcpkg\ports\physac + For details, see the files concatenated below: third_party\opentelemetry-cpp\tools\vcpkg\ports\physac\LICENSE + +Name: third_party\opentelemetry-cpp\tools\vcpkg\ports\pqp +License: Apache-2.0 +Files: third_party\opentelemetry-cpp\tools\vcpkg\ports\pqp + For details, see the files concatenated below: third_party\opentelemetry-cpp\tools\vcpkg\ports\pqp\LICENSE + +Name: third_party\opentelemetry-cpp\tools\vcpkg\ports\sigslot +License: Public Domain +Files: third_party\opentelemetry-cpp\tools\vcpkg\ports\sigslot + For details, see the files concatenated below: third_party\opentelemetry-cpp\tools\vcpkg\ports\sigslot\LICENSE + +Name: third_party\opentelemetry-cpp\tools\vcpkg\ports\tensorflow-common +License: MIT +Files: third_party\opentelemetry-cpp\tools\vcpkg\ports\tensorflow-common + For details, see the files concatenated below: third_party\opentelemetry-cpp\tools\vcpkg\ports\tensorflow-common\LICENSE.txt + +Name: third_party\opentelemetry-cpp\tools\vcpkg\ports\vulkan +License: Apache-2.0 with exception +Files: third_party\opentelemetry-cpp\tools\vcpkg\ports\vulkan + For details, see the files concatenated below: third_party\opentelemetry-cpp\tools\vcpkg\ports\vulkan\LICENSE.txt + +Name: third_party\protobuf +License: BSD-3-Clause +Files: third_party\protobuf + For details, see the files concatenated below: third_party\protobuf\LICENSE + +Name: third_party\protobuf\third_party\benchmark +License: Apache-2.0 +Files: third_party\protobuf\third_party\benchmark + For details, see the files concatenated below: third_party\protobuf\third_party\benchmark\LICENSE + +Name: third_party\protobuf\third_party\googletest +License: BSD-3-Clause +Files: third_party\protobuf\third_party\googletest + For details, see the files concatenated below: third_party\protobuf\third_party\googletest\LICENSE + +Name: third_party\protobuf\third_party\googletest\googlemock +License: BSD-3-Clause +Files: third_party\protobuf\third_party\googletest\googlemock + For details, see the files concatenated below: third_party\protobuf\third_party\googletest\googlemock\LICENSE + +Name: third_party\protobuf\third_party\googletest\googlemock\scripts\generator +License: Apache-2.0 +Files: third_party\protobuf\third_party\googletest\googlemock\scripts\generator + For details, see the files concatenated below: third_party\protobuf\third_party\googletest\googlemock\scripts\generator\LICENSE + +Name: third_party\protobuf\third_party\googletest\googletest +License: BSD-3-Clause +Files: third_party\protobuf\third_party\googletest\googletest + For details, see the files concatenated below: third_party\protobuf\third_party\googletest\googletest\LICENSE + +Name: third_party\psimd +License: MIT +Files: third_party\psimd + For details, see the files concatenated below: third_party\psimd\LICENSE + +Name: third_party\pthreadpool +License: BSD-2-Clause +Files: third_party\pthreadpool + For details, see the files concatenated below: third_party\pthreadpool\LICENSE + +Name: third_party\pybind11 +License: BSD-3-Clause +Files: third_party\pybind11 + For details, see the files concatenated below: third_party\pybind11\LICENSE + +Name: third_party\python-peachpy +License: BSD-2-Clause +Files: third_party\python-peachpy + For details, see the files concatenated below: third_party\python-peachpy\LICENSE.rst + +Name: third_party\sleef +License: BSL-1.0 +Files: third_party\sleef + For details, see the files concatenated below: third_party\sleef\LICENSE.txt + +Name: third_party\tensorpipe +License: BSD-3-Clause +Files: third_party\tensorpipe + For details, see the files concatenated below: third_party\tensorpipe\LICENSE.txt + +Name: third_party\tensorpipe\third_party\googletest +License: BSD-3-Clause +Files: third_party\tensorpipe\third_party\googletest + For details, see the files concatenated below: third_party\tensorpipe\third_party\googletest\LICENSE + +Name: third_party\tensorpipe\third_party\googletest\googlemock +License: BSD-3-Clause +Files: third_party\tensorpipe\third_party\googletest\googlemock + For details, see the files concatenated below: third_party\tensorpipe\third_party\googletest\googlemock\LICENSE + +Name: third_party\tensorpipe\third_party\googletest\googlemock\scripts\generator +License: Apache-2.0 +Files: third_party\tensorpipe\third_party\googletest\googlemock\scripts\generator + For details, see the files concatenated below: third_party\tensorpipe\third_party\googletest\googlemock\scripts\generator\LICENSE + +Name: third_party\tensorpipe\third_party\googletest\googletest +License: BSD-3-Clause +Files: third_party\tensorpipe\third_party\googletest\googletest + For details, see the files concatenated below: third_party\tensorpipe\third_party\googletest\googletest\LICENSE + +Name: third_party\tensorpipe\third_party\libnop +License: Apache-2.0 +Files: third_party\tensorpipe\third_party\libnop + For details, see the files concatenated below: third_party\tensorpipe\third_party\libnop\LICENSE + +Name: third_party\tensorpipe\third_party\libuv +License: MIT +Files: third_party\tensorpipe\third_party\libuv + For details, see the files concatenated below: third_party\tensorpipe\third_party\libuv\LICENSE + +Name: third_party\tensorpipe\third_party\pybind11 +License: BSD-3-Clause +Files: third_party\tensorpipe\third_party\pybind11 + For details, see the files concatenated below: third_party\tensorpipe\third_party\pybind11\LICENSE + +third_party\FP16\LICENSE +------------------------ +The MIT License (MIT) + +Copyright (c) 2017 Facebook Inc. +Copyright (c) 2017 Georgia Institute of Technology +Copyright 2019 Google LLC + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +third_party\FXdiv\LICENSE +------------------------- +The MIT License (MIT) + +Copyright (c) 2017 Facebook Inc. +Copyright (c) 2016-2017 Marat Dukhan + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +third_party\NNPACK\LICENSE +-------------------------- +Copyright (c) 2017 Facebook Inc. +Copyright (c) 2015-2017, Georgia Institute of Technology +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\NVTX\LICENSE.txt +---------------------------- +============================================================================== +NVTX is under the Apache License v2.0 with LLVM Exceptions: +============================================================================== + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +---- LLVM Exceptions to the Apache 2.0 License ---- + +As an exception, if, as a result of your compiling your source code, portions +of this Software are embedded into an Object form of such source code, you +may redistribute such embedded portions in such Object form without complying +with the conditions of Sections 4(a), 4(b) and 4(d) of the License. + +In addition, if you combine or link compiled forms of this Software with +software that is licensed under the GPLv2 ("Combined Software") and if a +court of competent jurisdiction determines that the patent provision (Section +3), the indemnity provision (Section 9) or other Section of the License +conflicts with the conditions of the GPLv2, you may retroactively and +prospectively choose to deem waived or otherwise exclude such Section(s) of +the License, but only in their entirety and only with respect to the Combined +Software. + + + +third_party\NVTX\docs\LICENSE.txt +--------------------------------- +============================================================================== +NVTX is under the Apache License v2.0 with LLVM Exceptions: +============================================================================== + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +---- LLVM Exceptions to the Apache 2.0 License ---- + +As an exception, if, as a result of your compiling your source code, portions +of this Software are embedded into an Object form of such source code, you +may redistribute such embedded portions in such Object form without complying +with the conditions of Sections 4(a), 4(b) and 4(d) of the License. + +In addition, if you combine or link compiled forms of this Software with +software that is licensed under the GPLv2 ("Combined Software") and if a +court of competent jurisdiction determines that the patent provision (Section +3), the indemnity provision (Section 9) or other Section of the License +conflicts with the conditions of the GPLv2, you may retroactively and +prospectively choose to deem waived or otherwise exclude such Section(s) of +the License, but only in their entirety and only with respect to the Combined +Software. + + + +third_party\NVTX\python\LICENSE.txt +----------------------------------- +============================================================================== +NVTX is under the Apache License v2.0 with LLVM Exceptions: +============================================================================== + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +---- LLVM Exceptions to the Apache 2.0 License ---- + +As an exception, if, as a result of your compiling your source code, portions +of this Software are embedded into an Object form of such source code, you +may redistribute such embedded portions in such Object form without complying +with the conditions of Sections 4(a), 4(b) and 4(d) of the License. + +In addition, if you combine or link compiled forms of this Software with +software that is licensed under the GPLv2 ("Combined Software") and if a +court of competent jurisdiction determines that the patent provision (Section +3), the indemnity provision (Section 9) or other Section of the License +conflicts with the conditions of the GPLv2, you may retroactively and +prospectively choose to deem waived or otherwise exclude such Section(s) of +the License, but only in their entirety and only with respect to the Combined +Software. + + + +third_party\VulkanMemoryAllocator\LICENSE.txt +--------------------------------------------- +Copyright (c) 2017-2025 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + + +third_party\XNNPACK\LICENSE +--------------------------- +BSD License + +For XNNPACK software + +Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +Copyright 2019 Google LLC + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Facebook nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\benchmark\LICENSE +----------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +third_party\composable_kernel\LICENSE +------------------------------------- +Copyright (c) 2018- , Advanced Micro Devices, Inc. (Chao Liu, Jing Zhang) +Copyright (c) 2019- , Advanced Micro Devices, Inc. (Letao Qin, Qianfeng Zhang, Liang Huang, Shaojie Wang) +Copyright (c) 2022- , Advanced Micro Devices, Inc. (Anthony Chang, Chunyu Lai, Illia Silin, Adam Osewski, Poyen Chen, Jehandad Khan) +Copyright (c) 2019-2021, Advanced Micro Devices, Inc. (Hanwen Chang) +Copyright (c) 2019-2020, Advanced Micro Devices, Inc. (Tejash Shah) +Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) +Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) + +SPDX-License-Identifier: MIT +Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +third_party\cpp-httplib\LICENSE +------------------------------- +The MIT License (MIT) + +Copyright (c) 2017 yhirose + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + + +third_party\cpuinfo\LICENSE +--------------------------- +Copyright (c) 2019 Google LLC +Copyright (c) 2017-2018 Facebook Inc. +Copyright (C) 2012-2017 Georgia Institute of Technology +Copyright (C) 2010-2012 Marat Dukhan + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\cpuinfo\deps\clog\LICENSE +------------------------------------- +Copyright (C) 2018 Marat Dukhan +Copyright (c) 2017-2018 Facebook Inc. +Copyright (c) 2017 Georgia Institute of Technology + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\cudnn_frontend\LICENSE.txt +-------------------------------------- +/* + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + + +third_party\cutlass\LICENSE.txt +------------------------------- +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\cutlass\python\LICENSE.txt +-------------------------------------- +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\fbgemm\LICENSE +-------------------------- +BSD License + +For FBGEMM software + +Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Facebook nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\fbgemm\external\composable_kernel\LICENSE +----------------------------------------------------- +Copyright (c) 2018- , Advanced Micro Devices, Inc. (Chao Liu, Jing Zhang) +Copyright (c) 2019- , Advanced Micro Devices, Inc. (Letao Qin, Qianfeng Zhang, Liang Huang, Shaojie Wang) +Copyright (c) 2022- , Advanced Micro Devices, Inc. (Anthony Chang, Chunyu Lai, Illia Silin, Adam Osewski, Poyen Chen, Jehandad Khan) +Copyright (c) 2019-2021, Advanced Micro Devices, Inc. (Hanwen Chang) +Copyright (c) 2019-2020, Advanced Micro Devices, Inc. (Tejash Shah) +Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) +Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) + +SPDX-License-Identifier: MIT +Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +third_party\fbgemm\external\cpuinfo\LICENSE +------------------------------------------- +Copyright (c) 2019 Google LLC +Copyright (c) 2017-2018 Facebook Inc. +Copyright (C) 2012-2017 Georgia Institute of Technology +Copyright (C) 2010-2012 Marat Dukhan + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\fbgemm\external\cpuinfo\deps\clog\LICENSE +----------------------------------------------------- +Copyright (C) 2018 Marat Dukhan +Copyright (c) 2017-2018 Facebook Inc. +Copyright (c) 2017 Georgia Institute of Technology + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\fbgemm\external\cutlass\LICENSE.txt +----------------------------------------------- +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\fbgemm\external\cutlass\python\LICENSE.txt +------------------------------------------------------ +Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\fbgemm\external\googletest\LICENSE +---------------------------------------------- +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\fbgemm\external\hipify_torch\LICENSE.txt +---------------------------------------------------- +MIT License + +Copyright (c) 2021-2024, Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +third_party\fbgemm\fbgemm_gpu\src\quantize_ops\mx\LICENSE +--------------------------------------------------------- + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE + + +third_party\fbgemm\fbgemm_gpu\test\quantize\mx\LICENSE +------------------------------------------------------ + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE + + +third_party\flash-attention\LICENSE +----------------------------------- +BSD 3-Clause License + +Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\flash-attention\csrc\composable_kernel\LICENSE +---------------------------------------------------------- +Copyright (c) 2018- , Advanced Micro Devices, Inc. (Chao Liu, Jing Zhang) +Copyright (c) 2019- , Advanced Micro Devices, Inc. (Letao Qin, Qianfeng Zhang, Liang Huang, Shaojie Wang) +Copyright (c) 2022- , Advanced Micro Devices, Inc. (Anthony Chang, Chunyu Lai, Illia Silin, Adam Osewski, Poyen Chen, Jehandad Khan) +Copyright (c) 2019-2021, Advanced Micro Devices, Inc. (Hanwen Chang) +Copyright (c) 2019-2020, Advanced Micro Devices, Inc. (Tejash Shah) +Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) +Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) + +SPDX-License-Identifier: MIT +Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +third_party\flash-attention\csrc\cutlass\LICENSE.txt +---------------------------------------------------- +Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\flash-attention\csrc\cutlass\python\LICENSE.txt +----------------------------------------------------------- +Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\flatbuffers\LICENSE +------------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +third_party\flatbuffers\dart\LICENSE +------------------------------------ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2014 Google Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +third_party\flatbuffers\swift\LICENSE +------------------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +third_party\fmt\LICENSE +----------------------- +Copyright (c) 2012 - present, Victor Zverovich and {fmt} contributors + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +--- Optional exception to the license --- + +As an exception, if, as a result of your compiling your source code, portions +of this Software are embedded into a machine-executable object form of such +source code, you may redistribute such embedded portions in such object form +without including the above copyright and permission notices. + + +third_party\gemmlowp\gemmlowp\LICENSE +------------------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +third_party\gloo\LICENSE +------------------------ +BSD License + +For Gloo software + +Copyright (c) 2017-present, Facebook, Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Facebook nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\googletest\LICENSE +------------------------------ +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\ideep\LICENSE +------------------------- +Copyright (c) 2018 Intel Corporation. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + + +third_party\ideep\mkl-dnn\LICENSE +--------------------------------- + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + ============================================================================ + + Copyright 2016-2023 Intel Corporation + Copyright 2018 YANDEX LLC + Copyright 2019-2023 FUJITSU LIMITED + Copyright 2020-2023 Arm Ltd. and affiliates + Copyright 2020-2022 Codeplay Software Limited + Copyright 2021 Alanna Tempest + Copyright 2022-2023 IBM Corporation + Copyright 2023 KNS Group LLC (YADRO) + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + This distribution includes third party software ("third party programs"). + This third party software, even if included with the distribution of + the Intel software, may be governed by separate license terms, including + without limitation, third party license terms, other Intel software license + terms, and open source software license terms. These separate license terms + govern your use of the third party programs as set forth in the + "THIRD-PARTY-PROGRAMS" file. + + +third_party\ideep\mkl-dnn\tests\gtests\gtest\LICENSE +---------------------------------------------------- +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\kineto\LICENSE +-------------------------- +BSD License + +For Kineto software + +Copyright (c) Meta Platforms, Inc. and affiliates. + +All contributions by Microsoft: +Copyright (c) Microsoft Corporation. (The Azure AI Platform team) + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Meta nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\kineto\libkineto\third_party\dynolog\LICENSE +-------------------------------------------------------- +MIT License + +Copyright (c) Facebook, Inc. and its affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +third_party\kineto\libkineto\third_party\dynolog\third_party\DCGM\LICENSE +------------------------------------------------------------------------- +Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + +third_party\kineto\libkineto\third_party\dynolog\third_party\DCGM\testing\python3\libs_3rdparty\colorama\LICENSE.txt +-------------------------------------------------------------------------------------------------------------------- +Copyright (c) 2010 Jonathan Hartley + +Released under the New BSD license (reproduced below), or alternatively you may +use this software under any OSI approved open source license such as those at +http://opensource.org/licenses/alphabetical + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name(s) of the copyright holders, nor those of its contributors + may be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +third_party\kineto\libkineto\third_party\dynolog\third_party\cpr\LICENSE +------------------------------------------------------------------------ +This license applies to everything except the contents of the "test" +directory and its subdirectories. + +MIT License + +Copyright (c) 2017-2021 Huu Nguyen +Copyright (c) 2022 libcpr and many other contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +third_party\kineto\libkineto\third_party\dynolog\third_party\cpr\test\LICENSE +----------------------------------------------------------------------------- +This license applies to everything inside this directory and all +subdirectories. + + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. + +third_party\kineto\libkineto\third_party\dynolog\third_party\fmt\LICENSE.rst +---------------------------------------------------------------------------- +Copyright (c) 2012 - present, Victor Zverovich + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +--- Optional exception to the license --- + +As an exception, if, as a result of your compiling your source code, portions +of this Software are embedded into a machine-executable object form of such +source code, you may redistribute such embedded portions in such object form +without including the above copyright and permission notices. + + +third_party\kineto\libkineto\third_party\dynolog\third_party\googletest\LICENSE +------------------------------------------------------------------------------- +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\kineto\libkineto\third_party\dynolog\third_party\json\test\thirdparty\doctest\LICENSE.txt +----------------------------------------------------------------------------------------------------- +The MIT License (MIT) + +Copyright (c) 2016-2021 Viktor Kirilov + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +third_party\kineto\libkineto\third_party\dynolog\third_party\json\third_party\cpplint\LICENSE +--------------------------------------------------------------------------------------------- +cpplint.py and its corresponding unit tests are Copyright (C) 2009 Google Inc. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\kineto\libkineto\third_party\dynolog\third_party\pfs\LICENSE +------------------------------------------------------------------------ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + Copyright 2020-present Daniel Trugman + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +third_party\kineto\libkineto\third_party\fmt\LICENSE +---------------------------------------------------- +Copyright (c) 2012 - present, Victor Zverovich and {fmt} contributors + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +--- Optional exception to the license --- + +As an exception, if, as a result of your compiling your source code, portions +of this Software are embedded into a machine-executable object form of such +source code, you may redistribute such embedded portions in such object form +without including the above copyright and permission notices. + + +third_party\kineto\libkineto\third_party\googletest\LICENSE +----------------------------------------------------------- +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\kineto\libkineto\third_party\googletest\googlemock\LICENSE +---------------------------------------------------------------------- +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\kineto\libkineto\third_party\googletest\googlemock\scripts\generator\LICENSE +---------------------------------------------------------------------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2007] Neal Norwitz + Portions Copyright [2007] Google Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +third_party\kineto\libkineto\third_party\googletest\googletest\LICENSE +---------------------------------------------------------------------- +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\kineto\tb_plugin\LICENSE +------------------------------------ +BSD License + +For Kineto software + +Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +All contributions by Microsoft: +Copyright (c) Microsoft Corporation. (The Azure AI Platform team) + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Facebook nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\mimalloc\LICENSE +---------------------------- +MIT License + +Copyright (c) 2018-2025 Microsoft Corporation, Daan Leijen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +third_party\miniz-3.0.2\LICENSE +------------------------------- +Copyright 2013-2014 RAD Game Tools and Valve Software +Copyright 2010-2014 Rich Geldreich and Tenacious Software LLC + +All Rights Reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + + +third_party\onnx\LICENSE +------------------------ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +third_party\onnx\third_party\pybind11\LICENSE +--------------------------------------------- +Copyright (c) 2016 Wenzel Jakob , All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Please also refer to the file .github/CONTRIBUTING.md, which clarifies licensing of +external contributions to this project including patches, pull requests, etc. + + +third_party\opentelemetry-cpp\LICENSE +------------------------------------- + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +third_party\opentelemetry-cpp\exporters\etw\include\opentelemetry\exporters\etw\LICENSE +--------------------------------------------------------------------------------------- +TraceLogging Dynamic for Windows + +Copyright (c) Microsoft Corporation. All rights reserved. + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +third_party\opentelemetry-cpp\third_party\benchmark\LICENSE +----------------------------------------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +third_party\opentelemetry-cpp\third_party\googletest\LICENSE +------------------------------------------------------------ +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\opentelemetry-cpp\third_party\ms-gsl\LICENSE +-------------------------------------------------------- +Copyright (c) 2015 Microsoft Corporation. All rights reserved. + +This code is licensed under the MIT License (MIT). + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + + +third_party\opentelemetry-cpp\third_party\opentelemetry-proto\LICENSE +--------------------------------------------------------------------- + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +third_party\opentelemetry-cpp\third_party\opentracing-cpp\LICENSE +----------------------------------------------------------------- + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright The OpenTracing Authors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +third_party\opentelemetry-cpp\third_party\opentracing-cpp\3rd_party\include\opentracing\catch2\LICENSE.txt +---------------------------------------------------------------------------------------------------------- +Boost Software License - Version 1.0 - August 17th, 2003 + +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. + + +third_party\opentelemetry-cpp\third_party\opentracing-cpp\3rd_party\include\opentracing\expected\LICENSE +-------------------------------------------------------------------------------------------------------- +The MIT License (MIT) + +Copyright (c) 2015 Martin Moene +Copyright (c) 2015 Microsoft Corporation. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + + +third_party\opentelemetry-cpp\third_party\opentracing-cpp\3rd_party\include\opentracing\variant\LICENSE +------------------------------------------------------------------------------------------------------- +Copyright (c) MapBox +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +- Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +- Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. +- Neither the name "MapBox" nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +third_party\opentelemetry-cpp\third_party\prometheus-cpp\LICENSE +---------------------------------------------------------------- +MIT License + +Copyright (c) 2016-2021 Jupp Mueller +Copyright (c) 2017-2022 Gregor Jasny + +And many contributors, see +https://github.com/jupp0r/prometheus-cpp/graphs/contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\civetweb\examples\rest\cJSON\LICENSE +------------------------------------------------------------------------------------------------------ +Copyright (c) 2009-2017 Dave Gamble and cJSON contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + + + +third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\civetweb\src\third_party\duktape-1.5.2\LICENSE.txt +-------------------------------------------------------------------------------------------------------------------- +=============== +Duktape license +=============== + +(http://opensource.org/licenses/MIT) + +Copyright (c) 2013-2016 by Duktape authors (see AUTHORS.rst) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + + +third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\civetweb\src\third_party\duktape-1.8.0\LICENSE.txt +-------------------------------------------------------------------------------------------------------------------- +=============== +Duktape license +=============== + +(http://opensource.org/licenses/MIT) + +Copyright (c) 2013-2017 by Duktape authors (see AUTHORS.rst) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + + +third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\googletest\LICENSE +------------------------------------------------------------------------------------ +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\opentelemetry-cpp\third_party\prometheus-cpp\3rdparty\googletest\googlemock\scripts\generator\LICENSE +----------------------------------------------------------------------------------------------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2007] Neal Norwitz + Portions Copyright [2007] Google Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +third_party\opentelemetry-cpp\tools\vcpkg\LICENSE.txt +----------------------------------------------------- +MIT License + +Copyright (c) Microsoft Corporation + +Permission is hereby granted, free of charge, to any person obtaining a copy of this +software and associated documentation files (the "Software"), to deal in the Software +without restriction, including without limitation the rights to use, copy, modify, +merge, publish, distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be included in all copies +or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF +CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE +OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +third_party\opentelemetry-cpp\tools\vcpkg\ports\boost-vcpkg-helpers\LICENSE.txt +------------------------------------------------------------------------------- +Copyright (c) Microsoft Corporation + +All rights reserved. + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +third_party\opentelemetry-cpp\tools\vcpkg\ports\ffnvcodec\LICENSE.txt +--------------------------------------------------------------------- +GNU LESSER GENERAL PUBLIC LICENSE +Version 2.1, February 1999 + +Copyright (C) 1991, 1999 Free Software Foundation, Inc. +51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +Everyone is permitted to copy and distribute verbatim copies +of this license document, but changing it is not allowed. + +[This is the first released version of the Lesser GPL. It also counts + as the successor of the GNU Library Public License, version 2, hence + the version number 2.1.] +Preamble +The licenses for most software are designed to take away your freedom to share and change it. By contrast, the GNU General Public Licenses are intended to guarantee your freedom to share and change free software--to make sure the software is free for all its users. + +This license, the Lesser General Public License, applies to some specially designated software packages--typically libraries--of the Free Software Foundation and other authors who decide to use it. You can use it too, but we suggest you first think carefully about whether this license or the ordinary General Public License is the better strategy to use in any particular case, based on the explanations below. + +When we speak of free software, we are referring to freedom of use, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for this service if you wish); that you receive source code or can get it if you want it; that you can change the software and use pieces of it in new free programs; and that you are informed that you can do these things. + +To protect your rights, we need to make restrictions that forbid distributors to deny you these rights or to ask you to surrender these rights. These restrictions translate to certain responsibilities for you if you distribute copies of the library or if you modify it. + +For example, if you distribute copies of the library, whether gratis or for a fee, you must give the recipients all the rights that we gave you. You must make sure that they, too, receive or can get the source code. If you link other code with the library, you must provide complete object files to the recipients, so that they can relink them with the library after making changes to the library and recompiling it. And you must show them these terms so they know their rights. + +We protect your rights with a two-step method: (1) we copyright the library, and (2) we offer you this license, which gives you legal permission to copy, distribute and/or modify the library. + +To protect each distributor, we want to make it very clear that there is no warranty for the free library. Also, if the library is modified by someone else and passed on, the recipients should know that what they have is not the original version, so that the original author's reputation will not be affected by problems that might be introduced by others. + +Finally, software patents pose a constant threat to the existence of any free program. We wish to make sure that a company cannot effectively restrict the users of a free program by obtaining a restrictive license from a patent holder. Therefore, we insist that any patent license obtained for a version of the library must be consistent with the full freedom of use specified in this license. + +Most GNU software, including some libraries, is covered by the ordinary GNU General Public License. This license, the GNU Lesser General Public License, applies to certain designated libraries, and is quite different from the ordinary General Public License. We use this license for certain libraries in order to permit linking those libraries into non-free programs. + +When a program is linked with a library, whether statically or using a shared library, the combination of the two is legally speaking a combined work, a derivative of the original library. The ordinary General Public License therefore permits such linking only if the entire combination fits its criteria of freedom. The Lesser General Public License permits more lax criteria for linking other code with the library. + +We call this license the "Lesser" General Public License because it does Less to protect the user's freedom than the ordinary General Public License. It also provides other free software developers Less of an advantage over competing non-free programs. These disadvantages are the reason we use the ordinary General Public License for many libraries. However, the Lesser license provides advantages in certain special circumstances. + +For example, on rare occasions, there may be a special need to encourage the widest possible use of a certain library, so that it becomes a de-facto standard. To achieve this, non-free programs must be allowed to use the library. A more frequent case is that a free library does the same job as widely used non-free libraries. In this case, there is little to gain by limiting the free library to free software only, so we use the Lesser General Public License. + +In other cases, permission to use a particular library in non-free programs enables a greater number of people to use a large body of free software. For example, permission to use the GNU C Library in non-free programs enables many more people to use the whole GNU operating system, as well as its variant, the GNU/Linux operating system. + +Although the Lesser General Public License is Less protective of the users' freedom, it does ensure that the user of a program that is linked with the Library has the freedom and the wherewithal to run that program using a modified version of the Library. + +The precise terms and conditions for copying, distribution and modification follow. Pay close attention to the difference between a "work based on the library" and a "work that uses the library". The former contains code derived from the library, whereas the latter must be combined with the library in order to run. + +TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION +0. This License Agreement applies to any software library or other program which contains a notice placed by the copyright holder or other authorized party saying it may be distributed under the terms of this Lesser General Public License (also called "this License"). Each licensee is addressed as "you". + +A "library" means a collection of software functions and/or data prepared so as to be conveniently linked with application programs (which use some of those functions and data) to form executables. + +The "Library", below, refers to any such software library or work which has been distributed under these terms. A "work based on the Library" means either the Library or any derivative work under copyright law: that is to say, a work containing the Library or a portion of it, either verbatim or with modifications and/or translated straightforwardly into another language. (Hereinafter, translation is included without limitation in the term "modification".) + +"Source code" for a work means the preferred form of the work for making modifications to it. For a library, complete source code means all the source code for all modules it contains, plus any associated interface definition files, plus the scripts used to control compilation and installation of the library. + +Activities other than copying, distribution and modification are not covered by this License; they are outside its scope. The act of running a program using the Library is not restricted, and output from such a program is covered only if its contents constitute a work based on the Library (independent of the use of the Library in a tool for writing it). Whether that is true depends on what the Library does and what the program that uses the Library does. + +1. You may copy and distribute verbatim copies of the Library's complete source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice and disclaimer of warranty; keep intact all the notices that refer to this License and to the absence of any warranty; and distribute a copy of this License along with the Library. + +You may charge a fee for the physical act of transferring a copy, and you may at your option offer warranty protection in exchange for a fee. + +2. You may modify your copy or copies of the Library or any portion of it, thus forming a work based on the Library, and copy and distribute such modifications or work under the terms of Section 1 above, provided that you also meet all of these conditions: + +a) The modified work must itself be a software library. +b) You must cause the files modified to carry prominent notices stating that you changed the files and the date of any change. +c) You must cause the whole of the work to be licensed at no charge to all third parties under the terms of this License. +d) If a facility in the modified Library refers to a function or a table of data to be supplied by an application program that uses the facility, other than as an argument passed when the facility is invoked, then you must make a good faith effort to ensure that, in the event an application does not supply such function or table, the facility still operates, and performs whatever part of its purpose remains meaningful. +(For example, a function in a library to compute square roots has a purpose that is entirely well-defined independent of the application. Therefore, Subsection 2d requires that any application-supplied function or table used by this function must be optional: if the application does not supply it, the square root function must still compute square roots.) + +These requirements apply to the modified work as a whole. If identifiable sections of that work are not derived from the Library, and can be reasonably considered independent and separate works in themselves, then this License, and its terms, do not apply to those sections when you distribute them as separate works. But when you distribute the same sections as part of a whole which is a work based on the Library, the distribution of the whole must be on the terms of this License, whose permissions for other licensees extend to the entire whole, and thus to each and every part regardless of who wrote it. + +Thus, it is not the intent of this section to claim rights or contest your rights to work written entirely by you; rather, the intent is to exercise the right to control the distribution of derivative or collective works based on the Library. + +In addition, mere aggregation of another work not based on the Library with the Library (or with a work based on the Library) on a volume of a storage or distribution medium does not bring the other work under the scope of this License. + +3. You may opt to apply the terms of the ordinary GNU General Public License instead of this License to a given copy of the Library. To do this, you must alter all the notices that refer to this License, so that they refer to the ordinary GNU General Public License, version 2, instead of to this License. (If a newer version than version 2 of the ordinary GNU General Public License has appeared, then you can specify that version instead if you wish.) Do not make any other change in these notices. + +Once this change is made in a given copy, it is irreversible for that copy, so the ordinary GNU General Public License applies to all subsequent copies and derivative works made from that copy. + +This option is useful when you wish to copy part of the code of the Library into a program that is not a library. + +4. You may copy and distribute the Library (or a portion or derivative of it, under Section 2) in object code or executable form under the terms of Sections 1 and 2 above provided that you accompany it with the complete corresponding machine-readable source code, which must be distributed under the terms of Sections 1 and 2 above on a medium customarily used for software interchange. + +If distribution of object code is made by offering access to copy from a designated place, then offering equivalent access to copy the source code from the same place satisfies the requirement to distribute the source code, even though third parties are not compelled to copy the source along with the object code. + +5. A program that contains no derivative of any portion of the Library, but is designed to work with the Library by being compiled or linked with it, is called a "work that uses the Library". Such a work, in isolation, is not a derivative work of the Library, and therefore falls outside the scope of this License. + +However, linking a "work that uses the Library" with the Library creates an executable that is a derivative of the Library (because it contains portions of the Library), rather than a "work that uses the library". The executable is therefore covered by this License. Section 6 states terms for distribution of such executables. + +When a "work that uses the Library" uses material from a header file that is part of the Library, the object code for the work may be a derivative work of the Library even though the source code is not. Whether this is true is especially significant if the work can be linked without the Library, or if the work is itself a library. The threshold for this to be true is not precisely defined by law. + +If such an object file uses only numerical parameters, data structure layouts and accessors, and small macros and small inline functions (ten lines or less in length), then the use of the object file is unrestricted, regardless of whether it is legally a derivative work. (Executables containing this object code plus portions of the Library will still fall under Section 6.) + +Otherwise, if the work is a derivative of the Library, you may distribute the object code for the work under the terms of Section 6. Any executables containing that work also fall under Section 6, whether or not they are linked directly with the Library itself. + +6. As an exception to the Sections above, you may also combine or link a "work that uses the Library" with the Library to produce a work containing portions of the Library, and distribute that work under terms of your choice, provided that the terms permit modification of the work for the customer's own use and reverse engineering for debugging such modifications. + +You must give prominent notice with each copy of the work that the Library is used in it and that the Library and its use are covered by this License. You must supply a copy of this License. If the work during execution displays copyright notices, you must include the copyright notice for the Library among them, as well as a reference directing the user to the copy of this License. Also, you must do one of these things: + +a) Accompany the work with the complete corresponding machine-readable source code for the Library including whatever changes were used in the work (which must be distributed under Sections 1 and 2 above); and, if the work is an executable linked with the Library, with the complete machine-readable "work that uses the Library", as object code and/or source code, so that the user can modify the Library and then relink to produce a modified executable containing the modified Library. (It is understood that the user who changes the contents of definitions files in the Library will not necessarily be able to recompile the application to use the modified definitions.) +b) Use a suitable shared library mechanism for linking with the Library. A suitable mechanism is one that (1) uses at run time a copy of the library already present on the user's computer system, rather than copying library functions into the executable, and (2) will operate properly with a modified version of the library, if the user installs one, as long as the modified version is interface-compatible with the version that the work was made with. +c) Accompany the work with a written offer, valid for at least three years, to give the same user the materials specified in Subsection 6a, above, for a charge no more than the cost of performing this distribution. +d) If distribution of the work is made by offering access to copy from a designated place, offer equivalent access to copy the above specified materials from the same place. +e) Verify that the user has already received a copy of these materials or that you have already sent this user a copy. +For an executable, the required form of the "work that uses the Library" must include any data and utility programs needed for reproducing the executable from it. However, as a special exception, the materials to be distributed need not include anything that is normally distributed (in either source or binary form) with the major components (compiler, kernel, and so on) of the operating system on which the executable runs, unless that component itself accompanies the executable. + +It may happen that this requirement contradicts the license restrictions of other proprietary libraries that do not normally accompany the operating system. Such a contradiction means you cannot use both them and the Library together in an executable that you distribute. + +7. You may place library facilities that are a work based on the Library side-by-side in a single library together with other library facilities not covered by this License, and distribute such a combined library, provided that the separate distribution of the work based on the Library and of the other library facilities is otherwise permitted, and provided that you do these two things: + +a) Accompany the combined library with a copy of the same work based on the Library, uncombined with any other library facilities. This must be distributed under the terms of the Sections above. +b) Give prominent notice with the combined library of the fact that part of it is a work based on the Library, and explaining where to find the accompanying uncombined form of the same work. +8. You may not copy, modify, sublicense, link with, or distribute the Library except as expressly provided under this License. Any attempt otherwise to copy, modify, sublicense, link with, or distribute the Library is void, and will automatically terminate your rights under this License. However, parties who have received copies, or rights, from you under this License will not have their licenses terminated so long as such parties remain in full compliance. + +9. You are not required to accept this License, since you have not signed it. However, nothing else grants you permission to modify or distribute the Library or its derivative works. These actions are prohibited by law if you do not accept this License. Therefore, by modifying or distributing the Library (or any work based on the Library), you indicate your acceptance of this License to do so, and all its terms and conditions for copying, distributing or modifying the Library or works based on it. + +10. Each time you redistribute the Library (or any work based on the Library), the recipient automatically receives a license from the original licensor to copy, distribute, link with or modify the Library subject to these terms and conditions. You may not impose any further restrictions on the recipients' exercise of the rights granted herein. You are not responsible for enforcing compliance by third parties with this License. + +11. If, as a consequence of a court judgment or allegation of patent infringement or for any other reason (not limited to patent issues), conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot distribute so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not distribute the Library at all. For example, if a patent license would not permit royalty-free redistribution of the Library by all those who receive copies directly or indirectly through you, then the only way you could satisfy both it and this License would be to refrain entirely from distribution of the Library. + +If any portion of this section is held invalid or unenforceable under any particular circumstance, the balance of the section is intended to apply, and the section as a whole is intended to apply in other circumstances. + +It is not the purpose of this section to induce you to infringe any patents or other property right claims or to contest validity of any such claims; this section has the sole purpose of protecting the integrity of the free software distribution system which is implemented by public license practices. Many people have made generous contributions to the wide range of software distributed through that system in reliance on consistent application of that system; it is up to the author/donor to decide if he or she is willing to distribute software through any other system and a licensee cannot impose that choice. + +This section is intended to make thoroughly clear what is believed to be a consequence of the rest of this License. + +12. If the distribution and/or use of the Library is restricted in certain countries either by patents or by copyrighted interfaces, the original copyright holder who places the Library under this License may add an explicit geographical distribution limitation excluding those countries, so that distribution is permitted only in or among countries not thus excluded. In such case, this License incorporates the limitation as if written in the body of this License. + +13. The Free Software Foundation may publish revised and/or new versions of the Lesser General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. + +Each version is given a distinguishing version number. If the Library specifies a version number of this License which applies to it and "any later version", you have the option of following the terms and conditions either of that version or of any later version published by the Free Software Foundation. If the Library does not specify a license version number, you may choose any version ever published by the Free Software Foundation. + +14. If you wish to incorporate parts of the Library into other free programs whose distribution conditions are incompatible with these, write to the author to ask for permission. For software which is copyrighted by the Free Software Foundation, write to the Free Software Foundation; we sometimes make exceptions for this. Our decision will be guided by the two goals of preserving the free status of all derivatives of our free software and of promoting the sharing and reuse of software generally. + +NO WARRANTY + +15. BECAUSE THE LIBRARY IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY FOR THE LIBRARY, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE LIBRARY "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE LIBRARY IS WITH YOU. SHOULD THE LIBRARY PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + +16. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR REDISTRIBUTE THE LIBRARY AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE LIBRARY (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE LIBRARY TO OPERATE WITH ANY OTHER SOFTWARE), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +END OF TERMS AND CONDITIONS +How to Apply These Terms to Your New Libraries +If you develop a new library, and you want it to be of the greatest possible use to the public, we recommend making it free software that everyone can redistribute and change. You can do so by permitting redistribution under these terms (or, alternatively, under the terms of the ordinary General Public License). + +To apply these terms, attach the following notices to the library. It is safest to attach them to the start of each source file to most effectively convey the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. + +one line to give the library's name and an idea of what it does. +Copyright (C) year name of author + +This library is free software; you can redistribute it and/or +modify it under the terms of the GNU Lesser General Public +License as published by the Free Software Foundation; either +version 2.1 of the License, or (at your option) any later version. + +This library is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +Lesser General Public License for more details. + +You should have received a copy of the GNU Lesser General Public +License along with this library; if not, write to the Free Software +Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA +Also add information on how to contact you by electronic and paper mail. + +You should also get your employer (if you work as a programmer) or your school, if any, to sign a "copyright disclaimer" for the library, if necessary. Here is a sample; alter the names: + +Yoyodyne, Inc., hereby disclaims all copyright interest in +the library `Frob' (a library for tweaking knobs) written +by James Random Hacker. + +signature of Ty Coon, 1 April 1990 +Ty Coon, President of Vice +That's all there is to it! + +third_party\opentelemetry-cpp\tools\vcpkg\ports\gettimeofday\LICENSE +-------------------------------------------------------------------- +/* + * Copied from PostgreSQL source: + * http://doxygen.postgresql.org/gettimeofday_8c_source.html + * + */ + +/* + * gettimeofday.c + * Win32 gettimeofday() replacement + * + * src/port/gettimeofday.c + * + * Copyright (c) 2003 SRA, Inc. + * Copyright (c) 2003 SKC, Inc. + * + * Permission to use, copy, modify, and distribute this software and + * its documentation for any purpose, without fee, and without a + * written agreement is hereby granted, provided that the above + * copyright notice and this paragraph and the following two + * paragraphs appear in all copies. + * + * IN NO EVENT SHALL THE AUTHOR BE LIABLE TO ANY PARTY FOR DIRECT, + * INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING + * LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS + * DOCUMENTATION, EVEN IF THE UNIVERSITY OF CALIFORNIA HAS BEEN ADVISED + * OF THE POSSIBILITY OF SUCH DAMAGE. + * + * THE AUTHOR SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS + * IS" BASIS, AND THE AUTHOR HAS NO OBLIGATIONS TO PROVIDE MAINTENANCE, + * SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. + */ + + +third_party\opentelemetry-cpp\tools\vcpkg\ports\hungarian\LICENSE.txt +--------------------------------------------------------------------- +/******************************************************************** + ******************************************************************** + ** + ** libhungarian by Cyrill Stachniss, 2004 + ** + ** + ** Solving the Minimum Assignment Problem using the + ** Hungarian Method. + ** + ** ** This file may be freely copied and distributed! ** + ** + ** Parts of the used code was originally provided by the + ** "Stanford GraphGase", but I made changes to this code. + ** As asked by the copyright node of the "Stanford GraphGase", + ** I hereby proclaim that this file are *NOT* part of the + ** "Stanford GraphGase" distrubition! + ** + ** This file is distributed in the hope that it will be useful, + ** but WITHOUT ANY WARRANTY; without even the implied + ** warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + ** PURPOSE. + ** + ******************************************************************** + ********************************************************************/ + + +third_party\opentelemetry-cpp\tools\vcpkg\ports\irrlicht\LICENSE.txt +-------------------------------------------------------------------- +The Irrlicht Engine License +=========================== + +Copyright (C) 2002-2015 Nikolaus Gebhardt + +This software is provided 'as-is', without any express or implied +warranty. In no event will the authors be held liable for any damages +arising from the use of this software. + +Permission is granted to anyone to use this software for any purpose, +including commercial applications, and to alter it and redistribute it +freely, subject to the following restrictions: + +1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgement in the product documentation would be + appreciated but is not required. +2. Altered source versions must be clearly marked as such, and must not be + misrepresented as being the original software. +3. This notice may not be removed or altered from any source distribution. + +third_party\opentelemetry-cpp\tools\vcpkg\ports\libstemmer\LICENSE +------------------------------------------------------------------ +Snowball - License +Except where explicitly noted, all the software given out on this Snowball site is covered by the 3-clause BSD License: + +Copyright (c) 2001, Dr Martin Porter, +Copyright (c) 2002, Richard Boulton. +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +Essentially, all this means is that you can do what you like with the code, except claim another Copyright for it, or claim that it is issued under a different license. The software is also issued without warranties, which means that if anyone suffers through its use, they cannot come back and sue you. You also have to alert anyone to whom you give the Snowball software to the fact that it is covered by the BSD license. + +We have not bothered to insert the licensing arrangement into the text of the Snowball software. + + +third_party\opentelemetry-cpp\tools\vcpkg\ports\pdcurses\LICENSE +---------------------------------------------------------------- +The core package is in the public domain, but small portions of PDCurses are subject to copyright under various licenses. + +The win32 files are released to the public domain. + +If you use PDCurses in an application, an acknowledgement would be appreciated, but is not mandatory. If you make corrections or enhancements to PDCurses, please forward them to the current maintainer for the benefit of other users. + +This software is provided AS IS with NO WARRANTY whatsoever. + +third_party\opentelemetry-cpp\tools\vcpkg\ports\physac\LICENSE +-------------------------------------------------------------- +MIT License + +Copyright (c) 2022 Víctor Fisac + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +third_party\opentelemetry-cpp\tools\vcpkg\ports\pqp\LICENSE +----------------------------------------------------------- +Copyright 1999 University of North Carolina at Chapel Hill. +All rights reserved. + +Permission to use, copy, modify, and distribute this software and its +documentation for educational, research, and non-profit purposes, without fee, +and without a written agreement is hereby granted, provided that the above +copyright notice and the following three paragraphs appear in all copies. + +IN NO EVENT SHALL THE UNIVERSITY OF NORTH CAROLINA AT CHAPEL HILL BE LIABLE TO +ANY PARTY FOR DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, +INCLUDING LOST PROFITS, ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS +DOCUMENTATION, EVEN IF THE UNIVERSITY OF NORTH CAROLINA AT CHAPEL HILL HAS +BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +THE UNIVERSITY OF NORTH CAROLINA AT CHAPEL HILL SPECIFICALLY DISCLAIMS ANY +WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED +HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF NORTH CAROLINA AT +CHAPEL HILL HAS NO OBLIGATIONS TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, +ENHANCEMENTS, OR MODIFICATIONS. + +The authors may be contacted via: + +US Mail: Eric Larsen, Stefan Gottschalk + Department of Computer Science + Sitterson Hall, CB #3175 + University of North Carolina + Chapel Hill, NC 27599-3175 + +Phone: (919) 962-1749 + +Email: geom@cs.unc.edu + +third_party\opentelemetry-cpp\tools\vcpkg\ports\sigslot\LICENSE +--------------------------------------------------------------- +License +The sigslot library has been placed in the public domain. This means that you are free to use it however you like. + +The author takes no responsibility or liability of any kind for any use that you may make of this library. + +If you screw up, it's your fault. + +If the library screws up, you got it for free, so you should have tested it better - it's still your responsibility. + +third_party\opentelemetry-cpp\tools\vcpkg\ports\tensorflow-common\LICENSE.txt +----------------------------------------------------------------------------- +Copyright (c) Microsoft Corporation + +All rights reserved. + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + + +third_party\opentelemetry-cpp\tools\vcpkg\ports\vulkan\LICENSE.txt +------------------------------------------------------------------ +/* +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + + +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. + +Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. + +Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. + +You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + +You must give any other recipients of the Work or Derivative Works a copy of this License; and +You must cause any modified files to carry prominent notices stating that You changed the files; and +You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and +If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. + +Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. + +This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. + +Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. + +In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. + +While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +=============================================================================================================================================== + +/Copyright (C) 2012 LunarG, Inc. +//All rights reserved. +// +//Redistribution and use in source and binary forms, with or without +//modification, are permitted provided that the following conditions +//are met: +// +// Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// +// Neither the name of LunarG Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +//THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +//"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +//LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +//FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +//COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +//BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +//LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +//CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +//LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +//ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +//POSSIBILITY OF SUCH DAMAGE. + +=============================================================================================================================================== + +#============================================================================= +# Copyright 2007-2009 Kitware, Inc. +# Copyright 2007-2008 Miguel A. Figueroa-Villanueva +# +# Distributed under the OSI-approved BSD License (the "License"); +# see accompanying file Copyright_cmake.txt for details. +# +# This software is distributed WITHOUT ANY WARRANTY; without even the +# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +# See the License for more information. +#============================================================================= +# (To distributed this file outside of CMake, substitute the full +# License text for the above reference.) + + +============================================================================================================================================== + +// +// Copyright (C) 2015-2018 Google, Inc. +// Copyright (C) +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// +// Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// +// Neither the name of 3Dlabs Inc. Ltd. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +// COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// + +========================================================================================================================================== + +Note: This license has also been called the "New BSD License" or "Modified BSD License". See also the 2-clause BSD License. +Copyright +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +========================================================================================================================================== + +/* +* xxHash - Fast Hash algorithm +* Copyright (C) 2012-2016, Yann Collet +* +* BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php) +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are +* met: +* +* * Redistributions of source code must retain the above copyright +* notice, this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above +* copyright notice, this list of conditions and the following disclaimer +* in the documentation and/or other materials provided with the +* distribution. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +* You can contact the author at : +* - xxHash homepage: http://www.xxhash.com +* - xxHash source repository : https://github.com/Cyan4973/xxHash +*/ + + +=========================================================================================================================================== + +# Copyright (C) 2018 Google, Inc. +# +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# +# Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. +# +# Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +========================================================================================================================================== + +/* A Bison parser, made by GNU Bison 3.0.4. */ + +/* Bison implementation for Yacc-like parsers in C +Copyright (C) 1984, 1989-1990, 2000-2015 Free Software Foundation, Inc. +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. +You should have received a copy of the GNU General Public License +along with this program. If not, see . */ + +/* As a special exception, you may create a larger work that contains +part or all of the Bison parser skeleton and distribute that work +under terms of your choice, so long as that work isn't itself a +parser generator using the skeleton or a modified version thereof +as a parser skeleton. Alternatively, if you modify or redistribute +the parser skeleton itself, you may (at your option) remove this +special exception, which will cause the skeleton and the resulting +Bison output files to be licensed under the GNU General Public +License without this special exception. +This special exception was added by the Free Software Foundation in +version 2.2 of Bison. */ + +/* C LALR(1) parser skeleton written by Richard Stallman, by +simplifying the original so-called "semantic" parser. */ + +/* All symbols defined below should begin with yy or YY, to avoid +infringing on user name space. This should be done even for local +variables, as they might otherwise be expanded by user macros. +There are some unavoidable exceptions within include files to +define necessary library symbols; they are noted "INFRINGES ON +USER NAME SPACE" below. */ + +============================================================================================================================================== + +copyright : [ +Copyright (c) 2017 The Khronos Group Inc., +, +Permission is hereby granted, free of charge, to any person obtaining a copy, +of this software and/or associated documentation files (the \Materials\"),", +to deal in the Materials without restriction, including without limitation, +the rights to use, copy, modify, merge, publish, distribute, sublicense,, +and/or sell copies of the Materials, and to permit persons to whom the, +Materials are furnished to do so, subject to the following conditions:, +, +The above copyright notice and this permission notice shall be included in, +all copies or substantial portions of the Materials., +, +MODIFICATIONS TO THIS FILE MAY MEAN IT NO LONGER ACCURATELY REFLECTS KHRONOS, +STANDARDS. THE UNMODIFIED, NORMATIVE VERSIONS OF KHRONOS SPECIFICATIONS AND, +HEADER INFORMATION ARE LOCATED AT https://www.khronos.org/registry/ , +, +THE MATERIALS ARE PROVIDED \AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS", +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL, +THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER, +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING, +FROM,OUT OF OR IN CONNECTION WITH THE MATERIALS OR THE USE OR OTHER DEALINGS, +IN THE MATERIALS. + +============================================================================================================================================= + +CMake - Cross Platform Makefile Generator +Copyright 2000-2009 Kitware, Inc., Insight Software Consortium +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +* Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +* Neither the names of Kitware, Inc., the Insight Software Consortium, +nor the names of their contributors may be used to endorse or promote +products derived from this software without specific prior written +permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +AS IS AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +------------------------------------------------------------------------------ + +The above copyright and license notice applies to distributions of +CMake in source and binary form. Some source files contain additional +notices of original copyright by their contributors; see each source +for details. Third-party software packages supplied with CMake under +compatible licenses provide their own copyright notices documented in +corresponding subdirectories. + +------------------------------------------------------------------------------ + +CMake was initially developed by Kitware with the following sponsorship: + +* National Library of Medicine at the National Institutes of Health +as part of the Insight Segmentation and Registration Toolkit (ITK). + +* US National Labs (Los Alamos, Livermore, Sandia) ASC Parallel +Visualization Initiative. + +* National Alliance for Medical Image Computing (NAMIC) is funded by the +National Institutes of Health through the NIH Roadmap for Medical Research, +Grant U54 EB005149. + +* Kitware, Inc. + +======================================================================================================================================== + +The authors of this software are Rob Pike and Ken Thompson. +* Copyright (c) 2002 by Lucent Technologies. +* Permission to use, copy, modify, and distribute this software for any +* purpose without fee is hereby granted, provided that this entire notice +* is included in all copies of any software which is or includes a copy +* or modification of this software and in all copies of the supporting +* documentation for such software. +* THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED +* WARRANTY. IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY +* REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY +* OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE. + + +======================================================================================================================================== + +Copyright (c) 2015-2018 Baldur Karlsson + +Copyright (c) 2014 Crytek + +Copyright (c) 1998-2018 Third party code and tools + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +========================================================================================================================================= + +/* +Copyright (c) 2009 Dave Gamble +Copyright (c) 2015-2016 The Khronos Group Inc. +Copyright (c) 2015-2016 Valve Corporation +Copyright (c) 2015-2016 LunarG, Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ + +=========================================================================================================================================== + +Copyright (c) 2005 - 2017 G-Truc Creation + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + + +========================================================================================================================================== + +/* +The JsonCpp library's source code, including accompanying documentation, +tests and demonstration applications, are licensed under the following +conditions... +The author (Baptiste Lepilleur) explicitly disclaims copyright in all +jurisdictions which recognize such a disclaimer. In such jurisdictions, +this software is released into the Public Domain. +In jurisdictions which do not recognize Public Domain property (e.g. Germany as of +2010), this software is Copyright (c) 2007-2010 by Baptiste Lepilleur, and is +released under the terms of the MIT License (see below). +In jurisdictions which recognize Public Domain property, the user of this +software may choose to accept it either as 1) Public Domain, 2) under the +conditions of the MIT License (see below), or 3) under the terms of dual +Public Domain/MIT License conditions described here, as they choose. +The MIT License is about as close to Public Domain as a license can get, and is +described in clear, concise terms at: +http://en.wikipedia.org/wiki/MIT_License + +The full text of the MIT License follows: + +Copyright (c) 2007-2010 Baptiste Lepilleur +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, copy, +modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS +BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +========================================================================================================================================== + +/** +* `murmurhash.h' - murmurhash +* +* copyright (c) 2014 joseph werle +* Copyright (c) 2015-2016 The Khronos Group Inc. +* Copyright (c) 2015-2016 Valve Corporation +* Copyright (c) 2015-2016 LunarG, Inc. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and/or associated documentation files (the "Materials"), to +* deal in the Materials without restriction, including without limitation the +* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +* sell copies of the Materials, and to permit persons to whom the Materials are +* furnished to do so, subject to the following conditions: +* +* The above copyright notice(s) and this permission notice shall be included in +* all copies or substantial portions of the Materials. +* +* THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE MATERIALS OR THE +* USE OR OTHER DEALINGS IN THE MATERIALS. +*/ + +========================================================================================================================================= + +Licenced as X11: http://www.kryogenix.org/code/browser/licence.html +This basically means: do what you want with it. + +========================================================================================================================================= + +/////////////////////////////////////////////////////////////////////////////////// +/// OpenGL Mathematics (glm.g-truc.net) +/// +/// Copyright (c) 2005 - 2014 G-Truc Creation (www.g-truc.net) +/// Permission is hereby granted, free of charge, to any person obtaining a copy +/// of this software and associated documentation files (the "Software"), to deal +/// in the Software without restriction, including without limitation the rights +/// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +/// copies of the Software, and to permit persons to whom the Software is +/// furnished to do so, subject to the following conditions: +/// +/// The above copyright notice and this permission notice shall be included in +/// all copies or substantial portions of the Software. +/// +/// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +/// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +/// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +/// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +/// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +/// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +/// THE SOFTWARE. +/// +/// @ref core +/// @file glm/common.hpp +/// @date 2013-12-24 / 2013-12-24 +/// @author Christophe Riccio +/////////////////////////////////////////////////////////////////////////////////// + + +========================================================================================================================================== + +// LICENSE +// +// This software is in the public domain. Where that dedication is not +// recognized, you are granted a perpetual, irrevocable license to copy, +// distribute, and modify this file as you see fit. +// + +========================================================================================================================================== + +Simple DirectMedia Layer +Copyright (C) 1997-2018 Sam Lantinga + +This software is provided 'as-is', without any express or implied +warranty. In no event will the authors be held liable for any damages +arising from the use of this software. + +Permission is granted to anyone to use this software for any purpose, +including commercial applications, and to alter it and redistribute it +freely, subject to the following restrictions: + +1. The origin of this software must not be misrepresented; you must not +claim that you wrote the original software. If you use this software +in a product, an acknowledgment in the product documentation would be +appreciated but is not required. +2. Altered source versions must be plainly marked as such, and must not be +misrepresented as being the original software. +3. This notice may not be removed or altered from any source distribution. + +========================================================================================================================================= + +/****************************************************************************\ +Copyright (c) 2002, NVIDIA Corporation. + +NVIDIA Corporation("NVIDIA") supplies this software to you in +consideration of your agreement to the following terms, and your use, +installation, modification or redistribution of this NVIDIA software +constitutes acceptance of these terms. If you do not agree with these +terms, please do not use, install, modify or redistribute this NVIDIA +software. + +In consideration of your agreement to abide by the following terms, and +subject to these terms, NVIDIA grants you a personal, non-exclusive +license, under NVIDIA's copyrights in this original NVIDIA software (the +NVIDIA Software), to use, reproduce, modify and redistribute the +NVIDIA Software, with or without modifications, in source and/or binary +forms; provided that if you redistribute the NVIDIA Software, you must +retain the copyright notice of NVIDIA, this notice and the following +text and disclaimers in all such redistributions of the NVIDIA Software. +Neither the name, trademarks, service marks nor logos of NVIDIA +Corporation may be used to endorse or promote products derived from the +NVIDIA Software without specific prior written permission from NVIDIA. +Except as expressly stated in this notice, no other rights or licenses +express or implied, are granted by NVIDIA herein, including but not +limited to any patent rights that may be infringed by your derivative +works or by other works in which the NVIDIA Software may be +incorporated. No hardware is licensed hereunder. + +THE NVIDIA SOFTWARE IS BEING PROVIDED ON AN "AS IS" BASIS, WITHOUT +WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING WITHOUT LIMITATION, WARRANTIES OR CONDITIONS OF TITLE, +NON-INFRINGEMENT, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, OR +ITS USE AND OPERATION EITHER ALONE OR IN COMBINATION WITH OTHER +PRODUCTS. + +IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY SPECIAL, INDIRECT, +INCIDENTAL, EXEMPLARY, CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, LOST PROFITS; PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF +USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) OR ARISING IN ANY WAY +OUT OF THE USE, REPRODUCTION, MODIFICATION AND/OR DISTRIBUTION OF THE +NVIDIA SOFTWARE, HOWEVER CAUSED AND WHETHER UNDER THEORY OF CONTRACT, +TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY OR OTHERWISE, EVEN IF +NVIDIA HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +\****************************************************************************/ + +================================================================================================================================================== + +This software is provided 'as-is', without any express or implied +warranty. In no event will the authors be held liable for any damages +arising from the use of this software. + +Permission is granted to anyone to use this software for any purpose, +including commercial applications, and to alter it and redistribute it +freely, subject to the following restrictions: + +1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. +2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. +3. This notice may not be removed or altered from any source distribution. + + +================================================================================================================================================== + +GNU LESSER GENERAL PUBLIC LICENSE +Version 3, 29 June 2007 + +Copyright (C) 2007 Free Software Foundation, Inc. + +Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. + +This version of the GNU Lesser General Public License incorporates the terms and conditions of version 3 of the GNU General Public License, supplemented by the additional permissions listed below. + +0. Additional Definitions. + +As used herein, "this License" refers to version 3 of the GNU Lesser General Public License, and the "GNU GPL" refers to version 3 of the GNU General Public License. + +"The Library" refers to a covered work governed by this License, other than an Application or a Combined Work as defined below. + +An "Application" is any work that makes use of an interface provided by the Library, but which is not otherwise based on the Library. Defining a subclass of a class defined by the Library is deemed a mode of using an interface provided by the Library. + +A "Combined Work" is a work produced by combining or linking an Application with the Library. The particular version of the Library with which the Combined Work was made is also called the "Linked Version". + +The "Minimal Corresponding Source" for a Combined Work means the Corresponding Source for the Combined Work, excluding any source code for portions of the Combined Work that, considered in isolation, are based on the Application, and not on the Linked Version. + +The "Corresponding Application Code" for a Combined Work means the object code and/or source code for the Application, including any data and utility programs needed for reproducing the Combined Work from the Application, but excluding the System Libraries of the Combined Work. + +1. Exception to Section 3 of the GNU GPL. + +You may convey a covered work under sections 3 and 4 of this License without being bound by section 3 of the GNU GPL. + +2. Conveying Modified Versions. + +If you modify a copy of the Library, and, in your modifications, a facility refers to a function or data to be supplied by an Application that uses the facility (other than as an argument passed when the facility is invoked), then you may convey a copy of the modified version: + +a) under this License, provided that you make a good faith effort to ensure that, in the event an Application does not supply the function or data, the facility still operates, and performs whatever part of its purpose remains meaningful, or +b) under the GNU GPL, with none of the additional permissions of this License applicable to that copy. +3. Object Code Incorporating Material from Library Header Files. + +The object code form of an Application may incorporate material from a header file that is part of the Library. You may convey such object code under terms of your choice, provided that, if the incorporated material is not limited to numerical parameters, data structure layouts and accessors, or small macros, inline functions and templates (ten or fewer lines in length), you do both of the following: + +a) Give prominent notice with each copy of the object code that the Library is used in it and that the Library and its use are covered by this License. +b) Accompany the object code with a copy of the GNU GPL and this license document. +4. Combined Works. + +You may convey a Combined Work under terms of your choice that, taken together, effectively do not restrict modification of the portions of the Library contained in the Combined Work and reverse engineering for debugging such modifications, if you also do each of the following: + +a) Give prominent notice with each copy of the Combined Work that the Library is used in it and that the Library and its use are covered by this License. +b) Accompany the Combined Work with a copy of the GNU GPL and this license document. +c) For a Combined Work that displays copyright notices during execution, include the copyright notice for the Library among these notices, as well as a reference directing the user to the copies of the GNU GPL and this license document. +d) Do one of the following: +0) Convey the Minimal Corresponding Source under the terms of this License, and the Corresponding Application Code in a form suitable for, and under terms that permit, the user to recombine or relink the Application with a modified version of the Linked Version to produce a modified Combined Work, in the manner specified by section 6 of the GNU GPL for conveying Corresponding Source. +1) Use a suitable shared library mechanism for linking with the Library. A suitable mechanism is one that (a) uses at run time a copy of the Library already present on the user's computer system, and (b) will operate properly with a modified version of the Library that is interface-compatible with the Linked Version. +e) Provide Installation Information, but only if you would otherwise be required to provide such information under section 6 of the GNU GPL, and only to the extent that such information is necessary to install and execute a modified version of the Combined Work produced by recombining or relinking the Application with a modified version of the Linked Version. (If you use option 4d0, the Installation Information must accompany the Minimal Corresponding Source and Corresponding Application Code. If you use option 4d1, you must provide the Installation Information in the manner specified by section 6 of the GNU GPL for conveying Corresponding Source.) +5. Combined Libraries. + +You may place library facilities that are a work based on the Library side by side in a single library together with other library facilities that are not Applications and are not covered by this License, and convey such a combined library under terms of your choice, if you do both of the following: + +a) Accompany the combined library with a copy of the same work based on the Library, uncombined with any other library facilities, conveyed under the terms of this License. +b) Give prominent notice with the combined library that part of it is a work based on the Library, and explaining where to find the accompanying uncombined form of the same work. +6. Revised Versions of the GNU Lesser General Public License. + +The Free Software Foundation may publish revised and/or new versions of the GNU Lesser General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. + +Each version is given a distinguishing version number. If the Library as you received it specifies that a certain numbered version of the GNU Lesser General Public License "or any later version" applies to it, you have the option of following the terms and conditions either of that published version or of any later version published by the Free Software Foundation. If the Library as you received it does not specify a version number of the GNU Lesser General Public License, you may choose any version of the GNU Lesser General Public License ever published by the Free Software Foundation. + +If the Library as you received it specifies that a proxy can decide whether future versions of the GNU Lesser General Public License shall apply, that proxy's public statement of acceptance of any version is permanent authorization for you to choose that version for the Library. + + +third_party\protobuf\LICENSE +---------------------------- +Copyright 2008 Google Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Code generated by the Protocol Buffer compiler is owned by the owner +of the input file used when generating it. This code is not +standalone and requires a support library to be linked with it. This +support library is itself covered by the above license. + + +third_party\protobuf\third_party\benchmark\LICENSE +-------------------------------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +third_party\protobuf\third_party\googletest\LICENSE +--------------------------------------------------- +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\protobuf\third_party\googletest\googlemock\LICENSE +-------------------------------------------------------------- +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\protobuf\third_party\googletest\googlemock\scripts\generator\LICENSE +-------------------------------------------------------------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2007] Neal Norwitz + Portions Copyright [2007] Google Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +third_party\protobuf\third_party\googletest\googletest\LICENSE +-------------------------------------------------------------- +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\psimd\LICENSE +------------------------- +The MIT License (MIT) + +Copyright (c) 2017 Facebook Inc. +Copyright (c) 2014-2017 Georgia Institute of Technology +Copyright 2019 Google LLC + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +third_party\pthreadpool\LICENSE +------------------------------- +Copyright 2019 Google LLC +Copyright (c) 2017 Facebook Inc. +Copyright (c) 2015-2017 Georgia Institute of Technology +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +third_party\pybind11\LICENSE +---------------------------- +Copyright (c) 2016 Wenzel Jakob , All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Please also refer to the file .github/CONTRIBUTING.md, which clarifies licensing of +external contributions to this project including patches, pull requests, etc. + + +third_party\python-peachpy\LICENSE.rst +-------------------------------------- +============================== +PeachPy license (2-clause BSD) +============================== + +Copyright (c) 2017, Facebook Inc. +Copyright (c) 2013-2017, Georgia Institute of Technology +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\sleef\LICENSE.txt +----------------------------- +Boost Software License - Version 1.0 - August 17th, 2003 + +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. + + +third_party\tensorpipe\LICENSE.txt +---------------------------------- +BSD License + +For TensorPipe software + +Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Meta nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\tensorpipe\third_party\googletest\LICENSE +----------------------------------------------------- +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\tensorpipe\third_party\googletest\googlemock\LICENSE +---------------------------------------------------------------- +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\tensorpipe\third_party\googletest\googlemock\scripts\generator\LICENSE +---------------------------------------------------------------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [2007] Neal Norwitz + Portions Copyright [2007] Google Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + +third_party\tensorpipe\third_party\googletest\googletest\LICENSE +---------------------------------------------------------------- +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +third_party\tensorpipe\third_party\libnop\LICENSE +------------------------------------------------- +Copyright 2017 The Native Object Protocols Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + +third_party\tensorpipe\third_party\libuv\LICENSE +------------------------------------------------ +libuv is licensed for use as follows: + +==== +Copyright (c) 2015-present libuv project contributors. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to +deal in the Software without restriction, including without limitation the +rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +sell copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +IN THE SOFTWARE. +==== + +This license applies to parts of libuv originating from the +https://github.com/joyent/libuv repository: + +==== + +Copyright Joyent, Inc. and other Node contributors. All rights reserved. +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to +deal in the Software without restriction, including without limitation the +rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +sell copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +IN THE SOFTWARE. + +==== + +This license applies to all parts of libuv that are not externally +maintained libraries. + +The externally maintained libraries used by libuv are: + + - tree.h (from FreeBSD), copyright Niels Provos. Two clause BSD license. + + - inet_pton and inet_ntop implementations, contained in src/inet.c, are + copyright the Internet Systems Consortium, Inc., and licensed under the ISC + license. + + - stdint-msvc2008.h (from msinttypes), copyright Alexander Chemeris. Three + clause BSD license. + + - pthread-fixes.c, copyright Google Inc. and Sony Mobile Communications AB. + Three clause BSD license. + + - android-ifaddrs.h, android-ifaddrs.c, copyright Berkeley Software Design + Inc, Kenneth MacKay and Emergya (Cloud4all, FP7/2007-2013, grant agreement + n° 289016). Three clause BSD license. + + +third_party\tensorpipe\third_party\pybind11\LICENSE +--------------------------------------------------- +Copyright (c) 2016 Wenzel Jakob , All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Please also refer to the file CONTRIBUTING.md, which clarifies licensing of +external contributions to this project including patches, pull requests, etc. diff --git a/phivenv/Lib/site-packages/torch-2.8.0.dist-info/METADATA b/phivenv/Lib/site-packages/torch-2.8.0.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..f96b6c3a9d2f9f1b168777b52746505938dd62dd --- /dev/null +++ b/phivenv/Lib/site-packages/torch-2.8.0.dist-info/METADATA @@ -0,0 +1,631 @@ +Metadata-Version: 2.1 +Name: torch +Version: 2.8.0 +Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration +Home-page: https://pytorch.org/ +Download-URL: https://github.com/pytorch/pytorch/tags +Author: PyTorch Team +Author-email: packages@pytorch.org +License: BSD-3-Clause +Project-URL: Homepage, https://pytorch.org/ +Project-URL: Documentation, https://pytorch.org/docs/ +Project-URL: Source, https://github.com/pytorch/pytorch +Project-URL: Forum, https://discuss.pytorch.org/ +Keywords: pytorch,machine learning +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: Intended Audience :: Education +Classifier: Intended Audience :: Science/Research +Classifier: License :: OSI Approved :: BSD License +Classifier: Topic :: Scientific/Engineering +Classifier: Topic :: Scientific/Engineering :: Mathematics +Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence +Classifier: Topic :: Software Development +Classifier: Topic :: Software Development :: Libraries +Classifier: Topic :: Software Development :: Libraries :: Python Modules +Classifier: Programming Language :: C++ +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 +Requires-Python: >=3.9.0 +Description-Content-Type: text/markdown +License-File: LICENSE +License-File: NOTICE +Requires-Dist: filelock +Requires-Dist: typing-extensions >=4.10.0 +Requires-Dist: sympy >=1.13.3 +Requires-Dist: networkx +Requires-Dist: jinja2 +Requires-Dist: fsspec +Requires-Dist: setuptools ; python_version >= "3.12" +Requires-Dist: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == "Linux" and platform_machine == "x86_64" +Requires-Dist: nvidia-cuda-runtime-cu12==12.8.90; platform_system == "Linux" and platform_machine == "x86_64" +Requires-Dist: nvidia-cuda-cupti-cu12==12.8.90; platform_system == "Linux" and platform_machine == "x86_64" +Requires-Dist: nvidia-cudnn-cu12==9.10.2.21; platform_system == "Linux" and platform_machine == "x86_64" +Requires-Dist: nvidia-cublas-cu12==12.8.4.1; platform_system == "Linux" and platform_machine == "x86_64" +Requires-Dist: nvidia-cufft-cu12==11.3.3.83; platform_system == "Linux" and platform_machine == "x86_64" +Requires-Dist: nvidia-curand-cu12==10.3.9.90; platform_system == "Linux" and platform_machine == "x86_64" +Requires-Dist: nvidia-cusolver-cu12==11.7.3.90; platform_system == "Linux" and platform_machine == "x86_64" +Requires-Dist: nvidia-cusparse-cu12==12.5.8.93; platform_system == "Linux" and platform_machine == "x86_64" +Requires-Dist: nvidia-cusparselt-cu12==0.7.1; platform_system == "Linux" and platform_machine == "x86_64" +Requires-Dist: nvidia-nccl-cu12==2.27.3; platform_system == "Linux" and platform_machine == "x86_64" +Requires-Dist: nvidia-nvtx-cu12==12.8.90; platform_system == "Linux" and platform_machine == "x86_64" +Requires-Dist: nvidia-nvjitlink-cu12==12.8.93; platform_system == "Linux" and platform_machine == "x86_64" +Requires-Dist: nvidia-cufile-cu12==1.13.1.3; platform_system == "Linux" and platform_machine == "x86_64" +Requires-Dist: triton==3.4.0; platform_system == "Linux" and platform_machine == "x86_64" +Provides-Extra: opt-einsum +Requires-Dist: opt-einsum >=3.3 ; extra == 'opt-einsum' +Provides-Extra: optree +Requires-Dist: optree >=0.13.0 ; extra == 'optree' +Provides-Extra: pyyaml +Requires-Dist: pyyaml ; extra == 'pyyaml' + +![PyTorch Logo](https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/pytorch-logo-dark.png) + +-------------------------------------------------------------------------------- + +PyTorch is a Python package that provides two high-level features: +- Tensor computation (like NumPy) with strong GPU acceleration +- Deep neural networks built on a tape-based autograd system + +You can reuse your favorite Python packages such as NumPy, SciPy, and Cython to extend PyTorch when needed. + +Our trunk health (Continuous Integration signals) can be found at [hud.pytorch.org](https://hud.pytorch.org/ci/pytorch/pytorch/main). + + + +- [More About PyTorch](#more-about-pytorch) + - [A GPU-Ready Tensor Library](#a-gpu-ready-tensor-library) + - [Dynamic Neural Networks: Tape-Based Autograd](#dynamic-neural-networks-tape-based-autograd) + - [Python First](#python-first) + - [Imperative Experiences](#imperative-experiences) + - [Fast and Lean](#fast-and-lean) + - [Extensions Without Pain](#extensions-without-pain) +- [Installation](#installation) + - [Binaries](#binaries) + - [NVIDIA Jetson Platforms](#nvidia-jetson-platforms) + - [From Source](#from-source) + - [Prerequisites](#prerequisites) + - [NVIDIA CUDA Support](#nvidia-cuda-support) + - [AMD ROCm Support](#amd-rocm-support) + - [Intel GPU Support](#intel-gpu-support) + - [Get the PyTorch Source](#get-the-pytorch-source) + - [Install Dependencies](#install-dependencies) + - [Install PyTorch](#install-pytorch) + - [Adjust Build Options (Optional)](#adjust-build-options-optional) + - [Docker Image](#docker-image) + - [Using pre-built images](#using-pre-built-images) + - [Building the image yourself](#building-the-image-yourself) + - [Building the Documentation](#building-the-documentation) + - [Building a PDF](#building-a-pdf) + - [Previous Versions](#previous-versions) +- [Getting Started](#getting-started) +- [Resources](#resources) +- [Communication](#communication) +- [Releases and Contributing](#releases-and-contributing) +- [The Team](#the-team) +- [License](#license) + + + +## More About PyTorch + +[Learn the basics of PyTorch](https://pytorch.org/tutorials/beginner/basics/intro.html) + +At a granular level, PyTorch is a library that consists of the following components: + +| Component | Description | +| ---- | --- | +| [**torch**](https://pytorch.org/docs/stable/torch.html) | A Tensor library like NumPy, with strong GPU support | +| [**torch.autograd**](https://pytorch.org/docs/stable/autograd.html) | A tape-based automatic differentiation library that supports all differentiable Tensor operations in torch | +| [**torch.jit**](https://pytorch.org/docs/stable/jit.html) | A compilation stack (TorchScript) to create serializable and optimizable models from PyTorch code | +| [**torch.nn**](https://pytorch.org/docs/stable/nn.html) | A neural networks library deeply integrated with autograd designed for maximum flexibility | +| [**torch.multiprocessing**](https://pytorch.org/docs/stable/multiprocessing.html) | Python multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training | +| [**torch.utils**](https://pytorch.org/docs/stable/data.html) | DataLoader and other utility functions for convenience | + +Usually, PyTorch is used either as: + +- A replacement for NumPy to use the power of GPUs. +- A deep learning research platform that provides maximum flexibility and speed. + +Elaborating Further: + +### A GPU-Ready Tensor Library + +If you use NumPy, then you have used Tensors (a.k.a. ndarray). + +![Tensor illustration](./docs/source/_static/img/tensor_illustration.png) + +PyTorch provides Tensors that can live either on the CPU or the GPU and accelerates the +computation by a huge amount. + +We provide a wide variety of tensor routines to accelerate and fit your scientific computation needs +such as slicing, indexing, mathematical operations, linear algebra, reductions. +And they are fast! + +### Dynamic Neural Networks: Tape-Based Autograd + +PyTorch has a unique way of building neural networks: using and replaying a tape recorder. + +Most frameworks such as TensorFlow, Theano, Caffe, and CNTK have a static view of the world. +One has to build a neural network and reuse the same structure again and again. +Changing the way the network behaves means that one has to start from scratch. + +With PyTorch, we use a technique called reverse-mode auto-differentiation, which allows you to +change the way your network behaves arbitrarily with zero lag or overhead. Our inspiration comes +from several research papers on this topic, as well as current and past work such as +[torch-autograd](https://github.com/twitter/torch-autograd), +[autograd](https://github.com/HIPS/autograd), +[Chainer](https://chainer.org), etc. + +While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date. +You get the best of speed and flexibility for your crazy research. + +![Dynamic graph](https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/dynamic_graph.gif) + +### Python First + +PyTorch is not a Python binding into a monolithic C++ framework. +It is built to be deeply integrated into Python. +You can use it naturally like you would use [NumPy](https://www.numpy.org/) / [SciPy](https://www.scipy.org/) / [scikit-learn](https://scikit-learn.org) etc. +You can write your new neural network layers in Python itself, using your favorite libraries +and use packages such as [Cython](https://cython.org/) and [Numba](http://numba.pydata.org/). +Our goal is to not reinvent the wheel where appropriate. + +### Imperative Experiences + +PyTorch is designed to be intuitive, linear in thought, and easy to use. +When you execute a line of code, it gets executed. There isn't an asynchronous view of the world. +When you drop into a debugger or receive error messages and stack traces, understanding them is straightforward. +The stack trace points to exactly where your code was defined. +We hope you never spend hours debugging your code because of bad stack traces or asynchronous and opaque execution engines. + +### Fast and Lean + +PyTorch has minimal framework overhead. We integrate acceleration libraries +such as [Intel MKL](https://software.intel.com/mkl) and NVIDIA ([cuDNN](https://developer.nvidia.com/cudnn), [NCCL](https://developer.nvidia.com/nccl)) to maximize speed. +At the core, its CPU and GPU Tensor and neural network backends +are mature and have been tested for years. + +Hence, PyTorch is quite fast — whether you run small or large neural networks. + +The memory usage in PyTorch is extremely efficient compared to Torch or some of the alternatives. +We've written custom memory allocators for the GPU to make sure that +your deep learning models are maximally memory efficient. +This enables you to train bigger deep learning models than before. + +### Extensions Without Pain + +Writing new neural network modules, or interfacing with PyTorch's Tensor API was designed to be straightforward +and with minimal abstractions. + +You can write new neural network layers in Python using the torch API +[or your favorite NumPy-based libraries such as SciPy](https://pytorch.org/tutorials/advanced/numpy_extensions_tutorial.html). + +If you want to write your layers in C/C++, we provide a convenient extension API that is efficient and with minimal boilerplate. +No wrapper code needs to be written. You can see [a tutorial here](https://pytorch.org/tutorials/advanced/cpp_extension.html) and [an example here](https://github.com/pytorch/extension-cpp). + + +## Installation + +### Binaries +Commands to install binaries via Conda or pip wheels are on our website: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/) + + +#### NVIDIA Jetson Platforms + +Python wheels for NVIDIA's Jetson Nano, Jetson TX1/TX2, Jetson Xavier NX/AGX, and Jetson AGX Orin are provided [here](https://forums.developer.nvidia.com/t/pytorch-for-jetson-version-1-10-now-available/72048) and the L4T container is published [here](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/l4t-pytorch) + +They require JetPack 4.2 and above, and [@dusty-nv](https://github.com/dusty-nv) and [@ptrblck](https://github.com/ptrblck) are maintaining them. + + +### From Source + +#### Prerequisites +If you are installing from source, you will need: +- Python 3.9 or later +- A compiler that fully supports C++17, such as clang or gcc (gcc 9.4.0 or newer is required, on Linux) +- Visual Studio or Visual Studio Build Tool (Windows only) + +\* PyTorch CI uses Visual C++ BuildTools, which come with Visual Studio Enterprise, +Professional, or Community Editions. You can also install the build tools from +https://visualstudio.microsoft.com/visual-cpp-build-tools/. The build tools *do not* +come with Visual Studio Code by default. + +An example of environment setup is shown below: + +* Linux: + +```bash +$ source /bin/activate +$ conda create -y -n +$ conda activate +``` + +* Windows: + +```bash +$ source \Scripts\activate.bat +$ conda create -y -n +$ conda activate +$ call "C:\Program Files\Microsoft Visual Studio\\Community\VC\Auxiliary\Build\vcvarsall.bat" x64 +``` + +A conda environment is not required. You can also do a PyTorch build in a +standard virtual environment, e.g., created with tools like `uv`, provided +your system has installed all the necessary dependencies unavailable as pip +packages (e.g., CUDA, MKL.) + +##### NVIDIA CUDA Support +If you want to compile with CUDA support, [select a supported version of CUDA from our support matrix](https://pytorch.org/get-started/locally/), then install the following: +- [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) +- [NVIDIA cuDNN](https://developer.nvidia.com/cudnn) v8.5 or above +- [Compiler](https://gist.github.com/ax3l/9489132) compatible with CUDA + +Note: You could refer to the [cuDNN Support Matrix](https://docs.nvidia.com/deeplearning/cudnn/backend/latest/reference/support-matrix.html) for cuDNN versions with the various supported CUDA, CUDA driver and NVIDIA hardware + +If you want to disable CUDA support, export the environment variable `USE_CUDA=0`. +Other potentially useful environment variables may be found in `setup.py`. If +CUDA is installed in a non-standard location, set PATH so that the nvcc you +want to use can be found (e.g., `export PATH=/usr/local/cuda-12.8/bin:$PATH`). + +If you are building for NVIDIA's Jetson platforms (Jetson Nano, TX1, TX2, AGX Xavier), Instructions to install PyTorch for Jetson Nano are [available here](https://devtalk.nvidia.com/default/topic/1049071/jetson-nano/pytorch-for-jetson-nano/) + +##### AMD ROCm Support +If you want to compile with ROCm support, install +- [AMD ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html) 4.0 and above installation +- ROCm is currently supported only for Linux systems. + +By default the build system expects ROCm to be installed in `/opt/rocm`. If ROCm is installed in a different directory, the `ROCM_PATH` environment variable must be set to the ROCm installation directory. The build system automatically detects the AMD GPU architecture. Optionally, the AMD GPU architecture can be explicitly set with the `PYTORCH_ROCM_ARCH` environment variable [AMD GPU architecture](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html#supported-gpus) + +If you want to disable ROCm support, export the environment variable `USE_ROCM=0`. +Other potentially useful environment variables may be found in `setup.py`. + +##### Intel GPU Support +If you want to compile with Intel GPU support, follow these +- [PyTorch Prerequisites for Intel GPUs](https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpus.html) instructions. +- Intel GPU is supported for Linux and Windows. + +If you want to disable Intel GPU support, export the environment variable `USE_XPU=0`. +Other potentially useful environment variables may be found in `setup.py`. + +#### Get the PyTorch Source +```bash +git clone https://github.com/pytorch/pytorch +cd pytorch +# if you are updating an existing checkout +git submodule sync +git submodule update --init --recursive +``` + +#### Install Dependencies + +**Common** + +```bash +conda install cmake ninja +# Run this command from the PyTorch directory after cloning the source code using the “Get the PyTorch Source“ section below +pip install -r requirements.txt +``` + +**On Linux** + +```bash +pip install mkl-static mkl-include +# CUDA only: Add LAPACK support for the GPU if needed +# magma installation: run with active conda environment. specify CUDA version to install +.ci/docker/common/install_magma_conda.sh 12.4 + +# (optional) If using torch.compile with inductor/triton, install the matching version of triton +# Run from the pytorch directory after cloning +# For Intel GPU support, please explicitly `export USE_XPU=1` before running command. +make triton +``` + +**On MacOS** + +```bash +# Add this package on intel x86 processor machines only +pip install mkl-static mkl-include +# Add these packages if torch.distributed is needed +conda install pkg-config libuv +``` + +**On Windows** + +```bash +pip install mkl-static mkl-include +# Add these packages if torch.distributed is needed. +# Distributed package support on Windows is a prototype feature and is subject to changes. +conda install -c conda-forge libuv=1.39 +``` + +#### Install PyTorch +**On Linux** + +If you're compiling for AMD ROCm then first run this command: +```bash +# Only run this if you're compiling for ROCm +python tools/amd_build/build_amd.py +``` + +Install PyTorch +```bash +export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}" +python setup.py develop +``` + +**On macOS** + +```bash +python3 setup.py develop +``` + +**On Windows** + +If you want to build legacy python code, please refer to [Building on legacy code and CUDA](https://github.com/pytorch/pytorch/blob/main/CONTRIBUTING.md#building-on-legacy-code-and-cuda) + +**CPU-only builds** + +In this mode PyTorch computations will run on your CPU, not your GPU. + +```cmd +python setup.py develop +``` + +Note on OpenMP: The desired OpenMP implementation is Intel OpenMP (iomp). In order to link against iomp, you'll need to manually download the library and set up the building environment by tweaking `CMAKE_INCLUDE_PATH` and `LIB`. The instruction [here](https://github.com/pytorch/pytorch/blob/main/docs/source/notes/windows.rst#building-from-source) is an example for setting up both MKL and Intel OpenMP. Without these configurations for CMake, Microsoft Visual C OpenMP runtime (vcomp) will be used. + +**CUDA based build** + +In this mode PyTorch computations will leverage your GPU via CUDA for faster number crunching + +[NVTX](https://docs.nvidia.com/gameworks/content/gameworkslibrary/nvtx/nvidia_tools_extension_library_nvtx.htm) is needed to build Pytorch with CUDA. +NVTX is a part of CUDA distributive, where it is called "Nsight Compute". To install it onto an already installed CUDA run CUDA installation once again and check the corresponding checkbox. +Make sure that CUDA with Nsight Compute is installed after Visual Studio. + +Currently, VS 2017 / 2019, and Ninja are supported as the generator of CMake. If `ninja.exe` is detected in `PATH`, then Ninja will be used as the default generator, otherwise, it will use VS 2017 / 2019. +
If Ninja is selected as the generator, the latest MSVC will get selected as the underlying toolchain. + +Additional libraries such as +[Magma](https://developer.nvidia.com/magma), [oneDNN, a.k.a. MKLDNN or DNNL](https://github.com/oneapi-src/oneDNN), and [Sccache](https://github.com/mozilla/sccache) are often needed. Please refer to the [installation-helper](https://github.com/pytorch/pytorch/tree/main/.ci/pytorch/win-test-helpers/installation-helpers) to install them. + +You can refer to the [build_pytorch.bat](https://github.com/pytorch/pytorch/blob/main/.ci/pytorch/win-test-helpers/build_pytorch.bat) script for some other environment variables configurations + + +```cmd +cmd + +:: Set the environment variables after you have downloaded and unzipped the mkl package, +:: else CMake would throw an error as `Could NOT find OpenMP`. +set CMAKE_INCLUDE_PATH={Your directory}\mkl\include +set LIB={Your directory}\mkl\lib;%LIB% + +:: Read the content in the previous section carefully before you proceed. +:: [Optional] If you want to override the underlying toolset used by Ninja and Visual Studio with CUDA, please run the following script block. +:: "Visual Studio 2019 Developer Command Prompt" will be run automatically. +:: Make sure you have CMake >= 3.12 before you do this when you use the Visual Studio generator. +set CMAKE_GENERATOR_TOOLSET_VERSION=14.27 +set DISTUTILS_USE_SDK=1 +for /f "usebackq tokens=*" %i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -version [15^,17^) -products * -latest -property installationPath`) do call "%i\VC\Auxiliary\Build\vcvarsall.bat" x64 -vcvars_ver=%CMAKE_GENERATOR_TOOLSET_VERSION% + +:: [Optional] If you want to override the CUDA host compiler +set CUDAHOSTCXX=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.27.29110\bin\HostX64\x64\cl.exe + +python setup.py develop + +``` + +**Intel GPU builds** + +In this mode PyTorch with Intel GPU support will be built. + +Please make sure [the common prerequisites](#prerequisites) as well as [the prerequisites for Intel GPU](#intel-gpu-support) are properly installed and the environment variables are configured prior to starting the build. For build tool support, `Visual Studio 2022` is required. + +Then PyTorch can be built with the command: + +```cmd +:: CMD Commands: +:: Set the CMAKE_PREFIX_PATH to help find corresponding packages +:: %CONDA_PREFIX% only works after `conda activate custom_env` + +if defined CMAKE_PREFIX_PATH ( + set "CMAKE_PREFIX_PATH=%CONDA_PREFIX%\Library;%CMAKE_PREFIX_PATH%" +) else ( + set "CMAKE_PREFIX_PATH=%CONDA_PREFIX%\Library" +) + +python setup.py develop +``` + +##### Adjust Build Options (Optional) + +You can adjust the configuration of cmake variables optionally (without building first), by doing +the following. For example, adjusting the pre-detected directories for CuDNN or BLAS can be done +with such a step. + +On Linux +```bash +export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}" +CMAKE_ONLY=1 python setup.py build +ccmake build # or cmake-gui build +``` + +On macOS +```bash +export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}" +MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ CMAKE_ONLY=1 python setup.py build +ccmake build # or cmake-gui build +``` + +### Docker Image + +#### Using pre-built images + +You can also pull a pre-built docker image from Docker Hub and run with docker v19.03+ + +```bash +docker run --gpus all --rm -ti --ipc=host pytorch/pytorch:latest +``` + +Please note that PyTorch uses shared memory to share data between processes, so if torch multiprocessing is used (e.g. +for multithreaded data loaders) the default shared memory segment size that container runs with is not enough, and you +should increase shared memory size either with `--ipc=host` or `--shm-size` command line options to `nvidia-docker run`. + +#### Building the image yourself + +**NOTE:** Must be built with a docker version > 18.06 + +The `Dockerfile` is supplied to build images with CUDA 11.1 support and cuDNN v8. +You can pass `PYTHON_VERSION=x.y` make variable to specify which Python version is to be used by Miniconda, or leave it +unset to use the default. + +```bash +make -f docker.Makefile +# images are tagged as docker.io/${your_docker_username}/pytorch +``` + +You can also pass the `CMAKE_VARS="..."` environment variable to specify additional CMake variables to be passed to CMake during the build. +See [setup.py](./setup.py) for the list of available variables. + +```bash +make -f docker.Makefile +``` + +### Building the Documentation + +To build documentation in various formats, you will need [Sphinx](http://www.sphinx-doc.org) +and the pytorch_sphinx_theme2. + +Before you build the documentation locally, ensure `torch` is +installed in your environment. For small fixes, you can install the +nightly version as described in [Getting Started](https://pytorch.org/get-started/locally/). + +For more complex fixes, such as adding a new module and docstrings for +the new module, you might need to install torch [from source](#from-source). +See [Docstring Guidelines](https://github.com/pytorch/pytorch/wiki/Docstring-Guidelines) +for docstring conventions. + +```bash +cd docs/ +pip install -r requirements.txt +make html +make serve +``` + +Run `make` to get a list of all available output formats. + +If you get a katex error run `npm install katex`. If it persists, try +`npm install -g katex` + +> [!NOTE] +> If you installed `nodejs` with a different package manager (e.g., +> `conda`) then `npm` will probably install a version of `katex` that is not +> compatible with your version of `nodejs` and doc builds will fail. +> A combination of versions that is known to work is `node@6.13.1` and +> `katex@0.13.18`. To install the latter with `npm` you can run +> ```npm install -g katex@0.13.18``` + +> [!NOTE] +> If you see a numpy incompatibility error, run: +> ``` +> pip install 'numpy<2' +> ``` + +When you make changes to the dependencies run by CI, edit the +`.ci/docker/requirements-docs.txt` file. + +#### Building a PDF + +To compile a PDF of all PyTorch documentation, ensure you have +`texlive` and LaTeX installed. On macOS, you can install them using: + +``` +brew install --cask mactex +``` + +To create the PDF: + +1. Run: + + ``` + make latexpdf + ``` + + This will generate the necessary files in the `build/latex` directory. + +2. Navigate to this directory and execute: + + ``` + make LATEXOPTS="-interaction=nonstopmode" + ``` + + This will produce a `pytorch.pdf` with the desired content. Run this + command one more time so that it generates the correct table + of contents and index. + +> [!NOTE] +> To view the Table of Contents, switch to the **Table of Contents** +> view in your PDF viewer. + + +### Previous Versions + +Installation instructions and binaries for previous PyTorch versions may be found +on [our website](https://pytorch.org/get-started/previous-versions). + + +## Getting Started + +Three-pointers to get you started: +- [Tutorials: get you started with understanding and using PyTorch](https://pytorch.org/tutorials/) +- [Examples: easy to understand PyTorch code across all domains](https://github.com/pytorch/examples) +- [The API Reference](https://pytorch.org/docs/) +- [Glossary](https://github.com/pytorch/pytorch/blob/main/GLOSSARY.md) + +## Resources + +* [PyTorch.org](https://pytorch.org/) +* [PyTorch Tutorials](https://pytorch.org/tutorials/) +* [PyTorch Examples](https://github.com/pytorch/examples) +* [PyTorch Models](https://pytorch.org/hub/) +* [Intro to Deep Learning with PyTorch from Udacity](https://www.udacity.com/course/deep-learning-pytorch--ud188) +* [Intro to Machine Learning with PyTorch from Udacity](https://www.udacity.com/course/intro-to-machine-learning-nanodegree--nd229) +* [Deep Neural Networks with PyTorch from Coursera](https://www.coursera.org/learn/deep-neural-networks-with-pytorch) +* [PyTorch Twitter](https://twitter.com/PyTorch) +* [PyTorch Blog](https://pytorch.org/blog/) +* [PyTorch YouTube](https://www.youtube.com/channel/UCWXI5YeOsh03QvJ59PMaXFw) + +## Communication +* Forums: Discuss implementations, research, etc. https://discuss.pytorch.org +* GitHub Issues: Bug reports, feature requests, install issues, RFCs, thoughts, etc. +* Slack: The [PyTorch Slack](https://pytorch.slack.com/) hosts a primary audience of moderate to experienced PyTorch users and developers for general chat, online discussions, collaboration, etc. If you are a beginner looking for help, the primary medium is [PyTorch Forums](https://discuss.pytorch.org). If you need a slack invite, please fill this form: https://goo.gl/forms/PP1AGvNHpSaJP8to1 +* Newsletter: No-noise, a one-way email newsletter with important announcements about PyTorch. You can sign-up here: https://eepurl.com/cbG0rv +* Facebook Page: Important announcements about PyTorch. https://www.facebook.com/pytorch +* For brand guidelines, please visit our website at [pytorch.org](https://pytorch.org/) + +## Releases and Contributing + +Typically, PyTorch has three minor releases a year. Please let us know if you encounter a bug by [filing an issue](https://github.com/pytorch/pytorch/issues). + +We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion. + +If you plan to contribute new features, utility functions, or extensions to the core, please first open an issue and discuss the feature with us. +Sending a PR without discussion might end up resulting in a rejected PR because we might be taking the core in a different direction than you might be aware of. + +To learn more about making a contribution to Pytorch, please see our [Contribution page](CONTRIBUTING.md). For more information about PyTorch releases, see [Release page](RELEASE.md). + +## The Team + +PyTorch is a community-driven project with several skillful engineers and researchers contributing to it. + +PyTorch is currently maintained by [Soumith Chintala](http://soumith.ch), [Gregory Chanan](https://github.com/gchanan), [Dmytro Dzhulgakov](https://github.com/dzhulgakov), [Edward Yang](https://github.com/ezyang), and [Nikita Shulga](https://github.com/malfet) with major contributions coming from hundreds of talented individuals in various forms and means. +A non-exhaustive but growing list needs to mention: [Trevor Killeen](https://github.com/killeent), [Sasank Chilamkurthy](https://github.com/chsasank), [Sergey Zagoruyko](https://github.com/szagoruyko), [Adam Lerer](https://github.com/adamlerer), [Francisco Massa](https://github.com/fmassa), [Alykhan Tejani](https://github.com/alykhantejani), [Luca Antiga](https://github.com/lantiga), [Alban Desmaison](https://github.com/albanD), [Andreas Koepf](https://github.com/andreaskoepf), [James Bradbury](https://github.com/jekbradbury), [Zeming Lin](https://github.com/ebetica), [Yuandong Tian](https://github.com/yuandong-tian), [Guillaume Lample](https://github.com/glample), [Marat Dukhan](https://github.com/Maratyszcza), [Natalia Gimelshein](https://github.com/ngimel), [Christian Sarofeen](https://github.com/csarofeen), [Martin Raison](https://github.com/martinraison), [Edward Yang](https://github.com/ezyang), [Zachary Devito](https://github.com/zdevito). + +Note: This project is unrelated to [hughperkins/pytorch](https://github.com/hughperkins/pytorch) with the same name. Hugh is a valuable contributor to the Torch community and has helped with many things Torch and PyTorch. + +## License + +PyTorch has a BSD-style license, as found in the [LICENSE](LICENSE) file. diff --git a/phivenv/Lib/site-packages/torch-2.8.0.dist-info/NOTICE b/phivenv/Lib/site-packages/torch-2.8.0.dist-info/NOTICE new file mode 100644 index 0000000000000000000000000000000000000000..c08c04389d281a4bf5ef77fc5426ffac8232dc3a --- /dev/null +++ b/phivenv/Lib/site-packages/torch-2.8.0.dist-info/NOTICE @@ -0,0 +1,456 @@ +======================================================================= +Software under third_party +======================================================================= +Software libraries under third_party are provided as github submodule +links, and their content is not part of the Caffe2 codebase. Their +licences can be found under the respective software repositories. + +======================================================================= +Earlier BSD License +======================================================================= +Early development of Caffe2 in 2015 and early 2016 is licensed under the +BSD license. The license is attached below: + +All contributions by Facebook: +Copyright (c) 2016 Facebook Inc. + +All contributions by Google: +Copyright (c) 2015 Google Inc. +All rights reserved. + +All contributions by Yangqing Jia: +Copyright (c) 2015 Yangqing Jia +All rights reserved. + +All contributions by Kakao Brain: +Copyright 2019-2020 Kakao Brain + +All other contributions: +Copyright(c) 2015, 2016 the respective contributors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +======================================================================= +Caffe's BSD License +======================================================================= +Some parts of the caffe2 code is derived from the original Caffe code, which is +created by Yangqing Jia and is now a BSD-licensed open-source project. The Caffe +license is as follows: + +COPYRIGHT + +All contributions by the University of California: +Copyright (c) 2014, The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2014, the respective contributors +All rights reserved. + +Caffe uses a shared copyright model: each contributor holds copyright over +their contributions to Caffe. The project versioning records all such +contribution and copyright details. If a contributor wants to further mark +their specific copyright on a particular contribution, they should indicate +their copyright solely in the commit message of the change when it is +committed. + +LICENSE + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +CONTRIBUTION AGREEMENT + +By contributing to the BVLC/caffe repository through pull-request, comment, +or otherwise, the contributor releases their content to the +license and copyright terms herein. + +======================================================================= +Caffe2's Apache License +======================================================================= + +This repo contains Caffe2 code, which was previously licensed under +Apache License Version 2.0: + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +======================================================================= +Cephes's 3-Clause BSD License +======================================================================= + +Code derived from implementations in the Cephes Math Library should mention +its derivation and reference the following license: + + 3-Clause BSD License for the Cephes Math Library + Copyright (c) 2018, Steven Moshier + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + + * Neither the name of the nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL Steven Moshier BE LIABLE FOR ANY + DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +======================================================================= +SciPy's 3-Clause BSD License +======================================================================= + +Code derived from implementations in SciPy should mention its derivation +and reference the following license: + + Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +======================================================================= +Boost's 1.0 Software License +======================================================================= + +Code derived from implementations in Boost 1.0 should mention its +derivation and reference the following license: + + Boost Software License - Version 1.0 - August 17th, 2003 + + Permission is hereby granted, free of charge, to any person or organization + obtaining a copy of the software and accompanying documentation covered by + this license (the "Software") to use, reproduce, display, distribute, + execute, and transmit the Software, and to prepare derivative works of the + Software, and to permit third-parties to whom the Software is furnished to + do so, all subject to the following: + + The copyright notices in the Software and this entire statement, including + the above license grant, this restriction and the following disclaimer, + must be included in all copies of the Software, in whole or in part, and + all derivative works of the Software, unless such copies or derivative + works are solely in the form of machine-executable object code generated by + a source language processor. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT + SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE + FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + DEALINGS IN THE SOFTWARE. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +======================================================================= +PILLOW-SIMD Software License +======================================================================= + +Code derived from implementations in PILLOW-SIMD should mention its derivation +and reference the following license: + + The Python Imaging Library (PIL) is + + Copyright © 1997-2011 by Secret Labs AB + Copyright © 1995-2011 by Fredrik Lundh + + Pillow is the friendly PIL fork. It is + + Copyright © 2010-2022 by Alex Clark and contributors + + Like PIL, Pillow is licensed under the open source HPND License: + + By obtaining, using, and/or copying this software and/or its associated + documentation, you agree that you have read, understood, and will comply + with the following terms and conditions: + + Permission to use, copy, modify, and distribute this software and its + associated documentation for any purpose and without fee is hereby granted, + provided that the above copyright notice appears in all copies, and that + both that copyright notice and this permission notice appear in supporting + documentation, and that the name of Secret Labs AB or the author not be + used in advertising or publicity pertaining to distribution of the software + without specific, written prior permission. + + SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS + SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. + IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR BE LIABLE FOR ANY SPECIAL, + INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM + LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE + OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + PERFORMANCE OF THIS SOFTWARE. diff --git a/phivenv/Lib/site-packages/torch-2.8.0.dist-info/RECORD b/phivenv/Lib/site-packages/torch-2.8.0.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..3efd236841cb800c8011984eaf0957deba9a5346 --- /dev/null +++ b/phivenv/Lib/site-packages/torch-2.8.0.dist-info/RECORD @@ -0,0 +1,13519 @@ +../../Scripts/torchfrtrace.exe,sha256=hSPQntEQf053ybBMC3uWgcfEdDiLRGzkbhEsneFNH_U,106360 +../../Scripts/torchrun.exe,sha256=6a6w0TkBYD4M5PXVlCFzWZzxihAfRoV8CrL00ZfO-Ag,106351 +functorch/_C.cp39-win_amd64.pyd,sha256=CY5CUYCtykJiWQu8GKtp09e7KCbpMeIhzl0uSWKo-EM,320000 +functorch/__init__.py,sha256=lkoKmqHGN8LKYTeSPXXXh28z4ehGyK9gTBtzohyNuu0,1076 +functorch/__pycache__/__init__.cpython-39.pyc,, +functorch/_src/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +functorch/_src/__pycache__/__init__.cpython-39.pyc,, +functorch/_src/aot_autograd/__init__.py,sha256=9blGZUidbjhxkEKHLqpZTmhU_o6twasOgrZ0li9RcLk,299 +functorch/_src/aot_autograd/__pycache__/__init__.cpython-39.pyc,, +functorch/_src/eager_transforms/__init__.py,sha256=CrfEZefUvrS8WmAmxFskj-Lbkvee502sdf1RpIaLl7U,298 +functorch/_src/eager_transforms/__pycache__/__init__.cpython-39.pyc,, +functorch/_src/make_functional/__init__.py,sha256=vo-yXbmEA9uzHQayl8mhCGArugyZe2vwHaj2fiklD3w,239 +functorch/_src/make_functional/__pycache__/__init__.cpython-39.pyc,, +functorch/_src/vmap/__init__.py,sha256=FrCE1RM1QK04BgLYVsXpZbetiVjtTri_1h7FGhjB2Mw,483 +functorch/_src/vmap/__pycache__/__init__.cpython-39.pyc,, +functorch/compile/__init__.py,sha256=SY48gIkfVpniQqcBqFZjEEAh2ID7e8HfUSzrv_rO9PA,786 +functorch/compile/__pycache__/__init__.cpython-39.pyc,, +functorch/dim/__init__.py,sha256=eEK-oQhTlqlhuOVnh9KCYwvCDbKccag5-tFQ8lXk95g,4837 +functorch/dim/__pycache__/__init__.cpython-39.pyc,, +functorch/dim/__pycache__/batch_tensor.cpython-39.pyc,, +functorch/dim/__pycache__/delayed_mul_tensor.cpython-39.pyc,, +functorch/dim/__pycache__/dim.cpython-39.pyc,, +functorch/dim/__pycache__/magic_trace.cpython-39.pyc,, +functorch/dim/__pycache__/op_properties.cpython-39.pyc,, +functorch/dim/__pycache__/reference.cpython-39.pyc,, +functorch/dim/__pycache__/tree_map.cpython-39.pyc,, +functorch/dim/__pycache__/wrap_type.cpython-39.pyc,, +functorch/dim/batch_tensor.py,sha256=n0TQMZkB81-2SvXSG9_pNVKjgY01-cWNgXCG6Gqfu4c,694 +functorch/dim/delayed_mul_tensor.py,sha256=KzetKaQs651FxwPJ_i7nf4GgPaq1OtLTLld8fuO-XB8,2453 +functorch/dim/dim.py,sha256=QK3zUOTCUk8z7OwdqPfyx73ssCIK6QHs9Rb1pjxyybE,3502 +functorch/dim/magic_trace.py,sha256=fjSSleCOyqpnihB2NukPIMcNJgQ0fNcKDbOVpahwing,1371 +functorch/dim/op_properties.py,sha256=GzHOa2ulADUvBOatQIxHbhLeUXKrclZxHRq4XPcUgms,6999 +functorch/dim/reference.py,sha256=tBlFMj7lgfzHkHMjnxLTkUeTB6nMvUd-ET3cD26RHlg,20977 +functorch/dim/tree_map.py,sha256=fl97CMvOv37UFqi3hgjZl3B1cfRW-3kzURx72ZMinWY,390 +functorch/dim/wrap_type.py,sha256=6qad-plrk0y-G0_czsWXUZAibCmbLdr1FD1j6FvrQJI,1943 +functorch/einops/__init__.py,sha256=ho1steDFM9r0hgvRWZWx9k6SkNzOMfH89Q35J79fW9c,63 +functorch/einops/__pycache__/__init__.cpython-39.pyc,, +functorch/einops/__pycache__/_parsing.cpython-39.pyc,, +functorch/einops/__pycache__/rearrange.cpython-39.pyc,, +functorch/einops/_parsing.py,sha256=SyX9AUEU6qPqUtJxpPuEl3G6c3xQWrwEk8lgOezQDiQ,12617 +functorch/einops/rearrange.py,sha256=9QOkRuibL8bqrWghknaTnw813OsiiiFL6N8SYTOOQ90,8299 +functorch/experimental/__init__.py,sha256=3ZNvV4NJakDHna7F96XCmv_KdQNUyUiDcFRe8xQWHlg,278 +functorch/experimental/__pycache__/__init__.cpython-39.pyc,, +functorch/experimental/__pycache__/control_flow.cpython-39.pyc,, +functorch/experimental/__pycache__/ops.cpython-39.pyc,, +functorch/experimental/control_flow.py,sha256=2OxWUXzrCW0XaSpQkK85UwUd5gHsHFN2V_qP5gAF4-o,150 +functorch/experimental/ops.py,sha256=g_LXSb0WzSHiOLW4OLkIiuhiHvHPHNtGpBN30Wss6Go,58 +torch-2.8.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +torch-2.8.0.dist-info/LICENSE,sha256=1hcB6HTDmoWNksR8LL872LnkmAMqGK5Cgy-6WRdSzh8,506902 +torch-2.8.0.dist-info/METADATA,sha256=IdWcvpg7Iuctq1uniGF6YxpOwwNXVEzpDKfXWqUbPAQ,30473 +torch-2.8.0.dist-info/NOTICE,sha256=HD3sbANCx-2eEq9ZspbHux05SvXPiB23QZHMgpnvzlw,24088 +torch-2.8.0.dist-info/RECORD,, +torch-2.8.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch-2.8.0.dist-info/WHEEL,sha256=Z6c-bE0pUM47a70GvqO_SvH_XXU0lm62gEAKtoNJ08A,100 +torch-2.8.0.dist-info/entry_points.txt,sha256=b8t38q301MKrL4Cxfn-q_mop6bMzdkZsKmm2HsnNsmQ,199 +torch-2.8.0.dist-info/top_level.txt,sha256=MsBcfJyMU15lW1efu5w7Tzd4MenrYHiuaixbHMfAoco,25 +torch/_C.cp39-win_amd64.pyd,sha256=efiAlA_YBoO5G1fgk6trs4Nh_4GubntSenDa_wD3Thg,10752 +torch/_C/_VariableFunctions.pyi,sha256=JavORauKFRu8ek2vuGJPzd04g9Vl1kuYevoBUCVz0G0,1198047 +torch/_C/__init__.pyi,sha256=sVU-poPViRc9ACz4kuipflDVpDs80beKNOFvNqQdlVo,419913 +torch/_C/_aoti.pyi,sha256=rohgn8QJCZucauXBUY332ut9NYfYmQDQxd1-MEzMqfw,6179 +torch/_C/_autograd.pyi,sha256=BLHIu721imnM343Wfk-EEBuRnCbUaWiVZLcpugOQjV0,4937 +torch/_C/_cpu.pyi,sha256=AxfveLRRrfQIaKZzyPyOKrBdLqmCJhN1xiA2Ga-hGDw,447 +torch/_C/_cudnn.pyi,sha256=QZJd1Z0nOV5-vdQk0c0dXFADrgQ-D2qYypn4sb64lFE,331 +torch/_C/_cusparselt.pyi,sha256=STZASq5R_xHmWNc23FYTi9xOuN6vX-YagUHiTPASNwY,33 +torch/_C/_distributed_autograd.pyi,sha256=RkSSkafXavZ08c2YDbRQES2H_RT309QTczd15rH14EU,925 +torch/_C/_distributed_c10d.pyi,sha256=s8cn8PU5uqxkRHRQPvp5rkMZ251_4L5xf6mcUOWy_Jk,24059 +torch/_C/_distributed_rpc.pyi,sha256=26VDsJyb4ooKY5tjFJdAxsEblhTyz2p2OivFNHW4Cso,6268 +torch/_C/_distributed_rpc_testing.pyi,sha256=ORYYhzRBVptDNYXzqKk-VWDm4MEauaxB0gBlEmrzebU,1049 +torch/_C/_dynamo/__init__.pyi,sha256=FXgJGha1KPh6dXYj_N_VRqxStKaU49xE8Tjwy-igDzY,170 +torch/_C/_dynamo/compiled_autograd.pyi,sha256=a86TIBwdhZQXIi0VkdQgusfB8RZJrXCKsl4ZWHycQwo,526 +torch/_C/_dynamo/eval_frame.pyi,sha256=eQ0JYXTumocWo_gXm87u7hN9BRgwTZE2p4U-pjJomb4,2242 +torch/_C/_dynamo/guards.pyi,sha256=aPGu0ejb7aGPQSokSWf0gkMsdXNz-J-_NCOwq347Fp4,5267 +torch/_C/_export/__init__.pyi,sha256=xmsrWjm6_5M5FsoqReL-K0xUCE31PlysmTwuMXnRzUs,266 +torch/_C/_export/pt2_archive_constants.pyi,sha256=AFNF3v4yVq-KSzQInn2dIuzPWBu8OJH2k79nGEiVzBU,698 +torch/_C/_functions.pyi,sha256=6IcSjn89EeI0yZePCn5I7JH3-FbdRJinuplSUUEXJOI,627 +torch/_C/_functorch.pyi,sha256=npdaYEIIBAVVwXUkC9_cCN_ZDQRbNEnktpAJXdSF7qY,3326 +torch/_C/_instruction_counter.pyi,sha256=euVpDw5-_od7VDlZXexNsfpkv0C4cxFKgLK6SIdS7Bk,113 +torch/_C/_itt.pyi,sha256=nrLlzZvGArWL5299sUDM050Or-Id-JWc-qkmFGISri4,174 +torch/_C/_lazy.pyi,sha256=KRIpi3f87UcYNhqe1gVvcIzCI2mEkfglmHSv5pH4e2k,1041 +torch/_C/_lazy_ts_backend.pyi,sha256=7te5WIANu3IYJl9hggNBpUX0g1RFrljAXHoXXWNTD0I,338 +torch/_C/_monitor.pyi,sha256=UWx6Tvy0-mHYBgmqi3Fkl1hSb9jwe3t5v4-7sItCvgU,1496 +torch/_C/_nn.pyi,sha256=IcakUkd6-IIU6ttb2itluWB80zCEiaum4uzehysCSy0,5146 +torch/_C/_nvtx.pyi,sha256=u34pMm7E1ue-HvrnPXrh7ADvFTiriXQJUgkCaDGtFaw,389 +torch/_C/_onnx.pyi,sha256=pKWakDSrFUkLPqHSECJhNZ6cHb0qfdKz6kD_mi05q8U,749 +torch/_C/_profiler.pyi,sha256=ul3dFDFLq11S68JDUB1Kq8vccqFWhqI_ZEXGR-rbYYE,6490 +torch/_C/_verbose.pyi,sha256=7KJppleT0c1wLkIser5ezi9K4I-f6ecUMvB9a7LmYyA,137 +torch/_C_flatbuffer/__init__.pyi,sha256=mHUoCyZPzfJNZtH-lvh_DNh7P7RMnVQqYmb_jZoBr6o,555 +torch/_VF.py,sha256=hJLG33MTfEPRgdChJxIUR82WMiiqNCfTW5DINBJBnho,695 +torch/_VF.pyi,sha256=JavORauKFRu8ek2vuGJPzd04g9Vl1kuYevoBUCVz0G0,1198047 +torch/__config__.py,sha256=JQ9izWG915YpLkJQxG9LAFEmB39cZwsv8SOtD2Ueo_g,596 +torch/__future__.py,sha256=wcFs07Ls3wHda59z5xJrMU7_Cw7s5PabTdCVFwY7K50,3260 +torch/__init__.py,sha256=bWyS0LEJH8j0N8_oIo6cmtCI3A-vhEax34c49hwINPI,105826 +torch/__pycache__/_VF.cpython-39.pyc,, +torch/__pycache__/__config__.cpython-39.pyc,, +torch/__pycache__/__future__.cpython-39.pyc,, +torch/__pycache__/__init__.cpython-39.pyc,, +torch/__pycache__/_appdirs.cpython-39.pyc,, +torch/__pycache__/_classes.cpython-39.pyc,, +torch/__pycache__/_compile.cpython-39.pyc,, +torch/__pycache__/_custom_ops.cpython-39.pyc,, +torch/__pycache__/_deploy.cpython-39.pyc,, +torch/__pycache__/_environment.cpython-39.pyc,, +torch/__pycache__/_guards.cpython-39.pyc,, +torch/__pycache__/_jit_internal.cpython-39.pyc,, +torch/__pycache__/_linalg_utils.cpython-39.pyc,, +torch/__pycache__/_lobpcg.cpython-39.pyc,, +torch/__pycache__/_lowrank.cpython-39.pyc,, +torch/__pycache__/_meta_registrations.cpython-39.pyc,, +torch/__pycache__/_namedtensor_internals.cpython-39.pyc,, +torch/__pycache__/_ops.cpython-39.pyc,, +torch/__pycache__/_python_dispatcher.cpython-39.pyc,, +torch/__pycache__/_size_docs.cpython-39.pyc,, +torch/__pycache__/_sources.cpython-39.pyc,, +torch/__pycache__/_storage_docs.cpython-39.pyc,, +torch/__pycache__/_streambase.cpython-39.pyc,, +torch/__pycache__/_tensor.cpython-39.pyc,, +torch/__pycache__/_tensor_docs.cpython-39.pyc,, +torch/__pycache__/_tensor_str.cpython-39.pyc,, +torch/__pycache__/_thread_safe_fork.cpython-39.pyc,, +torch/__pycache__/_torch_docs.cpython-39.pyc,, +torch/__pycache__/_utils.cpython-39.pyc,, +torch/__pycache__/_utils_internal.cpython-39.pyc,, +torch/__pycache__/_vmap_internals.cpython-39.pyc,, +torch/__pycache__/_weights_only_unpickler.cpython-39.pyc,, +torch/__pycache__/functional.cpython-39.pyc,, +torch/__pycache__/hub.cpython-39.pyc,, +torch/__pycache__/library.cpython-39.pyc,, +torch/__pycache__/overrides.cpython-39.pyc,, +torch/__pycache__/quasirandom.cpython-39.pyc,, +torch/__pycache__/random.cpython-39.pyc,, +torch/__pycache__/return_types.cpython-39.pyc,, +torch/__pycache__/serialization.cpython-39.pyc,, +torch/__pycache__/storage.cpython-39.pyc,, +torch/__pycache__/torch_version.cpython-39.pyc,, +torch/__pycache__/types.cpython-39.pyc,, +torch/__pycache__/version.cpython-39.pyc,, +torch/_appdirs.py,sha256=riyDX0KIIvYgL-jTVmPowMPH0K8_iTfYIdKll4x1XM8,26872 +torch/_awaits/__init__.py,sha256=zi2RTUcKVCkwUpLXF2KHUVMhiHeRNn-BjOo8fLm8Ay0,1705 +torch/_awaits/__pycache__/__init__.cpython-39.pyc,, +torch/_classes.py,sha256=R5gj60vFsHqrOUeukvK6WECIGmqfpLlJk-RRbminpSs,1777 +torch/_compile.py,sha256=Sj9nY_Z9tSUvfDxkhjkgJvrLEw2aEsZLsRWsdiZF8CI,2070 +torch/_custom_op/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_custom_op/__pycache__/__init__.cpython-39.pyc,, +torch/_custom_op/__pycache__/autograd.cpython-39.pyc,, +torch/_custom_op/__pycache__/impl.cpython-39.pyc,, +torch/_custom_op/autograd.py,sha256=3PiZ78B2jRfUEEBt4K9iDJ6eWvorIHvq_lsTQ-vgHAU,12390 +torch/_custom_op/impl.py,sha256=CHQwLOxQ-d0rk9uEoQF_h1N-BYQm0TwHGRe_Qp2L0qs,27727 +torch/_custom_ops.py,sha256=yVY9HB2RWS1tjTPf1F00btTJFdVmtsuG7oq5bOAF6IM,13150 +torch/_decomp/__init__.py,sha256=1w7PjKmxmtWhi07cYtGT95Cp6yK6jTNOIuimZxW51h8,19688 +torch/_decomp/__pycache__/__init__.cpython-39.pyc,, +torch/_decomp/__pycache__/decompositions.cpython-39.pyc,, +torch/_decomp/__pycache__/decompositions_for_jvp.cpython-39.pyc,, +torch/_decomp/__pycache__/decompositions_for_rng.cpython-39.pyc,, +torch/_decomp/decompositions.py,sha256=_2GG0krGxq1JaX9YviPzyAtRguDg2ZTNEF8XsYNyhfE,181650 +torch/_decomp/decompositions_for_jvp.py,sha256=90wiio4f_u1PzZU9ZVBkAZTWYu-iyR0Iu5Mhkbcx9-A,12036 +torch/_decomp/decompositions_for_rng.py,sha256=D3nw4ohqI715fU8OcGGJ_mpHBg9MsN-uirIhDA3ZPrc,9444 +torch/_deploy.py,sha256=TwPikibwIuBarPXASg1iQrAOR9c3t1dtpuoERzyz0V8,3561 +torch/_dispatch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_dispatch/__pycache__/__init__.cpython-39.pyc,, +torch/_dispatch/__pycache__/python.cpython-39.pyc,, +torch/_dispatch/python.py,sha256=zfcrpRQTgK6SwdITVzI8sibWHLxT9SYB4H0YOfxas3I,6902 +torch/_dynamo/__init__.py,sha256=kvq81JO_vOIhDWhQX5l0xSnalPOzmqLhOEL7LpEVN1w,5393 +torch/_dynamo/__pycache__/__init__.cpython-39.pyc,, +torch/_dynamo/__pycache__/_trace_wrapped_higher_order_op.cpython-39.pyc,, +torch/_dynamo/__pycache__/bytecode_analysis.cpython-39.pyc,, +torch/_dynamo/__pycache__/bytecode_transformation.cpython-39.pyc,, +torch/_dynamo/__pycache__/cache_size.cpython-39.pyc,, +torch/_dynamo/__pycache__/callback.cpython-39.pyc,, +torch/_dynamo/__pycache__/code_context.cpython-39.pyc,, +torch/_dynamo/__pycache__/codegen.cpython-39.pyc,, +torch/_dynamo/__pycache__/compiled_autograd.cpython-39.pyc,, +torch/_dynamo/__pycache__/comptime.cpython-39.pyc,, +torch/_dynamo/__pycache__/config.cpython-39.pyc,, +torch/_dynamo/__pycache__/convert_frame.cpython-39.pyc,, +torch/_dynamo/__pycache__/create_parameter_op.cpython-39.pyc,, +torch/_dynamo/__pycache__/current_scope_id.cpython-39.pyc,, +torch/_dynamo/__pycache__/debug_utils.cpython-39.pyc,, +torch/_dynamo/__pycache__/decorators.cpython-39.pyc,, +torch/_dynamo/__pycache__/device_interface.cpython-39.pyc,, +torch/_dynamo/__pycache__/distributed.cpython-39.pyc,, +torch/_dynamo/__pycache__/eval_frame.cpython-39.pyc,, +torch/_dynamo/__pycache__/exc.cpython-39.pyc,, +torch/_dynamo/__pycache__/external_utils.cpython-39.pyc,, +torch/_dynamo/__pycache__/funcname_cache.cpython-39.pyc,, +torch/_dynamo/__pycache__/graph_break_hints.cpython-39.pyc,, +torch/_dynamo/__pycache__/graph_deduplication.cpython-39.pyc,, +torch/_dynamo/__pycache__/graph_region_tracker.cpython-39.pyc,, +torch/_dynamo/__pycache__/graph_utils.cpython-39.pyc,, +torch/_dynamo/__pycache__/guards.cpython-39.pyc,, +torch/_dynamo/__pycache__/hooks.cpython-39.pyc,, +torch/_dynamo/__pycache__/logging.cpython-39.pyc,, +torch/_dynamo/__pycache__/metrics_context.cpython-39.pyc,, +torch/_dynamo/__pycache__/mutation_guard.cpython-39.pyc,, +torch/_dynamo/__pycache__/output_graph.cpython-39.pyc,, +torch/_dynamo/__pycache__/package.cpython-39.pyc,, +torch/_dynamo/__pycache__/pgo.cpython-39.pyc,, +torch/_dynamo/__pycache__/precompile_context.cpython-39.pyc,, +torch/_dynamo/__pycache__/profiler.cpython-39.pyc,, +torch/_dynamo/__pycache__/replay_record.cpython-39.pyc,, +torch/_dynamo/__pycache__/resume_execution.cpython-39.pyc,, +torch/_dynamo/__pycache__/side_effects.cpython-39.pyc,, +torch/_dynamo/__pycache__/source.cpython-39.pyc,, +torch/_dynamo/__pycache__/symbolic_convert.cpython-39.pyc,, +torch/_dynamo/__pycache__/tensor_version_op.cpython-39.pyc,, +torch/_dynamo/__pycache__/test_case.cpython-39.pyc,, +torch/_dynamo/__pycache__/test_dont_skip_tracing_functions.cpython-39.pyc,, +torch/_dynamo/__pycache__/test_minifier_common.cpython-39.pyc,, +torch/_dynamo/__pycache__/testing.cpython-39.pyc,, +torch/_dynamo/__pycache__/trace_rules.cpython-39.pyc,, +torch/_dynamo/__pycache__/types.cpython-39.pyc,, +torch/_dynamo/__pycache__/utils.cpython-39.pyc,, +torch/_dynamo/_trace_wrapped_higher_order_op.py,sha256=vcLstoAhI5JIAs8OAq7Ys74HF-MTBNCtvmV-ZE-9U3E,9475 +torch/_dynamo/backends/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_dynamo/backends/__pycache__/__init__.cpython-39.pyc,, +torch/_dynamo/backends/__pycache__/common.cpython-39.pyc,, +torch/_dynamo/backends/__pycache__/cudagraphs.cpython-39.pyc,, +torch/_dynamo/backends/__pycache__/debugging.cpython-39.pyc,, +torch/_dynamo/backends/__pycache__/distributed.cpython-39.pyc,, +torch/_dynamo/backends/__pycache__/inductor.cpython-39.pyc,, +torch/_dynamo/backends/__pycache__/onnxrt.cpython-39.pyc,, +torch/_dynamo/backends/__pycache__/registry.cpython-39.pyc,, +torch/_dynamo/backends/__pycache__/tensorrt.cpython-39.pyc,, +torch/_dynamo/backends/__pycache__/torchxla.cpython-39.pyc,, +torch/_dynamo/backends/__pycache__/tvm.cpython-39.pyc,, +torch/_dynamo/backends/common.py,sha256=dsWEpB_W2cbooy6t2KNR8QxpC_KzREKi9T_TRoxjEyI,5703 +torch/_dynamo/backends/cudagraphs.py,sha256=z8XeiWQJuIH0fKRfzfCmcYndSieXVvPrIgpf69_VUyQ,9541 +torch/_dynamo/backends/debugging.py,sha256=bIUK8Nl9V6A-VVN-tM5ezr0YMLeb35NpCOjMBsOnevk,16050 +torch/_dynamo/backends/distributed.py,sha256=NYp_WeLAxvhN-GRDNxLiL8a4fdEZYacOqCmx7qWrXBs,27090 +torch/_dynamo/backends/inductor.py,sha256=1E1XDz7DZqzkLacwC7OcSUWG5-XSSLk1VOQzvRyWPTc,909 +torch/_dynamo/backends/onnxrt.py,sha256=saxu4kT7VTt8I4SeUcHonvhRZizcu9D0EduhRnvjCMs,1579 +torch/_dynamo/backends/registry.py,sha256=XhMm85Axz4tV824vbJuHuMsF1X3p3qm6ucGIwsHjnRA,5611 +torch/_dynamo/backends/tensorrt.py,sha256=0UBGGjwh5yOcMG4w_uFqPBlrgCzNn9hsF5_6G8ndeiw,420 +torch/_dynamo/backends/torchxla.py,sha256=4s-drykMtZ5MLvES23JKBT5cMUpmUl1nucDx_yKq2ec,1302 +torch/_dynamo/backends/tvm.py,sha256=FTuPJwVFvJzU7A1Z4JFSGEzDli4UIoLxtktKqEJYWFs,7846 +torch/_dynamo/bytecode_analysis.py,sha256=WLdZP1Q1wh8VX0gPYhkBXnXdu0NbVMrdfJDkHih35kc,8785 +torch/_dynamo/bytecode_transformation.py,sha256=MDyHFmNTTnhf1mj4x6lJc5PyaoffBqdAqjQDxjXE4bQ,61371 +torch/_dynamo/cache_size.py,sha256=j8bBw-4yBeUh1ra1w5LkBExt_vzzV853ccIuGY78XCA,8093 +torch/_dynamo/callback.py,sha256=DlyJ04kKYLBuJGjrtY4P_ndVw8He_2Zp3MaKJ3t5agI,5757 +torch/_dynamo/code_context.py,sha256=vrGYv-Uomoc-Nw0Vfb5TSe3MJvMWMy220f_J5UwgLYs,1878 +torch/_dynamo/codegen.py,sha256=1p9vj7waWnQ8H3pQ80COzyqgC6nK4dfspouxW4gHSdg,28525 +torch/_dynamo/compiled_autograd.py,sha256=3hc3Vh0HAGPpngLzqfynVXg5W5smcsf811QHjslGa_g,61987 +torch/_dynamo/comptime.py,sha256=7DnhLpr1db4l_QEbgYPN-hVH3eUeCOYzX3_Za9m0PzI,14641 +torch/_dynamo/config.py,sha256=S3CFsdJB8Ix_FO-2oYkDRh3FUfb4KV2NlHzVJK67uVg,27983 +torch/_dynamo/convert_frame.py,sha256=WJC4kRBrqDcslOeaT11oJYfgeuAIUGKc7pbbp72PrEo,58363 +torch/_dynamo/create_parameter_op.py,sha256=nf2ZwpFxoll8azktVOoOANKKWiLG-eulEEHEWNWhfak,2592 +torch/_dynamo/current_scope_id.py,sha256=lFy-kRPxTBYkhMrdJsk_sBax3qBGSfw0lU0Q2JPDKVA,1475 +torch/_dynamo/debug_utils.py,sha256=gcJb7RLWCl-DumyoYjgMUPcSyTHwIk1bHShCSyk_ni0,31252 +torch/_dynamo/decorators.py,sha256=cSaPP5cd9K1xgveJGmaHPYYp5ElX-xBLYA9YRE8nvGw,33681 +torch/_dynamo/device_interface.py,sha256=XI-XVGAyzNake5wDQxZAdltHyf8YD3mYNnVXbdo_8O0,18259 +torch/_dynamo/distributed.py,sha256=Z4DNAhzYcDguigwhwfk_c3WTo5AdL_J12mf2pK5lQYQ,1724 +torch/_dynamo/eval_frame.py,sha256=fT2-DANcYo-nOHEk4Lus3TpguTBqBSp3r7god7nByMI,83855 +torch/_dynamo/exc.py,sha256=RC6k4M1PjMLTtfN3dqYtT33649sXe11mvFx9jMT2UHg,24375 +torch/_dynamo/external_utils.py,sha256=slWvT61Z718uUgp2-ZSqThp2Rwk3dIyHY1Zc9fZuShk,7080 +torch/_dynamo/funcname_cache.py,sha256=mDVmnyyHKozM7kj-9g5VNUPSXGApc01SQSIz7Hw1MJo,2623 +torch/_dynamo/graph_break_hints.py,sha256=u9QZAhv9Q57TnNARiIC-zCgRYHbhEuyj5Id5qzOI6fY,1353 +torch/_dynamo/graph_break_registry.json,sha256=aS64if2Hwy8DOns7XctByJjW_7gThMeBYys-_V5Zr5k,86841 +torch/_dynamo/graph_deduplication.py,sha256=LLeclfXSqOFo9BZa0MLgXWFppPl91me7Dm9ofSTzT0w,17457 +torch/_dynamo/graph_region_tracker.py,sha256=Iioc15ao-eYGQKuMYXriYNwOUJ8724gnI_ftvGww91w,17965 +torch/_dynamo/graph_utils.py,sha256=vjhb3nK6gdsvs_p6mkBZZLbLmQ0tRZSwdo9t2BiyCto,2444 +torch/_dynamo/guards.py,sha256=gggZ-frsOy0f6jHc4PCvyY2W7A6V_hK5U_yRSiQZ9hM,149499 +torch/_dynamo/hooks.py,sha256=CdYt2Yf7b-S3X3ZHZRhJlJHbDzKAjbAn7Qto_gPDzgU,891 +torch/_dynamo/logging.py,sha256=AvvDW0XYOswuqeps8p0tNjqK2xfytmxOAyXgiOmAOdk,2260 +torch/_dynamo/metrics_context.py,sha256=eydGSI0n6tXoq1YrU7U464IG6_OHxJciaKatv-ViT7U,8249 +torch/_dynamo/mutation_guard.py,sha256=UGjX_9eQwO-ahBxknnqF4sb-inkDpOlutah1luzJylQ,5326 +torch/_dynamo/output_graph.py,sha256=eeQBLIM-qJpYGvRoZN0Y_viqYIrgVn9R-T7SdcPTHdw,136702 +torch/_dynamo/package.py,sha256=IRs4YU-Y8oQ1EtGx659NId66aoFIc-fThie3liVNFdA,16368 +torch/_dynamo/pgo.py,sha256=ZKxwMMTwbh6eFaMbK17Hy8P0-V-la9R37rw9GLIIDvQ,33023 +torch/_dynamo/polyfills/__init__.py,sha256=qLTGnteGNzjRpWOul2VaymsD-48PjAoaSGZAlQg1zMs,9559 +torch/_dynamo/polyfills/__pycache__/__init__.cpython-39.pyc,, +torch/_dynamo/polyfills/__pycache__/builtins.cpython-39.pyc,, +torch/_dynamo/polyfills/__pycache__/functools.cpython-39.pyc,, +torch/_dynamo/polyfills/__pycache__/fx.cpython-39.pyc,, +torch/_dynamo/polyfills/__pycache__/itertools.cpython-39.pyc,, +torch/_dynamo/polyfills/__pycache__/loader.cpython-39.pyc,, +torch/_dynamo/polyfills/__pycache__/operator.cpython-39.pyc,, +torch/_dynamo/polyfills/__pycache__/os.cpython-39.pyc,, +torch/_dynamo/polyfills/__pycache__/pytree.cpython-39.pyc,, +torch/_dynamo/polyfills/__pycache__/sys.cpython-39.pyc,, +torch/_dynamo/polyfills/__pycache__/tensor.cpython-39.pyc,, +torch/_dynamo/polyfills/builtins.py,sha256=xNHq_zVWRFLt8JVlcM1Wr9QsdJWQ0KWZD2nu9sCljxI,1459 +torch/_dynamo/polyfills/functools.py,sha256=cLFf3z8Ly1TVz4PPejxYoBQQ468FCAu8Ns0WQ9EmOTc,1013 +torch/_dynamo/polyfills/fx.py,sha256=kCc02KU8-18tsqjRHnrJVNiPyC-hlawrYoOS9aG6RqU,1363 +torch/_dynamo/polyfills/itertools.py,sha256=1puqKHB7mbLNAF7AM6Q7suLRrQnts7fBZ13EyedvMnA,6190 +torch/_dynamo/polyfills/loader.py,sha256=FusALnv0qkC6AlkIh7tu4JbtCXHPjnPFBimcfO3d5So,1266 +torch/_dynamo/polyfills/operator.py,sha256=vbQvq04JBfpgnt0oxoEqaL3ahBJKR48Up8yd94XG1cc,3029 +torch/_dynamo/polyfills/os.py,sha256=HLqUpw4bXrz0bRpkP67tDjbNKXRtMWPKcB6dR3LNesE,1014 +torch/_dynamo/polyfills/pytree.py,sha256=1bNRUpgZzLDHkhsZk8ZhVZWeyE1vKA24CU43NswiemQ,16164 +torch/_dynamo/polyfills/sys.py,sha256=7Z8bmfbsc1CTFQp0mHvK5i12bKaKgU2qDWVPhXYQPbU,472 +torch/_dynamo/polyfills/tensor.py,sha256=Xvr3-armSIna8ql7NVTIoL0xw4kPiSEUcWgDsINR7Mw,1444 +torch/_dynamo/precompile_context.py,sha256=3zvlpqFPuLJXxn8KKQgFhL9sb-3eseR02TU5s1GocOE,5926 +torch/_dynamo/profiler.py,sha256=EiKAtODJtgWZouRm4KMurqTk2oTEe-vRYwx7JIKUcVU,5883 +torch/_dynamo/replay_record.py,sha256=2CZcnfgBfhBsMeQJHWgsjscOPQq30MFlDLmeg4h1ZLg,4419 +torch/_dynamo/repro/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_dynamo/repro/__pycache__/__init__.cpython-39.pyc,, +torch/_dynamo/repro/__pycache__/after_aot.cpython-39.pyc,, +torch/_dynamo/repro/__pycache__/after_dynamo.cpython-39.pyc,, +torch/_dynamo/repro/__pycache__/aoti.cpython-39.pyc,, +torch/_dynamo/repro/after_aot.py,sha256=F8uDaXUsPM4T4zs2cYlxRC9ndMSIvNuXqI0xhdl5C0c,37376 +torch/_dynamo/repro/after_dynamo.py,sha256=27mAP6OkTmki5ItDVkfUBYec7vsjuNUFshVW3X3G-qA,21179 +torch/_dynamo/repro/aoti.py,sha256=vFCEu3k4RznpvhaCkb6kKR55lbl2JSBHShDJOGGHeOI,21827 +torch/_dynamo/resume_execution.py,sha256=4k65tW4yPsH8U718wE7tKZH-riRKugYHQ8SIkWHBu0s,22428 +torch/_dynamo/side_effects.py,sha256=p_Hr54vzvax9Vox104xiqBzokOJUjpAVK5MrJlSkAHY,49319 +torch/_dynamo/source.py,sha256=dL89aRSCxPD_9ZrMZ0ngEQDCe_yvGqszyRU9-nThwAE,34594 +torch/_dynamo/symbolic_convert.py,sha256=5qOXfMYJJ8EtsgAv-XZd9FwPgDRKUhW5utawXSaH7d4,171189 +torch/_dynamo/tensor_version_op.py,sha256=MKAwL-cvnzrxjLrBpGni79xXGpoUxtWKViTLTMjDpb0,2462 +torch/_dynamo/test_case.py,sha256=GSc_JW5VWCjWKZW1DSiEsC8We5QWNaI1Ew77XuhDkh8,7462 +torch/_dynamo/test_dont_skip_tracing_functions.py,sha256=eFed1TvI4WRSq5uyNxHVze9STb36PBBOuACuu4n82vs,857 +torch/_dynamo/test_minifier_common.py,sha256=6Xr0zgxpr9w4iillHtLwXKz8tx8xV_QoLqYtR4pYlW0,12207 +torch/_dynamo/testing.py,sha256=SvH_hnps204TcibdTl9frFQYTU8Anuo16lfEodaNvR8,17624 +torch/_dynamo/trace_rules.py,sha256=A0L7zK4m0bvIe7DzJxYcUgXEmL9wulZUHeiGiv8lavI,157987 +torch/_dynamo/types.py,sha256=_UIPaSPI_YKSnnKm_ozD9BH-YfRb-fudSgYg96RB0Eg,4294 +torch/_dynamo/utils.py,sha256=4QXLjZDBj4az2QODcMbY-U9lvvUZ1BgO3umr9PlGZqs,165838 +torch/_dynamo/variables/__init__.py,sha256=7A0zNRwY-f2FW5GHG4vSu3ARAdaHWF9jbLj6o_3fuc0,6864 +torch/_dynamo/variables/__pycache__/__init__.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/base.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/builder.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/builtin.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/constant.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/ctx_manager.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/dicts.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/distributed.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/functions.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/higher_order_ops.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/iter.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/lazy.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/lists.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/misc.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/nn_module.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/optimizer.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/script_object.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/sdpa.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/tensor.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/torch.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/torch_function.cpython-39.pyc,, +torch/_dynamo/variables/__pycache__/user_defined.cpython-39.pyc,, +torch/_dynamo/variables/base.py,sha256=7M7Dc2_IO6200VBTrtLZA9o0F4z1fknj4q9Motl0iqk,24424 +torch/_dynamo/variables/builder.py,sha256=F2F-QaPPAtG8pJkuH1mWPaax1PDiWwkkin-lcP-NHFw,158531 +torch/_dynamo/variables/builtin.py,sha256=Zlemn4Oows_CF89AHMgBViuyQPpKUBW3B3Yk9TnXk1s,104923 +torch/_dynamo/variables/constant.py,sha256=zxroBOQ6dOTd-0G_lDb9CyV9jvlgU6L0fUSjLFFsWBs,10697 +torch/_dynamo/variables/ctx_manager.py,sha256=zFWNuPZriJvDs9C3Xn0HS-4xbZe4DEox80w9psbRfs4,52308 +torch/_dynamo/variables/dicts.py,sha256=1AGFbXY4uISXEzw2VLEVUHjGjJTLXNPN8305wdYsSv4,42670 +torch/_dynamo/variables/distributed.py,sha256=AR-S2zpEKtyG3YsFlp5i3rpZgxkjHYczLcmJxdwh13U,16469 +torch/_dynamo/variables/functions.py,sha256=NMxzOCPHcYW_T4RaT_iGNXj-qk2ho_wSmCCWPnXipP8,91240 +torch/_dynamo/variables/higher_order_ops.py,sha256=MiP-gqEHFKdkgD8Uf46GrmUq2TsN-NHwRETNmr4gRAI,135143 +torch/_dynamo/variables/iter.py,sha256=nF1QCPwjqiXkq8G6fJtLEheBdWTEiPiZBJwfmdv_iOc,23871 +torch/_dynamo/variables/lazy.py,sha256=xbbnBxADw_DT2cPsIZ5uiG0GlqWEqfum3EZcpd9K4So,7278 +torch/_dynamo/variables/lists.py,sha256=QhZzr0B3ZydznNakR5M1j93e0q_u3GBz7tzwGV-whqw,41260 +torch/_dynamo/variables/misc.py,sha256=NFar9y8o4UiozqkFS42Xf7eXzwxwELOtvPQDKT5N36c,76522 +torch/_dynamo/variables/nn_module.py,sha256=M1c291_qH2rgLpVN7y1htxlBKMhyBA_B36fGsrJFWfw,53108 +torch/_dynamo/variables/optimizer.py,sha256=MD0QTjqltziFEvq71MGcrw0jpG8Gib1MM_rpwIZwvpo,17121 +torch/_dynamo/variables/script_object.py,sha256=oTJlTP_l8dVsSEzzIITAJilT06yw1zLLoGHJG-k-TRk,3865 +torch/_dynamo/variables/sdpa.py,sha256=ATsREiCSqdQyfvzcEehtY5vPtmLKzqNDyi5ekhSq_AI,2600 +torch/_dynamo/variables/tensor.py,sha256=Ym--0r8BuyrFAfFJalI1cpt0PT2JaDrDnwL9cN5GCAU,68435 +torch/_dynamo/variables/torch.py,sha256=5kixBRjccUGmqWPf5qKl2LC5F02J9hNWD_W4bl-U3pw,72967 +torch/_dynamo/variables/torch_function.py,sha256=Hjf2E7BaViGKl3D8LiZyIeZ3RhtLVLroS1U6k66nb-E,28692 +torch/_dynamo/variables/user_defined.py,sha256=bercqdLTZfqAirXRuh7D8z0Knv1eUX73MItspRZv7xc,74020 +torch/_environment.py,sha256=oY4eUdyvzp214s4IxxLW5tcr0N7KkAVyR0qrjXKPqmg,44 +torch/_export/__init__.py,sha256=hhSswkDQsFJxxG0NvS7dFdwzaBixRZV9DpnSqxIP-4I,6672 +torch/_export/__pycache__/__init__.cpython-39.pyc,, +torch/_export/__pycache__/converter.cpython-39.pyc,, +torch/_export/__pycache__/error.cpython-39.pyc,, +torch/_export/__pycache__/non_strict_utils.cpython-39.pyc,, +torch/_export/__pycache__/pass_base.cpython-39.pyc,, +torch/_export/__pycache__/tools.cpython-39.pyc,, +torch/_export/__pycache__/utils.cpython-39.pyc,, +torch/_export/__pycache__/verifier.cpython-39.pyc,, +torch/_export/__pycache__/wrappers.cpython-39.pyc,, +torch/_export/converter.py,sha256=Nm-y8k8VvdQcLVUipfbAuBf4SCHXUF0DSsHp4z199B4,66154 +torch/_export/db/__init__.py,sha256=JQMVLRoFCKo629KQ_1vu4i10Rmkww8EWMyZnDLEGXag,211 +torch/_export/db/__pycache__/__init__.cpython-39.pyc,, +torch/_export/db/__pycache__/case.cpython-39.pyc,, +torch/_export/db/__pycache__/gen_example.cpython-39.pyc,, +torch/_export/db/__pycache__/logging.cpython-39.pyc,, +torch/_export/db/case.py,sha256=LTzxaWgX58ZQV5AYhHCli-eSbnoO3_IVSPSUJHrsUZo,5172 +torch/_export/db/examples/__init__.py,sha256=NpYiPkSVS3WILe8olwFApfL91puPyaicw9oNcGnZq68,1709 +torch/_export/db/examples/__pycache__/__init__.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/assume_constant_result.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/autograd_function.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/class_method.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/cond_branch_class_method.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/cond_branch_nested_function.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/cond_branch_nonlocal_variables.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/cond_closed_over_variable.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/cond_operands.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/cond_predicate.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/constrain_as_size_example.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/constrain_as_value_example.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/decorator.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/dictionary.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/dynamic_shape_assert.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/dynamic_shape_constructor.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/dynamic_shape_if_guard.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/dynamic_shape_map.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/dynamic_shape_round.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/dynamic_shape_slicing.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/dynamic_shape_view.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/fn_with_kwargs.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/list_contains.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/list_unpack.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/model_attr_mutation.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/nested_function.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/null_context_manager.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/optional_input.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/pytree_flatten.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/scalar_output.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/specialized_attribute.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/static_for_loop.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/static_if.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/tensor_setattr.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/type_reflection_method.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/unsupported_operator.cpython-39.pyc,, +torch/_export/db/examples/__pycache__/user_input_mutation.cpython-39.pyc,, +torch/_export/db/examples/assume_constant_result.py,sha256=KpwYNJ7FSfogkz2pR_bdCrpEGaEKoUYcOYZ41tMT52o,530 +torch/_export/db/examples/autograd_function.py,sha256=DGWCDLKevtQadFsNH02UM5XzrvmlSsQaWxFbTemZqYc,601 +torch/_export/db/examples/class_method.py,sha256=7RhhdSXeU2bWbLCk1Ftdmot1r2rdzS2YC3LWC5L-naE,521 +torch/_export/db/examples/cond_branch_class_method.py,sha256=n7r_4c1IomIpdPhhsE63NSdm2JNItSuBh3_83wIh8wo,1371 +torch/_export/db/examples/cond_branch_nested_function.py,sha256=r0A2249Ipudz5ONHPCkiUZ5nA1PQrNPTroVcluJ5a2A,1343 +torch/_export/db/examples/cond_branch_nonlocal_variables.py,sha256=sUe5h7AMWE958RPYXBWHKBRD_AIY8dqhKkrmjyPPJWw,1900 +torch/_export/db/examples/cond_closed_over_variable.py,sha256=9Mn39Q8WcZ9HUu26wq06BDhyNCE2OSw-AeNj7tuZh9U,569 +torch/_export/db/examples/cond_operands.py,sha256=ZtEnfCwl-YFVWTiR_Fz0W3IV0R04PVHrcACrn848K8o,834 +torch/_export/db/examples/cond_predicate.py,sha256=pj3w9TsgmM_UffQTngHb_Yd-lBw9CbBrQdEmrh7ou38,688 +torch/_export/db/examples/constrain_as_size_example.py,sha256=7O7V9AGhrS9LqgsQ-YsQu2vDACFmgDKSK7euEr9NIMU,662 +torch/_export/db/examples/constrain_as_value_example.py,sha256=tm-3fG-Nu6mUDM8ZHKhYQvrgItaIUh0qrsTo9iu-HTc,717 +torch/_export/db/examples/decorator.py,sha256=oHQ39bHuPeK3EeOV3HsyRs7MY5KxXO1eOR1CNPiZQ_U,503 +torch/_export/db/examples/dictionary.py,sha256=NJ-Klzxl-mE0IhFJGNkW8HN0FAg9HcAU5yccntZvedc,421 +torch/_export/db/examples/dynamic_shape_assert.py,sha256=0rRtnAHLYSYPxtONdmJz_MyCeJdGpCaM-sbLspi10Oo,468 +torch/_export/db/examples/dynamic_shape_constructor.py,sha256=lOjFINCStd4zmu2Vr9zi_rwYaL5A0sobDoVX3qhT6HM,411 +torch/_export/db/examples/dynamic_shape_if_guard.py,sha256=qggdXy2eBJk0kKaHnRPoM5v_ORu7kw4AgLkbd1jW0XA,579 +torch/_export/db/examples/dynamic_shape_map.py,sha256=ROqfFs6mIxZGNf4m_j39kimJ6ZiegndVByWRmGqAp8k,473 +torch/_export/db/examples/dynamic_shape_round.py,sha256=AfXnT8_oeDAcDnpFscD3qGhUuFA3PiyJMdLig-9D1_8,546 +torch/_export/db/examples/dynamic_shape_slicing.py,sha256=jv84W8ewyLl6W7qRP9jj1FEJzTUsKv4rcnT3dU4bPOM,403 +torch/_export/db/examples/dynamic_shape_view.py,sha256=mWYcUNa-btUzG5aF994K0wOOqG5V1y6gxRE2h7Gb-Hk,461 +torch/_export/db/examples/fn_with_kwargs.py,sha256=qUPoxV4BUCkLfEznW5PjHLYg-ggap5QI-IFtzM4FqyE,761 +torch/_export/db/examples/list_contains.py,sha256=EG9tgRlSkbZnz9hmLV5OEHVGonugVkwFPZgVmx-dTgU,494 +torch/_export/db/examples/list_unpack.py,sha256=avXgBV5cMTlHTHoqn01zFYv1oH6rkOTUlpdEneA5Qwk,589 +torch/_export/db/examples/model_attr_mutation.py,sha256=h1dvo1psNLOVIQYISCZcdQaZI9KwVCgxUqDv0Jg6kSU,688 +torch/_export/db/examples/nested_function.py,sha256=6M1YbZo8AXA7pIEwDjNjloeguUPkv2tTC22RgBYSMiU,514 +torch/_export/db/examples/null_context_manager.py,sha256=1fBDJxbRhIAW9td0fsifR8QoBpdGAlXwq4GxMiYNYg4,499 +torch/_export/db/examples/optional_input.py,sha256=3kSSo2Nr9muJ4GiRPE8xewuMLuW7GFLi0SbMJKvjfMk,475 +torch/_export/db/examples/pytree_flatten.py,sha256=igCLbl367lo6N8zZXm5MXpX2tmF22uceQL6aDZP5tsk,392 +torch/_export/db/examples/scalar_output.py,sha256=fpbxy1MySLUTAz57P39pLBPVEfSdIAC0d9JDAa7lB9g,566 +torch/_export/db/examples/specialized_attribute.py,sha256=prbTioycap7pvNkI_3SA688cfTJHvtFYNw42tPns40M,546 +torch/_export/db/examples/static_for_loop.py,sha256=f4dXMQw28-cspeGBVhZ31y_IEXhZ3YUFzoRrPP_Qh7g,401 +torch/_export/db/examples/static_if.py,sha256=CQPwPy1r7gcN7cR2cdziwTnj9aakvW53-ojoc-snBcw,415 +torch/_export/db/examples/tensor_setattr.py,sha256=PCCle6Wexu0ZBwxOSFr35UZwIjlBAe9o1LuLKQIJFwY,352 +torch/_export/db/examples/type_reflection_method.py,sha256=3Z-5GHb5eWKuuLK0-6B_tG_nZNeizr4Y7AuLDZRQ7ig,483 +torch/_export/db/examples/unsupported_operator.py,sha256=jiHwslxhsRjC1Uld3EVVMCme6RG0e1rf0JdUvs7bfc4,429 +torch/_export/db/examples/user_input_mutation.py,sha256=fe-oUXenBLZprn69qn4J6s804qrnJBMcXb0gQZJhVaE,319 +torch/_export/db/gen_example.py,sha256=43Jhg_l7Ke17ml79_eU8T_qT5NmyhjiEisZVK-YjRus,483 +torch/_export/db/logging.py,sha256=iFh5sub_1e67b1Hq5BnbXcTMKayV8Q8ICpnB23tQu-A,1697 +torch/_export/error.py,sha256=QXyVuJjbdbTcTodQ9tth-jUUpBeg4T8Ar3WLh1wHLXM,1826 +torch/_export/non_strict_utils.py,sha256=0hjFire1wPsgegbDu4Ib1jkD3m62pJH4wxcYgTTZRN4,40697 +torch/_export/pass_base.py,sha256=oALt6cfVP1owYmIEsfH3fmK7jm672ug-dkayTvNnqDw,18746 +torch/_export/pass_infra/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_export/pass_infra/__pycache__/__init__.cpython-39.pyc,, +torch/_export/pass_infra/__pycache__/node_metadata.cpython-39.pyc,, +torch/_export/pass_infra/__pycache__/proxy_value.cpython-39.pyc,, +torch/_export/pass_infra/node_metadata.py,sha256=rSit62__D-HKQrSX_8DrvzUsaT9XEuK3XwIB9AtHBqU,803 +torch/_export/pass_infra/proxy_value.py,sha256=McjGBissRThDTqog2Iwt2oE-m1jN_cPsFUYqh2Xv5Qk,1314 +torch/_export/passes/__init__.py,sha256=mUQ0mkji9c_GH4vU9VaS1CXALHLwW0mic3OhC2gXM3M,89 +torch/_export/passes/__pycache__/__init__.cpython-39.pyc,, +torch/_export/passes/__pycache__/_node_metadata_hook.cpython-39.pyc,, +torch/_export/passes/__pycache__/add_runtime_assertions_for_constraints_pass.cpython-39.pyc,, +torch/_export/passes/__pycache__/collect_tracepoints_pass.cpython-39.pyc,, +torch/_export/passes/__pycache__/constant_folding.cpython-39.pyc,, +torch/_export/passes/__pycache__/functionalize_side_effectful_ops_pass.cpython-39.pyc,, +torch/_export/passes/__pycache__/insert_custom_op_guards.cpython-39.pyc,, +torch/_export/passes/__pycache__/lift_constants_pass.cpython-39.pyc,, +torch/_export/passes/__pycache__/remove_runtime_assertions.cpython-39.pyc,, +torch/_export/passes/__pycache__/replace_autocast_with_hop_pass.cpython-39.pyc,, +torch/_export/passes/__pycache__/replace_quantized_ops_with_standard_ops_pass.cpython-39.pyc,, +torch/_export/passes/__pycache__/replace_set_grad_with_hop_pass.cpython-39.pyc,, +torch/_export/passes/__pycache__/replace_view_ops_with_view_copy_ops_pass.cpython-39.pyc,, +torch/_export/passes/__pycache__/replace_with_hop_pass_util.cpython-39.pyc,, +torch/_export/passes/_node_metadata_hook.py,sha256=IhgYnRblwmjxKudEx_q_JGUXGeGaV_WrzNsAWwUQBEU,2559 +torch/_export/passes/add_runtime_assertions_for_constraints_pass.py,sha256=WkTI6eISlWRumWZn9XmpyYksvsMJbZjlm7Y4RxnRK34,10410 +torch/_export/passes/collect_tracepoints_pass.py,sha256=LxTJJsz4CAGaBbwZv54Wy0FgWs06SsYjYe5fYu2GpZw,6694 +torch/_export/passes/constant_folding.py,sha256=o4FcmEZH-nSyFwGdi-EJPBQERGafQaPsr-EjNnDIqn0,11579 +torch/_export/passes/functionalize_side_effectful_ops_pass.py,sha256=dLd8XyxsoK70zUyKgURKgbIXf7VcbaIRJ7eFrVPD7gM,3378 +torch/_export/passes/insert_custom_op_guards.py,sha256=FYcgWcmUVjypMTnXqYTUKcatZ7LQBw7gfcZ2bXSi2mI,2933 +torch/_export/passes/lift_constants_pass.py,sha256=SRdp8v-vBZI9cqZNfcnV9E5iUeM3dra8Ljgq027E4k4,17872 +torch/_export/passes/remove_runtime_assertions.py,sha256=yokuUkDsU9_AJyeeWLPYR3YjIqdCM_jX9BW7h1eRpAE,1624 +torch/_export/passes/replace_autocast_with_hop_pass.py,sha256=vlcih0Fu2PDqfMS5LMkRvomu9ptP3B5piPb7fxFhm3o,7405 +torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py,sha256=YxC3zvw2SabZOD7ifgoiXNVap_Z1CBo0pYWSHQW4_cc,26572 +torch/_export/passes/replace_set_grad_with_hop_pass.py,sha256=G0tDBVubmReIEd9iZ3GvK8MPFHQmzeIW7WqeRxx25KE,4446 +torch/_export/passes/replace_view_ops_with_view_copy_ops_pass.py,sha256=MtQ07dRGrLc2Qfr5IyPMuxjcem2pVKEnq21_ci1xs2w,2476 +torch/_export/passes/replace_with_hop_pass_util.py,sha256=LmOdItRluDZIbkiimXSWFQDlbPvShKvZDyGUfpYToeE,7739 +torch/_export/serde/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_export/serde/__pycache__/__init__.cpython-39.pyc,, +torch/_export/serde/__pycache__/dynamic_shapes.cpython-39.pyc,, +torch/_export/serde/__pycache__/schema.cpython-39.pyc,, +torch/_export/serde/__pycache__/schema_check.cpython-39.pyc,, +torch/_export/serde/__pycache__/serialize.cpython-39.pyc,, +torch/_export/serde/__pycache__/union.cpython-39.pyc,, +torch/_export/serde/dynamic_shapes.py,sha256=OEt9lOPjR4o0-VUzMtAJHuELbQqq683Oq_HjCP4yN-Y,11873 +torch/_export/serde/export_schema.thrift,sha256=huDw51cHidD-L0RgkGBnQqxETi3pHmq3BNo48nEl-BM,7375 +torch/_export/serde/schema.py,sha256=_gAc1w-jqVvqvVBlFsXDYYehsz0qfcxXlxUjvyPwmYE,15148 +torch/_export/serde/schema.yaml,sha256=tap9QOYg4FKoE_PP8UsCaPjmnJ67-7Eur8O-hMQ0KVw,10033 +torch/_export/serde/schema_check.py,sha256=YKScHhUC6RsWcqJxoIP_hDOC3SFVpllIz2HAX3JHmkM,25180 +torch/_export/serde/serialize.py,sha256=uEeJA7aUrep9Vek-ESJzC6bmr4l1yI0OvgsuSFbmqDk,147605 +torch/_export/serde/union.py,sha256=4jipD6-g39czcccL967O1WpbpG9zosOPflB21b-h_bU,2097 +torch/_export/tools.py,sha256=uwIPO6XLFqwT8xBnWfxnt1dJV2-lv1UxrXGWmwn6kWk,4719 +torch/_export/utils.py,sha256=kFymstUsNYOG3i-Db-FgZuNxaYqwhquzycEueIX5Fy0,57244 +torch/_export/verifier.py,sha256=WGTNx7UAYshvM_QxQZPZTu-qqPM8QYxFpS9C-D1rMTM,20195 +torch/_export/wrappers.py,sha256=qk1A2NeMcd0lx1SKd86YSbJ4NvwD8Ev9ETPAmGYzsEk,9623 +torch/_functorch/__init__.py,sha256=JQMVLRoFCKo629KQ_1vu4i10Rmkww8EWMyZnDLEGXag,211 +torch/_functorch/__pycache__/__init__.cpython-39.pyc,, +torch/_functorch/__pycache__/aot_autograd.cpython-39.pyc,, +torch/_functorch/__pycache__/apis.cpython-39.pyc,, +torch/_functorch/__pycache__/autograd_function.cpython-39.pyc,, +torch/_functorch/__pycache__/batch_norm_replacement.cpython-39.pyc,, +torch/_functorch/__pycache__/benchmark_utils.cpython-39.pyc,, +torch/_functorch/__pycache__/compile_utils.cpython-39.pyc,, +torch/_functorch/__pycache__/compilers.cpython-39.pyc,, +torch/_functorch/__pycache__/config.cpython-39.pyc,, +torch/_functorch/__pycache__/deprecated.cpython-39.pyc,, +torch/_functorch/__pycache__/eager_transforms.cpython-39.pyc,, +torch/_functorch/__pycache__/functional_call.cpython-39.pyc,, +torch/_functorch/__pycache__/fx_minifier.cpython-39.pyc,, +torch/_functorch/__pycache__/make_functional.cpython-39.pyc,, +torch/_functorch/__pycache__/partitioners.cpython-39.pyc,, +torch/_functorch/__pycache__/pyfunctorch.cpython-39.pyc,, +torch/_functorch/__pycache__/python_key.cpython-39.pyc,, +torch/_functorch/__pycache__/pytree_hacks.cpython-39.pyc,, +torch/_functorch/__pycache__/top_operators_github_usage.cpython-39.pyc,, +torch/_functorch/__pycache__/utils.cpython-39.pyc,, +torch/_functorch/__pycache__/vmap.cpython-39.pyc,, +torch/_functorch/_activation_checkpointing/__init__.py,sha256=JQMVLRoFCKo629KQ_1vu4i10Rmkww8EWMyZnDLEGXag,211 +torch/_functorch/_activation_checkpointing/__pycache__/__init__.cpython-39.pyc,, +torch/_functorch/_activation_checkpointing/__pycache__/ac_logging_utils.cpython-39.pyc,, +torch/_functorch/_activation_checkpointing/__pycache__/graph_info_provider.cpython-39.pyc,, +torch/_functorch/_activation_checkpointing/__pycache__/knapsack.cpython-39.pyc,, +torch/_functorch/_activation_checkpointing/__pycache__/knapsack_evaluator.cpython-39.pyc,, +torch/_functorch/_activation_checkpointing/ac_logging_utils.py,sha256=sRPWVWLCu8A1XpbYUKWxo2-jzkMH722-pr04SwmMYyI,5304 +torch/_functorch/_activation_checkpointing/graph_info_provider.py,sha256=R3ehi2yuz7Yq2KexmsTmbRBjNTmvt8S9LynviwZCGnk,13087 +torch/_functorch/_activation_checkpointing/knapsack.py,sha256=tCWlDdZ8AohtrlSjqG2UxsE91CSqaPe5MpoI4P7YMJc,4076 +torch/_functorch/_activation_checkpointing/knapsack_evaluator.py,sha256=mtLyX1_ZNEl4KC5yWyqT5e2YKkZPSIvpqChELLQ9kmE,12096 +torch/_functorch/_aot_autograd/__init__.py,sha256=JQMVLRoFCKo629KQ_1vu4i10Rmkww8EWMyZnDLEGXag,211 +torch/_functorch/_aot_autograd/__pycache__/__init__.cpython-39.pyc,, +torch/_functorch/_aot_autograd/__pycache__/autograd_cache.cpython-39.pyc,, +torch/_functorch/_aot_autograd/__pycache__/collect_metadata_analysis.cpython-39.pyc,, +torch/_functorch/_aot_autograd/__pycache__/dispatch_and_compile_graph.cpython-39.pyc,, +torch/_functorch/_aot_autograd/__pycache__/functional_utils.cpython-39.pyc,, +torch/_functorch/_aot_autograd/__pycache__/input_output_analysis.cpython-39.pyc,, +torch/_functorch/_aot_autograd/__pycache__/jit_compile_runtime_wrappers.cpython-39.pyc,, +torch/_functorch/_aot_autograd/__pycache__/logging_utils.cpython-39.pyc,, +torch/_functorch/_aot_autograd/__pycache__/runtime_wrappers.cpython-39.pyc,, +torch/_functorch/_aot_autograd/__pycache__/schemas.cpython-39.pyc,, +torch/_functorch/_aot_autograd/__pycache__/subclass_parametrization.cpython-39.pyc,, +torch/_functorch/_aot_autograd/__pycache__/subclass_utils.cpython-39.pyc,, +torch/_functorch/_aot_autograd/__pycache__/traced_function_transforms.cpython-39.pyc,, +torch/_functorch/_aot_autograd/__pycache__/utils.cpython-39.pyc,, +torch/_functorch/_aot_autograd/autograd_cache.py,sha256=ntTcuYRNOKuJg7Fz3YdI38CAGAw0TRBXluRKn5oH5Pk,59472 +torch/_functorch/_aot_autograd/collect_metadata_analysis.py,sha256=ujG02-QKN0WsSgdw9iXS5L4pdbDz3Gc3jctWPd8TVN8,43471 +torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py,sha256=6blEjL-0JT8bMlWNKAli1v1UkQ6S6-bLO4BB0y0jPJI,12981 +torch/_functorch/_aot_autograd/functional_utils.py,sha256=YCvxQCbSrYxLMFzed7zLM8Yl9zgjW7rJgRwH3MM6rzM,23213 +torch/_functorch/_aot_autograd/input_output_analysis.py,sha256=5n1aTSt518kxJwI1UJjAZPHd2gP5tEL1D8oxZZ6C2ZE,18050 +torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py,sha256=5BLAiXSpm6jt2_veY_KVQGJ8mXCx7pgHyOK3IcTwVA8,78361 +torch/_functorch/_aot_autograd/logging_utils.py,sha256=CFTjdgtqBa9-F8Hdh1SHUCj0NvbA-RwUEhodDoh6ZoQ,4769 +torch/_functorch/_aot_autograd/runtime_wrappers.py,sha256=ebSqlFIHTE2ycSkrHkVMwX0KPxXgx9-E6ugINZ1Vc-k,108622 +torch/_functorch/_aot_autograd/schemas.py,sha256=Zr007j-F2YNdWb_lbCSw087LXu3kBpSYb6OWaR7Q8q4,42556 +torch/_functorch/_aot_autograd/subclass_parametrization.py,sha256=rQzhh1x8PRzMHVKxh9X-TkjcWH0ov5i_5Y7MuDwUIsw,4190 +torch/_functorch/_aot_autograd/subclass_utils.py,sha256=s9-TgWZ2YIH7ijKNuzHAcjaDimDtfVXbGXnLBnC761E,18889 +torch/_functorch/_aot_autograd/traced_function_transforms.py,sha256=XlgjjfagwInz6wVPWvjoKR_PZKpN6MY950sXlIiW1IQ,46158 +torch/_functorch/_aot_autograd/utils.py,sha256=xiWzCYthy70zAupVEhvxzKY0eeKe9QRUEDPAQulAEx8,19733 +torch/_functorch/aot_autograd.py,sha256=AKYCoY_raQ-OOvskImSbEMY5hsnZxYFVP2VWPtuvPfs,74396 +torch/_functorch/apis.py,sha256=2VfFhvHlFxdA2Z_V5VlCcLWERWI8Rz03Fo5nD9QsCYE,19537 +torch/_functorch/autograd_function.py,sha256=g22c_I-WmB4ME8fEb4tV7V4RPd7_xB6uDAz0ZrycItQ,29519 +torch/_functorch/batch_norm_replacement.py,sha256=rCCP_aiPj9c0zuKAW2E8fSCuR4u4-R3EdovaOg43VSY,884 +torch/_functorch/benchmark_utils.py,sha256=Ly1poMsLx76dp9O3rW3ILFdQ2-TIOADjKKhZnKL1hqc,6511 +torch/_functorch/compile_utils.py,sha256=aO6t-pzqtzQKgRBNhcV9C3sS1trXFbr5WJyB2Pv38Xk,7894 +torch/_functorch/compilers.py,sha256=VZt2StacWFAg7eyTRUZ99OHXyW-O0oWDpzNdtdcXjvo,14494 +torch/_functorch/config.py,sha256=GSVvCzxB-lPjdNxJCRPdzxM7MfqJ4TuZX7Wv-1gSbFI,13984 +torch/_functorch/deprecated.py,sha256=O6Gg1clqz3YfpXE7QKlnPBPM8FaHG_FnDXDC5877MfU,5375 +torch/_functorch/eager_transforms.py,sha256=zD8WX2fA1bL4mMyDRTvTYAAEdJJ_BQKGO8Wy48PCZdM,72874 +torch/_functorch/functional_call.py,sha256=jrTeb5FfAUFOb4C86IlnnU_sE2SPnbubZyJnDgv-9xY,10831 +torch/_functorch/fx_minifier.py,sha256=xhnj0JUPA-zZwbSsMFU5d5MVqezFOlPfGxlkzgROhvk,17864 +torch/_functorch/make_functional.py,sha256=X0qgbReE3eI0dByB2Cka2Zqqo3Xl7Urok0EaCPkLyqc,23351 +torch/_functorch/partitioners.py,sha256=creqtjvzHeAqb4UC1FTIP4T0qn4KnUvEx8st8ESGO34,107945 +torch/_functorch/pyfunctorch.py,sha256=_69-7aB6Wl5sSenTNNT1By8-7EMF4yviIpIo2Qquvk8,10988 +torch/_functorch/python_key.py,sha256=J5m4n931Z1hmTQmIjQWju_OHePMWJGTn58QYvcTvF-c,457 +torch/_functorch/pytree_hacks.py,sha256=usKxP7Ui-WS-Xa3UQnrNP96SH35GM3avonfhZV7c0gU,721 +torch/_functorch/top_operators_github_usage.py,sha256=hOYNz4o6eiOFoHtrPKzjnYZgqTM0CYB3HT9jmrJ24vk,22021 +torch/_functorch/utils.py,sha256=ohM06i44JiV-aHTKTqAx54B2xWa7628nCEGUDCxKkgE,1092 +torch/_functorch/vmap.py,sha256=W3M-RBdWp6OBf58xxf6oFT9_V5TcSFtqy3dkB8Q4dEo,19522 +torch/_guards.py,sha256=TSJYfERTFWpfO-kmyTD6F9mfYqhLmLJzmgu60M_-VO8,40146 +torch/_higher_order_ops/__init__.py,sha256=g-e21D037zc0WiNAoJ20MBk99DynrOVrSmANKdyGlwQ,2350 +torch/_higher_order_ops/__pycache__/__init__.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/_invoke_quant.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/aoti_call_delegate.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/associative_scan.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/auto_functionalize.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/base_hop.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/cond.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/effects.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/executorch_call_delegate.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/flat_apply.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/flex_attention.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/foreach_map.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/hints_wrap.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/invoke_subgraph.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/map.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/out_dtype.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/run_const_graph.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/scan.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/schema.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/strict_mode.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/torchbind.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/triton_kernel_wrap.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/utils.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/while_loop.cpython-39.pyc,, +torch/_higher_order_ops/__pycache__/wrap.cpython-39.pyc,, +torch/_higher_order_ops/_invoke_quant.py,sha256=muLGTOGNpSa07wR05T1xp_AA6P57roTfpari1FbTfmI,1962 +torch/_higher_order_ops/aoti_call_delegate.py,sha256=W_Xe-a3I5QOxskGDQ_D7o0qmcRHOcsmFzPUekXA8DIk,6224 +torch/_higher_order_ops/associative_scan.py,sha256=FAdb_UC7kHeKqrhymWRw1tSvRtxnwcgTF0RBPm2acjg,18824 +torch/_higher_order_ops/auto_functionalize.py,sha256=RwP5QdEmIg3K8MTk6Dwo2O8JtyRFktpbqaHUjfo1FSI,37207 +torch/_higher_order_ops/base_hop.py,sha256=29frpWYXCfQbE1dQa5serKZw41ufznzCsOZOINvCbjY,10798 +torch/_higher_order_ops/cond.py,sha256=yFbUxWQpNmRNp54-Pnvi6yV2TV_YKn5yf0eGbDZcvOU,28905 +torch/_higher_order_ops/effects.py,sha256=zHev1SQvLs7SUf9yVxIbKo-9onEfGy_IOINgCOeE70M,10417 +torch/_higher_order_ops/executorch_call_delegate.py,sha256=dU-LJh1XU1aSq0UWkBVM_pgkW36-K7Lb52pIv6V9JW0,6134 +torch/_higher_order_ops/flat_apply.py,sha256=l3xMFNL588LYk5FTwRiUjxR02hvyTtlq6HAxwnR_7jM,4494 +torch/_higher_order_ops/flex_attention.py,sha256=4JBLgIH8FLCrvzvlQpfD_PH6QhKr0tUOZYSWb-7lFJA,44306 +torch/_higher_order_ops/foreach_map.py,sha256=yAAPa-WWH7mX2DyQIS3-dWNU0hTl5ua5Ec4daOlvReo,686 +torch/_higher_order_ops/hints_wrap.py,sha256=9ymjphEvDzF3V_Ylx2v5J1iqXx8fGL9yld_jmWBtYvc,4979 +torch/_higher_order_ops/invoke_subgraph.py,sha256=cCqv736R73Je5bMr4hcjaCapDp7bBpKG6oCYIWbmrGg,26957 +torch/_higher_order_ops/map.py,sha256=vqUBg1Rf5EknTttGvWwClCzQbETegF5QCUSLiaJwfiU,10344 +torch/_higher_order_ops/out_dtype.py,sha256=oV82ov5up5yqU19fJFGEDtS01ni71yXFNpoXecRNHGg,5689 +torch/_higher_order_ops/run_const_graph.py,sha256=y26HpZQ6nZQqyy8QqLn1ISxoN7dYZcqZ9Pu3x-cbmqA,1950 +torch/_higher_order_ops/scan.py,sha256=V8BXLhfrQnhvMY8AY5GoNjZaLfHWsFO6ImVzwNjH7Ts,39080 +torch/_higher_order_ops/schema.py,sha256=NVywFVJMquUPz8S0P6vqqHLSm0z6zSk8LHO5AeUJBr8,11434 +torch/_higher_order_ops/strict_mode.py,sha256=lxlppDDOMxu1yFgVndhIs-I7EIKM6VVbN9HyoqIKkP8,3752 +torch/_higher_order_ops/torchbind.py,sha256=byhkmIAgEdJTgCVrIhbt5QhOlj7O486LIWMxSNBwGsU,6416 +torch/_higher_order_ops/triton_kernel_wrap.py,sha256=SZDCuoeysdTxw4dak2R1Df6yvSmgTF8EOQ1yLIJ6xP8,84489 +torch/_higher_order_ops/utils.py,sha256=01c2lgLpQtUBxXuIBMz6GswvmTZdpmXWmDf2q0FPdgE,44593 +torch/_higher_order_ops/while_loop.py,sha256=jpfEBXQUiqpF6yQFW16hUQNZ3WNWh-sGRSqJk9UOdKE,17986 +torch/_higher_order_ops/wrap.py,sha256=5SVxK035Y4i0OCWYASbQ_I8A60CqxhhsKG8Y-i2m4TI,11508 +torch/_inductor/__autotune_main__.py,sha256=eIhLRMgy1BNr_lSUEfXk9aGe6xG95TF_Tt2WJROawts,946 +torch/_inductor/__init__.py,sha256=Bz-rtPs-4m-dzr18r296QGfkDsAU0Z_Ypqx1cl5hEOQ,13931 +torch/_inductor/__pycache__/__autotune_main__.cpython-39.pyc,, +torch/_inductor/__pycache__/__init__.cpython-39.pyc,, +torch/_inductor/__pycache__/analyze_preserves_zero_mask.cpython-39.pyc,, +torch/_inductor/__pycache__/aoti_eager.cpython-39.pyc,, +torch/_inductor/__pycache__/async_compile.cpython-39.pyc,, +torch/_inductor/__pycache__/autotune_process.cpython-39.pyc,, +torch/_inductor/__pycache__/bounds.cpython-39.pyc,, +torch/_inductor/__pycache__/choices.cpython-39.pyc,, +torch/_inductor/__pycache__/codecache.cpython-39.pyc,, +torch/_inductor/__pycache__/comm_analysis.cpython-39.pyc,, +torch/_inductor/__pycache__/comm_lowering.cpython-39.pyc,, +torch/_inductor/__pycache__/comms.cpython-39.pyc,, +torch/_inductor/__pycache__/compile_fx.cpython-39.pyc,, +torch/_inductor/__pycache__/compile_fx_async.cpython-39.pyc,, +torch/_inductor/__pycache__/compile_fx_ext.cpython-39.pyc,, +torch/_inductor/__pycache__/compile_fx_subproc.cpython-39.pyc,, +torch/_inductor/__pycache__/compiler_bisector.cpython-39.pyc,, +torch/_inductor/__pycache__/config.cpython-39.pyc,, +torch/_inductor/__pycache__/constant_folding.cpython-39.pyc,, +torch/_inductor/__pycache__/cpp_builder.cpython-39.pyc,, +torch/_inductor/__pycache__/cpu_vec_isa.cpython-39.pyc,, +torch/_inductor/__pycache__/cudagraph_trees.cpython-39.pyc,, +torch/_inductor/__pycache__/cudagraph_utils.cpython-39.pyc,, +torch/_inductor/__pycache__/custom_graph_pass.cpython-39.pyc,, +torch/_inductor/__pycache__/debug.cpython-39.pyc,, +torch/_inductor/__pycache__/decomposition.cpython-39.pyc,, +torch/_inductor/__pycache__/dependencies.cpython-39.pyc,, +torch/_inductor/__pycache__/dtype_propagation.cpython-39.pyc,, +torch/_inductor/__pycache__/exc.cpython-39.pyc,, +torch/_inductor/__pycache__/extern_node_serializer.cpython-39.pyc,, +torch/_inductor/__pycache__/freezing.cpython-39.pyc,, +torch/_inductor/__pycache__/freezing_utils.cpython-39.pyc,, +torch/_inductor/__pycache__/fuzzer.cpython-39.pyc,, +torch/_inductor/__pycache__/fx_utils.cpython-39.pyc,, +torch/_inductor/__pycache__/graph.cpython-39.pyc,, +torch/_inductor/__pycache__/hooks.cpython-39.pyc,, +torch/_inductor/__pycache__/index_propagation.cpython-39.pyc,, +torch/_inductor/__pycache__/inductor_prims.cpython-39.pyc,, +torch/_inductor/__pycache__/ir.cpython-39.pyc,, +torch/_inductor/__pycache__/jagged_lowerings.cpython-39.pyc,, +torch/_inductor/__pycache__/loop_body.cpython-39.pyc,, +torch/_inductor/__pycache__/lowering.cpython-39.pyc,, +torch/_inductor/__pycache__/memory.cpython-39.pyc,, +torch/_inductor/__pycache__/metrics.cpython-39.pyc,, +torch/_inductor/__pycache__/mkldnn_ir.cpython-39.pyc,, +torch/_inductor/__pycache__/mkldnn_lowerings.cpython-39.pyc,, +torch/_inductor/__pycache__/mock_cache.cpython-39.pyc,, +torch/_inductor/__pycache__/ops_handler.cpython-39.pyc,, +torch/_inductor/__pycache__/optimize_indexing.cpython-39.pyc,, +torch/_inductor/__pycache__/output_code.cpython-39.pyc,, +torch/_inductor/__pycache__/pattern_matcher.cpython-39.pyc,, +torch/_inductor/__pycache__/quantized_lowerings.cpython-39.pyc,, +torch/_inductor/__pycache__/remote_cache.cpython-39.pyc,, +torch/_inductor/__pycache__/scheduler.cpython-39.pyc,, +torch/_inductor/__pycache__/select_algorithm.cpython-39.pyc,, +torch/_inductor/__pycache__/sizevars.cpython-39.pyc,, +torch/_inductor/__pycache__/standalone_compile.cpython-39.pyc,, +torch/_inductor/__pycache__/subgraph_lowering.cpython-39.pyc,, +torch/_inductor/__pycache__/template_heuristics.cpython-39.pyc,, +torch/_inductor/__pycache__/test_case.cpython-39.pyc,, +torch/_inductor/__pycache__/test_operators.cpython-39.pyc,, +torch/_inductor/__pycache__/tiling_utils.cpython-39.pyc,, +torch/_inductor/__pycache__/triton_bundler.cpython-39.pyc,, +torch/_inductor/__pycache__/utils.cpython-39.pyc,, +torch/_inductor/__pycache__/virtualized.cpython-39.pyc,, +torch/_inductor/__pycache__/wrapper_benchmark.cpython-39.pyc,, +torch/_inductor/analyze_preserves_zero_mask.py,sha256=0HshUo4urIQshRM24V4smzLK1_USk7AHHd8gkZJbeUo,5618 +torch/_inductor/aoti_eager.py,sha256=7My54R9IiGfW2P7ZIwWLn-3U_kca8aO3you8Zdi_-xc,11436 +torch/_inductor/async_compile.py,sha256=DlrhdsP2fgqJ8lrsWbQjUKxgZBkoax0ZzPbYr17eago,20581 +torch/_inductor/autoheuristic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_inductor/autoheuristic/__pycache__/__init__.cpython-39.pyc,, +torch/_inductor/autoheuristic/__pycache__/autoheuristic.cpython-39.pyc,, +torch/_inductor/autoheuristic/__pycache__/autoheuristic_utils.cpython-39.pyc,, +torch/_inductor/autoheuristic/__pycache__/learned_heuristic_controller.cpython-39.pyc,, +torch/_inductor/autoheuristic/__pycache__/learnedheuristic_interface.cpython-39.pyc,, +torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py,sha256=xEyu3zitjcoKRiNERluGOfStBcldKARFxcbZPm7IO7I,28340 +torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py,sha256=KZ-6UQu8CRvDAZ2RrMzTp2auqS89Owc4yK6bz2SuVW0,30989 +torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py,sha256=FCJuKLuNRxCXIUgkGeBIsgduFXH_xKHlusHI_gf0KO4,8070 +torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py,sha256=caBG5DTSTivpZrmQmN6R1N6tMT4WreYdWiP-T8lhzSE,8031 +torch/_inductor/autoheuristic/artifacts/_PadMMA100.py,sha256=koaNU2XdPdYUbLBN-kqy0kJswOhmfbD3a33GDhwAKaY,5040 +torch/_inductor/autoheuristic/artifacts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingA100.cpython-39.pyc,, +torch/_inductor/autoheuristic/artifacts/__pycache__/_MMRankingH100.cpython-39.pyc,, +torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMA100.cpython-39.pyc,, +torch/_inductor/autoheuristic/artifacts/__pycache__/_MixedMMH100.cpython-39.pyc,, +torch/_inductor/autoheuristic/artifacts/__pycache__/_PadMMA100.cpython-39.pyc,, +torch/_inductor/autoheuristic/artifacts/__pycache__/__init__.cpython-39.pyc,, +torch/_inductor/autoheuristic/autoheuristic.py,sha256=6Fvx2IqpoqjSTwCuehRbGn2abv2qu7KQPwhAdPkopCg,12249 +torch/_inductor/autoheuristic/autoheuristic_utils.py,sha256=XZt712PVAoz5_bWmyn8cUkzAlm8bBRPQrKi1ILkQpnY,11620 +torch/_inductor/autoheuristic/learned_heuristic_controller.py,sha256=URCnHE8dztWFTROS2tV41HVd0R6ByfcAGIGURiT8hIg,4436 +torch/_inductor/autoheuristic/learnedheuristic_interface.py,sha256=wwpOxagaLghw5h2oyMjx4XRNJG7ZeG7zu1CiWjgNR9c,2947 +torch/_inductor/autotune_process.py,sha256=mKN7CckZFrUC59IjfYza5028ToJjSAYmvWnUs5CF-OM,30328 +torch/_inductor/bounds.py,sha256=Zd-zVuey4cA1uDtyXTP5VijuX8VwMJg_aVuvo-MsFuA,9956 +torch/_inductor/choices.py,sha256=ikAf3xs-v0NciV420eMY3Yu9XOpmL-C1_OT6DU1TDpM,18092 +torch/_inductor/codecache.py,sha256=zr8zcJZK4XrTsNegteMUGtM-7cVP5qAHCf_Eu22syNU,159579 +torch/_inductor/codegen/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_inductor/codegen/__pycache__/__init__.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/aoti_hipify_utils.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/block_analysis.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/common.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/cpp.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/cpp_bmm_template.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/cpp_flex_attention_template.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/cpp_gemm_template.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/cpp_grouped_gemm_template.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/cpp_micro_gemm.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/cpp_template.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/cpp_template_kernel.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/cpp_utils.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu_array_ref.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/cpp_wrapper_gpu.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/cpp_wrapper_mps.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/cpu_device_op_overrides.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/debug_utils.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/halide.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/memory_planning.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/mps.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/mps_device_op_overrides.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/multi_kernel.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/simd.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/simd_kernel_features.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/subgraph.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/triton.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/triton_combo_kernel.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/triton_utils.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/wrapper.cpython-39.pyc,, +torch/_inductor/codegen/__pycache__/wrapper_fxir.cpython-39.pyc,, +torch/_inductor/codegen/aoti_hipify_utils.py,sha256=AY73d3v-5U_LLnlm1eaIh-pzDmoYrG-LaOXfA2-MUzY,1351 +torch/_inductor/codegen/aoti_runtime/interface.cpp,sha256=EzDvCGzf9vVHtLDsmb8GhYICB4sR9xsyMXV1jD_zVqc,16788 +torch/_inductor/codegen/block_analysis.py,sha256=hmykxeukwAwfaQUXD237kSHKG39r2jBC7-QrjFQglRA,6854 +torch/_inductor/codegen/common.py,sha256=x8XSNKTw_FxPzGVweEGQUkiK24ZuK8b5Xn7AyBN1LMY,98569 +torch/_inductor/codegen/cpp.py,sha256=tM6-S7VyD5Pz5mW6XrdvQG5EZflXjbK5ockJcxl6XS4,229056 +torch/_inductor/codegen/cpp_bmm_template.py,sha256=Yk2lER20VbemJW8xY1NNZTBdt5ZhDEzdBcQMoeeXmDc,9621 +torch/_inductor/codegen/cpp_flex_attention_template.py,sha256=GjOUwkKxPQsb6s-jY1dNxv6pRdDy6t5ddYMJokZtq9Q,42105 +torch/_inductor/codegen/cpp_gemm_template.py,sha256=27Z5oBO2DjQvShy49heyJmgXcXbhC1jYNrM4q80vBUk,77657 +torch/_inductor/codegen/cpp_grouped_gemm_template.py,sha256=lE9k_Zo89txkhNNNMeC7Mqj6rpklseg34509RoI4sSs,20882 +torch/_inductor/codegen/cpp_micro_gemm.py,sha256=qVTab_ZAyMcg2uzBzWFXQaZ-wrMLRgGyuANMaXkRzws,71511 +torch/_inductor/codegen/cpp_template.py,sha256=qZMEe4X361abCLsgT4MuMrpnU_UOH-FRFPQNyu62lR8,5047 +torch/_inductor/codegen/cpp_template_kernel.py,sha256=9eBq2KsWEZcgpxmb_-gGKeiMe-i6ldh05vtJyNe6E90,25713 +torch/_inductor/codegen/cpp_utils.py,sha256=wkz1Qsc8mC1erUfDJFUWvJWWC4UYV-5Dg6Yw97taFlc,28536 +torch/_inductor/codegen/cpp_wrapper_cpu.py,sha256=aYyqGRc0vYDl1BhemxLBJvPQ0x0jwL2RraY6rfv-I9c,122134 +torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py,sha256=hEuq1R8DfMhx2Ok5fWsdHPa91IBPfgriJ62t_kbaNdU,39247 +torch/_inductor/codegen/cpp_wrapper_gpu.py,sha256=A01ERTPNeQutdwX7GWEC_481mO6Hs2Kf4ppRSpL2Jnw,29204 +torch/_inductor/codegen/cpp_wrapper_mps.py,sha256=n0XebpR0snDS6ahU7qE8WaYDRaDVtO249oczlKgtopw,3607 +torch/_inductor/codegen/cpu_device_op_overrides.py,sha256=BJGVNLiyDJ4uBj8tjyk1idNOo8o_CuLd243wSz_MOp4,659 +torch/_inductor/codegen/cuda/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-39.pyc,, +torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-39.pyc,, +torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-39.pyc,, +torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-39.pyc,, +torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-39.pyc,, +torch/_inductor/codegen/cuda/__pycache__/cutlass_cache.cpython-39.pyc,, +torch/_inductor/codegen/cuda/__pycache__/cutlass_presets.cpython-39.pyc,, +torch/_inductor/codegen/cuda/__pycache__/cutlass_python_evt.cpython-39.pyc,, +torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-39.pyc,, +torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-39.pyc,, +torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-39.pyc,, +torch/_inductor/codegen/cuda/__pycache__/serialization.cpython-39.pyc,, +torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py,sha256=paPBWhEP4jYNNw4F0DRTRt3WC2zmQaJRAuMqTNLniXA,12110 +torch/_inductor/codegen/cuda/cuda_env.py,sha256=L7uIb_crhaURbwB78h5zbcOZds1lh3pZtnJF1ne8JLk,1211 +torch/_inductor/codegen/cuda/cuda_kernel.py,sha256=UaAsyzWi8FOOrKwKoxTYkswYPDjEgjZTLVkLxbCvO14,24714 +torch/_inductor/codegen/cuda/cuda_template.py,sha256=GMh-ioBkMdOPjif8gnpy1ju9XuPdbT_UD4nQ-8V9NvM,11526 +torch/_inductor/codegen/cuda/cutlass_cache.py,sha256=GUN-aq-g-UfCgBG72BTjdiViEFNwQvZLn4ItlOG53Sk,3191 +torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-39.pyc,, +torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/evt_extensions.cpython-39.pyc,, +torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-39.pyc,, +torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py,sha256=V3TSDciaEe1w2wwGIZ-xbCWWdtWNZ4YLcElfa11ft0Q,10083 +torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py,sha256=mHbG8w3wYHDQpFgyNG8uuvA_BBGFTHc16UeLA_nqEcE,19057 +torch/_inductor/codegen/cuda/cutlass_presets.py,sha256=3zPQodL_aIAYazkkuf9okGn8swWgMxQFjMILjFKyy-E,25745 +torch/_inductor/codegen/cuda/cutlass_python_evt.py,sha256=f6oVF-pevOGkY0n6a41tNQurg7gehCR4_LkVSMGX-Mw,12056 +torch/_inductor/codegen/cuda/cutlass_utils.py,sha256=91agq9SBd7fFNG6pYv2G5Gk6jnFOo2KGuGFarXz0nwc,17539 +torch/_inductor/codegen/cuda/device_op_overrides.py,sha256=JPi6YJ0PCU2NYb1wC_p643IamwEbql6QxUfrdoWGIEM,15218 +torch/_inductor/codegen/cuda/gemm_template.py,sha256=LEzdBj3fsP6yWzfI2oMsaNeTZZ1EBerozgE624hU8V0,76765 +torch/_inductor/codegen/cuda/serialization.py,sha256=Yv_lt9bfXK2UiX0PZ5-FNShnCqDio_EEuWEBaD-CalQ,16934 +torch/_inductor/codegen/cuda_combined_scheduling.py,sha256=Bw4-IUwufXBfvEhfIXVQFH3wpoO1SNvk-mHtYliMYAg,5167 +torch/_inductor/codegen/debug_utils.py,sha256=_ZLe6CvUmBqyzLSBveM8j6SsIpklSyjxjyflgD-XIYk,11535 +torch/_inductor/codegen/halide.py,sha256=_mHxgC6zd_09ynJ52U08EY3K13JdVyGTqDHF_3bI12w,63915 +torch/_inductor/codegen/memory_planning.py,sha256=lMK-1F8pnrRXZIZvwDavxJEytoS67SEY5sw0VZx7h-M,25856 +torch/_inductor/codegen/mps.py,sha256=Edx1BThIr9WD4J93M8AVP2ZYSqM5ztfSC44bg3E06-Y,38815 +torch/_inductor/codegen/mps_device_op_overrides.py,sha256=l7XXqoZQ6IBUfIGhf-WN3sLifUBrraLj2TZXUHtmVMc,695 +torch/_inductor/codegen/multi_kernel.py,sha256=08fsg9I8k76kBnAAehsrCn--aU6Ibn1Z6OnZMi0lZWw,13988 +torch/_inductor/codegen/rocm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_inductor/codegen/rocm/__pycache__/__init__.cpython-39.pyc,, +torch/_inductor/codegen/rocm/__pycache__/ck_conv_template.cpython-39.pyc,, +torch/_inductor/codegen/rocm/__pycache__/ck_template.cpython-39.pyc,, +torch/_inductor/codegen/rocm/__pycache__/ck_tile_template.cpython-39.pyc,, +torch/_inductor/codegen/rocm/__pycache__/ck_tile_universal_gemm_template.cpython-39.pyc,, +torch/_inductor/codegen/rocm/__pycache__/ck_universal_gemm_template.cpython-39.pyc,, +torch/_inductor/codegen/rocm/__pycache__/compile_command.cpython-39.pyc,, +torch/_inductor/codegen/rocm/__pycache__/rocm_benchmark_request.cpython-39.pyc,, +torch/_inductor/codegen/rocm/__pycache__/rocm_cpp_scheduling.cpython-39.pyc,, +torch/_inductor/codegen/rocm/__pycache__/rocm_kernel.cpython-39.pyc,, +torch/_inductor/codegen/rocm/__pycache__/rocm_template.cpython-39.pyc,, +torch/_inductor/codegen/rocm/__pycache__/rocm_template_buffer.cpython-39.pyc,, +torch/_inductor/codegen/rocm/__pycache__/rocm_utils.cpython-39.pyc,, +torch/_inductor/codegen/rocm/ck_conv_template.py,sha256=2q1lzPH0ckgDdu9N3P8OoC3pJkl7UFvmGiAM5q3JbsM,25012 +torch/_inductor/codegen/rocm/ck_template.py,sha256=IJUS7tPpOcNFCcK8M4M3ctj5GYLlYSqWqnO_6TG8C0Q,3698 +torch/_inductor/codegen/rocm/ck_tile_template.py,sha256=oDtWXsTmZdbWk85J7iNLOWJ_KYR3hWfzj1lVfJ9MfEE,1542 +torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py,sha256=EJcuqlU6a5E6YhdTpi5YqqNbWMvbFxQhbkXWNwcB2HU,37323 +torch/_inductor/codegen/rocm/ck_universal_gemm_template.py,sha256=qmYl2pTihxagPW8dOuguNkc3IgZA4npT7dx7QNbKEjE,40526 +torch/_inductor/codegen/rocm/compile_command.py,sha256=k4HvGhtEk5omm6PChLsUdcWYE6s8ic0tnCNPUkqi9ig,4620 +torch/_inductor/codegen/rocm/rocm_benchmark_request.py,sha256=OnD5Di9SHKjE7DlYrMfDiLeGRP_hGMcfoOxhbOYjofg,5214 +torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py,sha256=i3_7drGiCkqack8NsE0ett33xGHckgzH_tnVsicLjN8,3898 +torch/_inductor/codegen/rocm/rocm_kernel.py,sha256=zv3ZvX7d8_Z3HbcrEjxpf5l73wDaTasmqfa-WQiNIiE,10662 +torch/_inductor/codegen/rocm/rocm_template.py,sha256=YLSofTgVHBXYSurD59AUHlf5KOj_F94bmsXyFVUAGyU,6829 +torch/_inductor/codegen/rocm/rocm_template_buffer.py,sha256=mj6vxl4SAqfawUZ5Gp_Hix_kSGv0PjF7ulGGArqhszM,854 +torch/_inductor/codegen/rocm/rocm_utils.py,sha256=zVqOgTO4cnfmoweD60_SrhTYwFf-CBMN2CIv9D51NPY,281 +torch/_inductor/codegen/simd.py,sha256=2bx13Mwkufj9XohDzEx20IQaMpRgvHaEWqc6OgRdB2U,98123 +torch/_inductor/codegen/simd_kernel_features.py,sha256=SrwnR82FXh-s_ITbGiBHkhbIUGG3fOUyu5Ig0wgA6Ic,24339 +torch/_inductor/codegen/subgraph.py,sha256=fyshxgYD0-ywG54SO8e9ctKoRZVIpfLKVYzTLKtsOYg,7630 +torch/_inductor/codegen/triton.py,sha256=sTO9xkjN5jQmlHJ2yK-EYd_pvKUAeTP-9AoyddAQi90,181775 +torch/_inductor/codegen/triton_combo_kernel.py,sha256=bNM9cq_4jtX0TW-NvO1jy_sPYpbpseNrj8pX-FsNJ8U,42391 +torch/_inductor/codegen/triton_split_scan.py,sha256=FMdGWFtRQj6bg2dsA2OFnasfQq0mz_gWnZO_kXzp3DQ,7450 +torch/_inductor/codegen/triton_utils.py,sha256=aBz6rq5EfGHABHqhky6drRKwB93kHszBy1dv0xoi-EA,9222 +torch/_inductor/codegen/wrapper.py,sha256=lbKq-si0NwVUH42Wb7AUepGstnm6NZ5pHShgi-jVIqo,138351 +torch/_inductor/codegen/wrapper_fxir.py,sha256=7g6x6uiGjD5PBZsgTVfH-fuVbfk0Nd8u3u7Z_QajpL0,25306 +torch/_inductor/codegen/xpu/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_inductor/codegen/xpu/__pycache__/__init__.cpython-39.pyc,, +torch/_inductor/codegen/xpu/__pycache__/device_op_overrides.cpython-39.pyc,, +torch/_inductor/codegen/xpu/device_op_overrides.py,sha256=rjmpZo8c7v1r_nlFOvxf9ij7UyTwCSXFrW6-9GiKHlU,1871 +torch/_inductor/comm_analysis.py,sha256=LHjV8WCTXS8T42JoJIXASTrTRGK0YcK5o4ZoYAFET_A,8550 +torch/_inductor/comm_lowering.py,sha256=J8lxAyxmQsP5E9DK1EH8IYBs8-rntG-dfhp3j45FP5s,13024 +torch/_inductor/comms.py,sha256=W6-ylvLDAeE9r6pUMvPVMhFPF_7ogyhFdrdmNxW68QM,43535 +torch/_inductor/compile_fx.py,sha256=QmRj4HM1-81LZQBcVp8I2fHZHD9ftTj5xpXcKeI0UK0,105541 +torch/_inductor/compile_fx_async.py,sha256=MP3uyBFWmJUnwZw_sJxY10cN8wRx25hQ1f9igGPJ808,6618 +torch/_inductor/compile_fx_ext.py,sha256=VNOcenf8KTZ0Vk6JL-sd1NqATunEoa8XldB2sIBm1cQ,23825 +torch/_inductor/compile_fx_subproc.py,sha256=p9O1TkKXDhl9YTRfGRF6YzzOuZkAQkJLHjtdE6tNzFo,3290 +torch/_inductor/compile_worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_inductor/compile_worker/__main__.py,sha256=9f6RijCYEp9cypg_1FUbshjLvluj2BMynSQzEf5-9kU,2325 +torch/_inductor/compile_worker/__pycache__/__init__.cpython-39.pyc,, +torch/_inductor/compile_worker/__pycache__/__main__.cpython-39.pyc,, +torch/_inductor/compile_worker/__pycache__/subproc_pool.cpython-39.pyc,, +torch/_inductor/compile_worker/__pycache__/tracked_process_pool.cpython-39.pyc,, +torch/_inductor/compile_worker/__pycache__/utils.cpython-39.pyc,, +torch/_inductor/compile_worker/subproc_pool.py,sha256=WiqcZzpKE4n85drcSZnCD1Vq4dMGVr3YYwFNONjZbS0,13598 +torch/_inductor/compile_worker/tracked_process_pool.py,sha256=dnIDrH0rHQU0UJvBsfdwJ-eVJWw53hpSZ6q6_SVZ2Gk,3731 +torch/_inductor/compile_worker/utils.py,sha256=GTOA12OH6BCKN6_mIY0nApeg17kM9fUiXFuJyyeo0Q8,1555 +torch/_inductor/compiler_bisector.py,sha256=H9aBp3W5_azdcKZ5XPYOmP7bbHrrXCxy2U7fc6PbtGw,23046 +torch/_inductor/config.py,sha256=nGjlL7c3GA54QidxsY0ZUb7k29ZBZxBGixYTj_6BN50,72115 +torch/_inductor/constant_folding.py,sha256=xIWbdzLBOtcN9UpfvzOfDG425sfOBXC22mgaAKyTc_Y,15631 +torch/_inductor/cpp_builder.py,sha256=aHnnPXiMFRG20II48KXy4INjPi_iTFTdOVmtk6cS90g,69510 +torch/_inductor/cpu_vec_isa.py,sha256=nrImOxsZjkErn_5L7PLzDfsJPNDmQ-W7nnt4AKPwmKs,14482 +torch/_inductor/cudagraph_trees.py,sha256=GZcuECcFeq0jFvHJt32_VpjnQIC0JqLrZzWJ_p9WDRQ,106129 +torch/_inductor/cudagraph_utils.py,sha256=Hh0KtdU8mDCq9fM_oOmM2EwFwc43zul00GKiUYMUMEQ,14294 +torch/_inductor/custom_graph_pass.py,sha256=vEbsN-gFb8TYL8hK3DTdS9723t1lIpeiWX7vSj7ioIk,3962 +torch/_inductor/debug.py,sha256=Sr3Pb2cMBgOJeCiAaO1Cx_GGQBZLx8kgEeUHQ7wzCw4,33993 +torch/_inductor/decomposition.py,sha256=rG4mGgPq467H-OEEVNNWpw7peno3RKbe1Q6dMmpgmtc,38139 +torch/_inductor/dependencies.py,sha256=gZd1qp9nAp-xB0DbIk3bBjGmaHSlGsNfTM4DwJbujbY,30497 +torch/_inductor/dtype_propagation.py,sha256=r497d4FE86YLP-Zg4iWQmBY0gxO-Ym9OMcT0gOnnzFw,11877 +torch/_inductor/exc.py,sha256=iTTJ7wXTQKuXidWiBs_8Nl52lOc8XLthCXUl1_FXHyo,4808 +torch/_inductor/extern_node_serializer.py,sha256=9iaLxk6PJxF4UEtpMokFJDrMl32gZPIlkAnkNB-fKtY,854 +torch/_inductor/freezing.py,sha256=OJ7yDfDOSeC38HYcJE0W3I5FvUIjT8qPahPMUtw2T0Q,11164 +torch/_inductor/freezing_utils.py,sha256=A1mg2kmNjy-O3YxkKI04onOKCZ-7XhCURmSuN9U45fA,1323 +torch/_inductor/fuzzer.py,sha256=2L6miNDcnc_j2eAueXrc2BK2Y1q6lWYuQWdhVYGvmtY,37872 +torch/_inductor/fx_passes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_inductor/fx_passes/__pycache__/__init__.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/b2b_gemm.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/ddp_fusion.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/micro_pipeline_tp.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/post_grad.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/quantization.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/reinplace.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/replace_random.cpython-39.pyc,, +torch/_inductor/fx_passes/__pycache__/split_cat.cpython-39.pyc,, +torch/_inductor/fx_passes/b2b_gemm.py,sha256=nXUa5blS7z1cQqqkIN-usnLMI_5MyzcOSacoyQQMMCE,25726 +torch/_inductor/fx_passes/binary_folding.py,sha256=RGlr9hiFSxm8fHmVjOi0J5tDe0amyoJYKfAWj8qUIM8,20353 +torch/_inductor/fx_passes/ddp_fusion.py,sha256=shZ4bM3Tz6iSohq9lPaIMw2VJbfyamlSsdc4UCjhxWQ,21717 +torch/_inductor/fx_passes/decompose_mem_bound_mm.py,sha256=htn8hcf35IYdduH2pDLzWzvk0mcCx4mQPtSTmNf-G-o,5481 +torch/_inductor/fx_passes/dedupe_symint_uses.py,sha256=8TnkzP49-in-yBPL5pO4OiTXaSBQPZhU960v79SMFrI,2598 +torch/_inductor/fx_passes/efficient_conv_bn_eval.py,sha256=Ir1AGCwxUwk3T7fySR51zGe6xKH8eFd5v10leY1nlJQ,14487 +torch/_inductor/fx_passes/freezing_patterns.py,sha256=1W-EJx64RWUSp5fcU_-vNrXgpU4sJwyWXhsBEOAg1Rc,9327 +torch/_inductor/fx_passes/fuse_attention.py,sha256=T2ZL_JOJwyf3DRx2g4IOYpCGo0_SjDIFOz0dNt1nZ4k,36147 +torch/_inductor/fx_passes/group_batch_fusion.py,sha256=_eapKPU3DuRjpK_yNk-2_uz5Q3Nj6goRKPEOxlYM3Qs,60133 +torch/_inductor/fx_passes/joint_graph.py,sha256=QyE0K1lFOrt-k2fGmySea4SRHBt8p5JqO7TBAya3iJ8,34589 +torch/_inductor/fx_passes/micro_pipeline_tp.py,sha256=7yGHdmGS3SmyPUos1lehAH40jBPQckgefPeQKVivduU,40086 +torch/_inductor/fx_passes/misc_patterns.py,sha256=LcIri6-K7w1DqGC8m9-4q8dRZtrOTgUS9tjA-ZW29GM,4918 +torch/_inductor/fx_passes/mkldnn_fusion.py,sha256=F_1hoGNsqUCJpEb4X6E3mpJKR2Ij0VfI_c6iDT1yX0c,61598 +torch/_inductor/fx_passes/numeric_utils.py,sha256=9lnBQc6bo50bdYqjMxW-qWrlfqtDqAaycyfGojnzdrs,7491 +torch/_inductor/fx_passes/pad_mm.py,sha256=PPx0egWy-rYaJmzXXp0Kt9siuT3YitVrROa2rpi9BJk,30315 +torch/_inductor/fx_passes/post_grad.py,sha256=3OvSnW573ZfdxUjE57oaQgr-0J93CoAkWR41lt9506w,64795 +torch/_inductor/fx_passes/pre_grad.py,sha256=kfHtdOb9YlIOSotgiiv3PEcnRUyHyEMN761LkHEC3JQ,31394 +torch/_inductor/fx_passes/quantization.py,sha256=-fi0nZ5MlisecXs6vgq9usGq39-E2XDUMaAiv29wE3U,146201 +torch/_inductor/fx_passes/reinplace.py,sha256=gNuU_WPgGDkmuQSzEAlUUndaX6FuDtDoyaWEF8iQNUE,30608 +torch/_inductor/fx_passes/replace_random.py,sha256=2lf2xXqo0vGMvCQcrz-WxBx5CnlQCPV5Z_c7aJd6cf8,4195 +torch/_inductor/fx_passes/serialized_patterns/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_18.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_19.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_20.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_21.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_22.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_23.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/addmm_pattern.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/bmm_pattern.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/__pycache__/mm_pattern.cpython-39.pyc,, +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py,sha256=iM_6rnK6wv_zM-7M3CIb8AAAKhaD0sEQVnXYWZwnqzg,11351 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py,sha256=uRQxBa4zFL7oVqqy6ENGBWQX1JopKbysnWMrStBN3B0,14421 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py,sha256=LlSUr05EKQzSMpg2K8I9pXje5YicNKRGyIWGkSDEsfY,14191 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py,sha256=ipiXNX_AGdA7-rLJMFNdslRKoilIlLStounMBeikUCI,15477 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py,sha256=4oxOhW3xKEgLEl8vvWY8rHHjDggES8izC2fJqjBIKWg,7992 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py,sha256=14IUflF33gBpwLbEM_rDxAUraZe287uVbMZ5EKYBz-o,14533 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py,sha256=5iMVIsb_4OFZZKgePv9qYju0D4i2GLPtyLY0gp8TwcQ,16443 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py,sha256=yp3WyHDz7tUTKHOlyXF-HWExbTD0g9qLMC0VrEN7sM4,44195 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py,sha256=jPYif6T4ov4rLC3ZMc7tN2YP6niIYyXjjTNIRTvGTC4,17687 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py,sha256=cn0rJftxEgbabluDTvZma5BLd7dCOG5_jletm0ID-Ds,33189 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py,sha256=bRr9BwZUA9jBCCiiBWV8zAdR3HZ3FX0ph5uVwTVpwDs,14255 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py,sha256=CR1xIUsLPSLJZMm8IimSmmCfoqkZKiEblTwWa1Rbc_8,11361 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py,sha256=ByUPO43iUjt5RdBU6H4NmqetkIpqNQkf46VjvaQN-UA,17585 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py,sha256=RVYzFW3UniE0fojcpb_hiFWAqVrsmf4rPYbVNOn-D7Q,15283 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py,sha256=R6uY4nBdTqg6ilNPVqB7gBnt09SFMuyceTt6gktkXyg,15633 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py,sha256=j_GMX-G-pFPyLgGEXI_Sefja6LU-F7uJgpiBoAvS12Y,15543 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py,sha256=Rf4cfoPb5pp_-aZ5lMMNSuhoxzFmya0SSnu0nL78YQQ,12637 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py,sha256=9RXNFbH4qnZbz8QEKHJ2XaoCrj5fH1ZWk5zXxN7SD2U,12601 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py,sha256=2ET6A50Qc7c4Hihy_gyW_6bU17E3HCRXuW4UMt2tLZQ,11591 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py,sha256=uzXMO9ifbdetAg1DtGeH-IM9_cSJUpfJSRx9HMDHfOQ,12835 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py,sha256=zIE2r3YVlog0iKUyxG09yBLAdZpmzjUecW9aTflDsm0,15657 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py,sha256=CDm_1CudB8wZrfw1EnjMGTlcRnIMOtz2VCM4GdgGGus,14409 +torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py,sha256=f_WJ8-i7G9KGrz9KfyPmFtGa3-zFUaCwv6N7zATyMBE,15665 +torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py,sha256=qvjHnXM6SV8Qc6n_Vf9fJQk9lrQOBZ2qxuBO8jEOqRs,1911 +torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py,sha256=P0ZNyqXqPsyHKQKpyAybcLNCVORkhJCY4WsN0tUbsyg,1317 +torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py,sha256=_BMNCRyTiU7hFEgosp_AKkAUsizpnO2sh-AX5_4M91I,1305 +torch/_inductor/fx_passes/split_cat.py,sha256=7KF7lMVT1Jxo9FfYKGAF9CbHZWHnVZJyOvibHw1FHTs,121290 +torch/_inductor/fx_utils.py,sha256=EKfkHzj7rjUU8FBCvMSDpuRL8FmxxHXIusy0d1MDXkE,10349 +torch/_inductor/graph.py,sha256=Xc4ZBPPIQGzB6gLB11htOVEW8wOJtJQQyk3YnKTaDXM,106266 +torch/_inductor/hooks.py,sha256=FdBqoaeoSDsdJtdH_-hMjjMt2ngAWWPLS3OC0BwjUeo,669 +torch/_inductor/index_propagation.py,sha256=88F_eG5peYsjgoUu_XyAmLXY19GNZ9s2BQdA18Ay8lM,13238 +torch/_inductor/inductor_prims.py,sha256=a_1vANd-sobCtWmEWwgx67Qkfo5Ozml2ybPDx-mN6J8,7526 +torch/_inductor/ir.py,sha256=wGyXwG0eg9nqNLqZROabHY2-M0fFc9ZW22xSdXJATvk,312907 +torch/_inductor/jagged_lowerings.py,sha256=BEKFSGXG2JJ18Ynqtjlh9V1xTh2_cY77UUdv_Gghh34,9234 +torch/_inductor/kernel/__init__.py,sha256=iK2dUCFpus-1jXlPvFruvIAzGweWBgsu7jPR5hKRP4k,41 +torch/_inductor/kernel/__pycache__/__init__.cpython-39.pyc,, +torch/_inductor/kernel/__pycache__/bmm.cpython-39.pyc,, +torch/_inductor/kernel/__pycache__/conv.cpython-39.pyc,, +torch/_inductor/kernel/__pycache__/flex_attention.cpython-39.pyc,, +torch/_inductor/kernel/__pycache__/flex_decoding.cpython-39.pyc,, +torch/_inductor/kernel/__pycache__/mm.cpython-39.pyc,, +torch/_inductor/kernel/__pycache__/mm_common.cpython-39.pyc,, +torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-39.pyc,, +torch/_inductor/kernel/__pycache__/mm_scaled_grouped.cpython-39.pyc,, +torch/_inductor/kernel/bmm.py,sha256=0R-GtGSVxYu0SzLz3yAeTPimqIfNKbYDoZJyTlKgf9M,10102 +torch/_inductor/kernel/conv.py,sha256=9j4Sw7mvpXVh6mmADzie-sRCU9pGR44CbJGUwehqi6A,22010 +torch/_inductor/kernel/flex_attention.py,sha256=jbG22denhbhRxeOovGNCHWCadth2NRifyEiHWwg8DoM,103116 +torch/_inductor/kernel/flex_decoding.py,sha256=NyZgv5FzjelLijsz9e9aIy_1a_lAVsbEhuXKiUjMvRI,23889 +torch/_inductor/kernel/mm.py,sha256=OGItdVFkl5fRe9pvb6-acSZb4OJ2DpCRDHrU1UWxe4U,46222 +torch/_inductor/kernel/mm_common.py,sha256=GU-b5O4Uy9EEpbDxLO-egUFuE8dcXrMWBLFuf14Vxwg,9201 +torch/_inductor/kernel/mm_plus_mm.py,sha256=c3xgFae6CG9u9NGLXS50fqCdWR63N1pKtQMxZ6SE9j4,5594 +torch/_inductor/kernel/mm_scaled_grouped.py,sha256=97_qIuBX3MOJKZ9_FfXE17TmekFxhZi93RyINT6F2aE,23128 +torch/_inductor/loop_body.py,sha256=RDo4gqzGItlLzw6Kszy3fiKMTB0uqBPsHm4DWjH0fhc,24993 +torch/_inductor/lowering.py,sha256=pe6_ad2hRLBB4hZyizKzau1wN66pbv22wRkuyijzvWw,239215 +torch/_inductor/memory.py,sha256=CkddJhvVvb6cfuq-7FJQspkFhHA17uCcJ79wv1apbV0,26777 +torch/_inductor/metrics.py,sha256=Ks-aMQIbamDhOwJpEdcdHk0MYaFgGu2P-kJnvut3gGM,14325 +torch/_inductor/mkldnn_ir.py,sha256=niOr--cQn8oQmx_H2ddMLi3BozmGeixbl1V9i1AKlk0,43227 +torch/_inductor/mkldnn_lowerings.py,sha256=jWXhxjWHMS9bLoHpF4nyRqgKQTND1463NWfhDYa02Fg,55560 +torch/_inductor/mock_cache.py,sha256=YYm2FRXBX0ywp-Y-tE3GGEmCptzmu9v5c46ZQcA2ti0,8829 +torch/_inductor/ops_handler.py,sha256=JXdsWtyzij6jEcRE875YOjWiexPRNVsvJbCf6bWnhlI,36674 +torch/_inductor/optimize_indexing.py,sha256=Bi5Wa5unzm0gDtBuildjKPz5ffyzCnuwW0eOkKxisrA,4261 +torch/_inductor/output_code.py,sha256=QiLJoVeA1itH2MVHnv1EdEEfKNPozN2TWmbvmIz3xiw,30299 +torch/_inductor/package/__init__.py,sha256=sQTRXIg6O5rZQYURRT-cJ2r30R2AjcIO2ofzbbbfvzc,68 +torch/_inductor/package/__pycache__/__init__.cpython-39.pyc,, +torch/_inductor/package/__pycache__/build_package.cpython-39.pyc,, +torch/_inductor/package/__pycache__/package.cpython-39.pyc,, +torch/_inductor/package/build_package.py,sha256=eInHlzpquSyi0W0tjhjUu_bhfEjXTIoGLRpHxJZfI-w,344 +torch/_inductor/package/package.py,sha256=T9L5OEjIJkRI5IsFakArsZjlEvOibnrulEvelBhVrg4,4591 +torch/_inductor/pattern_matcher.py,sha256=oXPGIafJ1NiVvgRgr719_l_gkiHr1WZFWxMCubXPWwU,82881 +torch/_inductor/quantized_lowerings.py,sha256=VpnlU1X3gkO17ZdsDEuN4NehFZhv-8JXy37PIMq8BTY,5814 +torch/_inductor/remote_cache.py,sha256=ocSGgAKzd-y5H71VTTyGQOkL2dxqKZjlrdlbzJYh47s,13212 +torch/_inductor/runtime/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_inductor/runtime/__pycache__/__init__.cpython-39.pyc,, +torch/_inductor/runtime/__pycache__/autotune_cache.cpython-39.pyc,, +torch/_inductor/runtime/__pycache__/benchmarking.cpython-39.pyc,, +torch/_inductor/runtime/__pycache__/cache_dir_utils.cpython-39.pyc,, +torch/_inductor/runtime/__pycache__/compile_tasks.cpython-39.pyc,, +torch/_inductor/runtime/__pycache__/coordinate_descent_tuner.cpython-39.pyc,, +torch/_inductor/runtime/__pycache__/halide_helpers.cpython-39.pyc,, +torch/_inductor/runtime/__pycache__/hints.cpython-39.pyc,, +torch/_inductor/runtime/__pycache__/runtime_utils.cpython-39.pyc,, +torch/_inductor/runtime/__pycache__/static_cuda_launcher.cpython-39.pyc,, +torch/_inductor/runtime/__pycache__/triton_compat.cpython-39.pyc,, +torch/_inductor/runtime/__pycache__/triton_helpers.cpython-39.pyc,, +torch/_inductor/runtime/__pycache__/triton_heuristics.cpython-39.pyc,, +torch/_inductor/runtime/autotune_cache.py,sha256=Y-AEHQbsU8pnKMcuZennAE7ArkVEoxeNBtn-rHdDB_A,23376 +torch/_inductor/runtime/benchmarking.py,sha256=u166zZ_n1jqnZGufVpwr0RbxbXC7Er_m9oNDnck59V0,11508 +torch/_inductor/runtime/cache_dir_utils.py,sha256=o-C1G9qJ6-UtnkOaxQcskpci5UGsHL5M8cqB9QpLgMY,1498 +torch/_inductor/runtime/compile_tasks.py,sha256=wZDokIim3O9TJCScBEFTu3mhBis0GtAG6QKWWK7W1tM,2084 +torch/_inductor/runtime/coordinate_descent_tuner.py,sha256=GNZ49fQSlwTmOolAq39zbx448bsXTRsS03pRgL5WVnQ,10348 +torch/_inductor/runtime/halide_helpers.py,sha256=1aC3s7wXawGzBhbAAmEPERm6zf7wb3qT2qPtQb3gpzg,3660 +torch/_inductor/runtime/hints.py,sha256=Afqnvxv3akcvAj32QXH9KTQamFDKcrq4JQtQywQHkNM,7258 +torch/_inductor/runtime/runtime_utils.py,sha256=vDSnFlsDLfnnJjpzlwp8RqipBSN96m0An9Bvnp16TvA,5166 +torch/_inductor/runtime/static_cuda_launcher.py,sha256=XnJI8BUV4xvvNApLNGugiJ2_CyGKiFVuIxK8YoTGzio,9498 +torch/_inductor/runtime/triton_compat.py,sha256=6v7mK5XADJ8HSz4AgPTefGxlVBG0YpXbmGUs6dVktb4,4260 +torch/_inductor/runtime/triton_helpers.py,sha256=IuImcvMMAK0fIv1mFsVcSqlpFUYZOl4VEvhXFzzjqkw,23835 +torch/_inductor/runtime/triton_heuristics.py,sha256=Cky9OF-jkS2YsCR-KP61EzeaeWlMhydjvRJH7L-gfmM,114521 +torch/_inductor/scheduler.py,sha256=pqO3h5UYxQri7g1Thu39UlIQ7rCTuIpSIuerkaTUpeI,200819 +torch/_inductor/script.ld,sha256=gJa5rj-3zNDRd7_t4qAJzFpR8V_v8QhcybAM0-G8hDc,438 +torch/_inductor/select_algorithm.py,sha256=idvnm106S10g4WGF5nPAPOeU0PzOaNa-mueDAAWXYx4,120171 +torch/_inductor/sizevars.py,sha256=hPwhbWTPxAuq92jUeBaHCvAUfpeXmhfgTNLLgN9b8yA,38785 +torch/_inductor/standalone_compile.py,sha256=S4CtQ0ilZdIowBp_4KJ83GUmna2B_D1R7999rp8QseQ,9798 +torch/_inductor/subgraph_lowering.py,sha256=abzs_5aW7yDUbhaCAjD0ZSrZVfiE7jnt4RvZ9eqXdUk,7559 +torch/_inductor/template_heuristics.py,sha256=LbzNmHNyFIoGjJ8Togk47WMhrprb_d2VJitEm_UypYE,45746 +torch/_inductor/test_case.py,sha256=DllXJAzaBXqyU61iR1EMXm4V3rVbevuj4EUxd10pc3o,1426 +torch/_inductor/test_operators.py,sha256=ir56rnHMgO_4GGCDzwjIBCKA0TribsSsUpBzF3bVY90,962 +torch/_inductor/tiling_utils.py,sha256=9nl-ymqXKTStY-tXzFk-HJFUt1nDnChCM-tFEKcLjD8,26658 +torch/_inductor/triton_bundler.py,sha256=7i7SskB3q4CXUx2jwd7d-TOzuoojHqglMlTM1Jbb_pg,16558 +torch/_inductor/utils.py,sha256=QLU6zdFL3tYMPhw4uRNyr2YNy5tZxR46Ar3W8zES0sk,105914 +torch/_inductor/virtualized.py,sha256=g8ei5IhER0U7DkUmpFp6HOFKP0YRY4OXxGbdvMl8OEg,14481 +torch/_inductor/wrapper_benchmark.py,sha256=piU-QSLHt2oF-duJGB8mFSVtjGgAZYlOWoyQ2DsmAho,16341 +torch/_jit_internal.py,sha256=EJwafi3B_3CLn9t29YYpdW-Fz_5plJnbJOItFUogJGk,55429 +torch/_lazy/__init__.py,sha256=GKBTkAfZ9pnbRwoWKT2GBUUBbC1WbSec_ufn7ekDhSY,1848 +torch/_lazy/__pycache__/__init__.cpython-39.pyc,, +torch/_lazy/__pycache__/closure.cpython-39.pyc,, +torch/_lazy/__pycache__/computation.cpython-39.pyc,, +torch/_lazy/__pycache__/config.cpython-39.pyc,, +torch/_lazy/__pycache__/debug.cpython-39.pyc,, +torch/_lazy/__pycache__/device_context.cpython-39.pyc,, +torch/_lazy/__pycache__/extract_compiled_graph.cpython-39.pyc,, +torch/_lazy/__pycache__/ir_cache.cpython-39.pyc,, +torch/_lazy/__pycache__/metrics.cpython-39.pyc,, +torch/_lazy/__pycache__/tensor_factory_functions.cpython-39.pyc,, +torch/_lazy/__pycache__/ts_backend.cpython-39.pyc,, +torch/_lazy/closure.py,sha256=ebuMc8GDhvoBtxqSKp0hwBUIC2t8N51t5q9_o0Aav3Y,5707 +torch/_lazy/computation.py,sha256=Jd0b-aVk5Ek29dbMkt0Mov8Ysn1r5nqN8At7p17P6wM,946 +torch/_lazy/config.py,sha256=cSqerqEAgqTx59qy4GHAsrxA7xmBbNG9cg2_Uk1e0_Y,464 +torch/_lazy/debug.py,sha256=Nr-aSxYVQzOAGGIrelenLGc7Pc0xS0Qsl8stUg2XUyY,760 +torch/_lazy/device_context.py,sha256=5MeMBMnfsOnTq0Wx5F9ZhD3TFuiNL_sZ0GjRU9i5ORk,706 +torch/_lazy/extract_compiled_graph.py,sha256=riYKsAjGzl0CCUGw8Z22i0XCAT3udtri4vvJD7nwxcw,8648 +torch/_lazy/ir_cache.py,sha256=WOWccEwnoJyQnhzy0h78xPMB1ziKKx5wrIwYo4U-fng,362 +torch/_lazy/metrics.py,sha256=gMdsK2JUC5dFO4BQCCm5NypfTJmWBONucDDsxtDpp0M,567 +torch/_lazy/tensor_factory_functions.py,sha256=BotJAvC1li1Dj1YaOyd7tdI1A5ME-JfofMm_BV8oxIc,1417 +torch/_lazy/ts_backend.py,sha256=RR7S12OpeqSkVQrspwOcROc-ukN01WvBi_84YbLucCA,170 +torch/_library/__init__.py,sha256=hKRHdphnfJh6GXSYgjf75aVZ4but1yjVCMPFybY1a2w,275 +torch/_library/__pycache__/__init__.cpython-39.pyc,, +torch/_library/__pycache__/autograd.cpython-39.pyc,, +torch/_library/__pycache__/custom_ops.cpython-39.pyc,, +torch/_library/__pycache__/fake_class_registry.cpython-39.pyc,, +torch/_library/__pycache__/fake_impl.cpython-39.pyc,, +torch/_library/__pycache__/fake_profile.cpython-39.pyc,, +torch/_library/__pycache__/infer_schema.cpython-39.pyc,, +torch/_library/__pycache__/simple_registry.cpython-39.pyc,, +torch/_library/__pycache__/triton.cpython-39.pyc,, +torch/_library/__pycache__/utils.cpython-39.pyc,, +torch/_library/autograd.py,sha256=hjYpEP2E12pcjV-X5eL2-hX_0Xe-bjuP1wYrjb59iYw,8968 +torch/_library/custom_ops.py,sha256=2nwH71BQp7MZetbexpT6DnXIAadx-iY987u2gPUObIY,38470 +torch/_library/fake_class_registry.py,sha256=e7wjiulPXrAm_FG5Mh0Bk-vKsZZyKYwcD9b2fqfYyiE,13122 +torch/_library/fake_impl.py,sha256=9T3UBL8cEYUoHxq4FJyK21kRilRKOUkbu1PW0WCmHOY,8990 +torch/_library/fake_profile.py,sha256=rUkXyIyQwizwLy5m1bJTmuy1G06dXhMyVZLKo7OfVtQ,11811 +torch/_library/infer_schema.py,sha256=7Rmg8J5ylqocD9nJZ2TthNH3FDgL7bzZ0lPk2riFXoc,12912 +torch/_library/simple_registry.py,sha256=obJBykxD52tEGM-G-mjyG6vvIfYyqNEvS9Cm5ogiLag,2723 +torch/_library/triton.py,sha256=vlwpv9DFYhWiP90-cKZY4Ao1stgOG7SaSgcEHHXbeUQ,11961 +torch/_library/utils.py,sha256=KK3_N9EEawF3pFTg4ptdC_rX6x44VsJ0IOYjE9gTXxQ,18824 +torch/_linalg_utils.py,sha256=ohOWpYzZfqromwQwGojiSly17v6CxU0g7cQdwuJwVD8,5314 +torch/_lobpcg.py,sha256=YnHJG7kBN9SA2JuUb1OVr-GvjkzhlWxDUtOoCsq88GM,44581 +torch/_logging/__init__.py,sha256=H9CB5w-ZPLwmOQ-uZp8-gz1tEVNWnFk1BnLWryxnekY,818 +torch/_logging/__pycache__/__init__.cpython-39.pyc,, +torch/_logging/__pycache__/_internal.cpython-39.pyc,, +torch/_logging/__pycache__/_registrations.cpython-39.pyc,, +torch/_logging/__pycache__/scribe.cpython-39.pyc,, +torch/_logging/__pycache__/structured.cpython-39.pyc,, +torch/_logging/_internal.py,sha256=cFgJR_3IT5b5E_Lia9dtTFGlPrgrXxXw3TKfYpfi1UM,50146 +torch/_logging/_registrations.py,sha256=Ok0P_ZovA8xyAgCsT2uNmRnPWI_8Ai-ji7K0iAeoel0,8240 +torch/_logging/scribe.py,sha256=BbcOqqFejXTtxQ69sPBtkSX8dQXWPaUM1IMcowA7wmU,2641 +torch/_logging/structured.py,sha256=706TifsIEO3nt3yNODwiEzREWAvMOs6b7w27QRB3TAU,3035 +torch/_lowrank.py,sha256=IyeWMjPnewKBCJkXLJvnD2yQyJUF5DHbR1l_Dhaq_y8,10848 +torch/_meta_registrations.py,sha256=-tcU_BpZSLf5KsRCRKcuOyCjSUVn0T6cddUDXHtstVU,259400 +torch/_namedtensor_internals.py,sha256=DE-bDTDM87iai21aDg_KS9zF17MCGSu-I0oE5cBG6gg,5449 +torch/_numpy/__init__.py,sha256=rzNKs-1_I8dcRkXHRDOsJ-SLjviebJ7G3JB_8exvcfY,590 +torch/_numpy/__pycache__/__init__.cpython-39.pyc,, +torch/_numpy/__pycache__/_binary_ufuncs_impl.cpython-39.pyc,, +torch/_numpy/__pycache__/_casting_dicts.cpython-39.pyc,, +torch/_numpy/__pycache__/_dtypes.cpython-39.pyc,, +torch/_numpy/__pycache__/_dtypes_impl.cpython-39.pyc,, +torch/_numpy/__pycache__/_funcs.cpython-39.pyc,, +torch/_numpy/__pycache__/_funcs_impl.cpython-39.pyc,, +torch/_numpy/__pycache__/_getlimits.cpython-39.pyc,, +torch/_numpy/__pycache__/_ndarray.cpython-39.pyc,, +torch/_numpy/__pycache__/_normalizations.cpython-39.pyc,, +torch/_numpy/__pycache__/_reductions_impl.cpython-39.pyc,, +torch/_numpy/__pycache__/_ufuncs.cpython-39.pyc,, +torch/_numpy/__pycache__/_unary_ufuncs_impl.cpython-39.pyc,, +torch/_numpy/__pycache__/_util.cpython-39.pyc,, +torch/_numpy/__pycache__/fft.cpython-39.pyc,, +torch/_numpy/__pycache__/linalg.cpython-39.pyc,, +torch/_numpy/__pycache__/random.cpython-39.pyc,, +torch/_numpy/_binary_ufuncs_impl.py,sha256=m6w0fYCYwaKH8IolQmS_vBIHmSoDjKe5l-YfRIOsU_8,1956 +torch/_numpy/_casting_dicts.py,sha256=pNVF9rNC9TFps2ktvWau9H-GOtUBDx17UQ-AEaRuefQ,43846 +torch/_numpy/_dtypes.py,sha256=bATVCRn5KFiYicTfRIhVNxUMacwUv-T1072yawIvwRk,10779 +torch/_numpy/_dtypes_impl.py,sha256=5pEM3baVa5OLRPMkW5VDjrZHo8ljtnICh1PklpjDuzQ,6124 +torch/_numpy/_funcs.py,sha256=bjhFuiOr_fYT-tJuwy6U3T_iW_7bNLKIqPznlisWxIw,2173 +torch/_numpy/_funcs_impl.py,sha256=hoj-HCuYjMHNACKaPnLzMLo0K647jJoie_aA7gM_dYU,61297 +torch/_numpy/_getlimits.py,sha256=9pCtO5PQvHygMv2keLR0-AlcM3zKwWjYQ5pvLxw9ASE,284 +torch/_numpy/_ndarray.py,sha256=FGIz_XvmWul1TsJ7HND8IppLF5tBBNtEVFbZqIy6z5g,17236 +torch/_numpy/_normalizations.py,sha256=duSg9Nq81ofi95BghXP_dzqVJSO6RX0HrahD7cTEzMg,8508 +torch/_numpy/_reductions_impl.py,sha256=T38ZyTuKN6N66pqt9ln4EBI8VNodqquvHtUBlQSzKJg,12259 +torch/_numpy/_ufuncs.py,sha256=4wJeqxLv37PN-ORu7jw0KqBdIOE4GdJKmpvMpfMqelQ,8700 +torch/_numpy/_unary_ufuncs_impl.py,sha256=oTAovTAdhzexww7B4rR9u2ONK7heDWc2Q52vuJ5hSjg,1233 +torch/_numpy/_util.py,sha256=1YYs7sk14tps2BNNOhIjQ14qJ_-2o_-EFaxpPk2UgE0,7818 +torch/_numpy/fft.py,sha256=CQRzqq_1Ys1j1ahExCTp5U8ZBigloBXK1kcoyAL0T9Y,2935 +torch/_numpy/linalg.py,sha256=m2K8GEMQGySc6j9_eATOwx91-nrMnCo4gbvj062xO0s,5891 +torch/_numpy/random.py,sha256=CmLgwsBliHnaRwfuZfzgZRWC3v3UDxM6_f8DjIE90HM,4841 +torch/_numpy/testing/__init__.py,sha256=3F4nlODzdqjjbhBV4Nz9eA44BzLgp-BpbLycggm6GVY,395 +torch/_numpy/testing/__pycache__/__init__.cpython-39.pyc,, +torch/_numpy/testing/__pycache__/utils.cpython-39.pyc,, +torch/_numpy/testing/utils.py,sha256=MfRgBNSLDqhWI1-NVDpAC1SCV1VmJJJ_2xJR3sJStc8,78584 +torch/_ops.py,sha256=pr4egw9NfIC3uNjQ42JbKwIkCTiP7_esldCC40VAoQ0,62021 +torch/_prims/__init__.py,sha256=oTzJpbpPYrM-8dtxcX71t0Xs95Un4ImnBG-vVzOfWT8,84291 +torch/_prims/__pycache__/__init__.cpython-39.pyc,, +torch/_prims/__pycache__/context.cpython-39.pyc,, +torch/_prims/__pycache__/debug_prims.cpython-39.pyc,, +torch/_prims/__pycache__/executor.cpython-39.pyc,, +torch/_prims/__pycache__/rng_prims.cpython-39.pyc,, +torch/_prims/context.py,sha256=08OKc4WPIoS8JuXdcoReqxN6IYL7P12GgdymToIvqkE,6249 +torch/_prims/debug_prims.py,sha256=cBv9vw4uJ-9oynQi6EvZ7Se8STySagO-JFYmSbXB1NU,1943 +torch/_prims/executor.py,sha256=n3x6QEUw_b5z47hwTmXWe1oIf7YIFpxHw1QJGgEO-J4,1976 +torch/_prims/rng_prims.py,sha256=ClyeuM9XhFPVMsSCfA3p6Emnes3Twcde32CbbLQWmrQ,14855 +torch/_prims_common/__init__.py,sha256=vKjUvd-91F3_SwhImo194veF8-0WNrYQfPcCAC3gImI,71775 +torch/_prims_common/__pycache__/__init__.cpython-39.pyc,, +torch/_prims_common/__pycache__/wrappers.cpython-39.pyc,, +torch/_prims_common/wrappers.py,sha256=wQIdDswpNJXSBB5S8gUt5vzP6IzeeZQ569ieX3D7qHM,18711 +torch/_python_dispatcher.py,sha256=CBgBKLGX-LaF-fUR99qKJ8Fg_YdePdsdQOtExf_4tE4,7319 +torch/_refs/__init__.py,sha256=GfrOw5HYFSjhropu_UirSO5U5VoS7BcrEY7FaH4hgpQ,223120 +torch/_refs/__pycache__/__init__.cpython-39.pyc,, +torch/_refs/__pycache__/_conversions.cpython-39.pyc,, +torch/_refs/__pycache__/fft.cpython-39.pyc,, +torch/_refs/_conversions.py,sha256=xZr1y6bvByhAepVg0uqoo_z1S7OisYqt2jI7GwF2dCc,3652 +torch/_refs/fft.py,sha256=A1t1DoACBOHC2aX09E3RBUlvHtf_8eAL1M-5APXts8A,18558 +torch/_refs/linalg/__init__.py,sha256=Uc6e7jNK7LAEVagUpXVzS92sb0C1hoK5oHLWb_YD6G0,11871 +torch/_refs/linalg/__pycache__/__init__.cpython-39.pyc,, +torch/_refs/nn/__init__.py,sha256=PFr_V0xYQhWjKk5oc83cYg_JcNZ2FEKTsjXlnxmkyB8,25 +torch/_refs/nn/__pycache__/__init__.cpython-39.pyc,, +torch/_refs/nn/functional/__init__.py,sha256=q6S5GQ5DkAWOj_WbNvgv5FP2STPhMAXX6F8wjQGhC1E,43773 +torch/_refs/nn/functional/__pycache__/__init__.cpython-39.pyc,, +torch/_refs/special/__init__.py,sha256=fwq5D62DqplZKh-3xXUZbM5nUK6x5m54qkgrgPjlxuk,7060 +torch/_refs/special/__pycache__/__init__.cpython-39.pyc,, +torch/_size_docs.py,sha256=kcS1-bU8H8PCfDkENNH8YFOgfXkUXfsHt3f1z2JJvC8,939 +torch/_sources.py,sha256=8YnPINLrmUXvzjvOwPMjvPQh2q8TVfIoU-vr_ll4kg4,4561 +torch/_storage_docs.py,sha256=LQgAJq6wRYycwEY6_POp0HhW7bkOyjZ26Aw3GaGVN4s,1377 +torch/_streambase.py,sha256=DwofDa9lCG6iGzBIQ4Yzam08Jeua98daSS_bidnYedM,455 +torch/_strobelight/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_strobelight/__pycache__/__init__.cpython-39.pyc,, +torch/_strobelight/__pycache__/cli_function_profiler.cpython-39.pyc,, +torch/_strobelight/__pycache__/compile_time_profiler.cpython-39.pyc,, +torch/_strobelight/cli_function_profiler.py,sha256=O1-YMsppDehyMZOLvS_faOu7Ot0ju0wu3B3b5GR3RUY,12089 +torch/_strobelight/compile_time_profiler.py,sha256=B-D0pSY0eGbo69qO-5V68TY3_N-jIxDGF296HS6XD-Y,7748 +torch/_subclasses/__init__.py,sha256=nDR4CWhmJQdYQVYeY7KcoJn7nrLG7cAfAehXh3MOcWQ,392 +torch/_subclasses/__pycache__/__init__.cpython-39.pyc,, +torch/_subclasses/__pycache__/_fake_tensor_utils.cpython-39.pyc,, +torch/_subclasses/__pycache__/fake_impls.cpython-39.pyc,, +torch/_subclasses/__pycache__/fake_tensor.cpython-39.pyc,, +torch/_subclasses/__pycache__/fake_utils.cpython-39.pyc,, +torch/_subclasses/__pycache__/functional_tensor.cpython-39.pyc,, +torch/_subclasses/__pycache__/meta_utils.cpython-39.pyc,, +torch/_subclasses/__pycache__/schema_check_mode.cpython-39.pyc,, +torch/_subclasses/_fake_tensor_utils.py,sha256=QnZF8fkiQTqGcJ-6H13Z0-QCwSRYiPSKSj_mDpGYT_s,9110 +torch/_subclasses/fake_impls.py,sha256=pt1F2xDw4cRcMRfxiIXoVVFTOwM78J2c0tIqsBCxb3k,38155 +torch/_subclasses/fake_tensor.py,sha256=7SZZVgt8ZoGhDsJuhbvRaX0uNRhoKw3RSnltii_DYVs,132780 +torch/_subclasses/fake_utils.py,sha256=VwxMF5gdR_f_U7YwfDMIkOPwxqKUj3dmGZwYtm_XX5I,10587 +torch/_subclasses/functional_tensor.py,sha256=Vn7PUfDvyVjzodTS5nGzGbI-gjjLxRaQCa1j6Dyqd7M,35052 +torch/_subclasses/meta_utils.py,sha256=foQq1l6Nlh_Gv_9uW4jmerFsxS9d2xmwv1zL9P7aBHg,89002 +torch/_subclasses/schema_check_mode.py,sha256=NBVOJGxU5YOqZXXoxsJdaiSRoFQu85j3hv4_yP6zydc,8869 +torch/_tensor.py,sha256=1vo4YklnhiDkDGqgOIztId7gvCpnZqx7iHdldDItKZ0,74560 +torch/_tensor_docs.py,sha256=vVg-Zh3Q8Cwrm4R8Xb9MexJjxF7dnsGwgMP0zNSDfn8,151121 +torch/_tensor_str.py,sha256=5DKLr-jgmnu4F3ZNPSyY45SC9xTqyCtKKpieKZ3JaUA,29239 +torch/_thread_safe_fork.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_torch_docs.py,sha256=OkkdEYXl5hIDLeAC2luCE_Q5k1RmJ9r8op73SNgveYg,440510 +torch/_utils.py,sha256=p0xDPxLj8IZ-S5kR0_sHEgUZjulG947328-kOL2t1ss,41529 +torch/_utils_internal.py,sha256=QyoioBzU2od-pFcf0mLtHQORLTJdCyOr6fbjRZFupG4,8977 +torch/_vendor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/_vendor/__pycache__/__init__.cpython-39.pyc,, +torch/_vendor/packaging/__init__.py,sha256=tBMEyxhq3EEEWx2ildaGANR5SCtuuItz0Lbo217LOrg,511 +torch/_vendor/packaging/__pycache__/__init__.cpython-39.pyc,, +torch/_vendor/packaging/__pycache__/_structures.cpython-39.pyc,, +torch/_vendor/packaging/__pycache__/version.cpython-39.pyc,, +torch/_vendor/packaging/_structures.py,sha256=qrorE2MfdfAdF4JqdKD-aLgIo3sRECalhpljE8Accxs,1492 +torch/_vendor/packaging/version.py,sha256=o2agHbS_rqKFYhH6qABjDxZFyKWASNKtAWZ0KXHsjvU,16799 +torch/_vmap_internals.py,sha256=arCBjpxBmEjRtkGTxfmIIqCJMxoOikXa6C3lOkgsWP8,9687 +torch/_weights_only_unpickler.py,sha256=GCxALo3tGqVyDX5SFmEH_S1AQ8M_TpvDGFJmkJCfYwE,22787 +torch/accelerator/__init__.py,sha256=lJygbOyy_T0RYiittEHqUBbbwhKmXLeQoj8P3PG_5tg,8574 +torch/accelerator/__pycache__/__init__.cpython-39.pyc,, +torch/accelerator/__pycache__/_utils.cpython-39.pyc,, +torch/accelerator/_utils.py,sha256=FwiFWfZMWVFFXUzfBoPinUfTRzn4FP1tcqljO6nmpuE,996 +torch/amp/__init__.py,sha256=Uc1hDB3zzt2r2cyu6mgC-cFsV2erFWz5r2QOKU-yj7Q,190 +torch/amp/__pycache__/__init__.cpython-39.pyc,, +torch/amp/__pycache__/autocast_mode.cpython-39.pyc,, +torch/amp/__pycache__/grad_scaler.cpython-39.pyc,, +torch/amp/autocast_mode.py,sha256=clFRjAM6jOT-dCjlmL098IXFo7kCvLJir6qHOrWmrR4,25565 +torch/amp/grad_scaler.py,sha256=tNd0eXAZZGZxDtUCSr3CCxMI3Ry2ab826L2sCEZxlLY,31256 +torch/ao/__init__.py,sha256=BbFZLUQMOMzkKhsZ4-FhIGI_H3qyIjk4bflnCywfVxg,709 +torch/ao/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/__init__.py,sha256=g8uojNaWxOlXL20t1NrgkQ47A2YsSAK0RvjMogId2yA,869 +torch/ao/nn/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/intrinsic/__init__.py,sha256=VXw8ou3oEwbDkW1gNIZpayC5krlSwooV8QQkkQ1ATxo,1002 +torch/ao/nn/intrinsic/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/intrinsic/modules/__init__.py,sha256=TNOnhADW9Hj5TJmt_GwRWlkcFyRQT9P0bmRIwcutkag,696 +torch/ao/nn/intrinsic/modules/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/intrinsic/modules/__pycache__/fused.cpython-39.pyc,, +torch/ao/nn/intrinsic/modules/fused.py,sha256=h0_T5uH3hTVTCIZ0AKSmiD6iMVMLHXM16BFJBaRRv_M,10604 +torch/ao/nn/intrinsic/qat/__init__.py,sha256=sE-0qvXnTY5ilDKux2jZukNxzkeR-G2mlBp5ss2b4yk,38 +torch/ao/nn/intrinsic/qat/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/intrinsic/qat/modules/__init__.py,sha256=0ZwViGqdx6f0QYppptRbGOtCkl-E1tGcXJYqvkwjQkM,579 +torch/ao/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-39.pyc,, +torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-39.pyc,, +torch/ao/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-39.pyc,, +torch/ao/nn/intrinsic/qat/modules/conv_fused.py,sha256=r0rt-OK4_UAD5r1rL_EXm1d_NiJaZMVP_4iDCfH_4Qo,34012 +torch/ao/nn/intrinsic/qat/modules/linear_fused.py,sha256=eO9Md6k-gQb65nDZC8nbMb5TA3PZ5KpjM6PNLWaKH3M,6806 +torch/ao/nn/intrinsic/qat/modules/linear_relu.py,sha256=jptpg2zr_gIklBnBtweXWyDMU8aiS9ShREUKgo7GTOk,1738 +torch/ao/nn/intrinsic/quantized/__init__.py,sha256=QyW1Bo9d_YpVgKKEz_0XCK_vDpzGzr3cRCr6VFThD_g,251 +torch/ao/nn/intrinsic/quantized/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/intrinsic/quantized/dynamic/__init__.py,sha256=sE-0qvXnTY5ilDKux2jZukNxzkeR-G2mlBp5ss2b4yk,38 +torch/ao/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/intrinsic/quantized/dynamic/modules/__init__.py,sha256=YS0r9fkbJU20FDakDuLmzDJrWZXL-fB2wC3c5AOzDp8,76 +torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-39.pyc,, +torch/ao/nn/intrinsic/quantized/dynamic/modules/linear_relu.py,sha256=230M22hKL1x3YTxpOeRL-gDHFDbJkjliaK9KHqAxtRM,2088 +torch/ao/nn/intrinsic/quantized/modules/__init__.py,sha256=f6fu2CJkehRe7TovtjS-4asgjmz8O4LUhBlNKW5sPAU,427 +torch/ao/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc,, +torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_add.cpython-39.pyc,, +torch/ao/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc,, +torch/ao/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-39.pyc,, +torch/ao/nn/intrinsic/quantized/modules/bn_relu.py,sha256=4kbUvcSxUc9ryx-GrtA68uzy0c89rm9BymC0PP1HMeA,3392 +torch/ao/nn/intrinsic/quantized/modules/conv_add.py,sha256=QYuo_P4wGK20INRvCqrb1bcl0_Fu8u9KWGMW6fN0MqQ,4581 +torch/ao/nn/intrinsic/quantized/modules/conv_relu.py,sha256=V5J6T8mCXmKQGqavpNagr91kLyfYtqGdmU9L5qmhxLo,8717 +torch/ao/nn/intrinsic/quantized/modules/linear_relu.py,sha256=o-1VLCF_x8InDrzI72Rog9XufamN73OGDGDbg9K7ENs,7074 +torch/ao/nn/qat/__init__.py,sha256=sE-0qvXnTY5ilDKux2jZukNxzkeR-G2mlBp5ss2b4yk,38 +torch/ao/nn/qat/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/qat/dynamic/__init__.py,sha256=sE-0qvXnTY5ilDKux2jZukNxzkeR-G2mlBp5ss2b4yk,38 +torch/ao/nn/qat/dynamic/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/qat/dynamic/modules/__init__.py,sha256=h_S3dv05DjPYXtDjuRqkzrGoE6aeaPOjE1oKRfO9Bq8,54 +torch/ao/nn/qat/dynamic/modules/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/qat/dynamic/modules/__pycache__/linear.cpython-39.pyc,, +torch/ao/nn/qat/dynamic/modules/linear.py,sha256=b6YXXy2jVTJlV0JxVWybYR6TNf0i_TwGjcX00QwrY6M,1288 +torch/ao/nn/qat/modules/__init__.py,sha256=kSrrX2n3rqdHWMMphT8j0lxEch9YVFQOM9FTRDJhf7c,241 +torch/ao/nn/qat/modules/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/qat/modules/__pycache__/conv.cpython-39.pyc,, +torch/ao/nn/qat/modules/__pycache__/embedding_ops.cpython-39.pyc,, +torch/ao/nn/qat/modules/__pycache__/linear.cpython-39.pyc,, +torch/ao/nn/qat/modules/conv.py,sha256=VkQK3vFCANqOdaV5J2hg7m92LmHsmnpK_VktvS9BEoU,9888 +torch/ao/nn/qat/modules/embedding_ops.py,sha256=RTeiCLMU-mx1pJbUyBVX2FqSlF5USNdcxE7bp-tQ4Mo,8067 +torch/ao/nn/qat/modules/linear.py,sha256=hm7Wpx6kyKaQV79xWcXQAeSGVHYJQVRgW_qzA0NGP-I,3148 +torch/ao/nn/quantizable/__init__.py,sha256=sE-0qvXnTY5ilDKux2jZukNxzkeR-G2mlBp5ss2b4yk,38 +torch/ao/nn/quantizable/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/quantizable/modules/__init__.py,sha256=Cn3Ltv6d1JfKe7aV4GeTxQKkLfT2z5xLI-Jk1Sl938g,154 +torch/ao/nn/quantizable/modules/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/quantizable/modules/__pycache__/activation.cpython-39.pyc,, +torch/ao/nn/quantizable/modules/__pycache__/rnn.cpython-39.pyc,, +torch/ao/nn/quantizable/modules/activation.py,sha256=Lqiu8tt7vfaIxaauJmcD6-xuYC3XCYFTOShc_BuuyYQ,23600 +torch/ao/nn/quantizable/modules/rnn.py,sha256=zENgUmu35wpY5Avkwoc_nBsWKUZ34qzTMZsHa2FMx6E,22206 +torch/ao/nn/quantized/__init__.py,sha256=rufZDT8E2pU6Y_RJUA9VnLIJ7-zHTR0wGyzghMKDFPY,725 +torch/ao/nn/quantized/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/quantized/__pycache__/functional.cpython-39.pyc,, +torch/ao/nn/quantized/dynamic/__init__.py,sha256=sE-0qvXnTY5ilDKux2jZukNxzkeR-G2mlBp5ss2b4yk,38 +torch/ao/nn/quantized/dynamic/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/quantized/dynamic/modules/__init__.py,sha256=5gD6fcxxz2UnPlOKoCLOHU3aa5zW60RpXDDeOECBpAc,439 +torch/ao/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/quantized/dynamic/modules/__pycache__/conv.cpython-39.pyc,, +torch/ao/nn/quantized/dynamic/modules/__pycache__/linear.cpython-39.pyc,, +torch/ao/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-39.pyc,, +torch/ao/nn/quantized/dynamic/modules/conv.py,sha256=S0L7fOyCu4WZZMWilMDbkRHi2x2LCcFaLvS9XtTGIYY,18713 +torch/ao/nn/quantized/dynamic/modules/linear.py,sha256=OCT9lBLYERMveXYoBEqwI-R6uecHNlAwxAGUWeqmAQE,6515 +torch/ao/nn/quantized/dynamic/modules/rnn.py,sha256=4zF4JGxIpXwKWHKZ-lj84cSD_aizRY5WMs5-GPXbIp4,52711 +torch/ao/nn/quantized/functional.py,sha256=y-Fm-2w47-tD7jonoYhm_Z5tKNqDi8EuomfgrVit7h0,30368 +torch/ao/nn/quantized/modules/__init__.py,sha256=d7APn7TPLFlbJLMXxlGnTyqVZq9IGSch04idikjF_08,4683 +torch/ao/nn/quantized/modules/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/quantized/modules/__pycache__/activation.cpython-39.pyc,, +torch/ao/nn/quantized/modules/__pycache__/batchnorm.cpython-39.pyc,, +torch/ao/nn/quantized/modules/__pycache__/conv.cpython-39.pyc,, +torch/ao/nn/quantized/modules/__pycache__/dropout.cpython-39.pyc,, +torch/ao/nn/quantized/modules/__pycache__/embedding_ops.cpython-39.pyc,, +torch/ao/nn/quantized/modules/__pycache__/functional_modules.cpython-39.pyc,, +torch/ao/nn/quantized/modules/__pycache__/linear.cpython-39.pyc,, +torch/ao/nn/quantized/modules/__pycache__/normalization.cpython-39.pyc,, +torch/ao/nn/quantized/modules/__pycache__/rnn.cpython-39.pyc,, +torch/ao/nn/quantized/modules/__pycache__/utils.cpython-39.pyc,, +torch/ao/nn/quantized/modules/activation.py,sha256=U0aQR1_Rz9NrSUoufuerAM4L5t-bcypDUIyRwWr-tx0,11952 +torch/ao/nn/quantized/modules/batchnorm.py,sha256=vEs81H-u_JfBQSzjtLhLURSGy-Q-Z45C1vnD09HyCvM,4555 +torch/ao/nn/quantized/modules/conv.py,sha256=XQP4c1wL9U40WwtJRlwoN7K_c9DN4QjNnXQ8FsNm4n8,44657 +torch/ao/nn/quantized/modules/dropout.py,sha256=7EWdajteiitQbayVUtYQDAs_xCZCPp_kfh7l2GOvIyg,836 +torch/ao/nn/quantized/modules/embedding_ops.py,sha256=A9smZl-xlEthxnNeU5a1mG1gH4lQbqflcijATMXLdfg,15089 +torch/ao/nn/quantized/modules/functional_modules.py,sha256=uwyhuSiHkg99QfkJns2R-WQibI6ueZRXuWp-2LNe0Ic,9520 +torch/ao/nn/quantized/modules/linear.py,sha256=HLJZX4WiNBNWYpmDH_Dd1M186uJN7Azh2tFQb-0ZOhE,13972 +torch/ao/nn/quantized/modules/normalization.py,sha256=bN6Ad-N7exObMOssBZ-7APgoPeBwgQgh107eto4VWas,9886 +torch/ao/nn/quantized/modules/rnn.py,sha256=hoEciGUuczxFd-vfOWzlLngKSwU2dyXLNDSbZODGBRY,1880 +torch/ao/nn/quantized/modules/utils.py,sha256=Kxj9nrbLYDbFsBxkbUZdrd2m5w9hXDjsO1I0Wqk7Aso,4839 +torch/ao/nn/quantized/reference/__init__.py,sha256=uy5ftWsvDX4VkLX2C2wbtYF_sokM3UiC3PM5wwMcx4A,303 +torch/ao/nn/quantized/reference/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/quantized/reference/modules/__init__.py,sha256=8RHzSWvt67OBacAuyMPhZmpGHGm8GCe2TClxkq8BttM,523 +torch/ao/nn/quantized/reference/modules/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/quantized/reference/modules/__pycache__/conv.cpython-39.pyc,, +torch/ao/nn/quantized/reference/modules/__pycache__/linear.cpython-39.pyc,, +torch/ao/nn/quantized/reference/modules/__pycache__/rnn.cpython-39.pyc,, +torch/ao/nn/quantized/reference/modules/__pycache__/sparse.cpython-39.pyc,, +torch/ao/nn/quantized/reference/modules/__pycache__/utils.cpython-39.pyc,, +torch/ao/nn/quantized/reference/modules/conv.py,sha256=L7c_WfWjiksu2rSFVocw8YVR3kq--P5pv_SOuw90On8,15900 +torch/ao/nn/quantized/reference/modules/linear.py,sha256=jFax2t581eLcKI_vvUlcIXX9Q48dkSk0eigjXRKeXkA,2342 +torch/ao/nn/quantized/reference/modules/rnn.py,sha256=wEAQMlClo_P1rX9Whxxat-s0nNI7hisFk8ZCnO2Uuu8,30535 +torch/ao/nn/quantized/reference/modules/sparse.py,sha256=EDw-vk2xI1YEzLl1W-U6HG2fUGu8Rd8kV_uu-5xxxJk,4835 +torch/ao/nn/quantized/reference/modules/utils.py,sha256=cyox-Y4TGhQAOQ2djcqmYDuZvn7iER8Znkf4YnA45Io,15780 +torch/ao/nn/sparse/__init__.py,sha256=Lfoo4m9ZnVOja2YLfZa9jrz54iMOMhHzCg7RYcP5FLk,25 +torch/ao/nn/sparse/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/sparse/quantized/__init__.py,sha256=SEINwKuRz86NuFHzdWv4BowgnhoNv0qFpluJfm2nncQ,178 +torch/ao/nn/sparse/quantized/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/sparse/quantized/__pycache__/linear.cpython-39.pyc,, +torch/ao/nn/sparse/quantized/__pycache__/utils.cpython-39.pyc,, +torch/ao/nn/sparse/quantized/dynamic/__init__.py,sha256=fDXnm6MNQ02dzp2hv56Tg3fmTjnf2saSRH5ZBvPYafM,63 +torch/ao/nn/sparse/quantized/dynamic/__pycache__/__init__.cpython-39.pyc,, +torch/ao/nn/sparse/quantized/dynamic/__pycache__/linear.cpython-39.pyc,, +torch/ao/nn/sparse/quantized/dynamic/linear.py,sha256=3Jk1bcdJSnSKNz1Abq2btfECbAJHeogkCPHwQmb0XXk,6517 +torch/ao/nn/sparse/quantized/linear.py,sha256=kr-nj0796d-jSpkkmq2_h-5qqTOKJ3SfOo0torf9bVk,9292 +torch/ao/nn/sparse/quantized/utils.py,sha256=LxmAMLPuJMtxbx-JmwGN-RrSYsJ7L8khPv_V9XMgcNU,2143 +torch/ao/ns/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/ao/ns/__pycache__/__init__.cpython-39.pyc,, +torch/ao/ns/__pycache__/_numeric_suite.cpython-39.pyc,, +torch/ao/ns/__pycache__/_numeric_suite_fx.cpython-39.pyc,, +torch/ao/ns/_numeric_suite.py,sha256=NquGuDW4agBRwrWjerO1kLa3cKT3Ig5gXMu7dQuPuPU,20643 +torch/ao/ns/_numeric_suite_fx.py,sha256=4r1NbDhHD-RiauE4GRQK3OYCd088gtb_gUdvdYoXtz4,42482 +torch/ao/ns/fx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/ao/ns/fx/__pycache__/__init__.cpython-39.pyc,, +torch/ao/ns/fx/__pycache__/graph_matcher.cpython-39.pyc,, +torch/ao/ns/fx/__pycache__/graph_passes.cpython-39.pyc,, +torch/ao/ns/fx/__pycache__/mappings.cpython-39.pyc,, +torch/ao/ns/fx/__pycache__/n_shadows_utils.cpython-39.pyc,, +torch/ao/ns/fx/__pycache__/ns_types.cpython-39.pyc,, +torch/ao/ns/fx/__pycache__/pattern_utils.cpython-39.pyc,, +torch/ao/ns/fx/__pycache__/qconfig_multi_mapping.cpython-39.pyc,, +torch/ao/ns/fx/__pycache__/utils.cpython-39.pyc,, +torch/ao/ns/fx/__pycache__/weight_utils.cpython-39.pyc,, +torch/ao/ns/fx/graph_matcher.py,sha256=K6kKpG1f81GQe1kpRYpBMwMRJqGh8yApG41MD0NO1OI,19750 +torch/ao/ns/fx/graph_passes.py,sha256=mkzUe_we_av6MbTxzcYpSfPVqRLMzGxPWcL-EknVoN0,45484 +torch/ao/ns/fx/mappings.py,sha256=tVxlPJBfBMT8N06C4zVMkWoEMBcJRXnf9F4nRcGx1jU,19057 +torch/ao/ns/fx/n_shadows_utils.py,sha256=V0ZZeBZmntUgARxhUc5Yqymd4urBqpS3BO86oP2JffM,52559 +torch/ao/ns/fx/ns_types.py,sha256=uIFqq24jPtRKM8V-fhzQd8t1oNI2Yl03FYWwWFqxMd4,2392 +torch/ao/ns/fx/pattern_utils.py,sha256=WwMPYWeye4kKzVVMjF7YITFoowDgZU4MUhLNIyU1duM,8578 +torch/ao/ns/fx/qconfig_multi_mapping.py,sha256=lxDIC1M3qlxRX8wmNGIIvyaj0CTPBYqLSuAE-8BOA3M,10432 +torch/ao/ns/fx/utils.py,sha256=0zPgjYfNX5KqJ06HyToCrv25vmcSr0cKN84DgJT0RYo,21184 +torch/ao/ns/fx/weight_utils.py,sha256=xfRRiycSY--Hh1TlHoXlBv3Lgzot9hA30jsxvKwvJS0,11666 +torch/ao/pruning/__init__.py,sha256=buSUvIIoEdoVoBKVgjMHw6Tz52akY935O7_j9WCa034,663 +torch/ao/pruning/__pycache__/__init__.cpython-39.pyc,, +torch/ao/pruning/__pycache__/_mappings.cpython-39.pyc,, +torch/ao/pruning/_experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/ao/pruning/_experimental/__pycache__/__init__.cpython-39.pyc,, +torch/ao/pruning/_experimental/activation_sparsifier/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/__init__.cpython-39.pyc,, +torch/ao/pruning/_experimental/activation_sparsifier/__pycache__/activation_sparsifier.cpython-39.pyc,, +torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py,sha256=FE6CAsmW1jYKAwtzROEDVO0Ub5TEFT_JomUalyGioaI,19548 +torch/ao/pruning/_experimental/data_scheduler/__init__.py,sha256=JZ8rO6gG9c8OxZ1COfSCl6viLVoj2hKjkrxaHcmIlWA,98 +torch/ao/pruning/_experimental/data_scheduler/__pycache__/__init__.cpython-39.pyc,, +torch/ao/pruning/_experimental/data_scheduler/__pycache__/base_data_scheduler.cpython-39.pyc,, +torch/ao/pruning/_experimental/data_scheduler/base_data_scheduler.py,sha256=ZRZoz37zWMXjvqfop3T4iBRGRbc3Ek8aYEmUNoF-R-k,7862 +torch/ao/pruning/_experimental/data_sparsifier/__init__.py,sha256=XFtc3q0kNE1_DoO59bcPIm79KMkblr0AkKzrIMw_J3s,182 +torch/ao/pruning/_experimental/data_sparsifier/__pycache__/__init__.cpython-39.pyc,, +torch/ao/pruning/_experimental/data_sparsifier/__pycache__/base_data_sparsifier.cpython-39.pyc,, +torch/ao/pruning/_experimental/data_sparsifier/__pycache__/data_norm_sparsifier.cpython-39.pyc,, +torch/ao/pruning/_experimental/data_sparsifier/__pycache__/quantization_utils.cpython-39.pyc,, +torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py,sha256=NFZ3RSkLn3xDoOMTnr-5R8dJrjT3h52DqbZN1zvW15c,13766 +torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py,sha256=d3szjD_z8kYt4oprXp-SbJPuVZ1sQLzAD1kGFmHgvgM,7963 +torch/ao/pruning/_experimental/data_sparsifier/lightning/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/ao/pruning/_experimental/data_sparsifier/lightning/__pycache__/__init__.cpython-39.pyc,, +torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/__init__.cpython-39.pyc,, +torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/_data_sparstity_utils.cpython-39.pyc,, +torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/__pycache__/data_sparsity.cpython-39.pyc,, +torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/_data_sparstity_utils.py,sha256=oDL0SX3pzCs7EkZDklLYWSPwNZFA33Frc19532lQ7IA,1680 +torch/ao/pruning/_experimental/data_sparsifier/lightning/callbacks/data_sparsity.py,sha256=ZnB-pdsto8Zt-JV0pCIfztt55-nCwB6WnjxAiJAbxDk,6792 +torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py,sha256=WHnZLn6MD1LgaibgwxqkkgNfjUMXt00fL9-N9WgRlWU,6040 +torch/ao/pruning/_experimental/pruner/FPGM_pruner.py,sha256=xUziOzu3A4CQ6wnv1FMJcNBltBz3Ev0eqh-roi0Avzg,3553 +torch/ao/pruning/_experimental/pruner/__init__.py,sha256=B5dNGoaxQpyxLwzg-Ej7-UT3PZdJg4V4-58ZLqQCWTQ,265 +torch/ao/pruning/_experimental/pruner/__pycache__/FPGM_pruner.cpython-39.pyc,, +torch/ao/pruning/_experimental/pruner/__pycache__/__init__.cpython-39.pyc,, +torch/ao/pruning/_experimental/pruner/__pycache__/base_structured_sparsifier.cpython-39.pyc,, +torch/ao/pruning/_experimental/pruner/__pycache__/lstm_saliency_pruner.cpython-39.pyc,, +torch/ao/pruning/_experimental/pruner/__pycache__/match_utils.cpython-39.pyc,, +torch/ao/pruning/_experimental/pruner/__pycache__/parametrization.cpython-39.pyc,, +torch/ao/pruning/_experimental/pruner/__pycache__/prune_functions.cpython-39.pyc,, +torch/ao/pruning/_experimental/pruner/__pycache__/saliency_pruner.cpython-39.pyc,, +torch/ao/pruning/_experimental/pruner/base_structured_sparsifier.py,sha256=apN2M5ULdHk23I7CI32E4nic6uk3H5mAAWx10Ild1-Q,11248 +torch/ao/pruning/_experimental/pruner/lstm_saliency_pruner.py,sha256=2bAtsqfPHu1Aw0QzpzVoamU7LSlVS2FwAKbcIXnRr8U,2194 +torch/ao/pruning/_experimental/pruner/match_utils.py,sha256=yl2foMS9RIrM02MljpFTjwPeqmdnVqTV7W0CAzc1uOM,2046 +torch/ao/pruning/_experimental/pruner/parametrization.py,sha256=V2QllSgsvCFfAU2Q6z0En0GPCA___UUhA7MgYdkTljU,1903 +torch/ao/pruning/_experimental/pruner/prune_functions.py,sha256=s5j9IsqigZRHNvnJwOFcx2hdYE4mHpGWmF9tFoKMjNc,19574 +torch/ao/pruning/_experimental/pruner/saliency_pruner.py,sha256=T4aIRKy_AYHr6Welx4hCMowTe53F9L1a5LaMsw23Xd4,1432 +torch/ao/pruning/_mappings.py,sha256=QyqsbrIsTq-gkHBuk2Y0iAIaA2UnY4DIvUo4bTW2d0I,620 +torch/ao/pruning/scheduler/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/ao/pruning/scheduler/__pycache__/__init__.cpython-39.pyc,, +torch/ao/pruning/scheduler/__pycache__/base_scheduler.cpython-39.pyc,, +torch/ao/pruning/scheduler/__pycache__/cubic_scheduler.cpython-39.pyc,, +torch/ao/pruning/scheduler/__pycache__/lambda_scheduler.cpython-39.pyc,, +torch/ao/pruning/scheduler/base_scheduler.py,sha256=W_EN0YNIsSg0caJbeY9C19mglURiqfWftSwO-NkT13I,6696 +torch/ao/pruning/scheduler/cubic_scheduler.py,sha256=43ACWjPjqyDHAC6OaAU9t3cVCzercN5PzKCSq2U9Qrk,3957 +torch/ao/pruning/scheduler/lambda_scheduler.py,sha256=VtDD48IQyzdX2HJPbtNIpXixNlqRm0oakn4S3nRo9ek,2169 +torch/ao/pruning/sparsifier/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/ao/pruning/sparsifier/__pycache__/__init__.cpython-39.pyc,, +torch/ao/pruning/sparsifier/__pycache__/base_sparsifier.cpython-39.pyc,, +torch/ao/pruning/sparsifier/__pycache__/nearly_diagonal_sparsifier.cpython-39.pyc,, +torch/ao/pruning/sparsifier/__pycache__/utils.cpython-39.pyc,, +torch/ao/pruning/sparsifier/__pycache__/weight_norm_sparsifier.cpython-39.pyc,, +torch/ao/pruning/sparsifier/base_sparsifier.py,sha256=F-hFC22qlvIXe_Q_ASXv7-T3Hgu5D2vJCBug_RulqxE,14085 +torch/ao/pruning/sparsifier/nearly_diagonal_sparsifier.py,sha256=pU_R9KkTtsAixDQ9q230XMC4Jm3uEC1uS0a_J4VnsKE,2326 +torch/ao/pruning/sparsifier/utils.py,sha256=QtWN4pwtIa4_K7nft0xrmklivr80fVld0noy8QmSyVE,4940 +torch/ao/pruning/sparsifier/weight_norm_sparsifier.py,sha256=rJuGEzuCjSKikS4Oaej8NWIY0ruargE_v_WwPq3yAcI,9544 +torch/ao/quantization/__init__.py,sha256=Yi2cRZ3-8VMnQgXVcyWZ49hWVrO0BUTODV07Ag_6TBI,7553 +torch/ao/quantization/__pycache__/__init__.cpython-39.pyc,, +torch/ao/quantization/__pycache__/_correct_bias.cpython-39.pyc,, +torch/ao/quantization/__pycache__/_equalize.cpython-39.pyc,, +torch/ao/quantization/__pycache__/_learnable_fake_quantize.cpython-39.pyc,, +torch/ao/quantization/__pycache__/fake_quantize.cpython-39.pyc,, +torch/ao/quantization/__pycache__/fuse_modules.cpython-39.pyc,, +torch/ao/quantization/__pycache__/fuser_method_mappings.cpython-39.pyc,, +torch/ao/quantization/__pycache__/observer.cpython-39.pyc,, +torch/ao/quantization/__pycache__/qconfig.cpython-39.pyc,, +torch/ao/quantization/__pycache__/qconfig_mapping.cpython-39.pyc,, +torch/ao/quantization/__pycache__/quant_type.cpython-39.pyc,, +torch/ao/quantization/__pycache__/quantization_mappings.cpython-39.pyc,, +torch/ao/quantization/__pycache__/quantize.cpython-39.pyc,, +torch/ao/quantization/__pycache__/quantize_fx.cpython-39.pyc,, +torch/ao/quantization/__pycache__/quantize_jit.cpython-39.pyc,, +torch/ao/quantization/__pycache__/quantize_pt2e.cpython-39.pyc,, +torch/ao/quantization/__pycache__/stubs.cpython-39.pyc,, +torch/ao/quantization/__pycache__/utils.cpython-39.pyc,, +torch/ao/quantization/_correct_bias.py,sha256=mV3185FC5ebgt_8VflyfL-XGQWbmD8kLS2Jf2hyWEHo,5583 +torch/ao/quantization/_equalize.py,sha256=d9oYBdklbJ9Szhntz1JhfxhLDmvq86B26cPqsnrOgbY,9746 +torch/ao/quantization/_learnable_fake_quantize.py,sha256=hGTNJ-tC_RJXJVVY8oiDyCH_tgZj9hmhtcDvVV4gBRA,8111 +torch/ao/quantization/backend_config/__init__.py,sha256=4Y8ISKDtMW_f2GA6qOE-K4fo09GAclfC3RPo3bpxNJk,945 +torch/ao/quantization/backend_config/__pycache__/__init__.cpython-39.pyc,, +torch/ao/quantization/backend_config/__pycache__/_common_operator_config_utils.cpython-39.pyc,, +torch/ao/quantization/backend_config/__pycache__/_qnnpack_pt2e.cpython-39.pyc,, +torch/ao/quantization/backend_config/__pycache__/backend_config.cpython-39.pyc,, +torch/ao/quantization/backend_config/__pycache__/executorch.cpython-39.pyc,, +torch/ao/quantization/backend_config/__pycache__/fbgemm.cpython-39.pyc,, +torch/ao/quantization/backend_config/__pycache__/native.cpython-39.pyc,, +torch/ao/quantization/backend_config/__pycache__/observation_type.cpython-39.pyc,, +torch/ao/quantization/backend_config/__pycache__/onednn.cpython-39.pyc,, +torch/ao/quantization/backend_config/__pycache__/qnnpack.cpython-39.pyc,, +torch/ao/quantization/backend_config/__pycache__/tensorrt.cpython-39.pyc,, +torch/ao/quantization/backend_config/__pycache__/utils.cpython-39.pyc,, +torch/ao/quantization/backend_config/__pycache__/x86.cpython-39.pyc,, +torch/ao/quantization/backend_config/_common_operator_config_utils.py,sha256=S7L_VNQR8mwTwKDdRDXqZlIXxO4zwYJHal3Q33n7hjo,28294 +torch/ao/quantization/backend_config/_qnnpack_pt2e.py,sha256=jQjErEcpgC94Dx07DKacRtLwp-r3NV7_OV0F9pfDheU,6612 +torch/ao/quantization/backend_config/backend_config.py,sha256=UfjlsBJxo3_qSbe1TFe7xJ4dP2amMdQm0EtJCzx8-Cw,32410 +torch/ao/quantization/backend_config/executorch.py,sha256=erBNAKwOB10_T7tZxG79ELKwndYvwYmLW1SmNExzAAk,17422 +torch/ao/quantization/backend_config/fbgemm.py,sha256=Ol1Pr2cmjzPJnv5-jwz6mtvI1ggq7ZDv7-oOT4A3_Aw,4337 +torch/ao/quantization/backend_config/native.py,sha256=phw5IGaAYB6kIvFknf77xjJkXfeJxjaLWA8l1CWHaKk,8473 +torch/ao/quantization/backend_config/observation_type.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/ao/quantization/backend_config/onednn.py,sha256=V6YJu7CNz_RBlqmiVSKo7UtYENyvDXVztNbQ-u6xgGY,19723 +torch/ao/quantization/backend_config/qnnpack.py,sha256=9rWj2g2g1BKg4btmxwunj9XqQHRB-k0y9Pxg6_tb3JM,5571 +torch/ao/quantization/backend_config/tensorrt.py,sha256=2EPv3N_J68FgWAJXmsrOr5393o_84xk5LFhU2k5A33A,3119 +torch/ao/quantization/backend_config/utils.py,sha256=kdhTdlNJXyIjjWT4O2B_ddlwLvhpxqNq5wiuyuIVH48,12747 +torch/ao/quantization/backend_config/x86.py,sha256=pe-AhPb41P1BoL4ikQQ-UlgfIJ6wIEmRa4vAcfCCRLY,3995 +torch/ao/quantization/fake_quantize.py,sha256=ECYhRiwVp5qkXC-EmGx5LtLz_u1vsH2L-pjdk2m07C4,23537 +torch/ao/quantization/fuse_modules.py,sha256=O7uXN01hNC1v2xegvAWHBQqwxK91TklTuVHq1R5Y4KE,7041 +torch/ao/quantization/fuser_method_mappings.py,sha256=Pq1prcq13kc5J_ncRazlKnda4PDzZm3xobA4eFO-70w,10667 +torch/ao/quantization/fx/__init__.py,sha256=LznkRltyNzEL0Z_KaUUV2lOMPtlTaKXlwWx2Y428B9k,84 +torch/ao/quantization/fx/__pycache__/__init__.cpython-39.pyc,, +torch/ao/quantization/fx/__pycache__/_decomposed.cpython-39.pyc,, +torch/ao/quantization/fx/__pycache__/_equalize.cpython-39.pyc,, +torch/ao/quantization/fx/__pycache__/_lower_to_native_backend.cpython-39.pyc,, +torch/ao/quantization/fx/__pycache__/convert.cpython-39.pyc,, +torch/ao/quantization/fx/__pycache__/custom_config.cpython-39.pyc,, +torch/ao/quantization/fx/__pycache__/fuse.cpython-39.pyc,, +torch/ao/quantization/fx/__pycache__/fuse_handler.cpython-39.pyc,, +torch/ao/quantization/fx/__pycache__/graph_module.cpython-39.pyc,, +torch/ao/quantization/fx/__pycache__/lower_to_fbgemm.cpython-39.pyc,, +torch/ao/quantization/fx/__pycache__/lower_to_qnnpack.cpython-39.pyc,, +torch/ao/quantization/fx/__pycache__/lstm_utils.cpython-39.pyc,, +torch/ao/quantization/fx/__pycache__/match_utils.cpython-39.pyc,, +torch/ao/quantization/fx/__pycache__/pattern_utils.cpython-39.pyc,, +torch/ao/quantization/fx/__pycache__/prepare.cpython-39.pyc,, +torch/ao/quantization/fx/__pycache__/qconfig_mapping_utils.cpython-39.pyc,, +torch/ao/quantization/fx/__pycache__/quantize_handler.cpython-39.pyc,, +torch/ao/quantization/fx/__pycache__/tracer.cpython-39.pyc,, +torch/ao/quantization/fx/__pycache__/utils.cpython-39.pyc,, +torch/ao/quantization/fx/_decomposed.py,sha256=B0njMhYzQd9jHdj1d5i8SsFgG02n01pLqkixpRnLEU0,43666 +torch/ao/quantization/fx/_equalize.py,sha256=-9Ocpe0B_VTrt1Mu6VrKOTxTjIjWtMDCo-09MexvaWk,38872 +torch/ao/quantization/fx/_lower_to_native_backend.py,sha256=QQNJ409COekMXzJ-fy_tXteTiLaIf2h5HLx9QO1r4_w,55995 +torch/ao/quantization/fx/_model_report/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/ao/quantization/fx/_model_report/__pycache__/__init__.cpython-39.pyc,, +torch/ao/quantization/fx/_model_report/__pycache__/detector.cpython-39.pyc,, +torch/ao/quantization/fx/_model_report/__pycache__/model_report.cpython-39.pyc,, +torch/ao/quantization/fx/_model_report/__pycache__/model_report_observer.cpython-39.pyc,, +torch/ao/quantization/fx/_model_report/__pycache__/model_report_visualizer.cpython-39.pyc,, +torch/ao/quantization/fx/_model_report/detector.py,sha256=HszE9U5TbidXxrBzVoK6-Gsi03YiDFGARLWXknMr6e4,78132 +torch/ao/quantization/fx/_model_report/model_report.py,sha256=96OOIAvXw2KBDgjumNOVBCy4NFmqBjINjbYpzF-r9j8,30305 +torch/ao/quantization/fx/_model_report/model_report_observer.py,sha256=bJubiRVwuaPFAcvnFySxNobJ3GRg1l-SXEYyiVu4k1U,12369 +torch/ao/quantization/fx/_model_report/model_report_visualizer.py,sha256=Dks1FCWID4ns389bqsXMRENwuk_mHDIJnIPkwj3Q-Lw,33258 +torch/ao/quantization/fx/convert.py,sha256=5Og0nK_TNavJOoqI4zzoVSpK_pvPurLle5yiR6GA_Wg,59010 +torch/ao/quantization/fx/custom_config.py,sha256=nRQ2fm58qovr6kgMMrwL6mkOJizLJ--0-WHjfCBehiI,22382 +torch/ao/quantization/fx/fuse.py,sha256=ZwGWejJTLfb1TqiOuUda9RSRiDtssl9NdAjdGQ6rcPU,7447 +torch/ao/quantization/fx/fuse_handler.py,sha256=VV4MGfwvghKLcdIHKEcpjZ51gmsRGG8ax1jbYSyz8QM,4774 +torch/ao/quantization/fx/graph_module.py,sha256=SLq5zUgIYq5dOj5JriEuuU1boq0J2KQqenvVso-yAoA,6849 +torch/ao/quantization/fx/lower_to_fbgemm.py,sha256=St_PLfQHbvbqlpeD1c1tr--CaI5sDSKPKTTKyQBhk9s,623 +torch/ao/quantization/fx/lower_to_qnnpack.py,sha256=emVN0RZnJOc7JKUjei-lUz98l_2cGxyitgFXNPVAUOk,545 +torch/ao/quantization/fx/lstm_utils.py,sha256=5hY5MC_LaM_ouDiE0qxnBQP_EvCyw_H1dVKKvj2itvY,10537 +torch/ao/quantization/fx/match_utils.py,sha256=6S9Nmjdnb9cJlNQcerDe3UwwM0Sz2QjeNAf3jB-XP2o,9088 +torch/ao/quantization/fx/pattern_utils.py,sha256=p6f1tfEomnFd16QlPVXWzgrV699kpI3LukSB7Jh9D3k,3780 +torch/ao/quantization/fx/prepare.py,sha256=6DDcJGvmGq--ZteBbjgjBV3HWuUv-7VuielVYKhsjwA,89736 +torch/ao/quantization/fx/qconfig_mapping_utils.py,sha256=tvvaGh2PsGZORSa9wD7C5mih8xsiMQmM7XQo7LedUrQ,15759 +torch/ao/quantization/fx/quantize_handler.py,sha256=q72tXW2tNIGt8WTyc4vHwd3Jjd4zNrv0jtKl0ncUjqA,7505 +torch/ao/quantization/fx/tracer.py,sha256=Gm5GTjPKm-jkRs3UyOAxsw-7cBXCIG6yoXXdPf4y_y4,1736 +torch/ao/quantization/fx/utils.py,sha256=35udCMwa7xRSVXn-nDY9lwCs7Mz8UX9sd_dsjUob4Pw,38578 +torch/ao/quantization/observer.py,sha256=6PF6COR_sulgm0JIGgLLWhk4JXzSLh9WVQgRZ7jMKos,81385 +torch/ao/quantization/pt2e/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/ao/quantization/pt2e/__pycache__/__init__.cpython-39.pyc,, +torch/ao/quantization/pt2e/__pycache__/_affine_quantization.cpython-39.pyc,, +torch/ao/quantization/pt2e/__pycache__/_numeric_debugger.cpython-39.pyc,, +torch/ao/quantization/pt2e/__pycache__/duplicate_dq_pass.cpython-39.pyc,, +torch/ao/quantization/pt2e/__pycache__/export_utils.cpython-39.pyc,, +torch/ao/quantization/pt2e/__pycache__/graph_utils.cpython-39.pyc,, +torch/ao/quantization/pt2e/__pycache__/lowering.cpython-39.pyc,, +torch/ao/quantization/pt2e/__pycache__/port_metadata_pass.cpython-39.pyc,, +torch/ao/quantization/pt2e/__pycache__/prepare.cpython-39.pyc,, +torch/ao/quantization/pt2e/__pycache__/qat_utils.cpython-39.pyc,, +torch/ao/quantization/pt2e/__pycache__/utils.cpython-39.pyc,, +torch/ao/quantization/pt2e/_affine_quantization.py,sha256=4_q6Np_8-XQkuKo8rJ9mvIZsY5_1z2Rm7ACNH2BJrv8,34891 +torch/ao/quantization/pt2e/_numeric_debugger.py,sha256=ZVqP_wqtBExMCfNHZPjFB9c8gwU6g6XTY5E8TFsJVBU,12476 +torch/ao/quantization/pt2e/duplicate_dq_pass.py,sha256=sfui7mCwILS-uBiyBS4T223pQn10my98wsnfcPOtqds,3230 +torch/ao/quantization/pt2e/export_utils.py,sha256=8SwCTavZTk-PDQ5YqsfpGRl9eXAZsLTcWcdx2iOJs1M,8230 +torch/ao/quantization/pt2e/graph_utils.py,sha256=hsOPw9i__8jjsq7UOQ0rfNtC260nFTSkCYj39tb9Dtg,6566 +torch/ao/quantization/pt2e/lowering.py,sha256=79j2kyha6jnLMFXBmHDKy47P5fyH6QO5AsV90hU6tsY,1964 +torch/ao/quantization/pt2e/port_metadata_pass.py,sha256=MyrYgbEvdIBuuqaJwkitPcIkgkO6Ic5CnyyiV2iy30E,9437 +torch/ao/quantization/pt2e/prepare.py,sha256=rpHEXgfJDuo6Wy17gPfjCrnt3mt0TCC9QXp2kxenfDE,22126 +torch/ao/quantization/pt2e/qat_utils.py,sha256=nizBxWZzJpOpfXbCAeeyYe8q_0yMWsMEPIxDieDXDuY,37681 +torch/ao/quantization/pt2e/representation/__init__.py,sha256=uhKkqgWmX_G43QgqbD4FvLpn6nyf_y_ZRM8A7U9SOMM,116 +torch/ao/quantization/pt2e/representation/__pycache__/__init__.cpython-39.pyc,, +torch/ao/quantization/pt2e/representation/__pycache__/rewrite.cpython-39.pyc,, +torch/ao/quantization/pt2e/representation/rewrite.py,sha256=o_C1HCQTqLUudzlQq-WdtPFpzskCzKMhXrLzsBIs-ok,29199 +torch/ao/quantization/pt2e/utils.py,sha256=zuA78S1FXjcXR73dzvG1M4oLj6i-b03dUonZSb0c1y8,23851 +torch/ao/quantization/qconfig.py,sha256=HT7AHIvZ5PdWUtVA6UOryHnmUq0fZVsaSqvOKTw7_Hk,25038 +torch/ao/quantization/qconfig_mapping.py,sha256=Q7gBccxmu4a0Cnu2QuCzQRlWuhjbZTSSgXoShVzu0vU,15192 +torch/ao/quantization/quant_type.py,sha256=2pNr0PUNaewO7-VwLB_k9s02isvSuDSoDBYuiDMvA9k,795 +torch/ao/quantization/quantization_mappings.py,sha256=BpkZkXu3HcMJMZhJpNgHgt7DRbkO6w_LoMTEQlglLWg,14181 +torch/ao/quantization/quantize.py,sha256=0c8nfFfHGwokgkU06vnX3Iu5n0NrqLGfSsyfQO0nRXw,31830 +torch/ao/quantization/quantize_fx.py,sha256=vZQnlo3WEpcAwLzdMVYdl6nyPlcuJoT3PiFkys0pSk0,33412 +torch/ao/quantization/quantize_jit.py,sha256=kNATSuRBkqTeY-XcFW5srQ_V6ErZRJGd7gY6pQcvWoE,15016 +torch/ao/quantization/quantize_pt2e.py,sha256=yAcoX2X61zBcrApDdo2o2mZqLw7tLSQJyJhuI7wauAM,9735 +torch/ao/quantization/quantizer/__init__.py,sha256=kDj8d9proE6S4UYCNwSxFnaldBqsYVktx0gs1RYl1g0,477 +torch/ao/quantization/quantizer/__pycache__/__init__.cpython-39.pyc,, +torch/ao/quantization/quantizer/__pycache__/composable_quantizer.cpython-39.pyc,, +torch/ao/quantization/quantizer/__pycache__/embedding_quantizer.cpython-39.pyc,, +torch/ao/quantization/quantizer/__pycache__/quantizer.cpython-39.pyc,, +torch/ao/quantization/quantizer/__pycache__/utils.cpython-39.pyc,, +torch/ao/quantization/quantizer/__pycache__/x86_inductor_quantizer.cpython-39.pyc,, +torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer.cpython-39.pyc,, +torch/ao/quantization/quantizer/__pycache__/xnnpack_quantizer_utils.cpython-39.pyc,, +torch/ao/quantization/quantizer/__pycache__/xpu_inductor_quantizer.cpython-39.pyc,, +torch/ao/quantization/quantizer/composable_quantizer.py,sha256=qlJkHul7qH8boHjvTSue78ylpAPmrMtJtaJXk08TLCc,3095 +torch/ao/quantization/quantizer/embedding_quantizer.py,sha256=F0eKW-gimAqjlGmPrejxKdgx9sW5bzCrw0EEp_iP_kU,3554 +torch/ao/quantization/quantizer/quantizer.py,sha256=XrzAobRRMmjim2zligtDbz1NvooIIYVI_2o_BzAc6zw,6805 +torch/ao/quantization/quantizer/utils.py,sha256=RWY1Q-pchMO_P6QwWbXkWaKtOBC9CWVPlwZ7Cn8DiD4,3307 +torch/ao/quantization/quantizer/x86_inductor_quantizer.py,sha256=FWoW-Pw-wa_0dNiqZS3CmkK88iLHGfRPleiWRlEG-bk,66449 +torch/ao/quantization/quantizer/xnnpack_quantizer.py,sha256=oj113s1pT94jFsNKHFmWqK-jpo_UI0NciqN3peoNNwY,16760 +torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py,sha256=rQjgiUPaamcO2_IktVN_ybWTYf4xIS1AoRVwN1RZexI,41786 +torch/ao/quantization/quantizer/xpu_inductor_quantizer.py,sha256=w7Dp8Q5vdFOy9BOFLdB9DWZgs2r-L0MT64LVr56jg4g,3995 +torch/ao/quantization/stubs.py,sha256=wksJyH14lbX_vM5Do1f6ZSRcuhrFaDOgAVJ0SlAFZNw,2372 +torch/ao/quantization/utils.py,sha256=fcGI7ObNWI4TtLybY0Y75C9vD98o2S7zSzSgmXQg0uM,29759 +torch/autograd/__init__.py,sha256=SA-nBSCl3uhGuOpAdUgDkjoN2Si4qTVQy-2EJIj139s,26030 +torch/autograd/__pycache__/__init__.cpython-39.pyc,, +torch/autograd/__pycache__/anomaly_mode.cpython-39.pyc,, +torch/autograd/__pycache__/forward_ad.cpython-39.pyc,, +torch/autograd/__pycache__/function.cpython-39.pyc,, +torch/autograd/__pycache__/functional.cpython-39.pyc,, +torch/autograd/__pycache__/grad_mode.cpython-39.pyc,, +torch/autograd/__pycache__/gradcheck.cpython-39.pyc,, +torch/autograd/__pycache__/graph.cpython-39.pyc,, +torch/autograd/__pycache__/profiler.cpython-39.pyc,, +torch/autograd/__pycache__/profiler_legacy.cpython-39.pyc,, +torch/autograd/__pycache__/profiler_util.cpython-39.pyc,, +torch/autograd/__pycache__/variable.cpython-39.pyc,, +torch/autograd/_functions/__init__.py,sha256=jgHrjtlIe3ka_xUh-D_Ayj7EmRq003Wtyg9TQwQ1QPw,37 +torch/autograd/_functions/__pycache__/__init__.cpython-39.pyc,, +torch/autograd/_functions/__pycache__/tensor.cpython-39.pyc,, +torch/autograd/_functions/__pycache__/utils.cpython-39.pyc,, +torch/autograd/_functions/tensor.py,sha256=FFmYvdcrc3IbToNe0i9jeGVAiBHZvEdSRBnNspPegGk,2268 +torch/autograd/_functions/utils.py,sha256=tqi3FNRdNidKzbY-X45Vb-SUjV2qH-KL4l0WvijXXCg,2080 +torch/autograd/anomaly_mode.py,sha256=Fowf0oEBew4jesoYI1bgaKXsJ8-u-qoU-9TeucGMsJA,5072 +torch/autograd/forward_ad.py,sha256=1MnBUf9lEpy0WAOjNhEToWX-3B7vyzt5WE4xd4MhFvw,7868 +torch/autograd/function.py,sha256=OkfjQl9xoLZZcM8yKwTdrWowx44PZdCTwlfr3i2Yf_Q,33991 +torch/autograd/functional.py,sha256=zsANMO6zymN519X25y6Nk7lBYxLk2zV3IuhnWzEJn7A,53794 +torch/autograd/grad_mode.py,sha256=3n_eB78Ar7tot4BRSEA-OW0_tnMOJBXxyAcIBcIiPRA,13722 +torch/autograd/gradcheck.py,sha256=eIhcNUcyz4yHuf43gheAG4swgrxGGOuUtts8jWy2H78,92833 +torch/autograd/graph.py,sha256=7MRntJpPTOeu1qAcrsKFO1tZyRyT2qy6_oICINlv4Sc,31286 +torch/autograd/profiler.py,sha256=bsCVgbGP_oGdA84Ki7q12h-evFkRqUmc0wS0xZ7MaIo,50148 +torch/autograd/profiler_legacy.py,sha256=Vvu96BYpfLcXarxSNK4DRuG2aEaRIlEo2I6Ih6rGQwM,11818 +torch/autograd/profiler_util.py,sha256=Wd5wURH_GMkozOaJxswmWK_Hi5wRrhgRTSfW4qQhBW0,42509 +torch/autograd/variable.py,sha256=DvFQwaN93efyHpN42uFcUVSHWNwFyfiuC3t45H5vJns,406 +torch/backends/__init__.py,sha256=nnlvNuVIYJQgct8-I3pMhy-zK8jgJ1urKH0KC5wyU5Q,1850 +torch/backends/__pycache__/__init__.cpython-39.pyc,, +torch/backends/_coreml/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/backends/_coreml/__pycache__/__init__.cpython-39.pyc,, +torch/backends/_coreml/__pycache__/preprocess.cpython-39.pyc,, +torch/backends/_coreml/preprocess.py,sha256=Mw8AQGCMdasD08DauzimYB0KEtwx49N30QJA_4WV324,4451 +torch/backends/_nnapi/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/backends/_nnapi/__pycache__/__init__.cpython-39.pyc,, +torch/backends/_nnapi/__pycache__/prepare.cpython-39.pyc,, +torch/backends/_nnapi/__pycache__/serializer.cpython-39.pyc,, +torch/backends/_nnapi/prepare.py,sha256=GxtSK75LwZ9-5veE50_kWXDWqGf1eHH0BMM6mh4kfu0,6758 +torch/backends/_nnapi/serializer.py,sha256=07ShSKCCIq3rIU4ZQsQ6q2b3pIu6GRogwpX2OTY-pyI,85123 +torch/backends/cpu/__init__.py,sha256=B9FowRUnLL7Zbxlmzn84Sd8NwK2-o9HbopMdz4L4npo,335 +torch/backends/cpu/__pycache__/__init__.cpython-39.pyc,, +torch/backends/cuda/__init__.py,sha256=TZL9jcM1MDqt3FvBXdXNJ-ZDiN1bDrFvbfSGNcpHyAE,19623 +torch/backends/cuda/__pycache__/__init__.cpython-39.pyc,, +torch/backends/cudnn/__init__.py,sha256=XV2hlv_W0AO9Jh50SqSYpnA6DN1rhW0egDm6rjDwcV4,6811 +torch/backends/cudnn/__pycache__/__init__.cpython-39.pyc,, +torch/backends/cudnn/__pycache__/rnn.cpython-39.pyc,, +torch/backends/cudnn/rnn.py,sha256=gl9XlLM0aEJlQSYFuORoTWq1FQZwHJ-wy8jfwRHj9lE,2125 +torch/backends/cusparselt/__init__.py,sha256=Ktl12_Mgjoroyk7pR5CWHXlMqgZjiweX7yxXEYWOrpM,1303 +torch/backends/cusparselt/__pycache__/__init__.cpython-39.pyc,, +torch/backends/kleidiai/__init__.py,sha256=Ce5x4LSs2_Ec0uX5sq9a2Q9ZFlPXX0DQxyvoyYueoDY,169 +torch/backends/kleidiai/__pycache__/__init__.cpython-39.pyc,, +torch/backends/mha/__init__.py,sha256=E1q1Uqr78a4_L4EQYTILpDad4LxURpQigNBdHGAi4r8,743 +torch/backends/mha/__pycache__/__init__.cpython-39.pyc,, +torch/backends/mkl/__init__.py,sha256=OjXc4GzGn1hvOzElfPcYFKClk1LTM_om_7zy-Ab2geA,1839 +torch/backends/mkl/__pycache__/__init__.cpython-39.pyc,, +torch/backends/mkldnn/__init__.py,sha256=Hst-KOMuLs2CbMxYDYSMtLnpwoTfwdxSnmba3vrCtuM,3703 +torch/backends/mkldnn/__pycache__/__init__.cpython-39.pyc,, +torch/backends/mps/__init__.py,sha256=0MlxJPku0fNMou_8Jb5uER7PeAlZ-EouY4gglJhdMTU,1697 +torch/backends/mps/__pycache__/__init__.cpython-39.pyc,, +torch/backends/nnpack/__init__.py,sha256=Lza7vXQTuE_baJZL_05OozTyfeg4vu1VTWHCoMfso-I,869 +torch/backends/nnpack/__pycache__/__init__.cpython-39.pyc,, +torch/backends/openmp/__init__.py,sha256=Aee3z5fqS3euk0SSK4X39uZvxSZ-huoS5WY_Fvu54EE,164 +torch/backends/openmp/__pycache__/__init__.cpython-39.pyc,, +torch/backends/opt_einsum/__init__.py,sha256=ZK9DkzK067poaRJS2Jhblw8l8RWcNYEVQkx53ziBhbk,4001 +torch/backends/opt_einsum/__pycache__/__init__.cpython-39.pyc,, +torch/backends/quantized/__init__.py,sha256=Epc1-SXTMOvyBTVdV87ptvVYGH9DylVHfirzOPO7eio,1926 +torch/backends/quantized/__pycache__/__init__.cpython-39.pyc,, +torch/backends/xeon/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/backends/xeon/__pycache__/__init__.cpython-39.pyc,, +torch/backends/xeon/__pycache__/run_cpu.cpython-39.pyc,, +torch/backends/xeon/run_cpu.py,sha256=ZJ0uM1WsSXBJ24qf7cBlJcQUAcm9PuTaFsQLpUp8OGU,38479 +torch/backends/xnnpack/__init__.py,sha256=ifIkHb5pGCc3aYwkb6tSQnm30iPBvViD8A8KLvXrc8U,731 +torch/backends/xnnpack/__pycache__/__init__.cpython-39.pyc,, +torch/bin/asmjit.dll,sha256=YBlDMmQEb7VbJ1sPgD1wbPcFrzrAbql_CHsKpJqs3d0,367104 +torch/bin/fbgemm.dll,sha256=3HRFT930Be00P9jpXi0ojTdDl8LWsGWy52IDeNJzPsU,5721600 +torch/bin/protoc.exe,sha256=AUQEEJ8_9DuzUnevtZxXbtMQJ5mpbP1uWZmSLIbgKF8,2812416 +torch/compiler/__init__.py,sha256=8jnhkjc0_YdAuKywAASbRoXN97dw_go1qtepBCqwX5w,24024 +torch/compiler/__pycache__/__init__.cpython-39.pyc,, +torch/compiler/__pycache__/_cache.cpython-39.pyc,, +torch/compiler/__pycache__/config.cpython-39.pyc,, +torch/compiler/_cache.py,sha256=q-Klz5_Mrij3pOBa3GwarYI_YEaybFZgwlMeuzOmhUE,11233 +torch/compiler/config.py,sha256=v6JV_9X3I2eUEBaOTcX8XLDaQ4nJb_Ts7XbIsrxRXPk,3720 +torch/contrib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/contrib/__pycache__/__init__.cpython-39.pyc,, +torch/contrib/__pycache__/_tensorboard_vis.cpython-39.pyc,, +torch/contrib/_tensorboard_vis.py,sha256=PAsRpELdj9ig48p7Es4oWVLIjKL15JOxrHVzalaCl9E,6027 +torch/cpu/__init__.py,sha256=f22-ZnsrtpGWX7q3ofkoGuQTHwCBkPN-P860HJGNJ3k,5051 +torch/cpu/__pycache__/__init__.cpython-39.pyc,, +torch/cpu/amp/__init__.py,sha256=22SHegVxLfn2YM_gFoxYPfBrppjsWUXMqBTWipe3uIc,74 +torch/cpu/amp/__pycache__/__init__.cpython-39.pyc,, +torch/cpu/amp/__pycache__/autocast_mode.cpython-39.pyc,, +torch/cpu/amp/__pycache__/grad_scaler.cpython-39.pyc,, +torch/cpu/amp/autocast_mode.py,sha256=uYNYUnj1JcbKjf4HZOgBvBAeRGqcf1YgukaTSosL8RY,1572 +torch/cpu/amp/grad_scaler.py,sha256=GOq20a01Ba0YFMVEuGlYlgMGtBLo6koIUl9uYgxW8l4,993 +torch/cuda/__init__.py,sha256=QOAAhmT71vJWIKgQpfqFybuR5qp1uwBGsoFwug2FkPk,66651 +torch/cuda/__pycache__/__init__.cpython-39.pyc,, +torch/cuda/__pycache__/_gpu_trace.cpython-39.pyc,, +torch/cuda/__pycache__/_memory_viz.cpython-39.pyc,, +torch/cuda/__pycache__/_pin_memory_utils.cpython-39.pyc,, +torch/cuda/__pycache__/_sanitizer.cpython-39.pyc,, +torch/cuda/__pycache__/_utils.cpython-39.pyc,, +torch/cuda/__pycache__/comm.cpython-39.pyc,, +torch/cuda/__pycache__/error.cpython-39.pyc,, +torch/cuda/__pycache__/gds.cpython-39.pyc,, +torch/cuda/__pycache__/graphs.cpython-39.pyc,, +torch/cuda/__pycache__/jiterator.cpython-39.pyc,, +torch/cuda/__pycache__/memory.cpython-39.pyc,, +torch/cuda/__pycache__/nccl.cpython-39.pyc,, +torch/cuda/__pycache__/nvtx.cpython-39.pyc,, +torch/cuda/__pycache__/profiler.cpython-39.pyc,, +torch/cuda/__pycache__/random.cpython-39.pyc,, +torch/cuda/__pycache__/sparse.cpython-39.pyc,, +torch/cuda/__pycache__/streams.cpython-39.pyc,, +torch/cuda/__pycache__/tunable.cpython-39.pyc,, +torch/cuda/_gpu_trace.py,sha256=-YGY2Y8kVSfZGVmAIm5Qr8MXjXdSU-ExyYkB22JN4I4,2450 +torch/cuda/_memory_viz.py,sha256=iY-4bUC79tNFflipz2BrDVduTmrstdUh5Vs3RqCmlU4,26587 +torch/cuda/_pin_memory_utils.py,sha256=8RdDWW-YxnCaxgLnXWfx5rjJphfrE-Xl06By7lQA7E4,771 +torch/cuda/_sanitizer.py,sha256=RrY8RBdQBvFdSQRpvCOr9L60xvZt0y4lUo9ODhbh5xU,24861 +torch/cuda/_utils.py,sha256=vEp_WzaozN5gYQ4IqnUkNw5VBgxeFsr9PiytZSg7wtQ,12826 +torch/cuda/amp/__init__.py,sha256=SKHKvfv4Lj8s0Z0gO1S5FWSALds7Td7iWnEq4e-_vkc,279 +torch/cuda/amp/__pycache__/__init__.cpython-39.pyc,, +torch/cuda/amp/__pycache__/autocast_mode.cpython-39.pyc,, +torch/cuda/amp/__pycache__/common.cpython-39.pyc,, +torch/cuda/amp/__pycache__/grad_scaler.cpython-39.pyc,, +torch/cuda/amp/autocast_mode.py,sha256=AiSWnontfiZ6YOMUrhLrQ3z34lYPheUgjNSVxhUQmUU,2909 +torch/cuda/amp/common.py,sha256=3uWhYUr9-qeHDGh4U6n6NQYqeVnI4k1Y_FZnSIDJGgQ,241 +torch/cuda/amp/grad_scaler.py,sha256=Tv9cOcK98Sd2c65ft3gGf6SHBzv9Kb_wqzQMoNaS4MQ,1111 +torch/cuda/comm.py,sha256=mBimRSrUau-nRDQgCcmU4tsplgQBqZErOo1lcPKYGRw,363 +torch/cuda/error.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/cuda/gds.py,sha256=eaaMGet6XPMPZhzOPAOGD2VKgJIvoa0iwtAXqBYtRGQ,5972 +torch/cuda/graphs.py,sha256=4r6s97WnwVua0xV97x2BbKlhTqnZycdyO9vThPV-tpQ,24752 +torch/cuda/jiterator.py,sha256=lRGQfBiYPWrRUFAPN3ccW2M-ahTv9C2JLDpRu4sj_f8,7015 +torch/cuda/memory.py,sha256=qKia_PlKoV5B-5aV-_3vQWOl8FyFASBu4-G_fHmiH7A,47999 +torch/cuda/nccl.py,sha256=85XNZhgXmBqQM0o_HiHYUAd1gCKtaLGr5No6NOZqT6M,4729 +torch/cuda/nvtx.py,sha256=kYoWspwfOxbZRx1sCuyxkpQx0gNEe4_EaMDW3PRgRCs,3662 +torch/cuda/profiler.py,sha256=ulYEZICkOD7hyCW-WjxnXAPHoGMtpDvfRjeW0v5hRzc,2487 +torch/cuda/random.py,sha256=Ge3FywdyAUXAq0YxTFddaPLJFK7o1TDTaoSmluddfXc,5625 +torch/cuda/sparse.py,sha256=Fz_B2EnQ-a1-gPpkiuaOykIin7tXmT-CeLj7YflcloA,68 +torch/cuda/streams.py,sha256=wdMhl__fqJncE4fBChPnEBELivLMsCSI8NavAiygZek,9795 +torch/cuda/tunable.py,sha256=_TqTw3S_HmA4CyAPvAtveb6Z1FhLcQIiifAKw7mpA5A,32122 +torch/distributed/__init__.py,sha256=8f7c5dpJxGq1P1YDLhmacn3H5egKX9xHUJGgzH8BK3s,5190 +torch/distributed/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/__pycache__/_checkpointable.cpython-39.pyc,, +torch/distributed/__pycache__/_composable_state.cpython-39.pyc,, +torch/distributed/__pycache__/_functional_collectives.cpython-39.pyc,, +torch/distributed/__pycache__/_functional_collectives_impl.cpython-39.pyc,, +torch/distributed/__pycache__/_serialization.cpython-39.pyc,, +torch/distributed/__pycache__/_state_dict_utils.cpython-39.pyc,, +torch/distributed/__pycache__/argparse_util.cpython-39.pyc,, +torch/distributed/__pycache__/c10d_logger.cpython-39.pyc,, +torch/distributed/__pycache__/collective_utils.cpython-39.pyc,, +torch/distributed/__pycache__/constants.cpython-39.pyc,, +torch/distributed/__pycache__/device_mesh.cpython-39.pyc,, +torch/distributed/__pycache__/distributed_c10d.cpython-39.pyc,, +torch/distributed/__pycache__/launch.cpython-39.pyc,, +torch/distributed/__pycache__/logging_handlers.cpython-39.pyc,, +torch/distributed/__pycache__/remote_device.cpython-39.pyc,, +torch/distributed/__pycache__/rendezvous.cpython-39.pyc,, +torch/distributed/__pycache__/run.cpython-39.pyc,, +torch/distributed/__pycache__/utils.cpython-39.pyc,, +torch/distributed/_checkpointable.py,sha256=ZFUbvHl4IBMs3mbMR7ZqFn09_cSekEdB4Gfz2j0cPLI,1342 +torch/distributed/_composable/__init__.py,sha256=thBK3Pd-YbW1Q444eLu9PciaPGsNWw039cR1Q19Q4Gg,128 +torch/distributed/_composable/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/_composable/__pycache__/checkpoint_activation.cpython-39.pyc,, +torch/distributed/_composable/__pycache__/contract.cpython-39.pyc,, +torch/distributed/_composable/__pycache__/replicate.cpython-39.pyc,, +torch/distributed/_composable/checkpoint_activation.py,sha256=QlcauT4vqwHafAERqEF5GYQ6O9lcTieZkI4jER-zTpc,4857 +torch/distributed/_composable/contract.py,sha256=faATdku4oBhYPSJnTnAE3vEu2xy43ZTvtd-TtBDxo2Y,10624 +torch/distributed/_composable/fsdp/__init__.py,sha256=9DBnmW1yy-ApaEL75QCo7hsYTUQMzo5Ew8O4-BikHaE,172 +torch/distributed/_composable/fsdp/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/_composable/fsdp/__pycache__/fully_shard.cpython-39.pyc,, +torch/distributed/_composable/fsdp/fully_shard.py,sha256=bKBtbDN-wr8z90nXgoGpEOJO1J8Fc6it-HfkNbwRCgQ,248 +torch/distributed/_composable/replicate.py,sha256=tVxT8WZV3IPepqNE6oUMearpLYJcULqxNzW_0c0BFc0,9558 +torch/distributed/_composable_state.py,sha256=SG-CUbC4e328QzL33Rz9hN02nORP3We-VsNcqzIFrt8,1438 +torch/distributed/_functional_collectives.py,sha256=Vgl6naEIeziMH4DA4iP0dL8c9EfotE3zYSpLllfWHWk,45645 +torch/distributed/_functional_collectives_impl.py,sha256=6QgAGWwLzlocv05TubTtQOMdy6BI05fGBoL06s5OzqI,3340 +torch/distributed/_serialization.py,sha256=V6yherIQnWcLNGCtfV4hkPoZzVWJnNr8ycgJXeveZZg,4627 +torch/distributed/_shard/__init__.py,sha256=iQH3xAO1oGHbWgMAaP_ITI5ogx8MnQTSjowtvbYP6Tk,88 +torch/distributed/_shard/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/_shard/__pycache__/_utils.cpython-39.pyc,, +torch/distributed/_shard/__pycache__/api.cpython-39.pyc,, +torch/distributed/_shard/__pycache__/common_op_utils.cpython-39.pyc,, +torch/distributed/_shard/__pycache__/metadata.cpython-39.pyc,, +torch/distributed/_shard/__pycache__/op_registry_utils.cpython-39.pyc,, +torch/distributed/_shard/__pycache__/sharder.cpython-39.pyc,, +torch/distributed/_shard/_utils.py,sha256=NSZzZUHXID6AkkdQQ_-KT6okhW3sb35993CtI3hVxpE,1102 +torch/distributed/_shard/api.py,sha256=MnAEYQSnSFts0mx3aJ_1DCSq9kN8dYanLbgRdVPyzck,12702 +torch/distributed/_shard/checkpoint/__init__.py,sha256=BLwYCF5RBdVhgcQWIt-okUz4b2j4s8dZJUV5noFajKE,603 +torch/distributed/_shard/checkpoint/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/_shard/common_op_utils.py,sha256=GBorT9WnU4hYLQN2jnHL810UnDLiEzFMlSnL-nsLg7U,2244 +torch/distributed/_shard/metadata.py,sha256=Zv6m9C3NE9VbxcmfWfcl4fM3w_LZXaZfWXxQUf5fYy0,2279 +torch/distributed/_shard/op_registry_utils.py,sha256=CiG2tqa17vB8-byEvREH1xa6n3ok6jeZTw7yMfIJbOU,1072 +torch/distributed/_shard/sharded_optim/__init__.py,sha256=KTAmAspyEI0nVVpaZqDPcrSsQBCG-uJzJWuwmqvBtCE,1922 +torch/distributed/_shard/sharded_optim/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/_shard/sharded_optim/__pycache__/api.cpython-39.pyc,, +torch/distributed/_shard/sharded_optim/api.py,sha256=V6WbwMNv4rFN-0UILp_wxfk-ctqoFTEJHL92U7jKYRg,4379 +torch/distributed/_shard/sharded_tensor/__init__.py,sha256=AI6M_8ik2MGBOmZD_EpNKqJhgPhBiRgcsoitG0FT4ps,19744 +torch/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/_shard/sharded_tensor/__pycache__/api.cpython-39.pyc,, +torch/distributed/_shard/sharded_tensor/__pycache__/logger.cpython-39.pyc,, +torch/distributed/_shard/sharded_tensor/__pycache__/logging_handlers.cpython-39.pyc,, +torch/distributed/_shard/sharded_tensor/__pycache__/metadata.cpython-39.pyc,, +torch/distributed/_shard/sharded_tensor/__pycache__/reshard.cpython-39.pyc,, +torch/distributed/_shard/sharded_tensor/__pycache__/shard.cpython-39.pyc,, +torch/distributed/_shard/sharded_tensor/__pycache__/utils.cpython-39.pyc,, +torch/distributed/_shard/sharded_tensor/_ops/__init__.py,sha256=lmxOnYVj19giuuuhimdPOa9_hhfUDLNtb54PRrrn0rc,511 +torch/distributed/_shard/sharded_tensor/_ops/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/_shard/sharded_tensor/_ops/__pycache__/_common.cpython-39.pyc,, +torch/distributed/_shard/sharded_tensor/_ops/__pycache__/binary_cmp.cpython-39.pyc,, +torch/distributed/_shard/sharded_tensor/_ops/__pycache__/init.cpython-39.pyc,, +torch/distributed/_shard/sharded_tensor/_ops/__pycache__/misc_ops.cpython-39.pyc,, +torch/distributed/_shard/sharded_tensor/_ops/__pycache__/tensor_ops.cpython-39.pyc,, +torch/distributed/_shard/sharded_tensor/_ops/_common.py,sha256=WZ3HQwHgf7pEMrxtRH7iwQzU2BwA90YTc2eeBsrSJ2I,4309 +torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py,sha256=5r6jlsQqp-QSnX58cw15qd60IHOqdF1QabbmkCLKTMM,2814 +torch/distributed/_shard/sharded_tensor/_ops/init.py,sha256=tztslfE21-UP2LBiSbOLMGVH2BGCwxKh3A9XjrM2Xe4,5637 +torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py,sha256=PyhosH2velC0RYXbcy8M2AfRLBslVuslDIx8CqvSMdU,509 +torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py,sha256=9-6N_L2SP2YQhBT7Ug3u1te0_ZfhIrUCYgL66NwOCyk,7930 +torch/distributed/_shard/sharded_tensor/api.py,sha256=DE6cmHT9DjqrslSJc17frNTUUzv0DOSHIbCSxwIV4ao,56025 +torch/distributed/_shard/sharded_tensor/logger.py,sha256=LN2I0OxeVSkfYGDkvdKuicmHwxF05ojsg48juMHNlsM,1119 +torch/distributed/_shard/sharded_tensor/logging_handlers.py,sha256=TlQ6zgFzEXDyRX30awgGptSNGaPvhdBii_AMUtV8Vxw,375 +torch/distributed/_shard/sharded_tensor/metadata.py,sha256=UTEMAm3UuoGip42IrBmIhv6MfINRfXjrbX5LDvA_C_k,3092 +torch/distributed/_shard/sharded_tensor/reshard.py,sha256=SgGVV1SW0M2e_LXW42WdLFLtCx3BhUXplx1ijWQxs5s,10969 +torch/distributed/_shard/sharded_tensor/shard.py,sha256=m9eiukI8VWDG7mKK3HMSeb5Ge_Sa8ttHkp021eFrh4w,2422 +torch/distributed/_shard/sharded_tensor/utils.py,sha256=Z4ctOqDV85xVvuD0XEXEO0e8-SPNlnZmO_SdHluwwkc,11990 +torch/distributed/_shard/sharder.py,sha256=Hit1lmP4br_NABzhst2jkTINzzO3Oj_9aTiIeI2rjHQ,930 +torch/distributed/_shard/sharding_plan/__init__.py,sha256=bwXPLAUqt1LbEwkmcKf0F2xfLjOZOe3-H7XF9WzSba0,48 +torch/distributed/_shard/sharding_plan/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/_shard/sharding_plan/__pycache__/api.cpython-39.pyc,, +torch/distributed/_shard/sharding_plan/api.py,sha256=rbGYzRjoyyTBTkmIRlxG9cNveHQC3vcePgpRYOtFUhA,3736 +torch/distributed/_shard/sharding_spec/__init__.py,sha256=_5g_LTK5iUs19ZKCKThnOF3EfnUICbolGW_nuhep6zs,301 +torch/distributed/_shard/sharding_spec/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/_shard/sharding_spec/__pycache__/_internals.cpython-39.pyc,, +torch/distributed/_shard/sharding_spec/__pycache__/api.cpython-39.pyc,, +torch/distributed/_shard/sharding_spec/__pycache__/chunk_sharding_spec.cpython-39.pyc,, +torch/distributed/_shard/sharding_spec/_internals.py,sha256=tANfqgl9N-9v4wPrdR-lmx64pBzLoEROcs5o0YvKyM4,8635 +torch/distributed/_shard/sharding_spec/api.py,sha256=O8E2or9w-RIkfwc-o3easLLvA5JXivcmheADtavd44I,10084 +torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py,sha256=I8QfI4XaCAR5ZHapeENDdw818wpYxVaWtI9ps2XKMZI,9479 +torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/_common.cpython-39.pyc,, +torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding.cpython-39.pyc,, +torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding_bag.cpython-39.pyc,, +torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py,sha256=N0dbjzsCv5kDGBMztA8Ty_v-Xj5HqB8vFEjDLyc8XzI,13377 +torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py,sha256=ZuCM_9GDQfMp_0iLvmRgPMxdmHj6mG85Acq6dqN6O9o,11503 +torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py,sha256=QMUHO6IWT5m0JKRlHGY_w32_gAAgOlU1kNXIwfjEl_g,18796 +torch/distributed/_sharded_tensor/__init__.py,sha256=F2v-WztwXDCJmCC9j12mYZKP9aR_qKHRCyW8-Fig-RA,638 +torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/_sharding_spec/__init__.py,sha256=hpjniVEKyMpjJj_Y8896cZfS0O0DAQczWxKb0EQd6Zg,668 +torch/distributed/_sharding_spec/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/_state_dict_utils.py,sha256=OwNH0wmgVYaB_RCF0VN5o0kwP1QEVA7XjWf7BK-FIC8,30096 +torch/distributed/_symmetric_memory/__init__.py,sha256=A8LgM4y2pWrCbFqZE-vtdYlUWjFtf7B3CTcXH-4cN74,63598 +torch/distributed/_symmetric_memory/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/_symmetric_memory/__pycache__/_nvshmem_triton.cpython-39.pyc,, +torch/distributed/_symmetric_memory/_nvshmem_triton.py,sha256=n3PW8B-FUmeudvMDOF7W2yMBfJebyihI1iZ2GZA_dMQ,5720 +torch/distributed/_tensor/__init__.py,sha256=7CEueTzVfPhuDOUkv1_or6pLzyuLxGPFUW_h6Y5tjBQ,1014 +torch/distributed/_tensor/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/_tensor/__pycache__/api.cpython-39.pyc,, +torch/distributed/_tensor/__pycache__/placement_types.cpython-39.pyc,, +torch/distributed/_tensor/api.py,sha256=66VAvT_Nt3qGd1RfjGp8GAaJ0GbE2tfNO2deCCygDIQ,309 +torch/distributed/_tensor/placement_types.py,sha256=7neBt_m_QwfieD-08DUeKq9y1z-UAhswe2NSznxdw2c,394 +torch/distributed/_tools/__init__.py,sha256=Dsgy-WmACc3x2RvRhcnu-RMJKwmGHk8tpAt2N40fSkU,339 +torch/distributed/_tools/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/_tools/__pycache__/common_utils.cpython-39.pyc,, +torch/distributed/_tools/__pycache__/fake_collectives.cpython-39.pyc,, +torch/distributed/_tools/__pycache__/fsdp2_mem_tracker.cpython-39.pyc,, +torch/distributed/_tools/__pycache__/ilp_utils.cpython-39.pyc,, +torch/distributed/_tools/__pycache__/mem_tracker.cpython-39.pyc,, +torch/distributed/_tools/__pycache__/memory_tracker.cpython-39.pyc,, +torch/distributed/_tools/__pycache__/mod_tracker.cpython-39.pyc,, +torch/distributed/_tools/__pycache__/runtime_estimator.cpython-39.pyc,, +torch/distributed/_tools/__pycache__/sac_estimator.cpython-39.pyc,, +torch/distributed/_tools/__pycache__/sac_ilp.cpython-39.pyc,, +torch/distributed/_tools/common_utils.py,sha256=yW9YHDy08wd0BD1TJ9atKeo7ilxNHDbY5kL2ReT9lNg,1221 +torch/distributed/_tools/fake_collectives.py,sha256=Q2EEjJuH0whhYyb-OaYqR_RRCtsonAYND6KDO2PdUIU,12166 +torch/distributed/_tools/fsdp2_mem_tracker.py,sha256=p0erRstO9U6_SSXz3HPxM3PfHG2996-RDjzFatDxNTc,24297 +torch/distributed/_tools/ilp_utils.py,sha256=vy_OoUp-BFIHLJECHG5n_4GupfywNWCFfsLwZYEXem0,10392 +torch/distributed/_tools/mem_tracker.py,sha256=sQgg_4YhC6c3fpOKHldd-jz69CAeqRR-7mLiCQeK9Mg,43785 +torch/distributed/_tools/memory_tracker.py,sha256=SPxJ9eHNjUEcXITWkGaV8FeFRLo6Jde-kjDJ-RejK2E,11958 +torch/distributed/_tools/mod_tracker.py,sha256=32kkce_jtMu1lzmQghSAZLWt5uymsX-CrqRyvt5Afsk,10279 +torch/distributed/_tools/runtime_estimator.py,sha256=CtniO5x2f5u8-nPgk0Pj66XASOvZZum511MmOFADVqM,21674 +torch/distributed/_tools/sac_estimator.py,sha256=RwBAzo3MRCMLbpWpX_JmWW_z-kPNPsbLSb3ajcQA-t8,43246 +torch/distributed/_tools/sac_ilp.py,sha256=ssw7FGfd46u1KjIy5kpKkrbDrZ9gQojTUhJQ3vTxeLc,11616 +torch/distributed/algorithms/__init__.py,sha256=Pc8JMBAT_SDP0rhlhN4ziK8VIAsxKju29iZsKZwUH9A,44 +torch/distributed/algorithms/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/algorithms/__pycache__/join.cpython-39.pyc,, +torch/distributed/algorithms/_checkpoint/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/distributed/algorithms/_checkpoint/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/algorithms/_checkpoint/__pycache__/checkpoint_wrapper.cpython-39.pyc,, +torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py,sha256=AxFJGlbs_inoPOeaSrvzdcUCv7fr6-sojwLFITolEVo,12623 +torch/distributed/algorithms/_comm_hooks/__init__.py,sha256=EzLHTA765VfQrU0Bb7qJNZlXoJ7tMh44_VWY4QpS5pQ,138 +torch/distributed/algorithms/_comm_hooks/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/algorithms/_comm_hooks/__pycache__/default_hooks.cpython-39.pyc,, +torch/distributed/algorithms/_comm_hooks/default_hooks.py,sha256=wEX_wE24RvTc5Ad8hiECxl4BdoOlcbCA0m9stue-m1U,7845 +torch/distributed/algorithms/_optimizer_overlap/__init__.py,sha256=gfhXcWlZiV9qc3HF3rE0iAqb3i7O7IFvQzpKbnMnddQ,53 +torch/distributed/algorithms/_optimizer_overlap/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/algorithms/_optimizer_overlap/__pycache__/optimizer_overlap.cpython-39.pyc,, +torch/distributed/algorithms/_optimizer_overlap/optimizer_overlap.py,sha256=CJddd_95wQv-AKNkBHW01tJGpxrQmLSjQFIRBpq8RB0,3849 +torch/distributed/algorithms/_quantization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/distributed/algorithms/_quantization/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/algorithms/_quantization/__pycache__/quantization.cpython-39.pyc,, +torch/distributed/algorithms/_quantization/quantization.py,sha256=T03DHrPPcM8BaHTOHHBmTAae29SYV0gwmTfrvMnRGCA,5760 +torch/distributed/algorithms/ddp_comm_hooks/__init__.py,sha256=K_YD3IGlXqXEWNo7FhpuBj6Njr7hqUCkrcDvQQjLXSo,3707 +torch/distributed/algorithms/ddp_comm_hooks/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/algorithms/ddp_comm_hooks/__pycache__/ddp_zero_hook.cpython-39.pyc,, +torch/distributed/algorithms/ddp_comm_hooks/__pycache__/debugging_hooks.cpython-39.pyc,, +torch/distributed/algorithms/ddp_comm_hooks/__pycache__/default_hooks.cpython-39.pyc,, +torch/distributed/algorithms/ddp_comm_hooks/__pycache__/mixed_precision_hooks.cpython-39.pyc,, +torch/distributed/algorithms/ddp_comm_hooks/__pycache__/optimizer_overlap_hooks.cpython-39.pyc,, +torch/distributed/algorithms/ddp_comm_hooks/__pycache__/post_localSGD_hook.cpython-39.pyc,, +torch/distributed/algorithms/ddp_comm_hooks/__pycache__/powerSGD_hook.cpython-39.pyc,, +torch/distributed/algorithms/ddp_comm_hooks/__pycache__/quantization_hooks.cpython-39.pyc,, +torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py,sha256=goat8WWnYkCC9gpGpMfUMAFkrFJ00QGZximKlSQ9x9c,20178 +torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py,sha256=QQCaAKHjMyAvy9aAL5H66XR5sWbYAB5RaiYa_q_Aw20,1144 +torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py,sha256=jAeVOncsDwvX35q8cgSniO2LfWizZIoYW6x7ylDVSjU,7981 +torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py,sha256=pEPkGe14jATp7q_tGXv-WQ5XkozezzJfP48Kn7hO_SU,3340 +torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py,sha256=ZYUNbRoBnkiWDhh1znMTeHhR-36MopFU78J9JInl8jA,6287 +torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py,sha256=tNrXqBtz0ShZDpDsgVLsVEhYdTAgpcc4pEfucfebJVU,5274 +torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py,sha256=Zti8OvXTtodKdIL4E-70EsnN0ZI4XR9usengUEUgkkE,41274 +torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py,sha256=b-CWczlooxNUZFlhDUeF5B9491XKKz6JTunPLjTgr3A,8452 +torch/distributed/algorithms/join.py,sha256=n5-8GUatD3OyC0dT_dBghb8PR8nEOL_Nm4xO-Zyz0Ik,13734 +torch/distributed/algorithms/model_averaging/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/distributed/algorithms/model_averaging/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/algorithms/model_averaging/__pycache__/averagers.cpython-39.pyc,, +torch/distributed/algorithms/model_averaging/__pycache__/hierarchical_model_averager.cpython-39.pyc,, +torch/distributed/algorithms/model_averaging/__pycache__/utils.cpython-39.pyc,, +torch/distributed/algorithms/model_averaging/averagers.py,sha256=o8ey22WJkUBV6hiBCTrjJDihfJ2yNzaBCK-gwDeYdZo,5585 +torch/distributed/algorithms/model_averaging/hierarchical_model_averager.py,sha256=ir9Ct1Am04DJH7b9HSm2in8m1R9q9Lpp1TF4Ucr0-MI,9975 +torch/distributed/algorithms/model_averaging/utils.py,sha256=46Vi9H5FqJT8HVF0FT71Yto0mPW_xKi03zgKacakzRQ,3253 +torch/distributed/argparse_util.py,sha256=hsAuTVu0mAjfWycxDQwSMzF64PvTIE2s7cOa2kpOCHY,4007 +torch/distributed/autograd/__init__.py,sha256=GwTijeDd_HN8FuYgHpMIqooT1z3AkuhQMSnqqPm4Ugc,1700 +torch/distributed/autograd/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/c10d_logger.py,sha256=6JRSA7oTPuJ_JAM1azCE9K2VV7DWGVN-Mwiz7aWEiVo,3223 +torch/distributed/checkpoint/__init__.py,sha256=pfRBkipQ-e2VsvJgzgZL_RIMOHKqFSOIirOoCiMz_CE,710 +torch/distributed/checkpoint/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/_async_executor.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/_async_process_executor.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/_async_thread_executor.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/_checkpointer.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/_dedup_save_plans.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/_dedup_tensors.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/_extension.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/_fsspec_filesystem.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/_hf_utils.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/_nested_dict.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/_sharded_tensor_utils.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/_state_dict_stager.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/_storage_utils.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/_traverse.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/_version.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/api.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/default_planner.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/filesystem.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/format_utils.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/hf_storage.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/logger.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/logging_handlers.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/metadata.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/optimizer.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/planner.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/planner_helpers.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/resharding.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/staging.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/state_dict.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/state_dict_loader.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/state_dict_saver.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/stateful.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/storage.cpython-39.pyc,, +torch/distributed/checkpoint/__pycache__/utils.cpython-39.pyc,, +torch/distributed/checkpoint/_async_executor.py,sha256=5u5AHn9muq_PUSe-FAncWCoSny8gwa_1F6pOF679PKY,1129 +torch/distributed/checkpoint/_async_process_executor.py,sha256=tHYUqrIn0CbRN6NbpBd8OpgSsZ-VM0I8rozhvYhNgXE,12715 +torch/distributed/checkpoint/_async_thread_executor.py,sha256=EDsYR04C0-vCsX_h1SVX_OSwznxTHulflcBQRiH3NXc,1404 +torch/distributed/checkpoint/_checkpointer.py,sha256=-HxfEA8MXhc4QzVqW2LqvgyjBOcZiWTesAEAQmH1uZk,3756 +torch/distributed/checkpoint/_dedup_save_plans.py,sha256=TZD096lELZdkogH3HGMzw5JGHeKcUZ9QT8aDTRYv4Fg,2745 +torch/distributed/checkpoint/_dedup_tensors.py,sha256=vTfohDLqz1wIJkP4p9xgcsp9DFLHZeb7YJzV5Mv-JVc,2053 +torch/distributed/checkpoint/_extension.py,sha256=u4PVNjPFd4fJ6LDJcvxbVavje4lzpQ8j5cwwuOdl1iA,7924 +torch/distributed/checkpoint/_fsspec_filesystem.py,sha256=Nib_RA4hhGbNbX30lYLk6dK4ZgkfMNJhNPLSZKr62Qs,5750 +torch/distributed/checkpoint/_hf_utils.py,sha256=roYgE3rMBMHyPmWuYIWwOxV6F9gVcYKC5IrR5KdxYhM,2945 +torch/distributed/checkpoint/_nested_dict.py,sha256=YlZkWH5Q6XpUJwpLjSfie55l-M9SmXZJEbL7_UI_LE8,2342 +torch/distributed/checkpoint/_sharded_tensor_utils.py,sha256=vFi5JrynZ9tqGDO1osMg3ZFRcdY7PPOak-VSz4xmE2Q,4251 +torch/distributed/checkpoint/_state_dict_stager.py,sha256=bbmHPG1jHpBoByAhd8cts2tvXWVDVVCHwRZCKBTB_QU,13874 +torch/distributed/checkpoint/_storage_utils.py,sha256=lImEBWA00Wy_pBauprlFHzPi1vhrm1PEkXBZCRHddr8,1457 +torch/distributed/checkpoint/_traverse.py,sha256=5d4sIUn86T4lA8VnAtCKBAWTAfMOhObGcRUpWc9hrrA,7074 +torch/distributed/checkpoint/_version.py,sha256=2jSNOCX7Es6XawPrXxduplCACzbtzIgW_lSkURkP8a4,128 +torch/distributed/checkpoint/api.py,sha256=XulZrinATE9EIc4Ye_1oB5m73odOJlJseGzAO9doyS0,1459 +torch/distributed/checkpoint/default_planner.py,sha256=6bf-XoAFMl5G1AvWeCukfn7Pvo2HAcNdlxPUj7qyoA8,26881 +torch/distributed/checkpoint/filesystem.py,sha256=JcgNtA-RgxFDPITUKEErei8TOVHGdt7d13u8W6Jl4G4,35168 +torch/distributed/checkpoint/format_utils.py,sha256=GpVX23d2jbLGEEdX5E4UXP6IIeyRD_fGHW4w_8he7b4,10514 +torch/distributed/checkpoint/hf_storage.py,sha256=ZZts4jAbZIL3SZlytq6t-a5BK8Z5CeR7hnXOgaqYiGU,13275 +torch/distributed/checkpoint/logger.py,sha256=PTWeZk0B-Ayd3-XUZ6OEfHv5EE-uhwz3KP11V_ldQ0w,3667 +torch/distributed/checkpoint/logging_handlers.py,sha256=kjBL2bzgKjxZGodJrhZc_aDyabvhG-rzPLpxGIJwmnc,234 +torch/distributed/checkpoint/metadata.py,sha256=ain2ehPO0A4bZwIu9X6HqQjFAvVLSWLEx8gV2NaM0ZU,5781 +torch/distributed/checkpoint/optimizer.py,sha256=_Mv5o4iovhjFxYPqlmvnkZsZR6X-6h_mP_LmDsjNDM8,13510 +torch/distributed/checkpoint/planner.py,sha256=n14VKF0NJXT7sDs6sDIr7vlNBN2HYdyRWA5xOtEk_yI,16713 +torch/distributed/checkpoint/planner_helpers.py,sha256=5r3d7xX7n6MTB0J25VqWwMgRL9D0lJWLD16Kq0bURg4,17007 +torch/distributed/checkpoint/resharding.py,sha256=XLBXWRfcAD7pnM01B0Bmnez3rSMMNpOt59ModVMiS_U,2344 +torch/distributed/checkpoint/staging.py,sha256=rQfeMUZICoK9xqV8Axyf4GLS-CbnJk7diISviwDv124,5073 +torch/distributed/checkpoint/state_dict.py,sha256=bD3clV08Ydyz7h-jKXISyXCyzfZhZM42QB0BOKHiuv0,57846 +torch/distributed/checkpoint/state_dict_loader.py,sha256=TPz_aYbmlljSTsFOE0QSuTsw2cgeNSXHA1MEwcwMCy8,12756 +torch/distributed/checkpoint/state_dict_saver.py,sha256=DNFdawamX3JsCp1AcrTzeecRGUqRXwSyboaYUofSHxI,14517 +torch/distributed/checkpoint/stateful.py,sha256=3aEHPgcexX9gz3AIJaZejsl4OPqtSadvTxGCi3BPGpU,1103 +torch/distributed/checkpoint/storage.py,sha256=BqlHyr4Rs-sGJSI6gd0xovZ1yqEd8NTpGwuYGzGwoBM,10009 +torch/distributed/checkpoint/utils.py,sha256=8ZEN0igZBsmMUs8SZg3LbncopgwJRpR4ZU1wztOwum8,16550 +torch/distributed/collective_utils.py,sha256=PQ6mSRrt_W3iKiJsO6yQV7HaUHtnPpg3uSdHLuVrx4w,7490 +torch/distributed/constants.py,sha256=fk5-qXs8PGJIUHHHvihLBSz65uOMr-1KM9tnFs1VhQE,1255 +torch/distributed/device_mesh.py,sha256=baicIhfhW8OXGyzUWbe-SBC27Jx1qQhxVQFIfkcUW8Y,52221 +torch/distributed/distributed_c10d.py,sha256=FwzQN-e_9NayS08lUhZf7Iezy39A_cQgSyhMUq2EIT0,228009 +torch/distributed/elastic/__init__.py,sha256=I6xvpW3XNSMV50kNmYbKxwaz1Ihf-QwjazmhwXW1ANA,3731 +torch/distributed/elastic/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/elastic/__pycache__/control_plane.cpython-39.pyc,, +torch/distributed/elastic/agent/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/distributed/elastic/agent/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/elastic/agent/server/__init__.py,sha256=UuEcXkRQTe9-_iwHlKktGEyx8btsOrld-ZHngbhYLB4,1442 +torch/distributed/elastic/agent/server/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/elastic/agent/server/__pycache__/api.cpython-39.pyc,, +torch/distributed/elastic/agent/server/__pycache__/health_check_server.cpython-39.pyc,, +torch/distributed/elastic/agent/server/__pycache__/local_elastic_agent.cpython-39.pyc,, +torch/distributed/elastic/agent/server/api.py,sha256=FAKrCBm56o2pFR6xwDDoGK6m4_xFLN3BxqMe_lVEp7c,38462 +torch/distributed/elastic/agent/server/health_check_server.py,sha256=OuAC2PiBuSyBxRfreLqU30u-aDAVuSWotTR1lihhTFc,1744 +torch/distributed/elastic/agent/server/local_elastic_agent.py,sha256=BBocufS9m5psuf3eKD9UY6fOm2H3zweAPuGjDwfCae0,17002 +torch/distributed/elastic/control_plane.py,sha256=6WkVx31eMy0h1yCuQaMoaZCHfVy94Z6v0Blt7icaryk,1234 +torch/distributed/elastic/events/__init__.py,sha256=p5RfbE_2aAOgmdJIGWm-XBivStlOYe5Md9_oA_IJ-bc,5549 +torch/distributed/elastic/events/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/elastic/events/__pycache__/api.cpython-39.pyc,, +torch/distributed/elastic/events/__pycache__/handlers.cpython-39.pyc,, +torch/distributed/elastic/events/api.py,sha256=JkNihCruVXtr9GY_neimLGMA_eDCFQt5IwC9cJ41Tp8,3361 +torch/distributed/elastic/events/handlers.py,sha256=1JH4wO7999r9Apa2_F0zXsy1guVH4Pf0Hl-Hj19xElk,577 +torch/distributed/elastic/metrics/__init__.py,sha256=oaHhtjSiK2X91vb1P79mSFwIG7ORS4f4842BfVQEmG4,5070 +torch/distributed/elastic/metrics/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/elastic/metrics/__pycache__/api.cpython-39.pyc,, +torch/distributed/elastic/metrics/api.py,sha256=IDXHfpZCdQgrWSo5stnldXBjZULRrO82mBaqAnMW4xQ,5893 +torch/distributed/elastic/multiprocessing/__init__.py,sha256=YhZBNt-1iE7w9OKQuvfSraBgPV5M-QCa4j64FsoeWh0,7612 +torch/distributed/elastic/multiprocessing/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/elastic/multiprocessing/__pycache__/api.cpython-39.pyc,, +torch/distributed/elastic/multiprocessing/__pycache__/redirects.cpython-39.pyc,, +torch/distributed/elastic/multiprocessing/__pycache__/tail_log.cpython-39.pyc,, +torch/distributed/elastic/multiprocessing/api.py,sha256=jDc-udxCXtFZVJ32k3G12y5JsIJEv_q4Bgv1SHL4aDA,34668 +torch/distributed/elastic/multiprocessing/errors/__init__.py,sha256=nbAP4o97Xu6byob2uGfpZByV7XScBJK6LvJqKucfDpY,14876 +torch/distributed/elastic/multiprocessing/errors/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/elastic/multiprocessing/errors/__pycache__/error_handler.cpython-39.pyc,, +torch/distributed/elastic/multiprocessing/errors/__pycache__/handlers.cpython-39.pyc,, +torch/distributed/elastic/multiprocessing/errors/error_handler.py,sha256=FV1TdyS59acFA0xwBUu17A_IFqfDLMrPrUvo1OCG7yk,6766 +torch/distributed/elastic/multiprocessing/errors/handlers.py,sha256=n0SzDw51fxKuFsNbNCkonhmzao-fZ_KoUq3ipRZ-nGs,482 +torch/distributed/elastic/multiprocessing/redirects.py,sha256=M6kca0OhxlblpEvq8FSsZfZ6KuqxbM_VrlV65GRrLhA,2868 +torch/distributed/elastic/multiprocessing/subprocess_handler/__init__.py,sha256=ch8Cj7Mamh9UeiXgVwvn-DfiZA7pSZpz42BqNOjBde4,539 +torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/handlers.cpython-39.pyc,, +torch/distributed/elastic/multiprocessing/subprocess_handler/__pycache__/subprocess_handler.cpython-39.pyc,, +torch/distributed/elastic/multiprocessing/subprocess_handler/handlers.py,sha256=9KkYH48y0XNVOTmYtahIeSWVVah8SA5CgC5ihPgWxS4,756 +torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py,sha256=jmhJuk5i1qfpdUc8D5cMXKq8O725C5td5Z9LjAxZ8aM,2514 +torch/distributed/elastic/multiprocessing/tail_log.py,sha256=xuKWjBlzrCEgPx2PFnN3jo4h9yHfkQBYnVGfzbyj4mE,5105 +torch/distributed/elastic/rendezvous/__init__.py,sha256=cBTBUfhgLe8PhIpBTm61KRWd-fOHwjfELL2qSLinLng,6432 +torch/distributed/elastic/rendezvous/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/elastic/rendezvous/__pycache__/_etcd_stub.cpython-39.pyc,, +torch/distributed/elastic/rendezvous/__pycache__/api.cpython-39.pyc,, +torch/distributed/elastic/rendezvous/__pycache__/c10d_rendezvous_backend.cpython-39.pyc,, +torch/distributed/elastic/rendezvous/__pycache__/dynamic_rendezvous.cpython-39.pyc,, +torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous.cpython-39.pyc,, +torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous_backend.cpython-39.pyc,, +torch/distributed/elastic/rendezvous/__pycache__/etcd_server.cpython-39.pyc,, +torch/distributed/elastic/rendezvous/__pycache__/etcd_store.cpython-39.pyc,, +torch/distributed/elastic/rendezvous/__pycache__/registry.cpython-39.pyc,, +torch/distributed/elastic/rendezvous/__pycache__/static_tcp_rendezvous.cpython-39.pyc,, +torch/distributed/elastic/rendezvous/__pycache__/utils.cpython-39.pyc,, +torch/distributed/elastic/rendezvous/_etcd_stub.py,sha256=Ot-ThDI4HtowWGX4L6hRz2aZPjdd-iyVFHC2JTGt_cE,2089 +torch/distributed/elastic/rendezvous/api.py,sha256=A8FvriTOoxvC9VEd6_26YbfKAOFuBvucaW3t4SAEhyY,13484 +torch/distributed/elastic/rendezvous/c10d_rendezvous_backend.py,sha256=wjNF-DgOMaSApUBvFAZcbP8i5F3Nd8VKLjsWwwe5S4w,11173 +torch/distributed/elastic/rendezvous/dynamic_rendezvous.py,sha256=IPEIb-KYZlfSzo-fmc0O6wEwoZqf40vzT40F7YblFYM,50880 +torch/distributed/elastic/rendezvous/etcd_rendezvous.py,sha256=5qlbzNGbuLD_onQp7pgXJDg71OtWomiqTmuqmTycNtY,44611 +torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py,sha256=H-EzDxSNBTs5Dgl0jlCLD1qUTVvb5BQ6Pk1m5NUJnes,7624 +torch/distributed/elastic/rendezvous/etcd_server.py,sha256=xApL8vRQkGk0w2xewCIcHFn95ZNReQVHMMDm8xAOQBI,8685 +torch/distributed/elastic/rendezvous/etcd_store.py,sha256=SZqdO9TeYIiY5VmhkSdjFmWWkxl9ETxgQTkJ_pdr37I,7471 +torch/distributed/elastic/rendezvous/registry.py,sha256=pKutBL8fRfQhzfwiPRVA93wVLZruZ7nxM_JuoEikbmw,3123 +torch/distributed/elastic/rendezvous/static_tcp_rendezvous.py,sha256=ThxJvPoZaCdTHlg_eQGiVteLNPKdEl6N-hxTHQxIUxY,3793 +torch/distributed/elastic/rendezvous/utils.py,sha256=yOuQySTT4hNGA8-2OhXOPezB_PJdCpdgT3b0AXlcdV0,8674 +torch/distributed/elastic/timer/__init__.py,sha256=vh5LGpZfIfvcRSca6cMhWT6m30YPW7YxIR8s7fWpF2w,1804 +torch/distributed/elastic/timer/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/elastic/timer/__pycache__/api.cpython-39.pyc,, +torch/distributed/elastic/timer/__pycache__/debug_info_logging.cpython-39.pyc,, +torch/distributed/elastic/timer/__pycache__/file_based_local_timer.cpython-39.pyc,, +torch/distributed/elastic/timer/__pycache__/local_timer.cpython-39.pyc,, +torch/distributed/elastic/timer/api.py,sha256=gvX5TnB--QiQSUJDnDcQ9A4QvvEo-DrWBPGkm30hDsg,9869 +torch/distributed/elastic/timer/debug_info_logging.py,sha256=Cl88jOhn-HOHoIDmCJ30UfhpNHQnTKCO_QR5ExEuMiA,634 +torch/distributed/elastic/timer/file_based_local_timer.py,sha256=HFJo7nYr-SNIvdJJMRuT6y6jX1zR8XyJTa9UnSCX310,16891 +torch/distributed/elastic/timer/local_timer.py,sha256=QIiwqFL5NSMvX-YWHi8zSC14njR3lMmmpHcM6i4Yq5Y,4413 +torch/distributed/elastic/utils/__init__.py,sha256=tUTWXyJ9Fd67kIGlH9Lv_G0qNPSv2GjzOIts1uZop50,327 +torch/distributed/elastic/utils/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/elastic/utils/__pycache__/api.cpython-39.pyc,, +torch/distributed/elastic/utils/__pycache__/distributed.cpython-39.pyc,, +torch/distributed/elastic/utils/__pycache__/log_level.cpython-39.pyc,, +torch/distributed/elastic/utils/__pycache__/logging.cpython-39.pyc,, +torch/distributed/elastic/utils/__pycache__/store.cpython-39.pyc,, +torch/distributed/elastic/utils/api.py,sha256=GWTQS0CYilyDEK0s8r85MSRZCUDvrzjf0E2KPel4IF8,1767 +torch/distributed/elastic/utils/data/__init__.py,sha256=SAnEMDWeN7az6Keywbq5hzReYUH2CtnatIkMOuiB0Wc,382 +torch/distributed/elastic/utils/data/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/elastic/utils/data/__pycache__/cycling_iterator.cpython-39.pyc,, +torch/distributed/elastic/utils/data/__pycache__/elastic_distributed_sampler.cpython-39.pyc,, +torch/distributed/elastic/utils/data/cycling_iterator.py,sha256=nCr5OFt3TYLiKegT8owaUbX5yFLTIT7of8OuFcdzOrw,1700 +torch/distributed/elastic/utils/data/elastic_distributed_sampler.py,sha256=Q0MIcj4CUTk2ab1Tc4ME9HOIzVm6kiQJrr1GF4cKUVk,3126 +torch/distributed/elastic/utils/distributed.py,sha256=ihmDVbYs9vpIw47y_4C2YezQ2Jl-cqJHYvB2NqWtEg8,6107 +torch/distributed/elastic/utils/log_level.py,sha256=XU53rCS6i_d1O-_Iz4TRBE-Iup_Qe1hvgBHGzzFdXt4,353 +torch/distributed/elastic/utils/logging.py,sha256=LCJGehDHNJX85gEhlvBb6UjY44t1L-bKbAvk3N0ir-Y,2340 +torch/distributed/elastic/utils/store.py,sha256=6qSiA-i1Y5kNx_JMUUg0VVpQt1s-pONQQTX2-3Vkw1I,7504 +torch/distributed/fsdp/__init__.py,sha256=cz340oLncBv822k_hmxn6eXp3yBX8yM9j-KhqL8oyA0,1806 +torch/distributed/fsdp/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/_common_utils.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/_debug_utils.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/_dynamo_utils.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/_exec_order_utils.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/_flat_param.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/_fsdp_extensions.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/_init_utils.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/_limiter_utils.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/_optim_utils.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/_runtime_utils.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/_shard_utils.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/_state_dict_utils.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/_trace_utils.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/_traversal_utils.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/_unshard_param_utils.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/_wrap_utils.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/api.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/fully_sharded_data_parallel.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/sharded_grad_scaler.cpython-39.pyc,, +torch/distributed/fsdp/__pycache__/wrap.cpython-39.pyc,, +torch/distributed/fsdp/_common_utils.py,sha256=IhTz6ZC4547jnt4jO2c33uYdMOacke97-rbNGFlLn4g,22940 +torch/distributed/fsdp/_debug_utils.py,sha256=y79vp1STzfk0DvkdLDDGEGeg_ZxY3w5aJsEzKrpkRTI,5851 +torch/distributed/fsdp/_dynamo_utils.py,sha256=5wfQqx7RfMgSY7-RsQdLSRGv2pVX4cWPqxZZVEeFbzA,2674 +torch/distributed/fsdp/_exec_order_utils.py,sha256=HIx9LyzALwx2Cip-I3soaPtUvbKRoy12cvvSvW5i4g8,16437 +torch/distributed/fsdp/_flat_param.py,sha256=h9rKLaC7xo-nvYMEU1hmHgDMcO-TmFNovWp9QkQVdPI,126418 +torch/distributed/fsdp/_fsdp_extensions.py,sha256=86LMCJ08hU9shMJzm0U8OEvIYTDYeQV5T3T7lbuP-bo,5135 +torch/distributed/fsdp/_fully_shard/__init__.py,sha256=_yoX0FoVuiHi1epOHbpyDddq9vA-zAuNsaiuXOw79vc,394 +torch/distributed/fsdp/_fully_shard/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_api.cpython-39.pyc,, +torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_collectives.cpython-39.pyc,, +torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_common.cpython-39.pyc,, +torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_init.cpython-39.pyc,, +torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param.cpython-39.pyc,, +torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param_group.cpython-39.pyc,, +torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_state.cpython-39.pyc,, +torch/distributed/fsdp/_fully_shard/__pycache__/_fully_shard.cpython-39.pyc,, +torch/distributed/fsdp/_fully_shard/_fsdp_api.py,sha256=lGi6WzFTA1OkMXzfH044CF9_yRMzADbRb8p-yzHFua4,3422 +torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py,sha256=RJ_lJhR6BGkkGZ-KLcpZjNhs9cNmyGDvDVsNfVYh-eA,27461 +torch/distributed/fsdp/_fully_shard/_fsdp_common.py,sha256=jUen64KMENJ5LQkC9tsPSmje9EZTgTH1_866hL3VN68,5861 +torch/distributed/fsdp/_fully_shard/_fsdp_init.py,sha256=DO6j0EuffCQhjOJj49ROnlhfUim0LNXGDYBVUQowEVU,9384 +torch/distributed/fsdp/_fully_shard/_fsdp_param.py,sha256=yJqmjfRqrLf0ekJk_UpnrjhmN_dsihKm-vJdAItHy4s,42418 +torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py,sha256=7iRgYfL2wZ_5nLX41whSGvIF-URTrOkywfrRjAOwPrA,36221 +torch/distributed/fsdp/_fully_shard/_fsdp_state.py,sha256=vCJ5dLQUbgfYsgd-gPULauHWTxd9M3qu7HVtbsB-Qhs,18197 +torch/distributed/fsdp/_fully_shard/_fully_shard.py,sha256=WU54wW_D5cDTHNhY0jc7a5Vgu47mxuvWikBO413KBK4,30790 +torch/distributed/fsdp/_init_utils.py,sha256=f9zvIAexWMfp5VC9WaoXYEa-lGz4SdugZNV6ENyAIvU,47409 +torch/distributed/fsdp/_limiter_utils.py,sha256=2vinVgT6BalDbCTKABgCstpDCGxl94nm47E1G_qxO0s,1119 +torch/distributed/fsdp/_optim_utils.py,sha256=LChQyIP9uyQ-2QGZ5HQMs4FO2Jy7W7xnJIAdX3LnUR4,88692 +torch/distributed/fsdp/_runtime_utils.py,sha256=b1a-8WPvOb2DWj3P3EcNv8rClyzSUspkCsib3ZunWdw,68170 +torch/distributed/fsdp/_shard_utils.py,sha256=biaJcUFdVcGFtPEGcvLXOnYrb1JK-4QvMcb0xkRD4Mg,4760 +torch/distributed/fsdp/_state_dict_utils.py,sha256=btJMxETGPWNcYQplMyRInLASiO9Vc07iogGlCY4vWtQ,35212 +torch/distributed/fsdp/_trace_utils.py,sha256=nNuy4CSkOqqrxxhCEJ5edxiLKpaa1k2ESTeB8xdTrE8,10990 +torch/distributed/fsdp/_traversal_utils.py,sha256=I7av2W1foK7nHcKYE4ehsiH1EhyP1keNMvhvLuxtSA4,4722 +torch/distributed/fsdp/_unshard_param_utils.py,sha256=PBl6NBeKQI_FLTlTmuTQwHHg7sP1NFNXz3uR2HgiPLA,11861 +torch/distributed/fsdp/_wrap_utils.py,sha256=T6Ni5JTzNEOJ_Hp-iDUKzKd2KgKw3HaLWKg2bFJsILU,11157 +torch/distributed/fsdp/api.py,sha256=Ze9_vkpMquSeR6JAwjcIQ8QjwJ-bgy01KMEu9dW8wzU,19392 +torch/distributed/fsdp/fully_sharded_data_parallel.py,sha256=21hkg_TujKsbvInzNUd-yWkzBsAVBDFf9VacROhokTw,102447 +torch/distributed/fsdp/sharded_grad_scaler.py,sha256=JxucN3z6rN3iWSGlgzmyxxH8kEG6Bw34-5qIkIFkpPo,16523 +torch/distributed/fsdp/wrap.py,sha256=mS-fCpyUjOHR6I_YoDppgMVGtQBuVlLYvIdricGxf50,23211 +torch/distributed/launch.py,sha256=wKfJduKs2hFmzU3Ry7b0r5mesOkEdi1lYnUu0PmmQmY,7802 +torch/distributed/launcher/__init__.py,sha256=C63mqEEdfXL6c0fKxb8XMOAA5DQm28pvC-Y6EtiFKnY,363 +torch/distributed/launcher/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/launcher/__pycache__/api.cpython-39.pyc,, +torch/distributed/launcher/api.py,sha256=7znQoeemNDneEilRJ6JYRc6xzRuoNx282Ebd6fEgth8,11749 +torch/distributed/logging_handlers.py,sha256=TlQ6zgFzEXDyRX30awgGptSNGaPvhdBii_AMUtV8Vxw,375 +torch/distributed/nn/__init__.py,sha256=rT2varq7EORrQ1fq7hl6VAaULUlo9wcFS3AvGw-tt2g,152 +torch/distributed/nn/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/nn/__pycache__/functional.cpython-39.pyc,, +torch/distributed/nn/api/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/distributed/nn/api/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/nn/api/__pycache__/remote_module.cpython-39.pyc,, +torch/distributed/nn/api/remote_module.py,sha256=5QQRC9uSlKHWW-XLZofPzuVGEIpdOpRqzuyMqXOnD9I,32040 +torch/distributed/nn/functional.py,sha256=JhLvnfJmWQE7IF__vMxYpf5_Kno4o6xvH041JC-5Vns,15667 +torch/distributed/nn/jit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/distributed/nn/jit/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/nn/jit/__pycache__/instantiator.cpython-39.pyc,, +torch/distributed/nn/jit/instantiator.py,sha256=HxsmlbQFN6gzwpx0BkjXDzkW2OT16S309iHgKilVwRI,5666 +torch/distributed/nn/jit/templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/distributed/nn/jit/templates/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/nn/jit/templates/__pycache__/remote_module_template.cpython-39.pyc,, +torch/distributed/nn/jit/templates/remote_module_template.py,sha256=Kdt3GxUr-0oy4S4x9D43MPn6MmBSMvvR_hJqB2IkPZg,3571 +torch/distributed/optim/__init__.py,sha256=XaERShuW8bi8HE3Hc3IcWRNpWopXEHH0S3fttiLUON0,1484 +torch/distributed/optim/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/optim/__pycache__/_deprecation_warning.cpython-39.pyc,, +torch/distributed/optim/__pycache__/apply_optimizer_in_backward.cpython-39.pyc,, +torch/distributed/optim/__pycache__/functional_adadelta.cpython-39.pyc,, +torch/distributed/optim/__pycache__/functional_adagrad.cpython-39.pyc,, +torch/distributed/optim/__pycache__/functional_adam.cpython-39.pyc,, +torch/distributed/optim/__pycache__/functional_adamax.cpython-39.pyc,, +torch/distributed/optim/__pycache__/functional_adamw.cpython-39.pyc,, +torch/distributed/optim/__pycache__/functional_rmsprop.cpython-39.pyc,, +torch/distributed/optim/__pycache__/functional_rprop.cpython-39.pyc,, +torch/distributed/optim/__pycache__/functional_sgd.cpython-39.pyc,, +torch/distributed/optim/__pycache__/named_optimizer.cpython-39.pyc,, +torch/distributed/optim/__pycache__/optimizer.cpython-39.pyc,, +torch/distributed/optim/__pycache__/post_localSGD_optimizer.cpython-39.pyc,, +torch/distributed/optim/__pycache__/utils.cpython-39.pyc,, +torch/distributed/optim/__pycache__/zero_redundancy_optimizer.cpython-39.pyc,, +torch/distributed/optim/_deprecation_warning.py,sha256=fBIq67Fwa8ctxVsAPCiPF_S52-6oK5zB5fldvXykTOI,563 +torch/distributed/optim/apply_optimizer_in_backward.py,sha256=8uEDlD1xLBt8OVNQGmXD25mIULvoEzq8Hg8azfOLngg,5335 +torch/distributed/optim/functional_adadelta.py,sha256=B1NkCGXTCoFooIKrPe-yy3-UD6vfS6E0W_OTZmwF5Jg,4065 +torch/distributed/optim/functional_adagrad.py,sha256=cgesG57oakJdXzVKUlenn8pJ3dH19q2bcVXscIwCByY,4419 +torch/distributed/optim/functional_adam.py,sha256=DsKJ2rZgYod7AZmJTh1ttncwuts668VtBx0aOKZRxwY,7627 +torch/distributed/optim/functional_adamax.py,sha256=9XwgpqPuTal0oZjcjnoaBwBLKxR0Vjb8ApqaAp__4iY,4774 +torch/distributed/optim/functional_adamw.py,sha256=rlgGZJJ0CCM47mAPVNs0v3BPfxn-ImuVU7ltVS-Ne3c,7748 +torch/distributed/optim/functional_rmsprop.py,sha256=u_wEkVBwVjG4KM97tq26g1Vy6KIDkwGNFnZVro8_CzE,4814 +torch/distributed/optim/functional_rprop.py,sha256=MhKRV7e3hcrNBQzfGjeKs88W35ImHEUpns-hlAYz6AY,3945 +torch/distributed/optim/functional_sgd.py,sha256=gaxlo2Adbs82GKjePrQKkn_qoqF4EY9RYXMMkRioUUs,6104 +torch/distributed/optim/named_optimizer.py,sha256=2wR4CvzCxmUjxZBAPSh4MuATO1l4HEV2B73qzELT3LI,14312 +torch/distributed/optim/optimizer.py,sha256=31qBZKlS3HEC-QALnjIoLD1Cocs30_ftXnKQ1zxmj0U,10026 +torch/distributed/optim/post_localSGD_optimizer.py,sha256=HGsQ7Zmxh6cAgt8ywtLMam85Xs1djqPYzOA0WdDPFik,4577 +torch/distributed/optim/utils.py,sha256=rXAeyMs8VONBi6d3TSWlZzaPmUgkLLuQaqNns8QIiEs,2303 +torch/distributed/optim/zero_redundancy_optimizer.py,sha256=P0QVEqCSoYRzNDux4mml39n6Fy3x1AkoADMSZa8qY8M,73657 +torch/distributed/optim/zero_redundancy_optimizer.pyi,sha256=2xY-roUT0qj67ee0aBaOFlAju3wcN9NPJpyWuByL3o8,2891 +torch/distributed/pipelining/_IR.py,sha256=Fv7hceeDIlwjhj62g9j58xQdKeR3FciyuH0jkQO-VoQ,50490 +torch/distributed/pipelining/__init__.py,sha256=VKWp2ptehCalmwioR733ODYJpIpekzy_qC5XaiKFlrY,669 +torch/distributed/pipelining/__pycache__/_IR.cpython-39.pyc,, +torch/distributed/pipelining/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/pipelining/__pycache__/_backward.cpython-39.pyc,, +torch/distributed/pipelining/__pycache__/_debug.cpython-39.pyc,, +torch/distributed/pipelining/__pycache__/_schedule_visualizer.cpython-39.pyc,, +torch/distributed/pipelining/__pycache__/_unflatten.cpython-39.pyc,, +torch/distributed/pipelining/__pycache__/_utils.cpython-39.pyc,, +torch/distributed/pipelining/__pycache__/microbatch.cpython-39.pyc,, +torch/distributed/pipelining/__pycache__/schedules.cpython-39.pyc,, +torch/distributed/pipelining/__pycache__/stage.cpython-39.pyc,, +torch/distributed/pipelining/_backward.py,sha256=Fz5dhL2OMLiY5EEguD5eB3w2kct4oynt_Kk1ky0wefI,16270 +torch/distributed/pipelining/_debug.py,sha256=hreZ3Kq-c7t2OIPBxieVTDNm8ASZpTnCaarnSzW4jC8,630 +torch/distributed/pipelining/_schedule_visualizer.py,sha256=hKjIEr-syvTzjsQ0e8bQOqM4WcnxrGscp4WgDfrspFE,7111 +torch/distributed/pipelining/_unflatten.py,sha256=jJP3X5glRBRzxAB12ijFwJk_sbnVs73XazHVmTgQgPI,982 +torch/distributed/pipelining/_utils.py,sha256=qWBlD9yNXrD9acDogNz-PJGlK6XcBUwOgoTndaBoaJs,3920 +torch/distributed/pipelining/microbatch.py,sha256=Q8SKIaNC5qPR4SDS7T4BTMFYP8yzKPCtSEayaOdoZHQ,16667 +torch/distributed/pipelining/schedules.py,sha256=LS9wt6NEFCXLVWeCaVxLRJ8ph25dozYp7qiChzYJ6z8,116613 +torch/distributed/pipelining/stage.py,sha256=DEGqNKXKyquo0Eek7Qg5wlW3ibKFWqB5MWJ9VngBC0U,63186 +torch/distributed/remote_device.py,sha256=Rdo6612_BpP5mLwqiyPPatc8fIgWCKKjf6izbg7Xjfg,4720 +torch/distributed/rendezvous.py,sha256=SeFz4DlFmBQ8huSJd1a2drsMVETngNPcyso8i4_8Kgw,10475 +torch/distributed/rpc/__init__.py,sha256=SRsUzWZLdYoLuT2fIoT5Q2MwgOZghUm_KFMWbuQbSD4,10202 +torch/distributed/rpc/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/rpc/__pycache__/_utils.cpython-39.pyc,, +torch/distributed/rpc/__pycache__/api.cpython-39.pyc,, +torch/distributed/rpc/__pycache__/backend_registry.cpython-39.pyc,, +torch/distributed/rpc/__pycache__/constants.cpython-39.pyc,, +torch/distributed/rpc/__pycache__/functions.cpython-39.pyc,, +torch/distributed/rpc/__pycache__/internal.cpython-39.pyc,, +torch/distributed/rpc/__pycache__/options.cpython-39.pyc,, +torch/distributed/rpc/__pycache__/rref_proxy.cpython-39.pyc,, +torch/distributed/rpc/__pycache__/server_process_global_profiler.cpython-39.pyc,, +torch/distributed/rpc/_testing/__init__.py,sha256=igpNeyZ8dzHs9oo9YCXtJ2YE4haiH9XK2GPw9L4ocmo,497 +torch/distributed/rpc/_testing/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/rpc/_testing/__pycache__/faulty_agent_backend_registry.cpython-39.pyc,, +torch/distributed/rpc/_testing/faulty_agent_backend_registry.py,sha256=m6TJSTXVA063XVAYeDQ9ZRvWHaDmGeXXzyRIlhv_p1E,1701 +torch/distributed/rpc/_utils.py,sha256=bA-93SM5MU6KLg24KEW_hV-OYKoVIE5keJCK0_wLfRY,1696 +torch/distributed/rpc/api.py,sha256=LuNYsXMCP1x_WzXks57hE_49PjE3MCB-s_YdmXxbHDw,37969 +torch/distributed/rpc/backend_registry.py,sha256=SlNi3JzxsKFngW9ygqQFxnUBcM5-hxa45_42Z-UVpLI,16683 +torch/distributed/rpc/constants.py,sha256=oJTKportdDFWqLpj0BK0LR1VNO_lBINjtKBSdBBGFJw,828 +torch/distributed/rpc/functions.py,sha256=jXJmab8tseRiT-x2DbdMqyGcujL5aFAgGTTWzHnP2l4,7441 +torch/distributed/rpc/internal.py,sha256=nIMGSSq4H7lrbthg2h2Dg2q_h_PmfGqKo5BYxKKSm_k,11397 +torch/distributed/rpc/options.py,sha256=2luQaapiloVz8JagU0q4nguhzvo4KyE_7fatKfazyT0,7411 +torch/distributed/rpc/rref_proxy.py,sha256=k7LMkvNtzZ-jo-dgKpb336Hy10nJjzvG6bp_vfDXy1A,2753 +torch/distributed/rpc/server_process_global_profiler.py,sha256=huWWPAaJvdMHxVdDIdUzorZiGuwM5mXklG_mK56Yqy8,8631 +torch/distributed/run.py,sha256=5h6aOmLKc4g7PzeVQp2bZzoIhcHUnniTESAas0KP86Y,32820 +torch/distributed/tensor/__init__.py,sha256=J7DoD9Z2MzBXUJaPGdaDCmfoHuH9tJYFDZS0lZXpff4,2252 +torch/distributed/tensor/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/tensor/__pycache__/_api.cpython-39.pyc,, +torch/distributed/tensor/__pycache__/_collective_utils.cpython-39.pyc,, +torch/distributed/tensor/__pycache__/_dispatch.cpython-39.pyc,, +torch/distributed/tensor/__pycache__/_dtensor_spec.cpython-39.pyc,, +torch/distributed/tensor/__pycache__/_op_schema.cpython-39.pyc,, +torch/distributed/tensor/__pycache__/_random.cpython-39.pyc,, +torch/distributed/tensor/__pycache__/_redistribute.cpython-39.pyc,, +torch/distributed/tensor/__pycache__/_sharding_prop.cpython-39.pyc,, +torch/distributed/tensor/__pycache__/_shards_wrapper.cpython-39.pyc,, +torch/distributed/tensor/__pycache__/_tp_conv.cpython-39.pyc,, +torch/distributed/tensor/__pycache__/_utils.cpython-39.pyc,, +torch/distributed/tensor/__pycache__/device_mesh.cpython-39.pyc,, +torch/distributed/tensor/__pycache__/placement_types.cpython-39.pyc,, +torch/distributed/tensor/_api.py,sha256=fJm4IFG5zIzQB-pvOUPpG9kTBPV4y2Q2J8AQoHfKPDE,58374 +torch/distributed/tensor/_collective_utils.py,sha256=7wX9gg2zCBCkUoYI1XvEPLnzFe5IyGhQcLtTCfKVwQs,14385 +torch/distributed/tensor/_dispatch.py,sha256=3HPAA-SqoyNyTNDgYe9Ljsh4MbmN30Yn2CHIKQ9Qo8E,20003 +torch/distributed/tensor/_dtensor_spec.py,sha256=a0Pd6np43vQE_bGnxEXlUE_wR0wYsVkQga0DICNxYGk,10585 +torch/distributed/tensor/_op_schema.py,sha256=7HrTOgcJg96ixbcmjNwek-q70XEWb29FiwXldQU03yM,21014 +torch/distributed/tensor/_ops/__init__.py,sha256=56hg3LS8HtJYvn-hA5LobeEFPjFkfT9MoOim5z5bY9k,389 +torch/distributed/tensor/_ops/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/tensor/_ops/__pycache__/_common_rules.cpython-39.pyc,, +torch/distributed/tensor/_ops/__pycache__/_conv_ops.cpython-39.pyc,, +torch/distributed/tensor/_ops/__pycache__/_einsum_strategy.cpython-39.pyc,, +torch/distributed/tensor/_ops/__pycache__/_embedding_ops.cpython-39.pyc,, +torch/distributed/tensor/_ops/__pycache__/_math_ops.cpython-39.pyc,, +torch/distributed/tensor/_ops/__pycache__/_matrix_ops.cpython-39.pyc,, +torch/distributed/tensor/_ops/__pycache__/_pointwise_ops.cpython-39.pyc,, +torch/distributed/tensor/_ops/__pycache__/_random_ops.cpython-39.pyc,, +torch/distributed/tensor/_ops/__pycache__/_tensor_ops.cpython-39.pyc,, +torch/distributed/tensor/_ops/__pycache__/_view_ops.cpython-39.pyc,, +torch/distributed/tensor/_ops/__pycache__/utils.cpython-39.pyc,, +torch/distributed/tensor/_ops/_common_rules.py,sha256=3bKuGo67NTNsAbYrm2mEF5rTFAzeE7JDDHY8N5Huh08,12086 +torch/distributed/tensor/_ops/_conv_ops.py,sha256=m_Ln7AmJ1Hbj-adiZ6bhEodlrHjCqwF9LLvtRS5JFhI,3633 +torch/distributed/tensor/_ops/_einsum_strategy.py,sha256=PDOas_vR5CLJcXAZEGzSXIGtLaAJcalqjhlM8AXSrSQ,6556 +torch/distributed/tensor/_ops/_embedding_ops.py,sha256=YiTUvlZ9TYBhkPP09e93i0vo3nBZ0vf0cfbP175Ixng,10426 +torch/distributed/tensor/_ops/_math_ops.py,sha256=NrmBl84goV2fwPZpxDNfPPodrRzDdyMjLcCCOxazhuA,42303 +torch/distributed/tensor/_ops/_matrix_ops.py,sha256=nqyPg4hRWVbCuZdTrELR-gpD4NEwTnnqakWLQr8iu_c,36996 +torch/distributed/tensor/_ops/_pointwise_ops.py,sha256=I6RpOAnXwnArEiDX0iSzEilIkSL79iIiIPSzwWRr5Z4,22085 +torch/distributed/tensor/_ops/_random_ops.py,sha256=oY15kNMHZr3dIvq228feeEjaNN7AR8o5-Ta6UDwAL34,1123 +torch/distributed/tensor/_ops/_tensor_ops.py,sha256=5I6Hia8SpULDXGjQqbsQ_u_J-gUgxGTwrGjmIXu96KM,36222 +torch/distributed/tensor/_ops/_view_ops.py,sha256=_wdzgdwwtd_bqeYCZM4QJQ69uurB8zVFh0nwWMN2Gr4,26323 +torch/distributed/tensor/_ops/utils.py,sha256=EahyuB_uu0md0lzz7nf14RS95NYrdumLhZH3HBKfr6s,11041 +torch/distributed/tensor/_random.py,sha256=WpyuGMCE2f-qaxD4BXyWvewYvVnG75Eobwfzdr9WFoo,16599 +torch/distributed/tensor/_redistribute.py,sha256=ZoVKUq7vNwkzzZpDPP6wYCqwa2ok4j9iR9hBMF-jXug,16677 +torch/distributed/tensor/_sharding_prop.py,sha256=IYuS_2OgXARp2neXQ2gcbWltoTKpdMI0NFV85-t-300,23488 +torch/distributed/tensor/_shards_wrapper.py,sha256=TbN27OMcgVAy3pyLqhgK7vhD82XMD2dod7WjCsRryE0,13839 +torch/distributed/tensor/_tp_conv.py,sha256=eS9DQCBQs0tkFMYUrbkRdfjBKbSENLadVWHsRV8jXOU,10433 +torch/distributed/tensor/_utils.py,sha256=8fOg1hJE5W36Uc5Lu1jqSHuAnN6RdQ3J7rfThjwLC6E,16724 +torch/distributed/tensor/debug/__init__.py,sha256=yQKhaK8ee4k6pdKbGGYsy_u6GovU6zweXan2aGRif0M,874 +torch/distributed/tensor/debug/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/tensor/debug/__pycache__/_comm_mode.cpython-39.pyc,, +torch/distributed/tensor/debug/__pycache__/_op_coverage.cpython-39.pyc,, +torch/distributed/tensor/debug/__pycache__/_visualize_sharding.cpython-39.pyc,, +torch/distributed/tensor/debug/_comm_mode.py,sha256=mfhbg7KOmEc_R46sunKiOAbibGrskiFG0vESDxr5_4U,29497 +torch/distributed/tensor/debug/_op_coverage.py,sha256=yM8c5kKCWFzrh6C39-dzfQxVPSXLwNxrXDYgBhFSeHY,3246 +torch/distributed/tensor/debug/_visualize_sharding.py,sha256=mZomw2__cDx5_CYmC2a80EB0qQrLHcTfx8w_Fz4sTvE,7776 +torch/distributed/tensor/device_mesh.py,sha256=WmVZoV1gmNbdpIPJm-7BWJba8WDtfwbQlTjswMiTBrI,199 +torch/distributed/tensor/experimental/__init__.py,sha256=m1nj4z-5TaSBIP-7Uhe_kbdXVnv2zAYD84ph8cC9tb8,1458 +torch/distributed/tensor/experimental/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/tensor/experimental/__pycache__/_attention.cpython-39.pyc,, +torch/distributed/tensor/experimental/__pycache__/_func_map.cpython-39.pyc,, +torch/distributed/tensor/experimental/__pycache__/_register_sharding.cpython-39.pyc,, +torch/distributed/tensor/experimental/__pycache__/_tp_transform.cpython-39.pyc,, +torch/distributed/tensor/experimental/_attention.py,sha256=m9YiO6uS8Tu-EsXkFXii2PoL6rjtQ3t21WYy4e3f3Ms,52534 +torch/distributed/tensor/experimental/_func_map.py,sha256=afnw-XGBSQWAn1McbQQSfxi1JhQfjH8nTHdsgefjW9c,12439 +torch/distributed/tensor/experimental/_register_sharding.py,sha256=zhB8lYYdN0V0AnvWFN-9u_kgtImOjpisZYTEpZSOvkA,5803 +torch/distributed/tensor/experimental/_tp_transform.py,sha256=vbgtJ6vcxfa_CQBvdoYmgskobS67svBZXXsPj4KW3c0,20834 +torch/distributed/tensor/parallel/__init__.py,sha256=1jpJR1PUtf58TkmaDCu3A_4LW_Q81q5c82aQCdnV_E8,668 +torch/distributed/tensor/parallel/__pycache__/__init__.cpython-39.pyc,, +torch/distributed/tensor/parallel/__pycache__/_data_parallel_utils.cpython-39.pyc,, +torch/distributed/tensor/parallel/__pycache__/_utils.cpython-39.pyc,, +torch/distributed/tensor/parallel/__pycache__/api.cpython-39.pyc,, +torch/distributed/tensor/parallel/__pycache__/ddp.cpython-39.pyc,, +torch/distributed/tensor/parallel/__pycache__/fsdp.cpython-39.pyc,, +torch/distributed/tensor/parallel/__pycache__/input_reshard.cpython-39.pyc,, +torch/distributed/tensor/parallel/__pycache__/loss.cpython-39.pyc,, +torch/distributed/tensor/parallel/__pycache__/style.cpython-39.pyc,, +torch/distributed/tensor/parallel/_data_parallel_utils.py,sha256=sq7CtRtRwAhIU2VeDliIzjXcubn9NQvMLtWLwKDGeJc,1561 +torch/distributed/tensor/parallel/_utils.py,sha256=j46rUYMMzqs4g6b7ghuJYrbk86E30PRE12etHrloGtY,2380 +torch/distributed/tensor/parallel/api.py,sha256=-qffNnqv9aIQXQUJlsrMOlj8kf_2MD4gOaG9Mgo_Vys,6142 +torch/distributed/tensor/parallel/ddp.py,sha256=hpoXLV3JCqWxgXD7B4NffFc8SdWy4gMEflD5pfm0GgU,3838 +torch/distributed/tensor/parallel/fsdp.py,sha256=z_RINNqKgOOL-AOEx0oh8uiC99bXXl-dW4uwMpS3c9E,14058 +torch/distributed/tensor/parallel/input_reshard.py,sha256=b4mA-PzyQ6txeRYcFTxiR6QUbTK46y7Y78NQQBmyGaY,3862 +torch/distributed/tensor/parallel/loss.py,sha256=Al8S2eZEzbTaJuHCcixpKo7SJVbc0V3A9Ebn44aqHx8,18283 +torch/distributed/tensor/parallel/style.py,sha256=kaEN6lUBrxpbIC-cRgirqY4AaQ-On_a4w0XtPhaGdaw,38038 +torch/distributed/tensor/placement_types.py,sha256=6_qapKUzFTX8i3xYP-oj-K4I6G3HBQwYvo1JeR1tM4w,30322 +torch/distributed/utils.py,sha256=aHfG8wFEJG-Sc22PqZqZK8XGlrAD6jrKsyjLQWhDunQ,13730 +torch/distributions/__init__.py,sha256=C63TUTaGfUf7T4oHRB02nqf81IKkmdJZSftk9KNey9o,6285 +torch/distributions/__pycache__/__init__.cpython-39.pyc,, +torch/distributions/__pycache__/bernoulli.cpython-39.pyc,, +torch/distributions/__pycache__/beta.cpython-39.pyc,, +torch/distributions/__pycache__/binomial.cpython-39.pyc,, +torch/distributions/__pycache__/categorical.cpython-39.pyc,, +torch/distributions/__pycache__/cauchy.cpython-39.pyc,, +torch/distributions/__pycache__/chi2.cpython-39.pyc,, +torch/distributions/__pycache__/constraint_registry.cpython-39.pyc,, +torch/distributions/__pycache__/constraints.cpython-39.pyc,, +torch/distributions/__pycache__/continuous_bernoulli.cpython-39.pyc,, +torch/distributions/__pycache__/dirichlet.cpython-39.pyc,, +torch/distributions/__pycache__/distribution.cpython-39.pyc,, +torch/distributions/__pycache__/exp_family.cpython-39.pyc,, +torch/distributions/__pycache__/exponential.cpython-39.pyc,, +torch/distributions/__pycache__/fishersnedecor.cpython-39.pyc,, +torch/distributions/__pycache__/gamma.cpython-39.pyc,, +torch/distributions/__pycache__/generalized_pareto.cpython-39.pyc,, +torch/distributions/__pycache__/geometric.cpython-39.pyc,, +torch/distributions/__pycache__/gumbel.cpython-39.pyc,, +torch/distributions/__pycache__/half_cauchy.cpython-39.pyc,, +torch/distributions/__pycache__/half_normal.cpython-39.pyc,, +torch/distributions/__pycache__/independent.cpython-39.pyc,, +torch/distributions/__pycache__/inverse_gamma.cpython-39.pyc,, +torch/distributions/__pycache__/kl.cpython-39.pyc,, +torch/distributions/__pycache__/kumaraswamy.cpython-39.pyc,, +torch/distributions/__pycache__/laplace.cpython-39.pyc,, +torch/distributions/__pycache__/lkj_cholesky.cpython-39.pyc,, +torch/distributions/__pycache__/log_normal.cpython-39.pyc,, +torch/distributions/__pycache__/logistic_normal.cpython-39.pyc,, +torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-39.pyc,, +torch/distributions/__pycache__/mixture_same_family.cpython-39.pyc,, +torch/distributions/__pycache__/multinomial.cpython-39.pyc,, +torch/distributions/__pycache__/multivariate_normal.cpython-39.pyc,, +torch/distributions/__pycache__/negative_binomial.cpython-39.pyc,, +torch/distributions/__pycache__/normal.cpython-39.pyc,, +torch/distributions/__pycache__/one_hot_categorical.cpython-39.pyc,, +torch/distributions/__pycache__/pareto.cpython-39.pyc,, +torch/distributions/__pycache__/poisson.cpython-39.pyc,, +torch/distributions/__pycache__/relaxed_bernoulli.cpython-39.pyc,, +torch/distributions/__pycache__/relaxed_categorical.cpython-39.pyc,, +torch/distributions/__pycache__/studentT.cpython-39.pyc,, +torch/distributions/__pycache__/transformed_distribution.cpython-39.pyc,, +torch/distributions/__pycache__/transforms.cpython-39.pyc,, +torch/distributions/__pycache__/uniform.cpython-39.pyc,, +torch/distributions/__pycache__/utils.cpython-39.pyc,, +torch/distributions/__pycache__/von_mises.cpython-39.pyc,, +torch/distributions/__pycache__/weibull.cpython-39.pyc,, +torch/distributions/__pycache__/wishart.cpython-39.pyc,, +torch/distributions/bernoulli.py,sha256=CL4wHn9dkw3lo7Ssnv_hpmiczkqIsmEMyKqsY01ojy0,4754 +torch/distributions/beta.py,sha256=ZAb9sK_gUDthZnIZMyGeKXAamVI8lJGJXiCnF1InlIU,4093 +torch/distributions/binomial.py,sha256=bp-V2wEw2n7Ne8mNVD5c0NihaXNhIC5ZaJqJ1hGsFbY,6496 +torch/distributions/categorical.py,sha256=8uOsCxBf-2Iqu1SCFxjfU9S0o-K4CMRE-GJITaXVe-A,6229 +torch/distributions/cauchy.py,sha256=ElwrmIgHnQ1NLDZIoLpHRCB38Hiq3GLohT0I2wfqEFY,3271 +torch/distributions/chi2.py,sha256=SrWjaIB6q8XDmGihZj2-BDQFsNfTF5KT6Qk2NgoFpmk,1196 +torch/distributions/constraint_registry.py,sha256=lXyMFlXxIdpzIDotRZlM86Md5ODi4GCfzZK1-O9FhAY,10597 +torch/distributions/constraints.py,sha256=PLqRpyBpbZXYWJ7DYfCNz4MFKSMWVLtNtgp62xSQJEU,22339 +torch/distributions/continuous_bernoulli.py,sha256=pBuxyA9ndg_QPnVRNuFXq32yX-_zsGiDkV8N59fa3tg,9352 +torch/distributions/dirichlet.py,sha256=gUtVjd08V1_Xnch9zjUoWy-E_5Aqe4Nvlzc5G6T1DKY,4547 +torch/distributions/distribution.py,sha256=nImMRXlCSBA20jSKZ9uN5nyUc5RyvVuEVtZ5AdTdMqs,12951 +torch/distributions/exp_family.py,sha256=5vI4L3zPFz6UODF6yWEs6IQxosA510GjOg5TIMbHZzU,2502 +torch/distributions/exponential.py,sha256=O9cBaeLwb-yc_fOjJxL8N_2QC89dE6YiUcLpZngQKok,2788 +torch/distributions/fishersnedecor.py,sha256=PKVVE4yrmxXt317V9dCBz9-lcquJTKWalfJRGInMUd0,3785 +torch/distributions/gamma.py,sha256=TZWq3xM81c8n_Pv1vYF6PhDo5L0zGjKIABz8RBLOrSE,4026 +torch/distributions/generalized_pareto.py,sha256=-FdZgft9ImMOtCd46d1o6mOKhHKVz0eLpcsiijEjSY0,5907 +torch/distributions/geometric.py,sha256=vf2VVpkTlvn8U17vITjVGpYkTN9rd22sAQS2ki8J20k,5193 +torch/distributions/gumbel.py,sha256=N3NOPq4nCHQBEK5gYuSE9hcUjZH_mSZBop-Bjvkb97g,3106 +torch/distributions/half_cauchy.py,sha256=VSFK2yOnooyN7Sd_rXUs2amA2IMpOFqIfGTxr3r79oE,2699 +torch/distributions/half_normal.py,sha256=cT9Y2_KKiQfVvwJGCh04zvVkIqHjBJ_aSx0F9KBFNm4,2456 +torch/distributions/independent.py,sha256=UAyIU7pRZMnJEmF92eODhroS2LftRxzjh5mC0T-ja9s,5119 +torch/distributions/inverse_gamma.py,sha256=vY91rTaCAi6BK2PuJuw6T_7TbTquOskYviCNNnRHIKU,2838 +torch/distributions/kl.py,sha256=n29Kldc5C7nQuEiBptB-KWak-EU8q29PMFxxYrK48X0,32709 +torch/distributions/kumaraswamy.py,sha256=WJSbf7ZmNcKoGq1c6iRIcqmRncvtluRrE7MS3WlD3y8,3758 +torch/distributions/laplace.py,sha256=eG_6YoaEAPV_oR6XEQsqfKdm_d1y72mLVwkSsPc0sIE,3607 +torch/distributions/lkj_cholesky.py,sha256=OA_rzn-xzritba6kFHVDYTfBjuFfSK_7UFGngGDGqAk,6708 +torch/distributions/log_normal.py,sha256=epjzyQgLOycbTDbWoHFh4ae4Z8y8qEk_p0PYdxoawBw,2274 +torch/distributions/logistic_normal.py,sha256=aGCk1h3y6ZLrGX9Reez7ZjRGz_03qcWPX2O3fNWrxUk,2282 +torch/distributions/lowrank_multivariate_normal.py,sha256=jJUfT6H3IdshmaHCuOhUMmoHh8gY__lZUgJAXmDF81I,10377 +torch/distributions/mixture_same_family.py,sha256=7MsKNDczLYgFvAFYKzDPygvOjuIwpbn6bQEpQ08zf0E,8872 +torch/distributions/multinomial.py,sha256=9tKfYc1iRDPnGQnvc3EJI9GmYYI5qh7PsRI2ueKhV2A,5786 +torch/distributions/multivariate_normal.py,sha256=ksMrUZ7HbhqF4Dw4sBYs5-iptkPn9JUyjz46lje7BBQ,11362 +torch/distributions/negative_binomial.py,sha256=RsqAB0Fq4KoDdA8crrH84pxsJUffiCd-DzzJKGHC_Ys,5192 +torch/distributions/normal.py,sha256=pnVTPGZV0glN-TSbCCKdB4Uf8EernSRfVQbxWDh1Dsk,3987 +torch/distributions/one_hot_categorical.py,sha256=vIrjeVy3ve1dIQGsCncHylnNVDu5TF_Hm5hwBOpoq7g,5133 +torch/distributions/pareto.py,sha256=O6kAc3HvIxwMqchb2l7tauFmsKJIOW0AwVH6zvBY1no,2571 +torch/distributions/poisson.py,sha256=1PWSugsTdzAeFKgqilBpXfdaXV4MFVj1dYM6onO_fDk,2533 +torch/distributions/relaxed_bernoulli.py,sha256=GZwjR-Kls8p9GHOp_CQfXtLiw4pnonFbT6HUYJT1QaE,6125 +torch/distributions/relaxed_categorical.py,sha256=RwuUQTsE8i0pT78ZN6y2x1pe3LARbtUVA0Nd7M38FYk,5801 +torch/distributions/studentT.py,sha256=zqLrRVpR1gW0eBBbfJzjfDpHGFekB6bKWd7x0tSTYyM,4277 +torch/distributions/transformed_distribution.py,sha256=zKjb4guPe5X-IiuoNjUX05-oejcRE6_NaaQx7CEz_TQ,9097 +torch/distributions/transforms.py,sha256=xCT1-Du_GFUOaPG7fcrg8DdFQ-hS5kLDINrxw-K3SSU,43704 +torch/distributions/uniform.py,sha256=Eo2olo4L9x0Y94pgoRZ6yzlTLZAUDk68CpAD4_tNCEw,3469 +torch/distributions/utils.py,sha256=kzZiRGQuSebrC2RMlO6TpFzf8wVlMFlxobDVd7sq9_o,8248 +torch/distributions/von_mises.py,sha256=rCk2wZ2DBFTJPx6672Q6dz05zJ4s5oeO7E4iE6UQcVQ,6543 +torch/distributions/weibull.py,sha256=zOR2ANQnjKmvK8dsPzRKe0W07cdmEvHrwltIxfWPscc,3479 +torch/distributions/wishart.py,sha256=j4mSZR74cUvIXzVGDJPFNbdACGqDORzeIQAXLudyfs0,14053 +torch/export/__init__.py,sha256=ZKvjdwcjrY33L589VRweQrpa-qmStMMbe75CkQlf2Lo,24128 +torch/export/__pycache__/__init__.cpython-39.pyc,, +torch/export/__pycache__/_draft_export.cpython-39.pyc,, +torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-39.pyc,, +torch/export/__pycache__/_remove_effect_tokens_pass.cpython-39.pyc,, +torch/export/__pycache__/_safeguard.cpython-39.pyc,, +torch/export/__pycache__/_swap.cpython-39.pyc,, +torch/export/__pycache__/_trace.cpython-39.pyc,, +torch/export/__pycache__/_tree_utils.cpython-39.pyc,, +torch/export/__pycache__/_unlift.cpython-39.pyc,, +torch/export/__pycache__/_wrapper_utils.cpython-39.pyc,, +torch/export/__pycache__/custom_obj.cpython-39.pyc,, +torch/export/__pycache__/custom_ops.cpython-39.pyc,, +torch/export/__pycache__/decomp_utils.cpython-39.pyc,, +torch/export/__pycache__/dynamic_shapes.cpython-39.pyc,, +torch/export/__pycache__/exported_program.cpython-39.pyc,, +torch/export/__pycache__/graph_signature.cpython-39.pyc,, +torch/export/__pycache__/unflatten.cpython-39.pyc,, +torch/export/_draft_export.py,sha256=GXpO-sbjCX6Jrh1ybqrXeIi_eVUBu-sNqx2T4TRON04,19775 +torch/export/_remove_auto_functionalized_pass.py,sha256=GRHkvAs_UXy7lli2j87zmPBLnnZSRmwnAt6lf8rf_Lo,2003 +torch/export/_remove_effect_tokens_pass.py,sha256=eee-So9Sidb8FpbS_Gf0obEGnq9jBUNG9nQ8qghk7P8,6502 +torch/export/_safeguard.py,sha256=OHb5_QzVfyGuP25IIvKkLE7-fnAyjLHQoaZsMiPZeRo,2000 +torch/export/_swap.py,sha256=mIxgF2v2kKcByR2G9HqiPynpwZQuzWce7_fxjhq6l9A,18017 +torch/export/_trace.py,sha256=yOfMrVY-yDC9aOGJbjOw3wf1F3xOsiI8ESq86H-rLmI,90570 +torch/export/_tree_utils.py,sha256=HJz4dBPaqj_qI9WkAjQj2TFe0g65E7n9gVyQb0moVpE,2326 +torch/export/_unlift.py,sha256=aYl_bFkx1BqiyQ99mNuptcSyPTcE4tH2M5VsdGpfJJQ,18009 +torch/export/_wrapper_utils.py,sha256=y7WMqitDwh3NwyuOrXmKCwKzth1QujBffpzYpqvypNM,281 +torch/export/custom_obj.py,sha256=sayHd5SniH20UVLJ4QVUuhtjLa2RRqDR2fbvbMDXXfA,312 +torch/export/custom_ops.py,sha256=MEEExuWtQVLr2U-NGaTCeVbKececTL5AK6M1kyKmPpw,986 +torch/export/decomp_utils.py,sha256=Jgmj2ZhUJWrCEu8UXS66LTIwPtKPOvj-AZ8gAWnD1oE,5674 +torch/export/dynamic_shapes.py,sha256=THUaWR_O5s6GTsQpFlfegUiLbxH96GnbHWFjQge7eIY,54785 +torch/export/experimental/__init__.py,sha256=RMn-1Ig1PCV0lW4V3pkSzXw9VGwWKyVDNyekToNwQU0,13024 +torch/export/experimental/__pycache__/__init__.cpython-39.pyc,, +torch/export/exported_program.py,sha256=LIwbFF__mAAMKvKRFPpw1oILtpm9YLScZU32rO-k240,68721 +torch/export/graph_signature.py,sha256=WLSjFPCO0tdxNzAOLtEnHvDU2vj7z7ioK3Y2TeAspwU,25520 +torch/export/passes/__init__.py,sha256=8FJPXNzYyKkBjG7REq_-lN2SJlyt29561T4--phJ9VU,2332 +torch/export/passes/__pycache__/__init__.cpython-39.pyc,, +torch/export/pt2_archive/__init__.py,sha256=Pj7CXPunjeMGhxVtV5HwYjEYOeiRLcsP4sgmABiu1Ks,148 +torch/export/pt2_archive/__pycache__/__init__.cpython-39.pyc,, +torch/export/pt2_archive/__pycache__/_package.cpython-39.pyc,, +torch/export/pt2_archive/__pycache__/_package_weights.cpython-39.pyc,, +torch/export/pt2_archive/__pycache__/constants.cpython-39.pyc,, +torch/export/pt2_archive/_package.py,sha256=G2cT9MBKE7kyZ6HMduOmi6nEQUuYqrRUDVw-5_VNUkc,25142 +torch/export/pt2_archive/_package_weights.py,sha256=NKbvhG_bOazwU4ZfRow3rBT595DsvmGt9bqPJI_7IXI,3515 +torch/export/pt2_archive/constants.py,sha256=Vgs4UV5wEC8OZPDODn6BG9VrAJ0ewx_giGvnPyODDuY,1532 +torch/export/unflatten.py,sha256=DlXc4326H4Bwct--2yoz7qRWPbR7zzaCDCw4GeSgzLQ,67717 +torch/fft/__init__.py,sha256=Ux2AyoXcfP6z3pLV96O_QxjHxqPteV85pVV1WfsJY2w,56779 +torch/fft/__pycache__/__init__.cpython-39.pyc,, +torch/func/__init__.py,sha256=fEBWU84aI3QYiqjuZf_7Z-FXp8tVgkRqesChzIa2V9M,687 +torch/func/__pycache__/__init__.cpython-39.pyc,, +torch/functional.py,sha256=7teqvJ_ae-wi5jCMA8T-yBA8MQRRZ11pd8Jw0U_Tqi0,91112 +torch/futures/__init__.py,sha256=Hs80b7LO5EFyiH08zhrAN__rDMCeZEXINYnPNrTAN-M,14831 +torch/futures/__pycache__/__init__.cpython-39.pyc,, +torch/fx/__init__.py,sha256=ktywCt2GbZUKydJ9JofisfqXd0pCFl7ak9UB5KZYt3Q,4279 +torch/fx/__pycache__/__init__.cpython-39.pyc,, +torch/fx/__pycache__/_compatibility.cpython-39.pyc,, +torch/fx/__pycache__/_graph_pickler.cpython-39.pyc,, +torch/fx/__pycache__/_lazy_graph_module.cpython-39.pyc,, +torch/fx/__pycache__/_pytree.cpython-39.pyc,, +torch/fx/__pycache__/_symbolic_trace.cpython-39.pyc,, +torch/fx/__pycache__/_utils.cpython-39.pyc,, +torch/fx/__pycache__/annotate.cpython-39.pyc,, +torch/fx/__pycache__/config.cpython-39.pyc,, +torch/fx/__pycache__/graph.cpython-39.pyc,, +torch/fx/__pycache__/graph_module.cpython-39.pyc,, +torch/fx/__pycache__/immutable_collections.cpython-39.pyc,, +torch/fx/__pycache__/interpreter.cpython-39.pyc,, +torch/fx/__pycache__/node.cpython-39.pyc,, +torch/fx/__pycache__/operator_schemas.cpython-39.pyc,, +torch/fx/__pycache__/proxy.cpython-39.pyc,, +torch/fx/__pycache__/subgraph_rewriter.cpython-39.pyc,, +torch/fx/__pycache__/tensor_type.cpython-39.pyc,, +torch/fx/__pycache__/traceback.cpython-39.pyc,, +torch/fx/_compatibility.py,sha256=xaGmXxHE4MkaeqqNl6aVRU2vbkfYj3JLSngkLi6kvJo,1120 +torch/fx/_graph_pickler.py,sha256=163UOdulugfltsMnd-ggKlWrHObfRIcEliQ1FsosyOk,22330 +torch/fx/_lazy_graph_module.py,sha256=PaYsiptsRlJCGlQKzJiRBzPZWwZLHuTLCJb9dLjvuVM,7343 +torch/fx/_pytree.py,sha256=pouhXp_vY-ZUdEb-_z_LFBD_bkx1Ls940rirREjKsr4,3700 +torch/fx/_symbolic_trace.py,sha256=01UcfXVq9dWfJ-mV6YQQVSgyDWfZo5xZTGUWYX_wWw8,52153 +torch/fx/_utils.py,sha256=PX2e_hTCbfK--lEpkVSK7lljyDxVNlooSeL02s1Szj0,1803 +torch/fx/annotate.py,sha256=dn7XPNwrztonT4wNge5uEuD-pUa_WIsh5hqWX6DTwys,1292 +torch/fx/config.py,sha256=T3TTYqs3by9AqZqczmB4odnDmtnWfepicu-QExxw2gI,334 +torch/fx/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/fx/experimental/__pycache__/__init__.cpython-39.pyc,, +torch/fx/experimental/__pycache__/_backward_state.cpython-39.pyc,, +torch/fx/experimental/__pycache__/_config.cpython-39.pyc,, +torch/fx/experimental/__pycache__/_constant_symnode.cpython-39.pyc,, +torch/fx/experimental/__pycache__/_dynamism.cpython-39.pyc,, +torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-39.pyc,, +torch/fx/experimental/__pycache__/const_fold.cpython-39.pyc,, +torch/fx/experimental/__pycache__/debug.cpython-39.pyc,, +torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-39.pyc,, +torch/fx/experimental/__pycache__/merge_matmul.cpython-39.pyc,, +torch/fx/experimental/__pycache__/meta_tracer.cpython-39.pyc,, +torch/fx/experimental/__pycache__/normalize.cpython-39.pyc,, +torch/fx/experimental/__pycache__/optimization.cpython-39.pyc,, +torch/fx/experimental/__pycache__/partitioner_utils.cpython-39.pyc,, +torch/fx/experimental/__pycache__/proxy_tensor.cpython-39.pyc,, +torch/fx/experimental/__pycache__/recording.cpython-39.pyc,, +torch/fx/experimental/__pycache__/refinement_types.cpython-39.pyc,, +torch/fx/experimental/__pycache__/rewriter.cpython-39.pyc,, +torch/fx/experimental/__pycache__/schema_type_annotation.cpython-39.pyc,, +torch/fx/experimental/__pycache__/sym_node.cpython-39.pyc,, +torch/fx/experimental/__pycache__/symbolic_shapes.cpython-39.pyc,, +torch/fx/experimental/__pycache__/unify_refinements.cpython-39.pyc,, +torch/fx/experimental/__pycache__/validator.cpython-39.pyc,, +torch/fx/experimental/_backward_state.py,sha256=sY7kEctuCzdVZJUCrhkCgCCk61-XepGYuRmmNv6NW9o,994 +torch/fx/experimental/_config.py,sha256=SfjH0Vbdxm8dQW8-1Yd7ZGniog7gr6-mjtqxzQPEZbg,4816 +torch/fx/experimental/_constant_symnode.py,sha256=4ndLW6IPKnidbSnZVu3FZcw0WQ0Ssy7jQtyP9_kXbEg,1591 +torch/fx/experimental/_dynamism.py,sha256=umH_u8mPbSpqd5cILF76bhbbfftx3-nPklSQD5U2Dnc,4619 +torch/fx/experimental/accelerator_partitioner.py,sha256=59VcpKGTq8RWepcM_COnRVTs-b5hO7DS77Y45at7qBU,48847 +torch/fx/experimental/const_fold.py,sha256=1_Pf-GHQFZuukehMJeTwxi29EDiyPa9cT84B7FklVf8,12885 +torch/fx/experimental/debug.py,sha256=x5BdlNIc_AeXDA72WqGVOprN-WHkhSymqoQIFwEhHQE,844 +torch/fx/experimental/graph_gradual_typechecker.py,sha256=gF_Ym0YRE0f1FJ5I_vruGoJjmU2RQhu8w6WDltLgmqM,35111 +torch/fx/experimental/merge_matmul.py,sha256=SjSEzT06pQkHAQaUxl7lO2rK0XUMgDok1i7G2CIJTJ4,6240 +torch/fx/experimental/meta_tracer.py,sha256=xP2LiyOuY3SQ12XklHH-c_lodmUhwWGFCDW54sVW-ds,10888 +torch/fx/experimental/migrate_gradual_types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-39.pyc,, +torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-39.pyc,, +torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-39.pyc,, +torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-39.pyc,, +torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-39.pyc,, +torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-39.pyc,, +torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-39.pyc,, +torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-39.pyc,, +torch/fx/experimental/migrate_gradual_types/constraint.py,sha256=WeKDkuPjoFYWGotIfW8hRgevxKDN6I2AoWAdJ5KmJ-s,17865 +torch/fx/experimental/migrate_gradual_types/constraint_generator.py,sha256=UE16y0PYb8tTSZIFh4XxAuuqqGvePYZHNECpQPXPY1A,52594 +torch/fx/experimental/migrate_gradual_types/constraint_transformation.py,sha256=A0K4TOMQm3z4V7Edak-AB6oxFZFxUGn-LDzmftSdKG4,42322 +torch/fx/experimental/migrate_gradual_types/operation.py,sha256=zbUGWchsU6rC48vqKklLLe_z8MW-y076OlSNy7w2w2E,301 +torch/fx/experimental/migrate_gradual_types/transform_to_z3.py,sha256=GL5enOEip4xoTg_Rymw7ISMVH8uKY7D63VeSqSymw4o,16816 +torch/fx/experimental/migrate_gradual_types/util.py,sha256=VqJ-qT2Ho6-n4U6fE-FulBGzY3ds1oBFHhvgld_6PMg,1424 +torch/fx/experimental/migrate_gradual_types/z3_types.py,sha256=ipCI1wud5fXjuavpPXK9T61D-sNXsId__Ux849pJcAs,837 +torch/fx/experimental/normalize.py,sha256=O0x4ptfJ8HRcESmlJ8VKd_TQzB61YkaJ1oB6bN6gpL0,5655 +torch/fx/experimental/optimization.py,sha256=792GH6fAgBk_kLyZ3lUPpZ83X7aIiBTopzKLo-lFel0,18172 +torch/fx/experimental/partitioner_utils.py,sha256=AAOnE3goH6RXZok1U-KPASa3moCclGC50d-5xEm9br0,12612 +torch/fx/experimental/proxy_tensor.py,sha256=h3P-XO3Rug2YegFV6g_SwbQT12gMNpe2QI6Q7sNiAvI,93345 +torch/fx/experimental/recording.py,sha256=inlQ-xHPQzPLU5NVqQfKfaX8cmO-glZGyjXEcLqImqE,20452 +torch/fx/experimental/refinement_types.py,sha256=2QiUBXblb8SBDr4n9yhYA5sBX6FkuaymUtny2x7dr8A,467 +torch/fx/experimental/rewriter.py,sha256=oKaqQa5ODK6OMVAF1ixgoj8WIQ11So3Q-fphTiDyFLw,5611 +torch/fx/experimental/schema_type_annotation.py,sha256=GiT8xnwEhVJ4FXcxF1TJMVK1wLL8J_smUjFVN6cvl4Y,5524 +torch/fx/experimental/sym_node.py,sha256=ID-N6vfiW_AUIiUeoP4JUBCoFAW-BYeCLPkbuqyyQkM,62101 +torch/fx/experimental/symbolic_shapes.py,sha256=9EWbAB_5vAA-p_NyDVof7iwJORGcwYdmUjHaHb6Z2IM,342516 +torch/fx/experimental/unification/__init__.py,sha256=gFSoj0KQx-HXb10CoJ-yxySJWTOfzpbfnDE8AFupMLc,200 +torch/fx/experimental/unification/__pycache__/__init__.cpython-39.pyc,, +torch/fx/experimental/unification/__pycache__/core.cpython-39.pyc,, +torch/fx/experimental/unification/__pycache__/dispatch.cpython-39.pyc,, +torch/fx/experimental/unification/__pycache__/match.cpython-39.pyc,, +torch/fx/experimental/unification/__pycache__/more.cpython-39.pyc,, +torch/fx/experimental/unification/__pycache__/unification_tools.cpython-39.pyc,, +torch/fx/experimental/unification/__pycache__/utils.cpython-39.pyc,, +torch/fx/experimental/unification/__pycache__/variable.cpython-39.pyc,, +torch/fx/experimental/unification/core.py,sha256=oHudUofULazwHxFMY3OhQc44hEb3-3r2AMAl1en_ol0,2895 +torch/fx/experimental/unification/dispatch.py,sha256=b10Td_2mmcMUQXpNaXlfxKqQr8TdHUQJ4fmiE5xoDqY,201 +torch/fx/experimental/unification/match.py,sha256=04VuNhNb8MVJMbCbSDSYffq5rQzSlWSoGsgbXxypU54,3543 +torch/fx/experimental/unification/more.py,sha256=hNwXS1xLeT_53UyIdOkZ1sQxbZ7bsSnsGoCQ6USpi7k,3084 +torch/fx/experimental/unification/multipledispatch/__init__.py,sha256=GFz_pmv1HGMq8u3esbhpKoY3evsaQ7j2fXBWNFbuz6c,146 +torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-39.pyc,, +torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-39.pyc,, +torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-39.pyc,, +torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-39.pyc,, +torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-39.pyc,, +torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-39.pyc,, +torch/fx/experimental/unification/multipledispatch/conflict.py,sha256=CMuB6kiOBSnm3tyUhdRYrLYP9WuunP5sj0181d-qynY,4349 +torch/fx/experimental/unification/multipledispatch/core.py,sha256=sTCbk8C2N42jV1MCs9O-T9Rf-mAcqoAs6RvB9KZtHw8,2652 +torch/fx/experimental/unification/multipledispatch/dispatcher.py,sha256=6gs2Rxmar4doNdT-VNczw7TqsyLWgkoS6GYelm4Cy3Y,14337 +torch/fx/experimental/unification/multipledispatch/utils.py,sha256=EKZjqQsAnnGOiTlzwprexz2GlTN8haGFfxAXYi4MH3M,3887 +torch/fx/experimental/unification/multipledispatch/variadic.py,sha256=L3fDMtGGBH6SyV9WcUdzQwgrkGDQfXN_6pQ7RLf4usg,3058 +torch/fx/experimental/unification/unification_tools.py,sha256=03QW1SVCUedg7YLcd1Ilv90V8ChO0L32RyfbFyLvRfs,10986 +torch/fx/experimental/unification/utils.py,sha256=Y9naJmr_hgJWv-bLeP5hwhekMXvLEHMTgv7NDF_4Tx8,3045 +torch/fx/experimental/unification/variable.py,sha256=VQJygi4031DYRKK7ZR9IZop0gV7gi1Yc6LoMtwrO-X8,2155 +torch/fx/experimental/unify_refinements.py,sha256=MlfK84ALx4gOF5WSNb1jkOkapjKF7PMIVu-hwqVzliY,3275 +torch/fx/experimental/validator.py,sha256=FnGeccrtLb3BqS0dP00k9Tvzxy2d1hbF_tvWEHSaXPs,34875 +torch/fx/graph.py,sha256=WQH1HLhNWP2fWhGNUrpH4QNnopuwa2PhHtwl3zY33Ws,78320 +torch/fx/graph_module.py,sha256=qem0t18_DY4IJSUkDG_rPzO1yetKuRf8hjWbL5Fud6k,43350 +torch/fx/immutable_collections.py,sha256=NUtdRSD0G-6tzNAU8xg8zXSxV9gcdwuoB1NjqXLNnhM,3391 +torch/fx/interpreter.py,sha256=ZsH5oK3Gx8KjPchiDa70MYsE4Aagt1wF8b35Y3mzFJA,23773 +torch/fx/node.py,sha256=MUKHIoIaJkMJ_G_gZ7hPSugRdKI9Vz-ipCTNhd3jUdc,35104 +torch/fx/operator_schemas.py,sha256=LlonLfeLjrOc5ZKzwKoaJbr_kz8z7I8z1WOSiYgFs-I,22420 +torch/fx/passes/__init__.py,sha256=OhaBf_mMwHezRK3aq2IaBYymXfmydwZEmbbkMvsrdgM,254 +torch/fx/passes/__pycache__/__init__.cpython-39.pyc,, +torch/fx/passes/__pycache__/_tensorify_python_scalars.cpython-39.pyc,, +torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-39.pyc,, +torch/fx/passes/__pycache__/fake_tensor_prop.cpython-39.pyc,, +torch/fx/passes/__pycache__/graph_drawer.cpython-39.pyc,, +torch/fx/passes/__pycache__/graph_manipulation.cpython-39.pyc,, +torch/fx/passes/__pycache__/graph_transform_observer.cpython-39.pyc,, +torch/fx/passes/__pycache__/net_min_base.cpython-39.pyc,, +torch/fx/passes/__pycache__/operator_support.cpython-39.pyc,, +torch/fx/passes/__pycache__/param_fetch.cpython-39.pyc,, +torch/fx/passes/__pycache__/pass_manager.cpython-39.pyc,, +torch/fx/passes/__pycache__/reinplace.cpython-39.pyc,, +torch/fx/passes/__pycache__/runtime_assert.cpython-39.pyc,, +torch/fx/passes/__pycache__/shape_prop.cpython-39.pyc,, +torch/fx/passes/__pycache__/split_module.cpython-39.pyc,, +torch/fx/passes/__pycache__/split_utils.cpython-39.pyc,, +torch/fx/passes/__pycache__/splitter_base.cpython-39.pyc,, +torch/fx/passes/__pycache__/tools_common.cpython-39.pyc,, +torch/fx/passes/_tensorify_python_scalars.py,sha256=QMBd1cryd5rCs5dbZqlkAbwuHj7wPWx3vdG-oODjqzw,16466 +torch/fx/passes/annotate_getitem_nodes.py,sha256=knZiMDMX_daDFDGKPUqMHPJO2JsxtrFdXFYWausnW7E,2820 +torch/fx/passes/backends/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/fx/passes/backends/__pycache__/__init__.cpython-39.pyc,, +torch/fx/passes/backends/__pycache__/cudagraphs.cpython-39.pyc,, +torch/fx/passes/backends/cudagraphs.py,sha256=aYZYaP1-eeY4j8BgmFadWoO5vkhgbTVx6r_Vg3vNknY,2144 +torch/fx/passes/dialect/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/fx/passes/dialect/__pycache__/__init__.cpython-39.pyc,, +torch/fx/passes/dialect/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/fx/passes/dialect/common/__pycache__/__init__.cpython-39.pyc,, +torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-39.pyc,, +torch/fx/passes/dialect/common/cse_pass.py,sha256=qRTv86JeOKpIqA8BpKKQ3TuY5TyqSpqGuH7iBNshop0,5403 +torch/fx/passes/fake_tensor_prop.py,sha256=8mHq3YFOQb7JuQwv3UyBwZ2naFcCCtZ6WG_od0-YTro,4309 +torch/fx/passes/graph_drawer.py,sha256=-25ZLkOk37xkRWrc1jjPMysTChEy8bGRYHrRlpnNpmo,19613 +torch/fx/passes/graph_manipulation.py,sha256=MZT9CIlADZY3V4BPBDjtK_mPj1BUPmrndJsjJ-41onU,4077 +torch/fx/passes/graph_transform_observer.py,sha256=jMXoeuOsyM8oKJUlGfsE1CwA_lDqNqvzMS5AzwSECoU,7611 +torch/fx/passes/infra/__init__.py,sha256=XeK4BKH-Vs9snoA5CTrYmG0ZXdkAiCus1btUs3EPjvo,28 +torch/fx/passes/infra/__pycache__/__init__.cpython-39.pyc,, +torch/fx/passes/infra/__pycache__/partitioner.cpython-39.pyc,, +torch/fx/passes/infra/__pycache__/pass_base.cpython-39.pyc,, +torch/fx/passes/infra/__pycache__/pass_manager.cpython-39.pyc,, +torch/fx/passes/infra/partitioner.py,sha256=ldAzGahmvq50QsV2WV5hVbhJF0f9duMHbeLDMI9QJaM,16835 +torch/fx/passes/infra/pass_base.py,sha256=_GwLcijp1sc2slUmDaaJfBUYRxtag6qzm0hrL_lU23Q,2579 +torch/fx/passes/infra/pass_manager.py,sha256=Y_W4ivx2Preki1leGo2n-AOg7VWoOU82LQ87XFTI6yQ,10654 +torch/fx/passes/net_min_base.py,sha256=1_08tlQVkoVusaXGn4N_pJIVMN-vxzL8DrmZ-tcp83s,37548 +torch/fx/passes/operator_support.py,sha256=5JNEtAgfP9UkRBveunSvqPNBvyp4k-6Aow5KO9QdhBM,7861 +torch/fx/passes/param_fetch.py,sha256=ytRb5NtQL_xDHW4RoaWBfy0BbxZuk7kc3VmGyM2M-6U,3810 +torch/fx/passes/pass_manager.py,sha256=x06AXEgJIxx_NK4IKDQve1pmEYafCvXFBz9qwuJJpDw,7311 +torch/fx/passes/reinplace.py,sha256=STWdTvL8cyp51xxrk1w9O8z6_PzrJG0VriyQvScHYOI,35283 +torch/fx/passes/runtime_assert.py,sha256=JgTRBCuGoHohorRnGUhqD3POWTS2SR03EeVI7E5T4NY,29698 +torch/fx/passes/shape_prop.py,sha256=Msuc4q2QXqmBgHO6cUZALIQ3cqHW06h3vymiy94PHtI,8554 +torch/fx/passes/split_module.py,sha256=foeYTLziyKTIzALsclPTEN9ooJcLbSgBf4ctaX0s-aY,27408 +torch/fx/passes/split_utils.py,sha256=5VSPGtyFKHgnPQgxmoZHhD2hCsPHc7lZomxZFUz83WM,11780 +torch/fx/passes/splitter_base.py,sha256=Z0-lZttzJmDhdVPJaSdsgG4e_VVltfob19a980kq1ZY,34734 +torch/fx/passes/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/fx/passes/tests/__pycache__/__init__.cpython-39.pyc,, +torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-39.pyc,, +torch/fx/passes/tests/test_pass_manager.py,sha256=lL-SQ_BqCcQ7LALERNZzbgM58ClKBwFehjtCzDDv9g4,1929 +torch/fx/passes/tools_common.py,sha256=W_yFdd0xIU4wyVSw5SIQr16obfbGu8mKsFo6gyEYE-o,11521 +torch/fx/passes/utils/__init__.py,sha256=zsDIGQhthjEgPZr4tzE61sCnPNNwG725qUyZ-yzP-_M,75 +torch/fx/passes/utils/__pycache__/__init__.cpython-39.pyc,, +torch/fx/passes/utils/__pycache__/common.cpython-39.pyc,, +torch/fx/passes/utils/__pycache__/fuser_utils.cpython-39.pyc,, +torch/fx/passes/utils/__pycache__/matcher_utils.cpython-39.pyc,, +torch/fx/passes/utils/__pycache__/matcher_with_name_node_map_utils.cpython-39.pyc,, +torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-39.pyc,, +torch/fx/passes/utils/common.py,sha256=9wnm1pspWGzjWJtvH5shN2Ul5EAGLmwOr1nANVoRAnM,3201 +torch/fx/passes/utils/fuser_utils.py,sha256=kx8tCLup47GkrakQiqvCkFsYLwjJKnFG890DpSNSxnk,10063 +torch/fx/passes/utils/matcher_utils.py,sha256=WtSsfYq6i1B61CtH4N6VzJj-rasXQCMXu6amzCR9ZPQ,17682 +torch/fx/passes/utils/matcher_with_name_node_map_utils.py,sha256=Y4wKb2dZ2pNgUY_CVW9TgLKru2Q6FbG9uew2y9uWBAk,4311 +torch/fx/passes/utils/source_matcher_utils.py,sha256=VPJF7wYFrYthOhIy_WwDTEtjXG1-AtRhcqr1tyMD_xM,5930 +torch/fx/proxy.py,sha256=eCvdGOk-8e_-FdMsHG8IwdzbfLxFCJEqRc8pTkpETww,29303 +torch/fx/subgraph_rewriter.py,sha256=GAbL1_uC3t2qkBXTTyNGYNNXEAqsPFUmv5YDI_Rts-g,16333 +torch/fx/tensor_type.py,sha256=-fPKzS2WSrGQwJH6s-riLEY6jLe5BfMN1h_gbQwCWuY,3120 +torch/fx/traceback.py,sha256=FQoJmly3KEuivuUz4thFCexCZ_SVqDthSVUjcYiDdbs,7563 +torch/hub.py,sha256=xY7ytuhNGIlSu8UGCG2eKEaPmfHvh7B-PY1o30Cln8o,34334 +torch/include/ATen/ATen.h,sha256=2_a2EFRq8L3ZbUfOQ-U_MB2YFIocV0SmrwJSogR1TKM,1144 +torch/include/ATen/AccumulateType.h,sha256=tY72G0xRvH6E7_bussnxnpp4LsaRZc8zxCDIKYxMnBg,6284 +torch/include/ATen/ArrayRef.h,sha256=4zYGzRZtvYteRd__xp5pu9VciucD8IIzuVMs3A5Dp6k,46 +torch/include/ATen/Backend.h,sha256=-GO1LYCj0b1jHvhCo90emKutow1cpN3rM32YAP9T_Co,45 +torch/include/ATen/Backtrace.h,sha256=SlqFuyqN3PvlfyP68y1DEnFdWQUgLcNM0gmR3E_8a_I,48 +torch/include/ATen/BlasBackend.h,sha256=IqTWIufqJi9IxLvvw-UOVeHadJlk2Q-fPlKBrZAQcfo,786 +torch/include/ATen/CPUApplyUtils.h,sha256=LZQ0BtpIZljzKYGV0DhuYk0eztgVPPnwiM2daTik9zw,10986 +torch/include/ATen/CPUFixedAllocator.h,sha256=HevhsXM95SEkJu9W99gG-Lxdsf8BPzeMtmwLWSy1mxI,915 +torch/include/ATen/CPUFunctions.h,sha256=_eyHokMWODET3hOMLSvjcITHqi1_DGMTKj44IEYDr20,1983 +torch/include/ATen/CPUFunctions_inl.h,sha256=ptacOXZE91e4Gyo1EVouzHZLtBqiu4CxJ1-A2-gL7zc,27834 +torch/include/ATen/CPUGeneratorImpl.h,sha256=VDLdZFrny0eUiwfOpUNL7LTA1xYc-1t2aEIftQ9SDkw,1576 +torch/include/ATen/CUDAFunctions.h,sha256=LHMeRXdRzosLFV4qmPvH5v_2qfRNCEGnxGstavfYbWE,1984 +torch/include/ATen/CUDAFunctions_inl.h,sha256=HMkIjnmUVcPQ-Bff5QC5IuxwkJMAdhRsxNWnY5M8L_w,33248 +torch/include/ATen/CachedTensorUtils.h,sha256=oJKriW3daq5pz89rPdOwrfdBBvyOmuApbv2a8cJlqkQ,1031 +torch/include/ATen/CollapseDims.h,sha256=ZUyvqAOAPtHfJGq3-mo_49i3SfnP-hzUUGdprm_mimM,2654 +torch/include/ATen/CompositeExplicitAutogradFunctions.h,sha256=0-lSNBSpctfyiR_BsNj1F0mBsaunfRe-vATxxLQuBW8,2005 +torch/include/ATen/CompositeExplicitAutogradFunctions_inl.h,sha256=F-2AHLxJd01azDpaTE-6mL95G4AMPnpmSyOIm6JOpGI,41235 +torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h,sha256=nWMWcnR8GBCsKho56RjV6iV28-HyCTTE5q7EnzXh3-M,2018 +torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h,sha256=Ko9u6UsZ2XoGQl6jczJ_0WhsUZRFf6nrb3XLXlyYo6E,26567 +torch/include/ATen/CompositeImplicitAutogradFunctions.h,sha256=gliKmhdZ6nnFt44YvI4GnnAM3pmV6vN_V0QSOZXH3PY,2005 +torch/include/ATen/CompositeImplicitAutogradFunctions_inl.h,sha256=ZGmOirJc0r6ZJV4A965QEgQbY8GTJ1aF1KJdAuHJ7Xg,35444 +torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions.h,sha256=lXqy20Dh8nt-a7CUAhHw-_Go_9srUA7XeWGPkRS_nuA,2017 +torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h,sha256=J9aQiNqgqLlYNfc6aPBtYHOqGLqlyT3-NUHN5pL66Ls,1151 +torch/include/ATen/Config.h,sha256=aB8ef_RulFID2uUu0CtblZ-yUl5q1KwNpxjgi5LepmY,791 +torch/include/ATen/Context.h,sha256=CRRFRjMSTYiw3xZxhkxWnxjG43HzeoW6lmpC9vaHXc8,23019 +torch/include/ATen/DLConvertor.h,sha256=k4cdswqzeWLo1FIx9yaT3s1qEdvcJlA5R50f1D5VBAo,678 +torch/include/ATen/Device.h,sha256=PjTFxYOH8EE9gj2HbfmHkdVKNk0X1KBYoNoZgM-CEXw,44 +torch/include/ATen/DeviceAccelerator.h,sha256=O2O9kHBMG7PHNtGFgRmohCU-JE5pGcqJmcD4nKWa1jU,3140 +torch/include/ATen/DeviceGuard.h,sha256=WNtJWtrcb6QBvpqBe659BJDSH3GSTtPrUbt_fms9scU,1206 +torch/include/ATen/DimVector.h,sha256=UPELKgtkpIVtPe50MeJeVbzFzbmSSnd95P3eRfyVwbs,48 +torch/include/ATen/Dimname.h,sha256=irl8d98QKHdk4MLNTgPE-Qx2kiCFrzikBtkiAsyNGkU,32 +torch/include/ATen/Dispatch.h,sha256=xdJV7Z4WObFiiKZnGXmX7PT8cX92Z01-0YYi5gdBx7E,40953 +torch/include/ATen/Dispatch_v2.h,sha256=He6eC57jzepJutaNfrKQav62K3tcsQDCPy12jUmIL-s,60538 +torch/include/ATen/DynamicLibrary.h,sha256=3jFXb6PPJc69xKO93s2V9sdvw63lzPvghMmE3F4bKWU,715 +torch/include/ATen/EmptyTensor.h,sha256=97Sl49WBpUcga5FG-29A62IeD1CkhlzsSiXvpAkAYJs,4876 +torch/include/ATen/ExpandBase.h,sha256=M9aB8mJ0CnGtGPP6X6mIFzbH5PmrOX2yQ2rA7yAPRyA,944 +torch/include/ATen/ExpandUtils.h,sha256=EOI5i-AHROXTwokga91OjyGxX6XBPsfqljy1ufU7ajs,17216 +torch/include/ATen/Formatting.h,sha256=aDSVbqYGSFP73pWYd8X5cz7TwYLpZwHm9EtvITHprw8,35 +torch/include/ATen/FuncTorchTLS.h,sha256=JRGM1XzpnjcNA-7ei0S-anOq5bLfvjXKEZ8999BWlic,1862 +torch/include/ATen/FunctionalStorageImpl.h,sha256=AhdCg9lxndHfbrHeXmR3oupugfix7NQYHAauKtfLL90,7907 +torch/include/ATen/FunctionalTensorWrapper.h,sha256=ywrLNkzD_CTIkUvkSVZiTGyLc0f8WqLGh0SqhKqwn2Q,18220 +torch/include/ATen/Functions.h,sha256=YMPUUB9oWBogQ-p8odKdAuTJ7KLPmHE3bRUpiNo7qlc,56561 +torch/include/ATen/Generator.h,sha256=IsmYQ8-OVWES4O4d2Cw2zJ2zUCcVNnQtwmBJKM68P4s,48 +torch/include/ATen/InferSize.h,sha256=e3b9Gr8TALBD7e2jHeZ2GXXRbQpM_elIq3QhB27au0g,2977 +torch/include/ATen/InitialTensorOptions.h,sha256=JtlTji1MCEwQtcGcBFkAusPdNsq1HXx-mpHALNi4Wh4,454 +torch/include/ATen/Layout.h,sha256=uHy5slYoQiyLads1R9kjGHbFecuijeOHhHvtIxgnmOA,44 +torch/include/ATen/LegacyBatchedFallback.h,sha256=zsxktFB1yZ2exes2nr1Ct0r0n8TyoRtU4hEheSNM4sE,999 +torch/include/ATen/LegacyBatchedTensorImpl.h,sha256=u_rr-p-PMjg3pXjPIqOiNTliehlwJ9ccCCF_FYT0K38,5704 +torch/include/ATen/LegacyVmapMode.h,sha256=8EcLBLxOnPKRVBqdG1uOYWa3yitoazcKcFNSWWk33pE,953 +torch/include/ATen/LegacyVmapTransforms.h,sha256=q1qn5bc2yTlZJkZMjQQ1BSlEN6c1lLMCZJcrKbbAqXk,7998 +torch/include/ATen/LinalgBackend.h,sha256=SQjROHqkWakXnXNWO9FGhVme2GiElIWz9k8V81i7sk8,750 +torch/include/ATen/MapAllocator.h,sha256=M6MWF89_KBxge0DBieFkcvs4sjckKTCkiJnDCFGZPnc,3773 +torch/include/ATen/MatrixRef.h,sha256=CDbExbLStdCvaqTLqt7xFIkQLCgXgQNhT-HdivPDBYU,3115 +torch/include/ATen/MemoryOverlap.h,sha256=YbxE7UTcAX6LIZfzVGt-zvnAFd2VrJC5OVD27-YkqZI,1329 +torch/include/ATen/MetaFunctions.h,sha256=CTCyu8ARsTGOtWr0Wqn9ZiE4SLYQ8d_Acf7DqzT8MzA,1984 +torch/include/ATen/MetaFunctions_inl.h,sha256=03ky9P_qMhs5MwWKIGsevq7QbJvA931OeSPYGKvh9wk,16234 +torch/include/ATen/MethodOperators.h,sha256=opWOVIY0lE9N5SJK5_XB4iMjuzZybn6WPL2Mg2QAZr0,15888 +torch/include/ATen/NamedTensor.h,sha256=2Ccbg_TjhIsGFC1e6QAHsS79WUvjjhkKFOmf52bgG8w,36 +torch/include/ATen/NamedTensorUtils.h,sha256=mW1P7l48zsCxjcbIDYLciIwvYJospzfgD-JpdecROnw,7008 +torch/include/ATen/NativeFunctions.h,sha256=FAJxia5RjW-z0Abkh5s2MRW-eTJh8o_FwLUXC8Bc4uc,62178 +torch/include/ATen/NativeMetaFunctions.h,sha256=2V8w6L9EW8qHhtoV5aRAQSwQ7s9U_CMJRiqlhDHoe5M,58671 +torch/include/ATen/NestedTensorImpl.h,sha256=Hki61cUh7GQDZjugHWGrdy4ujCzJFKg54ZwFMkfsscQ,10508 +torch/include/ATen/NumericUtils.h,sha256=Ff-htYloFWMsmWakl-Oyuq4QByAjcVr-WVzZYfqungU,5340 +torch/include/ATen/OpMathType.h,sha256=Ie464kKRhB2XJ321JyBS_3QXnkH5l-zP8LnMGJcBJJ4,1654 +torch/include/ATen/OpaqueTensorImpl.h,sha256=_t7kvAD2PkoGAIYQfI3zoCtjcG9aIvDuF0UxRsD70Uk,7021 +torch/include/ATen/Operators.h,sha256=SK3Ap_D-XQHhaPyEjeOY8vi85RqB6TJH97I7pe9A7pM,60307 +torch/include/ATen/PTThreadPool.h,sha256=XoZi3jPKkJZWTgsrJSM9JhhQKaa0VAJ5KebQYrudqFI,408 +torch/include/ATen/PadNd.h,sha256=slASLr37YJnwjzoBC49HtWzichdxhYkV92KkP8G6fNM,138 +torch/include/ATen/Parallel-inl.h,sha256=2sIG8SeGjEQ9KIAbmZADV55ihH1DUtTRGZics9QsdJE,2386 +torch/include/ATen/Parallel.h,sha256=HTvhEWY4382_QuVQrnD-8AdL1dgoTU3oke6IO32hz9k,4911 +torch/include/ATen/ParallelFuture.h,sha256=HZ6MOFJ2bNk1Ks-nqaMUVqp5PFYxLu7e-Km6LYAXCbc,312 +torch/include/ATen/ParallelNative.h,sha256=OQFHzvOXR4-kgl6GyW0xUJCAyogTFjscwq5qMkEp5H0,307 +torch/include/ATen/ParallelOpenMP.h,sha256=CL3441Xwv8xa5tcbW0QOuNA98_GlM9VV2DcjGblYP_c,1337 +torch/include/ATen/PythonTorchFunctionTLS.h,sha256=OV4er-KFz4enf85D9OCMvu1ofmWIv1hkj7LdA4y-2g8,1229 +torch/include/ATen/ROCmFABackend.h,sha256=_nuo5SeeCZtrIlBb6fnAYLbLIR9ZnPAzhAVALxmYPWk,754 +torch/include/ATen/RedispatchFunctions.h,sha256=0ZxCw2sM95P1HygFsYwS0_gQ_g_HdokmaiPrtb9EcoA,2238304 +torch/include/ATen/RegistrationDeclarations.h,sha256=iA9yWRodvlVNNDH-fHPQKviTejZC6ljk4PS3tZcB46s,917220 +torch/include/ATen/SDPBackend.h,sha256=rBvqB7H-1krs1xw4UqB4TOJREhKxcFU1V5IBLuLZ64Y,269 +torch/include/ATen/SavedTensorHooks.h,sha256=KY6WTcXFBve6v1yX2TJWBC-DxwKzMWCGTQ7YKwHimVk,2554 +torch/include/ATen/Scalar.h,sha256=xdFJVWXQK44860D9dbtCrS_dIekAxuHnaeKlrvfLzCw,47 +torch/include/ATen/ScalarOps.h,sha256=75E4Ow6xw85BlSK4qnIGznJR4dKzGc2RnBBn7qP0D50,1648 +torch/include/ATen/ScalarType.h,sha256=ex-_jwtx4oUXviBFRBVrRVkJtSaZMsXa2JELvxRT0DU,133 +torch/include/ATen/SequenceNumber.h,sha256=HVNCQ1sW1di-7_TzIcaxLdbRFjkEAFjxokjjXqraJU4,346 +torch/include/ATen/SmallVector.h,sha256=SHoiKu_mdXsJq0UTlSdElvvXCBP0ZJsjBsdxMS7B4so,49 +torch/include/ATen/SparseCsrTensorImpl.h,sha256=C_OwNOgRc_RaJqujRV40UlaaPd_sVS0rng7vWHyYsII,7315 +torch/include/ATen/SparseCsrTensorUtils.h,sha256=lwpFrnveq5id6e_EXO_RvdmKEwhGuE6_YqSM5q2W0Ik,18228 +torch/include/ATen/SparseTensorImpl.h,sha256=VePg7ZwWw3m3Vqbjc472wYQ-vqjsno9E-nvRjvqUIfM,15729 +torch/include/ATen/Storage.h,sha256=VcWXKctsV2y4A-q7YACqbA4B-hRxu9XrxPQlfE5xByA,45 +torch/include/ATen/StorageUtils.h,sha256=Eh67tSk2Vx9xHW59bBMg7wZ6Rg-GOvgtYJ2fIkPHstg,1357 +torch/include/ATen/Tensor.h,sha256=qHcsthT-fkjfFAcpWpbEmOAOwgZOfBzeyX2rHW9qusk,47 +torch/include/ATen/TensorAccessor.h,sha256=XVhywuAjpY8Wo1NnE-1OJJLRc3Rg6LOMvK6CoD8cI1I,53 +torch/include/ATen/TensorGeometry.h,sha256=2_tfTElYVJw0gk9UCzwa8XP8EsRSAQhhO7vFOs4CCHY,4710 +torch/include/ATen/TensorIndexing.h,sha256=FCvPNXaGhbCzFrWxysl-aUKdVuebzDqClz5iglAbLVA,24862 +torch/include/ATen/TensorIterator.h,sha256=w1RIKqxeiS8wmeiunVRduQ9Cm1-xi3XMunSh3EY0Lk8,39912 +torch/include/ATen/TensorIteratorInternal.h,sha256=bQVnD0TF5q1HXnqCPdNSRDa_e_eYAN8unMhG4OO6a64,2003 +torch/include/ATen/TensorMeta.h,sha256=DNMcV-wfwGnc5S5P6rOUglUrVR1I0bKGusdM4n5ZVQk,5171 +torch/include/ATen/TensorNames.h,sha256=ZKD71pEkjZvo7cK537WsZPxfBmVBtgO1CEt-t2EQc68,2646 +torch/include/ATen/TensorOperators.h,sha256=XU-nDCEu3pd3a30vOprnTY9Edx3mk0MByVG1zlZJTuA,2542 +torch/include/ATen/TensorOptions.h,sha256=uSxXHUaxx0nKq6iXL8iu4qgb0XuWEADFDQmlNWslCXo,51 +torch/include/ATen/TensorSubclassLikeUtils.h,sha256=TMHqah26TN7Vv_neiBaqsxSw5Z3dMX0eLL66GGF1IS4,3317 +torch/include/ATen/TensorUtils.h,sha256=9mSWS8jNwZRFJGg4T-cP_5KJ71i_XMiJHisAcOX5ZJk,6148 +torch/include/ATen/ThreadLocalPythonObjects.h,sha256=spp4m4wqpaGRiZ12EQYGOjfJHKVCdjw4aE4Iq9lz7rk,628 +torch/include/ATen/ThreadLocalState.h,sha256=vd6RyYyFmDpEMh3d4U83NksRGVncXLR1Ar3jlHtid7Y,4433 +torch/include/ATen/TracerMode.h,sha256=Y-ZUcvLev3AhrwEP1K4K_GRVjco1wVtm5KU6KHFiAUg,5640 +torch/include/ATen/TypeDefault.h,sha256=2SUtJuipgP0Aho0EKh3CMy_TKkBYNQiNoIhaDw6S8Go,469 +torch/include/ATen/Utils.h,sha256=qm-Hv9mHK81-sZAxJP8fVr3DGIcvzncdB5W3Nd2zuPE,3702 +torch/include/ATen/Version.h,sha256=qxyNg02gIqQhC_ZS282Qgu0frus7RL07bXSYBXpIOh0,402 +torch/include/ATen/VmapGeneratedPlumbing.h,sha256=2OZpnAzVOkjOrM_ECW-ozXQfWuBjmrr3oo5pHjaNAxQ,1776366 +torch/include/ATen/WrapDimUtils.h,sha256=QH18Ku_vwxYuyTW2lk0jTwf0QRIzHfzrImLNMWOLoKk,5023 +torch/include/ATen/WrapDimUtilsMulti.h,sha256=mfGp49hqJg3uHiChEcZaP3NXx-AZSixHWqr9JcZOEQ4,1119 +torch/include/ATen/autocast_mode.h,sha256=4CmvJRTkJf2DRI6UWtQUWEnQYpkwQopgpsesrFP2LRM,42211 +torch/include/ATen/ceil_div.h,sha256=_TEAuNd5Nqes7nlpi9kP2yXI5F23wZzVvxvx3aCUPYw,521 +torch/include/ATen/code_template.h,sha256=DBC_68YyW_PW8cY0y49WVen8ydY1CRexm5dkRIat7iI,7196 +torch/include/ATen/core/ATenGeneral.h,sha256=IQQc92i1DuPkEZ7sKJ4DjM6a_zEyf0Y3bVzu5wDc984,48 +torch/include/ATen/core/ATenOpList.h,sha256=aakj8oiOCMjnG7fgAIaA70W_afez0btAF-x3PTPE9h0,259 +torch/include/ATen/core/ATen_fwd.h,sha256=TFN6NG2CHYp-O_gQIWLIAnxuuuCgZiAfnrFvQDpIzhs,1070 +torch/include/ATen/core/ATen_pch.h,sha256=NbMAYJgR2eL7Ow63ObuOOGxYmbq7FNqTE95Wsn8-0Sw,5239 +torch/include/ATen/core/Array.h,sha256=8kbQ-Pxou1H_OvMuBb1zydD_KFttpcpr8ORor0eukPA,1182 +torch/include/ATen/core/Backtrace.h,sha256=NHY_QVjlzACvtgnyHYa9nYai5N0aO0O9-ByeuDotpOw,61 +torch/include/ATen/core/CachingHostAllocator.h,sha256=ouknx5it8LruaMfy4NXRIIGn0qb-Xg7Eq4ZR6oUtH-U,27781 +torch/include/ATen/core/CheckMemoryFormat.h,sha256=RBxWRzYZl3YvVUyu0mgF-_kqe0FzwL5DlTerediJ5lg,814 +torch/include/ATen/core/DeprecatedTypeProperties.h,sha256=u9Dt7N1gfAyEiL_smw1z9n2Fd7aqLENpGrTTWTOqRW0,4018 +torch/include/ATen/core/DeprecatedTypePropertiesRegistry.h,sha256=5SRdiKiCKqFjIgkP68i4o0aX7Gzu8pMom2yWE0FCDRc,877 +torch/include/ATen/core/Dict.h,sha256=60a0XDyb_Yoryek-L3ksKJH9hmLrKPkvTo9DBNL7oaI,13672 +torch/include/ATen/core/Dict_inl.h,sha256=ixptwh7lUJhZuJGfts2Lk7sGS6Wpgw0q1LHS64Qhb-0,7670 +torch/include/ATen/core/DimVector.h,sha256=Ls0_clT44nBVoSTR7RP-FUrDcb_ejj7JguM_lVyyXXk,292 +torch/include/ATen/core/Dimname.h,sha256=2rWiZJpRnk6Wfo8IiZWJd0Izn1f_mmn-r1WcmiXAeRw,1215 +torch/include/ATen/core/DistributionsHelper.h,sha256=aYUIDpdUtS9CREzbvaRB_JkUWH4QMp7gxPc8uq4zKSc,12935 +torch/include/ATen/core/Formatting.h,sha256=oM97AkBviueNwY7aDcE4CEFLnHcpBdyuj8EtOY4qlCo,718 +torch/include/ATen/core/Generator.h,sha256=NpJtnS68Ul7eqSC-7Vlz_R9WjqyxvYRUxpmyui4kQBY,6597 +torch/include/ATen/core/GeneratorForPrivateuseone.h,sha256=zBMSV40e4Gurb9R43otQpAgM97Z9xZ_G2xTdbNyM33U,1120 +torch/include/ATen/core/IListRef.h,sha256=tgBErnLE8heYtfrtWyqdN53ZDp_P1MhxLMLGACbvN1c,21567 +torch/include/ATen/core/IListRef_inl.h,sha256=KJMEVdM6J8YYJg8Enl2k9bCLY5t-vdTB5Zo0TCKdGpc,6408 +torch/include/ATen/core/LegacyTypeDispatch.h,sha256=7pRRzqbhIwm7wkqpWXV1BKPjR15BxiYWMRBCJ9cVwa4,4968 +torch/include/ATen/core/List.h,sha256=BPPPrN4QwVFjE9dj-SSpqx-j9gGeqrrNxJOjDS-NWBc,16569 +torch/include/ATen/core/List_inl.h,sha256=dSmuHIzTPadyKk4vIf8PLdIX6Q1yAnsyplVfvOHYIIM,11105 +torch/include/ATen/core/MT19937RNGEngine.h,sha256=BvBrpPVLoygmSiLGPubPJlMOMwbwCzWFXgpk46y0C6M,6704 +torch/include/ATen/core/NamedTensor.h,sha256=mq9oOK9M2tCUBHAP2UtVeCpIPqxMeJNaXq3Ise9lx6s,5375 +torch/include/ATen/core/NestedIntSymNodeImpl.h,sha256=An4TFB0-qSROiv3pEkgwYNEP_kpAUHJaFzl2_bO6cDE,6147 +torch/include/ATen/core/PhiloxRNGEngine.h,sha256=gl4j8m7WKBoFYwiVlHB1rEF0ZoqQRjnE_hXNI7XOk10,7954 +torch/include/ATen/core/PythonFallbackKernel.h,sha256=nrEyVK3V-40iw0jR2hkbvWc_4Rh442UYiB28hWcufaU,1129 +torch/include/ATen/core/PythonOpRegistrationTrampoline.h,sha256=eu81ON9auDlHX86hHvI0fjIwiTn4sMw4KiaU5aCgeCs,617 +torch/include/ATen/core/QuantizerBase.h,sha256=gRuvu0vzE7CYmjuhq1--MvexJ-zabo2X1VSgCmBTq2g,2771 +torch/include/ATen/core/Range.h,sha256=47XY3n5sg3XIQ2gYgMPpTCD1RLQl32EVytCt4j7aHTw,443 +torch/include/ATen/core/Reduction.h,sha256=DMk5kCWiHjmL8VKSU6NIP0v1wNboJhmsKvTOQH4f5I8,413 +torch/include/ATen/core/Scalar.h,sha256=WPwKE4RciXVN9db99hpCLvDZAK39jlG1zlKffEFmgcQ,30 +torch/include/ATen/core/ScalarType.h,sha256=gzNZUMu8I-Zts17Fg2oXY0RLx6xTqy6XK_TIdCwxAus,34 +torch/include/ATen/core/Tensor.h,sha256=YHUdh67CCtXVxjDyupb9fr1STum49N5sT0VOpOOS6as,2501 +torch/include/ATen/core/TensorAccessor.h,sha256=4nSLpdq437Pzrx6o663LTrAhRpcJNqeHuFYxl0ApYeA,10680 +torch/include/ATen/core/TensorBase.h,sha256=Kebm9lIaXNp60g7MtVnGRv-EUzppS0Lon3RZpMcHT8I,39120 +torch/include/ATen/core/TensorBody.h,sha256=19xq7VNSqMm048IyYoWsDlZ26t49L36Pp7ctYEdOnc4,298490 +torch/include/ATen/core/TorchDispatchUtils.h,sha256=BovdvVDa5kO88mb6jiyWabkRr7E7p0TKDQ_NH5mEP9Y,501 +torch/include/ATen/core/TransformationHelper.h,sha256=Dj6V-dJtsXjESPTm-zxAsePkLGvEkn5SfzfjA2oPr6c,7029 +torch/include/ATen/core/UndefinedTensorImpl.h,sha256=gZf3eRbLC3HoqVKb2iiqACGBU_iTRgjareRV-Wub1DY,43 +torch/include/ATen/core/UnsafeFromTH.h,sha256=J6Ok0E3N54Y33xmyd0fPI0kOnYKUOtwfon559YnT-Yk,729 +torch/include/ATen/core/VariableHooksInterface.h,sha256=vWBqTy11Pte-sg7S8imcGN5ymCzxPE4CQFb0MpiNRs8,3621 +torch/include/ATen/core/Variadic.h,sha256=l40dcQD6ZmkpghGoEkIHyNi5xK50ct0vbq0CGiRzCpI,2472 +torch/include/ATen/core/Vitals.h,sha256=_FMxhm59ktqHyDv8imy0iwGeGQiYt15StiUm2hsfWNQ,2527 +torch/include/ATen/core/alias_info.h,sha256=NE6Kru3snnG05r_J9mitQaekpeqIG4whj-nwMo10ZV8,4744 +torch/include/ATen/core/aten_interned_strings.h,sha256=hSDA6X8KxZUdJ24L-hg8aCvhwfpwJ2TjKtHNH8In0-g,58800 +torch/include/ATen/core/blob.h,sha256=XaQtyXAJb0aGDhfV-Vd1cWWqI2gBO8dHG5YdoAXncQE,5448 +torch/include/ATen/core/boxing/BoxedKernel.h,sha256=3tNBJXI7iaBoj7qiWtupJGJJTqBi8F6uA5CkfGARoUM,8350 +torch/include/ATen/core/boxing/BoxedKernel_impl.h,sha256=dsqkF3DXRHni7DUYax7xuq-QaSWBnT8qsQG-N1ipD_w,3361 +torch/include/ATen/core/boxing/KernelFunction.h,sha256=9_2qFzxLV9kzTLbsdrRaQvyBTgp6AKI1_Il3qU-UTrc,9008 +torch/include/ATen/core/boxing/KernelFunction_impl.h,sha256=6hbhkD5JUSN9UJ7C7IWms-S9mhmA7coMZ8WH0oWslgg,12013 +torch/include/ATen/core/boxing/OperatorKernel.h,sha256=qQ3pujC0MeTGkO_m-Cx3gs0Y4zr32-VHFKY00pX-OBI,718 +torch/include/ATen/core/boxing/impl/WrapFunctionIntoFunctor.h,sha256=PTEoTFHJ0-N2ApDbpKamTMyLSX8HzTeNPLzOdNCXgYU,1366 +torch/include/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h,sha256=FpEKCbEh8L4tRXQIImZpS2M8MmBr6eYGX-yOk_DMMtc,1473 +torch/include/ATen/core/boxing/impl/boxing.h,sha256=8r5HqAMO1IhkS4vqvqttCXnKRf0mEIaDTHsMgugj034,13975 +torch/include/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h,sha256=Q3xuJTQXtxnDie-v5bEAQFs3epBFPoF5vyhfI69zvYU,32473 +torch/include/ATen/core/boxing/impl/test_helpers.h,sha256=4tsSztQnU76NNyxLxjnWBI-AWyGifOOnDq72uC65Bu0,4541 +torch/include/ATen/core/builtin_function.h,sha256=So8-NHvU_URhTsdHErhCnyh_YmD4IkcmjafQK51LSRo,2134 +torch/include/ATen/core/class_type.h,sha256=0RRTDeMf-0RT75IL3zooOMLt1X0aMOnMhN9KSAzMmuo,14536 +torch/include/ATen/core/custom_class.h,sha256=oIQZOptRGJV2w908hvdHCd0aCW43mWpAmEU5AAKXThE,772 +torch/include/ATen/core/dispatch/CppSignature.h,sha256=uRRuUvV7xwRdt73agSMjAiNcPx0GzcEQM-sZ7jVJZG8,2449 +torch/include/ATen/core/dispatch/DispatchKeyExtractor.h,sha256=1O-MU4lwRNhO_ps8PxuxpDeNXCRegEbxITkUYBRAhSA,11207 +torch/include/ATen/core/dispatch/Dispatcher.h,sha256=ImN63kWE1o9wK1-Psn0SBItaMmNwIjxjzyZNZynWkYE,35185 +torch/include/ATen/core/dispatch/ObservedOperators.h,sha256=NyVoJTiGihd_fcQO-tzHXCBJh0QaY_aaM3JtiWCW2M8,346 +torch/include/ATen/core/dispatch/OperatorEntry.h,sha256=rsQJc6_QP1ONcjcknGmKQO6QV6I4HzBfL9zTfM1Dz1g,13354 +torch/include/ATen/core/dispatch/OperatorOptions.h,sha256=pY6BwShOgABKJYNgKvikONmXqiEWrpZXSguE_P7JBdQ,916 +torch/include/ATen/core/dispatch/RegistrationHandleRAII.h,sha256=xcpdAbTNfgfGAUmZQm6azJ42zg6CwpaThciTyT22iH8,913 +torch/include/ATen/core/dynamic_type.h,sha256=y8BVCTdRwkNgPmXDfsQu7PkoFWWity_yvsjMdmnYy3U,10983 +torch/include/ATen/core/enum_tag.h,sha256=XsEYgRq1C6AQk3rBgLuyjNBUWoU_2tTeQBMkGCpO2kg,630 +torch/include/ATen/core/enum_type.h,sha256=Yljv8ux8W6-NAQeilZuI_bdyL1xztqe5f5fFYeNLd7I,2942 +torch/include/ATen/core/function.h,sha256=24WTeWynMDGJbpDsYVBithT7l6VRU7RHd5bEzK1dJCI,3569 +torch/include/ATen/core/function_schema.h,sha256=8a772FnlEHPy5u5SdbMzUPfvV__77LUtlNhUkIkBMTg,24804 +torch/include/ATen/core/function_schema_inl.h,sha256=WT3P5xoNG9LuWZ3UV7l6KcVdprpjbMYV2WRmGq-n3Ts,2149 +torch/include/ATen/core/functional.h,sha256=NVGx0ofC04rv9MfOzCcqsSzubMY7v-sT8Ggj9N3wIDY,1518 +torch/include/ATen/core/grad_mode.h,sha256=ay3zsOf0ukxMxLVRkKEd5IOCU9syB7pAX1tL1AQvW2o,220 +torch/include/ATen/core/interned_strings.h,sha256=SWAHBjKumMsJHM0YSiTc3OqGLW59_qObCLAJZ8TNEo8,13763 +torch/include/ATen/core/interned_strings_class.h,sha256=Qnu3MhKsj0nEGQ7HRxCKpdhEiWfa6VnKsvfC4f-5NZI,754 +torch/include/ATen/core/ivalue.h,sha256=d0ux-nMD1uiWFgTIMyGrBpR60TCZJy5UYQrrhR3ItPU,53092 +torch/include/ATen/core/ivalue_inl.h,sha256=KXTPS98fTEjNwQzV0HNcHnqE9UempTUFAWci-2R66x0,90391 +torch/include/ATen/core/ivalue_to.h,sha256=1sZ1MHxxnyqiv1FLFg8KfoLnDw7kq8Q2IKWGOBHyVhg,792 +torch/include/ATen/core/jit_type.h,sha256=dVexhw1eTA6Rx6SVvUVTwir3mqhFCT5Blr0jXk_T6TE,74642 +torch/include/ATen/core/jit_type_base.h,sha256=o9XXQaf8lvscPkUPgEuHps6-hHKZ19lUo88NEHONjr8,23845 +torch/include/ATen/core/op_registration/adaption.h,sha256=_B5PnaAl6QWlPKi2IgycIZR1gu_ANPJe24ixfnVwo0o,3297 +torch/include/ATen/core/op_registration/infer_schema.h,sha256=4ymNXWROb8-WthAw1-1drtnSqY4pul-YUthZOuzLkQI,6884 +torch/include/ATen/core/op_registration/op_allowlist.h,sha256=V9VBTNE0m8iNQfMG5b87O7LcN4McnYYplUwDCcemiuw,6497 +torch/include/ATen/core/op_registration/op_registration.h,sha256=cBYBy_jHNuw7WO2lLXIh3p_ycTcrpdxA6ZaTCH388eo,29175 +torch/include/ATen/core/operator_name.h,sha256=TMsUGy1J8TO0cFir9oFTM6aF3N6ffcEufHeTYTwiA-4,3132 +torch/include/ATen/core/qualified_name.h,sha256=hs6uTmzCsudXlclzLMJt6CMEAJzhUFoSm3aLmj0WT0Q,4534 +torch/include/ATen/core/rref_interface.h,sha256=GQ2fMHcFaMV8uemkwNkbiNlQlrScEDlP68Qmf-TE7po,1249 +torch/include/ATen/core/stack.h,sha256=3McELWsmZ5sl1QMubV_u0_9g0KypE8WjAAUP4I66ArA,6386 +torch/include/ATen/core/symbol.h,sha256=MpW9bAln8kwd5yllTP2Aeebsv41tSfJv5ncl43xCBPc,6020 +torch/include/ATen/core/type_factory.h,sha256=cFBe7PGGNbc2wk4RdCN6nLldBxp0n5pukvDTDFuUDbE,3355 +torch/include/ATen/core/type_ptr.h,sha256=TrlSyrHfXSMPezm_eIQYsMuNk2LVP3LrYwT30MWRcDg,1272 +torch/include/ATen/core/typeid.h,sha256=fR5DkdL7DayDhWT3Y69CQySt_TMFIsuCFAjV4Ihc1yc,30 +torch/include/ATen/cpp_custom_type_hack.h,sha256=938TgHVOXsBljP5x3gegAr4du0NGkJ8o2ZJ8XY-3-N8,5540 +torch/include/ATen/cpu/FlushDenormal.h,sha256=MTiddtcTeSZVkkfL7HQvP0_nzt_Q-cQxrjPoC_nYo5g,551 +torch/include/ATen/cpu/Utils.h,sha256=xUuUwKzPzVgbrFOQ4suHQc-K6HWCvZuJr02UODIsJmM,836 +torch/include/ATen/cpu/vec/functional.h,sha256=lUHsuFIHH4XHaUs9TOTbrr3rgVsiMB4ywI6JJwQrl60,106 +torch/include/ATen/cpu/vec/functional_base.h,sha256=n2aXD9QMn_9liyfV0Sy2DnofEo9rQPRFizHmTyk9PoU,15954 +torch/include/ATen/cpu/vec/functional_bfloat16.h,sha256=e7oBsoPnJuEPblZBwTgGaX071vuZ4hXJKQwiIWZ_JB4,26058 +torch/include/ATen/cpu/vec/intrinsics.h,sha256=0HLvhIHzYHtqDS46N3PyJ9CvfhBxqr-HpgHDGBFKE4w,2143 +torch/include/ATen/cpu/vec/sve/sve_helper.h,sha256=XuX0d1Oq9hVh8stOmUMjodDqZLP3bQQzFzrUaiUtE-I,3249 +torch/include/ATen/cpu/vec/sve/vec_bfloat16.h,sha256=34k1mJJI-Yc5EtdxmJK7X1JUWBoq5UVEDrm4-qWEL_4,20801 +torch/include/ATen/cpu/vec/sve/vec_common_sve.h,sha256=PCkgBFb409QvEZ-jcgoAsHs7exafzAcjaycQcNisf2I,8660 +torch/include/ATen/cpu/vec/sve/vec_double.h,sha256=c4EuE26yMaj39aJvPWCN-uH05_rkxT0tV65IJl54OME,19545 +torch/include/ATen/cpu/vec/sve/vec_float.h,sha256=y7IJspCCnIpoxrroYpHOxcjz1lJB4T8FMDaC_YuNf9M,27626 +torch/include/ATen/cpu/vec/sve/vec_int.h,sha256=SeHey9fppjVPhXjwCSE0BNYT8N0L95J3V6UiPRmv--w,28307 +torch/include/ATen/cpu/vec/sve/vec_qint.h,sha256=FuvR5dnpp5ITh22j1L_1AnsMbTXSXyInAK1gwGrYYLY,20649 +torch/include/ATen/cpu/vec/vec.h,sha256=Dj1cPVg5g7pH9EN3_ZGhgpb2rS8ERAyJaG9HwTnM_C0,1406 +torch/include/ATen/cpu/vec/vec128/vec128.h,sha256=a54IXRmXz7HwNvlenpZNnP47m59cV8B-ZWvDE0qeixQ,379 +torch/include/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h,sha256=dxJoRcYKYq8dGr713ZjL-cceM_agLiGhUIiC3vIhhxM,23137 +torch/include/ATen/cpu/vec/vec128/vec128_convert.h,sha256=B9qi7VIxVm9TVl9VZtuL1tBpSiAMt9kOqlvHLroSbOc,2052 +torch/include/ATen/cpu/vec/vec128/vec128_float_neon.h,sha256=2Id2bXjyGZIvZJTvfxFX9sQbvED5klSFx4IFL3LQeBM,20281 +torch/include/ATen/cpu/vec/vec128/vec128_half_neon.h,sha256=VIZY4E9Jw45CyKCETtOq7NLXgeDn_e0nPusEaWpspsY,20708 +torch/include/ATen/cpu/vec/vec128/vec128_reduced_precision_common_neon.h,sha256=iWgx3-z8n_QIIibYiiArq_jTsp-wpKSI0z7GoWf2t9k,10085 +torch/include/ATen/cpu/vec/vec256/missing_vld1_neon.h,sha256=-TgT2eWX5y60wb48zLv2cxLvJ1Q-MJJT8vTkqISZt8M,14179 +torch/include/ATen/cpu/vec/vec256/missing_vst1_neon.h,sha256=g1PRAShdtzXwFfgKJ173I0F-eNGO-r8dMc4Vi2NyFvU,292 +torch/include/ATen/cpu/vec/vec256/vec256.h,sha256=Qrw0qLt0kuRdMd8dlQSb415bpM2aCB-5PMjip6XRQ-E,13150 +torch/include/ATen/cpu/vec/vec256/vec256_16bit_float.h,sha256=it20ZC44GUwkka6zt0act8INkqRWtpXgXGVdqPpy95c,29220 +torch/include/ATen/cpu/vec/vec256/vec256_bfloat16.h,sha256=glSYCCVe_P6hQC1DpYEbkZivV69dq67Q8X2xWltr15U,9087 +torch/include/ATen/cpu/vec/vec256/vec256_complex_double.h,sha256=WljtXOUR-iepRwNq2VUOGAB-NBwFIqA1_8pBe33AcYQ,20031 +torch/include/ATen/cpu/vec/vec256/vec256_complex_float.h,sha256=NW9Z1PRyf4dGz9kBMTTeSYfmUAZ7KsOeW47vNWhzxAo,22194 +torch/include/ATen/cpu/vec/vec256/vec256_convert.h,sha256=K6a1FfFgDnrorj-wJpprcqM0mA3rqqdRrfTgL5W5e70,11398 +torch/include/ATen/cpu/vec/vec256/vec256_double.h,sha256=drnMcivbiLft-DuB8raev5IjX7lAviiAnE5OAy13cm8,15612 +torch/include/ATen/cpu/vec/vec256/vec256_float.h,sha256=jC7f8TDK6Sq4w8ih8TK1kiVJHYyijR5eif3X1ZDtEqc,25240 +torch/include/ATen/cpu/vec/vec256/vec256_half.h,sha256=Zi9OOAdVsASepQYJVViEKACuB49RnGILaJOmSIZwIF0,8669 +torch/include/ATen/cpu/vec/vec256/vec256_int.h,sha256=yWPe9iOTwuc7vUEFjrY-O9NncSBKFbJ6slHGVYocdME,67728 +torch/include/ATen/cpu/vec/vec256/vec256_mask.h,sha256=hArgJrOoCPMpaXKtMkaAZXlBqyn5cI1-VL2dhboWbN0,9070 +torch/include/ATen/cpu/vec/vec256/vec256_qint.h,sha256=H7yhZy9Cu2YqxL4p3AAFw5WKqfyT4hqU4AfU2w37-tc,49259 +torch/include/ATen/cpu/vec/vec256/vsx/vec256_bfloat16_vsx.h,sha256=j8a--3laX_H3bq7bBr3H8jLiI3F4LiGs9P5MDtLEyk0,2212 +torch/include/ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h,sha256=CBsf6P71312wqWI6bKx-ZiSajj-ZpgunhBRF3hFm-S8,8337 +torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h,sha256=nGCjZ8NvC3c0oss4r-trwel-66hsNzfjDmCndhsBNJY,22361 +torch/include/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h,sha256=THw01IhY0KImJ38PiyoNs_eSdIB_8QIWNDXAn0Cz0a4,25142 +torch/include/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h,sha256=rujt4cNli4gWYt4gcoCOidoGuE4imj3iP5q_OB5uZEU,16718 +torch/include/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h,sha256=O_D0V4iGdPVg_eT44gVw-kA49ie6843pJiCePG5Uym4,17302 +torch/include/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h,sha256=auzePVNRR6cBXGs4_6603vn6lTkbjRmdUCBywz9mFaE,14709 +torch/include/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h,sha256=SMzZJkqZgyPtFiRvR-I4q_8yO2hAwQeaF9HAcrQgz-o,12454 +torch/include/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h,sha256=8Y4gSih2uLEqflyuUlUXeghkJvJz-qz6KTWcsKK_GiA,10677 +torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h,sha256=fOduwKt2RMEglcsw8AmmIOeimJUuWphyjiW1qVmOII0,10366 +torch/include/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h,sha256=AlMVrLEpVtV_2yiOAMwt1TOio57cEK7RmqtyDHPpv7U,17930 +torch/include/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h,sha256=3nunhmV6dS0tATiYFWe43N9bcrk7j02yzaxi1o-4fqM,18648 +torch/include/ATen/cpu/vec/vec256/vsx/vsx_helpers.h,sha256=-ww0xf7QhBYttMLJArxKC0V5UbD_sEk13lAYyMDNObI,21028 +torch/include/ATen/cpu/vec/vec256/zarch/vec256_zarch.h,sha256=N804I2itwl2OYYXpnz-Ho_p9AazSlDv-eTYPNYn81pk,104938 +torch/include/ATen/cpu/vec/vec512/vec512.h,sha256=xDx7KemRRqJrD53r0ZhGraFe177PpMKgS-xgo0ZgKfI,11545 +torch/include/ATen/cpu/vec/vec512/vec512_bfloat16.h,sha256=n-pTsZ4nrFUWW5kCxw5rMREBNDhvxTrpLsDhltscM0g,68331 +torch/include/ATen/cpu/vec/vec512/vec512_complex_double.h,sha256=5HCEBAaUgJuRr7rIEaOSFBv64H5SxEtigDYdul0xd6E,24670 +torch/include/ATen/cpu/vec/vec512/vec512_complex_float.h,sha256=p7KL-mhLgRNyVhDnQrQxD3ky-6r6itMq9EDxR_9sXKs,45731 +torch/include/ATen/cpu/vec/vec512/vec512_convert.h,sha256=MSr7Pi87aplKxIciNqhhe28uD9u647H8AnGOVEyQO3w,10296 +torch/include/ATen/cpu/vec/vec512/vec512_double.h,sha256=Ygd0wCLK7CuV_46_Q-Uoug59FX4d9nfhRXe8kdfBVS8,17182 +torch/include/ATen/cpu/vec/vec512/vec512_float.h,sha256=W3HUrTZSpVz-eD0HpXBZJI8inEfKfWv6QfVG2k9PL5s,27987 +torch/include/ATen/cpu/vec/vec512/vec512_float8.h,sha256=uSF8YZUUwQfDaaZ6PspNgKvfmqYHPa15C2BZZGXeMxk,24347 +torch/include/ATen/cpu/vec/vec512/vec512_int.h,sha256=l68W0VHR8j6VSvBnuGNIQ0EU3-h8DtmRhQpjIx9L7Go,60849 +torch/include/ATen/cpu/vec/vec512/vec512_mask.h,sha256=6TA9IfrHMWPgD3ay4njyg6BVO6dKxWQcaYfyktrS1fU,12688 +torch/include/ATen/cpu/vec/vec512/vec512_qint.h,sha256=GV3ZezZKAWxDj2tZyo--qMvzu27hm3-S8V1BGXQ5KdQ,52211 +torch/include/ATen/cpu/vec/vec_base.h,sha256=OERs0wW3QQp2B2iPRpAvXGB6Y88VIftnc_j3ZjSHmsc,50319 +torch/include/ATen/cpu/vec/vec_convert.h,sha256=7wM1Cti1xSJ3HwtviTlj-z6Nll1EzMYNoAnwjxnJWyY,2368 +torch/include/ATen/cpu/vec/vec_half.h,sha256=MmH3TUmkpQqbPcD4sH-ldvqdD3j997jmrY3qbH_VLVQ,4791 +torch/include/ATen/cpu/vec/vec_mask.h,sha256=hOakyLGXq8iIfDIp_L53FbRcljeQEXKXeULrrv-bf3A,10271 +torch/include/ATen/cpu/vec/vec_n.h,sha256=-9owkQcSHX_VGNgxU86dZn4FPmtbqxD_4gpj9OkUGME,13866 +torch/include/ATen/cpu/vml.h,sha256=Wa4w5yQBvAU_JoIj0cSjF_3ofYNRzxGDefRBCHDrHXY,6245 +torch/include/ATen/cuda/ATenCUDAGeneral.h,sha256=-37BTZQ1d8h5R2_fTgelbx4RKkKYjI33FB2x2sf8jaY,199 +torch/include/ATen/cuda/ApplyGridUtils.cuh,sha256=8F8ridaqhj3mdDVgqPLSFbUaUNNEqxPORGfy4YE-c7U,1356 +torch/include/ATen/cuda/AsmUtils.cuh,sha256=1zIB1Nz8ZhhFQftcLwxwEeYo7IINf1bMPFd5XOVyflQ,3543 +torch/include/ATen/cuda/Atomic.cuh,sha256=nIO6Qn5l4_OPpYfUyCogfxdTkf0XOWrrLAzANGktiy4,27796 +torch/include/ATen/cuda/CUDAApplyUtils.cuh,sha256=gQKAWnTaN1vEGFvqkWauagW6l7imtH2IMEFr8Hy0D-U,20983 +torch/include/ATen/cuda/CUDABlas.h,sha256=MlfvPHCu0LbRcGnj7zrRKlUqNPtJQgFdxw7uHtSQJdc,16426 +torch/include/ATen/cuda/CUDAContext.h,sha256=cNOI7xgC-x_8PsQjeI6Usa4giKkD-ftPG_vIR3n1Wsw,247 +torch/include/ATen/cuda/CUDAContextLight.h,sha256=em83V7tAyfhqnKHCMBaJzmZz83z4aeS_Qrn7igVK9do,3013 +torch/include/ATen/cuda/CUDADataType.h,sha256=4JTbnmGsh3rt7lXfeMe4qauAYi7-hRKwpnbU35O9_p8,2889 +torch/include/ATen/cuda/CUDADevice.h,sha256=t7S0OV_AsARopbfquC6hGWnO4M0PgvwPwTTYBn65SRA,555 +torch/include/ATen/cuda/CUDAEvent.h,sha256=722oDT3mCnmkokZHNfBXfmDKIGw2Hrynnwrusc9iK90,9289 +torch/include/ATen/cuda/CUDAGeneratorImpl.h,sha256=tlE6fy29RKzYo7-eMidtujOpNWN0E4lt67weEx1y72c,6262 +torch/include/ATen/cuda/CUDAGraph.h,sha256=7lx_zbA75IXVPD5mkqoXxJrqoExgkaNOf8NV9EmY5po,3450 +torch/include/ATen/cuda/CUDAGraphsUtils.cuh,sha256=S5JvMK81gjexaeJ1EBEOc7K9Npt-4cpLWXsIgpfsLtw,1954 +torch/include/ATen/cuda/CUDASparse.h,sha256=CAQ_5Wc2lBCdvyRXg7wbbiK88UTz-a0EhCMmUIiySB0,2671 +torch/include/ATen/cuda/CUDASparseBlas.h,sha256=sC3nFgkZqNlhyXMQGpKcEsnAN3z7RwmpcVEmMoY25KI,13095 +torch/include/ATen/cuda/CUDASparseDescriptors.h,sha256=I4EoclUNNl3iILr6x4J5fdN8Udlj8k0eDZ9p24CX3b4,9637 +torch/include/ATen/cuda/CUDATensorMethods.cuh,sha256=NMDm85ZoVJgCF5a8IDFoYtTC6cjGGLbSRwZm7k7mocM,285 +torch/include/ATen/cuda/CUDAUtils.h,sha256=kShPcd6MMHM9udS_bqKHchrN6sHuf28T3o2GrpiiYIg,436 +torch/include/ATen/cuda/CachingHostAllocator.h,sha256=5ZBOg0kow967_zicCsDRvK5ex8bE66_PlMQA7Xf09DM,3102 +torch/include/ATen/cuda/DeviceUtils.cuh,sha256=5Ys4trVnuxkZK7YpVQTXTa6dYf9t--bR1u5zU988l3c,3401 +torch/include/ATen/cuda/EmptyTensor.h,sha256=JVhy_vgi-l5h8JU23Bi2tExQ-nP6yQdjYbeo2zSBuMI,1250 +torch/include/ATen/cuda/Exceptions.h,sha256=8v8QhFUBKvaTBHdGbFGn7V7e_jtgqAqyBVB3I03qNtI,12818 +torch/include/ATen/cuda/NumericLimits.cuh,sha256=l8ZMeSe4Z7YDm3eZOioax0qLXDxfW-QqErHdDzSTQis,5335 +torch/include/ATen/cuda/PeerToPeerAccess.h,sha256=BajqGhURmf0x3cWqQdQpJNRPJyKkGdFJKc54x4Vk4mg,306 +torch/include/ATen/cuda/PhiloxCudaState.h,sha256=_l4MLBwVfl97Y2lyUlJ7N5UKKPwzyWii7vZVDnY-LR4,90 +torch/include/ATen/cuda/PhiloxUtils.cuh,sha256=gXAtRGIoTgIVLwG5aZhwbggg9tQGu3_t6Xi7yfFPFdQ,99 +torch/include/ATen/cuda/PinnedMemoryAllocator.h,sha256=nfbOKweGeDbaFhvFSb5iIJzayl51SweNWuTugUC_iuk,233 +torch/include/ATen/cuda/ScanUtils.cuh,sha256=vly6Da0bLznDlFg6QD68wCm_L1W2ouZM5zzIzF3xtwI,2105 +torch/include/ATen/cuda/Sleep.h,sha256=suesWOumhYyEhgLdpcQR69Dh683RE22IfP7LtilxRHM,332 +torch/include/ATen/cuda/ThrustAllocator.h,sha256=XnhZ1t6xSoCEW-w2IDpbqye2mSNy_T-dX2vIfEGetoQ,528 +torch/include/ATen/cuda/cub-RadixSortPairs.cuh,sha256=k67rEqNM10Di_AULPjRbQaV_0UiXrJrVBIgy4jcB4m4,2279 +torch/include/ATen/cuda/cub.cuh,sha256=roVGOpWjnKEZ0TgFE2SJgvUcF2VF-OBZUmYKseb89vk,23186 +torch/include/ATen/cuda/cub.h,sha256=9FrasgI3FN9cI2cbZ-APR8lyGf-Rx06_EkUbmtNzsaE,3455 +torch/include/ATen/cuda/cub_definitions.cuh,sha256=Oz2Q_AyivGFhtFPSS7K99ziO42trv7dE-WcCe8B7W3Q,1510 +torch/include/ATen/cuda/detail/CUDAHooks.h,sha256=HUjzZY7rqXrEgl1PHcEPo3i2_k8CUdRXgrtJYROpK08,2729 +torch/include/ATen/cuda/detail/DeviceThreadHandles.h,sha256=97vDhT1VlEy3DJZB-5BmNuV0U3_cjOdr4P50L8e4Lik,7168 +torch/include/ATen/cuda/detail/IndexUtils.cuh,sha256=pZQVDPSx-nrWRHaKQ1D4ErRuBE96jJP8GHTGK9VaMIk,903 +torch/include/ATen/cuda/detail/IntegerDivider.cuh,sha256=jYQGvBlRr344Q401CFlF8fErsVj8_ZL5Zz9L7GJJ58Y,4143 +torch/include/ATen/cuda/detail/KernelUtils.h,sha256=EIrlefQ0rrNS_VY4--nSezZTaZ9p4W238ydXB0zNV1w,1572 +torch/include/ATen/cuda/detail/LazyNVRTC.h,sha256=WMzEfhr89A1AVJ-6pP5YRIXwxnZOEjLgRUlcpTtsBHY,231 +torch/include/ATen/cuda/detail/OffsetCalculator.cuh,sha256=d7em1tcbq5RF2ajIyNGRdo4PeuYuIl78szlycUQzp-s,4506 +torch/include/ATen/cuda/detail/PhiloxCudaStateRaw.cuh,sha256=LWCXbSKG02-cDaIs3nJAL8sFwaJ1tGKGF0rxBp24iCU,1401 +torch/include/ATen/cuda/detail/TensorInfo.cuh,sha256=96mVn_puxthtLQmp4XnI3uv0rBm3mB1Ll95X3LHiw3U,3355 +torch/include/ATen/cuda/detail/UnpackRaw.cuh,sha256=iDEYhOID9jnlai0_hcgWhZ2wnzCQuUwI_OKq7URVgU0,1773 +torch/include/ATen/cuda/jiterator.h,sha256=UUmWS5iO_7CC9yILbli86y6pxVqGsufKGXTElgAuXEk,1000 +torch/include/ATen/cuda/jiterator_impl.h,sha256=8_mj3vLsZnghVYqNpnY9BwtfBx9TIrEg9eVpm518I20,7351 +torch/include/ATen/cuda/llvm_jit_strings.h,sha256=sgRyT8521Q5edS-Qb8jrj-Ca4sbolZtP9O7xBFws4GU,442 +torch/include/ATen/cuda/tunable/GemmCommon.h,sha256=mrUDMLpI1hmLXUfbkjpS4XxhcCCzo4jYvJGfr_qrXW4,22516 +torch/include/ATen/cuda/tunable/GemmHipblaslt.h,sha256=3Wr4_-rw6ecMBhtSOp48GYSZRZ8AOCvHsN6Xn2DyLuM,21705 +torch/include/ATen/cuda/tunable/GemmRocblas.h,sha256=qXHjdlaxNn0p_BnFYfMmTg7XyUCj_UdH1qSbZSiGbx8,10671 +torch/include/ATen/cuda/tunable/StreamTimer.h,sha256=T--DrfI2qULg4NozdGGK5KaTCk5GJUgBtM2S3rgBx78,1095 +torch/include/ATen/cuda/tunable/Tunable.h,sha256=NE1TibxPxwjToWirPsScsfICvA6_VavazezGXl4myWA,7779 +torch/include/ATen/cuda/tunable/TunableGemm.h,sha256=n8-feksiuBGQs7x6-RAj6v12RTNW9_Qx46V7Tqy4OQo,9886 +torch/include/ATen/cuda/tunable/TunableOp.h,sha256=3V_u5m0By0kkpHY3opMO1Efj5qT--0kUWga_Yasng34,15124 +torch/include/ATen/cudnn/Descriptors.h,sha256=G_BwvOP2W514rGnib4nb1Zxco1sx8kSM-fN7Hk5XISc,15452 +torch/include/ATen/cudnn/Handle.h,sha256=Nf75UCb8ZTtzwoGTbZamRPEzbFnJAz_bZR9SIAggeqI,202 +torch/include/ATen/cudnn/Handles.h,sha256=5VYWmo7Rbp7rZ7tqo41LW8uPffcQMr3BnOeraCNXyFg,46 +torch/include/ATen/cudnn/Types.h,sha256=fvPWpAN_ZVdy_aWFG_TmCw9hs291DTyHt3L742WjX5c,324 +torch/include/ATen/cudnn/Utils.h,sha256=-qGtL87I43FefzKtAWQutgqmTIO6qhG8p2c10JdUYbs,617 +torch/include/ATen/cudnn/cudnn-wrapper.h,sha256=V5j0h5KwKmPn6TtQAEXoDnPtXkCrnfBQ7YKVL0vxiS8,556 +torch/include/ATen/detail/AcceleratorHooksInterface.h,sha256=gMXG2BLO4VzC84G3Eq0JLlGWIBFgZzXULp9rrSM0M7Y,3090 +torch/include/ATen/detail/CUDAHooksInterface.h,sha256=-qziyRQoArjkHw9juMlRCWru1mVsfdqYqF0cY-ygP_U,7805 +torch/include/ATen/detail/FunctionTraits.h,sha256=XisBM9m6dhb6TSpOB-xLo8dxXVDyHD1DU8hT8I3VKsM,3178 +torch/include/ATen/detail/HIPHooksInterface.h,sha256=eShM1keIXb7t8ZSqRlCByV2VS9U4LD2R4m_h4q1_7Pg,2071 +torch/include/ATen/detail/HPUHooksInterface.h,sha256=xJSyi04CUeHRVi2twliBySvPYZY3F0zxJyGoKXL3UHg,1502 +torch/include/ATen/detail/IPUHooksInterface.h,sha256=k9fDhlqBRAbWkmbU3bcP8K2fBB0Z7V6JP_mkq4KcThI,1287 +torch/include/ATen/detail/MAIAHooksInterface.h,sha256=0EraX-fUjg5bIqUSMiNGUQGEG4DBs2ucwPu1sVcwQ_g,1317 +torch/include/ATen/detail/MPSHooksInterface.h,sha256=i8sqYx6eUxLOPVsefApOMIOX53c6-kRZ16p_zNIk7LU,3865 +torch/include/ATen/detail/MTIAHooksInterface.h,sha256=f1ZSPKNugEYO9UqfWFy0rSFp4w8ODNuBJqGsl0XdjAI,4433 +torch/include/ATen/detail/PrivateUse1HooksInterface.h,sha256=lPuxBW_VCObijc4TDA6V_volt9CCgo-fYH80S_cSKVk,2578 +torch/include/ATen/detail/XPUHooksInterface.h,sha256=K_THShPDOYwBAbxYtdeOUlREnFoKwaXh5MbSZUk9704,2546 +torch/include/ATen/div_rtn.h,sha256=Q7k8n-6y8wBgGm7PCYEtbiRwpzevrhzHGZIRR15vcTU,222 +torch/include/ATen/dlpack.h,sha256=mFxGMjYMkTgwkShffsf1qmzoh8TiSlYgmvWogQmpC80,7243 +torch/include/ATen/functorch/ADInterpreters.h,sha256=ro4Qx9WgX-4S4IIR5iyt90lX2ByeUsW9UHaxQxG7IqU,1598 +torch/include/ATen/functorch/BatchRulesHelper.h,sha256=aAG3tefXbRwRscL1clVIPmSSdoJtezMYcmqRCb6gSwg,19191 +torch/include/ATen/functorch/BatchedFallback.h,sha256=6nc7zt07GiUdovhKWujknb1z1okQ2d-ppKQoTwgcTDg,3520 +torch/include/ATen/functorch/BatchedTensorImpl.h,sha256=vj9Nh81LPuwuAd934xsPUCG2wSOSr1P_pOSDrg6ybQU,6564 +torch/include/ATen/functorch/BatchingMetaprogramming.h,sha256=xQKz8P2IAt379YxdQtii7q62ZSdZEbF4MoIiiTt244c,5116 +torch/include/ATen/functorch/DynamicLayer.h,sha256=nFSf6Z7RUNNVxihqen9YqL-Pgtcl2C20CWAmZ78K1NE,5685 +torch/include/ATen/functorch/FunctionalizeInterpreter.h,sha256=0waU2epkJRIlw0EGIUCQ0mbhuGvLC463hKkZ1OgDT6w,929 +torch/include/ATen/functorch/Interpreter.h,sha256=4Nw_s6c2Otm8JJEBHr54CDAItN_CGLLvqZ6gny8cj3s,14472 +torch/include/ATen/functorch/LegacyVmapTransforms.h,sha256=c7obTpnca0mBbay29WvdPL0QAtV-oLE4x0QY86gSL7E,8439 +torch/include/ATen/functorch/Macros.h,sha256=Z8hbiYFWsvcM_BC6o_J4Pil5qQzy08uro3On4FJfCnI,53 +torch/include/ATen/functorch/PlumbingHelper.h,sha256=tvnthIraFvLOe1XIM47j_AM7omsUnXNHwCgStyP44pA,2917 +torch/include/ATen/functorch/TensorWrapper.h,sha256=E-heUCT274eyak8BnUHCl5ju-gA53bJIJoNiQ-N_C5o,4129 +torch/include/ATen/functorch/VmapInterpreter.h,sha256=sy5qLelkkoxdFVaK01jZW0x59kMu0r_ODM89AFtYQlg,982 +torch/include/ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h,sha256=xkwaSS6EDZxsz_-JieKG6PHDOZchv5JPd2EVrP3SMSM,1016 +torch/include/ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h,sha256=kRh_G1bry35jcSl_NVlZaPJLRybg_5gO-nieLSRbKNA,535 +torch/include/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h,sha256=7gcdZp8bNj5NoveNJlqGjmrjMjXqEnNTJa_3Xn94DMU,15553 +torch/include/ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h,sha256=ZyWygxLAUi9M2k_oxEcHNTdlge2PzNDJT2Gh_GSpo8Y,4649 +torch/include/ATen/jit_macros.h,sha256=cfnCnYngW-RIzvgPQFYJ1ioU9__tjop3o7qbYI9NUwc,237 +torch/include/ATen/jiterator_macros.h,sha256=4HKQ1TD4xob_WnlJRrxQLf6Mm2WiDKylhoeE-Agn0ms,1544 +torch/include/ATen/metal/Context.h,sha256=3w91OddOUfDDvQeeNwiIS1SQbfeLAuziEILIonlBzII,705 +torch/include/ATen/miopen/Descriptors.h,sha256=vYRgjvdGtdKkkpn09p7xi10cnqQ0_S_8xh3eiiN8wFA,6831 +torch/include/ATen/miopen/Exceptions.h,sha256=XPTOzEKeFSeuwImwaC0ywu7X99xv6Ujv7wCjADcQygc,1117 +torch/include/ATen/miopen/Handle.h,sha256=zpkdgahmce5x0cqVhd7nypNXsAw-_3TuDTtDEe70TaY,198 +torch/include/ATen/miopen/Types.h,sha256=S_VPo90v_x150HSidMF_v9h2flBMJhUHZzQsU0ynXWs,283 +torch/include/ATen/miopen/Utils.h,sha256=H-Xt_OYpmgEwhSan-TAeGsgDbLPT89b1MJTgRCdhG1A,419 +torch/include/ATen/miopen/miopen-wrapper.h,sha256=qnfwbu1F5TIpy3ILj5LDjwpoVm-YF1qprFLQ-cSINB4,549 +torch/include/ATen/mps/EmptyTensor.h,sha256=QxeOGTk1aqjG5AoNvtoEjeKqecm2E0LtEwtAX-Wco8I,783 +torch/include/ATen/mps/IndexKernels.h,sha256=bO1RX-k-88VZHAPVvnaEk105emDrqoKF5N1ilta_5ek,9696 +torch/include/ATen/mps/MPSAllocator.h,sha256=0ZW0G8jhUmgzksMhoopD7CUXukVNYjJJzzmkGM7ZzZU,18993 +torch/include/ATen/mps/MPSAllocatorInterface.h,sha256=uBgPZBG1FF5-Zbm70P_MqvH4F0slPDea9thIFpMnAE4,2786 +torch/include/ATen/mps/MPSDevice.h,sha256=DBhJMLthdtdaO0I_NEvc2MluLX_WnltBuBSjUFgVycg,1845 +torch/include/ATen/mps/MPSEvent.h,sha256=B3eR_4hbyfiSO7T75AXwgxeimzPbwbhQCZpW6kCENZo,3673 +torch/include/ATen/mps/MPSGeneratorImpl.h,sha256=xO_zDPgPtupQLbNlo0Desfo_UdWosFWehEAvuqi_oO4,1623 +torch/include/ATen/mps/MPSGuardImpl.h,sha256=Ig-AN5VliedNw-LvCeqx-FNEmC1mn6L7_aTWqVeuTek,5601 +torch/include/ATen/mps/MPSHooks.h,sha256=2Aeo4Ytkj0NvWSUNwKgDW1pI0DsE7kP8pkpW84FmzG0,2373 +torch/include/ATen/mps/MPSProfiler.h,sha256=QO0dzqBmNMHsFkyLcUQdt4J6BIc7611xvn3dbPvWCB4,17057 +torch/include/ATen/mps/MPSStream.h,sha256=iFxw-hQOwZvkYXttUyfU2JL2hyj9J-A81rkkKVPqCsU,4765 +torch/include/ATen/native/Activation.h,sha256=w-eq24SiSpy1fB8tmq5cV-O7M8pPL4USNS-7PX_AUwY,3588 +torch/include/ATen/native/AdaptivePooling.h,sha256=hWxPmjg1DCSNnCrJUtmOyc8PXE3-kPjXeZDrj8tKeO0,2477 +torch/include/ATen/native/AmpKernels.h,sha256=Aw5YDDuhaoP2GiUVIV4bRaa884ABelzmi_nOTx8eC7I,645 +torch/include/ATen/native/BatchLinearAlgebra.h,sha256=dX40wdYBja5PTmR5Ma8njRwvxAbve6PU9uGry0lTlcI,10217 +torch/include/ATen/native/BinaryOps.h,sha256=NGYNT-UF2s-C3WSxFhzR1ZL_3QZlr47mJzlGaMeUY0E,6051 +torch/include/ATen/native/BucketizationUtils.h,sha256=h0dygvjMzoURwdKPYUbw6XBbzI6_J5Vgbk534EJHZZ4,7962 +torch/include/ATen/native/CPUBlas.h,sha256=DTG5y_-mo88WI5DyV4nfKcZQj710xgTJ2VHtPVaWAJI,8669 +torch/include/ATen/native/CPUFallback.h,sha256=-qaQlDTyFDJZsaiWybF7bQU5Dr3SaZ46l_I201qDDu8,2459 +torch/include/ATen/native/CanUse32BitIndexMath.h,sha256=U52Kul-OwisG0btJZF2uE4f088sCDlhnbi02Tm1V3rc,255 +torch/include/ATen/native/ComplexHelper.h,sha256=GrvKp2LRHsH_Qd68Yjmp_Hznh5gNTqUN8kWAW8fXfQ0,4145 +torch/include/ATen/native/CompositeRandomAccessor.h,sha256=xPsPiT87SI6A5HgfosgHXxs0ej-EtIUfVqmH4jHSyxs,910 +torch/include/ATen/native/CompositeRandomAccessorCommon.h,sha256=hZmBUjm5zNTylyMM9op-YStUCqLK2xl6D5e4ahc2ZOc,6996 +torch/include/ATen/native/ConvUtils.h,sha256=gfKRcHxKLScPLdXB4jx0SxyfED4JGkfe6Egveo_igUM,20136 +torch/include/ATen/native/ConvolutionMM3d.h,sha256=0hDCvqA3y1Ha7WhuGQx42mF01mACnLxvFGL25iL_r3U,354 +torch/include/ATen/native/Copy.h,sha256=5-Wh5nVLhuPFMgj79OxBDquqvp-LfO5o_7k__ScvHAg,392 +torch/include/ATen/native/Cross.h,sha256=J_5BCt9RdEfgYe7Q8CvtPngfUHx6eaJxlfDKfxYbqWs,273 +torch/include/ATen/native/DilatedConvolutionUtils.h,sha256=vhVksJGxl7x3MEVJ3L4F_TFEuQYooiBpsKt1YVyvcH8,6631 +torch/include/ATen/native/DispatchStub.h,sha256=gipDcGk8JP0fnVoBJZ4LZkg-sq0kiCJadUobTt06gW0,16514 +torch/include/ATen/native/Distance.h,sha256=RTDQPiS-wsDylb4xOu2vRL6ISammoZM0AN8h6I0TBxk,740 +torch/include/ATen/native/DistributionTemplates.h,sha256=SH0WI8jN6QrMhZaDsJd2IJxdEw6QyD1T8zDo1Ayx8pw,18888 +torch/include/ATen/native/Distributions.h,sha256=tUPkbp644uqRp2qY6-GzrXWTIZNmog0zqmQJWSSso8s,22090 +torch/include/ATen/native/EmbeddingBag.h,sha256=4HGEr_ToyUfCTLMbTZUwib2BBjW1W_5UKoxay5hkJl8,5368 +torch/include/ATen/native/Fill.h,sha256=G-Y7r7VjTtPdQ6V4McYr1STrJAblQyC4RBjedEasTak,417 +torch/include/ATen/native/ForeachUtils.h,sha256=n5XpCzQdSPK9eogdl477X6RqgmXY92irgzdCwd9H92s,15476 +torch/include/ATen/native/FractionalMaxPooling.h,sha256=BkLXkSARJbVcWovxqf7JyevhPpEN5PX8ly-dkbTMKao,2241 +torch/include/ATen/native/FunctionOfAMatrixUtils.h,sha256=DxD8VGsuy4eDch3LAn7dE2tNpfEoiyHd3oA7AeRJ2F8,408 +torch/include/ATen/native/FusedAdagrad.h,sha256=8TnDNnQYyZm4I9M7g0Y6UtCOOpYt32tqxMGS7XFzbNI,515 +torch/include/ATen/native/FusedAdam.h,sha256=HcmPv9VWj9Cs9h4BxVmfHvLp8ybmlcQN5y2MS-Y0nyM,710 +torch/include/ATen/native/FusedSGD.h,sha256=F1moSFCTk7jISyYoNOroM6Aogx-3iNzVg5JxWL4XuzI,537 +torch/include/ATen/native/Gelu.h,sha256=Ex-Nqv6K8cMquZ2TeurctdXuqwO6r0gwo_lMxx_smIY,876 +torch/include/ATen/native/GridSampler.h,sha256=3vHkG4USE2wN4HAvvjqx_fSTXXcVvsEmj02eyVwDOxI,10705 +torch/include/ATen/native/GridSamplerUtils.h,sha256=tdVy5xNQZV1WTidufBgcAksfXtqwv4p__CdSTa4cS60,3604 +torch/include/ATen/native/Histogram.h,sha256=4P0LLz7d89omvU22HnAZQEboz56y0YZoKKsGSiJyAHM,761 +torch/include/ATen/native/IndexKernel.h,sha256=piBtn-V1Yz4eNmuQisPY2HZI4CRxFg6tPlcrlc3r8Jk,1744 +torch/include/ATen/native/IndexingUtils.h,sha256=TW2cKjC_z9otJUx_LZ7IBaZqgsueWw8phVCco2-fE4o,5789 +torch/include/ATen/native/Lerp.h,sha256=3rnb_-pij_rs7Yz5amtrMe_259Al8rHgrwyJaW5794o,1507 +torch/include/ATen/native/LinearAlgebra.h,sha256=sFafYjeSEdcNRAAAoupbGafLi36EYM-IDNUr9Ex-mhQ,316 +torch/include/ATen/native/LinearAlgebraUtils.h,sha256=1Fu61dliAnIXE5-jx0YoCjGlalOjknQXgcpJy6H3joo,26981 +torch/include/ATen/native/LossMulti.h,sha256=nk6hc3M5ZbWrN7PAL2IIPy-iXOS6PSa7yQXNNrSAd9E,2183 +torch/include/ATen/native/Math.h,sha256=C9sTRm5sdJAQX5EGtvEhTzbxvDyTuxQMmJtDerROwn0,145772 +torch/include/ATen/native/MathBitFallThroughLists.h,sha256=APedlL7LkuLODYxBKrnnu52Otm8tUxctz5v4KPJolAY,4207 +torch/include/ATen/native/MathBitsFallback.h,sha256=ST8hG7BE4LCp6xzScyH1sXwVqqfGv28M6Iw_eWgpRI4,7475 +torch/include/ATen/native/MaxPooling.h,sha256=MPog5iqNPVFWYCgUV_7QRLoWYUUzjE_i8sPx0rsopto,3365 +torch/include/ATen/native/NonEmptyUtils.h,sha256=LQZRFh6lr1mK38DiWjlJYd2tav2xzauFZfHevAWIZdo,626 +torch/include/ATen/native/NonSymbolicBC.h,sha256=lJJ_tyRLmmmNivbYh6To1Z4dBI5vJhWISVwrVdoU5Wk,2902 +torch/include/ATen/native/Normalization.h,sha256=XMTS5KKEP6qt8LMnd4Dum2G34idEoew9tI5hwV-qazM,572 +torch/include/ATen/native/Padding.h,sha256=a-W9-f6uCk52Lf9HZh1BttUWzq-G15DQPa86HqkTHXE,2124 +torch/include/ATen/native/PixelShuffle.h,sha256=Bcpq2zR7D3SGENiEWVR-L1fxuA1dBkaWNl7Qib8uwi8,1790 +torch/include/ATen/native/PointwiseOps.h,sha256=e7tofC6uVLaNmpNW6SwRhwibyp-q-QjjmTAujXaX-Qs,809 +torch/include/ATen/native/Pool.h,sha256=nu1_9q2vkruFKDthmwBgCXjW0UgxOFDg6xkG4_uWsQo,13613 +torch/include/ATen/native/Pow.h,sha256=kJqmBpmimTA219I6mxITLwqWL7eEvrDxxyI5Mi7ahCg,1718 +torch/include/ATen/native/RNN.h,sha256=H-VH3_OiEzxTtLxbI2dXCzEoxEZs9zfbFO7plxQt4c0,2557 +torch/include/ATen/native/RangeFactories.h,sha256=gN6xY4X3VM9xTOxWMUawxT4u4eIWeBFWo8le9tZxIl4,366 +torch/include/ATen/native/RangeUtils.h,sha256=Mnw5C91QKh8is4cxOntFz9Lks2Tve5PPlDGfZS83ZXg,2181 +torch/include/ATen/native/ReduceAllOps.h,sha256=yOimIBqQFDthtDGST-__q7HiJg92LLPUt2UnBJ901To,413 +torch/include/ATen/native/ReduceOps.h,sha256=dzcZt3pB4buPfTF-kGthU4x__AaOPFwBJXa1mFeGqcs,1822 +torch/include/ATen/native/ReduceOpsUtils.h,sha256=i2qhfJJlXSKiqwGbKa9bouPbmosfdug6NmVE_M2HKWM,16905 +torch/include/ATen/native/ReductionType.h,sha256=tRZO6vVCuiP3pCxmCgeqiJsqAoK3COfcnGYc-8B3w28,1179 +torch/include/ATen/native/Repeat.h,sha256=121zLyyDsGpo7nv14Fx2D2bX9HllOz8momYpyBv0mMs,1525 +torch/include/ATen/native/Resize.h,sha256=EALyfwXvllK_yum8EP_50lgXhQCo4cZsNI9dILaekL8,8370 +torch/include/ATen/native/ResizeCommon.h,sha256=QqlMUJtLul9FAXnay1PfEAv4TqbHg99Z3uO4YFtXBow,2566 +torch/include/ATen/native/ScatterGatherChecks.h,sha256=TlqmAnl4AU_wLQWsB0mktHvq05i3cEJntW0WpMn7KS4,3864 +torch/include/ATen/native/SegmentReduce.h,sha256=eSS5ICFfGKXzZb2lqKWDNWt37QWqj5RwRYcMtyP_0CU,1315 +torch/include/ATen/native/SharedReduceOps.h,sha256=VoSHNHdLlUDhawdTHBpDMRqkBhkscy7dd2_UzyzR7wE,16668 +torch/include/ATen/native/SobolEngineOpsUtils.h,sha256=NeugDPq15OXgIqLOuNMkGfilM0VTKaC5hRL_89ndURQ,1890 +torch/include/ATen/native/Sorting.h,sha256=f1LXeM0rdTtfaqMZY4SDcic3ssMmGENACm_RKif3bSk,644 +torch/include/ATen/native/SortingUtils.h,sha256=bIE98OX_i0DMF-GaBTu0w__N68bcoolkCWL92uCUYTI,2760 +torch/include/ATen/native/SparseTensorUtils.h,sha256=eI_QN2a4CjCIdbw-rN-RWy_Si4ovZS5VqFNqsCS1bWA,6690 +torch/include/ATen/native/SpectralOpsUtils.h,sha256=tAzQdNQuVOvTN_41z2wCXLV-dnhd9oI_hE_ehad42OA,3367 +torch/include/ATen/native/StridedRandomAccessor.h,sha256=5RWP8XBxj2iroOvF71RDImGSr7bsj5yLdf4Pt8NJACo,7136 +torch/include/ATen/native/TensorAdvancedIndexing.h,sha256=76nCAK-4FvL8sGrdEULvdIpWg6qlW0VQIifUnGa5uok,2999 +torch/include/ATen/native/TensorAdvancedIndexingUtils.h,sha256=2niso3mJph3wnwjSGJ-SCqtI2Pssdz_LptaQikE5mZA,3355 +torch/include/ATen/native/TensorCompare.h,sha256=vDWNhaC9SiOS4mjW3AI6bBy3FGOvjFvUUMev1Cvwhlw,1574 +torch/include/ATen/native/TensorConversions.h,sha256=kOwrRNiAMC1vKoa2JixxRuaM6lWOxV8XLnS5CUmzly4,869 +torch/include/ATen/native/TensorDimApply.h,sha256=agGvHJ320P4puhv-hubednM88WqbGft-g7VZYbXkAXI,1894 +torch/include/ATen/native/TensorFactories.h,sha256=62G-ZHvvU_n1DJ4O5WeJwtoHcjtYhaEo2SrtmmSOIcU,5567 +torch/include/ATen/native/TensorIterator.h,sha256=YtrZWy-qCP6CIX6LY2t6G7K7wuERpRuC2u6otAg1c7Y,48 +torch/include/ATen/native/TensorIteratorDynamicCasting.h,sha256=KcDx6tKNEAorvW3DWkMNAihMDUaD5n8ImyzpGkAtSh0,1862 +torch/include/ATen/native/TensorProperties.h,sha256=uFG9u7k1bBJ550OKSabBPSrgZSsRmLkrGj_3PACvBDw,210 +torch/include/ATen/native/TensorShape.h,sha256=PEexG4juPn8Bi82KZADy_Op67MDmTppF3Rbf6C4RvFk,4742 +torch/include/ATen/native/TensorTransformations.h,sha256=pjoFFiYCSipQ5jd60eSLm5F8X_SQLdXOphbHg01JL5M,962 +torch/include/ATen/native/TopKImpl.h,sha256=3ATbQ19F2TVzNvqdp6DlXLRq0DildzAnoRj_xi_nbN4,3557 +torch/include/ATen/native/TransposeType.h,sha256=fyc3us0hyOJEwyK0x2JvD2jSLdydgGvhj1WMeSzXMxI,601 +torch/include/ATen/native/TriangularOpsUtils.h,sha256=SBf8u6J6NTfFl2lBWIfTXKPlI3DHyXAsd4PWXOkojOo,2059 +torch/include/ATen/native/TypeProperties.h,sha256=PaIxe2X6O2ahFTN8HreJdhPXqO9Qrca_eZbQT6ollIQ,678 +torch/include/ATen/native/UnaryOps.h,sha256=Pf4hZbN8pyD5TX98nmPkGMdzcFXHFMdaJM5HAfenyQY,5543 +torch/include/ATen/native/Unfold2d.h,sha256=zZS0_Ac16KUXA90JzhNgf6AvIKq2pF-9trtxdMoHWBg,1027 +torch/include/ATen/native/Unfold3d.h,sha256=LcmkxSu9CkhvEviPb1HF8k6D7466H-kWlF6dBSBbMP8,922 +torch/include/ATen/native/UnfoldBackward.h,sha256=gv8ben2gdqJFFadg875Oz-ckGn5Uja4syYcekOJFrDE,3210 +torch/include/ATen/native/UpSample.h,sha256=1MRYOjMo9-4qaz70E0XlCX1zyusodpYDErxHvMmKStg,19721 +torch/include/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h,sha256=apSJnJpmb543Bnl12DB4lJ9TX2JuLv9ZeSBibfPAvxc,2995 +torch/include/ATen/native/ao_sparse/quantized/cpu/packed_params.h,sha256=PD4_u7pVrQGMjJtaLnhnNtlpjKvf0nNIKEc6BOP5IlU,2809 +torch/include/ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h,sha256=mTWA4X031p4I04LxwlchMJyHZVn0FKAOptAew27ASnQ,3332 +torch/include/ATen/native/batch_norm.h,sha256=0WBU59_s_-S9miN3AuK59-0HNKu7fVoCNL4kz6Hitws,1463 +torch/include/ATen/native/cpu/AtomicAddFloat.h,sha256=9t9RmFpCbL-0Li4OpcTajsXW-_WzQa-HeALl7Tt7bWs,894 +torch/include/ATen/native/cpu/CatKernel.h,sha256=NYMajqK5zwwir7vd605zck88852jcqiHjqAFIr_HzeY,318 +torch/include/ATen/native/cpu/ChannelShuffleKernel.h,sha256=itT2xG4osbiorOhkKIOLzKC368fqX2TPkuS9Nw0qfe4,300 +torch/include/ATen/native/cpu/CopyKernel.h,sha256=Rh8HoSfeC8SsqVk4tOJdzCLUa6euo1uqbOxNIMSAHAs,326 +torch/include/ATen/native/cpu/DepthwiseConvKernel.h,sha256=PAwD562BJU4ha-L-tOheXXBsoj_vPayu273FGywiAhY,491 +torch/include/ATen/native/cpu/DistributionTemplates.h,sha256=hteOOgC-9QwmkiZ4vzoo0RQRuQP_vRDdUYqeoUlje1Y,16889 +torch/include/ATen/native/cpu/Elu.h,sha256=x9Ak_bFyPKVxOIQ-DNyI0hARZLFJdusvxxYG4MsvOfQ,2940 +torch/include/ATen/native/cpu/Gelu.h,sha256=0jf6DZFUKlw6HxO7waBLlhoYFYUHwEdi2WDa7tpLQhc,3187 +torch/include/ATen/native/cpu/GridSamplerKernel.h,sha256=fOVEnusBoP8ktEXd84rivIeTUTWH57Hd_iR-FJDS7lw,857 +torch/include/ATen/native/cpu/IndexKernelUtils.h,sha256=Tm7K2LaZrAyWtglotUeCksZhIaOGkuBKbfOTlJOmpU4,3012 +torch/include/ATen/native/cpu/Intrinsics.h,sha256=ZcIIF9QKb0w4h5x6S7JX42ydWbDuSf1B8EBOukjZaCA,1245 +torch/include/ATen/native/cpu/IsContiguous.h,sha256=Ky3x3EDBfnZptf6LXYEka8nlY4gZnKVCEIrGjzhCOc0,2433 +torch/include/ATen/native/cpu/LogAddExp.h,sha256=dVwpkz5rmTvRB0VbRgmvr42IAWk-XvSCUM12y4rrE-s,2507 +torch/include/ATen/native/cpu/LogSoftmaxKernelImpl.h,sha256=dBmeEQbahBD_dAa6Rn_WWwoha5uJA5AWUHWekxNUQvw,13487 +torch/include/ATen/native/cpu/Loops.h,sha256=qHnPOsPzKbR7RdjeinYD2GoPb4KqQ1sNsa3CuwmVvUQ,15300 +torch/include/ATen/native/cpu/MaxUnpoolKernel.h,sha256=2cKRRLMxNIMX6vypgtEXEPyC6QIclvGJ3agxxmk1DRQ,320 +torch/include/ATen/native/cpu/PixelShuffleKernel.h,sha256=SGg4WEeCMr7Pd6UMUCFh8B03A-AwaD34cuQtht9r58M,334 +torch/include/ATen/native/cpu/Reduce.h,sha256=NuX1ptAtdpGW-T9lw9FY8ZosoSb5MvvdsVqgHTJ_jSI,12380 +torch/include/ATen/native/cpu/ReduceUtils.h,sha256=RQifpnWDImXIUgFgwR9J87r-R8umY7zggEzkuZ7-S_w,8976 +torch/include/ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.h,sha256=go3XiE8Agi8-o4vAt73CnYdJjXAcCjbfrUuqXNo6moE,1147 +torch/include/ATen/native/cpu/SampledAddmmKernel.h,sha256=79g8IipvP9FMjdvH2h1BCKvq7rK1Hh4xTeSLKnehtdI,334 +torch/include/ATen/native/cpu/SerialStackImpl.h,sha256=Vnnb4mYKjiT41HxshbOPRkScUWyjQUMDiI94IKdrS94,5602 +torch/include/ATen/native/cpu/SoftmaxKernel.h,sha256=J7d_n2OGJSOMcRdq29Fv6mYPWLI6ipL9Q7NNBjj1fZQ,963 +torch/include/ATen/native/cpu/SpmmReduceKernel.h,sha256=g3eSj61yRAzX7iGJ8pRXh9OGwNfg6RBkxET0PGnZCfA,1374 +torch/include/ATen/native/cpu/StackKernel.h,sha256=q3_Olp3FLU1gan4DonR8QApiYTuEYIB_P_qXXBD7_pc,320 +torch/include/ATen/native/cpu/UpSampleKernelAVXAntialias.h,sha256=NmWq1qOyYdD3abdEzqU45r50AYUVr84OVDXp4QvUMEY,59554 +torch/include/ATen/native/cpu/WeightNormKernel.h,sha256=zRaE5HeSq_Xaf9pfLQu3u6T7korC3nq4NPeWjoFvegw,570 +torch/include/ATen/native/cpu/avx_mathfun.h,sha256=A3D6L0Mu1xxCZoDTUpUbxq6_dwsC_no1TSB16_XlFZU,17971 +torch/include/ATen/native/cpu/int_mm_kernel.h,sha256=rgV9_-27CLGsbtXJOfQkSbIbTSrdfwcJXO9ZdxpvBm0,1143 +torch/include/ATen/native/cpu/mixed_data_type.h,sha256=qJ_l0zO9ExRkI0WUVIVEDePJlwK0fhi0Djw8vKFEWHk,1449 +torch/include/ATen/native/cpu/moments_utils.h,sha256=KvMNHMa70ritejKgqlaU25YtetwtkvSaK6rJT5ccltQ,6779 +torch/include/ATen/native/cpu/utils.h,sha256=GNF0DspT36GwpysPxzCZh2kcsVONPDlveBPqnXvmnIs,7363 +torch/include/ATen/native/cpu/zmath.h,sha256=GjtYmzTMZ7iYbznMEfNwWterwRSfpWGwzudl4y6XdyA,6872 +torch/include/ATen/native/cuda/Activation.h,sha256=bUQiLTRr97tZn5DDzOZ_d6i20c9ngGkWZLEG0T5tvuM,556 +torch/include/ATen/native/cuda/BinaryInternal.h,sha256=J6R3KlzgXjHEhdAZ0M9fpXL0cH68V-crbopAJ9MIlV8,1231 +torch/include/ATen/native/cuda/CUDAJitLoops.cuh,sha256=gYqHjKdDwWAck54k_YA9YGXYX0ig8K1MgV_yUS5qCLk,12205 +torch/include/ATen/native/cuda/CUDALoops.cuh,sha256=tXHIs8mTbCkmmoQqJOFdigcs5572bin6kNqSPxmz6bc,35607 +torch/include/ATen/native/cuda/CompositeRandomAccessor.h,sha256=OrszxRptzH-0CcEs7tVGQObnBI9ok0nQvjAvg8GbCKo,964 +torch/include/ATen/native/cuda/Copy.h,sha256=dXAA0lggucz5X7V9CU075ArUBinIdwH5ePq7szaKzfU,165 +torch/include/ATen/native/cuda/CuFFTPlanCache.h,sha256=A9OEVMzYx-VesnDw8yFjuwZXo0GARGVpnb_VX43RgNU,18423 +torch/include/ATen/native/cuda/CuFFTUtils.h,sha256=B0kZmcXiYYMx7C0KhoJ-Fz8msGxYIqfoZBrhyAFSESc,1946 +torch/include/ATen/native/cuda/DeviceSqrt.cuh,sha256=NEcpLMzvoNwJiVB6GLVdRcs9W-CXNQCVNlYDDES71IA,610 +torch/include/ATen/native/cuda/DistributionTemplates.h,sha256=GdUGJroNk2vgg91XW2jt_CV6QCToWQd0_JWGuaaE1KA,29946 +torch/include/ATen/native/cuda/Distributions.h,sha256=lLZ4jOyaLZyJUD-pTigWN7WZ9UaHJ4qiN9AITNEr9Xw,666 +torch/include/ATen/native/cuda/EmbeddingBackwardKernel.cuh,sha256=coa6GAqH0MzrxQr9bYUvZ4g5Ow6Xq4x0mKJ-JMOKV5o,576 +torch/include/ATen/native/cuda/ForeachFunctors.cuh,sha256=LgPKwk4p_9aH4WoANoE1AFfC8YsHIYyf7nCjZUAWq8w,25376 +torch/include/ATen/native/cuda/ForeachMinMaxFunctors.cuh,sha256=yHPnNBDV5oQiqY17xeXO8UFYf3gLe6dLliagjI2Zgfk,448 +torch/include/ATen/native/cuda/GridSampler.cuh,sha256=MiCqfMnvREMuP7crQDjrZ4a2PqahlrGLUSL0qUriX8g,11296 +torch/include/ATen/native/cuda/GridSampler.h,sha256=swQJkjxzqz86fGbZnw8jXT9hTsM7v7PrcARnqlL_8nc,1176 +torch/include/ATen/native/cuda/GroupMM.h,sha256=RC709rUFNw78rx8vpAlPz1Jfqvl7_i0QKHbbAiEDA9I,339 +torch/include/ATen/native/cuda/GroupMMCommon.cuh,sha256=JMgvAxsg1ttoZB4yUSHVO5gkEMWDUFZn69m_xejLaRs,6385 +torch/include/ATen/native/cuda/IndexKernel.h,sha256=tOFuXNaP66hcSRhVuEKXQwUgWjBqan_GzPpeuOcuQcI,353 +torch/include/ATen/native/cuda/IndexKernelUtils.h,sha256=_zWzBIlzY0X2eVRfB9enKyYYlDAIcTCwioJeTQ3zTN8,1941 +torch/include/ATen/native/cuda/JitLoops.cuh,sha256=B9HcboDzkT2Zy9d2dr7m-zNH6piQCaYo0kiNJuPIajM,7093 +torch/include/ATen/native/cuda/KernelUtils.cuh,sha256=4dhybPOqpxXgx3O7SNGL_BCjcurRZpxLjba7pdQoE6Q,13214 +torch/include/ATen/native/cuda/LaunchUtils.h,sha256=1a-qdCxNQSpFqgz5TQ7bv_u2XxMHEEeZP8deoTfEJSk,298 +torch/include/ATen/native/cuda/Loops.cuh,sha256=Ed7hpDNkiSsc-Qm2m37PLtdmZCWngcJcNGDmpj7qKmE,12068 +torch/include/ATen/native/cuda/Math.cuh,sha256=h1UcvU3hs-DXA88KCBLaYZqpx4p4YuT3hd7kADMs0Yc,126333 +torch/include/ATen/native/cuda/MemoryAccess.cuh,sha256=WTQRPRgWvYffySBhZVljqeaR2xoWEGG5tLq1O6uzpcY,22960 +torch/include/ATen/native/cuda/MiscUtils.h,sha256=zd6ODVUcY3yWeXfQTpuGSuElIXaw5-4cF37QmDmDa2E,965 +torch/include/ATen/native/cuda/MultiTensorApply.cuh,sha256=cKZAMmXfaIdX4AC85aj-UG4caFzXtyjrlwZA2aAjXYc,14317 +torch/include/ATen/native/cuda/Normalization.cuh,sha256=Q_l4a42YS5ET47huEzsSO-y2BzYXuOb1h_9M5qicr90,75967 +torch/include/ATen/native/cuda/PersistentSoftmax.cuh,sha256=3fW-Lz3aPHEnzcFuWcp3X4FxFdFCBOGa-betiTkx0hk,18385 +torch/include/ATen/native/cuda/Pow.cuh,sha256=o30Ts6PL5iUApZkSn3e706OfI_7VG_QA_c1-oHm52fk,2178 +torch/include/ATen/native/cuda/Randperm.cuh,sha256=pduqztZygnWLgWCVhgIR_s6QNROGjN_D6zeipJ3o7Cc,2166 +torch/include/ATen/native/cuda/Reduce.cuh,sha256=SKx83R04mbbKId0GOf2Lr93FwVXXCd1aheAVRNzde04,51508 +torch/include/ATen/native/cuda/ReduceOps.h,sha256=vK7q7aqyoIjM0z4rbtet8VKPhUjDxtfAI9f-5xSvLEE,508 +torch/include/ATen/native/cuda/Resize.h,sha256=5tdpuPNkTjhOB21zw0x5U9LiSV4rSXLLW3sLHl1r4n0,1597 +torch/include/ATen/native/cuda/RowwiseScaledMM.h,sha256=6iYXKL_dzZVaXWY8LLuTeErs_WRV0LlmP9JUAJMLhgY,383 +torch/include/ATen/native/cuda/ScaledGroupMM.h,sha256=64QFXg79hRiCR8HpM1U0O0CzoraFsJ4k7gzizaOUKu4,429 +torch/include/ATen/native/cuda/ScanKernels.h,sha256=1zfILyxxqd5HRNjTjAa577LsX9rxIZVMFGXUx-cuimU,797 +torch/include/ATen/native/cuda/ScanUtils.cuh,sha256=e6nrye0_SmsZ16ifxmdkhvV1MMBXfqzuwugPVRHbM-k,21748 +torch/include/ATen/native/cuda/Sort.h,sha256=nlEJ8hFW7VSh8w9LUAKf0RW4Rvrdt91YvRMCSIbKxcw,418 +torch/include/ATen/native/cuda/SortStable.h,sha256=LKrKGvanplh37OCGTlnKmZjLaW6qcVh77ESJZD-vJvw,456 +torch/include/ATen/native/cuda/SortUtils.cuh,sha256=fmYVJ7V0p66NgPHDkfmNoxqu5kLVqAxS9Ffu-QMs7DQ,12718 +torch/include/ATen/native/cuda/Sorting.h,sha256=vAnpVScSkqC1Ak5RaBMJL6bXR-EsdNoJMIxwb7_qd9o,413 +torch/include/ATen/native/cuda/SortingCommon.cuh,sha256=xFDGxI8WlXgal2R2wU6ViHZBtFSvLXTlO8HLCtQBfKw,5649 +torch/include/ATen/native/cuda/SortingRadixSelect.cuh,sha256=-KrnH7BRzeFir0VHrV58U_7auEblUYdwBrUYoaTYUKE,12741 +torch/include/ATen/native/cuda/TensorModeKernel.cuh,sha256=LrtSGe72VbUzuFa7e-BQXrQxWwSWyCZCbBXZ5TwPnZk,14860 +torch/include/ATen/native/cuda/TensorModeKernel.h,sha256=H8-rNfvEvsiJ84oax1qNJFQdzpSxcwh-Wbkb92Zl3X4,437 +torch/include/ATen/native/cuda/TensorTopK.h,sha256=nxt6_0ypGrKXHtEdDGARVD6yk4qJKEipRxTiyYSm8XA,267 +torch/include/ATen/native/cuda/UniqueCub.cuh,sha256=dA195d5ES3K-QMyaTBUGpRUT0rGPNvMmIziEr1wzEAI,314 +torch/include/ATen/native/cuda/UpSample.cuh,sha256=2IsVVsgm-fFOIdsSZVZAYYyTfZxIPmPRBO2_aqDbaT4,11992 +torch/include/ATen/native/cuda/block_reduce.cuh,sha256=l-QV1Zc-I4gYABAkqrg4MMevptB0B2xuN6MaAKdZiVU,4367 +torch/include/ATen/native/cuda/cutlass_common.cuh,sha256=VSr8shgUOLcpXDLPGRqBN6ey_OLCUDnCTjYIQ_rIbCk,1020 +torch/include/ATen/native/cuda/fused_adam_amsgrad_impl.cuh,sha256=TflsQN-84nDH3xd0GDz76fWSPKNV0PJBNn4zEpWKQOE,1067 +torch/include/ATen/native/cuda/fused_adam_impl.cuh,sha256=Z9FOhtX1QDtPzR-SXCjy2XiRRXK_R74MyWrBV8GmB0Q,977 +torch/include/ATen/native/cuda/fused_adam_utils.cuh,sha256=L0FlzFCP9zx8mODPeTPwnaYNmg9OMJMnFHSsO5BmAZY,7197 +torch/include/ATen/native/cuda/fused_adamw_amsgrad_impl.cuh,sha256=Uo-4skXIzxxiD97j_f6PkJqpe40jPLTwHESwIi1MMLc,1069 +torch/include/ATen/native/cuda/fused_adamw_impl.cuh,sha256=ucijwI0HIugvrCnhzXBFFlqcYTfDHoG_wLEpJ8bYPE0,979 +torch/include/ATen/native/cuda/im2col.cuh,sha256=SfgjmAgUEuwzxqsvHL74PTzBvbRWMzyv4RyfTKPP7Kk,10036 +torch/include/ATen/native/cuda/jit_utils.h,sha256=WLzbgLAJFvemzKjpMQT5VQlwRlzMzs_eZ7RnH23Iw-0,7360 +torch/include/ATen/native/cuda/reduction_template.cuh,sha256=k-KJ6L6q6FD263HJquY364pVLdJswbvIB-kaPLwwxSM,22348 +torch/include/ATen/native/cuda/thread_constants.h,sha256=-o6W7gVMlbgIfOjmOpEUadPiHIgqnYf-7zWHpQnSK58,685 +torch/include/ATen/native/cuda/vol2col.cuh,sha256=oy91IHLAYwuOC3ouap4GuqvveQNX7wBbDXhJbQwMiQI,8355 +torch/include/ATen/native/group_norm.h,sha256=jb8QS_68-AEwK5PHOcevNA1OVKA6WfTkx3SpjxjnF2o,947 +torch/include/ATen/native/hip/bgemm_kernels/bgemm_kernel_collection.h,sha256=BJU1HXJyNuI-W_EpQH9oGC2gk6feNeIswHqBGaB35Hc,3486 +torch/include/ATen/native/hip/bgemm_kernels/bgemm_kernel_template.h,sha256=o9ONsNbOast-I6iTeCFwX-apceldcBr6bQwJX7bdHlg,6142 +torch/include/ATen/native/hip/ck_bgemm.h,sha256=Z_f35wXbQHiPfXCawHgiCweJ2oYriOASuQEk0lSSakY,461 +torch/include/ATen/native/hip/ck_gemm.h,sha256=-cnlGlh2s3zXGy9K88YybaFDuB0pfYNhCNso3CsQWYI,646 +torch/include/ATen/native/hip/ck_gemm_template.h,sha256=nNLDeMde8ylzk8rDWKGEBuy-SQgf9RPuCKRq4a8SyY4,16112 +torch/include/ATen/native/hip/ck_types.h,sha256=7zAYKS7Ec1rFDNvXgBrS84gom_nlCVEXSKR_ZLaosVk,1717 +torch/include/ATen/native/im2col.h,sha256=EjbClkS0NVTJOqxSVC6xH9EUNZJGelabyBz3u6hTqfI,5386 +torch/include/ATen/native/im2col_shape_check.h,sha256=xOI0CckFMlW-BFO9yFNiE0IFXBInPq3hwE_asoru6i4,7167 +torch/include/ATen/native/kleidiai/kai_kernels.h,sha256=vOj_vGOa6gobkSPZU0xjHDcEh8mewvbL-cDiYyZTrRY,1088 +torch/include/ATen/native/kleidiai/kai_pack.h,sha256=6zkKqzxsx8WDTYdzJU4SuwCbDZlUefEtlBaBYDprsFU,2815 +torch/include/ATen/native/kleidiai/kai_ukernel_interface.h,sha256=n-rByvG-gwaUbQlAppJU211GTATY0yek4cSO-6-Qj6k,5019 +torch/include/ATen/native/layer_norm.h,sha256=Po3G0Y2mJ-HgNziluT8SLMY287rb2KvUMFgRfeZDHHE,4344 +torch/include/ATen/native/mkldnn/xpu/Conv.h,sha256=Uzs8XsSJTknrgc8ltPp1izEAONvcMULyxOqyLMf_ZtU,1611 +torch/include/ATen/native/mkldnn/xpu/FusionUtils.h,sha256=0anrbXOBB2TKCBgyqsaqfHpPCsVo4ChxAU9JHM2rJ4I,1798 +torch/include/ATen/native/mkldnn/xpu/detail/Attr.h,sha256=Zk6Th-fCT_eOZyehsMENefGicmzIuNwCqjF_kSA5Wns,17877 +torch/include/ATen/native/mkldnn/xpu/detail/DnnlExt.h,sha256=y7wNxbiqndEArDWjuXA_jOll-Nh1q-2gCyY7kiJQzn0,18636 +torch/include/ATen/native/mkldnn/xpu/detail/LRUCache.h,sha256=MYKh-sGMX7OxJS_TKhqm3GnFBKHIH6oMMnoVADgE9i8,2571 +torch/include/ATen/native/mkldnn/xpu/detail/Utils.h,sha256=YNRS_pG2pBesxJPn-nrrqcZk2ibs34drhPIxRZhbGsA,3990 +torch/include/ATen/native/mkldnn/xpu/detail/oneDNN.h,sha256=qsqEgmuC-lJ9AY1_mjKNWmM2MIhnXzLCTHK1ZU8kFoc,5439 +torch/include/ATen/native/mkldnn/xpu/detail/oneDNNContext.h,sha256=tGlCQGoK58hgY8-7m32oLkbB6KE_5jM2Yf6Ih8f4omA,2802 +torch/include/ATen/native/mps/Copy.h,sha256=9yt3cf9b2x8sACUn7Dx9iAIVhTsMSijASyhMSnBCDSo,309 +torch/include/ATen/native/mps/MPSGraphSequoiaOps.h,sha256=1jfG2B3WoFEUXYwON5B6mQsVWmYHrClbiMF6H8ekYPE,1560 +torch/include/ATen/native/mps/MPSGraphSonomaOps.h,sha256=8rqs_yz9ejFJuiurvdCaTWpgLtd-GqvGzGxsvU-asfQ,2376 +torch/include/ATen/native/mps/MPSGraphVenturaOps.h,sha256=Dvo7Od0ZuQGy9tXitVp6k4jTwHhOp9jP27-7chnN9Yk,12413 +torch/include/ATen/native/mps/MetalShaderLibrary.h,sha256=C6e7g8xfz4ayJjyEQDAIWGgD6FQ1Sh0Ayy5fWh4i1tw,5973 +torch/include/ATen/native/mps/OperationUtils.h,sha256=eKutK1QjDE8TSxUw5inZXDK2KnTShyus9rbgjgWUUk0,25602 +torch/include/ATen/native/mps/TensorFactory.h,sha256=q93kMHFDZJmhvl20JP2TQEyebF56PD3cOU6J9l7nZl4,1018 +torch/include/ATen/native/mps/kernels/UpSample.h,sha256=JylkR3sgtOO-Vk3knKOdvSn3g8UAGhJyPJq8eIJoO2s,471 +torch/include/ATen/native/mps/operations/BinaryKernel.h,sha256=ZHmp71Qg_hoWo2Qbg1Vq6Gye5QwHTIxXuJ3ZnH4IK0Y,270 +torch/include/ATen/native/mps/operations/FusedAdamAmsgradKernelImpl.h,sha256=sfBhyWoGgcP8822qYRI7FbxO2_sqjwIq4Knqs0DwPYY,1019 +torch/include/ATen/native/mps/operations/FusedAdamKernelImpl.h,sha256=o-eOOlQm7ry2mlPptzC3c12hrRKUrKKfKn2W7kgKAaI,923 +torch/include/ATen/native/mps/operations/FusedAdamWAmsgradKernelImpl.h,sha256=z0mOzMgGuHKaRPZfwpf0nLjdX4c7oHma7lWwz8u-16o,1007 +torch/include/ATen/native/mps/operations/FusedAdamWKernelImpl.h,sha256=70t7n1eAAz2pvgi919_S0Cb9XQ97fkrr6H7Fj0NWmBY,927 +torch/include/ATen/native/mps/operations/Indexing.h,sha256=RdbqTFd_YYLv2IiyADP_0VAcPHA73xJDJ4MN4pjFv9I,255 +torch/include/ATen/native/mps/operations/MultiTensorApply.h,sha256=DlpVhik_mM3ke5SIfbNtYPQXyVrLs0_WJJ99dQzh0Xk,16386 +torch/include/ATen/native/mtia/EmptyTensor.h,sha256=f6boHeoMz3bJ8nearpnH7EZYW--H1n3sJ96Sl_EzkbQ,1124 +torch/include/ATen/native/nested/NestedTensorBinaryOps.h,sha256=b4Wv15orB2RVyiiheAgNyCVSSwD02nvDrS7cS4oN-Vo,431 +torch/include/ATen/native/nested/NestedTensorMath.h,sha256=dJlZkSxkLET9Htvsbqm1tl-UX0NZN-Md6ABg1oFuoGg,2797 +torch/include/ATen/native/nested/NestedTensorTransformerFunctions.h,sha256=Ur2wFcKjlJFGhtAEyJXXu8QftKk8X9kaiXpWR-Ah8EA,2935 +torch/include/ATen/native/nested/NestedTensorTransformerUtils.h,sha256=Vlny2PQ6hLZWapI-QXWor1mjj_I7z4WmUZ-sijCrovo,1438 +torch/include/ATen/native/nested/NestedTensorUtils.h,sha256=EQvyj6omJOeR8mTaUwwg11nDJChPHlP88v1cTIGk3oE,15697 +torch/include/ATen/native/quantized/AffineQuantizer.h,sha256=HJ_Sp3eSIElfYrR5AWsFwxTj7uMkdCyEKgqMpMoZmVg,3843 +torch/include/ATen/native/quantized/AffineQuantizerBase.h,sha256=vQygeUXuEGfppDz4sBSTEXqNxDukNdqusL9GGq_ALQk,1472 +torch/include/ATen/native/quantized/ConvUtils.h,sha256=seN12fWOkAEvHZBTtO42HLjn281xu49n4H1oTerRQww,2302 +torch/include/ATen/native/quantized/Copy.h,sha256=xah-h77O-cmGIXLLfOpOxgO7vgDbZLDIJuTqaj1zo1I,164 +torch/include/ATen/native/quantized/FakeQuantAffine.h,sha256=g8rcQYh0LYvi69qvrbjW73cul6GfFacRlViWlL83eGU,1854 +torch/include/ATen/native/quantized/IndexKernel.h,sha256=HxCBEAyIkb4pGBqHPBVfh3syua8tyZjAmQ-l5lFjZMI,593 +torch/include/ATen/native/quantized/PackedParams.h,sha256=M2X-c_eZW163CZBXll-Wtd32WG20IPyB4pvxC0j2qc0,4886 +torch/include/ATen/native/quantized/cpu/ACLUtils.h,sha256=jeuJnPwQdUd5VIhv8q5F1GVp1rhHer_dvRf5237nCsk,8144 +torch/include/ATen/native/quantized/cpu/BinaryOps.h,sha256=1d5kERam0caSwpUMgndD2zkxRXPx5ApfyVW4HTXMeQM,174 +torch/include/ATen/native/quantized/cpu/EmbeddingPackedParams.h,sha256=J1RbfggPZZMZ426fz8ZeX4pwCI6reK_fLBqNSVxRvVE,950 +torch/include/ATen/native/quantized/cpu/OnednnUtils.h,sha256=_G5wDCtMR0GwugGWOqE4dfA_qbJEdAWK_z6poHbVlCM,14386 +torch/include/ATen/native/quantized/cpu/QnnpackUtils.h,sha256=SZXWhwzrnIFyVpBFUIg-sPV1heGmH2NoilaMkeSIYJI,17918 +torch/include/ATen/native/quantized/cpu/QuantUtils.h,sha256=b-Ai8zO0xtNFNh-fi70GVDzC_Gw6Usq52qQsTMxUqTE,8571 +torch/include/ATen/native/quantized/cpu/QuantizedOps.h,sha256=N7KYcugJER0k1vpuf87mNX7FQ3qqjHM3TrPLvas5N-s,8892 +torch/include/ATen/native/quantized/cpu/RuyUtils.h,sha256=GOdYSMJ0WWhJIGJ0oB2UxjMbHsdn9XSKYihVJjG2SrA,355 +torch/include/ATen/native/quantized/cpu/XnnpackUtils.h,sha256=IFAn3vmPcSBUU1FTDac1pXh8BMqbcCSFWV0OZGBJXIM,14436 +torch/include/ATen/native/quantized/cpu/conv_serialization.h,sha256=kFleeLmR86deXBe75GuoL7t0hBfxACVs2bidORM1ibo,13305 +torch/include/ATen/native/quantized/cpu/fbgemm_utils.h,sha256=1370XGklReICLC2-OiKPdC07oP3aBkBlNIwQeOmocc0,12345 +torch/include/ATen/native/quantized/cpu/init_qnnpack.h,sha256=fiGGqZRkWMnGbkWC19hGb3ye7jpqu8I-4EtCcLi2nps,132 +torch/include/ATen/native/quantized/cpu/qconv.h,sha256=lMRNupeDijr1pYWHU-rPAQFKg4ZdAs4EcBLxu1l3R4o,3655 +torch/include/ATen/native/quantized/cpu/qembeddingbag.h,sha256=UjbET4aD5-EHI-3Je8Ku5UCPu5HwG5F9TX3qhdw0rEs,1047 +torch/include/ATen/native/quantized/cpu/qembeddingbag_prepack.h,sha256=y5-QKhJ8s5plKNM4QOWj122f1vWXngG60e_lwIFHPPE,306 +torch/include/ATen/native/quantized/cpu/qlinear.h,sha256=yce7QT-VdeGgB3-mliUAA5pDctYG71aXJXnP0ZRHBx4,1729 +torch/include/ATen/native/quantized/cudnn/utils.h,sha256=PNWL1uxglz_MYs22kFCZb0tVVoQmyaJ1ncZHT6vzGf0,10800 +torch/include/ATen/native/quantized/library.h,sha256=0QEhBrJ_hnNqampwXoJWaKFA5ctUKrkw0XPxhbPksW4,197 +torch/include/ATen/native/transformers/attention.h,sha256=c_iyq7Ok-GFrCY9x1KEs77QvXY8CAQM5vDXnvDNrFWI,2336 +torch/include/ATen/native/transformers/cuda/flash_attn/flash_api.h,sha256=x0gi-BOyBxl3_PxBDz4C_kZhUO6TaU6xk-1wc7795IY,5450 +torch/include/ATen/native/transformers/cuda/flash_attn/static_switch.h,sha256=3vZendL_QmHgNhz7ZbQbiqwdobIUxHu86qPHnH-XCr8,3874 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/debug_utils.h,sha256=LvN2r7424Fm00K529vqGtPNz_tAU066tCEXPqMHJY1U,10464 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h,sha256=apUvn3fJqux3HE_qTIX2Sqp_jNafLk5Z-TOtIa2_O7g,22918 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_rescale_output.h,sha256=GRGZLJLj3tQpe7l7GAv46_-r0Om3dnU_Ig_LYxQRBZ0,8018 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_thread_apply_logsumexp.h,sha256=PgRuxPf3-iIgmA7v24rv47GY5CbFN6IUFMfh63R6bJY,6286 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma.h,sha256=hzmvWp4fIBYcIlBYNDUo-Kx25tySxKEmZtNdLg4-jeg,2591 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_base.h,sha256=qdzlHRRVdaPO6hgm8ovZKbmKxGrGSL99L7GS6ZioIIw,6424 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_multistage.h,sha256=WRayBISIk4yyt76w5olKND1kYJQeyyFafQt1mm2CniA,28049 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/custom_mma_pipelined.h,sha256=22aMPEFRIu65LC_KVRZkMvLFYPKEjn2EmZOh3tRtviI,14547 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/find_default_mma.h,sha256=0JDFoyKaePmgSjFtO5Re5IkkK0l7DGXilMrAch7IxFc,5343 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_accum_lambda_iterator.h,sha256=AwO-RQUZlpB9sSzz8zQycRmeDOG6ilEYkmDMrgzpYbQ,12704 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm/mma_from_smem.h,sha256=h7Yx-Qafve61cOC5-gTlSxLM3KRo4XYlkc-vDQ7GBgA,70748 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h,sha256=dw2xDrbYzaDfQ6YRIIJ576DRpRO676uluS8vQchPsvM,8817 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/default_warp_iterator_from_smem.h,sha256=eb2AJ8_q7yQtWsbLci9U79p54ZEMNbJwQdpD3fqllGQ,5970 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/epilogue_predicated_tile_iterator.h,sha256=IR4UUaUKAnHHYmadvgEqiLAaYh9wx3VNll1HElDGrcY,24607 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/make_residual_last.h,sha256=iaC9pAgBiOf8lKa985I0cyi9S74ljs8CSuwVtQQqzrg,1724 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_access_iterator_residual_last.h,sha256=PxutLEQMtYpRFvHabdclrx8dnkLqJ7X4RKQE8TkRUvw,66581 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/predicated_tile_iterator_residual_last.h,sha256=yYskm96IHojx7s1g3kumdXQvTm48rejmXNKwySdoh18,66619 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/transpose_warp_iterator.h,sha256=Qgx8Rdj_Q_igd8bDYeoCanck9gUATK3AfKM_MnfEtOE,992 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/iterators/warp_iterator_from_smem.h,sha256=9i3NylTSGikNzRdUlYRTjVQ_e0V4JLL3rxxlv9hFsTw,10339 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h,sha256=fwJe_51nBCb2gCRy_JoLZTpY4t40fIwhfyvur36idDo,102648 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h,sha256=GW2YBlHZUVcL_cNYQAe_2aAX6Yls6P4B0vFUGD07TSk,55467 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h,sha256=xZHYVL1SJwsuYmgWtcI-183bZodTHgw1UQgU8ZszAlk,97555 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h,sha256=WrQY-0wxX3aZ6rkFvKf1jFzgBDXnGFYIiXwdeP_H54Y,26263 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h,sha256=MJt_kpyDlp4jEzKntfUfsUSbyxqxDumwMDVty6jVpGs,1004 +torch/include/ATen/native/transformers/cuda/mem_eff_attention/transform/tile_smem_loader.h,sha256=UtRRYzipVcnQXed5JGSSr8yhgZdxXNcB2pGvUhAJSUs,2218 +torch/include/ATen/native/transformers/cuda/sdp_utils.h,sha256=tUntvlio9-zOrU-gtD9oO4W106aYU7LxmdhskIrHa_U,644 +torch/include/ATen/native/transformers/hip/aotriton_adapter.h,sha256=FxMJAHNO1JbJAAkVd6ZIqKYnqN6HvTOB-VxOiXzx4t4,3836 +torch/include/ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h,sha256=lmH-Gi22koQUzVO__P1iFTC3Ef5tir7MGGV1gJifcE0,1846 +torch/include/ATen/native/transformers/hip/flash_attn/flash_api.h,sha256=jJlM_XWc2dA4LDjpRWnwiORuLqeve8Pynaqtve74qpE,22560 +torch/include/ATen/native/transformers/sdp_utils.h,sha256=qBXR9NYmOGRk3Fwyc79oTjen22QL3vqGq_CUIrdF5Vc,2959 +torch/include/ATen/native/transformers/sdp_utils_cpp.h,sha256=4wkcRcVV1-Oscx0DC6dZ2E_EhzuWkt-0sJFmq7nIRVA,18477 +torch/include/ATen/native/utils/Factory.h,sha256=PmNgcH1tGKt2NQ5W17-62tty7Frqj2VfEt51yl6jGfo,507 +torch/include/ATen/native/utils/ParamUtils.h,sha256=5tLsOeu0UmFj2YyypGCo_5dSthnJNTM9QLflqyx3V1M,1240 +torch/include/ATen/native/utils/ParamsHash.h,sha256=naW9o6L151VBIFZer2OSrmW--fzQJUm_s5mu6kq66DQ,3228 +torch/include/ATen/native/verbose_wrapper.h,sha256=xS8Eo9ShhQLbBJRjg5BydHX3miqWGT7bVbJ7VGr_XhQ,201 +torch/include/ATen/native/vol2col.h,sha256=xqXZK-qrhU9nxkhjxD2w67chbrdOGgajdItW5HhU0Hc,3664 +torch/include/ATen/ops/_adaptive_avg_pool2d.h,sha256=itw7lux01nBoJqpreYraQiyx-On-CXI_NKxTj275xf4,4233 +torch/include/ATen/ops/_adaptive_avg_pool2d_backward.h,sha256=YgZNzNeQqO1d3CDYl3PwtlfMSlR3RWzkQhHCPEzkbBU,1492 +torch/include/ATen/ops/_adaptive_avg_pool2d_backward_compositeexplicitautograd_dispatch.h,sha256=pANA8PPa9w6d39hwIAW3d_vIYgze_yFxB4z2pUClmlI,1001 +torch/include/ATen/ops/_adaptive_avg_pool2d_backward_cpu_dispatch.h,sha256=oa4uVfaqzbnO-x2HJeB1JMiJXP-OFoDZ9V3bsTJWQg0,798 +torch/include/ATen/ops/_adaptive_avg_pool2d_backward_cuda_dispatch.h,sha256=au5tf5wNJxE56ppJ1Q8Ccb2WdOw8SNr2sbuXWyd9AkU,800 +torch/include/ATen/ops/_adaptive_avg_pool2d_backward_native.h,sha256=cZ5-FsV7aKFO3G9DrUwO1F_CxJweRHDj3reGYnJZkNo,803 +torch/include/ATen/ops/_adaptive_avg_pool2d_backward_ops.h,sha256=zZMl0krMLbk_tNK0r09au2sYHLSUkrPvEVPSMe-xZGY,1933 +torch/include/ATen/ops/_adaptive_avg_pool2d_compositeexplicitautograd_dispatch.h,sha256=aq4CmVrYzsGS8e2E64mHIxJG_obvRvUEb3c-5cKSELw,1244 +torch/include/ATen/ops/_adaptive_avg_pool2d_cpu_dispatch.h,sha256=irE7CDWQ91LzrzoXf288NBRAi02EJaYAe48yIiOfk-M,895 +torch/include/ATen/ops/_adaptive_avg_pool2d_cuda_dispatch.h,sha256=AT9OJq4wdLAHaOYljU7cbWV9jVpsgJSRERDmOPX94ec,897 +torch/include/ATen/ops/_adaptive_avg_pool2d_native.h,sha256=8MOZRzNkLafqy1zh0NwVE0rwZQ4Avuh6LMB6iQrHj2c,1001 +torch/include/ATen/ops/_adaptive_avg_pool2d_ops.h,sha256=3jEzG9EtGb-j_tMHKykPLwg0vW9r5K3tRMWvo1pDSEM,1891 +torch/include/ATen/ops/_adaptive_avg_pool3d.h,sha256=04Gn2Lle9wdZrfdP4x_80nbAqUla4JBFusNQQ30FDCw,4233 +torch/include/ATen/ops/_adaptive_avg_pool3d_backward.h,sha256=0ZYEONWCpF6gVr1olyzSV04C_UQyzQJZ8vMtOSKRs9w,1492 +torch/include/ATen/ops/_adaptive_avg_pool3d_backward_compositeexplicitautograd_dispatch.h,sha256=Sczz2-fRLMvyF_lPhvBSGkMh-ZOuaxdkZlMuiFjx0w0,1001 +torch/include/ATen/ops/_adaptive_avg_pool3d_backward_cpu_dispatch.h,sha256=AtKoGN1BgAc8mD58wJyY-b064QsjZxCXr2-AudAfPRs,798 +torch/include/ATen/ops/_adaptive_avg_pool3d_backward_cuda_dispatch.h,sha256=1jdPmTS16FpDTFFWV7Xh0_tZPrGl5w2I_KOCIvv2htU,800 +torch/include/ATen/ops/_adaptive_avg_pool3d_backward_native.h,sha256=-FmiN5nSrzX3q2VGef_M-Sfm8WPetVj4OKMI7lyHjkg,803 +torch/include/ATen/ops/_adaptive_avg_pool3d_backward_ops.h,sha256=CFht7vmJVBzNOevVUWy4aYkopeEXrWTkcFeZV2l0D7A,1933 +torch/include/ATen/ops/_adaptive_avg_pool3d_compositeexplicitautograd_dispatch.h,sha256=11NH6mgY9nIvLuB1tdzjxdf2HxqgVmhrlNwtoMFF9Ho,1244 +torch/include/ATen/ops/_adaptive_avg_pool3d_cpu_dispatch.h,sha256=slB5A0_V-CXs0CcNc7GyxjFJQpABjN1yVk1FoM3F9SY,895 +torch/include/ATen/ops/_adaptive_avg_pool3d_cuda_dispatch.h,sha256=eNPirPuPg2yZblx_dceSSIq8zHil_l8-zsxSsZBDSBg,897 +torch/include/ATen/ops/_adaptive_avg_pool3d_native.h,sha256=gXBEXtf_bGOIK2KA4VJK1DrTFi2rndKf9ZXhubOSHTA,889 +torch/include/ATen/ops/_adaptive_avg_pool3d_ops.h,sha256=kzWTVJNRAzsDrFiIJt79jnQBnoTQa6iCN6nsXrNXR_s,1891 +torch/include/ATen/ops/_add_batch_dim.h,sha256=_8v8ZaREh475LlrBf-lrnziwjIJ48ST-pZu8xi3ak6Y,778 +torch/include/ATen/ops/_add_batch_dim_compositeimplicitautograd_dispatch.h,sha256=ScLkbrTMOURu8UPSgO_OWen7t1EMu72ZVf8gUJxCYbM,829 +torch/include/ATen/ops/_add_batch_dim_native.h,sha256=REBJVuUyHnrV64k2nIXCVxoe7NHxTtL28gx6d1aU4R0,539 +torch/include/ATen/ops/_add_batch_dim_ops.h,sha256=fo0yJwRReyzZFNX0lSBUQi11Ep9iYOC-FYTpM3no_Gk,1122 +torch/include/ATen/ops/_add_relu.h,sha256=1kfy4kEvMjay9eqYHU4TXkpr3aMdDDseW_GgVzCvkfA,2845 +torch/include/ATen/ops/_add_relu_compositeexplicitautograd_dispatch.h,sha256=jF86va7Sq0DA3dpJCoRgpCUB7mQC-4qjACiRjx0AyT4,1003 +torch/include/ATen/ops/_add_relu_cpu_dispatch.h,sha256=EQe6rq-p-IR9g--Yww5jEHkFLOCfN4XuAevS7JHhxns,1401 +torch/include/ATen/ops/_add_relu_meta_dispatch.h,sha256=RgEkbMxE81nAJ3J6PUMnEOtuHt_VCresrSlR6-eBtQc,908 +torch/include/ATen/ops/_add_relu_native.h,sha256=3o5F1bQEWvqNmdUzoupiPlybl9Uphh_nTP2cIn43rWc,1154 +torch/include/ATen/ops/_add_relu_ops.h,sha256=4f5NarpI5JQaVyRD1gG-2kkisKWOXfV2Z6CarbafbRM,4924 +torch/include/ATen/ops/_addmm_activation.h,sha256=AgSR4vgVW5MT4h-_1L40hrHLiyuWO2dx_-On1u5Y6iw,1887 +torch/include/ATen/ops/_addmm_activation_compositeexplicitautogradnonfunctional_dispatch.h,sha256=b1Xj0KBf6aj0r-aTnJej05kLi8KgCKGtL60NG4GbQC0,950 +torch/include/ATen/ops/_addmm_activation_cpu_dispatch.h,sha256=iYibqkZe2aOQoERWv0GXxT6PujHEzLvsBUFWpcHLuFE,1303 +torch/include/ATen/ops/_addmm_activation_cuda_dispatch.h,sha256=u4LtSqOjXdXzOJbO-uCKG70t0yWbu44sXfUyOVnsTyA,1305 +torch/include/ATen/ops/_addmm_activation_meta.h,sha256=1TryV0ZObKuOb4KQ5gd64pMezBnUL45entGfmBX-R-M,721 +torch/include/ATen/ops/_addmm_activation_meta_dispatch.h,sha256=j7HY8YwS5FkRG1sy6fVgY4ufvACgxcxYdbHKgKOjKm4,1305 +torch/include/ATen/ops/_addmm_activation_native.h,sha256=jXvUQmPsz3LENLuHcr3hBY44Z3REWoHsoQ04eqefakY,1057 +torch/include/ATen/ops/_addmm_activation_ops.h,sha256=eg5saEt70074RSRbFTi0JNzhcdwBc0OpoSrzmGnmdTI,2448 +torch/include/ATen/ops/_aminmax.h,sha256=3ssX0J5vIl1m9B2G-1r0YvYVsQbgUUsSgOJOJoyCQIY,2304 +torch/include/ATen/ops/_aminmax_compositeexplicitautograd_dispatch.h,sha256=eNo8LlyCwr83SRFc6Bcx6oZomBOv2PDpARN0I5hTsHU,1306 +torch/include/ATen/ops/_aminmax_cpu_dispatch.h,sha256=JQdIVCG-KH--2flxdixP4AgDoqr77R6wJpEYb0svUfg,885 +torch/include/ATen/ops/_aminmax_cuda_dispatch.h,sha256=o5nciTlAmIEd_qCtTgCM2tel3oBUE0ZoWjlftfJ3dbw,887 +torch/include/ATen/ops/_aminmax_native.h,sha256=j0VvEcyERd31MAZ7_n3CUIyHRt0zgKONgoPyGjwrC9Q,930 +torch/include/ATen/ops/_aminmax_ops.h,sha256=K2oekqJnd8iPGvRp_qlXgIBST6DYfoS-GsQQHs64guQ,3446 +torch/include/ATen/ops/_amp_foreach_non_finite_check_and_unscale.h,sha256=tP6rLahse90gMOLZX3TaiZj_7_kxgCHcxfVH3aAymM8,2168 +torch/include/ATen/ops/_amp_foreach_non_finite_check_and_unscale_compositeexplicitautograd_dispatch.h,sha256=b0wHQnoGB2Af1DkiIbfIDoMq5UpXdPb7CRTO9yTze8g,1235 +torch/include/ATen/ops/_amp_foreach_non_finite_check_and_unscale_cpu_dispatch.h,sha256=KEjFeJWMUwy3hS3J2HoPPP-imyHVaSJPunn9S5WYJTY,823 +torch/include/ATen/ops/_amp_foreach_non_finite_check_and_unscale_cuda_dispatch.h,sha256=caieoRLUGobHxQVlI4ZymUhnsgllHlAoRHNzlS14BsY,825 +torch/include/ATen/ops/_amp_foreach_non_finite_check_and_unscale_native.h,sha256=N4Nhv1V8PNkPqQdBGtDWbirtfnHk4RLHb3ONU6OZCAI,1065 +torch/include/ATen/ops/_amp_foreach_non_finite_check_and_unscale_ops.h,sha256=dLwVplzy1WaHCiXuTa6vjWS6KsvTL6_vlLaerYRE014,3059 +torch/include/ATen/ops/_amp_update_scale.h,sha256=GjeTn5jtY1TmlAoMguqfvfksqrhvtBhOxiuOHJjffZQ,2809 +torch/include/ATen/ops/_amp_update_scale_compositeexplicitautograd_dispatch.h,sha256=WHBR1m6XqEc9XK77h7Ncv_MQ2eqgF3jTygDTKPwJIJM,1433 +torch/include/ATen/ops/_amp_update_scale_cpu_dispatch.h,sha256=Av5AeAvAO43oBC22vHtgt6vz_5oKxwkK3wNT2HanKwQ,892 +torch/include/ATen/ops/_amp_update_scale_cuda_dispatch.h,sha256=6LZzC80wx7l7icbGV8Chu7eACT3GWAvQWaqkB722zbw,894 +torch/include/ATen/ops/_amp_update_scale_meta_dispatch.h,sha256=h9Ev5kF_XDFrlApMiaY8rh6vNxH5HIHaDOq7bZpuu6o,894 +torch/include/ATen/ops/_amp_update_scale_native.h,sha256=E-Ocjz00GSDLxCjpbBPiCktomhAdv-wfU3-vQsMZHHc,1328 +torch/include/ATen/ops/_amp_update_scale_ops.h,sha256=9Qbsmf5UQIAE7pM49SDaC5mRkNwhqGhxMfP67zgyBOk,3700 +torch/include/ATen/ops/_assert_async.h,sha256=fKga2NiDGenVz57Cmf5TDR6zTDKUUMCI7n9cD9DWBXw,902 +torch/include/ATen/ops/_assert_async_cpu_dispatch.h,sha256=L4-Y8LLFrGRag2kiqjzPl32v2jZKUyDy-c-wN2iCXrg,829 +torch/include/ATen/ops/_assert_async_cuda_dispatch.h,sha256=eTObruAJJVLGErVIy9TNM-QCBqVfEGYJB0-FP9wJRGk,831 +torch/include/ATen/ops/_assert_async_native.h,sha256=twWJiJFAnQJVnuTTGsJrDwg7tLN7g0i2A5fwA7ADUUQ,750 +torch/include/ATen/ops/_assert_async_ops.h,sha256=dZxDp9-G1Juljc3BO0Gg2qTBTdIPujd7ml0phdx1kaQ,1598 +torch/include/ATen/ops/_assert_scalar.h,sha256=Io5Ss4efj0XjsWxPSuWLr2_1f20pyT8PVeInyPfiunQ,747 +torch/include/ATen/ops/_assert_scalar_compositeexplicitautograd_dispatch.h,sha256=6zBeyr51kd38BDPeNbghNj36Y8l69_k7q__TOB8Mt2M,818 +torch/include/ATen/ops/_assert_scalar_native.h,sha256=UtI57zQ-qW1iiksQnRkMu7kXCUCORwAQkAbSYl5otSI,528 +torch/include/ATen/ops/_assert_scalar_ops.h,sha256=sOYKVZyTQYdRN81AUVzyDv5-d64GtjkUhSOjJaiTFFI,1080 +torch/include/ATen/ops/_assert_tensor_metadata.h,sha256=53iqgUSVMIhQlfR-xLMke-aCibSjxD4hT76QxajWs2E,3066 +torch/include/ATen/ops/_assert_tensor_metadata_compositeexplicitautograd_dispatch.h,sha256=Egh729_4lvmvCDbqvZt2R6EPCsoNSCPTI_Det4f40bw,1367 +torch/include/ATen/ops/_assert_tensor_metadata_meta_dispatch.h,sha256=tCMEim27krBSkngt4BuoJgt4CuTJ5Y7j9oOKfgIIa_c,1325 +torch/include/ATen/ops/_assert_tensor_metadata_native.h,sha256=vAFRiWJ1rKyXCqDBJ5q36IM11DEkKKqCT6YF6sYeUFw,1082 +torch/include/ATen/ops/_assert_tensor_metadata_ops.h,sha256=s94Jg5WvTwODgBOG3gv_X5Uc8cnORCLy7p-nqPtmI7o,1624 +torch/include/ATen/ops/_autocast_to_full_precision.h,sha256=L0BxsYEHIAngpA0VWcWjOvUJSnlivV1qYeRLycYnCJk,551 +torch/include/ATen/ops/_autocast_to_full_precision_compositeimplicitautograd_dispatch.h,sha256=_U1_03lX7WkSQsCoU1vN-U8XiCPG78fbwzCP-HiRWY4,845 +torch/include/ATen/ops/_autocast_to_full_precision_native.h,sha256=Vb0oQ07zhIPHfhlVsDnbiv4qv8OLkpLoYE3WwKRTnec,555 +torch/include/ATen/ops/_autocast_to_full_precision_ops.h,sha256=Dr3i9T2dnYKSFSn6RIFwi51CrL1Tk1-Z_DYE22xbgX8,1178 +torch/include/ATen/ops/_autocast_to_reduced_precision.h,sha256=cJt08R9orSQ98xi-aBAfiCjHiEu6lOmNxWuKGX6U9jY,554 +torch/include/ATen/ops/_autocast_to_reduced_precision_compositeimplicitautograd_dispatch.h,sha256=Cv8MBhpSfT7lHvi5OBpEGUqEyTyROPinV8CtOEqYqKs,901 +torch/include/ATen/ops/_autocast_to_reduced_precision_native.h,sha256=LMQtpdmmYqHmg00b4lVo3DJbL6-Ogej0MSDoxVWfzZI,611 +torch/include/ATen/ops/_autocast_to_reduced_precision_ops.h,sha256=XEUcXovSn_eQ3jGB8VyNbd28_QE0KfJsuiah1FX0_98,1370 +torch/include/ATen/ops/_backward.h,sha256=tqlM0aTlDbed91-WSoIXQZp9bxU78aTZ7wIuwLwlzw8,533 +torch/include/ATen/ops/_backward_compositeimplicitautograd_dispatch.h,sha256=8fNJ2JxxegqKS67dzCYLKohLOeZbmEcJZGAYMgvPd0E,932 +torch/include/ATen/ops/_backward_native.h,sha256=tzH4qFnRrh-nobtrTzjV9zTFq-dPCOaboQYlfHniCdo,642 +torch/include/ATen/ops/_backward_ops.h,sha256=2SWpJKrVaglTEUU-wwJQWOKHE99a2TRWIIPer5A-FGQ,1393 +torch/include/ATen/ops/_batch_norm_impl_index.h,sha256=4CeJrIXtw_uS8RhBuV9hUbTnDmq6lLcWWvGtXm0yDBA,1286 +torch/include/ATen/ops/_batch_norm_impl_index_backward.h,sha256=f7H7-3oGBOxZDZUXzg4d3EBApyjBjVfsMSI_YvAZaFA,1554 +torch/include/ATen/ops/_batch_norm_impl_index_backward_compositeimplicitautograd_dispatch.h,sha256=3-cnEBbrBi8JbJgWfpdhaVAY87AugBXuktZhmPxFgm4,1239 +torch/include/ATen/ops/_batch_norm_impl_index_backward_native.h,sha256=ccQHHHzsCG65WurX_fCzf_t9SnFJnmDDfUdvyKzy3iU,949 +torch/include/ATen/ops/_batch_norm_impl_index_backward_ops.h,sha256=AqAmeGUsoGLBQHiHmqVnnH54YsdXMd4VZeFfrXLgWSc,2448 +torch/include/ATen/ops/_batch_norm_impl_index_compositeimplicitautograd_dispatch.h,sha256=CXcfrzDFCx3EKbxz-f5jyoCQewH9sMvu9K3apLiCe_c,1108 +torch/include/ATen/ops/_batch_norm_impl_index_native.h,sha256=tMZmsvd8nd4DfPsmUL47eJBKHCcvLnNHtU_1s2zvIMU,818 +torch/include/ATen/ops/_batch_norm_impl_index_ops.h,sha256=32zm6nHuk8wIgLjWAjnaIee55ux8hqXXhfDf0qfmBqg,2042 +torch/include/ATen/ops/_batch_norm_no_update.h,sha256=rMIk2pfzUWGBNuqtn_WCp-QVe1u0pTnG4IG0FTlr0Ts,2849 +torch/include/ATen/ops/_batch_norm_no_update_compositeexplicitautograd_dispatch.h,sha256=b6-5joZxXmdc3etq_Vape7ICj3EBysEu82M-MP2x_dw,1905 +torch/include/ATen/ops/_batch_norm_no_update_native.h,sha256=pR1oysdLyUeDdm_c9ZiW5YSTP0XJFSWxLvc38Id4viA,1194 +torch/include/ATen/ops/_batch_norm_no_update_ops.h,sha256=7pYu-lMXBYPHvvJ--ifor21vyWEsksdy-A7uCyCGqvQ,3641 +torch/include/ATen/ops/_batch_norm_with_update.h,sha256=jA7EIqEbeyWV-iz4IraTJdXLkgyxgeUFPJiVIUFFXUE,3538 +torch/include/ATen/ops/_batch_norm_with_update_compositeexplicitautograd_dispatch.h,sha256=N3fJ84Cqry3jZBu8_mNTqxIdxUzJ6KzdmV9HKPLU6VU,1065 +torch/include/ATen/ops/_batch_norm_with_update_cpu_dispatch.h,sha256=lmLWYuszFRg9SYv-A7T8JlI3qmNw7FOFwCzTmhAn78k,1757 +torch/include/ATen/ops/_batch_norm_with_update_cuda_dispatch.h,sha256=NK3yf7QN6wIV9g0gjbynw6FPsHqhOKNg9ONyD0NNpzo,1759 +torch/include/ATen/ops/_batch_norm_with_update_native.h,sha256=XXMWeZ4ilgI62_b-L_sVnYmKz5SPyp8mnELaLe24CHs,2444 +torch/include/ATen/ops/_batch_norm_with_update_ops.h,sha256=FuPQDi4z0UnnXsx-KtY_K8ZfqjQ-OuAZHoTmlXmaxW8,4899 +torch/include/ATen/ops/_cast_Byte.h,sha256=DujszKXwMtIqLRD6xTTj0_XBKxsQZr0_U16IjRfHfko,748 +torch/include/ATen/ops/_cast_Byte_compositeimplicitautograd_dispatch.h,sha256=x-nzOsNwnEvGHJWwEPqsqlS1J1de0HDLpFANVsW_ozE,816 +torch/include/ATen/ops/_cast_Byte_native.h,sha256=iIiFdsP8Nri8iueCiFMeZWMhj20aWt8ElO5KkJ5H9oQ,526 +torch/include/ATen/ops/_cast_Byte_ops.h,sha256=B8EAXKLRcUkiF_P0WI5a0qmdcS9Yi907zA7Wp0qVZ5I,1067 +torch/include/ATen/ops/_cast_Char.h,sha256=sfzhfZYnsM0sy47zGmPtTLQ2oTBxbOnBCfJHkFDcW5w,748 +torch/include/ATen/ops/_cast_Char_compositeimplicitautograd_dispatch.h,sha256=EYVr9OItGZMxDYDQOnhn6VeTJ4h9t-qyQbblJeP9iNY,816 +torch/include/ATen/ops/_cast_Char_native.h,sha256=zDiByFv0LpUpiu3yz8YcPUCzQclHPcOLJjqnZmSDvNs,526 +torch/include/ATen/ops/_cast_Char_ops.h,sha256=dzx4UP2Qa3v7Kujr1RwKqDB6PFT3o4vmRioD8FTuxjo,1067 +torch/include/ATen/ops/_cast_Double.h,sha256=f8Ar5YjgfyaYQ-MeNWNIgI6cZvo8TOMja2hExjshYPc,756 +torch/include/ATen/ops/_cast_Double_compositeimplicitautograd_dispatch.h,sha256=SW32RvDuw8_TwkhnYzDYyUNOrkuCaBUbAPK70O0f9zA,818 +torch/include/ATen/ops/_cast_Double_native.h,sha256=EYfrIbp3l4ZXCLxA8ZTzGQTZZUIXqS40wazERI27z64,528 +torch/include/ATen/ops/_cast_Double_ops.h,sha256=JYjgqAVy7SRNwjUnATkiCbViI-HBhxd5dqHfl2SURzo,1073 +torch/include/ATen/ops/_cast_Float.h,sha256=uO3Msy_r41F15woGt2fCsqtMSkyqp6O-CtO_dGZLDqo,752 +torch/include/ATen/ops/_cast_Float_compositeimplicitautograd_dispatch.h,sha256=OJhv0SlE9VORDo5b0S1DDQwNKmi72_1cKwCjqT0GJtA,817 +torch/include/ATen/ops/_cast_Float_native.h,sha256=arQvOeOnS1Gm-JJUr3VE1gnI6BDyrHufys76UrELevQ,527 +torch/include/ATen/ops/_cast_Float_ops.h,sha256=Thotqqg9GHLllQ8isdUjkc3b3-_jh08pMtrQPgCQek8,1070 +torch/include/ATen/ops/_cast_Half.h,sha256=OS02edPFML225VPJ4SQn5ZQyw8lv41ltVmHqArE-gYE,748 +torch/include/ATen/ops/_cast_Half_compositeimplicitautograd_dispatch.h,sha256=gQmob8ogOPleKWSQW9oYe38iRnFmq1uyN7gD3X7zpeI,816 +torch/include/ATen/ops/_cast_Half_native.h,sha256=sEekZ_rbibdK-HBE2wdOkVJ1HGcaALjb_vsdCQJGlzg,526 +torch/include/ATen/ops/_cast_Half_ops.h,sha256=eWNUpQtSxERhtsUYvn342-N9sdFDioZbHRHHkShzh8U,1067 +torch/include/ATen/ops/_cast_Int.h,sha256=ge2x002R0mXGfPA3UQ2HDh1_z4cGkgawYqdXjNjfdFU,744 +torch/include/ATen/ops/_cast_Int_compositeimplicitautograd_dispatch.h,sha256=WaPWLY7WTzqP__vB-rbpd1ayR7z4nrvZxzQHgiS_zPY,815 +torch/include/ATen/ops/_cast_Int_native.h,sha256=kh_0f0cHdNQpDfuDIUchWrXdutH-6mqGwIoa5pqEFoc,525 +torch/include/ATen/ops/_cast_Int_ops.h,sha256=pRUGfotUl5D_QzyqmNFFhx-RzMCcnzri3qlpdkrmbbU,1064 +torch/include/ATen/ops/_cast_Long.h,sha256=GhqrinN2sKlH4kGlfU-Tc0BEwBPDjIgy8lUKPquw8t0,748 +torch/include/ATen/ops/_cast_Long_compositeimplicitautograd_dispatch.h,sha256=stJ57xlV2LQxIHTNgKBtjtcUhRDR_9_8y9OBNLNihS4,816 +torch/include/ATen/ops/_cast_Long_native.h,sha256=unYUm3nGWF6gq10XE2_jaOf-Eh3MXuOEHUIDY881vhE,526 +torch/include/ATen/ops/_cast_Long_ops.h,sha256=EU7BFCfLRQu-Fezk0rtBsvCjnwifHbNg3pFRPKdDGEA,1067 +torch/include/ATen/ops/_cast_Short.h,sha256=EWIbiTK75eF2rHsgSV4_udSdpWz54fqwINJIdPEGUEo,752 +torch/include/ATen/ops/_cast_Short_compositeimplicitautograd_dispatch.h,sha256=mq4jHXg63wFbO06YaEdNQU1Q5vEEvIVGBsjazFdlyQk,817 +torch/include/ATen/ops/_cast_Short_native.h,sha256=d8Psk6Mnf1x4nQzEBqbpcDTM-4veJKE9N-8vT2fBMjM,527 +torch/include/ATen/ops/_cast_Short_ops.h,sha256=Z_MlqY0tzMTK93Ow_id2kqITXF4RxMC7pB6SZiknerc,1070 +torch/include/ATen/ops/_cdist_backward.h,sha256=pVzUcBuPCOHmAld8LjX6BQgffQBHZsIsORslLxSeJcA,1592 +torch/include/ATen/ops/_cdist_backward_compositeexplicitautograd_dispatch.h,sha256=sgA7Su3m-8iHc9V4ZF2IV9a4UnR-VuOzFuXI8-12V6g,1073 +torch/include/ATen/ops/_cdist_backward_cpu_dispatch.h,sha256=Cm_SF7Eh3rHCYTCTap_cTVVjQmcXY5k-muYeT5crLro,834 +torch/include/ATen/ops/_cdist_backward_cuda_dispatch.h,sha256=A0FIf35VKqIFRni-31nbg_bbxwaM2Q_5klDhcoJbD2Q,836 +torch/include/ATen/ops/_cdist_backward_native.h,sha256=dwKjn1DFUNHOoWhRQN6RFuDN2NPv3s6QTbNykOYzYAs,758 +torch/include/ATen/ops/_cdist_backward_ops.h,sha256=CpdTNkaINzm8ZuxxSuU5FYRuUwmGXE33oFENuGltQG8,2195 +torch/include/ATen/ops/_cdist_forward.h,sha256=Lm7_2Km3iFKooD0aA6zYcU-9XR5VCHbtRIRbhyAVjCE,1525 +torch/include/ATen/ops/_cdist_forward_compositeexplicitautograd_dispatch.h,sha256=VfWi6WXjR3JsvrMbIXISAuU3liFGz6vwzbdiQg_Km1A,1047 +torch/include/ATen/ops/_cdist_forward_cpu_dispatch.h,sha256=T6twrFA8x8WgXdYOyp8QWa4pyu1qsfE9VS1CPV-PxT0,821 +torch/include/ATen/ops/_cdist_forward_cuda_dispatch.h,sha256=LvYQooPq_LoVpMCzLJunw0Arsr5AWvvJsdraEwxB-Jc,823 +torch/include/ATen/ops/_cdist_forward_native.h,sha256=bMWDYlROEDlSobZLpMAmmy3yIKkIuC-64xwCocOIeFk,732 +torch/include/ATen/ops/_cdist_forward_ops.h,sha256=opX2OJ8guVFWjxj2WpjaGsZx0qpOQJ4AODY4AtqyXzc,2097 +torch/include/ATen/ops/_cholesky_solve_helper.h,sha256=1IbQdc0gGfsd1g5ry3uDt9x7BoSAIUbfVOkPa7KJZU0,1425 +torch/include/ATen/ops/_cholesky_solve_helper_compositeexplicitautograd_dispatch.h,sha256=EiOBZ3k9Na1NuIXgsmFHNBu_3a-D2M4AR_cq8JswLu4,991 +torch/include/ATen/ops/_cholesky_solve_helper_cpu_dispatch.h,sha256=fBGKTGSQN1JKFfEnfE7z_44Imi1HtqUMvn7O5ox1MT4,793 +torch/include/ATen/ops/_cholesky_solve_helper_cuda_dispatch.h,sha256=0bTFGi5OFMfuFmnGJcKtECiTUE0_ukns_xoUtwH8aFI,795 +torch/include/ATen/ops/_cholesky_solve_helper_native.h,sha256=IQGLxRcnZ3i3Vw3uh-mZoJCbBQ0yw6b4FmMc7qOW4wE,790 +torch/include/ATen/ops/_cholesky_solve_helper_ops.h,sha256=OSRljHIo11KB3L-g_bNroH3kOkav3cYImzPNhGBlg3Q,1915 +torch/include/ATen/ops/_choose_qparams_per_tensor.h,sha256=wmRjwLD9GkI26lAU0kvJiDwLLOSXI1fvWUwVSzDU3GE,836 +torch/include/ATen/ops/_choose_qparams_per_tensor_compositeimplicitautograd_dispatch.h,sha256=g3KOZar9pPpa49yLXnmpw-vsSswFIn_ZvSI7JIZB2Cg,850 +torch/include/ATen/ops/_choose_qparams_per_tensor_native.h,sha256=lKiSerULu1iWDdlEq7ITUsEASYtvUMvGJumqQxQsbj0,560 +torch/include/ATen/ops/_choose_qparams_per_tensor_ops.h,sha256=FL5Xq4S2Sw5iuSbpQwnl84u5SmCOHacGCGmktwy2DMA,1175 +torch/include/ATen/ops/_chunk_cat.h,sha256=5vMR8DpyybSWCaRjcdxDVQkDfLymXdgLhAj-u8bD0R4,1353 +torch/include/ATen/ops/_chunk_cat_compositeexplicitautograd_dispatch.h,sha256=eYuAV3i11oT5P-A7wnN4hFnW_xxBCzj8-hGy-EGSpaI,1054 +torch/include/ATen/ops/_chunk_cat_cuda_dispatch.h,sha256=y1AjBvq7qbHfN5vYP3_rA-dfJkMW_RwaRDgpxenbnIc,1012 +torch/include/ATen/ops/_chunk_cat_native.h,sha256=O6iDDLH82R1D_ynhGjv8k_2dmywGJgWD17OGC_4U_oY,864 +torch/include/ATen/ops/_chunk_cat_ops.h,sha256=kpB53-FMP3NnoyJaFZ7zjUdXTbMWvLG0N6CMwzt_mSI,1827 +torch/include/ATen/ops/_coalesce.h,sha256=bdUCZ02HzKBXwEe5GJIxr0uUDcP994WH3yY14QH7v0M,1097 +torch/include/ATen/ops/_coalesce_compositeexplicitautograd_dispatch.h,sha256=yB7yEtg2XEJF5ZgA4nyoG-k_1H0iELy3Gdyd65IjoH0,897 +torch/include/ATen/ops/_coalesce_native.h,sha256=_T8rC5DchG_N1XafTmh_-HYaldVBh5gGoc5XH6qEwYE,663 +torch/include/ATen/ops/_coalesce_ops.h,sha256=4_tFYFRZn4NsCCLhHYl2LphClnNN5Y-dcsvu_DPMFBg,1605 +torch/include/ATen/ops/_coalesced.h,sha256=X2Jflm0UYwooEANALuL1CY-7lrT11KlwvKiX_ixVcmA,1236 +torch/include/ATen/ops/_coalesced_compositeexplicitautograd_dispatch.h,sha256=j7i91HeUW3vwTecJSsEpF0aFnrNUE0Ji3Riwvq1wYzk,1006 +torch/include/ATen/ops/_coalesced_meta_dispatch.h,sha256=Ja3WiiBMMjG1qfYRW7kwYakjHHcwa-yUH-VQ9jk-jIU,762 +torch/include/ATen/ops/_coalesced_native.h,sha256=KJ4jZiTNf_VtJeSLn0_dXxf0MW6TaM540Err4lvJLc4,695 +torch/include/ATen/ops/_coalesced_ops.h,sha256=atqEYW5XbUOTwr2egakI4DNaGd1Uyko866TOLHN2lLQ,2295 +torch/include/ATen/ops/_compute_linear_combination.h,sha256=bcDgTNdC9_T45Zvefi6sgd4UfRta_Icv0MDEttR01kg,1490 +torch/include/ATen/ops/_compute_linear_combination_cpu_dispatch.h,sha256=sA13mtupRzN00xJf-4fC-jp3TGTAvAzWmsUFPVhm8Ao,1067 +torch/include/ATen/ops/_compute_linear_combination_cuda_dispatch.h,sha256=ZONZ-kGg2EgNO43ZkWlGUUEL-tt2EeS8NpD1Qp_BYC8,1069 +torch/include/ATen/ops/_compute_linear_combination_native.h,sha256=nzopHo1Vy2v68_BYJSRQpn_c2oFHr8hzxVVh0Zih_iQ,686 +torch/include/ATen/ops/_compute_linear_combination_ops.h,sha256=Ehq2f55-Y9ujD8G0RePCJB3X_AixRvAbZ-uyJL44pBA,1933 +torch/include/ATen/ops/_conj.h,sha256=7i7OuJDt8fayXqaFCgZjIjOhQJ2zmi0Uy1Q_ASY5euQ,670 +torch/include/ATen/ops/_conj_compositeexplicitautograd_dispatch.h,sha256=HJPCcR2g6VH4VEi6vjEfWjfb5T8Zyqb4hWsWDBTPL-c,786 +torch/include/ATen/ops/_conj_copy.h,sha256=38Xi7ABhkmenh7nRsiWh5XX0Cx-NdVisRZZd4mW9-AM,1107 +torch/include/ATen/ops/_conj_copy_compositeexplicitautograd_dispatch.h,sha256=UyRI63XSqqYSEhZ-RayyLSHQGomVFtZji0T5_adxac0,899 +torch/include/ATen/ops/_conj_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=xBt0CK6HTq8UQukMwRCkX7kr7-50G8P91t7QXTZGz9A,817 +torch/include/ATen/ops/_conj_copy_native.h,sha256=vj1Ioua7Waw0oGm4wmqoSJKMSGGxgFo5S4sNdF1HikA,584 +torch/include/ATen/ops/_conj_copy_ops.h,sha256=pHOk152BWcn7K09mN02-FWd4IpW0SFkrm2lYwdbtcpo,1611 +torch/include/ATen/ops/_conj_native.h,sha256=n4zF_Wis0SdANht7A2XhkC4OtpuvMb1LfLCT2w8yoic,496 +torch/include/ATen/ops/_conj_ops.h,sha256=6bBdnCNU9i_K331whmvVbVeiEEyFtU7xbM-570N_JBI,989 +torch/include/ATen/ops/_conj_physical.h,sha256=DvdT3xoOZyFOM_8Bf7CE_YE_0B5rqwhw-TNKDbC8pgM,1147 +torch/include/ATen/ops/_conj_physical_compositeexplicitautograd_dispatch.h,sha256=SxJJ6UyzER7U_WpTLVqa1RmS-Yvvar-D6qgRuuqcIVQ,970 +torch/include/ATen/ops/_conj_physical_native.h,sha256=adccAFr6Yj9WwRQiFyFmDneZgGFBIn90v6tch82WDwo,665 +torch/include/ATen/ops/_conj_physical_ops.h,sha256=-XGf3jpO_zBTDtsGcexWkfrytSNvmwG2inh1_V-_k5Q,1635 +torch/include/ATen/ops/_conv_depthwise2d.h,sha256=ub-ddeNS0WkTwPfjy_XQHlqJlrdR4UFUvDqj8t77DPo,7416 +torch/include/ATen/ops/_conv_depthwise2d_cuda_dispatch.h,sha256=JoQ5KMJ0jV7-b1pBjKzSQkkW9dge_XOP1xgQ6Fs35-4,2291 +torch/include/ATen/ops/_conv_depthwise2d_native.h,sha256=DQbySE2ldMLhXmqYaQ5k59vVwstK3lL5v-9kKnF0gMY,952 +torch/include/ATen/ops/_conv_depthwise2d_ops.h,sha256=7rydf4PfRyIcdVRqcMtIg44h_06NFegoXU68aTYFs18,2909 +torch/include/ATen/ops/_convert_indices_from_coo_to_csr.h,sha256=K7qp0YffYQNxRvO0ITCybZXPkdhjJG26kWW8I5BxtrM,1579 +torch/include/ATen/ops/_convert_indices_from_coo_to_csr_compositeexplicitautogradnonfunctional_dispatch.h,sha256=na70uMryAzCjGE-SF5CJ_33SOsTCONIZaise90ruAS4,875 +torch/include/ATen/ops/_convert_indices_from_coo_to_csr_cpu_dispatch.h,sha256=ckHNC4iLTGL2T9g03_P-0EtAvAzq9QHViwUEmdsJ_MY,1082 +torch/include/ATen/ops/_convert_indices_from_coo_to_csr_cuda_dispatch.h,sha256=lQP1UY5gVLami-fHRqGe1xFwK2SrX9H8S1R-iBwHkp8,1084 +torch/include/ATen/ops/_convert_indices_from_coo_to_csr_meta.h,sha256=RUZwb0s2kf695-ydUMhN_OCMwumY714AewlgE17M81E,650 +torch/include/ATen/ops/_convert_indices_from_coo_to_csr_meta_dispatch.h,sha256=k8RrPHprXow5W9fLnsXWHC_mpubuQS_vGTDUlSoWEgE,1084 +torch/include/ATen/ops/_convert_indices_from_coo_to_csr_native.h,sha256=l_2AwpQEaA12p7uYqN2xxsSq3v2DmlrBKkLvTRfFdao,976 +torch/include/ATen/ops/_convert_indices_from_coo_to_csr_ops.h,sha256=4uPEpy41RbENqgmk1UhnF7ZcLpH_PC9P1MKnsZEDnnI,1960 +torch/include/ATen/ops/_convert_indices_from_csr_to_coo.h,sha256=nNmujmk9POJDRZZxO4TsmCferXVp9V5wg0dmxnibydA,1915 +torch/include/ATen/ops/_convert_indices_from_csr_to_coo_compositeexplicitautogradnonfunctional_dispatch.h,sha256=l2vn16L7aeQd4ZKf7LqXDBhZMIxfb1Yk7SosPIR0pRU,923 +torch/include/ATen/ops/_convert_indices_from_csr_to_coo_cpu_dispatch.h,sha256=j7ga3HIpV6aU5vrLqBF_uaGeQEdOYc-yoJXURBAPgAE,1220 +torch/include/ATen/ops/_convert_indices_from_csr_to_coo_cuda_dispatch.h,sha256=v0oTPgrOY5UsSTxQ0ZqYyOQ7WYxXaGQeWJzY_-ZLR2U,1222 +torch/include/ATen/ops/_convert_indices_from_csr_to_coo_meta.h,sha256=4fPGQI4Apkjxw47TxOVSiqFhi9qwZXUsIveseMu6wqk,692 +torch/include/ATen/ops/_convert_indices_from_csr_to_coo_meta_dispatch.h,sha256=DDfV9opmf096rrehB5u5HLT3VgmeDEqNOlFFrtU3ivo,1222 +torch/include/ATen/ops/_convert_indices_from_csr_to_coo_native.h,sha256=ABpBxQ-eHqlomz3UTGgvCxdnW6JR18BG7mPL6PccD8A,1060 +torch/include/ATen/ops/_convert_indices_from_csr_to_coo_ops.h,sha256=yNDjGtKui0UeNqAtBvsdRB70k0dJsqbv0iBDK-KI7Kg,2242 +torch/include/ATen/ops/_convert_weight_to_int4pack.h,sha256=YaWh0t1yTHRP3VyVuiL7MArVZlaz80Cxhd3a_KXKnHw,803 +torch/include/ATen/ops/_convert_weight_to_int4pack_cuda_dispatch.h,sha256=iLUhn_4cvD3ZtBrksRGt-whANqAjHDMRXzm-tBRt4RM,787 +torch/include/ATen/ops/_convert_weight_to_int4pack_for_cpu.h,sha256=U-k5sdwSRvCv89AcAx_ksIO9L-2k3YAo13220z6b2qo,835 +torch/include/ATen/ops/_convert_weight_to_int4pack_for_cpu_cpu_dispatch.h,sha256=0Q4kJBaOitp6Q0YpUqpqoVAU5D07rqUTLYYHuw8Dvck,793 +torch/include/ATen/ops/_convert_weight_to_int4pack_for_cpu_native.h,sha256=YtuThNJfZiCyZG42NDvBsotwtjyio4tr6-3XU3sA2pI,543 +torch/include/ATen/ops/_convert_weight_to_int4pack_for_cpu_ops.h,sha256=nQquvyoYmjHscVDgS43JaMlMrUxWk75d3TMrot11Mb8,1141 +torch/include/ATen/ops/_convert_weight_to_int4pack_native.h,sha256=VHGkSbJ9Cc8T46JzZvCkspTknfzSfkBlIO7xbZTWMvc,544 +torch/include/ATen/ops/_convert_weight_to_int4pack_ops.h,sha256=mykNGVpZ4W2ph9sEig-J5vUIWa1WygRlWEMeSojSyC4,1117 +torch/include/ATen/ops/_convolution.h,sha256=PGusWqYmD3BaLnwwq8NR5QH92_TfiFNd9sE8MENeC60,13092 +torch/include/ATen/ops/_convolution_compositeexplicitautograd_dispatch.h,sha256=ejvo7k1rt86-TUfVIhCLLR-nUkJtMfCDv2TNwXrjsgU,2975 +torch/include/ATen/ops/_convolution_compositeimplicitautograd_dispatch.h,sha256=NjcLXczUNQ40oWmP3pQzU0bBgU0LxwxncqOtVn0_pZo,1409 +torch/include/ATen/ops/_convolution_double_backward.h,sha256=VWeCQouQA1XCJ_4ijIlwMkv8o4YS3Bp79KTEns9sQe4,4083 +torch/include/ATen/ops/_convolution_double_backward_compositeimplicitautograd_dispatch.h,sha256=-s0w5GgTLr_i564EaGx6XQJi0I1JZ7UdIozVXIl2HiA,1679 +torch/include/ATen/ops/_convolution_double_backward_native.h,sha256=Gw-LWnZpjh_3jYWLD4KU5KMu2-vgYNJmK1o4p8vljAE,902 +torch/include/ATen/ops/_convolution_double_backward_ops.h,sha256=gBFnvudUd0h83n8-vzsyO1qrOBBy2KOX1qDlWMC4Y8U,2385 +torch/include/ATen/ops/_convolution_mode.h,sha256=hauLWjwFwTwpmhQTgk9fnsnQknb1Tlk5TBbwduXxdvw,2504 +torch/include/ATen/ops/_convolution_mode_compositeimplicitautograd_dispatch.h,sha256=PZo7Bf5TNXdMmc1oSf8bxh9zUtLqejILBcZqp90DB4I,1207 +torch/include/ATen/ops/_convolution_mode_native.h,sha256=JcZMBy14v5ovJu9WpbtFVOWyMnGiWd-7gTMYnW4duhw,689 +torch/include/ATen/ops/_convolution_mode_ops.h,sha256=COGHEaxdGsBlgTZRp8fWGuH1zlVaJ1kF9xjfEouLBwI,1591 +torch/include/ATen/ops/_convolution_native.h,sha256=iz3Z5iGV_tP2kJhJfGeRVLUmOxcRyJb-rPh8ntt-jVQ,1508 +torch/include/ATen/ops/_convolution_ops.h,sha256=L8_JMsE9VkXF8HkvjAKYuRAh1wQyC9rN8LzTJ6RoF-k,5124 +torch/include/ATen/ops/_copy_from.h,sha256=T7vd5q_Krh7KLjbEhRsmSW5qXkUiqq7fuWsRjYz_yPM,1416 +torch/include/ATen/ops/_copy_from_and_resize.h,sha256=GVwRYN10JNzjHrqLqrZXaFBSxier0iaGidm0VVZlPNw,1340 +torch/include/ATen/ops/_copy_from_and_resize_compositeexplicitautograd_dispatch.h,sha256=5lpBxcy2Ay5kRcmaA50iURA0sNBbgfCj3NBdgg_Opjw,969 +torch/include/ATen/ops/_copy_from_and_resize_native.h,sha256=XKiGor0z-NvKjF5PFVTEskl4TdNvC8PTtmjsBm0KcOE,560 +torch/include/ATen/ops/_copy_from_and_resize_ops.h,sha256=BSoVA2wEEJDYsMGpUuf7l4_GDtRHU9fd9vkMJIgHMxU,1837 +torch/include/ATen/ops/_copy_from_compositeexplicitautograd_dispatch.h,sha256=Xkmp-_XINm4Z6eFythNxt4kSWJiDPEebC87z9xqSHUA,991 +torch/include/ATen/ops/_copy_from_native.h,sha256=NBgPYCInNWVGGJ0lkKCUA4Spp51ZqZZspbkLs9bhTOA,568 +torch/include/ATen/ops/_copy_from_ops.h,sha256=Msumgeq3hyD_HZKxs5JhKjWoFGZQg_xVPJby3L7vzGY,1909 +torch/include/ATen/ops/_cslt_compress.h,sha256=Ee3nx7_XAYACfX9Q1V_ENLXkHcGx-lpWILzvCX8PuWE,703 +torch/include/ATen/ops/_cslt_compress_cuda_dispatch.h,sha256=kjXf1hY1GQfRmjLsfwE3y2mfH-1c-FyI82fL_956FSE,754 +torch/include/ATen/ops/_cslt_compress_native.h,sha256=TGknp3myJhWMDXKK1IZKnOWqqCyj_A3v8mMn1i5RSsw,506 +torch/include/ATen/ops/_cslt_compress_ops.h,sha256=5PbQvtWS_mEAiBDc9TH-VOA-J1lWCnKeeFeQHBXKmV4,1013 +torch/include/ATen/ops/_cslt_sparse_mm.h,sha256=A199a2t3QYe1Cp6sFqE5FgOS_T0f_ut0A3FnPozEJCM,1240 +torch/include/ATen/ops/_cslt_sparse_mm_cuda_dispatch.h,sha256=ThuVQaxfkSjnVoJBh0oX0GdEwwCbkI73Mc7PEpP35M0,1030 +torch/include/ATen/ops/_cslt_sparse_mm_native.h,sha256=24q5-BUN7odtfEJmzPKzGeuXgYLwnGn4PrdHeXzku_Q,782 +torch/include/ATen/ops/_cslt_sparse_mm_ops.h,sha256=bkV_lSVjp_kk1LLUHJmg_cCs7jypXAd1iV4PhNo9GFA,1827 +torch/include/ATen/ops/_cslt_sparse_mm_search.h,sha256=aK0gS8oUJxTdas6o6Vy293qAIQJ0VHROzJyQo1AhRXE,1119 +torch/include/ATen/ops/_cslt_sparse_mm_search_cuda_dispatch.h,sha256=gQuEdVD9eboR2wHVeFDNOPTozg8_nxzh2tO5eR-qoRA,972 +torch/include/ATen/ops/_cslt_sparse_mm_search_native.h,sha256=EkghcOrgSDSb8vZuh2s9lMmEZXlFDPR2FI613Y-EzGc,724 +torch/include/ATen/ops/_cslt_sparse_mm_search_ops.h,sha256=Lax2shN-gapoeRpynsLifwhZ_ISVMR2iB0UyLVa4aJk,1649 +torch/include/ATen/ops/_ctc_loss.h,sha256=7eUOUjwo08KFAy_vxN0IxBgqRwo1BzeKQHygp5UZkrY,3981 +torch/include/ATen/ops/_ctc_loss_backward.h,sha256=fklTPUSQfVOawVcdmFh68UhqVl0jojo9WytmXp1a2l8,3355 +torch/include/ATen/ops/_ctc_loss_backward_compositeexplicitautograd_dispatch.h,sha256=yrWR1h9y3U5tBD5_Uc3y-kdGdX-Ocdku9fo-o0EHHbU,1371 +torch/include/ATen/ops/_ctc_loss_backward_cpu_dispatch.h,sha256=dSUO_3KCjRWrKIR9ywaWd4HVFZzZOJubToAyJsAYrn8,1290 +torch/include/ATen/ops/_ctc_loss_backward_cuda_dispatch.h,sha256=tTZrTC-YhlnYDFgE6Q5vEyjc_zpMAt8stAbYKP-uN1Q,1292 +torch/include/ATen/ops/_ctc_loss_backward_native.h,sha256=n1ZA7ybZwmqN9Ltd0_sbD8g6jdXuzW_O5eef-wMvDkw,1670 +torch/include/ATen/ops/_ctc_loss_backward_ops.h,sha256=nosyPwvhR0oIbOYiKJwGssTTYPVYcb3CPgL5ejHB_7k,4460 +torch/include/ATen/ops/_ctc_loss_compositeexplicitautograd_dispatch.h,sha256=o4MQ7wOTCD0IP3S-JYXx2t7o1vEkbO5pnXku6ztodUo,1802 +torch/include/ATen/ops/_ctc_loss_cpu_dispatch.h,sha256=oWCWyO5QBh5XYSOo-WcefJhRwWvMKh_6hdglANvhyOc,1138 +torch/include/ATen/ops/_ctc_loss_cuda_dispatch.h,sha256=xxFjzSadRFrwWl4uWJCuz3ilBkTXgOkA6lwvkiW5zGg,1140 +torch/include/ATen/ops/_ctc_loss_meta_dispatch.h,sha256=VOdpGzmo0KmOHvNP2f9uXMqGdCLR0oqMPin4eYdAWE0,912 +torch/include/ATen/ops/_ctc_loss_native.h,sha256=NAhs4S9GAn6aroeeZofI1IXHmdnhMHkALPwC-TT8igc,1885 +torch/include/ATen/ops/_ctc_loss_ops.h,sha256=HpjMjPZvMJiMPoSZbImq7OzQXkP72wsVD3t11gHedW4,5056 +torch/include/ATen/ops/_cudnn_attention_forward.h,sha256=LV4wnoH-fTUysub5QVx1TLgPYaZvE9dFr-Mhp2NEmkk,4607 +torch/include/ATen/ops/_cudnn_attention_forward_cuda_dispatch.h,sha256=H9zXizgcpnFj0LrH9kih8EuwWSAKOXt4ghDwGool1fY,1793 +torch/include/ATen/ops/_cudnn_attention_forward_native.h,sha256=Xr9IieZ2ThHTFVrx3-7OKFagR6Z0L0Ha_zegSGMUdn8,986 +torch/include/ATen/ops/_cudnn_attention_forward_ops.h,sha256=3dV3MpJB44SsoIEgGJ-dSgTV6Xx3XpFmB_QiSvMlyuU,2657 +torch/include/ATen/ops/_cudnn_ctc_loss.h,sha256=jxSz98HS94CATjVyGUkB214y8qSLnO9ItKZiL5fQr9o,2980 +torch/include/ATen/ops/_cudnn_ctc_loss_compositeexplicitautograd_dispatch.h,sha256=UZgTpU85FXOPg4G8MLoTrnBQTJbPM6ekiWkjhfdM6eU,1305 +torch/include/ATen/ops/_cudnn_ctc_loss_cuda_dispatch.h,sha256=3J_XMFKt2cqArQ3mTinaUM627OmNOZUeRsGUdb_RUNQ,1176 +torch/include/ATen/ops/_cudnn_ctc_loss_native.h,sha256=2MGrzTmwwMScJ4NGS8nnuTfP6XMvyHf9IOkust6BJ5A,1221 +torch/include/ATen/ops/_cudnn_ctc_loss_ops.h,sha256=YjeQI9Q30cRykO21JQjawsuX6gJEyvTQO8dMbFjCHV8,4032 +torch/include/ATen/ops/_cudnn_init_dropout_state.h,sha256=rIDafXeye9YWJ47gk8MMuMANC56PHD4OeFHJHkcL6y4,2293 +torch/include/ATen/ops/_cudnn_init_dropout_state_compositeexplicitautograd_dispatch.h,sha256=t0xxnO_EP0gGxQ5xO1ojymrw3R537JWoEobqDj1V0Eg,979 +torch/include/ATen/ops/_cudnn_init_dropout_state_cuda_dispatch.h,sha256=4LY6BYl7aFcRqPPbrvms5hN89GDeDw_tYHSoHrPUTic,1060 +torch/include/ATen/ops/_cudnn_init_dropout_state_native.h,sha256=pTWhYCE7aLglANp3cI1qhMz5_kYIjO4knMMN1EMGGjo,821 +torch/include/ATen/ops/_cudnn_init_dropout_state_ops.h,sha256=rw2-mvinxtKEc6prmRYOIbxIRVNfoERJ-HVe-660bkM,2368 +torch/include/ATen/ops/_cudnn_rnn.h,sha256=8bMPmeoa7idjOAOrWDYgCaQ6rUiKxnzLe-sUfXPkik8,13435 +torch/include/ATen/ops/_cudnn_rnn_backward.h,sha256=bfEQPFqdkNWwNodaMVVdVDZwAm59jI5KKQZhCdmUTvk,16550 +torch/include/ATen/ops/_cudnn_rnn_backward_compositeexplicitautograd_dispatch.h,sha256=bgqdmn9VP6WmrjjaFqFZSWUj1zhtqvsciLP4QGSsU7U,3708 +torch/include/ATen/ops/_cudnn_rnn_backward_cuda_dispatch.h,sha256=TEhJgjPdCq8yvwDPQ_Qcj9doQtcLx_MgOAIRPzTIli4,2149 +torch/include/ATen/ops/_cudnn_rnn_backward_native.h,sha256=2W0eYjL02yW-iTgClnwOkRQd0--31Dq7iEjb7NbHcSA,1915 +torch/include/ATen/ops/_cudnn_rnn_backward_ops.h,sha256=4noeEIFz-V66xslFWoO_-u7Kp54oOyk9PeU0fxzmxS0,6001 +torch/include/ATen/ops/_cudnn_rnn_compositeexplicitautograd_dispatch.h,sha256=css1XRtLMzOF5B8kQkPtYuODD-MANtXkoqoK9iYfBiM,3192 +torch/include/ATen/ops/_cudnn_rnn_cuda_dispatch.h,sha256=rf7N620D_SK8IRRjAakus9TwMoU3D5YnZgkGs7skUrg,1701 +torch/include/ATen/ops/_cudnn_rnn_flatten_weight.h,sha256=xBtXgx5mXNil-zL_PGvG23TuInz0OmMotaindVUyw0Q,7676 +torch/include/ATen/ops/_cudnn_rnn_flatten_weight_compositeexplicitautograd_dispatch.h,sha256=-xg6wQO2uGokldVfokv9BGHPg_B1iaicbX6INNjkPBk,1796 +torch/include/ATen/ops/_cudnn_rnn_flatten_weight_cuda_dispatch.h,sha256=AgCNDQ28AWAlYrC7CQZ1uPdZXZWaQBDqLPi12zCsJfs,1173 +torch/include/ATen/ops/_cudnn_rnn_flatten_weight_native.h,sha256=P0hdEU1hNb7Rea1CWF61JHej2oZ9d0VkexcxcTi0dpQ,949 +torch/include/ATen/ops/_cudnn_rnn_flatten_weight_ops.h,sha256=7W3aVDhCVrdFpS4EDqPgr0m7Rv3bVJ7rbF1YBUYNiIg,2827 +torch/include/ATen/ops/_cudnn_rnn_native.h,sha256=m-Nir1dlLXt-vkjI8Q5Y_3V7GcmB8U9tp_tMRtn_AZY,1562 +torch/include/ATen/ops/_cudnn_rnn_ops.h,sha256=6Zn8kTc1_nU0VUlanWVPxujjYAloZLqRqxLjTTsUiEA,4914 +torch/include/ATen/ops/_cufft_clear_plan_cache.h,sha256=De5W0gIjOfmWYEWyAX72DH_b3PNXm8Zb3-T9vC58yTs,752 +torch/include/ATen/ops/_cufft_clear_plan_cache_compositeimplicitautograd_dispatch.h,sha256=sZ8GgHOZvzohD5FEUW7tZ_sTZTME16cUiP-LNvd-hhM,803 +torch/include/ATen/ops/_cufft_clear_plan_cache_native.h,sha256=Wc9dDjWaN16p9Sb7ZFe9dmVSMFNzVXynkbKEl4njFtE,513 +torch/include/ATen/ops/_cufft_clear_plan_cache_ops.h,sha256=ETsZ_sdCgnC6EKKfh8DJM25K38bHcQwoOs__oc4jIRQ,1035 +torch/include/ATen/ops/_cufft_get_plan_cache_max_size.h,sha256=XWoZR_NxG0jxb7i70FKOjIRdPpEgaA0SBD245R6h6iE,784 +torch/include/ATen/ops/_cufft_get_plan_cache_max_size_compositeimplicitautograd_dispatch.h,sha256=coAxmuuCH4XGvd_6_FwGz14XSworvIcYBW0ttMGmwNQ,813 +torch/include/ATen/ops/_cufft_get_plan_cache_max_size_native.h,sha256=oAeLburcrgwBJVMxD3BlZAlaIIAqhAXgZh1aaPFYhoA,523 +torch/include/ATen/ops/_cufft_get_plan_cache_max_size_ops.h,sha256=ZLs8bX20USttKcWCMn535aQ4kH_Qw8W4x_bID31cWHE,1066 +torch/include/ATen/ops/_cufft_get_plan_cache_size.h,sha256=MZSNpguMw0eXo7R7rJu9g9xZrASAKtiYltpknXAv6t0,768 +torch/include/ATen/ops/_cufft_get_plan_cache_size_compositeimplicitautograd_dispatch.h,sha256=2O_JWB-QxCAwpRklIjX45FGsTbUObMViFB4VbHn-evc,809 +torch/include/ATen/ops/_cufft_get_plan_cache_size_native.h,sha256=DkmgPuIAmVBY4i7BUuVfPWameDdhUT0blozOPGkq4a8,519 +torch/include/ATen/ops/_cufft_get_plan_cache_size_ops.h,sha256=tOnbpB6INcxERwcpyqcMfkWKrZ9zZS8oh6ybijZ8UNo,1054 +torch/include/ATen/ops/_cufft_set_plan_cache_max_size.h,sha256=SLDTeZRiucF6qO1FVMtjofJR84OXBZFLhpa6VNJ1sgg,822 +torch/include/ATen/ops/_cufft_set_plan_cache_max_size_compositeimplicitautograd_dispatch.h,sha256=b3uS5uK5v4c6jOP16dk7pbHokPf1w4ew2Hyipse-Sh0,828 +torch/include/ATen/ops/_cufft_set_plan_cache_max_size_native.h,sha256=sFD_qvolufeFys-Wt6YO9Dmjr5puUCo2Ft5lQREo9dQ,538 +torch/include/ATen/ops/_cufft_set_plan_cache_max_size_ops.h,sha256=QY0M5Lzb5eB64edXI6EwMAS2A9U7KvZNfU31Zu_Gm1Y,1115 +torch/include/ATen/ops/_cummax_helper.h,sha256=YWcrL_BPqSW0goMAiSA7OKFdX7jmtw16lP4M0qJZ3B0,816 +torch/include/ATen/ops/_cummax_helper_cpu_dispatch.h,sha256=4TE3lVbvykJfTixJqtbixqDAScdYJW5lJCaD4lvOa_E,801 +torch/include/ATen/ops/_cummax_helper_cuda_dispatch.h,sha256=yABm2pGiE23Pxrp_3sqt2hoFG8O4tXNjsVVdJaKrK9U,803 +torch/include/ATen/ops/_cummax_helper_native.h,sha256=O-H78SWPfUgUPs5TD9DnssMVFP4l72dhmxQmI8eYXZA,675 +torch/include/ATen/ops/_cummax_helper_ops.h,sha256=_TPsVpgdupEATE4SMBIRifCukzb8SAzS4sphI26dnS0,1185 +torch/include/ATen/ops/_cummin_helper.h,sha256=36vvBLH1FYfPyJbMP0iD5YHx8eZM_BfwUfmS-rh2lng,816 +torch/include/ATen/ops/_cummin_helper_cpu_dispatch.h,sha256=8CDptiiSuCCTzKf8iH0m55sD3nTF_m3AAXwJuE6T8gM,801 +torch/include/ATen/ops/_cummin_helper_cuda_dispatch.h,sha256=IJ5aosznRWfztLFDOOVJPCg7YiRQJttziqI52c4R7Vo,803 +torch/include/ATen/ops/_cummin_helper_native.h,sha256=_ntUVDbm6cpv-XFlo9akMoclZO_Ls0ek2yBKehb1rMg,675 +torch/include/ATen/ops/_cummin_helper_ops.h,sha256=nGFlwwlKnVNGdG7bGtjClehJ8U4Hc1jeykt0z-6mmKk,1185 +torch/include/ATen/ops/_debug_has_internal_overlap.h,sha256=Ttyllf8k24IuWCcuulrvt8ywAzne1vyl8SNhJZgAhRk,746 +torch/include/ATen/ops/_debug_has_internal_overlap_compositeimplicitautograd_dispatch.h,sha256=8lp3ZiiB_Q3FRTdbxZyaYyc003SiCraCw28Fi3LcrxA,805 +torch/include/ATen/ops/_debug_has_internal_overlap_native.h,sha256=TmMTcLdxow0WpTsc1jLDnIy5xB5gHZhPgVnCPUgdx70,515 +torch/include/ATen/ops/_debug_has_internal_overlap_ops.h,sha256=0y7BrRqKg6jSo_-5Ai55IaaK1-_6o2ezxBQQEmOrbuo,1037 +torch/include/ATen/ops/_dimI.h,sha256=c2t8XmQQ4AqcWMSXnnHlQ0PqGS5nvfjJyp5n-tQqDWk,529 +torch/include/ATen/ops/_dimI_native.h,sha256=zcbHWmVfEMALURiWxtyvZym080UXkzqJDWS_zr4qqc4,505 +torch/include/ATen/ops/_dimI_ops.h,sha256=867f6gXDjtTFgSLNpjzeF4sibtAQKyAcEYKaMb9PrkI,971 +torch/include/ATen/ops/_dimV.h,sha256=ZT8Ji4RYYGxSYVFM-FmuZaFyp72k_pSvZfneqZ6Nq6s,529 +torch/include/ATen/ops/_dimV_native.h,sha256=-LfQLBtX7FWF3mFjQFKQW9pFNazFQHGtT692HfIwjvE,504 +torch/include/ATen/ops/_dimV_ops.h,sha256=tfvMJSFUOgs-coHRaFH9J_K0OF2pWqifa6iSUkZ8heg,971 +torch/include/ATen/ops/_dim_arange.h,sha256=3yhy_-i0U70rleMjWlzmbmubOdCusxvKWRYYswQPHYA,715 +torch/include/ATen/ops/_dim_arange_compositeimplicitautograd_dispatch.h,sha256=D1_G1LZIXaOY10Ytt81tpICevbTYbqUnn99ScJm4oQ4,805 +torch/include/ATen/ops/_dim_arange_native.h,sha256=u9JAfMjws3zS7x5EHwRIqkNB7EI-DGw4_tDAH7NZtgQ,515 +torch/include/ATen/ops/_dim_arange_ops.h,sha256=E6Ja2LdDhh42UnSdXit39SPFV2GeaDlQLyJQ1A7botY,1045 +torch/include/ATen/ops/_dirichlet_grad.h,sha256=uiPfJ57M-eBd2ySXf8WCmd9-wW-ePv6V1i3Aj5WHgBo,1412 +torch/include/ATen/ops/_dirichlet_grad_compositeexplicitautograd_dispatch.h,sha256=J2Lw6TIc1q6mbeV9f14vle87_I7zINm_t4x5ev_RAYc,1007 +torch/include/ATen/ops/_dirichlet_grad_cpu_dispatch.h,sha256=cCVdLbAtbuZUQ4hDEGRP39SUMHhgcQJBxVHDamWusqw,801 +torch/include/ATen/ops/_dirichlet_grad_cuda_dispatch.h,sha256=hvM6bwy8zezkrXEghiJLdYuSA_TgPOX_mzhRU_ldxeQ,803 +torch/include/ATen/ops/_dirichlet_grad_native.h,sha256=8V4HR8Pz102baetFYhGWbw2MBtI8NAvmq6Fdixgzupg,814 +torch/include/ATen/ops/_dirichlet_grad_ops.h,sha256=RPvXu0PEKdIHBqPQd0jGUX-WgSJ9kasyTq4aYhIoqck,1967 +torch/include/ATen/ops/_dyn_quant_matmul_4bit.h,sha256=PY3gROoS7jjyVDi1kMC6OprEgyV_XWPovXjpEJKWyvk,956 +torch/include/ATen/ops/_dyn_quant_matmul_4bit_cpu_dispatch.h,sha256=qSEmIla40gf3X7d0oWuseE1zlA6B3lpwGJ1VIJ_9cnI,856 +torch/include/ATen/ops/_dyn_quant_matmul_4bit_native.h,sha256=YJ4MCrVX3gtMTp4rz2Z6fLhymraXeyKyllCYyOFOMFQ,614 +torch/include/ATen/ops/_dyn_quant_matmul_4bit_ops.h,sha256=nTCyheebB_VaqoAV8ayPfutdvSQHuAWETMsnzDTLM0c,1348 +torch/include/ATen/ops/_dyn_quant_pack_4bit_weight.h,sha256=fbeVwSXhM_d4NL36l0WjZvphKFNjRHCI34Lkkf2yH7I,1044 +torch/include/ATen/ops/_dyn_quant_pack_4bit_weight_cpu_dispatch.h,sha256=ZElcFCeJmFcSY4RGN08vWZ4MOtyCKWhkpq_8jlZvECI,905 +torch/include/ATen/ops/_dyn_quant_pack_4bit_weight_native.h,sha256=mHj693LFaYOkRCK0AXl7zM_ujMzhgseNaJLd2KukwgY,663 +torch/include/ATen/ops/_dyn_quant_pack_4bit_weight_ops.h,sha256=QPX5OK-Dj90j8u_bYe4oBEAftzu7A438DXdZxQ7Pm_s,1504 +torch/include/ATen/ops/_efficient_attention_backward.h,sha256=gwEK0Wx6X3M1Yc_v33qQotzqB30pLvf5C93wbRZzoWI,6032 +torch/include/ATen/ops/_efficient_attention_backward_cuda_dispatch.h,sha256=kFN97yp1ATeDGjeyNRLoqnoNCHSE1UcBeSrV0m01nlY,2245 +torch/include/ATen/ops/_efficient_attention_backward_native.h,sha256=Ee0bzVjaExWpI3F_r6PpA1q1JqLAD_3Ykoa6VjnDP8g,1212 +torch/include/ATen/ops/_efficient_attention_backward_ops.h,sha256=4MebUD_KjhOW8d94dxGsOlIQubKvPC9nrYwO1z4cg-E,3195 +torch/include/ATen/ops/_efficient_attention_forward.h,sha256=-ldavXQ-kC7SobYOtiZG2RQwdQRGn4jhmbAWIl7uRqo,5453 +torch/include/ATen/ops/_efficient_attention_forward_cuda_dispatch.h,sha256=FtQbtY5LU7H6CL2YtWxg1HvI86_NbHj3xtF9PwDza-I,1989 +torch/include/ATen/ops/_efficient_attention_forward_native.h,sha256=CDYfRRuKyj1pwy2LIhP4QZXk8k4H43Ox9mytXB_cer0,1084 +torch/include/ATen/ops/_efficient_attention_forward_ops.h,sha256=Xf9eEVsLAtBQMo3FajoUgQXntMESbCNd7eMq4Id3UkE,2903 +torch/include/ATen/ops/_efficientzerotensor.h,sha256=G-JR9PhsKFReUBwB8ADyF9BrcGEr78z4QpyicAhh0LA,6154 +torch/include/ATen/ops/_efficientzerotensor_compositeexplicitautograd_dispatch.h,sha256=Jsq941vZgJsSL4ZUDvB-5sJUiMszUI9k1Q__a15r0-4,1116 +torch/include/ATen/ops/_efficientzerotensor_cpu_dispatch.h,sha256=KMZygXI7znw7yWrw3AYs7mzTGKcg7HgWaENnpIfYI18,1324 +torch/include/ATen/ops/_efficientzerotensor_cuda_dispatch.h,sha256=Oei6Sayslmj0Z5549con22Gc8Qo2MVswfrWFSX0YWXk,1326 +torch/include/ATen/ops/_efficientzerotensor_meta_dispatch.h,sha256=bGLsy5Tstt_a4I_7WDO7AlDySs-H2uGJ7xP3wr2Bah0,1326 +torch/include/ATen/ops/_efficientzerotensor_native.h,sha256=6Mv-91MCYZ7wYglRtHlhuIYy33b0N2UmItJHKD12ONM,1233 +torch/include/ATen/ops/_efficientzerotensor_ops.h,sha256=ydtmlfRsjQLzbSvKCL2hjs7mdIznd5KGNW1MaviMVS0,2177 +torch/include/ATen/ops/_embedding_bag.h,sha256=NtHaRCCkMhxrHNfEil4MTjxcO-WGTDRJBTJ8qjeK1zs,3218 +torch/include/ATen/ops/_embedding_bag_backward.h,sha256=SA0UwMoDZsxCBjjRKsdsNS5MgdyCv3R-lZhQ34bSD6w,3598 +torch/include/ATen/ops/_embedding_bag_backward_cpu_dispatch.h,sha256=dCIaE3D-hUpyhBClc5tC6_RhzCw9R5ysdjBdk53NJJw,1453 +torch/include/ATen/ops/_embedding_bag_backward_cuda_dispatch.h,sha256=XFQAhjtGtXouBF3js9c7w4rOWr99HV8rNqGdYsr9jqA,1455 +torch/include/ATen/ops/_embedding_bag_backward_native.h,sha256=EUh9ZAv4YITRUlk6FEmYiho90vJKDUtiFEegXZbqUF0,830 +torch/include/ATen/ops/_embedding_bag_backward_ops.h,sha256=aFee1PjELJDFcvGHPJp5UlOuJyJriYOqFkeRxUlZT_A,2037 +torch/include/ATen/ops/_embedding_bag_compositeexplicitautograd_dispatch.h,sha256=sXQM-GONqfJrVdFQOQhQ0c_1B40JF6mlUDw6H8hzGeI,1581 +torch/include/ATen/ops/_embedding_bag_cpu_dispatch.h,sha256=9HsmBlp_fRkJxKAId4Rv0AQLTZ5-BX5K9H6QZX7P7EU,1037 +torch/include/ATen/ops/_embedding_bag_cuda_dispatch.h,sha256=zXURLiwgQtjeFKXavvALtOKGjAItrd1TcjJE2J_g0oQ,1039 +torch/include/ATen/ops/_embedding_bag_dense_backward.h,sha256=H77B7NHEp_UEDCfdNj5zxjqGzCukYJHT0SRLxw6_Fgs,9384 +torch/include/ATen/ops/_embedding_bag_dense_backward_compositeexplicitautograd_dispatch.h,sha256=IiTRixmKVcmtBEV5qqymnG2DcH5rkZoYW6IsDY51zPc,2214 +torch/include/ATen/ops/_embedding_bag_dense_backward_cpu_dispatch.h,sha256=chqiYJpw9fBfeSFVUlge_FXaNGf2OpK-xeGdBhUe1bk,1383 +torch/include/ATen/ops/_embedding_bag_dense_backward_cuda_dispatch.h,sha256=J-2rDmXp6AZhbnXX6rus7rOHFO0X0GDDZPKGWQe6YxU,1385 +torch/include/ATen/ops/_embedding_bag_dense_backward_native.h,sha256=zeS66u0IBakfN5cOzeV16fe4aazyfQFrFelfIeaqQHA,1509 +torch/include/ATen/ops/_embedding_bag_dense_backward_ops.h,sha256=kZ_LQ-gooyVSt7OsVDRD0NE1bgoDpHSUADDS0o5ros0,3451 +torch/include/ATen/ops/_embedding_bag_forward_only.h,sha256=uqGq-93W4IBdsxNTfMHn6h7bmjaScLjirjXYbF7nkSw,3348 +torch/include/ATen/ops/_embedding_bag_forward_only_compositeexplicitautograd_dispatch.h,sha256=H8G9STbraeYLRHjDVTlzyFguxbTwd6wtnsJbQLOdhgA,1607 +torch/include/ATen/ops/_embedding_bag_forward_only_cpu_dispatch.h,sha256=IXK6lcUfDgVr4iBJggbGeWhEOva2EKMgduo5UTHmvMA,1050 +torch/include/ATen/ops/_embedding_bag_forward_only_cuda_dispatch.h,sha256=ZjDrVdOsA-4mV1xGh8THy5WTsXZsuRJBbBsoGLCjCwk,1052 +torch/include/ATen/ops/_embedding_bag_forward_only_native.h,sha256=AfO3NY19_eKrhpU9iQrr3hFZ752GZWgzUoUOOhTbSzU,1599 +torch/include/ATen/ops/_embedding_bag_forward_only_ops.h,sha256=YeRsSokXT-UplRO9ttqMmQdIKjrrzGSUWiMe256XMnM,3727 +torch/include/ATen/ops/_embedding_bag_native.h,sha256=uYnr8DQyPEG9HGFc1BCvOpH0cBNzcSI1lyIK9QQFYOA,1560 +torch/include/ATen/ops/_embedding_bag_ops.h,sha256=XClfIw8peqimFNOpWj5Gr967yuM5S71hvrxOel-2ZWA,3649 +torch/include/ATen/ops/_embedding_bag_per_sample_weights_backward.h,sha256=VCv3rDLbLU-Z8pLOPzoCZN4FGAYcMevlnc4WiODR9cI,2339 +torch/include/ATen/ops/_embedding_bag_per_sample_weights_backward_compositeexplicitautograd_dispatch.h,sha256=46y_fCgO7CufPS7m6JpjS3dU77xiO3x2OT-_rNEHu5g,1264 +torch/include/ATen/ops/_embedding_bag_per_sample_weights_backward_cpu_dispatch.h,sha256=vQB8MnRiwSFdSKm8IxFkIKg1sTzwqJav1kNeKrn5298,931 +torch/include/ATen/ops/_embedding_bag_per_sample_weights_backward_cuda_dispatch.h,sha256=fXyYQ8V514WTtBNpFJPsfAZXrTHRlYUS1YzsFG-1D8U,933 +torch/include/ATen/ops/_embedding_bag_per_sample_weights_backward_native.h,sha256=lMguZ4UC1i9vlSvej3VU4mtyZcP7DKKRg8HR3dxhr04,1201 +torch/include/ATen/ops/_embedding_bag_per_sample_weights_backward_ops.h,sha256=5HCeyWcJ5iyg160lh7E_7S0oI7zMy-D6i6K36yBKpW0,2787 +torch/include/ATen/ops/_embedding_bag_sparse_backward.h,sha256=tlKRgriN4Esnk_7y0zrPQe6LXYvEptAm3vslacV7VYw,3305 +torch/include/ATen/ops/_embedding_bag_sparse_backward_compositeimplicitautograd_dispatch.h,sha256=YQ9BeYv0rN7czAVUm_aitT5DQ_tXh_UjUOGCQo25zOw,1413 +torch/include/ATen/ops/_embedding_bag_sparse_backward_native.h,sha256=dDh8vM6hr0a3CqYYSjV_HhRB-D0rFdcJN0Gs5mONv7o,788 +torch/include/ATen/ops/_embedding_bag_sparse_backward_ops.h,sha256=BMA3WedMNjJnXW8Ws-qB6bTEK3D75gZFAVoCQS7qluk,1897 +torch/include/ATen/ops/_empty_affine_quantized.h,sha256=Sg8sImiDCc_kPk9jsVajmV7Q0uvG5OmtJdyocJwaU1A,9333 +torch/include/ATen/ops/_empty_affine_quantized_compositeexplicitautograd_dispatch.h,sha256=ziPomIWxyT5xw8NzyZluQapVvcwwTg3xdrvf2qZFNaY,1528 +torch/include/ATen/ops/_empty_affine_quantized_cpu_dispatch.h,sha256=fbXfKPtdYpFJDaRgnXR0pPC9Z8CLWRLJDimLk6Nj0zg,1736 +torch/include/ATen/ops/_empty_affine_quantized_native.h,sha256=uzSrXR_iRveaEtKWyszVeO9eWjmh6hEM4t4si79frK8,1333 +torch/include/ATen/ops/_empty_affine_quantized_ops.h,sha256=fUTKrNyJtYhkjNYZIRVqM7Wzgt0dhjn90UbdWVes_WY,2791 +torch/include/ATen/ops/_empty_per_channel_affine_quantized.h,sha256=U3wJOLwRqjaIJnHNxSfwlC6cj03Yndw6s9LeDUHYbBk,10641 +torch/include/ATen/ops/_empty_per_channel_affine_quantized_compositeexplicitautograd_dispatch.h,sha256=UjoCaV_Sbsz47DyvPvbRFHMMIKcrTyEzsJ7h_woaMAA,1724 +torch/include/ATen/ops/_empty_per_channel_affine_quantized_cpu_dispatch.h,sha256=6zCDElUzYJLOI0M5kbuXCmgye_tB7QBCcn7r9WFUXbU,1932 +torch/include/ATen/ops/_empty_per_channel_affine_quantized_native.h,sha256=D_78nwueFpuXn3hrh2Abpb1fLLhigrtvJjkDyP3sunI,1478 +torch/include/ATen/ops/_empty_per_channel_affine_quantized_ops.h,sha256=Fo0_Ro7mYx8rYtRa-vulWh1KHniY3gyOVTpgKps17xc,3107 +torch/include/ATen/ops/_euclidean_dist.h,sha256=tNIL46gixlG8N7xtywbYULUl1fPMs2dpKXlzjEd0RMM,1253 +torch/include/ATen/ops/_euclidean_dist_compositeexplicitautograd_dispatch.h,sha256=VgJ96vyDrxZN1XUZxLPgQOs_5l0_g8rmAA3Y6GeZXIQ,1036 +torch/include/ATen/ops/_euclidean_dist_native.h,sha256=FS7Z22psFIRTa-eQ-Lx-uTwWqvbWf3eypWXm3AIGaXM,636 +torch/include/ATen/ops/_euclidean_dist_ops.h,sha256=Op8u6It1atX1TE5jNa0NeE_1ZqpYtLOoAJ-wdQ2SAA0,1783 +torch/include/ATen/ops/_fake_quantize_learnable_per_channel_affine.h,sha256=jqzR3kdQjxBEN2Fs2EDn32wME9HRayesFlELB09cVno,2300 +torch/include/ATen/ops/_fake_quantize_learnable_per_channel_affine_backward.h,sha256=9yriB29bR4ZFPWfEoPCQ9uDNgdYyjT6Dr_kfAeCnglo,1239 +torch/include/ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_cpu_dispatch.h,sha256=FPwjopwMRZ5Mor2-LX2qF7tcqmhqw8rfizJH-T4xP6M,983 +torch/include/ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_cuda_dispatch.h,sha256=nEp_tdqqyM4Dhe3GDOev0wkopNIGEY3gtXry0GfaLcA,985 +torch/include/ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_native.h,sha256=3P3PUw34reJA4xbcyVU7IGoNbUpWxmBm8a_7n-PB-H0,737 +torch/include/ATen/ops/_fake_quantize_learnable_per_channel_affine_backward_ops.h,sha256=MvPyLpiLB3OfMzDxHE5h0vOYugGqzrCchmZEivAcyD8,1762 +torch/include/ATen/ops/_fake_quantize_learnable_per_channel_affine_compositeexplicitautograd_dispatch.h,sha256=t1YDBDsEQQSoQ1AccQYSUxgFF-Fr-F4ptpVSX1o1OTs,1227 +torch/include/ATen/ops/_fake_quantize_learnable_per_channel_affine_cpu_dispatch.h,sha256=IGoxYOEkoyND0zcPNTKoxi1OpN_tkx0zVRL8tVp2FN0,913 +torch/include/ATen/ops/_fake_quantize_learnable_per_channel_affine_cuda_dispatch.h,sha256=NbN7-ghjOZ0EJib_IXDPFEGB9laMpx2v2hpG6WAd320,915 +torch/include/ATen/ops/_fake_quantize_learnable_per_channel_affine_native.h,sha256=okCgF6jII9G2-M14l9KY58I1s7wq0eAyeXOncS-MS1c,912 +torch/include/ATen/ops/_fake_quantize_learnable_per_channel_affine_ops.h,sha256=iQdkSgsLlHc8rnLW07Ih9dVqpDOaVFrUc_Q_1X5FcVM,2667 +torch/include/ATen/ops/_fake_quantize_learnable_per_tensor_affine.h,sha256=VIXrVJK3zafbBySXnK8cF-650S1CrTOCiVqdK0pc73k,2200 +torch/include/ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward.h,sha256=V3dHewjCiwnzo3RsIxSYN_jqDH1T99wGs0iEj0R8lPg,1205 +torch/include/ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_cpu_dispatch.h,sha256=f1-peSegO2axvPAOgjk07DzQNnUmrMf2E0ur1pk9JZQ,968 +torch/include/ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_cuda_dispatch.h,sha256=vRbsK3_bPuwHjafv2snWgrU8yCPVZkmao3RrTzbZCro,970 +torch/include/ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_native.h,sha256=vuiO2bQTKguafdDrYQpgWFNS7QVvlI49kxU1chLFzaQ,722 +torch/include/ATen/ops/_fake_quantize_learnable_per_tensor_affine_backward_ops.h,sha256=d2J84TxpM76qth7BMcZkA32yk-TumZ6zDob6p_v412A,1712 +torch/include/ATen/ops/_fake_quantize_learnable_per_tensor_affine_compositeexplicitautograd_dispatch.h,sha256=dTCvFR7Bp7PDRGXBKh8Jcj_22Eu05KMUaQps65WIHW4,1197 +torch/include/ATen/ops/_fake_quantize_learnable_per_tensor_affine_cpu_dispatch.h,sha256=64WmAk9fXRmQWlLsfCgzdwvzleOuArR9dbk6zNfPHmw,898 +torch/include/ATen/ops/_fake_quantize_learnable_per_tensor_affine_cuda_dispatch.h,sha256=ic9rTCIneF0Madl0CR6tNVG_kuqAb41e_3RtlcW_XM8,900 +torch/include/ATen/ops/_fake_quantize_learnable_per_tensor_affine_native.h,sha256=SW9TeXWZHwGDm7_Mm9UFZOWzk1vHalPiBoIGdY10hc4,882 +torch/include/ATen/ops/_fake_quantize_learnable_per_tensor_affine_ops.h,sha256=yVulTfdPO6F930g7l5S93NYehOtd8BGtAFcopIaKP8k,2567 +torch/include/ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.h,sha256=oVEqbOCeDqQ8duVXhbgvD0ObnG0fDxPDQfJucKmZ_Ak,2651 +torch/include/ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_compositeexplicitautograd_dispatch.h,sha256=x7VXUMU1EePYrHbAr0v7NPx4FVivZsENxLvXoLSU4p0,1355 +torch/include/ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_cpu_dispatch.h,sha256=yKvNLyeA39e-i1pq3WsTGlj6ENKgBPgnMIcBURFicRA,953 +torch/include/ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_cuda_dispatch.h,sha256=sI7sJCe8CTUPpcgUdxVKpNJ9mO4MmoYDHcMQFcr-gA8,955 +torch/include/ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_native.h,sha256=ucl3FbVcZo3I4iHOk3svxtKa9a5ZQR7jTW2RTvvT2-Y,1018 +torch/include/ATen/ops/_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_ops.h,sha256=6I9WB9x_LETRe-WbDVBdUR51AzL7Yq1D6znvMwOr7Cw,3029 +torch/include/ATen/ops/_fft_c2c.h,sha256=ho0PCt0UcHkcnfRpLDOgr0KGqLEqeXsVV0zu9e-J45I,4545 +torch/include/ATen/ops/_fft_c2c_cpu_dispatch.h,sha256=wAmu1DWWu1j9szShrxWYL53xHNOHyRLBNGA44vOKy_U,1509 +torch/include/ATen/ops/_fft_c2c_cuda_dispatch.h,sha256=fjQPFTTZs2b6GYFt7InOmQpWlAgu4w4KPmT6Rkritp4,1511 +torch/include/ATen/ops/_fft_c2c_native.h,sha256=YAyFODmdfRpw4uZgNcvYbi6j95orgBX2bt5EIj3GI7I,970 +torch/include/ATen/ops/_fft_c2c_ops.h,sha256=18OqgM77_wsabB-8oP-nECCNnENbDadXpY5HGm0WGE4,2013 +torch/include/ATen/ops/_fft_c2r.h,sha256=m_LBPG-gbGI-J5HZU191rY3cdlEjDIdjhk7dv2FYJMA,4599 +torch/include/ATen/ops/_fft_c2r_cpu_dispatch.h,sha256=SbhvFpgjNgLZTLXT5tlEK5bN6EKNbcrH0Rk7G5CW1Ns,1563 +torch/include/ATen/ops/_fft_c2r_cuda_dispatch.h,sha256=tv0-83Z0ZREg6HU-rqmBLJrvloi3V4RnNSLp5yzO5Ug,1565 +torch/include/ATen/ops/_fft_c2r_native.h,sha256=mNom9fSFT4KlmLMEB9Z1KLxkdTfwe4Sz_-Cb8DGKHAs,1006 +torch/include/ATen/ops/_fft_c2r_ops.h,sha256=ElgKsnI8DQLjMCdAjFzIaB0dcLAU5wKEGQdapy85s1E,2065 +torch/include/ATen/ops/_fft_r2c.h,sha256=PdQtqZ37x34o2Dz9P73nzGRSlq6xLbP-NkabWsait8A,1489 +torch/include/ATen/ops/_fft_r2c_cpu_dispatch.h,sha256=ct5lK1ri3Q14YGm8pWQn19SqBtrIZK_xpfX8N9x0ToE,1085 +torch/include/ATen/ops/_fft_r2c_cuda_dispatch.h,sha256=o0oMO1tk5sxpJsqPgchQsejOZjnF_KnMnELoayFbbiw,1087 +torch/include/ATen/ops/_fft_r2c_native.h,sha256=ICsPy-UCUPfqBcLdwfLL7zo6kVrpj8UmTgIm-w5Ap30,974 +torch/include/ATen/ops/_fft_r2c_ops.h,sha256=gTIzyVLjl0vXa6XVgHQcQ_R7zGdwyyeJDyn35R6u0m4,1989 +torch/include/ATen/ops/_fill_mem_eff_dropout_mask.h,sha256=uOjJFqFS80cKjI3GhCIcjeDUh_sa4wm7oj23KWHbotk,867 +torch/include/ATen/ops/_fill_mem_eff_dropout_mask_cuda_dispatch.h,sha256=3ociCup4lwP685YF49dnAvCi5nkJGkNlMqMN48mxung,810 +torch/include/ATen/ops/_fill_mem_eff_dropout_mask_meta_dispatch.h,sha256=zMUpMfhwRsLqGmjc_4LZq-b_teCSFmAi0qy02cj9tQY,810 +torch/include/ATen/ops/_fill_mem_eff_dropout_mask_native.h,sha256=EHoO9hn96LF58BeZcdNEeiuc8Roprh5KSlliCfpaaKY,562 +torch/include/ATen/ops/_fill_mem_eff_dropout_mask_ops.h,sha256=5IqyhUz_d-oo8Ihs43SHX0kfi2uutsAzMwDbxOsX2nU,1206 +torch/include/ATen/ops/_flash_attention_backward.h,sha256=qDUQG4456D7SdXjFxX9lyAiBrGtgRom4Isp-C3Y7Bn8,5108 +torch/include/ATen/ops/_flash_attention_backward_cuda_dispatch.h,sha256=0j6cPRqrHXdom_IVP95vPclsvTbWZd_ma2OUGdgYUeM,1889 +torch/include/ATen/ops/_flash_attention_backward_native.h,sha256=ZPZ5V0VSBN8KVJ9skMXOCKA25lU6NAqGeuxuNcw2GtU,1030 +torch/include/ATen/ops/_flash_attention_backward_ops.h,sha256=yqqAywB6FrmSxNcntExtRL7OM-Wa8yhtxb1sxtlBj0E,2661 +torch/include/ATen/ops/_flash_attention_forward.h,sha256=lXz71uzefE6Y1rheduTAXCDnHrpDvY3AIGD0g4yF4n8,5407 +torch/include/ATen/ops/_flash_attention_forward_cuda_dispatch.h,sha256=sUj0-HwAaSjZQXZ7Xn_hLJZSrIGaz_7q5abB2qwLZIc,1973 +torch/include/ATen/ops/_flash_attention_forward_native.h,sha256=szTMR0q2iri4tvbS6DSoKiJ6GPkP8w3HKZ1KZOmUuRM,1072 +torch/include/ATen/ops/_flash_attention_forward_ops.h,sha256=Bt6ZTQYIYU6oEBTLtbU5baRT3uCzej2dx5n1DCzZPSY,2839 +torch/include/ATen/ops/_foobar.h,sha256=VEW_o4BsMx_5tp2y5rSVIS-GaZyOacySugiZDCglsqw,1407 +torch/include/ATen/ops/_foobar_compositeexplicitautograd_dispatch.h,sha256=ZvqcYw1re99EWIJpcwObLygnWF_KlgBDfmCjOGqXgEY,974 +torch/include/ATen/ops/_foobar_cpu_dispatch.h,sha256=2vxrpPWFIf4INqrI91L_-C4J_LLUtdWSVwDdLIzqONo,792 +torch/include/ATen/ops/_foobar_native.h,sha256=U3p8GXFYMFx_80Aq-gV9GLRtblRLvniCOgMEigXQk80,658 +torch/include/ATen/ops/_foobar_ops.h,sha256=SyLLcUrZMav6bl6WsEu2KiLIjj0GSQKj_b9ihd38uh4,1860 +torch/include/ATen/ops/_foreach_abs.h,sha256=DpyYAB7XVVSe4C1OqUUU4r9e8JQmaGsRWkUmA7T6oy8,1267 +torch/include/ATen/ops/_foreach_abs_compositeexplicitautograd_dispatch.h,sha256=kUQUxs6sZ2byArP_CvtApPbtDd5X-m8gB8N7cwW-yJY,1007 +torch/include/ATen/ops/_foreach_abs_cuda_dispatch.h,sha256=Qw1k-WuOgu4GKxK0XUVaEoiHVD4QbYeRZ7q1yVLJTn8,814 +torch/include/ATen/ops/_foreach_abs_native.h,sha256=p7_ydGAIf1ywyyeizCzynEkNbdFHyVqpSYVKDnqjrkI,809 +torch/include/ATen/ops/_foreach_abs_ops.h,sha256=-JeUJmaegkpbghxTaHJGtuaPxizGfsKEcrAa74V4Kt8,2130 +torch/include/ATen/ops/_foreach_acos.h,sha256=C3XWk-lXOpX1rczhQcutjnh4Gr4Pjz2y17ifoZhgg64,1280 +torch/include/ATen/ops/_foreach_acos_compositeexplicitautograd_dispatch.h,sha256=5IbR8Wy4RrzRT_kNqO5ezC3dsUqjbKihB5UEI5HoAss,1011 +torch/include/ATen/ops/_foreach_acos_cuda_dispatch.h,sha256=c6xSwKxZMDZiUgHp5CuFqVmsfcgrtaK5oqZtf6EdWSs,816 +torch/include/ATen/ops/_foreach_acos_native.h,sha256=6P2f9syc_VRQZG3KA8Lm0GXVGpgwBRqjJv3UsLlBsNg,814 +torch/include/ATen/ops/_foreach_acos_ops.h,sha256=f4ResbOKP03pX35OgHQOnJP9S35Q7YLpIp-3LRkqx_U,2139 +torch/include/ATen/ops/_foreach_add.h,sha256=-czu79zNiR60DVzcynA-_iqs2l6LduFHc6wTfZX3tVE,4934 +torch/include/ATen/ops/_foreach_add_compositeexplicitautograd_dispatch.h,sha256=oNbbzPZoBrCIsa6P_mCWF0ciMh2X8ijV656PFK7J4zU,2488 +torch/include/ATen/ops/_foreach_add_cuda_dispatch.h,sha256=xQ2MAkn5YomoIujllpgMFzewqVE67U4AIF5Vei8N5xk,1516 +torch/include/ATen/ops/_foreach_add_native.h,sha256=Asl5jLOdoUd36VA50XG2TRxA65eP9r19lOcfadxX-FA,2993 +torch/include/ATen/ops/_foreach_add_ops.h,sha256=on8Y_V7hv8Mjjekvmo9G9lUpJnIaJEKlQSJMFx8_Zjs,8986 +torch/include/ATen/ops/_foreach_addcdiv.h,sha256=xR0qc5-1l2IHiXaRBePZ9usAL6lMeF3T3m6D4aeFnT4,4963 +torch/include/ATen/ops/_foreach_addcdiv_compositeexplicitautograd_dispatch.h,sha256=ipb7N6Dow4atZS691BBA1QWt4rnIYwozoQvQ9SpIf_0,2539 +torch/include/ATen/ops/_foreach_addcdiv_cuda_dispatch.h,sha256=rrYVhC2_COK8aczDU_E_nMEFiXcJw-kffTEsQvUboLo,1554 +torch/include/ATen/ops/_foreach_addcdiv_native.h,sha256=1lzi64wPrGW4PAYk_lXvqGEzVNQ7eAKsA6lXUgk_hdM,2896 +torch/include/ATen/ops/_foreach_addcdiv_ops.h,sha256=J2BainVREVUFCWgfskP7YA2qn3Pfk_CWGfkF5ohppDk,8119 +torch/include/ATen/ops/_foreach_addcmul.h,sha256=srcr39WXGDgGdb_MBiLjWohYb4w3nlIIwqhXO38EtHQ,4963 +torch/include/ATen/ops/_foreach_addcmul_compositeexplicitautograd_dispatch.h,sha256=CRP-SmR5l2ZSdDq7qvDx1ejLqdtActICQX1Ls7jbnI4,2539 +torch/include/ATen/ops/_foreach_addcmul_cuda_dispatch.h,sha256=G8fJo09MXBjnvSfNd4LUucmROjBNf88fdiC3fn2fZT0,1554 +torch/include/ATen/ops/_foreach_addcmul_native.h,sha256=N4OG1S5dqRCeS9bkpGyL_RwGZgYNKoo5uVfq33jXpw0,2896 +torch/include/ATen/ops/_foreach_addcmul_ops.h,sha256=ltBYp_fa-i3s_Gp_Pdo7adbVoyCO2x-NvDk5pPkgzgc,8119 +torch/include/ATen/ops/_foreach_asin.h,sha256=7-sEMomfvXsSzEimtphRApRJqgyKk55zwOVx26CJmQ0,1280 +torch/include/ATen/ops/_foreach_asin_compositeexplicitautograd_dispatch.h,sha256=lDJhe__vH80iW9Z4vMq3qpwfBJMze0_z_HQaeIwvSL8,1011 +torch/include/ATen/ops/_foreach_asin_cuda_dispatch.h,sha256=1rbb-HDxkr6m56GZgodXz9fQddUihJ_ygSwZe_hfH14,816 +torch/include/ATen/ops/_foreach_asin_native.h,sha256=6TRxRI75narXiE9A7MxIWGq07u7VEgeUt9IgT7fETDQ,814 +torch/include/ATen/ops/_foreach_asin_ops.h,sha256=eVPymi5fuAv0K3mV0h16NT4ANs-rsFZ6flxCu9HC-ss,2139 +torch/include/ATen/ops/_foreach_atan.h,sha256=qvkkSWS6EF8I9ZVGovUBfIvIg2o1FXyeG2gb6tLFyyQ,1280 +torch/include/ATen/ops/_foreach_atan_compositeexplicitautograd_dispatch.h,sha256=NMZhOmBaYHBQhJfPl1cfRdF1zmrzaOfY1kPRdxObXgY,1011 +torch/include/ATen/ops/_foreach_atan_cuda_dispatch.h,sha256=uZIksQAqJTNoIbO1YiFnOv2QPiwFN-QPhJXPuPa89t0,816 +torch/include/ATen/ops/_foreach_atan_native.h,sha256=2xJgms-3tFrfU1tQMwnE9CfjUm8eMOUAd6S_zz1HaoU,814 +torch/include/ATen/ops/_foreach_atan_ops.h,sha256=IVent7-8VSdeyG2OKoOwEpjEcBT9I3l7P-4PGaEBfYg,2139 +torch/include/ATen/ops/_foreach_ceil.h,sha256=SoxspYVeCD-lICX4p9OHCV02BAtmWKRYtHtnFfbWhSE,1280 +torch/include/ATen/ops/_foreach_ceil_compositeexplicitautograd_dispatch.h,sha256=kLN4VAvFkFZlpikwlYzZ44cT4pg81X7cW55HqBpdCkI,1011 +torch/include/ATen/ops/_foreach_ceil_cuda_dispatch.h,sha256=KBXQWJjHv7FOHQgr5yRNEv52estlmnxROa2DxV7IpfM,816 +torch/include/ATen/ops/_foreach_ceil_native.h,sha256=mx_ky4CdjLbfjLb4fRLUwYg1UWnNN_8K_O5Hccr6zjs,814 +torch/include/ATen/ops/_foreach_ceil_ops.h,sha256=lrcBk_WcnHtsZSziCpmuww9FfK9n4zy4mM5APNnYI-o,2139 +torch/include/ATen/ops/_foreach_clamp_max.h,sha256=bF3ncWMPvDA9uO4ZCVvP5Jgr4XOsj2xFDS_VaK8BXiM,3763 +torch/include/ATen/ops/_foreach_clamp_max_compositeexplicitautograd_dispatch.h,sha256=D3ltR38bv3Qdzaa9h5KScxHGr3i64maJJlUVSl27Pac,1961 +torch/include/ATen/ops/_foreach_clamp_max_cuda_dispatch.h,sha256=bb54PiJWgh4tBRpYchjbc5K6mq_QWHGv6aXhI008bfI,1264 +torch/include/ATen/ops/_foreach_clamp_max_native.h,sha256=8p5-Xl6ymXVMDno0psBrmmmhH8BuItcynifEBMdK71Y,2247 +torch/include/ATen/ops/_foreach_clamp_max_ops.h,sha256=aZ13CIKvyIkka0mfQg8UPaNClv4dr1Ffzhp30Rc46h8,6634 +torch/include/ATen/ops/_foreach_clamp_min.h,sha256=XKYyx1wA8nZSyxKlY4l2oINLQBQmKg7zirJkE5CWo04,3763 +torch/include/ATen/ops/_foreach_clamp_min_compositeexplicitautograd_dispatch.h,sha256=d90ekybbmSO5uN3Z_bxl_FQ3ZXuUmHiqbgGjBY4YP_8,1961 +torch/include/ATen/ops/_foreach_clamp_min_cuda_dispatch.h,sha256=0jzIfY2kxZ8gwwSGi2dscH67PlfDWPHRvvVuC7dbEeg,1264 +torch/include/ATen/ops/_foreach_clamp_min_native.h,sha256=FfZ1dx3MYvzSA4sibjLPVubbKq6wUdhHsBCK_OESq2Y,2247 +torch/include/ATen/ops/_foreach_clamp_min_ops.h,sha256=axv2dKAWR2mut6fK6hYHpxmt3lnu-TY0pVJoeX4-Emc,6634 +torch/include/ATen/ops/_foreach_copy.h,sha256=1vfdLAUVer_XevrKDZgncGKEXviqXZdw7wKqcbI608I,1695 +torch/include/ATen/ops/_foreach_copy_compositeexplicitautograd_dispatch.h,sha256=77JimkSD3Lw_6jluP_XTPeoFbw4U-62T1BeLCRTwz-M,1185 +torch/include/ATen/ops/_foreach_copy_cuda_dispatch.h,sha256=Qo8EXEJnuiQnPfP132gaPv-uzpFGNxIkr-rXwpXWSLY,788 +torch/include/ATen/ops/_foreach_copy_native.h,sha256=nz60aGkT4h4PYb-ANx2VFxc-lEDpVk2Msq-OunGiQUo,917 +torch/include/ATen/ops/_foreach_copy_ops.h,sha256=oxhVuFh8I8UALfAn2FleKHsQsimjGebfiXoNvUegCqU,2565 +torch/include/ATen/ops/_foreach_cos.h,sha256=nGW9-8W0-XfKmkeNF-QE-8tCGRrNXYmN4iZW7EKZV6U,1267 +torch/include/ATen/ops/_foreach_cos_compositeexplicitautograd_dispatch.h,sha256=EesSBGZ9J6ZXRVEFe-KZq3fTFsVAvyBVAfB6zQXtps4,1007 +torch/include/ATen/ops/_foreach_cos_cuda_dispatch.h,sha256=gNtxCCqb0iewwINESO5Efo9c7SPGLKCec4gEULWXUgk,814 +torch/include/ATen/ops/_foreach_cos_native.h,sha256=k9Zs1fbw0-A10LMzYrK4LGEzm1FuZOxzPcCAWNaUDIY,809 +torch/include/ATen/ops/_foreach_cos_ops.h,sha256=W6jnAF6VpwttN-No3HZyIg0NZLPs6LNJWcA3N_tprSA,2130 +torch/include/ATen/ops/_foreach_cosh.h,sha256=BzNJM9iN87JFCmXWMK6qbuQD5mElMampA5s8KQ0-EZY,1280 +torch/include/ATen/ops/_foreach_cosh_compositeexplicitautograd_dispatch.h,sha256=kiDQ_8bsEhC1qkvWOFU8P5PvQMuEkrbV3PRm8nug6-E,1011 +torch/include/ATen/ops/_foreach_cosh_cuda_dispatch.h,sha256=Hf7Yp6dIvJJ6nzn2mXeZy1kykUDv3OL93VFr8RZ7aZg,816 +torch/include/ATen/ops/_foreach_cosh_native.h,sha256=T3kvrrGRz5f0deEM2ndxiiC8viQXbyaD-72ozhHA-Jk,814 +torch/include/ATen/ops/_foreach_cosh_ops.h,sha256=TtlZ1VIhm8hhAZQ80sy0OH2QuLbAi8wVjitfE944N_o,2139 +torch/include/ATen/ops/_foreach_div.h,sha256=ZlLMenaNcrRdvE4ceSrfioYy7z3JCsE-LBJZP5sWa2M,4518 +torch/include/ATen/ops/_foreach_div_compositeexplicitautograd_dispatch.h,sha256=iZHT89tCm6wOUCeMyZTSZQRlEEw1fXMsFqBnAPXi-EI,2268 +torch/include/ATen/ops/_foreach_div_cuda_dispatch.h,sha256=nmqyqFjS1NkNv2Hvh1YCOoBlIQOBTIFDFSvM7detCvY,1404 +torch/include/ATen/ops/_foreach_div_native.h,sha256=8ts7dvh060lQzWLNeXc7WOGm5EFeLYukK71ECg-j0UQ,2717 +torch/include/ATen/ops/_foreach_div_ops.h,sha256=o5mgNeZFhw4XbP0y2WswmPN97RjhNJR1UoN4OKHdCOg,8446 +torch/include/ATen/ops/_foreach_erf.h,sha256=uePkyLDdZn6OOkvROOFNY54CMlL--0RxmPjw5qkTmtA,1267 +torch/include/ATen/ops/_foreach_erf_compositeexplicitautograd_dispatch.h,sha256=bYC0QW45cTbfCqTkazXExSj16Bj9Lj9sHLv7mDBz53c,1007 +torch/include/ATen/ops/_foreach_erf_cuda_dispatch.h,sha256=cVZSWuvkd0f50B_QB72h20KFqsInm0Z_M2JHUrgUMnM,814 +torch/include/ATen/ops/_foreach_erf_native.h,sha256=1wiIy_qO8DAG3P5Wbwn8XcnrLznre3gAsiJCDlSH4Bw,809 +torch/include/ATen/ops/_foreach_erf_ops.h,sha256=LkZWo8NOFx1nV78fAhtvybZox1-BGa3WfTkfUJfGeEQ,2130 +torch/include/ATen/ops/_foreach_erfc.h,sha256=hoKx7kk8hPSdlbHxRoQYHkpTWlytDkGJY_IDezHQGSI,1280 +torch/include/ATen/ops/_foreach_erfc_compositeexplicitautograd_dispatch.h,sha256=cxtQ3AtMLbqcf_JQOCIvVvYHi6rHVs0Ui1N3gwG4Kj8,1011 +torch/include/ATen/ops/_foreach_erfc_cuda_dispatch.h,sha256=DAvgZ5guwOibB_-r4cqrl-JfuTetG3sPI_5RwdJry-Q,816 +torch/include/ATen/ops/_foreach_erfc_native.h,sha256=faLDz5BkrNTqgp63Y7Zb7HIRk06Vz_qg3anrQFsxe2c,814 +torch/include/ATen/ops/_foreach_erfc_ops.h,sha256=Cbz3gbmdSvVXpaXfJF3X6OT3qED9naGKI_juNMysf24,2139 +torch/include/ATen/ops/_foreach_exp.h,sha256=Oan9_E7TGYYhtgvKS4Ztc-4QV-L_PHI7f5cGv_Xp1HE,1267 +torch/include/ATen/ops/_foreach_exp_compositeexplicitautograd_dispatch.h,sha256=-CFrlJZk3JMBrG2xQN3UI9Zuf4IkgNd9NEeL2CgiUps,1007 +torch/include/ATen/ops/_foreach_exp_cuda_dispatch.h,sha256=bsMJgt9zQMH47vTmaL73MkEasVOCt-RbS8l_jI16EpE,814 +torch/include/ATen/ops/_foreach_exp_native.h,sha256=t-p6w6g-XWk547jtjsq2uxXkpxNT6UwfctxcjzbQtlU,809 +torch/include/ATen/ops/_foreach_exp_ops.h,sha256=VJKF8NVcIHFMYZ7thY0H68HDUNOrf5Gds9F1Ic4qbIc,2130 +torch/include/ATen/ops/_foreach_expm1.h,sha256=D8Tb1CwS4QkIZQLbRXlEb2A83v1r88plx6H3C2ulYU4,1293 +torch/include/ATen/ops/_foreach_expm1_compositeexplicitautograd_dispatch.h,sha256=LxqCa69jkOSwgQdlvOwqouNQvE-RWGCfrsUYPbbKb34,1015 +torch/include/ATen/ops/_foreach_expm1_cuda_dispatch.h,sha256=3gq10QvTUbRzKmJ0D4KBX9HxeZO9RAzXQqRuEsOHAaw,818 +torch/include/ATen/ops/_foreach_expm1_native.h,sha256=gAdg52efng0pNGNDKywd6CyAf07pB513ChnQLa-m518,819 +torch/include/ATen/ops/_foreach_expm1_ops.h,sha256=G35qLzW2NP05Mu5ZUy6r8jINPX4SMr9we4dOSxhpxNY,2148 +torch/include/ATen/ops/_foreach_floor.h,sha256=ACH-ZHOftGfTrFVhFZbTGIbjivRJxkZmUxTbgjdYVKI,1293 +torch/include/ATen/ops/_foreach_floor_compositeexplicitautograd_dispatch.h,sha256=mqOT2T8YPpBBELfoe8IEM0mDLrkmmKrWAYuSrvkfEbk,1015 +torch/include/ATen/ops/_foreach_floor_cuda_dispatch.h,sha256=ZjZb4tRD7_br11ueHqNLSjXXQwtrZ03_OGJ8dc1O2yo,818 +torch/include/ATen/ops/_foreach_floor_native.h,sha256=nr-0ImR71xmpgD4t1hv-BEmX9qmcJpcn3gIvRk-BGmU,819 +torch/include/ATen/ops/_foreach_floor_ops.h,sha256=ICp0lDUrbzNrfkrBIJXgRBkRw9v06E-rNw_R9c4N_Xs,2148 +torch/include/ATen/ops/_foreach_frac.h,sha256=sljQuAIakNvheNYnxF6lbDYI37Pxbf3xchVkRV7JQC0,1280 +torch/include/ATen/ops/_foreach_frac_compositeexplicitautograd_dispatch.h,sha256=Rjni4tRWzSRUGXnX8mTRi64k3nb2aOpBUQjsdFyFiTc,1011 +torch/include/ATen/ops/_foreach_frac_cuda_dispatch.h,sha256=AGgSAWwi8Z_QFUA12ph0kFFEnWbH7gLF7MDvdYbF8NU,816 +torch/include/ATen/ops/_foreach_frac_native.h,sha256=UFASwYjQOIs3T0WjXKyTTeYwAIQ4sXDztqsihsNOzuk,814 +torch/include/ATen/ops/_foreach_frac_ops.h,sha256=Evi1wUynY4OVDXxE9sDP7WQzQYkTjcO9_7ZfJmsIikQ,2139 +torch/include/ATen/ops/_foreach_lerp.h,sha256=dc1epUicwoKIOysWltPD8xH0HfTBm8whdbHDTTtrG9E,4238 +torch/include/ATen/ops/_foreach_lerp_compositeexplicitautograd_dispatch.h,sha256=_qKlab0U2Nx9eRmwebuTva8E84enWR2JxNZAJBZsbCg,2205 +torch/include/ATen/ops/_foreach_lerp_cuda_dispatch.h,sha256=CzNnR8rKFQ3ImBV5ryjz4OuViQVz-I05w1ZwBVDgilc,1386 +torch/include/ATen/ops/_foreach_lerp_native.h,sha256=o-k7pq2bUi6AhwYSN3ofLXRJfs5loI72fOA7rWVpOFA,2500 +torch/include/ATen/ops/_foreach_lerp_ops.h,sha256=TJoScXjIUkr89T5_dZ8rjH2CwlFjYhXhWn5GecV1cb8,7273 +torch/include/ATen/ops/_foreach_lgamma.h,sha256=-RR-pS5vib3ChSBtq0N4GzPMslhddp9xmPnfjE5dewI,1306 +torch/include/ATen/ops/_foreach_lgamma_compositeexplicitautograd_dispatch.h,sha256=NbIMZVJ5r4mc1Cv1Dhjv28Qer8UFpn4tXy599iglwnE,1019 +torch/include/ATen/ops/_foreach_lgamma_cuda_dispatch.h,sha256=7pSd1yHpgNjUoar-cRR_8pRJk9ClFUcxO3Wt7m6RQ_k,820 +torch/include/ATen/ops/_foreach_lgamma_native.h,sha256=C7m7YrBXiz-6dW9zq0G8ngVkkb1ZOe2W9rj5n9UMQHw,824 +torch/include/ATen/ops/_foreach_lgamma_ops.h,sha256=VliJo4ivMc6i3uBqsmsWK3MjyYKWcbIpP4CcV2kH6N4,2157 +torch/include/ATen/ops/_foreach_log.h,sha256=URF8kQOuoM9cmpCF4F57o0NbNMAgBzLZm6WLl4wWcbE,1267 +torch/include/ATen/ops/_foreach_log10.h,sha256=uDHwtKefciq0jer9wkbMkF7sm1BUgiAZvyk35ZKnj2s,1293 +torch/include/ATen/ops/_foreach_log10_compositeexplicitautograd_dispatch.h,sha256=EWDe6AQsUHfhb6iEhc5-B8YLk0UFRMyiSfdnWX0HGQ0,1015 +torch/include/ATen/ops/_foreach_log10_cuda_dispatch.h,sha256=NOiiUghxbi3sWqWWAb7MD8vqtDjXq4F10N2XUpdrY18,818 +torch/include/ATen/ops/_foreach_log10_native.h,sha256=guiAGL80xiZ9660kHNtxM2aVMvpTU0E-7XttHJk7y54,819 +torch/include/ATen/ops/_foreach_log10_ops.h,sha256=fMtwmx6n_tkuaVgwxZbHFk-J7wXjOqwOc6J12Mazc0w,2148 +torch/include/ATen/ops/_foreach_log1p.h,sha256=tmWCRA2Ag7OOtslu6SF617gxk6C7ysH18VRBPp29T_E,1293 +torch/include/ATen/ops/_foreach_log1p_compositeexplicitautograd_dispatch.h,sha256=ZD9y7iYgy7aSdFIwqJfxTssGo28jNVCnsKmmURJUCc4,1015 +torch/include/ATen/ops/_foreach_log1p_cuda_dispatch.h,sha256=NF4tg5UI1GYRIDxSCDm95cUc-64MUsp3d441wO0IOfI,818 +torch/include/ATen/ops/_foreach_log1p_native.h,sha256=jLpHWXKTtjIEsILMURhqZNLwghZKG94JCjwVYm9kA7g,819 +torch/include/ATen/ops/_foreach_log1p_ops.h,sha256=Om3KbeOoeZRgNqbuV8PxbwSOEr4BEtmQ9cxLNv-TK4k,2148 +torch/include/ATen/ops/_foreach_log2.h,sha256=MeSN5wgJxZD1itVXVF441lQMfwZZomhbFZtpP4kQDiw,1280 +torch/include/ATen/ops/_foreach_log2_compositeexplicitautograd_dispatch.h,sha256=5_dJInJpGGaoxlHbjoUmDf7wuU70us9JbBWtRrtXno0,1011 +torch/include/ATen/ops/_foreach_log2_cuda_dispatch.h,sha256=27WrpW639myaTZN_hd9SGSg9Q_HEpx9Pz79cyw8i35E,816 +torch/include/ATen/ops/_foreach_log2_native.h,sha256=FCG2PukAKVgtZXpO3yGL3xxoumgho2jqmToLwu7alLM,814 +torch/include/ATen/ops/_foreach_log2_ops.h,sha256=OKaxeHJUhQ9f-n4XKe8oKcn-5z8sn13rWssCttxEbBA,2139 +torch/include/ATen/ops/_foreach_log_compositeexplicitautograd_dispatch.h,sha256=fOFrYvrViaKj2oS5HlCeuzNJBBeOHXuzqcYM2UlNsUI,1007 +torch/include/ATen/ops/_foreach_log_cuda_dispatch.h,sha256=oSCrImDiiiblF7Xbs9CbrjotQ1c2qr525NnYGOh6VSU,814 +torch/include/ATen/ops/_foreach_log_native.h,sha256=3WSYsnMEllTOm_Z24H-fftsuGYT4oAWm6H7FOaYDX-Y,809 +torch/include/ATen/ops/_foreach_log_ops.h,sha256=LDKS6TPJioHtToLAcEEeknVydqxtk8JKd_MkFLLkceo,2130 +torch/include/ATen/ops/_foreach_max.h,sha256=Nd5G9Siyw_-p8v232oMEFhRJ_4Hn2ugHU2mTFNHxRg0,1114 +torch/include/ATen/ops/_foreach_max_compositeexplicitautograd_dispatch.h,sha256=SHVSKOtoLEnM-y2vIk33vk1rOS8C8K0S9gRFS27yTDc,955 +torch/include/ATen/ops/_foreach_max_cuda_dispatch.h,sha256=nYv-SMNAp5ARQ2LhVdyEoXvtTntsxkaGiUONt_8zU4E,762 +torch/include/ATen/ops/_foreach_max_native.h,sha256=Gq_eWZRF4I2gxprtvMkYL68TGxa3w8QNKX3S5age5TU,683 +torch/include/ATen/ops/_foreach_max_ops.h,sha256=mIcIbXtSnrXuIQ3hHIvVMnWGz8d4N-qIE0PO2cJ9nak,1626 +torch/include/ATen/ops/_foreach_maximum.h,sha256=KiykbJ5XhgLDJNH0bg4XJg4vn1MNt61xx4xyeB_tsig,3689 +torch/include/ATen/ops/_foreach_maximum_compositeexplicitautograd_dispatch.h,sha256=Nvy4ygRcPpWzu078dB9T-HODwyFKT_t8LWRCPobGt54,1937 +torch/include/ATen/ops/_foreach_maximum_cuda_dispatch.h,sha256=wxhhJirHBVH70J-BgsWwcYAzDnDqgOLNyYJuB1WXsas,1252 +torch/include/ATen/ops/_foreach_maximum_native.h,sha256=a38GzcyCrdlAZEHPon__x_Dxp3VsiJemBlsMy4Os5Ag,2241 +torch/include/ATen/ops/_foreach_maximum_ops.h,sha256=7VTQL4G5SIbS6p2fukBAYHrQ51AGZ46YF7ICHHdgF_U,6580 +torch/include/ATen/ops/_foreach_minimum.h,sha256=MNTaRsd-2rl0xo4VX_1C8otLkCtTL0paNXUV6MIZP0c,3689 +torch/include/ATen/ops/_foreach_minimum_compositeexplicitautograd_dispatch.h,sha256=KzGR1ojz_HSvmyt9sCoomzn2qhXbb8TMufm1vDvizRA,1937 +torch/include/ATen/ops/_foreach_minimum_cuda_dispatch.h,sha256=Chkx9ZSJ-VwnUGZAiCc0wXnZkbnR04A9cWFP2bA50WI,1252 +torch/include/ATen/ops/_foreach_minimum_native.h,sha256=Q3aXWMX2xeUw6kR7Tl2VcjVo2FVyzD95Ttjb65Y1P6w,2241 +torch/include/ATen/ops/_foreach_minimum_ops.h,sha256=FS8KTPjiwv2xchrvBafRYQexA2zgamnSRMHnob6tEiQ,6580 +torch/include/ATen/ops/_foreach_mul.h,sha256=cSCawKgxTngR7OmMUXpqibGETXTAh55pCOSpvTAiUpI,4518 +torch/include/ATen/ops/_foreach_mul_compositeexplicitautograd_dispatch.h,sha256=1Xswzd5bfpN8bwMQmUHeVFin_fKA006ZkUk4STHldfQ,2268 +torch/include/ATen/ops/_foreach_mul_cuda_dispatch.h,sha256=BXAM1dH0Rl8XJEPD1S-nL_4awvt-3FxWCDjLHNCviik,1404 +torch/include/ATen/ops/_foreach_mul_native.h,sha256=VXAFZjcfsYbvpF_al3VlziBSLnRle02WNXVhefQDmDs,2717 +torch/include/ATen/ops/_foreach_mul_ops.h,sha256=at0NSdkVeO1g9g6Uk_3tQUa_UWZDDSnm-nyUg2XojVg,8446 +torch/include/ATen/ops/_foreach_neg.h,sha256=6z-mIF3lM1MYzrHTkTvhVdeK0Ajo9PS0qJxXGx3kTMg,1267 +torch/include/ATen/ops/_foreach_neg_compositeexplicitautograd_dispatch.h,sha256=2Sy5mp3IkHIMY2NTQQ2YFiOuzk45nU2W7kzPLtGoql4,1007 +torch/include/ATen/ops/_foreach_neg_cuda_dispatch.h,sha256=HctH5pmxuc1UXjVpnMRhXBeFDlQUtF3pphYvw9xQ1aU,814 +torch/include/ATen/ops/_foreach_neg_native.h,sha256=xBSgVhZqqOl54PoKgwMC_8OC0CRd5BXCnPs3K6jUlqc,809 +torch/include/ATen/ops/_foreach_neg_ops.h,sha256=3Su4nZ9aOOYF080IIft_FN0f7LyzTaVibCct2Qav6oA,2130 +torch/include/ATen/ops/_foreach_norm.h,sha256=-ppRhfvlbMZrQy6A0Qc5wu5lCQGjRU-Dqzk9zY8Ve7g,1539 +torch/include/ATen/ops/_foreach_norm_compositeexplicitautograd_dispatch.h,sha256=hTjkR125W14oiWhJh9Tl1xXL-7Oj4rw2bTznqzNO8eU,1181 +torch/include/ATen/ops/_foreach_norm_cuda_dispatch.h,sha256=wUQZ3OlaemfAHu3kCrVBCdCZTb694KIYjApMKb7HfVU,843 +torch/include/ATen/ops/_foreach_norm_native.h,sha256=zA6dfaYdZzoZfaQg_HjH1fFVRkPbRoBw8GP0fkiNYuI,916 +torch/include/ATen/ops/_foreach_norm_ops.h,sha256=c7kEh5hc-AO0O6q4PEI1EKnXIwIlmeWC-8lrKzXAS0E,2107 +torch/include/ATen/ops/_foreach_pow.h,sha256=i-a7V1Drvjovt5zApEQzJEsfgSVt_CpWBCXyA705pTw,3874 +torch/include/ATen/ops/_foreach_pow_compositeexplicitautograd_dispatch.h,sha256=nAKsdidSwfTZRJKNqHT9jY_kx5DNLh35VFQDvHdfPtQ,2014 +torch/include/ATen/ops/_foreach_pow_cuda_dispatch.h,sha256=NreTFaIk285z2it5SGPMqajmSERWJ8DZlfUtiLLj_JU,1341 +torch/include/ATen/ops/_foreach_pow_native.h,sha256=RQZ3d1NM93sjI_eOSuxkqmMtR_-cZmek8G3VPwa_p_U,2435 +torch/include/ATen/ops/_foreach_pow_ops.h,sha256=_mDl5gh_h0kEktEQOg2X38YmEc_cQQzKQOrivOVp6cY,7234 +torch/include/ATen/ops/_foreach_reciprocal.h,sha256=w_8zl94s8jcTiGKQrblpYKiilz2PYxkAzPBlwgCdKAQ,1358 +torch/include/ATen/ops/_foreach_reciprocal_compositeexplicitautograd_dispatch.h,sha256=JRUmzZyHjW02x8wadcTfIUDCrHVK9YpDGgT5IxrMUM0,1035 +torch/include/ATen/ops/_foreach_reciprocal_cuda_dispatch.h,sha256=cQOV8YFm6wsc8YwVPjAR6CkxLKxgBIkIMNmANMt6I94,828 +torch/include/ATen/ops/_foreach_reciprocal_native.h,sha256=jZ9l4e_GSoGTbDcSFtA1pWYYdYbXcCky9f8l0LWBkXA,844 +torch/include/ATen/ops/_foreach_reciprocal_ops.h,sha256=yBFLlRv8mqZj-d1ZtYiOPo-pJsMnVp9Qxfh2SL6P6xU,2193 +torch/include/ATen/ops/_foreach_round.h,sha256=ALGM-gDd52j0WQ_VKCDHV-Z3n9Uoybg8PkZ3q3bIMnw,1293 +torch/include/ATen/ops/_foreach_round_compositeexplicitautograd_dispatch.h,sha256=KW7GcbOLfhsj3OUnEd3wcl73OG0XYJV-3LVzhxH211E,1015 +torch/include/ATen/ops/_foreach_round_cuda_dispatch.h,sha256=gaeMz4mkJ3U6Tlr-rL0RC12K7a6aylZ6CTuD0y2AHX8,818 +torch/include/ATen/ops/_foreach_round_native.h,sha256=GuRH1r268J2nfaiK0PIVTj3kHckZyCvfagNr9S5XFc8,819 +torch/include/ATen/ops/_foreach_round_ops.h,sha256=2y3s1Iimt5kcmuL1PFr61nI6qoLhq0DNU_s1ZVncq6U,2148 +torch/include/ATen/ops/_foreach_rsqrt.h,sha256=bVvnflj731TWnf82fIDxfLhs5ShCZ9Xz9tooV-Mn9EU,1293 +torch/include/ATen/ops/_foreach_rsqrt_compositeexplicitautograd_dispatch.h,sha256=N8VwhkRz2AaZ2AKXzweKqRm8Uo2jMoaWkzY2QJ-g9Yo,1015 +torch/include/ATen/ops/_foreach_rsqrt_cuda_dispatch.h,sha256=DqBSP049iEKAGd922vV1YENk3thr3DJAIS_UBCUvw4Y,818 +torch/include/ATen/ops/_foreach_rsqrt_native.h,sha256=JwSvqrCfJB-O8dvLtOp8BTQ9fj6fzffn5-t7bTf-3-E,819 +torch/include/ATen/ops/_foreach_rsqrt_ops.h,sha256=icnVXcuSKTT1LV5YDGX4TaxAA5JT5JZz0yTSdRIGByw,2148 +torch/include/ATen/ops/_foreach_sigmoid.h,sha256=8JpVIwk4Ne_AEEEOD4KNSXUQtlirh4VhgfXG0AQcxYc,1319 +torch/include/ATen/ops/_foreach_sigmoid_compositeexplicitautograd_dispatch.h,sha256=s1Yp2Q6wBPWvda4wBqN675VqltTsVt6HAKU9guN4XYU,1023 +torch/include/ATen/ops/_foreach_sigmoid_cuda_dispatch.h,sha256=wgqbeS5koIiIT-gXkYloeR-_0gaitiXykenwDb8OPq4,822 +torch/include/ATen/ops/_foreach_sigmoid_native.h,sha256=P0AyMAZrkYwlFDXopHpa_NqKzRAT-d-dp03QSydkro4,829 +torch/include/ATen/ops/_foreach_sigmoid_ops.h,sha256=gNEVGODt6QsnxsDs70bzaFyJzys1ok3FB2CbLAyCmo4,2166 +torch/include/ATen/ops/_foreach_sign.h,sha256=b4ZkxKkxuSnIDU1EyBITRIjQkfxvzxRw67AF0eaZc5Q,1280 +torch/include/ATen/ops/_foreach_sign_compositeexplicitautograd_dispatch.h,sha256=3UZ6WwUV7o_9ZHM0QCd0COMFadKylc0eGrjkGbL7rnU,1011 +torch/include/ATen/ops/_foreach_sign_cuda_dispatch.h,sha256=AggdiGW-cPQkb9JH2uyL_O39_xtFuBB_ZIUuJr7KWgQ,816 +torch/include/ATen/ops/_foreach_sign_native.h,sha256=Urodq140kaPjSIQDg67HckFY9WZFHHeNOfW2uZiVDus,814 +torch/include/ATen/ops/_foreach_sign_ops.h,sha256=52qnat0o1N1s4jviEN43lso4VOO-xZsEm7Sql4Ctf8U,2139 +torch/include/ATen/ops/_foreach_sin.h,sha256=kMhWmXxN6AE3JJKpkx0PQaLI8lS3gl0Y4qWi_AZmBRg,1267 +torch/include/ATen/ops/_foreach_sin_compositeexplicitautograd_dispatch.h,sha256=P1-dSOGwZ3CRODAIJwEN1O94eeDUGUMkD3S70oD73Hg,1007 +torch/include/ATen/ops/_foreach_sin_cuda_dispatch.h,sha256=AqpFP11MrT8CGfWo2NNPxpl-W49QcD9T1UZrMrFABUk,814 +torch/include/ATen/ops/_foreach_sin_native.h,sha256=1NDZV1MzsEzcQ_TCVz05CVJa8lCZ78aQnlCaGCyHOwk,809 +torch/include/ATen/ops/_foreach_sin_ops.h,sha256=s5PhixuBdXzf1XrLgv5udj_uJrBW2-fG8hO-KYKcouI,2130 +torch/include/ATen/ops/_foreach_sinh.h,sha256=szxKRPCsCkm-QgZFSmuEaPpb72CprYkMtJVYAQiVjok,1280 +torch/include/ATen/ops/_foreach_sinh_compositeexplicitautograd_dispatch.h,sha256=6EfthNm7bOD6kdOLNeKITLhVYibMMR28VvQlWd8QbsY,1011 +torch/include/ATen/ops/_foreach_sinh_cuda_dispatch.h,sha256=XY6uGBoGSdQ5hq0FBUJSlOl7ALUnVDDgLjDKlxBCXxg,816 +torch/include/ATen/ops/_foreach_sinh_native.h,sha256=NJHhs6Lu_Jte93XGt3epILl06MQwHxAlJMG1NsGH0GU,814 +torch/include/ATen/ops/_foreach_sinh_ops.h,sha256=ahw519b8lBioOxfiVUreITB5mRIafUMl37HD8yziOd0,2139 +torch/include/ATen/ops/_foreach_sqrt.h,sha256=mrbA4MXZdXoWqkW8ZimXwrw2QSwguf9haI-1LRW9e9k,1280 +torch/include/ATen/ops/_foreach_sqrt_compositeexplicitautograd_dispatch.h,sha256=qEowmipvXI5jF9zF72fTfVvKUaabYy6nxuK1mRMrm1s,1011 +torch/include/ATen/ops/_foreach_sqrt_cuda_dispatch.h,sha256=ORc-a_xAEOVRhmozWEJpWVIm_8Wln70q6RlaqYzYsYE,816 +torch/include/ATen/ops/_foreach_sqrt_native.h,sha256=liJLH6AkOqoB_7pXeeOs5jL2JXSTjsfNc8cYwhANSZI,814 +torch/include/ATen/ops/_foreach_sqrt_ops.h,sha256=hKNpOOAzu6gy8D2D1hkcO-KE2ZPMakNRLVcp27TQJEU,2139 +torch/include/ATen/ops/_foreach_sub.h,sha256=22T_-hSWVNlvSwujwi_Zc8AJTBDgUwMnBDXAz23BwGg,3749 +torch/include/ATen/ops/_foreach_sub_compositeexplicitautograd_dispatch.h,sha256=k-yF9q3WBEArgv-WOUu91V_yhzgP4TyCdFL8nQWwYA8,1999 +torch/include/ATen/ops/_foreach_sub_cuda_dispatch.h,sha256=3hjoJAv71AU-UnugARYpeSGsIgE080NGg2uRL6hK2k8,1284 +torch/include/ATen/ops/_foreach_sub_native.h,sha256=KUz8OGZ720gIrma3fe4fGhhzXD2KapXH-PRfjf2pHxs,2295 +torch/include/ATen/ops/_foreach_sub_ops.h,sha256=of8wPncoJBPhkbTraYoHPkblwaT5kk-d1b1R7NllyI4,6742 +torch/include/ATen/ops/_foreach_tan.h,sha256=OD4L7NBwVo4VFGQCTxftDfB2JkBzumdaLYuo5LindhQ,1267 +torch/include/ATen/ops/_foreach_tan_compositeexplicitautograd_dispatch.h,sha256=ys4Tjft1OnkV1rxwn8lWcxL-rszulBPcOrqOkMvX1j0,1007 +torch/include/ATen/ops/_foreach_tan_cuda_dispatch.h,sha256=C5iap7ZTylMP6kQJgRxn_6MSgwIXYmBqr6TbhpQEOA0,814 +torch/include/ATen/ops/_foreach_tan_native.h,sha256=HvF0sPDb_cRk8v2yB_t6CKSkM-oHu0vPG7G8aPJ99AA,809 +torch/include/ATen/ops/_foreach_tan_ops.h,sha256=bpefss--rhPDJGI2qmrZ57cNTh7rxPaOYFoA0B1ijSA,2130 +torch/include/ATen/ops/_foreach_tanh.h,sha256=58Nvt1yPGd4z_MX1ILzGfIDECUbFzpdU8HUv0IZ702M,1280 +torch/include/ATen/ops/_foreach_tanh_compositeexplicitautograd_dispatch.h,sha256=O3Lut_Mj9_1Eji67weSX5dhbrYoAsmSXL-DDO6fo1QY,1011 +torch/include/ATen/ops/_foreach_tanh_cuda_dispatch.h,sha256=p4SvB9KpzfQdi04N_8dIVfbTGbV-oq0kxcs410wwDzM,816 +torch/include/ATen/ops/_foreach_tanh_native.h,sha256=oRjlkUB691vSULbceQGLszKIr1zr5k7U7H0taT06_UE,814 +torch/include/ATen/ops/_foreach_tanh_ops.h,sha256=PnmpgpP6dUbJKjw_rfhClDf6jiWh9A-b_DDzTLIVtbo,2139 +torch/include/ATen/ops/_foreach_trunc.h,sha256=ZM62ndxQrcaNeRZUEmQJcXh_rWjdX80oPYN86Nkqmd0,1293 +torch/include/ATen/ops/_foreach_trunc_compositeexplicitautograd_dispatch.h,sha256=eTTifPwWGxwndZBCyOkvmroPvmf9D8B9tgMnjGQezbQ,1015 +torch/include/ATen/ops/_foreach_trunc_cuda_dispatch.h,sha256=qdlmneZv1ZWK9K0FX7tWfcpeouLIZ9M10i8q0MmS7NU,818 +torch/include/ATen/ops/_foreach_trunc_native.h,sha256=B1OrnoBeOTU2sPpqQjpwjxSVEnvMbL4Ih9pAYVD7TxA,819 +torch/include/ATen/ops/_foreach_trunc_ops.h,sha256=vO_X7yLSgRQrknnQk6zqqLT5z2ar8sEEe5plH45e1x4,2148 +torch/include/ATen/ops/_foreach_zero.h,sha256=hrtSd5-RC7t2ya6zjIycJFECQrzjh0WpOi4Kgbi1VI8,1289 +torch/include/ATen/ops/_foreach_zero_compositeexplicitautograd_dispatch.h,sha256=UF9weyuxoQVL4-kgdwGNcX1Pv3ByQymv1YNf0r0_qTo,1011 +torch/include/ATen/ops/_foreach_zero_cuda_dispatch.h,sha256=YHpBwZ-9EQ9yP2czG54_GkG982U6ZnkgL_b1plmTp1M,743 +torch/include/ATen/ops/_foreach_zero_native.h,sha256=q28acsY6Yk95sfoZloK0OeA9s9As5aMgnZiyhop5eNc,719 +torch/include/ATen/ops/_foreach_zero_ops.h,sha256=s8vvrNxeudemY5GSq7_72d-jqyYBu_sobZubvQ4pkpk,2148 +torch/include/ATen/ops/_functional_assert_async.h,sha256=TnxvB9lMUeZonHJD9fBYCxDiLGvLzsy9gl9sXViOPsQ,864 +torch/include/ATen/ops/_functional_assert_async_cpu_dispatch.h,sha256=XcOEnkFgwXNMgjtSGAplZggc0yX-xWi2WLQ2vcg7-ag,820 +torch/include/ATen/ops/_functional_assert_async_native.h,sha256=gEO_-bwEJgnLFYp5_fs3k7xuddrHz9Or_YybTUWNnbo,582 +torch/include/ATen/ops/_functional_assert_async_ops.h,sha256=VbBuTsV_R5EpM5_sAArAcRjkrNZ8CB_bDkdpMvg1lF0,1241 +torch/include/ATen/ops/_functional_assert_scalar.h,sha256=7e4IGNl1eTP9dDlxjXeVM1qaGNc-mxzJfVObKCr4Rq4,860 +torch/include/ATen/ops/_functional_assert_scalar_compositeexplicitautograd_dispatch.h,sha256=S9pcjZYz1uylj4i_BUT4pf5eJhiU-xxLytYKk9Mjpt0,865 +torch/include/ATen/ops/_functional_assert_scalar_native.h,sha256=XfNS3eeCFH4R_rlpDQ-wLo7NrdWw3EJuZGffl16hikM,575 +torch/include/ATen/ops/_functional_assert_scalar_ops.h,sha256=4qDaODvzjrwaLzKlO_b3FbHfGA8l-4bQJtbNqrOUbRU,1233 +torch/include/ATen/ops/_functional_sym_constrain_range.h,sha256=EpoLvxD_tqlwW1Z_M5H8lVw_ejXzDEqzQSjDoD8InP4,917 +torch/include/ATen/ops/_functional_sym_constrain_range_compositeexplicitautograd_dispatch.h,sha256=VTwwdbcDZlYNCofbM8iffaBMBEdBw38v-bgdtF0ZPxI,902 +torch/include/ATen/ops/_functional_sym_constrain_range_for_size.h,sha256=GZk5FYVBfRLJw3pMBGGlHzjs0x2MtCGlCtebV5bij1g,953 +torch/include/ATen/ops/_functional_sym_constrain_range_for_size_compositeexplicitautograd_dispatch.h,sha256=iKONRy4zJMJDJ7hfAY6CViyelxPm7LjV6g6Z04RLt8E,911 +torch/include/ATen/ops/_functional_sym_constrain_range_for_size_native.h,sha256=Kt08LIbUQDn3TS7-hYqs64x85mLZjLuSmB7NNHAyqug,621 +torch/include/ATen/ops/_functional_sym_constrain_range_for_size_ops.h,sha256=v6GPZ1dtIfIyTHREy7QOi9u6VLCcIMzsOWwGARQvqKs,1378 +torch/include/ATen/ops/_functional_sym_constrain_range_native.h,sha256=zyCi5hSZrdU4seqbMoI-Cy89y6st_S6VSVGscMediVk,612 +torch/include/ATen/ops/_functional_sym_constrain_range_ops.h,sha256=cyiRD9jQ3fyroy1EBGs26w_Ic8bzB12c8JIzP3jSUKg,1351 +torch/include/ATen/ops/_fused_adagrad.h,sha256=cJKhGBYbKJCIgYNkdEsP0l--fsDIl9YYqLkWPM5nFqo,6794 +torch/include/ATen/ops/_fused_adagrad_compositeexplicitautograd_dispatch.h,sha256=_UhGw7Be9RJSDl_GPW3z1BGXe1J_0aZFIHYmgH9Bj5o,2896 +torch/include/ATen/ops/_fused_adagrad_cpu_dispatch.h,sha256=I_zbNELOvTCRXtpzON3hC-14YaGY_m2EF3_1Q5WnZEs,1316 +torch/include/ATen/ops/_fused_adagrad_native.h,sha256=fK25iD8S47-bBCC-71Y9uASgbpT-C7imrrv541m31WA,2590 +torch/include/ATen/ops/_fused_adagrad_ops.h,sha256=N9b_CcfGRKWdQ5jMwb2YuLUpFH-IoY35avzA6O9zkzQ,9700 +torch/include/ATen/ops/_fused_adam.h,sha256=_SB_i6lkHzEl6eublQWs6cWjrbR94SWnor-khPD4qTY,8467 +torch/include/ATen/ops/_fused_adam_compositeexplicitautograd_dispatch.h,sha256=OuTafQtEHRCuLyMK2EYIRYLt61TCy26_MkCUCfMcgZc,3454 +torch/include/ATen/ops/_fused_adam_cpu_dispatch.h,sha256=95vdYZC3w1jowkmiGQcldzP98XYi2Qzs9dqz8Kyvyn4,1476 +torch/include/ATen/ops/_fused_adam_cuda_dispatch.h,sha256=ubmBz_gGn8qzkjW2GIjTqHQskoHAYewO0EMxMKrLga0,1478 +torch/include/ATen/ops/_fused_adam_native.h,sha256=1pwS02iWP-GdCUoAZeK2Eubg49r9pBePAaH2kE0ppTc,3960 +torch/include/ATen/ops/_fused_adam_ops.h,sha256=XdIMbyDy4X9_CdGxxJmOtALgv16hwq2yXU-zkbuEwtw,11678 +torch/include/ATen/ops/_fused_adamw.h,sha256=w0lkBOoxzUPgHbjLETRogJDm6t6jFpgLaU9MTO4pANw,8492 +torch/include/ATen/ops/_fused_adamw_compositeexplicitautograd_dispatch.h,sha256=jNDuD0COT-mPn52fAWt15UeHNTERM809rpXXuYm9Xxo,3460 +torch/include/ATen/ops/_fused_adamw_cpu_dispatch.h,sha256=wgtpXZzJebG8MWQD5-1t1wMtR9yqgTSaPpjXdCBhfoI,1478 +torch/include/ATen/ops/_fused_adamw_cuda_dispatch.h,sha256=-hKZQDzwIeTP6GHUShk2BDuLfwrPb5DCG9HxbSth3eQ,1480 +torch/include/ATen/ops/_fused_adamw_native.h,sha256=Wjp1yulx8WuVz10VQY-eq-oy9x9jQnaESTckPUseTtM,3968 +torch/include/ATen/ops/_fused_adamw_ops.h,sha256=jplszQG10Od-yuCepBK7BLe2UjC7rHdG57JDCTLAn-0,11696 +torch/include/ATen/ops/_fused_dropout.h,sha256=eY3BpDCwc2aJbbj-n_8wpDCXFeT3lSptQmywSKisDuI,1690 +torch/include/ATen/ops/_fused_dropout_compositeexplicitautograd_dispatch.h,sha256=OXw4if2z6gH9aP9MLCkajgs8q02nc6SJqFzgPOo9_DQ,1120 +torch/include/ATen/ops/_fused_dropout_cuda_dispatch.h,sha256=dcq1kvcvNWMVVO9i_i3cf2gcppHMIYq0gVpkZXHDawM,845 +torch/include/ATen/ops/_fused_dropout_native.h,sha256=nbL6WICYuVT4UAR2jFUvRub5RFVwL81Y--l8wLZj1gc,787 +torch/include/ATen/ops/_fused_dropout_ops.h,sha256=tTOWcOjqTerBByqdNEVTWVYnPMQIkFPE9ZQQLFz2M8U,2247 +torch/include/ATen/ops/_fused_moving_avg_obs_fq_helper.h,sha256=v_Vy8F1HM1cYt7aFPEXJObBeaMNkhgtXN1rgIHdDjbs,4820 +torch/include/ATen/ops/_fused_moving_avg_obs_fq_helper_compositeexplicitautograd_dispatch.h,sha256=sojV5y8wBMGDJ9Mi1d2gLYjVF1xaKDndvwu6lr8AMrs,2110 +torch/include/ATen/ops/_fused_moving_avg_obs_fq_helper_cpu_dispatch.h,sha256=cF6Exaux5zJ4uPNwVD8rZWDOqUHuQI5WWNsEymcuFMI,1094 +torch/include/ATen/ops/_fused_moving_avg_obs_fq_helper_cuda_dispatch.h,sha256=FhQXuOBUOqq6heHKsCNWfELF7caP-fPwqPXGT2kQlJw,1096 +torch/include/ATen/ops/_fused_moving_avg_obs_fq_helper_native.h,sha256=DCzzw6Clf3hWhNEf43Yin5nyXMLpWYp96z0HHP8mCRo,2188 +torch/include/ATen/ops/_fused_moving_avg_obs_fq_helper_ops.h,sha256=zt4wmH7w0BBVW7Hm2_13SqBPa6iCXCdCggEHecmSvuM,5925 +torch/include/ATen/ops/_fused_rms_norm.h,sha256=5lFiQsdiroonX1-TWqVv6UVjD7wiXyoA535i51-GibY,866 +torch/include/ATen/ops/_fused_rms_norm_native.h,sha256=9XaUGLJyiIUXnjau4tp-HFMbdt34qeE71RClgADHWpo,442 +torch/include/ATen/ops/_fused_rms_norm_ops.h,sha256=tHhtIm_nT9Zmhl0CKTaGW2WMBM3bpgPkpcyPXpeUyv8,1246 +torch/include/ATen/ops/_fused_sdp_choice.h,sha256=n68tPm44Qw55j_BwrJSUhfOi4p40RnvYJAMXcyKlXrg,1124 +torch/include/ATen/ops/_fused_sdp_choice_cpu_dispatch.h,sha256=OBd6TKq8zL11rdu6w2f2eXgE9-dEA_HFxIKLP6bKnCs,965 +torch/include/ATen/ops/_fused_sdp_choice_cuda_dispatch.h,sha256=17PyZ9oV0upkZ-HAboj9cwcwAkrpf9XEU1DBT5zkV38,967 +torch/include/ATen/ops/_fused_sdp_choice_meta_dispatch.h,sha256=RNtdl1BjQS4GD8muvoomkD-p-CIpoJaQGd2qbLiiUWM,967 +torch/include/ATen/ops/_fused_sdp_choice_native.h,sha256=6CkTccYqe-QxHj-wphAnPVjGiHq7KZxaPa4y6eV555Y,1287 +torch/include/ATen/ops/_fused_sdp_choice_ops.h,sha256=PwieFPeb4pDdDOTy2xqrI6iyaP8cd0d_oEj-fgi2rIU,1628 +torch/include/ATen/ops/_fused_sgd.h,sha256=UgOB_qScPGQVZ5xJQX7eX1UxJQsMBQ9j3nl7TAy1O2A,7290 +torch/include/ATen/ops/_fused_sgd_compositeexplicitautograd_dispatch.h,sha256=ga4g_ElEXnl2IIbf4zvnqIwXA8gslJEmoe6RFRejsiI,2984 +torch/include/ATen/ops/_fused_sgd_cpu_dispatch.h,sha256=0HRW0lAeFVwWaDudr92ffngXuaZPpI5psIzUtGqrUOw,1354 +torch/include/ATen/ops/_fused_sgd_cuda_dispatch.h,sha256=6niBeqpBH_Wvb7dy4-gg4B_6IEo67NefpJYzn2qKMPU,1356 +torch/include/ATen/ops/_fused_sgd_native.h,sha256=-vRY4TfPs1bNpx1m7MnVR7e6JfGOe81Vfo4iJXpbW4c,3368 +torch/include/ATen/ops/_fused_sgd_ops.h,sha256=E7TCqjROag88MObICqoZgYsFHLGBni1c-O9yvFkdRYA,9962 +torch/include/ATen/ops/_fw_primal.h,sha256=hszd5FpSmaN79TAVRhjkTnDLyH9K6jgpEbO3dqVEW5s,534 +torch/include/ATen/ops/_fw_primal_compositeexplicitautograd_dispatch.h,sha256=U15a4UjV2C8F1-m7rSgERz6STC2Bi6TH858Is9fIm6I,806 +torch/include/ATen/ops/_fw_primal_copy.h,sha256=2qERq5WwOjDeeq41KtBGDefmhRTqXJQCfjT2WKrHwCY,1256 +torch/include/ATen/ops/_fw_primal_copy_compositeexplicitautograd_dispatch.h,sha256=DMYG2QFD3s8d1bROTQVlZRmPjVm_PITUuvxVEvKqzsA,939 +torch/include/ATen/ops/_fw_primal_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=8eiUpOMUqtZEwVIIKy0-be5CTDWYFQffTOxe9jtXoWE,837 +torch/include/ATen/ops/_fw_primal_copy_native.h,sha256=bTPX09GVC5w7yndjoOUbkJ6s_Q9lElmERGvrdgjWWnU,624 +torch/include/ATen/ops/_fw_primal_copy_ops.h,sha256=gNUrCj1WvhHHEouK2NlOQ7w2y6Uli_ZLEeczBmV33Lk,1741 +torch/include/ATen/ops/_fw_primal_native.h,sha256=GFcuGq-Xeb7SqKQvBONGUdOeXpTjm4OU2GhdTqRi94U,516 +torch/include/ATen/ops/_fw_primal_ops.h,sha256=0x_vE4R6hCLt3VvJSiwyYDAllSdHWWhxQ2n0B64kjUw,1054 +torch/include/ATen/ops/_gather_sparse_backward.h,sha256=HnQy5D0jcqCvePfl2giuZeF_N8u-7k9leT2us31bjvo,854 +torch/include/ATen/ops/_gather_sparse_backward_compositeimplicitautograd_dispatch.h,sha256=cOCHmD3fKQmGe0rfTd6HgxUibRJRzwoR45Cyjf-d2TY,868 +torch/include/ATen/ops/_gather_sparse_backward_native.h,sha256=5lk8H8Bsc8rD5wvcZFJD2IR-D3STJJp3PPGiCJlzlRY,578 +torch/include/ATen/ops/_gather_sparse_backward_ops.h,sha256=ESR3n7trZWCQPYaBZDH6ypxwMh64QCzfrvcmLbLfdKg,1250 +torch/include/ATen/ops/_grid_sampler_2d_cpu_fallback.h,sha256=KMqtaW3O4kvdrkmO3RDcbVb7R7WmGXqLve6sCRHo0LA,1981 +torch/include/ATen/ops/_grid_sampler_2d_cpu_fallback_backward.h,sha256=NfLbH8_SqrjnvtH0VdGFWbpfuyogpRLsxWvH2T8Or3s,1124 +torch/include/ATen/ops/_grid_sampler_2d_cpu_fallback_backward_compositeimplicitautograd_dispatch.h,sha256=ouZ5BLvv-Grk9nbOSSqxdB2bKreyIHaBFjxR6XIlQUw,972 +torch/include/ATen/ops/_grid_sampler_2d_cpu_fallback_backward_native.h,sha256=sQ2aYpqJOCnFH53m_ZGGBn90Oe1_aLBrtYC3YIeKT7o,682 +torch/include/ATen/ops/_grid_sampler_2d_cpu_fallback_backward_ops.h,sha256=2T5jORPihFtSPFIG1dTkTsmCT4AR2JyopFtQTXa-bow,1583 +torch/include/ATen/ops/_grid_sampler_2d_cpu_fallback_compositeexplicitautograd_dispatch.h,sha256=Im-2cqDU4PM9DAhey_wbA4pUbaoOIOdapt88Y1GI-nA,1303 +torch/include/ATen/ops/_grid_sampler_2d_cpu_fallback_native.h,sha256=DXkAQan6aPNFiJIi-Uso9p2tXVU-5xLWpILDjvhVM50,814 +torch/include/ATen/ops/_grid_sampler_2d_cpu_fallback_ops.h,sha256=VCg4myW77uiF3CQH5ToLnTQ7SF9cWd73m-cYTmHNcmY,2349 +torch/include/ATen/ops/_grouped_mm.h,sha256=uiG-yQflaWmn3VQCtM5VdmbgYSPeSupsv7oWsThCLW0,969 +torch/include/ATen/ops/_grouped_mm_cuda_dispatch.h,sha256=RS6_2dcg9GPg7CoKj-Aj9LGJXNlAMtDrrXHNEuNUskg,923 +torch/include/ATen/ops/_grouped_mm_native.h,sha256=GiyjVfHRoq9rsBgfVJIitIcj0gWqHr9kFcoWx6TvNNA,680 +torch/include/ATen/ops/_grouped_mm_ops.h,sha256=NkegZkmzM3vg6qAiTD3tyXJp5YCsqUs-pABWaSfXS8A,1511 +torch/include/ATen/ops/_has_compatible_shallow_copy_type.h,sha256=p6XYkH0TIAEioMuKvFlxnqiLulFoz1gR9ur1D7a_uZ8,812 +torch/include/ATen/ops/_has_compatible_shallow_copy_type_compositeimplicitautograd_dispatch.h,sha256=Zb5VxsNh8lEajv3iIcX5ytSq5bIcSQsZYA54crUWe3A,833 +torch/include/ATen/ops/_has_compatible_shallow_copy_type_native.h,sha256=QHNulqTasP3HMo5XYq8-7Q86fpfXAksRF1bLLHkGBqQ,543 +torch/include/ATen/ops/_has_compatible_shallow_copy_type_ops.h,sha256=zBd9WsHeZLJgAAxA7O7XWPZ82cKm-YWLc9q2CdEgobs,1130 +torch/include/ATen/ops/_has_same_storage_numel.h,sha256=J2jnjiJqdI5Lzg1U_xMOp94T_muyHSnYj8X9vwTF7lI,775 +torch/include/ATen/ops/_has_same_storage_numel_compositeexplicitautograd_dispatch.h,sha256=LqI7D2WilCCkBCUNW1XQmnB1-5HANFwZS5_W-Mq39mE,824 +torch/include/ATen/ops/_has_same_storage_numel_native.h,sha256=J-M_NS7Py8org39yXxwnC1YgJK0VXfe4hHYSZa42ES4,534 +torch/include/ATen/ops/_has_same_storage_numel_ops.h,sha256=38jrkFoZPfY7QtuTqW5ccqK3rkW77WL61mA_y4JZEjM,1103 +torch/include/ATen/ops/_histogramdd_bin_edges.h,sha256=W9D67GlElWZjamZFotRrinA-SVSwiIMlacD96gz4JWQ,1958 +torch/include/ATen/ops/_histogramdd_bin_edges_compositeexplicitautograd_dispatch.h,sha256=jgn0L5RZTaHdYxo_YKcDM_bZd857N8Znhi9JlHPOI-U,1185 +torch/include/ATen/ops/_histogramdd_bin_edges_cpu_dispatch.h,sha256=m4IJ7wghOaibnOSPhYcLM_Bl5zQkxUOvwXsL_rb_a0E,923 +torch/include/ATen/ops/_histogramdd_bin_edges_native.h,sha256=CJjooTV4xTMDzorciUo2G6S_2zWjPEgTVozlIIIj_84,890 +torch/include/ATen/ops/_histogramdd_bin_edges_ops.h,sha256=tCyTGJVq4NjBMPEdsAfEqsDeZG9fWkxXAVL_pZJCkBI,2555 +torch/include/ATen/ops/_histogramdd_from_bin_cts.h,sha256=IlDS-v8KkPnQv7uGnmHHusrzGBtOZaFQVerUcpTMmS8,1995 +torch/include/ATen/ops/_histogramdd_from_bin_cts_compositeexplicitautograd_dispatch.h,sha256=nRyIOgFIkQDPpL07AUddUbEDwVpDELexyZ0hqscSLzI,1203 +torch/include/ATen/ops/_histogramdd_from_bin_cts_cpu_dispatch.h,sha256=c0CE6iIGIA9xbnuspucNnWJ6VRgP46DoDwH6O2S1uhs,911 +torch/include/ATen/ops/_histogramdd_from_bin_cts_native.h,sha256=yl8a_Z3cdn6sk1GgiXUM-ax19KnsZLUufvgdkeMvDos,875 +torch/include/ATen/ops/_histogramdd_from_bin_cts_ops.h,sha256=tChr42mY2ZMhvI9hoRIhvTWQqdSaB6O-WzSFI2FiSQA,2550 +torch/include/ATen/ops/_histogramdd_from_bin_tensors.h,sha256=c-a8j8PnQx9uJs2S4TUIAl0Zk5f-FgqPouRJq5-x4Tg,1792 +torch/include/ATen/ops/_histogramdd_from_bin_tensors_compositeexplicitautograd_dispatch.h,sha256=2_XhzkbnAoRSlKkmU4HiYSXyOLzcSw1MosJxt1jTqDo,1104 +torch/include/ATen/ops/_histogramdd_from_bin_tensors_cpu_dispatch.h,sha256=QY15kNr3TU9Vhkg_SmU3c5OP7zE6pm5ZHDPHDp1iZbQ,854 +torch/include/ATen/ops/_histogramdd_from_bin_tensors_native.h,sha256=CA64RhXkMcSnHfcvd6VvHy0dUSZTdk7pnWIKYROLWh8,772 +torch/include/ATen/ops/_histogramdd_from_bin_tensors_ops.h,sha256=0g3It141mDylnykB19KiLMOFZaHxD7R0npdsHCPtWME,2274 +torch/include/ATen/ops/_index_put_impl.h,sha256=s4w1C9H2hRRNUv7Lta9TMh9PxJ5gYTSuQigI13GikPo,2271 +torch/include/ATen/ops/_index_put_impl_compositeexplicitautograd_dispatch.h,sha256=5UaQHJrxfXSEVj3A4-kbBDfuXJgrsKppymXCfgMmC5U,1336 +torch/include/ATen/ops/_index_put_impl_cpu_dispatch.h,sha256=A2lYjLLxOPh7n87ALWHuh3l6WQwMxoKwTdgJ5krW9mk,874 +torch/include/ATen/ops/_index_put_impl_cuda_dispatch.h,sha256=huEZP0FOP-dXv6pPMrL58DQPZnz6ibqm5zN-174A_aQ,876 +torch/include/ATen/ops/_index_put_impl_meta_dispatch.h,sha256=febd4VaNy2O_2rXFgruJn_ji-nOOdGxzNkLQFYjfPJU,876 +torch/include/ATen/ops/_index_put_impl_native.h,sha256=eZki7U7Wn1FkXqPahA32e-4RNGMfuJyA5d-66a6eYoE,1419 +torch/include/ATen/ops/_index_put_impl_ops.h,sha256=Y9_OZ02BGN_ZCrDUDeXiQiiyItW6Vl3iPW9ziGsHDrI,3324 +torch/include/ATen/ops/_indices.h,sha256=aQOSG0UEO4wOhRy-Q0EHsDP_bqR5xxQ5LYi5uXTmDTE,532 +torch/include/ATen/ops/_indices_copy.h,sha256=V_Rl-iBwdxUE5Wsny0sNFCqguZrTqmx8slRak_9m2os,1137 +torch/include/ATen/ops/_indices_copy_compositeexplicitautograd_dispatch.h,sha256=C_zZJTkxIylmkavZ_Cz11Qq6XmMi1PlM0I1GhW2B6fI,905 +torch/include/ATen/ops/_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=FsnX0M1CWVUnVLTY9Fuoj9L7TVh341fBfZMq1KjPEuI,820 +torch/include/ATen/ops/_indices_copy_native.h,sha256=_EQdNhn3AOuE56KMClEhSviYxWRgi0h50vzyhMiUrDU,590 +torch/include/ATen/ops/_indices_copy_ops.h,sha256=iN4tp4W3CoFkPF7QRFu1cCX-JV98HonEMkoUBpY2p-A,1629 +torch/include/ATen/ops/_indices_native.h,sha256=ZT5L_lJB-QoLR4S4SCN20G5ffmPwaEkwZ-JRlVpbEXU,506 +torch/include/ATen/ops/_indices_ops.h,sha256=cnuJS-LTllPVMGvBKAhEkOd0rbWUhKVYg5GYwVEGNJA,998 +torch/include/ATen/ops/_int_mm.h,sha256=B9dO_hilKQyQlY2aGZHqFT_S3SfSEztH3Cvg0s7MGSY,1209 +torch/include/ATen/ops/_int_mm_cpu_dispatch.h,sha256=QmdZRZyqQ8XGV9qpI9V_H1qx_ixhIcnW3wNv3vCf8Po,980 +torch/include/ATen/ops/_int_mm_cuda_dispatch.h,sha256=zOT1118JHuvRKVw_JakKskYbMrb7fiQpvouHtilfARU,982 +torch/include/ATen/ops/_int_mm_native.h,sha256=pkJIi1USVovxVlyLT0MRJHnlPaz4tXU4ntx9xObooEs,832 +torch/include/ATen/ops/_int_mm_ops.h,sha256=IEX0T0ueecJrCBgE34OWTAhmdNCYzUwGe9dFpUi1B5A,1759 +torch/include/ATen/ops/_is_all_true.h,sha256=8X5FbK_6nA2OyRVnlRl2X9RrTd5z05EyTtF_hBjYT_M,692 +torch/include/ATen/ops/_is_all_true_compositeexplicitautograd_dispatch.h,sha256=aCE9KJrr00D0RlxAJcNuAUYTYuvTO5IMwDdDldsylnk,793 +torch/include/ATen/ops/_is_all_true_native.h,sha256=lxE5FdL5mqa8HmhOwIdTGXJIZZDtryHIyA528Xj5HPQ,503 +torch/include/ATen/ops/_is_all_true_ops.h,sha256=QdbzMjzQxHxDw1EDdhmHHBfmOooMefnAovItD3mlUrU,1004 +torch/include/ATen/ops/_is_any_true.h,sha256=tWDz8mOl8O8dmtGgVD7rGaiTFzygtOoiQnjBZUWcX8E,692 +torch/include/ATen/ops/_is_any_true_compositeexplicitautograd_dispatch.h,sha256=2wXAEQoUCgreYJS3nDA1BK-d7GvpyAZRqktw8YJETew,793 +torch/include/ATen/ops/_is_any_true_native.h,sha256=t95xS9X-BLG_8YuPtx99z3XB6BrS46CDHiEQR3BaAjw,503 +torch/include/ATen/ops/_is_any_true_ops.h,sha256=zon5lOhk_xxc3IJXAz2AIkWpWaxdmh8FH0lU9r-4xWQ,1004 +torch/include/ATen/ops/_is_zerotensor.h,sha256=Cec15NdhIIIOY8STtMEsCG1GlpR4WkUtn9PkEWbcJxM,703 +torch/include/ATen/ops/_is_zerotensor_compositeimplicitautograd_dispatch.h,sha256=D2KamppiqN0bbwgHcRJQzkC4tkcuGB8VFrNqayrW4fw,789 +torch/include/ATen/ops/_is_zerotensor_native.h,sha256=n3uZEIL_uPRnVA2cB1-dZUfg2jyOnAdsDtJ_6IT7o3Q,499 +torch/include/ATen/ops/_is_zerotensor_ops.h,sha256=gJOFVWnm4ZQHHMEI1mBLGla1kvokLNTICA5sfWBOggs,990 +torch/include/ATen/ops/_jagged_to_padded_dense_forward.h,sha256=0kYXThkcMi8eYJdyPL4r9ChqlNkz_547Om4-64-kJQ8,2168 +torch/include/ATen/ops/_jagged_to_padded_dense_forward_cpu_dispatch.h,sha256=OU3S0YQdoi4vV_FrKrMA5Xazt3n_F2nUU7yVCg1C8XQ,1021 +torch/include/ATen/ops/_jagged_to_padded_dense_forward_cuda_dispatch.h,sha256=lmSdt_82SmY_QhLlldJIP8nC2bMgMlOSHBWVAk2D0ak,1023 +torch/include/ATen/ops/_jagged_to_padded_dense_forward_native.h,sha256=XAXbOlq_npj8ZyJpH2qCjgZrOcRMtdOID37uzd7up20,775 +torch/include/ATen/ops/_jagged_to_padded_dense_forward_ops.h,sha256=e9GAJTDzWJhKTdo3BrDKC_68YC8bOuldQuzBroA3ciE,1335 +torch/include/ATen/ops/_lazy_clone.h,sha256=mPpsMN2QQLP3XgFfJrB9L6Fl0DiYKPysma4orAEDcAM,688 +torch/include/ATen/ops/_lazy_clone_compositeexplicitautograd_dispatch.h,sha256=hwre64cEkpmMqNa5nu4LqjyrrhF3_88VLt2tl79Hivg,792 +torch/include/ATen/ops/_lazy_clone_native.h,sha256=MgU-xpT-DCtIMJ2Hoz8JEDFCIzn-GhfvE26lY5G84kU,502 +torch/include/ATen/ops/_lazy_clone_ops.h,sha256=rVb9pJVchP6rbhUI27HLpq6dzEJ3iEBPaILF62GDUvo,1001 +torch/include/ATen/ops/_linalg_check_errors.h,sha256=fPBH1Va7vUIVubBXGtYmClfiR5p0PkWu9S8Z7Ky-yOs,811 +torch/include/ATen/ops/_linalg_check_errors_compositeexplicitautograd_dispatch.h,sha256=8WTZZE8eaHXLe_eVhm20bNWgCdlnJ2V3PdQw9_iKyQQ,838 +torch/include/ATen/ops/_linalg_check_errors_native.h,sha256=XLabFA9FYS2vrFWt47cL7xbauUoqczdi1Kkx5z83CAM,548 +torch/include/ATen/ops/_linalg_check_errors_ops.h,sha256=9IzpQQONx3f6sYf9wwgufE0E9t8KczYaRPFYys5q3u4,1149 +torch/include/ATen/ops/_linalg_det.h,sha256=1G4UKYWHFv1YcLQGSsW_YB-2J1MbsA-nz5ANHHi5flw,1525 +torch/include/ATen/ops/_linalg_det_compositeexplicitautogradnonfunctional_dispatch.h,sha256=v778MbcSI9fJ1rZO7gvdVaah9StaViZn7QToWqjC0-U,851 +torch/include/ATen/ops/_linalg_det_cpu_dispatch.h,sha256=eP3n9q2P6D08uREBJNXI1CqTEzGZTaLuJ-xehnQ8NTA,1106 +torch/include/ATen/ops/_linalg_det_cuda_dispatch.h,sha256=MnGDd4QfcA1Z0E9XSBXUQQEbfFUH7q3JAxuPGDgOC8o,1108 +torch/include/ATen/ops/_linalg_det_meta.h,sha256=uPOtXAUZCc589h9TcL_KAHKeGDdX68aojy4POQfYCd8,596 +torch/include/ATen/ops/_linalg_det_meta_dispatch.h,sha256=DnQrAY-bMP7NF1xQcMW8X17juQk59gOfxASxssVtTB4,1108 +torch/include/ATen/ops/_linalg_det_native.h,sha256=Ep964GcQsF0vhp4wt_rtE-1Cm-M0QlPwrA94Z0dDywQ,684 +torch/include/ATen/ops/_linalg_det_ops.h,sha256=IjtVLogZ6G7yeKq0RuBur3BWjw8EC4uni_TDsIoo5d8,2061 +torch/include/ATen/ops/_linalg_eigh.h,sha256=9FxTG6bajGOlP_8a2LSuVECLJhLrG9Q-0-2RaTzmmFM,1795 +torch/include/ATen/ops/_linalg_eigh_compositeexplicitautogradnonfunctional_dispatch.h,sha256=eupA2xz1hkoBJslG7cT6OYNpdnBM5quRdcx1t5ca2hA,889 +torch/include/ATen/ops/_linalg_eigh_cpu_dispatch.h,sha256=qmt7dQ0gGEti2xStpK3dPHbp_VqDmQlWf_eOXfXMlRo,1195 +torch/include/ATen/ops/_linalg_eigh_cuda_dispatch.h,sha256=8H28c4AW5ejYjrKGMG0gKqRtPfGZHuIl8YSjQL7ohmM,1197 +torch/include/ATen/ops/_linalg_eigh_meta.h,sha256=m8jyidUj6TEuvltu60r1YlrNVfdgNljq6M5mXsN5558,636 +torch/include/ATen/ops/_linalg_eigh_meta_dispatch.h,sha256=K1o1NXh8Yhs5Dd2yH09NAjaBCDd62Bi5pystei-a_Gk,1197 +torch/include/ATen/ops/_linalg_eigh_native.h,sha256=lCycNGcrL5cdwbw_GBzSvtxh_3lZDWLutLAvuZBMAiI,714 +torch/include/ATen/ops/_linalg_eigh_ops.h,sha256=D_VDmngsEZIt9ePYdbKruJWy7QOtcE0_OQLTdJjdiJc,2254 +torch/include/ATen/ops/_linalg_eigvals.h,sha256=b-uzi-h23seD-nhd2Ha4_WrdmX4LYbj4XNhb4QQcjew,704 +torch/include/ATen/ops/_linalg_eigvals_cpu_dispatch.h,sha256=IWjrc2wi0gESuEoVmmeMd1J3kRMbicsQtw-wLSDI2cQ,752 +torch/include/ATen/ops/_linalg_eigvals_cuda_dispatch.h,sha256=umLEsRoynUZX3laByW6P2AP-uBSbnSb_6DkHRl3GH10,754 +torch/include/ATen/ops/_linalg_eigvals_native.h,sha256=5bx6_FOhr1QsF9405ks7iKmYutw43TS7zLR_7sSOfmM,506 +torch/include/ATen/ops/_linalg_eigvals_ops.h,sha256=LOqd5f_Y64Q8JPmxmCT9LsmfCz_VjgSWu6rQbPNEKVI,1013 +torch/include/ATen/ops/_linalg_slogdet.h,sha256=wJO35OW62QNA_QJhh7MMKffV2jc7RSUcQ8hH1Ndu0k0,1752 +torch/include/ATen/ops/_linalg_slogdet_compositeexplicitautogradnonfunctional_dispatch.h,sha256=q0C9mr-pNMim0-Mvfz4f2SkxA_5V-PJ0HmTcXPi_9G0,866 +torch/include/ATen/ops/_linalg_slogdet_cpu_dispatch.h,sha256=YsJlXeAje2rmUvR4Hx6MDOdY6qSlv8wRNjZR9ohJGNs,1199 +torch/include/ATen/ops/_linalg_slogdet_cuda_dispatch.h,sha256=Cm_bhXNM2EI2myTQbHLvu8g5_45e_oMDNUTE-ln6JK8,1201 +torch/include/ATen/ops/_linalg_slogdet_meta.h,sha256=WxpZi-L5UpdjQg-EaxuKRj4zna7aFqJMaD3eWQUI6is,600 +torch/include/ATen/ops/_linalg_slogdet_meta_dispatch.h,sha256=ZfyQH3zE-u8imhmL-DZ3avLf8dxj6sJAY_VlxlOzu6A,1201 +torch/include/ATen/ops/_linalg_slogdet_native.h,sha256=nAgHq4VHF3k9WucdQUai96mrTrFwszEkSGpXBqDIGjs,724 +torch/include/ATen/ops/_linalg_slogdet_ops.h,sha256=HddRwDVLwTZpGDPEzNnOC9v1P-JZ3cBTad2rEyewqYA,2265 +torch/include/ATen/ops/_linalg_solve_ex.h,sha256=IX3DqwUjRJp-wsx59hnzY_8NEsNNEC5je0hO04378qU,2146 +torch/include/ATen/ops/_linalg_solve_ex_compositeexplicitautogradnonfunctional_dispatch.h,sha256=2zggqmMan2eWDFVIfA-t-DhpwmHvXaoswndkYOobQvM,930 +torch/include/ATen/ops/_linalg_solve_ex_cpu_dispatch.h,sha256=sIHBcAThiUEg7gP6zBL906C3ZH4PrfFqj_2y6uerEFM,1374 +torch/include/ATen/ops/_linalg_solve_ex_cuda_dispatch.h,sha256=E56xnub2pMUzqn0YIC5vsrUN_5xK4mUeu8lALUYnx_Q,1376 +torch/include/ATen/ops/_linalg_solve_ex_meta.h,sha256=IAKvU67d4clPXgvESbZqPXqqWUxN7TxmtDcGoo1wgJc,653 +torch/include/ATen/ops/_linalg_solve_ex_meta_dispatch.h,sha256=e9qZJm34llyzPzE8WqMQoot6jcN0-TAu_IZlcmELX0g,1376 +torch/include/ATen/ops/_linalg_solve_ex_native.h,sha256=5n4eV9uW7-75G3JBel7JV4F3LzIHsUr-MMBo1B_a6hg,776 +torch/include/ATen/ops/_linalg_solve_ex_ops.h,sha256=uP1G5Hux-v4e-5781d-ff7529kSevq02240A3aR1JC4,2639 +torch/include/ATen/ops/_linalg_svd.h,sha256=-T5aff6B1ylrCcGliRnOLyACBewjhX--zxc8U65grDw,2010 +torch/include/ATen/ops/_linalg_svd_compositeexplicitautogradnonfunctional_dispatch.h,sha256=4_ZAwwu7mPLS_zJ5LdBbCYvCeyeK-QraITkTL6777iY,956 +torch/include/ATen/ops/_linalg_svd_cpu_dispatch.h,sha256=DiPd97pLm-K0a_ygi60-FiFH6W1wAvK86_ivkOwzfMA,1375 +torch/include/ATen/ops/_linalg_svd_cuda_dispatch.h,sha256=7PsyA0TndGbtuIw7vCkzPfre3WiuYjypM30tYq_ugrk,1377 +torch/include/ATen/ops/_linalg_svd_meta.h,sha256=CqtULdNJ02r7PdH1Fz4waIJRuFssk04jeRjCnHmZPc8,675 +torch/include/ATen/ops/_linalg_svd_meta_dispatch.h,sha256=ljdqyymNKy6arYccnKL9LW_uqKS6bXP3LqU3daPKqxg,1377 +torch/include/ATen/ops/_linalg_svd_native.h,sha256=wBSRG5pdYFljpQSAfEw-6P2Pz0k_VP4i9jYTwg-a8r8,753 +torch/include/ATen/ops/_linalg_svd_ops.h,sha256=6yiciJcbTPCydkvsJHW46fUzcOaLi23r4Cj72l0SKhE,2541 +torch/include/ATen/ops/_local_scalar_dense.h,sha256=9XK81NmAlF3XbpW4VNnVrkN8063syJlPjqlKo2fWonY,720 +torch/include/ATen/ops/_local_scalar_dense_cpu_dispatch.h,sha256=vtYVIkSmsvtrIJTj-9uthBqhfovb75bEJsfDuUcsGyo,756 +torch/include/ATen/ops/_local_scalar_dense_cuda_dispatch.h,sha256=1P_Rh6Yu2rQQO9FniII8OV1JlsGJUT_XI6l7PHgg6TM,758 +torch/include/ATen/ops/_local_scalar_dense_native.h,sha256=PZ1UjyGsg8YL7IT4ifKfVstgz2YxpRyrpijJ8bDbNcQ,587 +torch/include/ATen/ops/_local_scalar_dense_ops.h,sha256=A08mSrMp6QIqKVzahQ20UEWvIIPVVc9qLod76hbrheY,1025 +torch/include/ATen/ops/_log_softmax.h,sha256=d9C5VuwI_0giQUwvpji386R42_UftzrnYho6V3iGEGU,1373 +torch/include/ATen/ops/_log_softmax_backward_data.h,sha256=g2rgXFkyIn171Gc22iaeloCjqWxvd1qzxjIC8Kgaops,1756 +torch/include/ATen/ops/_log_softmax_backward_data_compositeexplicitautogradnonfunctional_dispatch.h,sha256=rxA__a7TuQ_qKnqKrxFLks6wM7zu8YmO3qqybQXmxKQ,908 +torch/include/ATen/ops/_log_softmax_backward_data_cpu_dispatch.h,sha256=xbe2Jylv1AKiKxr60b-G0eMNWCkva2Y-5xu-M5x9gPY,1187 +torch/include/ATen/ops/_log_softmax_backward_data_cuda_dispatch.h,sha256=TBG3vuXKBWFSDnMQzfkkUrM2-9TtiB8lnOWVH6ADiaQ,1189 +torch/include/ATen/ops/_log_softmax_backward_data_meta.h,sha256=oaULI-MNAZMnzpWDAG_1o4o-MeS1C7kPDPIbKN8kj_I,689 +torch/include/ATen/ops/_log_softmax_backward_data_meta_dispatch.h,sha256=hU-pB16_WO_o1EJq_sZXfm9qlIkW88B0zRVtEhhUkWs,1189 +torch/include/ATen/ops/_log_softmax_backward_data_native.h,sha256=i73Svb1S1etIUDA0GrgcM0uSmgzkO1sUp2Lrl7hjmcw,1010 +torch/include/ATen/ops/_log_softmax_backward_data_ops.h,sha256=pLBVEfqYUxrQ_1H_M92VQg1n25lPruHjF3KRW_wfXh8,2207 +torch/include/ATen/ops/_log_softmax_compositeexplicitautogradnonfunctional_dispatch.h,sha256=s9ssEgID_tKKPZcxhTXSJ88zxoojQRbkg0Y1j8LTQrE,852 +torch/include/ATen/ops/_log_softmax_cpu_dispatch.h,sha256=M5MXCl5wxJgbTZOEWlc3FEFGRirInK81K562PtsaBIo,1019 +torch/include/ATen/ops/_log_softmax_cuda_dispatch.h,sha256=6dS8HcSBm-iA9-t_PkFg0ICQEbleChWtvz9cQPMITPg,1021 +torch/include/ATen/ops/_log_softmax_meta.h,sha256=aL3-R5tkBgSYBOFlr2-DjR5AgPCVMEklbw79pjoUfEo,633 +torch/include/ATen/ops/_log_softmax_meta_dispatch.h,sha256=OO-sdSJK6Zc-6G9uXAOofpx_wysKeE1JYgTTdwF129o,1021 +torch/include/ATen/ops/_log_softmax_native.h,sha256=Qs49vcEMUNuvbxOjsscGU7AJrLxRuLIW802Iy3WgXtA,866 +torch/include/ATen/ops/_log_softmax_ops.h,sha256=ip6JQ9onouac2aXbg6RqVbWBIv4Pwg-xIY1EHIJ5V40,1843 +torch/include/ATen/ops/_logcumsumexp.h,sha256=eGHxy2zOgzEHJ-6OpqiJqtEMXPLJNLtrZCBY2NEW0Ww,1218 +torch/include/ATen/ops/_logcumsumexp_cpu_dispatch.h,sha256=TuIP17NHGjgPYyI3g3cwVSj9y0FVLZdhi_pdUup4Aqs,962 +torch/include/ATen/ops/_logcumsumexp_cuda_dispatch.h,sha256=Q7PEnQ84AEGuMIa2EWxzhDtasN2xWzeq1eNc3fGi7-g,964 +torch/include/ATen/ops/_logcumsumexp_native.h,sha256=nhA1Nve3HbU5k7zyPgAYKlDnPQurOeqpxVB9_zOlRx0,808 +torch/include/ATen/ops/_logcumsumexp_ops.h,sha256=QUP0yRRZbedDTTO8xTq3NyIkJzlbOqLPLwk10smqlyE,1717 +torch/include/ATen/ops/_lstm_mps.h,sha256=oYKTbGG-qaYbkHEBKM14b8j2re9SLrs4D9ttebyVblM,2984 +torch/include/ATen/ops/_lstm_mps_compositeexplicitautograd_dispatch.h,sha256=lvgku05CmwOteLBKp5C4agsKmFYtq51Fl-UshGojuGI,1539 +torch/include/ATen/ops/_lstm_mps_native.h,sha256=G3ZWR1-7pRCnahhsFKzdvcnDkhIgCbWoxVbAbLvKWR8,845 +torch/include/ATen/ops/_lstm_mps_ops.h,sha256=0Pu7ypGB15cyoTBJ2Ia5ab-HcTeBUwjUdk31UK0Kbz8,3489 +torch/include/ATen/ops/_lu_with_info.h,sha256=gOHdZhvcq5JlEDItmHWg4CyZuQWUyB3qgb0Rx-QpHKk,868 +torch/include/ATen/ops/_lu_with_info_compositeimplicitautograd_dispatch.h,sha256=Kw5vnvIqc3pxCcSaf96VWcZVAW_RHvh7Za7xskd8QsI,871 +torch/include/ATen/ops/_lu_with_info_native.h,sha256=fDeDbK4ZXs_2m_hXQEIBHtQxpGFk_Nmp_BYvMLGlFbk,581 +torch/include/ATen/ops/_lu_with_info_ops.h,sha256=XyZzMsBNTJsQK3CrErfTEVbILWJan4ilL2Fyif5IXo8,1263 +torch/include/ATen/ops/_make_dep_token.h,sha256=zo_xsNyTkYjrj1dPtDNdpzGfD6imycoCE4s1vWrrWYs,1578 +torch/include/ATen/ops/_make_dep_token_cpu_dispatch.h,sha256=KVekXkpOiJ9o7I8mtF9OT_lUf-Q6u5YUQM3ys-t0VZk,1054 +torch/include/ATen/ops/_make_dep_token_native.h,sha256=qlDEAg2NFIY-Kr8xRgRT6clnBPX5KbXKAX8x7L8Q7rA,706 +torch/include/ATen/ops/_make_dep_token_ops.h,sha256=c3qOPQwW_anaeDZkZGpF_-ieJdStwnEbuYAjmJB85yc,1593 +torch/include/ATen/ops/_make_dual.h,sha256=OiqCFf8rP1IQiW7081zFia8R_NOJ7tKQkhNSTPFyXAs,782 +torch/include/ATen/ops/_make_dual_compositeexplicitautograd_dispatch.h,sha256=JDSpA_RnBmOZ1x9q-TJigJbSqy0qLMIkjk5QVkHRSCA,836 +torch/include/ATen/ops/_make_dual_copy.h,sha256=4NqoO7ysBMbquDmmkoNw0JhAXr7VSN5MyQb8ZNC6hX0,1433 +torch/include/ATen/ops/_make_dual_copy_compositeexplicitautograd_dispatch.h,sha256=O2XneTBuzyWuKedlTYiG3kl9tsCde5_lhNAqu3rqGjk,999 +torch/include/ATen/ops/_make_dual_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=NDQkqqw29plG1pZkyEtos4_9PgWACJvxR8vQ6gPSr6s,867 +torch/include/ATen/ops/_make_dual_copy_native.h,sha256=IjzW0Zk3_M-bpGKIFNLyDbH31zS-RIEoDCOimY-ad8c,684 +torch/include/ATen/ops/_make_dual_copy_ops.h,sha256=2GizPrxzYFtsBtEBAmZI5XDDRlrjO1OJcZHfXY10lzI,1937 +torch/include/ATen/ops/_make_dual_native.h,sha256=yRc1eJ-kvgiqt1wiZE7SXkTLD_G8RLknWk7wU8wu-qg,546 +torch/include/ATen/ops/_make_dual_ops.h,sha256=n4slq0ZhPGw3ZVb0ZevRwLpjVRJmC1DbS5bPUR5ofew,1152 +torch/include/ATen/ops/_make_per_channel_quantized_tensor.h,sha256=_lxwl7iKmjk9Sc0Q8tZ61jK-A6OmyMwpcr5yT4FQiqQ,1764 +torch/include/ATen/ops/_make_per_channel_quantized_tensor_compositeexplicitautograd_dispatch.h,sha256=ah3xUy0cqI7NsoZdk4rGVkIB9csBx0rZXIZPwg7bsos,1089 +torch/include/ATen/ops/_make_per_channel_quantized_tensor_cpu_dispatch.h,sha256=t9cMO48VhxL9HpbtirD1p4D5URaK-IIOTx2DFT8oXGA,842 +torch/include/ATen/ops/_make_per_channel_quantized_tensor_cuda_dispatch.h,sha256=UMJBcPUfHNJSp6ZwJ618_s9pB1GJjpjSrlCYzYRNxwc,844 +torch/include/ATen/ops/_make_per_channel_quantized_tensor_native.h,sha256=8O1wSwMdl2tPY76qBOvpYyYiQxtjkgUMkWGjUYMqX10,935 +torch/include/ATen/ops/_make_per_channel_quantized_tensor_ops.h,sha256=jwqg0Plq-wdg5baKLNR9ut8FtIkGnObydkmAZGFNuzs,2223 +torch/include/ATen/ops/_make_per_tensor_quantized_tensor.h,sha256=h1qt46l-WyQySwXCDbszkMXKFBGz2CyYj-EqT6A9ihw,1583 +torch/include/ATen/ops/_make_per_tensor_quantized_tensor_compositeexplicitautograd_dispatch.h,sha256=YZAeixFucRyu7vS8OqmGOpbzER8mAyZ9EdDLt6bydE0,1013 +torch/include/ATen/ops/_make_per_tensor_quantized_tensor_cpu_dispatch.h,sha256=ijfzgJc0a15cNxJBY5_quTbd67nuqocabIKqvGYwo3U,804 +torch/include/ATen/ops/_make_per_tensor_quantized_tensor_cuda_dispatch.h,sha256=zc0TzOts9K4Q1vzgj07300_tzRYzV2t18CPIsjl_qQw,806 +torch/include/ATen/ops/_make_per_tensor_quantized_tensor_native.h,sha256=hkHlwL-b9WZU75xt1C8SevJ3XO9XwKowKSKd4ARAfUU,821 +torch/include/ATen/ops/_make_per_tensor_quantized_tensor_ops.h,sha256=8QEeyKZ1Oqs5yMQNVdsi-MgeNe8iDUQpaVtDBKbOlgc,1977 +torch/include/ATen/ops/_masked_scale.h,sha256=ApooN8Iq4P8ykmOJk_zHpkrC8m2OEvp-aRCzcxr5j-Q,1371 +torch/include/ATen/ops/_masked_scale_compositeexplicitautograd_dispatch.h,sha256=eQr0qC-oQ24nBF0hUr7b1jhYa8R1CjBnel_-7y2Hnps,983 +torch/include/ATen/ops/_masked_scale_cuda_dispatch.h,sha256=7j-6_7Kd8TdDECKH4Ui1Q-I6lGtLsLydA2B1W0Oq864,791 +torch/include/ATen/ops/_masked_scale_native.h,sha256=cqoVEydkpP8D7K0bY3q11owBth9xbQWWoGo1YIMicdw,672 +torch/include/ATen/ops/_masked_scale_ops.h,sha256=roelW6k-9pqoDHnvrZySTh3UIQjVVsR2657xWuMcpw0,1893 +torch/include/ATen/ops/_masked_softmax.h,sha256=Nkn2rteLsY-gdZG2b31dH26w9Eh9jVngmf4wJNakQnM,1703 +torch/include/ATen/ops/_masked_softmax_backward.h,sha256=n7vHudLMVs0xNQjFrJfrpwwxvGpZ0suElSL6jvQZPuI,1772 +torch/include/ATen/ops/_masked_softmax_backward_compositeexplicitautograd_dispatch.h,sha256=qxfBnBqoJjGVh5zzmYogK1vbL8ffh9TRHXEvJ5UYaGs,1120 +torch/include/ATen/ops/_masked_softmax_backward_cpu_dispatch.h,sha256=meOF9rozMD8O09zBAubvyKhqnQI7atJs44q2EWHSKIo,865 +torch/include/ATen/ops/_masked_softmax_backward_cuda_dispatch.h,sha256=iSoFBJlMkIVe7zaMyZlixh5rzZfHF4DgSoGVhLUjJeI,867 +torch/include/ATen/ops/_masked_softmax_backward_native.h,sha256=nEZXTsE4iC1A1PHXCUxufaymySyHsyHvYjitFGkyAP0,989 +torch/include/ATen/ops/_masked_softmax_backward_ops.h,sha256=qLZFp0t35StlswgYqj4sANpmI0zvZpHTaOzKk8X13bE,2283 +torch/include/ATen/ops/_masked_softmax_compositeexplicitautograd_dispatch.h,sha256=jMjAqcC_5vHXOYO2VIV64cHluSUSursjur8OUfB6ZhA,1121 +torch/include/ATen/ops/_masked_softmax_cpu_dispatch.h,sha256=krlJzwLB81M2k6DcKF9m4a_OhVgsq2iutGdyfA7OEBI,873 +torch/include/ATen/ops/_masked_softmax_cuda_dispatch.h,sha256=uMe9To1YuLiqhVkHUySk4YJ1RApcFvDs3yFLn-QouqY,875 +torch/include/ATen/ops/_masked_softmax_native.h,sha256=8LBllapMOAg07a9MLxT_BnZXryS3t_WfqeI4GQe1UcI,998 +torch/include/ATen/ops/_masked_softmax_ops.h,sha256=wm6sBL76fkc6GeAEaf8MwnGwuTV-l84G-4-TJWLjApc,2247 +torch/include/ATen/ops/_mixed_dtypes_linear.h,sha256=EpxqcuqIBhbAay2mvZSZz-6alDWd01C8Uwqs0qvfiBE,992 +torch/include/ATen/ops/_mixed_dtypes_linear_cuda_dispatch.h,sha256=zYONspGM-NfHtZPLcpOt8-3K36pEBZn4aW99ymJKEKc,919 +torch/include/ATen/ops/_mixed_dtypes_linear_native.h,sha256=fTP4CoBaCuNvdo8LMwlmxYuXyCfAjReuzQ2yqa33tzo,671 +torch/include/ATen/ops/_mixed_dtypes_linear_ops.h,sha256=SbnGaRDVKEbOi_LB4zH4oLMw-vkNWF2UmzjI0pl3QVg,1498 +torch/include/ATen/ops/_mkldnn_reshape.h,sha256=LyWvUO-fg5muSNr3aKpu8cJnSbaqpd_gJ2RNTQyGvmE,1286 +torch/include/ATen/ops/_mkldnn_reshape_compositeexplicitautograd_dispatch.h,sha256=52DACsy1eGQHHje7h8DnhX20VrqUAScp--c28SQPMZQ,955 +torch/include/ATen/ops/_mkldnn_reshape_native.h,sha256=abWpNx-5LZdSlyFHp9Rex6UbRtoTYcGFRpRZMe1bMGM,639 +torch/include/ATen/ops/_mkldnn_reshape_ops.h,sha256=bQNczA4YmMg5cq_mOiXhyELof5yn4hws5JSuYvkWM7k,1793 +torch/include/ATen/ops/_mkldnn_transpose.h,sha256=cJ3hLPtXdLOiiX-h0C1p_0E_2yHQpHbfPUpSEPEJWgc,1597 +torch/include/ATen/ops/_mkldnn_transpose_compositeexplicitautograd_dispatch.h,sha256=WbToP4XWkjksIjcFAmyyOt72b4VMLalbvs4GZ_nZ-EY,969 +torch/include/ATen/ops/_mkldnn_transpose_meta_dispatch.h,sha256=FTFUY6y2pMqRr6vVYAR9t5naV9mVQjFcIKIRcFovlFE,781 +torch/include/ATen/ops/_mkldnn_transpose_native.h,sha256=weiDSnoAshFxoaUgJKDTcmLmyWPzSVpo1yJ6U-PY8U0,743 +torch/include/ATen/ops/_mkldnn_transpose_ops.h,sha256=ZDi7qT5wPrNkTvmy4vEdAEVre8gXbyQwgvtc6ATrbQ8,2478 +torch/include/ATen/ops/_mps_convolution.h,sha256=zMbjvNe1fX8Cp6-6NQ_LnJGBRMejbHNHBgLoMQhwFHo,6947 +torch/include/ATen/ops/_mps_convolution_compositeexplicitautograd_dispatch.h,sha256=94IVZyvJ6z354Z3YnNLjiBOw_WshmuDLslUsDfLeO7U,1776 +torch/include/ATen/ops/_mps_convolution_native.h,sha256=CMmyf09p8wG2tN95LLHti0U1IYLOpOApV5xGpgNnzUM,714 +torch/include/ATen/ops/_mps_convolution_ops.h,sha256=sfXmRhqcp82O19AbXqTSXGHXbN86h60-A_WS44Hu5WU,2813 +torch/include/ATen/ops/_mps_convolution_transpose.h,sha256=SMH2DcfAeY5VZIrxQ9qnwQUUG_fGfOZGIJL_5jeFW40,7503 +torch/include/ATen/ops/_mps_convolution_transpose_compositeexplicitautograd_dispatch.h,sha256=ikv_w0YLq_nkYMpBQFCPGpTzskZqiDguhfDCAYrhnyw,1784 +torch/include/ATen/ops/_mps_convolution_transpose_native.h,sha256=0drNEbABVN7U-sRzjw6of9yH5x5_gk6UAFF_j1YPosM,718 +torch/include/ATen/ops/_mps_convolution_transpose_ops.h,sha256=dmAn-ywNPd4usVI_-Uy6Vk3L6pQGxPBgKos5lxlxl7Q,2839 +torch/include/ATen/ops/_native_batch_norm_legit.h,sha256=wVEtMIqGRCr3p-iu2rVg58-XBjkYzfKSM_EDcrJ5v5A,5434 +torch/include/ATen/ops/_native_batch_norm_legit_compositeexplicitautograd_dispatch.h,sha256=nBYr2bDPYmTf-jKM9d33aagdqClGBgm0XDLucowx66Y,1070 +torch/include/ATen/ops/_native_batch_norm_legit_cpu_dispatch.h,sha256=aaYU-GDibNrJ6Tc_gnsmfjFWg34EeQVDaH3KNbZ9sPs,2601 +torch/include/ATen/ops/_native_batch_norm_legit_cuda_dispatch.h,sha256=44Dp-XTMzNV8ihKRanQm0eni36N0UpYm0H3lZeHw7wk,2603 +torch/include/ATen/ops/_native_batch_norm_legit_native.h,sha256=JaRqkSsI1_AssTNhqaVl4TXPCAHVRqQEzCdzkmYnl9M,3782 +torch/include/ATen/ops/_native_batch_norm_legit_no_training.h,sha256=Rr-wmFRle82oNcKzf2E41XR862Awvx8m0Fj2aTk_Mo8,2738 +torch/include/ATen/ops/_native_batch_norm_legit_no_training_compositeexplicitautograd_dispatch.h,sha256=pBxUy9hfCX9P2cCrVBaSSv7_O_ABOfhnIN4WcRqE_Yk,1773 +torch/include/ATen/ops/_native_batch_norm_legit_no_training_native.h,sha256=EYsfShJgUlniBkulSNhpSvhRwhsBtUViGNk2ZazP-6Q,1106 +torch/include/ATen/ops/_native_batch_norm_legit_no_training_ops.h,sha256=tJc4hDF4gcMqcd_dpNYCgPRRwmmtc62NqWic2DNwR-c,3362 +torch/include/ATen/ops/_native_batch_norm_legit_ops.h,sha256=8Jwt8sY0oytyOw2_qp4NNjH55WEAeaHfXVyc2SjRnBA,7431 +torch/include/ATen/ops/_native_multi_head_attention.h,sha256=_smKLnH5-qS1zdzRIKQC3vFmUS6w03E2j0VWX-_Y9UE,3629 +torch/include/ATen/ops/_native_multi_head_attention_compositeexplicitautograd_dispatch.h,sha256=L67zUaBQ4HhsL-Nim4yRD7sydXF9QhoUJiyYF0d7Lyw,1725 +torch/include/ATen/ops/_native_multi_head_attention_cpu_dispatch.h,sha256=PiABSMIgV_9KrbB66HXPK7MsQ12_Pp0Gd13IAZExBiE,1152 +torch/include/ATen/ops/_native_multi_head_attention_cuda_dispatch.h,sha256=oS49Oq8rplvcq7vbP8l-jTLWFY8fk66Df1rRihWgaGk,1154 +torch/include/ATen/ops/_native_multi_head_attention_native.h,sha256=6kDSe4jh_aVD96lzO7dFvbevgQiavCPnVMOK3q_9gjE,1859 +torch/include/ATen/ops/_native_multi_head_attention_ops.h,sha256=l2SWh9JT5H2ozLXiZ3TguLQPX6f-u6zNlYwoVzWnlPg,4185 +torch/include/ATen/ops/_neg_view.h,sha256=hEru4NVC8I3YcCVbS1PzljWvhreCi9XrDZa2MJ1EQdU,686 +torch/include/ATen/ops/_neg_view_compositeexplicitautograd_dispatch.h,sha256=kwty3z4v1SjdpIRnMtWPYDzp2gzDyVSc4UzRIMQdfH8,790 +torch/include/ATen/ops/_neg_view_copy.h,sha256=4zwdf37P7ws-NcNDUAg5KBggutavzuVWf0d5fwvmdAA,1147 +torch/include/ATen/ops/_neg_view_copy_compositeexplicitautograd_dispatch.h,sha256=jomBXUVPb0z8GAhWSc4nArDdJd0_eFC3U07YrLOmcvk,907 +torch/include/ATen/ops/_neg_view_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=UznGlAPSw9lMzzStVW5qZV7LFry7EWFlzd6G7b7MTtU,821 +torch/include/ATen/ops/_neg_view_copy_native.h,sha256=g0xIWtb2HwA5WrRTAF5naf5QFLTpjTX4iF01i8h_FDM,592 +torch/include/ATen/ops/_neg_view_copy_ops.h,sha256=eprKaMC0aVZs5IAJdsm7LFykNspPnDzkcO6vcIRXjH4,1635 +torch/include/ATen/ops/_neg_view_native.h,sha256=GBnnFNJ2QbKx35YopGMcKtjW1QLETmRRSVZVkDxNoWw,500 +torch/include/ATen/ops/_neg_view_ops.h,sha256=rFFt1URSsEYhNB-qS_pTAGihOx9ftGdMvn82VYC_rZ8,1001 +torch/include/ATen/ops/_nested_compute_contiguous_strides_offsets.h,sha256=bZUfN_hzjz8HWx4UEjFVDuwi9wJQZzGpPl1GKngKC2s,868 +torch/include/ATen/ops/_nested_compute_contiguous_strides_offsets_cpu_dispatch.h,sha256=_KySeSgwsNNZQ_ypjYrbZBqWnMZVLkCH6WUKT5r0arU,811 +torch/include/ATen/ops/_nested_compute_contiguous_strides_offsets_cuda_dispatch.h,sha256=m5s8hp4hACTYALs91d3e1MYtvSOAqb52xZC8hsXpHQU,813 +torch/include/ATen/ops/_nested_compute_contiguous_strides_offsets_native.h,sha256=6J9qHh0AGCm3ZAEZpm-VMZzchH9bQXJjnrmJlI1LIXU,565 +torch/include/ATen/ops/_nested_compute_contiguous_strides_offsets_ops.h,sha256=L8jw9HHlIrdXGOhfL58m9c0_7k9PQ4B-iEawuy5OP8I,1200 +torch/include/ATen/ops/_nested_from_padded.h,sha256=RZH-oGshCrd-XfBIsWbAuk4P_pZCaw9XNullhe92h4g,1776 +torch/include/ATen/ops/_nested_from_padded_and_nested_example.h,sha256=buXo4XN4XatRvZk89WfQDJUXjo_uP9whHBMXeagmCoc,1591 +torch/include/ATen/ops/_nested_from_padded_and_nested_example_compositeexplicitautograd_dispatch.h,sha256=TpWNmI32Ar7n24cDyYicXAW_3jVB0vk-BTQU1awmSgQ,1021 +torch/include/ATen/ops/_nested_from_padded_and_nested_example_native.h,sha256=ojBfPDfnVh0cjIqUTGeJ1qc33VpZTdezDcB1I3YZe5Y,711 +torch/include/ATen/ops/_nested_from_padded_and_nested_example_ops.h,sha256=92ZUuP1KE61tVCQuidNt9ZQG86OGI0o8HdxyKjc3Gec,1993 +torch/include/ATen/ops/_nested_from_padded_compositeexplicitautograd_dispatch.h,sha256=L6PTErALdoo1Qvyik4NpxMCEYHkzwriLBeHRmjwgeqQ,1069 +torch/include/ATen/ops/_nested_from_padded_cpu_dispatch.h,sha256=EgSFaqhSNQCH37Y5pCtAcmiYINKN8Erxfd1BT-ThWCE,835 +torch/include/ATen/ops/_nested_from_padded_cuda_dispatch.h,sha256=ys3d4Y0dOvzbl-j5nevsAutDFsd1Z7WHxuIyO8pbmUU,837 +torch/include/ATen/ops/_nested_from_padded_native.h,sha256=9WKF_b7TJN4UEJDdhA3NAtFMzXwt86LFnNoHYR4aLF0,912 +torch/include/ATen/ops/_nested_from_padded_ops.h,sha256=UExoLF2DqavyxloEdPfwQutFiVRU66oDjHRdEFGoMD8,2143 +torch/include/ATen/ops/_nested_from_padded_tensor.h,sha256=Ek6fcgzFcZwkpmOUwlIyzgc5hf9KALq2pJGt5fKUnkA,2985 +torch/include/ATen/ops/_nested_from_padded_tensor_native.h,sha256=9XaUGLJyiIUXnjau4tp-HFMbdt34qeE71RClgADHWpo,442 +torch/include/ATen/ops/_nested_from_padded_tensor_ops.h,sha256=OaZEVX5F4AgQeSc4cqtVNWkTDL5jGyruo-r-jmlx2Dg,1735 +torch/include/ATen/ops/_nested_get_jagged_dummy.h,sha256=YzNpXJtpGz11B_T_YiyiLZuQCKTdTFDwW7r5nBCjqPM,737 +torch/include/ATen/ops/_nested_get_jagged_dummy_native.h,sha256=9XaUGLJyiIUXnjau4tp-HFMbdt34qeE71RClgADHWpo,442 +torch/include/ATen/ops/_nested_get_jagged_dummy_ops.h,sha256=wYO5GusVuwBzYK-RiB7dOXat4VltfiEaIMJpzomgPz4,1037 +torch/include/ATen/ops/_nested_get_lengths.h,sha256=OOKbHLQrWA62TCY7kPgzh12T7rOgOgbq0A8IUaRDqfQ,720 +torch/include/ATen/ops/_nested_get_lengths_native.h,sha256=9XaUGLJyiIUXnjau4tp-HFMbdt34qeE71RClgADHWpo,442 +torch/include/ATen/ops/_nested_get_lengths_ops.h,sha256=tnE36qE4Lm_o1HMqHkSZZxtPNgImbfUhUvpFyrAJhmw,1025 +torch/include/ATen/ops/_nested_get_max_seqlen.h,sha256=QxxritnVahEIfKMFQTGqFhX0UwTcBwuNtSKuzkqmhAQ,732 +torch/include/ATen/ops/_nested_get_max_seqlen_native.h,sha256=9XaUGLJyiIUXnjau4tp-HFMbdt34qeE71RClgADHWpo,442 +torch/include/ATen/ops/_nested_get_max_seqlen_ops.h,sha256=Jauo0ciXnpvWOj2KQ0QdGusx4yZRkWchSU8p4SgbEfU,1034 +torch/include/ATen/ops/_nested_get_min_seqlen.h,sha256=Re81pF8WcaoV0EmB8kkNjRveOVF_1KqC_G-pPJuCDhQ,732 +torch/include/ATen/ops/_nested_get_min_seqlen_native.h,sha256=9XaUGLJyiIUXnjau4tp-HFMbdt34qeE71RClgADHWpo,442 +torch/include/ATen/ops/_nested_get_min_seqlen_ops.h,sha256=C2ssBoXHv-CHEaOBgi_pYtVtWu3R3VzkGrJg4lDuz24,1034 +torch/include/ATen/ops/_nested_get_offsets.h,sha256=I6RKqH0VNq4dncEl8S0M9v-w1lnDFUoay4xh-cEQOtY,720 +torch/include/ATen/ops/_nested_get_offsets_native.h,sha256=9XaUGLJyiIUXnjau4tp-HFMbdt34qeE71RClgADHWpo,442 +torch/include/ATen/ops/_nested_get_offsets_ops.h,sha256=_lf8flui8asax8sddvg5FulB-gemAN-LiCjrUrI-1vk,1025 +torch/include/ATen/ops/_nested_get_ragged_idx.h,sha256=XezHDgPcC6FBByHBS6rjpNRUCy28jHCBwktsAcHk65A,726 +torch/include/ATen/ops/_nested_get_ragged_idx_native.h,sha256=9XaUGLJyiIUXnjau4tp-HFMbdt34qeE71RClgADHWpo,442 +torch/include/ATen/ops/_nested_get_ragged_idx_ops.h,sha256=AseA9IsQJkVMMhFT5fu5k3m73cEz99eoeXm55RnEnc4,1022 +torch/include/ATen/ops/_nested_get_values.h,sha256=Gk5v7ra8ow6V3DpqiSrSHq8ZrrbNC4usVEKgVcHzYtg,722 +torch/include/ATen/ops/_nested_get_values_copy.h,sha256=VDrZcQf6QUC0RDvIEkthCyusKXQEGC9xI_jVpQqhagI,1237 +torch/include/ATen/ops/_nested_get_values_copy_compositeexplicitautograd_dispatch.h,sha256=Ah0MA6M3t5eXBD6YJl5L1dOuQGukr49IJ5uyd7y6X84,925 +torch/include/ATen/ops/_nested_get_values_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=h7LRDig_84xdSHi0yIlxDxa4AfZ9gP2APHrVE-3We_8,830 +torch/include/ATen/ops/_nested_get_values_copy_native.h,sha256=0PUZgrtQbREfOfJJhsdmmtudZiUj7M_ZXOvvAmPsUz4,610 +torch/include/ATen/ops/_nested_get_values_copy_ops.h,sha256=-HYHS2PKcrl-gnLtnuFQwiOupKXa6wQLMRV2gz3XY6k,1689 +torch/include/ATen/ops/_nested_get_values_native.h,sha256=9XaUGLJyiIUXnjau4tp-HFMbdt34qeE71RClgADHWpo,442 +torch/include/ATen/ops/_nested_get_values_ops.h,sha256=1IzrfW4b48_tQ7Evis1WwC50IbQXzSm7hZL6KT01eic,1028 +torch/include/ATen/ops/_nested_select_backward.h,sha256=MKaH3TophnduwJcmv0albQWy_HTYDiVvMIujCEvAhgU,1840 +torch/include/ATen/ops/_nested_select_backward_native.h,sha256=YG2ozKZ-XpskFHhrsAzpHhjf19v3-MNTXKnlw-TgNRs,585 +torch/include/ATen/ops/_nested_select_backward_ops.h,sha256=SVXFvStN-GjAA0_Mp5EDPPbHP6xIC7iVymeEFEpBvmQ,1250 +torch/include/ATen/ops/_nested_sum_backward.h,sha256=k4GSDeSP7oSaP06Gnl827LhdhwIyHamPplvrxcwErJQ,864 +torch/include/ATen/ops/_nested_sum_backward_native.h,sha256=I9ect803gsDWd1ih8b6MIOOcCbv8A43PZ0QWhRk28xQ,589 +torch/include/ATen/ops/_nested_sum_backward_ops.h,sha256=vZCKpGjS6FUwM2SPrOIhIDv3grFAwcN5q6-k-kKVL28,1261 +torch/include/ATen/ops/_nested_tensor_from_mask.h,sha256=uIuhacS7xTh2PpwLNvAbSuV9i31YzZLCiouTabX18mo,1515 +torch/include/ATen/ops/_nested_tensor_from_mask_compositeexplicitautograd_dispatch.h,sha256=GTYv9FWLIa_bH-gr6eKmIvEsMbeuNVD5n4qZzQOoIas,1010 +torch/include/ATen/ops/_nested_tensor_from_mask_cpu_dispatch.h,sha256=2wwMOpHhk91psWlWp1ZhsGgOoQKaQBxyg1s374f6uPc,805 +torch/include/ATen/ops/_nested_tensor_from_mask_cuda_dispatch.h,sha256=jdXhCShdWJPv5qzeVhEgvQo2Fj6NEKp4OdqnlcSzf9M,807 +torch/include/ATen/ops/_nested_tensor_from_mask_left_aligned.h,sha256=fddY03-CA5a0he4c3UmKoSE0j7-cw3sdhXccd91Dtls,819 +torch/include/ATen/ops/_nested_tensor_from_mask_left_aligned_cpu_dispatch.h,sha256=F_uhE05AXvkSYmjXvRnuEBRtfshPS_4pkWii7ABV888,790 +torch/include/ATen/ops/_nested_tensor_from_mask_left_aligned_cuda_dispatch.h,sha256=MaS0Cmc0Hl4SU3klyyn2ZcnWFr-FEnQpzWp0jiXrYfA,792 +torch/include/ATen/ops/_nested_tensor_from_mask_left_aligned_native.h,sha256=VEZRRxSLy-oLkpMKNUT8Ofz5kkulZCAdE0W4RXzU4AA,556 +torch/include/ATen/ops/_nested_tensor_from_mask_left_aligned_ops.h,sha256=67hhbe8cutUx5SPPWAxunoUo9QlghttYXXR6fzAEhxM,1133 +torch/include/ATen/ops/_nested_tensor_from_mask_native.h,sha256=B1j2zhZ2SL8xUJPvVhayU4YHQn8ripe5NBoFAWtL08U,707 +torch/include/ATen/ops/_nested_tensor_from_mask_ops.h,sha256=r9a6i1TmZRTQgbnF4lXIFIpsojvfSh7tKkXo6KntSik,1967 +torch/include/ATen/ops/_nested_tensor_from_tensor_list.h,sha256=771X-SZK234i2Br2ARh1oN1o4ybxrFd6b-gLlXDcfsE,2238 +torch/include/ATen/ops/_nested_tensor_from_tensor_list_compositeexplicitautograd_dispatch.h,sha256=fnKTYBS3WFN-CZxQ5C_4SvdQiuI7d7tBGiOCxAAvHQA,1564 +torch/include/ATen/ops/_nested_tensor_from_tensor_list_native.h,sha256=yG19vs7Mg_SJJjTZfjCF9vvI1PPzBNm7IZ6q7RlpsQU,968 +torch/include/ATen/ops/_nested_tensor_from_tensor_list_ops.h,sha256=WS5ze2nUNlct_kXD1X9dqlrt7DXmS5GlXYOK-KMOltM,2703 +torch/include/ATen/ops/_nested_tensor_size.h,sha256=5wPyLrhy39Y1yQWgYnVj_2p0LIaHKJXdvpMmDWGeeyY,1018 +torch/include/ATen/ops/_nested_tensor_size_compositeexplicitautograd_dispatch.h,sha256=zAky5LEOIM8_ddtNiIb1IYm2eSqGssMpXpnqa64f8uM,917 +torch/include/ATen/ops/_nested_tensor_size_native.h,sha256=eYnnU8n6W12nxfFwtWQWLtiiGCjVYl1xMv0w3cjRmzM,602 +torch/include/ATen/ops/_nested_tensor_size_ops.h,sha256=q_XxBdfbR3zikrcEtiv7bhGN1P6glnK2ohy5u2zXZBs,1665 +torch/include/ATen/ops/_nested_tensor_softmax_with_shape.h,sha256=XGQEG43jFVb60iqilNoOFW-cfhCUJ3aphrS3PNzFyKc,823 +torch/include/ATen/ops/_nested_tensor_softmax_with_shape_native.h,sha256=MwRIEnF3v8c3tveGnzMGKiqPR8h6bFXZlWtGv-0NRuk,653 +torch/include/ATen/ops/_nested_tensor_softmax_with_shape_ops.h,sha256=N8T9MXTQbGygHrvejcb7WljMl-b7P_BBkeAkRg__QE8,1153 +torch/include/ATen/ops/_nested_tensor_storage_offsets.h,sha256=kmHMwp4dVIb5DNIoOMBdmqvOnCxK1FcXgGOXMcCVJ7s,1095 +torch/include/ATen/ops/_nested_tensor_storage_offsets_compositeexplicitautograd_dispatch.h,sha256=pRkrZ_2iA6la4hC7XYnvOxMA5rnt-esVIbV5JheEAss,939 +torch/include/ATen/ops/_nested_tensor_storage_offsets_native.h,sha256=Z5MbixSw-iIaKc5W6cn1aljgE32XxNM0NhLHPgXwem4,624 +torch/include/ATen/ops/_nested_tensor_storage_offsets_ops.h,sha256=gt53L0pXv-wdQ1TFZkDTd_2OqQzTpvce_QPyAaRvHNM,1731 +torch/include/ATen/ops/_nested_tensor_strides.h,sha256=deNvHATZHD8qM8HzPbZ9ToCivkYVYV3o44tVWe-YNYQ,1039 +torch/include/ATen/ops/_nested_tensor_strides_compositeexplicitautograd_dispatch.h,sha256=VhDGwHqNFsFii5TMkiMGBLd-kvKNyZiz3BR8EQb16Jg,923 +torch/include/ATen/ops/_nested_tensor_strides_native.h,sha256=pRbmArmSeYDd2ErTMFOaAmavWw4ZRMqpDRnEobxpiTA,608 +torch/include/ATen/ops/_nested_tensor_strides_ops.h,sha256=apAHWiahV4fhm0JKK532YDoDfmDFKl3Q6qR_j9D4_Bg,1683 +torch/include/ATen/ops/_nested_view_from_buffer.h,sha256=d-ACvfBPEdqoiqla4EGg3na1eN4QKaCwswqRiUY5uQ4,938 +torch/include/ATen/ops/_nested_view_from_buffer_copy.h,sha256=Xd5VTv4A8fzt_bh2vxcaz1ylCLNuvyy-Pfy-Tc6JIfY,1873 +torch/include/ATen/ops/_nested_view_from_buffer_copy_compositeexplicitautograd_dispatch.h,sha256=7P0dBD3FwwTx5LucTcSVoa23CSgQHehoq9BCj6pPtIw,1127 +torch/include/ATen/ops/_nested_view_from_buffer_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=n-8omkz2w9bH36RBWA4kokhIw8790TojPjQ930766-I,931 +torch/include/ATen/ops/_nested_view_from_buffer_copy_native.h,sha256=1p6Aq6j5e-ad1Yz5jtZGuC4gKfKUyyypp-b_b7-oc8Q,812 +torch/include/ATen/ops/_nested_view_from_buffer_copy_ops.h,sha256=qXcLm0cuuk9O9YInbxq3DI6BP8KE8bBsHn_NYXFX6QQ,2343 +torch/include/ATen/ops/_nested_view_from_buffer_cpu_dispatch.h,sha256=0NUXMROIUdFslWBe3M6UnjDUYCi-tMiAx9NXLj3N4kU,856 +torch/include/ATen/ops/_nested_view_from_buffer_cuda_dispatch.h,sha256=uE2XseRvdPjBW4voY11JjzXXb0ZmUfPRKhN5alDiCFg,858 +torch/include/ATen/ops/_nested_view_from_buffer_native.h,sha256=eJYs9OvTt7AZOadfuFg9-6zdQ_YpuNC7fpWm25F69VY,610 +torch/include/ATen/ops/_nested_view_from_buffer_ops.h,sha256=Z0m02v1yJTTwKnbdcSTkVWisW2YzbwdMfp9BioVZs6M,1355 +torch/include/ATen/ops/_nested_view_from_jagged.h,sha256=D8dqbMagwrqlmcIQJ8C_jMDXbF99JQhDIPxhz3-BsUE,1153 +torch/include/ATen/ops/_nested_view_from_jagged_copy.h,sha256=mCkLi3Eedt0JOTCXv7eks97BrY6LRu2UUaxD_tk24LY,2507 +torch/include/ATen/ops/_nested_view_from_jagged_copy_compositeexplicitautograd_dispatch.h,sha256=7y7jfZg5Aif9ggKphUbdMznhCLEFAf4huhtqkr7V1sQ,1378 +torch/include/ATen/ops/_nested_view_from_jagged_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=QQF886-UE_pd8JFsgJBrXoXTB5tVGdoYj_nM0F0S4Gc,1062 +torch/include/ATen/ops/_nested_view_from_jagged_copy_native.h,sha256=__K3F0o-cZN8RvsRQGieOzz8646kuV3_DxdavQn0Qas,1063 +torch/include/ATen/ops/_nested_view_from_jagged_copy_ops.h,sha256=j7KeHLKahDF380vWaUS3IKecbMMPD0HOp_A0_f6Cjho,3145 +torch/include/ATen/ops/_nested_view_from_jagged_native.h,sha256=9XaUGLJyiIUXnjau4tp-HFMbdt34qeE71RClgADHWpo,442 +torch/include/ATen/ops/_nested_view_from_jagged_ops.h,sha256=h8B_HOp8ZT5n6--xQ1x9sQr965Wef4acnKT5HJmKvuA,1756 +torch/include/ATen/ops/_new_zeros_with_same_feature_meta.h,sha256=Vy9b0DhconfG7MKBW6Anr2BKv6DGotq9Zq0rKlb5Nd8,1716 +torch/include/ATen/ops/_new_zeros_with_same_feature_meta_compositeexplicitautograd_dispatch.h,sha256=w48MZqpdLHxHE3m4HrWEpUiFYXgL5u7M8xavt-tGpF0,1196 +torch/include/ATen/ops/_new_zeros_with_same_feature_meta_native.h,sha256=SkYTijaNSUMoZ42RkTb38YXFtdM3Y9zJ6Omy_ghq38g,742 +torch/include/ATen/ops/_new_zeros_with_same_feature_meta_ops.h,sha256=j_o9sjQdCgjm3IuA__Uso9KIZwd43XKG4jy-EV6TcHs,2112 +torch/include/ATen/ops/_nnpack_available.h,sha256=ZjJsRWh1AENOxX1Ivu7hjAJMK51on0TOP2xp4h7vRoI,666 +torch/include/ATen/ops/_nnpack_available_compositeimplicitautograd_dispatch.h,sha256=-hAoHXu2LcKSiLhc8MAzkGePQof-VJGMRCmJUtwDLOw,769 +torch/include/ATen/ops/_nnpack_available_native.h,sha256=gnpoBb1pa5X3s8ZFxvHxdYfQXqyxkL657YoWWGmCVBk,479 +torch/include/ATen/ops/_nnpack_available_ops.h,sha256=L5yMsYE3pAS6FA4NQ5gCURqXTL3y6Ughk3HqyTR1Q_8,922 +torch/include/ATen/ops/_nnpack_spatial_convolution.h,sha256=tCzCSrUT7VZ9i-8XZ-5cwCC5JEKz1EAwhB4H4nBF0UU,6282 +torch/include/ATen/ops/_nnpack_spatial_convolution_compositeexplicitautograd_dispatch.h,sha256=brCLSVWQdrCxBN5IK_QCc_leLyEDu33iv-J_oXuJALo,2079 +torch/include/ATen/ops/_nnpack_spatial_convolution_native.h,sha256=HwM3dkMWe3hyopU5eYr5q8g6GmE4Ior8Ty7dbaKfHwE,873 +torch/include/ATen/ops/_nnpack_spatial_convolution_ops.h,sha256=zS_f_qNuWIpWeii0Ec_W_UnLVp8oHpVk3TMhiaQHLY4,2557 +torch/include/ATen/ops/_nnz.h,sha256=KK0muwsE95Wiloq1RG5b6qEoOJu2PRvSDGbxyCpADoY,528 +torch/include/ATen/ops/_nnz_native.h,sha256=J1ecwoI91zlwsuqXCwbdL9fYkUNjkaeLGFCmy2O-9bQ,560 +torch/include/ATen/ops/_nnz_ops.h,sha256=GWvh1_98JL1bUuziI5B4vmWMYXlrN-_PEl2fWSZ-jTk,968 +torch/include/ATen/ops/_pack_padded_sequence.h,sha256=tKEkeDxo0d33C8BKGn1aY3RXG7AmszMJpQ2fHPeKkps,1739 +torch/include/ATen/ops/_pack_padded_sequence_backward.h,sha256=PqUCFnZOlzytB6t9OxKMhR4sWqIK02Lq4D3oEs-37oI,2125 +torch/include/ATen/ops/_pack_padded_sequence_backward_compositeimplicitautograd_dispatch.h,sha256=W7BKEOD5cnClc-hqg8rYVERmb0wxGksKpg3XHxt2Tr8,1057 +torch/include/ATen/ops/_pack_padded_sequence_backward_native.h,sha256=w1YbLGEiSuW4ecHNPAXMz3zbJ9n8XSqD0I3XKfh4jrk,610 +torch/include/ATen/ops/_pack_padded_sequence_backward_ops.h,sha256=7WceKq6XulAEyhClhHeC9wMkJuou-POdLUIrcB4J4RA,1328 +torch/include/ATen/ops/_pack_padded_sequence_compositeexplicitautograd_dispatch.h,sha256=Y0TyoGkG8CXE5YLp-Bxl6CglStsbzEc-BHmFLz7Ebb0,1251 +torch/include/ATen/ops/_pack_padded_sequence_native.h,sha256=hL8iC38E8i5D0ps6vQxvbkFG2Vjf7vlf32T537qyt50,772 +torch/include/ATen/ops/_pack_padded_sequence_ops.h,sha256=yQTTrC-nqpyoTuwF3_AU__55yb-LxM4Lhwkug7CyLHc,2239 +torch/include/ATen/ops/_pad_circular.h,sha256=GoAqneqNfHP0OOGT0v6E2G6I5HobqQjJrAquKMP3wow,1488 +torch/include/ATen/ops/_pad_circular_compositeimplicitautograd_dispatch.h,sha256=k-f9CvpcGUxJUgI6fkxSGHZXzp41qbmGGCwCtnifSAw,909 +torch/include/ATen/ops/_pad_circular_native.h,sha256=KMUW_vXU4g1v167fKYsXQSPNrpjQaYOIRPlHwZ8CP-I,536 +torch/include/ATen/ops/_pad_circular_ops.h,sha256=E42KVgq3Ov3NixPHJUC8a7PcVhFtln6Dx1BIhqB96x4,1092 +torch/include/ATen/ops/_pad_enum.h,sha256=s9adrylif_9MLElf2Mu31XnLH3ZOlacjX3sQ6Gisyr4,1794 +torch/include/ATen/ops/_pad_enum_compositeimplicitautograd_dispatch.h,sha256=RmPVhQAft9uvT_51KaNDXJ7dfGnBctSHZDgzJ53FWU4,1021 +torch/include/ATen/ops/_pad_enum_native.h,sha256=O-AFKfGkSZp_NtcITusdLpN5ihF0ynperS9h_C2rAq0,592 +torch/include/ATen/ops/_pad_enum_ops.h,sha256=w190JCrlzjTVZuuS9-4wcmjScV-pwX0vdWD4XHKC9qo,1233 +torch/include/ATen/ops/_pad_packed_sequence.h,sha256=sEP1zqb9jg9sKLJvRRz0GueCAimY1FdqpobgPfaEBvs,998 +torch/include/ATen/ops/_pad_packed_sequence_compositeimplicitautograd_dispatch.h,sha256=-x15aDEUSDsKNic06sm7ax5A6w_wkxDCWGRVVTLACeU,932 +torch/include/ATen/ops/_pad_packed_sequence_native.h,sha256=BTb4SjEcVd814dzIKmGJoxeNULwKgnIz--ZHI35UIMI,642 +torch/include/ATen/ops/_pad_packed_sequence_ops.h,sha256=DcYRJ_x6j229igBqd2AS53Sl9xe_Josr_k0j5-e1G94,1458 +torch/include/ATen/ops/_padded_dense_to_jagged_forward.h,sha256=HwjRaEd4uZDO5trOEDcTEXijwIzgcABtXIDEnncd1v4,2106 +torch/include/ATen/ops/_padded_dense_to_jagged_forward_cpu_dispatch.h,sha256=kseyP-HEGUZi8FSdhvx959VF_cIjA6UVyPfspE_J2GE,1007 +torch/include/ATen/ops/_padded_dense_to_jagged_forward_cuda_dispatch.h,sha256=eFrYteYgGIKP0sm7Bdh_GYb9AD0q9ZrPgSdrWpz4Fk4,1009 +torch/include/ATen/ops/_padded_dense_to_jagged_forward_native.h,sha256=kfqVDrO7uTpC4H8ALFxFRL9c87gQOQ12B0tduPtMq5s,765 +torch/include/ATen/ops/_padded_dense_to_jagged_forward_ops.h,sha256=MIIn_TeU6u_Gjte60m6-9HUPp_9X7O6bVt-5JnDVZaU,1274 +torch/include/ATen/ops/_pdist_backward.h,sha256=VU2NM517xbM5aVLez6jsa3xQnYePIMK4vW7UMjaXH8U,1496 +torch/include/ATen/ops/_pdist_backward_compositeexplicitautograd_dispatch.h,sha256=Rx68wcd0cihxuR5Tbvj0FnhNDgnL4VG_MdmqHLSE0a8,1031 +torch/include/ATen/ops/_pdist_backward_cpu_dispatch.h,sha256=aSrJYWlXT8sPLTZCm45a4YznNV1rT1rvRdLtoZZwHSs,813 +torch/include/ATen/ops/_pdist_backward_cuda_dispatch.h,sha256=TPSS8X59ZR6Ql7AQtUbosUi9uahxJqgAuFDgpFHgLrA,815 +torch/include/ATen/ops/_pdist_backward_native.h,sha256=HqwQGpVGaKZy-0iOCLKItaDXnBfJVhMRRLATb_yctrQ,716 +torch/include/ATen/ops/_pdist_backward_ops.h,sha256=wIOCrUKb_WLIlbXqie4jijlWcdBJH-86SxnVI79GplQ,2053 +torch/include/ATen/ops/_pdist_forward.h,sha256=sXOEtOiFVhysdNFfSnT9ffr9WB-2MNS3YbRzk5L9EHM,1223 +torch/include/ATen/ops/_pdist_forward_compositeexplicitautograd_dispatch.h,sha256=7mtKy9zLyWWhHDsZRq9E2vA_wns4XddsFNG3vnlMekI,929 +torch/include/ATen/ops/_pdist_forward_cpu_dispatch.h,sha256=4mgpq8f0gv7FWB0-RA2QZpP2f4_myagPbjjqU8NO9iI,763 +torch/include/ATen/ops/_pdist_forward_cuda_dispatch.h,sha256=LV00vy5cNuGRoPcrsgdpSWXcY1zZgYUaALS8Dv7cqd4,765 +torch/include/ATen/ops/_pdist_forward_native.h,sha256=bbXNQvCdrrR2WUlgkRAzVM18kr4q_AejSvAG3rKKnEs,614 +torch/include/ATen/ops/_pdist_forward_ops.h,sha256=ptzf25XCJxDP4mIYiDD5FE5uZAr9aR7A46rjR899Cdk,1713 +torch/include/ATen/ops/_pin_memory.h,sha256=sth4zYzCxwquAb7qD3OuRHjM2JGJXTvc3zXkkWHCKCg,1342 +torch/include/ATen/ops/_pin_memory_compositeexplicitautograd_dispatch.h,sha256=WisrWATsiQuFcOKc8PWPRZQP6vmDkYNaKa733esAe30,1099 +torch/include/ATen/ops/_pin_memory_native.h,sha256=PjAGwRt17fRYHuhQZO6Wv71T-XNIdg6G0wVtzoPjLr0,1042 +torch/include/ATen/ops/_pin_memory_ops.h,sha256=qFSo968chH_wlSQcaqhggDi_UodpXuOSxbkWTrZ63JU,1861 +torch/include/ATen/ops/_prelu_kernel.h,sha256=pmiVSYnEc-14e7-Hu76YPvTkRp-YD4DyH-n9PQk3-W8,746 +torch/include/ATen/ops/_prelu_kernel_backward.h,sha256=HmjASflypOGz6xaJSDmBd01zsKDB4zp8fxU44J0h6ZQ,882 +torch/include/ATen/ops/_prelu_kernel_backward_cpu_dispatch.h,sha256=kkDnEHpZw6vjOO-mN6XFKNO8yZfQbqpPjPAq3s05zPE,843 +torch/include/ATen/ops/_prelu_kernel_backward_cuda_dispatch.h,sha256=SCiDIsqxYqIKIGgn6YFJKP4VgUmm-KKt9JtX4r_2JPE,845 +torch/include/ATen/ops/_prelu_kernel_backward_native.h,sha256=AGrZXl7Z5psWjqzkhevzN35ahxCWlFrFx25wiUnRwRM,751 +torch/include/ATen/ops/_prelu_kernel_backward_ops.h,sha256=t6z9tvrIRfYBTfs6aG0hA4DhEWBjiawbYoS6sY15Ung,1312 +torch/include/ATen/ops/_prelu_kernel_cpu_dispatch.h,sha256=fC9VdD_iDGW5qHj3jWTo0E3OxmMIV3CVAoG4AyJzHgw,777 +torch/include/ATen/ops/_prelu_kernel_cuda_dispatch.h,sha256=pzivuUH5mrlIs7gJmunIV127MT3YEAHzczBdXQokh0k,779 +torch/include/ATen/ops/_prelu_kernel_native.h,sha256=r-sSmL7ZbZn2zgnz5jdcqxBGi7STrBEz9bM66ejsfL0,722 +torch/include/ATen/ops/_prelu_kernel_ops.h,sha256=G6-mP8JkPGFX9p9UnY3MI3pyPmLdKCfCOSOKNDd-ia0,1096 +torch/include/ATen/ops/_print.h,sha256=OTMCKt841f3vYc0r5nv1r8PRW8_M-JjE35MlK_ovb-s,644 +torch/include/ATen/ops/_print_compositeexplicitautograd_dispatch.h,sha256=Erl727TeuofFlNQJEobPzsb338CMG-fXyRLv3z2Sfv8,776 +torch/include/ATen/ops/_print_native.h,sha256=aV7cgzETeESG_jBetosMJ3Kp6CqFkYbW8kpZ-7_1Xlw,486 +torch/include/ATen/ops/_print_ops.h,sha256=AVifNPz43qfnUdZrcHczbORCQV7anXw_F-Ul-nhfR38,946 +torch/include/ATen/ops/_propagate_xla_data.h,sha256=DHbP6SDFOQpglQKmYJ-NgtI0ohooK5oen5MwF3HX1As,763 +torch/include/ATen/ops/_propagate_xla_data_compositeimplicitautograd_dispatch.h,sha256=kUw7uN_68vvdB8atz1afqOJNVEk1QzKMkx1uMCv2BLA,822 +torch/include/ATen/ops/_propagate_xla_data_native.h,sha256=k8QtZfxGIiRKZ-WnAs3rt5HC4V6B6vHkDjFt_JqqVaU,532 +torch/include/ATen/ops/_propagate_xla_data_ops.h,sha256=MSbeZNT22O8aLJVEkLCPXZwr_stnWudgQ0j3ylSR87M,1095 +torch/include/ATen/ops/_remove_batch_dim.h,sha256=lh5wC05LR77EKjj5OTIS7dJHcautFS0VKN7nAKNamS8,1754 +torch/include/ATen/ops/_remove_batch_dim_compositeimplicitautograd_dispatch.h,sha256=K7UcLGh1eVKkuIvnz6fCtAqj1XxjRcq4VVgWRQdR54Y,979 +torch/include/ATen/ops/_remove_batch_dim_native.h,sha256=0x3okb-N2W97KHs03mFZxdpWXugCMm6mXLCLb3QWT98,560 +torch/include/ATen/ops/_remove_batch_dim_ops.h,sha256=WrD7_E04CFcVsCQETckh--t4iOswdQbuxKSk0n_4Pgg,1205 +torch/include/ATen/ops/_reshape_alias.h,sha256=-mq81oewvhyLgd_ConwmigJ-e3z1hzChsaMGY3oC3nI,1743 +torch/include/ATen/ops/_reshape_alias_copy.h,sha256=HTBiVJXNxoWp0ERSvp4sDZTg6zvCCTGsPoL-13eFq8g,4652 +torch/include/ATen/ops/_reshape_alias_copy_compositeexplicitautograd_dispatch.h,sha256=vNI69_B2eR7Oe8yDtYfYAKsChdgG4q8yLHfjXRlCU2Q,1316 +torch/include/ATen/ops/_reshape_alias_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=fwzGKdcHoLrnjJnIECr-j2kyf6hVvoMvsoqLR1RepmI,1001 +torch/include/ATen/ops/_reshape_alias_copy_native.h,sha256=jLnc31Hz5tuinTy28oYUKws2vRROXeC4-1EMi6l_lag,724 +torch/include/ATen/ops/_reshape_alias_copy_ops.h,sha256=eWearVZ2ArnXXLaI5YynOpaGD4Rj7OTRT9z5a7TCoik,2029 +torch/include/ATen/ops/_reshape_alias_cpu_dispatch.h,sha256=mKzqpEDSFxZLs1x1bJEmmWll1BQiDpbcEZfkPN_hAtY,921 +torch/include/ATen/ops/_reshape_alias_cuda_dispatch.h,sha256=Vsg72Jkq4zRfNcpix0mQxc4NIpfHXcEL5CZSbhS5Ru4,923 +torch/include/ATen/ops/_reshape_alias_meta_dispatch.h,sha256=WSeJDs-frwtyUVaiAHQB8mJtFbGahd0ItCa28DBghUM,923 +torch/include/ATen/ops/_reshape_alias_native.h,sha256=HbjigftLL3augUmS_le8RymoEnLttuBhxYiEA2ykgw0,551 +torch/include/ATen/ops/_reshape_alias_ops.h,sha256=jqbmErjG24X_OBLwkoMq2i1eoD7x99WeZ_KyCCNgcEQ,1198 +torch/include/ATen/ops/_reshape_copy.h,sha256=lpknZPDjIOXH9OrLqEKqp2NlsyMqjJryE7UqtaZ7XDY,1498 +torch/include/ATen/ops/_reshape_copy_compositeexplicitautograd_dispatch.h,sha256=ZJGNiETAL35Cf9FshnX5SbCYfG3Ck1yoS6LVnbxRFZc,911 +torch/include/ATen/ops/_reshape_copy_native.h,sha256=b1gLWt3alf3Nkj9df6tKOmtluBaKYQwRINZ3avg9KFI,537 +torch/include/ATen/ops/_reshape_copy_ops.h,sha256=Y10eOlrjB-waxSUWfIo_t2P6ToO2G-WEHhaqOkIL0n0,1095 +torch/include/ATen/ops/_reshape_from_tensor.h,sha256=CitjHSbnub1nBByh_3SnHBPH8ziVuc5nQEUwQnYnAaE,771 +torch/include/ATen/ops/_reshape_from_tensor_compositeimplicitautograd_dispatch.h,sha256=BV31h48eui0j0BYziFWQ_VddhPrQp6bW4ymv3BFQr2A,827 +torch/include/ATen/ops/_reshape_from_tensor_native.h,sha256=KJn0c9YuIByZ-YuhKU0YN_dO2tNMHUHCScUx0hoS9Fc,537 +torch/include/ATen/ops/_reshape_from_tensor_ops.h,sha256=HGb8UJV2ix8fRXXRUHl_YeBokKTEUpqw0EHtwxNU1SQ,1114 +torch/include/ATen/ops/_resize_output.h,sha256=_Qpvjiyc49kTCnyVv6f_jK0xkR8SEo8xlHOhSjbjhaI,5510 +torch/include/ATen/ops/_resize_output_compositeexplicitautograd_dispatch.h,sha256=VVlWSsdMnefbMdFKXvQMRPYMurb5gf9rbb97KQEzyzU,1535 +torch/include/ATen/ops/_resize_output_meta_dispatch.h,sha256=hhTH-tI_iT3ca-v4U6ZO__ghCT1ZhJrR0VBrvv04SeY,927 +torch/include/ATen/ops/_resize_output_native.h,sha256=72dyThSiLzRJ3gyRl_gKXb-VsawpPoGjqXpA42ac2VQ,821 +torch/include/ATen/ops/_resize_output_ops.h,sha256=t787Xg91unXOi0t1AxyNJEJ9bTZ5oZooYFUeovc9X_k,2700 +torch/include/ATen/ops/_rowwise_prune.h,sha256=QXtDQ1_2V4wgskAB3aLT6gdwo51Y3PGgY0P4-4gxhu4,889 +torch/include/ATen/ops/_rowwise_prune_compositeimplicitautograd_dispatch.h,sha256=Moz7ruAeD0EJVaj3FhYcNQ--wnJx9Hzy9Ivsb52tZlk,888 +torch/include/ATen/ops/_rowwise_prune_native.h,sha256=Kr-MTYtUyuii3eGm8WZShij4Gn8eewh3C-rSflWTvbE,598 +torch/include/ATen/ops/_rowwise_prune_ops.h,sha256=OzlnZJgI4C2vRNlS0y8WtvOsT9DIpkVX0DBVAEv_tB4,1319 +torch/include/ATen/ops/_safe_softmax.h,sha256=BeQj9mOBFJES_6cloxgAoB60F2-255ITp2tNaVvV7F0,808 +torch/include/ATen/ops/_safe_softmax_compositeexplicitautograd_dispatch.h,sha256=Pj7O9CgqoF1HlAJIcmvMso7IhWgpamw9B3lq5Ue2xwU,861 +torch/include/ATen/ops/_safe_softmax_native.h,sha256=kkXObJCU5i407ly672OZliUNukKtHqmpCafNBucmwQo,571 +torch/include/ATen/ops/_safe_softmax_ops.h,sha256=BLHlM-GhfJse3l4gLJl1ma0BUsUjqe5y6RtuNAbTJHg,1186 +torch/include/ATen/ops/_sample_dirichlet.h,sha256=H8311jHxikDqt-dtjXxWoQqRFk90Ds2R8edfIcdZt4g,1447 +torch/include/ATen/ops/_sample_dirichlet_compositeexplicitautograd_dispatch.h,sha256=6ifmJa0DngGLzH-zkaf-GQTmM0yB9Uf9rgkdPloqnyA,1012 +torch/include/ATen/ops/_sample_dirichlet_cpu_dispatch.h,sha256=z7rbqtkKVRVknLozf6g_E9gz-9b52_2kSzFxF4NW0hs,811 +torch/include/ATen/ops/_sample_dirichlet_cuda_dispatch.h,sha256=93RYromj5OUdivMuOe6O0Eu1070PEmcYAo423BOkoMA,813 +torch/include/ATen/ops/_sample_dirichlet_native.h,sha256=xdcFbT8VmJdJqfK6ykW78H10B-4kBnjywRUNPxL8Psk,819 +torch/include/ATen/ops/_sample_dirichlet_ops.h,sha256=hashFTCVierfjNTQiklbdzurST-wk23ioa5VuI32tMQ,1939 +torch/include/ATen/ops/_saturate_weight_to_fp16.h,sha256=ckrIeQBIqP6h6oU_tyl_NLZ-PYkCC4WhpRCW-mE3C1I,746 +torch/include/ATen/ops/_saturate_weight_to_fp16_compositeimplicitautograd_dispatch.h,sha256=MmbKOuZfY6ilKxAAkMm2XEstqA83H5G20TQKA5sEaAY,807 +torch/include/ATen/ops/_saturate_weight_to_fp16_native.h,sha256=3JFX4DFfW385UFvoleejUPdmYkMCy2NCqqTtTi9B_Xw,517 +torch/include/ATen/ops/_saturate_weight_to_fp16_ops.h,sha256=_Jsc8A4TTGzCvtKMkpMILKZReP1eMrJs-0bWfNWf-P0,1046 +torch/include/ATen/ops/_scaled_dot_product_attention_math.h,sha256=5bMpYPTL9HRZ8Nyqgb4KX4UgM9ZEGpjqFDx518Q1mQ4,1327 +torch/include/ATen/ops/_scaled_dot_product_attention_math_compositeimplicitautograd_dispatch.h,sha256=LbATTrR1-jPQj0TUgPuyoJRmCOjq0RTumAveDFA-1NE,1107 +torch/include/ATen/ops/_scaled_dot_product_attention_math_for_mps.h,sha256=zBbSKyNGbJDOEwWlU_hr5lhvK18cijwKNAmbP78a8s0,1301 +torch/include/ATen/ops/_scaled_dot_product_attention_math_for_mps_native.h,sha256=9XaUGLJyiIUXnjau4tp-HFMbdt34qeE71RClgADHWpo,442 +torch/include/ATen/ops/_scaled_dot_product_attention_math_for_mps_ops.h,sha256=EiwVfHnsuUsA_f9iO9Feh_hh_fSXnXTtAgLsYfa-KtQ,1901 +torch/include/ATen/ops/_scaled_dot_product_attention_math_native.h,sha256=is8xIpwwCtxbVHZDcpIeLGTZCGlCDhxEF3H1NeRwUsA,817 +torch/include/ATen/ops/_scaled_dot_product_attention_math_ops.h,sha256=4AUQ4aKdV4eKMg0yBm1HN94RNjheGgakMHRN8gm6d0k,1940 +torch/include/ATen/ops/_scaled_dot_product_cudnn_attention.h,sha256=wEeLyND1XjAEB8ugFG9y5190W41UrwDbit6w4-6Y4Ik,1546 +torch/include/ATen/ops/_scaled_dot_product_cudnn_attention_backward.h,sha256=LCjloF6aVl8DqaDreiiWLgrJ_mRgdZfoAUcN3EpLu4E,4511 +torch/include/ATen/ops/_scaled_dot_product_cudnn_attention_backward_cuda_dispatch.h,sha256=tGC5GC5ZUcbd81fHeUIRlnJSqE-Z83vsz50rzqQ6M_8,1763 +torch/include/ATen/ops/_scaled_dot_product_cudnn_attention_backward_native.h,sha256=9rslWI_wykxaniPN-gPPiOwwyYZOSmdEqf62-bfx5O4,976 +torch/include/ATen/ops/_scaled_dot_product_cudnn_attention_backward_ops.h,sha256=f6awRx_SBk3RyKg9lAXqUyc0EwIjsRU-0UzsbRFHeeQ,2530 +torch/include/ATen/ops/_scaled_dot_product_cudnn_attention_cuda_dispatch.h,sha256=IiRI-V6jodFdH8_xhU1TNJSQAXQ0EFDdCwII3A0ND_s,1121 +torch/include/ATen/ops/_scaled_dot_product_cudnn_attention_native.h,sha256=LpG8TJwPVENFKTICzM8vIkWYP9rvBbSWaGd2mzRItOM,1327 +torch/include/ATen/ops/_scaled_dot_product_cudnn_attention_ops.h,sha256=OvDhoNV2hbPSq6qb_QICVtyhtaDLfpccGXCpPOkDLHA,2260 +torch/include/ATen/ops/_scaled_dot_product_efficient_attention.h,sha256=ojXnA3KpHLbtGzo613vRNUN2e4qELqbVIIXJm8nqNM8,1339 +torch/include/ATen/ops/_scaled_dot_product_efficient_attention_backward.h,sha256=U5ORGIXM0cn0F3W-4uohROqRnCjhnqc3i2YL02yPfaQ,1610 +torch/include/ATen/ops/_scaled_dot_product_efficient_attention_backward_cuda_dispatch.h,sha256=2cj2a8Tey4dDG4ov1EbD8anwBxl7XJqkVkfa7yRMAw0,1189 +torch/include/ATen/ops/_scaled_dot_product_efficient_attention_backward_native.h,sha256=9xzstRyhiXv1KroUH0t_vGSPvKiEYi_dSKqhKXQcKzo,946 +torch/include/ATen/ops/_scaled_dot_product_efficient_attention_backward_ops.h,sha256=xaBKRm7HwmXQsvNqm-jAToUWziiyw9XkUmOn6I_da6c,2389 +torch/include/ATen/ops/_scaled_dot_product_efficient_attention_cuda_dispatch.h,sha256=Omb4r-tGGjo59De3RHo9tae09aIGcZLks-CSudMfXEs,1038 +torch/include/ATen/ops/_scaled_dot_product_efficient_attention_native.h,sha256=jNsUH1GvbgCTkj3zkeOvue7ghr3fr4K4P08AjFe6Fs0,1161 +torch/include/ATen/ops/_scaled_dot_product_efficient_attention_ops.h,sha256=4yqO73EJ_7hPodmdQ9b8P6rhl6OUgDZRWoWXMdPLBoY,1930 +torch/include/ATen/ops/_scaled_dot_product_flash_attention.h,sha256=b5g-s3LzGp75ZvOK9OCFnoTq3Qa4iXwaCjubMZIU7b0,1390 +torch/include/ATen/ops/_scaled_dot_product_flash_attention_backward.h,sha256=fDFmqJxDPDSpBES_le1qdlqcBFIkyB1rpJxMipxGnzQ,4373 +torch/include/ATen/ops/_scaled_dot_product_flash_attention_backward_cuda_dispatch.h,sha256=V7ThMWXuAfSzgxM8GyNS5KjAjbqrhX64sOnv0O_6YbY,1703 +torch/include/ATen/ops/_scaled_dot_product_flash_attention_backward_native.h,sha256=zU8hrrbCtD6UAqS0n_5X3AlUJXqLg2Lru2JNImFC9F4,1452 +torch/include/ATen/ops/_scaled_dot_product_flash_attention_backward_ops.h,sha256=camEXsIOJ3BMfaTJxGvWRArklHNYlRmWNfAE5R45Oqs,2463 +torch/include/ATen/ops/_scaled_dot_product_flash_attention_cuda_dispatch.h,sha256=1x3OXJQY16yDy6_JNzozPHgqxc7iFMf3C5pm8Tn4EgM,1049 +torch/include/ATen/ops/_scaled_dot_product_flash_attention_for_cpu.h,sha256=0cXWvpil-klQI2h49zh43_m5g9e5bE0hO2mKPnpFURs,1228 +torch/include/ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward.h,sha256=iAyhiWuM5ypVUO1cPu4kf7_OcrctEqXQv56auZIcEb0,1433 +torch/include/ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_cpu_dispatch.h,sha256=Bv1PL73gdqxC-r1xGDBlPVc_ZjNbPufiX4xKd6613Ho,1089 +torch/include/ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_native.h,sha256=ue2stwZ9lLAf9VOYq0XKcjV7VpIGow29OqhHPWBGsvw,839 +torch/include/ATen/ops/_scaled_dot_product_flash_attention_for_cpu_backward_ops.h,sha256=QRyli182g0Gjp5uEKzzhmyeHyYovQza34BWj5qgrzww,2102 +torch/include/ATen/ops/_scaled_dot_product_flash_attention_for_cpu_cpu_dispatch.h,sha256=J6gL-4PmyIJNILE6r9wWpxGYi4z6aoGRQoKDKTjj_Hg,996 +torch/include/ATen/ops/_scaled_dot_product_flash_attention_for_cpu_native.h,sha256=UkgNfLUnxapqusgMBU93PrsG7b9DWgdxX215zuiGB0c,746 +torch/include/ATen/ops/_scaled_dot_product_flash_attention_for_cpu_ops.h,sha256=PsfADJqPEJ3M0-ffw3H43XLX6KeX3HWST9jseDq7d1Q,1757 +torch/include/ATen/ops/_scaled_dot_product_flash_attention_native.h,sha256=3wEjiE3u6JLlLofbndWCImGycja8eFcrctKjttWksrQ,1183 +torch/include/ATen/ops/_scaled_dot_product_flash_attention_ops.h,sha256=d6wi3FijoRaWB2_xezN5HHNnAYDxj3O_VdULbE-F43M,2020 +torch/include/ATen/ops/_scaled_dot_product_fused_attention_overrideable.h,sha256=P6e0hy7q3yEMvFpsLKD4TvHhR8Kn1Gy94Xm96kCgn34,1536 +torch/include/ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward.h,sha256=a52v20RTXJgA6cE0l8BpWWWMPRDeRAf9N24V8f7spoc,5076 +torch/include/ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward_compositeexplicitautograd_dispatch.h,sha256=v66LYhSYRUPhn6kSVOkzWoMCs1resaWYdFkhlL5SdQA,1929 +torch/include/ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward_native.h,sha256=le7vafWMcMvwGZ-q59cxnrrnnzFh7ekWljJEw9c9d9o,1033 +torch/include/ATen/ops/_scaled_dot_product_fused_attention_overrideable_backward_ops.h,sha256=yW-Lb2WunFWCRuEGRWcTkuW4V3GDh3xShQ6UH_JUFRs,2779 +torch/include/ATen/ops/_scaled_dot_product_fused_attention_overrideable_compositeexplicitautograd_dispatch.h,sha256=1Rs530VKaS5wzEa8dx8yo6z7VCmZxy0UWYBdljufmN4,1154 +torch/include/ATen/ops/_scaled_dot_product_fused_attention_overrideable_native.h,sha256=iYiuv2cJV6XR_RA8_zD7YFrNUx1Q1gkNyZT5b243p5U,864 +torch/include/ATen/ops/_scaled_dot_product_fused_attention_overrideable_ops.h,sha256=CbbAhJ1NUs5OtK-aI5_i1i6eMUSycno_hk50XBhgpS4,2223 +torch/include/ATen/ops/_scaled_grouped_mm.h,sha256=DsL1L3ALp4uQCffl5VodbdYYyYd1WKvgdHJXk870Chk,1267 +torch/include/ATen/ops/_scaled_grouped_mm_cuda_dispatch.h,sha256=OTFHJTKU2sKjsXJ_t-bJH5HH6n93a6HpJMHuwlFBl3w,1066 +torch/include/ATen/ops/_scaled_grouped_mm_native.h,sha256=kPy_M4x6tsQpYdD6a81duLjmhAYI6HT7sdbLL0ca9F4,823 +torch/include/ATen/ops/_scaled_grouped_mm_ops.h,sha256=pS6mZ9RIm5gmEdYLa-1WVoDbcoBNO4UXLkcNquEecYQ,1955 +torch/include/ATen/ops/_scaled_mm.h,sha256=oP6zj0aZBqFkhhLyUyaBLVzIxQ-HZ9w3sqSvjJNv8nU,2523 +torch/include/ATen/ops/_scaled_mm_cpu_dispatch.h,sha256=JI-fSyx9Q815mKMTKNcyRp1XlK88aDtcM2mRY4L9zG0,1679 +torch/include/ATen/ops/_scaled_mm_cuda_dispatch.h,sha256=IgDPdqtExJK9xOLlwwMe6izPIciBYVBbk_YdwqdXVyo,1681 +torch/include/ATen/ops/_scaled_mm_native.h,sha256=9TNj8KoTxuTbRxNoh84KfASA0HBDB3mskiC5_JH0v9k,1746 +torch/include/ATen/ops/_scaled_mm_ops.h,sha256=jsJ-xtCkRGp3Z5Mc4HDRwfn5egdUxGIN_xCzprQNjcQ,3197 +torch/include/ATen/ops/_segment_reduce_backward.h,sha256=gV3uD85gPOqaK_Sm1blo0hic68-XkgSperM1sQV-Ni8,2493 +torch/include/ATen/ops/_segment_reduce_backward_compositeexplicitautograd_dispatch.h,sha256=jeDWnjyEoP4Nv56S1gCXjrvWHkuPPqPWRk-Egc1f-s8,1402 +torch/include/ATen/ops/_segment_reduce_backward_cpu_dispatch.h,sha256=u-PKVj2FBi6iSG0J173XQvZFyC5TmSM3ktv9u_2b-Ww,1010 +torch/include/ATen/ops/_segment_reduce_backward_cuda_dispatch.h,sha256=y9qOrPBOgDJ7VLoQpI47PI89WyHaRHcWVsUZKGcRcHg,1012 +torch/include/ATen/ops/_segment_reduce_backward_native.h,sha256=amJKgZrLLCjUPP2FKA-oI_a3oUVH7BmmU3cBgnBRkN4,1094 +torch/include/ATen/ops/_segment_reduce_backward_ops.h,sha256=TLZIciwIKIeF8PXmA_QSSD2O6tkqTJzWH-d3KzIo3vQ,3194 +torch/include/ATen/ops/_shape_as_tensor.h,sha256=I6wcLwxG3Jn6bQzkjHHtDLaO3bkzmKKc8b11spAatrM,708 +torch/include/ATen/ops/_shape_as_tensor_compositeimplicitautograd_dispatch.h,sha256=SHmmx0AvrhR0cB8nIO64XTr4nUrIk7hu1viGpS1wsjM,797 +torch/include/ATen/ops/_shape_as_tensor_native.h,sha256=En6lvt8WvF8SLVqwSaP5_qew8i95vAB10ipoD1pplF8,507 +torch/include/ATen/ops/_shape_as_tensor_ops.h,sha256=iwqfV88uZzt9jLFFc1Phx1Sc6qLiqgYO35ruFqNNLSw,1016 +torch/include/ATen/ops/_slow_conv2d_backward.h,sha256=SBP4WR6Sl1XP3B4slDpHz7YvZg2aupAANoOas4UilA4,14666 +torch/include/ATen/ops/_slow_conv2d_backward_compositeexplicitautograd_dispatch.h,sha256=JP30koPAUJE0vMOaltTnLMDZ6iidstd9cR7ejgwyt3g,2148 +torch/include/ATen/ops/_slow_conv2d_backward_cpu_dispatch.h,sha256=ESxrgpYGN-huwn6aSCAAQ7LPeYktOKCRxq973vBJszw,2613 +torch/include/ATen/ops/_slow_conv2d_backward_cuda_dispatch.h,sha256=Zh7CLYTb5cWUgEgHRjjsFDhQ_Tznkw4szS1F3xDKN_0,2615 +torch/include/ATen/ops/_slow_conv2d_backward_native.h,sha256=y4pNwRcjtsQ9BVH9fSf7gV82I4tD1GwNZ2Wvu3EV0HQ,2041 +torch/include/ATen/ops/_slow_conv2d_backward_ops.h,sha256=gHjhKuT6krXZBPpHY38k5HqFMPFVtDHM5we01xk8hGA,4912 +torch/include/ATen/ops/_slow_conv2d_forward.h,sha256=5FV8LzZc60HyadrEGluRqL6Zm--V7MSFRAPgcRDmf_w,6873 +torch/include/ATen/ops/_slow_conv2d_forward_cpu_dispatch.h,sha256=qpfOzmpdMILXsqqzJnjWaDI9ZGO5iFzXnMt_0AAr7TE,2151 +torch/include/ATen/ops/_slow_conv2d_forward_cuda_dispatch.h,sha256=Y5OiVujbGYryMXA_TdwTe_qP21Q1xMN9xNpDXhvgv0w,2153 +torch/include/ATen/ops/_slow_conv2d_forward_native.h,sha256=paZOLXIM91ZCjhbMDxq81YUY2SB0ozm1H97E4k0FUhQ,1374 +torch/include/ATen/ops/_slow_conv2d_forward_ops.h,sha256=yPmUManjTQCf02zmJjhpI1Q8n0iIxI6qj_nUVGN9Ais,2743 +torch/include/ATen/ops/_sobol_engine_draw.h,sha256=XO5nL3_1AQdMbBb2fi4eNJ3Nj3VrImc-Twk5Dh3Hz7Y,1004 +torch/include/ATen/ops/_sobol_engine_draw_compositeimplicitautograd_dispatch.h,sha256=DKIHmbl2glZiL5IaEXtaXsljRoxUkCLpkHjLKYVV0FA,948 +torch/include/ATen/ops/_sobol_engine_draw_native.h,sha256=jQgA7vLCiEw_5NYW3XNFp2V1d2Q2kk29bheQEqHmXvw,658 +torch/include/ATen/ops/_sobol_engine_draw_ops.h,sha256=EP3Oxfihp0Nzy5IMQMfLiJm0qE7s6RdkLAZ17H66OfI,1515 +torch/include/ATen/ops/_sobol_engine_ff.h,sha256=vlQGmE-3V3H1HUBHxDwWih0dgZ_kESInOsYo-NSjojM,900 +torch/include/ATen/ops/_sobol_engine_ff_compositeimplicitautograd_dispatch.h,sha256=U7ppA4nbFqIwcwAubPmBr5f5dYh9Q9YeNprRJ0MAWUQ,878 +torch/include/ATen/ops/_sobol_engine_ff_native.h,sha256=KHvsphxPM7PmE4uQRoFURLOCbwQ1485bXp02yMOcFoE,588 +torch/include/ATen/ops/_sobol_engine_ff_ops.h,sha256=4dtQCxk8bK8c17kYSAc_VG9E9iv98qMc1x3fYTy79n4,1290 +torch/include/ATen/ops/_sobol_engine_initialize_state.h,sha256=CNDD0Tja_0AERXAhCcrBW8Fm6SlKI2QjE5C6eIYNOY0,816 +torch/include/ATen/ops/_sobol_engine_initialize_state_compositeimplicitautograd_dispatch.h,sha256=7zIvRDxQ_YOFwx0LbRrllhQnRGt78_2ASFTEIcApH-Q,827 +torch/include/ATen/ops/_sobol_engine_initialize_state_native.h,sha256=IXvFjYonJb6t3HrhHVp_7NW-a9ox0NMvCRG8u1Rz5Z4,537 +torch/include/ATen/ops/_sobol_engine_initialize_state_ops.h,sha256=aQXaRVrg_mcDiJNEyDvDSZt9hzwJW8ygO_c6KgLTOOU,1119 +torch/include/ATen/ops/_sobol_engine_scramble.h,sha256=IgDsb3de2E3SV6MW3ekDDZc88zca-VsDUqgW0xwYhIM,825 +torch/include/ATen/ops/_sobol_engine_scramble_compositeimplicitautograd_dispatch.h,sha256=ytWb3gPPUWeGbgf4Jt1YS40Y2MkYMKwhKCJKWC0SyLw,843 +torch/include/ATen/ops/_sobol_engine_scramble_native.h,sha256=J6oQB16hezBM-7-eT48vlVzCQkhCvpy3KDgFVlqkg8s,553 +torch/include/ATen/ops/_sobol_engine_scramble_ops.h,sha256=hS0XrUVGw-AKyks1GhNuAkxIH5ISnPSOUtZkrOyfDkg,1175 +torch/include/ATen/ops/_softmax.h,sha256=YEgoQbY-_hANDGFcjeL8Vz2sbEf39OnjHBZ_ShL1Wfw,1333 +torch/include/ATen/ops/_softmax_backward_data.h,sha256=awLADjRfQ3r-nghZzcOz6w7KZIDhKFezSC7RBCHE71M,1758 +torch/include/ATen/ops/_softmax_backward_data_compositeexplicitautogradnonfunctional_dispatch.h,sha256=DcIWdq8zP4BOeUgfl9kyXWRwy1HAu0TBteopMnM1-4U,904 +torch/include/ATen/ops/_softmax_backward_data_cpu_dispatch.h,sha256=Y9p462K6fInP2Oo41a3lnb4nw_z17MUnl1cqDvLiy-0,1189 +torch/include/ATen/ops/_softmax_backward_data_cuda_dispatch.h,sha256=Hj7vXrQ7uQWQvuPiSmH7bcrR6tjPCxwPWXsEJ9vw0og,1191 +torch/include/ATen/ops/_softmax_backward_data_meta.h,sha256=j3TmodAcC-XWWbi-mthYguOuv_uhblKjYzpdYk7wYAk,685 +torch/include/ATen/ops/_softmax_backward_data_meta_dispatch.h,sha256=HUv-N_vf-qW3F2Xix-57d69xz0y-Jfro1StAqb_0skA,1191 +torch/include/ATen/ops/_softmax_backward_data_native.h,sha256=cdl06hJLnXbx5qgoLpFjcPTc0hjJCrHKLFFobLkYHpQ,1151 +torch/include/ATen/ops/_softmax_backward_data_ops.h,sha256=PkvqSw3Ea5yLBxjJnVqCalRQsDAD8IstLA5Oa5pG5Q0,2204 +torch/include/ATen/ops/_softmax_compositeexplicitautogradnonfunctional_dispatch.h,sha256=UBgHl5QlASCQJKNOT56OUnTYRx4mMUIasBTbkKcvNbU,848 +torch/include/ATen/ops/_softmax_cpu_dispatch.h,sha256=Uxe0BPRjKjzBFD1QglomMz_cYDzdKpe_HP-5geD4PpQ,1007 +torch/include/ATen/ops/_softmax_cuda_dispatch.h,sha256=qyTKzsEcx1T7nZT-1fjIZmOXDkgC0VetHXPFYJ7C5gg,1009 +torch/include/ATen/ops/_softmax_meta.h,sha256=gccY7lf8BbGB7UBq_APwkyDZ-MDiivWV8jZr_NC2nq4,629 +torch/include/ATen/ops/_softmax_meta_dispatch.h,sha256=MXPuYDKHkMUOqamc-iEFVjNGUdpl8t9sE_p8XpJXUyo,1009 +torch/include/ATen/ops/_softmax_native.h,sha256=QL6qvLsH2ixczx6z1BYiQn-Kn4jGYC29kri38-yi40g,1038 +torch/include/ATen/ops/_softmax_ops.h,sha256=n2UJqRVDVF_9nwWBwK7778OsrGpfoJgGe-6nDFZzoVs,1819 +torch/include/ATen/ops/_sparse_addmm.h,sha256=yRlJYrAmk45nm9yBlqV-39VkKFjqEKGr1zyYsMAeudQ,1697 +torch/include/ATen/ops/_sparse_addmm_compositeexplicitautograd_dispatch.h,sha256=6PzdHz2DwFOmfogm-GAlcsRbvMECA9EpgKSJoJ1wLfM,1278 +torch/include/ATen/ops/_sparse_addmm_native.h,sha256=QY9OhnCFAs1uZz-d4zTsF6-SyimQv_Q9X63AiYArXjI,796 +torch/include/ATen/ops/_sparse_addmm_ops.h,sha256=DOS9By3r5y41OUvdKUF5zChauHXbqwqJDQlMUMFguiU,2310 +torch/include/ATen/ops/_sparse_broadcast_to.h,sha256=YPqCiZ0RZNIbUKM07rRUiHeRBRTxKRqvWeN5ipy4tHY,770 +torch/include/ATen/ops/_sparse_broadcast_to_copy.h,sha256=6Ks2xxB1vh4Hfh9Al1rRZNnHf2OCEl5xG0uSgesocmE,1377 +torch/include/ATen/ops/_sparse_broadcast_to_copy_compositeexplicitautograd_dispatch.h,sha256=f_M6rCizOWlkbPYyXTSfoZ8U1WANuULM5vxzxB3_Md4,973 +torch/include/ATen/ops/_sparse_broadcast_to_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=NB1u35yrFNKOOxzQBgUhQFvicSAQL7TFhDFA-ytNnbc,854 +torch/include/ATen/ops/_sparse_broadcast_to_copy_native.h,sha256=2iofrY97Cw6HaJfs5D8L-7K0xFrTzY3W-JTXN9JDnbo,658 +torch/include/ATen/ops/_sparse_broadcast_to_copy_ops.h,sha256=ExCqhaxWeYtDjmmYlgKXItAJ34v5dwoEUGyV5QK_Yw8,1847 +torch/include/ATen/ops/_sparse_broadcast_to_native.h,sha256=KHtF7pHQuRZkR_iolKX5j6xjXObAiNR9jcx5M_fKDx4,532 +torch/include/ATen/ops/_sparse_broadcast_to_ops.h,sha256=jNCm-5zFOixa56QyGFUzaFs7yyHif7zQc5z2dkYPQr4,1107 +torch/include/ATen/ops/_sparse_bsc_tensor_unsafe.h,sha256=yA2iw2kmxwu85rF8nU5k07MtGg5lQOI1yTXHHpUcODk,1815 +torch/include/ATen/ops/_sparse_bsc_tensor_unsafe_compositeimplicitautograd_dispatch.h,sha256=fD_Nza5BlGXqi7Z_BrTlaLp0_9zLr2hkPK8Smw6q17Y,1233 +torch/include/ATen/ops/_sparse_bsc_tensor_unsafe_native.h,sha256=FySVPp81Mu8XTdnBNuU5vnVF1n-q_1VqsV1__pQuCyE,762 +torch/include/ATen/ops/_sparse_bsc_tensor_unsafe_ops.h,sha256=hhyVwINPLMtIbEjs5pvzWGg_3EL0R6v_QY2oDy7ac2A,1829 +torch/include/ATen/ops/_sparse_bsr_tensor_unsafe.h,sha256=1mphwimc-0McYEb-PeJRGJPgLbTMAbEJF-61ulogeI4,1815 +torch/include/ATen/ops/_sparse_bsr_tensor_unsafe_compositeimplicitautograd_dispatch.h,sha256=HxjT_h_gvXQWD-MKHzujhd_5XD2YJANhH_anytd9H-A,1233 +torch/include/ATen/ops/_sparse_bsr_tensor_unsafe_native.h,sha256=nxAFBEhqu8OdsPjvnPcaF0PSAoCU7LZE_I05q_kbYm8,762 +torch/include/ATen/ops/_sparse_bsr_tensor_unsafe_ops.h,sha256=2861kZ-Qk7VhFRzC4JSe8fcXjgu927V1c4MiiKcfgII,1829 +torch/include/ATen/ops/_sparse_compressed_tensor_unsafe.h,sha256=6dhLNw07QJXLZm8F0XvCG4Y36Rzqmg5MI-883LKDr2M,5664 +torch/include/ATen/ops/_sparse_compressed_tensor_unsafe_compositeimplicitautograd_dispatch.h,sha256=0HzNHTwLfPlQGFsoL4ZGgjtJNCB59ifJdxBPLI28B90,1816 +torch/include/ATen/ops/_sparse_compressed_tensor_unsafe_native.h,sha256=7dR4gKCS-MmBOkFAWP1hWuALQLnC7lLmUjf56hUoVIQ,788 +torch/include/ATen/ops/_sparse_compressed_tensor_unsafe_ops.h,sha256=e1EPkrbc4FQX56t64MOGzqATKlW05BD2XjQsc0H1tyo,1889 +torch/include/ATen/ops/_sparse_compressed_tensor_with_dims.h,sha256=pYrlBjPTFupt05YmmkbG3j3-LLv4SvPlArpcMdIULmY,1902 +torch/include/ATen/ops/_sparse_compressed_tensor_with_dims_compositeexplicitautograd_dispatch.h,sha256=D9Gpz4XiimvB4tCmw9vtDZmn4epr3rn2-YXdCT8NHWY,1240 +torch/include/ATen/ops/_sparse_compressed_tensor_with_dims_native.h,sha256=iAq27nEqpRZKUuTxyLz_5cCJxs99cl8B-E9kpFHoemw,766 +torch/include/ATen/ops/_sparse_compressed_tensor_with_dims_ops.h,sha256=KyLNv3YYF8GhHKnFjs0neFWMDBbX_6djmdQdc88U-w4,1850 +torch/include/ATen/ops/_sparse_coo_tensor_unsafe.h,sha256=negnItEgRmvd9h5bmBI-kpK13NeT-OuZQqB_IrqFEn0,5377 +torch/include/ATen/ops/_sparse_coo_tensor_unsafe_compositeimplicitautograd_dispatch.h,sha256=slzuA49lwAUZR3aFBMkxhFm3H30xvmJn7jtAY_A2-nU,1782 +torch/include/ATen/ops/_sparse_coo_tensor_unsafe_native.h,sha256=UwTUpEBhW04qvaDXW7D7xbMe3D75ggknjguadciyDcI,787 +torch/include/ATen/ops/_sparse_coo_tensor_unsafe_ops.h,sha256=M-hCc9zXZkW3O4Afivks2N_NK8nJPcLso4Mbu-22Bh0,1845 +torch/include/ATen/ops/_sparse_coo_tensor_with_dims.h,sha256=PTA8RCnU11JQXQaBu_-YZz3JqzHfdaLaaYgElbir8Tk,2364 +torch/include/ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h,sha256=RFaLYJtwwDKhTT7jYTha__xQfKU6yKf23rFh28V5_d4,10726 +torch/include/ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_compositeexplicitautograd_dispatch.h,sha256=kqUx9Mokxf3MuWIwyTFu9CAGGyfSGdEXRds5ySu6i14,1746 +torch/include/ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_meta_dispatch.h,sha256=S2Sui2S8_gctxhVq9SRMsufge-dfHlgVZoIJzBnr_Nc,1950 +torch/include/ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_native.h,sha256=jxXhl0OnEQrQja4gABkrcqs4vDNQaHltN3Yw3AT-RYc,1083 +torch/include/ATen/ops/_sparse_coo_tensor_with_dims_and_tensors_ops.h,sha256=gzUS_cWX_rwH3JKlhUWloVkFsUfCVO3o_kwEQkbOPsY,3154 +torch/include/ATen/ops/_sparse_coo_tensor_with_dims_compositeexplicitautograd_dispatch.h,sha256=7YsWN-jjumSdzlFUNPITBvSviz11Lil5tv9g3wyygSk,1007 +torch/include/ATen/ops/_sparse_coo_tensor_with_dims_meta_dispatch.h,sha256=dUVuLZRFOtLZom7AUfnAS7NOPa6K20wPuBMbvxakmW0,1088 +torch/include/ATen/ops/_sparse_coo_tensor_with_dims_native.h,sha256=qRnBb6OWd58w5S3f0In5XTk0vFOBZwu9yCVj0xSb9zQ,841 +torch/include/ATen/ops/_sparse_coo_tensor_with_dims_ops.h,sha256=jJPhId0LRnfwgSzPpIta-NLT9TGF-PnR_DLNINM5jl0,2450 +torch/include/ATen/ops/_sparse_csc_tensor_unsafe.h,sha256=botA5D4zgjG_eiqlUoPW-uR7OhJ2SRjRkSQ-3jdok4w,1815 +torch/include/ATen/ops/_sparse_csc_tensor_unsafe_compositeimplicitautograd_dispatch.h,sha256=DdQz0B0arJZunyDFUDKXh9pG4NK7c8e4OC-j1GN1F4E,1233 +torch/include/ATen/ops/_sparse_csc_tensor_unsafe_native.h,sha256=KiTE44AZiv-tknGaPIeaNDBDTQ_j-vzew25-0AwLi1Y,762 +torch/include/ATen/ops/_sparse_csc_tensor_unsafe_ops.h,sha256=mGwtEV5SRAyYVwLfDGbEFd0cvxv_qDsWXvtiW3FS5bs,1829 +torch/include/ATen/ops/_sparse_csr_prod.h,sha256=7z98-7-pT0wrv9dQBvAvFvf-SUBPfJGxiaNmNkGuwkg,1725 +torch/include/ATen/ops/_sparse_csr_prod_compositeexplicitautograd_dispatch.h,sha256=vSlLeVPMmlhcQJRsszzRQrEGpopV3smAj4LmdgXtnrY,1080 +torch/include/ATen/ops/_sparse_csr_prod_native.h,sha256=3JRdXNW94ecdiwOWxfnwbWnuLZ8CMdbuWwpFmtFal6w,944 +torch/include/ATen/ops/_sparse_csr_prod_ops.h,sha256=xll9js7H79DnJRE37sHGIewPizc45w2JRI72-zhcjOM,2229 +torch/include/ATen/ops/_sparse_csr_sum.h,sha256=qMQSxjgVVyazPD5_otdIWUq0HouyNPJan7cfWx7mP6M,1715 +torch/include/ATen/ops/_sparse_csr_sum_compositeexplicitautograd_dispatch.h,sha256=pDVz0c4VPnuTYtuj28nLuMV3Ml9xKuvydOS8l2w--aw,1078 +torch/include/ATen/ops/_sparse_csr_sum_native.h,sha256=kCkeBBJKRPRa_-RKreeEL5SsUhX_pD_zx2zXptlLal4,941 +torch/include/ATen/ops/_sparse_csr_sum_ops.h,sha256=syDQ4lQ4wZm7oTWG1NSEHJ-7lzdBuKnD3ygb5oljD7c,2223 +torch/include/ATen/ops/_sparse_csr_tensor_unsafe.h,sha256=12-8GOhU591rpBFvnZk56c3xBs-feuEWJw_UaC3aTkU,1815 +torch/include/ATen/ops/_sparse_csr_tensor_unsafe_compositeimplicitautograd_dispatch.h,sha256=iwU_a6nWlWL1vg5gYv9Fjx2gYSR3bWc_QtzyqZphPuA,1233 +torch/include/ATen/ops/_sparse_csr_tensor_unsafe_native.h,sha256=9ljK6DGheT4tLsAa7lfxfCmTbyQuHcPnLzx-52GQqVk,762 +torch/include/ATen/ops/_sparse_csr_tensor_unsafe_ops.h,sha256=ZG7N0qjG4hZZ5VpuhhaRxlpnNY1yFfxXl3PjjbfoPsU,1829 +torch/include/ATen/ops/_sparse_log_softmax.h,sha256=iYxJyqeGOKkBehh4alFNyHsvqlVdUx0AwJp_Neaw3Vc,2060 +torch/include/ATen/ops/_sparse_log_softmax_backward_data.h,sha256=lNiEldY4buXdtrK2jFaZhjo1LFl1bjqdfk0uO0Nfz9w,1763 +torch/include/ATen/ops/_sparse_log_softmax_backward_data_compositeexplicitautograd_dispatch.h,sha256=HvR-ncFNFxswEX1x8a6Q6MVWWWll9YiQ-C9E2TE_Jx4,1089 +torch/include/ATen/ops/_sparse_log_softmax_backward_data_native.h,sha256=chj0zZlRmOFtL5Yny2fVqDx1r8U7lGim3aBpoeItVPs,925 +torch/include/ATen/ops/_sparse_log_softmax_backward_data_ops.h,sha256=BZnwxZzgo54CJr_mx_7Mh9PeOEMu2NFPJHkIPnwo8Fk,2223 +torch/include/ATen/ops/_sparse_log_softmax_compositeexplicitautograd_dispatch.h,sha256=Q5jQzsaEdxSgIjQAliPy-6yraZDqEo8vDJBMm5r4MrU,983 +torch/include/ATen/ops/_sparse_log_softmax_compositeimplicitautograd_dispatch.h,sha256=nK926vQS1wkEze7XCsuIYnKnKLPrIPkxhkWMXAN3ros,1006 +torch/include/ATen/ops/_sparse_log_softmax_native.h,sha256=FmD_ue5Tub4ITILWVYlEe33XMzDlwYqegmZckr2a_jk,1050 +torch/include/ATen/ops/_sparse_log_softmax_ops.h,sha256=MhgnExurM6HyEd6iJMWbGVBNa-aAGDLS50jGRLOGDG0,3396 +torch/include/ATen/ops/_sparse_mask_projection.h,sha256=2wqLAFO9zPdgcnbTgVaIFmLlJ9hNai0VsVlM-g4iF7Y,1292 +torch/include/ATen/ops/_sparse_mask_projection_compositeexplicitautograd_dispatch.h,sha256=r2_yQRB_oOghptPr3LQb0lPU0sXHBFkWwxslrMtqFqo,1031 +torch/include/ATen/ops/_sparse_mask_projection_native.h,sha256=N_hL4hO00j-g70Hnev_wmq3wnmO-exE2hIx7f05EEj8,715 +torch/include/ATen/ops/_sparse_mask_projection_ops.h,sha256=B5d2zTy80MDDQz-X9tRTZk-hBE_jIJjvlALfC1i3pZM,2029 +torch/include/ATen/ops/_sparse_mm.h,sha256=ZTfHiXQG4S-yitMAAIram-YyFAdFy8SSgGcYr1XNFV4,1001 +torch/include/ATen/ops/_sparse_mm_compositeimplicitautograd_dispatch.h,sha256=dw3egYh8dhms2-7F4yWWayzPxh__lAiXBYpa7z9Zi_M,931 +torch/include/ATen/ops/_sparse_mm_native.h,sha256=d3V6xm7lHAJkmdqWX_xPIWJ0THmO30EP6qvy-Ep45GI,641 +torch/include/ATen/ops/_sparse_mm_ops.h,sha256=1Rn8TgGVtXSPBBNZIb-Cfgej4ddUnfYay4RXZIVHLd8,1805 +torch/include/ATen/ops/_sparse_mm_reduce_impl.h,sha256=s8ZcY0L8p9-Lu5B_NBA19DUimsvibzRnIbsRgx_BwyI,859 +torch/include/ATen/ops/_sparse_mm_reduce_impl_backward.h,sha256=Xk0WCF17u9EdkW_A5Ci590EUFce1cSL56GPAACQFTfo,1075 +torch/include/ATen/ops/_sparse_mm_reduce_impl_backward_native.h,sha256=xjTmSV36qXzud8hBvl-5Nph0fVbbp5tTFsxLD2nev30,705 +torch/include/ATen/ops/_sparse_mm_reduce_impl_backward_ops.h,sha256=NOyrk3jOxdObVgwP_rUdjYnGKp0lNQWMqWQahvCRygo,1613 +torch/include/ATen/ops/_sparse_mm_reduce_impl_native.h,sha256=KbbOSs_RA9nF0UUsVVEaXY-h9Li4LQtW_T_3iNAKijk,604 +torch/include/ATen/ops/_sparse_mm_reduce_impl_ops.h,sha256=RLQaRsFfcSpMW-WJVtcHC4GJJWgEBDZnRNYC6G1wYAM,1285 +torch/include/ATen/ops/_sparse_semi_structured_addmm.h,sha256=8zV4gRfnLZrug398T6WVGyL2ED3bbM8PvcWhzFFlHRc,1109 +torch/include/ATen/ops/_sparse_semi_structured_addmm_cuda_dispatch.h,sha256=BI2xROQrCCcBtc1wfBWxP7qcaCxxrgHafiqVFgIPoDU,962 +torch/include/ATen/ops/_sparse_semi_structured_addmm_native.h,sha256=TiQty0WBpGPxMgpKAWQmuGtZUiLMeE3dtROxchiZcnM,714 +torch/include/ATen/ops/_sparse_semi_structured_addmm_ops.h,sha256=ddqiRow_Pbq8y8kTcYDQnjNqxs1jpEnJA5x-JAg9Rjc,1645 +torch/include/ATen/ops/_sparse_semi_structured_apply.h,sha256=0oUvG7GYyIG0h00PteTb9lUB9Ir9wR7Nj-wipli1AXQ,866 +torch/include/ATen/ops/_sparse_semi_structured_apply_cuda_dispatch.h,sha256=8kYwnBAL12xaXpHkJgKqxG7g0lcsCrSGlqHTlC7Y80M,827 +torch/include/ATen/ops/_sparse_semi_structured_apply_dense.h,sha256=l4tqeFfNLfmmB5qCHLlAZX2zfhOMiYMd5KtglMeQDcI,855 +torch/include/ATen/ops/_sparse_semi_structured_apply_dense_cuda_dispatch.h,sha256=RUV5HDeYk2gXocOnqJe_gwuBoROrxt7s7uguzm1fyfg,808 +torch/include/ATen/ops/_sparse_semi_structured_apply_dense_native.h,sha256=LMGDvGk6EH5cjL-zXfZ7UmgHz083EphUEHGMpAFaFvQ,560 +torch/include/ATen/ops/_sparse_semi_structured_apply_dense_ops.h,sha256=_0XLABgK8rO0AhifggkHHXgT5MaXsYlRzYlbGtTZAcc,1183 +torch/include/ATen/ops/_sparse_semi_structured_apply_native.h,sha256=S9ViNOrM1mbIyjNOaT25ZnsJk8eaWkAarym0E3QcKuQ,579 +torch/include/ATen/ops/_sparse_semi_structured_apply_ops.h,sha256=YSkOAiNhWsKY6t7jS_Y6nAfdI9EZ5qjQebHzjMtCA28,1250 +torch/include/ATen/ops/_sparse_semi_structured_linear.h,sha256=pDSU2izSwtHq7HsYDphqS8TRVyMN57Y9RvPDa02rgzw,1126 +torch/include/ATen/ops/_sparse_semi_structured_linear_cuda_dispatch.h,sha256=amPkDOHCJ-GHCt_9eGZhBVNgH3w8KdgYQYNptdrXU14,986 +torch/include/ATen/ops/_sparse_semi_structured_linear_native.h,sha256=USG3qpBgqlq39WsRjJnmB_zzF13gBO5AimrZH52bJHU,738 +torch/include/ATen/ops/_sparse_semi_structured_linear_ops.h,sha256=UCwdJ6Ud9nYQE22pAG-zPmGQZNm1LqYXLwNfzH-VuIA,1672 +torch/include/ATen/ops/_sparse_semi_structured_mm.h,sha256=Qja2TLpmc9OR6ZaGijYu5Tcz2pEqjrxOuK9Q40C39Cg,951 +torch/include/ATen/ops/_sparse_semi_structured_mm_cuda_dispatch.h,sha256=WXy13Roq_AG2IajteOWOJX8LjcLOyZFH2dHg8SnwtzA,878 +torch/include/ATen/ops/_sparse_semi_structured_mm_native.h,sha256=sIcvXxwx2CoFtb-r2eg0fVQCXBb59HL8wVhcbCNBoL0,630 +torch/include/ATen/ops/_sparse_semi_structured_mm_ops.h,sha256=xWAPkcZCcUdODphAd0qdThs4yiXK7Pmp-ndZAItsfBE,1377 +torch/include/ATen/ops/_sparse_semi_structured_tile.h,sha256=x38Dz-sJBE13r7NVlkhsnCJXVcuOcIk1vEE3OutSiII,970 +torch/include/ATen/ops/_sparse_semi_structured_tile_cuda_dispatch.h,sha256=5oI682P0DPsCqmbvGPFiGzZAZUp8nqjBCidohXj97M4,880 +torch/include/ATen/ops/_sparse_semi_structured_tile_native.h,sha256=-SFQTsn14Z2nJxBNAhjm88NXzQQhbT8UGEfGD3s0Cxg,632 +torch/include/ATen/ops/_sparse_semi_structured_tile_ops.h,sha256=5O2tuayjPKhnobMxqCLB_cgZ9WW7Mbs-sz2xC9ImucU,1422 +torch/include/ATen/ops/_sparse_softmax.h,sha256=iKK8JWu_bPFLI-nR6P855Wo84vg7wyCZkvDh8MP6720,1996 +torch/include/ATen/ops/_sparse_softmax_backward_data.h,sha256=Bw41nKFQ315Ng89yZAW5yDl3StT3sQWhyn2v571xG_4,1723 +torch/include/ATen/ops/_sparse_softmax_backward_data_compositeexplicitautograd_dispatch.h,sha256=tJtTxLguHX-vITr7GcElLGeGPZf9bwkERmpTM8vxbWM,1081 +torch/include/ATen/ops/_sparse_softmax_backward_data_native.h,sha256=Ab0RkrIw_VyXqygdVXl7jr8LIbRKLM7NLwv0RqfkzIo,913 +torch/include/ATen/ops/_sparse_softmax_backward_data_ops.h,sha256=nsy85eS_oIVJp0Ix2Fqm9e68iVptZB-RkxsbwQR8DuU,2199 +torch/include/ATen/ops/_sparse_softmax_compositeexplicitautograd_dispatch.h,sha256=6zvPrdFF3Ek3xC7d_C0sTFBLzg_NLByXAlMTfZA-HAk,975 +torch/include/ATen/ops/_sparse_softmax_compositeimplicitautograd_dispatch.h,sha256=cRr9x6RL-HiwoH5ZaQEvTbri19nAGKwrKC5I6nTvKNI,998 +torch/include/ATen/ops/_sparse_softmax_native.h,sha256=pJZA9DkcgrjKWJiWXRdS-JkZqXwIxhz49Us_abbo6aU,1030 +torch/include/ATen/ops/_sparse_softmax_ops.h,sha256=xSPLvigVIoyA1WLnGZg-1S4LJIYKNoFxIFq94f353I8,3348 +torch/include/ATen/ops/_sparse_sparse_matmul.h,sha256=Vq0Ed7K8AZpLUTZpkA9q8pG7BG_uH0va5V_pq9Y0YrM,1358 +torch/include/ATen/ops/_sparse_sparse_matmul_compositeexplicitautograd_dispatch.h,sha256=B8dQcFiWfvp_VT0BWj9Jr1IRiPeQ9tltYwGMu2THDRY,973 +torch/include/ATen/ops/_sparse_sparse_matmul_native.h,sha256=OIwtvwtFuKGiEXFH0xZwsSx8W2huaSAv0Wfcvzlwmzg,761 +torch/include/ATen/ops/_sparse_sparse_matmul_ops.h,sha256=LZzDftf98sPk52vEKCpNJJaP351BUD47BEp9HCbx8XI,1849 +torch/include/ATen/ops/_sparse_sum.h,sha256=4vpHGylshqZLSntwJ4BRXcgDBtYYp7u3PBkNK5SfOV4,1890 +torch/include/ATen/ops/_sparse_sum_backward.h,sha256=wt1t3PP0KdjMvAnAl5Kh5JmrP1KTgZYNXdShgC2wx0c,1450 +torch/include/ATen/ops/_sparse_sum_backward_compositeexplicitautograd_dispatch.h,sha256=abvZ-TD3riT971nlXuoYLVXuG3SIc0f5yAUakhl2368,1011 +torch/include/ATen/ops/_sparse_sum_backward_native.h,sha256=HX8Ny2sHBHdVinRQTuB8OnRzJDLZHOsTDD4lBOLzNqw,820 +torch/include/ATen/ops/_sparse_sum_backward_ops.h,sha256=dDCCH5BB4KFO3PkyLUUDycGOZA8w7YaFdNiF1vDGFhc,1977 +torch/include/ATen/ops/_sparse_sum_compositeexplicitautograd_dispatch.h,sha256=05D-cMTs5xT7JGUwhtB3KZOtNRMSV7uvXhIMtuGTOgc,1024 +torch/include/ATen/ops/_sparse_sum_compositeimplicitautograd_dispatch.h,sha256=JiAGRiZ-83JjLZkL1tdTdrpTDvAjDOn2P0VzxwyoYKM,977 +torch/include/ATen/ops/_sparse_sum_native.h,sha256=bwrn50bgnA1sa5qlFbMybSUzMds2cJbvTaFuSdoASFc,877 +torch/include/ATen/ops/_sparse_sum_ops.h,sha256=_2f1QV2kIc8BIRTJ3DI9rDw3CtrAbjhCahQYsm0utSA,3639 +torch/include/ATen/ops/_spdiags.h,sha256=5XrUM8LFtOa09uHer6IHt0356sytmBKcHgnFlIqR0zc,1645 +torch/include/ATen/ops/_spdiags_compositeexplicitautograd_dispatch.h,sha256=rbseLJ5kCyYEdHgry09i9ZPM5W5QGnFiATF4WOby5Eg,1094 +torch/include/ATen/ops/_spdiags_cpu_dispatch.h,sha256=8XTp063BgbUE--CkDfYRESUt9iaTxmiiK_uqYomOG6Q,852 +torch/include/ATen/ops/_spdiags_native.h,sha256=8adqKjUCAKX_x8UJ7_DAGUuFu9aIvCZnzQgJOPLOsvs,778 +torch/include/ATen/ops/_spdiags_ops.h,sha256=M0LWneOAwLWpbiL7tD7ToBw1AQVLTIcsuS0wrH93SYo,2209 +torch/include/ATen/ops/_spsolve.h,sha256=766qrNXnSj-fVORuNUoLZG-6lL9n9suCihqkTgyoSsA,743 +torch/include/ATen/ops/_spsolve_native.h,sha256=gQbMW0mLmPwkijsvf2J4xbjUZlRHQyaEiEOEDZt8HaI,550 +torch/include/ATen/ops/_spsolve_ops.h,sha256=0YYlL241xv_IxKfJ0pqHmUuuEOe4AQLy3uHkRRMJpEM,1104 +torch/include/ATen/ops/_stack.h,sha256=-3RyAXhDbiNUfdnCcmNkCVk9gmiEunYCphG3Uc-8tL8,1179 +torch/include/ATen/ops/_stack_compositeexplicitautograd_dispatch.h,sha256=akeMUmb7XjXuv2ywqLjpGQ09LaQcZG_V9ahK674Qas8,986 +torch/include/ATen/ops/_stack_cpu_dispatch.h,sha256=wLS5eXTdh5JE6d2f3A4JekqG-HHjyc5gnd-cECVJh6k,942 +torch/include/ATen/ops/_stack_native.h,sha256=1nYYxmRvQaS4R3gQYs7cgeIj1Sg286g7ragzObJ4g6Y,770 +torch/include/ATen/ops/_stack_ops.h,sha256=pilueWz4Fo1Azb07jsNNmYAZ5tLVrRLKZaiLM9BjSkM,1677 +torch/include/ATen/ops/_standard_gamma.h,sha256=kRXUDkw6ip6G11O0BHrSF4gSptq6QqyaE5P7B3-oGQg,1427 +torch/include/ATen/ops/_standard_gamma_compositeexplicitautograd_dispatch.h,sha256=cwS65LOPPY3V9Sj5jdZxomHzhydy0DxOGYm84aL1hfg,1008 +torch/include/ATen/ops/_standard_gamma_cpu_dispatch.h,sha256=fT6vdOOPs1hR70IcZ9_m29kniUcVxheQbDdRbRYYOCE,809 +torch/include/ATen/ops/_standard_gamma_cuda_dispatch.h,sha256=sr-WOQ4Va_jmXzAkniiAt2DOud6kYJ2LJzoXd_kuuIo,811 +torch/include/ATen/ops/_standard_gamma_grad.h,sha256=2EqMOT-eMGVw0-joMrlc54o7EuWZmRxrqKg6xlPOY4Y,1357 +torch/include/ATen/ops/_standard_gamma_grad_compositeexplicitautograd_dispatch.h,sha256=ay3mewn6Kn5PNl_OOJ0xTC-7yCU0kMANPwImQTJaHr0,973 +torch/include/ATen/ops/_standard_gamma_grad_cpu_dispatch.h,sha256=kdHScbqofZMJQFMumnj-dixJ3vLxGoa5H54nipk8M5k,784 +torch/include/ATen/ops/_standard_gamma_grad_cuda_dispatch.h,sha256=jUwz2rKzW5LAhL9pdtHdp3Vv8B_LLFT3IAw7Da_QX0A,786 +torch/include/ATen/ops/_standard_gamma_grad_native.h,sha256=kI9-wBLsomSLd66m99_rJPf-v6qI5lix8T4LzgFSJEE,763 +torch/include/ATen/ops/_standard_gamma_grad_ops.h,sha256=3G2SPnjG5oN9dlYJdI4uIM2hNvzPZ6jS_KQn2gSE6z4,1849 +torch/include/ATen/ops/_standard_gamma_native.h,sha256=eR8-bdhH9zW_cvyMEF0HHYdUXN41R5yOnbpGeqUpCds,809 +torch/include/ATen/ops/_standard_gamma_ops.h,sha256=u65jIpU1lL9eVkl24TUvIoCa9hRIj54DKjmudJSuZsk,1927 +torch/include/ATen/ops/_test_ambiguous_defaults.h,sha256=X4Fcd-TFRhbB2NPBKzDru1x-7to2-4JEJvv2aW1T8qM,1055 +torch/include/ATen/ops/_test_ambiguous_defaults_compositeimplicitautograd_dispatch.h,sha256=lbV5ApVqzNPaIyx4qi-6EWTj8q_Lp4PH4QmyirmXEYU,937 +torch/include/ATen/ops/_test_ambiguous_defaults_native.h,sha256=w6Qb4W84CdYK1H-ZYG2OX2clLT382FuEN5fNqXi6nHk,653 +torch/include/ATen/ops/_test_ambiguous_defaults_ops.h,sha256=iWD79hMk84Fm-oQCn-OM9DZreJoDMukhzrVQt4U2CTQ,1812 +torch/include/ATen/ops/_test_autograd_multiple_dispatch.h,sha256=55Hx6Ejn12ggflLC9utpB-RmiwYnpUv0m0JQp9lpFqw,1656 +torch/include/ATen/ops/_test_autograd_multiple_dispatch_compositeexplicitautograd_dispatch.h,sha256=Hb2TALfqnG3nCz8g3XPzkDzrf668hgiDGfp5-Pc4IXI,1024 +torch/include/ATen/ops/_test_autograd_multiple_dispatch_compositeimplicitautograd_dispatch.h,sha256=aTgITMF1Ic7Ogwoe_utobgRVmNO_18eK3EktVmGlCq4,821 +torch/include/ATen/ops/_test_autograd_multiple_dispatch_native.h,sha256=hDqaxkESQaM25oGoGgjC1KWFa4rYluj-ToE6pkne6i0,750 +torch/include/ATen/ops/_test_autograd_multiple_dispatch_ops.h,sha256=vYJXzwEBwDawjOi1F-JAlc4L2But_ijAriabSCbf_7Y,2459 +torch/include/ATen/ops/_test_autograd_multiple_dispatch_view.h,sha256=w1n35FBCidN6QaPsMOK88CE5-SEBBFvAvhzrVf9O0ZA,798 +torch/include/ATen/ops/_test_autograd_multiple_dispatch_view_compositeexplicitautograd_dispatch.h,sha256=T8Jkg49OOHPewIPkqBs9L8y0jYWIutTwdfi7b5mp9W8,818 +torch/include/ATen/ops/_test_autograd_multiple_dispatch_view_copy.h,sha256=wKfp-wMrOYOEdEdx0JUkF47PajFI7Gu9XoF367qYTkg,1427 +torch/include/ATen/ops/_test_autograd_multiple_dispatch_view_copy_compositeexplicitautograd_dispatch.h,sha256=38_VAi0T86Zbsl1T8guRKLarDCZ-90_r9LMSHhVvq5M,963 +torch/include/ATen/ops/_test_autograd_multiple_dispatch_view_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=GswrOWUgk30YjcVvpBUWw5XUnLS7DSHcV1DLnLiTPOU,849 +torch/include/ATen/ops/_test_autograd_multiple_dispatch_view_copy_native.h,sha256=hx4rcTHE2PNALQXnJgSy3zKXcnFlWz3Lh9rkX5jrp8Y,648 +torch/include/ATen/ops/_test_autograd_multiple_dispatch_view_copy_ops.h,sha256=erOfqMwU1DGJQyycujAO3nK_ohtaRg8p4Sblo5vqoTo,1803 +torch/include/ATen/ops/_test_autograd_multiple_dispatch_view_native.h,sha256=Us9PMaAB-Fb-W4Gab9ZHr3oDMwOd67ZZ6_GBFBlsvNg,528 +torch/include/ATen/ops/_test_autograd_multiple_dispatch_view_ops.h,sha256=qIpnHgOqdPFh0J2CkwDB0HtbUlKmWqpVJGQbKIHcTuM,1085 +torch/include/ATen/ops/_test_check_tensor.h,sha256=FLCN7T1IIPH5iikAy-4NxGkZwzFGd7ICUnXYdD3_kvQ,716 +torch/include/ATen/ops/_test_check_tensor_compositeimplicitautograd_dispatch.h,sha256=EzWc3pxBc9xwMIyUitmXI6BRX9Eu072YuO7y8_sefd8,799 +torch/include/ATen/ops/_test_check_tensor_native.h,sha256=4zalP-_JXKh753gq-RH3MvUbWwFOfTZw40qZGf0XiRI,509 +torch/include/ATen/ops/_test_check_tensor_ops.h,sha256=ujFIgvX9cZ25fHf5--vLCMhFNWk8KeW_xR2z4J1sE6k,1022 +torch/include/ATen/ops/_test_functorch_fallback.h,sha256=pccpwLMg4uxISy-gPoY5VPrQHrPICwItDwRfZeu53Wc,1388 +torch/include/ATen/ops/_test_functorch_fallback_compositeexplicitautograd_dispatch.h,sha256=6XmxuB_l0ER5t5Z8GGur5rrW4Kmh3m18Ob5b3kCHBr4,979 +torch/include/ATen/ops/_test_functorch_fallback_cpu_dispatch.h,sha256=h3Me7F1VwAp6qpEzH7gNmG4ULN9HiP-Lxr4_anHnvVY,787 +torch/include/ATen/ops/_test_functorch_fallback_native.h,sha256=7qocT_jdUM72U9o5cs6mZenGaZCWTgcuqfaiA6-kl-k,664 +torch/include/ATen/ops/_test_functorch_fallback_ops.h,sha256=vblXbG-CcJSxleAQ8pBMVtR_Zzwmv2eIt7jyhG2YWR8,1867 +torch/include/ATen/ops/_test_optional_filled_intlist.h,sha256=0bPbj7PglgJZED6NCI2Du_nla25DlVjpiB6oaqgKEz8,1492 +torch/include/ATen/ops/_test_optional_filled_intlist_compositeexplicitautograd_dispatch.h,sha256=-m8oALEcvSTu6AT6Q76XXTHwWcP7gr0UVj-g52cTGaE,1007 +torch/include/ATen/ops/_test_optional_filled_intlist_cpu_dispatch.h,sha256=Ikfwc8hgdp28yJwWI9EsWTpzvArgXAo9THeGGLx3N-U,801 +torch/include/ATen/ops/_test_optional_filled_intlist_native.h,sha256=j3dPNQBtnpRXjUNcFuDzpm55hmRUXxjoBaKer9WyFv4,685 +torch/include/ATen/ops/_test_optional_filled_intlist_ops.h,sha256=0r0G-jL1az0M-HApdQ4C5PI-AVW9Zd3PGEj4e-LzF9U,1953 +torch/include/ATen/ops/_test_optional_floatlist.h,sha256=Dfxigx-RZ9oOmt2IyGT6RnirPAsW5R1LHet86sNT078,1487 +torch/include/ATen/ops/_test_optional_floatlist_compositeexplicitautograd_dispatch.h,sha256=UXJSzTZVOKi4VL5z-YQtOcFgsVBDOjuXZQnkeQqIgEE,1025 +torch/include/ATen/ops/_test_optional_floatlist_cpu_dispatch.h,sha256=UTn4xAEjR4iPOK5FITWQcztwgsq_VO8qE_LHxAb_ibc,810 +torch/include/ATen/ops/_test_optional_floatlist_native.h,sha256=F361sKixryFkT6naYs9GzJaSSYrlSGTcg-DQFTpE4YE,710 +torch/include/ATen/ops/_test_optional_floatlist_ops.h,sha256=S5Yg6kommgc-9xQ9WrAJT60UL_R4jnntVPpshH-Do90,2009 +torch/include/ATen/ops/_test_optional_intlist.h,sha256=1VTRrKdOtHz53yBcftQoetTfgiTL2qE5J8EJVKdSg8A,1419 +torch/include/ATen/ops/_test_optional_intlist_compositeexplicitautograd_dispatch.h,sha256=KwRih7XmNG6qQRT5ev8SkOSa36xFiEn2PC3bn1IxSBw,993 +torch/include/ATen/ops/_test_optional_intlist_cpu_dispatch.h,sha256=_559kE083pRrstKGSUSYCE6hAY0jLsD5LUellOSpZq0,794 +torch/include/ATen/ops/_test_optional_intlist_native.h,sha256=2bsMJz8IwoHpKsvC60tiQCfKS6miOuKfturuO_tDZjo,678 +torch/include/ATen/ops/_test_optional_intlist_ops.h,sha256=6XUYIx3C86_8grM-ClSQn56WA_VIbss6l8TzfEI63KI,1909 +torch/include/ATen/ops/_test_parallel_materialize.h,sha256=NHRcI5tg5vzxv5EEfIpQvLmQOK372rgT6rG79tzwTEk,860 +torch/include/ATen/ops/_test_parallel_materialize_compositeexplicitautograd_dispatch.h,sha256=TCzeh1rtzCO4DPc9VO8TvY9a_yP4ylqYsARFZT4VunY,852 +torch/include/ATen/ops/_test_parallel_materialize_native.h,sha256=HofiSlfBCroxghgXTrCXEICKwx-Iask2PKfgZX0NahQ,562 +torch/include/ATen/ops/_test_parallel_materialize_ops.h,sha256=e1eWOC7oLAfZoQ45MyYNi1ZF6-CxSB-V9kDgohPOSTg,1180 +torch/include/ATen/ops/_test_serialization_subcmul.h,sha256=GpaRJJrMzbvMZHgFoadN9cmD7y03fiL0MCGyJj0OkYs,850 +torch/include/ATen/ops/_test_serialization_subcmul_compositeimplicitautograd_dispatch.h,sha256=az34ikhVLSaY_Y4X_1mptmSxCxHIliSmsaS1DTKM80c,862 +torch/include/ATen/ops/_test_serialization_subcmul_native.h,sha256=tpPDSsMyA20KJVgp_7Nuy4mw8xugihSERzTSq3vRnFw,572 +torch/include/ATen/ops/_test_serialization_subcmul_ops.h,sha256=qvHIEYGdmI8pDNLbhkZOTaMKsjvRgiB3prMEz-2ETrE,1223 +torch/include/ATen/ops/_test_string_default.h,sha256=U_lVa7YcsXI9H5xD5BcnbHbnrl5g_rRR7mA6bZJpysk,819 +torch/include/ATen/ops/_test_string_default_compositeimplicitautograd_dispatch.h,sha256=9rCIEdNBzF6tYB1odbpUyd5LY9KCR6IlfduI9l64IBE,858 +torch/include/ATen/ops/_test_string_default_native.h,sha256=FutEvcEeI-lsSgxExibJKpt99RG37I-D889bX7mlWoA,568 +torch/include/ATen/ops/_test_string_default_ops.h,sha256=yWfgNeQnw9gs1e2lqCVFb5UjP--vbBl5tUAvsRZY97U,1187 +torch/include/ATen/ops/_test_warn_in_autograd.h,sha256=BAzKS0JfXc-K8tF9AXWJwkhdYhPegRopMP3CEFs0B5I,1227 +torch/include/ATen/ops/_test_warn_in_autograd_compositeexplicitautograd_dispatch.h,sha256=dro6vjGea8plVW8GWU5uMvOsXZfwZ9odz2wlCwtMtXs,994 +torch/include/ATen/ops/_test_warn_in_autograd_native.h,sha256=YTU6mOoCVzIEIULm4YecfG0clKPJwduZrbkdOxON9y8,608 +torch/include/ATen/ops/_test_warn_in_autograd_ops.h,sha256=VeCOotQZJ2VJz3A-OqExel7VBNVeM1dJYtuP67Xqi-k,1683 +torch/include/ATen/ops/_thnn_differentiable_gru_cell_backward.h,sha256=ieHlG6KrkBWCGxxzbM5h83AmCi8zsm7ULeA15nTQq9A,1231 +torch/include/ATen/ops/_thnn_differentiable_gru_cell_backward_compositeimplicitautograd_dispatch.h,sha256=H4I9K2i5__CpLeYFPqXcv5UZm6DB-ptu8pTul6ugQj8,1065 +torch/include/ATen/ops/_thnn_differentiable_gru_cell_backward_native.h,sha256=51K1r13Jx4U20r4GDHHoh_Z-OXSDg2Spz3IAudvG7hg,775 +torch/include/ATen/ops/_thnn_differentiable_gru_cell_backward_ops.h,sha256=ow8pNbhy7vG_vwYl-GmDHJ2VWLJWaPyGkVuml7irq-k,1896 +torch/include/ATen/ops/_thnn_differentiable_lstm_cell_backward.h,sha256=h4DloXqXxcc-1lzz91b6PGWzbFUWuF9eFEQkOpkteHA,1362 +torch/include/ATen/ops/_thnn_differentiable_lstm_cell_backward_compositeimplicitautograd_dispatch.h,sha256=zfXwqdEUB81oiv66dGsJoJKYX6hRTLmvPrXm4wrZRjI,1151 +torch/include/ATen/ops/_thnn_differentiable_lstm_cell_backward_native.h,sha256=2w47Qwn9kVCPd-LSHLh7MJHaDFlWTe7gGxewWgryErA,861 +torch/include/ATen/ops/_thnn_differentiable_lstm_cell_backward_ops.h,sha256=p3VtbhHs0yGqGp8dRfx9invvx0JmlS2PjLvAMbtObWQ,2172 +torch/include/ATen/ops/_thnn_fused_gru_cell.h,sha256=1PZ2E_ePmCy5-B8xVA8tJuxB3TENXVlvcio_bv8tQmU,2326 +torch/include/ATen/ops/_thnn_fused_gru_cell_backward.h,sha256=6Qfc88Gke81-R25PaDbaOo3NJXFhvenC2gka5U_RpCI,2287 +torch/include/ATen/ops/_thnn_fused_gru_cell_backward_compositeexplicitautograd_dispatch.h,sha256=fJBvcFNCA3qr5PfCOnGQdT7DOKL6JbVctsxBreIl4uc,1319 +torch/include/ATen/ops/_thnn_fused_gru_cell_backward_cuda_dispatch.h,sha256=bYhE5wQHGtazUaGXXpqVFIOrLb6_5uoNQyFkD5jTNO8,874 +torch/include/ATen/ops/_thnn_fused_gru_cell_backward_native.h,sha256=j9U4-1XiyGPZrhMmRKxgXA1-_o34uPdHKBDcMSBTOwQ,924 +torch/include/ATen/ops/_thnn_fused_gru_cell_backward_ops.h,sha256=0wNUWMw0UqRJdrhUii_d9tG4FsMiO8nQjyS-Kt_lokI,2776 +torch/include/ATen/ops/_thnn_fused_gru_cell_compositeexplicitautograd_dispatch.h,sha256=5Khrk2Oca8jVaY0lznTX6kJE-uITIqrashbZIipxcH8,1339 +torch/include/ATen/ops/_thnn_fused_gru_cell_cuda_dispatch.h,sha256=jgR0K4lywIwb7Cq3wOdai4gTZirlSNug1TY-i9IViso,950 +torch/include/ATen/ops/_thnn_fused_gru_cell_native.h,sha256=Kc0wm7m1sDuJAsYIbrU8NAZ6c3BTZXJaZHtySM9U-tA,1007 +torch/include/ATen/ops/_thnn_fused_gru_cell_ops.h,sha256=Kwb3b4LWJWJiY8T_Sx1OxYbVBtVwmuSxUvMqo2-JvSY,2971 +torch/include/ATen/ops/_thnn_fused_lstm_cell.h,sha256=pqvASP6SLij8qfqcev_L2-sdcckY6BH9ow7zWaBa50o,2489 +torch/include/ATen/ops/_thnn_fused_lstm_cell_backward.h,sha256=nDLp1kUR8GXk8itCcihD3Ur6NvaEcUDVAYLgKlGGaIM,1129 +torch/include/ATen/ops/_thnn_fused_lstm_cell_backward_compositeimplicitautograd_dispatch.h,sha256=c6f0Q-ZINzU0e5HsYA0qOoAhXBLF4ldXyva_UMGzT7g,1025 +torch/include/ATen/ops/_thnn_fused_lstm_cell_backward_impl.h,sha256=xQ7z_xY7E1bUPG8Zi3onbqP2a3Y0HmUzILluiJHEYoA,2536 +torch/include/ATen/ops/_thnn_fused_lstm_cell_backward_impl_compositeexplicitautograd_dispatch.h,sha256=b7uvMDZHVnsxjRPoQ2HN-MDE7m5vdNnVzNCTF8KzUAA,1419 +torch/include/ATen/ops/_thnn_fused_lstm_cell_backward_impl_cuda_dispatch.h,sha256=n3MGgXNzMCraUNoEctsLa1A85ClzIkVdZqx-9YUdKRU,966 +torch/include/ATen/ops/_thnn_fused_lstm_cell_backward_impl_native.h,sha256=lXvwojDcjgkIEDPtkJT_tKgCyPjjI1qsLRUtop_VnWA,1066 +torch/include/ATen/ops/_thnn_fused_lstm_cell_backward_impl_ops.h,sha256=5TcOD3G2VmCap9wcfC13nnpJj1DvI5WeeHA5WVnRtXc,3190 +torch/include/ATen/ops/_thnn_fused_lstm_cell_backward_native.h,sha256=bI1ulvaqC1aLCgod5LTON9yYZK0vcG5xMwNIu7N9q_4,735 +torch/include/ATen/ops/_thnn_fused_lstm_cell_backward_ops.h,sha256=zaCthQuKWemfQxsfk9g9btTAUGZ8mWpdhCW3K3SVXsE,1774 +torch/include/ATen/ops/_thnn_fused_lstm_cell_compositeexplicitautograd_dispatch.h,sha256=FJi3nn95HK8Y8cvZfaonXO8Ggsd_84uZeU1tAZYdJxs,1405 +torch/include/ATen/ops/_thnn_fused_lstm_cell_cuda_dispatch.h,sha256=fdkSN6fspT1qOX8wdvedlZqhmA0OgIH3owZKLPMiHTk,962 +torch/include/ATen/ops/_thnn_fused_lstm_cell_native.h,sha256=-H4k4MsoA36c0kPF7yoofI1KYCcNLGTsMsOWGxxUtDI,1052 +torch/include/ATen/ops/_thnn_fused_lstm_cell_ops.h,sha256=nS-0ckFhQaEte5wNc6j2TwfHOFtOxoRYnfN2hJ1x1TA,3138 +torch/include/ATen/ops/_to_copy.h,sha256=y4j6j-zURoWv6yoyBHdSwS6TO-0BIVgEGrtZEyPpNL0,2483 +torch/include/ATen/ops/_to_copy_compositeexplicitautograd_dispatch.h,sha256=dLoqngAuFwlog8GsmzFYpp66ricfZaZtIVutFLppamU,1498 +torch/include/ATen/ops/_to_copy_native.h,sha256=yd335QfOXU6PKHnSUuk7dmxnY17x5WG-vxZUqibiGA8,1204 +torch/include/ATen/ops/_to_copy_ops.h,sha256=57gUVBaO29v5sA2X4zuTP3343OO-OSzZUpjs_mxSORo,2567 +torch/include/ATen/ops/_to_cpu.h,sha256=f93tOWDbXuz0MOL1rPYXY9Ttl_dkZibHc7AgOnDrJLw,696 +torch/include/ATen/ops/_to_cpu_compositeimplicitautograd_dispatch.h,sha256=Sg8jtKVY9SC9POO5xeord6t98meHMwGPM7V4PAS2FVo,802 +torch/include/ATen/ops/_to_cpu_native.h,sha256=EamW0i0j-8VAh4ixsPRfZHKcvwWdobz7aBYJmn6La34,512 +torch/include/ATen/ops/_to_cpu_ops.h,sha256=t3J8VtOzZiVGWSMMTK6xsNJBfoPNqizpA49BI3uvtDw,1035 +torch/include/ATen/ops/_to_dense.h,sha256=6TFzkVdG4yD7PspCxLxchYSHGw66T7iJ7_q8xaQhUYE,1262 +torch/include/ATen/ops/_to_dense_compositeexplicitautograd_dispatch.h,sha256=M6mWyTL2ckHTIFtBUjgG4VXKP-iKfqqdV8i18JjqqwU,1075 +torch/include/ATen/ops/_to_dense_native.h,sha256=2F4NBLR6MEOAPi9LGmjofRgcv9_OjbmdctNwJ3Lz6yM,1113 +torch/include/ATen/ops/_to_dense_ops.h,sha256=CAAqFuT1YxThf-znvmmHhqhw1DfBhdwdAONMQeV-hZQ,2109 +torch/include/ATen/ops/_to_sparse.h,sha256=8lY-zfBqrL1z3I6SSSlzfcXsEZ1lB5udtq8MKziX-Zg,1969 +torch/include/ATen/ops/_to_sparse_bsc.h,sha256=-6rZHHxXEp-A2hhoH2_7A5Fl2GR1-4LYWbfhYiSJLEo,1246 +torch/include/ATen/ops/_to_sparse_bsc_compositeexplicitautograd_dispatch.h,sha256=NCYyJn0lb-gGLJibLeKeYw-lpqVKa5k2bxwuhtYB6VQ,1048 +torch/include/ATen/ops/_to_sparse_bsc_cpu_dispatch.h,sha256=p7eClLUtsPcE3vhRoteFnVw2XFJO5CX-SHCzapksgvo,829 +torch/include/ATen/ops/_to_sparse_bsc_cuda_dispatch.h,sha256=FnhpPy73W97JjjmCtZcj9D7IZtVhozlPD09pV3MqGl8,831 +torch/include/ATen/ops/_to_sparse_bsc_native.h,sha256=LAwrHhPv8HwdijUnUZh4eW9_UuMgGISVbVRNjuetDEs,1040 +torch/include/ATen/ops/_to_sparse_bsc_ops.h,sha256=Vq2Gv4CkgZbkqBEdBLkvr3WDjJoF3TUjDlSHdVBo5bY,2051 +torch/include/ATen/ops/_to_sparse_bsr.h,sha256=QwpztDCY4dTzBv5JkMVGBx-mMJ462FPkoHfw6UUbqeQ,1246 +torch/include/ATen/ops/_to_sparse_bsr_compositeexplicitautograd_dispatch.h,sha256=fd1vnI7iZyESkd9pbPhIGWMuoWsiGahnv71Jm344W5Q,1048 +torch/include/ATen/ops/_to_sparse_bsr_cpu_dispatch.h,sha256=g4It0_TuIl8oTYDDDosaoKvOx9HcC0JvZVpUSHchl4k,829 +torch/include/ATen/ops/_to_sparse_bsr_cuda_dispatch.h,sha256=b2KGgb5WddSP9ZoOz0eBMQnS6vZtUtnJT4OGsVjsLNg,831 +torch/include/ATen/ops/_to_sparse_bsr_native.h,sha256=pbRO2G59KoL1Th4VyVQRDU5z2xLYodU1g5FSEossuGg,1040 +torch/include/ATen/ops/_to_sparse_bsr_ops.h,sha256=9ZwhThr8gZousKEg_RHkyD_1kirHm74ett2fCpc1478,2051 +torch/include/ATen/ops/_to_sparse_compositeexplicitautograd_dispatch.h,sha256=6phNq8YNczJnM4gSscTxyi1Sq-j1DlcXKCfs7let8UY,1365 +torch/include/ATen/ops/_to_sparse_cpu_dispatch.h,sha256=Ru_XGpXYwYMAOV--JjPOWldWl7nhNYyHi5csRZyqjq8,978 +torch/include/ATen/ops/_to_sparse_csc.h,sha256=V5fY7UP7B5YnnK9UQoiJXaXPY34ldKB9KIPNTkMKyUU,1134 +torch/include/ATen/ops/_to_sparse_csc_compositeexplicitautograd_dispatch.h,sha256=QMWV2TSHDJMPRzhyzrLndceSYS-yai0d8yJCoGaA6G8,994 +torch/include/ATen/ops/_to_sparse_csc_cpu_dispatch.h,sha256=10KPrtMvHJ7ub9p7ZNv7bKKNLKRONfHegVbtH_8_wbk,802 +torch/include/ATen/ops/_to_sparse_csc_cuda_dispatch.h,sha256=SQndfY0h-oBBSIS8zyUsyaOgQnYr5j9ZIkK0B1mz-6U,804 +torch/include/ATen/ops/_to_sparse_csc_native.h,sha256=HqTuFUgeRQlikDaBeEq8XeUESyrPUDscmhW0mbx1XxA,932 +torch/include/ATen/ops/_to_sparse_csc_ops.h,sha256=-fNYxdWAAJ7211IgWJJA5y-eDsVs3IRWTK4l9Mab6BI,1873 +torch/include/ATen/ops/_to_sparse_csr.h,sha256=n3JcOzs2xiLgE5D-JsPMZx_lPC4dWJVj5x6CCNesfKw,1134 +torch/include/ATen/ops/_to_sparse_csr_compositeexplicitautograd_dispatch.h,sha256=YOierIH8MIkMctmhk2hadw6_tGOjrG_C7_EDVa1X0PU,994 +torch/include/ATen/ops/_to_sparse_csr_cpu_dispatch.h,sha256=Ifqj8CJSckAq-aoT3kUZ6ZVgFojDol27-vWAOH3K-M0,802 +torch/include/ATen/ops/_to_sparse_csr_cuda_dispatch.h,sha256=4Q6gNqasS8M5ivktppDmtDOumCNDIehylvh__oUHcJU,804 +torch/include/ATen/ops/_to_sparse_csr_native.h,sha256=FL9PzFEhMEkXILiXOD_CpIjLhAn-K9vznzRNlh2B7us,932 +torch/include/ATen/ops/_to_sparse_csr_ops.h,sha256=bo9-FovFBd8KkCSu037ZA6pTIV_DtfHJYh_E5Qzps0o,1873 +torch/include/ATen/ops/_to_sparse_cuda_dispatch.h,sha256=5nOarLydXet7aS7P-aYCeHOuE6iM0SfMLlbzAvK5FDA,980 +torch/include/ATen/ops/_to_sparse_native.h,sha256=y29qFQdxNDURsgb_wbNFO7tFcWRMJk5oam0eFK_Nac0,1680 +torch/include/ATen/ops/_to_sparse_ops.h,sha256=2fkD75bGeX6GRIG95mdg9Y_ruKCGNv415oIsfvjNLKE,3665 +torch/include/ATen/ops/_to_sparse_semi_structured.h,sha256=FZpdMe3jBtKKDtGIvNo72gXYR3ljNFfg9p-YdSVMC4I,786 +torch/include/ATen/ops/_to_sparse_semi_structured_cuda_dispatch.h,sha256=1Az6SOH3y7WqmkTv2_5yyWBbCLqHiITnx_rbj1RIoAM,791 +torch/include/ATen/ops/_to_sparse_semi_structured_native.h,sha256=KWxV-GPZD9-TKA9Q7Eu7V9FrOT818lWyBPmc2wmgQsE,543 +torch/include/ATen/ops/_to_sparse_semi_structured_ops.h,sha256=r74-eWZkb32F_VguSu9QumSx478zSeSHYcS0uUSc0MI,1134 +torch/include/ATen/ops/_transform_bias_rescale_qkv.h,sha256=cf-wWD6q_fAoyckZXFy41ekpJiK7_97uVy9d7PQcmH8,1931 +torch/include/ATen/ops/_transform_bias_rescale_qkv_compositeexplicitautograd_dispatch.h,sha256=UzUwfsVshmfY49A0yrLb_PyxX4JDIfT6-TQ6ynBijQQ,1185 +torch/include/ATen/ops/_transform_bias_rescale_qkv_cpu_dispatch.h,sha256=39AtC41Bj562tF-g_1VDUeUNQEhrS29meuQif1yHRRQ,847 +torch/include/ATen/ops/_transform_bias_rescale_qkv_cuda_dispatch.h,sha256=DdwCIs51az6hROxfQPA7qyebVMF8dcmfwlrEaULKGu0,849 +torch/include/ATen/ops/_transform_bias_rescale_qkv_native.h,sha256=YPE3zoY8eZVptlWaUff-PwGAUB9_fo5coj-ek9mFB80,993 +torch/include/ATen/ops/_transform_bias_rescale_qkv_ops.h,sha256=IfPUrNXvRcJYquCXuy8-p3xvSMxWjfAWX-7ojL-babY,2434 +torch/include/ATen/ops/_transformer_encoder_layer_fwd.h,sha256=ouiWnwQYj0-DKN4011BpWS6aiE8tRL8wf2lcGxCcav8,4673 +torch/include/ATen/ops/_transformer_encoder_layer_fwd_compositeexplicitautograd_dispatch.h,sha256=5moQKYmfvEkMir7wUYu7PnIE2tQVGxuHtDN-Hy80zhM,2039 +torch/include/ATen/ops/_transformer_encoder_layer_fwd_cpu_dispatch.h,sha256=fWo4HSzh3I8y0TBeL1kFW8ecc2CZxFWDux4VX-W3j00,1326 +torch/include/ATen/ops/_transformer_encoder_layer_fwd_cuda_dispatch.h,sha256=v6M9AuO3IkgMm5C4q3aiF5eiv4hrdPLl_xHPYJ_cEYc,1328 +torch/include/ATen/ops/_transformer_encoder_layer_fwd_native.h,sha256=vC4lN7LiW18DZRBb4Dp4hzjYBTA6OuAmz47WnZWC2Ig,1727 +torch/include/ATen/ops/_transformer_encoder_layer_fwd_ops.h,sha256=o4lWA0MGH-oMscWlb1HIKG9HPZNPVAIVJBzhNFOOvx8,5277 +torch/include/ATen/ops/_trilinear.h,sha256=6Sh8H1IlsGifBN-y6dpN0VP37lerzMOjA999jRgb7ao,2050 +torch/include/ATen/ops/_trilinear_compositeexplicitautograd_dispatch.h,sha256=UmNMMZD6vqDdSTjZUl397uM3Uqys-Khi10CoRwlQkqs,1227 +torch/include/ATen/ops/_trilinear_compositeexplicitautogradnonfunctional_dispatch.h,sha256=DyaiUYdDOH-0YVUjSyMHC0Ueb5VQQn4VSDQvnsnv9lo,982 +torch/include/ATen/ops/_trilinear_native.h,sha256=c19OowMiGl7OvgToWN7weIv3TtgxUbs9WgCNYUvki-U,912 +torch/include/ATen/ops/_trilinear_ops.h,sha256=U5MhZxczG3X3fqncAfQwRzdvtg-471IaBw_JnHexSmU,2691 +torch/include/ATen/ops/_triton_multi_head_attention.h,sha256=fvmLBmEp4WpsQupdlX1hqdaoWZkc9-IzOLwK53cdOTA,2754 +torch/include/ATen/ops/_triton_multi_head_attention_compositeexplicitautograd_dispatch.h,sha256=0Lt3J0WoWIY1JakBGVzfmmii0WeJxELRI826VxJIrS0,1442 +torch/include/ATen/ops/_triton_multi_head_attention_cuda_dispatch.h,sha256=v0OrPTxXuChk_-sMCNm-ARBg9dbLo1jnHabmHrai4yk,1022 +torch/include/ATen/ops/_triton_multi_head_attention_native.h,sha256=D7P6jmSYw0orcKbPccWmrFJggpK1ul6yIy1N3Tzjuh0,1126 +torch/include/ATen/ops/_triton_multi_head_attention_ops.h,sha256=vPKwl0KAGAaySMrSVSIl0ok2ULgrNLtP0GRWe4bwyHs,3375 +torch/include/ATen/ops/_triton_scaled_dot_attention.h,sha256=SDfK1e-mHVPggRvG2X9ixjm4BZccrHDgDONL8wIwY70,1628 +torch/include/ATen/ops/_triton_scaled_dot_attention_compositeexplicitautograd_dispatch.h,sha256=n06-EObUehPQ834BapksP87gPTPlKRvYRPkuNTYg_us,1057 +torch/include/ATen/ops/_triton_scaled_dot_attention_cuda_dispatch.h,sha256=4d-htzM832LknvpiA1X7HP1NGnQNVfgaNqAJKz5HmyA,830 +torch/include/ATen/ops/_triton_scaled_dot_attention_native.h,sha256=5X4pK02C9A2-MXENC5ZvquSUxv50PTunyIbpXyIdfHc,741 +torch/include/ATen/ops/_triton_scaled_dot_attention_ops.h,sha256=f59ZCD0yxv0ZiSg5402lhQbBJz5Z3k-ya7Y8RBNx_M4,2127 +torch/include/ATen/ops/_unique.h,sha256=bOBQ_nXtjRZBUJ_bvwoXFooLaqy3amVehaGJQFkHl1Q,1615 +torch/include/ATen/ops/_unique2.h,sha256=A_TVJBBIh8bP93AYKKoyxH0HKBP7R--byd7EzXRXUj0,1973 +torch/include/ATen/ops/_unique2_compositeexplicitautograd_dispatch.h,sha256=AroKAfcwbLUkYA2qjklhriygVUd5BJp-VjejUgaD3cM,1178 +torch/include/ATen/ops/_unique2_cpu_dispatch.h,sha256=ts6KdUdiLnrvIxPGX1xPW0yeITyRqprcpNvqZaMB9Xs,852 +torch/include/ATen/ops/_unique2_cuda_dispatch.h,sha256=aJZLmPCYTAyIFXZmD29NrIXrjF880WFnuvVtWAuKing,854 +torch/include/ATen/ops/_unique2_native.h,sha256=PkOFvhepEVcqBXr7JB9lurnGA1vmso17g2JA9XAgjbE,993 +torch/include/ATen/ops/_unique2_ops.h,sha256=V7g6yEHMbVKOVLKIZ9S3BtLV82RmBQxqEYgIEncYdqI,2406 +torch/include/ATen/ops/_unique_compositeexplicitautograd_dispatch.h,sha256=_JRPYpFxBc6j7q_5VCSz2P7bglFQyGZPkQraoYYxN4Q,1066 +torch/include/ATen/ops/_unique_cpu_dispatch.h,sha256=7OWiY_UpVTSj-xdmBM4F_g3WkzLRd_daHn58mtmi9qs,814 +torch/include/ATen/ops/_unique_cuda_dispatch.h,sha256=vDfELk22kGYkt92LbLXVrrST06Rsn_xr6dWlgTNZOhc,816 +torch/include/ATen/ops/_unique_native.h,sha256=iVg-gs3T7NRZKbsRGzHoYy72B7I5BpFmrbhgVUSy1rI,864 +torch/include/ATen/ops/_unique_ops.h,sha256=sehC7xhJPYvaTC1cPpNuj-jSeUK55ag4Pg465Xo3jo0,2095 +torch/include/ATen/ops/_unpack_dual.h,sha256=QmIdFT1cj2--FRDo3EykkNlPmurKZ7iM3yZytUvcWZ0,781 +torch/include/ATen/ops/_unpack_dual_compositeimplicitautograd_dispatch.h,sha256=RKXGNtC-RsikFuqnh7_cGtglpY-bg9JLheICm9VU7Pc,833 +torch/include/ATen/ops/_unpack_dual_native.h,sha256=mXrX79PjuuQVqmT0hkVdu0MGGNEUBD6Yq4ebGuLKvC4,543 +torch/include/ATen/ops/_unpack_dual_ops.h,sha256=njxsKxYXyEOVSNcJR6wFt3X1hp8fHGM8Dh2UiABAcko,1160 +torch/include/ATen/ops/_unsafe_index.h,sha256=Hd-axbSl6g5uvtPY-tLMciDT2uHSTT0fgM2CFFxNznc,794 +torch/include/ATen/ops/_unsafe_index_compositeexplicitautograd_dispatch.h,sha256=Kz2G-HoT85SCMjtWgyojxfwSAcdGXCldMugsb1vVTMY,850 +torch/include/ATen/ops/_unsafe_index_native.h,sha256=RqpQjnCLOj1U9pwYP-9b9BKRbzITpURYPzo-uxhHAP0,560 +torch/include/ATen/ops/_unsafe_index_ops.h,sha256=WAgN8MpaqDzjIyIPYyKWJ6K-1uRXBIko9c_sUNtQZ5Q,1206 +torch/include/ATen/ops/_unsafe_index_put.h,sha256=f4d6SJ9vWDvyhuUqoW7JsskT1CWoWLlp4Z5IeSe_ONM,904 +torch/include/ATen/ops/_unsafe_index_put_compositeexplicitautograd_dispatch.h,sha256=ZX0E9bKlsacF7k_Jh1K2bsoEAPBZ48CP-UIOiVRpMj4,904 +torch/include/ATen/ops/_unsafe_index_put_native.h,sha256=XuP1tWoCRw0yIYR4j6edBg4uZbMWX6K-GuWZO-TAH5c,614 +torch/include/ATen/ops/_unsafe_index_put_ops.h,sha256=KItO1sDVjG8fL8AupCBAkzdc4B7VOiRNHJEJjvkhS7Y,1350 +torch/include/ATen/ops/_unsafe_masked_index.h,sha256=iA-eHyQyqKxF6UwujGBE_O9Wv07OUAuFeOI9_1AG9KE,896 +torch/include/ATen/ops/_unsafe_masked_index_compositeexplicitautograd_dispatch.h,sha256=shuCLfRHW4Mi4rig-Xj0vCmjLZVgAUYOHmfPZVVmnD0,907 +torch/include/ATen/ops/_unsafe_masked_index_native.h,sha256=rbfC-kljGSX7olzDmVvKPQF8tXTkFNTsO_R9YlNRaEY,617 +torch/include/ATen/ops/_unsafe_masked_index_ops.h,sha256=fNJBG_0GYxHfG5E8GVAcd2WzoseqBDK1iw1ABNMIq9I,1373 +torch/include/ATen/ops/_unsafe_masked_index_put_accumulate.h,sha256=dlq-dEjZpKYWwpKnVR0eXaAHPUciZq0aHYBlS8bNGPI,962 +torch/include/ATen/ops/_unsafe_masked_index_put_accumulate_compositeexplicitautograd_dispatch.h,sha256=ePiAsHodug4gOPFOuI6ON3FjhcdTT4oFPh8bS8KNH_0,924 +torch/include/ATen/ops/_unsafe_masked_index_put_accumulate_native.h,sha256=VD6x-ElxoO00tGA-Up-B19TJQSxCyGTv6RVxmqQQGQI,634 +torch/include/ATen/ops/_unsafe_masked_index_put_accumulate_ops.h,sha256=Na4tgXQloprnjlp67Pj-bbZQulEAzYqt9ahfj1A6EHk,1424 +torch/include/ATen/ops/_unsafe_view.h,sha256=TjfaTH1VK0ABzM2N324fWEl7mk-G-SEW5UtzuzzPRNg,3769 +torch/include/ATen/ops/_unsafe_view_compositeexplicitautograd_dispatch.h,sha256=3UMMTbQX6t-AK6eykVhlwl960F72Dw51ySq-hE30WPM,1361 +torch/include/ATen/ops/_unsafe_view_native.h,sha256=-3Hp_6lAljzZFQd2jW8GImYUWknL1dWXOX_fPfjvhzc,643 +torch/include/ATen/ops/_unsafe_view_ops.h,sha256=6n8zvNJHMbjvBEZAM101s-I873UddESdU9TZLloYUvA,1799 +torch/include/ATen/ops/_upsample_bicubic2d_aa.h,sha256=0qvGYLPwtmfXoETbuz6EgCBBtWVFaEZqu4PaHrUGZS8,8216 +torch/include/ATen/ops/_upsample_bicubic2d_aa_backward.h,sha256=VtgHAaAtK5XWULrP6GsH3N4VUUHs-t4IW-1mjlGCWLU,7900 +torch/include/ATen/ops/_upsample_bicubic2d_aa_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=qbQRDFBIhCXcnOFYALCU7aqjlvF9ekziCUGUkVdhheU,1297 +torch/include/ATen/ops/_upsample_bicubic2d_aa_backward_cpu_dispatch.h,sha256=yhP50ZrowNaONuzAP9bs1iVXQzeSmxjQRdmGybu60hY,2371 +torch/include/ATen/ops/_upsample_bicubic2d_aa_backward_cuda_dispatch.h,sha256=5TeZXqB0cT5RB_FMFsARwy_5rC0dNtdTP0bPmtIlkJU,2373 +torch/include/ATen/ops/_upsample_bicubic2d_aa_backward_meta.h,sha256=_f8m4s1QlYuwo-RjOzJeJy7bn6GNLLpQCmyDTIrECpk,783 +torch/include/ATen/ops/_upsample_bicubic2d_aa_backward_meta_dispatch.h,sha256=rNANyxY5uUSc6SgkreeFvtNDwpAFdY7UFjrG11d1H9o,2373 +torch/include/ATen/ops/_upsample_bicubic2d_aa_backward_native.h,sha256=JGFAK8y2BdCmXtGjW0hGwnGIq2BiygYnGmOicyDNx7M,1239 +torch/include/ATen/ops/_upsample_bicubic2d_aa_backward_ops.h,sha256=19n0lD6J1cQY55seKZe1WMzB2aaMrXEg5mSzb_m1h1g,2847 +torch/include/ATen/ops/_upsample_bicubic2d_aa_compositeexplicitautogradnonfunctional_dispatch.h,sha256=CRn1wdP0YYS8mlR3knEqtoyBYKVsbWCN3phyQS6Y8Eo,1205 +torch/include/ATen/ops/_upsample_bicubic2d_aa_compositeimplicitautograd_dispatch.h,sha256=0ZhSzHHimqipnw0kbnoHpEqvfYVE38MwnERVnYUf-yI,1106 +torch/include/ATen/ops/_upsample_bicubic2d_aa_cpu_dispatch.h,sha256=kIKqJPGrFeulhZBIvjygGza7446h6CVYDymyBjlQA7Q,2067 +torch/include/ATen/ops/_upsample_bicubic2d_aa_cuda_dispatch.h,sha256=bQX9rz98xpvTtp8EGhnbp3s6iaJ-VsvEmBPPTcLrFns,2069 +torch/include/ATen/ops/_upsample_bicubic2d_aa_meta.h,sha256=w2IZxZ71_6gBavDa-7wO-TznoO_FC0ZbiJVMuVwyw8E,733 +torch/include/ATen/ops/_upsample_bicubic2d_aa_meta_dispatch.h,sha256=9gZl58wQnTyj5KXooiGqH-9B_YqaRFpOqRjBhSoVAU0,2069 +torch/include/ATen/ops/_upsample_bicubic2d_aa_native.h,sha256=PdvdLgkn78PESXgxrlig29Bg34UlDW9ZZwSRqYlqzqQ,1280 +torch/include/ATen/ops/_upsample_bicubic2d_aa_ops.h,sha256=ul_7dK2QjGfAlr8YKSPgDv7b6MVjowQQG1gQlONckZk,3434 +torch/include/ATen/ops/_upsample_bilinear2d_aa.h,sha256=cKecofabtwr09h72SNJkz5BY9M3F5MLThXnXfZfnDRY,8257 +torch/include/ATen/ops/_upsample_bilinear2d_aa_backward.h,sha256=O1JReW7-UnzchPnnLYaeC0fu78C87sUkfL9gyaTz66s,7931 +torch/include/ATen/ops/_upsample_bilinear2d_aa_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=IAMUREq1WELcn8LEjOCDqGdAo9ZwQEwcqWarubv4sFE,1299 +torch/include/ATen/ops/_upsample_bilinear2d_aa_backward_cpu_dispatch.h,sha256=SvebdTlN4PFe0AJr8v3gvtYJvulvQ3Us2Qu6_Ox2AXQ,2377 +torch/include/ATen/ops/_upsample_bilinear2d_aa_backward_cuda_dispatch.h,sha256=P2NVk6I7gAYyrDM8SpyLr5LTZB_qjENthNyT1QnA5I0,2379 +torch/include/ATen/ops/_upsample_bilinear2d_aa_backward_meta.h,sha256=yNiR8Zeo01FDHDU1Xi82i5iSzZzInR6Vj9Qvndlb0CA,784 +torch/include/ATen/ops/_upsample_bilinear2d_aa_backward_meta_dispatch.h,sha256=Akror947iiwADOJcjN-o8keOxY2wRHQ30yIVjWGTTM4,2379 +torch/include/ATen/ops/_upsample_bilinear2d_aa_backward_native.h,sha256=UnQ6LYycfk8jUFINKh1K6Z4w8HtTS4zTxFkrqm6k6Q0,1244 +torch/include/ATen/ops/_upsample_bilinear2d_aa_backward_ops.h,sha256=8ENDenN8efxQ2XsvtYe4uGtNWqL6f9JN9M4lCVQyAjs,2853 +torch/include/ATen/ops/_upsample_bilinear2d_aa_compositeexplicitautogradnonfunctional_dispatch.h,sha256=dcdk3-sPjLP-BQ4EkganxZAeokKfiHSs_lqHbnMzO7k,1207 +torch/include/ATen/ops/_upsample_bilinear2d_aa_compositeimplicitautograd_dispatch.h,sha256=xkU0zRQ5fqdy_8zJJcGXsXsYstDS3grT_NW63nAKdKQ,1108 +torch/include/ATen/ops/_upsample_bilinear2d_aa_cpu_dispatch.h,sha256=VQT0Vi3FiBK2QNLUMLrsK7ncCik-h7VcNP-Q9NuRxjw,2073 +torch/include/ATen/ops/_upsample_bilinear2d_aa_cuda_dispatch.h,sha256=JD3FnPPh_18dMJWOkHoMkbHxoNWLchkLtKFCkhY3EF0,2075 +torch/include/ATen/ops/_upsample_bilinear2d_aa_meta.h,sha256=FAUAKxuzmd2HyK6XlS6eS1qCXXCA1f9Eee_WPmn-tw0,734 +torch/include/ATen/ops/_upsample_bilinear2d_aa_meta_dispatch.h,sha256=01m_n5waYHpEE3VZuOgX9RZZ3NqNH4gJxEx_pEl2PLg,2075 +torch/include/ATen/ops/_upsample_bilinear2d_aa_native.h,sha256=fwmuglcgVqTB-tWmZmBDRp_L4yePObhFBDPK3zvtf2k,1286 +torch/include/ATen/ops/_upsample_bilinear2d_aa_ops.h,sha256=0qCx8741dWNVrqm9LnYDZQe6CKv1I-Lj2BZUXDr7AKU,3443 +torch/include/ATen/ops/_upsample_nearest_exact1d.h,sha256=aq3fw6LD1gX9rwnG45GYOoh0XDmP_93rCk1UFRSqYmk,6779 +torch/include/ATen/ops/_upsample_nearest_exact1d_backward.h,sha256=BGjTfee8RMi8PQCCKL5_C-KUAvHlTi3B4NDI061_q4g,6613 +torch/include/ATen/ops/_upsample_nearest_exact1d_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=GU8KAX7OtZ0GAHqRvG2WedRxHq5TNtjsOtTBB7wXqFw,1161 +torch/include/ATen/ops/_upsample_nearest_exact1d_backward_cpu_dispatch.h,sha256=u0H6-iZu1yraRxu34IMKKFPSH5TCEuJBzem4dsLTWfI,1993 +torch/include/ATen/ops/_upsample_nearest_exact1d_backward_cuda_dispatch.h,sha256=thuH9LMiLTuJDhrIw9jCEYllyjpT6v7CaR7XuFXyaQ8,1995 +torch/include/ATen/ops/_upsample_nearest_exact1d_backward_meta.h,sha256=omVsdghHNnuVRG-7XmdKAvwlGRiQni2dIIMcqoVdn04,730 +torch/include/ATen/ops/_upsample_nearest_exact1d_backward_meta_dispatch.h,sha256=zQoSE6VCFJRbRqoqpFU6n_bOHXTDgcB0IVH2tD5-3Z4,1995 +torch/include/ATen/ops/_upsample_nearest_exact1d_backward_native.h,sha256=K-Enfv3RSAkgMX_bWFQl2gughNicjo8XkUZ_0l011wA,1142 +torch/include/ATen/ops/_upsample_nearest_exact1d_backward_ops.h,sha256=yFrAkmtwKdXWlvyMnyuCRa8aJi_dKKuAovmnVF6PAS8,2491 +torch/include/ATen/ops/_upsample_nearest_exact1d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=hT1U8NgX1i-VjWiATXmcIbQJRqXRbnYvgDoxfjtcusU,1069 +torch/include/ATen/ops/_upsample_nearest_exact1d_compositeimplicitautograd_dispatch.h,sha256=F40VoYkzqpBjMcVhi7YUj2RyJe9Pt-OISNOtqqkZuwA,1072 +torch/include/ATen/ops/_upsample_nearest_exact1d_cpu_dispatch.h,sha256=TCG_yF7e8dfRTKgEBYPCiMCwZvWhKqVD1gmPOaQLV7E,1689 +torch/include/ATen/ops/_upsample_nearest_exact1d_cuda_dispatch.h,sha256=ZaF6cB2r2XVJN1LdhXsj4PwdGfN5EhBOZp6Ssb8TmsA,1691 +torch/include/ATen/ops/_upsample_nearest_exact1d_meta.h,sha256=kLct2E_Bz1KXbpSZfz5U1FULiuE2m9_z4ovFV3EN8bY,680 +torch/include/ATen/ops/_upsample_nearest_exact1d_meta_dispatch.h,sha256=qHVUbJ7vPipbz-K7gGyNC6ZMy1Gx9CvjzHTux5cSqXE,1691 +torch/include/ATen/ops/_upsample_nearest_exact1d_native.h,sha256=NS3YvdHQi48Tlu_dz0km9NlSbbyyGQDjvPnAJcybQ88,1166 +torch/include/ATen/ops/_upsample_nearest_exact1d_ops.h,sha256=CYfvj86ORsUtYm9UxzZz7wy-j2KcAXo-_CZySpjLq_U,3021 +torch/include/ATen/ops/_upsample_nearest_exact2d.h,sha256=sPNQUSJTlJiS9gwwLQFdFVCaS7nUaA01XSGlcKqdpOA,7619 +torch/include/ATen/ops/_upsample_nearest_exact2d_backward.h,sha256=7Qw9u9Uf9kV9VKjy_2WPnnsmzmqp8hd1kcyWTsQgg-U,7453 +torch/include/ATen/ops/_upsample_nearest_exact2d_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=oAsMerpmWRr3gLTb4FdrAgVxQdjyh2zzIcgYFykwag4,1263 +torch/include/ATen/ops/_upsample_nearest_exact2d_backward_cpu_dispatch.h,sha256=6P86CMUI_EHu9Agq8sahaTEQQXUTWJJnXvdmOdDmSIk,2269 +torch/include/ATen/ops/_upsample_nearest_exact2d_backward_cuda_dispatch.h,sha256=e0uvdvqf7LxQGLRRPburzWgPKgZYfWkWwv3BqOTZXW4,2271 +torch/include/ATen/ops/_upsample_nearest_exact2d_backward_meta.h,sha256=O1lKcC4HY-dF5zV4cOJLEEG44foeO5M54wx3C47cfas,766 +torch/include/ATen/ops/_upsample_nearest_exact2d_backward_meta_dispatch.h,sha256=d7_4ndCA2lb5UvPdDyII-phL7n1H4xSgWgNgFkbIGHc,2271 +torch/include/ATen/ops/_upsample_nearest_exact2d_backward_native.h,sha256=ZaL1Cdbm6MhnGHOVVPVOukmX9CkxVpvY1cFamIVEOoQ,1214 +torch/include/ATen/ops/_upsample_nearest_exact2d_backward_ops.h,sha256=qJrmTVp3QsDguxSMCmVTzMnqZAWLS5HtlHeFLlfnmiQ,2733 +torch/include/ATen/ops/_upsample_nearest_exact2d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=YICzuyhokGatftpOpaL7QQM66_a-JTjYkGQz11Vpe_E,1171 +torch/include/ATen/ops/_upsample_nearest_exact2d_compositeimplicitautograd_dispatch.h,sha256=qSbcpR88NIC5wcdbCh1yeD-rFMRfooPSRM0vTopjypo,1072 +torch/include/ATen/ops/_upsample_nearest_exact2d_cpu_dispatch.h,sha256=KKJr7yKSv2cM5FEBWPGu9auu95ow-fAYFjHjaSrKvhc,1965 +torch/include/ATen/ops/_upsample_nearest_exact2d_cuda_dispatch.h,sha256=O_mRXymhI3nxYFbp4-TH0uuHEpKQPRrUT5laBKgo1Gc,1967 +torch/include/ATen/ops/_upsample_nearest_exact2d_meta.h,sha256=x9U1fL9GUTJEnWAUlxy9WICxkAxMCw3SiiI5Z1WNI0w,716 +torch/include/ATen/ops/_upsample_nearest_exact2d_meta_dispatch.h,sha256=aY6UtywWszjubOshcLMABgNbe2uJCT7qT3erwQpxUIE,1967 +torch/include/ATen/ops/_upsample_nearest_exact2d_native.h,sha256=pQNSTKmkBj_oL66jbQlXn5bWHTQmEeuTOMTZBsxYWmg,1453 +torch/include/ATen/ops/_upsample_nearest_exact2d_ops.h,sha256=5TjAfHO7zezp2hYaXExzmt-D1l_uvhqaKPz9Atlhwgc,3263 +torch/include/ATen/ops/_upsample_nearest_exact3d.h,sha256=youhmMHZkfXK7wA0HA7xsiWNPruGju6E_KOKiEEfBDA,8399 +torch/include/ATen/ops/_upsample_nearest_exact3d_backward.h,sha256=NOyXLW7pnQBB4-Sz-wrLw6yWfZl8XJ87HQNYc-fBu9Q,8233 +torch/include/ATen/ops/_upsample_nearest_exact3d_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=PJmtFESVUcWIsULwMUWBXk1jfytVzrdhPa9YbGwmjrc,1361 +torch/include/ATen/ops/_upsample_nearest_exact3d_backward_cpu_dispatch.h,sha256=KpvHQdUZWzKTXvowOthxStSXnwXgX5Pj6pA2HFT1DYc,2533 +torch/include/ATen/ops/_upsample_nearest_exact3d_backward_cuda_dispatch.h,sha256=rOca5fZz6AjRKHLbXQYFXhJpW38eF74rUrNzuPIXvuE,2535 +torch/include/ATen/ops/_upsample_nearest_exact3d_backward_meta.h,sha256=9z06lsb01pz_Y3OeLtV8aCF7WUSj-ikZe7RtMQUPEH4,800 +torch/include/ATen/ops/_upsample_nearest_exact3d_backward_meta_dispatch.h,sha256=jVG0PBcmUcz_6ijaVY7ZUe76AzfHRQI4Q4pgkIPM9SU,2535 +torch/include/ATen/ops/_upsample_nearest_exact3d_backward_native.h,sha256=0jzB1zFQz3Q5b4HxehGRwGcfCaR2oD560O2ZQKsnmac,1282 +torch/include/ATen/ops/_upsample_nearest_exact3d_backward_ops.h,sha256=un4Yl4oA4rv0is_tPRvHLzLauFbz2FNwODpOw1uWHz8,2963 +torch/include/ATen/ops/_upsample_nearest_exact3d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=eadGuYYpcRtt12ar5MYVfFngzog_07F4MPlDF9u-g5w,1269 +torch/include/ATen/ops/_upsample_nearest_exact3d_compositeimplicitautograd_dispatch.h,sha256=e3rIBti5QMqyZIKLtt00UxB-_cPoCXL9Odg8yBxrOJ4,1072 +torch/include/ATen/ops/_upsample_nearest_exact3d_cpu_dispatch.h,sha256=Uelwg7W0Pkj_8wFnF1EXoCGmPYfMlpR2wkrmEXnI-Ko,2229 +torch/include/ATen/ops/_upsample_nearest_exact3d_cuda_dispatch.h,sha256=qlIYrwXVBcflGrxF7arO_QTgzNJYr1rXgV5ibJ6i1ow,2231 +torch/include/ATen/ops/_upsample_nearest_exact3d_meta.h,sha256=q2i-cQEHk2WCildUDBo9q1xoaDeW-TCAOXdc-EkNwGE,750 +torch/include/ATen/ops/_upsample_nearest_exact3d_meta_dispatch.h,sha256=pw0QsXZZLLh35RnSUqXsJFivvijwVY50n5yAdbAr97Q,2231 +torch/include/ATen/ops/_upsample_nearest_exact3d_native.h,sha256=jKpVJ-f1iROXW5uJ4pDRZ_iCvS6PwXQiJ0tGdL8qIr0,1570 +torch/include/ATen/ops/_upsample_nearest_exact3d_ops.h,sha256=d0rTL6dJVuCaSZyD-XL6iZJD64aZ4gjUixBuH_fcKBM,3493 +torch/include/ATen/ops/_use_cudnn_ctc_loss.h,sha256=fMn6w164CfOYh_Qo5tcUICvXMzKLSYzKfAD2DlXdxwc,1381 +torch/include/ATen/ops/_use_cudnn_ctc_loss_cuda_dispatch.h,sha256=Hv06OHwxf_a-jfJgXbdg5nezPy3cUwS7tDTG6MxNuLE,1042 +torch/include/ATen/ops/_use_cudnn_ctc_loss_native.h,sha256=1_dldeHBqCii6oI-8hpanozWpaJr9xfQCerPTM_sSK8,801 +torch/include/ATen/ops/_use_cudnn_ctc_loss_ops.h,sha256=RXFBsTXU9fidi_xcqkCeNOxVfcs9H5U25B3Co8KgSzs,2295 +torch/include/ATen/ops/_use_cudnn_rnn_flatten_weight.h,sha256=KnXcL1nCGeAtGhXGN98z9Seez7zAgCI91apvdOrNpyI,714 +torch/include/ATen/ops/_use_cudnn_rnn_flatten_weight_compositeimplicitautograd_dispatch.h,sha256=gXw5dbBBXj2CfFu_I0rKKMUqlDkzBmwFiQaQ2qUTTbI,781 +torch/include/ATen/ops/_use_cudnn_rnn_flatten_weight_native.h,sha256=YZUzvNm76Bj0fd8ZBxwlzjEbXor8N_dORBxzbKp06Lc,491 +torch/include/ATen/ops/_use_cudnn_rnn_flatten_weight_ops.h,sha256=sD20TkLjXrcGJbS8xV9VFNITLHOYwIDrB7wv6DBBlM0,958 +torch/include/ATen/ops/_validate_compressed_sparse_indices.h,sha256=bmIoxg-xxDBkfYiMSiu_dXwFpz3MSw6_MWeU_t88T-o,984 +torch/include/ATen/ops/_validate_compressed_sparse_indices_cpu_dispatch.h,sha256=P3kWDsfvOzYN9ekghjDdtmxe1KpxPrxC_kmh5-bOdYQ,860 +torch/include/ATen/ops/_validate_compressed_sparse_indices_cuda_dispatch.h,sha256=NKqyaI29P79JvFdcgTBh1xE_raWC1IOaP16GZnxgO-Y,862 +torch/include/ATen/ops/_validate_compressed_sparse_indices_native.h,sha256=5uK9tQQ0Pa__Rgspht2eOomIyOokHpqzOrWPMXmXhtM,795 +torch/include/ATen/ops/_validate_compressed_sparse_indices_ops.h,sha256=JEB0CUtjsZSYPXvrW5EBaFe6WSzIPyPw91yS_LzgisI,1362 +torch/include/ATen/ops/_validate_sparse_bsc_tensor_args.h,sha256=a5Hnpa01XG2-xeG8JjWmGqU4cLlWis5fFhaJmNKSy2w,1034 +torch/include/ATen/ops/_validate_sparse_bsc_tensor_args_compositeimplicitautograd_dispatch.h,sha256=nF6g1-k5X4eJk6k4cK9Loa3YrYyepRlZGatfRqD9cSw,948 +torch/include/ATen/ops/_validate_sparse_bsc_tensor_args_native.h,sha256=Mg0pRhIreZxLi2yz_Onrvf8_gW6rpy18LJqy23knwTM,658 +torch/include/ATen/ops/_validate_sparse_bsc_tensor_args_ops.h,sha256=Jc9G99QvX7c4gvP0eZ0Hqy1WO3xkD00Meevq1mECSAY,1455 +torch/include/ATen/ops/_validate_sparse_bsr_tensor_args.h,sha256=fAyWFOvPWEfq7V2mxI-7RpGBzmfzY8rB15qmDeQqZKk,1034 +torch/include/ATen/ops/_validate_sparse_bsr_tensor_args_compositeimplicitautograd_dispatch.h,sha256=MV9ZIlb9B1ikO0QRTpwPK4gDB6TZAXRdHQdJHCkOTV0,948 +torch/include/ATen/ops/_validate_sparse_bsr_tensor_args_native.h,sha256=yrYwmmuIxP6vYVZPy2LwltmJnztoXc0G6hhk-2pumME,658 +torch/include/ATen/ops/_validate_sparse_bsr_tensor_args_ops.h,sha256=l_fnctcDI4TCR5ROhqKQs-Iv63s-9nIeqltOjtVZmfc,1455 +torch/include/ATen/ops/_validate_sparse_compressed_tensor_args.h,sha256=ymlyEo092CZwM1K1E1R1mFukKCSd904Pq3k0YP1XVPw,1128 +torch/include/ATen/ops/_validate_sparse_compressed_tensor_args_compositeimplicitautograd_dispatch.h,sha256=v8SILYu4WrJzNjHdg6h9iCDcJ-tz5_YKWgYpd1_9J54,982 +torch/include/ATen/ops/_validate_sparse_compressed_tensor_args_native.h,sha256=bt0cszRsvMTMmWtirHAKH2ZfFNlJcpGNn4ZQhOz3z4g,692 +torch/include/ATen/ops/_validate_sparse_compressed_tensor_args_ops.h,sha256=Qc1xcR1drLUizxeyzX7HMN2rQiYx0ECaU-74DmPnm_E,1565 +torch/include/ATen/ops/_validate_sparse_coo_tensor_args.h,sha256=eNx5Y39EeCjGNOAJpr9wc7qD-uE18_2Ab27KMUMgIJA,1044 +torch/include/ATen/ops/_validate_sparse_coo_tensor_args_compositeimplicitautograd_dispatch.h,sha256=A6UWtlmrJ8HfA8avrMebZGwViXNUea5X_wQST12FQqY,962 +torch/include/ATen/ops/_validate_sparse_coo_tensor_args_native.h,sha256=XJhXNx1I9zyqoYzM7WZv2z8sVOgXwN2C-xvauChgVZk,672 +torch/include/ATen/ops/_validate_sparse_coo_tensor_args_ops.h,sha256=TkpvpTelloz4np4uwX6mVtv42OC8uNe4wLr4ycR9w3E,1456 +torch/include/ATen/ops/_validate_sparse_csc_tensor_args.h,sha256=sWVoYM4B4GPJTXXhT6oshdR36rZB5IHM-ofGpu02SwQ,1034 +torch/include/ATen/ops/_validate_sparse_csc_tensor_args_compositeimplicitautograd_dispatch.h,sha256=HqIYW2BONsKaGR5ObKmEqoVc3Y7JM6rHCL8DjkZT4oA,948 +torch/include/ATen/ops/_validate_sparse_csc_tensor_args_native.h,sha256=mIG1CyoRx6WcW96BPmlbAWrJQ52OqP3ZRmKzloy3SVQ,658 +torch/include/ATen/ops/_validate_sparse_csc_tensor_args_ops.h,sha256=1bRjuk8_13bqAja1EKe-JOPqOLL4XO5i7hKh1uC8KQs,1455 +torch/include/ATen/ops/_validate_sparse_csr_tensor_args.h,sha256=THsbVEIB_3y6dQL--xw8dasPpYV_-FuQcab1KJKMaTk,1034 +torch/include/ATen/ops/_validate_sparse_csr_tensor_args_compositeimplicitautograd_dispatch.h,sha256=f8lg2gDn7CP-pgTVl-0rK22SI0N84ka1iGoVbANxI7g,948 +torch/include/ATen/ops/_validate_sparse_csr_tensor_args_native.h,sha256=n5_Ggf-QXL3bC4ygCsByMu0NVCDhncevVirazn8ewa4,658 +torch/include/ATen/ops/_validate_sparse_csr_tensor_args_ops.h,sha256=1pjIdPl4GCVl1wMJXsUb6A7YkxDnBAazcQx9MgmUF_w,1455 +torch/include/ATen/ops/_values.h,sha256=M6hy3_-M0laF5qgSjz2U2xTs8Quds9Ze2qnOJMgPrh0,531 +torch/include/ATen/ops/_values_copy.h,sha256=2vqpbH3U48-wknlLc4t58z6e-Obp0BF6lSwiCS8vQrQ,1127 +torch/include/ATen/ops/_values_copy_compositeexplicitautograd_dispatch.h,sha256=VHV0T6wGo7v66wBqo4uMUX6U4RlS7FfgCxto5RfDqwc,903 +torch/include/ATen/ops/_values_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Gz8Mah9aL7JerxBoMM-qIANHLbqPV6eb0CLF7Ek0hKA,819 +torch/include/ATen/ops/_values_copy_native.h,sha256=cD3AMkiSXJTke2uTgHaJNgx2SwlzPzA9mvcsXxrU8jI,588 +torch/include/ATen/ops/_values_copy_ops.h,sha256=Cr5LtwKchkKVt4F8DOxxBBy3oOoyVRRF1jfrDWVqMR0,1623 +torch/include/ATen/ops/_values_native.h,sha256=UtZmgRUUxT1ITmQuqyvxGi-9PWfzyQf26yXveBoYRTo,505 +torch/include/ATen/ops/_values_ops.h,sha256=v5eHjqWHCMbC7iry-iUx4VA_THppBA_APnWyLto3-vg,995 +torch/include/ATen/ops/_version.h,sha256=PSZPLvyGBLEvv1z7uSkvpEsRYDxTSXQvEyxJHf99FZ4,532 +torch/include/ATen/ops/_version_compositeimplicitautograd_dispatch.h,sha256=6it7nPoV3u_gPapMm4KFQGnsuiJrZ5NdBireBr5lzCY,786 +torch/include/ATen/ops/_version_native.h,sha256=xORcktpAQxxoE-ynCKXrShVQwOq1pRXzbf57FRgoXVM,496 +torch/include/ATen/ops/_version_ops.h,sha256=qJ8Fdo7AgO3bbQOTZCmn_L2hZ0d2PUm4F3bn82x-NkI,980 +torch/include/ATen/ops/_weight_int4pack_mm.h,sha256=WE_djn6ZgJyYWNDV03POV0G-cPXLHfFY3tj70vHYXfM,886 +torch/include/ATen/ops/_weight_int4pack_mm_cuda_dispatch.h,sha256=8JA9J6eXRis1NpLq6N-wc5AajBoRrCZwLJzV06HbiGA,838 +torch/include/ATen/ops/_weight_int4pack_mm_for_cpu.h,sha256=uLEumyj7dEbIIAQNe3KCx5PzKxNE4CeOqVI5Ze_kjKM,918 +torch/include/ATen/ops/_weight_int4pack_mm_for_cpu_cpu_dispatch.h,sha256=OLfqOnwTNnlbiWfTcNKjce0wpDJki8bq3Hz1dCOM9Ao,844 +torch/include/ATen/ops/_weight_int4pack_mm_for_cpu_native.h,sha256=3_eZ5stOgrpy5VqA3l2YtivocKH_E6FEg0QO6G3Sjm0,594 +torch/include/ATen/ops/_weight_int4pack_mm_for_cpu_ops.h,sha256=MKId8N2G_ElzZR_mar_mny5oAmwggWmZfJWjWmxbG1A,1310 +torch/include/ATen/ops/_weight_int4pack_mm_native.h,sha256=e0JvOErEK-pWsJiG91cdJiN-kZ2Ls3__RGfzWDB51k0,595 +torch/include/ATen/ops/_weight_int4pack_mm_ops.h,sha256=0_ouq7qQcNe34NvfCU8F_8jyf7mB45SkGn5N3_oYLJM,1286 +torch/include/ATen/ops/_weight_int4pack_mm_with_scales_and_zeros.h,sha256=OMf4lN_YrM0E00XLPCrrscAWf3nwP8CKVphTfELrxaY,1000 +torch/include/ATen/ops/_weight_int4pack_mm_with_scales_and_zeros_native.h,sha256=9XaUGLJyiIUXnjau4tp-HFMbdt34qeE71RClgADHWpo,442 +torch/include/ATen/ops/_weight_int4pack_mm_with_scales_and_zeros_ops.h,sha256=Puut6x7lHI8ejkhzLkdu8nDQx1StURj_Lb6qEtKv3uM,1417 +torch/include/ATen/ops/_weight_int8pack_mm.h,sha256=YVvNgmuiWg7BCa57fSqaThd_tH6W3Cc9fXYDbEhGc5U,814 +torch/include/ATen/ops/_weight_int8pack_mm_cpu_dispatch.h,sha256=tCbxxp_AbP_py5tGir8fAq1GGnZx00vxrvPArP8le-U,808 +torch/include/ATen/ops/_weight_int8pack_mm_native.h,sha256=kAj0qcxfC2GlApLn5T-CsBQpt8vxWmsmq6FbLftd5Vg,566 +torch/include/ATen/ops/_weight_int8pack_mm_ops.h,sha256=YjAzVWAUkocD-Mrb-QM32Kmul5F9U1Oj7borWpVNoTA,1197 +torch/include/ATen/ops/_weight_norm.h,sha256=Me1zeNdsokPo2palSdVhEW1ClUce7aY6BZadhROGOJU,749 +torch/include/ATen/ops/_weight_norm_compositeimplicitautograd_dispatch.h,sha256=WbL_kFUUHk9KUGK9XR0wQmQ20BmriNsSTyDkqONyKHw,827 +torch/include/ATen/ops/_weight_norm_differentiable_backward.h,sha256=ljrYW0zTK4aTrBSXfHfxudKXsDIy9boxOJ_X3JuwaW0,1027 +torch/include/ATen/ops/_weight_norm_differentiable_backward_compositeimplicitautograd_dispatch.h,sha256=ICA7ZnUFUoquJklf6ZhoXGGToVMykCxEyqt9jc7S1GI,945 +torch/include/ATen/ops/_weight_norm_differentiable_backward_native.h,sha256=siZA_oYB5QcRFM3JxVN99GbDTtf6ubuN775WFHsv1Ow,655 +torch/include/ATen/ops/_weight_norm_differentiable_backward_ops.h,sha256=D2GQuuo5Lx4OqHcccHb3eZZDqO6s2xTVwVacrygaEBc,1499 +torch/include/ATen/ops/_weight_norm_interface.h,sha256=NJViu4vR1BnxqUAkgp_oAYuLrCq072Nas0HkWkKOWlc,1603 +torch/include/ATen/ops/_weight_norm_interface_backward.h,sha256=rpnxqxvOrIediF9V06QV_sLdRYZE-9en8diwlah9tC0,2136 +torch/include/ATen/ops/_weight_norm_interface_backward_compositeexplicitautograd_dispatch.h,sha256=nJdseLLYMTbb5hUkLn096lc3NyUWrw15Q5GSRECVZfE,1241 +torch/include/ATen/ops/_weight_norm_interface_backward_cpu_dispatch.h,sha256=N2S37Vbi8XA2hiBNkF8Uoz3kozaDqE9R-8njjI5FGVU,896 +torch/include/ATen/ops/_weight_norm_interface_backward_cuda_dispatch.h,sha256=wh4rsRNuYMs-aAcwh5faT1AG0fZwPSMzy3rTCw5RAGU,898 +torch/include/ATen/ops/_weight_norm_interface_backward_native.h,sha256=GPUfsa-GANxojRwIjFgO8gLA95LGEl6DAJ4x5Jj5jpw,1099 +torch/include/ATen/ops/_weight_norm_interface_backward_ops.h,sha256=bAQCbEsjt_sO-aOVhlTcpVZCGCU8z-nxL4Lt1ONCERk,2665 +torch/include/ATen/ops/_weight_norm_interface_compositeexplicitautograd_dispatch.h,sha256=u2fG0iAoUEjLEhJ54yJMgpjks0xwKBFFIBgXtdE0Yuc,1083 +torch/include/ATen/ops/_weight_norm_interface_cpu_dispatch.h,sha256=VD2g_J8_r_O4uuh1nn61jIztSaGva2KJuDLxTDeQeRI,818 +torch/include/ATen/ops/_weight_norm_interface_cuda_dispatch.h,sha256=_4w9TlAr25bwiXE-CIuXj0Eo1UtCcH5YeKHUEAu9Qx0,820 +torch/include/ATen/ops/_weight_norm_interface_native.h,sha256=WWL0myaWDZoXFSHUrtrW3nbe8IpYpLN7OH5tWS2FqY0,863 +torch/include/ATen/ops/_weight_norm_interface_ops.h,sha256=7WByrd1t7HdF200lM8CKoYkuz-2O-PbndMxvd9MUL3M,2157 +torch/include/ATen/ops/_weight_norm_native.h,sha256=zJOcrgW90HZ4YU9hUqtz7rEW9f90dBihb8B8LvGXsa0,537 +torch/include/ATen/ops/_weight_norm_ops.h,sha256=HGnsIr8c0Z-JE-8RsaS26UdgvlEk6rYCFm0Yk1Hvkck,1115 +torch/include/ATen/ops/_wrapped_linear_prepack.h,sha256=0VmWAtUTnASqotMcWSYwC-sPvfEg-LSw1pL2KDLZX2U,937 +torch/include/ATen/ops/_wrapped_linear_prepack_compositeimplicitautograd_dispatch.h,sha256=WTxSknJuocCp5l_PMnKg8gKFPzI36QydmasXU2w-PI8,902 +torch/include/ATen/ops/_wrapped_linear_prepack_native.h,sha256=gIRUrNDp7jw-POKGEWc9ZQcVMMdaW4yB3Xat7r2Tn8Q,612 +torch/include/ATen/ops/_wrapped_linear_prepack_ops.h,sha256=T8unhHEFc_nqdhfilvUhg6cxSTSmHb6dWS8r4z6orSk,1355 +torch/include/ATen/ops/_wrapped_quantized_linear_prepacked.h,sha256=abSoZuf2leJqXddnAmJgvvTMh08-z919RYDyAM2GTHk,1205 +torch/include/ATen/ops/_wrapped_quantized_linear_prepacked_compositeimplicitautograd_dispatch.h,sha256=DEbqIag0qrWHXMp0pRFTa0BKSWiMxVGbRYQbm66i97Y,1012 +torch/include/ATen/ops/_wrapped_quantized_linear_prepacked_native.h,sha256=HFeDV0OvAk122Z9pEFgYR9kgAhHqYuv5he9jzOTSHn4,722 +torch/include/ATen/ops/_wrapped_quantized_linear_prepacked_ops.h,sha256=TgTKuSHorLLe9P1NAmo9O9Fu-2XXdQN1nYtbtkpp98E,1706 +torch/include/ATen/ops/abs.h,sha256=WO-f9pAlhmgDBHDGM2LEbGn7xIb2l2utOHt9Ur2X4Go,1175 +torch/include/ATen/ops/abs_compositeexplicitautograd_dispatch.h,sha256=zfiGci0gSyW0Z_NWYgyNO69etSIJgXjr_NXORxo1WgI,833 +torch/include/ATen/ops/abs_cpu_dispatch.h,sha256=DvRv9U95xiZLtQXv835s-Q7xz-Ipp9h1QOqLama5G1w,841 +torch/include/ATen/ops/abs_cuda_dispatch.h,sha256=qbJaZeWTxZnhxqQSkjyEdmN4XwJ9GGEqUGmYobTMBQM,843 +torch/include/ATen/ops/abs_native.h,sha256=OMoD7TI-v0-XdT6ffedJJQ9yRpbIY4ZcB35AEtMo21M,1154 +torch/include/ATen/ops/abs_ops.h,sha256=adpajGYM8yDBg0vAfdJUoFHsdcu9d_v54SLY8hyjN9Q,2070 +torch/include/ATen/ops/absolute.h,sha256=IjsS2RIsqtA9Y_tgwvmQOejPGX1Q-aYzUI30EzNFmtU,1087 +torch/include/ATen/ops/absolute_compositeimplicitautograd_dispatch.h,sha256=9ITB3u38WT9bhmNAoN7-diWPsY6OrhNBd8roo0PLwVE,1006 +torch/include/ATen/ops/absolute_native.h,sha256=GJ9ji6y285YL9lfMdbW8ddwu7zjiQLod1yIzeaF29aw,634 +torch/include/ATen/ops/absolute_ops.h,sha256=j3dw9a2EwVqHb04D1lw3mWlPLmiYbckwW1FluPoQfXY,2115 +torch/include/ATen/ops/acos.h,sha256=cXC2o0x_BdviG12k_he31dFEDGN5G41GDYUlKTJkmt4,1188 +torch/include/ATen/ops/acos_compositeexplicitautogradnonfunctional_dispatch.h,sha256=4gjkkjZDaoF32DObypEV0oC2wjmw6Vp1xcbjfyTEzJs,861 +torch/include/ATen/ops/acos_cpu_dispatch.h,sha256=eHMvbnS_MyF2_JVYH63myih7un_YxodeWqf6Q8qCTII,946 +torch/include/ATen/ops/acos_cuda_dispatch.h,sha256=coA-iEYuPgrMK88KGZqJ15fFOxe2T6-xCUjSX5_8EsQ,948 +torch/include/ATen/ops/acos_meta.h,sha256=kays7PPR67Wss9O1V6v1Ju1ZliKH15xxBXYu_2M5n8w,592 +torch/include/ATen/ops/acos_meta_dispatch.h,sha256=hZHZkFmaULDlhHh_MIHJO3Ikoac82e9MLlV7dA2GIXo,948 +torch/include/ATen/ops/acos_native.h,sha256=ZcgdccycruF5-_g-CIHk6ME_hMGljpMj6xf7_9id584,613 +torch/include/ATen/ops/acos_ops.h,sha256=IxqOhoUvpgbIEwrOQgDcAfsPflW2-JQVJfvcqnAo8Do,2079 +torch/include/ATen/ops/acosh.h,sha256=qg37_W9HlKM18Gm0XkQ0Vbx8wTE3Ao4cYqmYcT7AI2g,1201 +torch/include/ATen/ops/acosh_compositeexplicitautogradnonfunctional_dispatch.h,sha256=oMgL0_iUvoWj3JRTXjWRxrlBxU2qcMbyCsJzFgZ5hns,863 +torch/include/ATen/ops/acosh_cpu_dispatch.h,sha256=EukLZ7hwut7vEKNw0hd9_0ocBRTD5kFtw_aby85DtC8,950 +torch/include/ATen/ops/acosh_cuda_dispatch.h,sha256=WzcPjSduo996_4SmNbqbsHLMBi6U_xCLUC648MqYWmM,952 +torch/include/ATen/ops/acosh_meta.h,sha256=3b_hobwBLw5340337BpLIJefnestiid_9glY-fEPsSE,593 +torch/include/ATen/ops/acosh_meta_dispatch.h,sha256=8_ZhijK-SyEQmDxs1ng8xfV8uX57tLXUhlolLYbHKo0,952 +torch/include/ATen/ops/acosh_native.h,sha256=1zP9sV3i3gykTy02ml7g5wJbGzov8gCpfKTmYhtpm-g,616 +torch/include/ATen/ops/acosh_ops.h,sha256=oRAAX0Hi_0mj77Jzy1vt9Z6IDp2JOPFP9zJNDbdh1II,2088 +torch/include/ATen/ops/adaptive_avg_pool1d.h,sha256=cf4jDlAVnvqJm1gm-QkLeFnVW8FZ8L1QaSwb0QvENAc,1383 +torch/include/ATen/ops/adaptive_avg_pool1d_compositeexplicitautograd_dispatch.h,sha256=Gc0WlsEsmC9HAkqRmT_1trJPbTUBfm5ogcuiZaMFDVk,975 +torch/include/ATen/ops/adaptive_avg_pool1d_compositeimplicitautograd_dispatch.h,sha256=hBl8GNIFQXOy9Yaq51uNoiHfpFDZP5LLr978NH_4zcc,829 +torch/include/ATen/ops/adaptive_avg_pool1d_native.h,sha256=0MlXFX4Hp-kqEStf9UjvpcGkNRLAaOY2nqQRm474PVs,660 +torch/include/ATen/ops/adaptive_avg_pool1d_ops.h,sha256=1jW2Z5UGllU8rRpkyiu0_WSIujtCzg0Fqw7zvdoCqxw,1855 +torch/include/ATen/ops/adaptive_avg_pool2d.h,sha256=J8bjHCzxQ4MmBYzK85MoZ38l5L2z-0Q7WkVjry7BVVg,4202 +torch/include/ATen/ops/adaptive_avg_pool2d_compositeimplicitautograd_dispatch.h,sha256=vmx2vM8mQ_qe0-tEVy23IftHIzWstanKcquIG0fh_qE,937 +torch/include/ATen/ops/adaptive_avg_pool2d_cpu_dispatch.h,sha256=LmQJ5Q27CYs6InIjOFaMmbdlnkLfKH8zGtSE9d6IFEk,1196 +torch/include/ATen/ops/adaptive_avg_pool2d_cuda_dispatch.h,sha256=eynEnu65UUWEB_om1Ez_kzlGb_Y19OoVd41BIwmdYGQ,1198 +torch/include/ATen/ops/adaptive_avg_pool2d_native.h,sha256=Mw6I2RXoHIDyr3HntVe3cEe7Ax5zy_V3U7F0jgKpIGg,934 +torch/include/ATen/ops/adaptive_avg_pool2d_ops.h,sha256=P-kPCGqSUZcY-uJWIyLMFfY6WKvt9s__Loehga7AZaI,1885 +torch/include/ATen/ops/adaptive_avg_pool3d.h,sha256=rXJ53stWtkOTe1BaQc_xmt4F6qcl-svRPVELQtzZqF4,4202 +torch/include/ATen/ops/adaptive_avg_pool3d_backward.h,sha256=nagXcLYOFhE3m4V1xTI0OBCPakZP74Th7WbJkgiF964,1281 +torch/include/ATen/ops/adaptive_avg_pool3d_backward_cpu_dispatch.h,sha256=cYys0Ypk61evA7wztvRfH6hS3dyyinBo991_q13ihHI,969 +torch/include/ATen/ops/adaptive_avg_pool3d_backward_cuda_dispatch.h,sha256=BIhrAR2kADjsgCwW0TsPhdEa5ZHz_nXLvBZ9iGhg8dI,971 +torch/include/ATen/ops/adaptive_avg_pool3d_backward_native.h,sha256=6acAUQ6I9c1HOBtK6Rk7uK4yJ0x9l0cJo880HOY1i6o,731 +torch/include/ATen/ops/adaptive_avg_pool3d_backward_ops.h,sha256=H4Bb_7mu8f0K8UJLF9u-uKuv9o1p0iTOF9R_7SJt8Ag,1288 +torch/include/ATen/ops/adaptive_avg_pool3d_compositeimplicitautograd_dispatch.h,sha256=aNH7q4mxEfm7fHfn5XlhLv_r9YWpCIaiAxJCOP9mUt8,937 +torch/include/ATen/ops/adaptive_avg_pool3d_cpu_dispatch.h,sha256=gaZGrG9MFfDMjkMu7J5s3dcoCfSr4Bh6qQ_gJ9UBd6U,1196 +torch/include/ATen/ops/adaptive_avg_pool3d_cuda_dispatch.h,sha256=935YZkm8kBuNU4pW3QnfQM6HebIq76YhXbEBOlJ9H1M,1198 +torch/include/ATen/ops/adaptive_avg_pool3d_native.h,sha256=GaRLL8y4Jv5wW4IuAQcmBKiSrUtL3I4mTUy4fQQDv6E,936 +torch/include/ATen/ops/adaptive_avg_pool3d_ops.h,sha256=NxYj9kvcVUVRuzbnlN1ks9rzqgZVZbnpnB7f-UJvFQk,1885 +torch/include/ATen/ops/adaptive_max_pool1d.h,sha256=CVoAJEEGuIGjuv85PosT4QUJO_ZY86XxJnPCOFS9Gvg,817 +torch/include/ATen/ops/adaptive_max_pool1d_compositeimplicitautograd_dispatch.h,sha256=WWjtiBYyJqGnW7BeVQARfMQUZ2ROoap0m_HUojrCmwU,854 +torch/include/ATen/ops/adaptive_max_pool1d_native.h,sha256=KUFbIF7-xkmkVDsu1QXKOTMUX5kxOGOPqA2gmlB4u_8,564 +torch/include/ATen/ops/adaptive_max_pool1d_ops.h,sha256=ILVmEJW5fEi5bmyAUyyEEGg6KByU7W-AJ6MGxfXWp-M,1205 +torch/include/ATen/ops/adaptive_max_pool2d.h,sha256=YpaPd3Cz7CS_dsoJUqrCSbUVeQHyt3MgCvttqDpDDqk,1602 +torch/include/ATen/ops/adaptive_max_pool2d_backward.h,sha256=cg4oK0hQboArcZUYM0TVakZWiQtpcKx9Jje2DNs4O6s,1711 +torch/include/ATen/ops/adaptive_max_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=MYvc8odjKrlKnxWXLrY7QwP0mktQsv1wezjkLwVgDFE,895 +torch/include/ATen/ops/adaptive_max_pool2d_backward_cpu_dispatch.h,sha256=69c9Z4cBoUeC0P7YcVX6tgrrS9TFbDnIRhvu2MaP-t4,1162 +torch/include/ATen/ops/adaptive_max_pool2d_backward_cuda_dispatch.h,sha256=_z0nEVSyRr4y9s7kTbH2G6q_1sKepKusHJ_1hddcJLo,1164 +torch/include/ATen/ops/adaptive_max_pool2d_backward_meta.h,sha256=EkTp4xNSw1024-UFXNRzZj1Ycs9F_zZKCkUZBFlw7DE,676 +torch/include/ATen/ops/adaptive_max_pool2d_backward_meta_dispatch.h,sha256=QxbM-cXRLqZA7jtLiuhn4XumRdz6WOVB-6c--Fuf15Y,1164 +torch/include/ATen/ops/adaptive_max_pool2d_backward_native.h,sha256=lpbVvVXw9nAuPqcQT9789wX3BXubg6tDmDstcGgIf44,1016 +torch/include/ATen/ops/adaptive_max_pool2d_backward_ops.h,sha256=ZKeOMmHJPcjkE8--k-IHFgLg_1DG76xuIdBCW7EDfBc,2153 +torch/include/ATen/ops/adaptive_max_pool2d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=oU8dUuV2gJJqHV17BLmvRyYmp5phvBeJnyS342oyU4s,880 +torch/include/ATen/ops/adaptive_max_pool2d_cpu_dispatch.h,sha256=tWTFGmuZQFcWAAsjc81XecLvRdGDelO3arBuX18yO44,1151 +torch/include/ATen/ops/adaptive_max_pool2d_cuda_dispatch.h,sha256=CzE1nwvd2-xraXySgFiYfNosga5Ep_I_R-0fCw6vlpQ,1153 +torch/include/ATen/ops/adaptive_max_pool2d_meta.h,sha256=BTpIU6dPITf3E0pLbLviIgJRLjxU9Kl1y2XokNeUTYY,636 +torch/include/ATen/ops/adaptive_max_pool2d_meta_dispatch.h,sha256=ii-RPUkmqF1JGV-XM6EQqIR8mAiUaxQwvZXXo9cf0vo,1153 +torch/include/ATen/ops/adaptive_max_pool2d_native.h,sha256=SB2mFefVM0yJaHTIAVZUSH00oIvZFsPmrX0J26I9VZY,951 +torch/include/ATen/ops/adaptive_max_pool2d_ops.h,sha256=JDhJMmTpdOz7nyqrT3V3U73yWZ7a-x4gmSUTtdSL1dc,2113 +torch/include/ATen/ops/adaptive_max_pool3d.h,sha256=4-RpdRnAkEJ7GRlrFXyZ1gMY_6Jo1aVB3s0j6XygB6s,1602 +torch/include/ATen/ops/adaptive_max_pool3d_backward.h,sha256=iU4c6yTUeNqZSiuT4VEneNbiBmqJclPBDa680WxvVoE,1711 +torch/include/ATen/ops/adaptive_max_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=mqiio8OOBoCtaBCQiJocef72YiCmncANEktGRcemZLQ,895 +torch/include/ATen/ops/adaptive_max_pool3d_backward_cpu_dispatch.h,sha256=3wja2RgND1PCIh7B70UZLOw506U3w1YsW1z_LeHJH2c,1162 +torch/include/ATen/ops/adaptive_max_pool3d_backward_cuda_dispatch.h,sha256=yOmlB5uMx2GuM1Jf18_CCE6ml0WnjMMn5h9aIlurLA8,1164 +torch/include/ATen/ops/adaptive_max_pool3d_backward_meta.h,sha256=DzE4KLV2PAahfKRH6_7v2e9gbfFuDHMEtrR9mVxMoJ4,676 +torch/include/ATen/ops/adaptive_max_pool3d_backward_meta_dispatch.h,sha256=V2rVN6CIkg9Su41RB21nnStPE3LVsoJ_I6XluDtEKkI,1164 +torch/include/ATen/ops/adaptive_max_pool3d_backward_native.h,sha256=l6OBMY30EG61THmHs4p0bhY83ioyZFw4Rx7MGh7XAb0,1016 +torch/include/ATen/ops/adaptive_max_pool3d_backward_ops.h,sha256=wdU8eXMu2ovoOhODBJkIRY99GSS5_XQYBLivNJHoHQQ,2153 +torch/include/ATen/ops/adaptive_max_pool3d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=86JEJl0HPiS_vZY72z1CmLe3vW5oydQtCBI0phyS5pw,880 +torch/include/ATen/ops/adaptive_max_pool3d_cpu_dispatch.h,sha256=9Y0J81hdZFnW2-S8WSy3L5Bj4lsc2tt2GIbbeuKXoVk,1151 +torch/include/ATen/ops/adaptive_max_pool3d_cuda_dispatch.h,sha256=lHczTvUaVyAr_a1FSDitQo9wYC4NXpt8WNWoKmFFLMk,1153 +torch/include/ATen/ops/adaptive_max_pool3d_meta.h,sha256=FokgkCgKn58fCVoSfV7jV8Gi1xtErsKUasjnctN-UyY,636 +torch/include/ATen/ops/adaptive_max_pool3d_meta_dispatch.h,sha256=mkFo9EWMAHSIZXaqOguTWZ43PCRHOLHO_KpawVWT8NU,1153 +torch/include/ATen/ops/adaptive_max_pool3d_native.h,sha256=rl-Rn6819u60kI6Fb-_5su8BCunooRG32G-NZkU5z4k,951 +torch/include/ATen/ops/adaptive_max_pool3d_ops.h,sha256=7iRhfONNO4IRHCZrciKsKPYFQNRsKaxiVo1Myu5K4ns,2113 +torch/include/ATen/ops/add.h,sha256=HwOnFL2u8pDEATL4Y_vIBqQTKSZ6_BJjLn0Oqb7Hr_o,2192 +torch/include/ATen/ops/add_compositeexplicitautograd_dispatch.h,sha256=TntbCILJ0FcpUfob-PBJPd4QePLvLifpIt2hVZh-S-k,1200 +torch/include/ATen/ops/add_compositeexplicitautogradnonfunctional_dispatch.h,sha256=u9JmtrPVtvyNO1AOKf8ooe89vSTBCMYU2GyfOIzaMnI,967 +torch/include/ATen/ops/add_cpu_dispatch.h,sha256=fxizkKhw58936ek-VZ-gp2syRyYwtOxjW8dfVM5pcnc,1156 +torch/include/ATen/ops/add_cuda_dispatch.h,sha256=bxs1-Q16HaJ--w1Sy81N4zYhkMy1DGG-99fKPZIqi9U,1158 +torch/include/ATen/ops/add_meta.h,sha256=c3-XEdxYVErDdQ1PAf6nIV1AWhq6t1fB8b-j94WA9ZM,650 +torch/include/ATen/ops/add_meta_dispatch.h,sha256=yPuzknZnlePcWAEabZ0n_LoNfJr8q0KZOFDN6_v9qYY,1158 +torch/include/ATen/ops/add_native.h,sha256=gU54nk274yH0a-Z0RVAw2SyBEV4L4bQnoA8JEaDvERc,2981 +torch/include/ATen/ops/add_ops.h,sha256=tOj5gU3MDSYY6U5_BI7ToEdbodaOM1N5ZUI98bq6cjo,4816 +torch/include/ATen/ops/addbmm.h,sha256=p83fty3pf2QYuhlo0KbSOKrBlOU3uXKVhuHCAehJafo,1663 +torch/include/ATen/ops/addbmm_cpu_dispatch.h,sha256=Mu_xmnT5eicVZu-zNpnzepAn_cyD6zkVeVFXl6hqVyI,1386 +torch/include/ATen/ops/addbmm_cuda_dispatch.h,sha256=9lKtvq6j_U-ABd8ZKbclHjZOPGK1qbpoNMaoD-B7wFY,1388 +torch/include/ATen/ops/addbmm_meta_dispatch.h,sha256=UfOr8Q26hQLBx45GgxkO3FT7XVA_5l1-D6tCj5fd27Q,851 +torch/include/ATen/ops/addbmm_native.h,sha256=fThPgWaHu9u_brIIyd0EOXmgm1ED45T7oUJXbBjV8Uw,951 +torch/include/ATen/ops/addbmm_ops.h,sha256=6ULYIxr7xBUdb0qk0UCzSVyd871JrMObsMmwiOCscdI,3156 +torch/include/ATen/ops/addcdiv.h,sha256=NTVRjVkbMHaBUWD2e7Ail5akYxQZWOHkOvX22QcHysY,1549 +torch/include/ATen/ops/addcdiv_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Pm2z781OLhCy4ru4ShW_smFfq2mVECwr6tg75D1rFwk,1035 +torch/include/ATen/ops/addcdiv_cpu_dispatch.h,sha256=huG7nADQXtXyh3VYR819CegNI01ZKE_BGzsis-XENa0,1292 +torch/include/ATen/ops/addcdiv_cuda_dispatch.h,sha256=b9v-f_7ufvOm6ZAdfrd_MlEZ26_m6U0L-wANqsiq82A,1294 +torch/include/ATen/ops/addcdiv_meta.h,sha256=rbKI0jwzoEkbSu2ec7kqrerwPi5d7l_SemMM-IYww_Q,677 +torch/include/ATen/ops/addcdiv_meta_dispatch.h,sha256=rj-Husw2NZzw4CHlrHADZRwMoEM-oLVpKGJkFWlFeIw,1294 +torch/include/ATen/ops/addcdiv_native.h,sha256=3WXVchyP1ushDQpf8sFSbIX1SQGsdpxs7GaA6wHVy_k,704 +torch/include/ATen/ops/addcdiv_ops.h,sha256=3zorOS8FlBA2_2vZGbksfBFLnppAr5f8fIq6_7YQ4us,2928 +torch/include/ATen/ops/addcmul.h,sha256=-dgDoK5Glr6tzqUGXPgWC2qkhdSdJyFibnUQUPQ3Vrc,1549 +torch/include/ATen/ops/addcmul_compositeexplicitautogradnonfunctional_dispatch.h,sha256=AkW1m4qljSBBaAi4hqAB8OeN9vMuL2AxqdWpbE3cU7s,1035 +torch/include/ATen/ops/addcmul_cpu_dispatch.h,sha256=z1hZE422OZoaJKTErhy8Pq9ioJ12Axeq-RqHUD87KeI,1292 +torch/include/ATen/ops/addcmul_cuda_dispatch.h,sha256=VsuYivBBT1L7sbXhQgsUp1cr9K_Reuk_wvzdEDTsFlI,1294 +torch/include/ATen/ops/addcmul_meta.h,sha256=xdlYXyh0Lp1Qlgblgz3t0FiEDsq1M_1XRiv9nWX3Adk,677 +torch/include/ATen/ops/addcmul_meta_dispatch.h,sha256=LLrFuqPoNvHWA7z6N_Glihljx1pfkkwt3gt69X5cu5Y,1294 +torch/include/ATen/ops/addcmul_native.h,sha256=ASlEfOaVyld6M1oPbFAsjrKhp5fRpeVy0QZiPJmKa_8,704 +torch/include/ATen/ops/addcmul_ops.h,sha256=bLrVtq6lMYlbj7HdWAhJXvk0UqDTl46CVonHQscMg0k,2928 +torch/include/ATen/ops/addmm.h,sha256=3n59cIA5gMRFv9KOEUABWt-vQ5pkdUPNP0JtUEi0YbA,2920 +torch/include/ATen/ops/addmm_compositeexplicitautogradnonfunctional_dispatch.h,sha256=rdvT2t_NHJ52JCSe_DoQNwGBBaDDuKMd01r0uY83Jhw,1073 +torch/include/ATen/ops/addmm_cpu_dispatch.h,sha256=5l9nCMKlKaDyKqtAc3G7hD6sdGqSNhMZNmMYiCcjDNc,1366 +torch/include/ATen/ops/addmm_cuda_dispatch.h,sha256=ZdmpCMUuOGZXxZfKFqN39qCBX_HbWy0OVuULRlNVTLY,1968 +torch/include/ATen/ops/addmm_meta.h,sha256=tsFmBrhkP4NA2LL7viP_lrOkB2C2RsiDH654BwQMRwo,694 +torch/include/ATen/ops/addmm_meta_dispatch.h,sha256=tn3pSfkLn7dDzzhUIVpD6CBd_xibh_YCJ-_cQDedcnI,1368 +torch/include/ATen/ops/addmm_native.h,sha256=I58U2bGD-FXb-JrgGGRxeAOgMJJBGNCUwOKCoDOO3NU,3066 +torch/include/ATen/ops/addmm_ops.h,sha256=lgIsA_WqOlo-REVnKOiQpCNEUof2aPLxvkXE4XO18lQ,5113 +torch/include/ATen/ops/addmv.h,sha256=Zxvth8WzapQPL7TlFp5DSWrw7sOBov-EHxowPzn8WAg,1927 +torch/include/ATen/ops/addmv_compositeexplicitautogradnonfunctional_dispatch.h,sha256=g3VQHZ40XizUCSg0TdN8YGOfK8E_2xpvGJZzV5Yd24s,1069 +torch/include/ATen/ops/addmv_cpu_dispatch.h,sha256=f7-JW6x7d5nmgsmoVnQkmaiYqStgdLazu_u3Zc1NpfQ,1358 +torch/include/ATen/ops/addmv_cuda_dispatch.h,sha256=P-c9G5FS47FYqx8lrUJ15x4jSCTF8XW8q8sVyugsdpw,1360 +torch/include/ATen/ops/addmv_meta.h,sha256=iUP_j9yICNXOP1CGnd9HULS2kawdhrOuqxscDNuAcYo,692 +torch/include/ATen/ops/addmv_meta_dispatch.h,sha256=N_wgkNgBmKvVxFdZt_6pP09P90jAqaTxj-78G-YC9_A,1360 +torch/include/ATen/ops/addmv_native.h,sha256=E1fLpEmLD_ToO_UNQIIuPtuptBK7SnUrL-Zk9raU488,1360 +torch/include/ATen/ops/addmv_ops.h,sha256=-K5__nInLkXui6qg8h0EB1GPzd2wryD223mL-QrMtsY,3093 +torch/include/ATen/ops/addr.h,sha256=eVqGjSKVQcDXM6BIpGJsOq2mTMy260TuVVW82Xl7TZs,1607 +torch/include/ATen/ops/addr_compositeexplicitautograd_dispatch.h,sha256=Pu3Q4q9WKUHWhv49gw7odUfi8tIdhts4L0_V2pJnyzQ,1406 +torch/include/ATen/ops/addr_cpu_dispatch.h,sha256=Ad2wCQk2lGhduInrg8SfK-iI4k1uY62R0FtqwAsgWaY,1207 +torch/include/ATen/ops/addr_cuda_dispatch.h,sha256=_h7SOKKDvFjPU6imqBROBX9ti1V_832g1aWyfHVb-GM,1209 +torch/include/ATen/ops/addr_native.h,sha256=fdcLW-qIwTAAB37Ip0wmjnuT_A8AY7eZ455Ftgc4R9E,1279 +torch/include/ATen/ops/addr_ops.h,sha256=c3sJ9jbTh2Roe_1vIm-BaagNXAkp7h2fPwS87Le81Aw,3102 +torch/include/ATen/ops/adjoint.h,sha256=PX1_73xMygI4dyeKFKbqEMozc12-bFXZzUTspHXR2JY,678 +torch/include/ATen/ops/adjoint_compositeimplicitautograd_dispatch.h,sha256=FPY3WClLDGQ6IroX8yVlxBui1KOhBXOossPthMf1kwU,788 +torch/include/ATen/ops/adjoint_native.h,sha256=2Q5wVALoXbPYRJdV09GjsehKAEC4qWba7_bxnJCz7Rk,498 +torch/include/ATen/ops/adjoint_ops.h,sha256=oR4yAIza5lnC8-XKa85msWQG09szdTkR6U1eMWF3Qrw,995 +torch/include/ATen/ops/affine_grid_generator.h,sha256=uCMUxdQU--Fpz_HFvKmGqbW1ztJpHVpV_ZTYWPDdtY0,4618 +torch/include/ATen/ops/affine_grid_generator_backward.h,sha256=UzXwopDfJBaow99IwqqGoG-i3p20caxzD7rEA1vFowI,1865 +torch/include/ATen/ops/affine_grid_generator_backward_compositeimplicitautograd_dispatch.h,sha256=IHIvxW8yQI7z57xB-43dbjfRGT1WpU_UeFgsBgeFdc8,985 +torch/include/ATen/ops/affine_grid_generator_backward_native.h,sha256=sJW4FN28mjYez9l_aurrO95pp9onli1UN1tN74_l7v4,563 +torch/include/ATen/ops/affine_grid_generator_backward_ops.h,sha256=K6Y-awJli7MktNcIWFbrmVzaO1G8OSUpPe10nWRnHX8,1212 +torch/include/ATen/ops/affine_grid_generator_compositeexplicitautograd_dispatch.h,sha256=kdEp4zMDwPc8Qv2O__xANsDdeUbWCsNdtmEq0sSWjas,1541 +torch/include/ATen/ops/affine_grid_generator_native.h,sha256=1e7EQejq8jVgVYq4One5ecgvV8O169k-ZLPs9HuYwmo,703 +torch/include/ATen/ops/affine_grid_generator_ops.h,sha256=vvGZ1Z8sHWWfjI7LkrpLTpjN7gjpoSYsTlboG141diY,1991 +torch/include/ATen/ops/alias.h,sha256=BscOvKXdsoFPuP13bQoBiywpjc_Mqic7OM2ZKZSTY2s,670 +torch/include/ATen/ops/alias_compositeexplicitautograd_dispatch.h,sha256=YGYejf7wcZFLObDvfTta4CRehSr8uh7PhlBhQXS9kfQ,786 +torch/include/ATen/ops/alias_copy.h,sha256=t-yyGfpX_EmwDYBTahG1U6-l-l1uYUMT6NVWHtTAsZI,1107 +torch/include/ATen/ops/alias_copy_compositeexplicitautograd_dispatch.h,sha256=VogW3NiMpavSHQTn9loWTpYTHkMLNyy4qwLjJA9Pp6k,899 +torch/include/ATen/ops/alias_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=njuEY_U5G3uOwrKZ2CwqHxhKPNVGayLQG-QpQRT4lBw,817 +torch/include/ATen/ops/alias_copy_native.h,sha256=DeU3yDd4iEGD7eIEidVyOAs3bkr2JAt2LUq7JvFgImg,584 +torch/include/ATen/ops/alias_copy_ops.h,sha256=bpoUoweCQ6SZlzYqDdoHVnDytuAjthz6BSl0PPHfh6M,1611 +torch/include/ATen/ops/alias_native.h,sha256=dtZAnfLJvJylpwcsZK4C5173VjxExDmfHJHo7fArGsU,557 +torch/include/ATen/ops/alias_ops.h,sha256=HOPb3jiGzdVZr7hhj81Q8v8nrjgmVk_MKtnhUJqCx60,989 +torch/include/ATen/ops/align_as.h,sha256=dOF9RNmEo-rwsIR9QLWz6fpklpmcG0TWsKKsXcuNTL8,532 +torch/include/ATen/ops/align_as_compositeimplicitautograd_dispatch.h,sha256=tlpoGNBPlPpa1z07qgFZNtmUF47c7tGnTP9wD6ohxk8,815 +torch/include/ATen/ops/align_as_native.h,sha256=zQpQcIn3Ssx5vumwiIpLE7NFbm8wnICRniWxo4O18rw,525 +torch/include/ATen/ops/align_as_ops.h,sha256=LClALHo67akm1iPb7MMSJk_ZLWrKvQCLM8nNqYF4Owk,1078 +torch/include/ATen/ops/align_tensors.h,sha256=lidLcV3i5UKPhQls2EutfMVjHWLvYxdCdAzhYNWECDc,720 +torch/include/ATen/ops/align_tensors_compositeimplicitautograd_dispatch.h,sha256=c7cr0qF2bAQ4xPVMUKa8J3ZbGuX1G8HukyF8aTwjiCA,808 +torch/include/ATen/ops/align_tensors_native.h,sha256=qgDupCqjpecqzUGPPCkxCAzI3OP_7-qzi7gtLAdIXRc,518 +torch/include/ATen/ops/align_tensors_ops.h,sha256=_3xwYegYTUGOsXwWJOMGCGsMFC6gpHfNduRTxUo1LGQ,1053 +torch/include/ATen/ops/align_to.h,sha256=_pAdTpUR5iDAr-jF8xdnbWIEDyCkyq2HYBj70_m_-vY,532 +torch/include/ATen/ops/align_to_compositeimplicitautograd_dispatch.h,sha256=0dY_bHaC5d_0xfIsZfbhADzLlrxSeE5iC29ZAfSuuVA,914 +torch/include/ATen/ops/align_to_native.h,sha256=b6LerdFjHOSGmRyc-kyvg8YrJBqsQ2Zrq7rbJyLlqqw,624 +torch/include/ATen/ops/align_to_ops.h,sha256=to9pA2sudzlrfZqNi7OnfK64NcbPmr-0HXB9Qm9DdGM,1790 +torch/include/ATen/ops/all.h,sha256=-VReDI4bwbdBYLxrEjb9PJr6t5qw1gifxdc3pdqBT_A,3437 +torch/include/ATen/ops/all_compositeexplicitautograd_dispatch.h,sha256=L-7SdLRArZIrL1j9l3xshgOjoFDV9JHrzmCtepw89ZA,1078 +torch/include/ATen/ops/all_compositeexplicitautogradnonfunctional_dispatch.h,sha256=ib4UaDwcgDcjRDjVZ87gTYBWIg2vq-mJMeW2VqqXvQ0,996 +torch/include/ATen/ops/all_compositeimplicitautograd_dispatch.h,sha256=__Vecm_Je74i6QpwabblslaWCguVS1dY7mN-WQsJELI,1042 +torch/include/ATen/ops/all_cpu_dispatch.h,sha256=QmFEKXW9KQJKxaFhTNTU1E8okhEtBpx4fG7ORTx5Zzs,1537 +torch/include/ATen/ops/all_cuda_dispatch.h,sha256=Gybmy4ChVeEa8JXvrwOoEVohM4ClkXfZuiceD8u7ja8,1539 +torch/include/ATen/ops/all_meta.h,sha256=1Lwd96hNhF1f9YqqRMw3ZjwuTokjdomKmkDbrTpLZiY,894 +torch/include/ATen/ops/all_meta_dispatch.h,sha256=T0R4CY0VUlScdS_LZ0IiHk-NKBClrQp3UUZeSHFFx0g,1539 +torch/include/ATen/ops/all_native.h,sha256=_70z4NJRunPthhH4U2kGyh2N2jsMMBRs2ytZ5qh9gMI,1529 +torch/include/ATen/ops/all_ops.h,sha256=2G8SeA1bxRzbnJNK3Gvrb3MQR8EvAzlask3XL9x2KrQ,5682 +torch/include/ATen/ops/allclose.h,sha256=-ehBaT0Whn0Ls_o50SpvDOkEdiNUeufmwtjvN3rV1X0,856 +torch/include/ATen/ops/allclose_compositeexplicitautograd_dispatch.h,sha256=R-JiN6b_MRzVICJAkkht9QVDdAbg_lnFZG5AthSxGyk,869 +torch/include/ATen/ops/allclose_native.h,sha256=0PPUiEle20Ee4WEBNLvklKgFQibDu1iJXuPjFCGJmr0,579 +torch/include/ATen/ops/allclose_ops.h,sha256=d7p9Vqgz0YNzSh_eWgKCJlMisjm2LwrZxVWVIqtpDkc,1222 +torch/include/ATen/ops/alpha_dropout.h,sha256=JzEHU1KnzW98zfV7VLjDPb2WSpMoJ86w-RQNvA-oM9U,973 +torch/include/ATen/ops/alpha_dropout_compositeimplicitautograd_dispatch.h,sha256=x73XS5LPGg-TIKupbPTG2stRbYffVSyeVO7MudlXQ60,898 +torch/include/ATen/ops/alpha_dropout_native.h,sha256=CiQlQQ1-RdbgKmSSdZkxMq6stlsXIyYfCWvmj3l1lG4,608 +torch/include/ATen/ops/alpha_dropout_ops.h,sha256=Lk7T-ZaodUAxh_3ZKwTGjTV1vnBsqk3MddbyDFzRnbE,1699 +torch/include/ATen/ops/amax.h,sha256=9RRzhXJ66p52iuREIIsdL6r9EfUiDFuLVp4VxTflUlo,1317 +torch/include/ATen/ops/amax_compositeexplicitautogradnonfunctional_dispatch.h,sha256=cUvEsfROV2xAZC_gORt9KWpohJglhmKm3Q1TjUoQAhE,855 +torch/include/ATen/ops/amax_cpu_dispatch.h,sha256=L3Io7YJekxVzTPs1lkYyVKB90eC1009aCbhKdnevlE8,1019 +torch/include/ATen/ops/amax_cuda_dispatch.h,sha256=8GkNJySA75oZrJeKt515gv4fgi1w1tHRirUcLFQWOog,1021 +torch/include/ATen/ops/amax_meta.h,sha256=BSdnC5jThxiRDTZ5cYDV2Ow_xMYBrz-R8ZFcH666SJY,627 +torch/include/ATen/ops/amax_meta_dispatch.h,sha256=4_jLAJMOqORNjlvqC0gQa2-yMwGZtct-xU9s4WVTb0Q,1021 +torch/include/ATen/ops/amax_native.h,sha256=052N_Y555X0ZFEfd2NZE4ZslhEBLy5pUGvqS4i_WGXs,648 +torch/include/ATen/ops/amax_ops.h,sha256=exZd8Td0CEnh1bnAjA9gwJCAv3qZ2o9moKdCE-USQi0,1831 +torch/include/ATen/ops/amin.h,sha256=h_eGUvQF1IUnVHxxyUG-0KyuTJ4xeOYhZTj7zKYXsp0,1317 +torch/include/ATen/ops/amin_compositeexplicitautogradnonfunctional_dispatch.h,sha256=En0nxcY3kr0CbwyBqqvwwf7Ja4m-WyfDWHVEPzAijbY,855 +torch/include/ATen/ops/amin_cpu_dispatch.h,sha256=nXJ1hXPcRwHcmRec3C5oCcffrKOvE_fJYSdGpqr4_X8,1019 +torch/include/ATen/ops/amin_cuda_dispatch.h,sha256=YVBsfmOkpdZVH8SD0P6j-5xMKaV-ONQNT5-k9HkS45k,1021 +torch/include/ATen/ops/amin_meta.h,sha256=_l1206L6Y55q96qH7d9cnFjonpRnTVio_zUOxg4OX8E,627 +torch/include/ATen/ops/amin_meta_dispatch.h,sha256=3lirtjqNm-8r01-B_pVfxah04CH-qMapkxXfQZLZTOk,1021 +torch/include/ATen/ops/amin_native.h,sha256=Tc11u4hOTPAD3znpQjHFFXgREYu25ufivR8tEu02V54,648 +torch/include/ATen/ops/amin_ops.h,sha256=FGaWy9Z-zUcV4RlY1SSq5O_lCGJc9rrZqTMBjcp0VJk,1831 +torch/include/ATen/ops/aminmax.h,sha256=9WE1g6YqOASgdf2zy-is8qMoAETB_ED7c_bBkqCU4Kc,1620 +torch/include/ATen/ops/aminmax_compositeexplicitautogradnonfunctional_dispatch.h,sha256=EYbaxTx9itW8_PDWgePYhv4dPfl7GngGu3zSADR_J2w,904 +torch/include/ATen/ops/aminmax_cpu_dispatch.h,sha256=K9mt-A8l2ObTFMI-Dldi3Bz_7cmmG4UWOoelg3opgwI,1194 +torch/include/ATen/ops/aminmax_cuda_dispatch.h,sha256=AQTOXfoE4yivE-5VnLX4_Nder2G3ZJPk5-FAiLC8tFA,1196 +torch/include/ATen/ops/aminmax_meta.h,sha256=mcH-Ot44CJH9x51RiwmNeh-7VylPm6eTwCc3p2KjyBo,639 +torch/include/ATen/ops/aminmax_meta_dispatch.h,sha256=7W-jtU6YFz9fn5CaihbEnqU4d4Rglau8D1QEobHFvKo,1196 +torch/include/ATen/ops/aminmax_native.h,sha256=Ti-YhjzfSSPI20smu0F9Qth2HfrowIx1aD0lwTH4pgo,690 +torch/include/ATen/ops/aminmax_ops.h,sha256=bBoAtcbFSYN_IDbnT9rM5YPyv8Y5oggyB8JUwgLvXeA,2168 +torch/include/ATen/ops/and.h,sha256=yzw-MzsIrRnTVSdX-AKBquvElOpqwxvZ0uaa36R46LY,933 +torch/include/ATen/ops/and_compositeimplicitautograd_dispatch.h,sha256=kVXrgnVXq2EJgvFX7mi6dWxUFjvNQMPuabcp3rtJ3t4,1054 +torch/include/ATen/ops/and_native.h,sha256=vXcJ8Nbe1IRrCQYPsVS3iE0GiEcRQ3EEIt_3F2_TJU8,764 +torch/include/ATen/ops/and_ops.h,sha256=C4mYBrHo4XVEzfFQa2tHcGffvhVH1WQShIETjE5BXcs,2953 +torch/include/ATen/ops/angle.h,sha256=D3Ix3OzaRx89LEgzf3BI3tQEY2dvWul7umz6Vx4cfj0,1057 +torch/include/ATen/ops/angle_cpu_dispatch.h,sha256=A_fIWZOt5ofFVODIuDBaxjxDhpn58DdhEJe8kjCVt28,899 +torch/include/ATen/ops/angle_cuda_dispatch.h,sha256=Ond7AUHfyyxLp5iFuYchqpFsGviyAV7wkl282ZI68Ho,901 +torch/include/ATen/ops/angle_native.h,sha256=GAQS-T3QelINzJhTt-HUzF2uDFjQLvQM5hY1x1fkRec,728 +torch/include/ATen/ops/angle_ops.h,sha256=pdQeoccHXouXsZb7zlFOrR0F1yRBW9XbF7jEmU4OmvY,1581 +torch/include/ATen/ops/any.h,sha256=asCXWUGyqBcD2YHEKWGbRsRZjZku_NhVeqIG1ofsy-k,3437 +torch/include/ATen/ops/any_compositeexplicitautograd_dispatch.h,sha256=lPZ7HZfnQLtrezhR2v7X0i0p-AA5IwkDyVEy5SkX8dY,1078 +torch/include/ATen/ops/any_compositeexplicitautogradnonfunctional_dispatch.h,sha256=roXNPoebVBtV0hkZ4qjT0W1mfr_YzjwRFmlKM7qEHOU,996 +torch/include/ATen/ops/any_compositeimplicitautograd_dispatch.h,sha256=A42OgdYZM-nD3TNoLsGX5celIOPMHBBDBoXRIKn9xwY,1042 +torch/include/ATen/ops/any_cpu_dispatch.h,sha256=_8jbX3Cl8ZHMsM5pWObDRFfxTLSDrhg1u6uier7xoFQ,1537 +torch/include/ATen/ops/any_cuda_dispatch.h,sha256=UnG8m19gUKe55HibKBs5c-nU7altx5fIxmedznGO138,1539 +torch/include/ATen/ops/any_meta.h,sha256=kwcz04_T_4txg78mDpFgUj5ovNjC0NYEIfNHWA-5pg0,894 +torch/include/ATen/ops/any_meta_dispatch.h,sha256=LS3hfDPVYVmJv1ibAN7MjN35dGgEri6tdcHldC-wnPo,1539 +torch/include/ATen/ops/any_native.h,sha256=X0V_YGlMNHhcWgjoQ74nSkHlhcVq6REqj0Vx3KoCzCw,1490 +torch/include/ATen/ops/any_ops.h,sha256=GHF4j-E_zp6Jmtn9zymUM7eRkk2KrbgMOkJdv3F-m6k,5682 +torch/include/ATen/ops/arange.h,sha256=CFPclhfeZt2pCYPLyNAsGLvgc3LWJP_ZZ3AxFjPExPI,4242 +torch/include/ATen/ops/arange_compositeexplicitautograd_dispatch.h,sha256=qsXrPQwBdm-60ISMPdOQ08Ot9hcUq049uTKpH7Cb7V0,1892 +torch/include/ATen/ops/arange_cpu_dispatch.h,sha256=ucekjPPT4IDc9V0tD8jSSvJrb7_E01QaohaKLvGg5vc,947 +torch/include/ATen/ops/arange_cuda_dispatch.h,sha256=jb52TldaBy4p2th9HtW_yXi3FBkkfm_a2HisU_EqSSs,949 +torch/include/ATen/ops/arange_meta_dispatch.h,sha256=KKlCqxWYanA54fiDiT-Lcic5uEb304obAkEr-rcpGv0,949 +torch/include/ATen/ops/arange_native.h,sha256=OWRB1APQRxTD7_FBaygtrek7agG9nJaTnrVR9M6LvdY,1495 +torch/include/ATen/ops/arange_ops.h,sha256=lWtnyH143txhrbQamsGHRTVSIcuy3dpF5Mn-i4w-lZo,5178 +torch/include/ATen/ops/arccos.h,sha256=qHxBjWiok1mjXnu8wpTogVSSxaGftz4yXqjS8kJd3-s,1214 +torch/include/ATen/ops/arccos_compositeimplicitautograd_dispatch.h,sha256=7q2DPNsL2g5O7fY031Emmm4IK4rKHke5z_gy5mwladA,998 +torch/include/ATen/ops/arccos_native.h,sha256=hIeswCCStWoYWI2E9kL387mCFcs6itqJ2LvvoMPJzT4,628 +torch/include/ATen/ops/arccos_ops.h,sha256=oBdM3LNYNjjr-RGATbks4NKY5vVGniFS3y55rTQFJ0M,2097 +torch/include/ATen/ops/arccosh.h,sha256=jz79BGxW4bmwxJIM3hXioDKB3PcDx7FUr7c0rJ3lwEU,1227 +torch/include/ATen/ops/arccosh_compositeimplicitautograd_dispatch.h,sha256=EYaIHZ-8Fu1cO0yyS6UzZI926zUItJn4FxNBm-U8cgk,1002 +torch/include/ATen/ops/arccosh_native.h,sha256=Cva0FSKM4gB6PtCQyyvuEYPwXcVPl1zDtAls1EZnLzw,631 +torch/include/ATen/ops/arccosh_ops.h,sha256=3QH7aRYEiyqcR0oPnZWoqyyvec0dExlWhhrRf9_6_0c,2106 +torch/include/ATen/ops/arcsin.h,sha256=PEgT-WIPzV6wbGNCMy46BVP7XcAgVGKYdLa_QGBdJfg,1214 +torch/include/ATen/ops/arcsin_compositeimplicitautograd_dispatch.h,sha256=4DCjrlvi6TkOqy8s_tCgf3jD86dq458sCVi2dSGVt-0,998 +torch/include/ATen/ops/arcsin_native.h,sha256=PSB77b6ZnUfxeBMM8zQ-ac8Nfddf5d_q6_gEWvUC9R0,628 +torch/include/ATen/ops/arcsin_ops.h,sha256=cQ661Q_bgbrO7KJOMI8JmG-nxQV2vDVHFhgbCmbYW7Y,2097 +torch/include/ATen/ops/arcsinh.h,sha256=z-60NGwJxpcQRGRK6Bvepkt2rhmtalcUHxckTWprfsg,1227 +torch/include/ATen/ops/arcsinh_compositeimplicitautograd_dispatch.h,sha256=FRw_r449rpR-WM4JAn46LmszZOvasGaB-MWGr04rYCM,1002 +torch/include/ATen/ops/arcsinh_native.h,sha256=KuR6ZcvSC0FSPY_Tnvr5J0YCCpQFkJFRf6rgzKwjeYU,631 +torch/include/ATen/ops/arcsinh_ops.h,sha256=r7r-kMBVUOMoXWRPuHKoBwoTI1TQ1GK3DrPbMLfeltE,2106 +torch/include/ATen/ops/arctan.h,sha256=FJhOia7aYSqreVF7rWr7E6qDdVSli9ZR6UhNc62_qbQ,1214 +torch/include/ATen/ops/arctan2.h,sha256=O8P1YmeBhidUUnEftk7EF9LJ3pL7ybKt_ssJ_wdnvo4,1218 +torch/include/ATen/ops/arctan2_compositeimplicitautograd_dispatch.h,sha256=HuWEeJ1xBvhfxnhGooqoznT_FP8OCUv8P3oalFXfLDY,1106 +torch/include/ATen/ops/arctan2_native.h,sha256=BsLR6ywZj_AXCnjnJT5hwmc5rCiBoZR_RURamLq4K1Y,709 +torch/include/ATen/ops/arctan2_ops.h,sha256=f0tStIjFI6Ca8Q7SJMx2qaaXHHX5HZpfT0Ldhjms3EA,2364 +torch/include/ATen/ops/arctan_compositeimplicitautograd_dispatch.h,sha256=2IppZI3w9gjusdVB1A-I8OSVMlnkG6uR-B1rWWdFoLE,998 +torch/include/ATen/ops/arctan_native.h,sha256=OtpQKK_pKNoPg30I1HdOk8MpIdI7Q3_H81lvGLXxv94,628 +torch/include/ATen/ops/arctan_ops.h,sha256=b_3i281jRyvhVjGNeuRiFL9LTYUVmWkBSy0hmWR05P8,2097 +torch/include/ATen/ops/arctanh.h,sha256=guLC8aH49pjNKQi_TCsrGEE2L2TGj5JS0kacHi3Lm40,1227 +torch/include/ATen/ops/arctanh_compositeimplicitautograd_dispatch.h,sha256=CLHwky_TT5hLXZwUjmgpa2GN4S-_rX4R0llLKm7o_-I,1002 +torch/include/ATen/ops/arctanh_native.h,sha256=4X4f71Y68S_Oz6LSro-ypM2g3faAG9Ip9i6Sy7NTTBo,631 +torch/include/ATen/ops/arctanh_ops.h,sha256=0V_N3UQKt2j1SFn7zvNbAanApuOgSTfV2X5EFaipJiA,2106 +torch/include/ATen/ops/argmax.h,sha256=WDiMNSvlK4sWEbG63b2CIKb_2vk8TZcSfRsOJWtLTZg,1388 +torch/include/ATen/ops/argmax_compositeexplicitautogradnonfunctional_dispatch.h,sha256=iEiWw0Q9lTOnRqkwgbQeCJwu7XoE0MC6ACjCa57Ob6k,878 +torch/include/ATen/ops/argmax_cpu_dispatch.h,sha256=eDs03MEmIxhiNKSkhxKSMguX5P1YcSX03yGT8DoapgQ,1076 +torch/include/ATen/ops/argmax_cuda_dispatch.h,sha256=7oMtsfCygRM1xKAYCIsnWOER1Nl6dAjqSEOh6TWXiKc,1078 +torch/include/ATen/ops/argmax_meta.h,sha256=VUUM2YBX8eMFAm7hUiqxPsXjzynA6obaLAgA1n4ALNY,638 +torch/include/ATen/ops/argmax_meta_dispatch.h,sha256=uD_iFuRzbRCtxdZ9hEcNu0MpWYxfMQpUR8arMba7hQw,1078 +torch/include/ATen/ops/argmax_native.h,sha256=PGqeAJ0QjlTI3A_JlQKadbCxOpzwE6lqJ1xXHvZXg0Q,663 +torch/include/ATen/ops/argmax_ops.h,sha256=CXSWGBClBLRTs2miWQMbmP4Dp1LWx3BBJAkqklmDlqk,1897 +torch/include/ATen/ops/argmin.h,sha256=zmvMR0sRApOZW2i--WOBUvOrBZDB2URnWl7FVAE2984,1388 +torch/include/ATen/ops/argmin_compositeexplicitautogradnonfunctional_dispatch.h,sha256=RF5Q-FQlz2lNghJ5ZqUyPg-X4NQaUiBzkjJW6lezwDQ,878 +torch/include/ATen/ops/argmin_cpu_dispatch.h,sha256=LH9h5vxwb-_fEkdumOMFJ6ixgcTRoFT5Zuy4K3gxkjc,1076 +torch/include/ATen/ops/argmin_cuda_dispatch.h,sha256=6nfimgKRYih7icR1mTVYaN9wv3H80PbuTdW0LorjY60,1078 +torch/include/ATen/ops/argmin_meta.h,sha256=qTyLUB_v0WuhG1rKmjzfKLKGApVUOz9ZB8EzyGK-X3M,638 +torch/include/ATen/ops/argmin_meta_dispatch.h,sha256=GeGAs7q8pHwfJWN5i9LWNHrUuC_BY8rN_tOwwlI14VI,1078 +torch/include/ATen/ops/argmin_native.h,sha256=JJlWafFzmOLZkqKEJr9fx-R6TuIQ-RrP6PObvbGk4Vw,663 +torch/include/ATen/ops/argmin_ops.h,sha256=mahj8QsYDsIpEHP-bfmgdXWwc9FvGD_bXNX5DpKad5g,1897 +torch/include/ATen/ops/argsort.h,sha256=Zs833kXs96iaI0X1LqLsz5WEjL6YI6-ufk_mcIRvd7I,1974 +torch/include/ATen/ops/argsort_compositeimplicitautograd_dispatch.h,sha256=_f2bCz7-z6BGHLwc_lr6r-9Z6IjDB17c_8L-HYdta9g,1287 +torch/include/ATen/ops/argsort_native.h,sha256=hK-uYvQ3bpcqXN7dqCOLDK0EUd7AGET95jM2WVOv528,864 +torch/include/ATen/ops/argsort_ops.h,sha256=YeT_N2kSofZiiFQ7lzkYocVuUGk-CTWb0RjMxWvhQAM,3231 +torch/include/ATen/ops/argwhere.h,sha256=Ns_bMPoNxGQ2bFtQeaL4CVWpAGGYwyjZV7ELaaA2AM8,676 +torch/include/ATen/ops/argwhere_compositeimplicitautograd_dispatch.h,sha256=y92LhevhGwzuQLV1Pun-JN-IXE9RIuCSfzKDEnVety4,789 +torch/include/ATen/ops/argwhere_native.h,sha256=mumJ4COBb5T8ASt-WeSQonYqbPpg7yN3Mn6AO9K6V1w,499 +torch/include/ATen/ops/argwhere_ops.h,sha256=vtQaUVYz4isK5uAyuU8g6efYNIvBrevxJAiSzMttKqM,992 +torch/include/ATen/ops/as_strided.h,sha256=lumUfE6t2wA7qQrdnJuAymUbbPiN3pqnYOXPXNbnjsU,3948 +torch/include/ATen/ops/as_strided_compositeexplicitautogradnonfunctional_dispatch.h,sha256=TrdtisJBlDOu7qzEtJxDApOJHB9m-OEUJnJKbaFe1co,1117 +torch/include/ATen/ops/as_strided_copy.h,sha256=cv1VQpqRoAdrPjTOg4JB_-asqzaIQ9sVPur7R_KvwXY,6022 +torch/include/ATen/ops/as_strided_copy_compositeexplicitautograd_dispatch.h,sha256=8UhnkI0G8EZbqpi-XY905_M7GjF0Ht8cZdIvy2ONPc4,1502 +torch/include/ATen/ops/as_strided_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=YcNFbvAhmoSzuGmCkFOWNPFCr9tto2ndkPi0C6mdO0c,1109 +torch/include/ATen/ops/as_strided_copy_native.h,sha256=lB9yiOiVozQR6KU1nW1u6X4Vgrp5xN8ebXmk34S84z0,821 +torch/include/ATen/ops/as_strided_copy_ops.h,sha256=hwXfmTZLk2ACnj6uSslcFePz3Ck8xAwgauM7V2Lt3Mg,2303 +torch/include/ATen/ops/as_strided_cpu_dispatch.h,sha256=2oup49AdZJvke-mWqliPEKYAWfUThXzy4J9Ur10_pXA,1029 +torch/include/ATen/ops/as_strided_cuda_dispatch.h,sha256=noV2BDzlT97oR074NiXQacO3b0sYeVjm-gToCiF5BXo,1031 +torch/include/ATen/ops/as_strided_meta_dispatch.h,sha256=stZ_iNlOqEPXU0HmzhEp78pxpQfG-hHlDQ76oXaBdOs,1031 +torch/include/ATen/ops/as_strided_native.h,sha256=cO7BeNOE_2ngALVujFVk7gX8VrVsYP9qebxz87-hAG8,1172 +torch/include/ATen/ops/as_strided_ops.h,sha256=3gspOGys1ykHi0geJxRtfHLr8dzGgX2fV1RqR__UZBk,2224 +torch/include/ATen/ops/as_strided_scatter.h,sha256=T4X4BEjalIEZaeXQOk3Rde-KsV5fa2JNG5moyJV025Q,6535 +torch/include/ATen/ops/as_strided_scatter_compositeexplicitautograd_dispatch.h,sha256=5pmPPRCYNJFt99bkMgn00Ubs7w9PvtLQoLGk7cyI90U,1610 +torch/include/ATen/ops/as_strided_scatter_compositeexplicitautogradnonfunctional_dispatch.h,sha256=avEaMQExziIY1gq3bTYGTE_vzf-_YX8lKLOZAJGoPEI,1163 +torch/include/ATen/ops/as_strided_scatter_native.h,sha256=QJaAOkJKQzcwUdMiBqaOPHVa97nfultY2Avnmbyoy6o,875 +torch/include/ATen/ops/as_strided_scatter_ops.h,sha256=jRqzlT0xJvx_zdIe8d9qBAPq5NmtM7jS2KGWrk9zdkU,2481 +torch/include/ATen/ops/asin.h,sha256=y7RPoKjCipKWKl4ujcubyU64SlR-m866M1nraePnFHY,1188 +torch/include/ATen/ops/asin_compositeexplicitautogradnonfunctional_dispatch.h,sha256=wnlTtsc6jXfEq0nQpwfAvx18JFWK8vMf5dQ5lXwKJPA,861 +torch/include/ATen/ops/asin_cpu_dispatch.h,sha256=6JxahdkIjdn5UQZhKI8_1qYluTZAkeBA5tzsH2GITtM,946 +torch/include/ATen/ops/asin_cuda_dispatch.h,sha256=7mhWp4u0wO1OsRsmAgDeUF3GhPR1b57jHGE7yiL3QU8,948 +torch/include/ATen/ops/asin_meta.h,sha256=SFwozeyBeN83il5SijSzxOuEQN9YGkm3LlSkZS9wZlU,592 +torch/include/ATen/ops/asin_meta_dispatch.h,sha256=q4oVLO9VomRxVbhkTumPzzaGtoyjS7riswcFyhRT-a4,948 +torch/include/ATen/ops/asin_native.h,sha256=eYUE3Wn5lzhmZZR2VqJe2yTFvLoVxgrfvJpSBTy7FnQ,1027 +torch/include/ATen/ops/asin_ops.h,sha256=uyewtY23hzTIUIJ6kcjToy1ds03d49I_2_1Wmc-RHXY,2079 +torch/include/ATen/ops/asinh.h,sha256=-vra-YfQDfFBQdwusqJar4nwsAbnRS_ogYffkihuM3M,1201 +torch/include/ATen/ops/asinh_compositeexplicitautogradnonfunctional_dispatch.h,sha256=gBgAvhTYrkVoN9JryKpNR_Xex1NqZFScCi3rPsVSreI,863 +torch/include/ATen/ops/asinh_cpu_dispatch.h,sha256=JE23tq8U2PciyPub4h8kUaMtJ4rju6DcsdfgzvHQLEo,950 +torch/include/ATen/ops/asinh_cuda_dispatch.h,sha256=ptJkuUf5i-93ynPYdiKJo884PNQ-kCefida9AFeghKI,952 +torch/include/ATen/ops/asinh_meta.h,sha256=Se9gkaQf5tvqC0fpIvzK3iJHoCDyCVTFltlIu8rHP-M,593 +torch/include/ATen/ops/asinh_meta_dispatch.h,sha256=fMokD40H-xZCFjpt0LkfLl6HCeUWWT-JvWkVz0Aq1gU,952 +torch/include/ATen/ops/asinh_native.h,sha256=QFQC7-clZyqFfWCWfCpZWrE0ykV7wI5-khEw0cbunqA,1036 +torch/include/ATen/ops/asinh_ops.h,sha256=LN8pfIn5IOsZMkcCXgpwpUetSl70kC2wIMiyylkarKQ,2088 +torch/include/ATen/ops/atan.h,sha256=hp6p2KDVdLbCR4fK0q1HU6tdlXO7coFMw4Wd_PJ2CKw,1188 +torch/include/ATen/ops/atan2.h,sha256=w0olk3JAhecb8fsbccen-nGjZ3mLhCZ4VjnUhSYxvg8,1198 +torch/include/ATen/ops/atan2_compositeexplicitautogradnonfunctional_dispatch.h,sha256=ElTWNSAj1APWClWjkTySgf_e39cfquV3TbwmkjONX40,915 +torch/include/ATen/ops/atan2_cpu_dispatch.h,sha256=Q3DVI8Jfp7yof-tHBcoqZjI65Qw3tTWNyXGb_9EYOoE,1054 +torch/include/ATen/ops/atan2_cuda_dispatch.h,sha256=rr1i-I3YTTQbkc5wVPTFQAtFxRoZ3TQJ7V6xTKdtF40,1056 +torch/include/ATen/ops/atan2_meta.h,sha256=KGEVBUZSFZ5eNEjbmo6yUoQrgrvqCCo5y9kOCPUcTNU,619 +torch/include/ATen/ops/atan2_meta_dispatch.h,sha256=r_l9vlX74cr221srdhiDtt6-hX1MV7bzrHjZssxPs1w,1056 +torch/include/ATen/ops/atan2_native.h,sha256=OlWyoqBidcmeWVOaA0mtXMJKjBZ3rnZsckF16SqN3Ic,642 +torch/include/ATen/ops/atan2_ops.h,sha256=2cpzgptPJXUpI0Nk3vjt8UBuaU0MyqJgqof8zsecShU,2346 +torch/include/ATen/ops/atan_compositeexplicitautogradnonfunctional_dispatch.h,sha256=5KKyBVSYLeEv0RoFZkUIlO947syiJHO_CLYLbVGQ718,861 +torch/include/ATen/ops/atan_cpu_dispatch.h,sha256=J08mYAiyy20j_8b4oz0naRHQZVfiU77dzbLdZsixoUo,946 +torch/include/ATen/ops/atan_cuda_dispatch.h,sha256=46f1vZmZZMTjKc2CE2UQVVpsH6eVfqhqKwR_yQwEaQY,948 +torch/include/ATen/ops/atan_meta.h,sha256=pFn3Sb2xGkvZuxAZ5psMr_H0ki4y4SQyizVy2c1XPKI,592 +torch/include/ATen/ops/atan_meta_dispatch.h,sha256=mi5oyOC2LG7mcTX2hVjY771yKSHiS4Dxo5Q38eAP7O8,948 +torch/include/ATen/ops/atan_native.h,sha256=1ou5z5QEtW9DvXZ4VG0D-IXl6BYQbPZGz9yGCU4CUWg,1027 +torch/include/ATen/ops/atan_ops.h,sha256=X_qFkMLWbpCUQ2UE3KCtA5ucPpaHOT25UKVuQTosoFs,2079 +torch/include/ATen/ops/atanh.h,sha256=EzkPpMASd59yvxAI9APxrmKnV4aZwiV9YCZOdGIDF9s,1201 +torch/include/ATen/ops/atanh_compositeexplicitautogradnonfunctional_dispatch.h,sha256=DSXQEq-UMxwr1fUmlTSkPRPYIFzOQ4czxciu71tBzos,863 +torch/include/ATen/ops/atanh_cpu_dispatch.h,sha256=-1XhblogUyRFizgr3FJv3aAWlSwi4c0CE5XJ_7k1DbA,950 +torch/include/ATen/ops/atanh_cuda_dispatch.h,sha256=AlLK1E2-m_Ox99TDAb5QTenk2-fkw1OvYuGrvnx6rjM,952 +torch/include/ATen/ops/atanh_meta.h,sha256=IepPpf-efIJKb7k3tn4qHhvI2DxByH764TQ4b4AoVYw,593 +torch/include/ATen/ops/atanh_meta_dispatch.h,sha256=0OuCwktLRCLA7p2bFMBGOCFHf0X2LqKL5Vy7Jm9o8sc,952 +torch/include/ATen/ops/atanh_native.h,sha256=ZmSDLDsMvU4FyhowU39yA_L4pZXZ91_KDifZifkgG6c,1036 +torch/include/ATen/ops/atanh_ops.h,sha256=JW2WtQWP4pbAnkY7FgAej8VHYME-cGGp6JsoTHCc89s,2088 +torch/include/ATen/ops/atleast_1d.h,sha256=B1umkpWDd9w6AyfLG53CTqC76Z8OTICoaSrwiyVSyaM,878 +torch/include/ATen/ops/atleast_1d_compositeimplicitautograd_dispatch.h,sha256=Jvp2TCKzw4v10RJJLpuNlG-348DFpYkYzLHI5aaJdqk,864 +torch/include/ATen/ops/atleast_1d_native.h,sha256=lwXShN4IrsqwaNQx5-7cguKCwPUwdtaboPoXFnbjP3Y,574 +torch/include/ATen/ops/atleast_1d_ops.h,sha256=Qtm5dGzdK_56EO6ce4iBZCEsmtLV6el_YOq1nD1Do6g,1593 +torch/include/ATen/ops/atleast_2d.h,sha256=QEqgXX8UC3pmjYyx7nBR6bFoszmI9rzjCNmA09BuUss,878 +torch/include/ATen/ops/atleast_2d_compositeimplicitautograd_dispatch.h,sha256=ZKfjVqXnKc3s0UmsAOhewN5JO7JzFpbI_M5hxVGeues,864 +torch/include/ATen/ops/atleast_2d_native.h,sha256=MAp6IiqcwAC8q7tlkuazto7rdTh5D2RyPCup0C4zxn0,574 +torch/include/ATen/ops/atleast_2d_ops.h,sha256=6KXI77TO7tr813m89laIsh8z04pHbIun1GlmqC-Ywuk,1593 +torch/include/ATen/ops/atleast_3d.h,sha256=GkET8IqHsppTAHCDeD1wnydqQj9atLwB8QlZ-bprAlY,878 +torch/include/ATen/ops/atleast_3d_compositeimplicitautograd_dispatch.h,sha256=c2TqVt2URBaSWsvePikgY4Jsnxs-Hm4l4V1MG6HtmiU,864 +torch/include/ATen/ops/atleast_3d_native.h,sha256=R1lumPdBbgym1qlHBhZRbS-PfkmQ8gVGKxQqgOo3Y-s,574 +torch/include/ATen/ops/atleast_3d_ops.h,sha256=vzDW1jG1VKS4aac2j7rlO-SCYXFfQ6nRPD9airPb5H4,1593 +torch/include/ATen/ops/avg_pool1d.h,sha256=W3_JyRrqdT_80ksF6IR9pEVF3dW08KYrSGdFWqk02nY,1994 +torch/include/ATen/ops/avg_pool1d_compositeexplicitautograd_dispatch.h,sha256=gFFZ716pJNJR5x6cnPr5ajQIM6jeHI9lYTYSHNNi5QM,1151 +torch/include/ATen/ops/avg_pool1d_compositeimplicitautograd_dispatch.h,sha256=C970shXfeDRiwzshesVvGxkGZuOcvNXriAg_ZEHkzqA,925 +torch/include/ATen/ops/avg_pool1d_native.h,sha256=9vvh7D1kSc9BE4Dfmmj_sIZMq-ij19fXaJfHpjANqTg,836 +torch/include/ATen/ops/avg_pool1d_ops.h,sha256=5r7RPjVCkqbVDxNpIAb7RwGJPfMwJNKkNxxBzQFYazI,2423 +torch/include/ATen/ops/avg_pool2d.h,sha256=9qhgCPHVJxj9fRkDEkTDjFueQR98NZhvmh6yQ8BEK0I,2291 +torch/include/ATen/ops/avg_pool2d_backward.h,sha256=97T9aN2Gs2yHdFAfQ2Ly05d9btYMSC-uISYRXAe-wnQ,2521 +torch/include/ATen/ops/avg_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=fRXG5gKpqQ4oNnWe8u3c5vFjfGOoPMZK5f5RqtVO-ZA,1019 +torch/include/ATen/ops/avg_pool2d_backward_cpu_dispatch.h,sha256=f0nt1Ec-jhgnmrWdy4ep18-A_DUoCbgjhwboB4n-ZrE,1534 +torch/include/ATen/ops/avg_pool2d_backward_cuda_dispatch.h,sha256=XpGcOLBheGZ-7FcUq2rGGV9KY5Hc4XkLngX_8K4UP0w,1536 +torch/include/ATen/ops/avg_pool2d_backward_meta.h,sha256=T4aiaf9dLyELvS9u9ape0AjmngxWrI_kI__Q7DK6-M8,800 +torch/include/ATen/ops/avg_pool2d_backward_meta_dispatch.h,sha256=q2PCqyFtmbeMMYuPhKBGB39btXeekc91HESsEqWH_WU,1536 +torch/include/ATen/ops/avg_pool2d_backward_native.h,sha256=3ppF1oRYmBn1cMGznfBfym1ZM1iBfg4qfOE6fE0FyHg,1804 +torch/include/ATen/ops/avg_pool2d_backward_ops.h,sha256=NJAPe2mkF0PUD3HovmFQT2L0SbFR__gLGWOYNqpHra4,2965 +torch/include/ATen/ops/avg_pool2d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=MfJh7QqSnuhalxguVqGua0_1lshXhITGPIvyd-eUzlY,1009 +torch/include/ATen/ops/avg_pool2d_cpu_dispatch.h,sha256=zpPIgu3ea3foZQIzORj8-Pj_0gz7HZth7D0gPg53Knc,1459 +torch/include/ATen/ops/avg_pool2d_cuda_dispatch.h,sha256=L9BQeGvGxRksRTaAzWt-FMjsceXfi1yEXAOP7_wTzD0,1461 +torch/include/ATen/ops/avg_pool2d_meta.h,sha256=HDtIQTC_GpCTyeymd1nPD1qy_wC8fNX2PqDszWvxNXM,3556 +torch/include/ATen/ops/avg_pool2d_meta_dispatch.h,sha256=BkpW1criHSHTvjnpSZHBSeFAjSHQ4rAuik-wlEbVqYY,1461 +torch/include/ATen/ops/avg_pool2d_native.h,sha256=EpeXt3xPvY-Qg_U4Ct6OwvrUoj97DV-olbNOpf_DbAc,1884 +torch/include/ATen/ops/avg_pool2d_ops.h,sha256=SzkVdjgnj_9rtG-hYBC6evkaHmFZPGuGDDnuKKJWB7M,2703 +torch/include/ATen/ops/avg_pool3d.h,sha256=MQR1wattJxkMr4whiRsXZOBX0O-LDdsPfm12fpmCye8,2291 +torch/include/ATen/ops/avg_pool3d_backward.h,sha256=HqLnbtvQTK6Cyfa6b3XhAW5eInZDZEvx4NqerrST8vU,2521 +torch/include/ATen/ops/avg_pool3d_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=33uMQH4zm2oawwpWjBgdDR8Z60pfHkgnNw9fSCgQK58,1019 +torch/include/ATen/ops/avg_pool3d_backward_cpu_dispatch.h,sha256=K_l2PzO5SuW6k4ZwN7fwq0Xxa7daGOzZKx-P0XfpQSA,1534 +torch/include/ATen/ops/avg_pool3d_backward_cuda_dispatch.h,sha256=q-kt1Csx_5Wh7vkAk93wD_A9wewv73r8hZl7H2G1MNs,1536 +torch/include/ATen/ops/avg_pool3d_backward_meta.h,sha256=ldP4wYUTtrAuLHp5ibO-vmNaXINs3HPmo0nxshJEg84,800 +torch/include/ATen/ops/avg_pool3d_backward_meta_dispatch.h,sha256=OYi7Jmjhxr9UceapU0NW756PBKv5qjz-6_WcwoAoL2A,1536 +torch/include/ATen/ops/avg_pool3d_backward_native.h,sha256=5U9INYj_UYpv5v1usfqTgS2BqVk2cU0nzToq8vR_h1Q,1804 +torch/include/ATen/ops/avg_pool3d_backward_ops.h,sha256=b38UGFh62YKfhTJAayldpiqrxCRCxYG6VyBpaC1Fq4Q,2965 +torch/include/ATen/ops/avg_pool3d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=YBFpZikSQhY50RPjMhttQNIvSsWe30rB9nNeuSiM3Yk,1009 +torch/include/ATen/ops/avg_pool3d_cpu_dispatch.h,sha256=eaFQ9q5akXV-kbao5kvuUmJbu8DiMZzB_hIWXibbmuI,1459 +torch/include/ATen/ops/avg_pool3d_cuda_dispatch.h,sha256=8hwoC-Y3hYE_BrJMFNdaNiuq2AvOyipJ2c7BXJHU6jc,1461 +torch/include/ATen/ops/avg_pool3d_meta.h,sha256=KPZa7mLDBqkB4sFlqH3OTdjglK0KffM3p5---Kn_zqA,759 +torch/include/ATen/ops/avg_pool3d_meta_dispatch.h,sha256=dZqj8LUkuO-sm9hm6aPaqc_EPytfOaiz_y1BaU00J7I,1461 +torch/include/ATen/ops/avg_pool3d_native.h,sha256=7s2nIEJ21G3ZYe9vFx8-lokLqZK3W6wFDYzUme-Q8dw,1888 +torch/include/ATen/ops/avg_pool3d_ops.h,sha256=joCBCE_q0tVZkfWBod-Zouz2Hph7-u6hAT3wlLNc8dw,2703 +torch/include/ATen/ops/baddbmm.h,sha256=_L1lijIBLc3rmHI9j0MLqTZOUaXf5Lb-4RUovXo-h84,3030 +torch/include/ATen/ops/baddbmm_compositeexplicitautogradnonfunctional_dispatch.h,sha256=4EESnVcxL-gm7QzWakKyPEfFMIqbiRtlIIPKwT1aIWs,1085 +torch/include/ATen/ops/baddbmm_cpu_dispatch.h,sha256=uoGuXoi-s12ODeb88ZK_feFnIajINEghhmzQ0v_QG2Q,1390 +torch/include/ATen/ops/baddbmm_cuda_dispatch.h,sha256=je2YLEa-6K_FqDfaktWIB8Dbr0PcsrbecHiY92TGah8,2010 +torch/include/ATen/ops/baddbmm_meta.h,sha256=qRn7R-8A0suZZva8vmp_zK-aq9dx0V35Yew6mIQ8x_8,700 +torch/include/ATen/ops/baddbmm_meta_dispatch.h,sha256=lf1zoqJOxDnmPTy9X2-n5ilFtKBBLX_LDYG679BC1WE,1392 +torch/include/ATen/ops/baddbmm_native.h,sha256=O-8nxi-CwfkbbC6ZLZy3AxvdJ0TORl9I-BGRN9Fx-Lk,1614 +torch/include/ATen/ops/baddbmm_ops.h,sha256=yQzOaGXulZLvk6tAq1XaagUeIBd4-2vETUnagpOEoiY,5203 +torch/include/ATen/ops/bartlett_window.h,sha256=Kt-moSupi4pVQp8NeQxzP3rJz8zmza5fPw9C-roVWpo,3479 +torch/include/ATen/ops/bartlett_window_compositeexplicitautograd_dispatch.h,sha256=37z5EZFrV7tjZZ1Ex4XT2yc3KF8EcQ4BcI09d0zbLWg,1736 +torch/include/ATen/ops/bartlett_window_native.h,sha256=mYriYR4W3dGzlAiKHmcPxi813w4RUP_vhVulmx0WlTw,1091 +torch/include/ATen/ops/bartlett_window_ops.h,sha256=zjeOuA8FmJWdsg5uUsckYLJSjiasUOdNeImy6aH8Y1U,3918 +torch/include/ATen/ops/batch_norm.h,sha256=fLW1tjxAFc49KGVWjT9vb9ZsRTjzUFpG49qtl4qfAHE,1152 +torch/include/ATen/ops/batch_norm_backward.h,sha256=AWwgL_EIIbCSWLxyQJoP3YlLzgmRcyhbsMitg-PvuWs,1386 +torch/include/ATen/ops/batch_norm_backward_cpu_dispatch.h,sha256=Y-wDB1viILXYmtaNIiTL4ZKv1YI10wsAzs_iE9jhqfM,1128 +torch/include/ATen/ops/batch_norm_backward_cuda_dispatch.h,sha256=oDGTqU-hU0lez3chD94nYK-jE6muDBPmhqbXD2Bn5Ho,1130 +torch/include/ATen/ops/batch_norm_backward_elemt.h,sha256=4pEdk_RxGj-YHnNAqbKqMplIZAr1FoNw9wBXSWhx5WE,2397 +torch/include/ATen/ops/batch_norm_backward_elemt_compositeexplicitautograd_dispatch.h,sha256=Jw3Ox5Yunk9OSYxHhYtubc29u__T1ABN0cZvfMNIOrY,1349 +torch/include/ATen/ops/batch_norm_backward_elemt_cuda_dispatch.h,sha256=GNLQmMUJKXXTwp4J-fthRv-LH1XxbIpxXQ_3EYJuog8,974 +torch/include/ATen/ops/batch_norm_backward_elemt_native.h,sha256=zs85zbMjVPl5Y4U1jgvlUR8mWIzAMNl0DTniBn26o0o,1039 +torch/include/ATen/ops/batch_norm_backward_elemt_ops.h,sha256=keyYxTPjfWxX9fhRgTh7Gr8LEU1-p4fyAYYOVBDbZmw,3075 +torch/include/ATen/ops/batch_norm_backward_native.h,sha256=V6o3weQwbKSZHFAz9DOEjj6uGN2MYTSr8WCyx--PWwA,1793 +torch/include/ATen/ops/batch_norm_backward_ops.h,sha256=wz9cj1da30zJE-uwZmAZhgWsqRakqf_rucHNB9KxU3c,2241 +torch/include/ATen/ops/batch_norm_backward_reduce.h,sha256=0AJ3qL0hb5Xe2fd-3v980shMWlxwURYfh26DM5lw0Qo,2776 +torch/include/ATen/ops/batch_norm_backward_reduce_compositeexplicitautograd_dispatch.h,sha256=Tf1tAdRyspp5ABUGEtWFK3w6MKyNt9DhT8obQI-EP0c,1489 +torch/include/ATen/ops/batch_norm_backward_reduce_cuda_dispatch.h,sha256=c2yU8P5MKWoQl0bCDObNag-4uwgIT9u02Ym_3V_VkQk,980 +torch/include/ATen/ops/batch_norm_backward_reduce_native.h,sha256=-KaAQcuP7ZLFBUpV1We4nR1jOBTCq68cfbP48LvNCyM,1115 +torch/include/ATen/ops/batch_norm_backward_reduce_ops.h,sha256=hDBQAdcSBsD39bRh0TIHtLKzGKUZs7fb5IXrzHAP-l4,3391 +torch/include/ATen/ops/batch_norm_compositeimplicitautograd_dispatch.h,sha256=W81gr6k2y6h_O21keqKvj7-1fyb0utG2qLFLt7M5UD8,1041 +torch/include/ATen/ops/batch_norm_elemt.h,sha256=CczNtV6Tmog0TT4r0-4bdbKjDXkYABofBF2kqIGZML4,1932 +torch/include/ATen/ops/batch_norm_elemt_cuda_dispatch.h,sha256=EIUCgr2IhlpfJ4BdKrFYarBoPYdEt718i6v2k7KCNXE,1387 +torch/include/ATen/ops/batch_norm_elemt_native.h,sha256=oG1CjYZ9AWQBaWqp1VGnIEQ7o3dq9V8xYvWUxtM5_Bo,908 +torch/include/ATen/ops/batch_norm_elemt_ops.h,sha256=8A1lFXNrIjhi3fe_BcrBK0tMZn3DGrhzYlqvZ2tddus,2635 +torch/include/ATen/ops/batch_norm_gather_stats.h,sha256=EThWDQoqE5r3cB_-Pdd0i9DGoSyX-eqVbC8PUyReMCs,2554 +torch/include/ATen/ops/batch_norm_gather_stats_compositeexplicitautograd_dispatch.h,sha256=6BHV3w7aAFjp8_dGWJrQvYDK9MBe1y6iJF7ljHVrSeQ,1411 +torch/include/ATen/ops/batch_norm_gather_stats_cuda_dispatch.h,sha256=0TaKuCG5yZLfgsOfwA_iZkmE1VyC6zCJk_l_bO_gnH0,983 +torch/include/ATen/ops/batch_norm_gather_stats_native.h,sha256=b-is3i_yPhLj7g6j5CGlnaexYew-TV7awF-MgkS5qvY,1079 +torch/include/ATen/ops/batch_norm_gather_stats_ops.h,sha256=3Ph7VE2sX5iCkG-oxeA9IyH86Et6WRBbPkhLVFJ17TM,3223 +torch/include/ATen/ops/batch_norm_gather_stats_with_counts.h,sha256=4giF2OXjCpjug7bZpRZxrJgyHv-pnBrdMLbxaCSx1-8,2725 +torch/include/ATen/ops/batch_norm_gather_stats_with_counts_compositeexplicitautograd_dispatch.h,sha256=q0hn0t2NRub5anr_nkbgOv_Qrx8BpHN8v1D2N3iBhzM,1459 +torch/include/ATen/ops/batch_norm_gather_stats_with_counts_cuda_dispatch.h,sha256=InO_xbC0yz4XNzoVSb4nsFvFO5tS24-JDl89tCRWUzY,1007 +torch/include/ATen/ops/batch_norm_gather_stats_with_counts_native.h,sha256=3btPQbPAErrRFzd0SGGgLdd8RRi8losS1lJf6g22y8s,1127 +torch/include/ATen/ops/batch_norm_gather_stats_with_counts_ops.h,sha256=ib_o5D4h6YBTCbRm_kkBEsAffVt20kvFByM_WqJyWv0,3373 +torch/include/ATen/ops/batch_norm_native.h,sha256=t3aioaOY7Q-iggOKhgQBa4y2lnS3ZY21kBEx3nDSxqk,751 +torch/include/ATen/ops/batch_norm_ops.h,sha256=tqe94gGgqkq6_TZsLulEmUa2MQa2fINuIAmpAtzSbsg,1810 +torch/include/ATen/ops/batch_norm_stats.h,sha256=JBQ0RTQRz-tjzclMuPzjPPBE1FeN6Tc5w-12W7CfKEU,1467 +torch/include/ATen/ops/batch_norm_stats_compositeexplicitautograd_dispatch.h,sha256=P1Ms1oIz8V4MAbiKYjfweBWFumZAk93yxfeOWXXgmXg,1031 +torch/include/ATen/ops/batch_norm_stats_cuda_dispatch.h,sha256=RshrENJGA4TBzAt4bOA0Koe6ImhalCaqQE7cm6N5ES4,793 +torch/include/ATen/ops/batch_norm_stats_native.h,sha256=W2_3yta6RBVA7-DBRa8i5VK7ONxbcrs6tIbVX0fA2h8,699 +torch/include/ATen/ops/batch_norm_stats_ops.h,sha256=xlwN8BKCwoxJKZh4XWmv46UPK-9UZXHZvVc77O-I2xE,1991 +torch/include/ATen/ops/batch_norm_update_stats.h,sha256=K9gwpC3kJbaRSRUacwnnMRjrrvVEt8MXL_BPpUs5OiM,2089 +torch/include/ATen/ops/batch_norm_update_stats_compositeexplicitautograd_dispatch.h,sha256=-1xWCT7RqJLq0W9By5LRVTFjJBfZRg3bsJc2txJYCO4,1253 +torch/include/ATen/ops/batch_norm_update_stats_cpu_dispatch.h,sha256=PgPydYQxoGvaTIcToCwgA0Uv5FRh42gMNGrwO6rPa18,902 +torch/include/ATen/ops/batch_norm_update_stats_cuda_dispatch.h,sha256=KX1JK3ZM-7zxfVc7CWjxjz65O2qBe5OQOttiZsCi58Y,904 +torch/include/ATen/ops/batch_norm_update_stats_native.h,sha256=tv8h-ud06LU6o-RXyAc85m7We94yiVk7N15pp28zbnY,1139 +torch/include/ATen/ops/batch_norm_update_stats_ops.h,sha256=izDw4yz-JELp5EX1LKBlHkKkyg4h6kEQb4K9KAXw7fQ,2693 +torch/include/ATen/ops/bernoulli.h,sha256=Zsw20FG4kFBGuTLB4RIJVqI4PWquaTxn4VJyMTu4MtA,3301 +torch/include/ATen/ops/bernoulli_compositeexplicitautograd_dispatch.h,sha256=QuiyM3M3AhKJ28TmI01gdtNcdlPlqPO63c8pa0nsevQ,1580 +torch/include/ATen/ops/bernoulli_compositeexplicitautogradnonfunctional_dispatch.h,sha256=beRWxt21E1iR7ehzIx5p-haAIhMSrCSnkjxQG4bmypI,883 +torch/include/ATen/ops/bernoulli_cpu_dispatch.h,sha256=_5nL-BbyNDxH9S2j5ijfPFbECShjKZ-BVpv2VC1ghac,1212 +torch/include/ATen/ops/bernoulli_cuda_dispatch.h,sha256=h2WMpu_bvHshHl_MrK0QsKQ-IXjY3xsXSlk20vsobzA,1214 +torch/include/ATen/ops/bernoulli_meta_dispatch.h,sha256=vnCIBfiZqpqG-Jr_XI1kRtgfKAJFK8vfPJx_NeVnNuE,950 +torch/include/ATen/ops/bernoulli_native.h,sha256=7CJEM7WVQNCmXveHaXTW9WjUe9goUr6XyMpfOIp6O8w,1496 +torch/include/ATen/ops/bernoulli_ops.h,sha256=2BHZcvwy6cXS0k4UYtxt2g0Wu85EgPUCTysNfIv_Dcs,6504 +torch/include/ATen/ops/bilinear.h,sha256=87N-hY9oXkIQefqBX7Vt-gM8Nd0ERSfzUyeua_Xelx4,852 +torch/include/ATen/ops/bilinear_compositeimplicitautograd_dispatch.h,sha256=zHV66Qi6JSBhe1qUxBiTDDsCmItiHQfN1YQ8sAXHfRs,890 +torch/include/ATen/ops/bilinear_native.h,sha256=21FxM0ZBnaexioPFhsZjZQuUgtErp8enj19gSmR7guY,600 +torch/include/ATen/ops/bilinear_ops.h,sha256=zNpl29GcapHOda0rSj6IpisaZnYO6SKeXUNWWu7uJBs,1316 +torch/include/ATen/ops/binary_cross_entropy.h,sha256=dgKPvNxS34zN-8jgJ0otlkhjjhXFjSZmnoM6mf0-6-c,1772 +torch/include/ATen/ops/binary_cross_entropy_backward.h,sha256=KOBWnYcN29llu8pNUTH4B_zpJLbyMSz-aSUv-NXc6NE,2127 +torch/include/ATen/ops/binary_cross_entropy_backward_cpu_dispatch.h,sha256=7vlDnlGJ7eMlQcproQ1HD55GZomSJsHDuPdbpbPHcik,1397 +torch/include/ATen/ops/binary_cross_entropy_backward_cuda_dispatch.h,sha256=UfijUrymEvPWEmSwmFg_rAL7zeBBMYVAfHJ-N7uS5Jo,1399 +torch/include/ATen/ops/binary_cross_entropy_backward_native.h,sha256=E1HuBFOz4oEVn8gLObZHO3swCd8seVPHfs4R2dZOpH8,1368 +torch/include/ATen/ops/binary_cross_entropy_backward_ops.h,sha256=XDTgBzctHtlkfRQZ6qntvJ9sPJ6RWM4zvArIGKmQNwE,2579 +torch/include/ATen/ops/binary_cross_entropy_cpu_dispatch.h,sha256=3Sgn9R7kldbJM6KCvuJSZ57-8r4iCV28jsXw-SkJnrw,1260 +torch/include/ATen/ops/binary_cross_entropy_cuda_dispatch.h,sha256=rnS4-PIk09_p0E-WKzxFs_bG5V5iLgTHIFzPbyINOYs,1262 +torch/include/ATen/ops/binary_cross_entropy_native.h,sha256=XAxLNDkNt5j-Jw8Ax1d3KwFxhxTuOg79NpoRoOLgeyA,1190 +torch/include/ATen/ops/binary_cross_entropy_ops.h,sha256=UkeRu9BxjqDJZfmPx2z6hHxTqezXhGOO0oMdx99slp4,2275 +torch/include/ATen/ops/binary_cross_entropy_with_logits.h,sha256=A1R86gvQIra5Ca9VZ0RVHFVy4_sdqXVEoFvwJGFkOPg,2153 +torch/include/ATen/ops/binary_cross_entropy_with_logits_compositeexplicitautograd_dispatch.h,sha256=-F02HBOCrbZBYZVIyMaD3t7bM6TdH2vqVLaq3QLq0mk,1490 +torch/include/ATen/ops/binary_cross_entropy_with_logits_native.h,sha256=9H3pm4qRDjKqnYaEfJRfcKqOvHv8N9FRa7PgBxZmSZ8,930 +torch/include/ATen/ops/binary_cross_entropy_with_logits_ops.h,sha256=fd6mPNunIg1zU8hkYw29LO-ibAd-SodjqkLxsSViUKU,2663 +torch/include/ATen/ops/bincount.h,sha256=4rBEyoIiwtAJUeaoW2gh6Y0ocHgWZAdqxhvjnV5Hngo,4363 +torch/include/ATen/ops/bincount_compositeexplicitautograd_dispatch.h,sha256=b9xa83IA_4w7J6CDpIuzxSHZe75Bb85yBf-twG9Uc8Q,1346 +torch/include/ATen/ops/bincount_cpu_dispatch.h,sha256=mG3FXcJkTmFfRJwODOVy39D-ZXlbShpYK5xy0J4YanQ,951 +torch/include/ATen/ops/bincount_cuda_dispatch.h,sha256=_3jadsLGlPBZ17PhqU3VP3F-XueM7ZKflBry6fy_shs,953 +torch/include/ATen/ops/bincount_native.h,sha256=2c6mqcq02g81Q9hjkHCKrndu1FOJtNusji7ZDLq1r1Q,861 +torch/include/ATen/ops/bincount_ops.h,sha256=VFdncefydnWfIZGt5KnRKEj9ZKw94lvt7I0Xm6zMr8I,2055 +torch/include/ATen/ops/binomial.h,sha256=9UOUGzRqOucnFT5iHf3CDYRLwHrL53cN_VPlv_d78N0,1498 +torch/include/ATen/ops/binomial_compositeexplicitautograd_dispatch.h,sha256=Yx5xZt492s6jydq15LHto9wIZVi5nP4kTLreAtqk75A,1046 +torch/include/ATen/ops/binomial_cpu_dispatch.h,sha256=yB6AHKWHblgxv-90E8g4S709U5CmeAj7jqt7bx-EJpQ,828 +torch/include/ATen/ops/binomial_cuda_dispatch.h,sha256=LP8EmSKoQ8OXKF7yD7ikbCXoNRjPvQdhTiqIDO7iGIM,830 +torch/include/ATen/ops/binomial_native.h,sha256=Lo5Xmdb0KuwZp7SVuMig4IdDd-AWiZ1kmEkgYIcqo1s,886 +torch/include/ATen/ops/binomial_ops.h,sha256=ohT8Ybsyf9YLAM-5y719dFfzeSLQu7x2N2hNW3-lptQ,2057 +torch/include/ATen/ops/bitwise_and.h,sha256=FYT5IIAaGu9i3lz6eIITZtwt5d-2-9fPAUFUM6IokMA,2876 +torch/include/ATen/ops/bitwise_and_compositeexplicitautograd_dispatch.h,sha256=HicBtTGVjbFZ8mnqld-a99EeSIbP_8-2Bk0dzI5F-Xk,1429 +torch/include/ATen/ops/bitwise_and_compositeexplicitautogradnonfunctional_dispatch.h,sha256=LrkdRjcVkcXDme_l9QXfTopKJoZxbN0kDuCD85xHHQ0,927 +torch/include/ATen/ops/bitwise_and_cpu_dispatch.h,sha256=o-cFJiwDYXkIqLlB3GHA4hbFNfpfbxWwtCKbDH-1qTk,1078 +torch/include/ATen/ops/bitwise_and_cuda_dispatch.h,sha256=WmXg3twzRU5qfhlbpoLSXuCgMBl6cojlXajsMQwmBII,1080 +torch/include/ATen/ops/bitwise_and_meta.h,sha256=VJ4g4rv_SDvx4dj1dkp0hLer32WLx1yI6Ry1DwbStLU,632 +torch/include/ATen/ops/bitwise_and_meta_dispatch.h,sha256=QudXh8DsYNSOiGXT_NaZnau-xBNlNkKtNh9S66rExxo,1080 +torch/include/ATen/ops/bitwise_and_native.h,sha256=SPGTExp_6elPf2gpciClN0HXeFq3rE1MbWdd_C2eKK4,1156 +torch/include/ATen/ops/bitwise_and_ops.h,sha256=4uT2Nx3rcktQvEiOys3aAJwatYXNFKZ8p0cNLzg7rwA,5844 +torch/include/ATen/ops/bitwise_left_shift.h,sha256=IW-7Ezntvk4ahGRikynoIQakMo-NKlO9jQLkmvJZM1g,3114 +torch/include/ATen/ops/bitwise_left_shift_compositeexplicitautograd_dispatch.h,sha256=o52IOA4c-DbwaL0LN1A1pLTkz2ud-Kv1d50-8yZhdtw,1478 +torch/include/ATen/ops/bitwise_left_shift_compositeexplicitautogradnonfunctional_dispatch.h,sha256=dxAX9RiutHUKDUwnL7FSaRSKwtua92f-xijH-zUwCms,941 +torch/include/ATen/ops/bitwise_left_shift_cpu_dispatch.h,sha256=_o2PasYc_IIUMRMfMNu8Gs4Z5ozkkjyxsvbuySmfUZI,1106 +torch/include/ATen/ops/bitwise_left_shift_cuda_dispatch.h,sha256=4-EX5Ljw9eBoy89zulV5oNxmkLyo8aWbXSm3JUrSV-s,1108 +torch/include/ATen/ops/bitwise_left_shift_meta.h,sha256=9_GKw9hxTyCxu9Y0L8re8tsUW87dZUt47U_QA61L7w0,639 +torch/include/ATen/ops/bitwise_left_shift_meta_dispatch.h,sha256=WBKYkAkz6Rjb74ArdcyVCyUBwY0poaQJ4P1_YlCwUIQ,1108 +torch/include/ATen/ops/bitwise_left_shift_native.h,sha256=Kfkoy8vHtCpbb1-cKVfBAAR5ckviRzB-EoQGzrazY8M,1212 +torch/include/ATen/ops/bitwise_left_shift_ops.h,sha256=Lr6fRc8MImhtwWUsciFMjfXkiSHWQD5-qHFWmB3wWPw,6075 +torch/include/ATen/ops/bitwise_not.h,sha256=ellbwvWarGq_7uKvczg7HH5h07eoJwDN-GLRuFEKTiM,1117 +torch/include/ATen/ops/bitwise_not_compositeexplicitautogradnonfunctional_dispatch.h,sha256=HvCpWTz4xF0x_OnnYVzCNnWRkTZAJbjOXvEcOEoRUnc,875 +torch/include/ATen/ops/bitwise_not_cpu_dispatch.h,sha256=BEx6SGaV2kFf-uIdTQy7-2twl2HfQ5ihPF0eIdeTang,974 +torch/include/ATen/ops/bitwise_not_cuda_dispatch.h,sha256=rTnkNZ_lu4rfwka9Qm261RHQzKX2Kn4pWqPgblXNuz0,976 +torch/include/ATen/ops/bitwise_not_meta.h,sha256=nVT_hI4LquiX4Ir_KrSvHgaEbMfBHrLFIXjo-M5ytwI,599 +torch/include/ATen/ops/bitwise_not_meta_dispatch.h,sha256=0hF_4tDZHNpJXLiIroL9l_22nyzJxtuH672Wp6pxQtM,976 +torch/include/ATen/ops/bitwise_not_native.h,sha256=LVUyMTkIepOM-szTjxmDa04nTFOAlnU1o4hjcdGYYGc,634 +torch/include/ATen/ops/bitwise_not_ops.h,sha256=_IFAEt_tdfGGALtc_fmB17lfJ2SuG5wdTcrFV1Z1fCA,2142 +torch/include/ATen/ops/bitwise_or.h,sha256=hHGLXHJPJE7HIlyxgVd6ADKkszRi3OyKRDk_Tpl-8PE,2848 +torch/include/ATen/ops/bitwise_or_compositeexplicitautograd_dispatch.h,sha256=Tztw7zJOklukXYgQUrub451lQE5k3liWjycRfuxBmLI,1422 +torch/include/ATen/ops/bitwise_or_compositeexplicitautogradnonfunctional_dispatch.h,sha256=e6rsFAzlC_fCqLSRWW5SvPtan2fT7WilL9hUeTttIv8,925 +torch/include/ATen/ops/bitwise_or_cpu_dispatch.h,sha256=5NQxrSp0dIVvUvih8mgNt1Zq6YJcCfvF6zJBrxcEW1k,1074 +torch/include/ATen/ops/bitwise_or_cuda_dispatch.h,sha256=qbNAzLfQzWPOYtNGdP24o81ztLZQPuMHNMlX0cDKuiQ,1076 +torch/include/ATen/ops/bitwise_or_meta.h,sha256=WFp_MT8QvRRX7n_PU01WMRuIYJGZ47PBXekGM6zADh4,631 +torch/include/ATen/ops/bitwise_or_meta_dispatch.h,sha256=n2tzvM2CqwOTsLPf8gB9S4OhOHxtpEcs-3gIr_y3kNI,1076 +torch/include/ATen/ops/bitwise_or_native.h,sha256=z8Jqh74fQGLBuAundwwHC6IcJq3JaWT4DE5OX7eypC4,1148 +torch/include/ATen/ops/bitwise_or_ops.h,sha256=IHzgzU7OD3p-qCLR9MElucGrl5nRGMUFoNueI612yX0,5820 +torch/include/ATen/ops/bitwise_right_shift.h,sha256=XVFvDxGu20lZ8guJwfakZU2Znt2PLQU6pQgwoPl0S28,3142 +torch/include/ATen/ops/bitwise_right_shift_compositeexplicitautograd_dispatch.h,sha256=-6OmIRkg16UXBF2cmroRp3yRY7yDtbTS34ek8_Qm8iA,1485 +torch/include/ATen/ops/bitwise_right_shift_compositeexplicitautogradnonfunctional_dispatch.h,sha256=6qlqior1aOo9PnMb6ZILQ-K2OaYgEv9XM4q5fWwjb5U,943 +torch/include/ATen/ops/bitwise_right_shift_cpu_dispatch.h,sha256=ABcFQn9vOVOGTf6In7FQinTY_cQ5demhhuGR22zuEtM,1110 +torch/include/ATen/ops/bitwise_right_shift_cuda_dispatch.h,sha256=yfuoEARzcLAgu5l7Xc2AHlzUVTXlR99l5Z3AKCevuaI,1112 +torch/include/ATen/ops/bitwise_right_shift_meta.h,sha256=igoeN1GY-4ufjSUP8jfZuP7K9E7nE4xCLQijx8ZrrmE,640 +torch/include/ATen/ops/bitwise_right_shift_meta_dispatch.h,sha256=feXagVuB4TNjNnLg3Ma8t7gXhMZPogJ8yZmIgdYOtnI,1112 +torch/include/ATen/ops/bitwise_right_shift_native.h,sha256=XJiYnOfYc5cSvKT5VjO84vyqyPOTSgjUpa6u0sG-xBo,1220 +torch/include/ATen/ops/bitwise_right_shift_ops.h,sha256=wayxVILt83U2aj8sWyfQkkn9W470VarpjEaMcQgRxqE,6099 +torch/include/ATen/ops/bitwise_xor.h,sha256=IQ1BtH1sn6aZzYsG_CiutAR6IQ93bBoJZwaki99ZHLI,2876 +torch/include/ATen/ops/bitwise_xor_compositeexplicitautograd_dispatch.h,sha256=0vHtldQKeZxfD1peDHiy2ld8BzDRDcclEQvb3eDBiMw,1429 +torch/include/ATen/ops/bitwise_xor_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Z5ednpJOQCgWzz_z_Y6QDs7HZOMIfClj56RYMZX9290,927 +torch/include/ATen/ops/bitwise_xor_cpu_dispatch.h,sha256=19BttxWsFk1jdfxGbMlmBsXB4GE_61IxYdSy7U3eQZo,1078 +torch/include/ATen/ops/bitwise_xor_cuda_dispatch.h,sha256=KzEWGA7YC5NCR6SCsWjqrCdPDIlcLKB-fJNuj00vC_E,1080 +torch/include/ATen/ops/bitwise_xor_meta.h,sha256=FaPKMP1wN6j67GQvg22BFmEXMM1x7uIH3Uzrdg7EyFY,632 +torch/include/ATen/ops/bitwise_xor_meta_dispatch.h,sha256=Lr6YhTlDiJaOSphOPFsc8RO4ELkZmZh5ch-F9PF4Nkw,1080 +torch/include/ATen/ops/bitwise_xor_native.h,sha256=3mUbBDTpb0HnfD5o142TcnJdJOnws2In-63r1Hbb1t8,1156 +torch/include/ATen/ops/bitwise_xor_ops.h,sha256=V-8yu0BtoqR-UNNlE3k_fzo_1epZyv45hi3YFlphJXQ,5844 +torch/include/ATen/ops/blackman_window.h,sha256=NzW2Q4VUv5DJCO9VITY3z4gDsTvARSi7iCcIjekUTC4,3479 +torch/include/ATen/ops/blackman_window_compositeexplicitautograd_dispatch.h,sha256=AFNQoa88tjIeld1WWq2ZnVqid3rQLefbi9dKMSe9f7I,1736 +torch/include/ATen/ops/blackman_window_native.h,sha256=9Udkjb4bPynMTSZjJDpYt-h-HG6Eq5yf7vGi9u-MjH8,1091 +torch/include/ATen/ops/blackman_window_ops.h,sha256=3Kstiu2krv6zhcmREkh895Pnsjs_SxcjrCgS8jQ_ahw,3918 +torch/include/ATen/ops/block_diag.h,sha256=t8zFyPO7yUj6gieC_xSD24hsvyIQ0Ujy3c5dhefLiPE,1128 +torch/include/ATen/ops/block_diag_compositeexplicitautograd_dispatch.h,sha256=LCZveoqyxZ_Fg7Qvr0qtPNPMJarLBPS0PuPk0Pk9zos,955 +torch/include/ATen/ops/block_diag_native.h,sha256=NBsFdty974NcCdw-fifpJUAjJ23mIVn8cgDddpe2WY4,582 +torch/include/ATen/ops/block_diag_ops.h,sha256=UqPp6JdgeKFZu3xve9JlByQUCR5NRtCzWS6-5vrFAbs,1609 +torch/include/ATen/ops/bmm.h,sha256=7DE_pQozPM5N5cbcotmlerq8sOvHQdmBFHGLxaksPKs,2026 +torch/include/ATen/ops/bmm_compositeexplicitautogradnonfunctional_dispatch.h,sha256=2PhthoqZ1nqmaWrwmT3-dCuskm2lZjcyrtlgUR1MG8s,835 +torch/include/ATen/ops/bmm_cpu_dispatch.h,sha256=OeA07IHdOn7HYFQX1cStCOjJdFSgNugrClFEVe6Ng2k,968 +torch/include/ATen/ops/bmm_cuda_dispatch.h,sha256=Smz4-6OY8GmmXA-HVGzq68Yasw6HE_XV14SNo1CWiL8,1328 +torch/include/ATen/ops/bmm_meta.h,sha256=5L9_mfPlQPpH3Z9dERQdldnll5tOIKbhTuVVLpoC98E,616 +torch/include/ATen/ops/bmm_meta_dispatch.h,sha256=ZHksbdFcn-MSv1XmcnCrCMHpqd-rVKHTYCPpkjAiGOA,970 +torch/include/ATen/ops/bmm_native.h,sha256=lCVI-f-1a1x_efrP5s10eB3-MVsfXbpO9S2biy2Vv8Q,1753 +torch/include/ATen/ops/bmm_ops.h,sha256=caMkjCRpZFjBVWABWEijfmjZm6Ks9gV1hwRVfP2AE90,3210 +torch/include/ATen/ops/broadcast_tensors.h,sha256=RyTca5gswFUFaGEFu2zBM7H3K-xXsdhTpGsQ2MAinYE,736 +torch/include/ATen/ops/broadcast_tensors_compositeimplicitautograd_dispatch.h,sha256=NkN-MkyW8jq-opdaYjyu6tfMAd8EeAAST_bQcwFEjeA,812 +torch/include/ATen/ops/broadcast_tensors_native.h,sha256=rLLFJ2KglsuNFRiHev2Ur3tXMS8eyZkdKqrm6ynSlIw,522 +torch/include/ATen/ops/broadcast_tensors_ops.h,sha256=FnZW4bmQWdJROgkIURy36SdebbwetBIf9KKHiGOjPZU,1065 +torch/include/ATen/ops/broadcast_to.h,sha256=ElDWbe16RPmAzwt-lNzD-KwTQ28Dg3TkFNkWFJEjzws,1499 +torch/include/ATen/ops/broadcast_to_compositeimplicitautograd_dispatch.h,sha256=vVMB5IkQ5yYHkTdgtKFFAQaaWHr6KwY_Hp4vKecC4AM,909 +torch/include/ATen/ops/broadcast_to_native.h,sha256=xQDWE0AL2V24IcQpWPiuUU124htMrS_BkfIfySzJIyM,536 +torch/include/ATen/ops/broadcast_to_ops.h,sha256=G0J1Cw7mbOc_A7K2IwwBUruowqwahq6iCypC2wk9NT8,1098 +torch/include/ATen/ops/bucketize.h,sha256=O9QKyCQj00C5sh0lTE3aZzqcDbty4B1a8UwUZ_WDqYU,2689 +torch/include/ATen/ops/bucketize_compositeexplicitautograd_dispatch.h,sha256=rd86uX1L2n5Hw1EhMDNxHTVELljDltMlsisZgB4bi_M,1027 +torch/include/ATen/ops/bucketize_cpu_dispatch.h,sha256=AaZthxO54dxhcKfIo6XrEjwtgn0s4COF-uuMzCSg-1Q,1241 +torch/include/ATen/ops/bucketize_cuda_dispatch.h,sha256=RVKGTh_lxdGdNjfj_JCr5N55XKrVe1aefpXgkvmyPZg,1243 +torch/include/ATen/ops/bucketize_native.h,sha256=JHTT0uiqohIirbTuq_Nkky5pJXPPQ_SZPZLOpWz-nqU,1415 +torch/include/ATen/ops/bucketize_ops.h,sha256=bxliN1eQAdfSXiDQRYNxTZ7vUtxlc159Lno4cQH07y4,3659 +torch/include/ATen/ops/can_cast.h,sha256=KQrMC8P2LXZJF_rxEjNufeWAnMBtjUBFhpUvz_AfeeE,709 +torch/include/ATen/ops/can_cast_compositeimplicitautograd_dispatch.h,sha256=CSRddvIQ3cUXsUTT-FaRg_LIIHBSauIrmS0nt5ulN90,799 +torch/include/ATen/ops/can_cast_native.h,sha256=mw4KtwOijX_fmZikzLeIg8fXaujzhPXGN1WmpHiwDzo,509 +torch/include/ATen/ops/can_cast_ops.h,sha256=s_WgqdX9p5f82rHmzLe7t0nSQI9KbZ36kizgtvP7Wv4,1036 +torch/include/ATen/ops/cartesian_prod.h,sha256=w1ClQEeEukFJnrCB429RLDLkF9oPgW5ZlcLFFnqWnAg,707 +torch/include/ATen/ops/cartesian_prod_compositeimplicitautograd_dispatch.h,sha256=K358lBsp6-6jwLgwi-sLbqv4f4nFCJ0AyvJPah2dkbA,794 +torch/include/ATen/ops/cartesian_prod_native.h,sha256=trOi5fjwUjKsedLeW0LTUW5uvfFw9p0nYbCeKTCwSHo,504 +torch/include/ATen/ops/cartesian_prod_ops.h,sha256=zNuBPdJ2NlelmwylI1zKyJLjn0Dc792_Pt_N1e6kubo,1009 +torch/include/ATen/ops/cat.h,sha256=j-WkeHFAFKFRL6Ihe11ND3Ku6T-BwkUsbTtvsaz6GpE,1859 +torch/include/ATen/ops/cat_compositeexplicitautogradnonfunctional_dispatch.h,sha256=t__-cl53QiYtB94Zmw1JO1DhlOxJZQIHIA5c70kcmVM,836 +torch/include/ATen/ops/cat_compositeimplicitautograd_dispatch.h,sha256=qOr20uUe7n_v_wzndSZe-Rek1TpFrdr8kSV5K8vsz3Y,985 +torch/include/ATen/ops/cat_cpu_dispatch.h,sha256=-76NwBlHvzWd0VhkIqA8rLXNUEG5ZvADzHGR8iURsFc,969 +torch/include/ATen/ops/cat_cuda_dispatch.h,sha256=znbTEHV5qvo9MP0Pz8sjFnp6n8ToNRrJZcRJuTeLETA,971 +torch/include/ATen/ops/cat_meta.h,sha256=lqnGbFSeqsZj7ibs_6XeA4OHsKwpp2g5mogEQXHRXwQ,4920 +torch/include/ATen/ops/cat_meta_dispatch.h,sha256=JrxVPGdeJtqNsA9uNodinDON-Orz5w9UnglH2-2NBqk,971 +torch/include/ATen/ops/cat_native.h,sha256=_lfmW8NBfYlG0ypT9P07BbihU0V2WykyeV8itlZ-xrc,1583 +torch/include/ATen/ops/cat_ops.h,sha256=4k8S2AbfIQVitA_kMQgsf6OQti2GV0u2xp4amvsgLsI,2978 +torch/include/ATen/ops/cauchy.h,sha256=wXjxf-iXvHMUqaIDXxxGgodeiz57IEyyWykevdkiQyU,1573 +torch/include/ATen/ops/cauchy_compositeexplicitautograd_dispatch.h,sha256=lvUEz8tY33wLCVgMV9WoGOGORiiS6t9ffC2MaOKSXls,1197 +torch/include/ATen/ops/cauchy_cpu_dispatch.h,sha256=W0c5q2NJQobbGWYRj6PSx7FLGH9XSC09ihmqRdyKYDI,830 +torch/include/ATen/ops/cauchy_cuda_dispatch.h,sha256=vNf68oSQ7XAFpNtLoUa6fij1bH_bI6cMstQ7REE8F2Y,832 +torch/include/ATen/ops/cauchy_meta_dispatch.h,sha256=wZ2CG6Sbc-iJjzFw5fCt2Pppb5DFIUkJzpkUeq_qXPE,832 +torch/include/ATen/ops/cauchy_native.h,sha256=SkxojkSzB6bjnH-cY_BbM4rO8sFAY7Ylogu3RB8-5xQ,879 +torch/include/ATen/ops/cauchy_ops.h,sha256=M4ZBWP1nxbrmZrr7XJtlcrYJHyogSKqF4sDH7YckFBo,2847 +torch/include/ATen/ops/ccol_indices.h,sha256=RyKrdXlpMSnrLrnRiWlwOrwhFOF1-J2y1DX-UtAqwRA,536 +torch/include/ATen/ops/ccol_indices_compositeexplicitautograd_dispatch.h,sha256=pmTNMSoS-Ibf2lkWGTMstkA_cWR_V-_Trf21FSAFkQE,793 +torch/include/ATen/ops/ccol_indices_copy.h,sha256=8TXAc_6Oxd8uiz_Odpu3N9PnH2agKpZuulC3EjrRTa0,1177 +torch/include/ATen/ops/ccol_indices_copy_compositeexplicitautograd_dispatch.h,sha256=9MPCVdN9rRyXLmlrIHj3Fy63fJ-mOQAOSGisYabTFMY,913 +torch/include/ATen/ops/ccol_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=hO3_Y2p9GMtq5jVWU612XuBDyo653Lj1ZBElSuAOzUQ,824 +torch/include/ATen/ops/ccol_indices_copy_native.h,sha256=sXtzTZgw27LCwG5XlOLBi48lhj1ywrRJG7hQl_pRNtY,598 +torch/include/ATen/ops/ccol_indices_copy_ops.h,sha256=8Opg7yDqKSMHFMxSIdmjT3h_DfdIcAXlV9lL1bnO3gM,1653 +torch/include/ATen/ops/ccol_indices_native.h,sha256=9VqreGwWf2fXzrHUFQrA4setcFDG-0oe38q393Bbzp0,583 +torch/include/ATen/ops/ccol_indices_ops.h,sha256=qJz97lzVAGdJF3GhN_5v3qUJ3UmQwgMyX3H_lULANZU,1010 +torch/include/ATen/ops/cdist.h,sha256=yv2SgYkCJjgDyysci1hswVy-XySbASPcIdu4IZuGRnQ,814 +torch/include/ATen/ops/cdist_compositeimplicitautograd_dispatch.h,sha256=DimkmWxsVsRbRPjaeiZn5pL8LfB_DoZQfDABhPM1j7s,873 +torch/include/ATen/ops/cdist_native.h,sha256=mCk9GfdJQ5qqJ9eUAan0BN_nzWEHuK1-Zy1bdNvMJws,583 +torch/include/ATen/ops/cdist_ops.h,sha256=NWL5aYM7FNXDHltOTK_E98nPGBYyFp9r2BkBsycROcw,1221 +torch/include/ATen/ops/ceil.h,sha256=8banBgn1BrEf08b3o84msvyoPaS91dluVgs-hEGBUp4,1188 +torch/include/ATen/ops/ceil_compositeexplicitautogradnonfunctional_dispatch.h,sha256=6S6Lf3ua2fTvNGMr2ewsBx4zNuvuoDzdfmAWccWOsWM,861 +torch/include/ATen/ops/ceil_cpu_dispatch.h,sha256=vJUpyThq7wJFeu9PZgxOP9Fpt79X2rL4Ukv2zCnEr0s,946 +torch/include/ATen/ops/ceil_cuda_dispatch.h,sha256=2xKW-Cmm7uh5p4lnpzqXvf7T1iT0O9bk5W6tzudpUTs,948 +torch/include/ATen/ops/ceil_meta.h,sha256=am2CBMUYaePtpE6udlGCCrdwtKKUzGEVuBXqGnr43bE,592 +torch/include/ATen/ops/ceil_meta_dispatch.h,sha256=u6ftZhJyfHtYKCclxL3WQx5CXwKEHInpp80e73yzLls,948 +torch/include/ATen/ops/ceil_native.h,sha256=XSnsft1jHM0L-pSOTV0C5SMTk2VUeOKHiv8N_EO6mtM,1027 +torch/include/ATen/ops/ceil_ops.h,sha256=uqsXXgOq_0EzTdF97en0hVo12eE7CuLPgBsPFfKLCc4,2079 +torch/include/ATen/ops/celu.h,sha256=lKslGLXQYMw0XMvKkP6wV0px6vvXwKgMj50AiVflMqU,1404 +torch/include/ATen/ops/celu_compositeexplicitautograd_dispatch.h,sha256=QONqnAiTKnPd5_8JO8MqAcDy9bqUnv3x_8O7YLx-Lwg,1106 +torch/include/ATen/ops/celu_native.h,sha256=CMyU51YVzGslvZ_hJXDqHSRNEzayJmeJ0OnhomnbRYw,708 +torch/include/ATen/ops/celu_ops.h,sha256=-UHfZSVgmm1KF3b3z1ceD0eMzRf4ciHRLIjEek7lGPc,2349 +torch/include/ATen/ops/chain_matmul.h,sha256=i43fSBShyX34gkIn7jY5UZXOXsZHgAYvOs2NunONK6M,1157 +torch/include/ATen/ops/chain_matmul_compositeimplicitautograd_dispatch.h,sha256=5H-mk1ELa3CCFW6Y4zpMuz7mI-62Wni6eoGQeD34PsI,964 +torch/include/ATen/ops/chain_matmul_native.h,sha256=pqwqtPVpsbTq4i2EruMrU598_1yW4WD_rjwFGHGfe3w,588 +torch/include/ATen/ops/chain_matmul_ops.h,sha256=11uxxohFcdHNLNQ0ax9UzCmBa1RsKD5dI08IsDrEPpk,1627 +torch/include/ATen/ops/chalf.h,sha256=LWaOK3o2OHx0rFmpNeBTga-nDWDLE_RVBpcPWn1dVYM,529 +torch/include/ATen/ops/chalf_compositeimplicitautograd_dispatch.h,sha256=8J4Zb4SPoFBQOtDBOgVo-3t9l4ei4UZGzhAVvAAI3wQ,850 +torch/include/ATen/ops/chalf_native.h,sha256=2ir_2yy1l6KPTpA8o6XEKVVITbV-ctKxauV0OICTejs,560 +torch/include/ATen/ops/chalf_ops.h,sha256=OBrrN6NGpwYlyqHRz4m5DhTmoNnftxRK-HnPnViQQR8,1153 +torch/include/ATen/ops/channel_shuffle.h,sha256=oMUbXf75jsmqoAzJjW-irD1ZjGvAgWC-Zcfluu-Ujv0,3658 +torch/include/ATen/ops/channel_shuffle_compositeexplicitautograd_dispatch.h,sha256=HOq1eX-Y9KO5XXRR5mmF_wUpQRr3PPvxzWfuwSBYRYQ,1172 +torch/include/ATen/ops/channel_shuffle_cpu_dispatch.h,sha256=iribNdSHHf8y1ogLZv0wL4zbh8G0l_wi2OXHCl4e6Aw,859 +torch/include/ATen/ops/channel_shuffle_cuda_dispatch.h,sha256=CubBpIh-e0QkJSblaWBpgZWqqXbgPri67PCYDwjIzXc,861 +torch/include/ATen/ops/channel_shuffle_native.h,sha256=BNq-ZfDBJkzao8t-fKjASVXWQiUD2CY25nORExtffD4,731 +torch/include/ATen/ops/channel_shuffle_ops.h,sha256=LPlKjHGKw1TmB8WmwBE2JAormc5Fx2Zvk7_28-9jXF0,1777 +torch/include/ATen/ops/cholesky.h,sha256=2iTO-X8CUzjUQ4czhykFSOGYE-VNhE71esT8C9Qk9eM,1210 +torch/include/ATen/ops/cholesky_cpu_dispatch.h,sha256=R-dqHBYv0Ld1l0NjbAhw7hhEgWFH5p_j6_ssTzgWb5w,956 +torch/include/ATen/ops/cholesky_cuda_dispatch.h,sha256=aMBm11mXX0HHbG4O5eHwb3GtoJMwWal5SKRzZvpdpyo,958 +torch/include/ATen/ops/cholesky_inverse.h,sha256=qAfy7xHHuMytPOocG_Vpb2f37mYKqc5SttTOuWBXhGQ,1290 +torch/include/ATen/ops/cholesky_inverse_cpu_dispatch.h,sha256=MqZ3xtMOirG4ZDx8aHb6WxF-5VYcTh8GhQvirlhAxhk,980 +torch/include/ATen/ops/cholesky_inverse_cuda_dispatch.h,sha256=XHkoEqIrAExteL3mR3pQMi-dN-6iS1mde02ZgQFYv_c,982 +torch/include/ATen/ops/cholesky_inverse_native.h,sha256=zhE4qIlcnWWseld20dThYPn-wkfOUOKIM5bYMCxrO-k,626 +torch/include/ATen/ops/cholesky_inverse_ops.h,sha256=gKPMY7lXvQudKhixnXWedxWBMtfoH4wbtsVoLfuQMNc,1743 +torch/include/ATen/ops/cholesky_native.h,sha256=-VO_qKjf4i3PLqPv7uqZVmuQN9NB49_JnobNYmPPW7k,610 +torch/include/ATen/ops/cholesky_ops.h,sha256=RIz6M1LB15O9MpYJ5UYmaFMl3dBQXtZvi_8bkT6vVjg,1695 +torch/include/ATen/ops/cholesky_solve.h,sha256=0aSPRlknxLwEUR8B_Jig1jTVLaveARHwVo2Tjx_vXc4,1420 +torch/include/ATen/ops/cholesky_solve_compositeexplicitautograd_dispatch.h,sha256=0oAX61a5O70P-1KdbdxXxLuDvhDAKIG-Qlc58LbA5j8,1099 +torch/include/ATen/ops/cholesky_solve_native.h,sha256=oHcaifLmcRT-AUqwQcHETr4COJ1dqHxEqSWs4OgzVcE,676 +torch/include/ATen/ops/cholesky_solve_ops.h,sha256=C2PNg9fciufRKo6kdcVGV4tIGGfi_fmvNmZRswMksw0,1909 +torch/include/ATen/ops/choose_qparams_optimized.h,sha256=9v94fD8bQO6JL3C61Uu7fy9Uolsqq6EEAlUEYxDfesQ,926 +torch/include/ATen/ops/choose_qparams_optimized_compositeimplicitautograd_dispatch.h,sha256=1PkgD9f16NkbvVkIboqT4SRw81GK86ghra8qFvLrOfM,895 +torch/include/ATen/ops/choose_qparams_optimized_native.h,sha256=pTvdaLxnR7snPyKskXFGic8LJclgQy2WFq8BrDkC7cE,605 +torch/include/ATen/ops/choose_qparams_optimized_ops.h,sha256=PobSEjOj6YlClZa2A2Yfa-_jcDrGt1dr_Np9BNytUBY,1342 +torch/include/ATen/ops/chunk.h,sha256=C86ILyeBcxmKnj9cIym4pGYSXP2hcLNG7pSvQQC6zAQ,759 +torch/include/ATen/ops/chunk_compositeimplicitautograd_dispatch.h,sha256=WxvNsXFW-nXCDgZGih-vwUR3x_f5fSYhk7IHbQSdF6o,832 +torch/include/ATen/ops/chunk_native.h,sha256=ri4rm0SvHolF2r3fIwIeRmph9oDhAUoaQ_zF4Gh9HIw,656 +torch/include/ATen/ops/chunk_ops.h,sha256=3uy96vXii2wmwZUZTcKlMoh1NJQR5J0bCNtdntuF6iI,1140 +torch/include/ATen/ops/clamp.h,sha256=8eVwNLPD7i-p_nBnX8bnVJLp3LSncnbfxYvKM6NQJ3I,3018 +torch/include/ATen/ops/clamp_compositeexplicitautogradnonfunctional_dispatch.h,sha256=9v-7ufB1BziE-2MZYNPGCUyVcG4IMA6k6yLydFfd114,1338 +torch/include/ATen/ops/clamp_cpu_dispatch.h,sha256=CpFyQcP3pFRmNrnw7QQOJ_OG4dBn1fQ2e9g9Pw7xpWA,1931 +torch/include/ATen/ops/clamp_cuda_dispatch.h,sha256=eNFu4Hh0Lc9y85OSxyY207RTLxOSUhIw-BMe0rUBgBM,1933 +torch/include/ATen/ops/clamp_max.h,sha256=Hw_BESp9IodLFzR_ASKrsXyL1aU_QOWZsLDMrsyUKys,2359 +torch/include/ATen/ops/clamp_max_compositeexplicitautogradnonfunctional_dispatch.h,sha256=i3bYtbUqtQVuFvmWTnjzosybuMIOEP08wi_qYu9abqU,1080 +torch/include/ATen/ops/clamp_max_cpu_dispatch.h,sha256=nfysLEGlzPZNzPtdpJQwPkaDXTO6BMZkdseuOVJ91GM,1436 +torch/include/ATen/ops/clamp_max_cuda_dispatch.h,sha256=rAN3eL7ZEyZvpKyu2oHxmJY4Ur9Wncso3XLNf2T3WWk,1438 +torch/include/ATen/ops/clamp_max_meta.h,sha256=BWjmuIXTu_6TE02h_64XaTN81z_1EtFNooVLox3np-g,770 +torch/include/ATen/ops/clamp_max_meta_dispatch.h,sha256=SeIC0Jpd4L9S9eLBSNKq-fTr-VKaMBdDimzytVc1WaU,1438 +torch/include/ATen/ops/clamp_max_native.h,sha256=LtmF7vKcfFDF7PgN-dDXv2PZPUIa02DDtB_9A-Q_5OQ,840 +torch/include/ATen/ops/clamp_max_ops.h,sha256=vPKLpMy3OKXuaREHTIqrlI0e3UvtCiCuMENKGTyZF2Y,4314 +torch/include/ATen/ops/clamp_meta.h,sha256=lkVFy4JY14kAcHq5j8xOAmkdZU4pCELR3p4kx3oMK2g,822 +torch/include/ATen/ops/clamp_meta_dispatch.h,sha256=IE2JvEYkCU7ig6qWeRDKDTjZRyvLhYDpCQC3S1j56Js,1933 +torch/include/ATen/ops/clamp_min.h,sha256=Dzapjbcvdj5OScqi075fKKT1Lzb0_LLMdk0WxaPz3nA,2359 +torch/include/ATen/ops/clamp_min_compositeexplicitautogradnonfunctional_dispatch.h,sha256=3CJs5XfOlcAkJV5inaAozvD6TBHaOFPh4bT_imqKCD4,1080 +torch/include/ATen/ops/clamp_min_cpu_dispatch.h,sha256=LQ_m_egeIf5q3FziyiUN4c5JD5hfxjgP4NR44naN-zk,1436 +torch/include/ATen/ops/clamp_min_cuda_dispatch.h,sha256=wwNQSqLUN-f15mQIidjzEGnZTLp4ePyQXn4vokSrWNA,1438 +torch/include/ATen/ops/clamp_min_meta.h,sha256=JCtoroU61KIYbSNk75knwAh6LfWH1uIG2I9hGkDQTh0,770 +torch/include/ATen/ops/clamp_min_meta_dispatch.h,sha256=5jjdwukZVnVme6xnyz_SpTX2MKirVl9IzV5q68LuYZE,1438 +torch/include/ATen/ops/clamp_min_native.h,sha256=AfoerS--daaLcKXVNKAbPVZlaX_Zm_h2Y-rDGzyW-4Q,840 +torch/include/ATen/ops/clamp_min_ops.h,sha256=PeLHovNN3zYQIOv_gAMoAS8FWO3yJbNmfTFETSlj2Ls,4314 +torch/include/ATen/ops/clamp_native.h,sha256=1Z3SPoiY_LLVWiZgiXEa235wqpWc-pObbUqVr4ks0QY,1060 +torch/include/ATen/ops/clamp_ops.h,sha256=gpTkpQdJ5rW2jnxwOwDvNWiLVgHPG3mdwuHR00mZRNE,5406 +torch/include/ATen/ops/clip.h,sha256=YbEMT9SFYetZz2g3LKUejmWI0y_CO--9TC-y5ZrgPEk,2993 +torch/include/ATen/ops/clip_compositeimplicitautograd_dispatch.h,sha256=OIrdjOSNw9M5GkW24eR5A9-a5TLVRz3jNSonB1SYqrU,1967 +torch/include/ATen/ops/clip_native.h,sha256=9TPI1nMC7rxGoxHtbdk6LlTre7rD_ywY4fANymQzlhc,1366 +torch/include/ATen/ops/clip_ops.h,sha256=lOPx7KMi_Pa3O99fW4m-A2MVZC1euq4dAYB5UPoValI,5388 +torch/include/ATen/ops/clone.h,sha256=V8Fgrujd9nyd6PNpjjPbDdcvuca63WCEvuZJ87a6bio,1384 +torch/include/ATen/ops/clone_compositeexplicitautograd_dispatch.h,sha256=g4VaUZZXRy5522TJZqcE3V89JIFAKPrUyZOqWPwoCdI,1120 +torch/include/ATen/ops/clone_native.h,sha256=jE1PYjyTEHKmNP5XgxUT2_amqujgnAY2lLSQa1K7Rpo,1326 +torch/include/ATen/ops/clone_ops.h,sha256=scDDn998SW6x5rf8ltVSxxQS6YS590bPjy2EgtQ_PdM,1918 +torch/include/ATen/ops/coalesce.h,sha256=PCgsrxRtCvL8VQ823rgoZLdzGLHBh08EScNzyeNmlkA,532 +torch/include/ATen/ops/coalesce_compositeimplicitautograd_dispatch.h,sha256=_aOw2EPPSA5st_XgDaG6b7pdiqpw3FAcVqUheHjG9ys,789 +torch/include/ATen/ops/coalesce_native.h,sha256=Ki3LJgGqN5i4BT2O4SjotSVQhWJp6GoFZ2TtWI0_LIY,499 +torch/include/ATen/ops/coalesce_ops.h,sha256=AQBubyFzFepB08u2Ftq8hCprs8tTsnG4vsKHz1dO14o,998 +torch/include/ATen/ops/col2im.h,sha256=L4TUFiJE5h4nO3X9lwqay9Jc0YZ7peQEx9L7Oeh6pTE,5935 +torch/include/ATen/ops/col2im_cpu_dispatch.h,sha256=mhBsG7wp5gN1--0KVTZI8u7BtEp6MBITPWKYKwXQ1YQ,1947 +torch/include/ATen/ops/col2im_cuda_dispatch.h,sha256=UinSLNxeFJwrLcugyAa_SqDRzhWXEAAXwoOyUphVP1k,1949 +torch/include/ATen/ops/col2im_native.h,sha256=mPV-6eSOqz31hU2ylVaFgJzg6aIpoP7PyEAGs-HQuuo,1260 +torch/include/ATen/ops/col2im_ops.h,sha256=tPnkgzLrzJghng2Mm02RxMRkZ48kgd069GBpP_J43-Y,2495 +torch/include/ATen/ops/col_indices.h,sha256=iZlwoiSX3ljQpMMZrKuqwWDjHSyqvkq6U2lIh5R16vE,535 +torch/include/ATen/ops/col_indices_compositeexplicitautograd_dispatch.h,sha256=MvGSCEy1K97sgpgSu_2hSbbu5MKxBcrqkExI22LC0FA,792 +torch/include/ATen/ops/col_indices_copy.h,sha256=g-j1ff_cCHm2pqfrUxuXl6fj3kTmnx5P55gll5KDyZU,1167 +torch/include/ATen/ops/col_indices_copy_compositeexplicitautograd_dispatch.h,sha256=7OhGhoE0lnz0gxwg92rDrbYFdlVuN_UZhV3Q69QN2eE,911 +torch/include/ATen/ops/col_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Uqv5-asqPG1XIeDXCgdV_ykxBYr1AE0SBiMDfAyv-To,823 +torch/include/ATen/ops/col_indices_copy_native.h,sha256=14ibvE-okRCm11Ioey3LAkUpow8u3wHBlsOE-tUWKcU,596 +torch/include/ATen/ops/col_indices_copy_ops.h,sha256=BjIetwuOjBJMumixs-q1caitcmtTmwERFtePKIsIPfg,1647 +torch/include/ATen/ops/col_indices_native.h,sha256=Q21UvVrZXT6nGRJzB84NV_gI3FM1A1BtnwHYOl5488o,581 +torch/include/ATen/ops/col_indices_ops.h,sha256=zYkNNCwZRtmkhWjZwGZYACxmAT-EmYOQZnQkkZNyHsM,1007 +torch/include/ATen/ops/column_stack.h,sha256=WTh-fVpAtOI34H0lzk7kh2EJqpISzMlZ-RqIufWx5yM,1148 +torch/include/ATen/ops/column_stack_compositeimplicitautograd_dispatch.h,sha256=_IdwFCen6GkatAaXxhvFE06pS7N23XVYHcjwE8Fal1Y,961 +torch/include/ATen/ops/column_stack_native.h,sha256=klX_JX0Dh74Up1werV4rYR4PyLJI7ATr3NhaYubHHqI,586 +torch/include/ATen/ops/column_stack_ops.h,sha256=rq-GhdrTFnBrlglAZC7bWXDIVZAVk_DVxeSmeghGNqk,1621 +torch/include/ATen/ops/combinations.h,sha256=tOilGn30txi24Rmluu-3xLlqovB0AWRGKOmJKqxE8bI,793 +torch/include/ATen/ops/combinations_compositeimplicitautograd_dispatch.h,sha256=60qzt1eJ1LhRHtkBl_MCjfoHn5j9fn9zjASKxt3mqnc,835 +torch/include/ATen/ops/combinations_native.h,sha256=Pw_6JzWWA1Y0VlEpsaeAcjIzOLQ-CecxiWZwTbLAAeg,545 +torch/include/ATen/ops/combinations_ops.h,sha256=00KEdxRtnjp5mrvbeCloKWMCp0uszH6e_xDZDQV5WKs,1125 +torch/include/ATen/ops/complex.h,sha256=G99HZDCzl3kQHi3Aq9hdAU6JQqP04k4acvCDx0xhIsI,1209 +torch/include/ATen/ops/complex_compositeexplicitautograd_dispatch.h,sha256=2JcSwF320fuRzmTJctVJftwwaCVgpJOC24Pn68wSNCc,813 +torch/include/ATen/ops/complex_cpu_dispatch.h,sha256=7Sd4QEZ_h-s7wYeRNaynFi16X6ZLbYo7gAcbDZkZhRQ,899 +torch/include/ATen/ops/complex_cuda_dispatch.h,sha256=RCovE-mtOgHsYS8Yr1BPa1TFNykWkPuT7Tbc-dJm3C8,901 +torch/include/ATen/ops/complex_native.h,sha256=4d9zTvgx7ccWBshL5WtwKFWsL47SHO7aUod2qjH1QcM,628 +torch/include/ATen/ops/complex_ops.h,sha256=uPCGECnBVKdPLwm3uS5wNjbiRwB-4FlJ0ec1Q0GFpxg,1759 +torch/include/ATen/ops/concat.h,sha256=dNkTfBvGsq_g1NJu3iG2zHEmT9uZ6WiEGZvtZscZVWw,1880 +torch/include/ATen/ops/concat_compositeimplicitautograd_dispatch.h,sha256=aM6QnlnpIsjvccARJVLqIXi965x_NKE2SPV50UNs-xU,1248 +torch/include/ATen/ops/concat_native.h,sha256=KfOpx7GCVwsx50T95QAtQ3jIdEDdssTgWmSmuo4Jb54,768 +torch/include/ATen/ops/concat_ops.h,sha256=mJ6gGDejz34-x-UZyTXW_MmI3b9SPLmS4IxOlmS7m8k,2942 +torch/include/ATen/ops/concatenate.h,sha256=xcQK6f_Tzf3CF1W6ccaeaq5mBj_JYhuHvD2G1A1w2y4,1975 +torch/include/ATen/ops/concatenate_compositeimplicitautograd_dispatch.h,sha256=ArWuIA7tZgWySm_EGPs0QeAYJXuU5OiPRgmzFWhhv3U,1278 +torch/include/ATen/ops/concatenate_native.h,sha256=nHN7KeXhUBXgbE0QQqwxp7Itl3ctxc-Z4-Mx8SZabR8,788 +torch/include/ATen/ops/concatenate_ops.h,sha256=uMA1BRcE8Pnol6WjZzUc2Q_1-S4uZXDZF1JRLVRkZz4,3002 +torch/include/ATen/ops/conj.h,sha256=52yLvnVaIdmRe3BGzFHlb_u3HH5SUdBtfF1glpnFE_k,677 +torch/include/ATen/ops/conj_compositeimplicitautograd_dispatch.h,sha256=6tbNYo4RzyooIzrEV4_Husju8upMm6HKJfJkMpUHe24,785 +torch/include/ATen/ops/conj_native.h,sha256=DoKBkrwZSZorCp9GDGWWP50IQjrRzFBw6uhu0hD9f2M,495 +torch/include/ATen/ops/conj_ops.h,sha256=FZTOl-3LiYLzBy67ks0hCjFSC8xAIAET9NIoX3M4oLc,986 +torch/include/ATen/ops/conj_physical.h,sha256=d1GYI4sggHEd1Elg80jUQU121n3KnfULQSxE-BMz_58,1305 +torch/include/ATen/ops/conj_physical_compositeexplicitautograd_dispatch.h,sha256=_0rFF2eTc72z9PEw4d9yFgoBljUlwBXOW6hNG1H7yFs,791 +torch/include/ATen/ops/conj_physical_compositeimplicitautograd_dispatch.h,sha256=P6cm0qwqarMvL4VzVkgxwwlDakpHYQX41HQB5fQxiCk,794 +torch/include/ATen/ops/conj_physical_cpu_dispatch.h,sha256=3ISrvFgClyVIVK97zW0vTC2-PWbV_s85gbRlw0rz0RQ,861 +torch/include/ATen/ops/conj_physical_cuda_dispatch.h,sha256=Cn7oQF4guhsrJyISA3M-ozL4DRlFBBr4uK_G_SEm8gQ,863 +torch/include/ATen/ops/conj_physical_native.h,sha256=7VpeLUIwbtnH1xgCa0yBzYAVU1ucwLLYgGpTmcoIBiw,909 +torch/include/ATen/ops/conj_physical_ops.h,sha256=utKibn9-WfjAACfsXAZCM3dE0LUfXx24cM11Uo89K0Y,2160 +torch/include/ATen/ops/constant_pad_nd.h,sha256=J6PCua4RdhZfgzWSCq9mw3rwldjeUx9IVbAPKZQrwns,4340 +torch/include/ATen/ops/constant_pad_nd_compositeexplicitautograd_dispatch.h,sha256=34eNChkYWlcgSE8TlV8j_RVGMvmBSxdGmUEE8DPmiYw,1537 +torch/include/ATen/ops/constant_pad_nd_native.h,sha256=ajvhQ1D_fAUZZ-8IyzQkvxAS2Z15BgB00tRJ8lPT05Q,701 +torch/include/ATen/ops/constant_pad_nd_ops.h,sha256=1CeW4YpzSc0_uFAvIRJy46a2fHMiexqukfxM0UolGQE,1987 +torch/include/ATen/ops/contiguous.h,sha256=PAAoGHklaArxMlvXd8FzbJaCfQZGDYaz8C2t9scwjDk,534 +torch/include/ATen/ops/contiguous_compositeimplicitautograd_dispatch.h,sha256=S9JxIco4yVeqnl9fXtUDUC01FJiRdSI7DLNK1gUfo10,853 +torch/include/ATen/ops/contiguous_native.h,sha256=QZZR3w-LFCW9rNchWv-DoMTgUlHpX4-4z0MYMz4mcqA,563 +torch/include/ATen/ops/contiguous_ops.h,sha256=7KWX9jwR6DnqbNrmfxl1Uhp-FQQ0j28mAfcqpSIq4y8,1135 +torch/include/ATen/ops/conv1d.h,sha256=Yty6aD4HkScM-B5rt87Ve_XEiup-4HHZQ71IoicLqSw,4590 +torch/include/ATen/ops/conv1d_compositeimplicitautograd_dispatch.h,sha256=64bVt54xk3IOtd5_8QsUPsBPyN-Knj3HXCnb8O-jP6g,1722 +torch/include/ATen/ops/conv1d_native.h,sha256=vQ0ueE3s6xRHdJCwl4gloKvyuVtkx2E1FEZCKPUGPyA,1018 +torch/include/ATen/ops/conv1d_ops.h,sha256=zm_D8ps4-_ZaPbT0xzUWBATAAub5J3HNXgdZOxYMZsU,2717 +torch/include/ATen/ops/conv2d.h,sha256=Rrg4MNcJtnc6To_0T012Z8_ABlUyqZwsQzymkSJu9U4,4590 +torch/include/ATen/ops/conv2d_compositeimplicitautograd_dispatch.h,sha256=LfRbX-KIbGihUVzvt7GqXi0-Exfyf3azcTB0dRroGHg,1722 +torch/include/ATen/ops/conv2d_native.h,sha256=7tTZQp9nKcYdsnexd9r3dGm7rrvd5FapEjeWuqHFnoE,1018 +torch/include/ATen/ops/conv2d_ops.h,sha256=Io5NundyMFYw1naHAgss1BFDGp2XYd3RBKBxqzpBDW8,2717 +torch/include/ATen/ops/conv3d.h,sha256=-sHi7ifO2pUU5J0LcDqRtNfNONC6V3DIbRzCCItPA_w,4590 +torch/include/ATen/ops/conv3d_compositeimplicitautograd_dispatch.h,sha256=FPuJqX7hZmFZf5C-NDEaWC53CJwHAFtrG_H26nyU15E,1722 +torch/include/ATen/ops/conv3d_native.h,sha256=g-3qI2BcOmfz0CzApv_INbgF7fTffdgALy_RbxxikzY,1018 +torch/include/ATen/ops/conv3d_ops.h,sha256=F16GevDJ2ILeL6s2dRZA7Emp_NL5kympboRuDOUv2aE,2717 +torch/include/ATen/ops/conv_depthwise3d.h,sha256=k_Cup2teLw-jFNf3Ksdp9xBtRTlqL_ZsFsImxH-1UkM,7385 +torch/include/ATen/ops/conv_depthwise3d_compositeexplicitautograd_dispatch.h,sha256=Uvcwf1rBq9qNKViBpNKOHTvf1uekiawpzXO1Z3q_waE,1828 +torch/include/ATen/ops/conv_depthwise3d_cuda_dispatch.h,sha256=EB1Uw_oLrEjS40zpMk13YeE89jgt0qgECSAS9R6X8Vw,1189 +torch/include/ATen/ops/conv_depthwise3d_native.h,sha256=U95zbCLHkVy9CYZARjnuvg8Ssg-N7YvgyVbmzSUyfAk,970 +torch/include/ATen/ops/conv_depthwise3d_ops.h,sha256=gRYgcprd81JGd207QGRBUXvM-9y1oPMg2fOPznox9po,2903 +torch/include/ATen/ops/conv_tbc.h,sha256=BB3494Nz1tdi9yWmvIa9B4x_RDr4yn2z04fOadUbGCY,1460 +torch/include/ATen/ops/conv_tbc_backward.h,sha256=DANO_obsR85nldBM0p4bmuOIDKKbmxpNcqfg2W87RX4,934 +torch/include/ATen/ops/conv_tbc_backward_compositeimplicitautograd_dispatch.h,sha256=G4mXO1RMiBOF7giiwaaZ9K7BZCkpWSEyM53sR7AGnd0,925 +torch/include/ATen/ops/conv_tbc_backward_native.h,sha256=M920roM_VRjnldDzDMGotfSYiPWgUBGzq-dyZ4oDrm4,635 +torch/include/ATen/ops/conv_tbc_backward_ops.h,sha256=xZsXnj7Coz2jZm-KPt2DglVckIYeKeCi6ZnAyiX20XM,1447 +torch/include/ATen/ops/conv_tbc_compositeexplicitautograd_dispatch.h,sha256=4rbsyeqnmYsoBCZyI7MSA4VhPIYe1NE-xZT4_-rDGy8,1151 +torch/include/ATen/ops/conv_tbc_native.h,sha256=jzsaFDtjiV4DjzkEiTxXNNkFzR0RZ9cSCM0DTmZhcU4,712 +torch/include/ATen/ops/conv_tbc_ops.h,sha256=UbkEsjJRZ7zywyjShZZi-OejmiazJsRQuavlBldhmtI,2035 +torch/include/ATen/ops/conv_transpose1d.h,sha256=Em7sztcGWLLLSqGQF3PWs2qKKQ9KDUroLvo9eREb_ac,3055 +torch/include/ATen/ops/conv_transpose1d_compositeimplicitautograd_dispatch.h,sha256=wH9eJujDK5_eNkDZPSKaWLJ4p3BnPdPBYFAVVxQC0sw,1353 +torch/include/ATen/ops/conv_transpose1d_native.h,sha256=E2qO-Y9Alyrq3Tgml4fr0wrpFAC6cH652m678KODyi4,792 +torch/include/ATen/ops/conv_transpose1d_ops.h,sha256=9_zoo9_tsD4ksrg7JY1r1MSId7ksj89b59qX16vxar0,1739 +torch/include/ATen/ops/conv_transpose2d.h,sha256=4XEiW2PIWQf2Nev4U9WH-YvLCcKqI2cAzmkGKApVyjI,3091 +torch/include/ATen/ops/conv_transpose2d_compositeimplicitautograd_dispatch.h,sha256=85ZvXuBfN7Enp3JJCuZ5apc0xrE_oF2ss4ZBapnHcKA,1353 +torch/include/ATen/ops/conv_transpose2d_native.h,sha256=4RytatwIMMIW2IZvilFkZZk_Nbd87nzK5vEhHE5zwJc,792 +torch/include/ATen/ops/conv_transpose2d_ops.h,sha256=5irt3F61DkULj2SqhN3HyV0OmFwYXwU563HMwIprNNg,1756 +torch/include/ATen/ops/conv_transpose3d.h,sha256=YW1JDrAd01LM7QExSaOD_jy1FFWrV99jckPSpMDj9JI,3091 +torch/include/ATen/ops/conv_transpose3d_compositeimplicitautograd_dispatch.h,sha256=bvmI4FEaIz4acxNYofM3hp8Mbe6gNsifskGBFuc9byI,1353 +torch/include/ATen/ops/conv_transpose3d_native.h,sha256=XdRSw0mAbKK-jkzXKjLkTXeTKa22uGcV3gKjznbA_ow,792 +torch/include/ATen/ops/conv_transpose3d_ops.h,sha256=ENR6Nj5ejNVlsuM0gAHaFWifOJxXKUSunfu0ZsHBrT0,1756 +torch/include/ATen/ops/convolution.h,sha256=zUo_BmNgoWaO9W_m0wbrDuY2BIZ6IKOSdK7CV9j74eQ,8178 +torch/include/ATen/ops/convolution_backward.h,sha256=G0qNl7gS0kAg6oOxsaP6S29zRnhfB35AOUZ0SLGG3-s,11579 +torch/include/ATen/ops/convolution_backward_compositeexplicitautograd_dispatch.h,sha256=_HiIjk_yZGK_6xslY5brJW3qlK1niFdqinedb0jNxt0,3342 +torch/include/ATen/ops/convolution_backward_cuda_dispatch.h,sha256=CkB9YnwuJ58qexHevhJM0qCAumzTJb3K1pH2bjvjarU,1470 +torch/include/ATen/ops/convolution_backward_native.h,sha256=AxFyIb5ESH1EeJa2uilWZbpWjhRxsIkUXZjdX2pxm1s,1289 +torch/include/ATen/ops/convolution_backward_ops.h,sha256=8PNbKmjQ-9Nu9VTxX8VZqzp7TQvtiLBuL4FS9iWldAA,4022 +torch/include/ATen/ops/convolution_backward_overrideable.h,sha256=rd9VaB94QKiIDhJ9QDYwO-DSGF91MxBBNFebp9oSCZo,10776 +torch/include/ATen/ops/convolution_backward_overrideable_compositeexplicitautograd_dispatch.h,sha256=B_FQ67OONwyt3hbnr3p7ObTdqTwD_TSR7fg23J3XtjE,3195 +torch/include/ATen/ops/convolution_backward_overrideable_native.h,sha256=UE-RgRiUHSgfRDLfYPTdooymK_qC3-U9LQx0NIYn-jU,1240 +torch/include/ATen/ops/convolution_backward_overrideable_ops.h,sha256=MmBjG1nRcVEiUD2BXbEOzLvlql8vAvWCuuhLquaUMNE,3877 +torch/include/ATen/ops/convolution_compositeexplicitautograd_dispatch.h,sha256=bnnPDHdrZNmaA4S1Zrcp4jbH8dejOwgjxzwILhiOMQo,2531 +torch/include/ATen/ops/convolution_native.h,sha256=GRjWyS2OBU7HOJCxIy9Xb9mplVo9YNm2gVMX5CRoo5E,1033 +torch/include/ATen/ops/convolution_ops.h,sha256=By5fcK-fRUuQhCcjp3o3WH_A3a5-3LspMwLTrAhvOBM,3139 +torch/include/ATen/ops/convolution_overrideable.h,sha256=t5QqQS8rzij0mlOqDfPijxykEN0x6b7dBl-iINy-ZJI,8581 +torch/include/ATen/ops/convolution_overrideable_compositeexplicitautograd_dispatch.h,sha256=ubdDwdVVqG2AvUX_LbpkXH95PKZaHxCmXnpi3c3ceFA,2609 +torch/include/ATen/ops/convolution_overrideable_native.h,sha256=lA4kSRbc-UgfyAEW2XCj84ynjbZviE-rMm05cjlkyGI,1059 +torch/include/ATen/ops/convolution_overrideable_ops.h,sha256=LcUA0aFo_AbMGWmtui4Ly0ka_dmzwlNeFOXzhFOozO4,3217 +torch/include/ATen/ops/copy.h,sha256=1mYQMil3fTQCQkO9yzKj9No8MkFWEtxHa1f9jKqLheM,1356 +torch/include/ATen/ops/copy_compositeexplicitautograd_dispatch.h,sha256=kOPwiM1kkUiC2iJ8q0lo408jSS0qcdJJHIGXh7fwdng,1078 +torch/include/ATen/ops/copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=8jq_UeW5qQeQDowrBxebWtykk7-YmUB0Rv8HK_kG7WM,860 +torch/include/ATen/ops/copy_meta_dispatch.h,sha256=QdDXsEm_8vedlBPxAks4EF1YNGv26mcSgppvON7Bkrw,792 +torch/include/ATen/ops/copy_native.h,sha256=b0zyZFJcYxj5DdX50HuvtXTYEi6FO857BM9cltvw3HA,1313 +torch/include/ATen/ops/copy_ops.h,sha256=F-g9ZdQVXrXfHY7ZnyJogWWaxREtpENowvpxZ8AYzgc,2526 +torch/include/ATen/ops/copy_sparse_to_sparse.h,sha256=hh070TzcxAbycRluSMmpFQ5Zdses3T97j_zq9LKaZAY,1823 +torch/include/ATen/ops/copy_sparse_to_sparse_compositeexplicitautograd_dispatch.h,sha256=RhL77q--KyoVQz9O4Hx-x8F9AydnPPH2ge9AVn8Yxw8,1132 +torch/include/ATen/ops/copy_sparse_to_sparse_meta_dispatch.h,sha256=_XmI30A7PmL_98qvxm9dwOwZ533iXNXHsyigGH1-E8E,806 +torch/include/ATen/ops/copy_sparse_to_sparse_native.h,sha256=0p-W25LHXcH-1uHkA6u5TnsBXcBzyTHxs_TKZhtBCh0,804 +torch/include/ATen/ops/copy_sparse_to_sparse_ops.h,sha256=FQiXwFh1_qS2yL2WL_x6vPJFFoUCoEoM6dfNzKgbfCc,2679 +torch/include/ATen/ops/copysign.h,sha256=W9EkI_BTvzQFoJS_IHXsQOjtKJqyYOCck6oeY5SjUjU,1982 +torch/include/ATen/ops/copysign_compositeexplicitautograd_dispatch.h,sha256=vrIIM3VZ9LDDwROLsQEfF4gqY3pbSUtsDApXmtIZcgA,1110 +torch/include/ATen/ops/copysign_compositeexplicitautogradnonfunctional_dispatch.h,sha256=X0XJ9ZEDAzaoq8RKaSnYmjwdkudlv9uTlx78FJm-rtU,921 +torch/include/ATen/ops/copysign_cpu_dispatch.h,sha256=hJiktX5kbRN7uaeA-8Sjo0HB7jKYbzujus41TtXgjCk,1066 +torch/include/ATen/ops/copysign_cuda_dispatch.h,sha256=fTlL0oJ6JfxYTuMnj-a03afVX4iZlYD6RfXwdRxL_98,1068 +torch/include/ATen/ops/copysign_meta.h,sha256=PvA_FnVeS1Br4m4VBcKGqIn1py4sUEgHezDBQ0pTJuI,629 +torch/include/ATen/ops/copysign_meta_dispatch.h,sha256=tZ0AdILSvSOFqnbE6IfLKexmEHCThHaZovppgqv3jYw,1068 +torch/include/ATen/ops/copysign_native.h,sha256=NeHJ_hf6rY5QRL1igDsoDpFwMnjUpxXgughPGWTudN8,928 +torch/include/ATen/ops/copysign_ops.h,sha256=9Ilxl89la3usfTChyPLdJE9bl4hyeHfOk3_xhf24YT8,4372 +torch/include/ATen/ops/corrcoef.h,sha256=cgqRzHvBY3fD3oON8TZG-Z0xCnR2Fb4NrTVRSxTNMvI,676 +torch/include/ATen/ops/corrcoef_compositeimplicitautograd_dispatch.h,sha256=7V_oy1Vfu2ojrCiv2kYG4q0KUdCEjCBFJqgor-71K-0,789 +torch/include/ATen/ops/corrcoef_native.h,sha256=xdrkVGOeaAFUHzwsrckz3u7quex5o5MEVkcC6rwd6Qw,499 +torch/include/ATen/ops/corrcoef_ops.h,sha256=FxOUo-W9sJWHrwonOMDyYCzBm0W5quDNejuIYOxBiJc,992 +torch/include/ATen/ops/cos.h,sha256=vaYw-V5NdCqOkb6vxm4_MPLzmAfxKNJrz3L-RmsRYNw,1175 +torch/include/ATen/ops/cos_compositeexplicitautogradnonfunctional_dispatch.h,sha256=XVAn4vVZrOUGDcUAXTx2LTJThImkNe1iOm_jsN1M6aE,859 +torch/include/ATen/ops/cos_cpu_dispatch.h,sha256=yGfp9ceMPL56fW7Ttp7dJ-le134Ru8llvV4b3cyOLw4,942 +torch/include/ATen/ops/cos_cuda_dispatch.h,sha256=RxAQWgKfXTJtnJsIu48mIFpIsHcrElIdjbqz2iWqKFI,944 +torch/include/ATen/ops/cos_meta.h,sha256=oXsSZuVa-wkR45ccIoVGBFjVrd9Jh9F4f6zp4_R2tUA,591 +torch/include/ATen/ops/cos_meta_dispatch.h,sha256=9BM-QpPkro2ZmqkvbW1JKCwvRQpeeXbJgCy8DjdsXxE,944 +torch/include/ATen/ops/cos_native.h,sha256=x79YeaQgDsUJlA8Q-UgY-Xlt83tr1iTHq8hiwJessc4,675 +torch/include/ATen/ops/cos_ops.h,sha256=YX05w92RGtzwNDOOqqX2EQm3bmmHLaIChPu9ooIZ6EE,2070 +torch/include/ATen/ops/cosh.h,sha256=a4GmHCoLW5gkx-z8LN1LPdt5DYp1DX_fMcZhqWBhNu8,1188 +torch/include/ATen/ops/cosh_compositeexplicitautogradnonfunctional_dispatch.h,sha256=UDdRZlU8PUapXxAzqZgJprNf0pez72FfJLgp8KW4jPM,861 +torch/include/ATen/ops/cosh_cpu_dispatch.h,sha256=bjnJ3hun1-kXRkALDiQpKoJXqxllRiNt5vwdOj-NGyI,946 +torch/include/ATen/ops/cosh_cuda_dispatch.h,sha256=AKwDHJjOHoo2e_7mxkX6N0kGHEv-5Xc1KcrH2QCA5BI,948 +torch/include/ATen/ops/cosh_meta.h,sha256=dGf1rTDwt9xLYQA1dqKGB9LJeU7kAMzT6FZ6n_gMHbE,592 +torch/include/ATen/ops/cosh_meta_dispatch.h,sha256=64rZAnLgvqHAM3xO9acy8ahZ5JjnZ7C_bkt7fcN_5OQ,948 +torch/include/ATen/ops/cosh_native.h,sha256=2GM7DLmw46PRO6S-TJXw5cpyvrXC5UFOZhgTnPn6E2E,613 +torch/include/ATen/ops/cosh_ops.h,sha256=_8sK0DlL1yMqXMeSLBz-za2I4dYmpp-7V0Gg5RJNAMQ,2079 +torch/include/ATen/ops/cosine_embedding_loss.h,sha256=y73QDWkAuyKDvzUkw4lEW84I7I9knn84c7spBHbVbRA,949 +torch/include/ATen/ops/cosine_embedding_loss_compositeimplicitautograd_dispatch.h,sha256=Fcm8ClSeixYUn0ri8Reel-qDifyUh9AAPtbi9uKile0,916 +torch/include/ATen/ops/cosine_embedding_loss_native.h,sha256=qrrcYlqpJds1obZ95blgRcNE-7KEq7BDTsVK1o3OUlo,626 +torch/include/ATen/ops/cosine_embedding_loss_ops.h,sha256=gAirIbkjQ7ch46k8-liNC8zhlc8RdNmkOfg1E2fq-Sc,1338 +torch/include/ATen/ops/cosine_similarity.h,sha256=b6acMzEZS1voEJ2-GQ1THoLO7aYNLDYXh9S-ZfjxpIs,815 +torch/include/ATen/ops/cosine_similarity_compositeimplicitautograd_dispatch.h,sha256=QxuOvuBbJ3-G-UJQWaiH4Cgd6XxbLHIwc5urhUMbOks,852 +torch/include/ATen/ops/cosine_similarity_native.h,sha256=5L7dqwkAl8BSMTcqiay-UqviLyhu7yAZQEpsOWj83D4,562 +torch/include/ATen/ops/cosine_similarity_ops.h,sha256=ujbqvGrYAhLnRrx4MHNQ7otezKsjwXWV-Kd5jKHi9EI,1185 +torch/include/ATen/ops/count_nonzero.h,sha256=MpETeyLULac7BLp0DAoJeLXVxWjkYxP1ySuvbJxJBto,2102 +torch/include/ATen/ops/count_nonzero_compositeexplicitautograd_dispatch.h,sha256=LmFkwrRkscbuvzYCTQRZnvKpYeUIIoan-Q3zF7FRgwY,1302 +torch/include/ATen/ops/count_nonzero_cpu_dispatch.h,sha256=E4KLfCKtcVbLqMvUDD9CXLxf_s3hp3cR77dt5pPuE8g,771 +torch/include/ATen/ops/count_nonzero_cuda_dispatch.h,sha256=I_VabnPkxoz6tA2TpDMl-oMNCn3fZeDVDTQjk5oL_2M,773 +torch/include/ATen/ops/count_nonzero_native.h,sha256=axh1kLESq7wmwob5nLN0r6eGscCrvIMP8yyCuvMBBDE,959 +torch/include/ATen/ops/count_nonzero_ops.h,sha256=osU4VMWQhxdnfv2Ae4B6cLwzzUO-bfTEqt-K9e080ZI,3196 +torch/include/ATen/ops/cov.h,sha256=901eWhYJjC-IomuEzToT9TXzZPE_ewwtz7HeQxsxxtU,875 +torch/include/ATen/ops/cov_compositeimplicitautograd_dispatch.h,sha256=R0aEpV2Oxsy9pYvV3e7uQTKDdNS0en8sDt7G5a52P20,904 +torch/include/ATen/ops/cov_native.h,sha256=0YIHRqSAg2AAgj3rgajMVfRrEQky-JljxOolSEr7fgw,614 +torch/include/ATen/ops/cov_ops.h,sha256=QLYXijbvphaIWPtDNOjqcaWj3i9O1Cn7QkbQedvRj1o,1351 +torch/include/ATen/ops/cross.h,sha256=ZyrHO3HwbKDlzwsd1nKNhq3SoM4dahTTNfTuCG7VJaY,1378 +torch/include/ATen/ops/cross_compositeimplicitautograd_dispatch.h,sha256=OUH51b9T9pssmqtYwvgEG8V_cnNqCv4jkMOSFXDXFSU,1141 +torch/include/ATen/ops/cross_entropy_loss.h,sha256=resAib_2zjCe8YNRh2d5Kd4wqxgqu6fWm5W_WA1LRyI,2481 +torch/include/ATen/ops/cross_entropy_loss_compositeimplicitautograd_dispatch.h,sha256=_EuqCxUnyXMk0dAU-BaTanFf1i_x5ek-zKZpWGTi_yU,1213 +torch/include/ATen/ops/cross_entropy_loss_native.h,sha256=2c-54JHidRPFbu92K68g9gJ-D99E-XVlMdhOzbrnR-s,688 +torch/include/ATen/ops/cross_entropy_loss_ops.h,sha256=Bn0KHXueESvoWUI5WbHNk4NAcu-mQwy5tYcLYSwS_v4,1498 +torch/include/ATen/ops/cross_native.h,sha256=je5YTZSkru-C9ILtssXPHeOBFI1OxzMZAxz-uPekhno,701 +torch/include/ATen/ops/cross_ops.h,sha256=pL-O-s1lZyc2WF6sQih-NLv_1yWH3baai4RJunbqGZw,1955 +torch/include/ATen/ops/crow_indices.h,sha256=8tWReSewliy628gjT9vUjoSeOe7wmiBuA9QvtYuAJvk,536 +torch/include/ATen/ops/crow_indices_compositeexplicitautograd_dispatch.h,sha256=Kop498T3YYpCgH14FWgKyu-wmDdCOAIvJj08oY-ECUY,793 +torch/include/ATen/ops/crow_indices_copy.h,sha256=33Uh4mBBRg0oJ3SwsyX3_XgpPi5qHG_lBTOGcB0z5Ys,1177 +torch/include/ATen/ops/crow_indices_copy_compositeexplicitautograd_dispatch.h,sha256=FlJ_Zhgdb7mDs3ytL8gyQ2-QztECxwUnqjnUmNfn8rU,913 +torch/include/ATen/ops/crow_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Z1XL_E9hvJuhEiUDeBkX1CLQdNLjOX-TjIl3z_0zK08,824 +torch/include/ATen/ops/crow_indices_copy_native.h,sha256=gxI0rc8jIqONTPRpilVQZrPPp9I56R07U2h1H8t5M7U,598 +torch/include/ATen/ops/crow_indices_copy_ops.h,sha256=utbMKR_UReR-a9OITCTerWNCQWKC7QX8tg1om2I2CB8,1653 +torch/include/ATen/ops/crow_indices_native.h,sha256=XSQXplDNuClSHj6FG9VMtfe_GQHp1Og2dapNDcNetTU,583 +torch/include/ATen/ops/crow_indices_ops.h,sha256=7-skhT3j_ZmmSoEqD4SgeHmPBhaHSm-b8Zv96efTLrg,1010 +torch/include/ATen/ops/ctc_loss.h,sha256=zbih9memOKijOa52pnMuoh6FPWR5jlJns7JUEOpMIDM,1618 +torch/include/ATen/ops/ctc_loss_compositeimplicitautograd_dispatch.h,sha256=6kXzfZDGCxd1IKnmX-GurCiXZfvXKnNxEUIB-fNJyB8,1208 +torch/include/ATen/ops/ctc_loss_native.h,sha256=NnCee2W2pErf4fkl6q91DwaORlQkOg-IpNX60qjL-og,918 +torch/include/ATen/ops/ctc_loss_ops.h,sha256=ewWoPyb_h37QqC0KCktKrf_nv7sctQCfNyqttUDQjgo,2574 +torch/include/ATen/ops/cudnn_affine_grid_generator.h,sha256=va9xQnE_xifD3uk0d5WGx7N2NclrNjP8hFM1gLla0Ng,1543 +torch/include/ATen/ops/cudnn_affine_grid_generator_backward.h,sha256=41Oy2RMJAtlfkZs1GDUyF2rWwkXix82ShJv4BA-sAMM,1630 +torch/include/ATen/ops/cudnn_affine_grid_generator_backward_compositeexplicitautograd_dispatch.h,sha256=xDmV411rtD7NIwyt32WYUyh3eUJST_7DJmAKA3wrr5k,1039 +torch/include/ATen/ops/cudnn_affine_grid_generator_backward_cuda_dispatch.h,sha256=yLEm7twRR3B8OjPdkYZrGak9RE25ld8y2kN1ghbINYY,819 +torch/include/ATen/ops/cudnn_affine_grid_generator_backward_native.h,sha256=K89RqVZGgSyTU93giNMmHTkTZJOnlXM3AZHYKzkzUFM,724 +torch/include/ATen/ops/cudnn_affine_grid_generator_backward_ops.h,sha256=YHIgah-sLNXACz_qMbdJUi4JdxKm_q6GvP4i69e4JAQ,2082 +torch/include/ATen/ops/cudnn_affine_grid_generator_compositeexplicitautograd_dispatch.h,sha256=PBldYIKHBwuEpNQmKbuz0NWUKfPn-59OvNstBdBwBZk,1023 +torch/include/ATen/ops/cudnn_affine_grid_generator_cuda_dispatch.h,sha256=drV5ySjoKOhVGUVcVaBXN93LMsxLzc10PQeHncMbGJc,811 +torch/include/ATen/ops/cudnn_affine_grid_generator_native.h,sha256=7MhMMPdtiFqVGATnOOSRKbi3fFDSuBQLKA4qrEWO8cE,716 +torch/include/ATen/ops/cudnn_affine_grid_generator_ops.h,sha256=ShIHsN1Ej1gTsH4xRynzHQiXzWBURlW-7suq42G0lBE,2028 +torch/include/ATen/ops/cudnn_batch_norm.h,sha256=oTiDYxwMrM0OPA9UBSmJNl9QSSKLikIkcXma6dZ0jOE,3063 +torch/include/ATen/ops/cudnn_batch_norm_backward.h,sha256=6O8QSrG1SF_x4V3hCfxiRur5txQ-rPhKIwDMVUAkVjE,3255 +torch/include/ATen/ops/cudnn_batch_norm_backward_compositeexplicitautograd_dispatch.h,sha256=AvBsGu7paQq8BIZFMnMTef7RQjsm2ddv1arRHYe52oE,1689 +torch/include/ATen/ops/cudnn_batch_norm_backward_cuda_dispatch.h,sha256=lFrrxFEP-Xvx2vpGFVl4sPOqdAIQKHK3WeEjXqcPkfk,1101 +torch/include/ATen/ops/cudnn_batch_norm_backward_native.h,sha256=l7G0b4AHr5RmO14yYZRhc8_6Z47NWgBUKh6DXeVXQGs,1331 +torch/include/ATen/ops/cudnn_batch_norm_backward_ops.h,sha256=6wYD1EQanBfxipxmaBUnGMH3_YnoS9iOJzLKKVyIfMc,4054 +torch/include/ATen/ops/cudnn_batch_norm_compositeexplicitautograd_dispatch.h,sha256=buazrjC5v12UZKCzu4yT8_43G6YZOXwXfP9SzCQEKU8,1603 +torch/include/ATen/ops/cudnn_batch_norm_cuda_dispatch.h,sha256=rC65R9hvbkAjx6RKdjLUVRVPZbCbNVAYznhBenTkhJo,1037 +torch/include/ATen/ops/cudnn_batch_norm_native.h,sha256=EV90Tp47CPChv2eF5b7ThuNOTzw0KuMRcNquy25A3w8,1224 +torch/include/ATen/ops/cudnn_batch_norm_ops.h,sha256=G5cyUNKqmAPLFEYiywarykK-pmJWAqT6g_F7MkQy4M0,3741 +torch/include/ATen/ops/cudnn_convolution.h,sha256=FgNY6NmXZuMNR6A7Xa66Qmti8N89HebGkYxdTm_3Qlw,7728 +torch/include/ATen/ops/cudnn_convolution_add_relu.h,sha256=OGHGgy92dNCdozP8iua75IvdQY_AYyL67jtNVhmnq0A,8307 +torch/include/ATen/ops/cudnn_convolution_add_relu_compositeexplicitautograd_dispatch.h,sha256=hIO8Vj8oEICI3ECOqO-h47xOj_Q8tY3d15RRdZBP2jc,2076 +torch/include/ATen/ops/cudnn_convolution_add_relu_cuda_dispatch.h,sha256=KlHd2n8b4CPRD1dSaKEORrSqebJD8vPJcBsomHsPhp8,1313 +torch/include/ATen/ops/cudnn_convolution_add_relu_native.h,sha256=lv9Bclp2uoRhZ-dvDhagZH96GRkpGMQLN7aQfiFvr7A,1089 +torch/include/ATen/ops/cudnn_convolution_add_relu_ops.h,sha256=7jSRa7qYyDsRwxZPFXIg4W6n_cIuZj9pxll8mUaiU1s,3297 +torch/include/ATen/ops/cudnn_convolution_cuda_dispatch.h,sha256=_ca7HA9zdBgUkeZJEw3nltPBAAp2WHpXJzVNkpzMefo,2279 +torch/include/ATen/ops/cudnn_convolution_native.h,sha256=0TxmWUD3xc_E6l8Umfu0qFdJzTyCzu1IGZSbcOF3nWQ,940 +torch/include/ATen/ops/cudnn_convolution_ops.h,sha256=-x6XsZMvQttYsHPOybezekKF97RyHMEtcn2i_aW90mg,2903 +torch/include/ATen/ops/cudnn_convolution_relu.h,sha256=6B_p-5Cq-YPfmmKFSuxoTBieM4BD0amrA0vqvlHQ9W0,7133 +torch/include/ATen/ops/cudnn_convolution_relu_compositeexplicitautograd_dispatch.h,sha256=Ywjvgoeg3hHepND-ArD8qySSHU7WGC23MvgKZi6JzJA,1800 +torch/include/ATen/ops/cudnn_convolution_relu_cuda_dispatch.h,sha256=BA9GRRFfnUC0SwTohbDw2dQ-OIlUgmbWrZBKEoiFOCE,1175 +torch/include/ATen/ops/cudnn_convolution_relu_native.h,sha256=nq3Me5an4YOVSV_D7Cpm2d0eE_29KQvruiBWHSIoX7I,951 +torch/include/ATen/ops/cudnn_convolution_relu_ops.h,sha256=6-ypPeL5iRG1d8ngEF2wOd3Ap0uHUWWr05gk4m3dw7A,2849 +torch/include/ATen/ops/cudnn_convolution_transpose.h,sha256=W-KwAfMIE1kK7WWP0r8jnwrn8PSVdWbZ1rT1oJPKuns,8944 +torch/include/ATen/ops/cudnn_convolution_transpose_compositeexplicitautograd_dispatch.h,sha256=P9qOdo0mGuMmG9siV4aibTvX52fAtLDwl5CS7A5zSGI,2000 +torch/include/ATen/ops/cudnn_convolution_transpose_cuda_dispatch.h,sha256=BC0PlbroAHYVuoMz2HW4rzX9l5AC9SG_VPucieYw9A4,1275 +torch/include/ATen/ops/cudnn_convolution_transpose_native.h,sha256=yToX0iNRcgptx9-Yz72sEGMx66T5DFiDiG2vzG-4qT4,1051 +torch/include/ATen/ops/cudnn_convolution_transpose_ops.h,sha256=Ve7AG4bYg-NpbTFxwAZvgmjkjvowL-wGxpqYTsGomso,3199 +torch/include/ATen/ops/cudnn_grid_sampler.h,sha256=scVqBdlSci0QOIe7WPPPBXS18VurqAS8HCYSUmMZsJE,1326 +torch/include/ATen/ops/cudnn_grid_sampler_backward.h,sha256=4grW3fm-JNF9VIRjhialVBzuLt_QCWlz_64deQkvS3Q,1831 +torch/include/ATen/ops/cudnn_grid_sampler_backward_compositeexplicitautograd_dispatch.h,sha256=pljl6Gbe1-4PY1iJDP-JF4WsTGVeRsp2uJh3PWNTfEw,1141 +torch/include/ATen/ops/cudnn_grid_sampler_backward_cuda_dispatch.h,sha256=OjONM5OFtaJ6i2JzhC1clA9k14jaFGvVV4pf79DKCcI,848 +torch/include/ATen/ops/cudnn_grid_sampler_backward_native.h,sha256=qo0DUXEyuD9z9o17Msc5QjTrnV78IQwo4vNkgjXFyvI,804 +torch/include/ATen/ops/cudnn_grid_sampler_backward_ops.h,sha256=0hlHg0v3_clzn5Pch97jWciwCT2-8-ACOUTiAbG2ISo,2359 +torch/include/ATen/ops/cudnn_grid_sampler_compositeexplicitautograd_dispatch.h,sha256=zVcEVnuNQ7cj8AqrdNe63dOb-JsCs-mh8Qa-NsbMLE4,965 +torch/include/ATen/ops/cudnn_grid_sampler_cuda_dispatch.h,sha256=Cd6JcfxECgdnos3j9uq0_QhsjrmxlWsFYnuA1rOV3Vk,782 +torch/include/ATen/ops/cudnn_grid_sampler_native.h,sha256=qyOL-15FdSgut8LAZJaCuez0xLtNosnKJmKTNOLXlZI,658 +torch/include/ATen/ops/cudnn_grid_sampler_ops.h,sha256=KFmeE_qunrGEbYIaEf0Ee9FzMIpMqitzjroGh1ORAew,1832 +torch/include/ATen/ops/cudnn_is_acceptable.h,sha256=u7WKE91yiO4CJI0xCVbBsV0HjUePxVKUjoX6NX9WMY4,712 +torch/include/ATen/ops/cudnn_is_acceptable_compositeimplicitautograd_dispatch.h,sha256=4RDfynJ-NfnbUyEjn-LcW5Qf8U23_aJ5tR4SQBdv4xg,794 +torch/include/ATen/ops/cudnn_is_acceptable_native.h,sha256=MpaaOngDOMq1l5n01DHacdStxwzTTWs5Uwy35Xa2lE4,504 +torch/include/ATen/ops/cudnn_is_acceptable_ops.h,sha256=8oAahBStbW7x6JOACALY2lGKnvuyKOrlEmGwMtZzBuQ,1005 +torch/include/ATen/ops/cummax.h,sha256=IAbd4vEs7TffCflm6ZyDbd2uNqsBw4cShb5axysPYEM,2404 +torch/include/ATen/ops/cummax_compositeexplicitautograd_dispatch.h,sha256=76vwM6ayRGIrsMkUfdpExtaYKOp03IuyFr4b1REIuxw,1114 +torch/include/ATen/ops/cummax_compositeimplicitautograd_dispatch.h,sha256=Y8BzjVNWCu4k-uqJKVMEyMi7rdf0l5RqGRYGED69Ot8,1126 +torch/include/ATen/ops/cummax_native.h,sha256=5vaqWGOBSZ-Y0K9S7ECxE6Pj448IO68jKzq0ImlfyPc,924 +torch/include/ATen/ops/cummax_ops.h,sha256=EMmGSpEJgFo_Sos_CxZZi96DEMkweEZfq1KvgdLOyRg,3548 +torch/include/ATen/ops/cummaxmin_backward.h,sha256=eXepNCYfdyx72bCWbxy01yiJ3dRpw-KWLjJSpujrhuY,843 +torch/include/ATen/ops/cummaxmin_backward_compositeimplicitautograd_dispatch.h,sha256=8u-PM6RmSu-O8Jw43WFZyfB8xZgnwXORGmhwjPMvwww,866 +torch/include/ATen/ops/cummaxmin_backward_native.h,sha256=PJicNQOjNYsPLmIGmpnMu2s1D1DUoBEDKjPw6ZvD-bM,576 +torch/include/ATen/ops/cummaxmin_backward_ops.h,sha256=iH3h-FzF1yvzMVkKHiYzrWik0CxzGCQuJOPynn5jbk8,1244 +torch/include/ATen/ops/cummin.h,sha256=D_k5IHmefCGyQy_Wx0O4wsC5S7j6ytKKhOX04D9AtBs,2404 +torch/include/ATen/ops/cummin_compositeexplicitautograd_dispatch.h,sha256=VE1V9Kgbpgna3qcD2sbANDhUEfsNt3TT5mt6bFKE6BY,1114 +torch/include/ATen/ops/cummin_compositeimplicitautograd_dispatch.h,sha256=rayHgWV2LY64vsz4g4pwFUmQFZpqUJ8BLX0YR4w6d9s,1126 +torch/include/ATen/ops/cummin_native.h,sha256=aijlv0O8NjhnKD7W8HK__g-NvYORRL-yt-pQwi8uDdo,924 +torch/include/ATen/ops/cummin_ops.h,sha256=gxtneY5cjC0wNLdIteS1dG-AThPn0BGFN0C9IM9RNMI,3548 +torch/include/ATen/ops/cumprod.h,sha256=l4JLipgfsEjk6HjZups0IWXHkAnfxYLq5IyVE4nT9DE,2345 +torch/include/ATen/ops/cumprod_backward.h,sha256=pTBMqhGsPawsWiJB2VTgh7txhjkvkjFb33k32VlQRyg,832 +torch/include/ATen/ops/cumprod_backward_compositeimplicitautograd_dispatch.h,sha256=pYqv8aobsqiQMa_N99Z74REd134Ak_-86mAQ2tiPyUw,863 +torch/include/ATen/ops/cumprod_backward_native.h,sha256=MoANNAI8V6F9MKAPHGRUVDKT9fK1FNbGUhzxub-fKMY,573 +torch/include/ATen/ops/cumprod_backward_ops.h,sha256=o_MXarES0VGWUcVO_Po3N2yjYC3YAExTJ1RRQ780o-E,1235 +torch/include/ATen/ops/cumprod_compositeexplicitautogradnonfunctional_dispatch.h,sha256=NUlJqw4G7YrtFcnxxqN-FAGc8idePqWa6FIlXNuMoS4,1001 +torch/include/ATen/ops/cumprod_compositeimplicitautograd_dispatch.h,sha256=LwoMHUisTerulqfJ2qTFzpe0HVPCoTznOdJQM_tIHfI,1271 +torch/include/ATen/ops/cumprod_cpu_dispatch.h,sha256=IxXllw-tf-1tNw35VVVOeHAnBsidtXR2WlHred5KYKc,1211 +torch/include/ATen/ops/cumprod_cuda_dispatch.h,sha256=zdwE6Hb3Fx9RFyoRjZCnVvfVtLe-jqhY_ELKHOYBYPI,1213 +torch/include/ATen/ops/cumprod_meta.h,sha256=eJUUIUPcQYzflek_l79rXk-oTEBunpS4nf4Ffzy1ou0,647 +torch/include/ATen/ops/cumprod_meta_dispatch.h,sha256=YPEDEIJMET_sPePg3Om0_9oGl9Us23ucK5XA5ORfLB8,1213 +torch/include/ATen/ops/cumprod_native.h,sha256=V9AxJI2UbZYrunWAsoNKhpxm3B2XC6OlBlX-EiXRa3Q,1061 +torch/include/ATen/ops/cumprod_ops.h,sha256=TaXO0ULDHgRoeTpBVMqs5u2ld7cdY7cdi-CmIuoyRD4,4941 +torch/include/ATen/ops/cumsum.h,sha256=maCI8BLmJ-Mgj9fFsegcUMhDkIF2mcooCSfvk7fzZyQ,2326 +torch/include/ATen/ops/cumsum_compositeexplicitautogradnonfunctional_dispatch.h,sha256=S4sBEjrqZ34ajYasnz7hx4jqvrXEQf_0o7BsJB8QMNY,999 +torch/include/ATen/ops/cumsum_compositeimplicitautograd_dispatch.h,sha256=vXYu3kW44ljcBKKvwSYGTdWHV24YbIMU2Rr_im_asFM,1267 +torch/include/ATen/ops/cumsum_cpu_dispatch.h,sha256=Q2r46H7K69UHsyvJCPeUOz8yAEMPPu9K3j0zia4AK6Y,1207 +torch/include/ATen/ops/cumsum_cuda_dispatch.h,sha256=2hq4Zr5lCRWVPGpj0wMEhUTr7ZnUQA5MhRmVwjXOBq4,1209 +torch/include/ATen/ops/cumsum_meta.h,sha256=J_PnSeEIaFejTwjtSKU-vyaCSvhztIAmDXcr998-p9A,646 +torch/include/ATen/ops/cumsum_meta_dispatch.h,sha256=NrV_Am87nv6WHcZBV7C5jLS9mIQiRbfQ5wgBPGtrYNE,1209 +torch/include/ATen/ops/cumsum_native.h,sha256=ewJ4mSpog4a8lcKEU5eViGiIrQgc9kMztnrtzmZDuyw,1055 +torch/include/ATen/ops/cumsum_ops.h,sha256=2b5RXSiiPCoZO64WAV--DuVE8NQTe1IlYMKhhoxZrco,4923 +torch/include/ATen/ops/cumulative_trapezoid.h,sha256=nxIQJpOwZTPZdMAmLllthynqRYWZHvzqFe4iGVPN0TE,1047 +torch/include/ATen/ops/cumulative_trapezoid_compositeimplicitautograd_dispatch.h,sha256=L2kKTN-rCzcIJQTE_ljFjMi_Y8hCi12hg_65j3CSQgQ,943 +torch/include/ATen/ops/cumulative_trapezoid_native.h,sha256=rtjoL-EAM1_LyVmvQsWY0xKKWDhcErBCjdf6k2WHt4U,653 +torch/include/ATen/ops/cumulative_trapezoid_ops.h,sha256=pR0Anq9-dKUCP24OJK6OUM7kptUeEVSJ9FFQVc-T0a8,1829 +torch/include/ATen/ops/data.h,sha256=x7OLKs_uu93mXMcQDWiBDLFYDSqIzoIapxAEIQUYGyA,528 +torch/include/ATen/ops/data_compositeimplicitautograd_dispatch.h,sha256=UWuVkwslXZZ17m84qhfncK321SAt2NhbFcAAD3MamFs,785 +torch/include/ATen/ops/data_native.h,sha256=A2okSTB9JUXlZR0npGXOe8rnCKq5xF4UUnxPjyoW0CA,495 +torch/include/ATen/ops/data_ops.h,sha256=9vuw3OZuCaTIwk8TI2IpNBoeHC82y6sIhUGp2DbFZDs,980 +torch/include/ATen/ops/deg2rad.h,sha256=hqGQnKQBcusw6xYvBbe7LUq4bDnkZCCK7gu91cbT0d4,1227 +torch/include/ATen/ops/deg2rad_compositeexplicitautograd_dispatch.h,sha256=rxnQ8cFMIDarcRtyLrlJhimCRVNovsQnG6SfObPyV-E,1002 +torch/include/ATen/ops/deg2rad_native.h,sha256=XK_kbXaIn8AKb-jtAhnJDdkEUv1nabqrV8TB82vg5TU,1063 +torch/include/ATen/ops/deg2rad_ops.h,sha256=dmc6RNdWWyoyn_cpjR7SvzFu42BPH_0qyPVxca8hEg8,2106 +torch/include/ATen/ops/dense_dim.h,sha256=INxhke_egSDTrhNc4HI0c9_LUaZ-ul5GBum4qQ11dOE,533 +torch/include/ATen/ops/dense_dim_compositeexplicitautograd_dispatch.h,sha256=k_3GVogDxPA8RAAE4e9fWWt2DqwmQTu1RR9AlR8gNhQ,787 +torch/include/ATen/ops/dense_dim_native.h,sha256=v2LGOnKev2cyuAoQQx6SL7x5F0fB80YJgS8o1YRoVeM,633 +torch/include/ATen/ops/dense_dim_ops.h,sha256=pr8EOhS6gIrxdlKw1kbXh4Nuidmi57I3QFacUxg6LU4,983 +torch/include/ATen/ops/dequantize.h,sha256=U-fzh0ADbgC1JL9Hb2xh2n_OdgADg73d-BlS5_shomE,1774 +torch/include/ATen/ops/dequantize_compositeexplicitautograd_dispatch.h,sha256=FIxFmusTq27Atuq02viBdzhM2Zh4NRRJqkK_tFaIsgk,1052 +torch/include/ATen/ops/dequantize_cpu_dispatch.h,sha256=t6dPLbOLPUCAf80QSACYm2qzVklhnTTk3uFcdQn9NXE,747 +torch/include/ATen/ops/dequantize_cuda_dispatch.h,sha256=ubwpngHDiDSA_zsu5tpPZrjvYpnigtj-1YEnlpymkSw,749 +torch/include/ATen/ops/dequantize_native.h,sha256=z_NGDkE-oL8WGIccRwfq-anGGyDOoULPzdZcI2ybkhM,849 +torch/include/ATen/ops/dequantize_ops.h,sha256=LvtsxhN3LuhSCyN9nEtdb9YxNfPtBXih_gtsWCgOcOA,2844 +torch/include/ATen/ops/det.h,sha256=z7Y5uCit1P9foSb8AtwsNCehJE7iE6Mn8_mCno5wsGk,656 +torch/include/ATen/ops/det_compositeimplicitautograd_dispatch.h,sha256=hzsrS4965cFF0uvJUaSiY2zN2LyKjQYsURr_JnVq9RM,784 +torch/include/ATen/ops/det_native.h,sha256=PH0IYyq6SByBinA91daYj2zziuIZYd_CBCWxWMhD17Y,494 +torch/include/ATen/ops/det_ops.h,sha256=mTeSV57bC8nu9gamSLD-sK2mT8xzQ9-1MMbwxADobXE,977 +torch/include/ATen/ops/detach.h,sha256=uzaCTxLRHGu-umaGnxuGCQ9x9wJzSna-oLnMmP8Sxck,821 +torch/include/ATen/ops/detach_compositeexplicitautograd_dispatch.h,sha256=m_EMq_LYUltG0PEnb0JL2fxBxMlpsqEFs2Psig5Goh8,839 +torch/include/ATen/ops/detach_copy.h,sha256=mcpqcFXvB90i4T4nL7NQEYYv0bqJ_LEvOBzBsAP6MZY,1117 +torch/include/ATen/ops/detach_copy_compositeexplicitautograd_dispatch.h,sha256=N9yvEklo-g23Lgl1Trm7km-FpGVRwEbGHTbYThtxlqQ,901 +torch/include/ATen/ops/detach_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Mbg6RLY1vYTGdapeiAx8l57td1igdAPu4hq1qyT5rRo,818 +torch/include/ATen/ops/detach_copy_native.h,sha256=gKB38Chew9wjlCkLOuLl-WWlAjZqHdMIJvneruwysak,586 +torch/include/ATen/ops/detach_copy_ops.h,sha256=AinLBCeq9k6J7Fbme5gRogxDqsKodQMuKLLjUzj-0-E,1617 +torch/include/ATen/ops/detach_native.h,sha256=-aiGqPnmSjhTnee_VmbZOqXQ_Eu4-1B2-X0uNClMG6o,549 +torch/include/ATen/ops/detach_ops.h,sha256=_humAWHMMrDeclfqpIydrHHRyFCQANOwtHKbK_8RLuM,1502 +torch/include/ATen/ops/diag.h,sha256=CpKl-6pBwkLFx1rhJU1G82NuzRX6j_omByitbbop4gs,1183 +torch/include/ATen/ops/diag_compositeimplicitautograd_dispatch.h,sha256=tKsnlXwyPWenQf6Gq2aH_Y_Q8EfIbpaqF6mJmVWYvSk,998 +torch/include/ATen/ops/diag_embed.h,sha256=QFDnRvliYM6lD0ibXstmICtAuAi6fiw_A510RPRJ1VE,1435 +torch/include/ATen/ops/diag_embed_compositeexplicitautograd_dispatch.h,sha256=pK8ZcPhFBAzg-mrXqv0g8inFQj2JtUXSYURTW1rpgkg,995 +torch/include/ATen/ops/diag_embed_compositeexplicitautogradnonfunctional_dispatch.h,sha256=ZhD5YH7gGxx72xpCsWwbET5_G6PRHKloLza0cUj-FSk,869 +torch/include/ATen/ops/diag_embed_native.h,sha256=Jz7UlSE8MFXIgyqAlCA0ZaF8JD9ar52StfuADkgQwBA,680 +torch/include/ATen/ops/diag_embed_ops.h,sha256=3LzcVlBsq8reJC5-wQZUsd2S_zbBsZQDWeOoenh8kgM,1921 +torch/include/ATen/ops/diag_native.h,sha256=nrSnFaP1uH3Z413ifrjFH8OllJwnCKtsVbzErwFOLds,610 +torch/include/ATen/ops/diag_ops.h,sha256=f4y1f-d5dC10aicnkXFcHz5Wh-LCyAlOe4ebZza3FWQ,1697 +torch/include/ATen/ops/diagflat.h,sha256=5ca1izkFIjh64a5BmugghHqOkf5zGX9BLFVssWTRnAU,716 +torch/include/ATen/ops/diagflat_compositeimplicitautograd_dispatch.h,sha256=AxXEHll2B1BIh2lYVXHLoTMsR9jd2Rj9yZpSfVB4nAo,807 +torch/include/ATen/ops/diagflat_native.h,sha256=lUBhiLd1tpT2C1k3WNl_NSyVCMcTjjaGYWFBIPFf2mY,517 +torch/include/ATen/ops/diagflat_ops.h,sha256=SMC5LJbGxK8hiUMpWqkqDz-oMQ_6plu2rEAB63xqiXE,1047 +torch/include/ATen/ops/diagonal.h,sha256=jK3XochAfoAzrg6Him1glyHLuBw_wmVa1CTetnNg-1M,1121 +torch/include/ATen/ops/diagonal_backward.h,sha256=MMmTdGGCSIIDvne-gmjTRdgFc1Qr8bCm47bKTGMGCic,5304 +torch/include/ATen/ops/diagonal_backward_compositeexplicitautograd_dispatch.h,sha256=R-nQC_P_XezAmusCdZo4WDe1GXVYjZYMwwWV9w-fQEM,1739 +torch/include/ATen/ops/diagonal_backward_native.h,sha256=qxi1JaRom2ltGN0puRbVsI10OrHpy6JIV1jUQG5d7SE,780 +torch/include/ATen/ops/diagonal_backward_ops.h,sha256=DDitl6zZZqqWJOJKs5KNTjt0wmq9KswnuDN1mB6hncI,2207 +torch/include/ATen/ops/diagonal_compositeexplicitautograd_dispatch.h,sha256=er_GT-mvtJSpx_oYa8qsrnrZaVx8r4l-So60myVinF0,839 +torch/include/ATen/ops/diagonal_compositeimplicitautograd_dispatch.h,sha256=mZ2xO_vOHYpxPARIUXj1HYGbBSICBM-v4op4C_H2_bQ,863 +torch/include/ATen/ops/diagonal_copy.h,sha256=_Vam4asCxjLqBt_WproPeqKaY0ugRt6QSRo8WsEkdBo,1455 +torch/include/ATen/ops/diagonal_copy_compositeexplicitautograd_dispatch.h,sha256=q4jhmuqp58xLJ1iSFaaQ3HlmyNi81JDH83uziUxIkY0,999 +torch/include/ATen/ops/diagonal_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=vjAUVNabpanhou3hnOzjUvISS96Nx_RKKMbaLKMsmgo,870 +torch/include/ATen/ops/diagonal_copy_native.h,sha256=hqVbRM6D1ertpiD4p_oMb3viB_lbORH4KQD_dYobIW8,684 +torch/include/ATen/ops/diagonal_copy_ops.h,sha256=tu1ZrSZS8A92RyyDb8kyWA4QwqY3rS3YfQ-f8oiLUvw,1935 +torch/include/ATen/ops/diagonal_native.h,sha256=-z5k7HLnGGt8jkqBMsLwxpl09XZtg7OUmy5iDY8wfRI,680 +torch/include/ATen/ops/diagonal_ops.h,sha256=L-v71Nv6TBmuLfPoVeSvKtVk24f0KWws2RgRn8rqHH8,1950 +torch/include/ATen/ops/diagonal_scatter.h,sha256=ExKgsy_6kSEsC_CQTP2WLArGAaocPf02xdOVlnW63iQ,1608 +torch/include/ATen/ops/diagonal_scatter_compositeexplicitautograd_dispatch.h,sha256=eFB3alxl-RWBodD93HFso7aN8Ph9t0BKL3RJvHUvfhU,1053 +torch/include/ATen/ops/diagonal_scatter_compositeexplicitautogradnonfunctional_dispatch.h,sha256=5mpat0wzqhXEEW4vLDVSwPB1lj99oFGt3c5jsi8KB20,897 +torch/include/ATen/ops/diagonal_scatter_native.h,sha256=kyXvU5QoUviTNnOnhXUN1PxhlCTRDtFp7CSKHmc6Dg8,738 +torch/include/ATen/ops/diagonal_scatter_ops.h,sha256=fQDYdFVWFOKYZ8htEqnui0ZBRb18-Ou4K4kiwZK4o1A,2113 +torch/include/ATen/ops/diff.h,sha256=ytGAC0eWC6Jeg9CraJ8oZdaPjoTqI7ECENKzPd4-Zqs,1675 +torch/include/ATen/ops/diff_compositeimplicitautograd_dispatch.h,sha256=c2vg9pLz7pwwIjPHDVN8JR-ieav35Yv8imPlh-7gwBU,1301 +torch/include/ATen/ops/diff_native.h,sha256=wOifb7tQbmpXiDfZpcYrMH0CNfk_MZ3-V-5vnjJiJLo,809 +torch/include/ATen/ops/diff_ops.h,sha256=r07wnObOPOh0EYZgRzlXK0KhLMPsOhPXB_njBY4VMDg,2339 +torch/include/ATen/ops/digamma.h,sha256=CU3R_HlncrN281xG0VcRSbhzjNSa7m7QHzwkvSPOjxQ,1077 +torch/include/ATen/ops/digamma_compositeexplicitautogradnonfunctional_dispatch.h,sha256=UfK09Ulh-cIIueqGjg9qAuAygnRGQzHQBn4_gsaNxDA,867 +torch/include/ATen/ops/digamma_cpu_dispatch.h,sha256=39IvxUe-egmJU1uR0LeAeEEzfCqkXz47EmcQyKrCRIU,958 +torch/include/ATen/ops/digamma_cuda_dispatch.h,sha256=4yuKvdbZe1KDFPbjl874BSmlcAAjBY21YDaP9yiYap4,960 +torch/include/ATen/ops/digamma_meta.h,sha256=QFi9-v_PcS5CH_8wXRJuOXcgDiGBscukdKxdmqQi0vg,595 +torch/include/ATen/ops/digamma_meta_dispatch.h,sha256=4vPGNuw7l3TQ1HFd2H1yarV6s9IzsnVSOAJkPcfJv58,960 +torch/include/ATen/ops/digamma_native.h,sha256=lfmlDkqUGQC5ASbvtvtQuhySoq9nSwVQCtEmMPezGkA,622 +torch/include/ATen/ops/digamma_ops.h,sha256=nIH2uPmy_7hphhrZfCdUAT94AcsRD5dtSgeUeNTFx8c,2106 +torch/include/ATen/ops/dist.h,sha256=pQAAjmmUKxnDRh1MQCc0w1tu__3bKRRdI2a4IHHnQY4,1303 +torch/include/ATen/ops/dist_compositeexplicitautograd_dispatch.h,sha256=3bWdDsHWkCGBooXsWhwnKdj9vUYqdErjrpwMTpVI1mA,1088 +torch/include/ATen/ops/dist_native.h,sha256=BD89-_q5-fNXLxY2FbuBr6PKpLOKrDZPn5JZctQdsAE,670 +torch/include/ATen/ops/dist_ops.h,sha256=CL_GMnWgoNfL7zxIO87v1-FSLqHnCpUf_RSW5yoytSk,1899 +torch/include/ATen/ops/div.h,sha256=Yk7YCF2pi2UtK9oig2cyfwR4ocBThwDKlSNUQ53s4l0,3819 +torch/include/ATen/ops/div_compositeexplicitautograd_dispatch.h,sha256=4eeErL-lzng70cg0_5pmkrEPU2dd-p5ntV1KIvmG_oE,1644 +torch/include/ATen/ops/div_compositeexplicitautogradnonfunctional_dispatch.h,sha256=QCLD8tBY_XpwjhIRMGC6oRBiP9e1Lfy7VrRHTGA0Z_Y,1162 +torch/include/ATen/ops/div_cpu_dispatch.h,sha256=KuIDsDKGFIuC8JVk32uLuT8rjqu0dNSfm3SgolmnJcA,1600 +torch/include/ATen/ops/div_cuda_dispatch.h,sha256=cFZ5dV1HCQzC491r0B2Gz7m-jagfkKFlkGkC7oiWfCY,1602 +torch/include/ATen/ops/div_meta.h,sha256=sbmwOiFyHb9PCW4tsGAcydtXRsiFsHPAAcukslHhvdA,823 +torch/include/ATen/ops/div_meta_dispatch.h,sha256=qtWy8YthDLe154VavewoUUJ3IG9l0egH2XCBLxL3-nY,1602 +torch/include/ATen/ops/div_native.h,sha256=ZpXaDS6AMkVCP2C3es2h34Wi_HUeF0J-WZTHJOeD0VI,2549 +torch/include/ATen/ops/div_ops.h,sha256=iHjA1UkENJ43al6GIJQPg3PhL2bcPs4RzQnK1wIZspw,9109 +torch/include/ATen/ops/divide.h,sha256=_alq8sE0mrD4YzTXt035HtT3fkSwS0warxpF-ZoS6sQ,2700 +torch/include/ATen/ops/divide_compositeimplicitautograd_dispatch.h,sha256=NSVXhYL30_sJrrWx4MyXZsbDoyYSODIOjosPTMcM7Aw,2084 +torch/include/ATen/ops/divide_native.h,sha256=ouetj1l12-3cqf9O0PAu7OawvHvPlAMONaV5lZ8rWSk,1533 +torch/include/ATen/ops/divide_ops.h,sha256=vtd-Byyze0lqdJbm5Lh8ANnRaNaPcxZhZSeZ1JX-w9c,7633 +torch/include/ATen/ops/dot.h,sha256=1bm9VQ28Bf56PfCsggsGvqzuDMsaGrJ5HeWe-ZDg4AQ,1187 +torch/include/ATen/ops/dot_compositeexplicitautograd_dispatch.h,sha256=VqYelKawhAxipiENSTHo9ZvufMxW-CRxxWbdL-p044k,939 +torch/include/ATen/ops/dot_cpu_dispatch.h,sha256=joBcjSScsZAmYII3YfVPKQmLiutUD0kwGz8cYM2u0io,767 +torch/include/ATen/ops/dot_cuda_dispatch.h,sha256=wfMRbZlMbeG7Hr1Octie1xVYdbIf0qABKArF4oTwWhM,769 +torch/include/ATen/ops/dot_native.h,sha256=0tm5A8D1IupasSYcBKitv_-PO7K8WTDmekHGCbvYLdY,708 +torch/include/ATen/ops/dot_ops.h,sha256=4jUSL-qwLXKDeLdKyK2caCFuHVAxB66fUoDlR_3sLAs,1747 +torch/include/ATen/ops/dropout.h,sha256=xKo5ok4I3i51IOgnJBZ_PjCfz6BkipnbQD96rR2tms0,931 +torch/include/ATen/ops/dropout_compositeimplicitautograd_dispatch.h,sha256=zA9jQTy_IiyG4P4_ckEBw7kjj2MvBmQYio_mGCXmMpI,886 +torch/include/ATen/ops/dropout_native.h,sha256=wykd7NmBdBBZzqU8DrSvUpiTu7xDcMOgUpXZP0UkxlA,596 +torch/include/ATen/ops/dropout_ops.h,sha256=_vbxD4VvbncIxXYn8En6yejvE-5j_MI4Mc5iIr39LgA,1663 +torch/include/ATen/ops/dsplit.h,sha256=C09YqC_y5Jl5x_9W5hD4ZW1hB5mjUD_FzFolcgrPgqw,975 +torch/include/ATen/ops/dsplit_compositeimplicitautograd_dispatch.h,sha256=G6ROgXAUp-IfSLn4ocAqeVYcEPQYWqVLslT55vJe1iI,915 +torch/include/ATen/ops/dsplit_native.h,sha256=7bWYX8nlLs4AJEVMGYPweRUDZof9TXfhuFqQ25NMuco,625 +torch/include/ATen/ops/dsplit_ops.h,sha256=25ebuv_tiJvG1K-xnu5TCsUM-fMUJOmPLeavSIYDEIM,1782 +torch/include/ATen/ops/dstack.h,sha256=pSChdEB6ojL5N5wxRt0qATWx_f9vYwdSjFU4JBjdMe4,1088 +torch/include/ATen/ops/dstack_compositeimplicitautograd_dispatch.h,sha256=ffYw9HbgtW4JK1RH1tVm25MZw7IyETX7t0eEw1nAdnQ,943 +torch/include/ATen/ops/dstack_native.h,sha256=lnzJxuStO5vhcsSy-cyrynhAuceUTvkny7idNZHxJPc,574 +torch/include/ATen/ops/dstack_ops.h,sha256=ztxKe018qpFRaB-cX0WJkPnjAx0ue_0XkN3LTG84_38,1585 +torch/include/ATen/ops/einsum.h,sha256=uJ-3SgcksK5f92Bl406hXuxth6tbd6tdTqJZ4ZvuGsw,798 +torch/include/ATen/ops/einsum_compositeimplicitautograd_dispatch.h,sha256=ABUUUmy-oNHeZ_EScOPEqm4ljHXYzOj4IKWZ-8vxmkI,858 +torch/include/ATen/ops/einsum_native.h,sha256=aVnwVaTmXFSRnebXiqf-BfBmO4CWUOtoyXHaLeaAF14,568 +torch/include/ATen/ops/einsum_ops.h,sha256=BJiiD66KIi8yzcF1T7Z4dmGFUKiGC0i_xXKodKdZtVg,1177 +torch/include/ATen/ops/elu.h,sha256=nA0MP_wktOjO4Pe0NcLsZ_bKiLLhcjQmiXaLklcTpV8,1853 +torch/include/ATen/ops/elu_backward.h,sha256=-5AbdUd2eRhkAJxiPKXZg7qW4S8ZuL-oqCR0xoLH3GY,2088 +torch/include/ATen/ops/elu_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=OQ6NGf2RSe7Zi56Kya2-2nemlh2JkcdXHVUnXopO_QY,961 +torch/include/ATen/ops/elu_backward_cpu_dispatch.h,sha256=Ocx0rA2R_PS1cVUqkbb_hjjlRx0_964hyU4c72oru5c,1360 +torch/include/ATen/ops/elu_backward_cuda_dispatch.h,sha256=NvdlPDTxChD4rOZObv_GQ8kf-Z_fm-mmqzGBu9GvEnU,1362 +torch/include/ATen/ops/elu_backward_meta.h,sha256=VbnvDxdRvHZMozkqhvwnFi92KmHaGnLA9LgDc_n3LpM,742 +torch/include/ATen/ops/elu_backward_meta_dispatch.h,sha256=b9QSjD2uQprX0BqwNIer_0hELmgHDdXKHA6EqsKy7h0,1362 +torch/include/ATen/ops/elu_backward_native.h,sha256=YElmHmEdpxiPvhlmscqfFp1oqF5RA9RmTnLboQQL4SQ,786 +torch/include/ATen/ops/elu_backward_ops.h,sha256=WYZFlwkTRh3eEOR-MEulHa9tr3W7vItOQHKsO3It1EI,2593 +torch/include/ATen/ops/elu_compositeexplicitautogradnonfunctional_dispatch.h,sha256=OCm5LtpH6Rv9b_sXRjhy4D3lsLH0C3RQsGl6vmn3Bmw,1039 +torch/include/ATen/ops/elu_cpu_dispatch.h,sha256=itJn2lmQ5MjLVEwRJUhf-GwWRYekcccCAqlHMtdtAmk,1296 +torch/include/ATen/ops/elu_cuda_dispatch.h,sha256=XwKgCAWj8qVknGYlQ_j_NY-5yLGuUl5MbudvsVRuVYg,1298 +torch/include/ATen/ops/elu_meta.h,sha256=kg01R0H2d3XiEBAakAwoVqB3xcbk2uTF781JjwOlot4,675 +torch/include/ATen/ops/elu_meta_dispatch.h,sha256=thNLeYb4OUXYryTvrewE0xqqe3kX4Z0Ggy3VigmkpZo,1298 +torch/include/ATen/ops/elu_native.h,sha256=26wVsnnE8TZOSiSy36k8Wg5N5fGMoy9dvdDUYgTomMk,694 +torch/include/ATen/ops/elu_ops.h,sha256=ErCsI23k-eHsDvt6ZCBWuGwhN11VIlHff3MgHd2uHXM,2916 +torch/include/ATen/ops/embedding.h,sha256=ZkILTkGn5W8ChS5FDPA1GOaF2knUiLY9DKxXPtVXoRA,5452 +torch/include/ATen/ops/embedding_backward.h,sha256=vxsgpp6lG8CR-44gBNGjArY77c_5gNsC2qdogGDbKbc,2239 +torch/include/ATen/ops/embedding_backward_compositeimplicitautograd_dispatch.h,sha256=W9Qs6tOfsIQBIGga027432zGRvZ7EELYIbyirFpvB1w,1097 +torch/include/ATen/ops/embedding_backward_native.h,sha256=FkO6hgAtP2BHAtvH3IBigOS_EiALqsQEXmC6Yiuo2zQ,632 +torch/include/ATen/ops/embedding_backward_ops.h,sha256=Jm_Hj5gdxWShkTLiFbg_fd0JFqdlx6PJVbX_fCK0uko,1406 +torch/include/ATen/ops/embedding_bag.h,sha256=c4auup8vZFFgcwJxUfDz225dkEU9tw1CkFv00uO0c6U,1990 +torch/include/ATen/ops/embedding_bag_compositeimplicitautograd_dispatch.h,sha256=6Yi3Qsu7UalEhOzV-i2aC1Id1nSeEx0pEik0SV4YbeU,1395 +torch/include/ATen/ops/embedding_bag_native.h,sha256=B53aMU43bdzmtjEEYjE95t00ZYSHtrN71LgMK6Ol87k,1105 +torch/include/ATen/ops/embedding_bag_ops.h,sha256=MSUhYHULYzKXUdm7gpeqTBa5Sh7zM8RGzuoqZ8sZl_k,3301 +torch/include/ATen/ops/embedding_compositeexplicitautograd_dispatch.h,sha256=nvksXwRCdsCYr4sERlE_GqdDO3xSvaDauQ2FyPMf7fM,1805 +torch/include/ATen/ops/embedding_dense_backward.h,sha256=qeZpfykTwSV9cS7_G30wo8byUbWRzGUb7Eksd2QQUbE,6079 +torch/include/ATen/ops/embedding_dense_backward_compositeexplicitautograd_dispatch.h,sha256=QLgkWuM5npod7iUrER_Yav3VOtBj9KS21jLi3eGcey8,1560 +torch/include/ATen/ops/embedding_dense_backward_cpu_dispatch.h,sha256=UaFy8ykg2Vvn0zGgQVQFym8xIj-3dzXAiTknXxPDAlY,1053 +torch/include/ATen/ops/embedding_dense_backward_cuda_dispatch.h,sha256=nU7ASdlKges8akSm0BWMRyV0Rgt6m6t7GtHttMCTUWw,1055 +torch/include/ATen/ops/embedding_dense_backward_native.h,sha256=DEQSQnda1MNtnrjFnHC82dF_a6USGvoW5zWDEei81GU,1015 +torch/include/ATen/ops/embedding_dense_backward_ops.h,sha256=bbFbOlPp1ukuyIw-SUB8ZDVExM54MAcqWqbQMh_v1cM,2415 +torch/include/ATen/ops/embedding_native.h,sha256=6nrR9KwTUYvu4KeBl6JncdrmEUSqs-7Wgo9PNrPzrvE,972 +torch/include/ATen/ops/embedding_ops.h,sha256=eu2zJv-jfuGsAFJ2Ar1uTDrjWlKvtGTMMFgxQ3dCyCA,2249 +torch/include/ATen/ops/embedding_renorm.h,sha256=5V8Gc7A68SQAV1evfHm0QMvsoLIul1ZI8tI4X-KXUKc,1912 +torch/include/ATen/ops/embedding_renorm_compositeexplicitautograd_dispatch.h,sha256=H5_3_wMyvCklnMlA-cTz-qw0vZMHygT-YzimOuCfGks,1165 +torch/include/ATen/ops/embedding_renorm_cpu_dispatch.h,sha256=mrOcm2uLSnh2wU9DwiQbvaE5uFLbnPKsAMiGysOP21c,813 +torch/include/ATen/ops/embedding_renorm_cuda_dispatch.h,sha256=gOsiXXZ_5geioMYCr2P44nZa3W-uTb5P3-4aScBtEQY,815 +torch/include/ATen/ops/embedding_renorm_meta_dispatch.h,sha256=NQSHNkVx5k8M9hFDA37jqev3FE6-WZ52FLFVHKVc29s,815 +torch/include/ATen/ops/embedding_renorm_native.h,sha256=E-X2fOxgikD93YBgjM5EGgSX1zAd1pkHld5tyqLZR_8,981 +torch/include/ATen/ops/embedding_renorm_ops.h,sha256=WiEzUGVbAoJ4YliupfN1lJEtwxJ59SZr6KbuRFHZPfw,2820 +torch/include/ATen/ops/embedding_sparse_backward.h,sha256=IZlRvCukWBxz20x6Q1K83VAFTrMDhktXP2dAryUCERo,969 +torch/include/ATen/ops/embedding_sparse_backward_compositeimplicitautograd_dispatch.h,sha256=rik2rbXzxK33WgzpcBEHXtZFD_mC81q_aAglOjq_aqg,901 +torch/include/ATen/ops/embedding_sparse_backward_native.h,sha256=8VKVUgNNSJsPy9eDgPhhNtpwlUWK-TDCfVDFtiZ3eFU,611 +torch/include/ATen/ops/embedding_sparse_backward_ops.h,sha256=eJfKfdIN9M-cojRg2l4FnEOTiFqPcXdiEF8-zTI-A0U,1352 +torch/include/ATen/ops/empty.h,sha256=dJFtZeuJQ5GP55ekU2AjRb3nVS0TOdWRYghi3tZWrnI,9360 +torch/include/ATen/ops/empty_compositeexplicitautograd_dispatch.h,sha256=NflAQrbcGfmsDvuLllht42a1gsnA4h8XPnTz5y0Jjo8,1546 +torch/include/ATen/ops/empty_compositeimplicitautograd_dispatch.h,sha256=McoBzoagAGhiqi3evSpbgdWQiy-4tR1_77EyvPLRTfw,1282 +torch/include/ATen/ops/empty_cpu_dispatch.h,sha256=nHo_HBSMUCRAkeCAq7lT4ytPR3yn9S_WBhl9P7gR3qE,1490 +torch/include/ATen/ops/empty_cuda_dispatch.h,sha256=8D5l6f1IIxlR4tsa-sneREyQoD3QeM8tBVXCCv8ncIU,1492 +torch/include/ATen/ops/empty_like.h,sha256=Uk9UA3iiFqFkPMOxvmOq9VKIKrhc2qxKsXspYC-nBOc,2265 +torch/include/ATen/ops/empty_like_compositeexplicitautograd_dispatch.h,sha256=tWRpazyztfke3QVKEOhHlQChciXhYVSxZKAf_3kG3Qc,1418 +torch/include/ATen/ops/empty_like_native.h,sha256=cy180I7WfnKHJ3Pj-PMHb5pfxgOv0BrWu7HVVmCllo4,2013 +torch/include/ATen/ops/empty_like_ops.h,sha256=L_xmQCekm-coKFXDH7JBtscSULwVRDeX06XPdWV29fU,2441 +torch/include/ATen/ops/empty_meta_dispatch.h,sha256=RAho-ujyUBgsUdFpI6xWcg4aMAjudq8ISt_D7Pzh1E8,1492 +torch/include/ATen/ops/empty_native.h,sha256=Pv-6kEHA1A46NBoIm2goiNciJF9jzdTArbE-gWCWrWE,3624 +torch/include/ATen/ops/empty_ops.h,sha256=PP-wEci57Uaf_cPBHP0hFlFs0u1LK5rqUNQx0HJhsI8,4677 +torch/include/ATen/ops/empty_permuted.h,sha256=-hi58-W7KHo5QN5rgJpjmUwF6WK-fgDBTpb4p9Gpqlc,6892 +torch/include/ATen/ops/empty_permuted_compositeexplicitautograd_dispatch.h,sha256=QDOwU3pQVBVAAKY1xeQe8T2FndYGiWN51WP6eJFDTEA,1968 +torch/include/ATen/ops/empty_permuted_native.h,sha256=gZ_t4HIaUC-xzkzXIcrKZRE0_RILc10FtzrhEAVOWdU,831 +torch/include/ATen/ops/empty_permuted_ops.h,sha256=VtIw45ltWIGWkbTF0SEpW38a3zsm8dVq3rD1Sb1DgR0,2353 +torch/include/ATen/ops/empty_quantized.h,sha256=FpnAqLNtOAbTo6fla0kw053MuTyfM8Yf-P0BAEy0XjA,2526 +torch/include/ATen/ops/empty_quantized_compositeexplicitautograd_dispatch.h,sha256=X_XfMSGqFcOsVt2sKmxqHmMM_k24k_uDyGq3wMCTxyM,1072 +torch/include/ATen/ops/empty_quantized_native.h,sha256=_m5HK56wj1JR81_tbCsu2vwfYZs5DiWoabesu9h1fZs,914 +torch/include/ATen/ops/empty_quantized_ops.h,sha256=0dTkY2UMsO6anaLyMo24fJ4Rcsjqtcu2TKSXiMQZvow,2635 +torch/include/ATen/ops/empty_strided.h,sha256=pZIqgflFLInrHa0QENxlR1L3R0OTiSP6vIXoV3rOsRQ,6755 +torch/include/ATen/ops/empty_strided_compositeexplicitautograd_dispatch.h,sha256=kEx6pKNCM9DSUMVpBY6c9Ow5iaXuV4kB24AHSFVdSGE,1192 +torch/include/ATen/ops/empty_strided_cpu_dispatch.h,sha256=lqfXKrFbfpRKp5RBOGMWmoYu86TmoHrnMe8rQe-H4fE,1400 +torch/include/ATen/ops/empty_strided_cuda_dispatch.h,sha256=ZdSYmuiKoAu1VDI5OGo6CO1TgM1xWZU9eNaUOmq240E,1402 +torch/include/ATen/ops/empty_strided_meta_dispatch.h,sha256=GY1h1mF1uB0544gFb1HEGMJYDeEJCtrFsuQfsc3DlQs,1402 +torch/include/ATen/ops/empty_strided_native.h,sha256=-q9nwxhkWQJY80eXZ1Zs8MEysydv6PA7Um-ne2IJFgg,1571 +torch/include/ATen/ops/empty_strided_ops.h,sha256=0mpXqckOaG1WP3D6HCEHTz0RH__2aq4sH-3lhFcIBj0,2323 +torch/include/ATen/ops/eq.h,sha256=PLS5fwHOvp-PJ73t0vCcKbcD7o0F6b2FJI4MYCEHrlU,1896 +torch/include/ATen/ops/eq_compositeexplicitautogradnonfunctional_dispatch.h,sha256=0GGlFeqLJCgZJvLlSQd6ThFKGGB5DadgH1xA45uYFqs,1060 +torch/include/ATen/ops/eq_cpu_dispatch.h,sha256=m2KH95kej88ehSWAhwJMtKXh_UyfjRfIo66xEvLWkTc,1396 +torch/include/ATen/ops/eq_cuda_dispatch.h,sha256=tlHCuxYnRO3VabLFKjlet-QFel9M55mEGKxy28GnI7E,1398 +torch/include/ATen/ops/eq_meta.h,sha256=s8pf0G__ugYXB5FcuqC3ljMp2RsXciaqZC4dK62yZy0,767 +torch/include/ATen/ops/eq_meta_dispatch.h,sha256=daefTgYlHZfcuaBYKsocdE3qvDMgvTTzwmzi_3aWUEE,1398 +torch/include/ATen/ops/eq_native.h,sha256=9ue9N6NXDXMMgJp8Hx6DW3yF7jNOeFc-TEvyu0hhZms,1417 +torch/include/ATen/ops/eq_ops.h,sha256=bC_mtpjTcHQMD71xD9-Xv0FFcua5WcYwj1yu_RmJHvA,4285 +torch/include/ATen/ops/equal.h,sha256=QF_fhhdiZQsZtE5RPJfqbMinhqxofcwEB406QiTuncY,703 +torch/include/ATen/ops/equal_cpu_dispatch.h,sha256=vz7hmw9ex41tC_B1UbedEyIpU70GOrMR652-_q8HIX0,762 +torch/include/ATen/ops/equal_cuda_dispatch.h,sha256=IvwMSwuiqRGEqMWeCDa_V-ODV5kWEDm7rkcpIExBLGM,764 +torch/include/ATen/ops/equal_native.h,sha256=ihig0p4oOFI7uc8BNSjUncrwPq4QX5JUdloawZrgOWY,687 +torch/include/ATen/ops/equal_ops.h,sha256=IOBe4OFGL9Gux7c9ApRGWL29DAouY14Mtomb9fZan_M,1049 +torch/include/ATen/ops/erf.h,sha256=4ewBr59t719yrV2ZhCPOVTs1FB0ASpb4PqFnOLX5k6c,1175 +torch/include/ATen/ops/erf_compositeexplicitautogradnonfunctional_dispatch.h,sha256=h9xyEPbJQ1xDGvigkKBB3Apk5gQIpZWrPiT9d7v8h74,859 +torch/include/ATen/ops/erf_cpu_dispatch.h,sha256=DcORWMdZAl70kqbP6t88YxB8a8jIzS1MjnpeAsOCBJY,942 +torch/include/ATen/ops/erf_cuda_dispatch.h,sha256=aElRUzn_cyKHpUMTFF2Qm91uWhcBde5T0aNkfWPn7pw,944 +torch/include/ATen/ops/erf_meta.h,sha256=gIcUh8ijEij1_j86LmKj6Wrh7QeVrAiLh1oqhF-gVMQ,591 +torch/include/ATen/ops/erf_meta_dispatch.h,sha256=7ZFjwRLfKUp966KckOOeqJp5qVQb23sOf3fy5_TFrkA,944 +torch/include/ATen/ops/erf_native.h,sha256=szqvjFPaRlvZTP6kxAbiKqHD__x6ZQm8xDy2-av6Ips,1018 +torch/include/ATen/ops/erf_ops.h,sha256=eDW0an5sTny-GLyiBztBkLmMw8yAbkBhygUTp2tBnB4,2070 +torch/include/ATen/ops/erfc.h,sha256=D8xZRHafLI_yGgYIfxq87qT7bOySiJkjPOSzE0fOfGg,1188 +torch/include/ATen/ops/erfc_compositeexplicitautogradnonfunctional_dispatch.h,sha256=7LsskttYY7TjVzjh1e6nSdY9vTItkin1DS7mmjmCQbI,861 +torch/include/ATen/ops/erfc_cpu_dispatch.h,sha256=XbvC3Qvy4vH1VqWB1T4bVE153U6vwKR3fdvRpd9FvsM,946 +torch/include/ATen/ops/erfc_cuda_dispatch.h,sha256=v8ggcLZhnOXC6JtqpjrNTgjx6WDWnlEeXGsb4SLmNyY,948 +torch/include/ATen/ops/erfc_meta.h,sha256=94H9VGCJCEAL7Q9qYX-OPH_z6TDY_cMHWXdeWvYCrMI,592 +torch/include/ATen/ops/erfc_meta_dispatch.h,sha256=wbe8M6I7RT6cT05Ubfw-NaMqT7187Dh68i2mdnyruyo,948 +torch/include/ATen/ops/erfc_native.h,sha256=Hh_75jCWiLbvOh0z2_ccYtvcIrjiPDCOfRowRRyUgfk,613 +torch/include/ATen/ops/erfc_ops.h,sha256=sZqOLFBAu4GExk8rrkRHcEmEZvzpMHpQGmWTgRr1xCg,2079 +torch/include/ATen/ops/erfinv.h,sha256=K7IyB8Ltbv-2qFeR0HAAqN8y6-oulyHf0Ke0iWzfnYI,1067 +torch/include/ATen/ops/erfinv_compositeexplicitautogradnonfunctional_dispatch.h,sha256=AyvhQ0niBogEeJkdRXy3U2VI9mzVnijGdJJFZPVBvN8,865 +torch/include/ATen/ops/erfinv_cpu_dispatch.h,sha256=HRt-gpV5nxIW6rcAgt_R5J3qKyHzYq9s_0H-VUmC75Y,954 +torch/include/ATen/ops/erfinv_cuda_dispatch.h,sha256=E92GxI3eDcPacHnho9HrGI35Jp8Yo6GoFyizJ1lvvRU,956 +torch/include/ATen/ops/erfinv_meta.h,sha256=9t5Yy6uadoolcByH3x19U4Us2HpNNEKDUx1L3SHGwjk,594 +torch/include/ATen/ops/erfinv_meta_dispatch.h,sha256=CGqbgALRCdynFtXx9ged2FTxmkag0VKZZacIhTNRVik,956 +torch/include/ATen/ops/erfinv_native.h,sha256=k9mEUwM9Uoz9PAQXK-AyMNADgKwO8nWF6GBAYfipN2E,1045 +torch/include/ATen/ops/erfinv_ops.h,sha256=1lI_zR5p1PG57giEkKsab6wth21JIS7TEQTE2oc5hIk,2097 +torch/include/ATen/ops/exp.h,sha256=1jl-7ubBk5Gu9bcoe_LBj4sJIM7SU8Zs9ZCJOIe3Mn4,1175 +torch/include/ATen/ops/exp2.h,sha256=p2HxFzZG58-JKDoUqKxS6Lzk-x7psHAI8E6BffAJrBo,1188 +torch/include/ATen/ops/exp2_compositeexplicitautogradnonfunctional_dispatch.h,sha256=8WHEptHjB5qVNdzDuSuXumGUBBWSTk5d_ZoIHc-lwu4,861 +torch/include/ATen/ops/exp2_cpu_dispatch.h,sha256=rMHO1zpoXqWsgm2wVx-33nwn2KPWWwooe2R8lYP9Sqo,946 +torch/include/ATen/ops/exp2_cuda_dispatch.h,sha256=LsT_BE4vxDuax3evVuWO8m7MswxQ-aXlhPE-0CtmMIE,948 +torch/include/ATen/ops/exp2_meta.h,sha256=8w2ExdAQcRb8oFOmWAnu3L_AZJQM0KK6OBTFxjyMikE,592 +torch/include/ATen/ops/exp2_meta_dispatch.h,sha256=LTPycEqemhIR0De2oNcoyIigROCAhXTIA8rLf4G0cE4,948 +torch/include/ATen/ops/exp2_native.h,sha256=jbk370cwHxUSxFRj_P6wU-GwQ4sjD4g-lIO4qDOtxv0,613 +torch/include/ATen/ops/exp2_ops.h,sha256=rfzVZIAjsJUeffT7vdFxEl0HJuarAUTLtMk_odOip4w,2079 +torch/include/ATen/ops/exp_compositeexplicitautogradnonfunctional_dispatch.h,sha256=bgUdPxvPWwaEg6MS2NZv41ogCSHsdZUERHJydbMc5J4,859 +torch/include/ATen/ops/exp_cpu_dispatch.h,sha256=KMkk8_2RShL1si7QjR6tyUX5XMv2l99t0ZTqFQMYOhY,942 +torch/include/ATen/ops/exp_cuda_dispatch.h,sha256=z3jc0bge-_dCEdw9D7A279sSIEgEtMg9Ys-iOcGeBMM,944 +torch/include/ATen/ops/exp_meta.h,sha256=pfMoVEV-rQeqpfgw8fQi-QePI034pG3mW60xRkvKT1Y,591 +torch/include/ATen/ops/exp_meta_dispatch.h,sha256=ydU6d0BieN_dFdz9oUtwXh6MwsowPHLEog-JQ6mYeSU,944 +torch/include/ATen/ops/exp_native.h,sha256=-AlM4OqsztXu9lzit_2VRo16-PgsZPJN-ckSYMXc7h8,610 +torch/include/ATen/ops/exp_ops.h,sha256=YIoQa3XaxQ8hq95lfM8ZS6NNTdvrvsiTh1wh1vdL7gM,2070 +torch/include/ATen/ops/expand.h,sha256=-cOlqvfXWkBfPC_F1arpo8CjvbbUMC3l0VuEt1WnoHg,1084 +torch/include/ATen/ops/expand_as.h,sha256=f_N0IJhhYyT4jj15Ee6UKJqvGfhY9lGWhm48Jc0yUKQ,533 +torch/include/ATen/ops/expand_as_compositeimplicitautograd_dispatch.h,sha256=e5pPRr4OOqc4Swdg2pfVAXMnc3UKu5uDCglyNiM7wxo,816 +torch/include/ATen/ops/expand_as_native.h,sha256=fQShmz4U7AItSVPnVGlgap0t_pHukqu8CthFA-STv-U,526 +torch/include/ATen/ops/expand_as_ops.h,sha256=JovwLSwDBQyWNgboeS59uUzG8YF3gT11Js68UYw-nr4,1087 +torch/include/ATen/ops/expand_compositeexplicitautograd_dispatch.h,sha256=5MrLbNh38D3E8LJZwOCK-WvyJwn6CG5FBEtO2vQH_84,939 +torch/include/ATen/ops/expand_copy.h,sha256=MipgcTCqJGaYWkh77EhYySFNg5YUyZhoyeYkxml8Ivc,4218 +torch/include/ATen/ops/expand_copy_compositeexplicitautograd_dispatch.h,sha256=NgFC1zsirDXTd3kCqA-qiiHP_X8JrO0qP6Dd3eY4JzM,1252 +torch/include/ATen/ops/expand_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Xn6rftlWACoT-C3OCWD6PoXbzwCDZ_mw7LSSYhOqASY,975 +torch/include/ATen/ops/expand_copy_native.h,sha256=R15uX7YI_z5o6fw_zPKr9osoYqEVmnb59ggnq8niw20,688 +torch/include/ATen/ops/expand_copy_ops.h,sha256=01mrPdMAdLaB74NUsgNuYu5cKtgnNt_pwhTwu5dVvnc,1910 +torch/include/ATen/ops/expand_native.h,sha256=p6kHyTfhM_Rj-rNaYXVIxBnvEh0NliYZ77386qbsYOM,540 +torch/include/ATen/ops/expand_ops.h,sha256=RClZ1hgsYm1zJYSe_Eb3yG52LoYoA4f55-xhkx2vDWc,1140 +torch/include/ATen/ops/expm1.h,sha256=fixRfDu5lQJJRIec71Vd4Lfvu8SdDpSOF2XFiNJfs4Y,1201 +torch/include/ATen/ops/expm1_compositeexplicitautogradnonfunctional_dispatch.h,sha256=A_AQ5jf8gq3p_UGpLMmL8uULpAs7LsX2Chsu8wyP7oM,863 +torch/include/ATen/ops/expm1_cpu_dispatch.h,sha256=m3JN8aNNAducJWBqkR7TwXzujaAZHg0Ak5S_GWEqAPc,950 +torch/include/ATen/ops/expm1_cuda_dispatch.h,sha256=Np0ZhqRs5L74uHAvmR5uR_K1n1L5BOnf6UAb5ZRvCM0,952 +torch/include/ATen/ops/expm1_meta.h,sha256=93Sc0TkAIVL1jPhGR4J35xyJQhR-CStwcsfXKF_PxP0,593 +torch/include/ATen/ops/expm1_meta_dispatch.h,sha256=WJCiLSH0ItAVOYR_nulcYfn-rS78h5968QBajyVFmrA,952 +torch/include/ATen/ops/expm1_native.h,sha256=VN90iIQxRutVvdteZAHFZ89pXmqlwpXwm_INbIGxPzs,1036 +torch/include/ATen/ops/expm1_ops.h,sha256=orDa6gzz-K9l8XQ3dsxBPn2wzl1Z098uCouFDCmypp0,2088 +torch/include/ATen/ops/exponential.h,sha256=4hfPI1Hv6jAYSrWyERjhPWS5OsVsgpz9HYs6jz99Zks,1502 +torch/include/ATen/ops/exponential_compositeexplicitautograd_dispatch.h,sha256=G97t2RJ6xKOsCZ3bOVuNw_vFOCr6B7uXNgcMoeUIsSs,1163 +torch/include/ATen/ops/exponential_cpu_dispatch.h,sha256=OT2MEzMeZQ_Hc-Ngyf8-_kYne8izzxj4fzcZYFqcvbo,818 +torch/include/ATen/ops/exponential_cuda_dispatch.h,sha256=aPjDTLm4Sg3786pWLunDHRlHWLeTRUAIrfaAO14gExM,820 +torch/include/ATen/ops/exponential_meta_dispatch.h,sha256=24cV39GKlMIptrYXmc2jbZhdg5oH3QUea_LpV4sW5eI,820 +torch/include/ATen/ops/exponential_native.h,sha256=tspfO4UlMwAeVA8xffpV_dpeQM-gqtdEiNXRUmouGfc,845 +torch/include/ATen/ops/exponential_ops.h,sha256=H1cOmjCJmsAiJdEVrobf_qk9yaC1JRNi9DPIUoBOQ2U,2730 +torch/include/ATen/ops/eye.h,sha256=z-8s8zsHzBrGxSVMtF7-lp8Un9aPmsQU4nCrcwt5WhU,9829 +torch/include/ATen/ops/eye_compositeexplicitautograd_dispatch.h,sha256=YxsHlsjEtuDVH6crEBTF9OW-2Ej9fpu75J_Ko4A9Qr4,1832 +torch/include/ATen/ops/eye_cpu_dispatch.h,sha256=SDghe6wVM-i2G4o7lSLxC2wVP4-A_AHTDIB4q97-E_g,1284 +torch/include/ATen/ops/eye_cuda_dispatch.h,sha256=riNhJfsklMaQpexMoi8ckJ3WtDV9RdGQM9oKNnaN_ho,1286 +torch/include/ATen/ops/eye_meta_dispatch.h,sha256=q4TtUKBPpvQcBPH8z1_wSrY3GvfVpXmnq4RPdThIuPM,1286 +torch/include/ATen/ops/eye_native.h,sha256=IfhsakXPuj84jtMwMsVhVooPqlU_uNIqR4BOJZgeAQ8,1131 +torch/include/ATen/ops/eye_ops.h,sha256=RiASalCXnOYIELJDO1DrXYycWQEQHC2iN84zw0ug1sc,3652 +torch/include/ATen/ops/fake_quantize_per_channel_affine.h,sha256=Jm4Mxp4RUyJn36wGveWWnbchz47Rw_AfFFzT3LuowL8,1001 +torch/include/ATen/ops/fake_quantize_per_channel_affine_cachemask.h,sha256=ywkYHx9NhN_XS4Ty19ru62NZWzxmKu_F6hR5jgMoBLU,2333 +torch/include/ATen/ops/fake_quantize_per_channel_affine_cachemask_backward.h,sha256=xBvbhw74wppNWsl7eA9VJTdBLToAkhBF0XTk90SAd_I,892 +torch/include/ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_compositeimplicitautograd_dispatch.h,sha256=XdxHB4Lm5xbuRwQlnsPxppb0YpveLlS778XF6lfTeYA,857 +torch/include/ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_native.h,sha256=xiaySuFqop7Npj5ZlW6gKttDvx77NCXIfE5yrxInJ0Q,567 +torch/include/ATen/ops/fake_quantize_per_channel_affine_cachemask_backward_ops.h,sha256=-y2tBHQVgES7RzUGSrUnciAtzznOpCzoz5gOvT-pVJA,1204 +torch/include/ATen/ops/fake_quantize_per_channel_affine_cachemask_compositeexplicitautograd_dispatch.h,sha256=mPsT0kJBmCJhrD2KCDYdlKbIggq2P8ygJADLDOHTY20,1275 +torch/include/ATen/ops/fake_quantize_per_channel_affine_cachemask_cpu_dispatch.h,sha256=1-WcypustbSG8zXIEhNqb4xlbHBg_k_Ujl1ul-UPNPI,913 +torch/include/ATen/ops/fake_quantize_per_channel_affine_cachemask_cuda_dispatch.h,sha256=VJmLXrIb27CPu8-CWxvcz3KbJ78mX0Nv4HWZejUqAzg,915 +torch/include/ATen/ops/fake_quantize_per_channel_affine_cachemask_native.h,sha256=uWz8fCBwTnUoBXzthZ5W-UTxVJwj22zVlMLbLq0NNDk,938 +torch/include/ATen/ops/fake_quantize_per_channel_affine_cachemask_ops.h,sha256=vGhWoMeXRNkcpJGRoOnQgfBXNhYkLhY0kVoC0c2l8sM,2783 +torch/include/ATen/ops/fake_quantize_per_channel_affine_compositeimplicitautograd_dispatch.h,sha256=FdLHvzub6oXAhaa4D9q_9lZftzR1ghgQ1pJzhCRN4lo,922 +torch/include/ATen/ops/fake_quantize_per_channel_affine_native.h,sha256=6y5lIE3zkoIZBasSi8FHKuFXReZk9mcyDUw9HunvfW0,632 +torch/include/ATen/ops/fake_quantize_per_channel_affine_ops.h,sha256=Nr3jM73H9Gp43vkG8aFgPj6ZgdfzSFdIg5-3CywyXgM,1422 +torch/include/ATen/ops/fake_quantize_per_tensor_affine.h,sha256=AYtK1rvBlPgQK8QBHYUUP6-Z7dd4e1q9OGuvf3RONJs,1384 +torch/include/ATen/ops/fake_quantize_per_tensor_affine_cachemask.h,sha256=3ZOrhjCdd_GhJhkYIZa6aGCdWgv16CGOQZAK119ilO8,2152 +torch/include/ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward.h,sha256=7qhWCtInMIVtX91OS7efdL89jKx0DvexjiZ4xUBqUBU,888 +torch/include/ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward_compositeimplicitautograd_dispatch.h,sha256=kmw94RXWKxB48SbYPPwvjjMs90j1VIPSvk6ZhXepYnk,856 +torch/include/ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward_native.h,sha256=RyUjqDKfznWYGCIjy4QuwnzU2yUVKz-sjiJKCSEZhog,566 +torch/include/ATen/ops/fake_quantize_per_tensor_affine_cachemask_backward_ops.h,sha256=242AJwo5_RfZAs4HqJpt6u-7NmdPNRIHSbVUkkmSIBY,1201 +torch/include/ATen/ops/fake_quantize_per_tensor_affine_cachemask_compositeexplicitautograd_dispatch.h,sha256=0QZV6HRkYxLi2z-6LCRRENoue9tY7c1vYc24FK9YTUU,1199 +torch/include/ATen/ops/fake_quantize_per_tensor_affine_cachemask_cpu_dispatch.h,sha256=WbXdS23G2b914PnFXUu-F9u5rZLGn1i6yvyXqwRTVco,875 +torch/include/ATen/ops/fake_quantize_per_tensor_affine_cachemask_cuda_dispatch.h,sha256=iMMGg7mSQZWBRPyXtrRVJvKWZSvUd3hKYxJl78kUWG4,877 +torch/include/ATen/ops/fake_quantize_per_tensor_affine_cachemask_native.h,sha256=AnPkcaLEhAan9nCPjuCORbB91Yy7U3XeQ4iZivwoi7A,862 +torch/include/ATen/ops/fake_quantize_per_tensor_affine_cachemask_ops.h,sha256=nv3OD4044BIPdAJXTeTulghzVJp9P6iYKvm4lrcBfvk,2537 +torch/include/ATen/ops/fake_quantize_per_tensor_affine_compositeimplicitautograd_dispatch.h,sha256=lKDkvrRdncB78rUnmw7T7nVTxSZJriQ6LZsTZtA-RJU,1059 +torch/include/ATen/ops/fake_quantize_per_tensor_affine_native.h,sha256=UfMNmyd_jKwOUwUHEyDE1rAmAHJCUYenHDt1Sq94PtA,769 +torch/include/ATen/ops/fake_quantize_per_tensor_affine_ops.h,sha256=BS1xInEhgcESEhFj1W5v3C-05VIaO5-rAOB87Te3FLM,2240 +torch/include/ATen/ops/fbgemm_linear_fp16_weight.h,sha256=iiyweMNnHGcfu59nW9pghxoIYQCub1uWY2QasfBfsa8,862 +torch/include/ATen/ops/fbgemm_linear_fp16_weight_compositeimplicitautograd_dispatch.h,sha256=C5KDQLoLOW82CMAxcYX8KmeyLK57eJXiZuqDRv34Ers,866 +torch/include/ATen/ops/fbgemm_linear_fp16_weight_fp32_activation.h,sha256=LYGxV6vW9x_esxOApRg0WKwiXp0SAGp_C8Wxb-xPdTs,926 +torch/include/ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_compositeimplicitautograd_dispatch.h,sha256=ErOqV8lLD133dVfUW3fzFz-kXNFD4Lhs3x3nfdc2qhk,882 +torch/include/ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_native.h,sha256=F5bYzKZA9-SqaQjkhmMguc79wYvL56iklitHgR-7Dpw,592 +torch/include/ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_ops.h,sha256=5wNvbuS1D8pijeiK5PajddSvtFr0nI9_lNeBpUQoWIg,1287 +torch/include/ATen/ops/fbgemm_linear_fp16_weight_native.h,sha256=pNuhpR_BvMm4Sizzo1as0gz9hmglyfoeIeKspZavehU,576 +torch/include/ATen/ops/fbgemm_linear_fp16_weight_ops.h,sha256=AAn6KFJoQ5nJH9cKUzW9FpIGIAm-CdhxEazl5m0HjS4,1239 +torch/include/ATen/ops/fbgemm_linear_int8_weight.h,sha256=CZicv_w7arv7eRSWvYgKyU_TmP6Amzgsnf0AWk5tZrY,1107 +torch/include/ATen/ops/fbgemm_linear_int8_weight_compositeimplicitautograd_dispatch.h,sha256=4vymeLvZKJOG72WLQ_Zh9Pqhu1jkLmsflSRXAxjuLcs,989 +torch/include/ATen/ops/fbgemm_linear_int8_weight_fp32_activation.h,sha256=p3PkSYsvgffKxaBnI-gJlwWUj7-CT996sPemPFMjA7c,1171 +torch/include/ATen/ops/fbgemm_linear_int8_weight_fp32_activation_compositeimplicitautograd_dispatch.h,sha256=nzVbJrbQ6wpZziNTVwIwoKyQKpUFWnYK3-rES0zhCvc,1005 +torch/include/ATen/ops/fbgemm_linear_int8_weight_fp32_activation_native.h,sha256=zQonRBtIc_WwNj5uRdS11Mz_EksLafgpzHCPA2hV3Tc,715 +torch/include/ATen/ops/fbgemm_linear_int8_weight_fp32_activation_ops.h,sha256=SOKVxVgPpEDYua02BxtiqX_hpNtYzfmqnZt6iZ48YKM,1688 +torch/include/ATen/ops/fbgemm_linear_int8_weight_native.h,sha256=P9gY80gkdyTBXkReLhFYqYcj3rgR9SvTc7rhsj5PK1A,699 +torch/include/ATen/ops/fbgemm_linear_int8_weight_ops.h,sha256=Vu4fNAhjlLUNdsLIkwgsb1oVq-lBRMAVccGFEnue4ag,1640 +torch/include/ATen/ops/fbgemm_linear_quantize_weight.h,sha256=G7RvaBu3pnKuBHijA7rIiFIj51X07BoFPNlMKh-FsXs,825 +torch/include/ATen/ops/fbgemm_linear_quantize_weight_compositeimplicitautograd_dispatch.h,sha256=QND0FO2vaMVrW3Te8ICS-DaWtK-KJQxhd8PII4EHPas,851 +torch/include/ATen/ops/fbgemm_linear_quantize_weight_native.h,sha256=RdixkUYA-fjkMeuIqNHnMzfDbxkTjy_VQPHnraR8pKM,561 +torch/include/ATen/ops/fbgemm_linear_quantize_weight_ops.h,sha256=eOR18xVwDpOv_FRRKR_89JaLhWojZwO9kURNypjGoks,1200 +torch/include/ATen/ops/fbgemm_pack_gemm_matrix_fp16.h,sha256=ynTYL88FiAxmWUHHc54-a3y7EzHZkHL0Ed6Wx525C7Q,759 +torch/include/ATen/ops/fbgemm_pack_gemm_matrix_fp16_compositeimplicitautograd_dispatch.h,sha256=E6IzjSyqp6vIBqc_NSrhPUfyP8WuV6mbCvBPnWBBWbE,810 +torch/include/ATen/ops/fbgemm_pack_gemm_matrix_fp16_native.h,sha256=z1JouhPCtENlOA2Mixi8ua53V3joHuLUe_veQpHVkQc,520 +torch/include/ATen/ops/fbgemm_pack_gemm_matrix_fp16_ops.h,sha256=E_3Pds-DmDtsN_vw3THfO6DUQW8OAvmN-oQQ56_xhQw,1055 +torch/include/ATen/ops/fbgemm_pack_quantized_matrix.h,sha256=wTVMuH1OBPF6w4b9l1HfexGZq9goO_dYrc0y1snSaAg,1016 +torch/include/ATen/ops/fbgemm_pack_quantized_matrix_compositeimplicitautograd_dispatch.h,sha256=zyc5v2yf1VYYmfOpqg8Wxn_SCsS8W9_tNbtBFbLT_XI,910 +torch/include/ATen/ops/fbgemm_pack_quantized_matrix_native.h,sha256=IUJTxyuZqm_KcPN5Fx9p9KIKTwS-XIaJAzwPd-2dFUk,620 +torch/include/ATen/ops/fbgemm_pack_quantized_matrix_ops.h,sha256=IsPLvPxyqSoMGfJR0XhzUMvFR_c7ekbXSTUl7u8oexQ,1719 +torch/include/ATen/ops/feature_alpha_dropout.h,sha256=ZXR2nDjl2ExfOLJAee-hEakTy6d_sG9Xb_WrFOyaIL8,1029 +torch/include/ATen/ops/feature_alpha_dropout_compositeimplicitautograd_dispatch.h,sha256=aThteQaMsyhntGWMS2cPeYR6cL9PUHwbomLEO37yFmo,914 +torch/include/ATen/ops/feature_alpha_dropout_native.h,sha256=8QKY_TUs9KqaVliFhk9PZSfGLcWC6vfx_X_YwgKhk-U,624 +torch/include/ATen/ops/feature_alpha_dropout_ops.h,sha256=cmm2Kb6X1dFhVldI5XaaSRpWIZPhbXWmGWXNOmrLCAk,1747 +torch/include/ATen/ops/feature_dropout.h,sha256=k_XB_MzJszA-qOL-YUCK-KOD0NfeJns8D2zgX6Lo2tc,987 +torch/include/ATen/ops/feature_dropout_compositeimplicitautograd_dispatch.h,sha256=PtRo5tmlU1cpfcrgRcVVuIQ0q78-IOAi44ESIjdOXQY,902 +torch/include/ATen/ops/feature_dropout_native.h,sha256=czfjGD2QwIKw-l2ExJm2UNau_zlmOA2-T3FU5sX_dSY,612 +torch/include/ATen/ops/feature_dropout_ops.h,sha256=gn4sS8z7BNqECOdDYgxGUYU3RR6ZTqVm-IhFVB2lpLQ,1711 +torch/include/ATen/ops/fft_fft.h,sha256=NpFvt8vsNIruXAwQXr-MRIvy-kdoaKfit17DQcUNt9g,5114 +torch/include/ATen/ops/fft_fft2.h,sha256=jnpqggA3SgaNpMFFhZTgV89nrNXb4Ltjm-6LlBIiORc,5407 +torch/include/ATen/ops/fft_fft2_compositeimplicitautograd_dispatch.h,sha256=uZdFsk5GVMLnm-qrl82K5U3DldpWv875Et_0XE-zaYA,1882 +torch/include/ATen/ops/fft_fft2_native.h,sha256=SbAQLvn-veiGXNbfd6518xl2ADYxSTgofybRKxdMEgU,814 +torch/include/ATen/ops/fft_fft2_ops.h,sha256=tMq7a6iwsGEhoeAyqRk_0TUYEEgRQ-cjjjegu4Ohl0s,2233 +torch/include/ATen/ops/fft_fft_compositeimplicitautograd_dispatch.h,sha256=tBWmroEywJrVXAzMqhFuSsLySfFkCWIsfby60x_s6Ck,1817 +torch/include/ATen/ops/fft_fft_native.h,sha256=_19feDCmM5kiz1Kpz_LjzSJWVpouXqpCHy72vRPL4IA,795 +torch/include/ATen/ops/fft_fft_ops.h,sha256=XqLepCsU8h8zycDmkwBkyqFF_Ai9t1S1q3scCHwwbOQ,2169 +torch/include/ATen/ops/fft_fftfreq.h,sha256=BJPyBCQaEutKA6cWw1LbCdD1wEUbRLADAGntaDucGJQ,1808 +torch/include/ATen/ops/fft_fftfreq_compositeexplicitautograd_dispatch.h,sha256=axpnPfxcHMHEThbOYa6Fhm6C-phJIZnwYxCHpqhsSvw,1188 +torch/include/ATen/ops/fft_fftfreq_native.h,sha256=h4mahel6sogilDYczEitkXM42AVVqCM1DUR6odVN71w,739 +torch/include/ATen/ops/fft_fftfreq_ops.h,sha256=mtge5dJSQWGx_nsGo9Fg5UXvkrQujPNK-n4srLivx6o,2105 +torch/include/ATen/ops/fft_fftn.h,sha256=iCvXxH4YGsqykrt8xQ8pPU8UCXB9NUiomfcIZ-xUFic,5547 +torch/include/ATen/ops/fft_fftn_compositeimplicitautograd_dispatch.h,sha256=mlnmk4Q1I07E2WbxLtjqGWoKpMqwfcHazBq6YY6O5hE,1958 +torch/include/ATen/ops/fft_fftn_native.h,sha256=-HFmHBok2XF5YZSWpPKWoMB-ZhtbfqjZE4eh5sgLWFI,837 +torch/include/ATen/ops/fft_fftn_ops.h,sha256=bn0Q9AHFyyMnx22Ef5Ji8U-8o7kCzxpN3WeEMiRp1kc,2277 +torch/include/ATen/ops/fft_fftshift.h,sha256=Ed7CMsg6SY025i6o7xMydKtDWcYTfjrLhUQOYMVnbk8,759 +torch/include/ATen/ops/fft_fftshift_compositeimplicitautograd_dispatch.h,sha256=_z4XJ8jxJmRtMlqu7cfNyo5VqPbRsP4vZtL6aGj6dh4,837 +torch/include/ATen/ops/fft_fftshift_native.h,sha256=Og5L-0ECaPeuLUwrprXKoUmKizw4UszVnRBupK0fbLg,547 +torch/include/ATen/ops/fft_fftshift_ops.h,sha256=Bp6sXWhEuIqBqxaZ2__CSd1kVkqcqUQBdACcb4darcM,1105 +torch/include/ATen/ops/fft_hfft.h,sha256=3jImU3JKetYDINBcJ9nVlUn1Eb5q5LRFNfbynJqWXek,5145 +torch/include/ATen/ops/fft_hfft2.h,sha256=iXaFR2doWgmmhGsDThHLvVSqniNEfaauEA9HvZXnd0M,5438 +torch/include/ATen/ops/fft_hfft2_compositeimplicitautograd_dispatch.h,sha256=iCzpuP-MZi3SF-ysf0m4jVTNCFmPJeLNCVyILnbj3eU,1888 +torch/include/ATen/ops/fft_hfft2_native.h,sha256=hcp4j4wWmTrjuw9EOAKRWDYcIQQ-F9wbxzJ9lnMzXr0,816 +torch/include/ATen/ops/fft_hfft2_ops.h,sha256=B_Po6yAzNnB7VX0sjP2ayvPyA6iYnwWDmuOLqJBrLO0,2239 +torch/include/ATen/ops/fft_hfft_compositeimplicitautograd_dispatch.h,sha256=b2qNchbIk5fsBQePCh8RJMhjhKsEjsx5UAMxe07XqFg,1823 +torch/include/ATen/ops/fft_hfft_native.h,sha256=WkFsixazxf3-gXrtNVJyOLO78fhsWqxSPho07soYCtk,797 +torch/include/ATen/ops/fft_hfft_ops.h,sha256=h7uYV4gxPLwjaUGV5X3D_XIchpc-jFh0JgYa03K2j2Y,2175 +torch/include/ATen/ops/fft_hfftn.h,sha256=uJ0CK70hDX8cLXQEamEvAwUbdvpBWDVE_AC0dyVMF3o,5578 +torch/include/ATen/ops/fft_hfftn_compositeimplicitautograd_dispatch.h,sha256=BtZsjj6eNfjglDNLZ96frW-OUSliZwqU1UG6ggc-Tp8,1964 +torch/include/ATen/ops/fft_hfftn_native.h,sha256=TBXvUEQu0100gSvdlj4cZ10H7NgEkz6f66djg18GFdE,839 +torch/include/ATen/ops/fft_hfftn_ops.h,sha256=rwUMGI9asniG6zwA36MoeQ1arsrTqR6GCkYAseky_Dc,2283 +torch/include/ATen/ops/fft_ifft.h,sha256=iDJ18LWPmIld312AFGiFb9LTJ77MUCSmWnZeWMUHbRQ,5145 +torch/include/ATen/ops/fft_ifft2.h,sha256=tjwcNuyTiwSQnzMlCGx8Sc411bXmV4M59gGH1sXXad8,5438 +torch/include/ATen/ops/fft_ifft2_compositeimplicitautograd_dispatch.h,sha256=Op0gCyQbsp7pyPjfwR5QY_1eY3n8gGgLro_FzwPTYSU,1888 +torch/include/ATen/ops/fft_ifft2_native.h,sha256=HI-qFeOwuLM1sY-Giv3rCHYQk-z83F0efzKTA-qg6ts,816 +torch/include/ATen/ops/fft_ifft2_ops.h,sha256=652vZlB-AVT1gYRSK5jlK4XvbmNRiRolnea6K1YTshE,2239 +torch/include/ATen/ops/fft_ifft_compositeimplicitautograd_dispatch.h,sha256=8M0Dr6ZWpEzL_oXgr4al55DOnbylwJFCJuD2KYvy5lU,1823 +torch/include/ATen/ops/fft_ifft_native.h,sha256=XfItHPHjmPnoHk735Mu5inSdkkWtmPaZFPh30CnagXE,797 +torch/include/ATen/ops/fft_ifft_ops.h,sha256=XTMKXUbiTTDQPoQ1BEP5EjxKfusK6ZUwAXsCtTyvoMQ,2175 +torch/include/ATen/ops/fft_ifftn.h,sha256=mNeHjeHswkFbESs-nQqYpI4oaN-1-AnLPSImdSiN59Q,5578 +torch/include/ATen/ops/fft_ifftn_compositeimplicitautograd_dispatch.h,sha256=I85NLp_VvqoPaEIYSXReatzpQxqI-MwIvkltw29yxuM,1964 +torch/include/ATen/ops/fft_ifftn_native.h,sha256=vWlFOUNGXQTjQLqKZNWaKLrtg4wIDQ0ZdstwGiiNYyo,839 +torch/include/ATen/ops/fft_ifftn_ops.h,sha256=Ne-78LeDSdikHfbt-XWoWLI_xpcuZLPRwWZQh2FDNGw,2283 +torch/include/ATen/ops/fft_ifftshift.h,sha256=zJp7AqlPF-ePeBpPxN0G3mBf_qg4xIdUJTL1sbvwUpc,763 +torch/include/ATen/ops/fft_ifftshift_compositeimplicitautograd_dispatch.h,sha256=8h23hQ_hUYtQM1A1UR2y47wxr7plfFKu9HuSnMta22A,838 +torch/include/ATen/ops/fft_ifftshift_native.h,sha256=P65O4_7UPbTTsBUjBCycsr0718sXJHl3TRkAqAv5fhU,548 +torch/include/ATen/ops/fft_ifftshift_ops.h,sha256=Vb4mK4_yPP0Jy1EscCguK3nyPEafQmK0h4DqiWuXK04,1108 +torch/include/ATen/ops/fft_ihfft.h,sha256=GmKhQAlkQIQEudUfMcGntel8_PJptOf1AgP0FXEHpNc,5176 +torch/include/ATen/ops/fft_ihfft2.h,sha256=xLwPd9mJKsarPB8vkPLqHk0lpMOsLnW2w6gp2vvzpDc,5469 +torch/include/ATen/ops/fft_ihfft2_compositeimplicitautograd_dispatch.h,sha256=zd5Pa5qUaCCjv9jG9LrqxmrzYen5ZFfdj3daFExcL_I,1894 +torch/include/ATen/ops/fft_ihfft2_native.h,sha256=zHwTEt1NNPu2QrApaRCgiG3VHCV1iLUAqw22ClzqHM4,818 +torch/include/ATen/ops/fft_ihfft2_ops.h,sha256=q5PT7Rby7FEsdmYkL2OSrCo8shVKd0S5O1_CglFERPg,2245 +torch/include/ATen/ops/fft_ihfft_compositeimplicitautograd_dispatch.h,sha256=iEOb6jl_r7qwVgGzYFicBOxc6MxEsM5aGnYi1xPI5KE,1829 +torch/include/ATen/ops/fft_ihfft_native.h,sha256=P5XRM5Jfo2pA1nY2PBhsu4zdrdAlqSZztN0HrLMFmeE,799 +torch/include/ATen/ops/fft_ihfft_ops.h,sha256=W3mhvkaA7brXNdu_6VInytfvjo3FvWaEJG07tbIQ4lQ,2181 +torch/include/ATen/ops/fft_ihfftn.h,sha256=f5P_7eny07nT9ZZoERHdZYRtzzwLCOclgQy3fBQSw_Y,5609 +torch/include/ATen/ops/fft_ihfftn_compositeimplicitautograd_dispatch.h,sha256=taPLZxKEarqL4UIvlxNIQlVlTQyROGekSObg5s189Jk,1970 +torch/include/ATen/ops/fft_ihfftn_native.h,sha256=nwx9uBBA2t8mvmNGTdT_2BnvN95hInSVUaP1yOFlJS8,841 +torch/include/ATen/ops/fft_ihfftn_ops.h,sha256=HDa-ESDKkL4llhCNndg2jHL1d9dpDV721a4HbMm3r1U,2289 +torch/include/ATen/ops/fft_irfft.h,sha256=hgOG8JPsEWrClS9kYMDWWpMd66QZDT4mwTQ1FiLAVkA,5176 +torch/include/ATen/ops/fft_irfft2.h,sha256=weYonVbQY3-afIJn_fOf13dyI0jQKzBYfTaQVooezYk,5469 +torch/include/ATen/ops/fft_irfft2_compositeimplicitautograd_dispatch.h,sha256=VNAhcR2MbtFFcj64916ynjlWnBgnF_uHs2IfYU4ZhVE,1894 +torch/include/ATen/ops/fft_irfft2_native.h,sha256=wJ74b74fIZgT1AR94nnOZbivao5i5qlM1gdiDIMBAhQ,818 +torch/include/ATen/ops/fft_irfft2_ops.h,sha256=oYVEcFvNs4_jgFrDCqVuTF53160NgGTScHfWfInB72k,2245 +torch/include/ATen/ops/fft_irfft_compositeimplicitautograd_dispatch.h,sha256=x4nXNAcHdga36GQORwx75tHgmQY4A9pz9dtYN_QL0Ec,1829 +torch/include/ATen/ops/fft_irfft_native.h,sha256=evtlnnzrVGyAYX2y5yIeyORKxLccYII-tio2oUFk6jA,799 +torch/include/ATen/ops/fft_irfft_ops.h,sha256=BheOcLiNfbCiFgK319bTAURno8U7AA55T8JYlPSGKo8,2181 +torch/include/ATen/ops/fft_irfftn.h,sha256=CM5bJ418og0ys8rzMMXSrc4sAtjHj-8Gpujnuw2vneI,5609 +torch/include/ATen/ops/fft_irfftn_compositeimplicitautograd_dispatch.h,sha256=1ZE8bkUdghYyD2DlFZUx271lp0pjc99UYHR2cMrbO4w,1970 +torch/include/ATen/ops/fft_irfftn_native.h,sha256=JiGfqXmrV5rKDJTDwvFPoD9AEkTvoFuaHxZEN-FZFns,841 +torch/include/ATen/ops/fft_irfftn_ops.h,sha256=391SF0Ao5RbRspsrYF8FXMtBSzhIFQ1qYd-6wKj9H-s,2289 +torch/include/ATen/ops/fft_rfft.h,sha256=rOYn7pkPt_-6vDZ7kpjDLYJ68S0lwXj6RgRcxlSS7-E,5145 +torch/include/ATen/ops/fft_rfft2.h,sha256=Pw0jW6JfvA_B79mXhjnAd93Sk7aGnKG8eDRjALBY1tk,5438 +torch/include/ATen/ops/fft_rfft2_compositeimplicitautograd_dispatch.h,sha256=g_iuXy74fn9sdewd52CWv8gG0U6GGfogl7qNE99GMSw,1888 +torch/include/ATen/ops/fft_rfft2_native.h,sha256=ZriUDX72Eq7GfWOiZ29Ei6efikd0gKvUZq4W0yQHa64,816 +torch/include/ATen/ops/fft_rfft2_ops.h,sha256=OApbjLqEzpTnxERFqwLpGkBXIZY0E0_RDkuzT97WL5E,2239 +torch/include/ATen/ops/fft_rfft_compositeimplicitautograd_dispatch.h,sha256=M4YYUwpmdpViKyLh8QADwcSjrAAB_WncDfU3VlfTTo4,1823 +torch/include/ATen/ops/fft_rfft_native.h,sha256=drVykdXiUHsz0Kg0Ido5qDRDwvH0YXVVlnIMVp2Jw9M,797 +torch/include/ATen/ops/fft_rfft_ops.h,sha256=mjh0JWM688iJUT8qlanuk3tluNErmQBIcnuMpLqzYsw,2175 +torch/include/ATen/ops/fft_rfftfreq.h,sha256=WC1S1kV_lathYyJweKj6FvNNP14gWtkn4T_pa9NxO0Q,1821 +torch/include/ATen/ops/fft_rfftfreq_compositeexplicitautograd_dispatch.h,sha256=cqj909mSYNcKqioVBtkNderN5MHRPVBkoxVGcw1thog,1192 +torch/include/ATen/ops/fft_rfftfreq_native.h,sha256=nCqbA4d99AJ8UZe8MUytl0Tyh7WNk4LBK97m4wxJzbw,741 +torch/include/ATen/ops/fft_rfftfreq_ops.h,sha256=ZUi27_o_Url3JF6T4o1Iggss35IQ3r4_pg0nf8uOCMQ,2111 +torch/include/ATen/ops/fft_rfftn.h,sha256=KJXh46OKlxGFdAMelTBw_eJbRaTkwGamkpihiHPfEpY,5578 +torch/include/ATen/ops/fft_rfftn_compositeimplicitautograd_dispatch.h,sha256=bIntqX1vJACfD2OHWmUcvuxvw2lENZH5rRvmjg46uss,1964 +torch/include/ATen/ops/fft_rfftn_native.h,sha256=tzhtzZznE4NfoJFKiyVINDzJv0VWYIboTUU3F52OOV0,839 +torch/include/ATen/ops/fft_rfftn_ops.h,sha256=lvT5WxD-R2vxYpGFJW1omMnQbfSMNVTlxZrnDvNdFx4,2283 +torch/include/ATen/ops/fill.h,sha256=b8ujGD5VWCFVIgIROjN55F0TYpp3q5TF6Dx0muHPa7g,2338 +torch/include/ATen/ops/fill_compositeexplicitautograd_dispatch.h,sha256=QrWK8g2MukRDdZ4pBYSmsyGMepJCi_vaxk25xL35pTk,1304 +torch/include/ATen/ops/fill_cpu_dispatch.h,sha256=xITqKkIx_7TnZVn8BmPsy3_wmfZUImOyPytkyuCPaKg,840 +torch/include/ATen/ops/fill_cuda_dispatch.h,sha256=a1HTtHaz0HYuNkX-j1AuUeBtUDTQCglhYBLOJoeq5BU,842 +torch/include/ATen/ops/fill_diagonal.h,sha256=0YHhSadk9LRd1nTwDHRV7mxxYeVuoMKsMIvs6Fyr8is,537 +torch/include/ATen/ops/fill_diagonal_compositeimplicitautograd_dispatch.h,sha256=W3cfyZtcM2Egr5wIJFHZdGIFOhyZp4iW_BgXObmQRNE,839 +torch/include/ATen/ops/fill_diagonal_native.h,sha256=tuwcqKCQ78vP-g93v_4eSkOJoswyL04vGJ2C9JSuoBU,549 +torch/include/ATen/ops/fill_diagonal_ops.h,sha256=u4IC0PSTaWgz-4zdbdsdxZ2OTpK3VLbuj2ILW4Hrk5k,1152 +torch/include/ATen/ops/fill_meta_dispatch.h,sha256=MKcKQ3BeJHdm_AGx2yCbHPfN8n70LNzginNx76Kf0M0,842 +torch/include/ATen/ops/fill_native.h,sha256=qBMK2H1QEFYpsxgh_ouh5jynyW_jrme_TlVrsyn34UY,1559 +torch/include/ATen/ops/fill_ops.h,sha256=yDYvKOSuWqj3sT6GIXgZsBa6kxk6_HEGXZmgwldLJBc,4321 +torch/include/ATen/ops/fix.h,sha256=QAil0yIU1sMQTsuBAbxecWxTjJxEJ_dYqnqPXSyED8o,1175 +torch/include/ATen/ops/fix_compositeimplicitautograd_dispatch.h,sha256=UcKNz6B14G8JqAs8wbXDXxz9za2U9eMckxn-D2IgDaM,986 +torch/include/ATen/ops/fix_native.h,sha256=xZm1swCvpOUm42AZKulE_869Aiiu0NqnbQdiJ-N4k1A,619 +torch/include/ATen/ops/fix_ops.h,sha256=pEJ0g2ZnQH_XHA90KVufNu_4xb1xVal8pexVp9vbZhs,2070 +torch/include/ATen/ops/flatten.h,sha256=pv_f4u9gY92NkEbU6Q_0BMuSN-4NaYABnReT_XhIVOw,1686 +torch/include/ATen/ops/flatten_compositeimplicitautograd_dispatch.h,sha256=61qlFT2IHk2LeHuMcHG_Fgi9FQRyHhV4G0Ykdhi0Oww,1162 +torch/include/ATen/ops/flatten_dense_tensors.h,sha256=XkekStOVlmpyPq98nCX_Js_A162u6mTd-BMCuPPo4n4,735 +torch/include/ATen/ops/flatten_dense_tensors_compositeimplicitautograd_dispatch.h,sha256=nioNKhBuII1hT910ATCdHM84dVA0_7jaXVaOFLSiJYI,801 +torch/include/ATen/ops/flatten_dense_tensors_native.h,sha256=NuS59VZ9zzWI-SDK29mKw3o_mAwP_vSeCYe-q61tNNQ,511 +torch/include/ATen/ops/flatten_dense_tensors_ops.h,sha256=WkYgvc-Go13QTiB8SsfVeNPmhoWfzPZyz5gWZVqZxS8,1030 +torch/include/ATen/ops/flatten_native.h,sha256=_wHreF0OUzXi0321dGuGBDQAxKM6Ma-x8eA_PXzaFhQ,872 +torch/include/ATen/ops/flatten_ops.h,sha256=F8nhug_z1lf6ZHlEGo7mgSV692IhWv85Qms8L4IGzUs,3382 +torch/include/ATen/ops/flip.h,sha256=zCwHVgDYqxzy4ChnnJBAs8vae59Shec-jX_N9vOHp2I,1167 +torch/include/ATen/ops/flip_compositeexplicitautograd_dispatch.h,sha256=9Yr_v9RTWiukzd7sg8DDrC1TmkZLPL5cf1scQuscS-4,931 +torch/include/ATen/ops/flip_cpu_dispatch.h,sha256=ZsZjEHGya37mrL-hOpQwFOSACOwj782PtgQUt9XRtDI,763 +torch/include/ATen/ops/flip_cuda_dispatch.h,sha256=Rk9kOmucxo6KOapoT1jy0_CTqf-XGTH8p6N6i1JewD0,765 +torch/include/ATen/ops/flip_native.h,sha256=oy4T9cMfJWbWAC39LsjtSK1HGtxmzQEy1Do-xgQ_i74,616 +torch/include/ATen/ops/flip_ops.h,sha256=8Jz4tQwa__akmpaFOLuAF4B4T7lWdKhiLwyn5r7PRbE,1721 +torch/include/ATen/ops/fliplr.h,sha256=iR_-LIo6iyWNqusZsewKc-96IdKzifQR7BuM3UgCBQA,668 +torch/include/ATen/ops/fliplr_compositeimplicitautograd_dispatch.h,sha256=BFF29Q9LuGzshe2urYORhGQHodw2ErciFjvoabpU0G8,787 +torch/include/ATen/ops/fliplr_native.h,sha256=s7ggWzj72fNHBS0JaQRU6hWzBxeUiyfPDt0Yu8ba2Zs,497 +torch/include/ATen/ops/fliplr_ops.h,sha256=dCPa2nr_Th7cQXsoffP8JRZZBPwPqfMa0pdK2kIf1w0,986 +torch/include/ATen/ops/flipud.h,sha256=mzCcyteRvBTMe65p6fTcHv0P76HaYvkHDDLw0_db2Gc,668 +torch/include/ATen/ops/flipud_compositeimplicitautograd_dispatch.h,sha256=URIbF6s_StXQzyWb25FG7-MoOJYyyns86KeLHf9DCPM,787 +torch/include/ATen/ops/flipud_native.h,sha256=qpJuj1XjHmFaiVk_D1NscZt2-rr1uEr-l-n9-nVHnQg,497 +torch/include/ATen/ops/flipud_ops.h,sha256=2_Lm4lxRFG-D1nbMRBPHfHziuBLGQRY59u4_N7NEqEY,986 +torch/include/ATen/ops/float_power.h,sha256=lNrqlz0egtumA7RnZasxC9aDkxaA0MIs6XdDPpHzWlw,2999 +torch/include/ATen/ops/float_power_compositeimplicitautograd_dispatch.h,sha256=RkY5lCkBhLiEX7aDs2GUsP3gFPnyW15ApvWrSyrcBjA,1852 +torch/include/ATen/ops/float_power_native.h,sha256=pm1oBp8HUxGBt9O7nuzjY6jqeXOnYWhfoVJBfeBgcgM,1220 +torch/include/ATen/ops/float_power_ops.h,sha256=jty9qMbM2YUZPY_maAoau5ckACsp_VZlKVsgEva6eRc,5958 +torch/include/ATen/ops/floor.h,sha256=G02KwnEQwKCOUq9_h0n2IhfHVYRNla7E99pEsUYivUk,1201 +torch/include/ATen/ops/floor_compositeexplicitautogradnonfunctional_dispatch.h,sha256=9roGqNabIZFvu9eG91rig5NxT7lDYZ-XHFw_GimnDog,863 +torch/include/ATen/ops/floor_cpu_dispatch.h,sha256=ykYFHSa-It_Qm_0JfjB3D0M8zujiGQrFiV5Oj5ymXGA,950 +torch/include/ATen/ops/floor_cuda_dispatch.h,sha256=LVZh5NRxSh3HzO30k3_Zt1AcTj3qNbIPZdURBuTvp4g,952 +torch/include/ATen/ops/floor_divide.h,sha256=kTSkVCYVhR1FPSlW1OSRpk2jLFX0Y93FA84xL0IUQsI,2044 +torch/include/ATen/ops/floor_divide_compositeexplicitautograd_dispatch.h,sha256=-Q09AgmUv0CQTPi9lr51UY6rCgkdrN5EaCz5VtdsxyY,1126 +torch/include/ATen/ops/floor_divide_cpu_dispatch.h,sha256=pE3ZAGxazRWwn6tExYF9OrJm3wdHnQDphe5G7O2T_XA,1082 +torch/include/ATen/ops/floor_divide_cuda_dispatch.h,sha256=uGiMB-_0O-28ivZwXyPfIvuSiPSrVe64DSQOjfed-0A,1084 +torch/include/ATen/ops/floor_divide_meta_dispatch.h,sha256=_PCBIgeutxTdyrr5qQUBl-tX3_br3gOjYsOaTuAnPhI,774 +torch/include/ATen/ops/floor_divide_native.h,sha256=RLcxLISaO9M3qg-xgJ5wHPBXl0P9UTWff_ff0FWmjDA,1324 +torch/include/ATen/ops/floor_divide_ops.h,sha256=8pRLQGk-q_7hcIAUZDePaxO0-daYSO_kZK5Cfwdb2-s,4424 +torch/include/ATen/ops/floor_meta.h,sha256=8QydO_526PagvnA9bdy3_Ie88Z--Zs4j7uGZsJgvJss,593 +torch/include/ATen/ops/floor_meta_dispatch.h,sha256=HdyFsR3yQC2A5z-cmOvI4FeYBzSD7NdU3b02R0SOTKk,952 +torch/include/ATen/ops/floor_native.h,sha256=Dygh9Br8gwGo59cyKfOeyavvkDNadG-XnrkWAJ2BqlA,1036 +torch/include/ATen/ops/floor_ops.h,sha256=KJg_pseM8-3ZSCMcUfMbFOtgp5rhIdmD88AmxaLD2aI,2088 +torch/include/ATen/ops/fmax.h,sha256=CVEbmSrnaM2LRKOXr0PrvIYiz2QbMgSpqs77qlRMuhQ,1188 +torch/include/ATen/ops/fmax_compositeexplicitautogradnonfunctional_dispatch.h,sha256=mfbOMM2v_EHhV5-B0cZO9D39aTjP4YJi9ZLSWlW9WlQ,837 +torch/include/ATen/ops/fmax_cpu_dispatch.h,sha256=FZPRn7PXuz1axhZC6D0_rPT6pCVV6dBt-TCTFovkkjE,974 +torch/include/ATen/ops/fmax_cuda_dispatch.h,sha256=uItWVF9ScGaYXEK8zmbMRa8aV7vsUgwODTgqEGsaYJA,976 +torch/include/ATen/ops/fmax_meta.h,sha256=tCjI4fbbdmsslWHjTO_RZSk_7ALdULm5Wh012aOedes,618 +torch/include/ATen/ops/fmax_meta_dispatch.h,sha256=Cqn3kxKvCJnz1HgOb2JUUcm978eY__X_3rxJlRs32q0,976 +torch/include/ATen/ops/fmax_native.h,sha256=aLUX0Safq79n01AiFxotQt8dAT6l753eG3gB-zBMGNs,639 +torch/include/ATen/ops/fmax_ops.h,sha256=Wj_pw4vdEV9B0g4H0eQh0D307mQziHE9nQA0QrJJfdQ,1747 +torch/include/ATen/ops/fmin.h,sha256=bJx0_pPzb5Xc6KLBijUNTEkQLGHfFJ5-ytsY_HIpE8Q,1188 +torch/include/ATen/ops/fmin_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Wfod54GyzaY4EXB2tiU1FtSnedDnH15vydeSkADZAII,837 +torch/include/ATen/ops/fmin_cpu_dispatch.h,sha256=MmtmqeSo9OklQrbuhtD1RxiwHKw8jcML-CVXIASb57Q,974 +torch/include/ATen/ops/fmin_cuda_dispatch.h,sha256=6oi5CfTsoypJHaNo4qKNrl_oviEjgtRauTxKyLBxQQw,976 +torch/include/ATen/ops/fmin_meta.h,sha256=Lcqk6zNsj9JYrhEKs7GpuFTDhjng4XpgjF53FnV65bM,618 +torch/include/ATen/ops/fmin_meta_dispatch.h,sha256=DqE6j-qdoRXPm-0v_d5N0VrmUrObvOUrocVSNW37G0Q,976 +torch/include/ATen/ops/fmin_native.h,sha256=08UYmAVNk_1bDcNa1f2YXIeOyLs3bQdnLnqi5B9hoM4,639 +torch/include/ATen/ops/fmin_ops.h,sha256=e6-fqcz8BtQTFV9SC5SHURHrjwC_ezd75Xpgl7XrjME,1747 +torch/include/ATen/ops/fmod.h,sha256=ci8sykwug892LqMC4sSRou_sw-Y1ToYfuvI6kqRdgm4,1934 +torch/include/ATen/ops/fmod_compositeexplicitautograd_dispatch.h,sha256=EIvo7eiX8d9jN63qJnCqlUTm1_6OCxfx_mwU3v7ufg8,1094 +torch/include/ATen/ops/fmod_compositeexplicitautogradnonfunctional_dispatch.h,sha256=dpsvSu7M7v34-ssh2ZNrmOBYfw9wmVcgnhIHwL82ywk,913 +torch/include/ATen/ops/fmod_cpu_dispatch.h,sha256=HPGtRyJ7LsQjeM4qWQGfgOV3IH5SrAUQw3pOw2eKzrA,1050 +torch/include/ATen/ops/fmod_cuda_dispatch.h,sha256=hUDJqP7L26njdSRN2rf6-mWo7sL157_2aZiL-cEm6Iw,1052 +torch/include/ATen/ops/fmod_meta.h,sha256=hHhDQjQP3rp8q7qKm-BffdHTvwg8SQT4zgXcHpBl2ls,625 +torch/include/ATen/ops/fmod_meta_dispatch.h,sha256=EfDsNxFdIYT6wg9wEGw87wHosThwURv57mMwrr4pc6s,1052 +torch/include/ATen/ops/fmod_native.h,sha256=XkvLGt6tv5-gIuNf2M-j57IM6BQK9ea7SqXRUmh3brg,904 +torch/include/ATen/ops/fmod_ops.h,sha256=HlZvUHD_MdKM8m2RMoVK7s0Q0CaPb_qaKstGzo66qK8,4321 +torch/include/ATen/ops/frac.h,sha256=mX-aYVUxjsw1sfJU0iQK5NWULkjmw2T4tuldqk6MK50,1188 +torch/include/ATen/ops/frac_compositeexplicitautogradnonfunctional_dispatch.h,sha256=vbMcLKh1FBKK87n4yIF3PDWf47z3BR-MFkrs-vISlFM,861 +torch/include/ATen/ops/frac_cpu_dispatch.h,sha256=BL0jOfqkr1XwfRK0fMfFD7FrAwG0BH7_CS2auL5pag0,946 +torch/include/ATen/ops/frac_cuda_dispatch.h,sha256=E3SjWbAojHkHVgAujzmVErO3doKC7kPODM7N7AEqa10,948 +torch/include/ATen/ops/frac_meta.h,sha256=k-b3lFowjoLtLn7nkTBT8Ygy8Nuqn_FuA7TDGWseWOQ,592 +torch/include/ATen/ops/frac_meta_dispatch.h,sha256=UaX7rb7eOJa1MmFSHF_CuORtCU8ade8qkfegCjbCLh8,948 +torch/include/ATen/ops/frac_native.h,sha256=K6V_chkymTWrttwH4ZIt6KmPb0Ve6X8Fpa8smQrmVDQ,1027 +torch/include/ATen/ops/frac_ops.h,sha256=RI4ETMxIxOYng51Ytl1Y9KPcmoO3O2Xd2WjhiQHbQlA,2079 +torch/include/ATen/ops/fractional_max_pool2d.h,sha256=8Njuof2YcUO87lPGHuLkxI4NJJYKEs9GOQMvL__-EXc,2060 +torch/include/ATen/ops/fractional_max_pool2d_backward.h,sha256=z-ITqfPkTDhSWg8Ds38ODuYJyzfHrmYnD5L3pIi5cWs,2103 +torch/include/ATen/ops/fractional_max_pool2d_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=A_MZN_35wNZ8J2F8CkFZmnflvdv6OzmiEMEdnm9qpHU,955 +torch/include/ATen/ops/fractional_max_pool2d_backward_cpu_dispatch.h,sha256=LZ9vdI03VkNe5Zn4e2CGoifsPoixcTlZf7EXD4wKemA,1342 +torch/include/ATen/ops/fractional_max_pool2d_backward_cuda_dispatch.h,sha256=fDcHnRwUhokeGmAEfw57WDee4uz8sPYY4-mdfmQACU8,1344 +torch/include/ATen/ops/fractional_max_pool2d_backward_meta.h,sha256=Iej-NLBzDOZLdSlEQ8Wl2IeM6h8PPA1eu5fYBPBWrKE,736 +torch/include/ATen/ops/fractional_max_pool2d_backward_meta_dispatch.h,sha256=W3pmLzdWAGClhupZWktYq1E7GskZwqlwgJUDZQEL2bM,1344 +torch/include/ATen/ops/fractional_max_pool2d_backward_native.h,sha256=_u4bgSLdIF0v1cYsn7iaeVhRUF0wvM7uTERjJgxFae0,1134 +torch/include/ATen/ops/fractional_max_pool2d_backward_ops.h,sha256=59SG0enXgVihIduSOpsSWVBVZGI46NYLo2-uGSDwq4g,2545 +torch/include/ATen/ops/fractional_max_pool2d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=AhiAgeKTa3L4b2N61geudqXhsZ7_vI_SgDY-DXwSXHM,946 +torch/include/ATen/ops/fractional_max_pool2d_cpu_dispatch.h,sha256=C0QexDqB0rkKIhfgI5welhrPaCRXe2G7OJrOO1aX0I8,1355 +torch/include/ATen/ops/fractional_max_pool2d_cuda_dispatch.h,sha256=eLwR-e-uileZ5fJVM50e0u4EfID7aVUy3E_fAdRocf4,1357 +torch/include/ATen/ops/fractional_max_pool2d_meta.h,sha256=E6nj6M6oQpAl4-76Brxm1-z9F3hZU-iOpZ-9MrQC9IM,702 +torch/include/ATen/ops/fractional_max_pool2d_meta_dispatch.h,sha256=ARKBRxjxyHrCCmSGhe9AzfB3yOeM6vWkilZ60lXiVKM,1357 +torch/include/ATen/ops/fractional_max_pool2d_native.h,sha256=9L0Rlq3E-bVeF_3XhAwz8zCqR3QOPXkFB1eutmeFxfo,1095 +torch/include/ATen/ops/fractional_max_pool2d_ops.h,sha256=CtQgO9UVEfbiYznpwXDKP1VFR-umLSdiuLz1h8jajBU,2559 +torch/include/ATen/ops/fractional_max_pool3d.h,sha256=yK86EiJ-48Y_eHSlgVMJQz4YtwJ4_bcVl7YzPGv4WhY,2060 +torch/include/ATen/ops/fractional_max_pool3d_backward.h,sha256=8_7PQR2C-xIVAvxvWNZm5Xgm-VVKyRnPm3pavo-umeQ,2103 +torch/include/ATen/ops/fractional_max_pool3d_backward_cpu_dispatch.h,sha256=jDqeeYdwtfhXpJYq7xAU_aVhUg9FiQ_hykJximGTqXs,1342 +torch/include/ATen/ops/fractional_max_pool3d_backward_cuda_dispatch.h,sha256=pH7jj6fZY5WG5P0Tv1HanQ-n0JV9scGi_0i3m70cnUo,1344 +torch/include/ATen/ops/fractional_max_pool3d_backward_native.h,sha256=yz5BekxbOEYPr-lcbi5lS5wcaPBrc3FpOX5IopgS_qY,1310 +torch/include/ATen/ops/fractional_max_pool3d_backward_ops.h,sha256=tiesn0q4OtKHGAt-tKokx9k0LfeyUWe00oeFbi20zJE,2545 +torch/include/ATen/ops/fractional_max_pool3d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=VRvk9CAsVDpA_eK4H40aX5uPREvqolKDHSFBYt8rdTE,946 +torch/include/ATen/ops/fractional_max_pool3d_cpu_dispatch.h,sha256=FhfwrBMJ8Eg8JkZfsbNyJ9_alSaJnYb7Mss91UPAuEs,1355 +torch/include/ATen/ops/fractional_max_pool3d_cuda_dispatch.h,sha256=sWxq_NEtkcwRxuHgNkQe7lOfmg6iif7tXLUnHkujnVY,1357 +torch/include/ATen/ops/fractional_max_pool3d_meta.h,sha256=WW5ClwjxnON_bq61OjRjsgS_P-EYtL4I2Oo-51NkKKE,9846 +torch/include/ATen/ops/fractional_max_pool3d_meta_dispatch.h,sha256=I-JpDC13xjmDKweCNMEVvpY9jqcLr9UmMb4cngN0V9w,1357 +torch/include/ATen/ops/fractional_max_pool3d_native.h,sha256=W4pPNFhkC0gLnm4uwdy7fXkL5iGdaF4Wlu1SUzjWKY8,1365 +torch/include/ATen/ops/fractional_max_pool3d_ops.h,sha256=Gm1UVpLrmjkl4nv2dmAF8y5KJDqwOgoeziDgL1sWZEc,2559 +torch/include/ATen/ops/frexp.h,sha256=IJwtbij6GJ9NQdBcuFupGkDoiB-du45M-impnqpH4LM,1408 +torch/include/ATen/ops/frexp_compositeexplicitautograd_dispatch.h,sha256=deF99eyGbjSdbGPkTDpCP_cvzU2HexXU66Ff_IXWtRY,811 +torch/include/ATen/ops/frexp_cpu_dispatch.h,sha256=HbWHWxUobLfQXyF9Kk4izrqrXK0NFBR2zPOoK7etIWo,955 +torch/include/ATen/ops/frexp_cuda_dispatch.h,sha256=RtHcFCBvjxx4iEi3fsJuLnffzJdUw8GcFC7WKJYU2L4,957 +torch/include/ATen/ops/frexp_native.h,sha256=6VAjDOubEvHAgz5w8DG1YefATRwFZC8vccOMzTYclMA,654 +torch/include/ATen/ops/frexp_ops.h,sha256=myCCBX40zmgXYnUCC-MxQwvfTA_ZszQbOomwfn5wP0o,1934 +torch/include/ATen/ops/frobenius_norm.h,sha256=bSIIDMrMGN0AeE57jKQ89IzR0J_9s6Wl5FCgbY9kihk,1410 +torch/include/ATen/ops/frobenius_norm_compositeimplicitautograd_dispatch.h,sha256=8_Ok-tTCQ5nfpMyv6jq5T8FLNWxcsPiEixglUDhILpo,1087 +torch/include/ATen/ops/frobenius_norm_native.h,sha256=S1O6QrFwdBv4-LV0yGAXWwAsW8OxVennceKoCOaRiM0,668 +torch/include/ATen/ops/frobenius_norm_ops.h,sha256=mNzO34mjY2bplbwlFiyUEbyBbdM3kImsRzZFvzdaDIc,1896 +torch/include/ATen/ops/from_blob.h,sha256=O9mlg92bmafqsnGGEX2QzQwuKWYnLZTOPC_x6Pw6Vas,4330 +torch/include/ATen/ops/from_file.h,sha256=c8QwYBlySY-I66zgmT1rKtMprY3lfs1V4JnVrpFfbVk,2252 +torch/include/ATen/ops/from_file_compositeexplicitautograd_dispatch.h,sha256=9YnaJEcYzKaktdgyyEMLwRo9oLilhQRYbZ2OPvu7n8c,1040 +torch/include/ATen/ops/from_file_cpu_dispatch.h,sha256=NMq9Z_jyYb6iQ0vFVfw3fHigx-CdJTeXs6GjJKpdW6g,1122 +torch/include/ATen/ops/from_file_native.h,sha256=5ohz5lSnuboJKuntb-HAoEsFBOhzviJYRQ1lw8s0pAc,882 +torch/include/ATen/ops/from_file_ops.h,sha256=f6d2gtv329kh7O_e640criIkYFBuZosW6l_XL4yK8eg,2513 +torch/include/ATen/ops/full.h,sha256=M7mh1VMyIkdOnMcNHX4sjULxedBgmMGFcCVY1_a25hA,8049 +torch/include/ATen/ops/full_compositeexplicitautograd_dispatch.h,sha256=HqsVVuSpSdKb8Rk1dE2vYRqn5LzZsswgksn5QHmZRBU,2580 +torch/include/ATen/ops/full_like.h,sha256=tviegJUn8dKK9AgvTY6tQwpyYCyl02DrPiDijulOn-4,2500 +torch/include/ATen/ops/full_like_compositeexplicitautograd_dispatch.h,sha256=LU3Epq7x3ERLbVs-GBCx6z3FkdGtXv9qrwuZ0J3I2WQ,1538 +torch/include/ATen/ops/full_like_native.h,sha256=0Ug3QlpwWG1M7BnYrqbl-BcaT5HJrXhlB_woMI0tRC0,914 +torch/include/ATen/ops/full_like_ops.h,sha256=B4C1DtlvkFI9t0EybW5W3PhywBi8Y15CxilAngRBMl0,2637 +torch/include/ATen/ops/full_native.h,sha256=VjIMs5z5pest3oMcnM0iCMk5SDcOVizzFXeEhQX1CGc,1214 +torch/include/ATen/ops/full_ops.h,sha256=oCUDZ2RNgNS1rN5OMvVzuYEcJa9hdAMB9nx_3WxB-2E,4360 +torch/include/ATen/ops/fused_moving_avg_obs_fake_quant.h,sha256=iSemqWYWGxH4jY-YaiPRfXNovo0aCQF7MDzJX_PuKgQ,1483 +torch/include/ATen/ops/fused_moving_avg_obs_fake_quant_compositeimplicitautograd_dispatch.h,sha256=1vaCrmVRA4weYeI05s_A_gxl9dHCrjQ5zdlNZNVdELA,1113 +torch/include/ATen/ops/fused_moving_avg_obs_fake_quant_native.h,sha256=uubym-1EpgpgMIt6LKki4AGd_Omei6kAm58HmJR8gr0,823 +torch/include/ATen/ops/fused_moving_avg_obs_fake_quant_ops.h,sha256=I-Ux6Aj5a5vKD16Sb03SW4T4OUFJPc1ziuLZBKJqR5g,2038 +torch/include/ATen/ops/gather.h,sha256=pLKVuVTjC7m9iySRfdmupibQoxifqddODbrPraCkH4g,2482 +torch/include/ATen/ops/gather_backward.h,sha256=q1MZxLpXO0t1WOdRj9n21hK1ApVXTatOyrymJ0H2SK8,871 +torch/include/ATen/ops/gather_backward_compositeimplicitautograd_dispatch.h,sha256=WxH2OZwKPBjIKuGaPY1E5zRMnhA9SaARSye-NGOY8Ek,878 +torch/include/ATen/ops/gather_backward_native.h,sha256=4hzoSN6Ns_sTtI73-67UymfvTTh4WTCxCNTu0J3NH8k,588 +torch/include/ATen/ops/gather_backward_ops.h,sha256=5YoPNimi5qfn3YWJe_R0DvjlKfZol_BXvUIWBLMj_B0,1286 +torch/include/ATen/ops/gather_compositeexplicitautogradnonfunctional_dispatch.h,sha256=dHTThEYf4YJBb8xuwTBslfbvhMHRGDMu0nOd-U2qjhM,876 +torch/include/ATen/ops/gather_compositeimplicitautograd_dispatch.h,sha256=guHPqvi_nys0GvEYMMkJC9XqjdBAQPTQ6puP2gkFXdQ,1141 +torch/include/ATen/ops/gather_cpu_dispatch.h,sha256=9s0vE6xQLwwH8qdteFRWzD0ftBnY_rwC6quStftqr88,1085 +torch/include/ATen/ops/gather_cuda_dispatch.h,sha256=zcxiGa5Ttny_rTFbtHYa9fS7n1AuyNdfIl8kaVk13iw,1087 +torch/include/ATen/ops/gather_meta.h,sha256=llMDPdAJ4RA3We6LfQKpeDjV5CEBvzmrbttGC8of1WE,651 +torch/include/ATen/ops/gather_meta_dispatch.h,sha256=gs8Z3952uHGBsDB5d17Cam5lM9utWcvBJ5xCza9yGgU,1087 +torch/include/ATen/ops/gather_native.h,sha256=ztu0Y4CoT4eo-7U-DWpAbqP9AtmrY0AwNOEm6xA4PZQ,938 +torch/include/ATen/ops/gather_ops.h,sha256=q_mpLQV75zxolFqyRohZNNlqsxyc1UDW3a3LGaWqu68,3568 +torch/include/ATen/ops/gcd.h,sha256=Gv71wBgHK5gJJO4b7-6Dcai3hFEd4IWIayl1f1ahEc4,1363 +torch/include/ATen/ops/gcd_compositeexplicitautogradnonfunctional_dispatch.h,sha256=2byiHT9IxGLj2G-zmvE8jLktt64as6g-d__imGx9iN8,911 +torch/include/ATen/ops/gcd_cpu_dispatch.h,sha256=m_lKydFT1gULNsHCnZ2YZHqUlOCQ-xxyXfga9avoZZw,1046 +torch/include/ATen/ops/gcd_cuda_dispatch.h,sha256=G3x3FPPpqlZPImDbGsKpt-p9DjJLKRzPIeDV26Y-CBI,1048 +torch/include/ATen/ops/gcd_meta.h,sha256=za1sBQF4d4Er5sApLEUEhmq7qT-aYLdb2sESahgQc0M,617 +torch/include/ATen/ops/gcd_meta_dispatch.h,sha256=ifwsQCn3_rSHfNc8enRb1m10mSxSkWaFyGug9SDP60Y,1048 +torch/include/ATen/ops/gcd_native.h,sha256=EvL06e8UkeIGMIsLL0_QA_hGybyRI9rmQ3vbSOH07_4,636 +torch/include/ATen/ops/gcd_ops.h,sha256=aEpyOkqFDAZBlCU2X93H77fPpAk3soFvlOcyWLXqxZQ,2328 +torch/include/ATen/ops/ge.h,sha256=txInxyMO9LfKgUh2V8mcHMhc_VaCe150rgj2lqnOtso,1896 +torch/include/ATen/ops/ge_compositeexplicitautogradnonfunctional_dispatch.h,sha256=gVN2YA1H8CBD-QOdo5jrAd_7h3IbokNmV7-jigbYDRw,1060 +torch/include/ATen/ops/ge_cpu_dispatch.h,sha256=PKtZ8cTN5JsiuqQWnAFlWUxE75KWaKOaxqH6lvSgye8,1396 +torch/include/ATen/ops/ge_cuda_dispatch.h,sha256=ACUnY2_6X_1qwfF2f0q1VieVym_oyl32lJdrrLWI4Qg,1398 +torch/include/ATen/ops/ge_meta.h,sha256=f0fpLCsEX2ogf123Y8FHKalhKUA_GsTDv0tjpGo710E,767 +torch/include/ATen/ops/ge_meta_dispatch.h,sha256=lKsS58hpC0uZhe-bJK1vm4y94KkYbceWHLpdo5QbkG4,1398 +torch/include/ATen/ops/ge_native.h,sha256=dfeLZMqAPpStWuslrTGq9bFdm6DeAr-fi5rrZcFhjDA,1326 +torch/include/ATen/ops/ge_ops.h,sha256=6TGdxE05ff3fiFh8zg5wZa5SYEZWOlK9IH5aduPUTw4,4285 +torch/include/ATen/ops/gelu.h,sha256=pRHQfGBvZeVeoDoB1y4rZj-zSLLH_v7rJbPT25g83RI,1483 +torch/include/ATen/ops/gelu_backward.h,sha256=-ZAbxAhDcZFlr4JlvrOZN2LLqIwsxX52lKEq0vt7Fnk,1620 +torch/include/ATen/ops/gelu_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=qWVXQSpTN99JlMRfvpQwsYgRSyHui5fyR7uhORHm2m8,889 +torch/include/ATen/ops/gelu_backward_cpu_dispatch.h,sha256=baYV-iBi3ZEmC9Zul9FfEb8j2fM8x_M-jynP1O8YVIs,1137 +torch/include/ATen/ops/gelu_backward_cuda_dispatch.h,sha256=TeLD_WSbhUs2okVFhQgQ7uvQOFr0WbRoeguW5XsbZDU,1139 +torch/include/ATen/ops/gelu_backward_meta.h,sha256=k-iVt3AVo3uyTglYzj2RV1ljeJMj_7JGHZ4yrqg3FZ8,663 +torch/include/ATen/ops/gelu_backward_meta_dispatch.h,sha256=w_WzBHrbfMXL6EZ-3XAGa2ZYP-BcgIltMpJryppsAZM,1139 +torch/include/ATen/ops/gelu_backward_native.h,sha256=ueB-7NILazqE1Wn7l7V133v_kMRgBamanR7Y30cH7Rk,1222 +torch/include/ATen/ops/gelu_backward_ops.h,sha256=U9ai9wnNkCDnqxnGsF3nP5AxEI0Y67e0XYK7j7PQGRs,2086 +torch/include/ATen/ops/gelu_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Bngro4JC9tuvwcJCsYhCerE2yOCJlSMM3NEwkKKZ9kI,935 +torch/include/ATen/ops/gelu_cpu_dispatch.h,sha256=NZQI9tAuGJq1E2e7p4N2UdVZfOC1YHxAvIYdbAujQ0c,1087 +torch/include/ATen/ops/gelu_cuda_dispatch.h,sha256=cZWMnoA2CF_r_YS8Rygu4Y3xu9j1ZP6FSsJXVzhGU50,1089 +torch/include/ATen/ops/gelu_meta.h,sha256=-H0kPV4ZjPijXsYYOUSUj0-TLGxVaeyf8U7PUo83l0Y,622 +torch/include/ATen/ops/gelu_meta_dispatch.h,sha256=DUGMTD8aNwNwdQ4NtF04Oyi-CoSaqh3XpJz0U2bndXg,1089 +torch/include/ATen/ops/gelu_native.h,sha256=DyEzXqgg3pcO2PJ_0S_tAEOvYGbM9I3FcQkPLbNB4xM,1432 +torch/include/ATen/ops/gelu_ops.h,sha256=_EhVSIoXiVBCBTxsGY_ybZ3pkALwdOeOsgZvfXrT4iY,2391 +torch/include/ATen/ops/geometric.h,sha256=tCNTN3ybIiIwGpdUKaTveWjc8RmLa-NsGMqN63ZuFdM,1436 +torch/include/ATen/ops/geometric_compositeexplicitautograd_dispatch.h,sha256=TE7-YkGBSfcz0gdYnOTMTm2l7Ec2Y_djJrBiyafYppo,1141 +torch/include/ATen/ops/geometric_cpu_dispatch.h,sha256=VuWsHs9_5Oh5hkARF8cTjQbxfQOqS1SszJwmnqvcYGA,810 +torch/include/ATen/ops/geometric_cuda_dispatch.h,sha256=Nru-LIsyYfCQHAE5lrk46mnai-A1KJNZQKHT_-zule8,812 +torch/include/ATen/ops/geometric_meta_dispatch.h,sha256=Y8ImO6Hj7CGEeB3dwlOA48Lxq7v4uvVezrA4h1_Vteo,812 +torch/include/ATen/ops/geometric_native.h,sha256=Pe0Ms9d7XEprcN-AT-EL-qWweMgnahirs2bpfZQ1GFc,823 +torch/include/ATen/ops/geometric_ops.h,sha256=iqr1VtexQoJndmdk7jZJwcI__H1tafrFB193w5_oqUs,2670 +torch/include/ATen/ops/geqrf.h,sha256=v0AWdkOMiUJaw8YkBxckRdzjBJY3UeLtQ50_9dqNmZw,1250 +torch/include/ATen/ops/geqrf_cpu_dispatch.h,sha256=15AijNVsRObUP7jD_q-HEb3HWD1RXLwpY9clLdmsS9o,1010 +torch/include/ATen/ops/geqrf_cuda_dispatch.h,sha256=sPAs-b_vExwWGhY_5_a3PPwRtd3z-SFoUkVUfsSllPo,1012 +torch/include/ATen/ops/geqrf_native.h,sha256=ZXtr0RNt9jhWH1URUqe5MJhpGuhx547jNXvPuNaQrlU,642 +torch/include/ATen/ops/geqrf_ops.h,sha256=U_WEbSNsAt8UWrn1c1ZkkS0zA8ZY3gqeV32-3YWGt_g,1827 +torch/include/ATen/ops/ger.h,sha256=6vHT07gEedSW-VLjN-S8NF60-wqbGAURAX96DFIF1I0,1169 +torch/include/ATen/ops/ger_compositeimplicitautograd_dispatch.h,sha256=K7FKNcJrsDjdi-NynLs_TXhm3Up0W_BYCQll5_HnrWM,1012 +torch/include/ATen/ops/ger_native.h,sha256=Hgn_HoJY6Di5lOlPYIH4fSOze1lO-TpnPZ0RAbBCnlw,620 +torch/include/ATen/ops/ger_ops.h,sha256=pmInd9WiljHTa5MYaqrrCRS9imMvxWnrlF7ztUtjK7o,1735 +torch/include/ATen/ops/glu.h,sha256=1Z9qzdboCbdFYfpBbsGfYcDGoC_FI16U98JDSz_yGFc,1133 +torch/include/ATen/ops/glu_backward.h,sha256=8jOpzbeZ2KijjOxde4deVyBXjACwETzThdLdzwZe9EU,1473 +torch/include/ATen/ops/glu_backward_cpu_dispatch.h,sha256=CiGPGy4Lx73gwSbYAhHEtlqG6m2_2UAGZGvJ4fnXb-c,1069 +torch/include/ATen/ops/glu_backward_cuda_dispatch.h,sha256=OtadHCIw9gTvoRuIXZdzKWwIZjUT_tsOXNGM6enB-RM,1071 +torch/include/ATen/ops/glu_backward_jvp.h,sha256=mmPmhozsfbunDKjB4XyzaQs5P-qIuIwiBnkhqFAPWkI,1830 +torch/include/ATen/ops/glu_backward_jvp_compositeexplicitautograd_dispatch.h,sha256=lsd-nKCn_UCgkV2O7wCz_RBj3GHRMhvSRfko1KXoMZg,1149 +torch/include/ATen/ops/glu_backward_jvp_cpu_dispatch.h,sha256=uj4WENQzBx7KSiLUDxI5TnHUsZrqvqVV-wSdtR9M3GE,872 +torch/include/ATen/ops/glu_backward_jvp_cuda_dispatch.h,sha256=XepTFM3UzNmd2ps9j2Ze0zF3yw3RGAgXl_AQ_f97jIQ,874 +torch/include/ATen/ops/glu_backward_jvp_native.h,sha256=aFCAtkUWV6LcVauPg6B-PnMSyp4u67f4SnJpgjBQW3Y,834 +torch/include/ATen/ops/glu_backward_jvp_ops.h,sha256=nYCdfWBhZcBHuDyIGThpGsGUQH5p3XZfKqEDXarU-mY,2435 +torch/include/ATen/ops/glu_backward_native.h,sha256=oX_ozps1ShNTGg78lWFKgmJ0iMouqun79IpuZTDsukk,946 +torch/include/ATen/ops/glu_backward_ops.h,sha256=JyZMUXjWKY7mtckJrerG4a07Ki4o7HKLMuyybPEnmPc,1961 +torch/include/ATen/ops/glu_compositeexplicitautogradnonfunctional_dispatch.h,sha256=JxHBq33uGbOfhKV9BgraWtfCH00nurGZ82BZ9DSu0XM,826 +torch/include/ATen/ops/glu_cpu_dispatch.h,sha256=XAROScoLUEPYOHAbqplXtZlKlHYIQUOlLyb6pHx3Q8Y,938 +torch/include/ATen/ops/glu_cuda_dispatch.h,sha256=Tusi6OLVWipv8tdOJJ1D84fG0r9dfpvwy87lZkZJ0dY,940 +torch/include/ATen/ops/glu_jvp.h,sha256=R_faYzuNwrRnydwsB6HfqkfP6zoja41znL0yAEh2DIU,1368 +torch/include/ATen/ops/glu_jvp_compositeexplicitautograd_dispatch.h,sha256=Fsjc6H3etxAv0nIv2oDRJlQnXVr5WUdJj3lgRrUEYag,1007 +torch/include/ATen/ops/glu_jvp_cpu_dispatch.h,sha256=CgOWH-LflGH7jaSRjCIic4Ss_qnpA4EAK8rxORnW7jM,801 +torch/include/ATen/ops/glu_jvp_cuda_dispatch.h,sha256=xpTD6Sm5JxSPzzReLayOKVYoO2n8UP7Pn1WcW77s1-Q,803 +torch/include/ATen/ops/glu_jvp_native.h,sha256=EZmMBc3WC4ILbNwYg86fGOdkYmF0gP6ChrUgJClkX50,692 +torch/include/ATen/ops/glu_jvp_ops.h,sha256=sZ6vH6spEZjMQ1W3K0hgciglElVi-BCedgskngrSKMg,1977 +torch/include/ATen/ops/glu_meta.h,sha256=HjkDcorznopfW2U1qT8A4RAcKtu6HizHPyeb_Z9DNbI,604 +torch/include/ATen/ops/glu_meta_dispatch.h,sha256=-BzX_1h7paitic4dik8xR2pwMpDovuY4mwz_-FvXVgk,940 +torch/include/ATen/ops/glu_native.h,sha256=MHx4rFcWBU0CNf8ssR46AZecZkgN0-shO_GkuHQ0tKM,623 +torch/include/ATen/ops/glu_ops.h,sha256=lu3fHmrEdstcrusTC7XXJxXv9aZDVogKNWDkGC-12o0,1663 +torch/include/ATen/ops/gradient.h,sha256=5Tf5_u9M7vCVy_sekALOpEuU71wFjuY9mZK57QR9y9c,2933 +torch/include/ATen/ops/gradient_compositeimplicitautograd_dispatch.h,sha256=F22NVmcTTDeX3gDPX8lvOCmiJJOC6o1RN-Ehnxi-WLk,1813 +torch/include/ATen/ops/gradient_native.h,sha256=DiYPmCyMe6Bc58pkoA2IhAEa2z2_s7oR6pSnc_JPPXw,1523 +torch/include/ATen/ops/gradient_ops.h,sha256=P1pRe69BzasHP6aR-gVczqWY8bU_T7hJu6dZdC-C1NE,6367 +torch/include/ATen/ops/greater.h,sha256=ICOZ7pYf9uY2os2XmuiTqGk9FcqAf6YG9ghL7TZAwIY,1991 +torch/include/ATen/ops/greater_compositeimplicitautograd_dispatch.h,sha256=8m6b9kwclvpfoVGEZJq_WseULlPO2gxn1ssrkWvgkzo,1480 +torch/include/ATen/ops/greater_equal.h,sha256=tHzJfWICDOuBIyc5DzNVmmmADHpYKfcvprUq-OihYWI,2105 +torch/include/ATen/ops/greater_equal_compositeimplicitautograd_dispatch.h,sha256=fuSOyFfahtvmnsUaNgh9JGWOv1GTNj1tbmGy8CrmybE,1528 +torch/include/ATen/ops/greater_equal_native.h,sha256=KJxz7zw2boo_8Acevi5cuxPRjC5KDI71NaOPbdGZ2Oc,1012 +torch/include/ATen/ops/greater_equal_ops.h,sha256=FhtkOcNqZsA18-98sf_HraxexUKl91ur7wFo6haCEgM,4483 +torch/include/ATen/ops/greater_native.h,sha256=1hlTcCk813lWb_Aoo9qfjkqlj3-G4AgaSflUAyHYrPQ,976 +torch/include/ATen/ops/greater_ops.h,sha256=FxN1qDDrmBlgJo76fqArzndTTe01noxiEtbMXLGOLdI,4375 +torch/include/ATen/ops/grid_sampler.h,sha256=AICHoAbnXDbHOWVBb0Z-sz5qYu-BRcZE1MwRmQPlfmk,920 +torch/include/ATen/ops/grid_sampler_2d.h,sha256=JWMUmWHQh9tftFHizLZ1jS75kWokeOU_aRb7YQwpT20,1841 +torch/include/ATen/ops/grid_sampler_2d_backward.h,sha256=9-9IWNdXw-xtR0z_Zt1OiT5rsnQG7g9js1KFLySguVk,2537 +torch/include/ATen/ops/grid_sampler_2d_backward_compositeexplicitautograd_dispatch.h,sha256=0eof-LfTRuMNCKP9kcI0NWS5WotjDLAgOCduCP8Th-0,1345 +torch/include/ATen/ops/grid_sampler_2d_backward_cpu_dispatch.h,sha256=0v7_2jUuEVN8_yUk_XD6FMWkAN-lt0TYEA8EeSDImfs,948 +torch/include/ATen/ops/grid_sampler_2d_backward_cuda_dispatch.h,sha256=j6va7wm1_fsxoIFsR3M5ML9nIxjc_JYdxpP7yh7t7U4,950 +torch/include/ATen/ops/grid_sampler_2d_backward_native.h,sha256=HIBQ3wLG4SVwIO0nLjH_fMzF7TI33Y5E0Ug7R5eFq4Q,1277 +torch/include/ATen/ops/grid_sampler_2d_backward_ops.h,sha256=2FQZDXCfKWffnWxvYghjJZAwfnU42yLMbzzKX7e0v90,3001 +torch/include/ATen/ops/grid_sampler_2d_compositeexplicitautograd_dispatch.h,sha256=vtg90Gq6R8kMnK2u9d3Otq1dGnAfIC0CFvGylieliC4,1101 +torch/include/ATen/ops/grid_sampler_2d_cpu_dispatch.h,sha256=aPsZQIm4MQB2xQzijw5xpps0pmbh0iUwlMU1VdKx6kY,848 +torch/include/ATen/ops/grid_sampler_2d_cuda_dispatch.h,sha256=3Cm35FpDy5eFc54KpuebPC0tDA1fq6Hce3sQ6yW-yKw,850 +torch/include/ATen/ops/grid_sampler_2d_native.h,sha256=1Z4wh1nj3_P5L-feMDbONEhSxdlDz8hcsLhZ9Sr8pqs,955 +torch/include/ATen/ops/grid_sampler_2d_ops.h,sha256=hVbH_yo0AsdEEri-SKrsS9ja4G0fgrmKRTduPFi-EfY,2265 +torch/include/ATen/ops/grid_sampler_3d.h,sha256=S5efhFUQnW9N7CCsvjBofaRVMzAXAvAC3I6UjbN6I4Q,1841 +torch/include/ATen/ops/grid_sampler_3d_backward.h,sha256=wUydPiSSWL39FMYnn-4XWGp-25gtPQAZIecnyDppbKE,2537 +torch/include/ATen/ops/grid_sampler_3d_backward_compositeexplicitautograd_dispatch.h,sha256=pEPkdPNof9vMPqinSfAXO9kKpiuSMkCW_Ssd9twZTT8,1345 +torch/include/ATen/ops/grid_sampler_3d_backward_cpu_dispatch.h,sha256=7vZIOnvndYW0YixBSQ3ypM3t3RvHEfW6bQt4IJHe1Us,948 +torch/include/ATen/ops/grid_sampler_3d_backward_cuda_dispatch.h,sha256=23gIXncAwciW3VfkUjBZFBECKOMY6rY1bFOTLQ7Z52I,950 +torch/include/ATen/ops/grid_sampler_3d_backward_native.h,sha256=6Xrr5lTSrW3DQcL1rvkdTcRxRfRbw1sw5rBdRz5HgYI,1277 +torch/include/ATen/ops/grid_sampler_3d_backward_ops.h,sha256=zHKktYt26XWh-q-LL6Pm0BxKculCP1AzTgzDIuS9Bgk,3001 +torch/include/ATen/ops/grid_sampler_3d_compositeexplicitautograd_dispatch.h,sha256=9SUe9Hq-j7p4c8zMFaRfCBnOlr4S8nsGN4h7SnjBkVc,1101 +torch/include/ATen/ops/grid_sampler_3d_cpu_dispatch.h,sha256=GC-zXpvdAgXxDnqk1KaKNHmiWevDoQFVnarIdNdTp44,848 +torch/include/ATen/ops/grid_sampler_3d_cuda_dispatch.h,sha256=ESgJtfz_Aul0Faywq-Fr2kZmV8qrcqcdkI9nBJ_uEbM,850 +torch/include/ATen/ops/grid_sampler_3d_native.h,sha256=9emds8m8dPIL-xnnIwN-ll-1OZS7QLHcqjH78xaeiMA,955 +torch/include/ATen/ops/grid_sampler_3d_ops.h,sha256=j_vUl6x4crF5eOfR8EZeUjsgmaw6Zl0g_NlRDsdBdt0,2265 +torch/include/ATen/ops/grid_sampler_compositeimplicitautograd_dispatch.h,sha256=TzvShjnFoV-946upn5-TKq-Cy2Xw7LJZdqG7z5rP0w8,889 +torch/include/ATen/ops/grid_sampler_native.h,sha256=qZRpJWyTvZcssRRxqLYQ6oxCjdVrLpq2K7-d1uKqhpM,599 +torch/include/ATen/ops/grid_sampler_ops.h,sha256=cLst5B4c4vwYabqjAeMsN9Yce9m4k1l-1LtLfqnfBlA,1316 +torch/include/ATen/ops/group_norm.h,sha256=F1GCi4WKWudYB3zfexurYQo6_6z920vLRI25q_8AAZQ,986 +torch/include/ATen/ops/group_norm_compositeimplicitautograd_dispatch.h,sha256=5npicrSiIRekLmeugxlxVEzYWQ2e_rB1mOZXBQtpMno,947 +torch/include/ATen/ops/group_norm_native.h,sha256=PCbFJQxjBj_8SxjNQ3tW_bukuu381ba8RNk5Le_eA0Y,657 +torch/include/ATen/ops/group_norm_ops.h,sha256=FOgdd6t91Wr49y5DovDZglSnWCVfbJKvRNt_oF75BiQ,1472 +torch/include/ATen/ops/gru.h,sha256=cfAZj1i_FCQHzqUcv3PiFHHeKRhboa27Io4poHYbVms,1608 +torch/include/ATen/ops/gru_cell.h,sha256=TtAuD32C6eP3Ie-5RskHThjOfin9P5hGwG4zPLMLMnk,945 +torch/include/ATen/ops/gru_cell_compositeimplicitautograd_dispatch.h,sha256=Uf5BdbZvd2K-iIeEE5O9C-HKBnY3ddWYMmpJLIwb2iM,953 +torch/include/ATen/ops/gru_cell_native.h,sha256=OVLIyxDwkNACdq6J6opGB5UcGTXAEogJKAVM_DMgZ2o,663 +torch/include/ATen/ops/gru_cell_ops.h,sha256=o8w-eVUC8c5VqWTxyRD-mzBXDhLjY-GjIJKAYP-ss1k,1518 +torch/include/ATen/ops/gru_compositeimplicitautograd_dispatch.h,sha256=mQRRYCDpXC-uG05oGcbkM7_xCpirXfvYeKsarK4yF4Q,1199 +torch/include/ATen/ops/gru_native.h,sha256=4xOjzzkJIMVmGgqmWrVqCwIvrhGbtPvhIqbhkFWZ6TQ,909 +torch/include/ATen/ops/gru_ops.h,sha256=ZJqIAdnV8zbS_IqZ7bRtxWBUZZwUr_2yAj3zK_Mq5AI,2729 +torch/include/ATen/ops/gt.h,sha256=458_7H-vbKGjYCiGq5URIs9tbs5t7gLqTCXPLTIkZgs,1896 +torch/include/ATen/ops/gt_compositeexplicitautogradnonfunctional_dispatch.h,sha256=F_-oKd6yJcn4TO-GY1V_h7W-qKvmHv9LQFWkwSGJ510,1060 +torch/include/ATen/ops/gt_cpu_dispatch.h,sha256=olFMnVQYzQ1ndAMqVgoq2NXAkBffkEk_KN0DBdMfxac,1396 +torch/include/ATen/ops/gt_cuda_dispatch.h,sha256=2uYWHPbcMrjN9AeZIfR6wqu6runjUhjoIvlS7dPSGx8,1398 +torch/include/ATen/ops/gt_meta.h,sha256=HHfgxBQKqKC7jcmJSXXAmSOSfe0UL0GIBOwDar319JM,767 +torch/include/ATen/ops/gt_meta_dispatch.h,sha256=v99mjphAY0vGncwOTaDDXKirVgHYOuglgPqjqK9o4d4,1398 +torch/include/ATen/ops/gt_native.h,sha256=FL951OL7FfHBVnJmdCqtu_kSMfPAyVTNDem4se242AM,1326 +torch/include/ATen/ops/gt_ops.h,sha256=8iL5gMdzUwqU6EtFk7YIIZbJhdTPeJ247R--oHAmfpg,4285 +torch/include/ATen/ops/hamming_window.h,sha256=GiNi7LBSjem1BbEok8j4cuV_xOE35VUc0VI5oSq7wd0,7136 +torch/include/ATen/ops/hamming_window_compositeexplicitautograd_dispatch.h,sha256=R1lhareJLDWMlvegYZTLPffvIUew2_UB__7uNRRaZk0,2948 +torch/include/ATen/ops/hamming_window_native.h,sha256=HkUxzzeETI5Ih-C9rqna3fs1ZfBSJPt184IK4a3vO7Q,1870 +torch/include/ATen/ops/hamming_window_ops.h,sha256=HEz0liVB0HbaNw8Ao8MqNnYl7P_4l4M4X1ak-moLMmw,7882 +torch/include/ATen/ops/hann_window.h,sha256=AZc_2a7fdQna0YjQYUuXPl2dFRd5YDXvyOvzIsQ3VhQ,3379 +torch/include/ATen/ops/hann_window_compositeexplicitautograd_dispatch.h,sha256=U1QtzyLvyYVBEoJfRs5RmJ00w2y9kxJtHwHeCHoyHfM,1704 +torch/include/ATen/ops/hann_window_native.h,sha256=xaOEnCKhsn1GZB6JOkRZoiLLZbhwiHKYZCzell9Czy0,1075 +torch/include/ATen/ops/hann_window_ops.h,sha256=zPxWRFxl1N7y9-iNufNELKGqbMRcYvKgwQqeEUWqPlc,3870 +torch/include/ATen/ops/hardshrink.h,sha256=79Eoh-NmawqE34-b4jQHFPaYPDUlliQmepgBhx0IpCY,1268 +torch/include/ATen/ops/hardshrink_backward.h,sha256=DupWxakQZoDv4t02HlWs1JNoOtETJw4Mnj-EbQeB_zo,1576 +torch/include/ATen/ops/hardshrink_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=DSfvNgURbTtZV4Go9rSooE8p0e2K-zKRsmlaRe_swiM,881 +torch/include/ATen/ops/hardshrink_backward_cpu_dispatch.h,sha256=TT5Phbe2GoIXPxQ8iXnC75n0m8tJzpI_-2O_LECecv4,1120 +torch/include/ATen/ops/hardshrink_backward_cuda_dispatch.h,sha256=l0dc0t4zieu0Kb-H_DWDSI2ylkydR046zlTcxEWUtZM,1122 +torch/include/ATen/ops/hardshrink_backward_meta.h,sha256=LUC-Jv_ai0AdVQk8jGoXXlnG8rvjYkTMSBloHj5E8r4,662 +torch/include/ATen/ops/hardshrink_backward_meta_dispatch.h,sha256=nVECIFKc310wCjOvbuNJDsc5qAzPneoiFWUiYyRE9Es,1122 +torch/include/ATen/ops/hardshrink_backward_native.h,sha256=IiJ2zvmQQK9sdMjblmlSPHFtjziLLgzzlqlqnmZ_SmE,720 +torch/include/ATen/ops/hardshrink_backward_ops.h,sha256=PwZHEyb7V_8JTrjWWEXX3TeThHbcMd0iLJrVnAKyILc,2069 +torch/include/ATen/ops/hardshrink_compositeexplicitautogradnonfunctional_dispatch.h,sha256=ppiMfrw40um6MgUYp3QHjRKbwdisIFHpoviZDigEvWI,847 +torch/include/ATen/ops/hardshrink_cpu_dispatch.h,sha256=CuRtV1qyX0QZ1g1iMy7FZunxUzTfKt0_A3y5AunR_lY,1000 +torch/include/ATen/ops/hardshrink_cuda_dispatch.h,sha256=u3DZCdU557GFgjF3VVOgYPnlu5vW8AxC-slNotmhC-o,1002 +torch/include/ATen/ops/hardshrink_meta.h,sha256=pNHGh0UzXUBG_HQ2ch4eBJJvcGzB4M7Hk2oKt-l4Rcs,624 +torch/include/ATen/ops/hardshrink_meta_dispatch.h,sha256=BzSIaHEpwLwT1qoawyi1OgMKsFpU9y45XpRNBOCrpPY,1002 +torch/include/ATen/ops/hardshrink_native.h,sha256=9G_p8t1KUgjLr8dz5rp3ydeg4h1cCXXltN63fa-hG00,657 +torch/include/ATen/ops/hardshrink_ops.h,sha256=7NJoxUO_iKLD3m0EXYtxorccxv7edFgOXeCQQScNxqI,1791 +torch/include/ATen/ops/hardsigmoid.h,sha256=efY_hIwrYfzv6msh1J_UUjTebBPC-Yx-16DKwSnOKA4,1279 +torch/include/ATen/ops/hardsigmoid_backward.h,sha256=aDj2gkKIwmHY2lt4de8ajPIP0pplcUG33pf2sQs65bc,1472 +torch/include/ATen/ops/hardsigmoid_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=FzgASsN6eLbg-RiEMMNE__-SDjRJYQKDIcluF0r8zUk,859 +torch/include/ATen/ops/hardsigmoid_backward_cpu_dispatch.h,sha256=55a441ZZoaJUFMvtDkc1foxPnwGUhQVf43UVzvRgER4,1054 +torch/include/ATen/ops/hardsigmoid_backward_cuda_dispatch.h,sha256=t5wH76Vugsaheht8a-aZB2amdizZ1wawUxtTYQ0fRNA,1056 +torch/include/ATen/ops/hardsigmoid_backward_meta.h,sha256=nXi7swGFgHoIFr8xsW_-9jfykgBQy6QZexJk30CArNI,640 +torch/include/ATen/ops/hardsigmoid_backward_meta_dispatch.h,sha256=iZ3m_Zmq4FgNIeGrGUyccGl67OUWggasTqUytj2cpx4,1056 +torch/include/ATen/ops/hardsigmoid_backward_native.h,sha256=bwhch2g2u93zKIF9TMFmqVltvfQSei9XIXQfzKVTqaw,700 +torch/include/ATen/ops/hardsigmoid_backward_ops.h,sha256=kZ0JxH-hY4-xHObDjnv24ccCoY_Ul6N8Jm5oos2sY8Q,1921 +torch/include/ATen/ops/hardsigmoid_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Srfp8scdHFCdhoa5677IVr7mu7OXuDkHmR_kFkwQ2OA,875 +torch/include/ATen/ops/hardsigmoid_cpu_dispatch.h,sha256=rV1504Sea9n9gJczpcNWMcFEgjSo07OqN5QtnjKj1lY,974 +torch/include/ATen/ops/hardsigmoid_cuda_dispatch.h,sha256=jLr7zbrdcOBAsRTiPx_QK3OTn9rZcVxheo9FI0SNtyM,976 +torch/include/ATen/ops/hardsigmoid_meta.h,sha256=rRpN4qsow9ZeiDfdPa6eK0ecxWgaif0VUP0U5C3ryRE,599 +torch/include/ATen/ops/hardsigmoid_meta_dispatch.h,sha256=P_PDJ2pcZId03_kDD6YDi_YkD80EU-KULtQ4C9xt3co,976 +torch/include/ATen/ops/hardsigmoid_native.h,sha256=9zAuj79u6VZCXFjd3bB3QmeKrq6ngOrGN7G-g7xNyro,806 +torch/include/ATen/ops/hardsigmoid_ops.h,sha256=YmIbEkjuU2mjqjGHYjVM4_Xldl9QvBt9ZBr_dLvO9wY,2142 +torch/include/ATen/ops/hardswish.h,sha256=V7c_uTcxK5sUEQL-Z-VeGOd9x3emtmmAZafSH06HbRY,1253 +torch/include/ATen/ops/hardswish_backward.h,sha256=nLGHuiTuLjIP4LBPXj_jY-IepS5duBBoJKmIlv22NUI,1382 +torch/include/ATen/ops/hardswish_backward_compositeexplicitautograd_dispatch.h,sha256=YBbNDNFJJMzPrQ3kzCLguqCeJfEn_6bDlnre93_1p_g,979 +torch/include/ATen/ops/hardswish_backward_cpu_dispatch.h,sha256=a6EO1eXUkK4NveImLd91Eh2QMrUk6jXyVs20Szq6RXA,787 +torch/include/ATen/ops/hardswish_backward_cuda_dispatch.h,sha256=uxe81h0kRhvTRxFyV_FGQ6CwfgX1d2GbeBNHhHJbpjc,789 +torch/include/ATen/ops/hardswish_backward_native.h,sha256=SfNNsY9quIw56hvKOtUzyYDgcppKSwVS2C8UFwDrrm8,664 +torch/include/ATen/ops/hardswish_backward_ops.h,sha256=pAUrhkavHRyRlZs6DJuPJv_hTUTsr80JnKhWen4VbCw,1867 +torch/include/ATen/ops/hardswish_cpu_dispatch.h,sha256=-RuVmpD8qcWtdy4fuMCmSP1XJvpGx1MN9yj_EFP2HC0,966 +torch/include/ATen/ops/hardswish_cuda_dispatch.h,sha256=lj78B9IR2taT7KRpXu-G2IFMN-3b1pVZ4RL73iEY7Ww,968 +torch/include/ATen/ops/hardswish_meta_dispatch.h,sha256=vCKVzIHBDgeJ8c4qOvBbvsFqPx08bUm8Pur-7Mr3CiM,745 +torch/include/ATen/ops/hardswish_native.h,sha256=35tiHi8NDeMxDwRyjsfXbhH8NuSGp77L2qgg2JI1dKc,637 +torch/include/ATen/ops/hardswish_ops.h,sha256=v5AAcIM6qDI8wfPh4rJVHw0FSEDwqBXZKi5I2IpfODE,2124 +torch/include/ATen/ops/hardtanh.h,sha256=yn4neMrCjmPmPKHUEupjr01UOmBs9e7-7TMhbRWvn4s,1699 +torch/include/ATen/ops/hardtanh_backward.h,sha256=kha9UcfUJ-fRbJ3vKh9NgYMiIkHes3wdsp-RQWh2Yq8,1760 +torch/include/ATen/ops/hardtanh_backward_cpu_dispatch.h,sha256=pxDY8M14gyrMXFoGv8NvTnWBylP4_4cyMPsiwQ03W8w,1213 +torch/include/ATen/ops/hardtanh_backward_cuda_dispatch.h,sha256=WYNmpwXbj1pzw7YRqR33PSEBc1hOhJNy9ZAizDEXVUM,1215 +torch/include/ATen/ops/hardtanh_backward_native.h,sha256=sUEZUA53CgZyJavGKCYKhIn0IfUeSK5OlcE5zlX9ZPo,781 +torch/include/ATen/ops/hardtanh_backward_ops.h,sha256=BjX6OpAPQeSSGGBZukZIf6fj-2CyYhaqrepfZj1gMrA,2271 +torch/include/ATen/ops/hardtanh_cpu_dispatch.h,sha256=W0Mk2X1984cMzfClzFh0BdA8u4GdG0a3Z3HJoTiZNPQ,1201 +torch/include/ATen/ops/hardtanh_cuda_dispatch.h,sha256=9gSZbr1m1wq_C8llrqZky-mroBVSqcjfpnI7V1GVC_M,1203 +torch/include/ATen/ops/hardtanh_meta_dispatch.h,sha256=_yVg8Vb6NQ3L513GbhWXKydJVpvtpEoXmfxEiXwJKg8,805 +torch/include/ATen/ops/hardtanh_native.h,sha256=H5wqJeiug4koyrwhDB-NMwOoMWgtwlYA3mzAXC2Zl9I,1224 +torch/include/ATen/ops/hardtanh_ops.h,sha256=xPcC7izmt4Y06tLIcWNBDdTzNMGo1IWEpG8Y0xR4MPo,2682 +torch/include/ATen/ops/heaviside.h,sha256=-u6-fHflPm2qBVYmVJpbT_RjilmQp3ON7JpDqwGEWWU,1247 +torch/include/ATen/ops/heaviside_compositeexplicitautogradnonfunctional_dispatch.h,sha256=ffcLkx0EQ1pg-xZrdIoKdnN_gtqo0rtBySdFJ8qBq3k,925 +torch/include/ATen/ops/heaviside_cpu_dispatch.h,sha256=GwDhKCf6-TKGk6PC82mZmyS73EYPGADDoJg4TpLZFNU,1074 +torch/include/ATen/ops/heaviside_cuda_dispatch.h,sha256=4e96-T_08QVWqyoRVsMvG-IeRDtVHad-30I8K5TBEZw,1076 +torch/include/ATen/ops/heaviside_meta.h,sha256=FASNhbFmy2ru2BLxmQAAPid1vi885aGZ9-UoSe7BI1Y,624 +torch/include/ATen/ops/heaviside_meta_dispatch.h,sha256=kMdeEtH57XT7xMnV0C18-xnZ0hqtW71Bb0aQAKSROJ0,1076 +torch/include/ATen/ops/heaviside_native.h,sha256=NpMEJDpNvzGox4egsXUu97Ni0IkXqgy4ILXPCvq5s_w,655 +torch/include/ATen/ops/heaviside_ops.h,sha256=_QV80mcAR60ry1xnTvXaRmOqOBhlDwEEWOuR9xdT49I,2391 +torch/include/ATen/ops/hinge_embedding_loss.h,sha256=IauVUQK1koqOYXV_r-TCZkmuyZT8SV_O9K_zz7C-tu0,889 +torch/include/ATen/ops/hinge_embedding_loss_compositeimplicitautograd_dispatch.h,sha256=agNle1fnmgUwjzFAs7WzG-T_0g-aIznJglvL3H9o89k,886 +torch/include/ATen/ops/hinge_embedding_loss_native.h,sha256=eJTzshZ3HtFplNePUGKgs3tzqiQOZ5Z4knGtSTZD8lk,596 +torch/include/ATen/ops/hinge_embedding_loss_ops.h,sha256=elK05QFglWEwmbSmi_ZS2JOjMMoXjZrCKESVgXg570E,1240 +torch/include/ATen/ops/histc.h,sha256=bQaaNjsUCkwNsy5sBKxPqrmqiL6BoF6YKx9n8N-Ve4M,1433 +torch/include/ATen/ops/histc_cpu_dispatch.h,sha256=2cZA18gg8jb07USKHY9Uq8kqhPrMQDlcBenmot09hHc,1101 +torch/include/ATen/ops/histc_cuda_dispatch.h,sha256=1iASw3tSWy1gEunCA2lPTrxihLgYyjaB0Mj6iUIhmrM,1103 +torch/include/ATen/ops/histc_native.h,sha256=oSd5lQqCrSxwF-yMFcAqlA1lAItvTFquMisZFjLQVLo,1002 +torch/include/ATen/ops/histc_ops.h,sha256=w-asb3jTt0GDOR0QSwx36jWZo5RzfNPATkOYBiWAhzE,2011 +torch/include/ATen/ops/histogram.h,sha256=eX6FE4tdTTNO22gcIXFx-ztI1xtBvHI_Xjb8j7Lrhq0,3570 +torch/include/ATen/ops/histogram_cpu_dispatch.h,sha256=EyXXtqtXXiLrChH9atpaNN9XegUc3519lu3h8wJRxBE,2066 +torch/include/ATen/ops/histogram_native.h,sha256=ego-9UyGampskzONPcS73XvyTmjPCh-Bgtlo87VQeRw,1313 +torch/include/ATen/ops/histogram_ops.h,sha256=TQHIx2lubBGaCSOsDQWfeGAgbEvLe39WZojACgmjr-E,4815 +torch/include/ATen/ops/histogramdd.h,sha256=Lu8moG7ZPezAMlvKlacVEglO_TdcFRVnmtw1qaz6ad4,2000 +torch/include/ATen/ops/histogramdd_compositeimplicitautograd_dispatch.h,sha256=neenWuCY0QNwJ0w_3p_9xDaUB8Z7NgPaISL2-eWMrIk,1470 +torch/include/ATen/ops/histogramdd_native.h,sha256=cDbcGb1OFe2HaJhK_-K6Uxebe_qIgSNdaszwLloUecI,1180 +torch/include/ATen/ops/histogramdd_ops.h,sha256=8gT3ZtPEED-G9AZO1XNkJHlavHN1PTa9hZBHoMbXga0,3810 +torch/include/ATen/ops/hsplit.h,sha256=PCNyCZnaL8VhDyxmNM3ZoNxcMyYAgC81a6BZyxnJL0E,975 +torch/include/ATen/ops/hsplit_compositeimplicitautograd_dispatch.h,sha256=oZSSWZKuLwrASSeEZbOpxi6KgnYIX5k4C9Z52ChcXC4,915 +torch/include/ATen/ops/hsplit_native.h,sha256=HL6pJaREzQpTePbSUXoUSLXGfjeVNiQW7Kr9A6xFOr8,625 +torch/include/ATen/ops/hsplit_ops.h,sha256=HveuyPqeHCqw-n4Tb5cXwJwc5GGqfnkZVztNCxo2BEQ,1782 +torch/include/ATen/ops/hspmm.h,sha256=6_IOGYGSpb7Xn_QTy9d87gYZEHTsLqxnckJzafb-e7I,1189 +torch/include/ATen/ops/hspmm_native.h,sha256=Qqq5KYgWNPI5I_eMQFnOntdzsHVT_gVlU5bwS8KDfrY,852 +torch/include/ATen/ops/hspmm_ops.h,sha256=35pJ1-0PBtPU0L6KjHydanX-RLI2SV6F6Q0NS1pfnAw,1747 +torch/include/ATen/ops/hstack.h,sha256=smNRL5zXoutPQHntmPhy9yRE6NHNf6jYVnnpDvIxBTQ,1088 +torch/include/ATen/ops/hstack_compositeimplicitautograd_dispatch.h,sha256=qSvFN-1m3i869NssaHyqfGUBaTy_5EGL_9aXfwPt1L4,943 +torch/include/ATen/ops/hstack_native.h,sha256=XIPnHW27VGKdF3KjRwy5N2Gism6GANDoDNI6oKzJrNc,574 +torch/include/ATen/ops/hstack_ops.h,sha256=k8xTNGj0GVXnumIreTvIQeEs88LMlpW8hws34Eo2kCY,1585 +torch/include/ATen/ops/huber_loss.h,sha256=PJX46rFXbvNdkcOimCrRr-yRBGDf67FSvc9Ea61fVno,1569 +torch/include/ATen/ops/huber_loss_backward.h,sha256=4_3p4SFCJfIgTd1TxUvQHA-djEik1lyIf8A-jZgVBDE,1821 +torch/include/ATen/ops/huber_loss_backward_compositeexplicitautograd_dispatch.h,sha256=KZrJFso0Aj7ORNXKvRiMMUZ9IFFw37yDAI39LrGEBgA,892 +torch/include/ATen/ops/huber_loss_backward_cpu_dispatch.h,sha256=1FeHpsJ5inbeIQm23syCWpDCUDX8mTBQLiJgLxokQW8,1071 +torch/include/ATen/ops/huber_loss_backward_cuda_dispatch.h,sha256=7R0oAumeJHEie9P6wgPwcMOQsBHDBQQQY6ghPQNY5q4,1073 +torch/include/ATen/ops/huber_loss_backward_native.h,sha256=1EshJnvvrcAPqC5tuIMyrVIxRflFHKeSIPZKt6XP-Gg,793 +torch/include/ATen/ops/huber_loss_backward_ops.h,sha256=hMknr_32LRf_rajocEPX-ag9qF6a7rEArMJNNK3gmVo,2294 +torch/include/ATen/ops/huber_loss_cpu_dispatch.h,sha256=JAdOXOxXF8WPuLYkqwkUBvGfKgiudAn4zazFwF39pe8,1142 +torch/include/ATen/ops/huber_loss_cuda_dispatch.h,sha256=bbIO69bT6cu_kLWrEdjmS2YdXpGROOTaITX9DaNjyJE,1144 +torch/include/ATen/ops/huber_loss_native.h,sha256=gODrsFtQ8ENp5yk7pSsLDXPWhTvk3tvORaPJ3bZQ2RU,728 +torch/include/ATen/ops/huber_loss_ops.h,sha256=q7oBzospKT-F-2jCbalPg1U4HyfYoDVH45d92tMDi3U,2029 +torch/include/ATen/ops/hypot.h,sha256=u0RqAZz88SoM2YAseBhJkmrgtEWv8S72iHsa3TM3Hf8,1198 +torch/include/ATen/ops/hypot_compositeexplicitautogradnonfunctional_dispatch.h,sha256=ImEKpZuQpVoU7VvCs0hurrZaakFMeCuUwqVtAmfNnfs,915 +torch/include/ATen/ops/hypot_cpu_dispatch.h,sha256=BuiiXcEnAC4pWxe7vENddsI2aO_rzqb_0Vc153g-rFw,1054 +torch/include/ATen/ops/hypot_cuda_dispatch.h,sha256=KT78XiNkPT3uCLzW7tqeT6Lkl_7n6kKF91VQ-uK5gVE,1056 +torch/include/ATen/ops/hypot_meta.h,sha256=3py8YfZegQFrfzCmOqtVQc9X8GW6fixaBnv0qSYtris,619 +torch/include/ATen/ops/hypot_meta_dispatch.h,sha256=xUWGOTGtli5apfCgFyQKxxC6HmwEZgrXkvaQEroL2xE,1056 +torch/include/ATen/ops/hypot_native.h,sha256=cdEqSu4vGxcmeNMZAxAMlpkK2KgxGPb3X1LDNo1QLNU,642 +torch/include/ATen/ops/hypot_ops.h,sha256=WDWQ6y9S6dDdLDp3NuLjFEC32_1jJRHNuZaCsvFfqa0,2346 +torch/include/ATen/ops/i0.h,sha256=vtmHT18Y5atuViY0Y-2dx8Rixjeo-qCpx8Ud4wYVskM,1162 +torch/include/ATen/ops/i0_compositeexplicitautogradnonfunctional_dispatch.h,sha256=nJmSu1dFd7mmJIimRaMgg3pkd2jTn7on2qhAeJP9zPw,857 +torch/include/ATen/ops/i0_cpu_dispatch.h,sha256=kPYD60G30UEDrOBW6zfEgChn4ykcTsmxFMdkX5cfSUA,938 +torch/include/ATen/ops/i0_cuda_dispatch.h,sha256=60zO0yYcSmghHrvH3AH00J-GL2Xc_Q3bby0oE7_N8FQ,940 +torch/include/ATen/ops/i0_meta.h,sha256=BQYqdflNOnqHQGI3O7zFP4TQmDiFW_4bVrR5aS6RS1Y,590 +torch/include/ATen/ops/i0_meta_dispatch.h,sha256=8XtZuuqF3bPS5MpRq_Kf1SpqGBTYRDoEBUtxiUwoV1w,940 +torch/include/ATen/ops/i0_native.h,sha256=CcQAbXQgWxrmIM_jV2WIfWynPhKDj9fENFzsu6_vSNo,607 +torch/include/ATen/ops/i0_ops.h,sha256=smnUJnoRQQHl7Z5pJ-OwSuXg-5sakHjsUoSCUjsxNBw,2061 +torch/include/ATen/ops/igamma.h,sha256=sWFrySL4XF6ZJl8F0vZO1uAxNGj3Ea-SIHc5sc27p4k,1208 +torch/include/ATen/ops/igamma_compositeexplicitautogradnonfunctional_dispatch.h,sha256=blictQrEgYJY2l68yB3O2IsDpEQL-jv8smGy4DR1BWw,917 +torch/include/ATen/ops/igamma_cpu_dispatch.h,sha256=3prEsXehK_I-EhXyfINYDkbgrSrDElv_Lqkk0A6L4ZQ,1058 +torch/include/ATen/ops/igamma_cuda_dispatch.h,sha256=Fa1GZ_hJyyWnaMQM0eYoufjcwaxBO1aHBYHW-Mn5wEo,1060 +torch/include/ATen/ops/igamma_meta.h,sha256=6XAJz2ZbOQDhmg_TvPBOtLXnxKY01Ds0FA5RLWP7kjU,620 +torch/include/ATen/ops/igamma_meta_dispatch.h,sha256=3peisq9YwTqk1u2F1FZeQ6CHWO4c2cSkCrhFxxLUsz0,1060 +torch/include/ATen/ops/igamma_native.h,sha256=tM_jX_9uCvBmA0VZ__rXJZ5trvA_fOCm4qNbiYpKSBM,645 +torch/include/ATen/ops/igamma_ops.h,sha256=RipkatruxAXKn8Msko0BTOLDPILU-tmywP3xJCXQst4,2355 +torch/include/ATen/ops/igammac.h,sha256=iFY-80b8K-SPgav3hS51-H-1dQPrJ5dC84dAXrH-ZyY,1218 +torch/include/ATen/ops/igammac_compositeexplicitautogradnonfunctional_dispatch.h,sha256=yPvQMwlUk3N1KH1LsrU1ufGGFlYyn6EdLox1wuD05ZA,919 +torch/include/ATen/ops/igammac_cpu_dispatch.h,sha256=mpFiOt45ohKUw3VWZ4ld2Rhvn3yRR79-4Aewcb-i71s,1062 +torch/include/ATen/ops/igammac_cuda_dispatch.h,sha256=k1J3zoJuazqDWzvDeJRrvIwTrJkQo0TQ2qCbm4NQCBQ,1064 +torch/include/ATen/ops/igammac_meta.h,sha256=yW3UlSarkw9DvA31_zwEqDFCKliu-_g0miGUjUi2plA,621 +torch/include/ATen/ops/igammac_meta_dispatch.h,sha256=ALu7tgpPFY_WCEz0fJcDRUkUotz3xY91tDXkakrBtTw,1064 +torch/include/ATen/ops/igammac_native.h,sha256=yu9E0d_JfeZ4h0sAU1_vGElgdBh3_K6ZbZwnbA-ZQTk,648 +torch/include/ATen/ops/igammac_ops.h,sha256=Bwq7okmnWWITSHbTD1q-tfGMraKtkXz5P1vaFZgrdZ8,2364 +torch/include/ATen/ops/im2col.h,sha256=4rjGzGRc4pOeBCscm4vzDZKSgidHHtiAqXBNNNLu7pE,1703 +torch/include/ATen/ops/im2col_cpu_dispatch.h,sha256=WGWf3jSbWoIzqPa6U2RZxKnaBqccqliKatc1BaoHIZw,1214 +torch/include/ATen/ops/im2col_cuda_dispatch.h,sha256=zaifIoRLTE9-aoRjDbaKwvWy9E-Tjw5SYFm9tTJAcyM,1216 +torch/include/ATen/ops/im2col_native.h,sha256=F9lFQdw9AZmcZUnNRzr0GUm7C86HrmN77RJ1WLzvV8k,1144 +torch/include/ATen/ops/im2col_ops.h,sha256=NB90pYbzjZEz9YyTuyxdS87gOvHEsruunJi98HUtwvE,2275 +torch/include/ATen/ops/imag.h,sha256=b8xS3qXlCKrPgrk7wCxP5qjcjn0IkjpYk3Za5vJ4ySM,666 +torch/include/ATen/ops/imag_compositeimplicitautograd_dispatch.h,sha256=oeKDm1PHOSuVHdBlWkS4FjnnzcZnnLeHyYsS_Rsh7gM,785 +torch/include/ATen/ops/imag_native.h,sha256=E3B4KyUsquLsYeJSH-KBD5H1E7ypyXTbbU0fhV7tOHw,495 +torch/include/ATen/ops/imag_ops.h,sha256=idZu-1cE2y1fFzPm6dop9IADp0NVDEQYJ71km0WEgLE,986 +torch/include/ATen/ops/index.h,sha256=6aX_2iMI0lrL0FLD-KAK8l8T5uh3p1NqEq62YCJ9T_k,1351 +torch/include/ATen/ops/index_add.h,sha256=HWJxMZvkb_uuB5P-d36XCDcnHf26L1jvF7d52KkM5so,1974 +torch/include/ATen/ops/index_add_compositeexplicitautogradnonfunctional_dispatch.h,sha256=-DDi3l7afTs4cs86AMCZhDmQFnm5wCIeZ3fEndCb74Q,1059 +torch/include/ATen/ops/index_add_compositeimplicitautograd_dispatch.h,sha256=vTKHKuoycVfKBCT6adnbo86D37ZA4Kl5Hq1MTw-6IyQ,888 +torch/include/ATen/ops/index_add_cpu_dispatch.h,sha256=fHwUqAQb7SbL5UEPg4YdMNRHKoYRgoIilydUmAoiNnE,1340 +torch/include/ATen/ops/index_add_cuda_dispatch.h,sha256=m3rp7URb41pvtNq9wAblMNx6AMb22_jlZG6xCZ1YMp4,1342 +torch/include/ATen/ops/index_add_meta.h,sha256=aGdOhKQl6_ayfoEU-zTRvkP6ZB6fx8vdB1v78tyzK-o,1143 +torch/include/ATen/ops/index_add_meta_dispatch.h,sha256=NxZ8hsr27GYYSZ9ENFK8jZT258I2losVrXzMPEN7izs,1342 +torch/include/ATen/ops/index_add_native.h,sha256=ewVqaOHxle8oO5d-g7ovqIEJrc241FJ-P4T84wYvG5g,1127 +torch/include/ATen/ops/index_add_ops.h,sha256=9lXw9ypp57T10Ycdk-nS7pSRbx-tVolNTQpafVPV56U,3920 +torch/include/ATen/ops/index_compositeexplicitautogradnonfunctional_dispatch.h,sha256=hKY-OjEImjWXZ9ohTbQlcDDXwMh3Cb2l0WF0urS5edQ,868 +torch/include/ATen/ops/index_copy.h,sha256=fRCj1Yk_sOUWJoBwNr2QSp7xo70_juzrFfekJvTU888,1779 +torch/include/ATen/ops/index_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=-K5w2RPLPCr-Ir6xzS3F0DZyLtg_KOTRi50CbifyVeY,1005 +torch/include/ATen/ops/index_copy_compositeimplicitautograd_dispatch.h,sha256=BKg9qkoKe871tOp9YnObLl-GGN6rO1aaD-GQNvxyOxw,987 +torch/include/ATen/ops/index_copy_cpu_dispatch.h,sha256=HuhTOpm7tndQLHowfJaQcZ5nXipPbKmKsMR0rAqvNgY,1234 +torch/include/ATen/ops/index_copy_cuda_dispatch.h,sha256=PV30xZek9NeXlAvwinAZJtqX94YVWftc7XA3ELnfifY,1236 +torch/include/ATen/ops/index_copy_meta.h,sha256=4_YjRHW4Z_IN6EKUChvqgcIG1yqrMTPRfEJENYBgH_M,1118 +torch/include/ATen/ops/index_copy_meta_dispatch.h,sha256=S56Pu1KvCEyVQGBdgP87OpPnt58ifvZBAPU1sVP7-K8,1236 +torch/include/ATen/ops/index_copy_native.h,sha256=XRkOxXzXLpQSvJ4caJrLCnKhj-tMK_Fc_9LpbgLemwU,952 +torch/include/ATen/ops/index_copy_ops.h,sha256=6vsEyiqUzUgIbw9mCdLlsjpUf3LHJjMGAYr1aYq04rY,4351 +torch/include/ATen/ops/index_cpu_dispatch.h,sha256=nFrd6OnHRxlE0JKFzntjN6aOnzpGabQbTXKrXOasfRo,1067 +torch/include/ATen/ops/index_cuda_dispatch.h,sha256=lBT0b23OgMe7Q8MbNAHlkfphk-S8B3yr3iPBcecNqfU,1069 +torch/include/ATen/ops/index_fill.h,sha256=gGCoVCbYC_8OXQlh2qXEIE4KCyrDeB3D0Jn84Oqsqs0,3162 +torch/include/ATen/ops/index_fill_compositeexplicitautograd_dispatch.h,sha256=oJF6Uufn0Unps-eI0qnfIVcb0lLMPGo0wo0KDwikng4,1574 +torch/include/ATen/ops/index_fill_compositeimplicitautograd_dispatch.h,sha256=7VA3VD9kOhiY41gLMCpBrmv1oLzimFTuEVYNCITzOXQ,1238 +torch/include/ATen/ops/index_fill_cpu_dispatch.h,sha256=ycHSpTKzFTtmi5eE5rXHZNj9hPZbGOp__ygtfmtC9PE,930 +torch/include/ATen/ops/index_fill_cuda_dispatch.h,sha256=zZhLeI99vBu8UTEgNwkw65cVTTyw7lKdD6biHWgzQXE,932 +torch/include/ATen/ops/index_fill_meta_dispatch.h,sha256=G2AW_fXMC1tx0XsmB0WStJedCDoIXP_YXxhuCmqBpjQ,932 +torch/include/ATen/ops/index_fill_native.h,sha256=ZTprWxlt41-edoazT5h8P58TthmpDAXoTYeTR3hiqJ0,1756 +torch/include/ATen/ops/index_fill_ops.h,sha256=yO2H7FlP3GOdRI5chvW_8iri8dOCYUICj63LaHG6u1I,8475 +torch/include/ATen/ops/index_meta.h,sha256=3GhrIi7wj2DkBOV9f7hGg5nSg59tvVTDCMb85vPEmz4,1520 +torch/include/ATen/ops/index_meta_dispatch.h,sha256=rJ3a_sMVgNWfmrRxJh9jq0uAhOJK-kIscr2-dP2zDw0,1069 +torch/include/ATen/ops/index_native.h,sha256=vnnKWlAAhdEVDuPeOBQO1SASZayv7VWk7U3hRoDM4UU,787 +torch/include/ATen/ops/index_ops.h,sha256=buaI8etFpMrxgbwXSo_IkLHJUgxOE1fvIdgvzkHz2-I,1980 +torch/include/ATen/ops/index_put.h,sha256=xNA-O8h2fJbxBPq4a8goVXwrwqMAoPEYASHBs6f4bIE,2015 +torch/include/ATen/ops/index_put_compositeexplicitautograd_dispatch.h,sha256=OZKf3IY0Y-EqhJq81da3g7VgoFkMkougaA8qxt6sCuA,1428 +torch/include/ATen/ops/index_put_native.h,sha256=4Z0UQDX_yyj16dkX70hw_rJxIBLM46mVXVvFa7S2GqA,949 +torch/include/ATen/ops/index_put_ops.h,sha256=g1OUoPY8evBx-_bDHTNzTn3spLbBcFP4--Oh4_Mb1bQ,3117 +torch/include/ATen/ops/index_reduce.h,sha256=mRFW3fNvF7TmcUkVG8-rZgSHi3MA88Zqd0LbZtlshKA,1818 +torch/include/ATen/ops/index_reduce_compositeexplicitautogradnonfunctional_dispatch.h,sha256=aRUyQZaIwLmFbGH6J0UNqwVj7jxnO-Nvl-AiQ5qpQZY,1107 +torch/include/ATen/ops/index_reduce_cpu_dispatch.h,sha256=Bdw8N7MaqlGH_5Gu61XR8Ll5luitxVzGfsiQX2cRq3s,1433 +torch/include/ATen/ops/index_reduce_cuda_dispatch.h,sha256=NUxOrIMQLlcqEysIERuHq12OCPPFDwhcekbFmrnFYW4,1435 +torch/include/ATen/ops/index_reduce_meta.h,sha256=ePuggPI2pfii-dibazSUPd-Xk4uKKCQimHZ8zVotgDo,1164 +torch/include/ATen/ops/index_reduce_meta_dispatch.h,sha256=lnbG__ivHVCRN8yOATqc6qWY7M908HD_-Bt3FuOrSYU,1435 +torch/include/ATen/ops/index_reduce_native.h,sha256=F3hwZhKH70QhvKQ3yjmk2QV0qheaiSIAnzkEHV_MUPQ,1022 +torch/include/ATen/ops/index_reduce_ops.h,sha256=6Xg4kHL0f1fM1E8n0b5Gr8lbIgNGbo_pyAbn_qrea-I,3258 +torch/include/ATen/ops/index_select.h,sha256=2-Mu9LMxwcSSfcYEr_Z560YpR3hlwz112wBEI_XQHmQ,2236 +torch/include/ATen/ops/index_select_backward.h,sha256=JGRL9qrknrCN-pDAAD5Xnb4skbsptJdebeIl-ku84pE,1896 +torch/include/ATen/ops/index_select_backward_compositeimplicitautograd_dispatch.h,sha256=X2aW_PO8IXcs8bP-qiDKKrRmMGmotQE_R5QDzN3cb2U,1017 +torch/include/ATen/ops/index_select_backward_native.h,sha256=iOld3Wan-bq1jmw-ijDi45SvYPuDiOqiqBvB44d-nhw,590 +torch/include/ATen/ops/index_select_backward_ops.h,sha256=OL62xLMzpEOA2YZldUFoMhTDaA7C-8kvueDMUq5SE_Y,1267 +torch/include/ATen/ops/index_select_compositeimplicitautograd_dispatch.h,sha256=pEKqmFPTAdPKA9xWKtdCuOUBUftSRDvDnE_oC21NHQ4,1093 +torch/include/ATen/ops/index_select_cpu_dispatch.h,sha256=kMWS3kuFjYUHLnQA92e_5jWN-bXMOrE-yd5yX-DGoEU,1037 +torch/include/ATen/ops/index_select_cuda_dispatch.h,sha256=eVr5d1UnMfQyFyyeLLbdZ9zl8BI_odKDKvnYT_zWU9g,1039 +torch/include/ATen/ops/index_select_native.h,sha256=fwLUav3pF8cQUE-GVIYgFot__rJuj2FAFsu5RhnpcEs,1595 +torch/include/ATen/ops/index_select_ops.h,sha256=D1OeS-nA0Txha4fmMkbuQ2duYjoKV9CpvHXqhLAlIVA,3370 +torch/include/ATen/ops/indices.h,sha256=lrLxh7wmxnSjJnVIixLOChyWIlhByOz3vkX5DUdZeO8,531 +torch/include/ATen/ops/indices_compositeexplicitautograd_dispatch.h,sha256=yyjvm24Lk1DpHyIhQ9S1zEFgGgpeCFWYF0U9GFNnCDM,788 +torch/include/ATen/ops/indices_copy.h,sha256=IMg7ef8MJbL-mY3t6MJzFMnueYZ0gVicSfyemBteKbQ,1127 +torch/include/ATen/ops/indices_copy_compositeexplicitautograd_dispatch.h,sha256=MYM3k4HTy0ZNyKI6krEhghiaQAwbOjmA_B8ywlA_6Os,903 +torch/include/ATen/ops/indices_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=JoSTyAUblo1T5brcglprg2o04pqtqsg6wsgkCdAZt4w,819 +torch/include/ATen/ops/indices_copy_native.h,sha256=BFz_ZlYXaJ2XEvxaa5dpzokNEGBcvvq4GzfCkGF17O4,588 +torch/include/ATen/ops/indices_copy_ops.h,sha256=5RJr2I7EQu66UcGkL8cWmTqBRQNBSMmLD83ifQw5co0,1623 +torch/include/ATen/ops/indices_native.h,sha256=Av2H5KnqIE9odH6POMxOKf0QMg1Rrxsf4Oqcbo0v_08,569 +torch/include/ATen/ops/indices_ops.h,sha256=JhQD33wS1_cdVACDPdz9gK8elLLS7GIN8WUjnI_JVXo,995 +torch/include/ATen/ops/infinitely_differentiable_gelu_backward.h,sha256=XyNvjIqi79Ql_xWhyzzd3txAUpNEIgloUZ2-ZmzbxoI,844 +torch/include/ATen/ops/infinitely_differentiable_gelu_backward_compositeimplicitautograd_dispatch.h,sha256=5vjj1huJjb2eA4KFKVD8U7C1nqOxUOOwQzdxhYpl2nw,845 +torch/include/ATen/ops/infinitely_differentiable_gelu_backward_native.h,sha256=Nyrhv5F6ZKZdY3KyyInqLWYjsdRLwrYJdzNDUHa7HrE,555 +torch/include/ATen/ops/infinitely_differentiable_gelu_backward_ops.h,sha256=lP89uG3ZoORK-GDy-oKZiBtdryw8upOLQYSq0IPjvuQ,1168 +torch/include/ATen/ops/inner.h,sha256=-VV_yvYZPb27LA4y1KNSAtXPH_Pt9ZtMa2kWgXNED7s,1198 +torch/include/ATen/ops/inner_compositeimplicitautograd_dispatch.h,sha256=0B-jJ2uHNiYA4kOhMxOEOkXQPZwK7vByqoSE3sqFKgo,1021 +torch/include/ATen/ops/inner_native.h,sha256=fCUqkIKiZfSNPB2t__5Ctph3z0_Xcz6ORAZqGezTCjs,626 +torch/include/ATen/ops/inner_ops.h,sha256=t6AStw5xazx3OBoYf0CB9Y6DV3F-mW_8aVTsKXETz4E,1753 +torch/include/ATen/ops/instance_norm.h,sha256=7_pnEl3FCJpvkDyAzLKEj-t6y_Ytz4VbM0SehMdDf44,1185 +torch/include/ATen/ops/instance_norm_compositeimplicitautograd_dispatch.h,sha256=Oy5ijMxPTsO_rOcxGMXBEnPLYEdOGddZwe3M8caUxok,1051 +torch/include/ATen/ops/instance_norm_native.h,sha256=hhLMx-bwETPcHbsOvoE5CbpFf-eWqATTXR-Ulafx2dg,761 +torch/include/ATen/ops/instance_norm_ops.h,sha256=AJqnhXSWeaArEERk0zYuXf_xPzIA2XpL-mGSk82oz8c,1840 +torch/include/ATen/ops/int_repr.h,sha256=hXdguibBscyrV5_drEjdf8mofU6uh7eiclPPtkD1-0U,1087 +torch/include/ATen/ops/int_repr_compositeexplicitautograd_dispatch.h,sha256=66wwGFQC4qnvTmrpJpaIE2iYOCc-ngJhlpKFuRAR-QY,895 +torch/include/ATen/ops/int_repr_native.h,sha256=eTmOd4A6C6d4PeClVUFNykzoj2CWZMrY9XJuO0DqFf4,666 +torch/include/ATen/ops/int_repr_ops.h,sha256=XbRK0_GHC8zIKrupRlSVnvLWvRT0bcwUq-mNgQFz2kg,1599 +torch/include/ATen/ops/inverse.h,sha256=1-ufa7sdTSHm9f34-s4ShJURsUAGuulX5NEQQ7uwjds,1077 +torch/include/ATen/ops/inverse_compositeimplicitautograd_dispatch.h,sha256=D4JXeREsciBavXqHNcGOGnhc6x5WILWh24_pukTXpI8,949 +torch/include/ATen/ops/inverse_native.h,sha256=qGTrrZjHgpy1utKmTbq2o6IqE5YPAq-2MmGcjMJuRxI,578 +torch/include/ATen/ops/inverse_ops.h,sha256=HGBu0SnP2V9HTcvXLthJAOvJQU1hnYBJaDesH2PAz2o,1593 +torch/include/ATen/ops/is_coalesced.h,sha256=wgZI8dNYq1OqfoAefwQSTSk-tPRSElmP8YmMhbcfNPc,536 +torch/include/ATen/ops/is_coalesced_compositeexplicitautograd_dispatch.h,sha256=AUTixpEYNI1FEJVNh4dVjfWdIsZtc7vWyDXOdA2RQ-4,787 +torch/include/ATen/ops/is_coalesced_native.h,sha256=wR9us3mfu-HJe3RvTBbKN6X23qFeTRMo66YrA3zA_-g,567 +torch/include/ATen/ops/is_coalesced_ops.h,sha256=PRmBijU__wE00s8Vb0NRuGd5K2LNZgk3JlWM0DzHUVw,984 +torch/include/ATen/ops/is_complex.h,sha256=DxWNYA2TilDoORIQyn5szdWbP6obXWPdLzE68dJApJw,687 +torch/include/ATen/ops/is_complex_compositeimplicitautograd_dispatch.h,sha256=h95xAY2QAfv_xxkIp-lxxTpcxS2XpSXnzepeDDV6mqI,785 +torch/include/ATen/ops/is_complex_native.h,sha256=DkDC-jwwvKOkLb6287gEohkio2vtaqL6bNWQ1IBjZNc,495 +torch/include/ATen/ops/is_complex_ops.h,sha256=N5xyD0wkbx9GLYqZggHAtHMNCsU9A2258z_syZJ4ZZw,978 +torch/include/ATen/ops/is_conj.h,sha256=I9S9TNQQWjUeaCU81WOZkgXbhiXYwf5me6P8zA1IGb0,675 +torch/include/ATen/ops/is_conj_compositeimplicitautograd_dispatch.h,sha256=MaeV7CQJq98HKiLxpdtpxbGicGz8wHq0GBxQqJgzLkc,782 +torch/include/ATen/ops/is_conj_native.h,sha256=L7Qg4DvAUBasbW-PlhSO2_yAL6lbhYsaFTcwZRBVBaE,492 +torch/include/ATen/ops/is_conj_ops.h,sha256=4QY-SN--TjuNvhQMuQTf5N2wqbUSeQo2uVCU4Smiux0,969 +torch/include/ATen/ops/is_distributed.h,sha256=p29xOXfKS6OAOeVPRDkVNHIxqhz4Zdj0hMHOcqKiKa8,692 +torch/include/ATen/ops/is_distributed_compositeimplicitautograd_dispatch.h,sha256=nN2r1wNg6KT8E2SNn8bQz2MHwJSY4v6RjtJkqJzTtjU,789 +torch/include/ATen/ops/is_distributed_native.h,sha256=mpfX0PI6EBdC5h1gret2-_5eHd15ZcwoVrXM26iuT8c,499 +torch/include/ATen/ops/is_distributed_ops.h,sha256=rzzQxQ6n3Gh0hdotQxSEoujtupXdhgAMxbY5fJsmyLk,990 +torch/include/ATen/ops/is_floating_point.h,sha256=6PUXs2lB4Pv_8oTn4fPHbapSW5OeTGyDagxtipx8hiA,715 +torch/include/ATen/ops/is_floating_point_compositeimplicitautograd_dispatch.h,sha256=XUXBogNWTpVJCuECya8qGxThsnZ9AA5ts_KQNAWMsUw,792 +torch/include/ATen/ops/is_floating_point_native.h,sha256=9WUYXdG1uqBfOZOopEGwSMHEQkqL-3WJdqr1lniGyI8,502 +torch/include/ATen/ops/is_floating_point_ops.h,sha256=7_gLt-tnLhBfKKso7ytB28Mjsrq0KKdXt4oVU2aRVsY,999 +torch/include/ATen/ops/is_inference.h,sha256=4k4TrUfTSnwYRnlbtinWF-oA5NY6wPYtWdzx2wK6wYg,695 +torch/include/ATen/ops/is_inference_compositeimplicitautograd_dispatch.h,sha256=yeJwIcQrd_tcxjI8XPx8DzAR0aMebXobg9OTvXEY83I,787 +torch/include/ATen/ops/is_inference_native.h,sha256=R_gTQUcwPtfS-KIyFCx_I6ZPD64pPk_3HkiVaGjNLa0,497 +torch/include/ATen/ops/is_inference_ops.h,sha256=-DnO9g99lAdWVNk5pEtOS9l8GwuhZIXbEJLLWWfJogA,984 +torch/include/ATen/ops/is_leaf.h,sha256=r0ajuYN5undTQmjKH3ISAROc2meRTi9GFxfKvn16w8c,531 +torch/include/ATen/ops/is_leaf_compositeimplicitautograd_dispatch.h,sha256=ee-gSc1O1DGyb5r8kEOrpglYwl1cIsU4XWITtHtGjOg,782 +torch/include/ATen/ops/is_leaf_native.h,sha256=mnoEieI7tUAhr0h0M6Xwo49gpSx4Dlh8w3jbC9UKZ1U,492 +torch/include/ATen/ops/is_leaf_ops.h,sha256=Y14OlT-Y4N3uITJmr1pM64PQQOZe4PKivJAMx6mvmH8,969 +torch/include/ATen/ops/is_neg.h,sha256=5mXEJ2zzM_NSVtGJJxxWVC37w1KGpcRV_JucFJkSgbs,671 +torch/include/ATen/ops/is_neg_compositeimplicitautograd_dispatch.h,sha256=1r1cbHELLtoP_doK8-RzcKa78-aJmnLR7ywHCUs2wk8,781 +torch/include/ATen/ops/is_neg_native.h,sha256=LhGqoy_ZLbY6YgAImL2D7W3rViwfeZE4qpciU-6uAGQ,491 +torch/include/ATen/ops/is_neg_ops.h,sha256=HUWyHFbMGOjq9_t7g2-7KkZ7xDGce-Ww4xO7lSVnWjk,966 +torch/include/ATen/ops/is_nonzero.h,sha256=LOKBuz2a74GbwByxc-iqJo5u-pJMBqB98Lf_hJa5qzg,676 +torch/include/ATen/ops/is_nonzero_compositeimplicitautograd_dispatch.h,sha256=wyDCtYnaB5stV2L_7BujsabeWVoCxt6AXPCmiLsUTpI,785 +torch/include/ATen/ops/is_nonzero_native.h,sha256=MTlQS_yNqSOAk5g1BD8G6eJQ56MkUy3nOWVJ8eGsb3I,495 +torch/include/ATen/ops/is_nonzero_ops.h,sha256=Ko3GXEHQmYxgPavtOcfPS259jag0dS4_E7MbdxWaX-o,978 +torch/include/ATen/ops/is_pinned.h,sha256=dYt6TmdwfhsS0nwdMgtlKpVma6IBHREgyLDZfHMBQoE,533 +torch/include/ATen/ops/is_pinned_compositeexplicitautograd_dispatch.h,sha256=n6yUUBdxRZVU60qY4L7W5ZGwGhaV62GdMOqtQ1nmQvs,835 +torch/include/ATen/ops/is_pinned_native.h,sha256=M6UcKhnwFcRHrFsbL0-ORhvvvoFRuU9980u2KNVu5ao,780 +torch/include/ATen/ops/is_pinned_ops.h,sha256=hhFFs-hUmZw-T-ZxGvbNn3CpoC8cfmpF7XauYoOje6c,1097 +torch/include/ATen/ops/is_same_size.h,sha256=Oc86WTVtTOxUYwSJjHKKKMy000EnoY78DGYUYWVSIIU,731 +torch/include/ATen/ops/is_same_size_compositeexplicitautograd_dispatch.h,sha256=Jh20srMIEdQFmI8EwxsB3g5yMIwzjRidQNJK8KkW1tA,813 +torch/include/ATen/ops/is_same_size_native.h,sha256=cOSw9NZfTtNse27Tm3sQyJQcuwSQsb--V_tmWUXON5o,611 +torch/include/ATen/ops/is_same_size_ops.h,sha256=391ON_2ttRgcQw_tD4sQhD9IUKRz5bgRUZttbB9lALs,1070 +torch/include/ATen/ops/is_set_to.h,sha256=ulV313fQrkBXcITC8rUsHA7GMB-2LkdG5PdF8V1IEd8,533 +torch/include/ATen/ops/is_set_to_cpu_dispatch.h,sha256=RunyQ_cDiXI04RbkoLWS3Gi95KE5RAuHmH1YjmR1Ius,767 +torch/include/ATen/ops/is_set_to_cuda_dispatch.h,sha256=byipTQYtM5S7x8zXyMhH5ZOKvB0DefrtB6wCBEE60ZI,769 +torch/include/ATen/ops/is_set_to_native.h,sha256=CpT_0eouHrJap-cU4TpDzKhXQR5DzHezFKUG5qnfmuc,521 +torch/include/ATen/ops/is_set_to_ops.h,sha256=GiiFl7XvJubOxuwedHK8gagY1kQ62V5Mw8zH8ITktBQ,1064 +torch/include/ATen/ops/is_signed.h,sha256=8-q8p6NZ8sEseszbZbOACNtya78zWNrcV4VHCQYCrAU,683 +torch/include/ATen/ops/is_signed_compositeimplicitautograd_dispatch.h,sha256=BVo0Fs7lO9wbHyPY4TFu0zZoEjVTGa1heB7LvG2xX9E,784 +torch/include/ATen/ops/is_signed_native.h,sha256=Oluv3oNHDl9lIm-AIHcRFHUXG72IEgbRPByN0USuoys,494 +torch/include/ATen/ops/is_signed_ops.h,sha256=ZctUKEU45meM7gqvCGLucepkuscpt5EzS7J1KOxyIdA,975 +torch/include/ATen/ops/is_vulkan_available.h,sha256=qVvssOhZ5spVrwpmEuXulV5aQq-L0st88XkXuKQVdNw,674 +torch/include/ATen/ops/is_vulkan_available_compositeimplicitautograd_dispatch.h,sha256=QMLnEjHrLU9zJ48THu6NBuqghAiyGTWLDUI2oZlo3Cs,771 +torch/include/ATen/ops/is_vulkan_available_native.h,sha256=En2M8Eq3imC1NQTD-S17jACAqFa8OAcsT-awo-bkpW4,481 +torch/include/ATen/ops/is_vulkan_available_ops.h,sha256=PNxj654rbyq2hiyx4sNUGHwTh06xIJx8iuaZeOFk8vA,928 +torch/include/ATen/ops/isclose.h,sha256=ivXj2nxbp0yj1JwfXt71Rx_o-9nzzBbb3mUTEDhU8ug,860 +torch/include/ATen/ops/isclose_compositeimplicitautograd_dispatch.h,sha256=Mt-VCfkGwFaI2R3Tyox6zscr6ZEBKDivJEL6gh8eVOU,874 +torch/include/ATen/ops/isclose_native.h,sha256=0f4hfYEeJdQr1P5o74s6NnCVYvuaacxhF5tR4b9IoO4,584 +torch/include/ATen/ops/isclose_ops.h,sha256=-uRErayLo8akwZWR-i1fe5fh9FB6ZfO_VsWn8JYSRN0,1239 +torch/include/ATen/ops/isfinite.h,sha256=XEPZ9nPW0dmpzTP3qBOxLXIXhTSfW9JTnhjFvEr6JFg,676 +torch/include/ATen/ops/isfinite_compositeimplicitautograd_dispatch.h,sha256=YteVGAaV_tGhE1VVrYN1QdEsF6ac7prW8q74pbfM4To,789 +torch/include/ATen/ops/isfinite_native.h,sha256=O82KpuyYBABBWppTNf7_ibTD0tBFUozGbBOFenBvUJk,499 +torch/include/ATen/ops/isfinite_ops.h,sha256=fsX6vPKsEKlW8Zn-zMRVC24or1CQzKR7iIoKwoLQGbk,992 +torch/include/ATen/ops/isin.h,sha256=sObREzW504MYqXdBoPtVlwY7jW20-HZh064Cy0q0Xz8,4060 +torch/include/ATen/ops/isin_compositeexplicitautogradnonfunctional_dispatch.h,sha256=mhihbAQijjT1HjXCe7yATpbqxGotvL33a9tlyTyOgUM,1164 +torch/include/ATen/ops/isin_cpu_dispatch.h,sha256=0Dr77LKWPKwx45KOWj8NTFmhC4WAjVVUJwGEHPmdmYQ,2017 +torch/include/ATen/ops/isin_cuda_dispatch.h,sha256=IK4v9WiUpG3l2yqsdYxJDDsPFHxNhk2oFsbxg9XAYsc,2019 +torch/include/ATen/ops/isin_meta.h,sha256=v6WM9OqvpJ5LnoR5dhZvxjmuwIcrxLghanol1CIGtns,1071 +torch/include/ATen/ops/isin_meta_dispatch.h,sha256=j2C4ujo2PXMONTWgbSbWVaQR9wyNEdv9ySjkLhQIeUg,2019 +torch/include/ATen/ops/isin_native.h,sha256=kfHL8T4izLXRdYpI2YrGenEyTs_PYunlRQuCpyIYV3U,1188 +torch/include/ATen/ops/isin_ops.h,sha256=VhQhEjxKwEGzy23OgLas3nszGNwZNQKUr9DBfOojMf4,5491 +torch/include/ATen/ops/isinf.h,sha256=9qSPvQKbOA_nwqOHREGoMNNSRTKL6AeCGB8SUDhGVwk,1057 +torch/include/ATen/ops/isinf_compositeexplicitautograd_dispatch.h,sha256=cwkzs5SqVwNYSPY4o13ihEnMTk78jp_JmapEgf50wNE,943 +torch/include/ATen/ops/isinf_native.h,sha256=MFuc2RtYAjsisNbM8CEe-j0JcVGAm5K9QeSABeBlnKo,833 +torch/include/ATen/ops/isinf_ops.h,sha256=9c0pJyM_KrWgnbjNVcL1a8vGUYkIF5aCyhGNsuRLYkI,1581 +torch/include/ATen/ops/isnan.h,sha256=jwcEsLXWoFrobd82W8N9NgpLgYMS7giLS3_r5VPZbG0,1057 +torch/include/ATen/ops/isnan_compositeexplicitautograd_dispatch.h,sha256=FGKiJsl_rMQ-B58UJU3v-O3NATqjS0F3TBp5_2SXmzA,889 +torch/include/ATen/ops/isnan_cpu_dispatch.h,sha256=-ryf3BAV6EAwhWwNuFWQWoJBDBQHkyZrtig42uA2prk,742 +torch/include/ATen/ops/isnan_cuda_dispatch.h,sha256=bpwoLAc76QtmnWM1sjLeO3O3nCtGEiP82jC5D5HXVyA,744 +torch/include/ATen/ops/isnan_native.h,sha256=8NPTv9NCKj0mKld4Q8qk-771COZD-PH8F1dToeegfn0,767 +torch/include/ATen/ops/isnan_ops.h,sha256=Uz95n6kt-mps1I94RFJy0RsAvoTGrNkYS78qGgCEPb0,1581 +torch/include/ATen/ops/isneginf.h,sha256=As3yfb4QjZtDDz2Nrb5FZWjjl2tZjjSMMkpJzV1hWTY,1087 +torch/include/ATen/ops/isneginf_compositeexplicitautogradnonfunctional_dispatch.h,sha256=8A7MaWqDbDPtdhOEVDGq-oTS-jKBm7nrJqIQY6dvzc4,815 +torch/include/ATen/ops/isneginf_cpu_dispatch.h,sha256=5HHAcx5fqzVs4PedFNh0ktEPtG-8zpPgTU8G_Z-CaVk,908 +torch/include/ATen/ops/isneginf_cuda_dispatch.h,sha256=jfqLhner1jAQkxbdALC_Z3yheq10vddnvHvyXVX5icU,910 +torch/include/ATen/ops/isneginf_meta.h,sha256=SulztcRhyP02Ews5yksFZPTbSLJNUGaDv6tP7kQq-qI,596 +torch/include/ATen/ops/isneginf_meta_dispatch.h,sha256=I9yAX1vF78K6QUBXs-ccn_CNcac59X3Fg_chUZ5e-t4,910 +torch/include/ATen/ops/isneginf_native.h,sha256=nd3y8tb783Ue_mQ5s5mrCLFs_O5PjsRfbiou1ZqxTMM,1007 +torch/include/ATen/ops/isneginf_ops.h,sha256=kU4AhHRPdEEyeh8iDWofWhr9FHpTIot9rYC73I_tCW8,1599 +torch/include/ATen/ops/isposinf.h,sha256=OyoiCzfpb_LmhDa7nLtUVYEZglbS0jws2KxtpkRVUQw,1087 +torch/include/ATen/ops/isposinf_compositeexplicitautogradnonfunctional_dispatch.h,sha256=rR-wHjAl5VDw5EnvMVfINmnq4n5y2KW-nyVBKDO9AHQ,815 +torch/include/ATen/ops/isposinf_cpu_dispatch.h,sha256=PQUUsi7ST-DHMRYzdEzw6ozY0V517fCQxHUQqREBuCg,908 +torch/include/ATen/ops/isposinf_cuda_dispatch.h,sha256=2lVnyGQkV9RmB9lA-egvIV0JjB581EKhSJklGtdH58E,910 +torch/include/ATen/ops/isposinf_meta.h,sha256=MWLheYVdgZ8_UzmYrsXe8v1PbnyGrztJXj1UMmyPkrA,596 +torch/include/ATen/ops/isposinf_meta_dispatch.h,sha256=Vx55Qbu32ceYrIphuvsEfbdKb1WAZL0Pw3gl3_IlSEc,910 +torch/include/ATen/ops/isposinf_native.h,sha256=VSOHxG0sE3JrFH4bOeHx-MCdqJJ78OVhK9lFSi259eI,1007 +torch/include/ATen/ops/isposinf_ops.h,sha256=GOWhxJjHWnPByF_zuGCHQxExGJaILvJIKrUWhMmJarE,1599 +torch/include/ATen/ops/isreal.h,sha256=V2zD8vNoyBnjZtqDdeszuUhCDjK9c0_CmHqXXq0j6J4,668 +torch/include/ATen/ops/isreal_compositeimplicitautograd_dispatch.h,sha256=M_k2RCQuJJGJRULH5smFKnDoALN31w8aIDStR1O4xOc,787 +torch/include/ATen/ops/isreal_native.h,sha256=RcM6eR7HgYK9bB5RTTfyiUiu24YwNQ_6ntw8UvCypr0,497 +torch/include/ATen/ops/isreal_ops.h,sha256=MQgxys97_C5I-g00_fe2YBIY2AClqVbSGDpHxWP28v8,986 +torch/include/ATen/ops/istft.h,sha256=UXtyXqMFFxJfqxCIssOT52duDeictlu638bMqHH0Wjk,1269 +torch/include/ATen/ops/istft_compositeimplicitautograd_dispatch.h,sha256=57o0_b2QRMkqgCvABv9gRMZBUWyAW29oyB1yiAZ4Tr4,1115 +torch/include/ATen/ops/istft_native.h,sha256=-jUNelHmoHi-aOHwufrrHjqtjrO58EfGL9DczPe_t9c,825 +torch/include/ATen/ops/istft_ops.h,sha256=p3Ey25jN8w1WEGG5FEDo08PWPLkGWr3jOEOXVqAKvfQ,1829 +torch/include/ATen/ops/item.h,sha256=mI99eJlnWv-E4ujFe7dkzK293PwEOxour6uycKvEbIc,528 +torch/include/ATen/ops/item_compositeimplicitautograd_dispatch.h,sha256=1ecGhHGt8_A4dGOEXc1raK9f-VlfWu8JR6S1oSDdAps,785 +torch/include/ATen/ops/item_native.h,sha256=KhFHUrXbVH5XehVLDNT5xJdSyFb7PFFuFB25WBJ_TaM,495 +torch/include/ATen/ops/item_ops.h,sha256=Hy3AuUkNFMlkG7Ja27BjAWEBfVcAZrkPL3zzinx5lh8,980 +torch/include/ATen/ops/kaiser_window.h,sha256=mm0addVjWEMklKJbT5Rp-Gm8sJ_-fTy77_z-5O_GrbA,5084 +torch/include/ATen/ops/kaiser_window_compositeexplicitautograd_dispatch.h,sha256=gQl2k_LSnCL2vG8wfb5Wag59CEuSMy3hnUYeAh4yZp8,2296 +torch/include/ATen/ops/kaiser_window_native.h,sha256=7TE996k4sSk_3MpMxcRQKTcY-dePLtV5JIpnq3wYsxU,1445 +torch/include/ATen/ops/kaiser_window_ops.h,sha256=jKq-qPTYiPyN6bSPUshpcDU2Jl8RWNAkwsjwkFuE6Kc,5749 +torch/include/ATen/ops/kl_div.h,sha256=rD-GKmFonuhZVtEXTe4wIgCj6z9r4EXPDfeCkwxVMnQ,849 +torch/include/ATen/ops/kl_div_compositeimplicitautograd_dispatch.h,sha256=06QVEuF4ONULDQRJfoPwAwbIauVBtM0XsqRgLTAknf8,876 +torch/include/ATen/ops/kl_div_native.h,sha256=MNzMSmjTdLWkLTMq8WHpoaXCQXP3RijZ6pqV26qYnTA,586 +torch/include/ATen/ops/kl_div_ops.h,sha256=pMG8VS_xRlIvaymdH7QZvoJV2mFigi3-ZCJScGXtJJo,1208 +torch/include/ATen/ops/kron.h,sha256=nMJ4bMo6qeWTZmvyRbwuoxwY7wsFYqPMrD655n1t23E,1188 +torch/include/ATen/ops/kron_compositeimplicitautograd_dispatch.h,sha256=qSwnf1Ryj8RiCmKu2Dv_FD5FmaXPW40E7Ivg7eOvI9M,1018 +torch/include/ATen/ops/kron_native.h,sha256=EwHMdnPKWfeg7OdLLeLNSMZnWfFkIvGwRRzUeWOP7c0,624 +torch/include/ATen/ops/kron_ops.h,sha256=GanCvfEXNpoL0W8ZO7t0tSLFdbL_MjaRBSvmyE02Zyc,1747 +torch/include/ATen/ops/kthvalue.h,sha256=YwKMxGz_gD7RK4erfVotRdAN9ZlRnG63FHoGsuqj5PM,9514 +torch/include/ATen/ops/kthvalue_compositeexplicitautograd_dispatch.h,sha256=DC0GWkBJxcdyfOrwqi9uyBBICXaBydzhmQOY1v4j2e4,1001 +torch/include/ATen/ops/kthvalue_compositeimplicitautograd_dispatch.h,sha256=_l8LLH5JIs9eAJmIVyRgh05Ht528ZcYgkbWMKGSE7d0,1739 +torch/include/ATen/ops/kthvalue_cpu_dispatch.h,sha256=xmZUkxs7RTL-j_skkJQWll_0H4ZyDhTExLiOyfeBY3Q,1414 +torch/include/ATen/ops/kthvalue_cuda_dispatch.h,sha256=8fzBesWe2Q_n_MIxyFJQmAgJzJN2Af6FVfwGjueeBKA,1416 +torch/include/ATen/ops/kthvalue_native.h,sha256=snSV7-QNciFY4vANi3bG-BbKm15VcaFcSyvy6q0HOus,1227 +torch/include/ATen/ops/kthvalue_ops.h,sha256=sVx3KATwU76y1TyqYshoUizIvQA2GP5w61VBGF1bp9I,4015 +torch/include/ATen/ops/l1_loss.h,sha256=CmCVoxUJGLWECYzxZStgtg3psbuGy5oACiqnAlqKF9g,792 +torch/include/ATen/ops/l1_loss_compositeimplicitautograd_dispatch.h,sha256=M5Mp217w5jbEifSlFhPtVg-8t5YKKGxK5RRr0Est6aM,854 +torch/include/ATen/ops/l1_loss_native.h,sha256=4IQyZ3uswTelSrh7ZK3CL6awa_3vTnhjxVy2CftIhS4,564 +torch/include/ATen/ops/l1_loss_ops.h,sha256=ENK_obj9iglsMYxn_FM2mVfPzA8CyxNh5ZMLvza-jRc,1145 +torch/include/ATen/ops/layer_norm.h,sha256=VAw3tu_-ZmEl1dAJJZoQ-6jeiJH9ZOA4Tttz5tRcPQk,2425 +torch/include/ATen/ops/layer_norm_compositeimplicitautograd_dispatch.h,sha256=nzgb8UIRBZT74kO5PlPWF-icb6SKpM9DhfHx0FxrwGo,1199 +torch/include/ATen/ops/layer_norm_native.h,sha256=uUDna7UXEjnOo3Tlz3k4vvth66__t_Tts90UY3AwbLI,681 +torch/include/ATen/ops/layer_norm_ops.h,sha256=VHn63QoS8kB16yfqa3xMVvLNukNwGtJa8dDsdKKdQmk,1528 +torch/include/ATen/ops/lcm.h,sha256=rK6ZAddscNzD4rpWsQb8ANs0mgwA_nkwYEIJgq-GWc4,1363 +torch/include/ATen/ops/lcm_compositeexplicitautogradnonfunctional_dispatch.h,sha256=zY-IDlvJl0xWRiQnjoa0B2vVtVbp5Da0pSBbs5UId5Y,911 +torch/include/ATen/ops/lcm_cpu_dispatch.h,sha256=7TJaa5whkpQDBf7u6HAD7uRsrli2_z5bTUFjyOyfMXk,1046 +torch/include/ATen/ops/lcm_cuda_dispatch.h,sha256=rMQ8tdAw4D2nPvhMcrAIMpfQ4VoK7rjYUDCodwOmMwM,1048 +torch/include/ATen/ops/lcm_meta.h,sha256=ZzoJN5uyfz69j53kCEtPJ3aKfEhahahKzvJjM6WVlxc,617 +torch/include/ATen/ops/lcm_meta_dispatch.h,sha256=hiPxjqLePWlN4IToX0eA-QptHTd6t__pdGaVNRsJul4,1048 +torch/include/ATen/ops/lcm_native.h,sha256=QAe8jOSb9_CUq_cRH0lDPSXh7ZsjMJBy2C7ErSUGjw0,636 +torch/include/ATen/ops/lcm_ops.h,sha256=zZ0sg8tVz1pIqz9RBFiVfuiH-iMsp_BO1R2rQS_lTqo,2328 +torch/include/ATen/ops/ldexp.h,sha256=vC4FiY-hUYkJI4sTD8Gm01q_M8JuPaOATPKMnZUMRxk,1403 +torch/include/ATen/ops/ldexp_compositeimplicitautograd_dispatch.h,sha256=u5RwhvcX-ni3pCRDFk33yYWeRcvy3A2IVtm2FJ0OTN8,1098 +torch/include/ATen/ops/ldexp_native.h,sha256=JduRNliT5W-FFbgrW5TATyeaUGpK9j78T33tnwCTiiQ,703 +torch/include/ATen/ops/ldexp_ops.h,sha256=EBfH3Yplqo_3jXHOaeKetBAkk0HhHFqC8gKkKFrMkTs,2366 +torch/include/ATen/ops/le.h,sha256=7LpkqBv8S9L_7IEWV7mFSSLpnB8Gp8WJJ958VZ1LrEE,1896 +torch/include/ATen/ops/le_compositeexplicitautogradnonfunctional_dispatch.h,sha256=aepkbaqdU5XqYBpepOyqZRajB-ddZbZ-SzIgLwmVmkw,1060 +torch/include/ATen/ops/le_cpu_dispatch.h,sha256=7xoFyKZ_2Mbp_yp7FWM3z1yYWNsA1eTpIOR23_6OBu8,1396 +torch/include/ATen/ops/le_cuda_dispatch.h,sha256=NXd80r4s51wU4bveraqHGsQnY6ZcIArO_sNV-F6MWpE,1398 +torch/include/ATen/ops/le_meta.h,sha256=WwHRvjLHp7etj1t8sZLFv8p3RhFEoCFxp9WAItbNa9w,767 +torch/include/ATen/ops/le_meta_dispatch.h,sha256=tU4pcuKNuzvK4l0_C0fiZzXQuNe_wFoEbTHoXsUxEJQ,1398 +torch/include/ATen/ops/le_native.h,sha256=DKtkwx1a6TNtfRlBOqFF8ohVJ_4Tz9BS_2oUop66t2w,1235 +torch/include/ATen/ops/le_ops.h,sha256=JmLrc4empZXpD2AQ8Gno_VjQGLakMQ2cL-WYKk6RYs0,4285 +torch/include/ATen/ops/leaky_relu.h,sha256=g96deLvEoZc7iUFGmfbhvdVzJM6s3-HhireqiqAJppU,1597 +torch/include/ATen/ops/leaky_relu_backward.h,sha256=WoLboEAvk4ir9C1xECfMhmmO3Q0CjS_-Qy3ro-wtfE4,1858 +torch/include/ATen/ops/leaky_relu_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=A9tqDQOBlYY0VLohKdIAXzcQ9uXZthzPow--NmT28dA,914 +torch/include/ATen/ops/leaky_relu_backward_cpu_dispatch.h,sha256=Dr-2OKGIaX8nAv10gGB5cmf730JYL98t7ZWmy49oYZs,1219 +torch/include/ATen/ops/leaky_relu_backward_cuda_dispatch.h,sha256=i-BHyBH6vN_Vfx1mHCagrt1Sltg7J9b8jQ2TdU6b4RU,1221 +torch/include/ATen/ops/leaky_relu_backward_meta.h,sha256=WhcmiBgoV1rRgviwGhjfFWnlZZfFZuNjUXwDnRIuPVs,695 +torch/include/ATen/ops/leaky_relu_backward_meta_dispatch.h,sha256=3YuZcZ4MjuV-H1hA7ulprI6PkSiTXZksm_7ZlnCm968,1221 +torch/include/ATen/ops/leaky_relu_backward_native.h,sha256=9S3cq_-XQ11MdGEUh-9whtd4gOW6L1fuqJvCqdkheM0,753 +torch/include/ATen/ops/leaky_relu_backward_ops.h,sha256=EctA-DgzTPk18RQJrUBzlF0V4bSovy5lC2HS1QJEiAk,2279 +torch/include/ATen/ops/leaky_relu_compositeexplicitautogradnonfunctional_dispatch.h,sha256=KDC2F32jCuVnXlX7pL6-FoUAdOaJpmuQkNo9U5u9WCs,953 +torch/include/ATen/ops/leaky_relu_cpu_dispatch.h,sha256=tEXMq_cVJk65fPTjkcXI4bWY0rUHPN6xzenS3nOXvvk,1125 +torch/include/ATen/ops/leaky_relu_cuda_dispatch.h,sha256=SDfpwuNltusqJqPg3uoHMq8UE9mkN1UAYpvhXueJRBo,1127 +torch/include/ATen/ops/leaky_relu_meta.h,sha256=mGrIPHJb_aLGTaEJKVRSWYpkVv73UGd447qRlgqhdM8,633 +torch/include/ATen/ops/leaky_relu_meta_dispatch.h,sha256=qwsmE0_M6uVacJH2O_-GR_0KGzTGGm1BoTojALKeZm8,1127 +torch/include/ATen/ops/leaky_relu_native.h,sha256=3CWjuycc0hOSdMp9BNHOloWRHcglQcYet2-Z8srR8QM,1021 +torch/include/ATen/ops/leaky_relu_ops.h,sha256=q1B48Whn1-PWGQ9cGr-HkxTNoYp6QqHFLdjN0Dtd-L0,2487 +torch/include/ATen/ops/lerp.h,sha256=ayEy6fgehrCxvmZjG4rm-ccFRCl0lS8DPF_J8x5-vTE,2198 +torch/include/ATen/ops/lerp_compositeexplicitautogradnonfunctional_dispatch.h,sha256=pnb0P9rrT4n4svhpBS0RWsXkgfxz3VhqHT4zYhxW3yc,1168 +torch/include/ATen/ops/lerp_cpu_dispatch.h,sha256=MvUG64EEi-9VNlhI9GCg0vmb7h3HrzMjv7HUxeq5CsA,1612 +torch/include/ATen/ops/lerp_cuda_dispatch.h,sha256=kTNfodjLoesyXjaeEI7CBX5xpTHU_Ko_V6ChgzGkGlQ,1614 +torch/include/ATen/ops/lerp_meta.h,sha256=bXTMThLWzG5fgKS_MEDnc0KE8ppHKTEdGcJySrAatiI,821 +torch/include/ATen/ops/lerp_meta_dispatch.h,sha256=-A4GqcvaGQFDrGxclxm9K1mw7DxUi_ID9jCole9skj4,1614 +torch/include/ATen/ops/lerp_native.h,sha256=5yevCEZbAYw0GaJw42EBmX7KGlNq7hv0NF5fg4-3vNQ,875 +torch/include/ATen/ops/lerp_ops.h,sha256=LO4rQxpSU7fFiTO03WubseD61rSokuXZC2r8bpbaPmY,4819 +torch/include/ATen/ops/less.h,sha256=DpsXtR-nrRZNXsFjvbbq5lt0epfl5BTYSyzmgTTeBCA,1934 +torch/include/ATen/ops/less_compositeimplicitautograd_dispatch.h,sha256=pGm-gnrR0EbWSi-R22DV4EaQS6aH-BHL72gFzPuGhOk,1456 +torch/include/ATen/ops/less_equal.h,sha256=OAB5QTFKuBKBKGPWq-dOB2Y-fCNznNGvJ5lNG2mutCE,2048 +torch/include/ATen/ops/less_equal_compositeimplicitautograd_dispatch.h,sha256=q2TQ2jYJToBDd72ZPhi0_kj7hWlzjmhPn5dO_WWwpqs,1504 +torch/include/ATen/ops/less_equal_native.h,sha256=kmDmlhfrBXwRXhH3ODNp1ANpjks_DT-QUG3F_W5OPKk,994 +torch/include/ATen/ops/less_equal_ops.h,sha256=D-umb686TxNVrT9YvvzCxoVXGQqmgSQfK25-KaG3aVA,4429 +torch/include/ATen/ops/less_native.h,sha256=Y5W6Mq1Gj9EfK7a7518nha4ljiBacoAQ_WZ-0qypj4A,958 +torch/include/ATen/ops/less_ops.h,sha256=UEo97W7UuL2gE3VoytBAT7ZNkLJfgn6hZvkFRWVGar8,4321 +torch/include/ATen/ops/lgamma.h,sha256=moUK9CTd478sLqhLBEMBPuZtNydPH2jwjHLtWqUhvn0,1067 +torch/include/ATen/ops/lgamma_compositeexplicitautogradnonfunctional_dispatch.h,sha256=TeOtEIkrLxURhLEnJcdH_AvmBV7hJ0b1ti12_FZe_qs,865 +torch/include/ATen/ops/lgamma_cpu_dispatch.h,sha256=w6fn93gs-Gyd95w-b-58_fxsElD7eWuZOWeQEb6dwr0,954 +torch/include/ATen/ops/lgamma_cuda_dispatch.h,sha256=AnmGPYO7BYuUpVvMyYiwg_3BpN4-a3vpCUTOIewr85Q,956 +torch/include/ATen/ops/lgamma_meta.h,sha256=kTUo0ehyharR7ubtuJFOp0qQsqH7q4AGx-0eGLmmiRs,594 +torch/include/ATen/ops/lgamma_meta_dispatch.h,sha256=mtJKsJyN0fmSA9uqDJBfAj-CL9Jglt6uhr5LO_ntRWg,956 +torch/include/ATen/ops/lgamma_native.h,sha256=xSGDhsN6YrAbynEkvnqikPo3qnpjMiArGPVomrcL8tk,619 +torch/include/ATen/ops/lgamma_ops.h,sha256=2t7F3wJ4nuhJoC8UUhyN5_8jkLwiBb3XVO2tkuHMe70,2097 +torch/include/ATen/ops/lift.h,sha256=gt_5IqHr87_WQ2pqkDKcBBLKzoL8XXLPfjwpm7l863Q,1047 +torch/include/ATen/ops/lift_compositeexplicitautograd_dispatch.h,sha256=Yg2sBC4Dq-eOgF-6mSU1RJQp677d1voARD-icdw_JSo,940 +torch/include/ATen/ops/lift_fresh.h,sha256=SMUUIbFDHHm8kSLGCQxwIlVsn0gz-ZLKwmhXbqhXczQ,690 +torch/include/ATen/ops/lift_fresh_compositeexplicitautograd_dispatch.h,sha256=3H1mwN_Yq1wR5KfNZMFZETCdVf13Wyd64_xbL7ZddgE,791 +torch/include/ATen/ops/lift_fresh_copy.h,sha256=rClaEc_Qf14j0DlcVJ7VohxiLiNXvdN_93avCSyE62s,1157 +torch/include/ATen/ops/lift_fresh_copy_compositeexplicitautograd_dispatch.h,sha256=KIKMIToRE7dISahe9pXOSnMp-S6i-HRDhXAp-GoBdsM,909 +torch/include/ATen/ops/lift_fresh_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=0SLnkZo7O6f4q3BDjy3qrywO6ejKI_iy7QC9Q0o-zjo,822 +torch/include/ATen/ops/lift_fresh_copy_native.h,sha256=F24O3qk8T9KfV2Giygxjf5mWRCw1XnRi2oR1-lNa_p4,594 +torch/include/ATen/ops/lift_fresh_copy_ops.h,sha256=53uwnwshhuUnp85WfJ5mrVw4TkkY9eUSkcwHmp-dgTk,1641 +torch/include/ATen/ops/lift_fresh_native.h,sha256=F6zOgV98f2ZbgczW3tW82jS2ZRx4hk9HXzDCW3vSSkw,501 +torch/include/ATen/ops/lift_fresh_ops.h,sha256=EDfelfwzJ_AsMScM-bTvAjFgHio9gPsfO2BN-BgVV7g,1004 +torch/include/ATen/ops/lift_native.h,sha256=qlH415EJ11nqOX0CTc7PLLoAsLxpibnGjO7XGYlpm9g,572 +torch/include/ATen/ops/lift_ops.h,sha256=6iU-lKTQV6i-e5zKsvjlvPKg2gMPrAK_8ZWhbvgoGpI,1575 +torch/include/ATen/ops/linalg_cholesky.h,sha256=07X0yOVRVEEi1QwI3MZLq6fZ91vcnBYG5rpbRNlMN-4,1283 +torch/include/ATen/ops/linalg_cholesky_compositeimplicitautograd_dispatch.h,sha256=sNLo5PF1xgmI6hLu67Ca1FLznK2q33hTyxzFRD39JBY,1021 +torch/include/ATen/ops/linalg_cholesky_ex.h,sha256=6K2jf7-mfCDxUpcgoqSPtCVrCcYwWmCtG019sf1ecSY,1701 +torch/include/ATen/ops/linalg_cholesky_ex_compositeexplicitautogradnonfunctional_dispatch.h,sha256=L2qJWeRUsMuBB2U7_yAgCs5eByDkfmlbLgTqTzqzzqY,893 +torch/include/ATen/ops/linalg_cholesky_ex_cpu_dispatch.h,sha256=PjdrzW4TMynwVtozFKdoQrw9pt0W4XAZGPRW6pwJRGM,1168 +torch/include/ATen/ops/linalg_cholesky_ex_cuda_dispatch.h,sha256=wIDDff2-JTfP7gqPJ2mwVApWZPtMymQRB2lcLAizyT4,1170 +torch/include/ATen/ops/linalg_cholesky_ex_meta.h,sha256=mZT3qVGIruRmWYt_DkcCS2UftR6FuGOsuglR5Ce18Pk,637 +torch/include/ATen/ops/linalg_cholesky_ex_meta_dispatch.h,sha256=qGl3tO7KwsMZmh8i7AJCkh-5pt9Tq9u0r4iry1HvyFE,1170 +torch/include/ATen/ops/linalg_cholesky_ex_native.h,sha256=iLtuiSOuQCTUPa3vlAagRWscKWUwW30rCANAvrfN9xg,709 +torch/include/ATen/ops/linalg_cholesky_ex_ops.h,sha256=h6qVzhlZ8SKAJx9xrG_HganUYRhF3I7g99X5MhJ6WbA,2147 +torch/include/ATen/ops/linalg_cholesky_native.h,sha256=uxUKpFHnDvJ4UqwDIe5zGs_zpoTTh41Hto-Exi3vdRg,624 +torch/include/ATen/ops/linalg_cholesky_ops.h,sha256=s6bXtO7kGxtBBAeX2NB5Zl2VMXR1MC7rzd0RsHPPXt8,1740 +torch/include/ATen/ops/linalg_cond.h,sha256=KZWJv0bb4-XuHJHSEgOwoYIBp8vybxJ_axcsndBuoU8,2031 +torch/include/ATen/ops/linalg_cond_compositeimplicitautograd_dispatch.h,sha256=wA9BnAqamEo4wZjQLZe8C41wpSHOn2F3GXCRpRK2q8o,1397 +torch/include/ATen/ops/linalg_cond_native.h,sha256=WWKEOVNgyDVHurxJ-36gR7Nfb8pD_yJk7_rEmktqH2g,863 +torch/include/ATen/ops/linalg_cond_ops.h,sha256=YguMmHvm3EoFQ8mlsypfJyDtFT6ciezYWdsfo_lLQcU,3186 +torch/include/ATen/ops/linalg_cross.h,sha256=scPpGu2i0yDoIdrbYM4D_GkE6kvCw5akW_a3Q0HQJzk,1367 +torch/include/ATen/ops/linalg_cross_compositeexplicitautogradnonfunctional_dispatch.h,sha256=AjApHTgho7xGdqO3mK6nJk8o3HhCg3ggE1NfqNq6ooM,861 +torch/include/ATen/ops/linalg_cross_cpu_dispatch.h,sha256=05fFackUWVrdG6zamgka3irMoVV-8qHlchuBNwpoVlA,1043 +torch/include/ATen/ops/linalg_cross_cuda_dispatch.h,sha256=HpqZtLrYLm9XiLayHtAJahenoB2ZDyU0B3C_af0bGDI,1045 +torch/include/ATen/ops/linalg_cross_meta.h,sha256=td1WFXkYyHdZsOYOLTwhD-MxCZ-_TqFdKsnlkvHumBs,639 +torch/include/ATen/ops/linalg_cross_meta_dispatch.h,sha256=cxP4DIolMSVOIaeP4KGGaGhWXtcHhkW_dSpleRFULB4,1045 +torch/include/ATen/ops/linalg_cross_native.h,sha256=V9wgmRMjXtb_kedy8DNK0GrDdkYe1Mi2Iqwp4TFJQh8,790 +torch/include/ATen/ops/linalg_cross_ops.h,sha256=Pq1pfkgpTVE6yUMlH9qS1bKTNnmWwcnKWrUYUD-fUaU,1892 +torch/include/ATen/ops/linalg_det.h,sha256=VstSdJ-SxH8J6GZ421Yfqbknccv6HsZn6lpMMEDREF0,1080 +torch/include/ATen/ops/linalg_det_compositeimplicitautograd_dispatch.h,sha256=CEMOGhm13rogwdq6TWANpWVCGhaQmC0mD4cc-lvRZI4,949 +torch/include/ATen/ops/linalg_det_native.h,sha256=cPvwxKv-qN9kX25FMlXeGR9USsphLTnK-le5s9vsyqU,578 +torch/include/ATen/ops/linalg_det_ops.h,sha256=eLdZGK2u5Hn7Fcs-jJq5ojX8Q0cEdFu6ulHDA95OKTY,1593 +torch/include/ATen/ops/linalg_diagonal.h,sha256=dkEEUKtg87b4mWJugSQ842Sod8sHYvTEEv3tlXeDwjQ,816 +torch/include/ATen/ops/linalg_diagonal_compositeimplicitautograd_dispatch.h,sha256=__eF9ZuN_zW7Gc_naIGOMU_tLg9V_v3E2vQ84jBOGz4,845 +torch/include/ATen/ops/linalg_diagonal_native.h,sha256=sWRpHybtLlMQe4n3WjKIkuBPWEfaQODGs5Ii5Pcxe3k,555 +torch/include/ATen/ops/linalg_diagonal_ops.h,sha256=NrnTeJDAlxuDCXYJdn7mVoks1xPZdLj6cGkx11kBlP4,1168 +torch/include/ATen/ops/linalg_eig.h,sha256=4xc3hjFwB0o4IoK749L-GinP4KhdjYlyWbVSBk7dDlA,1479 +torch/include/ATen/ops/linalg_eig_cpu_dispatch.h,sha256=yCavEkzbw6OQeIwj6ff8BuyRL2NVd8etDG_l48F1cxI,1063 +torch/include/ATen/ops/linalg_eig_cuda_dispatch.h,sha256=gQQaONaKvnBzypFP6wOXzFhOEwxi-9kXlMS_toymQ6o,1065 +torch/include/ATen/ops/linalg_eig_native.h,sha256=vmBnsZn6T0PYhT0D0AOCEqOEoQj1amINdePQ15jps2Q,671 +torch/include/ATen/ops/linalg_eig_ops.h,sha256=YR2K69P-ifVLgoiKD9suCjqIPDDEghGIAQIEkKw8ai8,1958 +torch/include/ATen/ops/linalg_eigh.h,sha256=Dm69ihcuZ2T0WtEJI9gwNPfxXpVUJtvYMoFRfwuUh7E,1588 +torch/include/ATen/ops/linalg_eigh_compositeimplicitautograd_dispatch.h,sha256=agbG-x1M01NKw0bfw0wInSoTghOz5wWBz_kJJzFr4jI,1169 +torch/include/ATen/ops/linalg_eigh_native.h,sha256=ZsjGTYmh4TPGabE2XCC8RU1qqjMQ1A67vD0jv5i8qKo,714 +torch/include/ATen/ops/linalg_eigh_ops.h,sha256=TP5TzyXPQGZTAJfEDHJE12Kv55xvz356LQC91MCyZ8Q,2109 +torch/include/ATen/ops/linalg_eigvals.h,sha256=fubW7YwKUku-Am4Res8l5EUrHTAGjTxVjI64l1VaO90,1147 +torch/include/ATen/ops/linalg_eigvals_compositeimplicitautograd_dispatch.h,sha256=aUTo9IBIJ6Tvr8DG3iNP-4huHpacRYIwOzgxY5jspBk,795 +torch/include/ATen/ops/linalg_eigvals_cpu_dispatch.h,sha256=3trf1XLw5kZmswCIOe2ZQOdzwlTWxXahioHSi--jzBI,863 +torch/include/ATen/ops/linalg_eigvals_cuda_dispatch.h,sha256=YI-ikqzSSPoTdaryegSZJD3FwkC0Q9FHTvLEZ7jdsQY,865 +torch/include/ATen/ops/linalg_eigvals_native.h,sha256=s_U-JhEfIe0ECZvWarUZl1GaPuAOSLmwkCjCQuSiOJ4,592 +torch/include/ATen/ops/linalg_eigvals_ops.h,sha256=-eJKGQ5IYr70YkZmqrNua7HhSTFwQAL86TA9r09F25E,1635 +torch/include/ATen/ops/linalg_eigvalsh.h,sha256=xnPu5KYVBacTfal68Iit1YfnSgjdatBGu441b2vI0LM,1294 +torch/include/ATen/ops/linalg_eigvalsh_compositeimplicitautograd_dispatch.h,sha256=wSTXIjZ22GoPwo0j5ACS57Ix34LmZgviZw-IUwsMD4s,1050 +torch/include/ATen/ops/linalg_eigvalsh_native.h,sha256=bu6wf11rIag5cbFOb-GMBtqL6ZtOFQq6BfeJ71LiTk0,644 +torch/include/ATen/ops/linalg_eigvalsh_ops.h,sha256=VGXYBJ7T5FMh6dlY11BPk9johPv9cFfH4OcjCgMJb8Y,1801 +torch/include/ATen/ops/linalg_householder_product.h,sha256=WcC9rIGfdCji3TZRHE3vD9fw5-a5P6rmueSDGru0j5I,1399 +torch/include/ATen/ops/linalg_householder_product_cpu_dispatch.h,sha256=IWc8HN4J32TB83Uct4dGsOE0NJ7ZYoJFLTbXJu4qIFo,1037 +torch/include/ATen/ops/linalg_householder_product_cuda_dispatch.h,sha256=AAkGVfRmHnKlKdPc4FjV5f-gHa0GHHYzJfyoDAXSZQw,1039 +torch/include/ATen/ops/linalg_householder_product_native.h,sha256=2g44g2E_8f1IAg3JmP18uKWtrSyq2xbImvl-gw4gCl0,666 +torch/include/ATen/ops/linalg_householder_product_ops.h,sha256=-JFBGdFeYp8iZOcLIyHsqf74Cw1Hfydv8MUYMLjFZiM,1873 +torch/include/ATen/ops/linalg_inv.h,sha256=rkuEEadpvXPAXx2xyUTMT2V8c_d1153XhOUZORe4-CA,1080 +torch/include/ATen/ops/linalg_inv_compositeimplicitautograd_dispatch.h,sha256=8YCFi4O_QVjy909IXghbx1WhxCaVAsjlCMd1DDYybEE,949 +torch/include/ATen/ops/linalg_inv_ex.h,sha256=nOdyDKtWovRxMeaEkuKzYFp6p9Bxy1y0rMe45zKLh18,1579 +torch/include/ATen/ops/linalg_inv_ex_compositeexplicitautogradnonfunctional_dispatch.h,sha256=rM7hNptwdbhFG7zYBqPm8WvAXo72rJn_EhBc9-k-sAo,867 +torch/include/ATen/ops/linalg_inv_ex_cpu_dispatch.h,sha256=KPA31rwy9qgIqPRUy1_F8pKwdUFlt5xjhT-wjmqWHs8,1108 +torch/include/ATen/ops/linalg_inv_ex_cuda_dispatch.h,sha256=T7QQC5pvM7K_20NXlcQo6rAp_-8aDRehDds0k219Pjs,1110 +torch/include/ATen/ops/linalg_inv_ex_meta.h,sha256=vuegM0Rivp4zXmA4uabw9zxpIeB-0KJngZR3ZkdoBas,617 +torch/include/ATen/ops/linalg_inv_ex_meta_dispatch.h,sha256=YstI_i3-NBwpJp2PJwxa-qEedS8kW5DuAHX_0hOZzYw,1110 +torch/include/ATen/ops/linalg_inv_ex_native.h,sha256=H28k57oe0Xt2GJJm5RpDZ_DdrbFX_cMFEoOR7EwFeZE,685 +torch/include/ATen/ops/linalg_inv_ex_ops.h,sha256=5d_TftVRpfjEXaqJrzsGJtFZnFMVQ1E5sGXsAx_yDZA,2051 +torch/include/ATen/ops/linalg_inv_native.h,sha256=kMVg-OqmGjDNb0g-tQaQ_7m9J6dbwan1jtRrnJXvgPQ,578 +torch/include/ATen/ops/linalg_inv_ops.h,sha256=-_ovQSY_PPrxHSe-E9iA1FTwQ2NSqMAILHt22tWqcYQ,1593 +torch/include/ATen/ops/linalg_ldl_factor.h,sha256=92wjnJlMeUBWQUEYa_HwUmb32hzsUOvcY5S9kgE4AUI,1576 +torch/include/ATen/ops/linalg_ldl_factor_compositeimplicitautograd_dispatch.h,sha256=qEBWW6q3K0pS5bevS8ZpwAXkc3J2yDteIaCAnfE0xWA,1158 +torch/include/ATen/ops/linalg_ldl_factor_ex.h,sha256=zCdGjkhgoZcOqN_YR57TnBmqoLIskFJG7wBaprdPe88,1960 +torch/include/ATen/ops/linalg_ldl_factor_ex_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Q_15XsQUIz7Z3yCTTMyx-ES9Nc8lpRGSCofL6g5edHY,910 +torch/include/ATen/ops/linalg_ldl_factor_ex_cpu_dispatch.h,sha256=uZNsjsfcf6CJvvLiRlbKAlYo3gKcwLW99wXJp72fp7w,1267 +torch/include/ATen/ops/linalg_ldl_factor_ex_cuda_dispatch.h,sha256=XzQo_Iq74g6jpKGnAJoJgwCcBSie020w-VjQCC47T8s,1269 +torch/include/ATen/ops/linalg_ldl_factor_ex_meta.h,sha256=2wS0pkotvBiC595j-ETssZwJyBxT6nnoteOiFBY4m4Q,643 +torch/include/ATen/ops/linalg_ldl_factor_ex_meta_dispatch.h,sha256=bShYy9aJZOCh3_Zi08EcH7tFmJ06sbr7E2EnpLOVLLQ,1269 +torch/include/ATen/ops/linalg_ldl_factor_ex_native.h,sha256=LR_teyJ1PfvzN4AcSScswuo6nHRrQFliRVZSPucyhEg,747 +torch/include/ATen/ops/linalg_ldl_factor_ex_ops.h,sha256=Rtmyc4sW5TY5NpdGOPFZasHSiNOlC7x6ld37TNnx4MI,2375 +torch/include/ATen/ops/linalg_ldl_factor_native.h,sha256=SnBtIwuKJ6lY8Zgw8pxtjCEqnOp9SQGHyGhHncNgMxg,708 +torch/include/ATen/ops/linalg_ldl_factor_ops.h,sha256=U-H_F_nzco_cybKABgqeFYVBn3ROl1HmuPh0W17L41o,2048 +torch/include/ATen/ops/linalg_ldl_solve.h,sha256=X6vNDhjjvAucRovTnF188OIPRJiakI_xVRs8GXY_d1g,1566 +torch/include/ATen/ops/linalg_ldl_solve_compositeexplicitautogradnonfunctional_dispatch.h,sha256=HFGnzqycFoUQB_RFZo8tchW8KbKijP87kGRhIh1YpB4,892 +torch/include/ATen/ops/linalg_ldl_solve_cpu_dispatch.h,sha256=c-EOS4j5z7FSP_THrdGtYMaTG5INvznUMsg9CCsQz44,1133 +torch/include/ATen/ops/linalg_ldl_solve_cuda_dispatch.h,sha256=pVT1puytTdRx2R0cvAcFZUmpskRpEE0IP8A6jP7BvrQ,1135 +torch/include/ATen/ops/linalg_ldl_solve_meta.h,sha256=bWtRerUfQU061dJ4MVHjpAuqkL8DiUQ8ZlKBMGT1M-E,667 +torch/include/ATen/ops/linalg_ldl_solve_meta_dispatch.h,sha256=oqMgl5RBborDW-olvivj0Y1zeYjS8SjEtqZ5YIE302s,1135 +torch/include/ATen/ops/linalg_ldl_solve_native.h,sha256=ppBXO4PKDbRlSekrhJYQCHCpHk_5tGY73OFLH2Vplpo,712 +torch/include/ATen/ops/linalg_ldl_solve_ops.h,sha256=EVqcluAidaMn9zin3-b_D4zpBON2M1p8fNeT2Iw0K_o,2084 +torch/include/ATen/ops/linalg_lstsq.h,sha256=nJ7YYj2g52vBvsajLZxQxLQvE13kDhYLw-FFRTMmKHY,2423 +torch/include/ATen/ops/linalg_lstsq_compositeexplicitautograd_dispatch.h,sha256=7tY7gJeD51R5aptcIHmgfdr_RI8JOO9u-5VVuHpgO5g,965 +torch/include/ATen/ops/linalg_lstsq_cpu_dispatch.h,sha256=F7RsNyWV1nqmpR3FOvxY6Covsv00WpKBaiyHiEM3mKs,1341 +torch/include/ATen/ops/linalg_lstsq_cuda_dispatch.h,sha256=yTw1w72Eox28g2jUTD4w7HHwIaBdKRu29hcGypgdZgk,1343 +torch/include/ATen/ops/linalg_lstsq_native.h,sha256=Pf3oMYTGkzuGqT9tOtopb1iQ3xoHTIJxDStig5P2Ats,986 +torch/include/ATen/ops/linalg_lstsq_ops.h,sha256=_trIvjMELoHEtQUoVnChh2upUbWlOAK7cbLtrxlGz2Y,2974 +torch/include/ATen/ops/linalg_lu.h,sha256=nuZoMYPctAJTr5txHXkLmi5tY3UcKwHePXO64mE6-3g,1515 +torch/include/ATen/ops/linalg_lu_compositeexplicitautogradnonfunctional_dispatch.h,sha256=eSJzSWFLSbCI74XivERueyiVmwzhON-Vo3IO6bWzsGU,866 +torch/include/ATen/ops/linalg_lu_cpu_dispatch.h,sha256=jUz0Wh6wdRTwqljnjD6pitVNSWbKEoKK0vDUy1s-veg,1124 +torch/include/ATen/ops/linalg_lu_cuda_dispatch.h,sha256=-xSfvdTi9GPVAsuUsQb0tZsFJ0pcW8wYS2L3VXKspdQ,1126 +torch/include/ATen/ops/linalg_lu_factor.h,sha256=Ul0Z3uFoo4vKT5b907z_DcFfIJQA5rXqpiYtjDDpzMw,1498 +torch/include/ATen/ops/linalg_lu_factor_compositeimplicitautograd_dispatch.h,sha256=HlqM8dHvTitoK32QQ3hP2z33JL6G1PrN3qpWuLq19NI,1132 +torch/include/ATen/ops/linalg_lu_factor_ex.h,sha256=WV7UgI2JB0mAkgWxRdVC4HzROzR4NIqL8GGbHyTPLIk,1882 +torch/include/ATen/ops/linalg_lu_factor_ex_compositeexplicitautogradnonfunctional_dispatch.h,sha256=V37dBozm3bNyGxbB9ja7Ar975PdtiHNdyKD3l8GNxXA,901 +torch/include/ATen/ops/linalg_lu_factor_ex_cpu_dispatch.h,sha256=HNiBlwcgdDa3Xp8sxlVZ3wlzcCxDi_AVcg947nmeDpI,1241 +torch/include/ATen/ops/linalg_lu_factor_ex_cuda_dispatch.h,sha256=BedP0jxVoyCkgJaHeSPpHxev4iUhbkOClVEwV2eFyYI,1243 +torch/include/ATen/ops/linalg_lu_factor_ex_meta.h,sha256=St2Bjam0YBNFjpCEWBhDfB8oyhk636gMRiA20JAklGs,635 +torch/include/ATen/ops/linalg_lu_factor_ex_meta_dispatch.h,sha256=HWqT1BYWeypbTFZm08XaBEUt-XhoUGs18m9bhMTGMQc,1243 +torch/include/ATen/ops/linalg_lu_factor_ex_native.h,sha256=vXLYZKxvIsxN-cUw7ylloxADUMEMa1cNwWp8QD2FtJk,737 +torch/include/ATen/ops/linalg_lu_factor_ex_ops.h,sha256=b_kttRN0nk2VyHgYmSd28sdrSxQwye53O55hqVZpYVY,2325 +torch/include/ATen/ops/linalg_lu_factor_native.h,sha256=gO1bNc-k4PCBy-EHUb5mOXTA_cipwRPtRA4hyy0X71U,691 +torch/include/ATen/ops/linalg_lu_factor_ops.h,sha256=-gQLY2HZrZMekebhjLFzbNN5Vt72tSsuUDj07Cx_VA4,1998 +torch/include/ATen/ops/linalg_lu_meta.h,sha256=Rh8eKYe0K-lzZwBUucdiGVgduUV0dw83VUOW9oOL82I,606 +torch/include/ATen/ops/linalg_lu_meta_dispatch.h,sha256=bpdq8fWPRNlPvbtz1sQ5J_XkKk7NZoNXMYj6DZ0waP4,1126 +torch/include/ATen/ops/linalg_lu_native.h,sha256=Y9HZ5q9_0qCKWOmpPYjWHpo-tGh0l8Mem5A_hN1DTOg,679 +torch/include/ATen/ops/linalg_lu_ops.h,sha256=I_9QoJOtyZpCnYIVOsq_fVReeOgTvpGTloE8HLM67HY,2082 +torch/include/ATen/ops/linalg_lu_solve.h,sha256=WCCLVPQ6vv9wne-3O2jEJhwJYgqoqCaDkpdDxLE7scU,1647 +torch/include/ATen/ops/linalg_lu_solve_compositeexplicitautogradnonfunctional_dispatch.h,sha256=7zj-Y2ErMEUN6e-uGW-io-5Mp0aidJ4fZLflgjVhof0,905 +torch/include/ATen/ops/linalg_lu_solve_cpu_dispatch.h,sha256=xK4m8h2N9SCQG2chu9E1TeosELw9uTTguWJJGjoDgM0,1167 +torch/include/ATen/ops/linalg_lu_solve_cuda_dispatch.h,sha256=gpqZU6wPdrcJ1eWEDPw_qptwmxsLmKZJyMceuMByTPY,1169 +torch/include/ATen/ops/linalg_lu_solve_meta.h,sha256=5hRIraenpc0eYyo5i8wMILWqC-LhGx4yaZqFp_1zTbI,675 +torch/include/ATen/ops/linalg_lu_solve_meta_dispatch.h,sha256=53l5UanEMMvg8E-fuQlxmJ5I9szm1BNjO5OW3w2XAxM,1169 +torch/include/ATen/ops/linalg_lu_solve_native.h,sha256=fIZAdYgruLZDPEbQaHFmkCmDNwAFBHloa6l39MFfIbU,718 +torch/include/ATen/ops/linalg_lu_solve_ops.h,sha256=rjMomgWc3NXs4Fkmo3vDFUyDZfMei_T6-Q9B6jRsdhs,2154 +torch/include/ATen/ops/linalg_matmul.h,sha256=j4SUcyCnKNeYckLxQroel_FauqoNjDESqh-tSQ7Ehwg,1278 +torch/include/ATen/ops/linalg_matmul_compositeimplicitautograd_dispatch.h,sha256=U0aY6pEJIPSo_WbkdqDYFoFDqqHAnlUE7MMa7nj6Hss,1045 +torch/include/ATen/ops/linalg_matmul_native.h,sha256=h4FLUbMZZKkd7xO7-jppB9N420QVmQSLuWsqAmH3zaE,642 +torch/include/ATen/ops/linalg_matmul_ops.h,sha256=sc5XPoyYGPnN0L6d6GBExFU-SfHGFor1oSCGeIkOj68,1801 +torch/include/ATen/ops/linalg_matrix_exp.h,sha256=2B2R5MsGwbICrJp_RMVaWsD0IfPV-U8aSUMOuVvi72I,1177 +torch/include/ATen/ops/linalg_matrix_exp_compositeexplicitautograd_dispatch.h,sha256=GHGJnjT7e2kT8qZbnZsd3k-Po8er-ipFkUmMtXkGK0o,913 +torch/include/ATen/ops/linalg_matrix_exp_cpu_dispatch.h,sha256=S-4FsyrhsBLYGinIQlHmqCEn6l-HXf1cxdGVOGyF5Fs,754 +torch/include/ATen/ops/linalg_matrix_exp_cuda_dispatch.h,sha256=2ybl9Ory6qmP-6teJ-bCIW4bY12ugH_MQK1Vgtg5oAA,756 +torch/include/ATen/ops/linalg_matrix_exp_native.h,sha256=fCSF6X-h2HuD85wYmzYnTaCKDmGtcKq1x6-8T1gmMpg,598 +torch/include/ATen/ops/linalg_matrix_exp_ops.h,sha256=nKn-Goln6kC_27Yp0hVGv-cjPQXwjWQk9NpwIuJRSY0,1653 +torch/include/ATen/ops/linalg_matrix_norm.h,sha256=S6j6CGq6dFtj2GS3jV_FcoYv6fuev13D1dPwVwC2Kgo,3213 +torch/include/ATen/ops/linalg_matrix_norm_compositeimplicitautograd_dispatch.h,sha256=-Z7WEWUWI0f8sfHBd0K8W8qiO6Krj4l2cFidi-3CaeA,1942 +torch/include/ATen/ops/linalg_matrix_norm_native.h,sha256=36ystJuti0i3upoP_fX3cO3piqTXWWr0yALsUJsRcrU,1210 +torch/include/ATen/ops/linalg_matrix_norm_ops.h,sha256=N9ZkMPkVrIzXDiZCQoa3UiZDlwPxDMouRGAHWPIccj4,4278 +torch/include/ATen/ops/linalg_matrix_power.h,sha256=RkQlsTdbJyH2dkUfRlnPq0a4YHgXNu4ml8z67bLafmk,1260 +torch/include/ATen/ops/linalg_matrix_power_compositeimplicitautograd_dispatch.h,sha256=OKHdFNGAPD64iUhWZ_dU0XWYdNBSGC8eotIlPNhTUk4,1018 +torch/include/ATen/ops/linalg_matrix_power_native.h,sha256=4-HyiZK7hpzpAqFPCUBfv_Z_7y4ZVt1GXYMzbUa2SpA,624 +torch/include/ATen/ops/linalg_matrix_power_ops.h,sha256=w4YAT2WRKAfDhxWpah3Me8_51OuOYP4nkeT147z5yBQ,1741 +torch/include/ATen/ops/linalg_matrix_rank.h,sha256=taUMoFPJj5H-o6iy4XqF3TH3FNZsnwQGjHAuPdUVxY8,4997 +torch/include/ATen/ops/linalg_matrix_rank_compositeimplicitautograd_dispatch.h,sha256=cazqxhnoCTei96Xbfww19MCM0-hT08-P7FdElxsRv9o,2530 +torch/include/ATen/ops/linalg_matrix_rank_native.h,sha256=eZtrqGmanxuJ55cOpYiikGE36wpDyXc0nbv3G8Gjtww,1626 +torch/include/ATen/ops/linalg_matrix_rank_ops.h,sha256=Q916cHZKohcOKA3pAJlg0LcvLFMkr7Th7OTS4-Gh-KQ,7188 +torch/include/ATen/ops/linalg_multi_dot.h,sha256=oyDcwCB_5w2I_noljp0qJJOjZF4rvelyCkGPxyVgwf4,1188 +torch/include/ATen/ops/linalg_multi_dot_compositeimplicitautograd_dispatch.h,sha256=fd2R8mbeFhUfQzhn_zMAlD82Ix09w2h41V6R6bywMZg,973 +torch/include/ATen/ops/linalg_multi_dot_native.h,sha256=aqA5rZ81LLSWEhkIpgE8pLJiNiNxInKmUn_rMPPExdo,594 +torch/include/ATen/ops/linalg_multi_dot_ops.h,sha256=akV3dsJc4a5M3UOEiU49dKJkcVNkA7J5J-PsG0eHVaQ,1645 +torch/include/ATen/ops/linalg_norm.h,sha256=5Sb0VBL1RqXXCNqDs-Tzylc3B0wmSApxIAgYC1hOqDQ,3219 +torch/include/ATen/ops/linalg_norm_compositeimplicitautograd_dispatch.h,sha256=M_cqMzBDWDPIHkxiKJHufEO5dQVCKkaT01nTDl1aJ_k,2045 +torch/include/ATen/ops/linalg_norm_native.h,sha256=mjb9N_iuxY4lw4H23eWeclRuKGYDNlnKaDgX7yUUhDY,1271 +torch/include/ATen/ops/linalg_norm_ops.h,sha256=tIf7Ky6alzmvjNw_xkWiw3WEo8j7krShqBhGntCIDio,4388 +torch/include/ATen/ops/linalg_pinv.h,sha256=SE7scSXm9fmezZays9eNKk3TeYF1WocYSxy6GdhJcKg,4768 +torch/include/ATen/ops/linalg_pinv_compositeexplicitautograd_dispatch.h,sha256=6AvMw7UdFzhs0EvVhgVqqYHWwltxTm9sYvJyImqpnqo,1113 +torch/include/ATen/ops/linalg_pinv_compositeexplicitautogradnonfunctional_dispatch.h,sha256=mZMTxLPkg9dexdU_NUZKs8Qw2oV024Ytc5aB9vRwXwI,930 +torch/include/ATen/ops/linalg_pinv_compositeimplicitautograd_dispatch.h,sha256=Ve6PNzCnKMqu2cjp3TXbuyhBJKxXAC6VrPUEQfD0CzM,1899 +torch/include/ATen/ops/linalg_pinv_native.h,sha256=gJ7chqSVVud_D8FeN4JnYRa1NjA3EMMJDltme5fz9to,1574 +torch/include/ATen/ops/linalg_pinv_ops.h,sha256=Q5SqredFYbqGH2y6OGkG84WipOCcOhNJgIIQP7TYVcI,7044 +torch/include/ATen/ops/linalg_qr.h,sha256=M-3-a2IHC710PXuIRh-wfMmWBa6QBlWIhRIlWArQnzA,1420 +torch/include/ATen/ops/linalg_qr_compositeexplicitautogradnonfunctional_dispatch.h,sha256=6spNlsrIiz31j50VWoGrHbavie8GRAn6-XstoZ9PMrk,871 +torch/include/ATen/ops/linalg_qr_cpu_dispatch.h,sha256=OTr23PVwLB_pQQz48Y8AEpi1QKFh9R1hCdLsPd2zK0c,1098 +torch/include/ATen/ops/linalg_qr_cuda_dispatch.h,sha256=E7KMO20H4y6g0yl1dUVQP2p-uDiVn1KR7kvaPyJ5fww,1100 +torch/include/ATen/ops/linalg_qr_meta.h,sha256=vrP94ywVFlJE5myOyMtFyOwRX4CpPyIRlAmxCb53vsI,617 +torch/include/ATen/ops/linalg_qr_meta_dispatch.h,sha256=dJErAl9Xgk_rtf4FUMOYlVNxcbD5syc8La_CH2TYiTw,1100 +torch/include/ATen/ops/linalg_qr_native.h,sha256=t2F5iCe4OvlRILA04Xt6AW2WWyiy6y30SR-Lg5ZQe7Y,668 +torch/include/ATen/ops/linalg_qr_ops.h,sha256=ExyQrYV_X8hzdh926UWoAhiTBYSx3X_ahLEbvBnKhK0,1997 +torch/include/ATen/ops/linalg_slogdet.h,sha256=fgmfFAZQH-vN3yV3BzWXX_DYrosmfMKRVQYI_PkxVUk,1402 +torch/include/ATen/ops/linalg_slogdet_compositeimplicitautograd_dispatch.h,sha256=1EaoBAHvM7u8CqUmJNrivq3FWb3JBncFT1pMR-EY4rw,1090 +torch/include/ATen/ops/linalg_slogdet_native.h,sha256=_vb0-9d2BAEGH6UPJEZHSwiCXuJ_WG6WqiZMxwC5a4M,663 +torch/include/ATen/ops/linalg_slogdet_ops.h,sha256=GIal3jSpDNwRV1dVI-eLgMcyD_n_v3ce8i5ORGIJKhQ,1914 +torch/include/ATen/ops/linalg_solve.h,sha256=R1BZtLI3WSriTfyvwoD1qyVjNAvCaoO7KjJ1cwk1nOA,1317 +torch/include/ATen/ops/linalg_solve_compositeimplicitautograd_dispatch.h,sha256=cL4nenT7ZuTma2yqKva9csxuO5H9DY3wkKf71ES-OB8,1064 +torch/include/ATen/ops/linalg_solve_ex.h,sha256=lpaIiB3_GIiOk0A2eF1SMHNJWXjgT0FjyDj5JomHdrM,1788 +torch/include/ATen/ops/linalg_solve_ex_compositeimplicitautograd_dispatch.h,sha256=CcygaDoyBU8CIcTD9jSXL7UdHo19p3_CYZSc8CD7gnY,1265 +torch/include/ATen/ops/linalg_solve_ex_native.h,sha256=GmaCR3wATY0QxrAPGSE8cP_YnqHcrIwtyXGqwHNVt7o,777 +torch/include/ATen/ops/linalg_solve_ex_ops.h,sha256=v5FePagloFSiS7X8YDsgTMMu2Dq0o_m8x8vm66JtPsQ,2282 +torch/include/ATen/ops/linalg_solve_native.h,sha256=O5eNw9u9ezNs8MjV-HdzzfHZLta7EWZiA1W0NAzt3PI,653 +torch/include/ATen/ops/linalg_solve_ops.h,sha256=WY-Kcj2zd_uLMPoxahn8k4NWMhBNPQs7FUWpr-M2Fs8,1844 +torch/include/ATen/ops/linalg_solve_triangular.h,sha256=9q1I9kQaQo1SBZIIKe4m2DeJd-Gxw7FPaV03MEJjQts,1742 +torch/include/ATen/ops/linalg_solve_triangular_cpu_dispatch.h,sha256=rkgYTebF6HZHnx93hQ6R0GJWwpwU9fDROzYMu2M7Qec,1170 +torch/include/ATen/ops/linalg_solve_triangular_cuda_dispatch.h,sha256=2ztxlI0qElFPWzVcYfUqRxxnz1M8qlHti46Uo6qmN2g,1172 +torch/include/ATen/ops/linalg_solve_triangular_native.h,sha256=6lQnb9drUYEhLMib1cW33qF7QwudGlmbLjtSjCpMHRA,751 +torch/include/ATen/ops/linalg_solve_triangular_ops.h,sha256=q31qNu6LfG1ww5YkU0rk1UdKQOBQSy4R5ZXaWvuocl0,2156 +torch/include/ATen/ops/linalg_svd.h,sha256=DX_9aPId2eH2KZZYndm6cz9nIGWAkF8nWlCmBItcVfY,1832 +torch/include/ATen/ops/linalg_svd_compositeimplicitautograd_dispatch.h,sha256=h5VuLLtQ6E0Q3MzvBU6xGp4ggLFLgAo7IqxUDBNmufU,1353 +torch/include/ATen/ops/linalg_svd_native.h,sha256=Y6Zbar6fmcPSwj7LE9gwKj2pHJkzjZwruy-NK6ZrdWM,829 +torch/include/ATen/ops/linalg_svd_ops.h,sha256=TyShiMEsnfnSbWY-lgEOltST7NFGxIGmEqzrsCPH8nM,2409 +torch/include/ATen/ops/linalg_svdvals.h,sha256=oIOXnfgDjPLIk3bEV0d0cJjqrtnTV0I68PKJguRA7LM,1357 +torch/include/ATen/ops/linalg_svdvals_compositeimplicitautograd_dispatch.h,sha256=JVXjFjIs0DAgHv8jT1Nlas7NAaYw80bKz0qGyQGzT1Y,1117 +torch/include/ATen/ops/linalg_svdvals_native.h,sha256=ojzrVF6FQZ2ZBhIVidFU0l0aY7yTJatZzYWwfWQU6Ys,685 +torch/include/ATen/ops/linalg_svdvals_ops.h,sha256=CdfSg8HiE2njXjSX8amLIJeYK8VTInGkxwZ2qfsHUTQ,1894 +torch/include/ATen/ops/linalg_tensorinv.h,sha256=4FZcWTGHUjZIxEZINSXJxPkCeNYtdrt7JMdnomD3484,1258 +torch/include/ATen/ops/linalg_tensorinv_compositeimplicitautograd_dispatch.h,sha256=cbKHMCgE8GnwN-C0Ovr2Vqq7kF5pYcweRZvQ-mkaxtQ,1019 +torch/include/ATen/ops/linalg_tensorinv_native.h,sha256=SYyS0FIXxFgoAyI50Ih8impX8PvbvFlG6XdMiLf40Bc,624 +torch/include/ATen/ops/linalg_tensorinv_ops.h,sha256=-RZEjoqq7bz9KtFlIufxeoeEU7klllZwkA9QBT_rA7w,1739 +torch/include/ATen/ops/linalg_tensorsolve.h,sha256=dbWvkQ2hb-qPNd4TDEpPhMDSgprtKcVzv_bBgJv3UAw,1520 +torch/include/ATen/ops/linalg_tensorsolve_compositeimplicitautograd_dispatch.h,sha256=vJ8vw1oOy8TQBus-7PCyzYBy6lRYE3NLvyNFe2hmEPk,1180 +torch/include/ATen/ops/linalg_tensorsolve_native.h,sha256=v7i8PA2lP1JzZ1DngLetoApR8_SQNMcoUYkxyxJon0c,727 +torch/include/ATen/ops/linalg_tensorsolve_ops.h,sha256=yFKFsT1vXTRzTNxCKIN5o_UZPtRvhB6klVGn81jp4es,2037 +torch/include/ATen/ops/linalg_vander.h,sha256=0Z0wqovbfZC8ft8PUcBBEZ648nAHxC7xn6uRFbBoDEQ,1634 +torch/include/ATen/ops/linalg_vander_compositeimplicitautograd_dispatch.h,sha256=Dmx4NaOZZ5g8o_y9taSZBSKhTtZRkcT1sj5FesJLc-A,947 +torch/include/ATen/ops/linalg_vander_native.h,sha256=f2anfXt_SH3CX_2MyMtYvnem6576ckjPQntTOJ3Vduo,555 +torch/include/ATen/ops/linalg_vander_ops.h,sha256=fADhSuxQ7NJL8b7MdomeXjcSEjj4fuOhz5tJ0tkS_g4,1111 +torch/include/ATen/ops/linalg_vecdot.h,sha256=CkoTypidj5uGKM84wGoPL_vm8KuuMrjylIPNfqNDwnQ,1314 +torch/include/ATen/ops/linalg_vecdot_compositeimplicitautograd_dispatch.h,sha256=7FhuvPZT2c-JH6FKu9KWn6prRTxzZspVa_Pw-GOPk8A,1069 +torch/include/ATen/ops/linalg_vecdot_native.h,sha256=OKYZIng8f_X0o6qIsNZwVQVuvix1mKMh2dmXH2crnPU,657 +torch/include/ATen/ops/linalg_vecdot_ops.h,sha256=RMNLJWVkNq6W8ssdU3X6g3EOOprvAecTDJhXSOtMKtI,1856 +torch/include/ATen/ops/linalg_vector_norm.h,sha256=UEWQ4PSgG-A98yASGpWNM-EZF9iHegvxMU41ejCLoEs,1890 +torch/include/ATen/ops/linalg_vector_norm_compositeexplicitautogradnonfunctional_dispatch.h,sha256=9f7b95MVpSLRqh98UN6drY-k93sUXaj4Nz1U_DXIo10,969 +torch/include/ATen/ops/linalg_vector_norm_cpu_dispatch.h,sha256=JrMKWdXNp3I7GZjywxmbmOrxfzL3Evfsu5dbSqYaWcs,1332 +torch/include/ATen/ops/linalg_vector_norm_cuda_dispatch.h,sha256=HdqL_p4S95pno05c8sqMXOgY-Z_0RWu40_VI7DG7Pbs,1334 +torch/include/ATen/ops/linalg_vector_norm_meta.h,sha256=zYkmjaaw_Z2zHE9BR55JdP-aypFQoyWP_2_D9192MVg,712 +torch/include/ATen/ops/linalg_vector_norm_meta_dispatch.h,sha256=h5qYEQ7Nf3QlmzQRPrHywffNbc8HKzZ2-APCeOdI5HY,1334 +torch/include/ATen/ops/linalg_vector_norm_native.h,sha256=oM-ds5-sWXi5FN85SJcaQSW227fLGC41ks-7M_X2N0M,761 +torch/include/ATen/ops/linalg_vector_norm_ops.h,sha256=q90EC4iyo1jbPYymwsmBxBwuEY8jCDYHajvt-jHvJPg,2406 +torch/include/ATen/ops/linear.h,sha256=dVqRLtQnhgF0lKcsYYqtHuVTKiipTUJM5jWcXRRPrYg,1433 +torch/include/ATen/ops/linear_backward.h,sha256=uqbHW7oUPSQlMVr8sbFAKCdEeM0eBFkuvZthwf10YmQ,2066 +torch/include/ATen/ops/linear_backward_compositeexplicitautograd_dispatch.h,sha256=ziZIHjcPwk-5JxJvEUgFhgd0EFn05OnWadzo_ddCvD8,1253 +torch/include/ATen/ops/linear_backward_native.h,sha256=SB_cFa7znQPxUb3eC-qLKMrVgzMCxHwtyGQx6GUd-gc,902 +torch/include/ATen/ops/linear_backward_ops.h,sha256=GB6aXKbzIX3vTYgECaM_ONVkQ5qpw8R4plnzH8lVfxY,2662 +torch/include/ATen/ops/linear_compositeexplicitautograd_dispatch.h,sha256=KkG_vSff22DFNODvFKnAVZLieWW3KqgNJHoSRqph9s4,1034 +torch/include/ATen/ops/linear_compositeimplicitautograd_dispatch.h,sha256=GS-fSZveweBPNlukOhlYvgxxQjQ4MiJfU6jt16_mJKk,860 +torch/include/ATen/ops/linear_native.h,sha256=_PTz7OvDzZwYkvxIOiKBgx5Qg4IJ1YOd6kPaSUcOIfQ,854 +torch/include/ATen/ops/linear_ops.h,sha256=mXLUUYQ1McjbVyMT3Dbm3vqoUWirOk39OG36f552qOc,2051 +torch/include/ATen/ops/linspace.h,sha256=5lL2yRuKUhP4-ya2YuWWywmC-WO_M5SSP1l96fZ25UY,6974 +torch/include/ATen/ops/linspace_compositeexplicitautograd_dispatch.h,sha256=GSTWLMPTsudD9TpvAx006CHe8t1NBFyO4yDVxqb6arY,2937 +torch/include/ATen/ops/linspace_cpu_dispatch.h,sha256=JVHrj_cl7YfZi7g2UntlGqaWZlEbe7VyGtXbpgoBZRE,931 +torch/include/ATen/ops/linspace_cuda_dispatch.h,sha256=TFIjorpyhCO4h46OpwB_iHVVGMoei5Z3f93EfDA7NGg,933 +torch/include/ATen/ops/linspace_meta_dispatch.h,sha256=SW8Ro1fApvgvt2WlBQYT_nGYJfhu7prK-2gkpAMb3kc,933 +torch/include/ATen/ops/linspace_native.h,sha256=md5O3kt53L_Nsaf0RFZKqRauBMQzpI_4KnH9QV51VXo,2068 +torch/include/ATen/ops/linspace_ops.h,sha256=x05-lECw6FqsBHThimD_sRLpPWqKQiMYr9VMOwhonV8,8268 +torch/include/ATen/ops/log.h,sha256=x56LVBDequryUiqKx6HoFojGFzH2Un4LHRPiYc0xGmU,1175 +torch/include/ATen/ops/log10.h,sha256=OoSPySFNyXim6fWlpQcRySAEHP_Rhec96tJr-9PR8Vc,1201 +torch/include/ATen/ops/log10_compositeexplicitautogradnonfunctional_dispatch.h,sha256=m-T5Lht5kV6_p_s2eDf8LSjIgo69aRXCZb0Yy-iXZYU,863 +torch/include/ATen/ops/log10_cpu_dispatch.h,sha256=IS5RSglcVkdiGs6-vbhhDpeyUINKtwz_NXROsukNxRA,950 +torch/include/ATen/ops/log10_cuda_dispatch.h,sha256=FUWlAd8aCXJIQqAHn7iqFyifYpmZLLqDrlLsJwcNqLQ,952 +torch/include/ATen/ops/log10_meta.h,sha256=jnOHBXjKx5QvXKjADu-K8zgCC5ld3QUs-J-WZDESWno,593 +torch/include/ATen/ops/log10_meta_dispatch.h,sha256=Zeob0tpDyHluSkPB8xQUbT9LBfXsApYZpd8RHYTFrJU,952 +torch/include/ATen/ops/log10_native.h,sha256=7FxC1V7yS3wUoI79XxzzYGH7UEz_slUCZsO-r6I-SJ4,616 +torch/include/ATen/ops/log10_ops.h,sha256=3WpW84O5Gnb6Z3JYKittGt4pPjH6Eh3qt0W__EYP4v8,2088 +torch/include/ATen/ops/log1p.h,sha256=Q1TCL9u4Dk7ElQPf0COqilNzT3B0ctwiM1coOLFBSOo,1201 +torch/include/ATen/ops/log1p_compositeexplicitautogradnonfunctional_dispatch.h,sha256=XW20tRc0PEjr8az4YhS-sBYIPXGLXnLT_4UHe61_uJU,863 +torch/include/ATen/ops/log1p_cpu_dispatch.h,sha256=AR5G4slKywedUcTfF3pmiKMKVAckcm7nBWFt-KnONH8,950 +torch/include/ATen/ops/log1p_cuda_dispatch.h,sha256=wp1tKHb0eDOsCHsCoY-Guq3Y-tdOurG5qeVqijHbRkY,952 +torch/include/ATen/ops/log1p_meta.h,sha256=PxhFWjN5Otc4NuffVfkNNNJPgmZocHPxviQAikh40yQ,593 +torch/include/ATen/ops/log1p_meta_dispatch.h,sha256=1zcZdWA_yAD6k21971SEorN_FcqdV-JOuu1OOEgbkTs,952 +torch/include/ATen/ops/log1p_native.h,sha256=A-ZtkIboVh65g07-JwjS44FFh6K9JH8kI-6mHDQuGKc,1036 +torch/include/ATen/ops/log1p_ops.h,sha256=6FD2dRxE9anf0N3D1OaOHzv35F8jpWJTgFI03bKmHTw,2088 +torch/include/ATen/ops/log2.h,sha256=d9tjqRaoe08Q2ZXgaTLvPIm9Y7v6W1HFMOV70qoBvGo,1188 +torch/include/ATen/ops/log2_compositeexplicitautogradnonfunctional_dispatch.h,sha256=9lOoZAV3sWO4YrA9FnS4t8JvXNYWIpHcZR8eM_CBqAM,861 +torch/include/ATen/ops/log2_cpu_dispatch.h,sha256=sR6n55fpzvroxC_QAoQ1uKML55YjjTmptpPGFT7UZ-k,946 +torch/include/ATen/ops/log2_cuda_dispatch.h,sha256=Wimy2puCD97ccRiHNAWiBdW1Iq6wwI_noCOJMGMl-Fs,948 +torch/include/ATen/ops/log2_meta.h,sha256=bYofsjml5kmu5_r5UEMtLyuRAG6bwIs22xi7sGXRu_I,592 +torch/include/ATen/ops/log2_meta_dispatch.h,sha256=DgWEsEn34IWNsjVCaM0gDc3HDoSXHmzwg1ubJB4PAP0,948 +torch/include/ATen/ops/log2_native.h,sha256=J3-510l8O_KX64FYYTYRw2UmA62h-FPHet4AYUnKmy4,613 +torch/include/ATen/ops/log2_ops.h,sha256=A5C_DcptJRXJ14gxipnmPwAcazBG9GtL4afM4-biK0Y,2079 +torch/include/ATen/ops/log_compositeexplicitautogradnonfunctional_dispatch.h,sha256=LDvHxonRyHLAhaen7Jf1fGUSp6Acj1qqWWkRHGQLxME,859 +torch/include/ATen/ops/log_cpu_dispatch.h,sha256=ZA0Fbc5bo-n0I6-zlF192wDuK8PttCJNrfzs0gZNE84,942 +torch/include/ATen/ops/log_cuda_dispatch.h,sha256=1faZ9WEzwr0zM3fN51Zmr3RiZseFX4nUVdtTeweQMfM,944 +torch/include/ATen/ops/log_meta.h,sha256=L_PFgR8urg4xhGS9prRAgzynZCf7plV_8LClPDUfp2g,591 +torch/include/ATen/ops/log_meta_dispatch.h,sha256=y3VyrU5iUxElUX7Dw5qDBdBWzgkiIhUwJDyvkwcpIDw,944 +torch/include/ATen/ops/log_native.h,sha256=T0o12vzESKwtYpl2b_NkN0UdSgNjNRkuKp07YtOMHOs,610 +torch/include/ATen/ops/log_normal.h,sha256=k3_w5WnKaNxX_oi2Z0FPXyJJD479cYpg6SpmDPASYJ0,1577 +torch/include/ATen/ops/log_normal_compositeexplicitautograd_dispatch.h,sha256=FI2U-C9FCdCgKFn90Mde1_9FalTEOkvwBWanrCfgOE8,1197 +torch/include/ATen/ops/log_normal_cpu_dispatch.h,sha256=uPY9GCtsZkjx5HTlcuWzY5PB2NC_18qLJrtmYRaz26Q,830 +torch/include/ATen/ops/log_normal_cuda_dispatch.h,sha256=W9VWcEEMxuVYurwUFZ2kFwc1YCkkGLTin4BzeckXY8Q,832 +torch/include/ATen/ops/log_normal_meta_dispatch.h,sha256=lFHCEk0Y4Sbo7Q3KX4lHmumwjWdBLLHqWwecx90CUmI,832 +torch/include/ATen/ops/log_normal_native.h,sha256=yE6BkYRLaFdPhEENuXuDAkkGCP2VP0VVpF5QHjdD0zs,879 +torch/include/ATen/ops/log_normal_ops.h,sha256=npHD7Rsw_L97KDMzp3U2OfsmRlta7nWtc7C0wMqgwnk,2847 +torch/include/ATen/ops/log_ops.h,sha256=KANPrAxfPj92WHZplBpJ3bxnftG1oicEo31xR3QhnFc,2070 +torch/include/ATen/ops/log_sigmoid.h,sha256=_GivhvqO2iCil3ZFHsGzgy-RqxFp-Hp87p3mlSKMkl0,1117 +torch/include/ATen/ops/log_sigmoid_backward.h,sha256=CwRe6UEklPI1lyzZ05OmPuEpeRBd7roqomBgTmBWBgM,1622 +torch/include/ATen/ops/log_sigmoid_backward_cpu_dispatch.h,sha256=w389G_vvr8WMbDfILS6mrA5iAtPoBcffFPNPqwcqu6Q,1135 +torch/include/ATen/ops/log_sigmoid_backward_cuda_dispatch.h,sha256=ug6k4QhII2xSdE8Vspzn7JWPTnXmWORqkiUPOLQlA1w,1137 +torch/include/ATen/ops/log_sigmoid_backward_native.h,sha256=LSfYeJDNcR57usEq7IsRVM8dspe1YCujk5WYEFk7ZSw,1034 +torch/include/ATen/ops/log_sigmoid_backward_ops.h,sha256=wmM2t8Cso5aEf1-GTLmIjVzmct0umCbYXe-lb5R1UTQ,2099 +torch/include/ATen/ops/log_sigmoid_compositeimplicitautograd_dispatch.h,sha256=gOZuTTko3JdjIJJcoS2OcZNKPrIQ9WPBEeMPcYZEMZs,961 +torch/include/ATen/ops/log_sigmoid_forward.h,sha256=cgwW4ijTyoD_b_tnENi26iLT4GWS54Km8L_kNCx01ng,1454 +torch/include/ATen/ops/log_sigmoid_forward_cpu_dispatch.h,sha256=P_9hSV5NXT-D48fOkC0ZQi40YrFnpCjNFjdwm_0J-88,1068 +torch/include/ATen/ops/log_sigmoid_forward_cuda_dispatch.h,sha256=deNvfbB75RZp6y-2Nsp-Kd66lVQqDyDruyORNNgRGZU,1070 +torch/include/ATen/ops/log_sigmoid_forward_native.h,sha256=HRCOBKCVyuLmpi7Woq5eC8zhqJjoD3BFlRETChy1z5U,932 +torch/include/ATen/ops/log_sigmoid_forward_ops.h,sha256=7_9DhmEXv7ll7HtdLNel3QqEecaROW21zWNi4YD1xpI,1952 +torch/include/ATen/ops/log_sigmoid_native.h,sha256=r5nyMuZLLCTAXh4xOknN-COu4vO3K5ORpEyQnGcimJw,586 +torch/include/ATen/ops/log_sigmoid_ops.h,sha256=Fci_nG1usB6bVGWP3YZSXxE1utFoYmbQcv1hLRuRGGU,1617 +torch/include/ATen/ops/log_softmax.h,sha256=6auiyKwIu4Lh86Q6GkrOXDjmOpjz2Wbysv0pY0URlME,1756 +torch/include/ATen/ops/log_softmax_compositeexplicitautograd_dispatch.h,sha256=UukHdvEvMZ7fmXQTbYeHSHDm-7JFy4HkfUvd7z04hkQ,1020 +torch/include/ATen/ops/log_softmax_compositeimplicitautograd_dispatch.h,sha256=aTTMpH0L6h6pe5CdeOitmKqlVHZ53isQp4dzwHeWEmM,990 +torch/include/ATen/ops/log_softmax_native.h,sha256=SWWJqLim87ykg6BjtUQXr_0AAie9m8CsrYgjq2zJ_9Q,836 +torch/include/ATen/ops/log_softmax_ops.h,sha256=LOJmPWopx78XYlF2A-nlnQ2HgTOtwBmazVNStkD3ud4,2745 +torch/include/ATen/ops/logaddexp.h,sha256=LQVbUuUF9v85wm9h8JTynVzhuMxp2o3ZxeVMzBUS5CY,1238 +torch/include/ATen/ops/logaddexp2.h,sha256=48wNQblJVqefloFX12Cpiruf0wtY5zHb90OEY2ERuj4,1248 +torch/include/ATen/ops/logaddexp2_compositeexplicitautogradnonfunctional_dispatch.h,sha256=_bRu_kGwj81bPDAfBznZwAr2UnCU91H3wBarrIj6fow,843 +torch/include/ATen/ops/logaddexp2_cpu_dispatch.h,sha256=DOnYEX6DVT4QyiH7rNbrXn4n4nTAQpUeD3UhC3zTFIA,992 +torch/include/ATen/ops/logaddexp2_cuda_dispatch.h,sha256=lBEiTzjShIOdgPjvbqNYV2rWiyDBTjU4FvEGahJQYMU,994 +torch/include/ATen/ops/logaddexp2_meta.h,sha256=yIdU0sfor8ueW3wJBZV-TrDhR-Peid7wRKeOcBebawk,624 +torch/include/ATen/ops/logaddexp2_meta_dispatch.h,sha256=eQHbromfamm9LNpz7RM7-AWXt9sy87pdHTySpwvrMvQ,994 +torch/include/ATen/ops/logaddexp2_native.h,sha256=japKiIRPSvLRT80dHOAxwozgnQ9AgQa6YnldufSY1QM,657 +torch/include/ATen/ops/logaddexp2_ops.h,sha256=x64e60_gKc4eVlbBMfjsEbAJJV0x6QOz_ZRvk5fNYMc,1783 +torch/include/ATen/ops/logaddexp_compositeexplicitautogradnonfunctional_dispatch.h,sha256=r6ZEkNIW_O9SJkEtsq7LF3xKDfN0F2J61EHHCreqBuw,842 +torch/include/ATen/ops/logaddexp_cpu_dispatch.h,sha256=J3KMv4X1nMvzvPLuddqxTIbpKjzHHWBbeDwIdfYdUJU,989 +torch/include/ATen/ops/logaddexp_cuda_dispatch.h,sha256=vVgynp6ss53EUIfrOlJ017afE9rpMXIrmj-vAuc6awQ,991 +torch/include/ATen/ops/logaddexp_meta.h,sha256=E42UqlgGP0dXK688t_fo8WcLDGSpNRjHeW82EL6m7nk,623 +torch/include/ATen/ops/logaddexp_meta_dispatch.h,sha256=i1_lVQhSSjpPh3dQ0S08grECPt7lf92r_xM9Y5DWWDs,991 +torch/include/ATen/ops/logaddexp_native.h,sha256=Sb6LkfQK_8jn3ikjmr2bEfyWMkPhjESTjy9V3y7qCQ8,654 +torch/include/ATen/ops/logaddexp_ops.h,sha256=O3ulGARRlSZ50j0S8cise_gY2DuZXmpB2Cpf1RyL9Es,1777 +torch/include/ATen/ops/logcumsumexp.h,sha256=BXNKj8q8CKZNMUZdr0EkKaHc9uehH7lIzHlGhlPLh4c,1954 +torch/include/ATen/ops/logcumsumexp_compositeexplicitautograd_dispatch.h,sha256=9khBfJcVPbyJEAJL8mSEsIOGYBJSIyo06XaaSf7MGxo,1003 +torch/include/ATen/ops/logcumsumexp_compositeimplicitautograd_dispatch.h,sha256=WCmx0AKiEU29tkWUz4mlnvQe7RFG4YLnp1DN7a8Dr7c,1015 +torch/include/ATen/ops/logcumsumexp_native.h,sha256=V9GBB8f1Zu9wGMTlmeANDb5Sv2jRS5M57DiJhcSUzI0,794 +torch/include/ATen/ops/logcumsumexp_ops.h,sha256=fLxG3I5x4swB1NFUEN5tsAuYnGysBT3PvG_nFRgYxnQ,3026 +torch/include/ATen/ops/logdet.h,sha256=l5locCdWAs6_2_BxDvn-C97mbwrnoOXpkSgRYroqKcQ,668 +torch/include/ATen/ops/logdet_compositeimplicitautograd_dispatch.h,sha256=ll6mmji_0ZuEuaAv8c-RB61n6vHoA7pzuHXm_HGnnZI,787 +torch/include/ATen/ops/logdet_native.h,sha256=-aBcRd3SPcX462xBqnWNzbRO2BtFSD7lY-R7IrPnOcI,497 +torch/include/ATen/ops/logdet_ops.h,sha256=jzyQ4CQOkFFoh5QIl89hF_RbabHzIXQennqMKGqBAXk,986 +torch/include/ATen/ops/logical_and.h,sha256=FLzmZMf15UVFyZSjFQ2YEu5pjHQ45Dp3WvQgbr0xfUw,1258 +torch/include/ATen/ops/logical_and_compositeexplicitautograd_dispatch.h,sha256=JTdsIPhezL2BGb_CH08uENjXW25f3JlsgORWNj10B2E,901 +torch/include/ATen/ops/logical_and_cpu_dispatch.h,sha256=G2NmF13Bv5jCSjy_NeTmAywJqL4hRYP4LzeAZNfK8t8,909 +torch/include/ATen/ops/logical_and_cuda_dispatch.h,sha256=vLFhEeFJTfgQt09snnRjc-0pOmsODIzVlBy_B7Cykew,911 +torch/include/ATen/ops/logical_and_native.h,sha256=9-QQSHSpEu22ZdpsnU-qUubf0pAmURCpwCG1SAU7EMM,721 +torch/include/ATen/ops/logical_and_ops.h,sha256=PLLXiZBnSvjdslezYmhySANp8Xub22m0XRFGmrLZ_wo,2400 +torch/include/ATen/ops/logical_not.h,sha256=JGA20IEHvCxjFOjFAory6cejVbx7TGQxVTDDCl5_Hf4,1117 +torch/include/ATen/ops/logical_not_compositeexplicitautograd_dispatch.h,sha256=eN_zb5FF-w8wa2kmMtKJ4VDTpMjhA0Wr4nR_4D408xk,849 +torch/include/ATen/ops/logical_not_cpu_dispatch.h,sha256=-0NeXyOgPP3X0hmORh4cMSqGzCbN5xvxr6rcaIM9IAs,857 +torch/include/ATen/ops/logical_not_cuda_dispatch.h,sha256=zhst5qE49Iqh0UZtW8qayyuHikhV_OtcRfJMRxUH90o,859 +torch/include/ATen/ops/logical_not_native.h,sha256=TtQxMLeHOLTwyf-1absEkEGsa8AOcW2ARYTOprMH6_4,786 +torch/include/ATen/ops/logical_not_ops.h,sha256=nijTVUOCVxlYhapuBuBlA_7-wrxofILhQ3DlmRd32N0,2142 +torch/include/ATen/ops/logical_or.h,sha256=KI-TPR7zYXfBL4dnsIxTCSLbPYEDg8F_R2cjikdjI3s,1248 +torch/include/ATen/ops/logical_or_compositeexplicitautograd_dispatch.h,sha256=9dDxtNwXtx60f4G0oHDjti_KvUGcVy-8A2yXfMgdwqI,899 +torch/include/ATen/ops/logical_or_cpu_dispatch.h,sha256=YepEaEDNWaZp9FpuFP4WA0OLQtZryN0ivrSTbwH2wg0,907 +torch/include/ATen/ops/logical_or_cuda_dispatch.h,sha256=S1xJfRroy2IkBNJp2IzuY9ewZV9VdiHteZpXIb9k9T4,909 +torch/include/ATen/ops/logical_or_native.h,sha256=gJMNSL24lxcP3a5IC1qyuhO8JAIG14J7Wr5dgmq_9ss,718 +torch/include/ATen/ops/logical_or_ops.h,sha256=0B0VFZKZS_ZWz0aosRE2kLhyg9_Y5zWms643E2GtT2g,2391 +torch/include/ATen/ops/logical_xor.h,sha256=TqKs4-xpJu8eg-Jbpc4cyjma0A4B1ApwHA5qKvOWV7o,1258 +torch/include/ATen/ops/logical_xor_compositeexplicitautograd_dispatch.h,sha256=DCSwtgLLBnSBChasP-0ei23_xcWaZxh9xRznYYZohlI,901 +torch/include/ATen/ops/logical_xor_cpu_dispatch.h,sha256=X_zoJQ3dLZEZ55TkQtcwnbZd1KlzM8E0j-03rvrfDuQ,909 +torch/include/ATen/ops/logical_xor_cuda_dispatch.h,sha256=_iG2mZcP7UrQH-_QyVLfXviDikhLEFr06PykWzA2qjk,911 +torch/include/ATen/ops/logical_xor_native.h,sha256=b2sEIwUZ1RZk0lk_yLpc6epOsSW1403npQcoFc6pX1M,721 +torch/include/ATen/ops/logical_xor_ops.h,sha256=iX3ZT_eLSl43_SzkD22kA7gnth8dbed3usy-gbOYgkc,2400 +torch/include/ATen/ops/logit.h,sha256=8XXCFE6wmZBx4-kNMv7NsZ-G67ByxVCBViMXN9W9sxw,1450 +torch/include/ATen/ops/logit_backward.h,sha256=4_DBsMw1dyNlbbuOWzmXpnK3O4NiUfqVmBKKGyldYqs,1595 +torch/include/ATen/ops/logit_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=otuIVGsE9SpEmQhSyAA352bQ6MtgZdrE-vHoJrY8mYw,897 +torch/include/ATen/ops/logit_backward_cpu_dispatch.h,sha256=xBBRm0ICDBLNwCbvGr68S-ZQiXqd5K4HSczC1MyFGog,1153 +torch/include/ATen/ops/logit_backward_cuda_dispatch.h,sha256=uRpzMglkMAKYmRK_ccKzfJQx_CL_IBs68hC9FK58D6g,1155 +torch/include/ATen/ops/logit_backward_meta.h,sha256=8H6KMwoYTRG8h5hoBAxLrt7gHr7AatmLlYYDpoKXh30,663 +torch/include/ATen/ops/logit_backward_meta_dispatch.h,sha256=ezw7UGOr51vTrzuXUejPdg_UmvLDVeQIfF7FR-BAaDg,1155 +torch/include/ATen/ops/logit_backward_native.h,sha256=fItYvMEx9nSVwZDdGmhiSLLpT5_VdUUKdv4jia-nyIU,711 +torch/include/ATen/ops/logit_backward_ops.h,sha256=TstPrdeNPuHjt8-iS3Z0cxLJFKCUcim8yTEACVOgDiE,2085 +torch/include/ATen/ops/logit_cpu_dispatch.h,sha256=AnAmNi1cE9EkidgK87xa97xLUYBAYYKrKt5l23MDHjw,1111 +torch/include/ATen/ops/logit_cuda_dispatch.h,sha256=SgCMhzm2bL8rX4uPln35cZHl072mIYytwf4kgweeUgQ,1113 +torch/include/ATen/ops/logit_meta_dispatch.h,sha256=TjwHhFT4b-HLtD3H_MHsaIVvFhYoYM3D-ipsi2aeNXc,785 +torch/include/ATen/ops/logit_native.h,sha256=Z4rY57caPltET-47xsxCUjvx23qn45mzzYzBvqi5VYY,742 +torch/include/ATen/ops/logit_ops.h,sha256=8hU3aqMiBSpoZyLA05E5OJzdRU7FuD5IDLFmGne3Bls,2388 +torch/include/ATen/ops/logspace.h,sha256=2jHOSf4_zPrJDPu2IIi5u0nriMn7cVO7czSEqWnTogk,7590 +torch/include/ATen/ops/logspace_compositeexplicitautograd_dispatch.h,sha256=pZdtvNemo-Uu4T13kTnDMbrgf-6YKX3y1YMmxu8lHcg,3154 +torch/include/ATen/ops/logspace_cpu_dispatch.h,sha256=xQ-ekKZMpFmvwp65krkaQyCRX2_68r-c2qKZ5_FEtps,962 +torch/include/ATen/ops/logspace_cuda_dispatch.h,sha256=hNxvc8l3tsaP1af4qWbkLwabEZKvKsYmSFIixXi0Xz0,964 +torch/include/ATen/ops/logspace_meta_dispatch.h,sha256=3k1mpi63_q4-5qWmvbjNsTO978WwaewmnWsw0TPvfUY,964 +torch/include/ATen/ops/logspace_native.h,sha256=eLwVSfK9ICVS5n9Y2_JjCSTMRfkwMj5LYhNsEUD-KR4,2205 +torch/include/ATen/ops/logspace_ops.h,sha256=fwJ3rzeCljFYzjl7g5lxUZvTAZ3Zsb6yEkIBTVrHS2M,8676 +torch/include/ATen/ops/logsumexp.h,sha256=_IryyfZBVKQZGs9SrzzJX9ALjNd5CfNWEu_eQWVjpMM,2221 +torch/include/ATen/ops/logsumexp_compositeexplicitautograd_dispatch.h,sha256=ypPcTNoS9GXGl76AZS1BSDiPbCLrMdKOPhA_xW5Fqto,831 +torch/include/ATen/ops/logsumexp_compositeexplicitautogradnonfunctional_dispatch.h,sha256=OLo0l7KHFMVs1GRquhk4lAz3EhKhRYifWT9ETEZnR8s,999 +torch/include/ATen/ops/logsumexp_compositeimplicitautograd_dispatch.h,sha256=UkryT07Ooiwmb-Byki2RG7TqOpQmM0rmLQNRxO4enYE,1072 +torch/include/ATen/ops/logsumexp_native.h,sha256=fySyJmwAzPxnGcgoSm_Da8qJMTrqiNM6KRWOiv13K4Y,874 +torch/include/ATen/ops/logsumexp_ops.h,sha256=329JM-QkTXhPYYhom0tL5X6lHHXpzU3B5gVouSxPDXY,3278 +torch/include/ATen/ops/lshift.h,sha256=Bozs57AlZnXjXRRFMQRh9RzmFhRKqjz7LK6HxKRCLRg,2044 +torch/include/ATen/ops/lshift_compositeexplicitautograd_dispatch.h,sha256=IVFjsHQ-xrShrrWgGQrm2d3YbXNTfSKQTWGlCmjAEW4,1170 +torch/include/ATen/ops/lshift_cpu_dispatch.h,sha256=HLXZhQNRhSpma1uD2bbYnX35XFkPcAzhw0ssp7k34Bg,1022 +torch/include/ATen/ops/lshift_cuda_dispatch.h,sha256=jU2BZ-U59cvO8Dibc7bGWVcy-9v3QoS7Jq5AxhEZB8I,1024 +torch/include/ATen/ops/lshift_meta_dispatch.h,sha256=gyQmMB8g-91bG2N5WR4nBbEC8c6XK4GeQ7ZVnXt9A7A,854 +torch/include/ATen/ops/lshift_native.h,sha256=eElf2E15fHxcoiOylulRbmYYnOz--yh96ytkdCSl1I0,1008 +torch/include/ATen/ops/lshift_ops.h,sha256=pWfdHbD4uMUpEpBcvPVmGSdH47uB7_aKomrPzHITUc0,4429 +torch/include/ATen/ops/lstm.h,sha256=a_rtbOAxkP0NROhVWP8oGEwDnwALv34N3NOyWAeMJWk,1649 +torch/include/ATen/ops/lstm_cell.h,sha256=fDczyMxAchghz3j97NQ9Y6e0uOvufVG3VrCAwt6uWOM,982 +torch/include/ATen/ops/lstm_cell_compositeimplicitautograd_dispatch.h,sha256=vSDw6OHk196zhYNqUmKyTP6jtsYO4NJtXekBlPGRSOg,975 +torch/include/ATen/ops/lstm_cell_native.h,sha256=QzhU4E_NV6-NMjiDJU8yWrkyj9TgiCfVP2IWt_RW_Yk,685 +torch/include/ATen/ops/lstm_cell_ops.h,sha256=GhEerZbpEAiogR2dugwNycHGOsDe1jHLx5YpcN4l3FM,1596 +torch/include/ATen/ops/lstm_compositeimplicitautograd_dispatch.h,sha256=cZ7MZKuEtdSca9qIpkhSmUN7A-qzBQMU-W7GoRvvuog,1215 +torch/include/ATen/ops/lstm_mps_backward.h,sha256=xI6Jh7yMdzUrrFSG8Obj0ghFu_pgU8cMRzhy2aYwMiM,3715 +torch/include/ATen/ops/lstm_mps_backward_compositeexplicitautograd_dispatch.h,sha256=2jtNZLCm-YQGPXZjyUiETrST9V3q-eURdfRtaSCDMGk,1737 +torch/include/ATen/ops/lstm_mps_backward_native.h,sha256=Ph5Ou-mRd7fQsqhN9JD9ghVjQ6SGDgm_BFclWlZ7QCs,944 +torch/include/ATen/ops/lstm_mps_backward_ops.h,sha256=yuiGesv-xK-NE1C7zPNXeQkpat-Ur8BmCMF4CF2TIVc,4474 +torch/include/ATen/ops/lstm_native.h,sha256=EWd8gVe_kattHnvY2h8yMjXmFNz1Rcw-4cPNpVIhVKY,925 +torch/include/ATen/ops/lstm_ops.h,sha256=hwpIk-Ki6JFIPbJVigoIJlNlCzxk2s99kRYGnUFBxw8,2797 +torch/include/ATen/ops/lt.h,sha256=8Mzz4YLt95j2dpW6uzh8ChaaAim5IIBcXjr3Uv24qUk,1896 +torch/include/ATen/ops/lt_compositeexplicitautogradnonfunctional_dispatch.h,sha256=cDcw01H_DNWYfPiJS0sMZ2rSwNuTkf20ReJ2VBMKtgQ,1060 +torch/include/ATen/ops/lt_cpu_dispatch.h,sha256=EJ0_MpyqdqLIKCFgO5I14K7CQq1jyI1obN7TpJt5gdE,1396 +torch/include/ATen/ops/lt_cuda_dispatch.h,sha256=GgJFuNMLI2GiseqhrB4BJ6d9NWTaiVKAm5hN5ZeNVzk,1398 +torch/include/ATen/ops/lt_meta.h,sha256=4qkNq7g2ifGnmWOeAqHs_gp3-M1qgrly9xRbqkaiUFA,767 +torch/include/ATen/ops/lt_meta_dispatch.h,sha256=ZzpTBPV3FovQ1L64_wouxVEWuaYgx59HBCpTiVmkpYM,1398 +torch/include/ATen/ops/lt_native.h,sha256=Ei-c_O9gE5BuAlEAIvgeNwEAJ-slPgfDVlHf7XusLG8,1235 +torch/include/ATen/ops/lt_ops.h,sha256=1nVRFKpkeHjyy43mr5T6JfKm7-B2vLRMYmkO8fgrUb4,4285 +torch/include/ATen/ops/lu_solve.h,sha256=GWjBEkxB1zNI7qCiAIGfo72VWCA3DEoiRZ0LZ0m0ps4,1423 +torch/include/ATen/ops/lu_solve_compositeimplicitautograd_dispatch.h,sha256=5hDUimd7Tf8VkmDjrwJwj2oxiqq6XHhj_5KOEvG7kfc,1126 +torch/include/ATen/ops/lu_solve_native.h,sha256=E38YvEo_-VnMuWHKbqnILFjImLO9sEBF9GYwShQ8mdw,696 +torch/include/ATen/ops/lu_solve_ops.h,sha256=yg_SSoLmfoV5-cAj2E32lfc-ovkK-6Pcdnb90f4HQgw,1979 +torch/include/ATen/ops/lu_unpack.h,sha256=XDoCQrhTg439jx6U3vTiv48oeRbmJ29GffYxQn5Wlr4,1987 +torch/include/ATen/ops/lu_unpack_compositeexplicitautogradnonfunctional_dispatch.h,sha256=MdZAjCKbNNoGJBIwmZtV-MrNv3PpqpIhj_aChJHcc1E,933 +torch/include/ATen/ops/lu_unpack_cpu_dispatch.h,sha256=x55j_aiHgX_d6qTHQLyNqLQ7I6zqGT1Wvnp2uHqHloc,1320 +torch/include/ATen/ops/lu_unpack_cuda_dispatch.h,sha256=AlBwXZ6afW4DlKDPKjxqGnpoPH2HSOS9trTjrP2tIAQ,1322 +torch/include/ATen/ops/lu_unpack_meta.h,sha256=hjKFOzjxb3DTAKa47uffu9f41C3rzPcVLQT0KQ2oohM,668 +torch/include/ATen/ops/lu_unpack_meta_dispatch.h,sha256=kLaQu5PVsxypBteAU2iN11zF33kDJVBe5cKNpJ6NSkM,1322 +torch/include/ATen/ops/lu_unpack_native.h,sha256=-fDYbJ5JrK8mNWZ-HUV92yUadER0aqaOfRJiw_CsjuQ,741 +torch/include/ATen/ops/lu_unpack_ops.h,sha256=gOT0jnw1it_6C3rF5gk0vyTMJrXFGUjhJfcojWXrcR4,2489 +torch/include/ATen/ops/mH.h,sha256=iixVx1qZKWSFou3u7KPy34_yFtxvEYOpk51x2kDn5vM,526 +torch/include/ATen/ops/mH_compositeimplicitautograd_dispatch.h,sha256=1wb7xnXxTFWICKpsB0BYmMjsGH18s0580IRMaozP39c,783 +torch/include/ATen/ops/mH_native.h,sha256=xRDtIdpSbRyMKGN9LhEfNdjPn5GXO8QOe4PXOC8ljpw,493 +torch/include/ATen/ops/mH_ops.h,sha256=bPKLNBffdDow1UfIeSfYkl1g3uXp5RfuBnEL8EyDguo,980 +torch/include/ATen/ops/mT.h,sha256=PTI4EJQYZhQ2A3o6ArwNXKmFxgSXndhbGBfzAY0evmY,526 +torch/include/ATen/ops/mT_compositeimplicitautograd_dispatch.h,sha256=xKuANkqb7E40nATpR82x4SSUt-wiy5iaU7B7JMNrMOM,783 +torch/include/ATen/ops/mT_native.h,sha256=2vAIUsX-imNoUUwWKKJvQs2apSNL22miEYjaeesHcAs,493 +torch/include/ATen/ops/mT_ops.h,sha256=beRuZg23oFnyfZVdM2UeFbx9csZ4XtnE-Dc5CG9ICDs,980 +torch/include/ATen/ops/margin_ranking_loss.h,sha256=b7SEH_Jy1mwRoA1QioyPNs4sfNA-hWrl2E3wyQT28ss,941 +torch/include/ATen/ops/margin_ranking_loss_compositeimplicitautograd_dispatch.h,sha256=t1YQVabL6kuGR9A-e1rJN7u4wyyL9U1maD6tLxjjjJo,914 +torch/include/ATen/ops/margin_ranking_loss_native.h,sha256=A_sA_Puf0ApNGO8MXge7Ha7K_Wk-YiEnP1n11NhlD-U,624 +torch/include/ATen/ops/margin_ranking_loss_ops.h,sha256=4_5qKoFOL5zeuXp80SAMszh5UMGNnuap4uikbwOmtN0,1332 +torch/include/ATen/ops/masked_fill.h,sha256=4WO1D0dndJcnEwvW94dRe--Nyx2g8tBFEfLyaJFsiOI,2331 +torch/include/ATen/ops/masked_fill_compositeexplicitautograd_dispatch.h,sha256=msf_kxy1bIMu8w4JBLzsBrLmK1H5snFWh4GefLpKVTc,1496 +torch/include/ATen/ops/masked_fill_cpu_dispatch.h,sha256=53qRrvpKsdBZqKODkszPA0PywPBdPa_EvhqGmVpJi2Q,904 +torch/include/ATen/ops/masked_fill_cuda_dispatch.h,sha256=0MT5nuxalieqeUQ5RqeRalYwFZ7duQRwMko0-Fd5NvM,906 +torch/include/ATen/ops/masked_fill_meta_dispatch.h,sha256=9r0BpuWQggnQrtW72ZIv2o81kzakEqnd7-x2ibyEFfA,906 +torch/include/ATen/ops/masked_fill_native.h,sha256=ieDm93anRfOi8LeOUNX38CDmtBPgAJ94cghzF5GeyzE,2012 +torch/include/ATen/ops/masked_fill_ops.h,sha256=_6H1tUBS69m5awkiukiz-b1yOATJZV18jcsUODJRoaE,4945 +torch/include/ATen/ops/masked_scatter.h,sha256=Rs2E_pzvak-eQWLINNukpLHNiOYC7hkKUjwKxJqHhwU,1429 +torch/include/ATen/ops/masked_scatter_backward.h,sha256=5Gz1ETosfIog0PlQJrdwFw27YgvfOYZIeitBVxN4I-Q,1838 +torch/include/ATen/ops/masked_scatter_backward_compositeexplicitautograd_dispatch.h,sha256=TQTlQLDXYwA1vXJ_dOW3hJ4NVx3ddamvHeTWMnjB3Cg,997 +torch/include/ATen/ops/masked_scatter_backward_native.h,sha256=kGHDnLwxCEo2rdX1JwjIAi6U_nh8FLEaupMWq_KflxQ,580 +torch/include/ATen/ops/masked_scatter_backward_ops.h,sha256=6cWlnZFRlGkRsP6ykYzTofzf4kyWJuJM7VSGn0chOuU,1232 +torch/include/ATen/ops/masked_scatter_compositeexplicitautograd_dispatch.h,sha256=5cdmUB6EbKaJJ1yrNT-BORDkHZwqSuVt9SvSzMDyxo0,1126 +torch/include/ATen/ops/masked_scatter_cpu_dispatch.h,sha256=P1Swoh3a_qxkUak-NwDQhyFvtR5YmuB8_Fb3njeoic4,800 +torch/include/ATen/ops/masked_scatter_cuda_dispatch.h,sha256=yqGUQrsEn4jYau4VGa59DnxxV11QhsaPZjhm0GG0rxU,802 +torch/include/ATen/ops/masked_scatter_meta_dispatch.h,sha256=dZDC8i36cbEJKBzPDSVs9eH3Y9MZAxVjj2bVODjtVMc,802 +torch/include/ATen/ops/masked_scatter_native.h,sha256=H1PB5PgoGTyUZIE_QfZr15BxeMolkUnAv3u1uGdxqa0,929 +torch/include/ATen/ops/masked_scatter_ops.h,sha256=1ri1eNUHM9Wj_7Zf0xRDPMz4UihlnZ449hvMtIHEC8s,2685 +torch/include/ATen/ops/masked_select.h,sha256=b07cBgOEKl-JFbdl6oCyZZJwAwjNtmx1fQh_P6lfEos,1269 +torch/include/ATen/ops/masked_select_backward.h,sha256=vjKLMWpde4w2VeOa_vw7VbzHJR84McFPpICbRE33Bcs,823 +torch/include/ATen/ops/masked_select_backward_compositeimplicitautograd_dispatch.h,sha256=e3Jio-_PnMGJ5hSeNTReYDnpJ2suyE24vTqW1eO-tNc,854 +torch/include/ATen/ops/masked_select_backward_native.h,sha256=hrbh7ytULGW_E_HxiJA2nWNwHhIspHJarUxZJwHKl0Y,564 +torch/include/ATen/ops/masked_select_backward_ops.h,sha256=KiaRE1Njqx6juvR0_H3xiUxeefM7xgw_w0nSAUlPfgs,1203 +torch/include/ATen/ops/masked_select_cpu_dispatch.h,sha256=4xnWkxQwGmLtyR--fM2nm1LQzTgUz048oP8-vZEmDF4,998 +torch/include/ATen/ops/masked_select_cuda_dispatch.h,sha256=bPFL9eK-avekNbM2gBkAOueUXGP6z4HOR8_J70EQS5E,1000 +torch/include/ATen/ops/masked_select_native.h,sha256=Mtg3pYaWBJKCDX0mQ8r3GH6P3PVsiypEqwIg18EJVno,856 +torch/include/ATen/ops/masked_select_ops.h,sha256=jdcaSNWk04qi_5JbxYiu_HyQnzXA7Ty3IbbgVGYqlI0,1795 +torch/include/ATen/ops/matmul.h,sha256=60ixRo5IfWKrgLVolsoCXUTIhST09q1TzF2-vwk78TI,1208 +torch/include/ATen/ops/matmul_backward.h,sha256=el37z-pfqGz5M-DD-vs3FhxHTgP8gc6cqs8bGJ9S0yE,1778 +torch/include/ATen/ops/matmul_backward_compositeexplicitautograd_dispatch.h,sha256=sG3wWCZm3YIijcPctAzRouMEay89FaVxuVOlLNrJGAs,1159 +torch/include/ATen/ops/matmul_backward_native.h,sha256=5gexnaMjl9k-Q3Xx9meE9olKyMGCzfbFuHeg1j8bnWY,829 +torch/include/ATen/ops/matmul_backward_ops.h,sha256=eMXiSOAFtSfJRRqjvWSd8HNMp4aqmWgJnI6JX8bhu3U,2411 +torch/include/ATen/ops/matmul_compositeimplicitautograd_dispatch.h,sha256=CfD_8OUh7065VntU-CkWXxlxJi3_qTuLf819_OxFrmU,1024 +torch/include/ATen/ops/matmul_native.h,sha256=JAM0eTlhkeLmYbuHwajLXXzOUzB3GY6pIJDoxukk178,828 +torch/include/ATen/ops/matmul_ops.h,sha256=URad7ArEWnIA6bKTH9hfMWIv0Ned0kH84tuST2G25X0,1759 +torch/include/ATen/ops/matrix_H.h,sha256=DoAGvqSTiQhnC-QGjPmlLqft2aoyBwUpUHMNZ_Q59GM,532 +torch/include/ATen/ops/matrix_H_compositeimplicitautograd_dispatch.h,sha256=fd0dIVE2Iy-AFM7wuhSS92-gSsZGcuedRmlr--bwn_4,789 +torch/include/ATen/ops/matrix_H_native.h,sha256=pLK6EUapcEPkPiv9r1Nw8bnMnJSEiL_FpCh_K8l0dkA,499 +torch/include/ATen/ops/matrix_H_ops.h,sha256=U8LbDlf1HdO158smZOUXn_5L0OpaehtXo_p1ZZy5NcU,998 +torch/include/ATen/ops/matrix_exp.h,sha256=cnl7rg-EZS3iUf27KZqTaUE1kGsh8HUOhtQ4zdrqRZA,684 +torch/include/ATen/ops/matrix_exp_backward.h,sha256=YqxX0KwyPyT5VaZ1cnAHkjSoRLsZaokF1RtkCD3pEeo,764 +torch/include/ATen/ops/matrix_exp_backward_compositeimplicitautograd_dispatch.h,sha256=IsotoHFqQhRiLK0Z62hXToiNsqfnWGMS-su6oRJoMtk,825 +torch/include/ATen/ops/matrix_exp_backward_native.h,sha256=L6Lrj0z1JnyL8ZD-4eCCr04nJ91Dskf0Xj5RafhHrh0,535 +torch/include/ATen/ops/matrix_exp_backward_ops.h,sha256=mOwB7pi9QmLyj5qUKwL-TQHkTGjxNLE7xuvbWQS8HT4,1108 +torch/include/ATen/ops/matrix_exp_compositeimplicitautograd_dispatch.h,sha256=0ZRWpR0DFXtQ7BaVPg4lUHqQhtsnnzZ7fFja_VUD1To,791 +torch/include/ATen/ops/matrix_exp_native.h,sha256=q_AG-tZXKtfZW4GYKWxtpJcp1IK4RITYhUP7Y1dVw3w,501 +torch/include/ATen/ops/matrix_exp_ops.h,sha256=zJLfz4n510ywHhowNMg9nZrwdFDDpxihZEr_4NkXEus,998 +torch/include/ATen/ops/matrix_power.h,sha256=U8DDJMz_dMGp8JFmWSZYz7YzjaHI7l0zpKYfHAvrnAg,1190 +torch/include/ATen/ops/matrix_power_compositeimplicitautograd_dispatch.h,sha256=BuD4OYnvx8Es-BnsUfFgGc0TBmzgqCalitzL_RpCnis,997 +torch/include/ATen/ops/matrix_power_native.h,sha256=JvQ5lRN6NN6dAnVOgnVyQgrKlrhdg5JkGSpQGqt-I48,610 +torch/include/ATen/ops/matrix_power_ops.h,sha256=uN-Hv_OJV-oXcbHyG8yhQ1XYWWpViUNVHO_9-DWy8sQ,1699 +torch/include/ATen/ops/max.h,sha256=10HtPW3tVQbGfKLnTRHOALeAtd7kDo4FePhE7PuUaDY,3866 +torch/include/ATen/ops/max_compositeexplicitautogradnonfunctional_dispatch.h,sha256=D_fQlP5c3mb2Q08RJ1LAakKhFFkEeJPLk264Jp9mJtg,868 +torch/include/ATen/ops/max_compositeimplicitautograd_dispatch.h,sha256=NnpL0m8LU7Mtap-qmnKhyxNXhzZRVsAE6bI6xyM8Oo4,1454 +torch/include/ATen/ops/max_cpu_dispatch.h,sha256=D4WnrdBMLIVp0LCeuFr84rVoNOd7GIQMJVIO2MqljeU,1320 +torch/include/ATen/ops/max_cuda_dispatch.h,sha256=XXtJBH3JlDMa8AEjls55uqGC_-ywLak9_G3lsqZ6tc8,1322 +torch/include/ATen/ops/max_meta.h,sha256=oigIU7X4FN5qEmQEcR315Ee6Sn_zpLSqq83e1-I3p8g,1076 +torch/include/ATen/ops/max_meta_dispatch.h,sha256=e2HZa_sfV8zapF_SJEXQnzaMA8G-1bPQBiS_mQWdfPA,1117 +torch/include/ATen/ops/max_native.h,sha256=ZQrQxlQY34WRXpW9wc900iUxFLZmLnUi-dtf2xarf4c,1528 +torch/include/ATen/ops/max_ops.h,sha256=yADE14CPtR8PpuTm2qdYiYoj-PmMvo9OSFfcaCO5gCE,6158 +torch/include/ATen/ops/max_pool1d.h,sha256=VtJcMqfxg4dRkUFT6G1EZimZ_4xf4L2dUf3-tDjGLsI,965 +torch/include/ATen/ops/max_pool1d_compositeimplicitautograd_dispatch.h,sha256=KlmC35wiP8Q1Ap_hzB9VI6KRKo8-czUlvSO1Ey_LRL8,924 +torch/include/ATen/ops/max_pool1d_native.h,sha256=IgQ4Fr6vbhuJmEX9hmVkX4FKOiLs_wEnNYssCOhuXoE,634 +torch/include/ATen/ops/max_pool1d_ops.h,sha256=j7BgGrWM0CWXfaGMK_9kYtinP6v0O-pSgU3o2Q5ncZ0,1409 +torch/include/ATen/ops/max_pool1d_with_indices.h,sha256=JdBMLvO6TVmRwF8_wNSzE3cSte21Gs0B5iTn5NW3G4Q,1052 +torch/include/ATen/ops/max_pool1d_with_indices_compositeimplicitautograd_dispatch.h,sha256=FTDkIcG7HhgjAHXbv5vcvnVboWfUutpxWr60OqJh0Ow,962 +torch/include/ATen/ops/max_pool1d_with_indices_native.h,sha256=GPgalqZPOHO7oNTthrvR_hKiJHnQzzIWVyrhw-9lSzI,672 +torch/include/ATen/ops/max_pool1d_with_indices_ops.h,sha256=h-zqzAb93K9wvPNCPhgNJEXFCfLofqP4jrSIZDa4XdE,1533 +torch/include/ATen/ops/max_pool2d.h,sha256=Bo0ucI1A8sGFHggRO_tAJPlFxboWMYYqqjg7DYbmndY,965 +torch/include/ATen/ops/max_pool2d_backward.h,sha256=eXjbDexUYYrk-N-0d95fsID1bCb155aipRPAlRFPKr4,2222 +torch/include/ATen/ops/max_pool2d_backward_compositeexplicitautograd_dispatch.h,sha256=8BacQx6lQXlh28ZzQhwE-aaO6y1s3AW1f-RxJMgTs2g,1234 +torch/include/ATen/ops/max_pool2d_backward_native.h,sha256=2sUgTIjVjgoe0ZMMsBcZJFBIEGz2miDaYXbSuFOZIYs,686 +torch/include/ATen/ops/max_pool2d_backward_ops.h,sha256=f-0bkJ9BDorESa_jcVpGX9Jr7HYfqNUzZARpXJ0BGDk,2695 +torch/include/ATen/ops/max_pool2d_compositeimplicitautograd_dispatch.h,sha256=2G0p5t_8oSjIbNPW4yckDFEPu_p46DtGRWqnUcapkgw,924 +torch/include/ATen/ops/max_pool2d_native.h,sha256=UfqeW2hvacIxqN2en8kjvWmK_ClOBo8EZnQFWXYfjMU,634 +torch/include/ATen/ops/max_pool2d_ops.h,sha256=pezmxk5J-tP-oUPqRO-aiNPfWgtHLwrlMsgDSjFSS0w,1409 +torch/include/ATen/ops/max_pool2d_with_indices.h,sha256=XewLetAHLlHq1OFcCcR6swPw5Qi4xbo_Rk6UiW36Uic,2286 +torch/include/ATen/ops/max_pool2d_with_indices_backward.h,sha256=5KpAsdq8domcX5hZeVH9eckbaSH4Nb0Z54RhUmVod7k,2516 +torch/include/ATen/ops/max_pool2d_with_indices_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=8BqMcAXB_51-0odIBa-iaYqir4bs-YIwVzd2Ii8I85I,1019 +torch/include/ATen/ops/max_pool2d_with_indices_backward_cpu_dispatch.h,sha256=uXozHrmqI8_tJfc-U1myFuBR65Jxvp6i2iMymw1igjQ,1534 +torch/include/ATen/ops/max_pool2d_with_indices_backward_cuda_dispatch.h,sha256=-0jpfqzVRX4Ed6P1btLLZRjDfqBONWWzPfvbzhUlwak,1536 +torch/include/ATen/ops/max_pool2d_with_indices_backward_meta.h,sha256=C3S3Qg5codqz7pp5iU92R_DFzjyu5KY5QlnY-9hLtUw,800 +torch/include/ATen/ops/max_pool2d_with_indices_backward_meta_dispatch.h,sha256=yu9X-Y4KFSEntpsi3HWR-UncvLfFuAh6hMCKxKT4wW8,1536 +torch/include/ATen/ops/max_pool2d_with_indices_backward_native.h,sha256=pjRLcHcZWXFyDDPEIUkspMqCAlw1Qhx7BpxGqb7lfqQ,1276 +torch/include/ATen/ops/max_pool2d_with_indices_backward_ops.h,sha256=YIc_OrTflejOV908IlQYJLcdyD3gBHoYFBAKq4qIKVY,2973 +torch/include/ATen/ops/max_pool2d_with_indices_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Qhu2m98_qbSnsGKW5OYT0qGRYnzSLA3oTjd9hVU2zIY,988 +torch/include/ATen/ops/max_pool2d_with_indices_cpu_dispatch.h,sha256=H4KLrT-tbXF4yQP_bETNvyK3DZXauhUQLFlm9RmXZtg,1462 +torch/include/ATen/ops/max_pool2d_with_indices_cuda_dispatch.h,sha256=sqvng_Ha7Vl9JdM9vYFSXAizVkDwz9m7eE6K4OMzlkA,1464 +torch/include/ATen/ops/max_pool2d_with_indices_meta.h,sha256=ZkfUa7EvznZE7f1tWtwsS85sMc_vPy25_CcS3Wo4zSQ,731 +torch/include/ATen/ops/max_pool2d_with_indices_meta_dispatch.h,sha256=iO8PoddnyK_-84zq0Hvw_mRLMIYaSM8caWIfwL3c1xA,1464 +torch/include/ATen/ops/max_pool2d_with_indices_native.h,sha256=YGsQjYhjyrDAeTDwUVZRaZkstFDu4hr_rC8i6LwDoqM,1153 +torch/include/ATen/ops/max_pool2d_with_indices_ops.h,sha256=T6RxfeCqYzA6Be8wORtPMYKRLQNFS6ic_Wzfv9erb9I,2769 +torch/include/ATen/ops/max_pool3d.h,sha256=mynh7ZlyX5egqt7El7-pTUZXbgOcSJClsPoPxzaaASk,965 +torch/include/ATen/ops/max_pool3d_compositeimplicitautograd_dispatch.h,sha256=xNWBy9sxFjnWpjm2Jf3L6CN4wK6gvtPcDafkhm1xpHk,924 +torch/include/ATen/ops/max_pool3d_native.h,sha256=Daymi03M-s3y1pqMT_lBNCjHk5iHl0w9_qKnfm7Tues,634 +torch/include/ATen/ops/max_pool3d_ops.h,sha256=M-VUDBM8ao-nO409kgu_Vy-z5nLqEvVIKqNbGT8sQNk,1409 +torch/include/ATen/ops/max_pool3d_with_indices.h,sha256=3wsMSKi_UScr3-rN9Gf72ynixVM7ex-lhs1jIprJ1PY,2286 +torch/include/ATen/ops/max_pool3d_with_indices_backward.h,sha256=VAfs8QXU3wsdj80nqgMq7TNFrK4k9jV9ijSCmpZMaOg,2516 +torch/include/ATen/ops/max_pool3d_with_indices_backward_cpu_dispatch.h,sha256=JlRIrIDi-oST4ouPpC8upNKCVdVDL3Z3grsE29u2TiM,1534 +torch/include/ATen/ops/max_pool3d_with_indices_backward_cuda_dispatch.h,sha256=EJC6mKLGa7SE-dTeo_Sgitb6I0tIrdmb0xJPjiPAxOk,1536 +torch/include/ATen/ops/max_pool3d_with_indices_backward_native.h,sha256=0HaA0HG8LCdl1FZkf7oLBq4seQwFE-Z9bj8iWfgVCzc,1566 +torch/include/ATen/ops/max_pool3d_with_indices_backward_ops.h,sha256=RcOyJMCyfRSx1c4bPERuhFHi9eWBdGQe1YrROzAtPs8,2973 +torch/include/ATen/ops/max_pool3d_with_indices_cpu_dispatch.h,sha256=sT0iToWPGjJwTn3xPbvmzI1zrHc3uQrcZdFH85s9D9Y,1462 +torch/include/ATen/ops/max_pool3d_with_indices_cuda_dispatch.h,sha256=BYqpgdfwAUUlaa3pvB7oaIuDRPR9FBV7o4LDgWTFsJc,1464 +torch/include/ATen/ops/max_pool3d_with_indices_native.h,sha256=OKXS0p_1JFoqbi4l9ZlOhHAgwPgyDY6fd5WfPCelrOA,1450 +torch/include/ATen/ops/max_pool3d_with_indices_ops.h,sha256=f83nmhAmEnk90ArPux1pfayDnuWa1AQrZ4ng1Gw0kp8,2769 +torch/include/ATen/ops/max_unpool2d.h,sha256=5dF917xPRZFA8An3lNktz9iSjPt8hFitw4aeARAUL1g,4525 +torch/include/ATen/ops/max_unpool2d_cpu_dispatch.h,sha256=E9DIMisyGOiJ2tp_6tffB3j0TijMNVsMAw5KXtvzupE,1527 +torch/include/ATen/ops/max_unpool2d_cuda_dispatch.h,sha256=H3sFei4APQ8MA9eCElg15uG-gjh5GVdvDUTXokhhVW4,1529 +torch/include/ATen/ops/max_unpool2d_native.h,sha256=TYutRWxfhczy8Sx_jQu72elIFIN7OPv8O_VntU5coyg,1024 +torch/include/ATen/ops/max_unpool2d_ops.h,sha256=4qk0OMrLzmcCRplInEsdQJI_I0dIPO0yx4BucN6zn8k,2027 +torch/include/ATen/ops/max_unpool3d.h,sha256=mYw_e_HL3QqzqxyfbkPzeoTvqsTPnUwxD1ffloPgDGE,5503 +torch/include/ATen/ops/max_unpool3d_cpu_dispatch.h,sha256=MiLsgflzxD977edzsIoF629WfWmrKvEGBdkb63CHb1Y,1821 +torch/include/ATen/ops/max_unpool3d_cuda_dispatch.h,sha256=jE8kU55ohcyGB37Y5UOt6fyO3RwhQlZgIdH2TNdc5kw,1823 +torch/include/ATen/ops/max_unpool3d_native.h,sha256=l23KHv1-1mZM4L03th4W7xygUN1hxAza_1VYQRz6eV8,1220 +torch/include/ATen/ops/max_unpool3d_ops.h,sha256=vSNsrmA5L2e5QLZq7eF5J7gEvFdcQU6A1gsVfuo9mBc,2353 +torch/include/ATen/ops/maximum.h,sha256=zlmy9o-AumDHbCpsjzl6McMHjYk9xVqCqG3EpRA9PJs,1218 +torch/include/ATen/ops/maximum_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Tqlg3mcod788vOcEf0TEkz6oyjssXZXOmWd0XwotlEA,840 +torch/include/ATen/ops/maximum_cpu_dispatch.h,sha256=1_qErKsDTcyaCJBIaIP_j7BZymedQpw6etPX1dQpvMI,983 +torch/include/ATen/ops/maximum_cuda_dispatch.h,sha256=JydSP3geHxs31YGQwhtCztfpXynsddgqVhZSs_V334g,985 +torch/include/ATen/ops/maximum_meta.h,sha256=X8wbN1VU-d18i3xOzIeipH1RbtpBjzSp9Ug0XByVlHs,621 +torch/include/ATen/ops/maximum_meta_dispatch.h,sha256=gHvOPsorP2cRo4bnbwp5U-Ml1qc99TYVSyFqT4efWuw,985 +torch/include/ATen/ops/maximum_native.h,sha256=R6-7trKp5rnmppKB1bTPdoLlTNBg34w15BS5F_2gcqo,648 +torch/include/ATen/ops/maximum_ops.h,sha256=iabslcbCfjgYKP9ud23_xUbqF6pv6HUX1nJJ-evrWKM,1765 +torch/include/ATen/ops/mean.h,sha256=rWT8Hv84Sd_-7b_AJl_BzammYa8LdiCR-mGU6cuN7e4,3443 +torch/include/ATen/ops/mean_compositeexplicitautograd_dispatch.h,sha256=GQutBEtudwzYr5zZ2buGKv-ztrgjNAD-azAsKzntXdQ,1087 +torch/include/ATen/ops/mean_compositeexplicitautogradnonfunctional_dispatch.h,sha256=w74-QXowirSUQjBAhzKFSgA6IwpSwjiMIFifQ29tW68,914 +torch/include/ATen/ops/mean_compositeimplicitautograd_dispatch.h,sha256=Fg4syyMsDpj860evLiN7l7RtJNKXBoSK1mmwWJxDzwE,1204 +torch/include/ATen/ops/mean_cpu_dispatch.h,sha256=0Ia0UbEg_Y1CdSU-fmcYMAgKUHVV18MG5DsUFj3gonc,1184 +torch/include/ATen/ops/mean_cuda_dispatch.h,sha256=uMhK68hPUZUM12007nt7rCHsYZfte9c1PNsWfj4Hc68,1186 +torch/include/ATen/ops/mean_meta.h,sha256=oBv76gtfRR22WaQCguRAhcj9vhgPU7erlaB4YgVi7hg,678 +torch/include/ATen/ops/mean_meta_dispatch.h,sha256=-JxopXXJgxETUzVN6oWmeajHJHgGHPVbISPYuRXgVrE,1186 +torch/include/ATen/ops/mean_native.h,sha256=amrwR6fml8aLQ1ty4EmUam7IiiThHv80flqEm_7UIko,1570 +torch/include/ATen/ops/mean_ops.h,sha256=ia8IpJ-suIK_31OBvjFBB_YwurjHgDPxO_X9R_KLnh4,5228 +torch/include/ATen/ops/median.h,sha256=p-Ygw4LYwwGRwJ2-1jHOwAfUI8wJWtyRcoXcUsqhOVI,3285 +torch/include/ATen/ops/median_compositeexplicitautograd_dispatch.h,sha256=D4me7UQP9P-e4gdYSMfR4VTCv0HawVQDmgEzMg8ymrs,1004 +torch/include/ATen/ops/median_compositeimplicitautograd_dispatch.h,sha256=ap_jAjC6hbtFGROOeRlPdayFoMxiLIGQrFD5vS4xKMo,1180 +torch/include/ATen/ops/median_cpu_dispatch.h,sha256=QrbAbELG4ck4o8T1xLlMJ5dPOV4g4NutumXljWKi_eg,1066 +torch/include/ATen/ops/median_cuda_dispatch.h,sha256=1N3flrd0q5UikD4lTZSkpuPZV8B_Taek5bKR9gMtKS4,1068 +torch/include/ATen/ops/median_native.h,sha256=lkoJzz8kUHpDIvUvNTYY3MBBmo7nZfjwOCNZtIYsqxI,1357 +torch/include/ATen/ops/median_ops.h,sha256=s88K3fIl0pvOCjWtPlbKHq1Rao7HOHmgHhoeTD-SYk4,4929 +torch/include/ATen/ops/meshgrid.h,sha256=yHGdmLaOwosQ5nOGdf0RK1TwAZroyvV2UtxG6eM1lI8,942 +torch/include/ATen/ops/meshgrid_compositeimplicitautograd_dispatch.h,sha256=8YPx5_UODIlB0bBgNxOA9kZlPpW_ee4eLaIjK-JEIbQ,901 +torch/include/ATen/ops/meshgrid_native.h,sha256=Kcw_4_L84nC7CQut6I9mGhad65d0bRAL-7JysoXh6Rc,611 +torch/include/ATen/ops/meshgrid_ops.h,sha256=-kS84Doc_JQBEmAg6DE4atFy8cRpk2TAimic8Bbny-M,1716 +torch/include/ATen/ops/min.h,sha256=fKrOZK8BEGJjQXoNb9-0Z3xwUwxiO8qYGNNA0ng2Xi8,3878 +torch/include/ATen/ops/min_compositeexplicitautogradnonfunctional_dispatch.h,sha256=wRNZodHAX1zcNia2DVgT5WWXRcijbkd-_NIhaPRALxg,868 +torch/include/ATen/ops/min_compositeimplicitautograd_dispatch.h,sha256=xVvGM2R8dNB_iTCq9hwdCIgT3_BsO9hjAGE7v67Jous,1456 +torch/include/ATen/ops/min_cpu_dispatch.h,sha256=fBRNi069wKS1eW67bOaFfI0Kcl3s5sh3s73JfCMAoNo,1322 +torch/include/ATen/ops/min_cuda_dispatch.h,sha256=xxNTwjiCDVfyKU1HOh43RCWe-oqcJyIPH002ocRcPa0,1324 +torch/include/ATen/ops/min_meta.h,sha256=cj_CNmozOCdSvec-ktRHCYd8xDYx3TvSbtElN6-7ssQ,1076 +torch/include/ATen/ops/min_meta_dispatch.h,sha256=PGhnRFWSHf9DKKClJ_JfeQ8vSHX0B2Phl0DNbG1QoVQ,1119 +torch/include/ATen/ops/min_native.h,sha256=tkOimfnAMbTot76DhL90GB_sk2HbAfN-C8Rbt0_OmJo,1530 +torch/include/ATen/ops/min_ops.h,sha256=XWd4TCUG000tj8bXBMkwwbegWmtBhALxF7GYD2xYVFw,6164 +torch/include/ATen/ops/minimum.h,sha256=R3YCVn4x-0draLxewkxBjbWWFFknQ-sRChHIwwky7cE,1218 +torch/include/ATen/ops/minimum_compositeexplicitautogradnonfunctional_dispatch.h,sha256=A_j8joDqI5clBn80z0T2h1Z6Pyvv88fKzopsSA5izoE,840 +torch/include/ATen/ops/minimum_cpu_dispatch.h,sha256=AlNXXhrglX_N2mOTjpu2ckoLbXKcLbrOUd6NSnB0rwY,983 +torch/include/ATen/ops/minimum_cuda_dispatch.h,sha256=QPSTaPpOD742WNVTCETlcXg_0fbOALkDp-PL7JHDnwI,985 +torch/include/ATen/ops/minimum_meta.h,sha256=932T3PgKliVmM-J0RxhIIZtoaAGB0rt4ymgAbQOBtBA,621 +torch/include/ATen/ops/minimum_meta_dispatch.h,sha256=TfSthPrCLZT1_HU_Rng-Lum0ASNOxHh1ws-wFnNn-0Q,985 +torch/include/ATen/ops/minimum_native.h,sha256=5lERCD6sSSsFvEf7Nqg25gN4wms7_1T8k3pbiS5wAB4,648 +torch/include/ATen/ops/minimum_ops.h,sha256=Y9lgJViJ-M2d7wVD4SgVC_NVtI8RxDV7EzSLSLIrDr4,1765 +torch/include/ATen/ops/miopen_batch_norm.h,sha256=030fG1jdaPyqUgpEK8JkuAwm5D5kx0OE16oJ8H8Mfw4,2920 +torch/include/ATen/ops/miopen_batch_norm_backward.h,sha256=ojIQcA724H9DMYDokUub7bocDpx5-4lf-r7FqG0QvpQ,3061 +torch/include/ATen/ops/miopen_batch_norm_backward_compositeexplicitautograd_dispatch.h,sha256=2bbLOQ0TfJzED_-aOPS5-U_t6__hh3nTWM4xeHUrXMU,1625 +torch/include/ATen/ops/miopen_batch_norm_backward_cuda_dispatch.h,sha256=XNuliN158W8uxSDfUdOFD4ERoJmlVywBd5JYKM6rZN4,1069 +torch/include/ATen/ops/miopen_batch_norm_backward_native.h,sha256=T4Qo_K7N0zqn3WHEQ15lgTIpCjvTglpiTqIfYJnAVso,1267 +torch/include/ATen/ops/miopen_batch_norm_backward_ops.h,sha256=ZQjpCCusoNqScqbNAf6s-9Axj-RieM0T0iGSRfn1jm8,3846 +torch/include/ATen/ops/miopen_batch_norm_compositeexplicitautograd_dispatch.h,sha256=Q6h_FEoH-aBWptqu7RUQgMIBuMA0D2edpIT78dUEaOw,1541 +torch/include/ATen/ops/miopen_batch_norm_cuda_dispatch.h,sha256=qfIwy-wX5RR27cj5uwJQY5zb-H-WNrrAGFs3R7BcKOM,1027 +torch/include/ATen/ops/miopen_batch_norm_native.h,sha256=o5R7jCcAV9Yp-j62Gmo08CQsyK-aIIbIIji-xpxYOMM,1183 +torch/include/ATen/ops/miopen_batch_norm_ops.h,sha256=V5tIeEtvLIh8cByf8uxgAow8UeQeEqTAp1MOKaw7roI,3586 +torch/include/ATen/ops/miopen_convolution.h,sha256=Gwzag9NYKqGNfdfi8rID4FUFP45UwGhSonNvVvS8vLg,7969 +torch/include/ATen/ops/miopen_convolution_add_relu.h,sha256=5PysZbRdyR870jCplnAGQ6whxLW6nANcuDb81rHdQk4,3020 +torch/include/ATen/ops/miopen_convolution_add_relu_cuda_dispatch.h,sha256=8Zb172m-ZzgIbjnSOoFE0lChFIsmd27MBXzc3T1zBA4,1315 +torch/include/ATen/ops/miopen_convolution_add_relu_native.h,sha256=MrH8FppW-GHc8W1okxM7G82wR2f00Gx_9NzEqT66L20,743 +torch/include/ATen/ops/miopen_convolution_add_relu_ops.h,sha256=1VIrH7MhTvSdlM4aPlpTo0hRlPCyjEhFLJ9zdsfiqZs,1844 +torch/include/ATen/ops/miopen_convolution_compositeexplicitautograd_dispatch.h,sha256=_fYRarf4zKAYVtg-iomPGjaekNvUWjR_o1BJ0O-VUo8,1928 +torch/include/ATen/ops/miopen_convolution_cuda_dispatch.h,sha256=cPMdyM2LGctaD6NzILuGohkMmelem5qGyZg16xjNsyU,1239 +torch/include/ATen/ops/miopen_convolution_native.h,sha256=l11IDK4DCK3EaygPha4rK7PXncreBKNVnfPQkRcCEKI,1015 +torch/include/ATen/ops/miopen_convolution_ops.h,sha256=jE939udMx0Vhr5WEnLt-ijE8HBbvjqBL-nAKAl9bIPs,3065 +torch/include/ATen/ops/miopen_convolution_relu.h,sha256=hMwCC_CxwCX5kkxsFZP15GsYH7pLdwXq1M5byAj5g0g,2626 +torch/include/ATen/ops/miopen_convolution_relu_cuda_dispatch.h,sha256=uKLftXuzcZqWiGthQwBp5D-nD0TT-K0BaTBx2fiee7A,1177 +torch/include/ATen/ops/miopen_convolution_relu_native.h,sha256=R8LuUSmFVz4R-jawUU6qz_37SHFPXY1aGXTKbXzJasY,674 +torch/include/ATen/ops/miopen_convolution_relu_ops.h,sha256=XBBF74cPrHFVpJB17WFx2CbGq0itAh18aT2r5uJhK1w,1620 +torch/include/ATen/ops/miopen_convolution_transpose.h,sha256=diI-J2hMm_IAcUIBu-BZcOM4kZjwQziDSJms8J6Ny7Y,9185 +torch/include/ATen/ops/miopen_convolution_transpose_compositeexplicitautograd_dispatch.h,sha256=ZOlKZxoV3kVBXFJIgSq9V-dFAadyQBLKlgoXGe9tbMo,2104 +torch/include/ATen/ops/miopen_convolution_transpose_cuda_dispatch.h,sha256=o9IvKY8z4LCkJgfdRo_f7pmJuC4QLajN3ET8e9ye1GI,1327 +torch/include/ATen/ops/miopen_convolution_transpose_native.h,sha256=isFSgulSFflpFxd_gSnJJRX-k5tpAKgkY4_VDGsLl9I,1103 +torch/include/ATen/ops/miopen_convolution_transpose_ops.h,sha256=Z3Q92nR4T91bnZr2WRNRyRsHdrdoArjhgASP0NO0kXg,3361 +torch/include/ATen/ops/miopen_depthwise_convolution.h,sha256=eNad5c8fw_M9OS5yUCNnCbpF9utGz3noKiotBsbtG1s,8279 +torch/include/ATen/ops/miopen_depthwise_convolution_compositeexplicitautograd_dispatch.h,sha256=L4t7NPfqPUmam3LwHN_cH1e58cnUpr0SelXNJWr7HCA,1968 +torch/include/ATen/ops/miopen_depthwise_convolution_cuda_dispatch.h,sha256=Y-GXjuJAc6Tz1_OLmADDJlVgAw8SCrbTFI_VBz1CaDY,1259 +torch/include/ATen/ops/miopen_depthwise_convolution_native.h,sha256=ymbuCUeXgHhJowT0t5VS9CeMuBQLiyJaqgWZgidUsSQ,1035 +torch/include/ATen/ops/miopen_depthwise_convolution_ops.h,sha256=EzJkXil3IEigoFoxNqiI4EDqslbqeOjaNaxPmoAg3GY,3125 +torch/include/ATen/ops/miopen_rnn.h,sha256=dVQ8lQE7y7p-nC94jTOC1MY20yCLIHIiqtF_2tn2_Zs,3750 +torch/include/ATen/ops/miopen_rnn_backward.h,sha256=XEllswQWZdz3UalhGAvPIFPS9wruULW37PolLyEvfIs,4872 +torch/include/ATen/ops/miopen_rnn_backward_compositeexplicitautograd_dispatch.h,sha256=giDHa6qBnGRYwBi8eWF-2cPCBLoGbJTe2bb1EbPactg,2163 +torch/include/ATen/ops/miopen_rnn_backward_cuda_dispatch.h,sha256=N9TECVtj2BWf8NqaB7d9mvi82CxzxqwUjjwgqHIkkhE,1391 +torch/include/ATen/ops/miopen_rnn_backward_native.h,sha256=SnFVKvhpbMk9yMX4fXOFlf_PVhDvh2237QjdaSwq8KM,1858 +torch/include/ATen/ops/miopen_rnn_backward_ops.h,sha256=RPy70anWTp2ekZknIjx0lBbHS8y74yoWvVM4QOZutng,5787 +torch/include/ATen/ops/miopen_rnn_compositeexplicitautograd_dispatch.h,sha256=7nt1T4eHJ76Sm3S_E_KvSH76xvlr9eOGKen5v_vSuag,1809 +torch/include/ATen/ops/miopen_rnn_cuda_dispatch.h,sha256=xz_1zQGR84qmXhL0_7BWCpQhOM-oDHeKS7m1oHixme4,1119 +torch/include/ATen/ops/miopen_rnn_native.h,sha256=efyfRaJ-q_T9V2eXXkxRyTWuPM_vo0pCS-i2mCvYYag,1409 +torch/include/ATen/ops/miopen_rnn_ops.h,sha256=RGjOtm8Nc4eRM-VbFFaA12zomjaLfVbo3SDHBzWUaeY,4394 +torch/include/ATen/ops/mish.h,sha256=2UpeF1Dv142pY3-ZiGyRJQ8_cMWclrcTN6Km6oDqlmg,1188 +torch/include/ATen/ops/mish_backward.h,sha256=8GE49olGcMUyiTKiLcGsyjqJqT8Ui65XDXiFaCiUaKQ,761 +torch/include/ATen/ops/mish_backward_compositeimplicitautograd_dispatch.h,sha256=vCsagJYVX-cp5RI4Ri4q4BUZqbGCwZ-D_AH9neJ6VQs,826 +torch/include/ATen/ops/mish_backward_cpu_dispatch.h,sha256=eAUpLMSbwG8s1uEvYWY0qi7xrIk1NkI5eNh5CP7MUjA,782 +torch/include/ATen/ops/mish_backward_cuda_dispatch.h,sha256=xZRU6IrbVqK_qAWQN9s0hamIs-c-tGkngM4oBjdkpf0,784 +torch/include/ATen/ops/mish_backward_native.h,sha256=P8XIh4YHP0NQRQtzzo-4yHN1zidNh1QzxdqPZVAj5DE,635 +torch/include/ATen/ops/mish_backward_ops.h,sha256=hZPMbmHxheP8h1Lfcop2s087iA9RQZoAq8bY9UbxjtY,1111 +torch/include/ATen/ops/mish_compositeexplicitautogradnonfunctional_dispatch.h,sha256=GvNrZ3LK1UNy9r8BTQDn2C7osMamVWPbrdlIdASN6rA,861 +torch/include/ATen/ops/mish_cpu_dispatch.h,sha256=DQOEDLDM5gi0QVpUn3rdIjM9BxeJamv_H_EqY0X7kuM,946 +torch/include/ATen/ops/mish_cuda_dispatch.h,sha256=0pwH5YwK-HiQx7nEX0J3ONHqGoOHDWDLXAkeT2MW51Y,948 +torch/include/ATen/ops/mish_meta.h,sha256=DMTVLzsAT7WJ7-KwT_EfqWwG-OmG_2988q1luYK9X4g,592 +torch/include/ATen/ops/mish_meta_dispatch.h,sha256=BXA0U_805IiXqN0bfBbIyrjGHPuht4WmA0YF73xKBmQ,948 +torch/include/ATen/ops/mish_native.h,sha256=F4DwabGWW-NBSKXjCq02jAVchsABFyIZoxdcwi3VKFI,613 +torch/include/ATen/ops/mish_ops.h,sha256=RGbGgb-fJIA8c5tO2FJE-h3Fkz4pFwocUSB0ofpmZ60,2079 +torch/include/ATen/ops/mkldnn_adaptive_avg_pool2d.h,sha256=WBgHi-bgG5GbbTzqpQxatwxNxwrXGR8z5eZj2UWk-5Y,1453 +torch/include/ATen/ops/mkldnn_adaptive_avg_pool2d_backward.h,sha256=rZfhRp68wUUyMTJoFk59Ct_a1zm5i5NsnHmzmsUIT94,1552 +torch/include/ATen/ops/mkldnn_adaptive_avg_pool2d_backward_compositeexplicitautograd_dispatch.h,sha256=Pz2xa4NPDZ-ZH1HoRhcIuuvsNKgj9OAu-jiezVJgkoM,1013 +torch/include/ATen/ops/mkldnn_adaptive_avg_pool2d_backward_native.h,sha256=cJ_Aqc-1HElP1vc8LNutvwkt13OOEmLYhqwEk6RADJM,698 +torch/include/ATen/ops/mkldnn_adaptive_avg_pool2d_backward_ops.h,sha256=D9GMlc19CnesxvMxjTSqAXc3euFeZv4LxnR8nh-97bM,1969 +torch/include/ATen/ops/mkldnn_adaptive_avg_pool2d_native.h,sha256=IvrAMWhEsCKW2JlnpOKswDBt7mO3dfn278eh7T2LA4o,674 +torch/include/ATen/ops/mkldnn_adaptive_avg_pool2d_ops.h,sha256=kFRoGZFMNPJIRr45Q2uHNKc4VMpuApN8wEa-4Thrr8k,1897 +torch/include/ATen/ops/mkldnn_convolution.h,sha256=RgyAPtPwjpn4y0MaP6c4tNjBSW7o747II6LqEccmZkk,7009 +torch/include/ATen/ops/mkldnn_convolution_compositeexplicitautograd_dispatch.h,sha256=6yS32cbM5RTW35AQbzC08SV1ZAqELcOzRcNLT4e-oPA,2261 +torch/include/ATen/ops/mkldnn_convolution_native.h,sha256=fde4dg-zbxxyP6J8lA5cJqonf_HtXrd3y1O99cDRB6A,943 +torch/include/ATen/ops/mkldnn_convolution_ops.h,sha256=p9XfEFHgfvEyhOwvSWdcr7wkHr56xHmssVJazs1HZ6c,2825 +torch/include/ATen/ops/mkldnn_linear.h,sha256=hKAy7YXgCn72fQZBoR_SiOJWU8YLwg5MeTqWvxVsN3A,1494 +torch/include/ATen/ops/mkldnn_linear_backward.h,sha256=DO0-lof3Gx0FekuiP0R9jnMromqb54IRGb5J5tOInhY,2136 +torch/include/ATen/ops/mkldnn_linear_backward_compositeexplicitautograd_dispatch.h,sha256=bF5LHOAl_fw8wXld1x6FQL9C8IvJjFA5iqObCKDemcE,1267 +torch/include/ATen/ops/mkldnn_linear_backward_input.h,sha256=_r9fMaq8o3NyKK_Hutf_8XGaqYfCPweILPaVHzp9wkA,1674 +torch/include/ATen/ops/mkldnn_linear_backward_input_compositeexplicitautograd_dispatch.h,sha256=8cLpNgsha8K1laGqWN9n6B6IFalR4Q5fSqf-9idjT14,1059 +torch/include/ATen/ops/mkldnn_linear_backward_input_native.h,sha256=6MC_CEjPGWU3O2tp3WYie6je3N7fJ4mJDkfjkX2iun0,744 +torch/include/ATen/ops/mkldnn_linear_backward_input_ops.h,sha256=EeFOMOsnus1rYaExTIZSpsxfC--lENqcrFaKUbl-Jks,2121 +torch/include/ATen/ops/mkldnn_linear_backward_native.h,sha256=BP2N8JcGleDrjY6WV8rA11gI2shZtD15mPbTSzPlP3A,909 +torch/include/ATen/ops/mkldnn_linear_backward_ops.h,sha256=mz0NGjtQOO1_EBUfP3j2zLr6ZPf7JU5kAeHMPEfh9pM,2704 +torch/include/ATen/ops/mkldnn_linear_backward_weights.h,sha256=8V38cBvzqTiobRG_yZQldEeMvtP_dmh--FPdeT31HWE,2024 +torch/include/ATen/ops/mkldnn_linear_backward_weights_compositeexplicitautograd_dispatch.h,sha256=N_NMbHyzLDF6eoYYzmmhgcWl5YU3qj_1trLrfW1Y82I,1191 +torch/include/ATen/ops/mkldnn_linear_backward_weights_native.h,sha256=c6xV9Pe8J0nMBpMPMnj7dcMWrKf8sJUoXCXMiYXqjhY,854 +torch/include/ATen/ops/mkldnn_linear_backward_weights_ops.h,sha256=FR2K9VtG6KYrLtot-u3MkEI5pr_z5RhDIUMFccY88Kc,2501 +torch/include/ATen/ops/mkldnn_linear_compositeexplicitautograd_dispatch.h,sha256=Vsab5Ndc8aCs62Rkz1Bw79Onrq_StCj2uYGXU7oWrZM,1046 +torch/include/ATen/ops/mkldnn_linear_native.h,sha256=RGva8XKAkGYOJ3GwiWiAnr_JOYDCym4Jr3ESc6q32BI,731 +torch/include/ATen/ops/mkldnn_linear_ops.h,sha256=4OlQqk3jRkoDHLwCqoXueQDSNSOXUnQKG9JIKL3jDS4,2087 +torch/include/ATen/ops/mkldnn_max_pool2d.h,sha256=hRh6onnS9nkX8b4fM7RN3UgH6G9AqF41hkAUACMynf8,2007 +torch/include/ATen/ops/mkldnn_max_pool2d_backward.h,sha256=19xaCJyZfz8I4cjDiyO_j9W3S92YePqXMnQ5yhabxBg,2451 +torch/include/ATen/ops/mkldnn_max_pool2d_backward_compositeexplicitautograd_dispatch.h,sha256=G183pfirZODZcEwmZoSu1muqKT-bvdSZbBmpzb-6FWE,1304 +torch/include/ATen/ops/mkldnn_max_pool2d_backward_native.h,sha256=yB5gacQmWWN8c5ehn2eNmaXQFgYNMVd34Agsa0fQf8w,989 +torch/include/ATen/ops/mkldnn_max_pool2d_backward_ops.h,sha256=ohja1BdTA0UW8BJvqA7Ezf51wlyvXLPRv5tqmFaagVE,2921 +torch/include/ATen/ops/mkldnn_max_pool2d_compositeexplicitautograd_dispatch.h,sha256=eIQp6KlHCGGBlnEWVKAsMy7GccSCwETfrieN_xgC_o4,1166 +torch/include/ATen/ops/mkldnn_max_pool2d_native.h,sha256=l06WN4l-4ggqNCbv15IhbDmYGOj67hM3iiv6W0yfs8M,851 +torch/include/ATen/ops/mkldnn_max_pool2d_ops.h,sha256=wkJYs5bvAcH45lb6Y1iJPjBmr6-JqyDO96jXHIAoZ9M,2475 +torch/include/ATen/ops/mkldnn_max_pool3d.h,sha256=4pBm0SKXzSj-u01Z0lMgVG-rnxjkF0-GIiAhp1ay_LQ,2007 +torch/include/ATen/ops/mkldnn_max_pool3d_backward.h,sha256=DHe340U8sN_u_pVPwOJS9pWt50q_lHcP1ntMtl23uhU,2451 +torch/include/ATen/ops/mkldnn_max_pool3d_backward_compositeexplicitautograd_dispatch.h,sha256=bnddq_2pp-a_50RLIzpmQ4PmQHgDuc4s20Yd2oZwR6o,1304 +torch/include/ATen/ops/mkldnn_max_pool3d_backward_native.h,sha256=4YQ91xkPYQf19dn55lh0JGLPenAk02ZAqYK3p6PBlng,989 +torch/include/ATen/ops/mkldnn_max_pool3d_backward_ops.h,sha256=MjO7mWI2iBEkLTOzWSKzJx7WihhgDZwJ5GDnrdEwLAg,2921 +torch/include/ATen/ops/mkldnn_max_pool3d_compositeexplicitautograd_dispatch.h,sha256=UZSoOOmpsCeFXVsyg9wGOjwIkF7B_Uc4WA8u9wktbSE,1166 +torch/include/ATen/ops/mkldnn_max_pool3d_native.h,sha256=UzYOg42ScYkRw1HKaPcWR6ORy0aQ4ClV5OCJ2Z_UcA0,851 +torch/include/ATen/ops/mkldnn_max_pool3d_ops.h,sha256=Vxhje6FrlBNCN9bERDAJJsQg5KY1oH0CH1VmIjlHH0c,2475 +torch/include/ATen/ops/mkldnn_reorder_conv2d_weight.h,sha256=gOTJWsCdIS7J7WPmFfDso8-Zffvfkq2b0ss-rA7REqA,7857 +torch/include/ATen/ops/mkldnn_reorder_conv2d_weight_compositeexplicitautograd_dispatch.h,sha256=2L_3zA0SbMbtjcqpTXHOVGU9h_K96WvAQl13C58rAf4,1783 +torch/include/ATen/ops/mkldnn_reorder_conv2d_weight_native.h,sha256=Ut9Mc4cBsSElQ1hOD4SOWynBkgnMjh67fiOXj23CJ_A,923 +torch/include/ATen/ops/mkldnn_reorder_conv2d_weight_ops.h,sha256=NAzfPCjvnYZI-DzcFSF5TC4NBkHtkOYS3vYnn_jphyw,2725 +torch/include/ATen/ops/mkldnn_reorder_conv3d_weight.h,sha256=3A_b2aFkiW55ifnAUUqO052-UuLcGA8-Mby4N8sodf4,7857 +torch/include/ATen/ops/mkldnn_reorder_conv3d_weight_compositeexplicitautograd_dispatch.h,sha256=6Wn-XxlRusjrX2TYWKRs-5tAi_14GrYFHpGXqJJSmAg,1783 +torch/include/ATen/ops/mkldnn_reorder_conv3d_weight_native.h,sha256=kkvOoT9RBNi7m8Jrs2yndtJFMB8aFS5KuTrQ5aztklI,923 +torch/include/ATen/ops/mkldnn_reorder_conv3d_weight_ops.h,sha256=B8sdFogksvUsF9O5vrqx_LqlqRcyj2sSM_0491wqA8Y,2725 +torch/include/ATen/ops/mkldnn_rnn_layer.h,sha256=Ak_cfgXwjV4VFIkY9CN-NH-b7idt_bBaxpAEbjpwhAs,3795 +torch/include/ATen/ops/mkldnn_rnn_layer_backward.h,sha256=fDpweFRaBicb8T_KTTCkXBIstchBTIC0wlJPy6NIVR8,5619 +torch/include/ATen/ops/mkldnn_rnn_layer_backward_compositeexplicitautograd_dispatch.h,sha256=CgEj_gVZPaNR3lerjqukxs7z_41mfGu-OrjnpkqgJt4,2489 +torch/include/ATen/ops/mkldnn_rnn_layer_backward_cpu_dispatch.h,sha256=X7EgMBSxz1kuKL3ngxVtTUf5mbjvh2HhBm74SRqpu84,1415 +torch/include/ATen/ops/mkldnn_rnn_layer_backward_native.h,sha256=Yf72se3Gb0L5v5BDyHieYjIHOQWxd3VUbdq-dX3k_K8,2047 +torch/include/ATen/ops/mkldnn_rnn_layer_backward_ops.h,sha256=40f9oL4SNlHs5R2H48c4POtW2hiGzsIHz-quJVP4WzE,6514 +torch/include/ATen/ops/mkldnn_rnn_layer_compositeexplicitautograd_dispatch.h,sha256=MELqzXu7GVpsr3RUmod-gJ-ULyV7K953388KzyWYVDw,1785 +torch/include/ATen/ops/mkldnn_rnn_layer_cpu_dispatch.h,sha256=oTfnWlRwNDYmuho6wy5bIr7XrI7oTOyRp6UACBwJRaw,1126 +torch/include/ATen/ops/mkldnn_rnn_layer_native.h,sha256=IF4d2R1HOEQBLmkgkrA5vm75yeIVOkkc0S7awZolJas,1406 +torch/include/ATen/ops/mkldnn_rnn_layer_ops.h,sha256=Rj5kvRmgzKGplkcBlPAnic1ErYh0e7dIthoU4o-IN7s,4377 +torch/include/ATen/ops/mm.h,sha256=jnHNkDFmymEHqi5EMoLxwVe4F2DmCpDwg_QK9vn-mZg,2007 +torch/include/ATen/ops/mm_compositeexplicitautogradnonfunctional_dispatch.h,sha256=ZtkYbsNxONSyV5zbRZCbXHubl7mPMIO6yxxbSOBnA9c,834 +torch/include/ATen/ops/mm_cpu_dispatch.h,sha256=M_KAe78i0x99QxlhCqytOlF531MMYGpUMLYLvehJne8,965 +torch/include/ATen/ops/mm_cuda_dispatch.h,sha256=fi4QWURIm05rtxoZbPljHBKK-HoI1Mn_LJ8sRM1ywQY,1322 +torch/include/ATen/ops/mm_meta.h,sha256=_cAbADg24Gq19vIPWPVDimP_EBQ2rn8mOLtoxLrzZt4,615 +torch/include/ATen/ops/mm_meta_dispatch.h,sha256=MfALu29g2Unr9yaTKbRzYBL3HEb71z2Y7xedXrHGTuI,967 +torch/include/ATen/ops/mm_native.h,sha256=zTYASFBw01hLYXbhVjELGmKQaYPbYRT35dFfWG9fRtI,1446 +torch/include/ATen/ops/mm_ops.h,sha256=pDGRLi1-OnUftGtodYKh0ey5Kl8DEZSrFJAoE1gWM40,3198 +torch/include/ATen/ops/mode.h,sha256=9U0jd7g2V7XcqHBUcwFcIVxlICWzyRJW1f4rzUa6gro,2675 +torch/include/ATen/ops/mode_compositeexplicitautograd_dispatch.h,sha256=bVBeKSVlnpE4Qg1MBi7v83vR5il67LLM1IRfQH36LyQ,1054 +torch/include/ATen/ops/mode_compositeimplicitautograd_dispatch.h,sha256=WjdKh42UyE3UU_3DmXszQJM1mwvy2h_vqjMbH0kN09w,1174 +torch/include/ATen/ops/mode_cpu_dispatch.h,sha256=Z8NPTSoRaRxPXm0kU2zFJqf0KZsaiJld31UrsYKQxoo,802 +torch/include/ATen/ops/mode_cuda_dispatch.h,sha256=ZuudM0a2ykJx8NxQyER6OuG7OG96FpsDUfZcNlHrsDY,804 +torch/include/ATen/ops/mode_native.h,sha256=T8LmAze3u-hsh8gnNDbDjDY_NskWckLGklEeQk_zsp4,987 +torch/include/ATen/ops/mode_ops.h,sha256=egrt52x20w-HTY8O5-9gHoF1aUiMC7_bzztTn0_aPl4,3755 +torch/include/ATen/ops/moveaxis.h,sha256=i4-ITEkWaCH6xgOBz0WUr6Hrc5rnZevcjBa0k3p5j5o,1052 +torch/include/ATen/ops/moveaxis_compositeimplicitautograd_dispatch.h,sha256=_3FlhmaY7aVeMyCRdNsbbCvpLY2wQyL9-ZFcwZH4-fE,936 +torch/include/ATen/ops/moveaxis_native.h,sha256=F-qtU8obzu3o0TiOgU8RjXjh4DW2PcOECAdAmdmsLuU,646 +torch/include/ATen/ops/moveaxis_ops.h,sha256=PkTENbgCDhXDZuSWxiPhrzajqqy1J7035tSVZfcq7_A,1849 +torch/include/ATen/ops/movedim.h,sha256=VjNsmkg113FLfK17ch_dl1DJYBjpONcAuw5n6lWSDQE,1045 +torch/include/ATen/ops/movedim_compositeimplicitautograd_dispatch.h,sha256=ghIYm6Tod1vPAb_fjhQ5EB5UjWRF81i91-8AboaPC5c,934 +torch/include/ATen/ops/movedim_native.h,sha256=FncnYnWDIfIbnrsZks3iHU2gU3Uxx_7lqD5NcOYUhVo,644 +torch/include/ATen/ops/movedim_ops.h,sha256=bY1tD_QktGVeoebnPVi_KmCcrpjgs2G_4iiplhKo9c0,1843 +torch/include/ATen/ops/mps_convolution_backward.h,sha256=9LRswCNfL2PxQJLr0QoTg95_YCCSMPe7t0QqzgsLPbc,9045 +torch/include/ATen/ops/mps_convolution_backward_compositeexplicitautograd_dispatch.h,sha256=P54U-B5_L0WpJtHlQSS0KKvTDhszTdHn5b31ztBXTwo,2220 +torch/include/ATen/ops/mps_convolution_backward_native.h,sha256=oetW6-rJAa8AwzbB1nAABIelcKAE1qBm3MImpKttSWA,825 +torch/include/ATen/ops/mps_convolution_backward_ops.h,sha256=K5FdrzpypdSRC-tbVM2M7JHCOTNtyjweYoIkd2qPg8k,3434 +torch/include/ATen/ops/mps_convolution_transpose_backward.h,sha256=QshiGAJaACmlVTn9I_Io6hk8v6sSnL96jqPbsRg67PY,9781 +torch/include/ATen/ops/mps_convolution_transpose_backward_compositeexplicitautograd_dispatch.h,sha256=0AEBQhB4yPaqSyYnQuF7fFU_zaTH17qjX8JmKZotajI,2268 +torch/include/ATen/ops/mps_convolution_transpose_backward_native.h,sha256=fLjUZgfwRoKfBSlUTlI0P_20VPcukXV8ON2_aOy1i44,839 +torch/include/ATen/ops/mps_convolution_transpose_backward_ops.h,sha256=8jVC4-yU4xX1x0gGnLrqturmA1TGvBWMbwjROLV53IQ,3569 +torch/include/ATen/ops/mse_loss.h,sha256=El2SrRWXmkKd77HQydMz4bMLHWTnGC9XtfHsL4zB16U,1427 +torch/include/ATen/ops/mse_loss_backward.h,sha256=LulkfBnuEHnY69RZSwuk8MlPP1o3h1yY5R9A93_hggI,1727 +torch/include/ATen/ops/mse_loss_backward_cpu_dispatch.h,sha256=tbXFyEbbFz7gE2t38-kuxcUWusfudihKIO3--JMJ_eU,1183 +torch/include/ATen/ops/mse_loss_backward_cuda_dispatch.h,sha256=nvtVCTGrcSUrzEyph9LXDj3XcFZ-a4aLf0-8nKVZVf8,1185 +torch/include/ATen/ops/mse_loss_backward_native.h,sha256=Nfxd_0wjNZ1YP_UZoQkDcOpHVK1071_iuvqmHeFBLaE,761 +torch/include/ATen/ops/mse_loss_backward_ops.h,sha256=UJa1d8BdiqEWoc9f0dW0bML8rhJeXNquYtg2swZ9eJQ,2205 +torch/include/ATen/ops/mse_loss_compositeexplicitautogradnonfunctional_dispatch.h,sha256=leqQk4-6YHal29s09U7Sx__HJmvhoU915MWRYYr2t3s,881 +torch/include/ATen/ops/mse_loss_cpu_dispatch.h,sha256=z9jtzaYFrNwqIFpE2qv6GkwGOS-dLipumxW87qyYL70,1086 +torch/include/ATen/ops/mse_loss_cuda_dispatch.h,sha256=ojPzdQK7RzOLAD9JqgNCLsi0PGKzMpUhsBHYeGvqYpw,1088 +torch/include/ATen/ops/mse_loss_meta.h,sha256=N0BRQipGoUpNAspo3N1j47b-trU6d9y5ppdF6NC_vlo,642 +torch/include/ATen/ops/mse_loss_meta_dispatch.h,sha256=_PH7YqX1hMqduawcXWGMmo8HMqJCkajFRy2VIs3P2Cs,1088 +torch/include/ATen/ops/mse_loss_native.h,sha256=hrcOQrVr1gi2B5Wj8HlzdQItz57pIEBLhctNF_0ZUmc,671 +torch/include/ATen/ops/mse_loss_ops.h,sha256=Jw3yAgK7D9my1Nh60TZ7nkiF7XGxVHBDQXMtUD1gpHA,1911 +torch/include/ATen/ops/msort.h,sha256=1FYX9a8pW-C5dCXsbiwRbgt_coDnuLM9ofcEPJoCL44,1057 +torch/include/ATen/ops/msort_compositeimplicitautograd_dispatch.h,sha256=9wNDMigughV69kuvPwYF0pAs82cHKj9oZ9ef7dciBSk,943 +torch/include/ATen/ops/msort_native.h,sha256=IvTLqcDSNuX1-zL3NUbjv0H91VFacYwcU3qi1kh33Og,574 +torch/include/ATen/ops/msort_ops.h,sha256=sx2BCEKv9w_qT3Juknd8WEaA7KiC29GvACDDLMMsz0c,1581 +torch/include/ATen/ops/mul.h,sha256=pa51ZeQKggs8uWZbCvLYmUhFYFUnUq0HgFndgAxVTGk,1887 +torch/include/ATen/ops/mul_compositeexplicitautograd_dispatch.h,sha256=nV2ntdpVKBKLfDh18thRz4xlzDL4azuD3AIklh__E4c,1090 +torch/include/ATen/ops/mul_compositeexplicitautogradnonfunctional_dispatch.h,sha256=X_6N9So06ZeT1tk39rPhxFu2lXxXF2rxn3JxooD461o,911 +torch/include/ATen/ops/mul_cpu_dispatch.h,sha256=ReW5kM10Lgdq4BogyUF_HU1JugdFslZ6K9tZhKZZYj8,1046 +torch/include/ATen/ops/mul_cuda_dispatch.h,sha256=zSrHJvlnjx3LqO6rNWAGG0Uzb7Vog3DrHuM98HdRz7I,1048 +torch/include/ATen/ops/mul_meta.h,sha256=D0350O5phc6Ef1ESTw5kEbqlWgmn_07TNS-RT_8ZOL0,624 +torch/include/ATen/ops/mul_meta_dispatch.h,sha256=SD8Zlq6Ikoc6KBtkq0xMfYXMeipGTcd4JWqtblB94A0,1048 +torch/include/ATen/ops/mul_native.h,sha256=OoQoa41rfCqGyM0y1bB-w0WtU8TKIoTvXlYdiFNEh5Q,2527 +torch/include/ATen/ops/mul_ops.h,sha256=z0tcEe-7z8aD_ITzUx6jVA4mzRP_bcKyBU11__NBhpc,4282 +torch/include/ATen/ops/multi_margin_loss.h,sha256=NNQaakM5-HqOHcc4U3f0eLKyeMuSqEIBCwLDohd3bD8,2017 +torch/include/ATen/ops/multi_margin_loss_backward.h,sha256=dQ_0p_PCGXhBP-IgN3d-6pddKB-NH3dN1fyN0aFDLGE,2352 +torch/include/ATen/ops/multi_margin_loss_backward_cpu_dispatch.h,sha256=Fn4IQIOwQ-fQ76IAd-sDlddEH6XzQV6V_2x1XkzEe60,1535 +torch/include/ATen/ops/multi_margin_loss_backward_cuda_dispatch.h,sha256=TbUwZ2-CpmKfPE9V9Oci4tQjfTUkmRjdIlQ3SiRVkeI,1537 +torch/include/ATen/ops/multi_margin_loss_backward_native.h,sha256=_s0EN-8CKuHDddzeBb9eUTPN6jGw_s2Rc4iDWiyZY2E,1552 +torch/include/ATen/ops/multi_margin_loss_backward_ops.h,sha256=i5mIkbZ9y3A5giiAHap-37OxfATqg8OdoXEvnG013uk,2887 +torch/include/ATen/ops/multi_margin_loss_cpu_dispatch.h,sha256=4X55-jWlfIbpIhhKDmxrrFwPjtMjdd6F5lfm1RKF_oo,1406 +torch/include/ATen/ops/multi_margin_loss_cuda_dispatch.h,sha256=IuyUUw63zViFiVA3TJBq58H-ROqTs1KHOrszsxA0hgc,1408 +torch/include/ATen/ops/multi_margin_loss_native.h,sha256=Lfqivs1Wf7R-ZaD1QbKgHrX4WPg2-ruypfiY0rinV3M,1382 +torch/include/ATen/ops/multi_margin_loss_ops.h,sha256=qVneUEakK9UBNzXKu16UVJzmMuEv3mYddWeKYZWw6fk,2591 +torch/include/ATen/ops/multilabel_margin_loss.h,sha256=oCe65eA4UVGkt_NcFTzyHTC0gf1v8YaBeGwl59hxAEo,1567 +torch/include/ATen/ops/multilabel_margin_loss_backward.h,sha256=t4g366QzB7Hm174kZvNbpLLne1ZiwJPnR20WY-D8pCc,2044 +torch/include/ATen/ops/multilabel_margin_loss_backward_cpu_dispatch.h,sha256=mKO1QMdT9m4_50Dp9TvE-9AsZeo-1WWadwIWKGx7cQg,1315 +torch/include/ATen/ops/multilabel_margin_loss_backward_cuda_dispatch.h,sha256=R9eIsylDDSCJL8rtn_d-tdqeWzn0zdXqT4NNuhY6lzc,1317 +torch/include/ATen/ops/multilabel_margin_loss_backward_native.h,sha256=wDJ4Ko4kC9N5L6qtkjDBY120yiI71daqcL4YV1raorM,1274 +torch/include/ATen/ops/multilabel_margin_loss_backward_ops.h,sha256=qMy1hqMjPRoJaOR_lZqYdpn37mjF5khWA-boYpR9BOU,2485 +torch/include/ATen/ops/multilabel_margin_loss_compositeimplicitautograd_dispatch.h,sha256=KmOC7Z7XI-hPlrPEWik9XYdpVUKKEEfpg3UoQrCu9qI,1172 +torch/include/ATen/ops/multilabel_margin_loss_forward.h,sha256=qAIEAmMRHVPpha15IW3zdr907LCC71eSv2iHs8H04Dg,1870 +torch/include/ATen/ops/multilabel_margin_loss_forward_cpu_dispatch.h,sha256=TgNrSExeuUvNMBLwTvPafZ6420tl2BbGqkdQLx8YCvg,1245 +torch/include/ATen/ops/multilabel_margin_loss_forward_cuda_dispatch.h,sha256=YBczlw1Dc7hPFd_FAp_T1wgUHrtxqFAo18HDOyJ4SqA,1247 +torch/include/ATen/ops/multilabel_margin_loss_forward_native.h,sha256=qPsxHc5coZ2TIAaekQTiMiO6WWTAiYs6_4bCWchtnH0,1166 +torch/include/ATen/ops/multilabel_margin_loss_forward_ops.h,sha256=4a3nMpM37b4kw35k8ricBy05uQjhbQFVeE-cwBM_QYs,2332 +torch/include/ATen/ops/multilabel_margin_loss_native.h,sha256=VY0whCdEkQa0N8UYRoWLSkmDaDpluSWqh7s1XUPrLwg,720 +torch/include/ATen/ops/multilabel_margin_loss_ops.h,sha256=J79ZcO1DVIbGhRCYGId7VBLzU7k0d6z22WhNpoqCWhU,1995 +torch/include/ATen/ops/multinomial.h,sha256=220vT6W1zXM2zXGYopsxDhBfRM8jyyLzloxS3N_5PYA,5172 +torch/include/ATen/ops/multinomial_cpu_dispatch.h,sha256=kYRDpC7z5scCtp4HC7X3aDZxxHfua8M0TFdM-bSgg7A,1749 +torch/include/ATen/ops/multinomial_cuda_dispatch.h,sha256=KaLrAaKC2tZghjRPyaQ8U0ooFI0lDVg5w_-zrVH3lVE,1751 +torch/include/ATen/ops/multinomial_native.h,sha256=-2BiLUECAdyH37ir_O48-FNOkcLAFuKfJBjGRMQ5iF0,769 +torch/include/ATen/ops/multinomial_ops.h,sha256=my5vVdho98yJuGyTOtg0vfymyhnWcNxM-IMhI9lkRQY,2204 +torch/include/ATen/ops/multiply.h,sha256=hVTjttjT5oZKURTM78SZVy1ZpwyLww_svWNUqXJlp04,1449 +torch/include/ATen/ops/multiply_compositeimplicitautograd_dispatch.h,sha256=T5E7AhCvlzCOKk8K3GSmcIXdXYf3f27F8P0G88tf6vs,1273 +torch/include/ATen/ops/multiply_native.h,sha256=aVVPloFAJHcd9hk4KH5IfEyGDEKE-o31FDiMEK3lpjQ,875 +torch/include/ATen/ops/multiply_ops.h,sha256=7htL6N-cWVDAWJSRlBZCloVUVzUcnxJCHkH5CfasV8M,3658 +torch/include/ATen/ops/mv.h,sha256=2ncfNqwXttQ04SGjI32QBvbKYcSHESXWP9LMzqZ-D2k,1150 +torch/include/ATen/ops/mv_compositeexplicitautograd_dispatch.h,sha256=8AnES8ic5_9IUwiOJR1xw-Mz-Tgbwb7AKiMsSRwGZ_8,1006 +torch/include/ATen/ops/mv_native.h,sha256=LBBrnJBE-R5sfoynmRPagilvek-85xTcUbtUJwq68CY,698 +torch/include/ATen/ops/mv_ops.h,sha256=sYY0YQFtJZz6TV37EowTxju3_pLhkaXSjClfwzDF9ks,1723 +torch/include/ATen/ops/mvlgamma.h,sha256=WUq2-v81QTgt1Nkdvk-UUsE7LgKRln5r6Q-kWW9if-o,1150 +torch/include/ATen/ops/mvlgamma_compositeexplicitautograd_dispatch.h,sha256=6Ez4onXMxwJbac6Z_MB89r5WFcoMTLQIDzmGqDTtI7o,865 +torch/include/ATen/ops/mvlgamma_cpu_dispatch.h,sha256=xGD5uqcvkT4sG5e8hdCpNwdbPH1U5x3bDTn5qiFygfo,873 +torch/include/ATen/ops/mvlgamma_cuda_dispatch.h,sha256=ciuW11TeqXcGR_tQNzqTClOUM-SfMHQhDo9RoVXfrE4,875 +torch/include/ATen/ops/mvlgamma_native.h,sha256=v-AtdMpP8rnOSlt35YBE2tphid35rgR6PLvjUpq4Uss,667 +torch/include/ATen/ops/mvlgamma_ops.h,sha256=o_qeMk-sSaw05Esqb0r-r3hqgiUDnD5bpZso4zVJhVI,2229 +torch/include/ATen/ops/nan_to_num.h,sha256=90CIhauS8u0T1_-XBCq_9T10iT0W4YZBslm4aTzMwes,2085 +torch/include/ATen/ops/nan_to_num_compositeexplicitautograd_dispatch.h,sha256=2lTc67wiEkCcsXNdJFxfiu8MXhaNefAKzSds_Y1t1ck,1123 +torch/include/ATen/ops/nan_to_num_cpu_dispatch.h,sha256=6RJYIqFwHuNc_3VWIpbpyaTO-2JbNoWytXDFPhE8RQU,1086 +torch/include/ATen/ops/nan_to_num_cuda_dispatch.h,sha256=i8PAHh6JJ_wPZqiL82PlVWF5UCOI9MQFJOEaLS6WZQU,1088 +torch/include/ATen/ops/nan_to_num_native.h,sha256=xqwN7nbxoJqzjm-e3TzyIFlIMcfx_ogODWP08-bo6OQ,1597 +torch/include/ATen/ops/nan_to_num_ops.h,sha256=gyT895uKOy37wfsIeyDi9yynoJKIKyWWaHdqIJl3RME,3087 +torch/include/ATen/ops/nanmean.h,sha256=6QvL4Ns055FYGn_-8thoB7mlLDOH4AoeToMw_LPNDL4,1647 +torch/include/ATen/ops/nanmean_compositeimplicitautograd_dispatch.h,sha256=4qXIMF7oxcAjU_6imsqDzJTFqdO1Zw-hrcqgi3HHLQQ,1267 +torch/include/ATen/ops/nanmean_native.h,sha256=tZ1kWuVobm_KHsmgpFKcpikkdnb_JflhxPiRmsj-ioY,778 +torch/include/ATen/ops/nanmean_ops.h,sha256=Zk_mwlWShUNPOkjoST_UxoiuZirRxxII7P7M-vqy4-s,2176 +torch/include/ATen/ops/nanmedian.h,sha256=vfCqBoSDCnO3zbO_nwie8ol-mtpCC7DA9235-UMxyBo,3369 +torch/include/ATen/ops/nanmedian_compositeexplicitautograd_dispatch.h,sha256=wnkJTTV8EN5GRiIvZf6f8cM9J9GfHvgSSo2OsZ9T5bA,1013 +torch/include/ATen/ops/nanmedian_compositeimplicitautograd_dispatch.h,sha256=VbGdqWPc9bfpELES14_fOkyy4MphCj88yqfN7wCm2ms,1189 +torch/include/ATen/ops/nanmedian_cpu_dispatch.h,sha256=4cErTSAwj9-Cqqlgs5XmdcEQXP2ztjdtQP18_I3o714,1075 +torch/include/ATen/ops/nanmedian_cuda_dispatch.h,sha256=1ReZCv9H5TaiAiBB8EQvl4-V4Rzu60ZeLET7f61UiHQ,1077 +torch/include/ATen/ops/nanmedian_native.h,sha256=BnShBnD0pUMEjv9Su37OOSh9okkyfxlH0t6AihkLEqo,1381 +torch/include/ATen/ops/nanmedian_ops.h,sha256=sdQcgTlth3Pv4XyhWhT_Fgya-o2DR5W1hpHiSJTFaRU,4983 +torch/include/ATen/ops/nanquantile.h,sha256=AjrxhvrwuVsLbIra2jafaJ-cYz4vOoUaUCqJCoY7o-8,3048 +torch/include/ATen/ops/nanquantile_compositeimplicitautograd_dispatch.h,sha256=zfDouG3YmUTRyXo-UKa2pa19nFfwpezo68YiHb-McC0,1862 +torch/include/ATen/ops/nanquantile_native.h,sha256=p2vLzyrqr3emuXvqiWIktZQVz7FUQYbD2SCsy5t76ek,1158 +torch/include/ATen/ops/nanquantile_ops.h,sha256=jtNQzdczjc461IV7DMzwIyvnc7Dzzm_S-bJmvfz_8t4,4088 +torch/include/ATen/ops/nansum.h,sha256=lHUz43sMCoVu3XHsT8pU-QLZrfBhUrSir1G6xzNoLgE,1637 +torch/include/ATen/ops/nansum_cpu_dispatch.h,sha256=gaeoM_Iubr_9IjzbvCjjAPRkfcUhCCZaHTrmjbIKz8M,1220 +torch/include/ATen/ops/nansum_cuda_dispatch.h,sha256=yBzjytqvXa5OwG_mkqsGcv6dSh3aArg1FYVnWmkftrY,1222 +torch/include/ATen/ops/nansum_native.h,sha256=jNnox33T3qqFvOvbU3d_FGXeIUdtRD1nDifuWxhDVAw,776 +torch/include/ATen/ops/nansum_ops.h,sha256=zpc2zMkR1gsP84-hFOvEtUlrGtvnvoKY9bKY6AQt49E,2170 +torch/include/ATen/ops/narrow.h,sha256=5mCalPM_Pc47EvdrsrxHjT-WIhEgjwh1iSxJviWnhNI,2708 +torch/include/ATen/ops/narrow_compositeimplicitautograd_dispatch.h,sha256=uJ-aOjDQ9zq4juvb-iQsbKG1uHBpzU9CMlPPJlFHN6E,1176 +torch/include/ATen/ops/narrow_copy.h,sha256=Ru_EiEcbJq1W2lzI8XRnNgyHBKGOorZ4O-BkSLREVMI,4176 +torch/include/ATen/ops/narrow_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=0TPt75hpdigkuizk7VwHI2XWlsqbUoEuc3GhdASx0N8,981 +torch/include/ATen/ops/narrow_copy_cpu_dispatch.h,sha256=n-MgWrGl-WrujkDhv3qcn5eZMCR9T86mLkbiyDMZYX0,1455 +torch/include/ATen/ops/narrow_copy_native.h,sha256=klxsY6FjXQcutmUWGj234rxSfmbaMavYmrSnUQpMjko,930 +torch/include/ATen/ops/narrow_copy_ops.h,sha256=_1u0k021FS1cHqScriGoI_iQGGyDfN1qZCsfmXbPivc,1971 +torch/include/ATen/ops/narrow_native.h,sha256=vyObbIRNWI1UKYCAxjTVs20oUdBF7HA8NspKdtGObNo,805 +torch/include/ATen/ops/narrow_ops.h,sha256=p51j7cHySEakD-SvIlH9o8HoIbuWbQ25ZRO27N9SW_g,1904 +torch/include/ATen/ops/native_batch_norm.h,sha256=LplaQXzp1CLaZCxwcOT4_uGGiZSfZ5iJcwueZ38yYmI,2842 +torch/include/ATen/ops/native_batch_norm_backward.h,sha256=CloWdvwUiwdZDY1yQ3iqfGgRDZHaEr7ccM-MW3tjBxs,3376 +torch/include/ATen/ops/native_batch_norm_backward_compositeexplicitautograd_dispatch.h,sha256=udPlOeodbPgIOBgV3pAI1emWeYKpvAoLDuLGHcRMhiQ,1743 +torch/include/ATen/ops/native_batch_norm_backward_cpu_dispatch.h,sha256=--g5n4dXL73oRdYFXyYzbJYELF9YzmQDIVYFRzM-Ax4,1126 +torch/include/ATen/ops/native_batch_norm_backward_cuda_dispatch.h,sha256=9nQwFg5oX4Qg5oBR6NmbF_bjz5NYgS1Qp_N4TMelF-Y,1128 +torch/include/ATen/ops/native_batch_norm_backward_native.h,sha256=WeeII-vh3dis20XVJy9FAqahsbzr9My7KVfpRpDqwS8,2256 +torch/include/ATen/ops/native_batch_norm_backward_ops.h,sha256=QU3fxcpZYPSRlNzhvZbD0FOOzeZutFyB4DyjdjJNKE8,4232 +torch/include/ATen/ops/native_batch_norm_cpu_dispatch.h,sha256=igTLT25ZbZhQcsSkPyFfpQinC5kDfIZPMYB1oNb1zOI,1841 +torch/include/ATen/ops/native_batch_norm_cuda_dispatch.h,sha256=sx9WbbZuSC_qCa8FG-ePKaHithEfgus6GIGyZFfpmL0,1843 +torch/include/ATen/ops/native_batch_norm_native.h,sha256=bWtJLh70MurGCYi3JGEqZThrAPu3RlCrlEDwKmR6KeQ,2248 +torch/include/ATen/ops/native_batch_norm_ops.h,sha256=K84I1FaXZwZBpDRpGCqcTolwe6VJBD-BmhquCUBG7qY,3591 +torch/include/ATen/ops/native_channel_shuffle.h,sha256=7Jgvc74EzQ0todM1zG26IDEGyeDBBNzLbv0WkietDhw,1529 +torch/include/ATen/ops/native_channel_shuffle_compositeimplicitautograd_dispatch.h,sha256=OtxnhikaiMpQR6iVX2V8RlCbgqkJ2VSeuc9W8ZCcsXQ,917 +torch/include/ATen/ops/native_channel_shuffle_cpu_dispatch.h,sha256=4LOzc9znpF93JPtManCIGeUbYSevoOX3I7ZHd7_fO0Q,873 +torch/include/ATen/ops/native_channel_shuffle_native.h,sha256=cZlCKagnqFuDswbiEDGge4PvoHoHq2VVV5r78Tnr3WE,611 +torch/include/ATen/ops/native_channel_shuffle_ops.h,sha256=521o5MWqIlWqKi8w8_bkGh8tMIx2naMR804dcAfafPA,1102 +torch/include/ATen/ops/native_dropout.h,sha256=LDOYVpmApEtgPb_zE1GGdp5WMDfLRRdJQg1lih82TbI,1576 +torch/include/ATen/ops/native_dropout_backward.h,sha256=JXYMm6ie3gLkCksL5HffQKxyhM9TLxCtwXjEuQnyC_Y,1534 +torch/include/ATen/ops/native_dropout_backward_compositeexplicitautograd_dispatch.h,sha256=LLvjxWT0C4KByoBUdZPv5NHNGZfClRUM9YxsrMvnqgg,1017 +torch/include/ATen/ops/native_dropout_backward_cpu_dispatch.h,sha256=LpVapYLofIQyDkfs1PYs26VWTHkeglmc12lHFNZl0gk,806 +torch/include/ATen/ops/native_dropout_backward_cuda_dispatch.h,sha256=7YODmYuhWZQz5li55vCxZOtSDKBmU2ivr6o2bGeY2fw,808 +torch/include/ATen/ops/native_dropout_backward_native.h,sha256=BwpkptOqnsm9uiqxx2SYB6Xxx9Se5PLLg1GiXQVjcoY,825 +torch/include/ATen/ops/native_dropout_backward_ops.h,sha256=fauZLo6b8sa6e8UFGAenJk8hTZmFMSQzpCP69lSGSd4,1995 +torch/include/ATen/ops/native_dropout_compositeexplicitautograd_dispatch.h,sha256=Q_bpVoLbvY5cszW3I4cH9j-Gd2S0PtrLS7cOjPeghuo,1081 +torch/include/ATen/ops/native_dropout_cpu_dispatch.h,sha256=B1hU2sX-j5hCGQRQwM8OIiJhGPTuc4hbTc4hzDNB0Ws,816 +torch/include/ATen/ops/native_dropout_cuda_dispatch.h,sha256=pv-nYW7zVyTxd1alc-hcxMj6yMJfTGYZtImTSDdzK9s,818 +torch/include/ATen/ops/native_dropout_native.h,sha256=Ral5_wcng0c5zMrR6-z9DAArP2SNnquAOgpqK0pQVU8,1016 +torch/include/ATen/ops/native_dropout_ops.h,sha256=WFUtwMFdQtE_lsHL1y7bLv95FS6GdN6TYiGSNudiwGA,2155 +torch/include/ATen/ops/native_group_norm.h,sha256=LGGLvA-Vjws7Zxd3rI_mTZ_-VakSp43d4u5nkN3_pF4,7304 +torch/include/ATen/ops/native_group_norm_backward.h,sha256=dIIavlCZFP8PSxd4lT5hUn2I64GbIf-GGrRNgigcWws,8813 +torch/include/ATen/ops/native_group_norm_backward_compositeexplicitautograd_dispatch.h,sha256=EOvW_TCb-oMCpyGEIGudXWdBCqQ1T1fqHsvePdGyGoU,2316 +torch/include/ATen/ops/native_group_norm_backward_cpu_dispatch.h,sha256=dfq8nnIbeH82tL0_TvwSdroaaswM_voFdv8ac2Te4VQ,1345 +torch/include/ATen/ops/native_group_norm_backward_cuda_dispatch.h,sha256=I98qf6yn0twXyenTOBVFt9soGh3Oh9LLLTpjRfGUmo8,1347 +torch/include/ATen/ops/native_group_norm_backward_native.h,sha256=dkFs6mGmSa0h9tNpH5PPpyTLPDVCqV0PjaPC5iHivl8,1166 +torch/include/ATen/ops/native_group_norm_backward_ops.h,sha256=YCshptN-Gi944s3zPPEZCHTBxtdMTDLSxoaQ8l5vBVI,3582 +torch/include/ATen/ops/native_group_norm_compositeexplicitautograd_dispatch.h,sha256=MmX4X26MMcsO6lZdBrg0TqcDrWEmN5Mo2ac51XUNNSo,2565 +torch/include/ATen/ops/native_group_norm_cpu_dispatch.h,sha256=Q2D3L0s1j04E6Lh5C4tkGGKUwk9hfekT1PSn_OM2mbo,1209 +torch/include/ATen/ops/native_group_norm_cuda_dispatch.h,sha256=KgxvYpDvVGndged2LLc9vgUkBl4nl-KswCTHWOyiYFA,1211 +torch/include/ATen/ops/native_group_norm_native.h,sha256=9BLH5nGr9oId_PLELjcD1k69i4oZzMQX7KxDQyAejlQ,1279 +torch/include/ATen/ops/native_group_norm_ops.h,sha256=e2IAnzoJ_a11raFNGy2Yh40Bqr21NmSSWRITZZz-WCY,3140 +torch/include/ATen/ops/native_layer_norm.h,sha256=E_rs_qRDYnkrm8qD2YHh88JRCUUqbOo0feqIrXztuOU,7124 +torch/include/ATen/ops/native_layer_norm_backward.h,sha256=wcZg1gwsM_c978glayInCbNAT15F45PD-JAUE2Fn1As,9293 +torch/include/ATen/ops/native_layer_norm_backward_compositeexplicitautograd_dispatch.h,sha256=tHnVYB1V8TynL_lvRoKiNXXgKZXa-_IAietoTWn9dvA,2404 +torch/include/ATen/ops/native_layer_norm_backward_cpu_dispatch.h,sha256=kRMpt8MeGyQ7uMd6e0owsNNkoa2Rwl5HG1U0XDWfw_Y,1389 +torch/include/ATen/ops/native_layer_norm_backward_cuda_dispatch.h,sha256=a_QG-I1EyKQvDRhEW6R0eYtHkG5g0FsdbLr_ytJesT8,1391 +torch/include/ATen/ops/native_layer_norm_backward_native.h,sha256=YYl9VhA9solojrz9MFvdlDG1o0K3U3I90rC0XF5l8-0,1895 +torch/include/ATen/ops/native_layer_norm_backward_ops.h,sha256=aZ12Iwl9lyAYScoeb2vlh6SbxQ3HZKE5Fak6TboAFIU,3670 +torch/include/ATen/ops/native_layer_norm_compositeexplicitautograd_dispatch.h,sha256=Agpt0eqHAEmkOBfeYpoj-PbH3HtMdA0Ntz0TqbgF4cw,2445 +torch/include/ATen/ops/native_layer_norm_cpu_dispatch.h,sha256=tB1lHF_kYFTwxHNDjs1JXygm7qCc7BE-p2xR2wgqnlY,1169 +torch/include/ATen/ops/native_layer_norm_cuda_dispatch.h,sha256=GN2Y8LW9tT7kr8C2Gagp00IHYw7-KxtRr8brmKu9xHg,1171 +torch/include/ATen/ops/native_layer_norm_native.h,sha256=7hefz0ZJ2pj8hI0r0oEOO_QayBILIsSkvGB-0rWxuEQ,1695 +torch/include/ATen/ops/native_layer_norm_ops.h,sha256=BbWfr025PdvNiP4yiDxKanAXYXaJcTnfqBRKYve5gak,2958 +torch/include/ATen/ops/native_norm.h,sha256=NrYV1iB4suyIrk51hay99LtAnhnwWsiQ_aVml29VSN0,2515 +torch/include/ATen/ops/native_norm_compositeexplicitautograd_dispatch.h,sha256=w0i8r8REFV2SUJF2NDQgVoA8pTQmEVMGaVNWbl-WYYE,1342 +torch/include/ATen/ops/native_norm_native.h,sha256=KuY9nfiPGVCUbqHrmeamicoW84rNzh46N1-3gABPUCw,1022 +torch/include/ATen/ops/native_norm_ops.h,sha256=LIerycRdgZZpNahsvdwWT2fJ9gXBam5QnUoQ-EBK_WY,3780 +torch/include/ATen/ops/ne.h,sha256=PzkbFaviAuRWWoXIS-qNM7SZoBiSyZrRJRxahkb7wns,1896 +torch/include/ATen/ops/ne_compositeexplicitautogradnonfunctional_dispatch.h,sha256=CUfSa3jdg2lMFyEeZTiY6-_33TepYJoE3Dt3mLin5TE,1060 +torch/include/ATen/ops/ne_cpu_dispatch.h,sha256=NCudWTUQ87Hxv-LkcWTH6fuYj304NRFD49ZHuIGvUvU,1396 +torch/include/ATen/ops/ne_cuda_dispatch.h,sha256=eyzj2oMaGt4USuTsB2QeoLEb6aZ5uJTDhv9ohByhhWE,1398 +torch/include/ATen/ops/ne_meta.h,sha256=LnoaafBFfF7ul4z3o6znwH0V4KDxGr-PxizGip9SwnI,767 +torch/include/ATen/ops/ne_meta_dispatch.h,sha256=DzcY6_N-g1X0Uk5H-PDdJ-VZ2pM0n9rqCj-DC9WXagI,1398 +torch/include/ATen/ops/ne_native.h,sha256=PpJv2C33RvGZ0z-U5p639TVrk5dohVc_Mc21v1BgaQw,1235 +torch/include/ATen/ops/ne_ops.h,sha256=lpwMVqhFQItatrBMyjqTruPQKkXAUAEUVw-Ajq9GR2I,4285 +torch/include/ATen/ops/neg.h,sha256=vkqdRUb06vzH5MGZe_59GqmRL6bN6RMHRvyci5EhQME,1175 +torch/include/ATen/ops/neg_compositeexplicitautogradnonfunctional_dispatch.h,sha256=QyGOfBwB9wigvbcbVKEa29lgO9oqayyPdnkH9GUCfO0,859 +torch/include/ATen/ops/neg_cpu_dispatch.h,sha256=cuciUpSH9cchDH5m6u-GUdQPDsrY1q9T1W07pAke6is,942 +torch/include/ATen/ops/neg_cuda_dispatch.h,sha256=6swx_PyvOBQ-3jd5l7LmFqmPyW2clyhm0nysCmNshu8,944 +torch/include/ATen/ops/neg_meta.h,sha256=rHedOr6STGoQjbuwIjYRL09ha5Q-Oa2ez_dOiKyl_qw,591 +torch/include/ATen/ops/neg_meta_dispatch.h,sha256=KowkJguMJq7zexuWw877L6Lo2FRZL-YVVJtQ5G8ehw0,944 +torch/include/ATen/ops/neg_native.h,sha256=f2SxujFrqeAEUm4vto2AZYaMlU3j2QUeyXSOAWzeajs,1145 +torch/include/ATen/ops/neg_ops.h,sha256=8V5x1_X_98D8M5iKVfMU_lIir0ro5sGRqM80M6pJVTc,2070 +torch/include/ATen/ops/negative.h,sha256=U9fDYbcKRmtkE2WqhAb3bxJtC08EXTdeMgUBxCn1OK4,1240 +torch/include/ATen/ops/negative_compositeimplicitautograd_dispatch.h,sha256=O6cilrHxSlVSec2sW7xrAFDXDmT6PGsY2UqFAf626Iw,1006 +torch/include/ATen/ops/negative_native.h,sha256=ACqyULEYPpcy_dU3or7Pw-uME5xGheMMvFmSZDWyx-4,634 +torch/include/ATen/ops/negative_ops.h,sha256=9Ys1XraW6MI8Nr8vdcoFeIzhfkfAXymsWSzLwidmV_8,2115 +torch/include/ATen/ops/nested_to_padded_tensor.h,sha256=95gEZwKsB2c7HLN6bEIgFnr5W0Wv823xXETzLtSJKRI,866 +torch/include/ATen/ops/nested_to_padded_tensor_compositeimplicitautograd_dispatch.h,sha256=2p0sZ-YMDdJWKsPe1vUQur16xDTHwQclYzTLhN8398w,872 +torch/include/ATen/ops/nested_to_padded_tensor_native.h,sha256=dz8LZzt8ukGqr7KF0xgHon4MO1g11FwaxOxhocSeP4E,582 +torch/include/ATen/ops/nested_to_padded_tensor_ops.h,sha256=rYo4ubOfjG0JnIfn69uS4-y8wJGSELa2iBf9LnI_EsM,1216 +torch/include/ATen/ops/new_empty.h,sha256=kxcL8jFI4zrIKMZmXRXKiAdrnIySZ50k3-NmiYTh24o,4433 +torch/include/ATen/ops/new_empty_compositeexplicitautograd_dispatch.h,sha256=orZ9BemrjYYAstC9wyPmT905Hb_hODRpGj06tSga2XU,1864 +torch/include/ATen/ops/new_empty_native.h,sha256=NRJ-02WVdLMDhN6qey2RpqKxhS-X6m_H-BiNutNnE3Q,805 +torch/include/ATen/ops/new_empty_ops.h,sha256=ghCbR2kpd3I2sm_F1zvPTfLTqhy--ZWw5ulkLUlHt94,2277 +torch/include/ATen/ops/new_empty_strided.h,sha256=azKJxb32MvZKSgJqjO_hyA69ECfdcxRWdP-lV_uHWuA,5297 +torch/include/ATen/ops/new_empty_strided_compositeexplicitautograd_dispatch.h,sha256=EHGMjtQpRaE0JQrsLxb71RsuKdrFj2Zl58FiNv6lt4o,1308 +torch/include/ATen/ops/new_empty_strided_compositeexplicitautogradnonfunctional_dispatch.h,sha256=a3EYEeN5JkfOtkg02Rby_bxlfuE6E6JnhMnKHJadEIQ,1586 +torch/include/ATen/ops/new_empty_strided_native.h,sha256=7PnkrzcQDoiDm9wWuVfZ7_-xhNFEztQOe_0rwTHXQyo,877 +torch/include/ATen/ops/new_empty_strided_ops.h,sha256=R55J4KyyQBV-4hULHfjRc91yLr6lXiKfkvbHqMjpLx0,2513 +torch/include/ATen/ops/new_full.h,sha256=TOWFFjm6WG7eYo8k1A8RuvX6ry40Pc0U-vvtmO3eR_0,4996 +torch/include/ATen/ops/new_full_compositeexplicitautograd_dispatch.h,sha256=iho3OMy0Ob_IKV4tehjZ06HuFFIk8_44QEG9E_Cez_s,2104 +torch/include/ATen/ops/new_full_native.h,sha256=-kFhnf8X9NfdTP5nAxAt8u0oKTBTx7Yc6RmSJGvfAmM,854 +torch/include/ATen/ops/new_full_ops.h,sha256=_sYAya09BrnOzybqW7gwi5JLaq7LIVf81J01spdXzN0,2473 +torch/include/ATen/ops/new_ones.h,sha256=yeSXq4vNwnGk091o9B4ZNYlDe5P8yHB-4LY3ntrlPHM,4404 +torch/include/ATen/ops/new_ones_compositeexplicitautograd_dispatch.h,sha256=M7_pkIa1G8Ax--OTj4JUIjHgwR4cIfbqa5amwq0BbSk,1856 +torch/include/ATen/ops/new_ones_native.h,sha256=mPErXoINt8dXK5xWFo2iOy8E7wpqkFXxWvZZ9EgJKPI,792 +torch/include/ATen/ops/new_ones_ops.h,sha256=cbCShLZWcfasSXtAgwv-5isMhF3kRoEt9VMKIqucdAM,2271 +torch/include/ATen/ops/new_zeros.h,sha256=0kyEq3R6fyPjhy-ICB01MHKV1dIkAMt00MNpF6c8Id4,4433 +torch/include/ATen/ops/new_zeros_compositeexplicitautograd_dispatch.h,sha256=kDgBF5xx9juMMZeJUsIuQj1vKVGoAB5t9gSp2Hgw5VI,1864 +torch/include/ATen/ops/new_zeros_native.h,sha256=_HYr6gog3Nla6NjDx5rTMu2wqBmnyrLtSdh0dkKVCYs,794 +torch/include/ATen/ops/new_zeros_ops.h,sha256=vvjcljPmGrSWiVQkfUajuYRhvymApBZLKAuXFv_rNmw,2277 +torch/include/ATen/ops/nextafter.h,sha256=RVgxz8B019aUOZumuDfJD0T8-zOuWQK4RfAnH_0ViGg,1238 +torch/include/ATen/ops/nextafter_compositeexplicitautogradnonfunctional_dispatch.h,sha256=xXSdqDz6ZvfM1vQK1CddiXVsSjsRPOhxEgStIavywpE,923 +torch/include/ATen/ops/nextafter_cpu_dispatch.h,sha256=xvO8NzNKXzK0pA4Lkbi6OoUVBAA3Pp_EXgnm5INVFSM,1070 +torch/include/ATen/ops/nextafter_cuda_dispatch.h,sha256=yJ_J4V1LywMvBLp5LM946JRDIDnZrO-6CaB3463doII,1072 +torch/include/ATen/ops/nextafter_meta.h,sha256=Qr5vTWVJGwcp3kf0_5tTlWYQEsvoAZ9c8E_bYkuRzEY,623 +torch/include/ATen/ops/nextafter_meta_dispatch.h,sha256=mkJ0Q36UOKM_f3AZtEr5MiJDyMwg8ssJOpfpwpKS6Ao,1072 +torch/include/ATen/ops/nextafter_native.h,sha256=rEc6AUrRzjyb5IpoT34_VjbTsTS3Ws1omA1uAKyYme4,654 +torch/include/ATen/ops/nextafter_ops.h,sha256=BG74_BDqi1YaIOWu-FCKQ_foaGnmhNNyYIBz9oWjKdM,2382 +torch/include/ATen/ops/nll_loss.h,sha256=-PpNw2_60nL18Iogur3Xek4j7sDdFP_r5RPEm93IDVI,5615 +torch/include/ATen/ops/nll_loss2d.h,sha256=faWhhv-ZIxAbXDefpNJrpdVj8wGe5vNETKFma-gKG2g,5677 +torch/include/ATen/ops/nll_loss2d_backward.h,sha256=vlSwGMZCLU1RI2_YqS2Rjt0vNSp38M7UWAeNEMwiaQY,7216 +torch/include/ATen/ops/nll_loss2d_backward_cpu_dispatch.h,sha256=s-rSK3XLq60O6UlsvGmoev9QWrQZ-Sl-eaKCO1sZzlE,2317 +torch/include/ATen/ops/nll_loss2d_backward_cuda_dispatch.h,sha256=0NhX6GOV1gG4xReO4IEMv0SRtfIQA_dVBA2Kq3ab_L0,2319 +torch/include/ATen/ops/nll_loss2d_backward_native.h,sha256=DJNX6qrvLXVSn0i3oQdJrQ2KLNxufs5ggg24NO9YhRg,1502 +torch/include/ATen/ops/nll_loss2d_backward_ops.h,sha256=Mqg0Qt_4sGsc7V0_F5xwwqYHAUXClQohnd_WxqciKpM,2885 +torch/include/ATen/ops/nll_loss2d_compositeimplicitautograd_dispatch.h,sha256=VL6rhiTlJqnxgd_Vqmn1vv2crD2Of5Epe5DSe4SfJwY,2001 +torch/include/ATen/ops/nll_loss2d_forward.h,sha256=XH2Pzz4v13gtDIshkXb9XompKjxwkjBW0HilITUYGkM,6567 +torch/include/ATen/ops/nll_loss2d_forward_cpu_dispatch.h,sha256=aoTJN0FWNLDgnZKEkgtUAAGzVvrzGaDyaOeDmN4-0XA,2171 +torch/include/ATen/ops/nll_loss2d_forward_cuda_dispatch.h,sha256=xwze_GvQqQZ8dbcHZCVIxQj4np2PE-cStNstAk0Pf_s,2173 +torch/include/ATen/ops/nll_loss2d_forward_native.h,sha256=JkyrSxMqdSlqICa0zkv5qv-Q-k5wx0BjY2SDqQ0p0LY,1388 +torch/include/ATen/ops/nll_loss2d_forward_ops.h,sha256=ZPnEGaDUMrdmrlnWrJVuvlIa-TPlAEIWHHOTX1B6RwQ,2726 +torch/include/ATen/ops/nll_loss2d_native.h,sha256=nOiNAJz5hulGp3hMnArpjQKYbCZzZdjrLuz44atbJTY,847 +torch/include/ATen/ops/nll_loss2d_ops.h,sha256=7Ip9ByM83HxRtITSYBcUwBS4WxYnYOEvr2IbKtbJpUU,2397 +torch/include/ATen/ops/nll_loss_backward.h,sha256=B2j2p9N6DzL-lZwKXdIQTElMUH3tbL83rVQJSlJjWOg,7154 +torch/include/ATen/ops/nll_loss_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=XBEyiPiLBgKUhmJyafMYfZoC0O2DaxD8OtJ1I-JHz7Q,1255 +torch/include/ATen/ops/nll_loss_backward_cpu_dispatch.h,sha256=RiDGZibI3wpo89UjNlVAnOBQ2lLdMsfvIypXkXyVJIs,2305 +torch/include/ATen/ops/nll_loss_backward_cuda_dispatch.h,sha256=wqcK5zKQmk5EM7gSgdtRYGxmVVpXjbi_9c_AdlujzoE,2307 +torch/include/ATen/ops/nll_loss_backward_meta.h,sha256=iA7t5IdZQ3rt_8qYZjhpD_geq-WzjdE6F5XNO03seSg,768 +torch/include/ATen/ops/nll_loss_backward_meta_dispatch.h,sha256=dx3Rnzq4047puvLky7MF3uDPgOkQefBD0B3oQC4F8cc,2307 +torch/include/ATen/ops/nll_loss_backward_native.h,sha256=Eodz1BKL3JmZXNbQFZRmLLC_JFqft40aedZ2_1qiMis,1167 +torch/include/ATen/ops/nll_loss_backward_ops.h,sha256=bzKkeXDRU4CTs8KL-emvoOyE21RWaVBJEWgf1IyirlQ,2873 +torch/include/ATen/ops/nll_loss_compositeimplicitautograd_dispatch.h,sha256=W9lIUwFovkjms068IajuXGYK9MvIy6r--1pYZR59O0I,1989 +torch/include/ATen/ops/nll_loss_forward.h,sha256=Z1-QFxCHpwtGaD_SVed6pSlb96vJsxQ4m_T1d82FsV4,6505 +torch/include/ATen/ops/nll_loss_forward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=YYleFtt8SBl2XFir4AkKioODlKPlkl9YGkvVZfdJus0,1173 +torch/include/ATen/ops/nll_loss_forward_cpu_dispatch.h,sha256=5f5T4GsSY9uovSMdf4dV3bbhQ7v0LYJfq-GkA2lwFCo,2159 +torch/include/ATen/ops/nll_loss_forward_cuda_dispatch.h,sha256=P0e6Y6XmmC4SfKNb2mhn-4cnVKhBDPqv123b0Sga3VQ,2161 +torch/include/ATen/ops/nll_loss_forward_meta.h,sha256=_PJFFoJCt9oVFbEEX4txigvXTsD6ZNr_7eVNiuTRsmA,702 +torch/include/ATen/ops/nll_loss_forward_meta_dispatch.h,sha256=wbH7ydpwwjncTztTiIcbGWAJmDN1oqSBWW8LA2KoBbk,2161 +torch/include/ATen/ops/nll_loss_forward_native.h,sha256=AowgR3ngHY-M-6QE2-CyEKbfsaL35BLgtHxc9BntuX0,1090 +torch/include/ATen/ops/nll_loss_forward_ops.h,sha256=Ja6l-bWEOmJUR-2wBnGAkRBkLeL5WRy7FpwfcRLj60E,2714 +torch/include/ATen/ops/nll_loss_native.h,sha256=bmK4LM4QOGRuCe6d0PqohKeWb8Up8h5p34YM7mRk2ik,843 +torch/include/ATen/ops/nll_loss_nd.h,sha256=HbTXZxznVheWjwLGiRqixij5ma6ksFJvHWgRFiMiPuE,2170 +torch/include/ATen/ops/nll_loss_nd_compositeimplicitautograd_dispatch.h,sha256=JACaHb0pMgh-BCGOS0MYmo1WiLFjsMMRlbEpKewKcpY,1143 +torch/include/ATen/ops/nll_loss_nd_native.h,sha256=Jvl23wu0kiIBjh3vr1OawdIHSBW4xYcNpicTFAq1AeU,653 +torch/include/ATen/ops/nll_loss_nd_ops.h,sha256=jOU43M4FH3feosvqRrgkbJuOXAzk4zxN3tNIsy6qHII,1394 +torch/include/ATen/ops/nll_loss_ops.h,sha256=1-hsUoGkSUoqsL5Jby13L76wCwO_CGnl-RJr8x1yKSg,2385 +torch/include/ATen/ops/nonzero.h,sha256=RPP3_ZicPsF_vv-HN82XM-s9k2ulpjCf4LYm0gJbpa4,1077 +torch/include/ATen/ops/nonzero_cpu_dispatch.h,sha256=z_Y2AR6_00WoRlL7YcT5HXKBAJ1WxA_tUPVBYziAfyQ,905 +torch/include/ATen/ops/nonzero_cuda_dispatch.h,sha256=isXmPnweDW30bsL9e7LrGMgRQE5Pffav4DGP7J9eTvs,907 +torch/include/ATen/ops/nonzero_native.h,sha256=j9sAv5jK2jGSc6yUxcBz9uwtoPxGCKWc7q9xN9TMHNA,732 +torch/include/ATen/ops/nonzero_numpy.h,sha256=9BqEp3gqEvh9MfgrNXzq4OCVpVtD_YbN_rZjrEt8wtY,713 +torch/include/ATen/ops/nonzero_numpy_compositeimplicitautograd_dispatch.h,sha256=DW2mMwio4rRvPC7wSBYBtVr847pA-a5tziIc8qbsExc,809 +torch/include/ATen/ops/nonzero_numpy_native.h,sha256=Eu8-c15cajGbgTJaWA6XwYp4bo4UDqtYb8_ryqBvb-E,519 +torch/include/ATen/ops/nonzero_numpy_ops.h,sha256=s-xSGBv-tBjlWKHJ8sl0UCjA_ucG-DuYWjhUhdtI608,1054 +torch/include/ATen/ops/nonzero_ops.h,sha256=DCmLkMfd-kX_NeeSUayk43o6tB6AH2i6uvnMkA7ooGc,1593 +torch/include/ATen/ops/nonzero_static.h,sha256=2pLBFqIN6F2DgtWCXyOMCRPWWfYMwRQlmPGynGEwKps,4095 +torch/include/ATen/ops/nonzero_static_cpu_dispatch.h,sha256=kfUxHVOz6b8vPTxiBwrvjzJlfw8Q4wYH2sguWs9P9mg,1413 +torch/include/ATen/ops/nonzero_static_cuda_dispatch.h,sha256=RIEv5XQGP_2dhyl8sTIel8pIcD4Lgb44fAmOFWfZd1A,1415 +torch/include/ATen/ops/nonzero_static_native.h,sha256=dWxMrtDdB_BpYj4gNylxBdNC8Wqo_-UgH5im9q_O5TU,902 +torch/include/ATen/ops/nonzero_static_ops.h,sha256=E3_A4lS7szvY8J-yQHMjLhdGvG7pAeObjO1_e8PTw2E,1898 +torch/include/ATen/ops/norm.h,sha256=TXBrOGXyXamMAN5-o-gOZ0tsW9x_QPuU23tYvqK1lYo,6372 +torch/include/ATen/ops/norm_compositeexplicitautograd_dispatch.h,sha256=_yPpGOmKaI8uzCOmEpiTcP-iiiobh2mkVgQNDOcD1M0,1401 +torch/include/ATen/ops/norm_compositeexplicitautogradnonfunctional_dispatch.h,sha256=AzRO29YsTmUlZmhALIShRUe4JVksPolRx9-HhwDyTSg,1040 +torch/include/ATen/ops/norm_compositeimplicitautograd_dispatch.h,sha256=Jl7a6vGBnu1bhRq0QxRZiuyBeeUPbZD5BBga7FiTTZ8,1670 +torch/include/ATen/ops/norm_cpu_dispatch.h,sha256=Q4YoNXxDp6qJwyHX5VgBgvUFSZNZ9MAKAo-Q_tpgX4w,1626 +torch/include/ATen/ops/norm_cuda_dispatch.h,sha256=bL9p5xxkilVHAUR4hqFerFAoRoSOMVNH8U_3wQrD1-o,1628 +torch/include/ATen/ops/norm_except_dim.h,sha256=CSnDLdEABvaNl58lpsNtUDMR4bnXmb6x5XPsZ9BW2_w,757 +torch/include/ATen/ops/norm_except_dim_compositeimplicitautograd_dispatch.h,sha256=-kS1bktk8zKrjaC6KExKnaCvu3nAm61dVDU2UBlfVZ0,823 +torch/include/ATen/ops/norm_except_dim_native.h,sha256=YHev8E-7famAcRby9G_wg6nnGlBS2YLBzJiFdG5hQrw,533 +torch/include/ATen/ops/norm_except_dim_ops.h,sha256=XrZHqHXklhRde1jnCTvnlUNBpwoq4iJOp4-IfV-jvO8,1096 +torch/include/ATen/ops/norm_meta.h,sha256=VCysIshI4QENDmoAmTFJuiajOsGVAxKw5qsk_XkJdz4,881 +torch/include/ATen/ops/norm_meta_dispatch.h,sha256=9E1BuqkfeouW764qZZKLqXIF8hrAjrU3oiOuoeF2EOQ,1628 +torch/include/ATen/ops/norm_native.h,sha256=7lsQoioyi0lMHXWSkbZbZYX_buWLdZJKblni1xA2zaY,2294 +torch/include/ATen/ops/norm_ops.h,sha256=LCS6lJsICtNytDcoQl4LkJ4289oZ3mPgQIGqNQWCN1U,10400 +torch/include/ATen/ops/normal.h,sha256=qhDQa6eNfhhMpXLYPn77G0THSc7HaGiUtpzmCYNKMOM,11831 +torch/include/ATen/ops/normal_compositeexplicitautograd_dispatch.h,sha256=NSHfEeKd4zAZItIC135dS8wrgnQ5WZ67jRQK28va8X8,2700 +torch/include/ATen/ops/normal_cpu_dispatch.h,sha256=lL2yATflB7ivMB-qC4QAey-t2Pb5zMz54A2UqerqThU,2084 +torch/include/ATen/ops/normal_cuda_dispatch.h,sha256=Ui6Qx6d3KqUjznktmz_kRdKEU5VTJFyc-a6bBYm-Hyw,2086 +torch/include/ATen/ops/normal_meta_dispatch.h,sha256=5YnKjk_lZzYtdKdYbk-G2-iq3yIgGEaHOLam7eKV6yU,2086 +torch/include/ATen/ops/normal_native.h,sha256=FHsnsdD7Z8aNT2RAiwKTGYBifwE7WwJPHRDiDPvTp5M,3373 +torch/include/ATen/ops/normal_ops.h,sha256=F1CVP83pRfkemEsjBnK9xAyUervgyywUkgBoGR8H2yg,9773 +torch/include/ATen/ops/not_equal.h,sha256=fczZ-TaOECbHoYsN1oTKN9omjyQ3cLZgj0og5JmAKLo,2029 +torch/include/ATen/ops/not_equal_compositeimplicitautograd_dispatch.h,sha256=Z6ssIqH5PM5q2XBbm_uh_B5asTIuO8maIKDOMUFxkgU,1496 +torch/include/ATen/ops/not_equal_native.h,sha256=vfeTQ_zto-B_6Qg7rN0U4avJTUfkjQfn4HrLG2an-MI,988 +torch/include/ATen/ops/not_equal_ops.h,sha256=uxOEv3cRXJoQFyBUMG7aUKR-oMDKd058Ft688VBvFbw,4411 +torch/include/ATen/ops/nuclear_norm.h,sha256=5vjk3QT7SO0yMgRx7Rhwn83wN1Oo0nNSiJfMr3Lua1E,2140 +torch/include/ATen/ops/nuclear_norm_compositeimplicitautograd_dispatch.h,sha256=j3tjiK_pMW-g9fiD3kQoNX73SGAuID4VzWpz4RsZg50,1367 +torch/include/ATen/ops/nuclear_norm_native.h,sha256=08aC2xeTpoHII0_h_VltfR_nB5vlZi-9ZqKfhZpt2gw,844 +torch/include/ATen/ops/nuclear_norm_ops.h,sha256=j8U3F7WS4tyL56Sab8RYAYnHFq4Xxr9YScmrdyefFBw,3152 +torch/include/ATen/ops/numpy_T.h,sha256=DOIhW7DCaM3jFQqGFSzRCP13WQ5kwkrnURpbvDkGrsk,531 +torch/include/ATen/ops/numpy_T_compositeimplicitautograd_dispatch.h,sha256=-16WHMpDEKttaKzSkcr7_HasEnlimpKjGqhV9hZ07q0,788 +torch/include/ATen/ops/numpy_T_native.h,sha256=wUP2mwRYSHKyFbKfWr53OaMmZ8jyUfziQ4DGL6k4jWA,498 +torch/include/ATen/ops/numpy_T_ops.h,sha256=EHFYo_3M_pTBfyyEg1cbSP3P1mMxAp-v_Uk16827tVs,995 +torch/include/ATen/ops/one_hot.h,sha256=C6k3ltR1XjA3nqDvNFnt24RUIQxhl9k9ynckmCq3KzI,729 +torch/include/ATen/ops/one_hot_compositeimplicitautograd_dispatch.h,sha256=w3XA9Z7UdSC_I3hLer8y645lX_Zh2kR8nemuLpgDy88,812 +torch/include/ATen/ops/one_hot_native.h,sha256=_hfF-IbrhcOO0BXwun3Xhw_B4BATTT2lWpMhSxA9EPM,522 +torch/include/ATen/ops/one_hot_ops.h,sha256=zDBkLDtUwmMPHU1h4X6sdzo1wUHgE_yHu8x0LZOVwn0,1060 +torch/include/ATen/ops/ones.h,sha256=DYA9DIpSGeF_hSZeD5yVkqP2tDhktt4jgCgoY3Tny3U,6961 +torch/include/ATen/ops/ones_compositeexplicitautograd_dispatch.h,sha256=66gtK4veAMkCbe_ST21MwOZ5s6lRZCN7yGto7BswlLs,2208 +torch/include/ATen/ops/ones_like.h,sha256=MNH0qSGdy8MWy_e84vg3GTqL4tLrhozUE4DolmE7-SA,2252 +torch/include/ATen/ops/ones_like_compositeexplicitautograd_dispatch.h,sha256=QfVHhsOk3k2ogqWtvdZimRlThgXU0ySQEtYVpx0R1PI,1414 +torch/include/ATen/ops/ones_like_native.h,sha256=_WgjH2yGCVFmTz1VCXfVYNIJHH6vwS2FyNoPwSMI-mE,852 +torch/include/ATen/ops/ones_like_ops.h,sha256=2F3PeN2ZvWxy0JAzEVmhQKbA4fJMr_HpjsFCrKJGT4Y,2435 +torch/include/ATen/ops/ones_native.h,sha256=wSeAuUbvqcBgClxB0InwK2AdRXl7PMfp-dfGWFvlLU4,1090 +torch/include/ATen/ops/ones_ops.h,sha256=86wbQjUBuF-MPL8sp6nb5yJ4GDE0Q0ZSM79iy0zt2S4,3956 +torch/include/ATen/ops/or.h,sha256=QqXrsUqgsrNZ9e6R9lzMRS8gn3l_QXcPS2rJyrUkWJk,926 +torch/include/ATen/ops/or_compositeimplicitautograd_dispatch.h,sha256=OZSJ3zUZYPekwgEc9OdNiydFLc7HSfuEPe_y-mngDmA,1050 +torch/include/ATen/ops/or_native.h,sha256=Q3DmNr1Jx-ipL9HxCOX8pkpxWrky_8Cysm8KDNFl2Pk,760 +torch/include/ATen/ops/or_ops.h,sha256=jQ7efclI9QK-Oie15Seo_Hkqwk13oZD68EUnDwwtlGc,2941 +torch/include/ATen/ops/orgqr.h,sha256=rIpeqxkTW4u5f8NBVNaJHoTBw7_wM46kgYituFbHIAI,1207 +torch/include/ATen/ops/orgqr_compositeimplicitautograd_dispatch.h,sha256=D8SRqIzZwr2K8VHkzeEWgvcFZtOz54IaNoqIl_OtaRY,1024 +torch/include/ATen/ops/orgqr_native.h,sha256=crQS7Wl45ygfuBTnhnO7bPpdxhQzFjGrCfBOTuXpT6A,628 +torch/include/ATen/ops/orgqr_ops.h,sha256=8OUwrc9HGEaSxAmOQw2uzIzm9KX-63lLdUtxVN-c3pA,1759 +torch/include/ATen/ops/ormqr.h,sha256=Ly5phzwYPFcUSHSYlo7oCZzjCKKBbm_TikUhCaSDWQU,1625 +torch/include/ATen/ops/ormqr_cpu_dispatch.h,sha256=Si5ZEZV9ts81RTdIRuNWO06fZcAH3CSoEugY_HtwPSE,1164 +torch/include/ATen/ops/ormqr_cuda_dispatch.h,sha256=V8RHSsK9OuGTVIssvB80tZkGqCnPqLRiaXAH0i3q3IE,1166 +torch/include/ATen/ops/ormqr_native.h,sha256=-JQyL63-HDhA5lMQ8OxxqEr5djkY9LsV7KJ6YhDDKXs,747 +torch/include/ATen/ops/ormqr_ops.h,sha256=APAt0rGxRfsAnPele4Eh8jRQUW6kCfXq2PVDyCedT9M,2145 +torch/include/ATen/ops/outer.h,sha256=AJ5NH8F0ubQ6XsKaG4DZrJKf2iZ3icyU4a2uSBuDqMw,1189 +torch/include/ATen/ops/outer_compositeimplicitautograd_dispatch.h,sha256=cCi0gOIt4wksjCKp1O4hAdZLDXa58xMDxyat2BApoH0,1018 +torch/include/ATen/ops/outer_native.h,sha256=Tuw3noufgbij6FIGttvEfsnt1FmAFckrr5seu68MgeM,624 +torch/include/ATen/ops/outer_ops.h,sha256=7De-0nxFCCo8WaHwz-tWJu8DUSp1eQGVo0-04-Z8ZSw,1747 +torch/include/ATen/ops/output_nr.h,sha256=5U4BMgxUNN5etrXlQXsnrhyPb4b2uFQjJGgUSlwbZA0,533 +torch/include/ATen/ops/output_nr_compositeimplicitautograd_dispatch.h,sha256=6Gd4RS3mhcrWv6hSuOxxNFNrBGdsseJcV-xeCM9MM3Q,787 +torch/include/ATen/ops/output_nr_native.h,sha256=P4PJEmrzf9BmmOyZyea0Mjr1LYM2CZTncxj-Yr7Y6iE,497 +torch/include/ATen/ops/output_nr_ops.h,sha256=QhV6GzST-9yjLIPnRBy-EO2sTN9JzLFhfciS6miefPo,983 +torch/include/ATen/ops/pad.h,sha256=GXBqNeY6Cgh2GidsOLE1rXpiS07GLfdwrHMiEg2wr3E,1830 +torch/include/ATen/ops/pad_compositeimplicitautograd_dispatch.h,sha256=wD4D2RLz1mIbcBldgD7ONQNt0DgancfOnAkiziThwPQ,1049 +torch/include/ATen/ops/pad_native.h,sha256=k0HjgvuMVZaVc7-vuL7bCet8mzcVa3uDuf0Od8FwHLM,606 +torch/include/ATen/ops/pad_ops.h,sha256=_WVdAPeps3s7o0AkoFMGUlZ_R_uqSl7kmgW_p_pKvCw,1255 +torch/include/ATen/ops/pad_sequence.h,sha256=ao7FHmg4egIDf9t7OrDzATtMEOzIqZ6nzTHkDdGRZEA,911 +torch/include/ATen/ops/pad_sequence_compositeimplicitautograd_dispatch.h,sha256=mXaBDlkq0NbS8iFNbY8OcsLzMuI8ktpCps7CiG7i4aU,883 +torch/include/ATen/ops/pad_sequence_native.h,sha256=TYDNyL8_bLT7fzDKgAJGoeldqRVqxUIGeUAg2jsBOe0,593 +torch/include/ATen/ops/pad_sequence_ops.h,sha256=PVnwIdPP2LlsBSC4UX4f8EyQ9T5ytcM1ovaIgmH4c2M,1260 +torch/include/ATen/ops/pairwise_distance.h,sha256=7VfcJmNmrNeqw0zKVZ24oiztRoUymJbkatCjeF7wNNI,859 +torch/include/ATen/ops/pairwise_distance_compositeimplicitautograd_dispatch.h,sha256=CTs4SzZVTH8VUD7fvGMmrxcj187HYfaasTbRmlGvoAU,869 +torch/include/ATen/ops/pairwise_distance_native.h,sha256=B9UwJqG1mT7PMzGpY1eFQ4BAE9DI2hVXyRtX92RTY5c,579 +torch/include/ATen/ops/pairwise_distance_ops.h,sha256=570w_tw5aRNuaEqRJ5X92pnyp8e9uFX8N9az100i9eo,1232 +torch/include/ATen/ops/pdist.h,sha256=11FAA0POrzfZ_QxqESSJjWqlcuCpqUsqdMmh8wtduFY,690 +torch/include/ATen/ops/pdist_compositeimplicitautograd_dispatch.h,sha256=Z5l2UQRGHeMDSkAqrbM-2TvBEU-okClhqFmb7exwU4Y,798 +torch/include/ATen/ops/pdist_native.h,sha256=tsJbyXCmZxnb14SVh9ixSnC8s9li-F40NAx944tAFDg,508 +torch/include/ATen/ops/pdist_ops.h,sha256=7laoT0ycq7g4FUeDSqSa29VUWsiHhhx4H-89_mjQDDk,1022 +torch/include/ATen/ops/permute.h,sha256=giGIYgpiBtZA2yIBcdYFQmpJW72bf1ZoY23VYa7FUIM,718 +torch/include/ATen/ops/permute_compositeexplicitautograd_dispatch.h,sha256=LDhMjhEuPcE7l03g4j9zh3C7Mdw3wg6X_G9iKMdxT9s,810 +torch/include/ATen/ops/permute_copy.h,sha256=Dd5Wn09KJ7_fULelYs29gfpFQAm7tjD-roE1JtYuWRY,1247 +torch/include/ATen/ops/permute_copy_compositeexplicitautograd_dispatch.h,sha256=qmzsx-m1zyTL1K-avvVCNL7AzT3myogL0BRqhmA34Uo,947 +torch/include/ATen/ops/permute_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=CoKAIx2eu66YtL8zZuNwAQQpTyyx1JYDCyDIkJKblgQ,841 +torch/include/ATen/ops/permute_copy_native.h,sha256=Iu7vDmUk9pvYSQRABucdEJxs13uTrAqHfhGleuQloz4,632 +torch/include/ATen/ops/permute_copy_ops.h,sha256=flbxX9NGyzVqogV0uizxxA9RDNAt9QUaTMBT8rIdNjI,1769 +torch/include/ATen/ops/permute_native.h,sha256=5HLPzSVULjfBNRwjDPN3VWLfKZtb8B35VyB-uKVtvqk,609 +torch/include/ATen/ops/permute_ops.h,sha256=MmvhxX-6Y8kaq4dDmC6xcDFSZ7WO9DckXH-MwgrsHls,1068 +torch/include/ATen/ops/pin_memory.h,sha256=cEKELYG-14VytXIviIJ3TfX_DHTM0cYAsPF-is97ECQ,534 +torch/include/ATen/ops/pin_memory_compositeimplicitautograd_dispatch.h,sha256=mWjPCKQeX6zM_RQjLGjsa1dulmuFEiVxAz2nPnZF2r4,842 +torch/include/ATen/ops/pin_memory_native.h,sha256=gFya2QYwb5Tfmr7SHNwguURyBPnbQhc5RqIynvDui3o,552 +torch/include/ATen/ops/pin_memory_ops.h,sha256=cMgBtImxWwXiR6g2t3ACQZse2x01Rcoo17xQyyXBrW0,1126 +torch/include/ATen/ops/pinverse.h,sha256=Ovu_edV70WfD2IgMXeiruofYGqGMlUj4E1Vrm5HbnrA,722 +torch/include/ATen/ops/pinverse_compositeimplicitautograd_dispatch.h,sha256=btPvs_wjlW0lL0Buig3MmagP9ei8MLwCKK_AzqGJaac,809 +torch/include/ATen/ops/pinverse_native.h,sha256=Mrkbydotym1h_Mv26_cMhZpEW9KYyrGwXT6SiYOqOVg,519 +torch/include/ATen/ops/pinverse_ops.h,sha256=KGjV2lipd2p2XyJ-gkUvU-tN-bWbE6QlYCtltd2HPCY,1047 +torch/include/ATen/ops/pixel_shuffle.h,sha256=ZOgGzzwr-NhAzX8zC8XYd4LWjwbY8suQuIqBEQwxRvw,1317 +torch/include/ATen/ops/pixel_shuffle_compositeexplicitautograd_dispatch.h,sha256=6WqVv4Wljk2AwvgwEA-_ND-Y47xfoMoHUvB9rj1pgvc,953 +torch/include/ATen/ops/pixel_shuffle_compositeexplicitautogradnonfunctional_dispatch.h,sha256=5hprYUBkhGF3CP0Fy4hBR5axgRStzIn_mTOfd6ikxJU,844 +torch/include/ATen/ops/pixel_shuffle_cpu_dispatch.h,sha256=Opd0ZXNZEdkhWFzJ6Ow0u-tEibkI4L6xeE6MFwzOWLk,774 +torch/include/ATen/ops/pixel_shuffle_native.h,sha256=umhie6T_zhRU56dPQKWbYoZNcsMaFdDZ1OKyytWZkw4,733 +torch/include/ATen/ops/pixel_shuffle_ops.h,sha256=anP-6xmHhBn_FKcWcjzS13uxI7ea4k02duaj9nIpdBo,1783 +torch/include/ATen/ops/pixel_unshuffle.h,sha256=cA73TwAS12WtQVu0y0DOVrrEiY0MYalXh4A9LCXkL6w,1355 +torch/include/ATen/ops/pixel_unshuffle_compositeexplicitautograd_dispatch.h,sha256=WIUyo04ziEP-jxCu0fuLFApUwcsYUa67A_SgubU4qBA,961 +torch/include/ATen/ops/pixel_unshuffle_compositeexplicitautogradnonfunctional_dispatch.h,sha256=4ffMTIU27j8BwdoVUT8b7tTHvuZaayfi4yfv1ZooxhI,848 +torch/include/ATen/ops/pixel_unshuffle_cpu_dispatch.h,sha256=-6qAo3bsgA644mDrEDONNNxiF8SKMXMoP-bnnsUSEA4,778 +torch/include/ATen/ops/pixel_unshuffle_native.h,sha256=v8ZVhayX3_ex4ZbbPVcuzaEaf4oM5oL_aoXOkGJM_AI,745 +torch/include/ATen/ops/pixel_unshuffle_ops.h,sha256=m8M27tRecg8CN20pd8kJTs94cW8dw6n-ExlCIWWFBII,1807 +torch/include/ATen/ops/poisson.h,sha256=geCmU6GoB1dK3Tid2ldOybruaefZe6iWZG3kaSZTvFI,1347 +torch/include/ATen/ops/poisson_compositeexplicitautograd_dispatch.h,sha256=vzMJsJJ9uKoY1paN-SI3KZblGDg2tfoV4-obAENFOEM,992 +torch/include/ATen/ops/poisson_cpu_dispatch.h,sha256=IF9j3pMqbP9SplYslWTNx8SXIl0cC7XiNfFsp9WNQEM,801 +torch/include/ATen/ops/poisson_cuda_dispatch.h,sha256=_xXZ9BzG2am5LU-7e6Ue0PAoJrKo1PgC3ZF7dfCP2Xg,803 +torch/include/ATen/ops/poisson_native.h,sha256=hXvrOGRPzcfeNHsVeRoIlj7FFlOnMgEk3KqG6Tnz5jw,805 +torch/include/ATen/ops/poisson_nll_loss.h,sha256=bk_EtLBI9lxAixFbST36ielHlKK2hZDOXgZNu-FzVtc,905 +torch/include/ATen/ops/poisson_nll_loss_compositeimplicitautograd_dispatch.h,sha256=4w01Wz35l9U_pKMDA_yvcNGdAwykWnEdpCbWuA_L66k,883 +torch/include/ATen/ops/poisson_nll_loss_native.h,sha256=AuGQ-U76Ic05Kczr00qw6659Xo5yHm1LlmORxLLPME0,593 +torch/include/ATen/ops/poisson_nll_loss_ops.h,sha256=xONFzGO1W-jd51jC2nGg_ePAFxgvvrTX18QlBeOSqRY,1306 +torch/include/ATen/ops/poisson_ops.h,sha256=0dDy2-sAmaOQYU2ip_NPqW4BlJRsk_kwa6lhtMrcTa8,1879 +torch/include/ATen/ops/polar.h,sha256=crDk0M0aqfMABebuBtAnL0SWQrs8js4McNeo5abn_ug,1189 +torch/include/ATen/ops/polar_compositeexplicitautograd_dispatch.h,sha256=W8FF6dIA6QPw2l7FsSFQz1No2EI8qVceUU5KFJcFXKM,811 +torch/include/ATen/ops/polar_cpu_dispatch.h,sha256=zGm4TAKI-7Nia9k1ZD0eeiiIQTtYPUj-mi2eEbT2V04,895 +torch/include/ATen/ops/polar_cuda_dispatch.h,sha256=XNf3QYZ62yzfFYD0v5YLYe8QoCKU8V9wzG8R1nE2CFw,897 +torch/include/ATen/ops/polar_native.h,sha256=b1wAxwwaMwAM6hldZioLyKjLX-_OulWPoEEKDJqMgXo,624 +torch/include/ATen/ops/polar_ops.h,sha256=hjGpCJ5LXV0kPZA3ClGcYUN6FmNa-BZLWwwvPTBD4fs,1747 +torch/include/ATen/ops/polygamma.h,sha256=NlEfml1vbhurBR_JjKNXV7zugoR6zUvIw19D0asgUPw,1160 +torch/include/ATen/ops/polygamma_compositeexplicitautograd_dispatch.h,sha256=ecoi_p_ArZ0gv9KPPpjgvl8Z9Lar1_oCPbKqcQ7SGDE,798 +torch/include/ATen/ops/polygamma_compositeexplicitautogradnonfunctional_dispatch.h,sha256=kFSv2CzlrZnPGRkpHGVbERMCIWGv9UdEFtZjxpcWkJM,827 +torch/include/ATen/ops/polygamma_cpu_dispatch.h,sha256=-zEHbD1zA1qturI2Pr5_ZsaxOc3mMEuT5pSt_mMQlow,944 +torch/include/ATen/ops/polygamma_cuda_dispatch.h,sha256=8itUXUnYnAwF9duQ9QlwLxDML4c_QdwTCjxX_hgcnfo,946 +torch/include/ATen/ops/polygamma_meta.h,sha256=22gfr9s7UwPqphhlPPBGxeLPcT8BNbg1wlYqgH2KMjo,608 +torch/include/ATen/ops/polygamma_meta_dispatch.h,sha256=8m1yi7O8GDRVYAiwDACxqi3kTMF_IV3Fdzv83PLgGSI,946 +torch/include/ATen/ops/polygamma_native.h,sha256=f_00PMXq2TvKzC8YxI6vDJCm8B8qoaeXHgOESnWokxQ,705 +torch/include/ATen/ops/polygamma_ops.h,sha256=SLIQjbQ79SXlVlKNLPieWmnrlDLZbA9LtpMqfGjHVSs,2238 +torch/include/ATen/ops/positive.h,sha256=_1n9U-h7OPaMF-reF2ePT6G6SQVJG9PZpXxpdY3G8HY,682 +torch/include/ATen/ops/positive_compositeimplicitautograd_dispatch.h,sha256=GOiEADQKnI7rHdmuG09mVA8sfde29kdZiL_mXjpEloQ,789 +torch/include/ATen/ops/positive_native.h,sha256=cN5ja50kqlPcAcE37U9WE202LAxc0BasTu36KkiTcPU,499 +torch/include/ATen/ops/positive_ops.h,sha256=kQ0yOMBR_fY1LZSVqaaddiaH7HcI7VgQGn2I4bWdDTo,998 +torch/include/ATen/ops/pow.h,sha256=D_12CMU4dLnXzg3zMOug_o8BTo6BiXlrgAY6PsZqIpo,2775 +torch/include/ATen/ops/pow_compositeexplicitautogradnonfunctional_dispatch.h,sha256=JJzYhl3lAX1MsPqtCazt4awU3uAoJ4EIiIqqEhTmxMk,1157 +torch/include/ATen/ops/pow_cpu_dispatch.h,sha256=ryy24PGCMyCt3jjcDxizj_xVN43_80MQVhal-bOte7E,1720 +torch/include/ATen/ops/pow_cuda_dispatch.h,sha256=BLZceOx1gW_mDOEdfhi7wXM7zawEyC7fms_NxyHYXL0,1722 +torch/include/ATen/ops/pow_meta.h,sha256=dbJKyUvvuHwxj7Hf-0lgFE7Q8nElZufVpQHmebFuAhs,937 +torch/include/ATen/ops/pow_meta_dispatch.h,sha256=ggZDVcwPTxjd3wVG70bQAUzfES2VOmyU4G8NXxiTyJs,1722 +torch/include/ATen/ops/pow_native.h,sha256=mwpimvcT-B3xOohDiQF9d1fy93g6may1YMh0qFm5Xrs,1257 +torch/include/ATen/ops/pow_ops.h,sha256=WuSHC_DH5oP_PzBSdgqFokEszzItnxs_AbtrYYmpk_c,5766 +torch/include/ATen/ops/prelu.h,sha256=nLtWBKR0IMc_5HjlA5YG1X9ir7c1OwI7dT1hGCKBAT4,714 +torch/include/ATen/ops/prelu_compositeimplicitautograd_dispatch.h,sha256=KG2DXixtW9F6UdLAvHrbwEF4oPNKMAvnwD2o5ZJe86A,813 +torch/include/ATen/ops/prelu_native.h,sha256=HuwdL1M64HhicZkqxyJ-92xhyfTj5ad2HsHhi2ar_Zc,523 +torch/include/ATen/ops/prelu_ops.h,sha256=dQFE-0Ilq6tu4o5-mDjwW0Vj-IJSgnJKPBU2J6B2dMc,1072 +torch/include/ATen/ops/prod.h,sha256=D-JKqXCTpGLyAfh7boTIHfwPAro1n4xH6zI_S-G7uUk,3374 +torch/include/ATen/ops/prod_compositeexplicitautograd_dispatch.h,sha256=rn3WS_UKI-V8PgkTVKfknI2OHXkm4KrjzJeCSQSSYWQ,980 +torch/include/ATen/ops/prod_compositeexplicitautogradnonfunctional_dispatch.h,sha256=HZF7vg6O0C2RTerq9FhCFem3WLKNrD4ab1p8_4JaX0g,898 +torch/include/ATen/ops/prod_compositeimplicitautograd_dispatch.h,sha256=-Tdqh_6P0YnVbHG3M_X57lodFSvafQKZaWpovdaXq3U,1192 +torch/include/ATen/ops/prod_cpu_dispatch.h,sha256=yqhHXezQ4U9wpxplGQt9ujDcrhor004U5tEIdJUnl5c,1243 +torch/include/ATen/ops/prod_cuda_dispatch.h,sha256=AvvVaDem8lk7uvcMO6IwW3kCD1pwj8pQbtMXSx6YkrU,1245 +torch/include/ATen/ops/prod_meta.h,sha256=bvP5u4QgJe1y0qjJnDSGCXRItLjrLsRvynLD2c7lJEI,666 +torch/include/ATen/ops/prod_meta_dispatch.h,sha256=BAYVv3isupIc5lnCQaM0frogN12bE09qCtQ6u1B1dJE,1138 +torch/include/ATen/ops/prod_native.h,sha256=qiWIL9KGpDAFSKcNoQGRQl_K8otY2cqg66wcY3ksXno,1201 +torch/include/ATen/ops/prod_ops.h,sha256=VpbI0b4qH_2amCFPIbH0FPE92i37U1sKqCHvht5B_KU,5112 +torch/include/ATen/ops/promote_types.h,sha256=c__O5UI9bm94Ircw7iM4XLJ49AzCnxAk5cTE8uWeRI0,754 +torch/include/ATen/ops/promote_types_compositeimplicitautograd_dispatch.h,sha256=RUh0RQinWZOw76uundTAbZGNpNNQjzw4pShOeqGCz40,817 +torch/include/ATen/ops/promote_types_native.h,sha256=6F8rJtK9VCd3580QrhVNq2-mkg9uhf7vamYNduiTSBE,527 +torch/include/ATen/ops/promote_types_ops.h,sha256=_w_131vUBxBe4B-ZFNOOWuYABIC5_dX6Q_Dv_R-udx0,1096 +torch/include/ATen/ops/put.h,sha256=hNgd_dIW2U0b1vkiH8fxlsFKXIgcpyR62yBy5XnZgUQ,1496 +torch/include/ATen/ops/put_compositeexplicitautograd_dispatch.h,sha256=Cp8PbYQ8va7pdBTQoRPFHG8g9lyIY0sCLiaDyPGYVJY,1159 +torch/include/ATen/ops/put_cpu_dispatch.h,sha256=HtfqDYPihnK4zs3rdcnU-fAkO2JDZCYSIc2h9piP8h8,813 +torch/include/ATen/ops/put_cuda_dispatch.h,sha256=MX7mRZPNDJOCRPfwcVc06D1vBfE2nRYKEZ_sjEBSh3g,815 +torch/include/ATen/ops/put_meta_dispatch.h,sha256=2oMmTGIYrsKv95FgtW46oJIMCVoH28M7h7k1yR_vMpM,815 +torch/include/ATen/ops/put_native.h,sha256=aAppfwzdBaq3dF5UTsYYV0ZSePAkSeW-Hw_uHnoK5c8,841 +torch/include/ATen/ops/put_ops.h,sha256=IS7xsCpfiPnSvgql1V0M3Km5z6bzNC0PZzip4cUTWTU,2784 +torch/include/ATen/ops/q_per_channel_axis.h,sha256=AHhalERTl95u7mfGcOm8vyO93khHfsC06kzGoUY3hjQ,710 +torch/include/ATen/ops/q_per_channel_axis_native.h,sha256=OrprVoZ0VHoGq1eMQ2Dga8Kb1Q2kv8kqzaqKx5puWcM,506 +torch/include/ATen/ops/q_per_channel_axis_ops.h,sha256=RNO_4OuoSaSIFfQjVL1js5LMMnLQVVKD_XJEKiY8C_c,1010 +torch/include/ATen/ops/q_per_channel_scales.h,sha256=OtrQrlSAmM-Oytjx0yUZ-9YN__auiybzS9EUnv_mlio,1207 +torch/include/ATen/ops/q_per_channel_scales_compositeexplicitautograd_dispatch.h,sha256=hUgV1ZiSaPhyTUCHjj_EQnNK30CePctPP23Edb5IH8U,919 +torch/include/ATen/ops/q_per_channel_scales_native.h,sha256=RW5RQWJcofvGTvLe0d_zGU209g9J8mf5naNwDZtZWmY,604 +torch/include/ATen/ops/q_per_channel_scales_ops.h,sha256=7Y499BdInuTUHsPF5686EiQxQuftvbGZKl_UYRgV3IA,1671 +torch/include/ATen/ops/q_per_channel_zero_points.h,sha256=uwHWx9KYzit0XusxEdCsa2fjStGMaJdAyL8LjcS15NE,1257 +torch/include/ATen/ops/q_per_channel_zero_points_compositeexplicitautograd_dispatch.h,sha256=8HYV2QCu7iKZCkHdyD5b_S49GKPtyRAmeic1AUQG3io,929 +torch/include/ATen/ops/q_per_channel_zero_points_native.h,sha256=WS7eNFaisMZX_45DFxQRlOhYwudUkv_eIgEwBI4DXTA,614 +torch/include/ATen/ops/q_per_channel_zero_points_ops.h,sha256=jTUY-y0mRabwEB-ZPdrtIfruN7pt0RosO28OizCP16c,1701 +torch/include/ATen/ops/q_scale.h,sha256=hdzsnwzr3I6vysWKeP_RqH8ji73AIMNu_LrzNeXsJpk,667 +torch/include/ATen/ops/q_scale_native.h,sha256=x6lpYgRZ6nJPC66mIhlfy093sJNHoplxYaxiujHX05o,500 +torch/include/ATen/ops/q_scale_ops.h,sha256=oXoaVUIwqq0T6gr81qCbK6seSV5vqT7IOfYABSnCrYY,976 +torch/include/ATen/ops/q_zero_point.h,sha256=WgFDOKMwR6I2vZJJ7mL_BZ3Z2Rfwyxz0GWJ4DRSnQUc,686 +torch/include/ATen/ops/q_zero_point_native.h,sha256=agx0Ar88aEhHaoHI1bOySghFF_omj7c_yYX_PIYEU7M,506 +torch/include/ATen/ops/q_zero_point_ops.h,sha256=S4iJtAh7izKDqRy9AT8uhBW8v8ySUpECALJCAsvAi5E,992 +torch/include/ATen/ops/qr.h,sha256=XCIuk6csa2suPOhUv48zdPV9iRJc0R9YXnvj-ky0Fv8,1311 +torch/include/ATen/ops/qr_compositeimplicitautograd_dispatch.h,sha256=F8E0h1WE4cPhe41ZYbKwkkKB32H7w3n5GuBoGS0Q_7E,1084 +torch/include/ATen/ops/qr_native.h,sha256=LSMiJ6oObOj1sjW7fmZXBv1aLyyic_CLzugKAYKmTL8,661 +torch/include/ATen/ops/qr_ops.h,sha256=Gp2g3RrxxXkPBhfhdc7q4dagNgq02yDH60znCIpwKeg,1887 +torch/include/ATen/ops/qscheme.h,sha256=1ZxmWZICOcELy7XKpWe3CoWqtTsZ2gBTG1kUsMR6c78,531 +torch/include/ATen/ops/qscheme_native.h,sha256=Zx1VOhBb-LJBES1HbXRSVY32xXTWsAhC5EXJuX9Wkj0,505 +torch/include/ATen/ops/qscheme_ops.h,sha256=a4SQlo8mhx4f-jQ-V51gA8jIR068E4E3IqtTVcJghlU,993 +torch/include/ATen/ops/quantile.h,sha256=kMaa2chO51y2kiGkW36sMuHBo763IFGEWSP8uRghfdg,2991 +torch/include/ATen/ops/quantile_compositeimplicitautograd_dispatch.h,sha256=DJ4HbWwaI3bhrPj2rkOzeHf5TA6l-6vXECjVvmBjZIw,1844 +torch/include/ATen/ops/quantile_native.h,sha256=HayYd5NRS-PiTwPzJfPuBVgQ-RmqkOiYpEJkvewiTQU,1146 +torch/include/ATen/ops/quantile_ops.h,sha256=BhR8iN1wKjC6CQAYsZjKBcOcU21MWZK2AZsCOXMMIjM,4052 +torch/include/ATen/ops/quantize_per_channel.h,sha256=s_C77gZ7l-Jrvsp8SYoceWs0EkNTAp2b6oyDn8Q3xz4,1783 +torch/include/ATen/ops/quantize_per_channel_compositeexplicitautograd_dispatch.h,sha256=RebNwmadWc_DU7QzSTibSP1nuRUL9hl8tCDsASqkiI4,1109 +torch/include/ATen/ops/quantize_per_channel_cpu_dispatch.h,sha256=DceOqlAQ717PDE_jx-H2SnhAbHOXmoOh_ubKOWu5XQg,852 +torch/include/ATen/ops/quantize_per_channel_cuda_dispatch.h,sha256=kYB66OSBW9Odrj0yPVmGzniqQqaZ19PtUiMwGvDYb3c,854 +torch/include/ATen/ops/quantize_per_channel_native.h,sha256=MdQyjGYYeCFfyPNuRqdyiCtGDvY60xIdnyh8uyJ7cMI,794 +torch/include/ATen/ops/quantize_per_channel_ops.h,sha256=EM7uJ2nwo7cPmb07KoesZ0304X5nsbj4LuqzwGCcscE,2307 +torch/include/ATen/ops/quantize_per_tensor.h,sha256=nQhaEaSCOep04j3Jf7Ko7UGXoZcpJgjiF9-yJFLmapc,4002 +torch/include/ATen/ops/quantize_per_tensor_compositeexplicitautograd_dispatch.h,sha256=2mOCYYWVxgBXrXSFnFLp1JLq0xZWHmNIuKn3mUwgbOw,1705 +torch/include/ATen/ops/quantize_per_tensor_cpu_dispatch.h,sha256=VgXn0rRFcRlWUfmx0XkWdbV7UQ-svsBo_ENu-30kyCo,1122 +torch/include/ATen/ops/quantize_per_tensor_cuda_dispatch.h,sha256=OMCfAtqpr37YckMAnQZzT5CFeVzznLqh7xFQd4XIW-g,961 +torch/include/ATen/ops/quantize_per_tensor_dynamic.h,sha256=8THbn_JTuPpnfvB0mlpyR0BjPCgH5qliCgiQUPm-jsc,1574 +torch/include/ATen/ops/quantize_per_tensor_dynamic_compositeexplicitautograd_dispatch.h,sha256=Xl0aEMMckDu-PTk3_k08Jdtuk3oUY5lCG9vQWzfhm_A,1015 +torch/include/ATen/ops/quantize_per_tensor_dynamic_cpu_dispatch.h,sha256=T5JG1McI_stUq_-zwgEcUhgPPxOyEBi7A9bCVUfeiYc,805 +torch/include/ATen/ops/quantize_per_tensor_dynamic_cuda_dispatch.h,sha256=vLlk2lvdxOXLnWL93LWgkDFvrFaIzS5hZ8hifn8qmOk,807 +torch/include/ATen/ops/quantize_per_tensor_dynamic_native.h,sha256=-2mDTx8vcYWNH1hO-ELVUt2xCho-X3y632FDGDC1K0U,700 +torch/include/ATen/ops/quantize_per_tensor_dynamic_ops.h,sha256=SHGWP-bKT00dZ7yufrC2dZWy6R_JpdFcCDZIs8hSuWY,1995 +torch/include/ATen/ops/quantize_per_tensor_native.h,sha256=xvdXmGaHpxb_Kd-VJQw6jIi0d9AOgPKPL4GCGGtuoms,1408 +torch/include/ATen/ops/quantize_per_tensor_ops.h,sha256=WYH2R28DFC4Tdi7oYy8oGBSPdyN-logGSduvcdxR-IE,5658 +torch/include/ATen/ops/quantized_batch_norm.h,sha256=U2qZ2oR74fh6fecpduxcx0qNcdNZpqe8VDnK__goyyc,2317 +torch/include/ATen/ops/quantized_batch_norm_compositeexplicitautograd_dispatch.h,sha256=IwsCU81eHBBcVBx6IjhXj8x_Hb2zK88euHT96opqZwE,1311 +torch/include/ATen/ops/quantized_batch_norm_native.h,sha256=wAsfffgy1H_SQhvhM92XSYZN0mMCZTZZ2O5dxyc2ZXk,996 +torch/include/ATen/ops/quantized_batch_norm_ops.h,sha256=fEvbU5qfn6INxY82_k0VK4HX2wZ-LwHMY-CS3_-gqeI,2953 +torch/include/ATen/ops/quantized_gru_cell.h,sha256=MAO2vuEd3qzqKFpcOBGB2vLftG6fx8VAMgP2DM0bzOA,1453 +torch/include/ATen/ops/quantized_gru_cell_compositeimplicitautograd_dispatch.h,sha256=4DG7qhQlnCaCquNnpfeMN-F4zHPIacI-gwE7s1_4dyA,1179 +torch/include/ATen/ops/quantized_gru_cell_native.h,sha256=iiLC9znP5mAICCr2sbVVfIjVpn8MecdNFhaPu7K_Kys,889 +torch/include/ATen/ops/quantized_gru_cell_ops.h,sha256=Ik7u0RSab02IiCE6yerbW_u3VhExknu9x7vVvhdP2l8,2266 +torch/include/ATen/ops/quantized_lstm_cell.h,sha256=jA_yGP430KzHDpUnkmGn7R-aQL6i0SLw_Z6CVIOkm-Y,1490 +torch/include/ATen/ops/quantized_lstm_cell_compositeimplicitautograd_dispatch.h,sha256=mfVeVk37Ha_ba_536IMyke930nqIG-j9HmFEPR56fYQ,1201 +torch/include/ATen/ops/quantized_lstm_cell_native.h,sha256=akFcsDvivsp6X3J14hBDyJdk5bWogDf7FSpFCRP8xjk,911 +torch/include/ATen/ops/quantized_lstm_cell_ops.h,sha256=BynjVRr3m_mYtn1ExCxwUGUqpZOORrh83sWjf6L5wJ8,2344 +torch/include/ATen/ops/quantized_max_pool1d.h,sha256=E-f5tC6q-9NDrssIyns6izDjiW0pfcZC86RFVOmXdxo,2037 +torch/include/ATen/ops/quantized_max_pool1d_compositeexplicitautograd_dispatch.h,sha256=2F6PwvpgpoBPQbNVYn8ZsO58PczCbSDVlrQ8tOEjwUs,1172 +torch/include/ATen/ops/quantized_max_pool1d_native.h,sha256=W1r_94fU_OQOftYNQC9Z134hq4Kgf78DRHNdFYsZ-q4,857 +torch/include/ATen/ops/quantized_max_pool1d_ops.h,sha256=kOs85fKDhNt41Z8nUS43vCw4-1rhlx2TCHiKHbmARxc,2493 +torch/include/ATen/ops/quantized_max_pool2d.h,sha256=YWXdeRz_mr5WYVZ60fbpYioi5bJRhQmBIIHPgEaCiYk,2037 +torch/include/ATen/ops/quantized_max_pool2d_compositeexplicitautograd_dispatch.h,sha256=2z1bvV9tQfvDxrBjfROoUXFPXvmc2HwwlC-DQRvxLFM,1172 +torch/include/ATen/ops/quantized_max_pool2d_native.h,sha256=pwcafPocpPx3O10riy_egZuPccAj0qykkMNvVsm6YZ0,1065 +torch/include/ATen/ops/quantized_max_pool2d_ops.h,sha256=ddnwrSzyUQJCtpWt77NPr9DMhOFdYfHHxuwJN2OuySU,2493 +torch/include/ATen/ops/quantized_max_pool3d.h,sha256=yRpOgz5C0Z9s4pnTS9NaiT1jJWBdy-UVkiio8Opz3Co,2037 +torch/include/ATen/ops/quantized_max_pool3d_compositeexplicitautograd_dispatch.h,sha256=rJdzl1d_IsSshFYu8m5ZRg-qRHit4l0ASRkR9sTpeJE,1172 +torch/include/ATen/ops/quantized_max_pool3d_native.h,sha256=HSkN0ye1rpNKb0MbaTLLCkkwSM7bspkr50n5v9tiH3s,857 +torch/include/ATen/ops/quantized_max_pool3d_ops.h,sha256=6B_rBNEqOykWFd8cfKmgUrfFpTc_ETe0T-rqnbCz2TA,2493 +torch/include/ATen/ops/quantized_rnn_relu_cell.h,sha256=FHLbH4YIWiplRVcopCQxiqRAIkDJR0Plszgi-nsNc48,1473 +torch/include/ATen/ops/quantized_rnn_relu_cell_compositeimplicitautograd_dispatch.h,sha256=_5xk3Ttc9020YtinGdpo5jbnqWuHwTjxTv03z-XEU0E,1184 +torch/include/ATen/ops/quantized_rnn_relu_cell_native.h,sha256=uzGwD2JmNZucHofgLBLfNtBoERavDEun9XZzUM0xw0M,894 +torch/include/ATen/ops/quantized_rnn_relu_cell_ops.h,sha256=NRBaLMxiuGFHWZ0ZzpYKMhKVWYUNAjwKIWbYquicxkQ,2281 +torch/include/ATen/ops/quantized_rnn_tanh_cell.h,sha256=2PfbRrXE0Sp6WxHzRiSCT17S0gwFRCJF2zLwIxBCDNk,1473 +torch/include/ATen/ops/quantized_rnn_tanh_cell_compositeimplicitautograd_dispatch.h,sha256=nWird6MrFdEV8-u1o5q4puCtrTA8yfMkzMgTBUWGucg,1184 +torch/include/ATen/ops/quantized_rnn_tanh_cell_native.h,sha256=uhc-NtSVxvet-yBI--vbjYj-DBY1shT3mw1--hzYIbw,894 +torch/include/ATen/ops/quantized_rnn_tanh_cell_ops.h,sha256=kZmwXhcYeVwqQcng1jw6LbkFwfxGncV6PgAqP5Fe2kk,2281 +torch/include/ATen/ops/rad2deg.h,sha256=pPEUfcMqLcLsAgszORGut-CubG7hwZjx7x6jAUbTnMA,1227 +torch/include/ATen/ops/rad2deg_compositeexplicitautograd_dispatch.h,sha256=Ax7MoEMsDK0iKcEXm4giP3ivzRgmjNFfQPRMKmXYDcM,1002 +torch/include/ATen/ops/rad2deg_native.h,sha256=uakaoH7TLiAZmt1roQdPaec27a0y-mNyAUdxE2GkOq4,1063 +torch/include/ATen/ops/rad2deg_ops.h,sha256=305xNxYzXrLz5s9H0joNeemhPJfiXeTot-7c4iV3l0I,2106 +torch/include/ATen/ops/rand.h,sha256=7h8eUykwNThVyxoOa72x9BGmAp60Qnp0BfdKyxtXqHQ,25142 +torch/include/ATen/ops/rand_compositeexplicitautograd_dispatch.h,sha256=U44wJ7s5jC7dWUt8qpwHZBB8CDrz2URd19GQDX6vpuM,5124 +torch/include/ATen/ops/rand_compositeimplicitautograd_dispatch.h,sha256=9J956eskSCdAicz2MSvqa4vOkaAxqflQaKX-ADSjikA,1220 +torch/include/ATen/ops/rand_like.h,sha256=S_GTCWRbkUP8WMi8bMuWbPpM9DZHe-uvnGrWqxbihLU,2252 +torch/include/ATen/ops/rand_like_compositeexplicitautograd_dispatch.h,sha256=NBYk_jM86U4RyELzgvIOuaoY-wuAmOZOkcTtU4lwaP4,1414 +torch/include/ATen/ops/rand_like_native.h,sha256=5C9vDd0RJNYx0IEZvc9c4JwbJXe0OnWUBJRUBtMAIps,852 +torch/include/ATen/ops/rand_like_ops.h,sha256=p73H8nPCHokFDRFw1P-t4iVLPG3K5mKVnWMkpD1pwuU,2435 +torch/include/ATen/ops/rand_native.h,sha256=Zd7AikPBdzN9-pa3jnVyjSGWeSkaG-R9Zs4hQv_z3Wo,1943 +torch/include/ATen/ops/rand_ops.h,sha256=w1CaxHdu6998wESkHJlFcOVrqM4L1l4jJn0pWOvANkM,8198 +torch/include/ATen/ops/randint.h,sha256=pup5eGA8aaLadbpvVpUYAKKhREpHLBrw-wgCNfmtBsQ,26385 +torch/include/ATen/ops/randint_compositeexplicitautograd_dispatch.h,sha256=MEHenKoEnlD_YmKT-Om3wGfeOI9s7Bea5Y1ja9ndC1s,5876 +torch/include/ATen/ops/randint_like.h,sha256=4_dW0hC4siKl8H51LEYenehVTZk_8r4pNufk37ucrD8,17635 +torch/include/ATen/ops/randint_like_compositeexplicitautograd_dispatch.h,sha256=5iikGhObJGdkXwK9o76GuKNnDoIeTVARM4i-643rpqA,4734 +torch/include/ATen/ops/randint_like_native.h,sha256=HkSi7iq19PvnGUEMAOn8eUlDmuCY8cyOfWcN8BtynD0,1865 +torch/include/ATen/ops/randint_like_ops.h,sha256=iWNXeI7S92agzasDaVKH0A0pSi56srnoJsonT7Z3rq8,7041 +torch/include/ATen/ops/randint_native.h,sha256=ymxxibwmrFZ7FNnkRpj2FZCkAooYLTSFJFFQRqI0oak,1922 +torch/include/ATen/ops/randint_ops.h,sha256=Q4NVvvdUfNO-9o7pN0huIgBfu-Is3e2XIa3_SBFV6hs,8420 +torch/include/ATen/ops/randn.h,sha256=WpceslN-ZhbEYYhfRxvjSDvjpREQLICP8BgU56aozhs,25303 +torch/include/ATen/ops/randn_compositeexplicitautograd_dispatch.h,sha256=qKEMXjI21CPOR4Km_xeqd3mvuJovBDfg2-2oj_9JflU,4828 +torch/include/ATen/ops/randn_compositeimplicitautograd_dispatch.h,sha256=dVwNQu4LCk40iGtRcm-T4KVWHeNrHROW9Fo37WXJN3A,1548 +torch/include/ATen/ops/randn_like.h,sha256=AbvjG0FHiJJjX6mgiWaHs1cctPVpue-CrUn06j1X1sw,2265 +torch/include/ATen/ops/randn_like_compositeexplicitautograd_dispatch.h,sha256=fuMe2ipJAjjBobGLyL5VcrbBC8AN6yXZtoDnar8xbPI,1418 +torch/include/ATen/ops/randn_like_compositeimplicitautogradnestedtensor_dispatch.h,sha256=mnnNLeel7DqBTkrwb3-U1ZhjKXmt2KfpBVGO7qtIHmk,1162 +torch/include/ATen/ops/randn_like_native.h,sha256=wEDA96M5DIX5USP9OPwbm5voxqxzLqdUAgH2FN3V0qo,854 +torch/include/ATen/ops/randn_like_ops.h,sha256=sMZT5CHpYd6jIo2w4iPcTvBG7d-c_ng1Qg02h3AEvUs,2441 +torch/include/ATen/ops/randn_native.h,sha256=VHa6UOgQcXxeX86iIEwzqHUlb6tX-ZoLZvccrnpmpX0,1951 +torch/include/ATen/ops/randn_ops.h,sha256=hnQ-wMnbcRH7DBEiqmrXziJ_rQHIpYW5crSoz27hlgM,8222 +torch/include/ATen/ops/random.h,sha256=lZNIn3RupoTE5h2YLNANn-krv75zb1LquIhnUrJP3Dw,3300 +torch/include/ATen/ops/random_compositeexplicitautograd_dispatch.h,sha256=5yu2VG1rjfUEuKPTLtNkCNvLE9Cdo7EKJXLxxYrWPeQ,2007 +torch/include/ATen/ops/random_cpu_dispatch.h,sha256=F3lO0hW3x-FUCF-nRZ3dXL3m_Zyao1_4BgNyMlMli44,1070 +torch/include/ATen/ops/random_cuda_dispatch.h,sha256=bsw6VugSsiAb4DNK2STcpCO-k-4rxWU3dMh0aT2c8Tg,1072 +torch/include/ATen/ops/random_meta_dispatch.h,sha256=hE_7JTlFyAdPuAPvs0LKBFDdf8HWHLcieN8qZSsneLU,1072 +torch/include/ATen/ops/random_native.h,sha256=pdn09lH-Vh9l5inWM9CvrEcNs9Sri3yll2HOqNiLCBQ,2038 +torch/include/ATen/ops/random_ops.h,sha256=Q42rLASriORbOpzllrkcnb9QjLkznVdu0EZmTDlXZlM,7257 +torch/include/ATen/ops/randperm.h,sha256=ZRAafCn7_X7cZY-jgE2ODX5vCzKl4OKHI7rFTGcpkAQ,11170 +torch/include/ATen/ops/randperm_compositeexplicitautograd_dispatch.h,sha256=msr9b7zbP1wtbQRb10H9mn048l0M3W6D_Dxty-6A_ys,2308 +torch/include/ATen/ops/randperm_cpu_dispatch.h,sha256=6PbFxH3ikG3wgm4y5B1QXofR7nOXhqiItWH5kLPA_wA,1148 +torch/include/ATen/ops/randperm_cuda_dispatch.h,sha256=k7ygHJJPlXRH9a1CShxViwQwQlu0yDx2gOOSH8Oll_Y,1150 +torch/include/ATen/ops/randperm_native.h,sha256=Bu1B9H34_dQqIW_9lgF8a44mAt3UytmvMc1MFgSqyfM,1178 +torch/include/ATen/ops/randperm_ops.h,sha256=JpaU8xCRca2smZJMgMiHsmo1R5IWcWBpOE8e84LeVdI,3930 +torch/include/ATen/ops/range.h,sha256=KTz5sdS4y2xO7NwAUOYDeh1gMAmDyoHOrIWWrrVEhkM,3455 +torch/include/ATen/ops/range_compositeexplicitautograd_dispatch.h,sha256=xEgI-hH-muDduxvPM_fj9X2gf1hVSSMnF2j7_5dLg5c,1657 +torch/include/ATen/ops/range_cpu_dispatch.h,sha256=iotevuOTQKiMHiDcXHIp3GsOA_mOZ7d8kRcobmb5pR4,945 +torch/include/ATen/ops/range_cuda_dispatch.h,sha256=q1fDPIZ3F0os8pqBpykAIXd1tDzIR_GviKzrjwe26aM,947 +torch/include/ATen/ops/range_meta_dispatch.h,sha256=S8WW9rrbwePxDUroTnBQeXrLYRgn1doBMX2DTr4aBhc,947 +torch/include/ATen/ops/range_native.h,sha256=Hmg8UTHPcwtKIO_1ZExTB5X8S5VNejcnbHjxvf_qtig,1313 +torch/include/ATen/ops/range_ops.h,sha256=SWtVMiasM2Z3DyFz14R87Frc-BfxhFw84dXEJ_aE27A,4198 +torch/include/ATen/ops/ravel.h,sha256=VLXyTb61yuKnEGGT-o2a1ts8_jQWDyqCrKuiebglmEY,670 +torch/include/ATen/ops/ravel_compositeimplicitautograd_dispatch.h,sha256=cFGqUv1uJaer_wvAmVn4yfW3sOHdkfIBQpL4JNY-XiA,786 +torch/include/ATen/ops/ravel_native.h,sha256=S0zJFTwaVAlVfwRFWvjw1ISdbYZ-mOIQaFL4tWtaK5E,496 +torch/include/ATen/ops/ravel_ops.h,sha256=pW2oQoatw1q5EQWnHlhXppbq_xlpY_-KnKB8kNBnlws,989 +torch/include/ATen/ops/real.h,sha256=hTR00UmblRKF1mhjie_rUYGX6GzOdK87Zyk5JuBRA0Y,666 +torch/include/ATen/ops/real_compositeimplicitautograd_dispatch.h,sha256=deGUVtTHkCpsEgllWZ7u3zq_ftqNGU8GN660NvPIOJM,785 +torch/include/ATen/ops/real_native.h,sha256=OSm_TBJE3j43ol5GI2gFf8nqCKWZbYpKP3-VEXp7H_w,495 +torch/include/ATen/ops/real_ops.h,sha256=U5JuEQ7Gc5BDnOcl8Pv1xdiHHKx-XzUpS_OkryAXuDQ,986 +torch/include/ATen/ops/reciprocal.h,sha256=rv2iDhE1ndL0LjnbAuXOcx5Ghhl0CxCztDtEGXXlHOU,1266 +torch/include/ATen/ops/reciprocal_compositeexplicitautogradnonfunctional_dispatch.h,sha256=7vn9674AL5mvaGRodUgnzu5YHCSkaHzfDnLiomJv6ss,873 +torch/include/ATen/ops/reciprocal_cpu_dispatch.h,sha256=nHSXwkzxamFVil8g0ql8t8lbq6ypCjDDEjyIWGv_j40,970 +torch/include/ATen/ops/reciprocal_cuda_dispatch.h,sha256=G-b9PklcSqFz21_MZ1bo42RVSyz-bprJr8MCI0m2hfg,972 +torch/include/ATen/ops/reciprocal_meta.h,sha256=_bSvyNhZjWU6HNckhr5gUm0u0mj6Js_lIOpJFAntnJI,598 +torch/include/ATen/ops/reciprocal_meta_dispatch.h,sha256=wsDeoSHhK13aoGDVwJ8XWacy8aMxbBN-JDxkM4EzxtE,972 +torch/include/ATen/ops/reciprocal_native.h,sha256=9dNylgspbgGZ_tUrX9_ZBOQ5ZFS1h1mKYP9QABksC3w,631 +torch/include/ATen/ops/reciprocal_ops.h,sha256=GItEk2KkAGga7x1CeYfEx9ad9g0izyl6TkHbq4-P9Lo,2133 +torch/include/ATen/ops/record_stream.h,sha256=Id3qwzjhYGOVDShDJWw9MAQrXQIw1jBgF9tMCtRcp7U,537 +torch/include/ATen/ops/record_stream_cuda_dispatch.h,sha256=bF9GDL4S9dSswTPNiImc32TOGy0jsGScyVbWLiagJXU,754 +torch/include/ATen/ops/record_stream_native.h,sha256=SwJdXLLj2ws4sTr3c7dMwpHVdGBJLmSDPV3KMXVJ-FQ,511 +torch/include/ATen/ops/record_stream_ops.h,sha256=B6q9J5OidQFjlqXRQZ4ExaeK-tul6d5-2knpJJ_9q6w,1021 +torch/include/ATen/ops/refine_names.h,sha256=vzqZXytwGMdRhu_3RmDeCiVfyk7dVPfHFWP1qdi2FAM,536 +torch/include/ATen/ops/refine_names_compositeimplicitautograd_dispatch.h,sha256=ocxupGOJ5bLnlwkC0xmlZWE5jDWnTmlhhvwABF3IsJQ,816 +torch/include/ATen/ops/refine_names_native.h,sha256=djvuv2071xuxbNk77W8F417qqV5aIYDOgmiEeWHkQ4k,526 +torch/include/ATen/ops/refine_names_ops.h,sha256=jYO0arBTXbjbqbdmjjNGU2W4LJSXFkeHl8gghGaZBTI,1090 +torch/include/ATen/ops/reflection_pad1d.h,sha256=nXfHPeJxZrJu6DSQFi--HYcXhxeFTtGLW6dosqbatFA,3989 +torch/include/ATen/ops/reflection_pad1d_backward.h,sha256=_BKZsG1vSaw8EHfG2zyDwK-MU1jCuT5T94hbnH0C5i8,5152 +torch/include/ATen/ops/reflection_pad1d_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Byl7cED3fAPXxYRInIcysRYLfGxaTlIbKEGJ7yIhASA,1031 +torch/include/ATen/ops/reflection_pad1d_backward_cpu_dispatch.h,sha256=wN0WJPr4iAS8A6HAj3DvdGFQ-C9AgD_pQu6EysIdWr0,1633 +torch/include/ATen/ops/reflection_pad1d_backward_cuda_dispatch.h,sha256=GZHAmHImVX_ZhH4LptiiJMDlxbVqr-PcTZBoeCVV76s,1635 +torch/include/ATen/ops/reflection_pad1d_backward_meta.h,sha256=s3inTPYDOvPK3QqbVNrhcdEzv-0wNVyio1mty5GgBOI,676 +torch/include/ATen/ops/reflection_pad1d_backward_meta_dispatch.h,sha256=Ls7cDz7Z3jUjRwev_lMxdKvbN_3U5O3XO66T8Lnut9E,1635 +torch/include/ATen/ops/reflection_pad1d_backward_native.h,sha256=gIQC-unAl0-qOFuDmSJMmUB6rm8thTmXQoAtTpcYYrI,1007 +torch/include/ATen/ops/reflection_pad1d_backward_ops.h,sha256=tCoHB5r1TF1IFkS7T52ucB73X8cDepc3SftC5aZzVJc,2147 +torch/include/ATen/ops/reflection_pad1d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=XUhasiO_t8hilvYVzstAKeKUmtVyxZz3hBImpTXJlZc,949 +torch/include/ATen/ops/reflection_pad1d_cpu_dispatch.h,sha256=RUkSheQMZs135ZoMRGPcZBjJTbC0tiAN5LKmqQlTQZ0,1359 +torch/include/ATen/ops/reflection_pad1d_cuda_dispatch.h,sha256=J8TzOpOuMi7df8YJQeeyQmTC9oZk-JW0cXauCm0YsEo,1361 +torch/include/ATen/ops/reflection_pad1d_meta.h,sha256=j9qQGFpfvyBCkeIm6Re_Ks_P4elafAhxPR1ljdgJq8c,635 +torch/include/ATen/ops/reflection_pad1d_meta_dispatch.h,sha256=AtEzPt6MVFO0CUl_S7S1iqA7gaM04Mj53WOWiwdFrS0,1361 +torch/include/ATen/ops/reflection_pad1d_native.h,sha256=DcC8NzOD90LfB9QwEI1n9YA6gBG_A--trkErrnjcbpQ,1012 +torch/include/ATen/ops/reflection_pad1d_ops.h,sha256=qnvRDOTaLZO3sBSM0D9AsGGLQ0zl48KvLdsdK2ezpw8,1843 +torch/include/ATen/ops/reflection_pad2d.h,sha256=LUomLzEdzapgSviaMr5VF5_Jogx0ejrsBHOzsRrKG4s,3989 +torch/include/ATen/ops/reflection_pad2d_backward.h,sha256=Iu6Q64dcQ6Pt_5xnN371zoCDUHPWArXAwDYJYqgKeJY,5152 +torch/include/ATen/ops/reflection_pad2d_backward_cpu_dispatch.h,sha256=kOw45kGu3AdimCNHfSKx2QwaocASAcvd7alEHVS39rk,1633 +torch/include/ATen/ops/reflection_pad2d_backward_cuda_dispatch.h,sha256=pfoKgAM3v3SzJAYmcAUarea96CYuqUqV3b9jQhytNl4,1635 +torch/include/ATen/ops/reflection_pad2d_backward_native.h,sha256=ucjZThzNCqDwA7T8NVcjv163tW4khdUFCfxGwbjI7O4,1046 +torch/include/ATen/ops/reflection_pad2d_backward_ops.h,sha256=htAGKsbrshg8ETJYiXX1rC2aSnZgNVmnATN3glC088Y,2147 +torch/include/ATen/ops/reflection_pad2d_cpu_dispatch.h,sha256=gpM-SnLKU3SoEycd6G1Ej9IkjfHR0bJqar3XHZBgxtw,1359 +torch/include/ATen/ops/reflection_pad2d_cuda_dispatch.h,sha256=NU50HMaOu383ZQxmG-vzNiEjvAtdraN4tPI9aCIoqFY,1361 +torch/include/ATen/ops/reflection_pad2d_native.h,sha256=dfmL386aQl8CaCQxduqw2C71TW8CWq4PXw2op-32P7w,972 +torch/include/ATen/ops/reflection_pad2d_ops.h,sha256=xe_QdVXhDMS3oOyRhgKGvk_lb_duk57rTeiQuFz0SXY,1843 +torch/include/ATen/ops/reflection_pad3d.h,sha256=3C1QvzJgEmW6lOQ7eUHQNqWd5yCfI1cXLnaYfYVrMVg,3989 +torch/include/ATen/ops/reflection_pad3d_backward.h,sha256=d_nn0h1DIzdtik_ZA10SZArMa-GMrIB80wTSZuvTyhg,5152 +torch/include/ATen/ops/reflection_pad3d_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=nBtitBWGDZ2qQ0Qo0Lkm17QODnK0bBpkjjCEj9T80rA,1031 +torch/include/ATen/ops/reflection_pad3d_backward_cpu_dispatch.h,sha256=PAaZe6TdL5Eu3x5s4MgDLoKc0rwePvC6cPdMXlwk3Ms,1633 +torch/include/ATen/ops/reflection_pad3d_backward_cuda_dispatch.h,sha256=NWzNx5BGLm1jEaAXxlCzGRMR8ldc8RmQyWvEEF3G-aw,1635 +torch/include/ATen/ops/reflection_pad3d_backward_meta.h,sha256=-B-CddRuXJEn8URHlOpYhBv1gnR2GKw3BA1T2QSMHk8,676 +torch/include/ATen/ops/reflection_pad3d_backward_meta_dispatch.h,sha256=5TZeR9IC193c8cg8dKTef05f5_DYMtDirI_DpiFJNyg,1635 +torch/include/ATen/ops/reflection_pad3d_backward_native.h,sha256=4a6czlEONj5Dm_h0rnMuih8JJBNr9pwplJUJ5Jca0N4,1007 +torch/include/ATen/ops/reflection_pad3d_backward_ops.h,sha256=lZb0Ij6UJq-ZbRRUydD_DI5fjGOhrOQCGLV7UeHi6xw,2147 +torch/include/ATen/ops/reflection_pad3d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=OSpqx5it7a4vav0xvNsH51QafhozveUJvZFPxzab_lw,949 +torch/include/ATen/ops/reflection_pad3d_cpu_dispatch.h,sha256=GA9g5wyRmbQa7G4CYehGoriYwIq6MzqnWHqOqUqzfag,1359 +torch/include/ATen/ops/reflection_pad3d_cuda_dispatch.h,sha256=xvtO4KYmd3BAn51BGp9q9P0DlG6paQ09SJ5xLDaHBko,1361 +torch/include/ATen/ops/reflection_pad3d_meta.h,sha256=4w-F6rkQAdzooYvNrSQjre9cUJ1QKcH8Qc9lHn6hSSU,635 +torch/include/ATen/ops/reflection_pad3d_meta_dispatch.h,sha256=NZwKo5qslyydPIpDWlnD8x9wqfgwfXv7WBqdv2tNT-0,1361 +torch/include/ATen/ops/reflection_pad3d_native.h,sha256=wmpODccG2ohVTlLNkBf6dHuNcGlVayskowqRixPEdZk,884 +torch/include/ATen/ops/reflection_pad3d_ops.h,sha256=gTsFsHN7nSTEWvZUPJJ36fl90xuAVdgTrNCERa9i7cg,1843 +torch/include/ATen/ops/relu.h,sha256=xnTTyhkXO5_fdFeJhWZfxagMwaDPI9-4lOSjLK28cH4,1188 +torch/include/ATen/ops/relu6.h,sha256=DvEaDWOhyfoii2YV6HYwNEOs3277N9VM7vhDGQYixlw,808 +torch/include/ATen/ops/relu6_compositeimplicitautograd_dispatch.h,sha256=o7j3GMc0Q0CAT-12Mtk19aJxd7ApTnkJ6TgL78JFMmk,837 +torch/include/ATen/ops/relu6_native.h,sha256=MJQKjJo4uAGuSohzjQa04dbEc0-VVTOpEnLku2ruRZA,547 +torch/include/ATen/ops/relu6_ops.h,sha256=iPWvrooyqLDqC6C3xRk1qJk9AoIZ6Qr60_DyKWiIiik,1490 +torch/include/ATen/ops/relu_compositeexplicitautograd_dispatch.h,sha256=6WIocVBNENJzDBVRjw9bi9sxQ7YE-JlHVxuuo-zLOSQ,887 +torch/include/ATen/ops/relu_cpu_dispatch.h,sha256=T5XFC7pBfvoQPW1ZfEF23lVVTAqwqKJnj_CwaA5hlMQ,791 +torch/include/ATen/ops/relu_cuda_dispatch.h,sha256=fD3Rvtw95YHBz0OOmRi1lfAZF8eH9DwAimP3gozZj4M,793 +torch/include/ATen/ops/relu_meta_dispatch.h,sha256=Pjslh42_99Lkq6IFVc7LruPqXuyidnB3LivzQVHCMxI,740 +torch/include/ATen/ops/relu_native.h,sha256=uVeD5_ozOum_ho5Ji8AC7HryOkZtN62vaOPardLAvYA,1374 +torch/include/ATen/ops/relu_ops.h,sha256=7y1xUA2wv7GKcspq6av-7m0bnJkWVWGHLkrqiqGl9fw,2079 +torch/include/ATen/ops/remainder.h,sha256=OjMK0QcWjSH2OI9A4OSciq11WpuUi2WOsjEmFZyjJpg,2820 +torch/include/ATen/ops/remainder_compositeexplicitautograd_dispatch.h,sha256=CNIASKgZRkWEOVOAVgfKIRVAM5c68p4vne-CJl6BPoM,1331 +torch/include/ATen/ops/remainder_compositeexplicitautogradnonfunctional_dispatch.h,sha256=bFSakTqxx1OQ2WmqDq3LYeBrFPPh_hmOFEJr2ej6iqI,923 +torch/include/ATen/ops/remainder_cpu_dispatch.h,sha256=ICI3KUUdxQkCx6-mVjYaI0AvcRvcSu1yRvpMeH2TuHo,1154 +torch/include/ATen/ops/remainder_cuda_dispatch.h,sha256=1ZH52unDh2dvCW_-63eRr2I7SA_lXRDetb1XO7updAw,1156 +torch/include/ATen/ops/remainder_meta.h,sha256=Mgd8l3cBX2gZmfrRrAdnLmo2MDOPFtoQZ5F3fnIM0rE,630 +torch/include/ATen/ops/remainder_meta_dispatch.h,sha256=5V68epEnsg-8Us4nisUDtFFMKHUMCKKqWIZ04XmAf9M,1072 +torch/include/ATen/ops/remainder_native.h,sha256=MfDS1Pjm8fikfTjflwasq5hHpb-5kAKlyF-xuSVyneg,1140 +torch/include/ATen/ops/remainder_ops.h,sha256=qwNINjuBNDRh6Ra3ei4kbCr9bE6gSJZvasklGI5FnqM,5796 +torch/include/ATen/ops/rename.h,sha256=fces875T5XOhlL0lMoJLY14yBYh2M3vc4nsSUg6ZeBM,530 +torch/include/ATen/ops/rename_compositeimplicitautograd_dispatch.h,sha256=fwk5sa0UbZnc9e-Ksq2AsOBAI3405G3EbXqO_A-548I,919 +torch/include/ATen/ops/rename_native.h,sha256=QWXcFnu3C_9BzTNSaz8F2qMzG-54TeE0Bf45xGjjN24,629 +torch/include/ATen/ops/rename_ops.h,sha256=wpLiFpvtbIQt_npeV2kM6oiFtPb4BmbRAe2JTWDj7YA,1766 +torch/include/ATen/ops/renorm.h,sha256=iX-LcMR4SynV_QzYV-OMUPWMlxrvWj9Wxi7Jng2fycc,1412 +torch/include/ATen/ops/renorm_compositeexplicitautogradnonfunctional_dispatch.h,sha256=TblYjtM0sWzBvg5QmConnzmBMr1BDvpPU8HOKmVZGCs,991 +torch/include/ATen/ops/renorm_cpu_dispatch.h,sha256=UF_4KVtUg9Z1q5YwwLW4TovA--A2MpUP5bTGx04gz80,1206 +torch/include/ATen/ops/renorm_cuda_dispatch.h,sha256=TG3ZCBuLGr-9Txj4H7o4Z3yCweJP0MDewXvj5xjAPHI,1208 +torch/include/ATen/ops/renorm_meta.h,sha256=2GaX-bWWauwsIcVNcVLfPXDdU69gfxiMVtiKs_yBiHU,657 +torch/include/ATen/ops/renorm_meta_dispatch.h,sha256=fCgUhpk3JmbPbTTwKsUf1zrzh4pI4t-1hwYFA9Y3wGI,1208 +torch/include/ATen/ops/renorm_native.h,sha256=aq1dP2C4rIP34YWbpeva5yViMTnSa2HJfIXFxPQNSlY,682 +torch/include/ATen/ops/renorm_ops.h,sha256=4ycaR0gCTSXvBZWbaqJCBtrSqhS9KQDT7NcrHKTyYWY,2727 +torch/include/ATen/ops/repeat.h,sha256=YuvefBBNyGt8uAqglHH_jat03z7m8lugPK8fSgXP9_E,3256 +torch/include/ATen/ops/repeat_compositeexplicitautograd_dispatch.h,sha256=bSl_-wRwUAh11tSFVMmHdAHXmPxL2W6dQj0CWkkW8og,1343 +torch/include/ATen/ops/repeat_interleave.h,sha256=IML0_P8z1Slgh6Oa1eNZvnH6bhhWNmpCFiSjiFytvaE,8414 +torch/include/ATen/ops/repeat_interleave_compositeexplicitautograd_dispatch.h,sha256=9LAzkSRCYA-oJ_Fb_MmzONhbpf3Yz4djQedxkO5rIkM,1310 +torch/include/ATen/ops/repeat_interleave_compositeimplicitautograd_dispatch.h,sha256=O2dIH2-kSX689aCDPLlRZKkZ9DupddfHhQ0852JpTW0,1504 +torch/include/ATen/ops/repeat_interleave_cpu_dispatch.h,sha256=9Iww8eFEmAz5jX7j1mVgK1p4Xx_aXLAygK7xHsrOvAA,943 +torch/include/ATen/ops/repeat_interleave_cuda_dispatch.h,sha256=7IAXkPY4NuEVuYwq2De-qOXsd-t980GrrtjMJmMnrrA,945 +torch/include/ATen/ops/repeat_interleave_native.h,sha256=3n_037uiniieTx4ROPJJ2nXkt-gygAdHPauzEw34Dwo,1243 +torch/include/ATen/ops/repeat_interleave_ops.h,sha256=25Q01467amxiXS9TxBs3W9WAu6YT5zuQUWLZil3UMow,3795 +torch/include/ATen/ops/repeat_native.h,sha256=HWGEp0zOvGUNCem_Y3g1fpI1FdScrLV2KPTAFSgieOo,637 +torch/include/ATen/ops/repeat_ops.h,sha256=hjR8V_w4-L3Oi06gfjNIrDkk-yHa3ekXVPUlxWg7qUw,1781 +torch/include/ATen/ops/replication_pad1d.h,sha256=cBSoYZhNwxJ2MEkCcANx1V2L5pLhHWlobz8NhQwCG14,4020 +torch/include/ATen/ops/replication_pad1d_backward.h,sha256=wm4ULewMHyD4GYUgmoFbesKy-JO8bCop6vnfGFWhzzM,5183 +torch/include/ATen/ops/replication_pad1d_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=2IEkw4pcU81FSwrxLMVi0L3bexcM451kKgvn0Te5kik,1033 +torch/include/ATen/ops/replication_pad1d_backward_cpu_dispatch.h,sha256=g1P-RnTqVgXdTjmZsScSxoQoGcJ5gzzynEVNQFvwZ4Y,1639 +torch/include/ATen/ops/replication_pad1d_backward_cuda_dispatch.h,sha256=HfkSXEHw6AhLD1YMiGR_g8KY5OB-E6ghM711j_9QJss,1641 +torch/include/ATen/ops/replication_pad1d_backward_meta.h,sha256=ZYn2EamxXnwcuEbc8Gns2RkEM5BXWxR63bfhgHKFels,677 +torch/include/ATen/ops/replication_pad1d_backward_meta_dispatch.h,sha256=su06XgbaOXcY55GzyQ1TiaB0cpaTc6ex9gci6LrRWdc,1641 +torch/include/ATen/ops/replication_pad1d_backward_native.h,sha256=bPIi3HFCz_kCln5uIjYhAaNAjSbrsnCGhuptQOcCX8M,1012 +torch/include/ATen/ops/replication_pad1d_backward_ops.h,sha256=h9iOS0WRfeXo8EFwTcRH_nobOwrppCxubNhCZVr0LmA,2153 +torch/include/ATen/ops/replication_pad1d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=B65JiyhUGwLAnrkHNxR8llunN0ylFPrpWIn8Ngt9w3A,951 +torch/include/ATen/ops/replication_pad1d_cpu_dispatch.h,sha256=5vdngUhGxgt891yzbZd-M99emyv6mY_fE5b6OltuQfg,1365 +torch/include/ATen/ops/replication_pad1d_cuda_dispatch.h,sha256=_udK4vTyifhNc0cc5GVPLh5hjw31KosQlmkw3kB5sPo,1367 +torch/include/ATen/ops/replication_pad1d_meta.h,sha256=K06zxmbFNVuByyWRfkvTjOSxKLWNF6xTL8OxD56cb7Q,636 +torch/include/ATen/ops/replication_pad1d_meta_dispatch.h,sha256=TpDplJz7jNziT26DF77bOWhyNoh6ijGsMUbOeDqM1sI,1367 +torch/include/ATen/ops/replication_pad1d_native.h,sha256=Er2DOyuWpTexlRz3-7bFgO3qw1ZBlWpKN8zV-o9HR8Y,889 +torch/include/ATen/ops/replication_pad1d_ops.h,sha256=GzIFmVOUuS0pYUVIUQVo1QvHXrnnURpamkAckzQiC_c,1849 +torch/include/ATen/ops/replication_pad2d.h,sha256=8-KN_qQi9bn2-m0r6HWbfZMmQUY_nSuBvb_hvyAaq6w,4020 +torch/include/ATen/ops/replication_pad2d_backward.h,sha256=tjA0Uno6C8dFM6r4RlIngq4_ojQ0rNXSxJoh1in0ADM,5183 +torch/include/ATen/ops/replication_pad2d_backward_cpu_dispatch.h,sha256=Dw6PYIcyv9IxZKPzOFcTllzs1z3XZQCx_w7klPZSDOE,1639 +torch/include/ATen/ops/replication_pad2d_backward_cuda_dispatch.h,sha256=GCtcIAU3AoYsUqbiN-nY7HsM1UbLAnI6rbs_yMUITX4,1641 +torch/include/ATen/ops/replication_pad2d_backward_native.h,sha256=OrjSEmLY5tL8HICCN8Z8FP4FgN0uygkymqjeMTG4Zbs,1050 +torch/include/ATen/ops/replication_pad2d_backward_ops.h,sha256=Je6BeIMfb-E3RF8DdOfuFtwoM1fSDmF4fdC_WKp0YgQ,2153 +torch/include/ATen/ops/replication_pad2d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=ieD7ss5o_Cu4l5VIZpZETrYsyhiF7sHxPWk72HYFNqU,951 +torch/include/ATen/ops/replication_pad2d_cpu_dispatch.h,sha256=-j36OH3yeO21cAtSh_l4oZkY8yGZwx1m9AfZlnuQAKQ,1365 +torch/include/ATen/ops/replication_pad2d_cuda_dispatch.h,sha256=A9yW4i1d7IxrTtmUvXKFD_kbNBeH3NEt9bIyxp84jbo,1367 +torch/include/ATen/ops/replication_pad2d_meta.h,sha256=hoWPoe437n7EcF9mu2uGLz7f31sMYYR3Nj7xEKb7Ukc,636 +torch/include/ATen/ops/replication_pad2d_meta_dispatch.h,sha256=ALbtt_vHvSBIaYwv2LK1jTuMZDqJW409nGQsrQuyMo4,1367 +torch/include/ATen/ops/replication_pad2d_native.h,sha256=gQ-yCMjVLSBRTSnFO9bhWJb_0zvPiiBH9M9wcOHtkrk,889 +torch/include/ATen/ops/replication_pad2d_ops.h,sha256=EQ6r6q-kSwCXRPTLGL3VXlw5zwhKRxs0fLYG4_DV41g,1849 +torch/include/ATen/ops/replication_pad3d.h,sha256=EQsgypOU2CAn_t_JdulZ9rSgeC6JmSGwKyKWWhXUt9Y,4020 +torch/include/ATen/ops/replication_pad3d_backward.h,sha256=eJzpQxpc3hqoWoxwBklTf48vaVL8J92LjUkwmK6lEGs,5183 +torch/include/ATen/ops/replication_pad3d_backward_cpu_dispatch.h,sha256=kKvQDDx_qJR2tgOhHm3ZaPvBq77hqD7wGcOdP4mYxvU,1639 +torch/include/ATen/ops/replication_pad3d_backward_cuda_dispatch.h,sha256=vodXlWgk52TaC7u1jLaHnYamjLzPeKGy97AbSOgfLwg,1641 +torch/include/ATen/ops/replication_pad3d_backward_native.h,sha256=sjG8i2GwjwkplR4DUv1rz9mMJdqYNKbsCawzSwa4StA,1050 +torch/include/ATen/ops/replication_pad3d_backward_ops.h,sha256=nHnA5ihBiXAagP2qFuQ0jtOebRrDz0Bd0wHfvtbUK_4,2153 +torch/include/ATen/ops/replication_pad3d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=uyDfqY4gHGlIRN_LQpC26nKgaOy6GPGLf4BJlMPV9bg,951 +torch/include/ATen/ops/replication_pad3d_cpu_dispatch.h,sha256=VZ8-mPbxZ80KQBQD7PzD53AP_6PHBom1BBgv04ZiK1A,1365 +torch/include/ATen/ops/replication_pad3d_cuda_dispatch.h,sha256=OdqfuiAva2JDhI_V0MeBfJ63UwtoyuWZcf0FTAPFavU,1367 +torch/include/ATen/ops/replication_pad3d_meta.h,sha256=cMzcFxMszbVIjX3uny2BBtc6KaGvBnEYl4wL2WiWX40,636 +torch/include/ATen/ops/replication_pad3d_meta_dispatch.h,sha256=kcnE-uKqbCg6-zRfXEfXxu3CzNJmMvvsp1ndGV4RQos,1367 +torch/include/ATen/ops/replication_pad3d_native.h,sha256=txYTNugIWFg4ahwkVDaRJny5_SnsAD61bLCwDSmXAE0,889 +torch/include/ATen/ops/replication_pad3d_ops.h,sha256=MmRSx0NjJbP5w6uNcs6OZ98CNc5yOy8m54TYP882Kiw,1849 +torch/include/ATen/ops/requires_grad.h,sha256=FxBu80SbCus5Vszauqo64e2wtWCXyNIxDIXX32dcJjw,537 +torch/include/ATen/ops/requires_grad_compositeimplicitautograd_dispatch.h,sha256=5oR2_i_vOVo_IpmKxtVG5qN8okWrUxmROxB0P7A1Z80,816 +torch/include/ATen/ops/requires_grad_native.h,sha256=9TmSk9awhQk0qhlVYMpjgyR_MXTT9c-SQo_nP9NDRoU,526 +torch/include/ATen/ops/requires_grad_ops.h,sha256=tPrIuCc40d26AGxy4BS9RUaNWgXzbb-5P6cAFKRi--I,1077 +torch/include/ATen/ops/reshape.h,sha256=mzJxmsJXhrLTZbtN6Hm7Xvn6NptsB5-KE5TQ-k9R5Uc,1454 +torch/include/ATen/ops/reshape_as.h,sha256=7LlBUFh2rnphdCvZxHbkDlxXFDv-_EeSTHxdnRRj8O4,534 +torch/include/ATen/ops/reshape_as_compositeimplicitautograd_dispatch.h,sha256=6KJfys2fTlzKd9weK25Namh34-wemqwUNZQyeR7kJQU,817 +torch/include/ATen/ops/reshape_as_compositeimplicitautogradnestedtensor_dispatch.h,sha256=P_wDS1fckul4dTBUCIqlFgRSJrcS4jNKbeIset92--A,841 +torch/include/ATen/ops/reshape_as_native.h,sha256=DSL4ix21iueVatYsPX0Sv05yUz8M_JH7rs_m28FnoaQ,619 +torch/include/ATen/ops/reshape_as_ops.h,sha256=tW6lD49o8I60X-JX-ZWWnAtItAJej0deNoVIIIfSJo8,1090 +torch/include/ATen/ops/reshape_compositeimplicitautograd_dispatch.h,sha256=d4BpregJxyrZ1vtH3p3Rr2Mht7jLIJSbJK1XeB9Ho6E,901 +torch/include/ATen/ops/reshape_compositeimplicitautogradnestedtensor_dispatch.h,sha256=g_DVWJvHI9RUAfT44-NtUYnxvL77l640TzA1xe4afns,925 +torch/include/ATen/ops/reshape_native.h,sha256=gIcLO7foUhRAe6dRaQwUC7dNHBIG1vIyPdQv5vXlJRg,629 +torch/include/ATen/ops/reshape_ops.h,sha256=X7ItQMPJEHDPZL-cLwW6LX_cSNZE-DcQxZXg1pF5BjQ,1086 +torch/include/ATen/ops/resize.h,sha256=gV5DVxw-UP_l0BTy9uvMWA3_G4Vt8oq4fIm3O_iu6AM,5449 +torch/include/ATen/ops/resize_as.h,sha256=3BEvVrxwsRkkflRbtIdpA9Bx8sHd4lUeV3mPK67IAa0,2004 +torch/include/ATen/ops/resize_as_compositeexplicitautograd_dispatch.h,sha256=rmjutVfOXQGFgF_m2rm9M1ZMNo8OxUPNJcsYavpoaNk,1419 +torch/include/ATen/ops/resize_as_native.h,sha256=KSh9nyCiCfN54GQE4MM6ewrDEItDpZaT6MYZe9VwZ8w,937 +torch/include/ATen/ops/resize_as_ops.h,sha256=hImSAl6JPR7B2d4eTNfdQ63OZGndD39jhNMJJ46uN-Q,3024 +torch/include/ATen/ops/resize_as_sparse.h,sha256=0P_ILypZ2kpxf7DHwC9DyvDdXeUJEFUiKBUY4qoVq4Q,1652 +torch/include/ATen/ops/resize_as_sparse_compositeexplicitautograd_dispatch.h,sha256=VN3KGWf4HBfVrXi055jmnU9JCaLdYlUqM-Ud5MXIkBw,1099 +torch/include/ATen/ops/resize_as_sparse_meta_dispatch.h,sha256=KBlvgmPomonjCHAvh0yyQzpEVGIRKBGJ9e3KIUiHyjM,797 +torch/include/ATen/ops/resize_as_sparse_native.h,sha256=a0XH1I9oNkGq0VCsB9KEFxRPhQ1kPgg9BIiM6wJjQBM,899 +torch/include/ATen/ops/resize_as_sparse_ops.h,sha256=xbPBx3JYxb-W7cW450OTOcLgyD8V1HxyJdnGRUajjCs,2580 +torch/include/ATen/ops/resize_compositeexplicitautograd_dispatch.h,sha256=jL-R1XjCrF1pbPDwoxi8r4jAYkmNTMqIF1CX-RLoZHc,1727 +torch/include/ATen/ops/resize_cpu_dispatch.h,sha256=z2YYp43CbFN9a8kh1EN4zYsGNMgHaK5qOXAJ5Yls7Co,999 +torch/include/ATen/ops/resize_cuda_dispatch.h,sha256=syZ3iBk7kyS6a3BotfqhmUHvjOUbDtl4A2VL97WYsO8,1001 +torch/include/ATen/ops/resize_meta_dispatch.h,sha256=XvDN7UnG7huDN4Mu-mmcKcWpUZ0_aHQoYwUwTc_OaWY,1001 +torch/include/ATen/ops/resize_native.h,sha256=82By7TIpXQraVRfo3pdnRrV7eTZa2lOevfvjWmYDxPU,1558 +torch/include/ATen/ops/resize_ops.h,sha256=-FlkjUchlmyYBCtpfgwFvA1ShYBiEtk9LR4ejlabgs8,2940 +torch/include/ATen/ops/resolve_conj.h,sha256=MudF3vK5tO9e6AImiQFbEomHml1gTOUSGhMlMLn3BHo,698 +torch/include/ATen/ops/resolve_conj_compositeimplicitautograd_dispatch.h,sha256=89u4evnzkK6yEmsmmLbvbDGhh5_HMkFsnS6xc-Qy22g,793 +torch/include/ATen/ops/resolve_conj_native.h,sha256=jWLdpUYVVQr9481XkT1q49NI4f0kjCYpVinF-ykJ9S4,503 +torch/include/ATen/ops/resolve_conj_ops.h,sha256=368w-VkZ3OMaNMepN20tlitnAnAwO48m7uYDedv42w0,1010 +torch/include/ATen/ops/resolve_neg.h,sha256=x1kUycu8phv7n1Q77UxEYgMfzFonMjOeFhh-Uz840sg,694 +torch/include/ATen/ops/resolve_neg_compositeimplicitautograd_dispatch.h,sha256=YaSIg0nSU9sVfbTawcx2MCmKk-sJuzkLYp_0Oz8zJuQ,792 +torch/include/ATen/ops/resolve_neg_native.h,sha256=rKY0QwMNcKDevdbMOnnrZgt2zYrp4exgIwSw5uppW7c,502 +torch/include/ATen/ops/resolve_neg_ops.h,sha256=NGLWvlSmp-g8IJ1ipgKINibIPAmOWAWXChtfB4o47Gc,1007 +torch/include/ATen/ops/result_type.h,sha256=pH9AaQ2zddh-1BJO_JEz-l639KKW5SyKgPAgBPyKXsM,1493 +torch/include/ATen/ops/result_type_compositeimplicitautograd_dispatch.h,sha256=7NkrhBuWg5dmakLhZpL35auR7bMB2J1hZ_EUh0jpRBI,1104 +torch/include/ATen/ops/result_type_native.h,sha256=hw0siWa6Sz9_hm_3TH28snlwfERVv0DbSQjehqVjha4,814 +torch/include/ATen/ops/result_type_ops.h,sha256=Vm6M6u6j276rhEZ6BxlVSKh7IefSjL3U0c4d3CXYsS4,3145 +torch/include/ATen/ops/retain_grad.h,sha256=uc3XxF3rRfRd3SV0ugqIk2kx1qzIWQq-bGETIf-LGNM,535 +torch/include/ATen/ops/retain_grad_compositeimplicitautograd_dispatch.h,sha256=PtpQz0uG2TEDifP5bfFpRqpfSIWtoKNCTHyQ24UZS80,780 +torch/include/ATen/ops/retain_grad_native.h,sha256=gTPBYxlEiYGiIbjPmS-botVR5xYINn5vEnN9bx2dZ2w,490 +torch/include/ATen/ops/retain_grad_ops.h,sha256=a089L6v_kjROF6FLZJTOX_BL9SsmHdheKHq-A7aUITU,965 +torch/include/ATen/ops/retains_grad.h,sha256=-Uw_7wiF9IW88SGmh7Gmrs6CB9flRFPAsJ7tTHOu28E,536 +torch/include/ATen/ops/retains_grad_compositeimplicitautograd_dispatch.h,sha256=ZGDxL9qCba_6x1GAuA-y_kM2wtxrnJWg2ntvL4iqL9s,787 +torch/include/ATen/ops/retains_grad_native.h,sha256=hBB3QyEi-JPqVALMAPdrNmlAzkko5HHubHFRxPD0Rbg,497 +torch/include/ATen/ops/retains_grad_ops.h,sha256=fHsDO-J3NxiyRvZYN6FjO_PWTrwrn92VuBYQyy49fHQ,984 +torch/include/ATen/ops/rms_norm.h,sha256=4Q-GgNPSsL5fSPCWtHHyL8XF7hhFDljHuuUtMM8cW44,2065 +torch/include/ATen/ops/rms_norm_compositeimplicitautograd_dispatch.h,sha256=FfI7PMpXc5SOaR-I1Dyqr2N-fF6AtGZpUo26TqzVPvU,1109 +torch/include/ATen/ops/rms_norm_native.h,sha256=03LmpLqskSbpDbcxqCdFKaIWtsHVMhLcIqgM4rPkFtY,636 +torch/include/ATen/ops/rms_norm_ops.h,sha256=Eb84yWJUuhe82NN4L3qcvnzoB29DUniStX0Tvkpmr9Y,1365 +torch/include/ATen/ops/rnn_relu.h,sha256=8Wy2tayvl2LFwrHua9yoiGIHGveSgvUEKUMmoiXrf84,1643 +torch/include/ATen/ops/rnn_relu_cell.h,sha256=c5XSNUupG8kf625P5Hq1vEiqGuos3dkP02fyL69ymeY,965 +torch/include/ATen/ops/rnn_relu_cell_compositeimplicitautograd_dispatch.h,sha256=LQB705mlfXRhl-HLyo-arGYJ2zhej_fE5-RLf4mq2ms,958 +torch/include/ATen/ops/rnn_relu_cell_native.h,sha256=5_272rTJLycyC-IXz6TZrp2nPwB-vzLZITsv54pur3k,668 +torch/include/ATen/ops/rnn_relu_cell_ops.h,sha256=pMBt2NgfQ7_GoB3fDPC7sNzCDwoGXglFiyew7vfMIK4,1533 +torch/include/ATen/ops/rnn_relu_compositeimplicitautograd_dispatch.h,sha256=QEhKjkcPZ_rs_AITUt9a3saS21BK7bySzPcT6hpziwA,1209 +torch/include/ATen/ops/rnn_relu_native.h,sha256=ILTcVdJNCl67yCsJNzmtg4L9IIXcqusQCvY8JZxhtgs,919 +torch/include/ATen/ops/rnn_relu_ops.h,sha256=xXenvNM-A43s_7dPFd4PlmfP5MEYNhyB6sJX5QHd5Ug,2759 +torch/include/ATen/ops/rnn_tanh.h,sha256=tDWUIhpn_h99p9uvfJ06Q43zgmp3Tz_tncpFMvVkC_E,1643 +torch/include/ATen/ops/rnn_tanh_cell.h,sha256=EkulfOSC-HrcpkW7KAJIxaf3jzxXm2WzaLaJ-EO1jVU,965 +torch/include/ATen/ops/rnn_tanh_cell_compositeimplicitautograd_dispatch.h,sha256=VSKkFI4rHP79cDWFK9nPu6s-UV-NGjhLGXXPVXEdvkI,958 +torch/include/ATen/ops/rnn_tanh_cell_native.h,sha256=RzNhrwZowtrn3kqLPxem1YYTqmGQgJ-Jv9fxQ0tgVSw,668 +torch/include/ATen/ops/rnn_tanh_cell_ops.h,sha256=zGUpnA4f4L4NT2CiWXHOHXg8FQq7bYedhXR7YU4Je2Q,1533 +torch/include/ATen/ops/rnn_tanh_compositeimplicitautograd_dispatch.h,sha256=QBhr1BMt1nuT7DyvwOIWqbiqQRt-L08L6a0YezzIdOk,1209 +torch/include/ATen/ops/rnn_tanh_native.h,sha256=89yeoQfrWDXYsMIAhRXxWz-ChT1JG16kkRX1Cx_h0ls,919 +torch/include/ATen/ops/rnn_tanh_ops.h,sha256=F1Zl7Aj5RZqs2JAUHlje0e4CVz2uce9vQi0Ja5o6pQY,2759 +torch/include/ATen/ops/roll.h,sha256=PaTghI_ixIMD8vg2NCIut9oHJsxgXBVAyGE-MpN4jpY,4043 +torch/include/ATen/ops/roll_compositeexplicitautograd_dispatch.h,sha256=yrtY4WiHi1xO9rWZXhIH_39VEakvwKUFaXRi29YXa6E,1254 +torch/include/ATen/ops/roll_cpu_dispatch.h,sha256=PDmrHaeEUD1lf9TfxfHkgAUgyrm1Ny1jV4IaNun3ORA,903 +torch/include/ATen/ops/roll_cuda_dispatch.h,sha256=oEOrE5DO1Iw5lOyLSUwUFUV8mdfojzDzoPC9y3SQyHo,905 +torch/include/ATen/ops/roll_native.h,sha256=csY-AbP8e6o2POiw_PXQMIFxCZ_yVqNdphvZ0XulodU,785 +torch/include/ATen/ops/roll_ops.h,sha256=EGjz8zPOM_dcm8z3TsUD3MpvKjSULzAtRKj_AqSEFso,1919 +torch/include/ATen/ops/rot90.h,sha256=wXLx6WIRNGL1Jf4bjs8cs74dBosOXjm-kMgYsidTuJg,1280 +torch/include/ATen/ops/rot90_compositeexplicitautograd_dispatch.h,sha256=jas3PD2vMVc9a85pOvaLfNI90-AKIqNZ5vrQ6OEQHvs,1058 +torch/include/ATen/ops/rot90_native.h,sha256=GPlcBDR0tFry_yNuB49tQTPZZUi_KXU0QAcICkiw9js,648 +torch/include/ATen/ops/rot90_ops.h,sha256=C87WIm3V3Cqya-W0otD_JPWRTY6j7bm-2pF3oJsdy3s,1819 +torch/include/ATen/ops/round.h,sha256=pN0bDa1KCkJhg2PR-x7VltbJSjhpZC1ihrSCrghzANU,2121 +torch/include/ATen/ops/round_compositeexplicitautogradnonfunctional_dispatch.h,sha256=1t6qSTHWo_ZoZILO00bcGcDnxDrzNVQvDiWjV2uQNr8,1004 +torch/include/ATen/ops/round_cpu_dispatch.h,sha256=HDNNMLc0ECrEwPUw29zyiTfb2j4NQCELr_V1UEdnmes,1284 +torch/include/ATen/ops/round_cuda_dispatch.h,sha256=NmvxGlPyUgytKrOEdS08Htl2IBugLbufMRPWCUq_EBM,1286 +torch/include/ATen/ops/round_meta.h,sha256=_Vj8EhQkCO6rQ8O1RfdFL3_GxwX7RiCPePYlQty8Z-g,734 +torch/include/ATen/ops/round_meta_dispatch.h,sha256=A58V5gtB4TRJkie5L7315368RnW7Q35UztM-THpySH0,1286 +torch/include/ATen/ops/round_native.h,sha256=ZbnZSEsXtROM6I24MKeTRWS85NwPH0c2sgU3SZz24Jk,1214 +torch/include/ATen/ops/round_ops.h,sha256=nJFXiScsvzy33up88ml_eLRTvRAUq-7cUikCDFXOo24,3963 +torch/include/ATen/ops/row_indices.h,sha256=p8yF-5w9OOJ9w293FkZZ1EUU5SRPViRQApW3UB1WO0E,535 +torch/include/ATen/ops/row_indices_compositeexplicitautograd_dispatch.h,sha256=YO359IT0Ukg7n7g0OhYYo16MBwbMdX_goHs8bsJKC2w,792 +torch/include/ATen/ops/row_indices_copy.h,sha256=0TMfZxE761E_JufPQggnspzPNeImG3UsfK_NdchvxYA,1167 +torch/include/ATen/ops/row_indices_copy_compositeexplicitautograd_dispatch.h,sha256=L2gWzJkjmO4U8r0yWunLyICuhBZEPvB5BItqiEFnFcI,911 +torch/include/ATen/ops/row_indices_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=ZziPSFVL2rjWHgVTVvgRVqm2OjDGw0YF8BzDiyn6u3o,823 +torch/include/ATen/ops/row_indices_copy_native.h,sha256=9JiiXshsGu_ESg4jK1vt4NvbaJZRSgsBWVfX3IzvGDY,596 +torch/include/ATen/ops/row_indices_copy_ops.h,sha256=-zUpJ_FMtTkyPOsHXs80np3Fyv_wqtDdb9F5rxTplNM,1647 +torch/include/ATen/ops/row_indices_native.h,sha256=HUMDPFZskaDTwD4uQnlzJfS03ljthg8D9tDRK0GdQcY,581 +torch/include/ATen/ops/row_indices_ops.h,sha256=tCwBb7YRT6mgv600HuYeI4rOOqsbYbV1jV5ttdnBUTE,1007 +torch/include/ATen/ops/row_stack.h,sha256=xJdeJGmUgPyOVdNo9HqlFwrTBMmoEAP4DHYhmdUAEGM,1118 +torch/include/ATen/ops/row_stack_compositeimplicitautograd_dispatch.h,sha256=gJhE5x0cUwVQIDjpoY3daPmuONWQ_M9yrHsXnKTCW0Y,952 +torch/include/ATen/ops/row_stack_native.h,sha256=XNi-YxaMjyWqv83JNhYriSpZ-iBkk5at_PUske7Pw58,580 +torch/include/ATen/ops/row_stack_ops.h,sha256=lioo-oJqwiOruF9s-zKLlSrem6qnLzIelyMTDsM71Ek,1603 +torch/include/ATen/ops/rrelu.h,sha256=QuvKF0iofRG5DblqJgrapikz5bJsR0guYLiH-BbxFxk,1390 +torch/include/ATen/ops/rrelu_compositeimplicitautograd_dispatch.h,sha256=GLecglCTWRtPPmzNO4J3NAoc-xxbzUlrRX4_TxqJYS0,1147 +torch/include/ATen/ops/rrelu_native.h,sha256=Qzwtj33mXnW7m_J0_LuIQq9YsLmE0Xc0MdlJSufgagU,857 +torch/include/ATen/ops/rrelu_ops.h,sha256=9Bk0EEmJ7c9LNCGk4_Pjubl0-uBw11oiFIAXsoInq8M,2284 +torch/include/ATen/ops/rrelu_with_noise.h,sha256=l0-AashV7DeXuVQH2yE1XJCW_2Pc1p9fkJZ-q-qCFqk,3228 +torch/include/ATen/ops/rrelu_with_noise_backward.h,sha256=IZAAvfRYEOstA6Myh2kY3pxEa-95AOJrLp4sTeB6jDA,2169 +torch/include/ATen/ops/rrelu_with_noise_backward_compositeexplicitautograd_dispatch.h,sha256=PhMavetigluvtOjyI-vGrTMfGUFhElstlmiMaMdW_U0,1441 +torch/include/ATen/ops/rrelu_with_noise_backward_native.h,sha256=Lv78oSiv-ydMXtB5a4Lv0JQ9qT-Y5NbYoJGQvEIPPvo,906 +torch/include/ATen/ops/rrelu_with_noise_backward_ops.h,sha256=EOtdIoQw4VawzWLYG8-_a-DQSd-Snl_zUY8T5xHRhFE,2665 +torch/include/ATen/ops/rrelu_with_noise_compositeexplicitautograd_dispatch.h,sha256=k9qDCBZf_qPtmcWP0wJ08P46wPW3JqfkLibqNL7yJPc,1014 +torch/include/ATen/ops/rrelu_with_noise_cpu_dispatch.h,sha256=5UEegohjITdG2YDzvEv_bkNKkIHlhk-3nHoBF-TQpzg,1648 +torch/include/ATen/ops/rrelu_with_noise_cuda_dispatch.h,sha256=NC7WktDY0Z2HFO-ajX9N1QrcS3NIE5Ty5kOj7Cz68ew,1650 +torch/include/ATen/ops/rrelu_with_noise_meta_dispatch.h,sha256=NP09QLgnocB1izCoSy_u2hNb1m4Y46nWfPgr5_1Cf3U,927 +torch/include/ATen/ops/rrelu_with_noise_native.h,sha256=gzrE__QQsqhUmhzrKBI-r_jQgKpAl2C0EK5g0QrdREY,2141 +torch/include/ATen/ops/rrelu_with_noise_ops.h,sha256=uJ5QAFabW0Tskf7BlTiYOEAoHd-HvGRmIMmwUTjtADY,4746 +torch/include/ATen/ops/rshift.h,sha256=vSeMoQxJrZzpolQ-AQ8o8hjEX4NTr5w7kMTw0XxkIkU,2044 +torch/include/ATen/ops/rshift_compositeexplicitautograd_dispatch.h,sha256=YFO00F9wjVnVIkWULrGI6ud-xaybWToxDVFFTLipJ4w,1170 +torch/include/ATen/ops/rshift_cpu_dispatch.h,sha256=9PKqZkNCEVT4FURnSdvS528c5OeE4F0s1pO-_tqARdU,1022 +torch/include/ATen/ops/rshift_cuda_dispatch.h,sha256=CTPl2vuMOh7k6-4uaPYOnPHD4aU-ZEGltPZvy2X6Xoc,1024 +torch/include/ATen/ops/rshift_meta_dispatch.h,sha256=xLZS95waZ2419o8RqNvJ1rXPdhX7GM2YdqcBJNhSQKc,854 +torch/include/ATen/ops/rshift_native.h,sha256=eoDBEabqjm-kHukDYvyrof90kLDlOU_ggoGAwKESEe4,1008 +torch/include/ATen/ops/rshift_ops.h,sha256=WEY6H0ccfrEV1yXP5Yz0RsWvJG8L1LPxtEhhZ0_50a4,4429 +torch/include/ATen/ops/rsqrt.h,sha256=9ax9vCt-lOa5iwofo8dUAdYFnC8ee6YZMuRaQ7eXL5U,1201 +torch/include/ATen/ops/rsqrt_compositeexplicitautogradnonfunctional_dispatch.h,sha256=l1mem_aMqs_afxBZsDUbquO64zmkGlgV0Inlsb3AkME,863 +torch/include/ATen/ops/rsqrt_cpu_dispatch.h,sha256=GePx4o9SZLmTZbSjNcQ0qmxO0Q2Xk7cJXtkM1WLwXfg,950 +torch/include/ATen/ops/rsqrt_cuda_dispatch.h,sha256=P-o_2zmpROnucF19yGkY3JAyPK6T_zresWBsypd0SH4,952 +torch/include/ATen/ops/rsqrt_meta.h,sha256=UrtGod0ypYlyImMbqZjrSSd3yJXdHcjgf_qPc6kNJfU,593 +torch/include/ATen/ops/rsqrt_meta_dispatch.h,sha256=2935-LHKnzpfRGNNw-vrcvsGeLA_meZe8MWXyZ2_Jw8,952 +torch/include/ATen/ops/rsqrt_native.h,sha256=kk3Mh45bxv4kV2YP6zTBQVglzNeBxUgicO-QzXBJPPs,616 +torch/include/ATen/ops/rsqrt_ops.h,sha256=S7i3qF4oWTdSYTOMML4AlKguaoHeWIWxsZeBu_NJ_mY,2088 +torch/include/ATen/ops/rsub.h,sha256=IN5zfKc7r0VImH85BjmklEKZ1tB9w6KvlCpC8xR9UDk,2239 +torch/include/ATen/ops/rsub_compositeexplicitautograd_dispatch.h,sha256=9lzFdUliNLT73X6u57rafXkMb08gK00w_iz47v9F0dQ,1361 +torch/include/ATen/ops/rsub_cpu_dispatch.h,sha256=irEbnAGTq_Cn8Dz9pNEOBnKuDrifR2jqdA1i9TJszZw,795 +torch/include/ATen/ops/rsub_cuda_dispatch.h,sha256=lgqWSCQgBrVI8BHgevGCHRib_gRUt7EcGgeZ8gxeNCY,797 +torch/include/ATen/ops/rsub_native.h,sha256=NUtc-b6761x1sm9Vij9xsXRTGD7PXFRDiu0sLhydIB0,928 +torch/include/ATen/ops/rsub_ops.h,sha256=YVSEG0m5hIVYbQGOGLvu4a00aAeSirf162mWQQ_NysY,3456 +torch/include/ATen/ops/scalar_tensor.h,sha256=lRd7GkitTfIOBmFc2pr6D6fYgXYxKbNEk05jXNWFST8,1778 +torch/include/ATen/ops/scalar_tensor_compositeexplicitautograd_dispatch.h,sha256=Rao6xrRx8xxMFu4RY3qN1dykkC82BznOhrOlWp0Qi54,1192 +torch/include/ATen/ops/scalar_tensor_native.h,sha256=0Eik6ePV-P8ldXYhCcrGFVH_PpIH7GwXfhxD78v92CU,741 +torch/include/ATen/ops/scalar_tensor_ops.h,sha256=pGoAesPFjTFqUneekCcZk8VUMz5GOEZJmF0effzXiTA,2107 +torch/include/ATen/ops/scaled_dot_product_attention.h,sha256=MfsVXHVnmI6e3XUxldLrP5NFb5UPaPnXIaldGBJ5ZJU,1174 +torch/include/ATen/ops/scaled_dot_product_attention_compositeimplicitautograd_dispatch.h,sha256=FASSQTwwOQdibzMz7gLkQOXY9ftotKerhb8OGJy0LvE,1023 +torch/include/ATen/ops/scaled_dot_product_attention_native.h,sha256=dANKnDD_FUB55ltsVjep64-EBEINCUvw42BHH6958t4,733 +torch/include/ATen/ops/scaled_dot_product_attention_ops.h,sha256=IRmM2AqLwI0dZP2VJChlNOXVRmv36eMijtB6O4cUagU,1673 +torch/include/ATen/ops/scatter.h,sha256=vtL-IbgsyZxsAza2ayzd_Ys7W0pKnzgC9UMME2ya0gU,5183 +torch/include/ATen/ops/scatter_add.h,sha256=bnjx6_dAiv3eg6tUhXFG07jgVr978Llh7Mra4PLpSu8,1756 +torch/include/ATen/ops/scatter_add_compositeexplicitautogradnonfunctional_dispatch.h,sha256=hFx7ZrOHsNpUAg3X7O0FpPh5_fU2av3tLId0W4Q9VeQ,1001 +torch/include/ATen/ops/scatter_add_compositeimplicitautograd_dispatch.h,sha256=plN1ytGq59571OExjIZcNEwX45tjGXkVBsQ5Kugk5ck,859 +torch/include/ATen/ops/scatter_add_cpu_dispatch.h,sha256=bksw10tFnUptBPKmIevIsRoQas2Pj1F-C5SIuA2QAL4,1226 +torch/include/ATen/ops/scatter_add_cuda_dispatch.h,sha256=FF_AWAbvyfw5fFWAMvDVucjl4kxAV6LK68rt-UQZMig,1228 +torch/include/ATen/ops/scatter_add_meta.h,sha256=coYY9mFs_zcAdYgObH6jCrlHGjZold17z9oEM3VPFVU,662 +torch/include/ATen/ops/scatter_add_meta_dispatch.h,sha256=xYEVlhTCzji49Kpi4QM_A3NTsT0aeziZxZS1gR697TI,1228 +torch/include/ATen/ops/scatter_add_native.h,sha256=9E-hTfs92pqPcV_WdppDBCq5vM7ft9DSAEQCYoVIHSk,820 +torch/include/ATen/ops/scatter_add_ops.h,sha256=qb-TJkZOw6ElaSAWWXN_SRQEUMrOF9bCbuDc5EFeY3c,3547 +torch/include/ATen/ops/scatter_compositeexplicitautogradnonfunctional_dispatch.h,sha256=sDtD6r_KDb4pV2Dk_-30-a-XQBXO7s5cJ9flXDIDDS0,1806 +torch/include/ATen/ops/scatter_compositeimplicitautograd_dispatch.h,sha256=BGdDFhMPmflHZKUBtPZbRk2K9jK6EkvtbiJ7U0nJhxQ,980 +torch/include/ATen/ops/scatter_cpu_dispatch.h,sha256=dHHW9SyFRFRg0SWf20hlFac7lsIsHim8VxeWKvJ4mSs,2992 +torch/include/ATen/ops/scatter_cuda_dispatch.h,sha256=CK92Em57PIW54wpf8qBNs2mze5w4mOo7N6rNu2vqe64,2994 +torch/include/ATen/ops/scatter_meta.h,sha256=3FH8f46ptBA4tkJ0uaqSqcVCW2TK2ak8-iQAe--YdTo,1279 +torch/include/ATen/ops/scatter_meta_dispatch.h,sha256=QTgWFyR_a425NODLTH--hntLOTnB-8wxeOIIx9uQw0s,2994 +torch/include/ATen/ops/scatter_native.h,sha256=oxNsSK84NWt-XN76ivK6k1RdCoebvfH0oIwT8FDiT4U,1674 +torch/include/ATen/ops/scatter_ops.h,sha256=XEJcpG91XI5Noox7m0WpjOZhS7T-YuWHkRNctELR8no,11871 +torch/include/ATen/ops/scatter_reduce.h,sha256=0r3E5mZJafOGMlh3TQCggxwtMfF15LEQeVsCQW4vVvw,1835 +torch/include/ATen/ops/scatter_reduce_compositeexplicitautogradnonfunctional_dispatch.h,sha256=7ASqyPz7YoMfo0UTzi1Tf-4hd4niAV11d5hSt8kquN0,1105 +torch/include/ATen/ops/scatter_reduce_cpu_dispatch.h,sha256=qsbCWmuMD70tjMUIYRxSjxyQaeJRrlgHMPf51EUwWdE,1429 +torch/include/ATen/ops/scatter_reduce_cuda_dispatch.h,sha256=h_7clnLho8Q7RZO2tUe4HObZLVpGe8HLvS95LPiF6zw,1431 +torch/include/ATen/ops/scatter_reduce_meta.h,sha256=f4A00hXwIF4r6Dtu3_Ivs2k3efZjqZEzTUt7KU9GiPI,713 +torch/include/ATen/ops/scatter_reduce_meta_dispatch.h,sha256=QexjL4BfAZFJ6sG0_1tndIECGFo57JvnQgX9UTkesFA,1431 +torch/include/ATen/ops/scatter_reduce_native.h,sha256=47Hye9q5nc-r4Z3fw7CGPrZDmU4ZR2cKLd2oHLbj5WY,754 +torch/include/ATen/ops/scatter_reduce_ops.h,sha256=H6zvWb2wqtVzzjp7WQOw3kwFWiPlG8WXQalWB2PCuJY,3283 +torch/include/ATen/ops/searchsorted.h,sha256=A2FmDfunOt34ETynXO8CwOfNyG5N3LmDnEP1sZ9Kubk,3718 +torch/include/ATen/ops/searchsorted_cpu_dispatch.h,sha256=eEL3a6obq8YIBVabk_kJ-y9c-PjRUjWpY3091i4KD-o,2160 +torch/include/ATen/ops/searchsorted_cuda_dispatch.h,sha256=d2aqY5TccIeNVfHEXFdIaIM6bRfhJeV1EdfsJpuSxSI,2162 +torch/include/ATen/ops/searchsorted_native.h,sha256=oIwJYrMOaJGWu_ALJ-KlRGw_8lPkiCpdIObXArI0Ei4,2366 +torch/include/ATen/ops/searchsorted_ops.h,sha256=neA5tIpgyu7FpExayGS-pq92jwBDF1Oy5v43RvaEgyw,4863 +torch/include/ATen/ops/segment_reduce.h,sha256=NrB1A18FJbrQWilvYlFDxTu7HWQtAqAv6nPZbNGbIOA,2477 +torch/include/ATen/ops/segment_reduce_compositeexplicitautograd_dispatch.h,sha256=DwlhdC6g58yw-5laa8ppIWUQzlCe6Ub4UNIL4KRWx4o,1403 +torch/include/ATen/ops/segment_reduce_cpu_dispatch.h,sha256=TVEnN8qkIciGgIAY1pDJYGkcohuGdR9-jCGPkIvyZiY,1015 +torch/include/ATen/ops/segment_reduce_cuda_dispatch.h,sha256=k2xKNQLoZ5xcOTHxXIVyLhrW1zmqmkQdJbAEXzm96nA,1017 +torch/include/ATen/ops/segment_reduce_native.h,sha256=7XFKVT5vRfUQu_YZQqwuto0AMoOE8-ItSagpd8a5_EY,1095 +torch/include/ATen/ops/segment_reduce_ops.h,sha256=n8EJ8IRru58VFCbQrnvjUH5rBViK0_z8ytHwGbInS7o,3190 +torch/include/ATen/ops/select.h,sha256=v_s20gMMK_m4M7kQc_8swoZ0KObymzcOLvptS_Mfjhk,1699 +torch/include/ATen/ops/select_backward.h,sha256=I7P2YpTjrvLbHsDWBt2a3kMUxzfqZMqoYVatNK-jdjs,4924 +torch/include/ATen/ops/select_backward_compositeexplicitautograd_dispatch.h,sha256=Y8wVIf7YGWxdmxsJ8hvZ9z17lti7KR-nBpQ-Hbt5aVI,1372 +torch/include/ATen/ops/select_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=i95_jWByM6PQ9GAz3OpzPeawoa2UYXO1FeQmP0J-ZXw,1029 +torch/include/ATen/ops/select_backward_native.h,sha256=xUC_bm9rKrN3F-K4ZSFnUv4njIJs9o42SDBMV3-rzKY,752 +torch/include/ATen/ops/select_backward_ops.h,sha256=q4ozy4Ri48KHge40Ru15kVtbNW4Rsv2VrWnOlWFj7HE,2119 +torch/include/ATen/ops/select_compositeexplicitautograd_dispatch.h,sha256=iIB-luWpV-QvqBoy4D9JHmM13_uDIdn9pa0P2fQNOWI,909 +torch/include/ATen/ops/select_compositeimplicitautograd_dispatch.h,sha256=hEZSEQ0w_rWfgU2qIS7FzfMYaPfnpQf9TkYZi4hSzM8,819 +torch/include/ATen/ops/select_copy.h,sha256=K0RauFLME5hUPyVhRTc4yn4h1VQwKsZTlTkwvYJnLbY,3846 +torch/include/ATen/ops/select_copy_compositeexplicitautograd_dispatch.h,sha256=caDojXT-M_fcmyX7vX7IfF1oXZ_M516X0a4QnzmrSJI,1204 +torch/include/ATen/ops/select_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=rs8f-fgmRKCOfsGST6tyv8a2s-SZru09aSQlBZMOBJQ,945 +torch/include/ATen/ops/select_copy_native.h,sha256=9k7rRxTFVDIXJCQ__Np148VZZJ0tGhfHQ5WFWR3qIKs,767 +torch/include/ATen/ops/select_copy_ops.h,sha256=dr95tCrfufF0eG_w-M8pxS4F7ONVUNfr9pKw3HhB5TU,1858 +torch/include/ATen/ops/select_native.h,sha256=6ExZUVuG8BqDbGs7D300EcAxCCoxo7MOQNstoIVcjjE,807 +torch/include/ATen/ops/select_ops.h,sha256=MurdvrM4voSWahObjERqeqhBFSTKgt9xmvqIVg_RbAs,1762 +torch/include/ATen/ops/select_scatter.h,sha256=nAtx9vXT0qplAs152B7CrZRynzd84WEKM9ZXAF1shR0,4287 +torch/include/ATen/ops/select_scatter_compositeexplicitautograd_dispatch.h,sha256=8ShJCdbJHHORah1ndh575akqsGNQPINvBBYDEsfbJa4,1312 +torch/include/ATen/ops/select_scatter_compositeexplicitautogradnonfunctional_dispatch.h,sha256=n2YQu7k-wO0oGuTo_zVWsBdN0IHdWy0wv4Vm1bL0Fxk,999 +torch/include/ATen/ops/select_scatter_native.h,sha256=UT2WhbyojPmVvQDAlSXTPUGQ1e8esL7F9AXd3fAzfyI,718 +torch/include/ATen/ops/select_scatter_ops.h,sha256=eqF8U1a3X7YsjSQYVpgrFzQmbO4A7tP4f3aGWYJnZsc,2013 +torch/include/ATen/ops/selu.h,sha256=j5qRJdHsbqctZsSTSmUpEaTtORgOchFFctmHaKzkLgk,801 +torch/include/ATen/ops/selu_compositeimplicitautograd_dispatch.h,sha256=nT2E4vcqiQzt9lg7CAfWyrIyxPF8auMDl5e3wFWGCEg,835 +torch/include/ATen/ops/selu_native.h,sha256=we_HpGftFuhEs2lPKBIGNEB7V7p-JiG_I1A5d9cNIAs,545 +torch/include/ATen/ops/selu_ops.h,sha256=DrSA5nuZ3uPzdrBvN-ojYaJ5KWmORZ0J3wWwGsXszF0,1484 +torch/include/ATen/ops/set.h,sha256=tzK0amS7NNaPBjhnA4OuodNY6dIcCbFOL7S2T8i_PaI,9372 +torch/include/ATen/ops/set_compositeexplicitautograd_dispatch.h,sha256=Kt-Euf2J5DU1-meA2Jnq5PRorFvflpvmvjAsgwc8xyc,2507 +torch/include/ATen/ops/set_compositeimplicitautograd_dispatch.h,sha256=PYvyFsbyLBxGWDQ7RrkKHzbf1ihcpb7hbV9a8uLt0sg,1049 +torch/include/ATen/ops/set_cpu_dispatch.h,sha256=Who8oSiMtiQ6j9xO71zkDgLdlWOSz7mC8p-q8y2hVV0,1185 +torch/include/ATen/ops/set_cuda_dispatch.h,sha256=IqEyt87MRnBWBGwov6q3n-tSP93Gw8qiGiuMqylcD8g,1187 +torch/include/ATen/ops/set_data.h,sha256=mHLZpZphEPFSQrpsw0vPJ6qbkdchUQcIYh4TOhzs2Io,532 +torch/include/ATen/ops/set_data_compositeimplicitautograd_dispatch.h,sha256=EIQk67In4TuRySETNt2AO4Aw5ZTFQsu5JhZxH1QB8kY,806 +torch/include/ATen/ops/set_data_native.h,sha256=UzJAKiJsnYsJ3wIHdqxfmm34HXpFjapgmLFENlbUgGQ,516 +torch/include/ATen/ops/set_data_ops.h,sha256=Kvr1p8Ew15LKHBSi0iDECO9Bp76islGPG69iompie8o,1051 +torch/include/ATen/ops/set_meta_dispatch.h,sha256=JLs2gmZ1qej7u3Q4GSO-ZG9uuX4_bHSSjHNaS8YL8_0,1187 +torch/include/ATen/ops/set_native.h,sha256=pgtRlSLxR37j8oSzX91w0W8l3gMrrmtdhXqLtOilTR4,2452 +torch/include/ATen/ops/set_ops.h,sha256=hGYKHYiVERj6VqhW4NnXHnPDq4muO2M0kQ33zVG5eMQ,9845 +torch/include/ATen/ops/sgn.h,sha256=J4t3jY-k6YYpDQKufoVofBpyuFyDZ-OB2Oief0FiU1I,1037 +torch/include/ATen/ops/sgn_compositeexplicitautogradnonfunctional_dispatch.h,sha256=pGWfDhE4fs1m91TWDUKYFNmkXXjozs04EFY-ebb5wqc,859 +torch/include/ATen/ops/sgn_cpu_dispatch.h,sha256=seBrTjeyfY1tVRS8n4u2BnyqixBa6pAbIxgmQ3_Y8M8,942 +torch/include/ATen/ops/sgn_cuda_dispatch.h,sha256=--BS9N162Q09rgwWK0493czXryDwZFaurpGILhdKDtw,944 +torch/include/ATen/ops/sgn_meta.h,sha256=suQzfUxkjmNfbC2UI3SJcu1OlHPKaMxq7i9mZcpFPuA,591 +torch/include/ATen/ops/sgn_meta_dispatch.h,sha256=MJo8OB8xdY1g_Zv6wZYNVr5Edekj22N4Ig4EX7DmAeE,944 +torch/include/ATen/ops/sgn_native.h,sha256=pEF--F2LaJ_0IcKoeOunOygUoEDieVsvBmGMKUpCDyo,1145 +torch/include/ATen/ops/sgn_ops.h,sha256=ssYKXKrcIW0sv9-BIxapSSxvL8BJ4AqQrd6zABE41XQ,2070 +torch/include/ATen/ops/sigmoid.h,sha256=dv9RJCfeOfPzQuWGcUB1UlOIXDCqcrrih1N51amW1oM,1227 +torch/include/ATen/ops/sigmoid_backward.h,sha256=gp1fTveUCJkHPV-risJkwLBpRfgUnCwKYlSPoK2APgI,1450 +torch/include/ATen/ops/sigmoid_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=42GDXAs3SCt_yepmTvlq6WO5RAnBLBGrRbMhajZaYxE,857 +torch/include/ATen/ops/sigmoid_backward_cpu_dispatch.h,sha256=HAPrWsfGIt8B3zSq0Tec3MpGZPYb3mgUF6u8kn64t88,1048 +torch/include/ATen/ops/sigmoid_backward_cuda_dispatch.h,sha256=3Ubuhsqn9or07XJ7svk0CXTyPxuVEFT_LQXTpotsFdI,1050 +torch/include/ATen/ops/sigmoid_backward_meta.h,sha256=lF0Z3FvJjAEYNfNC4kzDI6ot3_-FJCVedBLIUbjD5q8,638 +torch/include/ATen/ops/sigmoid_backward_meta_dispatch.h,sha256=XGBhGd-kEj4akQJ8gS6BndioRUALbI1eB6I9uNp7eEw,1050 +torch/include/ATen/ops/sigmoid_backward_native.h,sha256=g2soqpGMhl2dFb0SpmMR0JVDhN5juIeY9lZPKtWnaQI,690 +torch/include/ATen/ops/sigmoid_backward_ops.h,sha256=IxgdDY-U2GjI9wob245js7jPloRh81FEu7i9S6B2gx0,1909 +torch/include/ATen/ops/sigmoid_compositeexplicitautogradnonfunctional_dispatch.h,sha256=FfAj29uB7rzgoAJwFWEodwa-g1rgAKdpCUYfVNH83nQ,867 +torch/include/ATen/ops/sigmoid_cpu_dispatch.h,sha256=zS_UWwaVs8oS-1eHcdEyeQ2ShQG7hgWW-xI72Gqkw7U,958 +torch/include/ATen/ops/sigmoid_cuda_dispatch.h,sha256=0vHCWe5GNB8uzU4u08S7GacsUFUDZ07aCgn7kCtAF84,960 +torch/include/ATen/ops/sigmoid_meta.h,sha256=Gm5L20b3vrrn7FYwDIvuOpvIf81vyeArh5GYzu4aYe8,595 +torch/include/ATen/ops/sigmoid_meta_dispatch.h,sha256=fIbjpruY-l5wjzLT_z-yfSUc1jWmwoACNeqwaz-7aAs,960 +torch/include/ATen/ops/sigmoid_native.h,sha256=N9UfY47FiC2m5iyJ898CjBN1dmzy-Q9PxiN0jtJX8PM,815 +torch/include/ATen/ops/sigmoid_ops.h,sha256=FZs_p9muXWYB_HLf9IGCCaphcJXf_OwYNoFv4cAaqIs,2106 +torch/include/ATen/ops/sign.h,sha256=7NKLlzNaDKG6H-c0bG5gPK_MvIyudkojO6VccB3h1KM,1047 +torch/include/ATen/ops/sign_compositeexplicitautogradnonfunctional_dispatch.h,sha256=AnBZXob5ALWDtzx4eiR_uP5hMhc1XLfzoD3sjzR9ous,861 +torch/include/ATen/ops/sign_cpu_dispatch.h,sha256=oDa1M77F6eKBD5nVA7wjJTOqjWAhI6CtYIChfib_zXk,946 +torch/include/ATen/ops/sign_cuda_dispatch.h,sha256=PJUqmBKPFTY-AenlW-4F9Yi7oOxPDOCry8WDyj3hSBs,948 +torch/include/ATen/ops/sign_meta.h,sha256=5798xDNURdSpPf69dIf1-IXNDX_M7iIY5Mo7LMb5Ooo,592 +torch/include/ATen/ops/sign_meta_dispatch.h,sha256=aSqGVPDn6xbyq-u7a-HHOMtsxrNfahlmGDgkdDlnRBI,948 +torch/include/ATen/ops/sign_native.h,sha256=LMNowOtX6HpbhcR_eX4vv-9WzCxG7Z9ks5QAp2w664c,1027 +torch/include/ATen/ops/sign_ops.h,sha256=xq7__qvmMyz8ogijZkCoDE6muDZNfOxJ3lQPwGO0lew,2079 +torch/include/ATen/ops/signbit.h,sha256=5yL93ioZXtzXW6uZWh5jJ5NoAExlxnDjIszI21-J4y8,1077 +torch/include/ATen/ops/signbit_compositeexplicitautogradnonfunctional_dispatch.h,sha256=kKkf6NEQlaKGRzulQrVAhF9ZWhjI6rj-jYbaLwOkq9M,814 +torch/include/ATen/ops/signbit_cpu_dispatch.h,sha256=qkGA6e4o6la4M0065hToJVhaRy_iI-dadLcoRVV2DGs,905 +torch/include/ATen/ops/signbit_cuda_dispatch.h,sha256=x17JnaigaUK4m7C7tcBt3Aybh8P89HVYZevbxkNsJTc,907 +torch/include/ATen/ops/signbit_meta.h,sha256=XwIUIpn4yVvRFSFfuXDidy85ctHm8ynAgM-X12-ZOTc,595 +torch/include/ATen/ops/signbit_meta_dispatch.h,sha256=nTr-bE0pB6hspIiVh8E9BCePbZsXu12fYqA_UQ4onm8,907 +torch/include/ATen/ops/signbit_native.h,sha256=X84CpnbwvbCC0nNpg2hDOcDk4IKD-KJOUX1LAZsfYpQ,930 +torch/include/ATen/ops/signbit_ops.h,sha256=Jfz3zSnylrH_66Xm63n8Cu6oLKROdfWMPQSKHceyt5k,1593 +torch/include/ATen/ops/silu.h,sha256=8GeO480yp0V4kwKRQsILQ7W16uHfxHUey9AwmKnAUvg,1188 +torch/include/ATen/ops/silu_backward.h,sha256=dDhvZ0GaXAWDoNYD4BA1NOOffcsHhs8iufrAwuPX3Cw,1402 +torch/include/ATen/ops/silu_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=WDQWAKpJvS5IRD5_OHuriYHqdmUi_ptO4LCRA0ch1D0,852 +torch/include/ATen/ops/silu_backward_compositeimplicitautograd_dispatch.h,sha256=M2KdmsuX_Om2MTHEmEOdptwcVur1Nj4wdO7hMmG-FMQ,826 +torch/include/ATen/ops/silu_backward_cpu_dispatch.h,sha256=6X-0w_-no5bxw0soCFV5kUAEejP1X4X0AFt53dsmIVY,1033 +torch/include/ATen/ops/silu_backward_cuda_dispatch.h,sha256=hLG28Pe4xw-U1EqTb-8_NmAW1Rp-DzGMu1PclYRO2fA,1035 +torch/include/ATen/ops/silu_backward_meta.h,sha256=1MoNf4OjYN32VQaAu4Uj1K7aA3eAEQ4VcQrYw_WZzAs,633 +torch/include/ATen/ops/silu_backward_meta_dispatch.h,sha256=f3i3UBNc2kZe-t5G0rI8ToVoWHCaQzhmzsa3CLz8PYk,1035 +torch/include/ATen/ops/silu_backward_native.h,sha256=G5YlEMK-MTPi3uGbj2dShepj0ITwqSLvcCQEFKFrw6Y,879 +torch/include/ATen/ops/silu_backward_ops.h,sha256=fSHOUHQBhVzeRgmujnxiR6zR0uXV2SUg7zf-iEJJpQw,1879 +torch/include/ATen/ops/silu_compositeexplicitautogradnonfunctional_dispatch.h,sha256=KYlf324Dl5Ubv0YSFCHqw7kHsrDZtbwTVXmv3AgUQdk,861 +torch/include/ATen/ops/silu_cpu_dispatch.h,sha256=VVVytfdv2ozfh6FOdE6WsYB0mnKNDwZBRaZmHjChUak,946 +torch/include/ATen/ops/silu_cuda_dispatch.h,sha256=JW9B2cFf2HwVpfrHc4gMOjN6omnH3KDgQTsIrnjc71o,948 +torch/include/ATen/ops/silu_meta.h,sha256=ILn7eXmI2O6GEhNfdX44zwbTWDELrh_f0M9fTtmgDYU,592 +torch/include/ATen/ops/silu_meta_dispatch.h,sha256=OFG_xdztLHQ3jjodKRT-mkwp_flc1wtHAd7gdSh5q7g,948 +torch/include/ATen/ops/silu_native.h,sha256=KiyoPbzDX_yD2xF3AItJOsaIIE6p_JUNPpVivBnYzRA,742 +torch/include/ATen/ops/silu_ops.h,sha256=kS2wuYVRgwjGQ5tpaSg1xIYJSrZPfmP8ZcJzi4_JI3E,2079 +torch/include/ATen/ops/sin.h,sha256=wsz4c1VS3c9RHslmfC4m1ofccrtYXLNe6NETzNj2n1s,1175 +torch/include/ATen/ops/sin_compositeexplicitautogradnonfunctional_dispatch.h,sha256=cnq9P9Vnz2au5qnpyE7ijvoevE4cZaphR14UWRpiruI,859 +torch/include/ATen/ops/sin_cpu_dispatch.h,sha256=Nj7xfYwSi6fbQYi0OAwthr1-O-qoDX21yVSNDi1W36E,942 +torch/include/ATen/ops/sin_cuda_dispatch.h,sha256=L_S8610gXfg-Yy-7oK1Y7bAADRLiMEMhPACyxArpWv4,944 +torch/include/ATen/ops/sin_meta.h,sha256=jzQUNkpKIXkDz3OgPhffPafkGImWk7-i1gvDS_gYa34,591 +torch/include/ATen/ops/sin_meta_dispatch.h,sha256=aX5zQysyspQGluhVAGFAcoxUytjkwRNAXH-U-hgugoA,944 +torch/include/ATen/ops/sin_native.h,sha256=stYK0sVlFZuFPpEQwDYzrta1t-uyYdOdeO3pRjLnVe0,1083 +torch/include/ATen/ops/sin_ops.h,sha256=pP1eh5fJjdWWyOmBS4EULns9ijWk0_FDsrkWh5fzlBg,2070 +torch/include/ATen/ops/sinc.h,sha256=OpPNYsuNZ3T8p03lLT7lhJvq1T9Dy32Vo8YzPT_22s8,1188 +torch/include/ATen/ops/sinc_compositeexplicitautogradnonfunctional_dispatch.h,sha256=10fQqH2382lL_kVUYmIRrUT8dt_9JTQ7rbWe8kEQImg,861 +torch/include/ATen/ops/sinc_cpu_dispatch.h,sha256=VShegRl72ajRWh8ThOGtiRLSsZA2K4vTD0YSLLxT-SY,946 +torch/include/ATen/ops/sinc_cuda_dispatch.h,sha256=aMhBlVquVEAx9WekdypmZnFcf-5wx44OVjG0Gx8F3Oc,948 +torch/include/ATen/ops/sinc_meta.h,sha256=mXbg_KRlgUkj4DnGmSnbBU6pLS1gvrzjdYsQfnkk_R0,592 +torch/include/ATen/ops/sinc_meta_dispatch.h,sha256=ADndZIdsYR9ZrGqGuXe_4U2eCPIdD7VZYc36UQdcxW0,948 +torch/include/ATen/ops/sinc_native.h,sha256=oVq3v7XARRZu_Mok-d6X-nnNWm9a7-k6I65h7kuOCQE,613 +torch/include/ATen/ops/sinc_ops.h,sha256=nLmILtvwOzHRZUN2aVR8vcEFl_vJZLZWhXNiUHLAYEE,2079 +torch/include/ATen/ops/sinh.h,sha256=MNxY3DdDHFzctBYjDNT28JbF2l1BHrBwMkdMLBoVFf4,1188 +torch/include/ATen/ops/sinh_compositeexplicitautogradnonfunctional_dispatch.h,sha256=TvOtiln_zo6hHlDxbFan8sKTIb3aujOhT5EhMPJGWnY,861 +torch/include/ATen/ops/sinh_cpu_dispatch.h,sha256=3dBxH4yZtunoGa8Krfiw5H-u4W9JwIjbInyfEiaAVic,946 +torch/include/ATen/ops/sinh_cuda_dispatch.h,sha256=0jlG3zlN_vP8_UOQ6x1GvhsjXKDqsueJ-CzwEyS3F2M,948 +torch/include/ATen/ops/sinh_meta.h,sha256=3fkEceSZHK72RumKEkczhORiphaToOMPlHUOlPkbPbM,592 +torch/include/ATen/ops/sinh_meta_dispatch.h,sha256=_NdcKiIcVyAtnGiNaOPg9yLpE0-PyKpt9v1fC4a54MU,948 +torch/include/ATen/ops/sinh_native.h,sha256=fhw2ULdDy61Hf74S3LWxpDBWLMrjjQ6ILBiBNU5EA5M,1027 +torch/include/ATen/ops/sinh_ops.h,sha256=Vqxo5iQ2fjka3b6abNKsghFlJtZFDGRxRDEXV-HkBW0,2079 +torch/include/ATen/ops/size.h,sha256=hvCM8VmfSupcDgADEAAYYp6OAH0o-RVWMLjPywE3Qoc,879 +torch/include/ATen/ops/size_compositeimplicitautograd_dispatch.h,sha256=YDVhtU-Cvf84yZDS9Nff5Vlm8YikwrAzyuf2rN57BOY,862 +torch/include/ATen/ops/size_native.h,sha256=sm3YmpyzCKupefd5_CARwXv9YyzxkaGKp0loFY28EOI,572 +torch/include/ATen/ops/size_ops.h,sha256=ichEST8ndK_xq1yTqU7WxixU2cqqcZoBvCKPNdNR20E,1599 +torch/include/ATen/ops/slice.h,sha256=AEfAzRNDV_5KnInjcTlWIaGeJxCSYzaxywg0odM96nQ,2286 +torch/include/ATen/ops/slice_backward.h,sha256=SwyTdO-eqb1XXKEVT60w5P-qUGM7A3yvuO0UH6xOasE,5547 +torch/include/ATen/ops/slice_backward_compositeexplicitautograd_dispatch.h,sha256=j0gMzryZEjUnSmpMPwKYPnzDl6Tg6ll3o5Y_QN2h_E8,1823 +torch/include/ATen/ops/slice_backward_native.h,sha256=O2vMibLyNNZWY8mwEpU1ldabeHRuM4u7g3qpBnsMNZA,797 +torch/include/ATen/ops/slice_backward_ops.h,sha256=5YjNw1DZy3FGcVoYX8llZ2T4_2B9ASyLZTnx4D8jQeE,2355 +torch/include/ATen/ops/slice_compositeexplicitautograd_dispatch.h,sha256=O5EfT3h-5LWs2WSYu6HyxqEA4gyMoV5LoDBSrgPaXAM,1105 +torch/include/ATen/ops/slice_copy.h,sha256=ortfjfYZSp7cqVj6tS7gt7njDeWoDRAVFjpyJORjalY,6163 +torch/include/ATen/ops/slice_copy_compositeexplicitautograd_dispatch.h,sha256=Qy-8D8M7aXnyhqMag6NbdNvSzmElSCc_4UZc-6caZzo,1528 +torch/include/ATen/ops/slice_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=rZI0WMrk2oPm2x4VHimocrW04Gl6rEg2FrwqBXSNV9s,1141 +torch/include/ATen/ops/slice_copy_native.h,sha256=ydbr3_0fUqt7e30cN4bFBsazHXM3ftcp5CuwHyrr-gw,848 +torch/include/ATen/ops/slice_copy_ops.h,sha256=BrlQ-7Cjo2ZfjpBS3YS9t7UVgndgtYJXbi8dR--h-6E,2348 +torch/include/ATen/ops/slice_inverse.h,sha256=lt9B8Y_ZXkvNFSxfyyt0rg_NCGWY6iuu2k1nCU8xUVA,2472 +torch/include/ATen/ops/slice_inverse_compositeexplicitautograd_dispatch.h,sha256=PSLGj5I2d6wdKxxeMWYY1Zz_YAU3xkGetoq4YsydepI,1169 +torch/include/ATen/ops/slice_inverse_native.h,sha256=zwEUS0Zj_vw6IWzIB7dJFRKKu6861p9t-JeE5wR8pCw,670 +torch/include/ATen/ops/slice_inverse_ops.h,sha256=BLRaV6yCPP5SVhY3mMnIZGkGuuxMY7bbkPt-aTSWmZw,1441 +torch/include/ATen/ops/slice_native.h,sha256=xmBqiWhcTKJQG6qqYOnkKGVPAzPVPfAGRT2LOLi5Apc,619 +torch/include/ATen/ops/slice_ops.h,sha256=IGiREs9nkmrQlTVnpLSOqbQBwcdTzi0A9i8pd8vN2mE,1357 +torch/include/ATen/ops/slice_scatter.h,sha256=6JtKDtNongS4sUux0NBkW7Mj8SZdMTFdVF7sbwaBDDc,6550 +torch/include/ATen/ops/slice_scatter_compositeexplicitautograd_dispatch.h,sha256=QSYvdbAfg2nVW5MESrWdhrLq6BTV4_d1YRL0qi3Tt1M,1636 +torch/include/ATen/ops/slice_scatter_compositeexplicitautogradnonfunctional_dispatch.h,sha256=KBpi-oxq8ojk0eVgtVoKbQhGKmEyv8Olh3HVC2SauCE,1195 +torch/include/ATen/ops/slice_scatter_native.h,sha256=OLBGLeYNqmDG9WvsQ7n6ieu5TqHwU2SzHC7KngMc3WU,869 +torch/include/ATen/ops/slice_scatter_ops.h,sha256=8bA-ncnAg-td6Ov8ohsBtyotabgEtYc1108af9UHzAg,2485 +torch/include/ATen/ops/slogdet.h,sha256=Ul4i9AyHlA706hWHcI32h6w-6rF7kT3OyRE8m2NFXms,1359 +torch/include/ATen/ops/slogdet_compositeimplicitautograd_dispatch.h,sha256=9RfMj5L5LEmhPb6Sf8QeN4rvur645HGMZbeTkZMxIbs,1078 +torch/include/ATen/ops/slogdet_native.h,sha256=bC7wMsDWYbWRspVXQZ0VjfIH1wF_r5MnRVlVHMqjWpw,655 +torch/include/ATen/ops/slogdet_ops.h,sha256=KqP2ReszAuoS-5eQ4_9q1fiyuW-GNnrwlL-j2xZxXVA,1890 +torch/include/ATen/ops/slow_conv3d.h,sha256=24JyUAvISfLsysamxcMpdiPsH4_bhLX9cIQmqQ8lpTA,6712 +torch/include/ATen/ops/slow_conv3d_compositeimplicitautograd_dispatch.h,sha256=PQS3mzlMvibAJLziIsywSom2Y1HIZ9J2hQfsQp8q9ag,2209 +torch/include/ATen/ops/slow_conv3d_forward.h,sha256=a8IvEqMI-G1t4fRPr-w-c_p-vLf1buVmFKBaS8pxnug,6842 +torch/include/ATen/ops/slow_conv3d_forward_cpu_dispatch.h,sha256=IUFa_5bokTc2ab1WHRrrzM8e6foEqt-DsU9OAutI7Uk,2145 +torch/include/ATen/ops/slow_conv3d_forward_native.h,sha256=spLemcAair5xcyJoPBX1q1bkjI87Ns2WyDS8viQCHbU,907 +torch/include/ATen/ops/slow_conv3d_forward_ops.h,sha256=8tadYVORMPGOvMNEuX9u3tL2o00gyTSeJmBSqOqrTQI,2737 +torch/include/ATen/ops/slow_conv3d_native.h,sha256=Q-HYVHv4mLEYxFLfSUia7ztkeD4pwjJBoxkCMOXtVrA,887 +torch/include/ATen/ops/slow_conv3d_ops.h,sha256=YiBmGFyBtIJJ45hB6i6ECR0MkT_m4OsvAP1o9GpTWpE,2689 +torch/include/ATen/ops/slow_conv_dilated2d.h,sha256=EbFOlMhCQjk939ZhzdwHRBvYt8jrbQO6gtDtt35yBYc,7772 +torch/include/ATen/ops/slow_conv_dilated2d_compositeexplicitautograd_dispatch.h,sha256=Cc_92NcXIMdUwkDTx4PN7bnvJkLXuMH15sDgtT7sJfw,1897 +torch/include/ATen/ops/slow_conv_dilated2d_cpu_dispatch.h,sha256=VofbzGb39PBdV0HkyHneEz-vS7o9wShuhAleZGIE7-Y,1250 +torch/include/ATen/ops/slow_conv_dilated2d_cuda_dispatch.h,sha256=KVcTtmcrF9ubVzHP7SK70BPxwyCIvL21lNEYrSsQwe4,1252 +torch/include/ATen/ops/slow_conv_dilated2d_native.h,sha256=xWC2gPUeX1flJwx56AZldl5dJs3s-qX9gMZkxCX2Zns,1239 +torch/include/ATen/ops/slow_conv_dilated2d_ops.h,sha256=wi3CQQ5qtGe5hr5svklW4fkTnYPBan684jN-QB6Oqxo,2943 +torch/include/ATen/ops/slow_conv_dilated3d.h,sha256=TzcvrGciruzRQH1xSjBDEAOCFXF5hSKXAEjQoRi3v50,7772 +torch/include/ATen/ops/slow_conv_dilated3d_compositeexplicitautograd_dispatch.h,sha256=d4Y5DRvuGb8AYJg3aUlTpJZM-dQ-SpACH8rMD2_li78,1897 +torch/include/ATen/ops/slow_conv_dilated3d_cpu_dispatch.h,sha256=yipL5yhxgEcjpSrdIbctuOj8uEMUK0D9TrLr7FZKX9Y,1250 +torch/include/ATen/ops/slow_conv_dilated3d_cuda_dispatch.h,sha256=E6QM7xm2A9lPMZT7PSOmXL2MmFpXv4JmdChN1xhRt1c,1252 +torch/include/ATen/ops/slow_conv_dilated3d_native.h,sha256=8PnLL7vHZYpmihmgNOwgdmdh-MYwQFS3GHdhly7ja28,1239 +torch/include/ATen/ops/slow_conv_dilated3d_ops.h,sha256=9egNDDv7JJqgdoNUDyZdWgZRdhIBy2Wntl0qj0aGsl8,2943 +torch/include/ATen/ops/slow_conv_transpose2d.h,sha256=Vzrm0OBJ0Nn1PtmlmjuJ7bRlqLYPKIGANFoKhzWuyHs,8826 +torch/include/ATen/ops/slow_conv_transpose2d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=YeTE5ZM5NY4c2tNr_1YQDz4N6G12TvtWEIcpp9YLwAM,1409 +torch/include/ATen/ops/slow_conv_transpose2d_cpu_dispatch.h,sha256=nTYWQyu19B3-2pXyWLSwHCd6FjR6_It-NP-yW6oGevM,2665 +torch/include/ATen/ops/slow_conv_transpose2d_cuda_dispatch.h,sha256=IP4LBA_wstWz27hEzkfQbDtSzae4Q5qNirYlrQN2_rk,2667 +torch/include/ATen/ops/slow_conv_transpose2d_meta.h,sha256=loVh4Fljv0hmp9nQwRBjB3gY14uvO3xWXBGsrnF5xl8,830 +torch/include/ATen/ops/slow_conv_transpose2d_meta_dispatch.h,sha256=jBljmP5hB5IEhiQ5Iqzc4lEOXIle2xM8qWZaNzAywaA,2667 +torch/include/ATen/ops/slow_conv_transpose2d_native.h,sha256=wb-4wYdq_-MKWzgUMbRl8hTK12TTz2gkGq1onlQRp9Y,1303 +torch/include/ATen/ops/slow_conv_transpose2d_ops.h,sha256=sFyNjP6uETRRHNSeyNMZoe32jikJQivYRgCjjvIGuqI,3197 +torch/include/ATen/ops/slow_conv_transpose3d.h,sha256=-v5auu777MdhSLpvoW349nULPQVfKUcDX8kHFtjZnmI,8826 +torch/include/ATen/ops/slow_conv_transpose3d_cpu_dispatch.h,sha256=Zc94BANrG5M65FWjA8CT6XoZD6bFyAOKRcBj7n-70l4,2665 +torch/include/ATen/ops/slow_conv_transpose3d_cuda_dispatch.h,sha256=xJv1IVvAf-2KcLf0_bPU5QjqXlq16zygxKvAmcVA_v8,2667 +torch/include/ATen/ops/slow_conv_transpose3d_native.h,sha256=U6wfcLGky_qVbZ-1PNCJZik8qY7ztk60m5yil0OyNBA,1630 +torch/include/ATen/ops/slow_conv_transpose3d_ops.h,sha256=ssbSD8FugZ1tjj2ePVQGtF0pAdTN5oAn54Mf0LwRDc8,3197 +torch/include/ATen/ops/smm.h,sha256=AagcbFYhnUM9ltPmAYsgMxyCNe3c7o5l3_cVsVvBDA4,700 +torch/include/ATen/ops/smm_compositeimplicitautograd_dispatch.h,sha256=H5EKA1qMBjbEedUixyhLaLo4OXfXrLVanvt-zjR4c-c,809 +torch/include/ATen/ops/smm_native.h,sha256=X3sRMNUC7N_9GqkNPr_6lq-GDVC_uyGdEEEra8Xo8PU,519 +torch/include/ATen/ops/smm_ops.h,sha256=_wa72hZcnCEDxQJ2ruIW1oH4EGiDgtA5pGm5xkiezT0,1060 +torch/include/ATen/ops/smooth_l1_loss.h,sha256=xerWzwudEKF9K3NZCpppMLIhy5XpWDbX2df4LsWfEZg,1600 +torch/include/ATen/ops/smooth_l1_loss_backward.h,sha256=WMLf5Eg1HYbLQhyhVvrDnIRkgcy3dSng_yFsgTOdoz8,1880 +torch/include/ATen/ops/smooth_l1_loss_backward_compositeexplicitautograd_dispatch.h,sha256=b3SH5E_yTpekwRVXxlVh26pNpe-krvSb5lORv52RByQ,895 +torch/include/ATen/ops/smooth_l1_loss_backward_cpu_dispatch.h,sha256=4No7vocgq616BUYLkXPGzr9zvNI8csE3mcvEARGd3pU,1077 +torch/include/ATen/ops/smooth_l1_loss_backward_cuda_dispatch.h,sha256=oBSQFUP4fDG84-NAh_AVAYWsqITscYvHnMAf8NbwL8E,1079 +torch/include/ATen/ops/smooth_l1_loss_backward_native.h,sha256=aU2c3TwtYMAWsRuG29ODktwPhHqSku4ViXAl6TPJuDk,799 +torch/include/ATen/ops/smooth_l1_loss_backward_ops.h,sha256=VcgnmKziN5Gr-3ppculjoJmYKh95F8mYTqxAyeVzpb0,2333 +torch/include/ATen/ops/smooth_l1_loss_compositeexplicitautogradnonfunctional_dispatch.h,sha256=X54sGX8vSXKvCfFu1F0IGj1oh0P5BQl7DegrJh16WSE,904 +torch/include/ATen/ops/smooth_l1_loss_cpu_dispatch.h,sha256=RyZvulnhfDQ7jPb0kfjXwG3lajx8_FYBSLCQdXLRx7k,1151 +torch/include/ATen/ops/smooth_l1_loss_cuda_dispatch.h,sha256=zfYOK9mcV6jtT1DtMozXHWXnptnkK6T7OAog64VHHug,1153 +torch/include/ATen/ops/smooth_l1_loss_meta.h,sha256=RLs_xJLvhKTQKSx6pjOYtMtGTOmI4XHvYyBzMWVwcu4,661 +torch/include/ATen/ops/smooth_l1_loss_meta_dispatch.h,sha256=sROiG49UogJcWLDhqYJiEOF_kwBT2aiIDkqDUgKIwJ8,1153 +torch/include/ATen/ops/smooth_l1_loss_native.h,sha256=m7eSQwL9MC38CtXzJzo3b0fxS0HqNb9olQ1gr3Nd038,702 +torch/include/ATen/ops/smooth_l1_loss_ops.h,sha256=5Z8BhylSNh6rxB_Dkyrg-mSiExvqDOzqEGj1k9_D5GY,2047 +torch/include/ATen/ops/soft_margin_loss.h,sha256=fnqB011T6L4HE7aQ7XvVnAJg0FRHxbFe9tdc40opMa8,1507 +torch/include/ATen/ops/soft_margin_loss_backward.h,sha256=y1xDwCjsNb9zFemwq5arIjlthxVcP9Oz6U-iKjgJ1_w,1807 +torch/include/ATen/ops/soft_margin_loss_backward_compositeexplicitautograd_dispatch.h,sha256=Xinuqb9JVo3ulOPyB3nk5aaG_pHSj6t5NHPE9_IXdDw,1251 +torch/include/ATen/ops/soft_margin_loss_backward_native.h,sha256=Dpb1OwNn_vPakDVrzod5ZLI3UmCbL8mW3PFpTM4iUx0,777 +torch/include/ATen/ops/soft_margin_loss_backward_ops.h,sha256=rlTAYXmH5f5rJagyaVgQU1Z-joOAiq_yvjwDpl41d0M,2253 +torch/include/ATen/ops/soft_margin_loss_compositeexplicitautograd_dispatch.h,sha256=28lPz-uwng0aUtuC400EL_Uqkr__iKoWQskPbB6elu0,1154 +torch/include/ATen/ops/soft_margin_loss_native.h,sha256=Gdp3hAZ_AFZMKH2yXWEZNOjY98elyl7P4C1FUDXpmaA,708 +torch/include/ATen/ops/soft_margin_loss_ops.h,sha256=g245K3T6iRrf6KJFfLxc4LaU-agvrNnHP4L_CqOrntA,1959 +torch/include/ATen/ops/softmax.h,sha256=Z4eC8BIgl9PWQvZFQvLcD-U8E8saDmweULG8mK6XiWw,1704 +torch/include/ATen/ops/softmax_compositeexplicitautograd_dispatch.h,sha256=8UfWBlYHfb7dLmyZs4sintXvpcRPzK4N4Isr-SwNB2o,1012 +torch/include/ATen/ops/softmax_compositeimplicitautograd_dispatch.h,sha256=KEtFpx3k7Zc0a2u2Eev9Bxtvc15kASVwecQn64Q-yCE,982 +torch/include/ATen/ops/softmax_native.h,sha256=uX5jDx8OCmZ5bJKWVagCTO3iJ629umG1L3zG7rRgDgo,824 +torch/include/ATen/ops/softmax_ops.h,sha256=m8JmCl-QE9cz06drBOR7C2ZWQKLHVwXVrxBYzW4YwL4,2709 +torch/include/ATen/ops/softplus.h,sha256=w97c5DlEizHunkiWHWmOk7avA2uHPgIcjClk4KIRHJY,1421 +torch/include/ATen/ops/softplus_backward.h,sha256=X5DuBlr6Ib1GI73rXzz48GejC5e7l8DM9hjqIOJ0YPk,1751 +torch/include/ATen/ops/softplus_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=jHCzG-UIFlWHsnobwepMox8pesK8XTCKfMqWDjC2dPE,911 +torch/include/ATen/ops/softplus_backward_cpu_dispatch.h,sha256=flEvb3ak-jp93nUW_0mg89xaAT8yba8FgaXhnq7MYXg,1210 +torch/include/ATen/ops/softplus_backward_cuda_dispatch.h,sha256=YX7lbRJtljm91oqPRkl6j7lc93jX6iHx24zF6lrFRTM,1212 +torch/include/ATen/ops/softplus_backward_meta.h,sha256=bwk0kmuHGFlzp6zX75os4zHeIalUCmqKg8kZ2xSHEEc,692 +torch/include/ATen/ops/softplus_backward_meta_dispatch.h,sha256=gAlrKDXX6Fr3vmF4JZoZJ9hCZyYOSIek5qqaFRBYVqI,1212 +torch/include/ATen/ops/softplus_backward_native.h,sha256=9csbwF9HvC9LLY4PsnyrJcj0Rh1Oc1R4sVbW704NjsM,746 +torch/include/ATen/ops/softplus_backward_ops.h,sha256=mOwWUcAbseIDu4VUeNyimCV0XalrlkveFD1hlneAHxE,2265 +torch/include/ATen/ops/softplus_compositeexplicitautogradnonfunctional_dispatch.h,sha256=7u51NFc37ioyX4zXVBR4vLkwlyj9VZ7NgG1yZ_vsCW8,875 +torch/include/ATen/ops/softplus_cpu_dispatch.h,sha256=mjRB1usA9HrtTjvtjaWwukir8iaBj0BzzPz-C50sl0o,1083 +torch/include/ATen/ops/softplus_cuda_dispatch.h,sha256=keuVZOtBV0IuFlQt_ArXUWkIbPxJWWkBAaZcIVpHWHM,1085 +torch/include/ATen/ops/softplus_meta.h,sha256=FTZzF3cZPfv1DSKiErpp8ByeFZ_O8ozyuQHqCSXsetI,651 +torch/include/ATen/ops/softplus_meta_dispatch.h,sha256=o4j_A5uOrpGMHGl5cN_LfpEkDnBZSaNpXl8_SbjXWHs,1085 +torch/include/ATen/ops/softplus_native.h,sha256=7BWh4bpfP4z5fuOVRVi6Sf-9DDiGonUhQYogbsYJsGc,680 +torch/include/ATen/ops/softplus_ops.h,sha256=g44_SdOKJ6MuzLyl8t5ff0ZjA0y0UUoL2O28Ut7njeQ,1971 +torch/include/ATen/ops/softshrink.h,sha256=N6FErBhJthGY8WD4vv9Tdy5IBc-jfLAA1TfUPZyU-CA,1268 +torch/include/ATen/ops/softshrink_backward.h,sha256=MjKlSoHigutwyErBEF1zmkNf6lW8PAxHOEY_N2uIbx0,1603 +torch/include/ATen/ops/softshrink_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=qvwqbb4gIsgK0qs3T3V2_SYrRBU01iqZLVS34LmyU8I,884 +torch/include/ATen/ops/softshrink_backward_cpu_dispatch.h,sha256=SpdisQ-EgzR-mhQa9voGZ6whUUefbuEjgy3YSZPAKig,1129 +torch/include/ATen/ops/softshrink_backward_cuda_dispatch.h,sha256=aJz9quoJ8gfWRGaz8-PoI3DErKei2zH6LN0PKml2JnI,1131 +torch/include/ATen/ops/softshrink_backward_meta.h,sha256=hMtLiGU9sj5tl18M71C_BYVtZhxiVPqewFZsx7o4seI,665 +torch/include/ATen/ops/softshrink_backward_meta_dispatch.h,sha256=HWpwGhkX59-TTCTd5iBkX_b0MZFddRhuzsg9D_pNKiQ,1131 +torch/include/ATen/ops/softshrink_backward_native.h,sha256=eP2Qj6k68V5kb-1wuFaVF2VDXB5uqht9jFR7Sm41B3I,723 +torch/include/ATen/ops/softshrink_backward_ops.h,sha256=_srQSYKIhbpEgOi5MLdwcnXNLNt9m2-HbrOjmdhwlRo,2087 +torch/include/ATen/ops/softshrink_compositeexplicitautogradnonfunctional_dispatch.h,sha256=sSyizCuwigPqlHFqUb-Dr_8a5Jh8acuNQFJXOswmdg4,847 +torch/include/ATen/ops/softshrink_cpu_dispatch.h,sha256=t1c9PxSa5Vbgvs-906ccOaP5x2TCOATOy1g44R3nv1Q,1000 +torch/include/ATen/ops/softshrink_cuda_dispatch.h,sha256=e9Ocsm1uFQGFOS4cg--pxZmW776TYDHBr8v0iHH-Eh8,1002 +torch/include/ATen/ops/softshrink_meta.h,sha256=Vhw2mdOAcDQasoldo1_0tjHvUTkRCakA7b70yYVhuaE,624 +torch/include/ATen/ops/softshrink_meta_dispatch.h,sha256=3RnwuCWIdVap7lDkeb-ztwvxSHcEPhVTKR5RI1js0wI,1002 +torch/include/ATen/ops/softshrink_native.h,sha256=lixXU9Nc7m67OwFrCyrINjkGYOiA8-AqSREa10NwFCI,657 +torch/include/ATen/ops/softshrink_ops.h,sha256=jZQxrRDfWtOHhr0EaqYrhJpux2opU0c_UBgh6GsfUjg,1791 +torch/include/ATen/ops/sort.h,sha256=QXVmzCCGi34sWyBILerDw6ub29KX3RNpqrkqsWgOTJA,5358 +torch/include/ATen/ops/sort_compositeexplicitautograd_dispatch.h,sha256=Gh3eGud158L35fxOgsG5b3_yv-IdT9nWfhvGoKx_1UA,1177 +torch/include/ATen/ops/sort_compositeexplicitautogradnonfunctional_dispatch.h,sha256=FVStsb7vcurrACX1RvS8HMhEnJ1nkJomgeB_-NI35UE,905 +torch/include/ATen/ops/sort_compositeimplicitautograd_dispatch.h,sha256=oRroBHbJbXYkvO4mHH-B3nmPoNYw-UnTDQ3drgtPn3o,1724 +torch/include/ATen/ops/sort_cpu_dispatch.h,sha256=GL-8gcTCMa1m638VkmXBBqRyhTqAI9iZVFzfgdu4xSo,1223 +torch/include/ATen/ops/sort_cuda_dispatch.h,sha256=cg411idtXI-n27amYelXTvUzgkGvXhcy8e0Wl8y_VP8,1225 +torch/include/ATen/ops/sort_meta.h,sha256=ifRf2Owtg0OCDCuqXt8jmwKnlnd1BTmRb3peVsIq_nU,659 +torch/include/ATen/ops/sort_meta_dispatch.h,sha256=dKhrqY6LpOnfTVkTs1DC_pvu-GrHYQNU3fiu1ZDgzDo,1225 +torch/include/ATen/ops/sort_native.h,sha256=mxxSRs2fM0mWNRmZQfbPWNn3eaUH3lqSxwzK7BX2u0A,1784 +torch/include/ATen/ops/sort_ops.h,sha256=MIhcI3IdbsikdWJx1wl7foAot3c94LJAkJEF4MXYuy8,7602 +torch/include/ATen/ops/sparse_bsc_tensor.h,sha256=6axKaOosp6ZOEgVw8vDZUHF6ii_XPZ29uBoCibNWU4M,3037 +torch/include/ATen/ops/sparse_bsc_tensor_compositeimplicitautograd_dispatch.h,sha256=k6szkAAeacNf7EFJt0dkok6PKv74Cx1JECE7SEdmzbc,1652 +torch/include/ATen/ops/sparse_bsc_tensor_native.h,sha256=Ab4sFQ0JZAf9e_lOZxWrW1wBix3pv82L5KHlmJM4oIg,1044 +torch/include/ATen/ops/sparse_bsc_tensor_ops.h,sha256=CCOYur2nUNpGYNcuP5CBLAv_ytNp7VFQxSxsNz3Inis,3167 +torch/include/ATen/ops/sparse_bsr_tensor.h,sha256=dTMdY1AmiIXDj7hOGwXH4laGa8NQAG8d6s3JayPpPWg,3037 +torch/include/ATen/ops/sparse_bsr_tensor_compositeimplicitautograd_dispatch.h,sha256=_F57OhMbWDpHu-0y2Aph8Lp2GvvUzTYEIzvgrFfBldE,1652 +torch/include/ATen/ops/sparse_bsr_tensor_native.h,sha256=e54re-cPqdPp-ffT--__hKtnjnYZ1mKoIr2PJ6Pb9qI,1044 +torch/include/ATen/ops/sparse_bsr_tensor_ops.h,sha256=MB985n5KyT4EczncsxneqrZAS-koJuNDlfAeVOCV8sQ,3167 +torch/include/ATen/ops/sparse_compressed_tensor.h,sha256=_Z4YodHPQJuELUlNmAgEWIATtBJ9jf7bH1VD_uR6Kcc,7049 +torch/include/ATen/ops/sparse_compressed_tensor_compositeexplicitautograd_dispatch.h,sha256=uLXM6vGy_5eF19vGw-YWGeU2T40Oatau2rfvp7M2r94,2246 +torch/include/ATen/ops/sparse_compressed_tensor_native.h,sha256=_ZXHUozQlcaxFNiZL0RHYpkffZ6NnfUjXRtjvws-3e8,1074 +torch/include/ATen/ops/sparse_compressed_tensor_ops.h,sha256=nNAuE47kaQ9rvsyLyenbuXg4Oh4kzpD0xoTeU9utW8k,3284 +torch/include/ATen/ops/sparse_coo_tensor.h,sha256=WGa8XNP1faQss7rNv5gSYNGMhJ7dkhP22wIN7ufGMxw,4343 +torch/include/ATen/ops/sparse_coo_tensor_compositeexplicitautograd_dispatch.h,sha256=26Ha9SgVxIS8HTlcg2Ajv8yBmZrGJ4HjD5o6Xc0Fkus,1205 +torch/include/ATen/ops/sparse_coo_tensor_compositeimplicitautograd_dispatch.h,sha256=RMrVUPggBbUBB9dTi_a8wMgmNvaGrQiYmFRuhmz13Xs,1684 +torch/include/ATen/ops/sparse_coo_tensor_native.h,sha256=RGwaopQdIyAbLPvPcS8wcoOfF3ix9Bsa-e6Jm-XXkHA,1384 +torch/include/ATen/ops/sparse_coo_tensor_ops.h,sha256=IbwrKjj9tZBP7sY8wWicIuMm2-MTgLJVkrzWdLXllDM,4809 +torch/include/ATen/ops/sparse_csc_tensor.h,sha256=MH9r1v1QNQSUq8W4o8YuyQDys6PmIdW5_xfKeP48CWM,3037 +torch/include/ATen/ops/sparse_csc_tensor_compositeimplicitautograd_dispatch.h,sha256=huZ-todArSn7pWDWA2HqDJmYcEeDiLkS-bnwYSoXXJo,1652 +torch/include/ATen/ops/sparse_csc_tensor_native.h,sha256=TwUnFfhw9pbWH2kqCuQu6DEp9ERdBpKY03q3Rneur2Y,1044 +torch/include/ATen/ops/sparse_csc_tensor_ops.h,sha256=BGZY69tCl9NY7erbSP9cOPvC4LDPzJHfdHQjvyYoGeE,3167 +torch/include/ATen/ops/sparse_csr_tensor.h,sha256=iEZgSh3rve3ZKqahiqKebgktAXAScDmXQqdEm2coGO0,3037 +torch/include/ATen/ops/sparse_csr_tensor_compositeimplicitautograd_dispatch.h,sha256=Htnq-HeY5uBngJjkX4NdXCEpT9Vi06_moQmB7CZhQBc,1652 +torch/include/ATen/ops/sparse_csr_tensor_native.h,sha256=4gh5eut-ZGmIE3LR0or2DKYQccM1kVcN9WipijqCUkk,1044 +torch/include/ATen/ops/sparse_csr_tensor_ops.h,sha256=L4sx5w0tvy7SXAQ91G-Ei7i8Gase3KCLQXYCdYHHAkU,3167 +torch/include/ATen/ops/sparse_dim.h,sha256=jyBnnzxa_tF3ZJHEUu-3NkNGUAwz5AmIEEqKZLidcfE,534 +torch/include/ATen/ops/sparse_dim_compositeexplicitautograd_dispatch.h,sha256=rE6VSXwXV0AIc9C3HhWItRBL-k5fWWKWULj0pB8_eiY,788 +torch/include/ATen/ops/sparse_dim_native.h,sha256=Okb_oc5qO5wOEO01lJ0XEpDfLL0ixKtOxXWSI63ZsfU,636 +torch/include/ATen/ops/sparse_dim_ops.h,sha256=gDUhavGMKzNC0wqOgzVoekJIUwELRA0Bqt8g6ya9Q1w,986 +torch/include/ATen/ops/sparse_mask.h,sha256=28c7vi9pBMe82Lu2f8_2OpxVbalIZor-gnNxpYG6XDU,1050 +torch/include/ATen/ops/sparse_mask_compositeexplicitautograd_dispatch.h,sha256=U9BVzA6zY1-tnHHcoCBl_Ul211q91SmvTx0fYXf_YB0,951 +torch/include/ATen/ops/sparse_mask_native.h,sha256=6Da3NC7j12rb9DehbNQWlBxVHamp3jFpzJBphUKDOS4,739 +torch/include/ATen/ops/sparse_mask_ops.h,sha256=w9hPMw-KUKf1fJV_7UyNI_qw9jAUmcfHI8k4QXuBkA0,1783 +torch/include/ATen/ops/sparse_resize.h,sha256=ECNGdQ1uMNs-gwXynTobQmUT2JP3olRhNVEMF1LOZow,1560 +torch/include/ATen/ops/sparse_resize_and_clear.h,sha256=sa6ZuHnQgx4Uk_nazBqfrH-_Av0zPKzMCPpuJSy5BQU,1660 +torch/include/ATen/ops/sparse_resize_and_clear_compositeexplicitautograd_dispatch.h,sha256=3ob1S4Tq9BvQWe2tgOeDeZM-qGtnOEsgE_CimNYYE98,1204 +torch/include/ATen/ops/sparse_resize_and_clear_meta_dispatch.h,sha256=YOV7FGes2j0XepBqVR5K9r_DaN0O1NwG0owK2xcWmwY,832 +torch/include/ATen/ops/sparse_resize_and_clear_native.h,sha256=GmjbvteX71JZiLOSeIXn5uJXOQT13ZnX6Sj_fZRceV0,886 +torch/include/ATen/ops/sparse_resize_and_clear_ops.h,sha256=7gMqv8U8GPP3K0EuzeDl3OCCoe8rPENwpAPkq7kGDQc,2922 +torch/include/ATen/ops/sparse_resize_compositeexplicitautograd_dispatch.h,sha256=9FvFIfnjWmw_YrRSo5_HKLCR58Ywu4IOj3ObpGCIuMw,1174 +torch/include/ATen/ops/sparse_resize_meta_dispatch.h,sha256=xkVCZV_WdTtdTL7PQ7G29xhq3IeQ6BW7ndFyGOrCVVI,822 +torch/include/ATen/ops/sparse_resize_native.h,sha256=v6GjBwNnVH4aUuQwLfXBpqKJz-ZXFNKUnZkt0XlNkDE,856 +torch/include/ATen/ops/sparse_resize_ops.h,sha256=sfTcLQF58q9c7vZu7f5xrvyMOV6HI1YYhF7o85X_934,2832 +torch/include/ATen/ops/sparse_sampled_addmm.h,sha256=vNu0tZvAJDSqRL8zUzN-h_i0jHyniAldy8RHmLpzOs0,1767 +torch/include/ATen/ops/sparse_sampled_addmm_native.h,sha256=7mWdupa07zj6MET0nH2bfsctwC7q4VAsvQF13avjcEY,1240 +torch/include/ATen/ops/sparse_sampled_addmm_ops.h,sha256=HXric2NNEbZmeaUdVgJzzC_R0QBf-KVgF8ITYV82egk,2352 +torch/include/ATen/ops/special_airy_ai.h,sha256=orQkoPBQzQBcQgzetLWUiwy6OGS5eFnmOKEyp_BS654,1130 +torch/include/ATen/ops/special_airy_ai_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Amv4wP2VqL4D8ot35ScjxKmabAK7mB40c7sYntF1Hb8,819 +torch/include/ATen/ops/special_airy_ai_cpu_dispatch.h,sha256=KwxeOB29ZYH6vqVZwfRoR4pCQbHIq3RNzGPimOJXW6s,920 +torch/include/ATen/ops/special_airy_ai_cuda_dispatch.h,sha256=-ZXm68gsSx9xJY7yiHocwk_qXaVPnK52ubTCNaN8l2c,922 +torch/include/ATen/ops/special_airy_ai_meta.h,sha256=_lcMVTHExBgWfkZhTScRQIG6aRZvurCfwgZV7kXOUQg,600 +torch/include/ATen/ops/special_airy_ai_meta_dispatch.h,sha256=CZ4oNieL1263yueO04_ZuP5pTJd0k28LTaoU0B4QRIw,922 +torch/include/ATen/ops/special_airy_ai_native.h,sha256=hT1nZd_FCAE197mfoatN8fcxCPVe3eLv5KaX8IA8Sl8,643 +torch/include/ATen/ops/special_airy_ai_ops.h,sha256=6Bpj5cBq4Phr8Z5JTnXEvjNebTAAUQui3sIHibl1oyU,1623 +torch/include/ATen/ops/special_bessel_j0.h,sha256=vP6pAsQ5e6tOu7c_68NOrdOfMe8z3WfzitJxLaVBEwc,1177 +torch/include/ATen/ops/special_bessel_j0_compositeexplicitautogradnonfunctional_dispatch.h,sha256=xtwonixHyzG6Pjd6avl17DEerDTXuWWOF_EZ5k__LB0,824 +torch/include/ATen/ops/special_bessel_j0_cpu_dispatch.h,sha256=i2sqMzzra49LRN7tTSdNZzLvDduc1r1asl9DEtZJC1E,935 +torch/include/ATen/ops/special_bessel_j0_cuda_dispatch.h,sha256=EvqgP8ENPN42-IpfjZgLL2y9-igY7XnRuxirLPY1g7c,937 +torch/include/ATen/ops/special_bessel_j0_meta.h,sha256=7TZ5Mnyx_c9vVWqwSNLdcDXb__UV3kxPiFzSO8u4uCo,605 +torch/include/ATen/ops/special_bessel_j0_meta_dispatch.h,sha256=QSTrxKsnimBJsabuoGlzu207x4ryQ5APvfnJrqI_NSU,937 +torch/include/ATen/ops/special_bessel_j0_native.h,sha256=t5a9BoPw66sMt4j37jR-JsfA6igh33wIQ8t7AoM8Zu8,652 +torch/include/ATen/ops/special_bessel_j0_ops.h,sha256=rKtmIXY9c_sCZUvBmCW-slU7LmmLUGq6tNABz5GngvE,1653 +torch/include/ATen/ops/special_bessel_j1.h,sha256=6V82FYfal801SFoJ96FELcWnbhilTX_Dp7L57BLqmmo,1177 +torch/include/ATen/ops/special_bessel_j1_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Vy0maUtz0KesRWYxuGXeNSlq8qDWEOOOvvXKgh7XeiE,824 +torch/include/ATen/ops/special_bessel_j1_cpu_dispatch.h,sha256=l3E-Ag7Veq4a11u-zO7-QW_hmm44rsNj3Y3WPyUxN6M,935 +torch/include/ATen/ops/special_bessel_j1_cuda_dispatch.h,sha256=QmmkgBNWB7M0YI3lO-13i5FGv5fNSBXK93Dvpcyvgp0,937 +torch/include/ATen/ops/special_bessel_j1_meta.h,sha256=t7_n3f1yBTNt-mkGuZzJpEzxsv33_-oXEYvEAZCrma0,605 +torch/include/ATen/ops/special_bessel_j1_meta_dispatch.h,sha256=6UIQkYz3b_AIlva5gHawBpfHwlPWLGZflcqXLKRkd7g,937 +torch/include/ATen/ops/special_bessel_j1_native.h,sha256=sfHT2IwWeK08NwpC2AFQWHLVyV0zV5XDJ76Z01-Hffk,652 +torch/include/ATen/ops/special_bessel_j1_ops.h,sha256=BXikotY35RN7zzlwjIS1grukhP35kOvFCMplswgyA5c,1653 +torch/include/ATen/ops/special_bessel_y0.h,sha256=Y36oKzLN6a8DZ9imZPxJjM-Ta3kuIyDoEDJ2mSVjPPM,1177 +torch/include/ATen/ops/special_bessel_y0_compositeexplicitautogradnonfunctional_dispatch.h,sha256=iV-ALV78i7IRtSAb5a1oLX759nQRWQ7rnWcoYYemiTA,824 +torch/include/ATen/ops/special_bessel_y0_cpu_dispatch.h,sha256=ft0jCJCL8kgE9t9EpGDfBhLobsSq9Apdn78tKcHerKs,935 +torch/include/ATen/ops/special_bessel_y0_cuda_dispatch.h,sha256=gNcFocQeaOYzJlQvNaZv4IuC3sZWUpD6Em-60idytL4,937 +torch/include/ATen/ops/special_bessel_y0_meta.h,sha256=v0dtT9MdkW48CPDWSok4KdAp1hLxKQQMikKw7e7fVxc,605 +torch/include/ATen/ops/special_bessel_y0_meta_dispatch.h,sha256=lbD9DODjgu0isIwRbgSPOscwl9fgXB92HxgPrlr01y4,937 +torch/include/ATen/ops/special_bessel_y0_native.h,sha256=IkP9vHRTK1ggx6AKBSUop6siOHqFcAUzsBTjcj5pgO0,652 +torch/include/ATen/ops/special_bessel_y0_ops.h,sha256=DfEJNWF8gG5Vo2SGbcvEZSU6ZnrubdFoRX7HuC9nXrc,1653 +torch/include/ATen/ops/special_bessel_y1.h,sha256=asI_GQJl6N_NivQoXEhdFB6akB7lwTh6QGGUnd4DMh0,1177 +torch/include/ATen/ops/special_bessel_y1_compositeexplicitautogradnonfunctional_dispatch.h,sha256=POr8He4--UU9eveTGUjZ3GEtIy0nTRr-SPTtm31x6ug,824 +torch/include/ATen/ops/special_bessel_y1_cpu_dispatch.h,sha256=iWuuuNJOjKZwTygVqAN6PZyF_XzPV-noHl5VJP-UKs4,935 +torch/include/ATen/ops/special_bessel_y1_cuda_dispatch.h,sha256=oL8g-OIfR0iT-FUXksiTzEGFtVgt4CsmGrXHvQUmq4M,937 +torch/include/ATen/ops/special_bessel_y1_meta.h,sha256=CTKS9qCb47ttro4TyJL12LXt5SFtEFsaECaiv1Zm5Qg,605 +torch/include/ATen/ops/special_bessel_y1_meta_dispatch.h,sha256=Jc30zvSKbpIAuAFWgFEJk53w6h1IPtN7QGgwxT2o5KI,937 +torch/include/ATen/ops/special_bessel_y1_native.h,sha256=U4lJC9awq20iqtl2gFSm-gfmUcdC4MPIzYbZ8V7NxKg,652 +torch/include/ATen/ops/special_bessel_y1_ops.h,sha256=HgKNJ-fCNyZm_D45EkNzQA7WdJ0Enrf6CKguaKcom9c,1653 +torch/include/ATen/ops/special_chebyshev_polynomial_t.h,sha256=1NNjxMYrekG4YziIrfrW2lPmCy1uZoQO1ruJotFNmUw,3159 +torch/include/ATen/ops/special_chebyshev_polynomial_t_compositeexplicitautograd_dispatch.h,sha256=o5NQlgET_c6mio0oc0qmWCULcI0L7X067bDf_L8CVgg,1418 +torch/include/ATen/ops/special_chebyshev_polynomial_t_compositeexplicitautogradnonfunctional_dispatch.h,sha256=d_lN-f_nTVmZX8zQilY0vf4BZPMqezFRkVjmjjBbp10,856 +torch/include/ATen/ops/special_chebyshev_polynomial_t_cpu_dispatch.h,sha256=TEPuTZj8iShv5Ahxl0g7sxSXGR8x5AoUhJcoMsUwvgo,1031 +torch/include/ATen/ops/special_chebyshev_polynomial_t_cuda_dispatch.h,sha256=g5tgqGci_VoX8lb229tsYUGUhUve5Q0Z3m1jfldjsdA,1033 +torch/include/ATen/ops/special_chebyshev_polynomial_t_meta.h,sha256=98-H5vlcAcAVaTO78N9Akf4uvic9dhYfIn6AI-2oinE,637 +torch/include/ATen/ops/special_chebyshev_polynomial_t_meta_dispatch.h,sha256=ksziK7OPv9cZHpWAwazGiNiD4NnYzcleiSN6FXn-ZpU,1033 +torch/include/ATen/ops/special_chebyshev_polynomial_t_native.h,sha256=QVYILjyggytVpvTcMVU6FKVdook3Uh4hLkjCYWJPklQ,1150 +torch/include/ATen/ops/special_chebyshev_polynomial_t_ops.h,sha256=h3KcoF2BBtCJOvsPJvYfvieR6J5VLacj9K-1L5bGeC0,4739 +torch/include/ATen/ops/special_chebyshev_polynomial_u.h,sha256=U_dc5gNXfOcWdBHvjv9DKl6ESw6gcFN2CXxEmmiNgto,3159 +torch/include/ATen/ops/special_chebyshev_polynomial_u_compositeexplicitautograd_dispatch.h,sha256=I9PNS9nuS1-W3l8DNVWTkhNX4dCazib4oAP8U8dVbSA,1418 +torch/include/ATen/ops/special_chebyshev_polynomial_u_compositeexplicitautogradnonfunctional_dispatch.h,sha256=LqVAqo6TpjFk2c4iRRDkrQkFfhgmfens_1zqnRSVHwA,856 +torch/include/ATen/ops/special_chebyshev_polynomial_u_cpu_dispatch.h,sha256=achPgrbVjEnBw1_affjDGyw30BUyBE__Daa73TunGb8,1031 +torch/include/ATen/ops/special_chebyshev_polynomial_u_cuda_dispatch.h,sha256=IcRaLmuetBq9N6wNqm-aXGT-Us87SSpa83x1OxRxzFs,1033 +torch/include/ATen/ops/special_chebyshev_polynomial_u_meta.h,sha256=ViFidwCDQhaPqRDV_3rWny3KKzlbksroRZXukTJhJUw,637 +torch/include/ATen/ops/special_chebyshev_polynomial_u_meta_dispatch.h,sha256=aH6JAMBnTxSi9sjBRJo9emWN8gIOPK696OF1PC0AFlE,1033 +torch/include/ATen/ops/special_chebyshev_polynomial_u_native.h,sha256=MbzoQ8vZUphAA7pm0CBJOeAf3FJsM6VTMVGImj9ArC8,1150 +torch/include/ATen/ops/special_chebyshev_polynomial_u_ops.h,sha256=HvHoMiVcf_NAjmmoooq6Kl9dfg6vY7SwWa_UQSSKFQo,4739 +torch/include/ATen/ops/special_chebyshev_polynomial_v.h,sha256=Ru0Kc9WzT-bp8UG0d_Y--z-Ml_qvVxp-l9K6zPrEKV0,3159 +torch/include/ATen/ops/special_chebyshev_polynomial_v_compositeexplicitautograd_dispatch.h,sha256=k86OAeY4aHkvfLH4pcDcoK8f03gY2usfhO8vDHfEqBI,1418 +torch/include/ATen/ops/special_chebyshev_polynomial_v_compositeexplicitautogradnonfunctional_dispatch.h,sha256=lzrlZRdDCBIWmiQinwZCVWA1XAPpS0XvDiwBV6Ka1kU,856 +torch/include/ATen/ops/special_chebyshev_polynomial_v_cpu_dispatch.h,sha256=pdJqsvuGG3nl4gTYJ1fKJ6AA_WTVV2e5uqppcmcU6fM,1031 +torch/include/ATen/ops/special_chebyshev_polynomial_v_cuda_dispatch.h,sha256=nfWcXOOs1thZTlcZmPpxBvIUgT0Tb9tQD2VYko_J5KA,1033 +torch/include/ATen/ops/special_chebyshev_polynomial_v_meta.h,sha256=Xin7rndrNUcJ-KleWksVLvhJjkCgfosJrUxDaArdrRE,637 +torch/include/ATen/ops/special_chebyshev_polynomial_v_meta_dispatch.h,sha256=0qpJS5xKqFV5aNRk5GfiMLprf5N_YehO--FDmnx4HPs,1033 +torch/include/ATen/ops/special_chebyshev_polynomial_v_native.h,sha256=2gzib64fH_Pzap68Vf5zUalFN58pFG_gDR4Tyk5KXSo,1150 +torch/include/ATen/ops/special_chebyshev_polynomial_v_ops.h,sha256=-VFGoJ1HUsXjaUU_mFLMkN7jO9kTokwPpWt7l6ZhFAM,4739 +torch/include/ATen/ops/special_chebyshev_polynomial_w.h,sha256=dwuuWiLQ-8oQ7H8356IobLKrD7-4zWciL-143xNtk84,3159 +torch/include/ATen/ops/special_chebyshev_polynomial_w_compositeexplicitautograd_dispatch.h,sha256=2UbQSITM63dPybbqOKHES24vpuuYyBVvoTHy73NG52c,1418 +torch/include/ATen/ops/special_chebyshev_polynomial_w_compositeexplicitautogradnonfunctional_dispatch.h,sha256=OYEwJdFmbajf_uuCFtFpH6bq1VYM-WG3_MwO5YvCC80,856 +torch/include/ATen/ops/special_chebyshev_polynomial_w_cpu_dispatch.h,sha256=cL1cea3menyNfAAzTVSEjS-6ozNE8wf5FlAbyWux4QI,1031 +torch/include/ATen/ops/special_chebyshev_polynomial_w_cuda_dispatch.h,sha256=Uu0TNIyZDUGlLSz_-OWsjq-pMsUbk9MWtKhDChVdR9E,1033 +torch/include/ATen/ops/special_chebyshev_polynomial_w_meta.h,sha256=QftLrZmqHXOYLtD8IzkjKwsKQtELauQayAGw-r0QnmQ,637 +torch/include/ATen/ops/special_chebyshev_polynomial_w_meta_dispatch.h,sha256=0apF3r7baWfNbB4Da6OboDmGqGlyyomgZx470fWrmmI,1033 +torch/include/ATen/ops/special_chebyshev_polynomial_w_native.h,sha256=G2CxrUOD2bB24c-V4D-QC3O51dDLOhI-e3E0BZ9o9Ew,1150 +torch/include/ATen/ops/special_chebyshev_polynomial_w_ops.h,sha256=xWdtzXPHfM2AOMg0HTWcUsTkwnlxNN6DuRrL_zJbY8E,4739 +torch/include/ATen/ops/special_digamma.h,sha256=cAXjhSveoT_LQweaY0jZlwRw8aO3xOUkyUfpwKhXSt0,1157 +torch/include/ATen/ops/special_digamma_compositeimplicitautograd_dispatch.h,sha256=03ESh2gs2HmvVUySN3E62x9_lMG5OQMEMtJ21phBUrA,973 +torch/include/ATen/ops/special_digamma_native.h,sha256=1BD2enGwP52YPckpemOSbSsB6Yzf8Ha3WcjQsyfX34k,594 +torch/include/ATen/ops/special_digamma_ops.h,sha256=_IX942umsFUSJVhVJCziPTVnuxl4QZo-EpCub3ZRlns,1641 +torch/include/ATen/ops/special_entr.h,sha256=SjtBzFjMthKG8ADas98OlLbzce2V-mRb8ZqLS5Q_HBc,1127 +torch/include/ATen/ops/special_entr_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Oy-xOmQ3j37y5j3SFONqL7JrUhvUb6tLbDrBAa2l34w,819 +torch/include/ATen/ops/special_entr_cpu_dispatch.h,sha256=yDKFRIOIyB-wml14UzDgGnywrDH0Myr4-0bcrIGGp3E,920 +torch/include/ATen/ops/special_entr_cuda_dispatch.h,sha256=wkfA777S2tPmi3Nanqn_qkcBCp7Ha6wYEWmRNunRFK4,922 +torch/include/ATen/ops/special_entr_meta.h,sha256=YPS3xJA9nkOO7hUy-1PEzDeDf3Gb9qn0WdH3ouR9Nrg,600 +torch/include/ATen/ops/special_entr_meta_dispatch.h,sha256=WLjAIUIRyR6L_1UMu3jTcOLAbgiEDl_ZsiXyXoWZ9p0,922 +torch/include/ATen/ops/special_entr_native.h,sha256=YUecCsy0_l-BjTv_c6rN8Vsum6tenDcVufEV2Y1br74,637 +torch/include/ATen/ops/special_entr_ops.h,sha256=DVtDAhqcp2miNNiO2m-47MHRJRNd20fXFT9BRHUSMNo,1623 +torch/include/ATen/ops/special_erf.h,sha256=0i12LBzM9e5A9vylF1pvcVJY2TRCmWnZ9W2s5heE7QI,1117 +torch/include/ATen/ops/special_erf_compositeimplicitautograd_dispatch.h,sha256=QVpsjLg7s4fQkssVX8_UWYfSXTXhojFaH5E9_c6FT5A,961 +torch/include/ATen/ops/special_erf_native.h,sha256=3tfd9NCihuuY-tybbetazkayt3BbhqVzuYcG0VW49CA,586 +torch/include/ATen/ops/special_erf_ops.h,sha256=ab-DYc7QDE8YvgGoHxO3Oli0F2UJTJdvvYbT4xvQfAE,1617 +torch/include/ATen/ops/special_erfc.h,sha256=Pdpf_B8gUKGZeQv8ynWVc6cphW1A3rNR1nwGFZy0WSU,1127 +torch/include/ATen/ops/special_erfc_compositeimplicitautograd_dispatch.h,sha256=VFwe4o01vVZ-BuJ7VcZqwURekptbIPik_zLpQu9r6mc,964 +torch/include/ATen/ops/special_erfc_native.h,sha256=LI29SsPZ1zTMvu2UFBM1nluIJx4DVmimaKCTMTJzAnw,588 +torch/include/ATen/ops/special_erfc_ops.h,sha256=ion77s-717FVeoRHJHWoh9ajIAv29D2AOmxQiJhrCBc,1623 +torch/include/ATen/ops/special_erfcx.h,sha256=yr_gzYexYDb6s2Ny6WJOaDgDdcWfXzD7jXOivuZCz7Q,1137 +torch/include/ATen/ops/special_erfcx_compositeexplicitautogradnonfunctional_dispatch.h,sha256=0E5g1TSsxS5T8OerMgfuKTdomLELg7CWV_wTrSExv9M,820 +torch/include/ATen/ops/special_erfcx_cpu_dispatch.h,sha256=Med4-1_SOi3r_N_Y-i-8jcpMyevjsPCQCH-F51OOPig,923 +torch/include/ATen/ops/special_erfcx_cuda_dispatch.h,sha256=bQAMEBZ-0UMAgRUKRlxl8j348buMF-zkBOKZMokGAwA,925 +torch/include/ATen/ops/special_erfcx_meta.h,sha256=gcdJutu29W6X963yAVZvMEmpuKUoZI2_QkMW0sBu8u8,601 +torch/include/ATen/ops/special_erfcx_meta_dispatch.h,sha256=RVMpV2Ahi0SRQcLW9-BwhJJOCrAfDRzo1X-w1lcZAIU,925 +torch/include/ATen/ops/special_erfcx_native.h,sha256=8aw28JM9mvi3EEkXHt7GjV_3a7D9PPZ1AnpQjz39EOo,640 +torch/include/ATen/ops/special_erfcx_ops.h,sha256=5rU7HXV4v8x74iU40bsntHKi2NGAJbeWqGlkXXLqOuk,1629 +torch/include/ATen/ops/special_erfinv.h,sha256=u4uXAs8S-ikn4m5S1onzaThduWh8SDN0qkDjvfjurK4,1147 +torch/include/ATen/ops/special_erfinv_compositeimplicitautograd_dispatch.h,sha256=pp6IewOxXdAARvT6hoJ3DNWwox8tSHVkTujbnlab34E,970 +torch/include/ATen/ops/special_erfinv_native.h,sha256=U6pLVT6dPpFAYILcEbyVaxYeiY8iUbG-yryxrAVCCso,592 +torch/include/ATen/ops/special_erfinv_ops.h,sha256=yudckEN9vIJpTIqLy735MEhAw46rn5_fzT7TLm7yWCk,1635 +torch/include/ATen/ops/special_exp2.h,sha256=yMkN8MTqmFb3GL-nfE0jbFMZAqh9NGoCXbIqa6fw3us,1127 +torch/include/ATen/ops/special_exp2_compositeimplicitautograd_dispatch.h,sha256=hZnvss0K63Fo52WOgh_ibNLaikq6-FhQljxkOYtCduI,964 +torch/include/ATen/ops/special_exp2_native.h,sha256=uLb9nVXSaDbdlzBvLhMJujbIAIBM4QWWwt3VqIH8hdo,588 +torch/include/ATen/ops/special_exp2_ops.h,sha256=HO4gu-Ii5k977deKBtJvIqNkJXEq7av_DlyO_wRFThU,1623 +torch/include/ATen/ops/special_expit.h,sha256=Hrsc3OpZ12TQBcQHwW7qhlboIujwzvQ2LO6ON9gSSkI,1137 +torch/include/ATen/ops/special_expit_compositeimplicitautograd_dispatch.h,sha256=JhR3dpoZ9a84JCTod3FYbQI_Ggy3PDeOLyz4MrRZru4,967 +torch/include/ATen/ops/special_expit_native.h,sha256=HtaiqUmnDSCgPxPogRqn_I2xKIwZMbNtXh33AZ1l0IY,590 +torch/include/ATen/ops/special_expit_ops.h,sha256=r8rHVHIBtajs3kVNG8cgd6OgBGHqzZduDf1sVaAaYbE,1629 +torch/include/ATen/ops/special_expm1.h,sha256=az9O1hhEBjJNiwcaXyG_Yu_V5G2OTMUMUJq6vCEAat8,1137 +torch/include/ATen/ops/special_expm1_compositeimplicitautograd_dispatch.h,sha256=pYMcUMtD8YKQYGdsnIqMq0RKlA-KwLmpFbOZL-qwWK8,967 +torch/include/ATen/ops/special_expm1_native.h,sha256=HpiQHi7vmG4heH9mHdUkMpH3fNknkptyyid73_rVcBs,590 +torch/include/ATen/ops/special_expm1_ops.h,sha256=FvAj51PywrcNuexuh5D-G5B7CSQ60UoOE3qATHbsIcY,1629 +torch/include/ATen/ops/special_gammainc.h,sha256=VidpZx4ClOeCypGxEnFC53POl_ZLV1-Ciw125Mh-D8U,1308 +torch/include/ATen/ops/special_gammainc_compositeimplicitautograd_dispatch.h,sha256=61VLuIzi1XD6_5I6YrdVgRny2vxvUthEUNsykTIPr-Q,1054 +torch/include/ATen/ops/special_gammainc_native.h,sha256=wQYBWI7SWae6Kkd82fpd1vVgYtcwsoTS0fzcvYKHOCQ,648 +torch/include/ATen/ops/special_gammainc_ops.h,sha256=UNCkfqFgLE1PiNcCPuWUAoHXH-at4uZ585zrGjHF6XM,1819 +torch/include/ATen/ops/special_gammaincc.h,sha256=blsaKGFAydX24bsc8riVCDi1QRIu_7L_DqAmwVPQPBw,1318 +torch/include/ATen/ops/special_gammaincc_compositeimplicitautograd_dispatch.h,sha256=jlzhPVr9SbBo-cGY_tQfwjIytOpMRGd50Mfh_qbNeRs,1057 +torch/include/ATen/ops/special_gammaincc_native.h,sha256=KqGABq_4mbeonTxONz-0UnAkqGuRAL7NUhJ92pABB-8,650 +torch/include/ATen/ops/special_gammaincc_ops.h,sha256=V1emGY-hrKYWPpXzOFxYs0K2pZEIFO_yyTDkdfPzVZg,1825 +torch/include/ATen/ops/special_gammaln.h,sha256=z1S5fHRAvbxEO0XgdiCdYhV1N4DXRMG62XJ06u14Da0,1157 +torch/include/ATen/ops/special_gammaln_compositeimplicitautograd_dispatch.h,sha256=93PnW44IoHNeQ1FeCJlDIVseTFilsAEACrDLtZyKl2E,973 +torch/include/ATen/ops/special_gammaln_native.h,sha256=Cyfiaphgt8p8Y_lW0-Okpvs-k9DzUouU8DvnNM7wVMU,594 +torch/include/ATen/ops/special_gammaln_ops.h,sha256=QJ-k4F_nWXMyqH130cpLjs4d754IIOkoNB5lRknKEOE,1641 +torch/include/ATen/ops/special_hermite_polynomial_h.h,sha256=F59QK_QBGB2m7PPUPpVEKT3aU-yCGNv2ETHSwz0KNtw,3103 +torch/include/ATen/ops/special_hermite_polynomial_h_compositeexplicitautograd_dispatch.h,sha256=m34rp6ZBPI6epQnmdLHiYHLuddKzCLitDKFgXlryGf0,1406 +torch/include/ATen/ops/special_hermite_polynomial_h_compositeexplicitautogradnonfunctional_dispatch.h,sha256=60JoAAzdf2LgmVGzFBxmsv7lHWDrucWjR8nFENjgisc,854 +torch/include/ATen/ops/special_hermite_polynomial_h_cpu_dispatch.h,sha256=ubk5jtZk2q5VxeW2jYA9z7pVKcxvu1jBthns0LRFIKs,1025 +torch/include/ATen/ops/special_hermite_polynomial_h_cuda_dispatch.h,sha256=GyMAxLj5NuwvxbugJW_jGFFUVxCvkrktSF7XJnh7pQ0,1027 +torch/include/ATen/ops/special_hermite_polynomial_h_meta.h,sha256=sKfX-5L2iTwidn6lxWwtTe6pxxdA_QYkvOM_9fgcsOw,635 +torch/include/ATen/ops/special_hermite_polynomial_h_meta_dispatch.h,sha256=gA9rW32q1h8inVN2oISfOoK9YX14P23yQIK6TndVfQE,1027 +torch/include/ATen/ops/special_hermite_polynomial_h_native.h,sha256=X6mgSTVysVWfV-bRGCcfIlZlAFaqOIa5ViAKNr9cGro,1136 +torch/include/ATen/ops/special_hermite_polynomial_h_ops.h,sha256=txaZZfH332Q4VIMBWgBYPD9V9QLX66FJfb5HpShYBZo,4703 +torch/include/ATen/ops/special_hermite_polynomial_he.h,sha256=pOxj4zVfvCZcRAXzPd4H4xJmvO1kADp802pv7GIFZ4I,3131 +torch/include/ATen/ops/special_hermite_polynomial_he_compositeexplicitautograd_dispatch.h,sha256=aU2KbNhiEmMn9wcwz-LRWQHPSo9C8pmbzsNgCaLnbnk,1412 +torch/include/ATen/ops/special_hermite_polynomial_he_compositeexplicitautogradnonfunctional_dispatch.h,sha256=fKWSoFCUB9kdz6KP8n_HG6xqSORirw8kXQMTGS1ZQe0,855 +torch/include/ATen/ops/special_hermite_polynomial_he_cpu_dispatch.h,sha256=gYExlZI8zn2ECXPj8jPXC4plZwQPBQP-GnEv6yz1IAs,1028 +torch/include/ATen/ops/special_hermite_polynomial_he_cuda_dispatch.h,sha256=aiNvwdxbZ4NATDE_FoRZPPnlwwoyGJPqeaeLntselvI,1030 +torch/include/ATen/ops/special_hermite_polynomial_he_meta.h,sha256=rB2YaoLzOUFV-4YnD1FKaDRlApSmhkJvS-rKU1Msxxw,636 +torch/include/ATen/ops/special_hermite_polynomial_he_meta_dispatch.h,sha256=8XofOfwjbkKKheqe0Usa2sUaYzDJKuqaDRqEkixsv2k,1030 +torch/include/ATen/ops/special_hermite_polynomial_he_native.h,sha256=prnGDutz0vEH5QBu87vXot-IwWpEIulJPtV9sJHtmcI,1143 +torch/include/ATen/ops/special_hermite_polynomial_he_ops.h,sha256=6nFQ8Vc9W31HxaOsZwZExnrK9g9PjWRN9NrbARyGu_I,4721 +torch/include/ATen/ops/special_i0.h,sha256=zgPxWLwuZ83ZYzQ5QT3QFKGEcdp-SsipGXptvKJ6R3k,1107 +torch/include/ATen/ops/special_i0_compositeimplicitautograd_dispatch.h,sha256=FC6bK4L6wzaCsmEQJVuqZcrLWzpgrI5POwy00I3TwKg,958 +torch/include/ATen/ops/special_i0_native.h,sha256=twbgtEHl0sGVFkSyXp4_AQTQ3v-3EUE5jGlhTvG5dsU,584 +torch/include/ATen/ops/special_i0_ops.h,sha256=iTl3w2DEXKwLJe56b-BPNVz9dkudg7BKvSZKC5hMGYw,1611 +torch/include/ATen/ops/special_i0e.h,sha256=dNsk17nMf311AiWr6qV4k9LJjjKq_7zhha1gmdbyNX8,1117 +torch/include/ATen/ops/special_i0e_compositeexplicitautogradnonfunctional_dispatch.h,sha256=chO_hOoalTlL3kPdpnZSh0LYr6rwF9pofLiqWyfUkPI,818 +torch/include/ATen/ops/special_i0e_cpu_dispatch.h,sha256=ekD85aVu3HBLP-BsWcz4OyPgzGUTz0EhNDzN8U9zX4w,917 +torch/include/ATen/ops/special_i0e_cuda_dispatch.h,sha256=s8FGpKr_lTSKZna5TWmbmryFBCsEbQH6P7IicWTc8-U,919 +torch/include/ATen/ops/special_i0e_meta.h,sha256=hDRgiXM2OXctisyb7D0_0KZ8qztEZwoEadFTyjlvDV0,599 +torch/include/ATen/ops/special_i0e_meta_dispatch.h,sha256=dDR_9rUN3kQKpAI43dvb00LckS8ocpZbz-DNxXEDRdk,919 +torch/include/ATen/ops/special_i0e_native.h,sha256=OYILZoPced01TBC4WOQKLU0zcD7yAaijgQ8IFQHntJM,634 +torch/include/ATen/ops/special_i0e_ops.h,sha256=1V-P_1s4w45UEcXby_f6fH1jvcmP-w1RGAVdHzQWBOU,1617 +torch/include/ATen/ops/special_i1.h,sha256=jkUfnhVw3qnuwhlCW8Wt593z1JvN02lbqc1sYlzGVM4,1107 +torch/include/ATen/ops/special_i1_compositeexplicitautogradnonfunctional_dispatch.h,sha256=LIx9ZthxW8-0o3XrkxYEPh16GfNHIrVMOFOAOVNKFnw,817 +torch/include/ATen/ops/special_i1_cpu_dispatch.h,sha256=PUDzyuHZ3iUzxVxRt6M7n3iIq27t36wQNSdFg-uBk0w,914 +torch/include/ATen/ops/special_i1_cuda_dispatch.h,sha256=D3jwyxVD1DZf89IAqJeTdG6Dq78M2FobqI6CYPSABec,916 +torch/include/ATen/ops/special_i1_meta.h,sha256=ag1Sy4D-V0BWZaq1fpi9svHqCPg0WT4Cx3JDA9ZA2TQ,598 +torch/include/ATen/ops/special_i1_meta_dispatch.h,sha256=Inh9FdoL_PTsYd2HKGfIhrOn3bmKHOA9Pd1kP9f4_S4,916 +torch/include/ATen/ops/special_i1_native.h,sha256=Lb8kD1IKgD5xGWCqWAfgqEiw-WQ7Z6ZqfaUjfrkeZXg,631 +torch/include/ATen/ops/special_i1_ops.h,sha256=5JS3EBs_XCdy6qJdIv-DMuW5fXCsmyHfsOiAgktZCks,1611 +torch/include/ATen/ops/special_i1e.h,sha256=umd6e4L85sXsGI4G48X5nd0DG-XltOlP5ElUq5HX4t8,1117 +torch/include/ATen/ops/special_i1e_compositeexplicitautogradnonfunctional_dispatch.h,sha256=rvPb5XLD3XPMo1qePzdSVgiKEKTrXNzK3jxUUqOTaUY,818 +torch/include/ATen/ops/special_i1e_cpu_dispatch.h,sha256=qce2bKa9Y3t9HTlAWHITSQ-4IEqsrMfD53oSIujAyTw,917 +torch/include/ATen/ops/special_i1e_cuda_dispatch.h,sha256=HlPCLJpath9non9qW1fWOmZGa_mibsqRYIohFEvhNZM,919 +torch/include/ATen/ops/special_i1e_meta.h,sha256=sCCSnSsPFo0tuhcCu08xCY96_4tpqLxx42_JHsanSQg,599 +torch/include/ATen/ops/special_i1e_meta_dispatch.h,sha256=Rpfvp6cXDBQ-o3XkRNBGot0O8UwZ3WdLbfR30peXJTY,919 +torch/include/ATen/ops/special_i1e_native.h,sha256=LwfjbH2L5eNHWf3d0EEgYlZj2vhA_dbeMq32mEH7dnI,634 +torch/include/ATen/ops/special_i1e_ops.h,sha256=B5C9sBgZmvqyj134mBE-_lkiz6y1pVzZanemSWV8xhM,1617 +torch/include/ATen/ops/special_laguerre_polynomial_l.h,sha256=jpsr0iZ8rLQvHZJzZnl28Ijv6GJrFbSRRa8dGpvJfZw,3131 +torch/include/ATen/ops/special_laguerre_polynomial_l_compositeexplicitautograd_dispatch.h,sha256=CgA_yMdzdjAsprvwsoxc0u5WRDLnly0VxqgrHLzY_GY,1412 +torch/include/ATen/ops/special_laguerre_polynomial_l_compositeexplicitautogradnonfunctional_dispatch.h,sha256=JLt2LI9rwhhhMYtS11cs-gg1It5VCbDxAonYbsUBiXA,855 +torch/include/ATen/ops/special_laguerre_polynomial_l_cpu_dispatch.h,sha256=qEEXSJturtKrdQplLfCAiFS9pICaPT5hb7GxE5Psvhk,1028 +torch/include/ATen/ops/special_laguerre_polynomial_l_cuda_dispatch.h,sha256=Mbh9lT1KVycQnrQaFm13vYpNoav4P29piE-JLhQqQd0,1030 +torch/include/ATen/ops/special_laguerre_polynomial_l_meta.h,sha256=s8eodiAnpQiWw1jmzPHDMT-OUEh4eGD06kz1QZcBdXQ,636 +torch/include/ATen/ops/special_laguerre_polynomial_l_meta_dispatch.h,sha256=riX_XkQtauQqIN8eFvO2yaie0fxPRAgmxxlz5TfAMoU,1030 +torch/include/ATen/ops/special_laguerre_polynomial_l_native.h,sha256=8VKvXluwL6-2ePUuFE5OnUMiPDHebccwmRxHZvmx904,1143 +torch/include/ATen/ops/special_laguerre_polynomial_l_ops.h,sha256=Co3_jp8B_xvTo5Map7_wjL9scIN83z7k4mvQ1R-2zHw,4721 +torch/include/ATen/ops/special_legendre_polynomial_p.h,sha256=GE67H1R536PoV1ZpJFpIooDVIDJnYSnaX2yg3hNRq4o,3131 +torch/include/ATen/ops/special_legendre_polynomial_p_compositeexplicitautograd_dispatch.h,sha256=lHeiCP-1VV6AOZ0O2JTH8apLew_e-jPNpeYPBJVYWM0,1412 +torch/include/ATen/ops/special_legendre_polynomial_p_compositeexplicitautogradnonfunctional_dispatch.h,sha256=mVKKg2KBQY1OxtgdlKG1pWncJ4PHddqYldqhENox9FA,855 +torch/include/ATen/ops/special_legendre_polynomial_p_cpu_dispatch.h,sha256=QdAiR8INLjbhX9kEOJGMO1MCil77TvYSEiEJeYIVYQA,1028 +torch/include/ATen/ops/special_legendre_polynomial_p_cuda_dispatch.h,sha256=Ty5fI1_M_A_aQZmwO19bz3zJfV5e_LzOpETu0Yu6GLM,1030 +torch/include/ATen/ops/special_legendre_polynomial_p_meta.h,sha256=89exe_Fy7jzbGWctL2u4mwrB9iByxo7Yj0IjOETJc8Q,636 +torch/include/ATen/ops/special_legendre_polynomial_p_meta_dispatch.h,sha256=9Q5vqvo7O1hCHqHpkoRME3Im23UEHCczgumeBZ5B144,1030 +torch/include/ATen/ops/special_legendre_polynomial_p_native.h,sha256=tdNia9pLk9PGQJG8sroB_62LjI9gIlC0ODLMaWFn6nU,1143 +torch/include/ATen/ops/special_legendre_polynomial_p_ops.h,sha256=je-1vvptcr9rdmon4muTLsNTU6QAQpiWqicfcGJV9qk,4721 +torch/include/ATen/ops/special_log1p.h,sha256=jkMIh-IEioc-AaQhQHjcj0eEIeK60jkTNjou-53967s,1137 +torch/include/ATen/ops/special_log1p_compositeimplicitautograd_dispatch.h,sha256=zk3JiSnn-97kWqZPcJjqJpICCCVH-Odo-vtXAOVt-2g,967 +torch/include/ATen/ops/special_log1p_native.h,sha256=7mtZ0gXOBEaXZK8ckZmasagwQv4YGmKw6cA2Jcda338,590 +torch/include/ATen/ops/special_log1p_ops.h,sha256=CYIAO6LlPQ3T7QxUg6aJfWFrhGyEPQASM5ywWw-rbrc,1629 +torch/include/ATen/ops/special_log_ndtr.h,sha256=IU1k39FSupyLAb2I4BawTeAMi7ue_5LK0sGY0E0CNLo,1167 +torch/include/ATen/ops/special_log_ndtr_compositeexplicitautogradnonfunctional_dispatch.h,sha256=8oKqlkHU-NKTgWRB7TPG3jTbBAhJpSRSo2pMYHuDmQE,823 +torch/include/ATen/ops/special_log_ndtr_cpu_dispatch.h,sha256=XTarxK9BvS44LHH0qDFbbRkez_3QSxYiakBLdBY_EpU,932 +torch/include/ATen/ops/special_log_ndtr_cuda_dispatch.h,sha256=93LhCOD56gz2r8pdlP1zk_YvrmpF-_ArJAlYX8Rwu6k,934 +torch/include/ATen/ops/special_log_ndtr_meta.h,sha256=WAvYAJjhLtz5nhwPQ0zYeC0jmQCQ2PWk_n5iR-7p-LA,604 +torch/include/ATen/ops/special_log_ndtr_meta_dispatch.h,sha256=UwSpn0Uqsbu2ymIymg08r_TZOEtyKiRSgDbsMFW08Ow,934 +torch/include/ATen/ops/special_log_ndtr_native.h,sha256=rhJxlc3w0KqiNfCSd7-BSQzWFIqFaY4j_1Vq2atE7gw,649 +torch/include/ATen/ops/special_log_ndtr_ops.h,sha256=2B-yCDvqWFDgGv2q684oh9p8ejs1X0fZTLAw8TQffPw,1647 +torch/include/ATen/ops/special_log_softmax.h,sha256=8Bn36srqO89duaWMK4EIT_JxNekTaiPX182dwVXAUHs,835 +torch/include/ATen/ops/special_log_softmax_compositeimplicitautograd_dispatch.h,sha256=NBglBi1rPNIvn72a6zQ89FID1P7oN3E5S_YrYQkphgs,867 +torch/include/ATen/ops/special_log_softmax_native.h,sha256=X8wgQ9eQ1PYbZZzOHLgTB5hnvpu5VpnQtmVSM1xMvgU,577 +torch/include/ATen/ops/special_log_softmax_ops.h,sha256=Y5jyzhY9RP0fy125qPe7cnMRG-TPYStrjM1oMd52S0Q,1207 +torch/include/ATen/ops/special_logit.h,sha256=7gbKdxcYaVfOQT15pqZERtpqj3Pw_S4sMfAEcQyVVQg,1320 +torch/include/ATen/ops/special_logit_compositeimplicitautograd_dispatch.h,sha256=9Z7bIexen-mhp_WlmyFfxiqFvIJPKmA7pngXtkqxgw8,1084 +torch/include/ATen/ops/special_logit_native.h,sha256=eaCFzR8j_bqFCt1MVkAbQeOurrrHIqtvzAG16u8xvtU,663 +torch/include/ATen/ops/special_logit_ops.h,sha256=nDjBBDPbXEAosy2UfOxl6Sqxt-NnGb7bkuCavtwv5-4,1829 +torch/include/ATen/ops/special_logsumexp.h,sha256=k9TarPQdVQrHuiTm20LzsYUMbprbyZbwNF7UziXYeP8,1432 +torch/include/ATen/ops/special_logsumexp_compositeimplicitautograd_dispatch.h,sha256=ke_moF6RRCcEhVGtB9v620hpcnQkboPWYrWQdIEA8XI,1096 +torch/include/ATen/ops/special_logsumexp_native.h,sha256=Ys90h-vWRTmP-c5yujfEPDeDBZ9ESr5UV7s7ntyd7P8,674 +torch/include/ATen/ops/special_logsumexp_ops.h,sha256=15O3pINivZclCDaGAng5xAcahGV6Me8yVvVHCEcw1ZI,1903 +torch/include/ATen/ops/special_modified_bessel_i0.h,sha256=2HhDLGBXrABWhKWwlc8c5pwiJ_RWk1MdWyc6PQ7kgqQ,1267 +torch/include/ATen/ops/special_modified_bessel_i0_compositeexplicitautogradnonfunctional_dispatch.h,sha256=HLFgSX3hXs2AI25wV8KKLj1AI8lyQOTUS7Xt3vnAGQA,833 +torch/include/ATen/ops/special_modified_bessel_i0_cpu_dispatch.h,sha256=bEHaFoQYTdY3hwd0FH9bySWtowMW1JZ6XojAVU78zMg,962 +torch/include/ATen/ops/special_modified_bessel_i0_cuda_dispatch.h,sha256=gTf9VGiy4d-UefNxyiB6xzCl560f3xZQMVwpKKzUk-4,964 +torch/include/ATen/ops/special_modified_bessel_i0_meta.h,sha256=CM5EWckSttQiYGb_xlFttXJTUrFpS0WtpNEQ8XbXKvA,614 +torch/include/ATen/ops/special_modified_bessel_i0_meta_dispatch.h,sha256=OcIDqLXl7jlNd3izU_bU_JWO7ImZnZhsBX5jkQcyGKc,964 +torch/include/ATen/ops/special_modified_bessel_i0_native.h,sha256=kJUcJRXLWOWtC6-Kof0tx6IhMnP7PzFqVnpiHl_DVQo,679 +torch/include/ATen/ops/special_modified_bessel_i0_ops.h,sha256=Dt6K0lDG3O9rTkHbsZbGzaij0L5WLjJRxldi0V4kyLE,1707 +torch/include/ATen/ops/special_modified_bessel_i1.h,sha256=a1Ua3_IfF4OvNXuY93ZHNMWnu5_btdDRrI23kJkzw8U,1267 +torch/include/ATen/ops/special_modified_bessel_i1_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Ck8I9hzGPUgtTbSRW7mJ9pD4XQuc5qkiszpEx7HLJjA,833 +torch/include/ATen/ops/special_modified_bessel_i1_cpu_dispatch.h,sha256=N-rQJyt4ZdoBiCE6ausT0F0M9dmnJpNSDHSZarx3DG8,962 +torch/include/ATen/ops/special_modified_bessel_i1_cuda_dispatch.h,sha256=TdJQ3WsWlOMwC-6PxlECEIiAhaUzodVa0vfUSdW2FC0,964 +torch/include/ATen/ops/special_modified_bessel_i1_meta.h,sha256=Dt8RxVFym_KvisXzI6CAQm871RxxsjId5GF-2feH0LY,614 +torch/include/ATen/ops/special_modified_bessel_i1_meta_dispatch.h,sha256=7K-fZGJnS0YE54xZ_qo0WARfQ0gaKwsRJmuPbUwTC9I,964 +torch/include/ATen/ops/special_modified_bessel_i1_native.h,sha256=w_pw3gN2cqmOXybEDSydIsc_y3TT3Nc-3c9T5HsLc0s,679 +torch/include/ATen/ops/special_modified_bessel_i1_ops.h,sha256=fdHRiok8O2niwVXbzQUE95vfrG3BD6c0V2xhLv0WuQM,1707 +torch/include/ATen/ops/special_modified_bessel_k0.h,sha256=r0PCAirVepFzDDyn-D2RRfZcpbrlqtwZpoA25KuuT-Q,1267 +torch/include/ATen/ops/special_modified_bessel_k0_compositeexplicitautogradnonfunctional_dispatch.h,sha256=f0QDgrhWVXgtv04LJdN554JdjRrUL6M51GhGzFs4_ss,833 +torch/include/ATen/ops/special_modified_bessel_k0_cpu_dispatch.h,sha256=-YbLqF2OaRMhBqIK8IrbPGgUzARGC5WkNYWFNfFw_-w,962 +torch/include/ATen/ops/special_modified_bessel_k0_cuda_dispatch.h,sha256=UkXeGXKMQOwCdNuH-gl_857nTBMhdSINJBlOwRlAR-c,964 +torch/include/ATen/ops/special_modified_bessel_k0_meta.h,sha256=_FXfceYhVyAcrPzFCXY7_Be6snFMola9iq8-ojAfhXQ,614 +torch/include/ATen/ops/special_modified_bessel_k0_meta_dispatch.h,sha256=-E3m5VNU8JcGlfcnl5N69El7ui-PUIs7KGYFnENXilY,964 +torch/include/ATen/ops/special_modified_bessel_k0_native.h,sha256=mqdCWljmVF0u3C5ittkdSQiZ6WnFY3G9bQUs_UPe7IM,679 +torch/include/ATen/ops/special_modified_bessel_k0_ops.h,sha256=9Dq0kp-QXiCk_iyYdTsf5rzaciYPN16Lu8J1tSphfVc,1707 +torch/include/ATen/ops/special_modified_bessel_k1.h,sha256=lk5XgeXN0NElOhl5QOrP329e4J1bqDENzRNXiCFJqk0,1267 +torch/include/ATen/ops/special_modified_bessel_k1_compositeexplicitautogradnonfunctional_dispatch.h,sha256=xkAwVlLp-20POLv-NHC90QzjtFR8HjZ2kpyye5QUl2g,833 +torch/include/ATen/ops/special_modified_bessel_k1_cpu_dispatch.h,sha256=ZVQOTEOdwzHGsipmpPYkWX2dBMzneam5FAmFxQE9evE,962 +torch/include/ATen/ops/special_modified_bessel_k1_cuda_dispatch.h,sha256=c0lvUewSlPo6QOPNn0YH39WONqhm4TIplsFMX1xyJ1c,964 +torch/include/ATen/ops/special_modified_bessel_k1_meta.h,sha256=_1xUd9ZEoKcoBIg0x50BruTgmpQKTiX8cOtwraU0Xg0,614 +torch/include/ATen/ops/special_modified_bessel_k1_meta_dispatch.h,sha256=UIlwV3b26LR21-mHnnKp6TfKGsH-Llhm_KroDLDbC1g,964 +torch/include/ATen/ops/special_modified_bessel_k1_native.h,sha256=5jUOk_md_I8ybo9X6YvAbGUyj0XBXgsV72WMECSanTc,679 +torch/include/ATen/ops/special_modified_bessel_k1_ops.h,sha256=pDFGB5b6AGJ7jGQgB3pJsBlZ2MIg3pcYER4s9mi5pAo,1707 +torch/include/ATen/ops/special_multigammaln.h,sha256=QO1eDyZ-ccdmspSzZoMCF1TwsMs3uKqHdYaYyu0v89Y,1270 +torch/include/ATen/ops/special_multigammaln_compositeimplicitautograd_dispatch.h,sha256=hw4w5SO3683IsxpZfn-nlhu9nW3PdW3YundExJpENdI,1021 +torch/include/ATen/ops/special_multigammaln_native.h,sha256=R8WWg359ahp-83flG74a60tNP6j9Pg_0z4CNzTm14zI,626 +torch/include/ATen/ops/special_multigammaln_ops.h,sha256=7a08EPeYSN7pAKDXs7y8GrsKBAvzr4nRdXErJbwIpBc,1747 +torch/include/ATen/ops/special_ndtr.h,sha256=lmLZVVmYjdR7cwbSn01-9zJxWiLOFULt8hxOL9u_koE,1127 +torch/include/ATen/ops/special_ndtr_compositeimplicitautograd_dispatch.h,sha256=uz__akHrt5feWbS2u6YANmHnMnuTOKujB4nBjgyrbq8,964 +torch/include/ATen/ops/special_ndtr_native.h,sha256=0pS-nh7oqpEQtOtO38AG-8VS6y4BOrli-JSJW_rBfqM,588 +torch/include/ATen/ops/special_ndtr_ops.h,sha256=JOn89V5Q6zHSqwk6nWpGJEiB6ynoj_BITN6YWATf66I,1623 +torch/include/ATen/ops/special_ndtri.h,sha256=TVR06_JPwwGtR4YavjTzCXsnQ6bAIaAXkrRO0Ou5mX4,1137 +torch/include/ATen/ops/special_ndtri_compositeexplicitautogradnonfunctional_dispatch.h,sha256=3_szt77ivZTtWepJ6PBoWTJro81aZ0iadfEggWAriHw,820 +torch/include/ATen/ops/special_ndtri_cpu_dispatch.h,sha256=8eXCICWXcrPFKNoa4o5Vwa_e1AsE8gW3yhFSpxV7AL8,923 +torch/include/ATen/ops/special_ndtri_cuda_dispatch.h,sha256=erHw-2Dtfcarv90QLeHgYt-Wr108P6H_PZHqP_LDlLU,925 +torch/include/ATen/ops/special_ndtri_meta.h,sha256=R8oxCuE_GOUWHDwjbcSuyAoYUIWoYqB2X27PDKSh9Qo,601 +torch/include/ATen/ops/special_ndtri_meta_dispatch.h,sha256=0Pp986Bqidq7tBPrZEzXJEx9O6UWSWubSto4tzC2IRo,925 +torch/include/ATen/ops/special_ndtri_native.h,sha256=RdtJTqOtBBjWj2tuKJUqay6b1aGdaDyswXF6Ut25UhI,640 +torch/include/ATen/ops/special_ndtri_ops.h,sha256=1Y3OFSpecH6-aQbU4-wP10ldJ7z5aYJWnAIUFVaRhgg,1629 +torch/include/ATen/ops/special_polygamma.h,sha256=GeoyQlsDEruUqGRAJNjYX4-g9Dvn3jpRJ698Q1Sjb24,1240 +torch/include/ATen/ops/special_polygamma_compositeimplicitautograd_dispatch.h,sha256=j4645TzA4mKHJpSl0oIq3i6RqrpyCHS8bmPC9wzZ0dc,1012 +torch/include/ATen/ops/special_polygamma_native.h,sha256=qM9vyK6NkZ6VOuWmkH6mHKiOja3cLy2tc9Uy4E9LkUU,620 +torch/include/ATen/ops/special_polygamma_ops.h,sha256=YG98Ph8TlXIdEbHJIfROb8XJmEOrNlt9vbKfaMXnh2Y,1729 +torch/include/ATen/ops/special_psi.h,sha256=7UllBvH18VFjmehiitUeFkxRLXNChsFNEyn_Fs7hyHQ,1117 +torch/include/ATen/ops/special_psi_compositeimplicitautograd_dispatch.h,sha256=pOocDEYPQrWvFsHVZTwNkAAdsb8zZQwTMp14jYB7Wg4,961 +torch/include/ATen/ops/special_psi_native.h,sha256=riOrdFH5B6_LzT1D0W6b8DxXITPtl_WXXJVueVvnzRc,586 +torch/include/ATen/ops/special_psi_ops.h,sha256=eV6HV-i2CICVkYwTNnxdAi7jIVCi47OGAbdydrOTa0s,1617 +torch/include/ATen/ops/special_round.h,sha256=ncobuEVp4jpsTqiNMGzwxo5yKyWt_Koy22Dlwus73K0,1276 +torch/include/ATen/ops/special_round_compositeimplicitautograd_dispatch.h,sha256=a0BcChLXvbz0RJTnP9im6d9hMRkaIbdNvnbmUN-zZ5M,1025 +torch/include/ATen/ops/special_round_native.h,sha256=7se-8MosVes3TE5czgcJpObg8Iv0abtHPbkVxBZR4ys,628 +torch/include/ATen/ops/special_round_ops.h,sha256=ITR5Dyd9ykT7plFJDk0D1tO0qlN-yN1HtCyEuGjsYNQ,1754 +torch/include/ATen/ops/special_scaled_modified_bessel_k0.h,sha256=tUffa6uRu5wEmZxS39rBlyvXMfUKFYq9_4XjHLIgOIk,1310 +torch/include/ATen/ops/special_scaled_modified_bessel_k0_compositeexplicitautogradnonfunctional_dispatch.h,sha256=SBZ8i8FnC1JoLWfVjd4Ol99rBHfwCxJPF-II2m4G438,837 +torch/include/ATen/ops/special_scaled_modified_bessel_k0_cpu_dispatch.h,sha256=gzcQ65b7jMXF8sQNvXpOYr16LaAXZyIXkjM8rLjJPNM,974 +torch/include/ATen/ops/special_scaled_modified_bessel_k0_cuda_dispatch.h,sha256=Ic7kI5AbclWH_2gEZmaMBqcKuPpK7SnwU7WJsSPWUX0,976 +torch/include/ATen/ops/special_scaled_modified_bessel_k0_meta.h,sha256=Vp7OUuXUW8rq_XUXDVEw1gEUL0pMf8NU2PT42meqVcg,618 +torch/include/ATen/ops/special_scaled_modified_bessel_k0_meta_dispatch.h,sha256=V_QrMBXBBz6wm3RSJmS2UQOX4vTcHaEuI9awHoeeXRo,976 +torch/include/ATen/ops/special_scaled_modified_bessel_k0_native.h,sha256=aUZkdm8FcQhmGHTG8Ili78NoD5Py6zXs-nGPE9EbS1A,697 +torch/include/ATen/ops/special_scaled_modified_bessel_k0_ops.h,sha256=JJL4YSQLkFXO6A_WLhi-0iP8jrD9mRXGf34mntQ4POU,1731 +torch/include/ATen/ops/special_scaled_modified_bessel_k1.h,sha256=pyo23-bum5ZXqO59s6-d18HETnSGbWFJR0Cd0hmSpXk,1310 +torch/include/ATen/ops/special_scaled_modified_bessel_k1_compositeexplicitautogradnonfunctional_dispatch.h,sha256=9QMCHdywUDHzGC5szyx_e2nd5Cu5H-8Jqf4WRDllhDg,837 +torch/include/ATen/ops/special_scaled_modified_bessel_k1_cpu_dispatch.h,sha256=h4FivR1kGh6iKEjdf96rR82UugIRty6xd05AXN4EIyw,974 +torch/include/ATen/ops/special_scaled_modified_bessel_k1_cuda_dispatch.h,sha256=gz_V-Foo2SpW0MMoDlP_weFj1Uu-TnRxEIQrrg7oDJg,976 +torch/include/ATen/ops/special_scaled_modified_bessel_k1_meta.h,sha256=HYIOIT-nwOdm0oR1hTGMnf3eY7tEOy0se9IPSOsHKQM,618 +torch/include/ATen/ops/special_scaled_modified_bessel_k1_meta_dispatch.h,sha256=fljmUHgB0viruo7LhV402XsXv2gabuuJLCbM3A7Ku5o,976 +torch/include/ATen/ops/special_scaled_modified_bessel_k1_native.h,sha256=shXMTkuyGV6syiPFoJJTjiN_nSCOPcWla5Xl4tHvZmU,697 +torch/include/ATen/ops/special_scaled_modified_bessel_k1_ops.h,sha256=dS-5akgAyxAjU6Oywa7TIHq7nAJF5Jg-cOouw8Xk3oo,1731 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_t.h,sha256=Q2JeGp38g1I0czEq3pjpFCYyqbxowFwEehDpfCB1XOw,3383 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_t_compositeexplicitautograd_dispatch.h,sha256=eFengWdPsvAZ2e-PxHoovEY0oKhT0L_xVljlNkR2RNk,1466 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_t_compositeexplicitautogradnonfunctional_dispatch.h,sha256=zzJ3QItu3SxJSd3k-mcZ3-4gta4Xkp2fULX1w8tkC2s,864 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_t_cpu_dispatch.h,sha256=-CQz0IzawK4nMDLrWry1CMj7bgEHsntq9iLmbTjQD3g,1055 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_t_cuda_dispatch.h,sha256=wWga4OxxD9GII8-xXFa8JQjRFaE_qAEc202lHCnaxnQ,1057 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_t_meta.h,sha256=vCkqSK9nfVulnsoJnVEd3HZH_QifnjGKfjHLQW0MjEA,645 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_t_meta_dispatch.h,sha256=oEGzr1eXiOXq5iOVteIvtLuPaIg2tUfzIfbe_peiNTU,1057 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_t_native.h,sha256=CWbYK27l1YwSN5QAf337Xv335C4zsnzB5VoRoPbceJI,1206 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_t_ops.h,sha256=l7Ye6ZgmdSLLJbvuUh9_mZkkeeuHLA-ay8HMPQYu54s,4883 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_u.h,sha256=kQf9RzzH9ax9wdw2GZ-MNPrLErUR8zjGQUPBDSaLuZI,3383 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_u_compositeexplicitautograd_dispatch.h,sha256=482jPanh3BYjKxpL31E6kNMQKi9Ntq7QNp7y3-SOjQ8,1466 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_u_compositeexplicitautogradnonfunctional_dispatch.h,sha256=d9JIdjUDzRMfBBe1LM_0K1gTnEDNMdDNKJtg279Q-5E,864 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_u_cpu_dispatch.h,sha256=wpiZ2PtPxV5cSBq_NaUCNJijKCrUFUgBNtpfSCzJ1KA,1055 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_u_cuda_dispatch.h,sha256=8Ypz42d_ITJoa27mh-uAXnu9P7wC5NQQgSvaKAq458Q,1057 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_u_meta.h,sha256=3MGw8Mr7LW2xifpbprZc-vIMyvmfMw27GP9YCPfArX4,645 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_u_meta_dispatch.h,sha256=JFdVKox_TJaDCcqI5pchOJ2_gF8fLcG1-RZlH36zvu4,1057 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_u_native.h,sha256=KUfnMNwjZJ6a98y5N-66PVA-XsL6N6Tm5gvX6fw3GpU,1206 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_u_ops.h,sha256=Qalo73F_t1-xdxR-q9wAhKc7a5T5RtuHN7RFni82v3c,4883 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_v.h,sha256=2PIKpzyHp-zIVFD4C5J2GCEHTK2uoPs_3YmKjkBEewk,3383 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_v_compositeexplicitautograd_dispatch.h,sha256=K7IDndJCkZ5lOXmukVfOhczVJ7wUNgw5auFD91DbAGo,1466 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_v_compositeexplicitautogradnonfunctional_dispatch.h,sha256=2xJO2h_L4jVezkHzVtqr4cFpjQJuk1rdRoQR9DXsugM,864 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_v_cpu_dispatch.h,sha256=3LfxAC2fZII9vDjbwuhVFTfRGge4_EK5yfkmLDsZVSU,1055 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_v_cuda_dispatch.h,sha256=aM25RZTvjIo_rgo0oCLAmxqg8vyt31tna7UT29oyVlU,1057 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_v_meta.h,sha256=oss1sdA5iZ1m0UD_Kcqf-F6xqHwH3XXskhfAovqmpf4,645 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_v_meta_dispatch.h,sha256=6p2EHjmhU_X8PpmC4Je8bkimF_j0xU87rn39J5U-NFs,1057 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_v_native.h,sha256=zvRst_Ki2rOrdAdYzYg7DOaaJI8Ckrd8THQ3cbnraB8,1206 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_v_ops.h,sha256=Zx3WMHm9VXeXnEmNJiiVpyTUrdMJCWNbp3aMReq6Z5k,4883 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_w.h,sha256=WXdoLhGh8Yx-6LUHDt28QSBwRZuTjXHLrVnmY0GbNKk,3383 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_w_compositeexplicitautograd_dispatch.h,sha256=lRD6u4ghpDcncwY-G97IfUuDUAEadQkDX2nhKk4bsOo,1466 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_w_compositeexplicitautogradnonfunctional_dispatch.h,sha256=ZFmx2Kp8LXhrU9rZuGCZlughdkdUhaonKTGsjknaxZw,864 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_w_cpu_dispatch.h,sha256=lAIJo1aI1hhmBwM1dDEQg6cWlngAhlDLx3pQxVfM9oA,1055 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_w_cuda_dispatch.h,sha256=wpV5mueRkYMLGxbPmRN7YtTto9TBixK_uToGUkJo0YI,1057 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_w_meta.h,sha256=e34RCUCUKgv3Z4zZVTFgq8X6wF0-dxksvACA2KrKb5g,645 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_w_meta_dispatch.h,sha256=ck8wVnYQ0qlQCXQGecxlCu0OM8Ev1QiQSHg56FTa-Bs,1057 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_w_native.h,sha256=NMLwCGvuxMguSR0RXki1Cm7qG5h8-ef4wO5z1YkidkQ,1206 +torch/include/ATen/ops/special_shifted_chebyshev_polynomial_w_ops.h,sha256=Itp-NHY2BhiNTZIackBPRZiqA3QtwNjhBE_VDjwyKRc,4883 +torch/include/ATen/ops/special_sinc.h,sha256=zUO5l2jlkremtDcoI8EKqpJ2StzXHGzem7ebddwVOwU,1127 +torch/include/ATen/ops/special_sinc_compositeimplicitautograd_dispatch.h,sha256=Pig8xw-Mvozz3Nv5ATew1PrK7Tus-q0CJqGavKChbN4,964 +torch/include/ATen/ops/special_sinc_native.h,sha256=YOH0kxAUQhnjDdUG-nowe0mjwk50ld97uElbGNB0-H0,588 +torch/include/ATen/ops/special_sinc_ops.h,sha256=RcXLzubTBxYBA0nBlgVK88gzBSWfXgCYjapUIqj8Ngc,1623 +torch/include/ATen/ops/special_softmax.h,sha256=bCAna7JPoiI-CBhT5RhF1weORqCKpAuK9C3dAJfBFSM,816 +torch/include/ATen/ops/special_softmax_compositeimplicitautograd_dispatch.h,sha256=OhjDEi8LPlFac84GAS_RcNBpmxNUrKmxXz2vxqQcIfc,863 +torch/include/ATen/ops/special_softmax_native.h,sha256=oODjhQsxxOl4pHVaHgAwUmg9V-Arm-sbl3pEWgimZd0,573 +torch/include/ATen/ops/special_softmax_ops.h,sha256=Y3mhbydKwDP9kQM77i_W0IBRgI2gn216vKE0PEHkmAo,1192 +torch/include/ATen/ops/special_spherical_bessel_j0.h,sha256=IqCbzlwJ1z9YKa2GsI1n7vncpKgipd3e53LUjd2RJKg,1250 +torch/include/ATen/ops/special_spherical_bessel_j0_compositeexplicitautogradnonfunctional_dispatch.h,sha256=CLbmwUwKPr6Tfg0iyysDxUN7Ak66xLgZAShZGsi5wkM,831 +torch/include/ATen/ops/special_spherical_bessel_j0_cpu_dispatch.h,sha256=Htud9A5I-95D7avL5hBOd7hZPXlwO93_jNl1C9JVckU,956 +torch/include/ATen/ops/special_spherical_bessel_j0_cuda_dispatch.h,sha256=5IcW2tHIDGb7YidRR3wPXhI2lsvkskiN5N64q1owgkk,958 +torch/include/ATen/ops/special_spherical_bessel_j0_meta.h,sha256=xJjDa1d7qsg_GDAcOrc6MHmBNjGWVMcpoH9R0nBuui8,612 +torch/include/ATen/ops/special_spherical_bessel_j0_meta_dispatch.h,sha256=fVbQyjaU6OZoPJABz-TzjK2KJaYg3g_cfTnQKzmvFto,958 +torch/include/ATen/ops/special_spherical_bessel_j0_native.h,sha256=ePSwyP5xWCBCLnsleViOhyJjEf3med1Bv4TiJmGGdA0,679 +torch/include/ATen/ops/special_spherical_bessel_j0_ops.h,sha256=J5yYh5wozHtqYq-ZxgxhJg1yVqbKAk6czmksQgFqAHc,1695 +torch/include/ATen/ops/special_xlog1py.h,sha256=1ERWdkOgfcap4i4l7NJO0zX7D9CGRbMO4Iru73AYFMo,2970 +torch/include/ATen/ops/special_xlog1py_compositeexplicitautograd_dispatch.h,sha256=L0D6xcx27G99rgNk-oCKcRiutuo7X8Ayy4v_o7myJDs,1370 +torch/include/ATen/ops/special_xlog1py_compositeexplicitautogradnonfunctional_dispatch.h,sha256=ToFOTXv3U74TODdVmJw2a7PqvYYjmBnQ565ErBsHP5Q,848 +torch/include/ATen/ops/special_xlog1py_cpu_dispatch.h,sha256=Y5hjl4ThSuaSHcnxb7L5tCOAFZF6A-PjbHNsDthAbi4,1007 +torch/include/ATen/ops/special_xlog1py_cuda_dispatch.h,sha256=R4wS-SaSmp0-KBsEgG-9sqS4iB1dGToemgxL-1wij-M,1009 +torch/include/ATen/ops/special_xlog1py_meta.h,sha256=Sk_2mEsqhwiI5eAfrltuWeHt2xKhsOY7ZDxDs-BSAy4,629 +torch/include/ATen/ops/special_xlog1py_meta_dispatch.h,sha256=1lw4xjjbzVxZRcR26SC2QSIarZJlGzDP37SLUtwRy44,1009 +torch/include/ATen/ops/special_xlog1py_native.h,sha256=JY8C6Y7rwN1d2Sk4JyC8wg9yCwzwimS1u35Z_7CqmuU,1080 +torch/include/ATen/ops/special_xlog1py_ops.h,sha256=kkT8lPR5t-HYZalfnrMeYGunER1tr2CL-ol6NiuRm94,4637 +torch/include/ATen/ops/special_xlogy.h,sha256=LdnU2FUfq4gaGwkxU3zwjySlwTbuFmIwnqHWB68JUKQ,2914 +torch/include/ATen/ops/special_xlogy_compositeimplicitautograd_dispatch.h,sha256=RfCswV1fxJwoUg8vNyInd95B9bpdu6vn0GqALEyoHps,1671 +torch/include/ATen/ops/special_xlogy_native.h,sha256=7ie24EzC77ueKyLan4GRQOmwtXqGvgJA8spCBarAWwg,1042 +torch/include/ATen/ops/special_xlogy_ops.h,sha256=x6O_vPRWJ9NCeWLOKfb_oUfID81YiWvr8nuKOKPgtAE,4601 +torch/include/ATen/ops/special_zeta.h,sha256=I6F06YUz-hjobz9dLSLN3N9mZJawwzGy1YAZSX-Zo-U,2886 +torch/include/ATen/ops/special_zeta_compositeexplicitautograd_dispatch.h,sha256=z3tfbizWWEeAbY40USOMykQo-5CUO5RYFAiWnNzQ9VY,1352 +torch/include/ATen/ops/special_zeta_compositeexplicitautogradnonfunctional_dispatch.h,sha256=kAUD_m68jxgfMqzs4AJsTK2MXfj5TsL8Lqd9UfdtSIg,845 +torch/include/ATen/ops/special_zeta_cpu_dispatch.h,sha256=ZrjIc0US_ul8CuL4Q08s79k6Xn9kb1iXSDgYc790iec,998 +torch/include/ATen/ops/special_zeta_cuda_dispatch.h,sha256=RLgaIA4thDMt-ujTZwHOU4Z1FsnAIfz0IaGYNqAvXs4,1000 +torch/include/ATen/ops/special_zeta_meta.h,sha256=iXNL1p9N3agG2jiP6nBifF3kLqrc8MmWcQbu7nnTOLs,626 +torch/include/ATen/ops/special_zeta_meta_dispatch.h,sha256=r2kF_JR5SyrMO-byaAOpsWJC6sBfaD8xMrUX4rMK97k,1000 +torch/include/ATen/ops/special_zeta_native.h,sha256=uQSdJHvGDVI7QVWN3UoLxsOYudWRNp5anZbBr6Cv96o,1059 +torch/include/ATen/ops/special_zeta_ops.h,sha256=aB9UEIhf8yrw6b1n0NRKJrAx59yF-2myFNPpQjs7Jws,4583 +torch/include/ATen/ops/split.h,sha256=8xYbPG2aKBQE12pLXltl-xlZbH0rqgV-eqWH6NLDMiU,2779 +torch/include/ATen/ops/split_compositeexplicitautograd_dispatch.h,sha256=KTpCl0gjkeXxMx1X86oDiJkoGAc9mLuCmexo6-LVDEs,951 +torch/include/ATen/ops/split_compositeimplicitautograd_dispatch.h,sha256=8x1M9RicuUjx-r2IzayYR79lN2Ots7IdKSO-3X7P3y4,967 +torch/include/ATen/ops/split_copy.h,sha256=zcPcMGWm4M5CNZ2m3sS0ydik_Z_1urEH6OuhCwrxvaw,4039 +torch/include/ATen/ops/split_copy_compositeexplicitautograd_dispatch.h,sha256=gR73PuZTNwcSSnMBAGGkmrV8ekqEtAG2k-DEDhRkhF8,1200 +torch/include/ATen/ops/split_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=5Byu1HwXDM9au6wdpRcE_GaJ-Kd9ZnG_27jWWPT2GfU,987 +torch/include/ATen/ops/split_copy_native.h,sha256=oQ_S8jj21TrDmqEdA7IveQL9hyM6z01n85i8KYjRVzE,686 +torch/include/ATen/ops/split_copy_ops.h,sha256=p6abpz9YPy8qkMkMGv7WSNajXjORPYAHv7Q_mCs9DbQ,1927 +torch/include/ATen/ops/split_native.h,sha256=FkYhfY3Zy--vRRKqy13-ww0_4wc91ynKr78CzhM5SkA,669 +torch/include/ATen/ops/split_ops.h,sha256=OVqzhKkrgpCKWeW1o3o0cc7t-2w32Cew6f3tqIoZqKw,1922 +torch/include/ATen/ops/split_with_sizes.h,sha256=4FCLvRGdEgblg5Ltf9KweIXBaSM03OZCnDcyJ3B9fSA,1789 +torch/include/ATen/ops/split_with_sizes_compositeexplicitautograd_dispatch.h,sha256=K74e9PyyuQTNtLe10vaqzfwixlXUtsU4usXfFu8ujn0,991 +torch/include/ATen/ops/split_with_sizes_copy.h,sha256=iXdrE_8FElGG2xx78zVkJvaIlqRJAIwxZxyy9CfvPFo,4548 +torch/include/ATen/ops/split_with_sizes_copy_compositeexplicitautograd_dispatch.h,sha256=l8OImAAcFMwN8rbDWua7xgkUCg7oA9SIVCpQ2pAU3TE,1280 +torch/include/ATen/ops/split_with_sizes_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=j9GHV6yWQRzA2jG7W9RqloV3J_CMGaFyAzC8bD789_0,1027 +torch/include/ATen/ops/split_with_sizes_copy_cuda_dispatch.h,sha256=LroLQJybFIfI0agZBxwLhHc3VEExQHHej52UYY6OKmA,1238 +torch/include/ATen/ops/split_with_sizes_copy_native.h,sha256=zcVU3tcW0h_3wPisjPnnFs6TMnR2HtDWruay5KsopLs,847 +torch/include/ATen/ops/split_with_sizes_copy_ops.h,sha256=PDk6C-DXPJPjT2qoY5WzbcoikQ85s6FRCQ6dkwnUgjY,2010 +torch/include/ATen/ops/split_with_sizes_native.h,sha256=-nAMhd1riJED6Vn2iUiijruY0cPawzbK3HhK-ciRs04,697 +torch/include/ATen/ops/split_with_sizes_ops.h,sha256=r_a1-W4-qyKwlX1Y3s4MYO0iIwt3rnoforOyk-frJi0,1229 +torch/include/ATen/ops/sqrt.h,sha256=1qJjAZEW7gXdCBQCudtlkMpF0b1QefdjlW_cl_-nSAo,1188 +torch/include/ATen/ops/sqrt_compositeexplicitautogradnonfunctional_dispatch.h,sha256=QTqXs_SxcuKDLV-CIwIOZJg5CxlpAtuSuh-6LBXXPJA,861 +torch/include/ATen/ops/sqrt_cpu_dispatch.h,sha256=xDERcizmbnhFj-XzQy9iLNNwBc032VfLBZqNheoEwx8,946 +torch/include/ATen/ops/sqrt_cuda_dispatch.h,sha256=q8mEolLONkn2r1QG9q12_xSNlN57nHYRpyJb84pdaF0,948 +torch/include/ATen/ops/sqrt_meta.h,sha256=8Xuf7pZfIYKbz4YB8dCIXD81f5BarUnGvk2Q1cRuqxE,592 +torch/include/ATen/ops/sqrt_meta_dispatch.h,sha256=jeBk8tAUY6yMPuwllKNpn8wzTokIZ7VSCR0tUoGle1w,948 +torch/include/ATen/ops/sqrt_native.h,sha256=Hnx2cPsD284EJXFMylByhktu7UR1rpdWnk1Fzj0CKIU,1093 +torch/include/ATen/ops/sqrt_ops.h,sha256=HIH-1KWS1v6zUa8KcQchf-q7mEk_XqIRLFko97xuaHs,2079 +torch/include/ATen/ops/square.h,sha256=34TKSbJtwCYyRjK--YIgiJE-EPaBVIGRtQXV8hzkjTI,1214 +torch/include/ATen/ops/square_compositeimplicitautograd_dispatch.h,sha256=FQI127UPL6GuNgFX_7zp2RK-Oc1w8FrvXSNVqd-vCKw,998 +torch/include/ATen/ops/square_native.h,sha256=ptF847iYH-zKvHO15-xsGoWgLcflniS1-78mT6yFNOA,628 +torch/include/ATen/ops/square_ops.h,sha256=a2Y5t2EMjdEIUcpTDdcGJbH6tz4GiGE_pQnDBxgkDk0,2097 +torch/include/ATen/ops/squeeze.h,sha256=oTNrL2TDqNBFRa7RJsBz2v5D5Q0LRUbCOTk8yepIUls,1258 +torch/include/ATen/ops/squeeze_compositeexplicitautograd_dispatch.h,sha256=nJ2h24X44PKWfSFVb9ivWT-NoaLM4sVOeDnpJvOSmLI,1127 +torch/include/ATen/ops/squeeze_compositeimplicitautograd_dispatch.h,sha256=YQa0AQDbokuxGIt3VY5htPMX2Th7oFAFS5aYLktU3n8,875 +torch/include/ATen/ops/squeeze_copy.h,sha256=BizOcjy7Tsi5aQPxDe6VJci91DBSRG2lokBpMq9WNow,2559 +torch/include/ATen/ops/squeeze_copy_compositeexplicitautograd_dispatch.h,sha256=WA0GcLBGi3tbDYKdYFudmF5WTEDoDJohcNqvCoICjO8,1313 +torch/include/ATen/ops/squeeze_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=8ldT8yCEs8zO_0M2FwEDmUOxslTqDkU60Vypu6F6lS8,975 +torch/include/ATen/ops/squeeze_copy_native.h,sha256=G1R-1z1f5ZEY8w4HhrGsVEuAttKbjQLgH8nYKkZI-z0,966 +torch/include/ATen/ops/squeeze_copy_ops.h,sha256=lnYktgqclf5L7aSblRU8wsGjiN8RuQfqR_5djK9Ixlo,4199 +torch/include/ATen/ops/squeeze_native.h,sha256=dS5RyQlYy7mKTIDYo3AEpxgyT8IQ1mGCcvEVKDI3F84,1443 +torch/include/ATen/ops/squeeze_ops.h,sha256=qbPVV7hhbhB7VqqiHdrppL48EcVXgtuEVDL8GC58ixY,5051 +torch/include/ATen/ops/sspaddmm.h,sha256=gkBuEPEsDYs7leQ0G2iPnfjgzV7XRZ3I6rqbMHk6Kak,1647 +torch/include/ATen/ops/sspaddmm_compositeimplicitautograd_dispatch.h,sha256=7taBNpNI3vjCYffW1-58SRhFeRatB6nrRr5R_by0674,894 +torch/include/ATen/ops/sspaddmm_cpu_dispatch.h,sha256=Cq7ZjM8zBltKxYimHgAvn1HheG0qcVl_xjBbtroUaJw,1057 +torch/include/ATen/ops/sspaddmm_cuda_dispatch.h,sha256=UeoJBvARDr6NndrsHanRfMAxHizElTxoQtBar7AEs7w,1059 +torch/include/ATen/ops/sspaddmm_native.h,sha256=F8Akt-sUpEmohdtLldJjNuWKaMFiL-5qD3_duyD7rYg,1374 +torch/include/ATen/ops/sspaddmm_ops.h,sha256=FoHQ9XSiRyv4lEweNDUAy2h8uG238VJ5fFSZiLJWWW4,2280 +torch/include/ATen/ops/stack.h,sha256=bh3CJUTIDaLPGrtpdqrZXmd7Q8jjtAFLdR3oXW753j0,1169 +torch/include/ATen/ops/stack_compositeexplicitautograd_dispatch.h,sha256=FGeWHMT47DqF8bHUaZ089oPkdjKAVrzG_FU4lMhO-5s,983 +torch/include/ATen/ops/stack_native.h,sha256=mt6smclalD-vwdoqDw33XANLjUXtNRDGmxY9LRwVhMQ,600 +torch/include/ATen/ops/stack_ops.h,sha256=pXHsD5ZHBQJrG9qVQqciRhXwlI7eV58lKL_838SY4RQ,1671 +torch/include/ATen/ops/std.h,sha256=kzw4zySdHcClM1jNzZ75wC-01HtEK2YG_ASusmH19wI,4958 +torch/include/ATen/ops/std_compositeimplicitautograd_dispatch.h,sha256=JD5wk81m-1k51m5dxK0tWDNRS2WC8U0xi2bC-5-TMzA,2053 +torch/include/ATen/ops/std_cpu_dispatch.h,sha256=6NK1BLZThgasuP0SyTfjiqfw0o1iaE2BzR05dEJNi-c,1238 +torch/include/ATen/ops/std_cuda_dispatch.h,sha256=5vV-v5ayGmLXntSPlbNtzVCt_EOQD8bKmQef2N0x33g,1240 +torch/include/ATen/ops/std_mean.h,sha256=lQC-Yhs2-k6xhb3iF2i7bZmYK93Rjo7OpDDGrE_fSRA,3277 +torch/include/ATen/ops/std_mean_compositeexplicitautograd_dispatch.h,sha256=_Bq5FHEqHURqhDHdygfJl_eNh0VSfkaO9_V4qaAWkQU,1207 +torch/include/ATen/ops/std_mean_compositeimplicitautograd_dispatch.h,sha256=MQPVwNDlvML_Yfwb2YAD8Jm3hIjN84NL8F8nxRxjP30,1299 +torch/include/ATen/ops/std_mean_cpu_dispatch.h,sha256=fFzgs0mu5YMsZ7Lb2ocCnhaqAHQJle30HjM2OWPUIhQ,897 +torch/include/ATen/ops/std_mean_cuda_dispatch.h,sha256=_qGAMEtJ0mRRsLUB-RGqGC_xR1m7Yr4W512wqdwWofo,899 +torch/include/ATen/ops/std_mean_native.h,sha256=rWpowP7VAUNbXFjmcWUXUSdWiADh6gywBpKozBFgz54,1463 +torch/include/ATen/ops/std_mean_ops.h,sha256=6INFCY7uuS2KD-6tDdYbI5t7e00LpkbKzfW2mJHg-Bc,5780 +torch/include/ATen/ops/std_native.h,sha256=HJk7AxMJKlZLgOZc9zrIi8Zx7WlI_xxSu8YksfQzwko,2043 +torch/include/ATen/ops/std_ops.h,sha256=1eb_JOGk7bZk3_gT7bHz8_NrN7tGudwBVxAfUVQkzhg,7621 +torch/include/ATen/ops/stft.h,sha256=qI7p8uDlsylyRM8e_J0RMMiZazi6l_-y7nkCRUY45Dg,2108 +torch/include/ATen/ops/stft_compositeimplicitautograd_dispatch.h,sha256=0v3EH32sSquDJuAWZ16T9bzrleCZOL-z1kit1C4zuPs,1540 +torch/include/ATen/ops/stft_native.h,sha256=Spvtsi7ZuJ9EftsFdNQu2G_Cx59ByIcnfh_5ggUQi2E,1289 +torch/include/ATen/ops/stft_ops.h,sha256=nSNjVVClJfqn_3cMs1v_8ws7WX16UT7D3vIn1lUL3SY,3385 +torch/include/ATen/ops/stride.h,sha256=6xVzr-fVspb1scf-JRJjaCxev5xYSPUibUWfvr9_PY4,893 +torch/include/ATen/ops/stride_compositeimplicitautograd_dispatch.h,sha256=6vL2YebnNRUj1JnAEWmH9sNdvds9MoabuamU5YNrAKA,866 +torch/include/ATen/ops/stride_native.h,sha256=Ixw7a8ed3_wj-Ifn8o9ovezLutkCbds2Aze2NMPTMxg,576 +torch/include/ATen/ops/stride_ops.h,sha256=Cpwqai4xbBrkfwXzMPJj-DlRHtrDVt7rI2g1ZRGQx5w,1611 +torch/include/ATen/ops/sub.h,sha256=oxXQvlfUO-KtFtAaAoVhwWIuUucRPxAXb3u6jDC36SE,2192 +torch/include/ATen/ops/sub_compositeexplicitautograd_dispatch.h,sha256=FOtDJwJ5pjNWtI6hyV7n52g9crqkmWe2EnIO5Jzp4_M,1200 +torch/include/ATen/ops/sub_compositeexplicitautogradnonfunctional_dispatch.h,sha256=zUwPmg40X560MIuhO6OC7j9nh8AtX5rADjxX7t6DWzc,967 +torch/include/ATen/ops/sub_cpu_dispatch.h,sha256=93Qns8HTG1If2GJX9QufWlNNM4Je38OXslh7cOq2TmM,1156 +torch/include/ATen/ops/sub_cuda_dispatch.h,sha256=wJbtgTEKFZ7beP-drtA4x4W9b0y-MOll40NxcjO3DSw,1158 +torch/include/ATen/ops/sub_meta.h,sha256=gmum-cYnMROxFrm5hxTYTaERf0VMoFtyzSi-hrwG7A8,650 +torch/include/ATen/ops/sub_meta_dispatch.h,sha256=HXyc6o8InKAmyhHTqTQmdO7ycpHEQYI28JMb8P0hKtE,1158 +torch/include/ATen/ops/sub_native.h,sha256=qwLm2NFV3E9mBhuUJfUNL2EgdYEKG0W4WuJvGrxf3Hk,1614 +torch/include/ATen/ops/sub_ops.h,sha256=gUhk7lVcyGiiN4gVsmg4vzao72LdN0y6ho2QMxM9ASI,4816 +torch/include/ATen/ops/subtract.h,sha256=GX57cG_iBm2ZqlZbwcsESKNNDReCtLsRdHKEcFVNqsM,1654 +torch/include/ATen/ops/subtract_compositeimplicitautograd_dispatch.h,sha256=5cKFIS_u6v72A2Qi2v1woyG9YGwAroMmN6V8zPerU5M,1439 +torch/include/ATen/ops/subtract_native.h,sha256=B7pTgb-mzqJZ3pjx3zLdJgf0COd8Aad8wGVUpTNEj30,1013 +torch/include/ATen/ops/subtract_ops.h,sha256=fQ4ICLLZgv407p67KelXEK3PbCLWfb9Llgcdi9NBaWk,4104 +torch/include/ATen/ops/sum.h,sha256=2UkpARt05wYQnmLX1Ncq_ChtukhmMTcEdFknsYDPr-Q,3475 +torch/include/ATen/ops/sum_compositeexplicitautograd_dispatch.h,sha256=5UR4vjeskuIBslnAYZr8bI-RL1NH2ocjUy4fmfUeurM,1084 +torch/include/ATen/ops/sum_compositeexplicitautogradnonfunctional_dispatch.h,sha256=R_UdzF96ZmDJiKx2-4N3pKPjlT8yXoSTAkLnCIKKheM,913 +torch/include/ATen/ops/sum_compositeimplicitautograd_dispatch.h,sha256=TP2LOM4Rn2nErQSGvLci5lZFPXNckOqzFzNtLnJKMUw,1201 +torch/include/ATen/ops/sum_cpu_dispatch.h,sha256=O-s0dNgF3qBLXU4QqsUdDnI6VXKr_ZZZV1leRTovTKc,1181 +torch/include/ATen/ops/sum_cuda_dispatch.h,sha256=bhcfRmDZXsby6LRSeuiSRd0oElDZcvH1qXDHWZiu47Y,1183 +torch/include/ATen/ops/sum_meta.h,sha256=z2etIzNhVWUE-9uxuatFXK5j0Vt5zk4mGDol1gSa9to,685 +torch/include/ATen/ops/sum_meta_dispatch.h,sha256=gjQ5Zie1P6dTptFMNP_-S6GY4g8MtcwAUZIlXJW8EJY,1183 +torch/include/ATen/ops/sum_native.h,sha256=BpxokbjnUSMrON9_4jPPFaI11oiBYab8dRNqfOfGEa4,1957 +torch/include/ATen/ops/sum_ops.h,sha256=_QH_ChzXr3DpZTP5VCRO_pBiJYVHIJACqsdnRtXAqG4,5276 +torch/include/ATen/ops/sum_to_size.h,sha256=PmeHkkahC_NT3Cw5GZdLypn3J_0UROYGp71YhMzTb8o,1047 +torch/include/ATen/ops/sum_to_size_compositeimplicitautograd_dispatch.h,sha256=Eh3b2z2itQPdXdD5HDj4LiFpleNdBIemtyxIrSP7fkQ,907 +torch/include/ATen/ops/sum_to_size_native.h,sha256=e9G18CwfeNgvsr75xtM8gdWS1WGokb0Ww9iVIDim1gE,535 +torch/include/ATen/ops/sum_to_size_ops.h,sha256=ymIXGbFMKvTaEGxB2Ot11bbrVALq0fR65ihvmzxHFNI,1089 +torch/include/ATen/ops/svd.h,sha256=muM6kwc_W83AFz4mnyGO5vIL_DVBTah03sXaHwkxbIA,1625 +torch/include/ATen/ops/svd_compositeimplicitautograd_dispatch.h,sha256=xMS1Vg4TEhVGTm6ZcWYI7AEawlySwc-1T2pQwog0YLs,1217 +torch/include/ATen/ops/svd_native.h,sha256=NJjgYqMHrRVhfx4KHm3TduQDUvzQYYgHpTnOa9Y8i50,742 +torch/include/ATen/ops/svd_ops.h,sha256=jP3-D6devWc7c_4FiTG7f_TlOF4L3-mq1LDrMkTT2EA,2173 +torch/include/ATen/ops/swapaxes.h,sha256=BdlxkS_k1tqAuQ-mkwKx71WjFkah1K7Ku0lZyU0JeU8,748 +torch/include/ATen/ops/swapaxes_compositeimplicitautograd_dispatch.h,sha256=oGE0ftTl27XWLqNeC_OZLhyrQBb0uhb1fQ0bFdkZR-0,903 +torch/include/ATen/ops/swapaxes_native.h,sha256=K-M3iLJXOR5dbG8-7aNk2sriyXEneFYglEKmi5D81dI,613 +torch/include/ATen/ops/swapaxes_ops.h,sha256=Rt05WhSVbZeAvC_Y3XRfKkVfPq4O7BX3bE5rv9wHOXg,1714 +torch/include/ATen/ops/swapdims.h,sha256=BoVWK8Q3YAOta-q7ifnL7T0IjXPv8aepqa1vpOTWS6o,742 +torch/include/ATen/ops/swapdims_compositeimplicitautograd_dispatch.h,sha256=ShpqJ1A8XdJn_op5-to-H5VQcnZ4I0kvK8ewFCY4dPE,899 +torch/include/ATen/ops/swapdims_native.h,sha256=vvDg4tfsP-3C4xNlBLPXOpM3y28M4p4WtPq0qoAuT5I,609 +torch/include/ATen/ops/swapdims_ops.h,sha256=0AacSm_TxuUHsP3FZmiknj_DNLcSFPP3zqR739yxICw,1702 +torch/include/ATen/ops/sym_constrain_range.h,sha256=ItyeYxTYyA7wTc0Baor5ULA3rn1TA93dIiEG0WNZJCM,843 +torch/include/ATen/ops/sym_constrain_range_compositeexplicitautograd_dispatch.h,sha256=B0ukZ5ypBaWkS3hE7yJlXjVpF-54z_lZK__0T8vMQW0,884 +torch/include/ATen/ops/sym_constrain_range_for_size.h,sha256=gWuw_tV4vECOjYN1spKAit_iYFvaDxebCe4IxnoqnzQ,879 +torch/include/ATen/ops/sym_constrain_range_for_size_compositeexplicitautograd_dispatch.h,sha256=6VXxgl-sp6GRWFPulOGLHABROczqfj3sJVYz16xXXUY,893 +torch/include/ATen/ops/sym_constrain_range_for_size_native.h,sha256=T573Mt0qZ5LCRa4GOP4RSLnnduFF0jNXBPW4lZuCnK4,603 +torch/include/ATen/ops/sym_constrain_range_for_size_ops.h,sha256=q7E9K9Qz0TL9lryoddOFFCTgoMypHdEhyEF8eljGDSE,1235 +torch/include/ATen/ops/sym_constrain_range_native.h,sha256=i5h-Vju8P_CaRSivnO12EOl8_0dgtkxF_Sk9dbZmSd0,594 +torch/include/ATen/ops/sym_constrain_range_ops.h,sha256=tWK6BzYcZ9iG91pySu0Ht-ajb6VWvRFuvrlq2TZ7p1E,1208 +torch/include/ATen/ops/sym_numel.h,sha256=fIXgKOTw-A05I4hi2TmmkCmgEIZEdzymS6_LrHJhTGY,692 +torch/include/ATen/ops/sym_numel_compositeimplicitautograd_dispatch.h,sha256=FxBVqhhU4ffeN4wq3dt59FYjKPWB9oiyM7voaFa4JdY,791 +torch/include/ATen/ops/sym_numel_native.h,sha256=1WnxowHMnm7LjuWj9_zqrNTIPlv-_DKCicGitrJRPbE,501 +torch/include/ATen/ops/sym_numel_ops.h,sha256=oLGRSiwdiGFwEh8kLiSHGF4hDU6b5yt0t6cfuXw0H0A,998 +torch/include/ATen/ops/sym_size.h,sha256=Asal3Nm2P0DIrGI2_OtpWwR4dvBOUQJidL-qmFezgTo,723 +torch/include/ATen/ops/sym_size_compositeimplicitautograd_dispatch.h,sha256=LHN-y4XpKWxgq9YlcoydfvNSblhxF9dh7MkC8WTuMvM,803 +torch/include/ATen/ops/sym_size_native.h,sha256=YVBN52_BgSyPj81s3Ouhq4Xp2LK9HSlgoAh0_vakjCg,513 +torch/include/ATen/ops/sym_size_ops.h,sha256=yJ0IKxvHyYonMbCljwHKJyXU0qXhKr5obFjefoWV98c,1050 +torch/include/ATen/ops/sym_storage_offset.h,sha256=5lMMJGF-V3LFIrIEo1gjbmxaghRWRIuB_5xoGXVX77A,728 +torch/include/ATen/ops/sym_storage_offset_compositeimplicitautograd_dispatch.h,sha256=61lFZV_MSiP5FzfmHHjHNGP07qyu5FpDGg8gsgW-ydg,800 +torch/include/ATen/ops/sym_storage_offset_native.h,sha256=-k44eCmGjXyRBz1UxuYK92R3LJs2bEoEgS_z2-U2zvQ,510 +torch/include/ATen/ops/sym_storage_offset_ops.h,sha256=Ct-_AvbdUtV9I4DS4aO_2Klpv3ihNwKnYKHA6pSrC8w,1025 +torch/include/ATen/ops/sym_stride.h,sha256=PhHtREbw6vmK5Hn993JV7EadcXaTwjojAubk9ZMlLBI,731 +torch/include/ATen/ops/sym_stride_compositeimplicitautograd_dispatch.h,sha256=B8StY3b3zATPJlD6QEyLdeeD-sYuX4KwHPreO7UtS-M,805 +torch/include/ATen/ops/sym_stride_native.h,sha256=zc25M8sfzDxMdiE-5IQoY3G-z4PMtOCokW6qh-K8ZPs,515 +torch/include/ATen/ops/sym_stride_ops.h,sha256=weI3mDZibbtQACQ8KKvHg-xcOP0Sdsd2QAtVhBVflw8,1056 +torch/include/ATen/ops/t.h,sha256=whm6p6VkdUX3oZBNdNLNQzJpDJ8_qJfbgKTXnmp9vao,654 +torch/include/ATen/ops/t_compositeexplicitautograd_dispatch.h,sha256=xD0voEmyjpQOBrbDOjsmLShs1qc0uT7JevKfeRSlDvw,829 +torch/include/ATen/ops/t_copy.h,sha256=vAJ_PRBD2w8N7gPQzTMU9Djupl5MmbkdSvBHbfSBoiY,1067 +torch/include/ATen/ops/t_copy_compositeexplicitautograd_dispatch.h,sha256=KBw6tUI2fMm4OdzSkN1xeKh0MB9qvlXEjH0qWbaIM-w,891 +torch/include/ATen/ops/t_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Fv20Mcc11YCreyQqzrL2E2VJnVpKtcxWZq1B1bhjVBA,813 +torch/include/ATen/ops/t_copy_native.h,sha256=vpb6cfIJosVn13J4PfMaIzToSSihI9blfm_UQBV1D2I,576 +torch/include/ATen/ops/t_copy_ops.h,sha256=8hsnoufwD0nhE2NkY_KyDacZGi6QGZj9ryldyoWW5hU,1587 +torch/include/ATen/ops/t_native.h,sha256=HutAtMAC9E04xvXdCmGHnn1SftPhKSPxco049CW6-SE,539 +torch/include/ATen/ops/t_ops.h,sha256=jzKrbLcUz60gOP-wnjdrEGPuafeNhIxWE5jHUVCzHEo,1472 +torch/include/ATen/ops/take.h,sha256=o1K26p4W1faVyTM5KjqM2SP2Wf6fQmr48VX1na2B1P8,1188 +torch/include/ATen/ops/take_along_dim.h,sha256=yWu9sXR8ukr8ZtrZnk7aC3UZg-GJT-_Oza9VxGBfhfE,1486 +torch/include/ATen/ops/take_along_dim_compositeimplicitautograd_dispatch.h,sha256=_bC9R7OcwElfc5z0yIiNchLk0c7MLoZxA0Ei1Q2ykGk,1174 +torch/include/ATen/ops/take_along_dim_native.h,sha256=_e_QJTRU2_FxEznJw5ccooq3rNyiKB08m59DMvZuISQ,723 +torch/include/ATen/ops/take_along_dim_ops.h,sha256=SmDCzIyrPVvwKjDe4mXtL6gP6S8HTxfpOduXlqN1now,2021 +torch/include/ATen/ops/take_cpu_dispatch.h,sha256=fXlVEWiWRrKsg69P0sPensGZr7ZQyg52tw47Q1_SyC8,974 +torch/include/ATen/ops/take_cuda_dispatch.h,sha256=sRzCQEz5iL7boJ2Db_SOpTeUqTJX3ZxXyoyxUvyBMKg,976 +torch/include/ATen/ops/take_native.h,sha256=oe9QqfRwBDs-K06z97Zbb3TwOrUARHS1nVkOt7GAHjI,624 +torch/include/ATen/ops/take_ops.h,sha256=mItR5MIYe4Eri83xNPNITKTCCBzqyKIzyROq1PvZHrk,1747 +torch/include/ATen/ops/tan.h,sha256=u4lLjKaZdqUEYfuzSNJFmzdzeQM_oJvjS-D4rtFe1sE,1175 +torch/include/ATen/ops/tan_compositeexplicitautogradnonfunctional_dispatch.h,sha256=pBs8FI8R7mCvmk-h81GnKW946DrruhxIfMoFWhEYu2U,859 +torch/include/ATen/ops/tan_cpu_dispatch.h,sha256=RFqr3fX5eZQam5a6Ts-orS7SSfp-K6NuJQXJZPGLAaw,942 +torch/include/ATen/ops/tan_cuda_dispatch.h,sha256=onmoa_XDe_v5bsUCKB7ZlpNaKgiPPOpKSvUywHO-55g,944 +torch/include/ATen/ops/tan_meta.h,sha256=AiVyRIUUr7PbNvnoiT5kiZajliLwA1tlCKBUG3e-6J0,591 +torch/include/ATen/ops/tan_meta_dispatch.h,sha256=SGqo2AYJY1sOBMju9D3LQSZ5SXg5Hu_fwT_JrJlpXhM,944 +torch/include/ATen/ops/tan_native.h,sha256=Denjuqb_1o6RzpB7tLJaS__9OuKCr5Jyk77r4GP0SXQ,1018 +torch/include/ATen/ops/tan_ops.h,sha256=OdxfkEHdfxMKxkF7Htbp3V4dACS5LG8F1eGdcWCF6rE,2070 +torch/include/ATen/ops/tanh.h,sha256=a_D97nnYItjW1vBUdBYYHKVaTEjv8OpAZfx7oZT-sMo,1188 +torch/include/ATen/ops/tanh_backward.h,sha256=SiId5JAwPI-j6-LexnU1iVhhvcOrwNA36LqhVgyAlaI,1420 +torch/include/ATen/ops/tanh_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=5vkHy9Z7dXvFgAXI9X_C98wBcQHmtJkOuSNzwXh3qms,854 +torch/include/ATen/ops/tanh_backward_cpu_dispatch.h,sha256=eVmoekzy0wRpt5ECOJaJVddPHvjgpMCXiGKOHM_bFaA,1039 +torch/include/ATen/ops/tanh_backward_cuda_dispatch.h,sha256=nstMIiMoLl2jZn585bg4I3fUIKSS2UJxiYRzN64hRBM,1041 +torch/include/ATen/ops/tanh_backward_meta.h,sha256=SNywMh0DQTs1jBequXNFjsd2EyhLBS6tq6BvqSdnMmE,635 +torch/include/ATen/ops/tanh_backward_meta_dispatch.h,sha256=pighWhoHvgbfTvfYuRpI3Q-Gf26PiJ0JMZNprR5Dz4E,1041 +torch/include/ATen/ops/tanh_backward_native.h,sha256=oOc5loe-TQsJ5xlxXRgc1eFTDzW2DJjdfbE4dXatpG0,681 +torch/include/ATen/ops/tanh_backward_ops.h,sha256=BfMOxA4xGFIxu0na9hkOxp4Q64WHxMK4MxvGU6XGvGY,1891 +torch/include/ATen/ops/tanh_compositeexplicitautogradnonfunctional_dispatch.h,sha256=ZPC58k7Ib9eVAKnS06i6cK61JbZ5-Mk8oOFJLPqLFmg,861 +torch/include/ATen/ops/tanh_cpu_dispatch.h,sha256=CKlMvSAJl3V4l6BXHvR1wDlpnkMln-G22gAAV9Ba90w,946 +torch/include/ATen/ops/tanh_cuda_dispatch.h,sha256=RpSI8PZZgmMcktjtoNS_BFIRS1RFIUbjBkGFX63uOTo,948 +torch/include/ATen/ops/tanh_meta.h,sha256=bnrdlMCDZJJbCs3fm1l5RH0QT2aC3ckWMiNfA_iGoJI,592 +torch/include/ATen/ops/tanh_meta_dispatch.h,sha256=lUQSJ6owl__AEBRJrSfhT50MHSRLtF0ByGnRcKWjbHM,948 +torch/include/ATen/ops/tanh_native.h,sha256=vkT47hO00F40azhAVDV9RqunD8hJ_5ji15fIDCF3rBg,1340 +torch/include/ATen/ops/tanh_ops.h,sha256=UcrgQtwzEQ7XcA_qLxplGI_BiVYUWNDI0wDuwMpRM0U,2079 +torch/include/ATen/ops/tensor.h,sha256=dnscgDxX-RB-pvMxgnPEjrIhdi8Yk7VkZFtWcwrUp3w,1661 +torch/include/ATen/ops/tensor_split.h,sha256=c74ALSJObz5slH3q_4JrxQmyD42IVg_s5vkpkPiggzU,3281 +torch/include/ATen/ops/tensor_split_compositeimplicitautograd_dispatch.h,sha256=WQI72ijlbuUnyK2S1K0eM1Y3oIbNlFjSg71Pijuk848,1342 +torch/include/ATen/ops/tensor_split_native.h,sha256=CSzksmPgkHMLxGJe7V_NZNN_cnVz7cQ7D-ywpk96was,844 +torch/include/ATen/ops/tensor_split_ops.h,sha256=C2empylzIWxMpSrcdrBV6Or6v-xXHumsiUbd0kak6b4,2823 +torch/include/ATen/ops/tensordot.h,sha256=Ksqg-7DfwCceSpNEXmacC8-ilOkC-18ouT9qosbvMac,1577 +torch/include/ATen/ops/tensordot_compositeimplicitautograd_dispatch.h,sha256=LpR9nH6HOyAf7XRldjdpdAsDjv2N2caRTaOFg2lzbH8,1198 +torch/include/ATen/ops/tensordot_native.h,sha256=Oy94bx9hLoUSFjkR6DW-nc4_i7y1XMaJeu_qVjrcVXs,744 +torch/include/ATen/ops/tensordot_ops.h,sha256=Lnw4nZJfMzXrH4xxwIkPWrST9o6Y3QiV7rYMnSBbCBI,2135 +torch/include/ATen/ops/thnn_conv2d.h,sha256=xKJLB5NYYICAVId1Pue9EJCOQyvFWfDheuyjFXuDxPs,6712 +torch/include/ATen/ops/thnn_conv2d_compositeimplicitautograd_dispatch.h,sha256=tNISurfKWd-_-iyKd077uwbOpTHspvS6Vhk5y_fACf4,2209 +torch/include/ATen/ops/thnn_conv2d_native.h,sha256=yjvPlzdQQJkVLmE8QDoMMOmwYfLQWBt3mlveKwnjN0c,887 +torch/include/ATen/ops/thnn_conv2d_ops.h,sha256=X7OycCWVdVyYWilcVJ0SdkQET6SdtrR2-P4Je3DBDqM,2689 +torch/include/ATen/ops/threshold.h,sha256=XULTkm4QsFrOB8sHkju9qYyDAOtk5XB18qlH8RI3BWI,1677 +torch/include/ATen/ops/threshold_backward.h,sha256=CXGsNLLEC8m0lQAmWQWIBJyjUqfz6RlSVmhYMrkduBw,1629 +torch/include/ATen/ops/threshold_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=v3-r-aKC_tAQ6gs2T2WFy-Y0t_hBFP_mo9oRg_eZI7Y,887 +torch/include/ATen/ops/threshold_backward_cpu_dispatch.h,sha256=6X3y_K6yfwUcIHHJeytWmpB93VWxLYv_ADBMBLh_O5A,1138 +torch/include/ATen/ops/threshold_backward_cuda_dispatch.h,sha256=44hFSR599cIFKsTR_3-2eVMFkCwnfAv_e_HnkrKtNac,1140 +torch/include/ATen/ops/threshold_backward_meta.h,sha256=7tnrDJDWm2XiW4jTgxZs6W3XA2xo1rRAGQS-hQnB0K4,668 +torch/include/ATen/ops/threshold_backward_meta_dispatch.h,sha256=T9dHLJlXjkvhOIc522RcvculYi6gA4D3P_duepzaLHk,1140 +torch/include/ATen/ops/threshold_backward_native.h,sha256=rYnJ9t404C-sGAfWMLIxzEqiHy0TUEDqqE66LQWAUGE,1620 +torch/include/ATen/ops/threshold_backward_ops.h,sha256=Vvt2rrdBxdrZ_jpu7Q-osZ6O9RDItGNDIw5DkWGnEHs,2105 +torch/include/ATen/ops/threshold_compositeexplicitautogradnonfunctional_dispatch.h,sha256=NBplj1j94L9bqD-xy73StPvtWWxZ4k42NYDz5dOK4P4,983 +torch/include/ATen/ops/threshold_cpu_dispatch.h,sha256=euv1dRGCprhE5NRuXFphAkl-ZWKtcVGVlcRZKRFayos,1190 +torch/include/ATen/ops/threshold_cuda_dispatch.h,sha256=7FjjxNkPEXJj9dGAj92pFi_3ehE7SNU-j8u_03Sh-2o,1192 +torch/include/ATen/ops/threshold_meta.h,sha256=pCwze0Q8fxY5INRzMNr4RofjI4aFRIiFLI8L6zo80uk,653 +torch/include/ATen/ops/threshold_meta_dispatch.h,sha256=CwyHogBaYaehlLp4f1LrUjupL6GJYJ9ntDqKBNnH_oA,1192 +torch/include/ATen/ops/threshold_native.h,sha256=317u55lQOmHYv2DhvksVbWBSdWCvXiPSemidnUdoWlo,812 +torch/include/ATen/ops/threshold_ops.h,sha256=dFVP-xRtFaPKnng8PrGS-VQx-XHocR1hS73HOXpt5Es,2676 +torch/include/ATen/ops/tile.h,sha256=RxKSsWYKKErFz6W_-lzlMhXOncj-UYeccvG7pAM_7Ns,1399 +torch/include/ATen/ops/tile_compositeimplicitautograd_dispatch.h,sha256=9E9XZyp32X6y2Y0SoaqBZDnT7ljPl_nJlivqAbChie4,893 +torch/include/ATen/ops/tile_native.h,sha256=C3FtAPnYv9az7fhVB4X-rEFzLDnB27ooWg_NyQAaFmE,528 +torch/include/ATen/ops/tile_ops.h,sha256=z0G-aQJQRxoSXQTWLEtqqjxZD0uCQAOw9XYKZZ3NX24,1068 +torch/include/ATen/ops/to.h,sha256=9LAZDHwJs6I_dCYxA15GRCurgTkGRLNpqCUicaPE5Ks,526 +torch/include/ATen/ops/to_compositeimplicitautograd_dispatch.h,sha256=H6TJAh_zD07d7CSZ78w6XB1KJ1a4BEr7SoepTjb2WmU,1754 +torch/include/ATen/ops/to_dense.h,sha256=q3UA8lJyVj2pkI5T92VuxfXoknc4LmlAkqJYenAZYOo,532 +torch/include/ATen/ops/to_dense_backward.h,sha256=KoQn_ez-G2ZdkN_IQtKqpm_ViWfdx7m4-LJzb9l4v1A,846 +torch/include/ATen/ops/to_dense_backward_compositeimplicitautograd_dispatch.h,sha256=wEuqZCNCo9JCTGZ6wL1QUW4enmF5857ZeajyZgRfR5k,874 +torch/include/ATen/ops/to_dense_backward_native.h,sha256=fbuYnCtQbqpOG77DXmzf4lQVDahmth4GtDgQMJPTw_E,584 +torch/include/ATen/ops/to_dense_backward_ops.h,sha256=4_s38jts_HVUJEyEsqTAZekrm81YULPHAf7rzX0XFTo,1222 +torch/include/ATen/ops/to_dense_compositeimplicitautograd_dispatch.h,sha256=JjrYQXoRBrk8eZx-7a7bXsjPKMEy-FwZv5E5GEzqiSQ,893 +torch/include/ATen/ops/to_dense_native.h,sha256=XiwWqRfbWHygC6NdK8VcRqLgh_h6es7RKBTvUoHr5bI,603 +torch/include/ATen/ops/to_dense_ops.h,sha256=XJwpEtX-C8af4IiLVeU7ZM78urraHMiuRKLBvW6YdMw,1247 +torch/include/ATen/ops/to_mkldnn.h,sha256=yjNLB0N7J-hE5viHyKWAG8O0q1u46BF8cbgcomK65VQ,1103 +torch/include/ATen/ops/to_mkldnn_backward.h,sha256=qwHXLD13fi_6wDOoxtpo4MJFKZ0epQP111pBdmpAowE,763 +torch/include/ATen/ops/to_mkldnn_backward_compositeimplicitautograd_dispatch.h,sha256=2-2f7I0xondbBFJDNmdU-R205uPBFbVVc1rio6jQ7u0,825 +torch/include/ATen/ops/to_mkldnn_backward_native.h,sha256=FNFQTPkxTxIvVdNr0qfWNEGRfA39GZ-BoTps30jJk_o,535 +torch/include/ATen/ops/to_mkldnn_backward_ops.h,sha256=ggAfrj52VeI0GDYq9ZC93T3C0AVU6X1GriB8gZLYNjM,1108 +torch/include/ATen/ops/to_mkldnn_compositeexplicitautograd_dispatch.h,sha256=r8kW2XeUSEl_3izbmtiKU3IjrGScq3bp40abLXGvXmU,990 +torch/include/ATen/ops/to_mkldnn_cpu_dispatch.h,sha256=PgBJkBxazUvHIXWazXSEvT-fq6zgymEXIH1KIQY7UJ0,800 +torch/include/ATen/ops/to_mkldnn_native.h,sha256=7mWH-x_zdjlNe6bivmcWpgnRxwuBzu3fP_y6wxlMXmY,681 +torch/include/ATen/ops/to_mkldnn_ops.h,sha256=ofYbC7GnIebhBIXTTlZmuO8fD9L7_RMNgVLBAg2C8sg,1875 +torch/include/ATen/ops/to_native.h,sha256=z9qeTd7BVGvoEix6D9igSI-vm3H9lx_NzDPGYXUp9SU,1316 +torch/include/ATen/ops/to_ops.h,sha256=X4zrl9sqipkHG4-xkkbBH0DwqmEjcQQ3NkVIPvbdtaU,4514 +torch/include/ATen/ops/to_padded_tensor.h,sha256=piYm-sjljVtNUpmZHQG2LBQ7cWpIYef0AXJyfEwlt1I,4431 +torch/include/ATen/ops/to_padded_tensor_compositeexplicitautograd_dispatch.h,sha256=C_yI17c00Oqo-7SdCx0r9jjZZH5XtJJFgzixZvz5JvE,1352 +torch/include/ATen/ops/to_padded_tensor_native.h,sha256=T-gULykHKBqn7c7l-05DLRWEWR6SfrYKMZ3QzBgXhK4,899 +torch/include/ATen/ops/to_padded_tensor_ops.h,sha256=Po3dVBvqqqvETb4d6fof3FXRSKQwydsCWnydwlFlAak,2029 +torch/include/ATen/ops/to_sparse.h,sha256=5bqeDmBgsHJ2qVSZbF8mATLuqIobqd19CgMr1QWe3EQ,533 +torch/include/ATen/ops/to_sparse_bsc.h,sha256=LBkSU_Ir6f8FJ7_WzKy7kFCy-aqXvUZd0l6cBfkZ4r8,537 +torch/include/ATen/ops/to_sparse_bsc_compositeimplicitautograd_dispatch.h,sha256=mjZCrY4C8JsSXibaGWxaJQLrqw7ZIsTgAdMaQFSRk9A,872 +torch/include/ATen/ops/to_sparse_bsc_native.h,sha256=5jjs1AXgt-8kqyAeThS3qznKxt9qVPLDsLxvJTAeKrg,582 +torch/include/ATen/ops/to_sparse_bsc_ops.h,sha256=AoMNRjY41YCceewZlxm8YAWZkKcMuHgLOLEdMJh7rQc,1215 +torch/include/ATen/ops/to_sparse_bsr.h,sha256=-qqEFf-D0A8ep_L5-wqGbBGlLihCT-KM8ce3K_f_0xo,537 +torch/include/ATen/ops/to_sparse_bsr_compositeimplicitautograd_dispatch.h,sha256=-Z8a8srtJ7gxPkZR95aLWf8k036sILO9SdnJbiw3vZ0,872 +torch/include/ATen/ops/to_sparse_bsr_native.h,sha256=kI6T9qT758iTwLUAZewTJD_hhnUJsye-CH4fNeOhZrY,582 +torch/include/ATen/ops/to_sparse_bsr_ops.h,sha256=FW_0aEy8ww-romy7L0Eb-xWLjwI_Q7mFtY_sZE9WWcg,1215 +torch/include/ATen/ops/to_sparse_compositeimplicitautograd_dispatch.h,sha256=nOh5gyjB7wJTt2b4HpiaoJrBE9Rd-tbLkEgb-xQOLWc,1020 +torch/include/ATen/ops/to_sparse_csc.h,sha256=jU-IclOSjYsqqe-w4mb5jlHaA2JmsKLrmBXRbkq1mbo,537 +torch/include/ATen/ops/to_sparse_csc_compositeimplicitautograd_dispatch.h,sha256=lrsSgg9971D0-KvQJDzw659HCvq21Lkx0_vUbli5UrM,845 +torch/include/ATen/ops/to_sparse_csc_native.h,sha256=g-JursgyGt_UN0QnjCFNIj5_OIjgVthUIuLk-TGNC_8,555 +torch/include/ATen/ops/to_sparse_csc_ops.h,sha256=HgX195Pk_8Wyc3wsvZwUUMgbYtXuZQ09HIcdQXqczOE,1126 +torch/include/ATen/ops/to_sparse_csr.h,sha256=EBPziwnqB3UotOTTWuCNvzjbwe4k4OqFkS9EIYYXfMw,537 +torch/include/ATen/ops/to_sparse_csr_compositeimplicitautograd_dispatch.h,sha256=OAznos8X8okvLGLIngbV0ae17JUb3CvfR1wlhM0iia4,845 +torch/include/ATen/ops/to_sparse_csr_native.h,sha256=0ju4clM1dp7gbQCODVfLyHNg8AB-IHL1zysgsrCcltY,555 +torch/include/ATen/ops/to_sparse_csr_ops.h,sha256=B2kBxkQ9i75LLI7G4h7zrEYkS76eQ4OTAkjPSoXXKZs,1126 +torch/include/ATen/ops/to_sparse_native.h,sha256=dTQZV1n2LGsYgTdvZ66kuYe3A1KGMBRXNTCptbuoTfM,730 +torch/include/ATen/ops/to_sparse_ops.h,sha256=lS52LP12DYPiAt9oIucwamHbCVjnG91ujxtV5SBMB10,1975 +torch/include/ATen/ops/topk.h,sha256=bR9QI01kMRBxftMsgtR2oj9Xtja1XdlMhAPRWdnLBQk,5215 +torch/include/ATen/ops/topk_compositeexplicitautogradnonfunctional_dispatch.h,sha256=f_p_4NoiUwQxxZW6kzgIdW-aGqkBYj-My2IqBklqHNQ,1053 +torch/include/ATen/ops/topk_cpu_dispatch.h,sha256=YEGiTCimkgDo0_J3XD9LzRx_9aapD0qYyRIz5l0PpY0,1753 +torch/include/ATen/ops/topk_cuda_dispatch.h,sha256=2X8nFMfr2SoT3KKypOAtbYlZA0A1s9IGfRHVvTpws8I,1755 +torch/include/ATen/ops/topk_meta.h,sha256=3PKM5AXzwFkSQwjKNuOmn7aTN25UDQlnQkh3HjfhcaQ,643 +torch/include/ATen/ops/topk_meta_dispatch.h,sha256=FBhegHhvUoiehL7D7ht8bnHH9pHfaMjG21zZIeACn4U,1755 +torch/include/ATen/ops/topk_native.h,sha256=WQ-w4Sm2AHPcI5HX9n7Bpuwk-yucppX-gQF8XqlmWBk,1082 +torch/include/ATen/ops/topk_ops.h,sha256=f8UuWf-g5gSJHTI6N9GLK8JK_Xi2Q5eUirJebIdJqDE,2287 +torch/include/ATen/ops/trace.h,sha256=8sz2moijgweUanJ1tCyxjPN4p3oLuZP0tUfWZFVXsZw,1057 +torch/include/ATen/ops/trace_backward.h,sha256=GddFyN8x8U9oeNjicL1jZC-mvHnnTzQrymIpgnnWgCE,1519 +torch/include/ATen/ops/trace_backward_compositeimplicitautograd_dispatch.h,sha256=f8o4APXS3-pPo88Qb2ThHKRGVAvchAGncNu0ytrRkcE,915 +torch/include/ATen/ops/trace_backward_native.h,sha256=xQ711dYVQbaoA-8SLxYvMsk4IDmT3O_x6ErIPMqKPG0,539 +torch/include/ATen/ops/trace_backward_ops.h,sha256=4W50Ftoc6d_GU5Ul_PvHRlpBwZ3kwSgb11y5SuEPYQI,1101 +torch/include/ATen/ops/trace_compositeexplicitautograd_dispatch.h,sha256=kRVojARnoGzIHZ-w_USXmeJnVgkvwPsZZCcLFNJgnlk,889 +torch/include/ATen/ops/trace_cpu_dispatch.h,sha256=14PNz0vYwezNZ5bEkIskUDMxn9QObxgVx8G79ShHdG8,742 +torch/include/ATen/ops/trace_cuda_dispatch.h,sha256=_NE1MjNWnM2o1tm9r0q2hYuHjh26_rYVRZ7GW721XG0,744 +torch/include/ATen/ops/trace_native.h,sha256=Fd1z1kNNMMfIOxP6cRgJUkUhvtweM2S1OTAyAQ6uy4M,637 +torch/include/ATen/ops/trace_ops.h,sha256=K7jP4g5ioMDHK0sz3duflRLwFssV2YBehHUAL_HPjVY,1581 +torch/include/ATen/ops/transpose.h,sha256=Pw18jkM0z4IYJr3ioxFnYUSQv00euCa04sXzlb1laiA,1001 +torch/include/ATen/ops/transpose_compositeexplicitautograd_dispatch.h,sha256=RDPkkn2vvSn0bFBo3l_-VYeBwwz8Ix71O-FfrWwPQk8,901 +torch/include/ATen/ops/transpose_compositeimplicitautograd_dispatch.h,sha256=P0I4h9H5Q5rWIPtpzR8tjqujhQjeSvf2obi1ALZLmVA,826 +torch/include/ATen/ops/transpose_copy.h,sha256=ZI2uwCQxJ6k60wL9-X52IBUpOA3tjDOVQ2LqXjXLcc4,1351 +torch/include/ATen/ops/transpose_copy_compositeexplicitautograd_dispatch.h,sha256=0RL4PW-Pr_D_2hOPmLA9-Ia9BG3vBz_9yurHKWYOmfE,963 +torch/include/ATen/ops/transpose_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=2wNSZIEtMSx-hm5lyykkLKWo-agQKjomZev2ec90nM4,849 +torch/include/ATen/ops/transpose_copy_native.h,sha256=Fava16VMMer9s-0r7g1MfuGN1-yyDx0oZariBFQUbro,656 +torch/include/ATen/ops/transpose_copy_ops.h,sha256=8ThS7cMD4KruPaOI9cOtzBK7sp0OZGAZ_GPJllhuxvk,1846 +torch/include/ATen/ops/transpose_native.h,sha256=DGogYgO8nIFIMDupzCufTikHQc1OLSAmKx1d8Tchbk4,798 +torch/include/ATen/ops/transpose_ops.h,sha256=VRtHvVnmPZrKtlCBQTYXN4TEMcclemyP3fUjJocD1N0,2394 +torch/include/ATen/ops/trapezoid.h,sha256=c3eotAq6zsXWc-ABSlykkf8p0M-XFIIShEosqS5JGrA,970 +torch/include/ATen/ops/trapezoid_compositeimplicitautograd_dispatch.h,sha256=O7TSDzWShceroJ3CLlJ0hq-9UmOca7IOqklx5OVZXmQ,921 +torch/include/ATen/ops/trapezoid_native.h,sha256=4leKQOVjDlgxh1zFNn53JUZHg-beGMyOZNGQIIM3AGg,631 +torch/include/ATen/ops/trapezoid_ops.h,sha256=BUyiXwIlQsGYnBr7Ixn9SNt2j5jVIJgpk2uwWxTXn_o,1763 +torch/include/ATen/ops/trapz.h,sha256=4IaS54HuapraI2fWVdIojwHraNrb-o1scY57uA1i1zs,929 +torch/include/ATen/ops/trapz_compositeimplicitautograd_dispatch.h,sha256=5V6N34UBW4kGMI3Hjtjf6ZzMhvod1xifgYC7o5U3sxY,901 +torch/include/ATen/ops/trapz_native.h,sha256=j1CfPSlc89ZTnuS16gnQ9nEoInAS4JMVX6560TqqB-o,611 +torch/include/ATen/ops/trapz_ops.h,sha256=SL0LR5Ix02pMDvv28Mu4smh1f7JaKbhEzDCWRSb3BX4,1702 +torch/include/ATen/ops/triangular_solve.h,sha256=UGaVE_KsVzABYSdhyrvM-dQav8Vulji7oEhZCRtj6WI,1991 +torch/include/ATen/ops/triangular_solve_compositeexplicitautogradnonfunctional_dispatch.h,sha256=eOwaiPHrq_IT4rvgsXoRFf5PWugDYHCQjQ_pmqmgYcQ,935 +torch/include/ATen/ops/triangular_solve_cpu_dispatch.h,sha256=Djl32Nb6KoagVZmOfQUKxMTnKJgaDosRRzdDaVujeXQ,1283 +torch/include/ATen/ops/triangular_solve_cuda_dispatch.h,sha256=GG581uC9rW8eyOq49DygUlTxeyP_WN2ZpX2yYRMcgLg,1285 +torch/include/ATen/ops/triangular_solve_meta.h,sha256=d7_AyXDRLWwhKm2cF_7Sh_Fo9hNyJnb1EYBMtfi_nqs,674 +torch/include/ATen/ops/triangular_solve_meta_dispatch.h,sha256=lG_CfOtV6n-CN5br1jRfAjhgHxVDtJd-qWHgw9QWIiI,1285 +torch/include/ATen/ops/triangular_solve_native.h,sha256=ghpno7brCSFM9H9fB8mPg8UgtWf5RD_HKksj7SV7-AI,1170 +torch/include/ATen/ops/triangular_solve_ops.h,sha256=jUrzUtyEnOJpSm65rWP-ByBsVcQeuPW7Sld6Kx6-K_A,2437 +torch/include/ATen/ops/tril.h,sha256=YTb6kcT3ieatUmb9FoHFyEe3KquXJItzWaeS0Us8HKQ,1183 +torch/include/ATen/ops/tril_compositeexplicitautogradnonfunctional_dispatch.h,sha256=d6hV1Fqj3vYDYSREONBPDO2jar_vhDVkYn__GVI0UEM,901 +torch/include/ATen/ops/tril_cpu_dispatch.h,sha256=KZJXlR1zq_C7h6cRm1bZfrSemFfVSILbvkNc4nLkRec,1024 +torch/include/ATen/ops/tril_cuda_dispatch.h,sha256=7lVLslO-FcHlkqkDARoTCuU57e5TPXorSgePN_KScbY,1026 +torch/include/ATen/ops/tril_indices.h,sha256=gL9xsYZyEWxe_V1Jy8o1AnWkhT22WtFwRlOYJTD_4uc,2004 +torch/include/ATen/ops/tril_indices_compositeexplicitautograd_dispatch.h,sha256=4kBB6CgV91Fwk89J7lrQ1h6bDfgsftFiHEPYxoUMHiI,939 +torch/include/ATen/ops/tril_indices_cpu_dispatch.h,sha256=FEFCGsqhOrYXS0vRY4TNW_lxHfdEH6w7a7bAMTUT6EU,1028 +torch/include/ATen/ops/tril_indices_cuda_dispatch.h,sha256=vvqI9t7kFzGfsSM_igDxGC7-IqNqLq-3fy86WulS7bo,1030 +torch/include/ATen/ops/tril_indices_native.h,sha256=jZwy9kWZIfl-uTwMGvD9GY8nilP2AEiACvD7gV8urOM,1027 +torch/include/ATen/ops/tril_indices_ops.h,sha256=6ttc_M8x6iHwSpws0zgCOLfZEIaOx-llkFb53M8ZyG0,2239 +torch/include/ATen/ops/tril_meta.h,sha256=bs9n2gI_P3OIpHpDGkE0-ZCxKDsYLpkQP1JG0eTnpH4,610 +torch/include/ATen/ops/tril_meta_dispatch.h,sha256=CHipFTMeJONiVhr3hU3THW8dqNUEFKKSLk2Fo2bd5UA,1026 +torch/include/ATen/ops/tril_native.h,sha256=guH2jcdtKgHNVNis9MPi5ogt_LTQyoxM3jzvrKxHEJs,790 +torch/include/ATen/ops/tril_ops.h,sha256=lYpPPKy8TPAgnDoCnCQQLO1C_vf77bQhNZbF3GBeW10,2262 +torch/include/ATen/ops/triplet_margin_loss.h,sha256=Rs4ZbXxmnAinAR9XKgvuDtlU0k9dp_WEtAIdXGH9dBA,1059 +torch/include/ATen/ops/triplet_margin_loss_compositeimplicitautograd_dispatch.h,sha256=TVOs_9VsxDQlJ_FYR9GOVCqZULpdwWd5RxXKY_TELLk,965 +torch/include/ATen/ops/triplet_margin_loss_native.h,sha256=hJJ8VXGYbIWOeiaS32U9Cu3SPDbKpnxeAtn00ZLm9F8,675 +torch/include/ATen/ops/triplet_margin_loss_ops.h,sha256=u3pICziw9j5MXAv3YVTGrhgrYOnUu5dbf3XQz4QqR4A,1477 +torch/include/ATen/ops/triu.h,sha256=nrkFLbsT74OVwHcKq7BXx0o73weI-AOxoKMJHBHb25E,1183 +torch/include/ATen/ops/triu_compositeexplicitautogradnonfunctional_dispatch.h,sha256=AkXr_PC9K288WcUkfjUIDA1z8I7qj7dogup0VWcvIaM,901 +torch/include/ATen/ops/triu_cpu_dispatch.h,sha256=DflxDDjsyThC3UAXpljK2Fc11cYntn-50bP2qGLW7GM,1024 +torch/include/ATen/ops/triu_cuda_dispatch.h,sha256=hwOdALS8Nm-Lj2qwmV7rpS0Q8W6YwUCT5MplZ3TnCN8,1026 +torch/include/ATen/ops/triu_indices.h,sha256=lD0JeJeqmNXVfdWEmvQyHXllWNwMQ80sceW7NeS6l-o,2004 +torch/include/ATen/ops/triu_indices_compositeexplicitautograd_dispatch.h,sha256=l33gqAub1sdj4eDj6IivjYK933rNzFFU5oqIA96qhXQ,939 +torch/include/ATen/ops/triu_indices_cpu_dispatch.h,sha256=lLTez5ya8UF_cWqa1yiXcT4LtUxkNrBTle0WUGHmWeY,1028 +torch/include/ATen/ops/triu_indices_cuda_dispatch.h,sha256=p5RwI2sVPQRb7pFrK9SxkSMNSCtT7D5WghSqcg4WALU,1030 +torch/include/ATen/ops/triu_indices_native.h,sha256=X2FqVjqNNOl9n154uCBgQSHlqxQ48R72I8pRFkrhKAU,1027 +torch/include/ATen/ops/triu_indices_ops.h,sha256=CalGEFZXRcyMZVeaLBo0kfDExJZmpY_unyEpW8PAQto,2239 +torch/include/ATen/ops/triu_meta.h,sha256=0gU3hfKtAB2loXiO3EYtrAU6id-Wci0Lsu8aB1VanBA,610 +torch/include/ATen/ops/triu_meta_dispatch.h,sha256=c-tUUDjt2FtATC6Bw-Uwy5kI7xSIyX64O8kByu9_rL8,1026 +torch/include/ATen/ops/triu_native.h,sha256=cGbzr6lJy0rCI71RA4r-VDKS0YUeap8O6tNZbOLKiR8,790 +torch/include/ATen/ops/triu_ops.h,sha256=wxfj0abrWevnKVWJ979bGE_IYVED-ua5Fyj1va04q-s,2262 +torch/include/ATen/ops/true_divide.h,sha256=yoRC9wCUo3koTgMC53EOT3iSui9NvOEGtOW_n-hFoSA,1488 +torch/include/ATen/ops/true_divide_compositeimplicitautograd_dispatch.h,sha256=-i6GmGjJzPZbM6KK7AFrEXNU5Wq1KWSZLZWbiMiLj9g,1291 +torch/include/ATen/ops/true_divide_native.h,sha256=2XFtI73sPI5FzpQg9Vb2MQQUfvZVTwyE7Jpylo8zMSo,890 +torch/include/ATen/ops/true_divide_ops.h,sha256=ScySZi28HVBtJnV-uHuwhV4dWCy1v8x_A3WFCY6-3Ug,3703 +torch/include/ATen/ops/trunc.h,sha256=2Cmt6g0yWY4zsbxDqI-4KuCb_SIK8duRJ_ISEgLALc0,1201 +torch/include/ATen/ops/trunc_compositeexplicitautogradnonfunctional_dispatch.h,sha256=ySkIUCH2DLCNFatJXj8pIS89gjOsAvgtCNmFx1QnYxQ,863 +torch/include/ATen/ops/trunc_cpu_dispatch.h,sha256=80T6Na4aQIBVooU0bR9WV6tlH2E230J6t9_j8R6UQLs,950 +torch/include/ATen/ops/trunc_cuda_dispatch.h,sha256=RvKDQ7snoa2GFts9-S5-QbseNSStd50Mku2vBV4IxMs,952 +torch/include/ATen/ops/trunc_meta.h,sha256=tkm_9EHdc6pognZx3fsjDb5UW91x7uFyEWwf2DMUjOc,593 +torch/include/ATen/ops/trunc_meta_dispatch.h,sha256=K3Ov1qUK_LF2Z2nV_0zZVwSxdAe_jKGWc-VeCouoRP8,952 +torch/include/ATen/ops/trunc_native.h,sha256=ks0uOEgqhLRl_AHg87j5DDODe-Orz5MRUG1yWCa07Uk,1036 +torch/include/ATen/ops/trunc_ops.h,sha256=OBrtyCC3BKKBkllQXDDS681C2THiF260n2rjtvL7Vw8,2088 +torch/include/ATen/ops/type_as.h,sha256=wSq3kVFeycRaoXPuuKxu4KuFOIGnrVCMqlNaTpqFpBM,531 +torch/include/ATen/ops/type_as_compositeimplicitautograd_dispatch.h,sha256=85RLzfMI8a_0YYCf9BDcYR8TqGIC51G2KaD5NwE7Ut4,814 +torch/include/ATen/ops/type_as_native.h,sha256=IMKc9f7yqAdytUW3YIs6lxTndVuOyfuoBFhONiqIjPs,524 +torch/include/ATen/ops/type_as_ops.h,sha256=JzrS51pXzIDWXuVQ8XSytBxMo_g6JTZe1C_coalx2t8,1075 +torch/include/ATen/ops/unbind.h,sha256=qvH1brPVKz8rRAzaQAgA4tOMVmahYVNFIqtL7_FOKlo,954 +torch/include/ATen/ops/unbind_compositeexplicitautograd_dispatch.h,sha256=VBc2esqM2gUTx7HbEvcnCdaBuhjfMfBo9zUQ0eTPTRM,817 +torch/include/ATen/ops/unbind_compositeimplicitautograd_dispatch.h,sha256=4fdvEgQy16o6YGhNDOUVyfurpJ0MBQ3mlK-83hLc-vY,819 +torch/include/ATen/ops/unbind_copy.h,sha256=2YrisXb5jv4wLHdlHZG0-zqUOA1H9ONNdqbZ4hZPyaI,1225 +torch/include/ATen/ops/unbind_copy_compositeexplicitautograd_dispatch.h,sha256=vJWiFGUd1iUn7r7NR2CuFdJmrR9gl4kbDdr06ycUHkU,917 +torch/include/ATen/ops/unbind_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=wmrLMNM0Vmdq25Z62vTF8j1nVjRUrkI2ot_LYY-vj4k,848 +torch/include/ATen/ops/unbind_copy_native.h,sha256=KI7Tfi9wjLx6WLSLVDHd9YvhbC3F1ylT3ox5vnO5Im0,631 +torch/include/ATen/ops/unbind_copy_ops.h,sha256=kFCMqziDm9BC0Do0XylBFOOVOvk1YusBnGdD6Vkvpz8,1755 +torch/include/ATen/ops/unbind_native.h,sha256=swwPrH-Z47_CBu8dv7lI3tcL6eB0SYPXjcWFPK9ITRw,712 +torch/include/ATen/ops/unbind_ops.h,sha256=Vi-l085OJKoDqCikdHEllCrZfO4WrchJmezLJglns1E,1753 +torch/include/ATen/ops/unflatten.h,sha256=MA8jjZcDevB1Digleya8j54jCOauPE1y5YGgJJ1FeOI,2851 +torch/include/ATen/ops/unflatten_compositeimplicitautograd_dispatch.h,sha256=qG-ogyLkW_XZd0TXSDXP1N0kIIahfgJOUT0-_VI7bY0,1184 +torch/include/ATen/ops/unflatten_dense_tensors.h,sha256=FfzzHD3waVXd-MyaOfUwVU0f7fELMCsjf5ujWzX4eoA,804 +torch/include/ATen/ops/unflatten_dense_tensors_compositeimplicitautograd_dispatch.h,sha256=wzFXh8xt0yObAgmpmRYH8m0YZBETDSt0ckXqkdcBJ64,843 +torch/include/ATen/ops/unflatten_dense_tensors_native.h,sha256=ke86BIWEJhdid15MxdksL_Ms40gxPBkdUzX1F-v_tAM,553 +torch/include/ATen/ops/unflatten_dense_tensors_ops.h,sha256=jT_uygGsHVGcUHrTNQSv4nMidR1GGsU1_oQRlUJD2iw,1166 +torch/include/ATen/ops/unflatten_native.h,sha256=bvmUNPmqTXuTZtm7TBjpUW49hpKscF5mXIPEVA-3hhA,687 +torch/include/ATen/ops/unflatten_ops.h,sha256=HjpYw_mxiNydpNuhP9Gyt3MgoGM2JfWsTQeBGE8-aOQ,1927 +torch/include/ATen/ops/unfold.h,sha256=Hz-hy4xt8BdMm8UyuWl10m1MDuKjXNx2fpObOgxGXf0,530 +torch/include/ATen/ops/unfold_backward.h,sha256=K4OBa03iHfrN51kS9U-RTzYDWgrUPgtW6UkRdxQpsxk,5032 +torch/include/ATen/ops/unfold_backward_compositeexplicitautograd_dispatch.h,sha256=X6TlOICYmRzLERJjjg3suKVcHhurpWOJrFpGS8Posws,1400 +torch/include/ATen/ops/unfold_backward_cpu_dispatch.h,sha256=mySslHJMJT1vl2c0iCEDQVOhvEgQg44l966xRLT6Qww,973 +torch/include/ATen/ops/unfold_backward_cuda_dispatch.h,sha256=ZRzeTO60a6imVlObC7TNEvGBHTYi9mnfTbnsqiwua5c,975 +torch/include/ATen/ops/unfold_backward_native.h,sha256=zhjeLTNLAAZNha3cj3CdPPmpUySz9Pao8dnZuYroNJU,751 +torch/include/ATen/ops/unfold_backward_ops.h,sha256=UpqZfulQKWVrcAFx0_U9HaXhfJNYNiZs4u7gCXMOQRY,2153 +torch/include/ATen/ops/unfold_copy.h,sha256=hetHdtimt1g1s2RTseDz3fz6cx2tVMkboCOz2I0W5CA,1432 +torch/include/ATen/ops/unfold_copy_compositeexplicitautograd_dispatch.h,sha256=Y6jJS2Dj_pHG_yZxhIKPMjXbazuziUMb2fzz54NEkT8,995 +torch/include/ATen/ops/unfold_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=bz18eSyZ2LqTv6yj79g6cDPI96rQtgqUQK33LHT07ak,865 +torch/include/ATen/ops/unfold_copy_native.h,sha256=COp28Zfob4GKMqc1KHcgbU8XPGgxEaVnI96OFiBlUCM,680 +torch/include/ATen/ops/unfold_copy_ops.h,sha256=XF_T5qN7eoiVObvqXEHUC-lJ3xz0aRrshB0-e5O2VP8,1929 +torch/include/ATen/ops/unfold_cpu_dispatch.h,sha256=8qDJpuX5LZsxyVUaOCBP7YuxjcYH3z6NBphwLX6nKHE,790 +torch/include/ATen/ops/unfold_cuda_dispatch.h,sha256=NClqq7jQEbpy-wSBSwlkkVYZAI1AzwJwY5GEXuujPLY,792 +torch/include/ATen/ops/unfold_meta_dispatch.h,sha256=Iroqc78YtEq7L5iyuGCcfrJLwd-O7FWDVXZO1P4EK_g,792 +torch/include/ATen/ops/unfold_native.h,sha256=m6-LBnBemuCqc8go2yKtgOkKVG826vCCIU9q2TZ74go,544 +torch/include/ATen/ops/unfold_ops.h,sha256=4IiwrHKCkgKjsBCXT3NS7VhYw1eNhGtHG5B_MkhIhS0,1148 +torch/include/ATen/ops/uniform.h,sha256=_li4zrHSqBmZLlaTKPl2XX92c50TdD5aMZBefXCuyew,1538 +torch/include/ATen/ops/uniform_compositeexplicitautograd_dispatch.h,sha256=iRTlkF-WotIhTNBoXlk6eMSWFOYAHp0lOydMlB8uWME,1185 +torch/include/ATen/ops/uniform_cpu_dispatch.h,sha256=lpcy-vl0fY5WjsgNyO8j2AOfC3S0jQt00yJ5E6k9Pq8,826 +torch/include/ATen/ops/uniform_cuda_dispatch.h,sha256=9vTxbtUY6X2_fS5nt9tYS8OzJ41uCsewcJK6T2ezwMs,828 +torch/include/ATen/ops/uniform_meta_dispatch.h,sha256=TMeS22XPGiN33ofP57PFqwfFD36G2dfjjYTdC3GJWaA,828 +torch/include/ATen/ops/uniform_native.h,sha256=lI4YrJbcUN0goGYMLKwlkXtpFwmkE4gcHW8r8Kyi8fk,1010 +torch/include/ATen/ops/uniform_ops.h,sha256=C0A-Ilr-2yyaT352Kft4kvFXM5M9H3DZtx3HsQqWsq4,2811 +torch/include/ATen/ops/unique_consecutive.h,sha256=FqglQx4xyju3oewtfkWp9Qzsrk-0I8vKJ7DjB5cUpUI,2126 +torch/include/ATen/ops/unique_consecutive_compositeexplicitautograd_dispatch.h,sha256=qY13U_K_bHnl9zg6spcU5Vw3YSIFbOb1psI2ufg3k-o,1242 +torch/include/ATen/ops/unique_consecutive_cpu_dispatch.h,sha256=yzs_sUkeRyzPGidMWxUUXL_kaa9r06zYiQ4w0suUiVc,889 +torch/include/ATen/ops/unique_consecutive_cuda_dispatch.h,sha256=URfECXryLg3Ast6tL_YeH8kiLLVl1VLlkicdcJXDm8M,891 +torch/include/ATen/ops/unique_consecutive_native.h,sha256=YPt0JxiR0cOZOxTsXQnQA2vIh5eDN5J8wSadcLksgPY,1094 +torch/include/ATen/ops/unique_consecutive_ops.h,sha256=ddvjJsl6jZLaytPt9m6_j6g1OyXsYd7hdF--p8qxYgI,2568 +torch/include/ATen/ops/unique_dim.h,sha256=O4NBW5D6hFdpg1nb6D231UKOSqcZWWOFpljtg4H0yiI,2074 +torch/include/ATen/ops/unique_dim_compositeexplicitautograd_dispatch.h,sha256=y1_6orY3E-XQBLugrxFJqNfeaEBNqAdiFOWMxLzU2VE,1208 +torch/include/ATen/ops/unique_dim_consecutive.h,sha256=s3Jqh2hsndoQK1iQGEyThAe7rpm0bXBvyaUdo-flxLU,2067 +torch/include/ATen/ops/unique_dim_consecutive_compositeexplicitautograd_dispatch.h,sha256=g41rJjAv7oJqDg7BfCyymS5Qc2Qx-6ZKS7OJFEgZ1uc,1201 +torch/include/ATen/ops/unique_dim_consecutive_cpu_dispatch.h,sha256=qwuCbU_H18a199qJrFR_DPJWWGCrGKRlDU51mQkOx2k,861 +torch/include/ATen/ops/unique_dim_consecutive_cuda_dispatch.h,sha256=er0BXohkuIZ2lc12VTzR52C_1lSx9fFrJW8PsC8vj3U,863 +torch/include/ATen/ops/unique_dim_consecutive_native.h,sha256=H89MAtYxps1GsUEm1UH93suG--d49DMyhjW0RNgeSgs,1025 +torch/include/ATen/ops/unique_dim_consecutive_ops.h,sha256=lslbnNQeuP8c8MYhj-qFlo-Y4mP1g0dO75hi0k2xdho,2478 +torch/include/ATen/ops/unique_dim_cpu_dispatch.h,sha256=ZKfDjH325UB23KNML3HyV_l7f-4Vizf0TXwv9KkNrgA,867 +torch/include/ATen/ops/unique_dim_cuda_dispatch.h,sha256=OT_hGwkZwCtktIYz8BjmHEFg_2EZTcVPM35Q8R7QYIo,869 +torch/include/ATen/ops/unique_dim_native.h,sha256=YXGjJxEnxAxnBYIgEXBBP-pwsO1UH1Ao4YOQtaZiFro,1038 +torch/include/ATen/ops/unique_dim_ops.h,sha256=UQwcO_c1uJ3vGylmt7VMpRn7ChbUEWeIlLCn0ZAyFBA,2506 +torch/include/ATen/ops/unsafe_chunk.h,sha256=C1xCssz2Ew2X_Zpej86d3-mhB6oPwOLzfIFF-_N8oj8,776 +torch/include/ATen/ops/unsafe_chunk_compositeimplicitautograd_dispatch.h,sha256=fcabCk1HfnlxeZODPSbvYnUiWHdvRNcrWnP8AZGqWoo,839 +torch/include/ATen/ops/unsafe_chunk_native.h,sha256=o7W0uUGdJiul5FBmGNUug0KidXR18T1lByjtV1GbJAw,549 +torch/include/ATen/ops/unsafe_chunk_ops.h,sha256=GBkD84So3WVlZTZGrwxXV17UJOSk87xeY_XpKdMjD94,1150 +torch/include/ATen/ops/unsafe_split.h,sha256=nKTZZNm62YYfnd4LppeBJsrYia_kmiq9bZ38HZMgvuQ,4101 +torch/include/ATen/ops/unsafe_split_compositeexplicitautograd_dispatch.h,sha256=3DG0aAI0zjfqQXAixYJ8SM5W8KYucO4USYlW6uOsuMU,1441 +torch/include/ATen/ops/unsafe_split_native.h,sha256=vqLgKF0etz6qOxhX4nGjlITp-RCFS49-XRFfc1sQGuQ,683 +torch/include/ATen/ops/unsafe_split_ops.h,sha256=cEPcHzxLFPcY4sbqfzguCupzL0khv-nm9ejPs3VaQFY,1939 +torch/include/ATen/ops/unsafe_split_with_sizes.h,sha256=LsQ974mSJrQ9HFHCx1wsxPR1J3wQdGhwXAXo4vc4Bak,4610 +torch/include/ATen/ops/unsafe_split_with_sizes_compositeexplicitautograd_dispatch.h,sha256=2JpEE6Cu7LjrxEVlwGdAS-T09yxHY5f1mzjkTzCMgzo,1561 +torch/include/ATen/ops/unsafe_split_with_sizes_native.h,sha256=tMuLSxkbAZQQHYg5AYjeKzwsnitnTqippvlc-W-02zU,716 +torch/include/ATen/ops/unsafe_split_with_sizes_ops.h,sha256=WfZK4ScOxg2cGW9xrR-hOtx-76mrvRSXQwbOz1Zs1Hc,2022 +torch/include/ATen/ops/unsqueeze.h,sha256=5lecauST4q7sLJU_44ePBg9dV2zhEmmmUe8NUYSfyHY,713 +torch/include/ATen/ops/unsqueeze_compositeexplicitautograd_dispatch.h,sha256=63eJYG9i0irBU1_fxnfqbhCvpPCLi0eE37gR8-6SKKk,871 +torch/include/ATen/ops/unsqueeze_copy.h,sha256=6Arl97oGaWZtZNLYo_WdmVJ9uuR-jwHeM-Jur-j3HBc,1228 +torch/include/ATen/ops/unsqueeze_copy_compositeexplicitautograd_dispatch.h,sha256=3YhRf-kIb_8-fO8kvkoyw6hCHAvIS6gNteWl-jkvms0,933 +torch/include/ATen/ops/unsqueeze_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=Z_nOuAqMbZNfM8Waf3g0WBt77BuqcywNZv6zHx4QCCA,834 +torch/include/ATen/ops/unsqueeze_copy_native.h,sha256=Gv4XrN0dJK28B2-rcwZhaQphdn_I4aMAbCWM251cBSY,618 +torch/include/ATen/ops/unsqueeze_copy_ops.h,sha256=OZJYMqo4G5cPMCHR942YH2L2lV3SxBwv1MDFIG32PKE,1723 +torch/include/ATen/ops/unsqueeze_native.h,sha256=SxHUB50hFp56gnecAmQksUbTqSiYoytpiiIVNMaQHQo,818 +torch/include/ATen/ops/unsqueeze_ops.h,sha256=byHbzcLyI5E1_m9hUzOox1yvbi0HGnifuvecgg2h5J4,1608 +torch/include/ATen/ops/upsample_bicubic2d.h,sha256=wuOWvSgHL9SMvKTRTfboxwurUROaQ5IXNs6IvTZ_ZHU,8052 +torch/include/ATen/ops/upsample_bicubic2d_backward.h,sha256=SL1tjglXlBaQPuxEOlj1v18eOPe2H0lFzjvn4J92ojQ,7776 +torch/include/ATen/ops/upsample_bicubic2d_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=XuASkxn3FXfw2de7UBCN8T2-zR1L0ZEfI_u1hIEz5wM,1289 +torch/include/ATen/ops/upsample_bicubic2d_backward_cpu_dispatch.h,sha256=ItQ345zw-n_Y1gQd5htEZtPA3bhaKloGa--W3xYIo08,2347 +torch/include/ATen/ops/upsample_bicubic2d_backward_cuda_dispatch.h,sha256=Eg1w3ZU7NNNdvUVGNxQjY29XIOKQ1t8ezhvwWgcd_BM,2349 +torch/include/ATen/ops/upsample_bicubic2d_backward_meta.h,sha256=xcPYxxOFX096wABeZah0x6VlQy9zKkFycX9ebTVh0yg,779 +torch/include/ATen/ops/upsample_bicubic2d_backward_meta_dispatch.h,sha256=kFQkYY9-Iuv0Z8mfQSdHVI1B6UPaoH9EpZVgGXzkvbI,2349 +torch/include/ATen/ops/upsample_bicubic2d_backward_native.h,sha256=V0LsKdONMWqIE6Q2F_KxnkyG86l0L4Xdk3_ztjl7phk,1219 +torch/include/ATen/ops/upsample_bicubic2d_backward_ops.h,sha256=fcRouSCm7zsetye0-koZZh-vH4czfn81TwVKpEBOS8A,2823 +torch/include/ATen/ops/upsample_bicubic2d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=TzU5b4aONMrjTDJzMwLyYQqhe6WKgQzEHKTKl1epktQ,1197 +torch/include/ATen/ops/upsample_bicubic2d_compositeimplicitautograd_dispatch.h,sha256=CSvDoYr_0I7AIgmV8OMpJ6z_9-82854m8voLz1LG-vU,1098 +torch/include/ATen/ops/upsample_bicubic2d_cpu_dispatch.h,sha256=1ejJQSCITxZVtn7RXODe7WUHxRJy8lHIJBuwITnpHg4,2043 +torch/include/ATen/ops/upsample_bicubic2d_cuda_dispatch.h,sha256=QQK38suSkKX9wvF1MgKGTUM37ZI5zYZO_5q46nyuRhU,2045 +torch/include/ATen/ops/upsample_bicubic2d_meta.h,sha256=TS6RuTvonrePod5h8w7jlZUUSIMJ6REug6fmOyH-HSY,729 +torch/include/ATen/ops/upsample_bicubic2d_meta_dispatch.h,sha256=SumDm8y57Q1Rx8Di44DL0cDxu277JvPShKXKdIFejl0,2045 +torch/include/ATen/ops/upsample_bicubic2d_native.h,sha256=pCm33Zll4YjVikjoTc01TmlQeOMpqrKel4T6g-PnwFI,1256 +torch/include/ATen/ops/upsample_bicubic2d_ops.h,sha256=MjZO3_y_eTLeRK7KI4IHA7XRGmQp6amR6_pi_nHYMOY,3398 +torch/include/ATen/ops/upsample_bilinear2d.h,sha256=WSo1h2E8jh1NvgD6K88o11ozRj-ed5feNcAD4mLyOLQ,12051 +torch/include/ATen/ops/upsample_bilinear2d_backward.h,sha256=ACxRHn7vn8ftjNxPoipRX0TnFYUecg_5a2H67S5UJTM,7807 +torch/include/ATen/ops/upsample_bilinear2d_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=vk735hgYb6n13q0ol3aLFa-h_WUp6vf4tEjQQ9SGIME,1291 +torch/include/ATen/ops/upsample_bilinear2d_backward_cpu_dispatch.h,sha256=BsXJVsLYs3K8KT4WTlLJ5ycx3-7uOqTVAlwZ5kxEJq0,2353 +torch/include/ATen/ops/upsample_bilinear2d_backward_cuda_dispatch.h,sha256=lQWg6DtQBpQVH1ElNfdVwbko3srNacIdrzc0wO01XXM,2355 +torch/include/ATen/ops/upsample_bilinear2d_backward_meta.h,sha256=LLKKXIx16BDgD1ReBytoSUuJzivc303jgfLRAAJncws,780 +torch/include/ATen/ops/upsample_bilinear2d_backward_meta_dispatch.h,sha256=3hWx4pKxXopnnD02gbINkpX3lBbyqsVX7GPAdk7dU4Q,2355 +torch/include/ATen/ops/upsample_bilinear2d_backward_native.h,sha256=TdmeL1e_C7BSFo-eY6otqIgLh_mPCoQlmNNIna3_KKE,1224 +torch/include/ATen/ops/upsample_bilinear2d_backward_ops.h,sha256=ruOAxjutTIj0N0xV5dSqOSP66ZkltlrMbLemJI75WR4,2829 +torch/include/ATen/ops/upsample_bilinear2d_compositeexplicitautograd_dispatch.h,sha256=pAvSAKsXGH-UCd6lL0WkLANg5ZIPYM7Lp7SLi2R_y3U,1566 +torch/include/ATen/ops/upsample_bilinear2d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=93KIA_G1wb3SDcrye4TZ9tb4Coix3fD5HdupcgNDZ58,1199 +torch/include/ATen/ops/upsample_bilinear2d_compositeimplicitautograd_dispatch.h,sha256=NYZHs7u1xmDjZm6E8T6j4V4ZXztqluzhpgdNVoH1szo,1100 +torch/include/ATen/ops/upsample_bilinear2d_cpu_dispatch.h,sha256=9BLGh7UGlGsH4ityCCBBnoKiGt2HaFRFsTAFHPQDIJM,2049 +torch/include/ATen/ops/upsample_bilinear2d_cuda_dispatch.h,sha256=MUs-JDYDHgYI7ddPh_VNrxdkjpxv7xj1FMVQPHJJ7CI,2051 +torch/include/ATen/ops/upsample_bilinear2d_meta.h,sha256=q414S1NVBVIVIilKW6rNJ7fT26RSkAjI22jANzKPFQ8,730 +torch/include/ATen/ops/upsample_bilinear2d_meta_dispatch.h,sha256=VhNKW6xWmcgGgjxVVxp8DtFnXBFR_oUjbSV4JfuUZfc,2051 +torch/include/ATen/ops/upsample_bilinear2d_native.h,sha256=5gQ-ksRd_1lt7A6keIDiM979RupCGU1Co_X7xKaJ75w,1708 +torch/include/ATen/ops/upsample_bilinear2d_ops.h,sha256=qOAqRPTxaFoULS1veDNrrilb5whmMm5O587SpegBMEk,4428 +torch/include/ATen/ops/upsample_linear1d.h,sha256=Rj-HaCG4SGHmHFzTf-VAuvzFpCuIGZf0My3zZwRdcO0,7171 +torch/include/ATen/ops/upsample_linear1d_backward.h,sha256=Lf1Rbcsv4Sv0W4sxW3zB0j_5_LJBZn_oD5_snH52FPs,6905 +torch/include/ATen/ops/upsample_linear1d_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=OoUKcd8k63EzA4g4EtaS1HafxZWuKyB0WdR0rH_H1tI,1185 +torch/include/ATen/ops/upsample_linear1d_backward_cpu_dispatch.h,sha256=MIwOgnYMtXbai0qbboaWo0NALzQrYm13d4KlpUML1EE,2065 +torch/include/ATen/ops/upsample_linear1d_backward_cuda_dispatch.h,sha256=KvWyPTLjWkI3AvNLBEZFPveEg1ZIVvM-yhzzGfijuyw,2067 +torch/include/ATen/ops/upsample_linear1d_backward_meta.h,sha256=LBWUiMdnGDXdarIGquVTHXKsIzvn_46rlA8-wyVdoHs,742 +torch/include/ATen/ops/upsample_linear1d_backward_meta_dispatch.h,sha256=qo7VDF2h5NLnmbwSGVtirve0eSAdW8ZJkgPslTON1gA,2067 +torch/include/ATen/ops/upsample_linear1d_backward_native.h,sha256=Z9Bs-5hi7-33-Wtf5_hGmD0zYtf0KOE9WHZ2J_M-Z9I,1142 +torch/include/ATen/ops/upsample_linear1d_backward_ops.h,sha256=rotfaW2kmKNyHbsTuzBqotsmBrZQq5JR-BgHv_rJlXY,2575 +torch/include/ATen/ops/upsample_linear1d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=CD9mAId5O_8_P803VJ6cVP4f6Dm4tuA0nMSM_yOQNmI,1093 +torch/include/ATen/ops/upsample_linear1d_compositeimplicitautograd_dispatch.h,sha256=KG1-oAylEfz6xcmcIbouLqrVJ0E9aNzlFcTXzZVsr-Q,1096 +torch/include/ATen/ops/upsample_linear1d_cpu_dispatch.h,sha256=H5MPHGm3UAxU_hsLDByovlhsazptWQ_dXKQ3LhKjQTQ,1761 +torch/include/ATen/ops/upsample_linear1d_cuda_dispatch.h,sha256=F27pZ-7oU0MAzY6dKYOiST7u1KXBEG1QD0LY-5Wmcyw,1763 +torch/include/ATen/ops/upsample_linear1d_meta.h,sha256=IgoBmrS5UURphB3GWWDf97oAsTwYEhOINO0PBdRZEa8,692 +torch/include/ATen/ops/upsample_linear1d_meta_dispatch.h,sha256=E6ZFb-kZwGMY57DBgBdiD22C69FmmTVeQsKDlk7jiTw,1763 +torch/include/ATen/ops/upsample_linear1d_native.h,sha256=ZYCs1Od0AS6D-o-typ5N8DDbJaY3w6um8pEhuDOlLuI,1178 +torch/include/ATen/ops/upsample_linear1d_ops.h,sha256=jsxoGTX8MMI_rn3rEH0WzoPDoqwFLE51-g275HZOmFE,3147 +torch/include/ATen/ops/upsample_nearest1d.h,sha256=NowHSJ1sX0L0d28LSL6bhP7fPCZ4o10jr-ykOnXBXH0,6492 +torch/include/ATen/ops/upsample_nearest1d_backward.h,sha256=oj6T_86EcRgcR_T5hEsIWeOQrmMf6HYjoyr0e0EnmH8,6396 +torch/include/ATen/ops/upsample_nearest1d_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=4M6fmkCUAoK8kM0Jxq2SOFZl7X8wHdHmzPZvc4Es8jM,1147 +torch/include/ATen/ops/upsample_nearest1d_backward_cpu_dispatch.h,sha256=55IlgXLw428FAXQEU9PfcynC6V1GVv_OLwFWTeHGK2U,1951 +torch/include/ATen/ops/upsample_nearest1d_backward_cuda_dispatch.h,sha256=AH7gLor9HmH6lLoJSsdGxNmz7vyJBNCgZBdh1Ax6T8c,1953 +torch/include/ATen/ops/upsample_nearest1d_backward_meta.h,sha256=Bde1Z6Fk0SUH7ykfeATxKbWQaIMsFHiveTmsJZIfVno,723 +torch/include/ATen/ops/upsample_nearest1d_backward_meta_dispatch.h,sha256=QNMdQIlLw9m5eYqvKkHdHdgVexGvJSi62HeIqjmjOqU,1953 +torch/include/ATen/ops/upsample_nearest1d_backward_native.h,sha256=yHbguw7-7UzQMm3xMPwWQ7Hnd7un9K8yd5PGqBK8DlE,1107 +torch/include/ATen/ops/upsample_nearest1d_backward_ops.h,sha256=eRUsk6lnGEtg-hj08ohkYvH2QM6dTGK0tuGNfiFMLu8,2449 +torch/include/ATen/ops/upsample_nearest1d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=-nJwXrQl1yJ6L46Ig69ErH0MgKmkw1nt6Ejep08actA,1055 +torch/include/ATen/ops/upsample_nearest1d_compositeimplicitautograd_dispatch.h,sha256=PcOWmfoRoDa_XsX-KJYaL6W52VVJ2D0Bowtd-cyihlE,1058 +torch/include/ATen/ops/upsample_nearest1d_cpu_dispatch.h,sha256=5yAyqJ1AEhhhoAfZuZ-0-Y0SBw05_Pl50NCR47_-oZk,1647 +torch/include/ATen/ops/upsample_nearest1d_cuda_dispatch.h,sha256=5ZkyPMr5t2cmvj158OFQmmFSbXpXUZU2QWBMTEPkn1E,1649 +torch/include/ATen/ops/upsample_nearest1d_meta.h,sha256=D-fRkBIkcq0d_jPUFsvFtoEu1UlyBKka_Cxde9GpgjE,673 +torch/include/ATen/ops/upsample_nearest1d_meta_dispatch.h,sha256=7Xl7My0VThvCL_xMKmydZuzfSB_na_OfCzTJzGjz6vc,1649 +torch/include/ATen/ops/upsample_nearest1d_native.h,sha256=pjGvaaKm2-ckToVzocwtUKxarpLh-aIL5hGmuWamqTM,1124 +torch/include/ATen/ops/upsample_nearest1d_ops.h,sha256=zuN9DNIEuwW-6HGuXo4f788A9QZdYC9UxlVkkAe8Ujc,2958 +torch/include/ATen/ops/upsample_nearest2d.h,sha256=aeWWEfoc170QxgTiQlpsFDiuXRPHZJLTC9eHaa69o6Y,10910 +torch/include/ATen/ops/upsample_nearest2d_backward.h,sha256=5Q4eDz8UYjVbF4rk5fo-R92Ie6uxfnUZhUg58AOICH8,7236 +torch/include/ATen/ops/upsample_nearest2d_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=nJGtoSr-7DGBNjhjoiHrNZRX9ZCtkUAuk0EaCnvhHYw,1249 +torch/include/ATen/ops/upsample_nearest2d_backward_cpu_dispatch.h,sha256=zwGFhoXcLouxCajx_yViywx0cDdNjRdhL3Kka3y9OxM,2227 +torch/include/ATen/ops/upsample_nearest2d_backward_cuda_dispatch.h,sha256=aIThR8Ctr4BV2ZjYkHWFu5fX3J42Tcf_-Z4kcQZeYSs,2229 +torch/include/ATen/ops/upsample_nearest2d_backward_meta.h,sha256=HT-QIrqdCAQ_Fa8g-bbkKg5nkjoOMEY-5DFUfOF2ltE,759 +torch/include/ATen/ops/upsample_nearest2d_backward_meta_dispatch.h,sha256=d4buOUOBamCDZ9GI99fIEXFwMmlNZVLGERufOIGvu3U,2229 +torch/include/ATen/ops/upsample_nearest2d_backward_native.h,sha256=_uCBTXJ1blaUapXZwyJ500CWh5zAFDiD0BYholfV3ck,1179 +torch/include/ATen/ops/upsample_nearest2d_backward_ops.h,sha256=WHaL9ULUnoA1QHTAIj-BUTcLLospx0-CZqqRWebUIZc,2691 +torch/include/ATen/ops/upsample_nearest2d_compositeexplicitautograd_dispatch.h,sha256=dydzujLipjF9LaGU0aTg9390Xf2FojXMcUzrMob1pxI,1482 +torch/include/ATen/ops/upsample_nearest2d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=tE6WHnwSvCMyEGZIm0R95bIKEOr9JT-echFb-l9bEl4,1157 +torch/include/ATen/ops/upsample_nearest2d_compositeimplicitautograd_dispatch.h,sha256=h4uFoFMxMJG334d5W3h5lNpqrsxQLKcwX7pB4vgVidg,1058 +torch/include/ATen/ops/upsample_nearest2d_cpu_dispatch.h,sha256=zRdzO9CBgOBv3IRFuKxNhnVrtPsdz_m52R5hbo71o_A,1923 +torch/include/ATen/ops/upsample_nearest2d_cuda_dispatch.h,sha256=RMLFfd0dEh6VDbiSZf9WqocgMJtKlKnWdgDEtE5z-O4,1925 +torch/include/ATen/ops/upsample_nearest2d_meta.h,sha256=qJbYNo1qBjpWvjPyRjjtLRP3HutH8ekNH27Kw3iTlrE,709 +torch/include/ATen/ops/upsample_nearest2d_meta_dispatch.h,sha256=VDxcEb3mw1PucssKcl8B1u3FN_aa4XGJV3kTxn_iyKE,1925 +torch/include/ATen/ops/upsample_nearest2d_native.h,sha256=vgA6P0Gyw-zWjr3O_gJs_OwpoOBO8AwZRbF-GMTFsgc,1600 +torch/include/ATen/ops/upsample_nearest2d_ops.h,sha256=7A3MOWgTd5eFlx26gj81nLATtoRyjHUjg_7v5qF4x-A,4152 +torch/include/ATen/ops/upsample_nearest3d.h,sha256=WNgPZAgT1PmbW7_O3uxJUsC4BJl59fDqY9dkrgNPUq4,8112 +torch/include/ATen/ops/upsample_nearest3d_backward.h,sha256=45dghos_p626LNuauhYfeKDfbjMiJ0Fqj1Zc6B0U-p8,8016 +torch/include/ATen/ops/upsample_nearest3d_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=IJkn4Vjay52HbIs0fXvtRKw3KJFcrFHc5t2wisaACR0,1347 +torch/include/ATen/ops/upsample_nearest3d_backward_cpu_dispatch.h,sha256=ldWPRLPkGSjKBM4JxH6fZOMJtEUvlYN7cOfMh4iP9Q4,2491 +torch/include/ATen/ops/upsample_nearest3d_backward_cuda_dispatch.h,sha256=eI3Pm_JjQyykkfvO6gOAnY6VmuLajdVvDEBRJWczd3U,2493 +torch/include/ATen/ops/upsample_nearest3d_backward_meta.h,sha256=fYDO19QS6tYkQ4eUdmu87Sa1ma4dWBREcqkGkvqsgT4,793 +torch/include/ATen/ops/upsample_nearest3d_backward_meta_dispatch.h,sha256=5yLtDWmC8tMf9idRefUglHg4jZ9b_E13ua-VM1x1k4c,2493 +torch/include/ATen/ops/upsample_nearest3d_backward_native.h,sha256=K2McR9Uy8cO67hsSraSXkBD7mkt_9ZjmXFpFbmmU4KA,1247 +torch/include/ATen/ops/upsample_nearest3d_backward_ops.h,sha256=mmtINjLscUKMio9Vb8kSqiv_cgrBmGVeaxf2ENQlnm8,2921 +torch/include/ATen/ops/upsample_nearest3d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=YMvolGvFWnYv6cJQdO2xUd9jIsgGFFdLWBUeYLAv8v8,1255 +torch/include/ATen/ops/upsample_nearest3d_compositeimplicitautograd_dispatch.h,sha256=SEeroX-1Z7mo_0cjB62tABj3uvXm-j_OTF28ZARHxl4,1058 +torch/include/ATen/ops/upsample_nearest3d_cpu_dispatch.h,sha256=mJKJTLe5qaU7ckgUVvNWrt4SpiDH9woTlekkOW5KLWQ,2187 +torch/include/ATen/ops/upsample_nearest3d_cuda_dispatch.h,sha256=BFbjcGRPEZz0OH8RnnEQzyQUvCxTAAL1SdD6bT11lI0,2189 +torch/include/ATen/ops/upsample_nearest3d_meta.h,sha256=KdXNr1hPj023PLt5FXTd9jcTaTW_WhTTRMr4DS1VTV8,743 +torch/include/ATen/ops/upsample_nearest3d_meta_dispatch.h,sha256=n4NaU-8bhSofRN-dd9BimvvMexyZOewk_4FRwBAMMT4,2189 +torch/include/ATen/ops/upsample_nearest3d_native.h,sha256=bj0rz9OHS8jXvauZ2sN3V4LV-fB3cWx6oZxRqiUSS4I,1521 +torch/include/ATen/ops/upsample_nearest3d_ops.h,sha256=6vXlu4oP9ISY9xDh7Uj-60hCk7Pl5lrwUxGhpIWYXs8,3430 +torch/include/ATen/ops/upsample_trilinear3d.h,sha256=LCpZ2KBY3HvOft4Hv8T8v6uBTMGdXfUEB0bNFyhdrXo,8914 +torch/include/ATen/ops/upsample_trilinear3d_backward.h,sha256=OAuDRc5S6_qCT4rDc_SVdl2az7_fB3a_I3y10CRBFjc,8618 +torch/include/ATen/ops/upsample_trilinear3d_backward_compositeexplicitautogradnonfunctional_dispatch.h,sha256=L3Q6ZH1CmcCaUEONYSoRMSr6QmpVL_a-nU-m-CJmZE0,1391 +torch/include/ATen/ops/upsample_trilinear3d_backward_cpu_dispatch.h,sha256=6cvJ5D-1GDg18XVQYvVTrgTtvfKWqeFBf4CwEERAbnQ,2623 +torch/include/ATen/ops/upsample_trilinear3d_backward_cuda_dispatch.h,sha256=8qvEeuTaCutgQAIKyj1UY59Dt_PzVvLtSFcjN3xujMU,2625 +torch/include/ATen/ops/upsample_trilinear3d_backward_meta.h,sha256=dgiqXbDUuXKEzaKJUUltihAkV-tLkS22Mji0IupelQY,815 +torch/include/ATen/ops/upsample_trilinear3d_backward_meta_dispatch.h,sha256=tnPHVcE87-GtmM7nZRCHM-eeh0HrbdslpOSiqJ3yjLY,2625 +torch/include/ATen/ops/upsample_trilinear3d_backward_native.h,sha256=PQktrwMyni_15ZtubWnacZMcbZV0foBErDfutjYGxuE,1297 +torch/include/ATen/ops/upsample_trilinear3d_backward_ops.h,sha256=ySYgO18EpaXlc9zlXMhMuhGlLN0psf1tKvFq3TsiTn8,3065 +torch/include/ATen/ops/upsample_trilinear3d_compositeexplicitautogradnonfunctional_dispatch.h,sha256=JoU0gH3ZnUCGb7pqp3VIGaQeRxpqrpSpK2IuHkdRZKs,1299 +torch/include/ATen/ops/upsample_trilinear3d_compositeimplicitautograd_dispatch.h,sha256=c1zSIgrMzPALp8pbTeR69hMYcHwh3k6e_0E0jwV7BKc,1102 +torch/include/ATen/ops/upsample_trilinear3d_cpu_dispatch.h,sha256=_9KBJmVGwPeGHdfqNzoLPgFS3JhnBD3_3AYYp98f6L0,2319 +torch/include/ATen/ops/upsample_trilinear3d_cuda_dispatch.h,sha256=NGAkKleLpA7vIIcQ66UQIM9_-PfC41KlLeR1U65aJqw,2321 +torch/include/ATen/ops/upsample_trilinear3d_meta.h,sha256=EnLBytnftJR6XTM9kTf_WEnQyiN36LBWvhNE4aHu7jw,765 +torch/include/ATen/ops/upsample_trilinear3d_meta_dispatch.h,sha256=jfbCcqFXxIYtK6FTNGzw6QhqkpoZvo8Sz2duWsCZewg,2321 +torch/include/ATen/ops/upsample_trilinear3d_native.h,sha256=3999a4ge1idaFWznhTCs0XbeiIgp-OOLblOADVaiIUQ,1336 +torch/include/ATen/ops/upsample_trilinear3d_ops.h,sha256=J7-MeK3p1Twj7E2jNmNmdvz0tPYUlPnCW76-AxQscbs,3646 +torch/include/ATen/ops/value_selecting_reduction_backward.h,sha256=eQ8isughyo1-mPVCVfhUtaqK0Ak0msBRrh84DlH5Q40,2129 +torch/include/ATen/ops/value_selecting_reduction_backward_compositeimplicitautograd_dispatch.h,sha256=ij6aUjF4-noXbQoMy_aLMkqMysS8Pg6_6luy5kC7Pvc,1065 +torch/include/ATen/ops/value_selecting_reduction_backward_native.h,sha256=l7aR6k83IHT3FVkpXsjDmcT7hT9sTYIiCtCxuamS7Qs,793 +torch/include/ATen/ops/value_selecting_reduction_backward_ops.h,sha256=NjJ8WvCynLdhuSoR4hC_iSTZGRVXK0lZNSW5M0lBt_o,1345 +torch/include/ATen/ops/values.h,sha256=k-UqgIAIM0FDPdn91U2U28l4gzxfsE-P7w9pWTP26ww,530 +torch/include/ATen/ops/values_compositeexplicitautograd_dispatch.h,sha256=uyvonAIGjrhbHWn1Jy75pJA23E-M1Ah_ahg-9WcqkDo,787 +torch/include/ATen/ops/values_copy.h,sha256=AAnupvbzBU70H_XTrdC7G8AfuGM3IpTehEi3pJAHyhg,1117 +torch/include/ATen/ops/values_copy_compositeexplicitautograd_dispatch.h,sha256=xftxpp8ajt9nyIrR25CSSKS4ZkhDUNbDI0cIUG5XImo,901 +torch/include/ATen/ops/values_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=E8kiCktbelfbw-u9vnXU4FTiRN9BReQH2tPcH-QZYRw,818 +torch/include/ATen/ops/values_copy_native.h,sha256=FBA-BQOLL7w9OXJ1sh2Ot7XVAc2jQDA3ryLHD5wBy6I,586 +torch/include/ATen/ops/values_copy_ops.h,sha256=9L37jgrYUJaVkqSfzfzWqVWbsoOkK4_XXwMF7oX0SNo,1617 +torch/include/ATen/ops/values_native.h,sha256=LMhz4Ydh2-mcabqIhqJWTGPlp4zfYTpeTcYSSGQxD9M,695 +torch/include/ATen/ops/values_ops.h,sha256=4jps5U9ZLAeWlw_W_M0yLNXORaDhxeQ-0uX7jRJZhF0,992 +torch/include/ATen/ops/vander.h,sha256=T8FeATIs0rwNmLfvhyqyCl0BkmGQ3zD3F55uP7wI9nc,776 +torch/include/ATen/ops/vander_compositeimplicitautograd_dispatch.h,sha256=oyB3L18oH5QRzlm_8hgA51Pi3p0isWC9IXfwBOwXPgk,850 +torch/include/ATen/ops/vander_native.h,sha256=9ltXuySydeSymfG_yxZQbmPRPCC8eK5GxWNnS7c1Lic,560 +torch/include/ATen/ops/vander_ops.h,sha256=Hmkb7su8aBf8VloyPllsEtYo5pOpN6euFK2r_9gLSIs,1135 +torch/include/ATen/ops/var.h,sha256=zZCTsjn0pBMBL11xmEIX4XLUwTQjhxdavMsEFAIA0UU,4958 +torch/include/ATen/ops/var_compositeimplicitautograd_dispatch.h,sha256=C-I1JNjK3IfE8KE1yH32BkdgMyrrZkFpddnUYa4TGwU,2053 +torch/include/ATen/ops/var_cpu_dispatch.h,sha256=ul_RTCsdjuf5m6uHrx58mGhplGdTAHbJ9S_djeSgURQ,1238 +torch/include/ATen/ops/var_cuda_dispatch.h,sha256=Rok7zReh-SaOLkelSeZ2bHuZGmD8N--9yFwg8Rx_Q7s,1240 +torch/include/ATen/ops/var_mean.h,sha256=82-N0VSaEDWTMg6dIleJ7ajUeVfLGR2vqDeFgPwHvyg,3277 +torch/include/ATen/ops/var_mean_compositeexplicitautograd_dispatch.h,sha256=e36pavNPOtEriAiHJMCZJz03NlrVa3pBvrPfL6ZPq2k,1207 +torch/include/ATen/ops/var_mean_compositeimplicitautograd_dispatch.h,sha256=GPInzmCMJBTmunSmjfsk-uinGxKEEes0mmoWO0-_7RY,1299 +torch/include/ATen/ops/var_mean_cpu_dispatch.h,sha256=CMH6Pu6kzmixfgvxymSBleIgNd03uNDCpZGbTVDhjw4,897 +torch/include/ATen/ops/var_mean_cuda_dispatch.h,sha256=BQ1n8X1Z8ZZJxS29kJqOz3MWYRUp5NyjuNVgv2iUZZI,899 +torch/include/ATen/ops/var_mean_native.h,sha256=VcahC15WmKmyQO7cspNBvgJy1n11UesPtMzYAYJB2X0,1463 +torch/include/ATen/ops/var_mean_ops.h,sha256=e2enx1-mSyRDPtK8kWnpFviTmSH4mRvaKtWuDFRFFDg,5780 +torch/include/ATen/ops/var_native.h,sha256=rr_bRQoXuxpZZpF-dbnFQR5XjsEQCrPj4xJFPRwj4NY,1669 +torch/include/ATen/ops/var_ops.h,sha256=ScOktHV9wGAEvUazht-gT2xgWO8S3ymzkL7_iraaSy4,7621 +torch/include/ATen/ops/vdot.h,sha256=eM4UOl59HT7JfSpTzDKpXo8_UGKXpUMXdCz0kB7HXsg,1188 +torch/include/ATen/ops/vdot_compositeexplicitautograd_dispatch.h,sha256=E8VVhPIqOg5IcE6xcOS7AZqBUQCXbjAlPWt_R92yp_s,939 +torch/include/ATen/ops/vdot_cpu_dispatch.h,sha256=1hJTxEHinzHqsAAWaj8QZiCw9t6xZeKb4KkcYy0Neuk,767 +torch/include/ATen/ops/vdot_cuda_dispatch.h,sha256=HxJI7igtju5ecH6p0uaXmvhwtTIpX4fds6uo0umq3_4,769 +torch/include/ATen/ops/vdot_native.h,sha256=YePyO6JeBZ77KMJjm75sBkXscAG-W2MQ7R0FRFxdNNA,708 +torch/include/ATen/ops/vdot_ops.h,sha256=NaqLqyqVGEhYDaUj4n7NLqGuDLp13K2BaAtlNX8IyLw,1747 +torch/include/ATen/ops/view.h,sha256=N7rhLNCRtcyXRDd47EOv9qNKPJoN1eD4eOmAvPWp4lI,1012 +torch/include/ATen/ops/view_as.h,sha256=bFtjzsgAfIWFhbpnS7P6A9zRHBeDZK73wyYAdtpB0N4,531 +torch/include/ATen/ops/view_as_complex.h,sha256=ZeBMAtANMt4UHYF012Kw7vEFmv7-whntUVwIF9cPoqE,710 +torch/include/ATen/ops/view_as_complex_copy.h,sha256=FozQmRirUMgr3_m34OGQJemwXOrc6-5bL6vaRku8_p4,1207 +torch/include/ATen/ops/view_as_complex_copy_compositeexplicitautograd_dispatch.h,sha256=bMT-WlbPIlrTDnYVEp4Goy3XxU08BoFeAjp-oZt_w7g,919 +torch/include/ATen/ops/view_as_complex_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=FSJ0tnnY1w88IW3wbPKE2MSNjARQ6S05ZNQxvPYHpEs,827 +torch/include/ATen/ops/view_as_complex_copy_native.h,sha256=gq5h-g0Wc3IPr_jz1VUt_RN-Kk622z146EgTN7UJVX8,604 +torch/include/ATen/ops/view_as_complex_copy_ops.h,sha256=wZuG-kRkKIeBzQ5TEidMDkrzIp6NUbICYy0EcYVM_YM,1671 +torch/include/ATen/ops/view_as_complex_cpu_dispatch.h,sha256=xoc7rm0ZbzuZlmSvQlJc__2LRx2-MRdM9-Cys-3BYpQ,752 +torch/include/ATen/ops/view_as_complex_cuda_dispatch.h,sha256=9HgD-RxZNSTgRKUJin00gvQ69bKYGQwQrjQbiCoOAuQ,754 +torch/include/ATen/ops/view_as_complex_meta_dispatch.h,sha256=VyP5DiKP9fZRWCeCXDiQU6Lz4oiX_EdI-2NdeyXFiAY,754 +torch/include/ATen/ops/view_as_complex_native.h,sha256=eqNIYfNEFp3Np0ILX5b2AYG84ACs04KNOSYAkIwSN_E,506 +torch/include/ATen/ops/view_as_complex_ops.h,sha256=sxcegFjswtztoo3KDWCv-3kYYHkzQ1T-Gw-rYbD2Av8,1019 +torch/include/ATen/ops/view_as_compositeimplicitautograd_dispatch.h,sha256=qB1D47Wwuol5X9ulRcPFOM4MwR_PELg4ZTC76-qgjmw,814 +torch/include/ATen/ops/view_as_native.h,sha256=IGOKolZpI0CIKTaXNCdZY2mBx8Zvyba7tlNE5JmERTY,524 +torch/include/ATen/ops/view_as_ops.h,sha256=zrcXTyFH2AoOvzZCBQ3SnDg4Zr-tJ9kik_wgmJACbmU,1081 +torch/include/ATen/ops/view_as_real.h,sha256=W9PKR_OOWxVgpW-VZANIe928Sqa7fv-RRD93wnyOvas,698 +torch/include/ATen/ops/view_as_real_copy.h,sha256=g-M6P_EUghSsXlYhyIEF-bxrdssvtYhupusO7gb50RQ,1177 +torch/include/ATen/ops/view_as_real_copy_compositeexplicitautograd_dispatch.h,sha256=6VL11NOlTz3oSjRikYdcmpnlkYmk1qQ7ZtZaHnOUuQQ,913 +torch/include/ATen/ops/view_as_real_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=JDMpQhJ3hBhW3jPA3oont6xzZ4i-_07xDCG5CsGRY40,824 +torch/include/ATen/ops/view_as_real_copy_native.h,sha256=CgQPqUQ6hRc7Ls8aDITwA1UAa0Qnzf-UEB12p4VYgVc,598 +torch/include/ATen/ops/view_as_real_copy_ops.h,sha256=imS-n4jfSOL6Uy9wT2UJb81bbv1F8KtNeB2sjCSvYbw,1653 +torch/include/ATen/ops/view_as_real_cpu_dispatch.h,sha256=LeQXzCkAGbC6Xzeg-vJSajh7N1rUTD5tsJA2K_z3U2Q,749 +torch/include/ATen/ops/view_as_real_cuda_dispatch.h,sha256=bJGsgNinIQfyi1L1Vo0iW2-CKyRBJfim42VYAnaEae4,751 +torch/include/ATen/ops/view_as_real_meta_dispatch.h,sha256=IP2l3kJsL_B4fvAa5tEgpCH4CcQQRAUIQykOxgqqe7g,751 +torch/include/ATen/ops/view_as_real_native.h,sha256=uxprV_Qq-pWTW0pwYDV9XY3PUPGljy1XGj7qbUcHjs8,503 +torch/include/ATen/ops/view_as_real_ops.h,sha256=15pzINXhc01CWnP-mfofscbUaf7TuWhVUM-03UJ3FDQ,1010 +torch/include/ATen/ops/view_compositeexplicitautograd_dispatch.h,sha256=jQRxSQo3p8sfWgFOEFxuFVF1gZGfFmDTCJw6TuW1BaE,807 +torch/include/ATen/ops/view_copy.h,sha256=YqwEREeXXXD6gXnGMoO94XWgpdG1LeeqGZ2SxJbzmzM,4419 +torch/include/ATen/ops/view_copy_compositeexplicitautograd_dispatch.h,sha256=cxTa3cCL06Htmt5i9vEHRy-qFEd75LC52lwb76dbTOs,1381 +torch/include/ATen/ops/view_copy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=t_CPjSaDl3Cjid2Obf29g6y2XrcdJBk7aFpgtANxSdw,1009 +torch/include/ATen/ops/view_copy_native.h,sha256=Yy6nto9DUYy__hLDsljk475kHjbGhMoob2Ux0RohgfU,844 +torch/include/ATen/ops/view_copy_ops.h,sha256=z-RiOlbxxN4m-JE2ohUdqfIMpviAvU-ZVlWt6Vk98cI,3102 +torch/include/ATen/ops/view_cpu_dispatch.h,sha256=n2jPSa99qQPuhLcn_v3K-q7qjc94JJJHRbcDTBDcHdQ,849 +torch/include/ATen/ops/view_cuda_dispatch.h,sha256=WgiomnLlle5RKeVz_lKqFC6wWW3W4yDKpLstLrsyA_U,851 +torch/include/ATen/ops/view_meta_dispatch.h,sha256=FBoshyq7zxNK0fjaPdmY0K_9WzzHgwgawkyboVPSX5k,851 +torch/include/ATen/ops/view_native.h,sha256=joe9VcsacubphBxlcasSU6BlPXLLlq37uoJSjpbhOxU,762 +torch/include/ATen/ops/view_ops.h,sha256=R_eHSuvnCEDmZn2pLhh5YCX9RV4YfMwwuX8sVgBwrmg,1680 +torch/include/ATen/ops/vsplit.h,sha256=WkmDRHO9UNXqIQrC_5grrrXJHYk37OsqzQkvzFhsA5E,975 +torch/include/ATen/ops/vsplit_compositeimplicitautograd_dispatch.h,sha256=HJlRlXL2vTRqpLaoJFSb7BBMQWAVyaI0QVcAwFLn-lM,915 +torch/include/ATen/ops/vsplit_native.h,sha256=t_oWLu3_Q4yYRJ2rRhGEBuXEzYGKHzvk-pqMgnSv9Tk,625 +torch/include/ATen/ops/vsplit_ops.h,sha256=ijDXbxeRjizA7FLCmBNUSeef081qB1dd8XVfHNEz8nc,1782 +torch/include/ATen/ops/vstack.h,sha256=_xAGJFNNXHckHd7dzqPN-inyLmjlVewKQ3eWO4mEepA,1088 +torch/include/ATen/ops/vstack_compositeimplicitautograd_dispatch.h,sha256=0X2j3a0gKnjd7xX3yr1qmyK-9J73QUkm2lEolh2r9Js,943 +torch/include/ATen/ops/vstack_native.h,sha256=25Igc5K0XDE4JxYRO87JE5zx91Ouj7G9wnZjL9_9ya0,574 +torch/include/ATen/ops/vstack_ops.h,sha256=NrNqj0uPGZxjxsFlKTjyfhxoVgzQ_liujfebarWc8iQ,1585 +torch/include/ATen/ops/where.h,sha256=gE0038s88dHvUVTpWYy61IYVW4zPjlbN1SMQuVsDpPc,2363 +torch/include/ATen/ops/where_compositeimplicitautograd_dispatch.h,sha256=u9XSyNhiBrZTXoL_YZN8FEjw6qEB2z7j4cU47Y5airA,1136 +torch/include/ATen/ops/where_cpu_dispatch.h,sha256=M4GSwmWLmHp68XOOx07T5rW7r7YKZfLU3sj5gAy0IqQ,1067 +torch/include/ATen/ops/where_cuda_dispatch.h,sha256=TG4f-kt_FT92hAdMlqqiqwtn24Gfx3w15j3NLPfXt_o,1069 +torch/include/ATen/ops/where_native.h,sha256=MS840yOF23UGnAfnaznZOYLFN3gufpcoHDY1-2ZKpGw,1365 +torch/include/ATen/ops/where_ops.h,sha256=dep6yKKdyJaSJD-OfAGk15dhNH9xRlYfnGcR9a4cRss,4711 +torch/include/ATen/ops/xlogy.h,sha256=Vkstjb8S_qdvtuZFVAYeJy6-3f5pix3KqHsqBM4JzQU,3142 +torch/include/ATen/ops/xlogy_compositeexplicitautograd_dispatch.h,sha256=Fua6Bb1hL_ez14x6w87In9H80nsWxkM8f5OjkSxDpUE,1387 +torch/include/ATen/ops/xlogy_compositeexplicitautogradnonfunctional_dispatch.h,sha256=8A8cxcpB7qXa2xRgIJvTMfQu5zARFDjWovxsjAT8Sd0,915 +torch/include/ATen/ops/xlogy_cpu_dispatch.h,sha256=_uM9Ly9pe1YdAD7wbeK62Ug4sKgYssXBlxgTDyttekc,1054 +torch/include/ATen/ops/xlogy_cuda_dispatch.h,sha256=kMK1bTI2WrAJfNdpEgv9SFHzyx6nXgUHOiKhSZEUd6E,1056 +torch/include/ATen/ops/xlogy_meta.h,sha256=EYqE_M6W91V7gQvuBXgB4pAYCyQm8gZZmeVyvw29d1M,626 +torch/include/ATen/ops/xlogy_meta_dispatch.h,sha256=roo_sGcZDsgoM8OB-MPcccrHLoDbwFpsrFc1mDojyOY,1056 +torch/include/ATen/ops/xlogy_native.h,sha256=1Y34uPoEBVvi8Ij3PKZeW738GMfBPbHkqV43pAnrqko,1094 +torch/include/ATen/ops/xlogy_ops.h,sha256=e482vcl4AbR3DRH74Jkk5vUHjWdfWfY94fnpPmhbpmk,5733 +torch/include/ATen/ops/xor.h,sha256=DIA_dwRIsrksvPUST6ha_869YPfZlpkgVX84yN8PZ8c,933 +torch/include/ATen/ops/xor_compositeimplicitautograd_dispatch.h,sha256=IzDtA1z-IxQlbZJjHRh8u05qdl3zDSuBWk4HwCWd64k,1054 +torch/include/ATen/ops/xor_native.h,sha256=zsFDSID_QxqW2QdnCzKWT2UwD6wYBf9aiH7VWKtz6fc,764 +torch/include/ATen/ops/xor_ops.h,sha256=W6Zo0oi-F8cZiqYir_eTpLbgsegin38X73WvBoGdC6I,2953 +torch/include/ATen/ops/zero.h,sha256=-LDXwGvsz-dE8ZWnWIiKBTd68V3f9yb5IeMAkNeXcBg,1188 +torch/include/ATen/ops/zero_compositeexplicitautograd_dispatch.h,sha256=QlaTjwoJJnzzzu8pNeciaRJq8_6fu3aU-FfNFUiiTHM,940 +torch/include/ATen/ops/zero_cpu_dispatch.h,sha256=wc3iblUe1CUaC_wgcnlt2Kw4nOzJ7amKWFlzcetyghk,738 +torch/include/ATen/ops/zero_cuda_dispatch.h,sha256=N7oe6reV0lo-zYx-3IrjmU_qmSbxElKHiFNg1AmFKuY,740 +torch/include/ATen/ops/zero_meta_dispatch.h,sha256=bbnHpAFaXZdULTeCL5dPv1QABYrWj1kPAfCv0xW-QeE,740 +torch/include/ATen/ops/zero_native.h,sha256=gMbKhdnOMZFoW_Ge6O_6RAnxTbT8DXYnh1BU7iQuCWU,909 +torch/include/ATen/ops/zero_ops.h,sha256=jFqgrTAdgZxfy3nX5D7fQYmCSflNM5cCLarUyLJ0hSY,2079 +torch/include/ATen/ops/zeros.h,sha256=4mGuH3i5PWDH5510E3WrGR-T77JvU6pIbzDFCO70zC8,7014 +torch/include/ATen/ops/zeros_compositeexplicitautograd_dispatch.h,sha256=H98TZx0qhaVvGhg0ru7BHO57T6WE5BPunXbUTN43hQQ,2220 +torch/include/ATen/ops/zeros_like.h,sha256=_E53EpvTrEjWbx6BtfV4m1tPEXb3ghUCZI0UgtuvjMY,2265 +torch/include/ATen/ops/zeros_like_compositeexplicitautograd_dispatch.h,sha256=TM2d9v6A5k-j2pASxo0_d_31eYgDjf12r94yRdSgziM,1418 +torch/include/ATen/ops/zeros_like_compositeimplicitautogradnestedtensor_dispatch.h,sha256=GCkQe9Gtv17YbQXAeESgJZaKGq8OJgQQ_5ejPj-kgEU,1162 +torch/include/ATen/ops/zeros_like_native.h,sha256=GezOoL8r7LHWzaK5wTY6C7pjIJBGEyFTlGq81M6gP_U,854 +torch/include/ATen/ops/zeros_like_ops.h,sha256=My1msjrJsdhZDB1y53S5MxwI6OP_oxcZxlyJ-Xl_-bU,2441 +torch/include/ATen/ops/zeros_native.h,sha256=vhr72fXmjATzTPxHHyQQOrEbIShvr2F6r5IeSyBmrdU,1187 +torch/include/ATen/ops/zeros_ops.h,sha256=JwEYONEFEh4VFSxaZkMq7sfChj_E9SY2cEPsYy66a1c,3968 +torch/include/ATen/quantized/QTensorImpl.h,sha256=pqmzLHkZCxt1edGjaix8_LN9AygLSO33IeW-RfafAPk,4146 +torch/include/ATen/quantized/Quantizer.h,sha256=6U26qidaxzNDIztSk2I7ePEKYcfACFN9Z22isgJfnX4,9511 +torch/include/ATen/record_function.h,sha256=fl-rDk6yYrcGCtMOz0vF5JFGHwcKDGz6SQsAPwAx2Dc,25096 +torch/include/ATen/xpu/CachingHostAllocator.h,sha256=6zXyfMNmC1b0yNki2O0PPqfrvc8Z3-4kOZjhjnQbIHM,1395 +torch/include/ATen/xpu/PinnedMemoryAllocator.h,sha256=BK_1FPGGXjlJ0vWz_Hp3R6G04PdpZRCousY9uSjPu9A,257 +torch/include/ATen/xpu/XPUContext.h,sha256=qlZxvohditoKXBFYHEPhUGsMgtxUoTW1b4OCoo7YxuM,478 +torch/include/ATen/xpu/XPUDevice.h,sha256=oLH_-c-OsKo6nEhPBK94hy6Nba9Xw9TqX2Nlfec6j9U,280 +torch/include/ATen/xpu/XPUEvent.h,sha256=Kphr2Q21RJbCIFmS8JY4Icitn0MqiBKGGlJN-Oancb0,5979 +torch/include/ATen/xpu/XPUGeneratorImpl.h,sha256=RpehSLWR0sWI25EDFPnkqPbNHysl5NBqc8oTS_5ThJY,1271 +torch/include/ATen/xpu/detail/XPUHooks.h,sha256=WmPlKe2nvcra6wXdpS-NfNyYN9WvdtrcdIeSR7WHvl0,1210 +torch/include/advisor-annotate.h,sha256=88_l3-oFl4ZHh6FY-oabGMt99DtwUF1eZSZN-hHr9Bs,24196 +torch/include/asmjit/a64.h,sha256=djGtItNI8qS4QjX25vuRccerjwuGSo77XtuwwRxgd6s,2216 +torch/include/asmjit/arm.h,sha256=zWTOIuv7vNOzNOmYZ-sixZEYiysgToyiTNpKEMzDISQ,3859 +torch/include/asmjit/arm/a64assembler.h,sha256=BA1IVrZLL5LCXbSaAX4g0LSsYLzFbiCxu9axW9UEx8k,1355 +torch/include/asmjit/arm/a64builder.h,sha256=TwJ9rfdL9Yq4b01H337jHnkoAAcRifAUMOhpMmgjtkw,1220 +torch/include/asmjit/arm/a64compiler.h,sha256=3CPX3cJRGTsXbzZGtaKC3NiIuASyPOlhJ6ekX4twwfg,10675 +torch/include/asmjit/arm/a64emitter.h,sha256=8kU9UjLMyZjNPciWlA3Krvj2PAAvTfuHLKk73ATfVAg,48367 +torch/include/asmjit/arm/a64globals.h,sha256=Suvq6QfPkPlvH5UvrqDz1YcHQqgfjlTlqfdyo91neFE,129927 +torch/include/asmjit/arm/a64instdb.h,sha256=oc-UpYBNHoY2snJY5dumNi0PNS6LYFWyH6MuyfKNgFo,2004 +torch/include/asmjit/arm/a64operand.h,sha256=XDDq7eekIKkHWu1fzmV4eILGjwNo1xUgC473sXl6xTo,29310 +torch/include/asmjit/arm/armglobals.h,sha256=F1WJ61iUIlCbNLcWmr8Bn1g57PL4KFZXusomLyHx7pM,475 +torch/include/asmjit/arm/armoperand.h,sha256=kK5Es3gCz8AvP4qhsKc_iUG6IiPVISqdbJphwl80hYw,21382 +torch/include/asmjit/arm/armutils.h,sha256=L_ccAMeu3UbYY0Y-HIKggfXy8tcR144qmDYmB47z00Q,8812 +torch/include/asmjit/asmjit-scope-begin.h,sha256=rANOXt4mpvRSjyLRHdm8Oi4AT5xuMyJ5-3lETvxEbb0,335 +torch/include/asmjit/asmjit-scope-end.h,sha256=F6WlLszbx0jQ4m5r8vaqRy4Hgv27L9LFkQjX680zhrk,249 +torch/include/asmjit/asmjit.h,sha256=EFOZJgxFwyWJEqJ0SfiQzDS8Myp71RaHJ688KmoqAos,1281 +torch/include/asmjit/core.h,sha256=ZdZ3PKr2E88RXaOSzgVWeOAaWzhPaxWUYrGIdppXlgU,101324 +torch/include/asmjit/core/api-config.h,sha256=1MlO7ceXEJb5Gq36e3W2zpBeZZuj1berNDjuPIgStHY,24356 +torch/include/asmjit/core/archcommons.h,sha256=ycUZNR1l3WdM9zBohSWfIILfpa7-ZWYETE0wv8qzdj0,9871 +torch/include/asmjit/core/archtraits.h,sha256=A_MCYbvIrkc-GZVBprudhokys-PML9o2LetUDjBeUpU,11117 +torch/include/asmjit/core/assembler.h,sha256=QUsLRn5VE9LksQwAwfvYEiBmAagoJYKH4gv5zpy2jqE,4415 +torch/include/asmjit/core/builder.h,sha256=QxaRyhvu8EmjHYliQUKTdwifqq9X2T06BvIo4BxO6ZU,53980 +torch/include/asmjit/core/codebuffer.h,sha256=c5JZHIknD5vaOUYX9O-YYrwUoCG4IarNa8UAaZNA8q8,3441 +torch/include/asmjit/core/codeholder.h,sha256=DnS26N7QGzrmF94rburCcP8R9W_lY7CUhr1fL5FieF4,45859 +torch/include/asmjit/core/compiler.h,sha256=es0MjZI5As_6FxlQl9yVDU_kyEq65KEHEj-dFghKZJc,28335 +torch/include/asmjit/core/compilerdefs.h,sha256=NteYyu_5robbyWQY2eX_ucUrqhYufuvV2F3dMw8PfV4,7269 +torch/include/asmjit/core/constpool.h,sha256=dNnuaiBRiQcwFk4tWvizqpJqF5KwuGoNnBsH1W6fePA,7321 +torch/include/asmjit/core/cpuinfo.h,sha256=azKKObLENBUCT9CRJbEtrzGsknL4zVi8RptrRlQ3IVc,64610 +torch/include/asmjit/core/emitter.h,sha256=Ii1jDXHrIUCBPuyN5nriVgrColLLWWji0qwMu49loOs,37968 +torch/include/asmjit/core/environment.h,sha256=MSkUnNxWgu0RMCYCxnh0pJNHaJa0F6ga3UZNFertG34,19328 +torch/include/asmjit/core/errorhandler.h,sha256=QF8Q9bmGYYzH8_51xS2SG5v_CS-Oy114u3MSmLuxy7c,8078 +torch/include/asmjit/core/formatter.h,sha256=qe-eLdYKeKaCrNN_snuZc84AkbFu19mAIbyFla6thqc,8448 +torch/include/asmjit/core/func.h,sha256=qYuc6NE_G02hQTaargMEtCKur2O2EOvhJD9eT4v7Aos,72016 +torch/include/asmjit/core/globals.h,sha256=qIqTn5MZBFGeeCNuKNJt8Ku3JP9WlKPex17BocZnOkU,15343 +torch/include/asmjit/core/inst.h,sha256=W-lWRZSPxy7CXiuVL7_pTIuXUiznSP_b4itqKYC_Xtg,32814 +torch/include/asmjit/core/jitallocator.h,sha256=JXEIk0gTmIaHbXrHwd5MgOAFW2Uo5Re2V1PDwd3orIE,24153 +torch/include/asmjit/core/jitruntime.h,sha256=D_y9YN7wI0SJaYV0YmILlhzOCB2LgbsEgthTxuu8D3Y,3243 +torch/include/asmjit/core/logger.h,sha256=FHcKLHoVzDgNhJxrmSKB5EUpuGP9gtpR_32pIuERjp8,7440 +torch/include/asmjit/core/operand.h,sha256=p_jHDCYR-vnrGhJiqP-FRe9IwBwFizyhwUz32Hq1KBo,87694 +torch/include/asmjit/core/osutils.h,sha256=pCyUFxb3YUMJryM7-ucK4poh5j6JbNrgn0XOkHjx3LY,1252 +torch/include/asmjit/core/string.h,sha256=1paEIKv01ntQaKhCGWEDZ7TL4RvVquCQhRYebkLZCgg,13479 +torch/include/asmjit/core/support.h,sha256=1RPqdvgbq7GT1H3gd0D-pH1Sj2G3wzvefyD9sKi8V3I,76848 +torch/include/asmjit/core/target.h,sha256=u_AFXpA-_OgrpMqgn2iTfKTDndNaJAWfZkbHfguEAQI,1582 +torch/include/asmjit/core/type.h,sha256=x8KmQYzC7y8gX6MsJ5OqzF3ye8EJQJ4PWSGnBp6rBsk,18171 +torch/include/asmjit/core/virtmem.h,sha256=gNq4lO6YrsVCReuGGkOX-fA7xVfSSES5aDaK4j4jfOw,13602 +torch/include/asmjit/core/zone.h,sha256=kGF9Ulji5a_fZyJKd561MT2cnvVP6yCnfdf1uVhEe8k,21290 +torch/include/asmjit/core/zonehash.h,sha256=8gPKg1Bqziux4H1-rPPpwfjIM43CPzJMTakzOp4h22I,5444 +torch/include/asmjit/core/zonelist.h,sha256=G7oR5aCxxmOcvPIinwNJ1UKm8dK5pyBUwKHkBpDaFFM,5002 +torch/include/asmjit/core/zonestack.h,sha256=txC2Z_rvI_TSHDICohVm2NVBoQYISFRBPsZrRBNEWRg,6399 +torch/include/asmjit/core/zonestring.h,sha256=kB355b76knsVj5WtR4-8tQEYXw3ikGFECH5xIcvUSg0,3348 +torch/include/asmjit/core/zonetree.h,sha256=mdzxc7mB4HPgh49SZbe9-DYjVfq0Qf1vRsTr0Hv7dRs,11723 +torch/include/asmjit/core/zonevector.h,sha256=b3JKdSuxd-C2OH4i6xyeyAl_x6R5bCF3PzNXSg5bJUk,23270 +torch/include/asmjit/x86.h,sha256=2DfiBdRVR1_oZobrg9JBm3xTpOGglPNDivtM704zVF8,3985 +torch/include/asmjit/x86/x86assembler.h,sha256=Yu8yu6Wpuha5DeUbXpmgE7FxRwwo7p0FOHbkedFjjO8,30106 +torch/include/asmjit/x86/x86builder.h,sha256=8siYYqfiji2lVYUQEmq4-cft1vl6C4P2yPPzzs59UD8,14367 +torch/include/asmjit/x86/x86compiler.h,sha256=Z1THi3dNdOvKbC1mMQKXsVYf-pvfGqoR2mSbGCSSTy8,31234 +torch/include/asmjit/x86/x86emitter.h,sha256=NQlz5KZsNzCcqLgV_d7pofjKDTm4ph159RQZ0XJmPEY,313874 +torch/include/asmjit/x86/x86globals.h,sha256=G3a2bgUYYQXHVJt62xBM4GV86UWk3IyrhSnu0xYhVYs,163006 +torch/include/asmjit/x86/x86instdb.h,sha256=AJP8YxhWvCuNt0MsDGkDvldRpeoV6XfHpdGsV-VIjxI,29916 +torch/include/asmjit/x86/x86operand.h,sha256=NgwibwYMf4NtMyLwpgoDicHh0eDHJiTlX06SqND5OWQ,60164 +torch/include/c10/core/Allocator.h,sha256=Pw5loug9xGH7DRDVNT5ly0tetGHivs4Oy5rWD4uJ6_I,14188 +torch/include/c10/core/AutogradState.h,sha256=ZhM1n11mBPE8SX2uupQ56GhAnF0NdCnldwdyicROFD8,1727 +torch/include/c10/core/Backend.h,sha256=atgIJBSt1xUumivPO0dQcSoFImO5KsvLsyUmoyF118M,11683 +torch/include/c10/core/CPUAllocator.h,sha256=bqA3qxrMIK0HL6Quu-mCy1wtQy4ftnyjlI6tQGFzxV4,1747 +torch/include/c10/core/CachingDeviceAllocator.h,sha256=Csyj5eZge_igYduq0YJVUvdLFKe_4R0AQa2vZ0Aak4I,2064 +torch/include/c10/core/CompileTimeFunctionPointer.h,sha256=Qtv7YgDwtd8j7wDAHnP_kjuS6aE5fJn8RjIgmj121ZE,1757 +torch/include/c10/core/ConstantSymNodeImpl.h,sha256=P241GMoIkFyDl2xOGWtr0tKsT_rmB-oAx7qbhFPESUE,3100 +torch/include/c10/core/Contiguity.h,sha256=40K5-LOx6kzyMJHeWDDtAMGzB3KrPTQt5ZFeQTSsOe8,4296 +torch/include/c10/core/CopyBytes.h,sha256=CqIxdD94B-C67drW7bXpRuXmChAnny9wmjLebeghAR0,1391 +torch/include/c10/core/DefaultDtype.h,sha256=4TOAW9l8_QQb4A0knfTzOTS7Wj4I-pHZ54knSe2yXhk,409 +torch/include/c10/core/DefaultTensorOptions.h,sha256=NqWkP5yL8imFcEs9-nRyw1kk7SEMkMyBjn8lth-MPx4,1109 +torch/include/c10/core/Device.h,sha256=S9U5qNNxSz1XU8Ennew8jcB8U_kJlJm39Cym_i4N7Ho,7111 +torch/include/c10/core/DeviceArray.h,sha256=P7ppZwyYIwCp-h9VxBrx9Noq42brN3Mi12LpupKRdDM,711 +torch/include/c10/core/DeviceGuard.h,sha256=GRSDjGoEF9iXag7g7kxa96GLcxkouhWBWH0TJ5nXnnM,8012 +torch/include/c10/core/DeviceType.h,sha256=rBgZeiobReeyGm14dUu6MViAiigpe6IGBqCSkqNea1U,4599 +torch/include/c10/core/DispatchKey.h,sha256=fjuU2bsU43JSujmtGkWkeX53KGc17R8lkXU6miFDegA,33117 +torch/include/c10/core/DispatchKeySet.h,sha256=mXTUJ3fCSvNn7GQ-e7B9XuQnbP5wyyddb2SNHuRULUU,42706 +torch/include/c10/core/DynamicCast.h,sha256=bAPgDSo2kIEXBFmP0qtmNcsHW-Lw5kBXGrqe0CrGo3o,4569 +torch/include/c10/core/Event.h,sha256=TuhUhRY_ynKX8EtC35Ep8oDqo0lI_ntobI99_aN8PEE,4600 +torch/include/c10/core/GeneratorImpl.h,sha256=MiWj9-Q916YzqoDt4MO7oyVqphbMjb1hcQutRid6ZYQ,4048 +torch/include/c10/core/GradMode.h,sha256=47vzeqGaYa5YupZ7mTBQh277OPwTmragPaoJrHV2iHc,1713 +torch/include/c10/core/InferenceMode.h,sha256=_OiqGDbD2GsAtAgleocA-_Lf9JJDn2VbaqTXFIhcHfE,3854 +torch/include/c10/core/Layout.h,sha256=_1HYkkioObPgW8RvtzbQaOPyhOWBttoHhzrA3667T78,2021 +torch/include/c10/core/MemoryFormat.h,sha256=rCpJfqhf-kvJGSKqnpw7Bx07juBpuCCiOZ-3GlP1jOc,9691 +torch/include/c10/core/OptionalRef.h,sha256=MaWkZyUjzzjcK1TL74XxiAkG3DBHqtiFdNhSYfaked8,552 +torch/include/c10/core/PyHandleCache.h,sha256=li6VXz_1QDxMFy1kwowqBJJepMJmmWxjCSWnn3KgPfo,3177 +torch/include/c10/core/QEngine.h,sha256=6TliQpegM-W1EI37iiGbua2f6bRc3x-DNqAHp84ZQgs,1056 +torch/include/c10/core/QScheme.h,sha256=7j0NnGluksYgRWKanM0nrOlIH1DkJe0UbCfWmHHpui4,1616 +torch/include/c10/core/RefcountedDeleter.h,sha256=cwTSGQsRjcoVct_a4Qjb4DD1LgXyA-tqmZ9K4RQ_vHI,2303 +torch/include/c10/core/SafePyObject.h,sha256=q8TRdh3z7_m_1Ws6fb31u30VTygEyjLZ4MgD7R98HlQ,3861 +torch/include/c10/core/Scalar.h,sha256=Xb2BxMF7YSbzveOY2mO06drfg7Bgca9l9pkJHQJDR68,14458 +torch/include/c10/core/ScalarType.h,sha256=GIkDqcN5upcSGVc051WX-pHvc3bDOpOleYE-tyYrLg8,25354 +torch/include/c10/core/ScalarTypeToTypeMeta.h,sha256=Yb8CJIsuEpR5iKhSBXsIe1ZOE3kKw6jS1SEVJpnNXXA,1414 +torch/include/c10/core/Storage.h,sha256=evyRcgbcZUnXPLTfy0wuLYtCwEVeQrjNVvssh61SskE,7331 +torch/include/c10/core/StorageImpl.h,sha256=KtOtNb3tC4p-ThTzh6DavdkCislImFNmM80IlZMJW8g,11754 +torch/include/c10/core/Stream.h,sha256=vbajo_A9mHAJ9n3hvioeLGtr5GN3YHskoPqSscQwUrA,6541 +torch/include/c10/core/StreamGuard.h,sha256=b_nbBTtaG5AZyjTB3tyKgQuVAVpTq9TeIL8l8tobfZU,6763 +torch/include/c10/core/SymBool.h,sha256=isiOOG901Nz1IQg7D-hCMm7aWTjpK_LyHeIPvz-14sg,4635 +torch/include/c10/core/SymFloat.h,sha256=--ySFNfTYX69XkqKqo1FaavJ6ZARDbGm0ohf0Hrg6iU,3675 +torch/include/c10/core/SymInt.h,sha256=lyN5owiUwhvxMFg2pNxR9gCBbZGDpN7ZEJ54ltZ6iXQ,14328 +torch/include/c10/core/SymIntArrayRef.h,sha256=6Tfc76EmWZMYXpsaRvR2xLCIkczIJM64E5XWZn9HJjQ,2803 +torch/include/c10/core/SymNodeImpl.h,sha256=JRGtHSiVEPrp5in_OmmE5mzC5N1HWw3Ftuy9lY9YIO4,7827 +torch/include/c10/core/SymbolicShapeMeta.h,sha256=tRHea6rgWTdFit94bXHuiYYGMJYuqdQ2zG-_Dec1cnk,7565 +torch/include/c10/core/TensorImpl.h,sha256=wVrY14VnXRdRV3m-VWuVfv9KKv6tVZBuifqApPkHVlA,117789 +torch/include/c10/core/TensorOptions.h,sha256=HL2X6cuhQDxTnY_GzhCqa-ntxdk1xru1-qwe5Ezs2Eg,27733 +torch/include/c10/core/UndefinedTensorImpl.h,sha256=Trm1I2YLk2_XVOW3XIDawzjv_Zyajh_hrwqIF9gYq9o,1270 +torch/include/c10/core/WrapDimMinimal.h,sha256=RVeBQnnUrB7tDolSO1lt4PSc8wxz-O4c970myF9Fj2I,1408 +torch/include/c10/core/alignment.h,sha256=QsRyJjWPY5U05bGLLwoW_CMJxh48Jdy_icDH33bNk8A,585 +torch/include/c10/core/impl/COW.h,sha256=yaU9qBov0GXwDZ_1kMM6OpDHoh42JzteCgb-N6cCK50,1088 +torch/include/c10/core/impl/COWDeleter.h,sha256=cLPaLNuT94Gc-lJ--NVfsKcDG5pwUGKhnNc4MulCmJs,2164 +torch/include/c10/core/impl/DeviceGuardImplInterface.h,sha256=0BBGhrlLyPL1OeTu_dxCthlMEEQDRqs4C5BV69fqaWw,13928 +torch/include/c10/core/impl/FakeGuardImpl.h,sha256=78KZG-_Q4vDrWVX3XoHzCM_R60Cj-UeE6hLxo8Ip9_E,3273 +torch/include/c10/core/impl/GPUTrace.h,sha256=N1AIMkj_KzUHcGrB3uBHtTIv3qeq_xH1dyhf19wLnik,892 +torch/include/c10/core/impl/HermeticPyObjectTLS.h,sha256=P0hNLP7LmSy7vKjTXf6JNLUCnRpUES4ArPafubV3uo0,2608 +torch/include/c10/core/impl/InlineDeviceGuard.h,sha256=PQSeBpomvC8i9TY-K507fQ2aOD7HMwrBVDWi-QtphJQ,16344 +torch/include/c10/core/impl/InlineEvent.h,sha256=h5Gz26FbK_MTeyN18AgSvXpwFC6GOEp0Vo3Rw9JM4B0,4455 +torch/include/c10/core/impl/InlineStreamGuard.h,sha256=K26HG8eC90oUnI5BoFsRI9hw9Yo8WzoiZYBgtbZSLlA,10174 +torch/include/c10/core/impl/LocalDispatchKeySet.h,sha256=ifimMBO12nuxnOY7WeSGjTzBkULjgw3IIvWM-xX4Hcs,6702 +torch/include/c10/core/impl/PyInterpreter.h,sha256=j3KaUHwK1JrF1L1Lx6ZujKF4eshj6VUcJGZzSTHqID8,11591 +torch/include/c10/core/impl/PyObjectSlot.h,sha256=vtmvt3eDxc71yBK5YV6Fnui3xYRn4v_PsYbVU57xT7E,8357 +torch/include/c10/core/impl/PythonDispatcherTLS.h,sha256=wjLK7XU2OoADqp_Ke4qbrD4z7T8Pqm9DcfFuLjgTF_0,869 +torch/include/c10/core/impl/SizesAndStrides.h,sha256=n5yL18PxZHDWKUoS1Bt2p7ksoGTo-XR2_x6w1v5nh2Q,8992 +torch/include/c10/core/impl/TorchDispatchModeTLS.h,sha256=I6wtkhL6mJsngjP99VuJ10RDAc9W2Di67rgsFWncfMc,2348 +torch/include/c10/core/impl/VirtualGuardImpl.h,sha256=5DAZQUrFIRZaGZKHx-qKw8juxyvAxxpbpxsxKeAYnJg,3410 +torch/include/c10/core/impl/alloc_cpu.h,sha256=hIJjYjjRh2XCTlznXwOWxl4JKd0uCkMCiyKnd9PoRM4,699 +torch/include/c10/core/thread_pool.h,sha256=9KSES5oj1cVeE3Hb2u4h4iqjx1NFmkGnF6bvGbN29ls,3117 +torch/include/c10/cuda/CUDAAlgorithm.h,sha256=PkpcRRGz2M31usHY-HM3UvG3-GN1o2EaHAa6Uw1DxlI,1072 +torch/include/c10/cuda/CUDAAllocatorConfig.h,sha256=uRy9YTvABo9sELmeiMBSAUjBHSCAiOMtv2kkBgefzxo,5270 +torch/include/c10/cuda/CUDACachingAllocator.h,sha256=z2Y5A1UG9lFSxll9Kj0uSQOHpHyspr7EirKOb6U1HR0,18640 +torch/include/c10/cuda/CUDADeviceAssertion.h,sha256=fNEtfhFYUAuzQPUJ2yTpRrnfOIM-DPc31Bn36Uh1vp0,4260 +torch/include/c10/cuda/CUDADeviceAssertionHost.h,sha256=hK9zQPDBJidFOB1yaPN0a6BSpKjjvC_nT2reCJV_rMU,6795 +torch/include/c10/cuda/CUDAException.h,sha256=NNXsMmcGaDauVYAZ0QW6SutjpRR_ozb9dAAHU--Hd7k,4508 +torch/include/c10/cuda/CUDAFunctions.h,sha256=3Ads-CBkgDRI7nV4cALh8Ee933k4pfl6hty-0vvR5ek,4080 +torch/include/c10/cuda/CUDAGraphsC10Utils.h,sha256=3HEvxxoXf3V5W0ryuUKe0h421SuwCTqnEqIPXNUkvZU,3073 +torch/include/c10/cuda/CUDAGuard.h,sha256=noDNlclbih9bmWARYG4c3D5sDGdTupaqI9Kfg6XwPew,11651 +torch/include/c10/cuda/CUDAMacros.h,sha256=qVfPpZL5u0zMaHAu7ooLzHPYJQtb7PRJc1TieW-SZC4,1530 +torch/include/c10/cuda/CUDAMathCompat.h,sha256=Dh6FPVzbwtZ2D0KSnHSQhGhwwkC__1UstoRm43-pRK8,3698 +torch/include/c10/cuda/CUDAMiscFunctions.h,sha256=tIxBpAtFQuflrO9PS1HtWQ_3LF6miD9IxKtupF99rQA,318 +torch/include/c10/cuda/CUDAStream.h,sha256=IU8EFReoaym7tWRlWsxYntRsOrtQOSgunBaZJ1WgU9A,9896 +torch/include/c10/cuda/driver_api.h,sha256=11sypFT6thcJNUpVUmwelw4dfkdaSjjnMwDuruD1nSQ,3533 +torch/include/c10/cuda/impl/CUDAGuardImpl.h,sha256=Fbt-E3IM1VnRZqxc8NzZtUhbVnos00EoPVJNOjyTk-s,9562 +torch/include/c10/cuda/impl/CUDATest.h,sha256=9ClCwSnel2tvAcFMBN_nEDI5x_faXUekDpUpwHv9ukA,123 +torch/include/c10/macros/Export.h,sha256=D60GuHypnvJJlD7kWHdjDo9bEVmKZ-8AZh2BzbNw9I0,2605 +torch/include/c10/macros/Macros.h,sha256=TbVxQrglW48pGYOk8gGt0QFSNg737cmhtksH4bIfQ4w,19202 +torch/include/c10/macros/cmake_macros.h,sha256=3uIcn6SP6ytGsow9EZcSP6mQDzprVjggpmuXtovhu3E,451 +torch/include/c10/metal/atomic.h,sha256=Tmfz0XzEEdIfYB76GI8M586V8T19mXIFu-PD4tb0jbI,2969 +torch/include/c10/metal/common.h,sha256=a5EoaWp1_gmO8EPexN1grV7PPE4lK0_8lXweG0C7VEg,1479 +torch/include/c10/metal/expm1f.h,sha256=4QMb1xalCKV3bQiUKewb3zOvrcuybzdhmIhLbk4r3ro,3693 +torch/include/c10/metal/indexing.h,sha256=YyPbV4FikEYlqBUy74fDsL5YrQX1if8uAG5mTUtAjeM,24593 +torch/include/c10/metal/random.h,sha256=vQ5wBdA_Bp2wkgxJfwZ4d8_Gb2JWBxl5YgxQt8DJNHc,2295 +torch/include/c10/metal/reduction_utils.h,sha256=0YwQiEjmHOpbQWnbdwy-f1YAhO_hVWSE9MbSJQQyMLE,5544 +torch/include/c10/metal/special_math.h,sha256=PGvDZ9H8RVFIlaVF-DAOnLeaDK_5cqpfPrnRP9oNt_A,52147 +torch/include/c10/metal/utils.h,sha256=PIDGu59SZkK3dq6L6qzKRAq8BkiNcsuUqytswGp0GwU,8200 +torch/include/c10/mobile/CPUCachingAllocator.h,sha256=aRendNtLEleFdLJGVXxutf5k88qn1r70EFhhRmQclyM,4271 +torch/include/c10/mobile/CPUProfilingAllocator.h,sha256=NfPhfNBa4rYLR81bvrpBhee-nYfKDeTAHlmQJVabuWs,4888 +torch/include/c10/test/util/Macros.h,sha256=MnKFKSBMhif6iUNNfRZ2lVLFYE2x08m-wahFTJcUljw,185 +torch/include/c10/test/util/complex_math_test_common.h,sha256=fwJ13-pVA6_Gq390say41OWxwjW6_den8XKcwaMD8ZY,22631 +torch/include/c10/test/util/complex_test_common.h,sha256=b0OxaYLhDqXU0GOx-YcTM6QpdCqymY7aEyMVH43BBic,21204 +torch/include/c10/util/AbortHandler.h,sha256=ujDWRKtwl6vHRourVpDS-uxTfiuT-49RyJ2ieGQtmfs,2275 +torch/include/c10/util/AlignOf.h,sha256=Vow6fLPAmaw50-uGorJM9SXA-q_DePRScuzXAjJ45dk,5082 +torch/include/c10/util/ApproximateClock.h,sha256=a1P_BAW7MWwE4VneT1RLBTPYpjHDOk0sCCVXhRQNk08,3597 +torch/include/c10/util/Array.h,sha256=4jX8gB1pKUtj4h2KCR1gCSX8GkhugTkVJ1AP7oeoCqk,468 +torch/include/c10/util/ArrayRef.h,sha256=KMuIoEPNUa4WtVpCxEvduW4KnaWBttZZe4yq4QkIofM,11263 +torch/include/c10/util/BFloat16-inl.h,sha256=uxKFaocqALYCvjOVhotf89u52WCSG7jg8Q8t-Jn4_m8,10605 +torch/include/c10/util/BFloat16-math.h,sha256=gbAMdYizIDrZTAxtE7-GsEEKO-VDsySI2-AU5AUBM8c,8816 +torch/include/c10/util/BFloat16.h,sha256=1jiZcRJCxXppOHGP6r1uLgBMWJ0wP0TMc1wAgLq4wcQ,3290 +torch/include/c10/util/Backtrace.h,sha256=m29_pRzFaJ_uhOYh6a8cMNOCmbDnZXLIdCZLak98PX0,828 +torch/include/c10/util/Bitset.h,sha256=-04rWF5Q9cbBhnbiwKHU8QrUtFwxtmCwbg760xvPst0,3539 +torch/include/c10/util/C++17.h,sha256=vMckK6UB7NGlw5eV9JhO-EjGAtKg7cafQ2CfaLLNgII,2489 +torch/include/c10/util/CallOnce.h,sha256=I-v_rNP5ddKmMWcyWWLRWH1VtiReKSl6yrzuUIxWsiY,2133 +torch/include/c10/util/ConstexprCrc.h,sha256=NKrdWhdczSvNnviRjIyPlS6XkrBdk9iqBn7QvUQ-ENk,6727 +torch/include/c10/util/DeadlockDetection.h,sha256=tNO7WTdwcdi1i7zZkYWxPsVEkyelUmIf4CfGofqoUDY,2269 +torch/include/c10/util/Deprecated.h,sha256=QBkJ24TmV6Jjw355jBhQnI1TeSvWcMD7a_uDdC3y_mA,3681 +torch/include/c10/util/DimVector.h,sha256=zGFxEDdeRR15Iqabllcx1W6ch4yJzDZDPDZyKoCpAnk,461 +torch/include/c10/util/DynamicCounter.h,sha256=gds0mbFHHQJjUXRUpDPOOcatAUUYd7fhXYMEtIQmgEQ,1353 +torch/include/c10/util/Enumerate.h,sha256=lAYnklmUKea_pB6SFzJAOSPU2Ov3ZLrXEXiNLMnf9E0,4160 +torch/include/c10/util/Exception.h,sha256=s3Bwp3iPt9DvhsPMlW6qm-VOwtnuDf_3NqgIqTSduiE,30507 +torch/include/c10/util/ExclusivelyOwned.h,sha256=NR3B9rI2vLLDQQHsQUi_BcLFqVKBBAjjEHInr3zWdLg,4593 +torch/include/c10/util/ExclusivelyOwnedTensorTraits.h,sha256=KvRKPSE6bR-bKt3R3rZfFJL_VaAD_tZKot7-VeUYX9U,2269 +torch/include/c10/util/FbcodeMaps.h,sha256=SJSQ8MgjIUamj6vrUFZVep8U873y3QAS0ij2Zej2KWM,757 +torch/include/c10/util/Flags.h,sha256=6iGVRyUCmE5hvT_JxUn0h3GXd7-ONgaDLapYAaRG8fg,11142 +torch/include/c10/util/Float4_e2m1fn_x2.h,sha256=H7ubpiLcyAloHjhEfth7Hdb0Gt9KLvYLiQsckljjK0g,890 +torch/include/c10/util/Float8_e4m3fn-inl.h,sha256=1NtlAKrsh2hwf3wJg-nCz22Mdxq_ZDU8BgIk7VeHNDQ,8840 +torch/include/c10/util/Float8_e4m3fn.h,sha256=I3_cTd0OnCIsE6zqyfa6TjtfdDj-k5_I_g52ff9P1M8,8386 +torch/include/c10/util/Float8_e4m3fnuz-inl.h,sha256=DZFzBgZMtQL5ZAQnnSyxv7zr2lfdXM59LGNC6V1kuPg,9275 +torch/include/c10/util/Float8_e4m3fnuz.h,sha256=7ZEp41apohF4ELf2YZ5fxHSud63RXZu_QZn-Q5Spt8E,3953 +torch/include/c10/util/Float8_e5m2-inl.h,sha256=7gNHBuyy8K5aB4xKzy_bVHCk5hW_enihewXolu9Wasw,8938 +torch/include/c10/util/Float8_e5m2.h,sha256=HZorSbogIAb2BNd4yGZBhVXVa0ssCcraAtyc6YrkxCY,4472 +torch/include/c10/util/Float8_e5m2fnuz-inl.h,sha256=hiskc7xHOf_XeoZRpPE5RilH797R_7UQK45lPTfnbFA,9510 +torch/include/c10/util/Float8_e5m2fnuz.h,sha256=Y4zolDc4NZxxvacHSzEJLNS48neKEcGcmSWHZhuBnBc,3981 +torch/include/c10/util/Float8_e8m0fnu-inl.h,sha256=6RKMFBihqp0yJj1YEY0UKrGhU9G2sBFZQ8sEfpIHamE,3897 +torch/include/c10/util/Float8_e8m0fnu.h,sha256=Bro8xi8ZZY-KmscEY_plcnFUMCjTigv-3YUB-92BcOE,3218 +torch/include/c10/util/Float8_fnuz_cvt.h,sha256=49KTipRc-wsE72WwRYsYty8Ys0g-qFtmyzE9ZPIWGtE,1796 +torch/include/c10/util/FunctionRef.h,sha256=rhSZh0rfjfawCqtxt-7xVjP4ZGqqXA8ZRMHdUUI5h7I,2369 +torch/include/c10/util/Gauge.h,sha256=-YZIuodEL1f73DAx9Qq9xCB8wq5AD0cHUTH5ydYJvXM,1225 +torch/include/c10/util/Half-inl.h,sha256=Tq8fx59YJi4gzfOB_XcCUdY1e31KYx-UlAVMHeSZtPY,10502 +torch/include/c10/util/Half.h,sha256=GQPJ2bAa89OlXhUxsPKjOKgYAnt9WxUE80eNcDTqQ_4,16764 +torch/include/c10/util/IdWrapper.h,sha256=Gyk3y6BYwI31FQtH4AvGuuHzfFLVKRKv2JxEcZQS5lo,2413 +torch/include/c10/util/IntrusiveList.h,sha256=fD2KNO2QYNhuHqJcTTYMT_cuXCy4tz4lItNg0esmIyw,4701 +torch/include/c10/util/Lazy.h,sha256=iEM0qwedVFjSheESRNMksC4nud-LvX6X2ufosnr8abA,2943 +torch/include/c10/util/LeftRight.h,sha256=h6r_hS5EMcvysuCZRcGhILU8Rucyr62uxVtwCuiQW-o,7438 +torch/include/c10/util/Load.h,sha256=7q6yQulhnuYn3qh7XxH-lBB8cRRf8RKHC7CDJNqTGx4,962 +torch/include/c10/util/Logging.h,sha256=AJt7r-aGokRcLbYkJT98lPBE96Ercz_UJe-O6OLBIIw,14457 +torch/include/c10/util/MathConstants.h,sha256=T_nRX-_RnhOKAMJ552WH5GwDRBSfPxqwFdPElthflv8,3832 +torch/include/c10/util/MaybeOwned.h,sha256=2Gok4htmd4BkvVeMBfB-BRb5ZzgB4QVsUxkUItr8wFY,7392 +torch/include/c10/util/Metaprogramming.h,sha256=RbKZjs5wH3EIrjhhSlhvcZUkni55U8hOwNDCNimVfng,7255 +torch/include/c10/util/NetworkFlow.h,sha256=C9ILbWH04QxY4yUVg5GdlBadgjv-27ompbM8Rbmu1mE,1191 +torch/include/c10/util/Optional.h,sha256=_u1CygRqSmhAjrJ1QWBZTdXIOnge1hvjgqHYdgN2aAE,2055 +torch/include/c10/util/OptionalArrayRef.h,sha256=JxYmjjpvgIQYf_t8i_hiTRmytLodfzJVWBqFhzuymY0,7399 +torch/include/c10/util/ParallelGuard.h,sha256=IiIWfO8lQ4OEdHuZlmeSWzRMyQJIdv1wyD07QkfLcxw,393 +torch/include/c10/util/Registry.h,sha256=gd4PTiFAtb5FO84g8rnZ8YTii7kce-9M7S2p2zv71uo,13589 +torch/include/c10/util/ScopeExit.h,sha256=nZw3J70LmIhAumG_DJKU5QgAVqJzA6Qx-1OELuBeOS0,1309 +torch/include/c10/util/Semaphore.h,sha256=vqkrxemSzcsdjKGxcLeuFl733vi2jJWk3e4rMOZuGhY,1489 +torch/include/c10/util/SmallBuffer.h,sha256=D6oIZWdBVl-IJzojI_mHzws9QQYJPiESQGaVb8-x4a0,1849 +torch/include/c10/util/SmallVector.h,sha256=RzZn3VQXVKvb6iq6e19jA9KPu0HutuqSoAujZPaWiwg,50472 +torch/include/c10/util/StringUtil.h,sha256=AXsocdeccL_CQAyFvagum-Zo35Id-yV75AD6hOOVpAM,7064 +torch/include/c10/util/Synchronized.h,sha256=1HA5rMQ9FEa-FKYMRG_QaDph1j-i8Jq3Px7dBQBZKog,1987 +torch/include/c10/util/ThreadLocal.h,sha256=OoKxldpp2VmTXC3BKdxoBz0j-vZrEfqRlAeAuv6pNRc,4176 +torch/include/c10/util/ThreadLocalDebugInfo.h,sha256=tGzWncNFdkgTwUtU8xcJnwpBoanXzesd-MiHmxVmjZY,2750 +torch/include/c10/util/Type.h,sha256=2rBcjd6k4fULIfjaKstAi-WEfDDMOAa8MI_QzzULY08,676 +torch/include/c10/util/TypeCast.h,sha256=j9oRVbFHoyXgizeaACj-0BurtTak2pJxovdWQJVFhYs,6875 +torch/include/c10/util/TypeIndex.h,sha256=UfR4A4EqejS8QU3q60wBytxhHLo5I0NTzNaECTM9Hlc,4820 +torch/include/c10/util/TypeList.h,sha256=JWFT6tsHhDE-ZV-6PcBBWr3bmePX04gYL0uVDgR6xLs,17344 +torch/include/c10/util/TypeSafeSignMath.h,sha256=pudcT5Vt_mdWtaHNBV0itRUErR-7mna6wj4ZLVIjpQA,4503 +torch/include/c10/util/TypeTraits.h,sha256=BTFqiz4XtVVACPYEqdS2ULVxdw9FRCw4fA5pIXg_etc,5489 +torch/include/c10/util/Unicode.h,sha256=lpUvrT67LUeCakllHJCKTBzkNcCZHT4mVbnGMZtLqmA,309 +torch/include/c10/util/UniqueVoidPtr.h,sha256=fKsnsHNodnU7EhohEynDVQNS0I47hk_sm2zK-OX8uRs,4710 +torch/include/c10/util/Unroll.h,sha256=3k7_PM75c8paRTq2CpLj_4cvDn3fdjw9B1IXnnu_U7c,873 +torch/include/c10/util/WaitCounter.h,sha256=E8Zxt-Z7tjbpNux-pCb9R7S7wxI5mtsAhm0RAg84us4,2687 +torch/include/c10/util/WaitCounterDynamicBackend.h,sha256=HZdDVg63SSoLXK2geocY9i_6afv8sU4n6FKfPIdlHRk,683 +torch/include/c10/util/accumulate.h,sha256=9zdW_RxPab3W3_-3_aab_aW2Idd4Z2TWQboiH50YVh0,4156 +torch/include/c10/util/bit_cast.h,sha256=zTGomP3-CMLUx4biuITnogMKntxFdLx0Z7SUvcIiEiE,1291 +torch/include/c10/util/bits.h,sha256=NrAdOy2mgs3RjDkY0zE73PdiYNBY-uNg4cauCwrgWp8,1510 +torch/include/c10/util/complex.h,sha256=9EptwtjtfD-QlRgmg2poe0UchS2sNWaYQeJz0gQ-Stw,20196 +torch/include/c10/util/complex_math.h,sha256=cgu1AQT-CgcTrJ0f7_sAvevGcA_lNPqYOJZMFBff9Bw,12939 +torch/include/c10/util/complex_utils.h,sha256=3Yngew12_8eTlbnsceI36rKrzxi5iO6WSaE1Z-OnI44,1123 +torch/include/c10/util/copysign.h,sha256=9RyUcCYzAY5PkzFl41cAKcY2nSS8dRjemAXv6gkM4G0,859 +torch/include/c10/util/env.h,sha256=89FjqFItz_PPzIbJcnHnlH6EDp9fRBpO4s5-E_V-aKU,925 +torch/include/c10/util/error.h,sha256=eo-t7Li9Zgxd4CAPyv2dLXdqGfRPc5-mvs_uTHdaNxM,216 +torch/include/c10/util/flat_hash_map.h,sha256=3yg71Csw4RBnum-Ya8BgSdll3hc5YM8MQmUJKZ2kqn0,64033 +torch/include/c10/util/floating_point_utils.h,sha256=Iu9zi_5mkaQC7PMxa8q-IcilYzX-uvje3ooTGHuRpNo,831 +torch/include/c10/util/generic_math.h,sha256=-ncCZXvf1G5prh5eM9Q02PdxeA2dqS2xFhy21KSZQEA,2994 +torch/include/c10/util/hash.h,sha256=CPznZphGBx1Hb2FL9mriCFQ1IE9ODw0rc6pzn6TvMgg,11485 +torch/include/c10/util/int128.h,sha256=YyLqDlFb6OEOUBigwS2SJOz86hZvb1wmBq-7DzVQG5s,12850 +torch/include/c10/util/intrusive_ptr.h,sha256=fPHnXz5R4kfEat1jSfyaF2QFvnFTZl1N462LWwEjd7k,39599 +torch/include/c10/util/irange.h,sha256=0f2P8mttyJfD9Pos-GxqNOOAg-rLGIPhUU7TRteFOkw,3563 +torch/include/c10/util/llvmMathExtras.h,sha256=JO8YS-DFmWkZZFUaXRiH-WIhHpr03MieNWEUgyi9cfU,30348 +torch/include/c10/util/logging_is_google_glog.h,sha256=FjFpo1Ie5UVCcieQeyYyQ6jRGBHK71M6x1kIBSUSlBs,3903 +torch/include/c10/util/logging_is_not_google_glog.h,sha256=-jZBy-pEJ4t5KUi4gao3CIavPKnfCdsMJ7fjtv4oF68,8939 +torch/include/c10/util/numa.h,sha256=CZDHyRgVpsHAbkqHL1uQC38w2BC0vaKyrAKhM9qNZtI,754 +torch/include/c10/util/order_preserving_flat_hash_map.h,sha256=JIk_MiQDTp7LOEAuTYSvfOHeNQQC5ZSuOA_ZzNRSe48,67669 +torch/include/c10/util/overflows.h,sha256=4sSX4GTYPaEbGCw_HRSnOLUBa6VSzcWcrH1gFZNtg50,3567 +torch/include/c10/util/overloaded.h,sha256=pWxJ5C3KJA6hiG238bVvpf9kzc4nTU_0wFkC1fSFGwE,758 +torch/include/c10/util/python_stub.h,sha256=VCeDRiGbtYhM3F7bhp90T2hFV0hOlI_g8iecAD84L9M,60 +torch/include/c10/util/qint32.h,sha256=4qZaXFtZctkC_WlN6T7s84vyC9Fn-iIWA6a5w4_wAuI,337 +torch/include/c10/util/qint8.h,sha256=bC8LXXkIZ2vxE6KxotIjMCPdn88svMHZ6LT5mtFsrH8,492 +torch/include/c10/util/quint2x4.h,sha256=2mNg6Tj78dJYc0bAShraGFNGewSzZ2E2CS5bAjSsyd0,385 +torch/include/c10/util/quint4x2.h,sha256=vxaKT_0BXlQU1Ytn68cUwj-0o0kgcZ1pONQsmXYDilc,385 +torch/include/c10/util/quint8.h,sha256=Y079DsrjnloXht4_oHu7RVr-rWi-Vz-ZJ6QaiBz6N0o,338 +torch/include/c10/util/safe_numerics.h,sha256=UasPMg8Tyhollrd656-JjDubbzNrxGXzcVXbsBWDFJQ,2350 +torch/include/c10/util/signal_handler.h,sha256=Wqn2whBuCO3Y6CHQBZJ7zW7UWPeWgzAe4Y0Pe6Hv_Qs,3886 +torch/include/c10/util/sparse_bitset.h,sha256=1EWTfU42hbr5JKG-oihf-_q2EnopQZ1FQSAEAE_cX2c,27570 +torch/include/c10/util/ssize.h,sha256=1yNGwkdeiXHnJI1yeRGRBMEoQMlEXQVHXWb-UnDu-pM,1415 +torch/include/c10/util/static_tracepoint.h,sha256=5jHsYnFs-Jodv4EASQvu78f60SxO0mQsdybEHjtiuqc,1110 +torch/include/c10/util/static_tracepoint_elfx86.h,sha256=KDu0MW-JMi-W3QjLhwj3CIPP6nOmcFtOYn3PaLac0fQ,7817 +torch/include/c10/util/strides.h,sha256=NZ_hqBRgcb27DdB2ZHzIKZA2NjjGT1vLrfTMI62ssJI,654 +torch/include/c10/util/string_utils.h,sha256=2uUbVpr0WvrA4ojqooVb_1Ikw6d__TAPjOyPSxLt_XM,468 +torch/include/c10/util/string_view.h,sha256=dJw3dpgD0Fob-tlEKul-ROZDcwhI7BPv9Nf5HDdtc4U,19053 +torch/include/c10/util/strong_type.h,sha256=Xyvi4CDJvay01Nh48_n_RFFzhwL8J2qYRVlRI2M1fgI,37198 +torch/include/c10/util/tempfile.h,sha256=IvQs4AiEMvPF9WLbsf06i6BER7rmbUvZJ0JQo5hYh_g,2849 +torch/include/c10/util/thread_name.h,sha256=pl56GeGZSUzVtC7unjdVGYRNEeAr5EnU-tXa1H4pg9o,199 +torch/include/c10/util/typeid.h,sha256=H5fCrdMRBg9SGKjrgfuhMHxME2qxAyMSfKBwFxi2xpY,24045 +torch/include/c10/util/win32-headers.h,sha256=Zgjt_19ZYB8ZUL12KSLv--q7rzs52rwOf9qvVdI1VgI,918 +torch/include/c10/xpu/XPUCachingAllocator.h,sha256=olW7f6jR9vMfBKIIecbeUt0tVe4lO5B9CfpK6v2tKxc,687 +torch/include/c10/xpu/XPUDeviceProp.h,sha256=moRO-WLsA4-5JVQjMtsKj5lOmlvNO3vDXWRqAf1VaiA,12629 +torch/include/c10/xpu/XPUException.h,sha256=2OR7PgcRlFFyTmlsv9BwNV51-XrcvLp1vkDWGyrRBW8,437 +torch/include/c10/xpu/XPUFunctions.h,sha256=z7Bq3UCTyYGeb-QyCycjrXlrQKh0Tv0Lzy62sYS6YHU,1314 +torch/include/c10/xpu/XPUMacros.h,sha256=IzRLsmSJKd1BoXKj47nJRzrlxru-yYSNJW2RnRzJt8o,903 +torch/include/c10/xpu/XPUStream.h,sha256=zl2NcmWZHYoRtGu85HKhy1cWtdepIAN8TcnDdqDhhjU,6935 +torch/include/c10/xpu/impl/XPUGuardImpl.h,sha256=hdOVi34E79yS0J3gZzEgyLWkDdRILEvmwI1ryLh68dU,7256 +torch/include/c10/xpu/test/impl/XPUTest.h,sha256=Ja8BcBlT6wpb9smo_ozcJTj_iMv7MuDsd6vbjhuj0s0,487 +torch/include/caffe2/core/common.h,sha256=YMy1-75xkcPm55BVeXGnncWBHPSerePCLKSgrClt34I,1420 +torch/include/caffe2/core/macros.h,sha256=VQTQ-BTyu8Dmfd91cx0YxYR0xO_ExGs6Kbf8oGGGZAU,2497 +torch/include/caffe2/core/timer.h,sha256=IMrjxi9ztMA_Z3hkHt6xp_-ylEB7_GgmnanXmZQwdJY,1266 +torch/include/caffe2/perfkernels/batch_box_cox_vec.h,sha256=m9kV2r5MWqDHs2FsfWLG44XLu7S4VzAz5pzl3gMTQxA,9671 +torch/include/caffe2/perfkernels/common.h,sha256=cHReHPpBGFORB3R6Tv2c5UJUHl0VtvyzSaEMZoOyB7w,6289 +torch/include/caffe2/perfkernels/embedding_lookup_idx.h,sha256=NMmeMuEwFWfFitzB5DknxhiZx3oYf_sjqntRNv1BoA8,1731 +torch/include/caffe2/serialize/crc_alt.h,sha256=-njmyCvsAUaggjiT_xe8hK78I3jIeZ3inf_3xqURJLk,76840 +torch/include/caffe2/serialize/file_adapter.h,sha256=pqdXMFs5RlSlCkGwCOPl9X5Wy5vLOXqIe__oSAdGxl4,902 +torch/include/caffe2/serialize/in_memory_adapter.h,sha256=E1tr23qX0f4vz9Zl3XIyukpm4jdXm3Jx08sAylz-LMg,671 +torch/include/caffe2/serialize/inline_container.h,sha256=hzRtSg4U11eMD_UcJmyVpZnZrd4lzYeOf4pQVI248Sw,10844 +torch/include/caffe2/serialize/istream_adapter.h,sha256=U-IJl8oVGnGCCe1x3a_fK3WDVaB4OjHIHAlX_7MUK9I,696 +torch/include/caffe2/serialize/read_adapter_interface.h,sha256=-9INOfFdWhTKfYcM5dShGZervkqitBsJckrdVmTlwOg,579 +torch/include/caffe2/serialize/versions.h,sha256=CpTjKy5rx-Ry7glOjW71vZZw8XvRABJOWyW9eCtBIO0,6781 +torch/include/caffe2/utils/fixed_divisor.h,sha256=3kjW3gNN8czzoj86HgY_FAAjYFubjIKVbo94iYAVjf0,3664 +torch/include/caffe2/utils/proto_wrap.h,sha256=aSbp1CMYKftEqRkbf9QJTV-viEhgHjj9pNapd3_06Vc,1297 +torch/include/caffe2/utils/string_utils.h,sha256=AsPaZQdJvUXW58v94YVD80P7rN7FNddh87mG2TUwujA,1257 +torch/include/caffe2/utils/threadpool/ThreadPool.h,sha256=qT7K9GTUBs2N076wLiNFUjD-oMKWJ2thNsM1aBp4CoY,2455 +torch/include/caffe2/utils/threadpool/ThreadPoolCommon.h,sha256=_xiSyQfQkfJ-xBAuxUBIOlEsQrLFbOyUgVRTpKgG5I0,697 +torch/include/caffe2/utils/threadpool/WorkersPool.h,sha256=JwCkJRXVGINgN6xWzCypnGndiu-vMkbgbi5Ffg5LDvw,12014 +torch/include/caffe2/utils/threadpool/pthreadpool-cpp.h,sha256=2cj9asy-5Ke4xS5KyWgKn6ThwPwTecY9hFLONGc020I,1519 +torch/include/caffe2/utils/threadpool/pthreadpool.h,sha256=X7hV7u39tMHn0WlGW9loh2cQ7cpoBGEZPbfv674vQBM,6537 +torch/include/caffe2/utils/threadpool/thread_pool_guard.h,sha256=IdCXhmsn9T3gp89LQPJuzv9cQvrHRV5pmnW_2TXPY0E,590 +torch/include/cpuinfo.h,sha256=ZJOilHLb47gS_gmUg3Ty36G53-fO8IT4IIjhaDBDFBI,57495 +torch/include/dnnl.h,sha256=197ADCTk4KfAFiwPzxlDHbipAi-PNnUchz1HcQV7sm0,848 +torch/include/dnnl.hpp,sha256=bkt_n8-s38qI9DhsjZqND8YKdsnGr_T62RpFzhTFn8U,856 +torch/include/dnnl_config.h,sha256=ljcLU2XTZWBcQe19UeLWX1BBv4b_taUxd86kHZDwMug,876 +torch/include/dnnl_debug.h,sha256=t8eV2XF0PY2f06zMek1bykicN-qJyUYWwwaG50FK0s8,872 +torch/include/dnnl_ocl.h,sha256=23IJYtbIm1B0T3LqCNgOqnVx5Nk60QV2UygUpSc-S4M,864 +torch/include/dnnl_ocl.hpp,sha256=A7hJIrmDu7Ql4zOyPVA8E-gMQVqyhU5-43LHejHCKv0,872 +torch/include/dnnl_sycl.h,sha256=x1mY9vbUhlz7UjQMDyaMqqR8NlTyXQkT1tsjl4ihpM8,868 +torch/include/dnnl_sycl.hpp,sha256=XAnrFdrzu9ee-O1NM4ZL77YcrB-yfOIg1mGFl7DTIuQ,876 +torch/include/dnnl_sycl_types.h,sha256=q8w4_h8djuUrkWhFPtNDZi9OM1eKs_ZPDmmKuA57_iY,892 +torch/include/dnnl_threadpool.h,sha256=fGaECxUqTtsCn_WkCMp-9d5YYnA9WzWU_zNTd3JXMCs,892 +torch/include/dnnl_threadpool.hpp,sha256=JUl2UErwRBqgIT_Itlvaw7xP8toM5F35a8kPqDisClM,900 +torch/include/dnnl_threadpool_iface.hpp,sha256=dG9j4_Cu7T4DJiu5ejGNi1eN6OlPdoXA2oCrvk3yfAE,924 +torch/include/dnnl_types.h,sha256=2xkCNCfJxnlX89OOlCsViG_iiuJDBDKARww4VqYdb0Y,872 +torch/include/dnnl_version.h,sha256=Dg-UqmjPaJxjkddEV7DbjqTjHHBzlefYGZsT1rmyro0,880 +torch/include/experiments-config.h,sha256=qsZ5gKHOlaUxvUt8WDcpKxJ-xYtPBWKNwh7PBpsztqE,537 +torch/include/fbgemm/ConvUtils.h,sha256=V_xhmY9HBDX5mtiSgQKUGV8t-WPUZavnS6NpENCf0Cw,6439 +torch/include/fbgemm/Fbgemm.h,sha256=ke0lBtfQGA07zK1up91aSH1PPGmQJc_CdyiStjqe40w,44295 +torch/include/fbgemm/FbgemmBuild.h,sha256=Es1UPz8H-BAK-6jitDLkekQYiMF4Lv1pguVzuEWBVU8,3430 +torch/include/fbgemm/FbgemmConvert.h,sha256=YiY23m6nfg_luhGgCrMZQFRUTfh6y4BfYooWLnuUUEg,4671 +torch/include/fbgemm/FbgemmEmbedding.h,sha256=4MTmuyPkGGlmcl3FfPQwRIgWe47t_hxn1NpI9HPtloA,12524 +torch/include/fbgemm/FbgemmFP16.h,sha256=QSPuLLbt8C7RZjI9u6vNdZMm4Q1c6EkE446cVIDOdF0,1457 +torch/include/fbgemm/FbgemmFP32.h,sha256=T8NvKu7vK8hdIG6VWI0oS-kh0nWe60aEKxpgX897FCY,1327 +torch/include/fbgemm/FbgemmFPCommon.h,sha256=ewKby_N4EA4D1cOuA0TaJIG9fah2MLSjwpvW19hibhk,10724 +torch/include/fbgemm/FbgemmI64.h,sha256=2ui6csRe0392OsjOcTH3Hymj1EBn_bHExURDTzs42Pk,612 +torch/include/fbgemm/FbgemmI8DepthwiseAvx2.h,sha256=4AeBVaxDYRJnsJQTYtxoK_G9oFqLuzelTFkXsvNUBJE,3360 +torch/include/fbgemm/FbgemmI8DirectconvAvx2.h,sha256=BcAJCJkpSBmEMwWBrNQeeLRqa4an1IURu7OEhuuANZQ,1636 +torch/include/fbgemm/FbgemmI8Spmdm.h,sha256=AtPELx7TUmK-nXCtg_2ZqraxwDEPIbAo8DVWvZEGP54,3528 +torch/include/fbgemm/FbgemmPackMatrixB.h,sha256=tkkEdD6FRwUICQ2ybIoSIaWW8KeBqAeZGOoslZ37V8E,9369 +torch/include/fbgemm/FbgemmSparse.h,sha256=Iii8puQrqhMWG4KWevTe5KpDEa0x8ns6qIuZEbkut2E,6601 +torch/include/fbgemm/FloatConversion.h,sha256=2LcDLFvfoEv5SHWkQ5vHk3KVbewhVG4ACBH4TVlRguU,11625 +torch/include/fbgemm/OutputProcessing-inl.h,sha256=HrKF0ZhV8PvJS5-GhKscJaGlDS8b2f-DGx6WAQpXL30,10519 +torch/include/fbgemm/PackingTraits-inl.h,sha256=m4CXwerkBuGo9GU5CRnB1kaDU8RJ0xpYRzgODzN6zyc,21321 +torch/include/fbgemm/QuantUtils.h,sha256=ejJv-Lvy6S7tYg1xlsjAIHKDDGF5vcy-uLpzalnabKQ,12592 +torch/include/fbgemm/QuantUtilsAvx2.h,sha256=3iAXW9EIRjakGUQZO7DujS07BfUKPMiwxQiXSvEXdKU,4778 +torch/include/fbgemm/QuantUtilsAvx512.h,sha256=bBuJI57l4hU-l-ro_2hWR5ldIzaXzuR9mtlkqT5Vb2s,939 +torch/include/fbgemm/QuantUtilsNeon.h,sha256=nVdytq5IO8c8oeHB_YZ99bp1vqmCwhaWxlEWbjqt7xI,848 +torch/include/fbgemm/SimdUtils.h,sha256=9uKhOW7sZ26PeZzCr4XOVbLEjEy_ymaOx7Jf-RMqaY8,1813 +torch/include/fbgemm/Types.h,sha256=k2Hk5muQfipHdHEk-tWcT1_zpUITPACw3e0avHI5pag,572 +torch/include/fbgemm/Utils.h,sha256=F4Al__Pt3kzZnf58YEExhqjdiM9bvfhXrNBJ6zZlpkk,13750 +torch/include/fbgemm/UtilsAvx2.h,sha256=ecit6WHbJhU0k6qyDeM8AqmqH2gIy7XOfqq-nVOgoDA,2379 +torch/include/fbgemm/spmmUtils.h,sha256=zhJ-eZiD8qTY_oTM5i9kQ4TGvfYlRCHqQoPuWYw6asY,1448 +torch/include/fbgemm/spmmUtilsAvx2.h,sha256=p7kgmYtD3P1eiXm8DfOIFapMtG1cn1hAdbAODaZZ1sM,1130 +torch/include/fmt/args.h,sha256=H2r4yYDQFMLgG3cfZNdvSkoLdmO5v9kFsmBq7qJGYtU,7400 +torch/include/fmt/base.h,sha256=vJ68BRXRjjQhzobFJsrlmTiQIc1fNwF2w6AhepxCnZc,106979 +torch/include/fmt/chrono.h,sha256=hgv_j1UIY6opwA2hG3qlGFZqotrUKVDeVIK9OVL072w,82048 +torch/include/fmt/color.h,sha256=AR0wrCvOwW6Y7HcJyqxDjuvwW71R4GiApEeevb8Jtrs,24927 +torch/include/fmt/compile.h,sha256=B2D2oXvGepPFne2xGnfuBbAmzchz6DdRB7Eh4Kt-If0,19331 +torch/include/fmt/core.h,sha256=C1urBHRZZCGaAaCXycmXUwmxmMk4q-V6TdGFyM0k6hI,192 +torch/include/fmt/format-inl.h,sha256=TG2oeXz-Fh-KW7u13MHFtWE1XJoOaw6DrXKdg0uk4C4,82933 +torch/include/fmt/format.h,sha256=EESGzdbeV3pZW2aqax2XMPg7pj2JSff-L7kM02RvKis,162037 +torch/include/fmt/os.h,sha256=x8NdrXS2cTLVbzuIpj9MvahvDC6nzKQFLP9rdvbWgjw,13213 +torch/include/fmt/ostream.h,sha256=PjDX4pnVh-zxOAvDrQgM2sOQyLNDKXy8ez9aPajY5Ao,5191 +torch/include/fmt/printf.h,sha256=mxaO5_bzD_SRK_XDiAoUqO7VjWucadr-YlE-E6UJYao,21073 +torch/include/fmt/ranges.h,sha256=FhAXrPbt_Z8S1sp2lZUUwN3I2nXOlIydJJFFOuOEgJw,29061 +torch/include/fmt/std.h,sha256=RJ_XpgodoNVQbb3xgQ0wIu37hpRPDi3y6TkAfAomMFA,23005 +torch/include/fmt/xchar.h,sha256=QEPyJDX63IQDOAGC78J6sZdgrF6s2LGZcOK6L41x7mc,14005 +torch/include/fp16.h,sha256=9wY9ucslvU9l-I5jOXPjiekCwDjQnYFjrU70U1iCpnI,152 +torch/include/fp16/bitcasts.h,sha256=jX2p5o9IuZZOYrr5b-oN8L24uJHMykS4yVxofgRppNg,2245 +torch/include/fp16/fp16.h,sha256=cAzp2gXu5TDSk7ml29wpi24tBwgOwiGWZ5OUzavpJBg,22453 +torch/include/fp16/psimd.h,sha256=mV3GwrsOSuqRaaaO4gttcpMn1ZW52te-MT9loDhidb0,6976 +torch/include/fxdiv.h,sha256=VvlB2ahMMqgqWvClGyE8B8qJhLkOJaEkxW6aLArpmIQ,13574 +torch/include/google/protobuf/any.h,sha256=Ge9EyYYtc0ovvfvpKPkjh4pFp3zCrPWtrdPGpBH_ds0,6330 +torch/include/google/protobuf/any.pb.h,sha256=e-lKiwm1Cz-lFGYCKa7L515J6_aC8SJNgqQ6Lg0j-Z8,15856 +torch/include/google/protobuf/api.pb.h,sha256=nCxAK-DNBFIIg_mB8iu40YxkAEe4zPA8Vni21xB-QbM,57585 +torch/include/google/protobuf/arena.h,sha256=57lJpSjOK1NcLV9NkXDNmKgTJB9_-tt1H8AYmTGZFyA,32065 +torch/include/google/protobuf/arena_impl.h,sha256=NXjpdAchIxNqsjdg1zrbMYQOOojcjquF1bg9zGa4_lc,14685 +torch/include/google/protobuf/arenastring.h,sha256=d5FEeFpGAYqtk7q6u2IzF7h326dxLT4WBjlXB69_JXQ,15079 +torch/include/google/protobuf/compiler/code_generator.h,sha256=8BIQPvMANw9Wm5AN0CeFLTANm8nNwCego7TMoIpRXmM,8045 +torch/include/google/protobuf/compiler/command_line_interface.h,sha256=exbuW_ad068OAUHgapOSORpUCe5y9a3VUGgFOldTFok,20892 +torch/include/google/protobuf/compiler/cpp/cpp_generator.h,sha256=NoTJ_0Grl0OATalaC2Di4qLcRgdpCrVZzmRbPcWa_qE,4275 +torch/include/google/protobuf/compiler/csharp/csharp_generator.h,sha256=mNHaWz5EznYnkYzeZp5VjwRsXEl6ff6hs6LlIEtm0vs,2872 +torch/include/google/protobuf/compiler/csharp/csharp_names.h,sha256=KzVEcDMO1ibG2tmZSyPn9Bn9ES42B5ZAZdusux0RqjI,4197 +torch/include/google/protobuf/compiler/importer.h,sha256=TaKXPDIF0kDhUZo-ABHKlcpq8vbHGtGloU23F1cOlTg,14592 +torch/include/google/protobuf/compiler/java/java_generator.h,sha256=xul7sXt8K16U8OJG0LYunSRA4uMWa-Gy0SsAihCb-Ag,3140 +torch/include/google/protobuf/compiler/java/java_names.h,sha256=NLe5sr-Ss3wkXtjK4Io2NrWnfGc1HowFlRtMH04eg-E,3721 +torch/include/google/protobuf/compiler/js/js_generator.h,sha256=vV3Tx0FQFx9RZsfu4HlRBd7fbKj0im_7CH-HH2WJJi4,16269 +torch/include/google/protobuf/compiler/js/well_known_types_embed.h,sha256=pFjsxe2HcOE4CTBtCmPX-h3YqY7-eDEPfMBYp-95Gf4,2024 +torch/include/google/protobuf/compiler/objectivec/objectivec_generator.h,sha256=mMBSBuPxyPQDX_z7XuxfiJJLc8cuKZrBymaX-Iz5llE,3512 +torch/include/google/protobuf/compiler/objectivec/objectivec_helpers.h,sha256=luBHKPF_PTGUXGoJTx-rfDtORqlneOcLmg6UmSnh5SA,12622 +torch/include/google/protobuf/compiler/parser.h,sha256=b2P4GOE0kfhqBOgtmHfBYuyFQQKdShwzk21S-vrPp1w,28428 +torch/include/google/protobuf/compiler/php/php_generator.h,sha256=Eryflqz7EmP5SuQXIJMwHA_AjQgJ2XR9Dy74V6m_SoQ,3694 +torch/include/google/protobuf/compiler/plugin.h,sha256=ANI_Q87lM3Mij1wgmD6ZUiP0Zwyz7WnBd9Md14BTSX8,4443 +torch/include/google/protobuf/compiler/plugin.pb.h,sha256=gvde8enbxPwUGV1OnLtMfrskHLBMR__Aiw7O6T2rnoo,78145 +torch/include/google/protobuf/compiler/python/python_generator.h,sha256=R7gUXEM3Y_V3iEvnJ8131LhSG-EaiHN3fv38W_b8qZc,7919 +torch/include/google/protobuf/compiler/ruby/ruby_generator.h,sha256=PpgAHc5GP-Kffse298wbBqoHR3SBHwXh8nCC97fZ8FI,2879 +torch/include/google/protobuf/descriptor.h,sha256=h6yWawGrqJh_VS5nApNVIP7mhY9CUUzRNKxSZ5tDA0k,98817 +torch/include/google/protobuf/descriptor.pb.h,sha256=3SK7CUvuOLhK_kIZIsUCC9teWiHE0MNwYEnBpAQnSfQ,557651 +torch/include/google/protobuf/descriptor_database.h,sha256=IWD3ZNF9UfVnCwVlzuKiSdtEim3OLnzhDZs921L5MEQ,19159 +torch/include/google/protobuf/duration.pb.h,sha256=KKuh5fguYPPdPxucfnP9ln8uNcKK-UQ7u8IKgk78R0I,10141 +torch/include/google/protobuf/dynamic_message.h,sha256=-oEboRZYGXELFfrUB-JFM3ZW9q-mEoMJi9VDWoa8m_k,10225 +torch/include/google/protobuf/empty.pb.h,sha256=6A2XPXqdNBBMPCVbbFuFpkav04bf8qG2LBMZAw62fFQ,7895 +torch/include/google/protobuf/extension_set.h,sha256=s_65Z89dNd3X_UsyqfWNNTWbRJdD-ckCHeUj4x91va4,80049 +torch/include/google/protobuf/extension_set_inl.h,sha256=kCdW187CehfUkWJyFqhc-wNhzNj_27xvFEcAzJiRBZc,12713 +torch/include/google/protobuf/field_mask.pb.h,sha256=Lk9g-yecJhKU-nH0m0nN553iirPzSVJSiueu-qp5eYs,12139 +torch/include/google/protobuf/generated_enum_reflection.h,sha256=ICAODkNIyEg4H-eT5COpyK1kdXLQwJpzolPmcfi-Gtc,4091 +torch/include/google/protobuf/generated_enum_util.h,sha256=GF6T_Y6vkYDlU4OriGrlmVMvp0QMmeXoucI6frM4v2Q,3349 +torch/include/google/protobuf/generated_message_reflection.h,sha256=RHe3qC-qabXriV5Ojrs0kl_9cauuyyPkkpsoN9HdXF8,13375 +torch/include/google/protobuf/generated_message_table_driven.h,sha256=PCrYOTqfBGtkeXnCqrJRlJrl9YMIHMZ3_Opw3Z5L8Z8,12949 +torch/include/google/protobuf/generated_message_util.h,sha256=ia9SMnjXN54d870C3mNEsH2_bifNYmroXRrCcYHpKOw,9820 +torch/include/google/protobuf/has_bits.h,sha256=hRoK2Imit2tLZeMx4xExavhTbr2y4-G08Wq5kZvlcVI,3644 +torch/include/google/protobuf/implicit_weak_message.h,sha256=YDJjaMvj9lrwzJtbhf4qbwvZoMX_5z2DUFvMlpw_xYA,7204 +torch/include/google/protobuf/inlined_string_field.h,sha256=LgIOsI3lBYKaUXyortSD3oe3C5HCUi0z16gH-cK_W_I,9592 +torch/include/google/protobuf/io/coded_stream.h,sha256=J9dM7dYEa0xEN-8szSisS3bZMPS3G8SczIlA5-DfY4E,71090 +torch/include/google/protobuf/io/gzip_stream.h,sha256=3sq379q8c2L3v9LRj-YJDF7s5yOLs2WAPL8f0hrlvMc,6884 +torch/include/google/protobuf/io/io_win32.h,sha256=Ao1qb-z_6kdFzzuwOAfwgbnpXHbNC4fMcD6XzO3V7C8,5523 +torch/include/google/protobuf/io/printer.h,sha256=_Mx2HN-0F9Ip4Vr_RZ4qhm2iiGJ2eRL8Na0vklT_szI,16319 +torch/include/google/protobuf/io/strtod.h,sha256=Rrdr4b0XsXMYDuNNCgcGe88opErSDtfPvE7ZPWWztrE,2486 +torch/include/google/protobuf/io/tokenizer.h,sha256=RtMWEB51YNn58f3B9Vu2w1wk_SY0pXmWMuLphi4_D9g,17170 +torch/include/google/protobuf/io/zero_copy_stream.h,sha256=EJDS2o7CPvX0Y0Jny4VGUtXOtB_d56o07-5Lg0lRsLA,10511 +torch/include/google/protobuf/io/zero_copy_stream_impl.h,sha256=OCneH7X4bYT_wpo28PiKZGgQr_D0sTFRSRm3V4tE2p0,13526 +torch/include/google/protobuf/io/zero_copy_stream_impl_lite.h,sha256=O_BSFAUzsrDvEu8ZJOFjzuMbflc32sxn5ltb0vkZ70s,17232 +torch/include/google/protobuf/map.h,sha256=lI2KgubnTvZ6zYcpk8Zd34jLgyjXpf2FcPmlviMtX2w,46029 +torch/include/google/protobuf/map_entry.h,sha256=A6-xUtkiMUOkr_VQurHZxdaQDGv3263sVdLdCFBMVbE,7443 +torch/include/google/protobuf/map_entry_lite.h,sha256=3PCTqJ_hGx8bdjxC0jfnfrbT6gJ6aC4rz9LykCha6ug,26549 +torch/include/google/protobuf/map_field.h,sha256=M4UNz61jn3GnkbVTF1YLM5AxI3Pcmr1LFqg3_UXNlPs,32269 +torch/include/google/protobuf/map_field_inl.h,sha256=iX4u-CLfPcW6o_MB-oKHZrKs1LJE9f8KEH36GmUCsIo,14869 +torch/include/google/protobuf/map_field_lite.h,sha256=NMyI3YENp3g28CqVjhM2zEBg7QnCufQH2fY8QtOkmkw,7749 +torch/include/google/protobuf/map_type_handler.h,sha256=b2d-vaQD3-J71kacxx15Ue54RjUFP4A0K75dMspSMAM,39889 +torch/include/google/protobuf/message.h,sha256=Nw2L3VPOcYJUoJU1hmnkta7BtsA0B3Odd6B3vm20_ZA,63570 +torch/include/google/protobuf/message_lite.h,sha256=4yyDLhenmidYrypwxK-JCecBsMHCGc-Y0xuCOmrjolw,27178 +torch/include/google/protobuf/metadata.h,sha256=ETQ89HbN5rpryCb-wF_4jU1vCqntR3t_NPP6-r1pdfs,1893 +torch/include/google/protobuf/metadata_lite.h,sha256=9kMql6_iLt4ayfTvhXLZSypA7tyCazslecT2JgrqETU,8508 +torch/include/google/protobuf/parse_context.h,sha256=wQiqBDReoptoxhJwwnPX7yCiuPQpvVZsGyPUKqhdpuA,30423 +torch/include/google/protobuf/port.h,sha256=664-zCuZAp87iKzn0uQjn4yGRSOecoUS4q8k57ZfDG0,2093 +torch/include/google/protobuf/reflection.h,sha256=zXMMNn778SRyJtLB0MEfCNMSiZqWYoc8-EndelWb_7w,23280 +torch/include/google/protobuf/reflection_ops.h,sha256=ynPXGlopQnQhn6f8N-3C6X3zFYx5w8EBNWgnVC7VcAA,3901 +torch/include/google/protobuf/repeated_field.h,sha256=t_ZToL1PBLAIzurbcOl3RMI2M_GzmD428eVaxQ8wslI,104438 +torch/include/google/protobuf/service.h,sha256=eCPoZVQ-kb62ExxHS-_y_DLy80sMzsBOTofx3ac8ThY,13452 +torch/include/google/protobuf/source_context.pb.h,sha256=Mqf0W2EuzkfZRQo1xdDhbyB5B5wfA_9zvX0UVqGeqLw,11825 +torch/include/google/protobuf/struct.pb.h,sha256=cSdvlcUuv4hFIVhhq0S7UNpJsicnOwDIyh8LDRD1nTc,43476 +torch/include/google/protobuf/stubs/bytestream.h,sha256=-8YEqd5LAzp__x6gQJwAHIZfyeuLBzVP29aqk09PCMg,12110 +torch/include/google/protobuf/stubs/callback.h,sha256=-44WZZQvYAhglxLQQ9eG4TkPvd6hul0cec5aueYaQ1M,17661 +torch/include/google/protobuf/stubs/casts.h,sha256=K1qzWJwXH2PZx6qVBpl0wHvIpwNe8bzwotTxuWO-p1k,5872 +torch/include/google/protobuf/stubs/common.h,sha256=Cey2z0OzIfhOIn8rB0sIQ4V8eKBz_kaUmZgYflMNoLE,7485 +torch/include/google/protobuf/stubs/fastmem.h,sha256=aue0bu84RB1fEDrjyrScm-NVfrXEdiEZXNuqH5dXoC8,6173 +torch/include/google/protobuf/stubs/hash.h,sha256=CT5B576qKeJXvEfgTi-OFYJme1D1B5WUG3APj1-jpwg,4238 +torch/include/google/protobuf/stubs/logging.h,sha256=5NZ5ueYBHEiTC1Vaz2Kd7w8Oz2BNy4SR1Gi0lquIjHk,9156 +torch/include/google/protobuf/stubs/macros.h,sha256=k7R8MMAxqhnYIr0e978uRukLLBgEX5E6kwUT05O2GKo,5023 +torch/include/google/protobuf/stubs/map_util.h,sha256=K2YiSLAfduI5-rO37KUVkGdJSI-Z1PADLDuA5B9pPbk,31982 +torch/include/google/protobuf/stubs/mutex.h,sha256=kOeau5N3SN2MOgzUGjSkWKBjaq1f_xdc33I4nkwEW1k,6343 +torch/include/google/protobuf/stubs/once.h,sha256=lAejegVRA6f-Auy-nc_aqddnNcA5IUsxwYJK81luzTM,2239 +torch/include/google/protobuf/stubs/platform_macros.h,sha256=2KxSSK2gu6JX-SfLguHQNgWjmCWpeR5UtiDUFyczi9Q,5242 +torch/include/google/protobuf/stubs/port.h,sha256=T2AvQKNhNF2c--sI7UbJSkBKllqVfbX2y6Xh7xrlwQI,13149 +torch/include/google/protobuf/stubs/status.h,sha256=0uRycX0joXBJ4zDUHcZYOnE5dCNI8sYfFlxnGYGADPc,4064 +torch/include/google/protobuf/stubs/stl_util.h,sha256=Eq45yPAewWQFTpXk9AvhfgSyV0kF_s0GFiIfi33iGD8,3354 +torch/include/google/protobuf/stubs/stringpiece.h,sha256=bEcbpEVSsSB-mTQnGEcYyDFHsCvNA4MamqOhPtlEMAk,18350 +torch/include/google/protobuf/stubs/strutil.h,sha256=OZVmMOvDgcFs9oAcG0OdhYqLfcoxqep4MAO-Gh6BxTs,39724 +torch/include/google/protobuf/stubs/template_util.h,sha256=6lBI5Bik265rQ3O87Neej81W--W67IGCPR4cpFMMzEc,4972 +torch/include/google/protobuf/text_format.h,sha256=tntU_k91dK8SgVjEla7KSssm_u8yIHWd1m32NBN_5fk,29637 +torch/include/google/protobuf/timestamp.pb.h,sha256=bppH2WTySQDQc2IkH5tYqMMH9Ora90q7KMoELjXbPQM,10204 +torch/include/google/protobuf/type.pb.h,sha256=xIElhkPASkzS_yMFeCjLCpoHXK9U4avOqGTespl11cA,99385 +torch/include/google/protobuf/unknown_field_set.h,sha256=6MGP39xrGeTuAncQnnvWHQk-9HDLUtqMhC6iNmyvTII,14812 +torch/include/google/protobuf/util/delimited_message_util.h,sha256=ZiZCkIsm81bdjYv4qpjPeH5cYlLoZNngaPX-Z4WV3ZI,5507 +torch/include/google/protobuf/util/field_comparator.h,sha256=E3wejfY6dZtN03RJfQFx__XqimupVSRzEC4ffKW-DWU,10762 +torch/include/google/protobuf/util/field_mask_util.h,sha256=rXd2TBrL4d0rtHHzJY0dIaAk7vs5YzLUsOGpWVFwUec,11638 +torch/include/google/protobuf/util/json_util.h,sha256=SdFGCExijaxtAfJDFjWOjC5gQY67dwFnrFRbAh_0QGs,8632 +torch/include/google/protobuf/util/message_differencer.h,sha256=9Opn0TErZH65gRg4OYrgV4WqvgRDO6RIz_gca48auzQ,46214 +torch/include/google/protobuf/util/time_util.h,sha256=qw_fNQBBv8wiffUzowKqpym0gb_hVBqwS2AlpFGq8Mg,12465 +torch/include/google/protobuf/util/type_resolver.h,sha256=VB9TxWkKOep1xi0lgwTs-9pw-BRSs0GK754ODFrC2BM,2928 +torch/include/google/protobuf/util/type_resolver_util.h,sha256=heF9HkWW_hqWymgvj6HbbJKoS56lMUqpG_FLl0wCKMY,2464 +torch/include/google/protobuf/wire_format.h,sha256=zbpaIns9jxSwFv1MNxpUOpfKA0x3S_GlIiTf35t-Zrs,18067 +torch/include/google/protobuf/wire_format_lite.h,sha256=HSJls8NEUFJoIgQoKq1xF11VcBSVk_eHZ1pBpKkbuJA,85514 +torch/include/google/protobuf/wrappers.pb.h,sha256=AqzpnZSpMHEcFrjQehXg8Ta58rfmdOR5laEACBb1NQ4,60790 +torch/include/ittnotify-zca.h,sha256=q2VemRR0e37b3t6BGLsu22fejJT7jLMnsOCL1yaxNBg,3834 +torch/include/ittnotify.h,sha256=oJ-vny2tQSkbos33V2GB5HG238Se27gJLzrgzklaKYw,201506 +torch/include/jitprofiling.h,sha256=_IP7fjMpBSRli0g_Gry_3e1t32-IKZxu--pRvkuj4Cw,30260 +torch/include/kineto/AbstractConfig.h,sha256=BoqCfIP4kU-WVbKFXOGLfYoIQFsUoFmVX-00frKRlso,3902 +torch/include/kineto/ActivityProfilerInterface.h,sha256=5yFNeiX6Tl6p-ZHEztZ_35NO0Gu2d6nK5C-F29yFS-4,3660 +torch/include/kineto/ActivityTraceInterface.h,sha256=AfJSgF2nKWu19rrYjKnZyyZpx8XU1_ABJGmKJ6YDhv0,612 +torch/include/kineto/ActivityType.h,sha256=MFlqsZgw9OKQSOX4TtUCnCdx4VjUOdiqA0STfOqiimo,2348 +torch/include/kineto/ClientInterface.h,sha256=rKNIGUqu3QZisKVMg6tbXAOTw6j3tD7o-9zKzvqqh7w,683 +torch/include/kineto/Config.h,sha256=i9M166iO5lOrI-n8hYvRSgXtSj1sUs3niJfFInBXgR8,16214 +torch/include/kineto/GenericTraceActivity.h,sha256=I9tabMXyC2Gqnza7WZaKghPukUDq7d4e9ox2_z8IGxI,4056 +torch/include/kineto/IActivityProfiler.h,sha256=VbKtofECKgu4fNIT9mmxYrtRao6s52V5QG2uofL0YPo,5281 +torch/include/kineto/ILoggerObserver.h,sha256=sdhAY7BDewfk4ocM2tmMDK9Sl-CTKsqIu8HaUg1d1G4,1883 +torch/include/kineto/ITraceActivity.h,sha256=kmDp_Jlz0NqQmPQT5iKz1YM_0zMFIwxG6X5pxHE2OHg,2112 +torch/include/kineto/LoggingAPI.h,sha256=KeswfY74LEeu1XqdxvhTgloIVSCaPV6veIq86jSzBxo,360 +torch/include/kineto/ThreadUtil.h,sha256=Sten7S91qJXLYRw_CjSvC3b8dgcrjSwycDV1btxRWpc,961 +torch/include/kineto/TraceSpan.h,sha256=2sRObY5MdaE-qe-MLWn9zrhXJa4zETEQ-XNxuhewZ0A,1014 +torch/include/kineto/libkineto.h,sha256=4cpqkK872cQK4AZmup2lUp3g7jSiUZIcBA11IemhIE4,4048 +torch/include/kineto/output_base.h,sha256=9UezVio6Pliq84sGZIARv6UHxCqNvr3zuGcNZYuo8aM,2242 +torch/include/kineto/time_since_epoch.h,sha256=VgfWFAlYqjzCnBCnUOTjlBwm8VE6xLzbBUpJRTwJbiQ,537 +torch/include/legacy/ittnotify.h,sha256=hyaaawLsWXLmSVzdtvRJbPcTL5CXn-A4NLQbxJlVtM4,38378 +torch/include/libittnotify.h,sha256=rq5GmAAVspXzw_iNUOlsk-fxDvKeghzWQ5M8ijQnWWE,588 +torch/include/libshm.h,sha256=TNoEjSFY6vmf9pi1lm8F0qLk_SthMxLDGGTrdKeog7E,815 +torch/include/mimalloc-2.2/mimalloc-new-delete.h,sha256=3U8lyuUyCdRdc_jmornCGej8dDTZfyDrqa8ai4UAMP0,4044 +torch/include/mimalloc-2.2/mimalloc-override.h,sha256=JD2xsHPOmFhzVFt0YohUbvuMyt65hg5nSn8PcBr4usk,3165 +torch/include/mimalloc-2.2/mimalloc-stats.h,sha256=F8rsJLS5cMmaapkY0YMD0gRMM0i0xtPvi4cP_9fsIgA,4047 +torch/include/mimalloc-2.2/mimalloc.h,sha256=VchIef0yboZrqOYeJFND07Cb88ne_M7RpmZ7W-tRq5c,39953 +torch/include/oneapi/dnnl/dnnl.h,sha256=-6CmJvlNn2QmcCQbahaNSIOW8XIiIzBE44jlNv_cfOY,187834 +torch/include/oneapi/dnnl/dnnl.hpp,sha256=dTf8CqkXGT0gAOA4ospiIDpfMdM6FdZZWoBLjOpUAXE,663366 +torch/include/oneapi/dnnl/dnnl_common.h,sha256=g6VyOPYLD6L5u4wOhUXzkR8keb_sRuTrXTTn7wlbjDQ,5831 +torch/include/oneapi/dnnl/dnnl_common.hpp,sha256=mM-X6rB9zfNYeF_e7w1ZkjA3iIQq5DhHK2Z7B6G5hA0,16192 +torch/include/oneapi/dnnl/dnnl_common_types.h,sha256=8wwRmEkfEeMt4S7rpM5my6b6FOLCkCcZEh6N-tVAta0,8465 +torch/include/oneapi/dnnl/dnnl_config.h,sha256=Y8SgsN3EhLvEbbESSziCOPkWX8N9Qu_XEpIOWSeDJ3M,6467 +torch/include/oneapi/dnnl/dnnl_debug.h,sha256=GakV4Fk9_9sXz67gVTrR5A_7_-uaSbGppKZSkRZvbvU,2371 +torch/include/oneapi/dnnl/dnnl_graph.h,sha256=amhasxb7JbT7k4YMp3kxcak0CL6D8gAqljV9NGmW83s,33498 +torch/include/oneapi/dnnl/dnnl_graph.hpp,sha256=R9DlsdiPAfarcFoCoRTecFYWY9LJzEjc9N7PR6pISGo,67058 +torch/include/oneapi/dnnl/dnnl_graph_ocl.h,sha256=F-eJRiksOWzIcyVBiXjsEGh9aaqNPZO0FKupqNVA9Rs,6077 +torch/include/oneapi/dnnl/dnnl_graph_ocl.hpp,sha256=uDUzNqFO62luFakwMnn23qarJdZIFEAiQUv_JGyg82Q,5589 +torch/include/oneapi/dnnl/dnnl_graph_sycl.h,sha256=kbzQuEbkPxXo6i-X-uznI24V3cHr6zrT-srMWexWaj4,3752 +torch/include/oneapi/dnnl/dnnl_graph_sycl.hpp,sha256=Ussr5os_qD39lJAK9mJR_WBWFGNjva9PDZjWSzrzcVo,4468 +torch/include/oneapi/dnnl/dnnl_graph_types.h,sha256=L_EhraQzcxMmLh8Tjzkg6Fq-wQkarxQhiG2RSPPxO_E,17275 +torch/include/oneapi/dnnl/dnnl_ocl.h,sha256=kEGI5dIvPy77sLt8PA8YWkP8_yCPJ0KZCJSxM7up9cA,12135 +torch/include/oneapi/dnnl/dnnl_ocl.hpp,sha256=dDAyEPPkmYhBLBL3QdBb_H_MZHa28PdDB46hBiMSeVs,18224 +torch/include/oneapi/dnnl/dnnl_ocl_types.h,sha256=bTP37xABKMcS0U3W5hJ0BV1FLvJop9DipWhbhAnDADo,1389 +torch/include/oneapi/dnnl/dnnl_sycl.h,sha256=ZkHn2dVls4cQMSaUKZs7iiWlLmYoKVKqRXwua-1iVtA,8551 +torch/include/oneapi/dnnl/dnnl_sycl.hpp,sha256=SkwgVh6fjjyjFBtKYkEyNtWUvYw9xV2cKncnBIvZgyI,15228 +torch/include/oneapi/dnnl/dnnl_sycl_types.h,sha256=LsxIUa6JePeHvymaHT8XvHS7FMQhcExAZI_mr5Z90YM,1401 +torch/include/oneapi/dnnl/dnnl_threadpool.h,sha256=YjywiBDIcpLmTkwdJAiXAzuerpynFfOyBNUKMyuehcs,4642 +torch/include/oneapi/dnnl/dnnl_threadpool.hpp,sha256=EQ3Gcj49FPivy3mlwfT3kMUQPKjZ7CmlJbuT3o_-w74,4440 +torch/include/oneapi/dnnl/dnnl_threadpool_iface.hpp,sha256=ZypjM7e5TVtpNFP56PZZRy5_JY1FjpF-ggU0LdzOU9E,2279 +torch/include/oneapi/dnnl/dnnl_types.h,sha256=26sKaWOezZImgUkwKvYa0TmsPpAw1ty7BMRznOp3mZ0,101328 +torch/include/oneapi/dnnl/dnnl_ukernel.h,sha256=goEgG-I9ML7IpIKlNPtlGdmIYLMH2Uw1if6t_ikS78k,13853 +torch/include/oneapi/dnnl/dnnl_ukernel.hpp,sha256=FVJP3aSPtxT8YIOWtdH63b0lP6X7LrXivr1O50urTVc,17789 +torch/include/oneapi/dnnl/dnnl_ukernel_types.h,sha256=szo30VIG4GHYCMfBu-PJSkt4m2N9xeyC_oGpVexDr4Q,2710 +torch/include/oneapi/dnnl/dnnl_version.h,sha256=EWijEB-YWqOuY-G3IHmp4L_WMXi56XuDsjol79QFUpw,1045 +torch/include/oneapi/dnnl/dnnl_version_hash.h,sha256=wrGM3Z0tmFkG79uVKHt8CkFek4NpoCLjHEhExjQ7RGQ,1277 +torch/include/psimd.h,sha256=jj8l40ewHhPbMi8a-Ih26lnB9Ax1FcrmQK8edLoUQR4,46888 +torch/include/pthreadpool.h,sha256=e3cKQt9aPHGcXyG6UHUurnj3dVqCfOct6m8-JQ-M3N0,101883 +torch/include/pybind11/attr.h,sha256=geF8hJN4ojg0PJHZpGGdKcSOayZX04DUtxcd4sa0PNk,25024 +torch/include/pybind11/buffer_info.h,sha256=qR7pghnWsDMQva0ewn81wtYD1bAf1C_61zI6AaSkiQw,7986 +torch/include/pybind11/cast.h,sha256=DmP4Q8Kt9OGL11XMXHgeWEtRb6VzN_wUSYxAEpcNHdY,73551 +torch/include/pybind11/chrono.h,sha256=EYBIP7LuEjp3vBy-Mg-YLLAlnYTKBaSOVQn8tSjtisU,8683 +torch/include/pybind11/common.h,sha256=9flrAhYFcDioFXEMfL-sYGZYQmFsUdycR1GSnVVKKA0,122 +torch/include/pybind11/complex.h,sha256=0FPTn_YbTQsFUCJw2vAZ_zWcrF4LsfGKRUjUrWz-ZX4,2170 +torch/include/pybind11/detail/class.h,sha256=DqtnTvQGVJRLNYdenb1RcuX_PZPVdClQVQ31X1Edmz0,29793 +torch/include/pybind11/detail/common.h,sha256=421RMsizfaS_sg9Hs1cBq0IkeNgObVFPuHAlfzcRdwU,55995 +torch/include/pybind11/detail/cpp_conduit.h,sha256=7OsUPJSdem-zYpowExADZdsJGR_fycGFUsftn_Cfcwo,2666 +torch/include/pybind11/detail/descr.h,sha256=RbtOT-vph-aUgV7ZqpjXl_T-nEQinQlQi35-cjCCHag,6207 +torch/include/pybind11/detail/exception_translation.h,sha256=6nZ1JDvOSJD8VkpmEcIc8E1GReq60uriLHKcM3OoBGY,2671 +torch/include/pybind11/detail/init.h,sha256=JpSmNz5cqLtJs0IL3OC9VAfpff0fB_kRZrewTjGMqCk,18419 +torch/include/pybind11/detail/internals.h,sha256=71O3DnZoeI8KsoKng0_5TN05fi7gyQtqWfOuEMd7-mM,32751 +torch/include/pybind11/detail/type_caster_base.h,sha256=lFxAAs3gdpyL1Msevr6djJ8E4Zf1gjsIGkej8tvYWA0,50133 +torch/include/pybind11/detail/typeid.h,sha256=ybfJGams4o0qi-zst7kspHjNNgpGnAldtG8Pwcgmtec,1690 +torch/include/pybind11/detail/value_and_holder.h,sha256=mgKBhVvSU2oHdYnCy8T8dXW9anJIvWO0XPOEZQjAxTA,2891 +torch/include/pybind11/eigen.h,sha256=9kGbkSF8hQAd-6fIqkJK5qH-x7-MtleA1deKtribDNQ,328 +torch/include/pybind11/eigen/common.h,sha256=sh9Thp2j0lIvGx6UT5k91ph2R66LGRsb-rOZ4_evrhg,387 +torch/include/pybind11/eigen/matrix.h,sha256=eQAC0l2zvbmoHkblIe7dUapohhrx6ixdMYW_3H9O6-E,32857 +torch/include/pybind11/eigen/tensor.h,sha256=VLh9o-uldgKtGmXnjQ7JJzmchNyzoF7Gs_YjdpZGQ2Q,18899 +torch/include/pybind11/embed.h,sha256=aOeJ_Wx9tC3zYNchIZQ0NeschneUSHKAJG5RLsB1Y38,13675 +torch/include/pybind11/eval.h,sha256=KARGCuiN1slB6oxi_xVd5QUTuS3MZNqTQEkgafXUD48,4887 +torch/include/pybind11/functional.h,sha256=B4w7SqP4zD41uz-Vx6ao26qoDMv9synOEgnZh3rKEe4,5416 +torch/include/pybind11/gil.h,sha256=0JXpF-tiE3G7L7_XStRxJWmO75-CZWZRGgZtdp1QNRc,7921 +torch/include/pybind11/gil_safe_call_once.h,sha256=SniDYmGZ6CxWWHbp4JoQQOFpydBUuaq49u3hUkqohHA,4093 +torch/include/pybind11/iostream.h,sha256=lcL3-ToaPhHzyc6iW1p9CiOZRPeT3NVuHrYTo0GOo1Q,9127 +torch/include/pybind11/numpy.h,sha256=LLG8vHy5KrjYLn_j62JtIktJ9CWs5X0Udiv82LC1crQ,86581 +torch/include/pybind11/operators.h,sha256=kD1nk8bD8VV-n1Cj_JXFtuNSkShwwB_bM8BUTFNA86A,9305 +torch/include/pybind11/options.h,sha256=geFwXpZBmkIqkoQQvUmenwPA5wp0dBeQs0kYuqMG_PM,2826 +torch/include/pybind11/pybind11.h,sha256=PjIqHcwXX_CahL1AmpSwZhb--COQW6fg7iBZOYfmJ9c,132876 +torch/include/pybind11/pytypes.h,sha256=rFFKMOFkmTi1dUMkej5JTvv90j21quIQpJ3YJzobORg,102500 +torch/include/pybind11/stl.h,sha256=GtoQeZkwTSMcKT1DMsMr9pX2LzpxOPfI8An_ZrEZ9Ww,15980 +torch/include/pybind11/stl/filesystem.h,sha256=Nkd469hk_vA0Ul0XfEjO2J-jHcIK6aZQjkIcftkuSPc,4683 +torch/include/pybind11/stl_bind.h,sha256=wd3k95ydKxFJPrRef-sp8CqvrkF3jsf9JQYGFiuQF8M,29317 +torch/include/pybind11/type_caster_pyobject_ptr.h,sha256=bIA5W1G3unHYLjHwcAz77FybP1uq-nlLSa038EhZEZI,1990 +torch/include/pybind11/typing.h,sha256=DGjlv4Kv9g-s14FXALmklL8NNG22TXCmmvIxZBAsj0s,7242 +torch/include/sleef.h,sha256=_fbfekfennyOAb5SstH2_tun0vdw6BipUo28a3t9cEg,273502 +torch/include/torch/csrc/CudaIPCTypes.h,sha256=9pv67zrx7Mc-xZ_Ob668wkUzhF0GRxBRX_9f3TUS5kE,3540 +torch/include/torch/csrc/DataLoader.h,sha256=ouSjEde-5djXm1E_t-79-x1QXFmFh_iBSviIFcV4-3Q,228 +torch/include/torch/csrc/Device.h,sha256=nza6IrUf7JN92fkj45ecpLDVWIuGqnurRPizAuuYliU,507 +torch/include/torch/csrc/DeviceAccelerator.h,sha256=aEOiPJkbj2p_GwlWfmxKvYN5vDtC-2BGYDBa1D-q9uY,184 +torch/include/torch/csrc/Dtype.h,sha256=YXtpmWT_tziQOZZ_kVzgxWUxKmzvmlumWUNH1yusrX8,868 +torch/include/torch/csrc/DynamicTypes.h,sha256=O1Wo3Zwelnb3Krrk_qaRZIe37OBPiatsC0gYMyeev2c,1099 +torch/include/torch/csrc/Event.h,sha256=LiNdG9Z1Dyd84bk6mXlIYeS4OUdiAua8Z2tcGyYz0Lk,583 +torch/include/torch/csrc/Exceptions.h,sha256=tVhfBCDD4U_MZMfAOMDv0cIfNmZ8udmTQB-kOry3oyg,16547 +torch/include/torch/csrc/Export.h,sha256=Uv9Agb-SWcmOYrBU3eYs6SBXyqm_TdzkS-dzR-dbBCg,166 +torch/include/torch/csrc/Generator.h,sha256=0rzN_H_ju2yEssqxaWaZRud7GPHKqwLh_ALzhLMOQ_A,1087 +torch/include/torch/csrc/Layout.h,sha256=npIOIOC2Dg_0dn4l7ufMMv1nNth1B8iTkPXeWFmmd8Y,612 +torch/include/torch/csrc/MemoryFormat.h,sha256=wP9xpQ1FG5Z6NgtMOu-0pGOab1GC9gzFDAjHBBVgMzI,711 +torch/include/torch/csrc/Module.h,sha256=f2qMc1yMQ1AOYVNYG5kf8HK09KZhbrLJylF8Ib9-wjs,107 +torch/include/torch/csrc/PyInterpreter.h,sha256=Pwc-SiT0yCvAdvOPULWMU8C0WfvbfJzJqu_42ODIj-c,396 +torch/include/torch/csrc/QScheme.h,sha256=wbBTAButiBub8zU4OTEFoyf0vanc_g8jw409n9ql3rU,635 +torch/include/torch/csrc/Size.h,sha256=yf7UKaKa93MCeuNKDxDkbTxmCIY6oSzd6LbSQBy3eZA,492 +torch/include/torch/csrc/Storage.h,sha256=JOhQtrklgPuTUmqJguKHgJukHy16oycwVMTTHVdGXwY,1630 +torch/include/torch/csrc/StorageMethods.h,sha256=dhpWpaNpaXudEqrL9QO5bfq97An8z3oDuavPN4AIO_4,140 +torch/include/torch/csrc/StorageSharing.h,sha256=AC6h1Xe3MGR7kbvrSOlGaYaMmR_iy13LoDdSNW5w5wM,147 +torch/include/torch/csrc/Stream.h,sha256=S15DzHlfLvD5crMVNvyNRPuMAo2nmI5-jppeWza5-SM,711 +torch/include/torch/csrc/THConcat.h,sha256=eEIG2mZkE_vkoIeb8ukEXc_w29GJaOjU2tjZJ5QU1A4,710 +torch/include/torch/csrc/THP.h,sha256=kUXIC9-QpnXbS1dB9b0h4tMwSlfHSQ2QpceHFYLBqno,924 +torch/include/torch/csrc/TypeInfo.h,sha256=OkWOxkySfAFhjx9wHUPZsF5AUacwA_v3RsCme4YCfJ4,593 +torch/include/torch/csrc/Types.h,sha256=2oIivVbts_6A6S0q0yMrMgAVewQBQ1M8wpDW1sp6xlA,176 +torch/include/torch/csrc/api/include/torch/all.h,sha256=Jumw_rbjIUE0f_gyKhOVVXPYwVGU-QdhDbFbShIQfio,587 +torch/include/torch/csrc/api/include/torch/arg.h,sha256=gpx6jp12QtQIZWUL0CLn7T7cBiJOJtotidU20H81UvA,1450 +torch/include/torch/csrc/api/include/torch/autograd.h,sha256=6uNE0-m_doNLsa_kJ94wxiJS_FVxJoCloAOwWtTYrTM,177 +torch/include/torch/csrc/api/include/torch/cuda.h,sha256=54zHsN69aEXymwBeoSsmLpJfAtU2dsUUFzyvv5-gwDA,761 +torch/include/torch/csrc/api/include/torch/data.h,sha256=JSePfPMjjOmjHlL22Y4TfeoSeuDrm1C3FxFpfq03n1A,310 +torch/include/torch/csrc/api/include/torch/data/dataloader.h,sha256=WljVzhHwY43Yqz6G0BPj9HYX1Ze_jxcrh6gdOoPaMBs,1951 +torch/include/torch/csrc/api/include/torch/data/dataloader/base.h,sha256=P5W2Ykztewa_vDiU75JljejsG3bmjTG6w1ds7kyTqe0,9475 +torch/include/torch/csrc/api/include/torch/data/dataloader/stateful.h,sha256=9E0WaH1e4ku6U1bWiWec-SSbuMX6-q1DrM0Sz-K4mLQ,2396 +torch/include/torch/csrc/api/include/torch/data/dataloader/stateless.h,sha256=sEHE51VoCxP8SbhyuqFX3WvwwO1t2vC-LKBMZS3XDYk,2826 +torch/include/torch/csrc/api/include/torch/data/dataloader_options.h,sha256=B7UaGxNk09Qf54IZcwAVO6vm41auX7wLKnGbVT_lIZI,2260 +torch/include/torch/csrc/api/include/torch/data/datasets.h,sha256=wUJmQwIgXsajJ4X2znwBM4ssPGsK566fI1QjiAoItY8,298 +torch/include/torch/csrc/api/include/torch/data/datasets/base.h,sha256=6IGGIzgMkv19MHbLmxCOiVg1KdbH1gcx4te9wHA88-E,3272 +torch/include/torch/csrc/api/include/torch/data/datasets/chunk.h,sha256=Pfd_uHqOEiUKi-acPUOSYp7nnxmkXJIm0U3vlbZwS5s,19703 +torch/include/torch/csrc/api/include/torch/data/datasets/map.h,sha256=AkRDG2lN6tXhy16JskHG5HwIKlFvEzFxmrmw_dSOyCU,4186 +torch/include/torch/csrc/api/include/torch/data/datasets/mnist.h,sha256=j3T47FpuoQIeL7RbZ0zz5Sxtbhhdr94MCChyepWdgHY,1273 +torch/include/torch/csrc/api/include/torch/data/datasets/shared.h,sha256=vSLVzxPJGfHu9R50_b2xmKSOBtQyKPbuXAvIhnljoxc,2674 +torch/include/torch/csrc/api/include/torch/data/datasets/stateful.h,sha256=GXh7FJBgaGzOiGRuMtLEfSOz1DJcq-BBoP6g2S0abPU,2298 +torch/include/torch/csrc/api/include/torch/data/datasets/tensor.h,sha256=1TQ_YwsX7kY_3n1YkjcsOVk-fTpI1lspM_uYVSEAZiQ,965 +torch/include/torch/csrc/api/include/torch/data/detail/data_shuttle.h,sha256=kArrte-eilkriikhyC6BJQoHFTBldT4k1KFqHjTz_mI,2668 +torch/include/torch/csrc/api/include/torch/data/detail/queue.h,sha256=n13QbGtFvhFM9Gf3cW46JWfWrUS5I4MTBg9RTwv1x0U,2535 +torch/include/torch/csrc/api/include/torch/data/detail/sequencers.h,sha256=J_dxBIdqDA-CrrrTCQyInb2h1_ii9ixKQ3I5zz3IsHg,4547 +torch/include/torch/csrc/api/include/torch/data/example.h,sha256=Ch2uVjs5iYjdqoEago6kFr6PmpnvWVDmH2GfXAvz9eA,1342 +torch/include/torch/csrc/api/include/torch/data/iterator.h,sha256=vPZq-Xruuoz6wrhUQMAV3FkdLeH7a8tjQrrOgmtyZ-I,5483 +torch/include/torch/csrc/api/include/torch/data/samplers.h,sha256=vvBX2qsD_kusBAOouG-B77963lUnYQRShVwo1p8qBT4,327 +torch/include/torch/csrc/api/include/torch/data/samplers/base.h,sha256=M2seapSM28wGSwivQQxxJiq2BnUt-HkfSSv4qlnz-7c,1207 +torch/include/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h,sha256=noZDLYKy3cYkaUKaOHBtfTglrVqwFHNntTB4yekJ-Hg,523 +torch/include/torch/csrc/api/include/torch/data/samplers/distributed.h,sha256=UpdzEmPAVFtb677R1FcH9TCHG60CvTh9uB4Y5UEq8hQ,4193 +torch/include/torch/csrc/api/include/torch/data/samplers/random.h,sha256=cbeXVJgo0xNY1qtQcje092zKZyYfXTJ3aFzuAs3ezCU,1510 +torch/include/torch/csrc/api/include/torch/data/samplers/sequential.h,sha256=42rdY16wKnmI6morZnDWKAr6Mb6taWZ4EgWdNnWOF24,1238 +torch/include/torch/csrc/api/include/torch/data/samplers/serialize.h,sha256=2b94CwZ8MofjMkr8D2nwUn6Al30xqY3Rf4Hpimw7X44,681 +torch/include/torch/csrc/api/include/torch/data/samplers/stream.h,sha256=a6JpRvbXMgVpvKDmmIGYk-Cm0kE3IHayx3Pyv72kA6E,2030 +torch/include/torch/csrc/api/include/torch/data/transforms.h,sha256=YV_eW7I4UJahbvu9a4ZH6__0ZAujBo-AE4OKP3l-FLI,229 +torch/include/torch/csrc/api/include/torch/data/transforms/base.h,sha256=FKOvVycx8HDXrODAmCyUv9LzhwCBSwMzfX9msz_C_gA,1628 +torch/include/torch/csrc/api/include/torch/data/transforms/collate.h,sha256=PHV1CM6T4MDiZBZ3kfKCGF9StfGk7X74On0KL8FNf6A,1094 +torch/include/torch/csrc/api/include/torch/data/transforms/lambda.h,sha256=ub8D70EPnUN9H2Qr0BrSnfccPZQWWU0QW63gE28Jq5I,1711 +torch/include/torch/csrc/api/include/torch/data/transforms/stack.h,sha256=LIo2macXr1UqVvWVM0NRN_qgxUGZJMr94pBOD6aDLKM,1419 +torch/include/torch/csrc/api/include/torch/data/transforms/tensor.h,sha256=lrsFzVHjLSEoj_DZHth11M0j3ki9zfPca5cZ_nLlDN4,2496 +torch/include/torch/csrc/api/include/torch/data/worker_exception.h,sha256=2yyxEYuZBorHkjlG5vJy_GK-ful6S8SYe1MaipF-znU,1156 +torch/include/torch/csrc/api/include/torch/detail/TensorDataContainer.h,sha256=AtKonXURAwCqQ2MD4IEY52902a8ex0En_TBr4mAVV50,13301 +torch/include/torch/csrc/api/include/torch/detail/static.h,sha256=yUXeXchKmQBMSjdPGadBJkQYBOHXtmCN3yGsBlZtAIw,2152 +torch/include/torch/csrc/api/include/torch/enum.h,sha256=BZ6JUN9ZRz0cGJy7F0sI8IIL5SbXwCs6fpIwEKAl65w,7660 +torch/include/torch/csrc/api/include/torch/expanding_array.h,sha256=jehwXKTKF-FMiYQuhYt5t6NPpvTOAVJ8QyO9-Jhtyes,6855 +torch/include/torch/csrc/api/include/torch/fft.h,sha256=jGZBJjJiskBFAn_iHawqxuEzM-Y9vreWynJgxt494R4,12494 +torch/include/torch/csrc/api/include/torch/imethod.h,sha256=gxHtBrxmLIqObKDPu7FnV7Q1DydUIZKHtBYvMEp5EiM,1793 +torch/include/torch/csrc/api/include/torch/jit.h,sha256=My7d0xwa67IJw7zeZwDjbK92r6GgAxfI-R4TFofeqJQ,922 +torch/include/torch/csrc/api/include/torch/mps.h,sha256=rjZBEqDgrmoqT2Fo_dXiHSBXk9vzePP9n8Sp4647W_M,1236 +torch/include/torch/csrc/api/include/torch/nested.h,sha256=2A7B5ww1tFrxqYxjXK92ZSIMon6Taj8XibEC_qwFHB8,2866 +torch/include/torch/csrc/api/include/torch/nn.h,sha256=JZOcrLr-MSI9dnJY1SpdOr1zAZvmRLQN4GrrJdy4qbA,261 +torch/include/torch/csrc/api/include/torch/nn/cloneable.h,sha256=Vj_2Opo_bbngnnCTSS7GNInDojg2cax1yP5K15vKCRU,3995 +torch/include/torch/csrc/api/include/torch/nn/functional.h,sha256=zX8T3cbwAs4nMt_LVbs50CaGzF4v9drbamY3EPrP6X8,659 +torch/include/torch/csrc/api/include/torch/nn/functional/activation.h,sha256=bCzw2M58kw6MvZUcKcC_ZxHo6aIzuMLKtEJ5hUHQF_g,30735 +torch/include/torch/csrc/api/include/torch/nn/functional/batchnorm.h,sha256=hoC_mRYifDCKnKnnnyxmOMpulzEGG9rlP4h8latmHqo,2094 +torch/include/torch/csrc/api/include/torch/nn/functional/conv.h,sha256=ZdT_KmqHPokCqIcRaHr0mPN8x9zDQ88cpdQ5jOl1B-8,8394 +torch/include/torch/csrc/api/include/torch/nn/functional/distance.h,sha256=gSF9MSgmcTM5a9IIEfJX7qk4vIv1_jYaUjPPAZHeWxA,2583 +torch/include/torch/csrc/api/include/torch/nn/functional/dropout.h,sha256=sJUOhVigTl8TqH2FqZP3I9r6qeLW9lzefOF22-mPnus,6766 +torch/include/torch/csrc/api/include/torch/nn/functional/embedding.h,sha256=eWbgE9tAbXYbAx9Cy2-5raQ8gClJrp0x3pgWk73nw1g,6565 +torch/include/torch/csrc/api/include/torch/nn/functional/fold.h,sha256=ytAy8mIWrd9yXoZSzDW1koQSkecK2dYm3gN9aRBIs0k,2834 +torch/include/torch/csrc/api/include/torch/nn/functional/instancenorm.h,sha256=FBqHD_Ju9tLNsZYoR5_KcORYvBKJLSfKReDdMiDlyY4,1614 +torch/include/torch/csrc/api/include/torch/nn/functional/linear.h,sha256=Mt7UqldDF5j71MkbV2SDIfIAyCBH3ER6ODxVGzMHhOg,794 +torch/include/torch/csrc/api/include/torch/nn/functional/loss.h,sha256=AuoeKpt6hxHYDzL5zVzWKvRV4eR8Kt04WCHTCeCZmmw,32880 +torch/include/torch/csrc/api/include/torch/nn/functional/normalization.h,sha256=RtGOAUFdKPkfzTUDVLbAT6utt7aE3VnEfOQbNiziXb0,6178 +torch/include/torch/csrc/api/include/torch/nn/functional/padding.h,sha256=MmBGRAQis39GvajCySoQsDIRua0PzpHIzNOQXf3SMig,1728 +torch/include/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h,sha256=MqO_tzcfC1sZsqGaZ5NK3ZaIIRIK7O-yO3lt6x9am8M,1336 +torch/include/torch/csrc/api/include/torch/nn/functional/pooling.h,sha256=XTKRlIb9cUFiktt1t5bAcsmpwm_l7PVi10sVEQsT-hg,36559 +torch/include/torch/csrc/api/include/torch/nn/functional/upsampling.h,sha256=ad0zoZsvAVPyL88xa8o3UOsKQapifR3-jC2t-msNYNg,11014 +torch/include/torch/csrc/api/include/torch/nn/functional/vision.h,sha256=talkaLVeHJhyZAYEilrpUNSQ3YLbvo2W0AcCbr_JVtI,3698 +torch/include/torch/csrc/api/include/torch/nn/init.h,sha256=W5CuPhlUI-1Xa8Rxf1KP8ZZNparhNDQTiPsWxxuMsrQ,5041 +torch/include/torch/csrc/api/include/torch/nn/module.h,sha256=-kfGkL9zJBx2AFkwdn3YXpzKZNURzQZKG4gwj1V4fn4,27513 +torch/include/torch/csrc/api/include/torch/nn/modules.h,sha256=GunxX9fV82A7Kq2pI5YW3wCrcdS7cu5TxR1zQ0kxZXs,1325 +torch/include/torch/csrc/api/include/torch/nn/modules/_functions.h,sha256=HXF9L2ARZ3opmbdr-9RvisKAq_Ny44yW3UvELoW1tew,678 +torch/include/torch/csrc/api/include/torch/nn/modules/activation.h,sha256=OEyFNgzMNYob_F0MdZ8MMM7wU4Cs6dc9uktCPMokEj0,31190 +torch/include/torch/csrc/api/include/torch/nn/modules/adaptive.h,sha256=yp_C-2O358_cXvSMt3-tI9cUM2_DDSlUSxCBm6GUoSs,3626 +torch/include/torch/csrc/api/include/torch/nn/modules/batchnorm.h,sha256=9ydXVI_a7VwAQXtfebC1L7nWRlItBA9VD0dNvrjTkOY,8467 +torch/include/torch/csrc/api/include/torch/nn/modules/common.h,sha256=b0qmYH3ShSjXMv87F_RH3CUQhulCBVjxBUGmUzXHh94,4447 +torch/include/torch/csrc/api/include/torch/nn/modules/container/any.h,sha256=tm3dXSI7bZU1vcITRxSxDNozIchIfqGxqP1x1HBFy5U,13769 +torch/include/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h,sha256=OSztxH_2ZmvIP8lSBXButBkbR3pyTnUuJpCDHrvxpGE,5131 +torch/include/torch/csrc/api/include/torch/nn/modules/container/any_value.h,sha256=UmIaIPBR4U9kRpCtQeNwDVDkVIrjaL3vlq2uNfhOPXQ,4262 +torch/include/torch/csrc/api/include/torch/nn/modules/container/functional.h,sha256=ey0u5YIGsQHIWR16HvcjYeIwdDRJN3AgOQzM0MN23wA,3444 +torch/include/torch/csrc/api/include/torch/nn/modules/container/moduledict.h,sha256=SiSZed3BTalywuHf9eI0ROae2bjtK3XKBfMw4Pm6bJM,8682 +torch/include/torch/csrc/api/include/torch/nn/modules/container/modulelist.h,sha256=RQ8yv34Zv-ObiMr2G2CkOpAkGRYO9XkTiXe_XILdV5o,9214 +torch/include/torch/csrc/api/include/torch/nn/modules/container/named_any.h,sha256=SvvwroiyhcCiiGc7OxblllG7hNVQ3esAgaGr6xnDOIA,2523 +torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h,sha256=YmZTNDY5uT9hiSk0v6Sl_W8r-6SLdX7AwZZOtVQri3I,4603 +torch/include/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h,sha256=wJcq55aygVDk-RE9Ytpu3QD4vJfsexUyzz61sAMW_18,5744 +torch/include/torch/csrc/api/include/torch/nn/modules/container/sequential.h,sha256=UHTpAefUQud9LW_zIhtAKqP2WhhtUhIzXDyX4QzueNk,14121 +torch/include/torch/csrc/api/include/torch/nn/modules/conv.h,sha256=_Q0UGDWIgnOskdfSsjOJFfpKYJzaoy2x5V6DN_t9Gi8,16707 +torch/include/torch/csrc/api/include/torch/nn/modules/distance.h,sha256=SAe11TTrqBIPapxAMjEeG1sJPAAxcsG87lF1ikfcrxw,3140 +torch/include/torch/csrc/api/include/torch/nn/modules/dropout.h,sha256=2SJBdWU2GwzZ6_aT3amJ26Twb64rGwZeSwpy0F-y-XQ,6584 +torch/include/torch/csrc/api/include/torch/nn/modules/embedding.h,sha256=ZJWikE0EPNqRFxYN2qRkJZORKuPmEJ1lT-TKNv9DIyY,6220 +torch/include/torch/csrc/api/include/torch/nn/modules/fold.h,sha256=iK0aksf3_p5zkCBNSY__-zqBBWXfHbC8jrckeO2Mslg,2907 +torch/include/torch/csrc/api/include/torch/nn/modules/instancenorm.h,sha256=h-ph1ormC0r6PwmcY51hcrZgprYxZ5gd8Y-DyyLtXAY,5667 +torch/include/torch/csrc/api/include/torch/nn/modules/linear.h,sha256=jNICd9KiLRBxsHvga8z8Um3U_X9ZdXODARDYmY8IMqI,7699 +torch/include/torch/csrc/api/include/torch/nn/modules/loss.h,sha256=pJoI1IYP23YZZ61wsjVJd3ictiaKbU8PPxilRqGF8s8,31785 +torch/include/torch/csrc/api/include/torch/nn/modules/normalization.h,sha256=XNc6wme4R8PDeY66luEnBnm0owuJ-xemCcCz1KJ51GM,7128 +torch/include/torch/csrc/api/include/torch/nn/modules/padding.h,sha256=7JAN1oH3ldNNipyWjU2OjYX7rdO1hJ7MYoe9V1yJMzo,14722 +torch/include/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h,sha256=5fHrAwDgOuX-ie0vbqHDJFBukTvrfYnf_3tBmO6oULY,3195 +torch/include/torch/csrc/api/include/torch/nn/modules/pooling.h,sha256=X10Qj87MiX9p7dBt3dLJhO5XY6qdRZw80_82uJ3HP9A,30417 +torch/include/torch/csrc/api/include/torch/nn/modules/rnn.h,sha256=hORYIzjHA5QCH8haf_d42K6kpMPMolZPPSzpDXPUGCo,13847 +torch/include/torch/csrc/api/include/torch/nn/modules/transformer.h,sha256=dsXo7v8hEShieTE1fj8HbGiMXPEDcldTjXUZsiQroek,5463 +torch/include/torch/csrc/api/include/torch/nn/modules/transformercoder.h,sha256=Wszd-oZzK_ZiuG4ZoJjTVI9ywE_Bm9IVcquf6mbLAqE,5355 +torch/include/torch/csrc/api/include/torch/nn/modules/transformerlayer.h,sha256=20k8qWWPNRcsIHUXR3SC8jt0xl0nDoErObB8x6OjQrM,6600 +torch/include/torch/csrc/api/include/torch/nn/modules/upsampling.h,sha256=vJLRyljN_wGMFULjU_b0X6MMDtK2dQdXg2BrshIU3p0,1672 +torch/include/torch/csrc/api/include/torch/nn/modules/utils.h,sha256=77ADrqfmMv3wO5RI-rflBwVePNGGgve9P59eKMa0pc0,1448 +torch/include/torch/csrc/api/include/torch/nn/options.h,sha256=mio9iZxv4kWN8PNBt7OkCpSDfVCAojnPrT7I1TcEXKE,663 +torch/include/torch/csrc/api/include/torch/nn/options/activation.h,sha256=BzP55Y8lvRf__XsGbqstjkoMatVn_JjA68Konx08LZw,19731 +torch/include/torch/csrc/api/include/torch/nn/options/adaptive.h,sha256=KMTXirLcljBZFaGwFLN-hHxlYFkWaVMD6oADeVnqwh4,1096 +torch/include/torch/csrc/api/include/torch/nn/options/batchnorm.h,sha256=o6rq1KvTXOzhlxeJCxlvwfCeUHVvAt8Chx5f1V3bfNc,2852 +torch/include/torch/csrc/api/include/torch/nn/options/conv.h,sha256=MloMisklg395xqZYC4x0MMZAsq69MGFaE0lvjlCpUhk,13859 +torch/include/torch/csrc/api/include/torch/nn/options/distance.h,sha256=Bu2OqEd-azQqxyI9Dp67EoCIk3wE_9agkyYaSY_v5Mc,2058 +torch/include/torch/csrc/api/include/torch/nn/options/dropout.h,sha256=vbu9sp1X_EwCuasTaIKK2uRzzP3vQ5PPApUw-BcpfRc,3173 +torch/include/torch/csrc/api/include/torch/nn/options/embedding.h,sha256=4BF84awReG7K1YT7kXB61DR_AUA3dWRI2n8wxOS6PXQ,11882 +torch/include/torch/csrc/api/include/torch/nn/options/fold.h,sha256=NN2303_kh8nWrHaSCBPET4HGaXFPt3yeTwnNQVd7HAs,3008 +torch/include/torch/csrc/api/include/torch/nn/options/instancenorm.h,sha256=YvSLzhlx-AA0Adp9KVKdBHxxQN9g9TeW9OjNso9Z8SA,2383 +torch/include/torch/csrc/api/include/torch/nn/options/linear.h,sha256=CvCuiSUPbFO86oPLmMJ9RUwmR-zWiZvOkHJNmfQHO4k,2872 +torch/include/torch/csrc/api/include/torch/nn/options/loss.h,sha256=6pWAL2HGoQTT5VdB0p07o-5Y03Esm2U3ydQpxBqZ96o,27502 +torch/include/torch/csrc/api/include/torch/nn/options/normalization.h,sha256=xBvr3cyBlAML8KdmBXZpTDhadhZcHKutduSDfk6ZxBU,5687 +torch/include/torch/csrc/api/include/torch/nn/options/padding.h,sha256=4wDGe-Wc7wLen0qOL8MAolO6NtdyL6Z6oSd1Jdo61EU,7052 +torch/include/torch/csrc/api/include/torch/nn/options/pixelshuffle.h,sha256=NRGKmzhvRNdQ4q53NKnqGi3GR4d2m0iCiQ_VOTcWD6s,1695 +torch/include/torch/csrc/api/include/torch/nn/options/pooling.h,sha256=sJemTYIPOfO6jT4xfZM0ELgriirsfZLstFFwL8Q7CFk,18312 +torch/include/torch/csrc/api/include/torch/nn/options/rnn.h,sha256=MMOwBbf8hrhq7YoyPZAjEkzOhmWgxsiHu5JqHd2hF5U,8429 +torch/include/torch/csrc/api/include/torch/nn/options/transformer.h,sha256=_cay4wJCH_fPl7QAP-uqVzKcgghjJxNmK9QJA0wZpqA,1876 +torch/include/torch/csrc/api/include/torch/nn/options/transformercoder.h,sha256=01DaY-884a5FmcY0-HqhvRkgyl3Nj2QHSQ6gd30B-YQ,2393 +torch/include/torch/csrc/api/include/torch/nn/options/transformerlayer.h,sha256=XzgUERWTKwi5uOIdBBRYm5fVuh8jNDaG9Vt67IhtsWI,2129 +torch/include/torch/csrc/api/include/torch/nn/options/upsampling.h,sha256=_E8Q2aS1KdZdT2rw4edkDAyKDqhp1e_GEWI9mnusd7s,4233 +torch/include/torch/csrc/api/include/torch/nn/options/vision.h,sha256=wGUxA24QF3LH0CtzkMQegpwNjSpNmoZwmFIxz4kb8xY,1120 +torch/include/torch/csrc/api/include/torch/nn/parallel/data_parallel.h,sha256=gjR1hJ3HPl1M270x-wLM8-_m6RBMaErmzID5NDlI5C8,11493 +torch/include/torch/csrc/api/include/torch/nn/pimpl-inl.h,sha256=qC12dFuUBAQqs4ubt1iPaMiBwTwCG97wATjA8y2YIb4,3300 +torch/include/torch/csrc/api/include/torch/nn/pimpl.h,sha256=59N_KVHDwsFoy_dhQHmX9-vgv3fZO6nDUuJ3nUIsML0,6825 +torch/include/torch/csrc/api/include/torch/nn/utils.h,sha256=8qAKIZqgGj4uPo608t1C3tSXaxcCpHcZcgumbcJ9wy4,136 +torch/include/torch/csrc/api/include/torch/nn/utils/clip_grad.h,sha256=h_qK2Ro44s4aTu1KwbXdvNNsAlV6d3DKpIpxtsv5-dI,4982 +torch/include/torch/csrc/api/include/torch/nn/utils/convert_parameters.h,sha256=Gua6BEKx0vosR-z7utQl5Ab3M0J-6HwMG5mWJwPUMEg,2454 +torch/include/torch/csrc/api/include/torch/nn/utils/rnn.h,sha256=AlgTB-JuS0BLv6txK3OWwptntyhzTuvs24IcY8boSzo,13141 +torch/include/torch/csrc/api/include/torch/optim.h,sha256=YR1ubuV1iVUXZ0EoYrUHG1ebXDixaFPvYp2kVBAmknI,407 +torch/include/torch/csrc/api/include/torch/optim/adagrad.h,sha256=42J4nM3NiYlEas_l64Z1in0GM7uDSIofesKx0JXUzAg,3296 +torch/include/torch/csrc/api/include/torch/optim/adam.h,sha256=QqNql4Tw2jajbYuyDMoHYb-nnRKmIEXaNdAc5dzlJQM,2966 +torch/include/torch/csrc/api/include/torch/optim/adamw.h,sha256=02zM064CHPHRXqkej9eQy57csVKmU6_uxaNREPN2Uqk,2986 +torch/include/torch/csrc/api/include/torch/optim/lbfgs.h,sha256=qnuh-4fIgn5FMiWKPgFzpT-5ahzC5iNu2N_cXU6KPmE,3545 +torch/include/torch/csrc/api/include/torch/optim/optimizer.h,sha256=K64Sfyxha6YYagUUI3hJrZla0WJQ6cKkNBZR2NHP7sg,8349 +torch/include/torch/csrc/api/include/torch/optim/rmsprop.h,sha256=fDdDhEr5Lp8yMtYyQDv8KjLbYnE8_tf9tmS4UwmdBMc,2989 +torch/include/torch/csrc/api/include/torch/optim/schedulers/lr_scheduler.h,sha256=-qxn3h0pDUzlVMATyd391_f7LHyeqppeo-qXGr2BirY,1157 +torch/include/torch/csrc/api/include/torch/optim/schedulers/reduce_on_plateau_scheduler.h,sha256=LSdIxkaHd3OuVUanpWLjih8Wk2DV0ZHhRSinUL-_wT8,1439 +torch/include/torch/csrc/api/include/torch/optim/schedulers/step_lr.h,sha256=nugegH2ZuooK4NbQGBPdVeG5euxCoLFv-WulHiQ1xqc,419 +torch/include/torch/csrc/api/include/torch/optim/serialize.h,sha256=-RtCdNSe9w8zvkEQVzNcVxIivQ_BNBHIloWZIfQi4gI,12988 +torch/include/torch/csrc/api/include/torch/optim/sgd.h,sha256=az6sQ_Q_Yn28mPORhFVTBb0Oemfiw-pNVfbOqiy4XgE,2707 +torch/include/torch/csrc/api/include/torch/ordered_dict.h,sha256=a5Xa-V8q0g7k4svbq7tKRBEtzEOeG1ly3zGykgH_uxQ,16784 +torch/include/torch/csrc/api/include/torch/python.h,sha256=1212sCcLu71x9-Vq5wZvqQernPBYXgpb3QESbxWhffc,10163 +torch/include/torch/csrc/api/include/torch/python/init.h,sha256=1BdHmrhl_REQ_M1kCKwBYJHI5otlh0RG9lDH0o6KIwI,212 +torch/include/torch/csrc/api/include/torch/serialize.h,sha256=SVhq2UC2Urt5A4vd6gF_g3EWKeL52DRxgoII_RPsgYg,5388 +torch/include/torch/csrc/api/include/torch/serialize/archive.h,sha256=eFWJ0rMIETBxQ-RzrbeJPinxHqQvw4HgItM4yTKstEo,105 +torch/include/torch/csrc/api/include/torch/serialize/input-archive.h,sha256=x2XkZnEQdckjsx4rZE7RNBVD_6v9OrkLHV_og7GRdBw,4070 +torch/include/torch/csrc/api/include/torch/serialize/output-archive.h,sha256=3ZeFuXRh0iHEVBu8SDnZSGKF3d1nmRqSAYgDNNEd2kI,2369 +torch/include/torch/csrc/api/include/torch/serialize/tensor.h,sha256=8THFm0NshLaq3TF5ft_O0-auH_b-w3zYAEQbqRdglDI,452 +torch/include/torch/csrc/api/include/torch/sparse.h,sha256=f5uTnnP3yHl4yQrD02o9R4LNFMt1cxusttlmyLRdEWA,40 +torch/include/torch/csrc/api/include/torch/special.h,sha256=DrB4Oq4E3xErMaf-ASb3Lgk5_SxPbmm4IqQL-9_-MfQ,39688 +torch/include/torch/csrc/api/include/torch/torch.h,sha256=diC5AEPIQ8C8M3WRszaaz7tfPjWeQzCxQQ-gczSPb-s,162 +torch/include/torch/csrc/api/include/torch/types.h,sha256=a2BaFOLbb33XTZEp0zoYcfrpV2kiaT-bXCvKZSGVmkY,2458 +torch/include/torch/csrc/api/include/torch/utils.h,sha256=hSen8TPO7Pzi6BHSk-UyWGrqv7suDpe2isMSYhDao1c,3702 +torch/include/torch/csrc/api/include/torch/version.h,sha256=K5xAhK7kWbzmmzMT-Csg_Gc9ze_aWGYMW82_Y6Em5s0,825 +torch/include/torch/csrc/api/include/torch/xpu.h,sha256=YNOirmQkZYOOgCH58b9rmec4zy_n5Nb-P-bnEvDbhj8,627 +torch/include/torch/csrc/autograd/FunctionsManual.h,sha256=axwjeq33DCcDcQFblzAmWhD3OIZ8qQP-f_tuo6Qv14U,34186 +torch/include/torch/csrc/autograd/InferenceMode.h,sha256=193BAyIQipMnEgZCHFb_SqewEpjYx-_cFMLnSOtytOA,166 +torch/include/torch/csrc/autograd/VariableTypeUtils.h,sha256=QUmZse8Awe3RBcrlKXid4MIklGevPkIhypxvk22Qno8,15008 +torch/include/torch/csrc/autograd/anomaly_mode.h,sha256=EdlJ3X6jrq1yMKVnOdg842anobabdVZUwl8bt0LzSFw,1796 +torch/include/torch/csrc/autograd/autograd.h,sha256=VXpDwnhJ4ClN6exWax3Ss7IZcYSJAopuNskWEOhq5I8,5413 +torch/include/torch/csrc/autograd/autograd_not_implemented_fallback.h,sha256=VN5hLPF82ULRZ0OnQeownr9yMKcKq4NUGZUHgiT1ILY,1174 +torch/include/torch/csrc/autograd/cpp_hook.h,sha256=z157wjum4Bi8rG0gKZ5KIsfP24tLQvcJXv_XLq1XOUU,991 +torch/include/torch/csrc/autograd/custom_function.h,sha256=fJ1mHO0oTjEQg4QDyp2MuJshKJ-xH7xa6UdpHK4IcKU,21929 +torch/include/torch/csrc/autograd/edge.h,sha256=d4pmtht0HycmtESGyi8e5ydG7-jAX3xwjOH0jy8hSDY,1671 +torch/include/torch/csrc/autograd/engine.h,sha256=AmzhrxDAqx39Hn_QKDm9i5GHaI0MMs9G6EYULkYmWuE,11094 +torch/include/torch/csrc/autograd/forward_grad.h,sha256=DSDZoiS8c9ej2S21GqKHBXocqHUxlTmZlnzyhoUm78w,9147 +torch/include/torch/csrc/autograd/function.h,sha256=0KvoS6_lYKA1ru23b0P8vrAdL6ffSSj9PpdoEFJTFzo,31446 +torch/include/torch/csrc/autograd/function_hook.h,sha256=c4DoTJ_KCtMcMS6FF0Luhbaair0JDsODwf0HqTUvHik,2302 +torch/include/torch/csrc/autograd/functions/accumulate_grad.h,sha256=bnYVO7zXYh7IUgL3dY50aGiiXEHZ2CPFbxF7zGJAoJ8,15199 +torch/include/torch/csrc/autograd/functions/basic_ops.h,sha256=aZwYgZg1d1i31xcHQN66f1PhjBXp95InnW4IB8GCh24,3508 +torch/include/torch/csrc/autograd/functions/comm.h,sha256=V1A7lmYmEC86xUkKZvASExJUwbbnw1EcW5rUXnl5dUs,1242 +torch/include/torch/csrc/autograd/functions/pybind.h,sha256=AeaElJ99I4XWPp1P4a4RRjcn8ANh3be4w8gkZ_EqTEc,394 +torch/include/torch/csrc/autograd/functions/tensor.h,sha256=_LlY5I8HqvHE7HXpRlYe64NEth3e__UqB225Y8CR-gw,7454 +torch/include/torch/csrc/autograd/functions/utils.h,sha256=5MkoQtEIDo9fYNuhHqp7aFpC4xAiny8L2JBDOdWrz-E,3345 +torch/include/torch/csrc/autograd/generated/Functions.h,sha256=jro2EIk-J07BjIWVQNj1jCw0dfoNFxI3rqGIA5_d75M,530251 +torch/include/torch/csrc/autograd/generated/VariableType.h,sha256=G3xk8KIJ9SkUguyjKxAPzWa31PacmPPc_9sv5OGDAz8,1560 +torch/include/torch/csrc/autograd/generated/ViewFuncs.h,sha256=W5q3AbdVZ6HxSsFvhEP8Z0ZJe1DDz-5-7eo11dEo-Jo,38278 +torch/include/torch/csrc/autograd/generated/python_functions.h,sha256=pPMQllLtVVzpn-jyNnrF3FbpG_sFGeuMbTROwA7ESq8,916 +torch/include/torch/csrc/autograd/generated/python_return_types.h,sha256=HMnZGi66DMSndlKqtTkQxcRO9onCwcO7MxXxxK3IlsM,4160 +torch/include/torch/csrc/autograd/generated/variable_factories.h,sha256=rXMr2vr4XpI8C24U1ddNr62dC9B3quvzbk6UzyY8Rt0,57399 +torch/include/torch/csrc/autograd/grad_mode.h,sha256=FeBYVzVyI9pozFfZJix2cp5tcJVcdwYUaD59FcZQqjc,221 +torch/include/torch/csrc/autograd/graph_task.h,sha256=0px1K5XcMkjhno24QoVNPwVhSsdl9C9tVXbxKivgy28,9605 +torch/include/torch/csrc/autograd/input_buffer.h,sha256=sjdmEzKOpNvpDvlxaQthaXoyi6q3YMzyYDEG1v6M__g,2051 +torch/include/torch/csrc/autograd/input_metadata.h,sha256=vEEkUGXuU75JKvxqS-TX9CeEO5jfyJd17uegqsCTS0A,3095 +torch/include/torch/csrc/autograd/jit_decomp_interface.h,sha256=L8y4EHpLmjbXlmC5IGOL8UgEJTwsypOg1HGPXX42Woo,1878 +torch/include/torch/csrc/autograd/profiler.h,sha256=ecH4z6CHlutJUcsQ0JEvqxuhBY5ezLYtNrhUHs9iCg8,116 +torch/include/torch/csrc/autograd/profiler_kineto.h,sha256=lFmlLzJqXAu0pxtxBAq_Mn-ggdJQMEYT8W59NoCIat0,8408 +torch/include/torch/csrc/autograd/profiler_legacy.h,sha256=ZbumxB1m3tjcxrun3BL-dkHINbkfiY7dWkZU74hGgHI,11079 +torch/include/torch/csrc/autograd/profiler_python.h,sha256=_cad-WJ8TurfdWXwU2ZaHhEN8Pme3A-kWMz9qvX4OVQ,91 +torch/include/torch/csrc/autograd/python_anomaly_mode.h,sha256=d8v4n4iablIu_Loyf-oj7aXTAsG-PB9GZ5FO3qguLqU,1229 +torch/include/torch/csrc/autograd/python_autograd.h,sha256=aX5wGQGP5tmKz82C_LG61I4LMgG-aikeGwbAgtYPsjw,441 +torch/include/torch/csrc/autograd/python_cpp_function.h,sha256=6AiI0uVli-CUI8rprvwT27GQu-BrtL8Z4KfQIOwCCHc,5375 +torch/include/torch/csrc/autograd/python_engine.h,sha256=Nioh82rxHwpwpIppJlxkEjfAuBK-TFkOFtUVFfU63WQ,1300 +torch/include/torch/csrc/autograd/python_enum_tag.h,sha256=2JZDAYZiXXH0QV-jY5PJeaT4Lc7p7EVU_f3Q1BPAKmY,127 +torch/include/torch/csrc/autograd/python_fft_functions.h,sha256=mQWbFZ0nW2CAy8X6WUO52g721q0gDugBTlY_K7-yX_k,143 +torch/include/torch/csrc/autograd/python_function.h,sha256=C67AqRRm12I2KBQ2qIrjcqWUIUACOMmBEic14sY5Ks4,5366 +torch/include/torch/csrc/autograd/python_hook.h,sha256=m3Rgl1sZjzLUApdhQg8La-Xii5Ghqrr4j86EqygO--0,2126 +torch/include/torch/csrc/autograd/python_legacy_variable.h,sha256=Qy7FTeQa8Y9KyFUoUhyZHpJMX35Pg6HEljILbZ-CV78,269 +torch/include/torch/csrc/autograd/python_linalg_functions.h,sha256=sfJScJPSRcHt8zgbZlRHms-x3Hhv6EM1hJ-1DC9DDCw,146 +torch/include/torch/csrc/autograd/python_nested_functions.h,sha256=MakQIRwTx2shIrMXT0tY7u1Mnu613AMsc3k6kLcojwA,218 +torch/include/torch/csrc/autograd/python_nn_functions.h,sha256=Dy0X6XmSlH8B8qEgOCaVuSIu0zlYqF9QMaOUt4qWjuQ,136 +torch/include/torch/csrc/autograd/python_saved_variable_hooks.h,sha256=TNq3kgdZpn9En7f-TRM2M90JxOodWb70zGtIa4vXG1c,1115 +torch/include/torch/csrc/autograd/python_sparse_functions.h,sha256=bZFCyXcB36euWvlWJPDU969tUk9FneIxvA7EUhOXxxc,146 +torch/include/torch/csrc/autograd/python_special_functions.h,sha256=LrEhjVvp3SYUlFp_q9xffRcV_8dzdwD0WvpLgO_5PTc,145 +torch/include/torch/csrc/autograd/python_torch_functions.h,sha256=rNTuwe2TI8bmjakJdrpdoESMANNaZmjVJQkfTeDuUDM,691 +torch/include/torch/csrc/autograd/python_variable.h,sha256=cY0NWSxrpAu5v5xy_W4hueczXdPA9dItxMas0ToVW7k,3626 +torch/include/torch/csrc/autograd/python_variable_indexing.h,sha256=mxm3pbMSiwd37CIxlJwjcoiHMPZw33vKEE-O8z2HWQg,2851 +torch/include/torch/csrc/autograd/record_function_ops.h,sha256=K3KNTCuo1XZdAzmtScEUBWcFOEygK9NTKPTOYeisPck,962 +torch/include/torch/csrc/autograd/saved_variable.h,sha256=eVDF-5AvAg8KnBpjYGFGC8R553aSwy2KS0FPpel0yjc,5156 +torch/include/torch/csrc/autograd/saved_variable_hooks.h,sha256=Xo26gw1h4-Vg_Jpw2HuTp7XnFRlcj2OHwJOq3WJqy0g,566 +torch/include/torch/csrc/autograd/symbolic.h,sha256=8D-6u1srDIUgsqYN9qF2KxKAULqP3piHcJGl6XMpQLU,316 +torch/include/torch/csrc/autograd/utils/error_messages.h,sha256=clVeU7UoYAJbkh62Svx2l_EAfZ9rMCl-Rq5ESyeZhCU,513 +torch/include/torch/csrc/autograd/utils/grad_layout_contract.h,sha256=ZI_XlM2JE0ODff8W28tOjKA4Bk3rdiysStE0VyeN-pg,2898 +torch/include/torch/csrc/autograd/utils/lambda_post_hook.h,sha256=aEgWj42AUw0a0XjPUjybf3KgRjBlL4H2kP-CtlNnLQU,1448 +torch/include/torch/csrc/autograd/utils/python_arg_parsing.h,sha256=vNtDgR6jTSPepG8cLO1ZM8Q9iD7UzabMSWgWS6clel0,1472 +torch/include/torch/csrc/autograd/utils/warnings.h,sha256=4m8p8eQuh4wTaUK5vOz8sNnSjf5POxLVzncfFloytBE,607 +torch/include/torch/csrc/autograd/utils/wrap_outputs.h,sha256=udjY-AHJ5_Blkp1JjWewJiCgJqx3SnJJWxYAR2zvY3U,3892 +torch/include/torch/csrc/autograd/variable.h,sha256=csPLFyRLovS2YBznXsfemSjkB1-b_jKmDjShsniVN9w,41297 +torch/include/torch/csrc/autograd/variable_info.h,sha256=dhwpuE7XE53h9MGXvR7KuNQi26AbPeByXb9kXQISdC4,631 +torch/include/torch/csrc/copy_utils.h,sha256=pBBkjwDDzr0RewFfbQgg5T0-6FTE4rb3_c3Vmnkza5I,1472 +torch/include/torch/csrc/cpu/Module.h,sha256=jbTvupqsG3uL3KDAZBLu4I6A4dBFMFksavuz7qSuuA4,147 +torch/include/torch/csrc/cuda/CUDAPluggableAllocator.h,sha256=VaTW5rLxLT2M-P0s5UBDxlRbrslYiYgYQ_1aypgz3pQ,7260 +torch/include/torch/csrc/cuda/Event.h,sha256=aMb8Wc0rYHXnECz6v0aQDok1bZGVSJ5-F5DDpSL0DBU,452 +torch/include/torch/csrc/cuda/GdsFile.h,sha256=CxtufewBQZxLrn0tDtBmbMsax0pLl3CiNVPXL6JPn5M,201 +torch/include/torch/csrc/cuda/Module.h,sha256=EV6U4_0L3f0lRPF8ONY75vssAWdO2Mliz67fh7_I9WA,496 +torch/include/torch/csrc/cuda/Stream.h,sha256=TEh5ybU34VDMzc1F4thXe6qTAGqctH9h5JM4PIJxBMU,524 +torch/include/torch/csrc/cuda/THCP.h,sha256=RQDpnKGGy0jAvI7J1yFplhf0QUlDDgvpJtVeMbZrE7E,231 +torch/include/torch/csrc/cuda/comm.h,sha256=3aMxp07JvvCeUQ1CGwOAY6oZUFZaLEzV4Vxt5dtiLo4,1566 +torch/include/torch/csrc/cuda/device_set.h,sha256=Yyw5e5j2ukJaEUgyh904RfmrHDoU5wqCercO5lcysjc,196 +torch/include/torch/csrc/cuda/memory_snapshot.h,sha256=2c5pDVUDOT4kqhMI_Vs1yL0OT2PAaj_PgYY5fhKEIqg,1014 +torch/include/torch/csrc/cuda/nccl.h,sha256=o6JGJc0LyYPYBA-oSXGXRPG5bXM5QgpNGMj5RuRG2oc,6090 +torch/include/torch/csrc/cuda/python_comm.h,sha256=rTmmYE_D4vtOx0bHH5U_hiogUJsnEfJL6TR8LqKow4Y,179 +torch/include/torch/csrc/cuda/python_nccl.h,sha256=5fIyxj552ZTtSeNK7q2dQjlnrphEsBORyALmzkRxRGo,695 +torch/include/torch/csrc/cuda/utils.h,sha256=eLzPmPc-H33WA2WsSiYuIe9H1-qgPKXWuvdwCIKH2aU,224 +torch/include/torch/csrc/distributed/autograd/autograd.h,sha256=8eMohrXv7gpPN0ny_owVMwdZdAxg8u79l5-jEORecCQ,1672 +torch/include/torch/csrc/distributed/autograd/context/container.h,sha256=h6OxlArf2srgN5K7Mm4gFh_Yi1Q8HWdyNsceuHa1K-U,6546 +torch/include/torch/csrc/distributed/autograd/context/context.h,sha256=fRYq9csinqEc8bG1c0_nANK1oBygw6FVLPvBopUd7pQ,6785 +torch/include/torch/csrc/distributed/autograd/engine/dist_engine.h,sha256=phsyfulRrCH2EiIOQS1mVGzYsaCfyMjhRvzzihMCTB4,7590 +torch/include/torch/csrc/distributed/autograd/functions/recvrpc_backward.h,sha256=eUbAkHAt1dgWgFRx8z1WWWxFsfYciGOqgInRhYgLjF0,1707 +torch/include/torch/csrc/distributed/autograd/functions/sendrpc_backward.h,sha256=6F9XfEg5EDH8Y9i0DkneB6p2aIb82A34llcr8H25TIA,1356 +torch/include/torch/csrc/distributed/autograd/python_autograd.h,sha256=SJ9VGYOml3FEy2ZwS-NEHeBMiAupmAnx6NGleFIr594,183 +torch/include/torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.h,sha256=LOywcaxn0TS2Ar0eQZv6cpfQBUNhYb2H7b_V2e9bTCA,721 +torch/include/torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h,sha256=O-fKhg3HjwvTSn4vCRc5xHLWEXWw8GhV_BQBftTu0Fo,861 +torch/include/torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h,sha256=pwnUH1latd-VHlw4RhwNv0JJwMLZnZ-1fXVX0PC8z5g,680 +torch/include/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h,sha256=C4kC-BYc2BDIHoJpHZ_sIItoHq5SCjf8I3a7Zs6geGc,1286 +torch/include/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h,sha256=HJ5LjVLxuxKYkamVpyyroy13uBbpCYUiHTOmfTF4FLg,774 +torch/include/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h,sha256=YyPuEk3MOmtBd51YhPdRD0s41Jqw__UjbwVOhCiN04w,3600 +torch/include/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_req.h,sha256=NgxgJO4eLiPX-nuweNsEWJdNpO2tWgnFacRnUlACURM,2571 +torch/include/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.h,sha256=9QOWj3K8GOOUiLyoEBaZsCo6zbGJzOaMibevcGy06lI,2528 +torch/include/torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h,sha256=3V-7R4PDVOyhqLMd_2qrUOLA7K4cfpINu6gQuGSEuXs,1228 +torch/include/torch/csrc/distributed/autograd/rpc_messages/rref_backward_resp.h,sha256=CBrsyzpqFZ2ftg3E7Zyt1NHLJyKrTZEYmkzafpyjFKQ,525 +torch/include/torch/csrc/distributed/autograd/utils.h,sha256=3KPV31Q_DCt0-A0ETgiYMOC18LmXbYhz19bPzPAg1PU,2703 +torch/include/torch/csrc/distributed/c10d/Backend.hpp,sha256=QIDEEZkCW0Xj8FjjuBZDfww8s4SEYIQhv0U8G3ABhkU,15233 +torch/include/torch/csrc/distributed/c10d/Backoff.hpp,sha256=2bBiso7arafyzx-khiAXF0N3TdgV_F2SbAtKSAvFzmg,1103 +torch/include/torch/csrc/distributed/c10d/FakeProcessGroup.hpp,sha256=5T9JpomkQq552HWEBV9sWanFR35gmK0OnKmYM0TepHI,6810 +torch/include/torch/csrc/distributed/c10d/FileStore.hpp,sha256=0KA3DOoLojHaiIcfsw7DTf462O-AX3rpXQGhqAQSucU,1604 +torch/include/torch/csrc/distributed/c10d/FlightRecorder.hpp,sha256=9yNohYkn7KPMmHwMqbJMQ3E6CPm4fPHOaUYk0HJ56ZE,10347 +torch/include/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp,sha256=RK5UTNSCVqhi4XkUL4u4UPMA8JOejQvhm8qAmwUAN-U,18680 +torch/include/torch/csrc/distributed/c10d/Functional.hpp,sha256=dbmYPuW47p4fsgxUn8uwAJO09qjWrAHLh5TZQ8i2m7w,73 +torch/include/torch/csrc/distributed/c10d/GlooDeviceFactory.hpp,sha256=evZt6fzwGn3Zo5oHXidOhjEzh-UUmR-KolCFkPKTVC0,875 +torch/include/torch/csrc/distributed/c10d/GroupRegistry.hpp,sha256=d68PbqmRNkzKUji9Inzp-4gRofDuC8j12z5e2sc_uwQ,591 +torch/include/torch/csrc/distributed/c10d/HashStore.hpp,sha256=JkPb79f0d23DnWUQHxOebnPCPRe2iTMSsY7EHq8sH70,2263 +torch/include/torch/csrc/distributed/c10d/NCCLUtils.hpp,sha256=QM9WejAan2KqljnypKIrN9Gw9jINm9-kX0bRqfScNiU,16519 +torch/include/torch/csrc/distributed/c10d/NanCheck.hpp,sha256=hnw1fr2PjtWVnorse_5LZwvZQ2AlYh9ucuvrhjQPBBw,344 +torch/include/torch/csrc/distributed/c10d/ParamCommsUtils.hpp,sha256=h45mB4kv81PQZbh_ekCZC0NqoNR-xovd6gkrJTRBJEc,8933 +torch/include/torch/csrc/distributed/c10d/PrefixStore.hpp,sha256=Yo03dWyYuKfHwVALQkGTttz4tb1NbB2RLiWCHBUFWUo,2297 +torch/include/torch/csrc/distributed/c10d/ProcessGroup.hpp,sha256=UoAMx4JAQrlCVSxYJpbg-I4N6ISxTVMRBIHk5sF82vY,36285 +torch/include/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp,sha256=-IeNgGZSvmUKmMPMiqJ6BVL8gyTEYbvTeM1DXnIztqw,16373 +torch/include/torch/csrc/distributed/c10d/ProcessGroupGlooDetail.hpp,sha256=0XP2a3JNU1Y1CO2ciHibsj89st76Tp_KX7Ehwh3OiW4,21793 +torch/include/torch/csrc/distributed/c10d/ProcessGroupMPI.hpp,sha256=xSA8FYLRP1JXvhjv1yCeWHNQ6yl9hnVhqvHgEWp5nic,9047 +torch/include/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp,sha256=jX5vZP41LbH_vKgqRMnSFm7g9D_vT_hKLShKkqxPmhQ,55947 +torch/include/torch/csrc/distributed/c10d/ProcessGroupUCC.hpp,sha256=D6xjgH-8AaltWLeMCkgaiOhow4yyzIqcY0vR1J3uyQc,11622 +torch/include/torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp,sha256=Y5xA54sYnDFe0tHwtIkETJNHbEK5ywvPGt7SMJpOI5w,5217 +torch/include/torch/csrc/distributed/c10d/PyProcessGroup.hpp,sha256=_6_0SrDOetrTj-dirxiYdqiArPnbA4E3ecY6TwtZj3c,10370 +torch/include/torch/csrc/distributed/c10d/RankLocal.hpp,sha256=_vXAoWTPQM0caOkBbMmdL0QHrlCuMz4u0JSy8PkW9oY,2360 +torch/include/torch/csrc/distributed/c10d/Store.hpp,sha256=jMEY-T1gWDvQF3-Gduc8qMsh82EMLi9n8mE9WFsdXR4,4453 +torch/include/torch/csrc/distributed/c10d/TCPStore.hpp,sha256=iE64OY4JDdYkzTHH2Q1GdlVCtGbJWFsz8uQ2q3OfEZs,5223 +torch/include/torch/csrc/distributed/c10d/TCPStoreBackend.hpp,sha256=961NTaT89aH7wZP6R5ahTW_DDLnCIYNGUiIVgEftnhg,1644 +torch/include/torch/csrc/distributed/c10d/TraceUtils.h,sha256=a2nRt1fZKRTabyfR_YDlfeU-2-COVRGtDHXjKTj4PoA,9405 +torch/include/torch/csrc/distributed/c10d/Types.hpp,sha256=eqY9HdZC94ZQUrSoe0MGpPDv9HDjOEDTFdP-COmhO6M,5199 +torch/include/torch/csrc/distributed/c10d/UCCTracing.hpp,sha256=seQKpba3pMPumoYGBRw8tYG0t6RsH5nKboSjKtdOII0,2359 +torch/include/torch/csrc/distributed/c10d/UCCUtils.hpp,sha256=ffqb3c9uAjEjedeuVkZRGZOMT7Si6yMjwDLlgt5eSVc,6547 +torch/include/torch/csrc/distributed/c10d/UnixSockUtils.hpp,sha256=lis-Iqe_djdBbahb01JINpHRtzDW0dfbd5YKK2zYwEM,575 +torch/include/torch/csrc/distributed/c10d/Utils.hpp,sha256=WecTsWUYFiQbByZwHvyLRIoZ0brp0JUa8k7B0q1jU6E,24554 +torch/include/torch/csrc/distributed/c10d/WinSockUtils.hpp,sha256=n4sMpcezCTCeAXynNtZWwx3MaOFCtYLBssj2BKhLXpQ,566 +torch/include/torch/csrc/distributed/c10d/Work.hpp,sha256=dbQ6TfEGW3MwFVDdy3sEwXBXl0YDABEkdU5xX5malfM,5690 +torch/include/torch/csrc/distributed/c10d/c10d.h,sha256=BYk8QfECnHp82Mq6sHAkaH7EYA2pWhroNrS9IcbZJmc,175 +torch/include/torch/csrc/distributed/c10d/comm.hpp,sha256=wsAGB0FCMY7in_-1vc4mZ0lanBxVc7XuKYv3bom7SIE,4565 +torch/include/torch/csrc/distributed/c10d/control_collectives/ControlCollectives.hpp,sha256=jXcHjvT8im48NCN7Y2T6aDXpRnMcfpOjCJb_q_3OsBk,1764 +torch/include/torch/csrc/distributed/c10d/control_collectives/StoreCollectives.hpp,sha256=kJmrAP8g2XLrDklp8bQbgCdiP2wAYolbP8Gxk_-SFO8,2011 +torch/include/torch/csrc/distributed/c10d/control_plane/Handlers.hpp,sha256=e26N1BWaRPbm1aPDPO2HZoIrvEKViOrBB4ZqeYqsoO8,2213 +torch/include/torch/csrc/distributed/c10d/control_plane/WorkerServer.hpp,sha256=7fRd517SkPntUDVH3DtW3NwqJCiZvcyeqEQhRLvZdko,624 +torch/include/torch/csrc/distributed/c10d/cuda/utils.hpp,sha256=iI43KT3i-IxYCXBdy1rH8lFYkLXT5CZZd4LMl_sFeMg,240 +torch/include/torch/csrc/distributed/c10d/debug.h,sha256=uojfbUpL-GFGL5HO_bFvMTB8JpeVgpRL8U1JHKUkt8E,627 +torch/include/torch/csrc/distributed/c10d/default_comm_hooks.hpp,sha256=txBto9SYRccxVfZDxEb-7OLwUdSeS1ylQj_g9xJkWOA,1797 +torch/include/torch/csrc/distributed/c10d/error.h,sha256=G_zLJQtkz2E0IBank46M-5oGheyw0AgOXtPgNXLCXz4,1390 +torch/include/torch/csrc/distributed/c10d/exception.h,sha256=vkm2TiFG5esOl2iXT7pInXJ_SpnqVnDg9Wb_F2eFeOQ,1365 +torch/include/torch/csrc/distributed/c10d/logger.hpp,sha256=ccX_uJ7Ee1UPtqLBo6d73BkbVxolJ1P5zr43DBtUtzI,6592 +torch/include/torch/csrc/distributed/c10d/logging.h,sha256=b0ZZ8Rj96NUThQaAIYr6s1dM9DxMV04QP715MostIaQ,1877 +torch/include/torch/csrc/distributed/c10d/python_comm_hook.h,sha256=3oee1thi2jWjRaO16qAu9uSx-edXcP8npSMsqFQtCug,1107 +torch/include/torch/csrc/distributed/c10d/quantization/quantization.h,sha256=WNqEC2TYR_Cvguo0QnNPr5t3YERSyluAVR8DrTvA_Wg,470 +torch/include/torch/csrc/distributed/c10d/quantization/quantization_gpu.h,sha256=Nky_P_Me-WGnvocia4fajuMCU63dkMlaZdLDsHIS7Kg,472 +torch/include/torch/csrc/distributed/c10d/quantization/quantization_utils.h,sha256=BfHEROlOBpKUgJW8YvQQaMzeYjRvawG4TIG16Mv-61g,1290 +torch/include/torch/csrc/distributed/c10d/reducer.hpp,sha256=u1s2eJJjCxW5oCwRhvRXqIq5p62iZTo7AVk8duhLMwU,26684 +torch/include/torch/csrc/distributed/c10d/reducer_timer.hpp,sha256=3ZMaL_ZrOYkmlQjyd2oJijrKJlnZG_zQI7dvMvAsa58,2465 +torch/include/torch/csrc/distributed/c10d/sequence_num.hpp,sha256=uLSofH2HajWKZKWoFbntNuuG0P54YE7upTTjVeJyb7o,1770 +torch/include/torch/csrc/distributed/c10d/socket.h,sha256=DQC8vb14SVNznumqWw5LlCXIui1jDTNUu7FaO-0O-4c,2565 +torch/include/torch/csrc/distributed/c10d/socket_fmt.h,sha256=hyhMosJNwTUEKdP8FZEDowWRw9WRjNtHQUqySfWTExY,732 +torch/include/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory-inl.h,sha256=Xq21zJ2MKTpUTFwJQcmAYnPp96duAdekV_MWqTHIpMo,13341 +torch/include/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp,sha256=hDNFcQtL5ouiTzIjslhNiHCpqy6Dd7VpDXTCtyHgMaA,3725 +torch/include/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryTypes.hpp,sha256=ab4v0UX15gUPuPdHdN89fmYknpHABexyMs4asQXPyVw,378 +torch/include/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.hpp,sha256=uRJJCDZek9SIr8UBq5htdVjTi_Qnk9NBpKb64BP89Rs,3081 +torch/include/torch/csrc/distributed/c10d/symm_mem/DMAConnectivity.hpp,sha256=g0MM3zaVqxPxDArrtmp_CexKFkucY1Ik6lwm7bO8k-I,1315 +torch/include/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp,sha256=YllIM3SlbkNxJKEPEX7Rop0yLnHsQSqd1KRB61qlV0k,7530 +torch/include/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.hpp,sha256=fBm9SWqnMTrA6JgPmuVKKtxFkg7mdGJAQTBx1o78P3w,2416 +torch/include/torch/csrc/distributed/rpc/agent_utils.h,sha256=qrvgzcHuQ22dKeVtQ2Gci-6gSlhABZbyUDQG1GLTekU,1671 +torch/include/torch/csrc/distributed/rpc/message.h,sha256=9MUqhNSeqEoAR55BS-HWLuUtQQiutQFmwTzTnPqL45A,7826 +torch/include/torch/csrc/distributed/rpc/metrics/RpcMetricsHandler.h,sha256=p9T1aMFMsBOdtM_LsQz2A3_4VeNVnda2j7z_1c9DGug,1608 +torch/include/torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h,sha256=2rbLYNDB1jLb-et0YMl-Gumk9N2v9ieGKz6FK5nIeJo,2281 +torch/include/torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h,sha256=PGqHa2qYCvNs9KEFO--vmg-xUHRnyg8Aun3fXkbb-Ow,4528 +torch/include/torch/csrc/distributed/rpc/py_rref.h,sha256=p27_kd0nj8SPjvAB7wavVTgeKCNbRKSrM637WKV7DiI,3046 +torch/include/torch/csrc/distributed/rpc/python_call.h,sha256=Km8Onke3JkXd9HJ7ig_Ka6xJIYnWG1Pd6dMBZkdPA9w,840 +torch/include/torch/csrc/distributed/rpc/python_functions.h,sha256=l8bSvyAmMkQMjUi6alxpoAY8zQE_Dv4LQifgguVSJek,2323 +torch/include/torch/csrc/distributed/rpc/python_remote_call.h,sha256=vnSwA3ai2j6tILZtMfILEbNISdWgj1o7pgObaoCq44Q,1374 +torch/include/torch/csrc/distributed/rpc/python_resp.h,sha256=ojnguIPzRZVY5AAJW7CpPmRq8ymcc7AIODDcBqWooOQ,644 +torch/include/torch/csrc/distributed/rpc/python_rpc_handler.h,sha256=88tmtQ9ddnd75hef5BElK1Wgv5UHzno3gdLDX0WPwyI,5083 +torch/include/torch/csrc/distributed/rpc/request_callback.h,sha256=tqOzvaJ3v_TNXf5Tj-I78Jdt6XPgWHjnVqdzLnejLuw,1258 +torch/include/torch/csrc/distributed/rpc/request_callback_impl.h,sha256=gi92W0qZJY3ZantZaQWCCbqAVaQNUgC0JPes4U5MfLQ,2143 +torch/include/torch/csrc/distributed/rpc/request_callback_no_python.h,sha256=KnM5WGTt9IsSAmRtp245auwa8wV2dG0sEBYsrndoWCg,4019 +torch/include/torch/csrc/distributed/rpc/rpc.h,sha256=Gw7c_2sprUeX5kBDuYhOVHEm9HJTYPYFi1i8U3fJWAY,173 +torch/include/torch/csrc/distributed/rpc/rpc_agent.h,sha256=B14K8_tWCHaKsDidL4UIp0dPAV_ljYLZATKjEOp32wU,13936 +torch/include/torch/csrc/distributed/rpc/rpc_command_base.h,sha256=w_6S8v2U9vkFHSOtUVnKJqi9RCgf3HtUHqa1jb0PxaQ,701 +torch/include/torch/csrc/distributed/rpc/rref_context.h,sha256=FSH_ffYRXpkPuaGkSHFWcEkU00k86GQTKKmzoCU136U,16104 +torch/include/torch/csrc/distributed/rpc/rref_impl.h,sha256=Xc0PoXVTkQJBWatsPxa2ltCc-H99NVvdHDSuJTzBl14,16874 +torch/include/torch/csrc/distributed/rpc/rref_proto.h,sha256=e6_V0vxYC_J64iAuFUX-TO3a9d2RgFwRWcnAiNVZHz8,5571 +torch/include/torch/csrc/distributed/rpc/script_call.h,sha256=zA1CxolnLoqbARQZ5hyPBGWCx6-67TDH1XyPV3KIy00,2604 +torch/include/torch/csrc/distributed/rpc/script_remote_call.h,sha256=qW0dWvwnzfP6GvDh1YDXf0YmLSVjS_uiXChrgT4TBz8,1798 +torch/include/torch/csrc/distributed/rpc/script_resp.h,sha256=SbNhjefmzekaZFW9K-CRsKXY3tJ_fIGZ5t7hm6_a5Hc,722 +torch/include/torch/csrc/distributed/rpc/tensorpipe_agent.h,sha256=wh43ClhKnlqahbbSsQhlYfDeJSbwsEyVfpchLZAKki8,18150 +torch/include/torch/csrc/distributed/rpc/tensorpipe_utils.h,sha256=KpL_krReoyjIlhEwTE-lPmNtp4LaESfKDn4moQzmsZI,4842 +torch/include/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.h,sha256=ONTxwhtEZt9WK3BEJww3iadVv93ZpN45tPyxeylvx5A,3890 +torch/include/torch/csrc/distributed/rpc/testing/testing.h,sha256=S1Djh22fxeNZ0wffOFA9dFSj56A1alBdSWW8HDMBk5Y,191 +torch/include/torch/csrc/distributed/rpc/torchscript_functions.h,sha256=-wGnemIOTdGbbBq7LZNqwQ3q9b3YkZJQ_5DhoBpBCBs,1693 +torch/include/torch/csrc/distributed/rpc/types.h,sha256=HFeYONFXX_uafJe_wrQmHxfiGMyZVXVoJ-edtdZWbus,2269 +torch/include/torch/csrc/distributed/rpc/unpickled_python_call.h,sha256=6ryczccXY4lxkym51s4QxHO9IC7v34KXanqFaOKDN4A,1426 +torch/include/torch/csrc/distributed/rpc/unpickled_python_remote_call.h,sha256=t1QDHI_meY0GcrbRL-Zkq9Ji016l4AUShTzQizQRI8E,1240 +torch/include/torch/csrc/distributed/rpc/utils.h,sha256=5--QvzmrtHoIs0FbiKZky-z_n8H_wQdDLyWKQMO4NPg,3876 +torch/include/torch/csrc/dynamo/cache_entry.h,sha256=ddKGf90W9jqUhwbFtUFMXQHETRdaYQqIIZIZBDw8uHQ,2911 +torch/include/torch/csrc/dynamo/compiled_autograd.h,sha256=sZI3DHCcv_h8ZwiDNATPdHKeSfCUGnTyEuAGrPJXQUE,51912 +torch/include/torch/csrc/dynamo/cpp_shim.h,sha256=y53NVkpBYkn6BuNe9UR1B_9RvzFnMNdmxb71hkECDsE,372 +torch/include/torch/csrc/dynamo/cpython_defs.h,sha256=2UVUrNUBq4iFbZh-uidp6P84F6YxSXwuqGc_Am2boLU,990 +torch/include/torch/csrc/dynamo/cpython_includes.h,sha256=axbTYiArJ86y3tRM7vnkOJRwL14c_fR7W_S7MfqEHXI,1063 +torch/include/torch/csrc/dynamo/debug_macros.h,sha256=UW_vMc4MUh7zpmWRmTLRdaD-vLubIeao4F6POUOogMI,3630 +torch/include/torch/csrc/dynamo/eval_frame.h,sha256=UFSXnLKf8Mw72FWknwYlxjH55ltpvixbkaNhcOnKddQ,1547 +torch/include/torch/csrc/dynamo/eval_frame_cpp.h,sha256=SYvnkPVmR-se3wQQrrt3hKp8dox42VwkHKriPys4LgM,496 +torch/include/torch/csrc/dynamo/extra_state.h,sha256=cjkQ_EawPD4IWY6RORC5cjixem6Y3iApdY3bEer8pus,6525 +torch/include/torch/csrc/dynamo/framelocals_mapping.h,sha256=WNx6nk0C1aFGIIuH_epdAEsrWVNsFDVbQo7vFwP5RdU,2873 +torch/include/torch/csrc/dynamo/guards.h,sha256=D-__v1Ytia6Xb--q5CYt0TElNqN0F3ldQDO5sKjwoLU,3105 +torch/include/torch/csrc/dynamo/init.h,sha256=3r8aB_SJQwAFxd1lI4Qfg_-D0I5WJa29syXdYrCoZak,198 +torch/include/torch/csrc/dynamo/python_compiled_autograd.h,sha256=vwBYmBcehKPEbA2fk9h-WRYxhP46XUTwCqPyc3HB14Y,222 +torch/include/torch/csrc/dynamo/utils.h,sha256=NOBl110ylKh_W0rYoDgw_nuvytcc8ENB_DBYBWwD2vM,536 +torch/include/torch/csrc/export/pt2_archive_constants.h,sha256=hijkK6iB9rexFZq4ywuI6uvyzFdkk9fTm6VqBfkMZU0,4334 +torch/include/torch/csrc/export/pybind.h,sha256=D1GJutOFb8pxET3IYLDAY06Cq7pVckzqArgEoXiZm-I,149 +torch/include/torch/csrc/functorch/init.h,sha256=V60RPDzxfSj72L41Se-rGJy3cHH97pMfwaKfwnrUesY,113 +torch/include/torch/csrc/fx/node.h,sha256=6EsyTiBcbCqAKpKjRNc52Yb0Oy35GvLOIALiVGCVO_w,136 +torch/include/torch/csrc/inductor/aoti_eager/kernel_holder.h,sha256=ULQjND0ZRRJInZM84I-Eb38KzyTYIM9k2XDBMXEh3nk,4225 +torch/include/torch/csrc/inductor/aoti_eager/kernel_meta_info.h,sha256=b2mrWbYi61w8pivbhZYQlLV5k0OvQjsm2o9WlX0azqg,5854 +torch/include/torch/csrc/inductor/aoti_include/array_ref.h,sha256=31kfZDdXzC1BJPG8mJvdi0YkcFxgw75Wu9vIW5FeZg8,308 +torch/include/torch/csrc/inductor/aoti_include/common.h,sha256=alC3tn_H3EMHbWLbg5b7xmeJrqfU74DJLG2Imf-pX6M,432 +torch/include/torch/csrc/inductor/aoti_include/cpu.h,sha256=4wmBuRxjCcvHrZcRM6H6pso96Q_MWYehwK-A_ePY510,136 +torch/include/torch/csrc/inductor/aoti_include/cuda.h,sha256=AvG6S6EgF33UtlWz1-EE62Ew5YCHd9u-n5_CbnSgtb4,137 +torch/include/torch/csrc/inductor/aoti_include/mps.h,sha256=DoGl_v4dop0u0CB5BVfWht8AGuk6ljviHWXLi2h2u-A,136 +torch/include/torch/csrc/inductor/aoti_include/xpu.h,sha256=HC43-lWmBN-pTLfkYgkDBbpD7jn87VBSx30lyCr_QOw,136 +torch/include/torch/csrc/inductor/aoti_package/model_package_loader.h,sha256=ZWQXeM49BCR2zIBg_egjaSesbGXhwP8wLxJpwNm_2nQ,1695 +torch/include/torch/csrc/inductor/aoti_package/pybind.h,sha256=RAIX3Q47416LRGoxepGQE4OMq5e9aHuYvxOptIatlb8,156 +torch/include/torch/csrc/inductor/aoti_runner/model_container_runner.h,sha256=ymoDvobrtIpD-K-6KNX-9oYTapqm58dB9Q_5rnq10eE,5303 +torch/include/torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h,sha256=jlb71c48i4UD56lT7UOuaR3KXjcLjLrGtMPGVjYI1ns,497 +torch/include/torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h,sha256=zAutzOBOlLJDgKQK_ViqVQIrpmm4bqU-5sU4dKdl22s,1166 +torch/include/torch/csrc/inductor/aoti_runner/model_container_runner_mps.h,sha256=SyHFzYvRiSeI5yWaMcvjilqZac6Mg3ec5cbTo6c9CVg,474 +torch/include/torch/csrc/inductor/aoti_runner/model_container_runner_xpu.h,sha256=xPyGaO0GCScys5yt3xykRCN4K3CErVPw40k5n_XJndk,1253 +torch/include/torch/csrc/inductor/aoti_runner/pybind.h,sha256=uI5Ys7HHCTfmP8doYtdHWgKkB3k4pG399HuFsU080zY,155 +torch/include/torch/csrc/inductor/aoti_runtime/arrayref_tensor.h,sha256=b-LkwWn848zELwy_pd19kWdz-FnsVmCx4Vaox30bkoU,6498 +torch/include/torch/csrc/inductor/aoti_runtime/constant_type.h,sha256=xmukWWFceuYJN1axSJcScLy2_eJp6m40NmDvYjKM9xE,547 +torch/include/torch/csrc/inductor/aoti_runtime/device_utils.h,sha256=6_H0_ha-i8gn9uozsRBINIEnYPzdK74oEyd46H85PIw,2387 +torch/include/torch/csrc/inductor/aoti_runtime/interface.h,sha256=EMDWljE609adwZ4HBrJbCEuEcKtPCoCeEjQaR0JWRLc,9912 +torch/include/torch/csrc/inductor/aoti_runtime/mini_array_ref.h,sha256=lmh2GTSWWUzh1mSDWj7L0-gUhhY_iVQ6hTUhLzyeeko,4886 +torch/include/torch/csrc/inductor/aoti_runtime/model.h,sha256=TIEu2Qx9ap4_uC1fiwFD3d-D3ogUs4gEBRpXeFe1qcA,24699 +torch/include/torch/csrc/inductor/aoti_runtime/model_container.h,sha256=LWrpOxozaYgu84LioXKm5us2XqbS1ldnCAteMAoKcCs,28037 +torch/include/torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h,sha256=4SGOMh3z5CUQ8LR0rKWDSHSEQax1K8leNAcKTQnqhR8,1628 +torch/include/torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h,sha256=-TqI44NkqbLO1rZINqtybBtSA0hDT2dI3g_1v5bzjCA,6136 +torch/include/torch/csrc/inductor/aoti_runtime/thread_local.h,sha256=Gmf_Q0GrMZz_z1y3C0m_xCToXZq0wG_tVxUD9BUGhSc,4512 +torch/include/torch/csrc/inductor/aoti_runtime/utils.h,sha256=9lh3yBf7BGg7LyHhVNbRF0o1ClLzuzGYBAF82sb_pQk,10916 +torch/include/torch/csrc/inductor/aoti_runtime/utils_cuda.h,sha256=oGfqWlUzdYGW7ohfG-fx7pA7-SDFWUiXkGZJtUR9VrM,1888 +torch/include/torch/csrc/inductor/aoti_runtime/utils_xpu.h,sha256=Nx6KSTwx4GrkLjoZ9wRGMx2zwtMEcL7EFT3p9EZP8v0,1771 +torch/include/torch/csrc/inductor/aoti_torch/c/shim.h,sha256=fyNAAlEvAhfiu9W-MY-utQ0ext6EdjxfuU0vIKEsfF8,30291 +torch/include/torch/csrc/inductor/aoti_torch/c/shim_cpu.h,sha256=t5MZCIOz-oD_bTR5KJlrZ6zgQlYv7ib1t2qBFOF5_aY,7230 +torch/include/torch/csrc/inductor/aoti_torch/c/shim_mps.h,sha256=vQ0dU4cfM0FvS3VNxX9XukjVqSsd56Z0-darKcsvztk,996 +torch/include/torch/csrc/inductor/aoti_torch/c/shim_xpu.h,sha256=pQhG-GZwiSbCiuBQdHocQ7fmQyxewvlIgcnq1r2R4uI,3217 +torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h,sha256=2hkQGeYovQs5OhkyOwlirGQ75sjA8WJZn74AjL4-VHA,30693 +torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h,sha256=NxCau-2X80ZVJqETUINsk-q_pdUB2CYYHUY_49Cz-lA,35314 +torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h,sha256=Y5xUoPstZHmOr1WDH0L13deRnDYlZGCyr3R1t5jAGrU,20529 +torch/include/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h,sha256=0CJtkuTiZZQvUbSfiVj-zzcrGrz54s33Ae9r_WK7cQg,10976 +torch/include/torch/csrc/inductor/aoti_torch/mkldnn_tensor.h,sha256=CqIk59pjogj-3How6KSgkmzrVzV1rpmet5iNa6HykxY,388 +torch/include/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h,sha256=3n8tBFWUFY0up0_MIhrJU5WETPryv_6-qeKOdN7eTuY,4792 +torch/include/torch/csrc/inductor/aoti_torch/proxy_executor.h,sha256=8zEU-J3-OebP1YFXiK1Cg9bMuiQiSbwguPR6q9qKJ14,894 +torch/include/torch/csrc/inductor/aoti_torch/tensor_converter.h,sha256=71yBXDDFWsVuReMXXQOW3CAidygIUUMsvMbFpUI9TWM,979 +torch/include/torch/csrc/inductor/aoti_torch/utils.h,sha256=RDyVR9VLTfOu8MNMpmQJ_ElAurQ1dC7YEWdduPPXdTI,7526 +torch/include/torch/csrc/inductor/array_ref_impl.h,sha256=LOKCcD1qM67dsufjnCYnTP2_QdyvhVqlNtgciflsMMc,3079 +torch/include/torch/csrc/inductor/cpp_prefix.h,sha256=DL-8FilY1sfEawkFCUcudconVvK9d30RmlBEtaZTmts,36787 +torch/include/torch/csrc/inductor/cpp_wrapper/array_ref.h,sha256=2NKc9K3-z_ns6-KBkhzW2vos2WgLZQf8sbG1UJ_T9Z0,307 +torch/include/torch/csrc/inductor/cpp_wrapper/common.h,sha256=S01_qo6GDsOPqA3nv9IkPb2189dHT5ILEFT8IAZjrhI,2037 +torch/include/torch/csrc/inductor/cpp_wrapper/cpu.h,sha256=FKENWHdjm4yd-PP0qR2XLABsDq2JoEEaP6vac9HIjkk,135 +torch/include/torch/csrc/inductor/cpp_wrapper/cuda.h,sha256=LkbXsnXwwpEXx4ly4a_DTDK68eKz3cb08O-6czmzj_Y,136 +torch/include/torch/csrc/inductor/cpp_wrapper/device_internal/cpu.h,sha256=UDvy9ybp2EY_PA1L9gdUvbVRsYYSZnDkX-DPoklRA58,82 +torch/include/torch/csrc/inductor/cpp_wrapper/device_internal/cuda.h,sha256=VpQwdE1r0OsGXJnomc0XPBLLGv-ZUgeElkUv0JHjw-8,141 +torch/include/torch/csrc/inductor/cpp_wrapper/device_internal/mps.h,sha256=9xisORQ_W5drKQGBe4w8GqZj2LwZxt38FPlipE9AAnU,138 +torch/include/torch/csrc/inductor/cpp_wrapper/device_internal/xpu.h,sha256=0M79DKn_nYsb-MIp-dvvwEKaSFoYp8hFo_4rDfHX09g,208 +torch/include/torch/csrc/inductor/cpp_wrapper/mps.h,sha256=B3lw3Tz6UXaFhtTUF0Q9n9gBVJ1DQA8N4yNMBjsQHcM,135 +torch/include/torch/csrc/inductor/cpp_wrapper/xpu.h,sha256=B5aN5w26B2u5JM-ugDNFMD_BCppJWw3kjBmIXRzgIRk,135 +torch/include/torch/csrc/inductor/inductor_ops.h,sha256=URQx1CeVjdwSYPDcv9giF3pd3L3z8IqjFzaCi_Ke9Us,1134 +torch/include/torch/csrc/inductor/static_cuda_launcher.h,sha256=tS5mYQ_Pdas3I3esDmg1qUsksJq_LV2fkWuGBitu0to,225 +torch/include/torch/csrc/instruction_counter/Module.h,sha256=FFMxRKeiugNoAprzQ2mdMGaAkBgU2D8R8SvjdvGk2aw,179 +torch/include/torch/csrc/itt.h,sha256=5-DtqMeRWHUQV9XeZqKukuhVc_fWdHv6cEwG5jOsqug,189 +torch/include/torch/csrc/itt_wrapper.h,sha256=yBuGzG2jfgS6iA9dGICG4F6Lg-ktIjBVnAZwmy6Ryls,332 +torch/include/torch/csrc/jit/api/compilation_unit.h,sha256=Mg87hr5lqJ-MX2XBEGT1qjB8_3vmxFKn6kDyYzCb9fE,12047 +torch/include/torch/csrc/jit/api/function_impl.h,sha256=8Sx8hgcp_xCBRJE5aPs05Sd074VBR8VBafaORTv_lbY,5882 +torch/include/torch/csrc/jit/api/method.h,sha256=uPYzSKariJOQyUUfdUz2BzKUZ2zBJqOSrff1HtUcyXg,2454 +torch/include/torch/csrc/jit/api/module.h,sha256=_Ej9BRoHTEKtACUfH95SdZ_AoMVNJNxpIxwiYoDPFy4,24166 +torch/include/torch/csrc/jit/api/object.h,sha256=bstVFYgLSji58KKsFdf0010mMYz4hJWLMMAgveQhTYU,6283 +torch/include/torch/csrc/jit/backends/backend.h,sha256=pzqQlIxLUSDiSohgBYQYUB58YqJVegbEsGg-yIXF-AI,3946 +torch/include/torch/csrc/jit/backends/backend_debug_handler.h,sha256=3CnyAp-WLs-AdPP4c6klqIVBA0VFcuVCt8roZdi0DcI,6469 +torch/include/torch/csrc/jit/backends/backend_debug_info.h,sha256=1v4JAyyk2BR_6QMVACbXkkIz7NL1yKwyGkJcJxOvPEg,2376 +torch/include/torch/csrc/jit/backends/backend_detail.h,sha256=EOp28wNxYDY6ogwcmZD46JTTNJMdTuuoWgphinR-_44,1118 +torch/include/torch/csrc/jit/backends/backend_exception.h,sha256=9jt1tVkm1yi_yPh7Kk-oM9DUW3iYO5jH9DKthW7j2HE,2172 +torch/include/torch/csrc/jit/backends/backend_init.h,sha256=Y93ezithNdwKk_Ax5nYomiJLExnweTHkqopuL_W3lE8,261 +torch/include/torch/csrc/jit/backends/backend_interface.h,sha256=csuI3iLS_VMmBW88FwjF5BSMConaz8_ybgEUfJiUq-o,1191 +torch/include/torch/csrc/jit/backends/backend_preprocess.h,sha256=KJILp9Du8kYOXp4g8wBzMhGYHrEqtDxOeb1KU59W4Kg,429 +torch/include/torch/csrc/jit/backends/backend_resolver.h,sha256=PRb9UcGEa_43eMh53JrGg4SC8FL-1GPPGTIlyysy7kY,260 +torch/include/torch/csrc/jit/backends/coreml/cpp/context.h,sha256=0dn3_rveGb95tZo4JfCYiKRC4LxCXnn1tn_rqC4djQU,460 +torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLCompiler.h,sha256=6rQrQ0IUoTYieEgQojuZDJWxhvBAyAiz7kLcAVWl1gI,555 +torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLExecutor.h,sha256=5jo0jQuxc8HA4heWmGTVEvAH-pbkKXJ38eXRaLBSbP8,433 +torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLFeatureProvider.h,sha256=GqxT3QgIDRPRIs4X8ZuwmFQ-MDWpjVrWV-o0yv1R7og,366 +torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLModelWrapper.h,sha256=IFnDGeNG8m8MGQ-bhiKej_PaBNPeJP4RAuiN2HvdOpM,963 +torch/include/torch/csrc/jit/backends/coreml/objc/PTMCoreMLTensorSpec.h,sha256=dyM-5ogXRnZof7SKHkELEp7jmgnpx_6G02a5NytU-oU,674 +torch/include/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h,sha256=zf2M-hDP5f9u00wqJ5Zi7oGuePbLNiO1hN45PQrT0yY,806 +torch/include/torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h,sha256=9PcJCDwZFXzn8gIyanwfl9kEwHC75U32cUHHi9NrkMQ,1635 +torch/include/torch/csrc/jit/backends/xnnpack/serialization/serializer.h,sha256=Z8sb3aEoKcYqul0pLG8znTy3T0xEZKM4RtHwfINLrXg,2868 +torch/include/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.h,sha256=-0u8qqYGHVxU_UfzyftRhZ1a8hh9ziEPRr7EfYkmi6M,3332 +torch/include/torch/csrc/jit/codegen/cuda/interface.h,sha256=6Vij6NpgQ5ONgS_eAteleY2HWsHXBDX0QoBps83Nn6U,1906 +torch/include/torch/csrc/jit/codegen/fuser/arg_spec.h,sha256=-3zHuDHLmbp68pCiKq8hbaasVoIUy7_QAi08Gy64Z-s,1408 +torch/include/torch/csrc/jit/codegen/fuser/codegen.h,sha256=qvNX2oC_sghA5-x9iSnLcUQCNSemSU01WMsPFTfPnSY,769 +torch/include/torch/csrc/jit/codegen/fuser/compiler.h,sha256=gU3Mo96RotF6hhlCQCV0TDYk2QhzsZ9xXkpJPVAXwGM,1890 +torch/include/torch/csrc/jit/codegen/fuser/cpu/fused_kernel.h,sha256=zUjRgDsYE2WPMMLiaiQ1SUF5D3H7tfwm3tuHP_PjKa0,1038 +torch/include/torch/csrc/jit/codegen/fuser/cpu/resource_strings.h,sha256=ADyOe8j-XaZpG6b4S_xDd8CqPjhGqeNT02V56kEP1_8,2379 +torch/include/torch/csrc/jit/codegen/fuser/cpu/temp_file.h,sha256=vCRiFUGLwLBLO61CMPdTEHbNca50rEG8WxFM5HGa5wg,3026 +torch/include/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h,sha256=M53IzF4yskqpUCfpCILKWC5QJTnD9fudo_jI_wvuesc,1570 +torch/include/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h,sha256=YeHR5HJQ8KKVPSb8d7yYsazMxZnm5BkavRSVP03dR_4,11153 +torch/include/torch/csrc/jit/codegen/fuser/executor.h,sha256=R-djq_29ocuOs70tz9_FwKrf90EmwyPyCRtDs-6BEUA,519 +torch/include/torch/csrc/jit/codegen/fuser/fallback.h,sha256=ovTuS1yBVyfL5mkBefwwuCII03z7jBVvScpWORemVqk,185 +torch/include/torch/csrc/jit/codegen/fuser/fused_kernel.h,sha256=SXAMHTqPLuj2JBdGL8ap4kQ8ixBTTetI9ARt1pD_lgs,3403 +torch/include/torch/csrc/jit/codegen/fuser/interface.h,sha256=zvAlUqnWQMydIH52AptfbRP6LirQzxwXuaUQPu3ehSo,1776 +torch/include/torch/csrc/jit/codegen/fuser/kernel_cache.h,sha256=CVjOVccu8HQrogpoi_qQuQ1EM316O4jbDJIjpmVQFL8,1027 +torch/include/torch/csrc/jit/codegen/fuser/kernel_spec.h,sha256=f4W0O_39LmQqSoFdtfong2uuXXw9D1nupcSgzTXb_pQ,4545 +torch/include/torch/csrc/jit/codegen/fuser/partition_desc.h,sha256=X3OoD0YrciQ0UoCfJwmjHyhlIR3Jpbs-F8hRzQLXi_4,1813 +torch/include/torch/csrc/jit/codegen/fuser/tensor_desc.h,sha256=OX7bt-im9Hv32JhHHMKUg8WU9JIBfhw-vXckODY2TJ8,2799 +torch/include/torch/csrc/jit/codegen/fuser/tensor_info.h,sha256=8GuG9pkzbvG6LnXxc_QZBE-SS3wWAAZVLRXuOptfGS8,560 +torch/include/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h,sha256=pPlYlwsNcSnXf6es7tvKiisKcEdv87uqMbXaHvmgyD4,7962 +torch/include/torch/csrc/jit/codegen/onednn/decompose_silu.h,sha256=bGoRYlDNrwTJ0yO0o-TML6cXGENkCvURFvB6c2vZ6mE,197 +torch/include/torch/csrc/jit/codegen/onednn/defer_size_check.h,sha256=6UgJLt_gWlTOLoA-EQmEGM48pJOzbMv3Ur47UoP2CSc,191 +torch/include/torch/csrc/jit/codegen/onednn/graph_fuser.h,sha256=9lmfs_-swItqHXzXGQEI84aIrnMd6ZX_9Uw3kXLdGCU,1263 +torch/include/torch/csrc/jit/codegen/onednn/graph_helper.h,sha256=gMLdtRPBo5miGaR2sRxhGYV-_75tk_33qY2x3Gl2hE8,2562 +torch/include/torch/csrc/jit/codegen/onednn/guard_shape.h,sha256=YrX8-Jy5a7ZhMP_TFg_uyEk1z0ux8b14BnYwqiTbRQc,193 +torch/include/torch/csrc/jit/codegen/onednn/interface.h,sha256=vsFFNskFxmM-29ELH61Tku_IJPjwjOGTFdG5IAwS0ag,1456 +torch/include/torch/csrc/jit/codegen/onednn/kernel.h,sha256=WtGKNf-MnAtHU3VwPTzE6aq_-A_BnTk7i497ZBqLYTA,2783 +torch/include/torch/csrc/jit/codegen/onednn/layout_propagation.h,sha256=KaE7EzolYSPcm030zW4yPVQEcp89zoNAHjqe4ETmzSk,198 +torch/include/torch/csrc/jit/codegen/onednn/operator.h,sha256=Mx8u1ARbUIm-kIfpxEPZZKYu4MdUObRFoCA0qR15yvs,4068 +torch/include/torch/csrc/jit/codegen/onednn/prepare_binary.h,sha256=AEt60IH84VEuAfksrt3jBUkR06UCxnWTyZ93kcNGcgU,519 +torch/include/torch/csrc/jit/cuda/cuda.h,sha256=ycs5RN-uDdKY2tACKP6_U8vqfr7j1jrqvhWnP8fAae4,5336 +torch/include/torch/csrc/jit/frontend/builtin_functions.h,sha256=4F9sRNoyjnej0qNuG5ORlLQPOggELcjv59-MMhkIyrc,224 +torch/include/torch/csrc/jit/frontend/canonicalize_modified_loop.h,sha256=NdZUmPkzv1bJ3zJAvdxmdHpsBPeywOvrMlNBP2npHEM,301 +torch/include/torch/csrc/jit/frontend/concrete_module_type.h,sha256=Bp9OJ_pI6Cw4OPFTw3eZDY9tSSN23m2L0g5GE4OGykE,9276 +torch/include/torch/csrc/jit/frontend/convert_to_ssa.h,sha256=Oa1Py44TYRPySdHxgo0AH0xe-tWEPzKRDwPQ09j8Wxo,316 +torch/include/torch/csrc/jit/frontend/edit_distance.h,sha256=298poj7MQ1zhKAGnvv3TcctB1F2Bl8hAGzWawug6MgQ,242 +torch/include/torch/csrc/jit/frontend/error_report.h,sha256=_PxkIzXHoAKR8Hb4UnFqBS2ah-zRQwNG4lFJyb2gzRE,1487 +torch/include/torch/csrc/jit/frontend/exit_transforms.h,sha256=fAvbMa66s6Yg6iVe3rPHPw_e99o5ayFT2Gsfgg7_1Ec,203 +torch/include/torch/csrc/jit/frontend/function_schema_parser.h,sha256=0ukRBdlwx5DdLpru5FyXjeQ6Hcfp-rGq5zYSYpf7FAA,827 +torch/include/torch/csrc/jit/frontend/inline_loop_condition.h,sha256=LyjsxdwYcO_cdgJyVQjkbmQX4qoKtScY1h2q9USEUEk,341 +torch/include/torch/csrc/jit/frontend/ir_emitter.h,sha256=628nP0Pil7zEyJOCIxJpV6w89SfhzrdNvT0BWxfIAwE,535 +torch/include/torch/csrc/jit/frontend/lexer.h,sha256=nA_aaVby1RA203w420esK04xjMCFkIJACL05OntA4Zw,19581 +torch/include/torch/csrc/jit/frontend/mini_environment.h,sha256=vArMlMby1-WzFvQ9UEuqZZp_uXb-rEayC506ggWYkN0,1430 +torch/include/torch/csrc/jit/frontend/name_mangler.h,sha256=FTAcvLySQz1XsS9QcjVWX9yEh3G64oObPWPy6UzQwec,655 +torch/include/torch/csrc/jit/frontend/parse_string_literal.h,sha256=s_JAfhA31IRCI6QlseCs0ISgn-_tIto1Et47XBBoLMo,2381 +torch/include/torch/csrc/jit/frontend/parser.h,sha256=Lq5GFcVTgEVsmBJu7TZXam6hF5OLeQJEPAJ3u4ct5rI,680 +torch/include/torch/csrc/jit/frontend/parser_constants.h,sha256=dsNcQqp7ZY6R39856i_HeRT1EOf8umfQDzLy4rLqzEE,157 +torch/include/torch/csrc/jit/frontend/resolver.h,sha256=3xnO5o31Dn2GIb_f0QLxEZysjMpZ_WFTwpekpAanDlk,2023 +torch/include/torch/csrc/jit/frontend/schema_matching.h,sha256=SqihZOutY_LiJfKP9Xt3L_nwY3XyCnXCYq-uenATYmY,2176 +torch/include/torch/csrc/jit/frontend/schema_type_parser.h,sha256=h06RpwHLCihkWfoFwlBc-vYKeHd6IX79vLOmTYHXKws,1245 +torch/include/torch/csrc/jit/frontend/script_type_parser.h,sha256=5ao-exigVQzU5JCQU-HeXJQgsEprLwC1fbe3uEJOcFE,1633 +torch/include/torch/csrc/jit/frontend/source_range.h,sha256=zqyY4yW1BG_wV21RU11Txt55svGZUfmUoxMiz_rM9as,17516 +torch/include/torch/csrc/jit/frontend/source_ref.h,sha256=7FsrsDhxsekba5xhG-doZC_qT2TUZEFrIq4dCQWh5RM,1333 +torch/include/torch/csrc/jit/frontend/strtod.h,sha256=IJyN1or0ERowFppBbzb_7Io3NTXWHZ78fNmpsovW9u4,226 +torch/include/torch/csrc/jit/frontend/sugared_value.h,sha256=6bAuIi7OdmKCIrgk1G8mC2Q42JkHpkorCLoz1BhFSN8,28758 +torch/include/torch/csrc/jit/frontend/tracer.h,sha256=OTyT436rrNEfY4PQQFUJyvTv__DESdNtkKcJwuCJAeU,13203 +torch/include/torch/csrc/jit/frontend/tree.h,sha256=CukHclFOH1oeAMUGXwN8kFKcztXR46qYDH1MT3p0zDA,6817 +torch/include/torch/csrc/jit/frontend/tree_views.h,sha256=lS9VEFXWY_BJa0cC1TN66fVt_j_bNjIY1-77l53eAeU,38480 +torch/include/torch/csrc/jit/frontend/versioned_symbols.h,sha256=OELfn9VjErHmLvUQSyFO7sESgQpjOYUV0LUeMD30KKY,614 +torch/include/torch/csrc/jit/ir/alias_analysis.h,sha256=LkvdHm545e7EF9PPs7TbT5ZyIywJj8LvXjhz4Bfy_hE,14652 +torch/include/torch/csrc/jit/ir/attributes.h,sha256=2MROMiINbtwiOwoUUvxNKJOpCAYx6WAlVMHMRLyq-Sk,4964 +torch/include/torch/csrc/jit/ir/constants.h,sha256=NdxA50B_cLga8YY2Y4ldvRuUvgof7msaCrCQ8CypKAc,2068 +torch/include/torch/csrc/jit/ir/graph_node_list.h,sha256=bxfn10Gwj7NzEW5s9IOhs3meiVhk8qebUlmuntnJPBY,6553 +torch/include/torch/csrc/jit/ir/graph_utils.h,sha256=EMBJevQAxN7ApNRS_h4eSJrTUyFgfWC5BDqkJaIYADE,527 +torch/include/torch/csrc/jit/ir/ir.h,sha256=2PXA1UJhVat7e1DvPqSOn6-H0iuAe8vZE5xM-hwe84s,55725 +torch/include/torch/csrc/jit/ir/ir_views.h,sha256=CsDfwICxsdr08ypsJkb2KetFG0Un9tpYWKGyMU48ZFc,4785 +torch/include/torch/csrc/jit/ir/irparser.h,sha256=NL96TqWRGOh6JrY7J8TWANF9n02pljay6ILyOY-RrVk,1133 +torch/include/torch/csrc/jit/ir/named_value.h,sha256=vhaJn0MLQ8EUp4AUp8vEn3jMWqaii1vdo9ZJrOYVxtc,2477 +torch/include/torch/csrc/jit/ir/node_hashing.h,sha256=QFjLearfbSj13z3JeMZzPCLtsq4lq8VpCH6kXH4DZXU,280 +torch/include/torch/csrc/jit/ir/scope.h,sha256=Qz2eRjUwyfTD-FhHPiMaJ-O9Z2ZqTQiI5vt9mcz7CF8,7366 +torch/include/torch/csrc/jit/ir/subgraph_matcher.h,sha256=oH74YsHmaHE0UBQN06D9aaXLbds4IB5mmLAO8t8T31M,3198 +torch/include/torch/csrc/jit/ir/type_hashing.h,sha256=vatGNhFjsoXjIdZ54jZTN6DGZRAvSgAOR9jKSwn3ZlU,452 +torch/include/torch/csrc/jit/jit_log.h,sha256=wJ8CP8QRVaGqptVfQq8r5IZEZDVeh3YRXxZ-0-4AqSI,4933 +torch/include/torch/csrc/jit/jit_opt_limit.h,sha256=11752ePq04ymEDhooj5_PD12S8f2d71vgOPNfSyStJA,1418 +torch/include/torch/csrc/jit/mobile/code.h,sha256=4VGjG_UD8Wjc-h7OdprN-TNpuhkvWha8jboUip842n8,1112 +torch/include/torch/csrc/jit/mobile/compatibility/backport.h,sha256=JjEPHlOHG6cVOfugh6UE_RB2qTIUXHg2ceLdIzPeEmA,667 +torch/include/torch/csrc/jit/mobile/compatibility/backport_manager.h,sha256=flqHr8tI1NXHDnFEbq-_H70m5eUerh5itc_zwne2Dl4,1207 +torch/include/torch/csrc/jit/mobile/compatibility/model_compatibility.h,sha256=kWd1xeKcvCvq6eX7njJUIQxyg29dhqsd59ZtLOazDoY,3716 +torch/include/torch/csrc/jit/mobile/compatibility/runtime_compatibility.h,sha256=06Tqz75fcxzkaX-AKg032sRe6MSstwvFp7eVnIm74oI,1239 +torch/include/torch/csrc/jit/mobile/debug_info.h,sha256=z2aHlt70uRSsmJnJAmtMMEEsfK7p8GWCgCYLP8Eseig,2260 +torch/include/torch/csrc/jit/mobile/file_format.h,sha256=gywdJo07PDAiQnt1gFMuUZKm37yIaD86SslXJpnR2LU,6790 +torch/include/torch/csrc/jit/mobile/flatbuffer_loader.h,sha256=gNqKdRa_MNagNTiqyuz62EtoPKzSBhI_IEkh04qiSQA,5097 +torch/include/torch/csrc/jit/mobile/frame.h,sha256=-80NLYXAVmCWQ_XJFWPIIDn-l8sfKDB7X-OxIAGmQvY,860 +torch/include/torch/csrc/jit/mobile/function.h,sha256=3VrR7u1GCFZJHHGWSPnyvJSfNdb2lCTorEKm_TKb8aI,2984 +torch/include/torch/csrc/jit/mobile/import.h,sha256=wSqEiA5KVRKXZqPuq_dzwU2UhubGjSd9q-xi2Xfrdog,3948 +torch/include/torch/csrc/jit/mobile/import_data.h,sha256=HdKa8Cv8E_c-pf8iV3PqX4jO2H0gJxRW5IF4B-6gxmc,1031 +torch/include/torch/csrc/jit/mobile/import_export_common.h,sha256=ujs6H-TVov2Mo3XkXns0tkjehgeU-YtSYtAfqITXG30,492 +torch/include/torch/csrc/jit/mobile/interpreter.h,sha256=-drviTPHiTO3kpEK33oNwfpR-Q31W1oZkgV7se5tO54,664 +torch/include/torch/csrc/jit/mobile/method.h,sha256=6Zk-Sl__upmINdXFYLrh2W2gd-9nFYs0J-Y1Sdh8-Co,865 +torch/include/torch/csrc/jit/mobile/model_tracer/BuildFeatureTracer.h,sha256=ZUgb3LnIOGIog2ixeybtOyP0k0XUzrtRycnSYOFDmhQ,998 +torch/include/torch/csrc/jit/mobile/model_tracer/CustomClassTracer.h,sha256=Kq1_bQ9URqh_fc9ZpUUsdMkdpp5-mG52XDGnVDI2_i0,996 +torch/include/torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h,sha256=DCL8-tB1qVvtHJA1cLF9zJGgSNQGb7RGIyrmde73qQE,1249 +torch/include/torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h,sha256=VDYf-k22AuTyLn3RSOroU397i8MsOcQ5UW4qF7JnFDg,5232 +torch/include/torch/csrc/jit/mobile/model_tracer/OperatorCallTracer.h,sha256=6Ts6EItduKVmi4zI0WL2bexNqfd6iHB4amlLlZLrmpo,950 +torch/include/torch/csrc/jit/mobile/model_tracer/TensorUtils.h,sha256=qWAw4qo_ixk1o1WkckPBCUGp14lwvAVi26OfivUT6Z4,411 +torch/include/torch/csrc/jit/mobile/model_tracer/TracerRunner.h,sha256=7aJ8-3K7gZX8D_QWGv9rbNlasU048IIamkXXIOMTW_4,1146 +torch/include/torch/csrc/jit/mobile/module.h,sha256=t3yHtRrUFDcjojfXcY5VO7CC4qePDek_kK_yeUiMBBg,6114 +torch/include/torch/csrc/jit/mobile/nnc/aot_compiler.h,sha256=v_B0JURAdq1k9X8mJ4xjzDla3WjC1s6dZoQyvDc-SJo,648 +torch/include/torch/csrc/jit/mobile/nnc/context.h,sha256=_0HaGHsSjwZ1qUCyeTNj_clP3tIb8AbDVs09F1AqR_Q,6839 +torch/include/torch/csrc/jit/mobile/nnc/registry.h,sha256=iXP_hCIp-2aJqXz0jMRsdUAjK363KyqNxPwbdaOCCys,1200 +torch/include/torch/csrc/jit/mobile/observer.h,sha256=9mqQiOVpWI3F6ubHH1xt08bFU5D2K16DSbrgH78xsso,3747 +torch/include/torch/csrc/jit/mobile/parse_bytecode.h,sha256=TzRSkL0-i6eNqKs-3WpR1scnVyuu0DiXCjdDffrI4zA,761 +torch/include/torch/csrc/jit/mobile/parse_operators.h,sha256=CQCMmhxqjebUU3Jlb5ufa91izsouvMnODIDSeTOJvP4,734 +torch/include/torch/csrc/jit/mobile/prim_ops_registery.h,sha256=wy3TcUFiyTAou_Cp-GbOjCCoXdnzP8qeyXsgKFMa4SI,623 +torch/include/torch/csrc/jit/mobile/profiler_edge.h,sha256=h25qR32gyqvwXoivJj2lIPgBxmq3YsQKx_Kq1TZ7J10,4603 +torch/include/torch/csrc/jit/mobile/promoted_prim_ops.h,sha256=cQfCafQMFg3gtS8XSKMzKqmit1UhouxYNz7zIzTjlCo,1104 +torch/include/torch/csrc/jit/mobile/quantization.h,sha256=rJmEme7p7nXFD9wHLSnbgVv2ms7FiCS6ZyYjFs0bThA,1272 +torch/include/torch/csrc/jit/mobile/register_ops_common_utils.h,sha256=fCXTmQClnq3cgy10xggZMwNUEf4NTWqOFctko5TYUBU,1740 +torch/include/torch/csrc/jit/mobile/train/export_data.h,sha256=Vt_Eq5qTveggPQtwkCTJ0Yx-t-LtR_GtQZHaVyPABUA,1645 +torch/include/torch/csrc/jit/mobile/train/optim/sgd.h,sha256=O0W_mTGF7TC6DKqxF1-wDnu2Am47YswqVcj4hzg8IT4,4456 +torch/include/torch/csrc/jit/mobile/train/random.h,sha256=OaZaWF48oJPBO1qmA0Vsn7LfrI8tKgvgouyfo_Om7To,1569 +torch/include/torch/csrc/jit/mobile/train/sequential.h,sha256=XWkaXz6HMDVEZOKzU0R5GajZDlel4dapBG1p5CuahXE,1269 +torch/include/torch/csrc/jit/mobile/type_parser.h,sha256=cz4-KlIrtJvVmz7WERSoebax574lQl9D0VKZO3agLjE,1497 +torch/include/torch/csrc/jit/mobile/upgrader_mobile.h,sha256=G2QUu_Eat54pEUskkbZ1Y_bxfsVxDUBiTb5E0uZzby0,944 +torch/include/torch/csrc/jit/operator_upgraders/upgraders.h,sha256=V8d6OSlDYYqG7BgBMN8AU9XT7Qa1QJt5NY6kwxncxMU,1430 +torch/include/torch/csrc/jit/operator_upgraders/upgraders_entry.h,sha256=_qr87ttlYkmdLTsBXw0pbbdv6kKBSfFxWmcZW-ed9pc,542 +torch/include/torch/csrc/jit/operator_upgraders/utils.h,sha256=-tn16NPKa5lKKyUg2PorMaqKWUWx5nHTgqLbFcw5u5M,1744 +torch/include/torch/csrc/jit/operator_upgraders/version_map.h,sha256=mkdL5qtrDrJcXuOL_brxq5-WjzytPI4lvikwOAtq3Wc,920 +torch/include/torch/csrc/jit/passes/add_if_then_else.h,sha256=dcvmjLmI-kz7g-lXMytvkDC7aNA0sVT2nrjgAu2wAm0,172 +torch/include/torch/csrc/jit/passes/annotate_warns.h,sha256=ejLWN3S4ilEZXX2duhshRHTkvFyHqvco-iihbAvm7dU,176 +torch/include/torch/csrc/jit/passes/autocast.h,sha256=V0Bn81ROnkg4Tqfa35j_R9GSN45f3Sc3Ih4W0Tcn274,255 +torch/include/torch/csrc/jit/passes/bailout_graph.h,sha256=XWzXGZM0K13Q4fysbBPU94UbiuYDQA-vKzaDtreMFzk,1122 +torch/include/torch/csrc/jit/passes/batch_mm.h,sha256=s99CwEK_3jy1TMj2DCpeiJdN7TxPhfKikdvW0ncuh3E,140 +torch/include/torch/csrc/jit/passes/canonicalize.h,sha256=aYp4VuZppo-VCTO4KSFuVpOmfHNdanPsCLjRUoLGef8,487 +torch/include/torch/csrc/jit/passes/canonicalize_graph_fuser_ops.h,sha256=UpyfaBZsCzDFpjoVRb8d9EW8NjE5J8P2tn0If0MZGEQ,154 +torch/include/torch/csrc/jit/passes/check_strict_fusion.h,sha256=VF0ILt9ArxfZB61WeUfpGbR47HpdNaTZcCZMebkg0kE,176 +torch/include/torch/csrc/jit/passes/clear_profiling.h,sha256=UIKuo8kMevVWOO2W1zC8OCj9G3iRiDxEcb8lZiFyIdI,485 +torch/include/torch/csrc/jit/passes/clear_undefinedness.h,sha256=gBC8u8lvjkfzwdYSUbwY6ofVRfRpHOzukFjJN2Eke-g,872 +torch/include/torch/csrc/jit/passes/common_subexpression_elimination.h,sha256=K4XU_GRf6iB0xTGRFjSDGqyXNJp9eG0PVZVUtgGH_Nw,171 +torch/include/torch/csrc/jit/passes/concat_opt.h,sha256=FCQbsxyThTrUJlHOUP-k1RGBuTsRDY2u-H_UILx5YjE,542 +torch/include/torch/csrc/jit/passes/constant_pooling.h,sha256=sUsM4TrSB00HUVnsygsG6aJbgWUDrAtngn3zKyzm1TA,154 +torch/include/torch/csrc/jit/passes/constant_propagation.h,sha256=4UCLhsi2MCQF1WPTbk7C7X9mCKa-UIQ3VOyE5eLT1og,1319 +torch/include/torch/csrc/jit/passes/create_autodiff_subgraphs.h,sha256=0wbxpbVhMU6hx3FauJpX8h941m5LLaYExFIBbDjAOn4,530 +torch/include/torch/csrc/jit/passes/create_functional_graphs.h,sha256=gUMV76y9GSy0mD3E56z6N4NK5jL5CisdVg1MRamGCAA,296 +torch/include/torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.h,sha256=AmJUEziJBF1wDZ-grZK7qt8gECxKGJjRWsmkOE9hyZE,359 +torch/include/torch/csrc/jit/passes/dead_code_elimination.h,sha256=VaTe24fv7LqTKZemuAbIXUKBExhu_-VZW-03tX5VfdA,1599 +torch/include/torch/csrc/jit/passes/decompose_ops.h,sha256=ELlyFEtYauiekdgPHDHiaIYu1us9xPApTALe_kGd9YM,145 +torch/include/torch/csrc/jit/passes/device_type_analysis.h,sha256=q7ny510G-t2PVO8cfOtP8F7kfjETqDK5G4B69V0P-0Y,253 +torch/include/torch/csrc/jit/passes/dtype_analysis.h,sha256=XHh5YgEVKdUjXDMdAOZiELmW46_vCSIxWF1QO2cAQ0Y,404 +torch/include/torch/csrc/jit/passes/eliminate_no_ops.h,sha256=hu-uCzFMLtu_fY1fzYENDLumRucBAq_A_k4wnCNiRwQ,507 +torch/include/torch/csrc/jit/passes/erase_number_types.h,sha256=yoaE07-5t1xMvWZqQUu6YromqGfqenaFVVSZdHu9Z7g,809 +torch/include/torch/csrc/jit/passes/fixup_trace_scope_blocks.h,sha256=RThuyoHEnO_AZ5y7kQvcCgzT3wMQ_3kd8_E369wnYHg,1693 +torch/include/torch/csrc/jit/passes/fold_conv_bn.h,sha256=h65Xo1OEkLuMAtQB_ciXpCkgvQmEhT1bgU1fr5giT8k,1005 +torch/include/torch/csrc/jit/passes/fold_linear_bn.h,sha256=SQmgKx4HWs7-f-pAV6J1RDxD27eQGSKSpmoWcx6y9oA,693 +torch/include/torch/csrc/jit/passes/freeze_module.h,sha256=rPAuBt2OKF0tEOKIgikUx_HaZf6Ce8ejnAqGs15g5MQ,1253 +torch/include/torch/csrc/jit/passes/frozen_concat_linear.h,sha256=lY6jWiZZtjuUNNa7RVbe5WrQ5xDOSWYaD15dE4dCkHM,263 +torch/include/torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h,sha256=lX_CYHpmGb3yU-tuua2oRPwHldnR8b5N4754_oXUxco,317 +torch/include/torch/csrc/jit/passes/frozen_conv_folding.h,sha256=B_eHlqjVyXP-map0JztrodoybuLFZeblcvA0B88VJdM,869 +torch/include/torch/csrc/jit/passes/frozen_graph_optimizations.h,sha256=Af-IrgzE6KTWcxEcJxymH5nGWL-pXtYRmPi84HJGPko,461 +torch/include/torch/csrc/jit/passes/frozen_linear_folding.h,sha256=4YMcYUPNHWZmtQ34NEg8jpR0pKcvQoyJ68MsUAGkNFU,357 +torch/include/torch/csrc/jit/passes/frozen_linear_transpose.h,sha256=GDXwVyTb9G6CnQqXy7jEzdVVT-FDZgT40mb8Fzbfb-o,272 +torch/include/torch/csrc/jit/passes/frozen_ops_to_mkldnn.h,sha256=u3C22Kt3bfLAFf7ZnCseN5Iz-ApDVqTIFpv6jcpsFYY,402 +torch/include/torch/csrc/jit/passes/fuse_linear.h,sha256=Mfh491MMbPrY-P8vDI_GWJ5UxOiL-FdlV5t0_hT2fMM,765 +torch/include/torch/csrc/jit/passes/fuse_relu.h,sha256=7D9ow7eifHZPsLz-ThfZCUi2rMI-QD8lzruqHBk4l78,257 +torch/include/torch/csrc/jit/passes/graph_fuser.h,sha256=-ylGo2irRTcW-_h-j7CUVZMCR3_ShZq9HkYHSuojJp4,1261 +torch/include/torch/csrc/jit/passes/graph_rewrite_helper.h,sha256=4G5o-13_8KuTQOsYbIPuK-KIW51NWSPX-sMRvjFGzv4,1785 +torch/include/torch/csrc/jit/passes/guard_elimination.h,sha256=qvjUCfO2lMhbYawQkhDcu_gn4fuy93d5b3dEXOywRyU,368 +torch/include/torch/csrc/jit/passes/hoist_conv_packed_params.h,sha256=ueCANfsR2c6E9Kukn3pt1LTS3rm6se6kGvSwnXw1VvE,196 +torch/include/torch/csrc/jit/passes/inline_autodiff_subgraphs.h,sha256=KYtqYzbgY_oq9KMD5e-h46JBLY9_vACAsSuWfYOE2b0,263 +torch/include/torch/csrc/jit/passes/inline_fork_wait.h,sha256=cX-w802NZD7uAqPLtcoCjzupgAIsuQ81PFjb9AxgpSc,536 +torch/include/torch/csrc/jit/passes/inline_forked_closures.h,sha256=kJCQ_Lyp23FfIK_nEVgoJRtqWCYeow5nAHGYyCbrRMo,212 +torch/include/torch/csrc/jit/passes/inliner.h,sha256=rt59r_rQr0A9O2a2KqW3tpYk5eoqjAS8v51LBqluJt8,241 +torch/include/torch/csrc/jit/passes/inplace_check.h,sha256=alaKbr-BZbsHXj7-ZpweMytEzob--W7fNFGHNMVKmGI,145 +torch/include/torch/csrc/jit/passes/insert_guards.h,sha256=oKWK8r3kux1B3idcwlSnjr3vjFUirbwioR4YPjirKn4,433 +torch/include/torch/csrc/jit/passes/integer_value_refinement.h,sha256=Y8uGXpioItsziL7rpIpkF-ihJ3eWOmyuITPbGl2tovg,219 +torch/include/torch/csrc/jit/passes/lift_closures.h,sha256=rMGYTbP1F0NwClMG5-qFzIOrsEFxgLngGD4cnnS3s0w,207 +torch/include/torch/csrc/jit/passes/liveness.h,sha256=aw9FfI4AMJMASfUD5fu1avd4etUw3iNCPxFkr3MpizI,647 +torch/include/torch/csrc/jit/passes/loop_unrolling.h,sha256=KfTgzJJ01PD7XAvBXbFzKV6Xn6ASEPpdyKby6-8rJmI,1015 +torch/include/torch/csrc/jit/passes/lower_grad_of.h,sha256=Xim2wnLaUZG_mi0r-8e40xxPnXYueqJrKDLvrwVlNio,338 +torch/include/torch/csrc/jit/passes/lower_graph.h,sha256=R7G1z-UEjln1ODrBIwNGEYbda8c4UdaDjjiGqSZfQbM,745 +torch/include/torch/csrc/jit/passes/lower_tuples.h,sha256=VCYP9LnLH2YI42u5FOsfkxa0W1kp8bWfLI_yIsO49qY,659 +torch/include/torch/csrc/jit/passes/metal_rewrite.h,sha256=ZKMnk3C72ikHZNo3gBiw80itcDowt1k8OWp1iTg-qR8,596 +torch/include/torch/csrc/jit/passes/mkldnn_rewrite.h,sha256=t_HgLt35LSXudNEPdughLwxpA7eMRUg9defy1pvpPkw,637 +torch/include/torch/csrc/jit/passes/mobile_optimizer_type.h,sha256=mVAFakGZ3_bSZyaUL4UnYHEeZlSMKKCwscJqgxyUK00,250 +torch/include/torch/csrc/jit/passes/normalize_ops.h,sha256=lSTM0ZR6ewYu_0_rsl65DhpV7WZL1mgP0gjkxYK1UW8,527 +torch/include/torch/csrc/jit/passes/onednn_graph_fuser.h,sha256=NwmfgPpN8b-fS-Ji6G96eucOPjVRECwHvXzQdqX4jFE,1460 +torch/include/torch/csrc/jit/passes/onnx.h,sha256=NnXzPe6ubhvicyjbavRbJOz-6NV85zcxcvQ3XAvu2jY,860 +torch/include/torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h,sha256=lhPru2ScME__er2OeeAmvQQ4V1E_OS1UsOn_227zg4U,227 +torch/include/torch/csrc/jit/passes/onnx/constant_fold.h,sha256=TBUO3nsGZPMl0PzFWqV_iW6OvJYBrkGzJRlxoLTL-aY,725 +torch/include/torch/csrc/jit/passes/onnx/constant_map.h,sha256=xyYNuI8XhTstLZ_lvFijY1OORXhbZY4Soo_HEXAlcyg,4534 +torch/include/torch/csrc/jit/passes/onnx/deduplicate_initializers.h,sha256=Waj6_KjAW4nebpDEDWkJAmX2bzHggtncrDj_Q2tYehU,261 +torch/include/torch/csrc/jit/passes/onnx/eliminate_unused_items.h,sha256=VeP4A_AAexl23f_uGPXhpNKe7ZmYzPNnafcBgLH3wYo,357 +torch/include/torch/csrc/jit/passes/onnx/eval_peephole.h,sha256=O8wGA8lw3Zdue9Eg-4IE73b6YWa7r6YGCJRpRSJUz4I,233 +torch/include/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h,sha256=xjcuQBFwoU3UowiuueA7y8svRBVKIJSvwZLTXTLrpkk,231 +torch/include/torch/csrc/jit/passes/onnx/function_extraction.h,sha256=uEDFqHs4bp3gp-XUIh0L7m1BH_deTl8mhKUnw_zuV0g,2315 +torch/include/torch/csrc/jit/passes/onnx/function_substitution.h,sha256=qiuUj0XaJozC0NtMmlbLitPgarR_sM8Dd4Bee_erPcY,144 +torch/include/torch/csrc/jit/passes/onnx/helper.h,sha256=_tO1gDy7NaObXZr1ztB17kSs2YgEInk-A7gBuyk7e7w,2206 +torch/include/torch/csrc/jit/passes/onnx/list_model_parameters.h,sha256=CYaKIZx9kDzHi6Okx9486u5vt0qz5G4GOQb3FmYxzso,250 +torch/include/torch/csrc/jit/passes/onnx/naming.h,sha256=PJPAQQaHoIQzeZqOclf3A9-F-nWTfCrJh4aAKtYurGA,786 +torch/include/torch/csrc/jit/passes/onnx/onnx_log.h,sha256=hJBQt1tzAGcQkMS278bus1KFAvVlyjfpJ0578GvGRaQ,607 +torch/include/torch/csrc/jit/passes/onnx/pattern_conversion/autograd_function_process.h,sha256=pEAemSuhHrnl5eltGDQzA0n0sWxdAELVVBoV5DkVvSY,184 +torch/include/torch/csrc/jit/passes/onnx/pattern_conversion/common.h,sha256=IDerh2TkFAmGis6r6kXjAdsM-TaAkNHUYVUyMlGqDsk,353 +torch/include/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.h,sha256=DlerVR0Qky1dSLIZsbaebax6x3Rbp7N3HtDWkeGfdXI,2139 +torch/include/torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h,sha256=9E_Y6NTGFQqY-BORbSDENWvFziA4Gs3APfo8LcqLrk8,1394 +torch/include/torch/csrc/jit/passes/onnx/peephole.h,sha256=xt3YKszdA7TneicQaZ695GaMHFQOiQ1bW7CrpQTDmhA,225 +torch/include/torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h,sha256=rrj68DNFeKkWFFnc-udnd9QhKJwWp5c2rOkWzGcQWXs,468 +torch/include/torch/csrc/jit/passes/onnx/preprocess_for_onnx.h,sha256=SvD326s7LRZjs-gPwdbSlqC0JrMqKeo1NaGp0sHnN6Y,164 +torch/include/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h,sha256=2e7VkJV4MkbDKqVJL9W4EBGcSOa68xP3amoboiu26mQ,212 +torch/include/torch/csrc/jit/passes/onnx/scalar_type_analysis.h,sha256=Re-Fa1Yw3Mo9_A0ip2L81l2hsuRQlE5IsRjjMQCNvUM,293 +torch/include/torch/csrc/jit/passes/onnx/shape_type_inference.h,sha256=sxa8_75j8K912SQProq0oVdYTClGBhBRyy3lrpiundA,4053 +torch/include/torch/csrc/jit/passes/onnx/unpack_quantized_weights.h,sha256=zxphW_0_5OeyMc2mUCu4yGOOImp1JtvD-pubMnXivmo,444 +torch/include/torch/csrc/jit/passes/pass_manager.h,sha256=l4MXB13x4nSRKJX-vDLdP0_JVR35O8DX18NtgdhYOIA,4683 +torch/include/torch/csrc/jit/passes/peephole.h,sha256=SGoqBxVPKIZHiXvDx26TqHHGLm3CWkZUSogGesvdn7o,500 +torch/include/torch/csrc/jit/passes/peephole_alias_sensitive.h,sha256=qhxTvRrotVXXlBB88s8trKsY58S4-XyWixWDJWz5XhY,425 +torch/include/torch/csrc/jit/passes/peephole_dict_idioms.h,sha256=NSL8e9v_A_i1qadCOmujPiHUqzjQQLTvRfuvBr_R2zo,1011 +torch/include/torch/csrc/jit/passes/peephole_list_idioms.h,sha256=TgMT2rwsmuGbYY488tvcbchEKSnbIJ1ivKm_5WqDuJM,2048 +torch/include/torch/csrc/jit/passes/peephole_non_tensor.h,sha256=Dg90Zq4nmoqH-tFua_-pN5pnD8dobjYSAxEdxR2ba4c,329 +torch/include/torch/csrc/jit/passes/prepack_folding.h,sha256=2IifDgzdYtvY04FYZVo7qV8Ygb8XwU0eEHG8GIjFy74,348 +torch/include/torch/csrc/jit/passes/quantization/dedup_module_uses.h,sha256=REGOsSFeDBbMkjKPzjs9Vkk3DZey7wHKa67zxI-Qnqw,827 +torch/include/torch/csrc/jit/passes/quantization/finalize.h,sha256=V106jDTjAHDmqz4-KEsfyHWDjecMyQDuGBzu0xkn5EU,2357 +torch/include/torch/csrc/jit/passes/quantization/fusion_passes.h,sha256=7ZC9zvrO648vTl014p5JEvngzpqS4awaSuC99yYrKzY,173 +torch/include/torch/csrc/jit/passes/quantization/helper.h,sha256=wEGUIpC4ua8ukmmDakxEkp7m5RSTEA6EH4rCJW_Y7-0,7673 +torch/include/torch/csrc/jit/passes/quantization/insert_observers.h,sha256=5dJD4lsCeT62xLjxeQsrocg9cEWwip0eaiWQN_JkXaA,2392 +torch/include/torch/csrc/jit/passes/quantization/insert_quant_dequant.h,sha256=c6RL7fJO4kmVgNkVaN3D3RJ0l_idqt70c45zjq9qqz8,1469 +torch/include/torch/csrc/jit/passes/quantization/quantization_patterns.h,sha256=89TyLFgRFha6Y9EeQtCUThWsbvq8UWV180Lx2PPR_vk,54629 +torch/include/torch/csrc/jit/passes/quantization/quantization_type.h,sha256=nN0kBbzzOWvU8-cym3GnXRwOY_Qn-60SeQ7azWthSho,346 +torch/include/torch/csrc/jit/passes/quantization/register_packed_params.h,sha256=MlxoxPyatVQplrR4UfQ45-ubozVD-DKrbzAuOVn9r-I,507 +torch/include/torch/csrc/jit/passes/refine_tuple_types.h,sha256=vChizt1g-fNzM49vbW_g0rqRGZF5Dpa26jcWuQYNc3s,252 +torch/include/torch/csrc/jit/passes/remove_dropout.h,sha256=3XTWl9rQP3PAEWJOqgKTQzchO3np-Ize7IeqBvv1Jxg,267 +torch/include/torch/csrc/jit/passes/remove_exceptions.h,sha256=fcPmv_oL8XsCgQzSJmXW9OnwL3Jfi4mebuIDOVY2fgs,950 +torch/include/torch/csrc/jit/passes/remove_expands.h,sha256=s599611cNHfeNVClS1e4H0FghxT_8g02MYyzpSqfePM,152 +torch/include/torch/csrc/jit/passes/remove_inplace_ops.h,sha256=ixFlR3GNHU1CivQUuN9thbm7PqnLoaVDf3XDiUAL2ew,283 +torch/include/torch/csrc/jit/passes/remove_mutation.h,sha256=OeC0r_DS6QbJNxCqWsoWH_BSlqaSSBl0byXoYVUX2kU,2718 +torch/include/torch/csrc/jit/passes/remove_redundant_profiles.h,sha256=JDsg0Wn2nN9J61bI3T87NiLlqLk4VvSRXC323FFOQtg,246 +torch/include/torch/csrc/jit/passes/replacement_of_old_operators.h,sha256=lGGNAmtWZwekKUXdExnjy1nmRdch2TE8NasBmE9qXNI,448 +torch/include/torch/csrc/jit/passes/requires_grad_analysis.h,sha256=w4D169aMg6iwJd5opFZocBY7qqNDFSCQing5F4hs6VY,235 +torch/include/torch/csrc/jit/passes/restore_mutation.h,sha256=cCkja5DhfcdlC3rICG17CxLVTEPP5EyohAV5sYWIKVc,1862 +torch/include/torch/csrc/jit/passes/shape_analysis.h,sha256=03Idfln6irp14B6zPLYb8bBWc1ui97yD0fs6uo9WodM,1118 +torch/include/torch/csrc/jit/passes/specialize_autogradzero.h,sha256=hilGFWW_AsUtoaB-4GObtHGZp4hoglaE-tOHME4lmf4,650 +torch/include/torch/csrc/jit/passes/subgraph_rewrite.h,sha256=3KloxA771sKO-_yiIxPM3Y5WlkjZrwcRFps6QUbxUPc,4202 +torch/include/torch/csrc/jit/passes/symbolic_shape_analysis.h,sha256=ZrxpUfV3fkR0c2bbpLS6Jb5eSPLjz_wVT-c-KmZS7wM,2134 +torch/include/torch/csrc/jit/passes/symbolic_shape_cache.h,sha256=097_N0DCW45tUH4oNU_-QeisDu0V-ZrXJ3LXEhuMj9o,1629 +torch/include/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h,sha256=AJlatJNSBQMwDrKIR56EG0OAEHHqFUsmX-bVoEVTMxQ,2390 +torch/include/torch/csrc/jit/passes/tensorexpr_fuser.h,sha256=5AZ7r-e1MzBcFGd_rI4M1ztJk9CmPRSKb2uU-6WMwgg,2777 +torch/include/torch/csrc/jit/passes/update_differentiable_graph_requires_grad.h,sha256=vV5xyRqr1gazDv8qlmtRgegl1HbkbZUWAqn_MipUfbk,733 +torch/include/torch/csrc/jit/passes/utils/check_alias_annotation.h,sha256=Z-8dro5xBbE1irFPfMgUZkzmkPcKiN9RbFkHHipg8rI,607 +torch/include/torch/csrc/jit/passes/utils/memory_dag.h,sha256=GxHzD71a4CRwl2x_qEySWy9NlsO6q2Vp68pjzSEVFNs,6565 +torch/include/torch/csrc/jit/passes/utils/op_registry.h,sha256=Ap96JRHBTGaCMCqgw7KKb7RLH0UGY5ydmsYYfXWhzg0,1037 +torch/include/torch/csrc/jit/passes/utils/optimization_utils.h,sha256=dhmbwBWDJFz6E3FnpQ0yhP-TDfkcizBkQ4GGBPaJHkU,231 +torch/include/torch/csrc/jit/passes/utils/subgraph_utils.h,sha256=0qet1_bMOwUiPSbHbE5XIBhs4VkkEWvIOLJX90xZB3g,2430 +torch/include/torch/csrc/jit/passes/value_refinement_utils.h,sha256=XOX7uCFmw_kcG_mtVNinZ_a5KKOrFkDj9XzQlXV2zrw,2686 +torch/include/torch/csrc/jit/passes/variadic_ops.h,sha256=0MNjgAPQzhZYwpVlEnw_JwC6KJm6m0ZACInMPtyqDjM,914 +torch/include/torch/csrc/jit/passes/vulkan_rewrite.h,sha256=d-h1S7rul1W26blZLbU4ogjM4Y-PerE7_HlBh4v4r9Y,689 +torch/include/torch/csrc/jit/passes/xnnpack_rewrite.h,sha256=Ie-FJR58r0hxfMWRcdxGsRQc1XrhJARcJK6BtJld0zw,814 +torch/include/torch/csrc/jit/python/init.h,sha256=i2NncE1TVtt2PJeNwjHo5LAGO9Fh0Vl5oS-klXs3CFE,152 +torch/include/torch/csrc/jit/python/module_python.h,sha256=5OMtEddhY1S10D3IKlu01MtNhkozsW-UnqLqUsCr6wI,1937 +torch/include/torch/csrc/jit/python/pybind.h,sha256=gWfL4KFhNgbkTZANADMYRx8NPfkQSUUJpnwOgGx69XQ,8126 +torch/include/torch/csrc/jit/python/pybind_utils.h,sha256=lLGXS3E4nzbBshmq1HqJnzPQnP1bcDqaLR-sJa4AFy8,45871 +torch/include/torch/csrc/jit/python/python_arg_flatten.h,sha256=Z5u21b2HupprHErlyNzMHg5NeB5O8U_VNC0Iq3JOPhM,3646 +torch/include/torch/csrc/jit/python/python_custom_class.h,sha256=bOfd7gbVnt0yeQH4apGbXWxQj10YpoO6JxROp64Opo0,431 +torch/include/torch/csrc/jit/python/python_dict.h,sha256=RZmA_owTGyb2lHQY77UUgZxvX4xh_A6Zd7sjqw1zDiE,3510 +torch/include/torch/csrc/jit/python/python_ir.h,sha256=zEkEnyqkMJ3V1zm07dS1fhOYlJr7dND2r8wGU-iPyDs,1743 +torch/include/torch/csrc/jit/python/python_ivalue.h,sha256=wyCb3W5fKu3TAwbjkikxulHqSOToOOjyKZd5cRsP2lM,3793 +torch/include/torch/csrc/jit/python/python_list.h,sha256=-r3r_ZGaaZLa-o_1_NIq3r-tP8zUe5Xcx64B7UAp4z4,5722 +torch/include/torch/csrc/jit/python/python_sugared_value.h,sha256=E_B0DXGfZASCIcX8oNdFwoaOzMrxkOBvk0HHQKDgY8A,11713 +torch/include/torch/csrc/jit/python/python_tracer.h,sha256=zT0AtDtdV-AMfHQLjV6M7IIdFmusEkAn5NkyLiUUDsE,1263 +torch/include/torch/csrc/jit/python/python_tree_views.h,sha256=dx_jIKmysQU6YGhnzJ-R3kHwseS1umv6zplBxn7X_mA,159 +torch/include/torch/csrc/jit/python/script_init.h,sha256=Efq6UwLvwHh1TZ5rUaarRmRVqxKiWMwXvT0YS8klZbU,159 +torch/include/torch/csrc/jit/python/update_graph_executor_opt.h,sha256=fURD0rimDGSdPQgPy19tyOcXj9h6uylfc8FcxuShCZ0,191 +torch/include/torch/csrc/jit/python/utf8_decoding_ignore.h,sha256=O7yWMF19h_RYeJ3sekAp5gTba0RlxhLJtp7am4VaXAM,185 +torch/include/torch/csrc/jit/resource_guard.h,sha256=LE_1QkcwmF7tSI4eOLsRLzjObKlPWcGbZNDyX70rPpI,465 +torch/include/torch/csrc/jit/runtime/argument_spec.h,sha256=SPh54CHucBUmxHi23Fa9qKX_lVZXYsO6_TeSIWRsUG4,16848 +torch/include/torch/csrc/jit/runtime/autodiff.h,sha256=WAqL1t6tXLfnn9egoeQaYVf_FZp02k5xXvdJoFBtasU,4024 +torch/include/torch/csrc/jit/runtime/calculate_necessary_args.h,sha256=AxCQixOYzEfhnYkM7BxpKWqcUVC7TY3vkjTpVshzyZI,2350 +torch/include/torch/csrc/jit/runtime/custom_operator.h,sha256=SoxwmEw9BUzaKn6eLSK-xiukMq2yFhdI27jry4uHWH8,1084 +torch/include/torch/csrc/jit/runtime/decomposition_registry.h,sha256=tSvQIYNQoak4c8w3Gxq8N_ek4U17wmaupPOhdLTtn24,1078 +torch/include/torch/csrc/jit/runtime/decomposition_registry_util.h,sha256=C8UGw7tjCwqrFID4Wb6IwiYV_OaD5p-hG7ZD7CC7gzI,273 +torch/include/torch/csrc/jit/runtime/exception_message.h,sha256=cTfv8KDSRsx1MK9HPplMs5lM_tWCoOj5bGSbeIoktug,643 +torch/include/torch/csrc/jit/runtime/graph_executor.h,sha256=LJx6TLg5F0UyM9J3dt5m4InHdFvnP6hfmtUAB68eTY4,4806 +torch/include/torch/csrc/jit/runtime/graph_executor_impl.h,sha256=9Md9kbxdduufJv1rUh0Uz1Iqv_rT1nwJhLlxg4bpHWI,4140 +torch/include/torch/csrc/jit/runtime/graph_iterator.h,sha256=NFoYMNVy7HWGAWQyE04wjwDD1hjJh0hmAqY_r83kwpk,5085 +torch/include/torch/csrc/jit/runtime/instruction.h,sha256=OvdnlykumT9G3kJHp_S_fLw1PlblXlvQCkcSH-qYsXc,5721 +torch/include/torch/csrc/jit/runtime/interpreter.h,sha256=r15ZiZ-E5fCuYvDiGQAtbae4wb1p5LvQCbdJxRU-q10,5151 +torch/include/torch/csrc/jit/runtime/interpreter/can_emit_inline.h,sha256=YFbw9WzmVmYDVPJ3Ydcy1L89cZeXZ9dXUve_s1PsXjo,4060 +torch/include/torch/csrc/jit/runtime/interpreter/code_impl.h,sha256=aOU-3yHGxtmRqClIjBLejGzkZ_BHs6gxEu96kmbQi5s,34360 +torch/include/torch/csrc/jit/runtime/interpreter/frame.h,sha256=OsN_d6h1BtCtbmapkZc3Nm2D-NAePl_KN_K900QTaoo,1150 +torch/include/torch/csrc/jit/runtime/interpreter/preprocess_graph.h,sha256=l69o-hLGrOLdRV6jzNpgsbIyjmm7gGuHNd2nvRau2CI,420 +torch/include/torch/csrc/jit/runtime/jit_exception.h,sha256=L10AkGIUIHZR3Ttg2jjHsfWReNbHrK71QcRjfMabDBc,1207 +torch/include/torch/csrc/jit/runtime/jit_trace.h,sha256=ULJfqZ5MFDMqJ09tCng6pJiclhumsg_uahXhrDCFn9Q,215 +torch/include/torch/csrc/jit/runtime/logging.h,sha256=DYzDmA73_agtwyTa5Ri2O3JhoC2YexSyUYlgjQcyoFE,2703 +torch/include/torch/csrc/jit/runtime/operator.h,sha256=IuPqIVgpCrkluNojIFC5iEXv5GL0Dd1NIoULUrG6UXs,11961 +torch/include/torch/csrc/jit/runtime/operator_options.h,sha256=NpHNZCfmsCujmj0uDOVC4fk1CzghZ-O3BS1HBQ5KH58,173 +torch/include/torch/csrc/jit/runtime/print_handler.h,sha256=X-jF8DYdCWxsXtXFHwYrMKIHMcZB9rELjEbmEEIKH24,323 +torch/include/torch/csrc/jit/runtime/profiling_graph_executor_impl.h,sha256=nCwxaKXBq0RqOz0Fzozo_np-RYBy8Q-Fy5p9YDjfhX8,3036 +torch/include/torch/csrc/jit/runtime/profiling_record.h,sha256=H9NCUEMXezp5h0jZ4XBWdGcvhkO4un-rBM9NKOvaaVg,8748 +torch/include/torch/csrc/jit/runtime/register_ops_utils.h,sha256=DoycQXlcJP9mtuxlhFJ1gCqLbR6l9ih9JKiFgidgzXo,43495 +torch/include/torch/csrc/jit/runtime/script_profile.h,sha256=OMAxLZ2iIl44AlYJrJHK2ZDz9HYiCCtxIFdYt8T1kXg,2716 +torch/include/torch/csrc/jit/runtime/serialized_shape_function_registry.h,sha256=DZobbkaPdoOtqMPiuzaVbCK9zj8KZ31qQkQR_MsZneY,371 +torch/include/torch/csrc/jit/runtime/shape_function_registry.h,sha256=3Vv4RjMqcac6Bv2VsyZYjfAfdWCtZ0a6S-E8ApDGdrg,255 +torch/include/torch/csrc/jit/runtime/simple_graph_executor_impl.h,sha256=HrWuo2wzNekmYpQCCaCD6jxi0AReX_1B6FmZDj8DxS4,666 +torch/include/torch/csrc/jit/runtime/slice_indices_adjust.h,sha256=mWIa0irSeq1DB9Q-oylD5p94QCL5e3rL88xAQOiCpAA,808 +torch/include/torch/csrc/jit/runtime/static/ProcessedNodeInputs.h,sha256=9Xv_tClC1EH3RepBxmRK1IWb1VFVtLrcL2FQhesMcDs,6644 +torch/include/torch/csrc/jit/runtime/static/fusion.h,sha256=ObBrJudadrZoQQK7FqmvLT-rcKKMEC6oey3-3ICHTb8,322 +torch/include/torch/csrc/jit/runtime/static/impl.h,sha256=HycnSEw2Lz-T-f9_n2skb2qv7oI9HKfJzF9fZdCdUYk,37118 +torch/include/torch/csrc/jit/runtime/static/init.h,sha256=7zrvHBS4oRmsvl-XcVgzoWT9HTH5bP8fQdgib0YsSC4,156 +torch/include/torch/csrc/jit/runtime/static/memory_planner.h,sha256=NqrXm7QmM1PgFqUx2trFhf2neWfOSKqioVvwFiajmf0,10215 +torch/include/torch/csrc/jit/runtime/static/ops.h,sha256=C-_Xu9t1cHmNqZXdHMWnyIXk2Q68eCZnbvnLI-aHXjA,5735 +torch/include/torch/csrc/jit/runtime/static/passes.h,sha256=d2LhcPb3vOnPVFirwaOgXcw3WC5mOmK6rhLAF-rb_8s,3775 +torch/include/torch/csrc/jit/runtime/static/processed_node_wrapper.h,sha256=2l_rKRHsxVDfIMwDYjyUA06glWOBvYmfQZ5jgtS6CMw,6806 +torch/include/torch/csrc/jit/runtime/static/static_method.h,sha256=eKOBtj41dQ3US1du_ktlGXtYVJntvJhOAJwljXdioyo,1372 +torch/include/torch/csrc/jit/runtime/static/te_wrapper.h,sha256=csyyim4vwfUoH_JSHaE0dCUjIAzaLTNhsll-Zf9t0b8,1189 +torch/include/torch/csrc/jit/runtime/symbolic_script.h,sha256=K2HNGLdirhFTj3wzCWgthEJZMGsx-IePHNDCeZZz6Zo,580 +torch/include/torch/csrc/jit/runtime/symbolic_shape_registry.h,sha256=MhHhsVN-mSlKMLz-lbopdJd8BA5k10jqoTy89pFdjDo,2871 +torch/include/torch/csrc/jit/runtime/symbolic_shape_registry_util.h,sha256=B8VZ11dt7_0nNGhb2pcliu6f2ASe49cDL4ahnyTFqBI,363 +torch/include/torch/csrc/jit/runtime/vararg_functions.h,sha256=yVqpOXalASEc5Qjrm-dVRFhfzC8Dds98m4v6CpwoO9U,1188 +torch/include/torch/csrc/jit/runtime/variable_tensor_list.h,sha256=DwXjvY5GRUpYqvZ8LBqEsxAyN_V2og2TQrSUbq62ZSw,544 +torch/include/torch/csrc/jit/serialization/callstack_debug_info_serialization.h,sha256=MU5GXzo2YLyUANqZMvCg-wQi5k4pPoop0orroqkVd0M,2693 +torch/include/torch/csrc/jit/serialization/export.h,sha256=DjpDSL8S-dY-PpAsoYIn8kziA40PiFbDhUoREs5GkBQ,11819 +torch/include/torch/csrc/jit/serialization/export_bytecode.h,sha256=TVcC6u42NlcDOWs8ceX1IiDxuP9aAqCV4OWUVgH2mt0,1409 +torch/include/torch/csrc/jit/serialization/flatbuffer_serializer.h,sha256=_DAsXKeDUDO_pXRUpkyKWZv0amjGtpxqMJaR4mWkCZ0,3135 +torch/include/torch/csrc/jit/serialization/flatbuffer_serializer_jit.h,sha256=ZsO6A5qt5Y4qNn0BJRCoVijB7f_DCtl4HYCDn8cbU5c,181 +torch/include/torch/csrc/jit/serialization/import.h,sha256=XNNJHmfuqdHkDywerTFLYITUC8t6Je1HMA-wAB_YE2k,5162 +torch/include/torch/csrc/jit/serialization/import_export_constants.h,sha256=Gxlb2d2w1NFWJi03CP4Na-mE1dZw-FlqydW5ULzsEx4,664 +torch/include/torch/csrc/jit/serialization/import_export_functions.h,sha256=rCCONSFPiXASVesUcMacR68WzdXl75hgUxeJblfCj2E,404 +torch/include/torch/csrc/jit/serialization/import_export_helpers.h,sha256=AYBaYN0Z0GqmOYYTCs4Wyb_iiNREkGYclEP-vTmh7G4,679 +torch/include/torch/csrc/jit/serialization/import_read.h,sha256=FHmMUfTYXlR65B-pV27k79sv0K0rYqqqJZaBxUSH9Jc,859 +torch/include/torch/csrc/jit/serialization/import_source.h,sha256=rzGq3Aw4H6RfD2iNqoGTHdiwwAdAjOs9xCLumld6lJ4,3524 +torch/include/torch/csrc/jit/serialization/mobile_bytecode_generated.h,sha256=G8poBLhQdXugNI9laSJNVDKr7gXIhTqKU_xUPnCSrkY,101210 +torch/include/torch/csrc/jit/serialization/onnx.h,sha256=4m-NDRqmWGCY1Ll2Vwz5dykJ7-3oyyArKB9EdvK9H8A,531 +torch/include/torch/csrc/jit/serialization/pickle.h,sha256=PE-d5PygGo_OsYlzhMypL9NACgnNIB4ohsdsMAIXzUw,4830 +torch/include/torch/csrc/jit/serialization/pickler.h,sha256=Lg7jjsR73TriNldCKmZKea33YYPAk-dnAWx-egSxAfY,13946 +torch/include/torch/csrc/jit/serialization/python_print.h,sha256=iHMkujGOtaGqSDkzF3NqmhfvwokJlQYK26E902KcCSE,1365 +torch/include/torch/csrc/jit/serialization/source_range_serialization.h,sha256=p8lpip1MX0PbOoUA9-E4yiqsedEqu-n6S6GOnhCjR8Q,1732 +torch/include/torch/csrc/jit/serialization/source_range_serialization_impl.h,sha256=4uwobNCEgslxm0X4i96dpFOTuKk_PKlzzAlqhHaU05A,708 +torch/include/torch/csrc/jit/serialization/storage_context.h,sha256=l3lzODCxwXvQh0su8CA4mDTBK7T8e_pGIz_1hZXewrE,2570 +torch/include/torch/csrc/jit/serialization/type_name_uniquer.h,sha256=pWy5GkA6Qti9gHaGdcNuQeniu5n4KYL-xDs7i6X51T4,785 +torch/include/torch/csrc/jit/serialization/unpickler.h,sha256=Stsh6cJdXT45mxSIAY0-HpSCWDj-chWBtxDU2xLiOZw,7853 +torch/include/torch/csrc/jit/tensorexpr/analysis.h,sha256=3JWD5XSNYp9i9DpmRYgaz0Bjc5CPl-UCDKS56aAyFN4,9380 +torch/include/torch/csrc/jit/tensorexpr/block_codegen.h,sha256=liHXcA3i3kUsqudgdsvRXoUJYCcExoATjUT7Nt4Gj0M,4436 +torch/include/torch/csrc/jit/tensorexpr/bounds_inference.h,sha256=ZiLA8S_RVeH6UXdxePWfI0oGCkc6Vx4n8w23cJV1Qp4,2245 +torch/include/torch/csrc/jit/tensorexpr/bounds_overlap.h,sha256=iX6jNPaYlqwLweHNshznve8uD1VNjR2NyHntihxbuGI,4592 +torch/include/torch/csrc/jit/tensorexpr/codegen.h,sha256=96FV8Ymzt7OhHcw7GKiBsVNrhm_4x5wRYYQ1UlMcIKI,7705 +torch/include/torch/csrc/jit/tensorexpr/cpp_codegen.h,sha256=s-g-450KR1v2ko1YK25xjAcglNyhFhdix224BYF-iQA,2445 +torch/include/torch/csrc/jit/tensorexpr/cpp_intrinsics.h,sha256=ORBrZam_67_JHZ6Zqrt_O4kI-QHj4u1iER-lOPLsGg8,665 +torch/include/torch/csrc/jit/tensorexpr/cuda_codegen.h,sha256=pQId-NRLSeQAwMRDNR4vz4qbzyXl8NdM8PPSe-YkvGs,8520 +torch/include/torch/csrc/jit/tensorexpr/cuda_random.h,sha256=O3UqOWAsQJqLMzl2UNTLLJBUVCOQ8drk08tF4iKzMqA,2692 +torch/include/torch/csrc/jit/tensorexpr/eval.h,sha256=Z1cYSzYi90UOa_HjudZCyCd5SjFecIX3svHTJnl_45g,10191 +torch/include/torch/csrc/jit/tensorexpr/exceptions.h,sha256=wRCyB6MGrwLign2m3wlTF3b37m3STayurv-LAmXMiJg,3276 +torch/include/torch/csrc/jit/tensorexpr/expr.h,sha256=2aRhdflzESw0gakl5Abb2aa4MFmQbM6Bs6bODxQYeYA,14674 +torch/include/torch/csrc/jit/tensorexpr/external_functions.h,sha256=bUfCYwARtco5y5ft0ebhhR5lyBDIS-Wi_QtLYQf8tig,3544 +torch/include/torch/csrc/jit/tensorexpr/external_functions_core.h,sha256=ufuYOMKvNfZdYellSyZu96-FIU7cVSzZrB5MOIVSDc4,478 +torch/include/torch/csrc/jit/tensorexpr/external_functions_registry.h,sha256=DnF7DHXa007fPnmf1iGFyo5SZPOUFFHrSE6xtx6NtJE,2363 +torch/include/torch/csrc/jit/tensorexpr/fwd_decls.h,sha256=RpPesLU9wJm1lsIk76BzFQL-AoRzawQLusl8CGd1BrA,3115 +torch/include/torch/csrc/jit/tensorexpr/graph_opt.h,sha256=Q0Yq7vwSoBY_ewgDxizoQ4Pz8nuOTiuIroulsHiGK7Y,4545 +torch/include/torch/csrc/jit/tensorexpr/half_support.h,sha256=ZmcXcPVR7o21jNd-PjgaYy9UVE9my_D5Rh_tCbiOrWc,6101 +torch/include/torch/csrc/jit/tensorexpr/hash_provider.h,sha256=wqko35xT5qnFpLQ3y62E7D3ZFGWThWPmB2uMNlqhgn8,7814 +torch/include/torch/csrc/jit/tensorexpr/intrinsic_symbols.h,sha256=YZwpuvvFwi-H8C9YZt2AXN1DDoZHAxQ02X11WAvOhss,442 +torch/include/torch/csrc/jit/tensorexpr/ir.h,sha256=YylpaoWdERv8OebjprEshbW3P44D0EkBLUZj3FkmRQ8,23871 +torch/include/torch/csrc/jit/tensorexpr/ir_cloner.h,sha256=UufUmbey1b3-FIU22GuVJOsHCJhR6omZVyJHlq2NqDM,2405 +torch/include/torch/csrc/jit/tensorexpr/ir_mutator.h,sha256=dvuNFbPxC3NPrLSdUuf9jArNh3ziQbxjGvlLUalt5mY,2431 +torch/include/torch/csrc/jit/tensorexpr/ir_printer.h,sha256=03qn-fR0SRw3-VNLv31lHUYvx8U6Qp9Fnnd-pGNTCi4,4322 +torch/include/torch/csrc/jit/tensorexpr/ir_simplifier.h,sha256=-ZFffvZpK0anEx-gdbb_GkiJiIG01FPvxcKo8DX8b0M,15803 +torch/include/torch/csrc/jit/tensorexpr/ir_verifier.h,sha256=5MPFYIFqxeiwG9RjQw57XcGCqJ3YfZLXMbr8a0LKzOs,1349 +torch/include/torch/csrc/jit/tensorexpr/ir_visitor.h,sha256=hcxYtN0kl6GH2Rk-Hn07E4BGla2MnsDXFEHpp6i_2zM,2247 +torch/include/torch/csrc/jit/tensorexpr/kernel.h,sha256=KyVVNKF8SP4o5g3OUaRgIQQ86unde7WEBD1TD4Fzo3U,13770 +torch/include/torch/csrc/jit/tensorexpr/llvm_codegen.h,sha256=la6pZYUN0gQWZX86APTSWjO1XpMQMxBFRy9Fvuiv4yA,3982 +torch/include/torch/csrc/jit/tensorexpr/llvm_jit.h,sha256=PC36me1ahcY9_uiebc_S0KL6YjRNWtv-ByJeHiRRjp4,2017 +torch/include/torch/csrc/jit/tensorexpr/loopnest.h,sha256=kz9sDN71Rmpwf-V6XThaHJt5vN5DfqEs2ijsQbMqTz4,22323 +torch/include/torch/csrc/jit/tensorexpr/loopnest_randomization.h,sha256=9PPU1L-TRQKnFzEKBTmdjQORv5I9ou0d92kSrxvAHAY,318 +torch/include/torch/csrc/jit/tensorexpr/lowerings.h,sha256=S0A7TugHStf2EZabEX5MFRMK7Af4cTAKZ2BgGzQl8Lw,1325 +torch/include/torch/csrc/jit/tensorexpr/mem_dependency_checker.h,sha256=_JO59BH_gfJC9VPIIQYad7BZ2pZOCJYQwnkzJGWiuvg,13870 +torch/include/torch/csrc/jit/tensorexpr/operators/conv2d.h,sha256=Mm0g-Ah4Mb7aMRT4_7VMKUy8vsrvytXkyUVycO2kcE4,2994 +torch/include/torch/csrc/jit/tensorexpr/operators/matmul.h,sha256=HGUVbfZPINHB3o7pve_A0XmZHSTMwUNLXuHC1JvrLe0,623 +torch/include/torch/csrc/jit/tensorexpr/operators/misc.h,sha256=7OB1Q1VnX1KsVXVAclsGcInM3KYm66fQAdpbl-swA5s,3382 +torch/include/torch/csrc/jit/tensorexpr/operators/norm.h,sha256=AwYdiB8yQ8-itzzifuqTFVCGd5dkLazgaNSSDnGhEjs,387 +torch/include/torch/csrc/jit/tensorexpr/operators/operators.h,sha256=-h6eXiEoQ3yAJrbBJY_nN3y2GJYtxWvXRs9FhcybUKQ,481 +torch/include/torch/csrc/jit/tensorexpr/operators/pointwise.h,sha256=ObRtjRwJnqvYIzcj_mvYBzvAJQ9CHF2i0KoXRA8vxuo,3241 +torch/include/torch/csrc/jit/tensorexpr/operators/quantization.h,sha256=Kp5ZgMZQZ-aCsJMZXUANdj_LpOGhZXTJqQIiljyHrZs,5688 +torch/include/torch/csrc/jit/tensorexpr/operators/reduction.h,sha256=0uLPBIgc_6djysvpAWDZfozD7XSazG2L0lfO-Rqgpqo,1137 +torch/include/torch/csrc/jit/tensorexpr/operators/softmax.h,sha256=uV-dMYJaYnFdEm4Tc3EH1GE_YaLRNYoZ7QYBjB2HvKk,334 +torch/include/torch/csrc/jit/tensorexpr/reduction.h,sha256=DJZ6nK-0johan_sOC-zHqreMT7JOGSSYOe7ox5gYcr8,9172 +torch/include/torch/csrc/jit/tensorexpr/registerizer.h,sha256=LSDiQlQYGsPVki_wR7KPhLMXn7zak_EQwsc9GVMkIQw,12944 +torch/include/torch/csrc/jit/tensorexpr/stmt.h,sha256=xK1rxiy5xyEmU8Yh0CypQ7ZA6Uhomcs9Vfu4faWH-sA,24812 +torch/include/torch/csrc/jit/tensorexpr/tensor.h,sha256=5PcRflwqxUlES18h4j2SSqiDyLjN0AdCPSLTLZnHOHo,10820 +torch/include/torch/csrc/jit/tensorexpr/tensorexpr_init.h,sha256=eBfap1CqnDTy-t-gaB4damjiV0mQk8mSdFazdubR-Pk,252 +torch/include/torch/csrc/jit/tensorexpr/types.h,sha256=metxvN6sSJUjJ7-JujLs5uDLrE2E6m9htEFEVV7bQnU,4438 +torch/include/torch/csrc/jit/tensorexpr/unique_name_manager.h,sha256=YmHJxk30IzeuE91fjBDviOV_iCnC8Skt893lqTMaTkg,858 +torch/include/torch/csrc/jit/tensorexpr/var_substitutor.h,sha256=dTOt97UDkg_Q1CgsmM1MSz1CwKgFDR6F0rlz3R1mNVg,1721 +torch/include/torch/csrc/jit/testing/catch_utils.hpp,sha256=S-huxpEPbTDaG3TIeZ50HL0YZW9A2UZfC-mOaann704,363 +torch/include/torch/csrc/jit/testing/file_check.h,sha256=wVDvTscyD8b5NV2J-ggOKkJVMcOEoWDDIHW29uSNWs8,2622 +torch/include/torch/csrc/jit/testing/hooks_for_testing.h,sha256=v4fjObrJmERmIG3ImKR5maWb5X4LYBI-l-O06xLlzME,597 +torch/include/torch/csrc/lazy/backend/backend_data.h,sha256=5_qzuWIuBz395o1tSqopnqaz0VW9Q8jap3j12BIdVwA,1239 +torch/include/torch/csrc/lazy/backend/backend_device.h,sha256=x1D4GYTSIoAIZahnUlpzU9cfDeZixosQ3_hsO6buX3o,3028 +torch/include/torch/csrc/lazy/backend/backend_interface.h,sha256=HNM_XmQLz0WS-KzHVstburyAyVDBlt-G4yYDsfOVr8A,4958 +torch/include/torch/csrc/lazy/backend/lowering_context.h,sha256=pUw9vJKF9VYdslvvRbjsopjpmYZb_nGVjiQtjtGOfEc,3348 +torch/include/torch/csrc/lazy/core/cache.h,sha256=70uhOoRL0qhQiE-lMIp2Sf1EiXsDXQYYQz57DDwJ1qo,3868 +torch/include/torch/csrc/lazy/core/config.h,sha256=cgH-rOAOxVHSgvMTF4C9gIjhTaiTIiatACTstS16158,937 +torch/include/torch/csrc/lazy/core/debug_util.h,sha256=c3kofLV_LR00CqleDhF2uRo43LVAtGezqkcmDfT4BgQ,1319 +torch/include/torch/csrc/lazy/core/dynamic_ir.h,sha256=Td8oIhgzIkB61n7-9MWJ7okufulgGGAAO1v9snmMur8,1445 +torch/include/torch/csrc/lazy/core/hash.h,sha256=1pFMFXu_psrBAzbLpVXJKzcclyENvPTR5O_3VbwvFJA,7841 +torch/include/torch/csrc/lazy/core/helpers.h,sha256=xtw1xpuRHNEg7Fp1M1OIiJylJRnF-KW7cYjX7Gh9GsQ,2293 +torch/include/torch/csrc/lazy/core/internal_ops/ltc_ops.h,sha256=7sT0crht4Eo41uQT1mAPD9NO6dWtFbNQM6vtevfQa4I,1499 +torch/include/torch/csrc/lazy/core/ir.h,sha256=DIo2NLA7uFiAREGaha67Rp8e4QqfVnYJhICRYqvrftE,8094 +torch/include/torch/csrc/lazy/core/ir_builder.h,sha256=nwYYsVeRDkFfAqPoVFgxQFv_fadWHJBLP3GzZco5uUQ,4857 +torch/include/torch/csrc/lazy/core/ir_dump_util.h,sha256=6iXALA8Nkqit0OfOPEjmFu49ZHa1uO1i8-pHZgVqgO0,694 +torch/include/torch/csrc/lazy/core/ir_metadata.h,sha256=J7UqtJcBWbQriezL6N-cgEHpnjatQaaHWrIKuGFOTbQ,1374 +torch/include/torch/csrc/lazy/core/ir_util.h,sha256=kIQyO2Qrh17P-PF43FuT7YRfaewUOblvoi3egz5Pu8M,1411 +torch/include/torch/csrc/lazy/core/lazy_graph_executor.h,sha256=QNTqnsIbzi2lXLIHH6efXBMMMW-YmpGF0xXM1dAL9pg,15447 +torch/include/torch/csrc/lazy/core/metrics.h,sha256=lIGMK8WAR3ZrHl6lneEdkTNm7SA1_MNb2NYiDWy69r4,8539 +torch/include/torch/csrc/lazy/core/multi_wait.h,sha256=1T2cpsXRLr1ZiB209meYPHRRaeCsakSlcKEYfe6dVJU,1809 +torch/include/torch/csrc/lazy/core/ops/arithmetic_ir_ops.h,sha256=CKs9ptKDHLDlRJGXz9RRI7P9Le9NkHsj5OkAC80tUC4,393 +torch/include/torch/csrc/lazy/core/ops/utils.h,sha256=Gzg-LTfHiatybYnBYTK9AcpaiHL6aCyRqTtdRgsr5jo,1024 +torch/include/torch/csrc/lazy/core/permutation_util.h,sha256=eF4t6wGgtloxrMym_EINFNuK45rPbXD6bT2svR6qx1Q,1293 +torch/include/torch/csrc/lazy/core/shape.h,sha256=7Vj7SxcQNCmCBA4YFY52BMmjzkq-AwMyBqQd3Va1nLg,2093 +torch/include/torch/csrc/lazy/core/shape_inference.h,sha256=Un1MKLfU_6bVp64BNshwJSWx0dLheqrIkgzJSwvmE_8,15524 +torch/include/torch/csrc/lazy/core/tensor.h,sha256=p1KdUk5jrz1Ryes3swhTDpguVEvLNlDl9vxyh5VM0Rs,10134 +torch/include/torch/csrc/lazy/core/tensor_impl.h,sha256=X9q5j6wO5z--jx_n4Qg8Cl4FqMolucocLgp0pjn2HxA,1945 +torch/include/torch/csrc/lazy/core/tensor_util.h,sha256=LPmanbOk2iGoza9Zdbw7echc64CsXgd2_2cB2s5wIpc,2612 +torch/include/torch/csrc/lazy/core/thread_pool.h,sha256=uycqgsX_CIJGVLv4GLOqaqXKmVMhdDuigVgCKttf7b0,828 +torch/include/torch/csrc/lazy/core/trie.h,sha256=La59BZ24moOPekhUpUDtLQ2exV4SFXJpug_kq7MITBA,2269 +torch/include/torch/csrc/lazy/core/unique.h,sha256=E9x6KDNslXzFCFQR-JA0RB40KgqM_KjNA4BG2poYJC4,1227 +torch/include/torch/csrc/lazy/core/util.h,sha256=h_7ihPzzer0m9rndz0s5Bw9MWd0HtwtvVi8es8XCtj4,2998 +torch/include/torch/csrc/lazy/generated/LazyIr.h,sha256=dlFsmPNuf69ksOxSCLIM5UqHvYTrBCjY9KC7WnTCAVI,306129 +torch/include/torch/csrc/lazy/generated/LazyNativeFunctions.h,sha256=ARI21TqdT7f13r3OHwnjlVE54b2pkqAyY2r4vIbX4d4,22948 +torch/include/torch/csrc/lazy/generated/LazyNonNativeIr.h,sha256=v1SsQyg0JU5Dcqnv6ljEVgRKE8sB8_HTk45N4kFOkCg,4098 +torch/include/torch/csrc/lazy/python/init.h,sha256=hWFC9Mi6Pq6UHI2Wuvs_G47gtr_uE1iu9S4BitNipPo,234 +torch/include/torch/csrc/lazy/python/python_util.h,sha256=IqFlmiEiO2PEGH2jJEqU045I1vBN86Ot7KyJFl_lbZc,328 +torch/include/torch/csrc/lazy/ts_backend/config.h,sha256=qmJyDaN_LrFM0CWt02XRK4vynnR15TnAM5FcGnHR_fI,213 +torch/include/torch/csrc/lazy/ts_backend/dynamic_ir.h,sha256=rTpCrZVDlL0gWGK4Nxj1ZOaFN8Z3Mh7s1kHn-4Y8WeY,2433 +torch/include/torch/csrc/lazy/ts_backend/ir_builder.h,sha256=xQkcSRrQTFkUetm_vvc7TeLi_KKLdfjU5R4AeAK87PM,2460 +torch/include/torch/csrc/lazy/ts_backend/ops/device_data.h,sha256=OlhSlHYADj2f_pCFyFxLCg2_ovuVDocBaLmD5T-09pw,1346 +torch/include/torch/csrc/lazy/ts_backend/ops/generic.h,sha256=opsimakMwVh4cJNO0vHO9omoeKj3HM-1-AiBr0CnYZA,1480 +torch/include/torch/csrc/lazy/ts_backend/ops/to_copy.h,sha256=KAKshdejriC2X3nGYOVyzWt0dPEMG-_36SR8mgfcY6k,4204 +torch/include/torch/csrc/lazy/ts_backend/tensor_aten_ops.h,sha256=1ovuhgTENNLMnXJs_QBRUxeip6_9ky_0sfrPUWvDQYM,538 +torch/include/torch/csrc/lazy/ts_backend/ts_autograd_functions.h,sha256=8hSv5NZcjHP20ZXv3rMOA8Q9cpSFCMSxgMpcvb0xWKw,648 +torch/include/torch/csrc/lazy/ts_backend/ts_backend_impl.h,sha256=Yx1daoyyzWewtSOU2lrIvc6sh0N9hKsk07kHqFYtvo0,1292 +torch/include/torch/csrc/lazy/ts_backend/ts_eager_fallback.h,sha256=2owJOOBAT2CtF5gLNmjAm40injSFXuP2giwQmp1U5MU,717 +torch/include/torch/csrc/lazy/ts_backend/ts_lowering_context.h,sha256=kW3OS0HkfD_7LEcgTuli-cliwfO72ybciPBO91Z4taA,4658 +torch/include/torch/csrc/lazy/ts_backend/ts_node.h,sha256=UouGY47AKB-SesjpK_Dz39Gc8tIzYDraaVTpyMsDpQ0,3455 +torch/include/torch/csrc/lazy/ts_backend/ts_node_lowering.h,sha256=CXTtiMYBFQvHocnkCp49zqRkAIdQCmMHKUVvqMDvJhc,481 +torch/include/torch/csrc/monitor/counters.h,sha256=riXkkGNp1YXaKTVSeihwZ2TOM6iKQ-hL4dASv-EW_zo,8123 +torch/include/torch/csrc/monitor/events.h,sha256=3UA_fQgH8o7aAxER5s1wdbcvJgJXYqFiz5SIOEWw_GQ,2763 +torch/include/torch/csrc/monitor/python_init.h,sha256=gx3t5q0mGdo2BKXCdljdSCw8_VsJwulFh1HnVGz9S3c,136 +torch/include/torch/csrc/mps/Module.h,sha256=iQvunKhwFtyNTGIEQ7YIdWhtFhEFwuYf36DXvM_ZuUI,183 +torch/include/torch/csrc/mtia/Module.h,sha256=5nB_fDypshZxGkqeRODtQNslobvR2x2fCDhg4So8gGI,188 +torch/include/torch/csrc/mtia/profiler/MTIAMemoryProfiler.h,sha256=p4fiXe0_nqrNoL1pDpusdEsE6qyTpxJsa_9q5JrxqRg,567 +torch/include/torch/csrc/multiprocessing/init.h,sha256=cjMqRD4hl9111s6FJDuRJcpvuvHLn3hr1dAjquDELMM,177 +torch/include/torch/csrc/onnx/back_compat.h,sha256=9qEjFwxGFF7OWUKRAwlP4pLw9mTDO1UmxwoG2kphFf8,1049 +torch/include/torch/csrc/onnx/init.h,sha256=KY5hHcYArCb0baUZBLMJMVbgtkyWYzWY0NwGUQeNpto,155 +torch/include/torch/csrc/onnx/onnx.h,sha256=V0LR_nLSo_seyDsfKaDTFkJLOXvib-d5ZJwVAZ3v5gg,525 +torch/include/torch/csrc/profiler/api.h,sha256=VDYmDghgFVbw674serVq3iULb_TDXeTebyydXpen-iQ,524 +torch/include/torch/csrc/profiler/collection.h,sha256=Gc7fv_X2ZbZi2TQvD_99knHFQeN-o8U_8W0OgOcXXWw,21458 +torch/include/torch/csrc/profiler/combined_traceback.h,sha256=5RsOp6KLPZVlukfEbf7tCiXD8K0iyia_NlKXX8neYOg,2533 +torch/include/torch/csrc/profiler/containers.h,sha256=svP9-cEDGKKH726PT6pWM-tY9BSKoX3IIq09LPWcoCs,6197 +torch/include/torch/csrc/profiler/data_flow.h,sha256=nkT-yS8N8confbvR4b3nvXlfzeRkvJLbhesLjJBPDeg,3716 +torch/include/torch/csrc/profiler/events.h,sha256=zc_yOwkIz3Ob-cF8VNpuz8DDOOM2gMORUdiEN42argg,1065 +torch/include/torch/csrc/profiler/kineto_client_interface.h,sha256=y1FiwkloZ5_rAflq5ULLBUAxarewrFjMXWfmgAaZgq0,257 +torch/include/torch/csrc/profiler/kineto_shim.h,sha256=bY-ErpxwbTFE1RzQ3_zy_hnYJ0OhnkZK51vgcdzHYl8,4063 +torch/include/torch/csrc/profiler/orchestration/observer.h,sha256=akWTaMGVAh8FplTFcWnS1gGkRF3GAPrgKt7D_hbuDHs,7005 +torch/include/torch/csrc/profiler/orchestration/python_tracer.h,sha256=FqjOD-aKmgKqHqVWjMQGtRM0EnyfwBm20r1Rl8k4D2U,2337 +torch/include/torch/csrc/profiler/orchestration/vulkan.h,sha256=ojZficMKdnSF1EIhOZGy8JJgtWZiD-gBmIxbfJJt6pI,798 +torch/include/torch/csrc/profiler/perf-inl.h,sha256=YfL6n00OhTVelUTLD82eFme9OGa0p74-D_AaNVVml4E,1472 +torch/include/torch/csrc/profiler/perf.h,sha256=tNW756FrJGzlUqPdqwZou-8StifyRzxjvpsukdu5lNU,2685 +torch/include/torch/csrc/profiler/python/combined_traceback.h,sha256=XGYBx7KzfcV09V2a2YggwBhFaC-UU717qDtLoVxw1MA,1019 +torch/include/torch/csrc/profiler/python/init.h,sha256=N0X-d4k_ttDTl8GOxydYVZOZNT9JAHzgPCeStcXIVI8,1029 +torch/include/torch/csrc/profiler/python/pybind.h,sha256=VCzcQYPzhqRJSd6b1PpxYEQJ6tzmJbleiQHzwS6W5Sg,1309 +torch/include/torch/csrc/profiler/standalone/execution_trace_observer.h,sha256=LblaS4ngs8iimSpcXMAaiNIzA8Qp8-PhQEsukVehXwY,649 +torch/include/torch/csrc/profiler/standalone/itt_observer.h,sha256=GDExohXxarK9-LosgQPEPiEZqo-18vpjs9Y60aBRuMQ,233 +torch/include/torch/csrc/profiler/standalone/nvtx_observer.h,sha256=UiHeamK3HzENh8pwzw6bAiAx4v3ql_fRAWE8odQ7JTo,234 +torch/include/torch/csrc/profiler/standalone/privateuse1_observer.h,sha256=CcwIKWKtn2DwLNSxQLAYH6-gpHfyX7G7dJovefpzF9Y,1491 +torch/include/torch/csrc/profiler/stubs/base.h,sha256=nK39hvhjJ5W92HQNRrYkVy3A2rVuBPMrpuhGtLdWHZw,1738 +torch/include/torch/csrc/profiler/unwind/action.h,sha256=NWFv3pnJ3lL_2ZuT_a0JIU77cEVlDIDEC9uk1IVeRjs,1483 +torch/include/torch/csrc/profiler/unwind/communicate.h,sha256=uVPh1LWiIQ78Jg41RHTvQoGnhppWefvkxHTPMO_o6VQ,2325 +torch/include/torch/csrc/profiler/unwind/debug_info.h,sha256=XTBWho4LxjapTCeWeGO26AJJgwJ8yxC19se2ob-I7l4,9380 +torch/include/torch/csrc/profiler/unwind/dwarf_enums.h,sha256=peZd5vn4hi83PUium-y1YRxhkSw4FhJ8osbpr2vWX1o,1156 +torch/include/torch/csrc/profiler/unwind/dwarf_symbolize_enums.h,sha256=a9vyfzJchOzGSDXVD0FCY51UgRcDfzwscFcwjuiQQqA,4837 +torch/include/torch/csrc/profiler/unwind/eh_frame_hdr.h,sha256=jDVryyMGD9DeXsvCGk9GHKQyZUrz1l-e7RwBeFuRbaM,2780 +torch/include/torch/csrc/profiler/unwind/fast_symbolizer.h,sha256=t0D681TKYCGhL5_14xvGoMPNr2lADeUj4xGGYacNnO8,3411 +torch/include/torch/csrc/profiler/unwind/fde.h,sha256=_F3ei86PWfFnNGDNQBqn480qpEhwNolhS7xbxIojREs,12870 +torch/include/torch/csrc/profiler/unwind/lexer.h,sha256=ZeXtLTSiMu2hTPB8MW6GsBBLoGcdH58mfL7EPu85WJA,4040 +torch/include/torch/csrc/profiler/unwind/line_number_program.h,sha256=DwmZKOUcZFgC3nrkTRg4vGVhkD-C7mUT1qBH9WKX8n8,11153 +torch/include/torch/csrc/profiler/unwind/mem_file.h,sha256=96rzjpax9_eQhJhr16aY2oGwuC9iYqF16rlduwW3iXI,4748 +torch/include/torch/csrc/profiler/unwind/range_table.h,sha256=-MGJlxR7Ct7g11UXXMMqhzTBRzUavgohzK00VeUe11c,2182 +torch/include/torch/csrc/profiler/unwind/sections.h,sha256=YgYLGnzhB1LtD2j3z25rarSNDIpic108T9BXaUfie3k,3788 +torch/include/torch/csrc/profiler/unwind/unwind.h,sha256=kpqpoYm7bQcVzwAP9JOFAwzrlYK9qcPKt_0X4WtfXB4,1176 +torch/include/torch/csrc/profiler/unwind/unwind_error.h,sha256=7q1zifDqVVUdEx68yu96uDIZuhpGsQazqyj2KzEJu58,927 +torch/include/torch/csrc/profiler/unwind/unwinder.h,sha256=-eXlBJB2ehVWOdqtOguPy1I_LsE_8LQjovA5NpN8r6Y,2374 +torch/include/torch/csrc/profiler/util.h,sha256=Budh-zb_pqXXt0DYU5_WEfLn4Jd7-Ej1WnppNYgrEGM,7064 +torch/include/torch/csrc/python_dimname.h,sha256=c77Jb9MhopjkGsQarpNj3EwA_-HPdvWkyIg6WxjIHSY,221 +torch/include/torch/csrc/python_headers.h,sha256=WqhGrFwwlDZfSUaUPQD3kLuzuzzXWYPS0lWW6s8XG-g,674 +torch/include/torch/csrc/serialization.h,sha256=n4vPtf32pqv0OQqnBdBueb3Q96g5TzqLu1wtQaiZjVE,708 +torch/include/torch/csrc/stable/library.h,sha256=Ba-CPF2FdgRTcrrbGp-0hdF5EG4I5MlcdZDJvsylpFM,13921 +torch/include/torch/csrc/stable/tensor.h,sha256=olgjhxXQiRL8pWDSCIkJlC6mX0hsFvbWqjhOrBegz7M,4528 +torch/include/torch/csrc/tensor/python_tensor.h,sha256=cZ7B9YrJARCRQ1xlpPICDUr3LG00yGkuVDZRB-lELAs,1235 +torch/include/torch/csrc/utils.h,sha256=LDhvHm3LSFl6n1a_JPfCDSsjE6pJDJOaCJuX7m-hLfo,9308 +torch/include/torch/csrc/utils/byte_order.h,sha256=F2L2YHm7KTyxCu2UKFHBjvQuLuU2eCbZ_LP-R_fEpb4,2335 +torch/include/torch/csrc/utils/cpp_stacktraces.h,sha256=p1NI6h3a3IVSXYpCZ19AegaJ7QpFlEpL665-B5qgGvM,239 +torch/include/torch/csrc/utils/cuda_enabled.h,sha256=sdbqDaSxb_PpKkxb1k6oEoOntjSb1KSeduhudzAr2wE,183 +torch/include/torch/csrc/utils/device_lazy_init.h,sha256=yZoW0KrEssRzY_67g7R_DCI62McRo1P6pgUF11JkPWI,2901 +torch/include/torch/csrc/utils/disable_torch_function.h,sha256=9fxr63O2NxYu4sRU_rZKaAN8O8VbVG7rQQU5XY2zx3k,1913 +torch/include/torch/csrc/utils/generated_serialization_types.h,sha256=oL5ib7LY_FQJcagdCOeoa8cpsbgZU5gFkKFuTICMW3k,120609 +torch/include/torch/csrc/utils/init.h,sha256=WOmDTZNh1OxBs_DDnptt4q8N-gplV6-7O15qv7XczMc,202 +torch/include/torch/csrc/utils/invalid_arguments.h,sha256=KH-u_vRNZGchvpQ4GmmzN09i430y8C7Mu8FGZZfJ2wU,317 +torch/include/torch/csrc/utils/nested.h,sha256=v-jYzmYBZFOoVrwYszacsetZnfLptAnYKWoWcErOH_0,321 +torch/include/torch/csrc/utils/numpy_stub.h,sha256=x0yEqAooYgpFD1pUVnpCauGU6T3_tc_asQxyOU_tVp4,420 +torch/include/torch/csrc/utils/object_ptr.h,sha256=TewuaL3Kp-6hse_HX88itFCVHSoYAxjos48RSmHPUYg,2081 +torch/include/torch/csrc/utils/out_types.h,sha256=E1IeEC0JsOUEny8HRRotbSQtmLPD763f1se3hs21n_A,335 +torch/include/torch/csrc/utils/pybind.h,sha256=xz2ZCrG6QkOIi66pcb8SfeGHg0vOmhvDJ4MGKs6MLfI,13430 +torch/include/torch/csrc/utils/pycfunction_helpers.h,sha256=2xLUpv3UGdJnmupyEqnBS0bnJ3gyOFmPDrijPMUwrbg,398 +torch/include/torch/csrc/utils/pyobject_preservation.h,sha256=WX-3hZ_WE7_VRWNKNIF63mNHZTDIOMJF9Zwev1lsApo,188 +torch/include/torch/csrc/utils/python_arg_parser.h,sha256=d8J04mrA4Hgp31V3Kuhkpe_6FBgH_KykFMKWYCqzuSw,42857 +torch/include/torch/csrc/utils/python_compat.h,sha256=3OXStf9tvgdXhkCo5PmSy5Vk56epD5XTUVEsTXXGo7Q,1223 +torch/include/torch/csrc/utils/python_dispatch.h,sha256=8wC6-4OmbtE0PcYvWp6evSH0ml6C5OJWyZaYtFFgcUA,413 +torch/include/torch/csrc/utils/python_numbers.h,sha256=vzVPyrATzdb9FIxSR0Hdpp4DIBS38O5-vrtctdxU7_U,5655 +torch/include/torch/csrc/utils/python_raii.h,sha256=zIKzo6cTPq-NwbHj_ybMP2t_Z82GaIbn2xPHHOHz_aM,2742 +torch/include/torch/csrc/utils/python_scalars.h,sha256=jdpsZUCupeg_NwsIrRxiW1b7yxZkfiOrWgJuB1UGciE,6309 +torch/include/torch/csrc/utils/python_strings.h,sha256=gUtNsJ4-ovhBz_u37Fxwl9-AFeaH0m3qAzvrFNUT568,4522 +torch/include/torch/csrc/utils/python_stub.h,sha256=VCeDRiGbtYhM3F7bhp90T2hFV0hOlI_g8iecAD84L9M,60 +torch/include/torch/csrc/utils/python_symnode.h,sha256=gmn9gtwsM4u8iC7GWGNKEab1dZlq-lES73QgO_X_kxc,10480 +torch/include/torch/csrc/utils/python_torch_function_mode.h,sha256=C3ss_MLJnMYBp1b1PqVHFB5niuuKdQKGlNCb-okieHo,869 +torch/include/torch/csrc/utils/python_tuples.h,sha256=QORmVT-FLZanLaABuBgWyZGLr1M-s1u2EdAeiV2yzYA,753 +torch/include/torch/csrc/utils/pythoncapi_compat.h,sha256=WIP_HjhbJeuk1Lj9_d2d7WjhCJScC46ZSY44JD137J8,42399 +torch/include/torch/csrc/utils/schema_info.h,sha256=gGgvIB9UUxusvUKiU2FGB1NZol6QX_CMTnDUUvwIDRA,3905 +torch/include/torch/csrc/utils/six.h,sha256=Imt5y75tGjdt3hxxB1-6Cr0GEB7WOXIuE5zAz7xWh8k,1541 +torch/include/torch/csrc/utils/structseq.h,sha256=noma5ebHQ5PI7XzoQb9PXJqPLDg_32BnNH3tpuZVPuA,150 +torch/include/torch/csrc/utils/tensor_apply.h,sha256=Y5cJu4HIc1y89OO3ZLh5GgBHBDpmQaD2sp3TcHI2KOM,447 +torch/include/torch/csrc/utils/tensor_dtypes.h,sha256=hRHSOd8LhSL4eXnGxFabHBig4MJdM24wvivOfhWth3w,255 +torch/include/torch/csrc/utils/tensor_flatten.h,sha256=hFF35zw2mnoU6y3ZmnHJ00PYOk3ve-EFuICPGvS1-HA,2829 +torch/include/torch/csrc/utils/tensor_layouts.h,sha256=qKI2EwdGjauNW01ATiYkdTd5ZuiODzkpC1h4czWhpDg,76 +torch/include/torch/csrc/utils/tensor_list.h,sha256=ajfcQNgYXWV-gghEwYEMvFcol_KX7JASy3CNdArs1xM,180 +torch/include/torch/csrc/utils/tensor_memoryformats.h,sha256=-8M0flDxa5ta1qGQdbRsTOLHJvtV2O0UU6t1h6NcIwI,337 +torch/include/torch/csrc/utils/tensor_new.h,sha256=J-1Teo4J7YFL0_JQkDmaHQvbarWwPgYwYfq9kKeht8k,4167 +torch/include/torch/csrc/utils/tensor_numpy.h,sha256=PwjkRySXdzKqbei0TQ_z_a1DfY6424XNVHvvo47Q-NE,842 +torch/include/torch/csrc/utils/tensor_qschemes.h,sha256=4ZDdJyV17S6A_sHrflkKHQQDLts00eMnHTbYtRYO0U4,183 +torch/include/torch/csrc/utils/tensor_types.h,sha256=rpPVlyeZgmfiZPIgAUIv5ORjMEagMhGxzKBD1Lin1mM,689 +torch/include/torch/csrc/utils/throughput_benchmark-inl.h,sha256=4YzEYGdbWWRrALuXXscWXpAsTdXJgHA0pFDEQSbGkF4,6507 +torch/include/torch/csrc/utils/throughput_benchmark.h,sha256=qFhVNso3G0zSYn52KFzc3pDHxZGNHqM0uF5Y6H--ixw,7136 +torch/include/torch/csrc/utils/torch_dispatch_mode.h,sha256=QcfjyWuX4LUF8qPrc1JCL9MbdmrmstRnBlqQodiu5SA,2343 +torch/include/torch/csrc/utils/variadic.h,sha256=S8NhLDb4XN6o-BoT-NwpgbbDOLRy33vXU4GH8m6-CX0,3460 +torch/include/torch/csrc/utils/verbose.h,sha256=AfAfwZ-pRGuwdRhbG8WIIuFssaG8vFw0TD5QZ16dr4I,146 +torch/include/torch/csrc/xpu/Event.h,sha256=0scFKbrXeW87Y--wFjem-g56LpnH_cj2SHYRnJN2TV0,385 +torch/include/torch/csrc/xpu/Module.h,sha256=7jE8gXvVL5gaPZpXNW1gSXrIRbN2AxVzetzhDPEAxSg,187 +torch/include/torch/csrc/xpu/Stream.h,sha256=nrZqsE-nbQYpb8xlspR2WY5j-AL_SJ4-K4q_ErjJbnA,454 +torch/include/torch/custom_class.h,sha256=sd5Xh2wywlph-3SpkcJX6CIDnFhephTE6bGXDCRPz7Y,20325 +torch/include/torch/custom_class_detail.h,sha256=31p77AqYexYtAkIiESKx_wUae7g_wXnpxZEoknlrJt8,8067 +torch/include/torch/extension.h,sha256=uEVLy1F9rBRiArbwg_TN-X4I7nXCoaRD9g0DyL8b_MQ,222 +torch/include/torch/headeronly/macros/Export.h,sha256=APL9s0nR-cv6mZHVOxEwVX1lYV0JNOR_2omGN1GPYws,3494 +torch/include/torch/library.h,sha256=iHd-NSkPoMxthZBWaDt4-irNgvL_5eGrThLlRPXI3fU,43286 +torch/include/torch/script.h,sha256=GndQ6oDpXBlnYAQvwUDhZ8p9olslwNmw89SfaUumvvM,482 +torch/include/xnnpack.h,sha256=o4HTOrGgf4XKoNzdxrXzclw3Tfvvl2QkEDW_ZnRlueI,205858 +torch/jit/__init__.py,sha256=K455Yco96mzdfUFBduop1nuUdzVBaXVSnhWP0xeT1Sg,8639 +torch/jit/__pycache__/__init__.cpython-39.pyc,, +torch/jit/__pycache__/_async.cpython-39.pyc,, +torch/jit/__pycache__/_await.cpython-39.pyc,, +torch/jit/__pycache__/_builtins.cpython-39.pyc,, +torch/jit/__pycache__/_check.cpython-39.pyc,, +torch/jit/__pycache__/_dataclass_impls.cpython-39.pyc,, +torch/jit/__pycache__/_decomposition_utils.cpython-39.pyc,, +torch/jit/__pycache__/_decompositions.cpython-39.pyc,, +torch/jit/__pycache__/_freeze.cpython-39.pyc,, +torch/jit/__pycache__/_fuser.cpython-39.pyc,, +torch/jit/__pycache__/_ir_utils.cpython-39.pyc,, +torch/jit/__pycache__/_logging.cpython-39.pyc,, +torch/jit/__pycache__/_monkeytype_config.cpython-39.pyc,, +torch/jit/__pycache__/_pickle.cpython-39.pyc,, +torch/jit/__pycache__/_recursive.cpython-39.pyc,, +torch/jit/__pycache__/_script.cpython-39.pyc,, +torch/jit/__pycache__/_serialization.cpython-39.pyc,, +torch/jit/__pycache__/_shape_functions.cpython-39.pyc,, +torch/jit/__pycache__/_state.cpython-39.pyc,, +torch/jit/__pycache__/_trace.cpython-39.pyc,, +torch/jit/__pycache__/annotations.cpython-39.pyc,, +torch/jit/__pycache__/frontend.cpython-39.pyc,, +torch/jit/__pycache__/generate_bytecode.cpython-39.pyc,, +torch/jit/__pycache__/quantized.cpython-39.pyc,, +torch/jit/__pycache__/supported_ops.cpython-39.pyc,, +torch/jit/__pycache__/unsupported_tensor_ops.cpython-39.pyc,, +torch/jit/_async.py,sha256=kXbyVaJOm080a82emXapZ8_kqW0yrPmQDLZ2fPrla4A,3942 +torch/jit/_await.py,sha256=_p5IAg33hSyJM2NnapAfT0UloKLfePWzRKF2R78zCOI,879 +torch/jit/_builtins.py,sha256=bm3jOgt8NURNUKgepCBjCu6rqcB5nHfhp6-BAtGAmCU,6995 +torch/jit/_check.py,sha256=ah21m0MOnMaCSMz8T7DZNsimP8CMBzkfWLW6dhccAj8,9635 +torch/jit/_dataclass_impls.py,sha256=4xZuB--tEbeKzk2jlVdedBmgohlj_2XQT18o_wU7ezM,6862 +torch/jit/_decomposition_utils.py,sha256=4ODk6U8gz7PWPu4GjxYmVB5Vi_bS04vNgHtcM13N6ao,416 +torch/jit/_decompositions.py,sha256=lEOLhEGDVsdn1NcAuco9RV5ujeavRyPZ6643bnORPzs,4511 +torch/jit/_freeze.py,sha256=NGxGJ2k8QQqCkn5bXVDOvIEVBSYVzFrNcBZJTeu8gcQ,9737 +torch/jit/_fuser.py,sha256=PrBTuJvsbIA7U3DkVMd9G1tXbUtBq3_gKUvUtmkNgqQ,7243 +torch/jit/_ir_utils.py,sha256=rDtTCH29vQtvrFSLTD4z0Dk5sEhxnl1-rS3laLVkLs4,919 +torch/jit/_logging.py,sha256=08M1dFZ7pNDPG0CK8U3KRGsDjDgGCLEqQlJDQJg8JzY,268 +torch/jit/_monkeytype_config.py,sha256=N8bensIEYo-Aahr34uPdq-Mz0znDeB4v21wa2bFU5QE,7483 +torch/jit/_passes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/jit/_passes/__pycache__/__init__.cpython-39.pyc,, +torch/jit/_passes/__pycache__/_property_propagation.cpython-39.pyc,, +torch/jit/_passes/_property_propagation.py,sha256=Y94CvweQZaPce4TrySM4uGEkuvxLKXmxt23MnLPfB1Y,1484 +torch/jit/_pickle.py,sha256=spTmVnJR8c_mLk4sLGyNMEd9nB0XfmdH5eHaNmTPccg,1208 +torch/jit/_recursive.py,sha256=mQs_g92M54sa9mUJoqGVpj0r-oS8QgeSDReG__ziDL8,43436 +torch/jit/_script.py,sha256=NinAmr4yMcalRSuuLcA8KKiHSuggS9Opa6IjJa3q6PE,67043 +torch/jit/_script.pyi,sha256=01y2aZWtiXJ9Oz4Ojmbxe0w8xWfAkqPHloh72BRYjJI,9742 +torch/jit/_serialization.py,sha256=aTkf58xNJ5T9KGQfFQga20EFJmJwyZyHrQL64lhSl-s,9967 +torch/jit/_shape_functions.py,sha256=svE6tK_XQqmja_AFQvnHrKS8mH7WptfNSI2FOfIP9vo,46743 +torch/jit/_state.py,sha256=wwV1GVB92LiD_I_CO0bK6e13IUN8ZWnyVqsnwYFMPj8,3887 +torch/jit/_trace.py,sha256=b3aTC-PYpxDSSMslrDeLflT3ZlHo9voXiMF_vkCEd7A,59884 +torch/jit/annotations.py,sha256=VTWLd0XIhQQ6oTx_7Z5tSM0tt_r6mQoAZNAW5f9nh1k,18435 +torch/jit/frontend.py,sha256=NxEDB4PvMerS8xOdf8dtdV2XOIuqcOzMMnUROLIOm6E,46521 +torch/jit/generate_bytecode.py,sha256=VW9-Kk1a7DkpbfBc2BjpFUNe0fz-iSAVLxmYcNWmHXo,1075 +torch/jit/mobile/__init__.py,sha256=4p-BxWnFFtjNR26xqiTCSPR4PfEw49pGs_mE6zb8Qus,8749 +torch/jit/mobile/__pycache__/__init__.cpython-39.pyc,, +torch/jit/quantized.py,sha256=BzEDKVDO7Ehr-Ngjzx6wBHAvDqzCiqxWv8ztINiuE8Q,3293 +torch/jit/supported_ops.py,sha256=BVQjM7_HLl1pXdRAR3xDP3F3uTcVaDvfpHbSu1HeGgk,10611 +torch/jit/unsupported_tensor_ops.py,sha256=N9PJhDgLXV6GjzlKz5cbPyMNjh72AVzBOiaPl8rLU4E,2065 +torch/lib/XNNPACK.lib,sha256=O_XJj2lPRYf1oZFznqjdVloGloKESIKOdJG5yMpdb-I,14049460 +torch/lib/_C.lib,sha256=LCDxkYo2aMyQK90CeAI-D4D9c1y78JsxA0gwox3Z-Go,1908 +torch/lib/asmjit.dll,sha256=YBlDMmQEb7VbJ1sPgD1wbPcFrzrAbql_CHsKpJqs3d0,367104 +torch/lib/asmjit.lib,sha256=sKb4ZZ1BrJlukXDOJg0eBlkqQm2rncDsMH6QhJjEdOA,140044 +torch/lib/c10.dll,sha256=zO23pwETaf3l_qeyRTKE150cmG9HyU6uOke0HOE1j9E,1046528 +torch/lib/c10.lib,sha256=O5ZT9l9ZwJkVqkJfsevJq79fH1JLlOmOYHwjARg3JLo,778152 +torch/lib/cpuinfo.lib,sha256=O9j__qH08novgb2zpW2wicRtrLXnncTJIfP5Y79ru4Q,601196 +torch/lib/dnnl.lib,sha256=UMVNtiI3T67VU-JpinkkskRBTJrcR5GxuQ3GYwD1Jdo,694253678 +torch/lib/fbgemm.dll,sha256=3HRFT930Be00P9jpXi0ojTdDl8LWsGWy52IDeNJzPsU,5721600 +torch/lib/fbgemm.lib,sha256=tu4rBCEkHfK6dN53cI5SGzptcTJ0dJPTtNAq5WYq-A8,1447968 +torch/lib/fmt.lib,sha256=8D7XGIb0t6VMGKULlAknBPaGv6vx6TximcJXZgnd4yw,3382620 +torch/lib/kineto.lib,sha256=3x2GXEUu_DX5uvISuvc4AYxGkUIoJnGNGNCfp1AsRv0,57893210 +torch/lib/libiomp5md.dll,sha256=2dZu0l8aDqcl-jpBsiz9XRgsGdvkdx2ckMoCrXRm9qE,1613680 +torch/lib/libiompstubs5md.dll,sha256=2us55M7Tv85KYitjuLAoFhPrgYAmb43R3GeDRt8cdpc,43888 +torch/lib/libittnotify.lib,sha256=t-nkhV_Y-vodjIaTZHE604qT27CWEUcKLWsE5B2-X28,591548 +torch/lib/libprotobuf-lite.lib,sha256=Cq5ZKZtulqRQI3-hEqSwNFA496Qsj18JQS-P7SVAdQk,4661164 +torch/lib/libprotobuf.lib,sha256=qmc80E_vmwpdRTd2lGmZ_zvoCJKWs8mw7fZs3UxngZI,32694610 +torch/lib/libprotoc.lib,sha256=Q5RPCISbhvkXXQerniks6RjGKYShOBgXQEbZKSn_R5Q,38476918 +torch/lib/libshm/alloc_info.h,sha256=ZFoLb91AQm4_iK-eFNlbI2n2tdjVtW5KCggYCGVnraM,113 +torch/lib/libshm/err.h,sha256=s_jsA6J6KTGy6ESHZPE9_S9NhVfC-kfotSi0m5PXeIw,1283 +torch/lib/libshm/libshm.h,sha256=-QjWfGIwq73UdOiBFl6_OSn9ODb49XQANpUbZm81yms,1242 +torch/lib/libshm/socket.h,sha256=3qpi-8dZkEkcD9_1LiT0N5nOe7AAhMuRKAu4O0-Johw,4571 +torch/lib/libshm_windows/libshm.h,sha256=TNoEjSFY6vmf9pi1lm8F0qLk_SthMxLDGGTrdKeog7E,815 +torch/lib/microkernels-prod.lib,sha256=1A5RvSuyX2B1j84gDS7qeLul65bpdxvsTYA-xnhuwR0,20267954 +torch/lib/pthreadpool.lib,sha256=iOLaPl55BHg1oh0AlwfqhVty2dx45G8XmIakHwdfdcg,768704 +torch/lib/shm.dll,sha256=wJF2JU_KV2BhCIxEUsi-KV20XJVPJfPF8V4CzuDULvQ,14848 +torch/lib/shm.lib,sha256=iPB__UNwNooisDQLSEnB_XOwM6DumP4VFs4N4Q8ozDQ,3790 +torch/lib/sleef.lib,sha256=EWTlzuGpvUYhcEIhsV-KJOoVNrxc8mdXnDqgOX1Tm7w,8776772 +torch/lib/torch.dll,sha256=-C4EJiL3TVsSYSQIildfccp04b5JKxZQ2f-cmgcgWrE,9728 +torch/lib/torch.lib,sha256=ox5HCNJyfz_YG-aVf-j6YVhxnToDoiyzjJ-hrWh2bao,1832 +torch/lib/torch_cpu.dll,sha256=ZxRX06Ytes_3wox8oc2d0fRrtsaogZI2WejmFA1nuho,255934976 +torch/lib/torch_cpu.lib,sha256=CLgTkxkaxH6_Y-kqrYtl7OiQ2G3VHrHnKU8b4-SW89c,29046564 +torch/lib/torch_global_deps.dll,sha256=qHrd1B2XLHHvQG3idFhoEszX_CU_pqlQSFS9daBkOlw,9728 +torch/lib/torch_python.dll,sha256=LNKJfRY_ECk0Hzrmf8zNX09P57TWLsroynZxKOMUD3M,16310272 +torch/lib/torch_python.lib,sha256=iRAXjrsUF1tci5zNJ7OLI2Apc5knO23TMSvMczt3lSk,287836 +torch/lib/uv.dll,sha256=-laeaC_F-3qOuUxoKa-fMKVpdI27xrzjlzXUi8lgvPg,195072 +torch/library.py,sha256=r_vrBF12krhKHXjFbO3bqiJNwK7HrNieZHLWjwncHDU,64269 +torch/linalg/__init__.py,sha256=5L3It2zGIUH3DA9d8AdV2wb-WQIU6Vj2m_gLCKh9jE0,117887 +torch/linalg/__pycache__/__init__.cpython-39.pyc,, +torch/masked/__init__.py,sha256=e18vUUI9fNnG5sZqFPBMrbFlXzoPIdM8Fxc--5FuJqk,985 +torch/masked/__pycache__/__init__.cpython-39.pyc,, +torch/masked/__pycache__/_docs.cpython-39.pyc,, +torch/masked/__pycache__/_ops.cpython-39.pyc,, +torch/masked/_docs.py,sha256=OmczevR9oCzOcHbugawMpdGV8skR2b55-8VrHHYyQuo,50645 +torch/masked/_ops.py,sha256=j-v0IVvb1ZmG1tdDk_s0DOU3yPb6enFud2PrQNVF3KQ,68013 +torch/masked/maskedtensor/__init__.py,sha256=DxojsToGqewZrb_P__TjYp8ZAcQuuC9kcj-6ZcKsduw,351 +torch/masked/maskedtensor/__pycache__/__init__.cpython-39.pyc,, +torch/masked/maskedtensor/__pycache__/_ops_refs.cpython-39.pyc,, +torch/masked/maskedtensor/__pycache__/binary.cpython-39.pyc,, +torch/masked/maskedtensor/__pycache__/core.cpython-39.pyc,, +torch/masked/maskedtensor/__pycache__/creation.cpython-39.pyc,, +torch/masked/maskedtensor/__pycache__/passthrough.cpython-39.pyc,, +torch/masked/maskedtensor/__pycache__/reductions.cpython-39.pyc,, +torch/masked/maskedtensor/__pycache__/unary.cpython-39.pyc,, +torch/masked/maskedtensor/_ops_refs.py,sha256=xXLqfsv8mDpnft8DVv0rKahc7LQp2iv4qgTtuu-7ArI,18087 +torch/masked/maskedtensor/binary.py,sha256=QYnUB1ot3DIHhXpeXS1weJKaYjHY3P4FqtnyjsQVOJI,5696 +torch/masked/maskedtensor/core.py,sha256=Q_GYALbzGX3KnFsXHnXn4tjcYIDZdE5ZRJxDjsBsvHI,13181 +torch/masked/maskedtensor/creation.py,sha256=GnNKnASrA2uOjSw4LDrC_dqxnl8HHEqZE12FUQnncGg,629 +torch/masked/maskedtensor/passthrough.py,sha256=jih-ZjwTDgw8rdb3qyFP0hkpvwaLoukfdG459X-RB-w,1497 +torch/masked/maskedtensor/reductions.py,sha256=rDnrTaohpk8qsPxA9gdqrrnyS1OFNJCFRj8Dld25R_o,5761 +torch/masked/maskedtensor/unary.py,sha256=_TGZ2JZ484PpXADd2in0QK4vemViMUANzcIK_pZKkqc,4391 +torch/monitor/__init__.py,sha256=fG_MJlz599HVJWQWjjMs3rNIV_AKVklcWIU2393NJ7g,1317 +torch/monitor/__pycache__/__init__.cpython-39.pyc,, +torch/mps/__init__.py,sha256=kKjTR8XFDZVS96Fhr5YqpEPcyM-9-xjCz7DMljcPaTk,6475 +torch/mps/__pycache__/__init__.cpython-39.pyc,, +torch/mps/__pycache__/event.cpython-39.pyc,, +torch/mps/__pycache__/profiler.cpython-39.pyc,, +torch/mps/event.py,sha256=oHh5ChjsP_L3WaJOs_KsUEv2-g4B1-midtwXdnpUX6Y,1775 +torch/mps/profiler.py,sha256=_AI_Aa0SjH9LfFcpuzX5KsBl2KL63TVzNmpurEomifc,3370 +torch/mtia/__init__.py,sha256=2X6E1sPRrII1TTm5ArPmt3OmhmTu7oNitOIXiCkswMg,13860 +torch/mtia/__pycache__/__init__.cpython-39.pyc,, +torch/mtia/__pycache__/_utils.cpython-39.pyc,, +torch/mtia/__pycache__/memory.cpython-39.pyc,, +torch/mtia/_utils.py,sha256=GLEJ1vMhlwA4adZskeiHSVOdzy0kGgdDpeoIAg7Qt-k,1635 +torch/mtia/memory.py,sha256=C1f9N4PpjqY3rq8UP2NxZ60o_jtoKOtSgxVkHvwHMH0,1814 +torch/multiprocessing/__init__.py,sha256=IR-LR_qNSZH1R8B22NDpJeX0_Z6puIWgIZyRWnZSB3U,3578 +torch/multiprocessing/__pycache__/__init__.cpython-39.pyc,, +torch/multiprocessing/__pycache__/_atfork.cpython-39.pyc,, +torch/multiprocessing/__pycache__/pool.cpython-39.pyc,, +torch/multiprocessing/__pycache__/queue.cpython-39.pyc,, +torch/multiprocessing/__pycache__/reductions.cpython-39.pyc,, +torch/multiprocessing/__pycache__/spawn.cpython-39.pyc,, +torch/multiprocessing/_atfork.py,sha256=HM7eWPbd712wegzQ70K_oD4_tvrcfyJDCZ1cc5lAXnE,825 +torch/multiprocessing/pool.py,sha256=qI_YbMLcImtdz3uIDq7hCRBhAdZkFXrDlA1f6Yp2ivk,1795 +torch/multiprocessing/queue.py,sha256=6wenEUd1uGcL6Afud00FdBNaTZ-8LaL2n_vNdrUEtb8,1520 +torch/multiprocessing/reductions.py,sha256=OWe0sf30KUR0I_zyWpwktsYDX5TiU4eczsyzBo90BP0,23787 +torch/multiprocessing/spawn.py,sha256=CTwJTMoyvH4e-E1t8COOV8r1HCkdsRa3qHEnIwjjDoE,13144 +torch/nested/__init__.py,sha256=BCCupu5if-tfcmIohcnR802qgxOq_DXnxObazTzVXyI,22430 +torch/nested/__pycache__/__init__.cpython-39.pyc,, +torch/nested/_internal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/nested/_internal/__pycache__/__init__.cpython-39.pyc,, +torch/nested/_internal/__pycache__/nested_int.cpython-39.pyc,, +torch/nested/_internal/__pycache__/nested_tensor.cpython-39.pyc,, +torch/nested/_internal/__pycache__/ops.cpython-39.pyc,, +torch/nested/_internal/__pycache__/sdpa.cpython-39.pyc,, +torch/nested/_internal/nested_int.py,sha256=_de6OwdoMjMwYTQJODH2-SMPTMPrey0X83NuUDlEDXY,3306 +torch/nested/_internal/nested_tensor.py,sha256=R4yayeM-KaCxiK-hGtY7HTiKjWNZ11ypbXH70Vi8oTE,25400 +torch/nested/_internal/ops.py,sha256=oCxIw2voV2C8313jFJhmZVT_vcsqojQHZg7S16myfck,100365 +torch/nested/_internal/sdpa.py,sha256=1nFTSA6FkPXmAoNd-DZtaCvpq1LZkr3KBEHgisHAu10,35418 +torch/nn/__init__.py,sha256=QleWrbVGUa1f4u5EE3FGCN7GxWq5Fo65er1doQEbmCU,2487 +torch/nn/__pycache__/__init__.cpython-39.pyc,, +torch/nn/__pycache__/_reduction.cpython-39.pyc,, +torch/nn/__pycache__/common_types.cpython-39.pyc,, +torch/nn/__pycache__/cpp.cpython-39.pyc,, +torch/nn/__pycache__/functional.cpython-39.pyc,, +torch/nn/__pycache__/grad.cpython-39.pyc,, +torch/nn/__pycache__/init.cpython-39.pyc,, +torch/nn/__pycache__/parameter.cpython-39.pyc,, +torch/nn/_reduction.py,sha256=t6PgLjrweCbwTcC6udIJ0XU8QlHFeXIGLDXm13rkdz4,1685 +torch/nn/attention/__init__.py,sha256=Mf7mnBcWnX6HY46bTW8eNQLNPmfzN53KePUy1jiBj0Q,6078 +torch/nn/attention/__pycache__/__init__.cpython-39.pyc,, +torch/nn/attention/__pycache__/_utils.cpython-39.pyc,, +torch/nn/attention/__pycache__/bias.cpython-39.pyc,, +torch/nn/attention/__pycache__/flex_attention.cpython-39.pyc,, +torch/nn/attention/_utils.py,sha256=7Ajx_rNjh2xXYTHCYwM-qNar0ItU3Y0gSCBtdZWUUNk,2099 +torch/nn/attention/bias.py,sha256=XBUMumd1o0J0EoMWbFjHlNxmW-tZr495SZvZnLYbBRY,13792 +torch/nn/attention/experimental/__init__.py,sha256=w99ul3HlG3CaoBcG3SJ9r3N10hPNAqirDfntLsa3HFs,112 +torch/nn/attention/experimental/__pycache__/__init__.cpython-39.pyc,, +torch/nn/attention/experimental/__pycache__/_paged_attention.cpython-39.pyc,, +torch/nn/attention/experimental/_paged_attention.py,sha256=ckPrCWZVLffqBVBlZN61gYy3v3pPa-WEIFjN43aTHf0,12373 +torch/nn/attention/flex_attention.py,sha256=iZ8IMGmZkI2ZFFkhpvQt1Ul5atvqUeKvDuJtnvjzfmM,59668 +torch/nn/backends/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/nn/backends/__pycache__/__init__.cpython-39.pyc,, +torch/nn/backends/__pycache__/thnn.cpython-39.pyc,, +torch/nn/backends/thnn.py,sha256=m5HpIz0NfjF-4O5ayPXPi1PFdPEkFvihp4ryASstGDs,152 +torch/nn/common_types.py,sha256=TvKlQWiBKifYXQdxeXUZEsKsSYlkbdyc9A2r3L2uhYg,2233 +torch/nn/cpp.py,sha256=Lls_nq7pKdoQFE9hX6YyhTHi4-5Rgh-VNt9m4QE6-MA,3106 +torch/nn/functional.py,sha256=_BPgXwS5VpZNzffduJ1M1VYMnnkUkvi1L8ToCpBSz38,250349 +torch/nn/functional.pyi,sha256=jF3a7FG5KyfR6j6zdFEeQMTIhUXCyciCvMHIX8f53a4,27825 +torch/nn/grad.py,sha256=MG54cZn6HW7CVLhDZHKqmiOj1WQqhtk3Oo5bFxeudDw,10208 +torch/nn/init.py,sha256=rWCSKrpnLqq0E4l0qdbQ22HTSxalOSrjJqUPwU1VUwA,26511 +torch/nn/intrinsic/__init__.py,sha256=4y35Rm3epwUxcsIUr4kYd3j-T9DPshjMUAvHizLDHqs,733 +torch/nn/intrinsic/__pycache__/__init__.cpython-39.pyc,, +torch/nn/intrinsic/modules/__init__.py,sha256=UBTOY6kdnyOWRjpGh6DSYu10sWxK7G-XvGFvSRQSiik,550 +torch/nn/intrinsic/modules/__pycache__/__init__.cpython-39.pyc,, +torch/nn/intrinsic/modules/__pycache__/fused.cpython-39.pyc,, +torch/nn/intrinsic/modules/fused.py,sha256=340yk-lF--9pZSnoax5nO4lDvMGs4P2ioIhAgchFZF4,596 +torch/nn/intrinsic/qat/__init__.py,sha256=p-5evJa2FgDM8QCPqTuBGITlflcXw7Yg0RCyg937f9w,60 +torch/nn/intrinsic/qat/__pycache__/__init__.cpython-39.pyc,, +torch/nn/intrinsic/qat/modules/__init__.py,sha256=0DDjW4qnb6yRV6vAPx0A8gYwVJMamrLF2RrRNDn2by4,669 +torch/nn/intrinsic/qat/modules/__pycache__/__init__.cpython-39.pyc,, +torch/nn/intrinsic/qat/modules/__pycache__/conv_fused.cpython-39.pyc,, +torch/nn/intrinsic/qat/modules/__pycache__/linear_fused.cpython-39.pyc,, +torch/nn/intrinsic/qat/modules/__pycache__/linear_relu.cpython-39.pyc,, +torch/nn/intrinsic/qat/modules/conv_fused.py,sha256=YtvXTOgobESVF0CmsvszC_ZunZf5ybHY7_skqd34Pxg,894 +torch/nn/intrinsic/qat/modules/linear_fused.py,sha256=frdmuGPzasCk3Hmy75kIBmTgFZwMejSxAvt-_M2sWRU,471 +torch/nn/intrinsic/qat/modules/linear_relu.py,sha256=wYa_ChQtouRtRGxQ0xokIfjJyT2MeWBeC5Iak2qWSoI,471 +torch/nn/intrinsic/quantized/__init__.py,sha256=TzvJoBi6HMEyZWBEcCh6DLsmodGy3-6S2ClPPi4Q3No,350 +torch/nn/intrinsic/quantized/__pycache__/__init__.cpython-39.pyc,, +torch/nn/intrinsic/quantized/dynamic/__init__.py,sha256=9zpAuvn27ttR3rNI37aoR0Ml_bqCA1_WmyZdvlO5RdY,74 +torch/nn/intrinsic/quantized/dynamic/__pycache__/__init__.cpython-39.pyc,, +torch/nn/intrinsic/quantized/dynamic/modules/__init__.py,sha256=f3fbpMxfHM5-gmdXbJT6HTJ1XBN6BgLXm_9jj-csmzI,120 +torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc,, +torch/nn/intrinsic/quantized/dynamic/modules/__pycache__/linear_relu.cpython-39.pyc,, +torch/nn/intrinsic/quantized/dynamic/modules/linear_relu.py,sha256=tgaDRubiJ9y1u1USs7UUDyQ0l9mB_WzQOPDIxPka7ho,103 +torch/nn/intrinsic/quantized/modules/__init__.py,sha256=7epBP6GiGh24LUh-kuuuKkEZeubs2_mI3s8AkWD5szs,396 +torch/nn/intrinsic/quantized/modules/__pycache__/__init__.cpython-39.pyc,, +torch/nn/intrinsic/quantized/modules/__pycache__/bn_relu.cpython-39.pyc,, +torch/nn/intrinsic/quantized/modules/__pycache__/conv_relu.cpython-39.pyc,, +torch/nn/intrinsic/quantized/modules/__pycache__/linear_relu.cpython-39.pyc,, +torch/nn/intrinsic/quantized/modules/bn_relu.py,sha256=mexLG77ELYLoVfXlSKcigmjc0NiePBNZwClR2mpDFFE,118 +torch/nn/intrinsic/quantized/modules/conv_relu.py,sha256=7isVCn-xK1Gmf8gCoGPC8iwuIRm6PQntHT9s0qufmHw,157 +torch/nn/intrinsic/quantized/modules/linear_relu.py,sha256=bFM-bT1HI7uu3zzCFPo-_KYfxO6Lr3-VPid5_EwHMpo,95 +torch/nn/modules/__init__.py,sha256=8sSanMhe5aUOD61p2kMhLvd7fte-aBIJHI_8ELsbdgQ,6828 +torch/nn/modules/__pycache__/__init__.cpython-39.pyc,, +torch/nn/modules/__pycache__/_functions.cpython-39.pyc,, +torch/nn/modules/__pycache__/activation.cpython-39.pyc,, +torch/nn/modules/__pycache__/adaptive.cpython-39.pyc,, +torch/nn/modules/__pycache__/batchnorm.cpython-39.pyc,, +torch/nn/modules/__pycache__/channelshuffle.cpython-39.pyc,, +torch/nn/modules/__pycache__/container.cpython-39.pyc,, +torch/nn/modules/__pycache__/conv.cpython-39.pyc,, +torch/nn/modules/__pycache__/distance.cpython-39.pyc,, +torch/nn/modules/__pycache__/dropout.cpython-39.pyc,, +torch/nn/modules/__pycache__/flatten.cpython-39.pyc,, +torch/nn/modules/__pycache__/fold.cpython-39.pyc,, +torch/nn/modules/__pycache__/instancenorm.cpython-39.pyc,, +torch/nn/modules/__pycache__/lazy.cpython-39.pyc,, +torch/nn/modules/__pycache__/linear.cpython-39.pyc,, +torch/nn/modules/__pycache__/loss.cpython-39.pyc,, +torch/nn/modules/__pycache__/module.cpython-39.pyc,, +torch/nn/modules/__pycache__/normalization.cpython-39.pyc,, +torch/nn/modules/__pycache__/padding.cpython-39.pyc,, +torch/nn/modules/__pycache__/pixelshuffle.cpython-39.pyc,, +torch/nn/modules/__pycache__/pooling.cpython-39.pyc,, +torch/nn/modules/__pycache__/rnn.cpython-39.pyc,, +torch/nn/modules/__pycache__/sparse.cpython-39.pyc,, +torch/nn/modules/__pycache__/transformer.cpython-39.pyc,, +torch/nn/modules/__pycache__/upsampling.cpython-39.pyc,, +torch/nn/modules/__pycache__/utils.cpython-39.pyc,, +torch/nn/modules/_functions.py,sha256=XuIEIBkqH1BctNBRAdudLL3xcSPzgBqBHS2pBNz_y5s,12310 +torch/nn/modules/activation.py,sha256=iIhJQE8zDPiijlOqF2btSOp96vL2lKKrpZMYPeVM8as,60163 +torch/nn/modules/adaptive.py,sha256=EFV6AHI8h1MXiKaPx3ucxHK4mSaMzfKt0Mn3YLvj68c,12767 +torch/nn/modules/batchnorm.py,sha256=iQvqK0pGLeQJrv6eQZm9yP-UXRRXeVos8TpOSZc72Ls,39278 +torch/nn/modules/channelshuffle.py,sha256=TNNIWJuBR_ZFfE523qd5TE4rehfntN1R0fGHHm8m5kc,1604 +torch/nn/modules/container.py,sha256=mdlrurmukleB74b1-IzlWpFMMWT38gnI4WEsQuL_ywM,37970 +torch/nn/modules/conv.py,sha256=6qfIsY7OtHO3GTk83HgHvvrdvudxkAN9xZ-dy7SAyQQ,78026 +torch/nn/modules/distance.py,sha256=e2WyoltAHkHKKIUQmTOlsY1qmKJ2Cnitiu5weRZctFE,3356 +torch/nn/modules/dropout.py,sha256=S26ciN0Mr8rEUU9xz2m-TptVR5w5t2gbWWGEEQxRMuc,11492 +torch/nn/modules/flatten.py,sha256=C2eFKxm_GdQu1TuiP1T6eoGrtg2Jeo5A3fcKACoYQYg,5695 +torch/nn/modules/fold.py,sha256=vsiNxfT1eJ5xbQcf4nGwEdmwlDXvoFhTfo4NpaWLwTI,13282 +torch/nn/modules/instancenorm.py,sha256=VyQ9z1O5biu82G6mj2xJZsWHszddbqBmjNhDqm9R-B8,20790 +torch/nn/modules/lazy.py,sha256=o3kmdNIyZdaKVUgCFEGOeIURPKiJnhCt38hMj4WcwSI,11887 +torch/nn/modules/linear.py,sha256=WLjOHyooif-sH3gXUi2u9GZQXs1RI9F5pdhsG3Wv_3A,11571 +torch/nn/modules/loss.py,sha256=GQojkTVXQJ4jp4PFb7iwtspUBtDNzrw6jGvJ-vHyDhk,95866 +torch/nn/modules/module.py,sha256=zuYNpjrkzwrrFrI8uE-tPZjuzIVUkYGcuLOABYGgS9Q,129809 +torch/nn/modules/normalization.py,sha256=rQwQgksrOjEgr0pen_Vw-8ozOC0UbmTIBJpHDUUgjyE,15370 +torch/nn/modules/padding.py,sha256=fN_gP9pmu1K5UeQKD5whrMOESGA_XTgR3WqBwU9amFs,31768 +torch/nn/modules/pixelshuffle.py,sha256=2YoVUAIk-lViz8KlwzPIXFBkDWMJ63QKW1Hk4mA6kv0,3795 +torch/nn/modules/pooling.py,sha256=ZDw4u-7KOxfwSdgd5GoJHBBJ6kiolz1YhTQq2WpKQBo,61048 +torch/nn/modules/rnn.py,sha256=TKl8MTNJSnshuwf9-Q7ASyfXZKSAPwL58Tn7YWFLXss,76334 +torch/nn/modules/sparse.py,sha256=8jA8POM82pbFEopYo2NEWO7KBwvMb7o9W8qTVn9yf0M,24632 +torch/nn/modules/transformer.py,sha256=kx7Hthu2YzfybUjjem6bRPcGdj5ViMZavUYubDH7nB0,53098 +torch/nn/modules/upsampling.py,sha256=g8phjrdP0hWbEv2SExmPVvKjg_1yfm8jVqUqD7HSe7o,11834 +torch/nn/modules/utils.py,sha256=DKpc-xpm2VNwZvxLeKhgghLI-dL7Ztm4ZNLRNK9lJvw,2660 +torch/nn/parallel/__init__.py,sha256=bH4_qdrQvxrqhlyf-9hHEzGdC5qcxL288zpCxlbFUEg,787 +torch/nn/parallel/__pycache__/__init__.cpython-39.pyc,, +torch/nn/parallel/__pycache__/_functions.cpython-39.pyc,, +torch/nn/parallel/__pycache__/comm.cpython-39.pyc,, +torch/nn/parallel/__pycache__/data_parallel.cpython-39.pyc,, +torch/nn/parallel/__pycache__/distributed.cpython-39.pyc,, +torch/nn/parallel/__pycache__/parallel_apply.cpython-39.pyc,, +torch/nn/parallel/__pycache__/replicate.cpython-39.pyc,, +torch/nn/parallel/__pycache__/scatter_gather.cpython-39.pyc,, +torch/nn/parallel/_functions.py,sha256=mEWfRi1CcTG2h5ktic_LcmSfpzARhUA8W8YYicSEPqQ,5080 +torch/nn/parallel/comm.py,sha256=bstYkYKrGI9Amw352i2DMGtzrY0x09GC03RjTPT0uWw,11152 +torch/nn/parallel/data_parallel.py,sha256=cDZQdmCWcUQIqlzJhjAnOtvg3kpuQIoYsvJLDIh3mg0,12047 +torch/nn/parallel/distributed.py,sha256=QN3iPLUHH-mgMAAs8GMecwr9vDhLhNOpGFarLmOxLv8,111550 +torch/nn/parallel/parallel_apply.py,sha256=XAkzt6619cET74bo_mb_LSLRTkF0Rr7fZvCs4nbLm6U,4595 +torch/nn/parallel/replicate.py,sha256=nyLtiCZp_1UYCarXkCOsI61FSQcIlfOsLKhYlm5hKnE,7224 +torch/nn/parallel/scatter_gather.py,sha256=2Xyb_G2W6u7gojpjtzwo1ZleppoKMJZE8Ne0qDwQjik,5120 +torch/nn/parameter.py,sha256=wCmhaClAFYIUtq48j0nggQ_vzYj5Xc-Eq7IJ3hnyplQ,11658 +torch/nn/parameter.pyi,sha256=g3PqT5tx0bRojYmW5A72aRr-gGie5KzEHlDegXcQUaA,1145 +torch/nn/qat/__init__.py,sha256=pHegh4R0-F97KQg3aD-sS0iMSrzniaH9gpLhLt1HAVA,384 +torch/nn/qat/__pycache__/__init__.cpython-39.pyc,, +torch/nn/qat/dynamic/__init__.py,sha256=IFxcIU-Zi7ju-VXfx5LO6qV1pzWSJGivHrShc3-VRrQ,216 +torch/nn/qat/dynamic/__pycache__/__init__.cpython-39.pyc,, +torch/nn/qat/dynamic/modules/__init__.py,sha256=xhaQBKF6ynIsfGH7lkql1W9yFDyJnVsQ1rhZ1ATMaDo,82 +torch/nn/qat/dynamic/modules/__pycache__/__init__.cpython-39.pyc,, +torch/nn/qat/dynamic/modules/__pycache__/linear.cpython-39.pyc,, +torch/nn/qat/dynamic/modules/linear.py,sha256=Qp_oQPS9HJUi2t0nPiFzUtUMkwUp8XW9lr3t6C7WzCY,427 +torch/nn/qat/modules/__init__.py,sha256=6Sx-ipG60zKKr96Gfu_oCeZ02h1j-WVyPB1tR7DQ_1A,522 +torch/nn/qat/modules/__pycache__/__init__.cpython-39.pyc,, +torch/nn/qat/modules/__pycache__/conv.cpython-39.pyc,, +torch/nn/qat/modules/__pycache__/embedding_ops.cpython-39.pyc,, +torch/nn/qat/modules/__pycache__/linear.cpython-39.pyc,, +torch/nn/qat/modules/conv.py,sha256=4_HHeeAfHpwkekhOyE02Plind4otYDPXvYL_PyfIKaI,417 +torch/nn/qat/modules/embedding_ops.py,sha256=tNse_KPodjxDvvr8d_wkgshCOp3SqL3Zd5mbORnvJ6E,472 +torch/nn/qat/modules/linear.py,sha256=1BvjhZ-tBetm7yk0qFGY7rx2jQDi7Qips2VNCHgoTEY,403 +torch/nn/quantizable/__init__.py,sha256=-0datFuoOimv7Qda2O56u2mPcV2YE2mBLW4ErmOMt8k,58 +torch/nn/quantizable/__pycache__/__init__.cpython-39.pyc,, +torch/nn/quantizable/modules/__init__.py,sha256=2r8p4LEolirX0dhKNAvuJi4ubW11KIaL5leZFJPTYn4,216 +torch/nn/quantizable/modules/__pycache__/__init__.cpython-39.pyc,, +torch/nn/quantizable/modules/__pycache__/activation.cpython-39.pyc,, +torch/nn/quantizable/modules/__pycache__/rnn.cpython-39.pyc,, +torch/nn/quantizable/modules/activation.py,sha256=ZkWt9LBHI95Tp37DGPWL7ywizIai0tVN7guf-jB9y_w,451 +torch/nn/quantizable/modules/rnn.py,sha256=Ft5W_4bwJrjTQZ_dHIrwjF2WG_D3c1OKfZLdFXpgaY8,440 +torch/nn/quantized/__init__.py,sha256=REwHEIXAsudYLljnHsXvh8cIOPXMdIa-VtcJeXgUQxA,810 +torch/nn/quantized/__pycache__/__init__.cpython-39.pyc,, +torch/nn/quantized/__pycache__/functional.cpython-39.pyc,, +torch/nn/quantized/_reference/__init__.py,sha256=WbDZWiGzHm3MlXrSDJTj7v8eCeEzcDuss7lOO2nInYU,67 +torch/nn/quantized/_reference/__pycache__/__init__.cpython-39.pyc,, +torch/nn/quantized/_reference/modules/__init__.py,sha256=8gzog8X5RxwQGehxxyt80qsWpckIe8OQzRLhnP07fk0,1057 +torch/nn/quantized/_reference/modules/__pycache__/__init__.cpython-39.pyc,, +torch/nn/quantized/_reference/modules/__pycache__/conv.cpython-39.pyc,, +torch/nn/quantized/_reference/modules/__pycache__/linear.cpython-39.pyc,, +torch/nn/quantized/_reference/modules/__pycache__/rnn.cpython-39.pyc,, +torch/nn/quantized/_reference/modules/__pycache__/sparse.cpython-39.pyc,, +torch/nn/quantized/_reference/modules/__pycache__/utils.cpython-39.pyc,, +torch/nn/quantized/_reference/modules/conv.py,sha256=pjn9wbE5ASUlOevzsmw8SDxGlOADYqKZMmGliV52nL4,600 +torch/nn/quantized/_reference/modules/linear.py,sha256=pPKp0fswx4aAvvKnNqtxCpcx40DXqygxOrqeEB688qY,462 +torch/nn/quantized/_reference/modules/rnn.py,sha256=1b6FGQwK5dbTBCN0KIwuBhCLfS8nmAsTvQqNok17HZM,543 +torch/nn/quantized/_reference/modules/sparse.py,sha256=hCVj2Gqe5uV_6ptxLpvjKlAO2ddwuyy26yZzjNQ5ikU,479 +torch/nn/quantized/_reference/modules/utils.py,sha256=xgb58XnUhfBnLgeNHrBeCf7BqI7zM_osXVI0EgeY62I,608 +torch/nn/quantized/dynamic/__init__.py,sha256=zUGrFi6mSGoJGYz5xRd8uDIvmz6tIHQwKXhlalyECKg,59 +torch/nn/quantized/dynamic/__pycache__/__init__.cpython-39.pyc,, +torch/nn/quantized/dynamic/modules/__init__.py,sha256=dOUhdslM3ZBE2rYYjF4PJppebHDMcv_SrZpbMxxRzuU,1036 +torch/nn/quantized/dynamic/modules/__pycache__/__init__.cpython-39.pyc,, +torch/nn/quantized/dynamic/modules/__pycache__/conv.cpython-39.pyc,, +torch/nn/quantized/dynamic/modules/__pycache__/linear.cpython-39.pyc,, +torch/nn/quantized/dynamic/modules/__pycache__/rnn.cpython-39.pyc,, +torch/nn/quantized/dynamic/modules/conv.py,sha256=hh6gMpib0xiYrIFcZkV8feUdafY4VExsggB6Mw6RMn4,697 +torch/nn/quantized/dynamic/modules/linear.py,sha256=gRKnDAxvb0dY7DvrVxrt6VUN4_YkS6iGdQr2ZAWxqRg,459 +torch/nn/quantized/dynamic/modules/rnn.py,sha256=mHRaypMPwS3iaj0nFQp4MIseczl1wLZobltW_b5U0iE,774 +torch/nn/quantized/functional.py,sha256=j7MGLQ6xhTh3XPiEJW67BRYBggrOMGdyu7zD-E9Wor4,286 +torch/nn/quantized/modules/__init__.py,sha256=UxxpMDG7z7jPwzAgKdR6wEH5vTOJ9Sgz_u0aJuuMNFQ,2198 +torch/nn/quantized/modules/__pycache__/__init__.cpython-39.pyc,, +torch/nn/quantized/modules/__pycache__/activation.cpython-39.pyc,, +torch/nn/quantized/modules/__pycache__/batchnorm.cpython-39.pyc,, +torch/nn/quantized/modules/__pycache__/conv.cpython-39.pyc,, +torch/nn/quantized/modules/__pycache__/dropout.cpython-39.pyc,, +torch/nn/quantized/modules/__pycache__/embedding_ops.cpython-39.pyc,, +torch/nn/quantized/modules/__pycache__/functional_modules.cpython-39.pyc,, +torch/nn/quantized/modules/__pycache__/linear.cpython-39.pyc,, +torch/nn/quantized/modules/__pycache__/normalization.cpython-39.pyc,, +torch/nn/quantized/modules/__pycache__/rnn.cpython-39.pyc,, +torch/nn/quantized/modules/__pycache__/utils.cpython-39.pyc,, +torch/nn/quantized/modules/activation.py,sha256=xW027WhWB7ZNliQOe-FBGUiidOL9ICVQP0h7Ez-qphY,548 +torch/nn/quantized/modules/batchnorm.py,sha256=NXrd4bZpkNotYxpUMMX6CbkCkK8cGpgDrcfA3TZQBak,448 +torch/nn/quantized/modules/conv.py,sha256=1QiyTXQYOFuLE-p6r1sXk-SWqFZMucTLY_1fBq7vVls,695 +torch/nn/quantized/modules/dropout.py,sha256=e2wkofacz2MkquH4AmVOoZNAcInnRj8ELLd9Y6FkVFc,456 +torch/nn/quantized/modules/embedding_ops.py,sha256=jMga5rGAoWJBb1HlZAnaoDxe4jqa23rfE3eFKtgej1E,565 +torch/nn/quantized/modules/functional_modules.py,sha256=ig156T5ByLXZcE8WHAcI1i-grqMRVXI3gg0oYioJegs,572 +torch/nn/quantized/modules/linear.py,sha256=pb0RdBMmEoA5yy6pbbxMH5bUEyJGj6gblHRtk7djFUk,495 +torch/nn/quantized/modules/normalization.py,sha256=f8sdj5vc5g7Y6JbRfmfBrcAXdkJhE1p1PojMxfHTGjA,652 +torch/nn/quantized/modules/rnn.py,sha256=sWTce57bC436QZ2xtG8xBO-fCQmIxwjfwIwl-0FRUl4,422 +torch/nn/quantized/modules/utils.py,sha256=3RhDoXUIo5ZJuAkqSZJ3PUtfvfFLyIF41e4SSNLgoMo,556 +torch/nn/utils/__init__.py,sha256=ihgUt20V8RRyUj2WDdJT_nfJoCgVPLs7KrihljN7wpM,1289 +torch/nn/utils/__pycache__/__init__.cpython-39.pyc,, +torch/nn/utils/__pycache__/_deprecation_utils.cpython-39.pyc,, +torch/nn/utils/__pycache__/_named_member_accessor.cpython-39.pyc,, +torch/nn/utils/__pycache__/_per_sample_grad.cpython-39.pyc,, +torch/nn/utils/__pycache__/clip_grad.cpython-39.pyc,, +torch/nn/utils/__pycache__/convert_parameters.cpython-39.pyc,, +torch/nn/utils/__pycache__/fusion.cpython-39.pyc,, +torch/nn/utils/__pycache__/init.cpython-39.pyc,, +torch/nn/utils/__pycache__/memory_format.cpython-39.pyc,, +torch/nn/utils/__pycache__/parametrizations.cpython-39.pyc,, +torch/nn/utils/__pycache__/parametrize.cpython-39.pyc,, +torch/nn/utils/__pycache__/prune.cpython-39.pyc,, +torch/nn/utils/__pycache__/rnn.cpython-39.pyc,, +torch/nn/utils/__pycache__/spectral_norm.cpython-39.pyc,, +torch/nn/utils/__pycache__/stateless.cpython-39.pyc,, +torch/nn/utils/__pycache__/weight_norm.cpython-39.pyc,, +torch/nn/utils/_deprecation_utils.py,sha256=kt1ThQiQtA-VAUgLYREkyUk8zOSG6kYFsF6n9rDsbXE,1728 +torch/nn/utils/_expanded_weights/__init__.py,sha256=8_VaTgAbp0cjDp-P6-5IHjTP2iKEAI7AOhpj1OdxIZ8,462 +torch/nn/utils/_expanded_weights/__pycache__/__init__.cpython-39.pyc,, +torch/nn/utils/_expanded_weights/__pycache__/conv_expanded_weights.cpython-39.pyc,, +torch/nn/utils/_expanded_weights/__pycache__/conv_utils.cpython-39.pyc,, +torch/nn/utils/_expanded_weights/__pycache__/embedding_expanded_weights.cpython-39.pyc,, +torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_impl.cpython-39.pyc,, +torch/nn/utils/_expanded_weights/__pycache__/expanded_weights_utils.cpython-39.pyc,, +torch/nn/utils/_expanded_weights/__pycache__/group_norm_expanded_weights.cpython-39.pyc,, +torch/nn/utils/_expanded_weights/__pycache__/instance_norm_expanded_weights.cpython-39.pyc,, +torch/nn/utils/_expanded_weights/__pycache__/layer_norm_expanded_weights.cpython-39.pyc,, +torch/nn/utils/_expanded_weights/__pycache__/linear_expanded_weights.cpython-39.pyc,, +torch/nn/utils/_expanded_weights/conv_expanded_weights.py,sha256=mjpmUHnpJoBKL1SOjuJMdArkxhLY2FBVb3KqF_CpAko,2895 +torch/nn/utils/_expanded_weights/conv_utils.py,sha256=RMKLRtdJKiMu8fRlpt557u_xm-vqAbMYBkQyW2wo15I,11083 +torch/nn/utils/_expanded_weights/embedding_expanded_weights.py,sha256=8MmVgtGfPYmy7J1xYoX5ZoirsQGyhWlFv0J2NZDa6hs,2933 +torch/nn/utils/_expanded_weights/expanded_weights_impl.py,sha256=q68MqE2QfatHuwyICHB2pG0k7KGN0tb9pZpL_2mmQV0,6464 +torch/nn/utils/_expanded_weights/expanded_weights_utils.py,sha256=L3rIvTtpOgkJrkbUaJllT01vIS-8roMYWvXCT1hjS1Q,7770 +torch/nn/utils/_expanded_weights/group_norm_expanded_weights.py,sha256=AicDIUgPrnJOlBKZH0ZTPZvlSUIyp64h0uUcrMYJVMA,3554 +torch/nn/utils/_expanded_weights/instance_norm_expanded_weights.py,sha256=7DulFQv-D7E3fGX1_QsDtBgGRaNhKh_cSn9WX5rms7M,3829 +torch/nn/utils/_expanded_weights/layer_norm_expanded_weights.py,sha256=S7oSZa0qR6MaGv1voPoRNCUe7Haq1d0WrMSHfwTeK8A,3333 +torch/nn/utils/_expanded_weights/linear_expanded_weights.py,sha256=n61wv8k_ww5oPaGgdn3hmFZ8I1nCKnsH3EQ6q4CCE1E,2278 +torch/nn/utils/_named_member_accessor.py,sha256=J2bPnJ5z_8GWXCWSLaCQhiSqfuloHJhmNGHydkxw5Ms,14535 +torch/nn/utils/_per_sample_grad.py,sha256=UaXBpMK9m5uTrti9mEsmvi6L_-Xm181NUG8X7KWp808,5901 +torch/nn/utils/clip_grad.py,sha256=awP1OAA9nvuzcTXCcLdd3aODVXtoJYFPr-7xgwUek8w,11444 +torch/nn/utils/convert_parameters.py,sha256=ApmRzAur3mr-9Zlw3SrlBJUIDvYab7ioIDUqEN-R59w,3335 +torch/nn/utils/fusion.py,sha256=-83z21e6AQt274kF-SRgAryTy2mfSTZrq69JimdKsWA,6624 +torch/nn/utils/init.py,sha256=IDz4UrKDFlvKbHJVDVPblAuLZll9uQOaj-_NtZFiLok,2305 +torch/nn/utils/memory_format.py,sha256=04_MY_hh2weI1vMsWLDEELl0WjpGO3aCxYTXR8K74rY,8376 +torch/nn/utils/parametrizations.py,sha256=33V8VkP-mhipNtUWrCI5lw1OKbjHxKSMM5_7NM8L4lU,26294 +torch/nn/utils/parametrize.py,sha256=NpYR8OFd9-1394KJ81BEP-64TMcbeC6_A2pgirtayT0,37046 +torch/nn/utils/prune.py,sha256=3jlxXGgktgu0ddYiAa_Y_i1oobD7dbyWKW2HLPUMxUw,59470 +torch/nn/utils/rnn.py,sha256=_qj_cQH_6rY5K8g445pHTfT57ZyabkoJGuTaE7iQZi0,23829 +torch/nn/utils/spectral_norm.py,sha256=EyHmvfJR9RP6djNftBsQyyoG3YKW8QTk4mZW_SIF-2M,15281 +torch/nn/utils/stateless.py,sha256=I8LUFGr_zryFaUDI9AFFrYZye-f04gxFcAiCSZuJ9j0,11978 +torch/nn/utils/weight_norm.py,sha256=-CqR4Ej1AUeD9Lnz8K-FpvLFEzc-nWqv_7gRNby3baU,6047 +torch/onnx/__init__.py,sha256=TF9Wg-_8eQNI9GgmO6A4f-la9BF7e34VgTtgq4hgHg4,22300 +torch/onnx/__pycache__/__init__.cpython-39.pyc,, +torch/onnx/__pycache__/_constants.cpython-39.pyc,, +torch/onnx/__pycache__/_experimental.cpython-39.pyc,, +torch/onnx/__pycache__/_flags.cpython-39.pyc,, +torch/onnx/__pycache__/_globals.cpython-39.pyc,, +torch/onnx/__pycache__/_onnx_supported_ops.cpython-39.pyc,, +torch/onnx/__pycache__/_type_utils.cpython-39.pyc,, +torch/onnx/__pycache__/errors.cpython-39.pyc,, +torch/onnx/__pycache__/operators.cpython-39.pyc,, +torch/onnx/__pycache__/symbolic_caffe2.cpython-39.pyc,, +torch/onnx/__pycache__/symbolic_helper.cpython-39.pyc,, +torch/onnx/__pycache__/symbolic_opset10.cpython-39.pyc,, +torch/onnx/__pycache__/symbolic_opset11.cpython-39.pyc,, +torch/onnx/__pycache__/symbolic_opset12.cpython-39.pyc,, +torch/onnx/__pycache__/symbolic_opset13.cpython-39.pyc,, +torch/onnx/__pycache__/symbolic_opset14.cpython-39.pyc,, +torch/onnx/__pycache__/symbolic_opset15.cpython-39.pyc,, +torch/onnx/__pycache__/symbolic_opset16.cpython-39.pyc,, +torch/onnx/__pycache__/symbolic_opset17.cpython-39.pyc,, +torch/onnx/__pycache__/symbolic_opset18.cpython-39.pyc,, +torch/onnx/__pycache__/symbolic_opset19.cpython-39.pyc,, +torch/onnx/__pycache__/symbolic_opset20.cpython-39.pyc,, +torch/onnx/__pycache__/symbolic_opset7.cpython-39.pyc,, +torch/onnx/__pycache__/symbolic_opset8.cpython-39.pyc,, +torch/onnx/__pycache__/symbolic_opset9.cpython-39.pyc,, +torch/onnx/__pycache__/utils.cpython-39.pyc,, +torch/onnx/__pycache__/verification.cpython-39.pyc,, +torch/onnx/_constants.py,sha256=5b1AY0Elyom82XlAEtUYhlzqRv-iRJP824v36zS_1TA,555 +torch/onnx/_experimental.py,sha256=Mz3VWbroPljcVY3On61zd1TRA50-LD8AMj_XKzpfwT8,1061 +torch/onnx/_flags.py,sha256=wzM12Gi0XO6liu_h259qcjAJqyQ6KXD172jgilW3Ndo,1285 +torch/onnx/_globals.py,sha256=BGop7wbSNRHOgA33mNTPnD0KBLm4Qlhy1gnqIKkYt7M,2854 +torch/onnx/_internal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/onnx/_internal/__pycache__/__init__.cpython-39.pyc,, +torch/onnx/_internal/__pycache__/_exporter_legacy.cpython-39.pyc,, +torch/onnx/_internal/__pycache__/_lazy_import.cpython-39.pyc,, +torch/onnx/_internal/__pycache__/io_adapter.cpython-39.pyc,, +torch/onnx/_internal/__pycache__/jit_utils.cpython-39.pyc,, +torch/onnx/_internal/__pycache__/onnx_proto_utils.cpython-39.pyc,, +torch/onnx/_internal/__pycache__/onnxruntime.cpython-39.pyc,, +torch/onnx/_internal/__pycache__/registration.cpython-39.pyc,, +torch/onnx/_internal/_exporter_legacy.py,sha256=nbznEQOs8htwfv9K_vHdKnbRD_49RlEFnD_00sUmBkI,20089 +torch/onnx/_internal/_lazy_import.py,sha256=EAyaR5oO64UZa567njd1UqMgefGP3oeDEQEpwxmQudo,1269 +torch/onnx/_internal/exporter/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/onnx/_internal/exporter/__pycache__/__init__.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_analysis.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_building.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_capture_strategies.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_compat.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_constants.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_core.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_decomp.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_dispatching.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_dynamic_shapes.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_errors.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_flags.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_fx_passes.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_ir_passes.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_isolated.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_onnx_program.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_registration.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_reporting.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_schemas.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_tensors.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_testing.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_type_casting.cpython-39.pyc,, +torch/onnx/_internal/exporter/__pycache__/_verification.cpython-39.pyc,, +torch/onnx/_internal/exporter/_analysis.py,sha256=JGOYBx6iH1JaK-ZA-o3iGvc8XdtQ82hZZUNJLgD8cFk,9104 +torch/onnx/_internal/exporter/_building.py,sha256=dxVLl24IlvuMo9aYCKfC9Yb47eRwq4u7mKDAZvamvUo,31568 +torch/onnx/_internal/exporter/_capture_strategies.py,sha256=97-5iZ_abYG1_eEfZiV6Fe7Rq3Qye5fwmRPEXe-H8Eo,10024 +torch/onnx/_internal/exporter/_compat.py,sha256=cBCFZEEQIzIjDdR6oCK_T06_Hfaw0C1oRdGudcXGURE,7760 +torch/onnx/_internal/exporter/_constants.py,sha256=I6tb-H1ESOnnkAkPvemWkejHs0NF_CEFz47Cs67Tcso,385 +torch/onnx/_internal/exporter/_core.py,sha256=4NhUoLZkkGiUjAOAx2BXVr9ZLwiLsWeAkzJMAKxKHu0,68964 +torch/onnx/_internal/exporter/_decomp.py,sha256=etnfesJBD4Xc8-ATGRHAUUvhr1bFCsljduNJIKzjJPE,2911 +torch/onnx/_internal/exporter/_dispatching.py,sha256=H7a3GsWQt2XymfYznasOrg_WFD6jJpYXDg26X-T7Mio,15169 +torch/onnx/_internal/exporter/_dynamic_shapes.py,sha256=1s0hBaL9kLj-vxi8HH4mQ7pJU6FMnwSn9xZJ5vbj5vg,14343 +torch/onnx/_internal/exporter/_errors.py,sha256=tOTcRIR3plnuELRHQYP5jA48Cb7r5IF_Ynp5EZotxJQ,556 +torch/onnx/_internal/exporter/_flags.py,sha256=-UESOX9ZYCXUSLi9p1QYHhcXzxqYYMqPKwRdeUCe-8w,667 +torch/onnx/_internal/exporter/_fx_passes.py,sha256=ITYYENDoQSS8J0x9odBFGkwipg93uEh9wvRiEPdBEzA,1768 +torch/onnx/_internal/exporter/_ir_passes.py,sha256=blhtUQdhRrKimFNqn5Jil4rBmzLYA-_uOLWTyRNdE0I,6002 +torch/onnx/_internal/exporter/_isolated.py,sha256=DPAVJ-kaj6UWwq0Mlv0iuxW525K-SKcwMEbBYrVDrIE,1877 +torch/onnx/_internal/exporter/_onnx_program.py,sha256=OLoQ4Psg25AJU5PEVVX7QapeUSKCVduCWRRJLXt_H0E,18662 +torch/onnx/_internal/exporter/_registration.py,sha256=aaVXDwWONks9U0jxF32ijfbLOk34iucP1b1GttvGCfw,12736 +torch/onnx/_internal/exporter/_reporting.py,sha256=z7zL1BdniFCMYL90uHxde4a8yrNjYb1TTxxFaRHDrHI,7603 +torch/onnx/_internal/exporter/_schemas.py,sha256=Jc9yFNWE8LlVrlg1Tw6OVo6qO9Zk8GEMTU_dRfzOFjQ,21168 +torch/onnx/_internal/exporter/_tensors.py,sha256=3OqfVx1KGjEM2ZvTuzU8G8j1GvjqcqhXQk8EisHRdNg,2562 +torch/onnx/_internal/exporter/_testing.py,sha256=zCtI19OBDGkDZe9x9JNrv7rYhwPtJDnS2vKOSzX3YgQ,3414 +torch/onnx/_internal/exporter/_torchlib/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/onnx/_internal/exporter/_torchlib/__pycache__/__init__.cpython-39.pyc,, +torch/onnx/_internal/exporter/_torchlib/__pycache__/_tensor_typing.cpython-39.pyc,, +torch/onnx/_internal/exporter/_torchlib/__pycache__/_torchlib_registry.cpython-39.pyc,, +torch/onnx/_internal/exporter/_torchlib/_tensor_typing.py,sha256=P0WkH1XZKoofIYcL_QL_BobXDwL27dcjIacC_WBVBOM,2120 +torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py,sha256=WAQwLhd1FsiF37wSiVFTuFuwuKNGzAI9owRErfFIV30,2815 +torch/onnx/_internal/exporter/_torchlib/ops/__init__.py,sha256=vpeNP_4-RXUtQZ8xOUn5d8ULRmK8D1NbtzemMzdZ8_s,186 +torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/__init__.cpython-39.pyc,, +torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/core.cpython-39.pyc,, +torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/hop.cpython-39.pyc,, +torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/nn.cpython-39.pyc,, +torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/symbolic.cpython-39.pyc,, +torch/onnx/_internal/exporter/_torchlib/ops/__pycache__/symops.cpython-39.pyc,, +torch/onnx/_internal/exporter/_torchlib/ops/core.py,sha256=igjVKhil3zfdZV85ef0yg6EZ2BLoSZboo1NjeurIpgo,1589 +torch/onnx/_internal/exporter/_torchlib/ops/hop.py,sha256=8kS7MlXLbjS8NcqJ2EKcslwae2Uz0xSAid0nLjqMMio,5466 +torch/onnx/_internal/exporter/_torchlib/ops/nn.py,sha256=TgPOw_fK0RAl-FmTnr2bcVKh_Ipf0KD4LUbCiyd5VLo,10754 +torch/onnx/_internal/exporter/_torchlib/ops/symbolic.py,sha256=ihJ2Jm7eJ9aQdM3bH0rMDr-5WIPbRNlfLEwvdRh1i_0,4871 +torch/onnx/_internal/exporter/_torchlib/ops/symops.py,sha256=ESHQ6tSOFIPKc0zVLxd9w3kYYi676zbdUuHyFQpeE7w,1188 +torch/onnx/_internal/exporter/_type_casting.py,sha256=xGjq4effVt7qIsoxywnPI1LnmLBGTFyEIzYtlINpM-g,1193 +torch/onnx/_internal/exporter/_verification.py,sha256=QRo8bSY-Q4VDJpm5ghlA395eerA4-YQvbxhSafJXvXw,12911 +torch/onnx/_internal/fx/__init__.py,sha256=JYi5iirfH0wnyf-7etmNfyP2Kkqi5kajEZSd8yMjOQ0,180 +torch/onnx/_internal/fx/__pycache__/__init__.cpython-39.pyc,, +torch/onnx/_internal/fx/__pycache__/_pass.cpython-39.pyc,, +torch/onnx/_internal/fx/__pycache__/decomposition_table.cpython-39.pyc,, +torch/onnx/_internal/fx/__pycache__/dynamo_graph_extractor.cpython-39.pyc,, +torch/onnx/_internal/fx/__pycache__/fx_onnx_interpreter.cpython-39.pyc,, +torch/onnx/_internal/fx/__pycache__/onnxfunction_dispatcher.cpython-39.pyc,, +torch/onnx/_internal/fx/__pycache__/patcher.cpython-39.pyc,, +torch/onnx/_internal/fx/__pycache__/registration.cpython-39.pyc,, +torch/onnx/_internal/fx/__pycache__/serialization.cpython-39.pyc,, +torch/onnx/_internal/fx/__pycache__/type_utils.cpython-39.pyc,, +torch/onnx/_internal/fx/_pass.py,sha256=HnJ6kXj2FzlklMjV6g5vtje-LZJzq8BYTGExmw7bnGo,8825 +torch/onnx/_internal/fx/decomposition_table.py,sha256=mXdn4e4BCmIISiJTO09WsDz8Hpxt6Cv6X0k33CpKMVQ,5216 +torch/onnx/_internal/fx/dynamo_graph_extractor.py,sha256=RmH-oiqFg8Hxqte5E2ZXyJGKt2K9BOXSqc2CG44pUwM,8517 +torch/onnx/_internal/fx/fx_onnx_interpreter.py,sha256=hcanBP1iK4vEqvfAyofz0K0kpGk1M8ZIhWv3h--r6Os,31922 +torch/onnx/_internal/fx/onnxfunction_dispatcher.py,sha256=QMv0wXLoKoCiiULMFa7-_wUwsJ-fDe_zNC4_RY_B4aE,31138 +torch/onnx/_internal/fx/passes/__init__.py,sha256=4WtCudUZpJnuoAJIr_yAUSwYlQ81uXN01_J-sYJ4dNI,570 +torch/onnx/_internal/fx/passes/__pycache__/__init__.cpython-39.pyc,, +torch/onnx/_internal/fx/passes/__pycache__/_utils.cpython-39.pyc,, +torch/onnx/_internal/fx/passes/__pycache__/decomp.cpython-39.pyc,, +torch/onnx/_internal/fx/passes/__pycache__/functionalization.cpython-39.pyc,, +torch/onnx/_internal/fx/passes/__pycache__/modularization.cpython-39.pyc,, +torch/onnx/_internal/fx/passes/__pycache__/readability.cpython-39.pyc,, +torch/onnx/_internal/fx/passes/__pycache__/type_promotion.cpython-39.pyc,, +torch/onnx/_internal/fx/passes/__pycache__/virtualization.cpython-39.pyc,, +torch/onnx/_internal/fx/passes/_utils.py,sha256=1Kw1B3Arm2Wf_ApoIV7O1Z5lxV634oX1nBDA2come0Y,4311 +torch/onnx/_internal/fx/passes/decomp.py,sha256=0KXDtMQr7E2FokCXjCA8oM4Cf-vu1JjBPBOthchGpAA,3633 +torch/onnx/_internal/fx/passes/functionalization.py,sha256=I31MYj2XLi4j0asTmCmJsaHJBadeGZ_4oXyCb5YI4tQ,6413 +torch/onnx/_internal/fx/passes/modularization.py,sha256=RBo6yB-96H0WJqfAQV5GMhZWidebxxT6EO3hSwj1nvo,34878 +torch/onnx/_internal/fx/passes/readability.py,sha256=6DeP1OXwquGdV6xAVVIGV1eQK82j2a9B2i6yW2Pxavo,5613 +torch/onnx/_internal/fx/passes/type_promotion.py,sha256=MbX7_4mLEvsE0d-kS27Fk-N4ADtjiY-1VfV7Nsl9iGM,66247 +torch/onnx/_internal/fx/passes/virtualization.py,sha256=LH-lzv5l8WZtH3ivhbKk4zq303W--TilfBSOKslJLDM,3909 +torch/onnx/_internal/fx/patcher.py,sha256=-vZUHWz7z9a4Vi4rykZ3xwt0jyeJjOTVW4v7fd1TTV8,6182 +torch/onnx/_internal/fx/registration.py,sha256=UhahO2FHsh4oXasYu92yQSZ8gk5hMHcbbq_ZKhk8bdo,3071 +torch/onnx/_internal/fx/serialization.py,sha256=l9uNKWksYse7YeBPwm-WFhSNxRne281aRk5snEnH5kk,11868 +torch/onnx/_internal/fx/type_utils.py,sha256=pba_pnbWsZxdx1eJpVwPRUHKOHfuiLR_ogqkNaunInM,6031 +torch/onnx/_internal/io_adapter.py,sha256=4ho58iZWeLJRzanl-u6sUqaYaNvghQfN0gYEsE7QEyc,23698 +torch/onnx/_internal/jit_utils.py,sha256=ENyLxf1aAojdd9K4ZcviwvFT5-9BoIMr86_uhu26yI4,14503 +torch/onnx/_internal/onnx_proto_utils.py,sha256=F-a1Tt7l9koXQ964HHfN5eI50bs5v9kTdr7mFDqOSh4,9449 +torch/onnx/_internal/onnxruntime.py,sha256=7SZ09_q_j2jAyS-Sdim_PgNQGNQLdUj-OtqLc3toWgQ,54204 +torch/onnx/_internal/registration.py,sha256=P49PZ2ttMXpXCb7Cwl9VE6Pfzmx8hucoxfI1Fr9sVZ0,11456 +torch/onnx/_onnx_supported_ops.py,sha256=64OfuAmDPWKSnhl3qx7GPGanqTVnSwCi42Pn6GOHXks,3417 +torch/onnx/_type_utils.py,sha256=abq61yp8OE_QvADAJqpAKmwyB7rg-EJvP0Gzg_kq6EM,14331 +torch/onnx/errors.py,sha256=PPAbjgyl4UT1_NRFr05pGF_mhZfcz-p86KQhlDD1w7E,3565 +torch/onnx/operators.py,sha256=2tRd6BZ2IT8lhvFX20H5RGyKi-4GWhpnRvbkKwiJQtg,1232 +torch/onnx/ops/__init__.py,sha256=cxWYdkgeobFncUh7e9GG_D97HLNyYnmtivY3U-9q2JU,21411 +torch/onnx/ops/__pycache__/__init__.cpython-39.pyc,, +torch/onnx/ops/__pycache__/_dtype_mappings.cpython-39.pyc,, +torch/onnx/ops/__pycache__/_impl.cpython-39.pyc,, +torch/onnx/ops/__pycache__/_symbolic_impl.cpython-39.pyc,, +torch/onnx/ops/_dtype_mappings.py,sha256=-jODxXM2EjOTWemOG6JSmbE1q35Q2ZOhVosxrF1JgxY,863 +torch/onnx/ops/_impl.py,sha256=gB30Wj4lYZMhuEfPW6i0dZBQxpyG7Ac4na0guTdZHUc,14286 +torch/onnx/ops/_symbolic_impl.py,sha256=7cpIQ9VLPWxb3UfIAUZ1jqEF5qyupjczF-6Doe9Gaew,12100 +torch/onnx/symbolic_caffe2.py,sha256=0UX9WzKVUq9TCBo4ZvbbMZZqBiyoVTCp2U-SAUlmBQ4,11332 +torch/onnx/symbolic_helper.py,sha256=i-hyBWRZFhr_Zt-_Er60eU3GX4P6ZM4MrNsekYA0ORI,84578 +torch/onnx/symbolic_opset10.py,sha256=V6L6EHcI7177iwmxqU6ulV5Tfe4fd7A7jKDgEIHPMVY,38616 +torch/onnx/symbolic_opset11.py,sha256=LXmtww1QfxE0bKw0Ak7SQtbT2tbOXCEWf9Wi4wUGqfA,54811 +torch/onnx/symbolic_opset12.py,sha256=X-MUrvq1vcXVE94I2pMw5_NB6JnqOza8-KGOJ4pgNcg,16123 +torch/onnx/symbolic_opset13.py,sha256=yWLnqSV2Ksk7Y9fYpEOOPVlomDA58ExAfbggBiXkwBk,42367 +torch/onnx/symbolic_opset14.py,sha256=_zZI2y5i25ZHiQlVlAfb4MlGDfoyn5LiX60etkVYK8I,9758 +torch/onnx/symbolic_opset15.py,sha256=hHHqvUgSJaTzEsrIxeIMbLjvFkP1tXzI5BxgIPVYDms,2952 +torch/onnx/symbolic_opset16.py,sha256=Bf94S9mAN9SU-npqhjNXwKPyFs6U8r5dg5tQhtvX8ao,6595 +torch/onnx/symbolic_opset17.py,sha256=a8lRYyXTVG2nnmLOqgM1I3aNhYnfzJ6G4MaBp2uEwgU,8278 +torch/onnx/symbolic_opset18.py,sha256=OuAvH3IzEgZwBtUZCBvgkjX9ieVsCZteptMuYE7yeEE,8357 +torch/onnx/symbolic_opset19.py,sha256=jb9QOasxZxzfPOfINkOmOZQ-XDp_k4DLZAp6QbMCvPE,567 +torch/onnx/symbolic_opset20.py,sha256=w7ITFC4UOpOdL22sSX064fwg7T2fWKyJ2vQL_vufh_k,2538 +torch/onnx/symbolic_opset7.py,sha256=IxAcpgFoamEGQygtwGX_Ftfqt92o77k_XYmESHFLLWs,2185 +torch/onnx/symbolic_opset8.py,sha256=iSSWlbGVnme0K4VrDOZ34XXAMEpB4nZnLHTJVi-BkEQ,15455 +torch/onnx/symbolic_opset9.py,sha256=YxlK3THICu7wfVnDQb2PP4cGpm_u6_qlCY6z_wJM-7M,231898 +torch/onnx/utils.py,sha256=ZG2yXC46Sbr9yby7sjsJMB-spWOONvAEEm9L__N4IeU,76747 +torch/onnx/verification.py,sha256=pTkth-DEOg6f2n-881pCfUweYMzscDMJvogLrButMw8,72401 +torch/optim/__init__.py,sha256=Y98Sxa2HJ7ZGOLDFXCnRt5pJeni8pJpZuil6Jk3o2js,2181 +torch/optim/__pycache__/__init__.cpython-39.pyc,, +torch/optim/__pycache__/_adafactor.cpython-39.pyc,, +torch/optim/__pycache__/_functional.cpython-39.pyc,, +torch/optim/__pycache__/adadelta.cpython-39.pyc,, +torch/optim/__pycache__/adagrad.cpython-39.pyc,, +torch/optim/__pycache__/adam.cpython-39.pyc,, +torch/optim/__pycache__/adamax.cpython-39.pyc,, +torch/optim/__pycache__/adamw.cpython-39.pyc,, +torch/optim/__pycache__/asgd.cpython-39.pyc,, +torch/optim/__pycache__/lbfgs.cpython-39.pyc,, +torch/optim/__pycache__/lr_scheduler.cpython-39.pyc,, +torch/optim/__pycache__/nadam.cpython-39.pyc,, +torch/optim/__pycache__/optimizer.cpython-39.pyc,, +torch/optim/__pycache__/radam.cpython-39.pyc,, +torch/optim/__pycache__/rmsprop.cpython-39.pyc,, +torch/optim/__pycache__/rprop.cpython-39.pyc,, +torch/optim/__pycache__/sgd.cpython-39.pyc,, +torch/optim/__pycache__/sparse_adam.cpython-39.pyc,, +torch/optim/__pycache__/swa_utils.cpython-39.pyc,, +torch/optim/_adafactor.py,sha256=PoexObjbEnOv_BIgsPxgFDtYZQ6epCfF3-GdGqk2OG4,28642 +torch/optim/_functional.py,sha256=r1WAeqm9rcCRK38vqcMwaKQDzkQ6MQfLVpLUVBnruIY,3326 +torch/optim/_multi_tensor/__init__.py,sha256=1B9IB5OE_1Ecb8PPtszzPNIpsguYWwVHCQDcmM5wynw,1058 +torch/optim/_multi_tensor/__init__.pyi,sha256=X8HBdb-Ka3ICBb5zpGJPT68rp59dOYJF8KXahMT6TAs,552 +torch/optim/_multi_tensor/__pycache__/__init__.cpython-39.pyc,, +torch/optim/adadelta.py,sha256=o7ZBTELcbTkwhVpLfrhew93Cb9aCiqwu-WuwbQZ_G6Y,17151 +torch/optim/adagrad.py,sha256=7whXlkSs6BSuxL2OYl4rB67X45NyH6VRaT4V6aWqGes,21393 +torch/optim/adam.py,sha256=iuNtHd0cpxG8cIaUo8heVP0P7KmZELk9IAZDat8O4ZM,40219 +torch/optim/adamax.py,sha256=ho1WzZHq8_yIV6oKzNKMVjC1ujIRYgzZsJmH8mlRjng,17827 +torch/optim/adamw.py,sha256=zblmAKoZM9vavBiAgT6Qoq7u7Vka2wcGi3I8TD1WbdU,7525 +torch/optim/asgd.py,sha256=gOPMSbuL1otbP9DTzIZ3uIC4DZ0Z7-EJvcwVrbOW8ys,16764 +torch/optim/lbfgs.py,sha256=cRSvAgFTRKRA6ohwx4eeZalFJItX1oIZ4lXeXEIXofM,18649 +torch/optim/lr_scheduler.py,sha256=4d6D-Mbt0EaeyXIB6s4l8Yuzmo5fq6bXzq6YJEQg1Fc,86329 +torch/optim/nadam.py,sha256=YvW4l9zJOq-VyOCpaQzSvdt-1JJpOlxLhnbO-hogsqY,27121 +torch/optim/optimizer.py,sha256=N3EJ5Wi0UCpYENfihXNeD_M9H1MQzsooaFHj3YG9Afk,50982 +torch/optim/radam.py,sha256=BKHhouNN0neg1BcIP0oPxZ783A4p68axCR2JuoVO1cM,25304 +torch/optim/rmsprop.py,sha256=QxBX1E9itYZ_MYMnbeP8_GW8EGvsYeYdFcnlkNXYFJk,20975 +torch/optim/rprop.py,sha256=J5RLtkdMygH_78A_pJ92zPzINE0BUlXPBg9T__4AAP4,18002 +torch/optim/sgd.py,sha256=v04tnUq6kzQkd9I6UuUnZh7SSqAXJ-5GxdRI_Pr1JBQ,20602 +torch/optim/sparse_adam.py,sha256=BHqndDKExnS4x2oKfiGTm8jT3jysl929ISqKNgJIovQ,8142 +torch/optim/swa_utils.py,sha256=6ngsV5AOGGeUxvYa-NCSLx2Tsn8qlLF1w-xNSbYHxZo,19650 +torch/overrides.py,sha256=e7rdlRdsN5uULM0MK490njyu_eqhtwak2hjcQwTCyaU,107333 +torch/package/__init__.py,sha256=LV4HkPItiQkYGteUQapkAQLVBv5mYC1zTfYdHSXZxbg,400 +torch/package/__pycache__/__init__.cpython-39.pyc,, +torch/package/__pycache__/_digraph.cpython-39.pyc,, +torch/package/__pycache__/_directory_reader.cpython-39.pyc,, +torch/package/__pycache__/_importlib.cpython-39.pyc,, +torch/package/__pycache__/_mangling.cpython-39.pyc,, +torch/package/__pycache__/_mock.cpython-39.pyc,, +torch/package/__pycache__/_package_pickler.cpython-39.pyc,, +torch/package/__pycache__/_package_unpickler.cpython-39.pyc,, +torch/package/__pycache__/_stdlib.cpython-39.pyc,, +torch/package/__pycache__/file_structure_representation.cpython-39.pyc,, +torch/package/__pycache__/find_file_dependencies.cpython-39.pyc,, +torch/package/__pycache__/glob_group.cpython-39.pyc,, +torch/package/__pycache__/importer.cpython-39.pyc,, +torch/package/__pycache__/package_exporter.cpython-39.pyc,, +torch/package/__pycache__/package_importer.cpython-39.pyc,, +torch/package/_digraph.py,sha256=b1g-2YEbIamHIaMRbaLksJNZX-DbYreF8p0_4dsKFsQ,5803 +torch/package/_directory_reader.py,sha256=oqpFQhKRKbgVcfJMLI7P85eE-O07eOg4pTltXh07SrI,1981 +torch/package/_importlib.py,sha256=TFT6uwqOIDOXOYtxn0oKH69rVcvmEIVAw2Nm7SfQQK8,3093 +torch/package/_mangling.py,sha256=-OZIqmgJ88B7qhpopxLczrlW3rQ2CHbp_GUKO8sJYGo,1955 +torch/package/_mock.py,sha256=dLBJ9n8wre0Zbb3olFgZLdRpzDMM9C65kGnybzk_4Ew,2989 +torch/package/_package_pickler.py,sha256=7zgn_XpDO0TA3tcL8e5lKyvpTLH1_oH__nA9xdZ-u6w,5117 +torch/package/_package_unpickler.py,sha256=v7xjc2QM4L1xInVHrpbc21dvzQmhDs01TJC82l3Q0OU,1019 +torch/package/_stdlib.py,sha256=NrqwuS6-t8H8y4ub3Or1hwwP_eosh30w0AMd5CUbwRg,4313 +torch/package/analyze/__init__.py,sha256=MSt9jByAbADeUqxkp02TPZX9Zs0VAECwSv4fx0IvBII,132 +torch/package/analyze/__pycache__/__init__.cpython-39.pyc,, +torch/package/analyze/__pycache__/find_first_use_of_broken_modules.cpython-39.pyc,, +torch/package/analyze/__pycache__/is_from_package.cpython-39.pyc,, +torch/package/analyze/__pycache__/trace_dependencies.cpython-39.pyc,, +torch/package/analyze/find_first_use_of_broken_modules.py,sha256=_hctT_gJgARx28bKpOhBef0DsfPX3Airb2ZFTIkMC6A,1065 +torch/package/analyze/is_from_package.py,sha256=FDGNhzhtyWTw_fVWe750ufQluGPGlD1MDjOEeyJ3lA4,420 +torch/package/analyze/trace_dependencies.py,sha256=W6j0_Cv-ZT1wE_NABog2DIYVUw9rNF3_I0Xf97d3-yA,2300 +torch/package/file_structure_representation.py,sha256=e1HVaM5ZJy_vqVx_yEe8be-sSjOubmJMDeI__3aOyX0,4890 +torch/package/find_file_dependencies.py,sha256=SJBzyuKQ0ioOMSR2Kc7GDEzYarTdnYI0u14QXnubu7g,4075 +torch/package/glob_group.py,sha256=hkEjc60_fFAoEGg_M2_tUk_wrmqh5tUlYY2u5hjSf2s,3750 +torch/package/importer.py,sha256=zWc6aINx_4dJnUPuiESosFY128w7q4imutpAZl_iDco,9126 +torch/package/package_exporter.py,sha256=uDftOFOs3bxB_90cPEXCNpRnMAfewrd6XM-NP1mmck4,52051 +torch/package/package_importer.py,sha256=IyhcfdvN-Jm8XrhEFkrdK2PBRhobk0hBaYlgEmKgDKo,32454 +torch/profiler/__init__.py,sha256=faV3z5jbtBVvPuqRYtPz20bWAz__WiRzMh4o1wCRZTY,1631 +torch/profiler/__pycache__/__init__.cpython-39.pyc,, +torch/profiler/__pycache__/_memory_profiler.cpython-39.pyc,, +torch/profiler/__pycache__/_pattern_matcher.cpython-39.pyc,, +torch/profiler/__pycache__/_utils.cpython-39.pyc,, +torch/profiler/__pycache__/itt.cpython-39.pyc,, +torch/profiler/__pycache__/profiler.cpython-39.pyc,, +torch/profiler/__pycache__/python_tracer.cpython-39.pyc,, +torch/profiler/_memory_profiler.py,sha256=kxUfpVlLYem65knKLZ8jwxznkQ1HasKGzwhNZToPH3w,49360 +torch/profiler/_pattern_matcher.py,sha256=XR-I9PTMuiSZIDGYNIweLHrycjNjprvJhv3zMuGMbik,25400 +torch/profiler/_utils.py,sha256=ofTOXVQDVzXFmz5kXAlWzwEI3dHC8I5qolmPf713feg,14322 +torch/profiler/itt.py,sha256=LRtz7xVIVAhJN9CxweUo6SpcZQxswfoiT61gVhZZYMo,1862 +torch/profiler/profiler.py,sha256=RgAGRr8QCl9uMe7COPN_PwX2Zk7Ybtp6gDrjDVUZSlc,46358 +torch/profiler/python_tracer.py,sha256=aDA0xiGsxbxr2JQ1m04WgIopUU2qCO1saJR-l18Yrrw,495 +torch/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/quantization/__init__.py,sha256=QfStNsp4EIUCQPnoGEX2-C6F8nkWjGwdZKUiSxwu5Us,2740 +torch/quantization/__pycache__/__init__.cpython-39.pyc,, +torch/quantization/__pycache__/_numeric_suite.cpython-39.pyc,, +torch/quantization/__pycache__/_numeric_suite_fx.cpython-39.pyc,, +torch/quantization/__pycache__/_quantized_conversions.cpython-39.pyc,, +torch/quantization/__pycache__/fake_quantize.cpython-39.pyc,, +torch/quantization/__pycache__/fuse_modules.cpython-39.pyc,, +torch/quantization/__pycache__/fuser_method_mappings.cpython-39.pyc,, +torch/quantization/__pycache__/observer.cpython-39.pyc,, +torch/quantization/__pycache__/qconfig.cpython-39.pyc,, +torch/quantization/__pycache__/quant_type.cpython-39.pyc,, +torch/quantization/__pycache__/quantization_mappings.cpython-39.pyc,, +torch/quantization/__pycache__/quantize.cpython-39.pyc,, +torch/quantization/__pycache__/quantize_fx.cpython-39.pyc,, +torch/quantization/__pycache__/quantize_jit.cpython-39.pyc,, +torch/quantization/__pycache__/stubs.cpython-39.pyc,, +torch/quantization/__pycache__/utils.cpython-39.pyc,, +torch/quantization/_numeric_suite.py,sha256=Apsbav_2_HtzAG7591KPYCABK3u-2ems0f91ckZYvqM,807 +torch/quantization/_numeric_suite_fx.py,sha256=WTMnaKmkQAsB6IgE3Eajh-x8fhXufgjjTa1fFpTZIs0,778 +torch/quantization/_quantized_conversions.py,sha256=UPArN_a1-aNu9eczrvj54Jgzgzolgr6tOK0mqFBzKmw,4454 +torch/quantization/fake_quantize.py,sha256=Ug5k6GxDURxlTHuKstQdkh-zEVDUmeJ4PNvCkat3-rw,1047 +torch/quantization/fuse_modules.py,sha256=Zc_rz6dKEugZhl8E7JfCM21_BaXcwNzw0RvmqJGw7gU,754 +torch/quantization/fuser_method_mappings.py,sha256=oIVWF7hZ6U2OFS-oKw3QjuixmXaFXTW3UcE86Qk_SrQ,526 +torch/quantization/fx/__init__.py,sha256=qBBzAdCsV8Jp2yNtquj0C6jkrGaNx3Qc2PsyLdXh330,609 +torch/quantization/fx/__pycache__/__init__.cpython-39.pyc,, +torch/quantization/fx/__pycache__/_equalize.cpython-39.pyc,, +torch/quantization/fx/__pycache__/convert.cpython-39.pyc,, +torch/quantization/fx/__pycache__/fuse.cpython-39.pyc,, +torch/quantization/fx/__pycache__/fusion_patterns.cpython-39.pyc,, +torch/quantization/fx/__pycache__/graph_module.cpython-39.pyc,, +torch/quantization/fx/__pycache__/match_utils.cpython-39.pyc,, +torch/quantization/fx/__pycache__/pattern_utils.cpython-39.pyc,, +torch/quantization/fx/__pycache__/prepare.cpython-39.pyc,, +torch/quantization/fx/__pycache__/quantization_patterns.cpython-39.pyc,, +torch/quantization/fx/__pycache__/quantization_types.cpython-39.pyc,, +torch/quantization/fx/__pycache__/utils.cpython-39.pyc,, +torch/quantization/fx/_equalize.py,sha256=ZHnbL_PV1fMJyNvBNyYOgwJ3nSjTpPq_nVM1F5iDHFw,1288 +torch/quantization/fx/convert.py,sha256=s84rvG4ApmHJhZRJF2z_EVSeaS2hqSpWa799IQ__p3A,395 +torch/quantization/fx/fuse.py,sha256=t1CFzL6X4n0hrEM9WsoAxGkVADvbkuzq1U_2kTgWpzg,389 +torch/quantization/fx/fusion_patterns.py,sha256=EMbl_P6a5OUYm1H3wQxmvSnoRAUnC42Tf5reTOzNdtI,424 +torch/quantization/fx/graph_module.py,sha256=aEpEBJd6ylKA5nf-S9ZoOuMwLNh0i9o2KWOUUkUYHDk,590 +torch/quantization/fx/match_utils.py,sha256=f_Uw5eWYMDpUHMrijx-u1EfDZ2XdjJgwvONhQbtZt2E,470 +torch/quantization/fx/pattern_utils.py,sha256=cJ6hXpHQiCnlYAtE8NAlDR3Zp21T20LTV_HXY5jFOEw,1333 +torch/quantization/fx/prepare.py,sha256=fTOOjMADuXBsx8pF_bs81c2Te-pENKkMJeL3b9-ETVU,395 +torch/quantization/fx/quantization_patterns.py,sha256=cieAvymJVrFxnNsOeuBGjhd31UrK0TIBqY_ChNZ23kI,2135 +torch/quantization/fx/quantization_types.py,sha256=uD3t3Tht8BISqBotXe5JPgt9BPmXTz64Ple8zn0BCJE,404 +torch/quantization/fx/utils.py,sha256=lCEIQa2UZ21-Rlhhijz0383Et25BJCBMuG-auvf-Fxw,743 +torch/quantization/observer.py,sha256=ic0SaZsyAfjjbxdt4qRRXiU_2eiZ1vMEQZIubPYrbiM,1114 +torch/quantization/qconfig.py,sha256=_w7zSYKVivzFbjG3fLEamClj2AoBhj953gUyHnawiEU,940 +torch/quantization/quant_type.py,sha256=iAwgxvRTAWOq6UWwWOInYIlfoyqAKrsirv_eefd7dCE,409 +torch/quantization/quantization_mappings.py,sha256=E5Bgvn9qgeOaMyTTMgwSdoa4KsDHcv-kPl8YybKgWHE,1176 +torch/quantization/quantize.py,sha256=_2bxz1Nxy2R3jer569e0OWCAvOJDOBRNg_WLkOB4HCs,834 +torch/quantization/quantize_fx.py,sha256=VFgy22MuCnZgzzSEvgo361E6rQXpWAGyb8HxzrjOHt8,762 +torch/quantization/quantize_jit.py,sha256=4-N7zY1HEFSfBbxNiz43hYlZVbWxekcWSLo7Y1y94dA,740 +torch/quantization/stubs.py,sha256=3fCztwAMksJYzQbZlBvS_v3u4-QJ7jKdGIJUkCcHT-0,402 +torch/quantization/utils.py,sha256=TFeb5f6X76T6oNKThjsiS3TpFXRoJgTdiO4MNdaEDTc,862 +torch/quasirandom.py,sha256=xK2fnd-Wyvca2WuluggUJtNmxc81awNO88oK8upiCC4,8165 +torch/random.py,sha256=FIT9Qn0kVnJPfc2mXMzTbYMT9J5FtOduD76i-t1RBxM,7400 +torch/return_types.py,sha256=ohmoEziZmcBXQrJr0_LLqaJIGQchUZEsI6miQPwUQBQ,1536 +torch/return_types.pyi,sha256=gJzMkb9SBctMhWoppxGbxeBO50Gj_wD824luxPu-YN8,18442 +torch/serialization.py,sha256=z9aXr0czC_jirQUzgTTOl9yXxf9qer58pFp_f0B_Muc,87011 +torch/share/cmake/ATen/ATenConfig.cmake,sha256=CTO_ZzHAwndUCAoqNmZCTxJySMBqqgymTW_dgh8HFp4,311 +torch/share/cmake/Caffe2/Caffe2Config.cmake,sha256=W8adntTMifZ3u6oo6Skid9nLvTN3-cJRTvzuobj1ikA,5484 +torch/share/cmake/Caffe2/Caffe2Targets-release.cmake,sha256=WmF_gCh039CyLxXNdkddLr1hOm97HlbY_CQH00rZDWc,1869 +torch/share/cmake/Caffe2/Caffe2Targets.cmake,sha256=j9EEu9MgCKJEgP-CfwdPCdm4YeDytZJAYnYrqsSZV0M,7211 +torch/share/cmake/Caffe2/FindCUDAToolkit.cmake,sha256=aHbgBCQAUIPxcf02qlKwhM3Z0mzy47NvofDV_YA74Zg,39849 +torch/share/cmake/Caffe2/FindCUDSS.cmake,sha256=sA21H-knRZmCcnJmM87dUss3KpT-M2ykBfS_ukZz75s,2765 +torch/share/cmake/Caffe2/FindCUSPARSELT.cmake,sha256=_3gRzt3JhJLSoxGANz8olAJc1HYgZriCb5U-VuxAgeY,3135 +torch/share/cmake/Caffe2/FindSYCLToolkit.cmake,sha256=L2LFn89vj4k4Nb-OMyEEIPP9nhxAFRlS8cEm47LgJa4,4689 +torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDA.cmake,sha256=yVY9ZuazsjOWvko_j0ebC9zXNBgDBqn7yFFUmuGWwdI,536 +torch/share/cmake/Caffe2/Modules_CUDA_fix/FindCUDNN.cmake,sha256=98AV65I9r6hXZ3djts25pWsUGSJij0uwaDLlNvr-OPw,3163 +torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/CMakeInitializeConfigs.cmake,sha256=pSX-UCD4EzIprk563tpH9yFmGH_ziG6NwUQt1GHhhos,1697 +torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA.cmake,sha256=6q26iPrkPNsNeOKwauXiZTXhjULqKIf_D0BpYTDO9hM,88641 +torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/make2cmake.cmake,sha256=p7tz2-DwvWVTxW082qsoMSXTwWlKk3AWw32drCNkjFk,4031 +torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/parse_cubin.cmake,sha256=T__r-AaXPXniuqi03AH2Pm4QNH_cQDJYQTHHCwBXHtg,3548 +torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/run_nvcc.cmake,sha256=41q0kdvB4N6yKC7Lt80TFblNjVs7i8SwuuN4ncDDQ28,12116 +torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake,sha256=m3w7BoWJYLSnrOKac-wyDLgTDJ1StLnA8oP7TwLQ8iY,11872 +torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindPackageHandleStandardArgs.cmake,sha256=OJQRMapUAJYF98ivq3YLfTzYXY4NUBARyPwsbNmCojM,15288 +torch/share/cmake/Caffe2/Modules_CUDA_fix/upstream/FindPackageMessage.cmake,sha256=aggekwaBIweoGRvyGnWWSLFOfGhinyHgkP-AFdXlIME,1611 +torch/share/cmake/Caffe2/public/LoadHIP.cmake,sha256=27zUk9G-OZkoeiAvBakL27_SzzjzQNkKhJAoj3rNIWg,10432 +torch/share/cmake/Caffe2/public/cuda.cmake,sha256=vDPJAkyuQuV9Lq7iTWApy7q-QIJBaX3zpb4erksUzmQ,13804 +torch/share/cmake/Caffe2/public/gflags.cmake,sha256=lqCHXupLZgbofqyIO2MunBGYZB-1EJPUSx6QZoJxDIc,2703 +torch/share/cmake/Caffe2/public/glog.cmake,sha256=4rEgDMK1s552hsGVEjG6BwIETKu4y1wcwhGu_PPgYRo,2390 +torch/share/cmake/Caffe2/public/mkl.cmake,sha256=smEnmDMPnXJI24n7YQfuMgGpancHt8np3Fmwtn-sZAo,1357 +torch/share/cmake/Caffe2/public/mkldnn.cmake,sha256=9it-d7b1GLo7NomimCLT3kP3N0iZErkHhr4pU2pE-to,462 +torch/share/cmake/Caffe2/public/protobuf.cmake,sha256=HaapG099tgG6ti9wx5IxIEdm2nEcwz-cOO53xQ8b6ss,4095 +torch/share/cmake/Caffe2/public/utils.cmake,sha256=UHboz87O8h_EJ07tPKKdGavYo_FFRX4GjqoUmJj5jfI,22996 +torch/share/cmake/Caffe2/public/xpu.cmake,sha256=EHVWgC3ZDI2zzhsbfMkjCZShWpw_e3xYeaoS_TfSNfw,1208 +torch/share/cmake/Torch/TorchConfig.cmake,sha256=NMifcbR-beyYwiPIOGoLG-REiTluMGbLa5jWJ3D2QbM,5238 +torch/share/cmake/Torch/TorchConfigVersion.cmake,sha256=Ri-Hq0C3Wo28Fy3i3cIdwVGr7O9DP9KMZ6DA0ohgLCY,377 +torch/signal/__init__.py,sha256=vsiLo7rYe0hiZTWzhpL90GTSQVh1rnENJp-aYdjb7nE,50 +torch/signal/__pycache__/__init__.cpython-39.pyc,, +torch/signal/windows/__init__.py,sha256=Hhn47ZK7d9faPk0ErykkdULI8qMtln9E-7s1leOIyp4,411 +torch/signal/windows/__pycache__/__init__.cpython-39.pyc,, +torch/signal/windows/__pycache__/windows.cpython-39.pyc,, +torch/signal/windows/windows.py,sha256=Ld6dHhTNDtCpYsibFBjjfMpzHu3UFUmVlz1hvXDPAdE,23544 +torch/sparse/__init__.py,sha256=z6lJtMEcqeQbl87C5lG8UEMUtlalOp5q0C_MwrnCo24,26221 +torch/sparse/__pycache__/__init__.cpython-39.pyc,, +torch/sparse/__pycache__/_semi_structured_conversions.cpython-39.pyc,, +torch/sparse/__pycache__/_semi_structured_ops.cpython-39.pyc,, +torch/sparse/__pycache__/_triton_ops.cpython-39.pyc,, +torch/sparse/__pycache__/_triton_ops_meta.cpython-39.pyc,, +torch/sparse/__pycache__/semi_structured.cpython-39.pyc,, +torch/sparse/_semi_structured_conversions.py,sha256=rAHYgNaL71jO5lRc5ckMm2yWcSZloEFYJ7OwkuXwyiM,14372 +torch/sparse/_semi_structured_ops.py,sha256=ugBlDfqJDVKPSFo6jSyQLr-2_9FpG9803oOzO5wUMCk,6569 +torch/sparse/_triton_ops.py,sha256=hHMWwDjIw4kUp7Bd8iFMvPXnucywZBq6dM9Gi8FdKvc,88679 +torch/sparse/_triton_ops_meta.py,sha256=l1umS0R9EY1nYrG-e3Ld81CTbLyX8n4CpygIzk-Enzg,508540 +torch/sparse/semi_structured.py,sha256=Em1S2PgV64s7odlZzyAXzZHmAJWC-gBwZVHSsY-GZGw,28701 +torch/special/__init__.py,sha256=SYuCuBexvur4LzY0XZB4c9yhDCWZqzXbY6sWe5-yVsU,34294 +torch/special/__pycache__/__init__.cpython-39.pyc,, +torch/storage.py,sha256=09dIpZtcbfkdwYB8h7-a0Es-NmkNHL3yBIIyuOe0iBY,53904 +torch/testing/__init__.py,sha256=Yy0wz_YYjsQIl6wPWAXNBofkITgtS8K4V4uUqUfV0Mo,192 +torch/testing/__pycache__/__init__.cpython-39.pyc,, +torch/testing/__pycache__/_comparison.cpython-39.pyc,, +torch/testing/__pycache__/_creation.cpython-39.pyc,, +torch/testing/__pycache__/_utils.cpython-39.pyc,, +torch/testing/_comparison.py,sha256=sEmtsHDRNBiYCwSGTP5-1WXqgbljiR4_NsuhQ_sLfEU,67714 +torch/testing/_creation.py,sha256=7u2AjOdwoomPQ4X-_akPTMl7Hp7ZyTNmL1TM8VDM5Hg,12474 +torch/testing/_internal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/testing/_internal/__pycache__/__init__.cpython-39.pyc,, +torch/testing/_internal/__pycache__/autocast_test_lists.cpython-39.pyc,, +torch/testing/_internal/__pycache__/autograd_function_db.cpython-39.pyc,, +torch/testing/_internal/__pycache__/check_kernel_launches.cpython-39.pyc,, +torch/testing/_internal/__pycache__/common_cuda.cpython-39.pyc,, +torch/testing/_internal/__pycache__/common_device_type.cpython-39.pyc,, +torch/testing/_internal/__pycache__/common_dist_composable.cpython-39.pyc,, +torch/testing/_internal/__pycache__/common_distributed.cpython-39.pyc,, +torch/testing/_internal/__pycache__/common_dtype.cpython-39.pyc,, +torch/testing/_internal/__pycache__/common_fsdp.cpython-39.pyc,, +torch/testing/_internal/__pycache__/common_jit.cpython-39.pyc,, +torch/testing/_internal/__pycache__/common_methods_invocations.cpython-39.pyc,, +torch/testing/_internal/__pycache__/common_mkldnn.cpython-39.pyc,, +torch/testing/_internal/__pycache__/common_modules.cpython-39.pyc,, +torch/testing/_internal/__pycache__/common_mps.cpython-39.pyc,, +torch/testing/_internal/__pycache__/common_nn.cpython-39.pyc,, +torch/testing/_internal/__pycache__/common_optimizers.cpython-39.pyc,, +torch/testing/_internal/__pycache__/common_pruning.cpython-39.pyc,, +torch/testing/_internal/__pycache__/common_quantization.cpython-39.pyc,, +torch/testing/_internal/__pycache__/common_quantized.cpython-39.pyc,, +torch/testing/_internal/__pycache__/common_subclass.cpython-39.pyc,, +torch/testing/_internal/__pycache__/common_utils.cpython-39.pyc,, +torch/testing/_internal/__pycache__/composite_compliance.cpython-39.pyc,, +torch/testing/_internal/__pycache__/custom_op_db.cpython-39.pyc,, +torch/testing/_internal/__pycache__/custom_tensor.cpython-39.pyc,, +torch/testing/_internal/__pycache__/dist_utils.cpython-39.pyc,, +torch/testing/_internal/__pycache__/dynamo_test_failures.cpython-39.pyc,, +torch/testing/_internal/__pycache__/fake_config_module.cpython-39.pyc,, +torch/testing/_internal/__pycache__/fake_config_module2.cpython-39.pyc,, +torch/testing/_internal/__pycache__/fake_config_module3.cpython-39.pyc,, +torch/testing/_internal/__pycache__/hop_db.cpython-39.pyc,, +torch/testing/_internal/__pycache__/hypothesis_utils.cpython-39.pyc,, +torch/testing/_internal/__pycache__/inductor_utils.cpython-39.pyc,, +torch/testing/_internal/__pycache__/jit_metaprogramming_utils.cpython-39.pyc,, +torch/testing/_internal/__pycache__/jit_utils.cpython-39.pyc,, +torch/testing/_internal/__pycache__/logging_tensor.cpython-39.pyc,, +torch/testing/_internal/__pycache__/logging_utils.cpython-39.pyc,, +torch/testing/_internal/__pycache__/quantization_torch_package_models.cpython-39.pyc,, +torch/testing/_internal/__pycache__/static_module.cpython-39.pyc,, +torch/testing/_internal/__pycache__/subclasses.cpython-39.pyc,, +torch/testing/_internal/__pycache__/torchbind_impls.cpython-39.pyc,, +torch/testing/_internal/__pycache__/triton_utils.cpython-39.pyc,, +torch/testing/_internal/__pycache__/two_tensor.cpython-39.pyc,, +torch/testing/_internal/autocast_test_lists.py,sha256=HIEXBGxQR_ET3kQOPLbM7CoXqStiaBUbViutWe-R2WA,28863 +torch/testing/_internal/autograd_function_db.py,sha256=cozNmhOxnJzCZ5Nju7IaPyLDRfzMB_HgEIOMuQdj4Nk,20246 +torch/testing/_internal/check_kernel_launches.py,sha256=xwVlNz0cv7YD6tL_fXZBtqB9uycjStvmSf9ObjBvkqM,6191 +torch/testing/_internal/codegen/__init__.py,sha256=DiaaC4wLY-kzeqPWCEVKCSzOa61Vli38kuG2Hd50x3c,23 +torch/testing/_internal/codegen/__pycache__/__init__.cpython-39.pyc,, +torch/testing/_internal/common_cuda.py,sha256=_NwvmyGupA6y2eE0P_7RXDLNQVT80vfraROsQkRHahI,15075 +torch/testing/_internal/common_device_type.py,sha256=VnQAQJvhmPAhAq164WohCO7mewzMnG5GIUB9Np46wo0,74886 +torch/testing/_internal/common_dist_composable.py,sha256=kvK0gEZLUA1afQm0TBYUmp-xBIvLPcW0PgHo5m0WaoM,3689 +torch/testing/_internal/common_distributed.py,sha256=O1G_7OzZK7CsVuWyHDArJ2Ffx9Gu2OkELgpdH7wk2IU,66436 +torch/testing/_internal/common_dtype.py,sha256=uhzUnIGATt2F0VKXmd14v1wpsICCELNsowKq3-CcjMQ,5078 +torch/testing/_internal/common_fsdp.py,sha256=NwSwy71SWaz5maAH1xEMBSWAeykBXS8wbeYSU79FMxc,59595 +torch/testing/_internal/common_jit.py,sha256=dHyjppST3eDeRdciYKBZiOcYPIjfWHFM132jvVY7ano,16170 +torch/testing/_internal/common_methods_invocations.py,sha256=O_wu4_8m70VjvtzYn_qzHQTbYfgOIh_1CP-EQ4eM_As,1231573 +torch/testing/_internal/common_mkldnn.py,sha256=0azW5_XzF2Hl9Bv2OUXujAXljkooaKUgAUP7qOWGDJw,2363 +torch/testing/_internal/common_modules.py,sha256=ufrzEZFNRPHgNRuYfWTobIxQ7vrk2tRoANTMt5WeASM,221007 +torch/testing/_internal/common_mps.py,sha256=HYzbmqOvHomnqPBM_a6hbeJ-VblJEQy98jt3RWygkzA,38670 +torch/testing/_internal/common_nn.py,sha256=R4ylxtcwejchXMglxRbBcHB7aL2Dm4sM0izMkR5J2gg,176504 +torch/testing/_internal/common_optimizers.py,sha256=R57l2MLTGJ2AM-l0JK3Km9dLYgtgDwbGe2X1Zr0rvgc,84567 +torch/testing/_internal/common_pruning.py,sha256=PE-FAOHFUB-0BUHf5GWU6xZe-X_bZaCsl0hT8phb-hU,14040 +torch/testing/_internal/common_quantization.py,sha256=0gvF0tEENGYUqvdqQ_Nql2nv-ZonIE-IUjszK6g1pSM,119193 +torch/testing/_internal/common_quantized.py,sha256=ciRoXO-SUgyso5bTqq4ee4_80GcTJXswmj-Explxj4c,18195 +torch/testing/_internal/common_subclass.py,sha256=bPIhwnCHbcWnZmUtub5K7-UfJyA6khTnEkayiGRzDpU,12540 +torch/testing/_internal/common_utils.py,sha256=thBZlD0IWv5wMgBOAnArmH0RH5Dr703pNsyccwpxukw,242645 +torch/testing/_internal/composite_compliance.py,sha256=YPcHkBWMonoRoEjVsyBXaNOZDtki1JzS55r8KQQ2rVM,26633 +torch/testing/_internal/custom_op_db.py,sha256=FeY-PkrUDfjFsiZKZYt5visIK5Qq4QB_0w7Q3m0HRvc,20230 +torch/testing/_internal/custom_tensor.py,sha256=NcKu4pXqHR8uVI-6flWj32Rm9330R_2OEkou1ZZi10M,5378 +torch/testing/_internal/data/__init__.py,sha256=DiaaC4wLY-kzeqPWCEVKCSzOa61Vli38kuG2Hd50x3c,23 +torch/testing/_internal/data/__pycache__/__init__.cpython-39.pyc,, +torch/testing/_internal/data/__pycache__/network1.cpython-39.pyc,, +torch/testing/_internal/data/__pycache__/network2.cpython-39.pyc,, +torch/testing/_internal/data/network1.py,sha256=oui4qSVOL39z1DqyRW0Ua-rhpC1weEGhQTpZfy5dAKg,179 +torch/testing/_internal/data/network2.py,sha256=bDrLEHTqoZYGLSt_wXgE_CAW0H6dRB1Pqr3KBQqk1uw,210 +torch/testing/_internal/dist_utils.py,sha256=t40OBrDA79mDmyf_ke_AigsadOrL5e_adwqjf1LJfXw,7454 +torch/testing/_internal/distributed/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/testing/_internal/distributed/__pycache__/__init__.cpython-39.pyc,, +torch/testing/_internal/distributed/__pycache__/checkpoint_utils.cpython-39.pyc,, +torch/testing/_internal/distributed/__pycache__/common_state_dict.cpython-39.pyc,, +torch/testing/_internal/distributed/__pycache__/ddp_under_dist_autograd_test.cpython-39.pyc,, +torch/testing/_internal/distributed/__pycache__/distributed_test.cpython-39.pyc,, +torch/testing/_internal/distributed/__pycache__/distributed_utils.cpython-39.pyc,, +torch/testing/_internal/distributed/__pycache__/fake_pg.cpython-39.pyc,, +torch/testing/_internal/distributed/__pycache__/multi_threaded_pg.cpython-39.pyc,, +torch/testing/_internal/distributed/__pycache__/rpc_utils.cpython-39.pyc,, +torch/testing/_internal/distributed/_shard/__init__.py,sha256=EhTGyPtX-bzRNfqTY2Ql8gEScx4267U_pGa_h69kMZU,28 +torch/testing/_internal/distributed/_shard/__pycache__/__init__.cpython-39.pyc,, +torch/testing/_internal/distributed/_shard/__pycache__/test_common.cpython-39.pyc,, +torch/testing/_internal/distributed/_shard/sharded_tensor/__init__.py,sha256=zrSJGcq_ttC1xU87DvgM2s2G3wRHtxVJg67CRB4U0WY,3307 +torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-39.pyc,, +torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_ops_common.cpython-39.pyc,, +torch/testing/_internal/distributed/_shard/sharded_tensor/__pycache__/_test_st_common.cpython-39.pyc,, +torch/testing/_internal/distributed/_shard/sharded_tensor/_test_ops_common.py,sha256=7UGyAp6lITtyl5DVz865Nqgvxmb5gONTkMeX4g4asog,4148 +torch/testing/_internal/distributed/_shard/sharded_tensor/_test_st_common.py,sha256=Pms5o2nHLLjYS6ZGAI7n9jqZry6w-AN8-Huh4bwswH0,1674 +torch/testing/_internal/distributed/_shard/test_common.py,sha256=upgmNyYS-cU5DBgh1gPd-HrkiPMqpOJanuyQuSZ_qkY,1260 +torch/testing/_internal/distributed/_tensor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/testing/_internal/distributed/_tensor/__pycache__/__init__.cpython-39.pyc,, +torch/testing/_internal/distributed/_tensor/__pycache__/common_dtensor.cpython-39.pyc,, +torch/testing/_internal/distributed/_tensor/common_dtensor.py,sha256=L1Mhzi0DaqXdSUG1aREgvOIn_CY4DoVdwVqEtOg3U-A,22384 +torch/testing/_internal/distributed/checkpoint_utils.py,sha256=hJoce4WTrrCkoZieRizGWqyBP8Hffsi9x83gTg4Osa0,5295 +torch/testing/_internal/distributed/common_state_dict.py,sha256=oo4MLbGpqvK14qgQOcY--A6mcexLIwVLTMRA-VXgUdI,6875 +torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py,sha256=xoXwoIAPMGoiMlTa1z-mLxrMVa1U6KAmt-YTURBj0w4,27696 +torch/testing/_internal/distributed/distributed_test.py,sha256=F2hJ--aiigQx_IG1k6n69LLKZYa0HdRtjhfZ9n6mL7c,447435 +torch/testing/_internal/distributed/distributed_utils.py,sha256=jfE3w1tGuKZDo8_6jVsZ9n9psiWxw3SYcdkblc55aVA,2017 +torch/testing/_internal/distributed/fake_pg.py,sha256=zmRoPBdV303qOAsEVrOBVSDPwVhIifFVMU_riSTCzXw,1061 +torch/testing/_internal/distributed/multi_threaded_pg.py,sha256=1TBV2myPu5FctlGX3x3RS2g2e_F9KaOceF3IS889Sr0,19875 +torch/testing/_internal/distributed/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/testing/_internal/distributed/nn/__pycache__/__init__.cpython-39.pyc,, +torch/testing/_internal/distributed/nn/api/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/testing/_internal/distributed/nn/api/__pycache__/__init__.cpython-39.pyc,, +torch/testing/_internal/distributed/nn/api/__pycache__/remote_module_test.cpython-39.pyc,, +torch/testing/_internal/distributed/nn/api/remote_module_test.py,sha256=3q1VlPi-oT-t8i4Fmm5_JOqk5ktjNAtsLUtRjj4xLN0,30479 +torch/testing/_internal/distributed/rpc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/testing/_internal/distributed/rpc/__pycache__/__init__.cpython-39.pyc,, +torch/testing/_internal/distributed/rpc/__pycache__/dist_autograd_test.cpython-39.pyc,, +torch/testing/_internal/distributed/rpc/__pycache__/dist_optimizer_test.cpython-39.pyc,, +torch/testing/_internal/distributed/rpc/__pycache__/faulty_agent_rpc_test.cpython-39.pyc,, +torch/testing/_internal/distributed/rpc/__pycache__/faulty_rpc_agent_test_fixture.cpython-39.pyc,, +torch/testing/_internal/distributed/rpc/__pycache__/rpc_agent_test_fixture.cpython-39.pyc,, +torch/testing/_internal/distributed/rpc/__pycache__/rpc_test.cpython-39.pyc,, +torch/testing/_internal/distributed/rpc/__pycache__/tensorpipe_rpc_agent_test_fixture.cpython-39.pyc,, +torch/testing/_internal/distributed/rpc/dist_autograd_test.py,sha256=BS3wKQGw7UjmQfwgtJfP8dLiuSWy4oYzPi5kxzGJtDg,109837 +torch/testing/_internal/distributed/rpc/dist_optimizer_test.py,sha256=Ih4Iuy7rAxfqUVhPHxMNk2p5sDH_Cx9_l3LcpLZ-Rso,10855 +torch/testing/_internal/distributed/rpc/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/testing/_internal/distributed/rpc/examples/__pycache__/__init__.cpython-39.pyc,, +torch/testing/_internal/distributed/rpc/examples/__pycache__/parameter_server_test.cpython-39.pyc,, +torch/testing/_internal/distributed/rpc/examples/__pycache__/reinforcement_learning_rpc_test.cpython-39.pyc,, +torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py,sha256=3XR1KBuD5V6fIqjdYEiVxzsM8O98HN7mgb7oIsiFUO0,4695 +torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py,sha256=md4V57qCVa8zwg018HveBUgEXsB9Ab9FrBd8SpQq1Io,9520 +torch/testing/_internal/distributed/rpc/faulty_agent_rpc_test.py,sha256=lpgXkSyctpgP6k2euI7ZBkKr_7hJDt8661To-2VMDTQ,14576 +torch/testing/_internal/distributed/rpc/faulty_rpc_agent_test_fixture.py,sha256=OLnDlz-dvddeG9-FxwyK-INzQCrGPrw5wwA6ik_FREY,2206 +torch/testing/_internal/distributed/rpc/jit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/testing/_internal/distributed/rpc/jit/__pycache__/__init__.cpython-39.pyc,, +torch/testing/_internal/distributed/rpc/jit/__pycache__/dist_autograd_test.cpython-39.pyc,, +torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test.cpython-39.pyc,, +torch/testing/_internal/distributed/rpc/jit/__pycache__/rpc_test_faulty.cpython-39.pyc,, +torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py,sha256=6HnwmARmcwdrm5qDXrDtTiDPFnqv-N_MgXlK_Taxgy8,4285 +torch/testing/_internal/distributed/rpc/jit/rpc_test.py,sha256=rFG9iZG-9QFAxtEODqgvxY6gGxKybvTKqWhTCu4H1zk,48263 +torch/testing/_internal/distributed/rpc/jit/rpc_test_faulty.py,sha256=UbP5mSXW56iDLD5_QqR2HIogfUWGnEEt8sguHJDLQww,8193 +torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py,sha256=zhpzVAar7GiEroCkwV4YuWLnGnG1iFQ42zo_V2PLfVA,1937 +torch/testing/_internal/distributed/rpc/rpc_test.py,sha256=WpcWE31J0GBHyMPueW7dAwJo_dv71qwKcfmG3BDol0U,232367 +torch/testing/_internal/distributed/rpc/tensorpipe_rpc_agent_test_fixture.py,sha256=hZ4ub4qtB50YYBoB2ckI_WoCFBB_ry4rGEWzJke-77Y,1002 +torch/testing/_internal/distributed/rpc_utils.py,sha256=hbwJG6YwJ-KtCseuXJKCkvPfXOsaticWbN_NVH71-5A,6803 +torch/testing/_internal/dynamo_test_failures.py,sha256=2Oz22E8Y0txvlHWfpgjGsOTDMolpRpo2YVNfgCig9-M,5584 +torch/testing/_internal/fake_config_module.py,sha256=rbyPoOYaFz9KFL521Lkt5XSU-CEexSHiiKQFLqEmc_4,1297 +torch/testing/_internal/fake_config_module2.py,sha256=-Uog-78Z6WL4MlhJQWg0Hi32H0JQVya5w9LZHpdbkxM,356 +torch/testing/_internal/fake_config_module3.py,sha256=JuTtpvkmOlFbcQqJPTTuKOogjxtD7kOVUGMD7Gj1DGo,229 +torch/testing/_internal/generated/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/testing/_internal/generated/__pycache__/__init__.cpython-39.pyc,, +torch/testing/_internal/generated/__pycache__/annotated_fn_args.cpython-39.pyc,, +torch/testing/_internal/generated/annotated_fn_args.py,sha256=gZPOaTQdeZR-xzYzehNS71r_S0__yaPOgFBXidM6--o,553385 +torch/testing/_internal/hop_db.py,sha256=tRLv4YB9w_IoexkOEBGePr8-UB3h7tja_3bc6w6AM8k,14355 +torch/testing/_internal/hypothesis_utils.py,sha256=QOpvAf9jAtt66RVHBzHokJfARpZ_J4eWICC78tPouU8,14989 +torch/testing/_internal/inductor_utils.py,sha256=_sVnaQRqKPjc1e36FaZN211hA-7iZbQru4NTj5Xr7AE,11481 +torch/testing/_internal/jit_metaprogramming_utils.py,sha256=IAnkB10EMJ-UwJZrhGgWyHiOzKSuSMpDlfrG8ECh6Y0,34782 +torch/testing/_internal/jit_utils.py,sha256=IB_8IV6Uk6ZxOcejcNWWwl-BQomNCCMeBokYCEfvxNg,34844 +torch/testing/_internal/logging_tensor.py,sha256=j_vIvb6D0IJB39z6tlMh7zTRe-WaGnGDW5VcvjabC-w,6814 +torch/testing/_internal/logging_utils.py,sha256=H_va5SrUNRI_O0qnHakwPZobJJKLsBwJboP_MQyMkvY,8448 +torch/testing/_internal/opinfo/__init__.py,sha256=gMzdjsnJzsmJrkvmnXErEc9HTeG-5r_AIznCn0YKhog,120 +torch/testing/_internal/opinfo/__pycache__/__init__.cpython-39.pyc,, +torch/testing/_internal/opinfo/__pycache__/core.cpython-39.pyc,, +torch/testing/_internal/opinfo/__pycache__/refs.cpython-39.pyc,, +torch/testing/_internal/opinfo/__pycache__/utils.cpython-39.pyc,, +torch/testing/_internal/opinfo/core.py,sha256=YdV4VNd2ZI771wxiB8ZJpNja38-1KW8lT1zHFM6dVC4,127164 +torch/testing/_internal/opinfo/definitions/__init__.py,sha256=IThISQWPf5p3Khcfvbbgw_fENb84fTYqJpi3bEZ07ZU,478 +torch/testing/_internal/opinfo/definitions/__pycache__/__init__.cpython-39.pyc,, +torch/testing/_internal/opinfo/definitions/__pycache__/_masked.cpython-39.pyc,, +torch/testing/_internal/opinfo/definitions/__pycache__/fft.cpython-39.pyc,, +torch/testing/_internal/opinfo/definitions/__pycache__/linalg.cpython-39.pyc,, +torch/testing/_internal/opinfo/definitions/__pycache__/nested.cpython-39.pyc,, +torch/testing/_internal/opinfo/definitions/__pycache__/signal.cpython-39.pyc,, +torch/testing/_internal/opinfo/definitions/__pycache__/sparse.cpython-39.pyc,, +torch/testing/_internal/opinfo/definitions/__pycache__/special.cpython-39.pyc,, +torch/testing/_internal/opinfo/definitions/_masked.py,sha256=LaSKtvZSyaYpKZF-fGCaxghSxmuxc82b4BxaHhnGZNw,47647 +torch/testing/_internal/opinfo/definitions/fft.py,sha256=VYe6XSYakDmeShN8ii30dmGU-oFtdJkkIz6DxOeNbdI,30254 +torch/testing/_internal/opinfo/definitions/linalg.py,sha256=t5wJ_eyjT52jqmeN0AwAVq7MpWgOSIkNY7oK67G3utw,86359 +torch/testing/_internal/opinfo/definitions/nested.py,sha256=ls7MHeB4xL93xWdojJBHJEV63cAPCWf4JDNJDSxOO64,61313 +torch/testing/_internal/opinfo/definitions/signal.py,sha256=gLJuXdJszNBVMXk3JfoM0fsfDnJbDsMkwB1u1rHoP3c,15801 +torch/testing/_internal/opinfo/definitions/sparse.py,sha256=if2e2kOzuV462vHxKmt5t3XNZUpfGes1uEgU1iWunwA,34705 +torch/testing/_internal/opinfo/definitions/special.py,sha256=8J4KJL9M17P7poZnnumh2C-2iYfT3sJ2qPPttTUbzjU,28379 +torch/testing/_internal/opinfo/refs.py,sha256=R9pnC9VvTcBBHQpkhYzlI3y6Xldhln2jV2RNp5vooGM,8246 +torch/testing/_internal/opinfo/utils.py,sha256=7lte4qJU1_t0YkNneJw5U6Is2YBxqFOzv0x6mO9aK6g,8996 +torch/testing/_internal/optests/__init__.py,sha256=FSuVnX16NUi5lWMkTRnykRD5t79ixmQnP5IzAda1sno,379 +torch/testing/_internal/optests/__pycache__/__init__.cpython-39.pyc,, +torch/testing/_internal/optests/__pycache__/aot_autograd.cpython-39.pyc,, +torch/testing/_internal/optests/__pycache__/autograd_registration.cpython-39.pyc,, +torch/testing/_internal/optests/__pycache__/fake_tensor.cpython-39.pyc,, +torch/testing/_internal/optests/__pycache__/generate_tests.cpython-39.pyc,, +torch/testing/_internal/optests/__pycache__/make_fx.cpython-39.pyc,, +torch/testing/_internal/optests/aot_autograd.py,sha256=-MCMJUpoRAVpJtoILoZLR5A9R_pTxWazy9Ufa2p518w,6589 +torch/testing/_internal/optests/autograd_registration.py,sha256=yY6JcMW5iTtZ--bNK6IbDsxvWcWGEyYFwcUtsc7sN3g,5824 +torch/testing/_internal/optests/fake_tensor.py,sha256=fe_mSRHFQYL6mzF8fybaBrLIOW9N6Retztpl_SNQXMY,269 +torch/testing/_internal/optests/generate_tests.py,sha256=zMiLrHStRfLYF7mxrqt2xA_LveaBAGYiTr8MvrmzN8Q,32628 +torch/testing/_internal/optests/make_fx.py,sha256=C4PqX2PYU8ZmQ5DF6Pc_VoBYidKxAQz-SLOWoc7onG8,3357 +torch/testing/_internal/quantization_torch_package_models.py,sha256=ZU4aq2H8SYQMgokcw1oOgalMNSFS5ekcjSeZn_49s-M,984 +torch/testing/_internal/static_module.py,sha256=KijDwJpMQ0RXVubvB5tSBEAmwEQLz_t8ZnJLkki0i3g,920 +torch/testing/_internal/subclasses.py,sha256=fkzG1SxYCzKMXwa0R6BoaNtjFw9vITB3QzQOSGN1oyY,2608 +torch/testing/_internal/test_module/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/testing/_internal/test_module/__pycache__/__init__.cpython-39.pyc,, +torch/testing/_internal/test_module/__pycache__/future_div.cpython-39.pyc,, +torch/testing/_internal/test_module/__pycache__/no_future_div.cpython-39.pyc,, +torch/testing/_internal/test_module/future_div.py,sha256=fM41286qBrBQbyt0_B_o-6nP8hL7CLgfhfHs1iWkDv0,124 +torch/testing/_internal/test_module/no_future_div.py,sha256=Xrd0DiR35Q8NNHYXdwr2R_sw2KM6ei-pNRBcMBa-rI8,156 +torch/testing/_internal/torchbind_impls.py,sha256=Ll2opciPt3A_rcLNFqt5f8ToEmF-S5V63Qgxau6XenM,5436 +torch/testing/_internal/triton_utils.py,sha256=kk6Ehr4aeiAcEc_ISoxgoPKD-sVR42CNojFPcG9i9f0,30298 +torch/testing/_internal/two_tensor.py,sha256=g024qYc9s7IogGcjuZhW3RlX3IgWEqjpI5R4Aq12vjo,3701 +torch/testing/_utils.py,sha256=rO-KBh1WmjoqfLGkJRpZBX6YNRhX1-5OMpQrvGmSpig,2091 +torch/torch_version.py,sha256=GebXGfR-qq13F_bhl1f68SgD_LH64QDaD9FR4_joYEU,2596 +torch/types.py,sha256=y9m8oeipOt2Sj4keEtEETVdHMi7ARkhFSnZRyPzDdVA,3817 +torch/utils/__init__.py,sha256=Q4ZUQMi6-SD-OVnBUyKw0ZSneqdhyWV14y1wsjpmC9Y,4174 +torch/utils/__pycache__/__init__.cpython-39.pyc,, +torch/utils/__pycache__/_appending_byte_serializer.cpython-39.pyc,, +torch/utils/__pycache__/_backport_slots.cpython-39.pyc,, +torch/utils/__pycache__/_config_module.cpython-39.pyc,, +torch/utils/__pycache__/_content_store.cpython-39.pyc,, +torch/utils/__pycache__/_contextlib.cpython-39.pyc,, +torch/utils/__pycache__/_cpp_embed_headers.cpython-39.pyc,, +torch/utils/__pycache__/_cpp_extension_versioner.cpython-39.pyc,, +torch/utils/__pycache__/_cxx_pytree.cpython-39.pyc,, +torch/utils/__pycache__/_device.cpython-39.pyc,, +torch/utils/__pycache__/_dtype_abbrs.cpython-39.pyc,, +torch/utils/__pycache__/_exposed_in.cpython-39.pyc,, +torch/utils/__pycache__/_filelock.cpython-39.pyc,, +torch/utils/__pycache__/_foreach_utils.cpython-39.pyc,, +torch/utils/__pycache__/_freeze.cpython-39.pyc,, +torch/utils/__pycache__/_functools.cpython-39.pyc,, +torch/utils/__pycache__/_get_clean_triton.cpython-39.pyc,, +torch/utils/__pycache__/_helion.cpython-39.pyc,, +torch/utils/__pycache__/_import_utils.cpython-39.pyc,, +torch/utils/__pycache__/_mode_utils.cpython-39.pyc,, +torch/utils/__pycache__/_ordered_set.cpython-39.pyc,, +torch/utils/__pycache__/_python_dispatch.cpython-39.pyc,, +torch/utils/__pycache__/_pytree.cpython-39.pyc,, +torch/utils/__pycache__/_stats.cpython-39.pyc,, +torch/utils/__pycache__/_thunk.cpython-39.pyc,, +torch/utils/__pycache__/_traceback.cpython-39.pyc,, +torch/utils/__pycache__/_triton.cpython-39.pyc,, +torch/utils/__pycache__/_typing_utils.cpython-39.pyc,, +torch/utils/__pycache__/_zip.cpython-39.pyc,, +torch/utils/__pycache__/backend_registration.cpython-39.pyc,, +torch/utils/__pycache__/bundled_inputs.cpython-39.pyc,, +torch/utils/__pycache__/checkpoint.cpython-39.pyc,, +torch/utils/__pycache__/collect_env.cpython-39.pyc,, +torch/utils/__pycache__/cpp_backtrace.cpython-39.pyc,, +torch/utils/__pycache__/cpp_extension.cpython-39.pyc,, +torch/utils/__pycache__/deterministic.cpython-39.pyc,, +torch/utils/__pycache__/dlpack.cpython-39.pyc,, +torch/utils/__pycache__/file_baton.cpython-39.pyc,, +torch/utils/__pycache__/flop_counter.cpython-39.pyc,, +torch/utils/__pycache__/hooks.cpython-39.pyc,, +torch/utils/__pycache__/mkldnn.cpython-39.pyc,, +torch/utils/__pycache__/mobile_optimizer.cpython-39.pyc,, +torch/utils/__pycache__/model_zoo.cpython-39.pyc,, +torch/utils/__pycache__/module_tracker.cpython-39.pyc,, +torch/utils/__pycache__/show_pickle.cpython-39.pyc,, +torch/utils/__pycache__/throughput_benchmark.cpython-39.pyc,, +torch/utils/__pycache__/weak.cpython-39.pyc,, +torch/utils/_appending_byte_serializer.py,sha256=Qob0trGcpwzQL3WH0gYR0JqHHLHgEYjH1uJDRGYijYU,3785 +torch/utils/_backport_slots.py,sha256=9QvJuHDztf-GKrro-38IPSy96PvBVM5bKTgn2W6_dUY,4721 +torch/utils/_config_module.py,sha256=44LJPz2D19I5sYyJ1cI0WFV8VvDmfIahzjXv8w99OrU,30604 +torch/utils/_config_typing.pyi,sha256=BXR7CFyr6Mn8BWt29IX8wwx3_oV9JTl7-8jUG1gi3qY,1253 +torch/utils/_content_store.py,sha256=ygoZJboYWVAz3qTZFvQ6vjYYisS69oH-61r8-RybD1E,9461 +torch/utils/_contextlib.py,sha256=jbtVyCb_YYadB4JMihkB_iY50DNu24mrc8IJAw5NUTU,6187 +torch/utils/_cpp_embed_headers.py,sha256=IxtvpYUAatGxwZzH38yYUgCj92q6cAYO6PaPGAakZas,1864 +torch/utils/_cpp_extension_versioner.py,sha256=8KHu7uVvfdPyFB0SWBclStfTaTz9Lgz0tVTCy6PiqE4,2001 +torch/utils/_cxx_pytree.py,sha256=2qWF3tkeDGrgjOqMUp9YwxbjLyM4nFKvysxHIsKnTSI,39413 +torch/utils/_device.py,sha256=DkcAevLyzKw0aBQpHz3gCJLUtz2CHtLJ7U5GdbxKQ8A,3827 +torch/utils/_dtype_abbrs.py,sha256=EltGnPMGj6vL9wHSfjE7yhw6K7W23ysqn6sTgIdgFA0,785 +torch/utils/_exposed_in.py,sha256=p5zMXcgQljOCL_cNC7ZnBv1mRBdXE0Uw9AU6Hw0ncrc,713 +torch/utils/_filelock.py,sha256=YpLSnQKHLH6Orn2A9dk480RGiIRPxWzaNPd00OZdRgc,1609 +torch/utils/_foreach_utils.py,sha256=k9VtVF4s0L6xX1g6qRRJmJsJvnGj_D22V6yNXB_970g,2487 +torch/utils/_freeze.py,sha256=4DWH7-A2_JFRq0znUi4tTHVT771K9YZ4mh-f41yt8PY,10288 +torch/utils/_functools.py,sha256=DqCmEGGfYhwVLJ6KQa3h4jCmmzMdDMn3rcs_AEE4A6c,1467 +torch/utils/_get_clean_triton.py,sha256=yo27UOHHhAXlXIQGYaFTC8pty_jyb6PK6AsgEibqqqg,7167 +torch/utils/_helion.py,sha256=ZOn4JLVBhawXV0B6YDvXUxI6w7h09TYlfBoyDgZVp40,381 +torch/utils/_import_utils.py,sha256=9mXbDXf8DzpvYa10Fanj8KKAYaR4j13jhdpyOfsYzGo,1373 +torch/utils/_mode_utils.py,sha256=--_9vhBRDmfUGMC5CX-tyoS-2zjvlGFxCswpAVP8b1g,270 +torch/utils/_ordered_set.py,sha256=Hszl6ejHlnJ8-qAemmXSkzDb6tSwofcJJo-sU6v1pFg,5688 +torch/utils/_python_dispatch.py,sha256=HBm44FhMvh_3hofmU3v3rxpOR5cPOmPtozCnE5eTZAk,29278 +torch/utils/_pytree.py,sha256=ti9Obm-1kFoTxTb-UmOsvbo2-QKqXDuCYBVm-YDo0_E,73877 +torch/utils/_stats.py,sha256=ZgWmhozlrgZMaTBytEIYfz-Qtv-Og-m5-_ese3F87q4,1044 +torch/utils/_strobelight/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/utils/_strobelight/__pycache__/__init__.cpython-39.pyc,, +torch/utils/_strobelight/__pycache__/cli_function_profiler.cpython-39.pyc,, +torch/utils/_strobelight/cli_function_profiler.py,sha256=0mv1QcJxvS8TmYjtE2x-CFh3XU-6M9QDiRj1qJCFIhY,11653 +torch/utils/_sympy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/utils/_sympy/__pycache__/__init__.cpython-39.pyc,, +torch/utils/_sympy/__pycache__/functions.cpython-39.pyc,, +torch/utils/_sympy/__pycache__/interp.cpython-39.pyc,, +torch/utils/_sympy/__pycache__/numbers.cpython-39.pyc,, +torch/utils/_sympy/__pycache__/printers.cpython-39.pyc,, +torch/utils/_sympy/__pycache__/reference.cpython-39.pyc,, +torch/utils/_sympy/__pycache__/singleton_int.cpython-39.pyc,, +torch/utils/_sympy/__pycache__/solve.cpython-39.pyc,, +torch/utils/_sympy/__pycache__/symbol.cpython-39.pyc,, +torch/utils/_sympy/__pycache__/value_ranges.cpython-39.pyc,, +torch/utils/_sympy/functions.py,sha256=MpW562iynuoIspY84DVH9oxBcHwmAByHdmwoYBY1KQk,52194 +torch/utils/_sympy/interp.py,sha256=lxKYev7VICwZ09C7Smv7K1Trta6o8mvm4xPDwc8ED-Y,7366 +torch/utils/_sympy/numbers.py,sha256=FvfpF4qMhvL2_U6uss9AEM3KtzVWVCRiAVT3Dzfju00,11798 +torch/utils/_sympy/printers.py,sha256=ndDeGRjIcLj4skY1zegGHPMXPbDD5WXxqzxdZuS5vZA,21061 +torch/utils/_sympy/reference.py,sha256=XX9JLmhoUnrKshsPPtfTZXqgc5VTE-eouCEDlPyFQxc,14185 +torch/utils/_sympy/singleton_int.py,sha256=qmZE5Y_eko7F0h6KPB45OfhG3eeWTy_xrefo4N4RHm8,3063 +torch/utils/_sympy/solve.py,sha256=zi7T-_GQg93xHRYQntivNT26saPHbxS9ICElEMfDuHo,6687 +torch/utils/_sympy/symbol.py,sha256=DEbfN8TSXIAzks1tut2st6qbic5_ROJ6oTGwMP7OCf8,3820 +torch/utils/_sympy/value_ranges.py,sha256=Umd02rVGTL_AniT0qin2nxqW42jw1P9jIaH88EE953A,36272 +torch/utils/_thunk.py,sha256=ANhkLaR7G9l0ck5ozEW-sTaDiruj-3BWkjgF9AnhTxg,653 +torch/utils/_traceback.py,sha256=GWIp0q3p5SumR6Zcb8HPV2JfidhKTHPkdvZoUS0UVb0,10508 +torch/utils/_triton.py,sha256=Paxta_MGhteCZNY8lZz4GajhuGCseG3OJhgVF-apjsE,4517 +torch/utils/_typing_utils.py,sha256=RLt9ulicp151yZcRNF2slxNkMuGmMhOLMMxWFO0IhcE,392 +torch/utils/_zip.py,sha256=ze4h6g6SGOGYbz3nNhNX0VOXcsOR3und13fpKV77BE4,2542 +torch/utils/backcompat/__init__.py,sha256=xNYwRpO2wm2roB8nqa-EEaeMRkBltXE9VB68R3bAH5g,689 +torch/utils/backcompat/__pycache__/__init__.cpython-39.pyc,, +torch/utils/backend_registration.py,sha256=yrZ2JEfIpT-h8cHjyoR76x_G-Z-6-_tkHbnWzS0lJM8,19925 +torch/utils/benchmark/__init__.py,sha256=HKYS3h8Ga20QnvZJYgekjWd2Dc82oreraUxoRCfNmjo,417 +torch/utils/benchmark/__pycache__/__init__.cpython-39.pyc,, +torch/utils/benchmark/examples/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/utils/benchmark/examples/__pycache__/__init__.cpython-39.pyc,, +torch/utils/benchmark/examples/__pycache__/compare.cpython-39.pyc,, +torch/utils/benchmark/examples/__pycache__/fuzzer.cpython-39.pyc,, +torch/utils/benchmark/examples/__pycache__/op_benchmark.cpython-39.pyc,, +torch/utils/benchmark/examples/__pycache__/simple_timeit.cpython-39.pyc,, +torch/utils/benchmark/examples/__pycache__/spectral_ops_fuzz_test.cpython-39.pyc,, +torch/utils/benchmark/examples/compare.py,sha256=lMgFmYuUuudfWeYBmBbhcTJCivzvX48G08V6IwXbILU,3014 +torch/utils/benchmark/examples/fuzzer.py,sha256=44i30rV1ZI2XvjjmgQwKn9uRHhPMzWu9-5q6qkiI7fI,2736 +torch/utils/benchmark/examples/op_benchmark.py,sha256=jL6vlb3CKQsrdJHtIPymHXoVlTA90IGs9VzZ-Wbv94s,4332 +torch/utils/benchmark/examples/simple_timeit.py,sha256=_yV8VfpGOyaPMGfc6c9i2XvTzFXZZWhUYvO3e9isY-0,566 +torch/utils/benchmark/examples/spectral_ops_fuzz_test.py,sha256=6sdnmWrpcPVZFfl7BxzIOCFrCFr8bVXTNGIpE9beZiU,4893 +torch/utils/benchmark/op_fuzzers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/utils/benchmark/op_fuzzers/__pycache__/__init__.cpython-39.pyc,, +torch/utils/benchmark/op_fuzzers/__pycache__/binary.cpython-39.pyc,, +torch/utils/benchmark/op_fuzzers/__pycache__/sparse_binary.cpython-39.pyc,, +torch/utils/benchmark/op_fuzzers/__pycache__/sparse_unary.cpython-39.pyc,, +torch/utils/benchmark/op_fuzzers/__pycache__/spectral.cpython-39.pyc,, +torch/utils/benchmark/op_fuzzers/__pycache__/unary.cpython-39.pyc,, +torch/utils/benchmark/op_fuzzers/binary.py,sha256=ysSfD2-iIH_4Z2qsFaNoKR1HtyMUK9LulcvG_6hil4Y,4243 +torch/utils/benchmark/op_fuzzers/sparse_binary.py,sha256=o7xSLUC8cuDBktdc7L3CXFR5o9Ryh6o-0YYqUFA4f0k,4325 +torch/utils/benchmark/op_fuzzers/sparse_unary.py,sha256=FGesxlpDbBsz-2dSrA91WUIC4ZsQmNuBXlgJdTlEBsI,3329 +torch/utils/benchmark/op_fuzzers/spectral.py,sha256=4VrLgIYZy9d25Wua5GJq9Qvl0XjZcU51gAtd92jiY1k,3718 +torch/utils/benchmark/op_fuzzers/unary.py,sha256=uGKaDaT895NcmpHPZbYascVqFqVkOdEF3MJc6JVTRSQ,3228 +torch/utils/benchmark/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/utils/benchmark/utils/__pycache__/__init__.cpython-39.pyc,, +torch/utils/benchmark/utils/__pycache__/_stubs.cpython-39.pyc,, +torch/utils/benchmark/utils/__pycache__/common.cpython-39.pyc,, +torch/utils/benchmark/utils/__pycache__/compare.cpython-39.pyc,, +torch/utils/benchmark/utils/__pycache__/compile.cpython-39.pyc,, +torch/utils/benchmark/utils/__pycache__/cpp_jit.cpython-39.pyc,, +torch/utils/benchmark/utils/__pycache__/fuzzer.cpython-39.pyc,, +torch/utils/benchmark/utils/__pycache__/sparse_fuzzer.cpython-39.pyc,, +torch/utils/benchmark/utils/__pycache__/timer.cpython-39.pyc,, +torch/utils/benchmark/utils/_stubs.py,sha256=wWKaADCERqO1JPDw-imxbRc0ByvaeHELX0a8y9Ci5Lo,1040 +torch/utils/benchmark/utils/common.py,sha256=3ZUXwEuG7X3dsP8ex5qKC9pkS5lTknk9fYUUoQaf8N4,14016 +torch/utils/benchmark/utils/compare.py,sha256=3j9Sub2AlbVjsrP5l418HRn9tgXRqwZWV5kkTbH7vKI,13615 +torch/utils/benchmark/utils/compile.py,sha256=h-W_aPnsluR52UYGZDVJLEYwggaRw0gCuQiuonRFvsI,7793 +torch/utils/benchmark/utils/cpp_jit.py,sha256=x0JzGYRMKOC7Pdg3l89UWjK4KnG4xX5R64_1iISwD6Q,6977 +torch/utils/benchmark/utils/fuzzer.py,sha256=aBap5jTVLAhp42ro1u6OC2N2IEyuVchZ8IAftOkPqeI,18829 +torch/utils/benchmark/utils/sparse_fuzzer.py,sha256=xRYh_eiM7WOAccrRpQXfjqfbb-XGKEgPTEdXP6qq_5Q,5281 +torch/utils/benchmark/utils/timeit_template.cpp,sha256=Va3o8cMDuSgEBzW9lRDPnLjLme5AXDjxlbYz7t8oFTU,1052 +torch/utils/benchmark/utils/timer.py,sha256=QyoQgjAlLST395OCraracXiZkq_9ozUJYkMJfqvmfZs,21740 +torch/utils/benchmark/utils/valgrind_wrapper/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/__init__.cpython-39.pyc,, +torch/utils/benchmark/utils/valgrind_wrapper/__pycache__/timer_interface.cpython-39.pyc,, +torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h,sha256=XCIF1Vdbl-Gx3F6t3FfIxGI5Nxg_fBriBQTvtWVAASQ,5873 +torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp,sha256=gPk9mKdSUgL8ld02ctjzbpYBCNwmzmejCGcrlzaTErI,848 +torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp,sha256=Mz8VIs_lJ0-swUxXsp7mX3U-B_IQNocow7BJAZGkmgA,1744 +torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py,sha256=haL3j8CU7e574bv100xV36xQmv1WBEHZ1tLWHFXCdik,38172 +torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h,sha256=duxM9eWNw9-pJLX9amnu0-5RBYcglE_pF-R0la7YS6s,429810 +torch/utils/bottleneck/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/utils/bottleneck/__main__.py,sha256=YdrGlMlmwjPZfMb4O3ntOPmIz0GonrJjwNjA2FwVhc8,7419 +torch/utils/bottleneck/__pycache__/__init__.cpython-39.pyc,, +torch/utils/bottleneck/__pycache__/__main__.cpython-39.pyc,, +torch/utils/bundled_inputs.py,sha256=hIcwed6j3cvppiSXK8nMIFF74T4VFSrzcQzb8YijI54,23052 +torch/utils/checkpoint.py,sha256=gdsivN-LSjTrCeUg-SfNxIlZj3l44GFULKq8GKXpfSs,69657 +torch/utils/collect_env.py,sha256=3jdOHEUDGJYVOVruOfq2smyzxrCNwdm6cAY6JtxkGkk,25200 +torch/utils/cpp_backtrace.py,sha256=2Kqo--zoORQcKGnK2UFIFfhw9QfunH-iE00FuAMtUFY,495 +torch/utils/cpp_extension.py,sha256=2DTX9ud03W13ZMTODXkmz57S8h21Os1lPoM66wGpauM,135426 +torch/utils/data/__init__.py,sha256=hLegEpoYshuRLS1FBfcK4lEFWutzOPnNLSyw0XRBP_Y,1731 +torch/utils/data/__pycache__/__init__.cpython-39.pyc,, +torch/utils/data/__pycache__/backward_compatibility.cpython-39.pyc,, +torch/utils/data/__pycache__/dataloader.cpython-39.pyc,, +torch/utils/data/__pycache__/dataset.cpython-39.pyc,, +torch/utils/data/__pycache__/distributed.cpython-39.pyc,, +torch/utils/data/__pycache__/graph.cpython-39.pyc,, +torch/utils/data/__pycache__/graph_settings.cpython-39.pyc,, +torch/utils/data/__pycache__/sampler.cpython-39.pyc,, +torch/utils/data/_utils/__init__.py,sha256=s6_rW5jvxF4QgKSFRFdMn6Mq6M8fGUE4jd1jSbY300k,1679 +torch/utils/data/_utils/__pycache__/__init__.cpython-39.pyc,, +torch/utils/data/_utils/__pycache__/collate.cpython-39.pyc,, +torch/utils/data/_utils/__pycache__/fetch.cpython-39.pyc,, +torch/utils/data/_utils/__pycache__/pin_memory.cpython-39.pyc,, +torch/utils/data/_utils/__pycache__/signal_handling.cpython-39.pyc,, +torch/utils/data/_utils/__pycache__/worker.cpython-39.pyc,, +torch/utils/data/_utils/collate.py,sha256=LrJTitn98hTbdiPq20HkYlyO6YLkuKG5oIf80Ua-Wxw,16362 +torch/utils/data/_utils/fetch.py,sha256=v2i1UiJf5WVWnNnuZzkLuouDlhthMYKqhop6mu2207k,2008 +torch/utils/data/_utils/pin_memory.py,sha256=DmVKNwBC5WqFhBY-MzXW3ft36WsketF6kn3G-cZHWjU,4561 +torch/utils/data/_utils/signal_handling.py,sha256=CAh-YmhwRc9N3o-Lh0lorW75NA8SEBD_aBwtfbM0mys,3250 +torch/utils/data/_utils/worker.py,sha256=0YcrPgL-fRskmQVQ35iRqLqBMnq5TC2R5qqqHdHi5ic,14174 +torch/utils/data/backward_compatibility.py,sha256=-KXzq0gPoz0oWnCs8fz0i1m1i3-e1ZBOU6_kKUf-KxA,320 +torch/utils/data/dataloader.py,sha256=YWHBKBRkIUDBab6T9D3zqsk0E2s7eQ-n1La8zrv0Y2A,81249 +torch/utils/data/datapipes/__init__.py,sha256=LwlHIcce6VnPe17EZHBEkX6e56_OIqO-gHTi8l5bUrU,89 +torch/utils/data/datapipes/__pycache__/__init__.cpython-39.pyc,, +torch/utils/data/datapipes/__pycache__/_decorator.cpython-39.pyc,, +torch/utils/data/datapipes/__pycache__/_hook_iterator.cpython-39.pyc,, +torch/utils/data/datapipes/__pycache__/_typing.cpython-39.pyc,, +torch/utils/data/datapipes/__pycache__/datapipe.cpython-39.pyc,, +torch/utils/data/datapipes/__pycache__/gen_pyi.cpython-39.pyc,, +torch/utils/data/datapipes/_decorator.py,sha256=-xGjiQF5wEm4GBNkLTNXPWM8-NJBlV3QBMdbFOSUhq0,8045 +torch/utils/data/datapipes/_hook_iterator.py,sha256=s4S1DGvxPhTE5xNCGDe5Sw41Yev09jMrAkcWHoyW6bc,12228 +torch/utils/data/datapipes/_typing.py,sha256=EfE0xahXP3ZHlJAAj444nsd_NXsOayEEsCjG5BqMEm4,16776 +torch/utils/data/datapipes/dataframe/__init__.py,sha256=0IbAghO5FXrmCE4FdZuzNIiB9bLdRai-Y0AigyQopPo,342 +torch/utils/data/datapipes/dataframe/__pycache__/__init__.cpython-39.pyc,, +torch/utils/data/datapipes/dataframe/__pycache__/dataframe_wrapper.cpython-39.pyc,, +torch/utils/data/datapipes/dataframe/__pycache__/dataframes.cpython-39.pyc,, +torch/utils/data/datapipes/dataframe/__pycache__/datapipes.cpython-39.pyc,, +torch/utils/data/datapipes/dataframe/__pycache__/structures.cpython-39.pyc,, +torch/utils/data/datapipes/dataframe/dataframe_wrapper.py,sha256=fDJVNgC8pC10rKyyn2Z9Z3VJL8aJiQzLEg1P8aIJYh0,3421 +torch/utils/data/datapipes/dataframe/dataframes.py,sha256=rJMweyoiHC_NxrMjajyD2JXIqdVBfkhiX14M2L3oLoI,13925 +torch/utils/data/datapipes/dataframe/datapipes.py,sha256=lF1da8hiatwnyhOqMwhoIvhlFiLFuPc3eyaaJdflaxA,4673 +torch/utils/data/datapipes/dataframe/structures.py,sha256=vIuh8h4Gj_mWZy49EsfYLyI1khN2NiZcMMiDTIYDyYo,684 +torch/utils/data/datapipes/datapipe.py,sha256=V51_PXA4nxNZqcDe_O4u6dNgQs0gbCd7LwN2euumrF4,17212 +torch/utils/data/datapipes/datapipe.pyi,sha256=563itfrD-HNZXXM40dZfKt1AqdYUYotfIQbRHCS8-Bo,33141 +torch/utils/data/datapipes/gen_pyi.py,sha256=DkJsnOM2LiKTgk_30m5gOwb5WKs8td8uw6Zs7lM4obI,12154 +torch/utils/data/datapipes/iter/__init__.py,sha256=GlUTH5tAd0CqDbBpwB4kP3KcQ9Vf7XmolxZ7sEM18z0,1880 +torch/utils/data/datapipes/iter/__pycache__/__init__.cpython-39.pyc,, +torch/utils/data/datapipes/iter/__pycache__/callable.cpython-39.pyc,, +torch/utils/data/datapipes/iter/__pycache__/combinatorics.cpython-39.pyc,, +torch/utils/data/datapipes/iter/__pycache__/combining.cpython-39.pyc,, +torch/utils/data/datapipes/iter/__pycache__/filelister.cpython-39.pyc,, +torch/utils/data/datapipes/iter/__pycache__/fileopener.cpython-39.pyc,, +torch/utils/data/datapipes/iter/__pycache__/grouping.cpython-39.pyc,, +torch/utils/data/datapipes/iter/__pycache__/routeddecoder.cpython-39.pyc,, +torch/utils/data/datapipes/iter/__pycache__/selecting.cpython-39.pyc,, +torch/utils/data/datapipes/iter/__pycache__/sharding.cpython-39.pyc,, +torch/utils/data/datapipes/iter/__pycache__/streamreader.cpython-39.pyc,, +torch/utils/data/datapipes/iter/__pycache__/utils.cpython-39.pyc,, +torch/utils/data/datapipes/iter/callable.py,sha256=Lfl4RlMWW9err2Dp6mG1Ow0DKGmwP2Zu_WKWIajrXiA,9307 +torch/utils/data/datapipes/iter/combinatorics.py,sha256=xNfHqi0AkldTxyyjcUTDlZlw7AQZ2zgcTijyVoVd4VY,6650 +torch/utils/data/datapipes/iter/combining.py,sha256=eHB7xZLvTpXhP5PM2-3q0eDQ3lqE5cDmuvtLCJD3nag,28014 +torch/utils/data/datapipes/iter/filelister.py,sha256=oI1QbmtlVZyQZ0YGFUHjykKr-jkCGfHG2fWdDWj_kdE,2664 +torch/utils/data/datapipes/iter/fileopener.py,sha256=Oz9Rxh-Zn3H7om4lxZfhWPgOrDu1fueHv6zcNUH-3BQ,2900 +torch/utils/data/datapipes/iter/grouping.py,sha256=lwVE1Hiq0aUw31ZjOO-O0khi4o8m9yR-zHGgMuPTqZE,12676 +torch/utils/data/datapipes/iter/routeddecoder.py,sha256=mp4srkbAmV87kMrIN7bdF-wpoTRhy-FXea0votY7U_s,2801 +torch/utils/data/datapipes/iter/selecting.py,sha256=c4K-ul55O8Ik-obCsshKmTQQV8ADKrk2XDw-ngE9pG4,3415 +torch/utils/data/datapipes/iter/sharding.py,sha256=zGmpVA8negHYPkBtXaN5WJ-nQIGutgOqmbCUYdoTL6s,3611 +torch/utils/data/datapipes/iter/streamreader.py,sha256=qS3wKPUit_9m5WQld_RpuWwYo9n3rixM1LAGC_qSPnA,1606 +torch/utils/data/datapipes/iter/utils.py,sha256=sMV9f61yUSGJUr-5iqntIFonG4qkiNPb743DXJhvnNc,1863 +torch/utils/data/datapipes/map/__init__.py,sha256=oaALDshhVqcb4hGlksTfovGF1hTp2uLsKUntU_oUUpE,686 +torch/utils/data/datapipes/map/__pycache__/__init__.cpython-39.pyc,, +torch/utils/data/datapipes/map/__pycache__/callable.cpython-39.pyc,, +torch/utils/data/datapipes/map/__pycache__/combinatorics.cpython-39.pyc,, +torch/utils/data/datapipes/map/__pycache__/combining.cpython-39.pyc,, +torch/utils/data/datapipes/map/__pycache__/grouping.cpython-39.pyc,, +torch/utils/data/datapipes/map/__pycache__/utils.cpython-39.pyc,, +torch/utils/data/datapipes/map/callable.py,sha256=mq5gvLNJu7iKatrRwQHK3ivivGNVRakROlJXCXM0GBI,1925 +torch/utils/data/datapipes/map/combinatorics.py,sha256=MkYLNcYlFjJjjfTCKCGZLcmAXIRRvr93St8-WjVQOWM,4321 +torch/utils/data/datapipes/map/combining.py,sha256=y1kJbS7_boC0FqpHOvlXh5Muuww0FRurlZDxP3NVu_8,3804 +torch/utils/data/datapipes/map/grouping.py,sha256=FGcEv44ptM-0hwxuT5kqtUzeFGFmcgCHk-_TMjwH_0I,2531 +torch/utils/data/datapipes/map/utils.py,sha256=T-o6wYR7rrC4bkQPW6ny068mX_oC3KhXOj9tczLvl6E,1628 +torch/utils/data/datapipes/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/utils/data/datapipes/utils/__pycache__/__init__.cpython-39.pyc,, +torch/utils/data/datapipes/utils/__pycache__/common.cpython-39.pyc,, +torch/utils/data/datapipes/utils/__pycache__/decoder.cpython-39.pyc,, +torch/utils/data/datapipes/utils/__pycache__/snapshot.cpython-39.pyc,, +torch/utils/data/datapipes/utils/common.py,sha256=x_qCiJPwS8xrtwTmI2qjR-5ayd4qjtTvbOQTahVKMxw,14112 +torch/utils/data/datapipes/utils/decoder.py,sha256=rJBIhroahFtXnKTmN9-QQTY2vQ3FYFz4ylSwGDyf4QE,12265 +torch/utils/data/datapipes/utils/snapshot.py,sha256=spA3vQv1uRcp1x-C2QYCxtKmd0O3zxwhMydbuX1_zQ0,3167 +torch/utils/data/dataset.py,sha256=vIOElaUV4nYh1JAVSZo-VPDJp-Kj5PSuFq7wV0FADIA,19952 +torch/utils/data/distributed.py,sha256=TUnjMHLIENqu7wGLajGQGTdqw9fW2YEbe2EoZsfP2E8,6275 +torch/utils/data/graph.py,sha256=k96CtTROv1wNEUAoiUsEeIGLaNVnBQM8cUhKenE18xg,5960 +torch/utils/data/graph_settings.py,sha256=0Cfe7DWL_AQFMq5VjEWhvNJnhLXqPr8cQEuG63RQQk8,5714 +torch/utils/data/sampler.py,sha256=fBMHLXRY3tSAc-qlg01DwzYX-uKttSvJU4cW_7RvcXk,13184 +torch/utils/deterministic.py,sha256=B4D1Q9GMAqQWEyFiOknB05rbroRGQ2uh7nfC_5oaV2w,633 +torch/utils/dlpack.py,sha256=7x6fU3fVJ6lLtvBLA8a-jH3mDIQRSRY01wy1FWgYudk,4636 +torch/utils/file_baton.py,sha256=bsaHUPgXI7Nz-Gb7jDzuiExAwINbAcz7v6R-MPgl8Ew,2103 +torch/utils/flop_counter.py,sha256=3-FKig-Fcs6hwZyH5s6yS9_qMfeTVbwGcdrLYbNkCXw,29567 +torch/utils/hipify/__init__.py,sha256=aF1qTFTEaCr30mm2ii0Wqs4ybLcLlwuiMraorC40TkU,34 +torch/utils/hipify/__pycache__/__init__.cpython-39.pyc,, +torch/utils/hipify/__pycache__/constants.cpython-39.pyc,, +torch/utils/hipify/__pycache__/cuda_to_hip_mappings.cpython-39.pyc,, +torch/utils/hipify/__pycache__/hipify_python.cpython-39.pyc,, +torch/utils/hipify/__pycache__/version.cpython-39.pyc,, +torch/utils/hipify/constants.py,sha256=rw3CEBOON8oXjimIdKy57SkimpFnPbE5nnfdgLAxvuQ,1236 +torch/utils/hipify/cuda_to_hip_mappings.py,sha256=HkmwmRxssfMOpDmTAYJr0Rt5npmTO_KdRR-qp8B7jpE,367220 +torch/utils/hipify/hipify_python.py,sha256=KmORIYP3mR7dJREYl6w2ex55PXYfGYcatVrH3shMeRQ,48230 +torch/utils/hipify/version.py,sha256=crlYURA_2Ujkpz7iVPHovvxZRtvt4QBXF5AAFJzRfz0,23 +torch/utils/hooks.py,sha256=DVIQEFrhzlyNwibYmRDb-6CwkjliFaujpFlzasvTB5g,10238 +torch/utils/jit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/utils/jit/__pycache__/__init__.cpython-39.pyc,, +torch/utils/jit/__pycache__/log_extract.cpython-39.pyc,, +torch/utils/jit/log_extract.py,sha256=4RNvziAAxYPhmFGbobjrj5_d811PO-qd9LfK7BwWi8k,3864 +torch/utils/mkldnn.py,sha256=DrhTx2yIld-PoGHs0VjA1XVqHcyK9nf0_2AW31SPeTU,8142 +torch/utils/mobile_optimizer.py,sha256=jp8-0WRXFvm6cl7jnPJGCup57cqdlPFLgD5b07TkuDU,6565 +torch/utils/model_dump/__init__.py,sha256=bthMwrGbjWg86Fh0hruOayBxTbyCyz80apAjktkAQMA,17192 +torch/utils/model_dump/__main__.py,sha256=svI1TUOVGqYtclCPw-W63BvkPoNLcCr7aDbGLvygZ24,84 +torch/utils/model_dump/__pycache__/__init__.cpython-39.pyc,, +torch/utils/model_dump/__pycache__/__main__.cpython-39.pyc,, +torch/utils/model_dump/code.js,sha256=scnF-tVarNCTh96M1XMkgLtxOW1HhjLDmQuepBT1ygQ,19940 +torch/utils/model_dump/htm.mjs,sha256=JOcbbeVY2ZKboFvrDklqA61fm3gOUUbGIpGurydshqM,1232 +torch/utils/model_dump/preact.mjs,sha256=FCb9XuNKpcEpGfo2Vk-0Xrem_mEjEMuAHa48ViJ8R5w,10080 +torch/utils/model_dump/skeleton.html,sha256=HfQCnjnPz6n9HwVKm-kN9tpxlNU18c9V2TMBCIPzxeU,405 +torch/utils/model_zoo.py,sha256=IYGYw6RMBc0ZkJbCnYnSHMScUIJIstopRVNiByTeS6E,119 +torch/utils/module_tracker.py,sha256=obqgqa4Elet-AJecEdUY-FdXMEoJ_VuJTun82tSjRCU,5544 +torch/utils/serialization/__init__.py,sha256=BX4TNoGsLpWO3SzTVc7H0MtPy4y6Tondm1oskIcBUEg,22 +torch/utils/serialization/__pycache__/__init__.cpython-39.pyc,, +torch/utils/serialization/__pycache__/config.cpython-39.pyc,, +torch/utils/serialization/config.py,sha256=jpeSFyWBTi37pAGMzTvcT0n_l5Biemq2L_jA2GCjIlk,683 +torch/utils/show_pickle.py,sha256=g0eCoQ25tRihN7x35na_A6qjj804O4wtq2sr4q117HA,5542 +torch/utils/tensorboard/__init__.py,sha256=VsJEmAeRYGsyXj7FSGn4UlhQIk8LlNI1KLyznx5TNuk,499 +torch/utils/tensorboard/__pycache__/__init__.cpython-39.pyc,, +torch/utils/tensorboard/__pycache__/_convert_np.cpython-39.pyc,, +torch/utils/tensorboard/__pycache__/_embedding.cpython-39.pyc,, +torch/utils/tensorboard/__pycache__/_onnx_graph.cpython-39.pyc,, +torch/utils/tensorboard/__pycache__/_proto_graph.cpython-39.pyc,, +torch/utils/tensorboard/__pycache__/_pytorch_graph.cpython-39.pyc,, +torch/utils/tensorboard/__pycache__/_utils.cpython-39.pyc,, +torch/utils/tensorboard/__pycache__/summary.cpython-39.pyc,, +torch/utils/tensorboard/__pycache__/writer.cpython-39.pyc,, +torch/utils/tensorboard/_convert_np.py,sha256=ARTgauWHS_k-n2U3IL3l4drNc6GVSeuZYokThfUuWuQ,767 +torch/utils/tensorboard/_embedding.py,sha256=A4LJrdz1lAVpaqAkx06W0HlqXph0pkb4RjwUNXupBcE,3310 +torch/utils/tensorboard/_onnx_graph.py,sha256=NwyaXYZBFdd3ncLR_o5tUi1xSEsXY_eyF4v9WBDgcpc,1948 +torch/utils/tensorboard/_proto_graph.py,sha256=Bj4ltnvvzYLKFrU1AVMsO8oIn41N8wlLx3_za-yQQRQ,1812 +torch/utils/tensorboard/_pytorch_graph.py,sha256=H-U0VdVLQrxzkM8-s-NUTPK7ueiG7bG9gca6HTJ4v90,14068 +torch/utils/tensorboard/_utils.py,sha256=UWCGh8XeP8639nwRpQHzlc2Y7zKkiTklHd8f_bM-vrw,4314 +torch/utils/tensorboard/summary.py,sha256=ii_Xp-n4rQvyODASAuO5mHWeG_j8u1ARaD_9hMHLQBE,35444 +torch/utils/tensorboard/writer.py,sha256=GUnjAoEfyVshyvg0TlcUk8cDeHJWydduauOHi2LkXOo,47867 +torch/utils/throughput_benchmark.py,sha256=D2bkOW6DmV3EL9gUtyOtCrUu6ttLy0_qTvrH2ESaokM,6662 +torch/utils/viz/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torch/utils/viz/__pycache__/__init__.cpython-39.pyc,, +torch/utils/viz/__pycache__/_cycles.cpython-39.pyc,, +torch/utils/viz/_cycles.py,sha256=pud6H4azf0_rJGM9WrVHG87FNBZc7EThwHqKz5g5MVk,17239 +torch/utils/weak.py,sha256=UhSqlsHRp6C-IhwboMwBSA5C6b-ezHNUUxiWN6J6k6A,11504 +torch/version.py,sha256=bb7_-IfccCBrAkdYU3BE6wIkKJLh7R2DndbkvYhKRME,286 +torch/xpu/__init__.py,sha256=6vRc6PO9YfIWM3Ma4kacxQniLeQ_l0KBuiCk9K6BjR4,18499 +torch/xpu/__pycache__/__init__.cpython-39.pyc,, +torch/xpu/__pycache__/_gpu_trace.cpython-39.pyc,, +torch/xpu/__pycache__/_utils.cpython-39.pyc,, +torch/xpu/__pycache__/memory.cpython-39.pyc,, +torch/xpu/__pycache__/random.cpython-39.pyc,, +torch/xpu/__pycache__/streams.cpython-39.pyc,, +torch/xpu/_gpu_trace.py,sha256=pDGjG58vYLW1Jm-uaJzIdyInbauIcDz9ry6sIsLcUlM,2424 +torch/xpu/_utils.py,sha256=M0jSr-gb8thX6GBTyCew6Zh2BpH01z2N8H26T5QH1RY,1630 +torch/xpu/memory.py,sha256=0PnU2vQ6WWHH9SodliP6p_sNVLFst5TU1gFfJFEOF5o,8236 +torch/xpu/random.py,sha256=55eJMHlJKCPjtDIOzYPrBx8eS3X7_iXjEBEPZoCFGaU,5416 +torch/xpu/streams.py,sha256=3qsCIc6VuKNUZR_RKiXcx2ndywowsbL1zeHhAfJf-P4,6031 +torchgen/__init__.py,sha256=H-SsfkUmMfiwvWtfi5h_wA6sZ18TnfDDgB_31K0Yt50,358 +torchgen/__pycache__/__init__.cpython-39.pyc,, +torchgen/__pycache__/code_template.cpython-39.pyc,, +torchgen/__pycache__/context.cpython-39.pyc,, +torchgen/__pycache__/gen.cpython-39.pyc,, +torchgen/__pycache__/gen_aoti_c_shim.cpython-39.pyc,, +torchgen/__pycache__/gen_backend_stubs.cpython-39.pyc,, +torchgen/__pycache__/gen_functionalization_type.cpython-39.pyc,, +torchgen/__pycache__/gen_lazy_tensor.cpython-39.pyc,, +torchgen/__pycache__/gen_schema_utils.cpython-39.pyc,, +torchgen/__pycache__/gen_vmap_plumbing.cpython-39.pyc,, +torchgen/__pycache__/local.cpython-39.pyc,, +torchgen/__pycache__/model.cpython-39.pyc,, +torchgen/__pycache__/native_function_generation.cpython-39.pyc,, +torchgen/__pycache__/utils.cpython-39.pyc,, +torchgen/__pycache__/yaml_utils.cpython-39.pyc,, +torchgen/aoti/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torchgen/aoti/__pycache__/__init__.cpython-39.pyc,, +torchgen/aoti/__pycache__/fallback_ops.cpython-39.pyc,, +torchgen/aoti/fallback_ops.py,sha256=cD7wwoc-RBifti01rcOglzplPMOy41M7Kr1egFop5e0,7658 +torchgen/api/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torchgen/api/__pycache__/__init__.cpython-39.pyc,, +torchgen/api/__pycache__/autograd.cpython-39.pyc,, +torchgen/api/__pycache__/cpp.cpython-39.pyc,, +torchgen/api/__pycache__/dispatcher.cpython-39.pyc,, +torchgen/api/__pycache__/functionalization.cpython-39.pyc,, +torchgen/api/__pycache__/lazy.cpython-39.pyc,, +torchgen/api/__pycache__/meta.cpython-39.pyc,, +torchgen/api/__pycache__/native.cpython-39.pyc,, +torchgen/api/__pycache__/python.cpython-39.pyc,, +torchgen/api/__pycache__/structured.cpython-39.pyc,, +torchgen/api/__pycache__/translate.cpython-39.pyc,, +torchgen/api/__pycache__/ufunc.cpython-39.pyc,, +torchgen/api/__pycache__/unboxing.cpython-39.pyc,, +torchgen/api/autograd.py,sha256=wEFa6DS2tEFksooNJs5xxj_O6goFgO1h35zsusSq-SI,39833 +torchgen/api/cpp.py,sha256=ef063uW9vxZL6KfG5GKKfdVA9v5OD-YKbNo0ajUgU_A,16748 +torchgen/api/dispatcher.py,sha256=mync3lYG5xKuaeehEcXdabZVV09a3HRRNxVKvL9lgr8,3604 +torchgen/api/functionalization.py,sha256=wUQBTWDpxOyCqOBCmHTZ5AavHy9jl1Ptn4WyGiTuY4Q,7765 +torchgen/api/lazy.py,sha256=j1A6sqn5QzBgMybI_myQRts1Kyd4MDHP5KkgRywEBlQ,17521 +torchgen/api/meta.py,sha256=jtTIUQQvK4WQwYBPiK7CBfH9ih7sYEPYnWlXcjJ7Nvs,496 +torchgen/api/native.py,sha256=pwJQt14yWscGRsOzXtl9Y7CmY-xlmibmkkx2HNg2rq0,5364 +torchgen/api/python.py,sha256=nkwki5yLfJ1ekCQWLgqr8fhgcHA51WWKk0mruBTeu9A,61235 +torchgen/api/structured.py,sha256=VlNYMT8aCGdZZjV9QmJro3yA0CQ4ytMIuO09a5p_C-U,6273 +torchgen/api/translate.py,sha256=qQjNcsYlvUZdrQ70rqxPpDIsKrMEaYyzXJvcZC0M9Oc,19734 +torchgen/api/types/__init__.py,sha256=C99zr_VE5z5WFZKdk93VW3ZN22NT4Hjxib3Pkk9T2VQ,149 +torchgen/api/types/__pycache__/__init__.cpython-39.pyc,, +torchgen/api/types/__pycache__/signatures.cpython-39.pyc,, +torchgen/api/types/__pycache__/types.cpython-39.pyc,, +torchgen/api/types/__pycache__/types_base.cpython-39.pyc,, +torchgen/api/types/signatures.py,sha256=bj_4UEkr96yFcm3UyzIpeBsWLkJa8FCNFJWm6TyExBg,16150 +torchgen/api/types/types.py,sha256=GzqS7_2xQyrArUs-tTw5c9L56wObzFg_uSbdUPFAhnA,6320 +torchgen/api/types/types_base.py,sha256=nKJ1diqqgQjOz_90qbQmwoh20-MrPy_j_lgXBYEOvKA,7422 +torchgen/api/ufunc.py,sha256=_CZftxat-hnDWlLlOHxtgm51BlusVVlWLzDmphJM3fg,6902 +torchgen/api/unboxing.py,sha256=eOLbJ1diUEiUN_xpw050jY1Nc6hUAEuf1kOSaIOOKbA,9622 +torchgen/code_template.py,sha256=IkcE64dLjeXrEHCzgmxzE-N7z-X7HyTwwinPcbjGRd4,3319 +torchgen/context.py,sha256=uNy8F--BVMSo-G5lBsbzuTSPTmC6egZeXIoaGhy6png,4190 +torchgen/dest/__init__.py,sha256=FbOiYHLsyT0FCZN7MHvnx0kkgdnee4TmmSueeDyqbQY,824 +torchgen/dest/__pycache__/__init__.cpython-39.pyc,, +torchgen/dest/__pycache__/lazy_ir.cpython-39.pyc,, +torchgen/dest/__pycache__/lazy_ts_lowering.cpython-39.pyc,, +torchgen/dest/__pycache__/native_functions.cpython-39.pyc,, +torchgen/dest/__pycache__/register_dispatch_key.cpython-39.pyc,, +torchgen/dest/__pycache__/ufunc.cpython-39.pyc,, +torchgen/dest/lazy_ir.py,sha256=XUWjQi8YpoEng3rE6tMEU5T9LVUX8-4pmqOQFf7hfVE,29697 +torchgen/dest/lazy_ts_lowering.py,sha256=ToSekeRQ09J9NOutqYAgfSxQQyGElWq18oPU_LYzoig,1879 +torchgen/dest/native_functions.py,sha256=RxIu9Jg92XUsNWp5wng8-tb3cafnzFLJRdl2srx66Dk,3255 +torchgen/dest/register_dispatch_key.py,sha256=UaGMCaAHUxCZhVocwpfIRy0fhfhusYNMLD4aeLRg-Rk,42500 +torchgen/dest/ufunc.py,sha256=aZpi7UMvDbB6lHQlCyS9bG5tg3_CDURAeNsawGw3hyM,18390 +torchgen/gen.py,sha256=zvuBR3NHiANedM_lh4OrWSOg_RFWYZXp7qaxpfepNEM,116470 +torchgen/gen_aoti_c_shim.py,sha256=LYUrapzhC1Xfm0g6XPhfuqrPkPy3SX56u-21grwPgdE,26513 +torchgen/gen_backend_stubs.py,sha256=fUfzEX40VFU2J5D6smR8op02k7xjqXgH3nxEEz5ZT18,23015 +torchgen/gen_functionalization_type.py,sha256=IEL7u5cPJVK8lpjnFRoyZRzuZAAdTGlrhS0U1K90p7o,39117 +torchgen/gen_lazy_tensor.py,sha256=k4exfZuM8Hooyn9NV7qSkkrzJeyzcJC9T93pj1U5A1A,23315 +torchgen/gen_schema_utils.py,sha256=vEO2-RS1uF-qEPhEY6UW5YdO3irQUHHQ9gTRmA-_p4M,3414 +torchgen/gen_vmap_plumbing.py,sha256=6prLLXZOoTaKKv-8goJ_oDXdlBQHBLSjDLlf4j6EIC8,9670 +torchgen/local.py,sha256=KfqAAZvSz2R1I88q-Q3LthJYh7Fgur346eVpOfM-7ww,2229 +torchgen/model.py,sha256=jSQyx6sNbtSf3wVNqqWETa7DSKWovHO3tXm6hJJtDcE,117284 +torchgen/native_function_generation.py,sha256=pJSFCzxD8xgrRQ5V26kEVig3Xkvg_HbrmewzFRa7P34,30416 +torchgen/operator_versions/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torchgen/operator_versions/__pycache__/__init__.cpython-39.pyc,, +torchgen/operator_versions/__pycache__/gen_mobile_upgraders.cpython-39.pyc,, +torchgen/operator_versions/__pycache__/gen_mobile_upgraders_constant.cpython-39.pyc,, +torchgen/operator_versions/gen_mobile_upgraders.py,sha256=Mup4EQSI9jyWKLJuvF6kKzAQYJgsZ0KzrktPEFSo8GM,12774 +torchgen/operator_versions/gen_mobile_upgraders_constant.py,sha256=W57M1HWkrQzY_Vw259_PPBTRH41YVC2fFg5k2Tt_yCA,250 +torchgen/packaged/ATen/native/native_functions.yaml,sha256=fsEeriHhKM5kovGYEYJEbmc94dQgG4xsfopsG3DlCwU,622333 +torchgen/packaged/ATen/native/tags.yaml,sha256=r0NkNCfqLUpnRPmW8sdMiGErvx3cIu-e6C2mr9t2VQ0,5379 +torchgen/packaged/ATen/templates/ATenOpList.cpp,sha256=2-bfZSImHz9rcx9v1LjiY1eT-btNz6A5UIMMi-F1i5k,1095 +torchgen/packaged/ATen/templates/CompositeViewCopyKernels.cpp,sha256=S7KFBelXFjlC6rxDimK2F-b5-WquNjJbxP1b_nhP7kk,2150 +torchgen/packaged/ATen/templates/DispatchKeyFunction.h,sha256=ApTWxDG1RjaC5DTjQpxbrkCaQmfWbdxUatDLM3nBP_E,725 +torchgen/packaged/ATen/templates/DispatchKeyFunctions.h,sha256=bW-d7rkZRIqmVkwOl8UmRLtjboz1XLb4nJUJ6ArRjC4,1966 +torchgen/packaged/ATen/templates/DispatchKeyFunctions_inl.h,sha256=ZAL9FVX2z3e-Dy7rFx7LpJ6B4Rgmb1n1zPFKlXjRJ9g,846 +torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp,sha256=eKjvIeiCkoiG-imx-bdSaSijsa6YCPddR3eQE1H4eyE,197 +torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.h,sha256=Dc80YgjeS7IKhlrvDbcFfF9dK1gRHKzAoAlDWHZpvGM,403 +torchgen/packaged/ATen/templates/Function.h,sha256=5BsGAqKe4rPVdAe-3M5Y_aPumTGrMY2K6obXKeZPM0A,546 +torchgen/packaged/ATen/templates/FunctionalInverses.h,sha256=6jdfM2a79k81bY_q1Qa2hYfBXlb19OZ_kMA2nxPS-GM,1264 +torchgen/packaged/ATen/templates/Functions.cpp,sha256=ZwVpQ6GifRW1oFZEVxIXEJL0YE3GG-sE0x-4li-pEac,3190 +torchgen/packaged/ATen/templates/Functions.h,sha256=m5ChjCleu5it8axXbNuLec9pU4U0Y0HBCxarkQAQfQI,4820 +torchgen/packaged/ATen/templates/LazyIr.h,sha256=o6jbaNUmVesCtDPi6W9REt8_hRrS6svFAe9xHiLwbJs,604 +torchgen/packaged/ATen/templates/LazyNonNativeIr.h,sha256=PKbQsvyOzry_u7iAtUnzKrF04BAXTIYY3zOMSnoDLBc,189 +torchgen/packaged/ATen/templates/MethodOperators.h,sha256=NClKeE6gjJAQpsZQGwOM6WqR3nmZOdIXMWkBi1ZldWE,854 +torchgen/packaged/ATen/templates/NativeFunction.h,sha256=m6DEzxpD_6O3VM_vBiMIM7iKktWt5-Y0CxaJ3LgDFgI,383 +torchgen/packaged/ATen/templates/NativeFunctions.h,sha256=Sbcv1aMyr_7HDs_ZOrHbdgSW7-Df9DHenZ1e7DhMXYo,1182 +torchgen/packaged/ATen/templates/NativeMetaFunction.h,sha256=cGSreOfBl6AZDUPu55pB8KHQBpZy-b9-kp7SZ1N5xLg,475 +torchgen/packaged/ATen/templates/NativeMetaFunctions.h,sha256=4yheSqOheiuuy2esGccOJ0oAyYzhOwyC3_04nrteL-A,325 +torchgen/packaged/ATen/templates/Operator.h,sha256=YHP7ZPqEV0yeCqz7TctBtp7CgDzWLZfcIkkghb16gOk,467 +torchgen/packaged/ATen/templates/Operators.cpp,sha256=bawk3V-uv13iD-ZDflZzzSpkKiVjN-rx22rvqaluYVo,366 +torchgen/packaged/ATen/templates/Operators.h,sha256=lLjk91v_znrZfyQ4-7vqPd2x4Rch9Qj2WuorJOUWdgc,3274 +torchgen/packaged/ATen/templates/RedispatchFunctions.cpp,sha256=cZTM_p-qgCTl6FnXmv4lEFxAshOZ1AyAPSRTdZq2Iqs,322 +torchgen/packaged/ATen/templates/RedispatchFunctions.h,sha256=j3mkCuYJU4Lk6qIUXzeBlF9_Y3Ey1lZvfbRbP_T1SJQ,914 +torchgen/packaged/ATen/templates/RegisterBackendSelect.cpp,sha256=gqG1UAo7Y7TmIxdeSI0snk4XZKMGg_qRsYibtqzGYaQ,781 +torchgen/packaged/ATen/templates/RegisterCodegenUnboxedKernels.cpp,sha256=V_V0w8dRrA4mXBTX_p2bKzy2xa_NceoH0zz4fJRaClg,1160 +torchgen/packaged/ATen/templates/RegisterDispatchDefinitions.ini,sha256=KfseidcHPRLoKSZJcpYSJiZeaNVoR5UvNpdo2Tz63II,477 +torchgen/packaged/ATen/templates/RegisterDispatchKey.cpp,sha256=kJQWVq42qjXG6I9AkOLMq3xOQh2lH7wICQ14frVNTMg,1629 +torchgen/packaged/ATen/templates/RegisterFunctionalization.cpp,sha256=QN9m6cpSyEneXQYa9_4cF9Feb57-g4cpgdMCRFJFc1Q,3580 +torchgen/packaged/ATen/templates/RegisterSchema.cpp,sha256=uvGTom__-QBGFWs6mdwiOuAqULY8PtkSQ-W0AVzIgrY,396 +torchgen/packaged/ATen/templates/RegistrationDeclarations.h,sha256=omjHl6z7mF7KpRdd7PKEbHplEHACz0IQC3_x0nvO4vo,164 +torchgen/packaged/ATen/templates/TensorBody.h,sha256=BWQ6y2UM2VQcJS02D4O5raUY6AcsbBZOpmkXwGw_bP0,29901 +torchgen/packaged/ATen/templates/TensorMethods.cpp,sha256=JKGLG8yhCI4oYWGnepE6UdU62hsdezzND6T7Q7itBA4,2674 +torchgen/packaged/ATen/templates/UfuncCPU.cpp,sha256=sNuukoCbDy034J9FalXIP1B49QqkcRqg0nGrKMhzSFE,464 +torchgen/packaged/ATen/templates/UfuncCPUKernel.cpp,sha256=rtFLOYPfpBLyUAIeI_MSgnNMIxpGFvPd-jCtxwQ5gdU,364 +torchgen/packaged/ATen/templates/UfuncCUDA.cu,sha256=9nuaFvMTZlWn4_FdRL4gB886BG7kqS6MEW62s1AhnKg,515 +torchgen/packaged/ATen/templates/UnboxingFunctions.cpp,sha256=U83PTJxwRjKE2tkNzKmigWYHbFLB0a_MIUdWLj8atPY,744 +torchgen/packaged/ATen/templates/UnboxingFunctions.h,sha256=uOkK0X-nci7gucn7MD1c5pL-Fh52I464rOhzYJWHSyg,1058 +torchgen/packaged/ATen/templates/aten_interned_strings.h,sha256=ySu-vkTTqSEIOpgKhCg_v7tJab1eOWBEqKUGBePJr4g,827 +torchgen/packaged/ATen/templates/enum_tag.h,sha256=noIV15PR8ybSfEW7LdykYBK5da9O_NRM8QW-DV6WKlU,189 +torchgen/packaged/autograd/BUILD.bazel,sha256=5k-uc2W6OMFb0NuoB7VA91YyflbdR0YEmApN3V6r5SI,108 +torchgen/packaged/autograd/README.md,sha256=Z_hBdZ0049gw9XLEVrbmrxuEKF59h8pHnEDP3iZZ_aY,150 +torchgen/packaged/autograd/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torchgen/packaged/autograd/__pycache__/__init__.cpython-39.pyc,, +torchgen/packaged/autograd/__pycache__/context.cpython-39.pyc,, +torchgen/packaged/autograd/__pycache__/gen_annotated_fn_args.cpython-39.pyc,, +torchgen/packaged/autograd/__pycache__/gen_autograd.cpython-39.pyc,, +torchgen/packaged/autograd/__pycache__/gen_autograd_functions.cpython-39.pyc,, +torchgen/packaged/autograd/__pycache__/gen_inplace_or_view_type.cpython-39.pyc,, +torchgen/packaged/autograd/__pycache__/gen_python_functions.cpython-39.pyc,, +torchgen/packaged/autograd/__pycache__/gen_trace_type.cpython-39.pyc,, +torchgen/packaged/autograd/__pycache__/gen_variable_factories.cpython-39.pyc,, +torchgen/packaged/autograd/__pycache__/gen_variable_type.cpython-39.pyc,, +torchgen/packaged/autograd/__pycache__/gen_view_funcs.cpython-39.pyc,, +torchgen/packaged/autograd/__pycache__/load_derivatives.cpython-39.pyc,, +torchgen/packaged/autograd/build.bzl,sha256=WB6MNefjiTn2EynM2msiIxLhBsrr_QG_gkQK42VN9do,362 +torchgen/packaged/autograd/context.py,sha256=2imEJAHPJDT0t-Sa69FUtKS_-4b9FEU5MiSiezfmprc,976 +torchgen/packaged/autograd/deprecated.yaml,sha256=nqwyinCg7yaWqCSYnt2K3e4cN9UhOn4EHnyjHqPbFAA,6384 +torchgen/packaged/autograd/derivatives.yaml,sha256=CfJG2dvlA4lmIhi6ePkR5wBxI6Yn53ewCp51YSFu0hk,184798 +torchgen/packaged/autograd/gen_annotated_fn_args.py,sha256=3ujwdrI7x5rPZuvegMM-K-AQVdqHcvyoUdu8B6ZTHr8,4610 +torchgen/packaged/autograd/gen_autograd.py,sha256=RdGuCFgSWUBk8zRH7LaHA-wzlmOF2iY5n2jFPGuejSk,4764 +torchgen/packaged/autograd/gen_autograd_functions.py,sha256=Q5r6e3ezSXHt040NbvDcV9tdmszoGbLLTtio0qzsXu0,39182 +torchgen/packaged/autograd/gen_inplace_or_view_type.py,sha256=K-G8QTIov-MY4qI7uTLUKe5whiv3Cu0LKoE671tYpLI,23394 +torchgen/packaged/autograd/gen_python_functions.py,sha256=pU0S5ZzS0--r3CypP4Yhe6HlrrMtW8caikmoPQQzRI8,47763 +torchgen/packaged/autograd/gen_trace_type.py,sha256=75YWS0C2HpjEowo4s50E9WarYAF_zB3eFrcqJWClcfY,19513 +torchgen/packaged/autograd/gen_variable_factories.py,sha256=ohsSThgYeGz1DkDjZmkwDEHroFd3tFn5WNKGvd7n4FQ,4595 +torchgen/packaged/autograd/gen_variable_type.py,sha256=Zth4LPpdsXUFkZ_vOjRVTCQv7bXtwYbJ4i-KSAjeHAg,86007 +torchgen/packaged/autograd/gen_view_funcs.py,sha256=_IZ2E1angL25yGRYJDit1smJXRpBbD7UyFurJYi3NC0,11928 +torchgen/packaged/autograd/load_derivatives.py,sha256=T6S1kHXd2fLoaP5HX1lJq6SmHMZOC-N645ITZxwAYII,41357 +torchgen/packaged/autograd/templates/ADInplaceOrViewType.cpp,sha256=Y_aTR4cYflFtyBbWYS4ukPk7YVV39GgZbbsXSJFx-PE,828 +torchgen/packaged/autograd/templates/Functions.cpp,sha256=wbbspcHjtkWMoo65sb-oz-v_9a4oeUMjHz182ByPzLA,1522 +torchgen/packaged/autograd/templates/Functions.h,sha256=6SmhZf2IIxPrVvf5xjaoLeCmkBfLLFqnFUhlSOhGXF0,1628 +torchgen/packaged/autograd/templates/TraceType.cpp,sha256=atz5-v313Iyl9rgCfl47_ejuCZfLey1bBwrzXqXaAZc,735 +torchgen/packaged/autograd/templates/VariableType.cpp,sha256=hALVtwdH4Tax8xlZ0P3mlV8DeG8cg8_K00CVDFBo16M,1917 +torchgen/packaged/autograd/templates/VariableType.h,sha256=YR-TpqKjlrG5RHXJx_-XhQiUuI4bmMrm7OR_OhOlAcs,1522 +torchgen/packaged/autograd/templates/ViewFuncs.cpp,sha256=ZaBU6mFMxrFHUV-mJ5C5G-bnMCUJJIERx6mkv6FlA4s,283 +torchgen/packaged/autograd/templates/ViewFuncs.h,sha256=97-ircmuaEKtzul3qUwdmYAS2QdRwCwPYFanyWQ6Y5c,526 +torchgen/packaged/autograd/templates/annotated_fn_args.py.in,sha256=lRK-e23JYHNU9tOhKCPyvHe0VCkfE0oCTHPY870xcms,210 +torchgen/packaged/autograd/templates/python_enum_tag.cpp,sha256=yJ8tkTuDfywePj1rHVMFUAb1W94bLXzvEvcIHl0JFGc,510 +torchgen/packaged/autograd/templates/python_fft_functions.cpp,sha256=hX3IY0ydw_IKpZtItWPAufuPNEb-VnwIPBh7XNCIhH0,2032 +torchgen/packaged/autograd/templates/python_functions.cpp,sha256=9hWIAuGR4McZ3-3eGByN_-Xrs3MarHG3YZc8FX69Ztw,1158 +torchgen/packaged/autograd/templates/python_functions.h,sha256=sPbeuOK7oLweIhgvOSeaiowQoAy4OGT95GwAxo9uprs,362 +torchgen/packaged/autograd/templates/python_linalg_functions.cpp,sha256=Bn0M_Y087pMuOI-2yKIAc4f6zH2j_STCPLeWnbmZunc,1682 +torchgen/packaged/autograd/templates/python_nested_functions.cpp,sha256=jUWc9gx-WTyO1IU3x705D67icWO8zihMd12Q3d02lLM,2110 +torchgen/packaged/autograd/templates/python_nn_functions.cpp,sha256=gb6nXP7vliFyuZLREsv4d0Gs0WFfHs9-pWj63BaLEYQ,3614 +torchgen/packaged/autograd/templates/python_return_types.cpp,sha256=rF1xv6sEAKkT5Qr4fFUulozpXikWddFyakECVq9Sn_I,1271 +torchgen/packaged/autograd/templates/python_return_types.h,sha256=qezAy6s6ZpEt5CfCINVH2Pnjn1sRMZcFa_kCz_7b3GQ,212 +torchgen/packaged/autograd/templates/python_sparse_functions.cpp,sha256=J2mvrC6J71O3iEbob45odDb7uuY8UvvpQSiJRSn8nYg,1618 +torchgen/packaged/autograd/templates/python_special_functions.cpp,sha256=M8gD-seaEjkXjV0Qq4GEURKH_aGWYjPcun8GoMwD2Ek,2051 +torchgen/packaged/autograd/templates/python_torch_functions.cpp,sha256=GiTOx6SYABBjRDQjYbLE1_wZehaEYj6D2L_KqTTTvQA,2694 +torchgen/packaged/autograd/templates/python_variable_methods.cpp,sha256=cfEjtxQhAqku8Zjm16qZUvqvWVDJqrk-t7Gm-oJ1I54,55306 +torchgen/packaged/autograd/templates/variable_factories.h,sha256=QFuWYylBWC-m7MdtD7y_SJV5hb8tLfPIMC6ZVPLy6AI,5772 +torchgen/selective_build/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torchgen/selective_build/__pycache__/__init__.cpython-39.pyc,, +torchgen/selective_build/__pycache__/operator.cpython-39.pyc,, +torchgen/selective_build/__pycache__/selector.cpython-39.pyc,, +torchgen/selective_build/operator.py,sha256=hNnVnr_mBoWaKYOQ1pRe8IPVjdc7VOrOYitha4iWt7k,6680 +torchgen/selective_build/selector.py,sha256=PQqSTqfo06ancrqAkGnfUEuG9FFaT6jSg3QrUPO47A0,13018 +torchgen/static_runtime/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +torchgen/static_runtime/__pycache__/__init__.cpython-39.pyc,, +torchgen/static_runtime/__pycache__/config.cpython-39.pyc,, +torchgen/static_runtime/__pycache__/gen_static_runtime_ops.cpython-39.pyc,, +torchgen/static_runtime/__pycache__/generator.cpython-39.pyc,, +torchgen/static_runtime/config.py,sha256=jzyW-CalPLeaX8djIxs1uiJgEuBS7NZC7EzTHPVg9rs,14875 +torchgen/static_runtime/gen_static_runtime_ops.py,sha256=QzQWI-JEDMYVakP2hHhu8bRVRt5Y_GGC6cBzBJenUAA,7639 +torchgen/static_runtime/generator.py,sha256=KQc0SilAICJFjnDco0yoCrbX5JKuhjvOZoPL_dgVp2Y,27930 +torchgen/utils.py,sha256=Nh865ZF7tjqJ1n7xzcAyzD_Q2OHdELD5phteKuT1xe8,19152 +torchgen/yaml_utils.py,sha256=liGsJa9N5rQ41u4RvEd6-1oYbiS1aPThHAKDynPqHnA,1106 diff --git a/phivenv/Lib/site-packages/torch-2.8.0.dist-info/REQUESTED b/phivenv/Lib/site-packages/torch-2.8.0.dist-info/REQUESTED new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch-2.8.0.dist-info/WHEEL b/phivenv/Lib/site-packages/torch-2.8.0.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..b7132af6d27aab726a7499fc58ccd63c206a0a33 --- /dev/null +++ b/phivenv/Lib/site-packages/torch-2.8.0.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.43.0) +Root-Is-Purelib: false +Tag: cp39-cp39-win_amd64 + diff --git a/phivenv/Lib/site-packages/torch-2.8.0.dist-info/entry_points.txt b/phivenv/Lib/site-packages/torch-2.8.0.dist-info/entry_points.txt new file mode 100644 index 0000000000000000000000000000000000000000..43e5bc29823736ad435669527c514ef7980aaf7c --- /dev/null +++ b/phivenv/Lib/site-packages/torch-2.8.0.dist-info/entry_points.txt @@ -0,0 +1,6 @@ +[console_scripts] +torchfrtrace = tools.flight_recorder.fr_trace:main +torchrun = torch.distributed.run:main + +[torchrun.logs_specs] +default = torch.distributed.elastic.multiprocessing:DefaultLogsSpecs diff --git a/phivenv/Lib/site-packages/torch-2.8.0.dist-info/top_level.txt b/phivenv/Lib/site-packages/torch-2.8.0.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..90d81bec1d35eb3996334268974885a5398b6c6c --- /dev/null +++ b/phivenv/Lib/site-packages/torch-2.8.0.dist-info/top_level.txt @@ -0,0 +1,3 @@ +functorch +torch +torchgen diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_VF.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_VF.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..244e60d870e67e4a28cc2c9478287e380a597b31 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_VF.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/__config__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/__config__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed311f4913225f6f39d3d8328a05e9000fc33572 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/__config__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/__future__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/__future__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ba6bbde7871a568389ef06878cb03978ded648a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/__future__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fe86ff511cab107b1cb7c6bac213c73cf63392f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_appdirs.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_appdirs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40b67e1404dc485185291866e0cf703f1a7a5612 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_appdirs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_classes.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_classes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2bbf033e392f637772e6829c73f7580beb0981a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_classes.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_compile.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_compile.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f86b92f50f0787d9b82e9c93d666cd1ba1532ea Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_compile.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_custom_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_custom_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8835e22749101065b295dd1cd7c22a39af8761f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_custom_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_deploy.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_deploy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9be18b7ce41d614af0b170e575016ad3aa57d96 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_deploy.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_environment.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_environment.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9ecf2c16e4d699fbcf174382b80aade98a6a1b7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_environment.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_guards.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_guards.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21fa20f5f5d054de4aa67c9f0322e8841808d8b9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_guards.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_jit_internal.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_jit_internal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfc207d54783201c10ac67ff59944729c769a0c1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_jit_internal.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_linalg_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_linalg_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..897ea1de05dd184d37748b92c2b36bc24a70d419 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_linalg_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_lobpcg.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_lobpcg.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e5e5e69775282e90ae0266cfe3b7332635315cb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_lobpcg.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_lowrank.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_lowrank.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fb3aed41bca7504be73bbb0ce8fd2e605960d33 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_lowrank.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_namedtensor_internals.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_namedtensor_internals.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a641feb7f5539f0b888ab5e92d71b69d19558e10 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_namedtensor_internals.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b67c229e00b295162756f28b51a6457220fcaa99 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_python_dispatcher.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_python_dispatcher.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1d8ff8ddbb5bc1dc5ecc5845d9488059f6aff1a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_python_dispatcher.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_size_docs.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_size_docs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f8b3bbf0845651edae516cc3de30979336ac687 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_size_docs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_sources.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_sources.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f82b20b55ffc209c14eff45ffe73812abdc705cc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_sources.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_storage_docs.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_storage_docs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51f21a04a31681aebd951fd43bb47bd0e1119307 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_storage_docs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_streambase.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_streambase.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68bf62624a0a317e7fee46c7fa01bdbf07ab1b94 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_streambase.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_tensor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_tensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5ce5548fcc8042fa1c7939efff009d0853f950b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_tensor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_tensor_str.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_tensor_str.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a1d6abf63e5d4daba5e171c5606ebfd2637fce2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_tensor_str.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_thread_safe_fork.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_thread_safe_fork.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55f6bfdce40c16eebe842f0a7dc245645ecdc69a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_thread_safe_fork.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe3394a79bb6781428f0f0ffde65eeb60a9b7ec0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_utils_internal.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_utils_internal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c9486f71b51d83a3407b7e81ccb1bdb4b1604ac Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_utils_internal.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_vmap_internals.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_vmap_internals.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd7361e24f545900501d1f7c900aaa4dfc1581c8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_vmap_internals.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/_weights_only_unpickler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/_weights_only_unpickler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20af00bd8b652a53467e0c92ef5b40f58e33429c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/_weights_only_unpickler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/functional.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/functional.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a8fad5067938b4291a5c7079cc18301ebe59011 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/functional.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/hub.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/hub.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e780b2c9fbd906a366c4dbca4e8c23099076a08 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/hub.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/library.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/library.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7933adb15e93c730b1531aa29a28d273cb07f2c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/library.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/quasirandom.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/quasirandom.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82b9c8363eb7c17ef302021abb352ab642ff7bd4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/quasirandom.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/random.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/random.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08c44721250e5811b1986821afdaa2e9bbcfea0a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/random.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/return_types.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/return_types.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37535a943f799efbdf147ddc53bb4e1850ad5ed4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/return_types.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/serialization.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/serialization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9652f219e8402e5694747bc8d3c1c507fd814ad Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/serialization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/storage.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/storage.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b48cad9414c7dc38901a3458b1b89cc16d4aac0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/storage.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/torch_version.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/torch_version.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f9c11f87bba7206f452db0aacbf67bb7bd1b489 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/torch_version.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/types.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/types.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4931fc5c3234aed77cad4bfcd9204bb9a57910c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/types.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/__pycache__/version.cpython-39.pyc b/phivenv/Lib/site-packages/torch/__pycache__/version.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35f9c655f8d243cefb81f6213990e64fe607f8d1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/__pycache__/version.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/__autotune_main__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/__autotune_main__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28a2a9a30d3a9e5f97dee965673dd3bf80e5f1ac Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/__autotune_main__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..381cf5468e83278823d1ceb928d851d6b41e66c0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/analyze_preserves_zero_mask.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/analyze_preserves_zero_mask.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09d600effdb3ff843abaa3ef794862ac5ed4cc00 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/analyze_preserves_zero_mask.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/aoti_eager.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/aoti_eager.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a284de710893ac70a96b0389e8efa091fe86c7fe Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/aoti_eager.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/async_compile.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/async_compile.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0afc718bb4584afa635a62a37b013e07bb3cca21 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/async_compile.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a068c8133c41b909222ff317d467a7f56969fa9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/autotune_process.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/bounds.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/bounds.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63e35c820bd0b4918986df87c2ee6ae018814636 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/bounds.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/choices.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/choices.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c03739abea9de678b495a1813f18f21ba917ded3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/choices.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0f651cfa47b9f133b22c70e8b28c4d77c9968f7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/comm_analysis.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/comm_lowering.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/comm_lowering.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1633e1ba126f2ff023d220f659749a0f6bed81ee Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/comm_lowering.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/comms.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/comms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e9b1be297681306a58cca2b70e070346ad688e6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/comms.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/compile_fx.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/compile_fx.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7be1e4d850ab6f3f4a094f964bec0570d8974fe5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/compile_fx.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/compile_fx_async.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/compile_fx_async.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f39304dcd288042d26664d6c235cb3b0c5606ece Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/compile_fx_async.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/compile_fx_ext.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/compile_fx_ext.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d8854b32da5ad892fd107fd545dcfcaeb11496b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/compile_fx_ext.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/compile_fx_subproc.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/compile_fx_subproc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f79db119f8ff9d3c53830db5652b5ba2a71d04a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/compile_fx_subproc.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/compiler_bisector.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/compiler_bisector.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afe4027f494f3b4d93955388451ff26e9eafe65e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/compiler_bisector.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/config.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79b9e3752abe2e60be1c1c281268ecb8a45b1b6d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/config.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..880babfaf7b238861bc71a95556e1beed015b2da Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/constant_folding.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e14052dee825fe52f4ccdf0bcb1f7193f0dc6ba Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/cpp_builder.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9292ec7f24b8b0d4c0b6eb865ac12bae6a3e839 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/cpu_vec_isa.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a4134e004130bb6b1bef175704904e3c7e5a998 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_trees.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba2a444c44da4e339493ce6c165e83c63bbeefd3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/cudagraph_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/custom_graph_pass.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/custom_graph_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5c5072431430b5c5b2725e398ef13e895bfad3d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/custom_graph_pass.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/debug.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/debug.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c39d8fabb2c0e5d928a96cc6cbdb6b470787ffeb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/debug.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/decomposition.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/decomposition.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75ca403a5b892519771eed9e09e72b3b7e22639a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/decomposition.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/dependencies.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/dependencies.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab33ea10cbee37ae308594ca13727013cc00c035 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/dependencies.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/dtype_propagation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/dtype_propagation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b0814ea442e4d45e4b5518d8d7810c29573c619 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/dtype_propagation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/exc.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/exc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7410f931e5e9639ca3ffcf6104d8f3559500ed10 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/exc.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/extern_node_serializer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/extern_node_serializer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1c25bb67639366f1d24e8137ef4087b8c6a90e2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/extern_node_serializer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/freezing.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/freezing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3b8599b3970ba9b2d17bd75e292f75971e98409 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/freezing.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/freezing_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/freezing_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdd7f9a9482bbd220f215ba94e45ea824177bea4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/freezing_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/fuzzer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/fuzzer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8b8c4088ce1cd16baba259f913d9bc65bb3b464 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/fuzzer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..872e937655ee76cf8a43dea9ef0c55e3e754ed0c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/fx_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/graph.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71e57f526d092b5ef5bacd35ef48a04f3906b772 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/graph.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/hooks.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/hooks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba15036a6ab56d71efdc70538c938f8e89179c0c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/hooks.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2c1be0281d080fc742e5aa22571327ff5469a5a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/index_propagation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7502828fee960108e6db4efeef4420df318195f6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/inductor_prims.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/jagged_lowerings.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/jagged_lowerings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7ba43dd7111874d3f02f4802ddcca43830554a9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/jagged_lowerings.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/loop_body.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/loop_body.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f05b83993fe166f2acfb7feb7c5ba39f4604ac2c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/loop_body.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/memory.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/memory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec3ce3f0ede300edcd1f83eaa0792997df596c8e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/memory.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/metrics.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/metrics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2de0cc9cfca1b62e483779ea9b733d22c5cb8224 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/metrics.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/mkldnn_ir.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/mkldnn_ir.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cdc84667075df222b3aa1602e068218dfc9e4c5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/mkldnn_ir.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/mkldnn_lowerings.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/mkldnn_lowerings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..456d5e5b9684d8b1e0719b326327ff7f6feaefa0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/mkldnn_lowerings.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/mock_cache.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/mock_cache.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..deff9494696aff61551d0dbb3cc754c5d9bbb861 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/mock_cache.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3294aeea66fd96e94c6da2eee610673c9c487dc4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/ops_handler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c933eec06da52fdb1ae01b68723bfce2d4bfca80 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/optimize_indexing.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/output_code.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/output_code.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21f63c092f89f59f8ef84bdc3ad5a6b37de20776 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/output_code.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c3fbc03c70622bd165da1649b6151c13a0095b7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/pattern_matcher.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fd98b3dc2d5dcf7245ae85431b55c7c43ee6285 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/quantized_lowerings.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/remote_cache.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/remote_cache.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a3b3cb5d1adbde58d9e68a32defa841a17dfd5c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/remote_cache.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/select_algorithm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/select_algorithm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4402fba4533ab6ea686403de5dcb6965d130076c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/select_algorithm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/sizevars.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/sizevars.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62f2e2b93083df178486c6e958c627c8e51eefd5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/sizevars.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/standalone_compile.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/standalone_compile.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dec857aa1dae1b0e015e35e72cf68f371dd7804c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/standalone_compile.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/subgraph_lowering.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/subgraph_lowering.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bca3353d5450f7e0d14191d0e923121e3bbffeb5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/subgraph_lowering.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/template_heuristics.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/template_heuristics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a0b9f0c384018075cbb9888aba3ca7d546324cd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/template_heuristics.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/test_case.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/test_case.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83388e0ba38363ef85c0eedef124c6cfb10dd2c8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/test_case.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/test_operators.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/test_operators.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3302fe8abbb04dd3b16416deaa8543e6cbbf1167 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/test_operators.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/tiling_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/tiling_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28ec33069ef4aadf41adaa91d3c830d330f1229b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/tiling_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/triton_bundler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/triton_bundler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa43b55b6681225f9a18ac2cd1c4b63bfe7fa648 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/triton_bundler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27a5278a15c0bf8673f0932d593f6df356f37e43 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/virtualized.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/virtualized.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca7ce4fb7cad6291b620b24a087012689d36a34b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/virtualized.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57876d81151e735285f673006a596f0b873dfe09 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/__pycache__/wrapper_benchmark.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__init__.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ef69af8c5241df62169399151f95f0ecb4a0237 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/aoti_hipify_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/aoti_hipify_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ab975834a9c5d87ef206b2b1eb9e52e6064a5e1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/aoti_hipify_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/block_analysis.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/block_analysis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a49b292bcd73dc9a50d661d5e4f5909f233096a8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/block_analysis.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ed80742603def782ac4fad49b858e92a7c6950e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/common.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_bmm_template.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_bmm_template.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58f9773a8c9103e08747151c8140e93e1291b84e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_bmm_template.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_flex_attention_template.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_flex_attention_template.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7436f49daadc638c82fcaf714d3b49510708c97d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_flex_attention_template.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_gemm_template.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_gemm_template.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68c93b6619d36259dc75a481ff9421a3860f59ae Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_gemm_template.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_grouped_gemm_template.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_grouped_gemm_template.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38180d1b747aedd6c31c6bcb1632af3ade528ea1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_grouped_gemm_template.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_micro_gemm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_micro_gemm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdbb2df670d28c4c9a5b40dc1067209a342e148f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_micro_gemm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_template.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_template.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49bf61df698bed72373430f760001f37e4dea228 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_template.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_template_kernel.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_template_kernel.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89931284f7726ac8cd1bb442d98214891a0813c5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_template_kernel.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60bc417d906417453494cb6e0dba75da235ad42e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58b3a5a2eb7d989a863eaf11d1a450697e25e4bc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu_array_ref.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu_array_ref.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7981b933013d61312e4b6587ee4c72991367d9fa Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_cpu_array_ref.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_gpu.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_gpu.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0beb010c07b8730756889da5de6bddbe4ec08ce0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_gpu.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_mps.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_mps.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..610f1e5bccf501bb953086db5831fec2a9dda86d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpp_wrapper_mps.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpu_device_op_overrides.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpu_device_op_overrides.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffa7f287c945661357f766ab46afe77447f52025 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cpu_device_op_overrides.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0b65e40d0a31cfb7e87b0cfb9e5c1eb001dcde2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/cuda_combined_scheduling.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/debug_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/debug_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dae3e6e5ac373fdc0d607fc6bc0b9d473b170447 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/debug_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efa75675564d03e61e1ac6333baf7fadab8565bc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/halide.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..861e36e7d1abc46a6d7338ab6a7fefef61c7267d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/memory_planning.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/mps.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/mps.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ec983d1075696db03b5d64e6a83e8919d4bc8fd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/mps.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/mps_device_op_overrides.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/mps_device_op_overrides.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..142462ba588d093c23a3ec45dfeff6d6636a07e6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/mps_device_op_overrides.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f2c26112a6382f7d9c1e097c92c08fea870f573 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/multi_kernel.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/simd.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/simd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b5a62c33665c5111b27407c46aa6999461052b1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/simd.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/simd_kernel_features.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/simd_kernel_features.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6dd02a7466016d2601bd4669d77487d0a0faf38 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/simd_kernel_features.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/subgraph.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/subgraph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26a658214722b00669d1153255effeea6d287833 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/subgraph.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton_combo_kernel.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton_combo_kernel.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b99fffb0ab41e1a5f4ccdce2703d6e72db80ea28 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton_combo_kernel.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1660cbcb493f66b69affc2b085672fb95fe1a183 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton_split_scan.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..694680f22355fe0348c8f130eda975ce3af28743 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/triton_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/wrapper_fxir.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/wrapper_fxir.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..406818e206e6bd2d8fcb5a323749086d38742192 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/__pycache__/wrapper_fxir.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp b/phivenv/Lib/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6676f7ee362bdbaab4807eda4c56d673750349c7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/aoti_runtime/interface.cpp @@ -0,0 +1,443 @@ +// Definition of AOTI runtime interface functions + +#include +#include + +#include +#include + +#define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \ + try { \ + __VA_ARGS__ \ + } catch (const std::exception& e) { \ + std::cerr << "Error: " << e.what() << '\n'; \ + return AOTI_RUNTIME_FAILURE; \ + } catch (...) { \ + std::cerr << "Unknown exception occurred.\n"; \ + return AOTI_RUNTIME_FAILURE; \ + } \ + return AOTI_RUNTIME_SUCCESS; + +#define AOTI_VECTOR_SIZE_CHECK(actual_size, expected_size, name) \ + do { \ + AOTI_RUNTIME_CHECK( \ + actual_size == expected_size, \ + "expected " + std::string(name) + " vector size to be " + \ + std::to_string(expected_size) + ", but got " + \ + std::to_string(actual_size)); \ + } while (0) + +// AOTInductor uses at::addmm_out, which doesn't supports +// arguments that requires gradient. For this reason, we +// enforce no_grad context for run APIs. +// +// A RAII, thread local (!) guard that enables or disables grad mode upon +// construction, and sets it back to the original value upon destruction. +struct AOTINoGradGuard { + AOTINoGradGuard() { + aoti_torch_grad_mode_set_enabled(false); + } + AOTINoGradGuard(const AOTINoGradGuard&) = delete; + AOTINoGradGuard(AOTINoGradGuard&&) noexcept = delete; + ~AOTINoGradGuard() { + aoti_torch_grad_mode_set_enabled(prev_mode); + } + AOTINoGradGuard& operator=(const AOTINoGradGuard&) = delete; + AOTINoGradGuard& operator=(AOTINoGradGuard&&) noexcept = delete; + bool prev_mode{aoti_torch_grad_mode_is_enabled()}; +}; + +extern "C" { + +AOTIRuntimeError AOTInductorModelContainerCreate( + AOTInductorModelContainerHandle* container_handle, + size_t num_models, + bool is_cpu, + const char* cubin_dir) { + return AOTInductorModelContainerCreateWithDevice( + container_handle, + num_models, + is_cpu ? "cpu" : "cuda", + cubin_dir); +} + +AOTIRuntimeError AOTInductorModelContainerCreateWithDevice( + AOTInductorModelContainerHandle* container_handle, + size_t num_models, + const char* device_str, + const char* cubin_dir) { + if (num_models == 0) { + std::cerr << "Error: num_models must be positive, but got 0\n"; + return AOTI_RUNTIME_FAILURE; + } + CONVERT_EXCEPTION_TO_ERROR_CODE({ + std::optional cubin_dir_opt; + if (cubin_dir != nullptr) { + cubin_dir_opt.emplace(cubin_dir); + } + auto* container = new torch::aot_inductor::AOTInductorModelContainer( + num_models, std::string(device_str), cubin_dir_opt); + *container_handle = + reinterpret_cast(container); + }) +} + +AOTIRuntimeError AOTInductorModelContainerDelete( + AOTInductorModelContainerHandle container_handle) { + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto* container = + reinterpret_cast( + container_handle); + delete container; + }); +} + +AOTIRuntimeError AOTInductorModelContainerRun( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle) { + auto* container = + reinterpret_cast( + container_handle); + AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs"); + AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs"); + + auto stream = + reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run( + input_handles, output_handles, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle* input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle) { + auto* container = + reinterpret_cast( + container_handle); + AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs"); + AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs"); + + auto stream = + reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run_single_threaded( + input_handles, output_handles, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumConstants( + AOTInductorModelContainerHandle container_handle, + size_t* num_constants) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *num_constants = container->num_constants(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantName( + AOTInductorModelContainerHandle container_handle, + size_t idx, + const char** name) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *name = container->constant_name(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( + AOTInductorModelContainerHandle container_handle, + size_t idx, + const char** original_fqn) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *original_fqn = container->constant_original_fqn(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( + AOTInductorModelContainerHandle container_handle, + size_t idx, + bool* from_folded) { + auto* container = + reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *from_folded = container->constant_from_folded(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantType( + AOTInductorModelContainerHandle container_handle, + size_t idx, + int32_t* type) { + auto* container = + reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *type = container->constant_type(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( + AOTInductorModelContainerHandle container_handle, + size_t idx, + int32_t* dtype) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *dtype = container->constant_dtype(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantDataSize( + AOTInductorModelContainerHandle container_handle, + size_t idx, + size_t* data_size) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *data_size = container->constant_data_size(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerExtractConstantsMap( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive) { + auto* container = + reinterpret_cast( + container_handle); + auto constants_map = reinterpret_cast*>(constant_map_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { const auto ret = container->extract_constants_map(use_inactive); + for (const auto& pair: ret) { + constants_map->emplace(pair.first, pair.second); + } + }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateUserManagedConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive, + bool validate_full_update) { + auto* container = + reinterpret_cast( + container_handle); + auto input_map = reinterpret_cast*>(constant_map_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->update_constant_buffer( + *input_map, use_inactive, validate_full_update, /* user_managed = */ true); + }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive, + bool validate_full_update) { + auto* container = + reinterpret_cast( + container_handle); + auto input_map = reinterpret_cast*>(constant_map_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->update_constant_buffer( + *input_map, use_inactive, validate_full_update); + }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle) { + return AOTInductorModelContainerUpdateConstantBuffer(container_handle, + constant_map_handle, + /*use_inactive*/ true, + /*validate_full_update*/ true); +} + +AOTIRuntimeError AOTInductorModelContainerFreeInactiveConstantBuffer( + AOTInductorModelContainerHandle container_handle) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->free_inactive_constant_buffer(); + }) +} + +AOTIRuntimeError AOTInductorModelContainerRunConstantFolding( + AOTInductorModelContainerHandle container_handle, + bool use_inactive, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle) { + auto* container = + reinterpret_cast( + container_handle); + auto stream = + reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run_const_fold(use_inactive, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer( + AOTInductorModelContainerHandle container_handle) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + container->swap_constant_buffer(); + }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumInputs( + AOTInductorModelContainerHandle container_handle, + size_t* ret_num_inputs) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_num_inputs = container->num_inputs(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetInputName( + AOTInductorModelContainerHandle container_handle, + size_t input_idx, + const char** ret_input_names) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_input_names = container->input_name(input_idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumOutputs( + AOTInductorModelContainerHandle container_handle, + size_t* ret_num_outputs) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_num_outputs = container->num_outputs(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetOutputName( + AOTInductorModelContainerHandle container_handle, + size_t output_idx, + const char** ret_output_names) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { *ret_output_names = container->output_name(output_idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetCallSpec( + AOTInductorModelContainerHandle container_handle, + const char** in_spec, + const char** out_spec) { + auto* container = + reinterpret_cast( + container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + *in_spec = container->get_in_spec(); + *out_spec = container->get_out_spec(); + }) +} + +AOTIRuntimeError AOTInductorModelCreate( + AOTInductorModelHandle* model_handle, + AOTInductorConstantMapHandle constant_map_handle){ + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto constant_map = std::make_shared(); + auto constant_array = std::make_shared>(); + auto input_map = reinterpret_cast*>(constant_map_handle); + + auto model = new torch::aot_inductor::AOTInductorModel( + constant_map, + constant_array, + "cpu", // device_str is hardcoded, as AOTInductorModelCreate is only use for CPU models + "" + ); + + if (input_map) { + for (auto const& kv : *input_map) { + constant_map->emplace(kv.first, kv.second); + } + } else { + model->load_constants(); + } + + *model_handle = reinterpret_cast(model); + })} + +AOTIRuntimeError AOTInductorModelRun( + AOTInductorModelHandle model_handle, + AtenTensorHandle* input_handles, + AtenTensorHandle* output_handles) { + auto model = + reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + model->run_impl( + input_handles, + output_handles, + (torch::aot_inductor::DeviceStreamType) nullptr, + nullptr); + }) +} + +AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle){ + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto model = reinterpret_cast( + model_handle); + delete model; + })} + +AOTIRuntimeError AOTInductorModelGetNumOutputs( + AOTInductorModelHandle model_handle, + size_t* ret_num_outputs) { + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto model = reinterpret_cast(model_handle); + *ret_num_outputs = model->num_outputs(); + }) +} + +AOTIRuntimeError AOTInductorModelUpdateConstantsMap( + AOTInductorModelHandle model_handle, + AOTInductorConstantMapHandle constant_map_handle) { + auto model = + reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto constant_map = std::make_shared(); + auto input_map = + reinterpret_cast*>( + constant_map_handle); + + for (auto const& kv : *input_map) { + constant_map->emplace(kv.first, kv.second); + } + model->update_constants_map(std::move(constant_map)); + }) +} + +} // extern "C" diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_template.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_template.py new file mode 100644 index 0000000000000000000000000000000000000000..430a504ed4bda3880d6bf05dddf848abdf08e93e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_template.py @@ -0,0 +1,138 @@ +# mypy: allow-untyped-defs +import ctypes +import functools +import itertools +import logging +import sys +from collections.abc import Iterable +from typing import Callable, Optional, Union +from unittest.mock import patch + +import sympy + +from .. import config, ir +from ..autotune_process import CppBenchmarkRequest, TensorMeta +from ..utils import IndentedBuffer, Placeholder, unique +from ..virtualized import V +from .common import KernelTemplate +from .cpp_template_kernel import CppTemplateCaller, CppTemplateKernel + + +log = logging.getLogger(__name__) + + +class CppTemplate(KernelTemplate): + index_counter = itertools.count() + + def __init__( + self, + name: str, + input_nodes, + layout: ir.Layout, + num_threads: int, + epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, + ) -> None: + super().__init__(name) + self.input_nodes = input_nodes + self.index = next(self.index_counter) + self.output_node: Union[ir.Buffer, list[ir.Buffer]] = ir.Buffer( + name=f"buf_out{self.index}", layout=layout + ) + self.layout = layout + self.num_threads = num_threads + self.epilogue_creator = epilogue_creator + + def generate(self, **kwargs): + kernel_name = f"cpp_{self.name}" + with ( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), + patch.object(ir.FlexibleLayout, "allow_indexing", True), + V.graph.set_current_device(self.layout.device), + CppTemplateKernel( + kernel_name=kernel_name, num_threads=self.num_threads + ) as kernel, + ): + code = kernel.render(self, **kwargs) + _, call_args, _, _ = kernel.args.python_argdefs() + log.debug("Generated Code:\n%s", code) + log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(), + kernel.args.python_argdefs(), + ) + + expected_args = list( + unique(input_node.get_name() for input_node in self.input_nodes) + ) + if isinstance(self.output_node, Iterable): + expected_args.extend([node.get_name() for node in self.output_node]) + else: + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + extra_args = V.graph.sizevars.size_hints( + map(sympy.expand, call_args[len(expected_args) :]) + ) + # Cast the size hint from int to ctypes.c_ulonglong explicitly + # since in cpp kernel, we bind it to C long + extra_args = tuple(ctypes.c_ulonglong(x) for x in extra_args) + + kernel_hash_name = f"cpp_{self.name}_{self.index}" + + # Create the BenchmarkRequest for CPP + bmreq = CppBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=extra_args, + source_code=code, + ) + + def make_kernel_render( + template_node: ir.CppTemplateBuffer, + flag_template_buffer_has_other_users: bool, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + ): + kernel = CppTemplateKernel( + kernel_name=str(Placeholder.KERNEL_NAME), num_threads=self.num_threads + ) + render = functools.partial( + kernel.render, + self, + template_buffer_node=template_node, + flag_template_buffer_has_other_users=flag_template_buffer_has_other_users, + epilogue_nodes=epilogue_nodes, + **kwargs, + ) + return kernel, render + + return CppTemplateCaller( + kernel_hash_name, + self.name, + self.input_nodes, + self.output_node[0].get_layout() + if isinstance(self.output_node, Iterable) + else self.output_node.get_layout(), + make_kernel_render, + bmreq, + self, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.writeline("#include ") + # TODO: add c10::ForcedUnroll test to test_aoti_abi_check + res.splice("""#include """) + res.splice("""#include """) + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] + if enable_kernel_profile: + res.writelines(["#include "]) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_template_kernel.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_template_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..03e81f43a415fd23602897cd485abff2ec3e8814 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_template_kernel.py @@ -0,0 +1,597 @@ +# mypy: allow-untyped-defs +import itertools +from collections.abc import Iterable +from typing import Any, Callable, Optional, Union + +import sympy +from sympy.parsing.sympy_parser import parse_expr + +import torch +from torch._inductor.utils import do_bench_using_profiling +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.symbol import SymT + +from .. import config, cpp_builder, ir, lowering as L +from ..autotune_process import CppBenchmarkRequest +from ..loop_body import LoopBody +from ..select_algorithm import PartialRender +from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix +from ..virtualized import V +from .common import REMOVED +from .cpp import CppKernel, CppKernelProxy, KernelGroup +from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext + + +def parse_expr_with_index_symbols(expr): + if isinstance(expr, sympy.Expr): + return expr + elif isinstance(expr, (list, tuple)): + return [parse_expr_with_index_symbols(e) for e in expr] + else: + expr = parse_expr(str(expr)) + int_symbols = {sym: sympy_index_symbol(sym.name) for sym in expr.free_symbols} + return expr.subs(int_symbols) + + +def wrap_with_tensorbox(node) -> ir.TensorBox: + return ( + ir.TensorBox.create(node) if isinstance(node, ir.Buffer) else ir.TensorBox(node) + ) + + +class CppTemplateKernel(CppKernel): + def __init__(self, kernel_name, num_threads): + super().__init__(None, num_threads) + self.kernel_name = kernel_name + self.render_hooks = {} + self.local_buffers = {} + + def render(self, template, **kwargs): + return PartialRender( + template.render(kernel=self, **kwargs), self.render_hooks + ).finalize_all() + + def def_kernel( + self, + inputs: dict[str, ir.Buffer], + outputs: dict[str, ir.Buffer], + aliases: Optional[dict[str, str]] = None, + function_name: str = "", + extra_sizevars: Optional[list[sympy.Expr]] = None, + placeholder: str = "", + ) -> str: + if len(function_name) == 0: + function_name = str(self.kernel_name) + for name, inp in inputs.items(): + if inp is not None: + self.args.input_buffers[inp.get_name()] = name + for name, out in outputs.items(): + self.args.output_buffers[out.get_name()] = name + if aliases is not None: + for alias, orig in aliases.items(): + if orig in self.args.input_buffers: + self.args.input_buffers[alias] = self.args.input_buffers[orig] + if orig in self.args.output_buffers: + self.args.output_buffers[alias] = self.args.output_buffers[orig] + + unique_sizevars = OrderedSet( + s + for input in inputs.values() + if input is not None + for sym in itertools.chain(input.get_size(), input.get_stride()) + if isinstance(sym, sympy.Expr) + for s in sym.free_symbols + ) + unique_sizevars.update( + s + for sym in extra_sizevars or [] + if isinstance(sym, sympy.Expr) + for s in sym.free_symbols + ) + unique_sizevars.update( + s + for output in outputs.values() + for sym in itertools.chain(output.get_size(), output.get_stride()) + if isinstance(sym, sympy.Expr) + for s in sym.free_symbols + ) + sizevars = sorted(unique_sizevars, key=str) + for sizevar in sizevars: + self.args.sizevars[sizevar] = f"k{sizevar}" + + def hook(): + # remove all aliases before generate function definition + if aliases is not None: + for alias in aliases: + if alias in self.args.input_buffers: + raise AssertionError( + f"input_buffers cannot be removed: {alias}" + ) + if alias in self.args.output_buffers: + self.args.output_buffers[alias] = REMOVED + cpp_argdefs, _, _ = self.args.cpp_argdefs() + return f"void {function_name}({', '.join(cpp_argdefs)})" + + assert placeholder not in self.render_hooks + self.render_hooks[placeholder] = hook + return placeholder + + def call_kernel(self, name: str, node: ir.CppTemplateBuffer): + wrapper = V.graph.wrapper_code + _, call_args, arg_types = self.args.cpp_argdefs() + wrapper.generate_kernel_call(name, call_args, triton=False, arg_types=arg_types) + + def dtype(self, node: ir.Buffer) -> str: + return DTYPE_TO_CPP[node.get_dtype()] + + def acc_dtype(self, node: ir.Buffer) -> str: + if node.get_dtype() in [torch.float32, torch.bfloat16, torch.half]: + return "float" + else: + raise NotImplementedError(f"Unsupported dtype: {node.get_dtype()}") + + def size(self, node: ir.Buffer, dim: int) -> str: + return cexpr_index(self.rename_indexing(node.get_size()[dim])) + + def stride(self, node: ir.Buffer, dim: int) -> str: + return cexpr_index(self.rename_indexing(node.get_stride()[dim])) + + def index(self, node: ir.Buffer, indices: list[Any]) -> str: + indexer = node.get_layout().as_fixed().make_indexer() + index = indexer(parse_expr_with_index_symbols(indices)) + index = self.rename_indexing(index) + outer_name = node.get_name() + inner_name = ( + outer_name + if outer_name in self.local_buffers + else self.args.input(node.get_name()) + ) + return f"{inner_name}[{cexpr_index(index)}]" + + def slice_nd(self, node, ranges: list[tuple[Any, Any]]) -> ir.ReinterpretView: + """ + Slice the given node with a list of ranges (start and end) corresponding to its dims. + The dim is not sliced if the corresponding range is empty. + """ + assert len(ranges) == len(node.get_size()), f"{ranges=}, {node=}" + sliced = wrap_with_tensorbox(node) + for dim, _range in enumerate(ranges): + if len(_range) == 0: + continue + assert len(_range) == 2 + start, end = parse_expr_with_index_symbols(_range) + sliced = L.slice_(sliced, dim, start, end, clamp=False) + assert isinstance(sliced.data, ir.ReinterpretView), sliced.data + return sliced.data + + def select(self, node, dim: int, idx: int) -> ir.ReinterpretView: + # We avoid using L.select here because we need clamp=False so the dim after slicing + # is 1 instead of a sympy expression of symbol - dim_size. + node = wrap_with_tensorbox(node) + idx = ir.View.handle_negative_index(idx, node.get_size()[dim]) + sliced = L.squeeze(L.slice_(node, dim, idx, idx + 1, clamp=False), dim) + assert isinstance(sliced.data, ir.ReinterpretView), sliced.data + return sliced.data + + def view(self, node, sizes: list[Any]) -> ir.View: + node = wrap_with_tensorbox(node) + sizes = parse_expr_with_index_symbols(sizes) + return L.view(node, sizes).data + + def permute(self, node, dims): + node = wrap_with_tensorbox(node) + permuted = L.permute(node, dims).data + assert isinstance(permuted, ir.ReinterpretView) + return permuted + + def maybe_codegen_profile(self) -> str: + if config.cpp.enable_kernel_profile: + graph_id = V.graph.graph_id + prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" + return f'RECORD_FUNCTION("{prefix}{self.kernel_name}", c10::ArrayRef({{}}));' + else: + return "" + + def unroll_pragma(self, unroll): + if cpp_builder.is_gcc(): + return f"#pragma GCC unroll {unroll}" + else: + return f"#pragma unroll {unroll}" + + def define_buffer(self, name, sizes: list[Any], dtype=torch.float) -> str: + """Define kernel local buffer""" + sizes = parse_expr_with_index_symbols(sizes) + buf = ir.Buffer( + name=name, layout=ir.FixedLayout(torch.device("cpu"), dtype, sizes) + ) + self.local_buffers[name] = buf + ctype = f"{DTYPE_TO_CPP[dtype]}" + numel = f"{cexpr_index(buf.get_numel())}" + return f"auto _{name} = std::make_unique<{ctype}[]>({numel}); auto {name} = _{name}.get();" + + def define_stack_allocated_buffer( + self, name, sizes: list[Any], dtype=torch.float + ) -> str: + """Define stack-allocated buffer""" + sizes = parse_expr_with_index_symbols(sizes) + buf = ir.Buffer( + name=name, layout=ir.FixedLayout(torch.device("cpu"), dtype, sizes) + ) + self.local_buffers[name] = buf + ctype = f"{DTYPE_TO_CPP[dtype]}" + numel = f"{cexpr_index(buf.get_numel())}" + return f"alignas(64) {ctype} _{name}[{numel}]; {ctype}* {name} = _{name};" + + def reinit_buffer_if_null(self, name): + """Reinit the previously defined local buffer if it is null""" + assert name in self.local_buffers + buf = self.local_buffers[name] + ctype = f"{DTYPE_TO_CPP[buf.layout.dtype]}" + numel = f"{cexpr_index(buf.get_numel())}" + return f"if (_{name} == nullptr) {{ _{name} = std::make_unique<{ctype}[]>({numel}); {name} = _{name}.get(); }}" + + def release_buffer(self, name): + """Codegen the code to release the ownership of a local buffer to others""" + assert name in self.local_buffers + return f"_{name}.release()" + + def store_pointwise_nodes( + self, + dst: ir.Buffer, + nodes: list[ir.IRNode], + offsets: Optional[list[sympy.Expr]] = None, + reindexers: Optional[list[Optional[Callable[[list[Any]], list[Any]]]]] = None, + ) -> str: + var_sizes = (tuple(dst.get_size()), ()) + var_ranges = { + sympy_index_symbol_with_prefix(SymT.INDEX, i): sz + for i, sz in enumerate(var_sizes[0]) + } + if not offsets: + offsets = [sympy.S.Zero] * len(var_sizes[0]) + if not reindexers: + reindexers = [None] * len(nodes) + assert len(offsets) == len(var_sizes[0]) + output_index = dst.get_layout().make_indexer()([*var_ranges.keys()]) + kernel_group = KernelGroup() + kernel_group.args = self.args + cpp_kernel_proxy = CppKernelProxy(kernel_group) + bodies = [] + var_sizes_list = [] + for i, node in enumerate(nodes): + output_name = node.get_name() if i < len(nodes) - 1 else dst.get_name() + node = node.data if isinstance(node, ir.ComputedBuffer) else node + assert isinstance(node, ir.Pointwise), node + + def fn(*args): + assert len(args) == 2 + assert len(args[0]) == len(var_sizes[0]) + assert len(args[1]) == 0 + new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type] + if reindexers[i] is not None: + new_args = reindexers[i](new_args) # type: ignore[misc] + V.ops.store( + output_name, + output_index, + node.make_loader()(new_args).value, + ) + + body = LoopBody( + fn, + (list(var_ranges.keys()), ()), + var_ranges, + list(var_ranges.keys()), + tuple(), + ) + bodies.append(body) + var_sizes_list.append(var_sizes) + + cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) + kernel_group.finalize_kernel(cpp_kernel_proxy, []) + return kernel_group.loops_code.getvalue() + + def store_grouped_gemm_pointwise_nodes( + self, + dst: tuple[ir.Buffer], + nodes: list[ir.IRNode], + offsets: list[sympy.Expr], + reindexers: list[Optional[Callable[[list[Any]], list[Any]]]], + output_names: list[str], + ) -> str: + ref_dst = dst[0] + var_sizes = (tuple(ref_dst.get_size()), ()) + var_ranges = { + sympy_index_symbol_with_prefix(SymT.INDEX, i): sz + for i, sz in enumerate(var_sizes[0]) + } + assert offsets, "offsets should be set outside" + assert all(len(offset) == len(var_sizes[0]) for offset in offsets) + output_index = ref_dst.get_layout().make_indexer()([*var_ranges.keys()]) + kernel_group = KernelGroup() + kernel_group.args = self.args + cpp_kernel_proxy = CppKernelProxy(kernel_group) + bodies = [] + var_sizes_list = [] + for i, node in enumerate(nodes): + output_name = output_names[i] + node = node.data if isinstance(node, ir.ComputedBuffer) else node + assert isinstance(node, ir.Pointwise), node + + def fn(*args): + assert len(args) == 2 + assert len(args[0]) == len(var_sizes[0]) + assert len(args[1]) == 0 + new_args = [arg + offset for arg, offset in zip(args[0], offsets[i])] # type: ignore[arg-type] + if reindexers[i] is not None: + new_args = reindexers[i](new_args) # type: ignore[misc] + V.ops.store( + output_name, + output_index, + node.make_loader()(new_args).value, + ) + + body = LoopBody( + fn, + (list(var_ranges.keys()), ()), + var_ranges, + list(var_ranges.keys()), + tuple(), + ) + bodies.append(body) + var_sizes_list.append(var_sizes) + + cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list) + kernel_group.finalize_kernel(cpp_kernel_proxy, []) + return kernel_group.loops_code.getvalue() + + def store_output( + self, + dst: ir.Buffer, + src: ir.Buffer, + orig_src: Optional[ir.Buffer] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + offsets: Optional[list[Any]] = None, + reindexers: Optional[list[Optional[Callable[[list[Any]], list[Any]]]]] = None, + ): + """ + Store the `src` buffer to the `dst` buffer. The size of `src` and `dst` should match. + If `epilogue_nodes` is provided, the `src` buffer is firstly computed with the epilogues + before stored to `dst`. The `epilogues_nodes` are all pointwise. + + Notes: + 1. `src` and `dst` buffer could be the same buffer in which case we are doing in-place compute + and stores. In case `epilogue_nodes` are not provided, we do nothing. + 2. The `epilogue_nodes`, if exist, have computations on `src` before storing to `dst` but since + they come form the original Inductor IR, they might need to be adjusted before working with + `src` and `dst` as outlined below: + a) `src` or `dst` buffer could be a sub-slice of the ranges the `epilogue_nodes`work on. + In this case, the `offsets` could be provided to adjust the indices passed to + `epilogue_nodes` during codegen and the data ranges are also configured according to + the sizes of `src` and `dst`. + b) `dst` might be indexed in a different way as the `epilogue_nodes`, hence a `reindexer` is + needed on the indices to `epilogue_nodes` to match the indexing of `dst`. + c) If `src` is local, we need to add a local buffer for it and localize the `orig_src` buffer + in `epilogue_nodes` with `src`. + """ + assert isinstance(dst, (ir.Buffer, ir.ReinterpretView)) + assert dst.get_size() == src.get_size(), f"{dst=}, {src=}" + if offsets: + offsets = parse_expr_with_index_symbols(offsets) + if epilogue_nodes: + with LocalBufferContext(self.args) as scope: + assert orig_src is not None + if orig_src.get_name() != src.get_name(): + scope.add_local_buffer( + src, + [ + orig_src, + ], + ) + epilogue_nodes = scope.localize_nodes(epilogue_nodes) + return self.store_pointwise_nodes( + dst, + epilogue_nodes, # type: ignore[arg-type] + offsets, + reindexers, + ) + else: + if dst.get_name() != src.get_name(): + # src is local + copy = L.copy(dst, src).data.data + with LocalBufferContext(self.args) as scope: + scope.add_local_buffer(src) + return self.store_pointwise_nodes(dst, [copy]) + else: + assert dst.layout == src.layout, f"{dst=}, {src=}" + return "" + + def store_outputs( + self, + dst: tuple[ir.Buffer], + src: tuple[ir.IRNode], + orig_src: Optional[tuple[ir.IRNode]] = None, + epilogue_nodes: Optional[list[ir.IRNode]] = None, + offsets: Optional[list[Any]] = None, + reindexers: Optional[list[Optional[Callable[[list[Any]], list[Any]]]]] = None, + multi_output_buffers: Optional[tuple[ir.MultiOutput]] = None, + ): + assert isinstance(dst, Iterable) + assert all(_dst.get_size() == _src.get_size() for _src, _dst in zip(src, dst)) + if offsets: + offsets = parse_expr_with_index_symbols(offsets) + gemm_num = len(src) + final_offsets = [] + output_names = [] + if epilogue_nodes: + if not reindexers: + reindexers = [None] * len(epilogue_nodes) + with LocalBufferContext(self.args) as scope: + assert orig_src is not None + localize_epilogue_nodes = [] + all_read_names = [] + for epilogue in epilogue_nodes: + all_read_names.extend(list(epilogue.get_read_names())) + localize_epilogue_nodes.extend(scope.localize_nodes(epilogue_nodes)) + final_offsets.extend([offsets] * len(localize_epilogue_nodes)) + output_names.extend( + [node.get_name() for node in localize_epilogue_nodes] + ) + for gemm_idx in range(gemm_num): + if orig_src[gemm_idx].get_name() != src[gemm_idx].get_name(): + if orig_src[gemm_idx].get_name() in all_read_names or ( + multi_output_buffers + and multi_output_buffers[gemm_idx].get_name() + in all_read_names + ): + # If any of the Epilogue nodes use this GEMM output, let's localize the GEMM output + global_buffers = [orig_src[gemm_idx]] + if ( + multi_output_buffers + and multi_output_buffers[gemm_idx].get_name() + in all_read_names + and orig_src[gemm_idx].get_name() not in all_read_names + ): + # Epilogue might directly read the MultiOutput, Locallize MultiOutput to the local Buffer + # if this MultiOutput has not been stored by in-template epilogue + # otherwise, use the cse store cache if it will be stored before used + global_buffers.append(multi_output_buffers[gemm_idx]) + scope.add_local_buffer( + src[gemm_idx], + global_buffers, + ) + else: + scope.add_local_buffer(src[gemm_idx]) + localize_epilogue_nodes.extend( + [L.copy(dst[gemm_idx], src[gemm_idx]).data.data] + ) + reindexers.append(None) + output_names.append(dst[gemm_idx].get_name()) + final_offsets.append( + [sympy.S.Zero] * len(dst[gemm_idx].get_size()) + ) + res = self.store_grouped_gemm_pointwise_nodes( + dst, + localize_epilogue_nodes, + final_offsets, + reindexers, + output_names=output_names, + ) + for gemm_idx in range(gemm_num): + if ( + multi_output_buffers + and multi_output_buffers[gemm_idx].get_name() in all_read_names + ): + # If the MultiOutput is used in the Epilogue, let's remove it from args + multi_output_name = multi_output_buffers[gemm_idx].get_name() + if ( + multi_output_name in self.args.output_buffers + and self.args.output_buffers[multi_output_name] + is not REMOVED + ): + self.remove_buffer(multi_output_name) + return res + else: + if dst[0].get_name() != src[0].get_name(): + copy_list = [] + with LocalBufferContext(self.args) as scope: + for _src, _dst in zip(src, dst): + copy_list.extend([L.copy(_dst, _src).data.data]) + scope.add_local_buffer(_src) + output_names.append(_dst.get_name()) + final_offsets.append([sympy.S.Zero] * len(_dst.get_size())) + reindexers = [None] * len(copy_list) + return self.store_grouped_gemm_pointwise_nodes( + dst, + nodes=copy_list, + offsets=final_offsets, + reindexers=reindexers, + output_names=output_names, + ) + else: + assert all( + _src.get_name() == _dst.get_name() for _src, _dst in zip(src, dst) + ) + assert all( + _src.get_layout() == _dst.get_layout() + for _src, _dst in zip(src, dst) + ) + return "" + + def check_bounds(self, expr, size, lower, upper): + # CppTemplateKernel does not need codegen related operations + return + + +class CppTemplateCaller(ir.ChoiceCaller): + """ + CppTemplateCaller + + This class represents a caller for CPP template kernels. It is a subclass of ir.ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (CppBenchmarkRequest): The benchmark request for the caller. + template_buffer (ir.CppTemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: list[ir.Buffer], + layout: ir.Layout, + make_kernel_render: Callable[ + [ + ir.CppTemplateBuffer, + bool, + Optional[list[ir.IRNode]], + ], + str, + ], + bmreq: CppBenchmarkRequest, + template: "CppTemplate", # type: ignore[name-defined] # noqa: F821 + info_kwargs: Optional[ + dict[str, Union[ir.PrimitiveInfoType, list[ir.PrimitiveInfoType]]] + ] = None, + ): + super().__init__(name, input_nodes, layout, description="") + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + if config.profile_bandwidth_with_do_bench_using_profiling: + algo = self.bmreq.make_run_fn(*args, out=out) + return do_bench_using_profiling(algo) + return self.bmreq.benchmark(*args, out=out) + + def hash_key(self) -> str: + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def info_dict( + self, + ) -> dict[str, Union[ir.PrimitiveInfoType, list[ir.PrimitiveInfoType]]]: + return {"backend": "CPP", "op_type": "unknown"} + + def output_node(self) -> ir.TensorBox: + return ir.TensorBox.create( + ir.CppTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + template=self.template, + choice=self, + ) + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_utils.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6cfc669c90d042eab7e72e512ead10a91e6697ea --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_utils.py @@ -0,0 +1,776 @@ +# mypy: allow-untyped-defs +import contextlib +import dataclasses +import functools +import math +import sys +from collections import namedtuple +from collections.abc import Sequence +from typing import Any, Callable, Optional +from unittest.mock import patch + +import sympy + +import torch +from torch._prims_common import is_integer_dtype +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.printers import CppPrinter as _CppPrinter +from torch.utils._sympy.symbol import symbol_is_type, SymT +from torch.utils._sympy.value_ranges import ValueRanges + +from .. import ir +from ..dependencies import Dep +from ..loop_body import LoopBody +from ..scheduler import BaseSchedulerNode, SchedulerBuffer +from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs +from ..virtualized import ops, OpsValue, V +from .common import CSEVariable, Kernel, KernelArgs, OptimizationContext + + +DTYPE_TO_CPP = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "at::Half", + torch.int64: "int64_t", + torch.int32: "int32_t", + torch.int16: "int16_t", + torch.int8: "int8_t", + torch.uint64: "uint64_t", + torch.uint32: "uint32_t", + torch.uint16: "uint16_t", + torch.uint8: "uint8_t", + torch.bool: "bool", + torch.bfloat16: "at::BFloat16", + torch.complex32: "at::complex", + torch.complex64: "at::complex", + torch.complex128: "at::complex", + torch.float8_e4m3fn: "at::Float8_e4m3fn", + torch.float8_e5m2: "at::Float8_e5m2", + torch.float8_e4m3fnuz: "at::Float8_e4m3fnuz", + torch.float8_e5m2fnuz: "at::Float8_e5m2fnuz", +} + +DTYPE_TO_ATEN = { + torch.float32: "at::kFloat", + torch.float64: "at::kDouble", + torch.float16: "at::kHalf", + torch.int64: "at::kLong", + torch.int32: "at::kInt", + torch.int16: "at::kShort", + torch.int8: "at::kChar", + torch.uint64: "at::kUInt64", + torch.uint32: "at::kUInt32", + torch.uint16: "at::kUInt16", + torch.uint8: "at::kByte", + torch.uint32: "at::kUInt32", + torch.uint64: "at::kUInt64", + torch.bool: "at::kBool", + torch.bfloat16: "at::kBFloat16", + torch.complex32: "at::kComplexHalf", + torch.complex64: "at::kComplexFloat", + torch.complex128: "at::kComplexDouble", + torch.float8_e4m3fn: "at::kFloat8_e4m3fn", + torch.float8_e5m2: "at::kFloat8_e5m2", + torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz", + torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz", +} + +DEVICE_TO_ATEN = { + "meta": "at::kMeta", + "cpu": "at::kCPU", + "cuda": "at::kCUDA", + "xpu": "at::kXPU", + "mps": "at::kMPS", +} + +LAYOUT_TO_ATEN = { + torch.strided: "at::kStrided", + torch._mkldnn: "at::kMkldnn", # type: ignore[attr-defined] +} + +# matches c10/core/DeviceType.h +DEVICE_TO_INT = {"cpu": 0, "cuda": 1} + +_IS_WINDOWS = sys.platform == "win32" + +INDEX_TYPE = "int64_t" + +GemmBlocking = namedtuple("GemmBlocking", ["block_m", "block_n", "block_k"]) + + +def get_promote_dtype(args): + return ( + functools.reduce( + torch.promote_types, # type: ignore[arg-type] + [n.dtype for n in args if isinstance(n, CppCSEVariable)], + ) + if all(n.dtype is not None for n in args if isinstance(n, CppCSEVariable)) + else None # not enough info to calculate the promote dtype + ) + + +def promote_args(new_args): + def promote_arg(arg, promote_type): + if ( + isinstance(arg, CppCSEVariable) + and arg.dtype + and promote_type + and arg.dtype != promote_type + ): + arg = ops.to_dtype(arg, promote_type) + arg = arg.value if isinstance(arg, OpsValue) else arg + arg.dtype = promote_type + return arg + + promote_type = get_promote_dtype(new_args) + promote_fn = functools.partial( + promote_arg, + promote_type=promote_type, + ) + if ( + all( + new_arg.dtype is not None + for new_arg in new_args + if isinstance(new_arg, CppCSEVariable) + ) + and promote_type + ): + new_args = list(map(promote_fn, new_args)) + return new_args + + +class CppCSEVariable(CSEVariable): + def __init__( + self, + name, + bounds: ValueRanges[Any], + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__(name, bounds, dtype) + self.is_vec = False + self.dependent_itervars = OrderedSet[sympy.Symbol]() + + def __repr__(self) -> str: + return ( + f"CppCSEVariable(name: {self.name}, bounds: {self.bounds}, is_vec: {self.is_vec}, dtype: {self.dtype}, " + f"dependent_itervars: {self.dependent_itervars})" + ) + + def update_on_args(self, name, args, kwargs): + if name == "load": + # args[2] is index + self._set_dependent_itervars(args[2]) + else: + # propagate relevant itervars and is_vec from args + self.dependent_itervars.update( + *[ + arg.dependent_itervars + for arg in args + if isinstance(arg, CppCSEVariable) + ] + ) + if name == "index_expr": + self._set_dependent_itervars(args[0]) + if any(arg.is_vec for arg in args if isinstance(arg, CppCSEVariable)): + self.is_vec = True + + def _set_dependent_itervars(self, index: sympy.Expr): + """ + Set the relevant itervars for this variable based on the `index` expression. + This includes the itervars directly used in the `index` as well as relevant itervars + of other cse variables used in the `index`. + """ + for s in index.free_symbols: + if s in V.kernel.itervars: + self.dependent_itervars.add(s) # type: ignore[arg-type] + elif s.name in V.kernel.cse.varname_map: # type: ignore[attr-defined] + self.dependent_itervars.update( + V.kernel.cse.varname_map[s.name].dependent_itervars # type: ignore[attr-defined] + ) + + def depends_on(self, itervar: sympy.Symbol): + return itervar in self.dependent_itervars + + +class CppPrinter(_CppPrinter): + def doprint(self, expr, *, simplify: bool = True, p=True): + # TODO: why are people passing strings to the printer here :think: + if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): + expr = V.graph.sizevars.simplify(expr) + return super().doprint(expr) + + +# A function to print, useful for printing sympy symbols. +cexpr = CppPrinter().doprint + + +def cexpr_index(index): + return f"static_cast<{INDEX_TYPE}>({cexpr(index)})" + + +def value_to_cpp(value, cpp_type): + if value == float("-inf"): + return f"-std::numeric_limits<{cpp_type}>::infinity()" + elif value == float("inf"): + return f"std::numeric_limits<{cpp_type}>::infinity()" + elif isinstance(value, bool): + return f"static_cast<{cpp_type}>({str(value).lower()})" + elif math.isnan(value): + return f"std::numeric_limits<{cpp_type}>::quiet_NaN()" + else: + return f"static_cast<{cpp_type}>({repr(value)})" + + +def rewrite_index_for_function( + localize_buffer_handler: "LocalizeBufferHandler", + index: sympy.Expr, + global_buf_name: str, +): + # Local buffer at the inner dimensions + snode = V.graph.scheduler.name_to_buf[global_buf_name].defining_op + assert snode is not None + local_buf = localize_buffer_handler.global_to_local[global_buf_name] + scheduler_nodes = snode.get_nodes() + _, (group, reduction_group) = max( + scheduler_nodes, key=lambda x: int(x.is_reduction()) + ).group + call_ranges = tuple(group) + tuple(reduction_group) + indices_to_keep = [ + f"x{len(call_ranges) - (idx + 1)}" + for idx in range(len(local_buf.get_layout().size)) + ] + sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) # type: ignore[attr-defined] + replacements = {} + for x in sorted_symbols: + if x.name.startswith("x") and x.name not in indices_to_keep: # type: ignore[attr-defined] + # Only keep index used by local buffer + replacements[x] = sympy.core.numbers.Zero() + index = sympy_subs(index, replacements) # type: ignore[arg-type] + return index + + +def rewrite_index_for_nodes( + localize_buffer_handler: "LocalizeBufferHandler", + index: sympy.Expr, + global_buf_name: str, +): + used_vars = OrderedSet( + s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX) + ) + index_vars = [] + local_buf = localize_buffer_handler.global_to_local[global_buf_name] + for i in range(len(local_buf.get_size())): + var = sympy_index_symbol_with_prefix(SymT.INDEX, i) + index_vars.append(var if var in used_vars else 0) + index = local_buf.get_layout().make_indexer()(index_vars) + return index + + +class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined] + def __init__( + self, + inner, + global_to_local: dict[str, ir.Buffer], + rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr], + ) -> None: + super().__init__(inner) + self.global_to_local = global_to_local + self.rewrite_index = rewrite_index + + def localize(self, name: str, index: sympy.Expr): + if self.global_to_local and name in self.global_to_local: + assert self.rewrite_index is not None + index = self.rewrite_index(self, index, name) + name = self.global_to_local[name].get_name() + return name, index + + def load(self, name: str, index: sympy.Expr): + return self._inner.load(*self.localize(name, index)) + + def store(self, name, index, value, mode=None): + local_buffer_name, local_buffer_index = self.localize(name, index) + res = self._inner.store(local_buffer_name, local_buffer_index, value, mode) + if ( + self.global_to_local + and name in self.global_to_local + and isinstance(V.kernel, Kernel) + ): + # Remove name of local buffer from Kernel.store_buffer_names + # local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store. + V.kernel.store_buffer_names.discard(local_buffer_name) + return res + + def store_reduction(self, name, index, value): + return self._inner.store_reduction(*self.localize(name, index), value) + + +class LocalBufferContext: + """ + This class creates a context that helps to generate code involving Inductor IR with + function local buffers. These buffers are constructed during the codegen process and + are used to store intermediate results such as local accumulators. We do not want to + add them to `V.graph` since they are not global and we do not want to add them as + function arguments either. So we patch the codegen processes under this scope to support + these buffers without exposure to the outside world. + """ + + def __init__(self, kernel_args: KernelArgs) -> None: + self.kernel_args = kernel_args + self.exit_stack = contextlib.ExitStack() + # map local buffer name to local buffer + self.local_buffers: dict[str, ir.Buffer] = {} + # map global buffer name to global buffer + self.global_buffers: dict[str, ir.Buffer] = {} + # map global buffer name to local buffer + self.global_to_local: dict[str, ir.Buffer] = {} + # record the global buffers that are removed by this LocalBufferContext + self.removed_buffers: OrderedSet[str] = OrderedSet() + + def __enter__(self): + self.exit_stack.__enter__() + original_get_dtype = V.graph.get_dtype + + def get_dtype(name): + if name in self.local_buffers: + return self.local_buffers[name].get_dtype() + return original_get_dtype(name) + + self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype)) + + original_input = self.kernel_args.input + + def input(name): + if name in self.local_buffers: + return name + return original_input(name) + + self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input)) + + original_output = self.kernel_args.output + + def output(name): + if name in self.local_buffers: + return name + return original_output(name) + + self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output)) + + # Set current LocalBufferContext into V + self.exit_stack.enter_context(V.set_local_buffer_context(self)) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.local_buffers.clear() + self.exit_stack.__exit__(exc_type, exc_val, exc_tb) + + def add_local_buffer( + self, local_buffer: ir.Buffer, global_buffers: Optional[list[ir.Buffer]] = None + ): + assert local_buffer.get_name() not in self.local_buffers + self.local_buffers[local_buffer.get_name()] = local_buffer + if global_buffers: + for global_buffer in global_buffers: + global_buffer_name = global_buffer.get_name() + assert ( + global_buffer_name not in self.global_buffers + and global_buffer_name not in self.global_to_local + ) + self.global_buffers[global_buffer_name] = global_buffer + self.global_to_local[global_buffer_name] = local_buffer + if global_buffer_name not in V.graph.removed_buffers: + # Record the global buffers that are removed by this LocalBufferContext + # since which may need to restore. Refer to issue: + # https://github.com/pytorch/pytorch/issues/144186 + self.removed_buffers.add(global_buffer_name) + V.graph.removed_buffers.add(global_buffer_name) + + def localize_function( + self, + fn: Callable[..., Any], + rewrite_index: Callable[ + ["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr + ] = rewrite_index_for_function, + ): + def inner(*args, **kwargs): + with V.set_ops_handler( + LocalizeBufferHandler( + V.get_ops_handler(), + global_to_local=self.global_to_local, + rewrite_index=rewrite_index, + ) + ): + return fn(*args, **kwargs) + + return inner + + def localize_nodes( + self, + nodes: list[ir.IRNode], + rewrite_index: Callable[ + ["LocalizeBufferHandler", sympy.Expr, str], sympy.Expr + ] = rewrite_index_for_nodes, + ) -> list[ir.IRNode]: + """ + Given `local_buf` and `global_buf` registered in current `LocalBufferContext` + though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf` + for the given `nodes` and returns a new list of IR nodes that work on `local_buf` + instead of `global_buf`, i.e., all the loads and stores are redirected to + `local_buf`. This helps the fused loops to work on smaller-sized local buffers + for better data locality. + + The the data access of `local_buf` is assumed to be contiguous with the + same order as the `global_buf`. + """ + assert len(nodes) > 0 + + def wrap_inner_fn_for_node(node: ir.IRNode): + loops = node.data if isinstance(node, ir.ComputedBuffer) else node + assert isinstance(loops, ir.Loops) + new_inner_fn = self.localize_function( + loops.inner_fn, + rewrite_index, + ) + + new_loops = dataclasses.replace(loops, inner_fn=new_inner_fn) + if isinstance(node, ir.ComputedBuffer): + new_node = ir.ComputedBuffer( + name=node.get_name(), layout=node.get_layout(), data=new_loops + ) + else: + new_node = new_loops # type: ignore[assignment] + + return new_node + + return [wrap_inner_fn_for_node(node) for node in nodes] + + +def unify_mask_base_type( + buffer: IndentedBuffer, + vars: tuple[CSEVariable, ...], + dtype=torch.float, +): + """ + Given list of cse variables, + Cast each to new mask base dtype and return casted cse variable. + """ + new_vars = ( + V.kernel.cse.generate( + buffer, + f"{V.kernel._get_mask_cast(var, dtype)}", + ) + for var in vars + ) + return new_vars + + +def may_unify_binary_op_mask_type(a, b): + """ + Given two cse variables, when dtype is bool, unify them to the same mask dtype and return casted cse variable. + """ + if a.dtype == torch.bool: + assert b.dtype == torch.bool + mask_dtype = torch.int32 + return unify_mask_base_type(V.kernel.compute, (a, b), mask_dtype) + return a, b + + +def codegen_rand(offset, code, rand_function, dst_dtype=torch.float32): + assert is_integer_dtype(offset.dtype) + code.writeline("[&]()") + with code.indent(): + code.writeline( + f"{DTYPE_TO_CPP[offset.dtype]} offset[{V.kernel.tiling_factor}];" + ) + code.writeline(f"{DTYPE_TO_CPP[dst_dtype]} result[{V.kernel.tiling_factor}];") + code.writeline(f"{offset}.store(offset);") + code.writeline( + f"for( {DTYPE_TO_CPP[offset.dtype]} offset_idx = 0; offset_idx < {V.kernel.tiling_factor}; offset_idx++ )" + ) + with code.indent(): + code.writeline(rand_function) + num_vectors = V.kernel._get_num_vectors(dtype=dst_dtype) + if num_vectors == 1: + code.writeline( + f"return at::vec::Vectorized<{DTYPE_TO_CPP[dst_dtype]}>::loadu(result);" + ) + else: + code.writeline( + f"return at::vec::VectorizedN<{DTYPE_TO_CPP[dst_dtype]}, {num_vectors}>::loadu(result);" + ) + code.writeline("()") + return code + + +def get_gemm_template_output_and_compute_dtype(input_dtype): + if input_dtype in [torch.uint8, torch.int8]: + return (torch.int32, torch.int32) + else: + return (torch.float32, torch.float32) + + +def create_epilogue_with_attr(input_buffer, attr, **kwargs): + input_loader = input_buffer.make_loader() + dtype = input_buffer.get_dtype() + if attr == "relu": + + def inner_fn(index): + input = input_loader(index) + zero = ops.constant(0, dtype) + return ops.maximum(input, zero) + + elif attr == "gelu": + assert "algorithm" in kwargs + if kwargs["algorithm"] == "none": + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + half = ops.constant(0.5, torch.float) + one = ops.constant(1.0, torch.float) + const = ops.constant(0.7071067811865476, torch.float) + result = input * half * (ops.erf(input * const) + one) + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + else: + assert kwargs["algorithm"] == "tanh" + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + half = ops.constant(0.5, torch.float) + one = ops.constant(1.0, torch.float) + const1 = ops.constant(0.7978845608028654, torch.float) + const2 = ops.constant(0.044715, torch.float) + result = ( + half + * input + * ( + one + + ops.tanh(const1 * (input + const2 * input * input * input)) + ) + ) + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + elif attr == "swish": + + def inner_fn(index): + input = input_loader(index) + result = input * ops.sigmoid(input) + return result + + elif attr == "sigmoid": + + def inner_fn(index): + return ops.sigmoid(input_loader(index)) + + elif attr == "tanh": + + def inner_fn(index): + return ops.tanh(input_loader(index)) + + elif attr == "hardswish" or attr == "hardsigmoid": + + def hardsigmoid_float(input): + zero = ops.constant(0, torch.float) + six = ops.constant(6, torch.float) + three = ops.constant(3, torch.float) + one_over_six = ops.constant(0.16666666666666666, torch.float) + max = ops.maximum(input + three, zero) + min = ops.minimum(max, six) + return min * one_over_six + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + result = hardsigmoid_float(input) + if attr == "hardswish": + result = input * result + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + elif attr == "leaky_relu": + assert "scalars" in kwargs + assert len(kwargs["scalars"]) == 1 + negative_slope = kwargs["scalars"][0] + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + zero = ops.constant(0, torch.float) + result = ops.where( + input > zero, input, input * ops.constant(negative_slope, torch.float) + ) + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + elif attr == "hardtanh": + assert "scalars" in kwargs + assert len(kwargs["scalars"]) == 2 + min_value = kwargs["scalars"][0] + max_value = kwargs["scalars"][1] + + def inner_fn(index): + input = input_loader(index) + if dtype != torch.float: + input = ops.to_dtype(input, torch.float) + result = ops.minimum( + ops.maximum(input, ops.constant(min_value, torch.float)), + ops.constant(max_value, torch.float), + ) + if dtype != torch.float: + result = ops.to_dtype(result, dtype) + return result + + elif attr in ["add", "sub", "mul"]: + assert "other" in kwargs + other = kwargs["other"] + num_input_dims = len(input_buffer.get_size()) + num_other_dims = len(other.get_size()) + dims_diff = num_input_dims - num_other_dims + other_loader = other.make_loader() + + def inner_fn(index): + op = getattr(ops, attr) + if dims_diff != 0: + return op(input_loader(index), other_loader(index[dims_diff:])) + else: + return op(input_loader(index), other_loader(index)) + + elif attr == "bias_add": + assert "other" in kwargs + assert "beta" in kwargs + assert "dtype" in kwargs + beta = kwargs["beta"] + other = kwargs["other"] + dtype = kwargs["dtype"] + bias_loader = other.make_loader() + + def inner_fn(index): + bias = bias_loader(index) + input = input_loader(index) + if beta != 1: + result = ops.constant(beta, torch.float) * bias + input + else: + result = bias + input + return result + + else: + raise ValueError(f"Unsupported epilogue attribute: {attr}") + return ir.Pointwise( + device=input_buffer.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=input_buffer.get_size(), + ) + + +def _get_loop_body(fn_list): + if all(isinstance(fn, LoopBody) for fn in fn_list): + loop_bodies = fn_list + else: + if hasattr(fn_list[0], "original_fn"): + # For the case of local buffer, we wrap the fn with localize_function + assert all(hasattr(fn, "original_fn") for fn in fn_list) + assert all( + isinstance(fn.original_fn.args[0]._body, LoopBody) for fn in fn_list + ) + loop_bodies = [fn.original_fn.args[0]._body for fn in fn_list] + else: + assert all(isinstance(fn, functools.partial) for fn in fn_list) + assert all(isinstance(fn.args[0]._body, LoopBody) for fn in fn_list) + loop_bodies = [fn.args[0]._body for fn in fn_list] + assert loop_bodies is not None + return loop_bodies + + +def _get_dtype_from_loopbodies(loop_bodies): + dtypes = OrderedSet[torch.dtype]() + for loop_body in loop_bodies: + graphs = [loop_body.root_block.graph] + [ + body.graph for body in list(loop_body.subblocks.values()) + ] + for graph in graphs: + for node in graph.nodes: + if node.op != "call_method": + continue + dtypes.add(node.meta[OptimizationContext.key].dtype) + return dtypes + + +def template_fusion_with_epilogues_supported( + template: BaseSchedulerNode, epilogues: list[BaseSchedulerNode] +) -> tuple[bool, bool]: + def _get_indexes_of_template_buf_read( + epilogue_node: ir.Operation, template_buf_names: list[str] + ) -> list[sympy.Expr]: + return [ + read.index + for read in epilogue_node.get_reads() + if read.name in template_buf_names + ] + + def _check_supported_and_same_indexes( + index_of_template_buf_read: Sequence[sympy.Expr], + epilogue_writes: OrderedSet[Dep], + ) -> tuple[bool, bool]: + num_indexes = len(OrderedSet(index_of_template_buf_read)) + + if num_indexes > 1: + same_index = False + supported = False # Different read indexes not supported + elif num_indexes == 0: + same_index = True + supported = True # No reads, automatically supported + elif num_indexes == 1: + iotbr = index_of_template_buf_read[0] + same_index = all(write.index == iotbr for write in epilogue_writes) + # TODO: Add support of fusion when the read of template buffer and the write of epilogue output + # in the epilogue node don't have the same index and change supported to True + supported = same_index + else: + raise AssertionError("Should not reach here") + + return supported, same_index + + def _template_fusion_supported( + template_outputs: Sequence[SchedulerBuffer], epilogue_nodes: list[ir.Operation] + ) -> tuple[bool, bool]: + template_buf_names = [x.get_name() for x in template_outputs] + indexes_of_template_buf_reads = [ + _get_indexes_of_template_buf_read(epilogue_node, template_buf_names) + for epilogue_node in epilogue_nodes + ] + epilogue_nodes_writes = [ + epilogue_node.get_read_writes().writes for epilogue_node in epilogue_nodes + ] + + results = [ + _check_supported_and_same_indexes(reads, writes) + for reads, writes in zip( + indexes_of_template_buf_reads, epilogue_nodes_writes + ) + ] + supported, same_indexes = zip(*results) + return all(supported), all(same_indexes) + + assert template.is_template() + template_outputs = template.get_outputs() + + epilogue_nodes = [ + n.node + for epilogue in epilogues + for n in epilogue.get_nodes() + if n.node is not None + ] + return _template_fusion_supported(template_outputs, epilogue_nodes) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py new file mode 100644 index 0000000000000000000000000000000000000000..769fcfeadc791830816669739071380c5e2e64d7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -0,0 +1,2747 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import ctypes +import functools +import math +import os +import sys +import textwrap +from itertools import chain, count +from typing import Any, Callable, Optional, Protocol, TYPE_CHECKING, Union + +import sympy + +import torch +import torch._higher_order_ops.torchbind +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._ops +from torch._inductor.runtime.runtime_utils import dynamo_timed +from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.symbol import symbol_is_type, SymT + +from .. import config, ir +from ..utils import _align, DeferredLineBase, LineContext, normalize_name +from ..virtualized import V +from .aoti_hipify_utils import maybe_hipify_code_wrapper +from .common import get_device_op_overrides, IndentedBuffer, Kernel +from .cpp_utils import cexpr, DEVICE_TO_ATEN, DEVICE_TO_INT, DTYPE_TO_ATEN, DTYPE_TO_CPP +from .wrapper import ( + EnterSubgraphLine, + ExitSubgraphLine, + PythonWrapperCodegen, + SymbolicCallArg, +) + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from ..graph import GraphLowering + + # At most, the list nesting can go one layer deep. + _OUTPUT_ARGS_TYPE = list[Union[Optional[str], list[Optional[str]]]] + + +class HasWriteLine(Protocol): + def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None: ... + + +class CppWrapperCpu(PythonWrapperCodegen): + """ + Generates cpp wrapper for running on CPU and calls cpp kernels + """ + + def __init__(self): + if not hasattr(self, "device"): + self.device = "cpu" + # must be initialized prior to calling super().__init__() + self.included_devices: OrderedSet[str] = OrderedSet() + super().__init__() + self.declare = "auto " + self.declare_maybe_reference = "decltype(auto) " + self.ending = ";" + self.comment = "//" + self.none_str = "nullptr" + self.supports_intermediate_hooks = False + self.kernel_callsite_id = count() + self.int_array_id = count() # for int array local variable declarations + self.declared_int_array_vars: OrderedSet[str] = OrderedSet() + self.tmp_tensor_id = count() # for tmp tensor local variable declarations + self.arg_var_id = count() + self.used_cached_devices: OrderedSet[str] = OrderedSet() + self.used_cached_dtypes: OrderedSet[str] = OrderedSet() + self.used_cached_layouts: OrderedSet[str] = OrderedSet() + self.used_cached_memory_formats: OrderedSet[str] = OrderedSet() + self.used_cond_predicate: OrderedSet[str] = OrderedSet() + self.cached_output_id = count() + self.scalar_to_tensor_id = count() + self.custom_op_wrapper_loaded = False + # For GEMM kernels that must be initialized and are resolved at linking. + self.initialized_kernels: dict[str, Kernel] = {} + self.device_codegen = get_device_op_overrides(self.device) + # only need to include each header once + self.include_extra_header = functools.lru_cache(None)( # type: ignore[method-assign] + self._include_extra_header + ) + + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[ir.GraphPartitionSignature] = None, + ): + # TODO - support subgraph codegen by lifting functions. Check the + # comment at CppWrapperCpu `codegen_subgraph` function. + return CppWrapperCpu() + + @staticmethod + def _generate_temporary_array_pointer( + c_type: str, elements: Sequence[str], *, force_mutable: bool = False + ) -> str: + """Get a pointer to an array that only exists for the duration of the C++ + statement it's used in.""" + # If the c_type is already a pointer, return a mutable pointer to the array. + # Otherwise, return a const pointer. In the C-shim API, pointer types are only + # const-qualified with respect to the underlying value, not any nested pointers. + # e.g. const double** is possible, but not const double* const*. This means + # that an array containing pointers must _already_ be properly const-qualified + # by the c_type, and not add additional const-ness. + ptr_call = "data()" if force_mutable or c_type.endswith("*") else "cbegin()" + return ( + f"std::array<{c_type}, {len(elements)}>{{{', '.join(elements)}}}.{ptr_call}" + ) + + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + graph_name="", + original_fxnode_name=None, + ): + """ + Generates kernel call code. + + triton: Defines whether the GPU backend uses Triton for codegen. + Otherwise it uses the CUDA language for codegen. + Only valid when cuda == True. + """ + assert arg_types is not None and len(call_args) == len(arg_types), ( + "Mismatch call_args and arg_types in generate_kernel_call:\n" + f"call_args: {call_args}\n" + f"arg_types: {arg_types}" + ) + new_args = [] + for idx, arg in enumerate(call_args): + if "*" in arg_types[idx]: + new_args.append(f"({arg_types[idx]})({arg}.data_ptr())") + else: + # arg is a scalar + new_args.append(arg) + # debug printer related logic for cpp kernel type. + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, + kernel_name, + None, + None, + "cpp", + ) + with debug_printer_manager: + self.writeline(self.wrap_kernel_call(kernel_name, new_args)) + + def write_constant(self, name, hashed): + # include a hash so our code cache gives different constants different files + self.header.writeline(f"// {name} {hashed}") + + @staticmethod + def get_device_include_path(device: str) -> str: + if V.graph.aot_mode: + return f"#include " + return f"#include " + + def add_device_include(self, device: str) -> None: + if device in self.included_devices: + return + + self.included_devices.add(device) + + # Add the default header for this device, plus any C-shim extensions that are + # present. + self.header.splice(self.get_device_include_path(device)) + extend_aoti_c_shim_include = ( + f"torch/csrc/inductor/aoti_torch/generated/extend/c_shim_{self.device}.h" + ) + extend_aoti_c_shim_path = os.path.join( + os.path.dirname(torch.__file__), + "include", + extend_aoti_c_shim_include, + ) + if os.path.exists(extend_aoti_c_shim_path): + self.header.splice(f"#include <{extend_aoti_c_shim_include}>") + + def write_header(self): + if V.graph.is_const_graph: + # We do not write header for constant graph, it will be written by main module. + return + + if not V.graph.aot_mode: + self.header.splice( + """ + import torch + from torch._inductor.codecache import CppWrapperCodeCache + + cpp_wrapper_src = ( + r''' + """ + ) + + self.add_device_include(self.device) + + if V.graph.aot_mode: + with open( + os.path.join(os.path.dirname(__file__), "aoti_runtime", "interface.cpp") + ) as f: + self.header.splice(f.read()) + self.header.splice("\n") + + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] + if config.profiler_mark_wrapper_call or enable_kernel_profile: + # No C shim for profiling APIs, assuming profiling is a debugging feature which + # does not provide any ABI compatibility promise. + self.header.splice("#include ") + + def _include_extra_header(self, header: str): + # This is needed for cpp to python dtype conversion + self.header.splice(f"#include <{header}>") + + def mark_output_type(self): + # mark output type to unwrap tensor back to python scalar + from ..ir import ShapeAsConstantBuffer + + output_is_tensor = {} + for idx, x in enumerate(V.graph.graph_outputs): + if isinstance(x, ShapeAsConstantBuffer): + output_is_tensor[idx] = False + else: + output_is_tensor[idx] = True + + self.output_is_tensor = output_is_tensor + + def write_prefix(self): + if V.graph.is_const_graph: + # We do not write prefix for constant graph, it will be written by main module. + return + if config.aot_inductor.custom_ops_to_c_shims: + # custom_ops_to_c_shims contains declaration of custom ops with C shim. + # TODO: this could be auto-generated from a passed-in custom op schema + custom_c_shims = list( + chain(*config.aot_inductor.custom_ops_to_c_shims.values()) + ) + declarations = "\n".join( + [f"extern {textwrap.dedent(shim)};" for shim in custom_c_shims] + ) + self.prefix.splice( + f""" + extern "C" {{ + {declarations} + }} + """ + ) + if V.graph.aot_mode: + self.prefix.writeline("namespace torch::aot_inductor {") + + def write_input_output_info( + self, + info_kind: str, + idx: int, + name: str, + ): + self.prefix.writeline(f"""{info_kind}[{idx}].name = "{name}";""") + + def codegen_input_symbol_assignment( + self, + name: str, + value: ir.TensorBox, + bound_vars: OrderedSet[sympy.Symbol], + ): + code = self.prefix + + @functools.cache + def sizeof(name): + self.codegen_input_size_var_decl(code, name) + return f"{name}_size" + + @functools.cache + def strideof(name): + self.codegen_input_stride_var_decl(code, name) + return f"{name}_stride" + + def codegen_symbol( + sym_or_exp: Union[sympy.Symbol, sympy.Expr], + base_name: str, + name_fn: Callable[[str], str], + dim: int, + ): + if isinstance(sym_or_exp, sympy.Symbol): + if sym_or_exp in bound_vars: + return + code.writeline(f"int64_t {sym_or_exp} = {name_fn(base_name)}[{dim}];") + bound_vars.add(sym_or_exp) + elif isinstance(sym_or_exp, sympy.Expr): + undefined_symbols = [ + sym for sym in sym_or_exp.free_symbols if sym not in bound_vars + ] + if len(undefined_symbols) != 1: + # Skip if expression contains no symbols or if multiple + # symbols exists since we assume each base symbol is defined + # by other codegen_symbol calls. + return + + from torch.utils._sympy.solve import try_solve + + free_symbol = undefined_symbols.pop() + base_name = name_fn(base_name) + # Use a size symbol to solve the free symbol + size_symbol = sympy.Symbol(f"{base_name}_{dim}", integer=True) + code.writeline(f"int64_t {size_symbol} = {base_name}[{dim}];") + solution = try_solve(sympy.Eq(sym_or_exp, size_symbol), free_symbol) + if solution is not None: + code.writeline(f"int64_t {free_symbol} = {cexpr(solution[1])};") + bound_vars.add(free_symbol) + else: + raise AssertionError( + str(sympy.Eq(sym_or_exp, size_symbol)) + " is not solvable" + ) + + if isinstance(value, sympy.Expr): + if not isinstance(value, sympy.Symbol) or value in bound_vars: + return + if value.is_integer: + decl = "int64_t" + elif value.is_float: + decl = "double" + else: + raise AssertionError("Unexpected symbol type") + code.writeline(f"{decl} {value} = {name};") + bound_vars.add(value) + elif isinstance(value, ir.TensorBox): + for dim, size in enumerate(value.get_size()): + codegen_symbol(size, name, sizeof, dim) + for dim, stride in enumerate(value.get_stride()): + codegen_symbol(stride, name, strideof, dim) + elif isinstance(value, ir.TorchBindObject): + # torchbind objects are loaded in proxy executor + pass + else: + raise AssertionError(f"Unknown value type: {type(value)}") + + def generate_input_output_runtime_checks(self): + """ + In debug_compile mode, we generate checks to ensure the dtype/shape/stride/device of each + real input/output tensor match ones provided at compile time via sample + input/output. + """ + + def gen_check(handle_kind, idx, name, tensor): + # Wrap AtenTensorHandle with ConstantHandle for cleaner utility function access + self.prefix.writeline( + f"ConstantHandle {name} = ConstantHandle({handle_kind}[{idx}]);" + ) + self.codegen_tensor_dtype_var_decl(self.prefix, name) + expected_dtype_name = DTYPE_TO_ATEN[tensor.dtype] + dtype_str = str(tensor.dtype).split(".")[-1] + self.prefix.splice( + f""" + int32_t {name}_expected_dtype = aoti_torch_dtype_{dtype_str}(); + if ({name}_expected_dtype != {name}_dtype) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: unmatched dtype, " + << "expected: " << {name}_expected_dtype << "({expected_dtype_name}), " + << "but got: " << {name}_dtype << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + self.codegen_input_size_var_decl(self.prefix, name) + for dim_idx, d in enumerate(tensor.get_size()): + if isinstance(d, (int, sympy.Integer)): + self.prefix.splice( + f""" + if ({d} != {name}_size[{dim_idx}]) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: unmatched dim value at {dim_idx}, " + << "expected: {d}, " << "but got: " << {name}_size[{dim_idx}] + << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + else: + from torch.utils._sympy.value_ranges import bound_sympy + + sym_range = bound_sympy(d, V.graph.sizevars.shape_env.var_to_range) + if not math.isinf(sym_range.lower): + self.prefix.splice( + f""" + if ({name}_size[{dim_idx}] < {sym_range.lower}) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: dim value is too small at {dim_idx}, " + << "expected it to be >= {sym_range.lower}, " << "but got: " + << {name}_size[{dim_idx}] << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + if not math.isinf(sym_range.upper): + # Limit upper bound to max C long long value (2^63 - 1) + max_long_long = ctypes.c_longlong(2**63 - 1).value + upper_bound = min(sym_range.upper, max_long_long) + self.prefix.splice( + f""" + if ({name}_size[{dim_idx}] > {upper_bound}) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: dim value is too large at {dim_idx}, " + << "expected to be <= {upper_bound}, " << "but got: " + << {name}_size[{dim_idx}] << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + + self.codegen_input_stride_var_decl(self.prefix, name) + for stride_idx, s in enumerate(tensor.get_stride()): + if not isinstance(s, (int, sympy.Integer)): + continue + self.prefix.splice( + f""" + if ({s} != {name}_stride[{stride_idx}]) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: unmatched stride value at {stride_idx}, " + << "expected: {s}, " << "but got: " << {name}_stride[{stride_idx}] + << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + + # check input device type + if isinstance(tensor, ir.TensorBox): + tensor_device = tensor.get_device() + if tensor_device is not None: + expected_device_type = DEVICE_TO_INT.get(tensor_device.type) + if expected_device_type is not None: + self.codegen_input_device_type_var_decl(self.prefix, name) + device_type_str = str(tensor_device.type) + self.prefix.splice( + f""" + int32_t {name}_expected_device_type = {expected_device_type}; + if ({name}_expected_device_type != {name}_device_type) {{ + std::stringstream ss; + ss << "{handle_kind}[{idx}]: unmatched device type, " + << "expected: " << {name}_expected_device_type << "{expected_device_type}({device_type_str}), " + << "but got: " << {name}_device_type << "\\n"; + throw std::runtime_error(ss.str()); + }} + """ + ) + + # Create a separate function for each input check to avoid "too big to optimize" error + for idx, (name, tensor) in enumerate(V.graph.graph_inputs.items()): + self.prefix.splice( + f""" + AOTI_NOINLINE static void check_input_{idx}( + AtenTensorHandle* input_handles + ) {{ + """ + ) + with self.prefix.indent(): + gen_check("input_handles", idx, name, tensor) + self.prefix.writeline("}") + + # force noinline to avoid any potential compilation slowdown due to aggressive + # inline done by the host compiler + self.prefix.splice( + """ + static bool _check_aoti_runtime_check_inputs_env() { + const static char* env_var_value = getenv("AOTI_RUNTIME_CHECK_INPUTS"); + const static bool result = env_var_value != nullptr && env_var_value[0] != '0'; + return result; + } + + AOTI_NOINLINE static void __check_inputs_outputs( + AtenTensorHandle* input_handles, + AtenTensorHandle* output_handles) { + if (!_check_aoti_runtime_check_inputs_env()){ + return; + } + """ + ) + with self.prefix.indent(): + for idx in range(len(V.graph.graph_inputs)): + self.prefix.writeline(f"check_input_{idx}(input_handles);") + self.prefix.writeline("}") + + def write_wrapper_decl(self): + inputs_len = len(V.graph.graph_inputs.keys()) + if V.graph.aot_mode: + if V.graph.const_module: + self.header.splice(V.graph.const_module.wrapper_code.header) + + assert V.graph.const_wrapper_code is not None + self.prefix.splice(V.graph.const_wrapper_code) + + assert V.graph.const_kernel_code is not None + self.kernel_declarations.splice(V.graph.const_kernel_code) + + if V.graph.is_const_graph: + self.prefix.splice( + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + ) + else: + if not config.aot_inductor.use_runtime_constant_folding: + # If we do not split the constant graph, we'll just create + # an empty implementation when wrapping the main module. + self.prefix.splice( + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) {} + + """ + ) + + run_impl_proto = """ + void AOTInductorModel::run_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + __check_inputs_outputs(input_handles, output_handles); + """ + + self.generate_input_output_runtime_checks() + self.prefix.splice(run_impl_proto) + else: + # cpp entry function for JIT with cpp wrapper + self.prefix.splice( + """ + void inductor_entry_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed) + ) { + """ + ) + with self.prefix.indent(): + # assign inputs and outputs in both cases so the later codegen can be simplified + if not V.graph.is_const_graph: + if V.graph.aot_mode: + num_args = len(V.graph.graph_inputs) + else: + # Weights are promoted in the JIT mode + num_args = len(V.graph.graph_inputs) + len(V.graph.constants) + # release GIL to support multiple instances inference (in different threads of the same process) + self.prefix.splice("py::gil_scoped_release release;") + + self.prefix.splice( + f""" + auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args}); + """ + ) + + if inputs_len != 0: + for idx, input_key in enumerate(V.graph.graph_inputs.keys()): + # unwrap input tensor back to scalar + if isinstance(V.graph.graph_inputs[input_key], sympy.Expr): + from ..graph import may_get_constant_buffer_dtype + + dtype = may_get_constant_buffer_dtype( + V.graph.graph_inputs[input_key] # type: ignore[arg-type] + ) + assert dtype is not None, ( + "Fails to get the dtype of the sympy.Expr" + ) + self.codegen_tensor_item( + dtype, f"inputs[{idx}]", input_key, self.prefix + ) + else: + self.prefix.writeline( + f"auto {input_key} = std::move(inputs[{idx}]);" + ) + # debug printing for all input args to AOTI model + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.codegen_model_inputs_value_print( + input_args_to_print=[ + input_key + for input_key in V.graph.graph_inputs.keys() + if input_key.startswith("arg") + ] + ) + + assert all( + isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) + ), "Expect all constants to be Tensor" + for idx, constants_key in enumerate(V.graph.constants.keys()): + if V.graph.aot_mode: + # Weights are stored in constants_ and owned by RAIIAtenTensorHandle there. + # Don't call std::move here because it will cause constants_ to lose the ownership. + self.prefix.writeline( + f"""[[maybe_unused]] auto {constants_key} = constants_->at({idx});""" + ) + else: + # Append constants as inputs to the graph + constants_idx = inputs_len + idx + self.prefix.writeline( + f"[[maybe_unused]] auto {constants_key} = std::move(inputs[{constants_idx}]);" + ) + + self.codegen_inputs() + + if V.graph.aot_mode: + if not V.graph.is_const_graph: + self.prefix.writeline("inputs.clear();") + self.prefix.writeline( + "[[maybe_unused]] auto& kernels = static_cast(*this->kernels_.get());" + ) + + def codegen_tensor_dtype_var_decl(self, code: IndentedBuffer, name): + code.writeline(f"int32_t {name}_dtype;") + code.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype({name}, &{name}_dtype));" + ) + + def codegen_input_size_var_decl(self, code: IndentedBuffer, name): + code.writeline(f"auto {name}_size = {name}.sizes();") + + def codegen_input_stride_var_decl(self, code: IndentedBuffer, name): + code.writeline(f"auto {name}_stride = {name}.strides();") + + def codegen_input_device_type_var_decl(self, code: IndentedBuffer, name): + code.writeline(f"int32_t {name}_device_type;") + code.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type({name}, &{name}_device_type));" + ) + + def codegen_model_kernels(self): + self.prefix.writeline("namespace {") + + # Tell compiler we need to link with the non-mangled symbols + for kernel in self.initialized_kernels.values(): + assert hasattr(kernel, "get_signature"), ( + f"{kernel} must have get_signature implemented" + ) + signature = kernel.get_signature() + self.prefix.writeline(f'extern "C" {signature};') + + self.prefix.writeline( + "class AOTInductorModelKernels : public AOTInductorModelKernelsBase {" + ) + self.prefix.writeline(" public:") + declare_kernel = OrderedSet(self.src_to_kernel.values()) - OrderedSet( + self.initialized_kernels.keys() + ) + declare_kernel.update( + entry[0] for entry in self.user_defined_kernel_cache.values() + ) + if V.graph.const_module: + declare_kernel.update( + V.graph.const_module.wrapper_code.src_to_kernel.values() + ) + for kernel in sorted(declare_kernel): + self.prefix.writeline( + maybe_hipify_code_wrapper( + f" {self.device_codegen.cpp_kernel_type()} {kernel}{{nullptr}};" + ) + ) + for name, kernel in self.initialized_kernels.items(): + assert hasattr(kernel, "get_signature"), ( + f"{kernel} must have get_signature implemented" + ) + kernel_ptr = f"(*{name})" + signature = kernel.get_signature().replace(name, kernel_ptr) + self.prefix.writeline(f" {signature} = torch::aot_inductor::{name};") + self.prefix.writeline("};") + self.prefix.writeline("} // namespace\n\n") + + if config.aot_inductor.embed_kernel_binary: + self.prefix.writeline('extern "C" {') + for name in sorted(declare_kernel): + self.prefix.writeline( + f" extern const unsigned char __{name}_start[];" + ) + if torch.xpu.is_available(): + self.prefix.writeline( + f" extern const unsigned char __{name}_end[];" + ) + self.prefix.writeline("}") + + def codegen_model_constructor(self): + """ + // Generated code example + AOTInductorModel::AOTInductorModel() + : AOTInductorModelBase(4, 1) { + inputs_info_[0].name = "input0"; + inputs_info_[0].dtype = "torch.float16"; + ... + constants_info_[0].name = "L__self___weight"; + constants_info_[0].dtype = at::kFloat; + constants_info_[0].offset = 0; + constants_info_[0].data_size = 8192; + constants_info_[0].shape = {64, 32}; + constants_info_[0].stride = {32, 1}; + ... + outputs_info_[0].name = "output0"; + outputs_info_[0].dtype = "torch.float16"; + } + """ + + num_inputs = len(V.graph.graph_inputs) + num_outputs = len(V.graph.graph_outputs) + num_constants = len(V.graph.constants) + include_weights = ( + "true" if config.aot_inductor.package_constants_in_so else "false" + ) + self.prefix.splice( + f""" + AOTInductorModel::AOTInductorModel(std::shared_ptr constants_map, + std::shared_ptr> constants_array, + const std::string& device_str, + std::optional cubin_dir) + : AOTInductorModelBase({num_inputs}, + {num_outputs}, + {num_constants}, + device_str, + std::move(cubin_dir), + {include_weights}) {{ + """ + ) + + with self.prefix.indent(): + for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()): + assert not isinstance(inp, sympy.Expr), ( + f"input {name=} cannot be symbolic" + ) + self.write_input_output_info("inputs_info_", idx, name) + + all_cuda = all( + V.graph.get_original_value_of_constant(name).is_cuda + for name in V.graph.constants.keys() + if name not in V.graph.folded_constants + ) + for idx, name in enumerate(V.graph.constants.keys()): + tensor = V.graph.get_original_value_of_constant(name) + assert isinstance(tensor, torch.Tensor) + self.prefix.writeline(f"""constants_info_[{idx}].name = "{name}";""") + self.prefix.writeline( + f"constants_info_[{idx}].dtype = static_cast({self.codegen_dtype(tensor.dtype)});" + ) + self.prefix.writeline( + f"constants_info_[{idx}].offset = {tensor.storage_offset()};" + ) + + # If constants to serialize contain cpu tensors, we always align data_size it to 64. + # When loading the constants, the valid data will depends on the size + # not the data_size so there won't be correctness issue. + data_size = ( + torch.ops.mkldnn._nbytes(tensor) + if tensor.is_mkldnn + else tensor.untyped_storage().nbytes() + ) + self.prefix.writeline( + f"constants_info_[{idx}].data_size = {data_size if all_cuda else _align(data_size)};" + ) + + from_folded = "true" if name in V.graph.folded_constants else "false" + self.prefix.writeline( + f"constants_info_[{idx}].from_folded = {from_folded};" + ) + + if name in V.graph.folded_constants: + constant_type_str = "FoldedConstant" + elif name.startswith("_tensor_constant"): + constant_type_str = "TensorConstant" + elif any( + name == normalize_name(parameter_name) + for parameter_name in V.graph.named_parameters + ): + constant_type_str = "Parameter" + elif any( + name == normalize_name(buffer_name) + for buffer_name in V.graph.named_buffers + ): + constant_type_str = "Buffer" + else: + constant_type_str = "Unknown" + self.prefix.writeline( + f"constants_info_[{idx}].type = static_cast(torch::aot_inductor::ConstantType::{constant_type_str});" + ) + + size_str = ", ".join([str(s) for s in tensor.size()]) + self.prefix.writeline(f"constants_info_[{idx}].shape = {{{size_str}}};") + + stride_str = ", ".join([str(s) for s in tensor.stride()]) + self.prefix.writeline( + f"constants_info_[{idx}].stride = {{{stride_str}}};" + ) + self.prefix.writeline( + f"constants_info_[{idx}].layout = static_cast({self.codegen_layout(tensor.layout)});" + ) + + if tensor.is_mkldnn: + opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md( + tensor + ) + assert opaque_metadata_tensor.dim() == 1, ( + "Expect opaque_metadata_tensor to be 1-D" + ) + + opaque_metadata_list = opaque_metadata_tensor.tolist() + opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list) + self.prefix.writeline( + f"constants_info_[{idx}].opaque_metadata = {opaque_metadata_str};" + ) + if name in V.graph.dynamo_flat_name_to_original_fqn: + original_fqn = V.graph.dynamo_flat_name_to_original_fqn.get( + name, name + ) + elif name in V.graph.allocated_constant_name: + original_fqn = V.graph.allocated_constant_name[name] + else: + raise AssertionError("original_fqn must be set for constant") + self.prefix.writeline( + f"""constants_info_[{idx}].original_fqn = "{original_fqn}";""" + ) + self.prefix.writeline("update_constants_map(std::move(constants_map));") + self.prefix.writeline("update_constants_array(std::move(constants_array));") + + def escape_string(x): + return ( + x.replace("\\", "\\\\") + .replace('"', '\\"') + .replace("\n", "\\n") + .replace("\t", "\\t") + ) + + self.prefix.writeline( + f'in_spec_ = R"({config.aot_inductor.serialized_in_spec})";' + ) + self.prefix.writeline( + f'out_spec_ = R"({config.aot_inductor.serialized_out_spec})";' + ) + + for idx, output in enumerate(V.graph.graph_outputs): + assert not isinstance(output, sympy.Expr), ( + f"output {name=} cannot be symbolic" + ) + name = f"output{idx}" + self.write_input_output_info("outputs_info_", idx, name) + + self.prefix.writeline( + "this->kernels_ = std::make_unique();" + ) + + self.prefix.writeline("}") + + def codegen_const_run_driver(self): + """ + // Generated code example + std::unordered_map AOTInductorModel::const_run_impl( + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor, + bool initialization + ) { + std::unordered_map folded_constants_map; + std::vector output_handles; + // build up output_handles over here. + _const_run_impl(output_handles, stream, proxy_executor); + // build up folded_constants_map + return folded_constants_map; + } + """ + + self.prefix.splice( + """ + std::unordered_map AOTInductorModel::const_run_impl( + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor, + bool initialization + ) { + """ + ) + if not config.aot_inductor.use_runtime_constant_folding: + self.prefix.splice( + """ + if (!initialization) { + std::cerr << "[WARNING] Calling constant_folding in model, but compiled with config: " + << "aot_inductor.use_runtime_constant_folding=False\\n"; + } + return {}; + } + """ + ) + return + + with self.prefix.indent(): + # This is a mapping to the index of constant folding graph's output + const_index_mapping: list[Optional[tuple[int, str]]] = [None] * len( + V.graph.const_output_index + ) + for idx, (name, _) in enumerate(V.graph.constants.items()): + if name in V.graph.const_output_index: + const_index_mapping[V.graph.const_output_index[name]] = (idx, name) # type: ignore[call-overload] + assert None not in const_index_mapping, ( + "Not all constant gets mapped for constant folding graph." + ) + + self.prefix.writeline( + f""" + std::unordered_map folded_constants_map; + folded_constants_map.reserve({len(const_index_mapping)}); + std::vector output_handles({len(const_index_mapping)}); + """ + ) + + self.prefix.splice( + """ + // The below assignment of output_handles to constants is not used directly. + // It's only used to memo the correspondence of handle and constants. + """ + ) + + for output_idx, (const_idx, _) in enumerate(const_index_mapping): # type: ignore[misc] + self.prefix.writeline( + f"output_handles[{output_idx}] = constants_->at({const_idx});" + ) + + self.prefix.writeline( + "_const_run_impl(output_handles, stream, proxy_executor);" + ) + + for output_idx, (_, const_name) in enumerate(const_index_mapping): # type: ignore[misc] + self.prefix.writeline( + f'folded_constants_map["{const_name}"] = output_handles[{output_idx}];' + ) + self.prefix.writeline("return folded_constants_map;") + + self.prefix.writeline("}") + + def generate(self, is_inference): + with dynamo_timed("CppWrapperCpu.generate", log_pt2_compile_event=True): + self.write_wrapper_decl() + return super().generate(is_inference) + + def finalize_prefix(self): + prior = self.prefix + self.prefix = aot_mode_decls = IndentedBuffer() + if V.graph.aot_mode and not V.graph.is_const_graph: + aot_mode_decls.writeline("namespace torch::aot_inductor {") + self.codegen_model_kernels() + self.codegen_model_constructor() + self.codegen_const_run_driver() + aot_mode_decls.writeline("} // namespace torch::aot_inductor") + aot_mode_decls.writeline("using namespace torch::aot_inductor;") + + self.prefix = cache_decls = IndentedBuffer() + for dtype in self.used_cached_dtypes: + cache_decls.writeline(f"CACHE_TORCH_DTYPE({dtype});") + for device in self.used_cached_devices: + cache_decls.writeline(f"CACHE_TORCH_DEVICE({device});") + for layout in self.used_cached_layouts: + cache_decls.writeline(f"CACHE_TORCH_LAYOUT({layout});") + for memory_format in self.used_cached_memory_formats: + cache_decls.writeline(f"CACHE_TORCH_MEMORY_FORMAT({memory_format});") + + self.prefix.splice(aot_mode_decls) + self.prefix.splice(prior) + + def _define_kernel_helper( + self, + kernel_name: str, + kernel_body: str, + metadata: Optional[str] = None, + gpu: bool = False, + cpp_definition: Optional[str] = None, + ): + if cpp_definition is not None: + self.header.splice(cpp_definition) + self.kernel_declarations.splice(f"\n{kernel_body}\n") + else: + self.header.splice(f"\n{kernel_body}\n") + + def codegen_scalar_to_tensor(self, output: str): + name = f"scalar_to_tensor_{next(self.scalar_to_tensor_id)}" + self.wrapper_call.writeline( + f"RAIIAtenTensorHandle {name} = scalar_to_tensor_handle({output});" + ) + return name + + def codegen_tensor_item( + self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None + ): + dtype_str = str(dtype).split(".")[-1] + writer = indented_buffer or self + + if dtype == torch.float16 or dtype == torch.bfloat16: + scalar_tmp = f"{scalar}_tmp" + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar_tmp};") + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));" + ) + writer.writeline(f"float {scalar} = float({scalar_tmp});") + else: + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};") + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));" + ) + + def generate_return(self, output_refs: list[str]): + cst_names = V.graph.constants.keys() + output2idx: dict[str, int] = {} + + # If any output ref represents an rvalue tensor, materialize it to an lvalue + # RAIIAtenTensorHandle first. This prevents situations where the code for the + # rvalue tensor references tensor handles whose contents are modified below. + output_refs = [ + self.create_tmp_raii_handle_var_if_needed(o, self.wrapper_call) + for o in output_refs + ] + + for idx, output in enumerate(output_refs): + if output == "nullptr": + continue + + is_constant_buffer = output in cst_names + output_buffer = V.graph.graph_outputs[idx] + if isinstance(output_buffer, ir.BaseView): + output_storage = output_buffer.unwrap_view() + if isinstance(output_storage.data, ir.ConstantBuffer): + is_constant_buffer = True + + if isinstance(output_buffer, ir.ShapeAsConstantBuffer): + # Need to wrap scalar into tensor as the main function returns a vector of tensors + output_tensor = self.codegen_scalar_to_tensor(output) + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output_tensor}.release();" + ) + continue + + if is_constant_buffer: + # See NOTE(return_constant) above. + self.wrapper_call.writeline( + f"aoti_torch_clone({output}, &output_handles[{idx}]);" + ) + else: + if output in output2idx: + src_idx = output2idx[output] + self.wrapper_call.writeline( + f"output_handles[{idx}] = output_handles[{src_idx}];" + ) + else: + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output}.release();" + ) + + if output not in output2idx: + output2idx[output] = idx + + def generate_before_suffix(self, result): + if not V.graph.is_const_graph: + if V.graph.aot_mode: + result.writeline("} // AOTInductorModel::run_impl") + else: + result.writeline("} // inductor_entry_impl") + + def generate_end(self, result): + """Generates the end of the code block, and any code needed to call it.""" + if V.graph.aot_mode: + if V.graph.is_const_graph: + result.writeline("} // AOTInductorModel::_const_run_impl") + else: + result.writeline("} // namespace torch::aot_inductor\n\n\n") + return + + if config.cpp_wrapper_build_separate: + # Close the wrapper code block, then write any kernel definitions. + result.splice("'''\n)") + if self.kernel_declarations: + result.splice("\nkernel_src = (\nr'''") + result.splice(self.kernel_declarations.getvalue()) + result.splice("'''\n)") + else: + result.splice( + """ + kernel_src = '' + """ + ) + else: + # Merge main code and kernel code + result.splice(self.kernel_declarations.getvalue()) + self.kernel_declarations.clear() + # Close the wrapper code block + result.splice("'''\n)") + + kernel_code = "kernel_src" if config.cpp_wrapper_build_separate else "None" + # Cpp entry function for JIT with cpp wrapper + result.splice( + f""" + inductor_entry = CppWrapperCodeCache.load_pybinding( + argtypes=["std::vector"], + main_code=cpp_wrapper_src, + device_type="{self.device}", + num_outputs={len(V.graph.graph_outputs)}, + kernel_code={kernel_code}, + ) + """ + ) + + wrapper_body = "input_tensors = [arg if isinstance(arg, torch.Tensor) else torch.tensor(arg) for arg in args]" + if V.graph.constants: + # Append constants to the input args for cpp wrapper. + # Python wrapper directly gets the value inside the wrapper call + # as a global variable passed when calling exec(code, mod.__dict__, mod.__dict__). + # For cpp wrapper, we need to pass this python value to the inductor_entry_impl function explicitly. + assert all( + isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) + ), "Expect all constants to be Tensor" + constants_str = f"[{', '.join(V.graph.constants.keys())}]" + wrapper_body += f""" + constants_tensor = {constants_str} + input_tensors.extend(constants_tensor) + """ + # Convert vector of at::Tensor to vector of AtenTensorHandle. + # If we pass at::Tensor, the compilation will be too slow. + wrapper_body += """ + input_handles = torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(input_tensors) + """ + # Release the inputs for memory reuse. + wrapper_body += """ + args.clear() + del input_tensors + """ + + # unwrap output tensor back to python scalar + if all(x for x in self.output_is_tensor.values()): + # If no ShapeAsConstantBuffer in the output, directly return the output as tensors + outputs_str = "output_tensors" + else: + outputs = [ + ( + f"output_tensors[{i}]" + if self.output_is_tensor[i] + else f"output_tensors[{i}].item()" + ) + for i in range(len(V.graph.graph_outputs)) + ] + outputs_str = f"[{', '.join(outputs)}]" + wrapper_body += f""" + output_handles = f(input_handles) + output_tensors = torch._C._aoti.alloc_tensors_by_stealing_from_void_ptrs(output_handles) + return {outputs_str} + """ + + # Wrap the func to support setting result._boxed_call = True + result.splice( + f""" + def _wrap_func(f): + def g(args): + {wrapper_body} + return g + + call = _wrap_func(inductor_entry) + """ + ) + + @staticmethod + def get_c_shim_func_name(kernel: str, device: str) -> str: + if kernel.startswith("aoti_torch_"): + return kernel + + assert "::" in kernel, "Cpp kernel name: " + kernel + " does not contain '::'" + kernel_tokens = kernel.split("::") + kernel_suffix = kernel_tokens[-1] + if kernel_suffix == "call": + kernel_suffix = kernel_tokens[-2] + + shim_fn = f"aoti_torch_{device}_{kernel_suffix}" + return shim_fn + + def generate_c_shim_extern_kernel_call( + self, + kernel: str, + args: list[str], + device: str, + *, + debug_args: Optional[list[str]] = None, + ) -> None: + """debug_args kwarg allows CppWrapperCpuArrayRef to pass in wrapped arguments in + place of args while preserving debug printer output.""" + # We can do this unconditionally, since we cache this call. + self.add_device_include(device) + + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + debug_args if debug_args is not None else args, kernel, None, None, "extern" + ) + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] + with debug_printer_manager: + shim_fn = self.get_c_shim_func_name(kernel, device) + shim_fn_codes = ( + f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));" + ) + if enable_kernel_profile: + shim_fn_codes = textwrap.dedent( + f""" + {{ + RECORD_FUNCTION("{shim_fn}", c10::ArrayRef()); + {shim_fn_codes} + }} + """ + ) + self.writeline(shim_fn_codes) + + def generate_c_shim_extern_kernel_alloc( + self, extern_kernel: ir.ExternKernelAlloc, args: list[str] + ) -> None: + # registered output buffer name + name = extern_kernel.name + output_handle_name = f"{name}_handle" + is_inplace = ( + isinstance(extern_kernel.op_overload, torch._ops.OpOverload) + and torch.Tag.inplace_view in extern_kernel.op_overload.tags + ) + + if not is_inplace: + self.writeline(f"AtenTensorHandle {output_handle_name};") + args = [*args, f"&{output_handle_name}"] + + device = d.type if (d := extern_kernel.get_device()) else self.device + self.generate_c_shim_extern_kernel_call( + extern_kernel.get_kernel_name(), args, device + ) + + if not is_inplace: + self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});") + + def _generate_extern_kernel_alloc_helper(self, extern_kernel, args): + if getattr(extern_kernel, "outputs", None): + # ir.ExternKernelAlloc may have outputs if it returns a tuple + self.generate_c_shim_fallback_kernel(extern_kernel, args) + else: + self.generate_c_shim_extern_kernel_alloc(extern_kernel, args) + + def generate_c_shim_fallback_kernel( + self, fallback_kernel: ir.FallbackKernel, args: list[str] + ) -> None: + output_args = [] + output_raii_handles = [] + output_name_base = fallback_kernel.get_name() + for idx, output in enumerate(fallback_kernel.outputs): + if isinstance(output, ir.MultiOutput): + # TODO: handle integer output (e.g., as in attention) + name = f"{output.get_name()}" + output_handle_name = f"{name}_handle" + if output.indices: + assert output.indices[0][1] == idx, ( + f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}" + ) + self.writeline(f"AtenTensorHandle {output_handle_name};") + output_args.append(f"&{output_handle_name}") + output_raii_handles.append( + f"RAIIAtenTensorHandle {name}({output_handle_name});" + ) + elif isinstance(output, int): + output_name = f"{output_name_base}_{idx}" + self.writeline(f"int64_t {output_name} = {output};") + output_args.append(f"&{output_name}") + elif isinstance(output, sympy.Expr): + output_name = f"{output_name_base}_{idx}" + self.writeline(f"auto {output_name} = {cexpr(output)};") + output_args.append(f"&{output_name}") + elif output is None: + output_args.append("nullptr") + else: + raise NotImplementedError(f"unsupported type of {output=}") + args = args + output_args + device = d.type if (d := fallback_kernel.get_device()) else self.device + self.generate_c_shim_extern_kernel_call( + fallback_kernel.cpp_kernel_name, # type: ignore[arg-type] + args, + device, + ) + for raii_handle in output_raii_handles: + self.writeline(raii_handle) + + def _generate_extern_kernel_out_helper( + self, + kernel: str, + out: str, + out_view: Optional[str], + args: list[str], + device: str, + ) -> None: + if out_view: + out_name = f"{out}_as_strided" + self.writeline(f"auto {out_name} = {out_view};") + args.insert(0, out_name) + else: + args.insert(0, out) + + self.generate_c_shim_extern_kernel_call(kernel, args, device) + + def generate_scatter_fallback( + self, + output, + inputs, + cpp_kernel_name, + python_kernel_name, + src_is_tensor, + reduce, + kwargs, + ): + # call the ABI shim function instead of the ATen one + cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device) + # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py + cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out" + inputs_wrapped = [str(x) for x in inputs] + line = f"{cpp_kernel_name}({output}, {','.join(inputs_wrapped)}" + + if python_kernel_name.startswith("aten.scatter_reduce"): + line += f", {','.join(kwargs)}" + else: + if src_is_tensor: + if reduce: + line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}" + else: + assert reduce is None, ( + "Expect reduce to be None for aten.scatter_ with scalar src" + ) + line += ");" + self.writeline(line) + + def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): + # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version + # See the comment in codegen_reinterpret_view about why having something like + # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the corresponding + # tensor prematurely deallocated, thus the temporary array trick here. + indices_str = self._generate_temporary_array_pointer( + "AtenTensorHandle", indices + ) + args = [ + x, + indices_str, + str(len(indices)), + values, + accumulate, + ] + args.insert(0, x) # set x as the output tensor, this fallback mutates x. + self.writeline(self.wrap_kernel_call(kernel, args)) + + def add_benchmark_harness(self, output): + if V.graph.aot_mode: + return + super().add_benchmark_harness(output) + + def codegen_cpp_sizevar(self, x: sympy.Expr, *, simplify: bool = True) -> str: + return cexpr(V.graph.sizevars.simplify(x) if simplify else x) + + def codegen_sizevar(self, x: sympy.Expr) -> str: + return self.codegen_cpp_sizevar(x) + + def codegen_tuple_access(self, basename: str, name: str, index: str) -> str: + # in the abi_compatible mode, outputs are returned via arguments + return name + + def codegen_shape_tuple(self, shape: Sequence[sympy.Expr]) -> str: + parts = [*map(self.codegen_sizevar, shape)] + if len(parts) == 0: + return "{}" + if len(parts) == 1: + return f"{{{parts[0]}, }}" + return f"{{{', '.join(parts)}}}" + + def ensure_size_computed(self, sym: sympy.Symbol): + if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE): + if sym in self.computed_sizes: + return + self.computed_sizes.add(sym) + expr = V.graph.sizevars.inv_precomputed_replacements[sym] + self.writeline(f"int64_t {sym} = {cexpr(expr)};") + + def _generate_symbolic_call_arg_helper( + self, arg: SymbolicCallArg, graph: GraphLowering + ) -> None: + if (arg.inner, graph) not in self.kernel_numel_expr: + # declare expr once in each graph (scope) + self.kernel_numel_expr.add((arg.inner, graph)) + self.writeline(f"int64_t {arg.inner} = {cexpr(arg.inner_expr)};") + else: + self.writeline(f"{arg.inner} = {cexpr(arg.inner_expr)};") + + def codegen_dynamic_scalar(self, node): + (data,) = (t.codegen_reference() for t in node.inputs) + self.codegen_tensor_item(node.inputs[0].get_dtype(), data, f"{node.sym}_raw") + + if len(node.keypath) == 0: + self.writeline(f"auto {node.sym} = {node.sym}_raw;") + elif len(node.keypath) == 1 and isinstance(node.keypath[0], ConvertIntKey): + self.writeline(f"int64_t {node.sym} = {node.sym}_raw ? 1 : 0;") + elif len(node.keypath) == 1 and isinstance(node.keypath[0], DivideByKey): + # TODO: assert divisibility here + self.writeline( + f"int64_t {node.sym} = {node.sym}_raw / {node.keypath[0].divisor};" + ) + else: + raise AssertionError(f"unrecognized keypath {node.keypath}") + + # record in unbacked_symbol_decls so we won't generate a declaration of the symbol again + self.unbacked_symbol_decls.add(str(node.sym)) + + def make_buffer_free(self, buffer): + return ( + "" + if isinstance(buffer.get_output_spec(), ir.MultiOutputLayout) + or isinstance(buffer, ir.TMADescriptor) + else f"{buffer.get_name()}.reset();" + ) + + def make_free_by_names(self, names_to_del: list[str]): + return " ".join(f"{name}.reset();" for name in names_to_del) + + def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str): + return f"auto {new_name} = std::move({old_name}); // reuse" + + def generate_profiler_mark_wrapper_call(self, stack): + self.wrapper_call.writeline( + 'RECORD_FUNCTION("inductor_wrapper_call", c10::ArrayRef());' + ) + + def generate_start_graph(self): + pass + + def generate_end_graph(self): + pass + + def generate_inf_and_nan_checker(self, nodes): + for buf in nodes.get_names(): + # TODO: Add buf name directly into check_inf_and_nan. + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan({buf}));" + ) + + def codegen_device(self, device): + assert device.type in DEVICE_TO_ATEN, ( + device.type + " not found in DEVICE_TO_ATEN" + ) + device_str = DEVICE_TO_ATEN[device.type][5:].lower() # remove "at::k" + self.used_cached_devices.add(device_str) + return f"cached_torch_device_type_{device_str}, {device.index if device.index else 0}" + + def codegen_dtype(self, dtype): + dtype_str = str(dtype).split(".")[-1] + self.used_cached_dtypes.add(dtype_str) + return f"cached_torch_dtype_{dtype_str}" + + def codegen_layout(self, layout): + layout_str = str(layout).split(".")[-1] + self.used_cached_layouts.add(layout_str) + return f"cached_torch_layout_{layout_str}" + + def codegen_memory_format(self, memory_format): + memory_format_str = str(memory_format).split(".")[-1] + self.used_cached_memory_formats.add(memory_format_str) + return f"cached_torch_memory_format_{memory_format_str}" + + @functools.cache # noqa: B019 + def codegen_int_array_var( + self, + int_array: str, + writeline: Callable[..., None], + known_statically=False, + graph=None, # for per-graph caching + ): + # Used for size/stride declaration + # + # Because the memory planning is done in two passes (see the implementation + # of self.generate), the writeline behavior is different in the two passes. + # As a result, the emitted int array declarations may appear in a later + # position of the generated code, so the second pass codegen should not + # reuse int array declarations generated in the first pass. + # This is why writeline needs to explicitly passed in as a parameter. + var = f"int_array_{next(self.int_array_id)}" + ctype = "int64_t" + if var not in self.declared_int_array_vars: + self.declared_int_array_vars.add(var) + if known_statically: + writeline(f"static constexpr {ctype} {var}[] = {int_array};") + else: + writeline(f"const {ctype} {var}[] = {int_array};") + return var + + def make_buffer_allocation(self, buffer): + return self.make_allocation( + buffer.get_name(), + buffer.get_device(), + buffer.get_dtype(), + buffer.get_size(), + buffer.get_stride(), + V.graph.get_allocation_size(buffer), + ) + + def make_allocation( + self, name, device, dtype, shape, stride, allocation_shape=None + ): + if allocation_shape is None: + allocation_shape = shape + + orig_stride = stride + device_str = self.codegen_device(device) + dtype_code = self.codegen_dtype(dtype) + size = self.codegen_shape_tuple(shape) + allocation_size = self.codegen_shape_tuple(allocation_shape) + stride = self.codegen_shape_tuple(orig_stride) + + size_array_var = self.codegen_int_array_var( + size, + self.wrapper_call.writeline, + known_statically=self.is_statically_known_list_of_ints(shape), + graph=self.get_codegened_graph(), + ) + + if allocation_size != size: + allocation_size_array_var = self.codegen_int_array_var( + allocation_size, + self.wrapper_call.writeline, + known_statically=self.is_statically_known_list_of_ints( + allocation_shape + ), + graph=self.get_codegened_graph(), + ) + else: + allocation_size_array_var = size_array_var + + stride_array_var = self.codegen_int_array_var( + stride, + self.wrapper_call.writeline, + known_statically=self.is_statically_known_list_of_ints(orig_stride), + graph=self.get_codegened_graph(), + ) + device_type, device_id = device_str.split(",") + device_idx = "this->device_idx_" if V.graph.aot_mode else device_id + + handle_name = f"{name}_handle" + args = [ + str(len(shape)), + allocation_size_array_var, + stride_array_var, + dtype_code, + device_type, + device_idx, + f"&{handle_name}", + ] + + self.wrapper_call.writeline(f"AtenTensorHandle {handle_name};") + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));" + ) + + if allocation_size != size: + old_handle_name, handle_name = handle_name, f"{name}_handle_restrided" + self.wrapper_call.writeline(f"AtenTensorHandle {handle_name};") + args = [ + old_handle_name, + size_array_var, + stride_array_var, + f"&{handle_name}", + ] + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_as_strided({', '.join(args)}));" + ) + + return f"RAIIAtenTensorHandle {name}({handle_name});" + + def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: + size = self.codegen_shape_tuple(shape) + stride = self.codegen_shape_tuple(stride) + tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" + args = [ + name, + cexpr(offset), # bytes not numel + self.codegen_dtype(dtype), + str(len(shape)), + self.codegen_int_array_var( + size, self.wrapper_call.writeline, graph=self.get_codegened_graph() + ), + self.codegen_int_array_var( + stride, self.wrapper_call.writeline, graph=self.get_codegened_graph() + ), + f"&{tmp_name}", + ] + self.wrapper_call.writeline(f"AtenTensorHandle {tmp_name};") + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));" + ) + return f"RAIIAtenTensorHandle({tmp_name})" + + def codegen_reinterpret_view( + self, + data, + size, + stride, + offset, + writeline: Callable[..., None], + dtype=None, + ) -> str: + """Returns a newly-created, temporary RAII tensor handle containing the + reinterpreted tensor data. Callers of this function are responsible for saving + the handle if persistent access is needed.""" + dim = str(len(size)) + original_offset = offset + offset = self.codegen_sizevar(offset) + call_strs = [] + final_tensor_str = None + + def create_reinterpret_call() -> str: + args = [ + f"{data.get_name()}", + dim, + self.codegen_int_array_var( + self.codegen_shape_tuple(size), + writeline, + known_statically=self.is_statically_known_list_of_ints(size), + graph=self.get_codegened_graph(), + ), + self.codegen_int_array_var( + self.codegen_shape_tuple(stride), + writeline, + known_statically=self.is_statically_known_list_of_ints(stride), + graph=self.get_codegened_graph(), + ), + offset, + ] + return f"wrap_with_raii_handle_if_needed(reinterpret_tensor_wrapper({', '.join(args)}))" + + def create_dtypeview_call(reinterpret_call: str) -> tuple[str, list[str]]: + tmp_AtenTensorHandle = f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}" + tmp_call_strs = [f"AtenTensorHandle {tmp_AtenTensorHandle};"] + device_name = data.layout.device.type + dtypeview_function = f"aoti_torch_{device_name}_view_dtype" + tmp_call_strs.append( + f"AOTI_TORCH_ERROR_CODE_CHECK({dtypeview_function}" + f"({reinterpret_call}, {self.codegen_dtype(dtype)}, &{tmp_AtenTensorHandle}));" + ) + return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs + + def create_new_tensor_handle() -> tuple[str, list[str]]: + tmp_AtenTensorHandle = f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}" + tmp_call_strs = [ + f"AtenTensorHandle {tmp_AtenTensorHandle};", + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle({data.get_name()}, &{tmp_AtenTensorHandle}));", + ] + return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs + + if ( + size == data.layout.size + and stride == data.layout.stride + and original_offset == data.layout.offset + ): + # pure dtypeview + if dtype is not None and dtype != data.dtype: + final_tensor_str, tmp_call_strs = create_dtypeview_call(data.get_name()) + else: + final_tensor_str, tmp_call_strs = create_new_tensor_handle() + call_strs.extend(tmp_call_strs) + else: + # firstly create reinterpretview + final_tensor_str = create_reinterpret_call() + + if dtype is not None and dtype != data.dtype: + # wrap it with dtypeview + final_tensor_str, tmp_call_strs = create_dtypeview_call( + final_tensor_str + ) + call_strs.extend(tmp_call_strs) + + for line in call_strs: + writeline(line) + + # NB, the return handle here represents a temporary tensor, which will be automatically + # released. + # Here's a sample usage in the cpp wrapper code: + # ``` + # aoti_torch_addmm_out( + # buf1, + # arg1_1, + # RAIIAtenTensorHandle(tmp_tensor_handle_0), + # buf0, + # 1L, + # 1L)); + # ``` + # RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out. + # This could be problematic when it's used in a different pattern, for example: + # ```` + # AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6}; + # aoti_torch_proxy_executor_call_function(..., tensor_args); + # ```` + # RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter + # kernel call. + # + # This is solved by updating the proxy_executor invocation to + # ``` + # aoti_torch_proxy_executor_call_function(..., + # std::array{ + # RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6 + # }.cbegin() + # ); + # ``` + return final_tensor_str + + def codegen_device_copy(self, src, dst, non_blocking: bool): + """This function is overridden by cpp_wrapper_cpu_array_ref, so we don't need to + handle cases where dst is not an AtenTensorHandle.""" + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_({dst}, {src}, {non_blocking}));" + ) + + def codegen_multi_output(self, node: ir.MultiOutput): + # in the abi_compatible mode, outputs are retrieved by passing + # output pointers, so we skip its codegen here. + pass + + def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs): + assert len(subgraph.graph.graph_inputs) == len(outer_inputs) + + for (inner_input, inner_input_val), outer_input in zip( + subgraph.graph.graph_inputs.items(), outer_inputs + ): + if not isinstance(inner_input_val, ir.TensorBox): + continue + + # in ABI-compatible mode, we copy the underlying at::Tensor of the conditional + # input (outer_input) into another at::Tensor to be used as a subgraph input + # (inner_input) in the nested scope. we can't std::move here, as the codegened + # outer input may be an expression / rvalue (e.g., reinterpret_view(x)), so we + # can't necessarily std::move it back to the origin (x). + self.writeline(f"AtenTensorHandle {inner_input}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({outer_input}, &{inner_input}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {inner_input}({inner_input}_handle);") + + def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs): + for inner_output, outer_output in zip( + subgraph.graph.graph_outputs, outer_outputs + ): + src = inner_output.codegen_reference() + if not isinstance(inner_output, ir.ShapeAsConstantBuffer): + # in ABI-compatible mode, we need to std::move subgraph output (inner_output) + # to the conditional output (outer_output), as RAIIAtenTensorHandle's copy + # constructor is deleted. + src = f"std::move({src})" + # in case the outer_output carried a value + # before (e.g., in the while_loop codegen) + self.writeline(f"{outer_output}.reset();") + self.writeline(f"{outer_output} = {src};") + + def codegen_invoke_subgraph(self, invoke_subgraph): + raise NotImplementedError( + "codegen invoke_subgraph is not implemented for cpp wrapper" + ) + + def codegen_conditional(self, conditional): + outer_inputs = [f"{buf.codegen_reference()}" for buf in conditional.operands] + outer_outputs = [] + for out in conditional.outputs: + # in ABI-compatible mode, ir.MultiOutput is not codegened, + # hence pre-declare output variables directly and separately + self.writeline(f"RAIIAtenTensorHandle {out.get_name()};") + outer_outputs.append(out.get_name()) + + if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer): + # in ABI-compatible mode, we need to use the ABI shim function + # to extract a C++ bool from the underlying scalar bool Tensor + predicate = f"{conditional.predicate.get_name()}_scalar" + if predicate not in self.used_cond_predicate: + self.codegen_tensor_item( + torch.bool, + conditional.predicate.codegen_reference(), + predicate, + ) + self.used_cond_predicate.add(predicate) + else: + # the predicate is not a Tensor: SymBool or Python bool + predicate = conditional.predicate.codegen_reference() + + self.writeline(f"if ({predicate}) {{") + self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph)) + self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs) + self.writeline(ExitSubgraphLine(self)) + self.writeline("} else {") + self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph)) + self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs) + self.writeline(ExitSubgraphLine(self)) + self.writeline("}") + + def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs): + # TODO (desertfire) - This function is the old way of supporting + # subgraph codegen by inlining subgraphs in the output code. For python + # wrapper, we have moved to lifting subgraphs as functions, supported by + # PythonWrapperCode `codegen_subgraph` function. We should perhaps + # support lifting of subgraphs as functions for cpp wrapper as well. + try: + self.push_codegened_graph(subgraph.graph) + self.writeline(f"// subgraph: {subgraph.name}") + self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs) + parent_graph = V.graph + with V.set_graph_handler(subgraph.graph): + subgraph.graph.codegen_subgraph( + parent_graph=parent_graph, + ) + self.codegen_subgraph_suffix(subgraph, outer_inputs, outer_outputs) + finally: + self.pop_codegened_graph() + + def codegen_while_loop(self, while_loop): + is_bool_pred = isinstance( + while_loop.cond_subgraph.graph.graph_outputs[0], ir.ShapeAsConstantBuffer + ) + name = while_loop.get_name() + outer_carried_inputs = [ + buf.codegen_reference() for buf in while_loop.carried_inputs + ] + outer_additional_inputs = [ + buf.codegen_reference() for buf in while_loop.additional_inputs + ] + cond_result_name = f"{name}_cond_result" + if is_bool_pred: + self.writeline(f"bool {cond_result_name};") + else: + self.writeline(f"RAIIAtenTensorHandle {cond_result_name};") + + cond_outer_inputs = [] + for inp, out in zip(outer_carried_inputs, while_loop.outputs): + # in ABI-compatible mode, the carried inputs are codegened + # as buffers outside the while loop and set to the initial + # values. at the end of each while_loop iteration, they + # will be assigned the carried values. + out_name = out.get_name() + self.writeline(f"AtenTensorHandle {out_name}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({inp}, &{out_name}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {out_name}({out_name}_handle);") + cond_outer_inputs.append(out_name) + + # additional inputs will be assigned within the while_loop + # iteration directly from the corresponding outer graph buffers + cond_outer_inputs.extend(outer_additional_inputs) + + cond_outer_outputs = [cond_result_name] + body_outer_inputs = list(cond_outer_inputs) + body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)] + + self.writeline("while (1) {") + self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph)) + self.codegen_subgraph( + while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs + ) + + if is_bool_pred: + cond_result = f"{cond_result_name}" + else: + cond_result = f"{cond_result_name}_scalar" + self.codegen_tensor_item(torch.bool, cond_result_name, cond_result) + self.writeline(f"if (!{cond_result}) break;") + + self.writeline(ExitSubgraphLine(self)) + self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) + self.codegen_subgraph( + while_loop.body_subgraph, body_outer_inputs, body_outer_outputs + ) + self.writeline(ExitSubgraphLine(self)) + self.writeline("}") + + def generate_extern_kernel_args_decl_if_needed( + self, + op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator], + raw_args: Sequence[Any], + output_args: _OUTPUT_ARGS_TYPE, + raw_outputs: Sequence[ir.Buffer], + ): + """ + Generates declarations for external kernel arguments if needed, based on the provided + operator and its arguments. It processes both input and output arguments, categorizing + them into tensor and integer arguments for further code generation. + """ + schema = None + if isinstance(op_overload, torch._higher_order_ops.torchbind.CallTorchBind): + obj = raw_args[0] + method = raw_args[1] + schema = op_overload.schema(obj, method) + else: + assert isinstance(op_overload, torch._ops.OpOverload), type(op_overload) + schema = op_overload._schema + assert schema is not None + arg_types = [x.real_type for x in schema.arguments] + return_types = [x.type for x in schema.returns] + + new_tensor_args = [] + new_int_args = [] + + def fill_args(arg, arg_type): + static_arg_types = ( + torch.FloatType, + torch.BoolType, + torch.StringType, + torch.Type, + torch.DeviceObjType, + ) + inductor_tensor_buffers = ( + ir.Buffer, + ir.ReinterpretView, + ) + + if isinstance(arg_type, torch.TensorType): + assert isinstance(arg, inductor_tensor_buffers), f"got {type(arg)}" + new_tensor_args.append(f"{arg.codegen_reference()}") + elif isinstance(arg_type, torch.IntType): + # int + new_int_args.append(str(arg)) + elif isinstance(arg_type, torch.SymIntType): + # SymInt + expr = arg.node.expr if isinstance(arg, torch.SymInt) else arg + new_int_args.append(cexpr(expr)) + elif isinstance(arg_type, torch.NumberType): + # Scalar of type int + assert isinstance(arg, (int, float, bool)) + # Only treat int Scalar as dynamic + if isinstance(arg, int): + new_int_args.append(str(arg)) + elif isinstance(arg, ir.TorchBindObject): + # torchbind objects are loaded in proxy executor + pass + elif isinstance(arg_type, torch.ListType): + assert isinstance(arg, (list, tuple)) + + # List[Tensor] + if isinstance(arg_type.getElementType(), torch.TensorType): + new_tensor_args.extend([f"{a.codegen_reference()}" for a in arg]) + # List[Optional[Tensor]] + elif isinstance( + arg_type.getElementType(), torch.OptionalType + ) and isinstance( + arg_type.getElementType().getElementType(), torch.TensorType + ): + new_tensor_args.extend( + [f"{a.codegen_reference()}" for a in arg if a is not None] + ) + # List[int] + elif isinstance(arg_type.getElementType(), torch.IntType): + new_int_args.extend([str(a) for a in arg]) + # List[SymInt] + elif isinstance(arg_type.getElementType(), torch.SymIntType): + expressions = [ + a.node.expr if isinstance(a, torch.SymInt) else a for a in arg + ] + new_int_args.extend([cexpr(expr) for expr in expressions]) + # List[Scalar] + elif isinstance(arg_type.getElementType(), torch.NumberType): + # Only treat int Scalar as dynamic + is_int_type = [isinstance(a, int) for a in arg] + if any(is_int_type): + assert all(is_int_type), ( + "AOTInductor only supports int scalars of the same type" + ) + new_int_args.extend([str(a) for a in arg]) + else: + assert isinstance( + arg_type.getElementType(), + static_arg_types, # type: ignore[arg-type] + ), ( + f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" + ) + else: + assert isinstance( + arg_type, + static_arg_types, # type: ignore[arg-type] + ), ( + f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" + ) + + for arg, arg_type in zip(raw_args, arg_types): + if arg is not None: + if isinstance(arg_type, torch.OptionalType): + fill_args(arg, arg_type.getElementType()) + else: + fill_args(arg, arg_type) + + def fill_output_arg( + arg: str, return_type: torch.JitType, is_mutated_output: bool + ) -> None: + if isinstance(return_type, torch.TensorType): + if not is_mutated_output: + self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);") + new_tensor_args.append(f"{arg}") + elif isinstance(return_type, torch.SymIntType): + raise NotImplementedError("NYI support for return type: SymInt") + elif isinstance(return_type, torch.ListType) and isinstance( + return_type.getElementType(), torch.SymIntType + ): + raise NotImplementedError("NYI support for return type: List[SymInt]") + else: + raise AssertionError(f"Unsupported return type found: {return_type}") + + # TODO: Only support None and tensor(s) returns for now, SymInt is not implemented yet + for return_type in return_types: + if isinstance( + return_type, (torch.TensorType, torch.NoneType, torch.IntType) + ): + pass + elif isinstance(return_type, torch.OptionalType): + assert isinstance(return_type.getElementType(), torch.TensorType) + elif isinstance(return_type, torch.ListType): + assert isinstance(return_type.getElementType(), torch.TensorType) + else: + raise NotImplementedError( + f"return type {return_type} is not yet supported." + ) + + for output_arg, raw_output_arg in zip(output_args, raw_outputs): # type: ignore[arg-type] + # None output is supported, but Optional return types are not yet supported + if output_arg is None: + continue + elif isinstance(raw_output_arg, int): + new_int_args.append(str(raw_output_arg)) + elif isinstance(output_arg, list): + for out in output_arg: + assert out is not None, out + fill_output_arg( + out, + torch.TensorType.get(), + isinstance(raw_output_arg, ir.MutationOutput), + ) + else: + fill_output_arg( + output_arg, + torch.TensorType.get(), + isinstance(raw_output_arg, ir.MutationOutput), + ) + + return new_tensor_args, new_int_args + + @staticmethod + def _compatible_with_stableivalue(op: torch._ops.OpOverload) -> bool: + """Returns true if op_overload._schema only utilizes types supported by the AOT + C-shim *internal* function to_ivalue. to_ivalue is an implementation detail, so + these types are not guaranteed to be supported long-term. When generating code + for cpp_wrapper mode, we don't have to be forward-compatible, so changing this + function's implementation in future is fine.""" + supported_types = ( + torch.BoolType, + torch.DeviceObjType, + torch.FloatType, + # ScalarTypeType, LayoutType, and MemoryFormatType are seen as IntType + # when queried via torch.JitType.type. + torch.IntType, + torch.TensorType, + ) + + def type_supported(t: torch.JitType) -> bool: + if isinstance(t, torch.OptionalType): + return type_supported(t.getElementType()) + return isinstance(t, supported_types) + + return all( + type_supported(a.type) + for a in chain(op._schema.arguments, op._schema.returns) + ) + + def generate_fallback_kernel_with_runtime_lookup( + self, + buf_name: str, + python_kernel_name: str, + get_args: Callable[[], Sequence[str]], + op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator], + raw_args: Sequence[Any], + outputs: Sequence[ir.Buffer], + ) -> None: + """Generate a call to a kernel not contained in the C-shim. This results in + different code paths for AOT Inductor vs cpp_wrapper Inductor mode.""" + + def extract_output_name( + out: Optional[Union[ir.Buffer, Sequence[ir.Buffer]]], + ) -> Union[Optional[str], _OUTPUT_ARGS_TYPE]: + if out is None: + return None + if isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)): + return out.get_name() + if isinstance(out, ir.MutationOutput): + mutated_buf_names = out.get_mutation_names() + assert ( + isinstance(mutated_buf_names, list) and len(mutated_buf_names) == 1 + ), "Expect only one mutated buffer in MutationOutput" + return mutated_buf_names[0] + if isinstance(out, (list, tuple)): + return [extract_output_name(o) for o in out] # type: ignore[misc] + if isinstance(out, int): + return str(out) + raise AssertionError(f"Unexpected output: {type(out)}") + + if isinstance(op_overload, torch._ops.HigherOrderOperator): + assert isinstance( + op_overload, torch._higher_order_ops.torchbind.CallTorchBind + ), type(op_overload) + assert len(raw_args) > 1 + obj = raw_args[0] + method = raw_args[1] + return_schema = op_overload.schema(obj, method).returns + else: + return_schema = op_overload._schema.returns + + # output_args has the same pytree structure as outputs + if not return_schema: + # kernel does not return a value + output_args: _OUTPUT_ARGS_TYPE = [] + elif isinstance(output_name := extract_output_name(outputs), str): + output_args = [output_name] + else: + # If the schema indicates a return value, we should have a non-None value by + # this point. + assert isinstance(output_name, list), type(output_name) + output_args = output_name + + # In AOT mode, we use a ProxyExecutor to run fallback kernels. + if V.graph.aot_mode: + self.generate_fallback_kernel_with_runtime_lookup_aot( + op_overload, + raw_args, + output_args, + outputs, + ) + return + + assert isinstance(op_overload, torch._ops.OpOverload), type(op_overload) + for output in output_args: + assert output is None or isinstance(output, str), ( + "fallback kernels with runtime lookup currently only support tensor " + "returns, not more complicated types (such as list-of-list-of-tensor)" + ) + + # In non-AOT mode, we use aoti_torch_call_dispatcher if all the inputs and + # outputs of the op can be represented with StableIValue. This avoids the + # overhead of calling back into Python, and covers most remaining fallback ops. + if self._compatible_with_stableivalue(op_overload): + self.generate_fallback_kernel_with_runtime_lookup_nopython( + get_args, + op_overload, + output_args, # type: ignore[arg-type] + outputs, + ) + return + + # Otherwise, we call back into Python, which has some extra runtime overhead, + # but handles situations like list[Tensor] (currently unrepresentable via + # StableIValue). + self.generate_fallback_kernel_with_runtime_lookup_python( + buf_name, + python_kernel_name, + op_overload, + raw_args, + output_args, # type: ignore[arg-type] + outputs, + ) + + def generate_scoped_gil_acquire(self, declarations_before_scope, lines_in_scope): + scoped_lines = IndentedBuffer() + for declaration in declarations_before_scope: + scoped_lines.writeline(declaration) + + scoped_lines.writeline("{") + with scoped_lines.indent(): + scoped_lines.writeline("py::gil_scoped_acquire acquire;") + scoped_lines.writelines(lines_in_scope.split("\n")) + scoped_lines.writelines("}") + return scoped_lines._lines + + def load_custom_op_wrapper(self): + # TODO: need to support control flow + if self.custom_op_wrapper_loaded: + return + + lines = """ +RAIIPyObject codecache_module(PyImport_ImportModule("torch._inductor.codecache")); +if (!codecache_module) { + throw std::runtime_error("Failed to load torch._inductor.codecache"); +} +custom_op_wrapper = PyObject_GetAttrString(codecache_module, "custom_op_wrapper"); +if (!custom_op_wrapper) { + throw std::runtime_error("Failed to load torch._inductor.codecache.custom_op_wrapper"); +}""" + + declarations_before_scope = ["RAIIPyObject custom_op_wrapper;"] + scope_gil_acquire = self.generate_scoped_gil_acquire( + declarations_before_scope, lines + ) + self.writelines(scope_gil_acquire) + + self.custom_op_wrapper_loaded = True + + def generate_float_value(self, val): + assert isinstance(val, float) + if val == float("inf"): + return "std::numeric_limits::infinity()" + elif val == float("-inf"): + return "-std::numeric_limits::infinity()" + elif math.isnan(val): + return "std::numeric_limits::quiet_NaN()" + else: + return f"{val}" + + def generate_py_arg(self, py_args_var, idx, raw_arg, arg_type): + def generate_py_arg_inner(lines, raw_arg, arg_type): + def handle_scalar(scalar): + if isinstance(scalar, int): + return f"PyLong_FromLongLong({scalar})" + if isinstance(scalar, float): + return f"PyFloat_FromDouble({self.generate_float_value(scalar)})" + if isinstance(scalar, bool): + return f"PyBool_FromLong({1 if scalar else 0})" + if isinstance(scalar, complex): + real = self.generate_float_value(scalar.real) + imag = self.generate_float_value(scalar.imag) + return f"PyComplex_FromDoubles({real}, {imag})" + if isinstance(scalar, SymTypes): + scalar_var = cexpr(scalar.node.expr) + if isinstance(scalar, torch.SymBool): + return f"PyBool_FromLong({scalar_var})" + if isinstance(scalar, torch.SymFloat): + return f"PyFloat_FromDouble({scalar_var})" + return f"PyLong_FromLongLong({scalar_var})" + raise NotImplementedError( + f"scalar {scalar}, {type(scalar)} cannot be handled by handle_scalar" + ) + + if raw_arg is None: + # Py_None is a singleton, so we have to explicitly incref it here + lines.append("Py_INCREF(Py_None);\n") + return "Py_None" + elif isinstance(arg_type, torch.TensorType): + # In some cases, scalar arguments may be passed in place of tensors. + if not hasattr(raw_arg, "codegen_reference"): + return handle_scalar(raw_arg) + + # Store AtenTensorHandle as void*. All Python args are constructed in a + # nested scope, so this handle will self-destruct after the function + # call. + base_handle = self.create_tmp_raii_handle_var_if_needed( + raw_arg.codegen_reference(), lines + ) + return f"PyCapsule_New(reinterpret_cast({base_handle}.get()), NULL, NULL)" + elif isinstance(arg_type, torch.OptionalType): + return generate_py_arg_inner(lines, raw_arg, arg_type.getElementType()) + elif isinstance(arg_type, torch.IntType): + # int + return f"PyLong_FromLongLong({raw_arg})" + elif isinstance(arg_type, torch.SymIntType): + # SymInt + expr = ( + raw_arg.node.expr if isinstance(raw_arg, torch.SymInt) else raw_arg + ) + return f"PyLong_FromLongLong({cexpr(expr)})" + elif isinstance(arg_type, torch.FloatType): + return f"PyFloat_FromDouble({self.generate_float_value(raw_arg)})" + elif isinstance(arg_type, torch.BoolType): + return f"PyBool_FromLong({1 if raw_arg else 0})" + elif isinstance(arg_type, torch.StringType): + return f'PyUnicode_FromString("{raw_arg}")' + elif isinstance(arg_type, torch.NumberType): + # Union[bool, int, float, complex] + # torch/_prims_common/__init__.py + return handle_scalar(raw_arg) + elif isinstance(raw_arg, torch.device): + device_str, device_index = self.codegen_device(raw_arg).split(", ") + return f"THPDevice_New(c10::Device(static_cast({device_str}), {device_index}))" + elif isinstance(raw_arg, torch.dtype): + return f"Py_NewRef(torch::getTHPDtype(static_cast({self.codegen_dtype(raw_arg)})))" + elif isinstance(raw_arg, torch.layout): + return f"Py_NewRef(torch::getTHPLayout(static_cast({self.codegen_layout(raw_arg)})))" + elif isinstance(raw_arg, torch.memory_format): + return ( + "Py_NewRef(torch::utils::getTHPMemoryFormat(static_cast(" + f"{self.codegen_memory_format(raw_arg)})))" + ) + else: + raise NotImplementedError( + f"arg type {arg_type} is not yet supported by custom_op_wrapper" + ) + + lines = [] + if isinstance(arg_type, torch.ListType): + assert isinstance(raw_arg, (list, tuple)), str(raw_arg) + " is not a list" + lines.append( + f"PyObject* {py_args_var}_{idx} = PyList_New({len(raw_arg)});\n" + ) + for i, elem in enumerate(raw_arg): + lines.append( + f"PyList_SetItem({py_args_var}_{idx}, {i}, {generate_py_arg_inner(lines, elem, arg_type.getElementType())});\n" + ) + lines.append( + f"PyTuple_SetItem({py_args_var}, {idx}, {py_args_var}_{idx});\n" + ) + else: + lines.append( + f"PyTuple_SetItem({py_args_var}, {idx}, {generate_py_arg_inner(lines, raw_arg, arg_type)});\n" + ) + return "".join(lines) + + def generate_fallback_kernel_with_runtime_lookup_nopython( + self, + get_args: Callable[[], Sequence[str]], + op_overload: torch._ops.OpOverload, + output_args: Sequence[Optional[str]], + raw_outputs: Sequence[ir.Buffer], + ) -> None: + """Generate fallback kernel calls with runtime (non-AOT) dispatch. This can + only be called in cpp_wrapper mode, and assumes that the input is a non-None + OpOverload. + + In the future, we may switch over to directly calling c10::Dispatcher if we need + to support more datatypes.""" + if raw_outputs: + declarations_before_scope = [ + f"RAIIAtenTensorHandle {output_arg};" + for output_arg, raw_output_arg in zip(output_args, raw_outputs) # type: ignore[arg-type] + if output_arg is not None + and not isinstance(raw_output_arg, ir.MutationOutput) + ] + else: + declarations_before_scope = [ + f"RAIIAtenTensorHandle {output_arg};" + for output_arg in output_args # type: ignore[arg-type] + if output_arg is not None + ] + + dispatch_lines = IndentedBuffer() + dispatch_lines.writelines(declarations_before_scope) + dispatch_lines.writeline("{") + + with dispatch_lines.indent(): + tmp_var_number = count() + + def parse_arg(arg_type: torch.JitType, codegen_arg: str) -> str: + # Strip off any temporary references; we're in an indented context, so + # any saved-off variables will be auto-destroyed. + new_codegen_arg = codegen_arg.removeprefix("&temporary_reference(") + if new_codegen_arg != codegen_arg: + # If we removed temporary_reference, there's a good chance the + # variable ends with get() (which would retrieve an ATenTensorHandle + # from a temporary RAII handle). Strip that off too, since we're + # going to save this in a temporary RAII handle. + if codegen_arg.endswith(".get())"): + codegen_arg = new_codegen_arg.removesuffix(".get())") + else: + codegen_arg = new_codegen_arg.removesuffix(")") + + if isinstance(arg_type, torch.OptionalType): + # If we have a pointer to a variable, strip it off and let + # from handle any internal pointers. + codegen_arg = codegen_arg.removeprefix("&") + + if codegen_arg == "nullptr": + return "from(std::nullopt)" + + var_name = f"tmp_var_{next(tmp_var_number)}" + dispatch_lines.writeline( + f"std::optional {var_name}{{{parse_arg(arg_type.getElementType(), codegen_arg)}}};" + ) + return f"from({var_name})" + + raii_var = self.create_tmp_raii_handle_var_if_needed( + codegen_arg, dispatch_lines + ) + temp_handle = raii_var != codegen_arg + + if isinstance(arg_type, torch.TensorType): + if not temp_handle: + # If the RAII tensor being referenced _isn't_ a temporary, + # scoped to this fallback call, then create a new handle + # referencing it which from can steal. + var_name = f"tmp_var_{next(tmp_var_number)}" + dispatch_lines.writeline(f"AtenTensorHandle {var_name};") + dispatch_lines.writeline( + f"aoti_torch_new_tensor_handle({raii_var}, &{var_name});" + ) + return f"from({var_name})" + # If the RAII tensor _is_ a temporary scoped to this fallback call, + # simply release and steal the handle. + return f"from({raii_var}.release())" + return f"from({codegen_arg})" + + codegen_args = get_args() + ivalue_args = ( + parse_arg(a.type, c) + for a, c in zip(op_overload._schema.arguments, codegen_args) + ) + array_len = max(len(codegen_args), len(output_args)) + dispatch_lines.writeline( + f"std::array dispatch_vars{{{', '.join(ivalue_args)}}};" + ) + dispatch_lines.writeline("AOTI_TORCH_ERROR_CODE_CHECK(") + with dispatch_lines.indent(): + dispatch_lines.writeline( + f'aoti_torch_call_dispatcher("{op_overload._schema.name}", "{op_overload._schema.overload_name}", dispatch_vars.data())' # noqa: B950 + ) + dispatch_lines.writeline(");") + + if len(output_args) == 1 and (output := output_args[0]) is not None: + # result is a single tensor + dispatch_lines.writeline( + f"{output} = to(dispatch_vars[0]);" + ) + else: + # result is a tuple of tensors + for idx, output_arg in enumerate(output_args): + if output_arg is None: + continue + dispatch_lines.writeline( + f"{output_arg} = to(dispatch_vars[{idx}]);" + ) + + dispatch_lines.writeline("}") + self.writelines(dispatch_lines.getvalue().splitlines()) + + def generate_fallback_kernel_with_runtime_lookup_python( + self, + buf_name: str, + python_kernel_name: str, + op_overload: torch._ops.OpOverload, + raw_args: Sequence[Any], + output_args: Sequence[Optional[str]], + raw_outputs: Sequence[ir.Buffer], + ) -> None: + """Generate fallback kernel calls with runtime (non-AOT) dispatch. This can + only be called in cpp_wrapper mode, and assumes that the input is a non-None + OpOverload. + + This function calls into Python to dispatch, which allows it to handle datatypes + that cannot be contained in StableIValue, at the cost of some performance.""" + self.load_custom_op_wrapper() + + num_args = len(raw_args) + py_args_var = f"py_args_{next(self.arg_var_id)}" + # First arg is always the python op name + lines = textwrap.dedent( + f""" + RAIIPyObject {py_args_var}(PyTuple_New({num_args + 1})); + if (!{py_args_var}) {{ + throw std::runtime_error("PyTuple_New {py_args_var} failed"); + }} + PyTuple_SetItem({py_args_var}, 0, PyUnicode_FromString("{python_kernel_name}")); + """ + ) + + for idx, (raw_arg, schema_arg) in enumerate( + zip(raw_args, op_overload._schema.arguments) + ): + lines += self.generate_py_arg( + py_args_var, idx + 1, raw_arg, schema_arg.real_type + ) + + lines += textwrap.dedent( + f""" + // Call the custom op in Python + RAIIPyObject py_{buf_name}(PyObject_CallObject(custom_op_wrapper, {py_args_var})); + if (!py_{buf_name}) {{ + if (PyErr_Occurred()) {{ + return; + }} + throw std::runtime_error("PyObject_CallObject {python_kernel_name} failed"); + }} + """ + ) + + if len(output_args) == 1 and (output := output_args[0]) is not None: + # result is a single tensor + lines += f"{output} = reinterpret_cast(PyCapsule_GetPointer(py_{buf_name}.get(), NULL));\n" + else: + # result is a tuple of tensors + for idx, output_arg in enumerate(output_args): + if output_arg is None: + continue + lines += f"{output_arg} = reinterpret_cast(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_name}.get(), {idx}), NULL));\n" # noqa: B950 + + if raw_outputs: + declarations_before_scope = [ + f"RAIIAtenTensorHandle {output_arg};" + for output_arg, raw_output_arg in zip(output_args, raw_outputs) # type: ignore[arg-type] + if output_arg is not None + and not isinstance(raw_output_arg, ir.MutationOutput) + ] + else: + declarations_before_scope = [ + f"RAIIAtenTensorHandle {output_arg};" + for output_arg in output_args # type: ignore[arg-type] + if output_arg is not None + ] + scope_gil_acquire = self.generate_scoped_gil_acquire( + declarations_before_scope, lines + ) + self.writelines(scope_gil_acquire) + + def generate_fallback_kernel_with_runtime_lookup_aot( + self, + op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator], + raw_args: Sequence[Any], + output_args: _OUTPUT_ARGS_TYPE, + raw_outputs: Sequence[ir.Buffer], + ) -> None: + ( + tensor_call_args, + int_call_args, + ) = self.generate_extern_kernel_args_decl_if_needed( + op_overload, + raw_args, + output_args, + raw_outputs, + ) + # force both temporary arrays to generate mutable data pointers, since the proxy + # executor signature requires that datatype + int_call_str = self._generate_temporary_array_pointer( + "int64_t", int_call_args, force_mutable=True + ) + tensor_call_str = self._generate_temporary_array_pointer( + "AtenTensorHandle", tensor_call_args, force_mutable=True + ) + + extern_kernel_node_index = len(V.graph.extern_kernel_nodes) - 1 + self.writeline( + f"aoti_torch_proxy_executor_call_function(proxy_executor, " + f"{extern_kernel_node_index}, " + f"{len(int_call_args)}, " + f"{int_call_str}, " + f"{len(tensor_call_args)}, " + f"{tensor_call_str});" + ) + + def generate_reset_kernel_saved_flags(self): + pass + + def generate_save_uncompiled_kernels(self): + pass + + def c_type_for_prim_type(self, val, type_) -> str: + if isinstance(type_, torch.OptionalType): + return f"{self.c_type_for_prim_type(val, type_.getElementType())}*" + elif isinstance(type_, torch.TensorType): + return "AtenTensorHandle" + elif isinstance(type_, (torch.IntType, torch.SymIntType)): + return "int64_t" + elif isinstance( + type_, (torch.BoolType, torch.SymBoolType, torch.EnumType) + ) or repr(type_) in ("Layout", "MemoryFormat", "ScalarType"): + return "int32_t" + elif isinstance(type_, torch.FloatType): + return "double" + elif isinstance(type_, torch.NumberType): + if isinstance(val, bool): + return "int32_t" + elif isinstance(val, (int, float)): + return "double" + elif val is None: + # This could happen when val is an optional value + return "double" + else: + raise AssertionError( + f"Unexpected type in c_type_for_prim_type: {type_=}" + ) + elif isinstance(type_, torch.StringType): + return "const char*" + else: + raise AssertionError(f"Unexpected type in c_type_for_prim_type: {type_=}") + + def val_to_arg_str_for_prim_type(self, val, type_) -> str: + # TODO: not using type_ as the first step of refactoring. Will update this later. + if isinstance(val, bool): + return "1" if val else "0" + elif isinstance(val, int): + # uint64_t is long on Linux, but long long on MacOS and Windows + return f"{val}LL" if sys.platform in ["darwin", "win32"] else f"{val}L" + elif isinstance(val, complex): + return f"c10::complex{{ {self.generate_float_value(val.real)}, {self.generate_float_value(val.imag)} }}" + elif isinstance(val, str): + return f'"{val}"' + elif isinstance( + val, (ir.Buffer, ir.ReinterpretView, ir.StorageBox, ir.TensorBox) + ): + return val.codegen_reference() + elif isinstance(val, torch.device): + return self.codegen_device(val) + elif isinstance(val, torch.dtype): + return self.codegen_dtype(val) + elif isinstance(val, torch.layout): + return self.codegen_layout(val) + elif isinstance(val, torch.memory_format): + return self.codegen_memory_format(val) + elif isinstance(val, float): + return self.generate_float_value(val) + elif isinstance(val, (list, tuple)): + # FIXME: This happens because type_ is not always properly set to torch.ListType + return f"{{{', '.join(self.val_to_arg_str(x, None) for x in val)}}}" + elif isinstance(val, SymTypes): + return cexpr(val.node.expr) + elif isinstance(val, sympy.Expr): + return cexpr(val) + else: + return repr(val) + + def val_to_arg_str(self, val, type_=None) -> str: + if val is None: + # None needs special care. It either represent nullopt or an empty tensor + if type_ is None or isinstance(type_, torch.OptionalType): + if type_ is not None and isinstance( + type_.getElementType(), + ( + torch.DeviceObjType, + torch.ListType, + torch.TupleType, + ), + ): + return "nullptr, 0" + return "nullptr" + + if isinstance(type_, torch.TensorType): + # create an empty tensor, the equivalent of at::Tensor() + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"AtenTensorHandle {var_name}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{var_name}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {var_name}({var_name}_handle);") + return var_name + + raise AssertionError("Can not map None to a known data type") + + if isinstance(type_, torch.OptionalType): + element_type = type_.getElementType() + arg_str = self.val_to_arg_str(val, element_type) + # Handle optional iterables as a special case. Utilize the + # temporary_reference function to avoid saving them off and increasing + # memory usage. + if isinstance(element_type, (torch.ListType, torch.TupleType)): + main_value, aux = arg_str.rsplit(", ", maxsplit=1) + return f"&temporary_reference({main_value}), {aux}" + + # Handle optional tensors as a special case, as above. + if isinstance(element_type, torch.TensorType): + base_handle = self.val_to_arg_str(val, element_type) + return f"&temporary_reference({base_handle}.get())" + + var_name = f"var_{next(self.arg_var_id)}" + if isinstance(element_type, torch.DeviceObjType): + main_value, aux = arg_str.rsplit(", ", maxsplit=1) + self.writeline(f"auto {var_name} = {main_value};") + return f"&{var_name}, {aux}" + + self.writeline( + f"{self.c_type_for_prim_type(val, element_type)} {var_name} = {arg_str};" + ) + return f"&{var_name}" + + if isinstance(type_, (torch.ListType, torch.TupleType)): + assert isinstance(val, (list, tuple)), ( + f"{val} does not match with arg type {type_}" + ) + element_type = type_.getElementType() + + if len(val) == 0: + # Zero-size array is not supported in the C or C++ standard, so return a + # nullptr. + return "nullptr, 0" + + result = [self.val_to_arg_str(x, element_type) for x in val] + if isinstance(element_type, torch.TensorType): + result = [f"{t}.get()" for t in result] + + c_type = self.c_type_for_prim_type(val[0], element_type) + # see the comment in self._generate_temporary_array_pointer for an + # explanation of why this c_type gets modified + if isinstance(element_type, torch.OptionalType) and not c_type.startswith( + "const" + ): + c_type = f"const {c_type}" + + # need to pass the array length, because we can't use the std::array member + # function + return ( + f"{self._generate_temporary_array_pointer(c_type, result)}, {len(val)}" + ) + + val_is_scalar = isinstance(val, (bool, complex, float, int, *SymTypes)) + if isinstance(type_, torch.TensorType) and val_is_scalar: + val_str = self.val_to_arg_str_for_prim_type(val, None) + return self.codegen_scalar_to_tensor(val_str) + + return self.val_to_arg_str_for_prim_type(val, type_) + + def create_tmp_raii_handle_var_if_needed( + self, handle: str, writer: Optional[Union[HasWriteLine, list[str]]] = None + ) -> str: + """If the input handle is an rvalue RAII tensor, creates an lvalue variable for + it in writer. Returns a variable name that can be used to access handle.""" + if not handle.startswith( + ( + "borrow_arrayref_tensor_as_tensor(", + "copy_arrayref_tensor_to_tensor(", + "wrap_with_raii_handle_if_needed(", + "RAIIAtenTensorHandle(", + ) + ): + return handle + + tmp_var_name = f"var_{next(self.arg_var_id)}" + call_str = f"auto {tmp_var_name} = {handle};" + + writer = writer if writer is not None else self + if isinstance(writer, list): + writer.append(call_str) + else: + writer.writeline(call_str) + + return tmp_var_name diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py new file mode 100644 index 0000000000000000000000000000000000000000..17fd1724e7cc27044b28736219ef8f5daad07ab2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py @@ -0,0 +1,878 @@ +# mypy: allow-untyped-defs +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union + +import sympy + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._ops + +from .. import config, ir +from ..utils import sympy_product +from ..virtualized import V +from .cpp_utils import DTYPE_TO_CPP +from .cpp_wrapper_cpu import CppWrapperCpu +from .wrapper import ( + BufferLike, + EnterSubgraphLine, + ExitSubgraphLine, + MemoryPlanningLine, + MemoryPlanningState, + PythonWrapperCodegen, +) + + +BufferName = str + +# Default thread stack sizes vary by platform: +# - Linux: 8 MB +# - macOS: 512 KB +# - Windows: 1 MB +# Just pick something comfortably smaller than the smallest for now. +MAX_STACK_ALLOCATION_SIZE = 1024 * 100 + + +class CppWrapperCpuArrayRef(CppWrapperCpu): + """ + Generates cpp wrapper for running on CPU and calls cpp kernels + + This class is forked from CppWrapperCpu, with a difference that tensors may be + represented as ArrayRef, see torch/csrc/inductor/aoti_runtime/arrayref_tensor.h + """ + + def __init__(self): + super().__init__() + assert self.device == "cpu", "ArrayRefTensor only supported on CPU!" + self.allow_stack_allocation = config.aot_inductor.allow_stack_allocation + self.stack_allocated_buffers: dict[BufferName, BufferLike] = {} + + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[ir.GraphPartitionSignature] = None, + ): + # TODO - support subgraph codegen by lifting functions. Check the + # comment at CppWrapperCpu `codegen_subgraph` function. + return CppWrapperCpuArrayRef() + + @staticmethod + def get_input_cpp_type(input): + assert config.aot_inductor.use_minimal_arrayref_interface + + if isinstance(input, sympy.Expr): + from ..graph import may_get_constant_buffer_dtype + + dtype = may_get_constant_buffer_dtype(input) + assert dtype is not None, f"Failed to get the dtype of sympy.Expr: {input}" + return DTYPE_TO_CPP[dtype] + return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>" + + @staticmethod + def get_device_include_path(device: str) -> str: + assert device == "cpu", "ArrayRef only supported on CPU!" + if V.graph.aot_mode: + return "#include " + return "#include " + + def codegen_input_numel_asserts(self): + for name, buf in V.graph.graph_inputs.items(): + if isinstance(buf, sympy.Expr): + continue + + # comparing strides for 0 size tensor is tricky. Ignore them for now. + if sympy_product(buf.get_size()) == 0: + continue + numel = buf.get_numel() + self.prefix.writeline(f"assert_numel({name}, {numel});") + + def generate_extern_kernel_alloc(self, *args, **kwargs): + # Disable stack allocation for extern kernels. + self.allow_stack_allocation = False + super().generate_extern_kernel_alloc(*args, **kwargs) + + def generate_extern_kernel_out(self, *args, **kwargs): + # Disable stack allocation for extern kernels. + self.allow_stack_allocation = False + super().generate_extern_kernel_out(*args, **kwargs) + + def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None: + # Disable stack allocation for extern kernels. + self.allow_stack_allocation = False + super().generate_fallback_kernel(node) + + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + graph_name="", + original_fxnode_name=None, + ): + """ + Generates kernel call code. + + triton: Defines whether the GPU backend uses Triton for codegen. + Otherwise it uses the CUDA language for codegen. + Only valid when cuda == True. + """ + assert not triton, ( + "CppWrapperCpuArrayRef.generate_kernel_call does not support GPU" + ) + assert arg_types is not None and len(call_args) == len(arg_types), ( + "Mismatch call_args and arg_types in generate_kernel_call" + ) + new_args = [] + for idx, arg in enumerate(call_args): + if "*" in arg_types[idx]: + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"auto* {var_name} = get_data_ptr_wrapper({arg});") + new_args.append(f"({arg_types[idx]})({var_name})") + else: + # arg is a scalar + new_args.append(arg) + # debug printer related logic for cpp kernel type. + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, + kernel_name, + None, + None, + "cpp", + ) + with debug_printer_manager: + self.writeline(self.wrap_kernel_call(kernel_name, new_args)) + + def write_wrapper_decl(self): + inputs_len = len(V.graph.graph_inputs.keys()) + if V.graph.aot_mode: + if ( + config.aot_inductor.use_minimal_arrayref_interface + and not V.graph.is_const_graph + ): + input_cpp_types = ", ".join( + f"{CppWrapperCpuArrayRef.get_input_cpp_type(x)}" + for x in V.graph.graph_inputs.values() + ) + output_arrayref_types = ", ".join( + f"ArrayRefTensor<{DTYPE_TO_CPP[x.get_dtype()]}>" + for x in V.graph.graph_outputs + ) + + self.prefix.splice( + f""" + using AOTInductorModelInputs = std::tuple<{input_cpp_types}>; + using AOTInductorModelOutputs = std::tuple<{output_arrayref_types}>; + """ + ) + + if V.graph.const_module: + self.header.splice(V.graph.const_module.wrapper_code.header) + + assert V.graph.const_wrapper_code is not None + self.prefix.splice(V.graph.const_wrapper_code) + + assert V.graph.const_kernel_code is not None + self.kernel_declarations.splice(V.graph.const_kernel_code) + + if V.graph.is_const_graph: + self.prefix.splice( + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + ) + else: + if not config.aot_inductor.use_runtime_constant_folding: + # If we do not split the constant graph, we'll just create + # an empty implementation when wrapping the main module. + self.prefix.splice( + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) {} + + """ + ) + + run_impl_proto = """ + void AOTInductorModel::run_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + + self.generate_input_output_runtime_checks() + run_impl_proto += """ + __check_inputs_outputs(input_handles, output_handles); + """ + + if config.aot_inductor.use_minimal_arrayref_interface: + self.prefix.splice( + """ + template <> + AOTInductorModelOutputs AOTInductorModel::run_impl_minimal_arrayref_interface< + AOTInductorModelInputs, AOTInductorModelOutputs>( + const AOTInductorModelInputs& inputs, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + ) + self.suffix.splice(run_impl_proto) + self.suffix.splice( + """ + AOTInductorModelInputs inputs; + convert_handles_to_inputs(input_handles, inputs); + auto outputs = run_impl_minimal_arrayref_interface( + inputs, stream, proxy_executor); + // NOTE: outputs is full of ArrayRef to thread_local storage. If in the future we need this + // interface to perform well for a DSO using the minimal arrayref interface, all we need + // to do is provide ThreadLocalCachedTensor for each one! + convert_outputs_to_handles(outputs, output_handles); + } + """ + ) + + self.suffix.splice( + """ + extern "C" AOTIRuntimeError AOTInductorModelRunMinimalArrayrefInterface( + AOTInductorModelHandle model_handle, + const AOTInductorModelInputs& inputs, + AOTInductorModelOutputs& outputs) { + auto model = reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + outputs = model->run_impl_minimal_arrayref_interface( + inputs, + (torch::aot_inductor::DeviceStreamType)nullptr, + nullptr); + }) + } + """ + ) + else: + self.prefix.splice(run_impl_proto) + else: + # cpp entry function for JIT with cpp wrapper + self.prefix.splice( + """ + void inductor_entry_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed) + ) { + """ + ) + with self.prefix.indent(): + # assign inputs and outputs in both cases so the later codegen can be simplified + if not config.aot_inductor.use_minimal_arrayref_interface: + if not V.graph.is_const_graph: + if V.graph.aot_mode: + num_args = len(V.graph.graph_inputs) + else: + # Weights are promoted in the JIT mode + num_args = len(V.graph.graph_inputs) + len(V.graph.constants) + # release GIL to support multiple instances inference (in different threads of the same process) + self.prefix.splice("py::gil_scoped_release release;") + + self.prefix.splice( + f""" + auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args}); + """ + ) + + if inputs_len != 0: + for idx, input_key in enumerate(V.graph.graph_inputs.keys()): + if config.aot_inductor.use_minimal_arrayref_interface: + self.prefix.writeline( + f"auto {input_key} = std::get<{idx}>(inputs);" + ) + continue + # unwrap input tensor back to scalar + if isinstance(V.graph.graph_inputs[input_key], sympy.Expr): + from ..graph import may_get_constant_buffer_dtype + + dtype = may_get_constant_buffer_dtype( + V.graph.graph_inputs[input_key] # type: ignore[arg-type] + ) + assert dtype is not None, ( + "Fails to get the dtype of the sympy.Expr" + ) + self.codegen_tensor_item( + dtype, f"inputs[{idx}]", input_key, self.prefix + ) + else: + self.prefix.writeline( + f"auto {input_key} = std::move(inputs[{idx}]);" + ) + + assert all( + isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) + ), "Expect all constants to be Tensor" + for idx, constants_key in enumerate(V.graph.constants.keys()): + if V.graph.aot_mode: + # Weights are stored in constants_ and owned by RAIIAtenTensorHandle there. + # Don't call std::move here because it will cause constants_ to lose the ownership. + self.prefix.writeline( + f"""auto {constants_key} = constants_->at({idx});""" + ) + else: + # Append constants as inputs to the graph + constants_idx = inputs_len + idx + self.prefix.writeline( + f"auto {constants_key} = std::move(inputs[{constants_idx}]);" + ) + + self.codegen_inputs() + + if V.graph.aot_mode: + if not V.graph.is_const_graph: + if config.aot_inductor.use_minimal_arrayref_interface: + # TODO: input shape checking for regular tensor interface as well? + self.codegen_input_numel_asserts() + else: + self.prefix.writeline("inputs.clear();") + self.prefix.writeline( + "[[maybe_unused]] auto& kernels = static_cast(*this->kernels_.get());" + ) + + def generate_return(self, output_refs: list[str]): + cst_names = V.graph.constants.keys() + arr_iface = ( + not V.graph.is_const_graph + and config.aot_inductor.use_minimal_arrayref_interface + ) # For brevity. + + def use_thread_local_cached_output_tensor(idx, output): + cached_output_name = f"cached_output_{next(self.cached_output_id)}" + cache_type = "Array" if arr_iface else "Tensor" + self.wrapper_call.writeline( + f"thread_local ThreadLocalCachedOutput{cache_type}> " + f"{cached_output_name}({output});" + ) + if arr_iface: + self.wrapper_call.writeline( + f"{cached_output_name}.copy_data_from({output});" + ) + output_entry = f"std::get<{idx}>(output_arrayref_tensors)" + element_type = f"std::decay_t" + self.wrapper_call.writeline( + f"{output_entry} = {cached_output_name}.arrayref_tensor<{element_type}>();" + ) + else: + self.wrapper_call.writeline( + f"{cached_output_name}.copy_data_from({output});" + ) + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]));" + ) + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors({cached_output_name}.tensor(), " + f"output_handles[{idx}]));" + ) + + if arr_iface: + self.wrapper_call.writeline( + "AOTInductorModelOutputs output_arrayref_tensors;" + ) + + output2idx: dict[str, int] = {} + for idx, output in enumerate(output_refs): + if output == "nullptr": + continue + + is_constant_buffer = output in cst_names + output_buffer = V.graph.graph_outputs[idx] + if isinstance(output_buffer, ir.BaseView): + output_storage = output_buffer.unwrap_view() + if isinstance(output_storage.data, ir.ConstantBuffer): + is_constant_buffer = True + + if isinstance(output_buffer, ir.ShapeAsConstantBuffer): + # Need to wrap scalar into tensor as the main function returns a vector of tensors + output_tensor = self.codegen_scalar_to_tensor(output) + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output_tensor}.release();" + ) + continue + + output_is_tensor_handle_expr = ( + f"std::is_same_v," + "RAIIAtenTensorHandle> || " + f"std::is_same_v," + "AtenTensorHandle> || " + f"std::is_same_v," + "ConstantHandle>" + ) + self.wrapper_call.writeline( + f"if constexpr ({output_is_tensor_handle_expr}) {{" + ) + with self.wrapper_call.indent(): + if arr_iface: + cached_output_name = f"cached_output_{next(self.cached_output_id)}" + self.wrapper_call.writeline( + f"thread_local RAIIAtenTensorHandle {cached_output_name};" + ) + if is_constant_buffer: + # NOTE(return_constant): In some rare cases where we return + # a constant, we have to return a copy of this constant, + # because (1) constants are not owned by the Model instance + # (2) constants remain the same cross inference runs, + # assuming they are not updated at runtime Basically, we + # cannot release or transfer the ownership of any original + # constant to the user. + self.wrapper_call.writeline( + f"AtenTensorHandle {cached_output_name}_tmp;" + ) + self.wrapper_call.writeline( + f"aoti_torch_clone({output}, &{cached_output_name}_tmp);" + ) + self.wrapper_call.writeline( + f"{cached_output_name} = {cached_output_name}_tmp;" + ) + else: + self.wrapper_call.writeline( + f"{cached_output_name} = {output}.release();" + ) + self.wrapper_call.writeline( + f"convert_handle_to_arrayref_tensor({cached_output_name}, " + f"std::get<{idx}>(output_arrayref_tensors));" + ) + else: + if is_constant_buffer: + # See NOTE(return_constant) above. + self.wrapper_call.writeline( + f"aoti_torch_clone({output}, &output_handles[{idx}]);" + ) + else: + if output in output2idx: + src_idx = output2idx[output] + self.wrapper_call.writeline( + f"output_handles[{idx}] = output_handles[{src_idx}];" + ) + else: + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output}.release();" + ) + self.wrapper_call.writeline("} else {") + with self.wrapper_call.indent(): + use_thread_local_cached_output_tensor(idx, output) + self.wrapper_call.writeline("}") + + if output not in output2idx: + output2idx[output] = idx + if arr_iface: + self.wrapper_call.writeline("return output_arrayref_tensors;") + + def memory_plan(self): + from .memory_planning import MemoryPlanner + + self.lines = MemoryPlanner(self).plan(self.lines) + # TODO: integrate memory planning & stack allocation? + self.allow_stack_allocation = False + + def memory_plan_reuse(self): + out_names = V.graph.get_output_names() + + while ( + self.lines + and isinstance(self.lines[-1], MemoryPlanningLine) + # TODO: this seems legit, NullLine has no node + and self.lines[-1].node.name not in out_names # type: ignore[attr-defined] + ): + # these lines will be pointless + self.lines.pop() + + # codegen allocations in two passes + planning_states = [MemoryPlanningState()] + past_planning_states = [] + for i in range(len(self.lines)): + line = self.lines[i] + if isinstance(line, MemoryPlanningLine): + self.lines[i] = line.plan(planning_states[-1]) + elif isinstance(line, EnterSubgraphLine): + planning_states.append(MemoryPlanningState()) + elif isinstance(line, ExitSubgraphLine): + past_planning_states.append(planning_states.pop()) + past_planning_states.append(planning_states.pop()) + assert len(planning_states) == 0 + + # conservatively use the sum of all allocated buffer sizes + # in potentially nested scopes as the total allocated size + total_allocated_buffer_size = sum( + s.total_allocated_buffer_size for s in past_planning_states + ) + + self.allow_stack_allocation = ( + self.allow_stack_allocation is not False + and config.aot_inductor.allow_stack_allocation + and total_allocated_buffer_size <= MAX_STACK_ALLOCATION_SIZE + ) + + def can_stack_allocate_buffer(self, buffer): + return ( + self.allow_stack_allocation + and buffer.get_device().type == "cpu" + and self.can_prove_buffer_has_static_shape(buffer) + and ir.is_contiguous_strides_for_shape( + buffer.get_stride(), buffer.get_size() + ) + ) + + def make_buffer_free(self, buffer): + return ( + "" + if isinstance(buffer.get_output_spec(), ir.MultiOutputLayout) + or (V.graph.aot_mode and buffer.get_name() in self.stack_allocated_buffers) + or ( + config.aot_inductor.use_minimal_arrayref_interface + and V.graph.aot_mode + and buffer.get_name() in V.graph.graph_inputs + ) + else f"{buffer.get_name()}.reset();" + ) + + def make_buffer_allocation(self, buffer): + return self.make_allocation( + buffer.get_name(), + buffer.get_device(), + buffer.get_dtype(), + buffer.get_size(), + buffer.get_stride(), + buffer if self.can_stack_allocate_buffer(buffer) else None, + ) + + def make_allocation( + self, name, device, dtype, shape, stride, buffer_if_can_stack_allocate=None + ): + orig_stride = stride + device_str = self.codegen_device(device) + dtype_code = self.codegen_dtype(dtype) + size = self.codegen_shape_tuple(shape) + stride = self.codegen_shape_tuple(orig_stride) + size_array_var = self.codegen_int_array_var( + size, + self.wrapper_call.writeline, + known_statically=self.is_statically_known_list_of_ints(shape), + graph=self.get_codegened_graph(), + ) + stride_array_var = self.codegen_int_array_var( + stride, + self.wrapper_call.writeline, + known_statically=self.is_statically_known_list_of_ints(orig_stride), + graph=self.get_codegened_graph(), + ) + device_type, device_id = device_str.split(",") + device_idx = "this->device_idx_" if V.graph.aot_mode else device_id + if buffer_if_can_stack_allocate is not None: + self.stack_allocated_buffers[name] = buffer_if_can_stack_allocate + cpp_type = DTYPE_TO_CPP[dtype] + numel = buffer_if_can_stack_allocate.get_numel() + # Note: we don't zero storage because empty_strided doesn't zero either. + self.wrapper_call.writeline(f"{cpp_type} {name}_storage[{numel}];") + args = [ + f"{name}_storage", + size_array_var, + stride_array_var, + device_type, + device_idx, + ] + return f"ArrayRefTensor<{cpp_type}> {name}({', '.join(args)});" + + args = [ + str(len(shape)), + size_array_var, + stride_array_var, + dtype_code, + device_type, + device_idx, + f"&{name}_handle", + ] + + self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;") + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));" + ) + + return f"RAIIAtenTensorHandle {name}({name}_handle);" + + def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool): + assert old.get_dtype() == new.get_dtype() + old_name = old.get_name() + new_name = new.get_name() + del_line = ";" + if old_name not in V.graph.get_output_names() and delete_old: + del_line = f"; {self.make_buffer_free(old)}" + + if old.get_size() == new.get_size() and old.get_stride() == new.get_stride(): + if old_name in self.stack_allocated_buffers: + self.stack_allocated_buffers[new_name] = new + return self.codegen_exact_buffer_reuse(old_name, new_name, del_line) + + reinterpret_view = self.codegen_reinterpret_view( + old, new.get_size(), new.get_stride(), 0, self.wrapper_call.writeline + ) + if reinterpret_view in self.stack_allocated_buffers: + self.stack_allocated_buffers[new_name] = new + # The only way to get into this case is via an exact buffer reuse, since all + # other options result in a new tensor handle. + return self.codegen_exact_buffer_reuse(old_name, new_name, del_line) + return f"{self.declare}{new_name} = {reinterpret_view}{del_line} // reuse" + + def _assert_safe_to_use_borrow_arrayref_tensor_as_tensor(self): + # Borrowing arguments to shim functions is only safe because we know + # that the arguments can't be stack-allocated. Otherwise, to be sure + # we can't return a dangling pointer, we need to either 1) be + # certain that the shim function cannot return an alias of a + # borrowed argument, or 2) be certain that the returned Tensor from + # the shim function cannot escape. + assert self.is_safe_to_use_borrow_arrayref_tensor_as_tensor(), ( + "borrowing arguments to shim functions is unsafe with " + "stack allocation on! (see comment above this assertion)" + ) + + def is_safe_to_use_borrow_arrayref_tensor_as_tensor(self): + return not self.allow_stack_allocation and not self.stack_allocated_buffers + + def generate_c_shim_extern_kernel_call( + self, kernel: str, args: list[str], device: str, **_ + ) -> None: + # In the abi_compatible mode, we call fallback aten ops through a C shim layer + # Setting self.allow_stack_allocation to False because the exchange between + # ArrayRefTensor and at::Tensor is still fragile. + self.allow_stack_allocation = False + + wrapped_args = [] + for arg in args: + # We only really *need* borrow_arrayref_tensor_as_tensor for + # ArrayRefTensors. The code flowing into here uses `0` for nullptr, which + # borrow_arrayref_tensor_as_tensor would blindly coerce to int, so just + # avoid wrapping integers. Name matching is to find tensor is hacky, but + # fixing all the ArrayRefTensor issues is not a priority for now. + if isinstance(arg, str) and arg.startswith( + ("buf", "arg", "wrap_with_raii_handle_if_needed") + ): + self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor() + arg = f"borrow_arrayref_tensor_as_tensor({arg})" + wrapped_args.append(arg) + + super().generate_c_shim_extern_kernel_call( + kernel, wrapped_args, device, debug_args=args + ) + + def generate_scatter_fallback( + self, + output, + inputs, + cpp_kernel_name, + python_kernel_name, + src_is_tensor, + reduce, + kwargs, + ): + # No stack allocation when there is a fallback op + self.allow_stack_allocation = False + + # call the ABI shim function instead of the ATen one + cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name, self.device) + # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py + cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out" + self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor() + inputs_wrapped = [ + (f"borrow_arrayref_tensor_as_tensor({x})" if isinstance(x, str) else str(x)) + for x in inputs + ] + line = f"{cpp_kernel_name}(borrow_arrayref_tensor_as_tensor({output}), {','.join(inputs_wrapped)}" + + if python_kernel_name.startswith("aten.scatter_reduce"): + line += f", {','.join(kwargs)}" + else: + if src_is_tensor: + if reduce: + line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}" + else: + assert reduce is None, ( + "Expect reduce to be None for aten.scatter_ with scalar src" + ) + line += ");" + self.writeline(line) + + def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): + # No stack allocation when there is a fallback op + self.allow_stack_allocation = False + + self._assert_safe_to_use_borrow_arrayref_tensor_as_tensor() + # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version + # See the comment in codegen_reinterpret_view about why having something like + # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the corresponding + # tensor prematurely deallocated, thus the temporary array trick here. + indices_str = self._generate_temporary_array_pointer( + "AtenTensorHandle", + [f"borrow_arrayref_tensor_as_tensor({i})" for i in indices], + ) + args = [ + f"borrow_arrayref_tensor_as_tensor({x})", + indices_str, + str(len(indices)), + f"borrow_arrayref_tensor_as_tensor({values})", + accumulate, + ] + args.insert( + 0, f"borrow_arrayref_tensor_as_tensor({x})" + ) # set x as the output tensor, this fallback mutates x. + self.writeline(self.wrap_kernel_call(kernel, args)) + + def generate_fallback_kernel_with_runtime_lookup( + self, + buf_name: str, + python_kernel_name: str, + get_args: Callable[[], Sequence[str]], + op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator], + raw_args: Sequence[Any], + outputs: Sequence[ir.Buffer], + ) -> None: + # No stack allocation when there is a fallback op + self.allow_stack_allocation = False + super().generate_fallback_kernel_with_runtime_lookup( + buf_name, python_kernel_name, get_args, op_overload, raw_args, outputs + ) + + def codegen_device_copy(self, src, dst, non_blocking: bool): + # aoti_torch_tensor_copy_ takes AtenTensorHandle as input, + # while stack-allocation results in ArrayRefTensor + # so disable stack allocation here + self.allow_stack_allocation = False + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_(expensive_copy_to_tensor_if_needed({dst}), {src}, {non_blocking}));" + ) + + def codegen_reinterpret_view( + self, + data, + size, + stride, + offset, + writeline: Callable[..., None], + dtype=None, + ) -> str: + """Returns a newly-created, temporary RAII tensor handle containing the + reinterpreted tensor data. Callers of this function are responsible for saving + the handle if persistent access is needed.""" + dim = str(len(size)) + + def create_reinterpret_call() -> str: + args = [ + f"{data.get_name()}", + dim, + self.codegen_int_array_var( + self.codegen_shape_tuple(size), + writeline, + known_statically=self.is_statically_known_list_of_ints(size), + graph=self.get_codegened_graph(), + ), + self.codegen_int_array_var( + self.codegen_shape_tuple(stride), + writeline, + known_statically=self.is_statically_known_list_of_ints(stride), + graph=self.get_codegened_graph(), + ), + offset, + ] + return f"wrap_with_raii_handle_if_needed(reinterpret_tensor_wrapper({', '.join(args)}))" + + def create_new_tensor_handle() -> tuple[str, list[str]]: + # Calling reset() on ArrayRefTensor does nothing, since the array is + # const-allocated on the stack. Thus, it's safe to return a reference to + # the original array. + if (name := data.get_name()) in self.stack_allocated_buffers: + return name, [] + + tmp_AtenTensorHandle = f"tmp_{name}_{next(self.tmp_tensor_id)}" + tmp_call_strs = [ + f"AtenTensorHandle {tmp_AtenTensorHandle};", + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle({data.get_name()}, &{tmp_AtenTensorHandle}));", + ] + return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs + + if ( + size == data.layout.size + and stride == data.layout.stride + and offset == data.layout.offset + and (dtype is None or dtype == data.dtype) + ): + final_tensor_str, call_strs = create_new_tensor_handle() + for line in call_strs: + writeline(line) + return final_tensor_str + + return super().codegen_reinterpret_view( + data, size, stride, offset, writeline, dtype + ) + + def val_to_arg_str(self, val, type_=None) -> str: + if ( + val is not None + and isinstance(type_, torch.OptionalType) + and isinstance(type_.getElementType(), torch.TensorType) + ): + # Handle optional tensors as a special case, as in the parent class. + base_handle = self.val_to_arg_str(val, torch.TensorType) + if config.aot_inductor.use_minimal_arrayref_interface: + if self.is_safe_to_use_borrow_arrayref_tensor_as_tensor(): + base_handle = f"borrow_arrayref_tensor_as_tensor({base_handle})" + else: + base_handle = f"copy_arrayref_tensor_to_tensor({base_handle})" + return f"&temporary_reference({base_handle}.get())" + + return super().val_to_arg_str(val, type_) + + def codegen_tensor_item( + self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None + ): + dtype_str = str(dtype).split(".")[-1] + writer = indented_buffer or self + + if dtype == torch.float16 or dtype == torch.bfloat16: + scalar_tmp = f"{scalar}_tmp" + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar_tmp};") + + # We know that item_ doesn't alias the input, so borrowing should be safe. + tensor = f"borrow_arrayref_tensor_as_tensor({tensor})" + + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));" + ) + writer.writeline(f"float {scalar} = float({scalar_tmp});") + else: + writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};") + + # We know that item_ doesn't alias the input, so borrowing should be safe. + tensor = f"borrow_arrayref_tensor_as_tensor({tensor})" + + writer.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));" + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_wrapper_gpu.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_wrapper_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..410b7ea4f5c0f6a78a778f08655633f46c65e013 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -0,0 +1,717 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import dataclasses +import re +from itertools import count, zip_longest +from typing import Any, Optional, Union +from typing_extensions import Self + +import sympy + +import torch +from torch import dtype as torch_dtype +from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name +from torch._inductor.runtime.runtime_utils import dynamo_timed + +from .. import config +from ..codecache import CudaKernelParamCache +from ..ir import ( + GraphPartitionSignature, + TensorBox, + TMADescriptorExperimental, + TMADescriptorStable, +) +from ..utils import cache_on_self, get_gpu_type, GPU_ALIGN_BYTES, IndentedBuffer +from ..virtualized import V +from .aoti_hipify_utils import maybe_hipify_code_wrapper +from .common import get_device_op_overrides, TritonScratchWorkspace +from .cpp_utils import cexpr +from .cpp_wrapper_cpu import CppWrapperCpu +from .multi_kernel import MultiKernelCall +from .triton_utils import should_unwrap_unspec_arg +from .wrapper import PythonWrapperCodegen, SymbolicCallArg + + +_cpp_string_literal_escapes = { + "\\": "\\\\", + '"': '\\"', + "\n": "\\n", + "\t": "\\t", + "\r": "\\r", +} +_cpp_string_literal_pattern = re.compile(r'["\\\n\t\r]') + + +def cpp_string_literal(s: str) -> str: + escaped = _cpp_string_literal_pattern.sub( + lambda match: _cpp_string_literal_escapes[match.group(0)], s + ) + return f'"{escaped}"' + + +@dataclasses.dataclass +class DeferredTritonCallWrapper: + """ + When using cpp wrapper, GPU kernel load and launch needs to wait for Triton kernels + to be tuned and stored as cubin files, so use a deferred generating the final wrapper around + the triton kernel until right before the prefix is written. + """ + + wrapper_name: str + kernel_name: str + kernel_name_to_body: dict[str, str] + arg_types: list[Any] + + def generate(self, wrapper: CppWrapperGpu): + """ + Generate the GPU kernel definition, as well as load and launch code. + """ + prefix = wrapper.prefix + if self.kernel_name.startswith("multi_kernel_"): + # MultiKernel will select one kernel after running the autotune block + self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) + params = CudaKernelParamCache.get(self.kernel_name) + assert params, f"CudaKernelParamCache not populated for {self.kernel_name}" + def_args = params["def_args"] + arg_types = self.arg_types + inductor_meta = params["inductor_meta"] + + if "extra_launcher_args" in inductor_meta and len(def_args) > len(arg_types): + # extra_launcher_args should already be in def_args + assert len(def_args) == len(arg_types) - len( + inductor_meta["extra_launcher_args"] + ) + arg_types = arg_types + [SymbolicCallArg] * len( + inductor_meta["extra_launcher_args"] + ) + + if not V.graph.aot_mode: + prefix.writeline( + maybe_hipify_code_wrapper( + f"static {wrapper.device_codegen.cpp_kernel_type()} {self.kernel_name} = nullptr;" + ) + ) + kernel_var_name = self.kernel_name + else: + kernel_var_name = f"kernels_.{self.kernel_name}" + + # tensors can be RAIIAtenTensorHandle or ConstantHandle, so make them template types + template_types = [ + f"typename {name}_type_" + for name, arg_type in zip(def_args, arg_types) + if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg)) + ] + if V.graph.aot_mode: + template_types.append("typename kernels_type_") + if template_types: + prefix.writeline(f"template <{', '.join(template_types)}>") + prefix.writeline(f"static inline void {self.wrapper_name}(") + with prefix.indent(): + assert len(def_args) == len(arg_types), (def_args, arg_types) + for name, arg_type in zip(def_args, arg_types): + if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg)): + prefix.writeline(f"const {name}_type_& {name},") + elif issubclass(arg_type, (SymbolicCallArg, sympy.Expr, int)): + prefix.writeline(f"int64_t {name},") + elif arg_type is float: + prefix.writeline(f"float {name},") + elif arg_type is bool: + prefix.writeline(f"bool {name},") + else: + raise ValueError(f"Unexpected arg type {arg_type}") + prefix.writeline("int32_t device_idx_,") + prefix.writeline( + maybe_hipify_code_wrapper( + f"{wrapper.device_codegen.cpp_stream_type()} stream_," + ) + ) + if V.graph.aot_mode: + prefix.writeline("kernels_type_& kernels_,") + prefix.writeline( + "const std::optional& cubin_dir_ = std::nullopt" + ) + prefix.writeline("){") + with prefix.indent(): + if V.graph.aot_mode: + # Emit the original Triton kernel for debugging purposes + prefix.writeline("/*") + prefix.splice(self.kernel_name_to_body[self.kernel_name]) + prefix.writeline("*/") + self.generate_grid(prefix, inductor_meta, params) + self.generate_load_kernel(prefix, kernel_var_name, params) + self.generate_launch_kernel(prefix, wrapper, kernel_var_name, params) + prefix.writeline("}") + + if not config.aot_inductor.embed_kernel_binary: + # Ensure the cubin file is included in the package + V.graph.wrapper_code.additional_files.append( + params[get_cpp_wrapper_cubin_path_name()] + ) + + def generate_grid( + self, + prefix: IndentedBuffer, + inductor_meta: dict[str, Any], + params: dict[str, Any], + ): + from ..runtime.triton_heuristics import GridExpr + + grid = GridExpr.from_meta(inductor_meta, params["config"], mode="cpp") + for line in grid.prefix: + prefix.writeline(line) + prefix.splice( + f"""\ + uint32_t grid_0 = {grid.x_grid}; + uint32_t grid_1 = {grid.y_grid}; + uint32_t grid_2 = {grid.z_grid}; + """ + ) + prefix.writeline("if (grid_0 == 0 || grid_1 == 0 || grid_2 == 0) return;") + + def generate_load_kernel(self, prefix, kernel_var_name, params): + prefix.writeline(f"if ({kernel_var_name} == nullptr) {{") + with prefix.indent(): + embed_kernel_args = [f"__{params['inductor_meta']['kernel_name']}_start"] + if torch.xpu.is_available(): + # XPU needs the end address of the kernel to calculate the size of the kernel binary. + embed_kernel_args.append( + f"__{params['inductor_meta']['kernel_name']}_end" + ) + + load_kernel_args = ( + [ + *embed_kernel_args, + cpp_string_literal(params["mangled_name"]), + str(params["shared_mem"]), + ] + if V.graph.aot_mode and config.aot_inductor.embed_kernel_binary + else [ + cpp_string_literal(params[get_cpp_wrapper_cubin_path_name()]), + cpp_string_literal(params["mangled_name"]), + str(params["shared_mem"]), + "cubin_dir_", + ] + ) + prefix.writeline( + f"{kernel_var_name} = loadKernel({', '.join(load_kernel_args)}); " + ) + prefix.writeline("}") + + def generate_launch_kernel(self, prefix, wrapper, kernel_var_name, params): + triton_meta = params["triton_meta"] + assert len(self.arg_types) == len(params["def_args"]), ( + self.arg_types, + params["def_args"], + ) + arg_type_loookup = dict(zip(params["def_args"], self.arg_types)) + # difference between Python and C++ wrapper: C++ wrapper strips out equal_to_1 constants + call_args = [ + name for name in params["call_args"] if name not in triton_meta["constants"] + ] + arg_types = [arg_type_loookup[name] for name in call_args] + arg_signatures = [triton_meta["signature"][name] for name in call_args] + call_args_str = wrapper.generate_args_decl( + prefix, + call_args, + arg_types, + arg_signatures, + workspace_size=params.get("global_scratch") or 0, + ) + prefix.writeline(f"void* kernel_args_[] = {{{call_args_str}}};") + launch_kernel_args = [ + kernel_var_name, + "grid_0", + "grid_1", + "grid_2", + str(params["num_warps"]), + str(params["shared_mem"]), + "kernel_args_", + "stream_", + ] + prefix.writeline(f"launchKernel({', '.join(launch_kernel_args)});") + + +class CppWrapperGpu(CppWrapperCpu): + """ + Generates cpp wrapper for running on GPU and calls CUDA kernels + """ + + def __init__(self) -> None: + self.device = get_gpu_type() + self.device_codegen = get_device_op_overrides(self.device) + super().__init__() + self.grid_id = count() + self._kernel_name_to_body: dict[str, str] = {} + self._triton_call_wrappers: dict[str, DeferredTritonCallWrapper] = {} + self.autotune_input_prefix = "_REAL_AUTOTUNE_INPUT" + + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[GraphPartitionSignature] = None, + ): + # TODO - support subgraph codegen by lifting functions. Check the + # comment at CppWrapperCpu `codegen_subgraph` function. + return CppWrapperGpu() + + def write_header(self): + if V.graph.is_const_graph: + # We do not write header for constant graph, it will be written by main module. + return + + super().write_header() + self.header.splice( + maybe_hipify_code_wrapper(self.device_codegen.kernel_driver()) + ) + + @cache_on_self + def write_tma_descriptor_helpers_once(self): + self.header.splice(self.device_codegen.tma_descriptor_helpers()) + + def write_get_raw_stream(self, device_idx: int, graph_name: str) -> str: + name = f"stream{device_idx}" + self.writeline( + maybe_hipify_code_wrapper( + f"{self.device_codegen.cpp_stream_type()} {name};" + ) + ) + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK({self.device_codegen.aoti_get_stream()}({device_idx}, (void**)&{name}));" + ) + return name + + def get_autotuning_input_name(self, idx): + return f"{self.autotune_input_prefix}_{idx}" + + def codegen_inputs(self): + # See Note: [Input Alignment handling in Inductor] + # + # JIT Inductor does not guard on input alignment. It relies on copy_misaligned_inputs to + # copy misaligned inputs to aligned buffers. For AOTInductor, we need to do the same in cpp. + + if config.is_fbcode(): + # TODO: This is added because FC. Remove this once the newly added shim symbols, + # e.g. aoti_torch_clone_preserve_strides, have landed + return super().codegen_inputs() + + if V.graph.aot_mode and V.graph.inputs_to_check: + for idx in V.graph.inputs_to_check: + input_name = V.graph.graph_input_names[idx] + assert input_name in V.graph.graph_inputs, ( + f"{input_name} not found in graph inputs" + ) + value = V.graph.graph_inputs[input_name] + assert isinstance(value, TensorBox), ( + f"{input_name} is expected to be tensor but found as {type(value)}" + ) + warn_msg = ( + f"Input {idx} was compiled as {GPU_ALIGN_BYTES}-bytes aligned, " + "but it is not aligned at run time. Copying to an aligned tensor " + "to guarantee correctness, but expect a performance hit." + ) + self.prefix.splice( + f""" + if ((long({input_name}.data_ptr()) & ({GPU_ALIGN_BYTES} -1)) != 0) {{ + AOTI_TORCH_WARN("{warn_msg}"); + AtenTensorHandle {input_name}_aligned; + aoti_torch_clone_preserve_strides({input_name}, &{input_name}_aligned); + {input_name} = std::move(RAIIAtenTensorHandle({input_name}_aligned)); + }} + """ + ) + + super().codegen_inputs() + + def _define_kernel_helper( + self, + kernel_name: str, + kernel_body: str, + metadata: Optional[str] = None, + gpu: bool = True, + cpp_definition: Optional[str] = None, + ): + if gpu: + self._kernel_name_to_body[kernel_name] = kernel_body + if config.triton.autotune_at_compile_time: + # Call PythonWrapperCodegen to create the autotune code block + PythonWrapperCodegen._define_kernel_helper( + self, kernel_name, kernel_body, metadata, gpu, cpp_definition + ) + else: + return CppWrapperCpu._define_kernel_helper( + self, kernel_name, kernel_body, metadata, gpu, cpp_definition + ) + + def generate(self, is_inference): + with dynamo_timed("CppWrapperGpu.generate", log_pt2_compile_event=True): + return super().generate(is_inference) + + def finalize_prefix(self): + """Define the triton kernels now that autotuning is finished""" + old_prefix = self.prefix # new content should go at start of prefix + + # Generating triton kernel callers can modify the prefix (cached dtypes), + # so do this before running finalize_prefix(), but put the generated code + # after the finalize_prefix() code. + self.prefix = IndentedBuffer() + for kernel in self._triton_call_wrappers.values(): + self.prefix.writeline("\n") + kernel.generate(self) + triton_prefix = self.prefix + + self.prefix = IndentedBuffer() + super().finalize_prefix() + + self.prefix.splice(triton_prefix) + + self.prefix.writeline("\n") + self.prefix.splice(old_prefix) + + def generate_tma_descriptor(self, desc): + self.write_tma_descriptor_helpers_once() + + if isinstance(desc, TMADescriptorExperimental): + self._generate_experimental_tma_descriptor(desc) + else: + assert isinstance(desc, TMADescriptorStable) + self._generate_stable_tma_descriptor(desc) + + def _generate_experimental_tma_descriptor(self, desc): + # generate data pointer for the source tensor + source = self.generate_args_decl( + code=self, + call_args=[self.val_to_arg_str(desc.tensor)], + arg_types=[desc.tensor.get_dtype()], + arg_signatures=[None], + # these args are passed to initNDTMADescriptor, which is NOT a triton kernel + is_triton_kernel=False, + ) + + desc_name = desc.name + self.writeline(f"alignas(64) CUtensorMap {desc_name};") + + # `source` is in the form of `&var_x`, where `var_x` is the data pointer + # (CUdeviceptr); we dereference `source` and cast to `void*` to pass to + # the data pointer of the source tensor to the helper function + # `init{1,2}DTMADescriptor` + ptr = f"reinterpret_cast(*({source}))" + dims = ", ".join(self.val_to_arg_str(dim) for dim in desc.dims) + block_dims = ", ".join(self.val_to_arg_str(dim) for dim in desc.block_dims) + element_size = self.val_to_arg_str(desc.element_size) + fn = f"init{desc.rank}DTMADescriptor" + args = f"&{desc_name}, {ptr}, {dims}, {block_dims}, {element_size}" + self.writeline(f"{fn}({args});") + + def _generate_stable_tma_descriptor(self, desc): + source = self.generate_args_decl( + code=self, + call_args=[self.val_to_arg_str(desc.tensor)], + arg_types=[desc.tensor.get_dtype()], + arg_signatures=[None], + # these args are passed to initNDTMADescriptor, which is NOT a triton kernel + is_triton_kernel=False, + ) + + desc_name = desc.name + # Pack the relevant information into a StableTMADescriptor struct. + # See [Note: AOTI TMA Stable handling] for more details. + self.writeline(f"alignas(64) StableTMADescriptor {desc_name};") + + def fill_array(name, values): + for i, val in enumerate(values): + self.writeline(f"{name}[{i}] = {val};") + + ptr = f"reinterpret_cast(*({source}))" + rank = len(desc.tensor.get_size()) + + fill_array(f"{desc_name}.block_shape", desc.block_shape) + fill_array(f"{desc_name}.global_shape", desc.tensor.get_size()) + fill_array(f"{desc_name}.strides", desc.tensor.get_stride()) + + element_size = self.val_to_arg_str(desc.tensor.get_dtype().itemsize) + fn = "initTMADescriptor" + args = ", ".join( + str(x) + for x in [ + f"&{desc_name}.m", + ptr, + element_size, + rank, + f"{desc_name}.block_shape", + f"{desc_name}.global_shape", + f"{desc_name}.strides", + ] + ) + self.writeline(f"{fn}({args});") + + def generate_args_decl( + self, + code: Union[IndentedBuffer, Self], + call_args, + arg_types, + arg_signatures, + is_triton_kernel=True, + workspace_size=0, + ): + """ + Generates any declarations of args to pass into a kernel call, and then returns the arg names. + + In more detail: + * declarations: e.g. this function has a side effect of generating lines like `auto var_0 = ...;` + * returns: a string with the list of args, e.g. "var_0, var_1" + + call_args: list of call arguments + arg_types: list of argument types + arg_signatures: list with signatures of all the args + is_triton_kernel: whether these are passed into a triton kernel or not. In particular, + calls to triton kernels will have an additional global scratch space + arg injected at the front of the arg list. + """ + new_args: list[str] = [] + + # Add more cases for other types as needed + signature2dtype = { + "i32": "int32_t", + "i64": "int64_t", + "fp32": "float", + } + + def signature_is_tma_desc(sig): + if not sig: + return False + if sig == "nvTmaDesc": + return True + if sig.startswith("tensordesc<"): + return True + return False + + def process_tma_stable_arg(arg, arg_type, arg_signature, var_name): + # [Note: AOTI TMA Stable handling] + # For most args, a single arg passed to the python triton interface + # maps to a single arg in the cubin interface. However, for host-side + # TMA descriptors, a single python arg turns into 1 + 2 * N args in the + # cubin interface (where N is the rank). + # + # To do this: at TMA codegen time (for aoti), we generate a struct + # (StableTMADescriptor) containing the necessary information; and then + # when we call the function (i.e. here), we unpack the struct members. + code.writeline(f"auto {var_name} = {cexpr(arg)};") + + result = [] + result.append(f"&{var_name}.m") + + # from https://github.com/triton-lang/triton/blob/16961b79bdac1b774b42d44e52fd55a266ec2866/third_party/nvidia/backend/driver.py#L111 # noqa: B950 + match = re.match("tensordesc<([^[>]*)\\[([^]]*)\\]", arg_signature) + assert match is not None + shape = match.group(2) + ndim = shape.count(",") + 1 + + for i in range(ndim): + result.append(f"&{var_name}.block_shape[{i}]") + + for i in range(ndim): + result.append(f"&{var_name}.strides[{i}]") + + return result + + def process_args(arg, arg_type, arg_signature=None): + var_name = f"var_{next(self.arg_var_id)}" + # ignore tma descriptors, as host-side TMA descriptors need + # to be passed to the compiled Triton kernel by value + if isinstance(arg_type, UnwrapUnspecArg) and not signature_is_tma_desc( + arg_signature + ): + self.codegen_tensor_item( + arg_type.dtype, + arg, + var_name, + indented_buffer=code, + ) + new_args.append(f"&{var_name}") + elif isinstance(arg_type, torch_dtype) and not signature_is_tma_desc( + arg_signature + ): + device_ptr_type = self.device_codegen.cpp_device_ptr() + code.writeline( + maybe_hipify_code_wrapper( + f"{device_ptr_type} {var_name} = reinterpret_cast<{device_ptr_type}>({arg}.data_ptr());" + ) + ) + new_args.append(f"&{var_name}") + elif arg_type in (sympy.Integer, int): + code.writeline(f"int {var_name} = {cexpr(arg)};") + new_args.append(f"&{var_name}") + elif arg_type in (sympy.Float, float): + code.writeline(f"float {var_name} = {cexpr(arg)};") + new_args.append(f"&{var_name}") + # For symbolic call arguments, examine the arg signatures from triton meta + # to explicitly cast to the right type + # Reason: `auto` can infer unexpected type against kernel input signature. + elif ( + isinstance(arg_type, type(SymbolicCallArg)) + and arg_signature is not None + and arg_signature in signature2dtype.keys() + ): + code.writeline( + f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};" + ) + new_args.append(f"&{var_name}") + elif arg_signature and arg_signature.startswith("tensordesc<"): + new_args.extend( + process_tma_stable_arg(arg, arg_type, arg_signature, var_name) + ) + else: + code.writeline(f"auto {var_name} = {cexpr(arg)};") + new_args.append(f"&{var_name}") + + for arg, arg_type, arg_signature in zip_longest( + call_args, arg_types, arg_signatures + ): + process_args(arg, arg_type, arg_signature) + + if ( + is_triton_kernel + and ( + global_scratch := self.device_codegen.cpp_global_scratch( + next(self.arg_var_id), + workspace=TritonScratchWorkspace( + size=workspace_size, + generate_dtype_str=(lambda: self.codegen_dtype(torch.uint8)), + ), + ) + ) + is not None + ): + global_scratch_def, global_scratch_var = global_scratch + code.writelines([maybe_hipify_code_wrapper(x) for x in global_scratch_def]) + new_args.append(f"&{global_scratch_var}") + + return ", ".join(new_args) + + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + graph_name="", + original_fxnode_name=None, + ): + """ + Override the default value of argument 'gpu' to True here. + generate_kernel_call can still be called with gpu=False because of + a mix of cpu kernels and gpu kernels. + """ + device = device or V.graph.get_current_device_or_throw() + if device.type == "cpu": + # Even in CppWrapperGpu, we may see cpp kernels + return CppWrapperCpu._generate_kernel_call_helper( + self, + kernel_name, + call_args, + device=device, + triton=triton, + arg_types=arg_types, + raw_keys=raw_keys, + raw_args=raw_args, + triton_meta=triton_meta, + ) + + if ( + triton + and config.triton.autotune_at_compile_time + and kernel_name not in self.kernel_autotune_names + ): + # Call PythonWrapperCodegen to create the autotune code block + PythonWrapperCodegen._generate_kernel_call_helper( + self, + kernel_name, + call_args, + device=device, + triton=triton, + arg_types=arg_types, + raw_keys=raw_keys, + raw_args=raw_args, + triton_meta=triton_meta, + original_fxnode_name=original_fxnode_name, + ) + + stream = ( + "stream" + if V.graph.aot_mode + else self.write_get_raw_stream(device.index, graph_name) + ) + + if triton: + call_args, arg_types = self.prepare_triton_wrapper_args( + call_args, arg_types + ) + wrapper_name = f"call_{kernel_name}" + if wrapper_name not in self._triton_call_wrappers: + self._triton_call_wrappers[wrapper_name] = DeferredTritonCallWrapper( + wrapper_name, + kernel_name, + self._kernel_name_to_body, + arg_types, + ) + device_idx = "this->device_idx_" if V.graph.aot_mode else str(device.index) + call_args.append(device_idx) + call_args.append(stream) + if V.graph.aot_mode: + call_args.append("kernels") + call_args.append("this->cubin_dir_") + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args[: len(arg_types)], kernel_name, arg_types, None + ) + with debug_printer_manager: + self.writeline(f"{wrapper_name}({', '.join(call_args)});") + else: + casted = [] + for arg_type, arg in zip(arg_types, call_args): + new_arg = arg + if arg_type.endswith("*") and arg != "nullptr": + new_arg = f"{arg}.data_ptr()" + casted.append(f"({arg_type}){cexpr(new_arg)}") + call_args_str = ", ".join(casted) + self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});") + + @staticmethod + def prepare_triton_wrapper_args( + call_args: list[Any], arg_types: list[Any] + ) -> tuple[list[Any], list[Any]]: + assert len(call_args) == len(arg_types), (call_args, arg_types) + new_args = [] + new_args_types = [] + for arg, arg_type in zip(call_args, arg_types): + if isinstance(arg, str): + if isinstance(arg_type, torch_dtype) and should_unwrap_unspec_arg(arg): + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + arg_type = UnwrapUnspecArg(dtype=arg_type) + new_args.append(arg) + elif isinstance(arg, bool): + new_args.append(str(arg).lower()) + elif isinstance(arg, (int, float, SymbolicCallArg)): + new_args.append(str(arg)) + else: + new_args.append(cexpr(V.graph.sizevars.simplify(arg))) + new_args_types.append(arg_type) + return new_args, new_args_types + + def make_zero_buffer(self, name): + return f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_({name}.get()));" + + +@dataclasses.dataclass +class UnwrapUnspecArg: + """Marker that we need to call .item() on the tensor""" + + dtype: torch_dtype diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_wrapper_mps.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_wrapper_mps.py new file mode 100644 index 0000000000000000000000000000000000000000..c28d1fbc3bc52f32310504971920ce5717aa6cab --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpp_wrapper_mps.py @@ -0,0 +1,99 @@ +from typing import Any, Optional + +import sympy + +import torch + +from ..ir import GraphPartitionSignature +from ..virtualized import V +from .cpp_wrapper_gpu import CppWrapperGpu +from .wrapper import PythonWrapperCodegen + + +class CppWrapperMps(CppWrapperGpu): + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[GraphPartitionSignature] = None, + ) -> "CppWrapperMps": + return CppWrapperMps() + + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args: list[str], + arg_types: Optional[list[type]] = None, + **kwargs: dict[str, Any], + ) -> None: + """ + Generates MPS kernel call code. It should look something like: + ``` + auto mps_lib_0_func = mps_lib_0.getKernelFunction("generated_kernel"); + auto mps_lib_0_func_handle = AOTIMetalKernelFunctionHandle(mps_lib_0_func.get()); + mps_lib_0_func->runCommandBlock([&] { + mps_lib_0_func->startEncoding(); + aoti_torch_mps_set_arg(mps_lib_0_func_handle, 0, buf0); + aoti_torch_mps_set_arg(mps_lib_0_func_handle, 1, arg0_1); + ... + mps_lib_0_func->dispatch(9); + }); + ``` + """ + assert arg_types is not None + + new_args = [] + for idx, (arg, arg_type) in enumerate(zip(call_args[:-2], arg_types[:-2])): + if isinstance(arg_type, torch.dtype): + new_args.append( + f"aoti_torch_mps_set_arg_tensor({kernel_name}_handle, {idx}, {arg});\n" + ) + elif arg_type in (int, sympy.core.symbol.Symbol): + new_args.append( + f"aoti_torch_mps_set_arg_int({kernel_name}_handle, {idx}, {arg});\n" + ) + else: + raise NotImplementedError( + f"Unsupported arg type {arg_type} for arg {arg} for kernel {kernel_name}" + ) + + threads, group_size = call_args[-2], call_args[-1] + if threads is None: + raise NotImplementedError("No threads or group_size provided") + elif group_size is None: + new_args.append(f"{kernel_name}->dispatch({threads});\n") + else: + new_args.append(f"{kernel_name}->dispatch({threads}, {group_size});\n") + + # debug printer related logic for cpp kernel type. + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args[:-2], + kernel_name, + None, + None, + "cpp", + ) + with debug_printer_manager: + self.writeline(self.wrap_kernel_call(kernel_name, new_args)) + + def wrap_kernel_call(self, name: str, call_args: list[str]) -> str: + lib_name = name[: -len("_func")] + calling_args = " ".join(call_args) + return f""" + auto {name} = {lib_name}.getKernelFunction("generated_kernel"); + auto {name}_handle = AOTIMetalKernelFunctionHandle({name}.get()); + {name}->runCommandBlock([&] {{ + {name}->startEncoding(); + {calling_args} + }}); + """ + + @staticmethod + def get_device_include_path(device: str) -> str: + assert V.graph.aot_mode + return ( + "#include \n" + "#include " + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cpu_device_op_overrides.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpu_device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..13e7e6973a31d1ccaa143fe2b76819543bbecf4d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cpu_device_op_overrides.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from textwrap import dedent + +from .common import DeviceOpOverrides, register_device_op_overrides + + +class CpuDeviceOpOverrides(DeviceOpOverrides): + def import_get_raw_stream_as(self, name: str) -> str: + return dedent( + """ + def get_raw_stream(_): + return 0 + """ + ) + + def set_device(self, device_idx: int) -> str: + return "pass" + + def synchronize(self) -> str: + return "pass" + + def device_guard(self, device_idx: int) -> str: + return "pass" + + +register_device_op_overrides("cpu", CpuDeviceOpOverrides()) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__init__.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fb040e028077a45f9ac0dc9b49ac56f3532a030 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9379dc226ec3ac2b1ed15160f114c5f7316d7702 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_cpp_scheduling.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0a550ddf450f1629118a817de0a3047f5f36dfc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_env.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aeac6c55a50f5d8ade4cff9ab9b886a47de18ec5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_kernel.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d29e67acfcbfcf56dfa757e5b562b96b2a71ced0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cuda_template.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_cache.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_cache.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24651cdf69d7d97b112dc38411decfc1bb45acfd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_cache.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_presets.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_presets.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54132383d8e7609d7428877cc2ed70456347a1c6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_presets.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_python_evt.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_python_evt.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff5818bff32bbbde6b14d68756021d64a277b5c7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_python_evt.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05140371b1f0dbe091f31f5c514da4699ff55076 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/cutlass_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79a473eb405519bfc67528764dc9fd05cd1df902 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/device_op_overrides.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..690389758028c985625307af1ddcef7f80d9dec6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/gemm_template.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/serialization.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/serialization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6588689cba6adb8bc7b84f849ffa7a1350837051 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/__pycache__/serialization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..f6129398da0728ce103560657d77e6849083866d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_cpp_scheduling.py @@ -0,0 +1,293 @@ +# mypy: allow-untyped-defs +import hashlib +import logging +from collections.abc import Sequence +from typing import cast + +from torch._inductor.codegen.cuda.cutlass_python_evt import ( + CutlassEVTCodegen, + MockCutlassHandler, +) +from torch._inductor.utils import Placeholder +from torch.utils._ordered_set import OrderedSet + +from ...._dynamo.utils import counters +from ... import config +from ...codecache import code_hash, get_path +from ...ir import Buffer, ComputedBuffer, CUDATemplateBuffer, Pointwise +from ...scheduler import ( + BaseSchedulerNode, + BaseScheduling, + FusedSchedulerNode, + SchedulerNode, + WhyNoFuse, +) +from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product +from ...virtualized import V +from ..common import BackendFeature, IndentedBuffer + + +log = logging.getLogger(__name__) + + +class WhyNoFuseNames(WhyNoFuse): + def __init__(self, name1: str, name2: str) -> None: + self.name1 = name1 + self.name2 = name2 + + +class CUDACPPScheduling(BaseScheduling): + """ + Partial Scheduling implementation for CUDA C++ Kernels. + This class is intended to be used in combination with TritonScheduling, + and delegated to by CUDACombinedScheduling. + + It handles fusion decisions and CUDA C++ specific template code generation. + """ + + @classmethod + def get_backend_features(cls, device) -> OrderedSet[BackendFeature]: + return OrderedSet() + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + @staticmethod + def is_cuda_cpp_template(node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, CUDATemplateBuffer + ) + + def is_cuda_cpp_fused_template(self, node: BaseSchedulerNode) -> bool: + return isinstance(node, FusedSchedulerNode) and self.is_cuda_cpp_template(node) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if self.is_cuda_cpp_template(node1) and isinstance(node2, BaseSchedulerNode): + assert node1.node, "node1.node should not be None" + return self._can_fuse_epilogue_impl( + cast(CUDATemplateBuffer, node1.node), + [], + node2, # type: ignore[arg-type] + ) + elif self.is_cuda_cpp_fused_template(node1) and isinstance( + node2, BaseSchedulerNode + ): + assert node1.node, "node1.node should not be None" + assert node2.node, "node2.node should not be None" + fnode1 = cast(FusedSchedulerNode, node1) + return self._can_fuse_epilogue_impl( + fnode1.get_template_node(), # type: ignore[arg-type] + self._unwrap_epilogue_nodes(fnode1), + node2, # type: ignore[arg-type] + ) + + return False + + def define_kernel(self, src_code: str, node_schedule) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + + # use the original src_code as the key + kernel_hash = hashlib.sha256(src_code.encode("utf-8")).hexdigest()[:8] + if fused_name == "fused": + # no EVT kernel, use the original kernel name + kernel_name = f"cutlass_{kernel_hash}" + else: + kernel_name = f"cutlass_{fused_name}_{kernel_hash}" + wrapper.src_to_kernel[src_code] = kernel_name + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_name) + + _, _, kernel_path = get_path(code_hash(src_code), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline("async_compile.cuda(r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline( + f"''', 'so', aot_compile={str(V.graph.aot_mode)})" + ) + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a CUDA template, possibly with fused epilogues + """ + counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes) + assert self.is_cuda_cpp_template(template_node), ( + "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer" + ) + template_node = cast(SchedulerNode, template_node) + _, (_numel, rnumel) = template_node.group + assert rnumel == 1 + ctb: CUDATemplateBuffer = cast(CUDATemplateBuffer, template_node.node) + epilogue_ir_nodes: list[Buffer] = [n.node for n in epilogue_nodes] # type: ignore[misc] + assert all(isinstance(n, ComputedBuffer) for n in epilogue_ir_nodes), ( + "Epilogue nodes must all be instances of ir.ComputedBuffer" + ) + kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_nodes) + + with kernel: + for node in [template_node, *epilogue_nodes]: + node.mark_run() + + # typically there is a codegen pass which runs after mark_run + # for this kernel we've already generated the C++ code, but we still + # need to let the kernel know about loads/stores that occur in the fused + # kernel for memory planning to properly optimize allocations + ctb.emulate_store_fn() + for node in epilogue_ir_nodes: + with V.set_ops_handler(MockCutlassHandler(V.get_ops_handler())): + assert isinstance( + node, ComputedBuffer + ) # Not sure why we need to do this again + node.get_store_function()(CutlassEVTCodegen.get_index_vars(node)) + + with V.set_kernel_handler(kernel): + src_code = render() + node_schedule = [template_node, *epilogue_nodes] + kernel_name = self.define_kernel(src_code, node_schedule) + + # debug printing values of intermediate tensors + _, call_args, arg_signatures, _ = kernel.args.python_argdefs() + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, kernel_name, arg_signatures, kernel + ) + with debug_printer_manager: + kernel.call_kernel(kernel_name, ctb) + + V.graph.removed_buffers |= kernel.removed_buffers + self.free_buffers_in_scheduler() + + @staticmethod + def _unwrap_epilogue_nodes( + fused_node: FusedSchedulerNode, + ) -> list[BaseSchedulerNode]: + nodes = fused_node.get_nodes() + template_node = fused_node.get_template_node() + assert all(n.node is not None for n in nodes), ( + "All epilogue nodes should have an IRNode" + ) + return cast( + list[BaseSchedulerNode], [n for n in nodes if n.node is not template_node] + ) + + def _can_fuse_epilogue_impl( + self, + cuda_template_buffer: CUDATemplateBuffer, + existing_epilogue_nodes: list[BaseSchedulerNode], + node_to_fuse: BaseSchedulerNode, + ) -> bool: + """ + Check if the given node can be fused with the epilogue. At the moment, Kernels + support fusion with Pointwise operations, wrapped in (named) ComputedBuffer nodes. + + Args: + cuda_template_buffer : A CUDATemplateBuffer object representing the CUDA template and it's result buffer + existing_epilogue_nodes : List[SchedulerNode]: The list of already fused epilogue nodes. + node_to_fuse: The SchedulerNode node to be checked if it can be fused with the epilogue. + Returns: + - bool: True if the given node can be fused with the epilogue, False otherwise. + + """ + why = WhyNoFuseNames(cuda_template_buffer.get_name(), node_to_fuse.get_name()) + + scheduler_nodes_to_fuse = node_to_fuse.get_nodes() + + assert isinstance(cuda_template_buffer, CUDATemplateBuffer) + + # Checks on constituent nodes + for s_node in scheduler_nodes_to_fuse: + node = s_node.node + + if not isinstance(node, ComputedBuffer): + why(f"{node} is not a ComputedBuffer") + return False + elif not isinstance(node.data, Pointwise): + why(f"{node} is not a Pointwise op") + return False + elif not node.get_computed_buffer_name(): # type: ignore[attr-defined] + why(f"{node} does not have a computed buffer name") + return False + + name = node.get_computed_buffer_name() # type: ignore[attr-defined] + # dtype can differ, and strides can differ as long as they are broadcastable + if node.get_size() != cuda_template_buffer.get_size(): + why( + f"{name}'s size: {node.get_size()} differs from {cuda_template_buffer.get_name()}'s \ +size: {cuda_template_buffer.get_size()}" + ) + return False + + assert len( + existing_epilogue_nodes + ) or cuda_template_buffer.get_name() in OrderedSet( + [rd.name for rd in node_to_fuse.read_writes.reads] + ), "First epilogue node must read from cuda template buffer" + + if node_to_fuse.has_aliasing_or_mutation(): + why(f"{node_to_fuse.get_name()} has aliasing or mutation") + return False + elif node_to_fuse.is_reduction(): + why( + f"{node_to_fuse.get_name()} is a reduction which is not yet supported by EVT" + ) + return False + elif ( + not config.cuda.cutlass_epilogue_fusion_enabled + or not config.epilogue_fusion + ): + why("cutlass epilogue fusion is not enabled") + return False + elif not cuda_template_buffer.supports_epilogue_fusion: + why("epilogue fusion is only supported for TMA-enabled gemm ops") + return False + + try: + from torch._inductor.codegen.cuda.cutlass_python_evt import ( + CutlassEVTCodegen, + ) + + CutlassEVTCodegen.ir_to_evt_python_code( + cuda_template_buffer.get_name(), + existing_epilogue_nodes + list(node_to_fuse.get_nodes()), + OrderedSet(), + ) + + except NotImplementedError as e: + not_implemented_op = str(e) + if not_implemented_op.startswith("_op_"): + not_implemented_op = not_implemented_op[4:] + why( + f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}, \ +likely due to unsupported operation: {not_implemented_op}" # noqa: G004, B950 + ) + return False + else: # Likely due to unsupported dtype. + why( + f"Cannot fuse epilogue node {node_to_fuse} into {cuda_template_buffer.name}. \ +Reason: {not_implemented_op}" # noqa: G004, B950 + ) + return False + + return True diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_env.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_env.py new file mode 100644 index 0000000000000000000000000000000000000000..73f1742fd13cb241b79ea4d6c2575ef722c122bb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_env.py @@ -0,0 +1,45 @@ +import functools +import logging +import shutil +from typing import Optional + +import torch +from torch._inductor.utils import clear_on_fresh_cache + +from ... import config + + +log = logging.getLogger(__name__) + + +@clear_on_fresh_cache +@functools.lru_cache(1) +def get_cuda_arch() -> Optional[str]: + try: + cuda_arch = config.cuda.arch + if cuda_arch is None: + # Get Compute Capability of the first Visible device + major, minor = torch.cuda.get_device_capability(0) + return str(major * 10 + minor) + return str(cuda_arch) + except Exception as e: + log.error("Error getting cuda arch: %s", e) + return None + + +@clear_on_fresh_cache +@functools.lru_cache(1) +def get_cuda_version() -> Optional[str]: + try: + cuda_version = config.cuda.version + if cuda_version is None: + cuda_version = torch.version.cuda + return cuda_version + except Exception as e: + log.error("Error getting cuda version: %s", e) + return None + + +@functools.cache +def nvcc_exist(nvcc_path: Optional[str] = "nvcc") -> bool: + return nvcc_path is not None and shutil.which(nvcc_path) is not None diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..2b461844161162c89a2ee6c1ac9d2f55fb15a75d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -0,0 +1,674 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union + +from sympy import Expr, symbols + +import torch._inductor.config as config +from torch import dtype as torch_dtype +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch._inductor.scheduler import BaseSchedulerNode +from torch._inductor.utils import do_bench_using_profiling, OrderedSet, Placeholder +from torch.utils._sympy.value_ranges import ValueRanges + +from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE + + +if TYPE_CHECKING: + from .cuda_template import ArgInfo + +from ...autotune_process import CUDABenchmarkRequest +from ...ir import ( + Buffer, + ChoiceCaller, + CUDATemplateBuffer, + IRNode, + Layout, + PrimitiveInfoType, + TensorBox, +) +from ...utils import sympy_product +from ...virtualized import V +from ..common import ( + CSEVariable, + IndentedBuffer, + Kernel, + OpOverrides, + WorkspaceArg, + WorkspaceZeroMode, +) +from ..cpp_utils import CppPrinter, DTYPE_TO_CPP + + +if TYPE_CHECKING: + from torch._inductor.codegen.cuda.cuda_template import CUDATemplate + +log = logging.getLogger(__name__) + +cexpr = CppPrinter().doprint + + +def _normalize_idx(index: int, total_length: int) -> int: + return index if index >= 0 else index + total_length + + +ValidLayoutSymbols = Literal["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"] +ValidLayoutAttrs = Literal["size", "stride"] + + +@dataclass(frozen=True) +class LayoutArg: + node: IRNode + symbol: ValidLayoutSymbols + attr: ValidLayoutAttrs + dim: int + + def matches(self, node, attr, dim) -> bool: + return self.node == node and self.attr == attr and self.dim == dim + + +class CUDAKernel(Kernel): + """ + Baseclass for CUDA / Cutlass based Kernels + """ + + overrides = OpOverrides # type: ignore[assignment] + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.layout_args: dict[str, list[LayoutArg]] = defaultdict(list) + self.size_args: list[Union[Expr, int]] = [] + # Mapping from arg name to IRNode. + self.named_nodes: dict[str, IRNode] = {} + + def find_symbol( + self, node: IRNode, attr: ValidLayoutAttrs, dim: int + ) -> Optional[str]: + arg = self.find_layout_arg(node, attr, dim) + return arg.symbol if arg else None + + def find_layout_arg( + self, node: IRNode, attr: ValidLayoutAttrs, dim: int + ) -> Optional[LayoutArg]: + matches = [ + arg + for arg in itertools.chain.from_iterable(self.layout_args.values()) + if arg.matches(node, attr, dim) + ] + if len(matches) >= 1: + # Verify all matches have the same node, attribute, and dimension + # And if they come from the same node, whichever symbol we use is fine. + # if in runtime the logic changes, this would trigger guard + first_match = matches[0] + if not all( + match.node == first_match.node + and match.attr == first_match.attr + and match.dim == first_match.dim + for match in matches + ): + raise AssertionError("All matching layout args should be identical") + return first_match + return None + + def add_layout_arg( + self, symbol: ValidLayoutSymbols, node: IRNode, attr: ValidLayoutAttrs, dim: int + ): + arg = LayoutArg(node, symbol, attr, dim) + self.layout_args[symbol].append(arg) + + def init_layout_args(self) -> None: + X = self.named_nodes["X"] + W = self.named_nodes["W"] + Y = self.named_nodes["Y"] + Bias = self.named_nodes.get("Bias", None) + x_mdim = _normalize_idx(-2, len(X.get_size())) + x_kdim = _normalize_idx(-1, len(X.get_size())) + w_kdim = _normalize_idx(-2, len(W.get_size())) + w_ndim = _normalize_idx(-1, len(W.get_size())) + y_mdim = _normalize_idx(-2, len(Y.get_size())) + y_ndim = _normalize_idx(-1, len(Y.get_size())) + self.add_layout_arg("M", X, "size", x_mdim) + self.add_layout_arg("K", X, "size", x_kdim) + self.add_layout_arg("K", W, "size", w_kdim) + self.add_layout_arg("N", W, "size", w_ndim) + self.add_layout_arg("M", Y, "size", y_mdim) + self.add_layout_arg("N", Y, "size", y_ndim) + if len(X.get_size()) > 2: + self.add_layout_arg("B", X, "size", 0) + + lda_dim = self.find_ld_idx(X) + ldb_dim = self.find_ld_idx(W) + ldc_dim = self.find_ld_idx(Bias) if Bias else None + ldd_dim = self.find_ld_idx(Y) + self.add_layout_arg("lda", X, "stride", lda_dim) + self.add_layout_arg("ldb", W, "stride", ldb_dim) + if Bias is not None and ldc_dim is not None: + self.add_layout_arg("ldc", Bias, "stride", ldc_dim) + self.add_layout_arg("ldd", Y, "stride", ldd_dim) + + def get_layout_args(self) -> tuple[Union[Expr, int], ...]: + X = self.named_nodes["X"] + W = self.named_nodes["W"] + Y = self.named_nodes["Y"] + Bias = self.named_nodes.get("Bias", None) + mdim = _normalize_idx(-2, len(X.get_size())) + ndim = _normalize_idx(-1, len(W.get_size())) + kdim = _normalize_idx(-1, len(X.get_size())) + + def get_ld(node) -> Union[Expr, int]: + dim = self.find_ld_idx(node) + return node.get_stride()[dim] + + M = X.get_size()[mdim] + N = W.get_size()[ndim] + K = X.get_size()[kdim] + B = X.get_size()[0] if len(X.get_size()) > 2 else 1 + LDA = get_ld(X) + LDB = get_ld(W) + LDC = get_ld(Bias) if Bias else 0 + LDD = get_ld(Y) + return (M, N, K, B, LDA, LDB, LDC, LDD) + + def get_dynamic_shape_args(self) -> list[Union[Expr, int]]: + return [*self.get_layout_args(), *self.size_args] + + @staticmethod + def find_ld_idx(node: IRNode) -> int: + strides = node.get_stride() + # Handle 1D tensor case + if V.graph.sizevars.statically_known_equals(strides[-1], 1): + return _normalize_idx(-2, len(strides)) + + assert V.graph.sizevars.statically_known_equals(strides[-2], 1), strides[-2] + return _normalize_idx(-1, len(strides)) + + +class CUDATemplateKernel(CUDAKernel): + """ + Template kernels defined by CUDA / Cutlass in C++. + """ + + _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream" + + def __init__( + self, + kernel_name: str, + runtime_arg_info: list["ArgInfo"], + runtime_arg_values: list[Any], + ) -> None: + """ + Initializes a new instance of the CUDATemplateKernel class. + + Args: + kernel_name (str): The name of the kernel. + """ + super().__init__() + self.kernel_name = kernel_name + self.runtime_arg_info = runtime_arg_info + self.runtime_arg_values = runtime_arg_values + + def check_not_null(self, node: IRNode) -> str: + """ + Generates code to check that a node is not null. + """ + if node is None: + return "" + + size_str = self.size(node, 0, -1) + name_str = self.arg_name(node) + if name_str is None: + return "" + + res = IndentedBuffer(initial_indent=2) + res.tabwidth = 1 + res.splice( + f""" + {{ + if (!{name_str}) {{ + int64_t {name_str}_size = {size_str}; + if ({name_str}_size > 0) {{ + throw std::runtime_error("input {name_str} is null but size is not 0!"); + }} + }} + }} + """ + ) + return res.getvalue() + + def get_signature(self) -> str: + return self.signature + + def def_kernel( + self, + inputs: list[IRNode], + outputs: list[IRNode], + names_str: str = "", + input_reorder: Optional[list[int]] = None, + ) -> str: + """ + Hook called from template code to generate function definition and + needed args. + + Args: + inputs: List of input IRNodes + outputs: List of output IRNodes + names_str: Comma separated list of input + output argument names. + input_reorder: The actual order of input nodes. + e.g. The template might have input argument defined as [X, W, Bias], + and the actual input passed into this template could be [Bias, X, W]. + In this case, the `input_reorder` would be [2, 0, 1]. + additional_size_args: Additional size arguments for epilogue inputs + """ + names = [x.strip() for x in names_str.strip().split(",")] + if len(inputs) + len(outputs) != len(names): + raise RuntimeError( + f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}" + ) + + if input_reorder is not None: + assert len(inputs) == len(input_reorder) + else: + input_reorder = list(range(len(inputs))) + + for idx in input_reorder: + name = names[idx] + node = inputs[idx] + if node is not None: + self.named_nodes[name] = node + self.args.input_buffers[node.get_name()] = name + + free_symbols: OrderedSet[Expr] = OrderedSet() + for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): + if node is not None: + self.named_nodes[name] = node + self.args.output_buffers[node.get_name()] = name + + if name not in ( + "X", + "W", + "Bias", + "Y", + ): # we handle these symbolic shapes explicitly + for expr in itertools.chain(node.get_size(), node.get_stride()): + if isinstance(expr, Expr): + for s in expr.free_symbols: + free_symbols.add(s) # type: ignore[arg-type] + + arg_defs, *_ = self.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE) + + self.init_layout_args() + size_vars = ["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"] + size_vars.extend(str(s) for s in free_symbols) + self.size_args.extend(free_symbols) + size_args = [f"const int {s}" for s in size_vars] + + runtime_arg_decls = ",".join( + [f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info] + ) + if runtime_arg_decls: + runtime_arg_decls += ", " + + signature = f"int {self.kernel_name}({', '.join(arg_defs + size_args)}, {runtime_arg_decls}{self._EXTRA_CPP_ARGS})" + self.signature = signature + return signature + + def call_kernel( + self, + name: str, + node: "CUDATemplateBuffer", # type: ignore[name-defined] + ) -> None: + """ + Generates code to call the kernel through V.graph.wrapper_code. + used from within torch._inductor.wrapper.PythonWrapperCodegen + + name: Name of kernel function. + node: The CUDATemplateBuffer node which contains information about the kernel, it's fused epilogue nodes + as well as all required inputs and outputs. + """ + wrapper = V.graph.wrapper_code + + arg_types: list[Any] + if V.graph.cpp_wrapper: + # Make sure we initialize these kernels since they're exported as + # C-style symbol names. + assert isinstance(wrapper, CppWrapperCpu) + wrapper.initialized_kernels[name] = self + # We always originally initialize name with "KERNEL_NAME". So, we + # we replace with the real kernel name passed as an arg to this function. + self.signature = self.signature.replace(str(Placeholder.KERNEL_NAME), name) + _, call_args, arg_types = self.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE) + else: + _, call_args, _, arg_types = self.args.python_argdefs() + + dynamic_shape_args = self.get_dynamic_shape_args() + call_args.extend(dynamic_shape_args) # type: ignore[arg-type] + for arg in self.runtime_arg_values: + call_args.append(arg) + arg_types.extend("int" for _ in dynamic_shape_args) + for arg in self.runtime_arg_info: + arg_types.append(arg.ty) + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + for i in range(len(call_args)): + if V.graph.is_unspec_arg(call_args[i]): + call_args[i] = call_args[i] + ".item()" + elif isinstance(arg_types[i], torch_dtype): + call_args[i] = ( + call_args[i] + if V.graph.cpp_wrapper + else f"c_void_p({call_args[i]}.data_ptr())" + ) + + # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size. + # workspace_size should have already been retrieved prior to this call. + # workspace_size is here. + call_args.append("nullptr" if V.graph.cpp_wrapper else "None") + if V.graph.cpp_wrapper: + arg_types.append("size_t*") + + if node.get_workspace_size() > 0: + ws = WorkspaceArg( + count=node.get_workspace_size(), + device=V.graph.get_current_device_or_throw(), + zero_mode=WorkspaceZeroMode.UNINITIALIZED, + outer_name=WorkspaceArg.unique_name(), + ) + wrapper.generate_workspace_allocation(ws) + workspace = str(ws.outer_name) + call_args.append( + workspace + if V.graph.cpp_wrapper + else f"c_void_p({workspace}.data_ptr())" + ) + else: + ws = None + call_args.append("nullptr" if V.graph.cpp_wrapper else "None") + if V.graph.cpp_wrapper: + arg_types.append("uint8_t*") + + wrapper.generate_kernel_call( + name, + call_args, + triton=False, + arg_types=arg_types, + ) + if ws: + wrapper.generate_workspace_deallocation(ws) + + def dtype(self, node: IRNode) -> Optional[str]: + """ + Generates code which represents dtype of a given node. + """ + + if node is None: + return "void" + return DTYPE_TO_CPP.get(node.get_layout().dtype) + + def cutlass_dtype(self, node: IRNode, default_dtype="void") -> Optional[str]: + # Helper method, called into from CUTLASSGemmTemplate + if node is None: + return default_dtype + from torch._inductor.codegen.cuda.cuda_template import CUTLASSTemplate + + return CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype] + + def max_valid_index(self, node: IRNode, default=-1): + # Helper method, called into from CUTLASSGemmTemplate + if node is None: + return default + max_valid_offset = 0 + for i in range(len(node.get_size())): + max_valid_offset += (node.get_size()[i] - 1) * node.get_stride()[i] + return max_valid_offset + + def offset(self, node: IRNode) -> str: + """ + Generates code which represents offset of a given node. + """ + + if node is None: + return "0" + return str(node.get_layout().offset) # type: ignore[union-attr] + + def ptr(self, node: IRNode) -> str: + """ + Generates code which represents pointer of a given node. + """ + + if node is None: + return "nullptr" + arg_name = self.arg_name(node) + if arg_name is None: + return "nullptr" + offset = self.offset(node) + return arg_name if offset == "0" else f"{arg_name} + {offset}" + + def size( + self, + node: IRNode, + start_index: int, + end_index: Optional[int] = None, + default_value: int = 0, + ) -> str: + """ + Hook called from template code to get the size of an arg. + Generates code which represents size of a given node in [start_index, end_index). + If node is None, returns default_value. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + start_index = _normalize_idx(start_index, len(node.get_size())) + if end_index is None: + end_index = start_index + end_index = _normalize_idx(end_index, len(node.get_size())) + sizes = [ + self.find_symbol(node, "size", dim=i) or node.get_size()[i] + for i in range(start_index, end_index + 1) + ] + if len(sizes) == 0: + return str(default_value) + + sizes = [symbols(v) if isinstance(v, str) else v for v in sizes] + val = sympy_product(sizes) + return val + + def stride(self, node: IRNode, index: int, default_value: int = 0) -> str: + """ + Hook called from template code to get the stride of an arg. + Generates code which represents stride of a given node at index. + If node is None, returns default_value. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None: + return str(default_value) + + index = _normalize_idx(index, len(node.get_size())) + if index < 0: + return str(default_value) + + stride = node.get_stride()[index] + if V.graph.sizevars.statically_known_leq(stride, 1): + return str(stride) + return self.find_symbol(node, "stride", dim=index) or str(stride) + + def batch_stride(self, node: IRNode, default_value: int = 0) -> str: + """ + Hook called from template code to get the batch stride of an arg. + Returns 0 if batch dim is not present. + + This method assumes that batch stride is the largest stride. + """ + + if node is None: + return str(default_value) + + if len(node.get_size()) < 3: + return str(default_value) + + batch_stride = node.get_stride()[0] + if V.graph.sizevars.statically_known_leq(batch_stride, 1): + return str(batch_stride) + + return "{}*{}".format( + self.find_symbol(node, "size", dim=1) or node.get_size()[1], + self.find_symbol(node, "size", dim=2) or node.get_size()[2], + ) + + def row_or_column_stride(self, node: IRNode, default_value: int = 0) -> str: + """ + Hook called from template code to get the row or column stride of an arg. + This is required by some CUTLASS 2.X APIs. + If the node is in row_major, it returns stride[-2]. + If the node is in column_major, it returns stride[-1]. + + TODO: Will add needed args to pass it in if it is dynamic. + """ + + if node is None or len(node.get_stride()) < 2: + return str(default_value) + + stride0 = node.get_stride()[-1] + stride1 = node.get_stride()[-2] + if stride0 == 1: + return cexpr(self.rename_indexing(stride1)) + elif stride1 == 1: + return cexpr(self.rename_indexing(stride0)) + else: + raise RuntimeError( + f"At least 1 stride should be 1. Strides: {node.get_stride()=}" + ) + + def load(self, name: str, index: Expr, mode: Any = None) -> CSEVariable: + """ + Mock load function for memory planning to optimize allocations properly. + """ + return self.create_cse_var(name, bounds=ValueRanges.unknown()) + + def store(self, name: str, index: Expr, value: Any, mode: Any = None) -> None: + """ + Mock store function for memory planning to optimize allocations properly. + """ + self.store_buffer_names.add(name) + + +class CUDATemplateCaller(ChoiceCaller): + """ + CUDATemplateCaller + + This class represents a caller for CUDA template kernels. It is a subclass of ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (CUDABenchmarkRequest): The benchmark request for the caller. + template_buffer (CUDATemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: list[Buffer], + layout: Layout, + make_kernel_render: Callable[ + [CUDATemplateBuffer, Optional[list[BaseSchedulerNode]]], + tuple[CUDATemplateKernel, functools.partial[str]], + ], + bmreq: CUDABenchmarkRequest, + supports_epilogue_fusion: bool, + template: "CUDATemplate", # type: ignore[name-defined] + info_kwargs: Optional[ + dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]] + ], # type: ignore[type-arg] + description: str, + ) -> None: + super().__init__(name, input_nodes, layout, description) + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.supports_epilogue_fusion = supports_epilogue_fusion + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + if config.profile_bandwidth_with_do_bench_using_profiling: + algo = self.bmreq.make_run_fn(*args, out=out) + return do_bench_using_profiling(algo) + return self.bmreq.benchmark(*args, out=out) + + def __str__(self) -> str: + return f"CUDATemplateCaller(source_file={self.bmreq.source_file})" + + def call_name(self) -> str: + return f"cuda_template_kernels.{self.name}" + + def kernel_hash_key(self) -> str: + """ + Return kernel hash key that does not depend on swizzle. + """ + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def hash_key(self) -> str: + """ + Return kernel hash key that does not depend on swizzle. + """ + return "-".join( + [ + self.category, + self.bmreq.hash_key, + str(self.info_dict().get("swizzle")), + ] + ) + + def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + if self.info_kwargs is not None and "op" in self.info_kwargs: + op: Any = self.info_kwargs["op"] + return { + "backend": "CUDA", + "op_type": type(op).__name__, + "op_conf_name": str(op.configuration_name()), + "op_arch": str(op.arch), + "tile_shape": str(op.tile_description.tile_shape), + "epilogue_schedule": str(op.epilogue_schedule), + "kernel_schedule": str(op.kernel_schedule), + "element_accumulator": str(op.accumulator_type()), + "op_name": str(op.procedural_name()), + "instruction_shape": str( + op.tile_description.math_instruction.instruction_shape + ), + "swizzle": str(self.info_kwargs["swizzle"]), + } + else: + return {"backend": "CUDA", "op_type": "unknown"} + + def output_node(self) -> TensorBox: + self.bmreq.update_workspace_size() + return TensorBox.create( + CUDATemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + workspace_size=self.bmreq.workspace_size, + supports_epilogue_fusion=self.supports_epilogue_fusion, + template=self.template, + ) + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_template.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_template.py new file mode 100644 index 0000000000000000000000000000000000000000..1e3045f5b40e30d6a95d9fdcb1e488799edabac1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cuda_template.py @@ -0,0 +1,318 @@ +# mypy: allow-untyped-defs +import functools +import hashlib +import itertools +from dataclasses import dataclass +from typing import Any, Optional, TYPE_CHECKING +from typing_extensions import override +from unittest.mock import patch + +import sympy + +import torch +from torch._inductor.utils import Placeholder +from torch._logging import getArtifactLogger + +from ...autotune_process import CUDABenchmarkRequest, TensorMeta +from ...ir import Buffer, CUDATemplateBuffer, IRNode, Layout +from ...utils import IndentedBuffer, unique +from ...virtualized import V +from ..common import KernelTemplate +from .cuda_kernel import CUDATemplateCaller, CUDATemplateKernel +from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE + + +if TYPE_CHECKING: + from ...scheduler import BaseSchedulerNode # noqa: TC004 +else: + BaseSchedulerNode = Any + +GemmOperation = Any + +autotuning_log = getArtifactLogger(__name__, "autotuning") + + +@dataclass(frozen=True) +class ArgInfo: + name: str + ty: str + + +class CUDATemplate(KernelTemplate): + index_counter = itertools.count() + + def __init__( + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + input_reorder: Optional[list[int]] = None, + ) -> None: + """ + + Baseclass for CUDA C++ Templates, derived from KernelTemplate. Not to be instantiated directly. + + Args: + name (str): The name of the CUDATemplate object. + input_nodes (List[IRNode]): A list of input IRNodes. + layout (Layout): The layout of the output buffer / tensor. + input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes. + + """ + super().__init__(name) + self.input_nodes = input_nodes + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + self.input_reorder = input_reorder + self.layout = layout + + @staticmethod + def supports_epilogue_fusion(op: GemmOperation) -> bool: + return False + + def generate( # type: ignore[override] + self, + description, + **kwargs, + ) -> CUDATemplateCaller: + """ + Generates the CUDA template caller object for the given GEMM template and operation. This CUDATemplateCaller + may be used to call and benchmark the generated CUDA kernel in a standalone manner to enable Autotuning. + + Args: + kwargs: Additional keyword arguments. + + Returns: + A CUDATemplateCaller object representing the generated CUDA template caller. + """ + kernel_name = str(Placeholder.KERNEL_NAME) + with ( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), + CUDATemplateKernel( + kernel_name=kernel_name, + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) as kernel, + ): + code = self.render(kernel=kernel, **kwargs) + _, call_args, _, _ = kernel.args.python_argdefs() + autotuning_log.debug("Generated Code:\n%s", code) + autotuning_log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE), + kernel.args.python_argdefs(), + ) + + input_reorder = ( + self.input_reorder + if self.input_reorder is not None + else list(range(len(self.input_nodes))) + ) + expected_args = list( + unique(self.input_nodes[idx].get_name() for idx in input_reorder) + ) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + V.graph.sizevars.size_hints(map(sympy.expand, call_args[len(expected_args) :])) + size_args = V.graph.sizevars.size_hints(kernel.get_dynamic_shape_args()) + extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs)) + + kernel_hash = hashlib.sha256(code.encode("utf-8")).hexdigest()[:8] + kernel_name = f"cutlass_{kernel_hash}" + code = code.replace(self.name, kernel_name) + + # create the BenchmarkRequest + bmreq = CUDABenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=extra_args, + source_code=code, + ) + + # kwargs has "op" argument in case of CUTLASSGemmTemplate + op = kwargs["op"] + if not op: + supports_epilogue_fusion = False + else: + # epilogue fusion is only supported for TMA kernels + supports_epilogue_fusion = self.supports_epilogue_fusion(op) + + def make_kernel_render( + template_node: CUDATemplateBuffer, + epilogue_nodes: Optional[list[BaseSchedulerNode]] = None, + ) -> tuple[CUDATemplateKernel, functools.partial[str]]: + assert supports_epilogue_fusion or not epilogue_nodes, ( + "epilogue fusion is not supported for this kernel" + ) + kernel = CUDATemplateKernel( + kernel_name=str(Placeholder.KERNEL_NAME), + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) + render = functools.partial( + self.render, + kernel=kernel, + template_buffer_node=template_node, + epilogue_nodes=epilogue_nodes, + **kwargs, # includes "op" argument in case of CUTLASSGemmTemplate + ) + return kernel, render + + return CUDATemplateCaller( + kernel_name, + "cutlass_gemm", + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + supports_epilogue_fusion, + self, + kwargs, + description, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + #include + #include + #include + #include + #include + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + // We compile all models with -fvisibility=hidden. Any symbols that need to be + // exposed in the final shared library must be declared with PT_EXPORT to make + // them visible. + #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) + #define PT_EXPORT __attribute__((__visibility__("default"))) + #else + #ifdef _WIN32 + #define PT_EXPORT __declspec(dllexport) + #else + #define PT_EXPORT + #endif + #endif + """ + ) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError + + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [] + + def get_runtime_arg_values(self, **kwargs) -> list[Any]: + return [] + + +class CUTLASSTemplate(CUDATemplate): + """ + CUTLASSTemplate is a class that provides a template for generating CUTLASS Templates. Used as a baseclass for the + CUTLASSGemmTemplate, providing functionality that might also be relevant for non-GEMM CUTLASS Kernels. + """ + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + #include "cute/tensor.hpp" + #include "cutlass/cutlass.h" + #include "cutlass/numeric_types.h" + #include "cutlass/tensor_ref.h" + #include "cutlass/util/host_tensor.h" + #include "cutlass/util/reference/host/tensor_fill.h" + #include "cutlass/util/reference/device/tensor_fill.h" + #include "cutlass/util/device_memory.h" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + using namespace cute; + #define CUTLASS_CHECK(status) \\ + { \\ + cutlass::Status error = status; \\ + if (error != cutlass::Status::kSuccess) { \\ + auto msg = std::string("[") + __FILE__ + "] Got cutlass error: " + \\ + cutlassGetStatusString(error) + " at: " + std::to_string(__LINE__); \\ + throw std::runtime_error(msg); \\ + } \\ + } + + // Used as pass-through functor in EVT just for type casting / rounding + template + struct identity_op { + CUTLASS_HOST_DEVICE + T operator()(T val) const { return val; } + }; + + """ + ) + return res + + def cute_int(self, int_str: str, var_name: str) -> str: + res = "" + if int_str in ("1", "1L"): + res = "cute::Int<1>{}" + else: + res = int_str + + return f"{res} /* {var_name} */" + + _DTYPE_TO_CUTLASS = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "cutlass::half_t", + torch.int32: "int32_t", + torch.int16: "int16_t", + torch.int8: "int8_t", + torch.uint8: "uint8_t", + torch.bool: "bool", + torch.bfloat16: "cutlass::bfloat16_t", + torch.float8_e4m3fn: "cutlass::float_e4m3_t", + } + + _DTYPE_TO_CUTLASS_SPARSE_META = { + torch.int32: "uint32_t", + torch.int16: "uint16_t", + } + + def cutlass_type_cast(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"({self._DTYPE_TO_CUTLASS.get(node.get_dtype())}*)({ptr})" + + def cutlass_sparse_meta_type_cast(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return ( + f"({self._DTYPE_TO_CUTLASS_SPARSE_META.get(node.get_dtype())}*)({ptr})" + ) + + @override + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [ArgInfo("swizzle", "const uint8_t")] + + @override + def get_runtime_arg_values(self, **kwargs) -> list[Any]: + """ + Helper method to retrieve runtime args from generate kwargs + """ + return [kwargs[arg.name] for arg in self.get_runtime_arg_info()] diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_cache.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..7cfc69f3902ac0bcb764892fc8e1ab35ba5cf5f1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_cache.py @@ -0,0 +1,105 @@ +# mypy: allow-untyped-defs +import functools +import hashlib +import json +import logging +import os +import time +from typing import Any, Optional + +import torch._inductor.config as config +from torch._inductor.codecache import cutlass_key +from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version +from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer +from torch._inductor.runtime.cache_dir_utils import cache_dir +from torch._inductor.utils import clear_on_fresh_cache + + +log = logging.getLogger(__name__) + + +CONFIG_PREFIX: str = "configs" + + +def get_config_request_key( + arch: str, + cuda_version: str, + instantiation_level: str, +) -> str: + """ + Return a key for the full ops, based on cutlass key, arch, cuda version, and instantiation level. + """ + hash_target = "-".join( + [ + cutlass_key().hex(), + arch, + cuda_version, + instantiation_level, + ] + ) + return hashlib.sha256(hash_target.encode("utf-8")).hexdigest()[0:8] + + +def _generate_config_filename(request_key: str) -> str: + """ + Generate a filename for the full ops. + """ + return f"{CONFIG_PREFIX}_{request_key}.json" + + +@clear_on_fresh_cache +@functools.cache +def maybe_fetch_ops() -> Optional[list[Any]]: + """ + Fetch ops from databases. + """ + if config.force_disable_caches: + return None + + # setup + arch: str = get_cuda_arch() + # get_cuda_version might return "12.4.0" or "12.4" + # but we want to use "12.4" + version: str = ".".join(get_cuda_version().split(".")[:2]) + instantiation_level: str = config.cuda.cutlass_instantiation_level + + # filename and filepath + request_key: str = get_config_request_key(arch, version, instantiation_level) + filename: str = _generate_config_filename(request_key) + filepath: str = os.path.join(cache_dir(), filename) + + # try fetch + serialized_ops: Optional[list[str]] = None + start_time = time.time() + if os.path.isfile(filepath): + # locally + try: + with open(filepath) as f: + serialized_ops = json.load(f) + + assert isinstance(serialized_ops, list), ( + f"Expected serialized ops is a list, got {type(serialized_ops)}" + ) + except Exception as e: + log.warning( + "Failed to load CUTLASS config %s from local cache: %s", + filename, + e, + ) + serialized_ops = None + elif config.is_fbcode(): + from torch._inductor.fb.cutlass_remote_cache import ( + maybe_fetch_cutlass_configs_from_remote, + ) + + # from remote + serialized_ops = maybe_fetch_cutlass_configs_from_remote(filepath) + + if serialized_ops is None: + return None + + # deserialize + serializer = get_cutlass_operation_serializer() + full_ops = [serializer.deserialize(x) for x in serialized_ops] # type: ignore[union-attr] + log.info("Loaded ops from %s cache in %.3fs", filename, time.time() - start_time) + return full_ops diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5d5753355611f8c092c41e35d152e81059700db Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/evt_extensions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/evt_extensions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06a68a25a61e609bebf05e481456647aee42fd24 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/evt_extensions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23005ce68d1db357410437f3312efa14fa2f9b47 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/__pycache__/gemm_operation_extensions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..719d5a6811584792a3542f998483e4deb15b12c5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py @@ -0,0 +1,240 @@ +from typing import Any, Callable, Union + +from sympy import Expr + +from torch._inductor.ir import ( + ComputedBuffer, + InputBuffer, + is_contiguous_strides_for_shape, +) +from torch.utils._ordered_set import OrderedSet + +from ..cutlass_utils import torch_dtype_to_cutlass_type, try_import_cutlass + + +EpilogueFunctor = Any # EpilogueFunctor local class defined in _trace +Buffer = Union[ComputedBuffer, InputBuffer] +CutlassTupleType = Any # cutlass.backend.c_types.tuple_factory_..TupleType +CutlassVisitorType = Any # cutlass.backend.c_types.visitor_factory..VisitorType +CutlassArgType = ( + Any # Can be a CutlassTupleType, CutlassVisitorType, EmptyByte, or ctype.c_void_p +) + + +if try_import_cutlass(): + import ast + import ctypes + import textwrap + from typing import Union + + from cutlass.backend.c_types import ( # type: ignore[import-untyped, import-not-found] + EmptyByte, + ) + from cutlass.backend.epilogue import ( # type: ignore[import-untyped, import-not-found] + dtype2ctype, + ) + from cutlass.backend.evt import ( # type: ignore[import-untyped, import-not-found] + EpilogueFunctorVisitor, + ) + from cutlass.backend.evt.backend.emitter_base import ( # type: ignore[import-untyped, import-not-found] + FusionCallbacks, + ) + from cutlass.backend.evt.backend.sm90_emitter import ( # type: ignore[import-untyped, import-not-found] + CollectiveEpilogue, + ) + from cutlass.backend.evt.frontend import ( # type: ignore[import-untyped, import-not-found] + PythonASTFrontend, + ) + from cutlass.backend.evt.ir.tensor import ( # type: ignore[import-untyped, import-not-found] + Tensor as CutlassTensor, + ) + from cutlass_library import ( + DataType, + EpilogueScheduleType, + LayoutType, + TileDescription, + ) + + from torch._inductor.codegen.cuda import cuda_env + from torch._inductor.utils import IndentedBuffer + + _CUTLASS_C_DTYPES = OrderedSet(dtype2ctype.values()) # type: ignore[var-annotated] + + def create_example_tensors( + var_name_to_buffer_name: dict[str, str], + name_to_buffer: dict[str, Buffer], + size_hint_fn: Callable[[Union[Expr, int]], int], + ) -> dict[str, CutlassTensor]: + def cutlass_tensor_from_buffer(buffer: Buffer) -> CutlassTensor: + shape = buffer.get_layout().size + stride = buffer.get_layout().stride + shape = tuple(size_hint_fn(x) for x in shape) + stride = tuple(size_hint_fn(x) for x in stride) + + is_row_major = is_contiguous_strides_for_shape(stride, shape) + is_column_major = is_contiguous_strides_for_shape(stride[::-1], shape[::-1]) + + if not is_row_major and not is_column_major: + raise RuntimeError( + f"Cannot create example tensor for {buffer.get_name()} with \ +non-contiguous layout, received stride: {stride} and shape: {shape}" + ) + + return CutlassTensor( + shape=shape, + layout_tag=LayoutType.RowMajor + if is_row_major + else LayoutType.ColumnMajor, + element=torch_dtype_to_cutlass_type(buffer.get_layout().dtype), + ) + + return { + key: cutlass_tensor_from_buffer(name_to_buffer[name]) + for key, name in var_name_to_buffer_name.items() + } + + def trace( + fn_src: str, + example_tensors: dict[str, CutlassTensor], + accum_type: DataType, + output_type: DataType, + tile_description: TileDescription, + epilogue_schedule: EpilogueScheduleType, + name_to_buffer: dict[str, Buffer], + size_hint_fn: Callable[[Union[Expr, int]], int], + **kwargs: dict[str, Any], + ) -> tuple[str, str, str]: + cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type] + assert cuda_arch >= 90, "Only SM90+ is supported for EVT" + epilogue_functor = _trace(fn_src, example_tensors, cuda_arch, **kwargs) + visitor = EpilogueFunctorVisitor(cuda_arch, epilogue_functor) + fusion_callbacks = FusionCallbacks(visitor.graph, cuda_arch, emit_CD=False) + collective_epilogue = CollectiveEpilogue( + tile_description, + epilogue_schedule, + accum_type, + output_type, + fusion_callbacks, + ) + evt_name, evt_code = collective_epilogue.emit() + evt_args = _render_argument_type(epilogue_functor, name_to_buffer, size_hint_fn) + return evt_name, evt_args, evt_code + + # Based off of + # https://github.com/NVIDIA/cutlass/blob/df18f5e4f5de76bed8be1de8e4c245f2f5ec3020/python/cutlass/epilogue/epilogue.py#L117 + # This is modified to enable directly passing the source code of the epilogue vs getting it from a bona-fide python function + # The reason for this is that inspect.getsource does not work with functions defined at runtime via exec/eval + def _trace( + fn_src: str, example_tensors: dict[str, CutlassTensor], cc: int, **kwargs: Any + ) -> EpilogueFunctor: + class EpilogueFunctor(PythonASTFrontend): + def __init__(self, cc: int, **kwargs: Any): + self.source = textwrap.dedent(fn_src) + super().__init__(cc, **kwargs) + + def parse(self, example_inputs: dict[str, CutlassTensor]) -> None: + self.example_inputs = example_inputs + self.ast = ast.parse(self.source) + self.visit(self.ast) + + cc = int(cuda_env.get_cuda_arch()) + epilogue_functor = EpilogueFunctor(cc=cc, **kwargs) + epilogue_functor.trace(example_tensors) + return epilogue_functor + + def _render_argument_type( + epilogue_functor: EpilogueFunctor, + name_to_buffer: dict[str, Buffer], + size_hint_fn: Callable[[Union[Expr, int]], int], + ) -> str: + epilogue_thread_type = epilogue_functor.epilogue_thread_type + + # Fragile, but this is the only way to guarantee t is expected type because t is a local class + def is_nested_visitor_type(t: type) -> bool: + return ( + ".".join([t.__module__, t.__qualname__]) + == "cutlass.backend.c_types.visitor_factory..VisitorType" + ) + + buffer = IndentedBuffer() + with buffer.set_tabwidth(2): + + def render_argument_type(name: str, t: CutlassArgType) -> None: + if issubclass(t, ctypes.c_byte): + buffer.writeline(f"{{}}, /* {name} */") + else: + fields = [ + ( + fname, + _get_arg_from_node(ty, name_to_buffer[name], size_hint_fn), + ) + for fname, ty in t._fields_ + ] + field_strs = [ + f"/* {fname} */ {str(field)}" for fname, field in fields + ] + buffer.writeline(f"{{{', '.join(field_strs)}}}, /* {name} */") + + def render_thread_type(name: str, t: CutlassArgType) -> None: + if is_nested_visitor_type(t): + buffer.writeline(f"{{ /* {name} */") + with buffer.indent(): + for name, inner_t in t._fields_: + render_thread_type(name, inner_t) + buffer.writeline("},") + else: + render_argument_type(name, t) + + # unroll the recursion once to address special case formatting + # namely, no ending comma and no indentation for the outermost thread type + buffer.writeline("{ /* thread */") + with buffer.indent(3): + if is_nested_visitor_type(epilogue_thread_type): + with buffer.indent(): + for name, inner_t in epilogue_thread_type._fields_: + render_thread_type(name, inner_t) + else: + render_argument_type("thread", epilogue_thread_type) + buffer.writeline("}") + + return buffer.getvalue() + + def _get_arg_from_node( + arg_ty: type, node: Buffer, size_hint_fn: Callable[[Union[Expr, int]], int] + ) -> str: + from ..cuda_template import CUTLASSTemplate + + # Today, arguments are either a pointer to the + # node's memory, a stride tuple, the datatype + # Once again, need to check for local class type for stride tuple + if ( + str(arg_ty) + == ".TupleType'>" + ): + DEFAULT_STRIDE_LEN = 3 + assert len(node.get_layout().stride) <= DEFAULT_STRIDE_LEN + stride = [size_hint_fn(x) for x in node.get_layout().stride] + for _ in range(DEFAULT_STRIDE_LEN - len(stride)): + stride.append(0) + + def render_stride(x: int) -> str: + # Handle EBO for 0 and 1 + if x == 0: + return "_0{}" + elif x == 1: + return "_1{}" + else: + return str(x) + + return f"{{{', '.join([render_stride(x) for x in stride])}}}" + + elif issubclass(arg_ty, ctypes.c_void_p): + return f"({CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}*) {node.get_name()}" + elif ( + arg_ty in _CUTLASS_C_DTYPES + ): # Assumption: this is the element dtype, this holds for all cutlass ir nodes currently + return f"{CUTLASSTemplate._DTYPE_TO_CUTLASS[node.get_layout().dtype]}(0)" + elif issubclass(arg_ty, EmptyByte): + return "{}" + + raise NotImplementedError(f"Unsupported arg type: {arg_ty}") diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..ff36ea0dc11dc7a60553f025a2421069f1db7517 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_lib_extensions/gemm_operation_extensions.py @@ -0,0 +1,411 @@ +# mypy: ignore-errors +from ..cutlass_utils import try_import_cutlass + + +# copied / modified from original at +# https://github.com/NVIDIA/cutlass/blob/8783c41851cd3582490e04e69e0cd756a8c1db7f/tools/library/scripts/gemm_operation.py#L658 + +if try_import_cutlass(): + import enum + + from cutlass_library.gemm_operation import * # noqa: F401, F403 + from cutlass_library.library import * # noqa: F401, F403 + + _LOGGER = logging.getLogger(__name__) + + class EmitGemmUniversal3xInstanceWithEVT: + """Responsible for emitting a CUTLASS 3.x template definition""" + + def __init__(self, operation_suffix="", evt_name=None): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/gemm/gemm.h", + "cutlass/numeric_types.h", + "cutlass/gemm/kernel/gemm_universal.hpp", + "cutlass/gemm/collective/collective_builder.hpp", + "cutlass/epilogue/collective/collective_builder.hpp", + ] + self.builtin_epilogue_functor_template = """${epilogue_functor}< + ${element_d}, + ${element_epilogue}, + ${element_c}, + ${element_epilogue} + >""" + + self.evt_name = evt_name + self.gemm_template = """ +using ${operation_name}_epilogue = +typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, ${opcode_class_epi}, + cute::Shape, + cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>, + ${epi_tile_mn}, + ${element_accumulator}, ${element_epilogue}, + ${element_c}, ${layout_c}, ${align_c}, + ${element_d}, ${layout_d}, ${align_d}, + ${epilogue_schedule}, + ${epilogue_functor} +>::CollectiveOp; + +${mixed_dtype_prepare_code} + +using ${operation_name}_mainloop = +typename cutlass::gemm::collective::CollectiveBuilder< + ${arch}, ${opcode_class_main}, + ${element_a}, ${layout_a}, ${align_a}, + ${element_b}, ${layout_b}, ${align_b}, + ${element_accumulator}, + cute::Shape, + cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>, + ${stages}, + ${kernel_schedule} +>::CollectiveOp; + +// Gemm operator ${operation_name} +using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< + ${problem_shape}, + ${operation_name}_mainloop, + ${operation_name}_epilogue, + ${tile_scheduler}>; + +// Define named type +struct ${operation_name} : +public ${operation_name}_base { }; + + """ + + # + def instance_template(self): + return """ +${compile_guard_start} +{ + using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>; + manifest.append( + new ${gemm_kind}("${operation_name}")); +} +${compile_guard_end} + """ + + def emit_block_scale_epilogue_functor(self, operation): + block_scaled_template = """ + ${epilogue_functor}< + ${epi_vs}, + ${element_d}, + ${element_accumulator}, + ${element_sfd}, + ${layout_sfd}, + ${element_c}, + ${element_scalar} + > + """ + block_scaled_values = { + "epi_vs": str(operation.ScaleFactorVectorSize), + "element_d": str(DataTypeTag[operation.D.element]), + "element_sfd": str(DataTypeTag[operation.ScaleFactorD.element]), + "layout_sfd": LayoutTag[operation.ScaleFactorD.layout], + "epilogue_functor": EpilogueFunctor3xTag[ + EpilogueFunctor3x.LinearCombinationBlockScaleFactor + ], + "element_accumulator": str(DataTypeTag[operation.accumulator_type()]), + "element_scalar": str(DataTypeTag[operation.accumulator_type()]), + "element_c": str(DataTypeTag[operation.C.element]), + } + return SubstituteTemplate(block_scaled_template, block_scaled_values) + + @staticmethod + def pointerize_if_grouped(operation, layout): + return layout if not is_grouped(operation.gemm_kind) else layout + "* " + + @staticmethod + def problem_shape(operation): + gemm_shape_type = "cute::Shape" + grouped_gemm_shape_type = "cute::Shape" + grouped_gemm_shape_type = ( + "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">" + ) + + return ( + gemm_shape_type + if not is_grouped(operation.gemm_kind) + else grouped_gemm_shape_type + ) + + def emit(self, operation): + """Given a gem operation, emits a template definition of the operation""" + + opcode_class_main = operation.tile_description.math_instruction.opcode_class + opcode_class_epi = opcode_class_main + + tile_shape = operation.tile_description.tile_shape + instruction_shape = ( + operation.tile_description.math_instruction.instruction_shape + ) + cluster_m = operation.tile_description.cluster_shape[0] + cluster_n = operation.tile_description.cluster_shape[1] + + tile_shape_m, tile_shape_n, tile_shape_k = tile_shape + + # account for static/dynamic cluster shapes + cta_m = tile_shape[0] // cluster_m if cluster_m > 0 else tile_shape[0] + cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1] + + # Shape passed to epilogue builder + is_sm100_kernel = operation.arch == 100 + if is_sm100_kernel: + cta_m_per_mma_instruction = ( + 2 if "2sm" in operation.procedural_name() else 1 + ) + if cluster_m <= 0: + cta_m = cta_m // cta_m_per_mma_instruction + + if opcode_class_main in [ + OpcodeClass.TensorOp, + OpcodeClass.BlockScaledTensorOp, + ]: + tile_shape_m = instruction_shape[0] + tile_shape_n = instruction_shape[1] + + # stage count set to zero indicates builder automatic stage selection + if operation.tile_description.stages > 0: + stage_count_string = f"cutlass::gemm::collective::StageCount<\ +{str(operation.tile_description.stages)}>" + else: + stage_count_string = ( + f"cutlass::gemm::collective::StageCountAutoCarveout(\ +sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage))>" + ) + + epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto" + + ( + instance_layout_A, + instance_layout_B, + instance_layout_C, + instance_layout_D, + ) = ( + operation.A.layout, + operation.B.layout, + operation.C.layout, + operation.D.layout, + ) + + # 3.0 profiler integration only supports trivial epilogues for now + epilogue_vector_length = 1 + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + values = { + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), + "epilogue_functor": EpilogueFunctor3xTag[ + operation.epilogue_functor + ], + } + epilogue_functor = SubstituteTemplate( + self.builtin_epilogue_functor_template, values + ) + + if ( + is_block_scaled(operation.gemm_kind) + and operation.ScaleFactorD.element != DataType.void + ): + epilogue_functor = self.emit_block_scale_epilogue_functor(operation) + else: + epilogue_functor = self.epilogue_functor.emit_declaration() + + if ( + is_block_scaled(operation.gemm_kind) + and operation.ScaleFactorD.element != DataType.void + ): + epilogue_functor = self.emit_block_scale_epilogue_functor(operation) + + # + # Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, + # e.g. cute::tuple, Transform : cute::identity / cute::conjugate. + element_a = ( + DataTypeTag[operation.A.element] + if not operation.is_complex() + else f"cute::tuple<{str(DataTypeTag[operation.A.element])},\ +{str(ComplexTransformTag3x[operation.A.complex_transform])}>" + ) + element_b = ( + DataTypeTag[operation.B.element] + if not operation.is_complex() + else f"cute::tuple<{str(DataTypeTag[operation.B.element])},\ +{str(ComplexTransformTag3x[operation.B.complex_transform])}>" + ) + epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] + + if opcode_class_main == OpcodeClass.BlockScaledTensorOp: + is_no_smem_epilogue = operation.epilogue_schedule in [ + EpilogueScheduleType.NoSmemWarpSpecialized1Sm, + EpilogueScheduleType.NoSmemWarpSpecialized2Sm, + ] + grouped = is_grouped(operation.gemm_kind) + if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule( + KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped + ): + epi_tile_mn = "cute::Shape" + if not is_no_smem_epilogue: + epilogue_schedule_type = EpilogueScheduleTag[ + to_grouped_schedule( + EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped + ) + ] + if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule( + KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped + ): + epi_tile_mn = "cute::Shape" + if not is_no_smem_epilogue: + epilogue_schedule_type = EpilogueScheduleTag[ + to_grouped_schedule( + EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped + ) + ] + element_a = f"cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>" + element_b = f"cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>" + + operation_name_str = operation.procedural_name() + layout_a_str = LayoutTag[instance_layout_A] + layout_b_str = LayoutTag[instance_layout_B] + mixed_dtype_prepare_code = "" + if operation.mixed_input_mode is not None: + A_dtype = operation.A.element + B_dtype = operation.B.element + A_dtype_bits = DataTypeSize[A_dtype] + B_dtype_bits = DataTypeSize[B_dtype] + is_A_dtype_narrow = A_dtype_bits < B_dtype_bits + if is_A_dtype_narrow: + narrow_dtype, wide_dtype = (A_dtype, B_dtype) + narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits) + else: + narrow_dtype, wide_dtype = (B_dtype, A_dtype) + narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits) + + narrow_tag = DataTypeTag[narrow_dtype] + wide_tag = DataTypeTag[wide_dtype] + scale_tag = DataTypeTag[wide_dtype] + zero_tag = DataTypeTag[wide_dtype] + + do_shuffle = False + value_shuffle_str = "" + if narrow_dtype_bits == 4 and wide_dtype_bits == 16: + value_shuffle_str = "cute::Layout, \ +cute::Stride>" + do_shuffle = True + if narrow_dtype_bits == 8 and wide_dtype_bits == 16: + value_shuffle_str = "cute::Layout, \ +cute::Stride>" + do_shuffle = True + do_shuffle = operation.mixed_input_shuffle and do_shuffle + + if do_shuffle: + if is_A_dtype_narrow: + stride_narrow_str = ( + f"cutlass::detail::TagToStrideA_t<{layout_a_str}>" + ) + layout_a_str = f"{operation_name_str}_LayoutNarrowReordered" + else: + stride_narrow_str = ( + f"cutlass::detail::TagToStrideB_t<{layout_b_str}>" + ) + layout_b_str = f"{operation_name_str}_LayoutNarrowReordered" + # The {operation_name_str}_ prefixs in mixed_dtype_prepare_code and + # layout_{a, b}_str are to prevent errors in Windows platform unity build + mixed_dtype_prepare_code = f""" + using {operation_name_str}_StrideNarrow = {stride_narrow_str}; + using {operation_name_str}_ValueShuffle = {value_shuffle_str}; + static constexpr int {operation_name_str}_NumShuffleAtoms = 1; + using {operation_name_str}_MmaAtomShape = \ +cute::Layout>>; + using {operation_name_str}_LayoutAtomQuant = \ +decltype(cutlass::compute_memory_reordering_atom<{wide_tag}, {operation_name_str}_MmaAtomShape, \ +{operation_name_str}_ValueShuffle>()); + using {operation_name_str}_LayoutNarrowReordered = \ +decltype(cute::tile_to_shape({operation_name_str}_LayoutAtomQuant{{}}, \ +cute::Layout, {operation_name_str}_StrideNarrow>{{}})); + """ + + mixed_input_modes_to_element = { + MixedInputMode.ConvertOnly: narrow_tag, + MixedInputMode.ScaleOnly: f"cute::tuple<{narrow_tag}, {scale_tag}>", + MixedInputMode.ScaleWithZeroPoint: f"cute::tuple<{narrow_tag}, {scale_tag}, {zero_tag}>", + } + narrow_element = mixed_input_modes_to_element.get( + operation.mixed_input_mode, narrow_tag + ) + + if narrow_dtype == DataType.s4 and ( + wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2 + ): + narrow_element = ( + f"cute::tuple<{narrow_tag}, cutlass::Array<{scale_tag}, 8>>" + ) + + if is_A_dtype_narrow: + element_a = narrow_element + else: + element_b = narrow_element + + if self.evt_name: + epilogue_functor = self.evt_name + + values = { + "operation_name": operation_name_str, + "operation_suffix": self.operation_suffix, + "problem_shape": self.problem_shape(operation), + "element_a": element_a, + "layout_a": self.pointerize_if_grouped(operation, layout_a_str), + "element_b": element_b, + "layout_b": self.pointerize_if_grouped(operation, layout_b_str), + "element_c": DataTypeTag[operation.C.element], + "layout_c": self.pointerize_if_grouped( + operation, LayoutTag[instance_layout_C] + ), + "element_d": DataTypeTag[operation.D.element], + "layout_d": self.pointerize_if_grouped( + operation, LayoutTag[instance_layout_D] + ), + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class_main": OpcodeClassTag[opcode_class_main], + "opcode_class_epi": OpcodeClassTag[opcode_class_epi], + "arch": f"cutlass::arch::Sm{operation.arch}", + "tile_shape_m": str(tile_shape_m), + "tile_shape_n": str(tile_shape_n), + "tile_shape_k": str(tile_shape_k), + "cluster_shape_m": "cute::_" + + str(operation.tile_description.cluster_shape[0]) + if operation.tile_description.cluster_shape[0] > 0 + else "int", + "cluster_shape_n": "cute::_" + + str(operation.tile_description.cluster_shape[1]) + if operation.tile_description.cluster_shape[1] > 0 + else "int", + "cluster_shape_k": "cute::_" + + str(operation.tile_description.cluster_shape[2]) + if operation.tile_description.cluster_shape[2] > 0 + else "int", + "instruction_shape_m": str(instruction_shape[0]), + "instruction_shape_n": str(instruction_shape[1]), + "instruction_shape_k": str(instruction_shape[2]), + "kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]), + "epilogue_schedule": str(epilogue_schedule_type), + "epi_tile_mn": epi_tile_mn, + "epilogue_functor": epilogue_functor, + "stages": stage_count_string, + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + "align_c": str(operation.C.alignment), + "align_d": str(operation.C.alignment), + "transform_a": ComplexTransformTag[operation.A.complex_transform], + "transform_b": ComplexTransformTag[operation.B.complex_transform], + "math_operation": MathOperationTag[ + operation.tile_description.math_instruction.math_operation + ], + "epilogue_vector_length": str(epilogue_vector_length), + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), + "tile_scheduler": str(TileSchedulerTag[operation.tile_scheduler]), + "mixed_dtype_prepare_code": mixed_dtype_prepare_code, + } + + return SubstituteTemplate(self.gemm_template, values) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_presets.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_presets.py new file mode 100644 index 0000000000000000000000000000000000000000..7d60f40c4bc3ffbd8ccd625648f6d66cf407b776 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_presets.py @@ -0,0 +1,239 @@ +import functools +from collections import defaultdict + +import torch +from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch + + +@functools.cache +def gen_cutlass_presets() -> dict[int, dict[str, list[str]]]: + """ + Generate cutlass presets for the given CUDA arch. + """ + presets: dict[int, dict[str, list[str]]] = {} + + if not torch._C._has_cuda: + return presets + + presets[0] = defaultdict(list) + arch = get_cuda_arch() + if arch == "90": + preset = presets[0] + preset["0"] = [ + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_64x256x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + ] + preset["1111"] = [ + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + ] + preset["2222"] = [ + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + ] + preset["3333"] = [ + r"cutlass3x_sm90_tensorop_s64x48x16gemm_.*_64x48x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_4x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_4x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_4x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_256x192x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_2x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_256x192x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + ] + preset["4444"] = [ + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x8x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_2x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_64x192x64_4x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_2x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_4x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_2x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_2x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + ] + preset["5555"] = [ + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_2x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x32x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_128x32x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x256_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_1x4x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_2x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_1x2x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_256x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x128_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_1x8x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x64x128_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x128_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_64x32x64_1x4x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_128x192x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x128_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x64_1x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x1x1_0_.*_align.*_cpasync_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_2x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x2x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x64_2x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_2x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x256x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_64x128x128_4x1x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x256x64_1x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x128x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x32x16gemm_.*_128x64x64_1x2x1_0_.*_align.*_warpspecialized_pingpong_epi_tma", + r"cutlass3x_sm90_tensorop_s64x64x16gemm_.*_128x128x64_2x1x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x128_1x8x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x160x16gemm_.*_256x160x64_1x1x1_0_.*_align.*_stream_k_warpspecialized_cooperative_epi_tma", + r"cutlass3x_sm90_tensorop_s64x16x16gemm_.*_64x16x256_1x1x1_0_.*_align.*_warpspecialized_epi_nosmem", + r"cutlass3x_sm90_tensorop_s64x192x16gemm_.*_256x192x64_1x2x1_0_.*_align.*_warpspecialized_cooperative_epi_tma", + ] + + return presets diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_python_evt.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_python_evt.py new file mode 100644 index 0000000000000000000000000000000000000000..96400763ddb1416af1b52432b203e6039b2a7381 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_python_evt.py @@ -0,0 +1,322 @@ +import itertools +from collections.abc import Generator, Iterable, Iterator, Sequence +from contextlib import contextmanager +from os import linesep +from typing import Any, Optional + +import sympy + +import torch +import torch._inductor.virtualized as virtualized +from torch._inductor.ir import ComputedBuffer, Pointwise +from torch._inductor.ops_handler import DefaultHandler, WrapperHandler +from torch._inductor.scheduler import BaseSchedulerNode +from torch._inductor.utils import DelayReplaceLine, IndentedBuffer, OrderedSet +from torch._inductor.virtualized import OpsValue + +from ...virtualized import V + + +_ACCUMULATOR_ARG_NAME = "accum" + + +def scaled_mm_evt( + scale_A_name: str, scale_B_name: str, bias_name: Optional[str], output_name: str +) -> tuple[list[str], dict[str, Any], str]: + evt_read_names = [scale_A_name, scale_B_name] + var_name_to_buffer_name = {n: n for n in [scale_A_name, scale_B_name]} + var_name_to_buffer_name["D"] = output_name + var_name_to_buffer_name[_ACCUMULATOR_ARG_NAME] = output_name + expr = f"accum * {scale_A_name} * {scale_B_name}{linesep}" + if bias_name: + expr = f"({expr}) + {bias_name}" + evt_read_names.append(bias_name) + var_name_to_buffer_name[bias_name] = bias_name + + evt_py_code = f"def fn(accum, {','.join(evt_read_names)}):{linesep}\ + D = {expr}{linesep}\ + return D{linesep}" + + return evt_read_names, var_name_to_buffer_name, evt_py_code + + +class CutlassEVTOpsMixIn: + @staticmethod + def _infix_bin_op(op: str, a: str, b: str) -> str: + return f"{a} {op} {b}" + + @staticmethod + def _prefix_bin_op(op: str, a: str, b: str) -> str: + return f"{op}({a}, {b})" + + @staticmethod + def _prefix_un_op(op: str, a: str) -> str: + return f"{op}({a})" + + @staticmethod + def to_dtype( + x: str, + dtype: Any, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = False, + ) -> str: + return x + + @staticmethod + def constant(value: Any, dtype: Any) -> str: + raise NotImplementedError + + @staticmethod + def mul(x0: str, x1: str) -> str: + return CutlassEVTOpsMixIn._infix_bin_op("*", x0, x1) + + @staticmethod + def truediv(x0: str, x1: str) -> str: + return CutlassEVTOpsMixIn._infix_bin_op("/", x0, x1) + + @staticmethod + def ge(x0: str, x1: str) -> str: + raise NotImplementedError + + @staticmethod + def add(x0: str, x1: str) -> str: + return CutlassEVTOpsMixIn._infix_bin_op("+", x0, x1) + + @staticmethod + def relu(x0: str) -> str: + return CutlassEVTOpsMixIn._prefix_un_op("relu", x0) + + @staticmethod + def sigmoid(x0: str) -> str: + raise NotImplementedError("sigmoid is not supported in CUTLASS python evt") + + @staticmethod + def sub(x0: str, x1: str) -> str: + return CutlassEVTOpsMixIn._infix_bin_op("-", x0, x1) + + @staticmethod + def tanh(x0: str) -> str: + raise NotImplementedError("tanh is not supported in CUTLASS python evt") + + +class MockCutlassHandler(CutlassEVTOpsMixIn, WrapperHandler): + """Passthrough handler for cutlass ops, used for running epilogue nodes for memory planning""" + + +class _AssignmentFormatter(DefaultHandler): + def __init__(self, parent_handler: "CutlassEVTCodegen"): + self.parent_handler = parent_handler + + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + # Handle op dispatch here + if hasattr(self.parent_handler, name): + fn = getattr(self.parent_handler, name) + line = fn(*args, **kwargs) + if name in ("load", "store"): + return OpsValue(line) + else: + var = self.parent_handler._tmp_var() + line = DelayReplaceLine( + var, + lambda: "D" + if var == self.parent_handler.last_stored_var_name + else var, + f"{var} = {line}", + ) + self.parent_handler.body.writeline(line) + return OpsValue(var) + else: + raise NotImplementedError(name) + + +class CutlassEVTCodegen(CutlassEVTOpsMixIn): + """ + Notes: + * Used by CUTLASSGemmTemplate. + * This class should not be instantiated by users, it is intended to be used + by calling CutlassEVTCodegen.ir_to_evt_python_code(...) + which instantiates this class as an ops handler for virtualized.V.ops.[op-name] + * Extend this with more _op_ nodes to add support for new pointwise operations. + """ + + def __init__(self, accumulator_node_name: str, removed_buffers: OrderedSet[str]): + """ + + Initializes a CutlassEVTEpilogueArgumentFormatter object. Do not instantiate directly. + Use the CutlassEVTCodegen.ir_to_evt_python_code static method. + + Args: + accumulator_node_name: The name of the accumulator node which should contain + the Matmul result before fusion according to the IR graph. + epilogue_nodes: The list of scheduler nodes to be fused into the epilogue + """ + self.accumulator_node_name: str = accumulator_node_name # + self.body: IndentedBuffer = IndentedBuffer(1) # The body buffer for codegen + self.var_counter: Iterator[int] = itertools.count() + self.store_name_to_value: dict[str, OpsValue] = ( + dict() + ) # Aliases for subexpression functors + self.reads: OrderedSet[str] = OrderedSet([]) + # Used for creating example tensors + self.var_name_to_buffer_name: dict[str, str] = { + _ACCUMULATOR_ARG_NAME: accumulator_node_name + } + self.removed_buffers: OrderedSet[str] = removed_buffers + self.cur_node: Optional[ComputedBuffer] = None + self.name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs + for name in V.graph.constants.keys(): + self.name_to_buffer[name] = V.graph.add_tensor_constant( + V.graph.constants[name], name + ) + self.is_D_assigned = False + self.D_var_name = None + + if accumulator_node_name not in removed_buffers: + # cannot return accumulator directly, so alias it + var = self._tmp_var() + self.body.writeline(f"{var} = {_ACCUMULATOR_ARG_NAME}") + self.store(accumulator_node_name, value=OpsValue(var)) + + @staticmethod + def ir_to_evt_python_code( + cuda_template_node_name: str, + epilogue_nodes: list[BaseSchedulerNode], + removed_buffers: OrderedSet[str], + ) -> tuple[list[str], list[str], dict[str, Any], str]: + codegen = CutlassEVTCodegen(cuda_template_node_name, removed_buffers) + handler = _AssignmentFormatter(codegen) + + with virtualized.V.set_ops_handler(handler): + for s_node in epilogue_nodes: + node = s_node.node + assert isinstance(node, ComputedBuffer) + with codegen.set_cur_node(node): + index_vars = CutlassEVTCodegen.get_index_vars(node) + node.get_store_function()(index_vars) + + codegen.finalize() + + return ( + codegen.get_reads(), + codegen.get_writes(), + codegen.get_renames(), + codegen.get_value(), + ) + + def get_value(self) -> str: + return linesep.join( + [ + self._render_input_signature(), + self.body.getvalue(), + self._render_return_statement(), + ] + ) + + def finalize(self) -> None: + # Rename the last store to D + # no other code references this store + # to workaround https://github.com/NVIDIA/cutlass/issues/2288 + # Note: the delayed line will automatically rewrite the last assignment to + # be to D + buffer_name = self.var_name_to_buffer_name[self.last_stored_var_name] + self.var_name_to_buffer_name.pop(self.last_stored_var_name) + self.var_name_to_buffer_name["D"] = buffer_name + self.store_name_to_value[buffer_name] = OpsValue("D") + + @contextmanager + def set_cur_node(self, node: ComputedBuffer) -> Generator[None, Any, Any]: + prev_node = self.cur_node + try: + self.cur_node = node + yield + finally: + self.cur_node = prev_node + + def get_renames(self) -> dict[str, str]: + return dict(self.var_name_to_buffer_name) + + def get_reads(self) -> list[str]: + return list(self.reads.difference(self.store_name_to_value.keys())) + + def get_writes(self) -> list[str]: + return list(self.store_name_to_value.keys()) + + def load(self, name: str, index: Any) -> str: + self._check_indexing(name, index) + if name in self.store_name_to_value: + return self.store_name_to_value[name].value + elif name == self.accumulator_node_name: + return _ACCUMULATOR_ARG_NAME + else: + self.reads.add(name) + self.var_name_to_buffer_name[name] = name + return name + + def store( + self, name: Any, index: Any = None, value: Any = None, mode: Any = None + ) -> None: + if name not in self.removed_buffers: + if index: + self._check_indexing(name, index) + assert value.value != _ACCUMULATOR_ARG_NAME, ( + "Cannot store accumulator arg name" + ) + self.var_name_to_buffer_name[value.value] = name + self.store_name_to_value[name] = value + self.last_stored_var_name = value.value + return None + + def _get_cur_node(self) -> ComputedBuffer: + assert self.cur_node + return self.cur_node + + @staticmethod + def get_index_vars(node: ComputedBuffer) -> Sequence[sympy.Expr]: + data = node.data + # TODO mlazos: relax this, cutlass supports reductions and other ops + assert isinstance(data, Pointwise) + return data._index(data.ranges) + + def _get_current_index_vars(self) -> Sequence[sympy.Expr]: + return self.get_index_vars(self._get_cur_node()) + + def _check_indexing(self, name: str, index: sympy.Expr) -> None: + # We only support indexing that matches the layout today because + # CUTLASS doesn't support arbitrary indexing + buffer_name = ( + self.accumulator_node_name if name == _ACCUMULATOR_ARG_NAME else name + ) + buffer = self.name_to_buffer[buffer_name] + index_strides = V.graph.sizevars.stride_vars( + index, self._get_current_index_vars() + ) + stride = buffer.get_layout().stride + if not self._stride_compatible(stride, index_strides): + raise NotImplementedError( + f"Unsupported indexing for {name} with index {index}, index strides {index_strides}, and layout stride {stride}" + ) + + def _stride_compatible( + self, left: Iterable[sympy.Expr], right: Iterable[sympy.Expr] + ) -> bool: + return all( + sympy.Eq(l, r) or sympy.Eq(l, 0) or sympy.Eq(r, 0) + for l, r in (zip(left, right)) + ) + + def _render_input_signature(self) -> str: + arguments = ", ".join( + [_ACCUMULATOR_ARG_NAME] + + [name for name in self.reads if name != self.accumulator_node_name] + ) + return f"def fn({arguments}):" + + def _render_return_statement(self) -> str: + return_vars = OrderedSet( + op_v.value for op_v in self.store_name_to_value.values() + ) + assert "D" in return_vars + return f"return {', '.join(return_vars)}" + + def _tmp_var(self) -> str: + return f"tmp_{next(self.var_counter)}" diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5a601125577a243dbcff5fab034133e4895f439a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/cutlass_utils.py @@ -0,0 +1,497 @@ +# mypy: allow-untyped-defs +import atexit +import functools +import logging +import os +import shutil +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional + +import sympy + +import torch +from torch._inductor.utils import clear_on_fresh_cache + +from ... import config +from ...ir import Layout +from ...runtime.runtime_utils import cache_dir +from ...virtualized import V +from ..cpp_utils import DTYPE_TO_CPP +from .cuda_env import get_cuda_arch, get_cuda_version + + +log = logging.getLogger(__name__) + +CUTLASS_OPERATION_KIND: str = "gemm" + + +@atexit.register +def move_cutlass_compiled_cache() -> None: + """Move CUTLASS compiled cache file to the cache directory if it exists.""" + if "cutlass" not in sys.modules: + return + + import cutlass # type: ignore[import-not-found] + + if not os.path.exists(cutlass.CACHE_FILE): + return + + try: + filename = os.path.basename(cutlass.CACHE_FILE) + shutil.move(cutlass.CACHE_FILE, os.path.join(cache_dir(), filename)) + log.debug("Moved CUTLASS compiled cache file to %s", cache_dir()) + except OSError as e: + log.warning("Failed to move CUTLASS compiled cache file: %s", str(e)) + + +def _rename_cutlass_import(content: str, cutlass_modules: list[str]) -> str: + for cutlass_module in cutlass_modules: + content = content.replace( + f"from {cutlass_module} import ", + f"from cutlass_library.{cutlass_module} import ", + ) + return content + + +@functools.cache +def try_import_cutlass() -> bool: + """ + We want to support three ways of passing in CUTLASS: + 1. fbcode, handled by the internal build system. + 2. pip install nvidia-cutlass, which provides the cutlass_library package + and the header files in the cutlass_library/source directory. + 3. User specifies cutlass_dir. The default is ../third_party/cutlass/, + which is the directory when developers build from source. + """ + if config.is_fbcode(): + try: + import cutlass # type: ignore[import-not-found] + import cutlass_library # type: ignore[import-not-found] + except ImportError as e: + log.warning( + "Failed to import CUTLASS packages in fbcode: %s, ignoring the CUTLASS backend.", + str(e), + ) + return False + + return True + + try: + import cutlass # type: ignore[import-not-found] # noqa: F811 + import cutlass_library # type: ignore[import-not-found] # noqa: F811 + + cutlass_minor_vesion = int(cutlass.__version__.split(".")[1]) + if cutlass_minor_vesion < 7: + log.warning("CUTLASS version < 3.7 is not recommended.") + + log.debug( + "Found cutlass_library in python search path, overriding config.cuda.cutlass_dir" + ) + cutlass_library_dir = os.path.dirname(cutlass_library.__file__) + assert os.path.isdir(cutlass_library_dir), ( + f"{cutlass_library_dir} is not a directory" + ) + config.cuda.cutlass_dir = os.path.abspath( + os.path.join( + cutlass_library_dir, + "source", + ) + ) + + return True + except ModuleNotFoundError: + log.debug( + "cutlass_library not found in sys.path, trying to import from config.cuda.cutlass_dir" + ) + + # Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path. + # This is a temporary hack to avoid CUTLASS module naming conflicts. + # TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues. + + # TODO(mlazos): epilogue visitor tree currently lives in python/cutlass, + # but will be moved to python/cutlass_library in the future (later 2025) + def path_join(path0, path1): + return os.path.abspath(os.path.join(path0, path1)) + + # contains both cutlass and cutlass_library + # we need cutlass for eVT + cutlass_python_path = path_join(config.cuda.cutlass_dir, "python") + torch_root = os.path.abspath(os.path.dirname(torch.__file__)) + mock_src_path = os.path.join( + torch_root, + "_inductor", + "codegen", + "cuda", + "cutlass_lib_extensions", + "cutlass_mock_imports", + ) + + cutlass_library_src_path = path_join(cutlass_python_path, "cutlass_library") + cutlass_src_path = path_join(cutlass_python_path, "cutlass") + pycute_src_path = path_join(cutlass_python_path, "pycute") + + tmp_cutlass_full_path = os.path.abspath(os.path.join(cache_dir(), "torch_cutlass")) + + dst_link_library = path_join(tmp_cutlass_full_path, "cutlass_library") + dst_link_cutlass = path_join(tmp_cutlass_full_path, "cutlass") + dst_link_pycute = path_join(tmp_cutlass_full_path, "pycute") + + # mock modules to import cutlass + mock_modules = ["cuda", "scipy", "pydot"] + + if os.path.isdir(cutlass_python_path): + if tmp_cutlass_full_path not in sys.path: + + def link_and_append(dst_link, src_path, parent_dir): + if os.path.exists(dst_link): + assert os.path.islink(dst_link), ( + f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again." + ) + assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath( + src_path, + ), f"Symlink at {dst_link} does not point to {src_path}" + else: + os.makedirs(parent_dir, exist_ok=True) + os.symlink(src_path, dst_link) + + if parent_dir not in sys.path: + sys.path.append(parent_dir) + + link_and_append( + dst_link_library, cutlass_library_src_path, tmp_cutlass_full_path + ) + link_and_append(dst_link_cutlass, cutlass_src_path, tmp_cutlass_full_path) + link_and_append(dst_link_pycute, pycute_src_path, tmp_cutlass_full_path) + + for module in mock_modules: + link_and_append( + path_join(tmp_cutlass_full_path, module), # dst_link + path_join(mock_src_path, module), # src_path + tmp_cutlass_full_path, # parent + ) + + try: + import cutlass # noqa: F401 + import cutlass_library.generator # noqa: F401 + import cutlass_library.library # noqa: F401 + import cutlass_library.manifest # noqa: F401 + import pycute # type: ignore[import-not-found] # noqa: F401 + + return True + except ImportError as e: + log.debug( + "Failed to import CUTLASS packages: %s, ignoring the CUTLASS backend.", + str(e), + ) + else: + log.debug( + "Failed to import CUTLASS packages: CUTLASS repo does not exist: %s", + cutlass_python_path, + ) + return False + + +@functools.lru_cache(8) +def _normalize_cuda_arch(arch: str) -> str: + if int(arch) >= 100: + log.warning( + "Detected CUDA architecture >= 100: %s. We will generate operations with " + "GenerateSM100 (if available) and GenerateSM90. Please file an " + "issue for any problems and feedback. ", + arch, + ) + + if int(arch) >= 100: + return "100" + elif int(arch) >= 90: + return "90" + elif int(arch) >= 80: + return "80" + elif int(arch) >= 75: + return "75" + elif int(arch) >= 70: + return "70" + else: + raise NotImplementedError(f"Unsupported cuda arch: {arch}") + + +@dataclass +class CUTLASSArgs: + """ + CUTLASS args used to initialize a CUTLASS Manifest. + """ + + architectures: Optional[str] = None + cuda_version: Optional[str] = None + instantiation_level: Optional[str] = None + operations: Optional[str] = None + + build_dir = "" + curr_build_dir = "" + generator_target = "" + kernels = "all" + ignore_kernels = "" + exclude_kernels = "" + # TODO: these three look dead? + kernel_filter_file: None = None + selected_kernel_list: None = None + interface_dir: None = None + filter_by_cc = True + disable_full_archs_compilation = False + + def __post_init__(self): + if self.architectures is None or self.cuda_version is None: + raise RuntimeError( + f"{self.architectures=} or {self.cuda_version=} is None!" + ) + self.architectures = _normalize_cuda_arch(self.architectures) + + +@clear_on_fresh_cache +@functools.cache +def _gen_ops_cached(arch, version) -> dict[Any, Any]: + # Note: Cache needs to be specific for cuda architecture and version + + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library.generator as cutlass_generator + import cutlass_library.manifest as cutlass_manifest + + if arch is None or version is None: + log.error( + "Cannot detect cuda arch %s or cuda version %s. " + "Will discard all cutlass ops. " + "Please consider setting _inductor.cuda.arch and _inductor.cuda.version configs.", + arch, + version, + ) + return {} + arch = _normalize_cuda_arch(arch) + instantiation_level: str = config.cuda.cutlass_instantiation_level + args = CUTLASSArgs( + architectures=arch, + cuda_version=version, + instantiation_level=instantiation_level, + operations=CUTLASS_OPERATION_KIND, + ) + manifest = cutlass_manifest.Manifest(args) + + start_time = time.time() + if arch == "100": + if hasattr(cutlass_generator, "GenerateSM100"): + cutlass_generator.GenerateSM100(manifest, args.cuda_version) + cutlass_generator.GenerateSM90(manifest, args.cuda_version) + else: + try: + func = getattr(cutlass_generator, "GenerateSM" + arch) + func(manifest, args.cuda_version) + except AttributeError as e: + raise NotImplementedError( + "Arch " + arch + " is not supported by current cutlass lib." + ) from e + + log.info( + "CUTLASS library generated a dict of %d operation kinds in %.2f seconds", + len(manifest.operations), + time.time() - start_time, + ) + return manifest.operations + + +def gen_ops() -> dict[Any, Any]: + """ + Generates all supported CUTLASS operations. + """ + arch = get_cuda_arch() + version = get_cuda_version() + return _gen_ops_cached(arch, version) + + +DTYPE_TO_CUTLASS_TYPE = { + **DTYPE_TO_CPP, + torch.float16: "__half", + torch.bfloat16: "__nv_bfloat16", + torch.float8_e4m3fn: "__nv_fp8_e4m3", +} + + +@functools.lru_cache(32) +def torch_dtype_to_cutlass_type( + torch_dtype: torch.dtype, +) -> "cutlass_library.library.DataType": # type: ignore[name-defined] # noqa: F821 + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library # type: ignore[import] + + if torch_dtype == torch.float: + return cutlass_library.library.DataType.f32 + elif torch_dtype == torch.half: + return cutlass_library.library.DataType.f16 + elif torch_dtype == torch.bfloat16: + return cutlass_library.library.DataType.bf16 + else: + raise NotImplementedError(f"Unsupported data type: {torch_dtype=}") + + +@functools.lru_cache(32) +def dtype_match( + torch_dtype: Optional[torch.dtype], + cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821 +) -> bool: + # Import cutlass python scripts. + assert try_import_cutlass() + import cutlass_library + + if torch_dtype == torch.float: + return ( + cutlass_dtype == cutlass_library.library.DataType.f32 + or cutlass_dtype == cutlass_library.library.DataType.tf32 + ) + elif torch_dtype == torch.half: + return cutlass_dtype == cutlass_library.library.DataType.f16 + elif torch_dtype == torch.bfloat16: + return cutlass_dtype == cutlass_library.library.DataType.bf16 + elif torch_dtype == torch.int8: + return cutlass_dtype == cutlass_library.library.DataType.s8 + elif torch_dtype == torch.uint8: + return cutlass_dtype == cutlass_library.library.DataType.u8 + elif torch_dtype == torch.int32: + return cutlass_dtype == cutlass_library.library.DataType.s32 + elif torch_dtype == torch.float8_e4m3fn: + return cutlass_dtype == cutlass_library.library.DataType.e4m3 + else: + return False + + +def get_accumulator_dtype( + input_torch_dtypes: list[torch.dtype], +) -> Optional[torch.dtype]: + """ + Given a pair of input torch dtypes, returns the inferred accumulator torch dtype. + """ + + if len(input_torch_dtypes) != 2: + return None + + torch_dtype = None + if input_torch_dtypes[0] == input_torch_dtypes[1]: + torch_dtype = input_torch_dtypes[0] + else: + size0 = torch.tensor([], dtype=input_torch_dtypes[0]).element_size() + size1 = torch.tensor([], dtype=input_torch_dtypes[1]).element_size() + if size0 > size1: + dtype0, dtype1 = input_torch_dtypes + else: + dtype1, dtype0 = input_torch_dtypes + if dtype0 in [torch.half, torch.bfloat16] and dtype1 in [ + torch.int8, + torch.uint8, + ]: + torch_dtype = dtype0 + + if torch_dtype in (torch.float16, torch.bfloat16, torch.float, torch.float8_e4m3fn): + return torch.float + if torch_dtype == torch.int8: + return torch.int32 + raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes=}") + + +@functools.lru_cache(32) +def get_alignments(torch_dtype: torch.dtype) -> list[int]: + """ + Returns all possible valid CUTLASS alignments in terms of the number of elements for a given dtype. + CUTLASS gemm / conv SM80 APIs support 16 bytes max alignment, and 2 bytes min alignment. + """ + + if torch_dtype in (torch.half, torch.bfloat16): + return [8, 4, 2, 1] + elif torch_dtype == torch.float: + return [4, 2, 1] + elif torch_dtype in (torch.uint8, torch.int8, torch.float8_e4m3fn): + return [16, 8, 4, 2] + elif torch_dtype == torch.int32: + return [4, 2, 1] + else: + raise NotImplementedError(f"unsupported {torch_dtype=} for alignments") + + +def get_max_alignment(inductor_layout: Layout) -> int: + """ + Returns the max alignment (in terms of number of elements) for a given Inductor Layout. + """ + + dtype = inductor_layout.dtype + size = inductor_layout.size + offset = inductor_layout.offset + + def is_static_int(number): + return isinstance(number, (int, sympy.Integer)) + + def a_factor_of(x, alignment): + if is_static_int(x) and is_static_int(alignment): + return x % alignment == 0 + rem = sympy.Mod(x, alignment) + return V.graph.sizevars.evaluate_expr(sympy.Eq(rem, 0)) + + try: + contiguous_dim = inductor_layout.stride.index(1) + except ValueError: + # No dim with stride 1 found, return 1 + return 1 + alignments = get_alignments(dtype) + for alignment in alignments: + if not a_factor_of(size[contiguous_dim], alignment) or not a_factor_of( + offset, alignment + ): + continue + if all( + (dim == contiguous_dim) + or a_factor_of(inductor_layout.stride[dim], alignment) + for dim in range(len(size)) + ): + return alignment + return 1 + + +class CUDACompileSourceCapturingContext: + # Helper class for Benchmarking and Testing CUTLASS Kernels in isolation. + # Can be used to capture the sourcecode passed to CUDACodeCache.compile + + def __init__(self): + self.sources = [] + self._compile_patch = None + + def __enter__(self, *args, **kwargs): + import unittest.mock as mock + + import torch._inductor.codecache + + _compile_method_orig = torch._inductor.codecache.CUDACodeCache.compile + + def my_compile(source_code, dst_file_ext): + self.sources.append(source_code) + return _compile_method_orig(source_code, dst_file_ext) + + self._compile_patch = mock.patch( + "torch._inductor.codecache.CUDACodeCache.compile", my_compile + ) + self._compile_patch.__enter__(*args, **kwargs) # type: ignore[union-attr] + return self + + def __exit__(self, *args, **kwargs): + self._compile_patch.__exit__(*args, **kwargs) # type: ignore[union-attr] + + +def cuda_standalone_runner_compile_command(srcpath: Path, exepath: Path): + # returns command string to compile a (captured) CUDA GEMM Kernel source to a standalone executable that's ready to run + # Passes the correct preprocessor define to nvcc to ensure the standalone runner is enabled. + from torch._inductor.codecache import cuda_compile_command + + extra_args = ["-DGENERATE_STANDALONE_RUNNER=1", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"] + compile_command = cuda_compile_command( + [str(srcpath)], str(exepath), "exe", extra_args=extra_args + ) + return compile_command diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..4f862fe8d9996256157db437115a2f4a10e10438 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -0,0 +1,366 @@ +from __future__ import annotations + +from typing import Optional + +import torch + +from ...utils import triton_version_uses_attrs_dict +from ..common import ( + DeviceOpOverrides, + register_device_op_overrides, + TritonScratchWorkspace, +) + + +class CUDADeviceOpOverrides(DeviceOpOverrides): + """ + CUDA-specific codegen functions, see DeviceOpOverrides for details + """ + + def import_get_raw_stream_as(self, name: str) -> str: + return f"from torch._C import _cuda_getCurrentRawStream as {name}" + + def set_device(self, device_idx: int) -> str: + return f"torch.cuda.set_device({device_idx})" + + def synchronize(self) -> str: + return "torch.cuda.synchronize()" + + def device_guard(self, device_idx: int) -> str: + return f"torch.cuda._DeviceGuard({device_idx})" + + def cpp_device_guard(self) -> str: + return "at::cuda::CUDAGuard" + + def cpp_aoti_device_guard(self) -> str: + return "AOTICudaGuard" + + def cpp_stream_guard(self) -> str: + return "at::cuda::CUDAStreamGuard" + + def cpp_aoti_stream_guard(self) -> str: + return "AOTICudaStreamGuard" + + def cpp_getStreamFromExternal(self) -> str: + return "at::cuda::getStreamFromExternal" + + def kernel_header(self) -> str: + source_codes = """ + #include + #include + #include + """ + return source_codes + + def kernel_driver(self) -> str: + source_codes = """ + #define CUDA_DRIVER_CHECK(EXPR) \\ + do { \\ + CUresult code = EXPR; \\ + const char *msg; \\ + CUresult code_get_error = cuGetErrorString(code, &msg); \\ + if (code_get_error != CUDA_SUCCESS) { \\ + throw std::runtime_error( \\ + std::string("CUDA driver error: ") + \\ + std::string("invalid error code!")); \\ + } \\ + if (code != CUDA_SUCCESS) { \\ + throw std::runtime_error( \\ + std::string("CUDA driver error: ") + \\ + std::string(msg)); \\ + } \\ + } while (0); + + static inline CUfunction loadKernel( + std::string filePath, + const std::string &funcName, + uint32_t sharedMemBytes, + const std::optional &cubinDir = std::nullopt) { + if (cubinDir) { + std::filesystem::path p1{*cubinDir}; + std::filesystem::path p2{filePath}; + filePath = (p1 / p2.filename()).string(); + } + + CUmodule mod; + CUfunction func; + CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str())); + CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); + if (sharedMemBytes > 0) { + CUDA_DRIVER_CHECK(cuFuncSetAttribute( + func, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + sharedMemBytes + )) + } + return func; + } + + static inline CUfunction loadKernel(const void* start, const std::string &funcName, uint32_t sharedMemBytes) { + CUmodule mod; + CUfunction func; + CUDA_DRIVER_CHECK(cuModuleLoadData(&mod, start)); + CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str())); + if (sharedMemBytes > 0) { + CUDA_DRIVER_CHECK(cuFuncSetAttribute( + func, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + sharedMemBytes + )) + } + return func; + } + + static inline void launchKernel( + CUfunction func, + uint32_t gridX, + uint32_t gridY, + uint32_t gridZ, + uint32_t numWarps, + uint32_t sharedMemBytes, + void* args[], + cudaStream_t stream) { + CUDA_DRIVER_CHECK(cuLaunchKernel( + func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr + )); + } + """ + if torch.version.hip is not None: + # Adjusting the warp size to GPU supported wavefront size on AMD GPU + prop = torch.cuda.get_device_properties(torch.cuda.current_device()) + source_codes = source_codes.replace( + "32*numWarps", str(prop.warp_size) + "*numWarps" + ) + return source_codes + + def tma_descriptor_helpers(self) -> str: + """ + CUDA helper functions for initializing TMA Descriptors on host side + """ + if torch.version.hip is not None: + raise RuntimeError("Host-side TMA descriptors not supported on HIP.") + + # helper functions for initializing 1D and 2D TMA descriptors in C++. borrowed from the Triton code here: + # Old APIs (fill(1|2)DTMADescriptor): + # https://github.com/triton-lang/triton/blob/6af4f88591c85de079d8a36a4d7dba67918e2b39/third_party/nvidia/backend/driver.c#L283 + # New APIs (fillTMADescriptor): + # https://github.com/triton-lang/triton/blob/main/third_party/nvidia/backend/driver.c#L283 + return """ + #if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 + [[maybe_unused]] static void init1DTMADescriptor( + CUtensorMap* m, + void* globalAddress, + uint64_t dim, + uint32_t blockDim, + uint32_t elementSize) { + uint64_t dims[1] = {dim}; + uint64_t globalStrides[1] = {dim * elementSize}; + uint32_t tensorDims[1] = {blockDim}; + uint32_t elementStrides[1] = {1}; + + CUtensorMapDataType type; + switch (elementSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + throw std::runtime_error("elementSize must be 1, 2, or 4"); + } + + if (elementSize * blockDim < 32) { + throw std::runtime_error("block size too small"); + } + + int rank = 1; + + CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( + m, type, rank, globalAddress, dims, + globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + } + + [[maybe_unused]] static void init2DTMADescriptor( + CUtensorMap* m, + void* globalAddress, + uint64_t dim1, + uint64_t dim0, + uint32_t blockDim1, + uint32_t blockDim0, + uint32_t elementSize) { + uint64_t dims[2] = {dim0, dim1}; + uint32_t tensorDims[2] = {blockDim0, blockDim1}; + uint64_t globalStrides[2] = {dims[0] * elementSize, + dims[0] * dims[1] * elementSize}; + uint32_t elementStrides[2] = {1, 1}; + + CUtensorMapDataType type; + switch (elementSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + throw std::runtime_error("elementSize must be 1, 2, or 4"); + } + + int rank = 2; + + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + uint32_t contigDimSizeInByte = elementSize * tensorDims[0]; + if (contigDimSizeInByte >= 128) { + swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + } else if (contigDimSizeInByte >= 64) { + swizzle = CU_TENSOR_MAP_SWIZZLE_64B; + } else if (contigDimSizeInByte >= 32) { + swizzle = CU_TENSOR_MAP_SWIZZLE_32B; + } else { + throw std::runtime_error("block size too small"); + } + + if (contigDimSizeInByte > 128) { + tensorDims[0] = 128 / elementSize; + } + + CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( + m, type, rank, globalAddress, dims, + globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + } + + [[maybe_unused]] static void initTMADescriptor( + CUtensorMap* m, + void* globalAddress, + int elemSize, + int rank, + uint32_t* blockSize, + uint64_t* shape, + uint64_t* stride + ) { + uint32_t elementStrides[5] = {1, 1, 1, 1, 1}; + uint32_t blockSizeInt[5]; + uint64_t shapeInt[5]; + uint64_t stridesLL[5]; + + // Reorder blockSize (reverse the order) + for (int i = 0; i < rank; ++i) { + blockSizeInt[rank - i - 1] = blockSize[i]; + } + + // Reorder shape (reverse the order) + for (int i = 0; i < rank; ++i) { + shapeInt[rank - i - 1] = shape[i]; + } + + // Reorder and calculate strides + for (int i = 0; i + 1 < rank; ++i) { + stridesLL[rank - i - 2] = elemSize * stride[i]; + } + stridesLL[rank - 1] = + shapeInt[rank - 1] * (rank == 1 ? elemSize : stridesLL[rank - 2]); + + CUtensorMapDataType type; + // In Triton this is computed ahead of time; but for simplicity + // in the PyTorch version we copied this code from the old + // TMA API handling (i.e. init2DTMADescriptor) + switch (elemSize) { + case 1: + type = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + case 2: + type = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 4: + type = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + default: + throw std::runtime_error("elemSize must be 1, 2, or 4"); + } + + // Calculate the size of the most contiguous dimension in bytes + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + uint32_t contigDimSizeInByte = elemSize * blockSizeInt[0]; + if (rank == 1) { + // rank 1 should not be swizzled + swizzle = CU_TENSOR_MAP_SWIZZLE_NONE; + } else if (contigDimSizeInByte >= 128) { + swizzle = CU_TENSOR_MAP_SWIZZLE_128B; + } else if (contigDimSizeInByte >= 64) { + swizzle = CU_TENSOR_MAP_SWIZZLE_64B; + } else if (contigDimSizeInByte >= 32) { + swizzle = CU_TENSOR_MAP_SWIZZLE_32B; + } else { + throw std::runtime_error("block size too small"); + } + + CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled( + m, type, rank, globalAddress, + shapeInt, stridesLL, blockSizeInt, elementStrides, + CU_TENSOR_MAP_INTERLEAVE_NONE, (CUtensorMapSwizzle)swizzle, + CU_TENSOR_MAP_L2_PROMOTION_L2_128B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + } + + struct StableTMADescriptor { + CUtensorMap m; + uint32_t block_shape[5]; + uint64_t global_shape[5]; + uint64_t strides[5]; + }; + #endif + """ + + def cpp_stream_type(self) -> str: + return "cudaStream_t" + + def aoti_get_stream(self) -> str: + return "aoti_torch_get_current_cuda_stream" + + def cpp_kernel_type(self) -> str: + return "CUfunction" + + def cpp_device_ptr(self) -> str: + return "CUdeviceptr" + + def cpp_global_scratch( + self, idx: int, workspace: TritonScratchWorkspace + ) -> Optional[tuple[list[str], str]]: + if triton_version_uses_attrs_dict(): + var_name = f"global_scratch_{idx}" + if workspace.size > 0: + size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};" + stride_array = f"int64_t {var_name}_stride[] = {{1}};" + device_type = "cached_torch_device_type_cuda" + device_idx = "device_idx_" + + return ( + [ + f"{size_array}", + f"{stride_array}", + f"AtenTensorHandle {var_name}_handle;", + ( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided(1, {var_name}_size, {var_name}_stride, " + f"{workspace.generate_dtype_str()}, {device_type}, {device_idx}, &{var_name}_handle));" + ), + f"RAIIAtenTensorHandle {var_name}_tensor({var_name}_handle);", + f"CUdeviceptr {var_name} = reinterpret_cast({var_name}_tensor.data_ptr());", + ], + var_name, + ) + else: + return [f"CUdeviceptr {var_name} = 0;"], var_name + return None + + +register_device_op_overrides("cuda", CUDADeviceOpOverrides()) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/gemm_template.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..b7dbb0ae5154342e8a8a30699d79ed36d090505a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/gemm_template.py @@ -0,0 +1,1905 @@ +# mypy: allow-untyped-defs +import copy +import enum +import functools +import logging +import re +import time +from abc import ABC, abstractmethod +from typing import Any, Optional, Union + +import torch +import torch.utils._pytree as pytree +from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops +from torch._inductor.scheduler import BaseSchedulerNode +from torch._inductor.select_algorithm import create_inputs_key +from torch._inductor.utils import clear_on_fresh_cache + +from ... import ir +from ...config import cuda as inductor_cuda_config +from ...ir import ( + Buffer, + ChoiceCaller, + CUDATemplateBuffer, + FixedLayout, + IRNode, + Layout, + ReinterpretView, +) +from ...utils import is_dynamic, Placeholder +from ...virtualized import V +from ..common import IndentedBuffer +from . import cutlass_utils +from .cuda_kernel import CUDATemplateKernel +from .cuda_template import CUTLASSTemplate +from .cutlass_presets import gen_cutlass_presets +from .cutlass_python_evt import CutlassEVTCodegen, scaled_mm_evt +from .cutlass_utils import torch_dtype_to_cutlass_type + + +GemmOperation = Any + +log = logging.getLogger(__name__) + +# Jinja template for GEMM Kernel, used by the CUTLASSGemm3xTemplate class below. +GEMM_TEMPLATE_CUTLASS_3X = r""" +{{template.header().getvalue()}} +{{template.globals().getvalue()}} +{{epilogue_visitor_tree}} +{{instance_definition}} +// When workspace_size is not a nullptr, populates requested workspace_size and returns. +// Otherwise, computes the Gemm kernel using the given workspace ptr. +extern "C" { +PT_EXPORT {{kernel_call_signature}} { + try { + using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; + using coord_t = cutlass::gemm::GemmCoord::Index; + static cutlass::KernelHardwareInfo hw_info; + if (hw_info.sm_count == 0) { + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); + CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); + } + {{instance_type}}::Arguments arguments; + {{template.render_gemm_arguments(argument_template, epilogue_template, should_swap_xw, + X, W, Bias, Y, alpha, beta, kernel, epilogue_args)}} + {{instance_type}} gemm_op; + if (workspace_size) { + *workspace_size = gemm_op.get_workspace_size(arguments); + return 0; + } + // check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers +#ifndef CUTLASS_BACKEND_DISABLE_CHECKS + { + auto status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + } +#endif +#ifdef CUTLASS_DEBUG_TRACE_LEVEL +#if CUTLASS_DEBUG_TRACE_LEVEL == 1 + { + // Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 + // we don't need a print statement, it's happening inside the function. + gemm_op.maximum_active_blocks(); + } +#endif +#endif + { + auto status = gemm_op.initialize(arguments, workspace, stream); + CUTLASS_CHECK(status); + } + { + auto status = gemm_op(stream); + CUTLASS_CHECK(status); + } + } + catch (std::exception& e) { + std::cerr << "Runtime error: " << e.what() << std::endl; + return -1; + } + catch (...) { + return -1; + } + return 0; +} +} + +// configuration name: {{op_conf_name}} +""" + +# Jinja template for Cutlass 3.x GEMM Kernel arguments, used by the CUTLASSGemmTemplate class below. +GEMM_ARGS_CUTLASS_3X = r""" + // Initialize GemmUniversal3xInstance arguments. + arguments = { + {{template.gemm_mode()}}, // GemmUniversalMode mode + { + static_cast({{M}}), + static_cast({{N}}), + static_cast(K), + static_cast(B) + }, // ProblemShape problem_shape + { + {{template.cutlass_type_cast(X, kernel.ptr(X))}}, // ElementA const* ptr_A + { + {{template.cute_int(kernel.stride(X, -2), "stride_x0")}}, + {{template.cute_int(kernel.stride(X, -1), "stride_x1")}}, + {{template.cute_int(kernel.batch_stride(X), "batch_stride_x")}} + }, // StrideA dA + {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // ElementB const* ptr_B + { + {{template.cute_int(kernel.stride(W, -1), "stride_w1")}}, + {{template.cute_int(kernel.stride(W, -2), "stride_w0")}}, + {{template.cute_int(kernel.batch_stride(W), "batch_stride_w")}} + }, // StrideB dB + }, // MainloopArguments mainloop + {{epilogue_arguments}}, + hw_info + }; + arguments.scheduler.max_swizzle_size = swizzle; +""" + +# Jinja template for Cutlass 3.x GEMM Kernel arguments if epilogue fusion is applied, +# used by the CUTLASSGemmTemplate class below. +GEMM_ARGS_CUTLASS_3X_EPILOGUE = r""" + // see https://tinyurl.com/4rk89z48 + { + {{epilogue_args}}, // thread, typename FusionCallbacks::Arguments ( EVT ) or ThreadEpilogueOp::Params (non-EVT ) + {{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // ElementC const* ptr_C + { + {{template.cute_int(kernel.stride(Bias, -2, 1), "stride_bias0")}}, + {{template.cute_int(kernel.stride(Bias, -1, 1), "stride_bias1")}}, + {{template.cute_int(kernel.batch_stride(Bias), "batch_stride_bias")}} + }, // StrideC dC + {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // ElementD const* ptr_D + { + {{template.cute_int(kernel.stride(Y, -2), "stride_y0")}}, + {{template.cute_int(kernel.stride(Y, -1), "stride_y1")}}, + {{template.cute_int(kernel.batch_stride(Y), "batch_stride_y")}} + }, // StrideD dD + }, // EpilogueArguments epilogue +""" + +# Jinja template for GEMM Kernel, used by the CUTLASS2xGemmTemplate class below. +GEMM_TEMPLATE_CUTLASS_2X = r""" +{{template.header().getvalue()}} +{{template.globals().getvalue()}} +{{instance_definition}} +// When workspace_size is not a nullptr, populates requested workspace_size and returns. +// Otherwise, computes the Gemm kernel using the given workspace ptr. +extern "C" { +PT_EXPORT {{kernel_call_signature}} { + try { + int B = {{kernel.size(Y, 0, -3, default_value=1)}}; + using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator; + using coord_t = cutlass::gemm::GemmCoord::Index; + static cutlass::KernelHardwareInfo hw_info; + if (hw_info.sm_count == 0) { + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); + CUTLASS_TRACE_HOST("Query result for SM count per device: " << hw_info.sm_count); + } + {{instance_type}}::Arguments arguments; + {{template.render_gemm_arguments(instance_type, argument_template, epilogue_template, should_swap_xw, + X, W, Bias, Meta, Y, alpha, beta, kernel, epilogue_args)}} + {{instance_type}} gemm_op; + if (workspace_size) { + *workspace_size = gemm_op.get_workspace_size(arguments); + return 0; + } + + // check for null pointers after workspace size, since querying workspace size doesn't require valid data pointers +#ifndef CUTLASS_BACKEND_DISABLE_CHECKS + { + auto status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + } +#endif +#ifdef CUTLASS_DEBUG_TRACE_LEVEL +#if CUTLASS_DEBUG_TRACE_LEVEL == 1 + { + // Print the maximum number of active blocks per SM for the kernel if CUTLASS_DEBUG_TRACE_LEVEL == 1 + // we don't need a print statement, it's happening inside the function. + gemm_op.maximum_active_blocks(); + } +#endif +#endif + + { + auto status = gemm_op.initialize(arguments, workspace, stream); + CUTLASS_CHECK(status); + } + { + auto status = gemm_op(stream); + CUTLASS_CHECK(status); + } + } + catch (std::exception& e) { + std::cerr << "Runtime error: " << e.what() << std::endl; + return -1; + } + catch (...) { + return -1; + } + return 0; +} +} +""" + +# Jinja template for Cutlass 2.x GEMM Kernel arguments, used by the CUTLASS2xGemmTemplate class below. +GEMM_ARGS_CUTLASS_2X = r""" + int64_t batch_stride_x = {{kernel.stride(X, -3)}}; + int64_t row_stride_x = {{kernel.row_or_column_stride(X)}}; + int64_t batch_stride_w = {{kernel.stride(W, -3)}}; + int64_t row_stride_w = {{kernel.row_or_column_stride(W)}}; + int64_t batch_stride_bias = {{kernel.stride(Bias, -3)}}; + int64_t row_stride_bias = {{kernel.row_or_column_stride(Bias)}}; + int64_t batch_stride_y = {{kernel.stride(Y, -3)}}; + int64_t row_stride_y = {{kernel.row_or_column_stride(Y)}}; + // Initialize GemmUniversalInstance arguments. + arguments = { + {{template.gemm_mode()}}, // GemmUniversalMode mode + { + static_cast(M), + static_cast(N), + static_cast(K) + }, // GemmCoord problem_size + {{split_k if split_k > 1 else 'B'}}, // int batch_count + {ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename EpilogueOutputOp::Params epilogue + {{template.cutlass_type_cast(X, kernel.ptr(X))}}, // void const * ptr_A + {{template.cutlass_type_cast(W, kernel.ptr(W))}}, // void const * ptr_B + {{template.cutlass_type_cast(Bias, kernel.ptr(Bias))}}, // void const * ptr_C + {{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, // void * ptr_D + batch_stride_x, // int64_t batch_stride_A + batch_stride_w, // int64_t batch_stride_B + batch_stride_bias, // int64_t batch_stride_C + batch_stride_y, // int64_t batch_stride_D + row_stride_x, // typename LayoutA::Stride::LongIndex lda + row_stride_w, // typename LayoutB::Stride::LongIndex ldb + row_stride_bias, // typename LayoutC::Stride::LongIndex ldc + row_stride_y, // typename LayoutC::Stride::LongIndex ldd + }; +""" + +GEMM_ARGS_SPARSE_CUTLASS_2X = r""" + using TensorRefA = cutlass::TensorRef<{{instance_type}}::ElementA, + {{instance_type}}::LayoutA>; + using TensorRefB = cutlass::TensorRef<{{instance_type}}::ElementB, + {{instance_type}}::LayoutB>; + using TensorRefC = cutlass::TensorRef<{{instance_type}}::ElementC, + {{instance_type}}::LayoutC>; + using TensorRefE = cutlass::TensorRef<{{instance_type}}::ElementE, + {{instance_type}}::LayoutE>; + // Note that "X" and "W" names may be misleading here. Namely, for + // sparse GEMM, the first argument is always sparse, while typically + // weight matrix, implied by name "W" will be sparse in + // applications. Thus, just remember that here: "X" refers to first + // argument, that is sparse, and "W" to second, that is dense. + TensorRefA X_ref({{template.cutlass_type_cast(X, kernel.ptr(X))}}, {{kernel.row_or_column_stride(X)}}); + TensorRefB W_ref({{template.cutlass_type_cast(W, kernel.ptr(W))}}, {{kernel.row_or_column_stride(W)}}); + TensorRefC Y_ref({{template.cutlass_type_cast(Y, kernel.ptr(Y))}}, {{kernel.row_or_column_stride(Y)}}); + TensorRefE Meta_ref({{template.cutlass_sparse_meta_type_cast(Meta, kernel.ptr(Meta))}}, + TensorRefE::Layout::packed({ {{kernel.size(Meta, 0)}}, {{kernel.size(Meta, 1)}} })); + // Initialize GemmSparse arguments. + arguments = { + { + static_cast(M), + static_cast(N), + static_cast(2 * K), + }, // GemmCoord problem_size + X_ref, // TensorRef ref_A + W_ref, // TensorRef ref_B + Y_ref, // TensorRef ref_C + Y_ref, // TensorRef ref_D + Meta_ref, // TensorRef ref_E + {ElementComputeEpilogue({{alpha}}), ElementComputeEpilogue({{beta}})}, // typename EpilogueOutputOp::Params epilogue, + }; +""" + +# Additional includes which are necessary if the standalone test / debug runner is generated as well +GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES = r""" +#ifdef GENERATE_STANDALONE_RUNNER +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include +#endif +""" + +# Jinja template for the standalone runner that may be generated as part of the code. +GEMM_STANDALONE_RUNNER_TEMPLATE = r""" +#ifdef GENERATE_STANDALONE_RUNNER +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed, float max=1.0, float min=-1.0) { + if (block.size()<=0) return false; + Element scope_max(static_cast(max)), scope_min(static_cast(min)); + cutlass::reference::device::BlockFillRandomUniform( + (Element*)block.get(), block.size(), seed, scope_max, scope_min); + + return true; +} + +{% if Meta is defined and Meta is not none %} +template +bool initialize_block_meta( + cutlass::DeviceAllocation& block, + uint64_t seed) { + if (block.size()<=0) return false; + cutlass::reference::device::BlockFillRandomSparseMeta( + (Element*)block.get(), block.size(), seed, {{instance_type}}::kMetaSizeInBits); + return true; +} +{% endif %} + +extern "C" int run_standalone(uint64_t seed, int repetitions) { + std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl; + size_t workspace_size = 0; + size_t* workspace_size_ptr = &workspace_size; + + int M = {{kernel.get_layout_args()[0]}}; + int N = {{kernel.get_layout_args()[1]}}; + int K = {{kernel.get_layout_args()[2]}}; + int B = {{kernel.get_layout_args()[3]}}; + int lda = {{kernel.get_layout_args()[4]}}; + int ldb = {{kernel.get_layout_args()[5]}}; + int ldc = {{kernel.get_layout_args()[6]}}; + int ldd = {{kernel.get_layout_args()[7]}}; + uint8_t swizzle = {{kernel.runtime_arg_values[0]}}; + + using ElementA = {{kernel.cutlass_dtype(X)}}; + using ElementB = {{kernel.cutlass_dtype(W)}}; + using ElementC = {{kernel.cutlass_dtype(Bias, default_dtype='uint8_t')}}; // may not be void + using ElementD = {{kernel.cutlass_dtype(Y)}}; + {% if Meta is defined and Meta is not none %} + using ElementE = {{kernel.cutlass_dtype(Meta)}}; + {% endif %} + + cutlass::DeviceAllocation X_data({{kernel.max_valid_index(X)+1}}); + initialize_block(X_data, seed++); + cutlass::DeviceAllocation W_data({{kernel.max_valid_index(W)+1}}); + initialize_block(W_data, seed++); + cutlass::DeviceAllocation Bias_data({{kernel.max_valid_index(Bias)+1}}); + initialize_block(Bias_data, seed++); + cutlass::DeviceAllocation Y_data({{kernel.max_valid_index(Y)+1}}); + {% if Meta is defined and Meta is not none %} + cutlass::DeviceAllocation Meta_data({{kernel.max_valid_index(Meta)+1}}); + initialize_block_meta(Meta_data, seed++); + {% endif %} + + cutlass::DeviceAllocation workspace_data; + // Call once with workspace_size_ptr set to get workspace size + + std::cout << "Calling once to get workspace size" << std::endl; + {{test_call_statement}}; + // Allocate workspace if necessary + if (workspace_size > 0) { + workspace_data.reset(workspace_size); + std::cout << "Allocated workspace size of " << workspace_size << " bytes" << std::endl; + } + std::cout << "Calling Kernel as {{test_call_statement}};" << std::endl; + workspace_size_ptr = nullptr; + for (int i=0; i None: + """ + Args: + input_nodes (List[Buffer]): List of input nodes of the GEMM kernel. + layout (Layout): Layout type of the resulting output node. + alpha (float): The scaling factor for the product of the inputs in the GEMM operation. + beta (float): The scaling factor applied to the output matrix. + input_reorder (Optional[List[int]]): Specifies the reordering of the input nodes. If not provided, + no reordering is performed. Defaults to None. + """ + super().__init__( + str(Placeholder.KERNEL_NAME), input_nodes, layout, input_reorder + ) + self.alpha = alpha + self.beta = beta + self.use_fast_accum = use_fast_accum + assert 2 <= len(input_nodes) <= 5 + assert self._are_inputs_layout_compatible( + [node.get_layout() for node in input_nodes] + ) + + self.cache_key: str = create_inputs_key(self.input_nodes) + + @staticmethod + @abstractmethod + def add_cutlass_gemm_choices( + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = None, + **extra_kwargs, + ) -> None: + raise NotImplementedError + + @staticmethod + @abstractmethod + def _get_supported_ops() -> "list[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + raise NotImplementedError + + @staticmethod + @abstractmethod + def _has_tma_epilogue(self) -> bool: + raise NotImplementedError + + @abstractmethod + def _get_template(self) -> str: + raise NotImplementedError + + @abstractmethod + def _get_template_args( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> tuple[str, Optional[str]]: + raise NotImplementedError + + @abstractmethod + def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool: + raise NotImplementedError + + @abstractmethod + def _shape_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _alignment_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _set_bias_layout_and_alignment( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + raise NotImplementedError + + @abstractmethod + def _define_gemm_instance( + self, + op: GemmOperation, + evt_name: Optional[str] = None, + ) -> tuple[str, str]: + raise NotImplementedError + + @abstractmethod + def _get_extra_inputs_and_names( + self, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]: + raise NotImplementedError + + @abstractmethod + def _update_arg_names_for_test_call_statement( + self, + arg_names: list[str], + input_nodes: list[Buffer], + ) -> list[str]: + raise NotImplementedError + + def _add_cutlass_gemm_choices( + self, + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + **extra_kwargs, + ) -> None: + """ + Adds Cutlass GEMM configurations choices to the auto-tuning list. + + This function mutates the passed list of choices by appending the choices for Cutlass GEMM configs to it. + + Args: + choices (list): The list to which choices are appended. + layout (ir.Layout): The layout configuration. + input_nodes (list): The list of input nodes. + alpha (float,int): Scaling factor, defaults to 1. + beta (float,int): Offset, defaults to 0. + input_reorder (list, optional): Order of the inputs, defaults to None. + **extra_kwargs: Additional keyword arguments. + + """ + + ops = self.gen_ops() + for name, op in ops: + for swizzle in inductor_cuda_config.cutlass_max_profiling_swizzle_options: + description = f"{name} swizzle={swizzle}" + self.maybe_append_choice( + choices, description=description, op=op, swizzle=swizzle + ) + + if len(ops) == 0: + input_layouts = [node.get_layout() for node in input_nodes] + input_strides = [node.get_stride() for node in input_nodes] + output_layout = layout + warning_msg = f"No suitable Cutlass GEMM configs found, fallbacks used ( {len(ops)=}, {output_layout=}, {input_layouts=}, {input_strides=} )" # noqa: B950 + log.warning(warning_msg) + log.debug( + "Added %d Cutlass gemm configs.", + len(ops), + ) + + def header(self) -> IndentedBuffer: + """ + Returns a buffer containing CUDA C++ code for the header section of the CUTLASS GEMM template. + This section primarily includes the necessary header files. + + Returns: + IndentedBuffer: An instance of IndentedBuffer that contains the generated CUDA C++ header code. + """ + res = super().header() + res.splice( + """ + #include "cutlass/gemm/gemm.h" + #include "cutlass/gemm/device/gemm_universal.h" + #include "cutlass/gemm/device/gemm_universal_adapter.h" + #include "cutlass/gemm/kernel/gemm_universal.hpp" + #include "cutlass/gemm/device/gemm_sparse.h" + #include "cutlass/gemm/collective/collective_builder.hpp" + #include "cutlass/epilogue/collective/collective_builder.hpp" + #include "cutlass/epilogue/collective/default_epilogue.hpp" + #include "cutlass/epilogue/thread/linear_combination.h" + #include "cutlass/epilogue/thread/activation.h" + #include "cutlass/gemm/dispatch_policy.hpp" + #include "cutlass/gemm/kernel/tile_scheduler.hpp" + #include "cutlass/tensor_ref.h" + #include "cutlass/util/distribution.h" + #include "cutlass/util/packed_stride.hpp" + #include "cutlass/util/tensor_view_io.h" + """ + ) + if inductor_cuda_config.generate_test_runner and not is_dynamic( + *self.input_nodes, self.output_node + ): + res.splice(GEMM_STANDALONE_RUNNER_ADDITIONAL_INCLUDES) + return res + + @staticmethod + def cutlass_layout(torch_layout: ir.Layout) -> "Optional[cutlass_lib.LayoutType]": # type: ignore[name-defined] # noqa: F821 + """ + Converts an ir.Layout instance into the corresponding cutlass_library.LayoutType enum value + (RowMajor, ColumnMajor, or None if no matching value is found ). + + Args: + torch_layout (ir.Layout): The layout that needs to be looked up. + + Returns: + cutlass_lib.LayoutType: The converted layout corresponding to the `torch_layout` or None if no matching + value is found. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return cutlass_lib.LayoutType.RowMajor + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-2], 1): + return cutlass_lib.LayoutType.ColumnMajor + else: + return None + + @staticmethod + def flip_cutlass_layout( + cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_lib.LayoutType": # type: ignore[name-defined] # noqa: F821 + """Helper method: Flips a given cutlass layout (cutlass_lib.LayoutType) from RowMajor + to ColumnMajor or vice versa""" + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + if cutlass_layout == cutlass_lib.LayoutType.RowMajor: + return cutlass_lib.LayoutType.ColumnMajor + else: + return cutlass_lib.LayoutType.RowMajor + + @staticmethod + @functools.lru_cache(32) + def layout_match( + torch_layout: ir.Layout, + cutlass_layout: "cutlass_lib.LayoutType", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + """Helper Method: Determines whether a given torch layout matches a given Cutlass layout""" + return CUTLASSGemmTemplate.cutlass_layout(torch_layout) == cutlass_layout + + @staticmethod + def set_alignment(torch_layout, op_element) -> bool: + """ + Helper method to update the alignment of a given CUTLASS GEMM op operand's element. + + This method modifies the alignment of the given Cutlass GEMM op operand's element to match the + layout of the corresponding ir.Buffer node. + + Args: + torch_layout: The layout of the corresponding ir.Buffer node. + op_element: The Cutlass GEMM op operand's element whose alignment is to be updated. + + Returns: + bool: True if the alignment was successfully updated, False otherwise. + """ + alignment = cutlass_utils.get_max_alignment(torch_layout) + cuda_arch = cutlass_utils.get_cuda_arch() + if cuda_arch and int(cuda_arch) >= 90 and alignment < op_element.alignment: + return False + else: + op_element.alignment = alignment + return True + + @staticmethod + def should_swap_XW( + bias: IRNode, + ) -> bool: + """ + Helper method to determine whether we should do an explicit transpose by switching the order of the + matmul operands. This might be necessary when we can't otherwise arrive at the right memory + layout for the given Bias operand. + + Note: This method is a workaround for CUDA Errors that seemingly non-deterministically + occurred in practice in some CUTLASS GEMM Kernels with Linear epilogues that have a bias term. + it might make sense to check on newer Cutlass releases whether it makes sense to keep + returning True in certain cases or whether it becomes unnecessary. + """ + # If bias is row major, swap all M and N dimensions + if ( + bias is not None + and len(bias.get_stride()) >= 2 + and bias.get_stride()[-1] in (0, 1) + ): + log.debug("GEMM Layout swapped X and W -> explicit transpose") + return True + return False + + @staticmethod + def swap_XW( + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + """ + Swap operands X and W (aka operans A and B) of the GEMM operation. This + requires transposing the operands, which is done by swapping the strides. + Note that we don't change the apparent external layout, just the operand layout. + this is intentional. + """ + new_op = copy.deepcopy(op) + new_op.A.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.A.layout) + new_op.B.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.B.layout) + new_op.A, new_op.B = new_op.B, new_op.A + new_op.C.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.C.layout) + new_op.D.layout = CUTLASSGemmTemplate.flip_cutlass_layout(new_op.D.layout) + return new_op + + def fix_op_layout( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + X: Buffer, + W: Buffer, + Bias: Optional[Buffer], + Y: Union[Buffer, ReinterpretView], + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + # This is a workaround to deal with cases where the input layouts have changed + # between autotuning and rendering. This happens if the inputs layout + # are FlexibleLayout instances. In this case, we need to update the + # op's input layouts. It is a hack, because now the op + # we benchmarked is not the same as the op we render, + # but there is no simple way to fix this in the autotuner, since that would + # potentially disable other optimizations. + a_layout = X.get_layout() + b_layout = W.get_layout() + c_layout = Bias.get_layout() if Bias is not None else None + + d_layout = copy.deepcopy(Y.get_layout()) + match_list = [ + CUTLASSGemmTemplate.layout_match(buf.get_layout(), op_layout) + for buf, op_layout in zip( + (X, W, Bias, Y), + (op.A.layout, op.B.layout, op.C.layout, op.D.layout), + ) + if buf is not None + ] + all_match = all(match_list) + if all_match: + return op + log.warning( + f"Cutlass GEMM Layout change: Input and/or output layouts have changed between autotuning/retuning and call to render on {self}. Applying workaround. This can lead to suboptimal performance. Match List: {match_list}" # noqa: G004, B950 + ) + new_op = copy.deepcopy(op) + + if a_layout is not None: + new_op.A.layout = CUTLASSGemmTemplate.cutlass_layout(a_layout) + if b_layout is not None: + new_op.B.layout = CUTLASSGemmTemplate.cutlass_layout(b_layout) + if c_layout is not None: + new_op.C.layout = CUTLASSGemmTemplate.cutlass_layout(c_layout) + new_op.C.element = cutlass_utils.torch_dtype_to_cutlass_type(c_layout.dtype) + if d_layout is not None: + new_op.D.layout = CUTLASSGemmTemplate.cutlass_layout(d_layout) + return new_op + + def _dtype_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + """ + Checking dtypes of A, B, acc, D here. + + Empirically speaking, CUTLASS2x ops have same dtype for C and D. + """ + X = self.input_nodes[0] + W = self.input_nodes[1] + + accumulator_torch_dtype = cutlass_utils.get_accumulator_dtype( + [X.get_dtype(), W.get_dtype()], + ) + if not ( + cutlass_utils.dtype_match(X.get_dtype(), op.A.element) + and cutlass_utils.dtype_match(W.get_dtype(), op.B.element) + and cutlass_utils.dtype_match( + self.output_node.get_layout().dtype, op.D.element + ) + and cutlass_utils.dtype_match( + accumulator_torch_dtype, op.accumulator_type() + ) + ): + return False + + return True + + def filter_op( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> "cutlass_library.gemm_op.GemmOperation": # type: ignore[name-defined] # noqa: F821 + """ + Helper method: + + Determines whether a given Cutlass GEMM op definition is suitable for the current + input / output of the operation that this template is supposed to implement. + + Takes memory layout, dtype and support for EVT operations into account, + and filters potentially problematic ops. + + Returns None if the op is not suitable, otherwise returns the op to be used, which might + have been mutated. + """ + + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib # type: ignore[import] + + # Skip simt kernels + if ( + op.tile_description.math_instruction.opcode_class + == cutlass_lib.OpcodeClass.Simt + ): + return None + + if op.gemm_kind not in self._get_supported_ops(): + return None + + X = self.input_nodes[0] + W = self.input_nodes[1] + + # Filter ops according to the shape match. + if not self._shape_match(op): + return None + + # Filter ops by dtypes. + if not self._dtype_match(op): + return None + + # Filter ops by input layouts. + if not ( + self.layout_match(X.get_layout(), op.A.layout) + and self.layout_match(W.get_layout(), op.B.layout) + ): + return None + + # Filter ops by alignment. + if not self._alignment_match(op): + log.debug( + "Skipping due to alignment mismatch. op: %s", op.configuration_name() + ) + return None + + # Update op. + op = copy.deepcopy(op) + + # Set output layout. + op.D.layout = CUTLASSGemmTemplate.cutlass_layout(self.output_node.get_layout()) + + # Filter ops by alignments and set alignments. + status = ( + self.set_alignment(X.get_layout(), op.A) + and self.set_alignment(W.get_layout(), op.B) + and self.set_alignment(self.output_node.get_layout(), op.D) + ) + if not status: + log.debug( + "Skipping due to alignment setting failure. op: %s", + op.configuration_name(), + ) + return None + + if inductor_cuda_config.cutlass_tma_only and not self._has_tma_epilogue(op): + return None + + # Set epilogue. + # TODO: update epilogue functor according to epilogues. + op.element_epilogue = op.accumulator_type() + + if self.use_fast_accum is not None: + is_op_fast_accum = "fastaccum" in op.configuration_name() + if self.use_fast_accum ^ is_op_fast_accum: + return None + + # Set bias layout and alignment. + status = self._set_bias_layout_and_alignment(op) + if not status: + log.debug( + "Skipping due to bias layout and alignment setting failure. op: %s", + op.configuration_name(), + ) + return None + + # Apply regex filters at the end when configuration name doesn't change anymore + if ( + inductor_cuda_config.cutlass_op_allowlist_regex + or inductor_cuda_config.cutlass_presets + ): + patterns = [] + if inductor_cuda_config.cutlass_op_allowlist_regex: + patterns.append(inductor_cuda_config.cutlass_op_allowlist_regex) + if inductor_cuda_config.cutlass_presets: + presets = gen_cutlass_presets() + preset_nums = [ + int(x) for x in inductor_cuda_config.cutlass_presets.split(",") + ] + for preset_num in preset_nums: + preset = presets.get(preset_num, {}).get( + inductor_cuda_config.cutlass_instantiation_level, [] + ) + + patterns.extend(preset) + + pattern = "|".join(patterns) + if pattern and not re.search(pattern, op.configuration_name()): + return None + if inductor_cuda_config.cutlass_op_denylist_regex is not None: + if re.search( + inductor_cuda_config.cutlass_op_denylist_regex, op.configuration_name() + ): + return None + + return op + + def gen_ops(self) -> "list[tuple[str, cutlass_gemm_op.GemmOperation]]": # type: ignore[name-defined] # noqa: F821 + """ + Creates a list of Cutlass GemmOperation instances that match the operation this template is designed to represent. + The matching is carried out with respect to the input and output specifications of the operation. + + No function arguments. + + Returns: + List[Tuple[str, cutlass_gemm_op.GemmOperation]]: A list of (cutlass_name, GemmOperation) + tuples that are compatible with the operation requirements of this template. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + + if self.cache_key in self.filtered_ops_cache: + log.debug("Using cached ops for %s", self.cache_key) + return self.filtered_ops_cache[self.cache_key] + + maybe_ops = maybe_fetch_ops() + if maybe_ops is None: + log.debug("Cannot fetch ops from cache, generating ops from scratch") + full_ops = cutlass_utils.gen_ops() + ops = pytree.tree_flatten(full_ops)[0] + else: + log.debug("Using cached ops from cache") + ops = maybe_ops + + res: dict[str, cutlass_gemm_op.GemmOperation] = {} + start_time = time.time() + for op in ops: + # if changed, need to also change CUTLASS_OPERATION_KIND + assert isinstance(op, cutlass_gemm_op.GemmOperation) + filter_res = self.filter_op(op) + if ( + filter_res is not None + and res.get(filter_res.configuration_name(), None) is None + ): + res[filter_res.configuration_name()] = filter_res + log.info( + "Got cutlass configs: total number of ops: %d. Filtering took %.2f seconds", + len(res), + time.time() - start_time, + ) + sorted_res = sorted(res.items()) + ret_res = sorted_res[: inductor_cuda_config.cutlass_max_profiling_configs] + if len(self.filtered_ops_cache) < 50: + self.filtered_ops_cache[self.cache_key] = ret_res + else: + log.debug("Not caching ops since filtered_ops_cache has reached size 50.") + return ret_res + + def gemm_mode(self) -> str: + """ + Returns a Cutlass GEMM mode string for the current operation, dependent on whether this op implements + a batched GEMM or a simple GEMM without batch dimension. + + Returns: + str: A string indicating the Cutlass GEMM mode. If the output node has more than two dimensions, + "cutlass::gemm::GemmUniversalMode::kBatched" is returned, otherwise + "cutlass::gemm::GemmUniversalMode::kGemm" is returned. + """ + sizes = self.output_node.get_size() + if len(sizes) > 2: + return "cutlass::gemm::GemmUniversalMode::kBatched" + else: + return "cutlass::gemm::GemmUniversalMode::kGemm" + + def render( # type: ignore[override] + self, + kernel: CUDATemplateKernel, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + template_buffer_node: Optional[CUDATemplateBuffer] = None, + epilogue_nodes: Optional[list[BaseSchedulerNode]] = None, + **kwargs, + ) -> str: + """ + The primary entry point for the code rendering process used in this template. + Renders the Cutlass based CUDA C++ code for the GEMM Kernel that this template is designed to implement, + including potentially fused epilogues. + + Args: + kernel (CUDATemplateKernel): The kernel to be rendered. + op (cutlass_gemm_op.GemmOperation, optional): A GEMM operation that is required to be compatible with the + input and output definitions as well as a possible epilogue. Defaults to None. + **kwargs: Additional keyword arguments. Currently unused. + + Returns: + str: Cutlass based CUDA C++ code fragment as a string, to be used by the current + CUDATemplateKernel or autotuning code. + + Note: + All inputs and their corresponding buffer addresses and names take precedence over previously + passed inputs to the template at construction time. However, they should be layout compatible. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + assert isinstance(op, cutlass_gemm_op.GemmOperation), ( + "op argument is required and has to be an instance of GemmOperation" + ) + + if epilogue_nodes and not self._has_tma_epilogue(op): + raise NotImplementedError( + "Non-TMA epilogue visitor tree is not supported in Cutlass." + ) + + assert len(self.input_nodes) >= 2 and self.output_node is not None + X, W = self.input_nodes[0], self.input_nodes[1] + for input_node in self.input_nodes: + if not isinstance(X.layout, FixedLayout): + input_node.freeze_layout() + + Y = self.output_node + if template_buffer_node is not None: + Y = template_buffer_node + + Bias, extra_inputs, extra_names = self._get_extra_inputs_and_names(op) + + # Define Kernel call signature + # Important: This step also populates Kernel name to node mapping data structures, + # which are required further below ( for example by the template renderer ) + inputs = [X, W, Bias, *extra_inputs] + names = ["X", "W", "Bias", *extra_names] + ["Y"] + names_str = ",".join(names) + if self.input_reorder is not None: + input_reorder = self.input_reorder + else: + input_reorder = None + + # The layouts might have changed between autotuning and this call if they were FlexibleLayout + # we need to adapt, which might lead to suboptimal performance. + op = self.fix_op_layout(op, X, W, Bias, Y) + + # to make op mutable without affecting others + op = copy.deepcopy(op) + is_scaled_mm = len(self.input_nodes) in (4, 5) + if Bias is not None and not is_scaled_mm: + assert Bias.get_dtype() == X.get_dtype() + # This might have been set to void during filtering, when the assumption was still that there's no C + # operand + op.C.element = op.A.element + + assert op.C.element == op.D.element, ( + f"Expect C and D to have the same dtype, found {op.C.element} and {op.D.element}" + ) + + argument_template, epilogue_template = self._get_template_args(op) + should_swap_xw: bool = False + if Bias is not None and self._has_tma_epilogue(op): + if ( + op.epilogue_schedule + != cutlass_lib.EpilogueScheduleType.EpilogueTransposed + and self.should_swap_XW(Bias) + ): + # TMA epilogue requires bias vector in column major to get best perf. + op = self.swap_XW(op) + should_swap_xw = True + + if epilogue_nodes or is_scaled_mm: + if epilogue_nodes: + ( + input_names, + output_names, + var_name_to_buffer_name, + evt_py_code, + ) = CutlassEVTCodegen.ir_to_evt_python_code( + Y.get_name(), epilogue_nodes, V.kernel.removed_buffers + ) + + D_output_name = var_name_to_buffer_name["D"] + name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs + for name in V.graph.constants.keys(): + name_to_buffer[name] = V.graph.add_tensor_constant( + V.graph.constants[name], name + ) + D_output_buffer = name_to_buffer[D_output_name] + Y = D_output_buffer # type: ignore[assignment] + # Interestingly, I don't think the rest of the layout matters here since we + # use the properties of the Y buffer to fill in D's properties in the epilogue + # args. This is needed though because it defines types expected in the epilogue args. + op.D.element = cutlass_utils.torch_dtype_to_cutlass_type( + D_output_buffer.get_dtype() + ) + + assert output_names, "There should be at least one write" + + epilogue_inputs = [name_to_buffer[name] for name in input_names] + outputs = [name_to_buffer[name] for name in output_names] + else: # Scaled MM, we read the two scale matrices (and optional bias) and write a single output + bias = None if len(self.input_nodes) < 5 else self.input_nodes[4] + bias_name = bias.get_name() if bias else None + + ( + evt_read_names, + var_name_to_buffer_name, + evt_py_code, + ) = scaled_mm_evt( + self.input_nodes[2].get_name(), # scale_A + self.input_nodes[3].get_name(), # scale_B + bias_name, + Y.get_name(), + ) + + input_names = list(evt_read_names) + output_names = [] # We only need Y + epilogue_inputs = [self.input_nodes[2], self.input_nodes[3]] + if bias: + epilogue_inputs.append(bias) + outputs = [] + + acc_dtype = cutlass_utils.get_accumulator_dtype( + [X.get_dtype(), W.get_dtype()] + ) + assert acc_dtype, "Could not determine accumulator dtype" + + evt_name, evt_args, evt_code = self._render_evt( + op, + evt_py_code, + var_name_to_buffer_name, + Y.get_dtype(), + acc_dtype, + ) + + inputs = [ + X, + W, + Bias, + *epilogue_inputs, # type: ignore[list-item] + Y, + *extra_inputs, + ] + names_str = ",".join( + ["X", "W", "Bias", *input_names, "Y", *output_names, *extra_names] + ) + else: + evt_name = None + outputs = [Y] + evt_args = f"{{ElementComputeEpilogue({self.alpha}), ElementComputeEpilogue({self.beta})}}" + evt_code = "" + + kernel_call_signature = kernel.def_kernel( + inputs=inputs, # type: ignore[arg-type] + outputs=outputs, # type: ignore[arg-type] + names_str=names_str, + input_reorder=input_reorder, + ) + + test_call_statement = self.test_call_statement(kernel, inputs, names_str) + + instance_definition, instance_type = self._define_gemm_instance(op, evt_name) + + options = dict( + alpha=self.alpha, + beta=self.beta, + X=X, + W=W, + Y=Y, + kernel_call_signature=kernel_call_signature, + Bias=Bias, + epilogue_template=epilogue_template, + argument_template=argument_template, + should_swap_xw=should_swap_xw, + template=self, + kernel=kernel, + instance_definition=instance_definition, + instance_type=instance_type, + input_reorder=self.input_reorder, + epilogue_args=evt_args, + test_call_statement=test_call_statement, + op_conf_name=op.configuration_name(), + epilogue_visitor_tree=evt_code, + ) + options.update(dict(zip(extra_names, extra_inputs))) + res = self._template_from_string(self._get_template()).render(**options) + if inductor_cuda_config.generate_test_runner and not is_dynamic(X, W, Y, Bias): + test_runner_code = self._template_from_string( + GEMM_STANDALONE_RUNNER_TEMPLATE + ).render(**options) + res += "\n\n" + test_runner_code + + # splice to remove trailing spaces in each line + buf = IndentedBuffer() + buf.splice(res) + return buf.getvalue() + + def test_call_statement( + self, + kernel, + input_nodes, + names_str: str = "", + ) -> str: + """ + Helper method to render the Cutlass CUDA C++ code required for calling the GEMM operation in the standalone + test runner that might also be generated along with the rest of the code, if the corresponding config is + enabled. + + Returns a C++ statement that calls the GEMM operation with the correct arguments. + """ + _, __, arg_types = kernel.args.cpp_argdefs(cutlass_utils.DTYPE_TO_CUTLASS_TYPE) + arg_names = [name.strip() for name in names_str.strip().split(",")] + arg_names = self._update_arg_names_for_test_call_statement( + arg_names, input_nodes + ) + arguments = [ + f"(({arg_type}){arg_name}_data.get())" + for arg_type, arg_name in zip(arg_types, arg_names) + ] + return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950 + + def _render_evt( + self, + op: GemmOperation, + evt_py_code: str, + buffer_renames: dict[str, str], + output_dtype: torch.dtype, + accumulator_dtype: torch.dtype, + ) -> tuple[str, str, str]: # type: ignore[name-defined] # noqa: F821 + raise NotImplementedError("_render_evt in CUTLASSGemmTemplate not implemented") + + +class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate): + """ + CUTLASS 3x GEMM Template, which is used to generate CUTLASS GEMM kernels + including those which allow flexible fusions with epilogues. + """ + + def __init__( + self, + input_nodes: list[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = None, + ): + super().__init__( + input_nodes, layout, alpha, beta, input_reorder, use_fast_accum + ) + + @staticmethod + def add_cutlass_gemm_choices( + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = None, + **extra_kwargs, + ) -> None: + template = CUTLASS3xGemmTemplate( + input_nodes, + layout, + alpha, + beta, + input_reorder, + use_fast_accum, + ) + template._add_cutlass_gemm_choices( + choices, layout, input_nodes, alpha, beta, input_reorder, **extra_kwargs + ) + + @staticmethod + @functools.lru_cache(1) + def _get_supported_ops() -> "list[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + import cutlass_library.library as cutlass_lib + + return [cutlass_lib.GemmKind.Universal3x] + + def _get_template(self) -> str: + return GEMM_TEMPLATE_CUTLASS_3X + + def _get_template_args( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> tuple[str, Optional[str]]: + return (GEMM_ARGS_CUTLASS_3X, GEMM_ARGS_CUTLASS_3X_EPILOGUE) + + @staticmethod + def _has_tma_epilogue( # noqa: F821 # type: ignore[arg-type,name-defined] + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined,arg-type] # noqa: F821 + ) -> bool: # type: ignore[name-defined] + """Helper method: Determine whether a given Cutlass GEMM op has a TMA Epilogue""" + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + result = False + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + epilogue_schedule_str = str(op.epilogue_schedule).split(".")[-1] + result = epilogue_schedule_str.lower().startswith("tma") + return result + + @staticmethod + def supports_epilogue_fusion(op: GemmOperation) -> bool: + return CUTLASS3xGemmTemplate._has_tma_epilogue(op) + + def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool: + """ + Evaluates whether input layouts are compatible for General Matrix Multiply (GEMM). + + This function checks compatibility of A, B, and possibly C operand layouts for + a General Matrix Multiply (GEMM) operation, expressed as 'alpha * matmul(A, B) + beta * C'. + It verifies requirements such as matching data types, minimum rank, and suitability + for broadcasting, as defined by PyTorch operations like `torch.matmul`, `torch.aten.mm`, + `addmm`, `bmm`, `baddbmm`, etc. + + Args: + layouts (List[Layout]): List containing 2 or 3 Layout objects representing + the input matrices A, B, and possibly C. + + Returns: + bool: True if layouts are GEMM compatible, otherwise False. + """ + assert 2 <= len(layouts) <= 5 + # Check if A and B are compatible + A_layout, B_layout = layouts[:2] + if len(A_layout.size) < 1: + return False + if len(B_layout.size) < 1: + return False + A_size = list(V.graph.sizevars.size_hints(A_layout.size)) + B_size = list(V.graph.sizevars.size_hints(B_layout.size)) + if len(A_size) < 2: + A_size.insert(0, 1) + if len(B_size) < 2: + A_size.insert(1, 1) + # Are batch dims broadcastable? + while len(A_size) < len(B_size): + A_size.insert(0, 1) + while len(B_size) < len(A_size): + B_size.insert(0, 1) + K = max(A_size[-1], B_size[-2]) + M = A_size[-2] + N = B_size[-1] + if K != A_size[-1] and A_size[-1] != 1: + return False + if K != B_size[-2] and B_size[-1] != 1: + return False + # check batch dim broadcastable + for i in range(len(A_size) - 2): + if A_size[i] != B_size[i] and A_size[i] != 1 and B_size[i] != 1: + return False + if len(layouts) == 3: + C_layout = layouts[2] + C_size = [V.graph.sizevars.size_hint(i) for i in C_layout.size] + while len(C_size) < len(A_size): + C_size.insert(0, 1) + # check batch dims + for i in range(len(A_size) - 2): + bd = max(A_size[i], B_size[i]) + if bd != C_size[i] and C_size[i] != 1: + return False + if len(C_size) > len(A_size): + # This may happen if the last elements of C are contiguous and + # their multiplied size equals the last dim size of B + if M != C_size[len(A_size) - 2] and C_size[len(A_size) - 2] != 1: + return False + remaining_size = 1 + for i in range(len(A_size) - 1, len(C_size)): + remaining_size *= C_size[i] + if N != remaining_size and remaining_size != 1: + return False + return True + assert len(C_size) == len(A_size) + if M != C_size[-2] and C_size[-2] != 1: + return False + if N != C_size[-1] and C_size[-1] != 1: + return False + return True + + def _render_evt( + self, + op: GemmOperation, + evt_py_code: str, + var_name_to_buffer_name: dict[str, str], + output_dtype: torch.dtype, + accumulator_dtype: torch.dtype, + ) -> tuple[str, str, str]: # type: ignore[name-defined] # noqa: F821 + from .cutlass_lib_extensions.evt_extensions import create_example_tensors, trace + + name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs + + for name in V.graph.constants.keys(): + name_to_buffer[name] = V.graph.add_tensor_constant( + V.graph.constants[name], name + ) + + # handle the fake output buffer during lowering + name_to_buffer[self.output_node.get_name()] = self.output_node # type: ignore[assignment] + + acc_dtype = torch_dtype_to_cutlass_type(accumulator_dtype) + output_dtype = torch_dtype_to_cutlass_type(output_dtype) + examples = create_example_tensors( + var_name_to_buffer_name, + name_to_buffer, # type: ignore[arg-type] + V.graph.sizevars.size_hint, + ) + evt_name, evt_args, evt_code = trace( + evt_py_code, + examples, + acc_dtype, + output_dtype, + op.tile_description, # type: ignore[attr-defined] + op.epilogue_schedule, # type: ignore[attr-defined] + {k: name_to_buffer[v] for k, v in var_name_to_buffer_name.items()}, # type: ignore[arg-type,misc] + V.graph.sizevars.size_hint, + ) + + return ( + evt_name, + evt_args, + evt_code, + ) + + def _shape_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + return True + + def _alignment_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + return True + + def _set_bias_layout_and_alignment( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + has_bias = len(self.input_nodes) == 3 and self.input_nodes[2] is not None + if has_bias: + Bias = self.input_nodes[2] + # bias dtype + op.C.element = cutlass_utils.torch_dtype_to_cutlass_type( + Bias.get_layout().dtype + ) + + # Bias layout + bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout()) + op.C.layout = bias_layout + + # Bias alignment + status = self.set_alignment(Bias.get_layout(), op.C) + if not status: + return False + else: + op.C.element = cutlass_lib.DataType.void + return True + + def _define_gemm_instance( + self, + op: GemmOperation, + evt_name: Optional[str] = None, + ) -> tuple[str, str]: + """Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance. + + This function uses the Cutlass library to generate key parts of the codegen process. General Matrix Multiply + forms a core part of a number of scientific applications, so this efficient and adaptable implementation is + crucial. + + Args: + op (cutlass_library.gemm_op.GemmOperation): This is the core GEMM operation that we are defining and rendering. + + Returns: + Tuple[str, str]: A tuple where the first part is a string that constitutes the defined GEMM operation in C++ + code (render) and the second part is the string that specifies the operation type. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.library as cutlass_lib + + from .cutlass_lib_extensions import gemm_operation_extensions as gemm_extensions + + emitter = gemm_extensions.EmitGemmUniversal3xInstanceWithEVT(evt_name=evt_name) # type: ignore[call-arg] + + if not hasattr(op, "epilogue_functor") or not isinstance( + op.epilogue_functor, enum.Enum + ): + op = copy.deepcopy(op) + op.epilogue_functor = cutlass_lib.EpilogueFunctor.LinearCombination + + op_def = emitter.emit(op) + pattern = re.compile(r"\s*struct\s(.*?)\s:") + decl = [line for line in op_def.split("\n") if "struct " in line][-1] + + match = pattern.match(decl) + if match is None: + raise RuntimeError("Invalid Gemm config: \n" + op_def) + op_type = match.groups()[0] + if op.gemm_kind == cutlass_lib.GemmKind.Universal3x: + op_def += f"\n using {op_type}_device_type = cutlass::gemm::device::GemmUniversalAdapter<{op_type}>;\n" + op_type = f"{op_type}_device_type" + + return op_def, op_type + + def _get_extra_inputs_and_names( + self, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]: + Bias = self.input_nodes[2] if len(self.input_nodes) == 3 else None + inputs: list[Optional[Buffer]] = [] + names: list[str] = [] + return (Bias, inputs, names) + + def _update_arg_names_for_test_call_statement( + self, + arg_names: list[str], + input_nodes: list[Buffer], + ) -> list[str]: + if input_nodes[2] is None: + del arg_names[2] + else: + # Reorder them as Bias, A, B + if self.input_reorder is not None: + arg_names[0 : len(self.input_reorder)] = [ + arg_names[i] for i in self.input_reorder + ] + return arg_names + + def render_gemm_arguments( + self, + argument_template: str, + epilogue_template: str, + should_swap_xw: bool, + X: IRNode, + W: IRNode, + Bias: IRNode, + Y: IRNode, + alpha: float, + beta: float, + kernel: CUDATemplateKernel, + epilogue_args, + ) -> str: + """ + Render the Cutlass CUDA C++ code required for passing arguments to the GEMM operation. + + Args: + argument_template (str): Template for the GEMM operation arguments. + epilogue_template (str): Template for the epilogue arguments. + should_swap_xw (bool): Determines whether X, W operands should be swapped. If True, applies an explicit + transpose operation to X and W. + X (IRNode): The X input tensor. + W (IRNode): The W input tensor. + Bias (IRNode): The bias tensor. + Y (IRNode): The output tensor. + alpha (float): Scaling factor for the product of the inputs. + beta (float): Scaling factor for the output tensor. + kernel (CUDATemplateKernel): CUDA Template kernel for the operation. + epilogue_args (any): Additional arguments for the epilogue state. + + Returns: + str: A block of CUDA C++ code as a string, ready to be used as arguments for the GEMM operation. + + Note: If `should_swap_xw` is True, a transpose operation will be applied to the X, W, Bias, and Y + tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped + before the function call. + """ + options = dict( + alpha=alpha, + beta=beta, + X=X, + W=W, + Y=Y, + Bias=Bias, + template=self, + kernel=kernel, + M="M", + N="N", + epilogue_args=epilogue_args, + ) + assert epilogue_template is not None + + if should_swap_xw: + # Swap + def clone_with_transposed_stride(node: IRNode) -> IRNode: + old_layout = node.get_layout() + new_stride = list(old_layout.stride) # type: ignore[union-attr] + new_stride[-2], new_stride[-1] = new_stride[-1], new_stride[-2] + assert old_layout.device is not None + new_layout = FixedLayout( + old_layout.device, + old_layout.dtype, + list(old_layout.size), # type: ignore[union-attr] + new_stride, + old_layout.offset, # type: ignore[union-attr] + ) + return Buffer(name=node.get_name(), layout=new_layout) + + new_X = clone_with_transposed_stride(X) + new_W = clone_with_transposed_stride(W) + new_Bias = clone_with_transposed_stride(Bias) + new_Y = clone_with_transposed_stride(Y) + options["X"], options["W"], options["Bias"], options["Y"] = ( + new_W, + new_X, + new_Bias, + new_Y, + ) + options["M"], options["N"] = "N", "M" + + epilogue_arguments = self._template_from_string(epilogue_template).render( + **options + ) + arguments = self._template_from_string(argument_template).render( + epilogue_arguments=epilogue_arguments, **options + ) + + return arguments + + +class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate): + def __init__( + self, + input_nodes: list[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[list[int]] = None, + ): + super().__init__(input_nodes, layout, alpha, beta, input_reorder) + + @staticmethod + def add_cutlass_gemm_choices( + choices: list[ChoiceCaller], + layout: ir.Layout, + input_nodes: list[Buffer], + alpha: Union[float, int] = 1, + beta: Union[float, int] = 0, + input_reorder: Optional[list[int]] = None, + use_fast_accum: Optional[bool] = False, + **extra_kwargs, + ) -> None: + template = CUTLASS2xGemmTemplate( + input_nodes, layout, alpha, beta, input_reorder + ) + template._add_cutlass_gemm_choices( + choices, layout, input_nodes, alpha, beta, input_reorder, **extra_kwargs + ) + + @staticmethod + def _get_supported_ops() -> "list[cutlass_library.gemm_operation.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + import cutlass_library.library as cutlass_lib + + return [cutlass_lib.GemmKind.Universal, cutlass_lib.GemmKind.Sparse] + + @staticmethod + def _has_tma_epilogue(self) -> bool: + return False + + def _get_template(self) -> str: + return GEMM_TEMPLATE_CUTLASS_2X + + def _get_template_args( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> tuple[str, Optional[str]]: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + return (GEMM_ARGS_SPARSE_CUTLASS_2X, None) + + return (GEMM_ARGS_CUTLASS_2X, None) + + def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool: + """ + Evaluates whether input layouts are compatible for set of operations supported by this class. + + Args: + layouts (List[Layout]): List containing Layout objects representing + the input matrices. + + Returns: + bool: True if layouts are GEMM compatible, otherwise False. + """ + assert len(layouts) == 2 or len(layouts) == 3 + # Check if A and B are compatible + A_layout, B_layout = layouts[:2] + if len(A_layout.size) != 2: + return False + if len(B_layout.size) != 2: + return False + A_size = [int(i) for i in A_layout.size] + B_size = [int(i) for i in B_layout.size] + K = max(A_size[1], B_size[0]) + return (K == A_size[1] or K == 2 * A_size[1]) and K == B_size[0] + + def _shape_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + X, W = self.input_nodes[0], self.input_nodes[1] + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + return X.get_size()[1] * 2 == W.get_size()[0] + + return X.get_size()[1] == W.get_size()[0] + + def _alignment_match( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind != cutlass_lib.GemmKind.Sparse: + return True + + # SparseGemm in CUTLASS has specific alignment check that for + # small k could make some of the choices throw kMisalignedOperand + # CUTLASS error when run, see: + # https://github.com/NVIDIA/cutlass/blob/e01b9b5029b7caca5a43c29f7d2714d7cf1dcae8/include/cutlass/gemm/kernel/sparse_gemm.h#L198-L200 # noqa: B950 + # So, let's skip these choices if that would be the case. + X = self.input_nodes[0] + return (X.get_size()[1] * 2) % op.tile_description.tile_shape[2] == 0 + + def _set_bias_layout_and_alignment( + self, + op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821 + ) -> bool: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + op.C.layout = op.D.layout + return True + + if len(self.input_nodes) >= 3 and self.input_nodes[2] is not None: + Bias = self.input_nodes[2] + bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout()) + if bias_layout != op.D.layout: + # For cutlass2, bias and output layout must match + return False + if not self.set_alignment(Bias.get_layout(), op.C): + return False + else: + op.C.layout = op.D.layout + return True + + def _define_gemm_instance( + self, + op: GemmOperation, + evt_name: Optional[str] = None, + ) -> tuple[str, str]: + """Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance. + + This function uses the Cutlass library to generate key parts of the codegen process. General Matrix Multiply + forms a core part of a number of scientific applications, so this efficient and adaptable implementation is + crucial. + + Args: + op (cutlass_library.gemm_op.GemmOperation): This is the core GEMM operation that we are defining and rendering. + + Returns: + Tuple[str, str]: A tuple where the first part is a string that constitutes the defined GEMM operation in C++ + code (render) and the second part is the string that specifies the operation type. + """ + assert cutlass_utils.try_import_cutlass() + import cutlass_library.gemm_operation as cutlass_gemm_op + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + emitter = cutlass_gemm_op.EmitSparseGemmInstance() + else: + emitter = cutlass_gemm_op.EmitGemmInstance() + op_def = emitter.emit(op) + op_def = op_def.replace( + "cutlass::gemm::device::Gemm", "cutlass::gemm::device::GemmUniversal" + ) + if op.gemm_kind != cutlass_lib.GemmKind.Sparse: + op_def = op_def.replace("false,", "") + pattern = re.compile(r"\s*using\s(.*?)\s=") + decl = op_def.split("\n")[2] + + match = pattern.match(decl) + if match is None: + raise RuntimeError("Invalid Gemm config: \n" + op_def) + op_type = match.groups()[0] + return op_def, op_type + + def _get_extra_inputs_and_names( + self, + op: "cutlass_gemm_op.GemmOperation" = None, # type: ignore[name-defined] # noqa: F821 + ) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]: + import cutlass_library.library as cutlass_lib + + if op.gemm_kind == cutlass_lib.GemmKind.Sparse: + Bias = None + Meta = self.input_nodes[2] + else: + Bias = None if len(self.input_nodes) == 2 else self.input_nodes[2] + Meta = None + inputs = [Meta] + names = ["Meta"] + return (Bias, inputs, names) + + def _update_arg_names_for_test_call_statement( + self, + arg_names: list[str], + input_nodes: list[Buffer], + ) -> list[str]: + if input_nodes[3] is None: + del arg_names[3] + if input_nodes[2] is None: + del arg_names[2] + return arg_names + + def render_gemm_arguments( + self, + instance_type: str, + argument_template: str, + epilogue_template: str, + should_swap_xw: bool, + X: IRNode, + W: IRNode, + Bias: IRNode, + Meta: IRNode, + Y: IRNode, + alpha: float, + beta: float, + kernel: CUDATemplateKernel, + epilogue_args, + ) -> str: + """ + Render the Cutlass CUDA C++ code required for passing arguments to the GEMM operation. + + Args: + instance_type (str): GEMM instance type. + argument_template (str): Template for the GEMM operation arguments. + epilogue_template (str): Template for the epilogue arguments. + should_swap_xw (bool): Determines whether X, W operands should be swapped. If True, applies an explicit + transpose operation to X and W. + X (IRNode): The X input tensor. + W (IRNode): The W input tensor. + Bias (IRNode): The bias tensor. + Meta (IRNode): The meta tensor. + Y (IRNode): The output tensor. + alpha (float): Scaling factor for the product of the inputs. + beta (float): Scaling factor for the output tensor. + kernel (CUDATemplateKernel): CUDA Template kernel for the operation. + epilogue_args (any): Additional arguments for the epilogue state. + + Returns: + str: A block of CUDA C++ code as a string, ready to be used as arguments for the GEMM operation. + + Note: If `should_swap_xw` is True, a transpose operation will be applied to the X, W, Bias, and Y + tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped + before the function call. + """ + options = dict( + instance_type=instance_type, + alpha=alpha, + beta=beta, + X=X, + W=W, + Y=Y, + Bias=Bias, + Meta=Meta, + template=self, + kernel=kernel, + M="M", + N="N", + epilogue_args=epilogue_args, + ) + + if epilogue_template is None: + arguments = self._template_from_string(argument_template).render( + split_k=1, **options + ) + return arguments + + epilogue_arguments = self._template_from_string(epilogue_template).render( + **options + ) + arguments = self._template_from_string(argument_template).render( + epilogue_arguments=epilogue_arguments, **options + ) + + return arguments diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/serialization.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..a1bca77f0a80df7842cb8a9efc9fbeaeccd1ab62 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda/serialization.py @@ -0,0 +1,466 @@ +# mypy: allow-untyped-defs +import enum +import functools +import json +from enum import Enum +from typing import Optional + +from torch._inductor.codegen.cuda.cutlass_utils import try_import_cutlass + + +class CUTLASSOperationSerializer: + """Serializes and deserializes CUTLASS GEMM operations to/from JSON. + + Handles GemmOperation objects and their nested components (TileDescription, TensorDescription). + """ + + # not used, but keeping in case we want to generalize the serializer + _SUPPORTED_CLASSES: list[str] = [ + "GemmOperation", + "GemmKind", + "TileDescription", + "TensorDescription", + "DataType", + "EpilogueFunctor", + "EpilogueFunctor3x", + "SwizzlingFunctor", + "KernelScheduleType", + "EpilogueScheduleType", + "TileSchedulerType", + ] + + @classmethod + def serialize(cls, operation: "GemmOperation"): # type: ignore[name-defined] # noqa: F821 + """Serialize a GEMM operation to JSON string. + + Args: + operation: GemmOperation object + indent: JSON indentation spaces + + Returns: + str: JSON representation of the operation + """ + assert operation.__class__.__qualname__ == "GemmOperation", ( + "Only GemmOperation objects are supported via the main API" + ) + return json.dumps(cls._gemm_operation_to_json(operation)) + + @classmethod + def deserialize(cls, json_str: str) -> "GemmOperation": # type: ignore[name-defined] # noqa: F821 + """Deserialize JSON string to a GEMM operation. + + Args: + json_str: JSON string of a GEMM operation + + Returns: + GemmOperation: Reconstructed operation + """ + json_dict = json.loads(json_str) + return cls._json_to_gemm_operation(json_dict) + + @classmethod + def _gemm_operation_to_json(cls, operation): + """Convert GemmOperation to JSON-serializable dict. + + Args: + operation: GemmOperation object + + Returns: + dict: Dictionary representation + """ + from cutlass_library.library import TensorDescription + + # Create the main dictionary with required and optional parameters + result = { + # Required parameters + "gemm_kind": cls._enum_to_json(operation.gemm_kind), + "arch": operation.arch, + "tile_description": cls._tile_description_to_json( + operation.tile_description + ), + "A": cls._tensor_description_to_json(operation.A), + "B": cls._tensor_description_to_json(operation.B), + "C": cls._tensor_description_to_json(operation.C), + "element_epilogue": cls._enum_to_json(operation.element_epilogue), + # Optional parameters + "epilogue_functor": cls._enum_to_json(operation.epilogue_functor), + "swizzling_functor": cls._enum_to_json(operation.swizzling_functor), + "D": cls._tensor_description_to_json(operation.D) if operation.D else None, + "kernel_schedule": cls._enum_to_json(operation.kernel_schedule), + "epilogue_schedule": cls._enum_to_json(operation.epilogue_schedule), + "tile_scheduler": cls._enum_to_json(operation.tile_scheduler), + } + + # Process optional attributes + optional_attrs = [ + "mixed_input_mode", + "mixed_input_shuffle", + "ScaleFactorA", + "ScaleFactorB", + "ScaleFactorD", + "ScaleFactorMVecSize", + "ScaleFactorNVecSize", + "ScaleFactorKVecSize", + "ScaleFactorVectorSize", + "is_3x", + ] + + for attr in optional_attrs: + if not hasattr(operation, attr): + continue + + value = getattr(operation, attr) + + if isinstance(value, TensorDescription): + result[attr] = cls._tensor_description_to_json(value) + elif isinstance(value, Enum): + result[attr] = cls._enum_to_json(value) + else: + result[attr] = value + + return result + + @classmethod + def _json_to_gemm_operation(cls, json_dict): + """Convert JSON dict to GemmOperation object. + + Args: + json_dict: Dictionary representation + + Returns: + GemmOperation: Reconstructed object + """ + from cutlass_library import DataType + from cutlass_library.gemm_operation import GemmKind, GemmOperation + from cutlass_library.library import ( + EpilogueFunctor, + EpilogueFunctor3x, + EpilogueScheduleType, + KernelScheduleType, + MixedInputMode, + SwizzlingFunctor, + TileSchedulerType, + ) + + # Extract constructor parameters from the JSON dictionary + gemm_kind = cls._json_to_enum(json_dict["gemm_kind"], GemmKind) + arch = json_dict["arch"] + tile_description = cls._json_to_tile_description(json_dict["tile_description"]) + A = cls._json_to_tensor_description(json_dict.get("A")) + B = cls._json_to_tensor_description(json_dict.get("B")) + C = cls._json_to_tensor_description(json_dict.get("C")) + element_epilogue = cls._json_to_enum(json_dict["element_epilogue"], DataType) + + # Get optional parameters with defaults + epilogue_functor = cls._json_to_enum( + json_dict.get("epilogue_functor"), + EpilogueFunctor3x if json_dict.get("is_3x") else EpilogueFunctor, + ) + swizzling_functor = cls._json_to_enum( + json_dict.get("swizzling_functor"), SwizzlingFunctor + ) + D = cls._json_to_tensor_description(json_dict.get("D")) + kernel_schedule = cls._json_to_enum( + json_dict.get("kernel_schedule"), KernelScheduleType + ) + epilogue_schedule = cls._json_to_enum( + json_dict.get("epilogue_schedule"), EpilogueScheduleType + ) + tile_scheduler = cls._json_to_enum( + json_dict.get("tile_scheduler"), TileSchedulerType + ) + + mixed_input_mode = cls._json_to_enum( + json_dict.get("mixed_input_mode"), MixedInputMode + ) + mixed_input_shuffle = json_dict.get("mixed_input_shuffle", False) + + # Scale factors + ScaleFactorA = cls._json_to_enum(json_dict.get("ScaleFactorA"), DataType) + ScaleFactorB = cls._json_to_enum(json_dict.get("ScaleFactorB"), DataType) + + ScaleFactorD = None + if "ScaleFactorD" in json_dict and "ScaleFactorVectorSize" in json_dict: + ScaleFactorD = { + "tensor": cls._json_to_tensor_description( + json_dict.get("ScaleFactorD") + ), + "vector_size": json_dict.get("ScaleFactorVectorSize"), + } + + ScaleFactorMVecSize = json_dict.get("ScaleFactorMVecSize") + ScaleFactorNVecSize = json_dict.get("ScaleFactorNVecSize") + ScaleFactorKVecSize = json_dict.get("ScaleFactorKVecSize") + + # Create the GemmOperation with the extracted parameters + operation = GemmOperation( + gemm_kind=gemm_kind, + arch=arch, + tile_description=tile_description, + A=A, + B=B, + C=C, + element_epilogue=element_epilogue, + epilogue_functor=epilogue_functor, + swizzling_functor=swizzling_functor, + D=D, + kernel_schedule=kernel_schedule, + epilogue_schedule=epilogue_schedule, + tile_scheduler=tile_scheduler, + mixed_input_mode=mixed_input_mode, + mixed_input_shuffle=mixed_input_shuffle, + ScaleFactorA=ScaleFactorA, + ScaleFactorB=ScaleFactorB, + ScaleFactorD=ScaleFactorD, + ScaleFactorMVecSize=ScaleFactorMVecSize, + ScaleFactorNVecSize=ScaleFactorNVecSize, + ScaleFactorKVecSize=ScaleFactorKVecSize, + ) + + return operation + + @classmethod + def _tile_description_to_json(cls, tile_desc): + """ + Convert TileDescription to JSON dict. + + Args: + tile_desc: TileDescription object + + Returns: + dict: Dictionary representation + """ + if tile_desc is None: + return None + + # Create a dictionary for math_instruction if it exists + math_instruction_dict = None + if ( + hasattr(tile_desc, "math_instruction") + and tile_desc.math_instruction is not None + ): + math_instruction = tile_desc.math_instruction + math_instruction_dict = { + "instruction_shape": math_instruction.instruction_shape, + "element_a": cls._enum_to_json(math_instruction.element_a), + "element_b": cls._enum_to_json(math_instruction.element_b), + "element_accumulator": cls._enum_to_json( + math_instruction.element_accumulator + ), + "opcode_class": cls._enum_to_json(math_instruction.opcode_class), + "math_operation": cls._enum_to_json(math_instruction.math_operation), + } + + # Add element_scale_factor if it exists + if ( + hasattr(math_instruction, "element_scale_factor") + and math_instruction.element_scale_factor is not None + ): + math_instruction_dict["element_scale_factor"] = cls._enum_to_json( + math_instruction.element_scale_factor + ) + + # Create the main dictionary with field names matching TileDescription constructor parameters + result = { + "threadblock_shape": tile_desc.threadblock_shape, + "stages": tile_desc.stages, + "warp_count": tile_desc.warp_count, + "math_instruction": math_instruction_dict, + "min_compute": tile_desc.minimum_compute_capability, # Store as min_compute for constructor + "max_compute": tile_desc.maximum_compute_capability, # Store as max_compute for constructor + "cluster_shape": tile_desc.cluster_shape, + "explicit_vector_sizes": tile_desc.explicit_vector_sizes, + } + + # Add tile_shape if it exists and differs from threadblock_shape + if ( + hasattr(tile_desc, "tile_shape") + and tile_desc.tile_shape != tile_desc.threadblock_shape + ): + result["tile_shape"] = tile_desc.tile_shape + + return result + + @classmethod + def _json_to_tile_description(cls, json_dict): + """ + Convert JSON dict to TileDescription object. + + Args: + json_dict: Dictionary representation + + Returns: + TileDescription: Reconstructed object + """ + if json_dict is None: + return None + + from cutlass_library import DataType + from cutlass_library.library import ( + MathInstruction, + MathOperation, + OpcodeClass, + TileDescription, + ) + + # First, reconstruct the math_instruction if it exists + math_instruction_obj = None + if ( + "math_instruction" in json_dict + and json_dict["math_instruction"] is not None + ): + mi_dict = json_dict["math_instruction"] + + # Convert string enum names back to enum values + element_a = cls._json_to_enum(mi_dict["element_a"], DataType) + element_b = cls._json_to_enum(mi_dict["element_b"], DataType) + element_acc = cls._json_to_enum(mi_dict["element_accumulator"], DataType) + + # Get the opcode_class enum + opcode_class = cls._json_to_enum(mi_dict["opcode_class"], OpcodeClass) + + # Get the math_operation enum + math_op = cls._json_to_enum(mi_dict["math_operation"], MathOperation) + + # Create the MathInstruction object + math_instruction_obj = MathInstruction( + instruction_shape=mi_dict["instruction_shape"], + element_a=element_a, + element_b=element_b, + element_accumulator=element_acc, + opcode_class=opcode_class, + math_operation=math_op, + ) + + # Add element_scale_factor if it exists + if ( + "element_scale_factor" in mi_dict + and mi_dict["element_scale_factor"] is not None + ): + math_instruction_obj.element_scale_factor = cls._json_to_enum( + mi_dict["element_scale_factor"], DataType + ) + + # Get compute capability values, checking both naming conventions + min_compute = json_dict.get( + "min_compute", json_dict.get("minimum_compute_capability") + ) + max_compute = json_dict.get( + "max_compute", json_dict.get("maximum_compute_capability") + ) + + # Get cluster shape with default value + cluster_shape = json_dict.get("cluster_shape", [1, 1, 1]) + + # Create the TileDescription object + tile_desc = TileDescription( + threadblock_shape=json_dict["threadblock_shape"], + stages=json_dict["stages"], + warp_count=json_dict["warp_count"], + math_instruction=math_instruction_obj, + min_compute=min_compute, + max_compute=max_compute, + cluster_shape=cluster_shape, + explicit_vector_sizes=json_dict.get("explicit_vector_sizes"), + ) + + # Set tile_shape if it exists and differs from threadblock_shape + if ( + "tile_shape" in json_dict + and json_dict["tile_shape"] != json_dict["threadblock_shape"] + ): + tile_desc.tile_shape = json_dict["tile_shape"] + + return tile_desc + + @classmethod + def _tensor_description_to_json(cls, tensor_desc): + """Convert TensorDescription to JSON dict. + + Args: + tensor_desc: TensorDescription object + + Returns: + dict: Dictionary representation + """ + if tensor_desc is None: + return None + + return { + "element": cls._enum_to_json(tensor_desc.element), + "layout": cls._enum_to_json(tensor_desc.layout), + "alignment": tensor_desc.alignment, + "complex_transform": cls._enum_to_json(tensor_desc.complex_transform), + } + + @classmethod + def _json_to_tensor_description(cls, tensor_json): + """Convert JSON dict to TensorDescription object. + + Args: + tensor_json: Dictionary representation + + Returns: + TensorDescription: Reconstructed object + """ + if tensor_json is None: + return None + + from cutlass_library import DataType + from cutlass_library.library import ( + ComplexTransform, + LayoutType, + TensorDescription, + ) + + element = cls._json_to_enum(tensor_json["element"], DataType) + layout = cls._json_to_enum(tensor_json["layout"], LayoutType) + alignment = tensor_json["alignment"] + complex_transform = cls._json_to_enum( + tensor_json["complex_transform"], ComplexTransform + ) + + return TensorDescription(element, layout, alignment, complex_transform) + + @classmethod + def _enum_to_json(cls, enum_value): + """Convert enum value to JSON dict. + + Args: + enum_value: Enum value + + Returns: + dict: Dictionary representation + """ + if enum_value is None: + return None + + assert isinstance(enum_value, enum.Enum) + return { + "type": enum_value.__class__.__name__, + "name": enum_value.name, + } + + @classmethod + def _json_to_enum(cls, json_dict, enum_class): + """Convert JSON dict to enum value. + + Format: {name: "EnumName", value: 1} + + Args: + json_dict: Dictionary representation + enum_class: Target enum class + + Returns: + Reconstructed enum value + """ + if json_dict is None or json_dict.get("name", "None") == "None": + return None + return enum_class[json_dict["name"]] + + +@functools.lru_cache(1) +def get_cutlass_operation_serializer() -> Optional[CUTLASSOperationSerializer]: + if not try_import_cutlass(): + return None + return CUTLASSOperationSerializer() diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..bd594e8e5395e2b75838eac51a2bfd109daf20ee --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py @@ -0,0 +1,136 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import Any, Optional, TYPE_CHECKING, Union + +from ..scheduler import ( + BaseSchedulerNode, + BaseScheduling, + FusedSchedulerNode, + Scheduler, + SchedulerNode, +) +from .cuda.cuda_cpp_scheduling import CUDACPPScheduling +from .rocm.rocm_cpp_scheduling import ROCmCPPScheduling +from .triton import TritonScheduling + + +if TYPE_CHECKING: + from collections.abc import Sequence + from typing_extensions import TypeAlias + + from sympy import Expr + + import torch + from torch.utils._ordered_set import OrderedSet + + from .common import BackendFeature + + _IntLike: TypeAlias = Union[int, Expr] + + +class CUDACombinedScheduling(BaseScheduling): + """ + Scheduler for CUDA Kernels, which delegates calls as appropriate + to the CUDA-C++ and Triton Schedulers, which both work for CUDA devices + and use a unified-wrapper for codegen. + + If Scheduling code needs to be specialized for the case of mixed Triton / CUDA C++ code, + this would also be the place to do it. + """ + + def __init__(self, scheduler: Optional[Scheduler]) -> None: + super().__init__(scheduler) + self._triton_scheduling = TritonScheduling(scheduler) + self._cuda_cpp_scheduling = CUDACPPScheduling(scheduler) + self._rocm_cpp_scheduling = ROCmCPPScheduling(scheduler) + + def get_backend_features(self, device: torch.device) -> OrderedSet[BackendFeature]: + return self._triton_scheduling.get_backend_features(device) + + def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling: + if self._cuda_cpp_scheduling.is_cuda_cpp_template(node): + return self._cuda_cpp_scheduling + if self._rocm_cpp_scheduling.is_rocm_cpp_template(node): + return self._rocm_cpp_scheduling + return self._triton_scheduling + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if self._cuda_cpp_scheduling.can_fuse_vertical(node1, node2): + return True + elif self._cuda_cpp_scheduling.is_cuda_cpp_template( + node1 + ) or self._cuda_cpp_scheduling.is_cuda_cpp_template(node2): + return False + return self._triton_scheduling.can_fuse_vertical(node1, node2) + + def can_fuse_horizontal( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + for node in (node1, node2): + if self._cuda_cpp_scheduling.is_cuda_cpp_template(node): + return self._cuda_cpp_scheduling.can_fuse_horizontal( + node1, node2 + ) # always False at the moment + return self._triton_scheduling.can_fuse_horizontal(node1, node2) + + def group_fn( + self, sizes: Sequence[Sequence[_IntLike]] + ) -> tuple[tuple[_IntLike, ...], ...]: + return self._triton_scheduling.group_fn(sizes) + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ) -> Optional[str]: + if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node): + assert not prologue_nodes + return self._cuda_cpp_scheduling.codegen_template( + template_node, epilogue_nodes, prologue_nodes + ) + elif self._rocm_cpp_scheduling.is_rocm_cpp_template(template_node): + assert not epilogue_nodes + assert not prologue_nodes + return self._rocm_cpp_scheduling.codegen_template( + template_node, epilogue_nodes, prologue_nodes + ) + else: + return self._triton_scheduling.codegen_template( + template_node, epilogue_nodes, prologue_nodes + ) + + def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]) -> None: + return self._triton_scheduling.codegen_node(node) + + def codegen_sync(self) -> None: + return self._triton_scheduling.codegen_sync() + + def flush(self) -> None: + return self._triton_scheduling.flush() + + def codegen_combo_kernel(self, *args: Any, **kwargs: Any) -> None: + return self._triton_scheduling.codegen_combo_kernel(*args, **kwargs) + + def benchmark_fused_nodes( + self, nodes: Sequence[BaseSchedulerNode] + ) -> tuple[float, str]: + return self._triton_scheduling.benchmark_fused_nodes(nodes) + + def benchmark_codegened_module(self, module): + return self._triton_scheduling.benchmark_codegened_module(module) + + def generate_kernel_code_from_nodes( + self, nodes: Sequence[Any], benchmark_kernel: bool = False + ) -> str: + return self._triton_scheduling.generate_kernel_code_from_nodes( + nodes, benchmark_kernel + ) + + def benchmark_combo_kernel( + self, node_list: Sequence[BaseSchedulerNode] + ) -> tuple[float, float, list[Optional[str]]]: + return self._triton_scheduling.benchmark_combo_kernel(node_list) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/debug_utils.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/debug_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fe752ed4f7d313eae830917b6f9224a5146075a0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/debug_utils.py @@ -0,0 +1,284 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import logging +import os +from enum import Enum +from typing import Callable, Optional + +import torch +from torch import dtype as torch_dtype + +from .. import config +from ..virtualized import V +from .multi_kernel import MultiKernel + + +log = logging.getLogger(__name__) + + +def _print_debugging_tensor_value_info(msg, arg): + # helper for printing debugging stats for intermediate tensor values + # at jit inductor level codegen + max_numel_to_print = 64 + print(msg) + if not isinstance(arg, torch.Tensor): + print("Value: ", arg) + return + numel = arg.float().numel() + # print the debug printing stats + if numel <= max_numel_to_print: + print(arg) + print("Number of elements: ", numel) + print("Size: ", arg.float().size()) + print("Dtype: ", arg.float().mean().item()) + print("Mean: ", arg.float().mean().item()) + print("Min: ", arg.float().min().item()) + print("Max: ", arg.float().max().item()) + print("Std: ", arg.float().std().item()) + + +# AOTI debug printing related configs +class IntermediateValueDebuggingLevel(Enum): + # OFF: No intermediate tensor value debug info will be printed or saved. + OFF = "0" + # LEVEL 1: Save all intermediate tensor values to individual `.pt` files. No debug printing will be displayed. + SAVE_ONLY = "1" + # LEVEL 2: Print all intermediate tensor values by default to the console. No debug saving will be performed. + PRINT_ONLY = "2" + # LEVEL 3: Print all kernel names to the console only. No debug saving/printing for input tensor value info will be performed. + # This mode can be helpful in cases when you just want to pinpointing what kernel is running into a CUDA IMA issue, etc. + PRINT_KERNEL_NAMES_ONLY = "3" + + +class DebugPrinterManager: + def __init__( + self, + debug_printer_level, + use_array_ref: bool, + writeline: Optional[Callable[..., None]] = None, + args_to_print_or_save: Optional[list[str]] = None, + kernel_name: str = "", + kernel=None, + arg_signatures: Optional[list[type]] = None, + kernel_type=None, + ): + self.debug_printer_level = IntermediateValueDebuggingLevel(debug_printer_level) + self.use_array_ref = use_array_ref + if args_to_print_or_save is None: + args_to_print_or_save = [] + self.args_to_print_or_save = args_to_print_or_save + self.kernel_name = kernel_name + self.arg_signatures: Optional[list[type]] = None + self.kernel = kernel + self.filtered_kernel_names_to_print = self._get_debug_filtered_kernel_names() + self.kernel_type = None + + def __enter__(self): + self._perform_debug_print_or_save_helper( + self.args_to_print_or_save, + self.kernel_name, + before_launch=True, + arg_signatures=self.arg_signatures, + ) + + def __exit__(self, args_to_print_or_save, kernel_name, arg_signatures): + self._perform_debug_print_or_save_helper( + args_to_print_or_save, + kernel_name, + before_launch=False, + arg_signatures=arg_signatures, + ) + + def _perform_debug_print_or_save_helper( + self, + args_to_print_or_save, + kernel_name, + before_launch, + arg_signatures: Optional[list[type]] = None, + ): + if self.debug_printer_level == IntermediateValueDebuggingLevel.OFF: + return + if self.debug_printer_level == IntermediateValueDebuggingLevel.SAVE_ONLY: + # by default save all the tensor values before launch + self.codegen_intermediate_tensor_value_save( + self.args_to_print_or_save, + self.kernel_name, + before_launch, + arg_signatures=self.arg_signatures, + ) + if self.debug_printer_level == IntermediateValueDebuggingLevel.PRINT_ONLY: + # by default print all the tensor values before launch + self.codegen_intermediate_tensor_value_print( + self.args_to_print_or_save, + self.kernel_name, + before_launch, + arg_signatures=self.arg_signatures, + ) + if ( + self.debug_printer_level + == IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY + ): + # Print all kernel names to the console only + self.codegen_intermediate_tensor_value_print( + [], + self.kernel_name, + before_launch, + ) + + @functools.lru_cache # noqa: B019 + def _get_debug_filtered_kernel_names(self) -> list[str]: + if config.aot_inductor.filtered_kernel_names is None: + return [] + return [ + x.strip() + for x in config.aot_inductor.filtered_kernel_names.lower().split(",") + ] + + def set_printer_args( + self, + args_to_print_or_save: list[str], + kernel_name: str, + arg_signatures: Optional[list[type]], + kernel, + kernel_type=None, + ): + # Note: MultiKernel debug printing is not supported for now + if isinstance(kernel, MultiKernel): + log.info( + "MultiKernel type is not supported in AOTI debug printer tool yet." + ) + self.debug_printer_level = IntermediateValueDebuggingLevel.OFF + + self.kernel_type = kernel_type + # Note: if the kernel type is an extern kernel (or cpp kernel), we do a special handling to + # get the list of args_to_print_or_save + # TODO: Find a more reliable way to detect kernel args types to print for extern kernel calls + if kernel_type == "extern": + args_to_print_or_save_extern = [ + arg for arg in args_to_print_or_save if arg.startswith(("buf", "arg")) + ] + self.args_to_print_or_save = args_to_print_or_save_extern + elif kernel_type == "cpp": + self.args_to_print_or_save = [ + ( + f"copy_arrayref_tensor_to_tensor({arg})" + if self.use_array_ref + else arg + ) + for arg in args_to_print_or_save + if arg.startswith(("buf", "arg")) + ] + else: + self.args_to_print_or_save = args_to_print_or_save + self.kernel_name = kernel_name + self.arg_signatures = arg_signatures + self.kernel = kernel + + def codegen_model_inputs_value_print(self, input_args_to_print: list[str]) -> None: + if self.debug_printer_level != IntermediateValueDebuggingLevel.PRINT_ONLY: + return + for arg in input_args_to_print: + if V.graph.cpp_wrapper: + V.graph.wrapper_code.prefix.writeline( + f'aoti_torch_print_tensor_handle({arg}, "aoti_model_inputs - {arg}");' + ) + + def codegen_intermediate_tensor_value_save( + self, + args_to_save, + kernel_name, + before_launch=True, + arg_signatures: Optional[list[type]] = None, + ) -> None: + for i, arg in enumerate(args_to_save): + if arg_signatures is not None and not isinstance( + arg_signatures[i], torch_dtype + ): + # infer from the arg data type (has torch.dtype) to see if it is a tensor type + continue + launch_prefix = "before_launch" if before_launch else "after_launch" + if V.graph.cpp_wrapper: + V.graph.wrapper_code.writeline( + f'aoti_torch_save_tensor_handle({arg}, "{arg}", "{launch_prefix}", "{kernel_name}");' + ) + else: + cwd = os.getcwd() + saved_dir = cwd + "/tmp/jit_inductor/" + if not os.path.exists(saved_dir): + log.info( + "Creating directory to save inductor intermediate tensor values." + ) + os.makedirs(saved_dir) + # Save the model to the directory + saved_path = saved_dir + f"{launch_prefix}_{kernel_name}_{arg}.pt" + log.info( + "Saved intermediate tensor %s for %s to %s", + arg, + kernel_name, + saved_path, + ) + line = f"torch.save({arg}, '{saved_path}')" + V.graph.wrapper_code.writeline(line) + + def codegen_intermediate_tensor_value_print( + self, + args_to_print, + kernel_name, + before_launch=True, + arg_signatures: Optional[list[type]] = None, + ) -> None: + launch_prefix = "before_launch" if before_launch else "after_launch" + + # if the debug printing level is PRINT_KERNEL_NAMES_ONLY + # we only print the kernel name to the console + if ( + self.debug_printer_level + == IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY + ): + if V.graph.cpp_wrapper: + V.graph.wrapper_code.writeline( + f'printf("[ {launch_prefix}: {kernel_name} ]\\n");' + ) + return + + if self.debug_printer_level != IntermediateValueDebuggingLevel.PRINT_ONLY: + return + for i, arg in enumerate(args_to_print): + # when debug printing is enabled i.e. IntermediateValueDebuggingLevel.PRINT_ONLY, + # check if filtered kernel name list is provided + if ( + len(self.filtered_kernel_names_to_print) > 0 + and kernel_name.lower() not in self.filtered_kernel_names_to_print + ): + continue + if V.graph.cpp_wrapper: + if arg_signatures is not None and isinstance( + arg_signatures[i], torch_dtype + ): + # infer from the arg data type (has torch.dtype) to see if it is a tensor type + V.graph.wrapper_code.writeline( + f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");' + ) + elif arg_signatures is not None and isinstance( + arg_signatures[i], + ( + type(torch._inductor.codegen.wrapper.SymbolicCallArg), + type(int), + type(float), + type(bool), + ), + ): + V.graph.wrapper_code.writeline( + f'printf("[ {launch_prefix} - {kernel_name} - {arg}: %ld ]", {arg}); printf("\\\\n");' + ) + else: + if arg_signatures is None and self.kernel_type == "cpp" or "extern": + V.graph.wrapper_code.writeline( + f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");' + ) + else: + V.graph.wrapper_code.writeline( + f'_print_debugging_tensor_value_info("inductor: {launch_prefix} - {kernel_name} - {arg}", {arg})' + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/halide.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/halide.py new file mode 100644 index 0000000000000000000000000000000000000000..28531864fa5693dcc60244d65ea74606669b0ae1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/halide.py @@ -0,0 +1,1699 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import dataclasses +import functools +import itertools +import logging +import re +from collections import defaultdict +from math import inf +from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union + +import sympy + +import torch +import torch._logging + +from ..._prims_common import is_integer_dtype +from ...utils._ordered_set import OrderedSet +from ...utils._sympy.functions import FloorDiv, ModularIndexing +from ...utils._sympy.symbol import symbol_is_type, SymT +from ...utils._sympy.value_ranges import ValueRanges +from .. import config, ir +from ..codecache import HalideCodeCache +from ..ir import get_reduction_combine_fn +from ..metrics import is_metric_table_enabled, log_kernel_metadata +from ..ops_handler import AddParenHandler +from ..runtime.hints import HalideInputSpec, HalideMeta +from ..utils import ( + get_bounds_index_expr, + get_kernel_metadata, + parallel_num_threads, + sympy_index_symbol, + sympy_subs, +) +from ..virtualized import _ops as ops, V +from .common import ( + BackendFeature, + CSEVariable, + DeferredLine, + IndentedBuffer, + KernelArgType, + OpOverrides, + PythonPrinter, + SizeArg, + TensorArg, +) +from .cpp import DTYPE_TO_CPP +from .cpp_utils import cexpr +from .simd import constant_repr, SIMDKernel, SIMDScheduling + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from ..ops_handler import ReductionType, StoreMode + +log = logging.getLogger(__name__) + + +def halide_constant(val): + if isinstance(val, int) and not (-2147483648 <= val <= 2147483647): + info = torch.iinfo(torch.int64) + if val == info.min: + return "hl.Int(64).min()" + if val == info.max: + return "hl.Int(64).max()" + return f"hl.i64({val!r})" + if isinstance(val, float): + return f"hl.f64({constant_repr(val)})" + return repr(val) + + +class Unsupported(RuntimeError): + def __init__(self, thing) -> None: + super().__init__(f"halide backend does not support: {thing}") + + +class HalidePrinter(PythonPrinter): + @staticmethod + def cast_index(expr): + return f"hl.cast({V.kernel.index_dtype}, {expr})" + + @staticmethod + def cast_float(expr): + return f"hl.cast(hl.Float(32), {expr})" + + def _print_Float(self, expr): + return f"hl.f32({expr})" + + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"hl.f32({self._print(expr.args[0])})" + + def _print_floor(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.floor({self._print(expr.args[0])})") + + _print_FloorToInt = _print_floor + + def _print_Trunc(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.trunc({self._print(expr.args[0])})") + + _print_TruncToInt = _print_Trunc + + def _print_ceiling(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.ceil({self._print(expr.args[0])})") + + def _helper_sqrt(self, expr): + return f"hl.sqrt({self.cast_float(self._print(expr))})" + + def _print_Where(self, expr): + c = self.doprint(expr.args[0]) + p = self.doprint(expr.args[1]) + q = self.doprint(expr.args[2]) + return f"hl.select({c}, {p}, {q})" + + def _print_Min(self, expr): + if len(expr.args) == 1: + return self._print(expr.args[0]) + + mid = len(expr.args) // 2 + a = self._print(sympy.Min(*expr.args[:mid])) + b = self._print(sympy.Min(*expr.args[mid:])) + return f"hl.min({a}, {b})" + + def _print_Max(self, expr): + if len(expr.args) == 1: + return self._print(expr.args[0]) + + mid = len(expr.args) // 2 + a = self._print(sympy.Max(*expr.args[:mid])) + b = self._print(sympy.Max(*expr.args[mid:])) + + return f"hl.max({a}, {b})" + + def _print_Abs(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.abs({self._print(expr.args[0])})") + + def _print_OpaqueUnaryFn_cos(self, expr): + assert len(expr.args) == 1 + return f"hl.cos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cosh(self, expr): + assert len(expr.args) == 1 + return f"hl.cosh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_acos(self, expr): + assert len(expr.args) == 1 + return f"hl.acos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sin(self, expr): + assert len(expr.args) == 1 + return f"hl.sin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sinh(self, expr): + assert len(expr.args) == 1 + return f"hl.sinh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_asin(self, expr): + assert len(expr.args) == 1 + return f"hl.asin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tan(self, expr): + assert len(expr.args) == 1 + return f"hl.tan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tanh(self, expr): + assert len(expr.args) == 1 + return f"hl.tanh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_atan(self, expr): + assert len(expr.args) == 1 + return f"hl.atan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_log2(self, expr): + raise NotImplementedError("log2") + + def _print_FloorDiv(self, expr): + if expr.is_integer: + return super()._print_FloorDiv(expr) + + x, div = expr.args + x = self.cast_float(self.doprint(x)) + div = self.cast_float(self.doprint(div)) + return self.cast_index(f"hl.floor({x} / {div})") + + def _print_Round(self, expr): + assert len(expr.args) == 1 + return self.cast_index(f"hl.round({self._print(expr.args[0])})") + + _print_RoundToInt = _print_Round + + def _print_IntTrueDiv(self, expr): + a, b = expr.args + # force a cast to float + return f"({a}) / ({b}+hl.f32(0))" + + def _print_RoundDecimal(self, expr): + val, n = expr.args + val = self._print(val) + n = int(n) + return f"hl.f32({10.0 ** (-n)!r})*hl.round(({val})*hl.f32({10.0**n!r}))" + + +texpr = HalidePrinter().doprint +pexpr = PythonPrinter().doprint + + +_halide_type = { + torch.bool: "hl.Bool()", + torch.bfloat16: "hl.BFloat(16)", + torch.float16: "hl.Float(16)", + torch.float32: "hl.Float(32)", + torch.float64: "hl.Float(64)", + torch.int8: "hl.Int(8)", + torch.int16: "hl.Int(16)", + torch.int32: "hl.Int(32)", + torch.int64: "hl.Int(64)", + torch.uint8: "hl.UInt(8)", + torch.uint16: "hl.UInt(16)", + torch.uint32: "hl.UInt(32)", + torch.uint64: "hl.UInt(64)", +} + + +def halide_type(dtype): + return _halide_type[dtype] + + +def halide_acc_type(dtype): + if is_integer_dtype(dtype) and dtype.is_signed and dtype != torch.int64: + dtype = torch.int32 + if dtype in (torch.float16, torch.bfloat16): + dtype = torch.float32 + return halide_type(dtype) + + +class HalideOverrides(OpOverrides): + @staticmethod + def to_dtype( + x, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types=True, + ): + if dtype == torch.bool: + return f"({x} != 0)" + return f"hl.cast({halide_type(dtype)}, {x})" + + @staticmethod + def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype): + if src_dtype in (torch.float16, torch.bfloat16): + x = f"hl.cast({halide_type(src_dtype)}, {x})" # body compute is upcast to fp32 + line = f"hl.reinterpret({halide_type(dtype)}, {x})" + if dtype in (torch.float16, torch.bfloat16): + line = f"hl.cast(hl.Float(32), {line})" + return line + + @classmethod + def constant(cls, value, dtype): + return cls.to_dtype(halide_constant(value), dtype) + + @staticmethod + def abs(x): + return f"hl.abs({x})" + + @staticmethod + def exp(x): + if not hasattr(x, "name"): + return f"hl.exp({x})" + return f"hl.fast_exp(hl.cast(hl.Float(32), {x})) if {x.name}.type().bits() <= 32 else hl.exp({x})" + + @staticmethod + def sqrt(x): + return f"hl.sqrt({x})" + + @staticmethod + def minimum(a, b): + # return f"hl.min({a}, {b})" <== handles nan wrong + if not hasattr(a, "name"): + return f"hl.min({a}, {b})" + b = f"hl.cast({a.name}.type(), {b})" + return f"hl.select(({a}<{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.min({a}, {b})" + + @staticmethod + def maximum(a, b): + # return f"hl.max({a}, {b})" <== handles nan wrong + if not hasattr(a, "name"): + return f"hl.max({a}, {b})" + b = f"hl.cast({a.name}.type(), {b})" + return f"hl.select(({a}>{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.max({a}, {b})" + + @staticmethod + def where(a, b, c): + if hasattr(b, "name"): + c = f"hl.cast({b.name}.type(), {c})" + return f"hl.select({a}, {b}, {c})" + + @staticmethod + def cos(x): + return f"hl.cos({x})" + + @staticmethod + def sin(x): + return f"hl.sin({x})" + + @staticmethod + def lgamma(x): + raise Unsupported("lgamma") + + @staticmethod + def erf(x): + return f"hl.erf({x})" + + @staticmethod + def cosh(x): + return f"hl.cosh({x})" + + @staticmethod + def sinh(x): + return f"hl.sinh({x})" + + @staticmethod + def acos(x): + return f"hl.acos({x})" + + @staticmethod + def acosh(x): + return f"hl.acosh({x})" + + @staticmethod + def asin(x): + return f"hl.asin({x})" + + @staticmethod + def asinh(x): + return f"hl.asinh({x})" + + @staticmethod + def atan2(x, y): + return f"hl.atan2({x}, {y})" + + @staticmethod + def atan(x): + return f"hl.atan({x})" + + @staticmethod + def atanh(x): + return f"hl.atanh({x})" + + @staticmethod + def copysign(x, y): + raise Unsupported("copysign") + + @staticmethod + def erfinv(x): + raise Unsupported("erfinv") + + @staticmethod + def hypot(x, y): + return f"hl.hypot({x}, {y})" + + @staticmethod + def nextafter(x, y): + raise Unsupported("nextafter") + + @staticmethod + def logical_and(a, b): + return f"{a} & {b}" + + @staticmethod + def logical_not(a): + return f"{a} == 0" + + @staticmethod + def logical_or(a, b): + return f"{a} | {b}" + + @staticmethod + def logical_xor(a, b): + return f"({a} ^ {b})" + + @staticmethod + def bitwise_and(a, b): + return f"{a} & {b}" + + @staticmethod + def bitwise_not(a): + return f"~{a}" + + @staticmethod + def bitwise_or(a, b): + return f"{a} | {b}" + + @staticmethod + def bitwise_xor(a, b): + return f"{a} ^ {b}" + + @staticmethod + def bitwise_left_shift(a, b): + return f"{a} << {b}" + + @staticmethod + def bitwise_right_shift(a, b): + return f"{a} >> {b}" + + @staticmethod + def rand(seed, offset): + return f"halide_helpers.rand({seed}, {offset})" + + @staticmethod + def randn(seed, offset): + return f"halide_helpers.randn({seed}, {offset})" + + @staticmethod + def randint64(seed, offset, low, high): + return f"halide_helpers.randint64({seed}, {offset}, {low}, {high})" + + @staticmethod + def load_seed(name, offset): + return f"{ops.load(name, 0)} + {V.kernel.args.seed_offset('load_seed_offset', offset)}" + + @staticmethod + def rsqrt(x): + # return f"hl.fast_inverse_sqrt({x})" <== accuracy issues + return f"1./hl.sqrt({x})" + + @staticmethod + def tan(x): + return f"hl.tan({x})" + + @staticmethod + def tanh(x): + return f"hl.tanh({x})" + + @staticmethod + def signbit(x): + return f"(hl.reinterpret(hl.UInt(32), hl.cast(hl.Float(32), {x})) >> 31) != 0" + + @staticmethod + def fmod(a, b): + # TODO(jansel): find a better way to do this, builtin % has wrong sign + return f"{a} - hl.trunc({a}/{b})*{b}" + + @staticmethod + def pow(a, b): + return f"hl.pow({a}, {b})" # hl.fast_pow fails accuracy + + @staticmethod + def log(x): + return f"hl.log({x})" # hl.fast_log fails accuracy + + @staticmethod + def log2(x): + raise NotImplementedError("log2") + + @staticmethod + def isinf(x): + # workaround https://github.com/halide/Halide/issues/8309 + return f"hl.is_inf(hl.cast(hl.Float(32), {x}))" + + @staticmethod + def isnan(x): + # workaround https://github.com/halide/Halide/issues/8309 + return f"hl.is_nan(hl.cast(hl.Float(32), {x}))" + + @staticmethod + def round(x): + return f"hl.round({x})" + + @staticmethod + def floor(x): + return f"hl.floor({x})" + + @staticmethod + def int_truediv(a, b): + return f"({a}) / ({b} + hl.f32(0))" + + @staticmethod + def floordiv(a, b): + # TODO(jansel): find a better ways to do this, the select-based trick from triton.py didn't work + return ( + f"hl.floor(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})" + ) + + @classmethod + def sign(cls, x): + left = ops.to_dtype(ops.lt("0", x), torch.int8) + right = ops.to_dtype(ops.lt(x, "0"), torch.int8) + sub = ops.sub(left, right) + return f"hl.cast({x.name}.type(), {sub})" + + @staticmethod + def trunc(x): + return f"hl.trunc({x})" + + @staticmethod + def truncdiv(a, b): + # this causes crashes with floating point exception, see test_div_zero_dim_cpu + # return f"hl.div_round_to_zero({a}, {b})" + return ( + f"hl.trunc(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})" + ) + + @staticmethod + def ceil(x): + return f"hl.ceil({x})" + + @staticmethod + def relu(x): + return f"hl.max({x}, 0)" + + @classmethod + def index_expr(cls, expr, dtype): + index = V.kernel.prepare_indexing(expr) + var = V.kernel.genfunc( + V.kernel.index_to_str(index), + V.kernel.used_dims_from_index(index), + bounds=get_bounds_index_expr(expr), + ) + if dtype not in (torch.int32, torch.int64): + return ops.to_dtype(var, dtype) + return var + + @classmethod + def indirect_indexing(cls, index_var, size, check=True, wrap_neg=True): + # TODO(jansel): Halide only supports 32-bit indexing, we should error on overflow + index_var = ops.to_dtype(index_var, torch.int32) + index_var = ops.halide_clamp(index_var, size, check) + index_var.indirect_indexing_size = size + return sympy_index_symbol(str(index_var)) + + @classmethod + def halide_clamp(cls, value, size, check): + end = V.kernel.kexpr(V.kernel.rename_indexing(size) - 1) + if not isinstance(size, (int, sympy.Integer)): + end = f"hl.cast({value.name}.type(), {end})" + # Skip unsafe_promise_clamped to workaround: https://github.com/halide/Halide/issues/8261#issuecomment-2148835692 + # return f"hl.unsafe_promise_clamped({value}, 0, {end})" + return f"hl.clamp({value}, 0, {end})" + + @staticmethod + def masked(mask, body, other): + with V.kernel.mask_loads(mask, other) as new_mask: + result = body() + + if result.bounds.is_bool: + other = bool(other) + + # Take dtype from result to prevent accidental promotion + other = V.kernel.genfunc( + f"hl.cast({result.name}.type(), {halide_constant(other)})", + [], + bounds=ValueRanges.wrap(other), + ) + # TODO(jansel): look into removing the where in the same places triton does + return ops.where(new_mask, result, other) + + @staticmethod + def frexp(x): + raise NotImplementedError("frexp") + + +HalideOverrides._initialize_pointwise_overrides("halide") + + +class HalideCSEVariable(CSEVariable): + undefined_re = re.compile(r"\b(tmp\d+)\[\?\]") + + def __init__( + self, + name, + bounds: ValueRanges[Any], + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__(name, bounds, dtype) + self.used_dims: Optional[list[sympy.Symbol]] = None + + def update_on_args(self, name, args, kwargs): + used = OrderedSet(self.used_dims or ()) + for arg in itertools.chain(args, kwargs.values()): + if isinstance(arg, HalideCSEVariable): + assert arg.used_dims is not None, (name, arg, args) + used.update(arg.used_dims) + self.used_dims = V.kernel.sort_used_dims(used) + + def index_str(self, dims): + if len(dims) == 0: + return f"{self.name}[()]" + # Reversed since Halide is column major + return f"{self.name}[{', '.join(map(str, dims))}]" + + def __str__(self) -> str: + if self.used_dims is None: + # This will get recomputed and replaced in codegen_kernel() + return f"{self.name}[?]" + return self.index_str(self.used_dims) + + def subs_str(self, replacements): + assert self.used_dims is not None and all( + isinstance(x, sympy.Expr) for x in self.used_dims + ) + return self.index_str([replacements.get(n, n) for n in self.used_dims]) + + +@dataclasses.dataclass +class DimensionInfo: + expr: Optional[sympy.Expr] + size: sympy.Expr + stride: sympy.Expr + + def __init__(self, expr, size, stride) -> None: + super().__init__() + if V.graph.sizevars.statically_known_lt(stride, 0): + stride = -stride + expr = -expr + self.expr = expr + self.size = size + self.stride = stride + + def index_str(self, replacements=None, zero_vars=False): + assert self.expr is not None + expr = self.expr + if zero_vars and expr == 0: + return "hl.Var()" + if replacements: + replacements = {**replacements} + for sym in expr.free_symbols: + if symbol_is_type(sym, SymT.TMP): + assert isinstance(sym, sympy.Symbol) + var = V.kernel.lookup_cse_var(sym.name) + assert isinstance(var, HalideCSEVariable) + replacements[sym] = sympy_index_symbol(var.subs_str(replacements)) + expr = sympy_subs(expr, replacements) + return V.kernel.index_to_str(expr) + + +def eq(left, right): + if V.graph.sizevars.statically_known_equals(left, right): + return True + try: + a = V.graph.sizevars.size_hint(left) + b = V.graph.sizevars.size_hint(right) + except TypeError: # unbacked symints + return False + if a == b: + V.graph.sizevars.guard_equals(left, right) + return a == b + + +def lt(left, right): + if V.graph.sizevars.statically_known_lt(left, right): + return True + try: + a = V.graph.sizevars.size_hint(left) + b = V.graph.sizevars.size_hint(right) + except TypeError: # unbacked symints + gcd = sympy.gcd(left, right) + if gcd == left: + return left != right + return False + if a < b: + V.graph.sizevars.guard_lt(left, right) + return a < b + + +class HalideKernel(SIMDKernel): + overrides = HalideOverrides # type: ignore[assignment] + kexpr: Callable[[sympy.Expr], str] = texpr + + def __init__( + self, + tiling: dict[str, sympy.Expr], + **kwargs, + ) -> None: + super().__init__(tiling, **kwargs) + # For halide, we just write directly to the body + self.compute = self.body + self.loads = self.body + self.stores = self.body + self.indexing_code_dom = IndentedBuffer() + self.needs_dom_indexing = self.inside_reduction + self.has_reduction = self.inside_reduction + self.buffer_dimensions: dict[str, list[DimensionInfo]] = {} + self.buffer_offsets: dict[str, sympy.Expr] = {} + # {h0: size1, h1: size2, ...} + self.halide_vars: dict[sympy.Symbol, sympy.Expr] = {} + # {x0: h0, x1: h1+10*h2, ...} + self.index_replacements: dict[sympy.Expr, sympy.Expr] = {} + # {h1: hr1, ...} + self.reduction_renames: dict[sympy.Symbol, sympy.Symbol] = {} + # {"i": {h0: hi0}, "o": ...} + self.dom_renames: dict[str, dict[sympy.Symbol, sympy.Symbol]] = {} + # {"in_ptr0": ["in_ptr0_view0"], ...} + self.buffer_aliases: dict[str, list[str]] = defaultdict(list) + self.has_indirect_indexing = False + + def dtype_to_str(self, dtype: torch.dtype) -> str: + return halide_type(dtype) + + def create_cse_var(self, name, bounds=None, dtype=None): + self.body.writeline(f"{name} = hl.Func({name!r})") + return HalideCSEVariable(name, bounds, dtype) + + def finalize_indexing(self, indices: Sequence[sympy.Expr]): + """ + Hook called right before codegen with every index that will be + used in the fused kernel. + + This populates self.halide_vars/index_replacements/reduction_renames which is an alternate indexing + scheme that avoids using divide and modulus. Instead of xindex/yindex/rindex + we base indexing on a larger number of vars whose product combines to those. + + This function populates self.halide_vars, self.index_replacements, and self.reduction_renames + """ + assert not ( + self.index_replacements or self.halide_vars or self.reduction_renames + ) + size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf) # type: ignore[arg-type] + indices = dict.fromkeys(map(super().prepare_indexing, indices)) + all_used_symbols = OrderedSet[Any]() + sym_to_node = { + n.symbol(): n + for n in itertools.chain.from_iterable( + [tree.nodes.values() for tree in self.range_trees] + ) + } + + def simplify(expr): + return sympy.simplify( + V.graph.sizevars.remove_precomputed_replacements(expr) + ) + + def visit_modular_indexing(base, divisor, modulus): + if base in sym_to_node: + node = sym_to_node[base] + all_used_symbols.add( + node.root.lookup( + node.divisor * divisor, + V.graph.sizevars.evaluate_min( + modulus, FloorDiv(node.length, divisor) + ), + ).symbol() + ) + + def visit_floor_div(base, divisor): + if base in sym_to_node: + node = sym_to_node[base] + all_used_symbols.add( + node.root.lookup( + node.divisor * divisor, + FloorDiv(node.length, divisor), + ).symbol() + ) + + # first figure out all_used_symbols to do dead symbol elimination + for index in indices: + if index.has(ModularIndexing): + index.replace( + ModularIndexing( + sympy.Wild("base"), + sympy.Wild("divisor"), + sympy.Wild("modulus"), + ), + visit_modular_indexing, + ) + if index.has(FloorDiv): + index.replace( + FloorDiv( + sympy.Wild("base"), + sympy.Wild("divisor"), + ), + visit_floor_div, + ) + all_used_symbols.update(super().prepare_indexing(index).free_symbols) + + self.has_indirect_indexing = any( + symbol_is_type(sym, SymT.INDIRECT) for sym in all_used_symbols + ) + + had_fallback = False + for tree in reversed(self.range_trees): + nodes = [n for n in tree.nodes.values() if n.symbol() in all_used_symbols] + nodes.sort(key=lambda n: size_hint(n.divisor)) + if not nodes: + nodes.append(tree.lookup(1, tree.numel)) + handled_count = 0 + divisor = sympy.S.One + added_sym_size = [] + # decide on a minimal set of symbols and put them in self.halide_vars + while handled_count < len(nodes) and not eq(tree.numel, divisor): + sizes_to_add = [ + simplify(n.length) for n in nodes if eq(n.divisor, divisor) + ] + handled_count += len(sizes_to_add) + assert sizes_to_add, nodes + end = divisor * functools.reduce( + V.graph.sizevars.evaluate_max, sizes_to_add + ) + sizes_to_add.extend( + [ + simplify(n.divisor / divisor) + for n in nodes + if lt(divisor, n.divisor) and lt(n.divisor, end) + ] + ) + while sizes_to_add: + next_size = functools.reduce(sympy.gcd, sizes_to_add) + if eq(next_size, 1): + # sizes share no common factors, e.g [2, 21, 42, 441, 889056] + # TODO(jansel): we should just prevent fusion in cases that hit this + next_size = simplify(tree.numel / divisor) + assert not eq(next_size, 1) + sizes_to_add = [] + handled_count = len(nodes) + had_fallback = True + sym = sympy_index_symbol(f"h{len(self.halide_vars)}") + if tree.is_reduction: + self.reduction_renames[sym] = sympy_index_symbol( + f"hr{len(self.halide_vars)}" + ) + self.halide_vars[sym] = next_size + added_sym_size.append((sym, next_size)) + divisor *= next_size + new_sizes = [n.length for n in nodes if eq(n.divisor, divisor)] + handled_count += len(new_sizes) + prior_len = len(sizes_to_add) + sizes_to_add = [ + sympy.simplify(s / next_size) + for s in sizes_to_add + if not eq(s, next_size) + ] + assert len(sizes_to_add) < prior_len or prior_len == 0 + sizes_to_add.extend(new_sizes) + + # create a mapping to the new set of symbols in self.index_replacements + for node in nodes: + try: + idx = 0 + divisor = 1 + while not eq(node.divisor, divisor): + sym, size = added_sym_size[idx] + idx += 1 + divisor *= size + length = 1 + expr = sympy.S.Zero + while not eq(node.length, length): + sym, size = added_sym_size[idx] + idx += 1 + expr += length * sym + length *= size + self.index_replacements[node.symbol()] = expr + except IndexError: + assert had_fallback + full_index = sympy.S.Zero + stride = sympy.S.One + for sym, size in added_sym_size: + full_index += stride * sym + stride *= size + self.index_replacements[node.symbol()] = ( + V.graph.sizevars.simplify_with_ranges( + ModularIndexing(full_index, node.divisor, node.length), + self.halide_vars, # type: ignore[arg-type] + ) + ) + + # codegen the variable definitions + for sym in self.halide_vars: + self.indexing_code.writeline(f"{sym} = hl.Var({sym.name!r})") + if self.reduction_renames: + self.codegen_rdom( + "rdom", + {rv: self.halide_vars[v] for v, rv in self.reduction_renames.items()}, + ) + + def setup_dom_indexing(self): + """RDom based indexing uses explicit iteration ranges for Func updates""" + prefix = "i" if self.inside_reduction else "o" + if prefix in self.dom_renames: + return self.dom_renames[prefix] + + renames = {} + for var in self.halide_vars.keys(): + if not self.inside_reduction and var in self.reduction_renames: + continue + m = re.match(r"^h(\d+)$", var.name) + assert m + renames[var] = sympy_index_symbol(f"h{prefix}{m.group(1)}") + + self.codegen_rdom( + f"{prefix}dom", {rv: self.halide_vars[v] for v, rv in renames.items()} + ) + + self.dom_renames[prefix] = renames + return renames + + def codegen_rdom(self, name, vars): + rsizes = [ + f"hl.Range(0, {self.kexpr(self.rename_indexing(size))})" + for size in vars.values() + ] + self.indexing_code.writeline(f"{name} = hl.RDom([{', '.join(rsizes)}])") + for i, rsym in enumerate(vars.keys()): + self.indexing_code.writeline(f"{rsym} = {name}[{i}]") + + def prepare_indexing( + self, + index: sympy.Expr, + ): + index = super().prepare_indexing(index) + index = sympy_subs(index, self.index_replacements) + return V.graph.sizevars.simplify_with_ranges(index, self.halide_vars) # type: ignore[arg-type] + + def sym_size(self, sym): + """The size of an index symbol""" + if symbol_is_type(sym, SymT.TMP): + return self.lookup_cse_var(sym.name).indirect_indexing_size + return self.halide_vars[sym] + + def indexing_to_dimensions(self, var: str, index: sympy.Expr, is_store: bool): + """Convert address-based indexing into dimensions using self.halide_vars""" + symbols = [] + for sym in sorted(index.free_symbols, key=lambda x: x.name): # type: ignore[attr-defined] + if symbol_is_type(sym, (SymT.HALIDE, SymT.TMP)): + symbols.append(sym) + else: + assert symbol_is_type( + sym, + ( + SymT.UNBACKED_INT, + SymT.SIZE, + SymT.PRECOMPUTED_SIZE, + ), + ), sym + + # group the expression by variables used + offset = sympy.S.Zero + split_expr = dict.fromkeys(symbols, sympy.S.Zero) + split_failed: list[tuple[list[sympy.Symbol], sympy.Expr]] = [] + index = sympy.expand(self.rename_indexing(index)) + for part in index.args if isinstance(index, sympy.Add) else [index]: + part_vars = [v for v in part.free_symbols if v in split_expr] + if len(part_vars) == 0: + offset += part + elif len(part_vars) == 1: + split_expr[part_vars[0]] += part + else: + new_split_failed = [] + for i in range(len(split_failed)): + assert split_failed[i] is not None + other_vars, other_part = split_failed[i] + if OrderedSet(other_vars) & OrderedSet(part_vars): + part_vars.extend([v for v in other_vars if v not in part_vars]) + part += other_part + else: + new_split_failed.append((other_vars, other_part)) + split_failed = [*new_split_failed, (part_vars, part)] + + def expr_to_dimension(expr, syms): + expr = sympy.factor(expr) + if len(syms) == 1: + stride_wild = sympy.Wild("wild", exclude=symbols) + m = expr.match(stride_wild * syms[0]) + if m: + return DimensionInfo( + syms[0], self.sym_size(syms[0]), m[stride_wild] + ) + assert not is_store, expr + length = sympy.simplify( + sympy_subs(expr, {sym: self.sym_size(sym) - 1 for sym in syms}) + 1 + ) + stride = sympy.S.One + if isinstance(expr, sympy.Mul): + for term in expr.args: + if isinstance(term, sympy.Integer): + stride *= term + expr = sympy.simplify(expr / term) + length = sympy.simplify(sympy.ceiling(length / term)) + return DimensionInfo(expr, length, stride) + + # try to turn each group into a strided access + dims = [] + for syms, expr in split_failed: + for v in syms: + expr += split_expr.pop(v) + dims.append(expr_to_dimension(expr, syms)) + for sym, expr in split_expr.items(): + dims.append(expr_to_dimension(expr, [sym])) + dims.sort(key=lambda d: V.graph.sizevars.size_hint(d.stride, fallback=inf)) # type: ignore[arg-type] + + if not dims: # scalar load/store + if self.has_indirect_indexing: + # workaround https://github.com/halide/Halide/issues/8338 + dims.append(DimensionInfo(sympy.S.Zero, 1, 1)) + elif not V.graph.sizevars.statically_known_equals(dims[0].stride, 1): + # Halide assumes dimension 0 is stride == 1, so add a dummy dimension + dims.insert( + 0, DimensionInfo(sympy.S.Zero, 1 if is_store else dims[0].stride, 1) + ) + + if dims and not is_store: + if var in self.buffer_offsets and V.graph.sizevars.statically_known_geq( + offset, self.buffer_offsets[var] + ): + # reuse the existing offset to avoid needing an input alias + self.apply_offset_to_dimension(dims, offset - self.buffer_offsets[var]) + offset = self.buffer_offsets[var] + elif V.graph.sizevars.statically_known_gt( + offset, 0 + ): # TODO(jansel): negative offsets + # roll the offset into the dimensions for cleaner indexing + self.apply_offset_to_dimension(dims, offset) + offset = 0 + + orig_var = var + for i in itertools.count(): + if self.install_dims(var, dims, offset, is_store): + return var, dims + assert not is_store + var = f"{orig_var}_view{i}" + if var not in self.buffer_aliases[orig_var]: + self.buffer_aliases[orig_var].append(var) + + def install_dims(self, var, dims, offset, is_store): + """Try to set self.buffer_dimensions[var], return True on success""" + if var not in self.buffer_dimensions: + self.buffer_dimensions[var] = dims + self.buffer_offsets[var] = offset + return True + if self.buffer_offsets[var] != offset or len( + self.buffer_dimensions[var] + ) != len(dims): + return False + if is_store: + return self.buffer_dimensions[var] == dims + for old, new in zip(self.buffer_dimensions[var], dims): + if old.stride != new.stride: + return False + if old.size != new.size or old.expr != new.expr: + old.size = V.graph.sizevars.evaluate_max(old.size, new.size) + old.expr = None + return True + + def apply_offset_to_dimension(self, dims, offset): + if offset == 0: + return + for i in reversed(range(len(dims))): + if dims[i].stride == 1 or V.graph.sizevars.statically_known_geq( + offset, dims[i].stride + ): + part = FloorDiv(offset, dims[i].stride) + offset -= part * dims[i].stride + dims[i].expr += part + assert offset == 0 + + def used_dims_from_index(self, index: sympy.Expr): + """Detect which range trees are used to populate HalideCSEVariable.used_dims""" + used_dims = OrderedSet[sympy.Symbol]() + for sym in index.free_symbols: + assert isinstance(sym, sympy.Symbol) + if symbol_is_type(sym, SymT.TMP): + # indirect indexing + cse_var = self.lookup_cse_var(sym.name) + assert ( + isinstance(cse_var, HalideCSEVariable) + and cse_var.used_dims is not None + ) + used_dims.update(cse_var.used_dims) + elif symbol_is_type(sym, SymT.HALIDE): + used_dims.add(sym) + elif symbol_is_type( + sym, (SymT.UNBACKED_INT, SymT.SIZE, SymT.PRECOMPUTED_SIZE, SymT.INDEX) + ): + pass + else: + raise NotImplementedError(f"unhandled symbol {sym}") + return self.sort_used_dims(used_dims) + + def sort_used_dims(self, used_dims): + assert all(isinstance(x, sympy.Expr) for x in used_dims) + ordered = [ + sym + for sym in itertools.chain( + self.halide_vars, self.reduction_renames.values() + ) + if sym in used_dims + ] + assert len(ordered) == len(used_dims) + return ordered + + def make_index_str(self, dims, replacements=None, zero_vars=False): + index_str = ", ".join(d.index_str(replacements, zero_vars) for d in dims) + if len(dims) == 0: + index_str = "()" + elif len(dims) == 1: + # workaround for https://github.com/halide/Halide/issues/8299 + index_str = f"{index_str}," + return index_str + + def load(self, name: str, index: sympy.Expr): + """Codegen a load from an InputBuffer""" + var = self.args.input(name) + index = self.prepare_indexing(index) + var, dims = self.indexing_to_dimensions(var, index, False) + line = f"{var}[{self.make_index_str(dims)}]" + dtype = V.graph.get_dtype(name) + if dtype in (torch.float16, torch.bfloat16): + dtype = torch.float32 + line = f"hl.cast(hl.Float(32), {line})" + + if self._load_mask: + assert ( + isinstance(self._load_mask, HalideCSEVariable) + and self._load_mask.used_dims is not None + ) + used_dims = OrderedSet( + (*self.used_dims_from_index(index), *self._load_mask.used_dims) + ) + result = self.newfunc(self.sort_used_dims(used_dims)) + if result.used_dims: + self.body.writeline(f"{result.name}_mask = hl.RDom([hl.Range(0, 1)])") + self.body.writeline(f"{result.name}_mask.where({self._load_mask})") + other = self.kexpr(self._load_other or 0) # type: ignore[arg-type] + self.body.writeline( + f"{result} = hl.cast({halide_type(dtype)}, {other})" + ) + self.body.writeline( + f"{result} = {line} + hl.cast({halide_type(dtype)}, {result.name}_mask)" + ) + else: + # scalar case + self.body.writeline( + f"{result} = hl.select({self._load_mask}, {line}, hl.cast({halide_type(dtype)}, 0))" + ) + return result + else: + return self.genfunc(line, self.used_dims_from_index(index)) + + def lookup_cse_var(self, name: str): + return self.cse.varname_map[re.sub(r"\[.*", "", name)] + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + """Codegen a store to an OutputBuffer""" + assert isinstance(value, HalideCSEVariable) + var = self.args.output(name) + index = self.prepare_indexing(index) + var, dims = self.indexing_to_dimensions(var, index, True) + if self.is_indirect_indexing(index) or mode is not None: + replacements = self.setup_dom_indexing() + index_str = self.make_index_str(dims, replacements) + value_str = value.subs_str(replacements) + undef_dims = (", ".join(["hl.Var()"] * len(dims))) or "()" + self.body.writeline( + DeferredLine(name, f"{var}[{undef_dims}] = hl.undef({var}.type())") + ) + else: + index_str = self.make_index_str(dims, zero_vars=True) + value_str = str(value) + + dtype = V.graph.get_dtype(name) + if mode is None: + line = f"{var}[{index_str}] = hl.cast({halide_type(dtype)}, {value_str})" + elif mode == "atomic_add": + line = f"{var}[{index_str}] += hl.cast({halide_type(dtype)}, {value_str})" + else: + raise NotImplementedError(f"store mode={mode}") + self.body.writeline(DeferredLine(name, line)) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: + """Codegen a reduction operation""" + assert self.inside_reduction + assert not self._load_mask + cache_key = (src_dtype, reduction_type, value) + if cache_key in self.cse.reduction_cache: + return self.cse.reduction_cache[cache_key] + + if isinstance(value, tuple): + assert reduction_type == "welford_combine" + self.cse.reduction_cache[cache_key] = result_tuple = ( + self.welford_combine_impl(*value) + ) + return result_tuple + + assert isinstance(value, HalideCSEVariable) and value.used_dims is not None + reduction_vars = OrderedSet(self.reduction_renames) + result_var = self.newfunc( + [v for v in value.used_dims if v not in reduction_vars] + ) + if reduction_vars - OrderedSet(value.used_dims): + value = self.genfunc( + f"{value}", + self.sort_used_dims(OrderedSet((*value.used_dims, *reduction_vars))), + ) + value_str = value.subs_str(self.reduction_renames) + default = ir.Reduction.default_accumulator(reduction_type, src_dtype) + acc_type = halide_acc_type(dtype) + + if reduction_type in ("argmax", "argmin"): + index = f"{result_var.name}_{reduction_type}" + self.body.writeline(f"{index} = hl.{reduction_type}(rdom, {value_str})") + # turn the N-D argmax index into a 1-D one + parts = [] + stride = 1 + for i, sym in enumerate(self.reduction_renames): + parts.append(f"{index}[{i}]") + if stride != 1: + parts[-1] += f"*{stride}" + stride *= self.halide_vars[sym] + self.body.writeline(f"{result_var} = {' + '.join(parts)}") + elif reduction_type == "welford_reduce": + # TODO(jansel): implement welford_reduce without fallback + result_var = self.welford_reduce_fallback(dtype, value) + else: + combine_fn = get_reduction_combine_fn(reduction_type, acc_type) + with V.set_ops_handler(AddParenHandler(HalideOverrides())): + combine_str = combine_fn(result_var, value_str) # type: ignore[arg-type] + default_str = f"hl.cast({acc_type}, {halide_constant(default)})" + self.body.writeline(f"{result_var} = {default_str}") + self.body.writeline(f"{result_var} = {combine_str}") + + self.cse.reduction_cache[cache_key] = result_var + return result_var + + def welford_combine_impl(self, mean, m2, weight): + assert isinstance(mean, HalideCSEVariable) and mean.used_dims is not None + assert isinstance(m2, HalideCSEVariable) and m2.used_dims is not None + assert isinstance(weight, HalideCSEVariable) and weight.used_dims is not None + used_dims = OrderedSet( + (*mean.used_dims, *m2.used_dims, *weight.used_dims) or self.halide_vars + ) + used_dims -= OrderedSet(self.reduction_renames) + result_var = self.newfunc(self.sort_used_dims(used_dims)) + default = [f"hl.cast({x.name}.type(), 0)" for x in (mean, m2, weight)] + pfx = result_var.name + self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(default)}])") + self.body.writeline(f"{pfx}_mean_1 = {result_var}[0]") + self.body.writeline(f"{pfx}_m2_1 = {result_var}[1]") + self.body.writeline(f"{pfx}_weight_1 = {result_var}[2]") + self.body.writeline(f"{pfx}_mean_2 = {mean.subs_str(self.reduction_renames)}") + self.body.writeline(f"{pfx}_m2_2 = {m2.subs_str(self.reduction_renames)}") + self.body.writeline( + f"{pfx}_weight_2 = {weight.subs_str(self.reduction_renames)}" + ) + self.body.writeline(f"{pfx}_delta = {pfx}_mean_2 - {pfx}_mean_1") + self.body.writeline(f"{pfx}_new_weight = {pfx}_weight_1 + {pfx}_weight_2") + self.body.writeline( + f"{pfx}_w2_over_w = hl.select({pfx}_new_weight == 0.0, 0.0, {pfx}_weight_2 / {pfx}_new_weight)" + ) + update = [ + f"{pfx}_mean_1 + {pfx}_delta * {pfx}_w2_over_w", + f"{pfx}_m2_1 + {pfx}_m2_2 + {pfx}_delta * {pfx}_delta * {pfx}_weight_1 * {pfx}_w2_over_w", + f"{pfx}_new_weight", + ] + self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(update)}])") + + unpacked = [] + for i in range(3): + unpacked.append(self.newfunc(result_var.used_dims)) + self.body.writeline(f"{unpacked[-1]} = {result_var}[{i}]") + return tuple(unpacked) + + def scan( + self, + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[ + [tuple[CSEVariable, ...], tuple[CSEVariable, ...]], tuple[CSEVariable, ...] + ], + values_orig: tuple[CSEVariable, ...], + ) -> tuple[CSEVariable, ...]: + assert self.inside_reduction + assert len(dtypes) == len(values_orig) + values: list[HalideCSEVariable] = [] + all_used_dims = OrderedSet[sympy.Symbol]() + + for value in values_orig: + assert isinstance(value, HalideCSEVariable) and value.used_dims is not None + if OrderedSet(value.used_dims) & OrderedSet(self.reduction_renames): + values.append(value) + else: + values.append( + self.genfunc( + f"{value}", [*value.used_dims, [*self.reduction_renames][:1]] + ) + ) + all_used_dims.update(value.used_dims) + result_var = self.newfunc(self.sort_used_dims(all_used_dims)) + assert result_var.used_dims and OrderedSet(result_var.used_dims) & OrderedSet( + self.reduction_renames + ) + initial = [ + f"hl.cast({halide_acc_type(dtype)}, {value})" + for dtype, value in zip(dtypes, values) + ] + + length = self.kexpr(self.rename_indexing(self.range_trees[-1].numel)) + scan_dom = f"{result_var.name}_rdom" + scan = f"{scan_dom}.x" + self.body.writeline(f"{scan_dom} = hl.RDom([hl.Range(1, {length})])") + + assert len(self.reduction_renames) == 1, ( + "multi-dimensional scan not implemented" + ) + (scan_var,) = [*self.reduction_renames] # type: ignore[misc] + scan_renames_cur = {scan_var: sympy_index_symbol(scan)} + scan_renames_pri = {scan_var: sympy_index_symbol(scan) - 1} + + if len(values) == 1: + + def maybe_tuple(x): + return x[0] + + read_left = [result_var.subs_str(scan_renames_pri)] + read_right = [result_var.subs_str(scan_renames_cur)] + else: + + def maybe_tuple(x): + return f"hl.Tuple([{', '.join(x)}])" + + read_left = [ + result_var.subs_str(scan_renames_pri) + f"[{i}]" + for i in range(len(values)) + ] + read_right = [ + result_var.subs_str(scan_renames_cur) + f"[{i}]" + for i in range(len(values)) + ] + + self.body.writeline(f"{result_var} = {maybe_tuple(initial)}") + + # Disable CSE for update fn + with V.set_ops_handler(AddParenHandler(HalideOverrides())): + combine_str = combine_fn(read_left, read_right) # type: ignore[arg-type] + self.body.writeline( + f"{result_var.subs_str(scan_renames_cur)} = {maybe_tuple(combine_str)}" + ) + + if len(values) == 1: + return (result_var,) + + unpack_vars = [self.newfunc(self.sort_used_dims(all_used_dims)) for _ in values] + for i, v in enumerate(unpack_vars): + self.body.writeline(f"{v} = {result_var}[{i}]") + return tuple(unpack_vars) + + def genfunc( + self, line, used_dims, *, bounds=ValueRanges.unknown() + ) -> HalideCSEVariable: + var = self.cse.generate(self.body, line, bounds=bounds) + assert isinstance(var, HalideCSEVariable) + var.used_dims = used_dims + return var + + def newfunc(self, used_dims) -> HalideCSEVariable: + var = self.cse.newvar() + assert isinstance(var, HalideCSEVariable) + var.used_dims = used_dims + return var + + def halide_buffer_numel(self, name: str): + """ + We map all tensors to 1D buffers in Halide since Halide has trouble representing some strides that PyTorch + supports. If there are gaps in the underlying layout the numel we pass to Halide includes the gaps while + PyTorch's numel excludes them. + """ + return V.graph.get_buffer(name).get_layout().storage_size() + + def halide_argdefs(self): + """ + Halide requires scalar inputs before outputs, so need to reorder args. + """ + + def arg_order(arg_tuple): + _call_str, arg = arg_tuple + if isinstance(arg, SizeArg): + return 1 # this would normally be at the end, move it to middle + elif "out_ptr" in arg.name: + return 2 + else: + assert "in_ptr" in arg.name + return 0 + + result: list[tuple[Optional[str], KernelArgType]] = [] + _, a, b, _ = self.args.python_argdefs() + for call_str, arg in sorted(zip(a, b), key=arg_order): + result.append((call_str, arg)) + if isinstance(arg, TensorArg): + assert arg.offset == 0 and arg.alias_of is None + result.extend( + ( + None, + TensorArg( + alias, + arg.buffer, + arg.dtype, + arg.offset, + alias_of=arg.name, + ), + ) + for alias in self.buffer_aliases.get(arg.name, ()) + ) + return result + + def halide_kernel_meta(self) -> HalideMeta: + """Compute metadata required by codecache.py""" + argtypes = [] + for _, arg in self.halide_argdefs(): + if isinstance(arg, SizeArg): + shape = None + stride = None + offset = None + dtype = "long" + else: + shape = [ + cexpr(self.rename_indexing(x.size)) + for x in self.buffer_dimensions[arg.name] + ] + stride = [ + cexpr(self.rename_indexing(x.stride)) + for x in self.buffer_dimensions[arg.name] + ] + assert len(shape) == len(stride) + offset = cexpr(self.buffer_offsets[arg.name]) + dtype = f"{DTYPE_TO_CPP[arg.dtype]}*" + argtypes.append( + HalideInputSpec( + dtype, + arg.name, + shape=shape, + stride=stride, + offset=offset, + alias_of=arg.alias_of, + ) + ) + + current_device = V.graph.get_current_device_or_throw() + if current_device.type == "cpu": + target = [config.halide.cpu_target] + scheduler = config.halide.scheduler_cpu + scheduler_flags = { + "parallelism": parallel_num_threads(), + } + cuda_device = None + else: + assert current_device.type == "cuda", "only cpu/cuda supported" + assert current_device.index <= 0, "only default device supported" + target = [config.halide.gpu_target] + scheduler = config.halide.scheduler_cuda + capability = torch.cuda.get_device_properties(current_device) + if "cuda_capability" not in target[0]: + for major, minor in [(8, 6), (8, 0), (7, 5), (7, 0), (6, 1)]: + if capability.major >= major and capability.minor >= minor: + target.append(f"cuda_capability_{major}{minor}") + break + target.append("user_context") + scheduler_flags = { + "parallelism": capability.multi_processor_count, + # TODO(jansel): explore other flags, see: + # grep parser.parse ~/Halide/src/autoschedulers/anderson2021/AutoSchedule.cpp + } + cuda_device = max(0, current_device.index) + + # strict_float is requires for correctness + target.append("strict_float") + + # without this we will initialize cuda once per kernel and hit errors + target.append("no_runtime") + + if not config.halide.asserts: + target.append("no_asserts") + + if config.halide.debug: + target.append("debug") + + if "64" in self.index_dtype: + # TODO(jansel): it is unclear if this does anything, since input sizes are still int32 + target.append("large_buffers") + + return HalideMeta( + argtypes, + target="-".join(target), + scheduler=scheduler, + scheduler_flags=scheduler_flags, # type: ignore[arg-type] + cuda_device=cuda_device, + ) + + def codegen_kernel(self, name=None): + """Called at the end to generate a final kernel string""" + if self.args.inplace_buffers: + raise Unsupported("inplace_buffers") + meta = self.halide_kernel_meta() # ensure needed args are added early + code = IndentedBuffer() + code.splice( + """ + import halide as hl + from torch._inductor.runtime import halide_helpers + from math import inf, nan + + @hl.generator(name="kernel") + class Kernel: + """, + strip=True, + ) + code.do_indent() + for _, arg in self.halide_argdefs(): + if isinstance(arg, SizeArg): + code.writeline(f"{arg.name} = hl.InputScalar({self.index_dtype})") + else: + assert arg.buffer, arg + argcls = "hl.OutputBuffer" if "out" in arg.name else "hl.InputBuffer" + argtype = halide_type(arg.dtype) + ndim = len(self.buffer_dimensions[arg.name]) + code.writeline(f"{arg.name} = {argcls}({argtype}, {ndim})") + code.splice( + """ + def generate(g): + """ + ) + code.do_indent() + for _, arg in self.halide_argdefs(): + code.writeline(f"{arg.name} = g.{arg.name}") + for old, new in self.args.aliases(): + code.writeline(f"{old} = {new}") + code.splice(self.indexing_code) + + def update_index(m): + var = cast(HalideCSEVariable, self.cse.varname_map[m.group(1)]) + assert var.used_dims is not None, var + return str(var) + + for line in self.body._lines: + if isinstance(line, str): + # fill in missing indices + line = HalideCSEVariable.undefined_re.sub(update_index, line) + code.writeline(line) + code.writeline("") + code.writeline("assert g.using_autoscheduler()") + + for _, arg in self.halide_argdefs(): + # fallback=1 below because halide requires buffers to be at least as large as the estimates + # This causes crashes if our estimate is greater than the vector length + # https://github.com/halide/Halide/issues/3103 + if isinstance(arg, SizeArg): + hint = V.graph.sizevars.size_hint(arg.expr, fallback=1) + code.writeline(f"{arg.name}.set_estimate({hint})") + else: + dims = self.buffer_dimensions[arg.name] + range_hints = [] + for i, dim in enumerate(dims): + hint = self._autoscheduler_workarounds( + V.graph.sizevars.size_hint(dim.size, fallback=1), dims + ) + range_hints.append(f"hl.Range(0, {hint})") + if "out" not in arg.name: + code.writeline(f"{arg.name}.dim({i}).set_min(0)") + try: + code.writeline( + f"{arg.name}.dim({i}).set_stride({int(dim.stride)})" + ) + except TypeError: + pass # not integer + try: + code.writeline( + f"{arg.name}.dim({i}).set_extent({int(dim.size)})" + ) + except TypeError: + pass # not integer + code.writeline(f"{arg.name}.set_estimates([{', '.join(range_hints)}])") + + code.do_unindent(2) + code.splice( + """ + if __name__ == "__main__": + hl.main() + """.rstrip(), + ) + if meta.scheduler: + code.splice( + f""" + else: + hl.load_plugin({HalideCodeCache.find_libautoschedule(meta.scheduler)!r}) + target = hl.Target({meta.target!r}) + autoscheduler = hl.AutoschedulerParams({meta.scheduler!r}, {meta.scheduler_flags!r}) + with hl.GeneratorContext(target, autoscheduler): + gen = Kernel() + pipeline = gen._build_pipeline() + # gen.compile_to_callable() does not run the autoscheduler + pipeline.apply_autoscheduler(target, autoscheduler) + kernel = pipeline.compile_to_callable([ + gen._get_input_parameter(a.name)._to_argument() + for a in gen._get_arginfos() + if a.dir == hl.ArgInfoDirection.Input + ], target) + """, + strip=True, + ) + else: + code.splice( + f""" + else: + with hl.GeneratorContext(hl.Target({meta.target!r})): + kernel = Kernel().compile_to_callable() + """, + strip=True, + ) + return code.getvalue() + + @staticmethod + def _autoscheduler_workarounds(n, dims): + if ( + len(dims) == 1 + and config.halide.scheduler_cuda == "Anderson2021" + and V.graph.get_current_device_or_throw().type == "cuda" + ): + # workaround https://github.com/halide/Halide/issues/8246 + n = max(2, n) + return n + + def call_kernel(self, name: str, node=None): + """Codegen a call to this kernel""" + wrapper = V.graph.wrapper_code + call_args = [f"{n}" for n, arg in self.halide_argdefs() if arg.alias_of is None] + current_device = V.graph.get_current_device_or_throw() + if current_device.type == "cuda": + stream_name = wrapper.write_get_raw_stream( + current_device.index, V.graph.name + ) + call_args.append(stream_name) + wrapper.generate_kernel_call( + name, + call_args, + device=current_device, + triton=False, + ) + + def generate_assert(self, check): + return False # TODO(jansel): support asserts + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ): + pass # TODO(jansel): support asserts + + +class HalideScheduling(SIMDScheduling): + kernel_type = HalideKernel # type: ignore[arg-type,assignment] + + @classmethod + def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]: + result = OrderedSet( + [ + BackendFeature.TUPLE_REDUCTION, + BackendFeature.PREFER_STORE_LOOP_ORDER, + BackendFeature.REDUCE_TO_SINGLE_ELEMENT, + ] + ) + if config.halide.scan_kernels: + result.add(BackendFeature.SCAN) + return result + + def define_kernel(self, src_code, node_schedule, kernel): + """Codegen kernel definition to go in output wrapper code""" + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + kernel_name = f"halide_kernel_{wrapper.next_kernel_suffix()}" + wrapper.src_to_kernel[src_code] = kernel_name + wrapper.add_import_once( + "from torch._inductor.runtime.hints import HalideMeta, HalideInputSpec" + ) + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline( + f"async_compile.halide({kernel.halide_kernel_meta()!r}, '''" + ) + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline("''')") + + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment = f"{origins}\n{detailed_origins}" + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + if is_metric_table_enabled("kernel_metadata"): + log_kernel_metadata(kernel_name, "", src_code) + + return kernel_name diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/memory_planning.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/memory_planning.py new file mode 100644 index 0000000000000000000000000000000000000000..c43eab95aaee0a631c7232b978609af21a293fa2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/memory_planning.py @@ -0,0 +1,775 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import dataclasses +import itertools +import pprint +from typing import Any, Optional, Protocol, TYPE_CHECKING + +import sympy + +import torch +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..utils import _align, align, cache_on_self, CachedMethod, IndentedBuffer +from ..virtualized import V +from .wrapper import ( + AllocateLine, + BufferLike, + FreeIfNotReusedLine, + MemoryPlanningLine, + NullLine, + ReuseLine, +) + + +if TYPE_CHECKING: + from collections.abc import Iterable + + +@dataclasses.dataclass +class LiveRange: + """ + A range where a given tensor is live. Begin and end are both counters + representing points in the program of grouped memory operations. + Begin is inclusive, end is exclusive. + + Invariant: begin <= end + """ + + begin: float # int | +/-inf + end: float # int | +/-inf + + def contains(self, other: LiveRange): + """Is other entirely within self""" + return self.begin <= other.begin and other.end <= self.end + + def join(self, other: LiveRange): + """Combine two ranges using a union operation""" + return LiveRange(min(self.begin, other.begin), max(self.end, other.end)) + + def __len__(self): + return self.end - self.begin + + +class LiveRanges: + """ + A collection of LiveRange regions, allowing for non-contiguous + live regions. + + Invariant: LiveRanges.ranges is in sorted order and non-overlapping + """ + + def __init__(self, ranges: Iterable[LiveRange]): + ranges = [*sorted(ranges, key=lambda x: x.begin)] + self.ranges = ranges[:1] + for r in ranges[1:]: + assert self.ranges[-1].begin <= r.begin + if self.ranges[-1].end >= r.begin: + self.ranges[-1] = LiveRange.join(self.ranges[-1], r) + else: + self.ranges.append(r) + + def overlaps(self, other: LiveRanges): + """Check if any pair of ranges in self and other overlap""" + left = collections.deque(self.ranges) + right = collections.deque(other.ranges) + while left and right: + if left[0].begin > right[0].begin: + left, right = right, left + assert left[0].begin <= right[0].begin + if left[0].end > right[0].begin: + return True + left.popleft() + return False + + @property + def begin(self): + return self.ranges[0].begin + + @property + def end(self): + return self.ranges[-1].end + + def __repr__(self): + return f"{self.__class__.__name__}([{', '.join(map(repr, self.ranges))}])" + + +class AllocationTreeNode: + """ + Abstract base class for nodes in allocation pool. + """ + + def allocate(self, block: Allocation, is_last: bool) -> bool: + """ + Try to assign block to a memory location in this bool. Return True if + an assignment was made. + """ + return False + + def get_live_ranges(self) -> LiveRanges: + """Aggregate LiveRanges for all objects below this in tree""" + raise NotImplementedError + + def get_size_hint(self) -> int: + """Number of bytes used for example inputs""" + raise NotImplementedError + + def get_symbolic_size(self) -> sympy.Expr: + """Number of bytes needed at runtime""" + raise NotImplementedError + + def finalize(self, pool, offset) -> AllocationTreeNode: + """Called after all allocations have been made""" + return self + + def is_empty(self): + return False + + +@dataclasses.dataclass +class Allocation(AllocationTreeNode): + """ + Represents memory allocated to a given node in the allocation pool. + """ + + node: BufferLike + live_range: LiveRange + size_hint: int + symbolic_size: sympy.Expr + allocated: bool = False + pool: Optional[AllocationPool] = None + offset: Optional[sympy.Expr] = None + + @property + def device(self): + return self.node.get_device() + + def get_live_ranges(self): + return LiveRanges([self.live_range]) + + def get_size_hint(self): + return self.size_hint + + def get_symbolic_size(self): + return self.symbolic_size + + def mark_allocated(self): + assert not self.allocated + self.allocated = True + + def finalize(self, pool, offset): + assert self.pool is None and self.offset is None + self.pool = pool + self.offset = offset + return self + + def codegen_alloc_from_pool(self, wrapper): + assert self.pool + node = self.node + shape = tuple(node.get_size()) + stride = tuple(node.get_stride()) + return wrapper.codegen_alloc_from_pool( + self.pool.name, self.offset, node.get_dtype(), shape, stride + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"node={self.node.get_name()}, " + f"live_range={self.live_range}, " + f"size_hint={self.size_hint}, " + f"symbolic_size={self.symbolic_size}, " + f"pool={self.pool.name if self.pool else None}, " + f"offset={self.offset})" + ) + + +@dataclasses.dataclass +class Empty(AllocationTreeNode): + """ + Placeholder to represent empty space in the allocation pool. + Only exists to get the size_hint correct in parent nodes. + """ + + size_hint: int + + def get_live_ranges(self): + return LiveRanges([]) + + def get_size_hint(self): + return self.size_hint + + def get_symbolic_size(self): + return 0 + + def is_empty(self): + return True + + +class MemorySplitProtocol(Protocol): + get_live_ranges: CachedMethod[[], LiveRanges] + get_size_hint: CachedMethod[[], int] + get_symbolic_size: CachedMethod[[], sympy.Expr] + + def _allocate(self, block: Allocation, is_last: bool) -> bool: ... + + +class ClearCacheOnAllocateMixin(MemorySplitProtocol): + """ + Helper to assist in caching get_live_ranges, get_size_hint, and + get_symbolic_size. + """ + + def allocate(self, block: Allocation, is_last: bool): + is_allocated = self._allocate(block, is_last) + if is_allocated: + self.clear_cache() + return is_allocated + + def clear_cache(self): + self.get_live_ranges.clear_cache(self) + self.get_size_hint.clear_cache(self) + self.get_symbolic_size.clear_cache(self) + + +@dataclasses.dataclass +class TemporalSplit(ClearCacheOnAllocateMixin, AllocationTreeNode): + """ + Contains a list of allocations not overlapping in LiveRanges. + + Invariant: no pair (a,b) in self.allocations will have: + a.get_live_ranges().overlaps(b.get_live_ranges()) + """ + + allocations: list[AllocationTreeNode] + + def _allocate(self, block: Allocation, is_last: bool): + slot_size = self.get_size_hint() + block_size = block.get_size_hint() + if not is_last and block_size > slot_size: + return False # doesn't fit + + block_live = block.get_live_ranges() + overlapping = [ + s for s in self.allocations if s.get_live_ranges().overlaps(block_live) + ] + if len(overlapping) > 1: + # TODO(jansel): we could try harder here by merging overlapping in space + return False + elif len(overlapping) == 1: + return overlapping[0].allocate(block, is_last) + else: + block.mark_allocated() + + if len(self.allocations) == 1 and isinstance(self.allocations[-1], Empty): + self.allocations.pop() + + if slot_size == block_size: + # perfect fit + self.allocations.append(block) + elif slot_size > block_size: + self.allocations.append( + SpatialSplit.create(block, slot_size - block_size) + ) + else: # grow this allocation + assert is_last + self.allocations = [ + *( + SpatialSplit.create(a, block_size - slot_size) + for a in self.allocations + ), + block, + ] + return True + + @cache_on_self + def get_live_ranges(self) -> LiveRanges: + return LiveRanges( + itertools.chain.from_iterable( + x.get_live_ranges().ranges for x in self.allocations + ) + ) + + @cache_on_self + def get_size_hint(self) -> int: + if not self.allocations: + return 0 + return max(x.get_size_hint() for x in self.allocations) + + @cache_on_self + def get_symbolic_size(self) -> sympy.Expr: + if not self.allocations: + return 0 # type: ignore[return-value] + return sympy.Max(*[x.get_symbolic_size() for x in self.allocations]) + + def is_empty(self): + return len(self.allocations) == 1 and self.allocations[0].is_empty() + + def finalize(self, pool, offset): + self.allocations = [block.finalize(pool, offset) for block in self.allocations] + self.clear_cache() + if len(self.allocations) == 1: + return self.allocations[0] + return self + + +@dataclasses.dataclass +class SpatialSplit(ClearCacheOnAllocateMixin, AllocationTreeNode): + """ + Contains two allocations, left and right, that do not overlap in space. + Right will be allocated immediately after left in memory. + """ + + left: TemporalSplit + right: TemporalSplit + + @staticmethod + def create(left, extra_space): + assert isinstance(left, AllocationTreeNode) + assert isinstance(extra_space, int) and extra_space >= 1 + return SpatialSplit(TemporalSplit([left]), TemporalSplit([Empty(extra_space)])) + + def _allocate(self, block: Allocation, is_last: bool): + return self.left.allocate(block, False) or self.right.allocate(block, is_last) + + @cache_on_self + def get_live_ranges(self): + return LiveRanges( + itertools.chain( + self.left.get_live_ranges().ranges, self.right.get_live_ranges().ranges + ) + ) + + @cache_on_self + def get_size_hint(self) -> int: + return _align(self.left.get_size_hint()) + self.right.get_size_hint() + + @cache_on_self + def get_symbolic_size(self) -> sympy.Expr: + return align(self.left.get_symbolic_size()) + self.right.get_symbolic_size() + + def finalize(self, pool, offset): + self.left = self.left.finalize(pool, offset) + self.right = self.right.finalize( + pool, offset + align(self.left.get_symbolic_size()) + ) + self.clear_cache() + if self.right.is_empty(): + return self.left + return self + + +@dataclasses.dataclass +class AllocationPool: + """ + Represents a pool of allocations that will be generated by a single + call to torch.empty. + """ + + device: torch.device + root: TemporalSplit + can_expand: bool = True + restrict_live_range: Optional[LiveRange] = None + name: Optional[str] = None + names_to_del: list[str] = dataclasses.field(default_factory=list) + creation_cache: dict[str, str] = dataclasses.field(default_factory=dict) + + def allocate(self, block: Allocation, is_last: bool): + if self.restrict_live_range and not self.restrict_live_range.contains( + block.live_range + ): + return False + + is_last = self.can_expand and is_last + if self.root.allocate(block, is_last): + return True + + if is_last: + return self.allocate_at_end(block) + + return False + + def allocate_at_end(self, block): + block.mark_allocated() + self.root = TemporalSplit([SpatialSplit(self.root, TemporalSplit([block]))]) + return True + + def finalize(self, name): + assert not self.name + self.name = name + self.names_to_del.append(name) + self.root.finalize(self, 0) + + def codegen_create(self, wrapper, code: IndentedBuffer): + assert self.name + nbytes = self.root.get_symbolic_size() + for block in self.root.allocations: + if isinstance(block, Allocation) and nbytes == block.get_symbolic_size(): + # optimization: fuse first allocation and pool creation + node = block.node + code.writeline( + wrapper.make_allocation( + self.name, + device=self.device, + dtype=node.get_dtype(), + shape=tuple(node.get_size()), + stride=tuple(node.get_stride()), + ) + ) + self.creation_cache[block.codegen_alloc_from_pool(wrapper)] = self.name + return + else: + code.writeline( + wrapper.make_allocation( + self.name, + device=self.device, + dtype=torch.uint8, + shape=(nbytes,), + stride=(1,), + ) + ) + + def codegen_destroy(self, wrapper, code: IndentedBuffer): + code.writeline(wrapper.make_free_by_names(self.names_to_del)) + + def __eq__(self, other): + return self is other + + def __hash__(self): + return id(self) + + +@dataclasses.dataclass +class AllocationPools: + """ + Collection of many AllocationPool objects grouped by device. + """ + + device_to_pools: dict[torch.device, list[AllocationPool]] = dataclasses.field( + default_factory=dict + ) + + def get_pools(self, block): + if block.device not in self.device_to_pools: + self.device_to_pools[block.device] = [] + return self.device_to_pools[block.device] + + def allocate(self, block: Allocation): + pools = self.get_pools(block) + + for pool in pools: + if pool.allocate(block, is_last=pool is pools[-1]): + return + + # everything is full, make a new pool + pools.append( + AllocationPool( + block.device, + TemporalSplit([block]), + can_expand=config.memory_pool != "none", + ) + ) + block.mark_allocated() + + def allocate_output(self, block: Allocation): + """Outputs get different pools so memory gets freed properly""" + pools = self.get_pools(block) + if pools and config.memory_pool in ("outputs", "combined"): + pools[-1].allocate_at_end(block) + else: + # create a new pool + block.mark_allocated() + pools.append( + AllocationPool( + block.device, + TemporalSplit([block]), + can_expand=config.memory_pool == "combined", + ) + ) + + def finalize(self): + """Called at the end of allocation process""" + for i, pool in enumerate( + itertools.chain.from_iterable(self.device_to_pools.values()) + ): + pool.finalize(f"pool{i}") + + def pprint(self): + for pool in itertools.chain.from_iterable(self.device_to_pools.values()): + print() + print(pool.name) + print(pool.root.get_live_ranges()) + pprint.pprint(pool.root) + + +class BufferGroup: + """ + Due to inplace reuse an allocated buffer can have many names. + This tracks these collections of buffers sharing underlying memory. + """ + + def __init__(self, node: BufferLike): + self.node = node + self.names = [node.get_name()] + self.is_output = False + self.allocation: Optional[Allocation] = None + self.live_range = LiveRange(float("inf"), -float("inf")) + + def update_usage(self, timestep: int): + """Expand self.live_range to include timestep""" + self.live_range = LiveRange( + min(timestep, self.live_range.begin), + max(timestep, self.live_range.end), + ) + + def sym_nbytes(self): + return self.node.get_layout().storage_size() * self.node.get_dtype().itemsize + + def make_allocation(self): + assert not self.allocation, "multiple allocations" + assert isinstance(self.live_range.begin, int), "live ranges not computed" + nbytes = self.sym_nbytes() + # For now, fallback value will be used if we encounter an unbacked SymInt. The longer-term plan is to have + # size_hint() use better heuristics for unbackeds, at which point the fallback value will be ignored. + size_hint = V.graph.sizevars.size_hint(nbytes, fallback=64) + self.allocation = Allocation( + self.node, + self.live_range, + size_hint=size_hint, + symbolic_size=nbytes, + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.names!r}, is_output={self.is_output}, " + f"live_range={self.live_range}" + ) + + +@dataclasses.dataclass +class PoolMemoryPlanningLine(MemoryPlanningLine): + """Abstract base class for {Alloc,Dealloc}FromPoolLine""" + + group: BufferGroup + timestep: Optional[int] = None + + @property + def node(self): + return self.group.node + + +@dataclasses.dataclass +class AllocFromPoolLine(PoolMemoryPlanningLine): + """Similar to AllocationLine, but takes memory from a pool""" + + is_first_pool_usage: bool = False + + def codegen(self, code: IndentedBuffer): + allocation = self.group.allocation + assert allocation and allocation.pool + pool = allocation.pool + name = self.node.get_name() + + if self.is_first_pool_usage: + pool.codegen_create(self.wrapper, code) + + pool.names_to_del.extend(self.group.names) + alloc_from_pool = allocation.codegen_alloc_from_pool(self.wrapper) + if alloc_from_pool in pool.creation_cache: + code.writeline( + self.wrapper.make_tensor_alias( + name, pool.creation_cache[alloc_from_pool], "alloc" + ) + ) + else: + pool.creation_cache[alloc_from_pool] = name + code.writeline( + f"{self.wrapper.declare}{name} = {alloc_from_pool}{self.wrapper.ending}" + ) + + +@dataclasses.dataclass +class DeallocFromPoolLine(PoolMemoryPlanningLine): + """Similar to FreeIfNotReusedLine, but takes memory from a pool""" + + is_last_pool_usage: bool = False + + def codegen(self, code: IndentedBuffer): + if self.is_last_pool_usage: + assert self.group.allocation and self.group.allocation.pool + self.group.allocation.pool.codegen_destroy(self.wrapper, code) + + +@dataclasses.dataclass +class MemoryPlanner: + """ + Coordination object to run memory planning passes during wrapper + codegen. + """ + + wrapper: Any + pools: AllocationPools = dataclasses.field(default_factory=AllocationPools) + buffer_groups: Optional[list[BufferGroup]] = None + + def plan(self, lines: list[Any]) -> list[Any]: + """Call all the memory planning passes in sequence""" + lines = [*lines] + self.drop_removed_buffers(lines) + self.convert_to_pool_lines(lines) + self.compute_live_ranges(lines) + self.allocate_groups() + self.mark_first_last_usage(lines) + return lines + + def drop_removed_buffers(self, lines): + """ + Replace any memory planning lines in V.graph.removed_buffers with NullLine + """ + # drop any removed buffers + for i, line in enumerate(lines): + if isinstance(line, (AllocateLine, FreeIfNotReusedLine, ReuseLine)): + if line.node.get_name() in V.graph.removed_buffers: + lines[i] = NullLine(self.wrapper) + + def compute_buffer_groups(self, lines): + """ + Populates self.buffer_groups with BufferGroup objects that join + allocations with common storage (due to inplace reuse) into a + single object. + """ + name_to_group = {} + for line in lines: + if isinstance(line, AllocateLine): + name = line.node.get_name() + assert name not in name_to_group + name_to_group[name] = BufferGroup(line.node) + elif isinstance(line, ReuseLine): + old_name = line.node.get_name() + new_name = line.reused_as.get_name() + assert new_name not in name_to_group + # TODO(jansel): we should support reusing buffers created via ExternKernelAlloc + if old_name in name_to_group: + name_to_group[old_name].names.append(new_name) + name_to_group[new_name] = name_to_group[old_name] + + outputs = OrderedSet(V.graph.get_output_names()) + unique_groups = [*{id(g): g for g in name_to_group.values()}.values()] + for group in unique_groups: + group.is_output = any(x in outputs for x in group.names) + + assert self.buffer_groups is None + self.buffer_groups = unique_groups + return name_to_group + + def convert_to_pool_lines(self, lines): + """ + Convert AllocateLine/FreeIfNotReusedLine/ReuseLine into their + pool-based counterparts. + """ + name_to_group = self.compute_buffer_groups(lines) + for i, line in enumerate(lines): + if isinstance(line, AllocateLine): + if line.node.get_name() in name_to_group: + lines[i] = AllocFromPoolLine( + self.wrapper, name_to_group[line.node.get_name()] + ) + elif isinstance(line, FreeIfNotReusedLine): + assert not line.is_reused + if line.node.get_name() in name_to_group: + lines[i] = DeallocFromPoolLine( + self.wrapper, name_to_group[line.node.get_name()] + ) + elif isinstance(line, ReuseLine): + if line.node.get_name() in name_to_group: + line.delete_old = False + + def compute_live_ranges(self, lines): + """Populate every BufferGroup.live_ranges field based on first/last usage""" + timestep = 0 + worklist = collections.deque(lines) + while worklist: + if isinstance(worklist[0], MemoryPlanningLine): + timestep += 1 + while worklist and isinstance(worklist[0], MemoryPlanningLine): + line = worklist.popleft() + if isinstance(line, PoolMemoryPlanningLine): + line.group.update_usage(timestep) + line.timestep = timestep + else: + worklist.popleft() + + timestep += 1 + assert self.buffer_groups is not None + for group in self.buffer_groups: + if group.is_output: + group.update_usage(timestep) + + def allocate_groups(self): + """ + Assign every allocation to a specific location in a specific AllocationPool. + """ + assert config.memory_pool in ("none", "intermediates", "outputs", "combined") + assert self.buffer_groups is not None + + for group in self.buffer_groups: + group.make_allocation() + + outputs: list[Allocation] = [] + intermediates: list[Allocation] = [] + for group in self.buffer_groups: + assert group.allocation + if group.is_output and config.memory_pool != "combined": + outputs.append(group.allocation) + else: + intermediates.append(group.allocation) + + for block in sorted( + outputs, + key=lambda x: ( + x.size_hint, + -len(x.live_range), + ), + ): + self.pools.allocate_output(block) + + for block in sorted( + intermediates, + key=lambda x: ( + -x.size_hint, + -len(x.live_range), + ), + ): + self.pools.allocate(block) + + self.pools.finalize() + + def mark_first_last_usage(self, lines): + """ + Populate the AllocFromPoolLine.is_first_pool_usage and + DeallocFromPoolLine.is_last_pool_usage fields so that pools + are created/destroyed. + """ + seen = OrderedSet[AllocationPool]() + for line in lines: + if isinstance(line, AllocFromPoolLine): + assert line.group.allocation + pool = line.group.allocation.pool + assert pool is not None + if pool not in seen: + line.is_first_pool_usage = True + seen.add(pool) + + seen = OrderedSet[AllocationPool]() + for line in reversed(lines): + if isinstance(line, DeallocFromPoolLine): + assert line.group.allocation + pool = line.group.allocation.pool + assert pool is not None + if pool not in seen: + line.is_last_pool_usage = ( + pool.root.get_live_ranges().end <= line.timestep + ) + seen.add(pool) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/mps.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/mps.py new file mode 100644 index 0000000000000000000000000000000000000000..d57b410194ad62a6621eb4de78913ef6a932f98d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/mps.py @@ -0,0 +1,988 @@ +# This is not a feature-complete compiler backend +# Just an early prototype that shows that one can compile elementwise ops into a Metal shader +from __future__ import annotations + +import functools +import itertools +import logging +import math +from pathlib import Path +from typing import Any, Optional, TYPE_CHECKING + +import sympy +from sympy.printing.precedence import PRECEDENCE + +import torch +from torch.utils._cpp_embed_headers import _embed_headers +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.printers import CppPrinter, ExprPrinter as ExprPrinter_ +from torch.utils._sympy.value_ranges import ValueRanges + +from ..utils import ceildiv, get_bounds_index_expr, get_kernel_metadata +from ..virtualized import ops, OpsWrapper, V +from .common import ( + CSEVariable, + DeferredLine, + DTYPE_TO_COMPUTATION_DTYPE, + IndentedBuffer, + OpOverrides, + PythonPrinter, +) +from .simd import IterationRangesEntry, SIMDKernel, SIMDScheduling + + +if TYPE_CHECKING: + from typing import Union + + from ..ops_handler import ReductionType, StoreMode + from ..scheduler import Scheduler, SchedulerNode + from .common import OpVarT + +log = logging.getLogger(__name__) + +DTYPE_TO_METAL = { + torch.bool: "bool", + torch.int8: "char", + torch.int16: "short", + torch.int32: "int", + torch.int64: "long", + torch.uint8: "uchar", + torch.float: "float", + torch.half: "half", + torch.bfloat16: "bfloat", +} + + +def value_to_metal(val: Union[float, int, bool, str, CSEVariable]) -> str: + if isinstance(val, float): + if val == torch.inf: + return "HUGE_VALF" + elif val == -torch.inf: + return "-HUGE_VALF" + elif val != val: # Only float that not equal to self is nan + return "NAN" + return str(val) + elif isinstance(val, bool): + return "true" if val else "false" + return str(val) + + +class MetalExprPrinter(ExprPrinter_): + """Converts sympy expression to Metal code snippet""" + + def _print_FloorDiv(self, expr: sympy.Expr) -> str: + x, div = expr.args + x = self.doprint(x) + div = self.doprint(div) + if expr.is_integer: + return f"c10::metal::floor_divide({x}, {div})" + return f"metal::floor({x}) / ({div})" + + def _print_ModularIndexing(self, expr: sympy.Expr) -> str: + x, div, mod = expr.args + x = self.doprint(x) + if div != 1: + div = self.doprint(div) + if expr.is_integer: + x = f"({x}) / ({div})" + else: + x = f"metal::floor({x}) / ({div})" + mod = self.doprint(mod) + return f"({x}) % ({mod})" + + def _print_Min(self, expr: sympy.Expr) -> str: + if len(expr.args) != 2: + raise RuntimeError("metal::min only supported for 2 args") + a, b = map(self._print, expr.args) + typecast_a = f"static_cast({a})" + typecast_b = f"static_cast({b})" + return f"metal::min({typecast_a}, {typecast_b})" + + def _print_Max(self, expr: sympy.Expr) -> str: + if len(expr.args) != 2: + raise RuntimeError("metal::max only supported for 2 args") + a, b = map(self._print, expr.args) + typecast_a = f"static_cast({a})" + typecast_b = f"static_cast({b})" + return f"metal::max({typecast_a}, {typecast_b})" + + def _print_Abs(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"metal::abs({self._print(expr.args[0])})" + + def _print_RoundToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"static_cast(metal::rint({self._print(expr.args[0])}))" + + def _print_RoundDecimal(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 2 + number, ndigits = expr.args + if number.is_integer: + # ndigits < 0 should have been filtered by the sympy function + assert ndigits < 0 + raise ValueError( + f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." + ) + number_str = self.parenthesize(number, PRECEDENCE["Mul"]) + return f"static_cast(metal::rint(1e{ndigits} * {number_str}) * 1e{-ndigits})" + + def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: + lhs, rhs = expr.args + # TODO: This is only accurate up to 2**23 + return f"static_cast({self._print(lhs)}) / static_cast({self._print(rhs)})" + + def _print_PowByNatural(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 2 + x, y = map(self.doprint, expr.args) + return f"metal::pow(static_cast({x}), static_cast({y}))" + + def _print_ToFloat(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + x = self.doprint(expr.args[0]) + return f"static_cast({x})" + + def _print_FloorToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + x = self.doprint(expr.args[0]) + return f"static_cast(metal::floor(static_cast({x})))" + + _print_floor = _print_FloorToInt + + def _print_TruncToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + x = self.doprint(expr.args[0]) + return f"static_cast(metal::trunc({x}))" + + def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + x = self.doprint(expr.args[0]) + return f"metal::log2({x})" + + +class MetalOverrides(OpOverrides): + """Implements Metal-specific overrides for ops. Base class emits Python-friendly overrides.""" + + @staticmethod + def to_dtype( + x: CSEVariable, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types: bool = True, + ) -> str: + if dtype == torch.double: + log.warning( + "float64 cast requested, probably from tensorify_python_scalars" + ) + return f"static_cast({x})" + return f"static_cast<{DTYPE_TO_METAL[dtype]}>({x})" + + @staticmethod + def to_dtype_bitcast( + x: CSEVariable, dtype: torch.dtype, src_dtype: torch.dtype + ) -> str: + return f"as_type<{DTYPE_TO_METAL[dtype]}>(static_cast<{DTYPE_TO_METAL[src_dtype]}>({x}))" + + @staticmethod + def constant(val: Union[bool, float, int], dtype: torch.dtype) -> str: + return value_to_metal(val) + + @staticmethod + def index_expr(expr: sympy.Expr, dtype: torch.dtype) -> str: + idx_str = V.kernel.index_to_str(V.kernel.prepare_indexing(expr)) + var = V.kernel.cse.generate( + V.kernel.compute, idx_str, bounds=get_bounds_index_expr(expr) + ) + return ops.to_dtype(var, dtype) + + @staticmethod + def masked(mask: CSEVariable, body: sympy.Expr, other: CSEVariable) -> str: + # TODO: Type annotation for other is wrong, it's often float or int + with V.kernel.mask_loads(mask, other) as new_mask: + result = body() + + if result.bounds.is_bool: + other = bool(other) # type: ignore[assignment] + + return ops.where(new_mask, result, other) + + @staticmethod + def where(a: OpVarT, b: OpVarT, c: OpVarT) -> str: + return f"{a} ? {b} : {value_to_metal(c)}" + + @staticmethod + def remainder(a: OpVarT, b: OpVarT) -> str: + return f"c10::metal::remainder({a}, {b})" + + @staticmethod + def maximum(a: CSEVariable, b: CSEVariable) -> str: + typecast_a = f"static_cast({a})" + typecast_b = f"static_cast({b})" + return f"c10::metal::max({typecast_a}, {typecast_b})" + + @staticmethod + def minimum(a: CSEVariable, b: CSEVariable) -> str: + typecast_a = f"static_cast({a})" + typecast_b = f"static_cast({b})" + return f"c10::metal::min({typecast_a}, {typecast_b})" + + @staticmethod + def logical_or(a: CSEVariable, b: CSEVariable) -> str: + return f"{a} || {b}" + + @staticmethod + def logical_and(a: CSEVariable, b: CSEVariable) -> str: + return f"{a} && {b}" + + @staticmethod + def isnan(x: CSEVariable) -> str: + return f"metal::isnan({x})" + + @staticmethod + def isinf(x: CSEVariable) -> str: + return f"metal::isinf({x})" + + @staticmethod + def log(x: CSEVariable) -> str: + return f"metal::log({x})" + + @staticmethod + def exp(x: CSEVariable) -> str: + return f"metal::exp({x})" + + @staticmethod + def abs(x: CSEVariable) -> str: + return f"metal::abs({x})" + + @staticmethod + def signbit(x: CSEVariable) -> str: + return f"metal::signbit({x})" + + @staticmethod + def sin(x: CSEVariable) -> str: + return f"metal::precise::sin({x})" + + @staticmethod + def sinc(x: CSEVariable) -> str: + return f"c10::metal::sinc({x})" + + @staticmethod + def cos(x: CSEVariable) -> str: + return f"metal::precise::cos({x})" + + @staticmethod + def tan(x: CSEVariable) -> str: + return f"metal::tan({x})" + + @staticmethod + def asin(x: CSEVariable) -> str: + return f"metal::asin({x})" + + @staticmethod + def acos(x: CSEVariable) -> str: + return f"metal::acos({x})" + + @staticmethod + def atan(x: CSEVariable) -> str: + return f"metal::atan({x})" + + @staticmethod + def atan2(x: CSEVariable, y: CSEVariable) -> str: + return f"::metal::atan2({x}, {y})" + + @staticmethod + def sqrt(x: CSEVariable) -> str: + return f"metal::sqrt({x})" + + @staticmethod + def neg(x: CSEVariable) -> str: + # TODO: Does it rely on undefined behavior? + # If so, add special logic for unsigned types + return f"static_cast(-{x})" + + @staticmethod + def rsqrt(x: CSEVariable) -> str: + return f"metal::rsqrt({x})" + + @staticmethod + def tanh(x: CSEVariable) -> str: + return f"metal::tanh({x})" + + @staticmethod + def atanh(x: CSEVariable) -> str: + return f"metal::atanh({x})" + + @staticmethod + def floordiv(a: CSEVariable, b: CSEVariable) -> str: + # a and b must be of integer type + return f"c10::metal::floor_divide({a}, {b})" + + @staticmethod + def floor(x: CSEVariable) -> str: + return f"metal::floor({x})" + + @staticmethod + def sign(x: CSEVariable) -> str: + return f"metal::sign({x})" + + @staticmethod + def fmod(a: CSEVariable, b: CSEVariable) -> str: + typecast_a = f"static_cast({a})" + typecast_b = f"static_cast({b})" + return f"metal::fmod({typecast_a}, {typecast_b})" + + @staticmethod + def trunc(x: CSEVariable) -> str: + return f"metal::trunc({x})" + + @staticmethod + def truncdiv(a: CSEVariable, b: CSEVariable) -> str: + quot = f"{a} / {b}" + if (a.dtype is not None and a.dtype.is_floating_point) or ( + b.dtype is not None and b.dtype.is_floating_point + ): + return f"metal::trunc({quot})" + return quot + + @staticmethod + def ceil(x: CSEVariable) -> str: + return f"metal::ceil({x})" + + @staticmethod + def rand(seed: CSEVariable, offset: CSEVariable) -> str: + V.kernel.headers.add("random") + return f"c10::metal::rand({seed}, {offset})" + + @staticmethod + def randn(seed: CSEVariable, offset: CSEVariable) -> str: + V.kernel.headers.add("random") + return f"c10::metal::randn({seed}, {offset})" + + @staticmethod + def randint64( + seed: CSEVariable, offset: CSEVariable, low: CSEVariable, high: CSEVariable + ) -> str: + V.kernel.headers.add("random") + return f"c10::metal::randint64({seed}, {offset}, {low}, {high})" + + @staticmethod + def round(x: CSEVariable) -> str: + return f"metal::round({x})" + + @staticmethod + def pow(a: CSEVariable, b: CSEVariable) -> str: + cast_a = f"static_cast({a})" + cast_b = f"static_cast({b})" + return f"metal::pow({cast_a}, {cast_b})" + + def _special_unary(self, a: CSEVariable, name: str) -> str: + V.kernel.headers.add("special_math") + return f"c10::metal::{name}({a})" + + def _special_binary(self, a: CSEVariable, b: CSEVariable, name: str) -> str: + V.kernel.headers.add("special_math") + return f"c10::metal::{name}({a}, {b})" + + @classmethod + def _initialize_special_ops(cls) -> None: + # Unary special ops + for name in [ + "erf", + "erfinv", + "i0", + "i0e", + "i1", + "i1e", + "digamma", + "spherical_bessel_j0", + ]: + setattr(cls, name, functools.partialmethod(cls._special_unary, name=name)) + + cls.lgamma = functools.partialmethod(cls._special_unary, name="log_gamma") # type: ignore[assignment] + + # Unary special ops with forward in method name + for name in [ + "bessel_j0", + "bessel_j1", + "bessel_y0", + "bessel_y1", + "modified_bessel_i0", + "modified_bessel_i1", + "modified_bessel_k0", + "modified_bessel_k1", + "scaled_modified_bessel_k0", + "scaled_modified_bessel_k1", + ]: + setattr( + cls, + name, + functools.partialmethod(cls._special_unary, name=name + "_forward"), + ) + + # Binary special ops + for name in [ + "polygamma", + "zeta", + ]: + setattr(cls, name, functools.partialmethod(cls._special_binary, name=name)) + + # Binary special ops with forward in method name + for name in [ + "chebyshev_polynomial_t", + "chebyshev_polynomial_u", + "chebyshev_polynomial_v", + "chebyshev_polynomial_w", + "hermite_polynomial_h", + "hermite_polynomial_he", + ]: + setattr( + cls, + name, + functools.partialmethod(cls._special_binary, name=name + "_forward"), + ) + + +MetalOverrides._initialize_pointwise_overrides("mps") +MetalOverrides._initialize_special_ops() + + +class MetalKernel(SIMDKernel): + """Implement Metal codegen based on the SIMDKernel abstraction""" + + overrides = MetalOverrides # type: ignore[assignment] + suffix = ";" + newvar_prefix = "auto " + max_threadgroup_size = 1024 + simd_group_size = 32 + pexpr = PythonPrinter().doprint + cexpr = CppPrinter().doprint + sexpr = MetalExprPrinter().doprint + kexpr = sexpr + headers: OrderedSet[str] = OrderedSet(["utils"]) + multistage_reduction_entry: list[IterationRangesEntry] = [] + + def __init__( + self, + tiling: dict[str, sympy.Expr], + **kwargs: Any, + ) -> None: + super().__init__(tiling, **kwargs) + self.acc_var_ids = itertools.count() + + def dtype_to_str(self, dtype: torch.dtype) -> str: + return DTYPE_TO_METAL[dtype] + + def load(self, name: str, index: sympy.Expr) -> CSEVariable: + """Codegen a load from an InputBuffer""" + var = self.args.input(name) + index = self.prepare_indexing(index) + dtype = V.graph.get_dtype(name) + line = f"{var}[{self.index_to_str(index)}]" + if dtype in [torch.float16, torch.bfloat16]: + # TODO(NS): Figure out the right balance between optype casts + # op_math_t for half-precision floats should be float32 + # Otherwise it can lead to a correctness issues with eager + line = f"static_cast({line})" + dtype = torch.float32 + return self.cse.generate(self.loads, line, dtype=dtype) + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + var = self.args.output(name) + index = self.prepare_indexing(index) + dtype_str = self.dtype_to_str(V.graph.get_dtype(name)) + cast_val = f"static_cast<{dtype_str}>({value})" + if mode is None: + line = f"{var}[{self.index_to_str(index)}] = {cast_val};" + elif mode == "atomic_add": + self.headers.add("atomic") + atomic_type = f"c10::metal::AtomicType<{dtype_str}>" + cast_var = f"reinterpret_cast({var})" + line = f"{atomic_type}::atomic_add({cast_var}, {self.index_to_str(index)}, {cast_val});" + else: + raise RuntimeError(f"Unimplemented store mode {mode}") + if self.inside_reduction: + self.compute.writeline(DeferredLine(name, line)) + else: + self.stores.writeline(DeferredLine(name, line)) + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None: + var = self.args.output(name) + index = self.prepare_indexing(index) + dtype_str = self.dtype_to_str(V.graph.get_dtype(name)) + reduction_dim = next(t for t in self.range_trees if t.is_reduction) + # Only one thread in the reduction group needs to store the results + line = f"{var}[{self.index_to_str(index)}] = static_cast<{dtype_str}>({value});" + line = f"if ({reduction_dim.name} == 0) {line}" + self.stores.writeline(DeferredLine(name, line)) + + def _new_idxvar( + self, + dtype: Union[str | torch.dtype], + elem_count: Optional[int] = None, + default_value: Optional[Any] = None, + is_threadgroup: bool = True, + bounds: ValueRanges[Any] = ValueRanges.unknown(), + ) -> CSEVariable: + if isinstance(dtype, torch.dtype): + dtype = self.dtype_to_str(dtype) + var_name = f"tmp_acc_{next(self.acc_var_ids)}" + var = V.kernel.create_cse_var(var_name, bounds, dtype) + var_def = "threadgroup " if is_threadgroup else "" + var_def += f"{dtype} {var_name}" + if elem_count: + var_def += f"[{elem_count}]" + if default_value is not None: + assert not is_threadgroup, "Thread group var can not have default value" + var_def += f" = {default_value}" + self.indexing_code.writeline(var_def + self.suffix) + return var + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: + "Caching wrapper around _reduction_nocache" + cache_key = (src_dtype, reduction_type, value) + # Return cached reduction + if cache_key in self.cse.reduction_cache: + return self.cse.reduction_cache[cache_key] + result = self._reduction_nocache(dtype, src_dtype, reduction_type, value) + self.cse.reduction_cache[cache_key] = result # type: ignore[assignment] + return result + + def _reduction_nocache( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: + """Codegen a reduction operation. + Only sum and prod operations are somewhat reasonable optimized""" + assert self.inside_reduction + assert not self._load_mask + + def _unwrap_helper(res3: CSEVariable) -> tuple[CSEVariable, ...]: + # Uwraps vec3 dtype into individual components + return OpsWrapper._unwrap( + [CSEVariable(f"{res3}.{t}", res3.bounds, res3.dtype) for t in "xyz"] + ) + + # Establish reduction buffer size and index expression + reduction_idx = "" + acc_buf_size = 1 + for rd in self.range_trees: + if not rd.is_reduction: + continue + if reduction_idx: + reduction_idx += " + " + reduction_idx += f"{rd.name} * {acc_buf_size}" + acc_buf_size *= rd.numel + acc_buf_size = min(acc_buf_size, self.max_threadgroup_size) + + if reduction_type == "any": + acc = self._new_idxvar(dtype) + self.indexing_code.writeline(f"{acc} = false;") + self.indexing_code.writeline( + "threadgroup_barrier(metal::mem_flags::mem_threadgroup);" + ) + self.compute.splice( + f""" + if ({value}) {{ + {acc} = true; + }} + """ + ) + self.stores.writeline( + "threadgroup_barrier(metal::mem_flags::mem_threadgroup);" + ) + return acc + + self.headers.add("reduction_utils") + + if reduction_type in ["prod", "sum"]: + acc_dtype = DTYPE_TO_COMPUTATION_DTYPE[src_dtype] + acc_buf = self._new_idxvar( + acc_dtype, ceildiv(acc_buf_size, self.simd_group_size) + ) + if not self.multistage_reduction_entry: + val = value + else: + default_val, reduction_op = ( + (0, "+") if reduction_type == "sum" else (1, "*") + ) + val = self._new_idxvar( + acc_dtype, default_value=default_val, is_threadgroup=False + ) + self.compute.splice(f"{val} {reduction_op}= {value};") + return self.cse.generate( + self.stores, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size})", + dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype], + ) + if reduction_type in ["max", "min", "argmin", "argmax"]: + acc_buf = self._new_idxvar(src_dtype, acc_buf_size) + acc_thread_var = f"{acc_buf}[{reduction_idx}]" + src_metal_type = DTYPE_TO_METAL[src_dtype] + if not self.multistage_reduction_entry: + self.compute.splice( + f"{acc_thread_var} = static_cast<{src_metal_type}>({value});" + ) + return self.cse.generate( + self.stores, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + dtype=dtype, + ) + lim_fn = "lowest" if reduction_type.endswith("max") else "max" + self.indexing_code.writeline( + f"{acc_thread_var} = ::metal::numeric_limits<{src_metal_type}>::{lim_fn}();" + ) + if reduction_type.startswith("arg"): + idx_var = next( + t for t in self.range_tree_nodes.values() if t.is_reduction + ) + idx_acc_buf = self._new_idxvar(torch.long, acc_buf_size) + cmp_op = ">" if reduction_type == "argmax" else "<" + idx_thread_var = f"{idx_acc_buf}[{reduction_idx}]" + self.indexing_code.splice(f"{idx_thread_var} = -1;") + self.compute.splice(f""" + if ({value} {cmp_op} {acc_thread_var}) {{ + {acc_thread_var} = {value}; + {idx_thread_var} = {idx_var.name}; + }} + """) + return self.cse.generate( + self.stores, + f"{idx_acc_buf}[c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})]", + dtype=dtype, + ) + self.compute.writeline( + f"{acc_thread_var} = ::c10::metal::{reduction_type}({acc_thread_var}, {value});" + ) + return self.cse.generate( + self.stores, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + dtype=dtype, + ) + if reduction_type == "welford_reduce": + if not self.multistage_reduction_entry: + acc_buf = self._new_idxvar(src_dtype, acc_buf_size) + self.compute.splice(f"{acc_buf}[{reduction_idx}] = {value};") + wf_res = self.cse.generate( + self.compute, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + dtype=torch.float32, + ) + return _unwrap_helper(wf_res) + acc_buf = self._new_idxvar("float3", acc_buf_size) + acc_thread_var = f"{acc_buf}[{reduction_idx}]" + self.indexing_code.splice(f"{acc_thread_var} = 0.0;") + self.compute.writeline( + f"{acc_thread_var} = ::c10::metal::welford_combine({acc_thread_var}, float3({value}, 0.0, 1.0));" + ) + wf_res = self.cse.generate( + self.stores, + f"c10::metal::threadgroup_welford_combine({acc_buf}, {acc_buf_size})", + dtype=torch.float32, + ) + return _unwrap_helper(wf_res) + if reduction_type == "welford_combine": + assert isinstance(value, tuple), "Input to welford combine must be tuple" + acc_buf = self._new_idxvar("float3", acc_buf_size) + acc_thread_var = f"{acc_buf}[{reduction_idx}]" + inp_value = f"float3({value[0]}, {value[1]}, {value[2]})" + self.indexing_code.splice(f"{acc_thread_var} = 0.0;") + if self.multistage_reduction_entry: + self.indexing_code.splice(f"{acc_thread_var} = 0.0;") + self.compute.writeline( + f"{acc_thread_var} = ::c10::metal::welford_combine({acc_thread_var}, {inp_value});" + ) + else: + self.compute.writeline(f"{acc_thread_var} = {inp_value};") + wf_res = self.cse.generate( + self.stores if self.multistage_reduction_entry else self.compute, + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + dtype=torch.float32, + ) + return _unwrap_helper(wf_res) + raise NotImplementedError(reduction_type) + + def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None: + index_expr = self.rename_indexing(entry.expr) + index_str = self.sexpr(index_expr) # type: ignore[misc] + + if not entry.is_reduction or entry.root.numel <= self.max_threadgroup_size: + self.indexing_code.writeline( + f"{self.index_dtype} {entry.name} = {index_str};" + ) + return + self.multistage_reduction_entry.append(entry) + # When reducing the tensor whose size exceeds max threadgroup size + # loop over extra indices per reduction thread and perform part of the operation + # using values in the shared memory + loop_size = ( + entry.root.numel + self.max_threadgroup_size - 1 + ) // self.max_threadgroup_size + self.body.writeline( + f"for(auto {entry.name}_cnt = 0; {entry.name}_cnt < {loop_size}; ++{entry.name}_cnt) {{" + ) + with self.body.indent(): + self.body.writeline( + f"{self.index_dtype} {entry.name} = {loop_size} * {index_str} + {entry.name}_cnt;" + ) + # Check that reduction is performed only within tensor boundary + if loop_size * self.max_threadgroup_size != entry.root.numel: + self.body.writeline(f"if ({entry.name} >= {entry.root.numel}) break;") + + def codegen_body(self) -> None: + """ + Concat output code from index_code, loads, compute, stores, + suffix into self.body. + + For pointwise kernels, this is called just once at the end. + + For reduction kernels, this generates a loop over the reduction + axis. + """ + if self.multistage_reduction_entry: + with self.body.indent(): + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.writeline("}" * len(self.multistage_reduction_entry)) + # Invalidate variables instantiated inside loop + # But results of reduction alive. Reduction cache values can be + # either CSEVariable or tuple of CSEVariables, in which case all + # variables in the tuple must be preserved + self.cse.invalidate( + OrderedSet( + v + for item in self.cse.reduction_cache.values() + for v in (item if isinstance(item, tuple) else (item,)) + ) + ) + # And loop codegen + while self.multistage_reduction_entry: + self.multistage_reduction_entry.pop().cache_clear() + else: + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + self.loads.clear() + self.compute.clear() + self.stores.clear() + + def codegen_kernel(self, name: Optional[str] = None) -> str: + """Called at the end to generate a final kernel string""" + self.codegen_body() + code = IndentedBuffer() + + if V.graph.cpp_wrapper: + code.writeline('(R"MTL(') + else: + code.writeline("compile_mps_shader('''") + + idx_vars = self.active_range_trees() + with code.indent(): + if not V.graph.cpp_wrapper: + for header in self.headers: + code.writeline(f"#include ") + else: + headers = [ + f"#include " for header in self.headers + ] + header_contents = _embed_headers( + headers, + [Path(__file__).parent.parent.parent / "include"], + OrderedSet(), # type: ignore[arg-type] + ) + code.writeline(header_contents) + + if self.inside_reduction: + total_reduction_size = math.prod( + t.numel for t in self.range_trees if t.is_reduction + ) + threadgroup_size = min(total_reduction_size, self.max_threadgroup_size) + code.writeline( + f"[[max_total_threads_per_threadgroup({threadgroup_size})]]" + ) + code.writeline("kernel void generated_kernel(") + with code.indent(): + for outer, inner in self.args.output_buffers.items(): + if outer in self.removed_buffers: + continue + dtype_str = self.dtype_to_str(V.graph.get_dtype(outer)) + code.writeline(f"device {dtype_str}* {inner},") + for outer, inner in self.args.input_buffers.items(): + dtype = V.graph.get_dtype(outer) + # MPS does not support float64, but scalar inputs are fine + if dtype == torch.float64: + outer_buf = V.graph.try_get_buffer(outer) + if outer_buf is None or outer_buf.get_size() != []: + raise RuntimeError("float64 is not supported by MPS") + dtype_str = "float" + else: + dtype_str = self.dtype_to_str(dtype) + code.writeline(f"constant {dtype_str}* {inner},") + for outer, inner in self.args.sizevars.items(): + code.writeline(f"constant long& {inner},") + assert len(idx_vars) < 4, "Up to 3 index variables are supported" + thread_pos_dtype = ( + f"uint{len(idx_vars)}" if len(idx_vars) > 1 else "uint" + ) + thread_pos_var_name = ( + idx_vars[0].name if len(idx_vars) == 1 else "thread_pos" + ) + thread_pos_suffix = "," if self.inside_reduction else "" + code.writeline( + f"{thread_pos_dtype} {thread_pos_var_name} [[thread_position_in_grid]]{thread_pos_suffix}" + ) + if self.inside_reduction: + code.writeline( + f"{thread_pos_dtype} group_pos [[thread_position_in_threadgroup]]" + ) + code.writeline(") {") + with code.indent(): + if len(idx_vars) > 1: + for idx, var in enumerate(idx_vars): + code.writeline( + f"auto {var.name} = thread_pos.{chr(120 + idx)};" + ) + code.splice(self.indexing_code) + code.splice(self.body) + code.writeline("}") + + if V.graph.cpp_wrapper: + code.writeline(')MTL");') + else: + code.writeline("''')") + + return code.getvalue() + + def call_kernel(self, name: str, node: Any = None) -> None: + """Codegen a call to this kernel""" + wrapper = V.graph.wrapper_code + # Make sure sizevars has been computed + for v in self.args.sizevars.keys(): + wrapper.ensure_size_computed(v) + + _, call_args, _, arg_types = self.args.python_argdefs() + arg_name_to_type = { + str(call_arg): arg_type for call_arg, arg_type in zip(call_args, arg_types) + } + + args = [*self.args.output_buffers.keys(), *self.args.input_buffers.keys()] + args = [arg for arg in args if arg not in self.removed_buffers] + args += [str(v) for v in self.args.sizevars.keys()] + + arg_types = [arg_name_to_type[arg] for arg in args] + expr_printer = self.cexpr if V.graph.cpp_wrapper else self.pexpr + + def format_threads(threads: list[str], kwarg: str) -> str: + if V.graph.cpp_wrapper: + threads = [f"static_cast({t})" for t in threads] + return f"{{{', '.join(threads)}}}" + else: + return f"{kwarg}=[{', '.join(threads)}]" + + # For reduction kernels, limit the maximum size over reduction dimensions to + # a maximum threadgroup size + if len(self.active_range_trees()) > 0: + threads = [ + expr_printer( + sympy.Min(v.numel, self.max_threadgroup_size) # type: ignore[misc] + if v.is_reduction + else v.numel + ) + for v in self.active_range_trees() + ] + + args.append(format_threads(threads, "threads")) + arg_types.append(list) + else: + if V.graph.cpp_wrapper: + raise RuntimeError("We should always have threads?") + + if self.inside_reduction: + threads = [ + expr_printer(sympy.Min(v.numel, self.max_threadgroup_size)) # type: ignore[misc] + if v.is_reduction + else "1" + for v in self.active_range_trees() + ] + args.append(format_threads(threads, "group_size")) + arg_types.append(list) + else: + if V.graph.cpp_wrapper: + # Add a None so that we always have a group_size in the + # arguments. We won't use it if the value is None. + args += [None] # type: ignore[list-item] + arg_types.append(None) + + wrapper.generate_kernel_call( + name, + args, + device=torch.device("cpu"), # TODO: Fix me, MPS does not expose streams now + triton=False, + arg_types=arg_types, + ) + + def check_bounds( + self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool + ) -> None: + if not (lower or upper): + return + # TODO(malfet): support asserts + # See https://github.com/pytorch/pytorch/issues/144634 + expr_str = self.index_to_str(expr) + lower_expr = f"{expr_str} < 0" if lower else "" + # TODO(malfet): Is upper bound inclusive or exclusive? + upper_expr = f"{expr_str} > {self.index_to_str(size)}" if upper else "" + if lower and upper: + line = f"if (({lower_expr}) && ({upper_expr})) return" + else: + line = f"if ({lower_expr}{upper_expr}) return" + self.cse.generate(self.compute, line, assignment=False) + + +class MetalScheduling(SIMDScheduling): + kernel_type = MetalKernel # type: ignore[assignment] + + def __init__(self, scheduler: Optional[Scheduler]) -> None: + super().__init__(scheduler) + wrapper = V.graph.wrapper_code + if wrapper is not None: + if not V.graph.cpp_wrapper: + wrapper.header.splice( + "from torch._inductor.runtime.runtime_utils import compile_mps_shader" + ) + + def define_kernel( + self, src_code: str, node_schedule: list[SchedulerNode], kernel: MetalKernel + ) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + # TODO: Merge multiple kernels into a single library + # Either using MultiKernel concept or overriding SIMDScheduling.codegen_node_scheduling + mps_lib_name = f"mps_lib_{wrapper.next_kernel_suffix()}" + + if V.graph.cpp_wrapper: + src_code = ( + f"at::native::mps::DynamicMetalShaderLibrary {mps_lib_name}" + + src_code + ) + kernel_name = f"{mps_lib_name}_func" + else: + kernel_name = f"{mps_lib_name}.generated_kernel" + + wrapper.src_to_kernel[src_code] = kernel_name + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment = f"{origins}\n{detailed_origins}" + wrapper.define_kernel(mps_lib_name, src_code, metadata_comment, gpu=False) + + return kernel_name diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/mps_device_op_overrides.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/mps_device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..a97df5f32a72aa0b2bd41aa940d4d768dc0d1e3e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/mps_device_op_overrides.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from .common import DeviceOpOverrides, register_device_op_overrides + + +class MPSDeviceOpOverrides(DeviceOpOverrides): + def device_guard(self, device_idx: int) -> str: + assert device_idx == 0 + return "torch._ops.contextlib.nullcontext()" + + def set_device(self, device_idx: int) -> str: + assert device_idx == 0 + return "pass # MPS set device" + + def kernel_driver(self) -> str: + return """ + #include + """ + + def cpp_kernel_type(self) -> str: + return "MTLFunction_t" + + +register_device_op_overrides("mps", MPSDeviceOpOverrides()) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/multi_kernel.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/multi_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..ae89bc53eb316fe4c7193781ddf7cfe11cbcf064 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/multi_kernel.py @@ -0,0 +1,379 @@ +# mypy: allow-untyped-defs +import functools +import logging +import os +import pathlib + +from torch._inductor.metrics import get_metric_table, is_metric_table_enabled +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..codecache import code_hash, CodeCacheFuture, get_path, write_atomic +from ..runtime.benchmarking import benchmarker +from ..utils import cache_on_self, IndentedBuffer +from ..virtualized import V +from .common import TensorArg, WorkspaceArg + + +log = logging.getLogger(__name__) + + +class MultiKernelState: + """ + Maintain state of multi-kernel compilation so we don't define duplicated + multi-kernel for the same set of sub-kernels. + + V.graph.wrapper_code has a reference to MultiKernelState instance. + """ + + def __init__(self): + self.subkernel_to_kernel_name = {} + self.kernel_defs = IndentedBuffer() + + def define_kernel(self, kernels): + """ + Previously we name the multi kernel as "multi_kernel_{kernel_names[0]}". + This has some minor issue. + + E.g. for persistent reduction https://gist.github.com/shunting314/39e7c00ff8bb2055942ed5a3255d61ca , + there are 2 flavors of non-persistent reduction: + https://gist.github.com/shunting314/056d43d35907e87efb883970b35c17d4 + and + https://gist.github.com/shunting314/02ee753b65c513c54e695626afe682bd + + The only different is cache eviction policy. + + We should name the multi-kernel differently in these 2 cases. + """ + kernel_names = tuple(k.kernel_name for k in kernels) + if kernel_names in self.subkernel_to_kernel_name: + return self.subkernel_to_kernel_name[kernel_names] + + # name the multi kernel based on the first kernel + multi_kernel_name = f"multi_kernel_{len(self.subkernel_to_kernel_name)}" + self.subkernel_to_kernel_name[kernel_names] = multi_kernel_name + + if V.graph.cpp_wrapper and not config.triton.autotune_at_compile_time: + # we should not generate any python code for multi-kernel during + # the second pass of cpp-wrapper. + return multi_kernel_name + + buf = self.kernel_defs + buf.writeline("") + buf.writeline( + f"{multi_kernel_name} = async_compile.multi_kernel({multi_kernel_name!r}, [" + ) + with buf.indent(): + for name in kernel_names: + buf.writeline(f"{name},") + buf.writeline("])") + + if config.triton.autotune_at_compile_time: + V.graph.wrapper_code.src_to_kernel["\n".join(kernel_names)] = ( + multi_kernel_name + ) + + return multi_kernel_name + + +class MultiKernel: + """ + This class maintains the compile time state for multi kernels. + + Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2. + The generated definition for the multi-kernel will looks like: + ``` + multi_kernel_kernel1 = MultiKernelCall( + [kernel1, kernel2], multi_kernel_definition_code + ) + ``` + + Here is a concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39 + """ + + def __init__(self, kernels): + assert len(kernels) >= 2 + + self.kernels = kernels + self.kernel_name = V.graph.wrapper_code.multi_kernel_state.define_kernel( + kernels + ) + + # need this since some code in inductor check if the kernel object has an args + # attribute to decide if it's a non-null kernel. + self.args = object() + + @staticmethod + def _merge_workspace_args(left: list[WorkspaceArg], right: list[WorkspaceArg]): + if left == right: + return left + result = {x.inner_name: x for x in left} + for arg in right: + if arg.inner_name in result: + result[arg.inner_name] = WorkspaceArg.maximum( + result[arg.inner_name], arg + ) + else: + result[arg.inner_name] = arg + return [*result.values()] + + @staticmethod + def merge_workspaces_inplace(kernels): + if len(kernels) < 2: + return + # All kernels must share the same workspace + workspace_args = functools.reduce( + MultiKernel._merge_workspace_args, + [kernel.args.workspace_args for kernel in kernels], + ) + for kernel in kernels: + kernel.args.workspace_args = workspace_args + return workspace_args + + def call_kernel(self, kernel_name): + """ + Collect the union of arguments from all subkernels as the arguments + for the multi-kernel. + """ + assert kernel_name == self.kernel_name + V.graph.wrapper_code.write_triton_header_once() + _, call_args, _, arg_types = self.kernels[0].args.python_argdefs() + for kernel in self.kernels[1:]: + _, other_call_args, _, other_arg_types = kernel.args.python_argdefs() + assert call_args == other_call_args, (call_args, other_call_args) + assert arg_types == other_arg_types + + if V.graph.cpp_wrapper and not config.triton.autotune_at_compile_time: + # for the second pass of cpp-wrapper codegen, we should call + # the fast kernel directly + kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) + + # numels for all subkernels should be the same. Use kernels[0] here + self.kernels[0].add_numel_to_call_args(kernel_name, call_args, arg_types) + + for ws in self.kernels[0].args.workspace_args: + V.graph.wrapper_code.generate_workspace_allocation(ws) + + V.graph.wrapper_code.generate_kernel_call( + kernel_name, + call_args, + arg_types=arg_types, + ) + + for ws in reversed(self.kernels[0].args.workspace_args): + V.graph.wrapper_code.generate_workspace_deallocation(ws) + + def codegen_nan_check(self): + wrapper = V.graph.wrapper_code + seen: OrderedSet[str] = OrderedSet() + for k in self.kernels: + _, call_args, precompile_args, _ = k.args.python_argdefs() + for arg, precompile_arg in zip(call_args, precompile_args): + if arg in seen: + continue + seen.add(arg) + if isinstance(precompile_arg, TensorArg): + line = f"assert not {arg}.isnan().any().item()" + wrapper.writeline(line) + line = f"assert not {arg}.isinf().any().item()" + wrapper.writeline(line) + + @property + def removed_buffers(self): + return OrderedSet.intersection(*[k.removed_buffers for k in self.kernels]) + + @property + def inplaced_to_remove(self): + return OrderedSet.intersection(*[k.inplaced_to_remove for k in self.kernels]) + + @property + @cache_on_self + def inplace_update_buffers(self): + """ + Make sure all kernels have the same inplace update mappings. + """ + for k in self.kernels[1:]: + assert k.inplace_update_buffers == self.kernels[0].inplace_update_buffers + return self.kernels[0].inplace_update_buffers + + def warn_mix_layout(self, kernel_name: str): + pass + + +class MultiKernelCall: + """ + This class is called at run time to actually run the kernel + """ + + def __init__(self, multi_kernel_name, kernels): + assert len(kernels) >= 2 + self._kernels = kernels + self.multi_kernel_name = multi_kernel_name + + self.disable_cache = os.environ.get( + "TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE" + ) == "1" or is_metric_table_enabled("persistent_red_perf") + + self.picked_kernel = None + if config.triton.multi_kernel > 1: + # manually force a subkernel to ease perf testing + picked_by_config = config.triton.multi_kernel - 2 + assert picked_by_config < len(self._kernels) + self.picked_kernel = picked_by_config + elif not self.disable_cache: + self.load_cache() + + self._recorded = False + + def cache_file_path(self): + key = code_hash( + ",".join( + [ + f"{k.fn.cache_key}{k.size_hints!r}{k.triton_meta!r}" + for k in self.kernels + ] + ) + ) + _, _, path = get_path(key, "picked_kernel") + return pathlib.Path(path) + + def load_cache(self): + assert self.picked_kernel is None + path = self.cache_file_path() + if path.exists(): + with path.open() as fd: + self.picked_kernel = int(fd.read()) + assert self.picked_kernel >= 0 and self.picked_kernel < len( + self._kernels + ) + log.debug( + "Load picked kernel %d from cache file %s", self.picked_kernel, path + ) + + def store_cache(self): + assert self.picked_kernel is not None + path = self.cache_file_path() + path.parent.mkdir(parents=True, exist_ok=True) + + write_atomic(path, str(self.picked_kernel)) + log.debug("Store picked kernel %d to cache file %s", self.picked_kernel, path) + + @property + def kernels(self): + """ + Read results from future. + + This should be called after parallel compilation is done. + In case you call this before compilation is done, + it may slow down the parallel compilation. + """ + for i, kernel in enumerate(self._kernels): + if isinstance(kernel, CodeCacheFuture): + self._kernels[i] = kernel.result() + + return self._kernels + + def benchmark_sub_kernels(self, *args, **kwargs): + """ + Benchmark all the sub kernels and return the execution time + (in milliseconds) for each of time. + + Unit test may mock this method to force a specific kernel to + be picked. + """ + + def wrap_fn(kernel): + def inner(): + args_clone, kwargs_clone = kernel.clone_args(*args, **kwargs) + return kernel.run(*args_clone, **kwargs_clone) + + return inner + + return [ + benchmarker.benchmark_gpu(wrap_fn(kernel), rep=40) + for kernel in self.kernels + ] + + # record_choice and lookup_choice are helper functions for cpp-wrapper + # codegen. The first pass use record_choice to keep the choice and + # the second pass do lookup by calling lookup_choice. + # + # An alternative that reused the multi-kernel cache does not work well + # since during codegen of the second pass, it's very hard to know the + # path for the cache file. Also reading the cache file need do some IO + # which can be slower. + @staticmethod + def record_choice(multi_kernel_name: str, picked_kernel_name: str): + """ + Record the multi-kernel choice for cpp-wrapper after autotuning + + We should do nothing if this function is not called during codegen. + """ + from torch._inductor.graph import GraphLowering + + if not isinstance(V.graph, GraphLowering): + return + + if not V.graph.record_multi_kernel_choice: + return + + V.graph.multi_kernel_to_choice[multi_kernel_name] = picked_kernel_name + + @staticmethod + def lookup_choice(multi_kernel_name: str) -> str: + # this should always been done during cpp-wrapper codegen + assert ( + V.graph.record_multi_kernel_choice + and multi_kernel_name in V.graph.multi_kernel_to_choice + ) + # there should be no miss + return V.graph.multi_kernel_to_choice[multi_kernel_name] + + def run(self, *args, **kwargs): + if self.picked_kernel is None: + timings = self.benchmark_sub_kernels(*args, **kwargs) + self.picked_kernel = timings.index(min(timings)) + k0 = self.kernels[0] + log.debug( + "pick %dth sub-kernel in %s. Size hints %s. Reduction hint %s. Timings %s", + self.picked_kernel, + [k.inductor_meta.get("kernel_name") for k in self.kernels], + k0.size_hints, + k0.inductor_meta.get("reduction_hint"), + timings, + ) + get_metric_table("persistent_red_perf").add_row( + functools.partial(self._metrics_table_row, timings) + ) + if not self.disable_cache: + self.store_cache() + + if not self._recorded: + self._recorded = True + picked_kernel_name = self.kernels[self.picked_kernel].inductor_meta.get( + "kernel_name" + ) + assert picked_kernel_name is not None + self.record_choice(self.multi_kernel_name, picked_kernel_name) + self.run = self.kernels[self.picked_kernel].run # type: ignore[method-assign] + self.run(*args, **kwargs) + + def _metrics_table_row(self, timings): + def get_kernel_path(k): + return k.fn.fn.__code__.co_filename + + k0 = self.kernels[0] + row = { + "size_hints": k0.size_hints, + "reduction_hint": k0.inductor_meta.get("reduction_hint"), + } + max_kernels = 4 + assert len(timings) <= max_kernels + for i in range(max_kernels): + if i < len(self.kernels): + row[f"kernel{i}_path"] = get_kernel_path(self.kernels[i]) + row[f"kernel{i}_latency"] = timings[i] + else: + row[f"kernel{i}_path"] = "" + row[f"kernel{i}_latency"] = "" + return row diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__init__.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7a3616f93ede0caaae097b25a380e99274ecaaa Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_conv_template.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_conv_template.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9de21823a3645e4e53ca942a71cdb0c41ed4173d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_conv_template.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_template.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_template.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e043f237290ea9c31dfcddcee6d6f55bd733f1b5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_template.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_tile_template.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_tile_template.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea54293f8b22bcdc21a73e2dec22aeee89c077f2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_tile_template.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_tile_universal_gemm_template.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_tile_universal_gemm_template.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8ce40f39d3ea34e7075f29b3fc58efee35880b4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_tile_universal_gemm_template.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_universal_gemm_template.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_universal_gemm_template.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..060b6a24f389346a32eadc4409bf0f9a6c9aa582 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/ck_universal_gemm_template.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/compile_command.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/compile_command.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0b62248417a0edd9dbca3407a353a8ee81bcc3e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/compile_command.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_benchmark_request.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_benchmark_request.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac0363c762997f59f7a3bf5f527ed58b24e253dd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_benchmark_request.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_cpp_scheduling.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_cpp_scheduling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db06105dff1f8631a3c47fcb4362fbede209ed18 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_cpp_scheduling.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_kernel.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_kernel.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4ce4c2b4993a4c30eb91f204dd50be6534f569b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_kernel.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7106b20173857f63c50e523996bf245a3ffa135 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template_buffer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template_buffer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c18af5dcc62eaa4ab56dfb4ad76b3a66336a1ccb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_template_buffer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd405070942a9135a74afa458c663ace82f8a971 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/__pycache__/rocm_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/ck_conv_template.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/ck_conv_template.py new file mode 100644 index 0000000000000000000000000000000000000000..76c947132bdc34e229662e142a07aa20fa98c804 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/ck_conv_template.py @@ -0,0 +1,608 @@ +# mypy: allow-untyped-defs +import copy +import logging +import random + +from torch._inductor.virtualized import V + + +try: + import ck4inductor # type: ignore[import] +except ImportError: + ck4inductor = None + +if ck4inductor is not None: + from ck4inductor.grouped_conv_fwd.gen_instances import ( # type: ignore[import] + gen_conv_ops_library, + ) + from ck4inductor.grouped_conv_fwd.op import ( # type: ignore[import] # noqa: TCH002 + CKGroupedConvFwdOp, + ) +else: + + def gen_conv_ops_library(): + return [] + + +from torch._inductor import config +from torch._inductor.codegen.rocm.ck_template import CKTemplate +from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel +from torch._inductor.utils import IndentedBuffer + + +log = logging.getLogger(__name__) + + +def torch_layout_to_ck_layouts(torch_layout): + # logically, torch tensors are always NCHW, + # and channels-last memory layout is visible in the strides + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + # when input or output is NCHW + # NB: torch.conv2d result is always NCHW + return ["NGCHW", "GKCYX", "NGKHW"] + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + # when input or output or weight is channels-last + return ["NHWGC", "GKYXC", "NHWGK"] + else: + return None + + +def torch_layout_to_ck_input_layout(torch_layout): + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return "NGCHW" + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + return "NHWGC" + else: + return None + + +def torch_layout_to_ck_weight_layout(torch_layout): + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return "GKCYX" + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + return "GKYXC" + else: + return None + + +def torch_layout_to_ck_output_layout(torch_layout): + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return "NGKHW" + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + return "NHWGK" + else: + return None + + +class CKGroupedConvFwdTemplate(CKTemplate): + conv_template = r""" + {{headers}} + {{globals}} + {{instance_definition}} + extern "C" { + PT_EXPORT {{kernel_definition}} { + auto conv = {{instance_type}} {}; + auto invoker = conv.MakeInvoker(); + + using ck::index_t; + + constexpr index_t NumDTensor = {{n_d_tensors}}; + constexpr index_t NDimSpatial = {{n_dim_spatial}}; + const std::vector FilterSize = { FilterSize_0, FilterSize_1 }; + const std::vector InputSize = { InputSize_0, InputSize_1 }; + const std::vector ConvolutionStrides = { ConvolutionStrides_0, ConvolutionStrides_1 }; + const std::vector Dilations = { Dilations_0, Dilations_1 }; + const std::vector LeftPads = { LeftPads_0, LeftPads_1 }; + const std::vector RightPads = { RightPads_0, RightPads_1 }; + + + auto conv_param = ck::utils::conv::ConvParam { + NDimSpatial, + GroupCount, + NBatch, + NOutChannels, + NInChannels, + FilterSize, + InputSize, + ConvolutionStrides, + Dilations, + LeftPads, + RightPads, + }; + + using InLayout = ck::tensor_layout::convolution::{{input_layout}}; + using WeiLayout = ck::tensor_layout::convolution::{{weight_layout}}; + using OutLayout = ck::tensor_layout::convolution::{{output_layout}}; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + const void* p_a = input; + const void* p_b = weight; + const std::array p_ds; + void* p_e = output; + std::array a_g_n_c_wis_lengths; + std::array a_g_n_c_wis_strides; + std::array b_g_k_c_xs_lengths; + std::array b_g_k_c_xs_strides; + std::array, NumDTensor> ds_g_n_k_wos_lengths; + std::array, NumDTensor> ds_g_n_k_wos_strides; + std::array e_g_n_k_wos_lengths; + std::array e_g_n_k_wos_strides; + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + const auto a_element_op = PassThrough {}; + const auto b_element_op = PassThrough {}; + const auto cde_element_op = PassThrough {}; + + auto copy = [](auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + auto argument = conv.MakeArgument( + p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op + ); + if (!conv.IsSupportedArgument(argument)) { + // we do our best to statically avoid this case in `filter_op` + std::cerr << "invalid argument for conv instance " << conv.GetTypeString() << std::endl; + argument.Print(); + return -23; + } + if (workspace_size) { + *workspace_size = conv.GetWorkSpaceSize(&argument); + return 0; + } + + if (p_a == nullptr) { + std::cerr << "p_a is nullptr" << std::endl; + return -1; + } + if (p_b == nullptr) { + std::cerr << "p_b is nullptr" << std::endl; + return -1; + } + if (p_e == nullptr) { + std::cerr << "p_e is nullptr" << std::endl; + return -1; + } + + // when debugging, do time kernel to serialize launches + auto stream_config = StreamConfig{stream, /* time kernel */ false, /* log level */ 0}; + + if (workspace != nullptr) { + conv.SetWorkSpacePointer(&argument, workspace, stream_config); + } + + // run the kernel + float elapsed_time = invoker.Run(argument, stream_config); + return 0; + } // kernel definition + } // extern C + + #ifdef GENERATE_CK_STANDALONE_RUNNER + int main(int argc, char** argv) { + (void) argc; + (void) argv; + return 0; + } + #endif // GENERATE_CK_STANDALONE_RUNNER +""" + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK conv globals + + using NWC = ck::tensor_layout::convolution::NWC; + using NHWC = ck::tensor_layout::convolution::NHWC; + using NDHWC = ck::tensor_layout::convolution::NDHWC; + + using KXC = ck::tensor_layout::convolution::KXC; + using KYXC = ck::tensor_layout::convolution::KYXC; + using KZYXC = ck::tensor_layout::convolution::KZYXC; + + using NWK = ck::tensor_layout::convolution::NWK; + using NHWK = ck::tensor_layout::convolution::NHWK; + using NDHWK = ck::tensor_layout::convolution::NDHWK; + + using GNWC = ck::tensor_layout::convolution::GNWC; + using GNHWC = ck::tensor_layout::convolution::GNHWC; + using GNDHWC = ck::tensor_layout::convolution::GNDHWC; + + using GKXC = ck::tensor_layout::convolution::GKXC; + using GKYXC = ck::tensor_layout::convolution::GKYXC; + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + + using GKCX = ck::tensor_layout::convolution::GKCX; + using GKCYX = ck::tensor_layout::convolution::GKCYX; + using GKCZYX = ck::tensor_layout::convolution::GKCZYX; + + using GNWK = ck::tensor_layout::convolution::GNWK; + using GNHWK = ck::tensor_layout::convolution::GNHWK; + using GNDHWK = ck::tensor_layout::convolution::GNDHWK; + + using NGKW = ck::tensor_layout::convolution::NGKW; + using NGKHW = ck::tensor_layout::convolution::NGKHW; + using NGKDHW = ck::tensor_layout::convolution::NGKDHW; + + using NWGC = ck::tensor_layout::convolution::NWGC; + using NHWGC = ck::tensor_layout::convolution::NHWGC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + + using KXGC = ck::tensor_layout::convolution::KXGC; + using KYXGC = ck::tensor_layout::convolution::KYXGC; + using KZYXGC = ck::tensor_layout::convolution::KZYXGC; + + using NWGK = ck::tensor_layout::convolution::NWGK; + using NHWGK = ck::tensor_layout::convolution::NHWGK; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using NGCW = ck::tensor_layout::convolution::NGCW; + using NGCHW = ck::tensor_layout::convolution::NGCHW; + using NGCDHW = ck::tensor_layout::convolution::NGCDHW; + + using G_K = ck::tensor_layout::convolution::G_K; + + using BlockGemmPipelineScheduler = ck::BlockGemmPipelineScheduler; + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + using BlockGemmPipelineVersion = ck::BlockGemmPipelineVersion; + + using ConvolutionForwardSpecialization = ck::tensor_operation::device::ConvolutionForwardSpecialization; + + namespace ck { + namespace utils { + namespace conv { + + ConvParam::ConvParam(ck::index_t n_dim, + ck::index_t group_count, + ck::index_t n_batch, + ck::index_t n_out_channels, + ck::index_t n_in_channels, + const std::vector& filters_len, + const std::vector& input_len, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads) + : num_dim_spatial_(static_cast(n_dim)), + G_(static_cast(group_count)), + N_(static_cast(n_batch)), + K_(static_cast(n_out_channels)), + C_(static_cast(n_in_channels)), + filter_spatial_lengths_(num_dim_spatial_), + input_spatial_lengths_(num_dim_spatial_), + output_spatial_lengths_(num_dim_spatial_), + conv_filter_strides_(num_dim_spatial_), + conv_filter_dilations_(num_dim_spatial_), + input_left_pads_(num_dim_spatial_), + input_right_pads_(num_dim_spatial_) + { + if(static_cast(filter_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(input_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(conv_filter_strides_.size()) != num_dim_spatial_ || + static_cast(conv_filter_dilations_.size()) != num_dim_spatial_ || + static_cast(input_left_pads_.size()) != num_dim_spatial_ || + static_cast(input_right_pads_.size()) != num_dim_spatial_) + { + throw( + std::runtime_error("ConvParam::ConvParam: " + "parameter size is different from number of declared dimensions!")); + } + + for(ck::index_t i = 0; i < num_dim_spatial_; ++i) + { + filter_spatial_lengths_[i] = static_cast(filters_len[i]); + input_spatial_lengths_[i] = static_cast(input_len[i]); + conv_filter_strides_[i] = static_cast(strides[i]); + conv_filter_dilations_[i] = static_cast(dilations[i]); + input_left_pads_[i] = static_cast(left_pads[i]); + input_right_pads_[i] = static_cast(right_pads[i]); + + // XEff = (X - 1) * conv_dilation_w + 1; + // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + const ck::long_index_t x_eff = + (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1; + + output_spatial_lengths_[i] = + (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) / + conv_filter_strides_[i] + + 1; + } + } + + } // namespace conv + } // namespace utils + } // namespace ck + + const std::vector& HostTensorDescriptor::GetLengths() const { return mLens; } + const std::vector& HostTensorDescriptor::GetStrides() const { return mStrides; } + std::size_t HostTensorDescriptor::GetNumOfDimension() const { return mLens.size(); } + void HostTensorDescriptor::CalculateStrides() { + mStrides.clear(); + mStrides.resize(mLens.size(), 0); + if(mStrides.empty()) + return; + + mStrides.back() = 1; + std::partial_sum( + mLens.rbegin(), mLens.rend() - 1, mStrides.rbegin() + 1, std::multiplies()); + } + """ + ) + return res + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // CK conv headers + + #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" + #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" + #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + + #include "ck/library/utility/convolution_parameter.hpp" + #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + """ + ) + return res + + @staticmethod + def add_ck_conv_choices( + choices, + layout, + input_nodes, + *, + stride, + padding, + dilation, + groups, + n_spatial_dimensions, + ): + template = CKGroupedConvFwdTemplate( + input_nodes, + layout, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + n_spatial_dimensions=n_spatial_dimensions, + ) + ops = template.gen_ops() + for op in ops: + template.maybe_append_choice( + choices, + op=op, + ) + + def __init__( + self, + input_nodes, + layout, + *, + stride, + padding, + dilation, + groups, + n_spatial_dimensions, + ): + super().__init__( + "ck_conv_template", + input_nodes, + layout, + ) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.n_spatial_dimensions = n_spatial_dimensions + + def filter_op(self, op: "CKGroupedConvFwdOp"): # type: ignore[name-defined] + metas = [ + T.get_layout() + for T in [*self.input_nodes, self.output_node] + if T is not None + ] + X_meta = metas[0] + W_meta = metas[1] + Y_meta = metas[-1] + # disable the instance if dtypes don't match + if op.a_element_dtype != self._TORCH_DTYPE_TO_CK[X_meta.dtype]: + return None + if op.b_element_dtype != self._TORCH_DTYPE_TO_CK[W_meta.dtype]: + return None + if op.e_element_dtype != self._TORCH_DTYPE_TO_CK[Y_meta.dtype]: + return None + # disable the instance if layouts don't match + if op.a_layout != torch_layout_to_ck_input_layout(X_meta): + return None + if op.b_layout != torch_layout_to_ck_weight_layout(W_meta): + return None + if op.e_layout != torch_layout_to_ck_output_layout(Y_meta): + return None + # disable the instance if number of spatial dimensions doesn't match + if op.n_dim_spatial != self.n_spatial_dimensions: + return None + # disable 1x1 and odd-channels conv specializations for now + if "Default" not in op.conv_forward_specialization: + return None + return op + + def gen_ops(self): + unfiltered_instances = gen_conv_ops_library() + + filtered_instances = list( + filter(lambda op: self.filter_op(op), unfiltered_instances) + ) + # NB: when using a fixed list order, most likely we will pick the subset of instances + # which are very similar to each other. Randomizing the choice seems to solve this. + random.seed(-11) + chosen_instances = ( + random.sample( + filtered_instances, + min(len(filtered_instances), config.rocm.ck_max_profiling_configs), + ) + if config.rocm.ck_max_profiling_configs + else filtered_instances + ) + log.debug( + "generated %d ck instances after filter: %s", + len(chosen_instances), + chosen_instances, + ) + return chosen_instances + + def emit_ck_instance(self, op: "CKGroupedConvFwdOp") -> tuple[str, str]: # type: ignore[name-defined] + # The Jinja template for generating a C++ type alias *definition* for a Universal GEMM instance + template_definition = r""" + // Gemm operator {{operation_name}} + using Operation_{{operation_name}} = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + {{template_params}}>; + +""" + # The Jinja template for generating a C++ type alias *usage* for a Universal GEMM instance + template_type = r""" + Operation_{{operation_name}} +""" + template_params = [] + for field_name, field_value in op.dict_items(): + if isinstance(field_value, tuple): + tuple_elements = ", ".join(map(str, iter(field_value))) + if "ds" in field_name: # element type and layout for bias + arg = f"/* {field_name} */ Tuple<{tuple_elements}>" + else: # tile shape + arg = f"/* {field_name} */ S<{tuple_elements}>" + template_params.append(arg) + else: + if field_value is not None: + template_params.append(f"/* {field_name} */ {field_value}") + return self._template_from_string(template_definition).render( + operation_name=op.name(), + template_params=(",\n" + 12 * " ").join(template_params), + ), self._template_from_string(template_type).render(operation_name=op.name()) + + def render( # type: ignore[override] + self, + kernel: ROCmTemplateKernel, + op: "CKGroupedConvFwdOp", # type: ignore[name-defined] + **kwargs, + ) -> str: + template_buffer_node = kwargs.get("template_buffer_node", None) + if template_buffer_node is not None: + self.output_node = template_buffer_node + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = self.input_nodes[2] if 3 == len(self.input_nodes) else None + + op = copy.deepcopy(op) + + instance_definition, instance_type = self.emit_ck_instance(op) + + size_arg_strs = [ + "GroupCount", + "NBatch", + "NOutChannels", + "NInChannels", + "FilterSize_0", + "FilterSize_1", + "InputSize_0", + "InputSize_1", + "ConvolutionStrides_0", + "ConvolutionStrides_1", + "Dilations_0", + "Dilations_1", + "LeftPads_0", + "LeftPads_1", + "RightPads_0", + "RightPads_1", + ] + + return self._template_from_string(self.conv_template).render( + headers=self.header().getvalue(), + globals=self.globals().getvalue(), + instance_definition=instance_definition, + instance_type=instance_type, + kernel_definition=kernel.def_kernel( + inputs=[X, W, Bias] if Bias is not None else [X, W], + outputs=[Y], + names_str="input, weight, bias, output" + if Bias is not None + else "input, weight, output", + size_args=[f"int32_t {arg}" for arg in size_arg_strs], + ), + n_d_tensors=1 if Bias is not None else 0, + n_dim_spatial=self.n_spatial_dimensions, + input_layout=op.a_layout, + weight_layout=op.b_layout, + output_layout=op.e_layout, + ) + + def size_args(self): + x, w = self.input_nodes[0], self.input_nodes[1] + y = self.output_node + + group_count = self.groups + n_batch = x.shape[0] # type: ignore[index] + n_out_channels = y.shape[1] # type: ignore[index] + n_in_channels = x.shape[1] # type: ignore[index] + + filter_size_0, filter_size_1 = w.shape[2:4] # type: ignore[index] + input_size_0, input_size_1 = x.shape[2:4] # type: ignore[index] + convolution_strides_0, convolution_strides_1 = self.stride + dilations_0, dilations_1 = self.dilation + left_pads_0, left_pads_1 = self.padding + right_pads_0, right_pads_1 = self.padding + + return ( + group_count, + n_batch, + n_out_channels, + n_in_channels, + filter_size_0, + filter_size_1, + input_size_0, + input_size_1, + convolution_strides_0, + convolution_strides_1, + dilations_0, + dilations_1, + left_pads_0, + left_pads_1, + right_pads_0, + right_pads_1, + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/ck_template.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/ck_template.py new file mode 100644 index 0000000000000000000000000000000000000000..4c14aec11a8023e1b4b0d79dcb3b8b687930c777 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/ck_template.py @@ -0,0 +1,108 @@ +from typing import Any +from typing_extensions import override + +import torch +from torch._inductor.codegen.rocm.rocm_template import ROCmTemplate +from torch._inductor.ir import IRNode +from torch._inductor.utils import IndentedBuffer + +from .rocm_template import ArgInfo + + +class CKTemplate(ROCmTemplate): + """ + Base class for generating CK templates, has common, i.e. non-gemm-specific, code generation logic + """ + + _TORCH_DTYPE_TO_CK = { + torch.float32: "F32", + torch.float64: "F64", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int32: "I32", + torch.int8: "I8", + torch.float8_e4m3fnuz: "F8", + torch.float8_e5m2fnuz: "BF8", + } + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // CK headers + + #ifdef DEBUG_LOG + #define DEBUG_LOG_TMP DEBUG_LOG + #undef DEBUG_LOG + #else + #define DEBUG_LOG_TMP 0 + #endif + #include "ck/ck.hpp" + #undef DEBUG_LOG + #define DEBUG_LOG DEBUG_LOG_TMP + + #include "ck/utility/data_type.hpp" + #include "ck/library/utility/check_err.hpp" + #include "ck/library/utility/device_memory.hpp" + #include "ck/library/utility/fill.hpp" + #include "ck/library/utility/host_tensor.hpp" + #include "ck/library/utility/host_tensor_generator.hpp" + #include "ck/library/utility/literals.hpp" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK globals + + template + using S = ck::Sequence; + + template + using Tuple = ck::Tuple; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using Bilinear = ck::tensor_operation::element_wise::Bilinear; + using Scale = ck::tensor_operation::element_wise::Scale; + using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; + using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply; + + // see "composable_kernel/include/ck/utility/data_type.hpp" + using F8 = ck::f8_t; + using BF8 = ck::bf8_t; + using F16 = ck::half_t; + using F32 = float; + // using F64 = double; + using BF16 = ck::bhalf_t; + // using I32 = int32_t; + // using I8 = int8_t; + // using I4 = ck::int4_t; + + #if DEBUG_LOG + static constexpr auto kDEBUG_LOG = 1; + #else + static constexpr auto kDEBUG_LOG = 0; + #endif + """ + ) + return res + + def torch_type_to_ck(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"({self._TORCH_DTYPE_TO_CK.get(node.get_dtype())}*)({ptr})" + + @override + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [ArgInfo("kBatch", "int32_t")] + + @override + def get_runtime_arg_values(self, **kwargs: Any) -> list[Any]: + """ + Helper method to retrieve runtime args from generate kwargs + """ + return [kwargs[arg.name] for arg in self.get_runtime_arg_info()] diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/ck_tile_template.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/ck_tile_template.py new file mode 100644 index 0000000000000000000000000000000000000000..cf990e8f9dffa71a329cb33dbd43b20f5f84ee05 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/ck_tile_template.py @@ -0,0 +1,56 @@ +import torch +from torch._inductor.codegen.rocm.rocm_template import ROCmTemplate +from torch._inductor.ir import IRNode +from torch._inductor.utils import IndentedBuffer + + +class CKTileTemplate(ROCmTemplate): + """ + Base class for generating CK templates, has common, i.e. non-gemm-specific, code generation logic + """ + + _TORCH_DTYPE_TO_CK = { + torch.float32: "F32", + torch.float64: "F64", + torch.float16: "F16", + torch.bfloat16: "BF16", + torch.int32: "I32", + torch.int8: "I8", + torch.float8_e4m3fnuz: "F8", + torch.float8_e5m2fnuz: "BF8", + } + + ck_dtype_to_size = { + "FP16": 2, + "BF16": 2, + } + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // CK headers + #include "ck_tile/core.hpp" + + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + using F8 = ck_tile::fp8_t; + using BF8 = ck_tile::bf8_t; + using F16 = ck_tile::half_t; + using F32 = float; + using BF16 = ck_tile::bfloat16_t; + """ + ) + return res + + def torch_type_to_ck(self, node: IRNode, ptr: str) -> str: + if node is None: + return ptr + else: + return f"({self._TORCH_DTYPE_TO_CK.get(node.get_dtype())}*)({ptr})" diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..dc1ea5f2f1168c4d6ddb1b7d1d41edbe5d19018d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/ck_tile_universal_gemm_template.py @@ -0,0 +1,967 @@ +# mypy: allow-untyped-defs, disable-error-code="attr-defined, valid-type" +import functools +import logging +import random +from dataclasses import asdict, dataclass +from typing import Any + +import torch +from torch._inductor import config +from torch._inductor.codegen.rocm.ck_tile_template import CKTileTemplate +from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel +from torch._inductor.codegen.rocm.rocm_template import ArgInfo +from torch._inductor.ir import Buffer, Layout +from torch.utils._ordered_set import OrderedSet + +from ...utils import IndentedBuffer + + +log = logging.getLogger(__name__) + + +def is_static_int(number): + import sympy + + return isinstance(number, (int, sympy.Integer)) + + +def torch_layout_to_ck_layout(torch_layout): + if torch_layout.stride[-1] == 1: + return "Row" + elif torch_layout.stride[-2] == 1: + return "Col" + else: + return None + + +@dataclass +class CKTileGemmOperation: + layout_a: str + layout_b: str + layout_c: str + + datatype_a: str + datatype_b: str + datatype_c: str + + tile_m: int + tile_n: int + tile_k: int + + warp_m: int + warp_n: int + warp_k: int + + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + + m_is_padded: str + n_is_padded: str + k_is_padded: str + + pipeline: str + scheduler: str + epilogue: str + + def layout_repr(self): + return f"{self.layout_a[0]}{self.layout_b[0]}{self.layout_c[0]}" + + def dtype_repr(self): + return f"{self.datatype_a}{self.datatype_b}{self.datatype_c}" + + def tile_sizes(self): + return "_".join( + [ + f"{self.tile_m}{self.tile_n}{self.tile_k}", + f"{self.warp_m}{self.warp_n}{self.warp_k}", + f"{self.warp_tile_m}{self.warp_tile_n}{self.warp_tile_k}", + ] + ) + + def name(self): + return "ck_tile_gemm_universal_" + "_".join( + [ + f"{self.layout_repr()}", + f"{self.dtype_repr()}", + f"{self.tile_sizes()}", + f"{self.pipeline}", + f"{self.scheduler}", + f"{self.epilogue}", + ] + ) + + def dict_items(self): + return asdict(self).items() + + +@functools.cache +def ops(): + """ + Generate the supported instance dataclasses + """ + import itertools + + compute_v3_instances = [ + CKTileGemmOperation( + layout_a=layout_a, + layout_b=layout_b, + layout_c=layout_c, + datatype_a=datatype_a, + datatype_b=datatype_b, + datatype_c=datatype_c, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + m_is_padded=m_is_padded, + n_is_padded=n_is_padded, + k_is_padded=k_is_padded, + pipeline="CompV3", + scheduler="Intrawave", + epilogue=epilogue, + ) + for (layout_a, layout_b, layout_c) in [ + ("Row", "Row", "Row"), + ("Row", "Col", "Row"), + ] + for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] + for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)] + for (warp_m, warp_n, warp_k) in [(2, 2, 1)] + for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)] + for m_is_padded in ["true", "false"] + for n_is_padded in ["true", "false"] + for k_is_padded in ["true", "false"] + for epilogue in ["Default", "CShuffle"] + ] + + compute_v4_instances = [ + CKTileGemmOperation( + layout_a=layout_a, + layout_b=layout_b, + layout_c=layout_c, + datatype_a=datatype_a, + datatype_b=datatype_b, + datatype_c=datatype_c, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + m_is_padded=m_is_padded, + n_is_padded=n_is_padded, + k_is_padded=k_is_padded, + pipeline="CompV4", + scheduler="Intrawave", + epilogue=epilogue, + ) + for (layout_a, layout_b, layout_c) in [ + ("Row", "Row", "Row"), + ("Row", "Col", "Row"), + ] + for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] + for (tile_m, tile_n, tile_k) in [ + (256, 256, 32) + ] # half the tile size since it has double buffering + for (warp_m, warp_n, warp_k) in [(2, 2, 1)] + for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)] + for m_is_padded in ["true", "false"] + for n_is_padded in ["true", "false"] + for k_is_padded in ["true", "false"] + for epilogue in ["Default", "CShuffle"] + ] + + mem_instances = [ + CKTileGemmOperation( + layout_a=layout_a, + layout_b=layout_b, + layout_c=layout_c, + datatype_a=datatype_a, + datatype_b=datatype_b, + datatype_c=datatype_c, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + m_is_padded=m_is_padded, + n_is_padded=n_is_padded, + k_is_padded=k_is_padded, + pipeline="Mem", + scheduler=scheduler, + epilogue=epilogue, + ) + for (layout_a, layout_b, layout_c) in [ + ("Row", "Row", "Row"), + ("Row", "Col", "Row"), + ] + for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] + for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)] + for (warp_m, warp_n, warp_k) in [(2, 2, 1)] + for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)] + for m_is_padded in ["true", "false"] + for n_is_padded in ["true", "false"] + for k_is_padded in ["true", "false"] + for scheduler in ["Intrawave", "Interwave"] + for epilogue in ["Default", "CShuffle"] + ] + + return list( + itertools.chain(compute_v3_instances, compute_v4_instances, mem_instances) + ) + + +class CKTileGemmTemplate(CKTileTemplate): + """ + This class is used for rendering CK-Tile Universal GEMM kernels + """ + + gemm_template = r"""{{version_comment}} + {{headers}} + {{globals}} + {{instance_definition}} + extern "C" { + PT_EXPORT {{kernel_definition}} { + + using {{instance_namespace}}::BaseGemmPipeline; + using {{instance_namespace}}::TilePartitioner; + + constexpr auto TileK = {{instance_namespace}}::TileK; + constexpr auto kPrefetchStages = BaseGemmPipeline::PrefetchStages; + + auto kargs = ck_tile::GemmKernelArgs { + X, + W, + Y, + M, + N, + K, + LDA, + LDB, + LDC, + kBatch + }; + + if (workspace_size) { + *workspace_size = 0; + return 0; + } + + // run the kernel + const auto dispatch = [&](const auto has_hot_loop_, const auto tail_number_) constexpr { + using Kernel = {{instance_namespace}}::Kernel; + + if (!Kernel::IsSupportedArgument(kargs)) { + // we do our best to statically avoid this case in `filter_op` + throw std::runtime_error("invalid argument"); + } + auto stream_config = ck_tile::stream_config{stream}; + auto grid_size = Kernel::GridSize(M, N, kBatch); + constexpr auto block_size = Kernel::BlockSize(); + constexpr auto lds_bytes = 0; + constexpr auto kBlockPerCU = 1; + auto gemm = ck_tile::make_kernel(Kernel{}, grid_size, block_size, lds_bytes, kargs); + float elapsed_time = ck_tile::launch_kernel(stream_config, gemm); + }; + + const ck_tile::index_t k_grain = kBatch * TileK; + const ck_tile::index_t K_split = (K + k_grain - 1) / k_grain * TileK; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + {{rendered_dispatch}} + + return 0; + } // kernel definition + } // extern C + """ + + def __init__( + self, + input_nodes: list[Buffer], + layout: Layout, + ) -> None: + super().__init__( + "ck_tile_gemm_template", + input_nodes=input_nodes, + layout=layout, + ) + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // CK GEMM header(s) + + #include "ck_tile/ops/gemm.hpp" + #include "ck_tile/ops/epilogue.hpp" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK GEMM globals + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + template + void dispatch_memory_pipeline_hot_loop(const ck_tile::TailNumber tail_num, Dispatcher dispatch) + { + if(tail_num == ck_tile::TailNumber::One) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Full) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + + if constexpr(PrefetchStages > 2) + { + if(tail_num == ck_tile::TailNumber::Two) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 3) + { + if(tail_num == ck_tile::TailNumber::Three) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 4) + { + if(tail_num == ck_tile::TailNumber::Four) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 5) + { + if(tail_num == ck_tile::TailNumber::Five) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 6) + { + if(tail_num == ck_tile::TailNumber::Six) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + if constexpr(PrefetchStages > 7) + { + if(tail_num == ck_tile::TailNumber::Seven) + { + dispatch(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } + } + """ + ) + return res + + def check_dtypes(self, op: "CKTileGemmOperation"): + X_dtype, W_dtype, out_dtype = [ + T.get_layout().dtype for T in [*self.input_nodes, self.output_node] + ] + if op.datatype_a != self._TORCH_DTYPE_TO_CK[X_dtype]: + return False + if op.datatype_b != self._TORCH_DTYPE_TO_CK[W_dtype]: + return False + if op.datatype_c != self._TORCH_DTYPE_TO_CK[out_dtype]: + return False + return True + + def check_layouts(self, op: "CKTileGemmOperation"): + X_layout, W_layout, out_layout = [ + torch_layout_to_ck_layout(T.get_layout()) + for T in [*self.input_nodes, self.output_node] + ] + if op.layout_a != X_layout: + return False + if op.layout_b != W_layout: + return False + if op.layout_c != out_layout: + return False + return True + + def get_gemm_problem_size(self): + X_size, W_size = [T.get_layout().size for T in [*self.input_nodes]] + + M, K = X_size + _, N = W_size + + return M, N, K + + def check_block_tiles(self, op: "CKTileGemmOperation"): + """ + The contiguous dimension of a tensor must be divisible by the block tile size + This helper function enforces it for the inputs and the output. + """ + M, N, K = self.get_gemm_problem_size() + + def check(dim_size, tile_size, is_padded): + if ( + is_static_int(dim_size) + and dim_size % tile_size != 0 + and is_padded == "false" + ): + return False + return True + + if op.layout_a == "Row": + # handle in kBatch check + return True + elif op.layout_a == "Col": + if not check(M, op.tile_m, op.m_is_padded): + return False + else: + raise AssertionError(f"Invalid layout {op.layout_a=}") + + if op.layout_b == "Row": + if not check(N, op.tile_n, op.n_is_padded): + return False + elif op.layout_b == "Col": + # handle in kBatch check + return True + else: + raise AssertionError(f"Invalid {op.layout_b=}") + + if op.layout_c == "Row": + if not check(N, op.tile_n, op.n_is_padded): + return False + elif op.layout_c == "Col": + if not check(M, op.tile_m, op.m_is_padded): + return False + else: + raise AssertionError(f"Invalid layout {op.layout_c=}") + + return True + + def check_alignments(self, op: "CKTileGemmOperation"): + """ + The contiguous dimension of a tensor must be divisible by the vector load size. + """ + M, N, K = self.get_gemm_problem_size() + + def max_alignment(contiguous_elements_per_tile, elements_per_thread, ck_dtype): + for vector_load_bytes in (16, 8, 4, 2, 1): + alignment = vector_load_bytes // self.ck_dtype_to_size[ck_dtype] + if ( + alignment > 0 + and contiguous_elements_per_tile % alignment == 0 + and elements_per_thread % alignment == 0 + ): + return alignment + + threads_per_block = ( + op.warp_m * op.warp_n * op.warp_k * self.gfx9_threads_per_warp + ) + a_elements_per_thread = op.tile_m * op.tile_k / threads_per_block + b_elements_per_thread = op.tile_n * op.tile_k / threads_per_block + + if op.layout_a == "Row": + # K is contiguous tensor dimension + a_max_vector_size = max_alignment( + op.tile_k, a_elements_per_thread, op.datatype_a + ) + if is_static_int(K) and K % a_max_vector_size != 0: + return False + elif op.layout_a == "Col": + # M is contiguous tensor dimension + a_max_vector_size = max_alignment( + op.tile_m, a_elements_per_thread, op.datatype_a + ) + if is_static_int(M) and M % a_max_vector_size != 0: + return False + else: + raise AssertionError(f"Invalid layout {op.layout_a=}") + + if op.layout_b == "Row": + # N is contiguous tensor dimension + b_max_vector_size = max_alignment( + op.tile_n, b_elements_per_thread, op.datatype_b + ) + if is_static_int(N) and N % b_max_vector_size != 0: + return False + elif op.layout_b == "Col": + # K is contiguous tensor dimension + b_max_vector_size = max_alignment( + op.tile_k, b_elements_per_thread, op.datatype_b + ) + if is_static_int(K) and K % b_max_vector_size != 0: + return False + else: + raise AssertionError(f"Invalid layout {op.layout_b=}") + + # the `default` epilogue writes C to memory by 1 tensor element + # (divisibility check not necessary) + # the `cshuffle` epilogue writes C to memory by 16 bytes + # (so the contiguous C dimension size must be divisible by the number of tensor elements in 16 bytes) + if op.epilogue == "CShuffle": + if ( + op.layout_c == "Row" + and is_static_int(N) + and N % (16 / self.ck_dtype_to_size[op.datatype_c]) != 0 + ): + return False + + return True + + def check_warp_tiles(self, op: "CKTileGemmOperation"): + if op.tile_m % (op.warp_m * op.warp_tile_m) != 0: + return False + if op.tile_n % (op.warp_n * op.warp_tile_n) != 0: + return False + if op.tile_k % (op.warp_k * op.warp_tile_k) != 0: + return False + return True + + def check_block_tile_size(self, op: "CKTileGemmOperation"): + # assuming LDS size is 64KB + if op.pipeline == "CompV4": + max_block_tile_size = 2**15 + else: + max_block_tile_size = 2**16 + + block_tile_size = ( + self.ck_dtype_to_size[op.datatype_a] * op.tile_m * op.tile_k + + self.ck_dtype_to_size[op.datatype_b] * op.tile_n * op.tile_k + ) + if block_tile_size > max_block_tile_size: + return False + return True + + def filter_op(self, op: "CKTileGemmOperation"): + """ + Determines whether a given op definition is suitable for the current + input / output of the operation that this template implements. + + Filter is based on inputs' dtype, layout and statically inferred size. + + Returns None if the op is not suitable, otherwise returns the op to be used. + """ + if not self.check_dtypes(op): + return None + if not self.check_layouts(op): + return None + if not self.check_block_tiles(op): + return None + if not self.check_alignments(op): + return None + + return op + + def emit_ck_instance(self, op: "CKTileGemmOperation"): + """ + This method is used to generate code which defines the type alias for the generated kernel class + """ + template_definition = r""" + // Gemm operator {{operation_name}} + + namespace {{operation_name}} { + // block tile + constexpr int32_t TileM = {{tile_m}}; + constexpr int32_t TileN = {{tile_n}}; + constexpr int32_t TileK = {{tile_k}}; + // warps per block + constexpr int32_t WarpM = {{warp_m}}; + constexpr int32_t WarpN = {{warp_n}}; + constexpr int32_t WarpK = {{warp_k}}; + // xdl tile + constexpr int32_t WarpTileM = {{warp_tile_m}}; + constexpr int32_t WarpTileN = {{warp_tile_n}}; + constexpr int32_t WarpTileK = {{warp_tile_k}}; + + constexpr bool kPadM = {{m_is_padded}}; + constexpr bool kPadN = {{n_is_padded}}; + constexpr bool kPadK = {{k_is_padded}}; + + using ALayout = {{layout_a}}; + using BLayout = {{layout_b}}; + using CLayout = {{layout_c}}; + + using ADataType = {{datatype_a}}; + using BDataType = {{datatype_b}}; + using CDataType = {{datatype_c}}; + using AccDataType = F32; + + constexpr bool permuteA = false; + constexpr bool permuteB = false; + constexpr bool DoubleSmemBuffer = {{has_double_smem_buffer}}; + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + constexpr ck_tile::index_t TilePartitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence, + permuteA, + permuteB>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = + ck_tile::TileGemmTraits; + + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + {{rendered_scheduler}} + + template + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; + + {{rendered_pipeline}} + + {{rendered_epilogue}} + + template + using Kernel = ck_tile::GemmKernel, GemmEpilogue>; + } + +""" + + def render_epilogue(epilogue_type): + if epilogue_type == "Default": + return r""" + using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem; + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue; + """ + elif epilogue_type == "CShuffle": + return r""" + constexpr auto kMemoryOperation = ck_tile::memory_operation_enum::set; + using EpilogueProblem = ck_tile::CShuffleEpilogueProblem; + + using GemmEpilogue = ck_tile::CShuffleEpilogue; + """ + else: + raise AssertionError("Epilogue must be set") + + def render_pipeline(pipeline_type): + return rf""" + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCr{pipeline_type}; + + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCr{pipeline_type}>; + """ + + def render_scheduler(scheduler_type): + return rf""" + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::{scheduler_type}; + """ + + rendered_definition = self._template_from_string(template_definition).render( + operation_name=op.name(), + **asdict(op), + rendered_scheduler=render_scheduler(op.scheduler), + rendered_pipeline=render_pipeline(op.pipeline), + rendered_epilogue=render_epilogue(op.epilogue), + has_double_smem_buffer=("true" if op.pipeline == "CompV4" else "false"), + ) + return rendered_definition + + def render( # type: ignore[override] + self, kernel: ROCmTemplateKernel, op: "CKTileGemmOperation", **kwargs + ) -> str: + """ + The primary entry point for the code rendering process used in this template. + """ + epilogue_nodes = kwargs.get("epilogue_nodes", None) + assert epilogue_nodes is None or 0 == len(epilogue_nodes) + template_buffer_node = kwargs.get("template_buffer_node", None) + if template_buffer_node is not None: + self.output_node = template_buffer_node + assert 2 == len(self.input_nodes) + X, W = self.input_nodes + Y = self.output_node + + instance_definition = self.emit_ck_instance(op) + + version_comment = rf"""/** +* Generated code for CK inductor backend +* See {type(self).__module__}.{type(self).__qualname__} +* +* Template instance {op} +* +* {torch.__version__=} +* torch.version.git_version={getattr(torch.version, "git_version", "None")} +*/ +""" + + def render_dispatch(pipeline_type, op_name): + switch_tailnum_template = r""" + switch (tail_num) { + {% for tail_num in valid_tailnums %} + case ck_tile::TailNumber::{{tail_num}}: + dispatch({{has_hot_loop}}, + ck_tile::integral_constant{}); + break; + {% endfor %} + default: + std::ostringstream err; + err << "Unsupported dispatch: " + << "Pipeline: " << "{{pipeline}}" + << "Prefetch stages: " << kPrefetchStages + << "Tail num: " << tail_num; + throw std::runtime_error(err.str()); + } // switch tail_num + """ + dispatch_template = r""" + if (has_hot_loop) { + {{rendered_with_hot_loop}} + } + else { // has_hot_loop == false + {{rendered_without_hot_loop}} + } // if has_hot_loop + """ + if pipeline_type == "CompV3": + return self._template_from_string(dispatch_template).render( + rendered_with_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Full", "Odd", "Even"), + pipeline=pipeline_type, + ), + rendered_without_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Full", "Odd", "Even"), + pipeline=pipeline_type, + ), + ) + elif pipeline_type == "Mem": + return self._template_from_string(dispatch_template).render( + rendered_with_hot_loop="dispatch_memory_pipeline_hot_loop(tail_num, dispatch);", + rendered_without_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Full", "Odd", "Even"), + pipeline=pipeline_type, + ), + ) + elif pipeline_type == "CompV4": + return self._template_from_string(dispatch_template).render( + rendered_with_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Two", "Three"), + pipeline=pipeline_type, + ), + rendered_without_hot_loop=self._template_from_string( + switch_tailnum_template + ).render( + has_hot_loop="ck_tile::integral_constant{}", + valid_tailnums=("Full", "Odd", "Even"), + pipeline=pipeline_type, + ), + ) + else: + raise AssertionError(f"Pipeline {pipeline_type} is not supported") + + return self._template_from_string(self.gemm_template).render( + headers=self.header().getvalue(), + globals=self.globals().getvalue(), + instance_definition=instance_definition, + kernel_definition=kernel.def_kernel( + inputs=[X, W], # type: ignore[list-item] + outputs=[Y], + names_str="X, W, Y", + size_args=[ + f"int32_t {arg}" for arg in ["M", "N", "K", "LDA", "LDB", "LDC"] + ], + ), + instance_namespace=op.name(), + version_comment=version_comment, + rendered_dispatch=render_dispatch(op.pipeline, op.name()), + ) + + def gen_ops(self): + """ + Creates a list of `CKTileGemmOperation` instances that match the GEMM operation this template represents. + The instances are guaranteed to have the correct layout, dtype and dimension padding for the GEMM input arguments. + + An instance may invalidate the GEMM configuration at runtime. + Such instances will be assigned +inf runtime by the autotune process. + """ + instances = ops() + if not instances: + raise AssertionError( + "No Composable Kernel Universal GEMM instances found. " + "Please check if the library is installed." + ) + filtered_instances = list(filter(self.filter_op, instances)) + # NB: when using a fixed list order, most likely we will pick the subset of instances + # which are very similar to each other. Randomizing the choice seems to solve this. + random.seed(-11) + chosen_instances = ( + random.sample( + filtered_instances, + min(len(filtered_instances), config.rocm.ck_tile_max_profiling_configs), + ) + if config.rocm.ck_tile_max_profiling_configs + else filtered_instances + ) + log.debug( + "generated %d ck instances after sample: %s", + len(chosen_instances), + chosen_instances, + ) + return chosen_instances + + @staticmethod + def add_choices( + choices, + layout, + input_nodes, + ): + """ + Add Composable Kernel Universal GEMM instance choices to the auto-tuning list. + """ + template = CKTileGemmTemplate( + input_nodes, + layout, + ) + ops = template.gen_ops() + for op in ops: + for k_batch in template.k_batch_choices(op): + template.maybe_append_choice( + choices, + op=op, + kBatch=k_batch, + ) + + def k_batch_choices(self, op: "CKTileGemmOperation") -> tuple[int, ...]: + """ + Returns a list of k_batch choices for the template. + """ + default_choices = (1, 2, 4, 8, 16, 32) + + def check(dim_size, tile_size, is_padded): + if ( + is_static_int(dim_size) + and dim_size % tile_size != 0 + and is_padded == "false" + ): + return False + return True + + _, _, K, _, _, _ = self.size_args() + if op.layout_a == "Row" or op.layout_b == "Col": + choices = tuple( + filter( + lambda k_batch: check(K, op.tile_k * k_batch, op.k_is_padded), + default_choices, + ) + ) + else: + choices = default_choices + + if op.epilogue == "Default": + choices = (1,) + + return choices + + def size_args(self): + """ + Sizes and strides to be used for the kernel call + """ + X = self.input_nodes[0] + W = self.input_nodes[1] + Y = self.output_node + + M = X.get_size()[0] + K = X.get_size()[1] + N = W.get_size()[1] + LDA = X.get_stride()[0 if X.get_stride()[1] == 1 else 1] + LDB = W.get_stride()[0 if W.get_stride()[1] == 1 else 1] + LDC = Y.get_stride()[0 if Y.get_stride()[1] == 1 else 1] + + return M, N, K, LDA, LDB, LDC + + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [ArgInfo("kBatch", "int32_t")] + + def get_runtime_arg_values(self, **kwargs: Any) -> list[Any]: + # maybe_append_choice kwarg for k_batch must match the name of the argument + arg_names = OrderedSet([arg.name for arg in self.get_runtime_arg_info()]) + if not arg_names.issubset(kwargs): + raise ValueError( + "Missing runtime arguments: " + ", ".join(arg_names - kwargs.keys()) + ) + return [kwargs[k] for k in arg_names] diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..2b0482ad652b99cc406c29452c5f43466e740bf6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py @@ -0,0 +1,1016 @@ +# mypy: allow-untyped-defs, disable-error-code="attr-defined, valid-type" +import copy +import logging +import math +import random +from collections import namedtuple +from typing import Optional + +import sympy + +import torch +from torch._inductor import config +from torch._inductor.codegen.cpp_utils import DTYPE_TO_CPP +from torch._inductor.codegen.rocm.ck_template import CKTemplate +from torch._inductor.codegen.rocm.compile_command import rocm_compile_command +from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel +from torch._inductor.ir import Buffer, Layout +from torch._inductor.runtime.runtime_utils import next_power_of_2 + +from ...utils import IndentedBuffer, is_dynamic, try_import_ck_lib + + +_, gen_ops_library, gen_ops_preselected, CKGemmOperation = try_import_ck_lib() + + +log = logging.getLogger(__name__) + +# lightweight collection of information about a single op +InductorROCmOp = namedtuple("InductorROCmOp", ["op", "kBatch"]) + +padding_lookup = { + "M": { + "GemmSpecialization::MPadding": True, + "GemmSpecialization::MNPadding": True, + "GemmSpecialization::MKPadding": True, + "GemmSpecialization::MNKPadding": True, + }, + "N": { + "GemmSpecialization::NPadding": True, + "GemmSpecialization::MNPadding": True, + "GemmSpecialization::NKPadding": True, + "GemmSpecialization::MNKPadding": True, + }, + "K": { + "GemmSpecialization::KPadding": True, + "GemmSpecialization::MKPadding": True, + "GemmSpecialization::NKPadding": True, + "GemmSpecialization::MNKPadding": True, + }, +} + + +def is_static_int(number): + return isinstance(number, (int, sympy.Integer)) + + +def torch_layout_to_ck_layout(torch_layout): + if torch_layout.stride[-1] == 1: + return "Row" + elif torch_layout.stride[-2] == 1: + return "Col" + else: + return None + + +class CKGemmTemplate(CKTemplate): + # the JINJA template for rendering CK Universal GEMMs + gemm_template = r"""{{version_comment}} + {{headers}} + {{globals}} + {{instance_definition}} + extern "C" { + PT_EXPORT {{kernel_definition}} { + auto gemm = {{instance_type}} {}; + auto invoker = gemm.MakeInvoker(); + {% if is_batched %} + auto argument = gemm.MakeArgument( + reinterpret_cast(X), + reinterpret_cast(W), + std::array{ {{ds_names}} }, + reinterpret_cast<{{c_element_dtype}}*>(Y), + M, + N, + K, + B, + LDA, + LDB, + std::array{ {{ds_strides}} }, + LDC, + M * K, // batch_stride_A + N * K, // batch_stride_B + std::array{ {{ds_batch_strides}} }, + M * N, // batch_stride_C + {{a_elementwise_op}}, + {{b_elementwise_op}}, + {{epilogue}} // c_elementwise_op + ); + {% else %} + auto argument = gemm.MakeArgument( + reinterpret_cast(X), + reinterpret_cast(W), + std::array{ {{ds_names}} }, + reinterpret_cast<{{c_element_dtype}}*>(Y), + M, + N, + K, + LDA, + LDB, + std::array{ {{ds_strides}} }, + LDC, + kBatch, // kBatch + {{a_elementwise_op}}, + {{b_elementwise_op}}, + {{epilogue}} // c_elementwise_op + ); + {% endif %} + if (!gemm.IsSupportedArgument(argument)) { + // we do our best to statically avoid this case in `filter_op` + std::cerr << "invalid argument for gemm instance " << gemm.GetTypeString() << std::endl; + argument.Print(); + return -23; + } + if (workspace_size) { + *workspace_size = gemm.GetWorkSpaceSize(&argument); + return 0; + } + // run the kernel + #ifdef GENERATE_CK_STANDALONE_RUNNER + const auto stream_config = StreamConfig{ + stream, + /* time kernel */ 1, + /* log level */ 1, + /* n_cold_iter */ 100, + /* n_hot_iter */ 100, + /* flush_l2_cache */ 1, + /* rotate_count */ 5}; + #else + const auto stream_config = StreamConfig{stream, /* time kernel */ false, /* log level */ 0}; + #endif + + const float elapsed_time = invoker.Run(argument, stream_config); + + #ifdef GENERATE_CK_STANDALONE_RUNNER + std::cout << "elapsed time: " << elapsed_time << " ms" << std::endl; + #else + (void)elapsed_time; + #endif + return 0; + } // kernel definition + } // extern C + """ + + standalone_runner_template = r""" + #ifdef GENERATE_CK_STANDALONE_RUNNER + // standalone runner for the generated CK GEMM kernel + + {{inline_utils}} + + extern "C" { + int run_main(int argc, char** argv) { + {% if is_batched %} + const int32_t B = {{B}}; + {% endif %} + const int32_t M = {{M}}; + const int32_t N = {{N}}; + const int32_t K = {{K}}; + const int32_t LDA = {{LDA}}; + const int32_t LDB = {{LDB}}; + const int32_t LDC = {{LDC}}; + const int32_t LDD = {{LDD}}; + const int32_t kBatch = {{kBatch}}; + + using AElementType = {{a_ck_dtype}}; + using BElementType = {{b_ck_dtype}}; + using CElementType = {{c_ck_dtype}}; + {% if has_bias %} + using BiasElementType = {{bias_ck_dtype}}; + {% endif %} + {% if has_scale %} + using ScaleAElementType = {{scale_a_ck_dtype}}; + using ScaleBElementType = {{scale_b_ck_dtype}}; + {% endif %} + + using AArgType = {{a_torch_dtype}}; + using BArgType = {{b_torch_dtype}}; + using CArgType = {{c_torch_dtype}}; + {% if has_bias %} + using BiasArgType = {{bias_torch_dtype}}; + {% endif %} + {% if has_scale %} + using ScaleAArgType = {{scale_a_torch_dtype}}; + using ScaleBArgType = {{scale_b_torch_dtype}}; + {% endif %} + + using ALayout = {{a_layout}}; + using BLayout = {{b_layout}}; + using CLayout = {{c_layout}}; + {% if has_bias %} + using BiasLayout = {{bias_layout}}; + {% endif %} + + {% if is_batched %} + using strides_t = std::array; + auto get_strides = [](int32_t batch_stride, int32_t leading_dimension, auto layout) constexpr -> strides_t { + if constexpr (std::is_same_v) { + return {batch_stride, leading_dimension, 1}; + } + return {batch_stride, 1, leading_dimension}; + }; + auto a_size = strides_t{B, M, K}; + auto a_stride = get_strides(M * K, LDA, ALayout{}); + auto b_size = strides_t{B, N, K}; + auto b_stride = get_strides(N * K, LDB, BLayout{}); + auto c_size = strides_t{B, M, N}; + auto c_stride = get_strides(M * N, LDC, CLayout{}); + {% else %} + using strides_t = std::array; + auto get_strides = [](int32_t leading_dimension, auto layout) constexpr -> strides_t { + if constexpr (std::is_same_v) { + return {leading_dimension, 1}; + } + return {1, leading_dimension}; + }; + auto a_size = strides_t{M, K}; + auto a_stride = get_strides(LDA, ALayout{}); + auto b_size = strides_t{N, K}; + auto b_stride = get_strides(LDB, BLayout{}); + auto c_size = strides_t{M, N}; + auto c_stride = get_strides(LDC, CLayout{}); + {% endif %} + + Tensor a_m_k ( HostTensorDescriptor ( a_size, a_stride ) ); + Tensor b_k_n ( HostTensorDescriptor ( b_size, b_stride ) ); + {% if has_bias %} + Tensor d_m_n ( HostTensorDescriptor ( c_size, get_strides(LDD, BiasLayout{}) ) ); + {% endif %} + {% if has_scale %} + // NB: these are hardcoded + Tensor s_a_m_n ( HostTensorDescriptor ( strides_t{M, N}, get_strides(0, Row{}) )); + Tensor s_b_m_n ( HostTensorDescriptor ( strides_t{M, N}, get_strides(0, Col{}) )); + {% endif %} + + Tensor c_m_n_host ( HostTensorDescriptor ( c_size, c_stride ) ); + Tensor c_m_n_device ( HostTensorDescriptor ( c_size, c_stride ) ); + + a_m_k.GenerateTensorValue(GeneratorTensor_2()); + b_k_n.GenerateTensorValue(GeneratorTensor_2()); + {% if has_bias %} + d_m_n.GenerateTensorValue(GeneratorTensor_2()); + {% endif %} + {% if has_scale %} + s_a_m_n.GenerateTensorValue(GeneratorTensor_2()); + s_b_m_n.GenerateTensorValue(GeneratorTensor_2()); + {% endif %} + DeviceMem a_m_k_device_buf(sizeof(AElementType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BElementType) * b_k_n.mDesc.GetElementSpaceSize()); + {% if has_bias %} + DeviceMem d_m_n_device_buf(sizeof(BiasElementType) * d_m_n.mDesc.GetElementSpaceSize()); + {% endif %} + {% if has_scale %} + DeviceMem s_a_m_n_device_buf(sizeof(ScaleAElementType) * s_a_m_n.mDesc.GetElementSpaceSize()); + DeviceMem s_b_m_n_device_buf(sizeof(ScaleBElementType) * s_b_m_n.mDesc.GetElementSpaceSize()); + {% endif %} + DeviceMem c_m_n_device_buf(sizeof(CElementType) * c_m_n_device.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + {% if has_bias %} + d_m_n_device_buf.ToDevice(d_m_n.mData.data()); + {% endif %} + {% if has_scale %} + s_a_m_n_device_buf.ToDevice(s_a_m_n.mData.data()); + s_b_m_n_device_buf.ToDevice(s_b_m_n.mData.data()); + {% endif %} + + {{kernel_name}}( + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + {% if has_scale %} + static_cast(s_a_m_n_device_buf.GetDeviceBuffer()), + static_cast(s_b_m_n_device_buf.GetDeviceBuffer()), + {% endif %} + {% if has_bias %} + static_cast(d_m_n_device_buf.GetDeviceBuffer()), + {% endif %} + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + {% if is_batched %} + B, + {% endif %} + M, + N, + K, + LDA, + LDB, + LDC, + LDD, + nullptr, // workspace_size + nullptr, // workspace + nullptr); // stream + + hip_check_error(hipDeviceSynchronize()); + + return 0; + } // run_main + } // extern C + + int main(int argc, char** argv) { + return run_main(argc, argv); + } + // compile with: {{compile_cmd}} + #endif // GENERATE_CK_STANDALONE_RUNNER + """ + + def __init__( + self, + input_nodes: list[Buffer], + layout: Layout, + alpha: float, + beta: float, + input_reorder: Optional[list[int]] = None, + ) -> None: + is_batched = len(layout.size) == 3 + name = "ck_batched_gemm_template" if is_batched else "ck_gemm_template" + super().__init__( + name=name, + input_nodes=input_nodes, + layout=layout, + input_reorder=input_reorder, + ) + self.alpha = alpha + self.beta = beta + self.is_batched = is_batched + + def header(self) -> IndentedBuffer: + res = super().header() + if self.is_batched: + res.splice( + """ + // CK GEMM header(s) + + #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp" + """ + ) + else: + res.splice( + """ + // CK GEMM header(s) + + #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK GEMM globals + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using BlockGemmPipelineScheduler = ck::BlockGemmPipelineScheduler; + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + using BlockGemmPipelineVersion = ck::BlockGemmPipelineVersion; + + struct MultiplyMultiplyAdd { + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const { + e = ck::type_convert( + ck::type_convert(c) + * ck::type_convert(d0) + * ck::type_convert(d1) + + ck::type_convert(d2) + ); + } + }; + """ + ) + return res + + def inline_utils(self): + res = IndentedBuffer() + res.splice( + """ + #include "host_tensor.cpp" + #include "device_memory.cpp" + """ + ) + return res + + def _has_padding(self, dimension, gemm_specialization): + # Get the relevant padding map for the given dimension + dimension_padding = padding_lookup.get(dimension, {}) + + # Check if the specialization is in the dimension's padding map + return dimension_padding.get(gemm_specialization, False) + + def filter_op(self, op_info: InductorROCmOp): + """ + Determines whether a given op definition is suitable for the current + input / output of the operation that this template implements. + + Filter is based on inputs' dtype, layout and statically inferred size. + + Returns None if the op is not suitable, otherwise returns the op to be used. + """ + op, kBatch = op_info.op, op_info.kBatch + metas = [T.get_layout() for T in [*self.input_nodes, self.output_node]] + X_meta = metas[0] + W_meta = metas[1] + Y_meta = metas[-1] + # disable the instance if dtypes don't match + if op.a_element_dtype != self._TORCH_DTYPE_TO_CK[X_meta.dtype]: + return None + if op.b_element_dtype != self._TORCH_DTYPE_TO_CK[W_meta.dtype]: + return None + if op.c_element_dtype != self._TORCH_DTYPE_TO_CK[Y_meta.dtype]: + return None + # disable the instance if layouts don't match + if op.a_layout != torch_layout_to_ck_layout(X_meta): + return None + if op.b_layout != torch_layout_to_ck_layout(W_meta): + return None + if op.c_layout != torch_layout_to_ck_layout(Y_meta): + return None + # try to avoid launching the instance with invalid problem size + # see GridwiseGemm_xdl_cshuffle_v3::CheckValidity + + M = X_meta.size[-2] + K = X_meta.size[-1] + N = W_meta.size[-1] + + if is_static_int(M): + if not self._has_padding("M", op.gemm_specialization): + if M % op.m_per_block != 0: + return None + if is_static_int(N): + if not self._has_padding("N", op.gemm_specialization): + if N % op.n_per_block != 0: + return None + if is_static_int(K): + if not self._has_padding("K", op.gemm_specialization): + if K % op.k_per_block != 0: + return None + K_t = kBatch * op.k_per_block + if K % K_t != 0: + return None + else: + # need another kBatch check here + lcm = abs(op.a_k1 * op.b_k1) // math.gcd(op.a_k1, op.b_k1) + K_t = kBatch * lcm + k_read_pad_splited = math.ceil(K / K_t) * lcm + if (k_read_pad_splited * (kBatch - 1)) >= K: + return None + + a_contig_size = ( + K if op.a_layout == "Row" else M if op.a_layout == "Col" else None + ) + if ( + is_static_int(a_contig_size) + and a_contig_size % op.a_block_transfer_src_scalar_per_vector != 0 + ): + return None + b_contig_size = ( + N if op.b_layout == "Row" else K if op.b_layout == "Col" else None + ) + if ( + is_static_int(b_contig_size) + and b_contig_size % op.b_block_transfer_src_scalar_per_vector != 0 + ): + return None + c_contig_size = ( + N if op.c_layout == "Row" else M if op.c_layout == "Col" else None + ) + c_shuffle_block_transfer_scalar_per_vector_n_per_block = ( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block[0] + if isinstance( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block, tuple + ) + else op.c_shuffle_block_transfer_scalar_per_vector_n_per_block + ) + if ( + is_static_int(c_contig_size) + and c_contig_size % c_shuffle_block_transfer_scalar_per_vector_n_per_block + != 0 + ): + return None + if not self._check_num_k_loops(op, kBatch): + return None + # TBD disable instances with invalid number of pipeline prefetch stages + # It will avoid compiling a small percentage of unrunnable instances which fail the gemm argument check + + return op + + def _check_num_k_loops(self, op, kBatch): + # Additional splitK scenario check + metas = [T.get_layout() for T in [*self.input_nodes]] + X_meta = metas[0] + W_meta = metas[1] + K = X_meta.size[-1] + if kBatch > 1: + if op.block_gemm_pipeline_version != "BlockGemmPipelineVersion::v1": + try: + prefetch_stages = self._prefetch_stages( + op, + torch.empty((), dtype=X_meta.dtype).element_size(), + torch.empty((), dtype=W_meta.dtype).element_size(), + torch.cuda.get_device_properties(X_meta.device).warp_size, + ) + except Exception as e: + log.debug( + "Failed to prefetch_stages for %s with exception %s", op.name, e + ) + # be conservative here and disable the op + return False + + K_t = op.k_per_block * kBatch + ak0 = (K + K_t - 1) // K_t * (op.k_per_block // op.a_k1) + num_k_loop = ak0 // (op.k_per_block // op.a_k1) + if num_k_loop <= prefetch_stages: + log.debug( + "Op %s is not compatible due to invalid number of pipeline prefetch stages. " + "Parameters: kBatch=%s, block_gemm_pipeline_version=%s, prefetch_stages=%s, num_k_loop=%s", + op.name(), + kBatch, + op.block_gemm_pipeline_version, + prefetch_stages, + num_k_loop, + ) + return False + + return True + + # small helper to figure out the prefetch stages on AMD + def _prefetch_stages(self, op, a_dtype_size, b_dtype_size, warp_size: int = 64): + version_str = op.block_gemm_pipeline_version.split("::")[-1] + try: + version = int(version_str[1:]) # Assuming the format is always 'vX' + except ValueError as e: + raise ValueError(f"Invalid version string: {version_str}") from e + if version not in [1, 2, 3, 4, 5]: + raise ValueError( + f"unknown prefetch stages for {op.block_gemm_pipeline_version}" + ) + # Define the mapping of versions to stages + version_to_stages = {1: 1, 3: 2, 4: 4, 5: 3} + # Get the stages for the given version + stages = version_to_stages.get(version, None) + if stages is None: + # This means we're at stage 2, and this requires computation + # See github.com/ROCm/composable_kernel/blob/d6a4605/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp#L143 # noqa: B950 + wgp_per_cu = max(4 * warp_size // op.block_size, 1) + full_mem_band_prefetch_stages = math.ceil( + 32768 + / wgp_per_cu + / ( + (op.m_per_block * a_dtype_size + op.n_per_block * b_dtype_size) + * op.k_per_block + ) + ) + stages = min(max(full_mem_band_prefetch_stages, 2), 8) + + return stages + + def emit_ck_instance(self, op: "CKGemmOperation"): + # The Jinja template for generating a C++ type alias *definition* for a Universal GEMM instance + struct_name = ( + "DeviceBatchedGemmMultiD_Xdl_CShuffle_V3" + if self.is_batched + else "DeviceGemmMultiD_Xdl_CShuffle_V3" + ) + template_definition = r""" + // Gemm operator {{operation_name}} + using Operation_{{operation_name}} = + ck::tensor_operation::device::{{struct_name}}< + {{template_params}}>; + +""" + # The Jinja template for generating a C++ type alias *usage* for a Universal GEMM instance + template_type = r""" + Operation_{{operation_name}} +""" + template_params = [] + for field_name, field_value in op.dict_items(): + if isinstance(field_value, tuple): + tuple_elements = ", ".join(map(str, iter(field_value))) + if "ds" in field_name: # element type and layout for bias + arg = f"/* {field_name} */ Tuple<{tuple_elements}>" + else: # tile shape + arg = f"/* {field_name} */ S<{tuple_elements}>" + template_params.append(arg) + else: + if field_value is not None: + template_params.append(f"/* {field_name} */ {field_value}") + operation_name = op.name().replace("(", "").replace(",", "").replace(")", "") + return self._template_from_string(template_definition).render( + operation_name=operation_name, + template_params=(",\n" + 12 * " ").join(template_params), + struct_name=struct_name, + ), self._template_from_string(template_type).render( + operation_name=operation_name + ) + + def render( # type: ignore[override] + self, + kernel: ROCmTemplateKernel, + op: "CKGemmOperation", + **kwargs, + ) -> str: + """ + The primary entry point for the code rendering process used in this template. + """ + epilogue_nodes = kwargs.get("epilogue_nodes", None) + assert epilogue_nodes is None or 0 == len(epilogue_nodes) + template_buffer_node = kwargs.get("template_buffer_node", None) + if template_buffer_node is not None: + self.output_node = template_buffer_node + # input nodes: + # * X, W for matmul + # * X, W, Bias for addmm + # * X, W, inv_scale_x, inv_scale_w for scaled_mm + # * X, W, inv_scale_x, inv_scale_w, Bias for scaled_mm with bias + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = ( + self.input_nodes[2] + if 3 == len(self.input_nodes) + else self.input_nodes[4] + if 5 == len(self.input_nodes) + else None + ) + has_bias = Bias is not None + has_scale = len(self.input_nodes) in (4, 5) + op = copy.deepcopy(op) + + # This parameter is converted into tuple because of change + # from DeviceGemm_Xdl_CShuffleV3 to DeviceGemmMultiD_Xdl_CShuffle_V3. + # The first tuple element corresponds to matmul result... + if not isinstance( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block, tuple + ): + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block = ( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block, + ) + + if has_scale: + scale_x = self.input_nodes[2] + scale_w = self.input_nodes[3] + if 1 == scale_x.get_numel() and 1 == scale_w.get_numel(): + # tensorwise scale for both X, W + if has_bias: + op.c_elementwise_op = "ScaleAdd" + else: + op.c_elementwise_op = "Scale" + else: + # rowwise scale for both X, W + if has_bias: + op.c_elementwise_op = "MultiplyMultiplyAdd" + else: + op.c_elementwise_op = "MultiplyMultiply" + op.c_shuffle_dtype = "F32" + op.ds_layouts = ( + torch_layout_to_ck_layout(scale_x.get_layout()), + torch_layout_to_ck_layout(scale_w.get_layout()), + ) + op.ds_element_dtypes = ( + self._TORCH_DTYPE_TO_CK[scale_x.get_layout().dtype], + self._TORCH_DTYPE_TO_CK[scale_w.get_layout().dtype], + ) + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block += (1, 1) + else: + scale_x = None + scale_w = None + + bias_dtype = "" + if Bias is not None: + bias_layout = torch_layout_to_ck_layout(Bias.get_layout()) + bias_dtype = self._TORCH_DTYPE_TO_CK[Bias.get_layout().dtype] + op.ds_layouts += (bias_layout,) + op.ds_element_dtypes += (bias_dtype,) + if not has_scale: + op.c_elementwise_op = "Bilinear" + # c_shuffle_dtype is also used for adding bias to matmul result + # before converting down to the result dtype + op.c_shuffle_dtype = op.acc_dtype + # this parameter needs to be set accordingly to bias stride for correct accumulation + if bias_layout == "Row": + # bias has (N, ) shape + bias_shuffle_block_transfer_scalar_per_vector_n_per_block = ( + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block + ) + elif bias_layout == "Col": + # bias has (M, 1) shape + bias_shuffle_block_transfer_scalar_per_vector_n_per_block = (1,) + else: + raise AssertionError( + "Bias layout is neither row-major nor column-major" + ) + # ...and the second tuple element corresponds to the bias + op.c_shuffle_block_transfer_scalar_per_vector_n_per_block += ( + bias_shuffle_block_transfer_scalar_per_vector_n_per_block + ) + + instance_definition, instance_type = self.emit_ck_instance(op) + + version_comment = rf"""/** +* Generated code for CK inductor backend +* See {type(self).__module__}.{type(self).__qualname__} +* +* Template instance {op} +* +* {torch.__version__=} +* torch.version.git_version={getattr(torch.version, "git_version", "None")} +*/ +""" + epilogue = None + + if op.c_elementwise_op == "Bilinear" and scale_w is None: + epilogue = f"Bilinear {{ {self.alpha}, {self.beta} }}" + + elif op.c_elementwise_op == "Scale": + epilogue = "Scale { (inv_scale_w && inv_scale_x) ? (*inv_scale_w * *inv_scale_x) : 1.0f }" + + elif op.c_elementwise_op == "ScaleAdd": + epilogue = "ScaleAdd { (inv_scale_w && inv_scale_x) ? (*inv_scale_w * *inv_scale_x) : 1.0f }" + + elif op.c_elementwise_op == "MultiplyMultiply": + epilogue = "MultiplyMultiply {}" + + elif op.c_elementwise_op == "MultiplyMultiplyAdd": + epilogue = "MultiplyMultiplyAdd {}" + + elif op.c_elementwise_op == "PassThrough": + epilogue = "PassThrough {}" + + assert epilogue is not None, "CK GEMM epilogue is not set" + + size_arg_strs = ["M", "N", "K", "LDA", "LDB", "LDC", "LDD"] + if self.is_batched: + size_arg_strs.insert(0, "B") + + res = self._template_from_string(self.gemm_template).render( + inline_utils=self.inline_utils(), + headers=self.header().getvalue(), + globals=self.globals().getvalue(), + instance_definition=instance_definition, + kernel_definition=kernel.def_kernel( + inputs=[X, W, scale_x, scale_w, Bias], # type: ignore[list-item] + outputs=[Y], + names_str="X, W, inv_scale_x, inv_scale_w, Bias, Y", + input_reorder=self.input_reorder, + size_args=[f"int32_t {arg}" for arg in size_arg_strs], + ), + instance_type=instance_type, + a_element_dtype=op.a_element_dtype, + b_element_dtype=op.b_element_dtype, + c_element_dtype=op.c_element_dtype, + bias_element_dtype=bias_dtype, + alpha=self.alpha, + beta=self.beta, + a_elementwise_op="PassThrough {}", + b_elementwise_op="PassThrough {}", + epilogue=epilogue, + has_bias=has_bias, + ds_size=1 + if op.c_elementwise_op in ("Bilinear", "ScaleAdd") + else 2 + if op.c_elementwise_op == "MultiplyMultiply" + else 3 + if op.c_elementwise_op == "MultiplyMultiplyAdd" + else 0, + ds_names=", ".join( + ["Bias"] + if op.c_elementwise_op in ("Bilinear", "ScaleAdd") + else ["inv_scale_x", "inv_scale_w"] + if op.c_elementwise_op == "MultiplyMultiply" + else ["inv_scale_x", "inv_scale_w", "Bias"] + if op.c_elementwise_op == "MultiplyMultiplyAdd" + else [] + ), + ds_strides=", ".join( + ["LDD"] + if op.c_elementwise_op in ("Bilinear", "ScaleAdd") + else ["0", "0"] + if op.c_elementwise_op == "MultiplyMultiply" + else ["0", "0", "LDD"] + if op.c_elementwise_op == "MultiplyMultiplyAdd" + else [] + ), + version_comment=version_comment, + is_batched=self.is_batched, + ds_batch_strides=", ".join([]), # FIXME when supporting baddbmm + ) + + if config.rocm.generate_test_runner: + is_static_problem = all(is_static_int(arg) for arg in self.size_args()) + # NOTE: size_arg_strs is defined above + size_arg_vals = ( + self.size_args() + if is_static_problem + else ( + f"std::stoi(argv[{k}])" for k, _ in enumerate(self.size_args(), 1) + ) + ) + size_args = dict(zip(size_arg_strs, size_arg_vals, strict=True)) + runtime_args = dict( + zip( + [a.name for a in self.get_runtime_arg_info()], + self.get_runtime_arg_values(), + ) + ) + runner_code = self._template_from_string( + self.standalone_runner_template + ).render( + inline_utils=self.inline_utils().getvalue(), + kernel_name=kernel.kernel_name, + has_bias=has_bias, + has_scale=has_scale, + is_batched=self.is_batched, + a_ck_dtype=op.a_element_dtype, + b_ck_dtype=op.b_element_dtype, + c_ck_dtype=op.c_element_dtype, + bias_ck_dtype=op.ds_element_dtypes[0] if has_bias else "", + scale_a_ck_dtype=op.ds_element_dtypes[0] + if has_scale and 2 == len(op.ds_element_dtypes) + else "BF16", + scale_b_ck_dtype=op.ds_element_dtypes[1] + if has_scale and 2 == len(op.ds_element_dtypes) + else "BF16", + a_torch_dtype=DTYPE_TO_CPP[X.get_layout().dtype], + b_torch_dtype=DTYPE_TO_CPP[W.get_layout().dtype], + c_torch_dtype=DTYPE_TO_CPP[Y.get_layout().dtype], + bias_torch_dtype=DTYPE_TO_CPP[Bias.get_layout().dtype] + if Bias is not None + else "", + scale_a_torch_dtype=DTYPE_TO_CPP[scale_x.get_layout().dtype] + if scale_x is not None + else "", + scale_b_torch_dtype=DTYPE_TO_CPP[scale_w.get_layout().dtype] + if scale_w is not None + else "", + a_layout=torch_layout_to_ck_layout(X.get_layout()), + b_layout=torch_layout_to_ck_layout(W.get_layout()), + c_layout=torch_layout_to_ck_layout(Y.get_layout()), + bias_layout=torch_layout_to_ck_layout(Bias.get_layout()) + if Bias is not None + else "", + compile_cmd=rocm_compile_command( + [""], "", "exe" + ), + **size_args, + **runtime_args, + ) + res += runner_code + + return res + + def _is_rcr_f16(self): + X_meta, W_meta, Y_meta = ( + T.get_layout() for T in [*self.input_nodes, self.output_node] + ) + X_dtype, W_dtype, Y_dtype = ( + self._TORCH_DTYPE_TO_CK[m.dtype] for m in (X_meta, W_meta, Y_meta) + ) + X_layout, W_layout, Y_layout = ( + torch_layout_to_ck_layout(m) for m in (X_meta, W_meta, Y_meta) + ) + + return ( + X_dtype == "F16" + and W_dtype == "F16" + and Y_dtype == "F16" + and X_layout == "Row" + and W_layout == "Col" + and Y_layout == "Row" + ) + + # helper to calculate a potentially optimal kBatch(es) for a problem + def _get_kBatch(self, op): + # we only set a higher kBatch if K > 16 * the larger of M and N + # this is a hand-tuned heuristic to start + metas = [T.get_layout() for T in [*self.input_nodes]] + X_meta = metas[0] + W_meta = metas[1] + M = X_meta.size[-2] + K = X_meta.size[-1] + N = W_meta.size[-1] + if is_dynamic(*self.input_nodes): + return [1] + if K // max(M, N) < config.rocm.split_k_threshold: + return [1] + # if the user is telling us which kBatches to sweep, just use those + if config.rocm.kBatch_sweep is not None: + return config.rocm.kBatch_sweep + # Calculate the number of blocks needed for each dimension + total_k_blocks = math.ceil(K / op.k_per_block) + # we want to calculate how many blocks we need to fit per CU + cus = torch.cuda.get_device_properties(X_meta.device).multi_processor_count + # again, manual heuristics as much larger kBatch are significantly worse in + # initial testing + kBatch = min(max(next_power_of_2(total_k_blocks // cus), 1), 128) + return [kBatch] + + def gen_ops(self) -> list[InductorROCmOp]: + """ + Creates a list of `CKGemmOperation` instances that match the GEMM operation this template represents. + The instances are guaranteed to have the correct layout, dtype and dimension padding for the GEMM input arguments. + + An instance may invalidate the GEMM configuration at runtime. + Such instances will be assigned +inf runtime by the autotune process. + """ + try: + from ck4inductor.batched_universal_gemm.gen_instances import ( # type: ignore[import] + gen_ops_library as gen_batched_gemm_ops_library, + ) + from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import] + gen_ops_library as gen_gemm_ops_library, + gen_ops_preselected as gen_gemm_ops_preselected, + ) + except ImportError: + return [] + + generator = None + if self.is_batched: + generator = gen_batched_gemm_ops_library + else: + generator = gen_gemm_ops_library + if config.rocm.use_preselected_instances and self._is_rcr_f16(): + generator = gen_gemm_ops_preselected + + assert generator is not None + + rops = generator() + ops = [] + for o in rops: + kBatches = self._get_kBatch(o) + for kBatch in kBatches: + ops.append(InductorROCmOp(op=o, kBatch=kBatch)) + + filtered_instances = list(filter(lambda op: self.filter_op(op), ops)) + + # NB: when using a fixed list order, most likely we will pick the subset of instances + # which are very similar to each other. Randomizing the choice seems to solve this. + random.seed(-11) + chosen_instances = ( + random.sample( + filtered_instances, + min(len(filtered_instances), config.rocm.ck_max_profiling_configs), + ) + if config.rocm.ck_max_profiling_configs + else filtered_instances + ) + log.debug( + "generated %d ck instances after filter: %s", + len(chosen_instances), + chosen_instances, + ) + return chosen_instances + + @staticmethod + def add_ck_gemm_choices( + choices, + layout, + input_nodes, + alpha=1, + beta=0, + input_reorder=None, + ): + """ + Add Composable Kernel Universal GEMM instance choices to the auto-tuning list. + """ + template = CKGemmTemplate( + input_nodes, + layout, + alpha=alpha, + beta=beta, + input_reorder=input_reorder, + ) + ops = template.gen_ops() + for op in ops: + template.maybe_append_choice( + choices, + op=op.op, + kBatch=op.kBatch, + ) + + def size_args(self): + X = self.input_nodes[0] + W = self.input_nodes[1] + Bias = ( + self.input_nodes[2] + if len(self.input_nodes) == 3 + else self.input_nodes[4] + if len(self.input_nodes) == 5 + else None + ) + Y = self.output_node + + M = X.get_size()[-2] + K = X.get_size()[-1] + N = W.get_size()[-1] + LDA = X.get_stride()[-2 if X.get_stride()[-1] == 1 else -1] + LDB = W.get_stride()[-2 if W.get_stride()[-1] == 1 else -1] + LDC = Y.get_stride()[-2 if Y.get_stride()[-1] == 1 else -1] + LDD = ( + 0 + if (Bias is None or len(Bias.get_size()) == 1) + else Bias.get_stride()[-2 if Bias.get_stride()[-1] == 1 else -1] + ) + if self.is_batched: + B = X.get_size()[0] + return B, M, N, K, LDA, LDB, LDC, LDD + else: + return M, N, K, LDA, LDB, LDC, LDD diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/compile_command.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/compile_command.py new file mode 100644 index 0000000000000000000000000000000000000000..c0c278bfd72ce5e12fa2fbdd0325e90607776aeb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/compile_command.py @@ -0,0 +1,148 @@ +# mypy: allow-untyped-defs +import logging +import os +from typing import Optional + +from torch._inductor import config +from torch._inductor.utils import is_linux + + +log = logging.getLogger(__name__) + + +def _rocm_include_paths(dst_file_ext: str) -> list[str]: + from torch.utils import cpp_extension + + rocm_include = ( + os.path.join(config.rocm.rocm_home, "include") + if config.rocm.rocm_home + else cpp_extension._join_rocm_home("include") + ) + if not config.rocm.ck_dir: + log.warning("Unspecified Composable Kernel include dir") + + if config.is_fbcode(): + from libfb.py import parutil + + ck_path = parutil.get_dir_path("composable-kernel-headers") + else: + ck_path = config.rocm.ck_dir or cpp_extension._join_rocm_home( + "composable_kernel" + ) + + ck_include = os.path.join(ck_path, "include") + ck_library_include = os.path.join(ck_path, "library", "include") + + # CK has to take priority over ROCm include paths + # Since CK is potentially more up-to-date + paths = [ + os.path.realpath(p) for p in (ck_include, ck_library_include, rocm_include) + ] + if dst_file_ext == "exe": + ck_utility_include = os.path.join(ck_path, "library", "src", "utility") + paths.append(os.path.realpath(ck_utility_include)) + return paths + + +def _rocm_lib_options(dst_file_ext: str) -> list[str]: + from torch.utils import cpp_extension + + rocm_lib_dir = ( + os.path.join(config.rocm.rocm_home, "lib") + if config.rocm.rocm_home + else cpp_extension._join_rocm_home("lib") + ) + hip_lib_dir = ( + os.path.join(config.rocm.rocm_home, "hip", "lib") + if config.rocm.rocm_home + else cpp_extension._join_rocm_home("hip", "lib") + ) + + opts = [ + "-include __clang_hip_runtime_wrapper.h", + f"-L{os.path.realpath(rocm_lib_dir)}", + f"-L{os.path.realpath(hip_lib_dir)}", + "-lamdhip64", + ] + if dst_file_ext == "exe": + opts += ["-lpthread", "-lstdc++"] + return opts + + +def _rocm_compiler_options() -> list[str]: + arch_list = config.rocm.arch or ["native"] + gpu_arch_flags = [f"--offload-arch={arch}" for arch in arch_list] + opts = [ + config.rocm.compile_opt_level, + "-x", + "hip", + "-std=c++17", + *gpu_arch_flags, + "-fno-gpu-rdc", + "-fPIC", + "-fvisibility=hidden", + "-mllvm", + "-amdgpu-early-inline-all=true", + "-mllvm", + "-amdgpu-function-calls=false", + "-mllvm", + "-enable-post-misched=0", + ] + if config.rocm.is_debug: + opts += ["-DDEBUG_LOG=1", "-g"] + if config.rocm.save_temps: + opts += ["--save-temps=obj"] + if config.rocm.print_kernel_resource_usage: + opts += ["-Rpass-analysis=kernel-resource-usage"] + if config.rocm.flush_denormals: + opts += ["-fgpu-flush-denormals-to-zero"] + if config.rocm.use_fast_math: + opts += ["-ffast-math"] + return opts + + +def rocm_compiler() -> Optional[str]: + if is_linux(): + if config.rocm.rocm_home: + return os.path.realpath( + os.path.join(config.rocm.rocm_home, "llvm", "bin", "clang") + ) + try: + from torch.utils import cpp_extension + + return os.path.realpath( + cpp_extension._join_rocm_home("llvm", "bin", "clang") + ) + except OSError: + # neither config.rocm.rocm_home nor env variable ROCM_HOME are set + return "clang" + return None + + +def rocm_compile_command( + src_files: list[str], + dst_file: str, + dst_file_ext: str, + extra_args: Optional[list[str]] = None, +) -> str: + include_paths = _rocm_include_paths(dst_file_ext) + lib_options = _rocm_lib_options(dst_file_ext) + compiler_options = _rocm_compiler_options() + compiler = rocm_compiler() + options = ( + compiler_options + + (extra_args or []) + + [f"-I{path}" for path in include_paths] + + lib_options + ) + src_file = " ".join(src_files) + # supported extensions: .o, .so, .exe + if dst_file_ext == "o": + options.append("-c") + elif dst_file_ext == "so": + options.append("-shared") + elif dst_file_ext == "exe": + options.append("-DGENERATE_CK_STANDALONE_RUNNER") + else: + raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!") + return f"{compiler} {' '.join(options)} -o {dst_file} {src_file}" diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/rocm_benchmark_request.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/rocm_benchmark_request.py new file mode 100644 index 0000000000000000000000000000000000000000..f0006331ac58a0287f27fa37d1039cc7ccc5ef00 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/rocm_benchmark_request.py @@ -0,0 +1,143 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import functools +import logging +from ctypes import byref, c_int, c_size_t, c_void_p +from typing import Any, Callable, Optional, TYPE_CHECKING, Union + +import torch +from torch._inductor import config +from torch._inductor.autotune_process import ( + BenchmarkRequest, + GPUDeviceBenchmarkMixin, + TensorMeta, +) +from torch._inductor.codecache import DLLWrapper, ROCmCodeCache + + +if TYPE_CHECKING: + from collections.abc import Iterable + + +log = logging.getLogger(__name__) + + +class ROCmBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest): + # Important: Instances of this class have to be serializable + # across process boundaries. Do not put CUDA Tensors in here! + + def __init__( + self, + kernel_name: str, + input_tensor_meta: Union[TensorMeta, list[TensorMeta]], + output_tensor_meta: Union[TensorMeta, list[TensorMeta]], + extra_args: Iterable[Any], + source_code: str, + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.source_code = source_code + self.workspace_size: int = 0 + self.workspace: Optional[torch.Tensor] = None + self.DLL: Optional[DLLWrapper] = None + self._workspace_size_updated = False + self.hash_key: str = "" + self.source_file: str = "" + self.hash_key, self.source_file = ROCmCodeCache.write(self.source_code, "so") + + def precompile(self): + # Prepopulate code cache + # may happen in separate Threadpool + log.debug("Precompiling %s", self) + ROCmCodeCache.compile(self.source_code, "so") + if config.rocm.generate_test_runner: + ROCmCodeCache.compile(self.source_code, "exe") + log.debug("Done precompiling %s", self) + + def make_run_fn( + self, *input_tensors: torch.Tensor, out: torch.Tensor + ) -> Callable[[], None]: + self.ensure_dll_loaded() + self.update_workspace_size() + args = [c_void_p(tensor.data_ptr()) for tensor in list(input_tensors) + [out]] + size_args = [c_int(arg) for arg in self.extra_args] + log.debug( + "make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + run_method = getattr(self.DLL, self.kernel_name) + workspace_ptr = c_void_p(0) + if self.workspace_size > 0: + self.workspace = torch.zeros( + (self.workspace_size + 7) // 8, + dtype=torch.float64, + device=out.device, + ) + workspace_ptr = c_void_p(self.workspace.data_ptr()) + + # Generate partial function. + return functools.partial( + run_method, + *args, + *size_args, + None, # null workspace size ptr + workspace_ptr, # set workspace ptr, + stream_ptr, + ) + + def update_workspace_size(self) -> None: + if self._workspace_size_updated: + return + self.ensure_dll_loaded() + unique_input_count = len( + {meta.name for meta in self.input_tensor_meta} # noqa: set_linter + ) + args = [c_void_p(None) for _ in range(unique_input_count + 1)] + stream_ptr = c_void_p(torch.cuda.current_stream().cuda_stream) + + run_method = getattr(self.DLL, self.kernel_name) + # Retrieve workspace_size and initialize workspace. + c_workspace_size = c_size_t() + size_args = [c_int(arg) for arg in self.extra_args] + run_method( + *args, # input ptrs and output ptrs + *size_args, + byref( + c_workspace_size + ), # set workspace size ptr to retrieve workspace size + None, # null workspace ptr + stream_ptr, + ) + torch.cuda.synchronize() # shake out any CUDA errors + self.workspace_size = c_workspace_size.value + log.debug( + "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950 + self.workspace_size, + self.kernel_name, + self.source_file, + self.hash_key, + self.DLL, + args, + self.extra_args, + ) + self._workspace_size_updated = True + + def ensure_dll_loaded(self): + if self.DLL is None: + self.DLL, self.hash_key, self.source_file = ROCmCodeCache.load( + self.source_code, "so" + ) + + def cleanup_run_fn(self) -> None: + if self.DLL is not None: + self.DLL.close() + self.workspace = None + + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc0eaf7ce199f2e1178f4afca118bd43e336ed4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py @@ -0,0 +1,99 @@ +# mypy: allow-untyped-defs +import logging +from collections.abc import Sequence +from typing import cast + +from ... import config +from ...codecache import code_hash, get_path +from ...scheduler import BaseSchedulerNode, BaseScheduling, SchedulerNode +from ...utils import get_fused_kernel_name, get_kernel_metadata, sympy_product +from ...virtualized import V +from ..common import IndentedBuffer +from .rocm_template_buffer import ROCmTemplateBuffer + + +log = logging.getLogger(__name__) + + +class ROCmCPPScheduling(BaseScheduling): + """ + Partial Scheduling implementation for ROCm C++ Kernels. + This class is intended to be used in combination with TritonScheduling, + and delegated to by CUDACombinedScheduling. + + It handles fusion decisions and ROCm C++ specific template code generation. + """ + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + @staticmethod + def is_rocm_cpp_template(node: BaseSchedulerNode) -> bool: + return isinstance(node, SchedulerNode) and isinstance( + node.node, ROCmTemplateBuffer + ) + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + return False + + def define_kernel(self, src_code: str, node_schedule) -> str: + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_name = "_".join(["rocm", fused_name, wrapper.next_kernel_suffix()]) + # use the original src_code as the key + wrapper.src_to_kernel[src_code] = kernel_name + src_code = src_code.replace("KERNEL_NAME", kernel_name) + + _, _, kernel_path = get_path(code_hash(src_code), "py") + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline("async_compile.rocm(r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline( + f"''', 'so', aot_compile={str(V.graph.aot_mode)})" + ) + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + return kernel_name + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ): + """ + Codegen a ROCm template, possibly with fused epilogues + """ + assert self.is_rocm_cpp_template(template_node), ( + "Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer" + ) + template_node = cast(SchedulerNode, template_node) + _, (_numel, rnumel) = template_node.group + assert rnumel == 1 + ctb: ROCmTemplateBuffer = cast(ROCmTemplateBuffer, template_node.node) + kernel, render = ctb.make_kernel_render(ctb) + with kernel: + template_node.mark_run() + src_code = render() + + with V.set_kernel_handler(kernel): + node_schedule = [template_node] + kernel_name = self.define_kernel(src_code, node_schedule) + kernel.call_kernel(kernel_name, ctb) + V.graph.removed_buffers |= kernel.removed_buffers + self.free_buffers_in_scheduler() diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/rocm_kernel.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/rocm_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..93b61a551621949aaed126c5641f485fe2aa8973 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/rocm_kernel.py @@ -0,0 +1,289 @@ +# mypy: allow-untyped-defs +import logging +from collections.abc import Sequence +from typing import Any, Callable, Optional, TYPE_CHECKING, Union + +import torch._inductor.config as config +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch._inductor.utils import do_bench_using_profiling + +from ...ir import Buffer, ChoiceCaller, IRNode, Layout, PrimitiveInfoType, TensorBox +from ...virtualized import V +from ..common import Kernel, OpOverrides, WorkspaceArg, WorkspaceZeroMode +from ..cpp_utils import CppPrinter +from .rocm_benchmark_request import ROCmBenchmarkRequest +from .rocm_template_buffer import ROCmTemplateBuffer +from .rocm_utils import DTYPE_TO_ROCM_TYPE + + +if TYPE_CHECKING: + from torch._inductor.codegen.rocm.rocm_template import ArgInfo, ROCmTemplate + +log = logging.getLogger(__name__) + +cexpr = CppPrinter().doprint + + +def _normalize_idx(index: int, total_length: int) -> int: + return index if index >= 0 else index + total_length + + +class ROCmKernel(Kernel): + """ + Baseclass for ROCm based Kernels + """ + + overrides = OpOverrides # type: ignore[assignment] + + +class ROCmTemplateKernel(ROCmKernel): + """ + Template kernels defined by ROCm in C++. + """ + + _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, hipStream_t stream" + + def __init__( + self, + kernel_name: str, + runtime_arg_info: list["ArgInfo"], + runtime_arg_values: list[Any], + ) -> None: + """ + Initializes a new instance of the ROCmTemplateKernel class. + + Args: + kernel_name (str): The name of the kernel. + """ + super().__init__() + self.kernel_name = kernel_name + # Mapping from arg name to IRNode. + self.named_nodes: dict[str, IRNode] = {} + self.runtime_arg_info = runtime_arg_info + self.runtime_arg_values = runtime_arg_values + + def get_signature(self): + return self.signature + + def def_kernel( + self, + inputs: list[IRNode], + outputs: list[IRNode], + size_args: list[str], + names_str: str = "", + input_reorder: Optional[list[int]] = None, + ) -> str: + """ + Hook called from template code to generate function definition and + needed args. + + Args: + inputs: List of input IRNodes + outputs: List of output IRNodes + names_str: Comma separated list of input + output argument names. + input_reorder: The actual order of input nodes. + e.g. The template might have input argument defined as [X, W, Bias], + and the actual input passed into this template could be [Bias, X, W]. + In this case, the `input_reorder` would be [2, 0, 1]. + """ + names = [x.strip() for x in names_str.strip().split(",")] + if len(inputs) + len(outputs) != len(names): + raise RuntimeError( + f"{len(inputs) + len(outputs)=} != {len(names)=}, {inputs=}, {outputs=}, {names=}" + ) + + if input_reorder == [2, 0, 1]: + input_reorder = [4, 0, 1, 2, 3] + + if input_reorder is not None: + assert len(inputs) == len(input_reorder) + else: + input_reorder = list(range(len(inputs))) + + for idx in input_reorder: + name = names[idx] + node = inputs[idx] + if node is not None: + self.named_nodes[name] = node + self.args.input_buffers[node.get_name()] = name + + for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): + if node is not None: + self.named_nodes[name] = node + self.args.output_buffers[node.get_name()] = name + + arg_defs, *_ = self.args.cpp_argdefs(DTYPE_TO_ROCM_TYPE) + + runtime_arg_defs = [f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info] + + signature = f"int {self.kernel_name}({', '.join(arg_defs + size_args + runtime_arg_defs)},{self._EXTRA_CPP_ARGS})" + self.signature = signature + return signature + + def call_kernel( + self, + name: str, + node: "ROCmTemplateBuffer", # type: ignore[name-defined] + ) -> None: + """ + Generates code to call the kernel through V.graph.wrapper_code. + used from within torch._inductor.wrapper.PythonWrapperCodegen + + name: Name of kernel function. + node: The ROCmTemplateBuffer node which contains information about the kernel, it's fused epilogue nodes + as well as all required inputs and outputs. + """ + wrapper = V.graph.wrapper_code + + arg_types: list[Any] + if V.graph.cpp_wrapper: + # Make sure we initialize these kernels since they're exported as + # C-style symbol names. + assert isinstance(wrapper, CppWrapperCpu) + wrapper.initialized_kernels[name] = self + # Kinda hacky because we always originally initialize name with "KERNEL_NAME" + # So, we replace with the real kernel name passed as an arg to this function. + self.signature = self.signature.replace("KERNEL_NAME", name) + _, call_args, arg_types = self.args.cpp_argdefs(DTYPE_TO_ROCM_TYPE) + else: + _, call_args, _, arg_types = self.args.python_argdefs() + + kernel_args = [] + for arg in call_args: + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + if V.graph.is_unspec_arg(arg): + arg = arg + ".item()" + else: + if not V.graph.cpp_wrapper: + arg = f"c_void_p({arg}.data_ptr())" + kernel_args.append(arg) + + # add size args + size_args = [ + f"{V.graph.sizevars.simplify(sarg)}" for sarg in node.template.size_args() + ] + + if V.graph.cpp_wrapper: + kernel_args.extend(size_args) + else: + kernel_args.extend(f"c_int({sarg})" for sarg in size_args) + + if V.graph.cpp_wrapper: + arg_types.extend(["int"] * len(node.template.size_args())) + + # the runtime args come right after the size args + kernel_args.extend(self.runtime_arg_values) + for arg in self.runtime_arg_info: + arg_types.append(arg.ty) + + # workspace_size ptr is NULL to mark this call is not intended for retrieving workspace_size. + # workspace_size should have already been retrieved prior to this call. + kernel_args.append("nullptr" if V.graph.cpp_wrapper else "None") + if V.graph.cpp_wrapper: + arg_types.append("size_t*") + + if node.get_workspace_size() > 0: + ws = WorkspaceArg( + count=node.get_workspace_size(), + device=V.graph.get_current_device_or_throw(), + zero_mode=WorkspaceZeroMode.UNINITIALIZED, + outer_name=WorkspaceArg.unique_name(), + ) + wrapper.generate_workspace_allocation(ws) + data_ptr = f"{ws.outer_name}.data_ptr()" + kernel_args.append( + data_ptr if V.graph.cpp_wrapper else f"c_void_p({data_ptr})" + ) + else: + ws = None + kernel_args.append("nullptr" if V.graph.cpp_wrapper else "None") + if V.graph.cpp_wrapper: + arg_types.append("uint8_t*") + wrapper.generate_kernel_call( + name, + kernel_args, + triton=False, + arg_types=arg_types, + ) + if ws: + wrapper.generate_workspace_deallocation(ws) + + +class ROCmTemplateCaller(ChoiceCaller): + """ + ROCmTemplateCaller + + This class represents a caller for ROCm template kernels. It is a subclass of ChoiceCaller. + Attributes: + name (str): The name of the caller. + category (str): The category of the caller. + bmreq (ROCmBenchmarkRequest): The benchmark request for the caller. + template_buffer (ROCmTemplateBuffer): The template buffer for the caller. + """ + + def __init__( + self, + name: str, + category: str, + input_nodes: list[Buffer], + layout: Layout, + make_kernel_render: Callable[ + [ROCmTemplateBuffer, Optional[Sequence[IRNode]]], str + ], + bmreq: ROCmBenchmarkRequest, + template: "ROCmTemplate", # type: ignore[name-defined] + info_kwargs: Optional[ + dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]] + ], # type: ignore[type-arg] + ) -> None: + super().__init__(name, input_nodes, layout, description="") + self.category = category + self.make_kernel_render = make_kernel_render + self.bmreq = bmreq + self.template = template + self.info_kwargs = info_kwargs + + def precompile(self) -> None: + assert self.bmreq is not None + self.bmreq.precompile() + + def benchmark(self, *args, out) -> float: + assert self.bmreq is not None + if config.profile_bandwidth_with_do_bench_using_profiling: + algo = self.bmreq.make_run_fn(*args, out=out) + return do_bench_using_profiling(algo) + return self.bmreq.benchmark(*args, out=out) + + def __str__(self) -> str: + return f"ROCmTemplateCaller(source_file={self.bmreq.source_file}, {self.info_dict()})" + + def call_name(self) -> str: + return f"rocm_template_kernels.{self.name}" + + def hash_key(self) -> str: + return "-".join( + [ + self.category, + self.bmreq.hash_key, + ] + ) + + def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return { + "backend": "ROCm", + "name": self.name, + **dict(self.info_kwargs["op"].dict_items()), # type: ignore[union-attr, index] + } + + def output_node(self) -> TensorBox: + self.bmreq.update_workspace_size() + return TensorBox.create( + ROCmTemplateBuffer( + layout=self.layout, + inputs=self.input_nodes, + make_kernel_render=self.make_kernel_render, + workspace_size=self.bmreq.workspace_size, + template=self.template, + ) + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/rocm_template.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/rocm_template.py new file mode 100644 index 0000000000000000000000000000000000000000..06a61b4ed629f2b8c68da5c98ca565baa5dbebeb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/rocm_template.py @@ -0,0 +1,192 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Optional +from unittest.mock import patch + +from ...autotune_process import TensorMeta +from ...ir import Buffer, IRNode, Layout +from ...utils import IndentedBuffer, unique +from ...virtualized import V +from ..common import KernelTemplate +from .rocm_benchmark_request import ROCmBenchmarkRequest +from .rocm_kernel import ROCmTemplateCaller, ROCmTemplateKernel +from .rocm_template_buffer import ROCmTemplateBuffer +from .rocm_utils import DTYPE_TO_ROCM_TYPE + + +log = logging.getLogger(__name__) + + +# FIXME: unify with the CUDA version +@dataclass(frozen=True) +class ArgInfo: + name: str + ty: str + + +class ROCmTemplate(KernelTemplate): + index_counter = itertools.count() + gfx9_threads_per_warp = 64 + + def __init__( + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + input_reorder: Optional[list[int]] = None, + ) -> None: + """ + + Baseclass for ROCm C++ Templates, derived from KernelTemplate. Not to be instantiated directly. + + Args: + name (str): The name of the ROCmTemplate object. + input_nodes (List[IRNode]): A list of input IRNodes. + layout (Layout): The layout of the output buffer / tensor. + input_reorder (Optional[List[int]]): An optional list that specifies the order of the input nodes. + + """ + super().__init__(name) + self.input_nodes = input_nodes + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) + self.input_reorder = input_reorder + self.layout = layout + + def generate( # type: ignore[override] + self, + **kwargs, + ) -> ROCmTemplateCaller: + """ + Generates the ROCm template caller object for the given GEMM template and operation. This ROCmTemplateCaller + may be used to call and benchmark the generated ROCm kernel in a standalone manner to enable Autotuning. + + Args: + kwargs: Additional keyword arguments. + + Returns: + A ROCmTemplateCaller object representing the generated ROCm template caller. + """ + kernel_name = f"rocm_{self.name}" + kernel_hash_name = f"rocm_{self.name}_{next(self.index_counter)}" + with ( + patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)), + ROCmTemplateKernel( + kernel_name=kernel_name, + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) as kernel, + ): + code = self.render(kernel=kernel, **kwargs) + _, call_args, _, _ = kernel.args.python_argdefs() + log.debug("Autotune key: %s, Generated Code:\n%s", kernel_hash_name, code) + log.debug( + "Args: cpp_argdefs: %s, python_argdefs: %s", + kernel.args.cpp_argdefs(DTYPE_TO_ROCM_TYPE), + kernel.args.python_argdefs(), + ) + + input_reorder = ( + self.input_reorder + if self.input_reorder is not None + else list(range(len(self.input_nodes))) + ) + expected_args = list( + unique(self.input_nodes[idx].get_name() for idx in input_reorder) + ) + expected_args.extend([self.output_node.get_name()]) + assert list(call_args)[: len(expected_args)] == expected_args, ( + call_args, + expected_args, + ) + + size_args = ( + self.size_args() if hasattr(self, "size_args") else () + ) # subclass should define def size_args() + size_args_ints = [ + V.graph.sizevars.size_hint(arg) for arg in size_args + ] # resolve to ints for benchmarking + # The runtime args come right after the size args + runtime_args = self.get_runtime_arg_values(**kwargs) + extra_args = size_args_ints + runtime_args + bmreq = ROCmBenchmarkRequest( + kernel_name=kernel_name, + input_tensor_meta=TensorMeta.from_irnodes(self.input_nodes), + output_tensor_meta=TensorMeta.from_irnodes(self.output_node), + extra_args=extra_args, + source_code=code, + ) + + def make_kernel_render( + template_node: ROCmTemplateBuffer, + epilogue_nodes: Optional[Sequence[IRNode]] = None, + ): + kernel = ROCmTemplateKernel( + kernel_name="KERNEL_NAME", + runtime_arg_info=self.get_runtime_arg_info(), + runtime_arg_values=self.get_runtime_arg_values(**kwargs), + ) + render = functools.partial( + self.render, + kernel=kernel, + template_buffer_node=template_node, + epilogue_nodes=epilogue_nodes, + **kwargs, # includes "op" argument in case of CUTLASSGemmTemplate + ) + return kernel, render + + return ROCmTemplateCaller( + kernel_hash_name, + self.name, + self.input_nodes, + self.output_node.get_layout(), + make_kernel_render, + bmreq, + self, + kwargs, + ) + + def header(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + #include + #include + #include + #include + #include + """ + ) + return res + + def globals(self) -> IndentedBuffer: + res = IndentedBuffer() + res.splice( + """ + // We compile all models with -fvisibility=hidden. Any symbols that need to be + // exposed in the final shared library must be declared with PT_EXPORT to make + // them visible. + #ifdef __GNUC__ // Applies to any compiler with GNU extensions (clang and g++) + #define PT_EXPORT __attribute__((__visibility__("default"))) + #else + #ifdef _WIN32 + #define PT_EXPORT __declspec(dllexport) + #else + #define PT_EXPORT + #endif + #endif + """ + ) + return res + + def render(self, **kwargs) -> str: + raise NotImplementedError + + def get_runtime_arg_info(self) -> list[ArgInfo]: + return [] + + def get_runtime_arg_values(self, **kwargs) -> list[Any]: + return [] diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/rocm_template_buffer.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/rocm_template_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..fcf76937c0b4bcb7e13754ad5c5d9b2e51be339d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/rocm_template_buffer.py @@ -0,0 +1,27 @@ +from collections.abc import Sequence +from typing import Callable, TypeVar +from typing_extensions import ParamSpec + +from ...ir import Buffer, Layout, TemplateBuffer + + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +class ROCmTemplateBuffer(TemplateBuffer): + def __init__( + self, + layout: Layout, + inputs: Sequence[Buffer], + make_kernel_render: Callable[_P, _T], + workspace_size: int, + template: "ROCmTemplate", # type: ignore[name-defined] # noqa: F821 + ) -> None: + super().__init__(layout, inputs, make_kernel_render) + # Global memory (in bytes) needed for this template. + self.workspace_size = workspace_size + self.template = template + + def get_workspace_size(self) -> int: + return self.workspace_size if self.workspace_size is not None else 0 diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/rocm_utils.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/rocm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a24d67eb734d5ee72dcc4dfc8929eaa584f428a5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/rocm/rocm_utils.py @@ -0,0 +1,15 @@ +# mypy: allow-untyped-defs + + +import torch + +from ..cpp_utils import DTYPE_TO_CPP + + +DTYPE_TO_ROCM_TYPE = { + **DTYPE_TO_CPP, + torch.float16: "uint16_t", + torch.float8_e4m3fnuz: "uint8_t", + torch.float8_e5m2fnuz: "uint8_t", + torch.bfloat16: "uint16_t", +} diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/simd.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/simd.py new file mode 100644 index 0000000000000000000000000000000000000000..23e42431f770ed6a6ef3411fb3ed84d55d3f2b2f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/simd.py @@ -0,0 +1,2462 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import functools +import itertools +import logging +import math +import operator +import textwrap +from collections import Counter +from typing import Any, Callable, Generic, no_type_check, Optional, TYPE_CHECKING, Union +from typing_extensions import TypeVar + +import sympy + +import torch +import torch._logging +from torch._inductor.tiling_utils import analyze_memory_coalescing +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols +from torch.fx.immutable_collections import immutable_dict +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing +from torch.utils._sympy.symbol import ( + free_symbol_is_type, + prefix_str, + symbol_is_type, + SymT, +) + +from ..._dynamo.utils import counters +from .. import config, ir, scheduler +from ..analyze_preserves_zero_mask import prologue_preserves_zero_mask +from ..codecache import code_hash +from ..dependencies import MemoryDep, StarDep, WeakDep + + +if TYPE_CHECKING: + from ..ir import IRNode + +from ..optimize_indexing import indexing_dtype_strength_reduction +from ..runtime.runtime_utils import green_text, yellow_text +from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse +from ..utils import ( + cache_on_self, + expr_fits_within_32bit, + get_dtype_size, + IndentedBuffer, + Placeholder, + prefix_is_reduction, + set_kernel_post_grad_provenance_tracing, + sympy_index_symbol, + sympy_product, + sympy_subs, + unique, +) +from ..virtualized import ops, OpsWrapper, V +from .block_analysis import BlockPatternMatcher +from .common import CSEVariable, index_prevent_reordering, Kernel, PythonPrinter +from .multi_kernel import MultiKernel +from .simd_kernel_features import ( + DisableReduction, + EnableReduction, + NodeScheduleEntry, + NodeScheduleMarker, + SIMDKernelFeatures, +) + + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Sequence + + from torch._inductor.tiling_utils import CoalesceVarAnalysis + + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") +fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") + + +pexpr = PythonPrinter().doprint + +all_prefixes = OrderedSet(["z", "y", "x", "r0_", "r1_"]) + + +def get_max_tiles(default: int = 2) -> int: + max_tiles = torch._inductor.config.triton.max_tiles + return max_tiles if max_tiles is not None else default + + +@dataclasses.dataclass +class IterationRanges: + """ + Each range tree represents multiple sets of iteration indexing + in a single tiled dimension in the output kernel. + + If you have two loops ranges one (4, 3, 2) and another (4, 6), + then the range tree will be: + 4 (i0) + 3 (i1) 6 (i3) + 2 (i2) + Where i0 is shared between both loops, but then the split into + different indexing vars. All loop ranges must iterate over + the same number of elements. + """ + + def __init__( + self, + name: str, + var_list: list[sympy.Symbol], + var_ranges: dict[sympy.Symbol, sympy.Expr], + numel: sympy.Expr, + prefix: str, + *, + kernel: SIMDKernel, + divisor=sympy.S.One, + length=sympy.S.One, + root: IterationRangesRoot, + ) -> None: + super().__init__() + self.name = name + self.var_list = var_list + self.var_ranges = var_ranges + self.numel = numel + self.prefix = prefix + self.divisor = divisor + self.length = length + self.kernel = kernel + self.root = root + + @property + @cache_on_self + @no_type_check # https://github.com/python/mypy/issues/17184 + def is_reduction(self) -> bool: + return prefix_is_reduction(self.prefix) + + def symbol(self) -> sympy.Symbol: + return sympy_index_symbol(self.name) + + @property + @cache_on_self + @no_type_check + def symt(self) -> SymT: + prefix_to_symt = {prefix: symt for symt, prefix in prefix_str.items()} + return prefix_to_symt[self.prefix] + + +class IterationRangesRoot(IterationRanges): + """ + Root of a iteration range tree that represents a single + tiled dimension in the output kernel. It contains multiple + sets of iteration represented with IterationRangesEntry. + """ + + def __init__( + self, + name: str, + numel: sympy.Expr, + prefix: str, + index: int, + kernel: SIMDKernel, + pid_cache: Optional[dict[str, str]] = None, + *, + is_loop: bool, + tensor_dim: Optional[int], + grid_dim: Optional[int], + has_zdim: bool, + ) -> None: + if pid_cache is None: + pid_cache = {} + super().__init__( + name=name, + var_list=[], + var_ranges={}, + numel=numel, + prefix=prefix, + kernel=kernel, + root=self, + ) + self.index = index + # Store all the nodes in one flat list + self.nodes: dict[sympy.Expr, IterationRangesEntry] = {} + # This is for re-ordering program ID in triton mm template + # pid_cache["tl.program_id(0)"] = pid_m + self.pid_cache: dict[str, str] = pid_cache + + # True if the dimension is implemented as a single program looping over + # the full dimension (currently only used for non-persistent reduction) + assert not is_loop or (self.is_reduction and grid_dim is None) + self.is_loop = is_loop + # Index of corresponding dimension on triton tensors + self.tensor_dim = tensor_dim + # Index of corresponding dimension in the triton grid + self.grid_dim = grid_dim + self.has_zdim = has_zdim + + def __repr__(self) -> str: + return f"IterationRangesRoot({self.name!r}, {self.numel}, ...)" + + def cache_clear(self) -> None: + for node in self.nodes.values(): + node.cache_clear() + + def index_sym(self) -> sympy.Symbol: + return sympy_index_symbol(f"{self.prefix}index") + + def lookup(self, divisor: sympy.Expr, length: sympy.Expr) -> IterationRangesEntry: + """ + Lookup a given RangeTreeEntry, creating it if needed + """ + if V.graph.sizevars.statically_known_equals(divisor * length, self.numel): + expr = FloorDiv(self.index_sym(), divisor) + else: + expr = ModularIndexing(self.index_sym(), divisor, length) + + if expr not in self.nodes: + node = IterationRangesEntry( + f"{self.prefix}{next(V.kernel.iter_vars_count)}", + divisor, + length, + expr, + self, + ) + V.kernel.range_tree_nodes[node.symbol()] = node + self.var_list.append(node.symbol()) + self.var_ranges[node.symbol()] = length + self.nodes[expr] = node + return self.nodes[expr] + + def construct_entries( + self, lengths: list[sympy.Expr] + ) -> list[IterationRangesEntry]: + divisor = sympy.S.One + itervars = [] + for length in reversed(lengths): + itervars.append(self.lookup(divisor, length)) + divisor = divisor * length + return [*reversed(itervars)] + + def construct(self, lengths: list[sympy.Expr]) -> list[sympy.Symbol]: + return [e.symbol() for e in self.construct_entries(lengths)] + + def vars_and_sizes( + self, index: sympy.Expr + ) -> tuple[list[sympy.Symbol], list[sympy.Expr]]: + """Figure out vars from this tree used in index""" + + def get_sort_key(x: IterationRangesEntry) -> tuple[int, bool]: + """ + Gets the key for sorting nodes. When two nodes have the + same divisor, the node with length as 1 should be handled + first so the current divisor is not changed after multiplied + node.length. Returns `not length_is_one_hint` for ascending + sort. + """ + divisor_hint = V.graph.sizevars.size_hint( + x.divisor, fallback=config.unbacked_symint_fallback + ) + length_is_one_hint = ( + V.graph.sizevars.size_hint( + x.length, fallback=config.unbacked_symint_fallback + ) + == 1 + ) + return (divisor_hint, not length_is_one_hint) + + nodes = [V.kernel.range_tree_nodes.get(s) for s in index.free_symbols] + nodes = [n for n in nodes if n and n.prefix == self.prefix] + nodes.sort(key=lambda x: get_sort_key(x)) + divisor = sympy.S.One + index_vars = [] + sizes = [] + + def add(node): + nonlocal divisor + index_vars.append(node.symbol()) + sizes.append(node.length) + divisor = divisor * node.length + + for node in nodes: + if not V.graph.sizevars.statically_known_equals(node.divisor, divisor): + # fill in unused index var + add(self.lookup(divisor, FloorDiv(node.divisor, divisor))) + divisor = node.divisor + add(node) + if not V.graph.sizevars.statically_known_equals(self.numel, divisor): + # fill in unused index var + add(self.lookup(divisor, FloorDiv(self.numel, divisor))) + + return [*reversed(index_vars)], [*reversed(sizes)] + + +class IterationRangesEntry(IterationRanges): + def __init__( + self, + name: str, + divisor: sympy.Expr, + length: sympy.Expr, + expr: sympy.Expr, + parent: IterationRanges, + ) -> None: + super().__init__( + name=name, + numel=parent.numel / length, + var_list=parent.var_list, + var_ranges=parent.var_ranges, + prefix=parent.prefix, + divisor=divisor, + length=length, + kernel=parent.kernel, + root=parent.root, + ) + self.parent = parent + self.codegen = functools.lru_cache(None)(self._codegen) + self.expr = expr + + def __repr__(self) -> str: + return f"IterationRangesEntry({self.name}, {self.divisor}, {self.length}, {self.expr}, {self.var_ranges})" + + def set_name(self, name: str) -> None: + self.codegen = lambda: name # type: ignore[assignment] + self.codegen.cache_clear = lambda: None # type: ignore[method-assign] + self.name = name + + def cache_clear(self) -> None: + self.codegen.cache_clear() + + def _codegen(self) -> str: + V.kernel.codegen_iteration_ranges_entry(self) + return self.name + + def precomputed_args(self) -> list[sympy.Expr]: + # for dynamic shapes, find parts of indexing expressions that have to be precomputed + precomputed_args: list[sympy.Expr] = [] + if isinstance(self.expr, sympy.Symbol): + return precomputed_args + assert isinstance(self.expr, (FloorDiv, ModularIndexing)), type(self.expr) + for arg in self.expr.args[1:]: + if not isinstance(arg, (sympy.Integer, sympy.Symbol)): + symbols = arg.free_symbols + if len(symbols) > 0 and all( + symbol_is_type(s, SymT.SIZE) for s in symbols + ): + precomputed_args.append(arg) + return precomputed_args + + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other: object) -> bool: + assert isinstance(other, IterationRangesEntry) + return self.name == other.name + + +def constant_repr(value: Union[int, float]) -> str: + if value == float("inf"): + return 'float("inf")' + elif value == float("-inf"): + return 'float("-inf")' + elif math.isnan(value): + return 'float("nan")' + return repr(value) + + +CSEVariableType = TypeVar("CSEVariableType", bound=CSEVariable, default=CSEVariable) + + +class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]): + """ + Common base class for Triton/Halide codegen which both use flattened indexing rather than loop nests. + """ + + sexpr: Callable[[sympy.Expr], str] = pexpr + kexpr: Callable[[sympy.Expr], str] + allow_block_ptr: bool = False + kernel_name: str + + def __init__( + self, + tiling: dict[str, sympy.Expr], + features: SIMDKernelFeatures, + pid_cache: Optional[dict[str, str]] = None, + override_persistent_reduction: Optional[bool] = None, + override_cooperative_reduction: Optional[bool] = None, + tiling_scores: Optional[dict[str, sympy.Expr]] = None, + ) -> None: + if pid_cache is None: + pid_cache = {} + super().__init__() + self.features = features + self.mutations = features.get_mutations() + self.body = IndentedBuffer() + self.indexing_code = IndentedBuffer() + self.numels = { + prefix: V.graph.sizevars.simplify(val) for prefix, val in tiling.items() + } + self.range_trees: list[IterationRangesRoot] = [] + self.range_tree_nodes: dict[sympy.Symbol, IterationRangesEntry] = {} + self.iter_vars_count = itertools.count() + self.inside_reduction = features.is_reduction() + self.cooperative_reduction: bool = ( + override_cooperative_reduction + if override_cooperative_reduction is not None + else self.should_use_cooperative_reduction() + ) + self.tiling_scores: Optional[dict[str, sympy.Expr]] = tiling_scores + self.persistent_reduction: bool = ( + override_persistent_reduction + if override_persistent_reduction is not None + else self.should_use_persistent_reduction() + ) + self.no_x_dim = self.want_no_x_dim() + self.code_hash: Optional[str] = None + + # define this in a closure to make cache local to object + @functools.cache + def simplify_indexing(index: sympy.Expr): + index = V.graph.sizevars.simplify_with_ranges(index, self.var_ranges()) + for tree in self.range_trees: + index = self.combine_contiguous_dims(index, tree) + + return self.combine_modular_indexing_pairs(index) + + self.simplify_indexing = simplify_indexing + self.initialize_range_tree(pid_cache) + + @property + @cache_on_self + @no_type_check # https://github.com/python/mypy/issues/17184 + def num_reduction_dims(self) -> int: + return sum(prefix_is_reduction(prefix) for prefix in self.numels) + + def dtype_to_str(self, dtype: torch.dtype) -> str: + raise NotImplementedError + + def get_index_dtype_as_torch_dtype(self) -> torch.dtype: + return self.features.select_index_dtype() + + @property + def index_dtype(self) -> str: + return self.dtype_to_str(self.get_index_dtype_as_torch_dtype()) + + def want_no_x_dim(self) -> bool: + return False + + def construct_range_trees( + self, + pid_cache: Optional[dict[str, str]], + inside_reduction: bool, + is_reduction: bool, + numels: dict[str, sympy.Expr], + no_x_dim: bool, + ) -> list[IterationRangesRoot]: + active_prefixes = OrderedSet( + prefix for prefix in all_prefixes if prefix in numels + ) + no_r_dim = not inside_reduction or not is_reduction + + def filtered_index_map(seq, mask) -> dict[Any, int]: + return { + val: idx for idx, val in enumerate(val for val in seq if val in mask) + } + + grid_dims = ["x", "y", "z"] + pointwise_tensor_dims = list(reversed(grid_dims)) + reduction_dims = ["r0_", "r1_"] + if no_x_dim: + tensor_dims = reduction_dims + elif no_r_dim: + tensor_dims = pointwise_tensor_dims + else: + tensor_dims = pointwise_tensor_dims + reduction_dims + + # Filter out unused tensor dims. + # Convert to dicts for O(1) index lookup. + tensor_dim_map = filtered_index_map(tensor_dims, active_prefixes) + grid_dim_map = filtered_index_map(grid_dims, all_prefixes) + + range_trees = [] + for i, prefix in enumerate(active_prefixes): + is_reduction = prefix_is_reduction(prefix) + tensor_dim = tensor_dim_map.get(prefix) + grid_dim = grid_dim_map.get(prefix) + index = i if grid_dim is None else grid_dim + range_trees.append( + IterationRangesRoot( + f"{prefix}index", + numels[prefix], + prefix, + index, + self, # type: ignore[arg-type] + pid_cache=pid_cache, + is_loop=is_reduction and not self.persistent_reduction, + tensor_dim=tensor_dim, + grid_dim=grid_dim, + has_zdim="z" in numels, + ) + ) + return range_trees + + def initialize_range_tree(self, pid_cache: dict[str, str]) -> None: + range_trees = self.construct_range_trees( + pid_cache, + self.inside_reduction, + self.features.is_reduction(), + self.numels, + self.no_x_dim, + ) + self.range_trees.extend(range_trees) + + def finalize_indexing(self, indices: Sequence[sympy.Expr]) -> None: + """ + Hook called right before codegen with every index that will be + used in the fused kernel. + """ + + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None: + prior = self.inside_reduction + self.inside_reduction = False + try: + return self.store(name, index, value) + finally: + self.inside_reduction = prior + + def should_use_cooperative_reduction(self) -> bool: + return False # defined in subclass + + def should_use_persistent_reduction(self) -> bool: + return False # defined in subclass + + def var_ranges(self) -> dict[sympy.Symbol, sympy.Expr]: + return dict( + itertools.chain.from_iterable( + tree.var_ranges.items() for tree in self.range_trees + ) + ) + + def triton_tensor_ndim(self) -> int: + return sum(int(tree.tensor_dim is not None) for tree in self.range_trees) + + def indexing_size_str(self, i: int) -> str: + sizes = ["None"] * self.triton_tensor_ndim() + sizes[i] = ":" + return f"[{', '.join(sizes)}]" + + def dense_size_list(self) -> list[str]: + sizes = ["1"] * self.triton_tensor_ndim() + for tree in self.range_trees: + if tree.tensor_dim is None: + continue + + if not tree.is_reduction or self.inside_reduction: + sizes[tree.tensor_dim] = f"{tree.prefix.upper()}BLOCK" + return sizes + + def dense_size_str(self) -> str: + sizes = self.dense_size_list() + return f"[{', '.join(sizes)}]" + + def combine_modular_indexing_pairs(self, index: sympy.Expr) -> sympy.Expr: + if not isinstance(index, ModularIndexing): + return index + x = index.args[0] + if (tree_node := self.range_tree_nodes.get(x)) is None: + return index + new_index = sympy_subs(index, {x: tree_node.expr}) + new_index = V.graph.sizevars.combine_modular_indexing_pairs(new_index) + # the index now contains xindex/etc, which is nonstandard, fix it up + return sympy_subs( + new_index, + { + tree_node.root.index_sym(): tree_node.root.lookup( + sympy.S.One, tree_node.root.numel + ).symbol() + }, + ) + + def combine_contiguous_dims( + self, index: sympy.Expr, tree: IterationRangesRoot + ) -> sympy.Expr: + if expand_res := V.graph.sizevars.expand_floor_div(index): + new_index, denominator = expand_res # type: ignore[misc] + return FloorDiv(self._combine_contiguous_dims(new_index, tree), denominator) + else: + return self._combine_contiguous_dims(index, tree) + + def _combine_contiguous_dims( + self, index: sympy.Expr, tree: IterationRangesRoot + ) -> sympy.Expr: + """ + More aggressive simplification to merge contiguous dims + """ + if isinstance(index, (sympy.Integer, sympy.Symbol)): + return index + index_vars, sizes = tree.vars_and_sizes(index) + if len(sizes) <= 1: + return index + new_sizes, reindex, _prune = V.graph.sizevars._simplify_loops( + index_vars, sizes, index_prevent_reordering([index], index_vars, sizes) + ) + if new_sizes == sizes: + return index + new_index_vars = tree.construct(new_sizes) + new_index = sympy_subs(index, dict(zip(index_vars, reindex(new_index_vars)))) + return new_index + + def disable_reduction(self) -> contextlib.AbstractContextManager[None]: + should_flush = self.range_trees[-1].is_loop or self.cooperative_reduction + + @contextlib.contextmanager + def ctx(): + if not self.features.is_reduction(): + assert not self.inside_reduction + yield + return + if should_flush: + # calling codegen_body() will flush all the pending buffers + # and write out a reduction loop + self.codegen_body() + self.inside_reduction = False + try: + yield + if should_flush: + # flush out any code before opening the next loop + self.codegen_body() + finally: + self.inside_reduction = True + + return ctx() + + def set_ranges(self, *lengths: sympy.Expr) -> list[sympy.Symbol]: + assert len(lengths) == len(self.range_trees) + return [ + ranges.construct(length) + for length, ranges in zip(lengths, self.range_trees) + ] + + @staticmethod + def _split_iteration_ranges( + groups: Iterable[sympy.Expr], lengths: Sequence[Sequence[sympy.Expr]] + ) -> tuple[ + list[list[sympy.Expr]], list[list[Callable[[list[sympy.Expr]], sympy.Expr]]] + ]: + # Special case: if a node's sizes are ([], []), there's nothing to split. + if all(len(length) == 0 for length in lengths): + return [[] for group in groups], [] + + sv = V.graph.sizevars + new_ranges: list[list[sympy.Expr]] = [[] for _ in groups] + remaining = [sv.simplify(g) for g in groups] + var_count = itertools.count() + + def add_range(i: int, expr: sympy.Expr) -> int: + expr = sv.simplify(expr) + if not sv.statically_known_multiple_of(remaining[i], expr): + raise CantSplit + # guard on the last item out + remaining[i] = FloorDiv(remaining[i], expr) + new_ranges[i].append(expr) + return next(var_count) + + def make_combined( + size: sympy.Expr, idx1: int, idx2: int + ) -> Callable[[list[sympy.Expr]], sympy.Expr]: + def getter(flat_vars: list[sympy.Expr]) -> sympy.Expr: + return size * flat_vars[idx1] + flat_vars[idx2] + + return getter + + return_getters_groups = [] + current_group = 0 + for length_group in lengths: + return_getters = [] + for size in length_group: + if sv.statically_known_equals(size, 1): # type: ignore[arg-type] + return_getters.append(lambda _: sympy.S.Zero) + continue + + while current_group < len(remaining) and sv.statically_known_equals( + remaining[current_group], + 1, # type: ignore[arg-type] + ): + # scroll to next group with remaining elements + current_group += 1 + + if current_group + 1 < len(remaining) and sv.statically_known_gt( + size, remaining[current_group] + ): + # need to break size in two + if not sv.statically_known_multiple_of( + size, remaining[current_group] + ): + raise CantSplit + + size1 = remaining[current_group] + size2 = FloorDiv(size, remaining[current_group]) + return_getters.append( + make_combined( + size2, + add_range(current_group, size1), + add_range(current_group + 1, size2), + ) + ) + else: + if current_group < len(remaining): + return_getters.append( + operator.itemgetter(add_range(current_group, size)) + ) + return_getters_groups.append(return_getters) + + assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining), ( + f"failed to set ranges {remaining} {lengths}" + ) + + return new_ranges, return_getters_groups + + @classmethod + def prepare_split_iteration_lengths( + cls, + groups: Iterable[sympy.Expr], + lengths: Sequence[Sequence[sympy.Expr]], + reduction_numel: sympy.Expr = sympy.S.One, + ) -> Sequence[Sequence[sympy.Expr]]: + "Fill in the reduction numel of lengths if missing" + sizevars = V.graph.sizevars + if len(lengths[1]) == 0 and ( + not sizevars.statically_known_equals(reduction_numel, sympy.S.One) + and sizevars.statically_known_equals( + sympy_product(groups), + sympy_product(lengths[0]) * reduction_numel, + ) + ): + return (lengths[0], [reduction_numel]) + + return lengths + + @classmethod + def is_compatible( + cls, + groups: Iterable[sympy.Expr], + lengths: Sequence[Sequence[sympy.Expr]], + reduction_numel: sympy.Expr = sympy.S.One, + ) -> bool: + lengths = cls.prepare_split_iteration_lengths(groups, lengths, reduction_numel) + + try: + cls._split_iteration_ranges(groups, lengths) + return True + except CantSplit: + return False + + def split_and_set_ranges( + self, lengths: Sequence[Sequence[sympy.Expr]] + ) -> list[list[sympy.Expr]]: + tiling = {rt.prefix: rt.numel for rt in self.range_trees} + if not self.inside_reduction: + for prefix in tiling: + if prefix_is_reduction(prefix): + tiling[prefix] = sympy.S.One + + groups = [*tiling.values()] + return self.map_kernel_groups_to_node_sizes(groups, lengths, self.set_ranges) + + @classmethod + def map_kernel_groups_to_node_sizes( + cls, + groups: Sequence[sympy.Expr], + lengths: Sequence[Sequence[sympy.Expr]], + set_ranges, + ) -> list[list[sympy.Expr]]: + """ + We may want to fuse `for i0 in s0*s1` into a tiled kernel with groups (s0, s1). + + To do this we need to split up the iteration space of i0 into something like: + for i1 in s0: + for i2 in s1: + i0 = i1*s1 + i2 + .... + + This function matches and resplits lengths to the groups of + this kernel to enable tiled + non-tiled fusions. + """ + if len(lengths) == len(groups) and all( + V.graph.sizevars.simplify(sympy_product(x) - g) == 0 + for x, g in zip(lengths, groups) + ): + return set_ranges(*lengths) + + new_ranges, return_getters_groups = cls._split_iteration_ranges(groups, lengths) + itervars = [*itertools.chain.from_iterable(set_ranges(*new_ranges))] + return [[fn(itervars) for fn in fns] for fns in return_getters_groups] + + def is_indirect_indexing(self, index: sympy.Expr) -> bool: + # tmpX means indirect indexing + return free_symbol_is_type(index, SymT.TMP) + + def is_broadcasted(self, index: sympy.Expr) -> bool: + # Note. This may not be correct when there is indirect indexing + if self.is_indirect_indexing(index): + return False + + index_numels = [1] * len(self.numels) + for symbol in index.free_symbols: + if symbol not in self.range_tree_nodes: + # Non-iterated variables, e.g. strides + continue + entry = self.range_tree_nodes[symbol] # type: ignore[index] + assert isinstance(entry.parent, IterationRangesRoot) + index_numels[entry.parent.index] *= entry.length + + # If the index variables only iterate over a subset of the kernel + # numels, then it must be broadcasted. + simplify = V.graph.sizevars.simplify + return any( + simplify(idx_range) != simplify(iter_range) # type: ignore[arg-type] + for idx_range, iter_range in zip(index_numels, self.numels.values()) + ) + + def index_to_str(self, index: sympy.Expr) -> str: + """ + Convert an index expr to a string that can be used in output code. + e.g. a sympy expression "s2" may actually appear as "ks1" in the generated kernel. + + Index expressions often need to be passed in as arguments to the triton kernel. + Rename_indexing and codegen_indexing keep track of the needed indices and add + new parameters to the function signature. + """ + if isinstance(index, list): + return f"[{', '.join(map(self.index_to_str, index))}]" + return self.kexpr(self.rename_indexing(index)) # type: ignore[call-arg] + + def prepare_indexing( + self, + index: sympy.Expr, + ) -> sympy.Expr: + index = self.simplify_indexing(index) + index = sympy_subs(index, V.graph.sizevars.precomputed_replacements) + # if simple replacements didn't get rid of floor/ceil, try full subs + if len(index.atoms(sympy.floor)) or len(index.atoms(sympy.ceiling)): + index = index.subs(V.graph.sizevars.precomputed_replacements) + # last resort, if no range vars are in the expr, hoist it + # TODO instead of trying to blindly find complicated exprs, we should hoist the + # inputs/outputs sizes and strides, but at the time indexing is generated + # kernel inputs and outputs are not set yet, we'd need a deeper refactor + # to do it this way + + if len(index.atoms(sympy.ceiling)): + for a in index.atoms(sympy.ceiling): + # for nested exprs, atoms yields top level first (?) + # so if everything goes fine, lower level replacements will come up empty + symbols = a.free_symbols + if len(symbols) > 0 and all( + symbol_is_type(s, (SymT.SIZE, SymT.PRECOMPUTED_SIZE)) + for s in symbols + ): + replacements = {a: V.graph.sizevars.lookup_precomputed_size(a)} + index = sympy_subs(index, replacements) + + simp_index = self.simplify_indexing(index) + + # Now that we are done simplifying we can unwrap Identity so that downstream handling + # for its contained expression will work. previously, tl.full wrapping of sympy.Integer + # would not occur + simp_index = ( + simp_index if not isinstance(simp_index, Identity) else simp_index.args[0] + ) + + return self.codegen_indexing(simp_index) + + def active_range_trees(self) -> list[IterationRangesRoot]: + return [ + t for t in self.range_trees if not t.is_reduction or self.inside_reduction + ] + + def codegen_indexing(self, expr: sympy.Expr) -> sympy.Expr: + expr = V.graph.sizevars.simplify_with_ranges(expr, self.var_ranges()) + for sym in sorted(expr.free_symbols, key=str): + if sym in self.range_tree_nodes: + # if indexing expression is complicated, we precompute it on the host side + # and send the result as a kernel argument + replacements = {} + for ps in self.range_tree_nodes[sym].precomputed_args(): # type: ignore[index] + replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps) + if len(replacements) > 0: + self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index] + self.range_tree_nodes[sym].expr, + replacements, # type: ignore[index] + ) + self.range_tree_nodes[sym].codegen() # type: ignore[index] + return expr + + def codegen_nan_check(self) -> None: + raise NotImplementedError("NYI: codegen_nan_check") + + def call_kernel(self, name: str, node: Optional[IRNode] = None) -> None: + raise NotImplementedError("NYI: call_kernel") + + @contextlib.contextmanager + def mask_loads( + self, mask: Union[str, OpsWrapper], value: Union[int, float] + ) -> Iterator[str]: + """Context manager to add an additional mask to tl.load/store""" + prior = self._load_mask + prior_val = self._load_other + if prior: + mask = ops.logical_and(mask, prior) + + mask = OpsWrapper._unwrap(mask) + self._load_mask = mask + self._load_other = value + try: + # TODO(jansel): do we need a reshape here? + yield mask + finally: + self._load_mask = prior + self._load_other = prior_val + + def get_strides_of_load(self, index: sympy.Expr) -> dict[sympy.Symbol, sympy.Expr]: + """ + This gets the stride of the index for each of the tiling variables + (technically, it does it at index 0) + + For example, if + xindex = x0 + 512*x1 + 1024*r0 + x0 = (xindex//512) + x1 = (xindex % 512) + r0 = rindex // 1024 + + this function would return + {xindex: 512, rindex: 1024} + """ + index_to_tile_indexes = {k: v.expr for k, v in self.range_tree_nodes.items()} + index_in_tile_vars = sympy_subs(index, index_to_tile_indexes) # type: ignore[arg-type] + strides = {} + for range_tree in self.range_trees: + s = sympy_index_symbol(range_tree.name) + strides[s] = sympy_subs(index_in_tile_vars, {s: 1}) - sympy_subs( + index_in_tile_vars, {s: 0} + ) + return strides + + @staticmethod + def _map_tuple_or_scalar(fn, value): + if isinstance(value, tuple): + return tuple(map(fn, value)) + return fn(value) + + def estimate_kernel_num_bytes(self): + """ + Try the best to estimate the total size (in bytes) of the + kernel's inputs and outputs, which is used for estimating the memory + throughput of this kernel. This information is used for checking how + far we are from the peak memory bandwidth. It's important that + we want to avoid overestimating the sizes of the inputs and outputs, + because it can wrongfully give us a very large memory traffic value, + which may be even larger than the theoretical bandwidth and thus + become very misleading. This is particularly problematic for cases + where we slice some inputs. In those cases, we should only count + the size of the "slices" instead of the original inputs, because + only the slices contribute to the real memory traffic. + """ + nbytes = [] + ninplace_args = len(unique(self.args.inplace_buffers.values())) + _, call_args, _, _ = self.args.python_argdefs() + buf_accesses = self.features.buf_accesses() + + # For pointwise and reduction kernels, this is the upper-bound numels + # for the output buffer. + # FIXME: This is not exactly right for cases like below: + # def foo(tensor0, tensor1): + # x0 = narrow(tensor0) + # return cat(x0, tensor1) + # For this example, we will end up overestimate the size for the + # slice s0. Potentially, we could have precise inputs information + # if we maintained the original inputs of the Pointwise kernel created + # for the "cat". However, I think it might be a bit overwhelming that + # we add such complexity only for handling some particular cases for + # benchmarking. + out_numel = V.graph.sizevars.size_hint(sympy_product(self.numels.values())) + for i, arg in enumerate(call_args): + # "buf" may be narrowed. In this case, the number of memory accesses + # should be estimated based on the reinterpreted layout. + # On the other hand, buf may be broadcasted. In this case, + # counting the size of the underline storage would give us + # a better estimation in terms of memory accesses. + if arg not in buf_accesses: + nbytes.append(0) + continue + arg_numel = V.graph.get_numel(arg) + buf_size = V.graph.sizevars.size_hint(arg_numel) + if buf_size > out_numel: + # This arg points to a buf that has been sliced. + # We need to count each individual slice to have + # a better estimation. + indices = OrderedSet[Any]() + no_index_dep_count = 0 + for dep in buf_accesses[arg]: + if isinstance(dep, (StarDep, WeakDep)): + indices.add(f"no_index_dep_{no_index_dep_count}") + no_index_dep_count += 1 + else: + indices.add(dep.index) + numel = len(indices) * out_numel + else: + numel = buf_size + dtype = V.graph.get_dtype(arg) + dtype_size = get_dtype_size(dtype) + nbytes.append(numel * dtype_size * (1 + int(i < ninplace_args))) + return sum(nbytes) + + def warn_mix_layout(self, kernel_name): + """ + Print message if the kernel have mixed layout inputs. + Only care about 4D tensor for now. + """ + if ( + len(self.args.input_buffers) == 1 + and len(self.args.output_buffers) == 1 + and len(self.args.inplace_buffers) == 0 + ): + # even if input buffer and output buffer have different layout, + # this can be a layout conversion kernel. No need to warn for + # the mix layouts. + return + + argdefs, call_args, _signature, _ = self.args.python_argdefs() + uniform_stride_order = None + for arg_name in call_args: + buf = V.graph.try_get_buffer(arg_name) + if not buf: + continue + layout = buf.get_layout() + if len(layout.size) == 4: + # ignore the tensor if only 1 dimension is non-zero + if len([x for x in layout.size if x == 1]) == 3: + continue + stride_order = ir.get_stride_order(layout.stride) + if uniform_stride_order is None: + uniform_stride_order = stride_order + elif uniform_stride_order != stride_order: + msg = yellow_text( + f"Expected stride order {uniform_stride_order}, but found stride order" + + f" {stride_order} for kernel {kernel_name}" + ) + log.warning(msg) + + stride_order_list = [ + ir.get_stride_order( + V.graph.get_buffer(name).get_layout().stride + ) + if V.graph.try_get_buffer(name) + else None + for name in call_args + ] + size_list = [ + V.graph.get_buffer(name).get_layout().size + if V.graph.try_get_buffer(name) + else None + for name in call_args + ] + source_list = [ + "GraphInput" + if name in V.graph.graph_inputs + else "IntermediateBuffer" + if name in V.graph.name_to_buffer + else None + for name in call_args + ] + + argdef_names = [x.name for x in argdefs] + msg = yellow_text( + f" param names {argdef_names}\n buf names {call_args}\n strides {stride_order_list}" + + f"\n sizes {size_list}\n sources {source_list}\n" + ) + log.warning(msg) + return + msg = green_text( + f"All the inputs for the triton kernel {kernel_name} have uniform layout" + ) + log.warning(msg) + + def welford_reduce_fallback(self, dtype, value): + sum_ = ops.reduction(dtype, dtype, "sum", value) + self.inside_reduction = False + rnumel = ops.index_expr(self.features.reduction_numel, dtype) + mean = ops.truediv(sum_, rnumel) + + self.inside_reduction = True + dx = ops.sub(value, mean) + dx2 = ops.mul(dx, dx) + m2 = ops.reduction(dtype, dtype, "sum", dx2) + return OpsWrapper._unwrap((mean, m2, rnumel)) + + def prepare_softmax_twopass_fallback(self, dtype, value): + vmax = ops.reduction(dtype, dtype, "max", value) + sub = ops.sub(value, vmax) + exp = ops.exp(sub) + vsum = ops.reduction(dtype, dtype, "sum", exp) + return OpsWrapper._unwrap((vmax, vsum)) + + def codegen_kernel(self): + raise NotImplementedError + + def codegen_body(self): + pass + + def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): + pass + + +class SIMDScheduling(BaseScheduling): + """ + Single Instruction Multiple Data parent class used for fusion across + multiple different backends. + """ + + kernel_type: type[Any] = SIMDKernel # override in subclass + + def group_fn(self, sizes): + return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) + + def can_fuse(self, node1, node2): + """ + Hook called by Scheduler to determine if the Triton backend + can fuse node1 and node2. These nodes might already be + FusedSchedulerNodes. + """ + if isinstance(node1, scheduler.ForeachKernelSchedulerNode) or isinstance( + node2, scheduler.ForeachKernelSchedulerNode + ): + return scheduler.ForeachKernelSchedulerNode.can_fuse(node1, node2) + + _, (numel1, rnumel1) = node1.group + _, (numel2, rnumel2) = node2.group + why = WhyNoFuse(node1, node2) + + if node1.is_split_scan() and not node2.is_split_scan(): + if node2.is_reduction(): + why("Split scan cannot fuse with reductions") + elif node2.is_split_scan() and not node1.is_split_scan(): + if node1.is_reduction(): + why("Split scan cannot fuse with reductions") + + if node1.is_reduction() and node2.is_reduction(): + reduction_can_fuse = numel1 == numel2 and rnumel1 == rnumel2 + if not reduction_can_fuse: + why( + "numel/rnumel mismatch (reduce) (%s, %s), (%s, %s)", + numel1, + numel2, + rnumel1, + rnumel2, + ) + return reduction_can_fuse + + if not node1.is_reduction() and not node2.is_reduction(): + if not (numel1 == numel2 and rnumel1 == rnumel2): + if not node2.is_template(): + why( + "numel/rnumel mismatch (non-reduce) (%s, %s), (%s, %s)", + numel1, + numel2, + rnumel1, + rnumel2, + ) + return False + else: + # prologue fusion input sizes differ from output group + # fuse so long as this node matches the group of existing prologue nodes + for node in node2.get_nodes(): + # dont need to check epilogue nodes for prologue fusion, break after template + if node.is_template(): + break + # we would have already restricted prologue from fusing if it had multiple + # uses, so it must be fusing into this node + if not node.used_buffer_names() & node1.get_buffer_names(): + continue + _, (pro_numel, pro_rnumel) = node.group + if not (numel1 == pro_numel and rnumel1 == pro_rnumel): + why( + "numel/rnumel mismatch prologue mismatch (%s, %s), (%s, %s)", + numel1, + pro_numel, + rnumel1, + pro_rnumel, + ) + return False + + for n in (node1, node2): + if n.is_template(): + return True + + # check for a bad combined tiling + tiling1 = self.select_tiling(node1.get_nodes(), numel1, rnumel1) + tiling2 = self.select_tiling(node2.get_nodes(), numel1, rnumel1) + tiling3 = self.select_tiling( + node1.get_nodes() + node2.get_nodes(), numel1, rnumel1 + ) + if config.triton.tiling_prevents_pointwise_fusion: + cond = True + if len(tiling1) > 2: + if len(tiling2) > 2: + cond = tiling1 == tiling2 == tiling3 + else: + cond = tiling1 == tiling3 + elif len(tiling2) > 2: + cond = tiling2 == tiling3 + if not cond: + why( + "tiling mismatch (%s, %s, %s)", + tiling1, + tiling2, + tiling3, + ) + return False + + return True + + if not node1.is_reduction() and node2.is_reduction(): + assert rnumel1 == 1 and rnumel2 != 1 + if numel1 == numel2 * rnumel2: + if not all( + SIMDKernel.is_compatible((numel2, rnumel2), n.get_ranges()) + for n in node1.get_nodes() + ): + why("nodes numel/rnumel incompatibility") + return False + if ( + config.triton.tiling_prevents_reduction_fusion + and not node1.is_template() + ): + is_reduction_tiling_valid = tuple( + self.select_tiling(node1.get_nodes(), numel1).values() + ) in ( + (numel1, 1), + (numel2, rnumel2, 1), + ) + if not is_reduction_tiling_valid: + why("invalid tiling for reduction") + return is_reduction_tiling_valid + return True + + if numel1 != numel2: + why("nodes numel incompatibility") + return numel1 == numel2 + + assert node1.is_reduction() and not node2.is_reduction() + # swap args to hit the case above + return self.can_fuse_horizontal(node2, node1) + + can_fuse_vertical = can_fuse + can_fuse_horizontal = can_fuse + + def generate_node_schedule(self, nodes, numel, rnumel): + node_schedule: list[Any] = [] + done = OrderedSet[scheduler.BaseSchedulerNode]() + # Writes with a reduced shape, meaning they are only present once the + # reduction loop has ended + not_ready_yet_nodes: OrderedSet[str] = OrderedSet() + current_loop_buffer_usage: OrderedSet[str] = OrderedSet() + maybe_split_index: Optional[int] = None + + def fits_in_main_body(n): + _, (node_numel, node_rnumel) = n.group + return (node_numel == numel and node_rnumel == rnumel) or ( + node_numel == numel * rnumel and node_rnumel == 1 + ) + + def fits_outside_reduction(n): + _, (node_numel, node_rnumel) = n.group + return node_numel == numel and node_rnumel == 1 and rnumel != 1 + + def expect_improved_memory_usage(n): + for read in n.read_writes.reads: + if read.name in current_loop_buffer_usage: + return True + return False + + def schedule_node_in_loop(n): + done.add(n) + node_schedule.append(n) + current_loop_buffer_usage.update([x.name for x in n.read_writes.reads]) + + # A scan is modelled as a reduction in the scheduler but has a + # full sized output that can be used inside the loop body + if ( + n.is_reduction() + and isinstance(n, scheduler.SchedulerNode) + and isinstance(n.node, ir.ComputedBuffer) + and not isinstance(n.node.data, ir.Scan) + ): + not_ready_yet_nodes.add(n.get_name()) + else: # this node is available within the loop + current_loop_buffer_usage.update([x.name for x in n.read_writes.writes]) + + @contextlib.contextmanager + def end_current_reduction_loop(): + nonlocal maybe_split_index + if node_schedule and node_schedule[-1] is EnableReduction: + node_schedule.pop() + else: + node_schedule.append(DisableReduction) + if maybe_split_index: + node_schedule.insert(maybe_split_index, DisableReduction) + node_schedule.insert(maybe_split_index + 1, EnableReduction) + maybe_split_index = None + yield + node_schedule.append(EnableReduction) + not_ready_yet_nodes.clear() + current_loop_buffer_usage.clear() + + def requires_closing_previous_reduction(node, node_schedule): + if rnumel == 1: + return False + if not not_ready_yet_nodes & node.ancestors: + return False + assert node_schedule and not isinstance( + node_schedule[-1], (EnableReduction, DisableReduction) + ) + return bool(not_ready_yet_nodes) + + for node in nodes: + if node in done: + continue + done.add(node) + + if fits_in_main_body(node): + if requires_closing_previous_reduction(node, node_schedule): + with end_current_reduction_loop(): + pass # need to start a new reduction loop + + if current_loop_buffer_usage and not expect_improved_memory_usage(node): + # If we don't improve memory usage, then it is better to split into two loops + maybe_split_index = maybe_split_index or len(node_schedule) + else: + # Memory usage got improved, cancel the loop split + maybe_split_index = None + + schedule_node_in_loop(node) + elif fits_outside_reduction(node): + with end_current_reduction_loop(): + node_schedule.append(node) + else: + raise NotImplementedError( + f"unexpected group: ({numel}, {rnumel}) != {node.group[1]}" + ) + + return node_schedule + + def codegen_node( + self, node: Union[scheduler.FusedSchedulerNode, scheduler.SchedulerNode] + ): + """ + Given a set of pre-fused nodes, generate a Triton kernel. + """ + + nodes: list[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment] + + if torch._inductor.config.triton.coalesce_tiling_analysis: + coalesce_analysis = analyze_memory_coalescing(node) + else: + coalesce_analysis = None + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + schedule_log.debug("Schedule:\n %s", node_schedule) + + return self.codegen_node_schedule( + SIMDKernelFeatures(node_schedule, numel, rnumel, coalesce_analysis) + ) + + @staticmethod + def can_use_32bit_indexing( + numel: sympy.Expr, + buffers: Iterable[ + Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject, ir.IRNode] + ], + ) -> bool: + int_max = torch.iinfo(torch.int32).max + + if not expr_fits_within_32bit(numel): + return False + + # Any use of a MultiOutputLayout will create a buffer with a + # Layout whose sizes are accounted for + buf_sizes = [ + buf.get_layout().storage_size() + for buf in buffers + if buf.has_tensor_output() + ] + + if not all(expr_fits_within_32bit(size) for size in buf_sizes): + return False + + # Only install guards for 32-bit indexing as there is no correctness + # issue with using 64-bit for everything + V.graph.sizevars.guard_leq(numel, int_max) # type: ignore[arg-type] + for size in buf_sizes: + V.graph.sizevars.guard_leq(size, int_max) # type: ignore[arg-type] + return True + + def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures): + node_schedule = kernel_features.node_schedule + + tiling, tiling_score = self.get_tiling_and_scores( + node_schedule, + kernel_features.numel, + kernel_features.reduction_numel, + kernel_features.coalesce_analysis, + ) + kernels = self.create_kernel_choices( + kernel_features, + [tiling], + {"features": kernel_features, "tiling_scores": tiling_score}, + ) + for kernel in kernels: + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + MultiKernel.merge_workspaces_inplace(kernels) + for kernel in kernels: + with V.set_kernel_handler(kernel): + src_code = kernel.codegen_kernel() + kernel_name = self.define_kernel(src_code, node_schedule, kernel) + if config.trace.enabled: + set_kernel_post_grad_provenance_tracing( + node_schedule, # type: ignore[arg-type] + kernel_name, + ) + log.debug("Generating kernel code with kernel_name: %s", kernel_name) + kernel.kernel_name = kernel_name + kernel.code_hash = code_hash(src_code) + del kernel + + final_kernel: Union[SIMDKernel, MultiKernel] + if len(kernels) > 1: + final_kernel = MultiKernel(kernels) + else: + (final_kernel,) = kernels + + with V.set_kernel_handler(final_kernel): + for node in kernel_features.scheduler_nodes(): + node.mark_run() + + self.codegen_comment(node_schedule) + final_kernel.call_kernel(final_kernel.kernel_name) + + if config.nan_asserts: + final_kernel.codegen_nan_check() + if config.warn_mix_layout: + final_kernel.warn_mix_layout(kernels[0].kernel_name) + + V.graph.removed_buffers |= final_kernel.removed_buffers + V.graph.inplaced_to_remove |= final_kernel.inplaced_to_remove + + if ( + V.graph.wrapper_code.supports_intermediate_hooks + and config.generate_intermediate_hooks + ): + # Not every node in the schedule will actually be live on output; + # we can't check dead buffers. + live_outs = kernels[0].args.live_output_buffers() + for node in kernel_features.scheduler_nodes(): + name = node.get_name() + if name not in live_outs: + continue + assert node.node is not None + origin_node = node.node.get_origin_node() + if origin_node is not None: + counters["inductor"]["intermediate_hooks"] += 1 + V.graph.wrapper_code.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {name})" + ) + + self.free_buffers_in_scheduler() + + def create_kernel_choices( + self, kernel_features: SIMDKernelFeatures, kernel_args, kernel_kwargs + ) -> list[SIMDKernel]: + return [ + self.kernel_type( + *kernel_args, + **kernel_kwargs, + ) + ] + + def codegen_node_schedule_with_kernel(self, node_schedule, kernel): + with kernel: + stack = contextlib.ExitStack() + all_indexing = {} + + # First pass to collect indexing and decide inplace updates + for node in node_schedule: + if node is DisableReduction: + stack.enter_context(kernel.disable_reduction()) + elif node is EnableReduction: + stack.close() + else: + node.decide_inplace_update() + index_vars = kernel.split_and_set_ranges(node.get_ranges()) + all_indexing.update( + dict.fromkeys( + node._body.indexing_from_args(index_vars).values() + ) + ) + + kernel.finalize_indexing(all_indexing.keys()) + + # Second pass to do codegen + for node in node_schedule: + if node is DisableReduction: + stack.enter_context(kernel.disable_reduction()) + elif node is EnableReduction: + stack.close() + else: + # TODO - use split ranges ? + indexing_dtype_strength_reduction(node._body) + index_vars = kernel.split_and_set_ranges(node.get_ranges()) + node.codegen(index_vars) + + def codegen_template( + self, template_node, epilogue_nodes, prologue_nodes, *, only_gen_src_code=False + ) -> Optional[str]: + """ + Codegen a triton template + + If `only_gen_src_code` the src code will be returned instead of codegen'd into the wrapper + """ + _, (_numel, rnumel) = template_node.group + assert rnumel == 1 + kernel, render = template_node.node.make_kernel_render(template_node.node) + + buf_name_to_prologue_group = {} + template_reads = template_node.used_buffer_names() + prologue_group = [] + for prologue in prologue_nodes: + names = prologue.get_buffer_names() + prologue_group.append(prologue) + # this must be the end of a prologue group + if names & template_reads: + assert len(names) == 1 + buf_name_to_prologue_group[next(iter(names))] = prologue_group + kernel.prologue_fused_inputs.add(next(iter(names))) + prologue_group = [] + + # all prologue groups should have finalized with use in template + assert len(prologue_group) == 0 + + with kernel: + if not only_gen_src_code: + # prologue nodes can only be fused if their only use is in the template, + # so they are necessarily not allocated + for node in [template_node, *epilogue_nodes]: + node.mark_run() + + partial_code = render() + + with kernel.set_subgraph_body(""): + for node in epilogue_nodes: + node.codegen(kernel.split_and_set_ranges(node.get_ranges())) + kernel.cse.invalidate(OrderedSet()) + + for input_name, buffer in kernel.named_input_nodes.items(): + subgraph_name = f"" + if prologue_group := buf_name_to_prologue_group.get( + buffer.get_name(), [] + ): + can_codegen_without_upcast = all( + p_n.can_codegen_without_upcasts() for p_n in prologue_group + ) + + # TODO - this doesn't work with libdevice calls, potentially other bugs + # upcasting to fp32 and downcasting gives large slowdown + with config.patch( + "triton.codegen_upcast_to_fp32", not can_codegen_without_upcast + ): + with kernel.set_subgraph_body(subgraph_name): + for prologue_node in prologue_group: + if ( + len(prologue_node.get_buffer_names()) == 1 + and len(prologue_group) == 1 + ): + if prologue_preserves_zero_mask(prologue_node): + kernel.prologue_fused_inputs_preserve_zero |= ( + prologue_node.get_buffer_names() + ) + + prologue_node.codegen( + kernel.split_and_set_ranges( + prologue_node.get_ranges() + ) + ) + kernel.cse.invalidate(OrderedSet()) + + if not isinstance(partial_code, str): + partial_code.finalize_hook("") + partial_code.finalize_hook("", strict=False) + # finalize must be called after adding epilogue above + + with V.set_kernel_handler(kernel): + # TODO: Maybe unify CUDATemplateKernel to also use PartialRender for flexible epilogue fusion. + + for input_name in kernel.named_input_nodes.keys(): + subgraph_name = f"" + partial_code.finalize_hook(subgraph_name, strict=False) + + with kernel.set_subgraph_body(""): + if isinstance(partial_code, str): + src_code = partial_code + else: + partial_code.finalize_hook("") + src_code = partial_code.code + node_schedule = [*prologue_nodes, template_node, *epilogue_nodes] + + if config.benchmark_kernel: + num_gb = kernel.estimate_kernel_num_bytes() / 1e9 + src_code = ( + f"{kernel.imports_for_benchmark_kernel()}\n" + f"{src_code}\n" + f"{kernel.codegen_kernel_benchmark(num_gb).getvalue()}" + ) + + if only_gen_src_code: + return src_code + + kernel_name = self.define_kernel(src_code, node_schedule, kernel) + + if config.trace.enabled: + set_kernel_post_grad_provenance_tracing(node_schedule, kernel_name) + + self.codegen_comment(node_schedule) + kernel.call_kernel(kernel_name, template_node.node) + + V.graph.removed_buffers |= kernel.removed_buffers + V.graph.inplaced_to_remove |= kernel.inplaced_to_remove + self.free_buffers_in_scheduler() + return None + + def codegen_sync(self): + V.graph.wrapper_code.writeline(V.graph.device_ops.synchronize()) + + def generate_combo_kernel_code( + self, + subkernel_nodes: list[BaseSchedulerNode], + custom_part_algorithm: bool, + enable_autotune: bool, + mixed_sizes: bool, + only_gen_src_code: bool = False, + ) -> list[tuple[str, Any, Any]]: + from .triton_combo_kernel import ComboKernel + + fused_node_lists = [node.get_nodes() for node in subkernel_nodes] + subkernel_map, node_schedule_map = {}, {} + for pn, nodes in zip(subkernel_nodes, fused_node_lists): + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + tiling = self.select_tiling(node_schedule, numel, rnumel) + node_schedule_map[pn] = node_schedule, tiling, numel, rnumel + subkernel_map[pn] = ComboKernel.create_triton_kernel( + tiling, + features=SIMDKernelFeatures(node_schedule, numel, rnumel), + optimize_mask=not mixed_sizes, + ) + + partitions = ComboKernel.horizontal_partition( + nodes=subkernel_nodes, + triton_scheduling=self, + custom_algorithm=custom_part_algorithm, + kernel_map=subkernel_map, + node_info_map=node_schedule_map, + ) + log.debug( + "ComboKernels: %d nodes partitioned into %s groups", + len(subkernel_nodes), + [len(p) for p in partitions], + ) + kernel_code_list = [] + for node_group in partitions: + fused_node_lists = [node.get_nodes() for node in node_group] + kernel = ComboKernel( + enable_autotune=enable_autotune, + mixed_sizes=mixed_sizes, + ) + + for pn, nodes in zip(node_group, fused_node_lists): + self.codegen_node_schedule_with_kernel( + node_schedule_map[pn][0], + kernel.create_sub_kernel(subkernel_map[pn]), + ) + subkernel = subkernel_map[pn] + node_schedule = node_schedule_map[pn][0] + if not only_gen_src_code: + with V.set_kernel_handler(subkernel): # type: ignore[call-arg] + for node in NodeScheduleMarker.only_nodes(node_schedule): + node.mark_run() + V.graph.removed_buffers |= subkernel.removed_buffers + V.graph.inplaced_to_remove |= subkernel.inplaced_to_remove + + src_code = kernel.codegen_kernel() + kernel_code_list.append((src_code, kernel, node_group)) + return kernel_code_list + + def codegen_combo_kernel(self, combo_kernel_node): + subkernel_nodes = combo_kernel_node.get_subkernel_nodes() + custom_part_algorithm = combo_kernel_node.use_custom_partition_algo + enable_autotune = combo_kernel_node.enable_autotune + mixed_sizes = config.combo_kernel_allow_mixed_sizes > 1 or ( + config.combo_kernel_allow_mixed_sizes == 1 and custom_part_algorithm + ) + + kernel_code_list = self.generate_combo_kernel_code( + subkernel_nodes, custom_part_algorithm, enable_autotune, mixed_sizes + ) + + for src_code, kernel, _ in kernel_code_list: + kernel_name = self.define_kernel(src_code, [combo_kernel_node], kernel) + # dump provenance node info for ComboKernelNode/ForeachKernel type + if config.trace.enabled: + set_kernel_post_grad_provenance_tracing( + combo_kernel_node.snodes, kernel_name + ) + self.codegen_comment([combo_kernel_node]) + log.debug("ComboKernels: generated kernel %s.", kernel_name) + kernel.call_kernel(V.graph.wrapper_code, kernel_name) + + self.free_buffers_in_scheduler() + + @classmethod + @functools.lru_cache(32) + def candidate_tilings(cls, node, numel, reduction_numel) -> list[CandidateTiling]: + is_pointwise = reduction_numel == 1 + + def tile_ranges(is_pointwise: bool, ranges, rw) -> list[CandidateTiling]: + """ + Compute tiling candidates by dividing up the iteration ranges. + """ + assert len(rw.range_vars) == len(ranges), f"{rw.range_vars=} {ranges=}" + + # isinstance(dep, MemoryDep): this filters out StarDeps. StarDeps refer to reads + # that need to access the entire tensor; they don't contribute read indexing + # information (and practically, they don't have dep.index so they can't be used + # for stride_hints below + dep_sources = [rw.reads, rw.writes] + assert all( + isinstance(dep, (MemoryDep, StarDep)) + for dep in itertools.chain.from_iterable(dep_sources) + ) + deps = [ + dep + for dep in itertools.chain.from_iterable(dep_sources) + if dep.name not in V.graph.removed_buffers + and isinstance(dep, MemoryDep) + ] + write_names = OrderedSet([dep.name for dep in rw.writes]) + + def collapse_ranges(ranges: Sequence[sympy.Expr]) -> sympy.Expr: + return V.graph.sizevars.simplify(sympy_product(ranges)) + + # Default to no tiling. + tilings = [ + CandidateTiling( + tiling=cls.create_partial_tiling( + [collapse_ranges(ranges)], is_pointwise + ), + name="none", + score=0, + ) + ] + + # Find non-trivial tiling candidates. + for dep in deps: + strides = V.graph.sizevars.stride_hints(dep.index, rw.range_vars) + assert len(strides) == len(ranges) + try: + split = strides.index(1) + 1 + if split == len(ranges): + continue + if all(s == 0 for s in strides[split:]): + # if this is a broadcasted tensor and all dimensions after split are broadcast, + # this is not a real split + continue + + except ValueError: + continue + + tiled_groups = ( + collapse_ranges(ranges[:split]), + collapse_ranges(ranges[split:]), + ) + + # score by number of elements + score = V.graph.sizevars.size_hint( + sympy_product( + size for size, stride in zip(ranges, strides) if stride != 0 + ) + ) + if dep.name in write_names: + # ngimel said contiguous writes is more important than reads + score *= 2 + if CandidateTiling.is_good_size(tiled_groups[0]): + score *= 2 + if CandidateTiling.is_good_size(tiled_groups[1]): + score *= 2 + + if ( + V.graph.sizevars.size_hint( + score - sympy_product(itertools.chain(ranges, reduction_ranges)) + ) + >= 0 + ): + tilings.append( + CandidateTiling( + tiling=cls.create_partial_tiling( + [ + collapse_ranges(ranges[:split]), + collapse_ranges(ranges[split:]), + ], + reduction_numel, + ), + score=score, + name=dep.name, + ) + ) + + return tilings + + pointwise_ranges, reduction_ranges = node.get_ranges() + if ( + len(pointwise_ranges) <= 1 + and len(reduction_ranges) <= 1 + or free_unbacked_symbols(pointwise_ranges + reduction_ranges) + ): + return [] + + # Tile either pointwise or reduction dims. + pointwise_ranges, reduction_ranges = node.get_ranges() + partial_tilings = tile_ranges( + is_pointwise, + pointwise_ranges if is_pointwise else reduction_ranges, + node.pointwise_or_reduction_read_writes(is_pointwise), + ) + + # Fill in the missing ranges. + full_tilings = [ + CandidateTiling( + tiling=cls.complete_partial_tiling( + tiling.tiling, numel, reduction_numel + ), + score=tiling.score, + name=tiling.name, + ) + for tiling in partial_tilings + ] + + return full_tilings + + @classmethod + def create_tiling( + cls, pw_tiling: Sequence[sympy.Expr], reduction_tiling: Sequence[sympy.Expr] + ) -> dict[str, sympy.Expr]: + """ + Create a tiling dict from pointwise and reduction splits. + """ + pw_prefixes = ["z", "y", "x"][-len(pw_tiling) :] + reduction_prefixes = ["r0_", "r1_"][: len(reduction_tiling)] + return immutable_dict( + [*zip(pw_prefixes, pw_tiling), *zip(reduction_prefixes, reduction_tiling)] + ) + + @classmethod + def create_partial_tiling( + cls, + tiling: Sequence[sympy.Expr], + is_pointwise: bool, + ) -> dict[str, sympy.Expr]: + return cls.create_tiling( + tiling if is_pointwise else [], + tiling if not is_pointwise else [], + ) + + @classmethod + def complete_partial_tiling( + cls, + tiling: dict[str, sympy.Expr], + numel: sympy.Expr, + reduction_numel: sympy.Expr, + ) -> dict[str, sympy.Expr]: + """ + Given a tiling for only pointwise or reduction dimensions, adds the missing one. + """ + splits = list(tiling.values()) + is_pointwise = "x" in tiling + + total_numel = numel * reduction_numel + missing_tiling = [total_numel / sympy_product(splits)] + + tiling_args = ( + (splits, missing_tiling) if is_pointwise else (missing_tiling, splits) + ) + return cls.create_tiling(*tiling_args) + + @classmethod + def get_nd_tilings( + cls, + node_schedule, + pointwise_numel, + reduction_numel, + ) -> list[dict[str, tuple[sympy.Expr]]]: + """ + Creates N-dimensional tiling candidates, attempting to simplify loads/stores + by tiling the kernel into higher dimensions. + + Returns a list of tilings ranked by dimensionality. + """ + is_pointwise = reduction_numel == 1 + tilings = OrderedSet[dict[str, sympy.Expr]]() + for node in EnableReduction.filter(node_schedule): + if not isinstance(node, scheduler.SchedulerNode): + continue + + # If this is a reduction schedule, skip nodes which are missing their + # reduction ranges. + node_ranges = node.get_ranges() + if not is_pointwise and len(node_ranges[1]) == 0: + continue + + # Use the node ranges as the default tiling candidate. + ranges_to_tile = node_ranges[0 if is_pointwise else 1] + node_tilings = [ranges_to_tile] + + # Search the indexing expressions for more candidates. + # If we see modular indexing, try to subdivide ranges into their implied + # block shape. + memory_deps = [ + dep + for dep in node.read_writes.reads_and_writes() + if isinstance(dep, MemoryDep) and len(dep.ranges) > 0 + ] + for dep in memory_deps: + # Attempt to partition variable ranges into pointwise and reduction groups. + # To achieve this, merge the leading ranges until we reach the pointwise numel. + all_var_ranges = [*dep.ranges.items()] + pointwise_vars_numel = sympy.S.One + sizevars = V.graph.sizevars + for pointwise_end_idx, (var, numel) in enumerate(all_var_ranges): + pointwise_vars_numel *= numel + if sizevars.statically_known_geq( + pointwise_vars_numel, pointwise_numel + ): + break + + # Reject the split if it does not match the total pointwise numel. + if not sizevars.statically_known_equals( + pointwise_vars_numel, pointwise_numel + ): + continue + + # Partition var ranges into pointwise and reduction splits. + reduction_start_idx = pointwise_end_idx + 1 + var_ranges = ( + all_var_ranges[:reduction_start_idx] + if is_pointwise + else all_var_ranges[reduction_start_idx:] + ) + + # Pattern match the subexpression pertaining to each index variable. + index_tiling = [] + for var, numel in var_ranges: + index = BlockPatternMatcher.get_subexpr_involving_symbol( + dep.index, var + ) + + # Heuristic to bound the maximum dimensionality of the block. + num_dims = max( + 2, + index.count(FloorDiv) + index.count(ModularIndexing), + len(ranges_to_tile), + ) + + # Attempt to pattern match the index expr. + # Failed matches default to the full range. + match_result = BlockPatternMatcher.match_mod_div_block_expr( + index, var, numel, num_dims + ) + dims = match_result[0] if match_result is not None else [numel] + index_tiling.extend(dims) + + # Prune dimensions of size 1. + index_tiling = [ + dim + for dim in index_tiling + if not V.graph.sizevars.statically_known_equals(dim, sympy.S.One) + ] + + if len(index_tiling) > 0: + node_tilings.append(index_tiling) + + # Flatten leading dimensions, assigning labels to each dim. + for node_tiling in node_tilings: + num_leading_dims = max(0, len(node_tiling) - get_max_tiles(2)) + first_trailing_dim = num_leading_dims + 1 + collapsed_leading_dim = sympy_product(node_tiling[:first_trailing_dim]) + collapsed_splits = (collapsed_leading_dim,) + tuple( + node_tiling[first_trailing_dim:] + ) + tilings.add( + cls.complete_partial_tiling( + cls.create_partial_tiling(collapsed_splits, is_pointwise), + pointwise_numel, + reduction_numel, + ) + ) + + # Rank tilings by the number of dimensions. E.g., prefer 2D to 1D. + # Since this is a stable sort, ties are broken by schedule order. + ranked_tilings = sorted( + tilings, + key=len, + reverse=True, + ) + + return ranked_tilings + + @classmethod + def compute_tiling_strategy( + cls, + node_schedule: list[NodeScheduleEntry], + pointwise_numel: sympy.Expr, + reduction_numel: sympy.Expr, + coalesce_analysis: CoalesceVarAnalysis, + ) -> tuple[dict[str, sympy.Expr], Optional[dict[str, sympy.Expr]]]: + """ + Generates a tiling, and a score of each tile according to each tile's coalesced memory accesses. + """ + tiling_var: Optional[sympy.Expr] = ( + None + if not coalesce_analysis.suggested_split + else coalesce_analysis.suggested_split.var + ) + + all_iter_vars = coalesce_analysis.norm_read_writes.index_vars + all_red_vars = coalesce_analysis.norm_read_writes.reduce_vars + ranges = coalesce_analysis.norm_read_writes.var_ranges + + pw_ranges = [ranges[v] for v in all_iter_vars] + red_ranges = [ranges[v] for v in all_red_vars] + + torch._check( + sympy_product(pw_ranges) == pointwise_numel, + lambda: f"{pw_ranges}, {pointwise_numel}, {node_schedule}", + ) + torch._check( + sympy_product(red_ranges) == reduction_numel, + lambda: f"{red_ranges}, {reduction_numel}, {node_schedule}", + ) + + # score of a pointwise or reduction split + scored_sub_split: dict[Any, tuple[list[int], list[int]]] = {} + + score_split: list[ + tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]] + ] = [] + + def process_node_vars( + vars_to_use: tuple[sympy.Expr, ...] = (), + use_split_var: bool = False, + is_pointwise: bool = False, + ) -> tuple[list[int], list[int]]: + """ + Generate a tiling, and a tiling score, given vars to use as splits. + """ + + ranges = pw_ranges if is_pointwise else red_ranges + target_numel = pointwise_numel if is_pointwise else reduction_numel + # Some kernels have no reduction ranges, and a reduction numel of 1 + if not ranges: + if target_numel: + return ([target_numel], []) + else: + return ([], []) + + key = (repr(vars_to_use), use_split_var, is_pointwise) + if out := scored_sub_split.get(key, None): + return out + + splitting_vars = all_iter_vars if is_pointwise else all_red_vars + + splits = [] + split_scores = [] + prod = 1 + prev_var_coalesced_score = 0 + + # iterate from non-dense to dense + for v, v_range in zip(splitting_vars, ranges): + if v not in vars_to_use: + prod *= v_range + prev_var_coalesced_score = coalesce_analysis.coalesced_by_var.get( + v, 0 + ) + continue + + if use_split_var and v == tiling_var: + var_tiling = coalesce_analysis.suggested_split + assert var_tiling is not None + + tile = var_tiling.tiling_factor + remainder = FloorDiv(v_range, var_tiling.tiling_factor) + + splits.append(prod * remainder) + split_scores.append(var_tiling.score) + + splits.append(tile) + split_scores.append(coalesce_analysis.coalesced_by_var.get(v, 0)) + + prod = 1 + prev_var_coalesced_score = 0 + + continue + + prod *= v_range + splits.append(prod) + split_scores.append(coalesce_analysis.coalesced_by_var.get(v, 0)) + prod = 1 + + if prod != 1 or (is_pointwise and len(splits) == 0): + splits.append(prod) + split_scores.append(prev_var_coalesced_score) + + # penalize splits that leave small blocks + # where we can't fully utilize full memory transaction + # TODO: incorporate exact bitwidth, and read/write + # coalesced write is 2x more important + for i in range(len(splits)): + s = V.graph.sizevars.size_hint(splits[i], fallback=32) + s = min(s, 8) + split_scores[i] = int(split_scores[i] * s / 8) + + scored_sub_split[key] = (splits, split_scores) + return (splits, split_scores) + + # add the default tiling + score_split.append( + ( + process_node_vars(is_pointwise=True), + process_node_vars(is_pointwise=False), + ) + ) + + if tiling_var: + score_split.append( + ( + process_node_vars( + (tiling_var,), use_split_var=True, is_pointwise=True + ), + process_node_vars(is_pointwise=False), + ) + ) + + # TODO, add tests, reduction splits if config.triton.tile_reductions + # TODO: we should ignore tiny increases in score for extra splits + overlapping_iter_vars = ( + all_iter_vars & coalesce_analysis.coalesced_by_var.keys() + ) + for v in overlapping_iter_vars: + score_split.append( + ( + process_node_vars((v,), is_pointwise=True), + process_node_vars(is_pointwise=False), + ) + ) + + if get_max_tiles(default=3) == 3 and reduction_numel == 1: + for vars_to_use in itertools.combinations(overlapping_iter_vars, 2): + score_split.append( + ( + process_node_vars(vars_to_use, is_pointwise=True), + process_node_vars(is_pointwise=False), + ) + ) + + tilings: list[tuple[CandidateTiling, dict[str, sympy.Expr]]] = [] + for (pw_split, pw_score), (red_split, red_score) in score_split: + candidate = CandidateTiling( + cls.create_tiling(pw_split, red_split), + score=sum(pw_score) + sum(red_score), + ) + tiling_score = cls.create_tiling(pw_score, red_score) + tilings.append((candidate, tiling_score)) + + default_tiling = cls.create_tiling([pointwise_numel], [reduction_numel]) + + # add a slight penalty for longer tilings that dont increase score much, + # and are poor sizes + bad_size_additional_tiling_penalty = 1.025 + good_size_tiling_penalty = 1.005 + + def score_mod(t): + score_factor = 1.0 + for tile_size in t[0].tiling.values(): + if not CandidateTiling.is_good_size(tile_size): + score_factor = score_factor / bad_size_additional_tiling_penalty + else: + score_factor = score_factor / good_size_tiling_penalty + + return -t[0].score * score_factor + + # apply penalty for longer tilings that dont increase score much + for cand, tiling_score in sorted(tilings, key=score_mod): + if cls.tiling_is_compatible( + node_schedule, pointwise_numel, reduction_numel, cand.tiling + ): + # we always include default reduction numel == 1, dont include + tiling_len = len(cand.tiling) - (1 if reduction_numel == 1 else 0) + if tiling_len > get_max_tiles(default=3): + perf_hint_log.info( + "Found optimal tiling with %s tiles but torch._inductor.config.triton.max_tiles " + "set to %s. Consider increasing", + tiling_len, + torch._inductor.config.triton.max_tiles, + ) + continue + + return cand.tiling, tiling_score + + # surprisingly, the default tiling is not always read as compatible by `tiling_is_compatible` + # TODO - look into, occurs with dynamic shapes often + if cand.tiling == default_tiling: + return cand.tiling, tiling_score + + return default_tiling, None + + @classmethod + def tiling_is_compatible( + cls, + node_schedule: list[NodeScheduleEntry], + numel: sympy.Expr, + reduction_numel: sympy.Expr, + tiling: dict[str, sympy.Expr], + ): + assert isinstance(tiling, dict) + return all( + SIMDKernel.is_compatible( + tiling.values(), node.get_ranges(), reduction_numel=reduction_numel + ) + for node in node_schedule + if isinstance(node, scheduler.SchedulerNode) + ) + + @classmethod + def get_first_compatible_tiling( + cls, + node_schedule: list[NodeScheduleEntry], + numel: sympy.Expr, + reduction_numel: sympy.Expr, + ranked_tilings: list[dict[str, sympy.Expr]], + ): + for tiling in ranked_tilings: + if cls.tiling_is_compatible(node_schedule, numel, reduction_numel, tiling): + return tiling + + return None + + @classmethod + def select_tiling( + cls, + node_schedule, + numel, + reduction_numel=sympy.S.One, + coalesce_analysis: Optional[CoalesceVarAnalysis] = None, + ) -> dict[str, sympy.Expr]: + return cls.get_tiling_and_scores( + node_schedule, numel, reduction_numel, coalesce_analysis + )[0] + + @classmethod + def get_tiling_and_scores( + cls, + node_schedule, + numel, + reduction_numel=sympy.S.One, + coalesce_analysis: Optional[CoalesceVarAnalysis] = None, + ) -> tuple[dict[str, sympy.Expr], Optional[dict[str, sympy.Expr]]]: + """ + Heuristics to decide how to tile kernels. + Currently, we tile based on stride-1 dimensions. + + Returns: + `(tile1, tile2, reduction_numel)` s.t. `tile1 * tile2 == numel` + + """ + # If this is a reduction, only tile reduction dims. + is_pointwise = reduction_numel == 1 + + # Tiled reductions are gated by a config flag. + default_tiling = cls.create_tiling([numel], [reduction_numel]) + + # # TODO: enable by default + if ( + torch._inductor.config.triton.coalesce_tiling_analysis + and coalesce_analysis + and not config.triton.prefer_nd_tiling + ): + return cls.compute_tiling_strategy( + node_schedule, numel, reduction_numel, coalesce_analysis + ) + + if (not is_pointwise and not config.triton.tile_reductions) or get_max_tiles( + default=2 + ) <= 1: + # Emit a perf hint in case we miss an opportunity to tile a reduction. + if perf_hint_log.level <= logging.WARNING: + for node in EnableReduction.filter(node_schedule): + if ( + not config.triton.tile_reductions + and len(cls.candidate_tilings(node, numel, reduction_numel)) > 0 + ): + perf_hint_log.info( + textwrap.dedent( + """ + Reduction over non-contiguous dims. + Consider setting config.triton.tile_reductions to True. + """ + ) + ) + break + + return default_tiling, None + + seen_names: OrderedSet[str] = OrderedSet() + candidate_tiles: Counter[CandidateTiling] = collections.Counter() + for node in EnableReduction.filter(node_schedule): + for candidate_tiling in cls.candidate_tilings(node, numel, reduction_numel): + if candidate_tiling.name in seen_names: + continue + elif candidate_tiling.name is not None: + seen_names.add(candidate_tiling.name) + candidate_tiles[candidate_tiling] += candidate_tiling.score + + ranked_tilings: list[dict[str, sympy.Expr]] = [ + candidate_tiling.tiling + for candidate_tiling, score in candidate_tiles.most_common() + ] + + if get_max_tiles(default=2) >= 3 and is_pointwise: + # Consider adding a third dimension of tiling, but only + # when a1 is a multiple of b1; otherwise, you have a lot + # of stragglers which is annoying to generate code for. + # + # NB: More than three max tiles is not enabled by default. + + def convert_tiling_to_3d( + tiling0: dict[str, sympy.Expr], tiling1: dict[str, sympy.Expr] + ) -> Optional[dict[str, sympy.Expr]]: + a0, a1 = tiling0["x"], tiling0.get("y", 1) + b0, b1 = tiling1["x"], tiling1.get("y", 1) + + if ( + free_unbacked_symbols([a1, b1]) + or V.graph.sizevars.size_hint(a1 - b1) == 0 + ): + return None + if V.graph.sizevars.size_hint(a1 - b1) < 0: + # swap so a0 is bigger + (a0, a1), (b0, b1) = (b0, b1), (a0, a1) + + assert V.graph.sizevars.size_hint(a1 - b1) > 0 + if not V.graph.sizevars.statically_known_multiple_of(a1, b1): + return None + + new_tiling = { + "z": a0, + "y": FloorDiv(a1, b1), + "x": b1, + "r0_": tiling0["r0_"], + } + + return new_tiling + + for i in range(1, len(ranked_tilings)): + new_3d_tiling = convert_tiling_to_3d( + ranked_tilings[0], ranked_tilings[i] + ) + if new_3d_tiling is not None: + ranked_tilings = [new_3d_tiling] + ranked_tilings + break # only 1 choice for now + + if len(ranked_tilings) > 1: + perf_hint_log.info("possibly bad tiling: %s", ranked_tilings) + + # Optionally, prefer tiling into as many dimensions as possible. + if config.triton.prefer_nd_tiling: + ranked_tilings = ( + cls.get_nd_tilings(node_schedule, numel, reduction_numel) + + ranked_tilings + ) + + if tiling := cls.get_first_compatible_tiling( + node_schedule, numel, reduction_numel, ranked_tilings + ): + return tiling, None + + return default_tiling, None + + def flush(self): + pass + + def ready_to_flush(self) -> bool: + return False + + def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False): + if not any(n.is_template() for n in nodes): + _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group + node_schedule = self.generate_node_schedule(nodes, numel, rnumel) + tiling = self.select_tiling(node_schedule, numel, rnumel) + kernel = self.kernel_type( + tiling, + features=SIMDKernelFeatures(node_schedule, numel, rnumel), + ) + self.codegen_node_schedule_with_kernel(node_schedule, kernel) + with ( + config.patch("benchmark_kernel", benchmark_kernel), + V.set_kernel_handler(kernel), + ): + src_code = kernel.codegen_kernel() + else: + prologue, template, epilogue = nodes[0].get_prologue_template_epilogue( + nodes + ) + with config.patch("benchmark_kernel", benchmark_kernel): + src_code = self.codegen_template( + template, + epilogue, + prologue, + only_gen_src_code=True, + ) + + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") + return src_code + + def codegen_comment(self, node_schedule): + pass + + def define_kernel(self, src_code, node_schedule, kernel): + raise NotImplementedError + + +@dataclasses.dataclass(frozen=True) +class CandidateTiling: + tiling: dict[str, sympy.Expr] + score: int # higher is better + name: Optional[str] = None + + @staticmethod + def is_good_size(s): + """Somewhat arbitrary heuristic used to boost scores for some sizes""" + s = V.graph.sizevars.size_hint(s) + return s >= 32 and (s % 32 == 0) + + +class CantSplit(Exception): + pass diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/simd_kernel_features.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/simd_kernel_features.py new file mode 100644 index 0000000000000000000000000000000000000000..50042d9fafc83810dae8b77b6c42d0e9f7e06ee7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/simd_kernel_features.py @@ -0,0 +1,618 @@ +from __future__ import annotations + +import collections +import dataclasses +import functools +import itertools +import typing +from typing import Any, Optional, Union + +import sympy + +import torch + +from ...utils._ordered_set import OrderedSet +from ...utils._sympy.functions import FloorDiv, ModularIndexing +from ...utils._sympy.symbol import make_symbol, SymT +from ..dependencies import Dep, extract_loop_body_with_args, MemoryDep +from ..runtime.hints import ReductionHint +from ..scheduler import SchedulerNode +from ..utils import cache_on_self +from ..virtualized import V + + +if typing.TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from torch._inductor.tiling_utils import CoalesceVarAnalysis + + +class NodeScheduleMarker: + @staticmethod + def only_nodes(it: Iterable[NodeScheduleEntry]) -> Iterable[SchedulerNode]: + for item in it: + if not (item is DisableReduction or item is EnableReduction): + yield item # type: ignore[misc] + + @staticmethod + def is_reduction() -> bool: + return False + + +NodeScheduleEntry = Union[SchedulerNode, type[NodeScheduleMarker]] + + +class DisableReduction(NodeScheduleMarker): + """ + Marker to invoke `kernel.disable_reduction()`. This closes a + reduction loop and allows for pointwise ops to occur on the output + of a reduction. + """ + + +class EnableReduction(NodeScheduleMarker): + """ + Marker to end a DisableReduction block. + """ + + @staticmethod + def filter(node_schedule: list[NodeScheduleEntry]) -> Iterable[SchedulerNode]: + """ + Get the nodes from node_schedule skipping those in a + DisableReduction block. + """ + disabled = False + for node in node_schedule: + if node in (EnableReduction, DisableReduction): + # Don't tile stuff outside the main reduction loop + disabled = node is DisableReduction + elif disabled: + pass + else: + yield node # type: ignore[misc] + + +class SIMDKernelFeatures: + """ + An ordered schedule of nodes that will become a single kernel. + """ + + def __init__( + self, + node_schedule: list[NodeScheduleEntry], + numel: sympy.Expr, + reduction_numel: sympy.Expr = sympy.S.One, + coalesce_analysis: Optional[CoalesceVarAnalysis] = None, + ): + self.node_schedule = node_schedule + # numel excludes reduction_numel + self.numel: sympy.Expr = V.graph.sizevars.simplify(numel) + self.reduction_numel: sympy.Expr = V.graph.sizevars.simplify(reduction_numel) + self._stats_cache: dict[tuple[sympy.Expr, ...], MemoryStats] = {} + self.coalesce_analysis = coalesce_analysis + + @cache_on_self + def is_reduction(self) -> bool: + return self.reduction_numel != 1 + + @cache_on_self + def scheduler_nodes(self) -> Iterable[SchedulerNode]: + return tuple(NodeScheduleMarker.only_nodes(self.node_schedule)) + + def reduction_nodes(self) -> list[SchedulerNode]: + return [n for n in self.scheduler_nodes() if n.is_reduction()] + + @cache_on_self + def buf_accesses(self) -> dict[str, list[Dep]]: + """only needed for config.benchmark_kernel""" + buf_accesses = collections.defaultdict(list) + for node in self.scheduler_nodes(): + for access in node.read_writes.reads | node.read_writes.writes: + buf_accesses[access.name].append(access) + return buf_accesses + + @cache_on_self + def op_counts(self) -> collections.Counter[str]: + counts: collections.Counter[str] = collections.Counter() + for node in self.scheduler_nodes(): + counts.update(node._body.op_counts) + return counts + + def contains_op(self, op_name: str) -> bool: + """True if V.ops.{op_name} is used in node_schedule""" + return bool(self.op_counts().get(op_name)) + + def get_mutations(self) -> OrderedSet[str]: + mutations: OrderedSet[str] = OrderedSet() + for node in self.scheduler_nodes(): + for buf in node.get_outputs(): + mutations.update(buf.get_mutations()) + return mutations + + @cache_on_self + def select_index_dtype(self) -> torch.dtype: + # Gather all used buffer names + buffer_names: OrderedSet[str] = OrderedSet() + for node in self.scheduler_nodes(): + buffer_names.update(node.get_buffer_names()) + buffer_names.update(node.used_buffer_names()) + buffers = [V.graph.get_buffer(name) for name in buffer_names] + + # In theory we can separately check xnumel and rnumel are <= int_max + # but some indexers do use the full linear index so we need to be + # conservative here. + total_numel = self.numel * self.reduction_numel + + from .simd import SIMDScheduling + + if SIMDScheduling.can_use_32bit_indexing(total_numel, buffers): + return torch.int32 + return torch.int64 + + @cache_on_self + def get_reduction_hint(self) -> ReductionHint: + reductions = self.reduction_nodes() + if len(reductions) > 0: + hints = [self.reduction_hint(n) for n in reductions] + if hints.count(hints[0]) == len(hints): + reduction_hint_val = hints[0] + else: + reduction_hint_val = ReductionHint.DEFAULT + + if ( + reduction_hint_val == ReductionHint.INNER + and self.has_non_contiguous_pw_in_reduction_kernel() + ): + reduction_hint_val = ReductionHint.DEFAULT + else: + reduction_hint_val = ReductionHint.DEFAULT + return reduction_hint_val + + @cache_on_self + def buffer_read_counts(self) -> dict[str, int]: + """Counts how many times each buffer is read within the kernel""" + read_counts: dict[str, int] = collections.defaultdict(int) + + for node in self.scheduler_nodes(): + # node.read_writes.reads contains MemoryDep objects for each read + for read_dep in node.read_writes.reads: + read_counts[read_dep.name] += 1 + + return dict(read_counts) # Convert defaultdict to regular dict + + def has_non_contiguous_pw_in_reduction_kernel(self) -> bool: + pointwise_nodes = [ + n + for n in self.scheduler_nodes() + if not n.is_reduction() + and n.group[1][0] == self.numel * self.reduction_numel + ] + for node in pointwise_nodes: + # An index can be an integer when loading a random seed. + if not all( + not isinstance(dep, MemoryDep) + or dep.is_contiguous() + or isinstance(dep.index, (sympy.Integer, int)) + or dep.stride1_for_last_dim() + for dep in itertools.chain( + node.read_writes.reads, node.read_writes.writes + ) + ): + return True + return False + + @staticmethod + def reduction_hint(node: Any) -> ReductionHint: + assert node.is_reduction() + if node.node.data.reduction_hint != ReductionHint.INNER and all( + dep.is_contiguous() + for dep in itertools.chain(node.read_writes.reads, node.read_writes.writes) + ): + return ReductionHint.INNER + else: + return node.node.data.reduction_hint + + def memory_stats( + self, groups_dict: Optional[dict[str, sympy.Expr]] = None + ) -> MemoryStats: + """Analysis to generate features that can be used in heuristics""" + if groups_dict is None: + groups = (self.numel, self.reduction_numel) + elif groups_dict.keys() == OrderedSet(["x", "r0_"]): + groups = (groups_dict["x"], groups_dict["r0_"]) + else: + raise NotImplementedError(f"groups_dict={groups_dict!r}") + result = self._stats_cache.get(groups) + if result is None: + self._stats_cache[groups] = result = MemoryStats.compute( + MemoryEstimator(self, groups) + ) + return result + + +class MemoryEstimator: + """ + Estimate various properties of the kernel for use in heuristics. + We simulate the memory effects of CSE/buffer elimination in codegen. + """ + + kernel_sizes: tuple[sympy.Expr, ...] + outside_loop: MemoryEstimate + loops: list[MemoryEstimate] + persistent: MemoryEstimate + symbols: list[sympy.Symbol] + + def __init__(self, features: SIMDKernelFeatures, groups: Sequence[sympy.Expr]): + self.features = features + self.inside_reduction = features.is_reduction() + self.store_buffer_names: OrderedSet[str] = OrderedSet() + self.must_keep_buffers: OrderedSet[str] = OrderedSet() + self.num_reductions_dims = 1 + self.groups = groups + self.symbols = [make_symbol(SymT.INDEX, i) for i in range(len(groups))] + # We are doing two estimates simultaneously: + # 1) the first is a for a non-persistent (aka looped) reduction, using self.outside_loop/self.loops + # we add an item to loops each corresponding to each reduction loop in the kernel + # outside_loop is only used for broadcasting or point-wise ops that don't use the reduction dimension + # 2) the second is for a persistent kernel, using self.persistent + # persistent kernels don't have loops, so we only have one MemoryEstimate() + # for point-wise ops the two estimates will be the same, they matter for reductions only + self.outside_loop = MemoryEstimate() + self.loops = [MemoryEstimate()] + self.persistent = MemoryEstimate() + self.simulate_codegen() + self.remove_kernel_local() + + def simulate_codegen(self) -> None: + from .simd import SIMDKernel + + kernel_size_outside_loop = (*self.groups[:-1], sympy.S.One) + kernel_size_inside_loop = tuple(self.groups) + self.kernel_sizes = kernel_size_inside_loop + + for node in self.features.node_schedule: + if node is DisableReduction: + self.inside_reduction = False + self.kernel_sizes = kernel_size_outside_loop + continue + elif node is EnableReduction: + self.inside_reduction = True + self.kernel_sizes = kernel_size_inside_loop + self.loops.append(MemoryEstimate()) + continue + assert isinstance(node, SchedulerNode) + rw = extract_loop_body_with_args( + node._body, + SIMDKernel.map_kernel_groups_to_node_sizes( + self.kernel_sizes, node.get_ranges(), self.set_ranges + ), + dict(zip(self.symbols, self.kernel_sizes)), + ) + + for dep in rw._reads: + assert isinstance(dep, MemoryDep) + dep = dep.simplify_with_ranges() + if not self.persistent.writes.get(dep.name): # cache miss? + self.persistent.reads[dep.name].add(dep) + # the cache behavior of looped kernels is more complex than the persistent case above + # some operations are lifted outside the loop (if they don't use the reduction dimension) + # other operations are inside the loop, and can only be reused within the same loop + if not ( + self.outside_loop.writes.get(dep.name) + or self.loops[-1].writes.get(dep.name) + ): + self.scope(dep).reads[dep.name].add(dep) + if dep.name in self.store_buffer_names and self.loops[-1].reads.get( + dep.name + ): + self.must_keep_buffers.add(dep.name) + + for dep in rw._writes: + assert isinstance(dep, MemoryDep) + dep = dep.simplify_with_ranges() + self.store_buffer_names.add(dep.name) + self.persistent.writes[dep.name].add(dep) + self.scope(dep).writes[dep.name].add(dep) + + def remove_kernel_local(self) -> None: + # Remove any kernel-local buffers + fused_node_names = OrderedSet( + [n.get_name() for n in self.features.scheduler_nodes()] + ) + for name in self.store_buffer_names: + if not self.persistent.reads.get( + name + ) and V.graph.scheduler.can_buffer_be_removed_through_fusion( + name, fused_node_names + ): + self.persistent.remove(name) + if name not in self.must_keep_buffers: + # we can also remove this from the looped kernel + self.outside_loop.remove(name) + for loop in self.loops: + loop.remove(name) + + if not self.loops[-1]: + self.loops.pop() # for pointwise ops + + def scope(self, dep: MemoryDep) -> MemoryEstimate: + """Determine how a read/write should be categorized""" + if self.inside_reduction and ( + self.has_reduction_var(dep.index) or dep.is_indirect() + ): + return self.loops[-1] + return self.outside_loop + + def has_reduction_var(self, index: sympy.Expr) -> bool: + for sym in self.symbols[-self.num_reductions_dims :]: + if isinstance(sym, sympy.Symbol) and sym in index.free_symbols: + return True + return False + + def set_ranges(self, *lengths: list[list[sympy.Expr]]) -> list[list[sympy.Expr]]: + assert len(self.kernel_sizes) == len(lengths) + return [ + self.make_flat_range(sym, numel, length) + for sym, numel, length in zip(self.symbols, self.kernel_sizes, lengths) + ] + + @staticmethod + def make_flat_range( + sym: sympy.Symbol, numel: sympy.Expr, lengths: list[sympy.Expr] + ) -> list[sympy.Expr]: + if len(lengths) == 1 and numel == lengths[0]: + return [sym] + divisor = sympy.S.One + itervars = [] + for length in reversed(lengths): + if V.graph.sizevars.statically_known_equals(divisor * length, numel): + expr = FloorDiv(sym, divisor) + else: + expr = ModularIndexing(sym, divisor, length) + itervars.append(expr) + divisor = divisor * length + return [*reversed(itervars)] + + +@dataclasses.dataclass +class MemoryEstimate: + """Tracks the memory usage of a single loop in the generated kernel""" + + reads: dict[str, OrderedSet[MemoryDep]] = dataclasses.field( + default_factory=functools.partial(collections.defaultdict, OrderedSet) + ) + writes: dict[str, OrderedSet[MemoryDep]] = dataclasses.field( + default_factory=functools.partial(collections.defaultdict, OrderedSet) + ) + + def remove(self, name: str) -> None: + self.reads.pop(name, None) + self.writes.pop(name, None) + + def __bool__(self) -> bool: + return bool(self.reads or self.writes) + + def __repr__(self) -> str: + return f"""MemoryEstimate( + reads={[*itertools.chain.from_iterable(self.reads.values())]!r}, + writes={[*itertools.chain.from_iterable(self.writes.values())]!r} + )""" + + +@dataclasses.dataclass +class StatsForDim: + """Memory usage stats for a block dimension in the generated kernel (different from user dimensions)""" + + # the number of load/store ops + count_per_thread_contiguous: int = 0 + count_per_thread_broadcast: int = 0 + count_per_thread_non_contiguous: int = 0 # excludes broadcast + + # total bytes in each load/store op for a single element + bytes_per_thread_contiguous: int = 0 + bytes_per_thread_broadcast: int = 0 + bytes_per_thread_non_contiguous: int = 0 # excludes broadcast + + # total bytes read by entire kernel + bytes_contiguous_or_broadcast: sympy.Expr = sympy.S.Zero + bytes_non_contiguous: sympy.Expr = sympy.S.Zero + + def __add__(self, other: typing.Self) -> StatsForDim: + return StatsForDim( + count_per_thread_contiguous=self.count_per_thread_contiguous + + other.count_per_thread_contiguous, + count_per_thread_broadcast=self.count_per_thread_broadcast + + other.count_per_thread_broadcast, + count_per_thread_non_contiguous=self.count_per_thread_non_contiguous + + other.count_per_thread_non_contiguous, + bytes_per_thread_contiguous=self.bytes_per_thread_contiguous + + other.bytes_per_thread_contiguous, + bytes_per_thread_broadcast=self.bytes_per_thread_broadcast + + other.bytes_per_thread_broadcast, + bytes_per_thread_non_contiguous=self.bytes_per_thread_non_contiguous + + other.bytes_per_thread_non_contiguous, + bytes_contiguous_or_broadcast=self.bytes_contiguous_or_broadcast + + other.bytes_contiguous_or_broadcast, + bytes_non_contiguous=self.bytes_non_contiguous + other.bytes_non_contiguous, + ) + + @property + def count_per_thread(self) -> int: + return ( + self.count_per_thread_contiguous + + self.count_per_thread_broadcast + + self.count_per_thread_non_contiguous + ) + + @property + def bytes_per_thread(self) -> int: + return ( + self.bytes_per_thread_contiguous + + self.bytes_per_thread_broadcast + + self.bytes_per_thread_non_contiguous + ) + + @property + def bytes(self) -> sympy.Expr: + return self.bytes_contiguous_or_broadcast + self.bytes_non_contiguous + + @property + def contiguous_score(self) -> float: + return 1.0 - self.count_per_thread_non_contiguous / max( + self.count_per_thread, 1 + ) + + +@dataclasses.dataclass +class StatsForLoop: + """Memory usage stats for single loop in the generated kernel""" + + # load/store ops + count_per_thread: int = 0 + bytes_per_thread: int = 0 + + def __add__(self, other: typing.Self) -> StatsForLoop: + return StatsForLoop( + count_per_thread=self.count_per_thread + other.count_per_thread, + bytes_per_thread=self.bytes_per_thread + other.bytes_per_thread, + ) + + +@dataclasses.dataclass +class StatsForReadsOrWrites: + """Memory usage stats that are collected for reads/writes/both""" + + dim: list[StatsForDim] + loop: list[StatsForLoop] + # total bytes contiguous in any dimension + bytes_contiguous_or_broadcast: sympy.Expr = sympy.S.Zero + bytes_non_contiguous: sympy.Expr = sympy.S.Zero + + def __add__(self, other: typing.Self) -> StatsForReadsOrWrites: + assert len(self.dim) == len(other.dim) + assert len(self.loop) == len(other.loop) + return StatsForReadsOrWrites( + dim=[a + b for a, b in zip(self.dim, other.dim)], + loop=[a + b for a, b in zip(self.loop, other.loop)], + bytes_contiguous_or_broadcast=self.bytes_contiguous_or_broadcast + + self.bytes_contiguous_or_broadcast, + bytes_non_contiguous=self.bytes_non_contiguous + other.bytes_non_contiguous, + ) + + @property + def count_per_thread(self) -> int: + return self.dim[0].count_per_thread + + @property + def bytes_per_thread(self) -> int: + return self.dim[0].bytes_per_thread + + @property + def bytes(self) -> sympy.Expr: + return self.bytes_contiguous_or_broadcast + self.bytes_non_contiguous + + @classmethod + def compute( + cls, + loop_deps: list[dict[str, OrderedSet[MemoryDep]]], + index_symbols: list[sympy.Symbol], + ) -> typing.Self: + ndim = len(index_symbols) + result = cls(dim := [StatsForDim() for _ in range(ndim)], []) + for dep_group in loop_deps: + result.loop.append(loop_stats := StatsForLoop()) + for name, deps in dep_group.items(): + assert deps + contiguous_or_broadcast = [True] * ndim + numel = sympy.S.Zero + itemsize = V.graph.get_dtype(name).itemsize + loop_stats.count_per_thread += len(deps) + loop_stats.bytes_per_thread += itemsize * len(deps) + for dep in deps: + strides: list[sympy.Expr] = V.graph.sizevars.stride_vars( + dep.index, index_symbols + ) + for i in range(ndim): + if V.graph.sizevars.statically_known_equals(strides[i], 1): + dim[i].count_per_thread_contiguous += 1 + dim[i].bytes_per_thread_contiguous += itemsize + elif ( + V.graph.sizevars.statically_known_equals(strides[i], 0) + and not dep.is_indirect() + ): + dim[i].count_per_thread_broadcast += 1 + dim[i].bytes_per_thread_broadcast += itemsize + else: + dim[i].count_per_thread_non_contiguous += 1 + dim[i].bytes_per_thread_non_contiguous += itemsize + contiguous_or_broadcast[i] = False + numel += dep.get_numel() + if len(deps) > 1: + # can't read more elements than exist in the buffer + numel = sympy.Min(numel, V.graph.get_numel(name)) + nbytes = numel * itemsize + for i in range(ndim): + if contiguous_or_broadcast[i]: + dim[i].bytes_contiguous_or_broadcast += nbytes + else: + dim[i].bytes_non_contiguous += nbytes + if any(contiguous_or_broadcast): + result.bytes_contiguous_or_broadcast += nbytes + else: + result.bytes_non_contiguous += nbytes + if len(result.loop) > 1: + # the first loop represent the "outside of the loop" compute which could be long lived + result.loop = [result.loop[0] + x for x in result.loop[1:]] + return result + + +@dataclasses.dataclass +class StatsForKernelType: + """Memory usage stats that are collected for both persistent and looped kernels""" + + reads: StatsForReadsOrWrites + writes: StatsForReadsOrWrites + memory: StatsForReadsOrWrites + + @classmethod + def compute( + cls, loops: list[MemoryEstimate], estimator: MemoryEstimator + ) -> typing.Self: + reads = StatsForReadsOrWrites.compute( + [loop.reads for loop in loops], estimator.symbols + ) + writes = StatsForReadsOrWrites.compute( + [loop.writes for loop in loops], estimator.symbols + ) + return cls( + reads=reads, + writes=writes, + memory=reads + writes, + ) + + +@dataclasses.dataclass +class MemoryStats: + """Memory usage stats collected for each generated kernel""" + + persistent: StatsForKernelType + looped: StatsForKernelType + + def get(self, persistent: bool) -> StatsForKernelType: + return self.persistent if persistent else self.looped + + @classmethod + def compute(cls, estimator: MemoryEstimator) -> typing.Self: + persistent = StatsForKernelType.compute([estimator.persistent], estimator) + if len(estimator.loops) == 1 and not ( + estimator.outside_loop and estimator.loops[0] + ): + looped = persistent # loops/persistent is the same in this common case + else: + looped = StatsForKernelType.compute( + [estimator.outside_loop, *estimator.loops], estimator + ) + return cls( + persistent=persistent, + looped=looped, + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/subgraph.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/subgraph.py new file mode 100644 index 0000000000000000000000000000000000000000..7840744918d9d2db49ba3d1bf4346482bfd2b327 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/subgraph.py @@ -0,0 +1,208 @@ +import itertools +import logging +from typing import Any, Callable + +import torch +import torch._inductor.config as config +from torch._inductor import ir +from torch._inductor.codegen.common import KernelTemplate +from torch._inductor.ir import ( + Buffer, + get_free_symbols, + get_symbolic_inputs, + gm_original_output_strides, + ir_node_to_tensor, + Layout, +) +from torch._inductor.runtime.benchmarking import benchmarker +from torch._inductor.utils import do_bench_using_profiling +from torch._inductor.virtualized import V + + +log = logging.getLogger(__name__) + + +class SubgraphChoiceCaller(ir.ChoiceCaller): + """ + Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary + GraphModule. Compiles the Subgraph down to a module for benchmarking. + """ + + def __init__( + self, + name: str, + input_nodes: list[Buffer], + layout: Layout, + description: str, + make_fx_graph: Callable[..., Any], + ) -> None: + super().__init__(name, input_nodes, layout, description) + + self.example_inputs = [] + with V.fake_mode: + for inp in self.input_nodes: + # Here there will be no unbacked symbols, as SubgraphBuffer does not support them + assert len(get_free_symbols(inp.get_size(), unbacked_only=True)) == 0 + assert len(get_free_symbols(inp.get_stride(), unbacked_only=True)) == 0 + + inp.data.freeze_layout() # type: ignore[attr-defined] + self.example_inputs.append(ir_node_to_tensor(inp)) + + self.gm = make_fx_graph(*self.example_inputs) + gm_original_output_strides(self.gm) + + self.sym_inputs = get_symbolic_inputs(self.input_nodes) + + def __str__(self) -> str: + return f"SubgraphCaller({self.name})" + + def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: + # Codegen Subgraph for benchmarking + # Need GraphLowering instead of SubgraphLowering to generate + # fully callable module + import torch._inductor.config as inductor_config + from torch._inductor.graph import GraphLowering + + bm_graph_lowering = GraphLowering( + gm=self.gm, + example_inputs=self.example_inputs, + shape_env=V.graph._shape_env, + cpp_wrapper=V.graph.cpp_wrapper, + aot_mode=V.graph.aot_mode, + extern_node_serializer=V.graph.extern_node_serializer, + is_inference=V.graph.is_inference, + is_backward=V.graph.is_backward, + name=f"benchmark_{self.name}", + ) + + for sym_inp in self.sym_inputs: + bm_graph_lowering.graph_inputs[sym_inp.name] = sym_inp + bm_graph_lowering.graph_input_names.append(sym_inp.name) + + sym_inputs = [ + int(V.graph.sizevars.shape_env.size_hint(sym_var)) + for sym_var in self.sym_inputs + ] + + if len(sym_inputs) == 0: + # Sanity check that args are same layout as example inputs + # Only do it if there are no symbolic inputs, otherwise + # the dynamic dim will be realized to the same size as args + for ar, example_inp in zip(args, self.example_inputs): + # Sanity check that args are same layout as example inputs + if isinstance(ar, torch.Tensor): + assert isinstance(example_inp, torch.Tensor) + assert ar.shape == example_inp.shape + assert ar.stride() == example_inp.stride() + + if len(sym_inputs) == 0: + # Sanity check that args are same layout as example inputs + # Only do it if there are no symbolic inputs, otherwise + # the dynamic dim will be realized to the same size as args + for ar, example_inp in zip(args, self.example_inputs): + # Sanity check that args are same layout as example inputs + if isinstance(ar, torch.Tensor): + assert isinstance(example_inp, torch.Tensor) + assert ar.shape == example_inp.shape + assert ar.stride() == example_inp.stride() + + with V.set_graph_handler(bm_graph_lowering): + # Don't bother autotuning on Triton here + with inductor_config.patch( + max_autotune=False, + max_autotune_gemm=False, + max_autotune_gemm_backends="ATEN", + ): + bm_graph_lowering.run(*self.example_inputs) + mod = bm_graph_lowering.compile_to_module() + bm_func = mod.call + + bm_func([*sym_inputs, *args]) + if config.profile_bandwidth_with_do_bench_using_profiling: + return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args])) + return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args])) + + def hash_key(self) -> str: + return "-".join( + [ + self.name.rsplit("_", 1)[0], + *[str(inp.get_size()) for inp in self.input_nodes], + *[str(inp.get_stride()) for inp in self.input_nodes], + str(self.gm.graph), + ] + ) + + def output_node(self) -> ir.TensorBox: + return ir.TensorBox.create( + ir.SubgraphBuffer( + layout=self.layout, + input_nodes=self.input_nodes, + gm=self.gm, + example_inputs=self.example_inputs, + subgraph_name=self.name, + ) + ) + + def info_dict(self) -> dict[str, Any]: + """Information returned here is logged to the autotune log file when that is enabled.""" + return { + "backend": "subgraph", + "kernel_name": self.name, + } + + def autoheuristic_id(self) -> str: + return f"subgraph_{self.name}" + + +class SubgraphTemplate(KernelTemplate): + """ + A template for subgraph evaluation to be used in autotuning. + + This class allows creating customized subgraphs that can be appended + as choices during the autotuning process, enabling the selection of + optimal implementations for complex operations. + """ + + index_counter = itertools.count() + + def __init__( + self, + name: str, + make_fx_graph: Callable[..., Any], + ): + """ + Initialize a subgraph template. + + Args: + name: The name of this template + graph: The FX graph + """ + self.name = f"{name}_{next(SubgraphTemplate.index_counter)}" + self.make_fx_graph = make_fx_graph + + def generate( # type: ignore[override] + self, + input_nodes: list[Buffer], + layout: Layout, + **kwargs: Any, + ) -> SubgraphChoiceCaller: + """ + Generate a SubgraphChoiceCaller instance for autotuning. + + Args: + input_nodes: List of input nodes to the subgraph + layout: Memory layout information for the output + example_inputs: Example tensor inputs used to trace and benchmark the subgraph + **kwargs: Additional keyword arguments + + Returns: + SubgraphChoiceCaller: A callable object that can be used for autotuning + """ + + return SubgraphChoiceCaller( + name=self.name, + input_nodes=input_nodes, + layout=layout, + description="", + make_fx_graph=self.make_fx_graph, + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/triton.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/triton.py new file mode 100644 index 0000000000000000000000000000000000000000..519bd3b2a030fb8aef4fd3fafda686b040d90fbc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/triton.py @@ -0,0 +1,4526 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import functools +import itertools +import logging +import math +import operator +import os +import textwrap +from collections.abc import Iterable, Sequence +from functools import lru_cache +from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union + +import sympy +from sympy.printing.precedence import PRECEDENCE + +import torch +import torch._logging +import torch.utils._pytree as pytree +from torch._dynamo.device_interface import get_interface_for_device +from torch._dynamo.utils import identity, preserve_rng_state +from torch._prims_common import is_integer_dtype +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing +from torch.utils._triton import has_triton_package + +from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT +from ...utils._sympy.value_ranges import ValueRanges +from .. import config, ir, metrics +from ..async_compile import AsyncCompile +from ..codecache import code_hash, get_path, PyCodeCache, write_atomic +from ..ops_handler import DefaultHandler +from ..runtime import triton_heuristics +from ..runtime.benchmarking import benchmarker +from ..runtime.hints import ( + AutotuneHint, + DeviceProperties, + TRITON_MAX_BLOCK, + TRITON_MAX_RSPLIT, +) +from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2 +from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode +from ..utils import ( + cache_on_self, + DelayReplaceLine, + get_bounds_index_expr, + get_fused_kernel_name, + get_kernel_metadata, + is_welford_reduction, + Placeholder, + prefix_is_reduction, + sympy_dot, + sympy_product, + sympy_subs, + triton_type, + triton_version_uses_attrs_dict, + upcast_compute_type, +) +from ..virtualized import _ops as ops, ReductionType, StoreMode, V +from ..wrapper_benchmark import get_kernel_category_by_source_code +from .block_analysis import BlockPatternMatcher +from .common import ( + ArgName, + BackendFeature, + ConstexprArg, + CSE, + CSEVariable, + DeferredLine, + IndentedBuffer, + InplacedBuffer, + OpOverrides, + PythonPrinter, + RemovedArg, + SizeArg, + TensorArg, + WorkspaceArg, + WorkspaceZeroMode, +) +from .simd import ( + constant_repr, + IterationRanges, + IterationRangesEntry, + IterationRangesRoot, + SIMDKernel, + SIMDScheduling, +) +from .triton_utils import ( + config_of, + equal_1_arg_indices, + non_constexpr_signature, + should_unwrap_unspec_arg, + signature_to_meta, +) +from .wrapper import SymbolicCallArg + + +if TYPE_CHECKING: + from types import ModuleType + from typing import TypeVar + + from torch._inductor.dtype_propagation import DtypePropagationOpsHandler + + from ..ir import IRNode + from .simd_kernel_features import SIMDKernelFeatures + + _T = TypeVar("_T") + +log = logging.getLogger(__name__) +perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") +schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") +fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") +async_compile = AsyncCompile() + + +class OpDtypeSupport: + """ + Some Triton ops such as libdevice and tl.math only support float32 and float64. + This class records which dtypes are supported by specific IR ops. + """ + + supported_dtypes: dict[str, OrderedSet[torch.dtype]] = {} + convert_outputs: dict[str, bool] = {} + + @classmethod + def register_upcast(cls, func: Callable[..., str], convert_output: bool) -> None: + op_name = func.__name__ + cls.supported_dtypes[op_name] = OrderedSet([torch.float32, torch.float64]) + cls.convert_outputs[op_name] = convert_output + + +@lru_cache(None) +def gen_attr_descriptor_import() -> str: + """ + import AttrsDescriptor if the triton version is new enough to have this + class defined. + """ + if not has_triton_package(): + return "" + + import triton.compiler.compiler + + # Note: this works because triton.compiler.compiler imports AttrsDescriptor from triton.backends.compiler + # When support for the legacy AttrsDescriptor is removed then this import path should be changed. + if hasattr(triton.compiler.compiler, "AttrsDescriptor"): + return "from triton.compiler.compiler import AttrsDescriptor" + else: + return "" + + +@lru_cache(None) +def gen_common_triton_imports() -> str: + imports = IndentedBuffer() + imports.splice( + """ + import triton + import triton.language as tl + """ + ) + if attr_desc := gen_attr_descriptor_import(): + imports.writeline(attr_desc) + + imports.splice( + """ + from torch._inductor.runtime import triton_helpers, triton_heuristics + from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math + from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties + """ + ) + return imports.getvalue() + + +class TritonSymbols: + """ + Stores sympy.Symbol instances and constants associated with triton codegen. + """ + + reduction_types = OrderedSet([SymT.R0_INDEX, SymT.R1_INDEX]) + block_types = OrderedSet([SymT.XBLOCK, SymT.YBLOCK, SymT.ZBLOCK, *reduction_types]) + + block_offsets = { + symt: sympy.Symbol(f"{prefix_str[symt]}offset", integer=True, nonnegative=True) + for symt in block_types + } + + block_sizes = { + symt: sympy.Symbol( + f"{prefix_str[symt].upper()}BLOCK", integer=True, positive=True + ) + for symt in block_types + } + + @classmethod + def get_block_size(cls, tree: IterationRanges) -> sympy.Symbol: + return cls.block_sizes[tree.symt] + + @classmethod + def get_block_offset(cls, tree: IterationRanges) -> sympy.Symbol: + return cls.block_offsets[tree.symt] + + +@dataclasses.dataclass +class IndexingOptions: + index_str: str + mask_vars: OrderedSet[str] + expand_str: Optional[str] + _has_rindex: bool + index: sympy.Expr + + def has_mask(self) -> bool: + return bool(self.mask_vars) + + def has_indirect(self) -> bool: + return free_symbol_is_type(self.index, SymT.TMP) + + def has_rindex(self) -> bool: + return self._has_rindex + + def has_tmpmask(self) -> bool: + return any(str(mask).startswith("tmp") for mask in self.mask_vars) + + def has_rmask(self) -> bool: + return any(str(mask).startswith("r") for mask in self.mask_vars) + + @property + def mask_str(self) -> str: + # The sorted call is added to make sure the order is still + # deterministic if self.mask_vars contains mix of string + # and TritonCSEVariable + return ( + " & ".join(sorted(map(str, self.mask_vars))) if self.mask_vars else "None" + ) + + +@dataclasses.dataclass +class BlockPtrOptions: + params: BlockParameters + constant_offset: sympy.Expr + order: list[int] + mask_vars: OrderedSet[str] + broadcast_shape: Sequence[sympy.Expr] + broadcasting_dims: list[bool] + final_shape: Sequence[sympy.Expr] + _boundary_check: Optional[list[int]] = None + + @property + def shape(self) -> list[sympy.Expr]: + return self.params.shape + + @property + def block_shape(self) -> list[sympy.Expr]: + return self.params.block_shape + + @property + def strides(self) -> list[sympy.Expr]: + return self.params.strides + + @property + def offsets(self) -> list[sympy.Expr]: + return self.params.offsets + + def codegen_broadcast_and_reshape( + self, + value: str, + initial_shape: Sequence[sympy.Expr], + final_shape: Sequence[sympy.Expr], + allow_implicit: bool, + ) -> str: + """ + Generate a broadcast and a reshape for the block pointer. + This restores stride-0 dimensions which were removed from the block pointer. + """ + + # Reshape to add singletons. + pre_broadcast_shape = [ + sympy.S.One if is_broadcasting else dim + for dim, is_broadcasting in zip( + self.broadcast_shape, self.broadcasting_dims + ) + ] + value = triton_reshape(value, initial_shape, pre_broadcast_shape) + + # Broadcast singletons. + # For loads, we can often implicitly broadcast singleton dimensions. + # We need an explicit broadcast for stores, or if the final reshape does more + # than add singletons. + sizevars = V.graph.sizevars + supports_implicit_broadcast = allow_implicit and ( + len(pre_broadcast_shape) == len(final_shape) + and all( + sizevars.statically_known_equals(pre_dim, 1) + or sizevars.statically_known_equals(pre_dim, post_dim) + for pre_dim, post_dim in zip(pre_broadcast_shape, final_shape) + ) + ) + + if any(self.broadcasting_dims) and not supports_implicit_broadcast: + value = f"tl.broadcast_to({value}, {V.kernel.index_to_str(self.broadcast_shape)})" + + # Reshape to the final shape. + value = triton_reshape(value, self.broadcast_shape, final_shape) + + return value + + @staticmethod + def create( + *, + params: BlockParameters, + constant_offset: sympy.Expr, + range_trees: list[IterationRangesRoot], + mask_vars: OrderedSet[str], + get_max_block: Callable[[str], int], + ) -> BlockPtrOptions: + """Helper to create a BlockPtrOptions instance""" + + sizevars = V.graph.sizevars + + def lookup_size(exprs: Iterable[sympy.Expr]) -> list[sympy.Expr]: + return [sizevars.lookup_precomputed_size(expr) for expr in exprs] + + # Look up precomputed sizes + params.shape = lookup_size(params.shape) + params.strides = lookup_size(params.strides) + + # Strip out dimensions of stride 0. + # These will be restored with tl.broadcast_to. + broadcasting_dims = [ + sizevars.statically_known_equals(stride, 0) for stride in params.strides + ] + + # Strip out dimensions of size 1. + # These will be restored by tl.reshape. + singleton_dims = [ + sizevars.statically_known_equals(dim, 1) for dim in params.block_shape + ] + if all(singleton_dims): + # Handle a pure singletons, e.g. [1, 1] + singleton_dims[-1] = False + + # Record the post-broadcast shape before broadcasting dims are removed. + # The pre-broadcast shape is identical to this, except broadcasting dims are + # replaced with 1. + broadcast_shape = [ + dim + for dim, is_singleton in zip(params.block_shape, singleton_dims) + if not is_singleton + ] + + # Combine all removable dims. + removable_dims = [any(dims) for dims in zip(singleton_dims, broadcasting_dims)] + + def remove_dims(it): + """Removes any broadcasting or singleton dims from a given sequence""" + return [ + item + for item, is_removable in zip(it, removable_dims) + if not is_removable + ] + + # Drop removable dimensions from the input. + params = BlockParameters( + **{key: remove_dims(val) for key, val in dataclasses.asdict(params).items()} + ) + + # Compute the final shape, adjusting for special kernel types. + final_shape = [TritonSymbols.get_block_size(tree) for tree in range_trees] + if V.kernel.no_x_dim: + assert range_trees[0].prefix == "x" + final_shape.pop(0) + + reduction_ndim = V.kernel.num_reduction_dims + if ( + not V.kernel.inside_reduction + and len(params.strides) == len(V.kernel.numels) - reduction_ndim + and V.kernel.features.is_reduction() + ): + # Need to expand rank to match the rank used inside the reduction loop + final_shape += [sympy.S.One] * reduction_ndim + + result = BlockPtrOptions( + params=params, + constant_offset=V.graph.sizevars.lookup_precomputed_size(constant_offset), + order=list(reversed(range(len(params.shape)))), + mask_vars=mask_vars, + final_shape=final_shape, + broadcast_shape=broadcast_shape, + broadcasting_dims=broadcasting_dims, + ) + result.compute_boundary_check(get_max_block, range_trees) + return result + + def replace_offset( + self, expr: sympy.Expr, replacement: sympy.Expr, symt: SymT + ) -> sympy.Expr: + """ + Replaces instances of {symt}_offset with the new expression. + """ + roffset = TritonSymbols.block_offsets[symt] + return sympy_subs(expr, {roffset: replacement}) + + def format(self, name: str, roffset=True) -> str: + """ + Codegen a call to tl.make_block_ptr() + + Args: + name: variable name for pointer + roffset: should rn_offset be included in offsets=..., for use with tl.advance() + + Returns: + "tl.make_block_ptr(...)" + """ + + def remove_roffsets(expr: sympy.Expr) -> sympy.Expr: + for symt in TritonSymbols.reduction_types: + expr = self.replace_offset(expr, sympy.Integer(0), symt) + return expr + + f = V.kernel.index_to_str + offsets = [*self.offsets] + if not roffset: + offsets = [remove_roffsets(offset) for offset in offsets] + args = [ + ( + f"{name} + ({f(self.constant_offset)})" + if self.constant_offset != 0 + else name + ), + f"shape={f(self.shape)}", + f"strides={f(self.strides)}", + f"block_shape={f(self.block_shape)}", + f"order={f(self.order)}", + f"offsets={f(offsets)}", + ] + return f"tl.make_block_ptr({', '.join(args)})" + + def compute_boundary_check( + self, + get_max_block: Callable[[str], int], + range_trees: list[IterationRangesRoot], + ) -> None: + """List of indices to pass to tl.load(boundary_check=...)""" + sizevars = V.graph.sizevars + + # Substitute maximum block sizes in shape expressions. + # This works in multiple_of checks because block sizes are powers of 2. + block_to_max: dict[sympy.Expr, Any] = { + TritonSymbols.block_sizes[t.symt]: get_max_block(prefix_str[t.symt]) + for t in range_trees + } + + # Also see Note: Constant mask optimisation + # if ynumel / YBLOCK > max_ygrid, then the z dimension is used to handle + # the remaining programs that cannot fit into the y dimension. This means + # it's possible that more than the required number of programs are launched, + # possibly leading to out-of-bounds accesses. So even if ynumel divides YBLOCK, + # boundary checking is required in the dimensions that are based on YBLOCK + # e.g. for [YBLOCK // 16, YBLOCK, XBLOCK] dimensions 0 and 1 need boundary + # checks when max_ygrid is exceeded. + needs_overflow_grid = any(map(V.kernel.needs_yz_grid_overflow, range_trees)) + self._boundary_check = [ + idx + for idx in range(len(self.shape)) + if ( + not sizevars.statically_known_equals(self.strides[idx], sympy.S.Zero) + and ( + ( + needs_overflow_grid + and TritonSymbols.block_sizes[SymT.YBLOCK] + in self.block_shape[idx].free_symbols + ) + or ( + not sizevars.statically_known_multiple_of( + self.shape[idx], self.block_shape[idx] + ) + and not sizevars.statically_known_multiple_of( + self.shape[idx], + sympy_subs(self.block_shape[idx], block_to_max), + ) + ) + ) + and not ( + V.kernel.no_x_dim + and self.block_shape[idx] == TritonSymbols.block_sizes[SymT.XBLOCK] + ) + ) + ] + + def boundary_check(self) -> list[int]: + assert self._boundary_check is not None + return self._boundary_check + + def advance_roffset(self, symt: SymT) -> sympy.Expr: + """ + Codegen string to pass to tl.advance(name, ...). + + Advance is the difference between offsets in each loop iteration. + To compute it, we replace rN_offset with multiples of RN_BLOCK. + Since we expect rN_offset to vary in range(0, rN_numel, RN_BLOCK), the first + iteration has rN_offset=0, while the second has rN_offset=RN_BLOCK. + """ + rblock = TritonSymbols.block_sizes[symt] + advance = [ + ( + self.replace_offset(offset, rblock, symt) + - self.replace_offset(offset, sympy.S.Zero, symt) + ) + for offset in self.offsets + ] + return advance + + def has_indirect(self) -> bool: + return False # block_ptr can't do indirect indexing + + def has_rindex(self) -> bool: + return any( + free_symbol_is_type(expr, TritonSymbols.reduction_types) + for expr in self.block_shape + ) + + def has_rmask(self) -> bool: + return self.has_rindex() + + def has_tmpmask(self) -> bool: + return False # block_ptr can't do indirect indexing + + def has_mask(self) -> bool: + return bool(self.boundary_check()) + + +def triton_reshape( + value: str, old_shape: Sequence[sympy.Expr], new_shape: Sequence[sympy.Expr] +) -> str: + """Workaround https://github.com/triton-lang/triton/issues/2836""" + assert isinstance(old_shape, list) and isinstance(new_shape, list) + + old_shape_str = [V.kernel.index_to_str(shape) for shape in old_shape] + new_shape_str = [V.kernel.index_to_str(shape) for shape in new_shape] + + if old_shape_str == new_shape_str: + return value + if [s for s in new_shape_str if s != "1"] != old_shape_str: + return f"tl.reshape({value}, [{', '.join(new_shape_str)}])" + # rewrite to [:, None] syntax, which is less buggy + idx = 0 + expand = [] + for size in new_shape_str: + if idx < len(old_shape_str) and size == old_shape_str[idx]: + expand.append(":") + idx += 1 + else: + assert size == "1" + expand.append("None") + assert idx == len(old_shape_str) + return f"{value}[{', '.join(expand)}]" + + +# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a +# number of operators which Triton "implements", but in a way that is +# inconsistent with Python semantics (and consistent with C semantics). We +# must override all of these, or it is potential silent correctness problem +class TritonPrinter(PythonPrinter): + def _print_TruncToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return ( + f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_Float(self, expr: sympy.Expr) -> str: + if config.is_fbcode() and torch.version.hip: + ret = f"{expr}" + else: + ret = f"tl.full([], {expr}, tl.float64)" + return ret + + def _print_ToFloat(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + s = self.parenthesize(expr.args[0], PRECEDENCE["Atom"] - 0.5) + return f"{s}.to(tl.float64)" + + def _print_PythonMod(self, expr: sympy.Expr) -> str: + quot, div = expr.args + if quot.is_nonnegative and div.is_nonnegative: + return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5) + quot_s = self._print(quot) + div_s = self._print(div) + return f"triton_helpers.remainder_integer({quot_s}, {div_s})" + + def _print_FloorDiv(self, expr: sympy.Expr) -> str: + assert expr.is_integer + quot, div = expr.args + if quot.is_nonnegative and div.is_nonnegative: + return self.stringify(expr.args, " // ", PRECEDENCE["Atom"] - 0.5) + quot_s = self._print(quot) + div_s = self._print(div) + return f"triton_helpers.div_floor_integer({quot_s}, {div_s})" + + # TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher + # precision algorithm, which we would need to replicate here + def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5) + + # NB: sympy.floor/ceiling produce integers, so we have to do the + # conversion to index dtype + def _print_floor(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return ( + f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_FloorToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return ( + f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_ceiling(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + + def _print_CeilToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + + def _helper_sqrt(self, expr: sympy.Expr) -> str: + return f"libdevice.sqrt(({self._print(expr)}).to(tl.float32))" + + def _print_FloatPow(self, expr: sympy.Expr) -> str: + return ( + f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})" + ) + + def _print_PowByNatural(self, expr: sympy.Expr) -> str: + if expr.args[0].is_Integer: + return f"libdevice.pow({float(expr.args[0])}, {self._print(expr.args[1])})" + return ( + f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})" + ) + + def _print_Where(self, expr: sympy.Expr) -> str: + c = self.doprint(expr.args[0]) + p = self.doprint(expr.args[1]) + q = self.doprint(expr.args[2]) + return f"tl.where({c}, {p}, {q})" + + def _print_min_max_helper(self, expr: sympy.Expr, cmp: str) -> str: + """ + Helper for max/min code generation. + cmp: > or < + """ + if len(expr.args) == 1: + return self._print(expr.args[0]) + + mid = len(expr.args) // 2 + cls = type(expr) + a = self._print(cls(*expr.args[:mid])) + b = self._print(cls(*expr.args[mid:])) + + # Use a macro so we can propagate constexprs. + # https://github.com/triton-lang/triton/issues/3815 + a, b = tuple(f"({x})" for x in (a, b)) + assert cmp in (">", "<"), f"Unexpected comparator: '{cmp}'" + return f"({a} * ({a} {cmp}= {b}) + {b} * ({b} {cmp} {a}))" + + def _print_Min(self, expr: sympy.Expr) -> str: + return self._print_min_max_helper(expr, "<") + + def _print_Max(self, expr: sympy.Expr) -> str: + return self._print_min_max_helper(expr, ">") + + def _print_Abs(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"tl_math.abs({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.cos(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.cosh(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.acos(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.sin(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.sinh(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.asin(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.tan(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.tanh(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.log2(({self._print(expr.args[0])}).to(tl.float32))" + + def _print_RoundToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return ( + f"libdevice.llrint({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_RoundDecimal(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 2 + number, ndigits = expr.args + if number.is_integer: + # ndigits < 0 should have been filtered by the sympy function + assert ndigits < 0 + raise ValueError( + f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." + ) + + number_str = self.parenthesize(number, PRECEDENCE["Mul"]) + return f"libdevice.nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits}" + + +texpr = TritonPrinter().doprint + + +def triton_compute_type(dtype: torch.dtype) -> str: + """Convert torch.dtype to triton type and upcast [b]float16 to float32""" + return triton_type(upcast_compute_type(dtype)) + + +def triton_store_type(dtype: torch.dtype) -> str: + """Convert torch.dtype to triton type, with fix for storing tl.bool""" + if dtype == torch.bool: + dtype = torch.int8 + return triton_type(dtype) + + +def upcast_acc_dtype(dtype: torch.dtype) -> torch.dtype: + """Implicit upcasts used for Triton reduction types""" + if is_integer_dtype(dtype) and dtype.is_signed and dtype.itemsize <= 4: + return torch.int32 + return upcast_compute_type(dtype) + + +def triton_acc_type(dtype: torch.dtype) -> str: + """Convert torch.dtype to triton type, with reduction upcasts""" + return triton_compute_type(upcast_acc_dtype(dtype)) + + +def low_precision_fp(dtype: torch.dtype) -> bool: + return dtype.itemsize <= 2 and dtype.is_floating_point + + +def low_precision_fp_var(var: Union[CSEVariable, Any]) -> bool: + if not isinstance(var, CSEVariable): + return False + + dtype = var.dtype + return low_precision_fp(dtype) if isinstance(dtype, torch.dtype) else False + + +class TritonCSEVariable(CSEVariable): + def __init__(self, name, bounds: ValueRanges[Any], dtype: torch.dtype) -> None: + super().__init__(name, bounds, dtype) + # We'll use this to track which masks the variable needs when used for indirect indexing + self.mask_vars: OrderedSet[str] = OrderedSet() + assert dtype is not None, "TritonCSEVariable must have dtype" + + def update_on_args(self, name, args, kwargs): + for arg in args: + if isinstance(arg, TritonCSEVariable): + self.mask_vars.update(arg.mask_vars) + elif isinstance(arg, sympy.Symbol): + # most of the time index vars don't need masks associated with them + # however, when index vars are used to compute indices for indirect reads + # those reads should subsequently be masked, + for symt in TritonSymbols.block_types: + if symbol_is_type(arg, symt): + self.mask_vars.update([f"{prefix_str[symt]}mask"]) + break + + +def get_dtype_handler() -> DtypePropagationOpsHandler: + from torch._inductor.dtype_propagation import DtypePropagationOpsHandler + + return DtypePropagationOpsHandler() + + +def maybe_upcast_float32(convert_output: bool = True) -> Callable[[_T], _T]: + """ + Codegen helper to upcast arguments to float32, depending on the config and dtype. + This decorates tl.math/libdevice codegen functions. + """ + + def needs_upcast(var) -> bool: + return ( + not config.triton.codegen_upcast_to_fp32 + and isinstance(var, CSEVariable) + and var.dtype in (torch.float16, torch.bfloat16) + ) + + def maybe_upcast_arg(var) -> str: + upcast_string = ".to(tl.float32)" if needs_upcast(var) else "" + return f"{var}{upcast_string}" + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + # Record that this function only supports float32 and float64. + OpDtypeSupport.register_upcast(func, convert_output) + + def wrapped(*args, **kwargs) -> str: + # Optionally upcast args to float32. + upcast_args = [maybe_upcast_arg(arg) for arg in args] + upcast_kwargs = {key: maybe_upcast_arg(val) for key, val in kwargs.items()} + + # Call the decorated function, optionally downcasting the result. + result = func(*upcast_args, **upcast_kwargs) + any_needs_upcast = convert_output and any( + needs_upcast(var) for var in itertools.chain(args, kwargs.values()) + ) + result_dtype = ( + None + if not any_needs_upcast + else getattr(get_dtype_handler(), func.__name__)(*args, **kwargs) + ) + needs_downcast = result_dtype not in (torch.float32, None) + downcast_string = ( + f".to({triton_type(result_dtype)})" + if needs_downcast and result_dtype is not None + else "" + ) + return f"{result}{downcast_string}" + + return wrapped + + return decorator # type: ignore[return-value] + + +class TritonOverrides(OpOverrides): + """Map element-wise ops to Triton""" + + _LOG_2_E = math.log2(math.e) + + @staticmethod + def to_dtype( + x, + dtype: torch.dtype, + src_dtype: Optional[torch.dtype] = None, + use_compute_types=True, + ): + def _get_min_elements_per_thread( + src_dtype: torch.dtype, dst_dtype: torch.dtype + ) -> int: + if src_dtype == dst_dtype: + # No data type conversion is needed. No requirements on min_elem_per_thread. + return 0 + + # fp8 data type conversions has min_elem_per_thread requirements. + # Refer to Triton implementations here: + # https://github.com/triton-lang/triton/blob/10f59d8ce04052521c1bc0cb3a3f8b98918fc7e3/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp#L10. + fp8_dtypes = ( + torch.float8_e4m3fn, + torch.float8_e5m2, + ) + # Triton doesn't support type conversions between fp8_e4m3 and fp8_e5m2. + assert not ( + src_dtype in fp8_dtypes + and dst_dtype in fp8_dtypes + and src_dtype != dst_dtype + ), "Conversions between float8_e5m2 and float8_e4m3fn is not supported!" + if src_dtype == torch.float8_e5m2 or dst_dtype == torch.float8_e5m2: + return 4 + if src_dtype == torch.float8_e4m3fn or dst_dtype == torch.float8_e4m3fn: + return 2 + # No requirements on min_elem_per_thread. + return 0 + + if src_dtype is not None: + # Both dtype and src_dtype are set. This is used by torch to(dtype=dtype). + # It takes the maximum min_elem_per_thread if there are multiple fp8 conversions + # in the same kernel. + V.kernel.min_elem_per_thread = max( + _get_min_elements_per_thread(src_dtype, dtype), + V.kernel.min_elem_per_thread, + ) + + if dtype == torch.bool: + return f"({x} != 0)" + elif dtype == torch.uint8: + # to work around llvm uint conversion semantics + # that produces 0's for negative values + return f"{x}.to(tl.int8).to(tl.uint8)" + + if use_compute_types: + out_dtype = triton_compute_type(dtype) + else: + out_dtype = triton_store_type(dtype) + + return f"{x}.to({out_dtype})" + + @staticmethod + def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype): + assert src_dtype.itemsize == dtype.itemsize + # We may promote float16 or bfloat16 to float32 and cause the + # bitwidth of dtype to be different from the input tensor (i.e. float32). + # In such as case, we will have to convert the input tensor to + # its src_type, perform bitcast, and then convert the bit-casted + # tensor back to float to ensure we use values with the right precision. + if x.dtype != src_dtype: + x = f"{x}.to({triton_type(src_dtype)})" + + out = f"{x}.to({triton_type(dtype)}, bitcast=True)" + if upcast_compute_type(dtype) != dtype: + out = f"{out}.to({triton_type(upcast_compute_type(dtype))})" + + return out + + @staticmethod + def _shaped_constant(value, dtype, shape): + type_ = torch._prims_common.dtype_to_type(dtype) + triton_val = constant_repr(type_(value)) + triton_type = triton_compute_type(dtype) + + if triton_type == "tl.float32": + # Float constants are always f32 in triton + return triton_val + + # NOTE: We use a tensor here in order to get the expected type. + # Otherwise, e.g. float64 constants would be truncated to float32. + if value < 0 and not dtype.is_signed: + triton_signed_type = f"tl.{triton_type[4:]}" + return f"tl.full({shape}, {triton_val}, {triton_signed_type}).to({triton_type})" + else: + return f"tl.full({shape}, {triton_val}, {triton_type})" + + @classmethod + def constant(cls, value, dtype): + return cls._shaped_constant(value, dtype, shape=[]) + + @staticmethod + @maybe_upcast_float32() + def abs(x): + return f"tl_math.abs({x})" + + # TODO - register these ops as having divergent dtype + # output if doing graph pass to remove consecutive casts + + @staticmethod + def truediv(x, y): + out = f"({x} / {y})" + if low_precision_fp_var(x) or low_precision_fp_var(y): + out_dtype = get_dtype_handler().truediv(x, y) + if out_dtype in (torch.float16, torch.float32): + out = f"{out}.to({triton_type(out_dtype)})" + + return out + + @staticmethod + def mod(x, y): + out = f"({x} % {y})" + if low_precision_fp_var(x) or low_precision_fp_var(y): + out_dtype = get_dtype_handler().mod(x, y) + if out_dtype in (torch.float16, torch.float32): + out = f"{out}.to({triton_type(out_dtype)})" + return out + + @staticmethod + @maybe_upcast_float32() + def exp(x): + """ + When use_fast_math, use the ftz (flushing to zero) variant + of exponent computation. + + Check https://github.com/triton-lang/triton/issues/5735 for + more details. + """ + if config.use_fast_math: + return f"libdevice.exp2({x} * {TritonOverrides._LOG_2_E})" + else: + return f"tl_math.exp({x})" + + @staticmethod + @maybe_upcast_float32() + def exp2(x): + return f"libdevice.exp2({x})" + + @staticmethod + @maybe_upcast_float32() + def expm1(x): + return f"libdevice.expm1({x})" + + @staticmethod + @maybe_upcast_float32() + def sqrt(x): + return f"libdevice.sqrt({x})" + + @staticmethod + def relu(x): + bug = config.triton.inject_relu_bug_TESTING_ONLY + if bug == "compile_error": + return "compile error!" + elif bug == "runtime_error": + # NB: this only triggers runtime error as long as input + # is not all zero + return f'triton_helpers.device_assert_then({x} == 0, "injected assert fail", {x})' + elif bug == "accuracy": + return f"{x} + 1" + elif bug is None: + return ops.maximum(ops.constant(0, torch.int32), x) + else: + raise AssertionError( + f"unrecognized config triton.inject_relu_bug_TESTING_ONLY = {bug!r}" + ) + + @staticmethod + def minimum(a, b): + return f"triton_helpers.minimum({a}, {b})" + + @staticmethod + def maximum(a, b): + return f"triton_helpers.maximum({a}, {b})" + + @staticmethod + def where(a, b, c): + return f"tl.where({a}, {b}, {c})" + + @staticmethod + def inline_asm_elementwise( + *inputs, asm, constraints=None, dtype=torch.float32, is_pure=True, pack=1 + ): + triton_type = triton_compute_type(dtype) + input_refs = ", ".join([str(i) for i in inputs]) + if constraints is None: + constraints = ", ".join(["=r"] + ["r" for _ in inputs]) + return f"tl.inline_asm_elementwise('{asm}', '{constraints}', [{input_refs}], dtype={triton_type}, is_pure={is_pure}, pack={pack})" # noqa: B950 + + @staticmethod + @maybe_upcast_float32() + def cos(x): + return f"tl_math.cos({x})" + + @staticmethod + @maybe_upcast_float32() + def sin(x): + return f"tl_math.sin({x})" + + @classmethod + def index_expr(cls, expr, dtype): + raise NotImplementedError("ops.index_expr not implemented outside a kernel") + + @staticmethod + def masked(mask, body, other): + raise NotImplementedError("ops.masked not implemented outside a kernel") + + @staticmethod + @maybe_upcast_float32() + def lgamma(x): + return f"libdevice.lgamma({x})" + + @staticmethod + @maybe_upcast_float32() + def erf(x): + return f"libdevice.erf({x})" + + @staticmethod + @maybe_upcast_float32() + def cosh(x): + return f"libdevice.cosh({x})" + + @staticmethod + @maybe_upcast_float32() + def sinh(x): + return f"libdevice.sinh({x})" + + @staticmethod + @maybe_upcast_float32() + def acos(x): + return f"libdevice.acos({x})" + + @staticmethod + @maybe_upcast_float32() + def acosh(x): + return f"libdevice.acosh({x})" + + @staticmethod + @maybe_upcast_float32() + def asin(x): + return f"libdevice.asin({x})" + + @staticmethod + @maybe_upcast_float32() + def asinh(x): + return f"libdevice.asinh({x})" + + @staticmethod + @maybe_upcast_float32() + def atan2(x, y): + return f"libdevice.atan2({x}, {y})" + + @staticmethod + @maybe_upcast_float32() + def atan(x): + return f"libdevice.atan({x})" + + @staticmethod + @maybe_upcast_float32() + def atanh(x): + return f"libdevice.atanh({x})" + + @staticmethod + @maybe_upcast_float32() + def copysign(x, y): + return f"libdevice.copysign({x}, {y})" + + @staticmethod + @maybe_upcast_float32() + def erfc(x): + return f"libdevice.erfc({x})" + + @staticmethod + @maybe_upcast_float32() + def erfinv(x): + return f"libdevice.erfinv({x})" + + @staticmethod + @maybe_upcast_float32() + def hypot(x, y): + return f"libdevice.hypot({x}, {y})" + + @staticmethod + @maybe_upcast_float32() + def log10(x): + return f"libdevice.log10({x})" + + @staticmethod + @maybe_upcast_float32() + def log2(x): + return f"libdevice.log2({x})" + + @staticmethod + @maybe_upcast_float32() + def nextafter(x, y): + return f"libdevice.nextafter({x}, {y})" + + @staticmethod + def logical_and(a, b): + return f"{a} & {b}" + + @staticmethod + def logical_not(a): + return f"{a} == 0" + + @staticmethod + def logical_or(a, b): + return f"{a} | {b}" + + @staticmethod + def logical_xor(a, b): + return f"({a} ^ {b})" + + @staticmethod + def bitwise_and(a, b): + return f"{a} & {b}" + + @staticmethod + def bitwise_not(a): + return f"~{a}" + + @staticmethod + def bitwise_or(a, b): + return f"{a} | {b}" + + @staticmethod + def bitwise_xor(a, b): + return f"{a} ^ {b}" + + @staticmethod + def bitwise_left_shift(a, b): + return f"{a} << {b}" + + @staticmethod + def bitwise_right_shift(a, b): + return f"{a} >> {b}" + + @staticmethod + def rand(seed, offset): + offset = f"({offset}).to(tl.uint32)" + return f"tl.rand({seed}, {offset})" + + @staticmethod + def randn(seed, offset): + offset = f"({offset}).to(tl.uint32)" + return f"tl.randn({seed}, {offset})" + + @staticmethod + def randint64(seed, offset, low, high): + offset = f"({offset}).to(tl.uint32)" + return f"triton_helpers.randint64({seed}, {offset}, {low}, {high})" + + @staticmethod + def load_seed(name, offset): + raise NotImplementedError("ops.load_seed not implemented outside a kernel") + + @staticmethod + @maybe_upcast_float32() + def rsqrt(x): + return f"libdevice.rsqrt({x})" + + @staticmethod + @maybe_upcast_float32() + def log1p(x): + return f"libdevice.log1p({x})" + + @staticmethod + @maybe_upcast_float32() + def tan(x): + return f"libdevice.tan({x})" + + @staticmethod + @maybe_upcast_float32() + def tanh(x): + return f"libdevice.tanh({x})" + + @staticmethod + @maybe_upcast_float32() + def sigmoid(x): + return f"tl.sigmoid({x})" + + @staticmethod + def signbit(x): + # XX: This is wrong for the value -0.0 in floating point + return ( + f"(libdevice.signbit({x}) != 0) if ({x}).dtype is tl.float32 else {x} < 0" + ) + + @staticmethod + @maybe_upcast_float32() + def fmod(a, b): + return f"libdevice.fmod({a}, {b})" + + @staticmethod + @maybe_upcast_float32() + def pow(a, b): + return f"libdevice.pow({a}, {b})" + + @staticmethod + @maybe_upcast_float32() + def log(x): + return f"tl_math.log({x})" + + @staticmethod + @maybe_upcast_float32(convert_output=False) + def isinf(x): + return f"libdevice.isinf({x}).to(tl.int1)" + + @staticmethod + @maybe_upcast_float32(convert_output=False) + def isnan(x): + return f"libdevice.isnan({x}).to(tl.int1)" + + @staticmethod + @maybe_upcast_float32() + def round(x): + return f"libdevice.nearbyint({x})" + + @staticmethod + @maybe_upcast_float32() + def floor(x): + return f"libdevice.floor({x})" + + @staticmethod + def floordiv(a, b): + # See the comment in lowering.div_mode. a and b are integer type. + # Similar to div_floor_kernel_cuda in pytorch core. + # Notice that // in triton behaves as truncdiv instead of floordiv + quot = f"{a} // {b}" + rem = f"{a} % {b}" + return f"tl.where(({a} < 0) != ({b} < 0), tl.where({rem} != 0, {quot} - 1, {quot}), {quot})" + + @staticmethod + def sign(x): + z = ops.constant(0, torch.int32) + left = ops.to_dtype((ops.lt(z, x)), torch.int8) + right = ops.to_dtype((ops.lt(x, z)), torch.int8) + sub = ops.sub(left, right) + return f"{sub}.to({x}.dtype)" + + @staticmethod + @maybe_upcast_float32() + def trunc(x): + return f"libdevice.trunc({x})" + + @staticmethod + def truncdiv(a, b): + # See the comment in lowering.div_mode. a and b are integer type. + # Notice that // in triton behaves as truncdiv instead of floordiv + return f"{a} // {b}" + + @staticmethod + @maybe_upcast_float32() + def ceil(x): + return f"libdevice.ceil({x})" + + +TritonOverrides._initialize_pointwise_overrides("triton") + + +class TritonKernelOverrides(TritonOverrides): + """Map element-wise ops to Triton within a TritonKernel + + Unlike TritonOverrides, these assume the code is going to be inserted into + the body of the main triton kernel and so it may use indexing and mask + variables which are assumed to already be defined in the current scope. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # happens in __init__ unlike _initialize_pointwise_overrides + # because the libdevice registrations are populated during lowerings + self._setup_libdevice_routing() + + @classmethod + @functools.cache + def _setup_libdevice_routing(cls): + """Set up routing to libdevice implementations for fp64 inputs.""" + + from torch._inductor.codegen.common import OpDecompositions + + for fn_name in torch._inductor.utils.op_requires_libdevice_fp64: + assert hasattr(cls, fn_name) + original_impl = getattr(cls, fn_name) + + def decomposition_router(x, _original_impl, _fn_name): + if x.dtype != torch.float64: + return _original_impl(x) + else: + return getattr(OpDecompositions, _fn_name)(x).value + + if fn_name == "sigmoid": + assert hasattr(OpDecompositions, "sigmoid") + fn = functools.partial( + decomposition_router, _original_impl=original_impl, _fn_name=fn_name + ) + fn.__name__ = fn_name # type: ignore[attr-defined] + setattr(cls, fn_name, staticmethod(fn)) + continue + + def dtype_router(x, _original_impl, _fn_name): + if x.dtype == torch.float64: + return f"libdevice.{_fn_name}({x})" + else: + return _original_impl(x) + + fn = functools.partial( + dtype_router, _original_impl=original_impl, _fn_name=fn_name + ) + fn.__name__ = fn_name # type: ignore[attr-defined] + setattr(cls, fn_name, staticmethod(fn)) + + @classmethod + def constant(cls, value, dtype): + # NOTE: Cannot use shape=[] as it's not supported by triton-rocm + # We could use shape=[1] instead but starting with the correct + # ndim avoids extra `tt.expand_dim` ops appearing in the triton IR. + ndim = V.kernel.triton_tensor_ndim() + shape = [1] * ndim + return cls._shaped_constant(value, dtype, shape=shape) + + @classmethod + def index_expr(cls, expr, dtype): + indexing = V.kernel.indexing(expr, block_ptr=False) + assert isinstance(indexing, IndexingOptions) + + # Our sympy expr printing casts to the current kernel index dtype. + # we only respect non int32-int64 dtypes and otherwise use current kernel indexing dtype + index_dtype = V.kernel.get_index_dtype_as_torch_dtype() + dtype = dtype if dtype not in (torch.int32, torch.int64) else index_dtype + + # after we emit this var we cast it to the correct dtype + orig = config.test_configs.runtime_triton_dtype_assert + try: + config.test_configs.runtime_triton_dtype_assert = False + var = V.kernel.cse.generate( + V.kernel.compute, + indexing.index_str, + bounds=get_bounds_index_expr(expr), + dtype=dtype, + ) + finally: + config.test_configs.runtime_triton_dtype_assert = orig + + if dtype not in (torch.int32, torch.int64): + var = V.kernel.cse.generate( + V.kernel.compute, + cls.to_dtype(var, dtype), + dtype=upcast_compute_type(dtype), + ) + else: + # TODO: we are not always consistent in enforcing that the output of the index expr printing + # results in the indexing dtype. So if we detect that we have an input which might type promote + # to a dtype other than indexing dtype, add a cast. + # Trying to avoid + dtype = index_dtype + for index_var in expr.free_symbols: + if symbol_is_type(index_var, SymT.TMP): + dtype = torch.promote_types( + dtype, V.kernel.cse.varname_map[index_var.name].dtype + ) + + if dtype != index_dtype: + var = V.kernel.cse.generate( + V.kernel.compute, + cls.to_dtype(var, index_dtype), + dtype=index_dtype, + ) + + var.mask_vars = indexing.mask_vars + return var + + @staticmethod + def masked(mask, body, other): + if mask is not None and torch.version.hip is not None: + mask = V.kernel.cse.generate( + V.kernel.compute, + f"{mask}.to(tl.int1)", + dtype=torch.bool, + ) + + nodes = body.graph.find_nodes(op="output") + assert nodes, "graph for body does not contain an output" + + need_where = False + # If we have a tl.load with a masking operator and no other value + # we can add the mask here and the other value to the tl.load + # operator to save the branching cost. + for node in nodes: + for arg in node.args: + if arg.target != "load" or should_unwrap_unspec_arg(arg.args[1]): + need_where = True + break + + value = None if need_where else other + + with V.kernel.mask_loads(mask, value=value) as new_mask: + result = body() + + if need_where: + # Remove once CSEVariables track the dtype + if result.bounds.is_bool: + other = bool(other) + # Take dtype from result to prevent accidental promotion + other = V.kernel.cse.generate( + V.kernel.compute, + f"tl.full({result}.shape, {constant_repr(other)}, {result}.dtype)", + bounds=ValueRanges.wrap(other), + dtype=result.dtype, + ) + ret = ops.where(new_mask, result, other) + else: + ret = result + + ret.mask_vars.discard(new_mask) + return ret + + @staticmethod + def load_seed(name, offset): + var = V.kernel.args.input(name) + return ( + f"tl.load({var} + {V.kernel.args.seed_offset('load_seed_offset', offset)})" + ) + + @staticmethod + def frexp(x): + cache_key = f"frexp({x})" + if cse_val := V.kernel.cse.try_get(cache_key): + return cse_val + + mantissa = V.kernel.cse.newvar(dtype=x.dtype) + exponent = V.kernel.cse.newvar(dtype=torch.int32) + V.kernel.compute.writeline( + f"{mantissa}, {exponent} = triton_helpers.frexp({x})" + ) + V.kernel.cse.put(cache_key, (mantissa, exponent)) + return (mantissa, exponent) + + +class HelperFunctions: + """An ordered set of helper functions.""" + + _templates_seen: dict[str, str] # Template code to function name + finalized_helpers: list[str] + + def __init__(self) -> None: + self._templates_seen = {} + self.finalized_helpers = [] + + def add(self, template_code: str, *, base_name="_triton_helper_fn") -> str: + """This accepts a function definition with the function name + left as a format specifier e.g. + + @triton.jit + def {name}(arg0, arg1): + return arg0 + arg1 + + We add the templated code to the function set and return the name + assigned to that function. + + """ + existing_name = self._templates_seen.get(template_code) + if existing_name is not None: + # Don't duplicate existing helpers + return existing_name + + name = f"{base_name}{len(self.finalized_helpers)}" + self._templates_seen[template_code] = name + self.finalized_helpers.append(template_code.format(name=name)) + return name + + def __iter__(self): + return iter(self.finalized_helpers) + + def __getitem__(self, idx): + return self.finalized_helpers[idx] + + +@dataclasses.dataclass +class BlockParameters: + """ + Class representing ND block dimensions, for block pointer analysis. + """ + + shape: list[sympy.Expr] = dataclasses.field(default_factory=list) + block_shape: list[sympy.Expr] = dataclasses.field(default_factory=list) + strides: list[sympy.Expr] = dataclasses.field(default_factory=list) + offsets: list[sympy.Expr] = dataclasses.field(default_factory=list) + + def __add__(self, other: BlockParameters) -> BlockParameters: + """ + Concatenates block parameters. + """ + cls = type(self) + a, b = tuple(dataclasses.asdict(x) for x in (self, other)) + return cls(**{key: a[key] + b[key] for key in a}) + + +class CooperativeReductionWorkspaceCache: + """ + The scratch space used for cooperative reductions can be reused + after two reduction loops. This keeps track of what can be reused. + """ + + def __init__(self, args): + self.args = args + self.current_loop = [] + self.prior_loop = [] + self.ready_for_reuse = collections.defaultdict(collections.deque) + self.loop_count = 0 + self.store_count = 0 + + def allocate(self, nbytes: sympy.Expr): + cached = self.ready_for_reuse.get(nbytes) + if cached: + return cached.popleft() + ws_name, ws_offset = self.args.workspace(nbytes, False) + self.current_loop.append((nbytes, ws_name, ws_offset)) + return (ws_name, ws_offset) + + def on_loop_end(self): + # Buffers can be reused after 2 loop ends + for nbytes, ws_name, ws_offset in self.prior_loop: + self.ready_for_reuse[nbytes].append((ws_name, ws_offset)) + self.prior_loop = self.current_loop + self.current_loop = [] + self.loop_count += 1 + + def increment_store_count(self): + prior = self.store_count + self.store_count += 1 + return prior + + +@dataclasses.dataclass +class FixedTritonConfig: + config: dict[str, int] + + def __getitem__(self, item): + return self.config[item] + + def __contains__(self, item): + return item in self.config + + +class TritonCSE(CSE[TritonCSEVariable, Union[str, tuple[str, str]]]): + """ + Subclasses CSE to apply the current load mask to the cache key to avoid CSEing + variables across separate masked blocks. + """ + + def augment_key(self, cache_key: str) -> Union[str, tuple[str, str]]: + if mask := V.kernel._load_mask: + return (cache_key, mask.name) + else: + return cache_key + + +class TritonKernel(SIMDKernel[TritonCSEVariable]): + overrides = TritonKernelOverrides # type: ignore[assignment] + helper_functions: HelperFunctions + kexpr: Callable[[sympy.Expr], str] = texpr + allow_block_ptr = True + + def __init__( + self, + tiling: dict[str, sympy.Expr], + min_elem_per_thread=0, + optimize_mask=True, + fixed_config: Optional[FixedTritonConfig] = None, + **kwargs, + ) -> None: + self.optimize_mask: bool = optimize_mask + self.fixed_config = fixed_config + super().__init__(tiling, **kwargs) + self.cse = TritonCSE(self.newvar_prefix, self.suffix) + self.post_loop_combine: IndentedBuffer = IndentedBuffer() + self.post_loop_store: IndentedBuffer = IndentedBuffer() + self.outside_loop_vars = OrderedSet[Any]() + self.min_elem_per_thread = min_elem_per_thread + self.block_ptr_id = itertools.count() + self.block_ptr_to_buffer = dict[str, str]() + self.helper_functions = HelperFunctions() + self.pointer_advancements: dict[SymT, dict[str, list[sympy.Expr]]] = ( + collections.defaultdict(dict) + ) + self._load_counts: collections.Counter[str] = collections.Counter() + + # A set of autotuning hints to pass as part of triton_meta + self.autotune_hints = OrderedSet[AutotuneHint]() + self.triton_meta: Optional[dict[str, Any]] = None + + if self.inside_reduction: + self.codegen_reduction_numels(self.body) + + if self.cooperative_reduction: + self.init_cooperative_reduction() + + self.codegen_range_tree() + + if self.cooperative_reduction: + self.init_cooperative_reduction_mask() + + def dtype_to_str(self, dtype: torch.dtype) -> str: + return triton_type(dtype) + + def should_use_cooperative_reduction(self) -> bool: + return self.inside_reduction and V.choices.should_use_cooperative_reduction( + self.features + ) + + def init_cooperative_reduction(self): + """One time setup code for cooperative reductions.""" + assert self.cooperative_reduction + + # shift all the grids over since tl.program_id(0) is for rsplit + for tree in self.range_trees: + if tree.grid_dim is not None: + tree.grid_dim += 1 + + sem_count = self.numels["x"] + if self.fixed_config: + sem_count = CeilDiv(sem_count, self.fixed_config["XBLOCK"]) + self.semaphores_name = self.args.semaphores(sem_count) + self.cooperative_reduction_workspace_cache = CooperativeReductionWorkspaceCache( + self.args + ) + self.body.splice( + """\ + RSPLIT_NEXT_POWER_OF_2: tl.constexpr = triton_helpers.constexpr_next_power_of_2(RSPLIT) + RSPLIT_IS_POWER_OF_2: tl.constexpr = RSPLIT == RSPLIT_NEXT_POWER_OF_2 + HAS_RSPLIT: tl.constexpr = RSPLIT > 1 + rsplit_id = tl.program_id(0) + num_rblocks = (rnumel + RBLOCK - 1) // RBLOCK + rsplit_chunk = (num_rblocks + RSPLIT - 1) // RSPLIT * RBLOCK + rsplit_start = rsplit_chunk * rsplit_id + rsplit_end = rsplit_chunk * (rsplit_id + 1) + """, + ) + if any( + not self._has_constant_mask(tree) + for tree in self.range_trees + if tree.is_reduction + ): + self.body.writeline( + "rsplit_end = tl.where(rsplit_end < rnumel, rsplit_end, rnumel)" + ) + + def init_cooperative_reduction_mask(self): + rsplit_arange = "tl.arange(0, RSPLIT_NEXT_POWER_OF_2)" + if not self.no_x_dim: + rsplit_arange = f"{rsplit_arange}[None, :]" + self.body.writeline(f"rsplit_arange = {rsplit_arange}") + + if self._has_constant_xmask(): + self.body.splice( + """\ + if RSPLIT_IS_POWER_OF_2: + rsplit_mask: tl.constexpr = None + else: + rsplit_mask = rsplit_arange < RSPLIT + """ + ) + else: + assert not self.no_x_dim + self.body.writeline( + "rsplit_mask = xmask if RSPLIT_IS_POWER_OF_2 else ((rsplit_arange < RSPLIT) & xmask)" + ) + + def codegen_range_tree(self): + for tree in self.range_trees: + # reduction indexing goes inside a loop + if not tree.is_loop: + self.iteration_ranges_codegen_header(tree, self.body) + elif self.inside_reduction: + # workaround for this issue: + # https://gist.github.com/jansel/6527126f781559095c5531f98a4235a7 + self.body.writeline( + f"{tree.prefix}base = {self.iteration_ranges_ranges_code(tree)}" + ) + + if self.inside_reduction: + if any(tree.is_loop for tree in self.range_trees): + # If the kernel contains loops, compute rbase. + rn_bases = self._get_reduction_symbols( + "base", integer=True, nonnegative=True + ) + rbase = self._flatten_reduction_indices(rn_bases) + self.body.splice(f"rbase = {self.index_to_str(rbase)}") + else: + # For looped reductions, indexing is deferred to the innermost loop. + self.codegen_reduction_indices(self.body) + + def need_numel_args(self): + """ + Indicate whether we need provide numel as arguments for the generated + kernel calls in the benchmark. + + Should be true for pointwise/reduction kernels but false for triton + matmul kernels. + """ + return True + + def should_use_persistent_reduction(self) -> bool: + return self.inside_reduction and V.choices.should_use_persistent_reduction( + self.features, self.cooperative_reduction + ) + + def want_no_x_dim(self): + if ( + self.persistent_reduction + and len(self.numels) == self.num_reduction_dims + 1 + ): + if self.fixed_config: + return self.fixed_config["XBLOCK"] == 1 + return V.choices.want_no_x_dim(self.features) + return False + + @property + def assert_function(self) -> str: + return "tl.device_assert" + + def indexing( + self, + index: sympy.Expr, + *, + copy_shape=None, + dense_indexing=False, + override_mask=None, + block_ptr=False, + ): + """ + Compute the index and mask to pass to tl.load() or tl.store() + """ + index = self.prepare_indexing(index) + index_vars = index.free_symbols + has_rindex = False + + mask_vars: OrderedSet[str] = OrderedSet() + for var in sorted(index_vars, key=operator.attrgetter("name")): + assert isinstance(var, sympy.Symbol) + has_rindex = has_rindex or symbol_is_type( + var, TritonSymbols.reduction_types + ) + if override_mask: + pass + elif symbol_is_type(var, SymT.TMP): + # indirect indexing + cse_var = self.cse.varname_map[var.name] + mask_vars.update(cse_var.mask_vars) + elif symbol_is_type( + var, + ( + SymT.UNBACKED_INT, + SymT.SIZE, + SymT.PRECOMPUTED_SIZE, + SymT.INDEX, + SymT.FLOAT, + SymT.UNBACKED_FLOAT, + ), + ): + pass + else: + # var is one of xN, yN, r0_N or r1_N + prefix_matches = [ + prefix_str[symt] + for symt in TritonSymbols.block_types + if symbol_is_type(var, symt) + ] + assert len(prefix_matches) == 1, f"Ambiguous type: {var.name}" + mask_vars.add(f"{prefix_matches[0]}mask") + + need_dense = ( + config.triton.dense_indexing + or dense_indexing + or self._load_mask is not None + ) and index != 0 + + have_dense = True + have_loop_vars = False + dense_mask_vars: OrderedSet[str] = OrderedSet() + + for tree in self.active_range_trees(): + if index_vars.intersection(tree.var_list): + have_loop_vars = True + else: + have_dense = False + dense_mask_vars.add(f"{tree.prefix}mask") + + if ( + block_ptr + and self.allow_block_ptr + and config.triton.use_block_ptr + and not override_mask + and not self._load_mask + and len(mask_vars - dense_mask_vars) == 0 + and not self.is_indirect_indexing(index) + and have_loop_vars + # workaround https://github.com/triton-lang/triton/issues/2821 + and self.index_dtype == "tl.int32" + ): + + def match_affine_block( + index: sympy.Expr, range_tree: IterationRangesRoot + ) -> Optional[BlockParameters]: + """ + Matches expressions of the form: + idx = s * xindex + + This implies stride (s,), and shape (XBLOCK,). + """ + stride = BlockPatternMatcher.match_affine_block_expr( + index, range_tree.symbol() + ) + if stride is None: + return None + + return BlockParameters( + shape=[range_tree.numel], + block_shape=[TritonSymbols.get_block_size(range_tree)], + strides=[stride], + offsets=[TritonSymbols.get_block_offset(range_tree)], + ) + + def match_mod_div_block( + index: sympy.Expr, range_tree: IterationRangesRoot + ) -> Optional[BlockParameters]: + """ + Matches higher-dimensional blocks coming from FloorDiv and ModularIndexing. + + Example expression to match: + sN * ((rindex//(d1 * ... * d(N-1)))) + + s1 * ModularIndexing(rindex, 1, d1) + + ... + + s(N-1) * ModularIndexing(rindex, d1 * ... * d(N-2), d(N-1)) + + This iterates over a block of shape (dN, ..., d1) and stride + (sN, ..., s1). (d1,...,d(N-1)) and (s1,...,sN) are + wildcards that we match. + + Note that dN does not appear in the expression, but we solve for it + using range tree numels and the other dims. + """ + + index_var = range_tree.symbol() + + # Bound the possible number of dims. We use the following heuristics: + # - At least one dim for each range tree node. + # - At least one dim for every FloorDiv or ModularIndexing op. + # - At least 2 dims to pattern match. + denom, modulo = sympy.symbols( + "denom modulo", + cls=functools.partial(sympy.Wild, exclude=[index_var]), + ) + num_dims = max( + 2, + len(self.range_tree_nodes), + ( + index.count(FloorDiv(index_var, denom)) + + index.count(ModularIndexing(index_var, denom, modulo)) + ), + ) + + match_result = BlockPatternMatcher.match_mod_div_block_expr( + index, index_var, range_tree.numel, num_dims + ) + if match_result is None: + return None + + ( + dims, + strides, + block_index_exprs, + ) = match_result + slice_numels = BlockPatternMatcher.get_slice_numels(dims) + + # Check for applicable iteration range sizes. + # When mapping a 1D block into an ND one, we need to know that + # the number of elements is not changed. This means the slice numels of + # the ND iteration range must evenly divide the length of the 1D block. + # There are two cases where we can guarantee this: + # 1. Numels are powers of 2. If numel == 2 ** n, and we know XBLOCK == 2 ** m, + # with n and m integers, then either numel is a multiple of XBLOCK, or numel + # is less than XBLOCK. (If numel is less than XBLOCK, we round up to 1 below.) + # 2. Numels are multiples of the maximum possible block size. + sizevars = V.graph.sizevars + max_block = self.max_block(range_tree.prefix) + if any( + not sizevars.statically_known_multiple_of(numel, max_block) + and not sizevars.statically_known_power_of_2(numel) + for numel in slice_numels + ): + return None + + # Compute the ND block shape from the linear block size. + # Use CielDiv to round leading dimensions up to 1. + # Non-leading dimensions are clamped to the size of the iteration range, + # while the leading dimension can exceed this to accommodate a larger + # block size. + linear_block_size = TritonSymbols.get_block_size(range_tree) + block_shape: list[sympy.Expr] = [ + CeilDiv(linear_block_size, slice_numels[0]) + ] + [ + sympy.Min(CeilDiv(linear_block_size, numel), dim) + for numel, dim in zip(slice_numels[1:], dims[1:]) + ] + + # Compute block offsets from {xyzr}offset and the matched expressions. + block_offsets: list[sympy.Expr] = [ + sympy_subs( + expr, {index_var: TritonSymbols.get_block_offset(range_tree)} + ) + for expr in block_index_exprs + ] + + return BlockParameters( + shape=dims, + block_shape=block_shape, + strides=strides, + offsets=block_offsets, + ) + + def match_block_pointer_subexpr( + expr: sympy.Expr, range_tree: IterationRangesRoot + ) -> Optional[BlockParameters]: + """ + Match a block indexing subexpression involving a single range tree. + """ + for match_func in ( + match_affine_block, + match_mod_div_block, + ): + match = match_func(expr, range_tree) + if match is not None: + return match + + return None + + def match_block_pointer() -> Optional[BlockPtrOptions]: + index_relative_to_xyr_index = sympy_subs( + index, {v: t.expr for v, t in self.range_tree_nodes.items()} + ) + range_trees = self.active_range_trees() + + # Partition the index into subexpressions pertaining to each range tree. + # For example xindex * 5 + r0_index * 3 is partitioned to + # (xindex * 5, r0_index * 3). + index_subexprs = [ + BlockPatternMatcher.get_subexpr_involving_symbol( + index_relative_to_xyr_index, tree.symbol() + ) + for tree in range_trees + ] + + # Match each range tree's subexpression separately. + range_symbols = OrderedSet(tree.symbol() for tree in range_trees) + block_params = BlockParameters() + for tree, subexpr in zip(range_trees, index_subexprs): + # Reject mixed terms, e.g. xindex * r0_index. + # NB: the zero expression is allowed, for broadcasting. + if len(range_symbols.intersection(subexpr.free_symbols)) > 1: + return None + + # Match the subexpression for this range tree. + params = match_block_pointer_subexpr(subexpr, tree) + if params is None: + return None + block_params += params + + # Collect leftover terms as a constant offset. + offset = index_relative_to_xyr_index - sum(index_subexprs) + + # Form the block pointer. + self.filter_masks(mask_vars) + return BlockPtrOptions.create( + params=block_params, + constant_offset=offset, + range_trees=range_trees, + mask_vars=mask_vars, + get_max_block=self.max_block, + ) + + # Return a block pointer, if indexing matches the pattern. + options = match_block_pointer() + if options is not None: + return options + + expand_str = None + index_str = self.index_to_str(index) + if isinstance(index, sympy.Integer): + expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() + index_str = f"tl.full({expand_str}, {index_str}, tl.int32)" + if self.fixed_config and not self._has_constant_xmask(): + mask_vars = OrderedSet(["xmask"]) + else: + mask_vars = OrderedSet() + if self._load_mask: + mask_vars.add(self._load_mask) + return IndexingOptions(index_str, mask_vars, expand_str, has_rindex, index) + + if need_dense and not have_dense: + expand_str = f"{copy_shape}.shape" if copy_shape else self.dense_size_str() + index_str = f"tl.broadcast_to({index_str}, {expand_str})" + mask_vars = dense_mask_vars + elif not have_loop_vars and copy_shape: + index_str = f"tl.broadcast_to({index_str}, {copy_shape}.shape)" + mask_vars = dense_mask_vars + + if override_mask: + mask_vars = OrderedSet([override_mask]) + + if self._load_mask: + mask_vars.add(self._load_mask) + + self.filter_masks(mask_vars) + + return IndexingOptions(index_str, mask_vars, expand_str, has_rindex, index) + + def codegen_block_ptr( + self, name: str, var: str, indexing: BlockPtrOptions, other="" + ) -> tuple[str, str]: + check = indexing.boundary_check() + if not check: + # workaround https://github.com/triton-lang/triton/issues/2813 + other = "" + elif other: + assert other == ", other=0.0" + other = f", boundary_check={check!r}, padding_option='zero'" + else: + other = f", boundary_check={check!r}" + if ( + self.inside_reduction + and self.range_trees[-1].is_loop + and indexing.has_rindex() + ): + block_ptr = f"block_ptr{next(self.block_ptr_id)}" + self.body.writeline( + DeferredLine( + name, f"{block_ptr} = {indexing.format(var, roffset=False)}" + ) + ) + # Store for later use. If the buffer is removed the below advancements + # are no longer necessary + self.block_ptr_to_buffer[block_ptr] = name + + # Generate block pointer advancements, for later use. + for symt in TritonSymbols.reduction_types: + advance_offsets = indexing.advance_roffset(symt) + + # Ignore identity advancements. + if all( + V.graph.sizevars.statically_known_equals(offset, sympy.Integer(0)) + for offset in advance_offsets + ): + continue + + advancements = self.pointer_advancements[symt] + assert block_ptr not in advancements, ( + "duplicate advancement for pointer '{block_ptr}' at type '{symt}'" + ) + advancements[block_ptr] = advance_offsets + else: + block_ptr = indexing.format(var) + return block_ptr, other + + def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=""): + # Stores require an explicit broadcast. We do this in two phases: + # 1. Broadcast the operand to the final shape of the range trees, e.g. [ZBLOCK, + # YBLOCK, XBLOCK]. This protects against implicit broadcasting from loads. + # 2. In case the block pointer has different dimensionality, broadcast/reshape the + # result to the shape of the pointer. + value = f"tl.broadcast_to({value}, {indexing.final_shape})" + + # These dims no longer need broadcasting. + for idx, (dim, broadcast_dim) in enumerate( + zip(indexing.final_shape, indexing.broadcast_shape) + ): + if V.graph.sizevars.statically_known_equals(dim, broadcast_dim): + indexing.broadcasting_dims[idx] = False + + value = indexing.codegen_broadcast_and_reshape( + value, indexing.final_shape, indexing.block_shape, False + ) + + # workaround https://github.com/triton-lang/triton/issues/2814 + value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})" + return f"tl.store({block_ptr}, {value}{other})" + + def check_bounds( + self, + expr: sympy.Expr, + size: sympy.Expr, + lower: bool, + upper: bool, + ): + if not (lower or upper): + return + + assert isinstance(expr, sympy.Expr) + indexing = self.indexing(expr, block_ptr=False) + assert isinstance(indexing, IndexingOptions) + + index_str = indexing.index_str + mask_str = indexing.mask_str if indexing.has_mask() else None + size_str = texpr(self.rename_indexing(size)) if upper else None + + # expr is already wrapped + line = self.indirect_assert( + index_str, "0" if lower else None, size_str, mask_str + ) + + buffer = self.get_load_buffer(indexing) + self.cse.generate(buffer, line, assignment=False, dtype=torch.int32) + + def get_load_buffer(self, indexing): + if indexing.has_indirect() or indexing.has_tmpmask(): + # Masked loads must come after the mask is computed + return self.compute + elif ( + self.inside_reduction + and self.range_trees[-1].is_loop + and not indexing.has_rindex() + ): + # can lift a common load outside of reduction loop + # One exception is when this is an indirect_load. + return self.body + else: + return self.loads + + def load(self, name: str, index: sympy.Expr): + var = self.args.input(name) + load_counts = self._load_counts + load_counts[name] += 1 + make_line: Callable[[str], Union[str, DelayReplaceLine]] = identity + indirect_indexing = self.is_indirect_indexing(index) + original_index = index + indexing = self.indexing(index, block_ptr=True) + has_rindex = indexing.has_rindex() + has_tmpmask = indexing.has_tmpmask() + + # Keep the variable in cache if were going to reuse it. Equiv., if any of the following hold + # 1) We are doing broadcasting + # 2) It is a non-coalesced load. The intuition is that if it's + # non-coalesced, we will likely load each element multiple times in + # practice. + # 3) It will be used later and it won't be CSE'd. Equiv., if all the following hold + # 3.1) We are in a reduction loop + # 3.2) Its not its last use + # 3.3) This load will not be lifted to the body + # + is_coalesced = any( + i == 1 for i in self.get_strides_of_load(original_index).values() + ) + if self.is_broadcasted(original_index): + ep = ", eviction_policy='evict_last'" + elif not is_coalesced: + ep = ", eviction_policy='evict_last'" + elif self.inside_reduction and self.range_trees[-1].is_loop: + + def decide_later(): + if load_counts[name] > expected_count and ( + has_rindex or indirect_indexing + ): + return "evict_last" + return "evict_first" + + expected_count = load_counts[name] + ep = ", eviction_policy=''" + make_line = functools.partial(DelayReplaceLine, "", decide_later) + else: + ep = "" + + if (has_tmpmask or has_rindex) and indexing.has_mask(): + if self._load_other: + other = f", other={constant_repr(self._load_other)}" + else: + other = ", other=0.0" + else: + other = "" + + """Check if the buffer we're about to load, has + more than one read dependency + NOTE: enabled with env variable TORCHINDUCTOR_SKIP_L1 + """ + has_read_deps = True + if config.triton.skip_l1_cache: + buffer_read_counts = self.features.buffer_read_counts() + has_read_deps = buffer_read_counts[name] > 1 + """Skip L1 cache if we're (pretty?) sure the data is used only once + """ + skip_l1_cache = ( + not self.is_broadcasted(original_index) + and not self.inside_reduction + and not has_read_deps + and is_coalesced # for indirect loads is_coalesced is False? + ) + cachemod = "" + if skip_l1_cache: + cachemod = ", cache_modifier='.cg'" + + append_broadcast = None + dtype = V.graph.get_dtype(name) + + if should_unwrap_unspec_arg(name): + line = var + # unwrapped bf16/fp16 0d tensors are passed in as float32 scalars + # see triton_utils.py:signature_of + if dtype in (torch.float16, torch.bfloat16): + dtype = torch.float32 + + else: + if isinstance(indexing, BlockPtrOptions): + block_ptr, other = self.codegen_block_ptr(name, var, indexing, other) + line = f"tl.load({block_ptr}{other}{ep}{cachemod})" + line = indexing.codegen_broadcast_and_reshape( + line, indexing.block_shape, indexing.final_shape, True + ) + elif isinstance(original_index, sympy.Integer): + line = f"tl.load({var} + ({original_index}))" + append_broadcast = indexing.expand_str + else: + line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other}{cachemod})" + + if ( + dtype in (torch.float16, torch.bfloat16) + and config.triton.codegen_upcast_to_fp32 + ): + line += ".to(tl.float32)" + dtype = torch.float32 + if dtype == torch.bool and torch.version.hip is None: + # Workaround for https://github.com/triton-lang/triton/issues/2151 + # tl.load returns int8 when loading from pointer to int1 + # NOTE: Currently causes hangs on bool UTs for ROCm + line += ".to(tl.int1)" + dtype = torch.bool + + load_buffer = self.get_load_buffer(indexing) + result_var = self.cse.generate(load_buffer, make_line(line), dtype=dtype) + if result_var.use_count > 1: + load_counts[name] -= 1 # don't double count cache hit + assert isinstance(result_var, TritonCSEVariable) + result_var.mask_vars = indexing.mask_vars # type: ignore[assignment] + + if append_broadcast: + line = f"tl.broadcast_to({result_var}, {append_broadcast})" + result_var = self.cse.generate(load_buffer, line, dtype=dtype) + if indexing.mask_vars: + if dtype.is_floating_point: + zero = "0.0" + elif dtype == torch.bool: + zero = "True" + else: + zero = "0" + other_val = ( + constant_repr(self._load_other) if self._load_other else zero + ) + line = f"tl.where({indexing.mask_str}, {result_var}, {other_val})" + result_var = self.cse.generate(load_buffer, line, dtype=dtype) + + if not self.inside_reduction or (not indexing.has_rmask() and not has_rindex): + self.outside_loop_vars.add(result_var) + + return result_var + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None + ) -> None: + var = self.args.output(name) + original_index = index + indexing = self.indexing(index, dense_indexing=True, block_ptr=mode is None) + + # Guard against write-after-read corruption in triton. + # See # https://github.com/triton-lang/triton/issues/1615 + # This triton bug means that a load which is broadcasted over multiple + # warps may see the result of a store that happens later in the triton + # program. The workaround is to add a barrier before storing, which + # enforces that all warps have already read the data. + is_inplace = name in self.args.inplace_buffers + is_broadcasted = self.is_broadcasted(original_index) + if is_inplace and is_broadcasted: + self.stores.writeline(DeferredLine(name, "tl.debug_barrier()")) + + if isinstance(indexing, BlockPtrOptions): + block_ptr, other = self.codegen_block_ptr(name, var, indexing) + # block_ptr stores don't do implicit casting + line = self.codegen_block_ptr_store_line( + name, indexing, block_ptr, value, other + ) + elif mode is None: + line = f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})" + elif mode == "atomic_add": + line = f"tl.atomic_add({var} + ({indexing.index_str}), {value}, {indexing.mask_str}, sem='relaxed')" + else: + raise NotImplementedError(f"store mode={mode}") + + exit_stack = contextlib.ExitStack() + if not self.inside_reduction and self.cooperative_reduction: + exit_stack.enter_context(self.guard_cooperative_store(name, self.stores)) + + self.stores.writeline(DeferredLine(name, line)) + + if not self.inside_reduction: + self.outside_loop_vars.add(value) + + exit_stack.close() + + def guard_cooperative_store(self, name, buffer): + """ + For cooperative reductions only one thread block should write out the result. + We rotate which thread block does each write for better parallelism + """ + idx = self.cooperative_reduction_workspace_cache.increment_store_count() + buffer.writeline(DeferredLine(name, f"if rsplit_id == ({idx} % RSPLIT):")) + return buffer.indent() + + def bucketize( + self, + values: CSEVariable, + boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], + boundary_indices: CSEVariable, + indexing_dtype: torch.dtype, + right: bool, + sorter: Optional[tuple[str, sympy.Expr]] = None, + sorter_indices: Optional[CSEVariable] = None, + ) -> CSEVariable: + """ + See [Note: Inductor bucketize op] + """ + + # Triton performance for bucketize_binary_search is much better when the number + # of threads equals the number of elements. + # If we're trying to use a bucketize kernel, we should make sure that an + # autotuning config with num_elements_per_warp=(warp_size) exists. + self.autotune_hints.add(AutotuneHint.ONE_ELEMENT_PER_THREAD) + + boundaries_ptr = self.args.input(boundaries[0]) + boundary_size = self.index_to_str(boundaries[1]) + boundaries_underlying_numel = self.index_to_str(boundaries[2]) + boundary_stride = self.index_to_str(boundaries[3]) + sorter_ptr = self.args.input(sorter[0]) if sorter else "None" + sorter_stride = self.index_to_str(sorter[1]) if sorter else "None" + + if indexing_dtype == torch.int32: + triton_dtype = "tl.int32" + elif indexing_dtype == torch.int64: + triton_dtype = "tl.int64" + else: + raise NotImplementedError( + "Bucketize only supports indexing with int32 and int64" + ) + + result = self.cse.generate( + self.compute, + f"triton_helpers.bucketize_binary_search({values}, " + f"{boundaries_ptr}, {boundary_size}, {boundaries_underlying_numel}, {boundary_stride}, " + f"{boundary_indices}, " + f"{triton_dtype}, " + f"{right}, " + f"{sorter_ptr}, {sorter_stride}, " + f"{sorter_indices}, " + ")", + dtype=indexing_dtype, # type: ignore[attr-defined] + ) + + return result + + def reduction_resize(self, value) -> str: + ndims = self.triton_tensor_ndim() + if ndims == 1: + return f"triton_helpers.promote_to_tensor({value})" + + nreduce = self.num_reduction_dims + sizes = [":"] * (ndims - nreduce) + ["None"] * nreduce + return f"{value}[{', '.join(sizes)}]" + + def reduction_collapse_dims(self, buffer, value: str, dtype: torch.dtype) -> str: + """ + Reshape to RBLOCK, collapsing all reduction dims. + """ + # This is not needed for 1D reductions. + if self.num_reduction_dims == 1: + return value + + target_ndim = self.triton_tensor_ndim() - self.num_reduction_dims + initial_shape = self.dense_size_list() + target_shape = initial_shape[:target_ndim] + ["RBLOCK"] + return str( + self.cse.generate( + buffer, triton_reshape(value, initial_shape, target_shape), dtype=dtype + ) + ) + + def reduction( + self, + dtype: torch.dtype, + src_dtype: torch.dtype, + reduction_type: ReductionType, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: + def maybe_upcast(value: CSEVariable) -> CSEVariable: + # Math reductions in FP16/BF16 are less accurate because the Triton compiler does not + # automatically promote to FP32 for accumulation. Additionally, max/min reductions + # do not support FP16/BF16. We manually promote to FP32 here. + return ( + ops.to_dtype(value, torch.float32) + if value.dtype + in [ + torch.float16, + torch.bfloat16, + ] + else value + ) + + original_dtypes = [val.dtype for val in pytree.tree_leaves(value)] + value = pytree.tree_map(maybe_upcast, value) + if any(x in [torch.float16, torch.bfloat16] for x in original_dtypes): + # Only promote FB16/BF16; do not promote other integer/boolean dtypes + src_dtype = torch.promote_types(src_dtype, torch.float32) + dtype = torch.promote_types(dtype, torch.float32) + + assert self.inside_reduction + masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) + self.filter_masks(masks) + masks = sorted(masks) + if self._load_mask: + masks.append(self._load_mask) + reduction_range_prefix = self.range_trees[-1].prefix[0] + + # Say we have + # tmp0 = ops.constant(1, torch.int64) + # tmp1 = ops.reduction(torch.int64, torch.int64, "sum", tmp0) + # tmp0 in the triton code is either a scalar, or single-element tensor + # so if we emit tl.sum directly, it will only give 1 instead of RBLOCK * 1 + # To avoid this, we broadcast to the expected shape first. + dense_size_str = self.dense_size_str() + value = self._map_tuple_or_scalar( + lambda v: self.cse.generate( + self.compute, + f"tl.broadcast_to({v}, {dense_size_str})", + dtype=v.dtype, + ), + value, + ) + + dim = self.triton_tensor_ndim() - self.num_reduction_dims + root_op: str + + def final_reduction( + buffer, + value: str, + result_type: Optional[str], + ) -> str: + """ + Helper to generate a reduction call, e.g. tl.sum. + """ + use_helper = reduction_type in ("any", "max", "min", "prod") + module = "triton_helpers" if use_helper else "tl" + + value = self.reduction_collapse_dims(buffer, value, dtype) + if reduction_type in ("max", "min"): + value = self.reduction_resize( + f"{module}.{reduction_type}2({value}, {dim})" + ) + else: + value = self.reduction_resize( + f"{module}.{reduction_type}({value}, {dim})" + ) + + if result_type is not None: + value = f"{value}.to({result_type})" + + return value + + def final_reduction_define( + buffer, + result_var: str, + value: str, + result_type: Optional[str], + ) -> None: + """ + Generate a reduction and assign it to an existing variable. + """ + value = final_reduction(buffer, value, result_type) + buffer.splice(f"{result_var} = {value}") + + def final_argreduce(buffer, result_var, value, index): + value = self.reduction_collapse_dims(buffer, value, dtype) + index = self.reduction_collapse_dims(buffer, index, dtype) + buffer.splice( + f"""\ + {result_var}_val, {result_var}_idx = triton_helpers.{root_op}_with_index({value}, {index}, {dim}) + {result_var} = {self.reduction_resize(f"{result_var}_idx")} + """ + ) + + cache_key = (src_dtype, reduction_type, value) + if cache_key in self.cse.reduction_cache: + return self.cse.reduction_cache[cache_key] + + acc_type = triton_acc_type(src_dtype) + torch_acc_type = upcast_acc_dtype(src_dtype) + result_var: Any = self.cse.newvar(dtype=torch_acc_type) + result_var.mask_vars = OrderedSet( + var for var in masks if not prefix_is_reduction(var[0]) + ) + cond = " & ".join(masks) + + def where_cond(tval, fval): + if not cond: + return tval + return TritonKernelOverrides.where(cond, tval, fval) + + if self.persistent_reduction: + default = ir.Reduction.default_value(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(constant_repr, default) + + def _mask_value(value, default) -> CSEVariable: + return self.cse.generate( + self.compute, where_cond(value, default), dtype=value.dtype + ) + + masked_value: Union[CSEVariable, Sequence[CSEVariable]] + if reduction_type == "online_softmax_reduce": + # Don't generate mask value for online_softmax since we + # will fallback below + pass + elif isinstance(value, tuple): + masked_value = [_mask_value(v, d) for v, d in zip(value, default)] + else: + masked_value = _mask_value(value, default) + + if reduction_type in ("argmax", "argmin"): + accumulator_dtype = V.kernel.get_index_dtype_as_torch_dtype() + accumulator_index = str( + self.cse.generate( + self.compute, + f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)", + dtype=accumulator_dtype, + ) + ) + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + final_argreduce( + self.compute, result_var, masked_value, accumulator_index + ) + result_var.dtype = accumulator_dtype + elif reduction_type == "welford_reduce": + if self.cooperative_reduction: + # cooperative reductions require full welford for correctness + result_var = self.welford_reduce( + result_var, reduction_type, value, where_cond, acc_type, dtype + ) + else: + # For persistent reductions, don't bother with + # welford's algorithm since it uses more registers, and + # taking two reductions doesn't increase memory usage. + result_var = self.welford_reduce_fallback(dtype, value) + elif reduction_type == "welford_combine": + assert isinstance(masked_value, Sequence) + (mean, m2, weight) = masked_value + result_var = tuple( + self.cse.generate(self.compute, value, dtype=dtype) + for value in self._welford( + self.compute, mean, m2, weight, dim, dtype + ) + ) + elif reduction_type == "online_softmax_reduce": + # All data is loaded to register anyway, no need to do + # online softmax + result_var = self.prepare_softmax_twopass_fallback(dtype, value) + else: + assert isinstance(masked_value, CSEVariable) + result_var = self.cse.generate( + self.compute, + final_reduction(self.compute, str(masked_value), None), + dtype=masked_value.dtype, + ) + else: + accumulator = self.cse.namedvar(f"_{result_var}", dtype=torch_acc_type) + default = ir.Reduction.default_accumulator(reduction_type, src_dtype) + default = self._map_tuple_or_scalar(constant_repr, default) + if not isinstance(default, tuple): + self.body.writeline( + f"{accumulator} = tl.full({self.dense_size_str()}, {default}, {acc_type})" + ) + + if reduction_type in ("argmax", "argmin"): + accumulator_index = f"_{result_var}_index" + index_dtype = self.features.select_index_dtype() + self.body.writeline( + f"{accumulator_index} = tl.full({self.dense_size_str()}, " + f"{torch.iinfo(index_dtype).max}, {self.dtype_to_str(index_dtype)})" + ) + root_op = {"argmax": "max", "argmin": "min"}[reduction_type] + + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index( + {accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index + ) + {accumulator} = {where_cond(f"{accumulator}_next", accumulator)} + {accumulator_index} = {where_cond(f"{accumulator_index}_next", accumulator_index)} + """ + ) + final_argreduce( + self.post_loop_combine, result_var, accumulator, accumulator_index + ) + elif is_welford_reduction(reduction_type): + result_var = self.welford_reduce( + result_var, reduction_type, value, where_cond, acc_type, dtype + ) + elif reduction_type == "online_softmax_reduce": + accumulator_max = f"_{result_var}_max" + accumulator_sum = f"_{result_var}_sum" + + # setup accumulator + self.body.writeline( + f"{accumulator_max} = tl.full({self.dense_size_str()}, float('-inf'), {acc_type})" + ) + self.body.writeline( + f"{accumulator_sum} = tl.zeros({self.dense_size_str()}, {acc_type})" + ) + + # combine + # Note, we pass config.use_fast_math to the JITFunction + # since a triton kernel can not access a config. + self.compute.splice( + f""" + {accumulator_max}_next, {accumulator_sum}_next = triton_helpers.online_softmax_combine( + {accumulator_max}, {accumulator_sum}, {value}, {config.use_fast_math} + ) + """ + ) + + # mask + self.compute.splice( + f""" + {accumulator_max} = {where_cond(f"{accumulator_max}_next", accumulator_max)} + {accumulator_sum} = {where_cond(f"{accumulator_sum}_next", accumulator_sum)} + """ + ) + + # reduce. Similar to the final reduction for coopereative + # reduction + result_max = result_var + result_sum = self.cse.newvar(dtype=dtype) + + result_var = self.online_softmax_reduce_final_reduction( + self.post_loop_combine, + result_max, + result_sum, + accumulator_max, + accumulator_sum, + dim, + dtype, + ) + else: + combine_fn = ir.get_reduction_combine_fn(reduction_type, src_dtype) + updated = combine_fn(accumulator, value) + self.compute.writeline( + f"{accumulator} = {where_cond(updated, accumulator)}" + ) + + if src_dtype == torch.bool: + # This is only really used for aten.any. It changes the + # final reduction of a non-persistent reduction from + # tmp5 = triton_helpers.max(_tmp5, 1)[:, None] + # to + # tmp5 = triton_helpers.max(_tmp5.to(tl.int8), 1)[:, None].to(tl.int1) + # which is needed because tl.reduce doesn't support tl.int1 + accumulator_casted_str = f"{accumulator}.to(tl.int8)" + result_type = triton_compute_type(dtype) + final_reduction_define( + self.post_loop_combine, + str(result_var), + accumulator_casted_str, + result_type, + ) + else: + final_reduction_define( + self.post_loop_combine, str(result_var), str(accumulator), None + ) + + if self.cooperative_reduction: + default = ir.Reduction.default_accumulator(reduction_type, src_dtype) + exit_stack = contextlib.ExitStack() + for buf in (self.post_loop_combine, self.post_loop_store): + # only do cooperative reduction combines if we have more than one thread block + buf.writeline("if HAS_RSPLIT:") + exit_stack.enter_context(buf.indent()) + + if reduction_type in ("argmax", "argmin"): + self.post_loop_combine.writeline( + f"{result_var}_bval = {self.reduction_resize(f'{result_var}_val')}" + ) + peer_val = self.codegen_cooperative_reduction_peer_combine( + f"{result_var}_bval", src_dtype, default + ) + index_dtype = self.features.select_index_dtype() + peer_idx = self.codegen_cooperative_reduction_peer_combine( + result_var, index_dtype, torch.iinfo(index_dtype).max + ) + final_argreduce(self.post_loop_store, result_var, peer_val, peer_idx) + elif is_welford_reduction(reduction_type): + assert reduction_type == "welford_reduce" + result_mean, result_m2, result_weight = result_var + peer_mean = self.codegen_cooperative_reduction_peer_combine( + result_mean, + upcast_acc_dtype(src_dtype), + default[0], # type: ignore[index] + ) + peer_m2 = self.codegen_cooperative_reduction_peer_combine( + result_m2, + upcast_acc_dtype(src_dtype), + default[1], # type: ignore[index] + ) + peer_weight = self.codegen_cooperative_reduction_peer_combine( + result_weight, + upcast_acc_dtype(src_dtype), + default[2], # type: ignore[index] + ) + self.welford_reduce_final_reduction( + self.post_loop_store, + result_mean, + result_m2, + result_weight, + peer_mean, + peer_m2, + peer_weight, + dim, + dtype, + ) + elif reduction_type == "online_softmax_reduce": + result_max, result_sum = result_var + peer_max = self.codegen_cooperative_reduction_peer_combine( + result_max, upcast_acc_dtype(src_dtype), default[0] + ) + peer_sum = self.codegen_cooperative_reduction_peer_combine( + result_sum, upcast_acc_dtype(src_dtype), default[1] + ) + self.online_softmax_reduce_final_reduction( + self.post_loop_store, + result_max, + result_sum, + peer_max, + peer_sum, + dim, + dtype, + ) + else: + peers = self.codegen_cooperative_reduction_peer_combine( + result_var, upcast_acc_dtype(src_dtype), default + ) + final_reduction_define( + self.post_loop_store, str(result_var), peers, None + ) + exit_stack.close() + + self.cse.reduction_cache[cache_key] = result_var + + if isinstance(result_var, tuple): + assert all(isinstance(x, TritonCSEVariable) for x in result_var) + self.outside_loop_vars.update(result_var) + + # Match output dtype with input dtype + if reduction_type in ("welford_reduce", "online_softmax_reduce"): + assert len(original_dtypes) == 1 + original_dtypes = len(result_var) * original_dtypes + + assert len(result_var) == len(original_dtypes) + for var, orig_dtype in zip(result_var, original_dtypes): + assert orig_dtype is not None + if var.dtype != orig_dtype: + self.post_loop_combine.writeline( + f"{var} = {var}.to({triton_compute_type(orig_dtype)})" + ) + else: + assert isinstance(result_var, TritonCSEVariable) + self.outside_loop_vars.add(result_var) + + # Match output dtype with input dtype + if result_var.dtype != original_dtypes[0]: + assert original_dtypes[0] is not None + self.post_loop_combine.writeline( + f"{result_var} = {result_var}.to({triton_compute_type(original_dtypes[0])})" + ) + + return result_var + + def _online_softmax_reduce( + self, buffer, accumulator_max, accumulator_sum, dim, dtype: torch.dtype + ): + accumulator_max = self.reduction_collapse_dims(buffer, accumulator_max, dtype) + accumulator_sum = self.reduction_collapse_dims(buffer, accumulator_sum, dtype) + result_max, result_sum = [str(self.cse.newvar(dtype=dtype)) for _ in range(2)] + buffer.splice( + f""" + {result_max}, {result_sum} = triton_helpers.online_softmax_reduce( + {accumulator_max}, {accumulator_sum}, {dim}, {config.use_fast_math}) + {result_max} = {self.reduction_resize(f"{result_max}")} + {result_sum} = {self.reduction_resize(f"{result_sum}")} + """ + ) + + return result_max, result_sum + + def _welford(self, buffer, mean, m2, weight, dim, dtype: torch.dtype): + """ + Helper to codegen triton_helpers.welford. + """ + mean, m2, weight = ( + self.reduction_collapse_dims(buffer, value, dtype) + for value in (mean, m2, weight) + ) + welford = f"triton_helpers.welford({mean}, {m2}, {weight}, {dim})" + welford_results = [str(self.cse.newvar(dtype=dtype)) for _ in range(3)] + buffer.writeline(f"{', '.join(welford_results)} = {welford}") + + result_values = tuple(self.reduction_resize(value) for value in welford_results) + return result_values + + def welford_reduce( + self, result_var, reduction_type, value, where_cond, acc_type, dtype + ): + """Helper to codegen a welford reduction""" + dim = self.triton_tensor_ndim() - self.num_reduction_dims + + accumulator = f"{result_var}_mean" + accumulator_m2 = f"{result_var}_m2" + accumulator_weight = f"{result_var}_weight" + self.body.writeline( + f"{accumulator} = tl.zeros({self.dense_size_str()}, {acc_type})" + ) + self.body.writeline( + f"{accumulator_m2} = tl.zeros({self.dense_size_str()}, {acc_type})" + ) + self.body.writeline( + f"{accumulator_weight} = tl.zeros({self.dense_size_str()}, {acc_type})" + ) + if reduction_type == "welford_combine": + mean, m2, weight = value + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_combine( + {accumulator}, {accumulator_m2}, {accumulator_weight}, + {mean}, {m2}, {weight} + ) + """ + ) + else: + assert reduction_type == "welford_reduce" + self.compute.splice( + f"""\ + {accumulator}_next, {accumulator_m2}_next, {accumulator_weight}_next = triton_helpers.welford_reduce( + {value}, {accumulator}, {accumulator_m2}, {accumulator_weight}, roffset == 0 + ) + """ + ) + self.compute.splice( + f"""\ + {accumulator} = {where_cond(f"{accumulator}_next", accumulator)} + {accumulator_m2} = {where_cond(f"{accumulator_m2}_next", accumulator_m2)} + {accumulator_weight} = {where_cond(f"{accumulator_weight}_next", accumulator_weight)} + """ + ) + result_mean = result_var + result_m2 = self.cse.newvar(dtype=dtype) + result_weight = self.cse.newvar(dtype=dtype) + return self.welford_reduce_final_reduction( + self.post_loop_combine, + result_mean, + result_m2, + result_weight, + accumulator, + accumulator_m2, + accumulator_weight, + dim, + dtype, + ) + + def welford_reduce_final_reduction( + self, + buffer, + result_mean, + result_m2, + result_weight, + mean, + m2, + weight, + dim, + dtype, + ): + """Helper to codegen call to triton_helpers.welford""" + values = self._welford(buffer, mean, m2, weight, dim, dtype) + result_exprs = [result_mean, result_m2, result_weight] + for result_expr, value in zip(result_exprs, values): + buffer.splice(f"{result_expr} = {value}") + + return result_mean, result_m2, result_weight + + def online_softmax_reduce_final_reduction( + self, buffer, result_max, result_sum, peer_max, peer_sum, dim, dtype + ): + values = self._online_softmax_reduce(buffer, peer_max, peer_sum, dim, dtype) + result_exprs = [result_max, result_sum] + for result_expr, value in zip(result_exprs, values): + buffer.splice(f"{result_expr} = {value}") + + return result_max, result_sum + + def max_rsplit(self): + if self.fixed_config: + return self.fixed_config["RSPLIT"] + return TRITON_MAX_RSPLIT + + def codegen_cooperative_reduction_peer_combine( + self, result_var, dtype, default_val + ): + """ + Generate code to save a [XBLOCK, RSPLIT] temporary workspace, where each thread block writes a different + column. After the barrier, every thread block loads the completed value so that it can compute the final + value independently. + """ + xnumel = self.numels["x"] + mask = "xindex < xnumel" if not self._has_constant_xmask() else None + + nbytes = xnumel * dtype.itemsize * self.max_rsplit() + ws_name, ws_offset = self.cooperative_reduction_workspace_cache.allocate(nbytes) + + self.post_loop_combine.splice( + f""" + {result_var}_ws = ({ws_name} + {self.index_to_str(ws_offset)}).to(tl.pointer_type({triton_type(dtype)})) + tl.store({result_var}_ws + (xindex * RSPLIT + rsplit_id), {result_var}, {mask}) + """, + strip=True, + ) + self.post_loop_store.writeline( + f"{result_var}_peers = tl.load({result_var}_ws + (xindex * RSPLIT + rsplit_arange), " + f"rsplit_mask, eviction_policy='evict_first', other=triton_helpers.if_mask(rsplit_mask, {constant_repr(default_val)}))" + ) + return f"{result_var}_peers" + + def store_reduction( + self, + name: str, + index: sympy.Expr, + value: Union[CSEVariable, tuple[CSEVariable, ...]], + ): + assert self.inside_reduction + self.inside_reduction = False + indexing = self.indexing(index, block_ptr=True) + self.inside_reduction = True + var = self.args.output(name) + + exit_stack = contextlib.ExitStack() + if self.cooperative_reduction: + exit_stack.enter_context( + self.guard_cooperative_store(name, self.post_loop_store) + ) + + if isinstance(indexing, BlockPtrOptions): + self.post_loop_store.writeline( + DeferredLine( + name, + self.codegen_block_ptr_store_line( + name, + indexing, + indexing.format(var), + value, + f", boundary_check={indexing.boundary_check()!r}", + ), + ) + ) + else: + assert isinstance(indexing, IndexingOptions) + self.post_loop_store.writeline( + DeferredLine( + name, + f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})", + ) + ) + + exit_stack.close() + + def _lift_helper(self, fn, num_args, dtypes: tuple[torch.dtype, ...]) -> str: + # Lift IR function for scan operations into a triton function + # in the global namespace + helper = IndentedBuffer() + helper.writeline("@triton.jit") + cse = CSE() + + args = [ + tuple(cse.namedvar(f"arg{i}_{n}", dtype=dtypes[n]) for n in range(num_args)) + for i in range(2) + ] + signature = ", ".join(str(x) for x in itertools.chain.from_iterable(args)) + helper.writeline(f"def {{name}}({signature}):") + + overrides = TritonOverrides() + + # Build a name that changes depending on fn to workaround a triton bug + # where the combine_fn to reduce and scan is not hashed, and so different + # scan ops may collide in the triton cache. + # This is fixed with the latest triton pin, but not the triton-rocm pin. + helper_name = "_triton_helper_fn" + + from torch._inductor.dtype_propagation import DtypePropagationOpsHandler + + dtype_handler = DtypePropagationOpsHandler() + + class CSEProxy(DefaultHandler): + def _default( + self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> Any: + nonlocal helper_name + helper_name += f"_{name}" + + output_dtype = getattr( + dtype_handler, + name, + )(*args, **kwargs) + + return cse.generate( + helper, + getattr(overrides, name)(*args, **kwargs), + dtype=output_dtype, + ) + + with helper.indent(), V.set_ops_handler(CSEProxy()): + outputs = fn(*args) + outputs = ", ".join(str(output) for output in outputs) + helper.writeline(f"return {outputs}") + + return self.helper_functions.add(helper.getvalue(), base_name=helper_name) + + def scan( + self, + dtypes: tuple[torch.dtype, ...], + combine_fn: Callable[ + [tuple[CSEVariable, ...], tuple[CSEVariable, ...]], tuple[CSEVariable, ...] + ], + values: tuple[CSEVariable, ...], + ) -> tuple[CSEVariable, ...]: + assert self.inside_reduction + assert not self.cooperative_reduction, "TODO" + masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) + self.filter_masks(masks) + masks = sorted(masks) + assert not self._load_mask, "ops.scan not supported inside ops.masked" + + broadcasted_values = [] + accumulators = [] + + dtypes = tuple(upcast_compute_type(dtype) for dtype in dtypes) + cse_compute = functools.partial(self.cse.generate, self.compute) + combine_helper_fn = self._lift_helper(combine_fn, len(values), dtypes) + dim = self.triton_tensor_ndim() - self.num_reduction_dims + + for value, dtype in zip(values, dtypes): + value_dtype = self.cse.generate( + self.compute, + f"{value}.to({triton_compute_type(dtype)})", + dtype=dtype, + ) + value = self.cse.generate( + self.compute, + f"tl.broadcast_to({value_dtype}, {self.dense_size_str()})", + dtype=dtype, + ) + broadcasted_values.append(value) + + acc_type = triton_acc_type(dtype) + + if not self.persistent_reduction: + accumulator = self.cse.newvar(dtype=dtype) + reduced_size = self.dense_size_list() + reduced_size[-1] = "1" + reduced_size = f"[{', '.join(reduced_size)}]" + + default = "float('nan')" if dtype.is_floating_point else "-1" + self.body.writeline( + f"{accumulator} = tl.full({reduced_size}, {default}, {acc_type})" + ) + + accumulators.append(accumulator) + + def csv(values): + return " ".join(f"{value}," for value in values) + + def cse_multiple(line, values, masks, dtypes): + n = len(values) + cache_keys = [f"{line}, {i}, {masks}" for i in range(n)] + if all(self.cse.contains(cache_key) for cache_key in cache_keys): + return [self.cse.get(cache_key) for cache_key in cache_keys] + result_vars = [self.cse.newvar(dtype=_dtype) for _dtype in dtypes] + self.compute.writeline( + f"{csv(result_vars)} = {line}", + ) + for result_var, cache_key in zip(result_vars, cache_keys): + if masks: + result_var.mask_vars = masks # type: ignore[attr-defined] + self.cse.put(cache_key, result_var) + return tuple(result_vars) + + partial_scan_vars = cse_multiple( + f"tl.associative_scan(({csv(broadcasted_values)}), {dim}, {combine_helper_fn})", + values, + masks, + dtypes, + ) + + if not self.persistent_reduction: + # tl.reduce doesn't work for non-commutative operators, so instead + # of repeating the scan op as a reduction, we use sum to select the + # last scan value + partial_reduce_vars = [ + cse_compute( + f"triton_helpers.select_one(({partial_scan_var}), rbase == (RBLOCK - 1), dim=-1, keep_dims=True)", + dtype=upcast_compute_type(partial_scan_var.dtype), + ) + for partial_scan_var in partial_scan_vars + ] + accs_next = combine_fn(tuple(accumulators), tuple(partial_reduce_vars)) + full_scan_vars = combine_fn(tuple(accumulators), partial_scan_vars) + result_vars = [ + cse_compute( + f"tl.where(roffset > 0, {full_scan}, {partial_scan})", + dtype=partial_scan.dtype, + ) + for full_scan, partial_scan in zip(full_scan_vars, partial_scan_vars) + ] + for acc_next, accumulator, partial_reduce in zip( + accs_next, accumulators, partial_reduce_vars + ): + self.compute.writeline( + f"{accumulator} = tl.where(roffset > 0, {acc_next}, {partial_reduce})" + ) + else: + result_vars = partial_scan_vars + + for result_var in result_vars: + assert isinstance(result_var, TritonCSEVariable) + result_var.mask_vars = OrderedSet(masks) + + return tuple(result_vars) + + def sort( + self, + dtypes: tuple[torch.dtype, ...], + values: tuple[CSEVariable, ...], + stable: bool, + descending: bool, + ) -> tuple[CSEVariable, ...]: + assert self.inside_reduction + assert not self.cooperative_reduction, "TODO" + masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) + self.filter_masks(masks) + masks = sorted(masks) + assert not self._load_mask, "ops.sort not supported inside ops.masked" + assert self.persistent_reduction, ( + "ops.sort is only supported in persistent reductions" + ) + + cse_compute = functools.partial(self.cse.generate, self.compute) + dim = self.triton_tensor_ndim() - self.num_reduction_dims + + dtypes = tuple(upcast_compute_type(dtype) for dtype in dtypes) + assert len(dtypes) == len(values) + broadcasted_values = [ + cse_compute( + f"tl.broadcast_to({value}, {self.dense_size_str()})", dtype=dtypes[i] + ) + for i, value in enumerate(values) + ] + + def csv(values): + return " ".join(f"{value}," for value in values) + + def cse_multiple(line, n, masks, dtypes): + cache_keys = [f"{line}, {i}, {masks}" for i in range(n)] + if all(self.cse.contains(cache_key) for cache_key in cache_keys): + return [self.cse.get(cache_key) for cache_key in cache_keys] + result_vars = [self.cse.newvar(dtype=dtypes[i]) for i in range(n)] # type: ignore[attr-defined] + self.compute.writeline( + f"{csv(result_vars)} = {line}", + ) + for result_var, cache_key in zip(result_vars, cache_keys): + if masks: + result_var.mask_vars = masks # type: ignore[attr-defined] + self.cse.put(cache_key, result_var) + return tuple(result_vars) + + assert self.range_trees[-1].is_reduction + rnumel = "None" if self._has_constant_mask(self.range_trees[-1]) else "rnumel" + + if len(values) == 2: + line = ( + f"triton_helpers.sort_with_index({broadcasted_values[0]}, {broadcasted_values[1]}," + f" {rnumel}, {dim}, stable={stable}, descending={descending})" + ) + result_vars = cse_multiple(line, len(values), masks, dtypes) + else: + raise AssertionError("Unhandled sort") + + for result_var, input_var in zip(result_vars, values): + result_var.mask_vars = masks # type: ignore[attr-defined] + result_var.bounds = input_var.bounds + + return tuple(result_vars) + + def codegen_body(self): + """ + Concat output code from index_code, loads, compute, stores, + suffix into self.body. + + For pointwise kernels, this is called just once at the end. + + For reduction kernels, this generates a loop over the reduction + axis. + """ + if not ( + self.indexing_code + or self.loads + or self.stores + or self.compute + or self.post_loop_combine + or self.post_loop_store + ): + return + + loop_trees = [tree for tree in self.range_trees if tree.is_loop] + if self.inside_reduction and len(loop_trees) > 0: + # Write the loop headers. + for level, tree in enumerate(loop_trees): + with self.body.indent(offset=level): + prefix = tree.prefix + loop_start = "rsplit_start" if self.cooperative_reduction else "0" + loop_end = ( + "rsplit_end" if self.cooperative_reduction else f"{prefix}numel" + ) + self.body.writeline( + f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):" + ) + with self.body.indent(offset=level + 1): + self.iteration_ranges_codegen_header(tree, self.body) + + # The innermost loop performs the reduction. + with self.body.indent(offset=len(loop_trees)): + self.codegen_reduction_indices(self.body) + self.body.splice(self.indexing_code) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + + # Write loop suffixes. + for level, tree in reversed([*enumerate(loop_trees)]): + with self.body.indent(offset=level + 1): + # Advance pointers at the end of each loop. + for block_ptr, advancement in self.pointer_advancements[ + tree.symt + ].items(): + # Subtract any advancements made in the previous loop level. + if level < len(loop_trees) - 1: + prev_tree = loop_trees[level + 1] + prev_advancement = self.pointer_advancements[ + prev_tree.symt + ][block_ptr] + prev_block = TritonSymbols.get_block_size(prev_tree) + prev_num_iter = CeilDiv(prev_tree.numel, prev_block) + advancement = [ + cur - prev * prev_num_iter + for cur, prev in zip(advancement, prev_advancement) + ] + + self.body.writeline( + DeferredLine( + self.block_ptr_to_buffer[block_ptr], + f"{block_ptr} = tl.advance({block_ptr}, {V.kernel.index_to_str(advancement)})", + ) + ) + + # Invalidate any cache entries that came from inside the loop. + self.cse.invalidate(self.outside_loop_vars) + tree.cache_clear() + else: + self.body.splice(self.indexing_code) + self.body.splice(self.loads) + self.body.splice(self.compute) + self.body.splice(self.stores) + self.body.splice(self.post_loop_combine) + if self.cooperative_reduction and ( + self.post_loop_combine or self.post_loop_store + ): + sem_ptr = f"{self.semaphores_name} + tl.program_id(1)" + self.body.splice( + f""" + if HAS_RSPLIT: + triton_helpers.x_grid_barrier({sem_ptr}) + """, + strip=True, + ) + self.cooperative_reduction_workspace_cache.on_loop_end() + self.body.splice(self.post_loop_store) + self.indexing_code.clear() + self.loads.clear() + self.compute.clear() + self.stores.clear() + self.post_loop_combine.clear() + self.post_loop_store.clear() + + def kernel_benchmark_extra_args(self) -> list[str]: + args = [] + if self.need_numel_args(): + numel_args: list[sympy.Expr] = [] + self.add_numel_to_call_args("", numel_args, []) + for arg in numel_args: + if isinstance(arg, int): + args.append(str(arg)) + elif isinstance(arg, SymbolicCallArg): + args.append(str(V.graph.sizevars.size_hint(arg.inner_expr))) + elif isinstance(arg, sympy.Expr): + args.append(str(V.graph.sizevars.size_hint(arg))) + else: + raise ValueError(f"Unsupported numel argument type: {type(arg)}") + return args + + def codegen_kernel_benchmark(self, num_gb): + result = IndentedBuffer() + _argdefs, call_args, signature, _ = self.args.python_argdefs() + + result.writelines(["", "", "def get_args():"]) + with result.indent(): + name_cnt = itertools.count() + var_names = [] + for arg_name, arg_sig in zip(call_args, signature): + var_name = f"arg_{next(name_cnt)}" + buf = V.graph.try_get_buffer(arg_name) + if buf: + result.writeline( + f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long + ) + elif arg_name in V.graph.constants: + # note that random seed is put in V.graph.constants + const_tensor = V.graph.constants[arg_name] + result.writeline( + f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long + ) + elif isinstance(arg_sig, SizeArg): + symval_hint = V.graph.sizevars.size_hint(arg_sig.expr) + + # Force the seed_offset to be 0 so calls to the same kernel + # using different seed offset will have the same benchmark harness. + # We can dedup kernel definitions in this case. + if "seed_offset" in arg_sig.name: + symval_hint = 0 + result.writeline(f"{var_name} = {symval_hint}") + elif isinstance(arg_sig, WorkspaceArg): + device = V.graph.get_current_device_or_throw() + count = V.graph.sizevars.size_hint(arg_sig.count) + result.writeline( + f"{var_name} = torch.zeros({count}, device='{device}', dtype={arg_sig.dtype})" + ) + else: + raise KeyError( + f"Don't find the buffer or const tensor for {arg_name}" + ) + var_names.append(var_name) + var_names.extend(self.kernel_benchmark_extra_args()) + result.writeline(f"return {', '.join(var_names)},") + + result.writelines(["\n", "\n", "def call(args):"]) + current_device = V.graph.get_current_device_or_throw() + index = current_device.index + with result.indent(): + result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") + with result.indent(): + result.writeline( + V.graph.device_ops.set_device(index) + ) # no-op to ensure context + stream_name = f"stream{index}" + result.writeline(f"{stream_name} = get_raw_stream({index})") + result.writeline( + f"{str(Placeholder.KERNEL_NAME)}.run(*args, stream={stream_name})" + ) + + # benchmark all configs + result.writelines(["\n", "\n", "def benchmark_all_configs(args):"]) + with result.indent(): + result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") + with result.indent(): + result.writeline( + V.graph.device_ops.set_device(index) + ) # no-op to ensure context + result.writeline( + f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args)" + ) + + result.writelines(["\n", "\n", "if __name__ == '__main__':"]) + with result.indent(): + result.writeline( + "from torch._inductor.runtime.benchmarking import benchmarker" + ) + result.writeline("") + + result.writeline("args = get_args()") + result.writeline( + "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)" + ) + result.writeline(f"num_gb = {num_gb}") + result.writeline("gb_per_s = num_gb / (ms / 1e3)") + result.writeline( + 'print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")' + ) + + return result + + def imports_for_benchmark_kernel(self): + return textwrap.dedent( + """ + from torch._dynamo.testing import rand_strided + {} + import torch + """.format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")) + ) + + def _get_heuristic(self): + if self.fixed_config: + return "fixed_config" + elif self.cooperative_reduction: + return "cooperative_reduction" + elif self.persistent_reduction: + assert self.inside_reduction + return "persistent_reduction" + elif self.inside_reduction: + return "reduction" + return "pointwise" + + @staticmethod + def inductor_meta_common(): + inductor_meta = { + "backend_hash": torch.utils._triton.triton_hash_with_backend(), + "are_deterministic_algorithms_enabled": torch.are_deterministic_algorithms_enabled(), + "assert_indirect_indexing": config.assert_indirect_indexing, + "autotune_local_cache": config.autotune_local_cache, + "autotune_pointwise": config.triton.autotune_pointwise, + "autotune_remote_cache": config.autotune_remote_cache, + "force_disable_caches": config.force_disable_caches, + "dynamic_scale_rblock": config.dynamic_scale_rblock, + "max_autotune": config.max_autotune, + "max_autotune_pointwise": config.max_autotune_pointwise, + "min_split_scan_rblock": config.triton.min_split_scan_rblock, + "spill_threshold": config.triton.spill_threshold, + "store_cubin": config.triton.store_cubin, + } + if torch.version.hip is not None: + inductor_meta["is_hip"] = True + if config.is_fbcode(): + inductor_meta["is_fbcode"] = True + if config.profile_bandwidth: + inductor_meta["profile_bandwidth"] = config.profile_bandwidth + inductor_meta["profile_bandwidth_regex"] = config.profile_bandwidth_regex + inductor_meta["profile_bandwidth_output"] = config.profile_bandwidth_output + inductor_meta["profile_bandwidth_with_do_bench_using_profiling"] = ( + config.profile_bandwidth_with_do_bench_using_profiling + ) + if config.coordinate_descent_tuning: + inductor_meta["coordinate_descent_tuning"] = ( + config.coordinate_descent_tuning + ) + inductor_meta["coordinate_descent_search_radius"] = ( + config.coordinate_descent_search_radius + ) + inductor_meta["coordinate_descent_check_all_directions"] = ( + config.coordinate_descent_check_all_directions + ) + return inductor_meta + + def codegen_kernel(self, name=None): + code = IndentedBuffer() + + size_hints = {} + for prefix, numel in self.numels.items(): + if prefix_is_reduction(prefix) and not self.inside_reduction: + continue + + numel_hint = V.graph.sizevars.symbolic_hint(numel) + if not isinstance(numel_hint, (int, sympy.Integer)): + # This default heuristic hint was picked carefully: it is + # large, to ensure that we don't shrink the block size (since + # if you don't have many elements, it'd be wasteful to pick a + # large block size). Since we don't know how many elements we + # might have, we should be OK with some inefficiency to make + # sure we handle the large case well. 8192 is the largest + # block size we support, so we pick that. + # + # If we have a better hint for unbacked SymInts (e.g., because + # a user told us, or we are tracking upper bounds) we could + # use that here. + size_hint = 8192 + else: + size_hint = next_power_of_2(int(numel_hint)) + size_hints[prefix] = size_hint + + if name is None: + code.splice(gen_common_triton_imports()) + device_type = V.graph.get_current_device_or_throw().type + if device_type == "cpu": + code.splice("triton_helpers.set_driver_to_cpu()") + else: + code.splice("triton_helpers.set_driver_to_gpu()") + + if config.benchmark_kernel: + code.splice(self.imports_for_benchmark_kernel()) + + argdefs, _, signature, _ = self.args.python_argdefs() + # maps actual expression to SizeArg if it is in sizevars replacements + for i, arg in enumerate(signature): + if isinstance(arg, SizeArg): + # mypy is unhappy about the sympy.Expr + # type for the key of the dict below + symbol = cast(sympy.Symbol, arg.expr) + if symbol in V.graph.sizevars.inv_precomputed_replacements: + signature[i] = SizeArg( + arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol] + ) + + mutated_args: OrderedSet[str] = OrderedSet() + for mutation in self.mutations: + if mutation in self.args.input_buffers: + mutated_args.add(self.args.input_buffers[mutation]) + if ( + mutation in self.args.inplace_buffers + and mutation not in V.graph.removed_buffers + and mutation not in self.removed_buffers + ): + mutated_args.add( + cast(InplacedBuffer, self.args.inplace_buffers[mutation]).inner_name + ) + if mutation in self.args.output_buffers: + mutation_arg = self.args.output_buffers[mutation] + assert not isinstance(mutation_arg, RemovedArg) + mutated_args.add(mutation_arg) + + # Note: [Workspace Mutation] + # workspace arguments are mutated, but are not marked as mutations in self.mutations + # because their buffers are added during codegen, and aren't tracked during + # lowering/scheduling. So we add them as mutated_args explicitly below. + # + # In the logic below, we only mark the workspaces a mutated if they are marked with + # zero_fill: that's because, if we don't expect the buffer to be pre-filled with + # zeros, then, although we still mutate the data, we don't care about those + # mutations because we don't make any assumptions about the contents of the + # workspace buffer. Similarly, ZERO_PER_GRAPH requires the kernel to return + # the buffer back to its original state. + for argname, arg in zip(argdefs, signature): + if ( + isinstance(arg, WorkspaceArg) + and arg.zero_mode == WorkspaceZeroMode.ZERO_ON_CALL + ): + mutated_args.add(argname.name) + + mutated_args = sorted(mutated_args) + + for tree in self.active_range_trees(): + sizearg = SizeArg(f"{tree.prefix}numel", tree.numel) + signature.append(sizearg) + argdefs.append(ArgName(sizearg.name)) + # constexpr version causes issues, see + # https://github.com/pytorch/torchdynamo/pull/1362 + # triton_meta["constants"][len(argdefs)] = V.graph.sizevars.size_hint( + # tree.numel + # ) + # argdefs.append(f"{tree.prefix}numel: tl.constexpr") + + def add_constexpr_arg(arg_name): + # new versions (but not old versions) of Triton need constexprs included in the signature + if triton_version_uses_attrs_dict(): + signature.append(ConstexprArg(arg_name)) + argdefs.append(ArgName(arg_name, is_constexpr=True)) + + for tree in self.range_trees: + if tree.is_reduction and self.persistent_reduction: + # Rn_BLOCK for persistent_reduction is defined in codegen_static_numels + continue + if tree.tensor_dim is None: + continue + + add_constexpr_arg(f"{tree.prefix.upper()}BLOCK") + + if self.cooperative_reduction: + add_constexpr_arg("RSPLIT") + + triton_meta_signature = signature_to_meta( + signature, size_dtype=self.index_dtype, argdefs=argdefs + ) + triton_meta: dict[str, Any] = { + "signature": triton_meta_signature, + "device": DeviceProperties.create(V.graph.get_current_device_or_throw()), + "constants": {}, + } + + # Skip memory optimization for forward of the training loop where we expect + # every new node will increase the peak memory and our greedy approach would + # introduce a lot of unnecessary cpu copies. + optimize_mem = V.graph.is_inference or V.graph.is_backward + + inductor_meta = { + # Triton will not accept an OrderedSet for autotune_hints + "grid_type": self._get_grid_type().__name__, + "autotune_hints": set(self.autotune_hints), # noqa: set_linter + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + "mutated_arg_names": mutated_args, + "optimize_mem": optimize_mem, + "no_x_dim": self.no_x_dim, + "num_load": self.num_load, + "num_reduction": self.num_reduction, + **self.inductor_meta_common(), + } + if self.tiling_scores: + inductor_meta["tiling_scores"] = self.tiling_scores + + if self.cooperative_reduction: + inductor_meta["persistent_reduction"] = self.persistent_reduction + + num_gb = None + if config.benchmark_kernel or config.profile_bandwidth: + num_gb = self.estimate_kernel_num_bytes() / 1e9 + inductor_meta["kernel_num_gb"] = num_gb + + triton_meta["configs"] = [config_of(signature)] + + # Triton compiler includes equal_to_1 args into constants even + # when they are not constexpr. otherwise there may be a segfault + # during launching the Inductor-compiled Triton kernel. + # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 + # https://github.com/triton-lang/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 + for arg_num in equal_1_arg_indices(signature): # type: ignore[index] + triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index,union-attr] + + self.triton_meta = triton_meta + + self.codegen_body() + + for helper in self.helper_functions: + code.writeline("") + code.splice(helper) + + if self.fixed_config: + heuristics_line = f""" + @triton_heuristics.{self._get_heuristic()}( + config={self.fixed_config.config!r}, + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + elif self.inside_reduction: + reduction_hint = self.features.get_reduction_hint() + heuristics_line = f""" + @triton_heuristics.{self._get_heuristic()}( + size_hints={size_hints!r}, + reduction_hint={reduction_hint}, + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + else: + tile_hint = "" + if len(size_hints) == 2: + if ( + len(non_constexpr_signature(signature)) == 4 + ): # input, output and 2 args + tile_hint = "tile_hint=TileHint.SQUARE," + else: + tile_hint = "tile_hint=TileHint.DEFAULT," + heuristics_line = f""" + @triton_heuristics.{self._get_heuristic()}( + size_hints={size_hints!r}, {tile_hint} + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + min_elem_per_thread={self.min_elem_per_thread} + ) + @triton.jit + """ + code.splice(heuristics_line) + code.writeline( + f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(x.full_name() for x in argdefs)}):" + ) + with code.indent(): + self.codegen_static_numels(code) + for old, new in self.args.aliases(): + code.writeline(f"{old} = {new}") + code.splice(self.body) + + if config.benchmark_kernel: + code.splice(self.codegen_kernel_benchmark(num_gb)) + + return code.getvalue() + + @staticmethod + def _get_persistent_RBLOCK(rnumel): + rnumel = V.graph.sizevars.simplify(rnumel) + if isinstance(rnumel, (sympy.Integer, int)): + val = int(rnumel) + val = next_power_of_2(val) + else: + val = 128 + while not V.graph.sizevars.statically_known_leq(rnumel, val): + if val > 16 * 1024: + raise ValueError(f"Failed to find static RBLOCK for {rnumel}") + val *= 2 + return val + + @staticmethod + def has_persistent_RBLOCK(rnumel): + try: + TritonKernel._get_persistent_RBLOCK(rnumel) + return True + except ValueError: + return False + + def codegen_static_numels(self, code): + """ + We get a small speedup from hard coding numels if they are static. + + This code stomps on the passed-in values by writing an constant to the top of the kernel. + + In a kernel like: + def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + + We would add + xnumel = 4096 + r0_numel = 768 + + After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes + a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream + knows that its a static numel, as that you just plop a constant into the kernel. + """ + + def is_static_integer(expr: sympy.Expr) -> bool: + return isinstance(expr, (sympy.Integer, int)) + + for tree in self.range_trees: + if not tree.is_reduction or self.inside_reduction: + simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) + if is_static_integer(simplified_tree_numel): + code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}") + + if tree.is_reduction and self.persistent_reduction: + if self.cooperative_reduction: + numel = self.kexpr(self.rename_indexing(tree.numel)) + val = f"triton_helpers.constexpr_next_power_of_2(({numel} + RSPLIT - 1) // RSPLIT)" + else: + val = self._get_persistent_RBLOCK(tree.numel) + code.writeline(f"{tree.prefix.upper()}BLOCK: tl.constexpr = {val}") + + if tree.prefix == "x" and self.no_x_dim: + code.writeline("XBLOCK: tl.constexpr = 1") + + def _get_grid_type(self) -> type[triton_heuristics.GridExpr]: + n = sum([int(not tree.is_reduction) for tree in self.range_trees]) + if self.cooperative_reduction: + assert n == 1 + return triton_heuristics.CooperativeReductionGrid + elif n == 1: + return triton_heuristics.Grid1D + elif n == 2: + if any(map(self.needs_yz_grid_overflow, self.range_trees)): + return triton_heuristics.Grid2DWithYZOverflow + return triton_heuristics.Grid2D + elif n == 3: + return triton_heuristics.Grid3D + raise ValueError(f"Unsupported number of dimensions: {n}") + + def add_numel_to_call_args(self, name, call_args, arg_types): + # TODO(jansel): if there are constants, we shouldn't bother passing them as args + for tree in self.range_trees: + if isinstance(tree.numel, (sympy.Integer, sympy.Symbol)): + expr = tree.numel + else: + expr = V.graph.wrapper_code.generate_numel_expr(name, tree) + + if not tree.is_reduction or self.inside_reduction: + call_args.append(expr) + arg_types.append(type(expr)) + + def call_kernel(self, name: str, node: Optional[IRNode] = None): + wrapper = V.graph.wrapper_code + wrapper.write_triton_header_once() + _, call_args, _, arg_types = self.args.python_argdefs() + self.add_numel_to_call_args(name, call_args, arg_types) + + for ws in self.args.workspace_args: + wrapper.generate_workspace_allocation(ws) + + wrapper.generate_kernel_call( + name, + call_args, + triton=True, + arg_types=arg_types, + triton_meta=self.triton_meta, + ) + + for ws in reversed(self.args.workspace_args): + wrapper.generate_workspace_deallocation(ws) + + def codegen_nan_check(self) -> None: + wrapper = V.graph.wrapper_code + _, call_args, arg_signatures, _ = self.args.python_argdefs() + for arg, arg_signature in zip(call_args, arg_signatures): + if isinstance(arg_signature, TensorArg): + if V.graph.cpp_wrapper: + wrapper.writeline( + f'AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan("{arg}", {arg}));' + ) + else: + line = f"assert not {arg}.isnan().any().item()" + wrapper.writeline(line) + line = f"assert not {arg}.isinf().any().item()" + wrapper.writeline(line) + + def create_cse_var(self, *args, **kwargs) -> TritonCSEVariable: + return TritonCSEVariable(*args, **kwargs) + + def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): + line = f"{entry.name} = {self.kexpr(self.rename_indexing(entry.expr))}" + if entry.root.is_loop: + self.indexing_code.writeline(line) + else: + # lift non-reduction stores outside loop + self.body.writeline(line) + + def iteration_ranges_ranges_code(self, entry: IterationRangesRoot) -> str: + assert entry.tensor_dim is not None + size = self.indexing_size_str(entry.tensor_dim) + index_dtype = self.index_dtype + suffix = f".to({index_dtype})" if index_dtype != "tl.int32" else "" + if ( + self.cooperative_reduction + and self.persistent_reduction + and entry.is_reduction + ): + suffix = f"{suffix} + rsplit_start" + return f"tl.arange(0, {entry.prefix.upper()}BLOCK){size}{suffix}" + + def iteration_ranges_scalar_code( + self, entry: IterationRangesRoot, value: Any + ) -> str: + index_dtype = self.index_dtype + ndim = self.triton_tensor_ndim() + size = [1] * ndim + return f"tl.full({size}, {value}, {index_dtype})" + + def iteration_ranges_get_pid(self, entry: IterationRangesRoot) -> str: + assert entry.grid_dim is not None + key = f"tl.program_id({entry.grid_dim})" + # y_grid has a limit, so express it in terms of y and z in case of overflow. + # z grid is only exercised when max_tiles == 3 (off by default). + if self.needs_yz_grid_overflow(entry): + # For ynumel larger than max_ygrid, we need to use zdim. + # For each z dimension, there are tl.num_programs(1) yblocks which is passed by grad(x,y,z). + # So, we need to add tl.program_id(z) * tl.num_programs(y) *YBLOCK to get the correct yoffset. + key = f"({key} + tl.program_id({entry.grid_dim + 1}) * tl.num_programs({entry.grid_dim}))" + pid = entry.pid_cache.get(key, key) + if self.index_dtype != "tl.int32": + return f"{pid}.to({self.index_dtype})" + return pid + + def needs_yz_grid_overflow(self, entry: IterationRangesRoot) -> bool: + return ( + entry.grid_dim == 1 + and not entry.has_zdim + and not self.cooperative_reduction + and not V.graph.sizevars.statically_known_leq(entry.numel, get_max_y_grid()) + ) + + def max_block(self, prefix: str) -> int: + if self.fixed_config: + return self.fixed_config[f"{prefix.upper()}BLOCK"] + return TRITON_MAX_BLOCK[prefix.upper()] + + def _has_constant_mask(self, tree: IterationRangesRoot) -> bool: + if not self.optimize_mask: + return False + + if self.fixed_config and f"{tree.prefix.upper()}BLOCK" in self.fixed_config: + if self.fixed_config[f"{tree.prefix.upper()}BLOCK"] == 1: + return True + else: + if V.graph.sizevars.statically_known_equals(tree.numel, 1): + return True + + # Masks are superfluous if numel is a multiple of BLOCK + # (We use the fact that BLOCK is required by triton to be a power of 2) + if tree.is_reduction and self.persistent_reduction: + max_block = self._get_persistent_RBLOCK(tree.numel) + elif tree.prefix == "x" and self.no_x_dim: + max_block = 1 + else: + max_block = self.max_block(tree.prefix) + + if tree.is_reduction and self.cooperative_reduction: + max_block = max_block * self.max_rsplit() + + # [Note: Constant mask optimisation] + # Optional optimization: if block divides numel exactly, we will + # never need to do a masked load to handle stragglers at the end. + # If this tree is for the y dimension, we should only use a constant + # mask if it can be guaranteed that: + # 1. (ynumel / YBLOCK) < max_ygrid or + # 2. (ynumel / YBLOCK) % max_ygrid == 0 + # Because YBLOCK is not constant, use a conservative heuristic: + # only use a constant mask if ynumel < max_ygrid. + # It's faster to avoid masking at all. But it is sound to always + # mask. + if V.graph.sizevars.statically_known_multiple_of(tree.numel, max_block): + return ( + tree.grid_dim != 1 + or tree.has_zdim + or V.graph.sizevars.statically_known_leq(tree.numel, get_max_y_grid()) + ) + + return False + + def _has_constant_xmask(self) -> bool: + xtree = self.range_trees[0] + assert xtree.prefix == "x" + return self._has_constant_mask(xtree) + + def filter_masks(self, mask_vars: OrderedSet[str]) -> None: + for tree in self.range_trees: + if self._has_constant_mask(tree): + mask_vars.discard(f"{tree.prefix}mask") + + # can be added as an override_mask + mask_vars.discard("None") + + @cache_on_self + def get_reduction_prefixes(self) -> list[str]: + return [ + prefix_str[symt] + for symt in list(TritonSymbols.reduction_types)[: self.num_reduction_dims] + ] + + def codegen_reduction_numels(self, buffer: IndentedBuffer) -> None: + """ + Generates code that flattens ND reduction numels, block sizes, etc. into 1D. + """ + # rnumel = r0_numel * ... * r(n-1)_numel + reduction_trees = [tree for tree in self.range_trees if tree.is_reduction] + rnumel = " * ".join(sorted(f"{tree.prefix}numel" for tree in reduction_trees)) + buffer.splice(f"rnumel = {self.kexpr(rnumel)}") + + # RBLOCK = R0_BLOCK * ... * R(N-1)_BLOCK + rn_blocks = [ + TritonSymbols.block_sizes[tree.symt] + for tree in self.range_trees + if tree.is_reduction + ] + rblock = sympy_product(rn_blocks) + buffer.splice(f"RBLOCK: tl.constexpr = {self.kexpr(rblock)}") + + def _get_reduction_symbols(self, suffix: str, **kwargs) -> list[sympy.Symbol]: + """ + Helper to initialize symbols like rn_numel, rn_base, etc. + """ + rn_prefixes = self.get_reduction_prefixes() + return [sympy.Symbol(f"{prefix}{suffix}", **kwargs) for prefix in rn_prefixes] + + @cache_on_self + def _get_reduction_index_coeffs(self) -> list[sympy.Expr]: + """ + Compute coefficients to convert ND reduction indices to linear indices. + For example: + rindex = r0_index * r1_numel * ... * rn_numel + ... + rn_index. + """ + rn_prefixes = self.get_reduction_prefixes() + rn_numels = self._get_reduction_symbols("numel", integer=True, positive=True) + return [ + sympy_product(rn_numels[idx + 1 :]) for idx in range(len(rn_prefixes) - 1) + ] + [sympy.Integer(1)] + + def _flatten_reduction_indices(self, multi_inds: list[sympy.Expr]) -> sympy.Expr: + """ + Compute linear reduction indices from N dimensional ones. + """ + coeffs = self._get_reduction_index_coeffs() + return sympy_dot(coeffs, multi_inds) + + def codegen_reduction_indices(self, buffer: IndentedBuffer) -> None: + """ + Generates code that converts ND reduction indices into linear indices. + """ + # Gather relevant numels, indices, etc. + rn_offsets = self._get_reduction_symbols( + "offset", integer=True, nonnegative=True + ) + rn_inds = self._get_reduction_symbols("index", integer=True, nonnegative=True) + + # Compute roffset and rindex. + roffset = self._flatten_reduction_indices(rn_offsets) + buffer.splice(f"roffset = {self.index_to_str(roffset)}") + rindex = self._flatten_reduction_indices(rn_inds) + buffer.splice(f"rindex = {self.index_to_str(rindex)}") + + def iteration_ranges_codegen_header( + self, entry: IterationRangesRoot, code: IndentedBuffer + ) -> None: + x = entry.prefix + if entry.is_loop: + code.writeline(f"{entry.name} = {x}offset + {x}base") + elif entry.grid_dim is None: + # no need to "{x}offset = " + code.writeline(f"{entry.name} = {self.iteration_ranges_ranges_code(entry)}") + code.writeline(f"{x}offset = 0") + else: + if entry.tensor_dim is not None: + line = f"{x}offset + {self.iteration_ranges_ranges_code(entry)}" + else: + line = self.iteration_ranges_scalar_code(entry, f"{x}offset") + code.writelines( + [ + f"{x}offset = {self.iteration_ranges_get_pid(entry)} * {x.upper()}BLOCK", + f"{entry.name} = {line}", + ] + ) + + if self._has_constant_mask(entry): + sizes = self.dense_size_str() + code.writeline(f"{x}mask = tl.full({sizes}, True, tl.int1)") + else: + code.writeline(f"{x}mask = {entry.name} < {x}numel") + + +class TritonScheduling(SIMDScheduling): + kernel_type: type[Any] = TritonKernel + backend_features = OrderedSet( + [ + BackendFeature.FOREACH, + BackendFeature.BUCKETIZE, + BackendFeature.INPLACE_BUFFERS, + BackendFeature.MASKED_SCATTER_WITH_INDEX, + BackendFeature.SCAN, + BackendFeature.SORT, + BackendFeature.TRITON_TEMPLATES, + BackendFeature.TUPLE_REDUCTION, + ] + ) + + def __init__(self, scheduler: Optional[Scheduler]) -> None: + super().__init__(scheduler) + if scheduler is None or not hasattr(scheduler, "nodes"): + return + for node in scheduler.nodes: + if isinstance(node, (SchedulerNode, FusedSchedulerNode)): + node.debug_device_str = debug_triton_code + + @classmethod + def get_backend_features(cls, device: torch.device): + if ( + config.triton.cooperative_reductions + or config.triton.force_cooperative_reductions + ): + return OrderedSet( + [*cls.backend_features, BackendFeature.REDUCE_TO_SINGLE_ELEMENT] + ) + return cls.backend_features + + def codegen_comment(self, node_schedule): + wrapper = V.graph.wrapper_code + origins, _detailed_origins = get_kernel_metadata(node_schedule, wrapper) + if origins: + wrapper.make_comment(origins) + + if config.debug_fusion: + from torch._inductor.scheduler import ( + BaseSchedulerNode, + ForeachKernelSchedulerNode, + ) + + if not any( + isinstance(n, ForeachKernelSchedulerNode) for n in node_schedule + ): + # We probably should look what are the nodes inside a foreach + # schedule node + node_names = [ + n.get_name() + for n in node_schedule + if isinstance(n, BaseSchedulerNode) + ] + wrapper.make_comment( + f"{wrapper.comment} Fused node name list: {', '.join(node_names)}" + ) + + def define_kernel(self, src_code, node_schedule, kernel): + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + kernel_name = wrapper.src_to_kernel[src_code] + else: + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_category = get_kernel_category_by_source_code(src_code)[:3] + kernel_name = "_".join( + ["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()] + ) + # use the original src_code as the key + wrapper.src_to_kernel[src_code] = kernel_name + subs_name = kernel_name if config.triton.unique_kernel_names else "triton_" + + # DESCRIPTIVE_NAME is used for profiling purposes; it shows the full kernel name + # even when unique_kernel_names is turned off. Meanwhile, KERNEL_NAME is sometimes set + # to "triton_" to maximize caching opportunities (when unique_kernel_names = False). + src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name) + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), subs_name) + + # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does + # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. + src_code = src_code.replace("#pragma CMT", "#") + + _basename, _, kernel_path = get_path(code_hash(src_code.strip()), "py") + compile_wrapper = IndentedBuffer() + + if async_compile.use_process_pool(): + # The process pool is warm, we can shell out to workers right away. This + # allows us to save the result in async_compile.CompiledTritonKernels, + # so that the second time we call async_compile.triton, we do no work. + async_compile.triton(subs_name, src_code) + + compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''") + + compile_wrapper.splice(src_code, strip=True) + current_device = V.graph.get_current_device_or_throw() + compile_wrapper.writeline(f"''', device_str='{current_device.type}')") + + metadata_comment = f"# kernel path: {kernel_path}" + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment += "\n" + origins + "\n" + detailed_origins + wrapper.define_kernel( + kernel_name, compile_wrapper.getvalue(), metadata_comment + ) + + # log kernel metadata for offline analysis. + # E.g. one can find all unaligned inner reduction and check if + # padding helps with the perf kernel by kernel. + if metrics.is_metric_table_enabled("kernel_metadata"): + metrics.log_kernel_metadata(kernel_name, kernel_path, src_code) + + return kernel_name + + def benchmark_fused_nodes(self, nodes, n_spills_threshold=8) -> tuple[float, str]: + """ + Benchmark fused list of nodes and return the execution time + in milliseconds on randomly generated inputs. + """ + src_code = self.generate_kernel_code_from_nodes(nodes, benchmark_kernel=True) + mod = PyCodeCache.load(src_code) + return self.benchmark_codegened_module( + mod, n_spills_threshold, node_names=OrderedSet(n.get_name() for n in nodes) + ) + + def benchmark_codegened_module( + self, mod, n_spills_threshold=8, node_names: Optional[OrderedSet[str]] = None + ) -> tuple[float, str]: + """Benchmark an already compiled module""" + device_interface = get_interface_for_device(V.graph.device_type) + with ( + preserve_rng_state(), + device_interface.device(V.graph.get_current_device_or_throw()), # type: ignore[attr-defined] + ): + ms = None + + def cache_file_path(): + assert mod.__file__ is not None + return os.path.splitext(mod.__file__)[0] + ".kernel_perf" + + def store_cache(): + path = cache_file_path() + write_atomic(path, str(ms)) + + def load_cache(): + path = cache_file_path() + if os.path.exists(path): + with open(path) as fd: + return float(fd.read()) + return None + + node_names = ( + node_names if node_names is not None else OrderedSet(["unknown"]) + ) + log.debug( + "kernel src code for %s written to: %s", + node_names, + mod.__file__, + ) + ms = load_cache() + if ms is not None: + return ms, mod.__file__ + + args = mod.get_args() + call = mod.call + wrapped_jit_function = mod.triton_ + # call once to trigger the compilation + try: + call(wrapped_jit_function.clone_args(*args)[0]) + except Exception as e: + if config.triton.disallow_failing_autotune_kernels_TESTING_ONLY: + raise + log.debug( + "Exception (%s) in compiling fused nodes %s", + e, + node_names, + ) + ms = float("inf") + store_cache() + return ms, mod.__file__ + + launchers = wrapped_jit_function.launchers + assert len(launchers) == 1 + # n_spills does not necessarily mean it's not profitable to fuse, + # and sometimes it can be inaccurate + if launchers[0].n_spills > n_spills_threshold: + # skip benchmarking the kernel if there are register spills + ms = float("inf") + else: + # We have to clone the inplace updated arguments to avoid earlier calls + # generating out of range indices for later calls. + ms = benchmarker.benchmark_gpu( + lambda: call(wrapped_jit_function.clone_args(*args)[0]) + ) + # overhead of cloning args gives bias for fusing the kernel + # in the case of mutating/in-placeable second fusion + # TODO - would be better as a hook in triton do_bench that reset + # the input values between benchmarking + if len(wrapped_jit_function.mutated_arg_names) > 0: + ms = ms - benchmarker.benchmark_gpu( + lambda: wrapped_jit_function.clone_args(*args) + ) + + log.debug( + "The fused kernel for %s took %.3f ms to run", + node_names, + ms, + ) + store_cache() + return ms, mod.__file__ + + def create_kernel_choices( # type: ignore[override] + self, + kernel_features: SIMDKernelFeatures, + kernel_args: list[Any], + kernel_kwargs: dict[str, Any], + ) -> list[TritonKernel]: + is_scan = kernel_features.contains_op("scan") + is_split_scan = is_scan and any( + node.is_split_scan() for node in kernel_features.scheduler_nodes() + ) + kernel_type: type[TritonKernel] = self.kernel_type + if is_split_scan: + from .triton_split_scan import TritonSplitScanKernel + + kernel_type = TritonSplitScanKernel + + if is_scan: + # TODO(jansel): scan does not yet work with cooperative reductions + kernel_kwargs["override_cooperative_reduction"] = False + + # ops.sort only works with persistent reduction, and is not bandwidth bound anyway + # so taking the hit of non-coalesced loads is okay + if kernel_features.contains_op("sort"): + kernel_kwargs["override_persistent_reduction"] = True + kernel_kwargs["override_cooperative_reduction"] = False + + if not TritonKernel.has_persistent_RBLOCK(kernel_features.reduction_numel): + # Cannot use persistent reduction with unknown dynamic rnumel + assert not kernel_kwargs.get("override_persistent_reduction") + kernel_kwargs["override_persistent_reduction"] = False + + kernel_kwargs = V.choices.triton_kernel_kwargs( + kernel_type, kernel_features, kernel_args, kernel_kwargs + ) + kernel = kernel_type(*kernel_args, **kernel_kwargs) + return self.add_multi_kernel_choices(kernel, kernel_args, kernel_kwargs) + + def add_multi_kernel_choices( + self, + kernel: TritonKernel, + kernel_args: list[Any], + kernel_kwargs: dict[str, Any], + ) -> list[TritonKernel]: + kernels: list[TritonKernel] = [kernel] + if not config.triton.multi_kernel: + return kernels + + optional_persistent = kernel.persistent_reduction and not kernel_kwargs.get( + "override_persistent_reduction" + ) + optional_cooperative = kernel.cooperative_reduction and not kernel_kwargs.get( + "override_cooperative_reduction" + ) + if optional_persistent: + kernels.append( + self.kernel_type( + *kernel_args, + **kernel_kwargs, + override_persistent_reduction=False, + ) + ) + if optional_cooperative: + rnumel = kernel.features.reduction_numel + # for larger sizes non-cooperative gets very slow + if V.graph.sizevars.statically_known_leq(rnumel, 65536): + kernels.append( + other := self.kernel_type( + *kernel_args, + **kernel_kwargs, + override_cooperative_reduction=False, + ) + ) + if optional_persistent and other.persistent_reduction: + kernels.append( + self.kernel_type( + *kernel_args, + **kernel_kwargs, + override_cooperative_reduction=False, + override_persistent_reduction=False, + ) + ) + + if len(kernels) > 1: + for kernel2 in kernels[1:]: + # Keep buffers needed by the non-persistent reduction so both kernels have the same arguments + kernel2.must_keep_buffers = kernel.must_keep_buffers + # persistent kernels must be generated last so must_keep_buffers works right + kernels.sort(key=lambda k: k.persistent_reduction) + return kernels + + def benchmark_combo_kernel(self, node_list): + mod: ModuleType + ms: float + ms_clone: float + + def cache_file_path(): + assert mod.__file__ is not None + return os.path.splitext(mod.__file__)[0] + ".kernel_perf" + + def load_cache(): + path = cache_file_path() + if os.path.exists(path): + with open(path) as fd: + return tuple(float(e) for e in fd.read().split()) + return (None, None) + + def store_cache(): + path = cache_file_path() + write_atomic(path, str(ms) + " " + str(ms_clone)) + + total_ms, file_list = 0, [] + total_clone_ms: float = 0.0 + removed_buffers_orig = V.graph.removed_buffers + V.graph.removed_buffers = OrderedSet(removed_buffers_orig) + inplaced_to_remove_orig = V.graph.inplaced_to_remove + V.graph.inplaced_to_remove = OrderedSet(inplaced_to_remove_orig) + enable_autotune = config.combo_kernels_autotune > 0 + mixed_sizes = config.combo_kernel_allow_mixed_sizes > 0 + kernel_code_list = self.generate_combo_kernel_code( + subkernel_nodes=node_list, + custom_part_algorithm=True, + enable_autotune=enable_autotune, + mixed_sizes=mixed_sizes, + only_gen_src_code=True, + ) + + for src_code, _, node_group in kernel_code_list: + fused_node_lists = [node.get_nodes() for node in node_group] + names = [n.get_name() for nodes in fused_node_lists for n in nodes] + + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") + mod = PyCodeCache.load(src_code) + + log.debug( + "kernel src code for %s written to: %s", + names, + mod.__file__, + ) + ms, ms_clone = load_cache() + if ms is not None: + total_ms += ms # type: ignore[assignment] + total_clone_ms += ms_clone + file_list.append(mod.__file__) + continue + + args = mod.get_args() + call = mod.call + wrapped_jit_function = mod.triton_ + + # call once to trigger the compilation + call(wrapped_jit_function.clone_args(*args)[0]) + + launchers = wrapped_jit_function.launchers + assert len(launchers) == 1 + if launchers[0].n_spills > 0: + # skip benchmarking the kernel if there are register spills + ms = ms_clone = float("inf") + else: + # We have to clone the inplace updated arguments to avoid earlier calls + # generating out of range indices for later calls. + ms = benchmarker.benchmark_gpu( + lambda: call(wrapped_jit_function.clone_args(*args)[0]) + ) + ms_clone = benchmarker.benchmark_gpu( + lambda: wrapped_jit_function.clone_args(*args)[0] + ) + + log.debug( + "The fused kernel for %s took %.3f ms to run, %.3f ms to clone inputs", + OrderedSet(n.get_name() for n in node_group), + ms, + ms_clone, + ) + store_cache() + total_ms += ms + total_clone_ms += ms_clone + file_list.append(mod.__file__) + V.graph.removed_buffers = removed_buffers_orig + V.graph.inplaced_to_remove = inplaced_to_remove_orig + return total_ms, total_clone_ms, file_list + + +def debug_triton_code(node: BaseSchedulerNode) -> list[str]: + lines = [] + multi_template = node.get_template_node() + assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer) + if multi_template and multi_template.make_kernel_render is None: + lines.append(f"{node.get_name()} Unfinalized multi template buffer") + else: + from torch._inductor.codegen.cuda_combined_scheduling import ( + CUDACombinedScheduling, + ) + + device = node.get_device() + assert device is not None + backend = node.scheduler.get_backend(device) + assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling)), ( + f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}" + ) + + with V.graph.set_current_device(device): + # Don't increment kernel count when generating debug string. + # This will confuse some unit tests that check the number of + # generated kernels. + old_generated_kernel_count = metrics.generated_kernel_count + triton_code = backend.generate_kernel_code_from_nodes( + node.get_nodes() + ).strip() + metrics.generated_kernel_count = old_generated_kernel_count + + lines.append(f"{node.get_name()} Triton code:") + lines.append(textwrap.indent(triton_code, " ")) + return lines diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/triton_combo_kernel.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/triton_combo_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..0dc96e3031c1660ec1c5ce09926e3cb675cf815e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/triton_combo_kernel.py @@ -0,0 +1,978 @@ +import itertools +import logging +import textwrap +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Callable, cast, Optional, Union + +import sympy +from sympy import Integer, Symbol + +from torch.utils._ordered_set import OrderedSet + +from .. import config, metrics +from ..runtime.hints import DeviceProperties +from ..runtime.runtime_utils import next_power_of_2 +from ..runtime.triton_heuristics import ( + RoundRobinComboKernelGrid, + SequentialComboKernelGrid, +) +from ..scheduler import BaseSchedulerNode +from ..utils import Placeholder, triton_version_uses_attrs_dict +from ..virtualized import V +from .common import ( + ArgName, + ConstexprArg, + DeferredLine, + IndentedBuffer, + InplacedBuffer, + Kernel, + PythonPrinter, + RemovedArg, + SizeArg, + WorkspaceArg, +) +from .simd import prefix_is_reduction, SIMDScheduling +from .simd_kernel_features import SIMDKernelFeatures +from .triton import gen_common_triton_imports, TritonKernel +from .triton_utils import config_of, signature_to_meta + + +log = logging.getLogger(__name__) +pexpr = PythonPrinter().doprint +LARGE_NUMELS = 512e5 +BLOCK_UTILIZATION = 0.8 + + +def _default_custom_combo_kernel_horizontal_partition( + nodes: list[BaseSchedulerNode], + triton_scheduling: SIMDScheduling, + kernel_map: dict[BaseSchedulerNode, TritonKernel], + node_info_map: dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]], +) -> list[list[BaseSchedulerNode]]: + """Horizontally partition the given list of nodes into a list of list of nodes where each sublist + represents a partition. Nodes in different partitions are implemented in different combo kernels. + Nodes in the same partition are likely to be implemented + in the same combo kernel, but subject to subsequent restrictions like CUDA limits for number of args. + + Input arguments: + nodes: a list of fused scheduler nodes to partition. + triton_scheduling: TritonScheduling instance. + kernel_map: a map from node to its kernel. + node_info_map: a map from node to (node_schedule, tiled_groups, numel, rnumel). + Output: + a list of list of nodes with each sublist representing a partition. + + The default algorithm is to partition nodes based on the following rules: + 1) nodes with the same number of block dimensions are grouped together. + 2) large pointwise nodes (numels greater than LARGE_NUMELS) are separated from other nodes. + 3) large reduce nodes are separated from other nodes. + """ + + assert len(nodes) >= 1 + + # first partition nodes based on number of block dimensions + tilings = [node_info_map[n][1] for n in nodes] + + max_dims = max(len(t) for t in tilings) + nodes_per_ndim: list[list[BaseSchedulerNode]] = [] + for i in range(2, max_dims + 1): + group_per_dim = [n for n, t in zip(nodes, tilings) if len(t) == i] + reduction = [ + n + for n in group_per_dim + if kernel_map[n].inside_reduction + and not (kernel_map[n].persistent_reduction and kernel_map[n].no_x_dim) + ] + not_reduction = [n for n in group_per_dim if n not in reduction] + # rnumel > 2048 usually has long execution time + # BaseSchedulerNode.group[-1][-1] is rnumel for reduction nodes + long_reduction = [ + n + for n in reduction + if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 # type: ignore[arg-type] + ] + short_reduction = [n for n in reduction if n not in long_reduction] + if long_reduction: + log.warning( + "ComboKernels: %d long reduction nodes are separated", + len(long_reduction), + ) + large_pointwise = [ + n + for n in not_reduction + if not kernel_map[n].inside_reduction + and len(kernel_map[n].numels) == 2 + and V.graph.sizevars.size_hint(kernel_map[n].numels["x"]) > LARGE_NUMELS + ] + if large_pointwise: + # TODO benchmark the performance when large pointwise nodes combining with others + log.warning( + "ComboKernels: %d large pointwise nodes are separated", + len(large_pointwise), + ) + not_reduction = [n for n in not_reduction if n not in large_pointwise] + nodes_per_ndim.extend([node] for node in large_pointwise) + + nodes_per_ndim.extend( + g for g in (not_reduction, short_reduction, long_reduction) if g + ) + + assert sum(len(p) for p in nodes_per_ndim) == len(nodes) + return nodes_per_ndim + + +_custom_combo_kernel_horizontal_partition_algorithm: Callable[ + [ + list[BaseSchedulerNode], + SIMDScheduling, + dict[BaseSchedulerNode, TritonKernel], + dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]], + ], + list[list[BaseSchedulerNode]], +] = _default_custom_combo_kernel_horizontal_partition + + +def set_custom_combo_kernel_horizontal_partition( + algorithm: Callable[ + [ + list[BaseSchedulerNode], + SIMDScheduling, + dict[BaseSchedulerNode, TritonKernel], + dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]], + ], + list[list[BaseSchedulerNode]], + ], +) -> None: + """Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions + are implemented in different combo kernels. Nodes in the same partition are likely to be implemented + in the same combo kernel, but subject to subsequent restricts like CUDA limits for number of args. + + The algorithm should take a list of nodes and return a list of list of nodes. + + The default algorithm is to partition nodes based on number of block dimensions. + """ + global _custom_combo_kernel_horizontal_partition_algorithm + _custom_combo_kernel_horizontal_partition_algorithm = algorithm + + +@dataclass +class PartitionState: + partitions: list[list[BaseSchedulerNode]] + cur_partition: list[BaseSchedulerNode] + cur_count: int + + def finalize(self) -> None: + if self.cur_partition: + self.partitions.append(self.cur_partition) + + +class ComboKernel(Kernel): + MAX_NUM_ARGS = 250 # number where I would no longer get triton errors + + @staticmethod + def _update_partition( + partition_state: PartitionState, + node_rw_count: int, + node_info: BaseSchedulerNode, + ) -> None: + if partition_state.cur_count + node_rw_count > ComboKernel.MAX_NUM_ARGS: + partition_state.partitions.append(partition_state.cur_partition) + partition_state.cur_partition = [node_info] + partition_state.cur_count = node_rw_count + else: + partition_state.cur_count += node_rw_count + partition_state.cur_partition.append(node_info) + + @staticmethod + def _base_horizontal_partition( + subkernel_nodes: list[BaseSchedulerNode], + triton_scheduling: SIMDScheduling, + node_info_map: dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]], + custom_algorithm: bool, + ) -> list[list[BaseSchedulerNode]]: + """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel) + for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args + (read/writes) and to have the same 2D or 1D blocking strategy.""" + # TODO support combination of kernels with different block dimensions + assert len(subkernel_nodes) >= 1 + mixed_sizes = config.combo_kernel_allow_mixed_sizes > 1 or ( + config.combo_kernel_allow_mixed_sizes == 1 and custom_algorithm + ) + + ndim_to_partition_state: dict[int, PartitionState] = defaultdict( + lambda: PartitionState([], [], 0) + ) + yelem_to_partition_state: dict[int, PartitionState] = defaultdict( + lambda: PartitionState([], [], 0) + ) + + for node in subkernel_nodes: + _node_schedule, tiled_groups, _numel, _rnumel = node_info_map[node] + node_info = node + + read_writes = node.read_writes + read_write_count = len(read_writes.reads) + len(read_writes.writes) + + ndim = len(tiled_groups) + assert ndim >= 2, f"Combokernel not support tile {tiled_groups}" + if not mixed_sizes and ndim == 3: + y_elem = tiled_groups["y"] + partition_state = yelem_to_partition_state[y_elem] + ComboKernel._update_partition( + partition_state, read_write_count, node_info + ) + else: + assert mixed_sizes or ndim <= 3, f"No mixed sizes: tile {tiled_groups}" + partition_state = ndim_to_partition_state[ndim] + ComboKernel._update_partition( + partition_state, read_write_count, node_info + ) + + all_partitions = [] + for partition_state in ndim_to_partition_state.values(): + partition_state.finalize() + all_partitions.extend(partition_state.partitions) + for partition_state in yelem_to_partition_state.values(): + partition_state.finalize() + all_partitions.extend(partition_state.partitions) + + return all_partitions + + @staticmethod + def horizontal_partition( + nodes: list[BaseSchedulerNode], + triton_scheduling: SIMDScheduling, + kernel_map: dict[BaseSchedulerNode, TritonKernel], + node_info_map: dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]], + custom_algorithm: bool = False, + ) -> list[list[BaseSchedulerNode]]: + """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnum) + for each subkernel node where each sublist forms a ComboKernel. It horizontally partitions nodes into + sublists in the following way: + 1) call _custom_combo_kernel_horizontal_partition_algorithm() if custom_algorithm is True + 2) then, call _base_horizontal_partition() to partition nodes into sublists, each sublist is + guaranteed to not exceed CUDA limits for number of args (read/writes) and to have the same + 2D or 1D blocking strategy. + """ + if custom_algorithm: + raw_partitions = _custom_combo_kernel_horizontal_partition_algorithm( + nodes, triton_scheduling, kernel_map, node_info_map + ) + else: + raw_partitions = [nodes] + + """Generates a list of lists of node info tuples which consist of (fused_nodes, tiling, numel, rnumel) + for each subkernel node where each sublist is guaranteed to not exceed CUDA limits for number of args + (read/writes) and to have the same 2D or 1D blocking strategy.""" + all_partitions = [] + for raw_partition in raw_partitions: + all_partitions.extend( + ComboKernel._base_horizontal_partition( + raw_partition, triton_scheduling, node_info_map, custom_algorithm + ) + ) + return all_partitions + + class SequentialDispatch: + """ + The dispatcher which dispatches the subkernels in a sequential manner: + the blocks are first dispatched to the 1st subkernel (until it is filled), + then to the 2nd subkernel, and so on. + The class defines the methods specific to the dispatch algorithm. + Methods: + codegen_pid_range(...): codegen the pid range for each subkernel. + grid(...): codegen the grid size for launching the combo kernel. + """ + + grid_expr = SequentialComboKernelGrid + + @classmethod + def codegen_pid_range( + cls, kernel: "ComboKernel", num: int, code: IndentedBuffer + ) -> None: + if num == 0: + cls._calculate_xblocks(kernel, code) + code.splice(f"if pid < num_xblocks_{num}:") + with code.indent(): + code.splice("pid_offset = pid") + else: + code.splice(f"elif pid < num_xblocks_{num}:") + with code.indent(): + code.splice(f"pid_offset = pid - num_xblocks_{num - 1}") + + @classmethod + def _calculate_xblocks( + cls, kernel: "ComboKernel", code: IndentedBuffer + ) -> None: + x_numels_list = kernel.x_numels_list + for i in range(len(x_numels_list)): + xnumels, no_x_dim = ( + (x_numels_list[i], False) + if isinstance(x_numels_list[i], str) + and cast(str, x_numels_list[i])[0] != "-" + or ( + isinstance(x_numels_list[i], int) + and cast(int, x_numels_list[i]) > 0 + ) + else (kernel.min_x_blocks_list[i], True) + ) + xblock_str = ( + f"tl.cdiv({xnumels}, XBLOCK)" if not no_x_dim else f"{xnumels}" + ) + if i == 0: + code.splice(f"num_xblocks_{i} = {xblock_str}") + else: + code.splice(f"num_xblocks_{i} = num_xblocks_{i - 1} + {xblock_str}") + + class RoundRobinDispatch: + """ + The dispatcher which dispatches the subkernels in a round robin manner: + the blocks are interleavedly dispatched to each subkernel to execute them + in parallel. + The class defines the methods specific to the dispatch algorithm. + Methods: + codegen_pid_range(...): codegen the pid range for each subkernel. + grid(...): codegen the grid size for launching the combo kernel. + """ + + grid_expr = RoundRobinComboKernelGrid + + @classmethod + def codegen_pid_range( + cls, kernel: "ComboKernel", num: int, code: IndentedBuffer + ) -> None: + num_kernels = len(kernel.sub_kernels) + if num == 0: + cond = "if" + else: + cond = "elif" + code.splice(f"{cond} pid % {num_kernels} == {num}:") + with code.indent(): + code.splice(f"pid_offset = pid // {num_kernels}") + + def __init__( + self, enable_autotune: bool = False, mixed_sizes: bool = False + ) -> None: + super().__init__() + self.sub_kernels: list[TritonKernel] = [] + self.iter_vars_count = itertools.count() + self.grids: list[list[int]] = [] + self.min_x_blocks_list: list[Union[int, str]] = [] + self.x_numels_list: list[Union[int, str]] = [] + self.enable_autotune = enable_autotune + self.mixed_sizes = mixed_sizes + self.dispatch_class: Optional[ + type[Union[ComboKernel.SequentialDispatch, ComboKernel.RoundRobinDispatch]] + ] = None + self.block_args: list[str] = [] + # there following are used when autotuning is disabled + self.block_size_1d = 1024 # Try tuning this value + self.block_size_2d = 32 + self.num_warps = 8 + self.block_size_reduce = 256 + self.dynamic_shape_args: list[str] = [] + + def create_sub_kernel(self, triton_kernel: TritonKernel) -> TritonKernel: + sub_kernel = triton_kernel + metrics.generated_kernel_count -= 1 + sub_kernel.args = self.args + sub_kernel.iter_vars_count = self.iter_vars_count + sub_kernel.cse.iter_buffer_ids = self.cse.iter_buffer_ids + self.sub_kernels.append(sub_kernel) + return sub_kernel + + @staticmethod + def create_triton_kernel( + tiling: dict[str, sympy.Expr], + features: SIMDKernelFeatures, + optimize_mask: bool, + ) -> TritonKernel: + """ + Only allow optimize_mask=True when 1) sequential dispatch is used, + 2) numels except x dimension are the same for each sub kernel. + """ + return TritonKernel( + tiling, + features=features, + pid_cache={"tl.program_id(0)": "pid_offset"}, + optimize_mask=optimize_mask, + # foreach kernels don't work with cooperative reductions + override_cooperative_reduction=False, + ) + + def codegen_static_numels_sub_kernel( + self, code: IndentedBuffer, sub_kernel: TritonKernel, num: int + ) -> list[str]: + """ + We get a small speedup from hard coding numels if they are static. + + This code stomps on the passed-in values by writing an constant to the top of the kernel. + + In a kernel like: + def KERNEL_NAME(in_ptr0, in_ptr1, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): + + We would add + xnumel = 4096 + rnumel = 768 + + After the signature, before the kernel code, if we decided to make these static. As its hardcoded, it becomes + a better signal to triton on how to unroll and do some static indexing. So, it's not so much that downstream + knows that its a static numel, as that you just plop a constant into the kernel. + """ + grid = [] + uniquify_block_sizes = [] + for tree in sub_kernel.range_trees: + simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) + if isinstance(simplified_tree_numel, (Integer, int)): + code.writeline(f"{tree.prefix}numel = {int(simplified_tree_numel)}") + else: + assert f"{tree.prefix}numel_{num}" in self.dynamic_shape_args + uniquify_block_sizes.append(f"{tree.prefix}numel") + + if not tree.is_reduction: + if isinstance(simplified_tree_numel, (Integer, int)): + grid.append(int(simplified_tree_numel)) + else: + grid.append(f"{tree.prefix}numel_{num}") + + if tree.is_reduction and sub_kernel.persistent_reduction: + if isinstance(simplified_tree_numel, (Integer, int)): + val = int(simplified_tree_numel) + else: + raise RuntimeError( + "Dynamic shape on reduction dimension is not supported" + ) + val = next_power_of_2(val) + code.writeline(f"RBLOCK_{num}: tl.constexpr = {val}") + code.writeline(f"R0_BLOCK_{num}: tl.constexpr = {val}") + uniquify_block_sizes.append("R0_BLOCK") + + if tree.prefix == "x" and sub_kernel.no_x_dim: + code.writeline(f"XBLOCK_{num}: tl.constexpr = 1") + uniquify_block_sizes.append("XBLOCK") + self.grids.append(grid) + return uniquify_block_sizes + + def min_x_blocks_sub_kernel(self, sub_kernel: TritonKernel, num: int) -> None: + """ + Kernels with no_x_dim being true has no tunable XBLOCK. They have a fixed number of X blocks. + Grid calculation needs to make sure that they are assigned with enough number of blocks. + """ + min_x_blocks: Union[int, str] = 0 + x_numels: Union[int, str] = 0 + for tree in sub_kernel.range_trees: + simplified_tree_numel = V.graph.sizevars.simplify(tree.numel) + if tree.prefix == "x": + if isinstance(simplified_tree_numel, (Integer, int)): + x_numels = int(simplified_tree_numel) + else: + x_numels = f"{tree.prefix}numel_{num}" + if sub_kernel.no_x_dim: + min_x_blocks = x_numels + x_numels = ( + -min_x_blocks + if isinstance(x_numels, int) + else "-" + cast(str, x_numels) + ) + else: + if isinstance(simplified_tree_numel, (Integer, int)): + x_numels = int(simplified_tree_numel) + else: + x_numels = f"{tree.prefix}numel_{num}" + self.min_x_blocks_list.append(min_x_blocks) + self.x_numels_list.append(x_numels) + + def select_heuristics(self, sub_kernel: TritonKernel) -> tuple[str, dict[str, int]]: + size_hints = { + prefix: next_power_of_2(V.graph.sizevars.size_hint(numel)) + for prefix, numel in sub_kernel.numels.items() + if not prefix_is_reduction(prefix) or sub_kernel.inside_reduction + } + if sub_kernel.persistent_reduction: + assert sub_kernel.inside_reduction + heuristics = "persistent_reduction" + elif sub_kernel.inside_reduction: + heuristics = "reduction" + else: + heuristics = "pointwise" + return heuristics, size_hints + + def select_combo_heuristics( + self, heuristics_list: list[str], size_hints_list: list[dict[str, int]] + ) -> tuple[str, dict[str, int], TritonKernel]: + if not self.enable_autotune: + return "foreach", size_hints_list[0], self.sub_kernels[0] + if "reduction" in heuristics_list: + i, _ = max( + enumerate(size_hints_list), + key=lambda x: x[1]["x"] if heuristics_list[x[0]] == "reduction" else 0, + ) + return heuristics_list[i], size_hints_list[i], self.sub_kernels[i] + elif "pointwise" in heuristics_list: + i, _ = max( + enumerate(size_hints_list), + key=lambda x: x[1]["x"] if heuristics_list[x[0]] == "pointwise" else 0, + ) + # modify size_hint to avoid oom check fail (may be a false alarm) + num_pointwise = len([e for e in heuristics_list if e == "pointwise"]) + num_reduction = len([e for e in heuristics_list if e == "reduction"]) + num_persistent_reduction = len( + [e for e in heuristics_list if e == "persistent_reduction"] + ) + assert num_reduction == 0, ( + "combining pointwise and reduction are not supported yet." + ) + heuristics = ( + "pointwise_with_reduction" + if num_persistent_reduction > 0 + else "pointwise" + ) + if len(heuristics_list) - num_pointwise >= 4: + size_hints = size_hints_list[i] + size_hints["x"] = min(128, size_hints["x"]) + return heuristics, size_hints_list[i], self.sub_kernels[i] + else: + return heuristics_list[0], size_hints_list[0], self.sub_kernels[0] + + def get_mutated_args_sub_kernels(self) -> list[str]: + mutated_args: OrderedSet[str] = OrderedSet() + for sub_kernel in self.sub_kernels: + for mutation in sub_kernel.mutations: + if mutation in sub_kernel.args.input_buffers: + mutated_args.add(sub_kernel.args.input_buffers[mutation]) + if ( + mutation in sub_kernel.args.inplace_buffers + and mutation not in V.graph.removed_buffers + and mutation not in sub_kernel.removed_buffers + ): + mutated_args.add( + cast( + InplacedBuffer, sub_kernel.args.inplace_buffers[mutation] + ).inner_name + ) + if mutation in sub_kernel.args.output_buffers: + arg = sub_kernel.args.output_buffers[mutation] + assert not isinstance(arg, RemovedArg) + mutated_args.add(arg) + return sorted(mutated_args) + + def select_dispatch_strategy(self) -> None: + if self.dispatch_class is not None: + return + # mixed_sizes is used for optimize_mask, so it only allows sequential dispatch + # Not mixed sizes on y dim technically is ok to use round robin as wells. + if not self.mixed_sizes or any(isinstance(e, str) for e in self.x_numels_list): + # str in x_numels_list means a dynamic shape + self.dispatch_class = ComboKernel.SequentialDispatch + return + # A negative x_blocks_list element means the kernel is not tunable, + # i.e., no_x_dim = True + x_numels_list = [abs(cast(int, e)) for e in self.x_numels_list] + total = max(x_numels_list) * len(x_numels_list) + needed = sum(x_numels_list) + if needed / total > BLOCK_UTILIZATION: + # Introduced overhead (masked blocks) is less than 20% + self.dispatch_class = ComboKernel.RoundRobinDispatch + else: + self.dispatch_class = ComboKernel.SequentialDispatch + + def jit_line( + self, + heuristics: str, + size_hints: dict[str, int], + selected_kernel: TritonKernel, + signature: list[Any], + argdefs: list[ArgName], + pointwise_with_reduce: bool = False, + ) -> str: + can_use_32bit = all(k.index_dtype == "tl.int32" for k in self.sub_kernels) + size_dtype = "tl.int32" if can_use_32bit else "tl.int64" + for i, sub in enumerate(self.sub_kernels): + self.min_x_blocks_sub_kernel(sub, i) + self.select_dispatch_strategy() + triton_meta = { + "signature": signature_to_meta( + signature, size_dtype=size_dtype, argdefs=argdefs + ), + "device": DeviceProperties.create(V.graph.get_current_device_or_throw()), + "constants": {}, + } + triton_meta["configs"] = [config_of(signature)] + mutated_args = self.get_mutated_args_sub_kernels() + dispatch = self.dispatch_class + assert dispatch is not None + inductor_meta = { + "grid_type": dispatch.grid_expr.__name__, + "combo_grid_meta": self.combo_grid_meta(), + "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), + "mutated_arg_names": mutated_args, + **TritonKernel.inductor_meta_common(), + } + + sub_kernel = selected_kernel + if heuristics == "foreach": + heuristics_line = f""" + @triton_heuristics.foreach( + num_warps={self.num_warps}, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r}, + ) + @triton.jit + """ + elif sub_kernel.inside_reduction: + reduction_hint = sub_kernel.features.get_reduction_hint() + heuristics_line = f""" + @triton_heuristics.{heuristics}( + size_hints={size_hints!r}, + reduction_hint={reduction_hint}, + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + else: + tile_hint = "" + if len(size_hints) == 2: + tile_hint = "tile_hint=TileHint.SQUARE," + else: + tile_hint = "tile_hint=TileHint.DEFAULT," + heuristics_line = f""" + @triton_heuristics.{heuristics}( + size_hints={size_hints!r}, {tile_hint} + filename=__file__, + triton_meta={triton_meta!r}, + inductor_meta={inductor_meta!r} + ) + @triton.jit + """ + + return heuristics_line + + def codegen_blocks(self, code: IndentedBuffer) -> None: + for block in self.block_args: + assert block in ( + "XBLOCK", + "YBLOCK", + "R0_BLOCK", + ), f"{block} is not supported without autotuning" + if "YBLOCK" in self.block_args: + code.splice(f"XBLOCK: tl.constexpr = {self.block_size_2d}") + code.splice(f"YBLOCK: tl.constexpr = {self.block_size_2d}") + else: + code.splice(f"XBLOCK: tl.constexpr = {self.block_size_1d}") + if "R0_BLOCK" in self.block_args: + code.splice(f"R0_BLOCK: tl.constexpr = {self.block_size_reduce}") + code.splice(f"RBLOCK: tl.constexpr = {self.block_size_reduce}") + + def get_block_args(self) -> list[ConstexprArg]: + """ + Calculate blocks from sub_kernels and range_trees. + **Update self.block_args** + Return the block args + """ + block_names = {} + for sub_kernel in self.sub_kernels: + # TODO: we assume all sub_kernels have the same block size + for tree in sub_kernel.range_trees: + if tree.is_reduction and ( + not sub_kernel.inside_reduction or sub_kernel.persistent_reduction + ): + continue + if tree.prefix == "x" and sub_kernel.no_x_dim: + continue + block_names[f"{tree.prefix.upper()}BLOCK"] = tree.prefix + self.block_args = list(block_names.keys()) + + return [ConstexprArg(x) for x in block_names.keys()] + + def add_numel_to_args( + self, argdefs: list[ArgName], signature: list[Any] + ) -> list[ArgName]: + for num, sub_kernel in enumerate(self.sub_kernels): + for tree in sub_kernel.active_range_trees(): + if not isinstance(tree.numel, (Integer, int)): + # only if it is a dynamic shape + sizearg = SizeArg(f"{tree.prefix}numel_{num}", tree.numel) + signature.append(sizearg) + argdefs.append(ArgName(f"{tree.prefix}numel_{num}")) + self.dynamic_shape_args.append(f"{tree.prefix}numel_{num}") + return argdefs + + def add_numel_to_call_args( + self, name: str, call_args: list[Any], arg_types: list[Any] + ) -> None: + for num, sub_kernel in enumerate(self.sub_kernels): + for i, tree in enumerate(sub_kernel.range_trees): + numel_name = f"{tree.prefix}numel_{num}" + if numel_name not in self.dynamic_shape_args: + continue + if isinstance(tree.numel, (Integer, Symbol)): + expr = tree.numel + else: + expr = V.graph.wrapper_code.generate_numel_expr( + name, tree, suffix=str(num) + ) + if not tree.is_reduction or sub_kernel.inside_reduction: + call_args.append(expr) + arg_types.append(type(expr)) + + def kernel_benchmark_extra_args(self) -> list[str]: + extra_args = [] + for num, sub_kernel in enumerate(self.sub_kernels): + for i, tree in enumerate(sub_kernel.range_trees): + numel_name = f"{tree.prefix}numel_{num}" + if numel_name not in self.dynamic_shape_args: + continue + if not tree.is_reduction or sub_kernel.inside_reduction: + extra_args.append(str(V.graph.sizevars.size_hint(tree.numel))) + return extra_args + + def codegen_kernel(self, name: Optional[str] = None) -> str: + # TODO: is it correct to use the first sub kernel's heuristics? + heuristics_list, size_hints_list = [], [] + for subkernel in self.sub_kernels: + h, s = self.select_heuristics(subkernel) + heuristics_list.append(h) + size_hints_list.append(s) + heuristics, size_hints, selected_kernel = self.select_combo_heuristics( + heuristics_list, size_hints_list + ) + pointwise_with_reduction, heuristics = ( + (True, "pointwise") + if heuristics == "pointwise_with_reduction" + else (False, heuristics) + ) + code = IndentedBuffer() + + code.splice(gen_common_triton_imports()) + if config.benchmark_combo_kernel: + code.splice(self.imports_for_benchmark_kernel()) + + argdefs, _, signature, _ = self.args.python_argdefs() + argdefs = self.add_numel_to_args(argdefs, signature) + block_args = self.get_block_args() + if self.enable_autotune: + argdefs.extend([ArgName(x.name, is_constexpr=True) for x in block_args]) + if triton_version_uses_attrs_dict(): + signature.extend(block_args) + + code.splice( + self.jit_line( + heuristics, + size_hints, + selected_kernel, + pointwise_with_reduce=pointwise_with_reduction, + signature=signature, + argdefs=argdefs, + ) + ) + code.writeline( + f"def {name or str(Placeholder.KERNEL_NAME)}({', '.join(x.full_name() for x in argdefs)}):" + ) + + with code.indent(): + code.splice("pid = tl.program_id(0)") + if not self.enable_autotune: + self.codegen_blocks(code) + + for num, sub_kernel in enumerate(self.sub_kernels): + assert self.dispatch_class is not None + self.dispatch_class.codegen_pid_range(self, num, code) + with code.indent(): + uniquify = self.codegen_static_numels_sub_kernel( + code, sub_kernel, num + ) + sub_kernel.codegen_body() + uniquified_body = self.uniquify_block_sizes( + sub_kernel.body, num, uniquify + ) + code.splice(uniquified_body) + + code.splice("else:") + with code.indent(): + code.splice("pass") + + if config.benchmark_combo_kernel: + code.splice(self.codegen_kernel_benchmark(num_gb=0)) + + return code.getvalue() + + def codegen_kernel_benchmark(self, num_gb: float) -> IndentedBuffer: + result = IndentedBuffer() + _argdefs, call_args, signature, _ = self.args.python_argdefs() + result.writelines(["", "", "def get_args():"]) + with result.indent(): + name_cnt = itertools.count() + var_names = [] + for arg_name, arg_sig in zip(call_args, signature): + var_name = f"arg_{next(name_cnt)}" + buf = V.graph.try_get_buffer(arg_name) + if buf: + result.writeline( + f"{var_name} = rand_strided({V.graph.sizevars.size_hints(buf.get_size())}, {V.graph.sizevars.size_hints(buf.get_stride())}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long + ) + elif arg_name in V.graph.constants: + # note that random seed is put in V.graph.constants + const_tensor = V.graph.constants[arg_name] + result.writeline( + f"{var_name} = rand_strided({V.graph.sizevars.size_hints(const_tensor.size())}, {V.graph.sizevars.size_hints(const_tensor.stride())}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long + ) + elif isinstance(arg_sig, SizeArg): + symval_hint = V.graph.sizevars.size_hint(arg_sig.expr) + + # Force the seed_offset to be 0 so calls to the same kernel + # using different seed offset will have the same benchmark harness. + # We can dedup kernel definitions in this case. + if "seed_offset" in arg_sig.name: + symval_hint = 0 + result.writeline(f"{var_name} = {symval_hint}") + elif isinstance(arg_sig, WorkspaceArg): + device = V.graph.get_current_device_or_throw() + count = V.graph.sizevars.size_hint(arg_sig.count) + # for benchmark harness, we ignore arg_sig.zero_mode and always zero it + result.writeline( + f"{var_name} = torch.zeros({count}, device='{device}', dtype={arg_sig.dtype})" + ) + else: + raise KeyError( + f"Don't find the buffer or const tensor for {arg_name}" + ) + var_names.append(var_name) + if self.dynamic_shape_args: + var_names.extend(self.kernel_benchmark_extra_args()) + result.writeline(f"return {', '.join(var_names)},") + + result.writelines(["\n", "\n", "def call(args):"]) + index = V.graph.get_current_device_or_throw().index + with result.indent(): + result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") + with result.indent(): + result.writeline( + V.graph.device_ops.set_device(index) + ) # no-op to ensure context + stream_name = f"stream{index}" + result.writeline(f"{stream_name} = get_raw_stream({index})") + result.writeline( + f"{str(Placeholder.KERNEL_NAME)}.run(*args, stream={stream_name})" + ) + + # benchmark all configs + result.writelines(["\n", "\n", "def benchmark_all_configs(args):"]) + with result.indent(): + result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") + with result.indent(): + result.writeline( + V.graph.device_ops.set_device(index) + ) # no-op to ensure context + result.writeline( + f"return {str(Placeholder.KERNEL_NAME)}.benchmark_all_configs(*args)" + ) + + result.writelines(["\n", "\n", "if __name__ == '__main__':"]) + with result.indent(): + result.writeline( + "from torch._inductor.runtime.benchmarking import benchmarker" + ) + result.writeline("") + + result.writeline("args = get_args()") + result.writeline( + "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)" + ) + result.writeline(f"num_gb = {num_gb}") + result.writeline("gb_per_s = num_gb / (ms / 1e3)") + result.writeline( + 'print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")' + ) + + return result + + def imports_for_benchmark_kernel(self) -> str: + return textwrap.dedent( + """ + from torch._dynamo.testing import rand_strided + {} + import torch + """.format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")) + ) + + def uniquify_block_sizes( + self, code: IndentedBuffer, num_kernel: int, uniquify: list[str] + ) -> IndentedBuffer: + if not uniquify: + return code + modified = IndentedBuffer(initial_indent=code._indent) + for line in code._lines: + if isinstance(line, str) and (blocks := [e for e in uniquify if e in line]): + modified_line = line + for block in blocks: + modified_line = modified_line.replace( + block, f"{block}_{num_kernel}" + ) + modified.writeline(modified_line) + elif isinstance(line, DeferredLine) and ( + blocks := [e for e in uniquify if e in line.line] + ): + modified_line = line.line + for block in blocks: + modified_line = modified_line.replace( + block, f"{block}_{num_kernel}" + ) + new_line = DeferredLine(line.name, modified_line) + modified.writeline(new_line) + else: + modified.writeline(line) + return modified + + def call_kernel(self, code: IndentedBuffer, name: str) -> None: + _, call_args, _, arg_types = self.args.python_argdefs() + + wrapper = V.graph.wrapper_code + assert self.dispatch_class is not None + if self.dynamic_shape_args: + self.add_numel_to_call_args(name, call_args, arg_types) + + wrapper.generate_kernel_call( + name, + call_args, + triton=True, + arg_types=arg_types, + ) + + def combo_grid_meta(self) -> dict[str, Any]: + dynamic_shape = bool(self.dynamic_shape_args) + num_kernels = len(self.sub_kernels) + min_blocks = ( + max(self.min_x_blocks_list) * num_kernels if not dynamic_shape else None + ) + + if not self.enable_autotune: + if "YBLOCK" in self.block_args: + default_config = { + "XBLOCK": self.block_size_2d, + "YBLOCK": self.block_size_2d, + } + else: + default_config = {"XBLOCK": self.block_size_1d} + else: + default_config = None + + meta = { + "num_kernels": num_kernels, + "min_blocks": min_blocks, + "default_config": default_config, + } + + for num, sub_kernel in enumerate(self.sub_kernels): + meta[f"no_x_dim_{num}"] = sub_kernel.no_x_dim + for i, tree in enumerate(sub_kernel.range_trees): + if not tree.is_reduction: + numel_name = f"{tree.prefix}numel_{num}" + if numel_name in self.dynamic_shape_args: + meta[numel_name] = None + else: + meta[numel_name] = int(V.graph.sizevars.simplify(tree.numel)) + + return meta diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/triton_split_scan.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/triton_split_scan.py new file mode 100644 index 0000000000000000000000000000000000000000..e1b400c37c08fbeb0257515d59df4277c56df095 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/triton_split_scan.py @@ -0,0 +1,207 @@ +# mypy: allow-untyped-defs +import functools +from typing import Union + +import sympy + +from torch._inductor import config +from torch._inductor.codegen.simd import IterationRangesRoot, prefix_is_reduction +from torch._inductor.codegen.triton import ( + triton_compute_type, + TritonCSEVariable, + TritonKernel, +) +from torch._inductor.runtime.triton_heuristics import SplitScanGrid +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import CeilDiv + +from ..utils import sympy_product + + +class TritonSplitScanKernel(TritonKernel): + """Generates a triton kernel that supports ops.scan calls while also splitting + the reduction dimension over multiple triton programs. + + For this kernel, loop numels will always take the form ``(xdim, rdim)`` + and the grid has the shape ``(CeilDiv(rdim, RBLOCK), xdim)``. Communication + between blocks occurs within a global memory workspace buffer, which + must be zero-filled before launching the kernel. + + Note that generation for ``ops.reduction`` is not supported. + + For details of the communication strategy, see + https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back + + """ + + def __init__( + self, + tiling: dict[str, sympy.Expr], + pid_cache=None, + fixed_config=None, + **kwargs, + ) -> None: + assert pid_cache is None, "not supported" + assert fixed_config is None, "not supported" + super().__init__( + tiling, + **kwargs, + ) + self.no_x_dim = True + + def should_use_persistent_reduction(self) -> bool: + return False + + def should_use_cooperative_reduction(self) -> bool: + return False + + def initialize_range_tree(self, pid_cache): + prefixes = ["y", "x", "r0_"] + assert len(self.numels) <= len(prefixes), ( + "z dimension not supported for split scan" + ) + active_prefixes = prefixes[len(prefixes) - len(self.numels) :] + + grid_dims = {"r0_": 0, "x": 1, "y": 2} + for prefix in active_prefixes: + numel = self.numels[prefix] + tensor_dim = 0 if prefix_is_reduction(prefix) else None + grid_dim = grid_dims[prefix] + self.range_trees.append( + IterationRangesRoot( + f"{prefix}index", + numel, + prefix, + grid_dim, + self, # type: ignore[arg-type] + pid_cache=pid_cache, + is_loop=False, + tensor_dim=tensor_dim, + grid_dim=grid_dim, + has_zdim=False, + ) + ) + + def reduction(self, dtype, src_dtype, reduction_type, value): + raise NotImplementedError("NYI TritonSplitDimKernel reductions") + + def scan(self, dtypes, combine_fn, values): + import triton.language as tl + + (dtype,) = dtypes + (value,) = values + + compute_type = triton_compute_type(dtype) + compute_type_triton = getattr(tl, compute_type[3:]) + + element_nbits = compute_type_triton.primitive_bitwidth + + scratch_type = "tl.uint32" if element_nbits <= 16 else "tl.uint64" + scratch_type_triton = getattr(tl, scratch_type[3:]) + scratch_elems_per_block = 3 if element_nbits == 64 else 1 + scratch_nbytes_per_block = scratch_elems_per_block * ( + scratch_type_triton.primitive_bitwidth // 8 + ) + + cse_load = functools.partial(self.cse.generate, self.loads, dtype=dtype) + cse_compute = functools.partial(self.cse.generate, self.compute) + + assert len(self.numels) == 2, "Unexpected tiling" + min_rblock = config.triton.min_split_scan_rblock + reduction_numel = sympy_product( + numel + for prefix, numel in self.numels.items() + if prefix_is_reduction(prefix) + ) + pointwise_numel = sympy_product( + numel + for prefix, numel in self.numels.items() + if not prefix_is_reduction(prefix) + ) + max_blocks = pointwise_numel * CeilDiv(reduction_numel, min_rblock) + nbytes = scratch_nbytes_per_block * max_blocks + scratch_base: Union[str, TritonCSEVariable] + scratch_base, offset = self.args.workspace(nbytes=nbytes, zero_fill=True) + if offset != 0: + scratch_base = cse_load(f"{scratch_base} + {self.index_to_str(offset)}") + runtime_rblocks = cse_load(f"tl.num_programs({self.range_trees[-1].index})") + scratch_base = cse_load( + f"{scratch_base}.to(tl.pointer_type({scratch_type})) + xoffset * " + f"{scratch_elems_per_block} * {runtime_rblocks}" + ) + + masks = OrderedSet(f"{tree.prefix}mask" for tree in self.range_trees) + self.filter_masks(masks) + assert not self._load_mask, "ops.scan not supported inside ops.masked" + + value = cse_compute( + f"{value}.to({compute_type})", + dtype=dtype, + ) + value = cse_compute( + f"tl.broadcast_to({value}, {self.dense_size_str()})", + dtype=dtype, + ) + + combine_helper_fn = self._lift_helper(combine_fn, 1, (dtype,)) + dim = self.triton_tensor_ndim() - 1 + assert dim == 0, "" + + block_sum = cse_compute( + f"tl.reduce({value}, {dim}, {combine_helper_fn})", + dtype=dtype, + ) + exclusive_prefix = self.cse.newvar( + dtype=dtype, + ) + if element_nbits == 64: + self.compute.splice( + f""" + {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback_64( + {scratch_base}, + {block_sum}, + {self.iteration_ranges_get_pid(self.range_trees[-1])}, + {combine_helper_fn}, + ) + """, + strip=True, + ) + + else: + assert element_nbits <= 32 + value_as_uint_dtype = f"tl.uint{element_nbits}" + + self.compute.splice( + f""" + {exclusive_prefix} = triton_helpers.exclusive_scan_decoupled_lookback( + {scratch_base}, + {block_sum}, + {self.iteration_ranges_get_pid(self.range_trees[-1])}, + {combine_helper_fn}, + DTYPE_VALUE_AS_UINT={value_as_uint_dtype}, + DTYPE_PACK={scratch_type}, + ) + """, + strip=True, + ) + # Compute final cumsum + block_scan = cse_compute( + f"tl.associative_scan({value}, {dim}, {combine_helper_fn})", + dtype=dtype, + ) + combined_result = cse_compute( + f"{combine_helper_fn}({exclusive_prefix}, {block_scan})", + dtype=dtype, + ) + return ( + cse_compute( + f"tl.where(roffset == 0, {block_scan}, {combined_result})", + dtype=dtype, + ), + ) + + def _get_heuristic(self): + return "split_scan" + + def _get_grid_type(self) -> type[SplitScanGrid]: + return SplitScanGrid diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/triton_utils.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/triton_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6feff92d3ade8290f53b2a307365edd80fe320eb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/triton_utils.py @@ -0,0 +1,257 @@ +# mypy: allow-untyped-defs +from typing import Any, Optional + +import sympy + +import torch + +from .. import config +from ..runtime.hints import AttrsDescriptorWrapper +from ..utils import _type_of, expr_fits_within_32bit, triton_version_uses_attrs_dict +from ..virtualized import V +from .common import ( + ArgName, + ConstexprArg, + KernelArgType, + SizeArg, + TensorArg, + TMADescriptorArg, + WorkspaceArg, +) + + +def should_unwrap_unspec_arg(name: str): + if V.graph.is_unspec_arg(name): + # Unwrap on all devices except CPU + if V.graph.get_current_device_or_throw().type != "cpu": + return True + # Only unwrap on CPU if the input is not used as an output + if name not in V.graph.mutated_buffers: + return True + return False + + +def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str: + if isinstance(arg, TensorArg): + # TODO: Remove fp8 special handling when Triton supports PyTorch fp8 dtypes. + # Related PR: https://github.com/triton-lang/triton/pull/2279/ + if arg.dtype == torch.float8_e4m3fn: + typ = "*fp8e4nv" + elif arg.dtype == torch.float8_e5m2: + typ = "*fp8e5" + elif arg.dtype == torch.float8_e4m3fnuz: + typ = "*fp8e4b8" + elif arg.dtype == torch.float8_e5m2fnuz: + typ = "*fp8e5b16" + else: + typ = _type_of(arg.dtype) + if should_unwrap_unspec_arg(arg.buffer): + # had unwrapped 0d tensor as scalar + new_typ = typ.lstrip("*") + if new_typ in ["fp16", "bf16"]: + return "fp32" + else: + return new_typ + else: + return typ + if isinstance(arg, SizeArg): + if arg.expr is None: + if triton_version_uses_attrs_dict(): + # In newer versions of Triton, the signature includes "None" args + # and their type is marked as "constexpr" + return "constexpr" + else: + # In older versions of Triton... + # From triton/runtime/jit.py + # `None` is nullptr. Implicitly convert to *i8. + return "*i8" + elif _arg_equals_1(arg) and triton_version_uses_attrs_dict(): + # In new versions of Triton, if we have an equal-to-1 arg that's marked as a constant, + # it should be marked as "constexpr" in the signature. + return "constexpr" + elif isinstance(arg.expr, (float, sympy.Float)): + return "fp32" + + # if this is a integer + if size_dtype == "tl.int32": + return "i32" + elif size_dtype == "tl.int64": + return "i64" + elif size_dtype is None: + # no hint: we'll see if we know that this is a 32-bit int, and guard if possible. + int_max = torch.iinfo(torch.int32).max + if expr_fits_within_32bit(arg.expr): + V.graph.sizevars.guard_leq(arg.expr, int_max) + return "i32" + else: + return "i64" + else: + raise NotImplementedError(f"unhandled size_dtype {size_dtype}") + if isinstance(arg, WorkspaceArg): + return _type_of(arg.dtype) + if isinstance(arg, TMADescriptorArg): + if arg.api_type == "experimental": + return "nvTmaDesc" + else: + # https://github.com/triton-lang/triton/blob/9695baed9b46cf957e08b157bb4133f4a4b331c5/python/triton/runtime/jit.py#L360-L363 + assert arg.api_type == "stable" + assert arg.block_shape is not None + assert arg.dtype is not None + inner = _type_of(arg.dtype)[1:] # strip the `*`: *fp32 -> fp32 + return f"tensordesc<{inner}{list(arg.block_shape)}>" + if isinstance(arg, ConstexprArg): + return "constexpr" + raise NotImplementedError(f"unhandled {type(arg)}: {arg}") + + +def non_constexpr_signature(signature): + new_signature = [] + for arg in signature: + if not isinstance(arg, ConstexprArg): + new_signature.append(arg) + + return new_signature + + +def signature_to_meta( + signature: list[KernelArgType], + *, + size_dtype: Optional[str], + argdefs: list[ArgName], + indices: Optional[list[int]] = None, + is_template: bool = False, +) -> dict[str, str]: + if indices is None: + indices = list(range(len(signature))) + + def _decide_tl_dtype(arg): + # Even if the ks0 symbol itself is within tl.int32 range, it's + # risky to use tl.int32 dtype since we may have ks0*ks1 later + # for kernels like torch.mean when dynamic shape is enabled. + # + # Check config.triton.use_block_ptr, since Triton block pointer + # does not support 64bit indexing: + # https://gist.github.com/shunting314/6a41c776171720ce4561f202dcde0ad6 + # + # If the triton metadata is for a template, don't use tl.int64 index. + # Templates like flex attention/decoding uses block pointers which + # does not support 64 bit indexing. + if ( + not config.triton.use_block_ptr + and not is_template + and isinstance(arg, SizeArg) + and arg.name.startswith("ks") + ): + return "tl.int64" + return size_dtype + + return { + argdefs[i].name: signature_of(arg, size_dtype=_decide_tl_dtype(arg)) + for i, arg in zip(indices, signature) + } + + +def is_unaligned_buffer(arg: TensorArg): + buf_name = arg.buffer + if buf_name in V.graph.unaligned_buffers: + return True + + if buf_name in V.graph.graph_inputs: + # See Note: [Input Alignment handling in Inductor] + # For graph inputs that is not recorded in V.graph.unaligned_buffers, + # we know for sure the tensor is aligned. + return False + + if buf_name in V.graph.constants: + # all constants are assumed to be aligned + return False + + if V.graph.scheduler: + layout = V.graph.scheduler.get_buffer_layout(buf_name) + else: + buffer = V.graph.try_get_buffer(buf_name) + # output arg + if not buffer: + assert buf_name == V.kernel.output_node.name + layout = V.kernel.output_node.layout + else: + layout = buffer.get_layout() + + if isinstance(layout, torch._inductor.ir.NonOwningLayout): + return not layout.maybe_guard_aligned() + else: + return False + + +def _arg_equals_1(arg: KernelArgType) -> bool: + return ( + isinstance(arg, SizeArg) + and isinstance(arg.expr, (int, sympy.Integer)) + and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type] + ) + + +def equal_1_arg_indices( + args: list[KernelArgType], + *, + indices: Optional[list[int]] = None, +) -> tuple[int, ...]: + if indices is None: + indices = list(range(len(args))) + + equal_to_1 = tuple(i for i, arg in zip(indices, args) if _arg_equals_1(arg)) + + return equal_to_1 + + +def config_of( + args: list[KernelArgType], + *, + indices: Optional[list[int]] = None, +) -> Any: + if indices is None: + indices = list(range(len(args))) + + def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool: + """ + Roughly follow triton code here: + https://github.com/triton-lang/triton/blob/5282ed890d453e10b9ee30076ef89115dd197761/python/triton/runtime/jit.py#L208-L222 + """ + if isinstance(x, TensorArg): + if include_tensor: + offset_aligned = V.graph.sizevars.statically_known_multiple_of( + x.offset * x.dtype.itemsize, + alignment, # type: ignore[arg-type] + ) + return offset_aligned and not is_unaligned_buffer(x) + else: + return False + if isinstance(x, SizeArg): + # TODO(voz): These are kinda redundant, if we can solve out statically_known_multiple_of with + # _maybe_evaluate_static... + if x.name.startswith("load_seed_offset"): + return False + if x.expr is None: + return False + if isinstance(x.expr, float): + return False + return V.graph.sizevars.statically_known_multiple_of(x.expr, alignment) # type: ignore[arg-type] + if isinstance(x, WorkspaceArg): + # We allocate the workspace ourselves, so it is always aligned + return True + if isinstance(x, (TMADescriptorArg, ConstexprArg)): + return False + raise NotImplementedError(f"unhandled {type(x)}: {x}") + + if config.triton.divisible_by_16: + divisible_by_16 = tuple( + i + for i, arg in zip(indices, args) + if is_aligned(arg, alignment=16, include_tensor=True) + ) + else: + divisible_by_16 = () + + equal_to_1 = equal_1_arg_indices(args, indices=indices) + + return AttrsDescriptorWrapper(divisible_by_16, equal_to_1) diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/wrapper.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..bbc63ed0320528e0b5bbc59d4c3e1756ca49480d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/wrapper.py @@ -0,0 +1,3413 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import contextlib +import dataclasses +import dis +import functools +import inspect +import logging +import operator +import random +import re +import tempfile +from itertools import chain, count +from typing import Any, Callable, Optional, TYPE_CHECKING, Union + +import sympy +from sympy import Expr + +import torch +import torch._ops +import torch.utils._pytree as pytree +from torch import dtype as torch_dtype +from torch._dynamo.utils import counters, dynamo_timed +from torch._inductor.codegen.debug_utils import DebugPrinterManager +from torch._inductor.codegen.multi_kernel import MultiKernelState +from torch._inductor.runtime.runtime_utils import cache_dir +from torch.fx.experimental.symbolic_shapes import ( + CallMethodKey, + ConvertIntKey, + DivideByKey, + resolve_unbacked_bindings, + SymTypes, +) +from torch.fx.node import _get_qualified_name +from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.singleton_int import SingletonInt +from torch.utils._sympy.symbol import symbol_is_type, SymT + +from .. import async_compile, config, ir +from ..codecache import output_code_log +from ..ir import IRNode, ReinterpretView +from ..runtime import triton_heuristics +from ..runtime.hints import DeviceProperties +from ..utils import ( + cache_on_self, + DelayReplaceLine, + get_benchmark_name, + IndentedBuffer, + is_codegen_graph_partition_subgraph, + LineContext, + set_kernel_post_grad_provenance_tracing, + sympy_product, + sympy_str, + sympy_subs, + triton_version_uses_attrs_dict, +) +from ..virtualized import V +from .common import ( + ArgName, + CodeGen, + DeferredLine, + PythonPrinter, + WorkspaceArg, + WorkspaceZeroMode, +) +from .cpp_utils import cexpr +from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta + + +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + import triton + + from ..graph import GraphLowering + from .wrapper_fxir import FxConverter + + +log = logging.getLogger(__name__) + +pexpr = PythonPrinter().doprint + + +ReuseKey = tuple[torch.device, torch.dtype, str, bool] +BufferLike = Union[ir.Buffer, WorkspaceArg] +FxConversionFunc = Callable[["WrapperLine"], None] + + +def buffer_reuse_key(node: BufferLike) -> ReuseKey: + storage_size = V.graph.get_allocation_storage_size(node) + alignment = node.get_name() not in V.graph.unaligned_buffers + return ( + node.get_device_or_error(), + node.get_dtype(), + # NB: this is symbolic so that we don't try to reuse a buffer + # for s0 for s1, just because they happen to share the same + # size hint + sympy_str(V.graph.sizevars.simplify(storage_size)), + alignment, + ) + + +def can_match_buffer_size(input_buf: BufferLike, output_buf: BufferLike): + # Return True if input_buf can be re-inplaced for output_buf. + # This differs from `buffer_reuse_key` for general buffer reuse. + if input_buf.get_device_or_error() != output_buf.get_device_or_error(): + return False + + if input_buf.get_dtype() != output_buf.get_dtype(): + return False + + input_size = V.graph.sizevars.simplify( + V.graph.get_allocation_storage_size(input_buf) + ) + output_size = V.graph.sizevars.simplify( + V.graph.get_allocation_storage_size(output_buf) + ) + + if ( + # NB: this is symbolic so that we don't try to reuse a buffer + # for s0 for s1, just because they happen to share the same + # size hint + sympy_str(input_size) == sympy_str(output_size) + ) or ( + # statically known that 0.95 * input_size <= output_size <= input_size + V.graph.sizevars.statically_known_geq(output_size, 0.95 * input_size) + and V.graph.sizevars.statically_known_leq(output_size, input_size) + ): + return True + + return False + + +# TODO: Move to a well known place +TritonMetaParams = dict[str, int] +TritonGrid = Union[ + tuple[Union[int, sympy.Expr], ...], Callable[[TritonMetaParams], tuple[int, ...]] +] + + +def user_defined_kernel_grid_fn_code( + name: str, + configs: list[triton.Config], # type: ignore[name-defined] + grids: list[TritonGrid], + wrapper: Optional[PythonWrapperCodegen] = None, + original_fxnode_name: Optional[str] = None, +) -> tuple[str, str]: + output = IndentedBuffer() + + def _convert_to_sympy_expr(item: Union[int, sympy.Expr]) -> sympy.Expr: + return item if isinstance(item, sympy.Expr) else sympy.Integer(item) + + def determine_grid( + grid: TritonGrid, + example_grid: Optional[TritonGrid] = None, + ): + """ + This function return a tuple of two values: the first one is for the real grid + which is used in the generated code; the second one is an example grid with + concreate values which is used in the autotune block to run the generated + kernels at compile time. + """ + if wrapper is None or callable(grid): + # return as-is when used in eager mode or when grid is callable + return grid, grid + # Grid contains ints/Expr, so utilize wrapper's expr printer for codegen + sympy_grid = tuple(_convert_to_sympy_expr(g) for g in grid) + if not example_grid: + example_grid = sympy_grid + return ( + wrapper.codegen_python_shape_tuple(sympy_grid), + ( + wrapper.codegen_python_shape_tuple( + tuple( + wrapper.generate_example_arg_value(g, type(g)) + for g in example_grid # type: ignore[union-attr] + ) + ) + if config.triton.autotune_at_compile_time + else None + ), + ) + + def writeline(line: str, example_grid: Optional[str] = None): + output.writeline(line) + if ( + wrapper + and config.triton.autotune_at_compile_time + and name not in wrapper.kernel_autotune_names + ): + wrapper.kernel_autotune_calls.writeline(example_grid or line) + + fn_name = f"grid_wrapper_for_{name}" + writeline(f"def {fn_name}(meta):") + kernel_autotune_calls_indent = ( + wrapper.kernel_autotune_calls.indent() + if wrapper and config.triton.autotune_at_compile_time + else contextlib.nullcontext() + ) + with output.indent(), kernel_autotune_calls_indent: + if ( + config.triton.autotune_at_compile_time + and original_fxnode_name + and V.graph.autotuning_grids + and original_fxnode_name in V.graph.autotuning_grids + ): + example_grids = V.graph.autotuning_grids[original_fxnode_name] + else: + example_grids = [None] * len(grids) + if len(grids) == 1: + grid, example_grid = determine_grid(grids[0], example_grids[0]) + writeline(f"return {grid}", f"return {example_grid}") + else: + assert len(grids) > 1 + assert len(grids) == len(configs) + seen: OrderedSet[str] = OrderedSet() + # sort the configs from the largest # of kwargs to the smallest to + # emit the grids in the order of (approximately) decreasing specificity + # TODO(aakhundov): the sorting below is generally not sufficient, so + # maybe we'll need to restrict the supported cases to identical kwarg + # names in all autotuning configs. + for grid, c, example_grid in sorted( + zip(grids, configs, example_grids), + key=lambda x: len(x[1].kwargs), + reverse=True, + ): + if c.kwargs: + guards = [ + f"meta['{name}'] == {val}" for name, val in c.kwargs.items() + ] + guards = " and ".join(guards) + else: + guards = "True" # for configs with empty kwargs + grid, example_grid = determine_grid(grid, example_grid) + statement = f"if {guards}: return {grid}" + if statement in seen: + continue + seen.add(statement) + writeline(statement, f"if {guards}: return {example_grid}") + + return fn_name, output.getvalue() + + +def user_defined_triton_kernel_transitive_closure_source_code(kernel) -> str: + """ + Given a triton kernel function pointer collect the transitive closure of + its dependencies + """ + compile_wrapper = IndentedBuffer() + compile_wrapper.splice(kernel.src, strip=True) + + # Also include any possible kernel being called indirectly + from triton import JITFunction # type: ignore[name-defined, attr-defined] + from triton.language import constexpr # type: ignore[name-defined] + + # global constexpr vars handled above + symbols_included = OrderedSet([kernel.__name__]) + + def traverse(cur_kernel): + # here we extract the unqualified names (i.e., not attributes and + # without prepended module name) loaded in the kernel code, which + # are matched with the co_names and __globals__ below to codegen + # the respective imports necessary for the kernel compilation + unqualified_loads = OrderedSet( + inst.argval + for inst in dis.Bytecode(cur_kernel.fn) + if inst.opname == "LOAD_GLOBAL" + ) + global_annotations = cur_kernel.fn.__globals__.get("__annotations__", {}) + for symbol_name in cur_kernel.fn.__code__.co_names: + if symbol_name in symbols_included: + continue + if symbol_name in cur_kernel.fn.__globals__: + symbol = cur_kernel.fn.__globals__[symbol_name] + if isinstance(symbol, JITFunction): + compile_wrapper.newline() + compile_wrapper.writeline("@triton.jit") + compile_wrapper.splice(symbol.src, strip=True) + symbols_included.add(symbol_name) + traverse(symbol) + elif isinstance(symbol, (int, str, bool, constexpr)): + compile_wrapper.newline() + if isinstance(symbol, constexpr): + symbol_str = f"tl.constexpr({symbol.value!r})" + else: + symbol_str = f"{symbol!r}" + if annotation := global_annotations.get(symbol_name): + if isinstance(annotation, type): + annotation_code = ( + f": {annotation.__module__}.{annotation.__name__}" + ) + else: + annotation_code = f": {annotation!r}" + compile_wrapper.writeline( + f"{symbol_name}{annotation_code} = {symbol_str}" + ) + else: + compile_wrapper.writeline(f"{symbol_name} = {symbol_str}") + symbols_included.add(symbol_name) + elif ( + symbol_name in unqualified_loads + and symbol_name != "tl" # already imported + and hasattr(symbol, "__module__") + # only codegen imports from triton; JITFunctions + # imported from other modules will be codegened + # in the separate branch above + and symbol.__module__.startswith("triton") + ): + # a global symbol imported from triton is referenced + # without module qualification (i.e., `store` instead + # of `tl.store`): need to codegen an import + compile_wrapper.writeline( + f"from {symbol.__module__} import {symbol.__name__} as {symbol_name}" + ) + symbols_included.add(symbol_name) + + traverse(kernel) + return compile_wrapper.getvalue() + + +@dataclasses.dataclass +class SymbolicCallArg: + inner: str + # the original symbolic expression represented by inner + inner_expr: sympy.Expr + + def __str__(self): + return str(self.inner) + + +class MemoryPlanningState: + def __init__(self): + super().__init__() + self.reuse_pool: dict[ReuseKey, list[FreeIfNotReusedLine]] = ( + collections.defaultdict(list) + ) + self.total_allocated_buffer_size: int = 0 + + def __contains__(self, key: ReuseKey) -> bool: + return bool(self.reuse_pool.get(key, None)) + + def pop(self, key: ReuseKey) -> FreeIfNotReusedLine: + item = self.reuse_pool[key].pop() + assert not item.is_reused + return item + + def push(self, key: ReuseKey, item: FreeIfNotReusedLine) -> None: + assert not item.is_reused + self.reuse_pool[key].append(item) + + +class WrapperLine: + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + raise NotImplementedError("FX codegen not yet supported for type {type(self)}") + + +@dataclasses.dataclass +class EnterSubgraphLine(WrapperLine): + wrapper: PythonWrapperCodegen + graph: GraphLowering + + def __post_init__(self) -> None: + self.wrapper.push_computed_sizes(self.wrapper.computed_sizes) + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper.push_codegened_graph(self.graph) + code.do_indent() + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_enter_subgraph + + +@dataclasses.dataclass +class CommentLine(WrapperLine): + line: LineContext + + def codegen(self, code: IndentedBuffer) -> None: + code.writeline(self.line) + + @staticmethod + def codegen_fx(converter: FxConverter) -> FxConversionFunc: + return converter._generate_comment + + +@dataclasses.dataclass +class ExitSubgraphLine(WrapperLine): + wrapper: PythonWrapperCodegen + + def __post_init__(self) -> None: + self.wrapper.computed_sizes = self.wrapper.pop_computed_sizes() + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper.pop_codegened_graph() + code.do_unindent() + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_exit_subgraph + + +@dataclasses.dataclass +class EnterDeviceContextManagerLine(WrapperLine): + device_idx: int + last_seen_device_guard_index: Optional[int] + + def codegen(self, code: IndentedBuffer) -> None: + if V.graph.cpp_wrapper: + code.writeline("\n") + if V.graph.aot_mode: + # In AOT mode, we have a stream provided as a param. A stream is + # associated with a device, so we never expect the device to change. + # CUDAStreamGuard sets the stream and the device. + if self.last_seen_device_guard_index is None: + code.writeline( + f"{V.graph.device_ops.cpp_aoti_stream_guard()} stream_guard(stream, this->device_idx_);" + ) + else: + assert self.last_seen_device_guard_index == self.device_idx, ( + "AOTInductor only supports running on one CUDA device" + ) + else: + if self.last_seen_device_guard_index is None: + code.writeline( + f"{V.graph.device_ops.cpp_aoti_device_guard()} device_guard({self.device_idx});" + ) + else: + code.writeline(f"device_guard.set_index({self.device_idx});") + else: + # Note _DeviceGuard has less overhead than device, but only accepts + # integers + code.writeline(f"with {V.graph.device_ops.device_guard(self.device_idx)}:") + code.do_indent() + code.writeline(V.graph.device_ops.set_device(self.device_idx)) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_enter_device_context_manager + + +class ExitDeviceContextManagerLine(WrapperLine): + def codegen(self, code: IndentedBuffer) -> None: + if not V.graph.cpp_wrapper: + code.do_unindent() + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_exit_device_context_manager + + +@dataclasses.dataclass +class ExternKernelAllocLine(WrapperLine): + wrapper: PythonWrapperCodegen + node: ir.ExternKernelAlloc + + def codegen(self, code: IndentedBuffer) -> None: + node = self.node + args = [*node.codegen_args(), *node.codegen_kwargs()] + self.wrapper._generate_extern_kernel_alloc_helper(self.node, args) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_extern_kernel_alloc + + +@dataclasses.dataclass +class ExternKernelOutLine(WrapperLine): + wrapper: PythonWrapperCodegen + node: ir.ExternKernelOut + + def codegen(self, code: IndentedBuffer) -> None: + node = self.node + args = [*node.codegen_args(), *node.codegen_kwargs(skip_out=True)] + kernel_name = node.get_kernel_name() + if ( + V.graph.cpp_wrapper + and node.cpp_kernel_name == "torch::inductor::_mm_plus_mm" + ): + # For https://github.com/pytorch/pytorch/issues/128474 + kernel_name = "aoti_torch__mm_plus_mm_out" + else: + kernel_name = node.get_kernel_name() + device = d.type if (d := node.get_device()) else V.graph.device_type + # set provenance tracing kernel mapping for ExternKernel types + if config.trace.enabled: + set_kernel_post_grad_provenance_tracing(node, kernel_name, is_extern=True) + self.wrapper._generate_extern_kernel_out_helper( + kernel_name, + node.codegen_reference(), + node.output_view.codegen_reference() if node.output_view else None, + args, + device, + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_extern_kernel_out + + +@dataclasses.dataclass +class FreeLine(WrapperLine): + wrapper: PythonWrapperCodegen + node: Union[BufferLike, ir.TorchBindObject] + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + code.writeline(self.wrapper.make_buffer_free(self.node)) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_free + + +@dataclasses.dataclass +class KernelCallLine(WrapperLine): + wrapper: PythonWrapperCodegen + kernel_name: str + call_args: tuple[Any, ...] + raw_keys: tuple[Any, ...] + raw_args: tuple[Any, ...] + arg_types: list[str] + triton: bool + triton_meta: dict[str, Any] + device: torch.device + graph_name: str + original_fxnode_name: str + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper._generate_kernel_call_helper( + self.kernel_name, + self.call_args, + triton=self.triton, + arg_types=self.arg_types, + raw_keys=self.raw_keys, + raw_args=self.raw_args, + triton_meta=self.triton_meta, + device=self.device, + graph_name=self.graph_name, + original_fxnode_name=self.original_fxnode_name, + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_kernel_call + + +@dataclasses.dataclass +class KernelDefinitionLine(WrapperLine): + wrapper: PythonWrapperCodegen + kernel_name: str + kernel_body: str + metadata: Optional[str] = None + gpu: bool = True + cpp_definition: Optional[str] = None + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper._define_kernel_helper( + self.kernel_name, + self.kernel_body, + metadata=self.metadata, + gpu=self.gpu, + cpp_definition=self.cpp_definition, + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_kernel_definition + + +@dataclasses.dataclass +class MemoryPlanningLine(WrapperLine): + wrapper: PythonWrapperCodegen + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + """First pass to find reuse""" + return self + + def codegen(self, code: IndentedBuffer) -> None: + """Second pass to output code""" + + def __str__(self) -> str: + """ + Emits a string representation that fits on one line. + """ + args: list[str] = [] + for field in dataclasses.fields(self): + if field.name == "wrapper": + continue + val = getattr(self, field.name) + args.append( + f"{field.name}={val.get_name() if field.type is ir.Buffer else val}" + ) + return f"{type(self).__name__}({', '.join(args)})" + + +@dataclasses.dataclass +class AllocateLine(MemoryPlanningLine): + node: BufferLike + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + if self.node.get_name() in V.graph.removed_buffers: + return NullLine(self.wrapper) + + # try to reuse a recently freed buffer + key = buffer_reuse_key(self.node) + if config.allow_buffer_reuse and key in state: + free_line = state.pop(key) + free_line.is_reused = True + return ReuseLine(self.wrapper, free_line.node, self.node) + + if self.node.get_device_or_error().type == "cpu": + static_shape = self.wrapper.static_shape_for_buffer_or_none(self.node) + if static_shape is not None: + state.total_allocated_buffer_size += int( + functools.reduce(operator.mul, static_shape, 1) + ) + + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + line = self.wrapper.make_buffer_allocation(self.node) + code.writeline(line) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_allocate + + +@dataclasses.dataclass +class FreeIfNotReusedLine(MemoryPlanningLine): + node: BufferLike + is_reused: bool = False + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + if len(self.node.get_inputs_that_alias_output()) > 0: + return self + if isinstance(self.node.layout, ir.MultiOutputLayout): + return self + assert not self.is_reused + if self.node.get_name() in V.graph.removed_buffers: + return NullLine(self.wrapper) + if config.allow_buffer_reuse: + state.push(buffer_reuse_key(self.node), self) + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + if not self.is_reused: + code.writeline(self.wrapper.make_buffer_free(self.node)) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_free_if_not_reused + + +@dataclasses.dataclass +class ReinterpretLine(MemoryPlanningLine): + node: BufferLike + reused_as: BufferLike + layout: ir.Layout + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert isinstance(self.layout, ir.NonOwningLayout) + assert isinstance(self.layout.view, ir.ReinterpretView) + self.wrapper.codegen_deferred_allocation( + self.reused_as.get_name(), self.layout.view + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_reinterpret + + +@dataclasses.dataclass +class ReuseLine(MemoryPlanningLine): + node: BufferLike + reused_as: BufferLike + delete_old: bool = True + + def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: + if self.node.get_name() in V.graph.removed_buffers: + assert self.reused_as.get_name() in V.graph.removed_buffers + return NullLine(self.wrapper) + assert self.reused_as.get_name() not in V.graph.removed_buffers + return self + + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + assert self.reused_as.get_name() not in V.graph.removed_buffers + code.writeline( + self.wrapper.make_buffer_reuse(self.node, self.reused_as, self.delete_old) + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_reuse + + +class NullLine(MemoryPlanningLine): + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_null + + +@dataclasses.dataclass +class CommBufferLine(WrapperLine): + wrapper: PythonWrapperCodegen # type: ignore[name-defined] # noqa: F821 + node: ir.Buffer + + @property + def size(self) -> int: + from torch._inductor.utils import is_symbolic + + numel = self.node.get_numel() + dtype = self.node.get_dtype() + if is_symbolic(numel): + raise AssertionError( + f"The size of a comm buffer can't be symbolic: {self.node}" + ) + return int(numel) * dtype.itemsize + + @property + def comm_buffer_type(self) -> ir.CommBufferType: + layout = self.node.get_output_spec() + assert isinstance(layout, ir.CommBufferLayout) + return layout.comm_buffer_type + + @property + def group_name(self) -> str: + layout = self.node.get_output_spec() + assert isinstance(layout, ir.CommBufferLayout) + return layout.group_name + + +@dataclasses.dataclass +class CommBufferAllocateLine(CommBufferLine): + def codegen(self, code: IndentedBuffer) -> None: + assert self.node.get_name() not in V.graph.removed_buffers + name = self.node.get_name() + device = self.node.get_device() + dtype = self.node.get_dtype() + shape = tuple(self.node.get_size()) + stride = tuple(self.node.get_stride()) + code.writeline( + self.make_allocation_line( + self.comm_buffer_type, + self.group_name, + self.wrapper, + name, + device, + dtype, + shape, + stride, + ) + ) + + @staticmethod + def make_allocation_line( + comm_buffer_type, group_name, wrapper, name, device, dtype, shape, stride + ): + if comm_buffer_type == ir.CommBufferType.SYMM_MEM: + return ( + f"{name} = empty_strided_p2p(" + f"{wrapper.codegen_shape_tuple(shape)}, " + f"{wrapper.codegen_shape_tuple(stride)}, " + f"{dtype}, " + f'torch.device("cuda:{device.index}"), ' + f'group_name="{group_name}", ' + f"alloc_id={random.randint(0, 2**64 - 1)})" + ) + else: + raise NotImplementedError( + f"Unsupported comm buffer type: {comm_buffer_type}" + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_comm_buffer_allocate + + +@dataclasses.dataclass +class CommBufferFreeLine(CommBufferLine): + def codegen(self, code: IndentedBuffer) -> None: + line = self.wrapper.make_buffer_free(self.node) + code.writeline(f"{line} # {self.comm_buffer_type.value} buffer free") + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_comm_buffer_free + + +@dataclasses.dataclass +class MultiOutputLine(WrapperLine): + """ + Given a MultiOutputLayout buffer, indexes actual buffer(s) from the result. + """ + + wrapper: PythonWrapperCodegen + result_name: str + arg_name: str + indices: Sequence[Any] + + def codegen(self, code: IndentedBuffer) -> None: + def codegen_list_tuple_access(basename, indices): # type: ignore[no-untyped-def] + if len(indices) > 0: + itype, i = indices[0] + if issubclass(itype, list): + return codegen_list_tuple_access(f"{basename}[{i}]", indices[1:]) + elif issubclass(itype, tuple): + # cpp wrapper code needs to use std::get<> to access a tuple + tuple_access = self.wrapper.codegen_tuple_access( + basename, self.result_name, str(i) + ) + return codegen_list_tuple_access(tuple_access, indices[1:]) + elif issubclass(itype, dict): + return codegen_list_tuple_access(f"{basename}['{i}']", indices[1:]) + else: + raise AssertionError("non supported index type: ", itype) + else: + return basename + + value = codegen_list_tuple_access(self.arg_name, self.indices) + code.writeline( + f"{self.wrapper.declare}{self.result_name} = {value}{self.wrapper.ending}" + ) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_multi_output + + +@dataclasses.dataclass +class SymbolicCallArgLine(WrapperLine): + wrapper: PythonWrapperCodegen + arg: SymbolicCallArg + graph: GraphLowering + + def codegen(self, code: IndentedBuffer) -> None: + self.wrapper._generate_symbolic_call_arg_helper(self.arg, self.graph) + + def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: + return converter._generate_symbolic_call_arg + + +BufferName = str +Line = Union[MemoryPlanningLine, LineContext] + + +class PythonWrapperCodegen(CodeGen): + """ + Generate outer wrapper in Python that calls the kernels. + """ + + supports_caching = True # Whether the output code is cacheable. + + def __init__(self): + super().__init__() + self._names_iter: Iterator[int] = count() + self.args_to_buffers: dict[ + str, Union[None, ir.TensorBox, ir.Buffer, ir.TorchBindObject] + ] = {} + self.imports = IndentedBuffer() + self.header = IndentedBuffer() + self.prefix = IndentedBuffer() + self.suffix = IndentedBuffer() + self.kernel_declarations = IndentedBuffer() + self.wrapper_call = IndentedBuffer() + self.kernel_autotune_defs = IndentedBuffer() + self.kernel_autotune_calls = IndentedBuffer() + self.subgraph_definitions = IndentedBuffer() + self.kernel_autotune_names: OrderedSet[str] = OrderedSet() + # Map key is the kernel argument name; value is a tuple of the resulting example + # tensor name with the kernel where that tensor was most recently used. + self.kernel_autotune_example_args: dict[str, tuple[str, str]] = {} + self.kernel_autotune_tmp_arg_idx: int = 0 + # If the generated source code is exactly the same, reuse the + # pre-existing kernel for it + self.src_to_kernel: dict[str, str] = {} + self.kernel_numel_expr: OrderedSet[tuple[str, GraphLowering]] = OrderedSet() + self.lines: list[Line] = [] + self.declare = "" + self.declare_maybe_reference = "" + self.ending = "" + self.comment = "#" + self.none_str = "None" + self.move_begin = "std::move(" if V.graph.cpp_wrapper else "" + self.move_end = ")" if V.graph.cpp_wrapper else "" + self.last_seen_device_guard_index: Optional[int] = None + self.supports_intermediate_hooks = True + self.user_defined_kernel_cache: dict[tuple[Any, ...], tuple[str, Any]] = {} + self.unbacked_symbol_decls: OrderedSet[str] = ( + OrderedSet() + ) # str of sympy.Symbol + self.computed_sizes: OrderedSet[sympy.Symbol] = OrderedSet() + self.launcher_fn_name = None + # This function can be overridden to change the launcher name + self.set_launcher_fn_name() + + # this is used for tracking which GraphLowering instance---parent graph + # or (nested) subgraph---is currently codegened; the primary use case is + # including the graph instance into a cache key to avoid cross-graph + # caching during lowering of nested subgraphs + self.codegened_graph_stack = [] + self.computed_sizes_stack = [] + + self.write_header() + + if not is_codegen_graph_partition_subgraph(self): + # See [Note: Removed Graph Partition Arguments] + self.write_prefix() + + self.write_kernel_autotune_defs_header() + + if not V.graph.aot_mode: + for name, hashed in V.graph.constant_reprs.items(): + # include a hash so our code cache puts different constants into different files + self.write_constant(name, hashed) + + self.allocated = OrderedSet[BufferName]() + self.freed = OrderedSet[BufferName]() + + # maps from reusing buffer to reused buffer + self.reuses: dict[BufferName, BufferName] = {} + + self.write_get_raw_stream = functools.lru_cache(None)( # type: ignore[assignment] + self.write_get_raw_stream + ) + + @functools.cache + def add_import_once(line: str) -> None: + self.imports.writeline(line) + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline(line) + + self.add_import_once = add_import_once + self._metas: dict[str, str] = {} + self._meta_vars: OrderedSet[str] = OrderedSet() + self.multi_kernel_state = MultiKernelState() + self.already_codegened_subgraphs: OrderedSet[str] = OrderedSet() + self.allocated_workspaces: dict[str, Any] = {} + + # intermediate tensor value printing utility + self.debug_printer = DebugPrinterManager( + debug_printer_level=config.aot_inductor.debug_intermediate_value_printer, + use_array_ref=config.aot_inductor.allow_stack_allocation, + ) + + # Additional files that are dependent to the wrapper (ex. cubin files) + self.additional_files = [] + + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[ir.GraphPartitionSignature] = None, + ): + if is_subgraph: + assert subgraph_name is not None + assert parent_wrapper is not None + return SubgraphPythonWrapperCodegen( + subgraph_name, parent_wrapper, partition_signatures + ) + return PythonWrapperCodegen() + + def set_launcher_fn_name(self) -> None: + self.launcher_fn_name = "call" + + def write_constant(self, name: str, hashed: str) -> None: + self.header.writeline(f"{name} = None # {hashed}") + + def write_header(self) -> None: + context = torch._guards.TracingContext.try_get() + aot_config_comment = "" + if context is not None and context.aot_graph_name is not None: + aot_config_comment = f"# AOT ID: {context.aot_graph_name}" + aot_inductor_debug_utils = "" + if int(config.aot_inductor.debug_intermediate_value_printer) > 0: + aot_inductor_debug_utils = "from torch._inductor.codegen.debug_utils import _print_debugging_tensor_value_info" + self.imports.splice( + f""" + {aot_config_comment} + from ctypes import c_void_p, c_long, c_int + import torch + import math + import random + import os + import tempfile + from math import inf, nan + from cmath import nanj + from torch._inductor.hooks import run_intermediate_hooks + from torch._inductor.utils import maybe_profile + from torch._inductor.codegen.memory_planning import _align as align + from torch import device, empty_strided + from {async_compile.__name__} import AsyncCompile + from torch._inductor.select_algorithm import extern_kernels + {aot_inductor_debug_utils} + """, + strip=True, + ) + self.header.splice( + """ + aten = torch.ops.aten + inductor_ops = torch.ops.inductor + _quantized = torch.ops._quantized + assert_size_stride = torch._C._dynamo.guards.assert_size_stride + assert_alignment = torch._C._dynamo.guards.assert_alignment + empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu + empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda + empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu + reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor + alloc_from_pool = torch.ops.inductor._alloc_from_pool + async_compile = AsyncCompile() + """, + strip=True, + ) + try: + # Only add empty_strided_p2p() if distributed and SymmetricMemory + # is available + from torch._C._distributed_c10d import _SymmetricMemory # noqa: F401 + + self.header.splice( + """ + empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p + """, + strip=True, + ) + except (AttributeError, ImportError): + pass + if config.annotate_training: + self.header.writeline("from torch.cuda import nvtx") + + def include_extra_header(self, header: str): + pass + + def write_kernel_autotune_defs_header(self) -> None: + self.kernel_autotune_defs.splice( + f""" + import torch + from torch._dynamo.testing import rand_strided + from torch._dynamo.utils import preserve_rng_state + from torch._inductor.select_algorithm import AlgorithmSelectorCache + from {async_compile.__name__} import AsyncCompile + + async_compile = AsyncCompile() + generate_example_value = AlgorithmSelectorCache.generate_example_value + empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda + empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu + """ + ) + + @cache_on_self + def write_triton_header_once(self) -> None: + import_str = f""" + import triton + import triton.language as tl + from {triton_heuristics.__name__} import start_graph, end_graph + """ + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.splice(import_str) + self.kernel_autotune_calls.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + if not V.graph.cpp_wrapper: + self.imports.splice(import_str, strip=True) + self.imports.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + + def write_get_raw_stream_header(self) -> None: + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + if not V.graph.cpp_wrapper: + self.imports.writeline( + V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + ) + + @cache_on_self + def write_get_raw_stream_header_once(self) -> None: + self.write_get_raw_stream_header() + + def add_meta_once(self, meta: TritonMetaParams) -> str: + meta = repr(meta) + if meta not in self._metas: + var = f"meta{len(self._metas)}" + self._metas[meta] = var + self.header.writeline(f"{var} = {meta}") + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline(f"{var} = {meta}") + self._meta_vars.add(var) + return self._metas[meta] + + @cache_on_self + def get_output_refs(self) -> list[str]: + return [ + x.codegen_reference(self.wrapper_call) for x in self.get_graph_outputs() + ] + + def mark_output_type(self) -> None: + return + + def get_graph_inputs( + self, + ) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr]]: + return V.graph.graph_inputs + + def get_graph_outputs(self) -> list[IRNode]: + return V.graph.graph_outputs + + def codegen_input_size_asserts(self) -> None: + for name, buf in self.get_graph_inputs().items(): + if isinstance(buf, (sympy.Expr, ir.TorchBindObject)): + continue + + # a graph partition may take an IRNode output from a previous partition + if name not in V.graph.graph_input_names or isinstance( + buf, ir.GeneratorState + ): + continue + + # comparing strides for 0 size tensor is tricky. Ignore them for now. + if sympy_product(buf.get_size()) == 0: + continue + size = self.codegen_python_shape_tuple(buf.get_size()) + stride = self.codegen_python_shape_tuple(buf.get_stride()) + self.prefix.writeline(f"assert_size_stride({name}, {size}, {stride})") + + def codegen_input_nan_asserts(self) -> None: + self.prefix.writeline("# make sure graph inputs are not nan/inf") + for name, buf in self.get_graph_inputs().items(): + if isinstance(buf, (sympy.Expr, ir.TorchBindObject)): + continue + + line = f"assert not {name}.isnan().any().item()" + self.prefix.writeline(line) + line = f"assert not {name}.isinf().any().item()" + self.prefix.writeline(line) + + def write_async_compile_wait(self) -> None: + self.prefix.splice( + """ + + async_compile.wait(globals()) + del async_compile + """ + ) + + def write_args(self, input_names: list[str]): + lhs = ", ".join(input_names) + if len(input_names) == 1: + lhs += "," + self.prefix.writeline(f"{lhs} = args") + self.prefix.writeline("args.clear()") + + def write_launcher_fn_call_get_indent(self) -> int: + if config.graph_partition: + self.prefix.splice( + """ + class Runner: + def __init__(self, partitions): + self.partitions = partitions + + def recursively_apply_fns(self, fns): + new_callables = [] + for fn, c in zip(fns, self.partitions): + new_callables.append(fn(c)) + self.partitions = new_callables + + def call(self, args): + """ + ) + prefix_indent = 2 + else: + self.prefix.splice( + f""" + def {self.launcher_fn_name}(args): + """ + ) + prefix_indent = 1 + + return prefix_indent + + def get_graph_input_names(self) -> list[str]: + return V.graph.graph_input_names + + def write_prefix(self) -> None: + assert self.launcher_fn_name is not None + self.write_async_compile_wait() + prefix_indent = self.write_launcher_fn_call_get_indent() + + with self.prefix.indent(prefix_indent): + if config.triton.debug_sync_graph: + self.prefix.writeline(V.graph.device_ops.synchronize()) + phase = V.graph.get_training_phase() + if config.annotate_training: + self.prefix.writeline( + f"training_annotation = nvtx._device_range_start('{phase}')" + ) + + if graph_input_names := self.get_graph_input_names(): + self.write_args(graph_input_names) + + self.codegen_inputs() + self.codegen_input_size_and_nan_asserts() + + def codegen_input_size_and_nan_asserts(self) -> None: + if config.size_asserts: + self.codegen_input_size_asserts() + if config.nan_asserts: + self.codegen_input_nan_asserts() + + # this function (and below) takes the graph name as input so + # that stream caching happens per graph instance. this + # is important for nested subgraph codegening. + def write_get_raw_stream(self, device_idx: int, graph_name: str) -> str: + self.write_get_raw_stream_header_once() + name = f"stream{device_idx}" + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline( + f"{name} = get_raw_stream({device_idx})" + ) + if V.graph.cpp_wrapper: + # For cpp wrapper, no need to continue codegen for the main body + return name + self.writeline(f"{name} = get_raw_stream({device_idx})") + return name + + def get_codegened_graph(self): + return self.codegened_graph_stack[-1] + + def push_codegened_graph(self, graph): + self.codegened_graph_stack.append(graph) + + def pop_codegened_graph(self): + return self.codegened_graph_stack.pop() + + def push_computed_sizes(self, computed_sizes): + from copy import deepcopy + + return self.computed_sizes_stack.append(deepcopy(computed_sizes)) + + def pop_computed_sizes(self): + return self.computed_sizes_stack.pop() + + def next_kernel_suffix(self) -> str: + return f"{next(self._names_iter)}" + + def codegen_device_guard_enter(self, device_idx: int) -> None: + self.writeline( + EnterDeviceContextManagerLine(device_idx, self.last_seen_device_guard_index) + ) + if config.triton.autotune_at_compile_time: + # mimic logic of EnterDeviceContextManagerLine.codegen for the autotune code block + self.write_triton_header_once() + self.kernel_autotune_calls.writeline( + f"with {V.graph.device_ops.device_guard(device_idx)}:" + ) + self.kernel_autotune_calls.do_indent() + self.kernel_autotune_calls.writeline( + V.graph.device_ops.set_device(device_idx) + ) + if is_codegen_graph_partition_subgraph(self): + # Need get_raw_stream for subgraph + self.write_get_raw_stream_header() + self.kernel_autotune_calls.writeline( + f"stream{device_idx} = get_raw_stream({device_idx})" + ) + self.last_seen_device_guard_index = device_idx + + def codegen_device_guard_exit(self) -> None: + self.writeline(ExitDeviceContextManagerLine()) + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.do_unindent() + + def generate_return(self, output_refs: list[str]) -> None: + if output_refs: + if config.nan_asserts: + self.wrapper_call.writeline( + "return_vars = (" + ", ".join(output_refs) + ", )" + ) + self.wrapper_call.writeline("for var in return_vars:") + self.wrapper_call.do_indent() + self.wrapper_call.writeline("if isinstance(var, torch.Tensor):") + self.wrapper_call.do_indent() + self.wrapper_call.writeline("assert not var.isnan().any().item()") + self.wrapper_call.writeline("assert not var.isinf().any().item()") + self.wrapper_call.do_unindent(2) + + self.wrapper_call.writeline("return (" + ", ".join(output_refs) + ", )") + else: + self.wrapper_call.writeline("return ()") + + def generate_before_suffix(self, result: IndentedBuffer) -> None: + return + + def generate_after_suffix(self, result: IndentedBuffer) -> None: + if config.graph_partition: + all_partition_name_list = ", ".join(self.all_partition_names) + ( + "," if len(self.all_partition_names) == 1 else "" + ) + + result.splice( + f""" + runner = Runner(partitions=[{all_partition_name_list}]) + call = runner.call + recursively_apply_fns = runner.recursively_apply_fns + """ + ) + + def generate_end(self, result: IndentedBuffer) -> None: + return + + def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None: + self.writeline(ExternKernelAllocLine(self, node)) + + def generate_extern_kernel_alloc(self, node: ir.ExternKernelAlloc): + node.codegen_comment(self) + self.writeline(ExternKernelAllocLine(self, node)) + if isinstance(node.layout, ir.Layout): + node.codegen_size_asserts(self) + + def _generate_extern_kernel_alloc_helper(self, extern_kernel, args): + # If it's a NoneLayout then the extern_kernel should essentially be + # treated as if it doesn't return anything + no_return = isinstance(extern_kernel.layout, ir.NoneLayout) + output_name = extern_kernel.get_name() + origin_node = extern_kernel.get_origin_node() + kernel_name = extern_kernel.get_kernel_name() + ending = self.ending + if config.memory_planning and "view_as_complex" in kernel_name: + # view operation fallbacks cause issues since inductor + # doesn't know the memory is still needed and might reuse it. + ending = f".clone(){ending}" + + if no_return: + self.writeline(f"{self.declare}{kernel_name}({', '.join(args)}){ending}") + else: + self.writeline( + f"{self.declare}{output_name} = {kernel_name}({', '.join(args)}){ending}" + ) + if ( + self.supports_intermediate_hooks + and config.generate_intermediate_hooks + and origin_node is not None + ): + counters["inductor"]["intermediate_hooks"] += 1 + self.writeline( + f"run_intermediate_hooks({origin_node.name!r}, {output_name})" + ) + + def generate_extern_kernel_out( + self, + node: ir.ExternKernelOut, + ) -> None: + node.codegen_comment(self) + self.writeline(ExternKernelOutLine(self, node)) + + def _generate_extern_kernel_out_helper( + self, + kernel: str, + out: str, + out_view: Optional[str], + args: list[str], + device: str, + ) -> None: + # add debug printer code for triton kernel calls at (jit) inductor level + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args(args, kernel, None, None, "extern") + args.append(f"out={out_view if out_view else out}") + with debug_printer_manager: + self.writeline(f"{kernel}({', '.join(args)})") + + def _generate_tma_descriptor_call_experimental(self, desc, apply_size_hints=False): + dims = desc.dims + block_dims = desc.block_dims + if apply_size_hints: + dims = tuple(V.graph.sizevars.atomically_apply_size_hint(d) for d in dims) + block_dims = tuple( + V.graph.sizevars.atomically_apply_size_hint(d) for d in block_dims + ) + + ptr = f"{desc.tensor.codegen_reference()}.data_ptr()" + # Explicitly call the Python version of val_to_arg_str + dims = ", ".join(PythonWrapperCodegen.val_to_arg_str(self, dim) for dim in dims) + block_dims = ", ".join( + PythonWrapperCodegen.val_to_arg_str(self, dim) for dim in block_dims + ) + element_size = PythonWrapperCodegen.val_to_arg_str(self, desc.element_size) + prefix = "triton.tools.experimental_descriptor" + fn = f"{prefix}.create_{desc.rank}d_tma_descriptor" + args = f"{ptr}, {dims}, {block_dims}, {element_size}" + call = f"{fn}({args})" + return call + + def _generate_tma_descriptor_call_stable(self, desc, apply_size_hints=False): + block_shape = desc.block_shape + if apply_size_hints: + block_shape = tuple( + V.graph.sizevars.atomically_apply_size_hint(d) for d in block_shape + ) + + prefix = "triton.tools.tensor_descriptor.TensorDescriptor" + fn = f"{prefix}.from_tensor" + args = f"{desc.tensor.codegen_reference()}, {block_shape}" + call = f"{fn}({args})" + return call + + def _generate_tma_descriptor_call(self, desc, apply_size_hints=False): + if isinstance(desc, ir.TMADescriptorExperimental): + return self._generate_tma_descriptor_call_experimental( + desc, apply_size_hints + ) + else: + assert isinstance(desc, ir.TMADescriptorStable) + return self._generate_tma_descriptor_call_stable(desc, apply_size_hints) + + def generate_tma_descriptor(self, desc): + call = self._generate_tma_descriptor_call(desc) + line = f"{desc.name} = {call}{self.ending}" + self.writeline(line) + + def generate_scatter_fallback( + self, + output, + inputs, + cpp_kernel_name, + python_kernel_name, + src_is_tensor, + reduce, + kwargs, + ): + line = f"{python_kernel_name}({','.join(map(str, inputs))}" + if python_kernel_name.startswith("aten.scatter_reduce"): + line += ", ".join([""] + kwargs) + else: + if reduce: + line += f", reduce={repr(reduce)}" + line += ")" + self.writeline(line) + + def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): + indices_str = f"[{', '.join(indices)}]" + args = [x, indices_str, values, accumulate] + self.writeline(self.wrap_kernel_call(kernel, args)) + + def generate_fallback_kernel_with_runtime_lookup( + self, + buf_name: str, + python_kernel_name: str, + get_args: Callable[[], Sequence[str]], + op_overload: Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator], + raw_args: Sequence[Any], + outputs: Sequence[ir.Buffer], + ) -> None: + self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(get_args())})") + + def generate(self, is_inference): + with dynamo_timed("PythonWrapperCodegen.generate"): + return self._generate(is_inference) + + def get_wrapper_call_indent(self) -> int: + if config.graph_partition: + return 2 + else: + return 1 + + @contextlib.contextmanager + def set_writeline(self, new: Callable[..., None]) -> Iterator[Callable[..., None]]: + old = self.writeline + try: + self.writeline = new # type: ignore[method-assign] + yield new + finally: + self.writeline = old # type: ignore[method-assign] + + def _write_multi_kernel_defs(self) -> None: + kernel_defs = self.multi_kernel_state.kernel_defs + if config.triton.autotune_at_compile_time: + self.kernel_autotune_defs.splice(kernel_defs) + else: + self.header.splice(kernel_defs) + + def _generate(self, is_inference): + if config.profile_bandwidth: + self.write_triton_header_once() + + with contextlib.ExitStack() as stack: + stack.enter_context(self.wrapper_call.indent()) + if config.profiler_mark_wrapper_call: + self.generate_profiler_mark_wrapper_call(stack) + if config.profile_bandwidth: + self.generate_start_graph() + + self.run_wrapper_ir_passes(is_inference) + + if config.triton.store_cubin and not config.triton.autotune_at_compile_time: + self.generate_reset_kernel_saved_flags() + + # At this point, we shouldn't generate any new memory planning lines. + # Override writeline to point at the wrapper call, in case it gets called. + with self.set_writeline(self.wrapper_call.writeline): + for line in self.lines: + if isinstance(line, WrapperLine): + line.codegen(self.wrapper_call) + else: + self.wrapper_call.writeline(line) + + self._write_multi_kernel_defs() + + output_refs = self.get_output_refs() + self.mark_output_type() + if config.triton.debug_sync_graph: + self.wrapper_call.writeline(V.graph.device_ops.synchronize()) + + if config.profile_bandwidth: + self.generate_end_graph() + + if config.triton.store_cubin and not config.triton.autotune_at_compile_time: + self.generate_save_uncompiled_kernels() + + if config.triton.autotune_at_compile_time: + self.generate_and_run_autotune_block() + + # cpp_wrapper currently doesn't support nvtx + if config.annotate_training and not config.cpp_wrapper: + self.wrapper_call.writeline( + "nvtx._device_range_end(training_annotation)" + ) + self.generate_return(output_refs) + + # Assemble the final code from sections. + result = IndentedBuffer() + result.splice(self.imports) + result.writeline("") + result.splice(self.header) + # We do not want the cpp header for intermediate const graph. Headers would be + # rendered by the main module instead. + if V.graph.aot_mode and V.graph.cpp_wrapper and V.graph.is_const_graph: + result = IndentedBuffer() + + # Add subgraph definitions to the result + result.splice(self.subgraph_definitions) + self.finalize_prefix() + result.splice(self.prefix) + + wrapper_call_indent = self.get_wrapper_call_indent() + + with result.indent(wrapper_call_indent): + result.splice(self.wrapper_call) + + self.generate_before_suffix(result) + result.splice(self.suffix) + self.generate_after_suffix(result) + + self.generate_end(result) + + self.add_benchmark_harness(result) + + return ( + result.getvaluewithlinemap(), + self.kernel_declarations.getvaluewithlinemap(), + ) + + def generate_and_run_autotune_block(self): + """ + Compose self.kernel_autotune_defs and self.kernel_autotune_calls into a single block of + code and execute it to trigger Triton kernel compilation and auto-tuning + """ + self.kernel_autotune_defs.splice( + """ + async_compile.wait(globals()) + del async_compile + """ + ) + scope = {} # type: ignore[var-annotated] + if config.triton.autotune_at_compile_time and V.graph.autotuning_inputs: + scope = { + self.get_autotuning_input_name(idx): v # type: ignore[attr-defined] + for idx, v in enumerate(V.graph.autotuning_inputs) + } + tuning_code = ( + self.kernel_autotune_defs.getvalue() + + "\n" + + self.kernel_autotune_calls.getvalue() + ) + if output_code_log.level == logging.DEBUG: + # Save the autotuning code block into a file + # Create a temporary file + with tempfile.NamedTemporaryFile( + dir=cache_dir(), suffix=".py", delete=False + ) as f: + f.write(tuning_code.encode("utf-8")) + file_path = f.name + output_code_log.debug( + "Auto-tuning code written to %s", + file_path, + ) + # Execute the code to autotune kernels + try: + exec(tuning_code, scope) + except Exception as e: + raise RuntimeError(f"Failed to run autotuning code block: {e}") from e + + def memory_plan(self): + from .memory_planning import MemoryPlanner + + self.lines = MemoryPlanner(self).plan(self.lines) + + def memory_plan_reuse(self): + out_names = V.graph.get_output_names() + + while ( + self.lines + and isinstance(self.lines[-1], MemoryPlanningLine) + # TODO: this seems legit, NullLine has no node + and self.lines[-1].node.name not in out_names # type: ignore[attr-defined] + ): + # these lines will be pointless + self.lines.pop() + + # codegen allocations in two passes + planning_states = [MemoryPlanningState()] + past_planning_states = [] + for i in range(len(self.lines)): + line = self.lines[i] + if isinstance(line, MemoryPlanningLine): + self.lines[i] = line.plan(planning_states[-1]) + elif isinstance(line, EnterSubgraphLine): + planning_states.append(MemoryPlanningState()) + elif isinstance(line, ExitSubgraphLine): + past_planning_states.append(planning_states.pop()) + past_planning_states.append(planning_states.pop()) + assert len(planning_states) == 0 + + # conservatively use the sum of all allocated buffer sizes + # in potentially nested scopes as the total allocated size + # FIXME(rec): not used + _total_allocated_buffer_size = sum( + s.total_allocated_buffer_size for s in past_planning_states + ) + + def run_wrapper_ir_passes(self, is_inference: bool): + # We disable planning during training because it presently increases peak memory consumption. + if is_inference and config.memory_planning: + self.memory_plan() + else: + self.memory_plan_reuse() + + def codegen_input_symbol_assignment( + self, + name: str, + value: ir.TensorBox, + bound_vars: OrderedSet[sympy.Symbol], + ): + code = self.prefix + + @functools.cache + def sizeof(name): + code.writeline(f"{name}_size = {name}.size()") + return f"{name}_size" + + @functools.cache + def strideof(name): + code.writeline(f"{name}_stride = {name}.stride()") + return f"{name}_stride" + + if isinstance(value, sympy.Expr): + if not isinstance(value, sympy.Symbol) or value in bound_vars: + return + code.writeline(f"{value} = {name}") + bound_vars.add(value) + elif isinstance(value, ir.TensorBox): + for dim, size in enumerate(value.get_size()): + if isinstance(size, sympy.Symbol) and size not in bound_vars: + code.writeline(f"{size} = {sizeof(name)}[{dim}]") + bound_vars.add(size) + for dim, stride in enumerate(value.get_stride()): + if isinstance(stride, sympy.Symbol) and stride not in bound_vars: + code.writeline(f"{stride} = {strideof(name)}[{dim}]") + bound_vars.add(stride) + elif isinstance(value, ir.TorchBindObject): + return + elif isinstance(value, ir.GeneratorState): + return + else: + if torch._inductor.config.graph_partition: + pass + else: + raise AssertionError(f"Unknown value type: {type(value)}") + + def codegen_inputs(self): + """Assign all symbolic shapes to locals""" + bound_vars = OrderedSet[sympy.Symbol]() + # There is a subtle case in the cpp wrapper codegen which requires generating + # symbol inputs first followed by non-symbol ones. + # + # When a dynamic size constraint specified at the Export time is an expression, + # we need to solve that expression to proper define a symbol in cpp. Thus we + # are enforcing this iterating order here to make sure all plain size symbols + # are defined first. + graph_inputs = self.get_graph_inputs() + inputs = [ + (k, v) for k, v in graph_inputs.items() if isinstance(v, sympy.Symbol) + ] + [(k, v) for k, v in graph_inputs.items() if not isinstance(v, sympy.Symbol)] + for name, value in inputs: + self.codegen_input_symbol_assignment(name, value, bound_vars) + + def _verify_input_symbol_assignment( + value: ir.TensorBox, + bound_vars: OrderedSet[sympy.Symbol], + ): + for expr in chain.from_iterable([value.get_size(), value.get_stride()]): + if not isinstance(expr, Expr) or isinstance(expr, sympy.Symbol): + continue + + undefined_symbols = [ + sym for sym in expr.free_symbols if sym not in bound_vars + ] + if len(undefined_symbols) > 0: + raise AssertionError( + f"For {expr}, expected {undefined_symbols} to have been codegen-ed." + ) + + # For inputs with size/strides which contain sympy expressions, we can + # encounter symbols that weren't defined yet. Now, let's check each + # symbol is defined. + for _, value in inputs: + if not isinstance(value, ir.TensorBox): + continue + _verify_input_symbol_assignment(value, bound_vars) + + def ensure_size_computed(self, sym: sympy.Symbol): + if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE): + if sym in self.computed_sizes: + return + self.computed_sizes.add(sym) + expr = V.graph.sizevars.inv_precomputed_replacements[sym] + self.writeline(f"{sym} = {pexpr(expr)}") + + def finalize_prefix(self): + pass + + def codegen_cpp_sizevar(self, x: Expr, *, simplify: bool = True) -> str: + raise RuntimeError("codegen_cpp_sizevar is only implemented for cpp_wrapper!") + + def codegen_python_sizevar(self, x: Expr, *, simplify: bool = True) -> str: + return pexpr(x, simplify=simplify) + + def codegen_sizevar(self, x: Expr) -> str: + return self.codegen_python_sizevar(x) + + def codegen_tuple_access(self, basename: str, name: str, index: str) -> str: + return f"{basename}[{index}]" + + def codegen_python_shape_tuple(self, shape: Sequence[Expr]) -> str: + parts = [*map(self.codegen_python_sizevar, shape)] + if len(parts) == 0: + return "()" + if len(parts) == 1: + return f"({parts[0]}, )" + return f"({', '.join(parts)})" + + def codegen_shape_tuple(self, shape: Sequence[Expr]) -> str: + return self.codegen_python_shape_tuple(shape) + + def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: + return "alloc_from_pool({})".format( + ", ".join( + [ + name, + pexpr(offset), # bytes not numel + str(dtype), + self.codegen_python_shape_tuple(shape), + self.codegen_python_shape_tuple(stride), + ] + ) + ) + + def codegen_reinterpret_view( + self, + data, + size, + stride, + offset, + writeline: Callable[..., None], + dtype=None, + ) -> str: + if ( + size == data.layout.size + and stride == data.layout.stride + and offset == data.layout.offset + ): + if dtype is not None and dtype != data.dtype: + return f"aten.view.dtype({data.get_name()}, {dtype})" + else: + return f"{data.get_name()}" + else: + size = self.codegen_python_shape_tuple(size) + stride = self.codegen_python_shape_tuple(stride) + offset = self.codegen_sizevar(offset) + if dtype is not None and dtype != data.dtype: + return f"aten.view.dtype(reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset}), {dtype})" + else: + return ( + f"reinterpret_tensor({data.get_name()}, {size}, {stride}, {offset})" + ) + + def codegen_device_copy(self, src, dst, non_blocking: bool): + self.writeline(f"{dst}.copy_({src}, {non_blocking})") + + def codegen_multi_output(self, node: ir.MultiOutput): + result_name = node.get_name() + arg_name = node.inputs[0].get_name() + self.writeline(MultiOutputLine(self, result_name, arg_name, node.indices)) + + def codegen_dynamic_scalar(self, node): + (data,) = (t.codegen_reference() for t in node.inputs) + if len(node.keypath) == 0: + self.writeline(f"{node.sym} = {data}.item()") + elif len(node.keypath) == 1 and isinstance(node.keypath[0], ConvertIntKey): + self.writeline(f"{node.sym} = 1 if {data}.item() else 0") + elif len(node.keypath) == 1 and isinstance(node.keypath[0], DivideByKey): + self.writeline(f"{node.sym}_undivided = {data}.item()") + self.writeline( + f"assert {node.sym}_undivided % {node.keypath[0].divisor} == 0, " + f"f'{{{node.sym}_undivided}} not divisible by {node.keypath[0].divisor}'" + ) + self.writeline( + f"{node.sym} = {node.sym}_undivided // {node.keypath[0].divisor}" + ) + else: + raise AssertionError(f"unrecognized keypath {node.keypath}") + # No one should ever use this buffer, but for uniformity + # define the variable and assign it None + self.writeline(f"{node.get_name()} = None") + + def benchmark_compiled_module(self, output): + def add_fake_input(name, shape, stride, device, dtype): + output.writeline( + f"{name} = rand_strided(" + f"{self.codegen_python_shape_tuple(shape)}, " + f"{self.codegen_python_shape_tuple(stride)}, " + f"device='{device}', dtype={dtype})" + ) + + def add_expr_input(name, val): + output.writeline(f"{name} = {val}") + + def add_torchbind_input(name, value): + import pickle + + assert isinstance(value, torch.ScriptObject) + + output.writeline(f"{name} = pickle.loads({pickle.dumps(value)!r})") + + output.writelines( + ["", "", "def benchmark_compiled_module(times=10, repeat=10):"] + ) + with output.indent(): + output.splice( + """ + from torch._dynamo.testing import rand_strided + from torch._inductor.utils import print_performance + """, + strip=True, + ) + + for name, value in V.graph.constants.items(): + # all the constants are global variables, that's why we need + # these 'global var_name' lines + output.writeline(f"global {name}") + add_fake_input( + name, value.size(), value.stride(), value.device, value.dtype + ) + + if len(V.graph.torchbind_constants) > 0: + output.writeline("import pickle") + for name, torchbind_obj in V.graph.torchbind_constants.items(): + # all the constants are global variables, that's why we need + # these 'global var_name' lines + output.writeline(f"global {name}") + add_torchbind_input(name, torchbind_obj) + + for name, value in V.graph.graph_inputs.items(): + if isinstance(value, sympy.Symbol) and isinstance( + V.graph.sizevars.var_to_val.get(value, None), SingletonInt + ): + # Inductor should only work with dense -> dense graph, and + # SingletonInts belong to metadata that should only live on + # the subclass. + continue + if isinstance(value, ir.TorchBindObject): + if len(V.graph.torchbind_constants) == 0: + # otherwise we have already imported the pickle package + output.writeline("import pickle") + output.writeline(f"global {name}") + add_torchbind_input(name, value.get_real_obj()) + elif isinstance(value, sympy.Expr): # Don't need to add symbolic + # TODO: this fallback and those below actually will generate possibly + # invalid benchmark code, because it's not guaranteed 42 + # is actually a valid value for the kernel in question. + # See https://github.com/pytorch/pytorch/issues/124686 + add_expr_input(name, V.graph.sizevars.size_hint(value, fallback=42)) + elif isinstance(value, ir.GeneratorState): + add_expr_input( + name, + f"torch.cuda.default_generators[{value.device.index}].graphsafe_get_state()", + ) + else: + shape = [ + V.graph.sizevars.size_hint(x, fallback=42) + for x in value.get_size() + ] + stride = [ + V.graph.sizevars.size_hint(x, fallback=42) + for x in value.get_stride() + ] + add_fake_input( + name, + shape, + stride, + value.get_device(), + value.get_dtype(), + ) + + call_str = f"call([{', '.join(V.graph.graph_inputs.keys())}])" + output.writeline(f"fn = lambda: {call_str}") + output.writeline("return print_performance(fn, times=times, repeat=repeat)") + + def add_benchmark_harness(self, output): + """ + Append a benchmark harness to generated code for debugging + """ + if not config.benchmark_harness: + return + + self.benchmark_compiled_module(output) + + output.writelines(["", "", 'if __name__ == "__main__":']) + with output.indent(): + output.writelines( + [ + "from torch._inductor.wrapper_benchmark import compiled_module_main", + f"compiled_module_main('{get_benchmark_name()}', benchmark_compiled_module)", + ] + ) + + def define_kernel( + self, + kernel_name: str, + kernel_body: str, + metadata: Optional[str] = None, + gpu: bool = True, + cpp_definition: Optional[str] = None, + ): + self.writeline( + KernelDefinitionLine( + self, + kernel_name, + kernel_body, + metadata=metadata, + gpu=gpu, + cpp_definition=cpp_definition, + ) + ) + + @staticmethod + def _format_kernel_definition( + kernel_name: str, kernel_body: str, metadata: Optional[str] = None + ): + metadata_comment = f"{metadata}\n" if metadata else "" + body = f"\n\n{metadata_comment}{kernel_name} = {kernel_body}" + return body + + def _define_kernel_helper( + self, + kernel_name: str, + kernel_body: str, + metadata: Optional[str] = None, + gpu: bool = True, + cpp_definition: Optional[str] = None, + ): + if config.triton.autotune_at_compile_time: + # Skip inserting comments for the autotune block as they may contain cpp style comments + body = self._format_kernel_definition( + kernel_name, kernel_body, metadata=None + ) + self.kernel_autotune_defs.splice(body) + if V.graph.cpp_wrapper: + # For cpp wrapper, no need to continue codegen for the main body + return + + body = self._format_kernel_definition( + kernel_name, kernel_body, metadata=metadata + ) + self.header.splice(body) + + def define_subgraph_launcher_fn(self, fn_code: str): + self.subgraph_definitions.splice(fn_code) + + def define_user_defined_triton_kernel( + self, + kernel, + configs, + kwargs, + restore_value_args, + reset_to_zero_args, + grids: list[list[Union[int, sympy.Expr]]], + ): + from ..runtime.triton_heuristics import ( + config_to_dict, + FixedGrid, + PrecomputedGrid, + ) + from .common import ( + ConstexprArg, + KernelArgType, + SizeArg, + TensorArg, + TMADescriptorArg, + ) + from .triton import gen_common_triton_imports, TritonKernel + + original_name = kernel.__name__ + signature: list[KernelArgType] = [] + constants: dict[str, Any] = {} + arg_indices: list[int] = [] + equal_to_1_args: list[str] = [] + + def add_to_signature(idx, arg): + signature.append(arg) + arg_indices.append(idx) + + def add_arg(idx, arg, is_constexpr=False, equals_1=False, equals_none=False): + if is_constexpr: + if triton_version_uses_attrs_dict(): + # tl.constexpr args appear in the signature in new versions of triton, + # but not in old versions of triton. + add_to_signature(idx, arg) + + if arg.name in kwargs: + # the arg may not appear in kwargs if it is an autotuned arg. + # in this case, it will be added in triton_heuristics after autotuning. + constants[arg.name] = kwargs[arg.name] + + else: + # the only case where arg name isn't in kwargs, should be + # when the arg is a constexpr. + assert arg.name in kwargs + + if equals_1: + if triton_version_uses_attrs_dict(): + # new versions of triton: add the equal-to-1 arg in the signature (labeled as "constexpr"), + # and add the arg as a constant. + # new versions of triton: add the equal-to-1 arg in the signature (labeled as, e.g., "i32"), + # and add the arg as a constant. + add_to_signature(idx, ConstexprArg(name=arg.name)) + else: + add_to_signature(idx, arg) + constants[arg.name] = 1 + elif equals_none: + if triton_version_uses_attrs_dict(): + # new versions of triton: add the none arg in the signature (as a constexpr arg) and as a constant + # old versions of triton: include the none arg as a constant (but not in the signature) + add_to_signature(idx, ConstexprArg(name=arg.name)) + constants[arg.name] = None + else: + add_to_signature(idx, arg) + + for idx, key in enumerate(kernel.arg_names): + if idx in kernel.constexprs: + add_arg(idx, ConstexprArg(name=key), is_constexpr=True) + continue + + if key not in kwargs: + continue + + arg = kwargs[key] + + if kwargs[key] is None: + add_arg(idx, ConstexprArg(name=key), equals_none=True) + else: + if isinstance(arg, ir.TMADescriptor): + api_type, block_shape, dtype = ( + ("stable", arg.block_shape, arg.tensor.get_dtype()) + if isinstance(arg, ir.TMADescriptorStable) + else ("experimental", None, None) + ) + add_arg( + idx, + TMADescriptorArg( + name=key, + api_type=api_type, + block_shape=block_shape, + dtype=dtype, + ), + ) + elif isinstance(arg, ir.Buffer): + add_arg( + idx, + TensorArg( + name=key, + buffer=arg.get_name(), + dtype=arg.get_dtype(), + ), + ) + elif isinstance(arg, ir.ReinterpretView): + # for ReinterpretView we use the underlying + # buffer name and note the (possibly non-zero) + # offset relative to the underlying buffer + add_arg( + idx, + TensorArg( + name=key, + buffer=arg.data.get_name(), + dtype=arg.get_dtype(), + offset=arg.layout.offset, + ), + ) + else: + equals_1 = isinstance( + arg, (int, sympy.Integer) + ) and V.graph.sizevars.statically_known_equals( + arg, + 1, # type: ignore[arg-type] + ) + add_arg(idx, SizeArg(key, arg), equals_1=equals_1) + + triton_signature = signature_to_meta( + signature, + size_dtype=None, # try to infer based on symints + indices=arg_indices, + argdefs=[ArgName(x) for x in kernel.arg_names], + ) + triton_meta: dict[str, Any] = { + "signature": triton_signature, + "device": DeviceProperties.create(V.graph.get_current_device_or_throw()), + # Triton compiler includes equal_to_1 args into constants even + # when they are not constexpr. otherwise there may be a segfault + # during launching the Inductor-compiled Triton kernel. + # TODO(aakhundov): add None args to constants, too. currently, this + # causes CUDA errors in test_aot_inductor.test_triton_kernel_with_none_input. + # https://github.com/pytorch/pytorch/issues/120478#issuecomment-1962822307 + # https://github.com/triton-lang/triton/blob/231efe9ed2d200be0f69a07c298e4342b08efe3d/python/triton/runtime/jit.py#L384 + "constants": { + **constants, + **dict.fromkeys(equal_to_1_args, 1), + }, + "configs": [ + config_of( + signature, + indices=arg_indices, + ) + ], + } + + if restore_value_args: + triton_meta["restore_value"] = tuple(restore_value_args) + + if reset_to_zero_args: + triton_meta["reset_to_zero"] = tuple(reset_to_zero_args) + + if len(grids) == 1: + # compute the grid in the wrapper and pass it in as an arg + inductor_meta: dict[str, Any] = FixedGrid.setup_grid_as_args() + extra_launcher_call_args = [*map(sympy.sympify, grids[0])] + else: + + def rename_sizes_for_launcher(expr: Union[int, sympy.Expr]) -> sympy.Expr: + if isinstance(expr, sympy.Expr): + symbols = [*expr.free_symbols] + if not symbols: + return expr + symbols.sort(key=str) + for sym in symbols: + if sym in extra_launcher_args: + continue + extra_launcher_args[sym] = sympy.Symbol( + f"_launcher_s{len(extra_launcher_args)}" + ) + return sympy_subs(expr, extra_launcher_args) + assert isinstance(expr, int) + return sympy.Integer(expr) + + extra_launcher_args: dict[sympy.Symbol, sympy.Symbol] = {} + grids = [[*map(rename_sizes_for_launcher, grid)] for grid in grids] + + assert grids and len(grids) == len(configs) + precomputed_grids = [] + for grid, cfg in sorted( + zip(grids, configs), key=lambda x: len(x[1].kwargs), reverse=True + ): + precomputed_grids.append( + { + "config": config_to_dict(cfg), + "python": [*map(pexpr, grid)], + "cpp": [*map(cexpr, grid)], + } + ) + inductor_meta = { + "grid_type": PrecomputedGrid.__name__, + "precomputed_grids": precomputed_grids, + "extra_launcher_args": [*map(str, extra_launcher_args.values())], + } + extra_launcher_call_args = [*extra_launcher_args.keys()] + + # Distinguish between different functions using function id + cache_key: Any = [id(kernel.fn)] + if len(configs) > 0: + for arg in kwargs.values(): + # We need to key on non tensor arg only in autotune mode + if not isinstance(arg, (ir.Buffer, ir.ReinterpretView)): + cache_key.append(arg) + cache_key.append(str(triton_meta)) + cache_key.extend(str(inductor_meta)) + cache_key = tuple(cache_key) + if cache_key in self.user_defined_kernel_cache: + return ( + *self.user_defined_kernel_cache[cache_key], + extra_launcher_call_args, + ) + + name = f"{original_name}_{len(self.user_defined_kernel_cache)}" + + compile_wrapper = IndentedBuffer() + if config.triton.unique_user_kernel_names: + compile_wrapper.writeline(f"async_compile.triton({name!r}, '''") + else: + compile_wrapper.writeline(f"async_compile.triton({original_name!r}, '''") + + inductor_meta["kernel_name"] = name + inductor_meta.update(TritonKernel.inductor_meta_common()) + + compile_wrapper.splice(gen_common_triton_imports()) + compile_wrapper.splice( + f""" + @triton_heuristics.user_autotune( + configs={[*map(config_to_dict, configs)]!r}, + inductor_meta={inductor_meta!r}, + triton_meta={triton_meta!r}, + filename=__file__, + custom_kernel=True, + ) + @triton.jit + """ + ) + kernel_src = user_defined_triton_kernel_transitive_closure_source_code(kernel) + if config.triton.unique_user_kernel_names: + # We replace the original_name with the unique name. + kernel_src = kernel_src.replace(f"def {original_name}(", f"def {name}(") + kernel_src = kernel_src.replace("'''", "\\'\\'\\'") + compile_wrapper.splice(kernel_src) + + current_device = V.graph.get_current_device_or_throw() + compile_wrapper.writeline(f"''', device_str='{current_device.type}')") + _, lineno = inspect.getsourcelines(kernel.fn) + srcfile = inspect.getsourcefile(kernel.fn) + metadata = f"# Original path: {srcfile}:{lineno}" + self.define_kernel( + name, + compile_wrapper.getvalue(), + metadata, + ) + # Add to the cache for the next use + self.user_defined_kernel_cache[cache_key] = (name, triton_meta) + return name, triton_meta, extra_launcher_call_args + + def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = None): + expr = f"{kernel_name}_{tree.prefix}numel" + if suffix is not None: + expr += f"_{suffix}" + + # We can get symbolic expressions here, like s0*64 + # It is fine to have them here, but we need to handle them correctly as their own type + # This is tricky to do, so we wrap in a custom type, distinct from scalars, but also from sympy* + # scalars as well. + # This is handled in `generate_args_decl` which has a correct comment of: TODO: only works for + # constant now, need type info. I agree, this needs type info, and while this is not true type info + # it suffices as a type hint for the purposes of producing the correct code for this type. + arg = SymbolicCallArg(expr, tree.numel) + self.writeline(SymbolicCallArgLine(self, arg, V.graph)) + + return arg + + def _generate_symbolic_call_arg_helper( + self, arg: SymbolicCallArg, graph: GraphLowering + ) -> None: + self.writeline(f"{arg.inner} = {pexpr(arg.inner_expr)}") + + def generate_workspace_allocation(self, ws: WorkspaceArg): + name = ws.get_name() + line = AllocateLine(self, ws) + if ws.zero_mode == WorkspaceZeroMode.UNINITIALIZED: + self.writeline(line) + elif ws.zero_mode == WorkspaceZeroMode.ZERO_ON_CALL: + self.writeline(line) + self.writeline(self.make_zero_buffer(name)) + elif ws.zero_mode == WorkspaceZeroMode.ZERO_PER_GRAPH: + prior = self.allocated_workspaces.get(name) + if prior: + assert isinstance(prior, AllocateLine) and isinstance( + prior.node, WorkspaceArg + ) + # expand existing allocation + prior.node = WorkspaceArg.maximum(prior.node, ws) + else: + self.writeline(line) + self.writeline(self.make_zero_buffer(name)) + self.allocated_workspaces[name] = line + else: + raise AssertionError(ws.zero_mode) + + if config.triton.autotune_at_compile_time: + self.kernel_autotune_calls.writeline( + PythonWrapperCodegen.make_allocation( + self, + name, + ws.device, + ws.dtype, + shape=(V.graph.sizevars.size_hint(ws.count),), + stride=(1,), + ) + ) + if ws.zero_mode != WorkspaceZeroMode.UNINITIALIZED: + self.kernel_autotune_calls.writeline( + PythonWrapperCodegen.make_zero_buffer(self, name) + ) + + def generate_workspace_deallocation(self, ws: WorkspaceArg): + if ws.zero_mode != WorkspaceZeroMode.ZERO_PER_GRAPH: + self.writeline(FreeIfNotReusedLine(self, ws)) + + def make_zero_buffer(self, name): + return f"{name}.zero_(){self.ending}" + + def wrap_kernel_call(self, name, call_args): + return f"{name}({', '.join(call_args)}){self.ending}" + + def generate_profiler_mark_wrapper_call(self, stack): + self.wrapper_call.writeline("from torch.profiler import record_function") + self.wrapper_call.writeline( + f"with record_function('graph_{V.graph.graph_id}_inductor_wrapper_call'):" + ) + stack.enter_context(self.wrapper_call.indent()) + + def generate_start_graph(self): + self.wrapper_call.writeline("start_graph()") + + def generate_end_graph(self): + self.wrapper_call.writeline(f"end_graph({config.profile_bandwidth_output!r})") + + def generate_reset_kernel_saved_flags(self): + self.wrapper_call.splice( + f""" + for kernel in globals().values(): + if isinstance(kernel, {triton_heuristics.__name__}.CachingAutotuner): + kernel.cuda_kernel_saved = False + """ + ) + + def generate_save_uncompiled_kernels(self): + """ + Precompile and save the CUBINs of the Triton kernels that haven't + been precompiled and saved as a side effect of running the generated + JIT model (Python wrapper). This can happen when the model contains + control flow: only one pass through the control flow operators covers + the kernels that are saved, the remaining kernels are not launched, + hence not saved. The main purpose of this codegen is to compile and + save the Triton kernels outside the active control flow path for + subsequent AOTInductor code generation and compilation. + """ + self.wrapper_call.splice( + f""" + for kernel in globals().values(): + if isinstance(kernel, {triton_heuristics.__name__}.CachingAutotuner): + if not kernel.cuda_kernel_saved: + if len(kernel.launchers) == 0: + kernel.precompile() + kernel.save_gpu_kernel( + grid=(0, 0, 0), # use dummy grid + stream="stream", # use dummy stream + launcher=kernel.launchers[0], + ) + """ + ) + + def prepare_triton_kernel_call(self, call_args): + def wrap_arg(arg): + if isinstance(arg, str): + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + return arg + ".item()" if should_unwrap_unspec_arg(arg) else arg + elif isinstance(arg, (int, float, bool, SymbolicCallArg)): + return str(arg) + else: + return pexpr(V.graph.sizevars.simplify(arg)) + + return [wrap_arg(arg) for arg in call_args] + + def generate_example_arg_value(self, arg, arg_type, raw_arg=None): + if isinstance(arg_type, torch_dtype): + if isinstance(raw_arg, ir.TMADescriptor): + # first we generate the underlying buffer + buf_name = raw_arg.get_tensor().get_name() + buf = self.args_to_buffers[arg] + elif self.args_to_buffers.get(arg): + buf_name = arg + buf = self.args_to_buffers[arg] + else: + assert raw_arg is not None, ( + "V.graph.get_buffer(arg) and raw_arg can't be None at the same time" + ) + buf_name = f"tmp_arg_{self.kernel_autotune_tmp_arg_idx}" + buf = raw_arg + self.kernel_autotune_tmp_arg_idx += 1 + + assert buf is not None, f"Failed to find a buffer for arg {arg}" + size = tuple( + V.graph.sizevars.atomically_apply_size_hint( + e, + fallback=config.unbacked_symint_fallback, + ) + for e in buf.get_size() + ) + allocation_size = tuple( + V.graph.sizevars.atomically_apply_size_hint( + e, + fallback=config.unbacked_symint_fallback, + ) + for e in V.graph.get_allocation_size(buf) + ) + stride = tuple( + V.graph.sizevars.atomically_apply_size_hint( + e, + fallback=config.unbacked_symint_fallback, + ) + for e in buf.get_stride() + ) + device = buf.get_device() + dtype = buf.get_dtype() + offset = V.graph.sizevars.size_hint( + buf.get_layout().offset, + fallback=config.unbacked_symint_fallback, + ) + value = f"generate_example_value({size}, {stride}, '{device}', {dtype}, {offset}, {allocation_size})" + self.kernel_autotune_calls.writeline(f"{buf_name} = {value}") + + if isinstance(raw_arg, ir.TMADescriptor): + # generate another line initializing a host-side TMA + # descriptor from the underlying buffer created above + value = self._generate_tma_descriptor_call( + desc=raw_arg, + apply_size_hints=True, + ) + buf_name = arg + self.kernel_autotune_calls.writeline(f"{buf_name} = {value}") + + return buf_name + elif issubclass(arg_type, sympy.Basic) or isinstance(arg, SymbolicCallArg): + # arg is a symbol or symbolic expression + if isinstance(arg, str): + if arg in self._meta_vars: + return arg + if raw_arg is None: + return "None" + arg = raw_arg + if isinstance(arg, SymbolicCallArg): + arg = arg.inner_expr + if arg in V.graph.sizevars.inv_precomputed_replacements: + arg = V.graph.sizevars.inv_precomputed_replacements[arg] + + return str( + V.graph.sizevars.atomically_apply_size_hint( + arg, fallback=config.unbacked_symint_fallback + ) + ) + + elif isinstance(arg, (str, int, float, bool)): + return str(arg) + elif isinstance(arg, list): + return f"[{', '.join(self.generate_example_arg_value(a, type(a)) for a in arg)}]" + else: + raise NotImplementedError(f"Unsupported type {type(arg)}") + + def _grid_dim_str(self, grid_per_dim): + if isinstance(grid_per_dim, list): + return ( + "[" + ", ".join(self._grid_dim_str(item) for item in grid_per_dim) + "]" + ) + else: + return pexpr(grid_per_dim) + + def generate_kernel_call( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + original_fxnode_name=None, + ): + """ + Generates kernel call code. + + triton: Defines whether the backend uses Triton for codegen. Otherwise it uses the CUDA language when gpu=True, + and C++ when gpu=False. + """ + + # Store buffers corresponding to each call arg. + # This is used to generate example args for autotuning later on. + self.args_to_buffers.update( + { + arg: V.graph.try_get_buffer(arg) + for arg in call_args + if isinstance(arg, str) + } + ) + + device = device or V.graph.get_current_device_or_throw() + self.writeline( + KernelCallLine( + self, + kernel_name=kernel_name, + call_args=call_args, + raw_keys=raw_keys, + raw_args=raw_args, + arg_types=arg_types, + triton=triton, + triton_meta=triton_meta, + device=device, + graph_name=V.graph.name, + original_fxnode_name=original_fxnode_name, + ) + ) + + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + graph_name="", + original_fxnode_name=None, + ): + device = device or V.graph.get_current_device_or_throw() + if not (triton or device.type != "cpu"): + self.writeline(self.wrap_kernel_call(kernel_name, call_args)) + return + + call_args_str = self.prepare_triton_kernel_call(call_args) + call_args_str = ", ".join(call_args_str) + stream_name = PythonWrapperCodegen.write_get_raw_stream( + self, device.index, graph_name + ) + if not triton: + stream_ptr = f"c_void_p({stream_name})" + self.writeline( + f"{kernel_name}.{kernel_name}({call_args_str}, {stream_ptr})" + ) + return + + self.write_triton_header_once() + + if ( + config.triton.autotune_at_compile_time + and kernel_name not in self.kernel_autotune_names + ): + # Create example args for autotune in a separate epilogue + assert arg_types is not None and len(call_args) == len(arg_types), ( + "call_args and arg_types do not match" + ) + + autotune_args = None + if original_fxnode_name and V.graph.autotuning_mapping: + autotune_args = V.graph.autotuning_mapping.get( + original_fxnode_name, None + ) + + def get_autotune_deletion_call() -> str: + """After all the autotune kernel calls have been written (i.e. + self.kernel_autotune_example_args is complete), returns a deletion call + for all autotune example tensors that are unnecessary after kernel_name + is called.""" + tensors_to_delete = [ + tensor + for tensor, kn in self.kernel_autotune_example_args.values() + if kn == kernel_name + ] + if tensors_to_delete: + return f"del {', '.join(tensors_to_delete)}\n" + return "" + + def infer_arg_by_inputs(raw_keys, raw_args, idx, reused_args): + """We try to infer raw_arg (i.e. raw_args[idx]) from remaining raw_args. + This is particularly useful for jagged cases, where the dimension is often + being passed in as an input.""" + + target_arg = raw_args[idx] + if target_arg in reused_args: + return True + + for i, (raw_key, raw_arg) in enumerate(zip(raw_keys, raw_args)): + if i == idx or not isinstance(raw_arg, IRNode): + continue + + triton_input = "" + if autotune_args and raw_key in autotune_args: + triton_input = self.get_autotuning_input_name( # type: ignore[attr-defined] + autotune_args[raw_key] + ) + if triton_input == "": + continue + + try: + layout = raw_arg.get_layout() + for dim, s in enumerate(layout.size): + if s == target_arg: + reused_args[target_arg] = f"{triton_input}.shape[{dim}]" + return True + except NotImplementedError: + # If layout for this IRNode is not implemented, we could just skip. + # Only raise for other Error cases. + continue + return False + + all_args = [] + if raw_args is None: + # create a dummy raw_args for uniform behavior in the following loop + assert raw_keys is None, "keys are not None but args are" + raw_keys = [None] * len(call_args) + raw_args = [None] * len(call_args) + else: + assert len(raw_args) == len(call_args), ( + "call_args and raw_args do not match" + ) + + reused_args = {} + for i, (arg, arg_type, raw_key, raw_arg) in enumerate( + zip(call_args, arg_types, raw_keys, raw_args) + ): + key = None + if isinstance(arg, str) and "=" in str(arg): + # arg may be passed in a kwarg style, and then we need to extract its value + key, arg = arg.split("=") + + triton_input: Optional[str] = None + if autotune_args and raw_key in autotune_args: + triton_input = self.get_autotuning_input_name( # type: ignore[attr-defined] + autotune_args[raw_key] + ) + + if triton_input: + arg_str = triton_input + if not isinstance(arg_type, torch_dtype) and ( + issubclass(arg_type, sympy.Basic) + or isinstance(arg, SymbolicCallArg) + ): + reused_args[raw_arg] = arg_str + elif raw_key == "" and infer_arg_by_inputs( + raw_keys, raw_args, i, reused_args + ): + # Empty raw_key means this is a arg that's not native to the triton kernel, + # and is being added by inductor. + arg_str = reused_args[raw_arg] + elif isinstance(arg_type, torch_dtype): + # workspace allocation is already generated by `generate_workspace_allocation()` + # in `TritonKernel.call_kernel()`. + if re.match(r"^(workspace|semaphore)", arg): + arg_str = arg + elif arg not in self.kernel_autotune_example_args: + arg_str = self.generate_example_arg_value( + arg, arg_type, raw_arg + ) + else: + arg_str = self.kernel_autotune_example_args[arg][0] + self.kernel_autotune_example_args[arg] = (arg_str, kernel_name) + else: + arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg) + all_args.append(arg_str if key is None else f"{key}={arg_str}") + + self.kernel_autotune_calls.writeline( + f"{kernel_name}.run({', '.join(all_args)}, stream={stream_name})" + ) + self.kernel_autotune_calls.writeline( + DelayReplaceLine("", get_autotune_deletion_call, "") + ) + self.kernel_autotune_names.add(kernel_name) + if V.graph.cpp_wrapper: + # For cpp wrapper, no need to continue codegen for the main body + return + + # add debug printer code for triton kernel calls at (jit) inductor level + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args(call_args, kernel_name, arg_types, None) + with debug_printer_manager: + self.writeline(f"{kernel_name}.run({call_args_str}, stream={stream_name})") + self.write_triton_header_once() + + def writeline(self, line): + self.lines.append(line) + + def writelines(self, lines): + for line in lines: + self.writeline(line) + + def enter_context(self, ctx): + self.lines.append(LineContext(ctx)) + + def val_to_arg_str(self, s, type_=None): + from torch.utils._triton import has_triton_package + + if has_triton_package(): + import triton + + if isinstance(s, SymTypes): + return pexpr(s.node.expr) + elif isinstance(s, sympy.Expr): + return pexpr(s) + elif isinstance(s, (tuple, list)): + + @dataclasses.dataclass + class Shim: + ref: Any + + def __repr__(self): + return self.ref + + # Explicitly call the Python version of val_to_arg_str + return repr( + type(s)(Shim(PythonWrapperCodegen.val_to_arg_str(self, a)) for a in s) + ) + elif isinstance(s, torch._ops.OpOverload): + return _get_qualified_name(s) + elif isinstance(s, (ir.Buffer, ir.MutableBox, ReinterpretView)): + return s.codegen_reference() + elif has_triton_package() and isinstance(s, triton.language.dtype): # type: ignore[possibly-undefined] + return repr(s) + elif isinstance(s, ir.GeneratorState): + return s.codegen_reference() + else: + return repr(s) + + # The following methods are for memory management + def make_buffer_allocation(self, buffer: BufferLike): + device = buffer.get_device() + dtype = buffer.get_dtype() + shape = tuple(buffer.get_size()) + allocation_shape = tuple(V.graph.get_allocation_size(buffer)) + stride = tuple(buffer.get_stride()) + return self.make_allocation( + buffer.get_name(), device, dtype, shape, stride, allocation_shape + ) + + def make_allocation( + self, name, device, dtype, shape, stride, allocation_shape=None + ): + if allocation_shape is None: + allocation_shape = shape + + codegen_shape_tuple = self.codegen_python_shape_tuple(shape) + codegen_allocation_shape_tuple = self.codegen_python_shape_tuple( + allocation_shape + ) + codegen_stride_tuple = self.codegen_python_shape_tuple(stride) + if device.type in ("cpu", "cuda", "xpu"): + # optimized path for faster allocations, saving ~2us versus the stuff below + out = ( + f"{name} = empty_strided_{device.type}(" + f"{codegen_allocation_shape_tuple}, " + f"{codegen_stride_tuple}, " + f"{dtype})" + ) + # all other devices: + else: + out = ( + f"{name} = empty_strided(" + f"{codegen_allocation_shape_tuple}, " + f"{codegen_stride_tuple}, " + f"device='{device.type}', dtype={dtype})" + ) + if codegen_shape_tuple != codegen_allocation_shape_tuple: + # need an extra as_strided call + out = out + f".as_strided({codegen_shape_tuple}, {codegen_stride_tuple})" + return out + + def make_comment(self, line): + self.writeline(CommentLine(line)) + + def make_tensor_alias(self, new_name, old_name, comment=""): + return f"{self.declare}{new_name} = {old_name}{self.ending} {self.comment} {comment}" + + def make_buffer_free(self, buffer: Union[BufferLike, ir.TorchBindObject]): + return f"del {buffer.get_name()}" + + def make_free_by_names(self, names_to_del: list[str]): + return f"del {', '.join(name for name in names_to_del)}" + + def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str): + return f"{self.declare_maybe_reference}{new_name} = {old_name}{del_line}{self.ending} {self.comment} reuse" + + def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool): + assert old.get_dtype() == new.get_dtype() + old_name = old.get_name() + new_name = new.get_name() + del_line = ";" + if old_name not in V.graph.get_output_names() and delete_old: + del_line = f"; {self.make_buffer_free(old)}" + + if old.get_size() == new.get_size() and old.get_stride() == new.get_stride(): + return self.codegen_exact_buffer_reuse(old_name, new_name, del_line) + + reinterpret_view = self.codegen_reinterpret_view( + old, new.get_size(), new.get_stride(), 0, self.wrapper_call.writeline + ) + return f"{self.declare}{new_name} = {reinterpret_view}{del_line} {self.comment} reuse" + + def codegen_deferred_allocation(self, name: str, view: ir.ReinterpretView) -> None: + self.writeline( + DeferredLine( + name, + f"{self.declare}{name} = {view.codegen_reference()}{self.ending} {self.comment} alias", + ) + ) + + def codegen_allocation(self, buffer: ir.Buffer): + name = buffer.get_name() + + if ( + name in V.graph.removed_buffers + or name in self.allocated + or isinstance(buffer, (ir.DonatedBuffer, ir.SubgraphBuffer)) + ): + return + self.allocated.add(name) + if ( + isinstance( + buffer.get_defining_op(), + (ir.ExternKernelAlloc, ir.MultiOutput), + ) + and not buffer.should_allocate() + ): + return + + layout = buffer.get_output_spec() + if isinstance(layout, ir.MutationLayoutSHOULDREMOVE): + return + if isinstance(layout, ir.NoneLayout): + return + if isinstance(layout, ir.NonOwningLayout): + assert isinstance(layout.view, ir.ReinterpretView), ( + f"unexpected {type(layout.view)}: {layout.view}" + ) + box = layout.view.data + assert isinstance(box, ir.StorageBox), type(box) + input_buffer = box.data + assert isinstance(input_buffer, ir.Buffer), type(box) + self.codegen_allocation(input_buffer) + self.writeline(ReinterpretLine(self, input_buffer, buffer, layout)) + return + + if isinstance(layout, ir.CommBufferLayout): + self.writeline(CommBufferAllocateLine(self, buffer)) + return + + self.writeline(AllocateLine(self, buffer)) + + def codegen_free(self, buffer): + name = buffer.get_name() + + # can be freed but not reused + if isinstance(buffer, (ir.InputBuffer, ir.TorchBindObject)): + self.writeline(FreeLine(self, buffer)) + return + + if isinstance(buffer.get_output_spec(), ir.CommBufferLayout): + # Comm buffers are not eligible for in-place reuse. Their reuse is + # achieved exclusively via buffer planning. + self.writeline(CommBufferFreeLine(self, buffer)) + return + + if not self.can_reuse(buffer): + return + self.freed.add(name) + + self.writeline(FreeIfNotReusedLine(self, buffer)) + + def can_reuse(self, input_buffer, output_buffer=None): + name = input_buffer.get_name() + return not ( + name in V.graph.removed_buffers + or ( + name in V.graph.graph_inputs + and not isinstance( + V.graph.graph_inputs_original[name], ir.DonatedBuffer + ) + ) + or name in V.graph.constants + or name in V.graph.torchbind_constants + or name in V.graph.never_reuse_buffers + or name in self.freed + ) + + def did_reuse(self, buffer, reused_buffer): + # Check whether a given buffer was reused by a possible reuser in the wrapper codegen + # Can be consulted from inside ir codegen, e.g. to determine whether a copy is needed + return ( + buffer.get_name() in self.reuses + and self.reuses[buffer.get_name()] == reused_buffer.get_name() + ) + + def codegen_inplace_reuse(self, input_buffer: ir.Buffer, output_buffer: ir.Buffer): + assert can_match_buffer_size(input_buffer, output_buffer) + self.codegen_allocation(input_buffer) + self.freed.add(input_buffer.get_name()) + self.allocated.add(output_buffer.get_name()) + self.reuses[output_buffer.get_name()] = input_buffer.get_name() + self.writeline(ReuseLine(self, input_buffer, output_buffer)) + + def codegen_unbacked_symbol_decl(self, symbol): + name = str(symbol) + if name in self.unbacked_symbol_decls: + return name + else: + # When in CppWrapperCpu, we should only generate the declaration once + self.unbacked_symbol_decls.add(name) + return self.declare + name + + def codegen_unbacked_symbol_defs_for_outputs( + self, + output_name: str, + outputs: Any, + unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]], + ) -> None: + unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, unbacked_bindings + ) + + if not unbacked_bindings: + return + + # This code is designed to generate code expressions from symbolic paths (keypaths) + # associated with certain symbols (unbacked bindings). These keypaths describe how + # to access the unbacked symbol in a structured way. + # For example, we might want to generate "u0 = outs[0].stride(1)"", where s = u0, and the keypath + # describes the structure of "outs[0].stride(1)", like [SequenceKey(0), CallMethodKey("stride"), SequenceKey[1]]. + for s, keypath in unbacked_bindings.items(): + # `go` recursively constructs a code expression by processing each element of + # the keypath and construct the expression incrementally. + # For example, given output name outs and keypath [SequenceKey(0), CallMethodKey("stride", 1)], + # it generates "outs[0]" based on SequenceKey(0), then recursively go("outs[0]", [CallMethodKey("stride"), ...]) + def go(expr: str, keypath: pytree.KeyPath): + if keypath == (): + return expr + + if ( + len(keypath) >= 2 + and isinstance(keypath[0], CallMethodKey) + and isinstance(keypath[1], pytree.SequenceKey) + ): + return go( + f"{expr}.{keypath[0].name}({keypath[1].idx})", keypath[2:] + ) + elif isinstance(keypath[0], CallMethodKey): + return go(f"{expr}.{keypath[0].name}()", keypath[1:]) + elif isinstance(keypath[0], pytree.SequenceKey): + return ( + go(f"std::get<{keypath[0].idx}>({expr})", keypath[1:]) + if V.graph.cpp_wrapper + else go(f"{expr}[{keypath[0].idx}]", keypath[1:]) + ) + elif isinstance(keypath[0], DivideByKey): + # TODO: need to assert divisibility + # TODO: this is invalid C++ codegen + return go(f"{expr}.__floordiv__({keypath[0].divisor})", keypath[1:]) + else: + raise AssertionError(f"unrecognized keypath {keypath}") + + # `go_outer` manages the top-level logic for generating the final expression. + # It handles special cases for C++ code generation and adjusts + # the keypath based on the context (e.g., single vs. multiple outputs). + def go_outer(): # type: ignore[no-untyped-def] + if V.graph.cpp_wrapper: + # Special handling for the top level buffer access, + # because self.get_name() is actually never bound; the + # individual output arguments are bound by + # generate_c_shim_fallback_kernel + if len(outputs) == 1: + out = outputs[0] + # When fallback kernel returns a list consisting of a single tensor, + # the output is represented as a MultiOutput with non empty indices. + # In this case, we strip the first key path away. + return go( + outputs[0].get_name(), + keypath[1:] + if isinstance(out, ir.MultiOutput) and len(out.indices) != 0 + else keypath, + ) + else: + assert isinstance(keypath[0], pytree.SequenceKey) + return go(outputs[keypath[0].idx].get_name(), keypath[1:]) + else: + return go(output_name, keypath) + + self.writeline( + f"{self.codegen_unbacked_symbol_decl(s)} = {go_outer()}{self.ending}" + ) + + def codegen_subgraph_by_inlining(self, subgraph, outer_inputs, outer_outputs): + # TODO (desertfire) - This function is the old way of supporting + # subgraph codegen by inlining subgraphs in the output code. For python + # wrapper, we have moved to lifting subgraphs as functions, supported by + # `codegen_subgraph` function. + # + # However this does not work with cpp wrapper. With cpp wrapper, we make + # two passes and the kernels are shared from the first pass to the next. + # Therefore, both the Python and CppWrapper need to share the some + # codegen infra. For now, CppWrapperCpu has not been updated to lift the + # subgraph as functions. Therefore for cpp_wrapper first pass with + # PythonWrapper, we still fallback to the old way of inlining subgraphs + # in the output code. Once we update CppWrapperCpu, we can remove this + # function. + def _codegen_subgraph_prefix(): + assert len(subgraph.graph.graph_inputs) == len(outer_inputs) + for inner_input, outer_input in zip( + subgraph.graph.graph_inputs, outer_inputs + ): + self.writeline( + f"{self.declare}{inner_input} = {outer_input}{self.ending}" + ) + + def _codegen_subgraph_suffix(): + assert len(subgraph.graph.graph_outputs) == len(outer_outputs) + for inner_output, outer_output in zip( + subgraph.graph.graph_outputs, outer_outputs + ): + self.writeline( + f"{outer_output} = {inner_output.codegen_reference()}{self.ending}" + ) + + try: + self.push_codegened_graph(subgraph.graph) + self.writeline(f"{self.comment} subgraph: {subgraph.name}") + _codegen_subgraph_prefix() + parent_graph = V.graph + with V.set_graph_handler(subgraph.graph): + subgraph.graph.codegen_subgraph( + parent_graph=parent_graph, + ) + _codegen_subgraph_suffix() + finally: + self.pop_codegened_graph() + + def codegen_partition_call( + self, + partition_id: int, + partition_signatures: ir.GraphPartitionSignature, + ): + """Generate code to call a graph partition""" + input_deallocation = partition_signatures.input_deallocation + output_nodes = partition_signatures.output_nodes + + input_names = list(input_deallocation.keys()) + [ + symbol_input.name for symbol_input in partition_signatures.symbol_inputs + ] + + inputs = ", ".join(input_names) + ("," if len(input_names) == 1 else "") + + output_names = [node.get_name() for node in output_nodes] + outputs = ", ".join(output_names) + ("," if len(output_nodes) == 1 else "") + + # Create a list of inputs for the subgraph call + self.writeline(f"partition{partition_id}_args = [{inputs}]") + + names_to_del = [ + name for name, deallocate in input_deallocation.items() if deallocate + ] + if names_to_del: + self.writeline(f"del {', '.join(names_to_del)}") + + # Call the subgraph launcher function + self.writeline( + f"({outputs}) = self.partitions[{partition_id}](partition{partition_id}_args)" + ) + self.writeline(f"del partition{partition_id}_args") + + def set_all_partition_names(self, num_partitions: int): + self.all_partition_names = [f"partition_{idx}" for idx in range(num_partitions)] + + def codegen_subgraph_call_with_flattened_outputs( + self, subgraph, outer_inputs, outer_flattened_outputs + ): + # Get the input and output names of the subgraph + outer_output_names = ", ".join(outer_flattened_outputs) + ( + "," if len(outer_flattened_outputs) == 1 else "" + ) + outer_input_names = ", ".join(outer_inputs) + ( + "," if len(outer_inputs) == 1 else "" + ) + + self.writeline(f"{subgraph.graph.name}_args = [{outer_input_names}]") + + # Call the subgraph launcher function + self.writeline( + f"({outer_output_names}) = {subgraph.graph.name}({subgraph.graph.name}_args)" + ) + + def codegen_subgraph_call(self, subgraph, outer_inputs, outer_buffer_name): + # Get the input and output names of the subgraph + outer_input_names = ", ".join(outer_inputs) + ( + "," if len(outer_inputs) == 1 else "" + ) + + self.writeline(f"{subgraph.graph.name}_args = [{outer_input_names}]") + + # Since the buffers are already put into the args list, we can free the + # buffers here. + V.graph.scheduler.free_buffers() + + # Call the subgraph launcher function + self.writeline( + f"{outer_buffer_name} = {subgraph.graph.name}({subgraph.graph.name}_args)" + ) + + def codegen_subgraph_common(self, subgraph): + self.push_codegened_graph(subgraph.graph) + self.writeline("") + self.writeline(f"{self.comment} subgraph: {subgraph.name}") + + parent_graph = V.graph + subgraph.graph.cpp_wrapper = parent_graph.cpp_wrapper + + if subgraph.graph.name not in self.already_codegened_subgraphs: + # If it is already codegened, the parent wrapper already has + # subgraph fn by name subgraph.graph.name + with V.set_graph_handler(subgraph.graph): + # do not graph partition for subgraph + with config.patch("graph_partition", False): + # Call the codegen of subgraph recursively + subgraph_code, _ = subgraph.graph.codegen() + self.already_codegened_subgraphs.add(subgraph.graph.name) + self.define_subgraph_launcher_fn(subgraph_code.value) + + def codegen_subgraph_with_flattened_outputs( + self, subgraph, outer_inputs, outer_flattened_outputs + ): + self.codegen_subgraph_common(subgraph) + self.codegen_subgraph_call_with_flattened_outputs( + subgraph, outer_inputs, outer_flattened_outputs + ) + + def codegen_subgraph(self, subgraph, outer_inputs, outer_buffer_name): + # Codegen subgraph by recursively calling the codegen for the subgraph. + # This lifts the subgraph as a function in the output code. + self.codegen_subgraph_common(subgraph) + self.codegen_subgraph_call(subgraph, outer_inputs, outer_buffer_name) + + def codegen_invoke_subgraph(self, invoke_subgraph): + name = invoke_subgraph.get_name() + + self.writeline(f"{name} = [None] * {len(invoke_subgraph.outputs)}") + outer_inputs = [buf.codegen_reference() for buf in invoke_subgraph.inputs] + + if V.graph.aot_mode: + outer_outputs = [ + f"{name}[{i}]" for i in range(len(invoke_subgraph.outputs)) + ] + self.codegen_subgraph_by_inlining( + invoke_subgraph.subgraph, outer_inputs, outer_outputs + ) + else: + self.codegen_subgraph(invoke_subgraph.subgraph, outer_inputs, name) + + def codegen_conditional(self, conditional): + name = conditional.get_name() + + outer_inputs = [buf.codegen_reference() for buf in conditional.operands] + + predicate = conditional.predicate.codegen_reference() + if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer): + # move the Tensor predicate to host + predicate = f"{predicate}.item()" + + self.writeline(f"{name} = [None] * {len(conditional.outputs)}") + self.writeline(f"if {predicate}:") + self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph)) + if V.graph.aot_mode: + outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))] + self.codegen_subgraph_by_inlining( + conditional.true_subgraph, outer_inputs, outer_outputs + ) + else: + self.codegen_subgraph(conditional.true_subgraph, outer_inputs, name) + + self.writeline(ExitSubgraphLine(self)) + self.writeline("else:") + self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph)) + if V.graph.aot_mode: + outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))] + self.codegen_subgraph_by_inlining( + conditional.false_subgraph, outer_inputs, outer_outputs + ) + else: + self.codegen_subgraph(conditional.false_subgraph, outer_inputs, name) + self.writeline(ExitSubgraphLine(self)) + + def codegen_while_loop(self, while_loop): + name = while_loop.get_name() + outer_carried_inputs = [ + buf.codegen_reference() for buf in while_loop.carried_inputs + ] + outer_additional_inputs = [ + buf.codegen_reference() for buf in while_loop.additional_inputs + ] + + self.writeline(f"{name} = [None] * {len(outer_carried_inputs)}") + for i, inp in enumerate(outer_carried_inputs): + # set the initial state before the loop + self.writeline(f"{name}[{i}] = {inp}") + + cond_outer_inputs = [ + *[f"{name}[{i}]" for i in range(len(outer_carried_inputs))], + *outer_additional_inputs, + ] + cond_outer_outputs = [f"{name}_cond_result"] + body_outer_inputs = list( + cond_outer_inputs + ) # same inputs for cond_fn and body_fn + # Carry over the state from body_fn. Note: We only carry over + # the carried_inputs part of the inputs, the additional ones + # are passed in as they're before. + body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)] + + self.writeline("while True:") + self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph)) + + if V.graph.aot_mode: + self.codegen_subgraph_by_inlining( + while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs + ) + else: + self.codegen_subgraph_with_flattened_outputs( + while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs + ) + self.writeline( + f"if not {cond_outer_outputs[0]}: break" + ) # condition doesn't hold + self.writeline(ExitSubgraphLine(self)) + self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) + if V.graph.aot_mode: + self.codegen_subgraph_by_inlining( + while_loop.body_subgraph, body_outer_inputs, body_outer_outputs + ) + else: + self.codegen_subgraph_with_flattened_outputs( + while_loop.body_subgraph, body_outer_inputs, body_outer_outputs + ) + self.writeline(ExitSubgraphLine(self)) + + @staticmethod + def statically_known_int_or_none(x): + try: + if getattr(x, "free_symbols", None): + # _maybe_evaluate_static will return (s0 // (2 // s0)) as 2, but + # the actual codegen will still generate the full expression here. + return None + if isinstance(x, int): + return x + val = V.graph._shape_env._maybe_evaluate_static(x) + if val is None: + return val + return int(val) # type: ignore[call-overload] + except Exception: + return None + + @staticmethod + def statically_known_list_of_ints_or_none(lst): + result = [] + for x in lst: + num = PythonWrapperCodegen.statically_known_int_or_none(x) + if num is None: + return None + result.append(num) + return result + + @staticmethod + def is_statically_known_list_of_ints(lst): + return ( + PythonWrapperCodegen.statically_known_list_of_ints_or_none(lst) is not None + ) + + @staticmethod + def static_shape_for_buffer_or_none(buffer): + return PythonWrapperCodegen.statically_known_list_of_ints_or_none( + buffer.get_size() + ) + + @staticmethod + def can_prove_buffer_has_static_shape(buffer): + return PythonWrapperCodegen.static_shape_for_buffer_or_none(buffer) is not None + + +class SubgraphPythonWrapperCodegen(PythonWrapperCodegen): + """ + A wrapper codegen that generates code for a subgraph. For most of the + methods, we rely on the implementation in the PythonWrapperCodegen. But we + override a few functions to produce cleaner code (like avoiding writing + imports twice in the output code) + """ + + def __init__( + self, + subgraph_name: str, + parent_wrapper: PythonWrapperCodegen, + partition_signatures: Optional[ir.GraphPartitionSignature] = None, + ): + # It is necessary to set the subgraph_name before calling super __init__ + # because __init__ calls set_launcher_fn_name + self.subgraph_name = subgraph_name + self.parent_wrapper = parent_wrapper + self.partition_signatures = partition_signatures + + super().__init__() + + def set_launcher_fn_name(self) -> None: + # This sets up the name of the function containing the launcher code of + # the subgraph. + self.launcher_fn_name = self.subgraph_name + + def write_header(self) -> None: + pass + + def add_benchmark_harness(self, output): + pass + + def benchmark_compiled_module(self, output): + pass + + def write_async_compile_wait(self): + pass + + def next_kernel_suffix(self) -> str: + # Ensures that subgraphs kernels do not clash with each other + return self.parent_wrapper.next_kernel_suffix() + + def generate_after_suffix(self, result: IndentedBuffer) -> None: + return + + def write_launcher_fn_call_get_indent(self) -> int: + self.prefix.splice( + f""" + def {self.launcher_fn_name}(args): + """ + ) + prefix_indent = 1 + return prefix_indent + + def get_wrapper_call_indent(self) -> int: + return 1 + + def get_graph_inputs( + self, + ) -> dict[str, Union[ir.TensorBox, ir.TorchBindObject, sympy.Expr]]: + if signature := self.partition_signatures: + inputs = signature.input_nodes | { + str(s): s for s in signature.symbol_inputs + } + else: + inputs = V.graph.graph_inputs + return inputs + + def get_graph_input_names(self) -> list[str]: + if signature := self.partition_signatures: + names = list(signature.input_nodes.keys()) + [ + symbol_input.name for symbol_input in signature.symbol_inputs + ] + else: + names = V.graph.graph_input_names + return names + + def get_graph_outputs(self) -> list[IRNode]: + if signature := self.partition_signatures: + outputs = signature.output_nodes + else: + outputs = V.graph.graph_outputs + return outputs + + def codegen_allocation(self, buffer: ir.Buffer): + name = buffer.get_name() + if (signature := self.partition_signatures) and name in signature.input_nodes: + # skip allocation if buffer is a subgraph input. + # This allows reusing an input buffer in graph partition, + # although this is not allowed in general. + return + + super().codegen_allocation(buffer) + + @cache_on_self + def write_triton_header_once(self) -> None: + # TODO: Uncomment in future. This will be needed to support subgraph + # codegen for cpp wrapper. + # if config.triton.autotune_at_compile_time: + # import_str = self.triton_header_str() + # self.kernel_autotune_calls.splice(import_str) + self.parent_wrapper.write_triton_header_once() + + @cache_on_self + def write_get_raw_stream_header_once(self) -> None: + # TODO: Uncomment in future. This will be needed to support subgraph + # codegen for cpp wrapper. + # if config.triton.autotune_at_compile_time: + # self.kernel_autotune_calls.writeline( + # V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + # ) + self.parent_wrapper.write_get_raw_stream_header_once() diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/wrapper_fxir.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/wrapper_fxir.py new file mode 100644 index 0000000000000000000000000000000000000000..438203f911f994d7b9b6f831aaa58c4246187702 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/wrapper_fxir.py @@ -0,0 +1,693 @@ +import dataclasses +import functools +import logging +import operator +import textwrap +from collections import Counter +from typing import Any, Callable, Optional, Union + +import sympy + +import torch +from torch._higher_order_ops.triton_kernel_wrap import ( + TraceableTritonKernelWrapper, + tracing_triton_hopifier_singleton, + triton_kernel_wrapper_mutation, +) +from torch._inductor.codecache import PyCodeCache +from torch._inductor.runtime.triton_heuristics import CachingAutotuner +from torch._inductor.select_algorithm import extern_kernels # noqa: F401 +from torch._inductor.utils import sympy_product +from torch._inductor.virtualized import V +from torch._library.triton import wrap_triton +from torch.fx import GraphModule +from torch.utils import _pytree as pytree +from torch.utils._sympy.functions import FloorDiv + +from .. import config, ir +from ..utils import convert_shape_to_symint, convert_to_symint, LineContext +from .common import ( + CodegenSymbol, + FileBackedGraphModule, + WorkspaceArg, + WorkspaceZeroMode, +) +from .wrapper import ( + AllocateLine, + BufferLike, + CommBufferAllocateLine, + CommBufferFreeLine, + CommentLine, + EnterDeviceContextManagerLine, + EnterSubgraphLine, + ExitDeviceContextManagerLine, + ExitSubgraphLine, + ExternKernelAllocLine, + ExternKernelOutLine, + FreeIfNotReusedLine, + FreeLine, + KernelCallLine, + KernelDefinitionLine, + Line, + MultiOutputLine, + NullLine, + PythonWrapperCodegen, + ReinterpretLine, + ReuseLine, + SymbolicCallArg, + SymbolicCallArgLine, + WrapperLine, +) + + +aten = torch.ops.aten +log = logging.getLogger(__name__) + + +@dataclasses.dataclass +class SymbolBuffer(CodegenSymbol): + """ + Represents a sympy.Symbol graph input. + """ + + symbol: sympy.Symbol + + def get_name(self) -> str: + return str(self.symbol) + + def get_example(self) -> Union[torch.Tensor, sympy.Symbol]: + return self.symbol + + +CodegenBuffer = Union[BufferLike, SymbolBuffer] + + +@dataclasses.dataclass +class TritonKernel: + """ + Stores metadata about Triton kernels for use in FX. + """ + + tuner: CachingAutotuner + wrapped: TraceableTritonKernelWrapper + + +class WrapperFxCodegen(PythonWrapperCodegen): + """ + Backend to generate wrapper code as an FX IR graph. + """ + + supports_caching = False + + def _generate(self, is_inference: bool) -> tuple[FileBackedGraphModule, None]: + self.run_wrapper_ir_passes(is_inference) + + prologue = "\n".join( + [ + self.imports.getvalue(), + self.header.getvalue(), + ] + ) + gm = FxConverter(lines=self.lines, prologue=prologue).generate() + compiled_fn = self.compile_graph(gm) + + return FileBackedGraphModule(gm, compiled_fn), None + + def compile_graph(self, gm: GraphModule) -> Callable[..., Any]: + """ + Converts the graph module into a runnable function. The default implementation + is simply an interpreter calling kernels in eager mode. Derived backends can + override this to do further compilation. + """ + return gm.forward + + @classmethod + def create( + cls, + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[ir.GraphPartitionSignature] = None, + ) -> "WrapperFxCodegen": + if is_subgraph: + raise NotImplementedError( + "Subgraphs are not yet supported by FX conversion" + ) + + # For derived backends, this could be a subclass. + return cls() + + +@dataclasses.dataclass +class FxConverter: + """ + Generates FX IR from Wrapper IR. As each instance is only meant to be used once, the + input and output code are stored as attributes. + """ + + lines: list[Line] + prologue: str = "" + + def __post_init__(self) -> None: + graph = torch.fx.Graph() + self.gm = GraphModule({}, graph) # Wrapper FX IR. + self.buffer_to_node: dict[ + Optional[str], torch.fx.Node + ] = {} # Symbol table for codegen. + self.kernels: dict[str, TritonKernel] = {} # Table to store Triton kernels. + self._unique_symbol_ids: Counter[str] = Counter() + + def _import_kernel(self, code: str, kernel_name: str) -> CachingAutotuner: + """ + Imports a kernel from source, possibly autotuning block parameters. + """ + module_code = "\n".join([self.prologue, code]) + mod = PyCodeCache.load(module_code) + kernel = getattr(mod, kernel_name) + + if not isinstance(kernel, CachingAutotuner): + raise NotImplementedError( + textwrap.dedent(f""" + Unsupported type for kernel {kernel_name}: {type(kernel)}. + FX conversion only supports Triton kernels. + """) + ) + + return kernel + + def _fake_tensor( + self, + size: tuple[Any, ...], + stride: tuple[Any, ...], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + with V.fake_mode: + return torch.empty_strided( + convert_shape_to_symint(size), + convert_shape_to_symint(stride), + dtype=dtype, + device=device, + ) + + def _create_meta_from_buffer( + self, node: torch.fx.Node, buffer: CodegenBuffer + ) -> None: + name = buffer.get_name() + assert name + node.name = name + node.meta["val"] = buffer.get_example() + + def _create_as_strided( + self, + input_node: torch.fx.Node, + size: tuple[Any, ...], + stride: tuple[Any, ...], + offset: Union[int, sympy.Expr], + ) -> torch.fx.Node: + return self.gm.graph.call_function( + torch.as_strided, + args=( + input_node, + convert_shape_to_symint(size), + convert_shape_to_symint(stride), + convert_to_symint(offset), + ), + ) + + def _record_allocation(self, buffer: CodegenBuffer, node: torch.fx.Node) -> None: + """ + Updates the symbol table to record that an Inductor buffer maps to the result of + an FX node. + """ + assert node not in self.buffer_to_node + self.buffer_to_node[buffer.get_name()] = node + + def _free(self, buffer: Union[CodegenBuffer, ir.TorchBindObject]) -> None: + """ + Removes the buffer from the symbol table. + """ + name = buffer.get_name() + del self.buffer_to_node[name] + + def _lookup_args(self, args: tuple[Any, ...]) -> tuple[Any, ...]: + """ + Maps call args back to FX nodes. + """ + return tuple( + self.buffer_to_node[arg] + if isinstance(arg, str) + else arg.inner_expr + if isinstance(arg, SymbolicCallArg) + else arg + for arg in args + ) + + def _get_buffer(self, node: ir.IRNode) -> CodegenBuffer: + """ + Extract buffer data from an IR node. + """ + if isinstance(node, (ir.Buffer, WorkspaceArg)): + return node + elif isinstance(node, (ir.BaseView, ir.MutableBox)): + return self._get_buffer(node.data) + elif isinstance(node, sympy.Symbol): + return SymbolBuffer(node) + else: + raise NotImplementedError(f"Unable to extract buffer from node: {node}") + + def _generate_graph_inputs(self) -> None: + """ + Converts graph inputs to FX placeholders. + """ + for name, ir_node in V.graph.graph_inputs.items(): + # Introduce a new symbol for constant inputs. + buffer = ( + SymbolBuffer(sympy.Symbol(name, is_integer=True)) + if isinstance(ir_node, (int, float, sympy.Integer, sympy.Float)) + else self._get_buffer(ir_node) + ) + node = self.gm.graph.placeholder(buffer.get_name()) + self._create_meta_from_buffer(node, buffer) + self._record_allocation(buffer, node) + + def _generate_buffer(self, node: ir.IRNode) -> Optional[torch.fx.Node]: + """ + Generates FX IR for transformations on a buffer, such as ReinterpretView. + Does nothing if no such transformations are present. + """ + + def generate_to_buffer(node: ir.IRNode) -> Optional[BufferLike]: + if isinstance(node, (ir.Buffer, WorkspaceArg)): + return node + elif isinstance(node, ir.NoneAsConstantBuffer): + return None + elif isinstance(node, ir.StorageBox): + return generate_to_buffer(node.data) + elif isinstance(node, ir.ReinterpretView): + # We need to introduce a new symbol if the output is a ReinterpretView. + # Use a WorkspaceArg for this. + buffer = self._get_buffer(node.data) + assert isinstance(buffer, (ir.Buffer, WorkspaceArg)) + unique_name = self.gm.graph._graph_namespace.create_name( + f"{buffer.get_name()}_view", None + ) + device = buffer.get_device() + assert device + reused_as = WorkspaceArg( + count=buffer.get_size(), + zero_mode=WorkspaceZeroMode.UNINITIALIZED, + device=device, + outer_name=unique_name, + dtype=buffer.get_dtype(), + ) + + # Generate FX IR for the view. + self._generate_reinterpret_helper(buffer, reused_as, node.layout) + + return reused_as + else: + raise NotImplementedError(f"Unrecognized buffer/view node: {node}") + + buffer = generate_to_buffer(node) + return self.buffer_to_node[buffer.get_name()] if buffer is not None else None + + def _generate_output(self) -> None: + """ + Generate FX IR for graph outputs. + """ + output_nodes = [ + self._generate_buffer(node) + for idx, node in enumerate(V.graph.graph_outputs) + ] + + # Single return elements don't use a tuple. + output_value = output_nodes[0] if len(output_nodes) == 1 else output_nodes + + self.gm.graph.output(output_value) + + def generate(self) -> torch.fx.GraphModule: + """ + Main entrypoint for FX codegen. + """ + self._generate_graph_inputs() + + # Generate FX IR from Wrapper IR lines. + for line in self.lines: + if isinstance(line, WrapperLine): + line.codegen_fx(self)(line) + elif isinstance(line, LineContext): + # Ignore line context in FX IR. + pass + else: + raise NotImplementedError( + textwrap.dedent( + f""" + Found line of unrecognized type '{type(line)}': + '{line}' + + FX conversion only supports Wrapper IR lines. + """ + ) + ) + + self._generate_output() + self.gm.recompile() + return self.gm + + def _generate_allocate(self, line: WrapperLine) -> None: + assert isinstance(line, AllocateLine) + buffer = line.node + name = buffer.get_name() + assert name not in V.graph.removed_buffers + + device = buffer.get_device() + dtype = buffer.get_dtype() + shape = convert_shape_to_symint(buffer.get_size()) + stride = convert_shape_to_symint(buffer.get_stride()) + + node = self.gm.graph.call_function( + torch.empty_strided, + args=(shape, stride), + kwargs={"dtype": dtype, "device": device}, + ) + assert name + node.name = name + self._create_meta_from_buffer(node, buffer) + self._record_allocation(buffer, node) + + def _generate_comment(self, line: WrapperLine) -> None: + assert isinstance(line, CommentLine) + # We ignore comments in FX IR. + + def _generate_enter_device_context_manager(self, line: WrapperLine) -> None: + assert isinstance(line, EnterDeviceContextManagerLine) + # We ignore the device context in FX IR. + + def _generate_exit_device_context_manager(self, line: WrapperLine) -> None: + assert isinstance(line, ExitDeviceContextManagerLine) + # We ignore the device context in FX IR. + + def _generate_enter_subgraph(self, line: WrapperLine) -> None: + assert isinstance(line, EnterSubgraphLine) + raise NotImplementedError("Subgraphs are not yet supported by FX conversion") + + def _generate_exit_subgraph(self, line: WrapperLine) -> None: + assert isinstance(line, ExitSubgraphLine) + raise NotImplementedError("Subgraphs are not yet supported by FX conversion") + + def _generate_free(self, line: WrapperLine) -> None: + assert isinstance(line, FreeLine) + + buf = line.node + + # No need to free placeholders. + if self.buffer_to_node[buf.get_name()].op == "placeholder": + return + + self._free(buf) + + def _generate_free_if_not_reused(self, line: WrapperLine) -> None: + assert isinstance(line, FreeIfNotReusedLine) + buf = line.node + assert buf.get_name() not in V.graph.removed_buffers + if not line.is_reused: + self._free(buf) + + def _generate_line_context(self, line: WrapperLine) -> None: + assert isinstance(line, LineContext) + # We ignore line context in FX IR. + + def _generate_reinterpret(self, line: WrapperLine) -> None: + assert isinstance(line, ReinterpretLine) + self._generate_reinterpret_helper(line.node, line.reused_as, line.layout) + + def _generate_reinterpret_helper( + self, input_buffer: BufferLike, result_buffer: BufferLike, layout: ir.Layout + ) -> None: + input_node = self.buffer_to_node[input_buffer.get_name()] + + # Look up output metadata. + name = result_buffer.get_name() + assert name + size = tuple(layout.size) + stride = tuple(layout.stride) + if isinstance(layout, ir.NonOwningLayout): + # Look up the view's layout. + view = layout.view + assert isinstance(view, ir.ReinterpretView), ( + f"unexpected type: {type(view)}" + ) + layout = view.layout + offset = input_buffer.get_offset() + layout.offset + + # Map ReinterpretView to as_strided. + result_node = self._create_as_strided(input_node, size, stride, offset) + result_node.name = name + result_node.meta["val"] = layout.get_example() + self._record_allocation(result_buffer, result_node) + + def _generate_reuse(self, line: WrapperLine) -> None: + assert isinstance(line, ReuseLine) + old = line.node + new = line.reused_as + assert not any(buf.get_name() in V.graph.removed_buffers for buf in (old, new)) + assert old.get_dtype() == new.get_dtype() + + old_node = self.buffer_to_node[old.get_name()] + result_node = old_node + + # Change shape and stride. + size = tuple(new.get_size()) + stride = tuple(new.get_stride()) + offset = new.get_offset() + if ( + tuple(old.get_size()) != size + or tuple(old.get_stride()) != stride + or old.get_offset() != offset + ): + result_node = self._create_as_strided(old_node, size, stride, offset) + self._create_meta_from_buffer(result_node, new) + + self._record_allocation(new, result_node) + + # Free the old buffer, if we allocated a new tensor. + if ( + old.get_name() not in V.graph.get_output_names() + and line.delete_old + and result_node is not old_node + ): + self._free(old) + + def _generate_multi_output(self, line: WrapperLine) -> None: + assert isinstance(line, MultiOutputLine) + + # Extract the index for tuple access. + inds = line.indices[0][1:] + assert len(inds) == 1, f"Cannot convert {inds} to an index." + idx = inds[0] + + arg_node = self.buffer_to_node[line.arg_name] + node = self.gm.graph.call_function(operator.getitem, args=(arg_node, idx)) + node.meta["val"] = arg_node.meta["val"][idx] + node.name = line.result_name + self.buffer_to_node[line.result_name] = node + + def _generate_null(self, line: WrapperLine) -> None: + assert isinstance(line, NullLine) + # Does nothing. + + def _generate_comm_buffer_allocate(self, line: WrapperLine) -> None: + assert isinstance(line, CommBufferAllocateLine) + raise NotImplementedError("Comm buffer allocation is not yet supported") + + def _generate_comm_buffer_free(self, line: WrapperLine) -> None: + assert isinstance(line, CommBufferFreeLine) + self._free(line.node) + + def _generate_triton_call(self, line: WrapperLine) -> None: + assert isinstance(line, KernelCallLine) + + # Collect all kwargs, including autotuned block sizes. + call_args = self._lookup_args(line.call_args) + kernel = self.kernels[line.kernel_name] + tuner = kernel.tuner + + # Optionally autotune the kernels. + # The FX backend currently only supports compile-time tuning. + kernel_name = tuner.fn.__name__ + if config.triton.autotune_at_compile_time: + from triton.runtime import driver + + log.info("Autotuning Triton kernel %s at compile time.", kernel_name) + device = driver.active.get_current_device() + stream = driver.active.get_current_stream(device) + + def node_to_tuning_arg(arg: Any) -> Any: + """ + Create real tensors for autotuning arguments, substituting size hints + for dynamic shapes. + """ + to_size_hint = functools.partial( + pytree.tree_map, V.graph.sizevars.size_hint + ) + if not isinstance(arg, torch.fx.Node): + return to_size_hint(arg) + + fake = arg.meta["val"] + return torch.empty_strided( + to_size_hint(fake.shape), + to_size_hint(fake.stride()), + device=device, + ).zero_() + + arg_values = [node_to_tuning_arg(arg) for arg in call_args] + tuner.run(*arg_values, stream=stream) + else: + log.info( + "Skipping autotuning for kernel %s. Set config.triton.autotune_at_compile_time = True to enable.", + kernel_name, + ) + + kernel_config = tuner.compile_results[0].config + call_args, grid = tuner._interpret_args_grid(call_args, kernel_config) + call_kwargs = dict(zip(tuner.triton_meta["signature"], call_args)) + call_kwargs.update(kernel_config.kwargs) + + def replace_floor_div(expr: sympy.Expr) -> sympy.Expr: + """ + Converts floor(x / c) to x // c. + """ + if isinstance(expr, sympy.core.mul.Mul) and isinstance( + expr.args[0], sympy.Rational + ): + # Only the first argument of a Mul can be a Rational. + frac = expr.args[0] + numerator = sympy_product(expr.args[1:]) * frac.numerator + denominator = frac.denominator + + # Sanity check the results. + new_expr = numerator / denominator + assert V.graph.sizevars.statically_known_equals(new_expr, expr), ( + f"Unsound replacement: '{new_expr}' != '{expr}'" + ) + + return FloorDiv(numerator, denominator) + else: + return sympy.floor(expr) + + def expr_to_symint(expr: Union[int, sympy.Expr]) -> Union[int, sympy.Expr]: + return ( + convert_to_symint(expr.replace(sympy.floor, replace_floor_div)) + if isinstance(expr, sympy.Expr) + else expr + ) + + # Convert sympy expressions to symints. + # Use FloorDiv over sympy.floor, so we can get nicer Python code from FX. + wrapper_grid = [tuple(expr_to_symint(dim) for dim in grid)] + call_kwargs = {name: expr_to_symint(val) for name, val in call_kwargs.items()} + + # Store non-graphable kwargs in the side table. + ( + call_kwargs, + constant_args_idx, + ) = tracing_triton_hopifier_singleton.store_non_graphable_args(call_kwargs) + + self.gm.graph.call_function( + triton_kernel_wrapper_mutation, + kwargs={ + "kernel_idx": kernel.wrapped.kernel_idx, + "constant_args_idx": constant_args_idx, + "grid": wrapper_grid, + "tma_descriptor_metadata": {}, + "kwargs": call_kwargs, + }, + ) + + def _generate_extern_kernel_alloc(self, line: WrapperLine) -> None: + assert isinstance(line, ExternKernelAllocLine) + node = line.node + self._generate_extern_kernel_common(node, node) + + def _generate_extern_kernel_out( + self, + line: WrapperLine, + ) -> None: + assert isinstance(line, ExternKernelOutLine) + node = line.node + out_node = node.output_view if node.output_view else node + self._generate_extern_kernel_common(node, out_node) + + def _generate_extern_kernel_common( + self, kernel: ir.ExternKernel, out_ir_node: ir.IRNode + ) -> None: + """ + Generates FX IR from either ExternKernelAlloc or ExternKernelOut. + """ + + # Get FX nodes corresponding to the call args. + tensor_nodes = tuple(self._generate_buffer(arg) for arg in kernel.inputs) + args = tensor_nodes + tuple(kernel.constant_args) + + # Get the result buffer. + # Some kernels write to a pre-existing output tensor via the "out" kwarg. + kwargs = kernel.kwargs.copy() + result_buffer: Optional[str] = None + if isinstance(kernel, ir.ExternKernelOut): + kwargs["out"] = self.buffer_to_node[out_ir_node.codegen_reference()] + elif isinstance(kernel.layout, (ir.Layout, ir.MultiOutputLayout)): + result_buffer = kernel.get_name() + elif isinstance(kernel.layout, ir.NoneLayout): + pass + else: + raise NotImplementedError(f"Unrecognized output layout: {kernel.layout}") + + # Look up the kernel function from its name. + kernel_name = kernel.get_kernel_name() + module_name, kernel_name = kernel_name.split(".", 1) + op = globals()[module_name] # E.g. extern_kernels, aten, etc. + for subname in kernel_name.split("."): + op = getattr(op, subname) # E.g. extern_kernels.addmm + + fx_node = self.gm.graph.call_function(op, args=args, kwargs=kwargs) + + # Assign the result to the given name. + if result_buffer: + assert "out" not in kwargs, ( + f"Extern kernel '{kernel}' has both result and out kwarg. Expected only one." + ) + fx_node.name = result_buffer + self.buffer_to_node[result_buffer] = fx_node + + arg_tensors = [ + arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg + for arg in args + ] + + # Run the operation to propagate metadata. + fx_node.meta["val"] = op(*arg_tensors, **kwargs) + + def _generate_kernel_call(self, line: WrapperLine) -> None: + assert isinstance(line, KernelCallLine) + if not line.triton: + raise NotImplementedError("FX conversion only supports Triton kernels.") + + self._generate_triton_call(line) + + def _generate_kernel_definition(self, line: WrapperLine) -> None: + assert isinstance(line, KernelDefinitionLine) + + # Generate code for the kernel. + kernel_code = PythonWrapperCodegen._format_kernel_definition( + line.kernel_name, line.kernel_body, metadata=line.metadata + ) + + # Import the module and store the JIT kernel. + tuner = self._import_kernel(kernel_code, line.kernel_name) + wrapped = wrap_triton(tuner.fn) + self.kernels[line.kernel_name] = TritonKernel(tuner, wrapped) + + def _generate_symbolic_call_arg(self, line: WrapperLine) -> None: + assert isinstance(line, SymbolicCallArgLine) + # No need for an FX node, as we will pass the arg to kernels via a SymInt. diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/xpu/__init__.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/xpu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/xpu/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/xpu/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0647c3ce64699d0c4fffa53b3c47bd5f22613bb1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/xpu/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/xpu/__pycache__/device_op_overrides.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/codegen/xpu/__pycache__/device_op_overrides.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..546c97c692f9aa0544a892302a58c542e75522b9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/codegen/xpu/__pycache__/device_op_overrides.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/codegen/xpu/device_op_overrides.py b/phivenv/Lib/site-packages/torch/_inductor/codegen/xpu/device_op_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..bd927b6ee3b8195706d3bd791451bdeb9c3ea91b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/codegen/xpu/device_op_overrides.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import Optional + +from ..common import ( + DeviceOpOverrides, + register_device_op_overrides, + TritonScratchWorkspace, +) + + +class XPUDeviceOpOverrides(DeviceOpOverrides): + def import_get_raw_stream_as(self, name: str) -> str: + return f"from torch._C import _xpu_getCurrentRawStream as {name}" + + def set_device(self, device_idx: int) -> str: + return f"torch.xpu.set_device({device_idx})" + + def synchronize(self) -> str: + return "torch.xpu.synchronize()" + + def device_guard(self, device_idx: int) -> str: + return f"torch.xpu._DeviceGuard({device_idx})" + + def cpp_device_guard(self) -> str: + return "at::DeviceGuard" + + def cpp_aoti_device_guard(self) -> str: + return "AOTIXpuGuard" + + def cpp_stream_guard(self) -> str: + return "at::xpu::XPUStreamGuard" + + def cpp_aoti_stream_guard(self) -> str: + return "AOTIXpuStreamGuard" + + def cpp_getStreamFromExternal(self) -> str: + return "at::xpu::getStreamFromExternal" + + def kernel_header(self) -> str: + source_codes = """ + #include + """ + return source_codes + + def kernel_driver(self) -> str: + return "" + + def cpp_stream_type(self) -> str: + return "sycl::queue*" + + def aoti_get_stream(self) -> str: + return "aoti_torch_get_current_xpu_stream" + + def cpp_kernel_type(self) -> str: + return "std::unique_ptr" + + def cpp_device_ptr(self) -> str: + return "void *" + + def cpp_global_scratch( + self, idx: int, workspace: TritonScratchWorkspace + ) -> Optional[tuple[list[str], str]]: + return None + + +register_device_op_overrides("xpu", XPUDeviceOpOverrides()) diff --git a/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__init__.py b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__main__.py b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..d996f8de000fcc422d9e7d7a9aff021d9ee431e3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__main__.py @@ -0,0 +1,80 @@ +# mypy: allow-untyped-defs +import argparse +import base64 +import functools +import importlib +import logging +import os +import sys +from typing import TypeVar + +from torch._inductor.async_compile import pre_fork_setup +from torch._inductor.codecache import torch_key +from torch._inductor.compile_worker.subproc_pool import ( + SubprocKind, + SubprocMain, + SubprocPickler, +) +from torch._inductor.compile_worker.utils import _async_compile_initializer +from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path + + +_T = TypeVar("_T") + + +log = logging.getLogger(__name__) + +_set_triton_ptxas_path() + +try: + import triton + + assert triton is not None # preload in parent +except ImportError: + pass + + +def _lookup_and_create_type(base: type[_T], qname: str) -> _T: + """ + Given a base type and qualified name: import & lookup that name, check + that it's of the given type and then instantiate it. + """ + pkg, name = qname.rsplit(".", 1) + mod = importlib.import_module(pkg) + ty = getattr(mod, name) + if not issubclass(ty, base): + raise TypeError(f"Type {ty} is not a subtype of {base}") + return ty() + + +def main(): + try: + parser = argparse.ArgumentParser() + parser.add_argument( + "--pickler", type=functools.partial(_lookup_and_create_type, SubprocPickler) + ) + parser.add_argument("--kind", type=SubprocKind) + parser.add_argument("--workers", type=int) + parser.add_argument("--parent", type=int) + parser.add_argument("--read-fd", type=int) + parser.add_argument("--write-fd", type=int) + parser.add_argument("--torch-key", type=str) + args = parser.parse_args() + if os.getppid() != args.parent: + sys.exit(0) + read_fd = os.fdopen(args.read_fd, "rb") + write_fd = os.fdopen(args.write_fd, "wb") + + pre_fork_setup() + + torch_key.set(base64.b64decode(args.torch_key.encode("utf-8"))) # type: ignore[attr-defined] + + _async_compile_initializer(args.parent) + + SubprocMain(args.pickler, args.kind, args.workers, read_fd, write_fd).main() + except Exception: + log.exception("Uncaught exception in compile_worker subprocess") + + +if __name__ == "__main__": + main() diff --git a/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00e92e42c98372d86491de72c029578c763904ae Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__pycache__/__main__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__pycache__/__main__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..042e9c7cdac6c255aa4688d5e5ce2ba67926d35a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__pycache__/__main__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__pycache__/subproc_pool.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__pycache__/subproc_pool.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cddcef133917a7df0d9cf8e4a0e1f0655fd88e07 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__pycache__/subproc_pool.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__pycache__/tracked_process_pool.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__pycache__/tracked_process_pool.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fc646df80f57e775a9b73a3eab9b9f2e563260a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__pycache__/tracked_process_pool.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bec52907af48c8cb0957e0e86cdd43dfaad2d08c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/compile_worker/subproc_pool.py b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/subproc_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..210601241e6bc99b7ae0d8a1c238e838f2e4f859 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/subproc_pool.py @@ -0,0 +1,379 @@ +import base64 +import functools +import itertools +import logging +import multiprocessing +import os +import pickle +import struct +import subprocess +import sys +import threading +import traceback +import typing +from concurrent.futures import Future, ProcessPoolExecutor +from concurrent.futures.process import BrokenProcessPool +from enum import Enum +from typing import Any, Callable, IO, Optional, TypeVar +from typing_extensions import Never, ParamSpec + +# _thread_safe_fork is needed because the subprocesses in the pool can read +# justknobs, e.g., in the Triton compiler. For internal, the import installs +# functionality to destroy singletons before forking and re-enable them after. +import torch._thread_safe_fork # noqa: F401 +from torch._inductor import config +from torch._inductor.codecache import torch_key +from torch._inductor.compile_worker.tracked_process_pool import ( + TrackedProcessPoolExecutor, +) +from torch._inductor.compile_worker.utils import _async_compile_initializer +from torch._inductor.utils import get_ld_library_path + + +log = logging.getLogger(__name__) + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +def _pack_msg(job_id: int, length: int) -> bytes: + return struct.pack("nn", job_id, length) + + +def _unpack_msg(data: bytes) -> tuple[int, int]: + if not data: + return -1, -1 + return struct.unpack("nn", data) + + +msg_bytes = len(_pack_msg(0, 0)) + + +def _send_msg(write_pipe: IO[bytes], job_id: int, job_data: bytes = b"") -> None: + length = len(job_data) + write_pipe.write(_pack_msg(job_id, length)) + if length > 0: + write_pipe.write(job_data) + write_pipe.flush() + + +def _recv_msg(read_pipe: IO[bytes]) -> tuple[int, bytes]: + job_id, length = _unpack_msg(read_pipe.read(msg_bytes)) + data = read_pipe.read(length) if length > 0 else b"" + return job_id, data + + +class _SubprocExceptionInfo: + """ + Carries exception info from subprocesses across the wire. traceback + objects are not pickleable, so we store the trace as a string and + use it for the message in the exception thrown in the main process. + """ + + def __init__(self, details: str) -> None: + self.details = details + + +class SubprocException(Exception): + """ + Thrown when a job in a subprocess raises an Exception. + """ + + def __init__(self, details: str) -> None: + super().__init__(f"An exception occurred in a subprocess:\n\n{details}") + + +class SubprocPickler: + """ + Allows a caller to provide a custom pickler for passing data with the + subprocess. + """ + + def dumps(self, obj: object) -> bytes: + return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) + + def loads(self, data: bytes) -> object: + return pickle.loads(data) + + +class SubprocKind(Enum): + FORK = "fork" + SPAWN = "spawn" + + +class SubprocPool: + """ + Mimic a concurrent.futures.ProcessPoolExecutor, but wrap it in + a subprocess.Popen() to try to avoid issues with forking/spawning + """ + + def __init__( + self, + nprocs: int, + pickler: Optional[SubprocPickler] = None, + kind: SubprocKind = SubprocKind.FORK, + ) -> None: + entry = os.path.join(os.path.dirname(__file__), "__main__.py") + self.pickler = pickler or SubprocPickler() + self.kind = kind + + subproc_read_fd, write_fd = os.pipe() + read_fd, subproc_write_fd = os.pipe() + self.write_pipe = os.fdopen(write_fd, "wb") + self.read_pipe = os.fdopen(read_fd, "rb") + torch_key_str = base64.b64encode(torch_key()).decode("utf-8") + + cmd = [ + sys.executable, + entry, + f"--pickler={self.pickler.__class__.__module__}.{self.pickler.__class__.__name__}", + f"--kind={self.kind.value}", + f"--workers={nprocs}", + f"--parent={os.getpid()}", + f"--read-fd={str(subproc_read_fd)}", + f"--write-fd={str(subproc_write_fd)}", + f"--torch-key={torch_key_str}", + ] + local = False + if config.worker_suppress_logging: + log.info("Suppressing compile worker output due to config") + local = True + + self.process = subprocess.Popen( + cmd, + env={ + **os.environ, + # We need to set the PYTHONPATH so the subprocess can find torch. + "PYTHONPATH": os.environ.get( + "TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path) + ), + # We don't want to re-warm the pool when the subprocess imports + # torch._inductor.codecache since the warming process is what + # creates the SubprocPool in the first place. + "TORCH_WARM_POOL": "0", + # Some internal usages need a modified LD_LIBRARY_PATH. + "LD_LIBRARY_PATH": get_ld_library_path(), + }, + pass_fds=(subproc_read_fd, subproc_write_fd), + stdout=subprocess.DEVNULL if local else None, + stderr=subprocess.DEVNULL if local else None, + ) + self.write_lock = threading.Lock() + self.read_thread = threading.Thread(target=self._read_thread, daemon=True) + + self.futures_lock = threading.Lock() + self.pending_futures: dict[int, Future[Any]] = {} + self.job_id_count = itertools.count() + + self.running = True + + # Start thread last to ensure all member variables are initialized + # before any access. + self.read_thread.start() + + def submit( + self, job_fn: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs + ) -> Future[_T]: + if args or kwargs: + job_fn = functools.partial(job_fn, *args, **kwargs) + job_data = self.pickler.dumps(job_fn) + future: Future[_T] + with self.futures_lock: + job_id = next(self.job_id_count) + self.pending_futures[job_id] = future = Future() + future.set_running_or_notify_cancel() + with self.write_lock: + if not self.running: + raise RuntimeError("submit() on closed pool") + _send_msg(self.write_pipe, job_id, job_data) + return future + + def _read_thread(self) -> None: + while True: + data = b"" + try: + job_id, data = _recv_msg(self.read_pipe) + except Exception: + # Something went wrong during the read. There's no way we have a + # valid job_id. + log.exception("failure in subproc_pool._recv_msg") + job_id = -1 + + if job_id < 0: + # read_pipe returned None or got exception + if self.running: + log.warning("SubprocPool unclean exit") + self.running = False + self.read_pipe.close() + # Cancel all the pending futures. + self.shutdown() + return + + try: + result = self.pickler.loads(data) + except Exception as e: + # Something went wrong unpickling. We have a job_id so just + # notify that particular future and continue on. + log.exception("unpickle failure in SubprocPool._read_thread") + result = e + + with self.futures_lock: + if not self.running: + return + if isinstance(result, _SubprocExceptionInfo): + # An exception occurred in the submitted job + self.pending_futures[job_id].set_exception( + SubprocException(result.details) + ) + elif isinstance(result, Exception): + # An exception occurred in some of our subprocess machinery. + self.pending_futures[job_id].set_exception(result) + else: + self.pending_futures[job_id].set_result(result) + del self.pending_futures[job_id] + + def shutdown(self) -> None: + try: + with self.write_lock: + if not self.running: + return + self.running = False + _send_msg(self.write_pipe, -1) + self.write_pipe.close() + self.process.wait(300) + except OSError as e: + log.warning("Ignored OSError in pool shutdown: %s", e) + finally: + with self.futures_lock: + for future in self.pending_futures.values(): + if not future.cancel(): + future.set_exception(RuntimeError("SubprocPool closed")) + self.pending_futures.clear() + + +class SubprocMain: + """Communicates with a SubprocPool in the parent process, called by __main__.py""" + + def __init__( + self, + pickler: SubprocPickler, + kind: SubprocKind, + nprocs: int, + read_pipe: IO[bytes], + write_pipe: IO[bytes], + ) -> None: + self.pickler = pickler + self.kind = kind + self.read_pipe = read_pipe + self.write_pipe = write_pipe + self.write_lock = threading.Lock() + self.nprocs = nprocs + self.pool = self._new_pool(nprocs, True) + self.running = True + + def _new_pool(self, nprocs: int, warm: bool) -> ProcessPoolExecutor: + pool = TrackedProcessPoolExecutor( + nprocs, + mp_context=multiprocessing.get_context(self.kind.value), + initializer=functools.partial(_async_compile_initializer, os.getpid()), + ) + multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize) + if warm: + _warm_process_pool(pool, nprocs) + return pool + + def main(self) -> None: + while True: + job_id, data = _recv_msg(self.read_pipe) + if job_id < 0: + return self._shutdown() + self.submit(job_id, data) + + def _shutdown(self) -> None: + with self.write_lock: + self.running = False + try: + _send_msg(self.write_pipe, -1) + self.write_pipe.close() + except BrokenPipeError: + pass # parent process already shutdown + self.read_pipe.close() + self.pool.shutdown() + + def submit(self, job_id: int, data: bytes) -> None: + while self.running: + try: + self._submit_inner(job_id, data) + return + except BrokenProcessPool: + # If any subprocess in the pool crashes, we get a BrokenProcessPool + # exception and the whole pool becomes unusable. Handle crashes by + # recreating the pool and resubmitting. + self.pool = self._new_pool(self.nprocs, False) + + def _submit_inner(self, job_id: int, data: bytes) -> None: + def callback(fut: Future[Any]) -> None: + if not self.running: + return + try: + result = fut.result() + except Exception as e: + log.exception("Error in subprocess") + result = self.pickler.dumps(e) + assert isinstance(result, bytes) + with self.write_lock: + if self.running: + _send_msg(self.write_pipe, job_id, result) + return + + future = self.pool.submit( + functools.partial(SubprocMain.do_job, self.pickler, data) + ) + future.add_done_callback(callback) + + @staticmethod + def do_job(pickler: SubprocPickler, data: bytes) -> bytes: + # do the pickle/unpickle in the sub-subproc + job = typing.cast(Callable[[], object], pickler.loads(data)) + + try: + result = job() + except Exception: + result = _SubprocExceptionInfo(traceback.format_exc()) + return pickler.dumps(result) + + +AnyPool = typing.Union[ProcessPoolExecutor, SubprocPool] + + +def _warm_process_pool(pool: ProcessPoolExecutor, n: int) -> None: + # We have to fork processes for compiler workers, but the more memory and other resources that are loaded, the + # slower the os.fork time is, quite drastically. It also holds the GIL so we can't put it on another thread. + + # Examples: + # A simple x + x + x script: 10ms seconds in the middle of the program, 2ms at startup + # tf_efficientnet_b0 benchmark: 50ms! in the middle of the program , 3ms at startup + + # So we want to start the workers early when it is still cheap, and also to allow the workers to get + # ready before we have work for them. + + # ProcessPoolExecutor also does not launch the workers until it finds a point when all the workers are idle. + # But if we waited until then fork time will be long and we will be waiting for the processes to initialize. + + # We force them to start here with some YOLOing of the internal methods. + + if hasattr(pool, "_start_queue_management_thread"): + pool._start_queue_management_thread() + else: + for _ in range(n): + pool._adjust_process_count() + if hasattr(pool, "_start_executor_manager_thread"): + pool._start_executor_manager_thread() + + +class TestException(RuntimeError): + pass + + +def raise_testexc() -> Never: + raise TestException diff --git a/phivenv/Lib/site-packages/torch/_inductor/compile_worker/tracked_process_pool.py b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/tracked_process_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..585fe8829870ab9e9bf834abbf5f044bb6627a6e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/tracked_process_pool.py @@ -0,0 +1,111 @@ +import atexit +import concurrent +import dataclasses +import logging +import threading +from concurrent.futures import Future, ProcessPoolExecutor +from dataclasses import dataclass +from multiprocessing.context import BaseContext +from time import time +from typing import Any, Callable, Optional, TypeVar +from typing_extensions import ParamSpec + +# _thread_safe_fork is needed because the subprocesses in the pool can read +# justknobs, e.g., in the Triton compiler. For internal, the import installs +# functionality to destroy singletons before forking and re-enable them after. +import torch._thread_safe_fork # noqa: F401 + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +log = logging.getLogger(__name__) + + +@dataclass +class _QueueStats: + # Mapping from id(future) -> start time + pending: dict[int, float] = dataclasses.field(default_factory=dict) + timing: list[float] = dataclasses.field(default_factory=list) + enqueue_count: int = 0 + dequeue_count: int = 0 + max_queue_depth: int = 0 + pool_count: int = 0 + + +# The queue statistics tracked by TrackedProcessPoolExecutor. Always grab +# _queue_stats_lock before touching. +_queue_stats = _QueueStats() +_queue_stats_lock = threading.Lock() + + +class TrackedProcessPoolExecutor(ProcessPoolExecutor): + def __init__( + self, + max_workers: Optional[int] = None, + mp_context: Optional[BaseContext] = None, + initializer: Optional[Callable[[], object]] = None, + ) -> None: + with _queue_stats_lock: + _queue_stats.pool_count += 1 + super().__init__(max_workers, mp_context, initializer) + + def _record_dequeue(self, f: Future[Any]) -> None: + now = time() + with _queue_stats_lock: + stats = _queue_stats + if (start_time := stats.pending.pop(id(f), None)) is None: + return + stats.dequeue_count += 1 + duration = now - start_time + stats.timing.append(duration) + + def _record_enqueue(self, f: Future[Any]) -> None: + # Monkeypatch the set_running_or_notify_cancel so we can track when the Future moves out of PENDING. + saved_running_or_notify_cancel = f.set_running_or_notify_cancel + + def set_running_or_notify_cancel() -> Any: + self._record_dequeue(f) + return saved_running_or_notify_cancel() + + now = time() + with _queue_stats_lock: + stats = _queue_stats + stats.pending[id(f)] = now + stats.enqueue_count += 1 + stats.max_queue_depth = max(stats.max_queue_depth, len(stats.pending)) + f.set_running_or_notify_cancel = set_running_or_notify_cancel # type: ignore[method-assign] + + if f._state != concurrent.futures._base.PENDING: + self._record_dequeue(f) + + def submit( + self, fn: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs + ) -> Future[_R]: + f = super().submit(fn, *args, **kwargs) + self._record_enqueue(f) + return f + + +@atexit.register +def _queue_stats_report() -> None: + stats = _queue_stats + if stats.pool_count == 0: + return + + timing = stats.timing + timing.sort() + + log.info("AsyncCompile Metrics:") + log.info(" Pools %s", stats.pool_count) + log.info( + " Items %d enqueued / %d dequeued", stats.enqueue_count, stats.dequeue_count + ) + log.info(" Max Queue Depth: %d", stats.max_queue_depth) + n = len(timing) + if n > 0: + log.info(" Longest queue time: %0.2fs", timing[-1]) + log.info(" P50: %0.2fs", timing[n // 2]) + if n >= 20: + log.info(" P95: %0.2fs", timing[n * 95 // 100]) diff --git a/phivenv/Lib/site-packages/torch/_inductor/compile_worker/utils.py b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b0c41107d17a546eadf27920934789e0cd38c633 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/compile_worker/utils.py @@ -0,0 +1,49 @@ +import os +import signal +from threading import Thread +from time import sleep +from typing import Optional + + +_IN_TOPLEVEL_PROCESS = True + + +def in_toplevel_process() -> bool: + global _IN_TOPLEVEL_PROCESS + return _IN_TOPLEVEL_PROCESS + + +# If this process dies abnormally (e.g. segfault) +# it will not shut down the workers. Instead, +# the workers will have their parent reassigned to the +# init process. This launches a separate thread to +# watch for the worker getting reassigned, +# and cleans it up in this case. +# +# This function cannot be an inner function since otherwise mp_context="spawn" would +# not work for ProcessPoolExecutor since inner functions cannot be pickled. +def _async_compile_initializer(orig_ppid: int) -> None: + def run() -> None: + while True: + sleep(1) + if orig_ppid != os.getppid(): + os.kill(os.getpid(), signal.SIGKILL) + + global _watchdog_thread, _original_parent + _original_parent = orig_ppid + _watchdog_thread = Thread(target=run, daemon=True) + _watchdog_thread.start() + # Ignore Ctrl-C (i.e. SIGINT) sent to pool workers to avoid meaningless log spam. + signal.signal(signal.SIGINT, signal.SIG_IGN) + + # Set a bit to distinguish async_compile subprocesses from the toplevel process. + global _IN_TOPLEVEL_PROCESS + _IN_TOPLEVEL_PROCESS = False + + +_watchdog_thread: Optional[Thread] = None +_original_parent: Optional[int] = None + + +def has_parent_changed() -> bool: + return _original_parent != os.getppid() diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__init__.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d548ae4ce12c8432a7981d93655dc721eb707712 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/b2b_gemm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/b2b_gemm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ff6c1fb97ffa4471abc291aaad3bb300ef1721c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/b2b_gemm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f4dcd703b2bf5f38e3691dd10b6916fafca289f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/binary_folding.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/ddp_fusion.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/ddp_fusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66c2cdb6ea51bb77938272c7635169d9acf22bd0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/ddp_fusion.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f5217c072ccd674448591af236cd891f88c1380 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/decompose_mem_bound_mm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c28cfdb2913560056d932f2675b3767bcd7e8222 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/dedupe_symint_uses.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1e5e0ecc1ee2508797901aaf5d4d389e47421d7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/efficient_conv_bn_eval.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74dd957fbb2ecd500640a4bdb3a1b130de0b6d6c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/freezing_patterns.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..968c1514d0a1ff40bd2cf697faa96c11d2a24071 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/fuse_attention.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d9bd7e8c5c14aa46ab1448c7f1709a0130c2de4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/group_batch_fusion.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16c3f6eb96c58b5bd15d9fb5711532ff4b46652e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/joint_graph.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/micro_pipeline_tp.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/micro_pipeline_tp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e482bf17dbd293ae266a140663de0425d937d32 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/micro_pipeline_tp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fe0d3611395904e5bbb0b03d2f7953d12ad469d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/misc_patterns.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..805b148d835f65ec2d3acbf6f58832dbb391ccac Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/mkldnn_fusion.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08115b7be229cf0e4bc08fe2bcb35387d33a91d3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/numeric_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7aa49a361844a0e7b9bd263c0e8e2c73c5d0d4ad Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/pad_mm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dc79d16dc2179f496de5026a823b3819bf73af6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/post_grad.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..345931589fd3e23196f91b5191592fc42c1be436 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/pre_grad.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/quantization.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/quantization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c343b8f5622f364a65cfe44bb99fec6a82ad4fc6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/quantization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4b11065667fbe05beee284dc72dee305c35b7bd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/reinplace.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc237b051bdbb1ded22ec321ff2e3f71bdf55899 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/replace_random.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/split_cat.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/split_cat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66d4dcee7c58f82b2749c81ba36f66c7d778cb7d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/__pycache__/split_cat.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/b2b_gemm.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/b2b_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..5be1e10b81aa0e6fd119b27b142d4cf5c02265d3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/b2b_gemm.py @@ -0,0 +1,760 @@ +# mypy: allow-untyped-defs +import functools +from collections import deque + +import torch +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_map + +from ..._dynamo.utils import counters +from ..ir import ( + ComputedBuffer, + FixedLayout, + FlexibleLayout, + InputBuffer, + StorageBox, + Subgraph, + TensorBox, +) +from ..lowering import lowerings +from ..pattern_matcher import ( + Arg, + CallFunction, + Match, + PatternMatcherPass, + register_graph_pattern, +) +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + SymbolicGridFn, + TritonTemplate, + TritonTemplateCaller, +) +from ..utils import ceildiv + + +B2B_GEMM_PASS = PatternMatcherPass( + pass_name="b2b_gemm_pass", +) + + +@SymbolicGridFn +def b2b_gemm_grid(M, P, meta, *, cdiv): + return (cdiv(M, meta["BLOCK_SIZE_M"]) * cdiv(P, meta["BLOCK_SIZE_P"]), 1, 1) + + +b2b_gemm_left_template = TritonTemplate( + name="b2b_gemm_left", + grid=b2b_gemm_grid, + debug=False, + source=r""" +{{def_kernel("A", "B", "C")}} + + + # B2B_GEMM_LEFT_TRITON_ENTRANCE + + # dynamic shapes + M = {{size("A", 0)}} + N = {{size("A", 1)}} + O = {{size("C", 0)}} + P = {{size("C", 1)}} + + # dynamic strides + stride_am = {{stride("A", 0)}} + stride_an = {{stride("A", 1)}} + stride_bn = {{stride("B", 0)}} + stride_bo = {{stride("B", 1)}} + stride_co = {{stride("C", 0)}} + stride_cp = {{stride("C", 1)}} + + # output block counts + num_m_block = tl.cdiv(M, BLOCK_SIZE_M) + num_p_block = tl.cdiv(P, BLOCK_SIZE_P) + + # internal block counts + num_n_block = tl.cdiv(N, BLOCK_SIZE_N) + num_o_block = tl.cdiv(O, BLOCK_SIZE_O) + + # output block ids + pid = tl.program_id(axis=0) + m_block_id = pid // num_p_block + p_block_id = pid % num_p_block + + # accumulator + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P), dtype=tl.float32) + + # main loop + offs_m = (m_block_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_p = (p_block_id * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P)) + # (subgraph(A @ B) @ C) + offs_o = tl.arange(0, BLOCK_SIZE_O) + for _ in range(num_o_block): + c_mask = (offs_o[:, None] < O) & (offs_p[None, :] < P) + c_ptrs = C + (offs_o[:, None] * stride_co + offs_p[None, :] * stride_cp) + c = tl.load(c_ptrs, mask=c_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_O * BLOCK_SIZE_P + acc_ab = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_O), dtype=tl.float32) + offs_n = tl.arange(0, BLOCK_SIZE_N) + for __ in range(num_n_block): + a_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + a_ptrs = A + (offs_m[:, None] * stride_am + offs_n[None, :] * stride_an) + a = tl.load(a_ptrs, mask=a_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_M * BLOCK_SIZE_N + b_mask = (offs_n[:, None] < N) & (offs_o[None, :] < O) + b_ptrs = B + (offs_n[:, None] * stride_bn + offs_o[None, :] * stride_bo) + b = tl.load(b_ptrs, mask=b_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_N * BLOCK_SIZE_O + acc_ab += tl.dot(a, b, out_dtype=tl.float32) + offs_n += BLOCK_SIZE_N + # apply the subgraph + {{ modification( + subgraph_number=0, + output_name="post_subgraph_acc_ab", + inner_mm="acc_ab" + ) | indent_except_first(2) }} + acc += tl.dot(post_subgraph_acc_ab, c, out_dtype=tl.float32) + offs_o += BLOCK_SIZE_O + + # type conversion + acc = acc.to(tl.float16) + + # store preparation + idx_m = offs_m[:, None] + idx_p = offs_p[None, :] + out_mask = (idx_m < M) & (idx_p < P) + + {{store_output(("idx_m", "idx_p"), "acc", "out_mask")}} +""", +) + + +b2b_gemm_right_template = TritonTemplate( + name="b2b_gemm_right", + grid=b2b_gemm_grid, + debug=False, + source=r""" +{{def_kernel("A", "B", "C")}} + + + # B2B_GEMM_RIGHT_TRITON_ENTRANCE + + # dynamic shapes + M = {{size("A", 0)}} + N = {{size("A", 1)}} + O = {{size("C", 0)}} + P = {{size("C", 1)}} + + # dynamic strides + stride_am = {{stride("A", 0)}} + stride_an = {{stride("A", 1)}} + stride_bn = {{stride("B", 0)}} + stride_bo = {{stride("B", 1)}} + stride_co = {{stride("C", 0)}} + stride_cp = {{stride("C", 1)}} + + # output block counts + num_m_block = tl.cdiv(M, BLOCK_SIZE_M) + num_p_block = tl.cdiv(P, BLOCK_SIZE_P) + + # internal block counts + num_n_block = tl.cdiv(N, BLOCK_SIZE_N) + num_o_block = tl.cdiv(O, BLOCK_SIZE_O) + + # output block ids + pid = tl.program_id(axis=0) + m_block_id = pid // num_p_block + p_block_id = pid % num_p_block + + # accumulator + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_P), dtype=tl.float32) + + # main loop (two cases) + offs_m = (m_block_id * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_p = (p_block_id * BLOCK_SIZE_P + tl.arange(0, BLOCK_SIZE_P)) + # (A @ subgraph(B @ C)) + offs_n = tl.arange(0, BLOCK_SIZE_N) + for _ in range(num_n_block): + a_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + a_ptrs = A + (offs_m[:, None] * stride_am + offs_n[None, :] * stride_an) + a = tl.load(a_ptrs, mask=a_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_M * BLOCK_SIZE_N + acc_bc = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_P), dtype=tl.float32) + offs_o = tl.arange(0, BLOCK_SIZE_O) + for __ in range(num_o_block): + b_mask = (offs_n[:, None] < N) & (offs_o[None, :] < O) + b_ptrs = B + (offs_n[:, None] * stride_bn + offs_o[None, :] * stride_bo) + b = tl.load(b_ptrs, mask=b_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_N * BLOCK_SIZE_O + c_mask = (offs_o[:, None] < O) & (offs_p[None, :] < P) + c_ptrs = C + (offs_o[:, None] * stride_co + offs_p[None, :] * stride_cp) + c = tl.load(c_ptrs, mask=c_mask, other=0.0).to(tl.float32) # BLOCK_SIZE_O * BLOCK_SIZE_P + acc_bc += tl.dot(b, c, out_dtype=tl.float32) + offs_o += BLOCK_SIZE_O + # apply the subgraph + {{ modification( + subgraph_number=0, + output_name="post_subgraph_acc_bc", + inner_mm="acc_bc" + ) | indent_except_first(2) }} + acc += tl.dot(a, post_subgraph_acc_bc, out_dtype=tl.float32) + offs_n += BLOCK_SIZE_N + + # type conversion + acc = acc.to(tl.float16) + + # store preparation + idx_m = offs_m[:, None] + idx_p = offs_p[None, :] + out_mask = (idx_m < M) & (idx_p < P) + + {{store_output(("idx_m", "idx_p"), "acc", "out_mask")}} +""", +) + + +# Note: load_ratio_left and load_ratio_right are only calculating numbers +# in the trivial subgraph case; i.e. (A @ (B @ C)) or ((A @ B) @ C) + + +def load_ratio_left( + M: int, N: int, O: int, P: int, m: int, n: int, o: int, p: int +) -> float: + """ + compute the ratio of estimated numbers of loads in baseline and b2bgemm + M, N, O, P are matrix sizes + m, n, o, p are block sizes + | | baseline (lower bound) | b2bgemm + | load | M * N + N * O + M * O + O * P | M / m * P / p * O / o * (o * p + N / n * (m * n + n * o)) + | store | M * O + M * P | M * P + b2bgemm is always better on stores, but for loads we need to find out beneficial cases using this function + """ + base = M * N + N * O + M * O + O * P + gemm = ( + ceildiv(M, m) + * ceildiv(P, p) + * ceildiv(O, o) + * (o * p + ceildiv(N, n) * (m * n + n * o)) + ) + return base / gemm + + +def load_ratio_right( + M: int, N: int, O: int, P: int, m: int, n: int, o: int, p: int +) -> float: + """ + compute the ratio of estimated numbers of loads in baseline and b2bgemm + M, N, O, P are matrix sizes + m, n, o, p are block sizes + | | baseline (lower bound) | b2bgemm + | load | N * O + O * P + M * N + N * P | M / m * P / p * N / n * (m * n + O / o * (n * o + o * p)) + | store | N * P + M * P | M * P + b2bgemm is always better on stores, but for loads we need to find out beneficial cases using this function + """ + base = N * O + O * P + M * N + N * P + gemm = ( + ceildiv(M, m) + * ceildiv(P, p) + * ceildiv(N, n) + * (m * n + ceildiv(O, o) * (n * o + o * p)) + ) + return base / gemm + + +# the block sizes are limited by hardware (the shared memory) +# intuitively, the optimization works when the intermediate matrix is large +# and we assign large block sizes to large dimensions +b2b_gemm_configs = [ + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_O": 16, + "BLOCK_SIZE_P": 16, + "num_stages": 4, + "num_warps": 8, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_O": 32, + "BLOCK_SIZE_P": 32, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_O": 64, + "BLOCK_SIZE_P": 64, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_O": 128, + "BLOCK_SIZE_P": 16, + "num_stages": 4, + "num_warps": 8, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_O": 128, + "BLOCK_SIZE_P": 32, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_O": 128, + "BLOCK_SIZE_P": 64, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_O": 16, + "BLOCK_SIZE_P": 128, + "num_stages": 4, + "num_warps": 8, + }, + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_O": 32, + "BLOCK_SIZE_P": 128, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_O": 64, + "BLOCK_SIZE_P": 128, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_O": 16, + "BLOCK_SIZE_P": 128, + "num_stages": 4, + "num_warps": 8, + }, + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_O": 32, + "BLOCK_SIZE_P": 128, + "num_stages": 2, + "num_warps": 4, + }, + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_O": 64, + "BLOCK_SIZE_P": 128, + "num_stages": 2, + "num_warps": 4, + }, +] + + +def is_b2b_gemm_good_on( + is_left_assoc: bool, + A_node: torch.fx.Node, + B_node: torch.fx.Node, + C_node: torch.fx.Node, +) -> bool: + """ + checks whether the sizes are good for b2b_gemm + """ + # basic checks + if not all(["val" in A_node.meta, "val" in B_node.meta, "val" in C_node.meta]): + return False + fake_tensors = ( + A_node.meta["val"], + B_node.meta["val"], + C_node.meta["val"], + ) # torch._subclasses.fake_tensor.FakeTensor + + A, B, C = fake_tensors + + def check_all_attr_true(objects, attr): + return all(hasattr(obj, attr) and getattr(obj, attr) for obj in objects) + + if not check_all_attr_true(fake_tensors, "is_cuda") and not check_all_attr_true( + fake_tensors, "is_xpu" + ): + return False + if not all([len(A.shape) == 2, len(B.shape) == 2, len(C.shape) == 2]): + return False + if not ((A.shape[1] == B.shape[0]) and (B.shape[1] == C.shape[0])): + return False + # size checks: we only dispatch to B2B-GEMM when the average load ratio is > 1 + M, N = A.shape + O, P = C.shape + ratios = [] + if is_left_assoc: + for config in b2b_gemm_configs: + ratio = load_ratio_left( + M, + N, + O, + P, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_O"], + config["BLOCK_SIZE_P"], + ) + ratios.append(ratio) + else: + for config in b2b_gemm_configs: + ratio = load_ratio_right( + M, + N, + O, + P, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_O"], + config["BLOCK_SIZE_P"], + ) + ratios.append(ratio) + ratios.sort(reverse=True) + average_ratio = 1.0 + for r in ratios[:3]: # top 3 choices + average_ratio *= r + average_ratio = average_ratio ** (1 / 3) + return ( + average_ratio > 1 + ) # even if average_ratio is close to 1, the number of stores is always better + + +def unoptimized_b2b_gemm( + is_left_assoc: bool, + subgraph: Subgraph, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + *, + out: torch.Tensor, +) -> torch.Tensor: + """ + The unoptimized version is used as a fallback when the b2b_gemm kernel is not beneficial. + """ + if is_left_assoc: + torch.mm(subgraph.graph_module(torch.mm(A, B)), C, out=out) + else: + torch.mm(A, subgraph.graph_module(torch.mm(B, C)), out=out) + return out + + +unoptimized_choice = ExternKernelChoice(unoptimized_b2b_gemm) + + +def build_subgraph_buffer( + args: list[TensorBox], + subgraph: Subgraph, +): + """ + This function is adapted from ../kernel/flex_attention.py. + The goal is to take in the required args and produce the subgraph buffer + The subgraph buffer is a ComputedBuffer that will be inlined into the triton template + + Args: + args: The args that are passed into the subgraph + subgraph: The Subgraph ir for which to produce the output node + """ + cnt = 0 + env = {} + for node in subgraph.graph_module.graph.nodes: + if node.op == "placeholder": + env[node] = args[cnt] + cnt += 1 + elif node.op == "call_function": + # For call_function we use the default lowerings and pass in the + # already created TensorBoxes as args + args, kwargs = tree_map( + lambda x: env[x] if x in env else x, (node.args, node.kwargs) + ) + env[node] = lowerings[node.target](*args, **kwargs) + elif node.op == "output": + + def convert_output_node_to_buffer(output): + if output is None: + return None + output_node = output + output_buffer = env[output_node] + assert isinstance(output_buffer, TensorBox), ( + "The output node for B2B-GEMM's subgraph must be a TensorBox, but got: ", + type(output_buffer), + ) + assert isinstance(output_buffer.data, StorageBox), ( + "The output node for B2B-GEMM's subgraph must be a StorageBox, but got: ", + type(output_buffer), + ) + subgraph_buffer = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=output_buffer.data.get_device(), + dtype=output_buffer.data.get_dtype(), + size=output_buffer.data.get_size(), + ), + data=output_buffer.data.data, # type: ignore[arg-type] + ) + return subgraph_buffer + + # node.args[0] should be a single element representing the output of the subgraph + return tree_map(convert_output_node_to_buffer, node.args[0]) + + raise ValueError("B2B-GEMM was passed a subgraph with no output node!") + + +def create_placeholder( + name: str, dtype: torch.dtype, device: torch.device +) -> TensorBox: + """ + Creates a placeholder input buffers for producing subgraph_output + """ + input_buffer = InputBuffer(name=name, layout=FixedLayout(device, dtype, [], [])) + return TensorBox.create(input_buffer) + + +def tuned_b2b_gemm( + is_left_assoc: bool, + subgraph: Subgraph, + A: torch._inductor.ir.TensorBox, + B: torch._inductor.ir.TensorBox, + C: torch._inductor.ir.TensorBox, + *, + layout=None, +) -> torch._inductor.ir.TensorBox: + # call .realize() to get rid of Pointwise + A.realize() + B.realize() + C.realize() + layout = FixedLayout( + A.get_device_or_error(), + A.get_dtype(), + [A.shape[0], C.shape[1]], # type: ignore[index] + ) + subgraph_buffer = build_subgraph_buffer( + [create_placeholder("inner_mm", A.get_dtype(), A.get_device_or_error())], + subgraph, + ) + choices: list[TritonTemplateCaller] = [] + for config in b2b_gemm_configs: + if is_left_assoc: + b2b_gemm_left_template.maybe_append_choice( + choices, + input_nodes=(A, B, C), + layout=layout, + subgraphs=[subgraph_buffer], + **config, + ) + else: + b2b_gemm_right_template.maybe_append_choice( + choices, + input_nodes=(A, B, C), + layout=layout, + subgraphs=[subgraph_buffer], + **config, + ) + # add the unoptimized choice to mitigate performance degradation + choices.append( + unoptimized_choice.bind( + (A, B, C), layout, is_left_assoc=is_left_assoc, subgraph=subgraph + ) + ) + # autotune + return autotune_select_algorithm("b2b_gemm", choices, [A, B, C], layout) + + +# match the inner mm of a potential b2b_gemm +@register_graph_pattern( + CallFunction(torch.ops.aten.mm, Arg(), Arg()), + pass_dict=B2B_GEMM_PASS, +) +def b2b_gemm_handler(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node) -> None: + # match.args: list[torch.fx.Node] + + def is_pointwise_node(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) + and (torch.Tag.pointwise in node.target.tags) + ) + + def is_mm(node: torch.fx.Node) -> bool: + return node.target == torch.ops.aten.mm.default + + # the inner MM + inner_mm = match.nodes[-1] + + # find the (candidate) outer MM, which will be re-checked below to ensure every path reaches it + # In a real (A @ f(B @ C)), every path starting from (B @ C) must reach (A @ _). + outer_mm = None + node = inner_mm + while len(node.users) > 0: + node = next(iter(node.users)) + if is_mm(node): + outer_mm = node + break + elif is_pointwise_node(node): + continue + else: + break + if not outer_mm: + return + + # find the unique input node for outer_mm representing f(B @ C) in (A @ f(B @ C)) + # we call it the "f_node" + # when the pattern is simply (A @ (B @ C)), f_node is just inner_mm + f_node = inner_mm + while next(iter(f_node.users)) is not outer_mm: + f_node = next(iter(f_node.users)) + + def all_reach_via_pointwise_with_no_other_inputs( + src: torch.fx.Node, + dst: torch.fx.Node, + ) -> tuple[bool, OrderedSet[torch.fx.Node]]: + """ + check whether every user path from src reaches dst via pointwise nodes, + with no other input nodes for the intermediates and dst; + return + (1) the Boolean value + (2) the subgraph node set including src and dst (which only makes sense when the Boolean value is True) + """ + visited = OrderedSet[torch.fx.Node]() + input_counter: dict[torch.fx.Node, int] = {} + + all_reachable = True + queue = deque([src]) + while queue: + node = queue.popleft() + if node not in visited: + if node is dst: + visited.add(node) + elif (node is src) or is_pointwise_node(node): + for user in node.users.keys(): + # for nodes other than dst, bookkeep their users' input counts + if user not in input_counter: + input_counter[user] = len(user.all_input_nodes) + input_counter[user] -= 1 + # continue BFS + queue.append(user) + visited.add(node) + else: + all_reachable = False + break + + return ( + all_reachable and all(count == 0 for count in input_counter.values()), + visited, + ) + + # check inner_mm reaches f_node on every user path via pointwise nodes with no outside input_nodes + ok, subgraph_node_set = all_reach_via_pointwise_with_no_other_inputs( + inner_mm, f_node + ) + if not ok: + return + + # check inner_mm's inputs and f_node's outputs + if not (len(inner_mm.all_input_nodes) == 2 and len(f_node.users) == 1): + return + + # at this point, the nodes between inner_mm and f_node (both included) + # are all used internally inside (A @ subgraph(B @ C)) + # i.e. they neither have other users nor have other inputs + + # original graph and module + graph, module = inner_mm.graph, inner_mm.graph.owning_module + + # construct the new (sub)graph + subgraph_node_list: list[ + torch.fx.Node + ] = [] # ordered list of nodes used for node removal later + new_graph: torch.fx.Graph = torch.fx.Graph() + node_remapping: dict[torch.fx.Node, torch.fx.Node] = {} + new_input_anchor: torch.fx.Node # inner_mm, to be changed to an input node + new_output_anchor: torch.fx.Node # f_node, to be used to construct an output node + new_input_node: torch.fx.Node + new_output_node: torch.fx.Node + for node in graph.nodes: # preserve the order of nodes + if node in subgraph_node_set: + subgraph_node_list.append(node) + new_node = new_graph.node_copy( + node, lambda x: node_remapping[x] if x in node_remapping else x + ) + node_remapping[node] = new_node + if node is inner_mm: + new_input_anchor = new_node + if node is f_node: + new_output_anchor = new_node + if new_input_anchor is not new_output_anchor: # subgraph is non-trivial + # update the input node + with new_graph.inserting_before(new_input_anchor): + new_input_node = new_graph.placeholder(name="subgraph_input") + new_input_node.meta.update(new_input_anchor.meta) + new_input_anchor.replace_all_uses_with(new_input_node) + new_graph.erase_node(new_input_anchor) + # add the output node + new_output_node = new_graph.output(new_output_anchor) + new_output_node.meta.update(new_output_anchor.meta) + else: # subgraph is trivial, e.g. (A @ (B @ C)) + # update the input node + with new_graph.inserting_before(new_input_anchor): + new_input_node = new_graph.placeholder(name="subgraph_input") + new_input_node.meta.update(new_input_anchor.meta) + new_input_anchor.replace_all_uses_with(new_input_node) + new_graph.erase_node(new_input_anchor) + # update the output node (don't use new_output_anchor since it has been erased) + new_output_node = new_graph.output(new_input_node) + new_output_node.meta.update(new_input_node.meta) + new_graph.lint() + + # construct the subgraph + subgraph = Subgraph( + name="subgraph", graph_module=torch.fx.GraphModule(module, new_graph) + ) + + # two cases + # (1) (subgraph(A @ B) @ C), called "left_assoc" + # (2) (A @ subgraph(B @ C)), called "right_assoc" + is_left_assoc = outer_mm.args[0] is f_node + + # find the nodes A, B, C and check the sizes + A: torch.fx.Node + B: torch.fx.Node + C: torch.fx.Node + if is_left_assoc: + A = inner_mm.args[0] # type: ignore[assignment] + B = inner_mm.args[1] # type: ignore[assignment] + C = outer_mm.args[1] # type: ignore[assignment] + else: + A = outer_mm.args[0] # type: ignore[assignment] + B = inner_mm.args[0] # type: ignore[assignment] + C = inner_mm.args[1] # type: ignore[assignment] + if not is_b2b_gemm_good_on(is_left_assoc, A, B, C): + return + + # finally update the original graph + counters["inductor"]["b2b_gemm"] += 1 + graph = match.graph + with graph.inserting_before(outer_mm): + function = functools.partial(tuned_b2b_gemm, is_left_assoc, subgraph) + function.__name__ = tuned_b2b_gemm.__name__ # type: ignore[attr-defined] + function._inductor_lowering_function = True # type: ignore[attr-defined] + replacement: torch.fx.Node = graph.call_function( + function, + (A, B, C), + match.kwargs, + ) + replacement.meta.update(outer_mm.meta) + outer_mm.replace_all_uses_with(replacement) + # erase unnecessary nodes + graph.erase_node(outer_mm) + for node in reversed(subgraph_node_list): + graph.erase_node(node) + graph.lint() diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/binary_folding.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/binary_folding.py new file mode 100644 index 0000000000000000000000000000000000000000..2002bbe9f7dabfc9697eb89f23202b4331813e47 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/binary_folding.py @@ -0,0 +1,503 @@ +# mypy: allow-untyped-defs +import functools +import itertools + +import torch + +from ..._dynamo.utils import counters +from .. import config +from ..pattern_matcher import Arg, CallFunction, KeywordArg +from .freezing_patterns import register_binary_folding_pattern + + +aten = torch.ops.aten +prims = torch.ops.prims + + +def mark_mixed_dtype(computation_node): + computation_node_dtype = computation_node.meta["val"].dtype + if computation_node_dtype not in (torch.float16, torch.bfloat16): + return + + if not len(computation_node.users) == 1: + return + + computation_node_user = next(iter(computation_node.users.keys())) + if not isinstance(computation_node_user.meta["val"], torch.Tensor): + return + + if not computation_node_user.meta["val"].dtype == torch.float32: + return + + while computation_node_user.target in _binary_ops: + if not len(computation_node_user.users) == 1: + return + + computation_node_user = next(iter(computation_node_user.users.keys())) + + if computation_node_user.target != prims.convert_element_type.default: + return + + computation_node.meta["_allow_mixed_dtype_folding"] = computation_node_dtype + + +def mark_mixed_dtype_allowed_computation_ops(gm): + """ + Mark convolutions/linear which we will binary fold even with mixed precision constants. We constant fold in the higher precision + for better accuracy and then recover the original precision after. + """ + for target in [aten.convolution.default, aten.addmm.default, aten.mm.default]: + for node in gm.graph.find_nodes(op="call_function", target=target): + mark_mixed_dtype(node) + + +def recover_original_precision_folded_computation_ops(gm): + """ + After binary folding conv/linear weights and biases to a higher dtype, recover the original precision they were in. + """ + graph = gm.graph + for target, idx in ( + (aten.convolution.default, (1, 2)), + (aten.addmm.default, (0, 2)), + (aten.mm.default, (1,)), + ): + for node in graph.find_nodes(op="call_function", target=target): + orig_dtype = node.meta.get("_allow_mixed_dtype_folding", None) + if orig_dtype is None: + continue + + with graph.inserting_before(node): + for i in idx: + old_input = node.args[i] + if old_input is None: + continue + + new_input = graph.create_node( + "call_function", + prims.convert_element_type.default, + (old_input, orig_dtype), + ) + node.replace_input_with(old_input, new_input) + + +_binary_ops = [aten.add.Tensor, aten.sub.Tensor, aten.mul.Tensor, aten.div.Tensor] + + +@functools.cache +def binary_folding_init(): + _conv_args = [Arg() for _ in range(9)] + _addmm_args = [Arg() for _ in range(3)] + _mm_args = [Arg() for _ in range(2)] + _computation_ops = [aten.convolution.default, aten.addmm.default, aten.mm.default] + _computation_calls = [ + CallFunction(aten.convolution.default, *_conv_args, _users=1), + CallFunction(aten.addmm.default, *_addmm_args, _users=1), + CallFunction( + aten.reshape.default, + CallFunction(aten.addmm.default, *_addmm_args, _users=1), + Arg(), + _users=1, + ), + CallFunction(aten.mm.default, *_mm_args, _users=1), + CallFunction( + aten.reshape.default, + CallFunction(aten.mm.default, *_mm_args, _users=1), + Arg(), + _users=1, + ), + ] + + """ + In order to fuse add/sub/mul/div with conv/linear, the dimensions of its + constant tensor must satisfy the following: + - with resizing, broadcast to w/ weight/bias tensor shape + - broadcast to the conv/linear output shape + It needs to have a shape that can resize to weight/bias + tensor shape because we need to run the op with the conv/linear + weights/bias without changing their sizes. + It needs to broadcast to the conv/linear output shape so that we do + accidentally change the shape of op output by pre-fusing it + compared to eager. + The only dimension value shared by weight, bias, and conv/linear output + is they all contain a dim with value = channels-out. In the + conv/linear output tensor, this is in the second dimension, + so the pointwise op tensor may have a second dimension of + value == channels-out, but all the other dimensions have to be 1 + """ + + def _op_not_broadcasting_with_conv(weight_tensor, other_tensor): + # According to opDoesNotBroadCastWithConv of frozen_conv_folding.cpp + weight_shape = weight_tensor.shape + other_shape = other_tensor.shape + if len(weight_shape) < len(other_shape): + return False + if len(weight_shape) == len(other_shape) + 1: + # weight shape is [o, i, *], other_shape is [o, 1...]. + for i in reversed(range(len(other_shape))): + if i == 0 and weight_shape[0] == other_shape[i]: + continue + if other_shape[i] != 1: + return False + else: + # weight shape is [o, i, *], other_shape is [1, i, *] + for i in reversed(range(len(other_shape))): + if i == 1 and weight_shape[0] == other_shape[i]: + continue + if other_shape[i] != 1: + return False + return True + + def _op_not_broadcasting_with_linear(weight_tensor, other_tensor, has_reshape): + weight_shape = weight_tensor.shape + other_shape = other_tensor.shape + other_shapes = [ + torch.Size( + [ + weight_shape[1], + ] + ), + torch.Size([1, weight_shape[1]]), + torch.Size( + [ + 1, + ] + ), + torch.Size([1, 1]), + ] + if has_reshape: + other_shapes.extend( + [ + torch.Size([1, 1, weight_shape[1]]), + torch.Size([1, 1, 1]), + ] + ) + return other_shape in other_shapes + + def _check_conv_and_broadcast_op(conv_node, other): + # According to checkConvAndBroadcastingOpPreConditions of frozen_conv_folding.cpp. + # conv.weight + if conv_node.args[1].op != "get_attr": + return False + # conv.bias + if conv_node.args[1] is not None and conv_node.args[1].op != "get_attr": + return False + if ( + not isinstance(other, int) + and not isinstance(other, float) + and other.op != "get_attr" + ): + return False + + if not len(conv_node.args[1].users) == 1: + return False + + weight_meta_value = conv_node.args[1].meta.get("val") + if weight_meta_value is None: + return False + # Avoid fusing op that causes type promotion + # restricting to float avoids int/float difficulties with scalar overload + if not weight_meta_value.is_floating_point(): + return False + if isinstance(other, torch.fx.Node) and other.op == "get_attr": + other_meta_value = other.meta.get("val") + if not other_meta_value.is_floating_point(): # type: ignore[union-attr] + return False + if ( + torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype) # type: ignore[union-attr] + != weight_meta_value.dtype + ): + if not conv_node.meta.get("_allow_mixed_dtype_folding", False): + return False + + if ( + other_meta_value.dtype != torch.float # type: ignore[union-attr] + and weight_meta_value.dtype not in (torch.float16, torch.bfloat16) + ): + return False + + if not _op_not_broadcasting_with_conv(weight_meta_value, other_meta_value): + return False + elif not isinstance(other, float): + return False + + return True + + def _check_linear_and_broadcast_op(linear_node, other, has_reshape): + weight_node = ( + linear_node.args[2] + if linear_node.target is aten.addmm.default + else linear_node.args[1] + ) + bias_node = ( + linear_node.args[0] if linear_node.target is aten.addmm.default else None + ) + if weight_node.op != "get_attr": + return False + if bias_node is not None and bias_node.op != "get_attr": + return False + if ( + not isinstance(other, int) + and not isinstance(other, float) + and other.op != "get_attr" + ): + return False + + if not len(weight_node.users) == 1: + return False + + weight_meta_value = weight_node.meta.get("val") + if weight_meta_value is None: + return False + # Avoid fusing op that causes type promotion + # restricting to float avoids int/float difficulties with scalar overload + if not weight_meta_value.is_floating_point(): + return False + if isinstance(other, torch.fx.Node) and other.op == "get_attr": + other_meta_value = other.meta.get("val") + if not other_meta_value.is_floating_point(): # type: ignore[union-attr] + return False + if ( + torch.promote_types(other_meta_value.dtype, weight_meta_value.dtype) # type: ignore[union-attr] + != weight_meta_value.dtype + ): + if not linear_node.meta.get("_allow_mixed_dtype_folding", False): + return False + + if ( + other_meta_value.dtype != torch.float # type: ignore[union-attr] + and weight_meta_value.dtype not in (torch.float16, torch.bfloat16) + ): + return False + + if not _op_not_broadcasting_with_linear( + weight_meta_value, other_meta_value, has_reshape + ): + return False + elif not isinstance(other, float): + return False + + return True + + def _is_foldable_pattern(match): + binary_node = match.output_node() + has_reshape = False + if binary_node.args[0].target in _computation_ops: + computation_node = binary_node.args[0] + other = binary_node.args[1] + elif binary_node.args[0].target == aten.reshape.default: + computation_node = binary_node.args[0].args[0] + other = binary_node.args[1] + has_reshape = True + elif binary_node.args[1].target in _computation_ops: + computation_node = binary_node.args[1] + other = binary_node.args[0] + else: + computation_node = binary_node.args[1].args[0] + other = binary_node.args[0] + has_reshape = False + if computation_node.target == aten.convolution.default: + return _check_conv_and_broadcast_op(computation_node, other) + elif computation_node.target in [aten.addmm.default, aten.mm.default]: + return ( + config.enable_linear_binary_folding + and _check_linear_and_broadcast_op(computation_node, other, has_reshape) + ) + + return False + + def resize_scalar_or_tensor_to_shape(graph, other, shape, weight): + if isinstance(other, float): + with torch.utils._python_dispatch._disable_current_modes(): + other_tensor = torch.tensor( + other, dtype=weight.dtype, device=weight.device + ) + graph.owning_module.register_buffer("other_tensor", other_tensor) + res = graph.create_node("get_attr", "other_tensor") + res = graph.create_node( + "call_function", + aten.reshape.default, + (res, (1,)), + ) + res = graph.create_node( + "call_function", + aten.expand.default, + (res, shape), + ) + elif other.meta.get("val").numel() == 1: + # expand errors if the shape input has less # dims than the tensor input + res = graph.create_node( + "call_function", + aten.reshape.default, + (other, (1,)), + ) + res = graph.create_node( + "call_function", + aten.expand.default, + (res, shape), + ) + else: + res = graph.create_node( + "call_function", + aten.reshape.default, + (other, shape), + ) + return res + + def _create_new_conv_node(graph, conv_node, binary_node, other): + assert conv_node.target == aten.convolution.default + conv_args = list(conv_node.args) + weight_meta_value = conv_node.args[1].meta.get("val") + bias = conv_args[2] + if binary_node.target in [aten.add.Tensor, aten.sub.Tensor]: + other_reshape = resize_scalar_or_tensor_to_shape( + graph, + other, + (weight_meta_value.size(0),), + weight_meta_value, + ) + new_bias = graph.create_node( + "call_function", + binary_node.target, + (0 if bias is None else bias, other_reshape), + ) + conv_args[2] = new_bias + else: + assert binary_node.target in [aten.mul.Tensor, aten.div.Tensor] + weight_broadcast_shape = [1 for _ in range(len(weight_meta_value.shape))] + weight_broadcast_shape[0] = weight_meta_value.size(0) + other_reshape1 = resize_scalar_or_tensor_to_shape( + graph, + other, + tuple(weight_broadcast_shape), + weight_meta_value, + ) + new_weight = graph.create_node( + "call_function", binary_node.target, (conv_args[1], other_reshape1) + ) + new_weight.meta.update(conv_args[1].meta) + conv_args[1] = new_weight + if bias is not None: + other_reshape = resize_scalar_or_tensor_to_shape( + graph, + other, + (weight_meta_value.size(0),), + weight_meta_value, + ) + new_bias = graph.create_node( + "call_function", binary_node.target, (bias, other_reshape) + ) + new_bias.meta.update(bias.meta) + conv_args[2] = new_bias + return graph.create_node("call_function", conv_node.target, tuple(conv_args)) + + def _create_new_linear_node(graph, linear_node, binary_node, other): + assert linear_node.target in [aten.addmm.default, aten.mm.default] + input_node = ( + linear_node.args[1] + if linear_node.target is aten.addmm.default + else linear_node.args[0] + ) + weight_node = ( + linear_node.args[2] + if linear_node.target is aten.addmm.default + else linear_node.args[1] + ) + bias_node = ( + linear_node.args[0] if linear_node.target is aten.addmm.default else None + ) + weight_meta_value = weight_node.meta.get("val") + if binary_node.target in [aten.add.Tensor, aten.sub.Tensor]: + other_reshape = resize_scalar_or_tensor_to_shape( + graph, + other, + (weight_meta_value.size(1),), + weight_meta_value, + ) + new_bias_node = graph.create_node( + "call_function", + binary_node.target, + (0 if bias_node is None else bias_node, other_reshape), + ) + return graph.create_node( + "call_function", + aten.addmm.default, + (new_bias_node, input_node, weight_node), + ) + else: + assert binary_node.target in [aten.mul.Tensor, aten.div.Tensor] + weight_broadcast_shape = [1, weight_meta_value.size(1)] + other_reshape1 = resize_scalar_or_tensor_to_shape( + graph, + other, + tuple(weight_broadcast_shape), + weight_meta_value, + ) + new_weight_node = graph.create_node( + "call_function", binary_node.target, (weight_node, other_reshape1) + ) + new_weight_node.meta.update(weight_node.meta) + if bias_node is not None: + other_reshape = resize_scalar_or_tensor_to_shape( + graph, + other, + (weight_meta_value.size(1),), + weight_meta_value, + ) + new_bias_node = graph.create_node( + "call_function", binary_node.target, (bias_node, other_reshape) + ) + new_bias_node.meta.update(bias_node.meta) + return graph.create_node( + "call_function", + linear_node.target, + (new_bias_node, input_node, new_weight_node), + ) + else: + return graph.create_node( + "call_function", linear_node.target, (input_node, new_weight_node) + ) + + for _computation_call, binary_op in itertools.product( + _computation_calls, _binary_ops + ): + + @register_binary_folding_pattern( + CallFunction(binary_op, _computation_call, KeywordArg("other")), + extra_check=_is_foldable_pattern, + ) + def folded_op(match, *args, **kwargs): + counters["inductor"]["binary_folding"] += 1 + other = kwargs.get("other") + binary_node = match.output_node() + reshape_node = None + if binary_node.args[0].target in _computation_ops: + computation_node = binary_node.args[0] + elif binary_node.args[0].target == aten.reshape.default: + computation_node = binary_node.args[0].args[0] + reshape_node = binary_node.args[0] + elif binary_node.args[1].target in _computation_ops: + computation_node = binary_node.args[1] + else: + computation_node = binary_node.args[1].args[0] + reshape_node = binary_node.args[1] + graph = match.graph + with graph.inserting_before(reshape_node if reshape_node else binary_node): + assert computation_node.target in _computation_ops + if computation_node.target == aten.convolution.default: + counters["inductor"]["binary_folding_conv"] += 1 + new_computation_node = _create_new_conv_node( + graph, computation_node, binary_node, other + ) + else: + new_computation_node = _create_new_linear_node( + graph, computation_node, binary_node, other + ) + new_computation_node.meta.update(computation_node.meta) + if reshape_node: + assert reshape_node.target == aten.reshape.default + computation_node.replace_all_uses_with(new_computation_node) + binary_node.replace_all_uses_with(reshape_node) + else: + binary_node.replace_all_uses_with(new_computation_node) + graph.erase_node(binary_node) + graph.erase_node(computation_node) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/ddp_fusion.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/ddp_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..8db82c3732bb722b3ec3dc4173f41e0b3d1c8962 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/ddp_fusion.py @@ -0,0 +1,586 @@ +# Owner(s): ["oncall: distributed"] +import collections +import inspect +import logging +import math +import operator +from collections.abc import Generator +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, cast, Optional, Union + +import torch +import torch.fx as fx +from torch._dynamo.utils import counters +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + +from ..fx_utils import get_fake_args_kwargs +from ..virtualized import V + + +aten = torch.ops.aten +logger: logging.Logger = logging.getLogger("comm_fusion") + + +def move_block_after(block: list[fx.Node], target_node: fx.Node) -> None: + for node in block: + target_node.append(node) + target_node = node + + +def move_block_before(block: list[fx.Node], target_node: fx.Node) -> None: + for node in block: + target_node.prepend(node) + target_node = node + + +def call_function( + graph: fx.Graph, + target: Union[str, Callable[..., Any]], + args: Optional[tuple[fx.node.Argument, ...]] = None, + kwargs: Optional[dict[str, fx.node.Argument]] = None, +) -> fx.Node: + # We accept target as a str to avoid typing error as the type of + # a node.target is Union[str, Callable[..., Any]]. + # This also allows us to avoid writing check for every call. + if isinstance(target, str): + raise RuntimeError(f"Call function should not get a str target {target=}") + node = graph.call_function(target, args, kwargs) + _, args, kwargs = get_fake_args_kwargs(node) + with V.fake_mode: + node.meta["val"] = target(*args, **kwargs) + # node.meta["val"] may be a container. So we use tree_map here + # to recursively extract the tensor metadata. + node.meta["tensor_meta"] = tree_map( + _extract_tensor_metadata, (node.meta["val"],) + )[0] + return node + + +@dataclass(unsafe_hash=True) +class CommBlock: + shape: Union[torch.Size, list[torch.Size]] + node_list: list[fx.Node] + inputs: list[fx.Node] + wait_nodes: list[fx.Node] + comm_node: fx.Node + outputs: OrderedSet[fx.Node] + + +def get_comm_block(comm_node: fx.Node) -> Optional[CommBlock]: + """ + Given a collective node (e.g., allreduce), find out all the nodes belong to + this communication. + + Args: + comm_node(fx.Node): The target communication/collective node. + Returns: + The CommBlock that encapsulates the related nodes (e.g., wait_node) of + the given comm_node. + """ + node_list = [] + wait_nodes = [] + inputs, _ = tree_flatten((comm_node.args, comm_node.kwargs)) + input_nodes = [inp for inp in inputs if isinstance(inp, fx.Node)] + # If the users of the wait node are following items, we consinder them + # to be a part of the output. + intermediate_outputs = ("split", "reshape", "getitem", "detach", "alias") + + first_user = next(iter(comm_node.users)) + if ( + len(comm_node.users) == 1 + and first_user.target == torch.ops._c10d_functional.wait_tensor.default + ): + # Collective with only one output + node_list = [comm_node, first_user] + wait_nodes.append(first_user) + elif len(comm_node.users) > 1 and first_user.target == operator.getitem: + # Collective with only more than one output + node_list.append(comm_node) + for user in comm_node.users: + if user.target != operator.getitem: + return None + if len(user.users) != 1: + return None + wait_node = next(iter(user.users)) + if wait_node.target != torch.ops._c10d_functional.wait_tensor.default: + return None + wait_nodes.append(wait_node) + node_list.append(user) + node_list.extend(wait_nodes) + else: + return None + + # Identify all the outputs of this collective block. + outputs = OrderedSet[fx.Node]() + nodes = collections.deque(wait_nodes) + while nodes: + node = nodes.popleft() + for user in node.users: + if isinstance(user, fx.Node) and user.name.startswith(intermediate_outputs): + nodes.append(user) + node_list.append(user) + else: + outputs.add(node) + break + + tensor_meta = input_nodes[0].meta["tensor_meta"] + shape: Union[torch.Size, list[torch.Size]] + if isinstance(tensor_meta, TensorMetadata): + shape = tensor_meta.shape + elif isinstance(tensor_meta, (list, tuple)): + shape = [tm.shape for tm in tensor_meta] + else: + logger.warning("Unexpected type of tensor_meta %s", type(tensor_meta)) + return None + + return CommBlock( + shape=shape, + node_list=node_list, + wait_nodes=wait_nodes, + comm_node=comm_node, + inputs=input_nodes, + outputs=outputs, + ) + + +def get_all_comm_blocks( + graph: fx.Graph, + comm_ops: tuple[torch._ops.OpOverload, ...], + comm_filter: Optional[Callable[..., bool]] = None, +) -> list[CommBlock]: + if comm_filter is None: + + def always_true(comm_block: CommBlock) -> bool: + return True + + comm_filter = always_true + + blocks = [] + for node in graph.nodes: + if node.target not in comm_ops: + continue + comm_block = get_comm_block(node) + if comm_block is not None and comm_filter(comm_block): + blocks.append(comm_block) + return blocks + + +def _fuse_allreduce_by_concat( + graph: fx.Graph, + last_input_node: fx.Node, + all_input_nodes: list[fx.Node], + last_comm_block: CommBlock, +) -> CommBlock: + """Given a list of inputs in order, create a fused allreduce using concat.""" + # Flatten all the inputs to the all_reduce nodes. + with graph.inserting_after(last_input_node): + cat_inputs = [] + for input_node in all_input_nodes: + assert isinstance(input_node.args[0], fx.Node) + input_node = input_node.args[0] + cat_inputs.append( + call_function(graph, aten.flatten.using_ints, (input_node,)) + ) + + # Concat all the flattened nodes. + with graph.inserting_after(cat_inputs[0]): + cat_node = call_function(graph, aten.cat, (cat_inputs,)) + + # Insert the fused div node and remove the input div nodes. + # This is an optimization and is not mandatory for fusion. + divisors = [div.args[1] for div in all_input_nodes] + assert all(divisor == divisors[0] for divisor in divisors) + with graph.inserting_after(cat_node): + div_node = call_function(graph, last_input_node.target, (cat_node, divisors[0])) + + # Create a new Comm/all_reduce node. + last_comm_node = last_comm_block.comm_node + last_wait_node = last_comm_block.wait_nodes[0] + with graph.inserting_after(div_node): + flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs)) + flatten_args[0] = div_node + args, kwargs = tree_unflatten(flatten_args, spec) + fused_comm_node = call_function(graph, last_comm_node.target, args, kwargs) + + # Create a new Wait node. + with graph.inserting_after(fused_comm_node): + flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs)) + flatten_args[0] = fused_comm_node + args, kwargs = tree_unflatten(flatten_args, spec) + fused_wait_node = call_function(graph, last_wait_node.target, args, kwargs) + + # Move the fused all_reduce and its args to right after the input node + nodes_to_move = cat_inputs + [cat_node, div_node, fused_comm_node, fused_wait_node] + move_block_after(nodes_to_move, last_input_node) + + return CommBlock( + shape=cast(TensorMetadata, cat_node.meta.get("tensor_meta")).shape, + node_list=[fused_comm_node, fused_wait_node], + wait_nodes=[fused_wait_node], + comm_node=fused_comm_node, + inputs=[div_node], + outputs=OrderedSet([fused_wait_node]), + ) + + +def _fuse_with_coalesced_op( + graph: fx.Graph, + last_input_node: fx.Node, + all_input_nodes: list[fx.Node], + last_comm_block: CommBlock, +) -> CommBlock: + """Given a list of inputs in order, create a fused allreduce by coalesced.""" + last_comm_node = last_comm_block.comm_node + last_wait_node = last_comm_block.wait_nodes[0] + + # Insert the fused div node and remove the input div nodes. + # This is an optimization and is not mandatory for fusion. + dividends = [div.args[0] for div in all_input_nodes] + divisors = [div.args[1] for div in all_input_nodes] + assert all(divisor == divisors[0] for divisor in divisors) + with graph.inserting_before(last_input_node): + last_input_node = call_function( + graph, aten._foreach_div.Scalar, (dividends, divisors[0]) + ) + input_node = last_input_node + + # Create a new Comm/all_reduce_coalesced node. + with graph.inserting_after(last_comm_node): + flatten_args, spec = tree_flatten((last_comm_node.args, last_comm_node.kwargs)) + flatten_args[0] = input_node + args, kwargs = tree_unflatten(flatten_args, spec) + fused_comm_node = call_function( + graph, torch.ops._c10d_functional.all_reduce_coalesced.default, args, kwargs + ) + + # Create a new wait node. + getitem_nodes = [] + wait_nodes = [] + flatten_args, spec = tree_flatten((last_wait_node.args, last_wait_node.kwargs)) + for idx in range(len(all_input_nodes)): + with graph.inserting_after(fused_comm_node): + gi_node = call_function(graph, operator.getitem, (fused_comm_node, idx)) + getitem_nodes.append(gi_node) + flatten_args[0] = gi_node + args, kwargs = tree_unflatten(flatten_args, spec) + with graph.inserting_after(gi_node): + wait_nodes.append(call_function(graph, last_wait_node.target, args, kwargs)) + + # Move the new all_reduce_coalesced and its args to right after the input node + nodes_to_move = [fused_comm_node] + getitem_nodes + wait_nodes + move_block_after(nodes_to_move, last_input_node) + + return CommBlock( + shape=[ + tm.shape + for tm in cast( + list[TensorMetadata], fused_comm_node.meta.get("tensor_meta") + ) + ], + node_list=[fused_comm_node] + getitem_nodes + wait_nodes, + wait_nodes=wait_nodes, + comm_node=fused_comm_node, + inputs=[input_node], + outputs=OrderedSet(wait_nodes), + ) + + +def _scatter_fused_allreduce_waits( + graph: fx.Graph, + fused_comm_block: CommBlock, + orig_comm_blocks: list[CommBlock], + node_indices: dict[fx.Node, int], + split_and_reshape: bool = True, +) -> None: + """ + Scatters the result of the fused communication node to the original users. + If the fused method is concat splitting the output and reshape will be inserted, + before inserting getitem. Otherwise getitem will be used as the users of the + wait node. + """ + + # Before we mass up the order, we need to get the index of the last wait node + # in orig_comm_blocks. This index will be later used to determine what users + # nodes need to be move to maintain a correct topological sort order. + last_wait_node_idx = 0 + for node in graph.nodes: + last_wait_node_idx = max( + node_indices.get(node, last_wait_node_idx), last_wait_node_idx + ) + if node == orig_comm_blocks[-1].wait_nodes[0]: + break + + if split_and_reshape: + fused_wait_node = fused_comm_block.wait_nodes[0] + with graph.inserting_after(fused_wait_node): + split_node = call_function( + graph, + aten.split, + ( + fused_wait_node, + [math.prod(cast(list[int], cb.shape)) for cb in orig_comm_blocks], + ), + ) + with graph.inserting_after(split_node): + fused_outputs = [] + for idx, comm_block in enumerate(orig_comm_blocks): + split_idx_node = call_function( + graph, operator.getitem, (split_node, idx) + ) + with graph.inserting_after(split_idx_node): + fused_outputs.append( + call_function( + graph, aten.reshape, (split_idx_node, comm_block.shape) + ) + ) + else: + fused_outputs = fused_comm_block.wait_nodes + + # Scatter the fused outputs. + incorrect_order_nodes = [] + for comm_block, fused_output in zip(orig_comm_blocks, fused_outputs): + # Some descendant users of the orig_comm_blocks may be scheduled before + # the fused all_reduce. For example, the user nodes of the very first + # all_reduce may be scheduled before the second all_reduce. Since the + # fused all_reduce is inserted right after the last all_reudce, the + # order can be wrong. + # `incorrect_order_nodes` records these nodes. + + orig_wait = comm_block.wait_nodes[0] + nodes = collections.deque(list(orig_wait.users)) + while nodes: + user_node = nodes.popleft() + if not isinstance(user_node, fx.Node): + continue + if node_indices[user_node] < last_wait_node_idx: + incorrect_order_nodes.append(user_node) + nodes.extend(list(user_node.users)) + + orig_wait.replace_all_uses_with(fused_output) + + last_fused_result = fused_outputs[0] + fused_outputs_set = OrderedSet(fused_outputs) + for node in graph.nodes: + if node in fused_outputs_set: + last_fused_result = node + + # Move the incorrect_order_nodes to right after the last fused_result. + incorrect_order_nodes = sorted( + incorrect_order_nodes, key=lambda node: node_indices[node] + ) + move_block_after(incorrect_order_nodes, last_fused_result) + + +def _fuse_allreduce( + graph: fx.Graph, + comm_blocks: list[CommBlock], + node_indices: dict[fx.Node, int], + use_concat: bool, +) -> CommBlock: + """Given a list of allreduce CommBlock, fuse the CommBlocks into one CommBlock.""" + + if len(comm_blocks) == 1: + return comm_blocks[0] + + # Find the last input node of all the CommBlocks. This node will be served + # as the inserting point of the new collective op. + last_input_node = comm_blocks[0].inputs[0] + last_input_index = -1 + all_input_nodes = [] + for comm_block in comm_blocks: + input_node = comm_block.inputs[0] + all_input_nodes.append(input_node) + index = node_indices[input_node] + if index >= last_input_index: + assert index != last_input_index + last_input_node = input_node + last_input_index = index + + if use_concat: + fused_comm_block = _fuse_allreduce_by_concat( + graph, last_input_node, all_input_nodes, comm_blocks[-1] + ) + else: + fused_comm_block = _fuse_with_coalesced_op( + graph, last_input_node, all_input_nodes, comm_blocks[-1] + ) + + _scatter_fused_allreduce_waits( + graph, fused_comm_block, comm_blocks, node_indices, split_and_reshape=use_concat + ) + + for comm_block in comm_blocks: + for wait in comm_block.wait_nodes: + graph.erase_node(wait) + graph.erase_node(comm_block.comm_node) + graph.eliminate_dead_code() + + return fused_comm_block + + +def _bucket_size_fusion( + graph: fx.Graph, comm_blocks: list[CommBlock], bucket_size_mb: int +) -> Generator[list[CommBlock], None, None]: + MB = 1024**2 + bucket_size = 1 * MB + bucket_cap_size = bucket_size_mb * MB + curr_size = 0 + curr_blocks = [] + + count = 0 + fuse_count = 0 + for i, block in enumerate(comm_blocks): + curr_blocks.append(block) + itemsize = block.comm_node.meta["tensor_meta"].dtype.itemsize + curr_size += cast(torch.Size, block.shape).numel() * itemsize + count += 1 + if curr_size < bucket_size and i != len(comm_blocks) - 1: + continue + + fuse_count += 1 + if torch.distributed.get_rank() == 0: + logger.info( + "DDP bucketing: block%d, count=%d, curr_size=%d, bucket_size=%d", + fuse_count, + count, + curr_size, + bucket_size, + ) + + # Set the debug counters + counters["inductor"]["ddp_buckets"] = fuse_count + yield curr_blocks + + bucket_size = bucket_cap_size + curr_blocks = [] + curr_size = 0 + count = 0 + + +def _fuse_ddp_communication( + graph: fx.Graph, algorithm_fn: Callable[..., Any], fusion_fn: Callable[..., Any] +) -> None: + for output in reversed(graph.nodes): + if output.op == "output": + break + + def ddp_reducer_filter(block: CommBlock) -> bool: + if ( + not isinstance(block.comm_node.args[0], fx.Node) + or block.comm_node.args[0].target != aten.div.Tensor + ): + return False + + if len(block.wait_nodes[0].users) != 1: + # gradient/wait node should only be used by one user + return False + + # Two cases: + # 1. gradient/wait node should be directly used by the output + # if gradient is None before bwd. + # 2. gradient/wait node should be directly used by copy_. + if ( + output not in block.wait_nodes[0].users + and next(iter(block.wait_nodes[0].users)).target != aten.copy_.default + ): + return False + + return True + + ops = ( + torch.ops._c10d_functional.all_reduce_.default, + torch.ops._c10d_functional.all_reduce.default, + ) + comm_blocks = get_all_comm_blocks(graph, ops, comm_filter=ddp_reducer_filter) + node_indices = {node: i for i, node in enumerate(graph.nodes)} + + for block in algorithm_fn(graph, comm_blocks): + fusion_fn(graph, block, node_indices) + + +def fuse_ddp_with_coalesced_op(graph: fx.Graph, bucket_size_mb: int) -> None: + _fuse_ddp_communication( + graph, + partial(_bucket_size_fusion, bucket_size_mb=bucket_size_mb), + partial(_fuse_allreduce, use_concat=False), + ) + + +def fuse_ddp_with_concat_op(graph: fx.Graph, bucket_size_mb: int) -> None: + _fuse_ddp_communication( + graph, + partial(_bucket_size_fusion, bucket_size_mb=bucket_size_mb), + partial(_fuse_allreduce, use_concat=True), + ) + + +def schedule_comm_wait(graph: fx.Graph) -> None: + """ + Delay the execution of wait tensors of allreduce until its first user. + + This algorithm considers the intermediate users, like split, getitem, + of the wait node and schedule those intermediate users as well. + This will result in a better overlapping result. + """ + ops = ( + torch.ops._c10d_functional.all_reduce_.default, + torch.ops._c10d_functional.all_reduce.default, + torch.ops._c10d_functional.all_reduce_coalesced.default, + torch.ops._c10d_functional.all_reduce_coalesced_.default, + ) + comm_blocks = get_all_comm_blocks(graph, ops) + if not comm_blocks: + return + + # Find all the end users. + allreduce_users = OrderedSet[fx.Node]() + for allreduce in comm_blocks: + for output in allreduce.outputs: + allreduce_users.update(output.users) + + node_indices = {node: i for i, node in enumerate(graph.nodes)} + for allreduce in comm_blocks: + # Find the earliest/first user -- target_node. + assert len(allreduce.outputs) >= 1, ( + f"Found a allreduce that has zero outputs/users -- {allreduce}." + ) + # Initialize the target node to avoid typing issues. + target_node = next(iter(next(iter(allreduce.outputs)).users)) + target_node_index = 2**31 + for user in (user for output in allreduce.outputs for user in output.users): + index = node_indices[user] + if index < target_node_index: + target_node = user + target_node_index = index + + # Move wait nodes and all the subsequent nodes in the comm_block to + # before the first user -- target_node. + wait_idx = -1 + for wait_idx, node in enumerate(allreduce.node_list): + if node == allreduce.wait_nodes[0]: + break + assert wait_idx >= 0 + move_block_before(allreduce.node_list[wait_idx:], target_node) + + +def fuse_ddp_communication( + graph: fx.Graph, passes: list[Union[Callable[..., None], str]], bucket_size_mb: int +) -> None: + for i, pa in enumerate(passes): + with GraphTransformObserver( + graph.owning_module, f"fuse_ddp_communication_pass_{i}" + ): + if isinstance(pa, str): + func = globals()[pa] + else: + func = pa + if "bucket_size_mb" in OrderedSet( + v.name for v in inspect.signature(func).parameters.values() + ): + func(graph, bucket_size_mb=bucket_size_mb) + else: + func(graph) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..caf384ec6c717f550d858e151d1034d7dea3d39f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/decompose_mem_bound_mm.py @@ -0,0 +1,156 @@ +# mypy: allow-untyped-defs +import logging + +import torch +from torch import Tensor +from torch._dynamo.utils import counters, is_node_meta_valid +from torch.fx.experimental.symbolic_shapes import statically_known_true + +from .. import config +from ..pattern_matcher import Arg, CallFunction, Match, register_graph_pattern +from .split_cat import construct_pattern_matcher_pass + + +aten = torch.ops.aten +log = logging.getLogger(__name__) + +# TODO: need a better strategy for decomposing mm +MIN_FIRST_DIMENSION_DECOMPOSITION = 10240 +MAX_OTHER_DIMENSION_DECOMPOSITION = 32 + +min_first_dimension_decomposition = MIN_FIRST_DIMENSION_DECOMPOSITION +max_other_dimension_decomposition = MAX_OTHER_DIMENSION_DECOMPOSITION +if "decompose_mm_pass" in config.post_grad_fusion_options: + min_first_dimension_decomposition = config.post_grad_fusion_options[ + "decompose_mm_pass" + ].get("min_first_dimension_decomposition", MIN_FIRST_DIMENSION_DECOMPOSITION) + max_other_dimension_decomposition = config.post_grad_fusion_options[ + "decompose_mm_pass" + ].get("max_other_dimension_decomposition", MAX_OTHER_DIMENSION_DECOMPOSITION) + + +def check_device(a: Tensor, b: Tensor, device="cuda") -> bool: + return (a.device.type == b.device.type) and (b.device.type == device) + + +def realize_inputs(inputs: list[torch.fx.Node]): + for inp in inputs: + if isinstance(inp, torch.fx.node.Node): + inp.meta["inductor_realize_to_strides"] = True + + +def should_decompose_bmm(mat1, mat2) -> bool: + if is_node_meta_valid(mat1) and is_node_meta_valid(mat2): + mat1 = mat1.meta["val"] + mat2 = mat2.meta["val"] + else: + return False + if len(mat1.shape) != 3 or len(mat2.shape) != 3: + return False + if check_device(mat1, mat2, device="cuda"): + if mat1.shape[0] < min_first_dimension_decomposition: + return False + # 2 of m, n, k must be <= MAX_OTHER_DIMENSION_DECOMPOSITION + if (mat1.shape[1] < max_other_dimension_decomposition) + ( + mat1.shape[2] < max_other_dimension_decomposition + ) + (mat2.shape[2] < max_other_dimension_decomposition) < 2: + return False + return True + elif check_device(mat1, mat2, device="cpu"): + if mat1.shape[0] == 1 and mat2.shape[0] == 1: + return True + return False + + +def should_decompose_mm(mat1, mat2) -> bool: + if is_node_meta_valid(mat1) and is_node_meta_valid(mat2): + mat1 = mat1.meta["val"] + mat2 = mat2.meta["val"] + else: + return False + if len(mat1.shape) != 2 or len(mat2.shape) != 2: + return False + return ( + check_device(mat1, mat2, device="cuda") + and statically_known_true(mat1.shape[0] >= min_first_dimension_decomposition) + and statically_known_true(mat2.shape[0] < max_other_dimension_decomposition) + and statically_known_true(mat2.shape[1] < max_other_dimension_decomposition) + ) or ( + check_device(mat1, mat2, device="cpu") + and statically_known_true(mat1.shape[0] == 1) + and statically_known_true(mat2.shape[0] <= 128) + and statically_known_true(mat2.shape[1] <= 512) + ) + + +def print_decompose_pattern(match: Match, inputs: list[torch.fx.Node]): + node = match.nodes[-1] + log.debug( + "Decompose %s with input shape: %s", + node.target, + ", ".join( + str(input.meta["val"].shape) if "val" in input.meta else "None" + for input in inputs + ), + ) + + +@register_graph_pattern( + CallFunction(aten.bmm, Arg(), Arg()), + pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"), +) +def decompose_bmm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node): + def repl(mat1, mat2): + return torch.sum(mat1[:, :, :, None] * mat2[:, None, :, :], dim=-2).to( + mat1.dtype + ) + + if should_decompose_bmm(mat1, mat2): + counters["inductor"]["decompose_bmm"] += 1 + match.replace_by_example(repl, [mat1, mat2]) + print_decompose_pattern(match, [mat1, mat2]) + realize_inputs([mat1, mat2]) + return + + +@register_graph_pattern( + CallFunction(aten.addmm, Arg(), Arg(), Arg()), + pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"), +) +def decompose_addmm( + match: Match, + mat1: torch.fx.Node, + mat2: torch.fx.Node, + mat3: torch.fx.Node, +): + def repl(mat1, mat2, mat3): + return ( + torch.sum(mat2[:, :, None] * mat3[None, :, :], dim=-2).to(mat2.dtype) + mat1 + ) + + if should_decompose_mm(mat2, mat3): + counters["inductor"]["decompose_addmm"] += 1 + match.replace_by_example(repl, [mat1, mat2, mat3]) + print_decompose_pattern(match, [mat1, mat2, mat3]) + realize_inputs([mat1, mat2, mat3]) + return + + +@register_graph_pattern( + CallFunction(aten.mm, Arg(), Arg()), + pass_dict=construct_pattern_matcher_pass("decompose_mm_pass"), +) +def decompose_mm( + match: Match, + mat1: torch.fx.Node, + mat2: torch.fx.Node, +): + def repl(mat1, mat2): + return torch.sum(mat1[:, :, None] * mat2[None, :, :], dim=-2).to(mat1.dtype) + + if should_decompose_mm(mat1, mat2): + counters["inductor"]["decompose_mm"] += 1 + match.replace_by_example(repl, [mat1, mat2]) + print_decompose_pattern(match, [mat1, mat2]) + realize_inputs([mat1, mat2]) + return diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py new file mode 100644 index 0000000000000000000000000000000000000000..fdafd6abec50eaccffc73bc83ff03afddd454d0f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/dedupe_symint_uses.py @@ -0,0 +1,81 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from typing import Any, Union + +import torch +from torch import SymBool, SymFloat, SymInt +from torch.types import py_sym_types +from torch.utils._ordered_set import OrderedSet + + +@dataclass +class _SymExprHash: + """ + Hash for a py_sym_types that will use the underlying sympy expression + """ + + sym_obj: Union[SymInt, SymFloat, SymBool] + + def __hash__(self) -> int: + return hash((type(self.sym_obj), self.sym_obj.node.expr)) + + def __eq__(self, value) -> bool: + if not isinstance(value, _SymExprHash): + return False + return self.sym_obj.node.expr == value.sym_obj.node.expr + + +class _SymHashingDict: + """ + Wrapper around a dictionary that will convert sym types to hash with _SymExprHash and reuse + existing sym proxies. + + SymPy hash is not always reliable so optimistically hash sympy expression, and if those fail, + fallback to symnodes. + """ + + def __init__(self): + self.sym_hash_dict = {} + + def __setitem__(self, key, value): + self.sym_hash_dict.__setitem__(self._wrap_to_sym_expr_hash(key), value) + + def __getitem__(self, key): + return self.sym_hash_dict[self._wrap_to_sym_expr_hash(key)] + + def __contains__(self, key): + return self._wrap_to_sym_expr_hash(key) in self.sym_hash_dict + + def get(self, key, default=None): + return self.sym_hash_dict.get(self._wrap_to_sym_expr_hash(key), default) + + def _wrap_to_sym_expr_hash(self, key): + return _SymExprHash(key) if isinstance(key, py_sym_types) else key + + +def dedupe_symints(graph: torch.fx.Graph): + """ + Dedupes sym ints in the graph to nodes are resolvable to symint graph inputs. + + We only dedupe from graph inputs to avoid adding a potential dependency in the forward + from the backward. + + """ + + sym_dict = _SymHashingDict() + resolvable_from_input_symints = OrderedSet[Any]() + + for node in graph.nodes: + val = node.meta.get("val", None) + if val is None or not isinstance(val, py_sym_types): + continue + + if node.op == "placeholder": + resolvable_from_input_symints.add(node) + sym_dict[val] = node + elif existing_node := sym_dict.get(val): + node.replace_all_uses_with(existing_node) + graph.erase_node(node) + elif all(n in resolvable_from_input_symints for n in node.all_input_nodes): + sym_dict[val] = node + resolvable_from_input_symints.add(node) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..b5bef51346b0554fa50caf65747e16d77b49b57e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/efficient_conv_bn_eval.py @@ -0,0 +1,409 @@ +# mypy: allow-untyped-defs +import torch +import torch.nn as nn +from torch._dynamo.utils import counters +from torch._inductor import config as inductor_config +from torch.func import functional_call + +from ..pattern_matcher import ( + CallFunctionVarArgs, + CallModuleVarArgs, + Match, + register_graph_pattern, +) +from .pre_grad import efficient_conv_bn_eval_pass + + +def efficient_conv_bn_eval( + bn: nn.modules.batchnorm._BatchNorm, conv: nn.modules.conv._ConvNd, x: torch.Tensor +): + """ + Implementation based on https://arxiv.org/abs/2305.11624 + "Efficient ConvBN Blocks for Transfer Learning and Beyond" + It leverages the associative law between convolution and affine transform, + i.e., normalize (weight conv feature) = (normalize weight) conv feature. + It works for Eval mode of ConvBN blocks during validation, and can be used + for **training** as well, but only if one sets `bn.training=False`. It + reduces memory footprint and computation cost, at the cost of slightly + reduced numerical stability. + Args: + bn (nn.modules.batchnorm._BatchNorm): a BatchNorm module. + conv (nn.modules.conv._ConvNd): a conv module + x (torch.Tensor): Input feature map. + """ + + assert bn.running_var is not None + assert bn.running_mean is not None + + # These lines of code are designed to deal with various cases + # like bn without affine transform, and conv without bias + weight_on_the_fly = conv.weight + if conv.bias is not None: + bias_on_the_fly = conv.bias + else: + bias_on_the_fly = torch.zeros_like(bn.running_var) + + if bn.weight is not None: + bn_weight = bn.weight + else: + bn_weight = torch.ones_like(bn.running_var) + + if bn.bias is not None: + bn_bias = bn.bias + else: + bn_bias = torch.zeros_like(bn.running_var) + + # shape of [C_out, 1, 1, 1] in Conv2d + target_shape = [-1] + [1] * (conv.weight.ndim - 1) + if isinstance(conv, nn.modules.conv._ConvTransposeNd): + # for transposed conv, the C_out dimension should at index 1. + target_shape[:2] = [target_shape[1], target_shape[0]] + weight_coeff = torch.rsqrt(bn.running_var + bn.eps).reshape(target_shape) + # shape of [C_out, 1, 1, 1] in Conv2d + coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff + + # shape of [C_out, C_in, k, k] in Conv2d + weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly + # shape of [C_out] in Conv2d + bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * ( + bias_on_the_fly - bn.running_mean + ) + + input = x + params = {"weight": weight_on_the_fly, "bias": bias_on_the_fly} + output = functional_call(conv, params, input) + return output + + +def efficient_conv_bn_eval_decomposed( + bn_weight, + bn_bias, + bn_running_mean, + bn_running_var, + bn_eps, + conv: torch._ops.OpOverload, + conv_weight, + conv_bias, + x, + conv_remainging_args, +): + """ + Implementation based on https://arxiv.org/abs/2305.11624 + "Efficient ConvBN Blocks for Transfer Learning and Beyond" + It leverages the associative law between convolution and affine transform, + i.e., normalize (weight conv feature) = (normalize weight) conv feature. + It works for Eval mode of ConvBN blocks during validation, and can be used + for **training** as well, but only if one sets `bn.training=False`. It + reduces memory footprint and computation cost, at the cost of slightly + reduced numerical stability. + Args: + """ + assert bn_running_var is not None + + # These lines of code are designed to deal with various cases + # like bn without affine transform, and conv without bias + weight_on_the_fly = conv_weight + if conv_bias is not None: + bias_on_the_fly = conv_bias + else: + bias_on_the_fly = torch.zeros_like(bn_running_var) + + if bn_weight is not None: + bn_weight = bn_weight + else: + bn_weight = torch.ones_like(bn_running_var) + + if bn_bias is not None: + bn_bias = bn_bias + else: + bn_bias = torch.zeros_like(bn_running_var) + + # shape of [C_out, 1, 1, 1] in Conv2d + target_shape = [-1] + [1] * (conv_weight.ndim - 1) + if "conv_transpose" in conv.__str__(): + # for transposed conv, the C_out dimension should at index 1. + target_shape[:2] = [target_shape[1], target_shape[0]] + weight_coeff = torch.rsqrt(bn_running_var + bn_eps).reshape(target_shape) + # shape of [C_out, 1, 1, 1] in Conv2d + coefff_on_the_fly = bn_weight.view_as(weight_coeff) * weight_coeff + + # shape of [C_out, C_in, k, k] in Conv2d + weight_on_the_fly = weight_on_the_fly * coefff_on_the_fly + # shape of [C_out] in Conv2d + bias_on_the_fly = bn_bias + coefff_on_the_fly.flatten() * ( + bias_on_the_fly - bn_running_mean + ) + + input = x + return conv(*((input, weight_on_the_fly, bias_on_the_fly) + conv_remainging_args)) + + +@register_graph_pattern( + CallFunctionVarArgs( + [ + torch.nn.functional.batch_norm, + ] + ), + pass_dict=efficient_conv_bn_eval_pass, + extra_check=lambda match: not inductor_config.freezing + and inductor_config.efficient_conv_bn_eval_fx_passes, +) +def efficient_conv_bn_eval_graph_transform_inlined(match: Match, *args, **kwargs): + bn_node = match.nodes[0] + graph = match.graph + assert len(bn_node.args) == 8 + + # We can only use efficient conv-bn for eval mode with track_running_stats + # bn_node.args is `training` + if bn_node.args[-3]: + return + + # Check if the input is Conv + input_node = bn_node.args[0] + + if input_node.op != "call_function": # type: ignore[union-attr] + return + + input_fn = input_node.target # type: ignore[arg-type, union-attr] + supported_convs = [ + torch._C._nn.linear, + torch.conv1d, + torch.conv2d, + torch.conv3d, + torch.conv_transpose1d, + torch.conv_transpose2d, + torch.conv_transpose3d, + ] + + if not any(input_fn is cls for cls in supported_convs): + return + + conv_node = input_node + # Output of conv is used by other nodes, cannot optimize + if len(conv_node.users) > 1: # type: ignore[union-attr] + return + + counters["inductor"]["efficient_conv_bn_eval"] += 1 + + with graph.inserting_before(bn_node): + # prepare args for the fused function + bn_running_mean = bn_node.args[1] + bn_running_var = bn_node.args[2] + bn_weight = bn_node.args[3] + bn_bias = bn_node.args[4] + bn_eps = bn_node.args[7] + assert len(conv_node.args) >= 2 # type: ignore[union-attr] + conv_input = conv_node.args[0] # type: ignore[union-attr] + conv_weight = conv_node.args[1] # type: ignore[union-attr] + conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None # type: ignore[union-attr] + conv_remainging_args = conv_node.args[3:] # type: ignore[union-attr] + args = ( + bn_weight, + bn_bias, + bn_running_mean, + bn_running_var, + bn_eps, + conv_node.target, # type: ignore[union-attr] + conv_weight, + conv_bias, + conv_input, + conv_remainging_args, + ) + + # create a new node + new_node = graph.create_node( + op="call_function", + target=efficient_conv_bn_eval_decomposed, + args=args, # type: ignore[arg-type] + name="efficient_conv_bn_eval", + ) + + # this node replaces the original conv + bn, and therefore + # should replace the uses of bn_node + bn_node.replace_all_uses_with(new_node) + # take care of the deletion order: + # delete bn_node first, and then conv_node + graph.erase_node(bn_node) + graph.erase_node(conv_node) # type: ignore[arg-type] + + return + + +@register_graph_pattern( + CallFunctionVarArgs( + [ + torch.ops.aten.batch_norm.default, + ] + ), + pass_dict=efficient_conv_bn_eval_pass, + extra_check=lambda match: not inductor_config.freezing + and inductor_config.efficient_conv_bn_eval_fx_passes, +) +def efficient_conv_bn_eval_graph_transform_decomposed(match: Match, *args, **kwargs): + bn_node = match.nodes[0] + graph = match.graph + assert len(bn_node.args) == 9 + + # We can only use efficient conv-bn for eval mode with track_running_stats + # bn_node.args is `training` + if bn_node.args[-4]: + return + + # Check if the input is Conv + input_node = bn_node.args[0] + + if input_node.op != "call_function": # type: ignore[union-attr] + return + + input_fn = input_node.target # type: ignore[arg-type, union-attr] + supported_convs = [ + torch.ops.aten.linear.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv3d.default, + torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv_transpose3d.input, + ] + + if not any(input_fn is cls for cls in supported_convs): + return + + conv_node = input_node + # Output of conv is used by other nodes, cannot optimize + if len(conv_node.users) > 1: # type: ignore[union-attr] + return + + counters["inductor"]["efficient_conv_bn_eval"] += 1 + + with graph.inserting_before(bn_node): + # prepare args for the fused function + bn_weight = bn_node.args[1] + bn_bias = bn_node.args[2] + bn_running_mean = bn_node.args[3] + bn_running_var = bn_node.args[4] + bn_eps = bn_node.args[7] + assert len(conv_node.args) >= 2 # type: ignore[union-attr] + conv_input = conv_node.args[0] # type: ignore[union-attr] + conv_weight = conv_node.args[1] # type: ignore[union-attr] + conv_bias = conv_node.args[2] if len(conv_node.args) >= 3 else None # type: ignore[union-attr] + conv_remainging_args = conv_node.args[3:] # type: ignore[union-attr] + args = ( + bn_weight, + bn_bias, + bn_running_mean, + bn_running_var, + bn_eps, + conv_node.target, # type: ignore[union-attr] + conv_weight, + conv_bias, + conv_input, + conv_remainging_args, + ) + + # create a new node + new_node = graph.create_node( + op="call_function", + target=efficient_conv_bn_eval_decomposed, + args=args, # type: ignore[arg-type] + name="efficient_conv_bn_eval", + ) + + # this node replaces the original conv + bn, and therefore + # should replace the uses of bn_node + bn_node.replace_all_uses_with(new_node) + # take care of the deletion order: + # delete bn_node first, and then conv_node + graph.erase_node(bn_node) + graph.erase_node(conv_node) # type: ignore[arg-type] + + return + + +@register_graph_pattern( + CallModuleVarArgs( + [ + nn.modules.batchnorm._BatchNorm, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.SyncBatchNorm, + ], + ), + pass_dict=efficient_conv_bn_eval_pass, + extra_check=lambda match: not inductor_config.freezing + and inductor_config.efficient_conv_bn_eval_fx_passes, +) +def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs): + # We matched a BN node + bn_node = match.nodes[0] + graph = match.graph + gm = graph.owning_module + bn_mod = getattr(gm, bn_node.target) # type: ignore[arg-type] + + # We can only use efficient conv-bn for eval mode with track_running_stats + if not bn_mod.track_running_stats or bn_mod.training: + return + + # Check if the input is Conv + if bn_node.args: + input_node = bn_node.args[0] + else: + input_node = bn_node.kwargs["input"] + if input_node.op != "call_module": # type: ignore[union-attr] + return + if not hasattr(gm, input_node.target): # type: ignore[arg-type, union-attr] + return + input_mod = getattr(gm, input_node.target) # type: ignore[arg-type, union-attr] + supported_convs = [ + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + ] + if not any(isinstance(input_mod, cls) for cls in supported_convs): + return + conv_node = input_node + # Output of conv is used by other nodes, cannot optimize + if len(conv_node.users) > 1: # type: ignore[union-attr] + return + + # Find a pair of conv and bn computation nodes to optimize. + counters["inductor"]["efficient_conv_bn_eval"] += 1 + + with graph.inserting_before(conv_node): # type: ignore[arg-type] + # create `get_attr` node to access modules + # note that we directly call `create_node` to fill the `name` + # argument. `graph.get_attr` and + # `graph.call_function` does not allow the `name` argument. + conv_get_node = graph.create_node( + op="get_attr", + target=conv_node.target, # type: ignore[union-attr] + name="get_conv", + ) + bn_get_node = graph.create_node( + op="get_attr", target=bn_node.target, name="get_bn" + ) + if conv_node.args: # type: ignore[union-attr] + conv_input = conv_node.args[0] # type: ignore[union-attr] + else: + conv_input = conv_node.kwargs["input"] # type: ignore[union-attr] + # prepare args for the fused function + args = (bn_get_node, conv_get_node, conv_input) + # create a new node + new_node = graph.create_node( + op="call_function", + target=efficient_conv_bn_eval, + args=args, + name="efficient_conv_bn_eval", + ) + # this node replaces the original conv + bn, and therefore + # should replace the uses of bn_node + bn_node.replace_all_uses_with(new_node) + # take care of the deletion order: + # delete bn_node first, and then conv_node + graph.erase_node(bn_node) + graph.erase_node(conv_node) # type: ignore[arg-type] diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/freezing_patterns.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/freezing_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..e19d29b91ff2920145b3007a3209fe373570e361 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/freezing_patterns.py @@ -0,0 +1,297 @@ +# mypy: allow-untyped-defs +import functools + +import torch +from torch._inductor.compile_fx import fake_tensor_prop +from torch._inductor.utils import GPU_TYPES + +from ..._dynamo.utils import counters +from .. import config +from ..pattern_matcher import ( + _return_true, + CallFunction, + fwd_only, + Ignored, + init_once_fakemode, + KeywordArg, + Match, + PatternMatcherPass, + register_graph_pattern, + register_replacement, + stable_topological_sort, +) + + +aten = torch.ops.aten + +# First pass_patterns[0] are applied, then [1], then [2] +pass_patterns = [ + PatternMatcherPass(), + PatternMatcherPass(), + PatternMatcherPass(), +] + +binary_folding_pass = PatternMatcherPass() + + +def freezing_passes(gm: torch.fx.GraphModule, aot_example_inputs): + """ + Passes that are applied to the graph to freeze pass. + """ + + from ..freezing import constant_fold + + lazy_init() + # We need a few rounds of binary folding to get rid of all the + # unnecessary nodes, but may need a good method to chose the rounds number. + # works like: conv+binary+binary. + binary_folding = counters["inductor"]["binary_folding"] + fake_tensor_prop(gm, aot_example_inputs, True) + + torch._inductor.fx_passes.binary_folding.mark_mixed_dtype_allowed_computation_ops( + gm + ) + for _ in range(4): + constant_fold(gm) + # Make sure meta['val'] is properly set for all nodes + fake_tensor_prop(gm, aot_example_inputs, True) + binary_folding_pass.apply(gm.graph) # type: ignore[arg-type] + # If we don't have binary folding, we don't need to run the pass again. + # TODO: remove the need to run fake_tensor_prop on the whole model. + if counters["inductor"]["binary_folding"] == binary_folding: + break + binary_folding = counters["inductor"]["binary_folding"] + + torch._inductor.fx_passes.binary_folding.recover_original_precision_folded_computation_ops( + gm + ) + + constant_fold(gm) + fake_tensor_prop(gm, aot_example_inputs, True) + + for pattern in pass_patterns: + pattern.apply(gm.graph) # type: ignore[arg-type] + + # The CPU weight packing always assume the conv's weight is channels last, + # So make sure the layout_optimization is on when doing it. + if ( + torch._C._has_mkldnn + and config.cpp.weight_prepack + and config.layout_optimization + ): + from .mkldnn_fusion import _eliminate_duplicate_packed_nodes + + _eliminate_duplicate_packed_nodes(gm) + + stable_topological_sort(gm.graph) + gm.recompile() + gm.graph.lint() + + +@init_once_fakemode +def lazy_init(): + if torch._C._has_mkldnn and config.cpp.weight_prepack: + from .mkldnn_fusion import _mkldnn_weight_pack_init + + _mkldnn_weight_pack_init() + + from .binary_folding import binary_folding_init + + addmm_patterns_init() + binary_folding_init() + + +def register_freezing_graph_pattern(pattern, extra_check=_return_true, pass_number=0): + while pass_number > len(pass_patterns) - 1: + pass_patterns.append(PatternMatcherPass()) + return register_graph_pattern( + pattern, + extra_check=extra_check, + pass_dict=pass_patterns[pass_number], + ) + + +def register_binary_folding_pattern(pattern, extra_check=_return_true): + return register_graph_pattern( + pattern, + extra_check=extra_check, + pass_dict=binary_folding_pass, + ) + + +@functools.cache +def addmm_patterns_init(): + """ + addmm related patterns. + To avoid duplication, also includes int8 WoQ GEMM pattern without bias. + """ + device = next( + (gpu for gpu in GPU_TYPES if getattr(torch, gpu).is_available()), "cpu" + ) + val = functools.partial(torch.empty, (10, 10), device=device, requires_grad=False) + scale = functools.partial(torch.empty, (10,), device=device, requires_grad=False) + + def check_int8_woq_concat_linear_weights(match): + is_cpu = match.kwargs["inp"].meta["val"].is_cpu + if not is_cpu or not config.cpp.enable_concat_linear: + # Currently, this pattern is only supported on CPU + return False + + weight_inputs = ["w1", "w2"] + if "w3" in match.kwargs: + weight_inputs.append("w3") + + if not all( + match.kwargs[wgt].target == torch.ops.prims.convert_element_type.default + for wgt in weight_inputs + ): + return False + + if not all( + next(iter(match.kwargs[wgt]._input_nodes.keys())).meta["val"].dtype + is torch.int8 + for wgt in weight_inputs + ): + return False + + if not all( + match.kwargs[wgt].meta["val"].dtype is torch.bfloat16 + for wgt in weight_inputs + ): + return False + + equal_shape_inputs = [weight_inputs] + for equal_shape_group in equal_shape_inputs: + inps = [match.kwargs[name] for name in equal_shape_group] + if not all( + inp.meta["val"].shape == inps[0].meta["val"].shape for inp in inps + ): + return False + return True + + def check_concat_weights(match): + is_cpu = match.kwargs["inp"].meta["val"].is_cpu + if is_cpu and not config.cpp.enable_concat_linear: + return False + + weight_inputs = ["w1", "w2"] + if "w3" in match.kwargs: + weight_inputs.append("w3") + + equal_shape_inputs = [weight_inputs] + + if "b1" in match.kwargs: + bias_inputs = ["b1", "b2"] + if "b3" in match.kwargs: + bias_inputs.append("b3") + + equal_shape_inputs.append(bias_inputs) + + for equal_shape_group in equal_shape_inputs: + inps = [match.kwargs[name] for name in equal_shape_group] + + if not all( + inp.op == "get_attr" + and inp.meta["val"].shape == inps[0].meta["val"].shape + for inp in inps + ): + return False + return True + + def int8_woq_fusion_pattern(inp, w1, w2, w3, s1, s2, s3): + return ((inp @ w1) * s1, (inp @ w2) * s2, (inp @ w3) * s3) + + def int8_woq_fusion_replacement(inp, w1, w2, w3, s1, s2, s3): + cat_w = torch.cat((w1, w2, w3), dim=1) + cat_s = torch.cat((s1, s2, s3), dim=0) + mm = (inp @ cat_w).mul(cat_s) + return mm.chunk(3, dim=1) + + register_replacement( + int8_woq_fusion_pattern, + int8_woq_fusion_replacement, + [val(), val(), val(), val(), scale(), scale(), scale()], + fwd_only, + pass_patterns[0], + extra_check=check_int8_woq_concat_linear_weights, + exclusive_arg_names=("w1", "w2", "w3", "s1", "s2", "s3"), + ) + + def matmul_fuse_pattern(inp, w1, w2, w3): + return (inp @ w1, inp @ w2, inp @ w3) + + def matmul_replacement(inp, w1, w2, w3): + cat_t = torch.cat((w1, w2, w3), dim=1) + mm = inp @ cat_t + return mm.chunk(3, dim=1) + + register_replacement( + matmul_fuse_pattern, + matmul_replacement, + [val(), val(), val(), val()], + fwd_only, + pass_patterns[0], + extra_check=check_concat_weights, + exclusive_arg_names=("w1", "w2", "w3"), + ) + + def matmul_fuse_pattern_two(inp, w1, w2): + return (inp @ w1, inp @ w2) + + def matmul_replacement_two(inp, w1, w2): + cat_t = torch.cat((w1, w2), dim=1) + mm = inp @ cat_t + return mm.chunk(2, dim=1) + + register_replacement( + matmul_fuse_pattern_two, + matmul_replacement_two, + [val(), val(), val()], + fwd_only, + pass_patterns[0], + extra_check=check_concat_weights, + exclusive_arg_names=("w1", "w2"), + ) + + def addmm_fuse_pattern_second(inp, w1, w2, w3, b1, b2, b3): + return ( + aten.addmm(b1, inp, w1), + aten.addmm(b2, inp, w2), + aten.addmm(b3, inp, w3), + ) + + def addmm_fuse_replacement_second(inp, w1, w2, w3, b1, b2, b3): + cat_w = torch.cat((w1, w2, w3), dim=1) + cat_b = torch.cat((b1, b2, b3)) + return aten.addmm(cat_b, inp, cat_w).chunk(3, dim=1) + + register_replacement( + addmm_fuse_pattern_second, + addmm_fuse_replacement_second, + [val() for _ in range(7)], + fwd_only, + pass_patterns[0], + extra_check=check_concat_weights, + exclusive_arg_names=("w1", "w2", "w3", "b1", "b2", "b3"), + ) + + +def same_dtype(match): + return match.output_node().args[0].meta["val"].dtype == match.kwargs["dtype"] + + +@register_graph_pattern( + CallFunction( + torch.ops.prims.convert_element_type.default, + Ignored(), + KeywordArg("dtype"), + ), + pass_dict=pass_patterns[0], + extra_check=same_dtype, +) +def unnecessary_dtype_convert(match: Match, **kwargs): + """Remove unnecessary dtype conversion op, probably left as a result of Conv-Bn folding""" + graph = match.graph + node = match.output_node() + node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type] + graph.erase_node(node) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/fuse_attention.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/fuse_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..ebcf5bc3d7c0a6eed1064b96c94c6733f015ef5e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/fuse_attention.py @@ -0,0 +1,1089 @@ +# mypy: allow-untyped-defs +import functools +import inspect +import logging +import math + +import torch + +from ..._dynamo.utils import counters +from ..pattern_matcher import ( + filter_nodes, + fwd_only, + gen_register_replacement, + joint_fwd_bwd, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + + +_scaled_dot_product_attention = aten.scaled_dot_product_attention + + +def _sfdp_pattern_1(query, key, value, inv_scale): + return ( + torch.matmul(query, key.transpose(-2, -1)) + .div(inv_scale) + .softmax(dim=-1) + .matmul(value) + ) + + +def _sfdp_replacement_1(query, key, value, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_2(query, key, value, scale_factor): + return ( + torch.matmul(query, key.transpose(-2, -1)) + .mul(scale_factor) + .softmax(dim=-1) + .matmul(value) + ) + + +def _sfdp_replacement_2(query, key, value, scale_factor): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=scale_factor, + ) + + +def _sfdp_pattern_3(query, key, value, inv_scale_factor, dropout_p): + return torch.nn.functional.dropout( + torch.matmul(query, key.transpose(-2, -1)) + .div(inv_scale_factor) + .softmax(dim=-1), + p=dropout_p, + ).matmul(value) + + +def _sfdp_replacement_3(query, key, value, inv_scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale_factor, + ) + + +def _sfdp_pattern_4(query, key, value, scale_factor, dropout_p): + return torch.nn.functional.dropout( + torch.matmul(query, key.transpose(-2, -1)).mul(scale_factor).softmax(dim=-1), + p=dropout_p, + ).matmul(value) + + +def _sfdp_replacement_4(query, key, value, scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=scale_factor, + ) + + +def _sfdp_pattern_5(query, key, value, attn_mask): + attn_weight = torch.softmax( + (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1 + ) + # attn_weight = torch.dropout(attn_weight, dropout_p) + return attn_weight @ value + + +def _sfdp_replacement_5(query, key, value, attn_mask): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=0.0, + is_causal=False, + ) + + +def _sfdp_pattern_6(query, key, value, attn_mask, dropout_p): + attn_weight = torch.softmax( + (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1 + ) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + return attn_weight @ value + + +def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=dropout_p, + is_causal=False, + ) + + +def _sfdp_pattern_7(query, key, value, dropout_p): + # in real workloads inputs to matmul are permuted + # causing matmul to expand to a series of expand and clone calls + # we want the same to happen during pattern tracing + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_7(query, key, value, dropout_p): + # sdpa prefers inputs in permuted format + # it makes a copy to put them in this format + # if they aren't already + # to make replacement efficient ensure that inputs to sdpa + # are in required order + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=dropout_p, + is_causal=False, + ) + + +def _sfdp_pattern_8(query, key, value): + # no dropout version of pattern 7 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_8(query, key, value): + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=0.0, + is_causal=False, + ) + + +def _sfdp_pattern_9(query, key, value, dropout_p): + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + q = q / math.sqrt(q.size(-1)) + div = q @ k.transpose(-2, -1) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_9(query, key, value, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=dropout_p, + is_causal=False, + ) + + +def _sfdp_pattern_10(query, key, value): + # no dropout version of 9 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + q = q / math.sqrt(q.size(-1)) + div = q @ k.transpose(-2, -1) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + + +def _sfdp_replacement_10(query, key, value): + counters["inductor"]["fuse_attention"] += 1 + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, # attn_mask, + dropout_p=0.0, + is_causal=False, + ) + + +def _sfdp_pattern_11(query, key, value, inv_scale): + # Mainly for huggingface models + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return torch.matmul(q, k.transpose(-2, -1)).div(inv_scale).softmax(dim=-1).matmul(v) + + +def _sfdp_replacement_11(query, key, value, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_12(query, key, value, inv_scale_factor, dropout_p): + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + return torch.nn.functional.dropout( + torch.matmul(q, k.transpose(-2, -1)).div(inv_scale_factor).softmax(dim=-1), + p=dropout_p, + ).matmul(v) + + +def _sfdp_replacement_12(query, key, value, inv_scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale_factor, + ) + + +def _sfdp_pattern_13(query, key, value, dropout_p): + attn_weight = torch.bmm(query, key.transpose(1, 2)).softmax(dim=-1) + attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p) + return torch.bmm(attn_weight, value) + + +def _sfdp_replacement_13(query, key, value, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + dropout_p=dropout_p, + scale=1.0, + ).squeeze(0) + + +def _sfdp_pattern_14(query, key, value, attn_mask, inv_scale): + # for BertLarge + # Permutations are needed to create clones in graph. + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + return ( + (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask) + .softmax(dim=-1) + .matmul(v) + ) + + +def _sfdp_replacement_14(query, key, value, attn_mask, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_15(query, key, value, attn_mask, inv_scale): + # for DistilBert + # Permutations are needed to create clones in graph. + # Ref: https://github.com/pytorch/pytorch/issues/119911 + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + bs = q.size(0) + k_len = k.size(-2) + scores = q @ k.transpose(-2, -1) + scores = scores.div(inv_scale) + fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) + attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) + return torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1) @ v + + +def _sfdp_replacement_15(query, key, value, attn_mask, inv_scale): + counters["inductor"]["fuse_attention"] += 1 + bs = query.size(0) + n_head = query.size(2) + q_len = query.size(1) + k_len = key.size(1) + # do attn_mask->logical_not() in _scaled_dot_product_attention + attn_mask = ( + (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) + ) + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=torch.bool), + dropout_p=0.0, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_16(query, key, value, attn_mask, inv_scale, dropout_p): + # for BertLarge with dropout + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + return ( + torch.nn.functional.dropout( + (torch.matmul(q, k.transpose(-2, -1)).div(inv_scale) + attn_mask).softmax( + dim=-1 + ), + dropout_p, + ) + .to(dtype=query.dtype) + .matmul(v) + ) + + +def _sfdp_replacement_16(query, key, value, attn_mask, inv_scale, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=query.dtype), + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_17(query, key, value, attn_mask, inv_scale, dropout_p): + # for DistilBert with dropout + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + bs = q.size(0) + k_len = k.size(-2) + scores = q @ k.transpose(-2, -1) + scores = scores.div(inv_scale) + fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) + attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) + return ( + torch.nn.functional.dropout( + torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1), dropout_p + ) + @ v + ) + + +def _sfdp_replacement_17(query, key, value, attn_mask, inv_scale, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + bs = query.size(0) + n_head = query.size(2) + q_len = query.size(1) + k_len = key.size(1) + # do attn_mask->logical_not() in _scaled_dot_product_attention + attn_mask = ( + (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) + ) + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=torch.bool), + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / inv_scale, + ) + + +def _sfdp_pattern_18(query, key, value, causal_mask, dropout_p): + # for hf_GPT2 with dropout (introduces clone node) for inference + # it also returns permuted key & value + query = query.permute([0, 2, 1, 3]) + key = key.permute([0, 2, 1, 3]) + value = value.permute([0, 2, 1, 3]) + attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) + inv_scale = torch.full( + [], + value.size(-1) ** 0.5, + dtype=attn_weights.dtype, + device=attn_weights.device, + ) + attn_weights = attn_weights.div(inv_scale) + causal_mask_value = torch.full( + (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device + ) + attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value) + return ( + ( + torch.nn.functional.dropout(attn_weights.softmax(dim=-1), dropout_p).matmul( + value + ) + ), + key, + value, + ) + + +def _sfdp_replacement_18(query, key, value, causal_mask, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + permuted_key = key.transpose(1, 2) + permuted_value = value.transpose(1, 2) + return ( + _scaled_dot_product_attention( + query.transpose(1, 2), + permuted_key, + permuted_value, + attn_mask=causal_mask, + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / math.sqrt(value.size(-1)), + ), + permuted_key, + permuted_value, + ) + + +def _sfdp_pattern_19(query, key, value, causal_mask, attn_mask, dropout_p): + # for token-classification+gpt2 / text-generation+gpt2 + attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) + inv_scale = torch.full( + [], + value.size(-1) ** 0.5, + dtype=attn_weights.dtype, + device=attn_weights.device, + ) + attn_weights = attn_weights.div(inv_scale) + causal_mask_value = torch.full( + (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device + ) + attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value) + attn_weights = attn_weights + attn_mask + attn_weights = attn_weights.softmax(dim=-1).type(value.dtype) + return torch.nn.functional.dropout(attn_weights, dropout_p).matmul(value) + + +def _sfdp_replacement_19(query, key, value, causal_mask, attn_mask, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) + attn_mask = torch.where(causal_mask, attn_mask, fill_value) + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / math.sqrt(value.size(-1)), + ) + + +def _sfdp_pattern_20(query, key, value, attn_mask, dropout_p): + # for DistilBert with dropout transformers==4.44.2 + q = query.permute([0, 2, 1, 3]) + k = key.permute([0, 2, 1, 3]) + v = value.permute([0, 2, 1, 3]) + bs = q.size(0) + k_len = k.size(-2) + q = q.div(math.sqrt(q.size(-1))) + scores = q @ k.transpose(-2, -1) + fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) + attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) + return ( + torch.nn.functional.dropout( + torch.softmax(scores.masked_fill(attn_mask, fill_value), dim=-1), dropout_p + ) + @ v + ) + + +def _sfdp_replacement_20(query, key, value, attn_mask, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + bs = query.size(0) + n_head = query.size(2) + q_len = query.size(1) + k_len = key.size(1) + # do attn_mask->logical_not() in _scaled_dot_product_attention + attn_mask = ( + (attn_mask == 1).view((bs, 1, 1, k_len)).expand((bs, n_head, q_len, k_len)) + ) + return _scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + attn_mask=attn_mask.to(dtype=torch.bool), + dropout_p=dropout_p, + is_causal=False, + scale=1.0 / math.sqrt(query.size(-1)), + ) + + +def _sfdp_pattern_21(query, key, value, attn_mask): + # for T5 with inplace add + query = query.permute([0, 2, 1, 3]) + key = key.permute([0, 2, 1, 3]) + value = value.permute([0, 2, 1, 3]) + score = torch.matmul(query, key.permute(0, 1, 3, 2)) + masked_score = score + attn_mask + score = masked_score.type_as(query) + viewd_score1 = score.view( + score.size(0) * score.size(1), score.size(2), score.size(3) + ) + viewd_score2 = viewd_score1.view( + score.size(0), score.size(1), score.size(2), score.size(3) + ) + return viewd_score2.float().softmax(dim=-1).type_as(query).matmul(value) + + +def _sfdp_replacement_21(query, key, value, attn_mask): + counters["inductor"]["fuse_attention"] += 1 + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + return _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + is_causal=False, + scale=1.0, + ) + + +def _sfdp_pattern_22(query, key, value, attn_mask): + # for T5 with inplace add and return key and value + query = query.permute([0, 2, 1, 3]) + key = key.permute([0, 2, 1, 3]) + value = value.permute([0, 2, 1, 3]) + score = torch.matmul(query, key.permute(0, 1, 3, 2)) + masked_score = score + attn_mask + score = masked_score.type_as(query) + viewd_score1 = score.view( + score.size(0) * score.size(1), score.size(2), score.size(3) + ) + viewd_score2 = viewd_score1.view( + score.size(0), score.size(1), score.size(2), score.size(3) + ) + return viewd_score2.float().softmax(dim=-1).type_as(query).matmul(value), key, value + + +def _sfdp_replacement_22(query, key, value, attn_mask): + counters["inductor"]["fuse_attention"] += 1 + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + return ( + _scaled_dot_product_attention( + query, + key, + value, + attn_mask=attn_mask, + is_causal=False, + scale=1.0, + ), + key, + value, + ) + + +def _sfdp_pattern_23(query, key, value): + # for T5 with inplace add and + # return key and value and + # attn_mask is generated by atem.full(..., 0) + query = query.permute([0, 2, 1, 3]) + key = key.permute([0, 2, 1, 3]) + value = value.permute([0, 2, 1, 3]) + score = torch.matmul(query, key.permute(0, 1, 3, 2)) + fp32_score = score.float() + score = fp32_score.type_as(query) + viewd_score1 = score.view( + score.size(0) * score.size(1), score.size(2), score.size(3) + ) + viewd_score2 = viewd_score1.view( + score.size(0), score.size(1), score.size(2), score.size(3) + ) + return viewd_score2.float().softmax(dim=-1).type_as(query).matmul(value), key, value + + +def _sfdp_replacement_23(query, key, value): + counters["inductor"]["fuse_attention"] += 1 + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + return ( + _scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + is_causal=False, + scale=1.0, + ), + key, + value, + ) + + +def _sfdp_params_check(match): + assert all(k in match.kwargs for k in ("query", "key", "value")) + query = match.kwargs["query"].meta["val"] + key = match.kwargs["key"].meta["val"] + value = match.kwargs["value"].meta["val"] + if not (query.dtype == key.dtype == value.dtype) or not ( + query.device == key.device == value.device + ): + return False + add_mask_node = filter_nodes(match.nodes, aten.add.Tensor) + # Has attn_mask add. + if len(add_mask_node) > 0: + attn_mask_node = add_mask_node[0].args[1] + # attn_mask_node may be a float/int number. + if not hasattr(attn_mask_node, "meta"): + return False + attn_mask = attn_mask_node.meta["val"] # type: ignore[union-attr] + # Make sure attn_mask.dtype == query.dtype or attn_mask.dtype == torch.bool + # attn_mask.dtype == torch.float for models like albert. + if ( + not isinstance(attn_mask, torch.Tensor) + or not ( + attn_mask.dtype == query.dtype + or attn_mask.dtype == torch.bool + or attn_mask.dtype == torch.float + ) + or query.device != attn_mask.device + # When we tensorify floats we end up turning floats + # into 0d scalar tensors. It doesn't make any sense + # to have a 0d scalar tensor attention mask so + # conveniently we can insert this check to get + # tests that erroneously passing in a float + # attention mask to fail as expected. + or attn_mask.dim() == 0 + ): + return False + return True + + +def _sfdp_extra_check(scale_factor_op=None, disable_cuda=False): + def fn(match): + if ( + disable_cuda + and "query" in match.kwargs + and "cuda" in str(match.kwargs["query"].meta["val"].device) + ): + return False + if scale_factor_op is not None: + scale_factor_node = filter_nodes(match.nodes, scale_factor_op)[0] + # Note: args[1] of the scale_factor_node is always the scale_factor for the current patterns. + scale_factor = scale_factor_node.args[1] + # make sure the scale_factor a float/int. SymInt? + if not isinstance(scale_factor, (float, int)): + return False + return _sfdp_params_check(match) + + return fn + + +def partialize_and_update_signature(func, **kwargs): + """ + Equivalent to functools.partial but also updates the signature on returned function + """ + original_sig = inspect.signature(func) + parameters = original_sig.parameters + + new_parameters = { + key: value for key, value in parameters.items() if key not in kwargs + } + new_sig = inspect.Signature(parameters=list(new_parameters.values())) + + partial_func = functools.partial(func, **kwargs) + + def wrapper(*args, **kwargs): + return partial_func(*args, **kwargs) + + wrapper.__signature__ = new_sig # type: ignore[attr-defined] + wrapper.__name__ = func.__name__ + + return wrapper + + +def _get_sfdp_patterns(): + from .joint_graph import patterns + + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + + # sizes/values don't actually matter for initial trace + # once we get a possible match we re-trace with the actual values and verify the match still holds + g_inp = functools.partial( + torch.empty, (2, 4, 8, 16), device=device, requires_grad=True + ) + # attn_mask + b_inp = functools.partial(torch.empty, (1, 1, 8, 8), device=device) + m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device) + # need 2d attn_mask to generate patterns with view op + m_inp_2d = functools.partial(torch.empty, (2, 4), device=device) + # inv_scale + c_inp = functools.partial(torch.tensor, 2.0, device=device) + # workaround https://github.com/pytorch/pytorch/issues/97894 + # 0.113377 is a "magic" value that lets us recover the lost input arg relationship + d = {"dropout_p": 0.113377} + + # we could also generate all these patterns in 3d.. TODO + g_3d_inp = functools.partial( + torch.empty, (1024, 128, 128), device=device, requires_grad=True + ) + + # reshape in matmul decomposition generates a clone when batch_size>1 due to the memory layout change. + # however when batch_size=1, reshape does not change the memory layout, so clone would not be generated. + # here we need to trace with input of batch_size=1 to generate a pattern graph without clone. + g_bs1_inp = functools.partial( + torch.empty, (1, 4, 8, 16), device=device, requires_grad=True + ) + m_bs1_inp = functools.partial(torch.empty, (1, 1, 1, 4), device=device) + + # softmax will generate a dtype conversion on inputs if they are in half, + # but will not in float, so we generate a pattern for both + for dtype in [torch.float, torch.half]: + g = functools.partial(g_inp, dtype=dtype) + b = functools.partial(b_inp, dtype=dtype) + b_float = functools.partial(b_inp, dtype=torch.float) + b_bool = functools.partial(b_inp, dtype=torch.bool) + m = functools.partial(m_inp, dtype=dtype) + m_float = functools.partial(m_inp, dtype=torch.float) + m_bool = functools.partial(m_inp, dtype=torch.bool) + m_2d = functools.partial(m_inp_2d, dtype=dtype) + c = functools.partial(c_inp, dtype=dtype) + g_3d = functools.partial(g_3d_inp, dtype=dtype) + g_bs1 = functools.partial(g_bs1_inp, dtype=dtype) + m_bs1 = functools.partial(m_bs1_inp, dtype=dtype) + m_bs1_float = functools.partial(m_bs1_inp, dtype=torch.float) + m_bs1_bool = functools.partial(m_bs1_inp, dtype=torch.bool) + + candidates = [ + ( + _sfdp_pattern_1, + _sfdp_replacement_1, + [g(), g(), g(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_2, + _sfdp_replacement_2, + [g(), g(), g(), c()], + {}, + _sfdp_extra_check(aten.mul.Tensor), + ), + ( + _sfdp_pattern_3, + _sfdp_replacement_3, + [g(), g(), g(), c()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_4, + _sfdp_replacement_4, + [g(), g(), g(), c()], + d, + _sfdp_extra_check(aten.mul.Tensor), + ), + ( + _sfdp_pattern_5, + _sfdp_replacement_5, + [g(), g(), g(), b()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_6, + _sfdp_replacement_6, + [g(), g(), g(), b()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_7, + _sfdp_replacement_7, + [g(), g(), g()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_8, + _sfdp_replacement_8, + [g(), g(), g()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_9, + _sfdp_replacement_9, + [g(), g(), g()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_10, + _sfdp_replacement_10, + [g(), g(), g()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_11, + _sfdp_replacement_11, + [g(), g(), g(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_12, + _sfdp_replacement_12, + [g(), g(), g(), c()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_13, + _sfdp_replacement_13, + [g_3d(), g_3d(), g_3d()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_14, + _sfdp_replacement_14, + [g(), g(), g(), m(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_15, + _sfdp_replacement_15, + [g(), g(), g(), m_2d(), c()], + {}, + _sfdp_extra_check(aten.div.Tensor), + ), + # TODO: Enable CUDA after solving Bert accuracy issue of calling efficient attention + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g(), g(), g(), m(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ), + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g_bs1(), g_bs1(), g_bs1(), m_bs1(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ), + ( + _sfdp_pattern_17, + _sfdp_replacement_17, + [g(), g(), g(), m_2d(), c()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_18, + _sfdp_replacement_18, + [g(), g(), g(), m_bool()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_18, + _sfdp_replacement_18, + [g_bs1(), g_bs1(), g_bs1(), m_bs1_bool()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_19, + _sfdp_replacement_19, + [g(), g(), g(), b_bool(), b_float()], + d, + _sfdp_params_check, + ), + ( + _sfdp_pattern_20, + _sfdp_replacement_20, + [g(), g(), g(), m_2d()], + d, + _sfdp_extra_check(aten.div.Tensor), + ), + ( + _sfdp_pattern_21, + _sfdp_replacement_21, + [g(), g(), g(), m_float()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_22, + _sfdp_replacement_22, + [g(), g(), g(), m_float()], + {}, + _sfdp_params_check, + ), + ( + _sfdp_pattern_23, + _sfdp_replacement_23, + [g(), g(), g()], + {}, + _sfdp_params_check, + ), + ] + mask_fp32_patterns = ["pattern_16"] + if dtype == torch.half: + # Add inputs of bf16 q/k/v and fp32 mask, for models like albert. + candidates.append( + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g(), g(), g(), m_float(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ) + ) + candidates.append( + ( + _sfdp_pattern_16, + _sfdp_replacement_16, + [g_bs1(), g_bs1(), g_bs1(), m_bs1_float(), c()], + d, + _sfdp_extra_check(aten.div.Tensor, disable_cuda=True), + ) + ) + + for pattern, replacement, args, workaround, extra_check in candidates: + # XXX: when adding a new pattern, re-run `gen_attention_patterns` so the pattern + # gets serialized to a python file and does not require tracing at runtime. + assert isinstance(workaround, dict) + name = pattern.__name__ + + if dtype != torch.float: + name += "_half" + if ( + any(p in name for p in mask_fp32_patterns) + and args[3].dtype == torch.float32 + ): + name += "_mask_fp32" + if args[0].size(0) == 1: + name += "_bs1" + + training_name = name + "_training" + yield ( + training_name, + { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": joint_fwd_bwd, + "pass_dicts": patterns, + "extra_check": extra_check, + "scalar_workaround": workaround, + }, + ) + + if workaround: + assert len(workaround) == 1 and "dropout_p" in workaround + # functools.partial insufficient because we look at signature downstream + pattern = partialize_and_update_signature(pattern, dropout_p=0.0) + replacement = partialize_and_update_signature( + replacement, dropout_p=0.0 + ) + workaround = {} + + inference_name = name + "_inference" + yield ( + inference_name, + { + "search_fn": pattern, + "replace_fn": replacement, + "example_inputs": args, + "trace_fn": fwd_only, + "pass_dicts": patterns, + "extra_check": extra_check, + "scalar_workaround": workaround, + # with dropout turned into clone, we end up with a number of + # semantically identical graphs + "skip_duplicates": True, + }, + ) + + +@functools.cache +def _sfdp_init(): + for key, register_replacement_kwargs in _get_sfdp_patterns(): + gen_register_replacement(key, **register_replacement_kwargs) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..8b18221f422e9e800bfab959eb4dec271d764191 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/group_batch_fusion.py @@ -0,0 +1,1426 @@ +# mypy: allow-untyped-defs +import collections +import logging +import operator +from collections import OrderedDict +from collections.abc import Iterable, Iterator +from typing import Any, Optional + +import torch +from torch._dynamo.utils import counters, is_node_meta_valid +from torch._logging import trace_structured +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..pattern_matcher import ( + CallFunctionVarArgs, + get_arg_value, + stable_topological_sort, +) +from ..utils import OPTIMUS_EXCLUDE_POST_GRAD + + +try: + # importing this will register fbgemm lowerings for inductor + import deeplearning.fbgemm.fbgemm_gpu.fb.inductor_lowerings # noqa: F401 + + has_fbgemm = True +except Exception: + has_fbgemm = False + +aten = torch.ops.aten + +log = logging.getLogger(__name__) + +DEFAULT_BETA = 1 +DEFAULT_ALPHA = 1 + +MIN_FUSE_SET_SIZE = 5 +MAX_FUSE_SET_SIZE = 300 +MAX_FUSE_SEARCH_DEPTH = 5 +# The maximum tensor size that can go into the fusion group +MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR = 4096 +# Whether we only fuse nodes with same parent node +FUSE_NODES_WITH_SAME_PARENT = False +# Whether we enable the add broadcast in batch linear +SHAPE_BROADCAST_BATCH_LINEAR = False +# Whether we enable the fuse nodes with same users +Fuse_NODES_WITH_SAME_USERS = False + +# exclude these nodes from BFS +# excluding get item improves optimizer compilation time by 60s +SEARCH_EXCLUSIONS = OrderedSet([operator.getitem]) + + +default_graph_search_options = { + "min_fuse_set_size": MIN_FUSE_SET_SIZE, + "max_fuse_set_size": MAX_FUSE_SET_SIZE, + "max_fuse_search_depth": MAX_FUSE_SEARCH_DEPTH, + "max_fuse_tensor_size_group_linear": MAX_FUSE_TENSOR_SIZE_GROUP_LINEAR, + "fuse_nodes_with_same_parent": FUSE_NODES_WITH_SAME_PARENT, + "shape_broadcast_batch_linear": SHAPE_BROADCAST_BATCH_LINEAR, + "fuse_nodes_with_same_users": Fuse_NODES_WITH_SAME_USERS, +} + +graph_search_options = default_graph_search_options + + +def update_stack_example_value(node, metadata, dim=0, op=torch.stack): + """ + Update the example value of the node in the graph to enable followup split cat opt. + """ + if node is not None and hasattr(node, "meta"): + if op == torch.stack: + example_value = torch.stack(metadata, dim=dim) + elif op == torch.unbind: + example_value = torch.unbind(metadata, dim=dim) # type: ignore[assignment] + else: + return + node.meta["example_value"] = example_value + + +def update_pointwise_example_value(pointwise_node, input, other, op): + """ + Update the example value of the add node in the graph to enable followup split cat opt. + """ + if pointwise_node is not None and hasattr(pointwise_node, "meta"): + if op == torch.add: + example_value = torch.add(input, other) + elif op == torch.mul: + example_value = torch.mul(input, other) + else: + return + pointwise_node.meta["example_value"] = example_value + + +class GroupBatchFusionBase: + def __init__(self, **kwargs) -> None: + self.graph_search_options = kwargs.pop( + "graph_search_options", default_graph_search_options + ) + + def match(self, node): + raise NotImplementedError("match called on base") + + def fuse(self, graph, subset): + raise NotImplementedError("fuse called on base") + + +PRE_GRAD_FUSIONS: dict[str, GroupBatchFusionBase] = {} +POST_GRAD_FUSIONS: dict[str, GroupBatchFusionBase] = {} + + +def register_fusion(name: str, pre_grad=True): + def decorator(fusion_cls: GroupBatchFusionBase): + if pre_grad: + PRE_GRAD_FUSIONS[name] = fusion_cls + else: + POST_GRAD_FUSIONS[name] = fusion_cls + return fusion_cls + + return decorator + + +def list_group_batch_fusions(pre_grad=True) -> list[str]: + if pre_grad: + return list(PRE_GRAD_FUSIONS.keys()) + else: + return list(POST_GRAD_FUSIONS.keys()) + + +def decompose_stack(graph: torch.fx.GraphModule, input_tensors: list[Any]) -> Any: + unsqueezed_inputs = [] + unsqueezed_inputs_meta = [] + for input_tensor in input_tensors: + unsqueezed_input = graph.call_function( # type: ignore[operator] + aten.unsqueeze, args=(input_tensor,), kwargs={"dim": 0} + ) + unsqueezed_inputs.append(unsqueezed_input) + unsqueezed_input.meta["val"] = aten.unsqueeze(input_tensor.meta["val"], dim=0) # type: ignore[assignment] + unsqueezed_inputs_meta.append(unsqueezed_input.meta["val"]) + stacked_inputs = graph.call_function( # type: ignore[operator] + aten.cat, args=(unsqueezed_inputs,), kwargs={"dim": 0} + ) + stacked_inputs.meta["val"] = aten.cat(unsqueezed_inputs_meta, dim=0) # type: ignore[assignment] + return stacked_inputs + + +class GroupFusion(GroupBatchFusionBase): + """ + Fuse ops in a group way, e.g, fuse mm/addmm of arbitrary input shapes with fbgemm.gmm. + """ + + +class BatchFusion(GroupBatchFusionBase): + """ + Fuse ops in a batch way, e.g, fuse mm/addmm of same input shapes with bmm. + """ + + +class BatchPointwiseOpsFusionFactory(BatchFusion): + def __init__(self, op, **kwargs) -> None: + super().__init__(**kwargs) + self.op = op + + +@register_fusion("batch_linear_post_grad", pre_grad=False) +class PostGradBatchLinearFusion(BatchFusion): + """ + Fuse ops in a batch way in post grad (aten level). + """ + + def _addmm_node_can_be_fused(self, node: torch.fx.Node) -> bool: + # pyre-fixme[7]: Incompatible return type + return ( + node.kwargs.get("beta", DEFAULT_BETA) == DEFAULT_BETA + and node.kwargs.get("alpha", DEFAULT_ALPHA) == DEFAULT_ALPHA # type: ignore[return-value] + ) + + def _is_input_2d(self, input: torch.fx.Node) -> bool: + input_shapes = input.meta["val"].shape + return ( + len(input_shapes) == 2 + and isinstance(input_shapes[0], int) + and isinstance(input_shapes[1], int) + ) + + def match( + self, node: torch.fx.Node + ) -> Optional[tuple[str, int, int, int, bool, str]]: + if CallFunctionVarArgs(aten.mm).match(node): + input_m, weight_m = node.args + bias_m = None + + elif CallFunctionVarArgs(aten.addmm.default).match( + node + ) and self._addmm_node_can_be_fused(node): + bias_m, input_m, weight_m = node.args + else: + return None + # get the user of the node + if self.graph_search_options.get("fuse_nodes_with_same_users", False): + users = [user.target for user in node.users.keys()] + else: + users = "" # type: ignore[assignment] + # only handle the cases where inputs are 2D tensors + if not self._is_input_2d(input_m) or not self._is_input_2d(weight_m): # type: ignore[arg-type] + return None + m, k = input_m.meta["val"].shape # type: ignore[union-attr] + n = weight_m.meta["val"].shape[1] # type: ignore[union-attr] + batch_key = ("batch_linear_post_grad", m, k, n, bias_m is not None, str(users)) + return batch_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_inputs = [] + batch_weights = [] + batch_biases = [] + batch_nodes = [] + batch_inputs_meta = [] + batch_weights_meta = [] + batch_biases_meta = [] + + for node in subset: + if CallFunctionVarArgs(aten.addmm.default).match(node): + bias, input, weight = node.args + elif CallFunctionVarArgs(aten.mm.default).match(node): + input, weight = node.args + bias = None + batch_nodes.append(node) + batch_inputs.append(input) # type: ignore[possibly-undefined] + batch_weights.append(weight) # type: ignore[possibly-undefined] + batch_biases.append(bias) # type: ignore[possibly-undefined] + batch_inputs_meta.append(input.meta) # type: ignore[possibly-undefined, union-attr] + batch_weights_meta.append(weight.meta) # type: ignore[possibly-undefined, union-attr] + if bias is not None: # type: ignore[possibly-undefined] + batch_biases_meta.append(bias.meta) # type: ignore[possibly-undefined, union-attr] + else: + batch_biases_meta.append(None) + + with graph.inserting_before(subset[-1]): # type: ignore[operator] + fused_inputs = decompose_stack(graph, batch_inputs) + fused_weights = decompose_stack(graph, batch_weights) + fused_inputs_meta_val = torch.stack( + [input["val"] for input in batch_inputs_meta] + ) + fused_weights_meta_val = torch.stack( + [weight["val"] for weight in batch_weights_meta] + ) + fused_bmm = graph.call_function( # type: ignore[operator] + aten.bmm, + args=(fused_inputs, fused_weights), + ) + fused_bmm.meta["val"] = aten.bmm( + fused_inputs_meta_val, fused_weights_meta_val + ) + for i, original_mm in enumerate(batch_nodes): + has_bias = False + with graph.inserting_after(fused_bmm): # type: ignore[operator] + new_mm = graph.call_function(aten.select, args=((fused_bmm, 0, i))) # type: ignore[operator] + new_mm.meta["val"] = aten.select(fused_bmm.meta["val"], 0, i) + if batch_biases[i]: + has_bias = True + # broadcast the bias to the same shape as the mm output + if self.graph_search_options.get( + "shape_broadcast_batch_linear", False + ): + broadcast_shape = torch.broadcast_shapes( + batch_biases_meta[i]["val"].shape, new_mm.meta["val"].shape + ) + broadcast_bias = graph.call_function( # type: ignore[operator] + aten.broadcast_to.default, + args=(batch_biases[i],), + kwargs={"size": broadcast_shape}, + ) + broadcast_bias.meta["val"] = aten.broadcast_to( + batch_biases_meta[i]["val"], broadcast_shape + ) # type: ignore[assignment] + new_bias_add = graph.call_function( # type: ignore[operator] + aten.add.Tensor, args=((broadcast_bias, new_mm)) + ) + new_bias_add.meta["val"] = aten.add.Tensor( + broadcast_bias.meta["val"], new_mm.meta["val"] + ) + else: + new_bias_add = graph.call_function( # type: ignore[operator] + aten.add, args=((batch_biases[i], new_mm)) + ) + new_bias_add.meta["val"] = aten.add.Tensor( + batch_biases_meta[i]["val"], new_mm.meta["val"] + ) + new_mm_cont = new_bias_add if has_bias else new_mm # type: ignore[possibly-undefined] + original_mm.replace_all_uses_with(new_mm_cont) + new_mm_cont.meta.update(original_mm.meta) + graph.erase_node(original_mm) # type: ignore[operator] + counters["inductor"]["batch_linear_post_grad"] += 1 + + +@register_fusion("group_linear", pre_grad=False) +class GroupLinearFusion(GroupFusion): + def _addmm_node_can_be_fused(self, node: torch.fx.Node): + input_shape = node.args[1].meta["val"].shape # type: ignore[union-attr] + weight_shape = node.args[2].meta["val"].shape # type: ignore[union-attr] + return ( + node.kwargs.get("beta", DEFAULT_BETA) == DEFAULT_BETA + and node.kwargs.get("alpha", DEFAULT_ALPHA) == DEFAULT_ALPHA + and len(input_shape) == 2 + and len(weight_shape) == 2 + and all(x % 2 == 0 for x in input_shape + weight_shape) + and all( + shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"] + for shape in input_shape + weight_shape + ) + ) + + def _mm_node_can_be_fused(self, node: torch.fx.Node): + input_shape = node.args[0].meta["val"].shape # type: ignore[union-attr] + weight_shape = node.args[1].meta["val"].shape # type: ignore[union-attr] + return ( + len(input_shape) == 2 + and len(weight_shape) == 2 + and all(x % 2 == 0 for x in input_shape + weight_shape) + and all( + shape <= self.graph_search_options["max_fuse_tensor_size_group_linear"] + for shape in input_shape + weight_shape + ) + ) + + def match(self, node: torch.fx.Node) -> Optional[tuple[str, bool]]: + if CallFunctionVarArgs(aten.mm.default).match( + node + ) and self._mm_node_can_be_fused(node): + group_key = ("group_linear", True) + elif CallFunctionVarArgs(aten.addmm.default).match( + node + ) and self._addmm_node_can_be_fused(node): + bias = node.args[0] + group_key = ("group_linear", bias is None) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + group_inputs = [] + group_weights = [] + group_biases = [] + group_nodes = [] + for node in subset: + if CallFunctionVarArgs(aten.addmm.default).match(node): + bias, input, weight = node.args + else: + assert CallFunctionVarArgs(aten.mm.default).match(node) + input, weight = node.args + bias = None + + group_nodes.append(node) + group_inputs.append(input) + group_weights.append(weight) + group_biases.append(bias) + + if all(bias is None for bias in group_biases): + group_biases = None # type: ignore[assignment] + + with graph.inserting_before(subset[0]): # type: ignore[operator] + fused_mm = graph.call_function( # type: ignore[operator] + torch.ops.fbgemm.gmm.default, + args=(group_inputs, group_weights, group_biases), + kwargs={"smart_fused": True}, + ) + + for i, original_mm in enumerate(group_nodes): + with graph.inserting_after(fused_mm): # type: ignore[operator] + new_mm = graph.call_function(operator.getitem, args=(fused_mm, i)) # type: ignore[operator] + original_mm.replace_all_uses_with(new_mm) + new_mm.meta.update(original_mm.meta) + graph.erase_node(original_mm) # type: ignore[operator] + counters["inductor"]["group_linear"] += 1 + + +class BatchPointwiseMathOpsPostGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch pointwise math operator (e.g., add, mul) in post grad pass. + """ + + def __init__(self, op, **kwargs) -> None: + super().__init__(op, **kwargs) + self.op = op + + def _pointwise_node_can_be_fused(self, node: torch.fx.Node): + # note: we only consider the case where the inputs are tensors + # for mixed precision training, we need to make sure the inputs + # of the aten.cat when do the stack should be the same dtype + # otherwise, the output of the aten.cat may be not the same as + # its inputs, and cause dtype not same error in mm or addmm + input, other = node.args + return ( + input.meta["val"].shape == other.meta["val"].shape # type: ignore[union-attr] + # input and other can be scalars, where they have no attribute 'meta' + if hasattr(input, "meta") + and hasattr(other, "meta") + and is_node_meta_valid(input) # type: ignore[arg-type, union-attr] + and is_node_meta_valid(other) # type: ignore[arg-type, union-attr] + # torch.SymInt or torch.SymFloat object has no attribute 'shape' + and isinstance(input.meta["val"], torch.Tensor) # type: ignore[union-attr] + and isinstance(other.meta["val"], torch.Tensor) # type: ignore[union-attr] + else False + ) + + def match(self, node: torch.fx.Node): + if CallFunctionVarArgs(self.op).match( + node + ) and self._pointwise_node_can_be_fused(node): + alpha = node.kwargs.get("alpha", DEFAULT_ALPHA) + rounding_mode = node.kwargs.get("rounding_mode", None) + input, other = node.args + shape = list(input.meta["val"].shape) # type: ignore[union-attr] + if self.graph_search_options.get("fuse_nodes_with_same_parent", False): + # only consider the linear case so far + # pyre-fixme[16] + if input.target == aten.select or other.target == aten.select: # type: ignore[union-attr] + parent = ( + # pyre-fixme[16] + input.args[0] # type: ignore[union-attr] + # pyre-fixme[16] + if input.target == aten.select # type: ignore[union-attr] + else other.args[0] # type: ignore[union-attr] + ) + else: + parent = "" + else: + parent = "" + group_key = ( + "batch_aten_" + self.op.__name__.lower().split(".")[0], + str(shape), + str(input.meta["val"].dtype), # type: ignore[union-attr] + str(other.meta["val"].dtype), # type: ignore[union-attr] + str(alpha), + str(rounding_mode), + str(parent), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_inputs, batch_others = [], [] + alpha = subset[0].kwargs.get("alpha", DEFAULT_ALPHA) + batch_inputs_meta, batch_others_meta = [], [] + + for node in subset: + input, other = node.args + batch_inputs.append(input) + batch_others.append(other) + batch_inputs_meta.append(input.meta) # type: ignore[possibly-undefined, union-attr] + batch_others_meta.append(other.meta) # type: ignore[possibly-undefined, union-attr] + + with graph.inserting_before(subset[0]): # type: ignore[operator] + stack_inputs = decompose_stack(graph, batch_inputs) + stack_others = decompose_stack(graph, batch_others) + stack_inputs_meta = torch.stack( + [input["val"] for input in batch_inputs_meta] + ) + stack_others_meta = torch.stack( + [other["val"] for other in batch_others_meta] + ) + + batch_op = graph.call_function( # type: ignore[operator] + self.op, + args=(stack_inputs, stack_others), + kwargs={"alpha": alpha} if self.op == aten.add.Tensor else {}, + ) + batch_op.meta["val"] = self.op(stack_inputs_meta, stack_others_meta) + for i, original_add in enumerate(subset): + with graph.inserting_after(batch_op): # type: ignore[operator] + new_add = graph.call_function( # type: ignore[operator] + torch.ops.aten.select, args=((batch_op, 0, i)) + ) + original_add.replace_all_uses_with(new_add) + new_add.meta.update(original_add.meta) + graph.erase_node(original_add) # type: ignore[operator] + counters["inductor"][ + "batch_aten_" + self.op.__name__.lower().split(".")[0] + ] += 1 + + +@register_fusion("batch_linear_lhs") +class BatchLinearLHSFusion(BatchFusion): + """ + Batch linear left-hand side fusion. This pass tries to fuse the following patterns: + + torch.nn.functional.linear(x, w1), linear(x, w2),... * linear(x, wn) + -> torch.mm(x, torch.cat([w1, w2,... * wn]).transpose(0, 1)) + + We have a separate pass to eliminate contiguous transpose in a generic way. + """ + + def match(self, node: torch.fx.Node) -> Optional[tuple[str, bool, Any]]: + if CallFunctionVarArgs(torch.nn.functional.linear).match( + node + ) and is_linear_node_can_be_fused(node): + input = get_arg_value(node, 0, "input") + bias = get_arg_value(node, 2, "bias") + group_key = ("batch_linear_lhs", bias is None, input) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_nodes = [] + batch_input = None + batch_weights, batch_weights_meta = [], [] + batch_biases, batch_biases_meta = [], [] + split_sections = [] + for node in subset: + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 1, "weight") + bias = get_arg_value(node, 2, "bias") + batch_nodes.append(node) + if batch_input is None: + batch_input = input + else: + assert batch_input is input + batch_weights.append(weight) + batch_weights_meta.append(weight.meta["example_value"]) + if bias: + batch_biases.append(bias) + batch_biases_meta.append(bias.meta["example_value"]) + split_sections.append(weight.meta["example_value"].shape[0]) + + with graph.inserting_before(subset[0]): # type: ignore[operator] + cat_weights = graph.call_function( # type: ignore[operator] + torch.cat, args=(batch_weights,), kwargs={"dim": 0} + ) + cat_weights.meta["example_value"] = torch.cat(batch_weights_meta, dim=0) + transposed_weights = graph.call_function( # type: ignore[operator] + torch.transpose, args=(cat_weights, 0, 1) + ) + transposed_weights.meta["example_value"] = torch.transpose( + cat_weights.meta["example_value"], 0, 1 + ) + if len(batch_biases) > 0: + cat_biases = graph.call_function( # type: ignore[operator] + torch.cat, args=(batch_biases,), kwargs={"dim": 0} + ) + cat_biases.meta["example_value"] = torch.cat(batch_biases_meta, dim=0) + fused_lhs = graph.call_function( # type: ignore[operator] + torch.addmm, + args=(cat_biases, batch_input, transposed_weights), + ) + fused_lhs.meta["example_value"] = torch.addmm( + cat_biases.meta["example_value"], + batch_input.meta["example_value"], # type: ignore[union-attr] + transposed_weights.meta["example_value"], + ) + else: + fused_lhs = graph.call_function( # type: ignore[operator] + torch.mm, + args=(batch_input, transposed_weights), + ) + fused_lhs.meta["example_value"] = torch.mm( + batch_input.meta["example_value"], # type: ignore[union-attr] + transposed_weights.meta["example_value"], + ) + fused_lhs_list = graph.call_function( # type: ignore[operator] + torch.split, args=(fused_lhs, split_sections), kwargs={"dim": 1} + ) + + for i, node in enumerate(batch_nodes): + with graph.inserting_after(fused_lhs_list): # type: ignore[operator] + new_node = graph.call_function( # type: ignore[operator] + operator.getitem, args=(fused_lhs_list, i) + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) # type: ignore[operator] + counters["inductor"]["batch_linear_lhs"] += 1 + + +# Poor person's check for if a node in the graph mutates its input. +# (the graph is torch IR, so we will see torch fns and python operators) +def _is_mutable_node(tgt): + if str(tgt).endswith("_"): + # e.g. torch.mul_, torch.Tensor.mul_ + return True + if ( + hasattr(tgt, "__module__") + and tgt.__module__ == "_operator" + and tgt.__name__.startswith("i") + ): + # e.g. operator.iand, operator.imul + return True + return False + + +def is_linear_node_can_be_fused(node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 1, "weight") + return ( + is_node_meta_valid(node) + and is_node_meta_valid(input) + and is_node_meta_valid(weight) + and len(input.meta["example_value"].shape) == 2 + and len(weight.meta["example_value"].shape) == 2 + # the mm -> bmm transform adds an unbind() op, + # which is not safe for autograd when the output of the mm is mutated. + # don't pattern match if any users of the mm mutate the input. + and not any(_is_mutable_node(user.target) for user in node.users) + ) + + +@register_fusion("batch_linear") +class PreGradBatchLinearFusion(BatchFusion): + """ + Batch linear fusion in pre grad pass. + Fuse linear with same size with torch.baddmm + """ + + def _getitem_args(self, getitem_node: torch.fx.Node): + if getitem_node.target != operator.__getitem__ or ( + getitem_node.op != "call_function" + ): + return None + return getitem_node.args[0] + + def match(self, node: torch.fx.Node): + if CallFunctionVarArgs(torch.nn.functional.linear).match( + node + ) and is_linear_node_can_be_fused(node): + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 1, "weight") + bias = get_arg_value(node, 2, "bias") + if self.graph_search_options.get("fuse_nodes_with_same_users", False): + users = [user.target for user in node.users.keys()] + else: + users = "" # type: ignore[assignment] + group_key = ( + "batch_linear", + self._getitem_args(input), + str(input.meta["example_value"].shape), + str(weight.meta["example_value"].shape), + bias is None, + str(users), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_weights = [] + batch_biases = [] + batch_inputs_metadata = [] + batch_weights_metadata = [] + batch_biases_metadata = [] + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["example_value"]) + weight = get_arg_value(node, 1, "weight") + batch_weights.append(weight) + batch_weights_metadata.append(weight.meta["example_value"]) + bias = get_arg_value(node, 2, "bias") + batch_biases.append(bias) + if bias is not None and hasattr(bias, "meta"): + batch_biases_metadata.append(bias.meta["example_value"]) + + with graph.inserting_before(subset[0]): # type: ignore[operator] + stack_inputs = graph.call_function( # type: ignore[operator] + torch.stack, args=(batch_inputs,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + stack_weights = graph.call_function( # type: ignore[operator] + torch.stack, args=(batch_weights,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_weights, batch_weights_metadata) + transpose_weight = graph.call_function( # type: ignore[operator] + torch.transpose, args=(stack_weights, 1, 2) + ) + transpose_weight.meta["example_value"] = torch.transpose( + stack_weights.meta["example_value"], 1, 2 + ) + if all(bias is None for bias in batch_biases): + bmm = graph.call_function( # type: ignore[operator] + torch.bmm, + args=(stack_inputs, transpose_weight), + ) + bmm.meta["example_value"] = torch.bmm( + stack_inputs.meta["example_value"], + transpose_weight.meta["example_value"], + ) + bmm_meta = bmm.meta["example_value"] + else: + stack_biases = graph.call_function( # type: ignore[operator] + torch.stack, args=(batch_biases,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_biases, batch_biases_metadata) + unsqueeze_biases = graph.call_function( # type: ignore[operator] + torch.unsqueeze, args=(stack_biases, 1) + ) + unsqueeze_biases.meta["example_value"] = torch.unsqueeze( + stack_biases.meta["example_value"], 1 + ) + bmm = graph.call_function( # type: ignore[operator] + torch.baddbmm, + args=(unsqueeze_biases, stack_inputs, transpose_weight), + ) + try: + # it will have runtime error to broadcast when it has dynamic shape included + # in the meta data, so we need to skip the update meta data + bmm.meta["example_value"] = torch.baddbmm( + unsqueeze_biases.meta["example_value"], + stack_inputs.meta["example_value"], + transpose_weight.meta["example_value"], + ) + bmm_meta = bmm.meta["example_value"] + except Exception as e: + log.debug( + f" exception when update bmm meta data with stack error tracekey {e}" # noqa: G004 + ) + bmm_meta = None + + bmm = graph.call_function(torch.unbind, args=(bmm,), kwargs={"dim": 0}) # type: ignore[operator] + if bmm_meta is not None: + bmm.meta["example_value"] = torch.unbind(bmm_meta, dim=0) + for i, linear in enumerate(batch_nodes): + with graph.inserting_after(bmm): # type: ignore[operator] + getitem = graph.call_function(operator.getitem, args=(bmm, i)) # type: ignore[operator] + linear.replace_all_uses_with(getitem) + getitem.meta.update(linear.meta) + graph.erase_node(linear) # type: ignore[operator] + counters["inductor"]["batch_linear"] += 1 + + +@register_fusion("batch_layernorm") +class BatchLayernormFusion(BatchFusion): + """ + Batch layer norm fusion in pre grad pass + """ + + def match(self, node: torch.fx.Node): + if CallFunctionVarArgs(torch.nn.functional.layer_norm).match(node): + input = get_arg_value(node, 0, "input") + weight = get_arg_value(node, 2, "weight") + bias = get_arg_value(node, 3, "bias") + if self.graph_search_options.get("fuse_nodes_with_same_users", False): + users = [user.target for user in node.users.keys()] + else: + users = "" # type: ignore[assignment] + group_key = ( + ( + "batch_layernorm", + str(input.meta["example_value"].shape), + str(weight.meta["example_value"].shape) + if weight is not None + else "", + str(bias.meta["example_value"].shape) if bias is not None else "", + str(get_arg_value(node, 1, "normalized_shape")), + str(get_arg_value(node, 4, "eps")), + str(users), + ) + if "example_value" in input.meta + and is_node_meta_valid(weight) + and is_node_meta_valid(bias) + else None + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + group_inputs = [] + group_shapes = [] + group_weights = [] + group_biases = [] + group_epss = [] + group_nodes = [] + group_inputs_metadata = [] + group_biases_metadata = [] + group_weights_metadata = [] + for node in subset: + group_nodes.append(node) + input = get_arg_value(node, 0, "input") + group_inputs.append(input) + group_inputs_metadata.append(input.meta["example_value"]) + group_shapes.append(get_arg_value(node, 1, "normalized_shape")) + weight = get_arg_value(node, 2, "weight") + group_weights.append(weight) + if weight is not None and hasattr(weight, "meta"): + group_weights_metadata.append(weight.meta["example_value"]) + bias = get_arg_value(node, 3, "bias") + group_biases.append(bias) + if bias is not None and hasattr(bias, "meta"): + group_biases_metadata.append(bias.meta["example_value"]) + eps = get_arg_value(node, 4, "eps") + if eps is None: + eps = 1e-5 + group_epss.append(eps) + stack_dim = -1 - len(group_shapes[-1]) + + if all(bias is None for bias in group_biases): + group_biases = None # type: ignore[assignment] + if all(weight is None for weight in group_weights): + group_weights = None # type: ignore[assignment] + assert all(eps == group_epss[0] for eps in group_epss), ( + "all epsilon values must be equal" + ) + + with graph.inserting_before(subset[0]): # type: ignore[operator] + stack_input = graph.call_function( # type: ignore[operator] + torch.stack, args=(group_inputs,), kwargs={"dim": stack_dim} + ) + update_stack_example_value(stack_input, group_inputs_metadata, stack_dim) + if group_weights is not None: + stack_weight = graph.call_function( # type: ignore[operator] + torch.stack, args=(group_weights,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_weight, group_weights_metadata) + else: + stack_weight = None + if group_biases is not None: + stack_bias = graph.call_function( # type: ignore[operator] + torch.stack, args=(group_biases,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_bias, group_biases_metadata) + else: + stack_bias = None + + batch_layer_norm = graph.call_function( # type: ignore[operator] + torch.nn.functional.layer_norm, + args=(stack_input, group_shapes[-1]), + kwargs={"eps": group_epss[-1]}, + ) + batch_layer_norm.meta["example_value"] = stack_input.meta["example_value"] + + if group_weights is not None and group_biases is not None: + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + batch_layer_norm = graph.call_function( # type: ignore[operator] + torch.mul, args=(stack_weight, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + stack_weight.meta["example_value"], + previous_batch_layer_norm_meta, + torch.mul, + ) + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + batch_layer_norm = graph.call_function( # type: ignore[operator] + torch.add, args=(stack_bias, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + stack_bias.meta["example_value"], + previous_batch_layer_norm_meta, + torch.add, + ) + elif group_weights is not None and group_biases is None: + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + batch_layer_norm = graph.call_function( + torch.mul, args=(stack_weight, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + stack_weight.meta["example_value"], + previous_batch_layer_norm_meta, + torch.mul, + ) + elif group_weights is None and group_biases is not None: + previous_batch_layer_norm_meta = batch_layer_norm.meta["example_value"] + batch_layer_norm = graph.call_function( + torch.add, args=(stack_bias, batch_layer_norm) + ) + update_pointwise_example_value( + batch_layer_norm, + stack_bias.meta["example_value"], + previous_batch_layer_norm_meta, + torch.add, + ) + + batch_layer_norm_unbind = graph.call_function( # type: ignore[operator] + torch.unbind, + args=(batch_layer_norm,), + kwargs={"dim": stack_dim}, + ) + update_stack_example_value( + batch_layer_norm_unbind, + batch_layer_norm.meta["example_value"], + op=torch.unbind, + dim=stack_dim, + ) + + for i, node in enumerate(group_nodes): + with graph.inserting_after(batch_layer_norm_unbind): # type: ignore[operator] + new_node = graph.call_function( # type: ignore[operator] + operator.getitem, args=(batch_layer_norm_unbind, i) + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) # type: ignore[operator] + counters["inductor"]["batch_layernorm"] += 1 + + +class BatchPointwiseOpsPreGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in pre grad pass. + We fuse it in random place, and the introduced stack node may be merged in split cat. + """ + + def __init__(self, op, **kwargs) -> None: + super().__init__(op, **kwargs) + self.op = op + + def match(self, node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): + if self.graph_search_options.get("fuse_nodes_with_same_parent", False): + # pyre-fixme[16] + parent = node.args[0] + parent = parent.target if parent is not None else "" # type: ignore[union-attr] + else: + parent = "" + # for relu op, we also use the inplace to construct the key + group_key = ( + "batch_" + self.op.__name__.lower().split(".")[0], + str(input.meta["example_value"].shape), + str(node.kwargs.get("inplace", False)), + str(parent), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_inputs_metadata = [] + + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["example_value"]) + + with graph.inserting_before(subset[0]): # type: ignore[operator] + stack_inputs = graph.call_function( # type: ignore[operator] + torch.stack, args=(batch_inputs,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + if self.op == torch.nn.functional.relu: + batch_op = graph.call_function( # type: ignore[operator] + self.op, + args=(stack_inputs,), + kwargs={"inplace": subset[0].kwargs.get("inplace", False)}, + ) + batch_op.meta["example_value"] = self.op( + stack_inputs.meta["example_value"], + inplace=subset[0].kwargs.get("inplace", False), + ) + else: + batch_op = graph.call_function( # type: ignore[operator] + self.op, + args=(stack_inputs,), + ) + batch_op.meta["example_value"] = self.op( + stack_inputs.meta["example_value"] + ) + unbind_op = graph.call_function( # type: ignore[operator] + torch.unbind, args=(batch_op,), kwargs={"dim": 0} + ) + unbind_op.meta["example_value"] = torch.unbind( + batch_op.meta["example_value"], dim=0 + ) + for i, node in enumerate(batch_nodes): + with graph.inserting_after(unbind_op): # type: ignore[operator] + getitem = graph.call_function(operator.getitem, args=(unbind_op, i)) # type: ignore[operator] + node.replace_all_uses_with(getitem) + getitem.meta.update(node.meta) + graph.erase_node(node) # type: ignore[operator] + counters["inductor"]["batch_" + self.op.__name__.lower().split(".")[0]] += 1 + + +class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch pointwise ops (e.g., sigmoid, relu, tanh) fusion in post grad pass. + The introduced stack node may be merged in split cat. + """ + + def __init__(self, op, **kwargs) -> None: + super().__init__(op, **kwargs) + self.op = op + + def match(self, node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): + # for relu op, we also use the inplace to construct the key + # we batch the ops with same parent to enable followup split cat + parent = node.args[0] + parent = ( + parent.target # type: ignore[union-attr] + if self.graph_search_options.get("fuse_nodes_with_same_parent", False) + else "" + ) + group_key = ( + "batch_aten_" + self.op.__name__.lower().split(".")[0], + str(input.meta["val"].shape), + str(node.kwargs.get("inplace", False)), + # pyre-fixme[16] + str(parent), + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_inputs_metadata = [] + + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["val"]) + + with graph.inserting_before(subset[0]): # type: ignore[operator] + stack_inputs = decompose_stack(graph, batch_inputs) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + batch_op = graph.call_function( # type: ignore[operator] + self.op, + args=(stack_inputs,), + ) + for i, node in enumerate(batch_nodes): + with graph.inserting_after(batch_op): # type: ignore[operator] + getitem = graph.call_function(aten.select, args=(batch_op, 0, i)) # type: ignore[operator] + node.replace_all_uses_with(getitem) + getitem.meta.update(node.meta) + graph.erase_node(node) # type: ignore[operator] + counters["inductor"][ + "batch_aten_" + self.op.__name__.lower().split(".")[0] + ] += 1 + + +class BatchMathOpsPreGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch simple match related ops such as nan_to_num in pre grad pass. + """ + + def __init__(self, op, **kwargs): + super().__init__(op, **kwargs) + self.op = op + + def match(self, node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): + # check the input has the same shape and its users have the same target + # check all clamp operators have the same min and max values, and + # nan_to_num operators use the same default value. + child = next(iter(node.users.keys())) + group_key = ( + str(input.meta["example_value"].shape) + + str(node.kwargs) + + str(child.target) + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_inputs_metadata = [] + kwargs = subset[0].kwargs + + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["example_value"]) + + with graph.inserting_before(subset[0]): # type: ignore[operator] + stack_inputs = graph.call_function( # type: ignore[operator] + torch.stack, args=(batch_inputs,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + batch_op = graph.call_function( # type: ignore[operator] + self.op, + args=(stack_inputs,), + kwargs=kwargs, + ) + batch_op.meta["example_value"] = self.op( + stack_inputs.meta["example_value"], **kwargs + ) + unbind_op = graph.call_function( # type: ignore[operator] + torch.unbind, args=(batch_op,), kwargs={"dim": 0} + ) + unbind_op.meta["example_value"] = torch.unbind( + batch_op.meta["example_value"], dim=0 + ) + for i, node in enumerate(batch_nodes): + with graph.inserting_after(unbind_op): # type: ignore[operator] + getitem = graph.call_function(operator.getitem, args=(unbind_op, i)) # type: ignore[operator] + node.replace_all_uses_with(getitem) + getitem.meta.update(node.meta) + graph.erase_node(node) # type: ignore[operator] + counters["inductor"]["batch_" + self.op.__name__.lower().split(".")[0]] += 1 + + +@register_fusion("batch_tanh") +class BatchTanhPreGradFusion(BatchPointwiseOpsPreGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(torch.tanh, **kwargs) + + +@register_fusion("batch_sigmoid") +class BatchSigmoidPreGradFusion(BatchPointwiseOpsPreGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(torch.sigmoid, **kwargs) + + +@register_fusion("batch_relu") +class BatchReLuPreGradFusion(BatchPointwiseOpsPreGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(torch.nn.functional.relu, **kwargs) + + +@register_fusion("batch_detach") +class BatchDetachPreGradFusion(BatchMathOpsPreGradFusion): + def __init__(self, **kwargs): + super().__init__(torch.detach, **kwargs) + + +@register_fusion("batch_nan_to_num") +class BatchNanToNumPreGradFusion(BatchMathOpsPreGradFusion): + def __init__(self, **kwargs): + super().__init__(torch.nan_to_num, **kwargs) + + +@register_fusion("batch_clamp") +class BatchClampPreGradFusion(BatchMathOpsPreGradFusion): + def __init__(self, **kwargs): + super().__init__(torch.clamp, **kwargs) + + +@register_fusion("batch_aten_tanh", pre_grad=False) +class BatchTanhPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.tanh.default, **kwargs) + + +@register_fusion("batch_aten_sigmoid", pre_grad=False) +class BatchSigmoidPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.sigmoid.default, **kwargs) + + +@register_fusion("batch_aten_relu", pre_grad=False) +class BatchReLuPostGradFusion(BatchPointwiseOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.relu.default, **kwargs) + + +@register_fusion("batch_aten_add", pre_grad=False) +class BatchAddPostGradFusion(BatchPointwiseMathOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.add.Tensor, **kwargs) + + +@register_fusion("batch_aten_sub", pre_grad=False) +class BatchSubPostGradFusion(BatchPointwiseMathOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.sub.Tensor, **kwargs) + + +@register_fusion("batch_aten_div", pre_grad=False) +class BatchDivPostGradFusion(BatchPointwiseMathOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.div.Tensor, **kwargs) + + +@register_fusion("batch_aten_mul", pre_grad=False) +class BatchMulPostGradFusion(BatchPointwiseMathOpsPostGradFusion): + def __init__(self, **kwargs) -> None: + super().__init__(aten.mul.Tensor, **kwargs) + + +class _OrderedSet: + def __init__(self, param=None) -> None: + if param: + self.rep = OrderedDict(dict.fromkeys(param)) + else: + self.rep = OrderedDict() + + def __contains__(self, o) -> bool: + return o in self.rep + + def __len__(self) -> int: + return self.rep.__len__() + + def append(self, o): + self.rep[o] = None + + def __iter__(self): + return self.rep.keys().__iter__() + + +def find_independent_subset_greedy( + node_list: Iterable[torch.fx.Node], + graph_search_options: dict[str, Any], +) -> Iterator[Iterable[torch.fx.Node]]: + """ + Yields a list of subsets of `node_list` where no element in the subset + depends on any other element in the subset. This results in a set of + independent nodes which can be fused together. + + The order of `node_list` is preserved within each subset so we can benefit + from split-cat elimination in later passes. + + During iteration it is only safe to mutate the graph by changing the nodes + that have been returned. + + graph_search_options: + - min_fuse_set_size: Minimum size of the subset to consider. Subsets below + this size will be ignored. + - max_fuse_set_size: Maximum size of the subset to consider. Subsets will + be broken to be at most this size. + """ + + # Compute all the children of `node` which are members of + # `interesting_nodes`. + def find_dependent_nodes(node, interesting_nodes): + visited_node_set = OrderedSet[torch.fx.Node]() + dep_set = OrderedSet[torch.fx.Node]() + + work = [node] + while work: + node = work.pop() + for input_node in node.all_input_nodes: + if input_node in interesting_nodes: + dep_set.add(input_node) + + if input_node not in visited_node_set: + visited_node_set.add(input_node) + work.append(input_node) + + return dep_set + + min_fuse_set_size = graph_search_options["min_fuse_set_size"] + max_fuse_set_size = graph_search_options["max_fuse_set_size"] + + # node_list needs to be a set because we only track the nodes that are left + # in it (and we want to do the `in` on a set, not a list). But we want to + # keep the correct order. + node_list = _OrderedSet(node_list) + + cache: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = {} + while node_list: + subset: list[torch.fx.Node] = [] + subset_deps = OrderedSet[torch.fx.Node]() + + next_round_node_list = _OrderedSet() + for node in node_list: + if len(subset) >= max_fuse_set_size or node in subset_deps: + next_round_node_list.append(node) + continue + + dep_set = cache.pop(node, None) + if dep_set is None: + dep_set = find_dependent_nodes(node, node_list) + + if not dep_set.intersection(subset): + subset.append(node) + subset_deps.update(dep_set) + else: + next_round_node_list.append(node) + cache[node] = dep_set + + if len(subset) >= min_fuse_set_size: + # Careful here - the caller uses the subsets to fuse nodes together + # so we need to clear any cache entry that contains one of the + # returned nodes because the dependency list could be different + # (larger) after the merge. + cache = {k: v for k, v in cache.items() if v.isdisjoint(subset)} + yield subset + + node_list = next_round_node_list + + +def get_fusion_candidates( + rule: GroupBatchFusionBase, + root_node: torch.fx.Node, + fused_set: OrderedSet[torch.fx.Node], +) -> collections.defaultdict[Any, list[torch.fx.Node]]: + """ + Search fusion candidates for a specific rule using BFS starting from the root node. + We only search the subgraph within graph_search_options["max_fuse_search_depth"]. + """ + q: collections.deque[tuple[int, torch.fx.Node]] = collections.deque() + + candidate_dict: collections.defaultdict[Any, list[torch.fx.Node]] = ( + collections.defaultdict(list) + ) + + if root_node.target in SEARCH_EXCLUSIONS: + return candidate_dict + + visited_set = OrderedSet[torch.fx.Node]() + + for next_node in root_node.all_input_nodes: + q.append((1, next_node)) + visited_set.add(next_node) + + while len(q) > 0: + depth, node = q.popleft() + + if node in fused_set: + continue + + key = rule.match(node) + if key is not None: + candidate_nodes = candidate_dict[key] + if node not in candidate_nodes: + candidate_nodes.append(node) + else: + if depth < rule.graph_search_options["max_fuse_search_depth"]: + for next_node in node.all_input_nodes: + if next_node not in visited_set: + visited_set.add(next_node) + q.append((depth + 1, next_node)) + + return candidate_dict + + +def apply_group_batch_fusion(graph: torch.fx.GraphModule, rule: GroupBatchFusionBase): + stable_topological_sort(graph) # type: ignore[arg-type] + fused_set = OrderedSet[torch.fx.Node]() + log_to_scuba = False + + for node in reversed(graph.nodes): # type: ignore[arg-type] + candidates = get_fusion_candidates(rule, node, fused_set) + + for key, candidate_nodes in candidates.items(): + if len(candidate_nodes) < rule.graph_search_options["min_fuse_set_size"]: + continue + + for subset in find_independent_subset_greedy( + candidate_nodes, rule.graph_search_options + ): + rule.fuse(graph, subset) + fused_set.update(subset) + log.debug( + f"{rule.__class__.__name__}: key = {key}; subset size = {len(list(subset))}" # noqa: G004 + ) + log_to_scuba = True + if log_to_scuba: + from torch.fx._lazy_graph_module import _LazyGraphModule + + # Force graph to re-compile otherwise the output python code may be broken + gm = graph._owning_module + if isinstance(gm, _LazyGraphModule): + _LazyGraphModule.recompile() + else: + assert isinstance(gm, torch.fx.GraphModule) + gm.recompile() + graph_str = gm.print_readable( + print_output=False, include_stride=True, include_device=True + ) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"optimus_{str(rule.__class__.__name__)}", + "encoding": "string", + }, + payload_fn=lambda: graph_str, + ) + + +def generate_fusion_from_config(config_options: dict[str, Any], pre_grad=True): + fusions: list[GroupBatchFusionBase] = [] + for name, options in config_options.items(): + # we skip all patterns from pattern_matcher passes (e.g., split_cat) + if name not in PRE_GRAD_FUSIONS and name not in POST_GRAD_FUSIONS: + continue + fusion_cls = PRE_GRAD_FUSIONS[name] if pre_grad else POST_GRAD_FUSIONS[name] + _options = graph_search_options.copy() + _options.update(options) + fusions.append(fusion_cls(graph_search_options=_options)) # type: ignore[operator] + return fusions + + +def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True): + fusions: list[GroupBatchFusionBase] = [] + # we keep all current pre grad fusions to keep + # current implementation, will remove this later + if pre_grad: + fusions += generate_fusion_from_config( + config.pre_grad_fusion_options, pre_grad=True + ) + else: + fbgemm_fusion_keys = [ + x + for x in config.post_grad_fusion_options + if ( + x not in OPTIMUS_EXCLUDE_POST_GRAD + and config.post_grad_fusion_options[x].get("require_fbgemm", False) + ) + ] + fbgemm_fusions = { + fusion: config.post_grad_fusion_options[fusion] + for fusion in fbgemm_fusion_keys + } + non_fbgemm_fusions = { + fusion: config.post_grad_fusion_options[fusion] + for fusion in config.post_grad_fusion_options.keys() + if fusion not in fbgemm_fusion_keys + } + fusions += generate_fusion_from_config(non_fbgemm_fusions, pre_grad=False) + if has_fbgemm: + fusions += generate_fusion_from_config(fbgemm_fusions, pre_grad=False) + + for i, rule in enumerate(fusions): + with GraphTransformObserver( + graph.owning_module, + f"group_batch_fusion_{i}", + ): + apply_group_batch_fusion(graph, rule) # type: ignore[arg-type] diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/joint_graph.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/joint_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..e7fe1d5404c4e0df521f8ed7b0fa79bcb0713037 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/joint_graph.py @@ -0,0 +1,943 @@ +# mypy: allow-untyped-defs +import functools +import itertools +import logging +import operator +import typing +from collections import Counter +from collections.abc import Sequence +from typing import Any, Union + +import torch +import torch._guards +import torch.utils._pytree as pytree +from torch._dynamo.utils import counters +from torch._inductor.constant_folding import ConstantFolder +from torch._inductor.fx_passes.dedupe_symint_uses import _SymHashingDict +from torch._inductor.utils import get_gpu_type +from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + statically_known_true, +) +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..pattern_matcher import ( + Arg, + CallFunction, + init_once_fakemode, + KeywordArg, + Match, + MULTIPLE, + PatternMatcherPass, + register_graph_pattern, + stable_topological_sort, +) +from .decompose_mem_bound_mm import check_device +from .replace_random import replace_random_passes + + +log = logging.getLogger(__name__) +patterns = PatternMatcherPass() +aten = torch.ops.aten +prims = torch.ops.prims + +pass_patterns = [ + patterns, + PatternMatcherPass(), +] + + +@init_once_fakemode +def lazy_init(): + from .fuse_attention import _sfdp_init + from .misc_patterns import _misc_patterns_init + from .pad_mm import _pad_mm_init + + _pad_mm_init() + _sfdp_init() + _misc_patterns_init() + + +def remove_no_ops( + gm: torch.fx.GraphModule, + zeros: OrderedSet[torch.fx.Node], + ones: OrderedSet[torch.fx.Node], +): + with torch.utils._python_dispatch._disable_current_modes(): + "Removes no-ops: (+ 0, - 0, * 1, / 1)" + graph = gm.graph + + def fake_tensors_eq(t1, t2, fields=("shape", "dtype", "device")): + if any(not isinstance(t, torch.Tensor) for t in (t1, t2)): + return False + for field in fields: + if getattr(t1, field) != getattr(t2, field): + return False + return True + + def replace_no_op(node, replace_input_index): + replacement = node.args[replace_input_index] + + # https://github.com/pytorch/pytorch/issues/86128 causes + # non-Tensor inputs even for ops with only Tensor inputs. + # TODO - decompose/type promote to avoid this + if not all(isinstance(arg, torch.fx.Node) for arg in node.args): + return + + if not fake_tensors_eq(node.meta["val"], replacement.meta["val"]): + if fake_tensors_eq( + node.meta["val"], + replacement.meta["val"], + ("shape", "device"), + ): + with graph.inserting_after(node): + replacement = graph.call_function( + torch.ops.prims.convert_element_type.default, + args=(replacement, node.meta["val"].dtype), + ) + else: + return + + node.replace_all_uses_with(replacement) + replacement.meta.update(node.meta) + graph.erase_node(node) + + for node in graph.find_nodes(op="call_function", target=aten.add.Tensor): + # TODO handle Tensor-Scalar adds, it's a different schema + if len(node.args) == 2: + if ( + not any(e in zeros for e in node.args) + or node.kwargs.get("alpha", 1) != 1 + ): + continue + + replace_index = 1 if node.args[0] in zeros else 0 + replace_no_op(node, replace_index) + + for node in graph.find_nodes(op="call_function", target=aten.sub.Tensor): + if len(node.args) == 2: + if node.args[1] not in zeros or node.kwargs.get("alpha", 1) != 1: + continue + + replace_no_op(node, 0) + + for node in graph.find_nodes(op="call_function", target=aten.mul.Tensor): + if len(node.args) == 2: + if not any(e in ones for e in node.args): + continue + + replace_input_index = 1 if node.args[0] in ones else 0 + replace_no_op(node, replace_input_index) + + for node in graph.find_nodes(op="call_function", target=aten.div.Tensor): + if len(node.args) == 2 and node.args[1] in ones: + replace_no_op(node, 0) + + # meta tensors returned from the graph have no data and can be replaced with empty_strided + for output_node in graph.find_nodes(op="output"): + had_meta_return = False + + def visit(n): + nonlocal had_meta_return + val = n.meta.get("val") + if isinstance(val, torch.Tensor) and val.device.type == "meta": + with graph.inserting_before(output_node): + n.replace_all_uses_with( + graph.call_function( + torch.ops.aten.empty_strided.default, + args=(val.size(), val.stride()), + kwargs={"dtype": val.dtype, "device": val.device}, + ) + ) + had_meta_return = True + + torch.fx.map_arg(output_node.args, visit) + if had_meta_return: + graph.eliminate_dead_code() + + +def remove_redundant_views(gm: torch.fx.GraphModule): + """ + Removes redundant views by reusing existing ones. + """ + with torch.utils._python_dispatch._disable_current_modes(): + # A dictionary mapping a tensor to all aliased views. + views: dict[torch.fx.Node, dict[torch.dtype, torch.fx.Node]] = {} + graph = gm.graph + + for node in graph.find_nodes( + op="call_function", target=torch.ops.aten.view.dtype + ): + src = node.args[0] + to_type = node.args[1] + existing_views = views.get(src) + is_needed = True + + if existing_views: + # Replace the view with the an existing view if available. + alias = existing_views.get(to_type) + if alias: + is_needed = False + node.replace_all_uses_with(alias) + alias.meta.update(node.meta) + graph.erase_node(node) + else: + from_type = src.meta["val"].dtype + existing_views = {from_type: src} + views[src] = existing_views + + if is_needed: + # Save the new alias but do not replace existing one. + existing_views.setdefault(to_type, node) + views[node] = existing_views + + # Clean up unused views. + while True: + unused_views = [alias for alias in views if not alias.users] + if len(unused_views) == 0: + break + for unused in unused_views: + views.pop(unused) + graph.erase_node(unused) + + +class UniformValueConstantFolder(ConstantFolder): + """ + Runs constant folding and replaces tensors that have a uniform value + with a tensor constructor call: aten.full([shape], value, ...) + """ + + def __init__(self, gm, skip_constructors=False) -> None: + super().__init__(gm, skip_constructors) + self.node_storages_ptrs: dict[torch.fx.Node, int] = {} + self.constant_data_ptrs: dict[torch.fx.Node, StorageWeakRef] = {} + # we may constant fold a tensor which in the graph has a sym size + # see: [constant folding refining of symints] + self.node_replacements_shapes: dict[torch.fx.Node, list[int]] = {} + + # initialize symint -> node mapping so that we can + # use symint nodes in full constructors + self.symint_nodes = _SymHashingDict() + for n in self.module.graph.nodes: # type: ignore[union-attr] + if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt): + self.symint_nodes[n.meta["val"]] = n + + # reference from torch/_funtorch/partitioners.py:get_default_op_list + self.view_op_packets = [ + aten.squeeze, + aten.unsqueeze, + aten.alias, + aten.view, + aten.slice, + aten.t, + prims.broadcast_in_dim, + aten.expand, + aten.as_strided, + aten.permute, + ] + + self.indexing_op_packets = OrderedSet( + [ + aten.slice, + ] + ) + + self._add_peephole_patterns() + + def _add_peephole_patterns(self) -> None: + """ + Add peephole patterns for nodes where we can infer constant value even if some inputs + of the node are unknown. + """ + for op in itertools.chain( + self.module.graph.find_nodes( # type: ignore[operator, union-attr] + op="call_function", target=torch.ops.aten.mul.Tensor + ), + self.module.graph.find_nodes( # type: ignore[operator, union-attr] + op="call_function", target=torch.ops.aten.mul.Scalar + ), + ): + tensor_val = op.meta.get("val", None) + if not isinstance(tensor_val, torch.Tensor): + continue + + def is_zero_int(arg: Any) -> bool: + return isinstance(arg, int) and arg == 0 + + if not any(is_zero_int(a) for a in op.args): + continue + + t = torch.full( + [1], # shape + 0, # value + dtype=tensor_val.dtype, + device=tensor_val.device, + pin_memory=False, + ) + self.add_node_replacement(op, t) + + def _support_dynamic_shape(self): + return True + + def insertable_tensor_check(self, t: torch.Tensor) -> bool: + return True + + def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: + self.node_replacements[node] = tensor.flatten()[0].item() + self.node_replacements_shapes[node] = node.meta["val"].shape + self.constant_data_ptrs[node] = StorageWeakRef(tensor.untyped_storage()) + + def insert_placerholder_values(self, env: dict[torch.fx.Node, Any]) -> None: + for n in self.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr] + if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt): + env[n] = n.meta["val"] + else: + env[n] = self.unknown_value + + def _deduce_value(self, node: torch.fx.Node): + # deduce value for full-like nodes + # 1. for constructors, substitute value is a tensor of size [1] + # 2. for view ops/indexing, substitute value is the same as the input + # 3. for pointwise ops, run node to get the substitute value + # 4. deal with some special ops + # otherwise, stop deduce value and return unknown value + + # TODO: cat, more indexing + # TODO - do on cpu to avoid syncs + + # single-elem attrs + if node.op == "get_attr" or ( + node.op == "call_function" + and node.target == torch.ops.aten.lift_fresh_copy.default + ): + out = super(ConstantFolder, self).run_node(node) + if isinstance(out, torch.Tensor) and out.numel() == 1: + return out + + # handle device_put op + if node.target == prims.device_put.default: + return super(ConstantFolder, self).run_node(node) + + # constructors ops + if ( + node.op == "call_function" + and node.target == aten.full.default + and len(node.args) == 2 + ): + args, kwargs = self.fetch_args_kwargs_from_env(node) + value = args[1] + # Don't specialize symbolic value. + if not isinstance(value, (torch.SymInt, torch.SymFloat, torch.SymBool)): + new_args = [[1], value] + return aten.full.default(*new_args, **node.kwargs) + + # handle before view ops because this changes value + if node.target == aten.view.dtype: + return super(ConstantFolder, self).run_node(node) + + # view ops, return input tensor, the first argument + if hasattr(node.target, "overloadpacket") and ( + node.target.overloadpacket in self.view_op_packets + or node.target.overloadpacket in self.indexing_op_packets + ): + assert isinstance(node.args[0], torch.fx.Node) + return self.env[node.args[0]] + + # we don't want to return unknown value for symints so that we can + # still constant fold through their use in constructors or views + # if we see them in a pointwise node (e.g., tensor * symint) + # we will bail + if "val" in node.meta and isinstance(node.meta["val"], torch.SymInt): + return node.meta["val"] + + # pointwise ops + if isinstance(node.target, torch._ops.OpOverload) and ( + torch.Tag.pointwise in node.target.tags + or node.target is torch.ops.aten.scalar_tensor.default + ): + args, kwargs = self.fetch_args_kwargs_from_env(node) + flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) + + if any(isinstance(inp, torch.SymInt) for inp in flattened_inputs): + return self.unknown_value + + # we run the ops with dim 1, so remove memory_format to avoid error + kwargs = dict(kwargs) + kwargs.pop("memory_format", None) + + return node.target(*args, **kwargs) + + return self.unknown_value + + +def constant_fold_uniform_value(gm: torch.fx.GraphModule): + with torch.utils._python_dispatch._disable_current_modes(): + "Runs constant folding and replaces constants which can be constructed with a single `full` call. Calls into remove_no_ops." + aten = torch.ops.aten + + # Constant folding can leak memory, especially with repeated compilation, so we are only going to + # remove constants which can be replaced with a constructor. + cf = UniformValueConstantFolder(gm) + cf.run() + + node_replacements = cf.node_replacements + + # note: [constant folding refining of symints] + # constant folding will partially evaluate a graph such that values which have dependencies which + # are entirely known at compile time may also become compile time constants. in some cases, + # this will include symints which we had not yet previously deduced are guaranteed a + # constant value and is then deduced in constant folding. an example is: + # unbacked_symint_eq_11 = torch.full((), 11).item() + # torch.full((unbacked_symint_eq_11,), 0) + node_replacements_shapes = cf.node_replacements_shapes + + graph = gm.graph + + zeros = OrderedSet[Any]() + ones = OrderedSet[Any]() + + # Got failures in `test_is_set_to_cuda` if we change aliasing on constants, + # so just constant-ify if a Tensor is unaliased + constant_data_ptr_count: typing.Counter[StorageWeakRef] = Counter() + + for node in cf.node_replacements: + constant_data_ptr_count[cf.constant_data_ptrs[node]] += 1 + + for node, value in node_replacements.items(): + # we dont have a functional way right now of instantiating a non-contiguous tensor with full/zeros/ones right now + # hasn't shown up to be important yet + if "val" not in node.meta: + # This can only happen in AOTI + continue + + fake_tensor = node.meta["val"] + if not fake_tensor.is_contiguous(memory_format=torch.contiguous_format): + continue + + # TODO - not sure about lossy uint->python value->uint conversions + if fake_tensor.dtype in ( + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + ): + continue + + if constant_data_ptr_count[cf.constant_data_ptrs[node]] > 1: + continue + + with graph.inserting_after(node): + # the conversion from tensor and back to value can be lossy, just use the original full ctor value + if ( + node.op == "call_function" + and node.target == aten.full.default + and len(node.args) == 2 + ): + value = node.args[1] + + # refines symints, see [constant folding refining of symints] above + for runtime_size, compile_time_size in zip( + node_replacements_shapes[node], fake_tensor.shape + ): + torch._check(runtime_size == compile_time_size) + + # replace SymInt as Node before creating a new full node + # e.g. (1, s0) -> (1, arg0_1) + node_shape = node_replacements_shapes[node] + if not all( + not isinstance(s, torch.SymInt) or s in cf.symint_nodes + for s in node_shape + ): + continue + + shapes = [ + cf.symint_nodes[s] if isinstance(s, torch.SymInt) else s + for s in node_replacements_shapes[node] + ] + + # zeros and ones just get traced into full, so we insert those + new_node = graph.call_function( + aten.full.default, + args=(shapes, value), + kwargs={ + "dtype": fake_tensor.dtype, + "layout": torch.strided, + "device": fake_tensor.device, + "pin_memory": False, + }, + ) + + new_node.meta.update(node.meta) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + if value == 0: + zeros.add(new_node) + elif value == 1: + ones.add(new_node) + + remove_no_ops(gm, zeros, ones) + remove_redundant_views(gm) + + +def canonicalize_quant_mapping(gm: torch.fx.GraphModule): + """ + + + torch.ops.higher_order.invoke_quant_packed(repeated_subgraph0, 'quant_invoke_0_0', (arg0_1, arg1_1)); + -> + torch.ops.higher_order.invoke_quant(repeated_subgraph0, arg0_1, arg1_1, scheme = 'nf4'); + """ + graph = gm.graph + invoke_quant_invocations = graph.find_nodes( + op="call_function", target=torch.ops.higher_order.invoke_quant_packed + ) + for invoke_quant in invoke_quant_invocations: + kwargs = dict(invoke_quant.kwargs) + + quant_options_node = kwargs.pop("quant_options", None) + if quant_options_node is not None: + assert isinstance(quant_options_node, torch.fx.Node) + quant_options = torch._higher_order_ops.InvokeQuant( + *invoke_quant.kwargs["quant_options"].args, + **invoke_quant.kwargs["quant_options"].kwargs, + ) + else: + quant_options = torch._higher_order_ops.InvokeQuant() + + subgraph, *args = invoke_quant.args + with gm.graph.inserting_before(invoke_quant): + invoke_quant_replacement = graph.call_function( + torch._higher_order_ops.invoke_quant, + (subgraph, *args), + kwargs, + ) + invoke_quant_replacement.meta.update(subgraph.meta) + invoke_quant_replacement.meta["quant_options"] = quant_options + + invoke_quant.replace_all_uses_with(invoke_quant_replacement) + graph.erase_node(invoke_quant) + + if quant_options_node and len(quant_options_node.users) == 0: + graph.erase_node(quant_options_node) + + first_user = next(iter(invoke_quant_replacement.users)) + + if ( + len(invoke_quant_replacement.users) == 1 + and len(subgraph.users) == 1 + and first_user.target == operator.getitem + and first_user.args[1] == 0 + ): + subgraph_graph = getattr(gm, subgraph.target) + output_node = torch._inductor.utils.output_node(subgraph_graph) + assert ( + isinstance(output_node.args[0], (list, tuple)) + and len(output_node.args[0]) == 1 + ) + + unpacked_output = output_node.args[0][0] + output_node.args = (unpacked_output,) + if "val" in output_node.meta: + output_node.meta["val"] = output_node.meta["val"][0] + subgraph_graph.recompile() + + invoke_quant_replacement.meta.update(first_user.meta) + first_user.replace_all_uses_with(invoke_quant_replacement) + graph.erase_node(first_user) + + +def canonicalize_aten_ir_passes(gm: torch.fx.GraphModule): + """ + Canonicalization passes that will run immediately after aot autograd + tracing. Thsis must be run before all other graph passes. + """ + canonicalize_quant_mapping(gm) + + +def joint_graph_passes(graph: torch.fx.GraphModule): + """ + Run FX transformations on the joint forwards+backwards graph. + """ + GraphTransformObserver = functools.partial( + torch.fx.passes.graph_transform_observer.GraphTransformObserver, + subsystem="joint_graph_passes", + ) + + lazy_init() + count = 0 + + # must occur before other passes + canonicalize_aten_ir_passes(graph) + + if config.joint_custom_pre_pass is not None: + GraphTransformObserver(graph, "joint_custom_pre_pass").apply_graph_pass( + config.joint_custom_pre_pass + ) + count += 1 + + from .post_grad import remove_noop_ops + + GraphTransformObserver(graph, "remove_noop_ops").apply_graph_pass(remove_noop_ops) + + if config.joint_graph_constant_folding: + GraphTransformObserver(graph, "constant_fold_uniform_value").apply_gm_pass( + constant_fold_uniform_value + ) + + if config.joint_custom_pre_pass is not None: + GraphTransformObserver(graph, "joint_custom_pre_pass").apply_graph_pass( + config.joint_custom_pre_pass + ) + count += 1 + + if config.pattern_matcher: + for i, patterns in enumerate(pass_patterns): + maybe_count = GraphTransformObserver( + graph, f"pass_pattern_{i}" + ).apply_graph_pass(patterns.apply) + count += maybe_count if maybe_count is not None else 0 + + if not config.fallback_random: + # not trying into the bisector because decomps may have already affected rng reproducibility + # we'll instead explicitly turn off the config + count += replace_random_passes(graph) + + if config.joint_custom_post_pass is not None: + GraphTransformObserver(graph, "joint_custom_post_pass").apply_graph_pass( + config.joint_custom_post_pass + ) + count += 1 + + if count: + stable_topological_sort(graph.graph) + graph.graph.lint() + graph.recompile() + return graph + + +@register_graph_pattern( + CallFunction( + torch.ops.prims.iota.default, + KeywordArg("length"), + start=KeywordArg("start"), + step=KeywordArg("step"), + dtype=KeywordArg("dtype"), + device=KeywordArg("device"), + requires_grad=KeywordArg("requires_grad"), + ), + pass_dict=patterns, +) +def fix_iota_device(match: Match, length, start, step, dtype, device, requires_grad): + """ + Eager supports: + + aten.index(cuda_tensor, torch.arange(..., device="cpu")) + + But this results in an implicit host-device-copy and breaks cudagraphs. + Rewrite the arange to use CUDA. + """ + (node,) = match.nodes + user_devices = OrderedSet[torch.device]() + for user in node.users: + if ( + user.op == "call_function" + and user.target in (aten.index.Tensor, aten.index_put.default) + and hasattr(user.meta.get("val"), "device") + ): + user_devices.add(user.meta["val"].device) # type: ignore[union-attr] + else: + return # bail out + + if len(user_devices) == 1 and "val" in node.meta: + (user_device,) = user_devices + if device.type != user_device.type: + repl = match.graph.call_function( + torch.ops.prims.iota.default, + (length,), + { + "start": start, + "step": step, + "dtype": dtype, + "device": user_device, + "requires_grad": requires_grad, + }, + ) + repl.meta.update(node.meta) + repl.meta["val"] = repl.meta["val"].to(user_device) + node.replace_all_uses_with(repl) + match.erase_nodes() + + +@register_graph_pattern( + CallFunction( + torch.ops.prims.convert_element_type.default, + CallFunction( + torch.ops.prims.convert_element_type.default, + KeywordArg("arg"), + KeywordArg("dtype1"), + ), + KeywordArg("dtype2"), + ), + pass_dict=patterns, +) +def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtype): + """Remove chain of dtype conversions often created by AMP""" + graph = match.graph + node = match.output_node() + allowed = torch.float16, torch.bfloat16, torch.float32, torch.float64 + if dtype1 in allowed and dtype2 in allowed: + repl = graph.call_function( + torch.ops.prims.convert_element_type.default, (arg, dtype2) + ) + repl.meta.update(node.meta) + node.replace_all_uses_with(repl) + match.erase_nodes() + + +def definitely_equal( + old_sizes: Sequence[Union[torch.SymInt, int]], + new_sizes: Sequence[Union[torch.SymInt, torch.fx.Node, int]], +) -> bool: + """ + Leverage guard_or_true/false to compare if two lists of int/symint are equal. + Useful to compare sizes, strides etc. + + Can handle -1 in new_sizes which happens in the size arguments of a + view op. old_sizes is supposed to be the tensor shape and should not + contain -1. + + new_sizes can contains fx.Node when dynamic shape is enabled. In that + case new_sizes[i].meta['val'] contains the real torch.SymInt. + """ + + num_neg1 = 0 + + if len(old_sizes) != len(new_sizes): + return False + + for lhs_item, rhs_item in zip(old_sizes, new_sizes): + if isinstance(rhs_item, torch.fx.Node): + rhs_item = rhs_item.meta["val"] + + assert isinstance(lhs_item, (int, torch.SymInt)), type(lhs_item) + assert isinstance(rhs_item, (int, torch.SymInt)), type(rhs_item) + + # It still makes sense to call guard_or_true/false since lhs_item + # rhs_item are torch.SymInt rather than sympy expressions when + # dynamic shape is enabled. + if guard_or_false(lhs_item == rhs_item): + continue + + if guard_or_true(rhs_item != -1): + return False + + num_neg1 += 1 + + if num_neg1 > 1: + return False + return True + + +@register_graph_pattern( + CallFunction(torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")), + pass_dict=patterns, +) +def pointless_view(match: Match, arg, size): + """Remove no-op view""" + node = match.output_node() + arg_size = list(node.args[0].meta["val"].shape) # type: ignore[union-attr] + if definitely_equal(arg_size, size): + node.replace_all_uses_with(node.args[0]) # type: ignore[arg-type] + match.erase_nodes() + + +@register_graph_pattern( + CallFunction( + aten.view.default, + CallFunction(aten.view.default, KeywordArg("arg"), KeywordArg("size1")), + KeywordArg("size2"), + ), + pass_dict=patterns, +) +def pointless_view_pair(match: Match, arg, size1, size2): + """ + Remove a pair of views that are pointless. + """ + node = match.output_node() + arg_size = list(arg.meta["val"].shape) + if definitely_equal(arg_size, size2): + node.replace_all_uses_with(arg) + match.erase_nodes() + counters["inductor"]["removed_pointless_view_pair"] += 1 + + +@register_graph_pattern( + CallFunction( + aten.permute.default, + CallFunction(aten.permute.default, KeywordArg("arg"), KeywordArg("perm1")), + KeywordArg("perm2"), + ), + pass_dict=patterns, +) +def pointless_permute_pair(match: Match, arg, perm1, perm2): + rank = len(perm1) + assert len(perm2) == rank + + for i in range(rank): + if perm1[perm2[i]] != i: + return # bail out + node = match.output_node() + node.replace_all_uses_with(arg) + match.erase_nodes() + + +@register_graph_pattern( + CallFunction( + aten.bmm, + Arg(), + Arg(), + ), + pass_dict=patterns, +) +def bmm_to_mm(match: Match, mat1: torch.fx.Node, mat2: torch.fx.Node): + """Convert bmm to mm when batch size is 1""" + + def repl(a, b): + return torch.mm(a.squeeze(0), b.squeeze(0)).unsqueeze(0) + + if ( + check_device(mat1.meta["val"], mat2.meta["val"], get_gpu_type()) + and statically_known_true(mat1.meta["val"].shape[0] == 1) + and statically_known_true(mat2.meta["val"].shape[0] == 1) + ): + match.replace_by_example(repl, [mat1, mat2]) + + +# When softmax is used with temperature or other scaling, we get the pattern +# +# scale(x) - scale(x).amax(dim, keepdim=True) +# +# which is expected to be at most zero, but we may end up with numerical +# discrepancies # between the recomputed values of scale(x) inside and out +# of the reduction, # depending on compiler optimizations, e.g. use of fma +# instructions. +# +# Here we replace it with the mathematically equivalent, +# +# scale(x - x.amax(dim, keepdim=True)) +# +# which is more stable as we only compute the scaling once. +# +# NOTE: This pattern must come after fused attention matching! + + +def _partial_softmax_pattern(linear_func, reverse=False, to_dtype=False): + # Allow matching inp * other and other * input + if reverse: + scaled = CallFunction( + linear_func, KeywordArg("other"), KeywordArg("inp"), _users=MULTIPLE + ) + else: + scaled = CallFunction( + linear_func, KeywordArg("inp"), KeywordArg("other"), _users=MULTIPLE + ) + if to_dtype: + scaled = CallFunction( + prims.convert_element_type, scaled, KeywordArg("dtype"), _users=MULTIPLE + ) + amax = CallFunction( + aten.amax.default, scaled, KeywordArg("dim"), KeywordArg("keepdim") + ) + return CallFunction(aten.sub.Tensor, scaled, amax) + + +def _other_is_broadcasted_in_dim(match): + # Check that the scaling factor is constant across the reduction dim, + # so scaling doesn't change which index corresponds to the maximum value + other = match.kwargs["other"] + if isinstance(other, (int, float)): + return True + + inp = match.kwargs["inp"] + if not all(isinstance(x, torch.fx.Node) for x in (inp, other)): + return False + + inp_example = inp.meta["val"] + other_example = other.meta["val"] + if isinstance(other_example, (torch.SymInt, torch.SymFloat)): + return True + + if not all(isinstance(x, torch.Tensor) for x in (inp_example, other_example)): + return False + + inp_ndim = inp_example.ndim + other_shape = other_example.shape + if inp_ndim < len(other_shape): + return False + + # Pad other_shape to the same ndim as inp + other_shape = [1] * (inp_ndim - len(other_shape)) + list(other_shape) + + dim = match.kwargs["dim"] + if isinstance(dim, int): + dim = (dim,) + + return all(statically_known_true(other_shape[d] == 1) for d in dim) + + +def mul_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None): + def repl(inp, other): + if dtype is not None: + inp = inp.to(dtype) + + sign: Union[int, float, torch.Tensor] + if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)): + sign = 1 if other >= 0 else -1 + else: + one = torch.scalar_tensor(1, dtype=inp.dtype, device=inp.device) + sign = torch.where(other >= 0, one, -one) + + inp = inp * sign + max_ = torch.amax(inp, dim=dim, keepdim=keepdim) + return (inp - max_) * (sign * other) + + match.replace_by_example(repl, [inp, other]) + + +for reverse, to_dtype in itertools.product((False, True), repeat=2): + register_graph_pattern( + _partial_softmax_pattern(aten.mul.Tensor, reverse=reverse, to_dtype=to_dtype), + pass_dict=pass_patterns[1], + extra_check=_other_is_broadcasted_in_dim, + )(mul_softmax_pattern) + + +def div_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None): + def repl(inp, other): + if dtype is not None: + inp = inp.to(dtype) + + sign: Union[int, float, torch.Tensor] + if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)): + sign = 1 if other >= 0 else -1 + else: + one = torch.scalar_tensor(1, dtype=inp.dtype, device=inp.device) + sign = torch.where(other >= 0, one, -one) + + inp = inp * sign + max_ = torch.amax(inp, dim=dim, keepdim=keepdim) + return (inp - max_) / (sign * other) + + match.replace_by_example(repl, [inp, other]) + + +for to_dtype in (False, True): + register_graph_pattern( + _partial_softmax_pattern(aten.div.Tensor, to_dtype=to_dtype), + pass_dict=pass_patterns[1], + extra_check=_other_is_broadcasted_in_dim, + )(div_softmax_pattern) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py new file mode 100644 index 0000000000000000000000000000000000000000..1437d4eb612d531e2262b74a0a0ed8f2b5f4291c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -0,0 +1,1079 @@ +# mypy: allow-untyped-defs +import logging +import operator +from collections import defaultdict +from dataclasses import dataclass, field +from math import prod +from typing import Any, cast, Optional + +import torch +from torch.utils._ordered_set import OrderedSet + +from .. import config, inductor_prims +from ..pattern_matcher import ( + CallFunction, + Ignored, + KeywordArg, + ListOf, + Match, + MULTIPLE, + PatternExpr, + PatternMatcherPass, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten +patterns = PatternMatcherPass() + + +def _is_backward(graph: torch.fx.Graph) -> bool: + placeholders = [] + for node in graph.nodes: + if node.op != "placeholder": + break + placeholders.append(node) + return not all(node.name.startswith("primal") for node in placeholders) + + +def _compute_mm_arithmetic_intensity(M: int, N: int, K: int) -> float: + return M * N * K / (M * K + N * K + M * N) + + +def _filter_nodes_by_target(nodes: list[torch.fx.Node], target) -> list[torch.fx.Node]: + return [x for x in nodes if x.target == target] + + +def _find_ancestors(node: torch.fx.Node) -> OrderedSet[torch.fx.Node]: + ancestors = OrderedSet[torch.fx.Node]() + ancestors.add(node) + cur_nodes = [node] + while len(cur_nodes) > 0: + new_nodes = [] + for node in cur_nodes: + for inp in node.all_input_nodes: + if inp not in ancestors: + ancestors.add(inp) + new_nodes.append(inp) + cur_nodes = new_nodes + return OrderedSet(node for node in ancestors if node.op != "placeholder") + + +def _get_tensor(node: torch.fx.Node) -> torch.Tensor: + val = node.meta["val"] + assert isinstance(val, torch.Tensor) + return val + + +@dataclass +class _AllGatherMatch: + match: Match + shard_node: torch.fx.Node + ag_node: torch.fx.Node + res_node: torch.fx.Node + gather_dim: int + group_name: str + + def replace_with(self, new_node: torch.fx.Node) -> None: + self.res_node.replace_all_uses_with(new_node) + + def erase(self) -> None: + for node in reversed(self.match.nodes): + if len(node.users) == 0: + node.graph.erase_node(node) + + +def find_all_gather_patterns(graph: torch.fx.Graph): + c10d = torch.ops._c10d_functional + + def make_zero_dim_all_gather_pattern(shard): + return CallFunction( + c10d.wait_tensor.default, + CallFunction( + c10d.all_gather_into_tensor.default, + shard, + Ignored(), + KeywordArg("group_name"), + ), + ) + + # Matches funcol.all_gather_tensor with gather_dim == 0 + zero_dim_all_gather_pattern = make_zero_dim_all_gather_pattern(KeywordArg("shard")) + + def make_all_gather_split_pattern(shard): + return CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + make_zero_dim_all_gather_pattern(shard), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + ) + + def make_cat_pattern(splits): + return CallFunction( + aten.cat.default, + ListOf(splits), + KeywordArg("gather_dim"), + ) + + # Matches funcol.all_gather_tensor with gather_dim > 0 + non_zero_dim_all_gather_pattern = make_cat_pattern( + make_all_gather_split_pattern(KeywordArg("shard")), + ) + + # Match a zero-dim all-gather in which the data is transferred as uint8 and + # viewed back as the original dtype. + zero_dim_type_erased_all_gather_pattern = CallFunction( + aten.view.dtype, + make_zero_dim_all_gather_pattern( + KeywordArg("shard"), + ), + Ignored(), + ) + + # Match a non-zero dim all-gather in which the data is transferred as uint8 + # and viewed back as the original dtype. + non_zero_dim_type_erased_all_gather_pattern = CallFunction( + aten.view.dtype, + make_cat_pattern( + CallFunction( + aten.view.dtype, + make_all_gather_split_pattern( + KeywordArg("shard"), + ), + Ignored(), + ), + ), + Ignored(), + ) + + # If two patterns with the same res_node_target have the same suffix, the + # longer pattern should appear first in the list. + # e.g. supposed we have (1) A -> B -> C -> D and (2) B -> C -> D, (1) + # should appear before (2) in the list. + res_node_target_to_patterns = { + aten.cat.default: [ + (non_zero_dim_all_gather_pattern, 0), + ], + aten.view.dtype: [ + (non_zero_dim_type_erased_all_gather_pattern, 0), + (zero_dim_type_erased_all_gather_pattern, 0), + ], + c10d.wait_tensor.default: [ + (zero_dim_all_gather_pattern, 0), + ], + } + + # Match in reverse to ensure longer patterns is prioritized + all_gathers = [] + visited_ag_nodes = OrderedSet[torch.fx.Node]() + for node in reversed(graph.nodes): + for target, patterns in res_node_target_to_patterns.items(): + if node.target != target: + continue + for pattern, ag_node_idx in patterns: + match = pattern.match(node) + if not match: + continue + + assert isinstance(match, Match) + ag_node = match.nodes[ag_node_idx] + assert ag_node.target == c10d.all_gather_into_tensor.default + + if ag_node in visited_ag_nodes: + continue + visited_ag_nodes.add(ag_node) + + ag_match = _AllGatherMatch( + match=match, + shard_node=match.kwargs["shard"], + ag_node=ag_node, + res_node=node, + gather_dim=match.kwargs.get("gather_dim", 0), + group_name=match.kwargs["group_name"], + ) + all_gathers.append(ag_match) + + return list(reversed(all_gathers)) + + +@dataclass +class _ReduceScatterMatch: + match: Match + input_node: torch.fx.Node + reduce_scatter_node: torch.fx.Node + wait_tensor_node: torch.fx.Node + reduce_op: str + scatter_dim: int + group_name: str + + def replace_with(self, new_node: torch.fx.Node) -> None: + # Replace all uses of the result node (wait_tensor) with the fused node. + self.wait_tensor_node.replace_all_uses_with(new_node) + + # If the reduce-scatter result is saved for backward, save the fused node for backward instead. + self._update_save_for_backward(new_node) + + def _update_save_for_backward(self, new_node: torch.fx.Node) -> None: + """ + If the output node is a user of the reduce_scatter node (indicating the reduce_scatter + result is saved for backward), this method will update the output node to use the fused node instead. + """ + output_node = None + for user in self.reduce_scatter_node.users: + if user.target == "output": + output_node = user + break + if output_node is not None: + output_node.replace_input_with(self.reduce_scatter_node, new_node) + + # Assert that now the reduce scatter node has only one user (the wait_tensor) and it's not + # saved for backward anymore. + assert len(self.reduce_scatter_node.users) == 1, ( + "Reduce scatter node has multiple users, this is not expected" + ) + + def erase(self) -> None: + for node in reversed(self.match.nodes): + if len(node.users) == 0: + node.graph.erase_node(node) + + +def find_reduce_scatter_patterns(graph: torch.fx.Graph): + c10d = torch.ops._c10d_functional + + def reduce_scatter_template(inp: PatternExpr, users: int): + return CallFunction( + c10d.wait_tensor.default, + CallFunction( + c10d.reduce_scatter_tensor.default, + inp, + KeywordArg("reduce_op"), + Ignored(), + KeywordArg("group_name"), + _users=users, + ), + ) + + # Matches funcol.reduce_scatter_tensor with scatter_dim == 0 + zero_dim_reduce_scatter_pattern_single_user = reduce_scatter_template( + KeywordArg("input"), users=1 + ) + + # Two users will occur when the reduce-scatter result is saved for backward + zero_dim_reduce_scatter_pattern_multi_user = reduce_scatter_template( + KeywordArg("input"), users=2 + ) + + # Matches funcol.reduce_scatter_tensor with scatter_dim > 0 + non_zero_dim_reduce_scatter_pattern_single_user = reduce_scatter_template( + CallFunction( + aten.cat.default, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + KeywordArg("input"), + Ignored(), + KeywordArg("scatter_dim"), + _users=MULTIPLE, + ), + Ignored(), + ) + ), + ), + users=1, + ) + + # Two users will occur when the reduce-scatter result is saved for backward + non_zero_dim_reduce_scatter_pattern_multi_user = reduce_scatter_template( + CallFunction( + aten.cat.default, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split.Tensor, + KeywordArg("input"), + Ignored(), + KeywordArg("scatter_dim"), + _users=MULTIPLE, + ), + Ignored(), + ) + ), + ), + users=2, + ) + + reduce_scatters = [] + for node in reversed(graph.nodes): + if node.target == c10d.wait_tensor.default: + if match := non_zero_dim_reduce_scatter_pattern_single_user.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + reduce_scatter_node=match.nodes[-2], + wait_tensor_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=match.kwargs["scatter_dim"], + group_name=match.kwargs["group_name"], + ) + ) + elif match := zero_dim_reduce_scatter_pattern_single_user.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + reduce_scatter_node=match.nodes[0], + wait_tensor_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=0, + group_name=match.kwargs["group_name"], + ) + ) + elif match := non_zero_dim_reduce_scatter_pattern_multi_user.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + reduce_scatter_node=match.nodes[-2], + wait_tensor_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=match.kwargs["scatter_dim"], + group_name=match.kwargs["group_name"], + ) + ) + elif match := zero_dim_reduce_scatter_pattern_multi_user.match(node): + assert isinstance(match, Match) + reduce_scatters.append( + _ReduceScatterMatch( + match=match, + input_node=match.kwargs["input"], + reduce_scatter_node=match.nodes[0], + wait_tensor_node=node, + reduce_op=match.kwargs["reduce_op"], + scatter_dim=0, + group_name=match.kwargs["group_name"], + ) + ) + return list(reversed(reduce_scatters)) + + +@dataclass +class _Matmul: + nodes: list[torch.fx.Node] + arg_ancestor_nodes: OrderedSet[torch.fx.Node] = field(init=False) + A_node: torch.fx.Node + B_node: torch.fx.Node + pre_mm_reshape: Optional[torch.fx.Node] + post_mm_reshape: Optional[torch.fx.Node] + + def __post_init__(self): + assert len(self.nodes) in (1, 3) + if len(self.nodes) == 1: + assert self.nodes[0].target in (aten.mm.default, aten._scaled_mm.default) + else: + assert self.nodes[0].target == aten.reshape.default + assert self.nodes[1].target in (aten.mm.default, aten._scaled_mm.default) + assert self.nodes[2].target == aten.reshape.default + self.arg_ancestor_nodes = _find_ancestors(self.B_node) + + def replace_with(self, new_node: torch.fx.Node) -> None: + """ + Replace the matmul with the new node. + """ + graph = new_node.graph + + # For 2D-matmuls, we simply replace the mm node with `new_node`. + if len(self.nodes) == 1: + mm_node = self.nodes[0] + assert mm_node.target in (aten.mm.default, aten._scaled_mm.default) + mm_node.replace_all_uses_with(new_node) + graph.erase_node(mm_node) + return + + # An ND-matmul is reshape -> mm -> reshape sequence. We first replace + # the second reshape node with `new_node`. Then, we ensure that the + # original mm node in the sequence ends up with zero users by replacing + # it with a reverse reshape of `new_node`. + graph = new_node.graph + assert len(self.nodes) == 3 + mm_node = self.nodes[1] + output_reshape_node = self.nodes[2] + + assert mm_node.target in (aten.mm.default, aten._scaled_mm.default) + assert output_reshape_node.target == aten.reshape.default + + output_reshape_node.replace_all_uses_with(new_node) + if len(mm_node.users) > 1: + with graph.inserting_after(new_node): + new_mm_node = graph.call_function( + aten.reshape.default, + args=(new_node, list(_get_tensor(mm_node).shape)), + ) + mm_node.replace_all_uses_with(new_mm_node) + + def erase(self) -> None: + for node in reversed(self.nodes): + if len(node.users) == 0: + node.graph.erase_node(node) + + @classmethod + def from_match(cls, match: list[torch.fx.Node]) -> "_Matmul": + assert len(match) in (1, 3) + assert match[0].target in ( + aten.mm.default, + aten.reshape.default, + ) + mm_node = match[0] if len(match) == 1 else match[1] + return _Matmul( + nodes=match, + A_node=cast("torch.fx.Node", match[0].args[0]), + B_node=cast("torch.fx.Node", mm_node.args[1]), + # _Matmul handles reshapes via custom graph manipulation logic, see `replace_with()` method. + # TODO: explore unifying the _Matmul and _ScaledMatmul approaches to handling reshapes. + pre_mm_reshape=None, + post_mm_reshape=None, + ) + + +@dataclass +class _ScaledMatmul(_Matmul): + A_scale_node: torch.fx.Node + B_scale_node: torch.fx.Node + bias_node: Optional[torch.fx.Node] + result_scale_node: Optional[torch.fx.Node] + out_dtype: Optional[torch.dtype] + use_fast_accum: bool + pre_mm_reshape: Optional[torch.fx.Node] + post_mm_reshape: Optional[torch.fx.Node] + + def __post_init__(self): + super().__post_init__() + self.arg_ancestor_nodes |= _find_ancestors(self.A_scale_node) + self.arg_ancestor_nodes |= _find_ancestors(self.B_scale_node) + + @classmethod + def from_match(cls, match: list[torch.fx.Node]) -> "_ScaledMatmul": + assert len(match) in (1, 3) + assert match[0].target in ( + aten._scaled_mm.default, + aten.reshape.default, + ) + + def get_arg(node: torch.fx.Node, idx: int, default: Any) -> Any: + if idx >= len(node.args): + return default + return node.args[idx] + + # Use mm_node with 2D args for both A and B, even if this is a "reshape -> mm -> reshape" pattern. + # We will store the reshapes in pre_mm_reshape and post_mm_reshape, to be referenced later to + # produce the correct output shapes, reduce-scatter along the correct dimensions, etc. + is_reshape_mm_reshape_pattern = match[0].target == aten.reshape.default + mm_node = match[1] if is_reshape_mm_reshape_pattern else match[0] + pre_mm_reshape = match[0] if is_reshape_mm_reshape_pattern else None + post_mm_reshape = match[-1] if is_reshape_mm_reshape_pattern else None + A_node = cast("torch.fx.Node", mm_node.args[0]) + B_node = cast("torch.fx.Node", mm_node.args[1]) + A_scale_node = cast("torch.fx.Node", mm_node.args[2]) + B_scale_node = cast("torch.fx.Node", mm_node.args[3]) + + return _ScaledMatmul( + nodes=match, + A_node=A_node, + B_node=B_node, + A_scale_node=A_scale_node, + B_scale_node=B_scale_node, + bias_node=get_arg(mm_node, 4, None), + result_scale_node=get_arg(mm_node, 5, None), + out_dtype=get_arg(mm_node, 6, None), + use_fast_accum=get_arg(mm_node, 7, False), + pre_mm_reshape=pre_mm_reshape, + post_mm_reshape=post_mm_reshape, + ) + + +def _find_reshape_mm_reshape(node: torch.fx.Node) -> list[_Matmul]: + if node.target != aten.reshape.default: + return [] + + matches = [] + for mm_node in node.users: + if mm_node.target not in (aten.mm.default, aten._scaled_mm.default): + continue + for reshape_node in mm_node.users: + if reshape_node.target != aten.reshape.default: + continue + + # Since the reshape -> mm -> reshape pattern would be subsumed into + # the fused op, we only match the patterns where the shape of the + # second reshape is matches the mm result produced by the fused op. + matmul_input_node = cast("torch.fx.Node", node.args[0]) + B_node = cast("torch.fx.Node", mm_node.args[1]) + matmul_out_shape = torch.Size( + [ + *_get_tensor(matmul_input_node).shape[:-1], + _get_tensor(B_node).shape[-1], + ] + ) + if _get_tensor(reshape_node).shape != matmul_out_shape: + continue + matches.append([node, mm_node, reshape_node]) + # If for some rare reason mm_node is being reshaped by two + # different reshape nodes, we only include mm_node once in the + # parsing result. + break + + matmuls = [] + for match in matches: + mm_node = match[1] + if mm_node.target == aten.mm.default: + matmul = _Matmul.from_match(match) + matmuls.append(matmul) + elif mm_node.target == aten._scaled_mm.default: + matmul = _ScaledMatmul.from_match(match) + matmuls.append(matmul) + else: + raise AssertionError( + "Expect the node's target to be either aten.mm.default or " + f"aten._scaled_mm.default. Got {mm_node.target}." + ) + return matmuls + + +def _find_consumer_matmuls(node: torch.fx.Node) -> list[_Matmul]: + """ + Find the matmuls that use `node` as the lhs argument. + """ + matmuls = [] + for user in node.users: + # ND matmuls + if user.target == aten.reshape.default: + matmuls.extend(_find_reshape_mm_reshape(user)) + # 2D matmuls + elif user.target == aten.mm.default: + matmul = _Matmul.from_match(match=[user]) + matmuls.append(matmul) + elif user.target == aten._scaled_mm.default: + matmul = _ScaledMatmul.from_match([user]) + matmuls.append(matmul) + return matmuls + + +def _insert_fused_all_gather_matmul( + graph: torch.fx.Graph, + matmuls: list[_Matmul], + shard_node: torch.fx.Node, + gather_dim: int, + group_name: str, +) -> torch.fx.Node: + mm_types = OrderedSet(map(type, matmuls)) + assert len(mm_types) == 1 + mm_type = next(iter(mm_types)) + if mm_type == _Matmul: + B_nodes = [matmul.B_node for matmul in matmuls] + return graph.call_function( + torch.ops.symm_mem.fused_all_gather_matmul.default, + args=(shard_node, B_nodes, gather_dim, group_name), + kwargs={"return_A": True}, + ) + elif mm_type == _ScaledMatmul: + scaled_matmuls = cast("list[_ScaledMatmul]", matmuls) + return graph.call_function( + torch.ops.symm_mem.fused_all_gather_scaled_matmul.default, + args=( + shard_node, + [matmul.B_node for matmul in scaled_matmuls], + scaled_matmuls[0].A_scale_node, + [matmul.B_scale_node for matmul in scaled_matmuls], + gather_dim, + group_name, + [matmul.bias_node for matmul in scaled_matmuls], + [matmul.result_scale_node for matmul in scaled_matmuls], + [matmul.out_dtype for matmul in scaled_matmuls], + [matmul.use_fast_accum for matmul in scaled_matmuls], + ), + ) + else: + raise AssertionError(f"Unexpected matmul match type: {mm_type}") + + +def fuse_all_gather_matmul(all_gather: _AllGatherMatch) -> None: + """ + Fused the pattern + + A = all_gather_tensor(A_shard, gather_dim, group_name) + C_0 = torch.matmul(A, B_0) + C_1 = torch.matmul(A, B_1) + C_2 = torch.matmul(A, B_2) + ... + + into + + A, Cs = torch.ops.symm_mem.fused_all_gather_matmul( + A_shard, [B_0, B_1, B_2, ...], gather_dim, group_name, + ) + """ + if ( + not torch.distributed.is_available() + or not torch.distributed.is_nccl_available() + ): + return + + from torch.distributed._symmetric_memory import ( + is_symm_mem_enabled_for_group, + restride_A_shard_for_fused_all_gather_matmul, + ) + + shard_node, ag_node, ag_res_node, gather_dim, group_name = ( + all_gather.shard_node, + all_gather.ag_node, + all_gather.res_node, + all_gather.gather_dim, + all_gather.group_name, + ) + + if not is_symm_mem_enabled_for_group(group_name): + return + + if gather_dim >= len(_get_tensor(shard_node).shape) - 1: + # Decomposing the matmul on the K dimension is not supported + return + + # Find consumer matmuls + matmuls = _find_consumer_matmuls(ag_res_node) + + # The matmuls are only fusible if non-A args don't depend on the all-gather + # result node + matmuls = [ + matmul + for matmul in matmuls + if all_gather.res_node not in matmul.arg_ancestor_nodes + ] + + if len(matmuls) == 0 or len(OrderedSet(map(type, matmuls))) != 1: + return + + # Fuse the all_gather_tensor with the eligible matmuls + graph = ag_node.graph + with graph.inserting_before(ag_node): + if "val" in shard_node.meta: + restrided = restride_A_shard_for_fused_all_gather_matmul( + _get_tensor(shard_node), + gather_dim, + ) + shard_node = graph.call_function( + inductor_prims.force_stride_order, + args=(shard_node, restrided.stride()), + ) + + fused_node = _insert_fused_all_gather_matmul( + graph, matmuls, shard_node, gather_dim, group_name + ) + new_ag_node = graph.call_function( + operator.getitem, + args=(fused_node, 0), + ) + new_out_nodes = graph.call_function( + operator.getitem, + args=(fused_node, 1), + ) + for idx, matmul in enumerate(matmuls): + new_out_node = graph.call_function( + operator.getitem, + args=(new_out_nodes, idx), + ) + matmul.replace_with(new_out_node) + matmul.erase() + all_gather.replace_with(new_ag_node) + all_gather.erase() + + # If the new_ag_node has no users, we tell the fused op to not return + # it. This creates more optimization opportunities. + if len(new_ag_node.users) == 0: + graph.erase_node(new_ag_node) + kwargs = dict(fused_node.kwargs) + if "return_A" in kwargs: + kwargs["return_A"] = False + fused_node.kwargs = kwargs + + # Raise ancestors of non-A args that are topologically ordered between + # ag_res_node and the matmul above fused_node. + order = {node: idx for idx, node in enumerate(graph.nodes)} + nodes_to_raise = sorted( + OrderedSet(x for matmul in matmuls for x in matmul.arg_ancestor_nodes), + key=lambda x: order[x], + ) + for node in nodes_to_raise: + if order[node] > order[fused_node]: + fused_node.prepend(node) + + +def _scatter_dim_after_reshape( + reshape_node: torch.fx.Node, orig_scatter_dim: int +) -> int: + """ + Given a reshape node and the original scatter dim for the target tensor, + returns the new scatter dim for the reshaped tensor. + """ + # if there was no pre-mm reshape, scatter dim will not change. + if not reshape_node: + return orig_scatter_dim + + reshape_op_output_tensor = _get_tensor(reshape_node) + assert reshape_op_output_tensor.ndim == 2, ( + "reshape must produce 2D tensor for scaled_mm" + ) + + assert len(reshape_node.args) >= 1, "reshape node must have at least 1 arg" + input_tensor_node = cast(torch.fx.Node, reshape_node.args[0]) + reshape_op_input_tensor = _get_tensor(input_tensor_node) + assert reshape_op_input_tensor.ndim > reshape_op_output_tensor.ndim, ( + "reshape must be from 3D+ to 2D" + ) + + # Note: for a N-D tensor to be reshaped into 2D, either the leading dims or ending dims must + # be collapsed to a single dim. First determine which of these happened. + input_shape = reshape_op_input_tensor.shape + output_shape = reshape_op_output_tensor.shape + leading_dims_collapsed = output_shape[0] == prod(input_shape[:-1]) + + # Case 1: scatter dim 0 always maps to 0 after any reshape from 3D+ to 2D, regardless if + # leading dims or ending dims were collapsed. + if orig_scatter_dim == 0: + return 0 + + # Case 2: scatter dim "ndim-1" always maps to 1 after any reshape from 3D+ to 2D, regardless if + # leading dims or ending dims were collapsed. + if orig_scatter_dim == reshape_op_input_tensor.ndim - 1: + return 1 + + # Case 3: scatter dim was one of the middle dims (between 0 and ndim-1). + # if the leading dims were collapsed, the new scatter dim will be 0. + # if the ending dims were collapsed, the new scatter dim will be 1. + return 0 if leading_dims_collapsed else 1 + + +def _find_producer_matmul(node: torch.fx.Node) -> Optional[_Matmul]: + """ + Returns producer matmul node if found, otherwise returns None. + """ + if node.target == aten.mm.default: + return _Matmul.from_match(match=[node]) + elif node.target == aten._scaled_mm.default: + return _ScaledMatmul.from_match(match=[node]) + elif node.target == aten.reshape.default: + reshape_node_1 = node + + mm_node = reshape_node_1.args[0] + assert isinstance(mm_node, torch.fx.Node) + if mm_node.target not in (aten.mm.default, aten._scaled_mm.default): + return None + + reshape_node_0 = mm_node.args[0] + assert isinstance(reshape_node_0, torch.fx.Node) + if reshape_node_0.target != aten.reshape.default: + return None + + if mm_node.target == aten.mm.default: + return _Matmul.from_match(match=[reshape_node_0, mm_node, reshape_node_1]) + elif mm_node.target == aten._scaled_mm.default: + return _ScaledMatmul.from_match( + match=[reshape_node_0, mm_node, reshape_node_1] + ) + return None + + +def _insert_fused_matmul_reduce_scatter( + graph: torch.fx.Graph, + matmul: _Matmul, + reduce_op: str, + orig_scatter_dim: int, + group_name: str, + scatter_dim_after_reshape: int, # only used for reshape -> scaled_mm -> reshape pattern + output_shape: list[int], # only used for reshape -> scaled_mm -> reshape pattern +) -> torch.fx.Node: + if type(matmul) == _Matmul: + return graph.call_function( + torch.ops.symm_mem.fused_matmul_reduce_scatter.default, + args=( + matmul.A_node, + matmul.B_node, + reduce_op, + orig_scatter_dim, + group_name, + ), + ) + elif type(matmul) == _ScaledMatmul: + return graph.call_function( + torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default, + args=( + matmul.A_node, + matmul.B_node, + matmul.A_scale_node, + matmul.B_scale_node, + reduce_op, + orig_scatter_dim, + scatter_dim_after_reshape, + group_name, + output_shape, + matmul.bias_node, + matmul.result_scale_node, + matmul.out_dtype, + matmul.use_fast_accum, + ), + ) + else: + raise AssertionError(f"Unexpected matmul match type: {type(matmul)}") + + +def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None: + """ + Fused the pattern + + reduce_scatter_tensor(A @ B, scatter_dim, group_name) + + into + + torch.ops.symm_mem.fused_matmul_reduce_scatter( + A, B, scatter_dim, group_name, + ) + + Returns boolean indicating if fusion was successful or not. + """ + if ( + not torch.distributed.is_available() + or not torch.distributed.is_nccl_available() + ): + return + + from torch.distributed._symmetric_memory import ( + is_symm_mem_enabled_for_group, + restride_A_for_fused_matmul_reduce_scatter, + ) + + ( + input_node, + _reduce_scatter_node, + rs_wait_tensor_node, + reduce_op, + orig_scatter_dim, + group_name, + ) = ( + reduce_scatter.input_node, + reduce_scatter.reduce_scatter_node, + reduce_scatter.wait_tensor_node, + reduce_scatter.reduce_op, + reduce_scatter.scatter_dim, + reduce_scatter.group_name, + ) + + if not is_symm_mem_enabled_for_group(group_name): + return + + # Currently fused_matmul_reduce_scatter doesn't return the matmul result, + # so we can't apply the fusion if the matmul result is used by multiple + # users. This is not a fundamental limitation of the fused op and can be + # addressed if needed. + if len(input_node.users) != 1: + log.warning( + "matmul result has more than one user, skipping fused_matmul_reduce_scatter fusion." + ) + return + + matmul = _find_producer_matmul(input_node) + if matmul is None: + log.warning( + "no producer matmul found for reduce scatter, skipping fuse_matmul_reduce_scatter fusion" + ) + return + + if rs_wait_tensor_node in matmul.arg_ancestor_nodes: + log.warning( + "reduce-scatter result node is an ancestor of matmul, skipping fuse_matmul_reduce_scatter fusion" + ) + return + + # We need to track 3 values for the fused scaled mm reduce scatter implementation: + # 1. The scatter dim before the reshape, which was assigned using the original (a,b,c) @ (c,d) = (a,b,d) dims. + # 2. The scatter dim after the reshape, to use when we are doing the 2D (a*b,c) @ (c,d) = (a,b,d) scaled mm op. + # 3. Store expected potentially 3D+ mm output shape, so we can reshape the 2D mm output to the intended + # 3D+ shape before applying reduce-scatter, and to prevent shape errors with subsequent ops. + + # If 'A' was reshaped from 3D+ -> 2D for the mm, we need to determine the new scattter dim after the reshape + # for the fused matmul reduce scatter implementation to use. + if matmul.pre_mm_reshape: + scatter_dim_after_maybe_reshape = _scatter_dim_after_reshape( + matmul.pre_mm_reshape, orig_scatter_dim + ) + else: + scatter_dim_after_maybe_reshape = orig_scatter_dim + + # If the 2D mm output was reshaped from 2D -> 3D+, we need to store the intended output shape for the + # fused matmul reduce scatter implementation to use. + if matmul.post_mm_reshape: + output_shape = list(_get_tensor(matmul.post_mm_reshape).shape) + else: + A_orig_shape = list(_get_tensor(matmul.A_node).shape) + B_shape = list(_get_tensor(matmul.B_node).shape) + output_shape = [*A_orig_shape[:-1], B_shape[-1]] + + graph = rs_wait_tensor_node.graph + with graph.inserting_before(rs_wait_tensor_node): + # Restride A tensor before fused op, for optimal perf in fused matmul reduce scatter + if "val" in matmul.A_node.meta: + restrided = restride_A_for_fused_matmul_reduce_scatter( + _get_tensor(matmul.A_node), + scatter_dim_after_maybe_reshape, + ) + matmul.A_node = graph.call_function( + inductor_prims.force_stride_order, + args=(matmul.A_node, restrided.stride()), + ) + + # Replace matched subgraph with fused matmul reduce scatter node + fused_node = _insert_fused_matmul_reduce_scatter( + graph, + matmul, + reduce_op, + orig_scatter_dim, + group_name, + scatter_dim_after_maybe_reshape, + output_shape, + ) + reduce_scatter.replace_with(fused_node) + reduce_scatter.erase() + matmul.erase() + + order = {node: idx for idx, node in enumerate(graph.nodes)} + nodes_to_raise = sorted( + matmul.arg_ancestor_nodes, + key=lambda x: order[x], + ) + for node in nodes_to_raise: + if order[node] > order[fused_node]: + fused_node.prepend(node) + + log.debug("successfully fused matmul reduce scatter") + + +def _get_node_to_ancestors( + graph: torch.fx.Graph, +) -> dict[torch.fx.Node, OrderedSet[torch.fx.Node]]: + """ + Compute the ancestors for all nodes in a graph. + """ + node_to_ancestors = defaultdict(OrderedSet[torch.fx.Node]) # type: ignore[var-annotated] + for node in graph.nodes: + node_to_ancestors[node] = OrderedSet(node.all_input_nodes) + for dep in node.all_input_nodes: + node_to_ancestors[node] |= node_to_ancestors[dep] + + return node_to_ancestors + + +def _get_collective_to_overlappable_nodes( + graph: torch.fx.Graph, +) -> dict[torch.fx.Node, list[torch.fx.Node]]: + """ + For each collective in the graph, find nodes that are neither ancestors nor + descendants of the collective. + """ + + def is_collective(node) -> bool: + # Only consider all-gather and reduce-scatter in the context of + # micro-pipeline TP. + return node.target in [ + torch.ops._c10d_functional.all_gather_into_tensor.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + ] + + node_to_ancestors = _get_node_to_ancestors(graph) + collective_to_overlappable_nodes = defaultdict(list) + for node in graph.nodes: + if not is_collective(node): + continue + for x in graph.nodes: + if ( + node not in node_to_ancestors[x] + and x not in node_to_ancestors[node] + and x.op == "call_function" + ): + collective_to_overlappable_nodes[node].append(x) + + return collective_to_overlappable_nodes + + +def _get_unexposed_collectives(graph: torch.fx.Graph) -> list[torch.fx.Node]: + """ + Find all unexposed collectives in the graph. + + Because we don't have the runtime estimate, this function is a rough + estimation using the following strong/hand-wavy assumptions: + + - Only a predefined set of "compute intensive" operation can hide a collective. + - Any "compute intensive" operation can hide exactly one collective. + """ + + def _is_compute_intensive(node: torch.fx.Node) -> bool: + return node.target in [torch.ops.aten.mm.default] + + collective_to_overlapping_candidates = defaultdict(list) + available_nodes = OrderedSet[torch.fx.Node]() + collective_to_overlappable_nodes = _get_collective_to_overlappable_nodes(graph) + for collective, overlappable_nodes in collective_to_overlappable_nodes.items(): + candidates = [x for x in overlappable_nodes if _is_compute_intensive(x)] + collective_to_overlapping_candidates[collective] = candidates + available_nodes.update(candidates) + + unexposed_collectives = [] + for ( + collective, + overlapping_candidates, + ) in collective_to_overlapping_candidates.items(): + # Each collective consumes exactly one overlapping candidate + for x in overlapping_candidates: + if x in available_nodes: + unexposed_collectives.append(collective) + available_nodes.remove(x) + break + return unexposed_collectives + + +def micro_pipeline_tp_pass(graph: torch.fx.Graph): + all_gathers = find_all_gather_patterns(graph) + reduce_scatters = find_reduce_scatter_patterns(graph) + + # When a collective can be hidden through either simple overlapping or + # micro-pipeline TP, we prefer simple overlapping to avoid the overhead + # associated with decomposition. If reorder_for_compute_comm_overlap is + # enabled, we identify collectives that can be hidden through simple + # overlapping and exclude them from micro-pipeline TP candidates. + if config.reorder_for_compute_comm_overlap: + unexposed_collectives = _get_unexposed_collectives(graph) + all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives] + reduce_scatters = [ + x + for x in reduce_scatters + if x.reduce_scatter_node not in unexposed_collectives + ] + + if not all_gathers and not reduce_scatters: + log.warning( + "async TP found no matching all-gather/reduce-scatter patterns for fusion" + ) + + for all_gather in all_gathers: + fuse_all_gather_matmul(all_gather) + + for reduce_scatter in reduce_scatters: + fuse_matmul_reduce_scatter(reduce_scatter) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/misc_patterns.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/misc_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..9d5a0e56b3d03e1e07457eeb398dc655bad5f5f4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/misc_patterns.py @@ -0,0 +1,131 @@ +# mypy: allow-untyped-defs +import functools + +import torch +from torch._dynamo.utils import counters +from torch._ops import OpOverload, OpOverloadPacket +from torch.utils._ordered_set import OrderedSet + +from ..pattern_matcher import fwd_only, register_replacement + + +aten = torch.ops.aten + + +@functools.cache +def _misc_patterns_init(): + from .joint_graph import patterns as joint_graph_patterns + from .post_grad import pass_patterns as post_grad_patterns_all + + post_grad_patterns = post_grad_patterns_all[1] # medium priority + + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + + # These patterns do 2 things + # 1. Since we know that index is completely unique, we can codegen it using + # stores instead of atomic adds, which is quite a bit faster. + # 2. Also, since we are guaranteed that they are completely within bounds, + # we can use unsafe indexing and skip debug asserts + def randperm_index_add_pattern(x, y): + index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] + return torch.index_add(x, dim=0, source=y, index=index), index + + def randperm_index_add_replacement(x, y): + index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] + return ( + torch.ops.aten._unsafe_index_put( + x, (index,), aten._unsafe_index(x, (index,)) + y, accumulate=False + ), + index, + ) + + register_replacement( + randperm_index_add_pattern, + randperm_index_add_replacement, + [torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)], + fwd_only, + [post_grad_patterns, joint_graph_patterns], + ) + + def randperm_index_pattern(x, slice_shape): + index = torch.randperm(x.shape[0], device=x.device)[:slice_shape] + return torch.ops.aten.index(x, (index,)), index + + def randperm_index_replacement(x, slice_shape): + index = torch.randperm(x.shape[0], device=x.device)[:slice_shape] + return torch.ops.aten._unsafe_index(x, (index,)), index + + register_replacement( + randperm_index_pattern, + randperm_index_replacement, + [torch.empty(4, 8, device=device)], + fwd_only, + [post_grad_patterns, joint_graph_patterns], + scalar_workaround={"slice_shape": 42}, + ) + + +class NumpyCompatNormalization: + numpy_compat: dict[str, tuple[str, ...]] = { + "dim": ("axis",), + "keepdim": ("keepdims",), + "input": ("x", "a", "x1"), + "other": ("x2",), + } + inverse_mapping: dict[str, str] + cache: dict["torch.fx.graph.Target", OrderedSet[str]] + + def __init__(self) -> None: + self.cache = {} # callable -> tuple of replaceable args e.g. ["axis"] + self.inverse_mapping = {} + for actual_kwarg, numpy_kwargs in self.numpy_compat.items(): + for numpy_kwarg in numpy_kwargs: + assert numpy_kwarg not in self.inverse_mapping + self.inverse_mapping[numpy_kwarg] = actual_kwarg + + def __call__(self, graph: torch.fx.Graph): + for node in graph.nodes: + if node.op != "call_function": + continue + if isinstance(node.target, (OpOverload, OpOverloadPacket)): + # only applies to torch ops; e.g. torch.stack(axis=1) works, torch.ops.aten.stack(axis=1) doesn't. + continue + kwargs = node.kwargs + + if node.target in self.cache: + replaceable_kwargs = self.cache[node.target] + else: + signatures = torch.fx.operator_schemas.get_signature_for_torch_op( + node.target + ) + signatures = () if signatures is None else signatures + replaceable_kwargs = OrderedSet() + for sig in signatures: + for param_name in sig.parameters.keys(): + if param_name in self.numpy_compat: + replaceable_kwargs.update(self.numpy_compat[param_name]) + + self.cache[node.target] = replaceable_kwargs + + if not replaceable_kwargs: + continue + + new_kwargs = {} + kwargs_changed = False + for k, v in kwargs.items(): + if k in replaceable_kwargs: + kwargs_changed = True + new_kwargs[self.inverse_mapping[k]] = v + else: + new_kwargs[k] = v + + if kwargs_changed: + node.kwargs = torch.fx.immutable_collections.immutable_dict(new_kwargs) + counters["inductor"]["numpy_compat_normalization"] += 1 + + +numpy_compat_normalization = NumpyCompatNormalization() diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..4c9f10b4e2b2c795f61935664f87cf1f54e41cf9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -0,0 +1,1527 @@ +# mypy: allow-untyped-defs +import functools +import operator +from functools import reduce +from typing import Any, Callable + +import torch +from torch._dynamo.utils import counters +from torch.fx.experimental.symbolic_shapes import has_free_symbols +from torch.utils._ordered_set import OrderedSet + +from .. import ir +from ..lowering import lowerings as L +from ..pattern_matcher import ( + Arg, + CallFunction, + filter_nodes, + get_arg_value, + KeywordArg, + MULTIPLE, +) +from ..utils import ( + is_mkldnn_bf16_supported, + is_mkldnn_fp16_supported, + SUPPORTED_MKLDNN_DEVICES, +) +from ..virtualized import ops, V +from .freezing_patterns import register_freezing_graph_pattern +from .post_grad import register_lowering_pattern +from .quantization import ( + _register_int8_woq_concat_linear_pattern, + _register_quantization_lowerings, + _register_quantization_weight_pack_pass, + _register_woq_lowerings, +) + + +if torch._C._has_mkldnn: + aten = torch.ops.aten + mkldnn = torch.ops.mkldnn + prims = torch.ops.prims + + _conv_args = [Arg() for _ in range(10)] + _linear_args = [Arg() for _ in range(6)] + _conv_transpose_args = [Arg() for _ in range(11)] + + class MkldnnDeviceOpBase: + def get_linear_transpose_weight(self, weight_node): + raise NotImplementedError + + def pack_conv_weight( + self, + graph, + is_transposed, + weight, + constant_args, + input_size, + ): + raise NotImplementedError + + def pack_linear_weight( + self, graph, is_lp_weight, transpose_weight_node, batch_size + ): + raise NotImplementedError + + def pack_linear( + self, graph, is_lp_weight, batch_size, input, packed_weight_node, bias + ): + raise NotImplementedError + + class CpuMkldnnDeviceOp(MkldnnDeviceOpBase): + def get_linear_transpose_weight(self, weight_node): + packed_weight_node = weight_node + assert packed_weight_node.target == mkldnn._reorder_linear_weight + transpose_weight_node = packed_weight_node.args[0] + assert transpose_weight_node.target == aten.permute.default + return transpose_weight_node + + def pack_conv_weight( + self, + graph, + is_transposed, + weight, + constant_args, + input_size, + ): + packed_weight_op = mkldnn._reorder_convolution_weight + if is_transposed: + packed_weight_op = mkldnn._reorder_convolution_transpose_weight + + # mkldnn_reorder_conv_weight(self, padding, stride, dilation, groups, input_size) + packed_weight_inputs = (weight,) + tuple(constant_args) + (input_size,) + return graph.create_node( + "call_function", packed_weight_op, args=packed_weight_inputs + ) + + def pack_linear_weight( + self, graph, is_lp_weight, transpose_weight_node, batch_size + ): + # For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance. + packed_weight_inputs = ( + transpose_weight_node, + batch_size.node.shape_env.size_hint(batch_size.node.expr) + if has_free_symbols(batch_size) + else batch_size, + ) + + # MKL packed matrix can't be copied to a different address because the internal implementation + # depends on the alignment of internally-stored metadata. + # In aot mode, we need to firstly save the packed weight, when loading it, + # it will be in a different address which doesn't work. + # Disable MKL prepack linear in AOT mode + packed_weight_op = ( + mkldnn._reorder_linear_weight + if ( + is_lp_weight + or mkldnn._is_mkldnn_acl_supported() + or V.aot_compilation + ) + else torch.ops.mkl._mkl_reorder_linear_weight + ) + return graph.create_node( + "call_function", packed_weight_op, args=packed_weight_inputs + ) + + def pack_linear( + self, graph, is_lp_weight, batch_size, input, packed_weight_node, bias + ): + packed_linear_inputs: tuple[Any, ...] = (input, packed_weight_node) + transpose_weight_node = packed_weight_node.args[0] + if is_lp_weight or mkldnn._is_mkldnn_acl_supported() or V.aot_compilation: + packed_linear_inputs += (bias, "none", [], "") + packed_linear_op: Callable[..., Any] = mkldnn._linear_pointwise.default + else: + packed_linear_inputs += (transpose_weight_node, bias, batch_size) + packed_linear_op = torch.ops.mkl._mkl_linear + + return graph.create_node( + "call_function", packed_linear_op, packed_linear_inputs + ) + + class XpuMkldnnDeviceOp(MkldnnDeviceOpBase): + def pack_conv_weight( + self, + graph, + is_transposed, + weight, + constant_args, + input_size, + ): + assert not is_transposed, ( + "'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device." + ) + return weight + + def _get_mkldnn_device_op(device_type: str) -> MkldnnDeviceOpBase: + """ + Returns the MKLDNN device operation class based on the current device type. + """ + if device_type == "cpu": + return CpuMkldnnDeviceOp() + elif device_type == "xpu": + return XpuMkldnnDeviceOp() + else: + raise RuntimeError(f"MKLDNN is not supported on {device_type} device.") + + def _is_valid_grouped_gemm_fusion(computation_nodes): + """ + Here we check: + 1. More than 1 GEMM nodes has been found. + 2. All the GEMM nodes share the same activation. + 3. All the GEMM nodes have same weight size but different wgt node. + """ + computation_op = mkldnn._linear_pointwise.default + act = computation_nodes[0].args[0] + wgt = computation_nodes[0].args[1] + wgt_size = wgt.meta.get("val").size() # type: ignore[union-attr] + return len(computation_nodes) >= 2 and all( + ( + node.target == computation_op + and node.args[0] == act + and (node.args[1].meta.get("val").size() == wgt_size) + and (node.args[1] != wgt or gemm_idx == 0) + ) + for gemm_idx, node in enumerate(computation_nodes) + ) + + def grouped_gemm_pass(graph: torch.fx.Graph): + """ + Group GEMM has multi output nodes which is complicated to define a Pattern. + Use below way to connect the pattern to the lowering. + TODO: Use MultiOutputPattern, current limitation is the pattern requires + fixed number of output nodes. Extend to support Group GEMM for pattern matcher. + """ + computation_op = mkldnn._linear_pointwise.default + from ..mkldnn_lowerings import grouped_gemm_lowering + + for node in graph.find_nodes(op="call_function", target=computation_op): + if ( + not node._erased + and isinstance(node.meta.get("val"), torch.Tensor) + and node.meta["val"].device.type == "cpu" + ): + act = node.args[0] + users = list(act.users) + if _is_valid_grouped_gemm_fusion(users): + with graph.inserting_before(node): + grouped_gemm_node = graph.create_node( + "call_function", + grouped_gemm_lowering, + ( + act, + [user.args[1] for user in users], + [user.args[2] for user in users], + ), + ) + grouped_gemm_node.meta["val"] = [ + user.meta["val"] for user in users + ] + with graph.inserting_after(grouped_gemm_node): + for gemm_idx, user in enumerate(users): + assert user.target == computation_op + get_item = graph.create_node( + "call_function", + operator.getitem, + ( + grouped_gemm_node, + gemm_idx, + ), + ) + user.replace_all_uses_with(get_item) + graph.erase_node(user) + return + + def _conv_call(users=1): + return CallFunction( + mkldnn._convolution_pointwise.default, *_conv_args, _users=users + ) + + def _linear_call(users=1): + return CallFunction( + mkldnn._linear_pointwise.default, *_linear_args, _users=users + ) + + def _conv_transpose_call(users=1): + return CallFunction( + mkldnn._convolution_transpose_pointwise.default, + *_conv_transpose_args, + _users=users, + ) + + def _to_float(input_call, users=1): + return CallFunction( + prims.convert_element_type.default, + input_call, + KeywordArg("to_float"), + _users=users, + ) + + def _to_bf16(input_call): + return CallFunction( + prims.convert_element_type.default, + input_call, + KeywordArg("to_bf16"), + _users=1, + ) + + def _to_fp16(input_call): + return CallFunction( + prims.convert_element_type.default, + input_call, + KeywordArg("to_fp16"), + _users=1, + ) + + def _unary_fusion_pattern(unary_fusion, call_fn, users, lowp_dtype): + # only insert to_dtype if lowp_dtype is True + computation_call = ( + _to_float(call_fn(), users=users) if lowp_dtype else call_fn(users=users) + ) + out = unary_fusion(computation_call) + if lowp_dtype == torch.bfloat16: + return _to_bf16(out) + elif lowp_dtype == torch.float16: + return _to_fp16(out) + else: + return out + + def _gelu_fusion_1(computation_call): + return CallFunction( + aten.mul, + CallFunction(aten.mul, computation_call, 0.5), + CallFunction( + aten.add, + CallFunction( + aten.erf, + CallFunction(aten.mul, computation_call, 0.7071067811865476), + ), + 1, + ), + ) + + def _gelu_fusion_2(computation_call): + return CallFunction( + aten.mul, + CallFunction(aten.mul, computation_call, 0.5), + CallFunction( + aten.add, + CallFunction( + aten.tanh, + CallFunction( + aten.mul, + CallFunction( + aten.add, + computation_call, + CallFunction( + aten.mul, + CallFunction( + aten.mul, + CallFunction( + aten.mul, computation_call, computation_call + ), + computation_call, + ), + 0.044715, + ), + ), + 0.7978845608028654, + ), + ), + 1, + ), + ) + + def _hardswish_fusion(computation_call): + return CallFunction( + aten.div, + CallFunction( + aten.mul, + computation_call, + CallFunction( + aten.clamp_max, + CallFunction( + aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0 + ), + 6, + ), + ), + 6, + ) + + def _silu_fusion(computation_call): + return CallFunction( + aten.mul, computation_call, CallFunction(aten.sigmoid, computation_call) + ) + + def _hardsigmoid_fusion(computation_call): + return CallFunction( + aten.div, + CallFunction( + aten.clamp_max, + CallFunction( + aten.clamp_min, CallFunction(aten.add, computation_call, 3), 0 + ), + 6, + ), + 6, + ) + + def _leaky_relu_fusion(computation_call): + return CallFunction( + aten.where, + CallFunction(aten.gt, computation_call, 0), + computation_call, + CallFunction(aten.mul, computation_call, KeywordArg("negative_slope")), + ) + + def _hardtanh_fusion(computation_call): + return CallFunction( + aten.clamp_max, + CallFunction(aten.clamp_min, computation_call, KeywordArg("min_value")), + KeywordArg("max_value"), + ) + + def _combined_fusion(computation_call, elementwise_op): + return CallFunction(elementwise_op, computation_call) + + # binary_op(other, computation_op) + def _binary_fusion_v1(computation_call, binary_fn): + return CallFunction(binary_fn, KeywordArg("other"), computation_call) + + # binary_op(computation_op, other) + def _binary_fusion_v2(computation_call, binary_fn): + return CallFunction(binary_fn, computation_call, KeywordArg("other")) + + def _is_single_computation_op(computation_op, lowp_dtype=None): + def fn(match): + computation_nodes = filter_nodes(match.nodes, computation_op) + + if lowp_dtype: + output_node_meta = match.output_node().meta.get("val") + if output_node_meta.dtype != lowp_dtype: + return False + + if len(computation_nodes) < 1: + return False + if any(n.args[-3] != "none" for n in computation_nodes): + return False + return True + + return fn + + def _is_valid_computation_unary_fusion(computation_op, lowp_dtype=None): + def fn(match): + matched = _is_single_computation_op(computation_op, lowp_dtype)(match) + computation_node = filter_nodes(match.nodes, computation_op)[0] + if lowp_dtype: + conversion_dtype_nodes = filter_nodes( + match.nodes, prims.convert_element_type.default + ) + if len(conversion_dtype_nodes) != 2: + return False + # fusion pattern is always in the form of computation_op + to_float32 + unary_op + to_bfloat16 + if computation_node == conversion_dtype_nodes[0].args[0]: + to_float = conversion_dtype_nodes[0].args[1] + to_lp = conversion_dtype_nodes[1].args[1] + else: + to_float = conversion_dtype_nodes[1].args[1] + to_lp = conversion_dtype_nodes[0].args[1] + matched = matched and to_float == torch.float and to_lp == lowp_dtype + return matched + + return fn + + def _register_unary_fusion_lowering( + pattern, unary_attr, computation_op, lowp_dtype=None + ): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_computation_unary_fusion(computation_op, lowp_dtype), + ) + def fn(match, *args, **kwargs): + computation_args = list(args)[:-3] + [ + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ] + counters["inductor"]["mkldnn_unary_fusion_matcher_count"] += 1 + counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"] += len( + match.nodes + ) + return L[computation_op](*computation_args) + + return fn + + def _register_leaky_relu_fusion_lowering(pattern, computation_op, lowp_dtype=None): + @register_lowering_pattern( + pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype) + ) + def fn(match, *args, **kwargs): + negative_slope = kwargs.get("negative_slope") + if isinstance(negative_slope, ir.TensorBox): + matched = False + else: # inp is a Number + matched = True + if lowp_dtype: + dtype1 = kwargs.get("to_float") + dtype2 = ( + kwargs.get("to_bf16") + if lowp_dtype == torch.bfloat16 + else kwargs.get("to_fp16") + ) + matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype + computation_args = list(args) + counters["inductor"]["mkldnn_unary_fusion_matcher_count"] += 1 + counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"] += len( + match.nodes + ) + if matched: + computation_args = computation_args[:-3] + [ + "leaky_relu", + [negative_slope], + "", + ] + return L[computation_op](*computation_args) + else: + # computation_args += ["none", [], ""] + out = L[computation_op](*computation_args) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=torch.float) + out = L[aten.where]( + L[aten.gt](out, 0), + out, + L[aten.mul](out, negative_slope), + ) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined] + return out + + return fn + + def _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype=None): + @register_lowering_pattern( + pattern, extra_check=_is_single_computation_op(computation_op, lowp_dtype) + ) + def fn(match, *args, **kwargs): + min_value = kwargs.get("min_value") + max_value = kwargs.get("max_value") + if isinstance(min_value, ir.TensorBox) or isinstance( + max_value, ir.TensorBox + ): + matched = False + else: # inp is a Number + assert max_value is not None + matched = min_value <= max_value + if lowp_dtype: + dtype1 = kwargs.get("to_float") + dtype2 = ( + kwargs.get("to_bf16") + if lowp_dtype == torch.bfloat16 + else kwargs.get("to_fp16") + ) + matched = matched and dtype1 == torch.float and dtype2 == lowp_dtype + computation_args = list(args) + counters["inductor"]["mkldnn_unary_fusion_matcher_count"] += 1 + counters["inductor"]["mkldnn_unary_fusion_matcher_nodes"] += len( + match.nodes + ) + if matched: + computation_args = computation_args[:-3] + [ + "hardtanh", + [min_value, max_value], + "", + ] + return L[computation_op](*computation_args) + else: + out = L[computation_op](*computation_args) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=torch.float) + out = L[aten.clamp_max](L[aten.clamp_min](out, min_value), max_value) + if lowp_dtype: + out = L[prims.convert_element_type.default](out, dtype=dtype2) # type: ignore[possibly-undefined] + return out + + return fn + + _binary_attr = { + aten.add: "add", + ops.add: "add", + aten.sub: "sub", + ops.sub: "sub", + } + + def _is_valid_binary(match, computation_op, binary_op): + binary_nodes = filter_nodes(match.nodes, binary_op) + if len(binary_nodes) < 1: + return False + + def get_meta_value(argument: torch.fx.node.Argument): + # Only torch.fx.Node is expected to have meta. + if isinstance(argument, torch.fx.Node): + return argument.meta.get("val", None) + return None + + if any( + not isinstance(get_meta_value(n.args[0]), torch.Tensor) + or not isinstance(get_meta_value(n.args[1]), torch.Tensor) + for n in binary_nodes + ): + return False + # check alpha is one. + if any( + get_arg_value(n, 2, kwarg_name="alpha") != 1.0 + and get_arg_value(n, 2, kwarg_name="alpha") is not None + for n in binary_nodes + ): + return False + + def _check_input_sizes(n, computation_op): + # Check if the tensor shape of the 'other' node is the same as or + # can be broadcasted to the tensor shape of the computation node. + computation_node = ( + n.args[0] if n.args[1] is match.kwargs["other"] else n.args[1] + ) + assert computation_node.target == computation_op + computation_node_size = get_meta_value(computation_node).size() + if computation_op is mkldnn._linear_pointwise.default: + broadcast_sizes = [] + if len(computation_node_size) >= 2: + broadcast_sizes = [ + torch.Size( + [1 for _ in range(len(computation_node_size) - 1)] + + [computation_node_size[-1]] + ), + ] + else: + assert len(computation_node_size) > 2 + broadcast_sizes = [ + torch.Size( + [computation_node_size[0], computation_node_size[1]] + + [1 for _ in range(len(computation_node_size) - 2)] + ), + torch.Size( + [1, computation_node_size[1]] + + [1 for _ in range(len(computation_node_size) - 2)] + ), + torch.Size([1 for _ in range(len(computation_node_size))]), + ] + return ( + get_meta_value(match.kwargs["other"]).size() + in [ + computation_node_size, + ] + + broadcast_sizes + ) + + if any( + not _check_input_sizes(n, computation_op) + or get_meta_value(n.args[0]).device != get_meta_value(n.args[1]).device + or get_meta_value(n.args[0]).dtype != get_meta_value(n.args[1]).dtype + for n in binary_nodes + ): + return False + # check args[0] and args[1] is not same + if any(n.args[0] == n.args[1] for n in binary_nodes): + return False + return True + + def _is_valid_computation_binary(computation_op, binary_op, other_index=None): + def fn(match): + if not _is_single_computation_op(computation_op)(match): + return False + if not _is_valid_binary(match, computation_op, binary_op): + return False + return True + + return fn + + def _get_remaining_users(extra_input_node, compute_node): + # Think about this pattern: + # ReLU + # / \ + # Conv1 + # / \ + # Conv2 + # \ / + # Add + # Although, the extra input node (ReLU) has more than 1 users: Conv1 and Add. + # The Conv1 is the ancestor node of the current compute node (Conv2). + # This indicates that the buffer of ReLU has completed all its usage, + # So we can safely make changes to it now by doing Conv2->Add inplace fusion. + # Take above case as example: + # * extra_input_node: ReLU + # * compute_node: Conv2 + # _get_remaining_users will return the users of extra_input_node which are not + # ancestor node of compute_node. + def _is_ancestor_node(_current_node, _ancestor_node): + # Check whether _ancestor_node is the ancestor node of _current_node + _node_list = [_current_node] + _visited_nodes = OrderedSet[torch.fx.Node]() + while len(_node_list) != 0: + _current_node = _node_list.pop(0) + if _current_node not in _visited_nodes: + _visited_nodes.add(_current_node) + if _current_node == _ancestor_node: + return True + elif isinstance( + _current_node, torch.fx.Node + ) and _current_node.op not in ["placeholder", "output", "get_attr"]: + for input in _current_node.all_input_nodes: + _node_list.append(input) # noqa: PERF402 + return False + + return [ + user + for user in list(extra_input_node.users) + if not _is_ancestor_node(compute_node, user) + ] + + def _is_valid_computation_binary_inplace(computation_op, binary_op, other_index): + def fn(match): + if not _is_valid_computation_binary(computation_op, binary_op)(match): + return False + binary_nodes = filter_nodes(match.nodes, binary_op) + + def _get_compute_node(_binary_node, _other_index): + assert len(_binary_node.all_input_nodes) == 2, ( + "Binary node should have 2 input nodes." + ) + _compute_index = 1 if (_other_index == 0) else 0 + return _binary_node.args[_compute_index] + + def _other_input_not_inplaceable(_binary_node, _other_index): + _compute_node = _get_compute_node(_binary_node, _other_index) + return ( + len( + _get_remaining_users( + _binary_node.args[_other_index], _compute_node + ) + ) + > 1 + or _binary_node.args[_other_index] == _compute_node.args[0] + ) + + if any(_other_input_not_inplaceable(n, other_index) for n in binary_nodes): + return False + if any( + n.args[other_index].op in ["placeholder", "output"] + for n in binary_nodes + ): + return False + return True + + return fn + + def _register_binary_unary_fusion_lowering( + pattern, + computation_op, + binary_op, + fusion_op, + unary_attr=None, + ): + @register_lowering_pattern( + pattern, extra_check=_is_valid_computation_binary(computation_op, binary_op) + ) + def fn(match, *args, **kwargs): + other = kwargs.get("other") + assert isinstance(other, ir.TensorBox) + binary_attr = _binary_attr[binary_op] + args_list = list(args) + computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr] + if len(args_list) > 6: + if unary_attr is not None: + computation_args += [ + 1.0, + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ] + else: + computation_args += [1.0, None, [], None] + counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1 + counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += ( + len(match.nodes) + ) + return L[fusion_op](*computation_args) + + return fn + + def _can_be_inplace(_other): + return not ( + isinstance(_other.data, ir.BaseView) + or len(_other.get_inputs_that_alias_output()) > 0 + ) + + def _register_binary_unary_maybe_inplace_fusion_lowering( + pattern, + computation_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + unary_attr=None, + other_index=None, + ): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_computation_binary_inplace( + computation_op, binary_op, other_index + ), + ) + def fn(match, *args, **kwargs): + other = kwargs.get("other") + assert isinstance(other, ir.TensorBox) + binary_attr = _binary_attr[binary_op] + args_list = list(args) + computation_args = [args_list[0], other] + args_list[1:-3] + [binary_attr] + if len(args_list) > 6: + if unary_attr is not None: + computation_args += [ + 1.0, + unary_attr.op_name, + unary_attr.scalars_attr, + unary_attr.algorithm_attr, + ] + else: + computation_args += [1.0, None, [], None] + counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1 + counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += ( + len(match.nodes) + ) + # Make sure the other is not an alias or mutation(fx side doesn't has such info). + other.realize() + if not _can_be_inplace(other) or other.data.shape != list( + match.nodes[0].meta["val"].size() + ): + return L[outplace_fusion_op](*computation_args) + return L[inplace_fusion_op](*computation_args) + + return fn + + computation_ops = [ + mkldnn._convolution_pointwise.default, + mkldnn._linear_pointwise.default, + mkldnn._convolution_transpose_pointwise.default, + ] + + class UnaryAttr: + def __init__( + self, op_name: str, scalars_attr=None, algorithm_attr=None + ) -> None: + self.op_name = op_name + self.scalars_attr = scalars_attr if scalars_attr else [] + self.algorithm_attr = algorithm_attr if algorithm_attr else "" + + def _register_unary_fusion(): + computation_call_fns = [_conv_call, _linear_call, _conv_transpose_call] + + def _unary_fusion_patterns(lowp_dtype): + replacement_unary_fusion_patterns = { + UnaryAttr("gelu", algorithm_attr="tanh"): [ + _unary_fusion_pattern(_gelu_fusion_2, call_fn, 4, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("gelu", algorithm_attr="none"): [ + _unary_fusion_pattern(_gelu_fusion_1, call_fn, 2, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("hardswish"): [ + _unary_fusion_pattern(_hardswish_fusion, call_fn, 2, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("hardsigmoid"): [ + _unary_fusion_pattern(_hardsigmoid_fusion, call_fn, 1, lowp_dtype) + for call_fn in computation_call_fns + ], + UnaryAttr("swish"): [ + _unary_fusion_pattern(_silu_fusion, call_fn, 2, lowp_dtype) + for call_fn in computation_call_fns + ], + } + if not lowp_dtype: + call_user1 = [call_fn(users=1) for call_fn in computation_call_fns] + replacement_unary_fusion_patterns.update( + { + UnaryAttr("relu"): [ + _combined_fusion(u, aten.relu) for u in call_user1 + ], + UnaryAttr("sigmoid"): [ + _combined_fusion(u, aten.sigmoid) for u in call_user1 + ], + UnaryAttr("tanh"): [ + _combined_fusion(u, aten.tanh) for u in call_user1 + ], + } + ) + + return replacement_unary_fusion_patterns + + for lowp_dtype in [torch.bfloat16, torch.float16, None]: + replace_patterns = _unary_fusion_patterns(lowp_dtype) + for unary_attr, patterns in replace_patterns.items(): + _register_unary_fusion_lowering( + patterns[0], unary_attr, computation_ops[0], lowp_dtype + ) + _register_unary_fusion_lowering( + patterns[1], unary_attr, computation_ops[1], lowp_dtype + ) + _register_unary_fusion_lowering( + patterns[2], unary_attr, computation_ops[2], lowp_dtype + ) + _leaky_relu_patterns = [ + _unary_fusion_pattern(_leaky_relu_fusion, call_fn, 3, lowp_dtype) + for call_fn in computation_call_fns + ] + for pattern, computation_op in zip(_leaky_relu_patterns, computation_ops): + _register_leaky_relu_fusion_lowering( + pattern, computation_op, lowp_dtype + ) + hardtanh_patterns = [ + _unary_fusion_pattern(_hardtanh_fusion, call_fn, 1, lowp_dtype) + for call_fn in computation_call_fns + ] + for pattern, computation_op in zip(hardtanh_patterns, computation_ops): + _register_hardtanh_fusion_lowering(pattern, computation_op, lowp_dtype) + + def _register_inplace_fusion(): + binary_ops = [aten.add, ops.add] + inplace_fusion_op = mkldnn._convolution_pointwise_.binary + outplace_fusion_op = mkldnn._convolution_pointwise.binary + conv_call = _conv_call(users=1) + conv_op = computation_ops[0] + for binary_op in binary_ops: + binary_v1 = _binary_fusion_v1(conv_call, binary_op) + binary_unary_v1 = _combined_fusion(binary_v1, aten.relu) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_unary_v1, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=0, + unary_attr=UnaryAttr("relu"), + ) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_v1, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=0, + ) + binary_v2 = _binary_fusion_v2(conv_call, binary_op) + binary_unary_v2 = _combined_fusion(binary_v2, aten.relu) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_unary_v2, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=1, + unary_attr=UnaryAttr("relu"), + ) + _register_binary_unary_maybe_inplace_fusion_lowering( + binary_v2, + conv_op, + binary_op, + inplace_fusion_op, + outplace_fusion_op, + other_index=1, + ) + + def _register_binary_fusion(): + binary_ops = [aten.add, ops.add, aten.sub, ops.sub] + fusion_ops = [ + mkldnn._convolution_pointwise.binary, + mkldnn._linear_pointwise.binary, + ] + _computation_user_1 = [_conv_call(users=1), _linear_call(users=1)] + for computation_call, computation_op, fusion_op in zip( + _computation_user_1, computation_ops[:-1], fusion_ops + ): + for binary_op in binary_ops: + pattern = _binary_fusion_v2(computation_call, binary_op) + _register_binary_unary_fusion_lowering( + pattern, computation_op, binary_op, fusion_op + ) + + for binary_op in [aten.add, ops.add]: + pattern = _binary_fusion_v1(computation_call, binary_op) + _register_binary_unary_fusion_lowering( + pattern, computation_op, binary_op, fusion_op + ) + + def _register_binary_unary_fusion(): + binary_ops = [aten.add, ops.add, aten.sub, ops.sub] + fusion_ops = [mkldnn._convolution_pointwise.binary] + _computation_user_1 = [_conv_call(users=1)] + for computation_call, computation_op, fusion_op in zip( + _computation_user_1, computation_ops[:-1], fusion_ops + ): + for binary_op in binary_ops: + pattern_v1 = _combined_fusion( + _binary_fusion_v2(computation_call, binary_op), aten.relu + ) + _register_binary_unary_fusion_lowering( + pattern_v1, + computation_op, + binary_op, + fusion_op, + unary_attr=UnaryAttr("relu"), + ) + for binary_op in [aten.add, ops.add]: + pattern_v2 = _combined_fusion( + _binary_fusion_v1(computation_call, binary_op), aten.relu + ) + _register_binary_unary_fusion_lowering( + pattern_v2, + computation_op, + binary_op, + fusion_op, + unary_attr=UnaryAttr("relu"), + ) + + def _recover_linear(): + # convert reshape+linear+reshape to a single linear for applying fusion path. + # concat_linear (pass_number=0) -> mkldnn_linear_pack (pass_numer=1) -> _recover_linear(pass_number=2) + @register_freezing_graph_pattern( + CallFunction( + aten.reshape.default, + CallFunction( + mkldnn._linear_pointwise.default, + CallFunction( + aten.reshape.default, + Arg(), + KeywordArg("reshape_1"), + _users=MULTIPLE, + ), + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + ), + KeywordArg("reshape_2"), + ), + pass_number=2, + ) + def reshape_linear_reshape_pattern(match, *args, **kwargs): + def get_val(val): + return val if isinstance(val, int) else val.meta.get("val") + + reshape_1 = kwargs.get("reshape_1") + reshape_2 = kwargs.get("reshape_2") + assert isinstance(reshape_1, list) + assert isinstance(reshape_2, list) + assert len(reshape_1) == 2 + + graph = match.graph + reshape_2_node = match.output_node() + linear_input_node = reshape_2_node.args[0].args[0].args[0] + # check linear's input's shape[:-1] == reshape_2[:-1] + # and check product(reshape_2[:-1]) == reshape_1[0] + can_remove_reshape = linear_input_node.meta.get("val").shape[ + :-1 + ] == torch.Size([get_val(val) for val in reshape_2[:-1]]) + can_remove_reshape = can_remove_reshape and ( + reduce( + operator.mul, + [get_val(val) for val in reshape_2[:-1]], + ) + == get_val(reshape_1[0]) + ) + + if can_remove_reshape: + repl = graph.call_function(mkldnn._linear_pointwise.default, args) + repl.meta.update(reshape_2_node.meta) + reshape_2_node.replace_all_uses_with(repl) + old_linear_node = reshape_2_node.args[0] + reshape_1_node = old_linear_node.args[0] + graph.erase_node(reshape_2_node) + graph.erase_node(old_linear_node) + if len(reshape_1_node.users) == 0: + graph.erase_node(reshape_1_node) + counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_count"] += 1 + counters["inductor"]["mkldnn_reshape_linear_reshape_matcher_nodes"] += len( + match.nodes + ) + + def is_linear_add_bias(match): + add_node = match.output_node() + linear_node = add_node.args[0] + device_type = add_node.meta.get("val").device.type + mkldnn_device_op = _get_mkldnn_device_op(device_type) + transpose_weight_node = mkldnn_device_op.get_linear_transpose_weight( + linear_node.args[1] + ) + weight_meta = transpose_weight_node.args[0].meta.get("val") + bias_node = add_node.args[1] + if isinstance(bias_node, int): + # we only folding bias if it is a constant + return False + bias_meta = add_node.args[1].meta.get("val") + if weight_meta is None or bias_meta is None: + return False + + if bias_meta.dtype != weight_meta.dtype: + return False + return ( + linear_node.args[2] is None + and bias_meta.dim() == 1 + and bias_meta.size(0) == weight_meta.size(1) + ) + + # convert linear+bias to a single linear for applying fusion path. + @register_freezing_graph_pattern( + CallFunction( + aten.add.Tensor, + CallFunction(mkldnn._linear_pointwise.default, *_linear_args), + Arg(), + ), + pass_number=2, + extra_check=is_linear_add_bias, + ) + def linear_bias_pattern(match, *args): + graph = match.graph + add_node = match.output_node() + linear_node = add_node.args[0] + new_args = list(linear_node.args) + new_args[2] = add_node.args[1] + repl = graph.call_function( + mkldnn._linear_pointwise.default, tuple(new_args) + ) + repl.meta.update(add_node.meta) + add_node.replace_all_uses_with(repl) + match.erase_nodes() + counters["inductor"]["mkldnn_linear_bias_matcher_count"] += 1 + counters["inductor"]["mkldnn_linear_bias_matcher_nodes"] += len(match.nodes) + + def _is_packable_mkldnn_rnn_layer(match): + lstm_node = match.output_node() + POS_WEIGHTS = [1, 2] + POS_INPUTS = [0, 5, 6] + POS_ARGS = POS_WEIGHTS + POS_INPUTS + # Weights should be Constant + if any( + lstm_node.args[POS_WEIGHT].op != "get_attr" for POS_WEIGHT in POS_WEIGHTS + ): + return False + + # Meta info for weights and inputs should be available + if any(lstm_node.args[POS_ARG].meta.get("val") is None for POS_ARG in POS_ARGS): + return False + + # Check device + if any( + lstm_node.args[POS_ARG].meta.get("val").device.type != "cpu" + for POS_ARG in POS_ARGS + ): + return False + + # Check dtype + if any( + lstm_node.args[POS_ARG].meta.get("val").dtype == torch.bfloat16 + and not is_mkldnn_bf16_supported("cpu") + for POS_ARG in POS_ARGS + ): + return False + if any( + lstm_node.args[POS_ARG].meta.get("val").dtype == torch.float16 + and not is_mkldnn_fp16_supported("cpu") + for POS_ARG in POS_ARGS + ): + return False + + return True + + def _is_packable_convolution(match): + """ + Check if the node is supported for MKLDNN convolution. + """ + conv_node = match.output_node() + device_type = conv_node.meta.get("val").device.type + # The operator 'mkldnn::_convolution_transpose_pointwise' is not currently implemented for the XPU device. + if match.kwargs["is_transposed"] and device_type == "xpu": + return False + + input_meta_value = conv_node.args[0].meta.get("val") + weight_meta_value = conv_node.args[1].meta.get("val") + if input_meta_value is None or weight_meta_value is None: + return False + input_size = input_meta_value.shape + if conv_node.args[1].op != "get_attr": + return False + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or meta_value.device.type not in SUPPORTED_MKLDNN_DEVICES + or (meta_value.dim() != 4 and meta_value.dim() != 5) + ): + return False + + if ( + input_meta_value.dtype == torch.bfloat16 + or weight_meta_value.dtype == torch.bfloat16 + ): + if not is_mkldnn_bf16_supported(device_type): + return False + if ( + input_meta_value.dtype == torch.float16 + or weight_meta_value.dtype == torch.float16 + ): + if not is_mkldnn_fp16_supported(device_type): + return False + is_transposed = conv_node.args[-3] + if is_transposed: + # TODO: Support dynamic shape case for MKLDNN conv transpose. + if has_free_symbols(input_size): + return False + groups = conv_node.args[-1] + in_channels = weight_meta_value.size(0) + # doesn't support group_depthwise_conv_transpose. + if groups > 1 and groups == in_channels: + return False + # Port from: aten/src/ATen/native/Convolution.cpp:is_output_padding_big + output_paddings = conv_node.args[-2] + strides = conv_node.args[3] + if any( + output_padding >= stride + for output_padding, stride in zip(output_paddings, strides) + ): + return False + return True + + def _is_packable_linear(match): + """ + Check if the node is supported for MKLDNN linear. + """ + + def is_const_or_cat_by_const(weight): + if weight.op == "get_attr": + return True + if weight.target != aten.cat.default: + return False + return all(arg.op == "get_attr" for arg in weight.args[0]) + + linear_node = match.output_node() + # mkldnn linear only supports beta=1or0 and alpha=1 + if linear_node.target == aten.addmm.default: + alpha = linear_node.kwargs.get("alpha", 1.0) + beta = linear_node.kwargs.get("beta", 1.0) + if (beta != 0.0 and beta != 1.0) or alpha != 1.0: + return False + # weight_idx is 1 for aten.mm and is 2 for aten.addmm + weight_idx = 2 if linear_node.target == aten.addmm.default else 1 + if not is_const_or_cat_by_const(linear_node.args[weight_idx]): + return False + input_meta_value = linear_node.args[weight_idx - 1].meta.get("val") + weight_meta_value = linear_node.args[weight_idx].meta.get("val") + if input_meta_value is None or weight_meta_value is None: + return False + batch_size = input_meta_value.shape[0] + if ( + input_meta_value.dtype == torch.float64 + or weight_meta_value.dtype == torch.float64 + ): + return False + is_lp_weight = weight_meta_value.dtype in ( + torch.bfloat16, + torch.float16, + ) + # on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol. + # on aarch64, use mkldnn op for fp32 as well if acl is enabled + if ( + not is_lp_weight + and not mkldnn._is_mkldnn_acl_supported() + and ((not torch._C.has_mkl) or has_free_symbols(batch_size)) + ): + return False + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or meta_value.device.type != "cpu" + or meta_value.dim() != 2 + ): + return False + if weight_idx == 2: + bias_meta_value = linear_node.args[0].meta.get("val") + if ( + bias_meta_value is None + or meta_value.device.type != "cpu" + or bias_meta_value.dim() != 1 + or bias_meta_value.size(0) != weight_meta_value.size(1) + ): + return False + + device_type = input_meta_value.device.type + if ( + input_meta_value.dtype == torch.bfloat16 + or weight_meta_value.dtype == torch.bfloat16 + ): + if not is_mkldnn_bf16_supported(device_type): + return False + if ( + input_meta_value.dtype == torch.float16 + or weight_meta_value.dtype == torch.float16 + ): + if not is_mkldnn_fp16_supported(device_type): + return False + return True + + _aten_conv_args = ( + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + KeywordArg("is_transposed"), + Arg(), + Arg(), + ) + + _aten_mkldnn_rnn_layer_args = ( + Arg(), # input + Arg(), # weight0 + Arg(), # weight1 + Arg(), # weight2 + Arg(), # weight3 + Arg(), # hx_ + Arg(), # cx_ + KeywordArg("reverse"), # reverse + Arg(), # batch_sizes + Arg(), # mode + Arg(), # hidden_size + Arg(), # num_layers + Arg(), # has_biases + Arg(), # bidirectional + Arg(), # batch_first + Arg(), # train + ) + + def _register_weight_pack_pass(): + @register_freezing_graph_pattern( + CallFunction(aten.convolution.default, *_aten_conv_args), + extra_check=_is_packable_convolution, + ) + def convolution(match, *args, **kwargs): + is_transposed = kwargs.get("is_transposed") + assert isinstance(is_transposed, bool) + graph = match.graph + conv_node = match.output_node() + device_type = conv_node.args[0].meta.get("val").device.type + mkldnn_device_op = _get_mkldnn_device_op(device_type) + input_size = conv_node.args[0].meta.get("val").shape + with graph.inserting_before(conv_node): + constant_args = [args[4], args[3], args[5], args[-1]] + packed_conv_op = mkldnn._convolution_pointwise.default + if is_transposed: + constant_args.insert(1, args[-2]) # output_padding + packed_conv_op = mkldnn._convolution_transpose_pointwise.default + + if not has_free_symbols(input_size): + packed_weight_node = mkldnn_device_op.pack_conv_weight( + graph, + is_transposed, + args[1], + constant_args, + input_size, + ) + else: + assert not is_transposed + # For dynamic shape case, we need to pack weight in runtime. + packed_weight_node = args[1] + + packed_conv_inputs = ( + (args[0], packed_weight_node, args[2]) + + tuple(constant_args) + + ("none", [], "") + ) + packed_conv_node = graph.create_node( + "call_function", packed_conv_op, tuple(packed_conv_inputs) + ) + conv_node.replace_all_uses_with(packed_conv_node) + packed_conv_node.meta.update(conv_node.meta) + graph.erase_node(conv_node) + counters["inductor"]["mkldnn_conv_weight_pack_matcher_count"] += 1 + counters["inductor"]["mkldnn_conv_weight_pack_matcher_nodes"] += len( + match.nodes + ) + + @register_freezing_graph_pattern( + CallFunction(aten.mkldnn_rnn_layer.default, *_aten_mkldnn_rnn_layer_args), + extra_check=_is_packable_mkldnn_rnn_layer, + ) + def mkldnn_rnn_layer(match, *args, **kwargs): + def get_item(graph, node, index): + return graph.call_function(operator.getitem, (node, index)) + + graph = match.graph + lstm_node = match.output_node() + weight0, weight1 = args[1:3] + reverse = kwargs.get("reverse") + packed_lstm_op = aten.mkldnn_rnn_layer.default + hidden_size = args[9] + has_biases = args[11] + batch_first = args[13] + with graph.inserting_before(lstm_node): + packed_weight_op = mkldnn._reorder_mkldnn_rnn_layer_weight.default + packed_weight_inputs = ( + weight0, + weight1, + hidden_size, + reverse, + has_biases, + batch_first, + ) + packed_weight_node = graph.create_node( + "call_function", packed_weight_op, packed_weight_inputs, {}, "name" + ) + packed_weight_items = [ + get_item(graph, packed_weight_node, i) for i in range(2) + ] + pack_lstm_inputs = ( + args[0], + *packed_weight_items, + args[3], + args[4], + args[5], + args[6], + reverse, + *args[7:], + ) + + packed_lstm_node = graph.create_node( + "call_function", packed_lstm_op, args=pack_lstm_inputs + ) + lstm_node.replace_all_uses_with(packed_lstm_node) + packed_lstm_node.meta.update(lstm_node.meta) + graph.erase_node(lstm_node) + counters["inductor"]["mkldnn_rnn_weight_pack_matcher_count"] += 1 + counters["inductor"]["mkldnn_rnn_weight_pack_matcher_nodes"] += len( + match.nodes + ) + + @register_freezing_graph_pattern( + CallFunction( + aten.addmm.default, + Arg(), + Arg(), + Arg(), + beta=KeywordArg("beta"), + alpha=KeywordArg("alpha"), + ), + extra_check=_is_packable_linear, + pass_number=1, + ) + @register_freezing_graph_pattern( + CallFunction(aten.mm.default, Arg(), Arg()), + extra_check=_is_packable_linear, + pass_number=1, + ) + def linear(match, *args, **kwargs): + graph = match.graph + linear_node = match.output_node() + input = args[0] if linear_node.target == aten.mm.default else args[1] + bias = ( + None + if linear_node.target == aten.mm.default + or ( + linear_node.target == aten.addmm.default + and linear_node.kwargs.get("beta", 1.0) == 0.0 + ) + else args[0] + ) + weight = args[1] if linear_node.target == aten.mm.default else args[2] + device_type = input.meta.get("val").device.type + mkldnn_device_op = _get_mkldnn_device_op(device_type) + with graph.inserting_before(linear_node): + transpose_weight_node = graph.create_node( + "call_function", aten.permute.default, (weight, (1, 0)) + ) + weight_dtype = weight.meta.get("val").dtype + is_lp_weight = weight_dtype in ( + torch.bfloat16, + torch.float16, + ) + batch_size = input.meta.get("val").shape[0] + if has_free_symbols(batch_size): + assert is_lp_weight or mkldnn._is_mkldnn_acl_supported(), ( + f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}" + ) + packed_weight_node = mkldnn_device_op.pack_linear_weight( + graph, is_lp_weight, transpose_weight_node, batch_size + ) + packed_linear_node = mkldnn_device_op.pack_linear( + graph, is_lp_weight, batch_size, input, packed_weight_node, bias + ) + + linear_node.replace_all_uses_with(packed_linear_node) + packed_linear_node.meta.update(linear_node.meta) + graph.erase_node(linear_node) + counters["inductor"]["mkldnn_linear_weight_pack_matcher_count"] += 1 + counters["inductor"]["mkldnn_linear_weight_pack_matcher_nodes"] += len( + match.nodes + ) + + def _eliminate_duplicate_packed_nodes(gm): + """ + Combine packed weight nodes with the same inputs to reduce memory usage. + for example: + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(32, 32, bias=True) + + def forward(self, x): + return self.linear(self.linear(x)) + + the above's packed weight nodes are duplicate if two linear calls have same input size. + """ + if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()): + return gm + + packed_weight_ops = [ + torch._C._nn.mkldnn_reorder_conv2d_weight, + torch._C._nn.mkldnn_reorder_conv3d_weight, + mkldnn._reorder_convolution_transpose_weight, + mkldnn._reorder_linear_weight, + mkldnn._reorder_mkldnn_rnn_layer_weight, + ] + if torch._C.has_mkl: + packed_weight_ops.append(torch.ops.mkl._mkl_reorder_linear_weight) + + for node in gm.graph.nodes: + if node.target in packed_weight_ops and len(node.args[0].users) > 1: + for user_node in list(node.args[0].users.keys()): + if ( + user_node.target == node.target + and user_node != node + and user_node.args == node.args + ): + user_node.replace_all_uses_with(node) + gm.graph.erase_node(user_node) + + @functools.cache + def _mkldnn_fusion_init(): + # TODO: aarch64: enable op fusion for acl once it supports fused operators. Disabling it for now. + # Otherwise even the matmul or innerproduct can not be accelerated with acl + if ( + torch.backends.mkldnn.enabled + and torch.backends.mkldnn.is_available() + and not torch.ops.mkldnn._is_mkldnn_acl_supported() + ): + _register_unary_fusion() + _register_inplace_fusion() + _register_binary_unary_fusion() + _register_binary_fusion() + _register_quantization_lowerings() + _register_woq_lowerings() + + @functools.cache + def _mkldnn_weight_pack_init(): + if torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available(): + _register_weight_pack_pass() + _recover_linear() + _register_quantization_weight_pack_pass() + _register_int8_woq_concat_linear_pattern() diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/numeric_utils.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/numeric_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cd4822f3b48c2d94e7b6318bfaaa1371889f85c4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/numeric_utils.py @@ -0,0 +1,213 @@ +# mypy: allow-untyped-defs +import gc +import logging +import os +import random +import traceback + +import numpy + +import torch +import torch.optim as optim +from torch.utils._ordered_set import OrderedSet + +from .. import config + + +logger: logging.Logger = logging.getLogger(__name__) + +MAIN_RANDOM_SEED = 1337 + +# Set the CUBLAS_WORKSPACE_CONFIG environment variable +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + +# If the two forward functions involve any non-deterministic operations, +# such as certain types of parallelism or asynchronous execution, +# this can also lead to different outputs. +def set_deterministic() -> None: + """Make torch manual seed deterministic.""" + + torch.manual_seed(MAIN_RANDOM_SEED) + random.seed(MAIN_RANDOM_SEED) + numpy.random.seed(MAIN_RANDOM_SEED) + torch.use_deterministic_algorithms(True) + + +def clean_memory() -> None: + """Clean memory to avoid OOM.""" + gc.collect() + torch.cuda.empty_cache() + + +# We compare the numerical results before and after pre/post grad fx passes +# transformation to make sure the numerical results are the same. +def compare_dict_tensors(dict_base, dict_control, precision): + if len(OrderedSet(dict_base.keys())) != len(OrderedSet(dict_control.keys())): + logger.warning("Mismatch keys found before and after pre/post grad fx passes.") + logger.debug("keys before pre/post grad fx passes %s", dict_base.keys()) + logger.debug("keys after pre/post grad fx passes %s", dict_control.keys()) + return False + is_allclose = True + for key in dict_base.keys(): + if key not in dict_control: + logger.warning( + "Mismatch parameter name %s does not exist after pre/post grad fx passes", + key, + ) + # Some parameters have `None`, and not every param has a valid .grad field, we skip them + if dict_base[key] is None or dict_control[key] is None: + continue + if not torch.allclose( + dict_base[key], + dict_control[key], + rtol=precision, + atol=precision, + equal_nan=True, + ): + logger.warning( + "Mismatch parameter values found before and after pre/post grad fx passes." + ) + logger.debug("value before pre/post grad fx passes %s", dict_base[key]) + logger.debug("value after pre/post grad fx passes %s", dict_control[key]) + is_allclose = False + return is_allclose + + +def compare_tuple_tensors(tuple_base, tuple_control, precision): + if len(tuple_base) != len(tuple_control): + logger.warning( + "Mismatch fw output length. before transformation: %s, after transformation: %s", + len(tuple_base), + len(tuple_control), + ) + return False + is_allclose = True + for i in range(len(tuple_base)): + # Some parameters have `None`, we skip them + if tuple_base[i] is None or tuple_control[i] is None: + continue + if not torch.allclose( + tuple_base[i], + tuple_control[i], + rtol=precision, + atol=precision, + equal_nan=True, + ): + logger.debug( + "forward output before pre/post grad fx passes %s", tuple_base[i] + ) + logger.debug( + "forward output after pre/post grad fx passes %s", tuple_control[i] + ) + is_allclose = False + return is_allclose + + +def compare_parameters(model_base, model_control, precision): + return compare_dict_tensors( + dict(model_base.named_parameters()), + dict(model_control.named_parameters()), + precision, + ) + + +def compare_forward_output(pred_base, pred_control, precision): + return compare_tuple_tensors( + pred_base, + pred_control, + precision, + ) + + +def compare_gradients(model_base, model_control, precision): + grad_base = {key: param.grad for key, param in model_base.named_parameters()} + grad_pt2 = {key: param.grad for key, param in model_control.named_parameters()} + return compare_dict_tensors( + grad_base, + grad_pt2, + precision, + ) + + +def run_model( + model_base, model_control, model_input, num_iterations=10, precision=1e-4 +): + clean_memory() + for i in range(num_iterations): + logger.info("start %s iteration", i) + set_deterministic() + pred_base = model_base(*model_input) + set_deterministic() + pred_control = model_control(*model_input) + + res = compare_parameters(model_base, model_control, precision) + logger.info("compare parameters. Numerical result : %s", res) + + res = compare_forward_output(pred_base, pred_control, precision) + logger.info("compare loss/predict. Numerical result : %s", res) + # tensor may not have a grad_fn + try: + _ = pred_base[0].sum().backward(retain_graph=True) + _ = pred_control[0].sum().backward(retain_graph=True) + res = compare_gradients(model_base, model_control, precision) + logger.info("compare param grad. Numerical result : %s", res) + except Exception: + logger.exception("Exception when comparing gradients") + traceback.print_exc() + + if config.fx_passes_numeric_check["requires_optimizer"]: + try: + optimizer_base = optim.SGD( + [param for name, param in model_base.named_parameters()], lr=0.01 + ) + optimizer_base.step() + + optimizer_control = optim.SGD( + [param for name, param in model_control.named_parameters()], lr=0.01 + ) + optimizer_control.step() + + res = compare_parameters(model_base, model_control, precision) + logger.info( + "compare parameters with optimizer added. Numerical result : %s", + res, + ) + except Exception: + logger.exception( + "Exception when optimizer is added to check parameter names" + ) + traceback.print_exc() + else: + logger.warning( + "no parameter with optimizer to compare with length %s before transformation" + " and the length %s after transformation", + len(dict(model_base.named_parameters())), + len(dict(model_control.named_parameters())), + ) + + +def numeric_check_if_enabled( + gm_before_fx_passes, + gm_after_fx_passes, + example_inputs, + num_iterations, + precision, +): + # need to topo-sort graphmodule before we run the model, + # otherwise it may fail as refer before def + # fail silently in order not to block the model run + try: + with torch.autograd.set_detect_anomaly(True): + run_model( + gm_before_fx_passes, + gm_after_fx_passes, + example_inputs, + num_iterations=num_iterations, + precision=precision, + ) + except Exception as e: + logger.warning( + "Runtime numeric check failed in pre grad fx passes with error: %s", e + ) + traceback.print_exc() diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/pad_mm.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/pad_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..8561614a5137fa9397eff59abde1417bccc86012 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/pad_mm.py @@ -0,0 +1,925 @@ +import functools +import itertools +import operator +import typing +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union + +import torch +import torch._inductor.runtime.runtime_utils +from torch import Tensor +from torch._dynamo.utils import counters, dynamo_timed +from torch._inductor import utils +from torch._inductor.autoheuristic.autoheuristic import ( + AHContext, + AutoHeuristic, + LocalFeedback, +) +from torch._inductor.autoheuristic.autoheuristic_utils import ( + context_add_strides, + context_add_using_tf32, + pad_mm_operations, + pad_mm_precondition, +) +from torch._subclasses.fake_tensor import FakeTensor +from torch.utils._mode_utils import no_dispatch + +from ...utils._triton import has_triton +from ..pattern_matcher import ( + fwd_only, + gen_register_replacement, + joint_fwd_bwd, + Match, + ReplaceFn, + SearchFn, +) + + +aten = torch.ops.aten + + +# This flag is only used for testing purpose. +# Changing it to True will ignore comparing do_bench times +# between original pattern and padded one. +_skip_do_bench_times = False + + +def fetch_fake_tensors(match: Match, kwarg_names: Sequence[str]) -> list[Tensor]: + kwargs = match.kwargs + return [kwargs[name].meta["val"] for name in kwarg_names] + + +def unwrap_fake_args( + *arg_names: str, +) -> Callable[[Callable[..., Any]], Callable[[Match], Any]]: + def decorator(func: Callable[..., Any]) -> Callable[[Match], Any]: + def wrapper(match: Match) -> Any: + fake_tensors = fetch_fake_tensors(match, arg_names) + return func(*fake_tensors) + + return wrapper + + return decorator + + +def get_alignment_size(x: Tensor) -> int: + return get_alignment_size_dtype(x.dtype) + + +def get_alignment_size_dtype(dtype: torch.dtype) -> int: + if dtype == torch.float16 or dtype == torch.half or dtype == torch.bfloat16: + return 8 + elif dtype == torch.float32 or dtype == torch.float: + return 4 + else: + return 0 + + +def check_device(a: Tensor, b: Tensor) -> bool: + return a.is_cuda and b.is_cuda + + +def check_dtype(a: Tensor, b: Tensor) -> bool: + return a.is_floating_point() and b.is_floating_point() + + +def should_pad_common( + mat1: Tensor, mat2: Tensor, input: Optional[Tensor] = None +) -> bool: + # It's fine we have symbolic shapes or strides as long as they + # have hints. Later, we will make sure we only pad non-symbolic dimensions. + def valid_shape_and_stride(t: Optional[Tensor]) -> bool: + if t is None: + return True + + symbolic_cnt = 0 + for x in t.size(): + if isinstance(x, int): + continue + elif utils.is_symbolic(x): + if not x.node.has_hint(): + return False + symbolic_cnt += 1 + else: + return False + # filter out cases where all dimensions are symbolic + if symbolic_cnt == len(t.size()): + return False + return all( + isinstance(x, int) or (utils.is_symbolic(x) and x.node.has_hint()) + for x in t.stride() + ) + + return ( + torch._inductor.config.shape_padding + and check_device(mat1, mat2) + and check_dtype(mat1, mat2) + and all(valid_shape_and_stride(t) for t in (mat1, mat2, input)) + ) + + +def get_padded_length(x: Union[int, torch.SymInt], alignment_size: int) -> int: + # we don't pad x if it is symbolic + if isinstance(x, torch.SymInt) or alignment_size == 0 or x % alignment_size == 0: + return 0 + + # ignore dim that can be squeezed away + if x == 1: + return 0 + + return int((x // alignment_size + 1) * alignment_size) - x + + +def pad_dim(x: Tensor, padded_length: int, dim: int) -> Tensor: + if padded_length == 0: + return x + pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :]) + return torch.cat([x, pad], dim=dim) + + +def addmm_pattern( + input: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float +) -> Tensor: + return aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha) + + +def should_pad_addmm(match: Match) -> bool: + mat1, mat2, input = fetch_fake_tensors(match, ("mat1", "mat2", "input")) + return should_pad_common(mat1, mat2, input) and should_pad_bench( + match, mat1, mat2, torch.ops.aten.addmm, input=input + ) + + +def pad_addmm( + input: Optional[Tensor], + mat1: Tensor, + mat2: Tensor, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, + beta: float = 1.0, + alpha: float = 1.0, + mat1_pre_padded: bool = False, + mat2_pre_padded: bool = False, +) -> Tensor: + # for paddings, dim order is reversed for some reasons + # and for every dim, we need to specify left and right padding + if not mat1_pre_padded: + mat1 = pad_mat1( + mat1, m_padded_length=m_padded_length, k_padded_length=k_padded_length + ) + if not mat2_pre_padded: + mat2 = pad_mat2( + mat2, k_padded_length=k_padded_length, n_padded_length=n_padded_length + ) + + # the add broadcasts, so we only pad if the dimension != 1 + if input is not None: + if n_padded_length != 0: + if input.dim() == 2 and input.shape[1] != 1: + input = pad_dim(input, n_padded_length, 1) + elif input.dim() == 1 and input.shape[0] != 1: + input = pad_dim(input, n_padded_length, 0) + if m_padded_length != 0 and input.dim() == 2 and input.shape[0] != 1: + input = pad_dim(input, m_padded_length, 0) + + res = aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha) + + if m_padded_length != 0: + res = res[:-m_padded_length, :] + if n_padded_length != 0: + res = res[:, :-n_padded_length] + return res + + +def addmm_replace( + input: Optional[Tensor], + mat1: Tensor, + mat2: Tensor, + beta: float = 1.0, + alpha: float = 1.0, +) -> Tensor: + k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) + n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2)) + m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1)) + return pad_addmm( + input, + mat1, + mat2, + m_padded_length, + k_padded_length, + n_padded_length, + beta, + alpha, + ) + + +def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool: + denominator = M * K + N * K + M * N + if denominator == 0: + return False + arithmetic_intensity = (M * N * K) / denominator + + # we have experienced some large perf hits in this case, even in bandwidth bound regimes + if ( + dtype is torch.bfloat16 + and K > M + and K > N + and torch.cuda.get_device_capability() < (9, 0) + ): # doesn't repro on h100s: + return True + + # Fails with AMD + try: + machine_balance = ( + 1000 * utils.get_device_tflops(dtype) + ) / utils.get_gpu_dram_gbps() + except Exception: + return True + + # dram_gbps might be underestimating bandwidth because of cache. + # if we estimate machine balance too low we might miss some speedups, + # if we estimate too high there will be unnecessary compilation time increase. + # TODO - finetune coefficient here. As a reference point, Triton mm model assumes + # 80% of reads are in cache and cache is 4x faster than dram_gbps + machine_balance = machine_balance * 0.5 + + return arithmetic_intensity > machine_balance + + +@functools.cache +def get_pad_cache() -> torch._inductor.codecache.LocalCache: + return torch._inductor.codecache.LocalCache() + + +def get_cached_should_pad(key: str) -> bool: + return get_pad_cache().lookup(key) # type: ignore[return-value] + + +def set_cached_should_pad(key: str, value: bool) -> None: + return get_pad_cache().set_value(key, value=value) + + +def get_cached_base_mm_benchmark_time(key: str) -> float: + return get_pad_cache().lookup(key) # type: ignore[return-value] + + +def set_cached_base_mm_benchmark_time(key: str, value: float) -> None: + return get_pad_cache().set_value(key, value=value) + + +def should_pad_bench_key( + match: Match, + mat1: Tensor, + mat2: Tensor, + op: torch._ops.OpOverloadPacket, + input: Optional[Tensor] = None, + is_base_time_key: bool = False, +) -> str: + def tensor_key(t: Tensor) -> tuple[torch.Size, tuple[int, ...], torch.dtype]: + return (t.shape, t.stride(), t.dtype) + + tf32_key = ( + None if mat1.dtype != torch.float32 else torch.backends.cuda.matmul.allow_tf32 + ) + + def fmt_pad(name: str) -> Optional[str]: + if is_base_time_key: + return None + return f"exclude_pad:{should_exclude_padding_time(match, name)}" + + key = ( + tensor_key(mat1), + tensor_key(mat2), + fmt_pad("mat1"), + fmt_pad("mat2"), + op, + input if input is None else tensor_key(input), + tf32_key, + ) + + key = str(key) + if is_base_time_key: + key = f"base mm time: {key}" + return key + + +def get_non_view_def(node: torch.fx.Node) -> torch.fx.Node: + if node.op == operator.getitem: + return get_non_view_def(node.args[0]) # type: ignore[arg-type] + + if ( + node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) + and utils.is_view(node.target) + ): + return get_non_view_def(node.all_input_nodes[0]) + + return node + + +def should_exclude_padding_time(match: Match, arg_name: str) -> bool: + node_def = get_non_view_def(match.kwargs[arg_name]) + + # constant padding converts tensors to contiguous so even if the input tensor + # can be planned layout transform is not free. TODO - way to pad and preserve layout ? + if not fetch_fake_tensors(match, (arg_name,))[0].is_contiguous(): + return False + + # TODO - see issue https://github.com/pytorch/pytorch/issues/128889 + # We would only able to completely plan these out if we were only doing + # first dimension padding. non-first we would still need a copy + # because these outputs are fixed dense. + cannot_plan_output = [ + aten.mm.default, + aten.convolution.default, + aten.convolution_backward.default, + aten.bmm.default, + aten.addmm.default, + aten._scaled_dot_product_flash_attention.default, + aten._scaled_dot_product_efficient_attention.default, + ] + + if node_def.target in cannot_plan_output: + return False + + if ( + node_def.target == aten.cat.default + and len(node_def.all_input_nodes) + > torch._inductor.config.max_pointwise_cat_inputs + ): + return False + + # optimistically assume we should be able to memory plan away + # all non inputs + return node_def.op != "placeholder" + + +def should_pad(key: str, ori_time: float, pad_time: float) -> bool: + multiplier = 1.1 + # Shape padding introduces additional memory ops. Based on microbenchmarks, 1.1x represents a reasonable + # tradeoff between performance improvement from shape padding and overhead from additional memory ops + # TODO: Build a learned model which would be better than this heuristic + if "shape_padding_multiplier" in torch._inductor.config.post_grad_fusion_options: + multiplier = torch._inductor.config.post_grad_fusion_options[ + "shape_padding_multiplier" + ].get("value", 1.1) + counters["inductor"]["shape_padding_multiplier"] += 1 + should_pad = _skip_do_bench_times or ori_time > pad_time * multiplier + set_cached_should_pad(key, should_pad) + return should_pad + + +def should_pad_mm_bf16(dtype: torch.dtype, M: int, N: int, K: int) -> bool: + # always force pad for mm with bf16 when the following are satisfied to avoid perf regression + large_k_threshold_to_pad = torch._inductor.config.post_grad_fusion_options[ + "pad_aten_mm_pass" + ].get("k_threshold_to_pad", 8388608) + if ( + dtype is torch.bfloat16 + and K > M + and K > N + and N % 2 == 1 + and K >= large_k_threshold_to_pad + and torch.cuda.get_device_capability() < (9, 0) + ): # doesn't repro on h100s: + return True + return False + + +def should_pad_bench(*args: Any, **kwargs: Any) -> bool: + with dynamo_timed( + "pad_mm_benchmark", + log_pt2_compile_event=False, + dynamo_compile_column_us="compile_time_autotune_time_us", + ): + return _should_pad_bench(*args, **kwargs) + + +def get_do_bench() -> Callable[[Callable[[], Any]], float]: + with dynamo_timed("pad_mm_benchmark_get_do_bench"): + return functools.partial( + torch._inductor.runtime.benchmarking.benchmarker.benchmark_gpu, + warmup=5, + ) + + +def _should_pad_bench( + match: Match, + mat1: Tensor, + mat2: Tensor, + op: torch._ops.OpOverloadPacket, + input: Optional[Tensor] = None, +) -> bool: + do_bench = get_do_bench() + + m_padded_length = 0 + n_padded_length = 0 + with no_dispatch(): + if op is torch.ops.aten.mm or op is torch.ops.aten.addmm: + m = mat1.shape[0] + k = mat1.shape[1] + n = mat2.shape[1] + k_padded_length = get_padded_length(k, get_alignment_size(mat1)) + n_padded_length = get_padded_length(n, get_alignment_size(mat2)) + m_padded_length = get_padded_length(m, get_alignment_size(mat1)) + elif op is torch.ops.aten.bmm: + m = mat1.shape[1] + k = mat1.shape[2] + n = mat2.shape[2] + k_padded_length = get_padded_length(k, get_alignment_size(mat1)) + m_padded_length = get_padded_length(m, get_alignment_size(mat1)) + n_padded_length = get_padded_length(n, get_alignment_size(mat2)) + else: + return False + + if m_padded_length == k_padded_length == n_padded_length == 0: + return False + + def realize_symbols( + ds: Union[torch.Size, tuple[torch.SymInt, ...]], + ) -> list[int]: + return [d if isinstance(d, int) else d.node.hint for d in ds] + + if any( + dim == 0 + for dim in itertools.chain( + realize_symbols(mat1.shape), realize_symbols(mat2.shape) + ) + ): + return False + + if torch._inductor.config.force_shape_pad: + return True + + if ( + "pad_aten_mm_pass" in torch._inductor.config.post_grad_fusion_options + and should_pad_mm_bf16(mat1.dtype, m, n, k) + ): + return True + + if not has_triton(): + return False + + if not is_mm_compute_bound(m, k, n, mat1.dtype): + return False + + # We don't want to look up the cache for cases that are trivially false + # since it does file io + key = should_pad_bench_key(match, mat1, mat2, op, input) + + cached_pad = get_cached_should_pad(key) + if cached_pad is not None: + return cached_pad + + def realize_tensor(t): + if isinstance(t, FakeTensor): + size_hints = realize_symbols(t.size()) + stride_hint = realize_symbols(t.stride()) + real_size = ( + sum((d - 1) * s for d, s in zip(size_hints, stride_hint)) + 1 + ) + real_t = torch.randn(real_size, dtype=t.dtype, device=t.device) + return torch.as_strided(real_t, size_hints, stride_hint) + else: + return torch.randn_like(t) + + mat1 = realize_tensor(mat1) + mat2 = realize_tensor(mat2) + + # since we key on whether or not the inputs can be memory planned, set cache for the + # original time which is unaffected by whether or not the input can be planned + ori_time_key = should_pad_bench_key( + match, mat1, mat2, op, input, is_base_time_key=True + ) + ori_time = get_cached_base_mm_benchmark_time(ori_time_key) + if ori_time is None and op is torch.ops.aten.addmm and input is not None: + # realize bias for addmm + input = realize_tensor(input) + + mat1_pad = mat1 + mat2_pad = mat2 + + is_bmm = op is torch.ops.aten.bmm + + mat1_pre_padded = should_exclude_padding_time(match, "mat1") + fns = [] + if mat1_pre_padded and (m_padded_length or k_padded_length): + mat1_pad = pad_mat1( + mat1_pad, + m_padded_length=m_padded_length, + k_padded_length=k_padded_length, + is_bmm=is_bmm, + ) + + def write_pad(): + if is_bmm: + mat1_pad[:, -m_padded_length:, -k_padded_length:].fill_(0) + else: + mat1_pad[-m_padded_length:, -k_padded_length:].fill_(0) + + fns.append(write_pad) + + mat2_pre_padded = should_exclude_padding_time(match, "mat2") + if mat2_pre_padded and (k_padded_length or n_padded_length): + mat2_pad = pad_mat2( + mat2_pad, + k_padded_length=k_padded_length, + n_padded_length=n_padded_length, + is_bmm=is_bmm, + ) + + def write_pad(): + if is_bmm: + mat2_pad[:, -k_padded_length:, -n_padded_length:].fill_(0) + else: + mat2_pad[-k_padded_length:, -n_padded_length:].fill_(0) + + fns.append(write_pad) + + if op is torch.ops.aten.addmm: + input_pad = None + if input is not None and input.is_cuda: + input_pad = torch.randn_like(input) + fns.append( + lambda: pad_addmm( + input_pad, + mat1_pad, + mat2_pad, + m_padded_length, + k_padded_length, + n_padded_length, + mat1_pre_padded=mat1_pre_padded, + mat2_pre_padded=mat2_pre_padded, + ) + ) + elif op is torch.ops.aten.mm: + fns.append( + lambda: pad_mm( + mat1_pad, + mat2_pad, + m_padded_length, + k_padded_length, + n_padded_length, + mat1_pre_padded=mat1_pre_padded, + mat2_pre_padded=mat2_pre_padded, + ) + ) + else: + fns.append( + lambda: pad_bmm( + mat1_pad, + mat2_pad, + m_padded_length, + k_padded_length, + n_padded_length, + mat1_pre_padded=mat1_pre_padded, + mat2_pre_padded=mat2_pre_padded, + ) + ) + + def orig_bench_fn(): + if op is torch.ops.aten.bmm or op is torch.ops.aten.mm: + op(mat1, mat2) + else: + op(input, mat1, mat2) + + def pad_bench_fn(): + for fn in fns: + fn() + + if ( + torch._inductor.config.run_autoheuristic("pad_mm") + and op is torch.ops.aten.mm + ): + ah_should_pad = run_autoheuristic( + mat1, + mat2, + orig_bench_fn, + pad_bench_fn, + m_padded_length, + k_padded_length, + n_padded_length, + do_bench, + mat1_pre_padded, + mat2_pre_padded, + ori_time, + ori_time_key, + key, + ) + if ah_should_pad is not None: + return ah_should_pad + + if ori_time is None: + ori_time = do_bench(orig_bench_fn) + set_cached_base_mm_benchmark_time(ori_time_key, ori_time) + + pad_time = do_bench(pad_bench_fn) + return should_pad(key, ori_time, pad_time) + + +def get_context( + mat1: Tensor, + mat2: Tensor, + mat1_pre_padded: bool, + mat2_pre_padded: bool, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, +) -> AHContext: + context = AHContext() + + context.add_feature("m", mat1.shape[0]) + context.add_feature("k", mat1.shape[1]) + context.add_feature("n", mat2.shape[1]) + + context_add_strides(context, "mat1", mat1.stride()) + context_add_strides(context, "mat2", mat2.stride()) + + context.add_feature("m_padded_length", m_padded_length) + context.add_feature("k_padded_length", k_padded_length) + context.add_feature("n_padded_length", n_padded_length) + + context.add_feature("mat1_align_size", get_alignment_size(mat1)) + context.add_feature("mat2_align_size", get_alignment_size(mat2)) + + context.add_feature("mat1_dtype", mat1.dtype, is_categorical=True) + context.add_feature("mat2_dtype", mat2.dtype, is_categorical=True) + + context.add_feature("prepadded_mat1", mat1_pre_padded, is_categorical=True) + context.add_feature("prepadded_mat2", mat2_pre_padded, is_categorical=True) + + context_add_using_tf32(context, mat1.dtype) + return context + + +def run_autoheuristic( + mat1: Tensor, + mat2: Tensor, + orig_bench_fn: Callable[[], None], + pad_bench_fn: Callable[[], None], + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, + do_bench: Callable[[Callable[[], Any]], float], + mat1_pre_padded: bool, + mat2_pre_padded: bool, + ori_time: float, + ori_time_key: str, + key: str, +) -> Optional[bool]: + def feedback_fn( + choice: str, + ) -> Optional[float]: + if choice == orig_choice: + return do_bench(orig_bench_fn) + elif choice == pad_choice: + return do_bench(pad_bench_fn) + return None + + def fallback() -> str: + return "autotune" + + orig_choice = "orig" + pad_choice = "pad" + choices = [orig_choice, pad_choice] + feedback = LocalFeedback(feedback_fn) # type: ignore[arg-type] + context = get_context( + mat1, + mat2, + mat1_pre_padded, + mat2_pre_padded, + m_padded_length, + k_padded_length, + n_padded_length, + ) + name = "pad_mm" + autoheuristic = AutoHeuristic( + fallback=fallback, + choices=choices, + feedback=feedback, + context=context, + name=name, + augment_context=pad_mm_operations(), + precondition=pad_mm_precondition, + ) + choice = autoheuristic.get_choice() + choice2should_pad = {orig_choice: False, pad_choice: True, "autotune": None} + ah_should_pad = choice2should_pad.get(choice, None) + + if torch._inductor.config.collect_autoheuristic(name): + ah_ori_time = autoheuristic.get_collected_feedback(orig_choice) + ah_pad_time = autoheuristic.get_collected_feedback(pad_choice) + + # if precondition is not satisfied, autoheuristic does not collect data + if ah_ori_time is not None and ah_pad_time is not None: + if ori_time is None: + set_cached_base_mm_benchmark_time(ori_time_key, ah_ori_time) + return should_pad(key, ah_ori_time, ah_pad_time) + if ah_should_pad is not None: + set_cached_should_pad(key, ah_should_pad) + return ah_should_pad + + +def mm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor: + return aten.mm(mat1, mat2) + + +def should_pad_mm(match: Match) -> bool: + mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2")) + return should_pad_common(mat1, mat2) and should_pad_bench( + match, mat1, mat2, torch.ops.aten.mm + ) + + +def pad_mat1( + mat1: Tensor, *, m_padded_length: int, k_padded_length: int, is_bmm: bool = False +) -> Tensor: + if k_padded_length != 0 or m_padded_length != 0: + # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding + pad_arg = [0, k_padded_length, 0, m_padded_length] + if is_bmm: + pad_arg.extend((0, 0)) + return aten.constant_pad_nd(mat1, pad_arg) + else: + return mat1 + + +def pad_mat2( + mat2: Tensor, *, k_padded_length: int, n_padded_length: int, is_bmm: bool = False +) -> Tensor: + if k_padded_length != 0 or n_padded_length != 0: + # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding + pad_arg = [0, n_padded_length, 0, k_padded_length] + if is_bmm: + pad_arg.extend((0, 0)) + return aten.constant_pad_nd(mat2, pad_arg) + else: + return mat2 + + +def pad_mm( + mat1: Tensor, + mat2: Tensor, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, + mat1_pre_padded: bool = False, + mat2_pre_padded: bool = False, +) -> Tensor: + if not mat1_pre_padded: + mat1 = pad_mat1( + mat1, m_padded_length=m_padded_length, k_padded_length=k_padded_length + ) + if not mat2_pre_padded: + mat2 = pad_mat2( + mat2, k_padded_length=k_padded_length, n_padded_length=n_padded_length + ) + res = aten.mm(mat1, mat2) + if m_padded_length != 0: + res = res[:-m_padded_length, :] + if n_padded_length != 0: + res = res[:, :-n_padded_length] + return res + + +def mm_replace(mat1: Tensor, mat2: Tensor) -> Tensor: + k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) + m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1)) + n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2)) + return pad_mm( + mat1, + mat2, + m_padded_length, + k_padded_length, + n_padded_length, + ) + + +def bmm_pattern(mat1: Tensor, mat2: Tensor) -> Tensor: + return aten.bmm(mat1, mat2) + + +def should_pad_bmm(match: Match) -> bool: + mat1, mat2 = fetch_fake_tensors(match, ("mat1", "mat2")) + return should_pad_common(mat1, mat2) and should_pad_bench( + match, mat1, mat2, torch.ops.aten.bmm + ) + + +def pad_bmm( + mat1: Tensor, + mat2: Tensor, + m_padded_length: int, + k_padded_length: int, + n_padded_length: int, + mat1_pre_padded: bool = False, + mat2_pre_padded: bool = False, +) -> Tensor: + if not mat1_pre_padded: + mat1 = pad_mat1( + mat1, + m_padded_length=m_padded_length, + k_padded_length=k_padded_length, + is_bmm=True, + ) + if not mat2_pre_padded: + mat2 = pad_mat2( + mat2, + k_padded_length=k_padded_length, + n_padded_length=n_padded_length, + is_bmm=True, + ) + res = aten.bmm(mat1, mat2) + if m_padded_length != 0: + res = res[:, :-m_padded_length, :] + if n_padded_length != 0: + res = res[:, :, :-n_padded_length] + return res + + +def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor: + k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1)) + n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2)) + m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1)) + return pad_bmm( + mat1, + mat2, + m_padded_length, + k_padded_length, + n_padded_length, + ) + + +@functools.cache +def _pad_mm_init() -> None: + from .joint_graph import patterns + + if torch.cuda.is_available(): + # workaround https://github.com/pytorch/pytorch/issues/97894 + device = "cuda" + else: + device = "cpu" + + # sizes/values dont actually matter for initial trace + # once we get a possible match we re-trace with the actual values and verify the match still holds + + dim2a = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True) + dim2b = functools.partial(torch.empty, (4, 4), device=device, requires_grad=True) + + dim3a = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True) + dim3b = functools.partial(torch.empty, (4, 4, 4), device=device, requires_grad=True) + + dim1a = functools.partial(torch.empty, (4), device=device, requires_grad=True) + + # workaround https://github.com/pytorch/pytorch/issues/97894 + # 0.113377 is a "magic" value that lets us recover the lost input arg relationship + rep = {"beta": 0.213377, "alpha": 0.113377} + + for pattern, replacement, args, workaround, extra_check in [ + ( + typing.cast(SearchFn, mm_pattern), + typing.cast(ReplaceFn, mm_replace), + [dim2a(), dim2b()], + {}, + should_pad_mm, + ), + ( + typing.cast(SearchFn, bmm_pattern), + typing.cast(ReplaceFn, bmm_replace), + [dim3a(), dim3b()], + {}, + should_pad_bmm, + ), + ( + typing.cast(SearchFn, addmm_pattern), + typing.cast(ReplaceFn, addmm_replace), + [dim1a(), dim2a(), dim2b()], + rep, + should_pad_addmm, + ), + ]: + assert isinstance(workaround, dict) # mypy is unable to infer the type properly + name = pattern.__name__ + + gen_register_replacement( + f"{name}_training", + pattern, + replacement, + args, + joint_fwd_bwd, + patterns, + extra_check=extra_check, + scalar_workaround=workaround, + ) + + gen_register_replacement( + f"{name}_inference", + pattern, + replacement, + args, + fwd_only, + patterns, + extra_check=extra_check, + scalar_workaround=workaround, + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/post_grad.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/post_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..ecf5a8e87e0fa03ca4e47d8c27d57ebcd7f73015 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/post_grad.py @@ -0,0 +1,1793 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +import itertools +import logging +import operator +from collections import Counter, defaultdict +from typing import Any, Callable, Optional, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +import torch._inductor as inductor +import torch.utils._pytree as pytree +from torch import fx +from torch._decomp import register_decomposition +from torch._dynamo.utils import counters +from torch._inductor import comms +from torch._inductor.virtualized import ops +from torch._logging import trace_structured +from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype +from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq +from torch.utils._ordered_set import OrderedSet + +from .. import config, ir, pattern_matcher +from ..codegen.common import custom_backend_passes +from ..comms import remove_fsdp2_unsharded_param_graph_input_usage +from ..fx_utils import FakeTensorUpdater, get_fake_args_kwargs, get_node_storage +from ..lowering import lowerings as L +from ..pattern_matcher import ( + _return_true, + Arg, + CallFunction, + CallFunctionVarArgs, + filter_nodes, + fwd_only, + get_arg_value, + get_mutation_region_id, + Ignored, + init_once_fakemode, + KeywordArg, + ListOf, + Match, + MultiOutputPattern, + MULTIPLE, + PatternMatcherPass, + register_graph_pattern, + register_replacement, + stable_topological_sort, +) +from ..utils import ( + decode_device, + get_all_devices, + get_gpu_type, + is_gpu, + is_pointwise_use, + OPTIMUS_EXCLUDE_POST_GRAD, +) +from ..virtualized import V +from .b2b_gemm import B2B_GEMM_PASS +from .ddp_fusion import fuse_ddp_communication +from .group_batch_fusion import group_batch_fusion_passes, POST_GRAD_FUSIONS +from .micro_pipeline_tp import micro_pipeline_tp_pass +from .pre_grad import is_same_dict, save_inductor_dict +from .reinplace import reinplace_inplaceable_ops +from .split_cat import POST_GRAD_PATTERNS + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims + +# First pass_patterns[0] are applied, then [1], then [2] +pass_patterns = [ + PatternMatcherPass(), + PatternMatcherPass(), + PatternMatcherPass(), +] + + +def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): + """ + Passes that run on after grad. This is called once on the forwards + graph and once on the backwards graph. + + The IR here has been normalized and functionalized. + """ + GraphTransformObserver = functools.partial( + torch.fx.passes.graph_transform_observer.GraphTransformObserver, + subsystem="post_grad_passes", + ) + + if not torch._dynamo.config.skip_fsdp_hooks: + remove_fsdp2_unsharded_param_graph_input_usage(gm.graph) + + if config.dce: + # has some issues with mutation in inference mode + gm.graph.eliminate_dead_code() + + if is_inference and config.reorder_for_locality: + GraphTransformObserver(gm, "reorder_for_locality").apply_graph_pass( + reorder_for_locality + ) + + fake_tensor_updater = FakeTensorUpdater(gm.graph) + + if post_grad_custom_pre_pass := config.post_grad_custom_pre_pass: + GraphTransformObserver(gm, "post_grad_custom_pre_pass").apply_graph_pass( + post_grad_custom_pre_pass + ) + + if torch._C._has_mkldnn: + if ( + config.cpp.enable_grouped_gemm_template + and config.max_autotune + and "CPP" in config.max_autotune_gemm_backends + ): + from .mkldnn_fusion import grouped_gemm_pass + + grouped_gemm_pass(gm.graph) + + if config.cpp.enable_concat_linear: + from .quantization import concat_linear_woq_int4 + + # Concat linear optimization for WOQ int4 + concat_linear_woq_int4(gm) + + if config.pattern_matcher: + lazy_init() + GraphTransformObserver(gm, "post_grad_custom_pre_pass").apply_graph_pass( + functools.partial(group_batch_fusion_passes, pre_grad=False) + ) + GraphTransformObserver(gm, "remove_noop_ops").apply_graph_pass(remove_noop_ops) + GraphTransformObserver(gm, "remove_assert_ops").apply_graph_pass( + remove_assert_ops + ) + for i, patterns in enumerate(pass_patterns): + GraphTransformObserver(gm, f"pass_pattern_{i}").apply_graph_pass( + patterns.apply + ) + for pass_name in config.post_grad_fusion_options: + # skip all patterns for group batch fusions or quantization patterns + if pass_name in POST_GRAD_FUSIONS or pass_name in OPTIMUS_EXCLUDE_POST_GRAD: + continue + pattern_matcher_pass = POST_GRAD_PATTERNS[pass_name] + inductor_before_change = save_inductor_dict( + [pattern_matcher_pass.pass_name] + ) + GraphTransformObserver(gm, pass_name).apply_graph_pass( + pattern_matcher_pass.apply + ) + if not is_same_dict(counters["inductor"], inductor_before_change): + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"{pattern_matcher_pass.pass_name}_post_grad", + "encoding": "string", + }, + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + if config.b2b_gemm_pass: + B2B_GEMM_PASS.apply(gm.graph) # type: ignore[arg-type] + + if config._micro_pipeline_tp: + micro_pipeline_tp_pass(gm.graph) + + if config._fuse_ddp_communication: + GraphTransformObserver(gm, "fuse_ddp_communication").apply_graph_pass( + lambda graph: fuse_ddp_communication( + graph, + config._fuse_ddp_communication_passes, + config._fuse_ddp_bucket_size, + ) + ) + + if post_grad_custom_post_pass := config.post_grad_custom_post_pass: + GraphTransformObserver(gm, "post_grad_custom_post_pass").apply_graph_pass( + post_grad_custom_post_pass + ) + + GraphTransformObserver(gm, "stable_sort").apply_graph_pass(stable_topological_sort) + + GraphTransformObserver(gm, "move_constructors_to_cuda").apply_graph_pass( + move_constructors_to_gpu + ) + + fake_tensor_updater.incremental_update() + + for device, custom_backend_pass in custom_backend_passes.items(): + if custom_backend_pass is not None: + gm_devices = [d.type for d in get_all_devices(gm)] + if device in gm_devices: + pass_name = "custom_backend_passes_" + device + GraphTransformObserver(gm, pass_name).apply_gm_pass(custom_backend_pass) + + # Keep these last, since they introduces mutation. Look at + # ./fx_passes/README.md for a discussion of mutation invariants. + GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass( + reinplace_inplaceable_ops + ) + GraphTransformObserver( + gm, "decompose_triton_kernel_wrapper_functional" + ).apply_graph_pass(decompose_triton_kernel_wrapper_functional) + GraphTransformObserver(gm, "decompose_auto_functionalized").apply_graph_pass( + decompose_auto_functionalized + ) + if not torch._dynamo.config.skip_fsdp_hooks: + GraphTransformObserver(gm, "reinplace_fsdp_all_gather").apply_graph_pass( + comms.reinplace_fsdp_all_gather + ) + GraphTransformObserver(gm, "decompose_scan_to_while_loop").apply_gm_pass( + decompose_scan_to_while_loop + ) + GraphTransformObserver(gm, "decompose_map_to_while_loop").apply_gm_pass( + decompose_map_to_while_loop + ) + + gm.recompile() + gm.graph.lint() + + +def prepare_softmax_pattern(x, dim): + xmax = x.amax(dim=dim, keepdim=True) + xsub = x - xmax + xexp = xsub.exp() + xsum = xexp.sum(dim=dim, keepdim=True) + return xmax, xsum, xsub, xexp + + +def prepare_softmax_replacement(x, dim): + """ + Return xsub since otherwise log-softmax can not be matched + due to a use of this intermediate node. Same reason to return + xsub.exp() for softmax. + """ + from torch._inductor.inductor_prims import prepare_softmax_online + + xmax, xsum = prepare_softmax_online(x, dim) + xsub = x - xmax + return xmax, xsum, xsub, xsub.exp() + + +def prepare_softmax_extra_check(match): + """ + We only have triton online softmax kernels currently. + """ + return ( + config.online_softmax + and match.kwargs["x"].meta["val"].device.type == "cuda" + and config.cuda_backend == "triton" + ) + + +def decompose_map_to_while_loop(gm: torch.fx.GraphModule): + """This is similar to decompose_scan_to_while_loop.""" + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.map_impl), + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + assert len(kwargs) == 0, ( + "kwargs of map are not merged into args before entering decompose_map_to_while_loop_pass" + ) + subgraph, fx_xs, fx_additional_inputs = args + sub_gm: torch.fx.GraphModule = getattr(gm, subgraph.target) + cur_node = match.nodes[0] + mapped_outputs = cur_node.meta["val"] + + def lower_to_while_loop(*args, **kwargs): + assert len(kwargs) == 0 + xs, additional_inputs = pytree.tree_unflatten(args, tree_spec) + assert isinstance(xs, (tuple, list)) and isinstance( + additional_inputs, (tuple, list) + ), (xs, additional_inputs) + map_length = xs[0].size(0) + loop_idx = torch.zeros([], dtype=torch.int64, device=torch.device("cpu")) + + # Similar to NOTE [Pre-allocate scan's output buffer] + bound_symbols = { + arg.node.expr: arg + for arg in pytree.tree_leaves((args, map_length)) + if isinstance(arg, torch.SymInt) + } + out_buffers = [ + torch.empty_strided( + resolve_shape_to_proxy(out.size(), bound_symbols), + resolve_shape_to_proxy(out.stride(), bound_symbols), + device=out.device, + dtype=out.dtype, + layout=out.layout, + requires_grad=out.requires_grad, + ) + for out in mapped_outputs + ] + + while_loop_operands = (loop_idx, out_buffers, xs) + while_loop_flat_operands, operands_spec = pytree.tree_flatten( + while_loop_operands + ) + while_loop_additional_inputs = additional_inputs + _, operands_and_additional_inputs_spec = pytree.tree_flatten( + (*while_loop_operands, additional_inputs) + ) + + def cond_fn(*flat_args): + loop_idx, _, _, _ = pytree.tree_unflatten( + flat_args, + operands_and_additional_inputs_spec, + ) + return loop_idx < map_length + + def body_fn(*flat_args): + loop_idx, out_bufs, xs, additional_inputs = pytree.tree_unflatten( + flat_args, + operands_and_additional_inputs_spec, + ) + + idx_int = loop_idx.item() + torch.ops.aten._assert_scalar.default(idx_int >= 0, "") + torch.ops.aten._assert_scalar.default(idx_int < map_length, "") + sub_xs = [torch.ops.aten.select.int(x, 0, idx_int) for x in xs] + outs = sub_gm(*sub_xs, *additional_inputs) + + for out, buffer in zip(outs, out_bufs): + buffer_slice = torch.ops.aten.select.int(buffer, 0, idx_int) + buffer_slice.copy_(out) + return loop_idx + 1, *out_bufs, *xs + + _, final_out, _ = pytree.tree_unflatten( + torch.ops.higher_order.while_loop( + cond_fn, + body_fn, + tuple(while_loop_flat_operands), + tuple(while_loop_additional_inputs), + ), + operands_spec, + ) + return (final_out,) + + lower_to_while_loop_args, tree_spec = pytree.tree_flatten( + (fx_xs, fx_additional_inputs) + ) + match.replace_by_example( + lower_to_while_loop, lower_to_while_loop_args, run_functional_passes=False + ) + + graph_pass.apply(gm) + + for node in gm.graph.find_nodes( + op="call_function", target=torch.ops.higher_order.map_impl + ): + raise AssertionError("map is not lowered to while_loop") + + +def resolve_shape_to_proxy( + shape: list[Union[int, torch.SymInt]], bound_symbols: dict[Any, Any] +): + """ + Given a list of symints/ints, this function returns a calculated expression of bound_symbols' values. + When we trace this function, we'll get a graph with call_function nodes that describes how the shape expr is + computed from bound_symbols' values. + + Suppose shape = (s1*s2, s1+s2) and bound_symbols = {s1: arg0, s2: arg1}, the result will be + (arg0 * arg1, arg0 + arg1). + """ + from torch.utils._sympy.interp import sympy_interp + from torch.utils._sympy.reference import PythonReferenceAnalysis + + ret = [] + for s in shape: + if isinstance(s, torch.SymInt): + ret.append( + sympy_interp( + PythonReferenceAnalysis, + bound_symbols, + s.node.expr, + ), + ) + else: + assert isinstance(s, int) + ret.append(s) + return ret + + +def decompose_scan_to_while_loop(gm: torch.fx.GraphModule): + """ + NOTE [decompose scan to while_loop] + This pass decomposes `scan` to `while_loop` by replacing the scan fx_node with a while_loop hop. + + Suppose we have a function f: + + def f(): + init = torch.zeros([]) + xs = torch.arange(4) + ys = [] + for i in range(xs.size(0)): + init = xs[i] + init + ys.append(init) + + # Return the final carry and stack the intermediates + return init, torch.stack(ys) + + We could rewrite it with a scan with the benefits of reducing compilation time/binary size, reducing + memory usage, supporting loops over unbacked shapes and cudagraph etc. + + def g(): + def step_fn(init: torch.Tensor, x: torch.Tensor): + next_init = x + init + return next_init, next_init + + init = torch.zeros([]) + xs = torch.arange(4) + final_carry, ys = torch._higher_order.scan(step_fn, init, xs) + return final_carry, ys + + This pass will rewrite scan into: + + def k(): + init = torch.zeros([]) + xs = torch.arange(4) + + # we create a loop_idx and loop through xs.shape[0] + loop_idx = torch.zeros([]) + ys = torch.empty_strided(_shape_stride_of_ys) + def cond_fn(loop_idx, ys, init, xs): + return loop_idx < xs.shape[0] + + # we pre-allocate the output buffer ys and inplace + # copy the y of each intermediate into a slice. + # NOTE [Pre-allocate scan's output buffer]. + def body_fn(loop_idx, ys, init, xs): + int_idx = loop_idx.item() + next_init, y = step_fn(init, xs[int_idx]) + ys[int_idx].copy_(y) + return loop_idx + 1, ys, next_init, xs + + final_carry, _, _, ys = torch._higher_order.while_loop(cond_fn, body_fn, (loop_idx, ys, init, xs)) + return final_carry, ys + """ + + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.scan), + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.scan import _extract_carry_and_out + + assert len(kwargs) == 0, ( + "kwargs of scan are not merged into args before entering decompose_scan_to_while_loop_pass" + ) + + combine_subgraph, fx_init, fx_xs, fx_additional_inputs = args + assert combine_subgraph.op == "get_attr", "first arg is not combine_subgraph" + sub_gm: torch.fx.GraphModule = getattr(gm, combine_subgraph.target) + cur_node = match.nodes[0] + num_init_leaves = len(fx_init) + _, ys_outputs = _extract_carry_and_out(cur_node.meta["val"], num_init_leaves) + + def lower_to_while_loop(*args, **kwargs): + """ + The traced graph of this function will be used to replace the original scan fx_node. + """ + assert len(kwargs) == 0 + + # Step 1: construct necessary inputs to while_loop based on scan's input. + ( + init, + xs, + additional_inputs, + ) = pytree.tree_unflatten(args, tree_spec) + scan_length = xs[0].size(0) + loop_idx = torch.zeros([], dtype=torch.int64, device=torch.device("cpu")) + + # NOTE [Pre-allocate scan's output buffer] + # In order to pre-allocate the output buffer for ys, we rely on the meta of scan's fx_node. + # However, the meta consists of concrete symints, we need to bind those symints with + # proxies in order to trace the torch.empyt_strided call correctly. + # + # Also note that basic free symbols of tensor's shapes are guaranteed to be lifted as subgraph inputs + # in dynamo so we can always re-construct the sym expression from placeholders. + # See Note [Auto lift basic free symbols when create_graph_input] for how this is done. + bound_symbols = { + arg.node.expr: arg + for arg in pytree.tree_leaves((args, scan_length)) + if isinstance(arg, torch.SymInt) + } + ys_outs = [ + torch.empty_strided( + resolve_shape_to_proxy(ys_out.size(), bound_symbols), + resolve_shape_to_proxy(ys_out.stride(), bound_symbols), + device=ys_out.device, + dtype=ys_out.dtype, + layout=ys_out.layout, + requires_grad=ys_out.requires_grad, + ) + for ys_out in ys_outputs + ] + + while_loop_operands = (loop_idx, ys_outs, init, xs) + flat_operands, operands_spec = pytree.tree_flatten(while_loop_operands) + _, operands_and_additional_inputs_spec = pytree.tree_flatten( + (*while_loop_operands, additional_inputs) + ) + + # Step 2: create the cond_fn and body_fn for while_loop + def cond_fn(*flat_args): + loop_idx, _, _, _, _ = pytree.tree_unflatten( + flat_args, operands_and_additional_inputs_spec + ) # type: ignore[has-type] + return loop_idx < scan_length # type: ignore[has-type] + + def body_fn(*flat_args): + loop_idx, ys_outs, carry, xs, additional_inputs = pytree.tree_unflatten( + flat_args, + operands_and_additional_inputs_spec, # type: ignore[has-type] + ) + + idx_int = loop_idx.item() + torch.ops.aten._assert_scalar.default(idx_int >= 0, "") + torch.ops.aten._assert_scalar.default(idx_int < scan_length, "") + sub_xs = [torch.ops.aten.select.int(x, 0, idx_int) for x in xs] + next_carry, ys = _extract_carry_and_out( + sub_gm(*(list(carry) + sub_xs + list(additional_inputs))), + num_init_leaves, + ) + for y, y_out in zip(ys, ys_outs): + y_out_slice = torch.ops.aten.select.int(y_out, 0, idx_int) + y_out_slice.copy_(y) + return loop_idx + 1, *ys_outs, *next_carry, *xs + + # Step 3: call the while_loop operator + _, ys_outs, last_carry, _ = pytree.tree_unflatten( + torch.ops.higher_order.while_loop( + cond_fn, + body_fn, + tuple(flat_operands), + tuple(additional_inputs), + ), + operands_spec, + ) + return list(last_carry) + list(ys_outs) + + lower_to_while_loop_args, tree_spec = pytree.tree_flatten( + ( + fx_init, + fx_xs, + fx_additional_inputs, + ) + ) + match.replace_by_example( + lower_to_while_loop, + lower_to_while_loop_args, + run_functional_passes=False, + ) + + graph_pass.apply(gm) + + for node in gm.graph.find_nodes( + op="call_function", target=torch.ops.higher_order.scan + ): + raise AssertionError("scan is not lowered to while_loop") + + +@init_once_fakemode +def lazy_init(): + if torch._C._has_mkldnn: + from . import decompose_mem_bound_mm # noqa: F401 + from .mkldnn_fusion import _mkldnn_fusion_init + + _mkldnn_fusion_init() + + # Put this patterns in post-grad pass rather than joint-graph + # pass since otherwise there will be perf/peak-memory regression: + # https://github.com/pytorch/pytorch/issues/148141 + register_replacement( + prepare_softmax_pattern, + prepare_softmax_replacement, + [torch.empty(4, 8)], + scalar_workaround=dict(dim=-1), + trace_fn=fwd_only, + pass_dicts=pass_patterns[1], + extra_check=prepare_softmax_extra_check, + ) + + +def reorder_for_locality(graph: torch.fx.Graph): + if torch.distributed.is_available(): + + def check(): + # This is a wait node, and `other_node`` is some collective node. + # Eager semantics allow waits to be issued in a different order than + # the collectives. Reordering this wait node might reorder collectives + # which cause hangs. Once we have SPMD mode, we can safely reorder them. + # However, increasing the locality between a collective and its wait node + # is generally worse for performance. + return node.target != torch.ops._c10d_functional.wait_tensor.default + else: + + def check(): + return True + + def visit(other_node): + if ( + other_node.op == "call_function" + and other_node.target != operator.getitem + and all((n in seen_nodes) for n in other_node.users) + and get_mutation_region_id(graph, node) + == get_mutation_region_id(graph, other_node) + and check() + ): + # move node's producers right before it + node.prepend(other_node) + + seen_nodes = OrderedSet[torch.fx.Node]() + + # only reorder nodes before the first copy_ in the graph. + # copy_ will appear at the end of functionalized graphs when there is mutation on inputs, + # and this reordering doesn't work well with mutation + first_copy = next( + iter(graph.find_nodes(op="call_function", target=torch.ops.aten.copy_.default)), + None, + ) + past_mutating_epilogue = True if first_copy is None else False + + for node in reversed(graph.nodes): + seen_nodes.add(node) + if not past_mutating_epilogue: + past_mutating_epilogue = node is first_copy + continue + + torch.fx.map_arg((node.args, node.kwargs), visit) + + +def register_lowering_pattern( + pattern, extra_check=_return_true, pass_number=1 +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + """ + Register an aten to inductor IR replacement pattern + """ + return pattern_matcher.register_lowering_pattern( + pattern, extra_check, pass_dict=pass_patterns[pass_number] + ) + + +################################################################################ +# Actual patterns below this point. +# Priority of patterns is: +# - later output nodes first +# - order patterns are defined in +################################################################################ + + +def is_valid_mm_plus_mm(match: Match): + if not (config.max_autotune or config.max_autotune_gemm): + return False + + *_b1, m1, k1 = match.kwargs["mat1"].meta.get("tensor_meta").shape + *_b2, k2, n1 = match.kwargs["mat2"].meta.get("tensor_meta").shape + if k1 != k2: + return False + + *_b1, m2, k3 = match.kwargs["mat3"].meta.get("tensor_meta").shape + *_b2, k4, n2 = match.kwargs["mat4"].meta.get("tensor_meta").shape + if k3 != k4: + return False + + if m1 != m2 or n1 != n2: + return False + + return True + + +def scatter_upon_const_tensor_extra_check(m): + if not config.optimize_scatter_upon_const_tensor: + return False + full_shape = m.kwargs["shape"] + selector = m.kwargs["selector"] + dim = m.kwargs["dim"] + if dim < 0: + dim += len(full_shape) + + selector_ft = selector.meta["val"] + assert selector_ft.dim() == len(full_shape) + + for idx, select_sz, full_sz in zip( + itertools.count(), selector_ft.shape, full_shape + ): + if idx == dim: + continue + + # TODO: the pattern can be updated to support the case that index tensor + # is shorter. But that will need a more complex condition expression + # especially for multi-dimensional tensors. + # Skip it for now. + if isinstance(full_sz, fx.Node): + full_sz = full_sz.meta["val"] + if select_sz < full_sz: + return False + + # Actually we can support small size larger than 1. It would be a bit + # tedius. E.g., we load all the index values (not many) and compare + # them with the position in tensor to decide what value to return. + return selector_ft.size(dim) == 1 + + +@register_lowering_pattern( + CallFunction( + aten.scatter.value, + CallFunction( + aten.full, + KeywordArg("shape"), + KeywordArg("background_val"), + dtype=KeywordArg("dtype"), + ), + KeywordArg("dim"), + KeywordArg("selector"), + KeywordArg("val"), # scalar value + ), + extra_check=scatter_upon_const_tensor_extra_check, +) +def scatter_upon_const_tensor( + match: Match, shape, background_val, dtype, dim, selector, val +): + """ + Match the pattern of full+scatter into a pointwise. + + TODO: Right now the scatter value must be a scalar. But we could support it + when it is a tensor as well. + """ + from torch._inductor import metrics + + metrics.num_matches_for_scatter_upon_const_tensor += 1 + + selector_loader = selector.make_loader() + + def inner_fn(idx): + selector_idx = list(idx) + selector_idx[dim] = 0 + + selector = selector_loader(selector_idx) + return ops.where( + selector == ops.index_expr(idx[dim], torch.int64), + ops.constant(val, dtype), + ops.constant(background_val, dtype), + ) + + return ir.Pointwise.create( + device=selector.get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=shape, + ) + + +@register_lowering_pattern( + CallFunction( + aten.add, + CallFunction(aten.mm, KeywordArg("mat1"), KeywordArg("mat2")), + CallFunction(aten.mm, KeywordArg("mat3"), KeywordArg("mat4")), + ), + extra_check=is_valid_mm_plus_mm, +) +def mm_plus_mm(match: Match, mat1, mat2, mat3, mat4): + return inductor.kernel.mm_plus_mm.tuned_mm_plus_mm(mat1, mat2, mat3, mat4) + + +@register_graph_pattern( + CallFunction( + aten.cumsum.default, + CallFunction( + torch.ops.aten.full.default, + KeywordArg("shape"), + KeywordArg("fill_value"), + dtype=KeywordArg("dtype"), + layout=Ignored(), + device=KeywordArg("device"), + pin_memory=False, + _users=MULTIPLE, + ), + KeywordArg("dim"), + _users=MULTIPLE, + ), + pass_dict=pass_patterns[1], +) +def pointless_cumsum_replacement(match: Match, shape, fill_value, device, dtype, dim): + """Based on a pattern in OPTForCausalLM""" + + if is_integer_dtype(dtype) or is_boolean_dtype(dtype): + # cumsum promotes all integral types to int64 + dtype = torch.int64 + + def repl(*shape): + dim_size = shape[dim] + idx = torch.arange(1, dim_size + 1, device=device, dtype=dtype) + + inter_shape = [1] * len(shape) + inter_shape[dim] = dim_size + return (idx * fill_value).view(inter_shape).expand(shape) + + # only replace the output node, not all nodes + match.nodes = [match.output_node()] + match.replace_by_example(repl, list(shape)) + + +_cat_1 = CallFunction(aten.cat, Arg(), 1, _users=2) + + +@register_lowering_pattern( + CallFunction( + aten.cat, + [ + _cat_1, + CallFunction( + aten.slice, + _cat_1, + 1, + 0, + KeywordArg("size"), + ), + ], + 1, + ) +) +def cat_slice_cat(match, cat_input, size, dim=1): + """ + This is an example of a more complex pattern where cat_1 is used + multiple times inside the pattern. We fold 2 calls to cat into one. + + Matches: + cat_1: f32[1024, 4077] = torch.ops.aten.cat.default([add_26, primals_217], 1) + slice_1: f32[1024, 4077] = torch.ops.aten.slice.Tensor(cat_1, 0, 0, 9223372036854775807) + slice_2: f32[1024, 19] = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 19) + cat_2: f32[1024, 4096] = torch.ops.aten.cat.default([cat_1, slice_2], 1) + + + Rewrite to: + slice_2 = torch.ops.aten.slice.Tensor(add_26, 1, 0, 19) + cat_2 = torch.ops.aten.cat.default([add_26, primals_217, slice2], 1) + """ + first, *rest = cat_input + # Optimization is optional, because we can just not fold the cat + # size should be within first.get_size()[dim] such that the optimization is valid. + # For negative `end`, we currently fallback to not optimizing. + if size >= 0 and V.graph.sizevars.statically_known_leq(size, first.get_size()[dim]): + # fold 2 cats into 1 cat + return L[aten.cat]( + [ + first, + *rest, + L[aten.slice](first, dim, 0, size), + ], + dim, + ) + else: + # don't expect to hit this case, just fall back + tmp = L[aten.cat](cat_input, dim) + return L[aten.cat]( + [ + tmp, + L[aten.slice](tmp, dim, 0, size), + ], + dim, + ) + + +def is_valid_splitwithsizes_cat(match): + split_nodes = filter_nodes(match.nodes, aten.split_with_sizes) + cat_nodes = filter_nodes(match.nodes, aten.cat) + get_item_nodes = filter_nodes(match.nodes, operator.getitem) + if len(split_nodes) != 1 or len(cat_nodes) != 1: + return False + split_node, cat_node = split_nodes[0], cat_nodes[0] + # The dim of split and cat should match for passthrough + if get_arg_value(split_node, 2, "dim") != get_arg_value(cat_node, 1, "dim"): + return False + get_item_args = OrderedSet( + get_arg_value(get_item_node, 1) for get_item_node in get_item_nodes + ) + assert None not in get_item_args + split_sizes = get_arg_value(split_node, 1, "split_sizes") + # All parts of split should be included in the cat + if get_item_args != OrderedSet(range(len(split_sizes))): + return False + # The order of get_item_args should same with cat_node used. + # For example, if the split_node like split_with_sizes(input, [2, 2, 3], 1), + # the cat node should be like cat([get_item(0), get_item(1), get_item(2)], 1). + cat_items_args_order = [ + get_arg_value(item_node, 1) for item_node in get_arg_value(cat_node, 0) + ] + if cat_items_args_order != list(range(len(split_sizes))): + return False + + return True + + +def same_meta(node1: torch.fx.Node, node2: torch.fx.Node): + """True if two nodes have the same metadata""" + val1 = node1.meta.get("val") + val2 = node2.meta.get("val") + return ( + val1 is not None + and val2 is not None + and statically_known_true(sym_eq(val1.size(), val2.size())) + and val1.layout == val2.layout + and val1.dtype == val2.dtype + and val1.device == val2.device + and ( + val1.layout != torch.strided + or statically_known_true(sym_eq(val1.stride(), val2.stride())) + ) + ) + + +noop_registry: dict[Any, Any] = {} + + +def register_noop_decomp(targets, nop_arg=0): + def register_fun(cond): + register_decomposition(targets, registry=noop_registry, unsafe=True)( + (cond, nop_arg) # type: ignore[arg-type] + ) + return cond + + return register_fun + + +@register_noop_decomp(aten.slice) +def slice_noop(self, dim=0, start=None, end=None, step=1): + if start is None or end is None: + return False + + slice_dim_size = self.shape[dim] + if ( + statically_known_true(sym_eq(start, 0)) + and ( + statically_known_true(end >= 2**63 - 1) + or statically_known_true(end >= slice_dim_size) + ) + and statically_known_true(sym_eq(step, 1)) + ): + return True + return False + + +@register_noop_decomp(aten.slice_scatter, 1) +def slice_scatter_noop(self, src, dim=0, start=None, end=None, step=1): + if start is None: + start = 0 + if end is None: + end = 2**63 - 1 + slice_scatter_dim_size = self.shape[dim] + if ( + self.shape == src.shape + and start == 0 + and ( + statically_known_true(end >= 2**63 - 1) + or statically_known_true(end >= slice_scatter_dim_size) + ) + and step == 1 + ): + return True + return False + + +@register_noop_decomp(aten.repeat) +def repeat_noop(self, repeats): + return all(r == 1 for r in repeats) + + +@register_noop_decomp(aten.constant_pad_nd) +def constant_pad_nd(x, padding, fill_value=0): + return all(p == 0 for p in padding) + + +@register_noop_decomp(torch.ops.prims.convert_element_type) +def convert_element_type_noop(x, dtype: torch.dtype): + return x.dtype == dtype + + +@register_noop_decomp(torch.ops.prims.device_put) +def device_put_noop(x, device, non_blocking=True): + return x.device == decode_device(device) + + +@register_noop_decomp([aten.ceil, aten.floor, aten.round, aten.trunc]) +def int_noop(x): + return is_integer_dtype(x.dtype) + + +@register_noop_decomp([aten.pow]) +def pow_noop(a, b): + return isinstance(b, int) and b == 1 + + +@register_noop_decomp([aten.cat], lambda args: args[0][0]) +def cat_noop(inputs, dim=0): + return len(inputs) == 1 + + +@register_noop_decomp(aten.view.default) +def view_default_noop(arg, size): + return statically_known_true(sym_eq(arg.shape, tuple(size))) + + +@register_noop_decomp(aten.view.dtype) +def view_dtype_noop(arg, dtype): + return arg.dtype == dtype + + +# Note, we also always have a check for identical metadata, which is why these +# are safe +@register_noop_decomp([aten.copy], nop_arg=1) +@register_noop_decomp([aten.alias, aten.clone]) +def true_noop(*args, **kwargs): + return True + + +def remove_noop_ops(graph: torch.fx.Graph): + """ + Removes both operations that are essentially aten.clone and operations that are essentially aten.alias from the graph. + """ + inputs = OrderedSet[torch.fx.Node]() + input_storages = OrderedSet[Union[int, None]]() + output_storages = OrderedSet[Union[int, None]]() + + for node in graph.find_nodes(op="placeholder"): + inputs.add(node) + input_storages.add(get_node_storage(node)) + + output_node = next(iter(reversed(graph.nodes))) + assert output_node.op == "output" + outputs = output_node.args[0] + if not isinstance(outputs, (list, tuple)): + # nested subgraphs can have singleton outputs + outputs = (outputs,) + for out in outputs: + if isinstance(out, torch.fx.Node): + output_storages.add(get_node_storage(out)) + + for node in graph.nodes: + if node.target in noop_registry: + cond, src_index = noop_registry[node.target] + if isinstance(src_index, int): + src = node.args[src_index] + else: + src = src_index(node.args) + if not isinstance(src, torch.fx.Node): + continue + # Don't introduce new aliasing between inputs and outputs. + # See fx_passes/README.md for a discussion of why this is + # necessary. + node_storage = get_node_storage(node) + src_storage = get_node_storage(src) + node_is_view = node_storage == src_storage + if ( + not node_is_view + and node_storage in output_storages + and (src_storage in input_storages or src_storage in output_storages) + ): + continue + + # Even if input and outputs are expected to alias, + # don't make "node is src" True + if ( + node_is_view + and node in output_node.args + and (src in inputs or src in output_node.args) + ): + continue + + is_valid, args, kwargs = get_fake_args_kwargs(node) + if not is_valid: + continue + if same_meta(node, src) and cond(*args, **kwargs): + node.replace_all_uses_with(src) + graph.erase_node(node) + + +def remove_assert_ops(graph: torch.fx.Graph): + """ + Removes aten._assert_tensor_metadata.default op because + 1) it will be lowered to a no-op in inductor + 2) it can block fusion, such as unfuse_bias_add_to_pointwise fusion. + + This op could come from aten.to functionalization in export. + + For example, if we have a graph like below + + %addmm = aten.addmm.default(%linear_bias, %arg3_1, %permute) + %_assert_tensor_metadata = aten._assert_tensor_metadata.default(%addmm, None, None, torch.float16) + %convert_element_type_3 = prims.convert_element_type.default(%addmm, torch.float32) + %pow_1 = aten.pow.Tensor_Scalar(%convert_element_type_3, 2) + + We still want to fuse add from addmm with pow, instead of fusing add with mm, according to unfuse_bias_add_to_pointwise fusion. + + However, aten._assert_tensor_metadata.default is not a pointwise op, and would fail the should_prefer_unfused_addmm check. + + We remove this op so it doesn't block fusion decisions. It's safe because this op is lowered to a no-op with @register_lowering. + + """ + for node in graph.find_nodes( + op="call_function", target=torch.ops.aten._assert_tensor_metadata.default + ): + graph.erase_node(node) + + +def decompose_triton_kernel_wrapper_functional(graph): + """Decomposes triton_kernel_wrapper_functional nodes into clones and the underlying + mutation node. + + We assume that the reinplacing pass runs before this; the reinplacing pass + tells us (via rewriting the arguments or .meta to those nodes) which + Tensors we should clone and which Tensors are safe to reinplace. + """ + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.triton_kernel_wrapper_functional), + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.triton_kernel_wrap import ( + triton_kernel_wrapper_functional_dense, + ) + + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + # NB: we combine (args, kwargs) into flat args for replacing. + # This is replace_by_example uses make_fx which does not support + # tracing a function with kwargs. + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + return (triton_kernel_wrapper_functional_dense(*args, **kwargs),) + + match.replace_by_example(decomp, flat_args, run_functional_passes=False) + + graph_pass.apply(graph) + + for node in graph.find_nodes( + op="call_function", + target=torch.ops.higher_order.triton_kernel_wrapper_functional, + ): + raise AssertionError("triton_kernel_wrapper_functional was not removed") + + +def decompose_auto_functionalized(graph): + """Decomposes auto_functionalized nodes into clones and the underlying + mutation node. + + We assume that the reinplacing pass runs before this; the reinplacing pass + tells us (via rewriting the arguments or .meta to those nodes) which + Tensors we should clone and which Tensors are safe to reinplace. + """ + graph_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized), + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.auto_functionalize import auto_functionalized_dense + + only_clone_these_tensors = tuple( + match.nodes[0].meta.get("only_clone_these_tensors", []) + ) + + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + # NB: we combine (args, kwargs) into flat args for replacing. + # This is replace_by_example uses make_fx which does not support + # tracing a function with kwargs. + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + assert len(args) == 1 + mode = args[0] + return auto_functionalized_dense(mode, only_clone_these_tensors, **kwargs) + + match.replace_by_example(decomp, flat_args, run_functional_passes=False) + + @register_graph_pattern( + CallFunctionVarArgs(torch.ops.higher_order.auto_functionalized_v2), + pass_dict=graph_pass, + ) + def _(match: Match, *args, **kwargs): + from torch._higher_order_ops.auto_functionalize import ( + auto_functionalized_v2_dense, + ) + + only_clone_these_bases = tuple( + match.nodes[0].meta.get("only_clone_these_tensors", []) + ) + + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + def _maybe_resolve_constant_get_attr(node): + # Resolve getattr node to its value because they don't always have meta["val"] + if ( + isinstance(node, torch.fx.Node) + and node.op == "get_attr" + and "val" not in node.meta + ): + const_attr = getattr(graph.owning_module, node.target) # type: ignore[arg-type] + assert isinstance( + const_attr, (torch.fx.GraphModule, pytree.TreeSpec) + ), (type(const_attr), const_attr) + return const_attr + return node + + flat_args = [_maybe_resolve_constant_get_attr(arg) for arg in flat_args] + + # NB: we combine (args, kwargs) into flat args for replacing. + # This is replace_by_example uses make_fx which does not support + # tracing a function with kwargs. + def decomp(*flat_args): + args, kwargs = pytree.tree_unflatten(flat_args, spec) + assert len(args) == 1 + mutable_op = args[0] + return auto_functionalized_v2_dense( + mutable_op, only_clone_these_bases, **kwargs + ) + + match.replace_by_example(decomp, flat_args, run_functional_passes=False) + + graph_pass.apply(graph) + + # We need to remove the get_attr registered for _constant_schema and the + # auto_functioanlized's graph module (it's replaced with original ) when auto_functionalize a hop. + _to_remove = [] + for node in graph.nodes: + if node.op == "get_attr" and len(node.users) == 0: + _to_remove.append(node) + if hasattr(graph.owning_module, node.target) and isinstance( + getattr(graph.owning_module, node.target), torch.fx.GraphModule + ): + delattr(graph.owning_module, node.target) + for node in _to_remove: + graph.erase_node(node) + + graph.lint() + + for _ in graph.find_nodes( + op="call_function", target=torch.ops.higher_order.auto_functionalized + ): + raise AssertionError("auto_functionalized was not removed") + + for _ in graph.find_nodes( + op="call_function", target=torch.ops.higher_order.auto_functionalized_v2 + ): + raise AssertionError("auto_functionalized_v2 was not removed") + + +@register_lowering_pattern( + CallFunction( + aten.cat, + ListOf( + CallFunction( + operator.getitem, + CallFunction( + aten.split_with_sizes, + KeywordArg("input_"), + Ignored(), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + ), + ), + Ignored(), + ), + pass_number=2, + extra_check=is_valid_splitwithsizes_cat, +) +def splitwithsizes_cat_replace(match, input_): + return input_ + + +def is_valid_cat_splitwithsizes(match): + cat_nodes = filter_nodes(match.nodes, aten.cat) + split_nodes = filter_nodes(match.nodes, aten.split_with_sizes) + if len(split_nodes) != 1 or len(cat_nodes) != 1: + return False + split_node, cat_node = split_nodes[0], cat_nodes[0] + + # the cat node has other users: can't eliminate + if len(cat_node.users) > 1: + return False + + # the dim of the cat and split should match + dim = get_arg_value(split_node, 2, "dim") + if dim != get_arg_value(cat_node, 1, "dim"): + return False + + cat_inputs = list(get_arg_value(cat_node, 0)) + split_sizes = get_arg_value(split_node, 1, "split_sizes") + # the number of input tensors in cat and the + # length of the split sizes should match + if len(cat_inputs) != len(split_sizes): + return False + + for cat_input, split_size in zip(cat_inputs, split_sizes): + # each cat input tensor's size along dim + # should match the corresponding split size + if "val" not in cat_input.meta: + return False + cat_input_size = cat_input.meta["val"].size(dim) + if cat_input_size != split_size: + return False + + return True + + +@register_lowering_pattern( + CallFunction( + aten.split_with_sizes, + CallFunction( + aten.cat, + KeywordArg("input_"), + Ignored(), + _users=MULTIPLE, + ), + Ignored(), + Ignored(), + ), + pass_number=2, + extra_check=is_valid_cat_splitwithsizes, +) +def cat_splitwithsizes_replace(match, input_): + return input_ + + +def view_to_reshape(gm): + """ + Replace view ops in the GraphModule to reshape ops. + """ + subgraph_names: OrderedSet[str] = OrderedSet( + x.target for x in gm.graph.find_nodes(op="get_attr") + ) + + for child_name, child_mod in gm.named_children(): + if child_name in subgraph_names and isinstance(child_mod, torch.fx.GraphModule): + view_to_reshape(child_mod) + + for nd in gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.view.default + ): + nd.target = torch.ops.aten.reshape.default + + +def should_prefer_unfused_addmm(match): + inp = match.kwargs["inp"] + if not is_gpu(inp.meta["val"].device.type): + return False + + output = match.output_node() + return all(is_pointwise_use(use) for use in output.users) + + +@register_graph_pattern( + CallFunction(aten.addmm, KeywordArg("inp"), Arg(), Arg()), + pass_dict=pass_patterns[2], + extra_check=should_prefer_unfused_addmm, +) +def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp): + def repl(inp, x1, x2): + return x1 @ x2 + inp + + match.replace_by_example(repl, [inp, mat1, mat2]) + + +def is_valid_addmm_fusion(match): + mat1, mat2 = match.args + inp = match.kwargs["inp"] + + if not ( + isinstance(inp, torch.fx.Node) and isinstance(inp.meta["val"], torch.Tensor) + ): + return False # Input is a number + + in_shape = inp.meta["val"].shape + mm_shape = mat1.meta["val"].shape[0], mat2.meta["val"].shape[1] + matched = is_expandable_to(in_shape, mm_shape) + if not matched: + return False # Shape mismatch + + return not should_prefer_unfused_addmm(match) + + +@register_graph_pattern( + CallFunction( + aten.add, + CallFunction(aten.mm, Arg(), Arg()), + KeywordArg("inp"), + ), + pass_dict=pass_patterns[2], + extra_check=is_valid_addmm_fusion, +) +@register_graph_pattern( + CallFunction( + aten.add, + KeywordArg("inp"), + CallFunction(aten.mm, Arg(), Arg()), + ), + pass_dict=pass_patterns[2], + extra_check=is_valid_addmm_fusion, +) +def addmm(match, mat1, mat2, *, inp): + def repl(inp, mat1, mat2): + return aten.addmm(inp, mat1, mat2) + + match.replace_by_example(repl, [inp, mat1, mat2]) + + +def register_partial_reduction_pattern(): + "Reuse partial reductions in complete reductions" + + # post grad equivalents + equiv_red = { + aten.amax.default: aten.max.default, + aten.amin.default: aten.min.default, + } + + # TODO: to support other reductions like sum, would need to skip + # lower precision reductions since partial output would need to be kept at fp32. + for red_op in (aten.amax.default, aten.amin.default): + inp = KeywordArg("input") + partial_reduc = CallFunction( + red_op, inp, KeywordArg("reduced_dims"), KeywordArg("keepdim") + ) + full_reduc = CallFunction([red_op, equiv_red[red_op]], inp) + + @register_graph_pattern( + MultiOutputPattern([partial_reduc, full_reduc]), pass_dict=pass_patterns[2] + ) + def reuse_partial(match, input, reduced_dims, keepdim): + partial_red, full_red = match.output_nodes() + + # if they're small, reuse not worth it + if not statically_known_true(input.meta["val"].numel() >= 4096): + return True + + def replacement(inp: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + partial = partial_red.target(inp, reduced_dims, keepdim) + complete = full_red.target(partial) + return (partial, complete) + + counters["inductor"]["partial_reduction_reuse"] += 1 + match.replace_by_example(replacement, [input]) + + +register_partial_reduction_pattern() + + +def check_shape_cuda_and_fused_int_mm_mul_enabled(match): + return ( + config.force_fuse_int_mm_with_mul + and len(getattr(match.args[2].meta.get("val"), "shape", [])) == 2 + and getattr(match.args[2].meta.get("val"), "is_cuda", False) + ) + + +def is_index_put_and_requires_h2d_sync_for_gpu_value(node): + from torch.fx.operator_schemas import normalize_function + + if node.target not in [ + torch.ops.aten.index_put.default, + torch.ops.aten.index_put_.default, + ]: + return False + # Inductor falls back to aten.index_put_. + # index_put_ will will call nonzero() and perform a H2D sync if + # any of its indices are bool/byte tensors + # However, it will short-circuit this H2D sync and run mask_fill_ + # if the value we are putting is a cpu scalar. + # Therefore, when inductor sees an index_put_ with byte tensor indices, + # it should *not* convert the cpu scalar value into a gpu tensor. + args_, _kwargs = normalize_function(node.target, node.args, node.kwargs) # type: ignore[misc] + any_byte_bool_indices = False + indices = args_[1] + for i in indices: + if i is not None and i.meta["val"].dtype in [torch.bool, torch.int8]: + any_byte_bool_indices = True + + val = args_[2].meta["val"] + val_is_cpu_scalar = val.device.type == "cpu" and val.numel() == 1 + # If both these conditions hold, then converting the val + # to a gpu tensor will incur a H2D sync when inductor calls aten.index_put_ + return any_byte_bool_indices and val_is_cpu_scalar + + +class ConstructorMoverPass: + def __init__( + self, target: str, allow_outputs: bool = False, allow_inputs: bool = False + ) -> None: + """ + Move constructors from cpu to the target_device. + + Sweeps through the module, looking for constructor nodes that can be moved + to the target_device. + + A constructor node can be moved to the target_device iff all of its users + can also be moved (tested by cannot_be_moved). Otherwise, all dependent + constructor nodes won't be moved. + + - target: target device type + - allow_outputs: allow outputs to be moved + - allow_inputs: allow inputs to be moved + """ + + self.target = target + self.allow_inputs = allow_inputs + self.allow_outputs = allow_outputs + + assert isinstance(target, str), ( + "target should be a string representing the device type. " + f"Got: {type(target).__name__}" + ) + + def allow_cpu_device(self, node: fx.Node) -> bool: + """ + Returns whether a node that returns a tensor on the target device may have + cpu tensors as input. + """ + return node.target in ( + torch.ops.aten.index.Tensor, + torch.ops.aten.index_put.default, + torch.ops.aten.index_put_.default, + torch.ops.aten.copy.default, + torch.ops.aten.copy_.default, + torch.ops.aten.slice_scatter.default, + ) + + def is_on_target_device(self, node: fx.Node) -> bool: + """ + Returns whether a node is on the target device. + """ + node_device = self.get_node_device(node) + return node_device is not None and node_device.type == self.target + + def is_cpu_scalar_tensor(self, node: fx.Node) -> bool: + """ + Returns whether a node is a cpu scalar tensor. + """ + device = self.get_node_device(node) + is_cpu = device is not None and device.type == "cpu" + ten = node.meta.get("val") + is_scalar = isinstance(ten, torch.Tensor) and len(ten.size()) == 0 + return is_cpu and is_scalar + + def all_inputs_are_cpu_scalar_or_on_target_device(self, node: fx.Node) -> bool: + """ + Returns whether a node's inputs are either cpu scalar tensors or + on the target device. + """ + inputs = ( + inp + for inp in itertools.chain(node.args, node.kwargs.values()) + if isinstance(inp, fx.Node) + ) + return all( + self.is_cpu_scalar_tensor(inp) or self.is_on_target_device(inp) + for inp in inputs + ) + + def cannot_be_moved(self, node: fx.Node) -> bool: + """ + Returns whether a node can be moved to the target device. + + If this function returns False, it means that this node and all of its users + won't be moved into the target device. + """ + if node.target == "output": + return not self.allow_outputs + + if not ( + isinstance(node.target, torch._ops.OpOverload) + and node.target.namespace in ("prims", "aten") + ): + return True + + if is_index_put_and_requires_h2d_sync_for_gpu_value(node): + return True + + return False + + def get_node_device(self, node: fx.Node) -> Optional[torch.device]: + """ + Get the device of a node. + """ + ten = node.meta.get("val") + return None if not isinstance(ten, torch.Tensor) else ten.device + + def get_cpu_indeg_count(self, graph: fx.Graph) -> dict[fx.Node, int]: + """ + Get the number of cpu inputs to a node + """ + cpu_indeg: dict[fx.Node, int] = Counter() + + for node in graph.nodes: + cpu_count = 0 + + def add_cpu_inp(node): + nonlocal cpu_count + device = self.get_node_device(node) + cpu_count += device is not None and device.type == "cpu" + + pytree.tree_map_only(fx.Node, add_cpu_inp, (node.args, node.kwargs)) + + if cpu_count: + cpu_indeg[node] = cpu_count + + return cpu_indeg + + def __call__(self, graph: fx.Graph) -> None: + target_devices = OrderedSet[torch.device]() + constructors = [] + cpu_placeholders: OrderedSet[fx.Node] = OrderedSet() + + for node in graph.nodes: + device = self.get_node_device(node) + if device and device.type == self.target: + target_devices.add(device) + + if ( + self.allow_inputs + and node.op == "placeholder" + and self.is_cpu_scalar_tensor(node) + ): + cpu_placeholders.add(node) + constructors.append(node) + continue + + if not ( + isinstance(node.target, torch._ops.OpOverload) + and node.target.namespace in ("prims", "aten") + ): + continue + + if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target): + continue + + if not node.kwargs.get("device") == torch.device("cpu"): + continue + + constructors.append(node) + + # not handling multiple target devices initially + if not constructors or len(target_devices) != 1: + return + + movable_constructors = self.find_movable_constructors(graph, constructors) + + target_device = next(iter(target_devices)) + for node in movable_constructors: + if node in cpu_placeholders: + with graph.inserting_after(node): + gpu_node = graph.call_function( + torch.ops.prims.device_put.default, (node, target_device) + ) + node.replace_all_uses_with( + gpu_node, + lambda x: x != gpu_node + and x.target != torch.ops.aten.copy_.default, + ) + + # noop elimination if there are other device_put for gpu_node to + # target device. Alternatively, we could just move the other device_put + # earlier in the graph, but that is not supported in fx graph yet. + noop_device_puts = [ + user + for user in gpu_node.users + if user.target == torch.ops.prims.device_put.default + and user.args[1] == target_device + ] + for noop in noop_device_puts: + noop.replace_all_uses_with(gpu_node) + graph.erase_node(noop) + else: + kwargs = node.kwargs.copy() + kwargs["device"] = target_device + node.kwargs = kwargs + + def find_movable_constructors( + self, graph: fx.Graph, constructors: list[fx.Node] + ) -> OrderedSet[fx.Node]: + """ + Starting from the cpu constructors, iterate through the graph and test that all of their + downstream uses can safely be moved to cpu. + """ + cpu_indeg: dict[fx.Node, int] = self.get_cpu_indeg_count(graph) + + # which constructors cannot be moved to gpu + cannot_move_to_gpu = OrderedSet[fx.Node]() + + # For any node in the graph, which constructors does it have a dependency on + constructor_dependencies: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict( + OrderedSet + ) + + # if a cpu node has a dependency on two different cpu constructors, + # then if either constructor cannot be moved to gpu, the other cannot as well. + # In this case any node with a dependency on one will have a dependency on the other + equal_constructor_sets: dict[fx.Node, OrderedSet[fx.Node]] = { + c: OrderedSet([c]) for c in constructors + } + + def make_dependencies_equivalent( + set1: OrderedSet[fx.Node], set2: OrderedSet[fx.Node] + ) -> OrderedSet[fx.Node]: + # could use union find but not worth complexity here + set1.update(set2) + for obj in set1: + equal_constructor_sets[obj] = set1 + return set1 + + queue: list[fx.Node] = list(constructors) + + for c in queue: + constructor_dependencies[c].add(c) + + while queue: + node = queue.pop() + dependencies = constructor_dependencies[node] + + for user in node.users: + if self.cannot_be_moved(user): + cannot_move_to_gpu.update(dependencies) + break + + # this node was used on a op which takes in multiple devices and output a gpu + # tensor. we can convert its cpu input to gpu without making further changes + if self.allow_cpu_device(user) and self.is_on_target_device(user): + del cpu_indeg[user] + elif ( + self.allow_inputs + and self.all_inputs_are_cpu_scalar_or_on_target_device(user) + ): + # this node takes only cpu scalar tensors or gpu tensors as inputs + # and outputs a gpu tensor. we can convert its cpu scalar inputs to gpu + # without making further changes + del cpu_indeg[user] + else: + # otherwise, we should continue look at its downstream uses + cpu_indeg[user] -= 1 + if cpu_indeg[user] == 0: + del cpu_indeg[user] + queue.append(user) + + unioned_set = make_dependencies_equivalent( + dependencies, constructor_dependencies[user] + ) + constructor_dependencies[user] = unioned_set + + for node in cpu_indeg: + if constructor_dependencies[node]: + cannot_move_to_gpu.update(constructor_dependencies[node]) + + all_cannot_move_to_gpu = cannot_move_to_gpu.copy() + for constructor in cannot_move_to_gpu: + all_cannot_move_to_gpu.update(equal_constructor_sets[constructor]) + + return OrderedSet(constructors) - all_cannot_move_to_gpu + + +def move_constructors_to_gpu(graph: fx.Graph) -> None: + """ + Moves intermediary tensors which are constructed on the cpu to gpu when safe + """ + + # cudagraph does not support cpu tensors. In this pass, we update the graph + # by explicitly moving cpu scalar tensors to gpu when profitable, relying on + # graph partition to split off this data copy, and cudagraphifying + # the remaining gpu ops. + allow_inputs_outputs = ( + True + if ( + torch._inductor.config.triton.cudagraphs + and torch._inductor.config.graph_partition + ) + else False + ) + ConstructorMoverPass( + get_gpu_type(), + allow_inputs=allow_inputs_outputs, + allow_outputs=allow_inputs_outputs, + )(graph) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/pre_grad.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/pre_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..142d62a39ca4e3a92542f3e77de9c081f9fd48b5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/pre_grad.py @@ -0,0 +1,862 @@ +# mypy: allow-untyped-defs +import copy +import itertools +import logging +import types +from collections.abc import Sequence +from typing import Optional + +import torch +import torch.nn as nn +from torch._dynamo.utils import counters, detect_fake_mode +from torch._logging import trace_structured +from torch.fx.experimental.optimization import ( + matches_module_pattern, + replace_node_module, +) +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.fx.passes.shape_prop import ShapeProp +from torch.nn import functional as F +from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_conv_bn_weights + +from .. import config +from ..fx_utils import matches_module_function_pattern +from ..pattern_matcher import ( + init_once_fakemode, + PatternMatcherPass, + stable_topological_sort, +) +from ..utils import is_cpu_device, pass_execution_and_save +from .group_batch_fusion import group_batch_fusion_passes, PRE_GRAD_FUSIONS +from .misc_patterns import numpy_compat_normalization +from .split_cat import PRE_GRAD_PATTERNS + + +log = logging.getLogger(__name__) + +efficient_conv_bn_eval_pass = PatternMatcherPass( + pass_name="efficient_conv_bn_eval_pass" +) + +fuse_split_linear_add_pass = PatternMatcherPass( + pass_name="fuse_split_linear_add_pass", +) +fuse_chunk_squeeze_cat_pass = PatternMatcherPass( + pass_name="fuse_chunk_squeeze_cat_pass", +) +remove_reshape_pass = PatternMatcherPass( + pass_name="remove_reshape_pass", +) + +# based on predispatch aten IR +normalization_pass_aten = PatternMatcherPass(pass_name="normalization_pass_aten") +merge_splits_pass_aten = PatternMatcherPass(pass_name="merge_splits_pass_aten") +split_cat_pass_aten = PatternMatcherPass(pass_name="split_cat_pass_aten") +unbind_stack_pass_aten = PatternMatcherPass(pass_name="unbind_stack_pass_aten") +merge_getitem_cat_pass_aten = PatternMatcherPass( + pass_name="merge_getitem_cat_pass_aten" +) +merge_stack_tahn_unbind_pass_aten = PatternMatcherPass( + pass_name="merge_stack_tahn_unbind_pass_aten" +) +mutate_cat_pass_aten = PatternMatcherPass(pass_name="mutate_cat_pass_aten") +remove_split_with_size_one_pass_aten = PatternMatcherPass( + pass_name="remove_split_with_size_one_pass_aten" +) + + +def save_inductor_dict(pass_to_compare=None): + if not pass_to_compare: + pass_to_compare = list(config.pre_grad_fusion_options.keys()) + list( + config.post_grad_fusion_options.keys() + ) + return {p: dict(counters["inductor"]).get(p, 0) for p in pass_to_compare} + + +def is_same_dict(inductor_dict, optimus_dict): + for pass_name, count in optimus_dict.items(): + if count != dict(inductor_dict).get(pass_name, 0): + return False + return True + + +def shape_prop(mod) -> None: + return None + + +def normalize_node_kwargs_pass(graph): + return None + + +def fuse_parallel_linear_pass(graph): + return None + + +def remove_split_ops(graph, shape_prop): + return None + + +def remove_split_ops_pass(graph): + remove_split_ops(graph.owning_module, shape_prop) + + +def fuse_chunk_reshape_unsqueeze_concat_pass(graph): + return None + + +def fuse_chunk_reshape_concat_pass(graph): + return None + + +def remove_noop_pass(graph): + return None + + +def stack_to_unsqueeze_pass(graph): + return None + + +def merge_concats_pass(graph): + return None + + +def relu_nan_to_num(graph): + return None + + +def fuse_split_getitem_squeeze_cat(graph): + return None + + +def use_triton_dot_compress(graph): + return None + + +def use_triton_lce_replace_simple_LCE_helper(gm, shape_prop): + return None + + +def use_triton_lce_replace_simple_LCE(graph): + return use_triton_lce_replace_simple_LCE_helper(graph.owning_module, shape_prop) + + +def use_triton_lce_replace_normal_LCE_helper(gm, shape_prop): + return None + + +def use_triton_lce_replace_normal_LCE(graph): + return use_triton_lce_replace_simple_LCE_helper(graph.owning_module, shape_prop) + + +def use_matmul_lce_replace_normal_LCE(graph): + return None + + +def use_matmul_fuse_lce_replace_first_LCE(graph): + return None + + +@init_once_fakemode +def lazy_init(): + from . import efficient_conv_bn_eval, split_cat # noqa: F401 + + if config.is_fbcode(): + from . import fb # type: ignore[attr-defined] # noqa: F401 + + +def _get_pass_name_func(p): + if isinstance(p, PatternMatcherPass): + pass_name = p.pass_name + pass_func = p.apply + elif isinstance(p, types.FunctionType): + pass_name = p.__name__.lstrip("_") + pass_func = p + else: + pass_name = None + pass_func = None + + return pass_name, pass_func + + +def _run_pre_dispatch_passes( + gm: torch.fx.GraphModule, + example_inputs: Sequence[object] = (), + add_passes: Optional[str] = None, + remove_passes: Optional[str] = None, +) -> None: + # order matters + default_pass_list = [ + # normalize passes, must be called as the first passes + normalization_pass_aten, + normalize_node_kwargs_pass, + remove_noop_pass, + relu_nan_to_num, + fuse_chunk_reshape_concat_pass, + group_batch_fusion_passes, + normalize_node_kwargs_pass, + fuse_chunk_squeeze_cat_pass, + merge_concats_pass, + fuse_split_linear_add_pass, + remove_reshape_pass, + fuse_parallel_linear_pass, + remove_split_ops_pass, + stack_to_unsqueeze_pass, # run before fuse_chunk_reshape_unsqueeze_concat_pass + fuse_chunk_reshape_unsqueeze_concat_pass, + ] + + full_pass_list = default_pass_list + [ + fuse_split_getitem_squeeze_cat, + use_triton_dot_compress, + use_triton_lce_replace_simple_LCE, + use_triton_lce_replace_normal_LCE, + use_matmul_fuse_lce_replace_first_LCE, + use_matmul_lce_replace_normal_LCE, + ] + + log.info( + f"pre_grad_passes: add_passes: {add_passes}, remove_pass: {remove_passes}" # noqa: G004 + ) + add_passes_list = [] + remove_passes_list = [] + if add_passes: + add_passes_list = add_passes.split(",") + if remove_passes: + remove_passes_list = remove_passes.split(",") + + shape_prop = lambda mod: ShapeProp( # noqa: E731 + gm=mod, + # pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode` + fake_mode=detect_fake_mode(example_inputs), + ).propagate(*tuple(example_inputs)) + + for p in default_pass_list: + pass_name, pass_func = _get_pass_name_func(p) + # should not happen + if pass_name is None or pass_func is None: + continue + if pass_name in remove_passes_list: + continue + pass_execution_and_save( + pass_func, + gm, + example_inputs, + f"[Pre grad(predispatch IR)] Apply {pass_name} pass", + ) + + for p in full_pass_list: + pass_name, pass_func = _get_pass_name_func(p) + if pass_name is None or pass_func is None: + continue + if pass_name in add_passes_list: + pass_execution_and_save( + pass_func, + gm, + example_inputs, + f"[Pre grad(predispatch IR)] Apply {pass_name} pass", + ) + + # Remove noops at the end, which may be generated other passes. + pass_execution_and_save( + remove_noop_pass, + gm, + example_inputs, + "[Pre grad(predispatch IR)]Apply remove_noop pass", + ) + shape_prop(gm) + + +def pre_grad_passes( + gm: torch.fx.GraphModule, + example_inputs: Sequence[object] = (), + add_passes: Optional[str] = None, + remove_passes: Optional[str] = None, +) -> torch.fx.GraphModule: + """ + Apply passes on the input FX graph using Torch IR. + + WARNING: + The IR before grad is not functional or normalized, so it is harder + to write passes on this IR. Passes must be safe with respect to + aliasing and mutation and need to handle all possible arg schemas. + + Consider adding a new pass to post_grad.py or joint_graph.py which + are after functionalization and normalization. + """ + if config.pattern_matcher: + lazy_init() + if hasattr( + config, "fx_passes_numeric_check" + ) and config.fx_passes_numeric_check.get("pre_grad", False): + gm_before_fx_passes = gm.__copy__() + # explicitly run with predispatch atenIR based passes + if config.is_predispatch: + _run_pre_dispatch_passes(gm, example_inputs, add_passes, remove_passes) + else: + # We only log the graph with changes to avoid the excessive compilation time + # https://fb.workplace.com/groups/257735836456307/permalink/633533465543207/ + if example_inputs is not None: + gm = fuse_fx(gm, example_inputs) + numpy_compat_normalization(gm.graph) + # We should always do the normalization_pass first + if "normalization_pass" in config.pre_grad_fusion_options: + pattern_matcher_pass = PRE_GRAD_PATTERNS["normalization_pass"] + pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] + group_batch_fusion_passes(gm.graph, pre_grad=True) + for pass_name in config.pre_grad_fusion_options: + # skip all patterns for group batch fusions + if pass_name in PRE_GRAD_FUSIONS or pass_name == "normalization_pass": + continue + pattern_matcher_pass = PRE_GRAD_PATTERNS[pass_name] + inductor_before_change = save_inductor_dict( + [pattern_matcher_pass.pass_name] + ) + # we support run same pattern multiple times, the default is to run only once + counter = config.pre_grad_fusion_options[pass_name].get("counter", 1) + for _ in range(counter): + pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] + if not is_same_dict(counters["inductor"], inductor_before_change): + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"{pattern_matcher_pass.pass_name}_pre_grad", + "encoding": "string", + }, + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + # TODO: move efficient_conv_bn_eval_pass to the fusions dict too. + efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type] + + if config.pre_grad_custom_pass is not None: + with GraphTransformObserver(gm, "pre_grad_custom_pass"): + config.pre_grad_custom_pass(gm.graph) + stable_topological_sort(gm.graph) + + from .quantization import quant_lift_up + + quant_lift_up(gm) + + gm.graph.lint() + gm.recompile() + + if ( + config.pattern_matcher + and hasattr(config, "fx_passes_numeric_check") + and config.fx_passes_numeric_check.get("pre_grad", False) + and example_inputs is not None + ): + from .numeric_utils import numeric_check_if_enabled + + gm_after_fx_passes = gm.__copy__() + numeric_check_if_enabled( + gm_before_fx_passes, # type: ignore[possibly-undefined] + gm_after_fx_passes, + example_inputs, + config.fx_passes_numeric_check.get("num_iterations", 1), + config.fx_passes_numeric_check.get("precision", 1e-4), + ) + + return gm + + +def fuse_fx(gm: torch.fx.GraphModule, example_inputs) -> torch.fx.GraphModule: + is_cpu = is_cpu_device(example_inputs) + # pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode` + fake_mode = detect_fake_mode(example_inputs) + + gm = sink_cat_after_pointwise(gm) + if config.permute_fusion and not is_cpu: + # For linear permute fusion, we need to check input info to identify + # and perform proper permutation/transpose + ShapeProp(gm, fake_mode=fake_mode).propagate(*example_inputs) + with GraphTransformObserver(gm, "linear_permute_fusion"): + gm = linear_permute_fusion(gm) + with GraphTransformObserver(gm, "permute_linear_fusion"): + gm = permute_linear_fusion(gm) + with GraphTransformObserver(gm, "permute_matmul_fusion"): + gm = permute_matmul_fusion(gm) + + # make sure the autograd is disabled. + if torch.is_grad_enabled() or not is_cpu: + return gm + if config.freezing: + with GraphTransformObserver(gm, "remove_identity"): + gm = remove_identity(gm) + with GraphTransformObserver(gm, "fuse_conv_bn"): + gm = fuse_conv_bn(gm) + return gm + + +def fetch_attr(target: str, mod): + target_atoms = target.split(".") + attr_itr = mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +def remove_identity(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Removes all identity layers from the module. + """ + + class IdentityRemover(torch.fx.Transformer): + def call_module(self, target, args, kwargs): + if isinstance(self.submodules[target], nn.Identity): + assert len(args) == 1 + return args[0] + else: + return super().call_module(target, args, kwargs) + + return IdentityRemover(gm).transform() + + +def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModule: + """ + Fuses Convolution/BN layers for inference purposes. + """ + modules_patterns = [ + (torch.nn.Conv1d, torch.nn.BatchNorm1d), + (torch.nn.Conv2d, torch.nn.BatchNorm2d), + (torch.nn.Conv3d, torch.nn.BatchNorm3d), + ] + module_function_patterns = [ + (torch.nn.Conv1d, F.batch_norm), + (torch.nn.Conv2d, F.batch_norm), + (torch.nn.Conv3d, F.batch_norm), + ] + modules = dict(gm.named_modules()) + + class ConvBNFusion: + def __init__( + self, + bn_node, + conv_module, + bn_module=None, # For BN Module + bn_running_mean=None, # For Functional BN + bn_running_var=None, + bn_eps=None, + bn_weight=None, + bn_bias=None, + ) -> None: + self.bn_nodes = [ + bn_node, + ] + self.conv_module = conv_module + self.bn_module = bn_module + self.bn_running_mean = bn_running_mean + self.bn_running_var = bn_running_var + self.bn_eps = bn_eps + self.bn_weight = bn_weight + self.bn_bias = bn_bias + self.fusion_enabled = True + + def add_bn_node(self, bn_node): + self.bn_nodes.append(bn_node) + + def disable_fusion(self): + self.fusion_enabled = False + + def is_fusion_enabled(self): + return self.fusion_enabled + + conv_bn_to_fuse: dict[int, ConvBNFusion] = {} + for pattern in modules_patterns: + conv_bn_to_fuse.clear() + for node in gm.graph.nodes: + if matches_module_pattern(pattern, node, modules): + if len(node.args[0].users) > 1: # Output of conv is used by other nodes + continue + conv = modules[node.args[0].target] + bn = modules[node.target] + eval_mode = all(not n.training for n in [conv, bn]) + if not eval_mode: + continue + if not bn.track_running_stats: + continue + + # Do hash based on the module name of conv + hash_id = hash(node.args[0].target) + if hash_id not in conv_bn_to_fuse: + conv_bn_to_fuse[hash_id] = ConvBNFusion(node, conv, bn) + else: + if bn == conv_bn_to_fuse[hash_id].bn_module: + # Do fusion if same bn module + conv_bn_to_fuse[hash_id].add_bn_node(node) + else: + # Disable the conv bn folding if conv shared by different bn + conv_bn_to_fuse[hash_id].disable_fusion() + + for conv_bn_fusion in conv_bn_to_fuse.values(): + if conv_bn_fusion.is_fusion_enabled(): + bn_nodes = conv_bn_fusion.bn_nodes + conv = conv_bn_fusion.conv_module + bn = conv_bn_fusion.bn_module + + fused_conv = fuse_conv_bn_eval(conv, bn) + for bn_node in bn_nodes: + replace_node_module(bn_node.args[0], modules, fused_conv) + bn_node.replace_all_uses_with(bn_node.args[0]) + gm.graph.erase_node(bn_node) + + gm.graph.lint() + for pattern in module_function_patterns: + conv_bn_to_fuse.clear() + for node in gm.graph.nodes: + if matches_module_function_pattern(pattern, node, modules): + # TODO: support kwargs. + if len(node.args) != 8: + continue + conv = modules[node.args[0].target] + bn_training = node.args[5] + bn_eps = node.args[7] + if conv.training or bn_training: + continue + if type(bn_eps) is not float: + continue + + def _used_by_same_conv_module(users): + conv_module_name = users[0].args[0].target + return all( + conv_module_name == user.args[0].target for user in users + ) + + bn_args_is_constant = all( + n.op == "get_attr" + and (len(n.users) == 1 or _used_by_same_conv_module(list(n.users))) + for n in node.args[1:5] + ) + if not bn_args_is_constant: + continue + bn_running_mean = fetch_attr(node.args[1].target, gm) + bn_running_var = fetch_attr(node.args[2].target, gm) + bn_weight = fetch_attr(node.args[3].target, gm) + bn_bias = fetch_attr(node.args[4].target, gm) + if bn_running_mean is None or bn_running_var is None: + continue + + # Do hash based on the module name of conv + hash_id = hash(node.args[0].target) + if hash_id not in conv_bn_to_fuse: + conv_bn_to_fuse[hash_id] = ConvBNFusion( + node, + conv, + bn_running_mean=bn_running_mean, + bn_running_var=bn_running_var, + bn_eps=bn_eps, + bn_weight=bn_weight, + bn_bias=bn_bias, + ) + else: + if ( + hash(bn_running_mean) + == hash(conv_bn_to_fuse[hash_id].bn_running_mean) + and hash(bn_running_var) + == hash(conv_bn_to_fuse[hash_id].bn_running_var) + and torch.allclose( + torch.tensor(bn_eps), + torch.tensor(conv_bn_to_fuse[hash_id].bn_eps), + ) + and hash(bn_weight) == hash(conv_bn_to_fuse[hash_id].bn_weight) + and hash(bn_bias) == hash(conv_bn_to_fuse[hash_id].bn_bias) + ): + # Do fusion if same functional bn + conv_bn_to_fuse[hash_id].add_bn_node(node) + else: + # Disable the conv bn folding if conv shared by different bn + conv_bn_to_fuse[hash_id].disable_fusion() + + for conv_bn_fusion in conv_bn_to_fuse.values(): + if conv_bn_fusion.is_fusion_enabled(): + bn_nodes = conv_bn_fusion.bn_nodes + conv = conv_bn_fusion.conv_module + bn_running_mean = conv_bn_fusion.bn_running_mean + bn_running_var = conv_bn_fusion.bn_running_var + bn_eps = conv_bn_fusion.bn_eps + bn_weight = conv_bn_fusion.bn_weight + bn_bias = conv_bn_fusion.bn_bias + + fused_conv = copy.deepcopy(conv) + fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights( + fused_conv.weight, + fused_conv.bias, + bn_running_mean, + bn_running_var, + bn_eps, + bn_weight, + bn_bias, + ) + for bn_node in bn_nodes: + replace_node_module(bn_node.args[0], modules, fused_conv) + bn_node.replace_all_uses_with(bn_node.args[0]) + gm.graph.erase_node(bn_node) + gm.graph.lint() + gm.recompile() + + return gm + + +class NormalizedLinearNode: + def __init__(self, node: torch.fx.Node) -> None: + assert node.op == "call_function" + assert node.target in [torch.nn.functional.linear] + self.node: torch.fx.Node = node + + def get_input(self) -> torch.fx.Node: + if len(self.node.args) > 0: + return self.node.args[0] # type: ignore[return-value] + else: + return self.node.kwargs["input"] # type: ignore[return-value] + + def get_weight(self) -> torch.fx.Node: + if len(self.node.args) > 1: + return self.node.args[1] # type: ignore[return-value] + else: + return self.node.kwargs["weight"] # type: ignore[return-value] + + def get_bias(self) -> torch.fx.Node: + if len(self.node.args) > 2: + return self.node.args[2] # type: ignore[return-value] + else: + return self.node.kwargs["bias"] if "bias" in self.node.kwargs else None # type: ignore[return-value] + + +class NormalizedMatmulNode: + def __init__(self, node: torch.fx.Node) -> None: + assert node.op == "call_function" + assert node.target in [torch.bmm, torch.matmul] + self.node: torch.fx.Node = node + + def get_input(self) -> torch.fx.Node: + if len(self.node.args) > 0: + return self.node.args[0] # type: ignore[return-value] + else: + return self.node.kwargs["input"] # type: ignore[return-value] + + def get_other(self) -> torch.fx.Node: + if len(self.node.args) > 1: + return self.node.args[1] # type: ignore[return-value] + else: + return self.node.kwargs["other"] # type: ignore[return-value] + + +def check_permute(node: torch.fx.Node) -> bool: + ranks = len(node.meta["tensor_meta"].shape) + if len(node.args) > 3: + permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] # type: ignore[operator] + elif ( + "permutation" in node.kwargs + and node.kwargs["permutation"] is not None + and len(node.kwargs["permutation"]) > 2 # type: ignore[arg-type] + ): + permutation = [i % ranks for i in node.kwargs["permutation"]] # type: ignore[operator, union-attr] + else: + return False + allowed_permutation = list(range(ranks)) + allowed_permutation[-1] = ranks - 2 + allowed_permutation[-2] = ranks - 1 + return permutation == allowed_permutation + + +def sink_cat_after_pointwise(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + def one_user(node): + users = list(node.users) + return users[0] if len(users) == 1 else None + + def is_view(node): + return node.op == "call_method" and node.target == "view" + + def is_pointwise_unary(node): + ops = "call_function", "call_method" + pointwise = torch.relu, torch.tanh, "relu", "tanh" + return node.op in ops and node.target in pointwise + + g = module.graph + for node in g.nodes: + if node.op != "call_function" or node.target != torch.cat: + continue + + cat_or_view = node + while True: + user = one_user(cat_or_view) + if not user or not is_view(user): + break + cat_or_view = user + + if user and is_pointwise_unary(user): + with g.inserting_before(node): + + def cat_args(tensors, dim=0): + return tensors, dim + + tensors, dim = cat_args(*node.args, **node.kwargs) + new_kwargs = { + name: val for name, val in user.kwargs.items() if name != "input" + } + new_tensors = [ + g.create_node(user.op, user.target, args=(arg,), kwargs=new_kwargs) + for arg in tensors + ] + new_cat = g.create_node( + "call_function", torch.cat, args=(new_tensors, dim) + ) + user.replace_all_uses_with(cat_or_view) + node.replace_all_uses_with(new_cat) + g.erase_node(user) + g.erase_node(node) + g.lint() + module.recompile() + return module + + +def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.find_nodes(op="call_method", target="permute"): + if check_permute(node): + if len(node.args) > 0: + input_node = node.args[0] + else: + input_node = node.kwargs["input"] + if ( + input_node.op == "call_function" + and input_node.target == torch.nn.functional.linear + ): + normalized = NormalizedLinearNode(input_node) + input = normalized.get_input() + weight = normalized.get_weight() + bias = normalized.get_bias() + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + linear_transpose, args=(input, weight, bias) + ) + node.replace_all_uses_with(fused_node) + module.graph.erase_node(node) + if len(input_node.users) == 0: + module.graph.erase_node(input_node) + + module.graph.lint() + module.recompile() + return module + + +# Y1 = X * W^T + bias +# Y2 = Y1.permute(0, 2, 1) +# ----> +# Y2 = (W * X^T + bias.unsqueeze(-1))^T +def linear_transpose( + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] +) -> torch.Tensor: + if bias is None: + return torch.matmul(weight, input.transpose(-1, -2)) + return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1) + + +def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.find_nodes( + op="call_function", target=torch.nn.functional.linear + ): + if len(node.args) > 0: + input_node = node.args[0] + else: + input_node = node.kwargs["input"] + if ( + input_node.op == "call_method" + and input_node.target == "permute" + and check_permute(input_node) + ): + normalized = NormalizedLinearNode(node) + if len(input_node.args) > 0: + input = input_node.args[0] + else: + input = input_node.kwargs["input"] + weight = normalized.get_weight() + bias = normalized.get_bias() + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + transpose_linear, args=(input, weight, bias) + ) + node.replace_all_uses_with(fused_node) + module.graph.erase_node(node) + if len(input_node.users) == 0: + module.graph.erase_node(input_node) + + module.graph.lint() + module.recompile() + return module + + +def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in itertools.chain( + module.graph.find_nodes(op="call_function", target=torch.bmm), + module.graph.find_nodes(op="call_function", target=torch.matmul), + ): + normalized = NormalizedMatmulNode(node) + input_A_node = normalized.get_input() + input_B_node = normalized.get_other() + input_A = input_A_node + input_B = input_B_node + Atrans = Btrans = False + if ( + input_A_node.op == "call_method" + and input_A_node.target == "permute" + and check_permute(input_A_node) + ): + Atrans = True + if len(input_A_node.args) > 0: + input_A = input_A_node.args[0] # type: ignore[assignment] + else: + input_A = input_A_node.kwargs["input"] # type: ignore[assignment] + + if ( + input_B_node.op == "call_method" + and input_B_node.target == "permute" + and check_permute(input_B_node) + ): + Btrans = True + if len(input_B_node.args) > 0: + input_B = input_B_node.args[0] # type: ignore[assignment] + else: + input_B = input_B_node.kwargs["input"] # type: ignore[assignment] + + if Atrans or Btrans: + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + transpose_matmul, + args=(input_A, input_B, Atrans, Btrans), + ) + node.replace_all_uses_with(fused_node) + module.graph.erase_node(node) + if Atrans and len(input_A_node.users) == 0: + module.graph.erase_node(input_A_node) + if Btrans and len(input_B_node.users) == 0: + module.graph.erase_node(input_B_node) + + module.graph.lint() + module.recompile() + return module + + +# X1 = X.permute(0, 2, 1) +# Y1 = X1 * W1^T + bias1 +# ----> +# Y2 = X1.transpose(-1, -2) * W1^T + bias1 +def transpose_linear( + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] +) -> torch.Tensor: + if bias is None: + return torch.matmul(input.transpose(-1, -2), weight.t()) + return torch.matmul(input.transpose(-1, -2), weight.t()) + bias + + +def transpose_matmul( + A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool +) -> torch.Tensor: + if Atrans: + A = A.transpose(-1, -2) + if Btrans: + B = B.transpose(-1, -2) + return torch.matmul(A, B) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/quantization.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..23a9a5679d279c279f941a2dc742459acafa8191 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/quantization.py @@ -0,0 +1,3891 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import copy +import functools +import itertools +import math +import operator +from typing import Any + +import torch +from torch._dynamo.utils import counters +from torch.fx.experimental.symbolic_shapes import has_free_symbols +from torch.fx.node import map_arg + +from .. import config +from ..lowering import lowerings as L, require_channels_last +from ..pattern_matcher import ( + Arg, + CallFunction, + filter_nodes, + KeywordArg, + ListOf, + Match, + stable_topological_sort, +) +from ..utils import pad_listlike +from .freezing_patterns import register_freezing_graph_pattern +from .post_grad import register_lowering_pattern + + +aten = torch.ops.aten +prims = torch.ops.prims +quantized_decomposed = torch.ops.quantized_decomposed +quantized = torch.ops.quantized + +# Only for per tensor quant since permute may changes the channel idx +_PER_TENSOR_QUANTIZE_OPS = [ + quantized_decomposed.quantize_per_tensor.default, + quantized_decomposed.quantize_per_tensor.tensor, +] + +_VIEW_OPS = [ + aten.transpose.int, + aten.permute.default, + aten.view.default, +] + +""" +The quantization.py file primarily incorporates passes related to quantization fusion +in inductor, includes: +1. Dequant Promotion; +2. Conv/GEMM weight prepack with oneDNN Library; +3. Conv/GEMM quantization fusion with output quant node (if have); +4. Other pointwise operators' quantization fusion like: qmaxpool2d, qcat and more; + +It also involves int8-mixed-fp32 and int8-mixed-bf16 quantization. The main difference +of patterns for int8-mixed-bf16, comparing with int8-mixed-fp32, is +1. There is to(dtype=torch.bfloat16) node at the inputs of activation and weight for Conv/GEMM. +2. There is to(dtype=torch.float32) node at the outputs of Conv/GEMM before inputs to next quant node. +Refer to: https://github.com/pytorch/pytorch/issues/111640 for detail design of int8-mixed-bf16 +quantization. +""" + + +def _get_pattern_output_dtype(match: Match): + """ + Get the pattern's output dtype from node's meta + Assume only 1 output node in this matched pattern. + """ + pattern_output_nodes = match.output_nodes() + assert len(pattern_output_nodes) == 1 + output_node = pattern_output_nodes[0] + assert isinstance(output_node, torch.fx.Node) + output_dtype = output_node.meta["val"].dtype + assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] + return output_dtype + + +def _may_generate_pattern_with_dtype_convert( + pattern, dtype=Arg(), with_dtype_convert=True, users=1 +): + if with_dtype_convert: + return CallFunction( + prims.convert_element_type.default, + pattern, + dtype, + _users=users, + ) + else: + return pattern + + +def _may_generate_pattern_with_reshape(pattern, reshape_size=Arg(), with_reshape=True): + if with_reshape: + return CallFunction( + torch.ops.aten.reshape.default, + pattern, + reshape_size, + ) + else: + return pattern + + +def _generate_linear_t_pattern( + _dequant_per_channel_pattern, + dtype, +): + assert dtype in [torch.float32, torch.bfloat16] + t_pattern = CallFunction( + aten.permute.default, + _may_generate_pattern_with_dtype_convert( + _dequant_per_channel_pattern, + KeywordArg("autocast_wgt_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("permute_axes"), + ) + return t_pattern + + +def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16): + # only insert to_dtype if is_bf16 is True + computation_call = _may_generate_pattern_with_dtype_convert( + call_fn, dtype=KeywordArg("to_float"), with_dtype_convert=is_bf16, users=users + ) + return unary_fusion(computation_call) + + +def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False): + dequantize_per_tensor_activation_pattern = CallFunction( + quantized_decomposed.dequantize_per_tensor.tensor + if is_tensor_overload + else quantized_decomposed.dequantize_per_tensor.default, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("x_quant_min"), + KeywordArg("x_quant_max"), + KeywordArg("x_dq_dtype"), + ) + return dequantize_per_tensor_activation_pattern + + +dequantize_per_channel_weight_pattern = CallFunction( + quantized_decomposed.dequantize_per_channel.default, + KeywordArg("q_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("w_axis"), + KeywordArg("w_quant_min"), + KeywordArg("w_quant_max"), + KeywordArg("w_dtype"), +) + +dequantize_per_channel_to_bf16_weight_pattern = ( + _may_generate_pattern_with_dtype_convert( + dequantize_per_channel_weight_pattern, + KeywordArg("autocast_wgt_dtype"), + ) +) + +dequantize_per_channel_clone_weight_pattern = CallFunction( + aten.clone.default, + dequantize_per_channel_weight_pattern, + memory_format=KeywordArg("memory_format"), +) + +dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction( + aten.clone.default, + dequantize_per_channel_to_bf16_weight_pattern, + memory_format=KeywordArg("memory_format"), +) + + +def get_qconv_pt2e_pattern(users=1): + return CallFunction( + torch.ops.onednn.qconv_pointwise.default, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("b"), + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("groups"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("postop_name"), + KeywordArg("postop_args"), + KeywordArg("postop_algorithm"), + _users=users, + ) + + +def get_qconv2d_binary_pt2e_pattern(users=1): + return CallFunction( + torch.ops.onednn.qconv2d_pointwise.binary, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("accum"), + KeywordArg("b"), + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("groups"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("accum_scale"), + KeywordArg("accum_zero_point"), + KeywordArg("binary_op_name"), + KeywordArg("alpha"), + KeywordArg("unary_op_name"), + KeywordArg("unary_op_args"), + KeywordArg("unary_op_algorithm"), + _users=users, + ) + + +def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1): + qlinear_op = ( + torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default + ) + return CallFunction( + qlinear_op, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("b"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("postop_name"), + KeywordArg("postop_args"), + KeywordArg("postop_algorithm"), + _users=users, + ) + + +def get_qlinear_binary_pt2e_pattern(x_scale_zp_are_tensors, users=1): + qlinear_op = ( + torch.ops.onednn.qlinear_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.binary + ) + return CallFunction( + qlinear_op, + KeywordArg("x"), + KeywordArg("x_scale"), + KeywordArg("x_zp"), + KeywordArg("packed_weight"), + KeywordArg("w_scale"), + KeywordArg("w_zp"), + KeywordArg("x_2"), + KeywordArg("b"), + KeywordArg("output_scale"), + KeywordArg("output_zero_point"), + KeywordArg("output_dtype"), + KeywordArg("x2_scale"), + KeywordArg("x2_zp"), + KeywordArg("binary_op_name"), + KeywordArg("alpha"), + KeywordArg("unary_op_name"), + KeywordArg("unary_op_args"), + KeywordArg("unary_op_algorithm"), + _users=users, + ) + + +dequantize_accum_pattern = CallFunction( + quantized_decomposed.dequantize_per_tensor.default, + KeywordArg("accum"), + KeywordArg("accum_scale"), + KeywordArg("accum_zp"), + Arg(), + Arg(), + KeywordArg("accum_dq_dtype"), +) + + +def generate_pattern_with_binary( + binary_post_op, + computation_call, + extra_input_pattern, + dtype_convert=False, + swap_inputs=False, +): + binary_pattern = ( + CallFunction( + binary_post_op, + extra_input_pattern, + computation_call, + ) + if swap_inputs + else CallFunction( + binary_post_op, + computation_call, + extra_input_pattern, + ) + ) + return _may_generate_pattern_with_dtype_convert( + binary_pattern, + KeywordArg("convert_dtype_after_inplace_add"), + dtype_convert, + ) + + +def generate_pattern_with_unary(computation_call, unary_post_op): + if unary_post_op is not None: + return CallFunction( + unary_post_op, + computation_call, + ) + return computation_call + + +def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False): + quantized_op_output_pattern_pt2e = CallFunction( + quantized_decomposed.quantize_per_tensor.default, + _may_generate_pattern_with_dtype_convert( + computation_call, + Arg(), + with_dtype_convert, + ), + KeywordArg("o_inv_scale"), + KeywordArg("o_zp"), + KeywordArg("o_qmin"), + KeywordArg("o_qmax"), + KeywordArg("o_dtype"), + ) + return quantized_op_output_pattern_pt2e + + +def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_value): + if kwarg_name in check_node.kwargs: + actual_value = check_node.kwargs[kwarg_name] + return actual_value == expected_value + else: + assert len(check_node.args) >= (args_index + 1) + actual_value = check_node.args[args_index] + return actual_value == expected_value + + +def _is_valid_quantized_conv_optimization_pattern(): + def fn(match): + output_dtype = _get_pattern_output_dtype(match) + if output_dtype in [torch.float32, torch.bfloat16]: + # Only keep matched pattern with same output_dtype + qconv_node_after_weight_prepack = filter_nodes( + match.nodes, torch.ops.onednn.qconv_pointwise + )[0] + return _check_node_kwarg_arg_value( + qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype + ) + return True + + return fn + + +def _is_valid_qconv_post_op_fusion_pattern(has_binary_post_op=False): + return ( + _is_valid_qconv_binary_optimization_pattern() + if has_binary_post_op + else _is_valid_quantized_conv_optimization_pattern() + ) + + +def _is_valid_qconv_lowering_pattern(): + def fn(match): + if len(match.nodes) != 1: + return False + return match.nodes[0].target in ( + torch.ops.onednn.qconv_pointwise.default, + torch.ops.onednn.qconv_pointwise.tensor, + torch.ops.onednn.qconv2d_pointwise.binary, + torch.ops.onednn.qconv2d_pointwise.binary_tensor, + ) + + return fn + + +def _register_quantized_conv_lowering( + pattern, + pass_number, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_qconv_lowering_pattern(), + pass_number=pass_number, + ) + def qconv(match: Match, *args, **kwargs): + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + # Conv Params + b, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + output_dtype = _get_pattern_output_dtype(match) + assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] + # Output QParams + o_inv_scale = kwargs["output_scale"] + o_zero_point = kwargs["output_zero_point"] + output_dtype = kwargs["output_dtype"] + # post op + postop_name = kwargs["postop_name"] + postop_args = kwargs["postop_args"] + postop_algorithm = kwargs["postop_algorithm"] + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + postop_name, + postop_args, + postop_algorithm, + ) + counters["inductor"]["qconv_unary_lower_count"] += 1 + counters["inductor"]["qconv_unary_lower_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qconv + + +def _is_valid_quantized_linear_optimization_pattern(): + def fn(match): + output_dtype = _get_pattern_output_dtype(match) + if output_dtype in [torch.float32, torch.bfloat16]: + # Only keep matched pattern with same output_dtype + qlinear_node_after_weight_prepack = filter_nodes( + match.nodes, torch.ops.onednn.qlinear_pointwise + )[0] + return _check_node_kwarg_arg_value( + qlinear_node_after_weight_prepack, "output_dtype", 9, output_dtype + ) + return True + + return fn + + +def _is_valid_qlinear_post_op_fusion_pattern(has_binary_post_op=False): + return ( + _is_valid_qlinear_binary_optimization_pattern() + if has_binary_post_op + else _is_valid_quantized_linear_optimization_pattern() + ) + + +def _is_valid_qlinear_lowering_pattern(): + def fn(match): + if len(match.nodes) != 1: + return False + return match.nodes[0].target in ( + torch.ops.onednn.qlinear_pointwise.default, + torch.ops.onednn.qlinear_pointwise.tensor, + torch.ops.onednn.qlinear_pointwise.binary, + torch.ops.onednn.qlinear_pointwise.binary_tensor, + ) + + return fn + + +def _register_quantized_linear_unary_lowering( + pattern, + pass_number, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_qlinear_lowering_pattern(), + pass_number=pass_number, + ) + def qlinear(match: Match, *args, **kwargs): + output_dtype = _get_pattern_output_dtype(match) + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # bias + b = kwargs["b"] if "b" in kwargs else None + + # Output QParams + o_inv_scale = kwargs["output_scale"] + o_zero_point = kwargs["output_zero_point"] + + # post op + postop_name = kwargs["postop_name"] + postop_args = kwargs["postop_args"] + postop_algorithm = kwargs["postop_algorithm"] + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + o_inv_scale, + o_zero_point, + output_dtype, + postop_name, + postop_args, + postop_algorithm, + ) + counters["inductor"]["qlinear_unary_lower_count"] += 1 + counters["inductor"]["qlinear_unary_lower_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qlinear + + +def _register_quantized_linear_binary_lowering( + pattern, + pass_number, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_qlinear_lowering_pattern(), + pass_number=pass_number, + ) + def qlinear_binary(match: Match, *args, **kwargs): + output_dtype = _get_pattern_output_dtype(match) + assert output_dtype is not None + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + x2 = kwargs["x_2"] + x2_scale = kwargs["x2_scale"] + x2_zp = kwargs["x2_zp"] + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + # bias + b = kwargs["b"] if "b" in kwargs else None + # Output QParams + o_inv_scale = kwargs["output_scale"] + o_zero_point = kwargs["output_zero_point"] + + x2.realize() + from .mkldnn_fusion import _can_be_inplace + + binary_op_name = kwargs["binary_op_name"] + alpha = kwargs["alpha"] + unary_op_name = kwargs["unary_op_name"] + unary_op_args = kwargs["unary_op_args"] + unary_op_algorithm = kwargs["unary_op_algorithm"] + + if binary_op_name == "sum" and not _can_be_inplace(x2): + # When we enable the GEMM Template, the output of QLinear + # will be reshaped from 2D back to 3D if the input is 3D. + # This causes _can_be_inplace(x2) to return False if x2 happens + # to be the output of QLinear in this scenario. + # Change the post op from sum to binary add for this case. + # Refer to test case: + # test_mkldnn_pattern_matcher.py::test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2 + binary_op_name = "add" + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + x2, + b, + o_inv_scale, + o_zero_point, + output_dtype, + x2_scale, + x2_zp, + binary_op_name, + alpha, + unary_op_name, + unary_op_args, + unary_op_algorithm, + ) + counters["inductor"]["qlinear_binary_lower_count"] += 1 + counters["inductor"]["qlinear_binary_lower_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qlinear_binary + + +def _is_valid_qconv_binary_optimization_pattern(): + return _is_valid_quantized_op_binary_optimization_pattern( + torch.ops.onednn.qconv_pointwise + ) + + +def _is_valid_qlinear_binary_optimization_pattern(): + return _is_valid_quantized_op_binary_optimization_pattern( + torch.ops.onednn.qlinear_pointwise, + # we don't insert q-dq for extra input due to accuracy issues + extra_input_from_dequant=False, + ) + + +def _is_valid_quantized_op_binary_optimization_pattern( + qop, extra_input_from_dequant=True +): + # Check if it's a valid Binary Pattern for qconv2d and qlinear: + # * qop_pointwise should only has one users + # * If extra_input_from_dequant is True, extra input of binary node should come from dequant pattern + # * the two inputs of binary node should have attribute "meta" and should be tensors + # * the two inputs of binary node should have the same shape + # * All users of the extra input in this pattern should be + # ancestor nodes of the compute node, except for the binary node + # connected to the compute node. + def fn(match): + output_dtype = _get_pattern_output_dtype(match) + compute_node = filter_nodes(match.nodes, qop)[0] + # qop_pointwise should only have one user + if len(compute_node.users) != 1: + return False + binary_node_inputs = next(iter(compute_node.users)).args + assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs" + if output_dtype in [torch.float32, torch.bfloat16]: + extra_input_of_binary_node = None + for arg in binary_node_inputs: + if arg != compute_node: + extra_input_of_binary_node = arg + break + assert extra_input_of_binary_node is not None + # Extra input of binary node comes from dequant pattern + if extra_input_from_dequant and ( + (not isinstance(extra_input_of_binary_node, torch.fx.Node)) + or ( + extra_input_of_binary_node.target + != quantized_decomposed.dequantize_per_tensor.default + ) + ): + return False + + # the two inputs of binary node should have attribute "meta" and should be tensors + if not ( + hasattr(binary_node_inputs[0], "meta") + and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr] + ) or not ( + hasattr(binary_node_inputs[1], "meta") + and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr] + ): + return False + # the two inputs of binary node should have the same shape + if ( + binary_node_inputs[0].meta["val"].size() # type: ignore[union-attr] + != binary_node_inputs[1].meta["val"].size() # type: ignore[union-attr] + ): + return False + + # All users of the extra input in this pattern should be + # ancestor nodes of the compute node, except for the binary node + # connected to the compute node. + + from .mkldnn_fusion import _get_remaining_users + + extra_input_of_pattern = ( + match.kwargs["other"] + if "other" in match.kwargs + else ( + match.kwargs["accum"] + if (output_dtype in [torch.uint8, torch.int8]) + or (not extra_input_from_dequant) + else match.kwargs["accum_after_dequant"] + ) + ) + if ( + len(_get_remaining_users(extra_input_of_pattern, compute_node)) > 1 + or extra_input_of_pattern == compute_node.args[0] + ): + return False + return True + + return fn + + +def _register_quantized_conv_binary_lowering( + pattern, + pass_number, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_qconv_lowering_pattern(), + pass_number=pass_number, + ) + def qconv_binary(match: Match, *args, **kwargs): + output_dtype = _get_pattern_output_dtype(match) + assert output_dtype is not None + x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"] + accum = kwargs["accum"] + accum_scale = kwargs["accum_scale"] + accum_zp = kwargs["accum_zero_point"] + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + b, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + # Output QParams + output_scale = kwargs["output_scale"] + output_zero_point = kwargs["output_zero_point"] + + # post ops + binary_op_name = kwargs["binary_op_name"] + alpha = kwargs["alpha"] + unary_op_name = kwargs["unary_op_name"] + unary_op_args = kwargs["unary_op_args"] + unary_op_algorithm = kwargs["unary_op_algorithm"] + + accum.realize() + from .mkldnn_fusion import _can_be_inplace + + assert _can_be_inplace(accum), ( + "QConv Binary Inplace Fusion requires accum is not an alias or mutation." + ) + + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + accum, + b, + stride, + padding, + dilation, + groups, + output_scale, + output_zero_point, + output_dtype, + accum_scale, + accum_zp, + binary_op_name, + alpha, + unary_op_name, + unary_op_args, + unary_op_algorithm, + ) + counters["inductor"]["qconv2d_binary_lower_count"] += 1 + counters["inductor"]["qconv2d_binary_lower_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qconv_binary + + +def _register_quantization_unary_lowering(): + # QConv2d + for users in [1, 2]: + qconv_pattern = get_qconv_pt2e_pattern(users) + _register_quantized_conv_lowering( + qconv_pattern, + 2, # pass_number + torch.ops.onednn.qconv_pointwise.default, # computation_op + ) + + # QLinear + for x_scale_zp_are_tensors in (False, True): + qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) + computation_op = ( + torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default + ) + _register_quantized_linear_unary_lowering( + qlinear_pattern, + 2, # pass_number + computation_op, + ) + + +def _register_quantization_binary_lowering(): + # QConv2d + for users in (1, 2): + qconv_pattern = get_qconv2d_binary_pt2e_pattern(users) + _register_quantized_conv_binary_lowering( + qconv_pattern, + 2, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + ) + + # QLinear + for x_scale_zp_are_tensors in (False, True): + qlinear_pattern = get_qlinear_binary_pt2e_pattern(x_scale_zp_are_tensors) + computation_op = ( + torch.ops.onednn.qlinear_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.binary + ) + _register_quantized_linear_binary_lowering( + qlinear_pattern, + 2, # pass_number + computation_op, + ) + + +def _is_valid_quantized_maxpool2d_optimization_pattern(): + def fn(match): + # Only match the pattern which max_pool2d_with_indices returns value + # instead of indices. + get_item_node = filter_nodes(match.nodes, operator.getitem)[0] + return get_item_node.args[1] == 0 + + return fn + + +def _register_quantized_maxpool2d_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_quantized_maxpool2d_optimization_pattern(), + ) + def qmaxpool2d(match: Match, *args, **kwargs): + x = kwargs["x"] + kernel_size = kwargs["kernel_size"] + stride = kwargs["stride"] if ("stride" in kwargs) else None + padding = kwargs["padding"] if ("padding" in kwargs) else 0 + dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1 + ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False + + if padding == 0: + padding = [0, 0] + if dilation == 1: + dilation = [1, 1] + if not stride: + stride = kernel_size + kernel_size = pad_listlike(kernel_size, 2) + stride = pad_listlike(stride, 2) + padding = pad_listlike(padding, 2) + dilation = pad_listlike(dilation, 2) + + assert len(kernel_size) == 2 + assert len(stride) == 2 + assert len(padding) == 2 + assert len(dilation) == 2 + + computation_args = ( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ) + computation_args, _ = require_channels_last(computation_op, *computation_args) + counters["inductor"]["qmaxpool2d_matcher_count"] += 1 + counters["inductor"]["qmaxpool2d_matcher_nodes"] += len(match.nodes) + return L[computation_op](*computation_args) + + return qmaxpool2d + + +def _register_quantization_maxpool2d(): + # Currently, the default parameters are not in FX Graph generated by Dynamo export. + # So, if user defines nn.MaxPool2d with different assignment of default parameter, + # it will generate graph with different number of input nodes and hence + # different pattern to be matched. + # Refer to the issue: https://github.com/pytorch/pytorch/issues/105901 + max_pool2d_args_list = [ + [ + KeywordArg("stride"), + ], + [ + KeywordArg("stride"), + KeywordArg("padding"), + ], + [ + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + ], + [ + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("ceil_mode"), + ], + ] + for max_pool2d_args in max_pool2d_args_list: + dequantize_maxpool2d_pattern = CallFunction( + aten.max_pool2d_with_indices.default, + get_dequantize_per_tensor_activation_pattern(), + KeywordArg("kernel_size"), + *max_pool2d_args, + ) + dequantize_lowmem_maxpool2d_pattern = CallFunction( + prims._low_memory_max_pool_with_offsets.default, + get_dequantize_per_tensor_activation_pattern(), + KeywordArg("kernel_size"), + *max_pool2d_args, + KeywordArg("offset_dtype"), + ) + dequantize_maxpool2d_get_item_pattern = CallFunction( + operator.getitem, + dequantize_maxpool2d_pattern, + Arg(), + ) + dequantize_lowmem_maxpool2d_get_item_pattern = CallFunction( + operator.getitem, + dequantize_lowmem_maxpool2d_pattern, + Arg(), + ) + _register_quantized_maxpool2d_lowering( + generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern), + quantized.max_pool2d.default, + ) + _register_quantized_maxpool2d_lowering( + generate_pattern_with_output_quant( + dequantize_lowmem_maxpool2d_get_item_pattern + ), + quantized.max_pool2d.default, + ) + + +def _is_input_output_same_scale_zp(check_node): + def fn(match): + # Ensure all the inputs and output has same scale and zero point + # Step 1: Check inputs/output zero point + # Get dequant nodes at input + dequant_nodes = filter_nodes( + match.nodes, quantized_decomposed.dequantize_per_tensor.default + ) + zero_points = [node.args[2] for node in dequant_nodes] + # Get quant nodes at output + quant_nodes = filter_nodes( + match.nodes, quantized_decomposed.quantize_per_tensor.default + ) + assert len(quant_nodes) == 1, "expect only 1 add node at output quant pattern" + zero_points.append(quant_nodes[0].args[2]) + if not all(zero_point == zero_points[0] for zero_point in zero_points): + return False + + # Step 2: Check inputs/output scale + scales = [node.args[1] for node in dequant_nodes] + scales.append(quant_nodes[0].args[1]) + if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type] + return False + + return True + + return fn + + +def _register_quantized_cat_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_input_output_same_scale_zp(aten.cat.default), + ) + def qcat(match: Match, inputs, dim, **kwargs): + # inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...] + uint8_inputs = [input[0] for input in inputs] + counters["inductor"]["qcat_matcher_count"] += 1 + counters["inductor"]["qcat_matcher_nodes"] += len(match.nodes) + return L[computation_op](uint8_inputs, dim) + + return qcat + + +_raw_dequantize_per_tensor_activation_pattern = CallFunction( + quantized_decomposed.dequantize_per_tensor.default, + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), + Arg(), +) + + +def _register_quantization_cat(): + dequantize_cat_pattern = CallFunction( + aten.cat.default, + ListOf(_raw_dequantize_per_tensor_activation_pattern), + KeywordArg("dim"), + ) + _register_quantized_cat_lowering( + generate_pattern_with_output_quant(dequantize_cat_pattern), + aten.cat, + ) + + +def _register_quantized_reshape_lowering( + pattern, + computation_op, +): + @register_lowering_pattern( + pattern, + extra_check=_is_input_output_same_scale_zp(aten.reshape.default), + ) + def qreshape(match: Match, *args, **kwargs): + qx = kwargs["x"] + shape = kwargs["shape"] + counters["inductor"]["qreshape_matcher_count"] += 1 + counters["inductor"]["qreshape_matcher_nodes"] += len(match.nodes) + return L[computation_op](qx, shape) + + return qreshape + + +def _register_quantization_reshape(): + dequantize_reshape_pattern = CallFunction( + torch.ops.aten.reshape.default, + get_dequantize_per_tensor_activation_pattern(), + KeywordArg("shape"), + ) + _register_quantized_reshape_lowering( + generate_pattern_with_output_quant(dequantize_reshape_pattern), + aten.reshape, + ) + + +def _is_valid_concat_linear_int8_woq_optimization_pattern(): + def fn(match): + if not config.cpp.enable_concat_linear: + return False + assert all(k in match.kwargs for k in ("x", "w1", "w2", "w3", "scales")) + if not all( + hasattr(match.kwargs[key], "meta") + for key in ["x", "w1", "w2", "w3", "scales"] + ): + return False + x = match.kwargs["x"].meta["val"] + w1 = match.kwargs["w1"].meta["val"] + w2 = match.kwargs["w2"].meta["val"] + w3 = match.kwargs["w3"].meta["val"] + scales = match.kwargs["scales"].meta["val"] + if len(match.kwargs["scales"].meta["val"].size()) > 1: + return False + num_scales = match.kwargs["scales"].meta["val"].numel() + w1_cols = match.kwargs["w1"].meta["val"].size()[0] + w2_cols = match.kwargs["w2"].meta["val"].size()[0] + w3_cols = match.kwargs["w3"].meta["val"].size()[0] + # Technically, the shapes of the three weights need not be equal. + # But currently, we only enable replacement in this case. + if w1_cols != w2_cols or w2_cols != w3_cols: + return False + if 3 * w1_cols != num_scales: + return False + return ( + # For now, we only support woq mm kernels + # with x.type=bfloat16 and w.type=int8 + x.dtype == torch.bfloat16 + and w1.dtype == torch.int8 + and w2.dtype == torch.int8 + and w3.dtype == torch.int8 + and scales.dtype == torch.bfloat16 + # _weight_int8pack_mm kernel only supports cpu now + # TODO: add cuda kernel support instead of calling mul+sum + and x.device.type == "cpu" + and x.device == w1.device + and w1.device == w2.device + and w2.device == w3.device + and x.device == scales.device + ) + + return fn + + +def _is_valid_woq_optimization_pattern(): + def fn(match): + assert all(k in match.kwargs for k in ("x", "weight", "scales")) + if not all( + hasattr(match.kwargs[key], "meta") for key in ["x", "weight", "scales"] + ): + return False + x = match.kwargs["x"].meta["val"] + weight = match.kwargs["weight"].meta["val"] + scales = match.kwargs["scales"].meta["val"] + return ( + # For now, we only support woq mm kernels + # with x.type=bfloat16 and w.type=int8 + x.dtype == torch.bfloat16 + and weight.dtype == torch.int8 + and scales.dtype == torch.bfloat16 + # _weight_int8pack_mm kernel only supports cpu now + # TODO: add cuda kernel support instead of calling mul+sum + and x.device.type == "cpu" + and x.device == weight.device + and x.device == scales.device + ) + + return fn + + +def _register_concat_linear_int8_woq_lowering( + pattern, computation_woq, computation_reshape +): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_concat_linear_int8_woq_optimization_pattern(), + pass_number=4, + ) + def woq(match: Match, *args, **kwargs): + x = kwargs["x"] + w1 = kwargs["w1"] + w2 = kwargs["w2"] + w3 = kwargs["w3"] + scales = kwargs["scales"] + counters["inductor"]["woq_matcher_count"] += 1 + counters["inductor"]["woq_matcher_nodes"] += len(match.nodes) + out_features = ( + w1.meta["val"].size()[0] + + w2.meta["val"].size()[0] + + w3.meta["val"].size()[0] + ) + origin_x_size = tuple(x.meta["val"].size()) + x_shape = [-1, origin_x_size[-1]] + out_shape = list(origin_x_size[:-1] + (out_features,)) + mm_node_of_x = None + for candidate in iter(x.users.keys()): + if ( + candidate.target == aten.mm.default + and list(candidate._input_nodes)[1].target == aten.cat.default + ): + mm_node_of_x = candidate + break + assert mm_node_of_x is not None, "unable to find mm node" + _, cat_wgt_node = mm_node_of_x._input_nodes + scaling_node = next(iter(mm_node_of_x.users.keys())) + user_of_scaling_node = next(iter(scaling_node.users.keys())) + # Some other pass is making some changes that entails + # adding a node before it's used, but it can only be found when + # lint is run. stable_topological_sort() is being run before lint, + # so that error was not being being discovered. + # We call stable_topological_sort here as a workaround. + stable_topological_sort(match.graph) + with match.graph.inserting_before(user_of_scaling_node): + new_cat_node = match.graph.call_function( + aten.cat.default, + args=([w1, w2, w3], 0), + ) + x_reshape_node = match.graph.call_function( + computation_reshape, args=(x, x_shape) + ) + new_woq_node = match.graph.call_function( + computation_woq, + args=(x_reshape_node, new_cat_node, scales), + ) + new_woq_node.meta = copy.copy(x.meta) + output_reshape_node = match.graph.call_function( + computation_reshape, args=(new_woq_node, out_shape) + ) + scaling_node.replace_all_uses_with(output_reshape_node) + match.graph.erase_node(scaling_node) + match.graph.erase_node(mm_node_of_x) + match.graph.erase_node(cat_wgt_node) + match.graph.lint() + + return woq + + +def _register_woq_lowering(pattern, computation_woq, computation_reshape): + @register_lowering_pattern( + pattern, + extra_check=_is_valid_woq_optimization_pattern(), + ) + def woq(match: Match, *args, **kwargs): + x = kwargs["x"] + weight = kwargs["weight"] + scales = kwargs["scales"] + counters["inductor"]["woq_matcher_count"] += 1 + counters["inductor"]["woq_matcher_nodes"] += len(match.nodes) + out_features = weight.get_size()[0] + origin_x_size = x.get_size() + x_shape = [-1, origin_x_size[-1]] + out_shape = origin_x_size[:-1] + [ + out_features, + ] + func1 = L[computation_reshape](x, x_shape) + func2 = L[computation_woq](func1, weight, scales) + return L[computation_reshape](func2, out_shape) + + return woq + + +def _register_woq_mm_int8_pattern1(): + # F.linear(x, weight.to(dtype=x.dtype)) * scales + # case of dispatching to mm, with x reshape + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.reshape.default, + CallFunction( + aten.mm.default, + CallFunction(aten.reshape.default, KeywordArg("x"), Arg()), + CallFunction( + aten.permute.default, + CallFunction( + prims.convert_element_type.default, KeywordArg("weight"), Arg() + ), + Arg(), + ), + ), + Arg(), + ), + KeywordArg("scales"), + ) + _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) + + +def _register_woq_mm_int8_pattern2(): + # F.linear(x, weight.to(dtype=x.dtype)) * scales + # case of dispatching to mm, w/o x reshape + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.reshape.default, + CallFunction( + aten.mm.default, + KeywordArg("x"), + CallFunction( + aten.permute.default, + CallFunction( + prims.convert_element_type.default, KeywordArg("weight"), Arg() + ), + Arg(), + ), + ), + Arg(), + ), + KeywordArg("scales"), + ) + _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) + + +def _register_woq_mm_int8_pattern3(): + # F.linear(x, weight.to(dtype=x.dtype)) * scales + # case of dispatching to bmm + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.bmm.default, + CallFunction(aten.expand.default, KeywordArg("x"), Arg()), + CallFunction( + aten.expand.default, + CallFunction( + aten.permute.default, + CallFunction( + prims.convert_element_type.default, KeywordArg("weight"), Arg() + ), + Arg(), + ), + Arg(), + ), + ), + KeywordArg("scales"), + ) + _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) + + +def _register_woq_mm_int8_pattern4(): + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction( + aten.mm.default, + KeywordArg("x"), + CallFunction( + prims.convert_element_type.default, + CallFunction( + aten.permute.default, + KeywordArg("weight"), + Arg(), + ), + Arg(), + ), + ), + KeywordArg("scales"), + ) + _register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape) + + +def _register_int8_woq_concat_linear_pattern(): + def _create_wgt_node(wgt_node_name: str): + return CallFunction( + prims.convert_element_type.default, + CallFunction( + aten.permute.default, + KeywordArg(wgt_node_name), + Arg(), + ), + Arg(), + ) + + cat_wgt = CallFunction( + aten.cat.default, [_create_wgt_node(wgt) for wgt in ["w1", "w2", "w3"]], 1 + ) + + _woq_pattern = CallFunction( + aten.mul.Tensor, + CallFunction(aten.mm.default, KeywordArg("x"), cat_wgt), + KeywordArg("scales"), + ) + _register_concat_linear_int8_woq_lowering( + _woq_pattern, aten._weight_int8pack_mm.default, aten.reshape + ) + + +def _register_quantization_lowerings(): + _register_quantization_unary_lowering() + _register_quantization_binary_lowering() + _register_quantization_maxpool2d() + _register_quantization_cat() + _register_quantization_reshape() + + +def _register_woq_lowerings(): + _register_woq_mm_int8_pattern1() + _register_woq_mm_int8_pattern2() + _register_woq_mm_int8_pattern3() + _register_woq_mm_int8_pattern4() + + +def _is_valid_dequant_promotion_pattern(dtype=torch.float32): + def _inner(match): + assert dtype in [torch.float32, torch.bfloat16] + dequant_pattern_end_node = match.output_node() + if dequant_pattern_end_node.target not in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + prims.convert_element_type.default, + aten.reshape.default, + ]: + return False + + if dequant_pattern_end_node.target is aten.reshape.default: + dequant_node = ( + dequant_pattern_end_node.args[ + 0 + ] # pattern: linear <- reshape <- dequant + if dtype == torch.float32 + else dequant_pattern_end_node.args[0].args[ + 0 + ] # pattern: linear <- reshape <- to_bf16 <- dequant + ) + else: + dequant_node = ( + dequant_pattern_end_node # pattern: linear <- dequant + if dtype == torch.float32 + else dequant_pattern_end_node.args[ + 0 + ] # pattern: linear <- to_bf16 <- dequant + ) + + if ( + dequant_node.target + in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ] + and len(list(dequant_pattern_end_node.users)) > 1 + ): + # If dequant pattern has more than 1 users, then do dequant promoted + return True + return False + + return _inner + + +def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_promotion_pattern(dtype), + pass_number=pass_number, + ) + def dequant_promotion(match: Match, *args, **kwargs): + # Dequant_promotion will transform + # graph 1: + # quant + # + - - - | - - - + + # | dequant | + # | / \ | + # | node1 node2 | + # + - | - - - | - + + # quant quant + # into: + # graph 2: + # quant + # + - - / - \ - - + + # |dequant dequant| + # | | | | + # | node1 node2 | + # + - | - - - | - + + # quant quant + # In graph 1, the dequant node is shared by node1 and node2, + # as a result, neither node1 nor node2 could form an int8 + # fusion pattern. + # After this transformation, the graph 2 could hit the int8 + # fusion pattern: dequant-node-quant, respectively for + # node1 and node2. + assert dtype in [torch.float32, torch.bfloat16] + + def clone_to_new_node(graph, source_node, user_node): + # Clone the source_node to a new node + # Replace user_node's input from source_node to new_node + assert source_node.op == "call_function", ( + "clone_to_new_node only support node.op call_function" + ) + with graph.inserting_before(user_node): + new_node = graph.call_function( + source_node.target, + args=source_node.args, + kwargs=source_node.kwargs, + ) + new_node.meta = copy.copy(source_node.meta) + user_node.replace_input_with(source_node, new_node) + return new_node + + # Find the start node and end node of a dequant pattern + # * End node should be the match.output_node() + # * Start node should be the node of dequantize_per_tensor + dequant_pattern_end_node = match.output_node() + assert dequant_pattern_end_node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + prims.convert_element_type.default, + aten.reshape.default, + ] + + # For a dequant pattern, we should expect see the node list as: + # * OPT(aten.reshape.default) + # * OPT(prims.convert_element_type.default) (to_bf16) + # * dequantize_per_tensor + def _find_first_node_in_dequant_pattern(_node): + if _node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ]: + # For a dequant pattern, we expect the start node is a dequantize_per_tensor node + return _node + else: + assert len(_node.args) >= 1, ( + "In in dequant pattern, each node should have more than 1 arg." + ) + return _find_first_node_in_dequant_pattern(_node.args[0]) + + dequant_pattern_start_node = _find_first_node_in_dequant_pattern( + dequant_pattern_end_node + ) + + assert dequant_pattern_start_node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ] + + # Clone the dequant pattern for each user node + graph = match.graph + user_node_list = list(dequant_pattern_end_node.users) + for user_node in user_node_list[1:]: + _source_node = dequant_pattern_end_node + _user_node = user_node + while _source_node != dequant_pattern_start_node.args[0]: + _user_node = clone_to_new_node(graph, _source_node, _user_node) + _source_node = _source_node.args[0] # type: ignore[assignment] + + counters["inductor"]["dequant_promotion_matcher_count"] += 1 + counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes) + + +def _is_valid_dequant_conv_pattern(dtype): + def _inner(match): + # Here we do some further check to ensure: + # 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now. + # 2. The dequant pattern has only 1 user of conv2d node. + # If these conditions don't meet, we will not + # insert weight prepack node into the matched pattern. + conv_node = match.output_node() + assert conv_node.target is aten.convolution.default + input_meta_value = conv_node.args[0].meta.get("val") + weight_meta_value = conv_node.args[1].meta.get("val") + for meta_value in [input_meta_value, weight_meta_value]: + if ( + meta_value is None + or (meta_value.device.type != "cpu" and meta_value.device.type != "xpu") + or meta_value.dim() not in [3, 4] + ): + # Only support conv1d/2d now + return False + + assert dtype in [torch.float32, torch.bfloat16] + + if dtype == torch.float32: + dequant_node = conv_node.args[0] + else: + convert_to_bf16 = conv_node.args[0] + dequant_node = convert_to_bf16.args[0] + + if len(list(dequant_node.users)) != 1: + # Ensure the dequant pattern only has 1 user + # since we will delete the dequant pattern here + return False + return True + + return _inner + + +def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_conv_pattern(dtype), + pass_number=pass_number, + ) + def qconv_weight_prepack(match: Match, *args, **kwargs): + """ + Match the pattern: + int8 activation + | + dequant_per_tensor + | + Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight + + Insert weight prepack node and change the pattern to: + int8 activation + | + onednn.qconv_pointwise <- onednn.qconv_prepack <- int8_weight + """ + assert dtype in [torch.float32, torch.bfloat16] + conv_node = match.output_node() + assert conv_node.target is aten.convolution.default + if dtype == torch.float32: + dequant_node = conv_node.args[0] + else: + convert_to_bf16 = conv_node.args[0] + dequant_node = convert_to_bf16.args[0] # type: ignore[union-attr] + has_clone_to_channel_last_node_in_pattern = ( + conv_node.args[1].target is aten.clone.default # type: ignore[union-attr] + ) + clone_node = ( + conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None + ) + + if dtype == torch.float32: + dequant_per_channel = ( + clone_node.args[0] # type: ignore[union-attr] + if has_clone_to_channel_last_node_in_pattern + else conv_node.args[1] + ) + else: + weight_to_bf16_node = ( + clone_node.args[0] # type: ignore[union-attr] + if has_clone_to_channel_last_node_in_pattern + else conv_node.args[1] + ) + dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr] + + assert ( + dequant_per_channel.target # type: ignore[union-attr] + is quantized_decomposed.dequantize_per_channel.default + ) + + # Activation QParams + qx, x_zp, x_scale = ( + kwargs["x"], + kwargs["x_zp"], + kwargs["x_scale"], + ) + + # Weight QParams + qw, w_scale, w_zp = ( + kwargs["q_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # Conv Params + bias, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + + x_shape = qx.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(conv_node): + # Insert weight prepack node and the QConv node + packed_weight_inputs = ( + qw, + w_scale, + x_scale, + x_zp, + stride, + padding, + dilation, + groups, + x_shape, + ) + packed_weight_op = torch.ops.onednn.qconv_prepack + prepack_weight_node = graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + new_args: tuple[Any, ...] = ( + qx, + x_scale, + x_zp, + prepack_weight_node, + w_scale, + w_zp, + bias, + stride, + padding, + dilation, + groups, + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # attr + [], # scalars + "", # algorithm + ) + new_conv_node = graph.call_function( + torch.ops.onednn.qconv_pointwise.default, args=new_args + ) + conv_node.replace_all_uses_with(new_conv_node) + new_conv_node.meta.update(conv_node.meta) + + # Erase the original conv node + graph.erase_node(conv_node) + # Erase the dequant pattern + if dtype == torch.bfloat16: + graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined, arg-type] + graph.erase_node(dequant_node) # type: ignore[arg-type] + # Erase the dequant per channel pattern + if clone_node is not None: + graph.erase_node(clone_node) # type: ignore[arg-type] + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type] + graph.erase_node(dequant_per_channel) # type: ignore[arg-type] + counters["inductor"]["qconv_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qconv_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +def _generate_dequant_convolution_node_pattern( + _dequant_per_channel_pattern, dtype=torch.float32 +): + assert dtype in [torch.float32, torch.bfloat16] + dequant_convolution_node_pattern = CallFunction( + aten.convolution.default, + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + _dequant_per_channel_pattern, + KeywordArg("b"), + KeywordArg("stride"), + KeywordArg("padding"), + KeywordArg("dilation"), + KeywordArg("is_transposed"), + KeywordArg("out_padding"), + KeywordArg("groups"), + ) + return dequant_convolution_node_pattern + + +def _generate_qconv_weight_prepack_patterns(dtype=torch.float32): + assert dtype in [torch.float32, torch.bfloat16] + return ( + _generate_dequant_convolution_node_pattern( + dequantize_per_channel_weight_pattern + if dtype == torch.float32 + else dequantize_per_channel_to_bf16_weight_pattern, + dtype, + ), + # There is another pattern due to the pass of convert_conv_weights_to_channels_last + # https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362. + # Depend on some heuristics, it may or may not insert to(channel_last) node + # between convolution and dequant_per_channel node + _generate_dequant_convolution_node_pattern( + dequantize_per_channel_clone_weight_pattern + if dtype == torch.float32 + else dequantize_per_channel_to_bf16_clone_weight_pattern, + dtype, + ), + ) + + +def _get_linear_node(match, input_dim_exceeds_two, input_contiguous): + output_reshape_node = None + if input_dim_exceeds_two: + if input_contiguous: + output_reshape_node = match.output_node() + assert output_reshape_node.target is aten.reshape.default + linear_node = output_reshape_node.args[0] + else: + linear_nodes = filter_nodes(match.nodes, aten.bmm.default) + assert len(linear_nodes) == 1 + linear_node = linear_nodes[0] + else: + linear_node = match.output_node() + + assert linear_node.target in ( + aten.addmm.default, + aten.mm.default, + aten.bmm.default, + ) + return linear_node, output_reshape_node + + +def _get_linear_dq_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous +): + act_reshape_node = None + activation_to_bf16_node = None + act_expand_node = None + if input_dim_exceeds_two: + if input_contiguous: + act_reshape_node = linear_node.args[input_index] + assert act_reshape_node.target is aten.reshape.default + if dtype == torch.float32: + # pattern: linear -> reshape -> dequant + dequant_node = act_reshape_node.args[0] + else: + # pattern: linear -> reshape -> to_bf16 -> dequant + activation_to_bf16_node = act_reshape_node.args[0] + dequant_node = activation_to_bf16_node.args[0] + else: + # bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous + act_expand_node = linear_node.args[input_index] + assert act_expand_node.target is aten.expand.default + if dtype == torch.float32: + dequant_node = act_expand_node.args[0] + else: + activation_to_bf16_node = act_expand_node.args[0] + dequant_node = activation_to_bf16_node.args[0] + else: + if dtype == torch.float32: + # pattern: linear -> dequant + dequant_node = linear_node.args[input_index] + else: + # pattern: linear -> to_bf16 -> dequant + activation_to_bf16_node = linear_node.args[input_index] + dequant_node = activation_to_bf16_node.args[0] + return dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node + + +def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous): + def _inner(match): + # Check dequant pattern has only 1 user. + ( + linear_node, + _, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + + input_index = 1 if linear_node.target is aten.addmm.default else 0 + assert dtype in [torch.float32, torch.bfloat16] + ( + dequant_node, + _, + _, + _, + ) = _get_linear_dq_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + ) + + assert dequant_node.target in [ + quantized_decomposed.dequantize_per_tensor.default, + quantized_decomposed.dequantize_per_tensor.tensor, + ] + + if len(list(dequant_node.users)) != 1: + # Ensure the dequant pattern only has 1 user + # since we will delete the dequant pattern here + return False + + # Extra check for bmm pattern + if input_dim_exceeds_two and not input_contiguous: + # Check for act + # Act expand size should be exactly same as act size + act_expand_size = match.kwargs["act_expand_size"] + act_node = match.kwargs["x"] + if not ( + hasattr(act_node, "meta") + and isinstance(act_node.meta.get("val", None), torch.Tensor) + and (act_node.meta["val"].size() == torch.Size(act_expand_size)) + ): + return False + + # Check for wgt + # wgt permute dims should be [1, 0] + wgt_permute_dims = match.kwargs["permute_axes"] + if wgt_permute_dims != [1, 0]: + return False + + # Check below wgt size items: + # wgt before expand should with dim 2 + # Expand size should with dim 3 + # Expand size[0] should same as act size[0] + # Expand size[1] should same as wgt size[1] + # Expand size[2] should same as wgt size[0] + qweight_node = match.kwargs["q_weight"] + wgt_expand_size = match.kwargs["wgt_expand_size"] + if not ( + hasattr(qweight_node, "meta") + and isinstance(qweight_node.meta.get("val", None), torch.Tensor) + and len(qweight_node.meta["val"].size()) == 2 + and len(wgt_expand_size) == 3 + and wgt_expand_size[0] == act_node.meta["val"].size()[0] + and wgt_expand_size[1] == qweight_node.meta["val"].size()[1] + and wgt_expand_size[2] == qweight_node.meta["val"].size()[0] + ): + return False + + return True + + return _inner + + +def _register_qlinear_weight_prepack_pass( + pattern, + pass_number, + dtype=torch.float32, + input_dim_exceeds_two=False, + input_contiguous=True, +): + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_dequant_linear_pattern( + dtype, input_dim_exceeds_two, input_contiguous + ), + pass_number=pass_number, + ) + def qlinear_weight_prepack(match: Match, *args, **kwargs): + """ + Match the pattern: + int8 activation + | + dequant_per_tensor + | + mm/addmm <- t <- dequant_per_channel <- int8_weight + + Insert weight prepack node and change the pattern to: + int8 activation + | + onednn.qlinear_pointwise <- onednn.qlinear_prepack <- int8_weight + """ + assert dtype in [torch.float32, torch.bfloat16] + ( + linear_node, + output_reshape_node, + ) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous) + input_index = 1 if linear_node.target is aten.addmm.default else 0 + weight_index = input_index + 1 + + ( + dequant_node, + act_reshape_node, + activation_to_bf16_node, + act_expand_node, + ) = _get_linear_dq_node( + linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous + ) + + if input_dim_exceeds_two and not input_contiguous: + wgt_expand_node = linear_node.args[weight_index] + assert wgt_expand_node.target is aten.expand.default + t_node = wgt_expand_node.args[0] + else: + t_node = linear_node.args[weight_index] + + if dtype == torch.float32: + dequant_per_channel = t_node.args[0] + else: + weight_to_bf16_node = t_node.args[0] + dequant_per_channel = weight_to_bf16_node.args[0] + assert ( + dequant_per_channel.target + is quantized_decomposed.dequantize_per_channel.default + ) + + # Activation QParams + qx, x_zp, x_scale = ( + kwargs["x"], + kwargs["x_zp"], + kwargs["x_scale"], + ) + + # Weight QParams + qw, w_scale, w_zp = ( + kwargs["q_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # Params + bias = kwargs["b"] if "b" in kwargs else None + + x_shape = qx.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(linear_node): + # Insert weight prepack node and the qlinear node + packed_weight_inputs = ( + qw, + x_shape, + ) + packed_weight_op = torch.ops.onednn.qlinear_prepack + prepack_weight_node = graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + new_args: tuple[Any, ...] = ( + qx, + x_scale, + x_zp, + prepack_weight_node, + w_scale, + w_zp, + bias, + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # post op name + [], # post op args + "", # post op algorithm + ) + Node = torch.fx.node.Node + if isinstance(x_scale, Node) and isinstance(x_zp, Node): + new_linear_node = graph.call_function( + torch.ops.onednn.qlinear_pointwise.tensor, args=new_args + ) + else: + new_linear_node = graph.call_function( + torch.ops.onednn.qlinear_pointwise.default, args=new_args + ) + if input_dim_exceeds_two: + if input_contiguous: + output_reshape_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(output_reshape_node.meta) + else: + if bias: + output_add_node_for_bias = match.output_node() + assert output_add_node_for_bias.target is aten.add.Tensor + output_add_node_for_bias.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(output_add_node_for_bias.meta) + else: + linear_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(linear_node.meta) + else: + linear_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(linear_node.meta) + + # Erase the original linear node + if input_dim_exceeds_two: + if input_contiguous: + graph.erase_node(output_reshape_node) + elif not input_contiguous and bias: + graph.erase_node(output_add_node_for_bias) # type: ignore[possibly-undefined] + graph.erase_node(linear_node) + if input_dim_exceeds_two: + if input_contiguous: + graph.erase_node(act_reshape_node) + else: + graph.erase_node(act_expand_node) + graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined] + if dtype == torch.bfloat16: + graph.erase_node(activation_to_bf16_node) + # Erase the dequant pattern + graph.erase_node(dequant_node) + # Erase the dequant per channel pattern + graph.erase_node(t_node) + if dtype == torch.bfloat16: + graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined] + graph.erase_node(dequant_per_channel) + + counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +def _generate_dequant_linear_node_pattern( + _dequant_per_channel_pattern, + dtype=torch.float32, + input_dim_exceeds_two=False, + is_tensor_overload=False, +): + assert dtype in [torch.float32, torch.bfloat16] + t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) + dequant_linear_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.addmm.default, + KeywordArg("b"), + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + dequant_linear_no_bias_pattern = _may_generate_pattern_with_reshape( + CallFunction( + aten.mm.default, + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ), + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ) + return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern + + +def _generate_dequant_bmm_node_pattern( + _dequant_per_channel_pattern, + dtype=torch.float32, + with_bias=False, + is_tensor_overload=False, +): + # When activation of linear dim exceed 2 and not contiguous + t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype) + + assert dtype in [torch.float32, torch.bfloat16] + dequant_bmm_pattern = CallFunction( + aten.bmm.default, + CallFunction( + aten.expand.default, + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern(is_tensor_overload), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_expand_size"), + ), + CallFunction( + aten.expand.default, + t_pattern, + KeywordArg("wgt_expand_size"), + ), + ) + + def _generate_pattern_with_output_add(_dequant_bmm_pattern, _with_bias): + if _with_bias: + return CallFunction( + aten.add.Tensor, + _dequant_bmm_pattern, + KeywordArg("b"), + ) + else: + return _dequant_bmm_pattern + + return _generate_pattern_with_output_add(dequant_bmm_pattern, with_bias) + + +def _generate_qlinear_weight_prepack_patterns( + dtype=torch.float32, + input_dim_exceeds_two=False, + input_contiguous=True, + with_bias=False, + is_tensor_overload=False, +): + if input_dim_exceeds_two and not input_contiguous: + return _generate_dequant_bmm_node_pattern( + dequantize_per_channel_weight_pattern, + dtype, + with_bias, + is_tensor_overload, + ) + else: + return _generate_dequant_linear_node_pattern( + dequantize_per_channel_weight_pattern, + dtype, + input_dim_exceeds_two, + is_tensor_overload, + ) + + +def _generate_linear_dynamic_fp16_pattern( + _dequant_weight_pattern, + input_dim_exceeds_two=False, + input_contiguous=True, + relu_fused=False, +): + dtype = torch.float32 + t_pattern = _generate_linear_t_pattern(_dequant_weight_pattern, dtype) + + if input_dim_exceeds_two and not input_contiguous: + # pattern is + # x -> expand -> bmm (-> add) (-> relu) + # w -> dequant -> permute -> expand / + pattern_no_bias = CallFunction( + aten.bmm.default, + CallFunction( + aten.expand.default, + KeywordArg("x"), + KeywordArg("act_expand_size"), + ), + CallFunction( + aten.expand.default, + t_pattern, + KeywordArg("wgt_expand_size"), + ), + ) + pattern_with_bias = CallFunction( + aten.add.Tensor, + pattern_no_bias, + KeywordArg("b"), + ) + if relu_fused: + pattern_with_bias = CallFunction(aten.relu.default, pattern_with_bias) + pattern_no_bias = CallFunction(aten.relu.default, pattern_no_bias) + return pattern_with_bias, pattern_no_bias + + x_pattern_with_reshape = _may_generate_pattern_with_reshape( + KeywordArg("x"), + KeywordArg("act_reshape_size"), + input_dim_exceeds_two, + ) + dequant_linear_bias_pattern = generate_pattern_with_unary( + _may_generate_pattern_with_reshape( + CallFunction( + aten.addmm.default, + KeywordArg("b"), + x_pattern_with_reshape, + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ), + aten.relu.default if relu_fused else None, + ) + dequant_linear_no_bias_pattern = generate_pattern_with_unary( + _may_generate_pattern_with_reshape( + CallFunction( + aten.mm.default, + x_pattern_with_reshape, + t_pattern, + ), + KeywordArg("output_reshape_size"), + input_dim_exceeds_two, + ), + aten.relu.default if relu_fused else None, + ) + return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern + + +def _register_dequant_promotion(): + dequant_pattern_cases = itertools.product( + [torch.float32, torch.bfloat16], [True, False], [True, False] + ) + for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases: + # 4 dequantization patterns will be matched based on the dtype and input dimension size. + # Case 1: int8-mixed-fp32, input dim size is 2 + # Case 2: int8-mixed-fp32, input dim size exceeds 2 + # Case 3: int8-mixed-bf16, input dim size is 2 + # Case 4: int8-mixed-bf16, input dim size exceeds 2 + # quant + # + - - - - | - - - - + + # | dequant | + # | | | + # | OPT(to_bf16) | + # | | | + # | OPT(reshape) | + # | / \ | + # | node1 node2 | + # + - - | - - - | - - + + # OPT(reshape) OPT(reshape) + # + - - | - - - | - - + + # OPT(to_fp32) OPT(to_fp32) + # + - - | - - - | - - + + # quant quant + _register_dequant_promotion_pass( + _may_generate_pattern_with_reshape( + _may_generate_pattern_with_dtype_convert( + get_dequantize_per_tensor_activation_pattern( + is_tensor_overload=is_tensor_overload + ), + KeywordArg("autocast_act_dtype"), + dtype == torch.bfloat16, + ), + KeywordArg("act_reshape_size"), + with_reshape=input_dim_exceeds_two, + ), + pass_number=0, + dtype=dtype, + ) # pass_number=0 to run before weight prepack + + +def _register_qconv_weight_prepack(): + for dtype in [torch.float32, torch.bfloat16]: + weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype) + for weight_prepack_pattern in weight_prepack_patterns: + # Register to pass_number 1, so we can do dequant promotion in pass_number 0. + _register_qconv_weight_prepack_pass( + weight_prepack_pattern, pass_number=1, dtype=dtype + ) + + +def _register_qlinear_weight_prepack(): + # 6 Linear related patterns will be matched based on the dtype, input dimension size and input contiguous. + # Then convert the pattern into a QLinear node with int8_fp32/bf16. + # Case 1: int8-mixed-fp32, input dim size is 2 + # Case 2: int8-mixed-fp32, input dim size exceeds 2 and contiguous + # Case 3: int8-mixed-bf16, input dim size is 2 + # Case 4: int8-mixed-bf16, input dim size exceeds 2 and contiguous + + # + - - - - | - - - - - - | - - - - - + + # | dq_per_tensor dq_per_channel | + # | | | | + # | OPT(to_bf16) OPT(to_bf16) | + # | | | | + # | OPT(reshape) permute | + # | \ / | + # | addmm/mm | + # | | | + # | OPT(reshape) | + + # Case 5: int8-mixed-fp32, input dim size exceeds 2 and not contiguous + # Case 6: int8-mixed-bf16, input dim size exceeds 2 and not contiguous + + # + - - - - | - - - - - - | - - - - - + + # | dq_per_tensor dq_per_channel | + # | | | | + # | OPT(to_bf16) OPT(to_bf16) | + # | | | | + # | expand permute | + # | \ | | + # | expand | + # | / | + # | bmm | + # | | | + # | OPT(add) | + + linear_weight_prepack_cases = itertools.product( + [torch.float32, torch.bfloat16], [True, False], [True, False] + ) + + # Step 1: register patterns from mm and addmm + for dtype, input_dim_exceeds_two, is_tensor_overload in linear_weight_prepack_cases: + weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns( + dtype, + input_dim_exceeds_two, + is_tensor_overload=is_tensor_overload, + ) + for weight_prepack_pattern in weight_prepack_patterns: + # Register to pass_number 1, so we can do dequant promotion in pass_number 0. + _register_qlinear_weight_prepack_pass( + weight_prepack_pattern, + pass_number=1, + dtype=dtype, + input_dim_exceeds_two=input_dim_exceeds_two, + ) + + # Step 2: register patterns from bmm + # Linear might be decomposed into bmm when input dim exceeds 2 and not contiguous + # refer to: + # https://github.com/pytorch/pytorch/blob/80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968 + # in this case, we can convert it back to qlinear + for dtype, with_bias, is_tensor_overload in itertools.product( + [torch.float32, torch.bfloat16], [True, False], [True, False] + ): + bmm_pattern = _generate_qlinear_weight_prepack_patterns( + dtype=dtype, + input_dim_exceeds_two=True, + input_contiguous=False, + with_bias=with_bias, + is_tensor_overload=is_tensor_overload, + ) + _register_qlinear_weight_prepack_pass( + bmm_pattern, + pass_number=1 + if with_bias + else 2, # if with_bias, there is an output add, so we should try to match it firstly + dtype=dtype, + input_dim_exceeds_two=True, + input_contiguous=False, + ) + + +def _register_linear_dynamic_fp16_weight_prepack_pass( + pattern, + pass_number, + input_dim_exceeds_two=False, + input_contiguous=True, + relu_fused=False, +): + def _extra_check_fn(match: Match): + return match.kwargs["dtype_fp16"] == torch.float16 + + @register_freezing_graph_pattern( + pattern, + extra_check=_extra_check_fn, + pass_number=pass_number, + ) + def linear_dynamic_fp16_weight_prepack(match: Match, *args, **kwargs): + """ + Match the pattern: + fp32 activation + | + mm/addmm <- t <- to_fp32 <- to_fp16 <- weight + | + (reshape) <- (relu) + + OR + + fp32 activation + | + expand + | + bmm <- expand <- t <- to_fp32 <- to_fp16 <- weight + | + (add) <- (relu) + + Insert weight prepack node and change the pattern to: + fp32 activation + | + onednn.linear_dynamic_fp16 <- onednn.linear_prepack_fp16 <- weight + (or onednn.linear_relu_dynamic_fp16) + """ + # find params + x = kwargs["x"] + w = kwargs["w"] + bias = kwargs["b"] if "b" in kwargs else None + + # find linear node + nodes_to_find = [aten.addmm.default, aten.mm.default, aten.bmm.default] + linear_nodes = [] + for node in nodes_to_find: + linear_nodes.extend(filter_nodes(match.nodes, node)) + assert len(linear_nodes) == 1 + linear_node = linear_nodes[0] + assert isinstance(linear_node, torch.fx.node.Node) + input_index = 1 if linear_node.target is aten.addmm.default else 0 + weight_index = input_index + 1 + + # find relu node + relu_node = None + if relu_fused: + relu_node = match.output_node() + assert isinstance(relu_node, torch.fx.node.Node) + + # find reshape node, expand node and add node + ( + act_reshape_node, + output_reshape_node, + expand_x_node, + expand_w_node, + add_bias_node, + ) = (None, None, None, None, None) + t_node = None + if input_dim_exceeds_two: + if input_contiguous: + act_reshape_node = linear_node.args[input_index] + t_node = linear_node.args[weight_index] + output_reshape_node = next(iter(linear_node.users)) + assert output_reshape_node.target is aten.reshape.default + else: + expand_x_node = linear_node.args[input_index] + expand_w_node = linear_node.args[weight_index] + assert isinstance(expand_w_node, torch.fx.node.Node) + t_node = expand_w_node.args[0] + if bias: + add_bias_node = next(iter(linear_node.users)) + assert add_bias_node.target is aten.add.Tensor + else: + t_node = linear_node.args[weight_index] + assert isinstance(t_node, torch.fx.node.Node) + + w_to_fp32_node = t_node.args[0] + assert ( + isinstance(w_to_fp32_node, torch.fx.node.Node) + and w_to_fp32_node.target + is quantized_decomposed.convert_element_type.no_fuse + ) + w_to_fp16_node = w_to_fp32_node.args[0] + assert ( + isinstance(w_to_fp16_node, torch.fx.node.Node) + and w_to_fp16_node.target + is quantized_decomposed.convert_element_type.no_fuse + ) + + x_shape = x.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + graph = match.graph + with graph.inserting_before(linear_node): + # Insert weight prepack node and the qlinear node + packed_weight_inputs = ( + w, + x_shape, + ) + packed_weight_op = torch.ops.onednn.linear_prepack_fp16 + prepack_weight_node = graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + # create new linear node and insert on graph + new_args: tuple[Any, ...] = ( + x, + prepack_weight_node, + bias, + ) + linear_op = ( + torch.ops.onednn.linear_relu_dynamic_fp16.default + if relu_fused + else torch.ops.onednn.linear_dynamic_fp16.default + ) + new_linear_node = graph.call_function(linear_op, args=new_args) + out_node = match.output_node() + out_node.replace_all_uses_with(new_linear_node) + + # Erase the original nodes in the reverse order + new_linear_node.meta.update(out_node.meta) + if relu_node is not None: + graph.erase_node(relu_node) + if output_reshape_node is not None: + graph.erase_node(output_reshape_node) + if add_bias_node is not None: + graph.erase_node(add_bias_node) + graph.erase_node(linear_node) + if act_reshape_node is not None: + assert isinstance(act_reshape_node, torch.fx.node.Node) + graph.erase_node(act_reshape_node) + if expand_x_node is not None: + assert isinstance(expand_x_node, torch.fx.node.Node) + graph.erase_node(expand_x_node) + if expand_w_node is not None: + assert isinstance(expand_w_node, torch.fx.node.Node) + graph.erase_node(expand_w_node) + graph.erase_node(t_node) + graph.erase_node(w_to_fp32_node) + graph.erase_node(w_to_fp16_node) + + counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +def _register_linear_dynamic_fp16_weight_prepack(): + to_dtype_op = torch.ops.quantized_decomposed.convert_element_type.no_fuse + weight_pattern = CallFunction( + to_dtype_op, + CallFunction( + to_dtype_op, + KeywordArg("w"), + KeywordArg("dtype_fp16"), + ), + KeywordArg("dtype_fp32"), + ) + cases = itertools.product( + [False, True], # input_dim_exceeds_two + [True, False], # input_contiguous + [False, True], # relu fused + ) + for input_dim_exceeds_two, input_contiguous, relu_fused in cases: + patterns = _generate_linear_dynamic_fp16_pattern( + weight_pattern, + input_dim_exceeds_two, + input_contiguous, + relu_fused, + ) + for pattern in patterns: + _register_linear_dynamic_fp16_weight_prepack_pass( + pattern, + pass_number=0 if relu_fused else 1, + input_dim_exceeds_two=input_dim_exceeds_two, + input_contiguous=input_contiguous, + relu_fused=relu_fused, + ) + + +def _register_smooth_quant_int_mm_pattern(): + """ + The pattern is: + (no bias) reshape -> _int_mm -> convert_element_type -> (expand ->) mul -> mul -> reshape + or + (with bias) pattern_no_bias -> add (-> reshape -> reshape) + """ + + # When torch.compile'ing with dynamic=True, the expand node and the two tailing reshape nodes exist + # When torch.compile'ing with dynamic=False, they don't exist + def get_pattern_no_bias(expand_a_scale: bool, reshape_a: bool = True): + return CallFunction( + aten.mul.Tensor, + CallFunction( + aten.mul.Tensor, + CallFunction( + prims.convert_element_type.default, + CallFunction( + aten._int_mm.default, + CallFunction( + aten.reshape.default, + KeywordArg("a"), + KeywordArg("in_shape"), + ) + if reshape_a + else KeywordArg("a"), + KeywordArg("b"), + ), + KeywordArg("dtype"), + ), + ( + CallFunction( + aten.expand.default, + KeywordArg("x_scale"), + Arg(), + ) + if expand_a_scale + else KeywordArg("x_scale") + ), + ), + KeywordArg("w_scale"), + ) + + def _with_outer_reshape(pattern): + return CallFunction( + aten.reshape.default, pattern, KeywordArg("out_shape_no_bias") + ) + + # for torch.compile(dynamic=False) + pattern_no_bias_1 = _with_outer_reshape(get_pattern_no_bias(expand_a_scale=False)) + pattern_with_bias_1 = CallFunction( + aten.add.Tensor, + pattern_no_bias_1, + KeywordArg("bias"), + ) + # for torch.compile(dynamic=True) + pattern_no_bias_2 = _with_outer_reshape(get_pattern_no_bias(expand_a_scale=True)) + pattern_with_bias_2 = CallFunction( + aten.reshape.default, + CallFunction( + aten.reshape.default, + CallFunction( + aten.add.Tensor, + pattern_no_bias_2, + KeywordArg("bias"), + ), + Arg(), + ), + KeywordArg("out_shape_with_bias"), + ) + + # The following patterns are for torchao int8_dynamic_activation_int8_weight linear, + # when both activation and weights are symmetrically quantized. + # In practice, though, they may also match smooth-quant pattern when a 2D input shape would be used. + # Since add is not currently being used as a oneDNN post-op, but is unfused, we don't need these patterns with bias. + # Ideally, we should add mul + add post-op support in ATen int8 oneDNN linear op. + pattern1_with_no_outer_or_act_reshape = get_pattern_no_bias( + expand_a_scale=False, reshape_a=False + ) + pattern2_with_no_outer_or_act_reshape = get_pattern_no_bias( + expand_a_scale=True, reshape_a=False + ) + + def _validate_pattern(match: Match): + if len(match.nodes) not in [4, 5, 6, 7, 10]: + return False + # Make sure weight is a constant + aten_int_mm_node = filter_nodes(match.nodes, aten._int_mm.default)[0] + if not isinstance(aten_int_mm_node.args[1], torch.fx.node.Node): + return False + if aten_int_mm_node.args[1].op != "get_attr": + return False + + if len(match.nodes) == 10: + # Check the two tailing reshape nodes can be fused + if match.nodes[9].args[1] != match.nodes[6].args[1]: + return False + if len(match.nodes) == 10 or ( + len(match.nodes) == 7 and match.nodes[6].target is aten.add.Tensor + ): + bias_idx = 7 if len(match.nodes) == 10 else 6 + # Check bias shape + bias_node = match.nodes[bias_idx].args[1] + if not isinstance(bias_node, torch.fx.node.Node): + return False + if len(bias_node.meta.get("tensor_meta").shape) != 1: # type: ignore[union-attr] + return False + return True + + pattern_to_pass_number = { + pattern_no_bias_2: 0, + pattern_with_bias_2: 0, + pattern_no_bias_1: 1, + pattern_with_bias_1: 1, + pattern1_with_no_outer_or_act_reshape: 2, + pattern2_with_no_outer_or_act_reshape: 2, + } + for pattern, pass_number in pattern_to_pass_number.items(): + + @register_freezing_graph_pattern( + pattern, + extra_check=_validate_pattern, + pass_number=pass_number, + ) + def _int_mm_weight_prepack(match: Match, *args, **kwargs): + bias = kwargs.get("bias", None) + x = kwargs["a"] + weight = kwargs["b"] + dtype = kwargs["dtype"] + x_scale = kwargs["x_scale"] + w_scale = kwargs["w_scale"] + x_shape = x.meta.get("tensor_meta").shape + if has_free_symbols(x_shape): + # For dynamic shape case, we can't get activation shape ahead of runtime. + x_shape = None + + out_node = match.output_node() + with match.graph.inserting_before(out_node): + transpose_node = match.graph.call_function( + aten.permute.default, args=(weight, [1, 0]) + ) + contig_node = match.graph.call_function( + aten.contiguous.default, args=(transpose_node,) + ) + packed_weight_inputs = ( + contig_node, + x_shape, + ) + packed_weight_op = torch.ops.onednn.qlinear_prepack + prepack_weight_node = match.graph.call_function( + packed_weight_op, args=packed_weight_inputs + ) + + dummy_zp = None + w_scale = match.graph.call_function( + prims.convert_element_type.default, args=(w_scale, torch.float32) + ) + + x_scale_shape = x_scale.meta.get("tensor_meta").shape + x_scale_is_scalar = False + if not has_free_symbols(x_scale_shape): + prod = 1 + for d in x_scale_shape: + prod *= d + x_scale_is_scalar = prod == 1 + + new_args: tuple[Any, ...] + if x_scale_is_scalar: + # in this case, we can call onednn.qlinear directly + new_args = ( + x, + x_scale, + dummy_zp, # x_zp + prepack_weight_node, + w_scale, + dummy_zp, # w_zp + bias, + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # post op name + [], # post op args + "", # post op algorithm + ) + new_linear_node = match.graph.call_function( + torch.ops.onednn.qlinear_pointwise.tensor, args=new_args + ) + out_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(out_node.meta) + else: + # onednn.qlinear does not support per-channel quantization of x + # so in this case, we have to apply x scale and add bias ourselves after qlinear + in_shape = kwargs.get("in_shape", None) + if in_shape is None: + x_reshaped = x + else: + x_reshaped = match.graph.call_function( + aten.reshape.default, args=(x, in_shape) + ) + new_args = ( + x_reshaped, + 1.0, # x_scale + 0, # x_zp + prepack_weight_node, + w_scale, + dummy_zp, # w_zp + None, # bias + 1.0, # output_scale + 0, # output_zero_point + dtype, # output_dtype + "none", # post op name + [], # post op args + "", # post op algorithm + ) + new_linear_node = match.graph.call_function( + torch.ops.onednn.qlinear_pointwise, args=new_args + ) + # apply x scale + new_out_node = match.graph.call_function( + aten.mul.Tensor, args=(new_linear_node, x_scale) + ) + + # Add bias and reshape + has_outer_reshape = ( + kwargs.get("out_shape_with_bias", None) is not None + or kwargs.get("out_shape_no_bias", None) is not None + ) + + if has_outer_reshape: + out_shape = kwargs.get( + "out_shape_with_bias", kwargs["out_shape_no_bias"] + ) + if bias is not None: + new_out_node = match.graph.call_function( + aten.add.Tensor, args=(new_out_node, bias) + ) + if has_outer_reshape: + new_out_node = match.graph.call_function( + aten.reshape.default, + args=(new_out_node, out_shape), # type: ignore[possibly-undefined] + ) + else: + if has_outer_reshape: + new_out_node = match.graph.call_function( + aten.reshape.default, + args=(new_out_node, out_shape), # type: ignore[possibly-undefined] + ) + out_node.replace_all_uses_with(new_out_node) + new_out_node.meta.update(out_node.meta) + for node in reversed(match.nodes): + match.graph.erase_node(node) + counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len( + match.nodes + ) + + +class PostOpAttr: + def __init__( + self, + binary_op_name: str = "none", + alpha=None, + unary_op_name: str = "none", + scalars_attr=None, + algorithm_attr=None, + ) -> None: + self.binary_op_name = binary_op_name + self.alpha = alpha if alpha else 1.0 + self.unary_op_name = unary_op_name + self.scalars_attr = scalars_attr if scalars_attr else [] + self.algorithm_attr = algorithm_attr if algorithm_attr else "" + + +def _register_qconv_post_op_fusion_pass( + pattern, + pass_number, + computation_op, + post_op_attr, +): + has_binary_post_op = post_op_attr.binary_op_name != "none" + + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_qconv_post_op_fusion_pattern(has_binary_post_op), + pass_number=pass_number, + ) + def qconv(match: Match, *args, **kwargs): + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + # Conv Params + b, stride, padding, dilation, groups = ( + kwargs["b"], + kwargs["stride"], + kwargs["padding"], + kwargs["dilation"], + kwargs["groups"], + ) + output_dtype = _get_pattern_output_dtype(match) + assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16] + # Output QParams + o_inv_scale = ( + kwargs["o_inv_scale"] + if (output_dtype == torch.uint8 or output_dtype == torch.int8) + else 1.0 + ) + o_zero_point = ( + kwargs["o_zp"] + if (output_dtype == torch.uint8 or output_dtype == torch.int8) + else 0 + ) + assert ( + kwargs["postop_name"] == "none" + ) # Expected no post op fused in weight prepack phase + if post_op_attr.unary_op_name == "hardtanh": + min_value = kwargs.get("min_value") + max_value = kwargs.get("max_value") + post_op_attr.scalars_attr = [min_value, max_value] + + out_node = match.output_node() + with match.graph.inserting_before(out_node): + if not has_binary_post_op: + computation_args: tuple[Any, ...] = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + post_op_attr.unary_op_name, + post_op_attr.scalars_attr, + post_op_attr.algorithm_attr, + ) + else: + accum = ( + kwargs["accum"] + if output_dtype in [torch.uint8, torch.int8] + else kwargs["accum_after_dequant"] + ) + accum_scale = ( + kwargs["accum_scale"] + if output_dtype in [torch.uint8, torch.int8] + else 1.0 + ) + accum_zp = ( + kwargs["accum_zp"] + if output_dtype in [torch.uint8, torch.int8] + else 0 + ) + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + accum, + b, + stride, + padding, + dilation, + groups, + o_inv_scale, + o_zero_point, + output_dtype, + accum_scale, + accum_zp, + post_op_attr.binary_op_name, + post_op_attr.alpha, + post_op_attr.unary_op_name, + post_op_attr.scalars_attr, + post_op_attr.algorithm_attr, + ) + new_conv_node = match.graph.call_function( + computation_op, args=computation_args + ) + out_node.replace_all_uses_with(new_conv_node) + new_conv_node.meta.update(out_node.meta) + for node in reversed(match.nodes): + match.graph.erase_node(node) + count_key = ( + "qconv2d_binary_matcher_count" + if has_binary_post_op + else "qconv_unary_matcher_count" + ) + nodes_key = ( + "qconv2d_binary_matcher_nodes" + if has_binary_post_op + else "qconv_unary_matcher_nodes" + ) + counters["inductor"][count_key] += 1 + counters["inductor"][nodes_key] += len(match.nodes) + + return qconv + + +def _register_qconv_unary_fusion(): + from .mkldnn_fusion import _hardswish_fusion, _hardtanh_fusion, _silu_fusion + + for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: + # Priority 1 to match: QConv2d Unary pattern with int8 output + # If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly. + # For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant + is_bf16 = original_pattern_output_dtype == torch.bfloat16 + conv_unary_replace_patterns = { + PostOpAttr( + "none", None, "none", [], "" + ): generate_pattern_with_output_quant( + get_qconv_pt2e_pattern(1), + ), + PostOpAttr( + "none", None, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + get_qconv_pt2e_pattern(1), aten.relu.default + ), + ), + PostOpAttr( + "none", None, "hardtanh", [], "" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _hardtanh_fusion, + get_qconv_pt2e_pattern(1), + 1, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + PostOpAttr( + "none", None, "hardswish", [], "" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _hardswish_fusion, + get_qconv_pt2e_pattern(1 if is_bf16 else 2), + 2, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + PostOpAttr( + "none", None, "swish", [], "" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _silu_fusion, + get_qconv_pt2e_pattern(1 if is_bf16 else 2), + 2, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + } + + for unary_attr, patterns in conv_unary_replace_patterns.items(): + # Register qconv2d pattern for ExternKernel Lowering + _register_qconv_post_op_fusion_pass( + patterns, + 3, # pass_number + torch.ops.onednn.qconv_pointwise.default, # computation_op + unary_attr, # unary_attr + ) + + # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output + conv_unary_replace_float_out_patterns = { + PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( + get_qconv_pt2e_pattern(1), aten.relu.default + ), + PostOpAttr( + "none", None, "hardtanh", [], "" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _hardtanh_fusion, + get_qconv_pt2e_pattern(1), + 1, + is_bf16, + ), + Arg(), + is_bf16, + ), + PostOpAttr( + "none", None, "hardswish", [], "" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _hardswish_fusion, + get_qconv_pt2e_pattern(1 if is_bf16 else 2), + 2, + is_bf16, + ), + Arg(), + is_bf16, + ), + PostOpAttr( + "none", None, "swish", [], "" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _silu_fusion, + get_qconv_pt2e_pattern(1 if is_bf16 else 2), + 2, + is_bf16, + ), + Arg(), + is_bf16, + ), + } + + for unary_attr, patterns in conv_unary_replace_float_out_patterns.items(): + # Register qconv2d pattern for ExternKernel Lowering + _register_qconv_post_op_fusion_pass( + patterns, + 4, # pass_number + torch.ops.onednn.qconv_pointwise.default, # computation_op + unary_attr, # unary_attr + ) + + +def _register_qconv_binary_fusion(): + for int8_mixed_bf16_with_inplace_add in [False, True]: + # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output + swap_binary_inputs_list = [False, True] + binary_replace_patterns = {} + for swap_inputs in swap_binary_inputs_list: + binary_replace_patterns.update( + { + PostOpAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_binary( + aten.add.Tensor, + get_qconv_pt2e_pattern(1), + dequantize_accum_pattern, + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + ), + PostOpAttr( + "sum", 1.0, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qconv_pt2e_pattern(1), + dequantize_accum_pattern, + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + aten.relu.default, + ), + ), + } + ) + + for binary_unary_attr, patterns in binary_replace_patterns.items(): + _register_qconv_post_op_fusion_pass( + patterns, + 3, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + + # Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output + binary_replace_float_out_patterns = {} + for swap_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qconv_pt2e_pattern(1), + KeywordArg("accum_after_dequant"), + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + aten.relu.default, + ) + } + ) + + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + if int8_mixed_bf16_with_inplace_add: + _register_qconv_post_op_fusion_pass( + patterns, + 3, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + else: + _register_qconv_post_op_fusion_pass( + patterns, + 4, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + + # Priority 3: QConv2d Binary pattern with fp32/bfloat16 output + binary_replace_float_out_patterns = {} + for swap_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_qconv_pt2e_pattern(1), + KeywordArg("accum_after_dequant"), + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + } + ) + + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qconv_post_op_fusion_pass( + patterns, + 4 if int8_mixed_bf16_with_inplace_add else 5, # pass_number + torch.ops.onednn.qconv2d_pointwise.binary, # computation_op + binary_unary_attr, # binary_unary_attr + ) + + +def _register_qlinear_post_op_fusion_pass( + pattern, + pass_number, + computation_op, + post_op_attr, +): + has_binary_post_op = post_op_attr.binary_op_name != "none" + + @register_freezing_graph_pattern( + pattern, + extra_check=_is_valid_qlinear_post_op_fusion_pattern(has_binary_post_op), + pass_number=pass_number, + ) + def qlinear_post_op_fusion(match: Match, *args, **kwargs): + """ + Match the pattern: + qlinear - post op + """ + output_dtype = _get_pattern_output_dtype(match) + # Activation QParams + x, x_scale, x_zp = ( + kwargs["x"], + kwargs["x_scale"], + kwargs["x_zp"], + ) + # Weight QParams + packed_weight, w_scale, w_zp = ( + kwargs["packed_weight"], + kwargs["w_scale"], + kwargs["w_zp"], + ) + + # bias + b = kwargs["b"] if "b" in kwargs else None + + # Output QParams + o_inv_scale = ( + kwargs["o_inv_scale"] + if (output_dtype in [torch.uint8, torch.int8]) + else 1.0 + ) + o_zero_point = ( + kwargs["o_zp"] if (output_dtype in [torch.uint8, torch.int8]) else 0 + ) + assert ( + kwargs["postop_name"] == "none" + ) # Expected no post op fused in weight prepack phase + + out_node = match.output_node() + with match.graph.inserting_before(out_node): + if not has_binary_post_op: + computation_args: tuple[Any, ...] = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + b, + o_inv_scale, + o_zero_point, + output_dtype, + post_op_attr.unary_op_name, + post_op_attr.scalars_attr, + post_op_attr.algorithm_attr, + ) + else: + other = kwargs["other"] if "other" in kwargs else kwargs["accum"] + x2_scale = 1.0 + x2_zp = 0 + computation_args = ( + x, + x_scale, + x_zp, + packed_weight, + w_scale, + w_zp, + other, + b, + o_inv_scale, + o_zero_point, + output_dtype, + x2_scale, + x2_zp, + post_op_attr.binary_op_name, + post_op_attr.alpha, + post_op_attr.unary_op_name, + post_op_attr.scalars_attr, + post_op_attr.algorithm_attr, + ) + new_linear_node = match.graph.call_function( + computation_op, args=computation_args + ) + out_node.replace_all_uses_with(new_linear_node) + new_linear_node.meta.update(out_node.meta) + for node in reversed(match.nodes): + match.graph.erase_node(node) + count_key = ( + "qlinear_binary_matcher_count" + if has_binary_post_op + else "qlinear_unary_matcher_count" + ) + nodes_key = ( + "qlinear_binary_matcher_nodes" + if has_binary_post_op + else "qlinear_unary_matcher_nodes" + ) + counters["inductor"][count_key] += 1 + counters["inductor"][nodes_key] += len(match.nodes) + + +def _register_qlinear_unary_fusion(): + from .mkldnn_fusion import ( + _gelu_fusion_1 as _gelu_fusion_erf, + _gelu_fusion_2 as _gelu_fusion_tanh, + ) + + for original_pattern_output_dtype in [torch.float32, torch.bfloat16]: + is_bf16 = original_pattern_output_dtype == torch.bfloat16 + for x_scale_zp_are_tensors in (False, True): + qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors) + computation_op = ( + torch.ops.onednn.qlinear_pointwise.tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.default + ) + # Priority 1 to match: QLinear Unary pattern with int8 output + linear_unary_replace_patterns = { + PostOpAttr( + "none", None, "none", [], "" + ): generate_pattern_with_output_quant( + qlinear_pattern, + ), + PostOpAttr( + "none", None, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary(qlinear_pattern, aten.relu.default), + ), + PostOpAttr( + "none", None, "gelu", [], "none" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _gelu_fusion_erf, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 2 + ), + 2, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + PostOpAttr( + "none", None, "gelu", [], "tanh" + ): generate_pattern_with_output_quant( + _unary_fusion_pattern( + _gelu_fusion_tanh, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 4 + ), + 4, + is_bf16, + ), + with_dtype_convert=is_bf16, + ), + } + + for unary_attr, patterns in linear_unary_replace_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 3, # pass_number + computation_op, + unary_attr, # unary_attr + ) + + # Priority 2 to match: QLinear Unary pattern with FP32/BF16 output + linear_unary_replace_float_out_patterns = { + PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( + qlinear_pattern, aten.relu.default + ), + PostOpAttr( + "none", None, "gelu", [], "none" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _gelu_fusion_erf, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 2 + ), + 2, + is_bf16, + ), + Arg(), + is_bf16, + ), + PostOpAttr( + "none", None, "gelu", [], "tanh" + ): _may_generate_pattern_with_dtype_convert( + _unary_fusion_pattern( + _gelu_fusion_tanh, + get_qlinear_pt2e_pattern( + x_scale_zp_are_tensors, 1 if is_bf16 else 4 + ), + 4, + is_bf16, + ), + Arg(), + is_bf16, + ), + } + + for unary_attr, patterns in linear_unary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 4, # pass_number + computation_op, + unary_attr, # unary_attr + ) + + +def _register_qlinear_binary_fusion(): + r""" + Supported linear-binary(-unary) patterns + + linear(X) extra input + \ / + Add + | + Optional(relu) + | + Y + + 1. int8-mixed-fp32 + +---+---------------+-----------+------------------------------+---------+ + | # | Add type | Quant out | Pattern | Post op | + +---+---------------+-----------+------------------------------+---------+ + | 1 | In-/out-place | Yes | linear + fp32 -> (relu) -> q | add | + +---+---------------+-----------+------------------------------+---------+ + | 2 | In-/out-place | No | linear + fp32 -> (relu) | sum | + +---+---------------+-----------+------------------------------+---------+ + + 2. int8-mixed-bf16 + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | # | X2 dtype | Add type | Quant out | Pattern | Post op | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 1 | BF16 | In-/out-place | Yes | linear + bf16 -> (relu) -> q | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 2 | BF16 | In-/out-place | No | linear + bf16 -> (relu) | sum | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 3 | FP32 | Out-place | Yes | linear + fp32 -> (relu) -> q | add | + | | | In-place right| | | | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 4 | FP32 | Out-place | No | linear + fp32 -> (relu) | sum | + | | | In-place right| | | | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 5 | FP32 | In-place left | Yes | linear + fp32 -> to_bf16 -> (relu) -> q | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + | 6 | FP32 | In-place left | No | linear + fp32 -> to_bf16 -> (relu) | add | + +---+----------+---------------+-----------+-----------------------------------------+---------+ + + Note + (1) The positions of linear and the extra input can be swapped. + (2) we don't insert q-dq before the extra input of linear-add by recipe. But if q-dq is found at the + extra input, we don't match that pattern because we cannot match all these patterns in 3 passes. + """ + for x_scale_zp_are_tensors in (False, True): + qlinear_binary_op = ( + torch.ops.onednn.qlinear_pointwise.binary_tensor + if x_scale_zp_are_tensors + else torch.ops.onednn.qlinear_pointwise.binary + ) + unary_postop_list = ["none", "relu"] + unary_postop_dict = { + "none": None, + "relu": aten.relu.default, + } + convert_dtype_after_binary_list = [False, True] + + # Priority 1 to match: QLinear Binary or Binary-Unary pattern with int8 output + # Covers case (1) of int8-mixed-fp32 and case (1)(3)(5) of int8-mixed-bf16, + # totally 3 patterns (2 are identical) + swap_binary_inputs_list = [False, True] + int8_mixed_bf16_list = [False, True] + combinations = itertools.product( + unary_postop_list, + int8_mixed_bf16_list, + swap_binary_inputs_list, + convert_dtype_after_binary_list, + ) + qlinear_binary_replace_patterns = {} + for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary in combinations: + if not int8_mixed_bf16 and cvt_dtype_binary: + # No convert node after binary node if dtypes are all fp32 + continue + qlinear_binary_replace_patterns.update( + { + PostOpAttr( + "add", 1.0, unary_op, [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + # If fp32 extra input is inplace added to bf16 linear output, + # a to_bf16 node is inserted after binary + dtype_convert=cvt_dtype_binary, + swap_inputs=swap_inputs, + ), + unary_postop_dict[unary_op], + ), + ) + } + ) + for binary_unary_attr, patterns in qlinear_binary_replace_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 3, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + + # Priority 2.1 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output + # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16, + # totally 2 patterns (2 are identical) + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("accum"), + dtype_convert=False, + swap_inputs=swap_binary_inputs, + ), + aten.relu.default, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 4, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + # Priority 2.2 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output + # Covers case (6) of int8-mixed-bf16 + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr("add", 1.0, "relu", [], ""): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + dtype_convert=True, + swap_inputs=swap_binary_inputs, + ), + aten.relu.default, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 4, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + + # Priority 3.1: QLinear Binary pattern with fp32/bfloat16 output + # Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16, + # totally 2 patterns (2 are identical) + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("accum"), + dtype_convert=False, + swap_inputs=swap_binary_inputs, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 5, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + # Priority 3.2: QLinear Binary pattern with fp32/bfloat16 output + # Covers (6) of int8-mixed-bf16 + binary_replace_float_out_patterns = {} + for swap_binary_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + PostOpAttr( + "add", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_qlinear_pt2e_pattern(x_scale_zp_are_tensors), + KeywordArg("other"), + dtype_convert=True, + swap_inputs=swap_binary_inputs, + ), + } + ) + for ( + binary_unary_attr, + patterns, + ) in binary_replace_float_out_patterns.items(): + _register_qlinear_post_op_fusion_pass( + patterns, + 5, # pass_number + qlinear_binary_op, # computation_op + binary_unary_attr, + ) + + +@functools.cache +def _register_quantization_weight_pack_pass(): + # Step 1: Dequant promotion for int8-mixed-fp32/bf16 + _register_dequant_promotion() + + # Step 2: QConv weight prepack + _register_qconv_weight_prepack() + + # Step 3: QLinear weight prepack + _register_qlinear_weight_prepack() + _register_linear_dynamic_fp16_weight_prepack() + + # Step 4: weight prepack for SmoothQuant from Torchao + _register_smooth_quant_int_mm_pattern() + + # Step 5: QLinear post op Fusion + if not torch.ops.mkldnn._is_mkldnn_acl_supported(): + # skip fusion on ARM + _register_qconv_unary_fusion() + _register_qconv_binary_fusion() + _register_qlinear_unary_fusion() + _register_qlinear_binary_fusion() + + +def _is_valid_concat_linear_woq_int4_fusion(computation_nodes): + computation_op = torch.ops.aten._weight_int4pack_mm_for_cpu.default + act = computation_nodes[0].args[0] + wgt = computation_nodes[0].args[1] + in_feature_size = wgt.meta.get("val").size(1) # type: ignore[union-attr] + group_size = computation_nodes[0].args[2] + return len(computation_nodes) >= 2 and all( + ( + node.target == computation_op + and node.args[0] == act # share same activation + and ( + node.args[1].meta.get("val").size(1) == in_feature_size + ) # same in feature size + and (node.args[1] != wgt or gemm_idx == 0) + and node.args[1].op == "get_attr" # wgt are all constants + and node.args[2] == group_size # same group size + ) + for gemm_idx, node in enumerate(computation_nodes) + ) + + +def concat_linear_woq_int4(gm: torch.fx.GraphModule): + """ + Concat Linear optimization pass for WOQ int4 + This pass fuses the original pattern: + def ... + return (woq_int4(x, w1, group_size, scale_zp1), woq_int4(x, w2, group_size, scale_zp1) ...) + into a single operation: + def ... + concat_res = woq_int4(x, concat_w, group_size, concat_scale_zp) + return split(concat_res, split_size_list) + """ + + def concat_wgt(packed_wgts, scale_zps, group_size, act_dtype): + # Concat the wgts and scale_zps, and repack the wgt + unpacked_wgts = [] + for packed_wgt in packed_wgts: + # Get the unpacked weight list + # Same as https://github.com/pytorch/pytorch/pull/156174 + K = packed_wgt.size(1) * 2 + N = packed_wgt.size(0) + x = torch.eye(K).to(dtype=act_dtype) + qscales_and_zeros = ( + torch.tensor([1.0, 8.0]) + .to(dtype=act_dtype) + .expand(K // group_size, N, 2) + .contiguous() + ) + unpacked_wgts.append( + torch.ops.aten._weight_int4pack_mm_for_cpu( + x, + packed_wgt, + group_size, + qscales_and_zeros, + ) + .t() + .contiguous() + .to(torch.int32) # N, K + ) + concat_unpacked_wgt = torch.cat(unpacked_wgts, dim=0) + repack_w = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + concat_unpacked_wgt, 1 + ) + concat_scale_zp = torch.cat(scale_zps, dim=1).contiguous() + return repack_w, concat_scale_zp + + graph = gm.graph + computation_op = torch.ops.aten._weight_int4pack_mm_for_cpu.default + for node in graph.find_nodes(op="call_function", target=computation_op): + if ( + not node._erased + and isinstance(node.meta.get("val"), torch.Tensor) + and node.meta["val"].device.type == "cpu" + ): + act = node.args[0] + users = list(act.users) + if _is_valid_concat_linear_woq_int4_fusion(users): + with graph.inserting_before(node): + assert all(user.args[1].op == "get_attr" for user in users) + computation_node_0 = users[0] + packed_wgts = [getattr(gm, user.args[1].target) for user in users] + group_size = computation_node_0.args[2] + scale_zps = [getattr(gm, user.args[3].target) for user in users] + out_feature_size_list = [ + packed_wgt.size(0) for packed_wgt in packed_wgts + ] + repack_w, concat_scale_zp = concat_wgt( + packed_wgts, scale_zps, group_size, act.meta.get("val").dtype + ) + repack_w_node_name = computation_node_0.args[1].target + "_concat" + concat_scale_zp_node_name = ( + computation_node_0.args[3].target + "_concat" + ) + gm.register_buffer(repack_w_node_name, repack_w) + setattr(gm, repack_w_node_name, repack_w) + gm.register_buffer(concat_scale_zp_node_name, concat_scale_zp) + setattr(gm, concat_scale_zp_node_name, concat_scale_zp) + + repack_w_node = graph.create_node( + "get_attr", repack_w_node_name, (), {} + ) + with graph.inserting_after(repack_w_node): + concat_scale_zp_node = graph.create_node( + "get_attr", concat_scale_zp_node_name, (), {} + ) + + with graph.inserting_after(concat_scale_zp_node): + concat_int4_gemm_node = graph.create_node( + "call_function", + computation_op, + ( + act, + repack_w_node, + group_size, + concat_scale_zp_node, + ), + ) + with graph.inserting_after(concat_int4_gemm_node): + split_node = graph.create_node( + "call_function", + torch.ops.aten.split_with_sizes.default, + ( + concat_int4_gemm_node, + out_feature_size_list, + 1, # split dim + ), + ) + with graph.inserting_after(split_node): + for gemm_idx, user in enumerate(users): + assert user.target == computation_op + get_item = graph.create_node( + "call_function", + operator.getitem, + ( + split_node, + gemm_idx, + ), + ) + with graph.inserting_after(get_item): + clone_node = graph.create_node( + "call_function", + torch.ops.aten.clone.default, + (get_item,), + {"memory_format": torch.contiguous_format}, + ) + user.replace_all_uses_with(clone_node) + graph.erase_node(user) + + +def quant_lift_up(graph_module: torch.fx.GraphModule): + """ + Lift up the quant node before view like nodes. It can benefit performance + of Attention like block. For example, we have the pattern as: + + DQ + DQ LINEAR + LINEAR VIEW + VIEW PERMUTE + PERMUTE TRANSPOSE + Q Q + DQ DQ + Matmul + DIV + ADD + SOFTMAX + + We want to lift up the the quant nodes from matmul before view like nodes + as the output of Linear node. + + DQ + DQ LINEAR + LINEAR Q + Q VIEW + VIEW PERMUTE + PERMUTE TRANSPOSE + DQ DQ + Matmul + DIV + ADD + SOFTMAX + + It produces a DQ->LINEAR->Q pattern which can be fused by backend. + """ + + def is_view_op(node): + return node.op == "call_function" and node.target in _VIEW_OPS + + for node in graph_module.graph.nodes: + # Leslie: Here we verify that the quant node has exactly + # one input FX node, with constant scalar value for scale and zero point. + # For the case input of quant node has more than one input FX nodes, + # extend the implementation to lift up all the connected nodes + # before the view nodes to keep the topological order. + if ( + node.op == "call_function" + and node.target in _PER_TENSOR_QUANTIZE_OPS + and len(node.all_input_nodes) == 1 + and is_view_op(node.all_input_nodes[0]) + ): + quant_node = node + input_node_of_quant = quant_node.args[0] + + # Check the nodes along lift up path has only 1 user node + # Propagate view like node to find where to insert the new quant node + could_lift_up = True + current_node = quant_node + input_node = current_node.args[0] + while is_view_op(input_node): + if len(input_node.users) != 1: + could_lift_up = False + break + current_node = input_node + input_node = current_node.args[0] + + # Further check the input node of the first view node has only 1 user node + if could_lift_up and len(input_node.users) == 1: + # Replace dequant's input from quant to quant's input + quant_node.replace_all_uses_with(input_node_of_quant) + # Insert the new quant node + with graph_module.graph.inserting_before(current_node): + new_quant_node = graph_module.graph.node_copy(quant_node) + input_node.replace_all_uses_with(new_quant_node) + + # Update inputs of new_quant_node + def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node: + if n == input_node_of_quant: + return input_node + else: + return n + + new_args = map_arg(new_quant_node.args, maybe_replace_node) + new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node) + new_quant_node.args = new_args # type: ignore[assignment] + new_quant_node.kwargs = new_kwargs # type: ignore[assignment] + graph_module.graph.erase_node(quant_node) + + graph_module.graph.lint() + graph_module.recompile() diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/reinplace.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/reinplace.py new file mode 100644 index 0000000000000000000000000000000000000000..f723a69a6b69b99ce68a4786205f99c117b10bc3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/reinplace.py @@ -0,0 +1,766 @@ +# mypy: allow-untyped-defs +import itertools +import logging +import operator +from collections import defaultdict +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Callable, cast, Union + +import torch +import torch.fx.node +from torch._C._dynamo.guards import compute_overlapping_tensors +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.utils import ReinplaceCounters, ReInplaceTrigger +from torch._higher_order_ops.triton_kernel_wrap import ( + kernel_side_table, + triton_kernel_wrapper_functional, +) +from torch._inductor import config, inductor_prims +from torch._inductor.fx_utils import get_node_storage, is_node_realized +from torch._inductor.lowering import ( + inplaceable_foreach_ops as inplaceable_foreach_ops_lowerings, +) +from torch._inductor.virtualized import V +from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.passes.reinplace import _is_view_op +from torch.utils import _pytree as pytree +from torch.utils._ordered_set import OrderedSet + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + + +@dataclass(frozen=True) +class InplaceableOp: + inplace_op: Callable[..., Any] + mutated_arg: int + extra_check: Callable[[torch.fx.Node], bool] = lambda node: True + + +_SCATTER_OP_TO_VIEW = { + torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, + torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, + torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor, + torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default, +} +_VIEW_OP_TO_SCATTER = {v: k for k, v in _SCATTER_OP_TO_VIEW.items()} + + +def graph_call_function(graph: torch.fx.Graph, fn, *args, **kwargs): + fake_args, fake_kwargs = pytree.tree_map( + lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node, + (args, kwargs), + ) + with V.fake_mode: + fake_result = fn(*fake_args, **fake_kwargs) + + node = graph.call_function(fn, args, kwargs) + node.meta["val"] = fake_result + return node + + +@dataclass +class ViewOp: + target: torch._ops.OpOverload + args: tuple[Any, ...] + kwargs: dict[str, Any] + + +def _inplace_generalized_scatter( + inp: torch.Tensor, src: torch.Tensor, view_ops: list[ViewOp] +) -> torch.Tensor: + tmp = inp + for view in view_ops: + fake_args, fake_kwargs = pytree.tree_map( + lambda node: node.meta["val"] if isinstance(node, torch.fx.Node) else node, + (view.args, view.kwargs), + ) + tmp = view.target(tmp, *fake_args, **fake_kwargs) + try: + tmp.copy_(src) + except RuntimeError as e: + raise RuntimeError( + f"shape error in scatter op, can not broadcast {src.shape} to {tmp.shape}" + ) from e + return inp + + +def _generalized_scatter( + inp: torch.Tensor, src: torch.Tensor, view_ops: list[ViewOp] +) -> torch.Tensor: + out = inp.clone() + return _inplace_generalized_scatter(out, src, view_ops) + + +def _decompose_scatter_functional_helper( + graph: torch.fx.Graph, + inp: torch.Tensor, + src: torch.Tensor, + view_ops: list[ViewOp], +) -> torch.fx.Node: + view_op, view_ops_tail = view_ops[0], view_ops[1:] + + if view_ops_tail: + view = graph_call_function( + graph, view_op.target, inp, *view_op.args, **view_op.kwargs + ) + src = _decompose_scatter_functional_helper(graph, view, src, view_ops[1:]) # type: ignore[assignment] + + return graph_call_function( + graph, + _VIEW_OP_TO_SCATTER[view_op.target], + inp, + src, + *view_op.args, + **view_op.kwargs, + ) + + +def _decompose_scatter_functional( + graph: torch.fx.Graph, node: torch.fx.Node +) -> torch.fx.Node: + """Decompose _generalized_scatter to a sequence of view_scatter operations + + e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)]) + + will become + + view = aten.slice(inp, 0, 0, 10) + view_updated = aten.slice_scatter(view, src, 1, 10, -10) + inp_updated = aten.slice_scatter(inp, view_updated, 0, 0, 10) + """ + assert node.target is _generalized_scatter + return _decompose_scatter_functional_helper(graph, *node.args) # type: ignore[arg-type] + + +def _decompose_scatter_mutating( + graph: torch.fx.Graph, node: torch.fx.Node +) -> torch.fx.Node: + """Decompose _generalized_scatter using mutations + + e.g. _generalized_scatter(inp, src, [(aten.slice, 0, 0, 10), (aten.slice, 1, 10, -10)]) + + will become + + inp_updated = aten.clone(inp) + slice1 = aten.slice(inp_updated, 0, 0, 10) + slice2 = aten.slice(slice1, 1, 10, -10) + slice2.copy_(src) + + """ + assert node.target in (_generalized_scatter, _inplace_generalized_scatter) + inp, src, view_ops = node.args + assert not node.kwargs + + if node.target is _generalized_scatter: + inp = graph_call_function(graph, aten.clone, inp) + + tmp = inp + for view in view_ops: # type: ignore[union-attr] + tmp = graph_call_function(graph, view.target, tmp, *view.args, **view.kwargs) # type: ignore[union-attr] + + graph_call_function(graph, aten.copy_.default, tmp, src) + return inp # type: ignore[return-value] + + +# View ops whose view_scatter op is lowered into mutations anyway, +# so is never a pessimisation to decompose. +_ALWAYS_MUTATING_SCATTER_OPS = OrderedSet( + [ + aten.as_strided.default, + aten.diagonal.default, + ] +) + + +def scatter_always_uses_mutation(node: torch.fx.Node) -> bool: + _, _, view_ops = node.args + view_ops = cast(Sequence[torch.fx.node.Argument], view_ops) + return any( + target in _ALWAYS_MUTATING_SCATTER_OPS + for view in view_ops + if isinstance(target := getattr(view, "target", None), torch._ops.OpOverload) + ) + + +def should_reinplace_scatter(node: torch.fx.Node) -> bool: + """Choose between mutating and functional scatter decompositions + + Reinplacing view scatter ops can be pessimising as it blocks fusion with the + input or output tensor computations. However, it is still profitable if the + input and output would have been realized anyway. + + """ + inp, _src, _view_ops = node.args + + # Mutating scatter ops unconditionally realize input and output + if scatter_always_uses_mutation(node): + return True + + if is_node_realized(inp) and is_node_realized(node): # type: ignore[arg-type] + return True + + # If the output is copied back into the input, this forces both to be + # realized as the output is a user of the input + if inp.op in ("placeholder", "get_attr") and any( # type: ignore[union-attr] + user.target is aten.copy_.default and user.args[0] is inp for user in node.users + ): + return True + + # Otherwise, assume fusions will make functional variants profitable + return False + + +def decompose_generalized_scatter(graph: torch.fx.Graph) -> None: + """Replace _generalized_scatter with normal aten ops""" + for node in itertools.chain( + graph.find_nodes(op="call_function", target=_generalized_scatter), + graph.find_nodes(op="call_function", target=_inplace_generalized_scatter), + ): + use_mutation = ( + node.target is _inplace_generalized_scatter + or scatter_always_uses_mutation(node) + ) + + with graph.inserting_before(node): + if use_mutation: + new_node = _decompose_scatter_mutating(graph, node) + else: + new_node = _decompose_scatter_functional(graph, node) + + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + +def canonicalize_view_scatter_ops(graph: torch.fx.Graph) -> None: + """ + This canonicalizes view scatter ops into a generalized form, defined as: + def scatter(inp, src, views): + tmp = inp.clone() + for view in views: + tmp = view(tmp) + tmp.copy_(src) + + We also fuse consecutive view scatter ops of the form + a = scatter(view2(self), src, [view1]) + b = scatter(self, a, [view2]) + which can be rewritten as + b = scatter(self, src, [view2, view1]) + a = view2(b) + + This is both more efficient as we only do a single scatter, and also + easier to reinplace since there is only one use of `self` + """ + + node_to_view_base: dict[torch.fx.Node, torch.fx.Node] = {} + node_to_view_op: dict[torch.fx.Node, list[ViewOp]] = defaultdict(list) + + def handle_views(node: torch.fx.Node): + inp = node.args[0] + node_to_view_base[node] = node_to_view_base.get(inp, inp) # type: ignore[arg-type, assignment] + node_to_view_op[node] = [ + *node_to_view_op[inp], # type: ignore[index] + ViewOp( + node.target, # type: ignore[arg-type] + args=node.args[1:], + kwargs=node.kwargs, + ), + ] + + def handle_view_scatter(node: torch.fx.Node): + assert len(node.args) >= 2 + inp, src = node.args[:2] + + assert isinstance(node.target, torch._ops.OpOverload) + scatter_view_op = ViewOp( + _SCATTER_OP_TO_VIEW[node.target], + args=node.args[2:], + kwargs=node.kwargs, + ) + + def can_fuse(): + if src.target is not _generalized_scatter: # type: ignore[union-attr] + return False + src_inp, _src_src, _src_scatter_view_op = src.args # type: ignore[union-attr] + + inp_base = node_to_view_base.get(inp, inp) # type: ignore[arg-type] + src_base = node_to_view_base.get(src_inp, src_inp) # type: ignore[arg-type] + return inp_base is src_base and node_to_view_op[src_inp] == [ # type: ignore[index] + *node_to_view_op[inp], # type: ignore[index] + scatter_view_op, + ] + + if not can_fuse(): + with graph.inserting_before(node): + new_node = graph_call_function( + graph, + _generalized_scatter, + inp, + src, + [scatter_view_op], + ) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + return + + _src_inp, src_src, src_scatter_view_op = src.args # type: ignore[union-attr] + with graph.inserting_before(src): # type: ignore[arg-type] + new_node = graph_call_function( + graph, + _generalized_scatter, + inp, + src_src, + [scatter_view_op, *src_scatter_view_op], # type: ignore[misc] + ) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + if src.users: # type: ignore[union-attr] + new_src = graph_call_function( + graph, + _SCATTER_OP_TO_VIEW[node.target], + new_node, + *node.args[2:], + **node.kwargs, + ) + + handle_views(new_src) + src.replace_all_uses_with(new_src) # type: ignore[union-attr] + + graph.erase_node(src) # type: ignore[arg-type] + + for node in graph.nodes: + if _is_view_op(node.target): + handle_views(node) + elif node.target in _SCATTER_OP_TO_VIEW: + handle_view_scatter(node) + + +inplaceable_ops: dict[Callable[..., Any], InplaceableOp] = { + aten.index_put.default: InplaceableOp(aten.index_put_.default, 0), + aten._unsafe_index_put.default: InplaceableOp(inductor_prims._unsafe_index_put_, 0), + _generalized_scatter: InplaceableOp( + _inplace_generalized_scatter, + 0, + extra_check=should_reinplace_scatter, + ), +} + +try: + c10d_functional = torch.ops._c10d_functional + inplaceable_collective_ops: dict[Callable[..., Any], InplaceableOp] = { + c10d_functional.all_reduce.default: InplaceableOp( + c10d_functional.all_reduce_.default, 0 + ), + c10d_functional.all_reduce_coalesced.default: InplaceableOp( + c10d_functional.all_reduce_coalesced_.default, 0 + ), + } + inplaceable_ops.update(inplaceable_collective_ops) +except AttributeError: + # _c10d_functional ops are only available when torch + # is built with USE_DISTRIBUTED=1. + pass + +inplaceable_foreach_ops: dict[torch._ops.OpOverload, InplaceableOp] = {} +for outplace_op, inplace_op in inplaceable_foreach_ops_lowerings.items(): + inplaceable_foreach_ops[outplace_op] = InplaceableOp(inplace_op, 0) + + +inplaceable_triton_ops = OrderedSet([triton_kernel_wrapper_functional]) + + +# Operators that don't depend on the tensor data +META_ONLY_OPS = OrderedSet( + [ + aten.sym_size.int, + aten.sym_stride.int, + aten.sym_numel.default, + aten.sym_storage_offset.default, + ] +) + + +def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None: + """ + Reinplaces in-placeable operations. + If there are no uses of a view of the mutated arg after the current node, + it is possible to inplace the op. + This above algorithm could be justified by observing side effects. While + we traverse the graph in forwards direction, only latter nodes could view + side effects of the current node. If the current node is not used later as + well as no view of this node is used later in the graph, then it is safe to + inplace as there would be no way to observe the side effects. + This condition is slightly different for graph inputs where they can only + be inplaced if the above condition is true and there's a copy_ in the + epilogue that signals that the caller wants to observe the mutation. + + Unlike JIT Inductor, AOTInductor currently unlifts weights and buffers from + input args, so instead of checking mutation on placeholder, AOTInductor + checks mutation on get_attr. This is subject to change in future. + """ + + copy_args_to_copy_nodes = {} + # maps argument to the first copy_ node that mutates it. + copy_nodes = {} + mutated_inputs = OrderedSet[Any]() + storage_to_nodes = defaultdict(list) + node_order: dict[Any, int] = {} + for i, node in enumerate(reversed(graph.nodes)): + node_order[node] = len(graph.nodes) - i - 1 + storage_to_nodes[get_node_storage(node)].append(node) + if node.target == aten.copy_.default and node.args[0].op in ( + "placeholder", + "get_attr", + ): + dst = node.args[0] + src = node.args[1] + # If the target is a getitem and it indexes a possible clone, + # then skip over it + if src.target == operator.getitem and ( + ( + src.args[0].target == triton_kernel_wrapper_functional + and src.args[0].kwargs["kwargs"][src.args[1]] == node.args[0] + ) + or (src.args[0].target in inplaceable_foreach_ops) + or (src.args[0].target == torch.ops.higher_order.auto_functionalized) + ): + src = src.args[0] + + copy_args_to_copy_nodes[(dst, src)] = node + copy_nodes[dst] = node + + mutated_inputs.add(node.args[0]) + + def any_use_of_views_after_node(node, shared_view_nodes, *, copy_node, mutated_arg): + node_loc = node_order[node] + copy_node_loc = node_order[copy_node] if copy_node is not None else None + + def is_meta_only_user(node): + if _is_view_op(node.target): + return all(is_meta_only_user(u) for u in node.users) + return node.target in META_ONLY_OPS + + for view in shared_view_nodes: + for user in view.users: + user_loc = node_order[user] + # Skip all users before node + if user_loc <= node_loc: + continue + # Ignore uses after the copy_ epilogue node, where the input + # has already been mutated anyway + if copy_node_loc is not None and copy_node_loc <= user_loc: + continue + # Reinplacing does not change shape metadata + if is_meta_only_user(user): + continue + # If our graph looks like: + # foo(mutated_arg) + # mutated_arg.copy_(other) + # then it's safe for us to reinplace foo because mutated_arg + # will get overwritten anyways. + if ( + user.target is torch.ops.aten.copy_.default + and mutated_arg is user.args[0] + ): + continue + return True + return False + + def can_inplace(node, mutated_arg): + # ls should be a list of tensors that all shares the same storage. + def _overlap(ls) -> bool: + try: + return len(compute_overlapping_tensors(ls)) != 0 + except GuardOnDataDependentSymNode: + # If we fail with data dependent error we assume they all overlap. + return True + + if isinstance(mutated_arg, (list, tuple)): + # TODO Using _overlap here causes a several issues. + unique_storages = OrderedSet(get_node_storage(arg) for arg in mutated_arg) + if len(unique_storages) != len(mutated_arg): + # At least two Tensors in mutated_arg alias each other, so we can't reinplace it. + # We can probably do better (that is, reinplace one of them and clone the other) + # but that requires more work and mutable List[Tensor] are not that common. + return False + return all(can_inplace(node, arg) for arg in mutated_arg) + + if get_node_storage(mutated_arg) is None: + return False + + shared_view_nodes = storage_to_nodes[get_node_storage(mutated_arg)] + + # Only keep tensor that might overlap with mutated_arg. + shared_view_nodes = [ + v + for v in shared_view_nodes + if _overlap([mutated_arg.meta["val"], v.meta["val"]]) + ] + + if mutated_arg.op in ("placeholder", "get_attr"): + # Get the first copy_ node that mutates the mutated_arg. + copy_node = copy_nodes.get(mutated_arg, None) + if copy_node is None: + # There is no copy_ back to the candidate mutated_arg (which is a graph input). + # Therefore the semantics of the program are that it does not mutate + # mutated_arg, so we cannot re-inplace it. + return False + if any_use_of_views_after_node( + node, shared_view_nodes, copy_node=copy_node, mutated_arg=mutated_arg + ): + return False + + return True + elif any(view.op in ("placeholder", "get_attr") for view in shared_view_nodes): + # This should never happen in auto_functionalize_v2 non-inference mode, + # since all mutated_arg are bases. + + # If mutated arg is view of any of the inputs of the graph, + # do not allow for inplacing. + # This would require more sophisticated algorithm to handle + return False + else: + return not any_use_of_views_after_node( + node, shared_view_nodes, copy_node=None, mutated_arg=mutated_arg + ) + + def log_inplace_results( + node_name, + old_tensors_to_clone, + tensors_to_clone, + missed_args, + missed_nodes, + trigger, + ): + # Total size of possibly_missed_reinplacing_opportunities for tensors with static shapes. + missed_bytes = 0 + + def bytes(node): + t = node.meta.get("val", None) + if ( + t is not None + and isinstance(t.element_size(), int) + and isinstance(t.numel(), int) + ): + return t.element_size() * t.numel() + else: + return 0 + + for node in missed_nodes: + if isinstance(node, (list, tuple)): + for n in node: + missed_bytes += bytes(n) + else: + missed_bytes += bytes(node) + + log.info( + "For node %s, attempted to reinplace %s. We were unable to reinplace %s; " + "%s (if non-empty) are possible missed reinplacing opportunities that may be bad for " + "memory usage and performance. Total size of missed opportunities with static shapes is" + " : %s bytes.", + node_name, + old_tensors_to_clone, + tensors_to_clone, + missed_args, + missed_bytes, + ) + + ReinplaceCounters.add_missed_opportunities(trigger, len(missed_args)) + ReinplaceCounters.add_missed_bytes(trigger, missed_bytes) + + replace_dict: dict[torch.fx.Node, torch.fx.Node] = {} + + def reinplace_and_refine_tensors_to_clone( + old_tensors_to_clone, kwargs, node_name, trigger + ): + tensors_to_clone: list[str] = [] + storage_of_reinplaced_args = OrderedSet[Union[int, None]]() + + # Those used to count possibly_missed_reinplacing_opportunities + missed_nodes = [] + missed_args = [] + + # TODO this logic can be made more precise using _overlap + def tensor_with_same_storage_already_reinplaced(arg): + if isinstance(arg, (list, tuple)): + return any( + get_node_storage(a) in storage_of_reinplaced_args for a in arg + ) + return get_node_storage(mutated_arg) in storage_of_reinplaced_args + + for arg in old_tensors_to_clone: + assert arg in kwargs + + mutated_arg = kwargs[arg] + + # Let's say we have: + # - op(x, y) that mutates both x and y + # - new_x, new_y = functional_op(x, y) is the functional variant + # If we are presented with functional_op(x, x), we must not reinplace + # this into op(x, x), because then it would be writing to the same Tensor. + # Instead, it's OK to reinplace one of them and to clone the other: + # >>> y = x.clone() + # >>> op(x, y) + # This also applies if we have views: functional_op(x, x[0]) + # should not reinplace into op(x, x[0]). + should_attempt_reinplace = not tensor_with_same_storage_already_reinplaced( + mutated_arg + ) + if should_attempt_reinplace and can_inplace(node, mutated_arg): + # In general, we probably do not need those optimizations. + copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) + if copy_node is not None: + replace_dict[copy_node] = copy_node.args[0] + if not trigger == ReInplaceTrigger.AUTO_FUNC_V2: + for user in node.users: + # For auto_functionalize_v2, arg is the index of the base, where base at index i corresponds to + # output atindex size(out)+i. + # This used to compare string with integers before for auto_functionalize_v2. Not sure + # if it was needed for inplaceable_triton_ops? + if user.target == operator.getitem and user.args[1] == arg: + replace_dict[user] = mutated_arg + + if isinstance(mutated_arg, (list, tuple)): + for a in mutated_arg: + storage_of_reinplaced_args.add(get_node_storage(a)) + else: + storage_of_reinplaced_args.add(get_node_storage(mutated_arg)) + else: + if should_attempt_reinplace: + missed_args.append(arg) + missed_nodes.append(mutated_arg) + + tensors_to_clone.append(arg) + + log_inplace_results( + node_name, + old_tensors_to_clone, + tensors_to_clone, + missed_args, + missed_nodes, + trigger, + ) + return tensors_to_clone + + for node in graph.nodes: + if (inplaceable_op := inplaceable_ops.get(node.target, None)) is not None: + mutated_arg = node.args[inplaceable_op.mutated_arg] + if can_inplace(node, mutated_arg) and inplaceable_op.extra_check(node): + # TODO(yifu): this doesn't properly remove copy epilogues for + # ops that mutate multiple inputs. Need to revise the copy + # node tracking logic to support the case. + copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) + if copy_node is not None: + replace_dict[copy_node] = copy_node.args[0] + node.target = inplaceable_op.inplace_op + elif node.target == torch.ops.higher_order.auto_functionalized_v2: + _mutable_op = node.args[0] + kwargs = node.kwargs + + all_bases = kwargs["_all_bases"] + bases_to_clone = range(len(all_bases)) + base_tensors_dct = dict(enumerate(all_bases)) + new_bases_to_clone: list[int] = reinplace_and_refine_tensors_to_clone( + bases_to_clone, + base_tensors_dct, + node.target, + ReInplaceTrigger.AUTO_FUNC_V2, + ) + # Stash the metadata. There is a pass later on where we decompose + # auto_functionalized into clones + a mutable op; this metadata + # tells the decomp to only clone the following inputs + node.meta["only_clone_these_tensors"] = new_bases_to_clone + elif node.target == torch.ops.higher_order.auto_functionalized: + _mutable_op = node.args[0] + from torch._higher_order_ops.auto_functionalize import get_mutable_args + + tensors_to_clone, _ = get_mutable_args(_mutable_op) + # Don't try to reinplace Optional[Tensor] args that are None. + tensors_to_clone = [ + t for t in tensors_to_clone if node.kwargs[t] is not None + ] + tensors_to_clone = reinplace_and_refine_tensors_to_clone( + tensors_to_clone, + node.kwargs, + _mutable_op._name, + ReInplaceTrigger.AUTO_FUNC_V1, + ) + + # Stash the metadata. There is a pass later on where we decompose + # auto_functionalized into clones + a mutable op; this metadata + # tells the decomp to only clone the following inputs + node.meta["only_clone_these_tensors"] = tensors_to_clone + elif node.target in inplaceable_triton_ops: + kernel_idx = node.kwargs["kernel_idx"] + kernel = kernel_side_table.get_kernel(kernel_idx) + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + + if isinstance(kernel, JITFunction): + kernel_name = kernel.fn.__name__ + elif isinstance(kernel, Autotuner): + if config.is_fbcode(): + # Autotuner has different implementations for AMD and NV + if torch.version.hip is None: + kernel_name = kernel.base_fn.__name__ + else: + kernel_name = kernel.fn.__name__ + else: + kernel_name = kernel.base_fn.__name__ + else: + raise AssertionError("Unknown triton kernel type") + + # inplaceable_triton_ops take an additional argument called + # tensors_to_clone which contain a list of tensors to clone + # This pass iterates over them and sees which ones are safe + # to eliminate (i.e. no longer need the clones) + tensors_to_clone = reinplace_and_refine_tensors_to_clone( + node.kwargs["tensors_to_clone"], + node.kwargs["kwargs"], + kernel_name, + ReInplaceTrigger.TRITON_OPS, + ) + + kwargs = dict(node.kwargs) + kwargs["tensors_to_clone"] = tensors_to_clone + node.kwargs = immutable_dict(kwargs) + if "eager_input_vals" in node.meta: + # We changed the kwargs, so we need to update eager_input_vals + # to something sane. + args, kwargs = node.meta["eager_input_vals"] + new_kwargs = {**kwargs} + new_kwargs["tensors_to_clone"] = immutable_list(tensors_to_clone) + new_kwargs = immutable_dict(new_kwargs) + node.meta["eager_input_vals"] = (args, new_kwargs) + elif ( + inplaceable_op := inplaceable_foreach_ops.get(node.target, None) + ) is not None: + mutated_args = node.args[inplaceable_op.mutated_arg] + + if not all((arg, node) in copy_args_to_copy_nodes for arg in mutated_args): + continue + + if can_inplace(node, mutated_args): + for arg in mutated_args: + copy_node = copy_args_to_copy_nodes[(arg, node)] + replace_dict[copy_node] = copy_node.args[0] + + node.target = inplaceable_op.inplace_op + for node, replacement in replace_dict.items(): + while replacement in replace_dict: + replacement = replace_dict[replacement] + replace_dict[node] = replacement + + node.replace_all_uses_with(replacement) + graph.erase_node(node) + + +def reinplace_inplaceable_ops(graph: torch.fx.Graph) -> None: + with enable_python_dispatcher(): + canonicalize_view_scatter_ops(graph) + reinplace_inplaceable_ops_core(graph) + decompose_generalized_scatter(graph) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/replace_random.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/replace_random.py new file mode 100644 index 0000000000000000000000000000000000000000..911dc76623076d1492079f97d3a118cf88f421c0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/replace_random.py @@ -0,0 +1,143 @@ +# mypy: allow-untyped-defs +import collections +import logging + +import torch +from torch.fx.passes.graph_transform_observer import GraphTransformObserver +from torch.fx.passes.shape_prop import _extract_tensor_metadata + +from .. import config, inductor_prims +from ..pattern_matcher import ( + CallFunctionVarArgs, + Match, + PatternMatcherPass, + register_graph_pattern, +) +from ..virtualized import V + + +log = logging.getLogger(__name__) +patterns = PatternMatcherPass() +aten = torch.ops.aten + + +def replace_random_passes(gm: torch.fx.GraphModule): + """Modify the given FX graph to use backend-native random ops""" + if config.fallback_random: + return 0 + + count = patterns.apply(gm) + with GraphTransformObserver(gm, "fuse_seed_creation_pass"): + count += fuse_seed_creation_pass(gm.graph) + + return count + + +def fuse_seed_creation_pass(graph: torch.fx.Graph): + """ + Horizontally fuse all the seed generation on each device + + a = inductor_seed(dev) + b = inductor_seed(dev) + + Becomes: + seeds = inductor_seeds(2, dev) + a = inductor_lookup_seed(seeds, 0) + b = inductor_lookup_seed(seeds, 1) + + We do this because seed creation is entirely launch overhead bound. + """ + device_seeds = collections.defaultdict(list) + for node in graph.nodes: + if CallFunctionVarArgs(inductor_prims.seed).match(node): + device_seeds[node.args[0]].append(node) + + if not device_seeds: + return 0 + + for device, seeds in device_seeds.items(): + with graph.inserting_before(seeds[0]): + combined = graph.call_function(inductor_prims.seeds, (len(seeds), device)) + with V.fake_mode: + combined.meta["val"] = torch.empty( + [len(seeds)], device=device, dtype=torch.int64 + ) + combined.meta["tensor_meta"] = _extract_tensor_metadata( + combined.meta["val"] + ) + + for idx, seed in enumerate(seeds): + with graph.inserting_before(seed): + new_seed = graph.call_function( + inductor_prims.lookup_seed, (combined, idx) + ) + seed.replace_all_uses_with(new_seed) + new_seed.meta.update(seed.meta) + graph.erase_node(seed) + + return len(device_seeds) + + +def default_kwargs(device): + return {} + + +def get_device(device): + if device is not None: + return device + return torch.empty([]).device # default device + + +@register_graph_pattern(CallFunctionVarArgs(aten.rand.default), pass_dict=patterns) +@register_graph_pattern(CallFunctionVarArgs(aten.rand.generator), pass_dict=patterns) +@register_graph_pattern(CallFunctionVarArgs(aten.randn.default), pass_dict=patterns) +@register_graph_pattern(CallFunctionVarArgs(aten.randn.generator), pass_dict=patterns) +def replace_random( + match: Match, + size, + *, + generator=None, + dtype=None, + device=None, + layout=None, + pin_memory=None, +): + if generator is not None: + return + + def replacement(size): + result = inductor_prims.random( + size, inductor_prims.seed(device), mode, **default_kwargs(device) + ) + if dtype is not None: + result = result.to(dtype) + return result + + mode = { + aten.rand: "rand", + aten.randn: "randn", + }[ + match.output_node().target.overloadpacket # type: ignore[union-attr] + ] # type: ignore[union-attr] + device = get_device(device) + match.replace_by_example(replacement, [size]) + + +@register_graph_pattern(CallFunctionVarArgs(aten.randint.low), pass_dict=patterns) +def replace_randint( + match: Match, + low, + high, + size, + *, + dtype=torch.int64, + device=None, + layout=None, + pin_memory=None, +): + def replacement(low, high, size): + result = inductor_prims.randint(low, high, size, inductor_prims.seed(device)) + return result.to(dtype) + + device = get_device(device) + match.replace_by_example(replacement, [low, high, size]) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14ed88eee3f3aef8e41cd098fae8cc299790c2f0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa72a82ecfc4672e8a8eab8e78b5fc72482e493c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_1.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a921e704176e8716036cca47575994c8732732b5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_10.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a55ced9da8c34915195f6a7ce92d0d121745f716 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_11.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2bb3523f64bf308ab721ac7c11a5a2bdb9c3e22 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_12.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..780af09585b4cbe363eb4209edeb645448507ce3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_13.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6f6c8e6263d218c0d314b1f8137c96faa660dc1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_14.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6398067231ec467c2926e90d6520234e50459b3f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_15.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e15ebb9eec1b3508acda1a12170fde12559e6d0d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_16.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e15e364bb516503632f3f705baa1e1591c820ac Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_17.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_18.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_18.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a09f1525742bd8fd1653dc4dd4745f576636cfae Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_18.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_19.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_19.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68e36c8b7a7ba125ac8fe306d94ab28ee2fc9f9b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_19.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b19f27b3eb8bb40c8373172da36b2388bb294f2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_2.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_20.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_20.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef4aaf4dcb965b2ecb30f2648c6ade740f06a3c8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_20.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_21.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_21.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5c473c1348c4876516642021ebb84c0fa2acfe1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_21.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_22.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_22.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fd267599aa054ba77bdaa589e6202c58656895a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_22.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_23.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_23.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c2f836b472117ded9f32f7de518ee888462685f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_23.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48b7cc798068065dec103bb9235dbf60bd687b02 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_3.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e688b6c819eb083eef950441ee2998af50f2e84 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_4.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f03fcbb77349a807d87c063238bf814c14335f16 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_5.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c40df374abe1d6253c8108eccd170f26731fc15 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_6.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da14445cfb236015fdb3d1db3ae23c03eb6b97bf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_7.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11b7d41461070b7f2228440bdbd372078d4f399c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_8.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b701e2302d184ceced8075d313899821dc06491f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/_sfdp_pattern_9.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/addmm_pattern.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/addmm_pattern.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c79458e4e7d2c046550a2ec98db8a84203cf1d7d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/addmm_pattern.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/bmm_pattern.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/bmm_pattern.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02ea0c97ffdfa8dba31823f58cb9f3c168c9931a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/bmm_pattern.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/mm_pattern.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/mm_pattern.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..038a8aa9ad13830253b4b9fb5d22800446ea7aee Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/__pycache__/mm_pattern.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a8d6c7183138bef3d6dfd273cfca4b34566ec3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_1.py @@ -0,0 +1,174 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_1_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_1_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_1_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py new file mode 100644 index 0000000000000000000000000000000000000000..c7b5ebdf5f8876b37176e60927cc54a38f958824 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py @@ -0,0 +1,205 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_10_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_10_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, convert_element_type_default_3, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_10_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_10_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py new file mode 100644 index 0000000000000000000000000000000000000000..49db4465614c9e5c1ab86641ffd759939f990f4d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_11.py @@ -0,0 +1,204 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_11_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_11_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_11_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_11_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py new file mode 100644 index 0000000000000000000000000000000000000000..d3df4d11773a3e02ebaa5319f3b2b35ff034fe2a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_12.py @@ -0,0 +1,220 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_12_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_12_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_12_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_12_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py new file mode 100644 index 0000000000000000000000000000000000000000..e1c0458ab3a422615c623fc8d63e0163c279e4f9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_13.py @@ -0,0 +1,130 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2) +amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) +neg_default = CallFunction(aten.neg.default, div_Tensor) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4, _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, fma_default, permute_default_2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, fma_default) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1')) +_sfdp_pattern_13_training = MultiOutputPattern([bmm_default_1, + bmm_default_3, + permute_default_4, + bmm_default_5, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default, _users=2) +amax_default = CallFunction(aten.amax.default, bmm_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, bmm_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +_sfdp_pattern_13_inference = CallFunction(aten.bmm.default, div_Tensor, KeywordArg('value'), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default) +convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, mul_Tensor_1, KeywordArg('value')) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default_1) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, bmm_default_2, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, convert_element_type_default_5, permute_default_2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, convert_element_type_default_5) +permute_default_4 = CallFunction(aten.permute.default, bmm_default_4, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, mul_Tensor_1, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, KeywordArg('tangents_1')) +_sfdp_pattern_13_half_training = MultiOutputPattern([bmm_default_1, + bmm_default_3, + permute_default_4, + bmm_default_5, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('query'), permute_default) +convert_element_type_default = CallFunction(prims.convert_element_type.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +_sfdp_pattern_13_half_inference = CallFunction(aten.bmm.default, convert_element_type_default_1, KeywordArg('value'), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py new file mode 100644 index 0000000000000000000000000000000000000000..f7bf0ba0aa3ca403f155bdb8d40e20249858a24b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_14.py @@ -0,0 +1,210 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_14_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_14_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_14_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_14_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py new file mode 100644 index 0000000000000000000000000000000000000000..eaa03b0ba62cf6f7927f22f81fa7eac42918b3e9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_15.py @@ -0,0 +1,230 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_8, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_9 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_15_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_15_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_8, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_9 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_15_half_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_15_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py new file mode 100644 index 0000000000000000000000000000000000000000..618980029aac68b8f01eb2406bc85146108e2e68 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_16.py @@ -0,0 +1,599 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_mask_fp32_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_mask_fp32_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_16_half_mask_fp32_bs1_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_16_half_mask_fp32_bs1_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py new file mode 100644 index 0000000000000000000000000000000000000000..d4251ec92b4ede82c3f1327af5ebeb3b6a0f0eeb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_17.py @@ -0,0 +1,246 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_8, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_9 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_17_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_17_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_8, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_5) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, KeywordArg('inv_scale')) +view_default_9 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_17_half_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_3, KeywordArg('inv_scale')) +where_self = CallFunction(aten.where.self, expand_default, full_default, div_Tensor) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_17_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py new file mode 100644 index 0000000000000000000000000000000000000000..845e4ccc0c19a6ca918b48a8a0558f467d8184ec --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py @@ -0,0 +1,453 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_bs1_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_bs1_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_5, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_half_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_half_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_5, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_18_half_bs1_training = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_18_half_bs1_inference = MultiOutputPattern([view_default_5, + permute_default_1, + permute_default_3 +]) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py new file mode 100644 index 0000000000000000000000000000000000000000..b3bb74ab9f7102533416786deeb903ee5dad858e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py @@ -0,0 +1,209 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), fma_default, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_19_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_19_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_1, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +where_self_1 = CallFunction(aten.where.self, KeywordArg('causal_mask'), convert_element_type_default_3, scalar_tensor_default) +div_Tensor_2 = CallFunction(aten.div.Tensor, where_self_1, full_default) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_19_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) +full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) +add_Tensor = CallFunction(aten.add.Tensor, where_self, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_19_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py new file mode 100644 index 0000000000000000000000000000000000000000..6c7a664f3194c1da490cde9c5e48aecd52b53609 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_2.py @@ -0,0 +1,174 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_2_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_2_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_1, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_1) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_2_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_2_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py new file mode 100644 index 0000000000000000000000000000000000000000..4c8a046f78459af74e3cd87f0d660a6f3974a302 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py @@ -0,0 +1,244 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_3, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_8, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, fma_default) +view_default_9 = CallFunction(aten.view.default, where_self_1, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_10, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_20_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_3, _users=2) +amax_default = CallFunction(aten.amax.default, where_self, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, where_self, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_3 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_20_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +view_default_6 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +scalar_tensor_default = CallFunction(aten.scalar_tensor.default, Ignored(), dtype=Ignored(), layout=torch.strided, device=Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_7 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_7, permute_default_4) +view_default_8 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_8, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +where_self_1 = CallFunction(aten.where.self, expand_default, scalar_tensor_default, convert_element_type_default_5) +view_default_9 = CallFunction(aten.view.default, where_self_1, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) +view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_10, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) +view_default_11 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_11, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_7) +view_default_12 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_12, Ignored()) +_sfdp_pattern_20_half_training = MultiOutputPattern([view_default_6, + permute_default_6, + permute_default_9, + permute_default_11, + None, + None +]) + + +eq_Scalar = CallFunction(aten.eq.Scalar, KeywordArg('attn_mask'), Ignored()) +view_default = CallFunction(aten.view.default, eq_Scalar, Ignored()) +expand_default = CallFunction(aten.expand.default, view_default, Ignored()) +full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_2, memory_format=torch.contiguous_format) +view_default_2 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default_1, view_default_2) +view_default_3 = CallFunction(aten.view.default, bmm_default, Ignored()) +where_self = CallFunction(aten.where.self, expand_default, full_default, view_default_3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, where_self, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_4 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_4, memory_format=torch.contiguous_format) +view_default_5 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_4, view_default_5) +_sfdp_pattern_20_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py new file mode 100644 index 0000000000000000000000000000000000000000..74cbc9a0fe47a0700acacb199942b13f5f7f091a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_21.py @@ -0,0 +1,217 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_21_training = MultiOutputPattern([view_default_7, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +_sfdp_pattern_21_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_21_half_training = MultiOutputPattern([view_default_7, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +_sfdp_pattern_21_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py new file mode 100644 index 0000000000000000000000000000000000000000..2d72f96c0aac67b8ee382942f67fcb36f2870e43 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_22.py @@ -0,0 +1,229 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_22_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +view_default_3 = CallFunction(aten.view.default, add_Tensor, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_22_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, convert_element_type_default_5, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_6, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_22_half_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, view_default_2, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_1, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_1, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_2, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_22_half_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py new file mode 100644 index 0000000000000000000000000000000000000000..8c86ceb4c1f43d4d0d280c5f76ffb8744e64dbca --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_23.py @@ -0,0 +1,225 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_9, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +view_default_10 = CallFunction(aten.view.default, fma_default, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +view_default_12 = CallFunction(aten.view.default, view_default_11, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_23_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, view_default_2, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_4, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_4, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_23_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_8 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_6, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_8, permute_default_4) +view_default_9 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, view_default_9, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_4, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_10 = CallFunction(aten.view.default, convert_element_type_default_5, Ignored()) +view_default_11 = CallFunction(aten.view.default, view_default_10, Ignored()) +convert_element_type_default_6 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +convert_element_type_default_7 = CallFunction(prims.convert_element_type.default, convert_element_type_default_6, Ignored()) +view_default_12 = CallFunction(aten.view.default, convert_element_type_default_7, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_12, permute_default_5) +view_default_13 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_13, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_12) +view_default_14 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_14, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_5, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_8) +view_default_15 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_15, Ignored()) +_sfdp_pattern_23_half_training = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_4 = CallFunction(aten.view.default, view_default_3, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_4, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_3, Ignored()) +view_default_5 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_6 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_5, view_default_6) +view_default_7 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +_sfdp_pattern_23_half_inference = MultiOutputPattern([view_default_7, + permute_default_1, + permute_default_3 +]) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py new file mode 100644 index 0000000000000000000000000000000000000000..463bc0265817ae1a2ed2f668ceaf7d434345caa2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_3.py @@ -0,0 +1,190 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_3_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_3_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale_factor')) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_3_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_3_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py new file mode 100644 index 0000000000000000000000000000000000000000..823b4df9b8333be6df6d9debd1a2bbc2d2acb8af --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_4.py @@ -0,0 +1,190 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor_4, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5) +mul_Tensor_6 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_6, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_4_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_4_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_4, Ignored()) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +mul_Tensor_6 = CallFunction(aten.mul.Tensor, convert_element_type_default_5, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_6, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_4_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_4_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py new file mode 100644 index 0000000000000000000000000000000000000000..24da097b10d18f3b2930dcb70a6d432e595eab52 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py @@ -0,0 +1,178 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_5_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_5_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_3, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_5_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_5_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py new file mode 100644 index 0000000000000000000000000000000000000000..4dc1ac285b6c03c45623c9151895a4de46a202eb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py @@ -0,0 +1,194 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_6_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) +amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_6_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored(), _users=2) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_2) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_6_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +view_default = CallFunction(aten.view.default, expand_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +view_default_4 = CallFunction(aten.view.default, expand_default_3, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_6_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py new file mode 100644 index 0000000000000000000000000000000000000000..19849033ab0ab1032daa4c3dfdad1b8c963e9417 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py @@ -0,0 +1,221 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_7_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_7_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_7_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_7_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py new file mode 100644 index 0000000000000000000000000000000000000000..afe0d0488a6bcf36a1e9e9b335131f73be1cbb60 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py @@ -0,0 +1,205 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_8_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_8_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, Ignored()) +view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, view_default_9, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_8_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11 +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +expand_default = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_8_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py new file mode 100644 index 0000000000000000000000000000000000000000..c23d1c43a72c84e9ac762390f07348f749061c58 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py @@ -0,0 +1,221 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) +view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +view_default_8 = CallFunction(aten.view.default, fma_default, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_9_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, view_default_2, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, view_default_2, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_9_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor_1) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, mul_Tensor, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor_1) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_2, mul_Tensor_2) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +view_default_8 = CallFunction(aten.view.default, convert_element_type_default_4, Ignored(), _users=2) +permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) +permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_8 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_9 = CallFunction(aten.permute.default, permute_default_8, Ignored()) +permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +_sfdp_pattern_9_half_training = MultiOutputPattern([view_default_5, + permute_default_6, + permute_default_9, + permute_default_11, + None +]) + + +permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +permute_default_2 = CallFunction(aten.permute.default, permute_default_1, Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default_2, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, view_default_2, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor_1 = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) +expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_9_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..9afb42d3ec7fc111bf55ebb6ef4f29b73bef5724 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/addmm_pattern.py @@ -0,0 +1,53 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +addmm_default = CallFunction(aten.addmm.default, KeywordArg('input'), KeywordArg('mat1'), KeywordArg('mat2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha')) +mul_Scalar = CallFunction(aten.mul.Scalar, KeywordArg('tangents_1'), KeywordArg('beta')) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, mul_Scalar, Ignored(), True) +view_default = CallFunction(aten.view.default, sum_dim_IntList, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored()) +mm_default = CallFunction(aten.mm.default, KeywordArg('tangents_1'), permute_default) +mul_Scalar_1 = CallFunction(aten.mul.Scalar, mm_default, KeywordArg('alpha')) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored()) +mm_default_1 = CallFunction(aten.mm.default, permute_default_1, KeywordArg('tangents_1')) +mul_Scalar_2 = CallFunction(aten.mul.Scalar, mm_default_1, KeywordArg('alpha')) +addmm_pattern_training = MultiOutputPattern([addmm_default, + view_default, + mul_Scalar_1, + mul_Scalar_2, + None, + None +]) + + +addmm_pattern_inference = CallFunction(aten.addmm.default, KeywordArg('input'), KeywordArg('mat1'), KeywordArg('mat2'), beta=KeywordArg('beta'), alpha=KeywordArg('alpha'), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..0cf26d0ea6c6a20683caa63d4e6b5732b571ded4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/bmm_pattern.py @@ -0,0 +1,45 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +bmm_default = CallFunction(aten.bmm.default, KeywordArg('mat1'), KeywordArg('mat2')) +permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, KeywordArg('tangents_1'), permute_default) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, permute_default_1, KeywordArg('tangents_1')) +bmm_pattern_training = MultiOutputPattern([bmm_default, + bmm_default_1, + bmm_default_2 +]) + + +bmm_pattern_inference = CallFunction(aten.bmm.default, KeywordArg('mat1'), KeywordArg('mat2'), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py new file mode 100644 index 0000000000000000000000000000000000000000..6c17ba76cbe15c7ecf7bce80c0b2d4e0a90fc63d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/serialized_patterns/mm_pattern.py @@ -0,0 +1,45 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +mm_default = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2')) +permute_default = CallFunction(aten.permute.default, KeywordArg('mat2'), Ignored()) +mm_default_1 = CallFunction(aten.mm.default, KeywordArg('tangents_1'), permute_default) +permute_default_1 = CallFunction(aten.permute.default, KeywordArg('mat1'), Ignored()) +mm_default_2 = CallFunction(aten.mm.default, permute_default_1, KeywordArg('tangents_1')) +mm_pattern_training = MultiOutputPattern([mm_default, + mm_default_1, + mm_default_2 +]) + + +mm_pattern_inference = CallFunction(aten.mm.default, KeywordArg('mat1'), KeywordArg('mat2'), _users=0) diff --git a/phivenv/Lib/site-packages/torch/_inductor/fx_passes/split_cat.py b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/split_cat.py new file mode 100644 index 0000000000000000000000000000000000000000..9536a99fda1379782260496e78106151df27c1a1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/fx_passes/split_cat.py @@ -0,0 +1,2960 @@ +# mypy: allow-untyped-defs +import itertools +import logging +import operator +from collections import defaultdict +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union +from typing_extensions import TypeAlias + +import torch +from torch._dynamo.utils import counters +from torch.fx.experimental.symbolic_shapes import free_symbols +from torch.utils._ordered_set import OrderedSet + +from ..pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethodVarArgs, + FailedMatch, + get_arg_value, + Ignored, + KeywordArg, + ListOf, + Match, + MatchContext, + MULTIPLE, + PatternExpr, + PatternMatcherPass, + register_graph_pattern, + RepeatedExpr, +) +from .group_batch_fusion import is_node_meta_valid, POST_GRAD_FUSIONS, PRE_GRAD_FUSIONS + + +log = logging.getLogger(__name__) + +_Arguments: TypeAlias = tuple[torch.fx.node.Argument, ...] +_TransformParam: TypeAlias = tuple[ + Optional[_Arguments], + Optional[_Arguments], + Optional[_Arguments], + Optional[_Arguments], +] +_Range: TypeAlias = tuple[int, int] + + +PRE_GRAD_PATTERNS: dict[str, PatternMatcherPass] = {} +POST_GRAD_PATTERNS: dict[str, PatternMatcherPass] = {} + +pre_grad_pass_names = [ + "normalization_pass", + "remove_split_with_size_one_pass", + "merge_getitem_cat_pass", + "merge_stack_tahn_unbind_pass", + "merge_splits_pass", + "mutate_cat_pass", + "split_cat_pass", + "unbind_stack_pass", + "split_cat_to_slices_pass", + "unbind_cat_to_view_pass", + "split_stack_to_cats_pass", + "unbind_stack_to_slices_pass", + "move_reshape_out_of_split_stack_pass", +] + +post_grad_pass_names = [ + "normalization_aten_pass", + "decompose_mm_pass", + "unbind_stack_aten_pass", + "shape_padding_multiplier", + "pad_aten_mm_pass", + "split_cat_aten_pass", + "select_cat_aten_pass", + "move_view_after_cat_aten_pass", +] + +for pass_name in pre_grad_pass_names: + # exclude all passes from the group batch fusion + # they do not use pattern matcher + if pass_name in PRE_GRAD_FUSIONS: + continue + PRE_GRAD_PATTERNS[pass_name] = PatternMatcherPass( + pass_name=pass_name, + ) + +for pass_name in post_grad_pass_names: + # exclude all passes from the group batch fusion + # they do not use pattern matcher + if pass_name in POST_GRAD_FUSIONS: + continue + POST_GRAD_PATTERNS[pass_name] = PatternMatcherPass( + pass_name=pass_name, + ) + + +def construct_pattern_matcher_pass(pass_name: str): + """ + Return the specific pattern_matcher_pass given the pass name. + """ + if pass_name in PRE_GRAD_PATTERNS: + return PRE_GRAD_PATTERNS[pass_name] + else: + return POST_GRAD_PATTERNS[pass_name] + + +def _get_split_args_default(split_node): + input_kwarg = "tensor" + split_size_kwarg = "split_size_or_sections" + dim_kwarg = "dim" + default_dim_value = 0 + if split_node.op == "call_method": + split_size_kwarg = "split_size" + return ( + get_arg_value(split_node, 0, input_kwarg), + get_arg_value(split_node, 1, split_size_kwarg), + get_arg_value(split_node, 2, dim_kwarg) or default_dim_value, + ) + + +def _get_dim(node: Any): + assert isinstance(node, torch.fx.Node) + if "dim" in node.kwargs: + assert isinstance(node.kwargs["dim"], int) + return node.kwargs["dim"] + if node.target == torch.unbind: + if len(node.args) == 2: + assert isinstance(node.args[-1], int) + return node.args[-1] + return 0 # defaults to dim=0 + if node.target == torch.split: + if len(node.args) == 3: + assert isinstance(node.args[-1], int) + return node.args[-1] + return 0 # defaults to dim=0 + raise AssertionError( + f"Can't extract `dim` from {node.target} {node.args} {node.kwargs}" + ) + + +# noqa: W605 +# ############The pattern to be optimized is######### +# unbind (dim=0) +# / ... \ +# getitem getitem -> user=1 +# | | +# split split -> dim=1, user=1, split_section_size=1 +# | | +# getitem getitem -> user=1 +# \ / +# cat (dim=1) -> user=1 +# | + +# ################After transformation############# +# unbind (dim=0) +# / ... \ +# getitem getitem -> user=1 +# \ / +# cat (dim=1) -> user=1 +# | + + +def normalize_split_base( + match: Match, + _get_split_args: Callable[ + [torch.fx.Node], tuple[Optional[torch.fx.Node], Optional[Any], Optional[int]] + ], +): + """ + Normalize split with split_size into split_with_sizes, so that we only deal with one type of split in + subsequent optimizations + """ + split_node = match.nodes[0] + graph = match.graph + split_input, split_size, split_dim = _get_split_args(split_node) + if split_input is None or split_dim is None or split_size is None: + log.debug("couldn't find split args") + return + if not is_node_meta_valid(split_node): + log.debug("example value absent for node: %s", split_node) + return + assert isinstance(split_node.meta["example_value"], (list, tuple)) + split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]] + + if any(isinstance(section, torch.SymInt) for section in split_sections): + # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. + return + if split_dim < 0: # Normalize split dim + split_dim += split_input.meta["example_value"].dim() + + new_args = (split_input, split_sections) + new_kwargs = {"dim": split_dim} + if ( + split_node.args == new_args + and split_node.kwargs == new_kwargs + and split_node.op == "call_function" + ): + return + + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.split, + args=new_args, + kwargs=new_kwargs, # type: ignore[arg-type] + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + graph.erase_node(split_node) + counters["inductor"]["normalization_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.split, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +@register_graph_pattern( + CallMethodVarArgs("split", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_split_default(match: Match, *args, **kwargs): + return normalize_split_base(match, _get_split_args_default) + + +@register_graph_pattern( + CallFunctionVarArgs(torch.split, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"), +) +@register_graph_pattern( + CallMethodVarArgs("split", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("remove_split_with_size_one_pass"), +) +def remove_split_with_size_one(match: Match, *args, **kwargs): + graph = match.graph + split_node = match.nodes[0] + split_input, split_size, split_dim = _get_split_args_default(split_node) + if split_input is None or split_dim is None or split_size is None: + log.debug("couldn't find split args") + return + if not is_node_meta_valid(split_node): + log.debug("example value absent for node: %s", split_node) + return + assert isinstance(split_node.meta["example_value"], (list, tuple)) + split_sections = [t.size()[split_dim] for t in split_node.meta["example_value"]] + + if any(isinstance(section, torch.SymInt) for section in split_sections): + # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. + return + # remove the dummy split whose split sections size is one + # theoretically nodes with no users should be removed, but we have seen the corner case + # thus we add its users check to walk around the StopIteration error. + if len(split_sections) == 1 and len(split_node.users.keys()) > 0: + # find the grand children of the split_node + next_users = find_next_users(split_node) + user = next(iter(split_node.users.keys())) + # replace the users of grand child node with the input node + for next_user in next_users: + next_user.replace_input_with(user, split_input) + # erase the split node and its child + graph.erase_node(user) + graph.erase_node(split_node) + counters["inductor"]["remove_split_with_size_one_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.unbind, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +@register_graph_pattern( + CallMethodVarArgs("unbind", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_unbind_default(match: Match, *args, **kwargs): + node = match.nodes[0] + graph = match.graph + input = get_arg_value(node, 0, "input") + dim = get_arg_value(node, 1, "dim") + if dim is None: + axis = node.kwargs.get("axis") + if axis is not None: + dim = axis + else: + dim = 0 + if input is None: + log.debug("couldn't find unbind args") + return + if not is_node_meta_valid(input): + log.debug("example value absent for node: %s", input) + return + ndim = input.meta["example_value"].ndim + if dim < 0: # Normalize unbind dim + dim += ndim + with graph.inserting_after(node): + new_node = graph.call_function( + torch.unbind, + args=(input,), + kwargs={"dim": dim}, + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"]["normalization_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.cat, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_cat_default(match: Match, *args, **kwargs): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + cat_node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(cat_node, 0, "tensors") + cat_dim = get_arg_value(cat_node, 1, "dim") + if cat_dim is None: + cat_axis = cat_node.kwargs.get("axis") + if cat_axis is not None: + cat_dim = cat_axis + else: + cat_dim = 0 + if tensors is None or cat_dim is None: + log.debug("couldn't find cat args") + return + assert isinstance(tensors, (list, tuple)) + for tensor in itertools.chain([cat_node], tensors): + if not is_node_meta_valid(tensor): + log.debug("example value absent for node: %s", tensor) + return + + ndim = cat_node.meta["example_value"].dim() + + def is_empty_tensor(x): + # special case where torch.cat supports cat'ing with an empty tensor + x_shape = x.meta["example_value"].shape + return len(x_shape) == 1 and guard_size_oblivious(x_shape[0] == 0) + + assert all( + ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors + ) + + if cat_dim < 0: # Normalize cat dim + cat_dim += ndim + + new_args = (tensors,) + new_kwargs = {"dim": cat_dim} + if ( + cat_node.args == new_args + and cat_node.kwargs == new_kwargs + and cat_node.op == "call_function" + ): + return + + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.cat, + args=new_args, + kwargs=new_kwargs, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + graph.erase_node(cat_node) + counters["inductor"]["normalization_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.stack, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_stack_default(match: Match, *args, **kwargs): + node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(node, 0, "tensors") + dim = get_arg_value(node, 1, "dim") or 0 + if tensors is None or dim is None: + log.debug("couldn't find stack args") + return + assert isinstance(tensors, (list, tuple)) + + # A bug in pytorch, some nodes miss the example_value metadata + for tensor in itertools.chain([node], tensors): + if not is_node_meta_valid(tensor): + log.debug("example value absent for node: %s", tensor) + return + + ndim = node.meta["example_value"].dim() + if dim < 0: # Normalize dim + dim += ndim + + with graph.inserting_after(node): + new_node = graph.call_function( + node.target, # type: ignore[arg-type] + args=(tensors,), + kwargs={"dim": dim}, + ) + node.replace_all_uses_with(new_node) + new_node.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"]["normalization_pass"] += 1 + + +def find_next_users(split_node: torch.fx.Node) -> list[torch.fx.Node]: + next_users = [] + for getitem_node in split_node.users.keys(): + for getitem_user in getitem_node.users.keys(): + if getitem_user not in next_users: + next_users.append(getitem_user) + return next_users + + +@register_graph_pattern( + CallMethodVarArgs("squeeze", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_squeeze_default(match: Match, *args, **kwargs): + squeeze_node = match.nodes[0] + squeeze_input = get_arg_value(squeeze_node, 0) + + if "dim" in squeeze_node.kwargs: + assert len(squeeze_node.args) == 1 + dim = squeeze_node.kwargs["dim"] + elif len(squeeze_node.args) == 1: + # squeeze(Tensor) + dim = None + elif len(squeeze_node.args) == 2: + # squeeze(Tensor self, int dim) + # squeeze(Tensor self, int[] dim) + dim = squeeze_node.args[1] + else: + # squeeze(Tensor self, int[] dim) (called with varargs) + dim = squeeze_node.args[1:] + + if isinstance(dim, Sequence) and len(dim) == 1: + dim = dim[0] + + with match.graph.inserting_after(squeeze_node): + if dim is None: + new_squeeze_node = match.graph.call_function( + torch.squeeze, args=(squeeze_input,) + ) + else: + new_squeeze_node = match.graph.call_function( + torch.squeeze, args=(squeeze_input,), kwargs={"dim": dim} + ) + squeeze_node.replace_all_uses_with(new_squeeze_node) + new_squeeze_node.meta.update(squeeze_node.meta) + match.graph.erase_node(squeeze_node) + + +@register_graph_pattern( + CallMethodVarArgs("reshape", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_reshape_default(match: Match, *args, **kwargs): + reshape_node = match.nodes[0] + if not is_node_meta_valid(reshape_node): + log.debug("example value absent for node: %s", reshape_node) + return + reshape_input = get_arg_value(reshape_node, 0) + + if free_symbols(reshape_node.meta["example_value"].shape): + log.debug("dynamic shape not supported: %s", reshape_node) + return + + with match.graph.inserting_after(reshape_node): + new_reshape_node = match.graph.call_function( + torch.reshape, + args=(reshape_input, tuple(reshape_node.meta["example_value"].shape)), + ) + reshape_node.replace_all_uses_with(new_reshape_node) + new_reshape_node.meta.update(reshape_node.meta) + match.graph.erase_node(reshape_node) + + +@register_graph_pattern( + CallMethodVarArgs("clamp", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +@register_graph_pattern( + CallFunctionVarArgs(torch.clamp, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_clamp_default(match: Match, *args, **kwargs): + clamp_node = match.nodes[0] + if not is_node_meta_valid(clamp_node): + log.debug("example value absent for node: %s", clamp_node) + return + + if free_symbols(clamp_node.meta["example_value"].shape): + log.debug("dynamic shape not supported: %s", clamp_node) + return + if len(clamp_node.args) > 1: + args = (get_arg_value(clamp_node, 0),) + kwargs = { + "min": get_arg_value(clamp_node, 1, kwarg_name="min"), + "max": get_arg_value(clamp_node, 2, kwarg_name="max"), + } + else: + args = clamp_node.args + kwargs = clamp_node.kwargs + with match.graph.inserting_after(clamp_node): + new_clamp_node = match.graph.call_function( + torch.clamp, + args=args, + kwargs=kwargs, + ) + clamp_node.replace_all_uses_with(new_clamp_node) + new_clamp_node.meta.update(clamp_node.meta) + match.graph.erase_node(clamp_node) + + +@register_graph_pattern( + CallMethodVarArgs("detach", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_detach_default(match: Match, *args, **kwargs): + detach_node = match.nodes[0] + if not is_node_meta_valid(detach_node): + log.debug("example value absent for node: %s", detach_node) + return + + if free_symbols(detach_node.meta["example_value"].shape): + log.debug("dynamic shape not supported: %s", detach_node) + return + + with match.graph.inserting_after(detach_node): + new_detach_node = match.graph.call_function( + torch.detach, + args=detach_node.args, + ) + detach_node.replace_all_uses_with(new_detach_node) + new_detach_node.meta.update(detach_node.meta) + match.graph.erase_node(detach_node) + + +class TorchSplit(CallFunction): + """ + Matches a call to torch.split if it is in a normalized form. Ensures that all users of + splits are unique getitems. + """ + + def __init__(self, arg, sizes, func=torch.split) -> None: + # using KeywordArg("dim") for `dim` checks they all match + super().__init__(func, arg, sizes, _users=MULTIPLE, dim=KeywordArg("dim")) + + def _match(self, node: torch.fx.Node, ctx: MatchContext): + m = super()._match(node, ctx) + if not m: + return m + split_sections = node.args[1] + if not isinstance(split_sections, (list, tuple)): + return FailedMatch("split not normalized") + # check users are all unique getitems + seen_idxs = OrderedSet[int]() + for user in node.users: + if not CallFunction(operator.getitem, Arg(), Arg()).match(user): + # This should ideally never happen. Split user should always be a getitem + return FailedMatch(f"user of split not a getitem: {user}") + if not isinstance(user.args[1], int): + return FailedMatch("only integer getitems are handled") + if user.args[1] in seen_idxs: + return FailedMatch(f"duplicate getitem {user.args[1]}") + if user.args[-1] < 0: # type: ignore[operator] + # This shouldn't ideally happen as dynamo normalizes indexes to positive + return FailedMatch("negative index") + seen_idxs.add(user.args[1]) + return m + + +@register_graph_pattern( + TorchSplit( + CallFunction( + operator.getitem, + TorchSplit( + KeywordArg("first_split_input"), + KeywordArg("first_split_sections"), + ), + Ignored(), + ), + KeywordArg("next_split_sections"), + ), + pass_dict=construct_pattern_matcher_pass("merge_splits_pass"), +) +def merge_splits( + match: Match, + first_split_input: torch.fx.Node, + first_split_sections: list[int], + next_split_sections: list[int], + # Note: dim is implicitly passed by TorchSplit, as it internally uses a pattern with dim + dim: int, +): + node = match.output_node() + # it is possible that the split has no users, + # we check the corner case and skip the pattern + if len(node.users.keys()) == 0: + return + graph = match.graph + first_split = node.args[0].args[0] # type: ignore[union-attr] + next_split_index = node.args[0].args[1] # type: ignore[union-attr] + + new_split_sections = list(first_split_sections) + new_split_sections[next_split_index : next_split_index + 1] = next_split_sections # type: ignore[operator, misc] + + first_split_dim = _get_dim(first_split) + + to_remove = [] + + with graph.inserting_before(first_split): # type: ignore[arg-type] + # Add the new split node + new_split = graph.call_function( + torch.split, + args=(first_split_input, new_split_sections), + kwargs={"dim": first_split_dim}, + ) + if is_node_meta_valid(first_split_input): + new_split.meta["example_value"] = torch.split( + first_split_input.meta["example_value"], + new_split_sections, + dim=first_split_dim, + ) + first_split_num_to_user = { + user.args[1]: user + for user in first_split.users.keys() # type: ignore[union-attr] + } + + new_split_num = 0 + for split_num in range(len(first_split_sections)): + if split_num not in first_split_num_to_user: + new_split_num += 1 + continue + old_getitem = first_split_num_to_user[split_num] + if split_num != next_split_index: + old_getitem.update_arg(0, new_split) + old_getitem.update_arg(1, new_split_num) + new_split_num += 1 + else: + next_split_num_to_user = { + user.args[1]: user for user in node.users.keys() + } + # It is not necessary all getitems from the split node are used. + for next_split_num in range(len(next_split_sections)): + with graph.inserting_after(new_split): + new_getitem = graph.call_function( + operator.getitem, args=(new_split, new_split_num) + ) + new_split_num += 1 + if next_split_num not in next_split_num_to_user: + continue + next_getitem = next_split_num_to_user[next_split_num] + new_getitem.meta.update(next_getitem.meta) + next_getitem.replace_all_uses_with(new_getitem) + to_remove.append(next_getitem) + to_remove.append(node) + to_remove.append(old_getitem) + + to_remove.append(first_split) # type: ignore[arg-type] + for node in to_remove: + graph.erase_node(node) + + counters["inductor"]["merge_splits_pass"] += 1 + + +class SplitCatSimplifier: + """ + Helper class to simplify split-cat pattern. In simple cases, both split and cat node can be removed in a "split->cat" + pattern. However, there are various cases where they can't and we need to simplify split/ add transforms before cat. + Some such cases are: + 1. Final node has additional args (not coming from the initial split) + 2. Shuffling of args between split/cat + 3. Some final nodes are non-(cat/stack) + 4. Split-dim != cat-dim (but equal split) + + Note that any combination of the above cases can happen. + + To deal with 1, 2, & 3 - we iterate over all users of split. And figure out common "ranges" that can be merged. + Then, we simplify the split accordingly. In the best case, split can be entirely removed. + + To deal with 4, we add some transformations (unflatten + movedim) (See `get_transform_params`). + + Finally, depending on final node being cat or stack, unsqueeze/flatten needs to be added. + + """ + + def simplify( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + split_sections: list[int], + ): + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + # Gather inputs of the next users. When inputs come from `split_node`, they are instead represented by + # a tuple indicating the split ranges. See `get_user_input_list` for more details + user_inputs_list = self.get_user_input_list(split_node, next_users) + # Simplify the split_sections based on user_inputs_list. In simpler cases, len(simplified_split_ranges) == 1 and + # we can simply replace the split node. Otherwise, we simplify it. + simplified_split_ranges = self.get_simplified_split_ranges( + split_sections, next_users, user_inputs_list + ) + if not simplified_split_ranges: # Simplification not possible + return + transform_params_list = self.get_transform_params( + split_node, next_users, user_inputs_list + ) + if not transform_params_list: + return + + # Start actual replacement + user_inputs_list_new = self.replace_split( + graph, split_node, split_sections, user_inputs_list, simplified_split_ranges + ) + self.replace_cat( + graph, + split_node, + next_users, + user_inputs_list_new, + transform_params_list, # type: ignore[arg-type] + ) + self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type] + counters["inductor"]["unbind_stack_pass"] += 1 + + def get_user_input_list( + self, split_node: torch.fx.Node, next_users: list[torch.fx.Node] + ) -> list[list[Union[torch.fx.Node, _Range]]]: + """ + Returns list of inputs to the following user nodes, in order. The outer list represents the user node. The inner + list represents the inputs to that particular node. This list can either contain + - a tuple representing the ranges of get_items that should go into the cat (closed interval) + - torch.fx.Node representing "other" inputs (which are not coming from our split) + """ + user_inputs_list: list[list[Union[torch.fx.Node, _Range]]] = [] + for user in next_users: + if user.target in (torch.cat, torch.stack): + user_inputs_list.append(self.get_merged_user_inputs(split_node, user)) + else: + user_inputs_list.append(self.get_non_cat_node_input(split_node, user)) # type: ignore[arg-type] + return user_inputs_list + + def get_merged_user_inputs( + self, split_node: torch.fx.Node, cat_node: torch.fx.Node + ) -> list[Union[torch.fx.Node, _Range]]: + user_inputs = get_arg_value(cat_node, 0, "tensors") + simplified_user_inputs = [] + split_users = OrderedSet(split_node.users.keys()) + for user_input in user_inputs: + if user_input not in split_users: + simplified_user_inputs.append(user_input) + else: + # Add which "getitem" cat depends on + simplified_user_inputs.append(user_input.args[1]) + return self.merge_consecutive_inputs(simplified_user_inputs) + + def get_non_cat_node_input( + self, split_node: torch.fx.Node, node: torch.fx.Node + ) -> list[_Range]: + """ + Get input for a non cat node in the same format as `get_merged_user_inputs` + """ + node_input = [] + split_users = OrderedSet(split_node.users.keys()) + for node_arg in node.all_input_nodes: + if node_arg in split_users: + getitem_num = get_arg_value(node_arg, 1) + node_input.append((getitem_num, getitem_num)) + return node_input + + def merge_consecutive_inputs( + self, inputs: list[Union[torch.fx.Node, int]] + ) -> list[Union[torch.fx.Node, _Range]]: + """ + Merge consecutive inputs going into a user node. + + For e.g. + [arg0, 0, 1, 2, arg1] -> [arg0, (0, 2), arg1] + """ + merged_ranges = [] + cur_range = None + for input_ in inputs: + if isinstance(input_, int): + if not cur_range: + cur_range = [input_, input_] + elif input_ == cur_range[1] + 1: + cur_range[1] += 1 + else: + merged_ranges.append(tuple(cur_range)) + cur_range = [input_, input_] + else: + if cur_range: + merged_ranges.append(tuple(cur_range)) + cur_range = None + merged_ranges.append(input_) # type: ignore[arg-type] + if cur_range: + merged_ranges.append(tuple(cur_range)) + return merged_ranges # type: ignore[return-value] + + def get_simplified_split_ranges( + self, + split_sections, + next_users, + user_inputs_list: list[list[Union[torch.fx.Node, _Range]]], + ) -> Optional[list[_Range]]: + ranges = OrderedSet[Any]() + for user_inputs in user_inputs_list: + ranges.update(u for u in user_inputs if isinstance(u, tuple)) + + cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist() + split_ranges = sorted( + [(cumulative_sizes[r[0]], cumulative_sizes[r[1] + 1]) for r in ranges] + ) + + if not self.has_non_overlapping_ranges( + split_ranges, + ): # This need not be a strict condition + # However, we keep it now for simplicity. + return None + split_ranges = self.fill_gaps(split_ranges, 0, cumulative_sizes[-1]) + if len(split_sections) == len(split_ranges): # Simplification not possible + return None + counters["inductor"]["scmerge_split_sections_removed"] = len( + split_sections + ) - len(split_ranges) + return split_ranges + + def has_non_overlapping_ranges(self, ranges: list[_Range]) -> bool: + for range_, next_range in zip(ranges, ranges[1:]): + if range_[1] > next_range[0]: + return False + return True + + def fill_gaps(self, ranges: list[_Range], min_: int, max_: int) -> list[_Range]: + cur = min_ + filled_ranges = [] + for a, b in ranges: + if cur < a: + filled_ranges.append((cur, a)) + filled_ranges.append((a, b)) + cur = b + if filled_ranges[-1][1] < max_: + filled_ranges.append((filled_ranges[-1][1], max_)) + return filled_ranges + + def get_transform_params( + self, + split_node: torch.fx.Node, + next_users: list[torch.fx.Node], + user_inputs_list: list[list[Union[torch.fx.Node, _Range]]], + ) -> Optional[list[list[_TransformParam]]]: + """ + Figure out what transforms are needed for each input to each cat node. + + We replace a split node with an unflatten followed by a movedim + """ + split_dim = _get_dim(split_node) + split_sections = split_node.args[1] + transform_params_list: list[list[_TransformParam]] = [] + + for user_node, user_inputs in zip(next_users, user_inputs_list): + if user_node.target not in (torch.cat, torch.stack): + transform_params_list.append([]) + continue + + cat_dim = get_arg_value(user_node, 1, "dim") + transform_params: list[_TransformParam] = [] + for user_input in user_inputs: + if split_dim == cat_dim and user_node.target == torch.cat: + # No transform needed + transform_params.append((None, None, None, None)) + elif isinstance(user_input, tuple): # Split being simplified + # Verify equal split + subset_split_sections = split_sections[ # type: ignore[index] + user_input[0] : user_input[1] + + 1 # type: ignore[index] + ] + # All sections should be equal + if len(OrderedSet(subset_split_sections)) != 1: # type: ignore[arg-type] + return None + + num_splits = len(subset_split_sections) # type: ignore[arg-type] + unflatten_params = (split_dim, (num_splits, -1)) + movedim_params = ( + (split_dim, cat_dim) if split_dim != cat_dim else None + ) + transform_params.append( + (unflatten_params, movedim_params, None, None) + ) + elif ( + user_node.target == torch.stack or split_dim != cat_dim + ): # We need to unsqueeze inputs not coming through split + transform_params.append((None, None, (cat_dim,), None)) + else: # Non-split inputs + transform_params.append((None, None, None, None)) + transform_params_list.append(transform_params) + return transform_params_list + + def replace_split( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + split_sections: list[int], + user_inputs_list: list[list[Union[torch.fx.Node, _Range]]], + split_ranges: list[_Range], + ) -> list[list[torch.fx.Node]]: + """ + Replace the split node. It can either remove the split node if len(split_ranges) == 1, or simplify it + into a split with lesser sections if len(split_ranges) > 1. + + Returns the new `user_inputs_list`, with tuples replaced with new getitems from the newer split node. + """ + split_input = split_node.args[0] + split_dim = _get_dim(split_node) + if len(split_ranges) == 1: # We can completely eliminate the split node + split_items = [split_input] + else: + with graph.inserting_after(split_node): + new_split = graph.call_function( + torch.split, + args=( + split_input, + [r[1] - r[0] for r in split_ranges], + ), + kwargs={"dim": split_dim}, + ) + if is_node_meta_valid(split_input): # type: ignore[arg-type, union-attr] + new_split.meta["example_value"] = torch.split( + split_input.meta["example_value"], # type: ignore[union-attr] + [r[1] - r[0] for r in split_ranges], + dim=split_dim, + ) + counters["inductor"]["scmerge_split_added"] += 1 + split_items = [] + with graph.inserting_after(new_split): + for i in range(len(split_ranges)): + getitem = graph.call_function(operator.getitem, args=(new_split, i)) + if is_node_meta_valid(new_split): + getitem.meta["example_value"] = new_split.meta["example_value"][ + i + ] + split_items.append(getitem) + # Now assign the right getitem to the right input + cumulative_sizes = [0] + torch.cumsum(torch.tensor(split_sections), 0).tolist() + new_user_inputs_list = [] + for user_inputs in user_inputs_list: + new_user_inputs = [] + for user_input in user_inputs: + if isinstance(user_input, tuple): + # Find the correct new getitem (present in split_items) + new_user_inputs.append( + split_items[ + split_ranges.index( + ( + cumulative_sizes[user_input[0]], + cumulative_sizes[user_input[1] + 1], + ) + ) + ] + ) + else: + new_user_inputs.append(user_input) + new_user_inputs_list.append(new_user_inputs) + return new_user_inputs_list # type: ignore[return-value] + + def replace_cat( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + next_users: list[torch.fx.Node], + user_inputs_list_new, + transform_params_list: list[list[_TransformParam]], + ): + split_dim = _get_dim(split_node) + split_users = split_node.users.keys() + new_cats = [] + for user_node, user_inputs_new, transform_params in zip( + next_users, user_inputs_list_new, transform_params_list + ): + if user_node.target not in (torch.cat, torch.stack): + # Change the args and kwargs of non-cat/stack nodes. Replace old getitems (belonging to + # the original split node) with the newer getitems + next_cat_input = 0 + for input_node in user_node.all_input_nodes: + if input_node in split_users: + user_node.replace_input_with( + input_node, user_inputs_new[next_cat_input] + ) + next_cat_input += 1 + continue + + # Handle cat/stack user nodes + cat_dim = get_arg_value(user_node, 1, "dim") + user_inputs_new_transformed, user_inputs_new_transformed_meta = [], [] + # For `unsqueeze` transform, we will combine consecutive inputs with the same unsqueeze params, and stack them + to_stack, to_stack_meta = [], [] + stack_dim = None + with graph.inserting_before(user_node): + for user_input_new, transform_param in zip( + user_inputs_new, transform_params + ): + if not is_node_meta_valid(user_input_new): + log.debug("example value absent for node: %s", user_input_new) + return + # Apply transforms + ( + unflatten_params, + movedim_params, + unsqueeze_params, + flatten_params, + ) = transform_param + if unsqueeze_params and ( + stack_dim is None or stack_dim == unsqueeze_params[0] + ): + to_stack.append(user_input_new) + to_stack_meta.append(user_input_new.meta["example_value"]) + stack_dim = unsqueeze_params[0] + continue + elif to_stack: + stacked_input = graph.call_function( + torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} + ) + stacked_input.meta["example_value"] = torch.stack( # type: ignore[arg-type] + to_stack_meta, + dim=stack_dim, # type: ignore[arg-type] + ) + to_stack, to_stack_meta = [], [] + stack_dim = None + user_inputs_new_transformed.append(stacked_input) + user_inputs_new_transformed_meta.append( + stacked_input.meta["example_value"] + ) + if unsqueeze_params: + to_stack.append(user_input_new) + stack_dim = unsqueeze_params[0] + to_stack_meta.append(user_input_new.meta["example_value"]) + continue + + if unflatten_params: + user_input_new_meta = user_input_new.meta["example_value"] + user_input_new = graph.call_function( + torch.unflatten, args=(user_input_new, *unflatten_params) + ) + user_input_new.meta["example_value"] = torch.unflatten( # type: ignore[arg-type] + user_input_new_meta, # type: ignore[arg-type] + *unflatten_params, # type: ignore[arg-type] + ) + if movedim_params: + user_input_new_meta = user_input_new.meta["example_value"] + user_input_new = graph.call_function( + torch.movedim, args=(user_input_new, *movedim_params) + ) + user_input_new.meta["example_value"] = torch.movedim( # type: ignore[arg-type] + user_input_new_meta, # type: ignore[arg-type] + *movedim_params, # type: ignore[arg-type] + ) + if flatten_params: + user_input_new_meta = user_input_new.meta["example_value"] + user_input_new = graph.call_function( + torch.flatten, args=(user_input_new, *flatten_params) + ) + user_input_new.meta["example_value"] = torch.flatten( # type: ignore[arg-type] + user_input_new_meta, + *flatten_params, # type: ignore[arg-type] + ) + user_inputs_new_transformed.append(user_input_new) + user_inputs_new_transformed_meta.append( + user_input_new.meta["example_value"] + ) + if to_stack: + stacked_input = graph.call_function( + torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} + ) + stacked_input.meta["example_value"] = torch.stack( # type: ignore[arg-type] + to_stack_meta, + dim=stack_dim, # type: ignore[arg-type] + ) + user_inputs_new_transformed.append(stacked_input) + user_inputs_new_transformed_meta.append( + stacked_input.meta["example_value"] + ) + + with graph.inserting_after(user_node): + if len(user_inputs_new_transformed) > 1: + new_cat_node = graph.call_function( + torch.cat, + args=(user_inputs_new_transformed,), + kwargs={"dim": cat_dim}, + ) + new_cat_node.meta["example_value"] = torch.cat( + user_inputs_new_transformed_meta, + dim=cat_dim, + ) + counters["inductor"]["scmerge_cat_added"] += 1 + else: + new_cat_node = user_inputs_new_transformed[-1] + new_cat_node.meta["example_value"] = ( + user_inputs_new_transformed_meta[-1] + ) + + if ( + user_node.target == torch.cat + and split_dim != cat_dim + and split_node.target == torch.split + ): + with graph.inserting_after(new_cat_node): + new_cat_node_meta = new_cat_node.meta["example_value"] + new_cat_node = graph.call_function( + torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1) + ) + new_cat_node.meta["example_value"] = torch.flatten( + new_cat_node_meta, + cat_dim, + cat_dim + 1, + ) + user_node.replace_all_uses_with(new_cat_node) + new_cats.append(new_cat_node) + + def erase_old_nodes( + self, + graph: torch.fx.Graph, + split_node: torch.fx.Node, + next_users: list[torch.fx.Node], + ): + to_remove = [split_node] + counters["inductor"]["scmerge_split_removed"] += 1 + to_remove.extend(split_node.users.keys()) + for next_user in next_users: + if next_user.target not in (torch.cat, torch.stack): + continue + counters["inductor"]["scmerge_cat_removed"] += 1 + to_remove.append(next_user) + for node in reversed(to_remove): + if len(node.users.keys()) == 0: + graph.erase_node(node) + + +class UnbindCatRemover(SplitCatSimplifier): + """ + Helper class to merge Unbind->Cat/Stack. Many of the cases are similar to SplitCatSimplifier. + + Unbind can't be simplified like splits. So, we can only remove the unbind node. Other than this, + other cases like multiple users, additional args, dim mismatch are similar to `SplitCatSimplifier`, + hence we extend that class. + """ + + def remove_unbind( + self, + graph: torch.fx.Graph, + unbind_node: torch.fx.Node, + ): + if not is_node_meta_valid(unbind_node): + return + # we need to check if the getitem indices from unbind are consecutive and all go to the same cat node + # before we do the unbind remove, otherwise it will hit the error when we unbind part of them + getitem_indices = [ + getitem_node.args[1] for getitem_node in unbind_node.users.keys() + ] + if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type] + getitem_indices + ) != len(unbind_node.meta["example_value"]): + return + num_unbind = len(getitem_indices) + split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type] + + super().simplify(graph, unbind_node, split_sections) + + def get_simplified_split_ranges( + self, + split_sections: list[int], + next_users: list[torch.fx.Node], + user_inputs_list: list[list[Union[torch.fx.Node, _Range]]], + ) -> Optional[list[_Range]]: + simplified_split_ranges = super().get_simplified_split_ranges( + split_sections, next_users, user_inputs_list + ) + if not simplified_split_ranges or len(simplified_split_ranges) != 1: + return None + return simplified_split_ranges + + def get_transform_params( + self, + split_node: torch.fx.Node, + next_users: list[torch.fx.Node], + user_inputs_list: list[list[Union[torch.fx.Node, _Range]]], + ) -> Optional[list[list[_TransformParam]]]: + """ + Figure out what transforms are needed for each input to each cat node. + + Here is the rough transforms we apply: + + x -> unbind -> stack => x -> movedim + + x -> unbind -> cat => x -> movedim -> flatten + + When cat/stack nodes have additional args: + + addn ---| addn -> unsqueeze ---| + x -> unbind -> stack => x -> movedim -> cat + + addn ---| addn ---| + x -> unbind -> cat => x -> movedim -> flatten -> cat + + (Note application of these depends on the dims as well) + + + """ + split_dim = _get_dim(split_node) + transform_params_list: list[list[_TransformParam]] = [] + for user_node, user_inputs in zip(next_users, user_inputs_list): + cat_dim = get_arg_value(user_node, 1, "dim") or 0 + transform_params: list[_TransformParam] = [] + for user_input in user_inputs: + if isinstance(user_input, tuple): + # User input is coming from unbind + movedim_params = ( + (split_dim, cat_dim) if split_dim != cat_dim else None + ) + flatten_params = None + if user_node.target == torch.cat: + flatten_params = (cat_dim, cat_dim + 1) + transform_params.append( + (None, movedim_params, None, flatten_params) + ) + elif ( + user_node.target == torch.stack + ): # We need to unsqueeze inputs not coming through unbind into cat + transform_params.append((None, None, (cat_dim,), None)) + else: # Non-unbind inputs + transform_params.append((None, None, None, None)) + transform_params_list.append(transform_params) + return transform_params_list + + +class GetItem(CallFunction): + def __init__(self, arg, index, _users=1) -> None: + super().__init__(operator.getitem, arg, index, _users=_users) + + def find_anchor_nodes(self, ctx: MatchContext, searched: OrderedSet[torch.fx.Node]): + # We generally match GetItem with arg being an Arg(). So, we never return the anchor + # nodes as the stored node in ctx.pattern_to_node is returned. Here we override find_anchor_nodes + # to not use ctx.pattern_to_node + for pattern in self.flat_args_kwargs[0]: + if isinstance(pattern, PatternExpr): + for other_node in pattern.find_anchor_nodes(ctx, searched): + if not isinstance(other_node, torch.fx.Node): + continue + for node in other_node.users: + if node not in searched: + if self._match_fns(node): + yield node + searched.add(node) + + +@register_graph_pattern( + RepeatedExpr( + CallFunction( + torch.squeeze, + GetItem( + TorchSplit( + KeywordArg("split_input"), + KeywordArg("split_sizes"), + ), + Ignored(), + ), + KeywordArg("dim"), + _users=MULTIPLE, + ), + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +@register_graph_pattern( + RepeatedExpr( + CallFunction( + torch.squeeze, + GetItem( + TorchSplit( + KeywordArg("split_input"), + KeywordArg("split_sizes"), + ), + Ignored(), + ), + dim=KeywordArg("dim"), + _users=MULTIPLE, + ) + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +def merge_split_squeeze( + match: Match, split_input: torch.fx.Node, split_sizes: list[int], dim: int +): + graph = match.graph + split = next(node for node in match.nodes if node.target == torch.split) + if not all(s == 1 for s in split_sizes): + return + if isinstance(dim, Sequence): + return + next_users = find_next_users(split) + if not all(node.target == torch.squeeze for node in next_users): + return + with graph.inserting_before(match.output_node()): + unbind = graph.call_function( + torch.unbind, args=(split_input,), kwargs={"dim": dim} + ) + if is_node_meta_valid(split_input): + unbind.meta["example_value"] = torch.unbind( + split_input.meta["example_value"], dim=dim + ) + for item_index, getitem_node in sorted( + [ + (getitem_node.args[1], getitem_node) + for getitem_node in split.users.keys() + ] + ): + squeeze = next(iter(getitem_node.users.keys())) + new_get_item = graph.call_function( + operator.getitem, args=(unbind, item_index) + ) + squeeze.replace_all_uses_with(new_get_item) + new_get_item.meta.update(squeeze.meta) + graph.erase_node(squeeze) + graph.erase_node(getitem_node) + graph.erase_node(split) + counters["inductor"]["split_cat_pass"] += 1 + + +getitem_unbind = ListOf( + GetItem( + CallFunction( + torch.unbind, + KeywordArg("unbind_input"), + dim=KeywordArg("dim"), + _users=MULTIPLE, + ), + Ignored(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunction([torch.stack, torch.cat], getitem_unbind, Ignored(), _users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], getitem_unbind, dim=Ignored(), _users=MULTIPLE + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], tensors=getitem_unbind, dim=Ignored(), _users=MULTIPLE + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), +) +def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int): + unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + UnbindCatRemover().remove_unbind(match.graph, unbind_node) + + +getitem_split = ListOf( + CallFunction( + operator.getitem, + TorchSplit( + Ignored(), + KeywordArg("split_sections"), + ), + Ignored(), + _users=MULTIPLE, + ), + partial=True, +) + + +reshape_getitem_split = ListOf( + CallFunction( + torch.reshape, + CallFunction( + operator.getitem, + TorchSplit( + Ignored(), + KeywordArg("split_sections"), + ), + Ignored(), + _users=MULTIPLE, + ), + Arg(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + tensors=getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +@register_graph_pattern( + CallFunction( + [torch.stack, torch.cat], + getitem_split, + Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_pass"), +) +def simplify_split_cat(match: Match, split_sections: list[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + split_node = next(node for node in match.nodes if node.target == torch.split) + SplitCatSimplifier().simplify(match.graph, split_node, split_sections) + + +# noqa: W605 +# ############pattern to be optimized is######### + +# split_node(dim=1) +# / \ ... / \ +# getitem getitem getitem getitem -> user=1 +# \ / \ / +# cat (user=mul, dim=1) cat(user=mul, dim=1) +# | \ | \ + +# ################after transformation############# + +# split_node(dim=1) +# / ... \ +# getitem getitem +# | \ | \ + + +def has_same_parent_node(node: torch.fx.Node): + # the input nodes of the node should come from the same parent + prev_node = None + for getitem in node.args[0]: # type: ignore[union-attr] + if getitem.target != operator.getitem: # type: ignore[union-attr] + return False + if prev_node is None: + prev_node = getitem.args[0] # type: ignore[union-attr] + else: + if getitem.args[0] != prev_node: # type: ignore[union-attr] + return False + return True + + +def remove_zeros(split_sections: list[int]): + """ + Remove zeros from the list and get the index mapping dict from getitem + in split node to getitem in new split node + """ + new_split_sections, index_mapping = [], {} + idx = 0 + for i in range(len(split_sections)): + if split_sections[i] > 0: + new_split_sections.append(split_sections[i]) + index_mapping[i] = idx + idx += 1 + + return new_split_sections, index_mapping + + +def is_sorted_and_consecutive(arr: list[int]) -> bool: + # check if the array is sorted + if arr == sorted(arr): + # check if the differences between adjacent elements are all 1 + return all(x[1] - x[0] == 1 for x in zip(arr, arr[1:])) + else: + return False + + +def calculate_fused_tensor_size(split_node: torch.fx.Node, indices: list[int]) -> int: + """ + Calculate the fused tensor size in the indices + """ + fused_tensor_size = 0 + for i in range(len(split_node.args[1])): # type: ignore[arg-type] + if i in indices: + fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] + return fused_tensor_size + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("merge_getitem_cat_pass"), +) +def merge_getitem_cat(match: Match, split_sections: list[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + graph = match.graph + split_node = next(node for node in match.nodes if node.target == torch.split) + split_input, _split_size, split_dim = _get_split_args_default(split_node) + # if the cat and split have different dims, return + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + # 'immutable_list' object does not support mutation. Create a new copy of it + split_sections = list(split_sections) + for cat_user in next_users: + if cat_user.target == torch.cat: + cat_dim = get_arg_value(cat_user, 1, "dim") + # check the all getitems in the cat_user from the same node + # check the input of the cat has all getitem from the split + # check all getitem only has one single user + if ( + split_dim != cat_dim + or not has_same_parent_node(cat_user) + or not all(len(arg.users) == 1 for arg in cat_user.args[0]) # type: ignore[union-attr] + ): + continue + # find the index of getitems to be cated/stacked + # type: ignore[union-attr] + indices = [arg.args[1] for arg in cat_user.args[0]] # type: ignore[union-attr] + # the getitems to be merged must be consecutive, otherwise + # returned sliced tensor could be wrong + if not is_sorted_and_consecutive(indices): # type: ignore[arg-type] + continue + # update the arg of cat user, only keep the first getitem + cat_user.update_arg(0, cat_user.args[0][0]) # type: ignore[index] + # calculate the fused tensor sizes in the indices + fused_tensor_size = 0 + for i in range(len(split_node.args[1])): # type: ignore[arg-type] + if i in indices: + fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] + # update the split sections + split_sections[indices[0]] = calculate_fused_tensor_size( # type: ignore[index] + split_node, + indices, # type: ignore[arg-type] + ) + # padding others with zeros to keep the same dict size + for i in indices[1:]: + split_sections[i] = 0 # type: ignore[index] + # remove all unused indexes in the split_node + new_split_sections, index_mapping = remove_zeros(split_sections) + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.split, + args=(split_input, split_sections), + kwargs={"dim": split_dim}, + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + # remove all unused getitem nodes + to_remove = [cat_user] + # dictionary keys changed during iteration + new_split_getitem_nodes = list(new_split_node.users.keys()) + for getitem_node in new_split_getitem_nodes: + if getitem_node.args[1] in indices[1:]: + to_remove.append(getitem_node) + # update meta data of getitem + elif getitem_node.args[1] == indices[0]: + cat_user.replace_all_uses_with(getitem_node) + getitem_node.meta.update(cat_user.meta) + else: + # update getitem index for new split node + getitem_node.update_arg(1, index_mapping[getitem_node.args[1]]) + graph.erase_node(split_node) + for getitem_node in to_remove: + graph.erase_node(getitem_node) + # update the split sections of new split node + new_split_node.update_arg(1, new_split_sections) + split_node = new_split_node + split_sections = new_split_sections + + counters["inductor"]["merge_getitem_cat_pass"] += 1 + + +# ############pattern to be optimized is######### + +# split_node(dim=1) -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ \ / \ +# other_op /cat(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# split_node(dim=1) -> -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ \ / \ +# other_op + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("mutate_cat_pass"), +) +def mutate_cat_node(match: Match, split_sections: list[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + graph = match.graph + split_node = next(node for node in match.nodes if node.target == torch.split) + _split_input, _split_size, split_dim = _get_split_args_default(split_node) + # if the cat and split have different dims, return + # Find the next users (i.e. users after the getitem) + next_users = find_next_users(split_node) + for cat_user in next_users: + if cat_user.target == torch.cat: + cat_dim = get_arg_value(cat_user, 1, "dim") or 0 + # check that all getitems in the cat_user from the same node + # check the input of the cat has all getitem from the split + if split_dim != cat_dim or not has_same_parent_node(cat_user): + continue + # find the index of getitems to be cat + indices, idx_to_getitem = [], {} + for getitem in cat_user.args[0]: # type: ignore[union-attr] + indices.append(getitem.args[1]) # type: ignore[union-attr] + idx_to_getitem[getitem.args[1]] = getitem # type: ignore[union-attr] + # the getitems to be merged must be consecutive, otherwise + # returned sliced tensor could be wrong + if not is_sorted_and_consecutive(indices): # type: ignore[arg-type] + continue + # case 1: the cat uses all getitems from the split + if len(split_sections) == len(cat_user.args[0]): # type: ignore[arg-type] + # replace the users of the cat node to be the input of the split node + cat_user.replace_all_uses_with(split_node.args[0]) # type: ignore[arg-type] + # remove the cat node + graph.erase_node(cat_user) + counters["inductor"]["mutate_cat_pass"] += 1 + # case 2: the cat uses some getitems from the split + elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type] + # check the split dim, and construct the slice tuple + start_fused_size = calculate_fused_tensor_size( + split_node, + list(range(indices[0])), # type: ignore[arg-type] + ) + end_fused_size = start_fused_size + calculate_fused_tensor_size( + split_node, + indices, # type: ignore[arg-type] + ) + slice_list = [] + for i in range(len(split_node.args[0].meta["example_value"].shape)): # type: ignore[union-attr] + if i != split_dim: + slice_list.append(slice(None, None, None)) + else: + slice_list.append(slice(start_fused_size, end_fused_size, None)) + with graph.inserting_after(split_node): + slice_node = graph.call_function( + operator.getitem, + args=(split_node.args[0], tuple(slice_list)), + ) + cat_user.replace_all_uses_with(slice_node) + slice_node.meta.update(cat_user.meta) + + # remove the cat node + graph.erase_node(cat_user) + counters["inductor"]["mutate_cat_pass"] += 1 + + +getitem_split_aten = ListOf( + CallFunction( + operator.getitem, + CallFunctionVarArgs([torch.ops.aten.split_with_sizes.default], users=MULTIPLE), + Ignored(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunctionVarArgs(torch.ops.aten.split.Tensor, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"), +) +def normalize_split_default_aten(match: Match, *args, **kwargs): + split_node = match.nodes[0] + graph = match.graph + split_input, split_size, split_dim = _get_split_args_default(split_node) + if split_input is None or split_dim is None or split_size is None: + log.debug("couldn't find split args") + return + if not is_node_meta_valid(split_node): + log.debug("val absent for node: %s", split_node) + return + assert isinstance(split_node.meta["val"], (list, tuple)) + split_sections = [t.size()[split_dim] for t in split_node.meta["val"]] + if any(isinstance(section, torch.SymInt) for section in split_sections): + # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. + return + if split_dim < 0: # Normalize split dim + split_dim += split_input.meta["val"].dim() + split_section_list = [split_size] * (len(split_node.meta["val"])) + new_args = (split_input, split_section_list) + new_kwargs = {"dim": split_dim} + if ( + split_node.args == new_args + and split_node.kwargs == new_kwargs + and split_node.op == "call_function" + ): + return + + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.ops.aten.split_with_sizes.default, + args=new_args, + kwargs=new_kwargs, # type: ignore[arg-type] + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + graph.erase_node(split_node) + counters["inductor"]["normalization_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.ops.aten.split_with_sizes.default, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"), +) +def normalize_split_with_size_default_aten(match: Match, *args, **kwargs): + split_node = match.nodes[0] + graph = match.graph + split_input, split_sections, split_dim = _get_split_args_default(split_node) + if split_input is None or split_dim is None or split_sections is None: + log.debug("couldn't find split args") + return + if not is_node_meta_valid(split_node): + log.debug("val absent for node: %s", split_node) + return + if any(isinstance(section, torch.SymInt) for section in split_sections): + # TODO dynamic_shapes with assume_static_by_default=False fails while AOT Autograd tracing. + return + if split_dim < 0: # Normalize split dim + split_dim += split_input.meta["val"].dim() + + new_args = (split_input, split_sections) + new_kwargs = {"dim": split_dim} + if ( + split_node.args == new_args + and split_node.kwargs == new_kwargs + and split_node.op == "call_function" + ): + return + + with graph.inserting_after(split_node): + new_split_node = graph.call_function( + torch.ops.aten.split_with_sizes.default, + args=new_args, + kwargs=new_kwargs, # type: ignore[arg-type] + ) + split_node.replace_all_uses_with(new_split_node) + new_split_node.meta.update(split_node.meta) + graph.erase_node(split_node) + counters["inductor"]["normalization_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunction( + torch.ops.aten.cat.default, + getitem_split_aten, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_aten_pass"), +) +def merge_split_cat_aten(match: Match, *args, **kwargs): + graph = match.graph + split_node = match.nodes[0] + threshold_to_cat = torch._inductor.config.post_grad_fusion_options[ + "split_cat_aten_pass" + ].get("threshold_to_cat", 10) + # get the getitem nodes from the split node + getitem_nodes = list(split_node.users.keys()) + for cat_node in list(getitem_nodes[0].users.keys()): + cat_dim = get_arg_value(cat_node, 1, "dim") + cat_inputs = get_arg_value(cat_node, 0, "tensors") + if len(cat_inputs) < threshold_to_cat: + continue + # check split node and cat node has same dim, and all getitem nodes have same parent node + parent_to_indices = defaultdict(list) # type: ignore[var-annotated] + parent_to_getitems = defaultdict(list) # type: ignore[var-annotated] + for cat_input in cat_inputs: + # skip all non-getitem cat input + if cat_input.target != operator.getitem: + continue + current_getitem_parent = cat_input.args[0] + split_dim = get_arg_value(current_getitem_parent, 2, "dim") + if split_dim != cat_dim: + break + getitem_idx = cat_input.args[1] + if ( + current_getitem_parent not in parent_to_indices + ) or getitem_idx != parent_to_indices[current_getitem_parent][-1][-1] + 1: + parent_to_indices[current_getitem_parent].append([getitem_idx]) + parent_to_getitems[current_getitem_parent].append([cat_input]) + else: + parent_to_getitems[current_getitem_parent][-1].append(cat_input) + parent_to_indices[current_getitem_parent][-1].append(getitem_idx) + + cat_inputs_list = list(cat_inputs) + update_cat_arg = [] + # iterate through the indices to construct the slice nodes + for parent, indices in parent_to_indices.items(): + for idx, indice in enumerate(indices): + start, end = indice[0], indice[-1] + split_sections = list(parent.args[1]) + input_of_current_getitem_parent = parent.args[0] + if len(indice) >= threshold_to_cat or len(indice) == len( + split_sections + ): + if len(indice) != len(split_sections): + # get the start and end slicing indices + slice_node = graph.call_function( + torch.ops.aten.slice.Tensor, + args=( + input_of_current_getitem_parent, + split_dim, # type: ignore[possibly-undefined] + sum(split_sections[:start]), + sum(split_sections[: end + 1]), + ), + ) + else: + slice_node = input_of_current_getitem_parent + # find the index in the cat_inputs_list given the getitem node + update_cat_arg.append( + ( + slice_node, + cat_inputs_list.index(parent_to_getitems[parent][idx][0]), + cat_inputs_list.index(parent_to_getitems[parent][idx][-1]), + ) + ) + + result = [] + i = 0 + for slice_tensor, start, end in update_cat_arg: + while i < start: + result.append(cat_inputs_list[i]) + i += 1 + result.append(slice_tensor) + i = end + 1 + while i < len(cat_inputs_list): + result.append(cat_inputs_list[i]) + i += 1 + + cat_node.update_arg(0, result) + for getitem_node in getitem_nodes: + if len(getitem_node.users) == 0: + graph.erase_node(getitem_node) + if len(split_node.users) == 0: + graph.erase_node(split_node) + counters["inductor"]["split_cat_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunction( + torch.ops.aten.cat.default, + ListOf( + CallFunctionVarArgs(torch.ops.aten.select.int, users=MULTIPLE), + partial=True, + ), + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("select_cat_aten_pass"), +) +def merge_select_cat_aten(match: Match, *args, **kwargs): + graph = match.graph + node = match.nodes[0] + node_input = get_arg_value(node, 0, "tensors") + # get the select nodes from the node + select_nodes = list(node_input.users.keys()) + for cat_node in list(node.users.keys()): + if cat_node.target == torch.ops.aten.cat.default: + cat_dim = get_arg_value(cat_node, 1, "dim") + cat_inputs = get_arg_value(cat_node, 0, "tensors") + # check all select nodes has same slice dim + if not all( + select_node.args[1] == select_nodes[0].args[1] + for select_node in select_nodes + ): + continue + # We only consider the case where selece slice dim and cat node has same dim + if select_nodes[0].args[1] != cat_dim: + continue + if not is_node_meta_valid(cat_node): + continue + # check the cat node has consecutive indices + indices = [select.args[2] for select in cat_node.args[0]] # type: ignore[union-attr] + if ( + not is_sorted_and_consecutive(indices) # type: ignore[arg-type] + or len(select_nodes) != len(cat_inputs) + ): + continue + # check all the select nodes can be merged to the cat node input + if len(indices) != select_nodes[0].args[0].meta["val"].shape[cat_dim]: # type: ignore[union-attr] + continue + # reshape the node input to be the same shape as the cat node + with graph.inserting_before(node): + view_node = graph.call_function( + torch.ops.aten.view.default, + args=(node_input, cat_node.meta["val"].shape), + ) + # replace the node input with the new node + cat_node.replace_all_uses_with(view_node) + view_node.meta.update(cat_node.meta) + # remove the cat node + graph.erase_node(cat_node) + for select_node in select_nodes: + if len(select_node.users) == 0: + graph.erase_node(select_node) + counters["inductor"]["select_cat_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunctionVarArgs(torch.ops.aten.cat.default, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_aten_pass"), +) +def normalize_cat_default_aten(match: Match, *args, **kwargs): + cat_node = match.nodes[0] + graph = match.graph + tensors = get_arg_value(cat_node, 0, "tensors") + cat_dim = get_arg_value(cat_node, 1, "dim") + if cat_dim is None: + cat_axis = cat_node.kwargs.get("axis") + if cat_axis is not None: + cat_dim = cat_axis + else: + cat_dim = 0 + if tensors is None or cat_dim is None: + log.debug("couldn't find cat args") + return + assert isinstance(tensors, (list, tuple)) + for tensor in itertools.chain([cat_node], tensors): + if "val" not in tensor.meta: + log.debug("val absent for node: %s", tensor) + return + + ndim = cat_node.meta["val"].dim() + + def is_empty_tensor(x: torch.fx.Node) -> bool: + # special case where torch.ops.aten.cat.default supports cat'ing with an empty tensor + x_shape = x.meta["val"].shape + return len(x_shape) == 1 and x_shape[0] == 0 + + assert all(ndim == x.meta["val"].dim() or is_empty_tensor(x) for x in tensors) + + if cat_dim < 0: # Normalize cat dim + cat_dim += ndim + + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.ops.aten.cat.default, + args=(tensors,), + kwargs={"dim": cat_dim}, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + graph.erase_node(cat_node) + counters["inductor"]["normalization_aten_pass"] += 1 + + +@register_graph_pattern( + CallFunction( + torch.ops.aten.cat, + ListOf(CallFunctionVarArgs(torch.ops.aten.unsqueeze)), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_aten_pass"), +) +def merge_unbind_stack_aten(match: Match, *args, **kwargs): + node = match.nodes[-1] + graph = match.graph + # pyre-fixme[6] + unsqueeze_nodes = list(node.args[0]) # type: ignore[arg-type] + cat_dim = get_arg_value(node, 1, "dim") + # check the unsqueeze nodes come from the select nodes + if not all( + get_arg_value(unsqueeze_node, 0, "input").target == torch.ops.aten.select + for unsqueeze_node in unsqueeze_nodes + ): + return + select_nodes = [ + get_arg_value(unsqueeze_node, 0, "input") for unsqueeze_node in unsqueeze_nodes + ] + parent_of_select_node = get_arg_value(select_nodes[0], 0, "input") + # check the target of select_nodes are the same + if not all( + select_node.target == torch.ops.aten.select for select_node in select_nodes + ): + return + # check the select nodes come from the same parent node + if not all( + get_arg_value(select_node, 0, "input") == parent_of_select_node + for select_node in select_nodes + ): + return + if len(unsqueeze_nodes) != len(select_nodes): + return + # check the select nodes have the same dim + if not all( + get_arg_value(select_node, 1, "dim") == cat_dim for select_node in select_nodes + ): + return + # check the select nodes have consecutive indices starting from 0 + if get_arg_value(select_nodes[0], 2, "index") != 0 or not is_sorted_and_consecutive( + [get_arg_value(select_node, 2, "index") for select_node in select_nodes] + ): + return + # check the users of parent of select node only from unsqueeze nodes that go to the cat node + # we simply check the number of users of the parent of select node + if len(parent_of_select_node.users.keys()) != len(node.args[0]): # type: ignore[arg-type] + return + node.replace_all_uses_with(parent_of_select_node) + graph.erase_node(node) + for unsqueeze_node in unsqueeze_nodes: + graph.erase_node(unsqueeze_node) + for select_node in select_nodes: + if len(select_node.users) == 0: + graph.erase_node(select_node) + counters["inductor"]["unbind_stack_aten_pass"] += 1 + + +def divide_into_consecutive_sublists(indices: list[int]) -> list[list[int]]: + n = len(indices) + if n <= 1: + return [indices] + + # Initialize the list of sublists + sublists = [] + + # Iterate over the indices + i = 0 + while i < n: + # Initialize the current sublist + sublist = [indices[i]] + + # Iterate over the remaining indices + j = i + 1 + while j < n and indices[j] == indices[j - 1] + 1: + # Add the next index to the current sublist + sublist.append(indices[j]) + j += 1 + + # Add the current sublist to the list of sublists + sublists.append(sublist) + # Move to the next index + i = j + + return sublists + + +def update_args_from_split_getitem( + graph: torch.fx.Graph, + node: torch.fx.Node, + getitem_indices: list[int], + parents_seen: list[torch.fx.Node], + new_cat_args: list[torch.fx.Node], + new_cat_args_meta: list[torch.fx.Node], + idx_to_getitems: dict[int, torch.fx.Node], + threshold_to_cat: int = 2, +): + split_input, split_size, split_dim = _get_split_args_default(parents_seen[-1]) + # case 1: the number of getitems is the same as the split size, eliminate the split + if len(split_size) == len(getitem_indices) and is_sorted_and_consecutive( + getitem_indices + ): + # we can merge the getitems from the previous parent + new_cat_args.append(split_input) + new_cat_args_meta.append(split_input.meta["example_value"]) + else: + if len(getitem_indices) > 0: + # case 2: the number of getitems is smaller than the split size but larger than the threshold, and + # the indices of getitems are not all consecutive, we need to divide the indices into multiple groups + geitem_indices_sublist = divide_into_consecutive_sublists(getitem_indices) + for sublist in geitem_indices_sublist: + if len(sublist) >= threshold_to_cat: + # case 2: the number of getitems is smaller than the split size but larger than the threshold + # we need to slice the input of parent + start_fused_size = sum(split_size[: sublist[0]]) + end_fused_size = sum(split_size[: sublist[-1] + 1]) + slice_list = [] + for i in range(len(split_input.meta["example_value"].shape)): # type: ignore[union-attr] + if i != split_dim: + slice_list.append(slice(None, None, None)) + else: + slice_list.append( + slice(start_fused_size, end_fused_size, None) + ) + with graph.inserting_after(node): + slice_node = graph.call_function( + operator.getitem, + args=(split_input, tuple(slice_list)), + ) + slice_node.meta["example_value"] = split_input.meta[ + "example_value" + ][tuple(slice_list)] + new_cat_args.append(slice_node) + new_cat_args_meta.append(slice_node.meta["example_value"]) + else: + # case 3: the number of getitems is smaller than the threshold, no merge is done + # get the getitems based on the indexes + for i in sublist: + new_cat_args.append(idx_to_getitems[i]) + new_cat_args_meta.append( + idx_to_getitems[i].meta["example_value"] + ) + + +def reshape_cat_node( + graph: torch.fx.Graph, + cat_node: torch.fx.Node, + unbind_input: torch.fx.Node, + cat_dim: int, + unbind_dim: int, + cat_shape: torch.Size, +) -> torch.fx.Node: + if cat_dim != unbind_dim: + # construct the permute node args, which has the same shape as the slice node + # then it has the same dim as the unbind_input, i.e., shape of cat + 1 + with graph.inserting_after(cat_node): + permute_list = list(range(len(cat_shape) + 1)) + permute_list[unbind_dim], permute_list[cat_dim] = ( + permute_list[cat_dim], + permute_list[unbind_dim], + ) + permute_node = graph.call_function( + torch.permute, + args=(unbind_input, permute_list), + ) + permute_node.meta["example_value"] = torch.permute( + unbind_input.meta["example_value"], permute_list + ) # type: ignore[arg-type] + else: + permute_node = unbind_input + with graph.inserting_after(permute_node): + reshape_node = graph.call_function( + torch.reshape, args=(permute_node, tuple(cat_shape)) + ) + reshape_node.meta["example_value"] = torch.reshape( + permute_node.meta["example_value"], tuple(cat_shape) + ) # type: ignore[arg-type] + return reshape_node + + +def update_args_from_unbind_getitem( + graph: torch.fx.Graph, + node: torch.fx.Node, # cat or stack node + getitem_indices: list[int], + parents_seen: list[torch.fx.Node], + new_cat_args: list[torch.fx.Node], + new_cat_args_meta: list[torch.fx.Node], + idx_to_getitems: dict[int, torch.fx.Node], + threshold_to_cat: int = 2, +): + unbind_input = get_arg_value(parents_seen[-1], 0, "input") # split or unbind input + unbind_dim = get_arg_value(parents_seen[-1], 1, "dim") # split or unbind dim + cat_dim = get_arg_value(node, 1, "dim") # cat or stack dim + # case 1: the number of getitems is the same as the split size, eliminate the split + size = list(unbind_input.meta["example_value"].shape)[unbind_dim] + if size == len(getitem_indices): + cat_shape = torch.cat( + [idx_to_getitems[i].meta["example_value"] for i in getitem_indices], + dim=cat_dim, + ).shape + # we can merge the getitems from the previous parent + reshape_node = reshape_cat_node( + graph, node, unbind_input, cat_dim, unbind_dim, cat_shape + ) + new_cat_args.append(reshape_node) + new_cat_args_meta.append(reshape_node.meta["example_value"]) + elif len(getitem_indices) >= threshold_to_cat and is_sorted_and_consecutive( + getitem_indices + ): + # case 2: the number of getitems is smaller than the split size but larger than the threshold + # we need to slice the input of parent + cat_shape = torch.cat( + [idx_to_getitems[i].meta["example_value"] for i in getitem_indices], + dim=cat_dim, + ).shape + slice_list = [] + for i in range(len(cat_shape) + 1): + if i != unbind_dim: + slice_list.append(slice(None, None, None)) # start, end, step + else: + slice_list.append( + slice(getitem_indices[0], getitem_indices[-1] + 1, None) + ) + with graph.inserting_after(node): + slice_node = graph.call_function( + operator.getitem, + args=(unbind_input, tuple(slice_list)), + ) + slice_node.meta["example_value"] = torch.narrow( + unbind_input.meta["example_value"], + unbind_dim, + getitem_indices[0], + getitem_indices[-1] - getitem_indices[0] + 1, + ) + reshape_node = reshape_cat_node( + graph, node, slice_node, cat_dim, unbind_dim, cat_shape + ) + new_cat_args.append(reshape_node) + new_cat_args_meta.append(reshape_node.meta["example_value"]) + else: + # case 3: the number of getitems is smaller than the threshold, no merge is done + # get the getitems based on the indexes + for i in getitem_indices: + new_cat_args.append(idx_to_getitems[i]) + new_cat_args_meta.append(idx_to_getitems[i].meta["example_value"]) + + +def construct_cat_args( + graph: torch.fx.Graph, + cat_or_stack_node: torch.fx.Node, + inputs: list[torch.fx.Node], + split_or_unbind_node: torch.fx.Node, + threshold_to_cat: int = 2, + run_update_func: Callable = update_args_from_split_getitem, # type: ignore[type-arg] +) -> tuple[list[torch.fx.Node], list[torch.Tensor]]: + new_cat_args, parents_seen, getitem_indices, idx_to_getitems = [], [], [], {} # type: ignore[var-annotated] + new_cat_args_meta = [] # type: ignore[var-annotated] + for input in inputs: + if input.target != operator.getitem: + # update the last arg based on getitem_indices and parents_seens + if len(parents_seen) > 0: + run_update_func( # type: ignore[arg-type, union-attr] + graph, + cat_or_stack_node, + getitem_indices, + parents_seen, + new_cat_args, + new_cat_args_meta, + idx_to_getitems, # type: ignore[arg-type, union-attr] + threshold_to_cat, + ) + new_cat_args.append(input) + new_cat_args_meta.append(input.meta["example_value"]) + # reset the indices array + getitem_indices, idx_to_getitems = [], {} + else: + # get the parent node of the getitem input + parent, idx = input.args[0], input.args[1] # type: ignore[union-attr] + if parent.target != split_or_unbind_node.target: # type: ignore[union-attr] + new_cat_args.append(input) + new_cat_args_meta.append(input.meta["example_value"]) + continue + # cannot use parents_seen to check since the first item could be non getitem node + if len(parents_seen) == 0: + parents_seen.append(parent) + idx_to_getitems[idx] = input + getitem_indices.append(idx) + # case: we only have one getitem input, and it is in the last position + if input == inputs[-1]: + new_cat_args.append(input) + new_cat_args_meta.append(input.meta["example_value"]) + continue + # if it is the last input in the tensors, we also check if it can be optimized + if parent != parents_seen[-1] or input == inputs[-1]: + if input == inputs[-1]: + getitem_indices.append(idx) + idx_to_getitems[idx] = input + run_update_func( # type: ignore[arg-type, union-attr] + graph, + cat_or_stack_node, + getitem_indices, + parents_seen, + new_cat_args, + new_cat_args_meta, + idx_to_getitems, # type: ignore[arg-type, union-attr] + threshold_to_cat, + ) + # reset the indices array for the next parent + # remember to add the last element since it is the first + # item in this round of parent + # add the parent to the list of seen parents + parents_seen.append(parent) + getitem_indices, idx_to_getitems = [idx], {idx: input} + else: + getitem_indices.append(idx) + idx_to_getitems[idx] = input + return new_cat_args, new_cat_args_meta + + +def remove_split_unbind_children(graph: torch.fx.Graph, inputs: list[torch.fx.Node]): + nodes = OrderedSet[Any]() + for input in inputs: + if input.target == operator.getitem: + nodes.add(input.args[0]) # type: ignore[union-attr] + if len(input.users.keys()) == 0: + graph.erase_node(input) + # check the split node to remove if it has no users + for node in nodes: + if len(node.users.keys()) == 0: # type: ignore[union-attr] + graph.erase_node(node) # type: ignore[arg-type] + + +# ############pattern to be optimized is######### + +# split_node(dim=1) -> user=multiple +# / \ ... / \ +# other inputs getitem getitem getitem -> user=multiple +# \ / \ +# cat(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# split_node(dim=1) other inputs -> -> user=multiple +# / \ +# cat (user=mul, dim=1, split_node) + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_cat_to_slices_pass"), +) +def split_cat_to_slices(match: Match, split_sections: list[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + split_nodes = [node for node in match.nodes if node.target == torch.split] + if split_nodes: + split_node = next(node for node in split_nodes) + else: + # Handle the case where there are no nodes with a target of torch.split + return + split_dim = get_arg_value(split_node, 2, "dim") or 0 + graph = match.graph + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "split_cat_to_slices_pass" + ].get("threshold_to_cat", 10) + # get the cat_node and check its inputs and meta data + next_users = find_next_users(split_node) + for cat_node in next_users: + if cat_node.target != torch.cat or not is_node_meta_valid(cat_node): + continue + cat_inputs = get_arg_value(cat_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, _ = construct_cat_args( + graph, + cat_node, + cat_inputs, + split_node, + threshold_to_cat, + update_args_from_split_getitem, + ) + # At least one node would be in the returned new_cat_args + # case 1: if new cat args has length 1, we can remove the cat node + if len(new_cat_args) == 1: + cat_node.replace_all_uses_with(new_cat_args[0]) + # remove inputs of cat_node if they have no users + cat_inputs = cat_node.args[0] # type: ignore[union-attr] + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] + counters["inductor"]["split_cat_to_slices_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(cat_inputs): + new_args = (new_cat_args,) + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.cat, + args=new_args, + # split and cat have the same dim + kwargs={"dim": split_dim}, + ) + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + # remove the cat node + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) + counters["inductor"]["split_cat_to_slices_pass"] += 1 + + +# ############pattern to be optimized is######### + +# unbind(dim=0) -> user=multiple +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ / \ +# cat(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# input_of_unbind +# | \ +# slice +# | +# view +# | + + +@register_graph_pattern( + CallFunction( + torch.cat, + getitem_unbind, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("unbind_cat_to_view_pass"), +) +def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int): + unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + graph = match.graph + # get the cat_node and check its inputs and meta data + next_users = find_next_users(unbind_node) + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "unbind_cat_to_view_pass" + ].get("threshold_to_cat", 10) + # get the cat_node and check its inputs and meta data + for cat_node in next_users: + if cat_node.target != torch.cat or not is_node_meta_valid(cat_node): + continue + inputs = get_arg_value(cat_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, new_cat_args_meta = construct_cat_args( + graph, + cat_node, + inputs, + unbind_node, + threshold_to_cat, + update_args_from_unbind_getitem, + ) + # get the view shape + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + cat_node.replace_all_uses_with(new_cat_args[0]) + # remove inputs of cat_node if they have no users + cat_inputs = cat_node.args[0] # type: ignore[union-attr] + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] + counters["inductor"]["unbind_cat_to_view_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + # get the view shape + cat_dim = get_arg_value(cat_node, 1, "dim") + with graph.inserting_after(cat_node): + new_cat_node = graph.call_function( + torch.cat, + args=(new_cat_args,), + kwargs={"dim": cat_dim}, + ) + new_cat_node.meta["example_value"] = torch.cat( + new_cat_args_meta, dim=cat_dim + ) # type: ignore[arg-type] + cat_node.replace_all_uses_with(new_cat_node) + new_cat_node.meta.update(cat_node.meta) + # remove inputs of cat_node if they have no users + cat_inputs = cat_node.args[0] # type: ignore[union-attr] + graph.erase_node(cat_node) + remove_split_unbind_children(graph, cat_inputs) # type: ignore[arg-type] + counters["inductor"]["unbind_cat_to_view_pass"] += 1 + + +def reshape_cat_node_to_stack( + graph: torch.fx.Graph, + cat_node: torch.fx.Node, + stack_node: torch.fx.Node, + split_or_unbind_dim: int, +) -> None: + # reshape the cat node to the stack node shape + stack_shape = stack_node.meta["example_value"].shape + stack_dim = _get_dim(stack_node) + if stack_dim != split_or_unbind_dim: + # case 1: the stack dim is not the same as the split dim + # we need to reshape the split input before we do the reshape + reshape_list = list(stack_shape) + reshape_list[stack_dim], reshape_list[split_or_unbind_dim] = ( + reshape_list[split_or_unbind_dim], + reshape_list[stack_dim], + ) + reshape_node = graph.call_function( + torch.reshape, + args=(cat_node, tuple(reshape_list)), + ) + reshape_node.meta["example_value"] = torch.reshape( + cat_node.meta["example_value"], tuple(reshape_list) + ) + permute_list = list(range(len(stack_shape))) + permute_list[stack_dim], permute_list[split_or_unbind_dim] = ( + permute_list[split_or_unbind_dim], + permute_list[stack_dim], + ) + permute_node = graph.call_function( + torch.permute, + args=(reshape_node, permute_list), + ) + permute_node.meta["example_value"] = torch.permute( + reshape_node.meta["example_value"], permute_list + ) + else: + # case 2: the stack dim is the same as the split dim + # we can directly reshape the split input + permute_node = cat_node + reshape_node = graph.call_function( + torch.Tensor.view, + args=(permute_node, *stack_shape), # type: ignore[arg-type] + ) + stack_node.replace_all_uses_with(reshape_node) + reshape_node.meta.update(stack_node.meta) + stack_inputs = stack_node.args[0] # type: ignore[union-attr] + # remove stack node + graph.erase_node(stack_node) + # check the input of stack node, and remove nodes that have no users + remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] + + +def convert_reshape_cat_arg_to_stack( + graph: torch.fx.Graph, + cat_node: torch.fx.Node, + stack_node: torch.fx.Node, + stack_node_shape: torch.Size, + stack_dim: int, + split_dim: int, +) -> torch.fx.Node: + # reshape the cat node to the stack node shape + cat_shape = cat_node.meta["example_value"].shape + if stack_dim != split_dim: + permute_list = list(range(len(cat_shape))) + permute_list[stack_dim], permute_list[split_dim] = ( + permute_list[split_dim], + permute_list[stack_dim], + ) + permute_node = graph.call_function( + torch.permute, + args=(cat_node, permute_list), + ) + permute_node.meta["example_value"] = torch.permute( + cat_node.meta["example_value"], permute_list + ) + else: + permute_node = cat_node + reshape_node = graph.call_function( + torch.Tensor.view, + args=(permute_node, tuple(stack_node_shape)), # type: ignore[arg-type] + ) + reshape_node.meta["example_value"] = torch.Tensor.view( + permute_node.meta["example_value"], + tuple(stack_node_shape), # type: ignore[arg-type] + ) + return reshape_node + + +# ############pattern to be optimized is######### +# | | +# split split (dim=1) +# / \ / \ +# getitem ... getitem other ops +# \ | / / +# stack(user=mul, dim=1 or 2) -> can be different dim +# | + +# ################after transformation############# + +# / \ ... / \ +# getitem getitem getitem getitem -> user=multiple +# \ / +# cat(user=mul, dim=1) cat_other_opts +# \ / +# cat +# | +# view +# | + + +@register_graph_pattern( + CallFunction( + torch.stack, + getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("split_stack_to_cats_pass"), +) +def split_stack_to_cats(match: Match, split_sections: list[int], dim: int): + if not isinstance(split_sections, (list, tuple)): # Unnormalized split + return + split_node = next(node for node in match.nodes if node.target == torch.split) + split_dim = get_arg_value(split_node, 2, "dim") or 0 + graph = match.graph + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "split_stack_to_cats_pass" + ].get("threshold_to_cat", 10) + # get the stack_node and check its inputs and meta data + next_users = find_next_users(split_node) + for stack_node in next_users: + if stack_node.target != torch.stack or not is_node_meta_valid(stack_node): + continue + inputs = get_arg_value(stack_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, new_cat_args_meta = construct_cat_args( + graph, + stack_node, + inputs, + split_node, + threshold_to_cat, + update_args_from_split_getitem, + ) + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + reshape_cat_node_to_stack(graph, new_cat_args[0], stack_node, split_dim) + counters["inductor"]["split_stack_to_cats_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + with graph.inserting_after(stack_node): + cat_node = graph.call_function( + torch.cat, + args=(new_cat_args,), + kwargs={"dim": split_dim}, + ) + cat_node.meta["example_value"] = torch.cat( # type: ignore[arg-type] + new_cat_args_meta, dim=split_dim + ) + reshape_cat_node_to_stack(graph, cat_node, stack_node, split_dim) + counters["inductor"]["split_stack_to_cats_pass"] += 1 + + +# ############pattern to be optimized is######### + +# unbind(dim=1) -> user=multiple +# \ ... / \ +# others getitem getitem getitem -> user=multiple +# \ \ / \ +# stack(user=mul, dim=1) other_op +# | + +# ################after transformation############# + +# input_of_unbind +# | \ +# slice +# | +# view others +# | / +# stack +# | + + +@register_graph_pattern( + CallFunction( + torch.stack, + getitem_unbind, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("unbind_stack_to_slices_pass"), +) +def unbind_stack_to_slices(match: Match, unbind_input: torch.fx.Node, dim: int): + unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + graph = match.graph + # get the cat_node and check its inputs and meta data + next_users = find_next_users(unbind_node) + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "unbind_stack_to_slices_pass" + ].get("threshold_to_cat", 10) + # get the cat_node and check its inputs and meta data + for stack_node in next_users: + if stack_node.target != torch.stack or not is_node_meta_valid(stack_node): + continue + inputs = get_arg_value(stack_node, 0, "tensors") # type: ignore[union-attr] + new_cat_args, new_cat_args_meta = construct_cat_args( + graph, + stack_node, + inputs, + unbind_node, + threshold_to_cat, + update_args_from_unbind_getitem, + ) + unbind_dim = get_arg_value(unbind_node, 1, "dim") or 0 + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + reshape_cat_node_to_stack(graph, new_cat_args[0], stack_node, unbind_dim) + counters["inductor"]["unbind_stack_to_slices_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + # get the view shape + cat_dim = get_arg_value(stack_node, 1, "dim") + with graph.inserting_after(stack_node): + new_cat_node = graph.call_function( + torch.cat, + args=(new_cat_args,), + kwargs={"dim": cat_dim}, + ) + new_cat_node.meta["example_value"] = torch.cat( + new_cat_args_meta, dim=cat_dim + ) + reshape_cat_node_to_stack(graph, new_cat_node, stack_node, unbind_dim) + counters["inductor"]["unbind_stack_to_slices_pass"] += 1 + + +# ############pattern to be optimized is######### +# input +# | +# split(dim=1) -> user=multiple +# \ \ +# others getitem getitem +# \ \ / +# reshape reshape reshape other_op +# \ \ / / +# stack(user=mul, dim=0) +# | + +# ################after transformation############# +# input +# | +# permute +# | +# reshape others +# | / +# cat (dim=0) +# | + + +def get_view_shape_list(cat_arg: torch.fx.Node, stack_dim: int) -> list[int]: + # cat_arg must be the split input + view_shape_list = [] + for user in cat_arg.users.keys(): + if user.target == torch.split: + for getitem in user.users.keys(): + if getitem.target == operator.getitem: + reshape_user = [ + user + for user in getitem.users.keys() + if user.target == torch.reshape + ] + if len(reshape_user) > 0: + view_shape_list = list( + reshape_user[0] + .meta["example_value"] + .unsqueeze(stack_dim) + .shape + ) + view_shape_list[stack_dim] = -1 + return view_shape_list + return view_shape_list + + +@register_graph_pattern( + CallFunction( + torch.stack, + reshape_getitem_split, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("move_reshape_out_of_split_stack_pass"), +) +def move_reshape_out_of_split_stack(match: Match, *args, **kwargs): + split_node = next(node for node in match.nodes if node.target == torch.split) + split_dim = _get_dim(split_node) + split_users = list(split_node.users.keys()) + stack_nodes = [node for node in match.nodes if node.target == torch.stack] + graph = match.graph + threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ + "move_reshape_out_of_split_stack_pass" + ].get("threshold_to_cat", 10) + for stack_node in stack_nodes: + if not is_node_meta_valid(stack_node): + log.debug("example value absent for node: %s", stack_node) + continue + stack_dim = _get_dim(stack_node) + stack_inputs = get_arg_value(stack_node, 0, "tensors") # type: ignore[union-attr] + inputs = [] + for stack_input in stack_inputs: + if stack_input.target != torch.reshape: + inputs.append(stack_input) + else: + inputs.append(stack_input.args[0]) # type: ignore[union-attr] + new_cat_args, _new_cat_args_meta = construct_cat_args( + graph, + stack_node, + inputs, + split_node, + threshold_to_cat, + update_args_from_split_getitem, + ) + # At least one node would be in the returned new_cat_args + # case 1: only one node in the new cat args, don't need to cat + if len(new_cat_args) == 1: + reshape_node = convert_reshape_cat_arg_to_stack( + graph, + new_cat_args[0], + stack_node, + stack_node.meta["example_value"].shape, + stack_dim, + split_dim, + ) + stack_node.replace_all_uses_with(reshape_node) + # remove stack node + graph.erase_node(stack_node) + # check the input of stack node, and remove nodes that have no users + remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] + remove_split_unbind_children(graph, split_users) # type: ignore[arg-type] + counters["inductor"]["move_reshape_out_of_split_stack_pass"] += 1 + continue + if len(new_cat_args) > 1 and len(new_cat_args) < len(inputs): + # decompose the cat args into multiple stack nodes, i.e., we stack + # all the nodes exist in the stack inputs and reshape the rest followed by a cat + stack_node_input, stack_node_input_meta, cat_inputs = [], [], [] # type: ignore[var-annotated] + for cat_arg in new_cat_args: + if cat_arg not in stack_inputs: + if len(stack_node_input) > 0: + with graph.inserting_after(stack_node): + decomposed_stack_node = graph.call_function( + torch.stack, + args=(stack_node_input,), + kwargs={"dim": stack_dim}, + ) + decomposed_stack_node.meta["example_value"] = torch.stack( + stack_node_input_meta, dim=stack_dim + ) + cat_inputs.append(decomposed_stack_node) + # cat_arg must be the split input + view_shape_list = get_view_shape_list(cat_arg, stack_dim) + stack_node_shape = torch.reshape( + cat_arg.meta["example_value"], tuple(view_shape_list) + ).shape # type: ignore[union-attr] + cat_inputs.append( + convert_reshape_cat_arg_to_stack( + graph, + cat_arg, + stack_node, + stack_node_shape, + stack_dim, + split_dim, + ) + ) + stack_node_input, stack_node_input_meta = [], [] + else: + stack_node_input.append(cat_arg) + stack_node_input_meta.append(cat_arg.meta["example_value"]) + + if len(stack_node_input) > 0: + with graph.inserting_after(stack_node): + decomposed_stack_node = graph.call_function( + torch.stack, + args=(stack_node_input,), + kwargs={"dim": stack_dim}, + ) + decomposed_stack_node.meta["example_value"] = torch.stack( + stack_node_input_meta, dim=stack_dim + ) + cat_inputs.append(decomposed_stack_node) + + with graph.inserting_after(stack_node): + cat_node = graph.call_function( + torch.cat, + args=(cat_inputs,), + kwargs={"dim": stack_dim}, + ) + stack_node.replace_all_uses_with(cat_node) + cat_node.meta.update(stack_node.meta) + graph.erase_node(stack_node) + remove_split_unbind_children(graph, stack_inputs) # type: ignore[arg-type] + remove_split_unbind_children(graph, split_users) # type: ignore[arg-type] + counters["inductor"]["move_reshape_out_of_split_stack_pass"] += 1 + + +view_getitem_split_aten = ListOf( + CallFunction( + [torch.ops.aten.reshape.default], + CallFunction( + operator.getitem, + CallFunctionVarArgs( + torch.ops.aten.split_with_sizes.default, users=MULTIPLE + ), + Ignored(), + _users=MULTIPLE, + ), + Arg(), + _users=MULTIPLE, + ), + partial=True, +) + + +@register_graph_pattern( + CallFunction( + torch.ops.aten.cat.default, + view_getitem_split_aten, + dim=Ignored(), + _users=MULTIPLE, + ), + pass_dict=construct_pattern_matcher_pass("move_view_after_cat_aten_pass"), +) +def move_view_after_cat(match: Match, *args, **kwargs): + split_node = next( + node + for node in match.nodes + if node.target == torch.ops.aten.split_with_sizes.default + ) + split_input, split_section, split_dim = _get_split_args_default(split_node) + split_users = list(split_node.users.keys()) + getitem_indices = [ + getitem.args[1] for getitem in split_users if getitem.target == operator.getitem + ] + if not is_sorted_and_consecutive(getitem_indices): # type: ignore[arg-type] + return + cat_nodes = [ + node for node in match.nodes if node.target == torch.ops.aten.cat.default + ] + graph = match.graph + for cat_node in cat_nodes: + if not is_node_meta_valid(cat_node): + log.debug("example value absent for node: %s", cat_node) + continue + cat_dim = _get_dim(cat_node) + cat_inputs = get_arg_value(cat_node, 0, "tensors") # type: ignore[union-attr] + # we only consider the following special case + if len(cat_inputs) != len(split_section): + continue + # check if the cat inputs are all the view nodes + if not all( + view_node.target == torch.ops.aten.reshape.default + for view_node in cat_inputs + ): + continue + # check if the view nodes are all from getitem nodes + if not all( + view_node.args[0].target == operator.getitem for view_node in cat_inputs + ): + continue + view_indices = [view.args[0].args[1] for view in cat_inputs] + if not is_sorted_and_consecutive(view_indices): # type: ignore[arg-type] + continue + if cat_dim != split_dim: + # construct permute node + permute_list = list(range(len(cat_node.meta["val"].shape) + 1)) + permute_list[split_dim], permute_list[cat_dim] = ( + permute_list[cat_dim], + permute_list[split_dim], + ) + permute_node = graph.call_function( + torch.ops.aten.permute.default, + args=(split_input, permute_list), + ) + else: + permute_node = split_input + + with graph.inserting_before(cat_node): + view_node = graph.call_function( + torch.ops.aten.reshape.default, + args=(permute_node, list(cat_node.meta["val"].shape)), + ) + cat_node.replace_all_uses_with(view_node) + view_node.meta.update(cat_node.meta) + graph.erase_node(cat_node) + counters["inductor"]["move_view_after_cat_aten_pass"] += 1 diff --git a/phivenv/Lib/site-packages/torch/_inductor/kernel/__init__.py b/phivenv/Lib/site-packages/torch/_inductor/kernel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..460428ac8b741b5745a5fd2b7074dbdcd8453d7f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/kernel/__init__.py @@ -0,0 +1 @@ +from . import mm, mm_common, mm_plus_mm diff --git a/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b69f3f7d5fe148084faee53891ee29a46668b48 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/bmm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/bmm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45981c794f828f2cc794e974586674fe9972c4e8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/bmm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af583e8d28cd053e5a5918d04206aa6419bd4aeb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/conv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/flex_attention.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/flex_attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd81d64c108d3215440d5c8c3365782b1ebefa48 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/flex_attention.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/flex_decoding.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/flex_decoding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c83e2767054ac275d89ea0b7461b592bd22f580b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/flex_decoding.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07be62e0b6ef8a42103f29ec5088ca0ef9960fe1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea6d50ec3c83a43be96245dff4217b158fcc006a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm_common.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17cec1b242876c8c81ca7536c1f3d740c05e5e42 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm_plus_mm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm_scaled_grouped.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm_scaled_grouped.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0cd22e40a6ddefc3521d39f17e9fb96644eca4e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/kernel/__pycache__/mm_scaled_grouped.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/kernel/bmm.py b/phivenv/Lib/site-packages/torch/_inductor/kernel/bmm.py new file mode 100644 index 0000000000000000000000000000000000000000..8d7c7afdd229a7e025d6c8e07fd255a2cc6336a1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/kernel/bmm.py @@ -0,0 +1,294 @@ +# mypy: allow-untyped-defs +import logging + +import torch +from torch._dynamo.utils import counters +from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate + +from .. import ir, lowering as L +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + SymbolicGridFn, + TritonTemplate, +) +from ..utils import ( + _use_cutlass_for_op, + use_aten_gemm_kernels, + use_ck_gemm_template, + use_cpp_bmm_template, + use_cutlass_template, + use_triton_template, +) +from ..virtualized import V +from .mm_common import ( + _is_static_problem, + addmm_epilogue, + is_batch_stride_largest, + mm_args, + mm_config_kwargs, + mm_options, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + + +@SymbolicGridFn +def bmm_grid(b, m, n, meta, *, cdiv): + return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1) + + +def _is_large_block_for_cpu(m, n, k): + # Thresholds are experimentally determined to reduce Triton CPU compile times + if m > 128 or n > 128 or k > 128: + return True + return m * n > 2**12 + + +bmm_template = TritonTemplate( + name="bmm", + grid=bmm_grid, + source=r""" +{{def_kernel("A", "B")}} + M = {{size("A", -2)}} + N = {{size("B", -1)}} + K = {{size("A", -1)}} + + stride_aq = {{stride("A", 0)}} + stride_am = {{stride("A", 1)}} + stride_ak = {{stride("A", 2)}} + + stride_bq = {{stride("B", 0)}} + stride_bk = {{stride("B", 1)}} + stride_bn = {{stride("B", 2)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + ram = rm % M + if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + rbn = rn % N + + rk = tl.arange(0, BLOCK_K) + + idx_q = tl.program_id(1) # batch dimension for BMM + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_q = tl.program_id(1) # batch dimension for BMM + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask")}} +""", + cache_codegen_enabled_for_template=True, +) + +aten_bmm = ExternKernelChoice(torch.bmm, "at::bmm_out") +aten_bmm_dtype = ExternKernelChoice( + torch.bmm, + "at::_bmm_out_dtype_cuda", + name="bmm_dtype", + op_overload=aten.bmm.dtype_out, +) +aten_baddbmm = ExternKernelChoice( + torch.baddbmm, "at::baddbmm_out", op_overload=aten.baddbmm.out +) + + +@L.register_lowering(aten.bmm) +def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None): + """ + Lowering for autotuning aten.bmm with different backends (Aten, Triton, CUTLASS, etc.) + """ + if all(x.get_device().type == "cpu" for x in [mat1, mat2]): + # decompose to small ops when memory bound + if mat1.get_size()[1] == 1 or mat2.get_size()[2] == 1: + mat1 = L.unsqueeze(mat1, -1) + mat2 = L.unsqueeze(mat2, 1) + return L.sum_(L.mul(mat1, mat2), axis=2) + + def is_valid_to_require_contiguous(t): + if not ir.is_storage_and_layout(t): + return True + _, layout = ir.as_storage_and_layout(t, freeze=False) + return isinstance(layout, ir.FlexibleLayout) + + def is_preferred_layout_as_bmm_input(sizes, strides): + # contiguous on one of the last two dims + return ( + strides[-1] == 1 and (sizes[-2] == 1 or strides[-2] >= sizes[-1]) + ) or (strides[-2] == 1 and (sizes[-1] == 1 or strides[-1] >= sizes[-2])) + + # Make the input of bmm contiguous + # if it is not contiguous on either of the last two dims, + # because bmm cpu implementation would do contiguous() if not. + # This is to avoid additional copies in bmm. + def may_require_contiguous(t, meta_t): + sizes = meta_t.meta["val"].size() + strides = meta_t.meta["val"].stride() + if not is_preferred_layout_as_bmm_input(sizes, strides): + t = ir.ExternKernel.require_contiguous(t) + return t + + if is_valid_to_require_contiguous(mat1): + meta_mat1 = V.graph.current_node.args[0] + mat1 = may_require_contiguous(mat1, meta_mat1) + if is_valid_to_require_contiguous(mat2): + meta_mat2 = V.graph.current_node.args[1] + mat2 = may_require_contiguous(mat2, meta_mat2) + + m, n, k, layout, mat1, mat2 = mm_args( + mat1, mat2, layout=layout, out_dtype=out_dtype + ) + + # below is for getting an overview logging info of inductor mms + batch_size = mat1.get_size()[0] # Extract batch dimension + counters["aten_mm_info"][f"aten.bmm_{batch_size}_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten.bmm: batch=%s, m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + batch_size, + m, + n, + k, + mat1.get_dtype(), + mat2.get_dtype(), + layout, + ) + + if out_dtype: + assert mat1.get_device().type == "cuda", "out_dtype is only supported for CUDA" + aten_func = aten_bmm_dtype.bind((mat1, mat2), layout, out_dtype=out_dtype) + else: + aten_func = aten_bmm.bind((mat1, mat2), layout) + + # options to tune from + choices = [aten_func] if use_aten_gemm_kernels() else [] + + device_type = ir.get_device_type(mat1) + bmm_configs = V.choices.get_base_mm_configs(device_type) + + dtype = mat1.get_dtype() + if use_triton_template(layout): + # TODO: add out_dtype support for Triton Template + assert out_dtype is None, "out_dtype is not supported for Triton" + for config in bmm_configs( + m, + n, + k, + **mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize), + ): + bmm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + _, is_nonzero = _is_static_problem(layout) + batch_stride_largest = is_batch_stride_largest(mat1, mat2, layout) + if ( + batch_stride_largest + and is_nonzero + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op("bmm") + ): + from ..codegen.cuda.gemm_template import CUTLASS3xGemmTemplate + + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) # type: ignore[arg-type] + + if use_cpp_bmm_template(layout, mat1, mat2): + from ..codegen.cpp_bmm_template import CppBmmTemplate + + CppBmmTemplate.add_choices( + choices, + layout, + [mat1, mat2], + ) + + if use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2]) + + return autotune_select_algorithm("bmm", choices, [mat1, mat2], layout) + + +@L.register_lowering(aten.baddbmm) +def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): + m, n, k, layout, mat1, mat2, inp = mm_args(mat1, mat2, inp, layout=layout) + + # below is for getting an overview logging info of inductor mms + batch_size = mat1.get_size()[0] + counters["aten_mm_info"][f"aten.baddbmm_{batch_size}_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten.baddbmm: batch_size=%s, m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, inp=%s, output_layout=%s", + batch_size, + m, + n, + k, + mat1.get_dtype(), + mat2.get_dtype(), + inp.get_dtype(), + layout, + ) + + # options to tune from + choices = ( + [aten_baddbmm.bind((inp, mat1, mat2), layout, alpha=alpha, beta=beta)] + if use_aten_gemm_kernels() + else [] + ) + + device_type = ir.get_device_type(mat1) + bmm_configs = V.choices.get_base_mm_configs(device_type) + + if use_triton_template(layout): + for config in bmm_configs( + m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) + ): + bmm_template.maybe_append_choice( + choices, + input_nodes=(inp, mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + prefix_args=1, + epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + epilogue_fn_hash=str(["addmm_epilogue", layout.dtype, alpha, beta]), + ) + + return autotune_select_algorithm("baddbmm", choices, [inp, mat1, mat2], layout) diff --git a/phivenv/Lib/site-packages/torch/_inductor/kernel/conv.py b/phivenv/Lib/site-packages/torch/_inductor/kernel/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..5a52a8eebd80262d33b728cfca49050b55b532da --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/kernel/conv.py @@ -0,0 +1,696 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import logging +from typing import Optional, TYPE_CHECKING, TypedDict + +import torch +from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate + +from .. import config, ir +from ..lowering import ( + add_layout_constraint, + constrain_to_fx_strides, + lowerings as L, + register_lowering, +) +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + SymbolicGridFn, + TritonTemplate, +) +from ..utils import ( + is_ones, + is_zeros, + pad_listlike, + sympy_product, + use_ck_conv_template, + use_triton_template, +) +from ..virtualized import V +from .mm_common import mm_config_kwargs + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from ..ir import TensorBox + +log = logging.getLogger(__name__) + + +aten = torch.ops.aten + + +@SymbolicGridFn +def conv2d_grid(n, c, h, w, meta, *, cdiv): + return ( + cdiv(n * h * w, meta["BLOCK_M"]), + cdiv(c, meta["BLOCK_N"]), + meta["GROUPS"], + ) + + +@SymbolicGridFn +def conv3d_grid(n, c, d, h, w, meta, *, cdiv): + return ( + cdiv(n * d * h * w, meta["BLOCK_M"]), + cdiv(c, meta["BLOCK_N"]), + meta["GROUPS"], + ) + + +def _is_large_block_for_cpu(m, n, k): + # Thresholds are experimentally determined to reduce Triton CPU compile times + if m > 256 or n > 256 or k > 256: + return True + return m * n * k > 2**17 + + +LOOP_BODY_2D = """ + idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H + idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W + idx_x_c = tl.arange(0, BLOCK_K) + k + + x_ptrs = x_base + ( + (idx_x_h * stride_xh)[:, None] + + (idx_x_w * stride_xw)[:, None] + + (idx_x_c * stride_xc)[None, :] + ) + mask_x = ( + (idx_n < BATCH)[:, None] + & (idx_x_h >= 0)[:, None] + & (idx_x_h < IN_H)[:, None] + & (idx_x_w >= 0)[:, None] + & (idx_x_w < IN_W)[:, None] + & (idx_x_c < GROUP_IN_C)[None, :] + ) + matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) + + w_ptrs = w_base + ( + (idx_x_c * stride_wc_in)[:, None] + (i * stride_wh) + (j * stride_ww) + ) + mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C) + matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) + acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32) +""" + +""" +This is a relatively simple conv implementation that can likely be +improved. Many alternate conv versions can be found here: +https://github.com/pytorch/torchdynamo/pull/971 +""" +conv2d_template = TritonTemplate( + name="convolution2d", + grid=conv2d_grid, + source=r""" +{{def_kernel("X", "W")}} + # Tensor dimensions + BATCH = {{size("X", 0)}} + IN_C = {{size("X", 1)}} + IN_H = {{size("X", 2)}} + IN_W = {{size("X", 3)}} + OUT_C = {{size(None, 1)}} + OUT_H = {{size(None, 2)}} + OUT_W = {{size(None, 3)}} + + # Strides: + stride_xn = {{stride("X", 0)}} + stride_xc = {{stride("X", 1)}} + stride_xh = {{stride("X", 2)}} + stride_xw = {{stride("X", 3)}} + stride_wc_out = {{stride("W", 0)}} + stride_wc_in = {{stride("W", 1)}} + stride_wh = {{stride("W", 2)}} + stride_ww = {{stride("W", 3)}} + + nhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + idx_y_w = nhw % OUT_W + nh = nhw // OUT_W + idx_y_h = nh % OUT_H + idx_n = nh // OUT_H + idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + +{% if GROUPS == 1 %} + group = 0 + GROUP_IN_C = IN_C + GROUP_OUT_C = OUT_C +{% else %} + group = tl.program_id(2) + GROUP_IN_C = IN_C // GROUPS + GROUP_OUT_C = OUT_C // GROUPS +{% endif %} + + x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None] + w_base = ( + W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :] + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + +{% if UNROLL %} +{% for i in range(KERNEL_H) %} +{% for j in range(KERNEL_W) %} + i = {{i}} + j = {{j}} + for k in range(0, GROUP_IN_C, BLOCK_K): + """ + + LOOP_BODY_2D + + """ +{% endfor %} +{% endfor %} +{% else %} + # Could be simplified, but slightly slower: + # for i in range(KERNEL_H): + # for j in range(KERNEL_W): + # for k in range(0, GROUP_IN_C, BLOCK_K): + BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K + for ijk in range(KERNEL_H * KERNEL_W * BLOCK_K_COUNT): + k = (ijk % BLOCK_K_COUNT) * BLOCK_K + ij = ijk // BLOCK_K_COUNT + i = ij // KERNEL_W + j = ij % KERNEL_W + """ + + LOOP_BODY_2D + + """ +{% endif %} + + mask = ( + (idx_n < BATCH)[:, None] + & (idx_y_h < OUT_H)[:, None] + & (idx_y_w < OUT_W)[:, None] + & (idx_y_c < GROUP_OUT_C)[None, :] + ) + idx_n = idx_n[:, None] + idx_c = idx_y_c[None, :] + group * GROUP_OUT_C + idx_h = idx_y_h[:, None] + idx_w = idx_y_w[:, None] + + # inductor generates a suffix + {{store_output(("idx_n", "idx_c", "idx_h", "idx_w"), "acc", "mask")}} +""", +) + +LOOP_BODY_3D = """ + idx_x_d = d - PADDING_D + idx_y_d * STRIDE_D + idx_x_h = i - PADDING_H + idx_y_h * STRIDE_H + idx_x_w = j - PADDING_W + idx_y_w * STRIDE_W + idx_x_c = tl.arange(0, BLOCK_K) + k + + x_ptrs = x_base + ( + (idx_x_d * stride_xd)[:, None] + + (idx_x_h * stride_xh)[:, None] + + (idx_x_w * stride_xw)[:, None] + + (idx_x_c * stride_xc)[None, :] + ) + mask_x = ( + (idx_n < BATCH)[:, None] + & (idx_x_d >= 0)[:, None] + & (idx_x_d < IN_D)[:, None] + & (idx_x_h >= 0)[:, None] + & (idx_x_h < IN_H)[:, None] + & (idx_x_w >= 0)[:, None] + & (idx_x_w < IN_W)[:, None] + & (idx_x_c < GROUP_IN_C)[None, :] + ) + matrix_x = tl.load(x_ptrs, mask=mask_x, other=0.0) + + w_ptrs = w_base + ( + (idx_x_c * stride_wc_in)[:, None] + + (d * stride_wd) + (i * stride_wh) + (j * stride_ww) + ) + mask_w = (idx_x_c[:, None] < GROUP_IN_C) & (idx_y_c[None, :] < GROUP_OUT_C) + matrix_w = tl.load(w_ptrs, mask=mask_w, other=0.0) + acc += tl.dot(matrix_x, matrix_w, allow_tf32=ALLOW_TF32) +""" + +conv3d_template = TritonTemplate( + name="convolution3d", + grid=conv3d_grid, + source=r""" +{{def_kernel("X", "W")}} + # Tensor dimensions + BATCH = {{size("X", 0)}} + IN_C = {{size("X", 1)}} + IN_D = {{size("X", 2)}} + IN_H = {{size("X", 3)}} + IN_W = {{size("X", 4)}} + OUT_C = {{size(None, 1)}} + OUT_D = {{size(None, 2)}} + OUT_H = {{size(None, 3)}} + OUT_W = {{size(None, 4)}} + + # Strides: + stride_xn = {{stride("X", 0)}} + stride_xc = {{stride("X", 1)}} + stride_xd = {{stride("X", 2)}} + stride_xh = {{stride("X", 3)}} + stride_xw = {{stride("X", 4)}} + stride_wc_out = {{stride("W", 0)}} + stride_wc_in = {{stride("W", 1)}} + stride_wd = {{stride("W", 2)}} + stride_wh = {{stride("W", 3)}} + stride_ww = {{stride("W", 4)}} + + ndhw = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + idx_y_w = ndhw % OUT_W + ndh = ndhw // OUT_W + idx_y_h = ndh % OUT_H + nd = ndh // OUT_H + idx_y_d = nd % OUT_D + idx_n = nd // OUT_D + idx_y_c = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + +{% if GROUPS == 1 %} + group = 0 + GROUP_IN_C = IN_C + GROUP_OUT_C = OUT_C +{% else %} + group = tl.program_id(2) + GROUP_IN_C = IN_C // GROUPS + GROUP_OUT_C = OUT_C // GROUPS +{% endif %} + + x_base = X + (group * stride_xc * GROUP_IN_C + idx_n * stride_xn)[:, None] + w_base = ( + W + (group * stride_wc_out * GROUP_OUT_C + idx_y_c * stride_wc_out)[None, :] + ) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + +{% if UNROLL %} +{% for d in range(KERNEL_D) %} +{% for i in range(KERNEL_H) %} +{% for j in range(KERNEL_W) %} + d = {{d}} + i = {{i}} + j = {{j}} + for k in range(0, GROUP_IN_C, BLOCK_K): + """ + + LOOP_BODY_3D + + """ +{% endfor %} +{% endfor %} +{% endfor %} +{% else %} + # Could be simplified, but slightly slower: + # for d in range(KERNEL_D): + # for i in range(KERNEL_H): + # for j in range(KERNEL_W): + # for k in range(0, GROUP_IN_C, BLOCK_K): + BLOCK_K_COUNT = (GROUP_IN_C + BLOCK_K - 1) // BLOCK_K + for dijk in range(KERNEL_D * KERNEL_H * KERNEL_W * BLOCK_K_COUNT): + k = (dijk % BLOCK_K_COUNT) * BLOCK_K + dij = dijk // BLOCK_K_COUNT + j = dij % KERNEL_W + di = dij // KERNEL_W + i = di % KERNEL_H + d = di // KERNEL_H + """ + + LOOP_BODY_3D + + """ +{% endif %} + + mask = ( + (idx_n < BATCH)[:, None] + & (idx_y_d < OUT_D)[:, None] + & (idx_y_h < OUT_H)[:, None] + & (idx_y_w < OUT_W)[:, None] + & (idx_y_c < GROUP_OUT_C)[None, :] + ) + idx_n = idx_n[:, None] + idx_c = idx_y_c[None, :] + group * GROUP_OUT_C + idx_d = idx_y_d[:, None] + idx_h = idx_y_h[:, None] + idx_w = idx_y_w[:, None] + + # inductor generates a suffix + {{store_output(("idx_n", "idx_c", "idx_d", "idx_h", "idx_w"), "acc", "mask")}} +""", +) + +aten_convolution = ExternKernelChoice( + torch.convolution, + "at::convolution", + has_out_variant=False, + op_overload=aten.convolution.default, +) + + +def conv1x1_via_mm(x, w, *, out): + w = torch.squeeze(torch.squeeze(w, -1), -1) + return torch.matmul( + x.permute(0, 2, 3, 1), w.permute(1, 0), out=out.permute(0, 2, 3, 1) + ) + + +aten_conv1x1_via_mm = ExternKernelChoice(conv1x1_via_mm, None) + + +class ConvLayoutParams(TypedDict): + stride: tuple[int, ...] + padding: tuple[int, ...] + dilation: tuple[int, ...] + transposed: bool + output_padding: tuple[int, ...] + groups: int + + +def conv_layout( + x: TensorBox, + weight: TensorBox, + bias: Optional[TensorBox], + stride: Sequence[int], + padding: tuple[int, ...], + dilation: tuple[int, ...], + transposed: bool, + output_padding: tuple[int, ...], + groups: int, +) -> ir.Layout: + """Determine output layout for a convolution""" + with V.graph.fake_mode: + output = torch.ops.aten.convolution( + ir.ir_node_to_tensor(x, guard_shape=True), + ir.ir_node_to_tensor(weight, guard_shape=True), + ir.ir_node_to_tensor(bias, guard_shape=True), + V.graph.sizevars.size_hints(stride), # type: ignore[arg-type] + V.graph.sizevars.size_hints(padding), # type: ignore[arg-type] + V.graph.sizevars.size_hints(dilation), # type: ignore[arg-type] + transposed, + V.graph.sizevars.size_hints(output_padding), # type: ignore[arg-type] + groups, + ) + sizes = ir.convert_shape_to_inductor(output.size()) + stride = ir.convert_shape_to_inductor(output.stride()) # type: ignore[assignment] + + return ir.FixedLayout( + x.get_device_or_error(), + x.get_dtype(), + sizes, + stride, + ) + + +def channels_last_order(rank): + order = list(reversed(range(rank))) + order.insert(1, order.pop(-1)) + return order + + +def convert_1x1_conv_to_mm(x, weight, bias): + # special case for 1x1 convolution, which is actually just a matmul + rank = len(weight.get_size()) + for _ in range(rank - 2): + weight = L[aten.squeeze](weight, dim=-1) + weight = L[aten.permute](weight, [1, 0]) + + x = ir.ExternKernel.require_stride_order(x, channels_last_order(rank)) + x_permute = list(range(rank)) + x_permute.append(x_permute.pop(1)) + x = L[aten.permute](x, x_permute) + *sizes, in_chan = x.get_size() + x = L[aten.reshape](x, [sympy_product(sizes), in_chan]) + if bias is None: + result = L[aten.mm](x, weight) + else: + result = L[aten.addmm](bias, x, weight) + result = L[aten.reshape](result, [*sizes, -1]) + result_permute = list(range(rank)) + result_permute.insert(1, result_permute.pop(-1)) + return L[aten.permute](result, result_permute) + + +@register_lowering(aten.convolution) +def convolution( + x: TensorBox, + weight: TensorBox, + bias: Optional[TensorBox], + stride: Sequence[int], + padding: Sequence[int], + dilation: Sequence[int], + transposed: bool, + output_padding: Sequence[int], + groups: int, +): + stride = tuple(stride) + padding = tuple(padding) + dilation = tuple(dilation) + output_padding = tuple(output_padding) + if not isinstance(groups, int): + groups = V.graph.sizevars.evaluate_static_shape(groups) + assert isinstance(groups, int) + + # Need use hint for triton template since the template does not + # work with a dynamic shape. + # + # No need to evaluate_static_shape for dilation and output_padding + # since the template is only used when dilation is 1 and output_padding + # is 0. + stride = tuple(V.graph.sizevars.evaluate_static_shapes(stride)) + padding = tuple(V.graph.sizevars.evaluate_static_shapes(padding)) + + kwargs: ConvLayoutParams = { + "stride": stride, + "padding": padding, + "dilation": dilation, + "transposed": transposed, + "output_padding": output_padding, + "groups": groups, + } + + device_type = ir.get_device_type(x) + + if len(x.get_size()) == len(weight.get_size()) - 1: + # add batch dimension to simplify rest of function + return L[aten.squeeze]( + convolution(L[aten.expand](x, [1, *x.get_size()]), weight, bias, **kwargs), + dim=0, + ) + + out_chan, in_chan, *kernel_shape = V.graph.sizevars.evaluate_static_shapes( + weight.get_size() + ) + + # Always convert conv1D to 2D for Intel GPU. + # Only conv2D can be converted to channel last layout, + # which have much better performance. + if len(x.get_size()) == 3 and len(kernel_shape) == 1 and device_type == "xpu": + kwargs.update( + { + "stride": (1,) + stride, + "padding": (0,) + padding, + "dilation": (1,) + dilation, + "output_padding": (0,) + output_padding, + } + ) + # (N, C, L) -> (N, C, 1, L) + x = L[aten.unsqueeze](x, dim=2) + weight = L[aten.unsqueeze](weight, dim=2) + + return L[aten.squeeze]( + convolution(x, weight, bias, **kwargs), + dim=2, + ) + + ndim = len(kernel_shape) + stride = pad_listlike(stride, ndim) + padding = pad_listlike(padding, ndim) + dilation = pad_listlike(dilation, ndim) + output_padding = pad_listlike(output_padding, ndim) + + def channels_last_conv(): + if V.graph.layout_opt and ndim == 2: + return True + + layout = conv_layout(x, weight, None, **kwargs) + req_stride_order = ir.get_stride_order( + V.graph.sizevars.size_hints(layout.stride) + ) + return req_stride_order == ir.NHWC_STRIDE_ORDER + + autotuning_gemm = config.max_autotune or config.max_autotune_gemm + + if ( + (config.conv_1x1_as_mm or (autotuning_gemm and channels_last_conv())) + and is_ones(kernel_shape) + and is_ones(stride) + and is_zeros(padding) + and is_ones(dilation) + and not transposed + and is_zeros(output_padding) + and groups == 1 + and V.graph.sizevars.statically_known_gt(sympy_product(x.get_size()), 0) + ): + return convert_1x1_conv_to_mm(x, weight, bias) + + if bias is not None and device_type != "cpu": + # peel off the bias, cudnn is slower with it + result = convolution(x, weight, None, **kwargs) + return L[aten.add]( + result, L[aten.view](bias, [result.get_size()[1]] + ndim * [1]) + ) + + x.realize() + weight.realize() + + # ndim can be 1 for convolution in models such as demucs + # TODO: check if it's beneficial to convert Conv1d to Conv2d and then + # apply channels last. + if V.graph.layout_opt and ndim == 2: + V.graph.num_channels_last_conv += 1 + x = ir.ExternKernel.require_channels_last(x) + # TODO maybe we can convert weights to channels last just once before + # running the model. + weight = ir.ExternKernel.require_channels_last(weight) + layout = conv_layout(x, weight, None, **kwargs) + else: + layout = conv_layout(x, weight, None, **kwargs) + req_stride_order = ir.get_stride_order( + V.graph.sizevars.size_hints(layout.stride) + ) + x = ir.ExternKernel.require_stride_order(x, req_stride_order) + weight = ir.ExternKernel.require_stride_order(weight, req_stride_order) + + ordered_kwargs_for_cpp_kernel = [ + "stride", + "padding", + "dilation", + "transposed", + "output_padding", + "groups", + ] + if bias is None: + args = [x, weight] + kwargs["bias"] = None # type: ignore[typeddict-unknown-key] + ordered_kwargs_for_cpp_kernel.insert(0, "bias") + else: + args = [x, weight, bias] + bias.realize() + bias.freeze_layout() + V.graph.sizevars.evaluate_static_shapes(bias.get_size()) + + choices = [] + if torch._inductor.utils._use_conv_autotune_backend("ATEN"): + choices = [ + aten_convolution.bind( + args, + layout, + ordered_kwargs_for_cpp_kernel, + **kwargs, + ) + ] + + if ( + torch._inductor.utils._use_conv_autotune_backend("TRITON") + and use_triton_template(layout) + # templates only support these: + and is_ones(dilation) + and not transposed + and is_zeros(output_padding) + # there are some odd models where this check fails (e.g. shufflenet_v2_x1_0) + and V.graph.sizevars.statically_known_equals(in_chan, x.get_size()[1]) # type: ignore[arg-type] + ): + if ( + is_ones(kernel_shape) + and is_ones(stride) + and is_zeros(padding) + and groups == 1 + ): + choices.append(aten_conv1x1_via_mm.bind(args, layout)) + + conv_configs = V.choices.get_conv_configs(device_type) + + for cfg in conv_configs( + sympy_product([x.get_size()[0], *x.get_size()[2:]]), + out_chan, + in_chan, + **mm_config_kwargs(device_type, _is_large_block_for_cpu), + ): + if ndim == 2: + conv2d_template.maybe_append_choice( + choices, + input_nodes=(x, weight), + layout=layout, + KERNEL_H=kernel_shape[0], + KERNEL_W=kernel_shape[1], + STRIDE_H=stride[0], + STRIDE_W=stride[1], + PADDING_H=padding[0], + PADDING_W=padding[1], + GROUPS=groups, + # TODO(jansel): try unroll for bigger kernels once fixed: + # https://github.com/triton-lang/triton/issues/1254 + UNROLL=is_ones(kernel_shape), + ALLOW_TF32=torch.backends.cudnn.allow_tf32, + num_stages=cfg.num_stages, + num_warps=cfg.num_warps, + **cfg.kwargs, + ) + elif ndim == 3: + conv3d_template.maybe_append_choice( + choices, + input_nodes=(x, weight), + layout=layout, + KERNEL_D=kernel_shape[0], + KERNEL_H=kernel_shape[1], + KERNEL_W=kernel_shape[2], + STRIDE_D=stride[0], + STRIDE_H=stride[1], + STRIDE_W=stride[2], + PADDING_D=padding[0], + PADDING_H=padding[1], + PADDING_W=padding[2], + GROUPS=groups, + # TODO(jansel): try unroll for bigger kernels once fixed: + # https://github.com/triton-lang/triton/issues/1254 + UNROLL=is_ones(kernel_shape), + ALLOW_TF32=torch.backends.cudnn.allow_tf32, + num_stages=cfg.num_stages, + num_warps=cfg.num_warps, + **cfg.kwargs, + ) + if use_ck_conv_template(layout): + CKGroupedConvFwdTemplate.add_ck_conv_choices( + choices, + layout, + input_nodes=(x, weight) + ((bias,) if bias is not None else tuple()), + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + n_spatial_dimensions=ndim, + ) + return autotune_select_algorithm("convolution", choices, args, layout) + + +@register_lowering(aten._convolution) +def _convolution( + x, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + benchmark, + deterministic, + cudnn_enabled, + allow_tf32, +): + return convolution( + x, weight, bias, stride, padding, dilation, transposed, output_padding, groups + ) + + +def constrain_conv_to_fx_strides(fx_node, *args, **kwargs): + assert fx_node.target == torch.ops.aten.convolution.default + if V.graph.layout_opt: + return args, kwargs + else: + return constrain_to_fx_strides(fx_node, *args, **kwargs) + + +add_layout_constraint(aten.convolution, constrain_conv_to_fx_strides) diff --git a/phivenv/Lib/site-packages/torch/_inductor/kernel/flex_attention.py b/phivenv/Lib/site-packages/torch/_inductor/kernel/flex_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..0c085cff518127bf3103f3b669c84285cfd203cb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/kernel/flex_attention.py @@ -0,0 +1,2763 @@ +# mypy: allow-untyped-defs +"""Triton Implementation of the flex_attention Kernel""" + +import copy +import logging +import math +from collections.abc import Sequence +from dataclasses import dataclass +from enum import auto, Enum +from typing import Any, Optional, Union + +import sympy + +import torch +from torch._inductor.virtualized import V +from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import tree_map +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.value_ranges import ValueRanges + +from ..ir import ( + Buffer, + ComputedBuffer, + ExternKernel, + FixedLayout, + FlexibleLayout, + get_fill_order, + InputBuffer, + IRNode, + MutationLayoutSHOULDREMOVE, + Scatter, + StorageBox, + Subgraph, + TensorBox, +) +from ..lowering import ( + _full, + check_and_broadcast_indices, + empty, + empty_strided, + expand, + index_output_size_and_inner_fn, + lowerings, + register_lowering, + to_dtype, +) +from ..select_algorithm import ( + autotune_select_algorithm, + realize_inputs, + SymbolicGridFn, + TritonTemplate, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten +Expr = sympy.Expr + + +def construct_strides( + sizes: Sequence[int], + fill_order: Sequence[int], +) -> Sequence[int]: + """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" + # Initialize strides + assert len(sizes) == len(fill_order), ( + "Length of sizes must match the length of the fill order" + ) + strides = [0] * len(sizes) + + # Start with stride 1 for the innermost dimension + current_stride = 1 + + # Iterate through the fill order populating strides + for dim in fill_order: + strides[dim] = current_stride + current_stride *= sizes[dim] + + return strides + + +def infer_dense_strides(size: Sequence[int], orig_strides: Sequence[int]): + """This is a mirror of the same function in aten/src/ATen/ExpandUtils.cpp + + Args: + size: The size of the output tensor + orig_strides: The strides of the input tensor + Returns: + List[int]: Dense non-overlapping strides that preserve the input tensor's layout permutation. + The returned strides follow the same stride propagation rules as TensorIterator. This matches + The behavior of empty_like() + """ + fill_order = get_fill_order(orig_strides, V.graph.sizevars.shape_env) + return construct_strides(size, fill_order) + + +@SymbolicGridFn +def flex_attention_grid(batch_size, q_heads, num_queries, d_model, meta, *, cdiv): + """How is this kernel parallelized? + We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1) + Each block is responsible for iterating over blocks of keys and values calculating + the final attention output. + """ + return (cdiv(num_queries, meta["BLOCK_M"]), batch_size * q_heads, 1) + + +def create_placeholder( + name: str, + dtype: torch.dtype, + device: torch.device, + size: Optional[list[int]] = None, +) -> TensorBox: + """Creates a placeholder input buffers for producing subgraph_output.""" + input_buffer = InputBuffer( + name=name, + layout=FixedLayout( + device, + dtype, + size if size else [], + FlexibleLayout.contiguous_strides(size) if size else [], + ), + ) + return TensorBox.create(input_buffer) + + +def maybe_realize(args: list[Optional[IRNode]]): + """Accepts a list of optional IRNodes and returns a list of realized IRNodes""" + return tree_map( + lambda x: ( + realize_inputs(x) + if x is not None and not isinstance(x, sympy.Symbol) + else x + ), + args, + ) + + +def get_float32_precision(): + if ( + torch.get_float32_matmul_precision() == "highest" + or torch.version.hip + or torch.mtia.is_available() + ): + return "'ieee'" + else: + return "'tf32'" + + +def zeros_and_scatter_lowering(shape: list[int], indices, values): + # Always accumulate into fp32 then cast + grad = _full(0, values.get_device(), torch.float32, shape) + assert isinstance(grad, TensorBox) + grad.realize() + x_size = grad.get_size() + values = to_dtype(values, grad.get_dtype()) + indices_loaders = [i.make_loader() if i is not None else None for i in indices] + indices, tensor_indices = check_and_broadcast_indices(indices, grad.get_device()) + # We can use the first one since they are all required to be the same size + tensor_size = list(indices[tensor_indices[0]].get_size()) + indexed_size = [x_size[i] for i in range(len(indices))] + + expected_vals_size, inner_fn = index_output_size_and_inner_fn( + x_size, + indices, + tensor_indices, + tensor_size, + indices_loaders, + indexed_size, + None, + check=True, + ) + + values = expand(values, expected_vals_size) + device = grad.get_device() + assert device is not None + scatter = Scatter( + device=device, + dtype=grad.get_dtype(), + inner_fn=values.make_loader(), + ranges=expected_vals_size, # iter_ranges, + output_indexer=inner_fn, + scatter_mode="atomic_add", + ) + + buffer = ComputedBuffer( + name=grad.data.data.name, # type: ignore[attr-defined] + layout=MutationLayoutSHOULDREMOVE(grad), + data=scatter, + ) + return buffer + + +SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]] + + +def build_subgraph_module_buffer( + args: list[TensorBox], graph_module: torch.fx.GraphModule +) -> SubgraphResults: + """This function's goal is to take in the required args and produce the subgraph buffer + The subgraph buffer is a ComputedBuffer that will be inlined into the triton template + + Args: + args: The args that are passed into the subgraph. Contains both fixed and lifted inputs. + subgraph: The Subgraph ir for which to produce the output node + """ + from ..subgraph_lowering import PointwiseSubgraphLowering + + pw_subgraph = PointwiseSubgraphLowering( + graph_module, + root_graph_lowering=V.graph, + allowed_mutations=OrderedSet([torch.ops.flex_lib.zeros_and_scatter.default]), + additional_lowerings={ + torch.ops.flex_lib.zeros_and_scatter.default: zeros_and_scatter_lowering + }, + ) + with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type] + pw_subgraph.run(*args) + + # Since we are allowing mutations/buffer creation, we need to register any fresh buffers + # creating during the pointwise subgraph lowering + if len(pw_subgraph.buffers) > 0: + for buffer in pw_subgraph.buffers: + V.graph.register_buffer(buffer) + + def convert_output_node_to_buffer(output_buffer) -> Optional[ComputedBuffer]: + if output_buffer is None: + return None + if isinstance(output_buffer, ComputedBuffer): + # These nodes are coming from the output of zeros_and_scatter + return output_buffer + assert isinstance(output_buffer, TensorBox), ( + "The output node for flex attention's subgraph must be a TensorBox, but got: ", + type(output_buffer), + ) + assert isinstance(output_buffer.data, StorageBox), ( + "The output node for the flex attention subgraph must be a StorageBox, but got: ", + type(output_buffer), + ) + subgraph_buffer = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=output_buffer.data.get_device(), + dtype=output_buffer.data.get_dtype(), + size=output_buffer.data.get_size(), + ), + data=output_buffer.data.data, # type: ignore[arg-type] + ) + return subgraph_buffer + + return tree_map(convert_output_node_to_buffer, pw_subgraph.graph_outputs) + + +def build_subgraph_buffer(args: list[TensorBox], subgraph: Subgraph) -> SubgraphResults: + return build_subgraph_module_buffer(args, subgraph.graph_module) + + +def get_fwd_subgraph_outputs( + subgraph_buffer: SubgraphResults, mask_graph_buffer: SubgraphResults +) -> list[Optional[ComputedBuffer]]: + subgraph_buffer = ( + subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer] + ) + mask_graph_buffer = ( + mask_graph_buffer + if isinstance(mask_graph_buffer, Sequence) + else [mask_graph_buffer] + ) + return [*subgraph_buffer, *mask_graph_buffer] + + +# Inner Triton functions shared by flex_attention & split-k decoding kernels. +compute_next_offset_func = r""" +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset +""" + +get_bounded_indices_func = r""" +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices +""" + + +load_checked_block = r""" +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") +""" + +load_checked_2d = r""" +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_DIM: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr) +""" + +compute_flex_attention = r""" +{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} + + ZQ = {{size("Q", 0)}} + HQ = {{size("Q", 1)}} + Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0) + off_zq = tl.program_id(1) // HQ + off_hq = tl.program_id(1) % HQ + + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + {%- if USE_TMA %} + desc_q = tl.make_tensor_descriptor( + base=Q, + shape=[Q_LEN*HQ*ZQ, QK_HEAD_DIM], + strides=[QK_HEAD_DIM, 1], + block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED], + ) + desc_v = tl.make_tensor_descriptor( + base=V, + shape=[KV_LEN*ZKV*HQ, V_HEAD_DIM], + strides=[V_HEAD_DIM, 1], + block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], + ) + desc_k = tl.make_tensor_descriptor( + base=V, + shape=[KV_LEN*ZKV*HQ, V_HEAD_DIM], + strides=[V_HEAD_DIM, 1], + block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], + ) + {%- endif %} + + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + K_block_ptr = None + V_block_ptr = None + Q_block_ptr = None + + if not USE_TMA: + Q_block_ptr = tl.make_block_ptr( + base=Q , + shape=(Q_LEN, QK_HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(q_start * BLOCK_M, 0), + block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + + {%- if USE_TMA %} + q = tl.load_tensor_descriptor( + desc_q, + [(q_start * BLOCK_M).to(tl.int32), 0], + ) + {%- else %} + q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) + {%- endif %} + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + if not USE_TMA: + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(QK_HEAD_DIM, KV_LEN), + strides=(stride_kk, stride_kn), + offsets=(0, kv_start), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(kv_start, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + if not USE_TMA: + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(QK_HEAD_DIM, KV_LEN), + strides=(stride_kk, stride_kn), + offsets=(0, kv_start), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(kv_start, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1) // HQ + idx_hq = tl.program_id(1) % HQ + idx_m = offs_m[:, None] + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} + + if OUTPUT_LOGSUMEXP: + off_hz = tl.program_id(1) + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) + """ + + +compute_forward_inner = r""" +@triton.jit +def forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + {{gen_defines() | indent_except_first(1)}} + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + start_n, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + start_n, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + if not USE_TMA: + K_block_ptr = tl.advance(K_block_ptr, (0, offset)) + V_block_ptr = tl.advance(V_block_ptr, (offset, 0)) + + + return acc, l_i, m_i + +""" + + +compute_forward_block_mn = r""" +@triton.jit +def forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + start_n, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + {{gen_defines() | indent_except_first(1)}} + + # -- load k -- + # NB reversed order to since K is transposed + {%- if USE_TMA %} + k = tl.load_tensor_descriptor( # load in row major + desc_k, + [start_n.to(tl.int32) , kv_start], + ) + {%- else %} + k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE) + {%- endif %} + + if USE_TMA: + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + out="qk" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=1, + output_name="mask_mod_output", + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + ) | indent_except_first(2) }} + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + {%- if USE_TMA %} + v = tl.load_tensor_descriptor( + desc_v, + [kv_start.to(tl.int32) + start_n.to(tl.int32),0], + ) + {%- else %} + v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) + {%- endif %} + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +""" + + +flex_attention_template = TritonTemplate( + name="flex_attention", + grid=flex_attention_grid, + source=compute_flex_attention + + compute_forward_inner + + compute_next_offset_func + + compute_forward_block_mn + + load_checked_block + + get_bounded_indices_func, +) + + +def _use_flex_decoding(query, kv_indices, kernel_options, enable_gqa): + """Decide which kernel to use, return true if use flex decoding kernel. + Note: + Since the number of splits is calculated based of the the number of batch and head dims + we need to ensure that the batch and head dims are statically known. Otherwise we just + use the main flex_attention kernel. + """ + force_flex = kernel_options.get("FORCE_USE_FLEX_ATTENTION", False) + short_query_length = V.graph.sizevars.evaluate_expr( + sympy.Lt(query.get_size()[-2], 128) + ) + non_zero_length = V.graph.sizevars.evaluate_expr(sympy.Gt(query.get_size()[-2], 0)) + static_batch = isinstance(query.get_size()[0], (int, sympy.Integer)) + static_num_heads = isinstance(query.get_size()[1], (int, sympy.Integer)) + if enable_gqa: + # in the current flex decoding triton kernel, grouped query heads for the + # same kv head are handled by the same block. So it's hard to support different + # kv num blocks for grouped query heads. We just fall back to main flex_attention + # kernel where each query head is handled by a separate block. + valid_block_mask_num_heads = V.graph.sizevars.evaluate_expr( + sympy.Eq(kv_indices.get_size()[1], 1) + ) + else: + valid_block_mask_num_heads = V.graph.sizevars.evaluate_expr( + sympy.Or( + sympy.Eq(kv_indices.get_size()[1], 1), + sympy.Eq(kv_indices.get_size()[1], query.get_size()[1]), + ) + ) + return ( + not force_flex + and short_query_length + and static_batch + and static_num_heads + and non_zero_length + and valid_block_mask_num_heads + ) + + +_h100_default_config = { + (torch.float32, 64): (128, 32, 4, 3), + (torch.float32, 128): (32, 64, 4, 3), + (torch.float32, 256): (32, 32, 4, 3), + (torch.bfloat16, 64): (128, 128, 4, 3), + (torch.bfloat16, 128): (128, 64, 8, 3), + (torch.bfloat16, 256): (64, 32, 4, 3), + (torch.float16, 64): (128, 128, 4, 3), + (torch.float16, 128): (128, 128, 8, 3), + (torch.float16, 256): (64, 32, 4, 3), +} + +_a100_default_config = { + (torch.float32, 64): (128, 32, 4, 3), + (torch.float32, 128): (128, 32, 4, 3), + (torch.float32, 256): (64, 16, 4, 3), + (torch.bfloat16, 64): (128, 64, 4, 3), + (torch.bfloat16, 128): (128, 64, 8, 3), + (torch.bfloat16, 256): (32, 64, 4, 3), + (torch.float16, 64): (128, 64, 4, 3), + (torch.float16, 128): (128, 64, 8, 3), + (torch.float16, 256): (32, 64, 4, 3), +} + +_rocm_default_config = { + (torch.float32, 64): (128, 32, 4, 1), + (torch.float32, 128): (128, 32, 4, 1), + (torch.float32, 256): (64, 16, 4, 1), + (torch.bfloat16, 64): (128, 64, 8, 1), + (torch.bfloat16, 128): (128, 64, 8, 1), + (torch.bfloat16, 256): (32, 64, 8, 1), + (torch.float16, 64): (128, 64, 8, 1), + (torch.float16, 128): (128, 64, 8, 1), + (torch.float16, 256): (32, 64, 4, 1), +} + + +class Mode(Enum): + fwd = auto() + bwd = auto() + + +def create_num_blocks_fake_generator(sparse_indices): + # The idea here is that we need to create a real tensor with real data + # that's representative for benchmarking. + # For example, returning all zeros for the `kv_num_blocks` input would mean + # that we are computing 0 blocks for each row, which would provide bogus + # autotuning results. + # + # In this case, we choose to use min(16, max_block) blocks, because I + # (Horace) think it'll probably result in pretty representative performance. + # If it's too short then prefetching won't help. If it's too long then + # autotuning will take longer for no good reason. + def create_num_blocks_fake(x) -> torch.Tensor: + num_blocks_for_autotuning = V.graph.sizevars.size_hint(sparse_indices.shape[-1]) + size = [V.graph.sizevars.size_hint(i) for i in x.get_size()] + return torch.full( + size, + num_blocks_for_autotuning, + dtype=x.get_dtype(), + device=x.get_device(), + ) + + return create_num_blocks_fake + + +def create_indices_fake(x) -> torch.Tensor: + size = [V.graph.sizevars.size_hint(i) for i in x.get_size()] + indices = torch.arange(0, size[-1], dtype=x.get_dtype(), device=x.get_device()) + indices = indices.expand(size).contiguous() + return indices + + +from torch._inductor.kernel.flex_decoding import create_flex_decoding_kernel + +from ..codegen.cpp_flex_attention_template import CppFlexAttentionTemplate + + +def check_cpu_supported(): + import os + import sys + + requires_avx2_on_cpu = ( + torch.cpu._is_avx2_supported() and os.getenv("ATEN_CPU_CAPABILITY") != "default" + ) + supported = ( + requires_avx2_on_cpu + and not torch.xpu.is_available() + and not sys.platform == "darwin" + ) + return supported + + +def contiguous_last_dim(x): + """Ensure that realized IR node has a contiguous stride in the last dimension.""" + strides = x.maybe_get_stride() + if strides and strides[-1] != 1: + contiguous_stride_order = list(reversed(range(len(x.get_size())))) + return ExternKernel.require_stride_order(x, contiguous_stride_order) + return x + + +def lower_cpu( + query, + key, + value, + subgraph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, +): + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + mask_graph, + ) = block_mask + + if kernel_options["OUTPUT_LOGSUMEXP"]: + raise NotImplementedError( + "torch.compile on CPU only supports inference and `return_lse` is not supported yet." + ) + if not check_cpu_supported(): + raise NotImplementedError( + "torch.compile on current platform is not supported for CPU." + ) + + fake_buffers: list[Buffer] = [] # noqa: F821 + + # [Note] Handle the case where the split sizes are not statically known. + # The value of cur_qSplitSize and cur_kvSplitSize are decided during runtime. + # We use symbols to represent them during the compilation here. + # They'll be replaced by the string "cur_qSplitSize" and "cur_kvSplitSize" in + # the modification function of the CppFlexAttentionTemplate class. + cur_qSplitSize = V.graph.sizevars.shape_env.create_unbacked_symint().node.expr + cur_kvSplitSize = V.graph.sizevars.shape_env.create_unbacked_symint().node.expr + shape_env = V.graph.sizevars.shape_env + + # We don't know the concrete value of cur_qSplitSize and cur_kvSplitSize during the compilation. + # Mark symbols > 1 to ensure broadcasting is always applied. + # This avoids treating them as equal when `eq(var, 1)` is evaluated in `broadcast_symbolic_shapes`. + shape_env.var_to_range[cur_qSplitSize] = ValueRanges(2, int_oo) + shape_env.var_to_range[cur_kvSplitSize] = ValueRanges(2, int_oo) + + score_dtype = torch.float + placeholder_inps = [ + create_placeholder(name, dtype, query.get_device(), size) + for name, dtype, size in [ + ("score", score_dtype, [cur_qSplitSize, cur_kvSplitSize]), + ("b", torch.int64, []), + ("h", torch.int64, []), + ("q_idx", torch.int64, [cur_qSplitSize, 1]), + ("kv_idx", torch.int64, [1, cur_kvSplitSize]), + ] + ] + subgraph_buffer = build_subgraph_buffer( + placeholder_inps + list(score_mod_other_buffers), subgraph + ) + if subgraph_buffer is not None: + if isinstance(subgraph_buffer, list): + for _buf in subgraph_buffer: + if _buf is not None: + _buf.freeze_layout() + else: + subgraph_buffer.freeze_layout() + mask_graph_placeholder_inps = [ + create_placeholder(name, dtype, query.get_device(), size) + for name, dtype, size in [ + ("score", score_dtype, [cur_qSplitSize, cur_kvSplitSize]), + ("b", torch.int64, []), + ("h", torch.int64, []), + ("q_idx", torch.int64, [cur_qSplitSize, 1]), + ("kv_idx", torch.int64, [1, cur_kvSplitSize]), + ] + ] + + # The original mask_graph works on a scalar and only includes + # the logic of calculating the mask value. + # We need to add the logic of applying the mark to the qk_data tensor + # into the graph for the later codegen of this part. + # Example: + # mask_graph: + # def mask_fn(b, h, q_idx, kv_idx): + # mask = q_idx >= kv_idx + # return mask + # The converted_mask_graph should be: + # def converted_mask_fn(qk_data, b, h, q_idx, kv_idx): + # mask = q_idx >= kv_idx + # qk_data = torch.where(mask, qk_data, torch.full_like(qk_data, -float("inf"))) + # return qk_data + def convert_mask_graph_module(mask_graph): + gm = copy.deepcopy(mask_graph.graph_module) + graph = gm.graph + # Add qk_data as the first input + with graph.inserting_before(next(iter(graph.nodes))): + qk_data_node = graph.placeholder("qk_data") + + # Find the node that returns the mask + output_node = None + for node in graph.nodes: + if node.op == "output": + output_node = node + break + + # Get the mask node + assert output_node is not None + mask_node = output_node.args[0] + + size_node = [cur_qSplitSize, cur_kvSplitSize] + # Create a new node for torch.full + with graph.inserting_after(mask_node): + full_node = graph.call_function( + torch.full, + args=(size_node, -float("inf")), + kwargs={"dtype": score_dtype}, + ) + + # Create a new node for torch.where + with graph.inserting_after(full_node): + where_node = graph.call_function( + torch.ops.aten.where, args=(mask_node, qk_data_node, full_node) + ) + + # Update the output node to return the result of torch.where + output_node.args = (where_node,) + + graph.lint() + converted = torch.fx.GraphModule(gm, graph) + return converted + + converted_mask_graph_module = convert_mask_graph_module(mask_graph) + + mask_graph_buffer = build_subgraph_module_buffer( + mask_graph_placeholder_inps + list(mask_mod_other_buffers), + converted_mask_graph_module, + ) + + # Clear the pending fresh unbacked symbols that are created for cur_qSplitSize and cur_kvSplitSize in the current kernel. + pending = V.graph.sizevars.shape_env.pending_fresh_unbacked_symbols + V.graph.sizevars.shape_env.pending_fresh_unbacked_symbols = [ + x for x in pending if x not in (cur_qSplitSize, cur_kvSplitSize) + ] + + buffer_list = ( + placeholder_inps + + list(score_mod_other_buffers) + + mask_graph_placeholder_inps + + list(mask_mod_other_buffers) + ) + for item in buffer_list: + if isinstance(item, TensorBox): + fake_buffers.append(item.data.data) # type: ignore[attr-defined] + + # CPU kernel requires last dim to be contiguous + query, key, value = map(contiguous_last_dim, [query, key, value]) + + ( + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ) = maybe_realize( + [ + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ] + ) + + if len(OrderedSet([query.get_name(), key.get_name(), value.get_name()])) != 3: + raise NotImplementedError( + "Unsupported for now if query, key, value are the same buffer." + ) + if query.get_dtype() not in [torch.float, torch.bfloat16, torch.float16]: + raise NotImplementedError( + "`torch.float` , `torch.float16` and `torch.bfloat16` are supported in FlexAttention for CPU device. " + f"Found input tensors are `{query.get_dtype()}`." + ) + score_mod_other_buffers = maybe_realize(score_mod_other_buffers) + mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + B = Bq + + # Construct output layout with strides matching the query. + out_size = [B, Hq, seq_len_q, v_head_dim] + out_strides = infer_dense_strides(out_size, query.get_stride()) + + layout = FixedLayout( + query.get_device(), + query.get_dtype(), + [B, Hq, seq_len_q, v_head_dim], + stride=[sympy.sympify(s) for s in out_strides], + ) + _choices: list[Any] = [] + input_nodes = [query, key, value, kv_num_blocks, kv_indices] + if not full_kv_num_blocks: + no_full_kv_block = True + else: + no_full_kv_block = False + input_nodes += [full_kv_num_blocks] + input_nodes += [full_kv_indices] + has_other_buffer = False + kernel_input_name_to_buffer = {} + if score_mod_other_buffers or mask_mod_other_buffers: + has_other_buffer = True + + for prefix, buffers in [ + ("score_others", score_mod_other_buffers), + ("mask_others", mask_mod_other_buffers), + ]: + kernel_input_name_to_buffer.update( + {f"{prefix}_{i}": buf for i, buf in enumerate(buffers)} + ) + input_nodes += [ + value + for value in kernel_input_name_to_buffer.values() + if not isinstance(value, sympy.Symbol) + ] + + skip_mask_score = kernel_options.get("SKIP_MASK_SCORE", False) + # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) + assert V.graph.sizevars.evaluate_expr( + sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE)) + ), ( + "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask." + ) + assert V.graph.sizevars.evaluate_expr( + sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE)) + ), ( + "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask." + ) + CppFlexAttentionTemplate.add_choices( + choices=_choices, + input_nodes=input_nodes, + layout=layout, + scale=scale, + score_mod=None if skip_mask_score else subgraph_buffer, + mask_mod=None if skip_mask_score else mask_graph_buffer, + kv_block_size=SPARSE_KV_BLOCK_SIZE, + q_block_size=SPARSE_Q_BLOCK_SIZE, + has_other_buffer=has_other_buffer, + no_full_kv_block=no_full_kv_block, + fake_buffers=fake_buffers, + len_score_other=len(score_mod_other_buffers), + len_mask_other=len(mask_mod_other_buffers), + kernel_input_name_to_buffer=kernel_input_name_to_buffer, + block_vars=(cur_qSplitSize, cur_kvSplitSize), + ) + inputs_for_autotuning = [ + query, + key, + value, + ] + res = autotune_select_algorithm( + "flex_attention", + _choices, + inputs_for_autotuning, + layout, + ) + + # need subgraph inputs and outputs to analyze all symints used in flex attention + res.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + res.data.data.subgraph_outs = get_fwd_subgraph_outputs( + subgraph_buffer, mask_graph_buffer + ) + + return (res,) + + +def is_power_of_2(n): + return n != 0 and ((n & (n - 1)) == 0) + + +def next_power_of_two(n): + if n <= 0: + return 1 + return 2 ** math.ceil(math.log2(n)) + + +def set_head_dim_values( + kernel_options: dict[str, Any], qk_head_dim, v_head_dim, graph_sizevars +): + """ + Mutates kernel options, adding head dimension calculations. + + Args: + kernel_options: Dictionary to populate with options + qk_head_dim: Query/Key head dimension + v_head_dim: Value head dimension + graph_sizevars: Graph size variables object with evaluate_static_shape method + + """ + # QK dimensions + qk_head_dim_static = graph_sizevars.evaluate_static_shape(qk_head_dim) + kernel_options.setdefault("QK_HEAD_DIM", qk_head_dim_static) + kernel_options.setdefault( + "QK_HEAD_DIM_ROUNDED", next_power_of_two(qk_head_dim_static) + ) + + # V dimensions + v_head_dim_static = graph_sizevars.evaluate_static_shape(v_head_dim) + kernel_options.setdefault("V_HEAD_DIM", v_head_dim_static) + kernel_options.setdefault( + "V_HEAD_DIM_ROUNDED", next_power_of_two(v_head_dim_static) + ) + + # Safety flag + kernel_options.setdefault( + "SAFE_HEAD_DIM", + is_power_of_2(qk_head_dim_static) and is_power_of_2(v_head_dim_static), + ) + + +@register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None) +def flex_attention( + query, + key, + value, + subgraph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, +): + if query.get_device().type == "cpu": + return lower_cpu( + query, + key, + value, + subgraph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + # below is cuda path if device is not cpu + # tl.dot does not support embedding size less than 16 + small_dqk = V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-1], 16)) + small_dv = V.graph.sizevars.evaluate_expr(sympy.Lt(value.get_size()[-1], 16)) + if small_dqk or small_dv: + raise NotImplementedError( + f"NYI: embedding dimension of the query, key, and value must be " + f"at least 16 but got E={query.get_size()[-1]} and Ev={value.get_size()[-1]}" + ) + + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + mask_graph, + ) = block_mask + + placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("score", query.get_dtype()), + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + subgraph_buffer = build_subgraph_buffer( + placeholder_inps + list(score_mod_other_buffers), subgraph + ) + + mask_graph_placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + mask_graph_buffer = build_subgraph_buffer( + mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph + ) + + kernel_options = dict(kernel_options) + # Mark symbols in custom kernel options as static shapes and add guards. + kernel_options = { + k: V.graph.sizevars.evaluate_static_shape(v) + if isinstance(v, sympy.Symbol) + else v + for k, v in kernel_options.items() + } + kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) + enable_gqa = V.graph.sizevars.evaluate_expr( + sympy.Ne(query.get_size()[1], key.get_size()[1]), + ) + if _use_flex_decoding(query, kv_indices, kernel_options, enable_gqa): + return create_flex_decoding_kernel( + query, + key, + value, + block_mask, + scale, + kernel_options, + subgraph_buffer, + mask_graph_buffer, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + ( + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ) = maybe_realize( + [ + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ] + ) + + score_mod_other_buffers = maybe_realize(score_mod_other_buffers) + mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) + + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + ) + assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_q, 0)), ( + "Query length must be greater than 0" + ) + assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_kv, 0)), ( + "Key length must be greater than 0" + ) + + B = Bq + + if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: + kernel_options.setdefault("IS_DIVISIBLE", False) + else: + kernel_options.setdefault("IS_DIVISIBLE", True) + + # NB it is okay that the v_head_dim is different + # We are using these to match fill order of the output. + q_strides = query.get_stride() + # Construct output layout with strides matching the query. + out_size = [B, Hq, seq_len_q, v_head_dim] + out_strides = infer_dense_strides(out_size, q_strides) + + layout = FixedLayout( + query.get_device(), + query.get_dtype(), + [B, Hq, seq_len_q, v_head_dim], + stride=[sympy.sympify(s) for s in out_strides], + ) + # see NOTE:[TritonTemplates with multiple outputs] + logsumexp_shape = [B, Hq, seq_len_q] + logsumexp = empty_strided( + logsumexp_shape, + None, + dtype=torch.float32, # The logsumexp is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + kernel_options.setdefault("SM_SCALE", scale) + + # Determine GQA broadcast factor. + gqa_shared_heads = Hq // Hkv + kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) + + # Inside of Triton kernel, only apply partial masking if partial blocks are computed. + # full_kv_num_blocks is None if partial blocks are not computed + has_full_blocks = full_kv_num_blocks is not None + kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks) + if not has_full_blocks: + full_kv_num_blocks, full_kv_indices = ( + empty(0, device=query.get_device()) for _ in range(2) + ) + + set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) + + choices: list[Any] = [] + + dtype = query.get_dtype() + head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1]) + configs = V.choices.get_flex_attention_fwd_configs(head_dim, dtype) + + # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) + + # Note, we don't need to pass in the captured buffers explicitly + # because they're implicitly added by the score_mod function + # We do need to explicitly pass it in for autotuning though. + original_kernel_options = kernel_options.copy() + # Default config for warp specialization + num_consumer_groups, num_buffers_warp_spec = 0, 0 + + for conf in configs: + if ( + SPARSE_KV_BLOCK_SIZE % conf.block_n != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_m != 0 + ): + if len(configs) == 1: + raise ValueError( + f"Q and KV block size must be divisible by BLOCK_M and BLOCK_N. We " + f"got Q_BLOCK_SIZE={SPARSE_Q_BLOCK_SIZE} and KV_BLOCK_SIZE={SPARSE_KV_BLOCK_SIZE}." + ) + continue + + cur_kernel_options = original_kernel_options.copy() + # Performance tuning + # Triton parameters + # Remove prefix for forward kernels options and delete backward kernel options. + for k in list(cur_kernel_options.keys()): + if k.startswith("fwd_"): + v = cur_kernel_options.pop(k) + cur_kernel_options[k[4:]] = v + if k.startswith("bwd_"): + cur_kernel_options.pop(k) + cur_kernel_options.setdefault("num_stages", conf.num_stages) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + if cur_kernel_options.get("num_consumer_groups", False): + cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) + cur_kernel_options.setdefault( + "num_buffers_warp_spec", num_buffers_warp_spec + ) + + # Disabling TMA by default, only explicit kernel_options supported for now + cur_kernel_options.setdefault("USE_TMA", False) + + cur_kernel_options.setdefault("BLOCK_M", conf.block_m) + cur_kernel_options.setdefault("BLOCK_N", conf.block_n) + # Blocksparse options + cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) + cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + + # ROCm specific kernargs + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + + error = flex_attention_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ], + layout=layout, + subgraphs=[ + subgraph_buffer, + mask_graph_buffer, + ], + mutated_inputs=[ + logsumexp, + ], + call_sizes=query.get_size(), + **cur_kernel_options, + ) + if error is not None and len(configs) == 1: + raise error + inputs_for_autotuning = ( + [ + query, + key, + value, + logsumexp, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + + list(score_mod_other_buffers) + + list(mask_mod_other_buffers) + ) + input_gen_fns = { + 4: create_num_blocks_fake_generator(kv_indices), + 5: create_indices_fake, + 6: create_num_blocks_fake_generator(full_kv_indices), + 7: create_indices_fake, + } + + out = autotune_select_algorithm( + "flex_attention", + choices, + # Need to filter out symbols since there is an invariant + # that all input_nodes are of type IRNode + [x for x in inputs_for_autotuning if isinstance(x, torch._inductor.ir.IRNode)], + layout, + input_gen_fns=input_gen_fns, + ) + + # need subgraph inputs and outputs to analyze all symints used in flex attention + out.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + out.data.data.subgraph_outs = get_fwd_subgraph_outputs( + subgraph_buffer, mask_graph_buffer + ) + + return (out, logsumexp) + + +# ---------------------------- Backward HOP Implementation ---------------------------- + + +def flex_attention_backward_grid( + batch_size, q_heads, num_queries, d_model, kv_heads, num_key_value, meta +): + """How is this kernel parallelized? + Currently this is only parallelizing over batch* kv_heads, but we can, and want to + parallelize over ceil_div(q_heads//kv_heads * num_key_value, key_value_block_size). + To do this will either require atomic updates to some grad values or to have a two pass kernel design. + """ + import triton + + return ( + triton.cdiv(num_queries, meta["BLOCK_M2"]) * (q_heads // kv_heads) + + triton.cdiv(num_key_value, meta["BLOCK_N1"]), + 1, + batch_size * kv_heads, + ) + + +flex_attention_backward_template = TritonTemplate( + name="flex_attention_backward", + grid=flex_attention_backward_grid, + source=r""" +{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}} + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}} + stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}} + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}} + stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}} + + ZQ = {{size("Q", 0)}} + HQ = {{size("Q", 1)}} + HKV = {{size("K", 1)}} + Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_hz = tl.program_id(2) + off_zq = off_hz // HKV # q batch idx + off_hkv = off_hz % HKV # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + {{gen_argdefs()}}, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + {{gen_argdefs()}}, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}} + stride_q_idx_h = {{stride("Q_IDX", 1)}} + stride_q_idx_n = {{stride("Q_IDX", 2)}} + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + {{gen_argdefs()}}, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + {{gen_argdefs()}}, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + {{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}} + +@triton.jit +def bwd_dq_inner( + {{gen_argdefs()}}, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + if not IS_DIVISIBLE: + if hi >= 1: + for start_n in range(0, hi - 1): + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + else: + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, +): + {{gen_defines() | indent_except_first(1)}} + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds prior to the last loop + m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) + + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qk", + b="off_z", + h="off_hq", + m="m", + n="n", + out="qk" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=2, + output_name="mask_mod_output", + score="qk", + b="off_z", + h="off_hq", + m="m", + n="n", + ) | indent_except_first(2) }} + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="ds" + ) | indent_except_first(1) }} + if CHECK_BLOCK_BOUNDARY: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = offs_m2[:, None] < Q_LEN and offs_n2[None, :] < KV_LEN + {{ modification( + subgraph_number=3, + output_name=None, + mask="scatter_mask", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="ds" + ) | indent_except_first(2) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + {{gen_argdefs()}}, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + if not IS_DIVISIBLE: + if hi >= 1: + for start_m in range(0, hi - 1): + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + + offs_m1 += offset + + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + else: + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, +): + {{gen_defines() | indent_except_first(1) }} + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds prior to the last loop + n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) + + pre_mod_scores = qkT + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qkT", + b="off_z", + h="off_hq", + m="m", + n="n", + out="qkT" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=2, + output_name="mask_mod_output", + score="qkT", + b="off_z", + h="off_hq", + m="m", + n="n", + ) | indent_except_first(2) }} + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="dsT" + ) | indent_except_first(1) }} + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = offs_m1[None, :] < Q_LEN and offs_n1[:, None] < KV_LEN + {{ modification( + subgraph_number=3, + output_name=None, + mask="scatter_mask", + score="pre_mod_scores", + b="idx_b", + h="idx_h", + m="idx_m", + n="idx_n", + grad_score_mod="dsT" + ) | indent_except_first(2) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + if CHECK_BLOCK_BOUNDARY: + grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0) + + dsT = grad_scores + if not IS_FULL_BLOCKS: + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv + """ + + compute_next_offset_func + + get_bounded_indices_func + + load_checked_2d, +) + + +def validate_joint_graph(joint_graph: torch.fx.Graph): + """We do some pre lowering graph checks in order to raise nicer error messages""" + for node in joint_graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.flex_lib.zeros_and_scatter.default + ): + for user in node.users: + if user.op != "output": + raise NotImplementedError( + "Using multiple indexing operations on the same tensor that requires gradients " + "in a score_mod function is not currently supported. " + "This typically happens when indexing the same tensor multiple times, like:\n\n" + " def score_mod(score, b, h, q_idx, kv_idx):\n" + " return score + bias[q_idx] + bias[kv_idx] # bias used twice!\n\n" + "A valid workaround is to clone() the tensors that will be indexed multiple times. For example:\n\n" + " bias1 = bias.clone()\n" + " def score_mod(score, b, h, q_idx, kv_idx):\n" + " return score + bias[q_idx] + bias1[kv_idx]\n\n" + "Note that this solution will use additional memory." + ) + return + + +@dataclass(frozen=True) +class JointOutputResult: + """Results from processing joint outputs.""" + + grad_input: ComputedBuffer + captured_grads_compute: list[ComputedBuffer] + captured_grads: list[Optional[TensorBox]] + mutated_grads: list[TensorBox] + + +def process_joint_outputs( + all_joint_outputs: SubgraphResults, num_placeholders: int +) -> JointOutputResult: + """Process joint outputs and extract various buffers needed for lowering + + Args: + all_joint_outputs: List of all the outputs from build_subgraphs + num_placeholders: The number of placeholder inputs, used to skip over unused backward compute buffers + + Returns: + JointOutputResult containing processed buffers and gradients + """ + assert isinstance(all_joint_outputs, list) + assert all_joint_outputs[0] is not None, ( + "joint_subgraph_buffer is None - this is a bug!" + ) + + joint_buffer = all_joint_outputs[0] + other_grads = all_joint_outputs[num_placeholders - 1 :] + + # outer_grads has the structure: Len(other_buffer_grads) if buffer doesn't require grad than it will be None + # We only grab the buffers that require grad for inlining into kernel + grads_compute = [buf for buf in other_grads if buf is not None] + + def get_out(buf): + if buf is None: + return None + assert isinstance(buf, ComputedBuffer) + assert buf.name is not None + return TensorBox.create(V.graph.get_buffer(buf.name)) + + grads_out = [get_out(x) for x in other_grads] + mutated_grads = [buf for buf in grads_out if buf is not None] + + return JointOutputResult( + grad_input=joint_buffer, + captured_grads_compute=grads_compute, + captured_grads=grads_out, + mutated_grads=mutated_grads, + ) + + +# TODO: We probably also need a layout constraint? +@register_lowering( + torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None +) +def flex_attention_backward(*args, **kwargs): + ( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) = args + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + mask_graph, + ) = block_mask + + ( + query, + key, + value, + grad_out, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ) = maybe_realize( + [ + query, + key, + value, + grad_out, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + q_num_blocks, + q_indices, + full_q_num_blocks, + full_q_indices, + ] + ) + + device = query.get_device() + dtype = query.get_dtype() + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + + assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + ) + + kernel_options = dict(kernel_options) + # Mark symbols in custom kernel options as static shapes and add guards. + kernel_options = { + k: V.graph.sizevars.evaluate_static_shape(v) + if isinstance(v, sympy.Symbol) + else v + for k, v in kernel_options.items() + } + kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) + if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: + kernel_options.setdefault("IS_DIVISIBLE", False) + else: + kernel_options.setdefault("IS_DIVISIBLE", True) + + fwd_placeholder_inps = [ + create_placeholder(name, dtype, device) + for name, dtype in [ + ("score", dtype), + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + fw_subgraph_buffer = build_subgraph_buffer( + fwd_placeholder_inps + list(score_mod_other_buffers), fw_graph + ) + + joint_placeholder_inps = fwd_placeholder_inps + [ + create_placeholder("grad_score_mod", dtype, device) + ] + # Sometimes we have weird unused nodes here + joint_graph.graph_module.graph.eliminate_dead_code() + + # It is hard to raise nice errors for some joint graphs during subgraph lowering + # This lets us do some checks before attempting to lower + validate_joint_graph(joint_graph.graph_module.graph) + + all_joint_outputs = build_subgraph_buffer( + joint_placeholder_inps + list(score_mod_other_buffers), + joint_graph, + ) + + joint_outputs = process_joint_outputs( + all_joint_outputs, len(joint_placeholder_inps) + ) + + mask_graph_placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + mask_graph_buffer = build_subgraph_buffer( + mask_graph_placeholder_inps + list(mask_mod_other_buffers), mask_graph + ) + + mask_graph_buffer = mask_graph_buffer + + # Construct layout with stride order matching K + key_size = [Bq, Hkv, seq_len_kv, qk_head_dim] + key_strides = infer_dense_strides(key_size, key.get_stride()) + + layout_broadcasted_k = FixedLayout( + key.get_device(), + key.get_dtype(), + key_size, + stride=[sympy.sympify(s) for s in key_strides], + ) + + # Create delta which will is needed for the bwd's kernel + grad_lse_exp2 = lowerings[aten.mul](grad_logsumexp, 1 / math.log(2)) + mul_delta = lowerings[aten.mul](out, grad_out) + delta = lowerings[aten.sum](mul_delta, axis=-1) + delta = lowerings[aten.sub](delta, grad_lse_exp2) + delta = ExternKernel.require_contiguous(delta) + + grad_lse_exp2, delta = maybe_realize([grad_lse_exp2, delta]) + + # # see NOTE:[TritonTemplates with multiple outputs] + query_size = [Bq, Hq, seq_len_q, qk_head_dim] + grad_query_strides = infer_dense_strides(query_size, query.get_stride()) + grad_query = empty_strided( + query_size, + stride=[sympy.sympify(s) for s in grad_query_strides], + dtype=query.get_dtype(), + device=query.get_device(), + ) + + # Construct output layout with stride order matching value + value_size = [Bq, Hkv, seq_len_kv, v_head_dim] + value_strides = infer_dense_strides(value_size, value.get_stride()) + + broadcasted_grad_value = empty_strided( + value_size, + stride=[sympy.sympify(s) for s in value_strides], + dtype=value.get_dtype(), + device=value.get_device(), + ) + + kernel_options.setdefault("SM_SCALE", scale) + + # Determine GQA factor + gqa_shared_heads = Hq // Hkv + kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) + + # Inside of Triton kernel, only apply partial masking if partial blocks are computed. + # full_kv_num_blocks is torch.zeros([1, 1, 1]) if partial blocks are not computed. + has_full_blocks = full_kv_num_blocks is not None + kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks) + if not has_full_blocks: + full_kv_num_blocks, full_kv_indices, full_q_num_blocks, full_q_indices = ( + empty(0, device=query.get_device()) for _ in range(4) + ) + + set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) + + SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + + choices: list[Any] = [] + + dtype = query.get_dtype() + head_dim = V.graph.sizevars.evaluate_static_shape(query.get_size()[-1]) + configs = V.choices.get_flex_attention_bwd_configs(head_dim, dtype) + + # Default config for warp specialization + num_consumer_groups, num_buffers_warp_spec = 0, 0 + + original_kernel_options = kernel_options.copy() + for conf in configs: + if ( + SPARSE_KV_BLOCK_SIZE % conf.block_m != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_m != 0 + or SPARSE_KV_BLOCK_SIZE % conf.block_n != 0 + or SPARSE_Q_BLOCK_SIZE % conf.block_n != 0 + ): + continue + + # Performance tuning + # Triton heuristics + cur_kernel_options = original_kernel_options.copy() + # Remove prefix for backward kernels options and delete forward kernel options. + for k in list(cur_kernel_options.keys()): + if k.startswith("bwd_"): + v = cur_kernel_options.pop(k) + cur_kernel_options[k[4:]] = v + if k.startswith("fwd_"): + cur_kernel_options.pop(k) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + cur_kernel_options.setdefault("num_stages", conf.num_stages) + + if cur_kernel_options.get("num_consumer_groups", False): + cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) + cur_kernel_options.setdefault( + "num_buffers_warp_spec", num_buffers_warp_spec + ) + + cur_kernel_options.setdefault("BLOCK_M1", conf.block_m) + cur_kernel_options.setdefault("BLOCK_N1", conf.block_n) + cur_kernel_options.setdefault("BLOCK_M2", conf.block_n) + cur_kernel_options.setdefault("BLOCK_N2", conf.block_m) + + # Blocksparse options + cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) + cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + + # ROCm specific kernargs + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + + flex_attention_backward_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + delta, + grad_out, + grad_query, + broadcasted_grad_value, + kv_num_blocks, + kv_indices, + q_num_blocks, + q_indices, + full_kv_num_blocks, + full_kv_indices, + full_q_num_blocks, + full_q_indices, + ], + layout=layout_broadcasted_k, # We use store_output only for grad_key + subgraphs=[ + fw_subgraph_buffer, + joint_outputs.grad_input, + mask_graph_buffer, + joint_outputs.captured_grads_compute, + ], + mutated_inputs=[ + grad_query, + broadcasted_grad_value, + *joint_outputs.mutated_grads, + ], + call_sizes=query.get_size() + key.get_size()[1:3], + **cur_kernel_options, + ) + inputs_for_autotuning = ( + [ + query, + key, + value, + logsumexp, + delta, + grad_out, + grad_query, + broadcasted_grad_value, + kv_num_blocks, + kv_indices, + q_num_blocks, + q_indices, + full_kv_num_blocks, + full_kv_indices, + full_q_num_blocks, + full_q_indices, + ] + + list(score_mod_other_buffers) + + list(mask_mod_other_buffers) + + joint_outputs.mutated_grads + ) + input_gen_fns = { + 8: create_num_blocks_fake_generator(kv_indices), # kv_num_blocks + 9: create_indices_fake, + 10: create_num_blocks_fake_generator(q_indices), # q_num_blocks + 11: create_indices_fake, + 12: create_num_blocks_fake_generator(full_kv_indices), # full_kv_num_blocks + 13: create_indices_fake, + 14: create_num_blocks_fake_generator(full_q_indices), # full_q_num_blocks + 15: create_indices_fake, + } + + broadcasted_grad_key = autotune_select_algorithm( + "flex_attention_backward", + choices, + [x for x in inputs_for_autotuning if isinstance(x, torch._inductor.ir.IRNode)], + layout_broadcasted_k, + input_gen_fns=input_gen_fns, + ) # [Bq, Hkv, seq_len_kv, k_head_dim] + + # need subgraph inputs and outputs to analyze all symints used in flex attention + broadcasted_grad_key.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + broadcasted_grad_key.data.data.subgraph_outs = get_bwd_subgraph_outputs( + fw_subgraph_buffer, mask_graph_buffer, joint_outputs + ) + + if V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv)): + grad_key = broadcasted_grad_key + grad_value = broadcasted_grad_value + else: + assert V.graph.sizevars.evaluate_expr(sympy.Gt(Bq, 1) & sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. " + f"Got Bq={V.graph.sizevars.evaluate_expr(Bq)} " + f"and Bkv={V.graph.sizevars.evaluate_expr(Bkv)}" + ) + grad_key = lowerings[aten.sum](broadcasted_grad_key, axis=0, keepdims=True) + grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True) + + return (grad_query, grad_key, grad_value, tuple(joint_outputs.captured_grads)) + + +def get_bwd_subgraph_outputs( + subgraph_buffer: SubgraphResults, + mask_graph_buffer: SubgraphResults, + joint_outputs: JointOutputResult, +) -> list[Optional[Union[ComputedBuffer, TensorBox]]]: + subgraph_buffer = ( + subgraph_buffer if isinstance(subgraph_buffer, Sequence) else [subgraph_buffer] + ) + mask_graph_buffer = ( + mask_graph_buffer + if isinstance(mask_graph_buffer, Sequence) + else [mask_graph_buffer] + ) + joint_output_buffers = [ + joint_outputs.grad_input, + *joint_outputs.captured_grads_compute, + *joint_outputs.captured_grads, + *joint_outputs.mutated_grads, + ] + + return [*subgraph_buffer, *mask_graph_buffer, *joint_output_buffers] diff --git a/phivenv/Lib/site-packages/torch/_inductor/kernel/flex_decoding.py b/phivenv/Lib/site-packages/torch/_inductor/kernel/flex_decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..89ad46dd0d49992a314cb534e65301b0a2ed0e30 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/kernel/flex_decoding.py @@ -0,0 +1,628 @@ +# mypy: allow-untyped-defs +"""Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)""" + +from typing import Any + +import sympy + +import torch +from torch._inductor.virtualized import V + +from .. import ir +from ..ir import FixedLayout, FlexibleLayout +from ..lowering import empty, empty_strided, lowerings +from ..runtime.runtime_utils import is_power_of_2, next_power_of_2 +from ..select_algorithm import autotune_select_algorithm, SymbolicGridFn, TritonTemplate +from .flex_attention import ( + compute_forward_block_mn, + compute_forward_inner, + compute_next_offset_func, + create_indices_fake, + create_num_blocks_fake_generator, + get_bounded_indices_func, + get_fwd_subgraph_outputs, + load_checked_2d, + load_checked_block, + maybe_realize, +) + + +aten = torch.ops.aten +prims = torch.ops.prims + + +@SymbolicGridFn +def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, meta): + """How is this kernel parallelized? + We create a grid of (batch_size * kv_heads, SPLIT_KV, 1) + Each block is responsible for iterating over blocks of keys and values calculating + the local output for their tile of keys and values over all full length of query. + groups of SPLIT_KV blocks then combine their output to produce the final result. + """ + + return (batch_size * kv_heads, meta["SPLIT_KV"], 1) + + +flex_decoding_template = TritonTemplate( + name="flex_decoding", + grid=flex_decoding_grid, + source=r""" + {{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} + stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}} + stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}} + + + Z = {{size("Q", 0)}} + ZKV = {{size("K", 0)}} + HKV = {{size("Q", 1)}} + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = {{size("Q", 3)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0) % HKV + off_t = tl.program_id(1) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = {{stride("KV_NUM_BLKS")}} + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = {{stride("KV_IDX")}} + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Apply both score_mod and mask_mod + + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + indices_idx = block_n_start // SPARSE_KV_MULTIPLE + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(QK_HEAD_DIM, KV_LEN), # (d, N) + strides=(stride_kk, stride_kn), + offsets=(0, off_n), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(off_n, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + None, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = block_n_start // SPARSE_KV_MULTIPLE + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(QK_HEAD_DIM, KV_LEN), # (d, N) + strides=(stride_kk, stride_kn), + offsets=(0, off_n), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(off_n, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + None, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + {{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} + """ + + compute_forward_inner + + get_bounded_indices_func + + load_checked_block + + load_checked_2d + + compute_next_offset_func + + compute_forward_block_mn, +) + + +def get_split_k(B: int, H: int, Mk: int) -> int: + num_SM = torch.cuda.get_device_properties("cuda").multi_processor_count + bh = max(B * H, 1) # NOTE: Handle B*h=0 case + assert isinstance(bh, (int, sympy.Integer)), "B and H must be concrete integers" + split_k = num_SM // bh * 2 # Each SM should at least get one block. + # TODO: workload evening at runtime for splits fully masked out. + # Before we have runtime workload evening, assign 2 splits per SM. + split_k = max(split_k, 1) + + return split_k + + +def create_flex_decoding_kernel(*args, **kwargs): + from .flex_attention import set_head_dim_values + + ( + query, + key, + value, + block_mask, + scale, + kernel_options, + score_mod_subgraph, + mask_mod_subgraph, + score_mod_other_buffers, + mask_mod_other_buffers, + ) = args + ( + _, # q_length + _, # kv_length + kv_num_blocks, + kv_indices, + full_kv_num_blocks, # full_kv_num_blocks, + full_kv_indices, # full_kv_indices, + _, # q_num_blocks + _, # q_indices + _, # full_q_num_blocks, + _, # full_q_indices, + _, # SPARSE_Q_BLOCK_SIZE, + SPARSE_KV_BLOCK_SIZE, + _, + ) = block_mask + + Bq, Hq, seq_len_q, qk_head_dim = query.get_size() + Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() + + assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), ( + f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" + ) + + B = Bq + kernel_options = dict(kernel_options) + # Mark symbols in custom kernel options as static shapes and add guards. + kernel_options = { + k: V.graph.sizevars.evaluate_static_shape(v) + if isinstance(v, sympy.Symbol) + else v + for k, v in kernel_options.items() + } + + # TODO: Fix flex decoding non-divisible case! + if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: + kernel_options.setdefault("IS_DIVISIBLE", False) + else: + kernel_options.setdefault("IS_DIVISIBLE", True) + + # Calculate GQA head sharing + gqa_shared_heads = Hq // Hkv + if not is_power_of_2(gqa_shared_heads): + raise ValueError( + "Number of shared query heads sharing the same KV head must be power of 2. " + ) + kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) + + # Determine if there are "full" blocks where we only need to apply score_mod, and can skip mask_mod + has_full_blocks = full_kv_num_blocks is not None + kernel_options.setdefault("HAS_FULL_BLOCKS", has_full_blocks) + if not has_full_blocks: + # Create a plackeholder full block list in case it is empty + full_kv_num_blocks, full_kv_indices = ( + empty(0, device=query.get_device()) for _ in range(2) + ) + + ( + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ) = maybe_realize( + [ + query, + key, + value, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + ) + score_mod_other_buffers = maybe_realize(score_mod_other_buffers) + mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) + + choices: list[Any] = [] + dtype = key.get_dtype() + head_dim = V.graph.sizevars.evaluate_static_shape(key.get_size()[-1]) + configs = V.choices.get_flex_decode_configs(head_dim, dtype) + + # TODO: fix autotuning. + + kernel_options.setdefault("SM_SCALE", scale) + kernel_options.setdefault("SPLIT_KV", get_split_k(B, Hkv, seq_len_kv)) + MAX_SPLIT_KV = kernel_options["SPLIT_KV"] + + # create config dependent intermediate buffers + buf_ACC_shape = [B, MAX_SPLIT_KV, Hq, seq_len_q, v_head_dim] + buf_ML_shape = buf_ACC_shape[:-1] + buf_M = empty_strided( + buf_ML_shape, + None, + dtype=torch.float32, # The rowmax is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + buf_L = empty_strided( + buf_ML_shape, + None, + dtype=torch.float32, # The intermediate sumexp is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + + layout_acc = FixedLayout( + query.get_device(), + torch.float32, + buf_ACC_shape, + FlexibleLayout.contiguous_strides(buf_ACC_shape), + ) + + set_head_dim_values(kernel_options, qk_head_dim, v_head_dim, V.graph.sizevars) + + kernel_options.setdefault( + "BLOCK_M", + ( + # m + # if V.graph.sizevars.evaluate_expr(sympy.Lt(query.get_size()[-2], 0)) + # else # Always use a BLOCK_M > 16 before Triton fix https://github.com/triton-lang/triton/pull/4061 is in pin + max( + next_power_of_2( + V.graph.sizevars.size_hint( + seq_len_q, + fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type] + ) + * gqa_shared_heads + ), + 16, + ) + ), + ) + + query = ir.ExternKernel.realize_input(query) + stride_b, stride_hq, stride_seq_len_q, stride_qk_head_dim = query.get_stride() + + # Reshape query for GQA: [B, Hq, Mq, D] -> [B, Hkv, G, Mq, D] + gqa_query_shape = (B, Hkv, gqa_shared_heads, seq_len_q, qk_head_dim) + gqa_query_stride = ( + stride_b, + stride_hq * gqa_shared_heads, + stride_hq, + stride_seq_len_q, + stride_qk_head_dim, + ) + query = lowerings[aten.as_strided](query, gqa_query_shape, gqa_query_stride) + + V.graph.sizevars.guard_leq( + seq_len_q * gqa_shared_heads, sympy.Integer(kernel_options["BLOCK_M"]) + ) + + kernel_options.setdefault( + "SAFE_M_BOUNDARY", + ((seq_len_q * gqa_shared_heads) % kernel_options["BLOCK_M"]) == 0, + ) + # TODO: This feels sketchy + kernel_options.setdefault("SAFE_N_BOUNDARY", True) + # Mark SPARSE_KV_BLOCK_SIZE as static shapes and add guards. + SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + + original_kernel_options = kernel_options.copy() + # Note, we don't need to pass in the captured buffers explicitly + # because they're implicitly added by the score_mod function + # We do need to explicitly pass it in for autotuning though. + + # Default config for warp specialization + num_consumer_groups, num_buffers_warp_spec = 0, 0 + + for conf in configs: + if SPARSE_KV_BLOCK_SIZE % conf.block_n != 0: + continue + + cur_kernel_options = original_kernel_options.copy() + # Remove prefix for forward kernels options and delete backward kernel options. + for k in list(cur_kernel_options.keys()): + if k.startswith("fwd_"): + v = cur_kernel_options.pop(k) + cur_kernel_options[k[4:]] = v + if k.startswith("bwd_"): + cur_kernel_options.pop(k) + # Performance tuning + cur_kernel_options.setdefault("BLOCK_N", conf.block_n) + cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + cur_kernel_options.setdefault("num_warps", conf.num_warps) + cur_kernel_options.setdefault("num_stages", conf.num_stages) + + if cur_kernel_options.get("num_consumer_groups", False): + cur_kernel_options.setdefault("num_consumer_groups", num_consumer_groups) + cur_kernel_options.setdefault( + "num_buffers_warp_spec", num_buffers_warp_spec + ) + + # Set default to False + cur_kernel_options.setdefault("USE_TMA", False) + + # Add ROCm-specific parameters if they exist in the config + for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: + if hasattr(conf, attrib): + cur_kernel_options[attrib] = getattr(conf, attrib) + + flex_decoding_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + buf_M, + buf_L, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ], + layout=layout_acc, + subgraphs=[ + score_mod_subgraph, + mask_mod_subgraph, + ], + mutated_inputs=[buf_M, buf_L], + call_sizes=query.get_size(), + **cur_kernel_options, + ) + + inputs_for_flex_decoding = ( + [ + query, + key, + value, + buf_M, + buf_L, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ] + + list(score_mod_other_buffers) + + list(mask_mod_other_buffers) + ) + + input_gen_fns = { + 5: create_num_blocks_fake_generator(kv_indices), + 6: create_indices_fake, + 7: create_num_blocks_fake_generator(full_kv_indices), + 8: create_indices_fake, + } + + buf_ACC = autotune_select_algorithm( + "flex_decoding", + choices, + inputs_for_flex_decoding, + layout_acc, + input_gen_fns=input_gen_fns, + ) + + # need subgraph inputs and outputs to analyze all symints used in flex attention + buf_ACC.data.data.subgraph_inps = list(score_mod_other_buffers) + list( + mask_mod_other_buffers + ) + buf_ACC.data.data.subgraph_outs = get_fwd_subgraph_outputs( + score_mod_subgraph, mask_mod_subgraph + ) + + # Reduction + + g_M = lowerings[aten.max](buf_M, dim=1, keepdim=True)[0] + # See [Note] Handle fully masked out rows: + # g_M Is the global max among split kv blocks. + masked_rows = lowerings[aten.eq](g_M, -float("inf")) + adj_M = lowerings[aten.sub](buf_M, g_M) + adj_M = lowerings[aten.where](masked_rows, 0, adj_M) + alpha = lowerings[aten.exp2](adj_M) + + buf_L = lowerings[aten.mul](buf_L, alpha) + g_L = lowerings[aten.sum](buf_L, axis=1) + masked_rows_squeezed = lowerings[aten.squeeze](masked_rows, dim=1) + g_L = lowerings[aten.where](masked_rows_squeezed, 1.0, g_L) + logsumexp = lowerings[aten.log2](g_L) + logsumexp = lowerings[aten.add](logsumexp, lowerings[aten.squeeze](g_M, dim=1)) + + alpha_unseq = lowerings[aten.unsqueeze](alpha, 4) + buf_ACC = lowerings[aten.mul](buf_ACC, alpha_unseq) + output = lowerings[aten.sum](buf_ACC, axis=1) + L_unseq = lowerings[aten.unsqueeze](g_L, 3) + output = lowerings[aten.div](output, L_unseq) + output = lowerings[prims.convert_element_type](output, query.get_dtype()) + + return ( + output, + logsumexp, + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/kernel/mm.py b/phivenv/Lib/site-packages/torch/_inductor/kernel/mm.py new file mode 100644 index 0000000000000000000000000000000000000000..435fe9790eaebd026f34d4fc7f5b1d90b58b0c15 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/kernel/mm.py @@ -0,0 +1,1373 @@ +# mypy: allow-untyped-defs +import functools +import logging +from typing import Any, Optional + +import sympy + +import torch +from torch._dynamo.utils import counters +from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + context_add_strides, + context_add_using_tf32, + mm_operations, +) +from torch._inductor.codegen.cpp_gemm_template import CppGemmTemplate +from torch._inductor.virtualized import V +from torch.fx.experimental.proxy_tensor import make_fx +from torch.torch_version import TorchVersion + +from .. import config as inductor_config, ir +from ..codegen.cuda.gemm_template import CUTLASS2xGemmTemplate, CUTLASS3xGemmTemplate +from ..codegen.rocm.ck_tile_universal_gemm_template import CKTileGemmTemplate +from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate +from ..codegen.subgraph import SubgraphTemplate +from ..ir import FlexibleLayout, is_triton +from ..lowering import ( + add_layout_constraint, + constrain_to_fx_strides, + lowerings as L, + register_lowering, +) +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + realize_inputs, + TritonTemplate, +) +from ..utils import ( + _use_cutlass_for_op, + get_k_splits, + get_tma_workspace_arg, + use_aten_gemm_kernels, + use_ck_gemm_template, + use_ck_tile_gemm_template, + use_cpp_gemm_template, + use_cutlass_template, + use_decompose_k_choice, + use_triton_template, + use_triton_tma_template, +) +from .mm_common import ( + _is_static_problem, + addmm_epilogue, + mm_args, + mm_config_kwargs, + mm_grid, + mm_options, + persistent_mm_grid, + persistent_mm_options, + scale_mm_epilogue, + scaled_mm_options, +) + + +try: + import triton + + triton_version = TorchVersion(triton.__version__) + has_triton = True +except ImportError: + triton_version = TorchVersion("0.0.0") + has_triton = False + +log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims + +mm_template = TritonTemplate( + name="mm", + grid=mm_grid, + source=( + r""" +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and M >= BLOCK_M: + offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + offs_a_m = rm % M + if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and N >= BLOCK_N: + offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + offs_b_n = rn % N + offs_k = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for k_idx in range(0, tl.cdiv(K, BLOCK_K)): + {% if not EVEN_K %} + a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) + b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) + {% endif %} + a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) + b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) + + idx_m = offs_a_m[:, None] + idx_n = a_k_idx_vals + {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", indent_width=8)}} + + idx_m = b_k_idx_vals + idx_n = offs_b_n[None, :] + {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}} + + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""" + if (torch.version.hip is None) or triton_version >= "3.3.0" + # FIXME: To get around rocm failures like https://github.com/pytorch/pytorch/actions/runs/13123783322/job/36617154943 + # The only difference between the two templates is M >= BLOCK_M and N >= BLOCK_N checking. + # See more details in https://github.com/pytorch/pytorch/pull/146293 + else r""" +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): + offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + offs_a_m = rm % M + if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): + offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + offs_b_n = rn % N + offs_k = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for k_idx in range(0, tl.cdiv(K, BLOCK_K)): + {% if not EVEN_K %} + a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) + b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) + {% endif %} + a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) + b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) + + idx_m = offs_a_m[:, None] + idx_n = a_k_idx_vals + {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", indent_width=8)}} + + idx_m = b_k_idx_vals + idx_n = offs_b_n[None, :] + {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}} + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""" + ), + cache_codegen_enabled_for_template=True, + prologue_loads_all_inputs=True, +) + +persistent_tma_mm_template = TritonTemplate( + name="mm_persistent_tma", + grid=persistent_mm_grid, + source=r""" +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + start_pid = tl.program_id(0) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = grid_m * grid_n + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + width = GROUP_M * grid_n + rk_for_mask = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + + {%- if TMA_EXPERIMENTAL_API %} + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=A, + load_size=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + global_size=[M, K] if A_ROW_MAJOR else [K, M], + element_ty=A.dtype.element_ty, + ) + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=B, + load_size=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + global_size=[K, N] if B_ROW_MAJOR else [N, K], + element_ty=B.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + + a_desc = a_desc_ptr + b_desc = b_desc_ptr + {%- else %} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K] if A_ROW_MAJOR else [K, M], + strides=[K, 1] if A_ROW_MAJOR else [M, 1], + block_shape=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[K, N] if B_ROW_MAJOR else [N, K], + strides=[N, 1] if B_ROW_MAJOR else [K, 1], + block_shape=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + ) + {%- endif %} + + pid_m = 0 + pid_n = 0 + rm = 0 + rn = 0 + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + # re-order program ID for better L2 performance + group_id = tile_id // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // (group_size) + + rm = pid_m * BLOCK_M + rn = pid_n * BLOCK_N + + rk = ki * BLOCK_K + + {%- if TMA_EXPERIMENTAL_API %} + a = tl._experimental_descriptor_load( + a_desc, + [rm, rk] if A_ROW_MAJOR else [rk, rm], + [BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], + A.dtype.element_ty, + ) + b = tl._experimental_descriptor_load( + b_desc, + [rk, rn] if B_ROW_MAJOR else [rn, rk], + [BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], + B.dtype.element_ty, + ) + {%- else %} + a = tl.load_tensor_descriptor( + a_desc, + [rm, rk] if A_ROW_MAJOR else [rk, rm], + ) + b = tl.load_tensor_descriptor( + b_desc, + [rk, rn] if B_ROW_MAJOR else [rn, rk], + ) + {%- endif %} + acc += tl.dot( + a if A_ROW_MAJOR else a.T, + b if B_ROW_MAJOR else b.T, + allow_tf32=ALLOW_TF32, + ) + + if ki == k_tiles - 1: + # rematerialize rm and rn to save registers + rcm = rm + tl.arange(0, BLOCK_M) + rcn = rn + tl.arange(0, BLOCK_N) + idx_m = rcm[:, None] + idx_n = rcn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}} + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + +""", +) + +load_scales = r""" +@triton.jit +def load_scales(a_scale_ptr, b_scale_ptr, SCALING_ROWWISE: tl.constexpr): + if SCALING_ROWWISE: + # For row-wise scaling, we'll return the pointers + return a_scale_ptr, b_scale_ptr + else: + # For per-tensor scaling, we'll load the scalar values + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr) + return a_scale, b_scale +""" + + +apply_scaling = r""" +@triton.jit +def apply_scaling( + accumulator, + a_scale, + b_scale, + SCALING_ROWWISE: tl.constexpr, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, +): + if SCALING_ROWWISE: + # For row-wise scaling, we need to load the scales for each row/column + a_scales = tl.load( + a_scale + (offs_cm * stride_a_scale_m), + mask=offs_cm < M, + other=0.0, + ) + b_scales = tl.load( + b_scale + (offs_cn * stride_b_scale_n), + mask=offs_cn < N, + other=0.0, + ) + acc_scale = a_scales[:, None] * b_scales[None, :] + else: + # For per-tensor scaling, we can directly use the loaded scalar values + acc_scale = a_scale * b_scale + + return accumulator * acc_scale +""" + + +device_tma = r""" +{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + if SCALING_ROWWISE: + stride_a_scale_m = 1 + stride_b_scale_n = 1 + else: + stride_a_scale_m = 0 + stride_b_scale_n = 0 + + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = num_pid_m * num_pid_n + + workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + + {%- if TMA_EXPERIMENTAL_API %} + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=A, + load_size=[BLOCK_M, BLOCK_K], + global_size=[M, K], + element_ty=A.dtype.element_ty, + ) + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=B, + load_size=[BLOCK_N, BLOCK_K], + global_size=[N, K], + element_ty=B.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + + a_desc = a_desc_ptr + b_desc = a_desc_ptr + {%- else %} + a_desc = triton.language.make_tensor_descriptor( + base=A, + shape=[M, K], + strides=[K, 1], + block_shape=[BLOCK_M, BLOCK_K], + ) + b_desc = triton.language.make_tensor_descriptor( + base=B, + shape=[N, K], + strides=[K, 1], + block_shape=[BLOCK_N, BLOCK_K], + ) + {%- endif %} + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_M * num_pid_n + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + a_scale, b_scale = load_scales(A_inverse_scale, B_inverse_scale, SCALING_ROWWISE) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + offs_k = ki * BLOCK_K + + {%- if TMA_EXPERIMENTAL_API %} + a = tl._experimental_descriptor_load( + a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty + ) + {%- else %} + a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) + b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k]) + {%- endif %} + if USE_FAST_ACCUM: + accumulator = tl.dot(a, b.T, accumulator) + else: + accumulator += tl.dot(a, b.T) + + if ki == k_tiles - 1: + # Apply inverse scaling + offs_cm = offs_am + tl.arange(0, BLOCK_M) + offs_cn = offs_bn + tl.arange(0, BLOCK_N) + # Apply scaling + accumulator = apply_scaling( + accumulator, + a_scale, + b_scale, + SCALING_ROWWISE, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, + ) + + idx_m = offs_cm[:, None] + idx_n = offs_cn[None, :] + mask = (idx_m < M) & (idx_n < N) + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}} + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) +""" + + +scaled_mm_device_tma_template = TritonTemplate( + name="scaled_mm_device_tma", + grid=persistent_mm_grid, + source=device_tma + load_scales + apply_scaling, +) + + +# prevent duplication registration of extern functions +@functools.cache +def lazy_register_extern_choice(fn): + return ExternKernelChoice(fn) + + +aten_mm = ExternKernelChoice(torch.mm, "at::mm_out") + +aten_addmm = ExternKernelChoice( + torch.addmm, "at::addmm_out", op_overload=aten.addmm.default +) + +aten__int_mm = ExternKernelChoice(torch._int_mm, "at::_int_mm_out") + +aten__sparse_semi_structured_mm = ExternKernelChoice( + torch._sparse_semi_structured_mm, + "at::_sparse_semi_structured_mm", + has_out_variant=False, +) + +aten__fp8_mm = ExternKernelChoice( + torch._scaled_mm, "at::_scaled_mm_out", op_overload=aten._scaled_mm.out +) + + +def _is_int8_mat(mat): + return mat.get_dtype() in (torch.int8, torch.uint8) + + +def _is_large_block_for_cpu(m, n, k): + # Thresholds are experimentally determined to reduce Triton CPU compile times + return m * n > 2**13 + + +@functools.lru_cache +def using_b200() -> bool: + """Returns true if the device is a NVIDIA B200, otherwise returns false.""" + if not torch.cuda.is_available(): + return False + # compute capability 10.0 or 10.0a is NVIDIA B200 + device_properties = torch.cuda.get_device_properties(torch.cuda.current_device()) + return device_properties.major == 10 + + +def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): + """ + Giving torch.addmm a 1D tensor calls a different (faster) cublasLt + kernel under the hood. There are a few shapes where this is slower, + but they are rare. + """ + if inp.stride(0) == 0 or inp.size(0) == 1: + return torch.addmm(inp[0], mat1, mat2, out=out, alpha=alpha, beta=beta) + return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta) + + +def check_supported_striding(mat_a, mat_b) -> None: + def is_row_major(stride) -> bool: + return V.graph.sizevars.statically_known_equals(stride[1], 1) + + def is_col_major(stride) -> bool: + return V.graph.sizevars.statically_known_equals(stride[0], 1) + + def has_zero_dim(size) -> bool: + return bool( + V.graph.sizevars.statically_known_equals(size[0], 0) + or V.graph.sizevars.statically_known_equals(size[1], 0) + ) + + # Check mat_a (self) stride requirements + torch._check( + is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()), + lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}", + ) + + # Check mat_b stride requirements + torch._check( + is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()), + lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}", + ) + + +aten_bias_addmm = ExternKernelChoice(bias_addmm, None) + + +def decomposeK(a, b, k_splits): + m = a.shape[0] + n = b.shape[1] + k = a.shape[1] + + k_parts = k // k_splits + B = k_splits + a_reshaped = torch.permute(a.reshape(m, B, k_parts), (1, 0, 2)) + b_reshaped = b.reshape(B, k_parts, n) + result = torch.bmm(a_reshaped, b_reshaped, out_dtype=torch.float32) + reduced_buf = torch.sum(result, 0) + return reduced_buf.to(a.dtype) + + +@register_lowering(aten.mm, type_promotion_kind=None) +def tuned_mm(mat1, mat2, *, layout=None): + """ + Lowering for autotuning aten.mm with different backends (Aten, Triton, CUTLASS, etc.) + """ + m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + device_type = ir.get_device_type(mat1) + name = "mm" + + # below is for getting an overview logging info of inductor mms + counters["aten_mm_info"][f"aten.mm_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten.mm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + m, + n, + k, + mat1.get_dtype(), + mat2.get_dtype(), + layout, + ) + + aten_layout = layout + if not (inductor_config.max_autotune or inductor_config.max_autotune_gemm): + aten_layout = FlexibleLayout( + device=layout.device, dtype=layout.dtype, size=layout.size + ) + + # options to tune from + choices = ( + [aten_mm.bind((mat1, mat2), aten_layout)] if use_aten_gemm_kernels() else [] + ) + static_shape, is_nonzero = _is_static_problem(layout) + + mm_configs = V.choices.get_base_mm_configs(device_type) + persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type) + extra_mm_configs = V.choices.get_extra_mm_configs(device_type) + + dtype = mat1.get_dtype() + if is_nonzero and use_triton_template(layout): + for config in mm_configs( + m, + n, + k, + **mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize), + ): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + + if use_triton_tma_template(mat1, mat2): + for config in persistent_mm_configs( + m, + n, + k, + **mm_config_kwargs( + device_type, _is_large_block_for_cpu, dtype.itemsize + ), + ): + persistent_tma_mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + workspace_arg=get_tma_workspace_arg( + num_tma_descriptors=2, + device=mat1.get_device(), + ), + **mm_options(config, m, n, k, layout), + **persistent_mm_options(mat1, mat2), + ) + + from torch._inductor.ir import get_free_symbols + + # Only do split-k optimization if K is much larger than m, n and m, n are small + # and if there aren't any unbacked symbols + unbacked_symbols = any( + len(get_free_symbols(itr, unbacked_only=True)) > 0 + for itr in ( + mat1.get_size(), + mat1.get_stride(), + mat2.get_size(), + mat2.get_stride(), + ) + ) + if use_decompose_k_choice(m, n, k) and not unbacked_symbols: + from torch._dispatch.python import enable_python_dispatcher + + from ..decomposition import select_decomp_table + + k_splits = get_k_splits(m, n, k) + for k_split in k_splits: + if not V.graph.sizevars.statically_known_true( + sympy.Eq(sympy.Mod(k, k_split), 0) + ): + continue + + with enable_python_dispatcher(): + decompositions = select_decomp_table() + + decompose_k_subgraph_template = SubgraphTemplate( + name=f"decompose_k_mm_{k_split}_split", + make_fx_graph=make_fx( + functools.partial(decomposeK, k_splits=k_split), + decompositions, + ), + ) + + decompose_k_subgraph_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + ) + + if ( + is_nonzero + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op("mm") + ): + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) + + if is_nonzero and use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2]) + if is_nonzero and use_ck_tile_gemm_template(layout, m, n, k): + CKTileGemmTemplate.add_choices(choices, layout, [mat1, mat2]) + + if use_cpp_gemm_template(layout, mat1, mat2): + CppGemmTemplate.add_choices( + choices, + layout, + [mat1, mat2], + ) + + input_nodes = [mat1, mat2] + if ( + is_nonzero + and use_triton_template(layout) + and torch._inductor.config.run_autoheuristic(name) + and is_triton(mat1) + ): + always_included = [] + if use_aten_gemm_kernels(): + always_included.append("extern_mm") + num_choices_before_extra_configs = len(choices) + for config in extra_mm_configs( + m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) + ): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + + # using AutoHeuristic for ranking + ah_choices = mm_autoheuristic( + mat1, + mat2, + m, + n, + k, + choices, + name, + input_nodes, + mm_operations(), + None, + top_k=10, + always_included=always_included, + ) + if not torch._inductor.config.collect_autoheuristic(name): + # if we are collecting data, we do not want to modify choices + if ah_choices is not None and len(ah_choices) > 0: + # the order in which autoheuristic returns choices is not the same as + # as the order of choices, which affects things like epilogue fusion. + # once epilogue fusion benchmarks choices in sorted order, I think we can + # just use the order returned by autoheuristic + choices = [choice for choice in choices if choice in ah_choices] + else: + choices = choices[:num_choices_before_extra_configs] + + for k in inductor_config.external_matmul: + choices.append(lazy_register_extern_choice(k).bind((mat1, mat2), layout)) + + return autotune_select_algorithm(name, choices, [mat1, mat2], layout) + + +@register_lowering(aten._int_mm, type_promotion_kind=None) +def tuned_int_mm(mat1, mat2, *, layout=None): + m, n, k, layout, mat1, mat2 = mm_args( + mat1, mat2, layout=layout, out_dtype=torch.int32 + ) + + # below is for getting an overview logging info of inductor mms + counters["aten_mm_info"][f"aten._int_mm_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten._int_mm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + m, + n, + k, + mat1.get_dtype(), + mat2.get_dtype(), + layout, + ) + + device_type = ir.get_device_type(mat1) + + static_shape, is_nonzero = _is_static_problem(layout) + use_cutlass = static_shape and is_nonzero and use_cutlass_template(layout, m, n, k) + + choices = ( + [aten__int_mm.bind((mat1, mat2), layout)] if use_aten_gemm_kernels() else [] + ) + + if use_cutlass and _use_cutlass_for_op("int_mm"): + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( + choices, layout, [mat1, mat2], fuseable=True, non_fuseable=True + ) + + int8_mm_configs = V.choices.get_int8_mm_configs(device_type) + + if is_nonzero and use_triton_template(layout, enable_int32=True): + for config in int8_mm_configs( + m, n, k, **mm_config_kwargs(device_type, _is_large_block_for_cpu) + ): + mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + ) + + return autotune_select_algorithm("int_mm", choices, [mat1, mat2], layout) + + +@register_lowering(aten.addmm, type_promotion_kind=None) +def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): + device_type = ir.get_device_type(mat1) + m, n, k, layout, mat1, mat2, inp_expanded = mm_args(mat1, mat2, inp, layout=layout) + static_shape, is_nonzero = _is_static_problem(layout) + + # below is for getting an overview logging info of inductor mms + counters["aten_mm_info"][f"aten.addmm_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten.addmm: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + m, + n, + k, + mat1.get_dtype(), + mat2.get_dtype(), + layout, + ) + + if (not is_nonzero) or ( + not (inductor_config.max_autotune or inductor_config.max_autotune_gemm) + ): + # Use a FlexibleLayout if we are not autotuning. + # This allows padding strides for the output. + from torch._inductor.ir import FixedLayout, FlexibleLayout + + if isinstance(layout, FixedLayout): + layout = FlexibleLayout( + device=layout.device, dtype=layout.dtype, size=layout.size + ) + choices = ( + [ + aten_addmm.bind( + (inp, mat1, mat2), + layout, + alpha=alpha, + beta=beta, + ) + ] + if use_aten_gemm_kernels() + else [] + ) + return autotune_select_algorithm("addmm", choices, [inp, mat1, mat2], layout) + + choices = ( + [ + aten_addmm.bind( + (inp_expanded, mat1, mat2), + layout, + alpha=alpha, + beta=beta, + ) + ] + if use_aten_gemm_kernels() + else [] + ) + + if ( + use_aten_gemm_kernels() + and inp_expanded.get_stride()[0] == 0 + and inp_expanded.get_device().type == "cuda" + and inductor_config.triton.autotune_cublasLt + ): + # unexpand inp to make sure fused addmm from cublasLt is used + choices.insert( + 0, + aten_bias_addmm.bind( + (inp_expanded, mat1, mat2), layout, alpha=alpha, beta=beta + ), + ) + + mm_configs = V.choices.get_base_mm_configs(device_type) + persistent_mm_configs = V.choices.get_persistent_mm_configs(device_type) + + dtype = mat1.get_dtype() + if is_nonzero and use_triton_template(layout): + for config in mm_configs( + m, + n, + k, + **mm_config_kwargs(device_type, _is_large_block_for_cpu, dtype.itemsize), + ): + mm_template.maybe_append_choice( + choices, + input_nodes=(inp_expanded, mat1, mat2), + layout=layout, + **mm_options(config, m, n, k, layout), + prefix_args=1, + epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + epilogue_fn_hash=str(["addmm_epilogue", layout.dtype, alpha, beta]), + ) + + if use_triton_tma_template(mat1, mat2): + for config in persistent_mm_configs( + m, + n, + k, + **mm_config_kwargs( + device_type, _is_large_block_for_cpu, dtype.itemsize + ), + ): + persistent_tma_mm_template.maybe_append_choice( + choices, + input_nodes=(inp_expanded, mat1, mat2), + layout=layout, + workspace_arg=get_tma_workspace_arg( + num_tma_descriptors=2, + device=mat1.get_device(), + ), + **mm_options(config, m, n, k, layout), + **persistent_mm_options(mat1, mat2), + prefix_args=1, + epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta), + ) + + if ( + is_nonzero + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op("addmm") + ): + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( + choices, + layout, + [mat1, mat2, inp_expanded], + alpha=alpha, + beta=beta, + input_reorder=[2, 0, 1], + ) + + if is_nonzero and use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices( + choices, + layout, + [mat1, mat2, inp_expanded], + alpha=alpha, + beta=beta, + input_reorder=[2, 0, 1], + ) + + if use_cpp_gemm_template(layout, mat1, mat2): + CppGemmTemplate.add_choices( + choices, + layout, + [inp_expanded, mat1, mat2], + alpha=alpha, + beta=beta, + has_bias=True, + ) + + return autotune_select_algorithm( + "addmm", choices, [inp_expanded, mat1, mat2], layout + ) + + +@register_lowering(aten._sparse_semi_structured_mm, type_promotion_kind=None) +def tuned_sparse_semi_structured_mm( + mat1, mat1_meta, mat2, *, out_dtype=None, layout=None +): + from torch._inductor.select_algorithm import realize_inputs + + mat1, mat1_meta, mat2 = realize_inputs(mat1, mat1_meta, mat2) + m1, k1 = mat1.get_size() + m2, _ = mat1_meta.get_size() + k2, n = mat2.get_size() + m = V.graph.sizevars.guard_equals(m1, m2) + k = V.graph.sizevars.guard_equals(2 * k1, k2) + + if layout is None: + from torch._inductor.ir import FixedLayout + + layout = FixedLayout( + mat2.get_device(), + out_dtype if out_dtype else mat2.get_dtype(), + [m, n], + [n, 1], + ) + else: + assert out_dtype is None, "out_dtype is ignored if layout is specified." + + choices = ( + [ + aten__sparse_semi_structured_mm.bind( + (mat1, mat1_meta, mat2), layout, out_dtype=out_dtype + ) + ] + if use_aten_gemm_kernels() + else [] + ) + + if ( + m * n != 0 + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op("sparse_semi_structured_mm") + ): + CUTLASS2xGemmTemplate.add_cutlass_gemm_choices( + choices, layout, [mat1, mat2, mat1_meta], fuseable=True, non_fuseable=True + ) + + return autotune_select_algorithm( + "sparse_semi_structured_mm", choices, [mat1, mat1_meta, mat2], layout + ) + + +add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides) + + +@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc] +def tuned_scaled_mm( + mat_a, + mat_b, + scale_a, + scale_b, + bias=None, + scale_result=None, + out_dtype=None, + use_fast_accum=False, + layout=None, +): + """ + Performs an optimized matrix multiplication where scaling factors are applied + to the inputs and/or output. + + Args: + mat1 (Tensor): First input matrix + mat2 (Tensor): Second input matrix + scale1 (Tensor): Scale factor applied to mat1 (supports broadcasting) + scale2 (Tensor): Scale factor applied to mat2 (supports broadcasting) + bias (Tensor, optional): Optional bias tensor to add to the result + layout: Layout hint for optimization + + Returns: + Tensor: The result of the scaled matrix multiplication + """ + m, n, k, layout, mat_a, mat_b = mm_args( + mat_a, mat_b, layout=layout, out_dtype=out_dtype + ) + # below is for getting an overview logging info of inductor mms + counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten._scaled_mm.default: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + m, + n, + k, + mat_a.get_dtype(), + mat_b.get_dtype(), + layout, + ) + + device_type = ir.get_device_type(mat_a) + check_supported_striding(mat_a, mat_b) + + scale_a_real, scale_b_real = realize_inputs(scale_a, scale_b) + + input_nodes: tuple[Any, ...] + + if not bias: + input_nodes = (mat_a, mat_b, scale_a_real, scale_b_real) + else: + bias_real = realize_inputs(bias) + input_nodes = (mat_a, mat_b, scale_a_real, scale_b_real, bias_real) + + aten_choice = aten__fp8_mm.bind( + input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum + ) + + choices = [] + if use_aten_gemm_kernels(): + choices.append(aten_choice) + + # We dont have triton lowerings for the MX variants yet + if scale_a.dtype != torch.float32: + return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) + + _, is_nonzero = _is_static_problem(layout) + + scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type) + scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs( + device_type + ) + + if is_nonzero and use_triton_template(layout, enable_float8=True): + triton_input_nodes: tuple[Any, ...] + if bias and len(mat_b.get_size()) == len(bias.get_size()) + 1: + # Need to unsqueeze bias from [N] -> [1, N] + triton_bias = L[aten.unsqueeze](bias, 0) + else: + triton_bias = bias + + if len(scale_a.get_size()) == 0 or len(scale_b.get_size()) == 0: + assert len(scale_a.get_size()) == len(scale_b.get_size()) + # Need to unsqueeze scale from [] -> [1, 1] + triton_scale_a = L[aten.unsqueeze](L[aten.unsqueeze](scale_a, 0), 1) + triton_scale_b = L[aten.unsqueeze](L[aten.unsqueeze](scale_b, 0), 1) + else: + triton_scale_a = scale_a + triton_scale_b = scale_b + + if bias: + triton_input_nodes = ( + mat_a, + mat_b, + triton_scale_a, + triton_scale_b, + triton_bias, + ) + suffix_args = 3 + else: + triton_input_nodes = (mat_a, mat_b, triton_scale_a, triton_scale_b) + suffix_args = 2 + + # TODO (paulzhan): There is no template that exists for bias and TMA + # Don't run tma template currently if bias exists + if use_triton_tma_template(mat_a, mat_b) and not bias: + for config in scaled_persistent_mm_configs(m, n, k): + kwargs = scaled_mm_options( + config, + m, + n, + k, + layout, + scale_a, + scale_b, + use_fast_accum, + device_tma=True, + ) + scaled_mm_device_tma_template.maybe_append_choice( + choices, + input_nodes=triton_input_nodes, + layout=layout, + workspace_arg=get_tma_workspace_arg( + num_tma_descriptors=2, + device=mat_a.get_device(), + ), + **kwargs, + ) + + for config in scaled_mm_configs(m, n, k): + if V.graph.sizevars.guard_or_false(sympy.Le(k, 16)): + # Triton crashes however uncommon for real workloads + continue + + # On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid + # source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape + if using_b200() and V.graph.sizevars.guard_or_false(sympy.Lt(k, 32)): + continue + + kwargs = scaled_mm_options( + config, m, n, k, layout, scale_a, scale_b, use_fast_accum + ) + # possibly appends a TritonTemplateCaller to choices + mm_template.maybe_append_choice( + choices, + input_nodes=triton_input_nodes, + layout=layout, + **kwargs, + suffix_args=suffix_args, + epilogue_fn=scale_mm_epilogue(), + epilogue_fn_hash="scale_mm_epilogue", + ) + + if ( + is_nonzero + and use_cutlass_template(layout, m, n, k) + and _use_cutlass_for_op("scaled_mm") + ): + CUTLASS3xGemmTemplate.add_cutlass_gemm_choices( + choices, + layout, + input_nodes, # type: ignore[arg-type] + use_fast_accum=use_fast_accum, # type: ignore[arg-type] + ) + + if is_nonzero and use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes) + + return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) + + +@functools.cache +def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool: + props = torch.cuda.get_device_properties(index or 0) + return props.major <= 7 + + +def dims_are_int(dims): + return all(isinstance(dim, int) for dim in dims) + + +def mm_autoheuristic( + mat1, + mat2, + m, + n, + k, + choices, + name, + input_nodes, + ops, + precondition, + top_k: Optional[int] = None, + always_included=None, +): + m, n, k = get_size_hints(mat1, mat2, m, n, k) + if not dims_are_int([m, n, k]): + return None + mat1_stride, mat2_stride = get_size_hints_strides(mat1, mat2) + + def get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride): + context = AHContext() + context.add_feature("m", m) + context.add_feature("k", k) + context.add_feature("n", n) + context.add_feature("mat1_dtype", mat1.layout.dtype, is_categorical=True) + context.add_feature("mat2_dtype", mat2.layout.dtype, is_categorical=True) + context_add_strides(context, "mat1", mat1_stride) + context_add_strides(context, "mat2", mat2_stride) + context.add_feature( + "mat1_iscontig", mat1.layout.is_contiguous(), is_categorical=True + ) + context.add_feature( + "mat2_iscontig", mat2.layout.is_contiguous(), is_categorical=True + ) + if name == "mm": + context_add_using_tf32(context, mat1.layout.dtype) + return context + + def fallback(): + return None + + context = get_context(m, k, n, mat1, mat2, mat1_stride, mat2_stride) + autoheuristic = AutoHeuristicSelectAlgorithm( + fallback=fallback, + choices=choices, + input_nodes=input_nodes, + context=context, + name=name, + augment_context=ops, + precondition=precondition, + ) + + if top_k is not None: + # TODO: is there a cleaner way to ensure aten.mm is always included? + return autoheuristic.get_top_k_choices_caller( + top_k, always_included=always_included + ) + + return autoheuristic.get_choice_caller() + + +def get_size_hints(mat1, mat2, m, n, k): + if not isinstance(m, int) or not isinstance(k, int): + (m, k) = V.graph.sizevars.size_hints( + mat1.get_size(), + fallback=torch._inductor.config.unbacked_symint_fallback, + ) + + if not isinstance(n, int) or not isinstance(k, int): + (k, n) = V.graph.sizevars.size_hints( + mat2.get_size(), + fallback=torch._inductor.config.unbacked_symint_fallback, + ) + return m, n, k + + +def get_size_hints_strides(mat1, mat2): + mat1_stride = mat1.layout.stride + mat2_stride = mat2.layout.stride + strides = [mat1_stride, mat2_stride] + strides_hints = [] + for stride in strides: + if not isinstance(stride, int): + stride = V.graph.sizevars.size_hints( + stride, + fallback=torch._inductor.config.unbacked_symint_fallback, + ) + strides_hints.append(stride) + return strides_hints[0], strides_hints[1] diff --git a/phivenv/Lib/site-packages/torch/_inductor/kernel/mm_common.py b/phivenv/Lib/site-packages/torch/_inductor/kernel/mm_common.py new file mode 100644 index 0000000000000000000000000000000000000000..a7cd3ff0485a94ec1845c685dfbe334fa6fb4980 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/kernel/mm_common.py @@ -0,0 +1,302 @@ +# mypy: allow-untyped-defs +import logging +from collections.abc import Sequence +from typing import Any + +import sympy + +import torch +from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn +from torch._inductor.utils import sympy_product +from torch._inductor.virtualized import V + +from .. import config as inductor_config +from ..codegen.wrapper import PythonWrapperCodegen +from ..ir import _IntLike, Layout, TensorBox +from ..utils import get_num_sms, TMA_DESCRIPTOR_SIZE + + +log = logging.getLogger(__name__) + + +@SymbolicGridFn +def mm_grid(m, n, meta, *, cdiv): + """ + The CUDA grid size for matmul triton templates. + """ + return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1) + + +@SymbolicGridFn +def persistent_mm_grid(M: int, N: int, meta: dict[str, Any], *, cdiv, min): + """Defines the grid for persistent kernels.""" + return ( + min(meta["NUM_SMS"], cdiv(M, meta["BLOCK_M"]) * cdiv(N, meta["BLOCK_N"])), + 1, + 1, + ) + + +@SymbolicGridFn +def persistent_grouped_mm_grid(*args): + meta = args[-1] + return (meta["NUM_SMS"], 1, 1) + + +def acc_type(dtype): + if dtype in (torch.float16, torch.bfloat16): + return "tl.float32" + return f"tl.{dtype}".replace("torch.", "") + + +def mm_options(config, sym_m, sym_n, sym_k, layout): + """ + Common options to matmul triton templates. + """ + even_k_symbolic = ( + # it isn't worth guarding on this + sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] + ) + allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and ( + not inductor_config.force_same_precision + or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0) + ) + options_dict = dict( + EVEN_K=even_k_symbolic, + ALLOW_TF32=allow_tf32, + USE_FAST_ACCUM=False, # Option for _scaled_mm + ACC_TYPE=acc_type(layout.dtype), + num_stages=config.num_stages, + num_warps=config.num_warps, + **config.kwargs, + ) + + # If GROUP_M not specified then default to 8 + if "GROUP_M" not in config.kwargs: + group_m = config.kwargs.get("GROUP_M", 8) + options_dict["GROUP_M"] = group_m + + return options_dict + + +def tma_options() -> dict[str, Any]: + from torch.utils._triton import has_triton_stable_tma_api + + return {"TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api()} + + +def persistent_mm_options(mat1, mat2): + res = dict( + A_ROW_MAJOR=not mat1.layout.is_transposed(), + B_ROW_MAJOR=not mat2.layout.is_transposed(), + NUM_SMS=get_num_sms(), + TMA_SIZE=TMA_DESCRIPTOR_SIZE, + ) + res.update(tma_options()) + return res + + +def scaled_mm_options( # type: ignore[no-untyped-def] + config, # triton.Config + sym_m: sympy.core.numbers.Integer, + sym_n: sympy.core.numbers.Integer, + sym_k: sympy.core.numbers.Integer, + layout: Layout, + scale_a, + scale_b, + use_fast_accum: bool, + device_tma: bool = False, +) -> dict[str, Any]: + def are_compatible_scales(size_a, size_b) -> bool: + # Same sized scales are compatible + if len(size_a) == len(size_b): + return True + + # Both need to be scalars or len(1) tensors + if len(size_a) <= 1 and len(size_b) <= 1: + return True + + return False + + size_a, size_b = scale_a.get_size(), scale_b.get_size() + assert are_compatible_scales(size_a, size_b), ( + "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " + f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." + ) + + mm_template_options = mm_options(config, sym_m, sym_n, sym_k, layout) + + mm_template_options["ACC_TYPE"] = "tl.float32" + mm_template_options["USE_FAST_ACCUM"] = use_fast_accum + mm_template_options["SCALING_ROWWISE"] = len(size_a) == 2 + + if device_tma: + mm_template_options["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE + mm_template_options["NUM_SMS"] = get_num_sms() + + mm_template_options.update(tma_options()) + + return mm_template_options + + +def mm_args( + mat1, + mat2, + *others, + layout=None, + out_dtype=None, + use_4x2_dim=False, + mat2_transposed=False, +): + """ + Common arg processing for mm,bmm,addmm,etc + """ + mat1, mat2 = realize_inputs(mat1, mat2) + *b1, m, k1 = mat1.get_size() + if mat2_transposed: + *b2, n, k2 = mat2.get_size() + else: + *b2, k2, n = mat2.get_size() + b = [V.graph.sizevars.guard_equals(a, b) for a, b in zip(b1, b2)] + if use_4x2_dim: + k2 = k2 * 2 + k = V.graph.sizevars.guard_equals(k1, k2) + if layout is None: + from torch._inductor.ir import FixedLayout + + if out_dtype is None: + out_dtype = mat1.get_dtype() + + layout = FixedLayout( + mat1.get_device(), + out_dtype, + [*b, m, n], + ) + else: + assert out_dtype is None, "out_dtype is ignored if layout is specified." + from ..lowering import expand + + others = [realize_inputs(expand(x, layout.size)) for x in others] + + return [m, n, k, layout, mat1, mat2, *others] + + +def mm_config_kwargs(device, exclude_condition, dtype_size=None): + if device == "cpu": + return { + "scale": 0.5, + "exclude": exclude_condition, + } + + if dtype_size and inductor_config.max_autotune_gemm_search_space == "EXHAUSTIVE": + return { + "dtype_size": dtype_size, + } + return {} + + +def addmm_epilogue(dtype, alpha, beta): + def epilogue(acc, bias): + if alpha != 1: + acc = V.ops.mul(acc, V.ops.constant(alpha, dtype)) + if beta != 1: + bias = V.ops.mul(bias, V.ops.constant(beta, dtype)) + return V.ops.add(acc, bias) + + return epilogue + + +def scale_mm_epilogue(): + """ + Create an epilogue function that applies scaling to matrix multiplication result + using the given scale factors. + + Args: + dtype: The data type of the output + scale_a: Scale factor for matrix A + scale_b: Scale factor for matrix B + + Returns: + Epilogue function that takes the accumulator and applies scaling + """ + + def epilogue(acc, inv_a_scale, inv_b_scale, bias=None): + # The epilogue function receives the accumulator (result of mat1 @ mat2) + # and applies the scaling factors + # In the original scaled_mm, we use inverse scales, so we multiply by them + mul_scales = V.ops.mul(inv_a_scale, inv_b_scale) + mul_acc = V.ops.mul(acc, mul_scales) + if bias is not None: + return V.ops.add(mul_acc, bias) + else: + return mul_acc + + return epilogue + + +def _is_static_problem(layout: Layout) -> tuple[bool, bool]: + """ + Check if input tensors and output layout have static shapes and non-zero sizes. + + Args: + layout: Output layout object with a 'size' attribute. + + Returns: + Tuple[bool, bool]: (is_static, is_nonzero) + is_static: True if all shapes are statically known + is_nonzero: True if all dimensions are non-zero + """ + static_shape = True + static_size = PythonWrapperCodegen.statically_known_list_of_ints_or_none( + layout.size + ) + if static_size is None: + nonzero = True + for s in layout.size: + sz = PythonWrapperCodegen.statically_known_int_or_none(s) + if sz is not None and sz == 0: + nonzero = False + break + return False, nonzero + numel = 1 + for dim in static_size: + numel *= dim + nonzero = numel > 0 + return static_shape, nonzero + + +def check_supported_striding(mat_a: TensorBox, mat_b: TensorBox) -> None: + def is_row_major(stride: Sequence[_IntLike]) -> bool: + return stride[-1] == 1 + + def is_col_major(stride: Sequence[_IntLike]) -> bool: + return stride[-2] == 1 + + def has_zero_dim(size: Sequence[_IntLike]) -> bool: + return bool(size[0] == 0 or size[1] == 0) + + # Check mat_a (self) stride requirements + torch._check( + is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()), + lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}", + ) + + # Check mat_b stride requirements + torch._check( + is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()), + lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}", + ) + + +def is_batch_stride_largest(mat1, mat2, layout) -> bool: + """ + Checking if the batch stride is the largest in the stride. + """ + sizes = [mat1.get_size(), mat2.get_size(), layout.size] + strides = [mat1.get_stride(), mat2.get_stride(), layout.stride] + for size, stride in zip(sizes, strides): + assert len(size) == len(stride) == 3, "Expect 3D tensors" + if stride[0] != sympy_product(size[1:]): + return False + + return True diff --git a/phivenv/Lib/site-packages/torch/_inductor/kernel/mm_plus_mm.py b/phivenv/Lib/site-packages/torch/_inductor/kernel/mm_plus_mm.py new file mode 100644 index 0000000000000000000000000000000000000000..db566a5923deb9265a94d31017b8c58c835a7d7f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/kernel/mm_plus_mm.py @@ -0,0 +1,166 @@ +# mypy: allow-untyped-defs + +import torch + +from .. import ir +from ..lowering import lowerings +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + TritonTemplate, +) +from ..utils import use_aten_gemm_kernels, use_triton_template +from ..virtualized import V +from .mm_common import mm_args, mm_grid, mm_options + + +aten = torch.ops.aten + +aten_mm_plus_mm = ExternKernelChoice( + torch.ops.inductor._mm_plus_mm, "torch::inductor::_mm_plus_mm" +) + +mm_plus_mm_template = TritonTemplate( + name="mm_plus_mm", + grid=mm_grid, + debug=False, + source=r""" +{{def_kernel("A", "B", "C", "D")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K1 = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + # K2 = {{size("C", 1)}} + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + stride_cm = {{stride("C", 0)}} + stride_ck = {{stride("C", 1)}} + stride_dk = {{stride("D", 0)}} + stride_dn = {{stride("D", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + if (((stride_am == 1 and stride_ak == M) or (stride_am == K1 and stride_ak == 1)) + and ((stride_cm == 1 and stride_ck == M) or (stride_cm == K1 and stride_ck == 1))): + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + ram = rm % M + + if (((stride_bk == 1 and stride_bn == K1) or (stride_bk == N and stride_bn == 1)) + and ((stride_dk == 1 and stride_dn == K1) or (stride_dk == N and stride_dn == 1))): + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + rbn = rn % N + + rk = tl.arange(0, BLOCK_K) + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + C = C + (ram[:, None] * stride_cm + rk[None, :] * stride_ck) + D = D + (rk[:, None] * stride_dk + rbn[None, :] * stride_dn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k1 in range(K1, 0, -BLOCK_K): + # First matmul with A @ B + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k1, other=0.) + b = tl.load(B, mask=rk[:, None] < k1, other=0.) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + for k2 in range(K1, 0, -BLOCK_K): + + # Second matmul with C @ D + if EVEN_K: + c = tl.load(C) + d = tl.load(D) + else: + c = tl.load(C, mask=rk[None, :] < k2, other=0.) + d = tl.load(D, mask=rk[:, None] < k2, other=0.) + acc += tl.dot(c, d, allow_tf32=ALLOW_TF32) + C += BLOCK_K * stride_ck + D += BLOCK_K * stride_dk + + + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""", + cache_codegen_enabled_for_template=True, +) + + +def tuned_mm_plus_mm(mat1, mat2, mat3, mat4, *, layout=None): + """ + Computes mm(mat1, mat2) + mm(mat3, mat4) + """ + m1, n1, k1, layout1, mat1, mat2 = mm_args(mat1, mat2, layout=layout) + m2, n2, _, layout2, mat3, mat4 = mm_args(mat3, mat4, layout=layout) + device_type = ir.get_device_type(mat1) + + # Optimization is optional, because we can always just not do the fusion + if ( + m1 * n1 == 0 + or m2 * n2 == 0 + or not V.graph.sizevars.statically_known_list_equals( + mat1.get_size(), mat3.get_size() + ) + or not V.graph.sizevars.statically_known_list_equals( + mat2.get_size(), mat4.get_size() + ) + ): + # TODO(jansel): support different K values when this is fixed: + # https://github.com/triton-lang/triton/issues/967 + return lowerings[aten.add]( + lowerings[aten.mm](mat1, mat2), lowerings[aten.mm](mat3, mat4) + ) + + assert layout1 == layout2 + # options to tune from + choices = ( + [aten_mm_plus_mm.bind((mat1, mat2, mat3, mat4), layout1)] + if use_aten_gemm_kernels() + else [] + ) + + mm_configs = V.choices.get_mm_plus_mm_configs(device_type) + if use_triton_template(layout1): + for config in mm_configs(): + # see https://github.com/triton-lang/triton/issues/1298 + # BLOCK_K = K causes llvm error + if V.graph.sizevars.statically_known_lt(config.kwargs["BLOCK_K"], k1): + mm_plus_mm_template.maybe_append_choice( + choices, + input_nodes=(mat1, mat2, mat3, mat4), + layout=layout1, + **mm_options(config, m1, n1, k1, layout1), + ) + + return autotune_select_algorithm( + "mm_plus_mm", choices, [mat1, mat2, mat3, mat4], layout1 + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/kernel/mm_scaled_grouped.py b/phivenv/Lib/site-packages/torch/_inductor/kernel/mm_scaled_grouped.py new file mode 100644 index 0000000000000000000000000000000000000000..ba3005a0e666c1df703b2bff355ee4a7b6ffb17f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/kernel/mm_scaled_grouped.py @@ -0,0 +1,741 @@ +# mypy: allow-untyped-defs +import logging +from dataclasses import dataclass +from typing import Any, Optional + +import torch +from torch._dynamo.utils import counters +from torch._inductor.runtime.triton_compat import tl +from torch._inductor.virtualized import V +from torch.utils._triton import has_triton + +from ..ir import ChoiceCaller, Layout, TensorBox +from ..lowering import register_lowering +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + realize_inputs, + TritonTemplate, +) +from ..utils import ( + get_gpu_shared_memory, + get_num_sms, + has_free_symbols, + use_aten_gemm_kernels, +) +from .mm_common import ( + _is_static_problem, + check_supported_striding, + persistent_grouped_mm_grid, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + + +@dataclass +class Config: + kwargs: dict[str, int] + num_stages: int + num_warps: int + + +_NV_CONFIGS = [ + Config( + { + "BLOCK_M": block_size_m, + "BLOCK_N": block_size_n, + "BLOCK_K": block_size_k, + "NUM_CONSUMER_GROUPS": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + for block_size_m in [16, 32, 64, 128] + for block_size_n in [64, 128, 256] + for block_size_k in [64, 128, 256] + for num_stages in [3, 4] + for num_warps in [4, 8] +] + + +def grouped_mm_configs(): + return _NV_CONFIGS + + +def early_config_prune(g, m, configs, named_args): + dtsize = 1 + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, num_consumer_groups = ( + kw["BLOCK_M"], + kw["BLOCK_N"], + kw["BLOCK_K"], + config.num_stages, + config.num_warps, + getattr(config, "num_consumer_groups", 0), + ) + + # 1. Prune NV configs depending on g and m. + if not has_free_symbols((g, m)): + a_is_2d, b_is_2d = named_args["A_IS_2D"], named_args["B_IS_2D"] + m_avg = m // g if a_is_2d and not b_is_2d else m + if m_avg <= 16: + if BLOCK_M > 32: + continue + elif m_avg <= 32: + if BLOCK_M > 64: + continue + elif m_avg <= 64: + if BLOCK_M <= 16: + continue + else: + if BLOCK_M <= 32: + continue + + # 2. make sure we have enough smem + max_shared_memory = get_gpu_shared_memory() + + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory > max_shared_memory: + continue + + use_warp_specialization = num_consumer_groups >= 1 + + # 3. make sure we can partition for ws + if use_warp_specialization: + if num_warps != 4: + continue + + # "tritongpu-warp-spec-data-partition" + m_slice = BLOCK_M // num_consumer_groups + n_slice = BLOCK_N // num_consumer_groups + if m_slice < 64 and n_slice < 256: + continue + + pruned_configs.append(config) + + return pruned_configs + + +triton_grouped_mm_source = r""" +{%- if SCALED %} +{%- if A_IS_2D or B_IS_2D %} +{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr", "offsets_ptr")}} +{%- else %} +{{def_kernel("a_ptr", "b_ptr", "scale_a_ptr", "scale_b_ptr")}} +{%- endif %} +{%- else %} +{%- if A_IS_2D or B_IS_2D %} +{{def_kernel("a_ptr", "b_ptr", "offsets_ptr")}} +{%- else %} +{{def_kernel("a_ptr", "b_ptr")}} +{%- endif %} +{%- endif %} + tidx = tl.program_id(0) + +{%- set M_IS_VARYING = A_IS_2D and not B_IS_2D %} +{%- set N_IS_VARYING = not A_IS_2D and B_IS_2D %} +{%- set K_IS_VARYING = A_IS_2D and B_IS_2D %} + +{%- if A_IS_2D %} +{%- if B_IS_2D %} + G = {{size("offsets_ptr", 0)}} +{%- else %} + G = {{size("b_ptr", 0)}} +{%- endif %} +{%- else %} +{%- if B_IS_2D %} + G = {{size("a_ptr", 0)}} +{%- else %} + G = {{size("a_ptr", 0)}} +{%- endif %} +{%- endif %} + + # the b_ptr tensor is given with its last two dims transposed, revert here + + M = {{size("a_ptr", -2)}} + N = {{size("b_ptr", -1)}} + K = {{size("a_ptr", -1)}} + + A_STRIDE_M = {{stride("a_ptr", -2)}} + A_STRIDE_K = {{stride("a_ptr", -1)}} +{%- if not A_IS_2D %} + A_STRIDE_G = {{stride("a_ptr", 0)}} +{%- if SCALED %} + SCALE_A_STRIDE_G = {{stride("scale_a_ptr", 0)}} +{%- endif %} +{%- endif %} + B_STRIDE_N = {{stride("b_ptr", -1)}} + B_STRIDE_K = {{stride("b_ptr", -2)}} +{%- if not B_IS_2D %} + B_STRIDE_G = {{stride("b_ptr", 0)}} +{%- if SCALED %} + SCALE_B_STRIDE_G = {{stride("scale_b_ptr", 0)}} +{%- endif %} +{%- endif %} + +{%- if USE_TMA_LOAD %} +{%- if USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR %} + a_desc = tl._experimental_make_tensor_descriptor( +{%- else %} + a_desc = tl.make_tensor_descriptor( +{%- endif %} + a_ptr, +{%- if A_IS_2D %} + shape=[M, K], + # fixme: strides=[A_STRIDE_M, A_STRIDE_K], + strides=[{{stride("a_ptr", -2)}}, {{stride("a_ptr", -1)}}], + block_shape=[BLOCK_M, BLOCK_K], +{%- else %} + shape=[G, M, K], + # fixme: strides=[A_STRIDE_G, A_STRIDE_M, A_STRIDE_K], + strides=[{{stride("a_ptr", 0)}}, {{stride("a_ptr", -2)}}, {{stride("a_ptr", -1)}}], + block_shape=[1, BLOCK_M, BLOCK_K], +{%- endif %} + ) + +{%- if USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR %} + b_desc = tl._experimental_make_tensor_descriptor( +{%- else %} + b_desc = tl.make_tensor_descriptor( +{%- endif %} + b_ptr, +{%- if B_IS_2D %} + shape=[N, K], + # fixme: strides=[B_STRIDE_N, B_STRIDE_K], + strides=[{{stride("b_ptr", -1)}}, {{stride("b_ptr", -2)}}], + block_shape=[BLOCK_N, BLOCK_K], +{%- else %} + shape=[G, N, K], + # fixme: strides=[B_STRIDE_G, B_STRIDE_N, B_STRIDE_K], + strides=[{{stride("b_ptr", 0)}}, {{stride("b_ptr", -1)}}, {{stride("b_ptr", -2)}}], + block_shape=[1, BLOCK_N, BLOCK_K], +{%- endif %} + ) +{%- endif %} + +{%- if M_IS_VARYING %} + m_end_offset = 0 +{%- endif %} +{%- if N_IS_VARYING %} + n_end_offset = 0 +{%- endif %} +{%- if K_IS_VARYING %} + k_end_offset = 0 +{%- endif %} + iterated_tiles = 0 + for g in tl.range(G): +{%- if M_IS_VARYING %} + # Move across groups + m_start_offset = m_end_offset + m_end_offset = tl.load(offsets_ptr + g) + m_size = m_end_offset - m_start_offset +{%- if SCALED %} + m_scale_start_offset = m_start_offset +{%- endif %} +{%- else %} + m_start_offset = 0 + m_size = M +{%- if SCALED %} + m_scale_start_offset = g * M +{%- endif %} +{%- endif %} + +{%- if N_IS_VARYING %} + # Move across groups + n_start_offset = n_end_offset + n_end_offset = tl.load(offsets_ptr + g) + n_size = n_end_offset - n_start_offset +{%- if SCALED %} + n_scale_start_offset = n_start_offset +{%- endif %} +{%- else %} + n_start_offset = 0 + n_size = N +{%- if SCALED %} + n_scale_start_offset = g * N +{%- endif %} +{%- endif %} + + if m_size > 0 and n_size > 0: +{%- if K_IS_VARYING %} + # Move across groups + k_start_offset = k_end_offset + k_end_offset = tl.load(offsets_ptr + g) + k_size = k_end_offset - k_start_offset +{%- else %} + k_start_offset = 0 + k_size = K +{%- endif %} + + num_m_tiles = tl.cdiv(m_size, BLOCK_M) + num_n_tiles = tl.cdiv(n_size, BLOCK_N) + num_tiles = num_m_tiles * num_n_tiles + + # Move across tiles + while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles: + gidx = tidx - iterated_tiles + # Split M first and N second. + tile_m_idx = gidx % num_m_tiles + tile_n_idx = gidx // num_m_tiles + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + +{%- if USE_TMA_LOAD %} + m_offset = (m_start_offset + tile_m_idx * BLOCK_M).to(tl.int32) + n_offset = (n_start_offset + tile_n_idx * BLOCK_N).to(tl.int32) + + for k_offset in range(0, k_size, BLOCK_K): +{%- if A_IS_2D %} + a = a_desc.load([m_offset, k_start_offset + k_offset]) +{%- else %} + a = a_desc.load([g, m_offset, k_start_offset + k_offset]).reshape(BLOCK_M, BLOCK_K) +{%- endif %} +{%- if B_IS_2D %} + b = b_desc.load([n_offset, k_start_offset + k_offset]) +{%- else %} + b = b_desc.load([g, n_offset, k_start_offset + k_offset]).reshape(BLOCK_N, BLOCK_K) +{%- endif %} + +{%- if K_IS_VARYING %} + if k_offset + BLOCK_K > k_size: + group_offs_k = k_offset + tl.arange(0, BLOCK_K) + a = tl.where(group_offs_k < k_size, a, 0) + b = tl.where(group_offs_k < k_size, b, 0) +{%- endif %} + +{%- if USE_FAST_ACCUM %} + accumulator = tl.dot(a, b.T, accumulator) +{%- else %} + accumulator += tl.dot(a, b.T) +{%- endif %} +{%- else %} + offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = k_start_offset + tl.arange(0, BLOCK_K) + a_ptrs = ( + a_ptr +{%- if not A_IS_2D %} + + g * A_STRIDE_G +{%- endif %} + + (m_start_offset + offs_am[:, None]) * A_STRIDE_M + + offs_k[None, :] * A_STRIDE_K + ) + b_ptrs = ( + b_ptr +{%- if not B_IS_2D %} + + g * B_STRIDE_G +{%- endif %} + + (n_start_offset + offs_bn[:, None]) * B_STRIDE_N + + offs_k[None, :] * B_STRIDE_K + ) + for k_offset in range(0, k_size, BLOCK_K): + a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size) + b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size) + if k_offset + BLOCK_K > k_size: + group_offs_k = k_offset + tl.arange(0, BLOCK_K) + a = tl.where(group_offs_k < k_size, a, 0) + b = tl.where(group_offs_k < k_size, b, 0) +{%- if USE_FAST_ACCUM %} + accumulator = tl.dot(a, b.T, accumulator) +{%- else %} + accumulator += tl.dot(a, b.T) +{%- endif %} + a_ptrs += BLOCK_K + b_ptrs += BLOCK_K +{%- endif %} + + offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) + offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) +{%- if SCALED %} + scale_a = tl.load( + scale_a_ptr +{%- if A_IS_2D %} + + m_scale_start_offset +{%- else %} + + g * SCALE_A_STRIDE_G +{%- endif %} + + offs_am[:, None], + mask=offs_am[:, None] < m_size, + ) + scale_b = tl.load( + scale_b_ptr +{%- if B_IS_2D %} + + n_scale_start_offset +{%- else %} + + g * SCALE_B_STRIDE_G +{%- endif %} + + offs_bn[None, :], + mask=offs_bn[None, :] < n_size, + ) + c = accumulator.to(tl.float32) * scale_a * scale_b +{%- else %} + c = accumulator.to(tl.float32) +{%- endif %} + +{%- if M_IS_VARYING %} + idx_m = (m_start_offset + offs_am[:, None]) +{%- else %} + idx_m = offs_am[:, None] +{%- endif %} +{%- if N_IS_VARYING %} + idx_n = (n_start_offset + offs_bn[None, :]) +{%- else %} + idx_n = offs_bn[None, :] +{%- endif %} + mask = offs_am[:, None] < m_size and offs_bn[None, :] < n_size +{%- if M_IS_VARYING or N_IS_VARYING %} + {{store_output(("idx_m", "idx_n"), "c", "mask", indent_width=16)}} +{%- else %} + {{store_output(("g", "idx_m", "idx_n"), "c", "mask", indent_width=16)}} +{%- endif %} + tidx += NUM_SMS + + iterated_tiles += num_tiles +""" + + +triton_grouped_mm_template = TritonTemplate( + name="grouped_mm", + grid=persistent_grouped_mm_grid, + source=triton_grouped_mm_source, +) + +triton_scaled_grouped_mm_template = TritonTemplate( + name="scaled_grouped_mm", + grid=persistent_grouped_mm_grid, + source=triton_grouped_mm_source, +) + + +def grouped_mm_args( + mat1: TensorBox, + mat2: TensorBox, + offs: Optional[TensorBox], + layout=None, + out_dtype=None, +): + mat1, mat2 = realize_inputs(mat1, mat2) + if offs is not None: + realize_inputs(offs) + mat1_size = mat1.get_size() + mat2_size = mat2.get_size() + + m1dim, m2dim = len(mat1_size), len(mat2_size) + + assert m1dim == 2 or m1dim == 3 + assert m2dim == 2 or m2dim == 3 + + if layout is None: + from torch._inductor.ir import FixedLayout + + if out_dtype is None: + out_dtype = mat1.get_dtype() + + dims = [] + if m1dim == 2: + if m2dim == 2: + assert offs is not None + dims = [offs.get_size()[0], mat1_size[0], mat2_size[1]] + else: + dims = [mat1_size[0], mat2_size[-1]] + else: + if m2dim == 2: + dims = [mat1_size[1], mat2_size[1]] + else: + dims = [mat1_size[0], mat1_size[1], mat2_size[-1]] + layout = FixedLayout( + mat1.get_device(), + out_dtype, + dims, + ) + else: + assert out_dtype is None, "out_dtype is ignored if layout is specified." + + return (mat1_size, mat2_size, layout, mat1, mat2, offs) + + +aten__grouped_mm = ExternKernelChoice( + torch._grouped_mm, + "at::_grouped_mm", + op_overload=aten._grouped_mm, + has_out_variant=False, +) + + +aten__scaled_grouped_mm = ExternKernelChoice( + torch._scaled_grouped_mm, + "at::_scaled_grouped_mm", + op_overload=aten._scaled_grouped_mm, + has_out_variant=False, +) + + +def can_use_triton_kernel( + mat_a: TensorBox, + mat_b: TensorBox, + offs: Optional[TensorBox], + bias: Optional[TensorBox], + scale_result: Optional[TensorBox], +) -> bool: + if not ( + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (9, 0) + and not torch.version.hip + ): + return False + if not has_triton(): + return False + + # The _grouped_mm()/_scaled_grouped_mm() operator do not support + # bias nor scale_result yet. + if bias is not None: + return False + if scale_result is not None: + return False + + if len(mat_a.get_size()) == 2 or len(mat_b.get_size()) == 2: + return offs is not None + else: + return offs is None + + +def create_offsets(x, m1_size, m2_size, offs_size): + m1_is_2d = len(m1_size) == 2 + m2_is_2d = len(m2_size) == 2 + if m1_is_2d: + if m2_is_2d: + k = V.graph.sizevars.size_hint(m1_size[1]) + noffs = V.graph.sizevars.size_hint(offs_size[0]) + step = k / noffs + return torch.linspace( + step, k, noffs, dtype=x.get_dtype(), device=x.get_device() + ) + + else: + m = V.graph.sizevars.size_hint(m1_size[0]) + noffs = V.graph.sizevars.size_hint(offs_size[0]) + step = m / noffs + return torch.linspace( + step, m, noffs, dtype=x.get_dtype(), device=x.get_device() + ) + else: + if m2_is_2d: + n = V.graph.sizevars.size_hint(m2_size[0]) + noffs = V.graph.sizevars.size_hint(offs_size[0]) + step = n / noffs + return torch.linspace( + step, n, noffs, dtype=x.get_dtype(), device=x.get_device() + ) + else: + return None + + +def _tuned_grouped_mm_common( + operator_name: str, + algorithm_name: str, + extern_kernel_choice: ExternKernelChoice, + kernel_template: TritonTemplate, + mat_a: TensorBox, + mat_b: TensorBox, + scale_a: Optional[TensorBox] = None, + scale_b: Optional[TensorBox] = None, + offs: Optional[TensorBox] = None, + bias: Optional[TensorBox] = None, + scale_result: Optional[TensorBox] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: Optional[bool] = None, + layout: Optional[Layout] = None, +) -> TensorBox: + assert (scale_a is None) == (scale_b is None) + assert scale_result is None or scale_a is not None + + m1_size, m2_size, layout, mat_a, mat_b, offs = grouped_mm_args( + mat_a, mat_b, offs, layout=layout, out_dtype=out_dtype + ) + counters["aten_mm_info"][operator_name] += 1 + log_message = f"Tuned {operator_name}: mat1_shape=%s, mat2_shape=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s" + log.info( + log_message, + m1_size, + m2_size, + mat_a.get_dtype(), + mat_b.get_dtype(), + layout, + ) + + if scale_a is not None and scale_b is not None: + check_supported_striding(mat_a, mat_b) + + # workaround for Inductor not supporting optional tensor input arguments + input_nodes: list[Any] = [mat_a, mat_b] + if scale_a is not None: + input_nodes.append(realize_inputs(scale_a)) + if scale_b is not None: + input_nodes.append(realize_inputs(scale_b)) + if offs is not None: + input_nodes.append(realize_inputs(offs)) + + if use_fast_accum is None: + aten_choice = extern_kernel_choice.bind( + input_nodes, + layout, + out_dtype=out_dtype, + ) + else: + aten_choice = extern_kernel_choice.bind( + input_nodes, + layout, + out_dtype=out_dtype, + use_fast_accum=use_fast_accum, + ) + if use_fast_accum is None: + use_fast_accum = False + + choices: list[ChoiceCaller] = [] + if use_aten_gemm_kernels(): + choices.append(aten_choice) + + _, is_nonzero = _is_static_problem(layout) + + # Checking only for the equality of corresponding dims of + # multiplicands here, relying on meta function checks for + # everything else. + if is_nonzero and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result): + scaled = scale_a is not None + if len(m1_size) == 2: + if len(m2_size) == 2: + m, k1 = m1_size + k2, _ = m2_size + g = offs.get_size()[0] + V.graph.sizevars.guard_equals(k1, k2) + a_is_2d, b_is_2d = True, True + else: + g1 = offs.layout.size[0] + m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.guard_equals(g1, g2) + V.graph.sizevars.guard_equals(k1, k2) + a_is_2d, b_is_2d = True, False + else: + if len(m2_size) == 2: + g1 = offs.layout.size[0] + g2, m, k1 = m1_size + k2, _ = m2_size + g = V.graph.sizevars.guard_equals(g1, g2) + V.graph.sizevars.guard_equals(k1, k2) + a_is_2d, b_is_2d = False, True + else: + g1, m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.guard_equals(g1, g2) + V.graph.sizevars.guard_equals(k1, k2) + a_is_2d, b_is_2d = False, False + + triton_has_make_tensor_descriptor = hasattr(tl, "make_tensor_descriptor") + triton_has_experimental_make_tensor_descriptor = hasattr( + tl, "_experimental_make_tensor_descriptor" + ) + use_tma_load = ( + triton_has_make_tensor_descriptor + or triton_has_experimental_make_tensor_descriptor + ) + # The make_tensor_descriptor imposes this additional limitation. + use_tma_load = use_tma_load and ( + mat_a.get_stride()[-1] == 1 and mat_b.get_stride()[-2] == 1 + ) + + kwargs = { + "SCALED": scaled, + "A_IS_2D": a_is_2d, + "B_IS_2D": b_is_2d, + "USE_FAST_ACCUM": use_fast_accum, + "NUM_SMS": get_num_sms(), + "USE_TMA_LOAD": use_tma_load, + "USE_EXPERIMENTAL_MAKE_TENSOR_DESCRIPTOR": triton_has_experimental_make_tensor_descriptor, + } + + for config in early_config_prune(g, m, grouped_mm_configs(), kwargs): + kernel_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=layout, + num_stages=config.num_stages, + num_warps=config.num_warps, + **kwargs, + **config.kwargs, + ) + + input_gen_fns = { + 4: lambda x: create_offsets( + x, m1_size, m2_size, offs.get_size() if offs is not None else None + ), + } + return autotune_select_algorithm( + algorithm_name, choices, input_nodes, layout, input_gen_fns=input_gen_fns + ) + + +@register_lowering(aten._grouped_mm.default, type_promotion_kind=None) +def tuned_grouped_mm( + mat_a: TensorBox, + mat_b: TensorBox, + offs: Optional[TensorBox] = None, + bias: Optional[TensorBox] = None, + out_dtype: Optional[torch.dtype] = None, + layout: Optional[Layout] = None, +) -> TensorBox: + """Auto-tuning for _grouped_mm() operator.""" + + return _tuned_grouped_mm_common( + "aten._grouped_mm.default", + "grouped_mm", + aten__grouped_mm, + triton_grouped_mm_template, + mat_a, + mat_b, + None, + None, + offs, + bias, + None, + out_dtype, + None, + layout, + ) + + +@register_lowering(aten._scaled_grouped_mm.default, type_promotion_kind=None) +def tuned_scaled_grouped_mm( + mat_a: TensorBox, + mat_b: TensorBox, + scale_a: TensorBox, + scale_b: TensorBox, + offs: Optional[TensorBox] = None, + bias: Optional[TensorBox] = None, + scale_result: Optional[TensorBox] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, + layout: Optional[Layout] = None, +) -> TensorBox: + """Auto-tuning for _scaled_grouped_mm() operator.""" + + return _tuned_grouped_mm_common( + "aten._scaled_grouped_mm.default", + "scaled_grouped_mm", + aten__scaled_grouped_mm, + triton_scaled_grouped_mm_template, + mat_a, + mat_b, + scale_a, + scale_b, + offs, + bias, + scale_result, + out_dtype, + use_fast_accum, + layout, + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/package/__init__.py b/phivenv/Lib/site-packages/torch/_inductor/package/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..84ef7e8fc4eb7dd0ea5bb77e01e8b7282899a15d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/package/__init__.py @@ -0,0 +1 @@ +from .package import AOTICompiledModel, load_package, package_aoti diff --git a/phivenv/Lib/site-packages/torch/_inductor/package/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/package/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16c1af0d137630c660ac863e2d881ce50c0f80d1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/package/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/package/__pycache__/build_package.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/package/__pycache__/build_package.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..754aee68c163148561355e6337fd23fd7430f04c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/package/__pycache__/build_package.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/package/__pycache__/package.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/package/__pycache__/package.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0abf50f22729f1429a04f9ed694a3e04d97d3f34 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/package/__pycache__/package.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/package/build_package.py b/phivenv/Lib/site-packages/torch/_inductor/package/build_package.py new file mode 100644 index 0000000000000000000000000000000000000000..b4a64dc1369f7b6e33ebc718363d8d031e7ebbe6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/package/build_package.py @@ -0,0 +1,15 @@ +build_package_contents = """ +import os +from pathlib import Path + +from torch._inductor.package.package import compile_so + +curr_dir = Path(__file__).parent +aoti_files = [ + os.path.join(root, file) + for root, dirs, files in os.walk(curr_dir) + for file in files +] + +output_so = compile_so(curr_dir, aoti_files, curr_dir) +""" diff --git a/phivenv/Lib/site-packages/torch/_inductor/package/package.py b/phivenv/Lib/site-packages/torch/_inductor/package/package.py new file mode 100644 index 0000000000000000000000000000000000000000..0666e320a69612c37ab70dae0eb33ea9b9660917 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/package/package.py @@ -0,0 +1,138 @@ +import io +import json +import logging +import os +import tempfile +from typing import IO + +import torch +from torch._inductor import config +from torch._inductor.cpp_builder import BuildOptionsBase, CppBuilder +from torch.export.pt2_archive._package import ( + AOTI_FILES, + AOTICompiledModel, + load_pt2, + package_pt2, +) +from torch.types import FileLike + + +log = logging.getLogger(__name__) + + +def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str: + def get_aoti_file_with_suffix(suffix: str) -> str: + for file in aoti_files: + if file.endswith(suffix): + return file + raise RuntimeError(f"Unable to find file with suffix {suffix}") + + # Compile all the files into a .so + cpp_file = os.path.join(aoti_dir, get_aoti_file_with_suffix(".cpp")) + consts_o = os.path.join(aoti_dir, get_aoti_file_with_suffix(".o")) + + file_name = os.path.splitext(cpp_file)[0] + + # Parse compile flags and build the .o file + with open(file_name + "_compile_flags.json") as f: + compile_flags = json.load(f) + + compile_options = BuildOptionsBase( + **compile_flags, use_relative_path=config.is_fbcode() + ) + object_builder = CppBuilder( + name=file_name, + sources=cpp_file, + BuildOption=compile_options, + ) + output_o = object_builder.get_target_file_path() + object_builder.build() + + # Parse linker flags and build the .so file + with open(file_name + "_linker_flags.json") as f: + linker_flags = json.load(f) + + linker_options = BuildOptionsBase( + **linker_flags, use_relative_path=config.is_fbcode() + ) + so_builder = CppBuilder( + name=os.path.split(so_path)[-1], + sources=[output_o, consts_o], + BuildOption=linker_options, + output_dir=so_path, + ) + output_so = so_builder.get_target_file_path() + so_builder.build() + + # mmapped weights + serialized_weights_filename = file_name + "_serialized_weights.bin" + if serialized_weights_filename in aoti_files: + with open(serialized_weights_filename, "rb") as f_weights: + serialized_weights = f_weights.read() + + with open(output_so, "a+b") as f_so: + so_size = f_so.tell() + # Page align the weights + f_so.write(b" " * (16384 - so_size % 16384)) + f_so.write(serialized_weights) + + return output_so + + +def package_aoti( + archive_file: FileLike, + aoti_files: AOTI_FILES, +) -> FileLike: + """ + Saves the AOTInductor generated files to the PT2Archive format. + + Args: + archive_file: The file name to save the package to. + aoti_files: This can either be a singular path to a directory containing + the AOTInductor files, or a dictionary mapping the model name to the + path to its AOTInductor generated files. + """ + + return package_pt2( + archive_file, + aoti_files=aoti_files, + ) + + +def load_package( + path: FileLike, + model_name: str = "model", + run_single_threaded: bool = False, + num_runners: int = 1, + device_index: int = -1, +) -> AOTICompiledModel: # type: ignore[type-arg] + try: + pt2_contents = load_pt2( + path, + run_single_threaded=run_single_threaded, + num_runners=num_runners, + device_index=device_index, + ) + if model_name not in pt2_contents.aoti_runners: + raise RuntimeError(f"Model {model_name} not found in package") + return pt2_contents.aoti_runners[model_name] + except RuntimeError: + log.warning("Loading outdated pt2 file. Please regenerate your package.") + + if isinstance(path, (io.IOBase, IO)): + with tempfile.NamedTemporaryFile(suffix=".pt2") as f: + # TODO(angelayi): We shouldn't need to do this -- miniz should + # handle reading the buffer. This is just a temporary workaround + path.seek(0) + f.write(path.read()) + log.debug("Writing buffer to tmp file located at %s.", f.name) + loader = torch._C._aoti.AOTIModelPackageLoader( + f.name, model_name, run_single_threaded, num_runners, device_index + ) + return AOTICompiledModel(loader) + + path = os.fspath(path) # AOTIModelPackageLoader expects (str, str) + loader = torch._C._aoti.AOTIModelPackageLoader( + path, model_name, run_single_threaded, num_runners, device_index + ) + return AOTICompiledModel(loader) diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/__init__.py b/phivenv/Lib/site-packages/torch/_inductor/runtime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fb1e3f3614d650dd1e89eafd1fc24fe886d4e65 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/autotune_cache.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/autotune_cache.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..768eea595d784659781606299d7af9ecbc5a411c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/autotune_cache.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/benchmarking.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/benchmarking.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0c2280c05ba9f3fd409672cf8ad94ba14e7e1d2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/benchmarking.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/cache_dir_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/cache_dir_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d43ed085129c720b8c3f942c1072279f9f44758 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/cache_dir_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/compile_tasks.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/compile_tasks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60257c5cdff1382459137dfa47bf953565fe3983 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/compile_tasks.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/coordinate_descent_tuner.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/coordinate_descent_tuner.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3305afbaecf73d393174844ee7a7fe28961e1332 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/coordinate_descent_tuner.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/halide_helpers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/halide_helpers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61ecc24db3feb3be08f7aca718f182eec12c5c03 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/halide_helpers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/hints.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/hints.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..717203f019cfdab29565268343baab71f571f51b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/hints.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/runtime_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/runtime_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59fbf159894a15b49d170fff7dca3ac6fbeab7e2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/runtime_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/static_cuda_launcher.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/static_cuda_launcher.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a21f67c5d340ae34373a10e9a90feb72c924e413 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/static_cuda_launcher.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/triton_compat.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/triton_compat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01545d301d70ae4d81109a354155093e8e758dac Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/triton_compat.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/triton_helpers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/triton_helpers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0db1dab85db0cf96d8c386a9065254dca42ceef1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/triton_helpers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/triton_heuristics.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/triton_heuristics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c531e3e4d5e322a5a2c6145a3b297bc60bdc88f3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_inductor/runtime/__pycache__/triton_heuristics.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/autotune_cache.py b/phivenv/Lib/site-packages/torch/_inductor/runtime/autotune_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..64700dcafbaf1ac2e24346fdf0763b675ac7b57c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/runtime/autotune_cache.py @@ -0,0 +1,640 @@ +""" +PyTorch Inductor Autotuning Cache System + +This module implements a caching system for autotuning configurations in PyTorch's Inductor compiler. +It provides mechanisms to store and retrieve optimal kernel configurations both locally and remotely, +which significantly speeds up compilation by reusing previously discovered optimal parameters. + +The caching system includes: +- Local filesystem caching for individual machine reuse +- Remote caching for sharing optimizations across machines +- Bundled caching to efficiently store multiple related configurations +- Cache invalidation based on PyTorch versions and backend changes +- Serialization/deserialization support for worker processes + +Key components: +- AutotuneCache: Main class for managing cache access and storage +- AutotuneCacheBundler: Bundles multiple cache entries for efficient storage +- LocalAutotuneCache: Handles filesystem-based caching +- _LocalAutotuneCacheBackend: Low-level file operations for cache storage +- AutotuneCacheArtifact: Integration with PyTorch's artifact system + +This caching system is critical for performance as it eliminates the need to re-run +expensive autotuning operations when the same kernels are compiled multiple times. +""" + +from __future__ import annotations + +import dataclasses +import hashlib +import logging +import os +import os.path +import re +from typing import Any, Optional, TYPE_CHECKING +from typing_extensions import override + +import torch +from torch._inductor.runtime.runtime_utils import cache_dir +from torch.compiler._cache import ( + CacheArtifact, + CacheArtifactFactory, + CacheArtifactManager, +) +from torch.utils._triton import has_triton + +from ..remote_cache import ( + create_cache, + JsonDataTy, + RemoteCache, + RemoteCacheBackend, + RemoteCacheJsonSerde, +) +from .triton_compat import Config, HAS_WARP_SPEC + + +if TYPE_CHECKING: + from ..remote_cache import Sample + +log = logging.getLogger(__name__) + + +_InductorMetaTy = dict[str, object] + + +def inductor_meta_from_config() -> _InductorMetaTy: + from torch._inductor import config + + backend_hash = None + if has_triton(): + try: + backend_hash = torch.utils._triton.triton_hash_with_backend() + except RuntimeError: + # This can get the error: + # RuntimeError: 0 active drivers ([]). There should only be one. + pass + + is_hip = None + if torch.version.hip is not None: + is_hip = True + + return { + "autotune_local_cache": config.autotune_local_cache, + "autotune_remote_cache": config.autotune_remote_cache, + "backend_hash": backend_hash, + "bundled_autotune_remote_cache": config.bundled_autotune_remote_cache, + "coordinate_descent_tuning": config.coordinate_descent_tuning, + "is_fbcode": config.is_fbcode(), + "is_hip": is_hip, + } + + +@CacheArtifactFactory.register +class AutotuneCacheArtifact(CacheArtifact): + @override + def populate_cache(self) -> None: + autotune_cache = _LocalAutotuneCacheBackend() + key = os.path.join(cache_dir(), self.key) + autotune_cache._put(key, self.content) + + @override + @staticmethod + def type() -> str: + return "autotune" + + @override + @staticmethod + def encode(content: JsonDataTy) -> bytes: + assert not isinstance(content, bytes) + serde = RemoteCacheJsonSerde() + content_bytes = serde.encode(content) + assert isinstance(content_bytes, bytes) + return content_bytes + + +@dataclasses.dataclass +class AutotuneCache: + configs_hash: str + local_cache: Optional[tuple[RemoteCache[JsonDataTy], str]] = None + remote_cache: Optional[tuple[RemoteCache[JsonDataTy], str]] = None + + # Create a AutotuneCache. Returns None if none of the caches can be used. + @staticmethod + def create( + inductor_meta: _InductorMetaTy, filename: str, configs_hash: str + ) -> Optional[AutotuneCache]: + cache = AutotuneCache(configs_hash) + key = AutotuneCache._prepare_key(filename) + cache._setup_local_cache(inductor_meta, os.path.dirname(filename), key) + cache._setup_remote_autotune_cache(inductor_meta, key) + if cache.local_cache or cache.remote_cache: + return cache + else: + return None + + @staticmethod + def _prepare_key(filename: str) -> str: + from torch.compiler import config as cconfig + + # base of filename is already sha256 hash the source contents + key = f"{os.path.basename(filename)}:{cconfig.cache_key_tag}" + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + # Read the best config options from the most local cache and return it. + def _read(self) -> Optional[dict[str, JsonDataTy]]: + if local_cache := self.local_cache: + cache, key = local_cache + if best_config := cache.get(key): + if isinstance(best_config, dict): + return best_config + + if remote_cache := self.remote_cache: + cache, key = remote_cache + if best_config := cache.get(key): + if isinstance(best_config, dict): + return best_config + + return None + + # Read the best config options from the most local cache and figure out + # which `configs` represents that option. + def read_best( + self, inductor_meta: _InductorMetaTy, configs: list[Config] + ) -> Optional[Config]: + if best := self._read(): + return _load_cached_autotuning( + best, self.configs_hash, configs, inductor_meta + ) + return None + + # Set up local filesystem caching information + def _setup_local_cache( + self, inductor_meta: _InductorMetaTy, dirname: str, cache_key: str + ) -> None: + if not inductor_meta.get("autotune_local_cache", True): + return + + from ..codecache import torch_key + + """ + [Note: torch_key in autotune cache key] + Include torch_key() in the cache key so that different versions + of torch result in cache invalidation. This is important in case + of changes to the best_config format or other code changes that + are not backward compatible w.r.t. the cache. + """ + hasher = hashlib.sha256() + hasher.update(cache_key.encode("utf-8")) + hasher.update(torch_key()) + updated_cache_key = hasher.hexdigest() + + cache_filename = f"{dirname}/{updated_cache_key}.best_config" + local_cache = LocalAutotuneCache() + self.local_cache = (local_cache, cache_filename) + + # Set up remote caching information + def _setup_remote_autotune_cache( + self, inductor_meta: _InductorMetaTy, cache_key: str + ) -> None: + if not _should_use_remote_autotune_cache(inductor_meta): + return + + if (backend_hash := inductor_meta.get("backend_hash", None)) is None: + log.debug( + "backend_hash is not passed on the inductor_meta, unable to use autotune remote cache" + ) + return + assert isinstance(backend_hash, str) + + from ..codecache import torch_key + + is_fbcode = bool(inductor_meta.get("is_fbcode", False)) + + salt = "autotune-best-config-v2" + # re: torch_key - see [Note: torch_key in autotune cache key] + key = torch_key().hex() + backend_hash + self.configs_hash + salt + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + + remote_cache = create_cache( + key, + is_fbcode, + "FbRemoteAutotuneCache", + "RemoteAutotuneCache", + ) + if not remote_cache: + return + + # Save the args passed to create_cache + # in case AutotuneCache needs to be pickled + self.remote_cache_full_key = key + self.is_fbcode = is_fbcode + self.remote_cache = (remote_cache, cache_key) + + # The AutotuneCache may be serialized/deserialized if we're using + # AsyncCompile worker processes to run triton compilation. + # This is because AutotuneCache instances are created on the worker + # process, but we need to run AutotuneCache.save on the parent process + # when actually doing autotuning. + def __getstate__(self) -> dict[str, Any]: + # The remote cache handles themselves may not be serializable + # So clear it and reconstruct it on setstate + remote_cache = getattr(self, "remote_cache", None) + return { + **self.__dict__, + # Save the cache_key portion + "remote_cache": remote_cache and remote_cache[1], + } + + def __setstate__(self, state: dict[str, Any]) -> None: + # Reconstruct the remote cache on the parent class + self.__dict__.update(state) + if self.remote_cache is not None: + assert isinstance(self.remote_cache, str) + assert hasattr(self, "remote_cache_full_key") + assert hasattr(self, "is_fbcode") + cache_key = self.remote_cache + remote_cache = create_cache( + self.remote_cache_full_key, + self.is_fbcode, + "FbRemoteAutotuneCache", + "RemoteAutotuneCache", + ) + if remote_cache is not None: + self.remote_cache = (remote_cache, cache_key) + else: + log.warning("Warning, failed to recreate remote cache after pickling") + self.remote_cache = None + + # Save the config in the caches + def save( + self, + config: Config, + time_taken_ns: int, + found_by_coordesc: bool = False, + triton_cache_hash: Optional[str] = None, + ) -> None: + data = { + **config.kwargs, + "num_warps": config.num_warps, + "num_stages": config.num_stages, + "configs_hash": self.configs_hash, + "found_by_coordesc": found_by_coordesc, + "time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS + "triton_cache_hash": triton_cache_hash, + } + if HAS_WARP_SPEC: + data.update( + { + "num_consumer_groups": getattr(config, "num_consumer_groups", 0), + "num_buffers_warp_spec": getattr( + config, "num_buffers_warp_spec", 0 + ), + } + ) + + if local_cache := self.local_cache: + cache, key = local_cache + cache.put(key, data) + AutotuneCacheBundler.put(key, data) + autotune_artifact_key = os.path.join(*key.split(os.sep)[-2:]) + CacheArtifactManager.record_artifact( + AutotuneCacheArtifact.type(), autotune_artifact_key, data + ) + + if log.isEnabledFor(logging.DEBUG): + type_str = "coordesc" if found_by_coordesc else "heuristic" + log.debug("Save %s tuning result to %s", type_str, key) + + if remote_cache := self.remote_cache: + cache, key = remote_cache + cache.put(key, data) + + +class _AutotuneCacheBundlerImpl: + """ + Caches a set of LocalAutotuneCacheBackend entries together in a single + cache. + """ + + _key: str + _cache: RemoteCache[JsonDataTy] + + # All known entries from LocalAutotuneCache.put() + _entries: dict[str, JsonDataTy] + + def end_compile(self) -> None: + # TODO: Do we need to compute time_taken_ms and encode that somehow? + if self._entries: + self._cache.put(self._key, self._entries) + + def put(self, basename: str, data: JsonDataTy) -> None: + # Do we need to worry about duplicates? We only have a single local fs + # entry - so probably not. + self._entries[basename] = data + + def __init__(self, key: str, cache: RemoteCache[JsonDataTy]) -> None: + self._key = key + self._cache = cache + self._entries = {} + + def sync(self) -> None: + # We don't currently use this - but we could async load starting at + # `begin_compile` and wait for the load to be finished here. + pass + + @classmethod + def _should_use_bundled_autotune_remote_cache( + cls, inductor_meta: _InductorMetaTy + ) -> bool: + # The bundled autotune cache is only available if you've also got local + # caching enabled (because we feed the bundled data to the local cache). + if not inductor_meta.get("autotune_local_cache", True): + return False + + # Check if the we're enabled via config + if ( + bundled_autotune_remote_cache := inductor_meta.get( + "bundled_autotune_remote_cache" + ) + ) is not None: + return bool(bundled_autotune_remote_cache) + + if not cls._get_is_fbcode(inductor_meta): + return False + if torch._utils_internal.is_fb_unit_test(): + return False + if inductor_meta.get("is_hip"): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + jk = torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:bundled_autotune_remote_cache_version" + ) + return REMOTE_CACHE_VERSION >= jk + + def _load_cache(self) -> bool: + from torch._inductor import codecache + + # The single key is defined on construction of the cache. + entries = self._cache.get(self._key) + if entries is None or not isinstance(entries, dict): + # We couldn't load the cache - so mark _entries as non-None so we + # store local cache values. + return False + + # Go through the entries we got from the cache and save them locally. + time_saved_ns = 0 + for basename, data in entries.items(): + # Reconstruct the final filename (see put()) + root, ext = _splitext_nodot(basename) + _, _, filename = codecache.get_path(root, ext) + if isinstance(data, dict) and (tsns := data.get("time_saved_ns")): + time_saved_ns += int(tsns) # type: ignore[arg-type] + local_cache = LocalAutotuneCache() + local_cache.put(filename, data) + + codecache.add_ephemeral_timeout_increase_for_distributed(time_saved_ns) + + return True + + @staticmethod + def _get_is_fbcode(inductor_meta: _InductorMetaTy) -> bool: + return bool(inductor_meta.get("is_fbcode", False)) + + @staticmethod + def _get_backend_hash(inductor_meta: _InductorMetaTy) -> str: + backend_hash = inductor_meta["backend_hash"] + assert isinstance(backend_hash, str) + return backend_hash + + +class AutotuneCacheBundler: + _bundler: Optional[_AutotuneCacheBundlerImpl] = None + + def __init__(self) -> None: + pass + + # Call this before we start any autotune computation for an inductor python + # file. On a cache hit it copies the individual results into the local + # autotune caches. + @classmethod + def begin_compile( + cls, + inductor_meta: _InductorMetaTy, + *, + code: Optional[str] = None, + code_hash: Optional[str] = None, + ) -> None: + assert cls._bundler is None + + if code is not None: + assert code_hash is None, "Cannot specify both code and code_hash" + code_hash = _comment_stripped_hash(code) + assert code_hash is not None + + if not _AutotuneCacheBundlerImpl._should_use_bundled_autotune_remote_cache( + inductor_meta + ): + return + + cache = create_cache( + "bundled-autotune-v1", + _AutotuneCacheBundlerImpl._get_is_fbcode(inductor_meta), + "FbRemoteBundledAutotuneCache", + "RemoteBundledAutotuneCache", + ) + if not cache: + return + + # We're starting a compilation phase. We have a cache key for the code + # we're compiling. We'll get the individual autotune bundles later (via + # self.put()). For now create the AutotuneCacheBundler and try to load + # from the cache. + + salt = "bundled-autotune-best-configs-v1" + backend_hash = _AutotuneCacheBundlerImpl._get_backend_hash(inductor_meta) + # TODO: The autotune cache includes configs_hash in the key. The problem + # is that the configs_hash includes info from the individual pointwise() + # calls (size_hints, for example) which we can't know yet. I *think* + # that info is basically present in the `code_hash` (since it's a + # parameter to the pointwise decorator) - but is there other info we + # need to include from inductor_meta? + key = code_hash + backend_hash + salt + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + + bundler = _AutotuneCacheBundlerImpl(key, cache) + if not bundler._load_cache(): + # We couldn't load from the cache - so save the data so we can store + # the saved autotunes. + cls._bundler = bundler + + # If we get a cache hit don't bother saving any of the individual + # autotune results. + + # Call this after all individual autotune results are finished for a + # inductor python file. If we gathered any individual results then we bundle + # those and put it into the cache. + @classmethod + def end_compile(cls) -> None: + if bundler := cls._bundler: + cls._bundler = None + bundler.end_compile() + + @classmethod + def sync(cls) -> None: + if bundler := cls._bundler: + bundler.sync() + + @classmethod + def put(cls, filename: str, data: JsonDataTy) -> None: + if bundler := cls._bundler: + # The filename comes in as something like + # "/tmp/tmp{random}/{aa}/{basename}.py" (where aa is + # basename[1:3]). Strip it down and make sure that it looks like a path + # we could reconstruct (because it's possible for the caller to + # customize the path). + basename = os.path.basename(filename) + + # TODO: check cache_dir() vs filename, then strip dirname + bundler.put(basename, data) + + +# Remove the comments from the code (which include things like run ids and file +# paths) and then hash the result. +def _comment_stripped_hash(code: str) -> str: + code = re.sub(r"#.*$", "", code, count=0, flags=re.MULTILINE) + return torch._inductor.codecache.code_hash(code) + + +def _should_use_remote_autotune_cache(inductor_meta: _InductorMetaTy) -> bool: + if (config := inductor_meta.get("autotune_remote_cache")) is not None: + return bool(config) + if not inductor_meta.get("is_fbcode"): + return False + if torch._utils_internal.is_fb_unit_test(): + return False + if inductor_meta.get("is_hip"): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:autotune_memcache_version" + ) + + +def _load_cached_autotuning( + best_config: dict[str, JsonDataTy], + configs_hash: str, + configs: list[Config], + inductor_meta: _InductorMetaTy, +) -> Optional[Config]: + if best_config is None: + return None + if best_config.pop("configs_hash", None) != configs_hash: + return None + + # Remove time taken for comparison + best_config.pop("time_taken_ms", None) + + best_config.pop("triton_cache_hash", None) + + if inductor_meta.get("coordinate_descent_tuning") and best_config.pop( + "found_by_coordesc", False + ): + num_warps = best_config.pop("num_warps") + num_stages = best_config.pop("num_stages") + + # Extract common arguments + config_args = { + "num_warps": num_warps, + "num_stages": num_stages, + } + + if HAS_WARP_SPEC: + config_args.update( + { + "num_consumer_groups": best_config.pop("num_consumer_groups", 0), + "num_buffers_warp_spec": best_config.pop( + "num_buffers_warp_spec", 0 + ), + } + ) + + # Create the triton_config with the appropriate arguments + triton_config = Config(best_config, **config_args) + triton_config.found_by_coordesc = True + return triton_config + + matching_configs = [ + cfg + for cfg in configs + if all(val == best_config.get(key) for key, val in cfg.kwargs.items()) + and cfg.num_warps == best_config.get("num_warps") + and cfg.num_stages == best_config.get("num_stages") + ] + if len(matching_configs) != 1: + return None + + return matching_configs[0] + + +class _LocalAutotuneCacheBackend(RemoteCacheBackend[bytes]): + @override + def _get(self, key: str) -> Optional[bytes]: + try: + with open(key, "rb") as fd: + return fd.read() + except FileNotFoundError: + return None + + @override + def _put(self, key: str, data: bytes) -> None: + os.makedirs(os.path.dirname(key), exist_ok=True) + from torch._inductor import codecache + + codecache.write_atomic(key, data) + + +class LocalAutotuneCache(RemoteCache[JsonDataTy]): + def __init__(self) -> None: + backend = _LocalAutotuneCacheBackend() + serde = RemoteCacheJsonSerde() + super().__init__(backend, serde) + + @override + def _get(self, key: str, sample: Optional[Sample]) -> Optional[JsonDataTy]: + AutotuneCacheBundler.sync() + result = super()._get(key, sample) + if result is not None: + assert isinstance(result, dict) + # What? Why are we doing a put() here? Imagine we have a new model + # that reuses some existing kernels that have already been + # compiled. If we didn't do a `put` here (on cache hit) then the new + # model would only bundle *newly* compiled kernels, not existing + # kernels that were already compiled and cached. + AutotuneCacheBundler.put(key, result) + autotune_artifact_key = os.path.join(*key.split(os.sep)[-2:]) + CacheArtifactManager.record_artifact( + AutotuneCacheArtifact.type(), autotune_artifact_key, result + ) + return result + + @override + def _put(self, key: str, value: JsonDataTy, sample: Optional[Sample]) -> None: + AutotuneCacheBundler.put(key, value) + super()._put(key, value, sample) + + +def _splitext_nodot(basename: str) -> tuple[str, str]: + root, ext = os.path.splitext(basename) + if ext: + ext = ext[1:] + return root, ext diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/benchmarking.py b/phivenv/Lib/site-packages/torch/_inductor/runtime/benchmarking.py new file mode 100644 index 0000000000000000000000000000000000000000..aa15611aed20050cae91149dba15bbe8c10a469d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/runtime/benchmarking.py @@ -0,0 +1,290 @@ +import inspect +import time +from functools import cached_property, wraps +from itertools import chain +from statistics import median +from typing import Any, Callable +from typing_extensions import Concatenate, ParamSpec, Self, TypeVar + +import torch +from torch._dynamo.utils import counters, dynamo_timed +from torch._inductor.config import use_experimental_benchmarker + + +logger = torch._logging.getArtifactLogger(__name__, "benchmarking") +use_experimental_benchmarker = ( + use_experimental_benchmarker and torch.cuda.is_available() +) + + +MILLISECONDS_PER_SECOND = 1000 + +P = ParamSpec("P") +T = TypeVar("T") + + +def time_and_count( + fn: Callable[Concatenate[Any, P], T], +) -> Callable[Concatenate[Any, P], T]: + """Wraps `fn` with `dynamo_timed` context, and increments the appropriate dynamo + counters. It is expected that `fn` is a method of `Benchmarker` or one of its + subclasses; typing limitations prevent us from declaring this directly. + """ + + @wraps(fn) + def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T: + fn_qual_name = f"{self.__class__.__name__}.{fn.__name__}" + counters["inductor"][f"benchmarking.{fn_qual_name}"] += 1 + with dynamo_timed(fn_qual_name, log_pt2_compile_event=False): + return fn(self, *args, **kwargs) + + return wrapper + + +class Benchmarker: + def __init__(self: Self) -> None: + pass + + @time_and_count + def benchmark( + self: Self, + fn: Callable[..., Any], + fn_args: tuple[Any, ...], + fn_kwargs: dict[str, Any], + **kwargs: Any, + ) -> float: + """Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the + actual runtime calculation is dictated by the benchmarking implementation, but may be + one of [mean, median, minimum, etc.]). Functions as a convenience wrapper around + device-specific implementations, like `benchmark_cpu` and `benchmark_gpu`. Raises + `ValueError(...)` if we can't safely infer the device type of `fn`; for example, + if multiple device types are found in `fn_args` and `fn_kwargs`, or if no device + types are found. + + Arguments: + - fn: The function to benchmark. + - fn_args: The function's arguments. + - fn_kwargs: The function's kwargs. + + Keyword Arguments: + - **kwargs: The benchmarking implementation's kwargs. + + Returns: + - The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds. + """ + inferred_device = None + for arg_or_kwarg in chain(fn_args, fn_kwargs.values()): + if not isinstance(arg_or_kwarg, torch.Tensor): + continue + if inferred_device is None: + inferred_device = arg_or_kwarg.device + elif arg_or_kwarg.device != inferred_device: + raise ValueError( + "Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!" + ) + if inferred_device is None: + raise ValueError( + "Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950 + ) + _callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731 + if inferred_device == torch.device("cpu"): + return self.benchmark_cpu(_callable, **kwargs) + # TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking + # implementation which was written specifically with CUDA devices in mind, we may want to + # explore alternate implementations for other device types. + return self.benchmark_gpu(_callable, **kwargs) + + @time_and_count + def benchmark_cpu( + self: Self, _callable: Callable[[], Any], warmup: int = 20, rep: int = 100 + ) -> float: + """Benchmark the CPU callable, `_callable`, and return the median runtime, + in milliseconds. + + Arguments: + - _callable: The CPU callable to benchmark. + + Keyword Arguments: + - warmup: Optionally, the duration, in milliseconds, to run `_callable` + before benchmarking starts. + - rep: Optionally, the duration, in milliseconds, to run `_callable` + during benchmarking. + + Returns: + - The median runtime of `_callable`, in milliseconds. + """ + + def run_for(ms: int) -> list[float]: + timings = [] + run_start_t = time.perf_counter() + while True: + start_t = time.perf_counter() + _callable() + end_t = time.perf_counter() + timings.append((end_t - start_t) * MILLISECONDS_PER_SECOND) + if ((end_t - run_start_t) * MILLISECONDS_PER_SECOND) > ms: + break + return timings + + run_for(warmup) + return median(run_for(rep)) + + @time_and_count + def benchmark_gpu(self: Self, *args: Any, **kwargs: Any) -> float: + raise NotImplementedError + + +class TritonBenchmarker(Benchmarker): + @cached_property + def triton_do_bench(self: Self) -> Callable[..., Any]: + """Lazily import Triton's `do_bench`.""" + try: + from triton.testing import do_bench + except ImportError as e: + raise NotImplementedError("requires Triton") from e + return do_bench + + @time_and_count + def benchmark_gpu(self: Self, _callable: Callable[[], Any], **kwargs: Any) -> float: + """Benchmark the GPU callable, `_callable`, and return the runtime, in milliseconds. + + Arguments: + - _callable: The GPU callable to benchmark. + + Keyword Arguments: + - quantiles: Optionally, a tuple of floats denoting the requested quantiles. + - return_mode: Optionally, the requested return mode. Currently, Triton's + `do_bench` supports min, max, mean, and median return modes. + - **kwargs: Additional kwargs passed to Triton's `do_bench`. + + Returns: + - The runtime of `callable`, in milliseconds. If `kwargs["quantiles"]` is specified, + this is the first requested quantile. Else, if `kwargs["return_mode"]` is specified, + this is the requested return mode. Otherwise, this is the median. + """ + do_bench_params = inspect.signature(self.triton_do_bench).parameters + for kwarg in list(kwargs.keys()): + if kwarg not in do_bench_params: + del kwargs[kwarg] + if "quantiles" in kwargs: + return self.triton_do_bench(_callable, **kwargs)[0] + elif "return_mode" in kwargs: + return self.triton_do_bench(_callable, **kwargs) + return self.triton_do_bench(_callable, **kwargs, return_mode="median") + + +class InductorBenchmarker(TritonBenchmarker): + @cached_property + def L2_cache_size(self: Self) -> int: + """Get the L2 cache size, in bytes, of the current device.""" + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + return props.L2_cache_size + + def get_event_pairs( + self: Self, iters: int + ) -> list[tuple[torch.cuda.Event, torch.cuda.Event]]: + """Get `iters` pairs of CUDA events.""" + return [ + ( + torch.cuda.Event(enable_timing=True), + torch.cuda.Event(enable_timing=True), + ) + for _ in range(iters) + ] + + def get_event_pairs_min_timing( + self: Self, event_pairs: list[tuple[torch.cuda.Event, torch.cuda.Event]] + ) -> float: + """Get the minimum timing, in milliseconds, for a group of CUDA event pairs.""" + return min( + [ + start_event.elapsed_time(end_event) + for start_event, end_event in event_pairs + ] + ) + + @time_and_count + def benchmark_gpu( + self: Self, + _callable: Callable[[], Any], + estimation_iters: int = 5, + memory_warmup_iters: int = 100, + benchmark_iters: int = 100, + max_benchmark_duration: int = 25, + **kwargs: Any, + ) -> float: + """Benchmark a GPU callable using a custom benchmarking implementation. + + Arguments: + - _callable: The callable to benchmark. + + Keyword Arguments: + - estimation_iters: Optionally, the number of iterations to run `_callable` + during runtime estimation. + - memory_warmup_iters: Optionally, the number of iterations to flush the L2 + cache before starting benchmarking. + - benchmark_iters: Optionally, the number of iterations to run `_callable` + during the benchmarking. + - max_benchmark_duration: Optionally, the maximum duration of the benchmarking, + in milliseconds. An estimated duration is calculated based on the values + of `memory_warmup_iters` and `benchmark_iters`, along with the estimated + runtime of `_callable` and various other factors, and we then shrink + `benchmark_iters` to fit in the allotted maximum duration. + - **kwargs: Additional kwargs that may be passed to the fallback. + + Returns: + - The minimum runtime of `_callable`, in milliseconds. + """ + # we don't want any outside errors propagating into benchmarking + torch.cuda.synchronize() + + # warmup `_callable` (and catches any failures in the process) + _callable() + torch.cuda.synchronize() + + # see https://github.com/triton-lang/triton/pull/840 for why `dtype=torch.int` + buffer = torch.empty(self.L2_cache_size // 4, dtype=torch.int, device="cuda") + buffer.zero_() + + # estimate the runtime of `_callable` + event_pairs = self.get_event_pairs(estimation_iters) + for start_event, end_event in event_pairs: + buffer.zero_() + start_event.record() + _callable() + end_event.record() + torch.cuda.synchronize() + estimated_timing = self.get_event_pairs_min_timing(event_pairs) + + # adjust `benchmark_iters` to fit in the maximum benchmarking duration + benchmark_iters = max( + min(benchmark_iters, int(max_benchmark_duration // estimated_timing)), 1 + ) + + # do the memory warmup + for _ in range(memory_warmup_iters): + buffer.zero_() + + # benchmark `_callable` + event_pairs = self.get_event_pairs(benchmark_iters) + for start_event, end_event in event_pairs: + buffer.zero_() + start_event.record() + _callable() + end_event.record() + torch.cuda.synchronize() + benchmarked_timing = self.get_event_pairs_min_timing(event_pairs) + + # explicitly delete the buffer, sometimes helps memory + # footprint metrics in OSS Inductor performance benchmarks + del buffer + + # return the minimum of `estimated_timing` and `benchmarked_timing`, + # we just want the minimum timing overall so we might as well check both + return min(estimated_timing, benchmarked_timing) + + +benchmarker = ( + InductorBenchmarker() if use_experimental_benchmarker else TritonBenchmarker() +) diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/cache_dir_utils.py b/phivenv/Lib/site-packages/torch/_inductor/runtime/cache_dir_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1bad34be0a7366f9fce34380c9eac3113751efb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/runtime/cache_dir_utils.py @@ -0,0 +1,54 @@ +import getpass +import os +import re +import tempfile +from collections.abc import Generator +from contextlib import contextmanager + +from torch._environment import is_fbcode + + +# Factoring out to file without torch dependencies + + +def cache_dir() -> str: + cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") + if cache_dir is None: + os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = default_cache_dir() + os.makedirs(cache_dir, exist_ok=True) + return cache_dir + + +def default_cache_dir() -> str: + sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser()) + return os.path.join( + tempfile.gettempdir() if not is_fbcode() else "/var/tmp", + "torchinductor_" + sanitized_username, + ) + + +def triton_cache_dir(device: int) -> str: + if (directory := os.getenv("TRITON_CACHE_DIR")) is not None: + return directory + return os.path.join( + cache_dir(), + "triton", + str(device), + ) + + +@contextmanager +def temporary_cache_dir(directory: str) -> Generator[None, None, None]: + from torch._inductor.utils import clear_caches + + original = os.environ.get("TORCHINDUCTOR_CACHE_DIR") + os.environ["TORCHINDUCTOR_CACHE_DIR"] = directory + try: + clear_caches() + yield + finally: + clear_caches() + if original is None: + del os.environ["TORCHINDUCTOR_CACHE_DIR"] + else: + os.environ["TORCHINDUCTOR_CACHE_DIR"] = original diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/compile_tasks.py b/phivenv/Lib/site-packages/torch/_inductor/runtime/compile_tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2f65486f8118535aac031f9868f49d85fabdd7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/runtime/compile_tasks.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import functools +import linecache +import os +import sys +import time +import warnings +from pathlib import Path +from types import ModuleType +from typing import Any, Callable, TYPE_CHECKING + + +if TYPE_CHECKING: + from torch._inductor.runtime.triton_heuristics import CachingAutotuner + + +def _reload_python_module( + key: str, path: str, set_sys_modules: bool = True +) -> ModuleType: + with open(path) as f: + try: + code = compile(f.read(), path, "exec", dont_inherit=True) + except Exception as e: + raise RuntimeError( + f"Failed to import {path}\n{type(e).__name__}: {e}" + ) from None + mod = ModuleType(f"{__name__}.{key}") + mod.__file__ = path + mod.key = key # type: ignore[attr-defined] + exec(code, mod.__dict__, mod.__dict__) + if set_sys_modules: + sys.modules[mod.__name__] = mod + return mod + + +@functools.cache +def _set_triton_ptxas_path() -> None: + if os.environ.get("TRITON_PTXAS_PATH") is not None: + return + ptxas = Path(__file__).absolute().parents[1] / "bin" / "ptxas" + if not ptxas.exists(): + return + if ptxas.is_file() and os.access(ptxas, os.X_OK): + os.environ["TRITON_PTXAS_PATH"] = str(ptxas) + else: + warnings.warn(f"{ptxas} exists but is not an executable") + + +def _worker_compile_triton( + load_kernel: Callable[[], CachingAutotuner], + extra_env: dict[str, str], + extra_config: dict[str, Any], +) -> tuple[CachingAutotuner, int]: + _set_triton_ptxas_path() + os.environ.update(extra_env) + from torch._inductor import config + + with config.patch(extra_config): + start_ns = time.time_ns() + kernel = load_kernel() + kernel.precompile(warm_cache_only=True) + elapsed_ns = time.time_ns() - start_ns + kernel.prepare_for_pickle() + # We can release this memory in the compile subprocesses: + linecache.clearcache() + return kernel, elapsed_ns // 1000 diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/coordinate_descent_tuner.py b/phivenv/Lib/site-packages/torch/_inductor/runtime/coordinate_descent_tuner.py new file mode 100644 index 0000000000000000000000000000000000000000..a15aaa05508ed040286c1132d5966064395d4d47 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -0,0 +1,302 @@ +# mypy: allow-untyped-defs +import copy +import itertools +import logging +from typing import Callable, Optional, TYPE_CHECKING + +from .hints import TRITON_MAX_BLOCK +from .runtime_utils import red_text, triton_config_to_hashable + + +if TYPE_CHECKING: + from .triton_compat import triton + + +log = logging.getLogger(__name__) + + +def get_field(config, name): + if name == "num_warps": + return config.num_warps + elif name == "num_stages": + return config.num_stages + elif name == "waves_per_eu": + return config.kwargs.get(name, int(8 // config.num_warps)) + else: + return config.kwargs.get(name, None) + + +def set_field(config, name, value): + if name == "num_warps": + config.num_warps = value + elif name == "num_stages": + config.num_stages = value + else: + config.kwargs[name] = value + + +class CoordescTuner: + """ + The coordinate descent tuner. Tune one field/coordinate at a time. + + TODO will it be necessary to tune multiple fields simultaneously. + + + TODO: what if both increasing and decreasing a field can improve perf. + i.e., there are multiple local optima.. + """ + + def __init__( + self, is_mm=False, name="unknown", size_hints=None, inductor_meta=None + ): + self.is_mm = is_mm # we will tune num_stages for mm + self.cached_benchmark_results = {} + self.name = name + self.size_hints = size_hints + self.inductor_meta = inductor_meta or {} + + def get_config_max(self, prefix: str) -> int: + max_block = TRITON_MAX_BLOCK[prefix.upper()] + size_hint = self.size_hints.get(prefix) if self.size_hints is not None else None + return min(max_block, size_hint) if size_hint is not None else max_block + + def get_warpsmax(self): + # Currently, CUDA has a maximum of 1024 threads, so 32 is the max + # number of warps. + return 1024 // 32 + + def cache_benchmark_result(self, config, timing): + self.cached_benchmark_results[triton_config_to_hashable(config)] = timing + + def lookup_in_cache(self, config): + return self.cached_benchmark_results.get(triton_config_to_hashable(config)) + + def call_func(self, func, config): + found = self.lookup_in_cache(config) + if found is not None: + log.debug(" CACHED") + return found + timing = func(config) + self.cache_benchmark_result(config, timing) + return timing + + @property + def tunable_fields(self): + out = [ + "XBLOCK", + "YBLOCK", + "ZBLOCK", + # NOTE: we should not tune R0_BLOCK for persistent reduction. + # We rely on the fact that persistent reduction's triton.Config + # does not have the R0_BLOCK field to guarantee that. + "R0_BLOCK", + "R1_BLOCK", + # the following 3 are for mm + "BLOCK_M", + "BLOCK_N", + "BLOCK_K", + "num_warps", + ] + if self.is_mm: + out.append("num_stages") + if self.inductor_meta.get("is_hip") is True: + out.append("waves_per_eu") + + return out + + def value_too_large(self, name: str, val: int) -> bool: + block_suffix = "BLOCK" + if name.endswith(block_suffix): + prefix = name.strip(block_suffix).lower() + return val > self.get_config_max(prefix) + if name == "num_warps": + return val > self.get_warpsmax() + if name == "waves_per_eu": + return val > 8 + + return False + + def get_neighbour_values(self, name, orig_val, radius=1, include_self=False): + """ + Get neighbour values in 'radius' steps. The original value is not + returned as it's own neighbour. + """ + assert radius >= 1 + + def update(cur_val, inc=True): + if name == "num_stages": + if inc: + return cur_val + 1 + else: + return cur_val - 1 + else: + if inc: + return cur_val * 2 + else: + return cur_val // 2 + + out = [] + # increment loop + cur_val = orig_val + for _ in range(radius): + cur_val = update(cur_val, True) + if self.value_too_large(name, cur_val): + break + out.append(cur_val) + + # decrement loop + cur_val = orig_val + for _ in range(radius): + cur_val = update(cur_val, False) + if cur_val <= 0: + break + out.append(cur_val) + + if include_self: + out.append(orig_val) + return out + + @staticmethod + def has_improvement(baseline, test): + threshold = 0.001 # 0.1% + return test is not None and test < baseline * (1 - threshold) + + def check_all_tuning_directions( + self, + func: Callable[["triton.Config"], float], + best_config, + best_timing, + ): + """ + Check all directions. We only do this once the regular coordinate + descent tuning find no better choices any more. + We only have a few tunable fields, so this should be fine. + """ + candidate_values_list = [] + effective_fields = [] + for field in self.tunable_fields: + old_value = get_field(best_config, field) + if old_value is None: + continue + candidate_values = self.get_neighbour_values( + field, + old_value, + radius=self.inductor_meta.get("coordinate_descent_search_radius", 1), + include_self=True, + ) + candidate_values_list.append(candidate_values) + effective_fields.append(field) + + choices = itertools.product(*candidate_values_list) + improved = False + for choice in choices: + assert len(choice) == len(effective_fields) + candidate_config = copy.deepcopy(best_config) + for new_val, field in zip(choice, effective_fields): + set_field(candidate_config, field, new_val) + cmp_res, candidate_timing = self.compare_config( + func, candidate_config, best_config, best_timing + ) + if cmp_res: + improved = True + best_config = candidate_config + best_timing = candidate_timing + + return improved, best_config, best_timing + + def compare_config(self, func, candidate_config, best_config, best_timing): + """ + Check if candidate_config is better than best_config. + + Return a tuple of (compare_result, candidate_timing). + compare_result is true iff candidate_config is better. + """ + log.debug("Try config %s", candidate_config) + try: + candidate_timing = self.call_func(func, candidate_config) + except Exception as e: + log.debug("Got exception %s", e) + return False, float("inf") + + if self.has_improvement(best_timing, candidate_timing): + log.debug( + "Tune from %s %f -> %s %f", + best_config, + best_timing, + candidate_config, + candidate_timing, + ) + + return True, candidate_timing + return False, candidate_timing + + def autotune( + self, + func: Callable[["triton.Config"], float], + baseline_config: "triton.Config", + baseline_timing: Optional[float] = None, + ) -> "triton.Config": + if baseline_timing is None: + baseline_timing = self.call_func(func, baseline_config) + + log.debug("= Do coordinate descent tuning for %s =", self.name) + log.debug( + "Baseline Config %s, baseline timing %f", baseline_config, baseline_timing + ) + improved = True + best_config = baseline_config + best_timing = baseline_timing + tunable_fields = self.tunable_fields + + while improved: + improved = False + + for name in tunable_fields: + cur_val = get_field(best_config, name) + # some kernel don't have R0_BLOCK/YBLOCK/ZBLOCK. So cur_val may be None + if cur_val is None: + continue + + # It's possible that candidate_values is empty. + # E.g., if XBLOCK is 1 initially and size_hint for x is also 1. + # We would not try either larger or smaller XBLOCK in this case. + candidate_values = self.get_neighbour_values(name, cur_val) + + for next_val in candidate_values: + candidate_config = copy.deepcopy(best_config) + set_field(candidate_config, name, next_val) + + cmp_res, candidate_timing = self.compare_config( + func, candidate_config, best_config, best_timing + ) + if cmp_res: + improved = True + best_config, best_timing = candidate_config, candidate_timing + + if not improved and self.inductor_meta.get( + "coordinate_descent_check_all_directions" + ): + old_best_timing = best_timing + improved, best_config, best_timing = self.check_all_tuning_directions( + func, best_config, best_timing + ) + + if improved: + msg = red_text( + "Coordinate descend tuning found improvement of %.3fx by looking in all directions." + ) + log.debug( + msg, + old_best_timing / best_timing, + ) + + log.debug( + "Improve from %s %f -> %s %f, %.3fx", + baseline_config, + baseline_timing, + best_config, + best_timing, + baseline_timing / best_timing, + ) + + return best_config diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/halide_helpers.py b/phivenv/Lib/site-packages/torch/_inductor/runtime/halide_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..043bb64ee08a8d1b7ae9a23ba05053ff42e1353b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/runtime/halide_helpers.py @@ -0,0 +1,118 @@ +# mypy: allow-untyped-defs +try: + import halide as hl # type: ignore[import-untyped, import-not-found] +except ImportError: + hl = None + +PHILOX_N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox + +if hl is not None: + PHILOX_KEY_A_U32 = hl.u32(0x9E3779B9) + PHILOX_KEY_B_U32 = hl.u32(0xBB67AE85) + PHILOX_ROUND_A_U32 = hl.u32(0xD2511F53) + PHILOX_ROUND_B_U32 = hl.u32(0xCD9E8D57) +else: + PHILOX_KEY_A_U32 = None + PHILOX_KEY_B_U32 = None + PHILOX_ROUND_A_U32 = None + PHILOX_ROUND_B_U32 = None + + +def _pair_uniform_to_normal(u1, u2): + """Box-Muller transform""" + u1 = hl.max(hl.f32(1.0e-7), u1) + th = hl.f32(6.283185307179586) * u2 + r = hl.sqrt(hl.f32(-2.0) * hl.log(u1)) + return r * hl.cos(th), r * hl.sin(th) + + +def _uint_to_uniform_float(x): + """ + Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1). + """ + + # TODO: + # conditions can be simplified + # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1) + # https://github.com/triton-lang/triton/blob/e4a0d93ff1a367c7d4eeebbcd7079ed267e6b06f/python/triton/language/random.py#L116-L132. + assert x.type() == hl.UInt(32) or x.type() == hl.Int(32) + x = hl.cast(hl.Int(32), x) + scale = hl.f64(4.6566127342e-10) + x = hl.select(x < 0, -x - 1, x) + return x * scale + + +def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds): + def umulhi(a, b): + a = hl.cast(hl.UInt(64), a) + b = hl.cast(hl.UInt(64), b) + return hl.cast(hl.UInt(32), ((a * b) >> 32) & hl.u64(0xFFFFFFFF)) + + for _ in range(n_rounds): + _c0, _c2 = c0, c2 + + c0 = umulhi(PHILOX_ROUND_B_U32, _c2) ^ c1 ^ k0 + c2 = umulhi(PHILOX_ROUND_A_U32, _c0) ^ c3 ^ k1 + c1 = PHILOX_ROUND_B_U32 * _c2 + c3 = PHILOX_ROUND_A_U32 * _c0 + # raise key + k0 = k0 + PHILOX_KEY_A_U32 + k1 = k1 + PHILOX_KEY_B_U32 + + return c0, c1, c2, c3 + + +def halide_philox(seed, c0, c1, c2, c3, n_rounds): + seed = hl.cast(hl.UInt(64), seed) + + assert c0.type().bits() == 32 + + seed_hi = hl.cast(hl.UInt(32), (seed >> 32) & hl.u64(0xFFFFFFFF)) + seed_lo = hl.cast(hl.UInt(32), seed & hl.u64(0xFFFFFFFF)) + + return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds) + + +def randint4x(seed, offset, n_rounds): + offset = hl.cast(hl.UInt(32), offset) + _0 = hl.u32(0) + return halide_philox(seed, offset, _0, _0, _0, n_rounds) + + +def rand4x(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): + i1, i2, i3, i4 = randint4x(seed, offset, n_rounds) + u1 = _uint_to_uniform_float(i1) + u2 = _uint_to_uniform_float(i2) + u3 = _uint_to_uniform_float(i3) + u4 = _uint_to_uniform_float(i4) + return u1, u2, u3, u4 + + +def randint(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): + ret, _, _, _ = randint4x(seed, offset, n_rounds) + return ret + + +def rand(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): + source = randint(seed, offset, n_rounds) + return _uint_to_uniform_float(source) + + +def randn(seed, offset): + i1, i2, _, _ = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT) + u1 = _uint_to_uniform_float(i1) + u2 = _uint_to_uniform_float(i2) + n1, _ = _pair_uniform_to_normal(u1, u2) + return n1 + + +def randint64(seed, offset, low, high): + r0, r1, _r2, _r3 = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT) + r0 = hl.cast(hl.UInt(64), r0) + r1 = hl.cast(hl.UInt(64), r1) + + result = r0 | (r1 << 32) + size = high - low + result = result % hl.cast(hl.UInt(64), size) + result = hl.cast(hl.Int(64), result) + low + return result diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/hints.py b/phivenv/Lib/site-packages/torch/_inductor/runtime/hints.py new file mode 100644 index 0000000000000000000000000000000000000000..89a6a747b4d1c8495bb89d7614460490d2d86027 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/runtime/hints.py @@ -0,0 +1,221 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import collections +import functools +import typing +from enum import auto, Enum +from typing import Optional, Union + +from torch.utils._triton import has_triton_package + + +# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values +# NOTE: if these fail asserts submit a PR to increase them +TRITON_MAX_BLOCK = { + "X": 4096, + "Y": 1024, + "Z": 1024, + "R0_": 4096 * 16, # * 16 is multi-kernel only + "R1_": 2048 * 16, # * 16 is multi-kernel only +} +TRITON_MAX_RSPLIT = 64 + + +class ReductionHint(Enum): + INNER = 0 + OUTER = 1 + OUTER_TINY = 2 + DEFAULT = 3 + + +class TileHint(Enum): + SQUARE = 0 + DEFAULT = 1 + + +# Define `AttrsDescriptorWrapper` function with clear conditional handling +if has_triton_package(): + import triton + import triton.backends.compiler + import triton.compiler.compiler + + if hasattr(triton.backends.compiler, "AttrsDescriptor"): + # Triton 3.2.0 - the second implementation + from triton.backends.compiler import AttrsDescriptor + + def AttrsDescriptorWrapper( + divisible_by_16=None, + equal_to_1=None, + ): + # Prepare the arguments for AttrsDescriptor + kwargs = { + "tt.divisibility": divisible_by_16, + "tt.equal_to": equal_to_1, + } + + # Instantiate AttrsDescriptor with the prepared arguments + res = AttrsDescriptor.from_dict( + {"arg_properties": kwargs, "cls": AttrsDescriptor.__name__} + ) + assert res.property_values["tt.divisibility"] == 16 + assert res.property_values["tt.equal_to"] == 1 + return res + + elif hasattr(triton.compiler.compiler, "AttrsDescriptor"): + # Triton 3.0.0 - the original implementation + from triton.compiler.compiler import AttrsDescriptor + + def AttrsDescriptorWrapper( + divisible_by_16=None, + equal_to_1=None, + ): + # Prepare the arguments for AttrsDescriptor + kwargs = { + "divisible_by_16": divisible_by_16, + "equal_to_1": equal_to_1, + } + + # Instantiate AttrsDescriptor with the prepared arguments + return AttrsDescriptor(**kwargs) + + else: + # Triton in 2025: + # note: there's also a range of triton commits not currently supported + # from ~Dec 9, 2024 to Jan 1 2025, in which AttrsDescriptors are still + # used, but the contents are different. + + def AttrsDescriptorWrapper( + divisible_by_16=None, + equal_to_1=None, + ): + return {(x,): [["tt.divisibility", 16]] for x in divisible_by_16} + +else: + # Define a namedtuple as a fallback when AttrsDescriptor is not available + AttrsDescriptorWrapper = collections.namedtuple( # type: ignore[no-redef, name-match] + "AttrsDescriptor", + ["divisible_by_16", "equal_to_1"], + defaults=[(), ()], + ) + + +_NUM_THREADS_PER_WARP = 32 + + +class HeuristicType(Enum): + PERSISTENT_REDUCTION = auto() + POINTWISE = auto() + REDUCTION = auto() + SPLIT_SCAN = auto() + TEMPLATE = auto() + USER_AUTOTUNE = auto() + FIXED = auto() + + +class AutotuneHint(Enum): + ONE_ELEMENT_PER_THREAD = 0 + + # Triton codegen tries to codegen set of AutotuneHints. + # Enum.__repr__ looks like """ + # which isn't valid python. + # Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32". + __repr__ = Enum.__str__ + + +class DeviceProperties(typing.NamedTuple): + """Copy device properties into a data structure not requiring torch to be imported""" + + type: str # type: ignore[assignment] + index: int # type: ignore[assignment] + multi_processor_count: int + cc: int + major: Optional[int] = None + regs_per_multiprocessor: Optional[int] = None + max_threads_per_multi_processor: Optional[int] = None + warp_size: Optional[int] = None + + @classmethod + @functools.cache + def create(cls, device) -> DeviceProperties: + import torch + from torch._dynamo.device_interface import get_interface_for_device + + device_type = device.type + + if torch.version.hip and device_type == "cuda": + device_type = "hip" + + device_interface = get_interface_for_device(device) + props = device_interface.get_device_properties(device) + try: + multi_processor_count = props.multi_processor_count + except AttributeError: + if device_type == "xpu": + multi_processor_count = props.gpu_subslice_count + elif device_type == "mps": + # TODO: Fetch the actual value from ioreg + multi_processor_count = 8 + else: + raise + return cls( + type=device_type, + index=device.index, + multi_processor_count=multi_processor_count, + cc=device_interface.get_compute_capability(device), + major=getattr(props, "major", None), + regs_per_multiprocessor=getattr(props, "regs_per_multiprocessor", None), + max_threads_per_multi_processor=getattr( + props, "max_threads_per_multi_processor", None + ), + warp_size=getattr(props, "warp_size", 32 if device_type != "cpu" else None), + ) + + +class HalideInputSpec(typing.NamedTuple): + ctype: str + name: str + shape: Optional[list[str]] = None + stride: Optional[list[str]] = None + offset: Optional[str] = None + alias_of: Optional[str] = None + + def bindings_type(self) -> str: + if self.ctype in ("at::Half*", "at::BFloat16*"): + return "uint16_t*" # half not defined + return self.ctype + + def halide_type(self) -> str: + if self.ctype == "at::Half*": + return "halide_type_t(halide_type_float, 16)" # half not defined + if self.ctype == "at::BFloat16*": + return "halide_type_t(halide_type_bfloat, 16)" # half not defined + return f"halide_type_of<{self.ctype.replace('*', '')}>()" + + def is_scalar(self) -> bool: + return self.shape is None + + def is_buffer(self) -> bool: + return self.shape is not None + + +class HalideMeta(typing.NamedTuple): + argtypes: list[HalideInputSpec] + target: str + scheduler: Optional[str] = None + scheduler_flags: Optional[dict[str, Union[int, str]]] = None + cuda_device: Optional[int] = None + + def args(self) -> list[str]: + """Command line args to pass to halide generator""" + args = [f"target={self.target}"] + if self.scheduler: + args.append(f"autoscheduler={self.scheduler}") + if self.scheduler_flags: + assert self.scheduler + for k, v in self.scheduler_flags.items(): + args.append(f"autoscheduler.{k}={v}") + return args + + def is_cuda(self) -> bool: + return self.cuda_device is not None diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/runtime_utils.py b/phivenv/Lib/site-packages/torch/_inductor/runtime/runtime_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..50f85e4d655cb1722a13756f49fbeb73da40c3d9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/runtime/runtime_utils.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +import functools +import operator +from typing import Any, TYPE_CHECKING + +import torch + +# NOTE: other files rely on the imports below +from torch._dynamo import callback as compilation_callback # noqa: F401 +from torch._inductor.runtime.cache_dir_utils import ( # noqa: F401 + cache_dir, + default_cache_dir, + triton_cache_dir, +) + + +if TYPE_CHECKING: + from collections.abc import Hashable + + from .triton_compat import Config + + +def conditional_product(*args: int) -> int: + return functools.reduce(operator.mul, [x for x in args if x]) + + +def ceildiv(number: int, denom: int) -> int: + return -(number // -denom) + + +def is_power_of_2(n: int) -> bool: + """Returns whether n = 2 ** m for some integer m.""" + return n > 0 and n & n - 1 == 0 + + +def next_power_of_2(n: int) -> int: + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n |= n >> 32 + n += 1 + return n + + +def get_num_bytes(*args: torch.Tensor, num_in_out_args: int = 0) -> int: + """ + Return the total number of bytes the arguments of tensor type takes. + + For in/out args, tensor sizes are counted twice: once for reading and + once for writing. + + The first num_in_out_args arguments are in out tensors. + """ + return sum( + arg.numel() * arg.element_size() * (1 + int(i < num_in_out_args)) + for i, arg in enumerate(args) + if isinstance(arg, torch.Tensor) + ) + + +def triton_config_to_hashable(cfg: Config) -> Hashable: + """ + Convert triton config to a tuple that can uniquely identify it. We can use + the return value as a dictionary key. + """ + items = sorted(cfg.kwargs.items()) + items.append(("num_warps", cfg.num_warps)) + items.append(("num_stages", cfg.num_stages)) + return tuple(items) + + +def validate_triton_config(cfg: Config) -> None: + # [Note: Triton pre_hook in inductor] + # pre-hook is a lambda function, which we don't attempt to serialize. + # right now, if a pre-hook is attached to the config, it will not be saved; + # and then it won't be used when the config is loaded from cache. + # So we assert - if we do get a pre_hook, it might get ignored after caching. + assert getattr(cfg, "pre_hook", None) is None, ( + "triton configs with pre_hooks not supported" + ) + + +def create_bandwidth_info_str( + ms: float, + num_gb: float, + gb_per_s: float, + prefix: str = "", + suffix: str = "", + color: bool = True, +) -> str: + info_str = f"{prefix}{ms:.3f}ms \t{num_gb:.3f} GB \t {gb_per_s:7.2f}GB/s{suffix}" + slow = ms > 0.012 and gb_per_s < 650 + return red_text(info_str) if color and slow else info_str + + +def get_max_y_grid() -> int: + return 65535 + + +try: + import colorama + + HAS_COLORAMA = True +except ModuleNotFoundError: + HAS_COLORAMA = False + colorama = None # type: ignore[assignment] + + +if HAS_COLORAMA: + + def _color_text(msg: str, color: str) -> str: + return getattr(colorama.Fore, color.upper()) + msg + colorama.Fore.RESET + +else: + + def _color_text(msg: str, color: str) -> str: + return msg + + +def green_text(msg: str) -> str: + return _color_text(msg, "green") + + +def yellow_text(msg: str) -> str: + return _color_text(msg, "yellow") + + +def red_text(msg: str) -> str: + return _color_text(msg, "red") + + +def blue_text(msg: str) -> str: + return _color_text(msg, "blue") + + +def get_first_attr(obj: Any, *attrs: str) -> Any: + """ + Return the first available attribute or throw an exception if none is present. + """ + for attr in attrs: + if hasattr(obj, attr): + return getattr(obj, attr) + + raise AssertionError(f"{obj} does not has any of the attributes: {attrs}") + + +dynamo_timed = torch._dynamo.utils.dynamo_timed # type: ignore[has-type] + + +def triton_hash_to_path_key(key: str) -> str: + # In early versions of Triton, the hash is directly used in the path name. + # Later, the hash is converted to base64 before being used in the path name. + # Later, the base64 conversion was replaced to the base32 + # + # This code tries to import _base64 and falls back to _base32 if _base64 is unavailable. + # + # To handle this, try to import the to-base64-conversion function. + # If it exists, use it; otherwise, try using _base32; if both are unavailable, use the hash directly. + try: + from triton.runtime.cache import _base64 + + return _base64(key) + except Exception: + try: + from triton.runtime.cache import _base32 + + return _base32(key) + except Exception: + return key + + +def compile_mps_shader(source: str) -> Any: + """ + Compiles shader source but raise more actionable error message when needed + """ + try: + return torch.mps.compile_shader(source) + except SyntaxError as err: + raise SyntaxError(f"failed to compile {source} with {err.msg}") from err diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/static_cuda_launcher.py b/phivenv/Lib/site-packages/torch/_inductor/runtime/static_cuda_launcher.py new file mode 100644 index 0000000000000000000000000000000000000000..46961bcdfd395ba2b59dad625635e062999ab68f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/runtime/static_cuda_launcher.py @@ -0,0 +1,237 @@ +import functools +import os +from typing import Any, Optional +from typing_extensions import Unpack + +from .triton_compat import ASTSource, CompiledKernel, knobs as triton_knobs + + +class StaticallyLaunchedCudaKernel: + """ + Parses the metadata of a CompiledKernel from Triton into a structure that can + launch the cuda kernel directly. Only works for triton kernels compiled to cubin. + + Doing this avoids C++ codegen and compilation during compile, since we can use a + statically compiled library to launch the kernel. To avoid mallocing for the arguments, + we have a launcher for different numbers of arguments up to a max. StaticCudaLauncher + only supports # of arguments up until 10 for now. + + Workflow: + Compile time: + 1. Compile a kernel with triton and get a CompiledKernel + 2. Instantiate kernel = StaticallyLaunchedCudaKernel(triton_kernel) + 3. Write to a cubin file: kernel.write_cubin_to_file(filepath) + 4. Call kernel.load_kernel() (CUDA should be initialized by this point) to load the cubin + Runtime: + 5. Call kernel.run(grid, stream, args) to launch the kernel + + Note that after step 3, StaticallyLaunchedCudaKernel is fully pickleable/serializable. + This allows it to be cached by FXGraphCache/TritonBundler, as well as sent from the worker + to the parent process in inductor. + + There are two main versions of triton that we wish to support: 3.3 and 3.2. Triton makes considerable changes + to how it handles constants in 3.3, so there's some special logic necessary to handle both versions. + """ + + def __init__(self, kernel: CompiledKernel) -> None: + self.name = kernel.src.fn.__name__ + self.cubin_raw = kernel.asm.get("cubin", None) + self.cubin_path = kernel._cubin_path + + # Used by torch.compile to filter constants in older triton versions + self.arg_names = kernel.src.fn.arg_names + + # Const exprs that are declared by the triton kernel directly + # Used to generate the kernel launcher's def args + self.declared_constexprs = kernel.src.fn.constexprs + + self.hash = kernel.hash + + if triton_knobs is None: + launch_enter = kernel.__class__.launch_enter_hook + launch_exit = kernel.__class__.launch_exit_hook + else: + launch_enter = triton_knobs.runtime.launch_enter_hook + launch_exit = triton_knobs.runtime.launch_exit_hook + + if launch_enter is not None or launch_exit is not None: + raise NotImplementedError( + "We don't support launch enter or launch exit hooks" + ) + self.num_warps = kernel.metadata.num_warps + self.shared = ( + kernel.shared if hasattr(kernel, "shared") else kernel.metadata.shared + ) + + # Newer triton versions pass an extra global scratch parameter to the compiled cuda kernel. + # Inductor never uses this field or enables it, but we still have to pass + # an extra None into the set of params if its enabled + if hasattr(kernel.metadata, "global_scratch_size"): + if kernel.metadata.global_scratch_size > 0: + raise NotImplementedError("Global scratch not yet supported") + else: + self.has_global_scratch = True + else: + self.has_global_scratch = False + + self.arg_tys = self.arg_ty_from_signature(kernel.src) + self.function: Optional[int] = ( + None # Loaded by load_kernel(on the parent process) + ) + num_ctas = 1 + if hasattr(kernel, "num_ctas"): + num_ctas = kernel.num_ctas + elif hasattr(kernel, "metadata"): + num_ctas = kernel.metadata.num_ctas + + if num_ctas != 1: + raise NotImplementedError( + "Static cuda launcher only supports num_ctas == 1" + ) + + def reload_cubin_from_raw(self, filepath: str) -> str: + """ + If the cubin file triton generated gets deleted under us, we can + reload it from the raw cubin file. + """ + if self.cubin_path is None: + assert self.cubin_raw is not None + os.makedirs(os.path.dirname(filepath), exist_ok=True) + with open(filepath, "wb") as f: + f.write(self.cubin_raw) + self.cubin_path = filepath + return self.cubin_path + + def load_kernel(self, device: int) -> None: + from torch._C import _StaticCudaLauncher + + if self.function is not None: + return + + assert hasattr(self, "cubin_path") + assert self.cubin_path is not None + (self.function, self.n_regs, self.n_spills) = _StaticCudaLauncher._load_kernel( + self.cubin_path, self.name, self.shared, device + ) + # Don't need the cubin path anymore now that we've loaded + self.cubin_path = None + self.cubin_raw = None + + @staticmethod + @functools.lru_cache + def type_mappings() -> dict[str, str]: + return { + "i1": "i", + "i8": "b", + "i16": "h", + "i32": "i", + "i64": "l", + "u1": "I", + "u8": "B", + "u16": "H", + "u32": "I", + "u64": "K", + "fp16": "f", + "bf16": "f", + "fp32": "f", + "f32": "f", + "fp64": "d", + # TODO handle nvTmaDesc/CUtensormap + } + + def extract_type(self, ty: str) -> str: + """ + Takes a triton type from CompiledKernel.signature and + converts it into a single char encoding. _StaticCudaLauncher + will switch on this char to figure out what type the underlying + value should be passed to the triton kernel as. + """ + if ty[0] == "*": + return "O" + elif ty == "nvTmaDesc": + raise NotImplementedError("nvTmaDesc kernels are not yet supported") + return StaticallyLaunchedCudaKernel.type_mappings()[ty] + + def arg_ty_from_signature(self, src: ASTSource) -> str: + def index_key(i: Any) -> int: + if isinstance(i, str): + return src.fn.arg_names.index(i) + elif isinstance(i, tuple): + # In triton 3.3, src.fn.constants has tuples as a key + return i[0] + else: + return i + + signature = {index_key(key): value for key, value in src.signature.items()} + # Triton uses these as the main way to filter out constants passed to their cubin + constants = [index_key(key) for key in getattr(src, "constants", dict())] + # This value is always a superset of kernel.fn.constexprs: kernel.fn.constexprs are + # constants declared by the triton kernel directly, whereas this list can have + # constants that are unused by the triton kernel that triton figured out during + # compilation. + self.full_constexprs = constants + # Despite requiring them to be passed in, the triton CUDA launcher + # completely ignores the constexprs passed into it when generating code. + # So we can ignore them here too + params = [] + + for i in sorted(signature.keys()): + ty = signature[i] + # In newer triton versions, constants are passed in to signature with type `constexpr` + # In older triton versions, there can be constants in src.constants that are not `constexpr` in signature + # so we check both here + if ty == "constexpr" or i in constants: + pass + else: + params.append(self.extract_type(ty)) + return "".join(params) + + def __getstate__(self) -> dict[str, Any]: + # Remove objects that are no longer valid for pickling + state = self.__dict__.copy() + state["function"] = None + # Cubin paths aren't consistent across processes, so we clear + # and reload them. + state["cubin_path"] = None + return state + + def run( + self, + grid_x: int, + grid_y: int, + grid_z: int, + stream: int, + *args: Unpack[tuple[object, ...]], + ) -> None: + """Actually run the kernel at runtime. This function is the hot codepath.""" + from torch._C import _StaticCudaLauncher + + # Assert load_kernel() has been called and args match + assert self.function is not None + + # TODO: actually, if the args *don't* match, we probably should + # throw an exception. But if inductor is the only one calling this + # thing, it should always match. + # Get rid of constants before passing to cubin launcher + + # Add a None if triton wants an extra parameter to the cubin + if self.has_global_scratch: + arg_tys = self.arg_tys + "O" + args = (*args, None) + else: + arg_tys = self.arg_tys + assert len(args) == len(arg_tys) + + # TODO: can handle grid functions here or in C++, so + # that we don't need the grid handler above. + _StaticCudaLauncher._launch_kernel( + self.function, + grid_x, + grid_y, + grid_z, + self.num_warps, + self.shared, + arg_tys, + args, + stream, + ) diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/triton_compat.py b/phivenv/Lib/site-packages/torch/_inductor/runtime/triton_compat.py new file mode 100644 index 0000000000000000000000000000000000000000..fcdea4f3d5d7c5b4e1d1da17c562181ad0798914 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/runtime/triton_compat.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import inspect +from typing import Any, Union + +import torch + + +try: + import triton +except ImportError: + triton = None + + +if triton is not None: + import triton.language as tl + from triton import Config + from triton.compiler import CompiledKernel + from triton.runtime.autotuner import OutOfResources + from triton.runtime.jit import KernelInterface + + try: + from triton.runtime.autotuner import PTXASError + except ImportError: + + class PTXASError(Exception): # type: ignore[no-redef] + pass + + try: + from triton.compiler.compiler import ASTSource + except ImportError: + ASTSource = None + + try: + from triton.backends.compiler import GPUTarget + except ImportError: + + def GPUTarget( + backend: str, + arch: Union[int, str], + warp_size: int, + ) -> Any: + if torch.version.hip: + return [backend, arch, warp_size] + return (backend, arch) + + # In the latest triton, math functions were shuffled around into different modules: + # https://github.com/triton-lang/triton/pull/3172 + try: + from triton.language.extra import libdevice + + libdevice = tl.extra.libdevice # noqa: F811 + math = tl.math + except ImportError: + if hasattr(tl.extra, "cuda") and hasattr(tl.extra.cuda, "libdevice"): + libdevice = tl.extra.cuda.libdevice + math = tl.math + elif hasattr(tl.extra, "intel") and hasattr(tl.extra.intel, "libdevice"): + libdevice = tl.extra.intel.libdevice + math = tl.math + else: + libdevice = tl.math + math = tl + + try: + from triton.language.standard import _log2 + except ImportError: + + def _log2(x: Any) -> Any: + raise NotImplementedError + + def _triton_config_has(param_name: str) -> bool: + if not hasattr(triton, "Config"): + return False + if not hasattr(triton.Config, "__init__"): + return False + return param_name in inspect.signature(triton.Config.__init__).parameters + + HAS_WARP_SPEC = ( + hasattr(tl, "async_task") + and _triton_config_has("num_consumer_groups") + and _triton_config_has("num_buffers_warp_spec") + ) + + try: + from triton import knobs + except ImportError: + knobs = None + + builtins_use_semantic_kwarg = ( + "_semantic" in inspect.signature(triton.language.core.view).parameters + ) +else: + + def _raise_error(*args: Any, **kwargs: Any) -> Any: + raise RuntimeError("triton package is not installed") + + class OutOfResources(Exception): # type: ignore[no-redef] + pass + + class PTXASError(Exception): # type: ignore[no-redef] + pass + + Config = object + CompiledKernel = object + KernelInterface = object + ASTSource = None + GPUTarget = None + _log2 = _raise_error + libdevice = None + math = None + knobs = None + builtins_use_semantic_kwarg = False + + class triton: # type: ignore[no-redef] + @staticmethod + def jit(*args: Any, **kwargs: Any) -> Any: + return _raise_error + + class tl: # type: ignore[no-redef] + @staticmethod + def constexpr(val: Any) -> Any: + return val + + tensor = Any + dtype = Any + + HAS_WARP_SPEC = False + + +def cc_warp_size(cc: Union[str, int]) -> int: + if torch.version.hip: + cc_str = str(cc) + if "gfx10" in cc_str or "gfx11" in cc_str: + return 32 + else: + return 64 + else: + return 32 + + +try: + autograd_profiler = torch.autograd.profiler +except AttributeError: # Compile workers only have a mock version of torch + + class autograd_profiler: # type: ignore[no-redef] + _is_profiler_enabled = False + + +__all__ = [ + "Config", + "CompiledKernel", + "OutOfResources", + "KernelInterface", + "PTXASError", + "ASTSource", + "GPUTarget", + "tl", + "_log2", + "libdevice", + "math", + "triton", + "cc_warp_size", + "knobs", +] diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/triton_helpers.py b/phivenv/Lib/site-packages/torch/_inductor/runtime/triton_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..cfcf5e1ee1aeccef86ec86f6ada28e1fa4ec08d0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/runtime/triton_helpers.py @@ -0,0 +1,737 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import math as pymath +import warnings +from functools import wraps +from typing import Any, Callable, TypeVar + +from .triton_compat import ( # noqa: F401 + _log2, + builtins_use_semantic_kwarg, + libdevice, + math, + tl, + triton, +) + + +_T = TypeVar("_T") +_LOG_2_E: tl.constexpr = tl.constexpr(pymath.log2(pymath.e)) + + +def set_driver_to_cpu(): + driver = triton.runtime.driver + if backend := triton.backends.backends.get("cpu", None): + if isinstance(driver.active, backend.driver): + # Don't re-initialize backend if it is already active + return + driver.set_active(backend.driver()) + return + # This can be a hard error once triton-cpu is merged into fbcode + warnings.warn( + "Could not find an active CPU backend. Generated kernels will not be executable!" + ) + + +def set_driver_to_gpu(): + driver = triton.runtime.driver + for name, backend in triton.backends.backends.items(): + if backend.driver.is_active() and name != "cpu": + # After https://github.com/triton-lang/triton/commit/b844d519bc5e86edf00fe6b3c6c2d1badcd509a4, + # `driver.active` can be of `LazyProxy` type and the sign of this - `_obj` attribute. + if ( + isinstance(driver.active, backend.driver) + or hasattr(driver.active, "_obj") + and isinstance(driver.active._obj, backend.driver) + ): + # Don't re-initialize backend if it is already active + return + driver.set_active(backend.driver()) + return + raise RuntimeError("Could not find an active GPU backend") + + +def get_backend_options(): + from triton.runtime import driver + + target = driver.active.get_current_target() + backend = triton.compiler.compiler.make_backend(target) + options = backend.parse_options(dict()) + return options.__dict__ + + +@triton.jit +def promote_to_tensor(x): + # Addition promotes to tensor for us + return x + tl.zeros((1,), tl.int1) + + +@triton.jit +def div_floor_integer(a, b): + # NOTE: a // b is C division, but we want floor division + # Based on c10::div_floor_integer + quot = a // b + remainder = a % b + fixed = tl.where(remainder != 0, quot - 1, quot) + return tl.where((a < 0) != (b < 0), fixed, quot) + + +@triton.jit +def remainder_integer(a, b): + # NOTE: a % b matches C division, not floor division + remainder = a % b + return tl.where(remainder != 0 and ((a < 0) != (b < 0)), remainder + b, remainder) + + +@triton.jit +def is_floating(x): + return promote_to_tensor(x).dtype.is_floating() + + +@triton.jit +def _prod_accumulate(a, b): + return a * b + + +@triton.jit +def prod(input, axis): + return tl.reduce(input, axis, _prod_accumulate) + + +@triton.jit +def minimum(a, b): + mask = a < b + if is_floating(a): + mask |= a != a + return tl.where(mask, a, b) + + +@triton.jit +def maximum(a, b): + mask = a > b + if is_floating(a): + mask |= a != a + return tl.where(mask, a, b) + + +@triton.jit +def min2(a, dim): + return tl.reduce(a, dim, minimum) + + +@triton.jit +def max2(a, dim): + return tl.reduce(a, dim, maximum) + + +@triton.jit +def minimum_with_index(a_value, a_index, b_value, b_index): + mask = a_value < b_value + equal = a_value == b_value + if is_floating(a_value): + a_isnan = a_value != a_value + b_isnan = b_value != b_value + mask |= a_isnan and not b_isnan + # Consider NaNs as equal + equal |= a_isnan and b_isnan + + # Prefer lowest index if values are equal + mask |= equal & (a_index < b_index) + return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) + + +@triton.jit +def maximum_with_index(a_value, a_index, b_value, b_index): + mask = a_value > b_value + equal = a_value == b_value + if is_floating(a_value): + a_isnan = a_value != a_value + b_isnan = b_value != b_value + mask |= a_isnan and not b_isnan + # Consider NaNs as equal + equal |= a_isnan and b_isnan + + # Prefer lowest index if values are equal + mask |= equal & (a_index < b_index) + return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) + + +@triton.jit +def min_with_index(value, index, dim): + return tl.reduce((value, index), dim, minimum_with_index) + + +@triton.jit +def max_with_index(value, index, dim): + return tl.reduce((value, index), dim, maximum_with_index) + + +@triton.jit +def exp(x, use_fast_math: tl.constexpr): + if use_fast_math: + return libdevice.exp2(x * _LOG_2_E) + else: + return math.exp(x) + + +@triton.jit +def online_softmax_reduce(lhs_max, lhs_sum, dim, use_fast_math: tl.constexpr): + out_max = max2(lhs_max, dim) + out_max_keepdim = out_max[:, None] + delta = tl.where(out_max_keepdim == float("-inf"), 0, lhs_max - out_max_keepdim) + out_sum = tl.sum(lhs_sum * exp(delta, use_fast_math), dim) + return out_max, out_sum + + +@triton.jit +def online_softmax_combine(lhs_max, lhs_sum, rhs_max, use_fast_math: tl.constexpr): + """ + When we do combine, we assume lhs is the accumulator and rhs is the next + block of data. + Then rhs_sum is always 1. With that assumption, we can save some registers + and computation. + """ + out_max = maximum(lhs_max, rhs_max) + + lhs_scale = tl.where( + out_max == float("-inf"), 1.0, exp(lhs_max - out_max, use_fast_math) + ) + rhs_scale = tl.where( + out_max == float("-inf"), 1.0, exp(rhs_max - out_max, use_fast_math) + ) + + # Should be + # out_sum = lhs_sum * lhs_scale + rhs_sum * rhs_scale + # but since rhs_sum is all 1, we can simplify it. + out_sum = lhs_sum * lhs_scale + rhs_scale + return out_max, out_sum + + +@triton.jit +def welford_reduce(value, mean, m2, weight, first_iteration): + if first_iteration: + new_weight = tl.full(weight.shape, 1, weight.dtype) + new_mean = value + new_m2 = tl.zeros_like(m2) + else: + delta = value - mean + new_weight = weight + 1 + new_mean = mean + delta / new_weight + new_m2 = m2 + delta * (value - new_mean) + return new_mean, new_m2, new_weight + + +@triton.jit +def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): + delta = mean_2 - mean_1 + new_weight = weight_1 + weight_2 + w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight) + return ( + mean_1 + delta * w2_over_w, + m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w, + new_weight, + ) + + +@triton.jit +def welford(mean, m2, weight, dim): + return tl.reduce((mean, m2, weight), dim, welford_combine) + + +@triton.jit +def device_assert_then(cond, msg, r): + tl.device_assert(cond, msg) + return r + + +@triton.jit +def randint64(seed, offset, low, high): + r0, r1, _r2, _r3 = tl.randint4x(seed, offset) + r0 = r0.to(tl.uint64) + r1 = r1.to(tl.uint64) + result = r0 | (r1 << 32) + size = high - low + result = result % size.to(tl.uint64) + result = result.to(tl.int64) + low + return result + + +@triton.jit +def _any_combine(a, b): + return a | b + + +@triton.jit +def any(a, dim): + return tl.reduce(a, dim, _any_combine) + + +@triton.jit +def bucketize_binary_search( + values: tl.tensor, + boundaries_ptr: tl.tensor, + BOUNDARIES_SIZE: int, + BOUNDARIES_UNDERLYING_NUMEL: int, + BOUNDARIES_STRIDE: int, + boundary_indices: tl.tensor, + indexing_dtype: tl.dtype, + right: "bool", # triton can't handle the unquoted bool annotation + sorter_ptr: tl.tensor, + SORTER_STRIDE: int, + sorter_indices: tl.tensor, +): + """ + See [Note: Inductor bucketize op] + + Inputs: + ------- + values: the values to bucketize. + boundaries_ptr: a pointer to the beginning of the boundaries tensor, in 1-D. + BOUNDARIES_SIZE: the length of the last dimension of the boundaries tensor (i.e. one + individual set of boundaries). + BOUNDARIES_UNDERLYING_NUMEL: the length of the boundaries tensor, in 1-D, ignoring + any striding. + BOUNDARIES_STRIDE: the stride of the last dimension of the boundaries tensor + boundary_indices: a tensor of the same size as "values"; each element is an index + into a 1-D, un-strided boundaries tensor, pointing to the first element in the set + of boundaries used for that value. + indexing_dtype: the dtype used for indexing into the boundaries tensor, and the + return dtype. + right: if true, use boundary intervals closed on the left; otherwise use intervals + closed on the right. + sorter_ptr: an optional pointer to a sorter tensor of the same shape as boundaries, + but potentially different striding. If present, this allows us to treat boundaries + as sorted even if the elements of boundaries are unsorted. + SORTER_STRIDE: must be present if sorter_ptr is non-None; the stride of the last + dimension of the sorter tensor. + sorter_indices: must be present if sorter_ptr is non-None; see "boundary_indices". + BLOCK_SHAPE: the shape of the data block being processed. + """ + + low = tl.zeros(values.shape, dtype=indexing_dtype) + high = tl.full(values.shape, BOUNDARIES_SIZE, dtype=indexing_dtype) + + full_range = BOUNDARIES_SIZE + 1 + while full_range > 1: + mid = (high + low) // 2 + mask = ( + mid * BOUNDARIES_STRIDE + boundary_indices + ) < BOUNDARIES_UNDERLYING_NUMEL and mid < BOUNDARIES_SIZE + mid_indices = ( + mid + if sorter_ptr is None or SORTER_STRIDE is None + else tl.load( + sorter_ptr + sorter_indices + SORTER_STRIDE * mid, + mask=mask, + other=0, + ) + ) + + bucket_upper_bound = tl.load( + boundaries_ptr + boundary_indices + BOUNDARIES_STRIDE * mid_indices, + mask=mask, + other=0, + ) + if right: + is_above = values >= bucket_upper_bound + else: + is_above = values > bucket_upper_bound + + low = tl.where(is_above & mask, mid + 1, low) + high = tl.where(is_above, high, mid) + + full_range = (full_range + 1) // 2 + + return low + + +@triton.jit +def pack_value_flag( + value, + flag, + DTYPE_VALUE_AS_UINT: tl.constexpr, + DTYPE_PACK: tl.constexpr, +): + # Workaround for triton bug, tensor.to doesn't unwrap constexpr values + DTYPE_VALUE_AS_UINT = tl.core._unwrap_if_constexpr(DTYPE_VALUE_AS_UINT) + bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth + uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK) + return flag.to(DTYPE_PACK) | (uv << bitwidth) + + +@triton.jit +def unpack_value( + pack, + DTYPE_VALUE, + DTYPE_VALUE_AS_UINT, +): + # Workaround for triton bug, tensor.to doesn't unwrap constexpr values + DTYPE_VALUE = tl.core._unwrap_if_constexpr(DTYPE_VALUE) + DTYPE_VALUE_AS_UINT = tl.core._unwrap_if_constexpr(DTYPE_VALUE_AS_UINT) + bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth + value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT) + return value_uint.to(DTYPE_VALUE, bitcast=True) + + +@triton.jit +def unpack_flag(pack, DTYPE_FLAG): + return pack.to(DTYPE_FLAG) + + +@triton.jit +def exclusive_scan_decoupled_lookback( + scratch_base, + block_value, + index, + combine_fn, + DTYPE_VALUE_AS_UINT: tl.constexpr, + DTYPE_PACK: tl.constexpr, +): + """Compute exclusive scan of a scalar value between blocks + + Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back + + scratch_base: Pointer to scratch space in global memory + block_value: Scalar value for this block + index: Scalar index of this block relative to the current scan + combine_fn: Function ``(value, value) -> value`` which is scanned over + DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value`` + DTYPE_PACK: Unsigned type twice the width of block_value + + NOTE: This function is limited to values which are 32-bits or less because + we need to pack (value, flag) into a single unsigned int. + """ + # Publish block sum so subsequent blocks don't get stuck waiting for us + DTYPE_VALUE = block_value.dtype + pack = pack_value_flag( + block_value, + tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT), + DTYPE_VALUE_AS_UINT, + DTYPE_PACK, + ) + if index > 0: + tl.atomic_xchg(scratch_base + index, pack, sem="relaxed") + + # Calculate exclusive prefix scan + exclusive_prefix = tl.zeros([], DTYPE_VALUE) + prefix_valid = False + test_target = index - 1 + while test_target >= 0: + # tl.atomic_load + flag = tl.full([], 0, DTYPE_VALUE_AS_UINT) + while flag == 0: + pack = tl.atomic_add(scratch_base + test_target, 0, sem="relaxed") + flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT) + + value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT) + if prefix_valid: + exclusive_prefix = combine_fn(value, exclusive_prefix) + else: + exclusive_prefix = value + prefix_valid = True + + if flag == 2: + test_target = -1 + else: + test_target = test_target - 1 + + # Make inclusive block sum visible to other blocks + if prefix_valid: + inclusive_prefix = combine_fn(exclusive_prefix, block_value) + else: + inclusive_prefix = block_value + pack = pack_value_flag( + inclusive_prefix, + tl.full([], 2, DTYPE_VALUE_AS_UINT), + DTYPE_VALUE_AS_UINT, + DTYPE_PACK, + ) + tl.atomic_xchg(scratch_base + index, pack, sem="relaxed") + return exclusive_prefix + + +@triton.jit +def exclusive_scan_decoupled_lookback_64(scratch_base, block_value, index, combine_fn): + """Compute exclusive scan of a scalar value between blocks + + Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back + + scratch_base: Pointer to scratch space in global memory + block_value: Scalar value for this block, must be 64-bits wide + index: Scalar index of this block relative to the current scan + combine_fn: Function ``(value, value) -> value`` which is scanned over + init: Scalar value equal to the identity of combine_fn + """ + # Publish block sum so subsequent blocks don't get stuck waiting for us + if index > 0: + block_value_u64 = block_value.to(tl.uint64, bitcast=True) + tl.store(scratch_base + 3 * index + 1, block_value_u64) + tl.debug_barrier() + flag_one = tl.full([], 1, tl.uint64) + tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem="release") + + # Calculate exclusive prefix scan + exclusive_prefix = tl.zeros([], block_value.dtype) + prefix_valid = False + test_target = index - 1 + while test_target >= 0: + flag = tl.full([], 0, tl.uint64) + while flag == 0: + flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem="acquire") + + value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32)) + value = value_u64.to(block_value.dtype, bitcast=True) + if prefix_valid: + exclusive_prefix = combine_fn(value, exclusive_prefix) + else: + exclusive_prefix = value + prefix_valid = True + + if flag == 2: + test_target = -1 + else: + test_target = test_target - 1 + + # Make inclusive block sum visible to other blocks + if prefix_valid: + inclusive_prefix = combine_fn(exclusive_prefix, block_value) + else: + inclusive_prefix = block_value + inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True) + tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64) + tl.debug_barrier() + flag_two = tl.full([], 2, tl.uint64) + tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem="release") + + return exclusive_prefix + + +@triton.jit +def frexp(x): + # TODO(isuruf): use inline_asm_elementwise here + y = libdevice.ilogb(x) + 1 + exponent = tl.where(x == 0, 0, y) + mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y)) + return mantissa, exponent + + +@triton.jit +def _compare_and_swap_with_index( + x, + idxs, + rnumel, + flip, + i: tl.constexpr, + n_dims: tl.constexpr, + stable: tl.constexpr, + descending: tl.constexpr, +): + n_outer: tl.constexpr = x.numel >> n_dims + shape: tl.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)] + + idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + + y = tl.reshape(x, shape) + iy = y.to(idtype, bitcast=True) + # slice left/right with 'stride' 2**(n_dims - i - 1) + right_mask = tl.arange(0, 2)[None, :, None].to(idtype) + left_mask = (1 - right_mask).to(idtype) + ileft = tl.broadcast_to(tl.sum(iy * left_mask, 1).to(idtype)[:, None, :], shape) + iright = tl.broadcast_to(tl.sum(iy * right_mask, 1).to(idtype)[:, None, :], shape) + ileft = tl.reshape(ileft, x.shape) + iright = tl.reshape(iright, x.shape) + left = ileft.to(x.dtype, bitcast=True) + right = iright.to(x.dtype, bitcast=True) + + # idx + y_idx = tl.reshape(idxs, shape) + left_idx = tl.broadcast_to( + tl.sum(y_idx * left_mask.to(y_idx.dtype), 1)[:, None, :], shape + ) + right_idx = tl.broadcast_to( + tl.sum(y_idx * right_mask.to(y_idx.dtype), 1)[:, None, :], shape + ) + left_idx = tl.reshape(left_idx, x.shape) + right_idx = tl.reshape(right_idx, x.shape) + + # valid + if rnumel is None: + left_valid_mask = tl.full(x.shape, True, tl.int1) + right_valid_mask = tl.full(x.shape, True, tl.int1) + else: + left_valid_mask = left_idx < rnumel + right_valid_mask = right_idx < rnumel + + # actual compare-and-swap + ix = x.to(idtype, bitcast=True) + + if descending: + cond = left < right + else: + cond = left > right + + if stable: + # When stable sorting, tie break by index + cond = cond | ((left == right) & (left_idx > right_idx)) + + cond = (right_valid_mask > left_valid_mask) | ( + (right_valid_mask == left_valid_mask) & cond + ) + cond = (cond ^ flip).to(tl.int1) + ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix)) + new_idxs = idxs ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(idxs)) + + return ret.to(x.dtype, bitcast=True), new_idxs + + +@triton.jit +def _bitonic_merge_with_index( + x, + idxs, + rnumel, + stage: tl.constexpr, + alternating: tl.constexpr, + n_dims: tl.constexpr, + stable: tl.constexpr, + descending: tl.constexpr, +): + n_outer: tl.constexpr = x.numel >> n_dims + tl.static_assert(stage <= n_dims) + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if alternating: + shape: tl.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage] + flip = tl.reshape( + tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape + ) + else: + flip = False + # perform `stage` rounds of `compare-and-swap` + for i in tl.static_range(stage): + x, idxs = _compare_and_swap_with_index( + x, idxs, rnumel, flip, i + (n_dims - stage), n_dims, stable, descending + ) + return x, idxs + + +@triton.jit +def sort_with_index( + x, # value + idxs, # index + rnumel, # number of elements + dim: tl.constexpr = None, + stable: tl.constexpr = tl.constexpr(False), + descending: tl.constexpr = tl.constexpr(False), +): + x, idxs = tl.broadcast(x, idxs) + # handle default dimension or check that it is the most minor dim + _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim + tl.static_assert( + _dim == len(x.shape) - 1, "only minor dimension is currently supported" + ) + # iteratively run bitonic merge-sort steps + n_dims: tl.constexpr = _log2(x.shape[_dim]) + + for i in tl.static_range(1, n_dims + 1): + x, idxs = _bitonic_merge_with_index( + x, + idxs, + rnumel, + i, + alternating=i < n_dims, + n_dims=n_dims, + stable=stable, + descending=descending, + ) + return x, idxs + + +@triton.jit +def select_one(x, mask, dim, keep_dims=False): + idtype = tl.core.get_int_dtype(x.dtype.primitive_bitwidth, signed=False) + ix = x.to(idtype, bitcast=True) + iy = tl.sum(ix * mask, dim, keep_dims=keep_dims) + return iy.to(x.dtype, bitcast=True) + + +@triton.jit +def x_grid_barrier(sem): + """ + Wait for all other thread blocks in grid sharing same y/z program_id + to reach this barrier before returning. + + Args: + sem: an uint32 semaphores, zero or 0x80000000 initialized. Must be unique to each y/z program ID. + """ + # ensure stores before this are visible + tl.debug_barrier() + + one_i32 = 1 + one_u32 = one_i32.to(tl.uint32) # type: ignore[attr-defined] + expected = tl.num_programs(0).to(tl.uint32) + if tl.program_id(0) == 0: + nb = 0x80000000 - (expected - one_u32) + else: + nb = one_u32 + + old_arrive = tl.atomic_add(sem, nb, sem="release") + + bar_flipped = False + while not bar_flipped: + # want a `ld.acquire.gpu.u32 $0,[$1];` but Triton doesn't have it + current_arrive = tl.atomic_add(sem, 0, sem="acquire") + # current_arrive = tl.load(sem, volatile=True) + bar_flipped = ((old_arrive ^ current_arrive) & 0x80000000) != 0 + + # TODO(jansel): is this needed? + tl.debug_barrier() + + +def triton_builtin(f: Callable[..., _T]) -> Callable[..., _T]: + """ + Decorator to mark a function as a Triton built-in function. These functions + are evaluated at compile time. + + Args: + f (function): The function to be marked as a Triton built-in. + + Returns: + function: The same function, marked as a Triton built-in. + """ + if builtins_use_semantic_kwarg: + # support Triton before and after https://github.com/triton-lang/triton/pull/7054 + @wraps(f) + def wrapper(*args, **kwargs): + kwargs["_builder"] = kwargs["_semantic"] + del kwargs["_semantic"] + return f(*args, **kwargs) + else: + wrapper = f # type: ignore[assignment] + + wrapper.__triton_builtin__ = True # type: ignore[attr-defined] + return wrapper + + +@triton_builtin +def constexpr_next_power_of_2( + n: tl.constexpr, *, _builder: object = None +) -> tl.constexpr: + """ + A version triton.next_power_of_two that can be used within a kernel on constants. + """ + assert isinstance(n, tl.constexpr) + return tl.constexpr(triton.next_power_of_2(n.value)) + + +@triton_builtin +def if_mask(mask: Any, val, *, _builder: object = None) -> tl.constexpr: + """ + Work around triton compile error: `ValueError: `other` cannot be provided without `mask`` + A compile-time to check to return either `val` or `None` depending on the value of mask. + """ + if isinstance(mask, tl.constexpr) and mask.value is None: + return tl.constexpr(None) + return val diff --git a/phivenv/Lib/site-packages/torch/_inductor/runtime/triton_heuristics.py b/phivenv/Lib/site-packages/torch/_inductor/runtime/triton_heuristics.py new file mode 100644 index 0000000000000000000000000000000000000000..f2a9a61c0eab7526f01649039ba12803e1e636b2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_inductor/runtime/triton_heuristics.py @@ -0,0 +1,3022 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import builtins +import copy +import dataclasses +import functools +import hashlib +import inspect +import itertools +import logging +import math +import operator +import os +import os.path +import re +import sys +import threading +import time +from collections import namedtuple +from typing import ( + Any, + Callable, + Generic, + Literal, + Optional, + TYPE_CHECKING, + TypeVar, + Union, +) + +import torch +from torch._dynamo.utils import set_feature_use +from torch._prims_common import compute_required_storage_length +from torch.utils._ordered_set import OrderedSet + +from ..triton_bundler import TritonBundler +from ..utils import prefix_is_reduction, triton_version_uses_attrs_dict +from . import triton_helpers +from .autotune_cache import AutotuneCache +from .benchmarking import benchmarker +from .coordinate_descent_tuner import CoordescTuner +from .hints import ( + _NUM_THREADS_PER_WARP, + AutotuneHint, + DeviceProperties, + HeuristicType, + ReductionHint, + TileHint, + TRITON_MAX_BLOCK, + TRITON_MAX_RSPLIT, +) +from .runtime_utils import ( + ceildiv, + compilation_callback, + conditional_product, + create_bandwidth_info_str, + dynamo_timed, + get_first_attr, + get_max_y_grid, + get_num_bytes, + next_power_of_2, + triton_cache_dir, + triton_config_to_hashable, + triton_hash_to_path_key, + validate_triton_config, +) +from .static_cuda_launcher import StaticallyLaunchedCudaKernel +from .triton_compat import ( + ASTSource, + autograd_profiler, + cc_warp_size, + CompiledKernel, + Config, + GPUTarget, + HAS_WARP_SPEC, + KernelInterface, + knobs, + OutOfResources, + PTXASError, + triton, +) + + +class NoTritonConfigsError(RuntimeError): + pass + + +if TYPE_CHECKING: + from collections.abc import Container, Hashable + + from torch._guards import CompileId + + LauncherType = Any + +_KernelType = Union[CompiledKernel, StaticallyLaunchedCudaKernel] +_T = TypeVar("_T", bound=_KernelType) + +log = logging.getLogger(__name__) + + +def get_total_reduction_numel(numels: dict[str, int]) -> int: + return conditional_product( + *[numel for prefix, numel in numels.items() if prefix_is_reduction(prefix)] + ) + + +def autotune_hints_to_configs( + hints: OrderedSet[AutotuneHint], + size_hints, + block_size: int, + device_props: DeviceProperties, +) -> list[Config]: + """ + AutotuneHints can be attached to the metadata of triton kernels for providing + suggestions about what to try for autotuning. One reason to do this is if there are + some configs that are only useful in specific scenarios, in which case we can avoid + wasting compile time on autotuning unless we know we are in one of those scenarios. + + Based on those hints, this function will generate a list of additional autotuning + configs to try. + """ + xyz_options: tuple[tuple[int, Optional[int], Optional[int]], ...] + configs: list[Config] = [] + for hint in hints: + if hint == AutotuneHint.ONE_ELEMENT_PER_THREAD: + if len(size_hints) == 1: + xyz_options = ((block_size // 4, None, None),) + elif len(size_hints) == 2: + xyz_options = ((block_size // 4, 1, None), (1, block_size // 4, None)) + elif len(size_hints) == 3: + xyz_options = ( + (block_size // 4, 1, 1), + (1, block_size // 4, 1), + (1, 1, block_size // 4), + ) + configs.extend( + triton_config( + size_hints, + *xyz, + num_elements_per_warp=( + device_props.warp_size if device_props.warp_size else 32 + ), + ) + for xyz in xyz_options + ) + + return configs + + +def disable_pointwise_autotuning(inductor_meta): + # Autotuning can give different benchmarking results from run to run, and + # therefore we disable autotuning when use_deterministic flag is on. + if inductor_meta.get("are_deterministic_algorithms_enabled"): + return True + return not inductor_meta.get("autotune_pointwise", True) + + +def _dump_launch_params(args, kwargs, launcher, kernel_name, grid): + call_args = [] + call_kwargs = {} + for arg in args: + if isinstance(arg, (int, bool)): + call_args.append(str(arg)) + else: + call_args.append("T") + for k, v in kwargs.items(): + if isinstance(arg, (int, bool)): + call_kwargs[k] = v + else: + call_kwargs[k] = v + if not triton_version_uses_attrs_dict(): + call_kwargs.update(launcher.config.kwargs) + call_kwargs["num_warps"] = launcher.config.num_warps + call_kwargs["num_stages"] = launcher.config.num_stages + if HAS_WARP_SPEC: + call_kwargs["num_consumer_groups"] = getattr( + launcher.config, "num_consumer_groups", 0 + ) + call_kwargs["num_buffers_warp_spec"] = getattr( + launcher.config, "num_buffers_warp_spec", 0 + ) + args_str = [*call_args] + args_str.extend(f"{k}={v}" for k, v in call_kwargs.items()) + args_str = ", ".join(args_str) + abs_path = os.path.abspath(sys.argv[0]) + with open(f"{abs_path}.launch_params", "a") as f: + f.write(f"{kernel_name} | {args_str} | {grid!r}\n") + + +def check_autotune_cache( + configs: list[Config], filename: Optional[str], inductor_meta: dict[str, Any] +) -> tuple[list[Config], Optional[AutotuneCache], dict[str, Any]]: + """ + Given a list of configs, checks autotune cache and return metadata + """ + autotune_cache = None + autotune_cache_info = {} + disabled = inductor_meta.get("force_disable_caches", False) + if ( + not disabled + and filename is not None + and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning")) + and not os.environ.get("TRITON_INTERPRET", "0") == "1" + ): + configs_hash = hash_configs(configs) + + autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash) + if autotune_cache: + if best_config := autotune_cache.read_best(inductor_meta, configs): + configs = [best_config] + autotune_cache_info["best_config"] = triton_config_to_hashable( + best_config + ) + autotune_cache_info["autotune_cache_state"] = "hit" + + else: + autotune_cache_info["autotune_cache_state"] = "miss" + autotune_cache_info["num_configs"] = len(configs) + if inductor_meta.get("coordinate_descent_tuning"): + autotune_cache_info["coordesc_tuning"] = True + if len(configs) == 1: + # This is the config that coordinate descent tuning started at, which + # is not the same as the final config chosen (i.e. only_config, best_config) + autotune_cache_info["coordesc_tuning_start_config"] = ( + triton_config_to_hashable(configs[0]) + ) + else: + if len(configs) == 1: + autotune_cache_info["autotune_cache_state"] = "only 1 config" + autotune_cache_info["only_config"] = triton_config_to_hashable(configs[0]) + + if disabled: + autotune_cache_info["autotune_cache_state"] = "force_disabled" + log.debug("autotune caching is disabled by config.force_disable_caches") + + return configs, autotune_cache, autotune_cache_info + + +class CachingAutotuner(KernelInterface): + """ + Simplified version of Triton autotuner that has no invalidation + key and caches the best config to disk to improve cold start times. + Unlike the main triton Autotuner, this version can precompile all + configs, and does not rely on the Triton JIT. + """ + + def __init__( + self, + fn, + triton_meta, # passed directly to triton + configs, + save_cache_hook, + mutated_arg_names: list[str], # see [Note: clone mutated buffers] + optimize_mem, + heuristic_type, + size_hints=None, + inductor_meta=None, # metadata not relevant to triton + custom_kernel=False, # whether the kernel is inductor-generated or custom + filename: Optional[str] = None, + reset_to_zero_arg_names: Optional[list[str]] = None, + autotune_cache_info: Optional[dict[str, Any]] = None, + ): + super().__init__() + + assert len(configs) > 0, "Non-empty TritonConfig list required for compiling" + # makes sure there are no pre-hooks on any of the triton configs + for cfg in configs: + validate_triton_config(cfg) + + self.fn = fn + self.device_props: DeviceProperties = triton_meta["device"] + self.triton_meta = { + **triton_meta, + "device": self.device_props.index, + "device_type": self.device_props.type, + } + self.inductor_meta = {} if inductor_meta is None else inductor_meta + self.save_cache_hook = save_cache_hook + self.mutated_arg_names = mutated_arg_names + self.reset_to_zero_arg_names = ( + [] if reset_to_zero_arg_names is None else reset_to_zero_arg_names + ) + self.optimize_mem = optimize_mem + self.configs = configs + self.heuristic_type = heuristic_type + self.custom_kernel = custom_kernel + self.cuda_kernel_saved = False + self.autotune_cache_info = autotune_cache_info + if log.isEnabledFor(logging.DEBUG): + log.debug( + "CachingAutotuner gets %d configs for %s", + len(self.configs), + self.fn.__name__, + ) + for c in self.configs: + log.debug(c) + + self.compile_results: list[CompileResult[_KernelType]] = [] + self.launchers: list[LauncherType] = [] + self.lock = threading.Lock() + if os.getenv("TRITON_CACHE_DIR") is None: + os.environ["TRITON_CACHE_DIR"] = triton_cache_dir( + self.triton_meta.get("device", 0) + ) + log.debug("Triton cache dir: %s", os.environ["TRITON_CACHE_DIR"]) + + self.size_hints = size_hints + self.coordesc_tuner = CoordescTuner( + is_mm=False, + name=self.fn.__name__, + size_hints=size_hints, + inductor_meta=self.inductor_meta, + ) + self.filename = filename + + # used for profiling + self.kernel_hash: str = "" + + # Kernels are stored in the codecache with the filename as a hash of the code. + # We rely on this to obtain the kernel hash + if self.filename is not None: + base_name = os.path.basename(self.filename) + if ".py" in base_name: + self.kernel_hash = os.path.splitext(base_name)[0] + + self.precompile_time_taken_ns = 0 + self.autotune_time_taken_ns = 0 + # Dumps the launch configs after autotuning. + self.dump_launch_params = ( + os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", "0") == "1" + ) + + self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1" + + # Compile-time info included in runtime logginging + self.compile_id: Optional[CompileId] = None + self.is_backward = False + + def is_statically_launchable(self): + """ + Checks if every compiled kernel is statically launchable, which + allows us to efficiently cache it in FXGraphCache + """ + if not self.compile_results: + return False + return all( + isinstance(x, StaticTritonCompileResult) for x in self.compile_results + ) + + def recheck_autotune_cache( + self, reload_kernel_from_src: Callable[[], CachingAutotuner] + ) -> None: + """ + On cache load on static autotuner, we need to recheck the autotune cache, since + a best config could have been found from a previous run + """ + assert self.is_statically_launchable() + + configs = [result.config for result in self.compile_results] + + (cached_configs, _, autotune_cache_info) = check_autotune_cache( + configs, self.filename, self.inductor_meta + ) + self.autotune_cache_info = autotune_cache_info + # I.e. there was an autotune cache hit + if len(cached_configs) == 1 and len(configs) > 1: + best_config = cached_configs[0] + # Grab the best compiled config, if it's in the list of available ones + best_config_hash = triton_config_to_hashable(best_config) + + for compile_result in self.compile_results: + if triton_config_to_hashable(compile_result.config) == best_config_hash: + self.compile_results = [compile_result] + return + + # If the best config isn't in our list of compile results, + # it's likely because it was found by coordesc after the cache + # already saved + if best_config.found_by_coordesc: + with dynamo_timed("CachingAutotuner.slow_precompile_config"): + if self.fn.fn is None: + self.fn = reload_kernel_from_src().fn + self.compile_results = [self._precompile_config(best_config)] + + def set_compile_info( + self, compile_id: Optional[CompileId], is_backward: bool + ) -> None: + self.compile_id = compile_id + self.is_backward = is_backward + + def precompile( + self, + warm_cache_only=False, + reload_kernel: Optional[Callable[[], CachingAutotuner]] = None, + static_triton_bundle_key: Optional[str] = None, + ): + if warm_cache_only: + self._precompile_worker() + return + with self.lock: + # Helper function for reloading a kernel generated in a worker + # in the parent class. Normally we don't need to reload the kernel + # in the parent process, but in certain cases (coordesc tuning, dynamic_scale_rblock), + # we need to actually run compilation on the parent process + if reload_kernel is not None: + self._reload_kernel = reload_kernel + self._precompile_worker() + if static_triton_bundle_key is not None and self.is_statically_launchable(): + TritonBundler.put_static_autotuner(static_triton_bundle_key, self) + self._make_launchers() + self._dynamic_scale_rblock() + + def _precompile_worker(self): + if self.compile_results: + for result in self.compile_results: + TritonBundler.put( + triton_hash_to_path_key(result.kernel.hash), # type: ignore[attr-defined] + self.triton_meta.get("device", 0), + ) + return + assert not self.launchers + if not self.configs: + raise NoTritonConfigsError("No triton configs are available") + + compile_results = [] + exc = None + for c in self.configs: + try: + compile_results.append(self._precompile_config(c)) + except (OutOfResources, PTXASError) as e: + exc = e + if len(compile_results) == 0: + raise NoTritonConfigsError( + f"No valid triton configs. {type(exc).__name__}: {exc}" + ) + self.compile_results = compile_results + self.configs = None + + def _dynamic_scale_rblock(self): + # TODO(jansel): we should find a way to move this extra compile into the worker process + # Currently it relies on _make_launchers(), which requires a cuda context, to populate nreg. + device_prop = self.device_props + if ( + self.inductor_meta.get("dynamic_scale_rblock", True) + and not self.inductor_meta.get("persistent_reduction") + and self.heuristic_type == HeuristicType.REDUCTION + and self.size_hints is not None + # Disable for Intel as Triton is not ready to return n_regs for a compiled_binary. + and device_prop.type in ["cuda", "hip"] + and device_prop.major + and (device_prop.major >= 8 or torch.version.hip) + and device_prop.regs_per_multiprocessor is not None + ): + assert device_prop.regs_per_multiprocessor + assert device_prop.max_threads_per_multi_processor + assert device_prop.multi_processor_count + seen_config_hashes: Optional[OrderedSet[Hashable]] = None + warp_size = device_prop.warp_size or 32 + for result in self.compile_results: + triton_config = result.config + compiled_binary = result.kernel + assert len(self.size_hints) >= 2 + xblock = triton_config.kwargs.get("XBLOCK", 1) + reduction_kwargs = [ + kwarg for kwarg in triton_config.kwargs if kwarg.startswith("R") + ] + rblocks = [triton_config.kwargs[kwarg] for kwarg in reduction_kwargs] + total_block = (self.size_hints["x"] + xblock - 1) // xblock + nreg = getattr(compiled_binary, "n_regs", None) + if nreg is None: + continue + + # make sure rblocks are not too small + if conditional_product(*rblocks) <= 64: + continue + + # each SM of A100 has 65536 32-bit registers. To maximize + # the theoretical occupancy, we need run 2048 threads on each + # SM. So each thread should use no more than 65536 / 2048 + # = 32 registers. In cases where occupancy matters, and each + # thread uses too many registers, reduce R0_BLOCK to reduce + # the register usage. + # For kernel https://gist.github.com/shunting314/e4cccc031fe30d378b9b23c08c238cbd + # from PLBartForCausalLM, latency improve from + # 7.795ms to 4.883ms. + # + if ( + nreg + <= device_prop.regs_per_multiprocessor + // device_prop.max_threads_per_multi_processor + ): + continue + + nreg_per_warp = nreg * warp_size + nreg_per_block = nreg_per_warp * triton_config.num_warps + + # Previously we set max_blocks_per_sm to 'max_threads_per_multi_processo / (32 * num_warps)' + # The formula below is a tighter upper bound since we have the assumption that + # nreg > device_prop.regs_per_multiprocessor // device_prop.max_threads_per_multi_processor + # due to the if condition above and: + # regs_per_multiprocessor / nreg_per_block + # = regs_per_multiprocessor / (nreg * 32 * num_warps) + # < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps) + # = max_threads_per_multi_processor / (32 * num_warps) + # Using a tigher upper bound can reveal more optimization opportunities. + max_blocks_per_sm = max( + device_prop.regs_per_multiprocessor // nreg_per_block, 1 + ) + + if total_block <= max_blocks_per_sm * device_prop.multi_processor_count: + # no need to improve occupancy + continue + new_config = copy.deepcopy(triton_config) + + # Reduce the largest Rn_BLOCK by a factor of 2. + largest_rkwarg: str = max( + reduction_kwargs, key=triton_config.kwargs.__getitem__ + ) + new_config.kwargs[largest_rkwarg] //= 2 + + if seen_config_hashes is None: + seen_config_hashes = OrderedSet( + [ + triton_config_to_hashable(x.config) + for x in self.compile_results + ] + ) + new_config_hash = triton_config_to_hashable(new_config) + if new_config_hash in seen_config_hashes: + continue + seen_config_hashes.add(new_config_hash) + log.debug( + "Dynamically scale down %s from TritonConfig(%s) and get a new TritonConfig(%s)", + largest_rkwarg, + triton_config, + new_config, + ) + if self.fn.fn is None: + """ + We are in the parent process, while this program was compiled in a worker + and the fn was dropped in prepare_for_pickle(). We haven't loaded the module + containing the real fn yet. + """ + assert hasattr(self, "_reload_kernel") + assert callable(self._reload_kernel) + self.fn = self._reload_kernel().fn + self.compile_results.append(self._precompile_config(new_config)) + + self._make_launchers() + + def _make_launchers(self): + if len(self.launchers) == len(self.compile_results): + return + + from torch._dynamo.device_interface import DeviceGuard + + device_interface = self.get_device_interface() + + # load binary to the correct device + with DeviceGuard(device_interface, self.triton_meta["device"]): + # need to initialize context + device_interface.synchronize(device_interface.current_device()) + launchers = [] + exc = None + for result in self.compile_results: + try: + launchers.append(result.make_launcher()) + + except (OutOfResources, PTXASError, torch.cuda.OutOfMemoryError) as e: + exc = e + if len(launchers) == 0: + raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}") + self.launchers = launchers + + def prepare_for_pickle(self) -> tuple[Any, Any, Any, Any, Any]: + """Drop stuff from triton.JITFunction that does not pickle. + This must be called after precompile so that these things are no longer needed. + Returns a tuple of old values + """ + old_values = ( + self.fn.fn, + self.fn.__globals__, + self.fn.used_global_vals, + self.fn.repr, + self.launchers, + ) + self.fn.fn = None + self.fn.__globals__ = None + self.fn.used_global_vals = None + self.fn.repr = _ConstRepr(self.fn.repr(self.fn)) + self.launchers = [] + return old_values + + def prepare_for_caching(self) -> None: + """ + Statically Launched CUDA Kernels have a raw cubin on them + that we don't need to store in the cache(since TritonBundler handles the collection for us) + """ + for result in self.compile_results: + if isinstance(result, StaticTritonCompileResult): + # Don't save this in the inductor cache, as it is very large + result.kernel.cubin_raw = None + + def __getstate__(self) -> dict[str, Any]: + assert not self.launchers, ( + "pickle should not be called with after make_launchers()" + ) + return { + **self.__dict__, + "lock": None, + } + + def __setstate__(self, state: dict[str, Any]) -> None: + self.__dict__.update(state) + self.lock = threading.Lock() + + def get_device_interface(self): + # this code cannot run in compile workers, because it imports from torch + from torch._dynamo.device_interface import get_interface_for_device + + return get_interface_for_device(self.device_props.type.replace("hip", "cuda")) + + def _precompile_config(self, cfg: Config) -> CompileResult[_KernelType]: + """Ahead of time compile a given autotuner config.""" + compile_meta = copy.deepcopy(self.triton_meta) + cfg_kwargs = cfg.kwargs + if self.device_props.type == "hip": + cfg_kwargs = {**cfg_kwargs} + for k in ("matrix_instr_nonkdim", "waves_per_eu", "kpack"): + if k in cfg_kwargs: + compile_meta[k] = cfg_kwargs.pop(k) + compile_meta["constants"].update(cfg_kwargs) + for i in self.fn.constexprs: + arg_name = self.fn.arg_names[i] + if arg_name not in compile_meta["constants"] and ( + arg_name == "num_warps" or arg_name == "num_stages" + ): + compile_meta["constants"][arg_name] = getattr(cfg, arg_name) + compile_meta["num_warps"] = cfg.num_warps + compile_meta["num_stages"] = cfg.num_stages + if HAS_WARP_SPEC: + compile_meta["num_consumer_groups"] = getattr(cfg, "num_consumer_groups", 0) + compile_meta["num_buffers_warp_spec"] = getattr( + cfg, "num_buffers_warp_spec", 0 + ) + compile_meta["debug"] = self.inductor_meta.get( + "assert_indirect_indexing", True + ) and not self.inductor_meta.get("is_hip", False) + + # device type will be "hip" rather than "cuda" here + compile_meta["device_type"] = self.device_props.type + compile_meta["cc"] = self.device_props.cc + + if self.device_props.type == "cpu": + triton_helpers.set_driver_to_cpu() + else: + triton_helpers.set_driver_to_gpu() + + if not ASTSource: + raise RuntimeError("Installed triton version too old, please upgrade") + + compile_args = ( + ASTSource( + self.fn, + compile_meta["signature"], + compile_meta["constants"], + compile_meta["configs"][0], + ), + ) + + if self.device_props.type == "mtia": + from mtia.host_runtime.torch_mtia.acc_flags import ( # type: ignore[import-not-found] + build_codename, + ) + + arch = build_codename() + else: + arch = compile_meta["cc"] + + target = GPUTarget( + compile_meta["device_type"], + arch, + cc_warp_size(compile_meta["cc"]), + ) + + options = { + "num_warps": compile_meta["num_warps"], + "num_stages": compile_meta["num_stages"], + "debug": compile_meta["debug"], + "sanitize_overflow": False, # turn off additional asserts added for overflow checks + } + if HAS_WARP_SPEC: + options.update( + { + "num_consumer_groups": compile_meta.get("num_consumer_groups", 0), + "num_buffers_warp_spec": compile_meta.get( + "num_buffers_warp_spec", 0 + ), + } + ) + if self.device_props.type == "hip": + if "waves_per_eu" in compile_meta: + options["waves_per_eu"] = compile_meta["waves_per_eu"] + if "matrix_instr_nonkdim" in compile_meta: + options["matrix_instr_nonkdim"] = compile_meta["matrix_instr_nonkdim"] + compile_kwargs = { + "target": target, + "options": options, + } + + try: + binary = triton.compile(*compile_args, **compile_kwargs) + except Exception: + log.exception( + "Triton compilation failed: %s\n%s\nmetadata: %s", + self.inductor_meta.get("kernel_name", "triton_"), + self.fn.src, + compile_meta, + ) + raise + TritonBundler.put( + triton_hash_to_path_key(binary.hash), self.triton_meta.get("device", 0) + ) + # If the binary has a cubin file to directly launch, save it on the binary + static_launcher = StaticTritonCompileResult.can_statically_launch( + binary, self.inductor_meta, self.triton_meta, self.heuristic_type + ) + + if static_launcher is not None: + result = StaticTritonCompileResult( + static_launcher, cfg, compile_meta, self.inductor_meta + ) + return result + + return TritonCompileResult(binary, cfg, compile_meta, self.inductor_meta) + + def _get_args_with_constexprs(self, args, launcher): + """ + `args` is passed in with only the non-constexpr args (because the constexpr arg values + depend on the config). However, in later triton versions, the constexpr args need to be + added into the args list. + """ + if triton_version_uses_attrs_dict(): + # first: aggregate the constexpr args in (index, val) pairs + # so we can sort them by index. + constexpr_args: list[tuple[int, Any]] = [] + for arg_name, arg_val in launcher.config.kwargs.items(): + if arg_name in self.fn.arg_names: + constexpr_args.append((self.fn.arg_names.index(arg_name), arg_val)) + + constexpr_args.sort() + new_args = [*args] + for arg_idx, arg_val in constexpr_args: + new_args.insert(arg_idx, arg_val) + + return new_args + return args + + def bench(self, launcher, *args, with_profiler=False, **kwargs): + """Measure the performance of a given launcher""" + # we don't skip configs with spilled registers when auto-tuning custom + # (user-written) Triton kernels, as (i) we don't have any knowledge or + # control over the kernel code; (ii) there is empirical evidence that + # for some (complicated) custom Triton kernels, a register-spilling + # config may yield the best latency. + if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get( + "spill_threshold", 16 + ): + log.debug( + "Skip config %s because of register spilling: %d", + launcher.config, + launcher.n_spills, + ) + return float("inf") + + device_interface = self.get_device_interface() + stream = device_interface.get_raw_stream(device_interface.current_device()) + + cpu_copies = self.copy_args_to_cpu_if_needed(*args, **kwargs) + + def kernel_call(): + cloned_args, cloned_kwargs = self.maybe_clone_args( + cpu_copies, *args, **kwargs + ) + # reset to zero before evaluating any config + self.reset_to_zero_args(*args, **kwargs) + args_with_constexprs = self._get_args_with_constexprs(cloned_args, launcher) + launcher( + *args_with_constexprs, + **cloned_kwargs, + stream=stream, + ) + self.restore_args_from_cpu(cpu_copies) + + if with_profiler: + from torch._inductor.utils import do_bench_using_profiling + + return do_bench_using_profiling(kernel_call, warmup=10, rep=40) + + if self.device_props.type == "cpu": + return benchmarker.benchmark_cpu(kernel_call) + + return benchmarker.benchmark_gpu(kernel_call, rep=40) + + def copy_args_to_cpu_if_needed(self, *args, **kwargs): + """ + To support benchmarking in the presence of mutated args, we need to avoid + autotuning contanminating them. We try to pass cloned args to the kernel. + If those clones would increase the peak memory usage, however, we instead + copy to cpu and restore them after each iteration. Figure out the args + to be copied and do the copying. + """ + if not self.optimize_mem: + return {} + + copies = {} + budget = torch.cuda.max_memory_allocated() - torch.cuda.memory_allocated() + + def maybe_copy(name, arg): + if name in self.mutated_arg_names and arg.is_cuda: + nonlocal budget + assert isinstance(arg, torch.Tensor) + required_storage_length = compute_required_storage_length( + arg.size(), + arg.stride(), + 0, + ) + size = required_storage_length * arg.element_size() + if size > budget: + cpu_arg = torch.empty_strided( + (required_storage_length,), + (1,), + dtype=arg.dtype, + device="cpu", + pin_memory=True, + ) + cpu_arg.copy_( + arg.as_strided((required_storage_length,), (1,)), + non_blocking=True, + ) + copies[name] = (arg, cpu_arg) + else: + budget -= size + + for name, arg in zip(self.fn.arg_names, args): + maybe_copy(name, arg) + + for name, arg in kwargs.items(): + maybe_copy(name, arg) + + return copies + + def restore_args_from_cpu(self, cpu_copies): + for pair in cpu_copies.values(): + arg, cpu_arg = pair + required_storage_length = compute_required_storage_length( + arg.size(), + arg.stride(), + 0, + ) + arg.as_strided((required_storage_length,), (1,)).copy_( + cpu_arg, non_blocking=True + ) + + def reset_to_zero_args(self, *args, **kwargs): + if not self.reset_to_zero_arg_names: + return + for i, arg in enumerate(args): + if self.fn.arg_names[i] in self.reset_to_zero_arg_names: + assert isinstance( + arg, + torch.Tensor, + ), ( + "self.reset_to_zero_arg_names should only contain valid argument names" + ) + arg.zero_() + + for name, arg in kwargs.items(): + if name in self.reset_to_zero_arg_names: + assert isinstance( + arg, + torch.Tensor, + ), ( + "self.reset_to_zero_arg_names should only contain valid argument names" + ) + arg.zero_() + + def maybe_clone_args( + self, exclude: Container[str], *args, **kwargs + ) -> tuple[list[Any], dict[str, Any]]: + """ + Prepare new args and kwargs by cloning any in-place buffers + (that are not in the provided exclusion list), to avoid autotune + contaminating them. Avoid cloning the other buffers because it + leads to increased memory usage. + """ + from ..compile_fx import clone_preserve_strides + + def prepare_arg(name, arg): + if name in self.mutated_arg_names and name not in exclude: + assert isinstance(arg, torch.Tensor) + return clone_preserve_strides(arg) + else: + return arg + + cloned_args = [ + prepare_arg(name, arg) + for name, arg in itertools.zip_longest(self.fn.arg_names[: len(args)], args) + ] + cloned_kwargs = {name: prepare_arg(name, arg) for name, arg in kwargs.items()} + return cloned_args, cloned_kwargs + + def clone_args(self, *args, **kwargs) -> tuple[list[Any], dict[str, Any]]: + return self.maybe_clone_args(OrderedSet(), *args, **kwargs) + + def benchmark_all_configs(self, *args, **kwargs): + with ( + dynamo_timed( + "CachingAutotuner.benchmark_all_configs", + log_pt2_compile_event=True, + metadata={"kernel_name": self.inductor_meta.get("kernel_name")}, + dynamo_compile_column_us="runtime_triton_autotune_time_us", + compile_id=self.compile_id, + is_backward=self.is_backward, + log_waitcounter=True, + waitcounter_name_override="triton_autotuner", + ), + compilation_callback.callback_handler.install_callbacks( + compilation_callback.CallbackTrigger.TRITON_AUTOTUNING, + str(self.compile_id), + ), + ): + timings = { + launcher: self.bench(launcher, *args, **kwargs) + for launcher in self.launchers + } + + for k, v in timings.items(): + self.coordesc_tuner.cache_benchmark_result(k.config, v) + + if log.isEnabledFor(logging.DEBUG): + log.debug("Benchmark all input configs for %s, get:", self.fn.__name__) + for k, v in timings.items(): + log.debug( + "%s: %f, nreg %d, nspill %d, #shared-mem %s", + k.config, + v, + k.n_regs, + k.n_spills, + k.shared, + ) + + self.reset_to_zero_args(*args, **kwargs) + return timings + + def autotune_to_one_config(self, *args, **kwargs): + """Do the actual autotuning""" + start_time = time.time_ns() + timings = self.benchmark_all_configs(*args, **kwargs) + benchmark_time_taken_ns = time.time_ns() - start_time + self.launchers = [builtins.min(timings, key=timings.get)] + self.autotune_time_taken_ns = ( + self.precompile_time_taken_ns + benchmark_time_taken_ns + ) + + # log the best config + launcher = self.launchers[0] + log.debug( + "Best config for %s: %s: %f, nreg %d, nspill %d, #shared-mem %s", + self.fn.__name__, + launcher.config, + timings[launcher], + launcher.n_regs, + launcher.n_spills, + launcher.shared, + ) + + if self.save_cache_hook: + self.save_cache_hook( + launcher.config, + self.autotune_time_taken_ns, + triton_cache_hash=launcher.cache_hash, + ) + + def save_gpu_kernel(self, stream, launcher): + key = self.inductor_meta.get("kernel_name", None) # unique kernel name + assert key is not None, "kernel_name can not be None" + params = { + "mangled_name": ( + launcher.bin.metadata.name + if hasattr(launcher.bin.metadata, "name") + else launcher.bin.metadata["name"] + ), + "num_warps": ( + launcher.bin.num_warps + if hasattr(launcher.bin, "num_warps") + else launcher.bin.metadata.num_warps + ), + "shared_mem": ( + launcher.bin.shared + if hasattr(launcher.bin, "shared") + else launcher.bin.metadata.shared + ), + "stream": stream, + # User defined triton kernels will have arbitrary kwarg names + "config": config_to_dict(launcher.config), + "inductor_meta": self.inductor_meta, + "triton_meta": self.triton_meta, + "def_args": launcher.def_args, + "call_args": launcher.call_args, + "global_scratch": launcher.global_scratch, + } + from torch._inductor.codecache import CudaKernelParamCache + + bin_type = {"hip": "hsaco", "xpu": "spv"}.get(self.device_props.type, "cubin") + binary = launcher.bin.asm[bin_type] + # Also store asm code which can be used for debugging and generating cpp package + asm_type = {"hip": "amdgcn", "cuda": "ptx", "xpu": "spv"}.get( + self.device_props.type, None + ) + asm = launcher.bin.asm.get(asm_type, None) + + CudaKernelParamCache.set(key, params, binary, bin_type, asm, asm_type) + self.cuda_kernel_saved = True + + def coordinate_descent_tuning(self, launcher, *args, **kwargs): + """ + Coordinate descent tuning can be run with or without max-autotune. + + The only difference between these two is the starting config for coordinate_descent tuning. + E.g., assuming regular autotune only get one config C1; while max-autotune get 4 configs C1, C2, C3, C4 + and max-autotune figure out C3 is the best. + + Then if coordinate desecnt tuning is run with max-autotune disabled, it will start from C1; + while if coordinate descent tuning is run with max-autotune enabled, it will start from C3. + """ + if ( + self.heuristic_type == HeuristicType.TEMPLATE + or self.heuristic_type == HeuristicType.USER_AUTOTUNE + ): + # skip triton template + return launcher + + config2launcher = {launcher.config: launcher} + + # TODO: should we just load the kernels ahead of time if we know we're going to call this? + if self.fn.fn is None: + """ + We are in the parent process, while this program was compiled in a worker + and the fn was dropped in prepare_for_pickle(). We haven't loaded the module + containing the real fn yet. + """ + assert hasattr(self, "_reload_kernel") + assert callable(self._reload_kernel) + self.fn = self._reload_kernel().fn + + def benchmark_one_config(config): + with self.lock: + launcher = self._precompile_config(config).make_launcher() + config2launcher[config] = launcher + + out = self.bench(launcher, *args, **kwargs) + log.debug( + "COORDESC: %s: %f, nreg %d, nspill %d, #shared-mem %d", + launcher.config, + out, + launcher.n_regs, + launcher.n_spills, + launcher.shared, + ) + return out + + assert not ( + self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION + and "R0_BLOCK" in launcher.config.kwargs + ), ( + "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK" + ) + start_time = time.time_ns() + best_config = self.coordesc_tuner.autotune( + benchmark_one_config, launcher.config, None + ) + coordesc_time_taken_ns = time.time_ns() - start_time + best_config.found_by_coordesc = True + + if self.save_cache_hook: + self.save_cache_hook( + best_config, + self.autotune_time_taken_ns + coordesc_time_taken_ns, + found_by_coordesc=True, + ) + + if best_config not in config2launcher: + # On a Coordesc cache hit, we might not have loaded the launcher + # This can happen because PyCodeCache saves CachingAutotuners in memory, + # even for separate compile IDs (which can have different inputs without changing output code) + config2launcher[best_config] = self._precompile_config( + best_config + ).make_launcher() + return config2launcher[best_config] + + def run( + self, + *args, + stream, + benchmark_run=False, + **kwargs, + ): # type:ignore[override] + if hasattr(triton, "set_allocator"): + + def alloc_fn(size: int, align: int, stream: Optional[int]): + return torch.empty( + size, dtype=torch.int8, device=self.device_props.type + ) + + triton.set_allocator(alloc_fn) + + if self.triton_interpret: + args, grid = self._interpret_args_grid(args, self.configs[0]) + return self.fn[grid]( + *args, + **kwargs, + **self.configs[0].kwargs, + ) + + if len(self.launchers) != 1: + if len(self.launchers) == 0: + start_time = time.time_ns() + self.precompile() + self.precompile_time_taken_ns = time.time_ns() - start_time + if len(self.launchers) > 1: + self.autotune_to_one_config(*args, **kwargs) + + if not getattr( + self.launchers[0].config, "found_by_coordesc", False + ) and self.inductor_meta.get("coordinate_descent_tuning", False): + self.launchers = [ + self.coordinate_descent_tuning(self.launchers[0], *args, **kwargs) + ] + + (launcher,) = self.launchers + if launcher.store_cubin and (not benchmark_run or not self.cuda_kernel_saved): + self.save_gpu_kernel(stream, launcher) + + args = self._get_args_with_constexprs(args, launcher) + + if self.dump_launch_params: + new_args, grid = self._interpret_args_grid(args, launcher.config) + _dump_launch_params(new_args, kwargs, launcher, self.fn.__name__, grid) + + # it is faster than entering and exiting a context manager, even if the context + # manager is a nullcontext. + if autograd_profiler._is_profiler_enabled: + kernel_kwargs_str = ",".join( + f"{k}={v}" for (k, v) in launcher.config.kwargs.items() + ) + + profiler_kwargs = { + "kernel_file": (self.filename or ""), + "kernel_hash": self.kernel_hash, + "kernel_backend": "triton", + "stream": stream, + "num_warps": launcher.config.num_warps, + "num_stages": launcher.config.num_stages, + "kernel_kwargs": kernel_kwargs_str, + } + + with torch._C._profiler._RecordFunctionFast( + self.inductor_meta.get("kernel_name", "triton kernel"), + args, + profiler_kwargs, + ): + return launcher( + *args, + **kwargs, + stream=stream, + ) + else: + return launcher( + *args, + **kwargs, + stream=stream, + ) + + def _interpret_args_grid( + self, args: tuple[Any, ...], cfg: Config + ) -> tuple[tuple[Any, ...], tuple[int, int, int]]: + grid = GridExpr.from_meta(self.inductor_meta, cfg).eval_slow( + dict( + zip( + [ + *self.triton_meta["signature"].keys(), + *self.inductor_meta.get("extra_launcher_args", ()), + ], + args, + ) + ) + ) + if self.inductor_meta.get("extra_launcher_args"): + args = args[: -len(self.inductor_meta["extra_launcher_args"])] + return args, grid + + +class _ConstRepr: + def __init__(self, value: str): + self.value = value + + def __call__(self, _=None) -> str: + return self.value + + +class CompileResult(Generic[_T]): + def __init__( + self, + kernel: _T, + config: Config, + compile_meta: dict[str, Any], + inductor_meta: dict[str, Any], + ): + self.kernel = kernel + self.config = config + self.compile_meta = compile_meta + self.inductor_meta = inductor_meta + + def make_launcher(self) -> LauncherType: ... + + def _gen_launcher_code(self, scope, def_args, runner_args) -> LauncherType: + grid = GridExpr.from_meta(self.inductor_meta, self.config) + # grid.prefix is usually empty, grid.x_grid is something like `-(xnumel//-1024)` + lines = [ + f"def launcher({', '.join(def_args)}, stream):", + *[f" {line}" for line in grid.prefix], + f" grid_0 = {grid.x_grid}", + f" grid_1 = {grid.y_grid}", + f" grid_2 = {grid.z_grid}", + f" runner({', '.join(runner_args)})", + ] + launcher_code = "\n".join(lines) + exec(launcher_code, scope) + return scope["launcher"] + + def _get_arg_lists( + self, arg_names, constexprs + ) -> tuple[list[str], list[str], OrderedSet[str]]: + """ + Return a bunch of intermediate lists of args needed for generating + launcher code. + """ + compile_meta = self.compile_meta + cfg = self.config + known_constants = OrderedSet( + arg for i, arg in enumerate(arg_names) if i in constexprs + ) + + """ + https://github.com/pytorch/pytorch/issues/115344 + + self.fn.constexprs doesn't properly deal with None args, so when we filter out + an arg in UserDefinedTritonKernel.codegen, we need to filter it here as well. + We also don't want to modify self.fn. + + We know that we removed something from the signature if: + 1. It's in compile_meta["constants"] + 2. It isn't a constant we already know about + Note: The value of interest has already been added to compile_meta['constants'], + so we use self.fn.constexprs instead. + 3. It isn't in the compile_meta signature + """ + none_args = OrderedSet( + k + for k, v in compile_meta["constants"].items() + if v is None and k not in known_constants + ) + none_args = none_args.difference(OrderedSet(compile_meta["signature"].keys())) + + if triton_version_uses_attrs_dict(): + call_args = arg_names + def_args = arg_names + if ( + "num_warps" in compile_meta["constants"] + or "num_stages" in compile_meta["constants"] + ): + # num_warps/num_stages are special implicit args that are not in the signature + # see test_triton_kernel_special_params + def_args = [ + arg for arg in def_args if arg not in ("num_warps", "num_stages") + ] + repl = { + k: str(compile_meta["constants"].get(k)) + for k in ("num_warps", "num_stages") + } + call_args = [repl.get(arg, arg) for arg in call_args] + else: + call_args = [ + arg + for i, arg in enumerate(arg_names) + if i not in constexprs and arg not in none_args + ] + cfg_dict = config_to_dict(cfg) + def_args = [ + name + for name in arg_names + if name not in cfg_dict and name not in none_args + ] + + if "extra_launcher_args" in self.inductor_meta: + def_args = [*def_args, *self.inductor_meta["extra_launcher_args"]] + + return call_args, def_args, none_args + + +class CannotStaticallyLaunchKernel(Exception): + pass + + +class StaticTritonCompileResult(CompileResult[StaticallyLaunchedCudaKernel]): + """ + TritonCompileResult that uses StaticCudaLauncher, + which vastly simplifies the setup and metadata needed to be kept. + """ + + @staticmethod + def can_statically_launch( + kernel: CompiledKernel, + inductor_meta: dict[str, Any], + triton_meta: dict[str, Any], + heuristic_type: HeuristicType, + ) -> Optional[StaticallyLaunchedCudaKernel]: + if not torch._inductor.config.use_static_cuda_launcher: + return None + + def check_can_launch() -> StaticallyLaunchedCudaKernel: + if triton_meta.get("device_type", None) != "cuda": + # Only cuda kernels + raise CannotStaticallyLaunchKernel("Non-cuda device") + + if torch._inductor.config.cpp_wrapper: + # If we're running with cpp wrapper, it doesn't + # make sense to statically compile since everything + # is codegenned anyway + raise CannotStaticallyLaunchKernel("Cpp wrapper enabled") + + if ( + heuristic_type == HeuristicType.USER_AUTOTUNE + and not torch._inductor.config.static_launch_user_defined_triton_kernels + ): + # Don't support user defined triton kernels yet + raise CannotStaticallyLaunchKernel("User defined triton kernel") + + if inductor_meta.get("store_cubin", None): + # Requires storing the entire binary + raise CannotStaticallyLaunchKernel("store_cubin is enabled") + + cubin_location = os.path.join( + triton_cache_dir(triton_meta.get("device", 0)), + triton_hash_to_path_key(kernel.hash), + f"{kernel.src.fn.__name__}.cubin", + ) + + if not os.path.exists(cubin_location): + raise CannotStaticallyLaunchKernel( + f"Cubin path not found: {cubin_location}" + ) + + else: + kernel._cubin_path = cubin_location + + try: + static_kernel = StaticallyLaunchedCudaKernel(kernel) + except NotImplementedError as e: + raise CannotStaticallyLaunchKernel(f"NotImplemented: {str(e)}") from e + + return static_kernel + + try: + result = check_can_launch() + return result + except CannotStaticallyLaunchKernel as e: + log.info("Bypassing StaticallyLaunchedCudaKernel due to %s", str(e)) + if torch._inductor.config.strict_static_cuda_launcher: + raise e + return None + + def reload_cubin_path(self): + """ + When loading from cache on disk, we want to reload cubin + files from their appropriate location on disc. + """ + cubin_location = os.path.join( + triton_cache_dir(self.compile_meta.get("device", 0)), + triton_hash_to_path_key(self.kernel.hash), + f"{self.kernel.name}.cubin", + ) + if not os.path.exists(cubin_location): + if self.kernel.cubin_raw is not None: + # We saved the raw cubin, so write it to he appropriate location + self.kernel.reload_cubin_from_raw(cubin_location) + else: + raise RuntimeError( + "Cubin file saved by TritonBundler not found at %s", cubin_location + ) + self.kernel.cubin_path = cubin_location + + def make_launcher(self) -> LauncherType: + # If at least one static make_launcher call occurs, + # we're sure static cuda launcher was used for this compile + set_feature_use("static_cuda_launcher", True) + # Load the binary on the parent + if not self.kernel.cubin_path: + self.reload_cubin_path() + device = self.compile_meta.get("device", 0) + if device is None: + device = 0 + self.kernel.load_kernel(device) + scope = { + "runner": self.kernel.run, + } + + # NOTE: Constexpr handling for triton and static cuda launcher + + # Triton kernels have two types of constexprs: *declared* ones, which are ones the user + # has explicitly declared as tl.constexpr, and *implied* ones, which are expressions triton + # deems constant while compiling/analyzing the code (i.e. unused parameters, for example) + + # Triton kernels handle constexprs slightly differently depending on which version of triton + # we care about (we support 3.2.0 and 3.3.0). + + # In 3.2.0, triton kernels do not require passing any declared constexprs into the kernel + # In 3.3.0, triton kernels require all declared constexprs be passed into the kernel, where + # they are subsequently ignored. + # When statically launching, since we're launching from the triton generated cubin, we actually want to + # always get rid of all const exprs, declared or implied, since the underlying cubin file has all + # of the constants stripped away anyway. + + # But CachingAutotuner.run will pass us a different number of arguments depending on + # whether or not we're in triton 3.2.0 or later, so we grab def_args with the same logic + # as the (non static) TritonCompileResult. We then generate call_args ourselves, since we + # want only a subset of the arguments passed to triton. + # Here, arg_names is exactly fn.src.arg_names and declared_constexprs is exactly fn.src.constexprs, + # which matches behavior with regular TritonCompileResult + _, def_args, none_args = self._get_arg_lists( + self.kernel.arg_names, self.kernel.declared_constexprs + ) + + call_args = [ + arg + for i, arg in enumerate(self.kernel.arg_names) + if i not in self.kernel.full_constexprs and arg not in none_args + ] + + # StaticallyLaunchedCudaKernel.run takes in order grid_0, grid_1, grid_2, stream, and call_args + runner_args = ["grid_0", "grid_1", "grid_2", "stream", *call_args] + launcher = self._gen_launcher_code(scope, def_args, runner_args) + launcher.config = self.config # type: ignore[attr-defined] + launcher.n_regs = self.kernel.n_regs # type: ignore[attr-defined] + launcher.n_spills = self.kernel.n_spills # type: ignore[attr-defined] + launcher.shared = self.kernel.shared # type: ignore[attr-defined] + launcher.cache_hash = triton_hash_to_path_key(self.kernel.hash) # type: ignore[attr-defined] + launcher.store_cubin = False # type: ignore[attr-defined] + launcher._is_static = True # type: ignore[attr-defined] + return launcher + + +class TritonCompileResult(CompileResult[CompiledKernel]): + """ + Upstream Triton CompileKernel can not be pickled. This is a wrapper + to support serialization and generate the launcher function. + """ + + @staticmethod + @functools.lru_cache(32) + def _kernel_metadata_cls(fields: tuple[str, ...]) -> Any: + return namedtuple("KernelMetadata", sorted(fields)) + + @staticmethod + def _serialize_metadata(metadata): + """ + Triton uses a nested class called KernelMetadata to store metadata information. + Pickle does not work well with nested namedtuples, as the namedtuple doesn't appear + in the toplevel namespace of the module. So these serialization/deser functions + are used to convert the namedtuples to a dict and back. + + As for packed_metadata, depending on the triton backend, KernelMetadata can be + a namedtuple, or a regular tuple! So the serialization function branches on whether + the metadata to be serialized is a namedtuple or regular, serializable one. + """ + + def is_namedtuple(obj) -> bool: + return ( + isinstance(obj, tuple) + and hasattr(obj, "_asdict") + and hasattr(obj, "_fields") + ) + + if is_namedtuple(metadata): + return metadata._asdict() + else: + return metadata + + @staticmethod + def _deserialize_metadata(metadata): + if isinstance(metadata, dict): + return TritonCompileResult._kernel_metadata_cls(tuple(metadata.keys()))( + **metadata + ) + else: + return metadata + + def __getstate__(self) -> dict[str, Any]: + kernel = self.kernel + # replace the fields that don't pickle nicely + kernel_state = { + **kernel.__dict__, + # See doc about serializing metadata above + "metadata": self._serialize_metadata(kernel.metadata), + "packed_metadata": self._serialize_metadata( + getattr(kernel, "packed_metadata", None) + ), + "module": None, # regenerated by kernel._init_handles() + "function": None, # regenerated by kernel._init_handles() + "run": None, # regenerated by kernel._init_handles() + } + return {**self.__dict__, "kernel": kernel_state} # type: ignore[dict-item] + + def __setstate__(self, state: dict[str, Any]) -> None: + # src = ASTSource.__new__(ASTSource) + # src.__setstate__(state["kernel"]["src"]) + # TODO(jansel): need to fixup src.fn which is now None + kernel = CompiledKernel.__new__(CompiledKernel) + metadata = state["kernel"]["metadata"] + packed_metadata = state["kernel"]["packed_metadata"] + kernel.__dict__.update( + { + **state["kernel"], + # "src": src, + "metadata": self._deserialize_metadata(metadata), + "packed_metadata": self._deserialize_metadata(packed_metadata), + } + ) + self.__dict__.update(state) + self.kernel = kernel + + def make_launcher(self) -> LauncherType: + """ + Launching triton kernels is performance sensitive, we compile + a custom Python function get the grid() and reorder the args to + the underlying wrapper. + """ + cfg = self.config + compile_meta = self.compile_meta + binary = self.kernel + fn = binary.src.fn + binary._init_handles() + (call_args, def_args, none_args) = self._get_arg_lists( + fn.arg_names, fn.constexprs + ) + binary_shared = ( + binary.shared if hasattr(binary, "shared") else binary.metadata.shared + ) + + if knobs is None: + launch_enter = binary.__class__.launch_enter_hook + launch_exit = binary.__class__.launch_exit_hook + else: + launch_enter = knobs.runtime.launch_enter_hook + launch_exit = knobs.runtime.launch_exit_hook + + import math as math_lib + + import torch as torch_lib + + scope = { + "grid_meta": cfg.kwargs, + "bin": binary, + "launch_enter_hook": launch_enter, + "launch_exit_hook": launch_exit, + "metadata": ( + binary.packed_metadata + if hasattr(binary, "packed_metadata") + else binary.metadata + ), + "shared": binary_shared, + "num_warps": ( + binary.num_warps + if hasattr(binary, "num_warps") + else binary.metadata.num_warps + ), + "cta_args": ( + ( + binary.num_ctas, + *get_first_attr(binary, "cluster_dims", "clusterDims"), + ) + if hasattr(binary, "num_ctas") + else ( + (binary.metadata.num_ctas, *binary.metadata.cluster_dims) + if hasattr(binary, "metadata") + else () + ) + ), + "function": get_first_attr(binary, "function", "cu_function"), + "runner": get_first_attr(binary, "run", "c_wrapper"), + "math": math_lib, + "torch": torch_lib, + } + + if not hasattr(binary, "launch_metadata"): + # launch args before CompiledKernel.launch_metadata is added. + # TODO(jansel): delete this branch in mid-2025 + runner_args = [ + "grid_0", + "grid_1", + "grid_2", + "num_warps", + "*cta_args", + "shared", + "stream", + "function", + "launch_enter_hook", + "launch_exit_hook", + "metadata", + *call_args, + ] + else: # args after CompiledKernel.launch_metadata: https://github.com/triton-lang/triton/pull/3492 + # Getting the kernel launch args is extremely perf-sensitive. Evaluating + # `bin.launch_metadata` is relatively expensive, and returns None unless a + # `launch_enter_hook` is installed. So if we don't have that hook installed, + # we want to burn None in to the launch args with zero overhead. + # See https://github.com/pytorch/pytorch/issues/123597 + if launch_enter: + launch_metadata = f"bin.launch_metadata((grid_0, grid_1, grid_2), stream, {', '.join(call_args)})" + else: + launch_metadata = "None" + runner_args = [ + "grid_0", + "grid_1", + "grid_2", + "stream", + "function", + "metadata", + launch_metadata, + "launch_enter_hook", + "launch_exit_hook", + *call_args, + ] + + launcher = self._gen_launcher_code(scope, def_args, runner_args) + + launcher = scope["launcher"] + launcher.config = cfg + launcher.n_regs = getattr(binary, "n_regs", None) + launcher.n_spills = getattr(binary, "n_spills", None) + launcher.shared = binary_shared + launcher.cache_hash = triton_hash_to_path_key(binary.hash) + launcher.store_cubin = self.inductor_meta.get("store_cubin", False) + # store this global variable to avoid the high overhead of reading it when calling run + if launcher.store_cubin: + launcher.fn = fn + launcher.bin = binary + if triton_version_uses_attrs_dict(): + # arg filtering wasn't done above + cfg_dict = config_to_dict(cfg) + def_args = [x for x in def_args if x not in cfg_dict] + call_args = [ + x + for x in call_args + if compile_meta["signature"].get(x, "constexpr") != "constexpr" + and x not in none_args + ] + launcher.def_args = def_args + launcher.call_args = call_args + kernel_metadata = getattr(self.kernel, "metadata", None) + launcher.global_scratch = getattr( + kernel_metadata, "global_scratch_size", None + ) + return launcher + + +def _find_names(obj): + import gc + import inspect + + frame = inspect.currentframe() + while frame is not None: + frame.f_locals + frame = frame.f_back + obj_names = [] + for referrer in gc.get_referrers(obj): + if isinstance(referrer, dict): + for k, v in referrer.items(): + if v is obj: + obj_names.append(k) + return obj_names + + +collected_calls: list[Any] = [] + + +def start_graph(): + collected_calls.clear() + + +def end_graph(output_file): + if len(collected_calls) == 0: + return + overall_time = sum(call[0] for call in collected_calls) + overall_gb = sum(call[1] for call in collected_calls) + cur_file = inspect.stack()[1].filename + summary_str = ( + f"SUMMARY ({cur_file})\n" + f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb / (overall_time / 1e3):.2f}GB/s" + ) + log.info( + "%s", + summary_str, + ) + if output_file is not None: + # sort perf numbers in descending order, i.e. placing the + # most runtime-heavy kernels at the top of the list + sorted_calls = sorted(collected_calls, key=lambda c: float(c[0]), reverse=True) + try: + with open(output_file, "a") as file: + log.info( + "Save profile bandwidth results to %s", + output_file, + ) + file.write("====================\n") + file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n") + for ms, num_gb, gb_per_s, kernel_name in sorted_calls: + # also display the runtime percentage for each kernel + percentage = f"{ms / overall_time * 100:.2f}%" + suffix = f" \t {percentage} \t {kernel_name}" + bw_info_str = create_bandwidth_info_str( + ms, + num_gb, + gb_per_s, + suffix=suffix, + color=False, + ) + file.write(bw_info_str + "\n") + file.write(f"{summary_str}\n\n") + except Exception as e: + log.warning( + "failed to write profile bandwidth result into %s: %s", + output_file, + e, + ) + + +class DebugAutotuner(CachingAutotuner): + def __init__( + self, + *args, + regex_filter="", + with_profiler=False, + with_bandwidth_info=True, + **kwargs, + ): + self.regex_filter = regex_filter + self.with_profiler = with_profiler + self.with_bandwidth_info = with_bandwidth_info + super().__init__(*args, **kwargs) + self.cached = None + + def run(self, *args, stream, **kwargs): + if not self.with_bandwidth_info: + super().run(*args, stream=stream, **kwargs, benchmark_run=True) + return + else: + possible_names = _find_names(self) + kernel_name = f"{max(possible_names, key=len)}" + if not re.match(self.regex_filter, kernel_name): + return + + if len(self.launchers) != 1: + if len(self.launchers) == 0: + start_time = time.time_ns() + self.precompile() + self.precompile_time_taken_ns = time.time_ns() - start_time + if len(self.launchers) > 1: + self.autotune_to_one_config(*args, **kwargs) + (launcher,) = self.launchers + + if launcher.store_cubin: + self.save_gpu_kernel(stream, launcher) + + if self.cached is None: + ms = self.bench(launcher, *args, with_profiler=self.with_profiler) + num_in_out_ptrs = len( + [ + arg_name + for arg_name in self.fn.arg_names + if arg_name.startswith("in_out_ptr") + ] + ) + num_gb = self.inductor_meta.get("kernel_num_gb", None) + if num_gb is None: + num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9 + gb_per_s = num_gb / (ms / 1e3) + self.cached = ms, num_gb, gb_per_s, kernel_name + collected_calls.append((ms, num_gb, gb_per_s, kernel_name)) + log.info( + "%s", + create_bandwidth_info_str( + ms, num_gb, gb_per_s, suffix=f" \t {kernel_name}" + ), + ) + else: + # in AOTI, we will call the kernel and its timing info has been cached already + collected_calls.append(self.cached) + + +def hash_configs(configs: list[Config]): + """ + Hash used to check for changes in configurations + """ + hasher = hashlib.sha256() + for cfg in configs: + hasher.update( + f"{sorted(cfg.kwargs.items())} {cfg.num_warps} {cfg.num_stages}\n".encode() + ) + return hasher.hexdigest() + + +def cached_autotune( + size_hints: Optional[list[int]], + configs: list[Config], + triton_meta, + heuristic_type, + filename=None, + inductor_meta=None, + custom_kernel=False, +): + """ + A copy of triton.autotune that calls our subclass. Our subclass + has additional debugging, error handling, and on-disk caching. + """ + configs = unique_configs(configs) + assert len(configs) == 1 or filename + inductor_meta = {} if inductor_meta is None else inductor_meta + + configs, autotune_cache, autotune_cache_info = check_autotune_cache( + configs, filename, inductor_meta + ) + mutated_arg_names = inductor_meta.pop("mutated_arg_names", ()) + optimize_mem = inductor_meta.pop("optimize_mem", True) + + if "restore_value" in triton_meta: + mutated_arg_names += triton_meta.pop("restore_value") + + reset_to_zero_arg_names: list[str] = [] + if "reset_to_zero" in triton_meta: + reset_to_zero_arg_names.extend(triton_meta.pop("reset_to_zero")) + + def decorator(fn): + # Remove XBLOCK from config if it's not a function argument. + # This way, coordinate descent tuning will not try to tune it. + # + # Context: When TritonKernel.no_x_dim is True, we hardcode XBLOCK to 1. + import inspect + + if "XBLOCK" not in inspect.signature(fn.fn).parameters: + for tconfig in configs: + if "XBLOCK" in tconfig.kwargs: + assert tconfig.kwargs["XBLOCK"] == 1 + tconfig.kwargs.pop("XBLOCK") + + if inductor_meta.get("profile_bandwidth"): + return DebugAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + regex_filter=inductor_meta["profile_bandwidth_regex"], + with_profiler=inductor_meta[ + "profile_bandwidth_with_do_bench_using_profiling" + ], + configs=configs, + save_cache_hook=autotune_cache and autotune_cache.save, + mutated_arg_names=mutated_arg_names, + reset_to_zero_arg_names=reset_to_zero_arg_names, + optimize_mem=optimize_mem, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + filename=filename, + with_bandwidth_info=True, + ) + return CachingAutotuner( + fn, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + configs=configs, + save_cache_hook=autotune_cache and autotune_cache.save, + mutated_arg_names=mutated_arg_names, + reset_to_zero_arg_names=reset_to_zero_arg_names, + optimize_mem=optimize_mem, + heuristic_type=heuristic_type, + size_hints=size_hints, + custom_kernel=custom_kernel, + filename=filename, + autotune_cache_info=autotune_cache_info, + ) + + return decorator + + +def unique_configs(configs: list[Config]): + """Remove duplicate configurations""" + seen: OrderedSet[Hashable] = OrderedSet() + pruned_configs = [] + + for cfg in configs: + key = triton_config_to_hashable(cfg) + if key not in seen: + seen.add(key) + pruned_configs.append(cfg) + return pruned_configs + + +def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None): + for numel, label in zip((xnumel, ynumel, znumel), "XYZ"): + if numel is None: + continue + block = cfg[f"{label}BLOCK"] + if numel == 1: + assert block == 1, ( + f"TritonKernel.indexing assumes numel == 1 => BLOCK == 1" + f" but {label.lower()}numel=={numel} and {label}BLOCK={block} (cfg={cfg})." + ) + max_block = TRITON_MAX_BLOCK[label] + max_block_str = f'config.triton.max_block["{label}"]' + assert max_block % block == 0, ( + f"TritonKernel.indexing assumes {label}BLOCK divides {max_block_str}" + f" but {label}BLOCK={block} and {max_block_str}={max_block} (cfg={cfg})." + ) + + +def check_max_block(cfg: dict[str, int]): + """ + Check that block sizes are within the maximum allowed. + """ + for var, val in cfg.items(): + block_suffix = "BLOCK" + if block_suffix in var: + prefix = var.removesuffix(block_suffix) + max_block = TRITON_MAX_BLOCK[prefix] + assert val <= max_block, ( + f"'{var}' too large. Maximum: {max_block}. Actual: {val}." + ) + + +def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False): + # On AMD GPU each warp has 64 lanes which is double the size on NV GPU, + # therefore using half the number of warps here correspondingly. + if torch.version.hip: + max_num_warps = (max_num_warps + 1) // 2 + min_num_warps = (min_num_warps + 1) // 2 + # persistent reduction is register intensive + if register_intensive: + max_num_warps = max_num_warps // 2 + return next_power_of_2(min(max(num_warps, min_num_warps), max_num_warps)) + + +def _check_max_grid_x(size_hints, x, num_warps): + # Check if maxGridSize is exceeded - if so then must scale XBLOCK further + max_grid_x = 2147483647 + warp_size = ( + 64 if torch.version.hip else 32 + ) # TODO: query warp size once #129663 is merged + num_blocks = (size_hints["x"] + x - 1) // x + + while (num_blocks * num_warps * warp_size) > max_grid_x and x < size_hints["x"]: + x *= 2 # Scale up XBLOCK if grid exceeds limits + num_blocks = num_blocks // 2 + if (num_blocks * num_warps * warp_size) > max_grid_x: + raise AssertionError( + "Reduction config exceeds cudaDeviceProp maxGridSize. Please raise a pytorch issue" + ) + return x, num_blocks + + +def triton_config( + size_hints, + x, + y=None, + z=None, + num_stages=1, + num_elements_per_warp=256, + min_elem_per_thread=0, +) -> Config: + """ + Construct a pointwise triton config with some adjustment heuristics + based on size_hints. Size_hints is a tuple of numels in each tile + dimension and will be rounded up to the nearest power of 2. + + num_elements_per_warp is a suggestion for controlling how many warps + the triton config should contain. e.g.: if x=16, y=8, z=4 then + num_elements = 16*8*4 = 512. Then if we set num_elements_per_warp=128, + we'll launch 512 (elem) / 128 (elem/warp) = 4 warps. Note that it's + just a suggestion, and sometimes other adjustment heuristics will + override the num_elements_per_warp. + + min_elem_per_thread controls the minimum number of elements + processed by each thread. It's always enforced. + """ + # Ideally we want to read this from some device config + + maxGridSize = [2147483647, 65535, 65535] + + target = conditional_product(x, y, z) + if conditional_product(*size_hints.values()) < target: + target //= 8 + + # shrink sizes to size hints + x = min(x, size_hints["x"]) + if y: + y = min(y, size_hints["y"]) + if z: + z = min(z, size_hints["z"]) + + # if we are below original block size, scale up where we can; + # or if the calculated grid size is larger than the limit, we bump up the corresponding dimension + while x < min(size_hints["x"], TRITON_MAX_BLOCK["X"]) and ( + x * maxGridSize[0] < size_hints["x"] or conditional_product(x, y, z) < target + ): + x *= 2 + while ( + y + and y < min(size_hints["y"], TRITON_MAX_BLOCK["Y"]) + and ( + y * maxGridSize[1] < size_hints["y"] + or conditional_product(x, y, z) < target + ) + ): + y *= 2 + while ( + z + and z < min(size_hints["z"], TRITON_MAX_BLOCK["Z"]) + and ( + z * maxGridSize[2] < size_hints["z"] + or conditional_product(x, y, z) < target + ) + ): + z *= 2 + + num_warps = _num_warps( + conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 + ) + # we are going to arrive at 2 warps only if bs was too small due to + # numel being too small. However to workaround some ptx bugs we still + # want at least 4 warps if there's enough elements per thread + # given that this is a rare situation, don't expect this to affect perf + # in general + # see https://github.com/pytorch/pytorch/pull/97950 + if conditional_product(x, y, z) >= 128 and not torch.version.hip: + num_warps = max(num_warps, 4) + xnumel = size_hints["x"] + ynumel = size_hints.get("y") + znumel = size_hints.get("z") + + # Increase x to satisfy min_elem_per_thread requirements. + block_size = max( + conditional_product(x, y, z), + min_elem_per_thread * _NUM_THREADS_PER_WARP * num_warps, + ) + x *= math.ceil(block_size / conditional_product(x, y, z)) + + x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps) + x = min(x, size_hints["x"]) + + cfg = {"XBLOCK": x} + if y: + cfg["YBLOCK"] = y + if z: + cfg["ZBLOCK"] = z + check_max_block(cfg) + check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel) + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]: + """ + Converts a linear reduction numel to ND, in row major order. + This order is often desirable as it presents opportunities to coalesce memory + accesses. + For example, if r = 64 and size_hints = [32,32], this function returns [32, 2]. + This unraveling works because both r and size_hints are powers of 2. + """ + # Shrink r to size_hints. + r = min(r, get_total_reduction_numel(size_hints)) + num_reduction_dims = len( + [prefix for prefix in size_hints if prefix_is_reduction(prefix)] + ) + + remaining = r + rnumels = {} + for idx in range(num_reduction_dims - 1, -1, -1): + prefix = f"r{idx}_" + max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()]) + dim = min(max_size, remaining) + assert remaining % dim == 0, ( + f"Expected dimension '{dim}' to divide remaining size '{remaining}'" + ) + rnumels[prefix] = dim + remaining //= dim + + # Sanity check the results. + final_numel = conditional_product(*rnumels.values()) + assert r == final_numel, ( + f"Expected ND reduction size ({rnumels}) to have {r} elements." + ) + assert all(rnumels[prefix] <= size_hints[prefix] for prefix in rnumels), ( + f"rnumels exceed size_hints. {rnumels} > {size_hints}" + ) + + return rnumels + + +def triton_config_reduction( + size_hints, + x: int, + r: int, + num_stages=1, + num_warps=None, + register_intensive=False, +) -> Config: + """ + Construct a reduction triton config with some adjustment heuristics + based on size_hints. Size_hints is a tuple of numels in each tile + dimension and will be rounded up to the nearest power of 2. + """ + # Convert the linear reduction numel into a multi-dimensional block. + rnumels = _get_nd_reduction_numels(r, size_hints) + + # shrink sizes to size hints + x = min(x, size_hints["x"]) + + def total_numel() -> int: + return conditional_product(x, *rnumels.values()) + + target = total_numel() + if conditional_product(*size_hints.values()) < target: + target //= 8 + + # if we are below original block size, scale up where we can + while x < size_hints["x"] and total_numel() < target: + x *= 2 + for prefix in sorted(rnumels): + while rnumels[prefix] < size_hints[prefix] and total_numel() < target: + rnumels[prefix] *= 2 + + if num_warps is None: + num_warps = total_numel() // 128 + num_warps = _num_warps( + num_warps, max_num_warps=16, register_intensive=register_intensive + ) + + x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps) + + for prefix in sorted(rnumels): + while total_numel() > target: + if rnumels[prefix] == 1: + break + rnumels[prefix] //= 2 + + cfg = _get_config({"x": x, **rnumels}) + check_max_block(cfg) + check_config(cfg, xnumel=size_hints["x"]) + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def _get_config(numels: dict[str, int]) -> dict[str, int]: + """ + Convert numels ("x", "r0_", etc.) to block sizes ("XBLOCK", "R0_BLOCK"), etc. + """ + + return {prefix.upper() + "BLOCK": numel for prefix, numel in numels.items()} + + +def triton_config_tiled_reduction( + size_hints, x, y, r, num_stages=1, register_intensive=False +): + """ + Construct a tile reduction triton config with some adjustment + heuristics based on size_hints. Size_hints is a tuple of numels in + each tile dimension and will be rounded up to the nearest power of 2. + """ + # Convert the linear reduction numel into a multi-dimensional block. + rnumels = _get_nd_reduction_numels(r, size_hints) + + # shrink sizes to size hints + x = min(x, size_hints["x"]) + y = min(y, size_hints["y"]) + + def total_numel() -> int: + return conditional_product(x, y, *rnumels.values()) + + target = total_numel() + if conditional_product(*size_hints.values()) < target: + target //= 8 + + # if we are below original block size, scale up where we can + while x < size_hints["x"] and total_numel() < target: + x *= 2 + for prefix in sorted(rnumels): + while rnumels[prefix] < size_hints[prefix] and total_numel() < target: + rnumels[prefix] *= 2 + while y < size_hints["y"] and total_numel() < target: + y *= 2 + + cfg = _get_config({"x": x, "y": y, **rnumels}) + num_warps = _num_warps(total_numel() // 256, min_num_warps=1) + num_warps = _num_warps( + num_warps, max_num_warps=16, register_intensive=register_intensive + ) + check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"]) + check_max_block(cfg) + return Config(cfg, num_warps=num_warps, num_stages=num_stages) + + +def pointwise( + size_hints, + triton_meta, + tile_hint=None, + filename=None, + min_elem_per_thread=0, + inductor_meta=None, +): + """ + Construct @triton.heuristics() based on size_hints. + """ + inductor_meta = {} if inductor_meta is None else inductor_meta + assert not inductor_meta.get("no_x_dim") + + numel = functools.reduce(operator.mul, size_hints.values()) + bs = max(256, min(numel // 128, 1024)) + + hinted_configs = autotune_hints_to_configs( + inductor_meta.get("autotune_hints", OrderedSet()), + size_hints, + bs, + triton_meta["device"], + ) + + triton_config_with_settings = functools.partial( + triton_config, min_elem_per_thread=min_elem_per_thread + ) + + configs = None + if len(size_hints) == 1: + if disable_pointwise_autotuning(inductor_meta) and not ( + inductor_meta.get("max_autotune") + or inductor_meta.get("max_autotune_pointwise") + ): + configs = [triton_config_with_settings(size_hints, bs)] + else: + configs = [ + triton_config_with_settings(size_hints, bs, num_elements_per_warp=256), + triton_config_with_settings( + size_hints, bs // 2, num_elements_per_warp=64 + ), + *hinted_configs, + ] + if len(size_hints) == 2: + if ( + disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE + ) and not ( + inductor_meta.get("max_autotune") + or inductor_meta.get("max_autotune_pointwise") + ): + configs = [triton_config_with_settings(size_hints, 32, 32)] + else: + configs = [ + triton_config_with_settings(size_hints, 32, 32), + triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 + triton_config_with_settings(size_hints, 256, 16), + triton_config_with_settings(size_hints, 16, 256), + triton_config_with_settings(size_hints, bs, 1), + triton_config_with_settings(size_hints, 1, bs), + *hinted_configs, + ] + if len(size_hints) == 3: + if disable_pointwise_autotuning(inductor_meta): + configs = [triton_config_with_settings(size_hints, 16, 16, 16)] + else: + configs = [ + triton_config_with_settings(size_hints, 16, 16, 16), + triton_config_with_settings(size_hints, 64, 8, 8), + triton_config_with_settings(size_hints, 8, 64, 8), + triton_config_with_settings(size_hints, 8, 8, 64), + triton_config_with_settings(size_hints, bs, 1, 1), + triton_config_with_settings(size_hints, 1, bs, 1), + triton_config_with_settings(size_hints, 1, 1, bs), + *hinted_configs, + ] + + if not configs: + raise NotImplementedError(f"size_hints: {size_hints}") + return cached_autotune( + size_hints, + configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.POINTWISE, + filename=filename, + ) + + +def _reduction_configs( + *, size_hints: dict[str, int], inductor_meta: dict[str, Any] +) -> list[Config]: + reduction_hint = inductor_meta.get("reduction_hint", None) + + # Convert reductions to 1D, to simplify heuristics. + rnumel = get_total_reduction_numel(size_hints) + + register_intensive = False + MAX_R0_BLOCK = 2048 + if ( + size_hints["x"] >= 1024 + and inductor_meta.get("num_load", 0) + inductor_meta.get("num_reduction", 0) + >= 10 + ): + # A heuristics to reduce R0_BLOCK if a kernel potentially need many registers. + # Consider load and reduction since load need move data into registers and + # reduction needs an accumulator. + # + # The magic numbers are a bit arbitrary. + # + # We cannot rely on dynamically scaling down R0_BLOCK later, since sometimes + # triton makes it to use less registers with worse perf. Check: + # https://github.com/pytorch/pytorch/issues/126463 + # + # The heuristic is a very simple one since registers can be reused. But + # hopefully it can be a good enough indicator. + MAX_R0_BLOCK = 1024 + register_intensive = True + + def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False): + # For 3D case with tiling scores, create an adapted version + if "y" in size_hints: + assert "tiling_scores" in inductor_meta + return adapt_config_for_tiling( + size_hints, + inductor_meta["tiling_scores"], + x, + r, + num_warps=num_warps, + num_stages=num_stages, + register_intensive=register_intensive, + ) + else: + # For other cases, use the original function + return triton_config_reduction( + size_hints, + x, + r, + num_warps=num_warps, + num_stages=num_stages, + register_intensive=register_intensive, + ) + + contiguous_config = make_config( + 1, + min(rnumel, MAX_R0_BLOCK), + register_intensive=register_intensive, + ) + outer_config = make_config(64, 8, register_intensive=register_intensive) + tiny_config = make_config( + 2 * (256 // rnumel) if rnumel <= 256 else 1, + min(rnumel, MAX_R0_BLOCK), + register_intensive=register_intensive, + ) + # For 3d tiling, default to more autotuning initially + if "y" in size_hints: + pass + elif inductor_meta.get("max_autotune") or inductor_meta.get( + "max_autotune_pointwise" + ): + pass # skip all these cases + elif reduction_hint == ReductionHint.INNER: + return [contiguous_config] + elif reduction_hint == ReductionHint.OUTER: + return [outer_config] + elif reduction_hint == ReductionHint.OUTER_TINY: + return [tiny_config] + if disable_pointwise_autotuning(inductor_meta): + return [make_config(32, 128)] + return [ + contiguous_config, + outer_config, + tiny_config, + make_config(64, 64), + make_config(8, 512), + # halve the XBLOCK/Rn_BLOCK compared to outer_config + # TODO: this may only be beneficial when each iteration of the reduction + # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 + make_config(64, 4, num_warps=8), + ] + + +def match_target_block_product( + size_hints, tiling_scores, target_block_product, min_block_size=1 +): + """ + Distribute block sizes across dimensions according to tiling scores, + aiming to match a target product of block sizes. + """ + total_score = sum(tiling_scores.values()) + if total_score == 0: + # just assume even score with no minimum block size + min_block_size = 1 + tiling_scores = dict.fromkeys(tiling_scores.keys(), target_block_product) + + # First, give each coalescing dimension at least min_block_size + block_sizes = {} + relative_scores = {} + curr_block_product = 1 + + for dim, score in tiling_scores.items(): + if score == 0: + block_sizes[dim] = 1 + continue + + block_sizes[dim] = min_block_size + curr_block_product *= min_block_size + relative_scores[dim] = score / total_score + + # Scale up dimensions by their relative scores until we reach the target + while curr_block_product < target_block_product and len(relative_scores): + dim, score = max(relative_scores.items(), key=lambda item: item[1]) + + # Check if we've hit the max for this dimension + if ( + block_sizes[dim] >= TRITON_MAX_BLOCK[dim.capitalize()] + or block_sizes[dim] >= size_hints[dim] + ): + del relative_scores[dim] + continue + + block_sizes[dim] *= 2 + relative_scores[dim] /= 2 + curr_block_product *= 2 + + return block_sizes + + +def adapt_config_for_tiling( + size_hints, + tiling_scores, + original_x, + original_r, + num_warps=None, + num_stages=1, + register_intensive=False, + persistent_reduction=False, +) -> Config: + """ + Create an adapted configuration based on tiling scores, + redistributing the same total block size (x * r) according to tiling scores. + """ + assert all(s in tiling_scores for s in size_hints) + target_block_product = original_x * original_r + block_sizes = match_target_block_product( + size_hints, tiling_scores, target_block_product + ) + + return triton_config_tiled_reduction( + size_hints, + block_sizes["x"], + block_sizes["y"], + block_sizes["r0_"], + num_stages=num_stages, + register_intensive=register_intensive, + ) + + +def reduction( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + """args to @triton.heuristics()""" + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints["x"] = 1 + + assert triton_meta is not None + + configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) + return cached_autotune( + size_hints, + configs=configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.REDUCTION, + filename=filename, + ) + + +def cooperative_reduction( + size_hints, + reduction_hint, + triton_meta, + filename, + inductor_meta, +): + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints["x"] = 1 + + # Cooperative reductions currently only support a single reduction dimension. + assert len(size_hints) == 2, ( + "Cooperative reductions don't support tiling reduction dims" + ) + xnumel, rnumel = size_hints["x"], size_hints["r0_"] + + # TODO(jansel): we should base target on the SM count of the local GPU + target = 64 + split = max(1, min(target // xnumel, TRITON_MAX_RSPLIT)) + assert rnumel >= split + assert split <= TRITON_MAX_RSPLIT + if inductor_meta["persistent_reduction"]: + configs = _persistent_reduction_configs( + {"x": xnumel, "r0_": rnumel // split}, reduction_hint, inductor_meta + ) + else: + configs = _reduction_configs( + size_hints={"x": xnumel, "r0_": rnumel // split}, + inductor_meta=inductor_meta, + ) + for config in configs: + config.kwargs["RSPLIT"] = split + # TODO(jansel): add more configs in max_autotune + + return cached_autotune( + size_hints, + configs=configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.REDUCTION, + filename=filename, + ) + + +def _persistent_reduction_configs( + size_hints, + reduction_hint=False, + inductor_meta=None, +): + xnumel = size_hints["x"] + rnumel = get_total_reduction_numel(size_hints) + + MAX_PERSISTENT_BLOCK_NUMEL = 4096 + + if "y" not in size_hints: + configs = [ + triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) + for xblock in (1, 8, 32, 128) + if xblock == 1 + or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel) + ] + else: + configs = [] + assert "tiling_scores" in inductor_meta + x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")} + for target_block_size in (1, 8, 32, 64, 128): + if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL: + continue + + block_sizes = match_target_block_product( + size_hints, x_y_scores, target_block_size + ) + configs.append( + triton_config_tiled_reduction( + size_hints, block_sizes["x"], block_sizes["y"], rnumel + ) + ) + + # defer to more autotuning, initially + if "y" in size_hints: + pass + # TODO(jansel): we should be able to improve these heuristics + elif reduction_hint == ReductionHint.INNER and rnumel >= 256: + configs = configs[:1] + elif reduction_hint == ReductionHint.OUTER: + configs = configs[-1:] + elif reduction_hint == ReductionHint.OUTER_TINY: + configs = [ + triton_config_reduction( + size_hints, + 2 * (256 // rnumel) if rnumel <= 256 else 1, + rnumel, + ) + ] + for c in configs: + # we don't need Rn_BLOCK for persistent reduction + for prefix in size_hints: + if prefix_is_reduction(prefix): + c.kwargs.pop(f"{prefix.upper()}BLOCK") + + if disable_pointwise_autotuning(inductor_meta): + configs = configs[:1] + + return configs + + +def persistent_reduction( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints["x"] = 1 + + configs = _persistent_reduction_configs(size_hints, reduction_hint, inductor_meta) + + return cached_autotune( + size_hints, + configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + filename=filename, + heuristic_type=HeuristicType.PERSISTENT_REDUCTION, + ) + + +def split_scan( + size_hints, + reduction_hint=False, + triton_meta=None, + filename=None, + inductor_meta=None, +): + """Heuristic for TritonSplitScanKernel""" + inductor_meta = {} if inductor_meta is None else inductor_meta + inductor_meta["reduction_hint"] = reduction_hint + if inductor_meta.get("no_x_dim"): + size_hints["x"] = 1 + + assert triton_meta is not None + if len(size_hints) != 2: + raise NotImplementedError(f"size_hints: {size_hints}") + + configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta) + + # Fixup configs to enforce the minimum Rn_BLOCK size + min_rblock = inductor_meta.get("min_split_scan_rblock", 256) + for cfg in configs: + for var in list(cfg.kwargs.keys()): + if var.startswith("R") and cfg.kwargs[var] < min_rblock: + cfg.kwargs[var] = min_rblock + + return cached_autotune( + size_hints, + configs=configs, + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.SPLIT_SCAN, + filename=filename, + ) + + +def template( + num_stages, + num_warps, + triton_meta, + num_consumer_groups=0, + num_buffers_warp_spec=0, + filename=None, + inductor_meta=None, +): + """ + Compile a triton template + """ + # Prepare the base configuration + config_args = { + "num_stages": num_stages, + "num_warps": num_warps, + } + + # Conditionally add arguments based on HAS_WARP_SPEC + if HAS_WARP_SPEC: + config_args.update( + { + "num_consumer_groups": num_consumer_groups, + "num_buffers_warp_spec": num_buffers_warp_spec, + } + ) + return cached_autotune( + None, + [triton.Config({}, **config_args)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.TEMPLATE, + filename=filename, + ) + + +def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]: + """Extract triton.Config options that should become kwargs""" + popped = {} + for key in ( + "num_warps", + "num_stages", + "num_ctas", + "maxnreg", + "num_consumer_groups", + "num_buffers_warp_spec", + ): + val = config.pop(key, None) + if val is not None: + popped[key] = val + return popped + + +def config_to_dict(config: Config) -> dict[str, Any]: + config_dict = { + **config.kwargs, + "num_warps": config.num_warps, + "num_stages": config.num_stages, + } + if HAS_WARP_SPEC: + config_dict.update( + { + "num_consumer_groups": getattr(config, "num_consumer_groups", 0), + "num_buffers_warp_spec": getattr(config, "num_buffers_warp_spec", 0), + } + ) + return config_dict + + +def config_from_dict(config: dict[str, Any]) -> Config: + config = {**config} + return Config(config, **_pop_config_kwargs(config)) + + +def fixed_config(config, filename, triton_meta, inductor_meta): + """ + Used when the configuration is already decided at compile time + """ + config = {**config} + return cached_autotune( + None, + [triton.Config(config, **_pop_config_kwargs(config))], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.FIXED, + filename=filename, + ) + + +def user_autotune( + configs, triton_meta, filename=None, inductor_meta=None, custom_kernel=False +): + """ + Compile a user defined triton kernel + """ + if len(configs) == 0: + configs = [triton.Config({})] + else: + configs = [*map(config_from_dict, configs)] + return cached_autotune( + None, + configs, + triton_meta=triton_meta, + heuristic_type=HeuristicType.USER_AUTOTUNE, + filename=filename, + inductor_meta=inductor_meta, + custom_kernel=custom_kernel, + ) + + +def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): + """ + Compile a triton foreach kernel + """ + return cached_autotune( + None, + [triton.Config({}, num_stages=1, num_warps=num_warps)], + triton_meta=triton_meta, + inductor_meta=inductor_meta, + heuristic_type=HeuristicType.TEMPLATE, + filename=filename, + ) + + +@dataclasses.dataclass +class GridExpr: + """Generate code for grid size expressions in launcher""" + + inductor_meta: dict[str, Any] + mode: Literal["python", "cpp"] = "python" + prefix: list[str] = dataclasses.field(default_factory=list) + x_grid: Union[str, int] = 1 + y_grid: Union[str, int] = 1 + z_grid: Union[str, int] = 1 + + def __post_init__(self) -> None: + assert self.mode in ("python", "cpp") + + def generate(self, meta: dict[str, int]) -> None: + raise NotImplementedError + + def ceildiv( + self, numel: Union[str, int], block: Union[None, int, str] + ) -> Union[str, int]: + if block is None or block == 1: + return numel + if isinstance(numel, int) and isinstance(block, int): + return ceildiv(numel, block) # constant fold + if self.mode == "python": + return f"-(({numel}) // -({block}))" + # trick above doesn't work in C++ due to rounding differences + return f"(({numel} + ({block} - 1)) / ({block}))" + + def maximum(self, seq: list[Union[int, str]]) -> Union[int, str]: + """Codegen for max function with constant folding, constants are represented as int""" + items = self._constant_fold(max, seq) + if len(items) <= 1: + return items[0] + if self.mode == "python": + return f"max({', '.join(map(str, items))})" + return functools.reduce(lambda x, y: f"std::max({x}, {y})", items) + + def summation(self, seq: list[Union[int, str]]) -> Union[int, str]: + """Codegen for sum function with constant folding, constants are represented as int""" + items = self._constant_fold(sum, seq) + if len(items) <= 1: + return items[0] + return " + ".join(map(str, items)) + + def _constant_fold( + self, fn: Callable[[list[int]], int], seq: list[Union[int, str]] + ) -> list[Union[int, str]]: + """Constant fold through a commutative fn where ints are constants""" + items: list[Union[int, str]] = [x for x in seq if not isinstance(x, int)] + const_items = [x for x in seq if isinstance(x, int)] + if const_items: + items.append(fn(const_items)) + return items + + def assign_tmp(self, name: str, expr: Union[str, int]) -> str: + # Grid functions are one per kernel, so name collisions are fine + if self.mode == "python": + return f"{name} = {expr}" + if self.mode == "cpp": + return f"uint32_t {name} = {expr};" + raise AssertionError(f"invalid mode {self.mode}") + + @staticmethod + def from_meta( + inductor_meta: dict[str, Any], + cfg: Union[Config, dict[str, int]], + mode: Literal["python", "cpp"] = "python", + ) -> GridExpr: + grid_cls = globals()[inductor_meta["grid_type"]] + assert issubclass(grid_cls, GridExpr) + grid = grid_cls(inductor_meta=inductor_meta, mode=mode) + if isinstance(cfg, Config): + cfg = config_to_dict(cfg) + grid.generate(cfg) + return grid + + def eval_slow(self, meta: dict[str, int]) -> tuple[int, int, int]: + scope = {**meta} + for line in self.prefix: + exec(line, scope) + exec(f"grid_0 = {self.x_grid}", scope) + exec(f"grid_1 = {self.y_grid}", scope) + exec(f"grid_2 = {self.z_grid}", scope) + return scope["grid_0"], scope["grid_1"], scope["grid_2"] + + +class Grid1D(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK")) + + +class Grid2D(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK")) + self.y_grid = self.ceildiv("ynumel", meta.get("YBLOCK")) + + +class Grid3D(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK")) + self.y_grid = self.ceildiv("ynumel", meta.get("YBLOCK")) + self.z_grid = self.ceildiv("znumel", meta.get("ZBLOCK")) + + +class Grid2DWithYZOverflow(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + self.x_grid = self.ceildiv("xnumel", meta.get("XBLOCK")) + self.prefix.extend( + [ + self.assign_tmp( + "y_grid_raw_", self.ceildiv("ynumel", meta.get("YBLOCK")) + ), + self.assign_tmp( + "y_grid_div_", self.ceildiv("y_grid_raw_", get_max_y_grid()) + ), + ] + ) + self.y_grid = self.ceildiv("y_grid_raw_", "y_grid_div_") + self.z_grid = "y_grid_div_" + + +class CooperativeReductionGrid(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + self.x_grid = str(meta["RSPLIT"]) + self.y_grid = self.ceildiv("xnumel", meta.get("XBLOCK")) + + +class SplitScanGrid(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + assert meta.get("XBLOCK", 1) == 1 + self.x_grid = self.ceildiv("r0_numel", meta.get("R0_BLOCK")) + self.y_grid = "xnumel" + + +class FixedGrid(GridExpr): + @staticmethod + def setup_grid_as_args() -> dict[str, Any]: + """Inductor meta so the launcher takes three extra grid arguments""" + return { + "grid_type": FixedGrid.__name__, + "fixed_grid": ["_grid_0", "_grid_1", "_grid_2"], + "extra_launcher_args": ["_grid_0", "_grid_1", "_grid_2"], + } + + def generate(self, meta: dict[str, int]) -> None: + self.x_grid, self.y_grid, self.z_grid = self.inductor_meta["fixed_grid"] + + +class PrecomputedGrid(GridExpr): + def generate(self, meta: dict[str, int]) -> None: + for candidate in self.inductor_meta["precomputed_grids"]: + if all(meta.get(k) == v for k, v in candidate["config"].items()): + self.x_grid, self.y_grid, self.z_grid = candidate[self.mode] + return + raise AssertionError( + f"Precomputed grid not found for {meta} in {self.inductor_meta['precomputed_grids']}" + ) + + +class ComboKernelGrid(GridExpr): + def generate(self, meta: dict[str, int]): + combo_meta = self.inductor_meta["combo_grid_meta"] + if combo_meta["default_config"]: + meta = {**combo_meta["default_config"], **meta} + no_x_dims = [] + xnumels = [] + ynumels = [] + znumels = [] + for num in range(combo_meta["num_kernels"]): + assert ( + combo_meta[f"xnumel_{num}"] is None or combo_meta[f"xnumel_{num}"] > 0 + ) + no_x_dims.append(combo_meta[f"no_x_dim_{num}"]) + xnumels.append(combo_meta[f"xnumel_{num}"] or f"xnumel_{num}") + if f"ynumel_{num}" in combo_meta: + ynumels.append(combo_meta[f"ynumel_{num}"] or f"ynumel_{num}") + if f"znumel_{num}" in combo_meta: + znumels.append(combo_meta[f"znumel_{num}"] or f"znumel_{num}") + + self.x_grid = self.combo_x_grid(xnumels, no_x_dims, meta) + if combo_meta["min_blocks"]: + self.x_grid = self.maximum([self.x_grid, combo_meta["min_blocks"]]) + if ynumels: + self.y_grid = self.ceildiv(self.maximum(ynumels), meta.get("YBLOCK")) + if znumels: + self.z_grid = self.ceildiv(self.maximum(znumels), meta.get("ZBLOCK")) + + def combo_x_grid( + self, + xnumels: list[Union[int, str]], + no_x_dims: list[bool], + meta: dict[str, int], + ) -> Union[str, int]: + raise NotImplementedError + + +class SequentialComboKernelGrid(ComboKernelGrid): + def combo_x_grid( + self, + xnumels: list[Union[int, str]], + no_x_dims: list[bool], + meta: dict[str, int], + ) -> Union[str, int]: + assert len(xnumels) == len(no_x_dims) + return self.summation( + [ + self.ceildiv(x, 1 if no_x_dim else meta.get("XBLOCK")) + for x, no_x_dim in zip(xnumels, no_x_dims) + ] + ) + + +class RoundRobinComboKernelGrid(ComboKernelGrid): + def combo_x_grid( + self, + xnumels: list[Union[int, str]], + no_x_dims: list[bool], + meta: dict[str, int], + ) -> str: + assert len(xnumels) == len(no_x_dims) + num_kernels = self.inductor_meta["combo_grid_meta"]["num_kernels"] + exprs = [x for x, no_x_dim in zip(xnumels, no_x_dims) if no_x_dim] + xnumels_x_dim = [x for x, no_x_dim in zip(xnumels, no_x_dims) if not no_x_dim] + if xnumels_x_dim: + exprs.append(self.ceildiv(self.maximum(xnumels_x_dim), meta.get("XBLOCK"))) + return f"({self.maximum(exprs)}) * {num_kernels}" diff --git a/phivenv/Lib/site-packages/torch/_lazy/__init__.py b/phivenv/Lib/site-packages/torch/_lazy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a3d1de2dfa09c3a1fd1cdad5465ff8edad09c29 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_lazy/__init__.py @@ -0,0 +1,55 @@ +# mypy: allow-untyped-defs + +import torch._C._lazy +from torch.utils._pytree import tree_flatten, tree_unflatten + +from .closure import add_step_closure, run_step_closures + + +def mark_step(device: str = "", wait=False): + """Triggers a mark step, which amounts to + - collecting a group of 'live' lazy tensors to index into the compilation cache + (lowering/compiling their IR graphs if not cached) + - kicking off execution of the compiled function + - (optionally, wait=True) waiting for cpu-side execution to complete (does not sync the accelerator) + """ + # TODO(whc) expand this to include backend hooks and align with XLA backend needs + torch._C._lazy._mark_step(device, [], wait=wait) + + run_step_closures() + + +def wait_device_ops(devices=None): + """Waits for all the async operations on the given devices to complete. + Args: + devices (string..., optional): The devices whose async ops need to be waited + for. If empty, all the local devices will be waited for. + """ + if devices is None: + devices = [] + torch._C._lazy._wait_device_ops(devices=devices) + + +def sync_multi(tensors, devices): + """ + Sync the list of lazy tensors so there IR get lowered for the activate backend + and the compiled computation graph get cached. + """ + torch._C._lazy._sync_multi(tensors, devices) + + +def get_tensor_id(tensor): + """Return a unique id of the lazy tensor maintained by LTC""" + return torch._C._lazy._get_tensor_id(tensor) + + +def to_cpu(tensors, devices=None): + devices = devices or ["lazy"] + + flattened, spec = tree_flatten(tensors) + sync_multi(flattened, devices) + return tree_unflatten([t.to("cpu") for t in flattened], spec) + + +def save(tensors, *args, **kwargs): + torch.save(to_cpu(tensors), *args, **kwargs) diff --git a/phivenv/Lib/site-packages/torch/_lazy/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6424c9e5d426af3ebde544b0fe1274368adf7c6f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_lazy/__pycache__/closure.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/closure.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e896d9c3a1cf9196a72234ffc56f23637e5c0768 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/closure.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_lazy/__pycache__/computation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/computation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cae3b2e9de6c3606457a057388958cf7f55ef79 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/computation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_lazy/__pycache__/config.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9aa987f51debe9dd5285e9ff85b4988a115a8f47 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/config.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_lazy/__pycache__/debug.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/debug.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9493adaefec691d61fc524361a1c945ab0908460 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/debug.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_lazy/__pycache__/device_context.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/device_context.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..558a6901456bef2fd98a6fce907384d90a0a8664 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/device_context.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1680e4f2269cf9f31b9bfaa15b4aad1dc6293bb4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/extract_compiled_graph.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0756a338598c146e4d0bc26c08d4a26c4b085d71 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/ir_cache.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_lazy/__pycache__/metrics.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/metrics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2de5b24d14cb1cebac83e94420bd8cb99b1e4a1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/metrics.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d850fd7d10979caa65cee572c17d2f87113a035 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/tensor_factory_functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c551e00592c77cea9ab254e2cda408940dec8a9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_lazy/__pycache__/ts_backend.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_lazy/closure.py b/phivenv/Lib/site-packages/torch/_lazy/closure.py new file mode 100644 index 0000000000000000000000000000000000000000..a51b9804c4d25fa741c531aeac6acc84c3a10015 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_lazy/closure.py @@ -0,0 +1,135 @@ +# mypy: allow-untyped-defs +import os +import threading +from queue import Empty as EmptyQueue, Queue + +from torch._lazy.device_context import get_device_context + + +class ClosureHandler: + def __init__(self) -> None: + pass + + def run(self, closure): + """Run closure function + + Args: + closure: callable function to run + """ + closure() + + def __call__(self, closures): + for closure in closures: + self.run(closure) + + +class AsyncClosureHandler(ClosureHandler): + """Handler for Asynchronous Step Closures + Args: + max_queue_size: The maximum length of the closure queue after which + the training loop will block until closures are evaluated. + By default, a reasonable limit of a maximum of 100 on the queue. + This value can be set using the `XLA_MAX_ASYNC_QUEUE` environment + variable. + """ + + def __init__(self, max_queue_size=100): + super().__init__() + self._closure_queue: Queue = Queue( + int(os.environ.get("LTC_MAX_ASYNC_QUEUE", max_queue_size)) + ) + self._closure_exception: Queue = Queue() + self._closure_lock = threading.Lock() + self._closure_event_loop_finished = threading.Event() + self._closure_event_loop = None + + def start_event_loop(self): + """Start closure event loop if not started""" + if self._closure_event_loop is None: + + def event_loop(): + # Run loop until closure event is set and closure queue is empty + while True: + try: + closure = self._closure_queue.get(block=True, timeout=3) + closure() + self._closure_queue.task_done() + except EmptyQueue: + with self._closure_lock: + if self._closure_queue.empty(): + self._closure_event_loop_finished.set() + return + except Exception as e: + self._closure_exception.put(e) + return + + self._closure_event_loop = threading.Thread(target=event_loop) + self._closure_event_loop.start() + + def run(self, closure): + with self._closure_lock: + self._closure_queue.put(closure, block=True) + if ( + self._closure_event_loop is None + or not self._closure_event_loop.is_alive() + ): + try: + e = self._closure_exception.get(block=False) + raise RuntimeError( + "Cannot run asynchronous closure due to previously raised exception" + ) from e + except EmptyQueue: + self._closure_event_loop = None + self.start_event_loop() + + +def add_step_closure(closure, args=(), run_async=False): + """Adds a closure to the list of the ones to be run at the end of the step. + Many times during model training there is the need to print/report (print to + console, post to tensorboard, etc...) information which require the content of + intermediary tensors to be inspected. + Inspecting different tensors content in different points of the model code + requires many executions and typically causes performance issues. + Adding a step closure will ensure that it will be run after the barrier, when + all the live tensors will be already materialized to device data. + Live tensors which will include the ones captured by the closure arguments. + So using `add_step_closure()` will ensure a single execution will be + performed, even when multiple closures are queued, requiring multiple tensors + to be inspected. + Step closures will be run sequentially in the order they have been queued. + Note that even though using this API the execution will be optimized, it is + advised to throttle the printing/reporting events once every N steps. + Args: + closure (callable): The function to be called. + args (tuple): The arguments to be passed to the closure. + run_async: If True, run the closure asynchronously. + """ + devctx = get_device_context() + closures_type = "async_step_closures" if run_async else "step_closures" + step_closures = getattr(devctx, closures_type, None) + if step_closures is None: + step_closures = [] + setattr(devctx, closures_type, step_closures) + step_closures.append(lambda a=args: closure(*a)) + + +def run_step_closures(): + devctx = get_device_context() + async_step_closures = getattr(devctx, "async_step_closures", None) + if async_step_closures is not None: + devctx.async_step_closures = [] # type: ignore[attr-defined] + async_closure_handler = getattr(devctx, "async_closure_handler", None) + if async_closure_handler is None: + async_closure_handler = AsyncClosureHandler() + devctx.async_closure_handler = async_closure_handler # type: ignore[attr-defined] + async_closure_handler(async_step_closures) + + step_closures = getattr(devctx, "step_closures", None) + if step_closures is not None: + devctx.step_closures = [] # type: ignore[attr-defined] + closure_handler = getattr(devctx, "closure_handler", None) + if closure_handler is None: + closure_handler = ClosureHandler() + devctx.closure_handler = closure_handler # type: ignore[attr-defined] + closure_handler(step_closures) + return devctx diff --git a/phivenv/Lib/site-packages/torch/_lazy/computation.py b/phivenv/Lib/site-packages/torch/_lazy/computation.py new file mode 100644 index 0000000000000000000000000000000000000000..b4d6bd4989c6b8c29d6cf88130f3615a4160101e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_lazy/computation.py @@ -0,0 +1,27 @@ +# mypy: allow-untyped-defs +import torch._C._lazy +import torch._C._lazy_ts_backend + + +def get_tensors_ts_device_data_node(tensors): + """Return tensor ids and eager tensors for DeviceData nodes in the + IR for the passed in lazy tensors. + + TODO: This API is currently ts backend specific. We are working on + generalizing it to all backends including XLA. + """ + return torch._C._lazy_ts_backend._get_tensors_ts_device_data_node(tensors) + + +def get_graph_hash(tensors): + """Return the graph hash for the passed in lazy tensors""" + return torch._C._lazy._get_graph_hash(tensors) + + +def run_cached_graph(hash_str, graph_inputs): + """Running the cached computation graph with the given inputs + + TODO: This API is currently ts backend specific. We are working on + generalizing it to all backends including XLA. + """ + return torch._C._lazy_ts_backend._run_cached_graph(hash_str, graph_inputs) diff --git a/phivenv/Lib/site-packages/torch/_lazy/config.py b/phivenv/Lib/site-packages/torch/_lazy/config.py new file mode 100644 index 0000000000000000000000000000000000000000..e5ea4e578509979fb07d6dc1fa66bf9c20df9947 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_lazy/config.py @@ -0,0 +1,16 @@ +import torch._C._lazy + + +def get_force_fallback() -> str: + """Get the config used to force LTC fallback""" + return torch._C._lazy._get_force_fallback() + + +def set_force_fallback(configval: str) -> None: + """Set the config used to force LTC fallback""" + torch._C._lazy._set_force_fallback(configval) + + +def set_reuse_ir(val: bool) -> None: + """Set the config to reuse IR nodes for faster tracing""" + torch._C._lazy._set_reuse_ir(val) diff --git a/phivenv/Lib/site-packages/torch/_lazy/debug.py b/phivenv/Lib/site-packages/torch/_lazy/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..5ae8ff0919cf0afdb338cf839d39b6f62c04b853 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_lazy/debug.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch._C._lazy + + +def render_ir_graph(tensors): + """Return a text dump of the LTC IR graph in dot format for the tensors. + The text can be processed by tools like dot to be rendered in pdf,png etc.""" + return torch._C._lazy._get_tensors_dot(tensors) + + +def dump_ir(tensors, ir_format): + """Return a dump of the tensors in the specified format. + Valid format are + - text: for LTC IR + - backend: for the activate backend IR + """ + if ir_format == "text": + return torch._C._lazy._get_tensors_text(tensors) + elif ir_format == "backend": + return torch._C._lazy._get_tensors_backend(tensors) + else: + raise RuntimeError(f"Unrecognized IR format: {ir_format}") diff --git a/phivenv/Lib/site-packages/torch/_lazy/device_context.py b/phivenv/Lib/site-packages/torch/_lazy/device_context.py new file mode 100644 index 0000000000000000000000000000000000000000..4b6a66823d37b5fd51ba5ab81dc120d18814e6e4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_lazy/device_context.py @@ -0,0 +1,25 @@ +import threading +from typing import Any, Optional + +import torch._C._lazy + + +class DeviceContext: + _CONTEXTS: dict[str, Any] = {} + _CONTEXTS_LOCK = threading.Lock() + + def __init__(self, device: str) -> None: + self.device = device + + +def get_device_context(device: Optional[str] = None) -> DeviceContext: + if device is None: + device = torch._C._lazy._get_default_device_type() + else: + device = str(device) + with DeviceContext._CONTEXTS_LOCK: + devctx = DeviceContext._CONTEXTS.get(device, None) + if devctx is None: + devctx = DeviceContext(device) + DeviceContext._CONTEXTS[device] = devctx + return devctx diff --git a/phivenv/Lib/site-packages/torch/_lazy/extract_compiled_graph.py b/phivenv/Lib/site-packages/torch/_lazy/extract_compiled_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..bb80250351d8fe1f029e1da3dd86db0d5c83981c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_lazy/extract_compiled_graph.py @@ -0,0 +1,225 @@ +# mypy: allow-untyped-defs +import copy +import dataclasses +import itertools +import os +from typing import Any, Callable + +import torch +import torch._lazy as lazy +import torch._lazy.metrics as metrics +from torch import fx +from torch._lazy import computation, debug as lazy_debug +from torch._lazy.tensor_factory_functions import tensor_factory_functions + + +debug = os.environ.get("debug_extract_compiled_graph") is not None + + +@dataclasses.dataclass +class GraphInputMatcher: + """ + The GraphInputMatcher class setup the graph inputs for future calls after lazy tracing. + Specifically, those graph inputs corresponding to method parameters should be replaced with the + arguments for the current call. + + tensor_id_to_arg_idx maps the tensor id to the parameter index. + graph_input_tensor_ids, graph_input_ivalues list the tensor_id and ivalue for each of the + TS/XLA graph inputs. + """ + + tensor_id_to_arg_idx: dict[int, int] + graph_input_tensor_ids: list[int] + # there are 2 categories of graph_input_tensors. + # Category 1: those whose id are not found in tensor_id_to_arg_idx. These are + # most likely const tensors and we can get its content from graph_input_tensors + # Category 2: those whose id are found in tensor_id_to_arg_idx. We should get + # the tensor from method arguments + graph_input_ivalues: list[Any] + + # get the real graph input tensors + def __call__(self, args): + real_input = [] + for tensor_id, traced_ivalue in zip( + self.graph_input_tensor_ids, self.graph_input_ivalues + ): + arg_idx = self.tensor_id_to_arg_idx.get(tensor_id, None) + if arg_idx is None: + inp = traced_ivalue + else: + inp = args[arg_idx] + real_input.append(inp) + return real_input + + +class ReturnValueHandler: + r""" + When ltc_sync_multi is called on multi tensors, the compiled graph + will contain output only for unique tensors - if a tensor appears multiple + times in the input to _ltc_sync_multi, only the first occurance matters. + + However from python level, we still expect multi tensors returned with duplciation + even if the TS graph dedup the output. e.g. for method: + + def forward(self, a): + return a, a + + the TS graph captured by LTC will return a single tensor, but Python method expects 2. + + This class dedup the lazy tensors first to get the index that will be used + to duplicate the eager tensors later. + """ + + def __init__(self, lazy_out_list): + self.index: list[list[int]] = [] + self.total_count = len(lazy_out_list) + + tensor_id_to_idx: dict[int, int] = {} + for dup_idx, lazy_tensor in enumerate(lazy_out_list): + uniq_idx = tensor_id_to_idx.get(id(lazy_tensor), None) + if uniq_idx is not None: + self.index[uniq_idx].append(dup_idx) + else: + uniq_idx = len(self.index) + self.index.append([dup_idx]) + tensor_id_to_idx[id(lazy_tensor)] = uniq_idx + + def duplicate_eager_tensors(self, eager_tensor_list): + duplicated_list = [None] * self.total_count + assert len(eager_tensor_list) == len(self.index) + + for uniq_idx, eager_tensor in enumerate(eager_tensor_list): + for dup_idx in self.index[uniq_idx]: + duplicated_list[dup_idx] = eager_tensor + return duplicated_list + + +def force_lazy_device(model: fx.GraphModule): + """ + Factory methods in a Fx graph may create tensors for a specific eager devices. + If we take no actions, those eager tensors will be mixed with lazy tensors and + cause crash. This method overwrite those eager device to lazy device. + """ + + def tolazydevice(dev): + if isinstance(dev, torch.device): + return torch.device("lazy", index=dev.index) + return dev + + def hasDeviceArg(args, kwargs): + return any( + isinstance(arg, torch.device) + for arg in itertools.chain(args, kwargs.values()) + ) + + for nd in model.graph.nodes: + nd.args = tuple(tolazydevice(arg) for arg in nd.args) + nd.kwargs = {k: tolazydevice(v) for k, v in nd.kwargs.items()} + + # For torchbench like yolov3, hf_Bart, dynamo generates Fx graph that return + # eager tensors on the default device + # (check https://gist.github.com/shunting314/eabdf6c769c59bc384469717b8f9bb7f for yolove, + # and https://gist.github.com/shunting314/8d5e2d9348a3258959d3954186c48814 for hf_Bart). + # To force those tensors on the lazy device, we can not simply override + # the device argument since there is no explicit device argument. + # What we are doing here is, for the list of covered tensor factory methods + # we add a lazy device argument explicity. + # + # TODO: This solution is no ideal since we may miss some factory methods. In future + # when we support lazy mode, this method can be replaced by that. + if nd.target in tensor_factory_functions and not hasDeviceArg( + nd.args, nd.kwargs + ): + kwargs = dict(nd.kwargs) # nd.kwargs is immutable. make a mutable copy. + kwargs["device"] = torch.device("lazy") + nd.kwargs = kwargs + + model.recompile() + + +def get_fallback_ops(): + fallback_ops = [] + for opname in metrics.counter_names(): + if "aten::" not in opname: + continue + val = int(metrics.counter_value(opname)) + if val > 0: + fallback_ops.append(f"{opname}={val}") + + return fallback_ops + + +def extract_compiled_graph(model: fx.GraphModule, example_inputs) -> Callable: + """ + Optimize an eager model with LTC and returns a wrapper to execute the + compiled graph directly without retracing. It depends on other mechanisms + like TorchDynamo guards to guarantee the returned wrapper is only called + when it's safe. + """ + lazy_args = [arg.to(device="lazy") for arg in example_inputs] + args_tensor_ids = [lazy.get_tensor_id(lazy_arg) for lazy_arg in lazy_args] + tensor_id_to_arg_idx = {tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)} + lazy_model = copy.deepcopy(model).to(device=torch.device("lazy")) + force_lazy_device(lazy_model) + + # This line executes lazy tracing and enable us extracting compiled graph later + metrics.reset() + lazy_out = lazy_model(*lazy_args) + fallback_ops = get_fallback_ops() + metrics.reset() + + if len(fallback_ops) > 0: + raise RuntimeError( + f"Fail to extact the compiled graph because of fallback: {','.join(fallback_ops)}" + ) + + if not isinstance(lazy_out, (tuple, list)): + lazy_out = (lazy_out,) + + args_and_out = tuple(lazy_args) + tuple(lazy_out) + return_value_handler = ReturnValueHandler(args_and_out) + if debug: + print("Fx code:\n", model.code) + print("LTC IR:", lazy_debug.dump_ir(args_and_out, "text")) + + # TODO: this part is TS backend specific for now and will be generalized to + # support XLA + ( + graph_input_tensor_ids, + graph_input_ivalues, + ) = computation.get_tensors_ts_device_data_node(args_and_out) + assert len(graph_input_tensor_ids) == len(graph_input_ivalues) + graph_input_matcher = GraphInputMatcher( + tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_ivalues + ) + + graph_hash = computation.get_graph_hash(args_and_out) + + if debug: + print("graph_hash", graph_hash) + print(f"args_tensor_ids {args_tensor_ids}") + print("tensor ids from device data:", graph_input_tensor_ids) + + # sync the list of output tensors so the computation graph for these + # tensors will be cached. Those computation graphs can be retrieved + # by graph hash later. + lazy.sync_multi(args_and_out, []) + + def optimized_mod(*args): + if len(args_and_out) == 0: + return () + graph_input = graph_input_matcher(args) + res = return_value_handler.duplicate_eager_tensors( + computation.run_cached_graph(graph_hash, graph_input) + ) + + assert len(res) == len(args_and_out) + for i, arg in enumerate(args): + # only copy those tensors that get inplace updated + if arg is not res[i]: + arg.copy_(res[i]) + + # skip the args + return res[len(args) :] + + return optimized_mod diff --git a/phivenv/Lib/site-packages/torch/_lazy/ir_cache.py b/phivenv/Lib/site-packages/torch/_lazy/ir_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..6b0c4037405bba1bf64c2e41c5dc5ea899eb9281 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_lazy/ir_cache.py @@ -0,0 +1,14 @@ +# mypy: allow-untyped-defs +import torch._C._lazy + + +def dump(dot_file_name: str): + """Dump TrieCache in the dot format""" + return torch._C._lazy._dump_ir_cache(dot_file_name) + + +def reset(): + """Clear TrieCache. This is needed in testing to avoid + node reusing between different tests. + """ + return torch._C._lazy._clear_ir_cache() diff --git a/phivenv/Lib/site-packages/torch/_lazy/metrics.py b/phivenv/Lib/site-packages/torch/_lazy/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..4879c3e4800c099b959d4f964b00de2e9cdc2306 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_lazy/metrics.py @@ -0,0 +1,22 @@ +# mypy: allow-untyped-defs +import torch._C._lazy + + +def reset(): + """Resets all metric counters.""" + torch._C._lazy._reset_metrics() + + +def counter_names(): + """Retrieves all the currently active counter names.""" + return torch._C._lazy._counter_names() + + +def counter_value(name: str): + """Return the value of the counter with the speficied name""" + return torch._C._lazy._counter_value(name) + + +def metrics_report(): + """Return the combined (lazy core and backend) metric report""" + return torch._C._lazy._metrics_report() diff --git a/phivenv/Lib/site-packages/torch/_lazy/tensor_factory_functions.py b/phivenv/Lib/site-packages/torch/_lazy/tensor_factory_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..98b4ec6c5cea99edb937a6160141af360621823f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_lazy/tensor_factory_functions.py @@ -0,0 +1,49 @@ +import torch + + +""" +tensor_factory_functions defines the list of torch functions that create tensors. +The list is grabbed by searching thru native_functions.yaml by the following +regular expression: + + cat native_functions.yaml | grep 'func:' | grep -v "Tensor.*->" | grep "[-]>.*Tensor" + +It's possible that new tensor factory functions are added making this list stale. +Use at your own risk or regenerate the list. +""" +tensor_factory_functions = ( + torch._cudnn_init_dropout_state, + torch.arange, + torch.bartlett_window, + torch.blackman_window, + torch._empty_affine_quantized, + torch.empty_strided, + torch.eye, + torch.full, + torch.from_file, + torch.hann_window, + torch.hamming_window, + torch.kaiser_window, + torch.linspace, + torch.logspace, + torch.ones, + torch.scalar_tensor, + torch.rand, + torch.randint, + torch.randn, + torch.randperm, + torch.range, + torch._efficientzerotensor, + torch.zeros, + torch.tril_indices, + torch.triu_indices, + # Note: the following functions match the regular expression search above but + # they are not available in the torch module. Comment out. + # torch._sparse_coo_tensor_with_dims, + # torch.fft_fftfreq, + # torch.fft_rfftfreq, +) + ( + # torch.tensor is special since it's not in native_functions.yaml + # add it separately + torch.tensor, +) diff --git a/phivenv/Lib/site-packages/torch/_lazy/ts_backend.py b/phivenv/Lib/site-packages/torch/_lazy/ts_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..6050e9dca4555553bba07eaec11c2ea6137d4b30 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_lazy/ts_backend.py @@ -0,0 +1,7 @@ +# mypy: allow-untyped-defs +import torch._C._lazy_ts_backend + + +def init(): + """Initializes the lazy Torchscript backend""" + torch._C._lazy_ts_backend._init() diff --git a/phivenv/Lib/site-packages/torch/_library/__init__.py b/phivenv/Lib/site-packages/torch/_library/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87063792979bc996f2df3ad0a2b41f93f688ce4c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_library/__init__.py @@ -0,0 +1,6 @@ +import torch._library.autograd +import torch._library.fake_impl +import torch._library.simple_registry +import torch._library.utils +from torch._library.fake_class_registry import register_fake_class +from torch._library.triton import capture_triton, triton_op, wrap_triton diff --git a/phivenv/Lib/site-packages/torch/_library/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_library/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b4e5a859c4fe7b3a126b7bd895b83f0a9d556e6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_library/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_library/__pycache__/autograd.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_library/__pycache__/autograd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70a1a006da45d957da4efd48d2884c842e6004dd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_library/__pycache__/autograd.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_library/__pycache__/custom_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_library/__pycache__/custom_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36e5b3eea1eaedcf8da81b11b1ac084c20d842e4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_library/__pycache__/custom_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_library/__pycache__/fake_class_registry.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_library/__pycache__/fake_class_registry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b49e7994a46f486f798a7789789064f6a5e6227 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_library/__pycache__/fake_class_registry.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_library/__pycache__/fake_impl.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_library/__pycache__/fake_impl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0cd033d2bdc05e85e655f1773bed5308644ba5a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_library/__pycache__/fake_impl.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_library/__pycache__/fake_profile.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_library/__pycache__/fake_profile.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d55a40f0f2133fe7e5cd27ae7d03e4b629a13fb9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_library/__pycache__/fake_profile.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_library/__pycache__/infer_schema.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_library/__pycache__/infer_schema.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7821b4782a63d4037cb94a2e07297114e9a8b693 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_library/__pycache__/infer_schema.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_library/__pycache__/simple_registry.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_library/__pycache__/simple_registry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e81a846290224fce86a2f5c5e218dfa25f946a64 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_library/__pycache__/simple_registry.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_library/__pycache__/triton.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_library/__pycache__/triton.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59cade2d2613ad484569d5ac8120255c8f00d80e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_library/__pycache__/triton.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_library/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_library/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..580a3e15a56eb8f366c0970a32afa60f370ad419 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_library/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_library/autograd.py b/phivenv/Lib/site-packages/torch/_library/autograd.py new file mode 100644 index 0000000000000000000000000000000000000000..09dc4f1abfbd329832c5f90e53fc173f25244577 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_library/autograd.py @@ -0,0 +1,239 @@ +# mypy: allow-untyped-defs +import dataclasses +from dataclasses import dataclass +from typing import Any, Callable, Optional, Protocol + +from torch import _C, _ops, autograd, Tensor +from torch.utils import _pytree + +from . import utils + + +class InfoProtocol(Protocol): + _backward_fn: Optional[Callable] + _setup_context_fn: Optional[Callable] + + +@dataclasses.dataclass +class Info: + _backward_fn: Optional[Callable] + _setup_context_fn: Optional[Callable] + + +def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable: + name: str = f"GeneratedBackwardFor_{op._namespace}_{op._opname}_{op._overloadname}" + + has_kwarg_only_args = utils.has_kwarg_only_args(op._schema) + + @dataclass + class Metadata: + keyset: _C.DispatchKeySet + keyword_only_args: dict[str, Any] + + def forward_no_grad(*args): + metadata = args[-1] + args = args[:-1] + + with _C._AutoDispatchBelowAutograd(): + keyset = metadata.keyset + kwargs = metadata.keyword_only_args + result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs) + return result + + def forward(ctx, *args): + metadata = args[-1] + args = args[:-1] + + with _C._AutoDispatchBelowAutograd(): + keyset = metadata.keyset + kwargs = metadata.keyword_only_args + result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs) + if info._setup_context_fn: + # The Dispatcher will remove args that are equal to their default + # values from (args, kwargs). We're going to add it back so that + # the user can access them. + # + # This is OK to do: The Dispatcher removed the args for serialization + # FC/BC reasons (that is, a graph will not store args that are equal + # to their default values), but that doesn't matter here. If the user + # adds a new default arg, then they must update + # their setup_context (along with the rest of their operator + # registrations) + args, kwargs = utils.fill_defaults(op._schema, args, kwargs) + + if has_kwarg_only_args: + info._setup_context_fn( + ctx=ctx, inputs=args, keyword_only_inputs=kwargs, output=result + ) + else: + info._setup_context_fn(ctx=ctx, inputs=args, output=result) + return result + + def backward(ctx, *grads): + if info._backward_fn: + try: + prev_needs_input_grad = ctx.needs_input_grad + ctx.needs_input_grad = ctx.needs_input_grad[:-1] + result = info._backward_fn(ctx, *grads) + finally: + ctx.needs_input_grad = prev_needs_input_grad + if isinstance(result, tuple): + return (*result, None) + return result, None + raise RuntimeError( + f"Trying to backward through {op} but no autograd " + f"formula was registered. " + f"Please use register_autograd to add one." + ) + + Generated = type( + name, + (autograd.Function,), + { + "forward": staticmethod(forward), + "backward": staticmethod(backward), + }, + ) + + schema = op._schema + if any( + utils.is_tensorlist_like_type(a.type) + for a in (*schema.arguments, *schema.returns) + ): + Generated = supports_tensorlist(Generated) + + # The dispatcher passes any keyword-only-args as kwargs and the + # rest of the args (even if specified as kwargs) as args. + def autograd_impl(keyset, *args, **keyword_only_args): + if _C.is_grad_enabled() and _C._any_requires_grad(*args): + result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined] + else: + result = forward_no_grad(*args, Metadata(keyset, keyword_only_args)) + return result + + return autograd_impl + + +def supports_tensorlist(cls: Any) -> Any: + """Allows a given autograd.Function class to support List[Tensor] inputs/outputs. + + Regular autograd.Function has a constraint that it only directly supports autograd for + Tensors. Applying @supports_tensorlist enables an autograd.Function to support + autograd for List[Tensor] inputs and outputs. + """ + orig_forward = cls.forward + orig_backward = cls.backward + orig_apply = cls.apply + + @dataclass + class Metadata: + input_spec: spec_t + output_spec: Optional[spec_t] = None + result_is_tuple: Optional[bool] = None + + def new_forward(ctx, *args): + metadata = args[-1] + args = args[:-1] + if not isinstance(metadata, Metadata): + raise NotImplementedError( + "NYI: calling supports_tensorlist autograd.Function.forward directly. " + "You should probably be calling .apply instead. " + "Please file an issue if not." + ) + args = unflatten(list(args), metadata.input_spec) + result = orig_forward(ctx, *args) + metadata.result_is_tuple = isinstance(result, tuple) + if not metadata.result_is_tuple: + result = (result,) + flat_result, output_spec = flatten(result, not_list_of_tensor) + metadata.output_spec = output_spec + + if hasattr(ctx, "_pt_metadata"): + raise RuntimeError( + "Please don't set ctx._pt_metadata; PyTorch uses it to store info" + ) + ctx._pt_metadata = metadata + + return tuple(flat_result) + + def new_backward(ctx, *grads): + if not hasattr(ctx, "_pt_metadata"): + raise NotImplementedError( + "NYI: calling supports_tensorlist autograd.Function.backward directly. " + "This will automatically get called by PyTorch autograd. " + "Please file an issue if you need this." + ) + + metadata = ctx._pt_metadata + grads = unflatten(list(grads), metadata.output_spec) + + # If the user's input is ([x, y, z], w), + # then needs_input_grad is (bool, bool, bool, bool, bool). + # We need to + # 1. get rid of the additional bool (which comes from the extra + # `metadata input`) + # 2. unflatten to get the right structure. + prev_needs_input_grad = ctx.needs_input_grad + try: + ctx.needs_input_grad = unflatten( + list(ctx.needs_input_grad[:-1]), metadata.input_spec + ) + grad_inputs = orig_backward(ctx, *grads) + finally: + ctx.needs_input_grad = prev_needs_input_grad + + if not isinstance(grad_inputs, tuple): + grad_inputs = (grad_inputs,) + # Assume that any Nones in the backward are Tensors. + # If the forward has an arg that is [1, 2, 3], the backward should + # return None as the grad. + # If the forward has an arg that is [tensor, tensor], the backward + # may return [None, None], [grad, None], [None, grad], or [grad, grad]. + flat_grad_inputs, grad_inputs_spec = flatten( + grad_inputs, not_list_of_optional_tensor + ) + if grad_inputs_spec != metadata.input_spec: + raise RuntimeError( + f"Expected the return from backward to be of the same structure " + f"as the inputs. Got: {grad_inputs_spec} (return from backward), " + f"{metadata.input_spec} (inputs)" + ) + return tuple(flat_grad_inputs + [None]) + + def new_apply(*args): + flat_args, input_spec = flatten(args, is_leaf=not_list_of_tensor) + metadata = Metadata(input_spec) + result = orig_apply(*flat_args, metadata) # type: ignore[misc] + assert metadata.output_spec is not None + result = unflatten(list(result), metadata.output_spec) + if not metadata.result_is_tuple: + assert isinstance(result, tuple) + assert len(result) == 1 + return result[0] + return result + + cls.forward = new_forward + cls.backward = new_backward + cls.apply = new_apply + return cls + + +def not_list_of_tensor(tree): + if isinstance(tree, tuple): + return False + if isinstance(tree, list): + return any(not isinstance(l, Tensor) for l in tree) + return True + + +def not_list_of_optional_tensor(tree): + if isinstance(tree, tuple): + return False + if isinstance(tree, list): + return any(l is not None and not isinstance(l, Tensor) for l in tree) + return True + + +flatten = _pytree.tree_flatten +unflatten = _pytree.tree_unflatten +spec_t = _pytree.TreeSpec diff --git a/phivenv/Lib/site-packages/torch/_library/custom_ops.py b/phivenv/Lib/site-packages/torch/_library/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0b979c6fabb6a32b9d7f90dd08b5c06868d20278 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_library/custom_ops.py @@ -0,0 +1,931 @@ +# mypy: allow-untyped-defs +import collections +import inspect +import logging +import weakref +from collections.abc import Iterable, Sequence +from contextlib import contextmanager +from typing import Any, Callable, Literal, Optional, overload, Union + +import torch +from torch import _C, _ops, Tensor +from torch.types import _dtype +from torch.utils._exposed_in import exposed_in + +from . import autograd, utils + + +device_types_t = Optional[Union[str, Sequence[str]]] +log = logging.getLogger(__name__) + + +@overload +def custom_op( + name: str, + fn: Literal[None] = None, + /, + *, + mutates_args: Union[str, Iterable[str]], + device_types: device_types_t = None, + schema: Optional[str] = None, +) -> Callable[[Callable[..., object]], "CustomOpDef"]: + ... + + +@overload +def custom_op( + name: str, + fn: Callable[..., object], + /, + *, + mutates_args: Union[str, Iterable[str]], + device_types: device_types_t = None, + schema: Optional[str] = None, +) -> "CustomOpDef": + ... + + +@exposed_in("torch.library") +def custom_op( + name: str, + fn: Optional[Callable] = None, + /, + *, + mutates_args: Union[str, Iterable[str]], + device_types: device_types_t = None, + schema: Optional[str] = None, + tags: Optional[Sequence[_C.Tag]] = None, +) -> Union[Callable[[Callable[..., object]], "CustomOpDef"], "CustomOpDef"]: + """Wraps a function into custom operator. + + Reasons why you may want to create a custom op include: + - Wrapping a third-party library or custom kernel to work with PyTorch + subsystems like Autograd. + - Preventing torch.compile/export/FX tracing from peeking inside your function. + + This API is used as a decorator around a function (please see examples). + The provided function must have type hints; these are needed to interface + with PyTorch's various subsystems. + + Args: + name (str): A name for the custom op that looks like "{namespace}::{name}", + e.g. "mylib::my_linear". The name is used as the op's stable identifier + in PyTorch subsystems (e.g. torch.export, FX graphs). + To avoid name collisions, please use your project name as the namespace; + e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace. + mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates. + This MUST be accurate, otherwise, the behavior is undefined. If "unknown", + it pessimistically assumes that all inputs to the operator are being mutated. + device_types (None | str | Sequence[str]): The device type(s) the function + is valid for. If no device type is provided, then the function + is used as the default implementation for all device types. + Examples: "cpu", "cuda". + When registering a device-specific implementation for an operator that accepts no Tensors, + we require the operator to have a "device: torch.device argument". + schema (None | str): A schema string for the operator. If None + (recommended) we'll infer a schema for the operator from its type + annotations. We recommend letting us infer a schema unless you + have a specific reason not to. + Example: "(Tensor x, int y) -> (Tensor, Tensor)". + + .. note:: + We recommend not passing in a ``schema`` arg and instead letting us infer + it from the type annotations. It is error-prone to write your own schema. + You may wish to provide your own schema if our interpretation of + the type annotation is not what you want. + For more info on how to write a schema string, see + `here `_ + + Examples:: + >>> import torch + >>> from torch import Tensor + >>> from torch.library import custom_op + >>> import numpy as np + >>> + >>> @custom_op("mylib::numpy_sin", mutates_args=()) + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> x = torch.randn(3) + >>> y = numpy_sin(x) + >>> assert torch.allclose(y, x.sin()) + >>> + >>> # Example of a custom op that only works for one device type. + >>> @custom_op("mylib::numpy_sin_cpu", mutates_args=(), device_types="cpu") + >>> def numpy_sin_cpu(x: Tensor) -> Tensor: + >>> x_np = x.numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np) + >>> + >>> x = torch.randn(3) + >>> y = numpy_sin_cpu(x) + >>> assert torch.allclose(y, x.sin()) + >>> + >>> # Example of a custom op that mutates an input + >>> @custom_op("mylib::numpy_sin_inplace", mutates_args={"x"}, device_types="cpu") + >>> def numpy_sin_inplace(x: Tensor) -> None: + >>> x_np = x.numpy() + >>> np.sin(x_np, out=x_np) + >>> + >>> x = torch.randn(3) + >>> expected = x.sin() + >>> numpy_sin_inplace(x) + >>> assert torch.allclose(x, expected) + >>> + >>> # Example of a factory function + >>> @torch.library.custom_op("mylib::bar", mutates_args={}, device_types="cpu") + >>> def bar(device: torch.device) -> Tensor: + >>> return torch.ones(3) + >>> + >>> bar("cpu") + + """ + + def inner(fn: Callable[..., object]) -> CustomOpDef: + import torch + + if schema is None: + schema_str = torch.library.infer_schema(fn, mutates_args=mutates_args) + else: + schema_str = schema + + namespace, opname = name.split("::") + result = CustomOpDef(namespace, opname, schema_str, fn, tags) + if schema is not None: + # Check that schema's alias annotations match those of `mutates_args`. + expected = set() + for arg in result._opoverload._schema.arguments: + if arg.alias_info is not None and arg.alias_info.is_write: + expected.add(arg.name) + if expected != set(mutates_args): + raise ValueError( + f"Attempted to create a custom op with `mutates_args={mutates_args}` " + f"and `schema={schema}. The schema suggests that the op mutates {expected}" + f"which is different from what was provided to us in `mutates_args`. " + f"Please make these consistent." + ) + result.register_kernel(device_types)(fn) + return result + + if fn is None: + return inner + return inner(fn) + + +class CustomOpDef: + """CustomOpDef is a wrapper around a function that turns it into a custom op. + + It has various methods for registering additional behavior for this + custom op. + + You should not instantiate CustomOpDef directly; instead, use the + :func:`torch.library.custom_op` API. + """ + + def __init__( + self, + namespace: str, + name: str, + schema: str, + fn: Callable, + tags: Optional[Sequence[_C.Tag]] = None, + ) -> None: + # Fields used to interface with the PyTorch dispatcher + self._namespace = namespace + self._name = name + self._schema = schema + self._tags = tags if tags is not None else [] + + self._init_fn = fn + + self._backend_fns: dict[Union[str, None], Callable] = {} + self._abstract_fn: Optional[Callable] = None + self._setup_context_fn: Optional[Callable] = None + self._backward_fn: Optional[Callable] = None + self._torch_dispatch_fns: dict[type, Callable] = {} + self._vmap_fn: Optional[Callable] = None + self._autocast_cuda_dtype: Optional[_dtype] = None + self._autocast_cpu_dtype: Optional[_dtype] = None + + self._lib = get_library_allowing_overwrite(self._namespace, self._name) + self._register_to_dispatcher(self._tags) + self._disabled_kernel: set = set() + OPDEFS[self._qualname] = self + + @property + def _qualname(self) -> str: + return f"{self._namespace}::{self._name}" + + def __repr__(self) -> str: + return f"" + + @contextmanager + def set_kernel_enabled(self, device_type: str, enabled: bool = True): + """ + Disable or re-enable an already registered kernel for this custom operator. + + If the kernel is already disabled/enabled, this is a no-op. + + Note: + If a kernel is first disabled and then registered, it is disabled until enabled again. + + Args: + device_type (str): The device type to disable/enable the kernel for. + disable (bool): Whether to disable or enable the kernel. + + Example: + >>> inp = torch.randn(1) + >>> + >>> # define custom op `f`. + >>> @custom_op("mylib::f", mutates_args=()) + >>> def f(x: Tensor) -> Tensor: + >>> return torch.zeros(1) + >>> + >>> print(f(inp)) # tensor([0.]), default kernel + >>> + >>> @f.register_kernel("cpu") + >>> def _(x): + >>> return torch.ones(1) + >>> + >>> print(f(inp)) # tensor([1.]), CPU kernel + >>> + >>> # temporarily disable the CPU kernel + >>> with f.set_kernel_enabled("cpu", enabled = False): + >>> print(f(inp)) # tensor([0.]) with CPU kernel disabled + + """ + action = "enable" if enabled else "disable" + originally_disabled = device_type in self._disabled_kernel + if device_type not in self._backend_fns: + log.warning( + "Attempted to %s kernel for %s but no kernel was registered for this device type.", + action, + device_type, + ) + + if not enabled: + if originally_disabled: + log.warning( + "Attempted to disable kernel for %s but it was already disabled.", + device_type, + ) + else: + self._disabled_kernel.add(device_type) + else: # enable the kernel + if not originally_disabled: + log.warning( + "Attempted to enable kernel for %s but it was already enabled.", + device_type, + ) + else: + self._disabled_kernel.remove(device_type) + + try: + yield + finally: + # restore original state + if originally_disabled: + self._disabled_kernel.add(device_type) + else: + self._disabled_kernel.discard(device_type) + + def register_kernel( + self, device_types: device_types_t, fn: Optional[Callable] = None, / + ) -> Callable: + """Register an implementation for a device type for this operator. + + Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu". + This API may be used as a decorator. + + Args: + fn (Callable): The function to register as the implementation for + the given device types. + device_types (str | Sequence[str]): The device device_types to register an impl to. + + Examples:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> from torch import Tensor + >>> from torch.library import custom_op + >>> import numpy as np + >>> + >>> # Create a custom op that works on cpu + >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu") + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> x_np = x.numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np) + >>> + >>> # Add implementations for the cuda device + >>> @numpy_sin.register_kernel("cuda") + >>> def _(x): + >>> x_np = x.cpu().numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> x_cpu = torch.randn(3) + >>> x_cuda = x_cpu.cuda() + >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin()) + >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin()) + + """ + + def inner(fn): + if device_types is None or isinstance(device_types, str): + dtypes: list[Union[str, None]] = [device_types] + else: + dtypes = list(device_types) + for device_type in dtypes: + if device_type not in self._backend_fns: + + def backend_impl(*args, **kwargs): + result = self._backend_fns[device_type](*args, **kwargs) + + def get_module(): + fn = self._backend_fns[device_type] + return inspect.getmodule(fn) + + utils._c_check_aliasing_constraint( + self._name, + args, + kwargs, + result, + get_module, + ) + return result + + if device_type is None: + self._lib.impl( + self._name, backend_impl, "CompositeExplicitAutograd" + ) + else: + self._lib.impl( + self._name, + backend_impl, + _C._dispatch_key_for_device(device_type), + ) + + # Wrap function to choose between the default implementation or the device-specific + # implementation depending on if the kernel is disabled. + @torch._disable_dynamo + def wrapped_fn(*args, **kwargs): + if device_type in self._disabled_kernel: + return self._init_fn(*args, **kwargs) + else: + return fn(*args, **kwargs) + + self._backend_fns[device_type] = wrapped_fn + return fn + + if device_types is not None and not utils.has_tensor_arg( + self._opoverload._schema + ): + device_arg_index = utils.get_device_arg_index(self._opoverload._schema) + if device_arg_index is None: + raise ValueError( + "Functions without tensor inputs are required to have a `device: torch.device` argument" + ) + self._register_backend_select_dispatcher(device_arg_index) + + # See NOTE: [Supporting decorator and non-decorator usage] + if fn is None: + return inner + return inner(fn) + + def register_fake(self, fn: Callable, /) -> Callable: + r"""Register a FakeTensor implementation for this custom op. + + This is necessary to get the operator to work efficiently with torch.compile. + + The Fake impl (sometimes also known as a meta kernel or abstract impl) + specifies the behavior of this operator on Tensors that carry no data. + Given some input Tensors with certain properties + (sizes/strides/storage_offset/device), it specifies what the properties of + the output Tensors are. + + Please see :func:`torch.library.impl_abstract` for more details. + + Args: + fn (Callable): The function to register as the FakeTensor + implementation. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> # Example 1: an operator without data-dependent output shape + >>> @torch.library.custom_op("mylib::linear", mutates_args=()) + >>> def linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: + >>> return (x @ weight.t()) + bias + >>> + >>> @linear.register_fake + >>> def _(x, weight, bias): + >>> assert x.dim() == 2 + >>> assert weight.dim() == 2 + >>> assert bias.dim() == 1 + >>> assert x.shape[1] == weight.shape[1] + >>> assert weight.shape[0] == bias.shape[0] + >>> assert x.device == weight.device + >>> return x.new_empty(x.size(0), weight.size(0)) + >>> + >>> x = torch.randn(2, 2) + >>> weight = torch.randn(2, 2) + >>> bias = torch.randn(2) + >>> # xdoctest: +SKIP("Requires Python <= 3.11") + >>> out = torch.compile(linear, fullgraph=True)(x, weight, bias) + >>> # xdoctest: +SKIP("Requires Python <= 3.11") + >>> assert torch.allclose(out, torch.nn.functional.linear(x, weight, bias)) + >>> + >>> # Example 2: an operator with data-dependent output shape + >>> @torch.library.custom_op("mylib::nonzero", mutates_args=()) + >>> def nonzero(x: Tensor) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> res = np.stack(np.nonzero(x_np), axis=1) + >>> return torch.tensor(res, device=x.device) + >>> + >>> @nonzero.register_fake + >>> def _(x): + >>> # Number of nonzero-elements is data-dependent. + >>> # Since we cannot peek at the data in an abstract impl, + >>> # we use the ctx object to construct a new symint that + >>> # represents the data-dependent size. + >>> ctx = torch.library.get_ctx() + >>> nnz = ctx.new_dynamic_size() + >>> shape = [nnz, x.dim()] + >>> result = x.new_empty(shape, dtype=torch.int64) + >>> return result + >>> + >>> x = torch.tensor([0, 1, 2, 0, 0, 1]) + >>> # xdoctest: +SKIP("Requires Python <= 3.11") + >>> out = torch.compile(nonzero, fullgraph=True)(x) + >>> # xdoctest: +SKIP("Requires Python <= 3.11") + >>> assert torch.allclose(out, x.nonzero()) + + """ + self._abstract_fn = fn + return fn + + def register_torch_dispatch( + self, torch_dispatch_class: Any, fn: Optional[Callable] = None, / + ) -> Callable: + r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``. + + This allows for open registration to specify the behavior between the operator + and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class`` + or the operator directly. + + Please see :func:`torch.library.register_torch_dispatch` for examples and more details. + """ + + def register(fn): + if torch_dispatch_class not in self._torch_dispatch_fns: + + def inner(*args, **kwargs): + return self._torch_dispatch_fns[torch_dispatch_class]( + *args, **kwargs + ) + + self._lib._register_torch_dispatch_rule( + self._name, torch_dispatch_class, inner + ) + self._torch_dispatch_fns[torch_dispatch_class] = fn + return fn + + if fn is None: + return register + else: + return register(fn) + + def register_autograd( + self, + backward: Callable, + /, + *, + setup_context: Optional[Callable] = None, + ) -> None: + r"""Register a backward formula for this custom op. + + In order for an operator to work with autograd, you need to register + a backward formula: + 1. You must tell us how to compute gradients during the backward pass + by providing us a "backward" function. + 2. If you need any values from the forward to compute gradients, you can + use `setup_context` to save values for backward. + + ``backward_fn`` runs during the backward pass. It accepts ``(ctx, *grads)``: + - ``grads`` is one or more gradients. The number of gradients matches + the number of outputs of the operator. + The ``ctx`` object is `the same ctx object `_ used by + :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the + same as :meth:`torch.autograd.Function.backward`. + + ``setup_context(ctx, inputs, output)`` runs during the forward pass. + Please save quantities needed for backward onto the ``ctx`` object via + either :meth:`torch.autograd.function.FunctionCtx.save_for_backward` + or assigning them as attributes of ``ctx``. If your custom op has + kwarg-only arguments, we expect the signature of ``setup_context`` + to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``. + + Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is, + they may not directly access :meth:`torch.Tensor.data_ptr` and they must + not depend on or mutate global state. If you need a non-traceable backward, + you can make it a separate custom_op that you call inside ``backward_fn``. + + If you need different autograd behavior on different devices, then we + recommend creating two different custom operators, one for each device + that needs different behavior, and switching between them at runtime. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> + >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) + >>> def numpy_sin(x: Tensor) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> y_np = np.sin(x_np) + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> def setup_context(ctx, inputs, output) -> Tensor: + >>> x, = inputs + >>> ctx.save_for_backward(x) + >>> + >>> def backward(ctx, grad): + >>> x, = ctx.saved_tensors + >>> return grad * x.cos() + >>> + >>> numpy_sin.register_autograd(backward, setup_context=setup_context) + >>> + >>> x = torch.randn(3, requires_grad=True) + >>> y = numpy_sin(x) + >>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y)) + >>> assert torch.allclose(grad_x, x.cos()) + >>> + >>> # Example with a keyword-only arg + >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) + >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor: + >>> x_np = x.cpu().numpy() + >>> y_np = x_np * val + >>> return torch.from_numpy(y_np).to(device=x.device) + >>> + >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor: + >>> ctx.val = keyword_only_inputs["val"] + >>> + >>> def backward(ctx, grad): + >>> return grad * ctx.val + >>> + >>> numpy_mul.register_autograd(backward, setup_context=setup_context) + >>> + >>> x = torch.randn(3, requires_grad=True) + >>> y = numpy_mul(x, val=3.14) + >>> grad_x, = torch.autograd.grad(y, x, torch.ones_like(y)) + >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14)) + + """ + schema = self._opoverload._schema + if not utils.is_functional_schema(schema): + raise RuntimeError( + f"Cannot register autograd formula for non-functional operator " + f"{self} with schema {schema}. Please create " + f"a functional operator and register an autograd formula for that." + ) + + self._backward_fn = backward + self._setup_context_fn = setup_context + + def _register_to_dispatcher(self, tags: Sequence[_C.Tag]) -> None: + if torch._running_with_deploy(): + utils.warn_deploy(stacklevel=5) + return + + lib = self._lib + schema_str = self._name + self._schema + cpp_schema = _C.parse_schema(schema_str) + if utils.has_kwarg_only_tensors(cpp_schema): + # If you want to support this, the progression is: + # - supporting kwarg-only Tensors that are non-differentiable + # - supporting kwarg-only Tensors (regardless of differentiability) + raise NotImplementedError( + f"custom_op with kwarg-only Tensor args. Please make your " + f"tensors not kwarg-only. Got: {schema_str}" + ) + + lib.define( + schema_str, + tags=[_C.Tag.pt2_compliant_tag, *tags], + ) + self._opoverload = utils.lookup_op(self._qualname) + + def fake_impl(*args, **kwargs): + if self._abstract_fn is None: + if utils.can_generate_trivial_fake_impl(self._opoverload): + return None + raise RuntimeError( + f"There was no fake impl registered for {self}. " + f"This is necessary for torch.compile/export/fx tracing to work. " + f"Please use `{self._init_fn.__name__}.register_fake` to add an " + f"fake impl." + ) + return self._abstract_fn(*args, **kwargs) + + lib._register_fake(self._name, fake_impl, _stacklevel=4) + + autograd_impl = autograd.make_autograd_impl(self._opoverload, self) + lib.impl(self._name, autograd_impl, "Autograd", with_keyset=True) + + schema = self._opoverload._schema + if schema.is_mutable: + mutated_idxs, mutated_keys = utils.mutated_args_kwargs(schema) + + def adinplaceorview_impl(keyset, *args, **kwargs): + for idx in mutated_idxs: + increment_version(args[idx]) + for key in mutated_keys: + increment_version(kwargs[key]) + with _C._AutoDispatchBelowADInplaceOrView(): + return self._opoverload.redispatch( + keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs + ) + + lib.impl( + self._name, + adinplaceorview_impl, + "ADInplaceOrView", + with_keyset=True, + ) + + def _register_backend_select_dispatcher(self, device_arg_index: int): + """ + Switch on the device argument to select the correct backend to dispatch to. + """ + + def backend_select(keyset, *args, **kwargs): + device = args[device_arg_index].type + if device not in self._backend_fns: + raise RuntimeError( + f"{self._name} does not have a kernel registered for {device}. " + "Please use register_kernel to do so." + ) + dispatch_key = _C._dispatch_key_for_device(device) + dispatch_key = getattr(_C.DispatchKey, dispatch_key) + return self._opoverload.redispatch( + _C.DispatchKeySet(dispatch_key), *args, **kwargs + ) + + self._lib.impl(self._name, backend_select, "BackendSelect", with_keyset=True) + + def __call__(self, *args, **kwargs): + return self._opoverload(*args, **kwargs) + + def register_vmap( + self, + func: Optional[Callable] = None, + ): + r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op. + + This API may be used as a decorator. + + In order for an operator to work with :func:`torch.vmap`, you may need to register a + vmap implementation in the following signature: + + ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``, + + where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``. + + It specifies how do we compute the batched version of ``op`` given inputs with an additional + dimension (specified by ``in_dims``). + + For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None`` + if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer + specifying what dimension of the Tensor is being vmapped over. + + ``info`` is a collection of additional metadata that may be helpful: + ``info.batch_size`` specifies the size of the dimension being vmapped over, while + ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`. + + The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``, + ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim`` + per output that specifies if the output has the vmapped dimension and what index it is in. + + Examples: + >>> import torch + >>> import numpy as np + >>> from torch import Tensor + >>> from typing import Tuple + >>> + >>> def to_numpy(tensor): + >>> return tensor.cpu().numpy() + >>> + >>> lib = torch.library.Library("mylib", "FRAGMENT") + >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=()) + >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: + >>> x_np = to_numpy(x) + >>> dx = torch.tensor(3 * x_np ** 2, device=x.device) + >>> return torch.tensor(x_np ** 3, device=x.device), dx + >>> + >>> def numpy_cube_vmap(info, in_dims, x): + >>> result = numpy_cube(x) + >>> return result, (in_dims[0], in_dims[0]) + >>> + >>> numpy_cube.register_vmap(numpy_cube_vmap) + >>> + >>> x = torch.randn(3) + >>> torch.vmap(numpy_cube)(x) + >>> + >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) + >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor: + >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) + >>> + >>> @numpy_mul.register_vmap + >>> def numpy_mul_vmap(info, in_dims, x, y): + >>> x_bdim, y_bdim = in_dims + >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) + >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) + >>> result = x * y + >>> result = result.movedim(-1, 0) + >>> return result, 0 + >>> + >>> + >>> x = torch.randn(3) + >>> y = torch.randn(3) + >>> torch.vmap(numpy_mul)(x, y) + """ + from torch._functorch.autograd_function import custom_function_call_vmap_helper + from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter + + def register(func): + need_register = self._vmap_fn is None + self._vmap_fn = func + + if need_register: + + def wrapped_func(keyset, *args, **kwargs): + interpreter = retrieve_current_functorch_interpreter() + return custom_function_call_vmap_helper( + interpreter, self._vmap_fn, self._opoverload, *args, **kwargs + ) + + self._lib.impl( + self._name, wrapped_func, "FuncTorchBatched", with_keyset=True + ) + + if func is None: + return register + else: + return register(func) + + def register_autocast( + self, + device_type: str, + cast_inputs: _dtype, + ): + r"""Register an autocast dispatch rule for this custom op. + + Valid `device_type` include: "cpu" and "cuda". + + Args: + op (str | OpOverload): The operator to register an autocast dispatch rule to. + device_type(str): Device type to use. 'cuda' or 'cpu'. + The type is the same as the `type` attribute of a :class:`torch.device`. + Thus, you may obtain the device type of a tensor using `Tensor.device.type`. + cast_inputs (:class:`torch.dtype`): When custom op runs in an autocast-enabled region, + casts incoming floating-point Tensors to the target dtype (non-floating-point Tensors + are not affected), then executes custom op with autocast disabled. + lib (Optional[Library]): If provided, the lifetime of this registration + + Examples:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> from torch import Tensor + >>> from torch.library import custom_op + >>> + >>> # Create a custom op that works on cuda + >>> @torch.library.custom_op("mylib::my_sin", mutates_args=()) + >>> def my_sin(x: Tensor) -> Tensor: + >>> return torch.sin(x) + >>> + >>> # Register autocast dispatch rule for the cuda device + >>> torch.library.register_autocast("mylib::my_sin", "cuda", torch.float16) + >>> + >>> x = torch.randn(3, dtype=torch.float32, device="cuda") + >>> with torch.autocast("cuda", dtype=torch.float16): + >>> y = torch.ops.mylib.my_sin(x) + >>> assert y.dtype == torch.float16 + + """ + if not isinstance(device_type, str): + raise ValueError( + f"Expected `device_type` of type `str`, got: `{type(device_type)}`" + ) + if device_type not in ["cpu", "cuda"]: + raise ValueError(f"Unknown device type: {device_type}") + + need_register_cuda = self._autocast_cuda_dtype is None + need_register_cpu = self._autocast_cpu_dtype is None + if device_type == "cuda": + self._autocast_cuda_dtype = cast_inputs + else: + self._autocast_cpu_dtype = cast_inputs + + def kernel(_, *args, **kwargs): + assert len(kwargs) == 0, "Custom ops do not support kwargs yet." + autocast_keyset = torch._C.DispatchKeySet( + torch._C.DispatchKey.AutocastCPU + ) | torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastCUDA) + with torch._C._ExcludeDispatchKeyGuard(autocast_keyset): + return self._opoverload(*_cast(args, device_type, cast_inputs)) + + if need_register_cuda and self._autocast_cuda_dtype: + self._lib.impl(self._name, kernel, "AutocastCUDA", with_keyset=True) + elif need_register_cpu and self._autocast_cpu_dtype: + self._lib.impl(self._name, kernel, "AutocastCPU", with_keyset=True) + + return kernel + + +# TODO: Merge this function with torch.amp.autocast_mode._cast, and refactor it +# into a utility function once custom ops support arbitrary input types. +def _cast(value, device_type: str, dtype: _dtype): + if isinstance(value, torch.Tensor): + is_eligible = ( + value.is_floating_point() + and value.device.type == device_type + and (value.dtype is not torch.float64) + ) + return value.to(dtype) if is_eligible else value + elif isinstance(value, (str, bytes)): + return value + elif isinstance(value, collections.abc.Iterable): + iterable = (_cast(v, device_type, dtype) for v in value) + if isinstance(value, (list, tuple)): + return type(value)(iterable) + else: + return iterable + else: + return value + + +def increment_version(val: Any) -> None: + if isinstance(val, Tensor): + torch.autograd.graph.increment_version(val) + elif isinstance(val, (tuple, list)): + for v in val: + if isinstance(v, Tensor): + torch.autograd.graph.increment_version(v) + + +# NOTE: [Supporting decorator and non-decorator usage] +# +# Some APIs may be both used as a decorator and not as a decorator. +# For example: +# +# >>> def fn(x): +# >>> return x.sin() +# >>> +# >>> # Usage 1: not as a decorator +# >>> numpy_sin.register_kernel("cuda", fn) +# >>> +# >>> # Usage 2: as a decorator +# >>> @numpy_sin.register_kernel("cuda") +# >>> def fn2(x): +# >>> return x.sin +# +# The way we support this is that `register_kernel` accepts an optional `fn`. +# If `fn` is provided (Usage 1), then we know that the user is using it not +# as a decorator. +# If `fn` is not provided (Usage 2), then `register_kernel` needs to return a +# decorator. + + +OPDEF_TO_LIB: dict[str, "torch.library.Library"] = {} +OPDEFS: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + + +def get_library_allowing_overwrite( + namespace: str, name: str +) -> "torch.library.Library": + qualname = f"{namespace}::{name}" + + if qualname in OPDEF_TO_LIB: + OPDEF_TO_LIB[qualname]._destroy() + del OPDEF_TO_LIB[qualname] + + lib = torch.library.Library(namespace, "FRAGMENT") # noqa: TOR901 + OPDEF_TO_LIB[qualname] = lib + return lib + + +def _maybe_get_opdef( + op: Union[CustomOpDef, _ops.OpOverload, str] +) -> Optional[CustomOpDef]: + if isinstance(op, CustomOpDef): + return op + if isinstance(op, _ops.OpOverload): + op = op._name + assert isinstance(op, str) + if op in OPDEFS: + return OPDEFS[op] + return None diff --git a/phivenv/Lib/site-packages/torch/_library/fake_class_registry.py b/phivenv/Lib/site-packages/torch/_library/fake_class_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..fef31805fc5aaf9462ddc27ba4cbc8d195cc1606 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_library/fake_class_registry.py @@ -0,0 +1,341 @@ +# mypy: allow-untyped-defs +import copy +import logging +from typing import Any, Optional, Protocol, Union + +import torch +from torch._library.utils import parse_namespace +from torch.utils._python_dispatch import _disable_current_modes + + +log = logging.getLogger(__name__) + + +class FakeScriptObject: + def __init__(self, wrapped_obj: Any, script_class_name: str, x: torch.ScriptObject): + self.wrapped_obj = wrapped_obj + + # The fully qualified name of the class of original script object + self.script_class_name = script_class_name + try: + with _disable_current_modes(): + self.real_obj = copy.deepcopy(x) + except RuntimeError: + log.warning( + "Unable to deepcopy the custom object %s. " + "Defaulting to the user given object. This might be " + "dangerous as side effects may be directly applied " + "to the object.", + script_class_name, + ) + self.real_obj = x + + +class FakeScriptMethod: + def __init__( + self, + self_fake_obj: FakeScriptObject, + method_name: str, + schema: Optional[torch.FunctionSchema], + ): + self.self_fake_obj = self_fake_obj + self.method_name = method_name + self.schema = schema + + def __call__(self, *args, **kwargs): + from torch._higher_order_ops.torchbind import call_torchbind + + return call_torchbind(self.self_fake_obj, self.method_name, *args, **kwargs) + + +class HasStaticMethodFromReal(Protocol): + @classmethod + def from_real(cls, real_obj: torch.ScriptObject): + pass + + +class FakeClassRegistry: + def __init__(self) -> None: + self._registered_class: dict[str, Any] = {} + + def has_impl(self, full_qualname: str) -> bool: + return full_qualname in self._registered_class + + def get_impl(self, full_qualname: str) -> Any: + self._check_registered(full_qualname) + return self._registered_class[full_qualname] + + def register(self, full_qualname: str, fake_class=None) -> None: + if self.has_impl(full_qualname): + log.warning( + "%s is already registered. Previous fake class is overridden with %s.", + full_qualname, + fake_class, + ) + self._registered_class[full_qualname] = fake_class + + def deregister(self, full_qualname: str) -> Any: + if not self.has_impl(full_qualname): + log.warning( + "Cannot deregister %s. Please use register_fake_class to register it first." + " Or do you dereigster it twice?", + full_qualname, + ) + else: + return self._registered_class.pop(full_qualname) + + def clear(self) -> None: + self._registered_class.clear() + + def _check_registered(self, full_qualname: str) -> None: + if full_qualname not in self._registered_class: + raise RuntimeError( + f"{full_qualname} is not registered. Please use register_fake_class to register it first." + ) + + +global_fake_class_registry = FakeClassRegistry() + + +# TODO: add this check at compile time for __obj_flatten__. +def _check_valid_flat_script_obj(flat_x): + if not isinstance(flat_x, tuple): + raise RuntimeError("Expect flat x to be a tuple.") + + for tp in flat_x: + if not isinstance(tp, tuple): + raise RuntimeError("Expect flat x to be a tuple of tuples.") + + if not len(tp) == 2 or not isinstance(tp[0], str): + raise RuntimeError( + "Expect element of flat x to be a tuple of two elements with first element being a string" + ) + + +def tracing_with_real(x: torch.ScriptObject) -> bool: + if not hasattr(x, "tracing_mode"): + return False + + assert x.tracing_mode() in [ + "real", + "fake", + ], f"tracing_mode can be either real or fake but got {x.tracing_mode()}" + return x.tracing_mode() == "real" + + +def maybe_to_fake_obj( + fake_mode, x: torch.ScriptObject +) -> Union[FakeScriptObject, torch.ScriptObject]: + import torch.utils._pytree as pytree + from torch.utils._python_dispatch import _disable_current_modes + + # When tracing with real mode, people should implement meta kernels that can + # handle the case of real script object + fake tensor inputs. + if tracing_with_real(x): + return x + + # x.__obj_flatten__() could be calling some tensor operations inside but we don't + # want to call these ops in surrounding dispatch modes when executing it. + # Otherwise, for example, the fake tensor modes will error out when the tensors inside + # script obeject execute some operations like clone if allow_non_fake_input flag is set. + with _disable_current_modes(): + flat_x = x.__obj_flatten__() # type: ignore[attr-defined] + + _check_valid_flat_script_obj(flat_x) + + fake_flattened = pytree.tree_map_only( + torch.Tensor, + lambda t: fake_mode.from_tensor(t), + flat_x, + ) + + fake_x = _find_fake_class_for_script_object(x).__obj_unflatten__(fake_flattened) + + fake_x_wrapped = FakeScriptObject(fake_x, x._type().qualified_name(), x) # type: ignore[attr-defined] + + for name in x._method_names(): # type: ignore[attr-defined] + attr = getattr(fake_x, name, None) + if attr is not None: + if not callable(attr): + raise RuntimeError(f"Expect {name} to be a callable but got {attr}.") + + real_attr = getattr(x, name) # type: ignore[attr-defined] + + # real attr sometimes is not torch.ScriptMethod thus doesn't have schema e.g. __init___ or __eq__ + method_schema: Optional[torch.FunctionSchema] = None + if isinstance(real_attr, torch.ScriptMethod): + method_schema = real_attr.schema # type: ignore[attr-defined] + + setattr( + fake_x_wrapped, + name, + FakeScriptMethod(fake_x_wrapped, name, method_schema), + ) + else: + override_skip_list = {"__obj_flatten__", "__get_state__", "__set_state__"} + if name not in override_skip_list: + log.warning("fake object of %s doesn't implement method %s.", x, name) + return fake_x_wrapped + + +def register_fake_class(qualname, fake_class: Optional[HasStaticMethodFromReal] = None): + r"""Register a fake implementation for this class. + + It's in the same spirit of registering a fake implementation for + an operator but with the difference that it + associates a fake class with the original torch bind class (registered + with torch::class_). In this way, torch.compile can handle them properly + in components such as Dynamo and AOTAutograd. + + This API may be used as a decorator (see example). For the fake class, users + are required to provide a from_real classmethod that takes a real object and + returns an instance of the fake class. All tensors in the fake object should also + be properly fakified with to_fake_tensor() in from_real. + + + Examples: + # For a custom class Foo defined in test_custom_class_registration.cpp: + + TORCH_LIBRARY(_TorchScriptTesting, m) { + m.class_("_TensorQueue") + .def(torch::init()) + .def("push", &TensorQueue::push) + .def("pop", &TensorQueue::pop) + .def("top", &TensorQueue::top) + .def("size", &TensorQueue::size) + .def("clone_queue", &TensorQueue::clone_queue) + .def("__obj_flatten__", &TensorQueue::__obj_flatten__) + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr& self) + -> c10::Dict { + return self->serialize(); + }, + // __setstate__ + [](c10::Dict data) + -> c10::intrusive_ptr { + return c10::make_intrusive(std::move(data)); + }); + }; + # We could register a fake class FakeTensorQueue in Python as follows: + import torch + + @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") + class FakeTensorQueue: + def __init__(self, queue): + self.queue = queue + + @classmethod + def __obj_unflatten__(cls, flattened_ctx): + return cls(**dict(ctx)) + + def push(self, x): + self.queue.append(x) + + def pop(self): + return self.queue.pop(0) + + def size(self): + return len(self.queue) + + In this example, the original TensorQeue need to addd a __obj_flatten__ method + to the class TensorQueue and the flattend result is passed into FakeTensorQueue's + __obj_unflatten__ as inputs to create a fake class. This protocol allows pytorch to look + at the contents of the script object and properly handle them in the subsystems + like dynamo, aot_aotugrad or more. + """ + + def inner(fake_class: HasStaticMethodFromReal): + ns, name = parse_namespace(qualname) + + # This also checks whether the refered torch::class_ exists. + torch._C._get_custom_class_python_wrapper(ns, name) + + from_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None) + if not from_method: + raise RuntimeError( + f"{fake_class} doesn't define a classmethod {_CONVERT_FROM_REAL_NAME}." + ) + + if not isinstance(fake_class.__dict__[_CONVERT_FROM_REAL_NAME], classmethod): + raise RuntimeError( + f"{_CONVERT_FROM_REAL_NAME} method is not a classmethod." + ) + + global_fake_class_registry.register(_full_qual_class_name(qualname), fake_class) + return fake_class + + if fake_class is None: + return inner + return inner(fake_class) + + +def deregister_fake_class(qualname): + return global_fake_class_registry.deregister(_full_qual_class_name(qualname)) + + +def has_fake_class(full_qualname) -> bool: + return global_fake_class_registry.has_impl(full_qualname) + + +def find_fake_class(full_qualname) -> Optional[Any]: + if not has_fake_class(full_qualname): + return None + return global_fake_class_registry.get_impl(full_qualname) + + +def _full_qual_class_name(qualname: str) -> str: + ns, name = parse_namespace(qualname) + return "__torch__.torch.classes." + ns + "." + name + + +def _is_script_object(obj: Any) -> bool: + return isinstance( + obj, torch.ScriptObject + ) and obj._type().qualified_name().startswith( # type: ignore[attr-defined] + "__torch__.torch.classes" + ) + + +# Return the namespace and class name from fully qualified name. +def _ns_and_class_name(full_qualname: str) -> tuple[str, str]: + splits = full_qualname.split(".") + assert len(splits) == 5, f"Could not split {full_qualname=}" + _torch, _torch_ns, _classes, ns, class_name = splits + return ns, class_name + + +def _find_fake_class_for_script_object(x: torch.ScriptObject) -> Any: + full_qualname = x._type().qualified_name() # type: ignore[attr-defined] + ns, class_name = _ns_and_class_name(full_qualname) + fake_class = find_fake_class(full_qualname) + if fake_class is None: + raise RuntimeError( + f" ScriptObject's {full_qualname} haven't registered a fake class." + f" Please use register_fake_class({ns}::{class_name}) to annotate a fake class for the script obj." + f" Specifically, create a python class that implements a fake version for all the methods" + f" that're used in the program and put annotated class in the program e.g. after loading the library." + f" The fake methods can be written in the same way as a meta kernel for an operator but need to additionally" + f" simulate the object's states. Be sure to add a {_CONVERT_FROM_REAL_NAME} classmethod" + f" to enable creating a fake obj from a real one." + ) + return fake_class + + +_CONVERT_FROM_REAL_NAME = "__obj_unflatten__" + + +def _fake_obj_from_real(fake_mode, x) -> Any: + fake_class = _find_fake_class_for_script_object(x) + + from_real_method = getattr(fake_class, _CONVERT_FROM_REAL_NAME, None) + if not from_real_method: + raise RuntimeError( + f"{fake_class} must define a classmethod {_CONVERT_FROM_REAL_NAME}" + f" that converts the real object to the fake object." + ) + + # from_real defined by user need the ctx to fakify the tensor states. + ctx = torch._library.fake_impl.FakeImplCtx(fake_mode, None) + with torch._library.fake_impl.set_ctx_getter(lambda: ctx): + return fake_class.from_real(x) diff --git a/phivenv/Lib/site-packages/torch/_library/fake_impl.py b/phivenv/Lib/site-packages/torch/_library/fake_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..3578d5dd41850c3f065990ba6c60beb54c832343 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_library/fake_impl.py @@ -0,0 +1,227 @@ +# mypy: allow-untyped-defs +import contextlib +import functools +from typing import Callable +from typing_extensions import deprecated + +import torch +from torch._library.utils import Kernel, RegistrationHandle + + +class FakeImplHolder: + """A holder where one can register an fake impl to.""" + + def __init__(self, qualname: str): + self.qualname: str = qualname + # kernels stores all registered fake kernels, ordered by registration + # time ascendingly (newer registration after older registration). If an + # operator library gets loaded that overrides an existing fake kernel, + # both kernels will be in the list, but the newest one will be the one + # that is run. If the library is unloaded, we will remove the kernel + # from this list. + self.kernels: list[Kernel] = [] + + @property + def kernel(self): + if len(self.kernels) == 0: + return None + return self.kernels[-1] + + @kernel.setter + def kernel(self, value): + raise RuntimeError("Unable to directly set kernel.") + + def register( + self, func: Callable, source: str, lib, *, allow_override=False + ) -> RegistrationHandle: + """Register an fake impl. + + Returns a RegistrationHandle that one can use to de-register this + fake impl. + """ + + if not allow_override: + if self.kernel is not None: + raise RuntimeError( + f"register_fake(...): the operator {self.qualname} " + f"already has an fake impl registered at " + f"{self.kernel.source}." + ) + if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"): + raise RuntimeError( + f"register_fake(...): the operator {self.qualname} " + f"already has an DispatchKey::Meta implementation via a " + f"pre-existing torch.library or TORCH_LIBRARY registration. " + f"Please either remove that registration or don't call " + f"register_fake." + ) + + if torch._C._dispatch_has_kernel_for_dispatch_key( + self.qualname, "CompositeImplicitAutograd" + ): + raise RuntimeError( + f"register_fake(...): the operator {self.qualname} " + f"already has an implementation for this device type via a " + f"pre-existing registration to " + f"DispatchKey::CompositeImplicitAutograd." + f"CompositeImplicitAutograd operators do not need an fake " + f"impl; " + f"instead, the operator will decompose into its constituents " + f"and those " + f"can have fake impls defined on them." + ) + + # Store the kernel in this holder + kernel = Kernel(func, source) + self.kernels.append(kernel) + + def deregister_fake_kernel(): + self.kernels.remove(kernel) + + meta_kernel = construct_meta_kernel(self.qualname, self) + lib.impl(self.qualname, meta_kernel, "Meta", allow_override=allow_override) + + handle = RegistrationHandle(deregister_fake_kernel) + return handle + + +def construct_meta_kernel(qualname: str, fake_impl_holder: FakeImplHolder) -> Callable: + assert fake_impl_holder.kernel is not None + + @functools.wraps(fake_impl_holder.kernel.func) + def meta_kernel(*args, **kwargs): + assert fake_impl_holder.kernel is not None + source = fake_impl_holder.kernel.source + + def error_on_ctx(): + raise RuntimeError( + f"{qualname} ({source}): You're trying to run this operator " + f"with meta Tensors (as opposed to FakeTensors), but this " + f"operator may return an output Tensor with data-dependent shape. Meta " + f"Tensors don't support operators with outputs that have data-dependent shapes " + f"but FakeTensors do. " + f"If your operator does not return an output with data-dependent shape, " + f"make sure the FakeTensor and/or meta kernel does not call " + f"torch.library.get_ctx(). Otherwise, please use FakeTensors." + ) + + with set_ctx_getter(error_on_ctx): + return fake_impl_holder.kernel(*args, **kwargs) + + return meta_kernel + + +def get_none(): + return None + + +global_ctx_getter: Callable = get_none + + +@contextlib.contextmanager +def set_ctx_getter(ctx_getter): + global global_ctx_getter + prev = global_ctx_getter + try: + global_ctx_getter = ctx_getter + yield + finally: + global_ctx_getter = prev + + +class FakeImplCtx: + """ + Context object for writing fake implementations for custom operators. + """ + + def __init__(self, _fake_mode, _op): + self._fake_mode = _fake_mode + self._shape_env = _fake_mode.shape_env + self._op = _op + + @deprecated( + "`create_unbacked_symint` is deprecated, please use `new_dynamic_size` instead", + category=FutureWarning, + ) + def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt: + return self.new_dynamic_size(min=min, max=max) + + def new_dynamic_size(self, *, min=0, max=None) -> torch.SymInt: + """Constructs a new symint (symbolic int) representing a data-dependent value. + + This is useful for writing the fake implementation (which is necessary + for torch.compile) for a CustomOp where an output Tensor has a size + that depends on the data of the input Tensors. + + Args: + min (int): A statically known inclusive lower bound for this symint. Default: 0 + max (Optional[int]): A statically known inclusive upper bound for this + symint. Default: None + + .. warning: + + It is important that the ``min`` and ``max`` (if not None) values are set + correctly, otherwise, there will be undefined behavior under + torch.compile. The default value of ``min`` is 2 due to torch.compile + specializing on 0/1 sizes. + + You must also verify that your implementation on concrete Tensors + (e.g. CPU/CUDA) only returns Tensors where the size that corresponds + to the symint also has respects these constraint. + The easiest way to do this is to add an assertion in the CPU/CUDA/etc + implementation that the size follows these bounds. + + Example:: + + >>> # An operator with data-dependent output shape + >>> lib = torch.library.Library("mymodule", "FRAGMENT") + >>> lib.define("mymodule::custom_nonzero(Tensor x) -> Tensor") + >>> + >>> @torch.library.register_fake("mymodule::custom_nonzero") + >>> def _(x): + >>> # Number of nonzero-elements is data-dependent. + >>> # Since we cannot peek at the data in an fake impl, + >>> # we use the ctx object to construct a new symint that + >>> # represents the data-dependent size. + >>> ctx = torch.library.get_ctx() + >>> nnz = ctx.new_dynamic_size() + >>> shape = [nnz, x.dim()] + >>> result = x.new_empty(shape, dtype=torch.int64) + >>> return result + >>> + >>> @torch.library.impl(lib, "custom_nonzero", "CPU") + >>> def _(x): + >>> x_np = x.numpy() + >>> res = np.stack(np.nonzero(x_np), axis=1) + >>> return torch.tensor(res, device=x.device) + + """ + if ( + self._shape_env is None + or not self._shape_env.allow_dynamic_output_shape_ops + ): + raise torch._subclasses.fake_tensor.DynamicOutputShapeException(self._op) + + if isinstance(min, torch.SymInt) or isinstance(max, torch.SymInt): + raise ValueError( + f"ctx.new_dynamic_size(min={min}, max={max}): expected " + f"min and max to be statically known ints but got SymInt. " + f"This is not supported." + ) + + if min < 0: + raise ValueError( + f"ctx.new_dynamic_size(min={min}, ...): expected min to be " + f"greater than or equal to 0: this API can only create " + f"non-negative sizes." + ) + + return allocate_size(self._shape_env, min, max) + + +def allocate_size(shape_env, min_val=0, max_val=None): + result = shape_env.create_unbacked_symint() + torch.fx.experimental.symbolic_shapes._constrain_range_for_size( + result, min=min_val, max=max_val + ) + return result diff --git a/phivenv/Lib/site-packages/torch/_library/fake_profile.py b/phivenv/Lib/site-packages/torch/_library/fake_profile.py new file mode 100644 index 0000000000000000000000000000000000000000..5fee99e6d36cd4c5ca17646fd32b08118a6b6a43 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_library/fake_profile.py @@ -0,0 +1,323 @@ +import contextlib +import io +import logging +import os +from collections.abc import Generator +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +from torch._library.custom_ops import _maybe_get_opdef +from torch.types import FileLike + + +log = logging.getLogger(__name__) + + +class MissingOpProfile(RuntimeError): + """ + This is raised when we don't have an operator profile available for the + given inputs. + """ + + +@dataclass(frozen=True) +class TensorMetadata: + rank: int + dtype: torch.dtype + device: torch.device + layout: torch.layout + + @staticmethod + def maybe_from_tensor(t: Any) -> Optional["TensorMetadata"]: + if not isinstance(t, torch.Tensor): + return None + return TensorMetadata(t.dim(), t.dtype, t.device, t.layout) + + +@dataclass(frozen=True) +class OpProfile: + args_profile: tuple[Optional[TensorMetadata]] + out_profile: Union[TensorMetadata, tuple[TensorMetadata]] + + +def _generate_fake_kernel(op_name: str, op_profile: set[OpProfile]) -> Callable: + def _match_args(args_profile: tuple[Optional[TensorMetadata]], args: Any) -> bool: + return all( + TensorMetadata.maybe_from_tensor(arg) == args_profile[i] + for i, arg in enumerate(args) + ) + + def _generate_res( + out_profile: Union[TensorMetadata, tuple[TensorMetadata]], + ) -> Union[torch.Tensor, list[torch.Tensor]]: + ctx = torch.library.get_ctx() + + def _generate_tensor_out(t: TensorMetadata) -> torch.Tensor: + fake_shape = [ctx.new_dynamic_size() for _ in range(t.rank)] + fake_strides = [-1] * t.rank + expected = 1 + fake_stride = expected + for i in range(t.rank): + fake_strides[i] = fake_stride # type: ignore[assignment] + fake_stride = fake_stride * fake_shape[i] # type: ignore[assignment] + + return torch.empty_strided( + fake_shape, + fake_strides, + device=t.device, + dtype=t.dtype, + layout=t.layout, + ) + + if isinstance(out_profile, TensorMetadata): + return _generate_tensor_out(out_profile) + else: + return [_generate_tensor_out(t) for t in out_profile] + + def _fake_kernel(*args, **kwargs): # type: ignore[no-untyped-def] + for profile in op_profile: + if _match_args(profile.args_profile, (*args, *kwargs.values())): + return _generate_res(profile.out_profile) + + raise MissingOpProfile( + f"No fake kernel was found for {op_name}, and although we have " + "previously registered some profiles to generate a fake kernel, " + f"no profiles match the given inputs: {args, kwargs}." + ) + + return _fake_kernel + + +@contextlib.contextmanager +def unsafe_generate_fake_kernels(op_profiles: dict[str, set[OpProfile]]) -> Generator: + """ + Registers a fake kernel based on the given operator profiles. This fake + kernel registration will override any existing fake kernel registrations. + + The input is a dictionary mapping operator names to a set of operator + profiles, which we will use to generate fake kernels. The operator profiles + are a record of the input and output tensor metadata. Based on this + information we will match a given input to the recorded profile, and return + an output with the same metadata as in the recorded profile. If a profile + doesn't exist then an exception will be thrown. + + The fake kernel generation is considerd unsafe because it relies on the + rigid, pre-defined operator profiles that do not account for potential + variations in output behavior. Specifically, the generated kernels assume a + fixed relationship between input and output ranks. However, in reality, it's + possible that data-dependent operations may produce outputs of different + ranks even when given inputs of the same rank. The generated fake kernels + are inflexible and unable to accommodate these nuances, making them + potentially unsafe. + + Args: + op_profiles (dict[str, set[OpProfile]]): A dictionary mapping operator + name to a set of operator profiles from which we will generate fake + kernels. + + Examples: + + >>> # Example: Registering an op-profile from draft-export + >>> import torch + >>> from torch.export._draft_export import draft_export + >>> + >>> @torch.library.custom_op("mylib::foo", mutates_args=()) + >>> def foo(x: Tensor, y: Tensor) -> Tensor: + >>> return x + y + >>> + >>> class M(torch.nn.Module): + >>> def forward(self, a, b): + >>> res = torch.ops.mylib.foo(a, b) # no fake impl + >>> return res + >>> + >>> ep = draft_export(M(), (torch.ones(3, 4), torch.ones(3, 4)) + >>> + >>> with torch._library.fake_profile.unsafe_generate_fake_kernels(ep._report.op_profiles): + >>> decomp = ep.run_decompositions() + + """ + + libs: list[torch.library.Library] = [] + # Stores old fake impls from custom ops declared through @custom_op + old_fake_impls: dict[str, Callable] = {} + for op_name, profiles in op_profiles.items(): + log.warning( + "Registering fake profile for %s. This will override any existing " + "fake kernel registration.", + op_name, + ) + + op_name_split = op_name.split(".") + namespace, op_name_str = op_name_split[0], op_name_split[1] + op_str = f"{namespace}::{op_name_str}" + + fake_kernel = _generate_fake_kernel(op_str, profiles) + + if opdef := _maybe_get_opdef(op_str): + # If the op is a CustomOpDef, save the existing abstract_fn so that + # we can restore it after this contextmanager + if opdef._abstract_fn is not None: + old_fake_impls[op_str] = opdef._abstract_fn + opdef.register_fake(fake_kernel) + + else: + # Create a new library so that we can register a new fake impl. + # These libraries will then be destroyed after the contextmanager, + # which will automatically restore the previously registered fake + # impls. + newlib = torch.library.Library(namespace, "FRAGMENT") # noqa: TOR901 + torch.library.register_fake( + op_str, fake_kernel, lib=newlib, allow_override=True + ) + libs.append(newlib) + + try: + yield libs + finally: + # Destroying the libraries will automatically restore the previously + # registered fake impls + for lib in libs: + lib._destroy() + + # Restore abstract_fns for CustomOpDefs + for op_str, old_fake in old_fake_impls.items(): + opdef = _maybe_get_opdef(op_str) + assert opdef is not None + opdef.register_fake(old_fake) + + +def get_torch_version() -> str: + version = torch.__version__.split(".") + return f"{int(version[0])}.{int(version[1])}" + + +def generate_yaml_from_profiles(op_profiles: dict[str, set[OpProfile]]) -> str: + """ + Generates a yaml string from the given operator profiles which can be saved + to a file. The yaml string can be loaded back into an operator profile + structure using `read_profiles_from_yaml`. + """ + import yaml + + from torch._export.serde.serialize import ( + _TORCH_TO_SERIALIZE_DTYPE, + _TORCH_TO_SERIALIZE_LAYOUT, + ) + + def serialize_tensor_metadata(t: TensorMetadata) -> dict: + return { + "rank": t.rank, + "dtype": _TORCH_TO_SERIALIZE_DTYPE[t.dtype].value, + "device": str(t.device), + "layout": _TORCH_TO_SERIALIZE_LAYOUT[t.layout].value, + } + + def serialize_op_profile(op: OpProfile) -> dict: + return { + "args_profile": [ + serialize_tensor_metadata(arg) + for arg in op.args_profile + if arg is not None + ], + "out_profile": ( + serialize_tensor_metadata(op.out_profile) + if isinstance(op.out_profile, TensorMetadata) + else [serialize_tensor_metadata(out) for out in op.out_profile] + ), + } + + serialized_data = { + operator: [serialize_op_profile(profile) for profile in profiles] + for operator, profiles in op_profiles.items() + } + return yaml.dump( + {"torch_version": get_torch_version(), "operators": serialized_data}, + sort_keys=False, + ) + + +def save_op_profiles(op_profiles: dict[str, set[OpProfile]], f: FileLike) -> None: + """ + Serializes the given operator profiles into a yaml format and saves it to + the given file. The operator profile can be loaded back using `load_op_profiles`. + """ + yaml_str = generate_yaml_from_profiles(op_profiles) + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + with open(f, "w") as file: + file.write(yaml_str) + + elif isinstance(f, io.BytesIO): + f.write(yaml_str.encode("utf-8")) + + else: + raise ValueError(f"Invalid type of file {f}") + + +def read_profiles_from_yaml(yaml_str: str) -> dict[str, set[OpProfile]]: + """ + Reads the yaml saved by `save_op_profiles` and returns the operator profiles. + """ + import yaml + + from torch._export.serde.serialize import ( + _SERIALIZE_TO_TORCH_DTYPE, + _SERIALIZE_TO_TORCH_LAYOUT, + ) + + def deserialize_tensor_metadata(data: dict) -> TensorMetadata: + return TensorMetadata( + rank=data["rank"], + dtype=_SERIALIZE_TO_TORCH_DTYPE[data["dtype"]], + device=torch.device(data["device"]), + layout=_SERIALIZE_TO_TORCH_LAYOUT[data["layout"]], + ) + + def deserialize_op_profile(data: dict) -> OpProfile: + args_profile = tuple( + deserialize_tensor_metadata(arg) for arg in data["args_profile"] + ) + out_profile_data = data["out_profile"] + out_profile: Union[tuple[TensorMetadata], TensorMetadata] = ( + tuple(deserialize_tensor_metadata(out) for out in out_profile_data) # type: ignore[assignment] + if isinstance(out_profile_data, list) + else deserialize_tensor_metadata(out_profile_data) + ) + return OpProfile(args_profile=args_profile, out_profile=out_profile) # type: ignore[arg-type] + + loaded_data = yaml.safe_load(yaml_str) + loaded_torch_version = loaded_data["torch_version"] + + if loaded_torch_version != get_torch_version(): + raise RuntimeError( + "Unable to load outdated profile. It was saved with torch version: " + f"{loaded_torch_version} but the current torch version is: {get_torch_version()}" + ) + + operators_data = loaded_data["operators"] + return { + operator: {deserialize_op_profile(profile) for profile in profiles} + for operator, profiles in operators_data.items() + } + + +def load_op_profiles(f: FileLike) -> dict[str, set[OpProfile]]: + """ + Loads the saved operator profiles from `save_op_profiles`. + """ + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + with open(f) as file: + yaml_str = file.read() + + elif isinstance(f, io.BytesIO): + yaml_str = f.read().decode("utf-8") + + else: + raise ValueError(f"Invalid type of file {f}") + + return read_profiles_from_yaml(yaml_str) diff --git a/phivenv/Lib/site-packages/torch/_library/infer_schema.py b/phivenv/Lib/site-packages/torch/_library/infer_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..c3aa6125b4148c75743a596404eb5a877c8c5dab --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_library/infer_schema.py @@ -0,0 +1,324 @@ +# mypy: allow-untyped-defs +import collections +import inspect +import typing +from types import GenericAlias +from typing import Optional, Union + +import torch +from torch import device, dtype, Tensor, types +from torch.utils._exposed_in import exposed_in + + +# This is used as a negative test for +# test_custom_ops.py::TestTypeConversion::test_type_eval. +_TestTensor = torch.Tensor + + +@exposed_in("torch.library") +def infer_schema( + prototype_function: typing.Callable, + /, + *, + mutates_args, + op_name: Optional[str] = None, +) -> str: + r"""Parses the schema of a given function with type hints. The schema is inferred from the + function's type hints, and can be used to define a new operator. + + We make the following assumptions: + + * None of the outputs alias any of the inputs or each other. + * | String type annotations "device, dtype, Tensor, types" without library specification are + | assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union" + | without library specification are assumed to be typing.*. + * | Only the args listed in ``mutates_args`` are being mutated. If ``mutates_args`` is "unknown", + | it assumes that all inputs to the operator are being mutates. + + Callers (e.g. the custom ops API) are responsible for checking these assumptions. + + Args: + prototype_function: The function from which to infer a schema for from its type annotations. + op_name (Optional[str]): The name of the operator in the schema. If ``name`` is None, then the + name is not included in the inferred schema. Note that the input schema to + ``torch.library.Library.define`` requires a operator name. + mutates_args ("unknown" | Iterable[str]): The arguments that are mutated in the function. + + Returns: + The inferred schema. + + Example: + >>> def foo_impl(x: torch.Tensor) -> torch.Tensor: + >>> return x.sin() + >>> + >>> infer_schema(foo_impl, op_name="foo", mutates_args={}) + foo(Tensor x) -> Tensor + >>> + >>> infer_schema(foo_impl, mutates_args={}) + (Tensor x) -> Tensor + """ + UNKNOWN_MUTATES = "unknown" + pf_globals = prototype_function.__globals__ + pf_locals = None + # TODO: Once our minimum version is py3.10+ pass `eval_str=True` to + # inspect.signature() and we no longer need to deal with stringified + # annotations below. + sig = inspect.signature(prototype_function) + + def error_fn(what): + raise ValueError(f"infer_schema(func): {what} Got func with signature {sig})") + + def convert_type_string(annotation_type: str): + try: + return eval(annotation_type, pf_globals, pf_locals) + except Exception: + error_fn( + f"Unsupported type annotation {annotation_type}. It is not a type." + ) + + def unstringify_types( + tys: tuple[Union[type[object], str], ...], + ) -> tuple[tuple[typing.Any, ...], bool]: + res = [] + changed = False + for ty in tys: + ty, ty_changed = unstringify_type(ty) + res.append(ty) + changed |= ty_changed + if changed: + return tuple(res), True + else: + return tys, False # type: ignore[return-value] + + def unstringify_type(ty: Union[type[object], str]) -> tuple[typing.Any, bool]: + # Dig through a generic type and if it contains a stringified type + # convert that to a real type. The second return value indicates if the + # type contained a string or not. + if isinstance(ty, str): + return convert_type_string(ty), True + elif origin := typing.get_origin(ty): + args, args_changed = unstringify_types(typing.get_args(ty)) + if args_changed: + return GenericAlias(origin, args), True + + return ty, False + + params = [] + seen_args = set() + saw_kwarg_only_arg = False + for idx, (name, param) in enumerate(sig.parameters.items()): + if not supported_param(param): + error_fn("We do not support positional-only args, varargs, or varkwargs.") + + if param.kind == inspect.Parameter.KEYWORD_ONLY: + # The first time we see a kwarg-only arg, add "*" to the schema. + if not saw_kwarg_only_arg: + params.append("*") + saw_kwarg_only_arg = True + + if param.annotation is inspect.Parameter.empty: + error_fn(f"Parameter {name} must have a type annotation.") + + # The annotation might be converted to a string by annotation, + # we convert it to the actual type. + annotation_type, _ = unstringify_type(param.annotation) + + if annotation_type not in SUPPORTED_PARAM_TYPES: + if annotation_type.__origin__ is tuple: + list_type = tuple_to_list(annotation_type) + example_type_str = "\n\n" + # Only suggest the list type if this type is supported. + if list_type in SUPPORTED_PARAM_TYPES.keys(): + example_type_str = f"For example, {list_type}.\n\n" + error_fn( + f"Parameter {name} has unsupported type {param.annotation}. " + f"We do not support Tuple inputs in schema. As a workaround, please try to use List instead. " + f"{example_type_str}" + f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}." + ) + else: + error_fn( + f"Parameter {name} has unsupported type {param.annotation}. " + f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}." + ) + + schema_type = SUPPORTED_PARAM_TYPES[annotation_type] + if type(mutates_args) == str: + if mutates_args != UNKNOWN_MUTATES: + raise ValueError( + "mutates_args must either be a sequence of the names of " + "the arguments that are mutated or the string 'unknown'. " + ) + if schema_type.startswith("Tensor"): + schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" + elif name in mutates_args: + if not schema_type.startswith("Tensor"): + error_fn( + f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated" + ) + schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" + seen_args.add(name) + if param.default is inspect.Parameter.empty: + params.append(f"{schema_type} {name}") + else: + default_repr = None + if param.default is None or isinstance(param.default, (int, float, bool)): + default_repr = str(param.default) + elif isinstance(param.default, (str, torch.device)): + default_repr = f'"{param.default}"' + elif isinstance(param.default, torch.dtype): + dtype_repr = str(param.default) + torch_dot = "torch." + assert dtype_repr.startswith(torch_dot) + default_repr = dtype_repr[len(torch_dot) :] + else: + error_fn( + f"Parameter {name} has an unsupported default value type {type(param.default)}. " + f"Please file an issue on GitHub so we can prioritize this." + ) + params.append(f"{schema_type} {name}={default_repr}") + if mutates_args != UNKNOWN_MUTATES: + mutates_args_not_seen = set(mutates_args) - seen_args + if len(mutates_args_not_seen) > 0: + error_fn( + f"{mutates_args_not_seen} in mutates_args were not found in " + f"the custom op's signature. " + f"mutates_args should contain the names of all args that the " + f"custom op mutates, or just the string 'unknown' if you don't know." + ) + return_annotation, _ = unstringify_type(sig.return_annotation) + ret = parse_return(return_annotation, error_fn) + if op_name is not None: + return f"{op_name}({', '.join(params)}) -> {ret}" + return f"({', '.join(params)}) -> {ret}" + + +def derived_types( + base_type: Union[type, typing._SpecialForm], + cpp_type: str, + list_base: bool, + optional_base_list: bool, + optional_list_base: bool, +): + result: list[tuple[Union[type, typing._SpecialForm, GenericAlias], str]] = [ + (base_type, cpp_type), + (typing.Optional[base_type], f"{cpp_type}?"), + ] + + def derived_seq_types(typ: Union[type, typing._SpecialForm]): + return ( + typing.Sequence[typ], # type: ignore[valid-type] # noqa: UP006 + typing.List[typ], # type: ignore[valid-type] # noqa: UP006 + GenericAlias(collections.abc.Sequence, (typ,)), + GenericAlias(list, (typ,)), + ) + + if list_base: + result.extend( + (seq_typ, f"{cpp_type}[]") for seq_typ in derived_seq_types(base_type) + ) + if optional_base_list: + result.extend( + (seq_typ, f"{cpp_type}?[]") + for seq_typ in derived_seq_types(typing.Optional[base_type]) + ) + if optional_list_base: + result.extend( + (typing.Optional[seq_typ], f"{cpp_type}[]?") + for seq_typ in derived_seq_types(base_type) + ) + return result + + +def get_supported_param_types(): + data: list[tuple[Union[type, typing._SpecialForm], str, bool, bool, bool]] = [ + # (python type, schema type, type[] variant, type?[] variant, type[]? variant + (Tensor, "Tensor", True, True, False), + (int, "SymInt", True, False, True), + (float, "float", True, False, True), + (bool, "bool", True, False, True), + (str, "str", False, False, False), + (types.Number, "Scalar", True, False, False), + (dtype, "ScalarType", False, False, False), + (device, "Device", False, False, False), + ] + result = [] + for line in data: + result.extend(derived_types(*line)) + return dict(result) + + +SUPPORTED_RETURN_TYPES = { + Tensor: "Tensor", + typing.List[Tensor]: "Tensor[]", # noqa: UP006 + list[Tensor]: "Tensor[]", + int: "SymInt", + float: "float", + bool: "bool", + types.Number: "Scalar", +} + + +def parse_return(annotation, error_fn): + if annotation is None: + return "()" + + if annotation is inspect.Parameter.empty: + error_fn("No return type annotation was provided. Please add one.") + + origin = typing.get_origin(annotation) + if origin is not tuple: + if annotation not in SUPPORTED_RETURN_TYPES.keys(): + error_fn( + f"Return has unsupported type {annotation}. " + f"The valid types are: {SUPPORTED_RETURN_TYPES}." + ) + return SUPPORTED_RETURN_TYPES[annotation] + + args = typing.get_args(annotation) + for arg in args: + if arg not in SUPPORTED_RETURN_TYPES: + error_fn( + f"Return has unsupported type {annotation}. " + f"The valid types are: {SUPPORTED_RETURN_TYPES}." + ) + output_ty = ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + + # use (()) to represent tuple with single element + if len(args) == 1: + output_ty = "(" + output_ty + ")" + return "(" + output_ty + ")" + + +SUPPORTED_PARAM_TYPES = get_supported_param_types() + + +def supported_param(param: inspect.Parameter) -> bool: + return param.kind in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + + +def tuple_to_list(tuple_type: type[tuple]) -> type[list]: + """ + Convert `tuple_type` into a list type with the same type arguments. Assumes that `tuple_type` is typing.Tuple type. + """ + type_args = getattr(tuple_type, "__args__", None) + # Account for different python versions, e.g. python 3.8 would give () + # but python 3.12 would give None. + if ( + tuple_type is typing.Tuple # noqa: UP006 + or tuple_type is tuple + or type_args == () + or type_args is None + ): + # Handle the case of an empty tuple type + return list + elif len(type_args) == 1: + # General case: create a List with the same type arguments + return list[type_args[0]] # type: ignore[valid-type] + elif len(type_args) == 2 and type_args[1] is Ellipsis: + return list[type_args[0]] # type: ignore[valid-type] + else: + return list[typing.Union[tuple(type_args)]] # type: ignore[misc, return-value] diff --git a/phivenv/Lib/site-packages/torch/_library/simple_registry.py b/phivenv/Lib/site-packages/torch/_library/simple_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..a5283666d9e6e5aa257a0d04d2d9e2c5e46d6d0c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_library/simple_registry.py @@ -0,0 +1,85 @@ +# mypy: allow-untyped-defs +from typing import Callable, Optional + +from .fake_impl import FakeImplHolder +from .utils import RegistrationHandle + + +__all__ = ["SimpleLibraryRegistry", "SimpleOperatorEntry", "singleton"] + + +class SimpleLibraryRegistry: + """Registry for the "simple" torch.library APIs + + The "simple" torch.library APIs are a higher-level API on top of the + raw PyTorch DispatchKey registration APIs that includes: + - fake impl + + Registrations for these APIs do not go into the PyTorch dispatcher's + table because they may not directly involve a DispatchKey. For example, + the fake impl is a Python function that gets invoked by FakeTensor. + Instead, we manage them here. + + SimpleLibraryRegistry is a mapping from a fully qualified operator name + (including the overload) to SimpleOperatorEntry. + """ + + def __init__(self): + self._data = {} + + def find(self, qualname: str) -> "SimpleOperatorEntry": + if qualname not in self._data: + self._data[qualname] = SimpleOperatorEntry(qualname) + return self._data[qualname] + + +singleton: SimpleLibraryRegistry = SimpleLibraryRegistry() + + +class SimpleOperatorEntry: + """This is 1:1 to an operator overload. + + The fields of SimpleOperatorEntry are Holders where kernels can be + registered to. + """ + + def __init__(self, qualname: str): + self.qualname: str = qualname + self.fake_impl: FakeImplHolder = FakeImplHolder(qualname) + self.torch_dispatch_rules: GenericTorchDispatchRuleHolder = ( + GenericTorchDispatchRuleHolder(qualname) + ) + + # For compatibility reasons. We can delete this soon. + @property + def abstract_impl(self): + return self.fake_impl + + +class GenericTorchDispatchRuleHolder: + def __init__(self, qualname): + self._data = {} + self.qualname = qualname + + def register( + self, torch_dispatch_class: type, func: Callable + ) -> RegistrationHandle: + if self.find(torch_dispatch_class): + raise RuntimeError( + f"{torch_dispatch_class} already has a `__torch_dispatch__` rule registered for {self.qualname}" + ) + self._data[torch_dispatch_class] = func + + def deregister(): + del self._data[torch_dispatch_class] + + return RegistrationHandle(deregister) + + def find(self, torch_dispatch_class): + return self._data.get(torch_dispatch_class, None) + + +def find_torch_dispatch_rule(op, torch_dispatch_class: type) -> Optional[Callable]: + return singleton.find(op.__qualname__).torch_dispatch_rules.find( + torch_dispatch_class + ) diff --git a/phivenv/Lib/site-packages/torch/_library/triton.py b/phivenv/Lib/site-packages/torch/_library/triton.py new file mode 100644 index 0000000000000000000000000000000000000000..1ca2fd66cfaa153c9c570d6173b3db7b71df63a0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_library/triton.py @@ -0,0 +1,274 @@ +import contextlib +import threading +from collections.abc import Generator, Iterable +from typing import Any, Callable, Optional, Union + +from torch.utils._exposed_in import exposed_in + +from .custom_ops import custom_op, CustomOpDef +from .infer_schema import infer_schema + + +@exposed_in("torch.library") +def triton_op( + name: str, + fn: Optional[Callable] = None, + /, + *, + mutates_args: Union[str, Iterable[str]], + schema: Optional[str] = None, +) -> Callable: + """Create a custom operator whose implementation is backed by 1+ triton kernels. + + This is a more structured way of using triton kernels with PyTorch. + Prefer using triton kernels with no ``torch.library`` custom operator wrappers + (like :func:`torch.library.custom_op`, :func:`torch.library.triton_op`) because + that is simpler; + only use :func:`torch.library.custom_op`/:func:`torch.library.triton_op` if you + want to create an operator that behaves like PyTorch built-in operators. + For example, you may use a ``torch.library`` wrapper API to define the + behavior of the triton kernel when passed a tensor subclass or under + a TorchDispatchMode. + + Use :func:`torch.library.triton_op` instead of :func:`torch.library.custom_op` + when the implementation + consists of 1+ triton kernels. :func:`torch.library.custom_op` treats + custom operators as opaque (:func:`torch.compile` and + :func:`torch.export.export` will never trace into them), but ``triton_op`` + makes the implementation visible to these subsystems, allowing them + to optimize the triton kernel(s). + + Note that ``fn`` must only consist of calls to PyTorch-understood + operators and triton kernels. Any triton kernels called inside ``fn`` + must be wrapped in a call to :func:`torch.library.wrap_triton`. + + Args: + name (str): A name for the custom op that looks like "{namespace}::{name}", + e.g. "mylib::my_linear". The name is used as the op's stable identifier + in PyTorch subsystems (e.g. torch.export, FX graphs). + To avoid name collisions, please use your project name as the namespace; + e.g. all custom ops in pytorch/fbgemm use "fbgemm" as the namespace. + mutates_args (Iterable[str] or "unknown"): The names of args that the function mutates. + This MUST be accurate, otherwise, the behavior is undefined. If "unknown", + it pessimistically assumes that all inputs to the operator are being mutated. + schema (None | str): A schema string for the operator. If None + (recommended) we'll infer a schema for the operator from its type + annotations. We recommend letting us infer a schema unless you + have a specific reason not to. + Example: "(Tensor x, int y) -> (Tensor, Tensor)". + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> import torch + >>> from torch.library import triton_op, wrap_triton + >>> + >>> import triton + >>> from triton import language as tl + >>> + >>> @triton.jit + >>> def add_kernel( + >>> in_ptr0, + >>> in_ptr1, + >>> out_ptr, + >>> n_elements, + >>> BLOCK_SIZE: "tl.constexpr", + >>> ): + >>> pid = tl.program_id(axis=0) + >>> block_start = pid * BLOCK_SIZE + >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) + >>> mask = offsets < n_elements + >>> x = tl.load(in_ptr0 + offsets, mask=mask) + >>> y = tl.load(in_ptr1 + offsets, mask=mask) + >>> output = x + y + >>> tl.store(out_ptr + offsets, output, mask=mask) + >>> + >>> @triton_op("mylib::add", mutates_args={}) + >>> def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + >>> output = torch.empty_like(x) + >>> n_elements = output.numel() + >>> + >>> def grid(meta): + >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + >>> + >>> # NB: we need to wrap the triton kernel in a call to wrap_triton + >>> wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16) + >>> return output + >>> + >>> @torch.compile + >>> def f(x, y): + >>> return add(x, y) + >>> + >>> x = torch.randn(3, device="cuda") + >>> y = torch.randn(3, device="cuda") + >>> + >>> z = f(x, y) + >>> assert torch.allclose(z, x + y) + + """ + + def dec(fn: Callable[..., object]) -> CustomOpDef: + def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def] + # Optimization: we're passing regular Tensors into the triton kernel, so + # no need to go through HOP dispatch + with set_wrap_triton_enabled(False): + return fn(*args, **kwargs) + + result = custom_op( + name, + backend_fn, + mutates_args=mutates_args, + schema=infer_schema(fn, mutates_args=mutates_args), + ) + from .._subclasses.functional_tensor import FunctionalTensorMode + + # We require that the user pass us a function that is make_fx traceable, + # so we can just register it as the Fake/meta kernel. + result.register_fake(fn) + + # We decompose the operator when FunctionalTensorMode is active. + # The goal is to decompose the operator in AOTDispatcher. + # - With torch.compile, this means that the backend (usually Inductor) + # can see a call to the triton kernel(s) and so it can directly optimize + # them by inlining them into the lowering process. + def functional_decomp( # type: ignore[no-untyped-def] + mode, op, types, args, kwargs + ): + # NOTE [Export custom triton op] + # For torch.export (strict and non-strict), we don't do functional decomposition. + # Instead, we preserve the custom triton ops as custom ops. This is because we want + # the exported program to be high-level and serializable. If we decompose + # the custom op to a functional hop and make it a node in exported program, + # we need to figure out ways of serializing the hop and its arguments, which can be triton.jited + # functions and triton dtypes. This is undesireble because: + # - it can be tedious to maintain a layer that serializes the jited function (e.g. with a string) and dtypes. + # - exported program will contain the implementation detail (e.g. triton source code) for a specific + # backend (GPU), which is probably at a wrong level of abstraction. + # - changes to triton or the serialization logic for triton arguments can be BC breaking + # + # In the short term, we expect users to have a separate aot_compile stage that compiles the exported program + # into a Cubin file on the same machine that users call export, which does autotuning and removes triton + # dependency and serve the model with Cubin. This guarantees that triton changes won't break BC. + # In the long term, we may export multiple cubins for the triton op directly + from torch.export._trace import custom_triton_ops_decomposition_disabled + + if custom_triton_ops_decomposition_disabled(): + return mode.__torch_dispatch__(op, types, args, kwargs) + else: + with mode: + return fn(*args, **kwargs) + + result.register_torch_dispatch(FunctionalTensorMode, functional_decomp) + return result + + if fn is None: + return dec + else: + return dec(fn) + + +wrap_triton_enabled = threading.local() +wrap_triton_enabled_default = True + + +@contextlib.contextmanager +def set_wrap_triton_enabled(enabled: bool) -> Generator[None, None, None]: + """If triton kernels annotated with @wrap_triton should dispatch via HOP + or go straight to the triton kernel execution. + + We have this switch because eager-mode performance of HOP dispatch is slow + enough to matter (~1ms) and we know that wrap_triton isn't necessary in + some situations (eager-mode with regular Tensors) + """ + try: + prev = is_wrap_triton_enabled() + wrap_triton_enabled.value = enabled + yield + finally: + wrap_triton_enabled.value = prev + + +def is_wrap_triton_enabled() -> bool: + return getattr(wrap_triton_enabled, "value", wrap_triton_enabled_default) + + +def capture_triton(triton_kernel: Callable, /) -> Any: + """This API has been renamed to wrap_triton""" + return wrap_triton(triton_kernel) + + +@exposed_in("torch.library") +def wrap_triton(triton_kernel: Callable, /) -> Any: + """Allows capture of a triton kernel into a graph via make_fx or + non-strict ``torch.export``. + + These technologies perform Dispatcher-based tracing (via + ``__torch_dispatch__``) and cannot see calls to raw triton kernels. + The ``wrap_triton`` API wraps a triton kernel into a callable that + can actually be traced into a graph. + + Please use this API together with :func:`torch.library.triton_op`. + + Examples: + + >>> # xdoctest: +SKIP + >>> import torch + >>> import triton + >>> from triton import language as tl + >>> from torch.fx.experimental.proxy_tensor import make_fx + >>> from torch.library import wrap_triton + >>> + >>> @triton.jit + >>> def add_kernel( + >>> in_ptr0, + >>> in_ptr1, + >>> out_ptr, + >>> n_elements, + >>> BLOCK_SIZE: "tl.constexpr", + >>> ): + >>> pid = tl.program_id(axis=0) + >>> block_start = pid * BLOCK_SIZE + >>> offsets = block_start + tl.arange(0, BLOCK_SIZE) + >>> mask = offsets < n_elements + >>> x = tl.load(in_ptr0 + offsets, mask=mask) + >>> y = tl.load(in_ptr1 + offsets, mask=mask) + >>> output = x + y + >>> tl.store(out_ptr + offsets, output, mask=mask) + >>> + >>> def add(x, y): + >>> output = torch.empty_like(x) + >>> n_elements = output.numel() + >>> + >>> def grid_fn(meta): + >>> return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + >>> + >>> wrap_triton(add_kernel)[grid_fn](x, y, output, n_elements, 16) + >>> return output + >>> + >>> x = torch.randn(3, device="cuda") + >>> y = torch.randn(3, device="cuda") + >>> gm = make_fx(add)(x, y) + >>> print(gm.code) + >>> # def forward(self, x_1, y_1): + >>> # empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False) + >>> # triton_kernel_wrapper_mutation_proxy = triton_kernel_wrapper_mutation( + >>> # kernel_idx = 0, constant_args_idx = 0, + >>> # grid = [(1, 1, 1)], kwargs = { + >>> # 'in_ptr0': x_1, 'in_ptr1': y_1, 'out_ptr': empty_like, + >>> # 'n_elements': 3, 'BLOCK_SIZE': 16 + >>> # }) + >>> # return empty_like + + """ + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + + from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper + + if not isinstance(triton_kernel, (JITFunction, Autotuner)): + raise RuntimeError( + "wrap_triton only works on functions annotated with triton.jit or triton.autotune" + ) + if not is_wrap_triton_enabled(): + return triton_kernel + return TraceableTritonKernelWrapper(triton_kernel, None, None) diff --git a/phivenv/Lib/site-packages/torch/_library/utils.py b/phivenv/Lib/site-packages/torch/_library/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0743451141b2fe1e8480d744988a8d3c08a7d36b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_library/utils.py @@ -0,0 +1,525 @@ +# mypy: allow-untyped-defs +import dataclasses +import inspect +import sys +import warnings +from collections.abc import Iterable, Iterator +from typing import Any, Callable, Union + +import torch +import torch.utils._pytree as pytree +from torch import _C, _utils_internal +from torch._ops import OpOverload + + +def warn_deploy(stacklevel=3): + warnings.warn( + "Python torch.library APIs do nothing under torch::deploy (multipy). " + "Please instead use C++ custom operator registration APIs.", + RuntimeWarning, + stacklevel=stacklevel, + ) + + +@dataclasses.dataclass +class Kernel: + """Models a (function, source location)""" + + func: Callable + source: str + + def __call__(self, *args, **kwargs): + return self.func(*args, **kwargs) + + +class RegistrationHandle: + """Does something when someone calls .destroy() on it""" + + def __init__(self, on_destroy: Callable): + self._on_destroy = on_destroy + + def destroy(self) -> None: + self._on_destroy() + + +def get_source(stacklevel: int) -> str: + """Get a string that represents the caller. + + Example: "/path/to/foo.py:42" + + Use stacklevel=1 to get the caller's source + Use stacklevel=2 to get the caller's caller's source + etc. + """ + frame = inspect.getframeinfo(sys._getframe(stacklevel)) + source = f"{frame.filename}:{frame.lineno}" + return source + + +def parse_namespace(qualname: str) -> tuple[str, str]: + splits = qualname.split("::") + if len(splits) != 2: + raise ValueError( + f"Expected `qualname` to be of the form " + f'"namespace::name", but got {qualname}. ' + f"The qualname passed to the torch.library APIs must consist " + f"of a namespace and a name, e.g. aten::sin" + ) + return splits[0], splits[1] + + +def lookup_op(qualname: str) -> OpOverload: + namespace, name = parse_namespace(qualname) + if "." in name: + name, overload = name.split(".") + else: + overload = "default" + ns = getattr(torch.ops, namespace) + packet = getattr(ns, name) + return getattr(packet, overload) + + +def is_builtin(op: OpOverload) -> bool: + assert isinstance(op, OpOverload) + return op.namespace in {"aten", "prim", "prims"} + + +def is_functional_schema(schema: Any) -> bool: + """Check if the schema is functional. + + An operator is functional if: + - it does not mutate any of its inputs + - it does not return a view on any of its inputs + - it has at least one return + """ + + def is_functional(schema): + if schema.is_mutable: + return False + rets = schema.returns + is_non_mutating_view = len(rets) > 0 and any( + r.alias_info is not None and not r.alias_info.is_write for r in rets + ) + if is_non_mutating_view: + return False + if not schema.returns: + return False + return True + + if isinstance(schema, torch._C.FunctionSchema): + return is_functional(schema) + + # Lazy import because not all PyTorch builds have torchgen + from torchgen.model import FunctionSchema + + if isinstance(schema, str): + schema = FunctionSchema.parse(schema) + assert isinstance(schema, FunctionSchema) + return is_functional(schema) + + +# should be torch._C.JitType but that annotation is busted +def is_tensorlist_like_type(typ: Any) -> bool: + return ( + typ == _C.ListType(_C.TensorType.get()) + or typ == _C.ListType(_C.OptionalType(_C.TensorType.get())) + or typ == _C.OptionalType(_C.ListType(_C.TensorType.get())) + or typ == _C.OptionalType(_C.ListType(_C.OptionalType(_C.TensorType.get()))) + ) + + +# should be torch._C.JitType but that annotation is busted +def is_tensor_like_type(typ: Any) -> bool: + return typ == _C.TensorType.get() or typ == _C.OptionalType(_C.TensorType.get()) + + +def mutates_and_returns_first_arg(op: OpOverload): + """Check if an op is an inplace aten op, i.e. it mutates and returns the first arg. + + TODO: torchgen/model.py's FunctionSchema.parse is the source of truth for this, + but not all PyTorch builds have torchgen (due to the yaml dependency being weird). + Figure this out. + + Example: add_(Tensor(a!) x, Tensor y) -> Tensor(a) + """ + if op.namespace != "aten": + return False + schema = op._schema + if not len(schema.returns) == 1: + return False + if schema.returns[0].alias_info is None: + return False + alias_set = schema.returns[0].alias_info.after_set + if len(alias_set) != 1: + return False + loc = next(iter(alias_set)) + if len(schema.arguments) < 1: + return False + first_arg = schema.arguments[0] + if first_arg.alias_info is None: + return False + if not first_arg.alias_info.is_write: + return False + alias_set = first_arg.alias_info.after_set + if len(alias_set) != 1: + return False + if loc != next(iter(alias_set)): + return False + for arg in schema.arguments[1:]: + if arg.alias_info is not None: + return False + return True + + +def fill_defaults(schema, args, kwargs): + new_args = [] + new_kwargs = {} + for i in range(len(schema.arguments)): + info = schema.arguments[i] + if info.kwarg_only: + if info.name in kwargs: + new_kwargs[info.name] = kwargs[info.name] + else: + new_kwargs[info.name] = info.default_value + else: + if i < len(args): + new_args.append(args[i]) + else: + new_args.append(info.default_value) + return tuple(new_args), new_kwargs + + +def zip_schema( + schema: _C.FunctionSchema, args: tuple[Any, ...], kwargs: dict[str, Any] +) -> Iterable[tuple[_C.Argument, Any]]: + """zips schema.arguments and (args, kwargs) together. + + Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload: + that is, (args, kwargs) must be bindable to the schema (args, kwargs). + """ + assert len(schema.arguments) >= len(args) + len(kwargs) + for i in range(len(schema.arguments)): + info = schema.arguments[i] + if info.kwarg_only: + if info.name in kwargs: + yield info, kwargs[info.name] + continue + if i >= len(args): + if not info.kwarg_only and info.name in kwargs: + yield info, kwargs[info.name] + # args that are equal to their default values are not populated + # if they are followed by args that are equal to their defaults. + # Skip these. + continue + yield info, args[i] + return + + +def hop_schema_from_fx_node(node): + from torchgen.gen_schema_utils import FunctionSchemaGen + + hop = node.target + if not isinstance(hop, torch._ops.HigherOrderOperator): + raise RuntimeError("fx_node's target must be a hop.") + + def _collect_example_val(node): + meta_val = node.meta.get("val", None) + if meta_val is None: + assert node.op == "get_attr" + meta_val = getattr(node.graph.owning_module, node.target) + return meta_val + + example_inputs = [] + for arg in node.args: + if isinstance(arg, (torch.fx.Node, torch.fx.node.Node)): + example_inputs.append(_collect_example_val(arg)) + elif isinstance( + arg, (torch.fx.immutable_collections.immutable_list, list, tuple) + ): + example_inputs.append([_collect_example_val(x) for x in arg]) + else: + raise RuntimeError(f"Unsupported arg type {type(arg)}") + + # Bound the arguments to make sure number of inputs are correct + bound_args: inspect.BoundArguments = inspect.signature(hop.__call__).bind( + *example_inputs + ) + + # We treat example_output as a single value in return. This is to differentiate 1. return a single val + # vs 2. return a tuple with one element. + example_output = _collect_example_val(node) + return FunctionSchemaGen.from_example( + hop._name, tuple(bound_args.arguments.items()), (list(example_output),) + ) + + +def can_generate_trivial_fake_impl(op: OpOverload) -> bool: + assert isinstance(op, OpOverload) + if is_builtin(op): + # We control the built-ins. These may (in rare cases) + # do input metadata mutation (which we have banned on custom ops) + return False + schema = op._schema + # It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution + if not schema.is_mutable: + return False + if len(schema.returns) > 0: + return False + # If the op returns nothing, then it has a trivial fake impl. + return True + + +def requires_set_python_module() -> bool: + """If an op was defined in C++ and extended from Python using the + torch.library APIs, returns if we require that there have been a + m.set_python_module("mylib.ops") call from C++ that associates + the C++ op with a python module. + """ + return getattr(_utils_internal, "REQUIRES_SET_PYTHON_MODULE", True) + + +def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs): + assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode) + args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values())) + # TODO: need to double check the semantics of the "types" argument to torch_dispatch. + # It's generated in PyInterpreter.cpp, but seems to be generated in two places, + # where in one case we only include tensors with the python key, and in another + # we include **all** tensors. + overload_types = [ + type(a) + for a in args_flattened + if isinstance(a, torch.Tensor) + and torch._C._dispatch_keys(a).has(torch._C.DispatchKey.Python) + ] + # TODO: check that I got these args correct (in C++, we pass in "0000"??) + + return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs) + + +def has_kwarg_only_args(schema: _C.FunctionSchema): + return any(a.kwarg_only for a in schema.arguments) + + +def has_kwarg_only_tensors(schema: _C.FunctionSchema): + for a in schema.arguments: + if not (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)): + continue + if not a.kwarg_only: + continue + return True + return False + + +def has_tensor_arg(schema: _C.FunctionSchema) -> bool: + """ + Given a schema, returns True if the schema has a Tensor arg. + A Tensor arg is any arg with a type annotation that might involve Tensor. + """ + return any( + (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)) + for a in schema.arguments + ) + + +def get_device_arg_index(schema: _C.FunctionSchema) -> Union[int, None]: + """ + Given a schema, returns the id of the `device: torch.device` argument. + If it does not exist, returns None. + """ + for index, arg in enumerate(schema.arguments): + if arg.type is _C.DeviceObjType.get() and arg.name == "device": + return index + return None + + +def iter_tensors( + args: tuple[Any], kwargs: dict[str, Any], allowed_nesting: int = 1 +) -> Iterator[torch.Tensor]: + def check(arg): + if isinstance(arg, torch.Tensor): + yield arg + elif allowed_nesting > 0 and isinstance(arg, (tuple, list)): + yield from iter_tensors(tuple(arg), {}, allowed_nesting - 1) + + for arg in args: + yield from check(arg) + for kwarg in kwargs.values(): + yield from check(kwarg) + + +def check_aliasing_constraint(name, prev, result, get_module=lambda: "???"): + """ + custom operators' outputs must not alias any inputs or other outputs. + """ + storages = {id(t.untyped_storage()) for t in prev if isinstance(t, torch.Tensor)} + tuple_result = result + if not isinstance(result, tuple): + tuple_result = (result,) + for tensor in iter_tensors(tuple_result, {}): + key = id(tensor.untyped_storage()) + if id(tensor.untyped_storage()) in storages: + raise RuntimeError( + f"{name} (with implementation in {get_module()}): " + f"The output of this custom operator (1) must not " + f"also be an input to this custom operator and " + f"(2) may not alias any inputs to this custom operator " + f"or other returns. " + f"The most common way to trigger this error is if " + f"we have y = custom_op(x) and y and x are the same Tensor. " + f"Please instead return a clone of the offending output " + f"tensor(s) (e.g. return x.clone()) or refactor the custom " + f"operator to not return y." + ) + storages.add(key) + + +def _c_check_aliasing_constraint(name, args, kwargs, result, get_module=lambda: "???"): + """ + custom operators' outputs must not have any aliases + This version uses C++ implementation for perf. + Only List container is supported. + Tensors in Lists with not only Tensors are checked. + """ + tuple_result = result + if not isinstance(result, tuple): + tuple_result = (result,) + if _C._any_output_is_alias_to_input_or_output(args, kwargs, tuple_result): + raise RuntimeError( + f"{name} (with implementation in {get_module()}): " + f"The output of this custom operator (1) must not " + f"also be an input to this custom operator and " + f"(2) may not alias any inputs to this custom operator " + f"or other returns. " + f"The most common way to trigger this error is if " + f"we have y = custom_op(x) and y and x are the same Tensor. " + f"Please instead return a clone of the offending output " + f"tensor(s) (e.g. return x.clone()) or refactor the custom " + f"operator to not return y." + ) + + +class MutationChecker: + """ + Check if an operator mutated its arguments. + Usage: + + checker = MutationChecker(op, flat_args, args_spec) + op(*args, **kwargs) + checker.check() + """ + + def __init__(self, op, flat_args, args_spec): + self.op = op + self.args_spec = args_spec + self.flat_args = flat_args + self.real_pre_hashes = [ + hash_tensor(a) if isinstance(a, torch.Tensor) else None for a in flat_args + ] + + def check(self): + real_post_hashes = [ + hash_tensor(a) if isinstance(a, torch.Tensor) else None + for a in self.flat_args + ] + was_mutated = [ + not torch.equal(pre, post) + and not (pre.isnan().all() and post.isnan().all()) + if isinstance(pre, torch.Tensor) and isinstance(post, torch.Tensor) + else None + for pre, post in zip(self.real_pre_hashes, real_post_hashes) + ] + was_mutated_args, was_mutated_kwargs = pytree.tree_unflatten( + was_mutated, self.args_spec + ) + for info, was_mutated in zip_schema( + self.op._schema, was_mutated_args, was_mutated_kwargs + ): + + def check_one(info, was_mutated): + if info.is_write == was_mutated: + return + raise RuntimeError( + f"{self.op._name}: for argument '{info.name}': the operator's schema " + f"{self.op._schema} specified that " + f"the operator {'mutates' if info.is_write else 'does not mutate'} " + f"the argument, but this seems to be emperically wrong. " + f"Please make the schema and operator behavior consistent. " + f"You can specify that an operator mutates a Tensor by " + f"e.g. changing its schema type from 'Tensor name' to 'Tensor(a!) name'" + f"(use different identifiers (a, b, c, ...) for different Tensors)" + ) + + if is_tensor_like_type(info.type): + check_one(info, was_mutated) + elif is_tensorlist_like_type(info.type): + was_any_mutated = False if was_mutated is None else any(was_mutated) + check_one(info, was_any_mutated) + + +def hash_tensor(t: torch.Tensor) -> torch.Tensor: + """Some inexpensive hash. Used as a quick and dirty indicator for tensor mutation""" + return t.detach().float().mean() + + +def has_fake_kernel(op: torch._ops.OpOverload) -> bool: + """If an operator (that stays alive until FakeTensorMode) has a Fake kernel. + Don't use this if the operator decomposes before FakeTensorMode. + """ + if can_generate_trivial_fake_impl(op): + return True + name = op._name + if torch._C._dispatch_has_kernel_for_dispatch_key( + name, "CompositeImplicitAutograd" + ): + return True + opdef = torch._library.custom_ops._maybe_get_opdef(name) + if opdef is None: + # the non-torch.library.custom_op path + if torch._C._dispatch_has_kernel_for_dispatch_key( + name, "CompositeExplicitAutograd" + ): + return True + entry = torch._library.simple_registry.singleton.find(name) + if entry.fake_impl.kernel is not None: + return True + if torch._C._dispatch_has_kernel_for_dispatch_key(name, "Meta"): + return True + else: + # the torch.library.custom_op path + if opdef._abstract_fn is not None: + return True + return False + + +def mutated_args_kwargs(schema: _C.FunctionSchema) -> tuple[list[int], list[str]]: + idxs = [] + keys = [] + for i, info in enumerate(schema.arguments): + if info.alias_info is not None and info.alias_info.is_write: + if info.kwarg_only: + keys.append(info.name) + else: + idxs.append(i) + return idxs, keys + + +tags_by_priority = [ + _C.Tag.needs_exact_strides, + _C.Tag.needs_contiguous_strides, + _C.Tag.needs_fixed_stride_order, + _C.Tag.flexible_layout, +] + + +def get_layout_constraint_tag(fn, *, with_default=True): + for tag in tags_by_priority: + if tag in fn.tags: + return tag + if with_default: + if is_builtin(fn): + return _C.Tag.flexible_layout + import torch._functorch + from torch._functorch import config + + return getattr(torch._C.Tag, config.custom_op_default_layout_constraint) + return None diff --git a/phivenv/Lib/site-packages/torch/_logging/__init__.py b/phivenv/Lib/site-packages/torch/_logging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0eb669c421b142737b89bb548d739d97a8871cb1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_logging/__init__.py @@ -0,0 +1,19 @@ +# Top level logging module for torch logging +# Design doc: https://docs.google.com/document/d/1ZRfTWKa8eaPq1AxaiHrq4ASTPouzzlPiuquSBEJYwS8/edit# +# Simple setup for onboarding (see above doc for more detail): +# 1. register any top-level log qualified name for your module in torch._logging._registrations (see there for examples) +# 2. register any artifacts ( below) in torch._logging._registrations +# a. call getArtifactLogger(__name__, ) at your logging site instead of the standard logger to log your artifact +import torch._logging._registrations + +from ._internal import ( + _init_logs, + DEFAULT_LOGGING, + dtrace_structured, + get_structured_logging_overhead, + getArtifactLogger, + LazyString, + set_logs, + trace_structured, + warning_once, +) diff --git a/phivenv/Lib/site-packages/torch/_logging/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_logging/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f602712c868b721fe12bbdbe3953484fa6b94afe Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_logging/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_logging/__pycache__/_internal.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_logging/__pycache__/_internal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfc94f59938a5dded8f08ffc1e360617f8400c47 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_logging/__pycache__/_internal.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_logging/__pycache__/_registrations.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_logging/__pycache__/_registrations.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85c4e25f831c429f989f7c98bd7ee85a9bd605d0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_logging/__pycache__/_registrations.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_logging/__pycache__/scribe.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_logging/__pycache__/scribe.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79069d9c302932cdf084cf36c1d78f0a8d9802b6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_logging/__pycache__/scribe.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_logging/__pycache__/structured.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_logging/__pycache__/structured.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16c128c649a81f04f7e31d6a6e371adce9fb11db Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_logging/__pycache__/structured.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_logging/_internal.py b/phivenv/Lib/site-packages/torch/_logging/_internal.py new file mode 100644 index 0000000000000000000000000000000000000000..28648d4c23265f9441e28eca2babbb8601c82865 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_logging/_internal.py @@ -0,0 +1,1343 @@ +# mypy: allow-untyped-defs +import functools +import hashlib +import importlib.util +import itertools +import json +import logging +import os +import os.path +import pathlib +import re +import sys +import tempfile +import time +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Callable, Generic, Optional, Union +from typing_extensions import ParamSpec +from weakref import WeakSet + +import torch._logging.structured +from torch._guards import CompileId +from torch._utils_internal import log_trace_structured_event +from torch.utils._traceback import CapturedTraceback + + +_P = ParamSpec("_P") + +log = logging.getLogger(__name__) + +# This is a synthetic logger which doesn't correspond to an actual logger, +# but handles all of our "tracing" logging, which is structured and doesn't go +# to stderr but always goes to a dedicated log file. We don't put these +# loggers in the classic module hierarchy, because we don't want a suppression +# of logs to also cause a trace to get suppressed (traces typically are not +# collected, unless we are in prod, in which case they always are collected.) +# +# TODO: Maybe we should allow for some sub-hierarchy so you can control which +# traces you want to collect, for performance reasons. +# +# See https://docs.google.com/document/d/1CX_hJ0PNy9f3R1y8TJrfkSeLkvGjjjLU84BSXgS2AZ8/edit +trace_log = logging.getLogger("torch.__trace") + +DEFAULT_LOG_LEVEL = logging.WARNING +LOG_ENV_VAR = "TORCH_LOGS" +LOG_OUT_ENV_VAR = "TORCH_LOGS_OUT" +LOG_FORMAT_ENV_VAR = "TORCH_LOGS_FORMAT" +LOG_TRACE_ID_FILTER = "TORCH_LOGS_TRACE_ID_FILTER" +TRACE_ENV_VAR = "TORCH_TRACE" +DTRACE_ENV_VAR = "TORCH_DTRACE" + +LOG_TRACE_HANDLER: Optional["LazyTraceHandler"] = None + +GET_DTRACE_STRUCTURED = False + + +@dataclass +class LogRegistry: + # shorthand name to log qualified name + # Note: this only contains loggers registered + # from register_log + # e.g. "dynamo" -> "torch._dynamo" + log_alias_to_log_qnames: dict[str, list[str]] = field(default_factory=dict) + + # artifact logger qualified names, + # this is populated lazily, as calls to getArtifactLogger + # currently formatted as .__ + # e.g. "torch._dynamo.convert_frame.__guards" + artifact_log_qnames: set[str] = field(default_factory=set) + + # child logs of registered logs if specified via open + # registration by the user (ie placing "torch._dynamo.output_graph" in the env var) + # these need to be tracked so their levels can be reset properly + # e.g. "torch._dynamo.output_graph" + child_log_qnames: set[str] = field(default_factory=set) + + # artifact names, populated by register_artifact + # e.g. "guards" + artifact_names: set[str] = field(default_factory=set) + + # Artifacts that should be visible by default in the error message + visible_artifacts: set[str] = field(default_factory=set) + + # A short description of each artifact + artifact_descriptions: dict[str, str] = field(default_factory=dict) + + # artifacts which are not displayed unless explicitly named in the + # settings. Ex. output_code is NOT displayed even if the inductor + # log level is set to DEBUG. It must be explicitly named in the settings + off_by_default_artifact_names: set[str] = field(default_factory=set) + + # logging format string for artifacts + artifact_log_formatters: dict[str, logging.Formatter] = field(default_factory=dict) + + def is_artifact(self, name): + return name in self.artifact_names + + def is_log(self, alias): + return alias in self.log_alias_to_log_qnames + + # register a log with an alias + def register_log(self, alias, log_qnames: Union[str, list[str]]) -> None: + if isinstance(log_qnames, str): + log_qnames = [log_qnames] + self.log_alias_to_log_qnames[alias] = log_qnames + + # register an artifact name + def register_artifact_name( + self, name, description, visible, off_by_default, log_format + ) -> None: + self.artifact_names.add(name) + if visible: + self.visible_artifacts.add(name) + self.artifact_descriptions[name] = description + + # if off by default, don't enable it + # when log_name's log_level is set to DEBUG + if off_by_default: + self.off_by_default_artifact_names.add(name) + + if log_format is not None: + self.artifact_log_formatters[name] = logging.Formatter(log_format) + + # register the qualified name of an artifact log + # this is needed to know which logs need to be reset + # whenever the log_state is changed + def register_artifact_log(self, artifact_log_qname) -> None: + self.artifact_log_qnames.add(artifact_log_qname) + + def register_child_log(self, log_qname) -> None: + self.child_log_qnames.add(log_qname) + + # flattens all the qnames together (TODO: consider memoizing?) + def get_log_qnames(self) -> set[str]: + return set(itertools.chain.from_iterable(self.log_alias_to_log_qnames.values())) + + def get_artifact_log_qnames(self): + return set(self.artifact_log_qnames) + + def get_child_log_qnames(self): + return set(self.child_log_qnames) + + def is_off_by_default(self, artifact_qname): + return artifact_qname in self.off_by_default_artifact_names + + +@dataclass +class LogState: + # qualified log names -> currently set log level + log_qname_to_level: dict[str, str] = field(default_factory=dict) + + # the set of currently enabled artifacts + artifact_names: set[str] = field(default_factory=set) + + def enable_artifact(self, artifact_name) -> None: + self.artifact_names.add(artifact_name) + + def is_artifact_enabled(self, name): + return name in self.artifact_names + + def enable_log(self, log_qnames, log_level) -> None: + if isinstance(log_qnames, str): + log_qnames = [log_qnames] + for log_qname in log_qnames: + self.log_qname_to_level[log_qname] = log_level + + def get_log_level_pairs(self): + """Returns all qualified module names for which the user requested + explicit logging settings. + + .. warning: + + This function used to return all loggers, regardless of whether + or not the user specified them or not; it now only returns logs + which were explicitly mentioned by the user (and torch, which + always is implicitly requested when we initialize our logging + subsystem.) + """ + return self.log_qname_to_level.items() + + def clear(self) -> None: + self.log_qname_to_level.clear() + self.artifact_names.clear() + + +log_registry = LogRegistry() +log_state = LogState() + +# sample usage: torch._logging.set_logs(**torch._logging.DEFAULT_LOGGING) +DEFAULT_LOGGING = { + "dynamo": logging.INFO, + "aot": logging.INFO, + "inductor": logging.INFO, + "fsdp": logging.INFO, + "ddp_graphs": True, + "graph_breaks": True, + "guards": True, + "recompiles": True, + "dynamic": logging.INFO, +} + + +def set_logs( + *, + all: Optional[int] = None, + dynamo: Optional[int] = None, + aot: Optional[int] = None, + autograd: Optional[int] = None, + dynamic: Optional[int] = None, + inductor: Optional[int] = None, + distributed: Optional[int] = None, + c10d: Optional[int] = None, + ddp: Optional[int] = None, + fsdp: Optional[int] = None, + dtensor: Optional[int] = None, + onnx: Optional[int] = None, + bytecode: bool = False, + aot_graphs: bool = False, + aot_joint_graph: bool = False, + ddp_graphs: bool = False, + graph: bool = False, + graph_code: bool = False, + graph_code_verbose: bool = False, + graph_breaks: bool = False, + graph_sizes: bool = False, + guards: bool = False, + recompiles: bool = False, + recompiles_verbose: bool = False, + trace_source: bool = False, + trace_call: bool = False, + trace_bytecode: bool = False, + output_code: bool = False, + kernel_code: bool = False, + schedule: bool = False, + perf_hints: bool = False, + pre_grad_graphs: bool = False, + post_grad_graphs: bool = False, + ir_pre_fusion: bool = False, + ir_post_fusion: bool = False, + onnx_diagnostics: bool = False, + fusion: bool = False, + overlap: bool = False, + export: Optional[int] = None, + modules: Optional[dict[str, Union[int, bool]]] = None, + cudagraphs: bool = False, + sym_node: bool = False, + compiled_autograd: bool = False, + compiled_autograd_verbose: bool = False, + cudagraph_static_inputs: bool = False, + benchmarking: bool = False, + autotuning: bool = False, + graph_region_expansion: bool = False, + inductor_metrics: bool = False, + hierarchical_compile: bool = False, +) -> None: + """ + Sets the log level for individual components and toggles individual log + artifact types. + + .. warning:: This feature is a prototype and may have compatibility + breaking changes in the future. + + .. note:: The ``TORCH_LOGS`` environment variable has complete precedence + over this function, so if it was set, this function does nothing. + + A component is a set of related features in PyTorch. All of the log + messages emitted from a given component have their own log levels. If the + log level of a particular message has priority greater than or equal to its + component's log level setting, it is emitted. Otherwise, it is suppressed. + This allows you to, for instance, silence large groups of log messages that + are not relevant to you and increase verbosity of logs for components that + are relevant. The expected log level values, ordered from highest to lowest + priority, are: + + * ``logging.CRITICAL`` + * ``logging.ERROR`` + * ``logging.WARNING`` + * ``logging.INFO`` + * ``logging.DEBUG`` + * ``logging.NOTSET`` + + See documentation for the Python ``logging`` module for more information on + log levels: ``_ + + An artifact is a particular type of log message. Each artifact is assigned + to a parent component. A component can emit many different kinds of + artifacts. In general, an artifact is emitted if either its corresponding + setting in the argument list below is turned on or if its parent component + is set to a log level less than or equal to the log level of the artifact. + + Keyword args: + all (:class:`Optional[int]`): + The default log level for all components. Default: ``logging.WARN`` + + dynamo (:class:`Optional[int]`): + The log level for the TorchDynamo component. Default: ``logging.WARN`` + + aot (:class:`Optional[int]`): + The log level for the AOTAutograd component. Default: ``logging.WARN`` + + autograd (:class:`Optional[int]`): + The log level for autograd. Default: ``logging.WARN`` + + inductor (:class:`Optional[int]`): + The log level for the TorchInductor component. Default: ``logging.WARN`` + + dynamic (:class:`Optional[int]`): + The log level for dynamic shapes. Default: ``logging.WARN`` + + distributed (:class:`Optional[int]`): + Whether to log c10d communication operations and other debug info from PyTorch Distributed components. + Default: ``logging.WARN`` + + c10d (:class:`Optional[int]`): + Whether to log c10d communication operations related debug info in PyTorch Distributed components. + Default: ``logging.WARN`` + + ddp (:class:`Optional[int]`): + Whether to log debug info related to ``DistributedDataParallel``(DDP) from PyTorch Distributed components. + Default: ``logging.WARN`` + + fsdp (:class:`Optional[int]`): + Whether to log debug info related to ``FullyShardedDataParallel``(FSDP) in PyTorch Distributed components. + Default: ``logging.WARN`` + + dtensor (:class:`Optional[int]`): + Whether to log debug info related to ``DTensor``(DTensor) in PyTorch Distributed components. + Default: ``logging.WARN`` + + onnx (:class:`Optional[int]`): + The log level for the ONNX exporter component. Default: ``logging.WARN`` + + bytecode (:class:`bool`): + Whether to emit the original and generated bytecode from TorchDynamo. + Default: ``False`` + + aot_graphs (:class:`bool`): + Whether to emit the graphs generated by AOTAutograd. Default: ``False`` + + aot_joint_graph (:class:`bool`): + Whether to emit the joint forward-backward graph generated by AOTAutograd. Default: ``False`` + + ddp_graphs (:class:`bool`): + Whether to emit graphs generated by DDPOptimizer. Default: ``False`` + + graph (:class:`bool`): + Whether to emit the graph captured by TorchDynamo in tabular format. + Default: ``False`` + + graph_code (:class:`bool`): + Whether to emit the python source of the graph captured by TorchDynamo. + Default: ``False`` + + graph_code_verbose (:class:`bool`): + Whether to emit verbose/intermediate FX pass logs for graph code. Default: ``False`` + + graph_breaks (:class:`bool`): + Whether to emit the graph breaks encountered by TorchDynamo. + Default: ``False`` + + graph_sizes (:class:`bool`): + Whether to emit tensor sizes of the graph captured by TorchDynamo. + Default: ``False`` + + guards (:class:`bool`): + Whether to emit the guards generated by TorchDynamo for each compiled + function. Default: ``False`` + + recompiles (:class:`bool`): + Whether to emit a guard failure reason and message every time + TorchDynamo recompiles a function. Default: ``False`` + + recompiles_verbose (:class:`bool`): + Whether to emit all guard failure reasons when TorchDynamo recompiles + a function, even those that are not actually run. Default: ``False`` + + trace_source (:class:`bool`): + Whether to emit when TorchDynamo begins tracing a new line. Default: ``False`` + + trace_call (:class:`bool`): + Whether to emit detailed line location when TorchDynamo creates an FX node + corresponding to function call. Python 3.11+ only. Default: ``False`` + + trace_bytecode (:class:`bool`): + Whether to emit bytecode instructions and traced stack state as TorchDynamo + traces bytecode. Default: ``False`` + + output_code (:class:`bool`): + Whether to emit the TorchInductor output code on a per-graph basis. Default: ``False`` + + kernel_code (:class:`bool`): + Whether to emit the TorchInductor output code on a per-kernel bases. Default: ``False`` + + schedule (:class:`bool`): + Whether to emit the TorchInductor schedule. Default: ``False`` + + perf_hints (:class:`bool`): + Whether to emit the TorchInductor perf hints. Default: ``False`` + + pre_grad_graphs (:class:`bool`): + Whether to emit the graphs before inductor grad passes. Default: ``False`` + + post_grad_graphs (:class:`bool`): + Whether to emit the graphs generated by after post grad passes. Default: ``False`` + + ir_pre_fusion (:class:`bool`): + Whether to emit the graphs before inductor fusion passes. Default: ``False`` + + ir_post_fusion (:class:`bool`): + Whether to emit the graphs after inductor fusion passes. Default: ``False`` + + onnx_diagnostics (:class:`bool`): + Whether to emit the ONNX exporter diagnostics in logging. Default: ``False`` + + fusion (:class:`bool`): + Whether to emit detailed Inductor fusion decisions. Default: ``False`` + + overlap (:class:`bool`): + Whether to emit detailed Inductor compute/comm overlap decisions. Default: ``False`` + + sym_node (:class:`bool`): + Whether to emit debug info for various SymNode opterations. Default: ``False`` + + export (:class:`Optional[int]`): + The log level for export. Default: ``logging.WARN`` + + benchmarking (:class:`bool`): + Whether to emit detailed Inductor benchmarking information. Default: ``False`` + + modules (dict): + This argument provides an alternate way to specify the above log + component and artifact settings, in the format of a keyword args + dictionary given as a single argument. There are two cases + where this is useful (1) if a new log component or artifact has + been registered but a keyword argument for it has not been added + to this function and (2) if the log level for an unregistered module + needs to be set. This can be done by providing the fully-qualified module + name as the key, with the log level as the value. Default: ``None`` + + cudagraph_static_inputs (:class:`bool`): + Whether to emit debug info for cudagraph static input detection. Default: ``False`` + + autotuning (:class:`bool`): + Autotuning choice logs, such as kernel source, perf, and tuning parameters. Default: ``False`` + + graph_region_expansion (:class:`bool`): + Whether to emit the detailed steps of the duplicate graph region tracker expansion algorithm. Default: ``False`` + + inductor_metrics (:class:`bool`): + Whether to estimate the runtimes of the nodes in a graph and log them to the metrics table. Default: ``False`` + + hierarchical_compile (:class:`bool`): + Whether to emit debug info for hierarchical compilation. Default: ``False`` + + Example:: + + >>> # xdoctest: +SKIP + >>> import logging + + # The following changes the "dynamo" component to emit DEBUG-level + # logs, and to emit "graph_code" artifacts. + + >>> torch._logging.set_logs(dynamo=logging.DEBUG, graph_code=True) + + # The following enables the logs for a different module + + >>> torch._logging.set_logs(modules={"unregistered.module.name": logging.DEBUG}) + """ + # ignore if env var is set + if LOG_ENV_VAR in os.environ: + log.warning( + "Using TORCH_LOGS environment variable for log settings, ignoring call to set_logs" + ) + return + + log_state.clear() + + modules = modules or {} + + def _set_logs(**kwargs) -> None: + for alias, val in itertools.chain(kwargs.items(), modules.items()): # type: ignore[union-attr] + if val is None: + continue + + if log_registry.is_artifact(alias): + if not isinstance(val, bool): + raise ValueError( + f"Expected bool to enable artifact {alias}, received {val}" + ) + + if val: + log_state.enable_artifact(alias) + elif log_registry.is_log(alias) or alias in log_registry.child_log_qnames: + if val not in logging._levelToName: + raise ValueError( + f"Unrecognized log level for log {alias}: {val}, valid level values " + f"are: {','.join([str(k) for k in logging._levelToName.keys()])}" + ) + + log_state.enable_log( + log_registry.log_alias_to_log_qnames.get(alias, alias), val + ) + elif _is_valid_module(alias): + if not _has_registered_parent(alias): + log_registry.register_log(alias, alias) + else: + log_registry.register_child_log(alias) + log_state.enable_log( + log_registry.log_alias_to_log_qnames.get(alias, alias), val + ) + else: + raise ValueError( + f"Unrecognized log or artifact name passed to set_logs: {alias}" + ) + + _init_logs() + + _set_logs( + torch=all, + dynamo=dynamo, + aot=aot, + autograd=autograd, + inductor=inductor, + dynamic=dynamic, + bytecode=bytecode, + aot_graphs=aot_graphs, + aot_joint_graph=aot_joint_graph, + ddp_graphs=ddp_graphs, + distributed=distributed, + c10d=c10d, + ddp=ddp, + fsdp=fsdp, + dtensor=dtensor, + graph=graph, + graph_code=graph_code, + graph_code_verbose=graph_code_verbose, + graph_breaks=graph_breaks, + graph_sizes=graph_sizes, + guards=guards, + recompiles=recompiles, + recompiles_verbose=recompiles_verbose, + trace_source=trace_source, + trace_call=trace_call, + trace_bytecode=trace_bytecode, + output_code=output_code, + kernel_code=kernel_code, + schedule=schedule, + perf_hints=perf_hints, + pre_grad_graphs=pre_grad_graphs, + post_grad_graphs=post_grad_graphs, + ir_pre_fusion=ir_pre_fusion, + ir_post_fusion=ir_post_fusion, + onnx=onnx, + onnx_diagnostics=onnx_diagnostics, + fusion=fusion, + overlap=overlap, + sym_node=sym_node, + export=export, + cudagraphs=cudagraphs, + compiled_autograd=compiled_autograd, + compiled_autograd_verbose=compiled_autograd_verbose, + cudagraph_static_inputs=cudagraph_static_inputs, + benchmarking=benchmarking, + autotuning=autotuning, + graph_region_expansion=graph_region_expansion, + inductor_metrics=inductor_metrics, + hierarchical_compile=hierarchical_compile, + ) + + +def get_loggers() -> list[logging.Logger]: + """ + Returns: a list of all registered loggers + """ + return [logging.getLogger(qname) for qname in log_registry.get_log_qnames()] + + +def register_log(setting_name, log_name) -> None: + """ + Enables a log to be controlled by the env var and user API with the setting_name + Args: + setting_name: the shorthand name used in the env var and user API + log_name: the log name that the setting_name is associated with + """ + log_registry.register_log(setting_name, log_name) + + +def register_artifact( + setting_name, description, visible=False, off_by_default=False, log_format=None +) -> None: + """ + Enables an artifact to be controlled by the env var and user API with name + Args: + setting_name: the shorthand name used in the env var and user API + description: A description of what this outputs + visible: Whether it gets suggested to users by default + off_by_default: whether this artifact should be logged when the ancestor loggers + are enabled at level DEBUG + """ + log_registry.register_artifact_name( + setting_name, description, visible, off_by_default, log_format + ) + + +def getArtifactLogger(module_qname, artifact_name) -> logging.Logger: + if artifact_name not in log_registry.artifact_names: + raise ValueError( + f"Artifact name: {repr(artifact_name)} not registered," + f"please call register_artifact({repr(artifact_name)}) in torch._logging.registrations." + ) + qname = module_qname + f".__{artifact_name}" + log = logging.getLogger(qname) + log.artifact_name = artifact_name # type: ignore[attr-defined] + log_registry.register_artifact_log(qname) + configure_artifact_log(log) + return log + + +INCR_VERBOSITY_CHAR = "+" +DECR_VERBOSITY_CHAR = "-" +VERBOSITY_REGEX = ( + "(" + + "|".join([re.escape(INCR_VERBOSITY_CHAR), re.escape(DECR_VERBOSITY_CHAR)]) + + "?)" +) + + +def configure_artifact_log(log) -> None: + # If the artifact is off by default, then it should only be logged when explicitly + # enabled; set propagate to False so that this artifact is not propagated + # to its ancestor logger + if log_registry.is_off_by_default(log.artifact_name): + log.propagate = False + + # enable artifact logging when explicitly enabled + if log_state.is_artifact_enabled(log.artifact_name): + log.setLevel(logging.DEBUG) + log.propagate = True + + +# match a comma separated list of loggable names (whitespace allowed after commas) +def _gen_settings_regex(): + return re.compile(r"((\+|-)?[\w\.]+,\s*)*(\+|-)?[\w\.]+?") + + +def _validate_settings(settings): + return re.fullmatch(_gen_settings_regex(), settings) is not None + + +def help_message(verbose=False): + def pad_to(s, length=30): + assert len(s) <= length + return s + " " * (length - len(s)) + + if verbose: + printed_artifacts = log_registry.artifact_names + else: + printed_artifacts = log_registry.visible_artifacts + if verbose: + heading = "All registered names" + else: + heading = "Visible registered names (use TORCH_LOGS='+help' for full list)" + lines = ( + ["all"] + + sorted(log_registry.log_alias_to_log_qnames.keys()) + + sorted( + [ + f"{pad_to(name)}\t{log_registry.artifact_descriptions[name]}" + for name in printed_artifacts + ] + ) + ) + setting_info = " " + "\n ".join(lines) + examples = """ +Examples: + TORCH_LOGS="+dynamo,aot" will set the log level of TorchDynamo to + logging.DEBUG and AOT to logging.INFO + + TORCH_LOGS="-dynamo,+inductor" will set the log level of TorchDynamo to + logging.ERROR and TorchInductor to logging.DEBUG + + TORCH_LOGS="aot_graphs" will enable the aot_graphs artifact + + TORCH_LOGS="+dynamo,schedule" will enable set the log level of TorchDynamo + to logging.DEBUG and enable the schedule artifact + + TORCH_LOGS="+some.random.module,schedule" will set the log level of + some.random.module to logging.DEBUG and enable the schedule artifact + + TORCH_LOGS_FORMAT="%(levelname)s: %(message)s" or any provided format + string will set the output format + Valid keys are "levelname", "message", "pathname", "levelno", "lineno", + "filename" and "name". + + TORCH_LOGS_OUT=/tmp/output.txt will output the logs to /tmp/output.txt as + well. This is useful when the output is long. +""" # flake8: noqa: B950 + msg = f""" +TORCH_LOGS Info +{examples} + +{heading} +{setting_info} +""" + return msg + + +def _invalid_settings_err_msg(settings, verbose=False): + valid_settings = ( + ["all"] + + list(log_registry.log_alias_to_log_qnames.keys()) + + list(log_registry.artifact_names) + ) + valid_settings = ", ".join(sorted(valid_settings)) + msg = f""" +Invalid log settings: {settings}, must be a comma separated list of fully +qualified module names, registered log names or registered artifact names. +For more info on various settings, try TORCH_LOGS="help" +Valid settings: +{valid_settings} +""" + return msg + + +@functools.lru_cache +def _parse_log_settings(settings): + if settings == "": + return {} + + if settings == "help": + raise ValueError(help_message(verbose=False)) + elif settings == "+help": + raise ValueError(help_message(verbose=True)) + if not _validate_settings(settings): + raise ValueError(_invalid_settings_err_msg(settings)) + + settings = re.sub(r"\s+", "", settings) + log_names = settings.split(",") + + def get_name_level_pair(name): + clean_name = name.replace(INCR_VERBOSITY_CHAR, "") + clean_name = clean_name.replace(DECR_VERBOSITY_CHAR, "") + + if name[0] == INCR_VERBOSITY_CHAR: + level = logging.DEBUG + elif name[0] == DECR_VERBOSITY_CHAR: + level = logging.ERROR + else: + level = logging.INFO + + return clean_name, level + + log_state = LogState() + + for name in log_names: + name, level = get_name_level_pair(name) + + if name == "all": + name = "torch" + + if log_registry.is_log(name): + assert level is not None + log_qnames = log_registry.log_alias_to_log_qnames[name] + log_state.enable_log(log_qnames, level) + elif log_registry.is_artifact(name): + log_state.enable_artifact(name) + elif _is_valid_module(name): + if not _has_registered_parent(name): + log_registry.register_log(name, name) + else: + log_registry.register_child_log(name) + log_state.enable_log(name, level) + else: + raise ValueError(_invalid_settings_err_msg(settings)) + + return log_state + + +def _is_valid_module(qname): + spec = importlib.util.find_spec(qname) + return spec is not None + + +def _update_log_state_from_env() -> None: + global log_state + log_setting = os.environ.get(LOG_ENV_VAR, None) + if log_setting is not None: + log_state = _parse_log_settings(log_setting) + + +def _has_registered_parent(log_qname) -> bool: + cur_log = logging.getLogger(log_qname) + + registered_log_qnames = log_registry.get_log_qnames() + + while cur_log.parent: + if cur_log.name in registered_log_qnames: + return True + cur_log = cur_log.parent + + return False + + +def make_module_path_relative(abs_path): + """ + Given an absolute filepath corresponding to a Python module which was + loaded via normal import mechanisms using sys.path, convert it into + a relative path relative to one of the Python search paths. + """ + + abs_path = pathlib.Path(abs_path).resolve() + + for path in sys.path: + try: + rel_path = abs_path.relative_to(path) + except ValueError: + continue + else: + return str(rel_path) + + return str(abs_path) + + +# apply custom formats to artifacts when necessary +class TorchLogsFormatter(logging.Formatter): + def __init__( + self, *, trace: bool = False, trace_id_filter: Optional[set[str]] = None + ) -> None: + super().__init__() + self._is_trace = trace + self._trace_id_filter = trace_id_filter + + def format(self, record): + artifact_name = getattr(logging.getLogger(record.name), "artifact_name", None) + if artifact_name is not None: + artifact_formatter = log_registry.artifact_log_formatters.get( + artifact_name, None + ) + if artifact_formatter is not None: + return artifact_formatter.format(record) + + record.message = record.getMessage() + record.asctime = self.formatTime(record, "%m%d %H:%M:%S") + + # exception handling - copied from logging.Formatter.format + s = record.message + if record.exc_info: + # Cache the traceback text to avoid converting it multiple times + # (it's constant anyway) + if not record.exc_text: + record.exc_text = self.formatException(record.exc_info) + if record.exc_text: + if s[-1:] != "\n": + s = s + "\n" + s = s + record.exc_text + if record.stack_info: + if s[-1:] != "\n": + s = s + "\n" + s = s + self.formatStack(record.stack_info) + + record.rankprefix = "" + if not self._is_trace and dist.is_available() and dist.is_initialized(): + record.rankprefix = f"[rank{dist.get_rank()}]:" + + record.traceid = "" + if ( + not self._is_trace + and (trace_id := torch._guards.CompileContext.current_trace_id()) + is not None + ): + record.traceid = f" [{trace_id}]" + + glog_level_to_abbr = { + "DEBUG": "V", # V is for VERBOSE in glog + "INFO": "I", + "WARNING": "W", + "ERROR": "E", + "CRITICAL": "C", + } + + shortlevel = glog_level_to_abbr.get(record.levelname, record.levelname) + + record.artifactprefix = "" + if artifact_name is not None: + record.artifactprefix = f" [__{artifact_name}]" + + filepath = make_module_path_relative(record.pathname) + + if ( + self._trace_id_filter + and record.traceid.strip() not in self._trace_id_filter + ): + return "" + + prefix = ( + f"{record.rankprefix}{shortlevel}{record.asctime}.{int(record.msecs * 1000):06d} {record.process} " + f"{filepath}:" + f"{record.lineno}]{record.traceid}{record.artifactprefix}" + ) + if self._is_trace: + assert s == "" + try: + r = f"{prefix} {json.dumps(record.metadata)}" + except TypeError: + log.warning("failing metadata: %r", record.metadata) + raise + if record.payload is not None: + r += "".join(f"\n\t{l}" for l in record.payload.split("\n")) + return r + else: + lines = s.split("\n") + return "\n".join(f"{prefix} {l}" for l in lines) + + +def _default_formatter(): + fmt = os.environ.get(LOG_FORMAT_ENV_VAR, None) + trace_id_filter = { + item.strip() + for item in os.environ.get(LOG_TRACE_ID_FILTER, "").split(",") + if item.strip() + } + if fmt is None: + return TorchLogsFormatter(trace_id_filter=trace_id_filter) + else: + if fmt in ("short", "basic"): + fmt = logging.BASIC_FORMAT + return logging.Formatter(fmt) + + +DEFAULT_FORMATTER = _default_formatter() + + +def _setup_handlers(create_handler_fn, log) -> None: + debug_handler = _track_handler(create_handler_fn()) + debug_handler.setFormatter(DEFAULT_FORMATTER) + debug_handler.setLevel(logging.DEBUG) + log.addHandler(debug_handler) + + +handlers = WeakSet() # type: ignore[var-annotated] + + +# mark handlers that we've created +# so we don't modify user handlers +def _track_handler(handler): + handlers.add(handler) + return handler + + +def _is_torch_handler(handler): + return handler in handlers + + +# clears all torch handlers on specified loggers +def _clear_handlers(log) -> None: + to_remove = [handler for handler in log.handlers if _is_torch_handler(handler)] + for handler in to_remove: + log.removeHandler(handler) + + +def _reset_logs() -> None: + # reset all registered logs + for log_qname in log_registry.get_log_qnames(): + log = logging.getLogger(log_qname) + log.setLevel(logging.WARNING) + log.propagate = False + _clear_handlers(log) + + # reset all artifact and child logs + for artifact_log_qname in itertools.chain( + log_registry.get_artifact_log_qnames(), log_registry.get_child_log_qnames() + ): + log = logging.getLogger(artifact_log_qname) + log.setLevel(logging.NOTSET) + log.propagate = True + + trace_log.propagate = False + _clear_handlers(trace_log) + + +def _get_log_state(): + return log_state + + +def _set_log_state(state) -> None: + global log_state + log_state = state + + +def _init_logs(log_file_name=None) -> None: + global GET_DTRACE_STRUCTURED + + _reset_logs() + _update_log_state_from_env() + + out = os.environ.get(LOG_OUT_ENV_VAR, None) + if out is not None: + log_file_name = out + + # First, reset all known (registered) loggers to NOTSET, so that they + # respect their parent log level + for log_qname in log_registry.get_log_qnames(): + # But not the top level torch level: this defaults to WARNING so + # that our log messages don't leak to the lower levels + if log_qname == "torch": + continue + log = logging.getLogger(log_qname) + log.setLevel(logging.NOTSET) + + # Now, for all loggers which the user requested to have non-standard + # logging behavior, modify their log levels + for log_qname, level in log_state.get_log_level_pairs(): + log = logging.getLogger(log_qname) + log.setLevel(level) + + # Finally, setup handlers for all registered loggers + for log_qname in log_registry.get_log_qnames(): + log = logging.getLogger(log_qname) + _setup_handlers( + logging.StreamHandler, + log, + ) + + if log_file_name is not None: + _setup_handlers( + lambda: logging.FileHandler(log_file_name), + log, + ) + + # configure artifact loggers, note: this must happen last + # since the levels of ancestor loggers are taken into account + for artifact_log_qname in log_registry.get_artifact_log_qnames(): + log = logging.getLogger(artifact_log_qname) + configure_artifact_log(log) + + # Setup handler for the special trace_log, with different default + # configuration + trace_dir_name = os.environ.get(TRACE_ENV_VAR, None) + + if dtrace_dir_name := os.environ.get(DTRACE_ENV_VAR, None): + GET_DTRACE_STRUCTURED = True + trace_dir_name = dtrace_dir_name + + # This handler may remove itself if trace_dir_name is None and we are not + # actually in an FB environment. This allows us to defer actually + # initializing it until we actually need to log anything. This is + # important because JK initializes a C++ singleton, which will pork our + # process if we subsequently fork. + global LOG_TRACE_HANDLER + if LOG_TRACE_HANDLER is None: + LOG_TRACE_HANDLER = LazyTraceHandler(trace_dir_name) + # This log is ALWAYS at debug level. We will additionally test if there + # are any handlers before deciding to actually call logging on this. Do + # not manually call + trace_log.setLevel(logging.DEBUG) + trace_log_handler = _track_handler(LOG_TRACE_HANDLER) + trace_log_handler.setFormatter(TorchLogsFormatter(trace=True)) + trace_log.addHandler(trace_log_handler) + + +class LazyTraceHandler(logging.StreamHandler): + """Like FileHandler, but the file is allocated lazily only upon the first log message""" + + def __init__(self, root_dir: Optional[str]) -> None: + # This is implemented in the same way that delay is implemented on + # FileHandler + self.root_dir = root_dir + logging.Handler.__init__(self) + self.stream = None + self._builtin_open = open + + # cloned from FileHandler in cpython + def close(self) -> None: + self.acquire() + try: + try: + if self.stream: + try: + self.flush() + finally: + stream = self.stream + self.stream = None + if hasattr(stream, "close"): + stream.close() + finally: + # Issue #19523: call unconditionally to + # prevent a handler leak when delay is set + # Also see Issue #42378: we also rely on + # self._closed being set to True there + logging.StreamHandler.close(self) + finally: + self.release() + + def emit(self, record) -> None: + if self.stream is None: + if self.root_dir is None: + TRACE_LOG_DIR = "/logs" + + import torch.version as torch_version + + if ( + hasattr(torch_version, "git_version") + and os.getenv("MAST_HPC_JOB_NAME") is None + ): + log.info( + "LazyTraceHandler: disabled because not fbcode or conda on mast" + ) + elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"): + log.info( + "LazyTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False" + ) + elif not os.path.exists(TRACE_LOG_DIR): + log.info( + "LazyTraceHandler: disabled because %s does not exist", + TRACE_LOG_DIR, + ) + elif not os.access(TRACE_LOG_DIR, os.W_OK): + log.info( + "LazyTraceHandler: disabled because %s is not writeable", + TRACE_LOG_DIR, + ) + else: + self.root_dir = TRACE_LOG_DIR + + if self.root_dir is not None: + os.makedirs(self.root_dir, exist_ok=True) + ranksuffix = "" + if dist.is_available() and dist.is_initialized(): + ranksuffix = f"rank_{dist.get_rank()}_" + self.stream = tempfile.NamedTemporaryFile( + mode="w+", + suffix=".log", + prefix=f"dedicated_log_torch_trace_{ranksuffix}", + dir=self.root_dir, + delete=False, + ) + log.info("LazyTraceHandler: logging to %s", self.stream.name) + else: + # We go poof, remove and no-op + trace_log.removeHandler(self) + return + if self.stream: + super().emit(record) + + +@functools.cache +def warning_once(logger_obj, *args, **kwargs) -> None: + """ + This function is similar to `logger.warning()`, but will emit the warning with the same message only once + Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache. + The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to + another type of cache that includes the caller frame information in the hashing function. + """ + logger_obj.warning(*args, **kwargs) + + +class LazyString(Generic[_P]): + def __init__( + self, func: Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs + ) -> None: + self.func = func + self.args = args + self.kwargs = kwargs + + def __str__(self) -> str: + return self.func(*self.args, **self.kwargs) + + +# Logs the time it takes to do structured logging by frame/compile id +# key is always {frame_id}_{frame_compile_id} +structured_logging_overhead: dict[str, float] = defaultdict(float) + + +def add_structured_logging_overhead(time_spent: float) -> None: + global structured_logging_overhead + key = None + if (trace_id := torch._guards.CompileContext.current_trace_id()) is not None: + frame_id = trace_id.compile_id.frame_id + frame_compile_id = trace_id.compile_id.frame_compile_id + # Why not trace_id.attempt, like structured logging? + # We aggregate across all attempts because + # a compilation metric is logged per successful attempt + key = f"{frame_id}_{frame_compile_id}" + # TODO: deal with structured logging that occurs outside of specific compile ids + # It's hard to figure out where we would log that if we want it in compilation metrics + # itself. + if key is not None: + key = str(key) + structured_logging_overhead[key] += time_spent + + +def get_structured_logging_overhead() -> Optional[float]: + key = None + if (trace_id := torch._guards.CompileContext.current_trace_id()) is not None: + frame_id = trace_id.compile_id.frame_id + frame_compile_id = trace_id.compile_id.frame_compile_id + key = f"{frame_id}_{frame_compile_id}" + if key is not None: + return structured_logging_overhead.get(key) + else: + return None + + +def trace_structured_artifact( + name: str, # this will go in metadata + encoding: str, + payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None, +) -> None: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": name, + "encoding": encoding, + }, + payload_fn=payload_fn, + ) + + +def trace_structured( + name: str, + # NB: metadata expected to be dict so adding more info is forward compatible + # Tuple[str, int] is a special case for string interning + metadata_fn: Callable[[], Union[dict[str, Any], tuple[str, int]]] = dict, + *, + payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None, + suppress_context: bool = False, + expect_trace_id: bool = True, # Whether or not we expect to have a current trace id + record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging + compile_id: Optional[CompileId] = None, # Optional if unavailable in the trace +) -> None: + """ + metadata is an arbitrary JSON compatible struct, but it's expected to not be + too long (e.g., less than 1MB) + + payload is an arbitrary string, which can be arbitrarily long (but expected to have + newlines so no lines are too long) + """ + assert "name" not in [ + "rank", + "compiled_autograd_id", + "frame_id", + "frame_compile_id", + "attempt", + ] + assert callable( + metadata_fn + ), f"metadata_fn should be callable, but got {type(metadata_fn)}" + assert callable( + payload_fn + ), f"payload_fn should be callable, but got {type(payload_fn)}" + # trace_log never propagates and is ALWAYS DEBUG, so also check that there + # are handlers instead of checking the log level + if trace_log.handlers: + start_time = time.time_ns() + record: dict[str, object] = {} + record[name] = metadata_fn() + if not suppress_context: + # TODO: Actually, the rank probably should just be emitted once at + # the top, and not repeatedly spammed in all the logs, since it + # never changes and we assume no interleaving + if dist.is_available() and dist.is_initialized(): + record["rank"] = dist.get_rank() + + trace_id = torch._guards.CompileContext.current_trace_id() + if expect_trace_id and trace_id is None and compile_id is None: + # Record the stack of the log call to better diagnose why we + # don't have a frame id for it + record["stack"] = torch._logging.structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ) + else: + cid = trace_id.compile_id if trace_id else compile_id + if cid is not None: + if cid.compiled_autograd_id is not None: + record["compiled_autograd_id"] = cid.compiled_autograd_id + if cid.frame_id is not None: + record["frame_id"] = cid.frame_id + if cid.frame_compile_id is not None: + record["frame_compile_id"] = cid.frame_compile_id + if trace_id: + record["attempt"] = trace_id.attempt + + payload = payload_fn() + if payload is not None: + if not isinstance(payload, str): + if isinstance(payload, list): + # special case to look better + payload = "[\n" + ",\n".join(json.dumps(i) for i in payload) + "\n]" + else: + + def json_default(obj): + # Sets aren't json serializable + if isinstance(obj, set): + return list(obj) + raise TypeError( + f"Object of type {type(obj)} is not JSON serializable" + ) + + # force newlines so we are unlikely to overflow line limit + payload = json.dumps(payload, default=json_default, indent=0) + h = hashlib.md5(usedforsecurity=False) + h.update(payload.encode("utf-8")) + record["has_payload"] = h.hexdigest() + trace_log.debug( + "", extra={"metadata": record, "payload": payload}, stacklevel=2 + ) + log_trace_structured_event(name, record) + + if record_logging_overhead: + # Convert to seconds from nanoseconds, add it to the frame compile total + structured_logging_overhead_s = (time.time_ns() - start_time) / 1e9 + add_structured_logging_overhead(structured_logging_overhead_s) + + +def dtrace_structured( + name: str, + # NB: metadata expected to be dict so adding more info is forward compatible + # Tuple[str, int] is a special case for string interning + metadata_fn: Callable[[], Union[dict[str, Any], tuple[str, int]]] = dict, + *, + payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None, + suppress_context: bool = False, + expect_trace_id: bool = False, # Whether or not we expect to have a current trace id + record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging +) -> None: + """ + For logging more detailed information used for debugging. This may result in + the program becoming slow. + """ + if GET_DTRACE_STRUCTURED: + trace_structured( + name, + metadata_fn, + payload_fn=payload_fn, + suppress_context=suppress_context, + expect_trace_id=expect_trace_id, + record_logging_overhead=record_logging_overhead, + ) + + +import torch._guards +import torch._utils_internal +import torch.distributed as dist diff --git a/phivenv/Lib/site-packages/torch/_logging/_registrations.py b/phivenv/Lib/site-packages/torch/_logging/_registrations.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd8a58d6776c5d8f3c6165f26e264a7661e77a5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_logging/_registrations.py @@ -0,0 +1,248 @@ +# flake8: noqa: B950 +from ._internal import register_artifact, register_log + + +DYNAMIC = [ + "torch.fx.experimental.symbolic_shapes", + "torch.fx.experimental.sym_node", + "torch.fx.experimental.recording", +] +DISTRIBUTED = [ + "torch.distributed", + "torch._dynamo.backends.distributed", + "torch.nn.parallel.distributed", +] + +register_log( + "async_compile", + [ + "torch._inductor.async_compile", + "torch._inductor.compile_worker.tracked_process_pool", + ], +) +register_log( + "cache", ("torch._inductor.remote_cache", "torch._inductor.fb.remote_cache") +) +register_log("dynamo", ["torch._dynamo", *DYNAMIC]) +register_log("fake_tensor", ["torch._subclasses.fake_tensor"]) +register_log("aot", ["torch._functorch.aot_autograd", "torch._functorch._aot_autograd"]) +register_log("autograd", "torch.autograd") +register_log("inductor", ["torch._inductor", "torch._inductor.cudagraph_trees"]) + +register_artifact( + "cudagraphs", + "Logs information from wrapping inductor generated code with cudagraphs.", +) + +register_log("dynamic", DYNAMIC) +register_log("torch", "torch") +register_log("distributed", DISTRIBUTED) +register_log( + "c10d", ["torch.distributed.distributed_c10d", "torch.distributed.rendezvous"] +) +register_log( + "ddp", ["torch.nn.parallel.distributed", "torch._dynamo.backends.distributed"] +) +register_log("pp", ["torch.distributed.pipelining"]) +register_log("fsdp", ["torch.distributed.fsdp", "torch.distributed._composable.fsdp"]) +register_log("dtensor", ["torch.distributed._tensor", "torch.distributed.tensor"]) +register_log("onnx", "torch.onnx") +register_log( + "export", + [ + "torch._dynamo", + "torch.export", + "torch.export.dynamic_shapes", + *DYNAMIC, + "torch._export.converter", + "torch._export.non_strict_utils", + "torch._export.serde.serialize", + "torch.fx.experimental.proxy_tensor", + ], +) + +register_artifact( + "guards", + "This prints the guards for every compiled Dynamo frame. It does not tell you where the guards come from.", + visible=True, +) +register_artifact("verbose_guards", "", off_by_default=True) +register_artifact( + "bytecode", + "Prints the original and modified bytecode from Dynamo. Mostly useful if you're debugging our bytecode generation in Dynamo.", + off_by_default=True, +) +register_artifact( + "graph", + "Prints the dynamo traced graph (prior to AOTDispatch) in a table. If you prefer python code use `graph_code` instead. ", +) +register_artifact("graph_code", "Like `graph`, but gives you the Python code instead.") +register_artifact( + "graph_code_verbose", + "Verbose FX pass logs, e.g. from tensorify_python_scalars and runtime_assert.", +) +register_artifact( + "graph_sizes", "Prints the sizes of all FX nodes in the dynamo graph." +) +register_artifact( + "trace_source", + "As we execute bytecode, prints the file name / line number we are processing and the actual source code. Useful with `bytecode`", +) +register_artifact( + "trace_call", + "Like trace_source, but it will give you the per-expression blow-by-blow if your Python is recent enough.", +) +register_artifact( + "trace_bytecode", + "As we trace bytecode, prints the instruction and the current stack.", +) +register_artifact( + "aot_graphs", + "Prints the FX forward and backward graph generated by AOTDispatch, after partitioning. Useful to understand what's being given to Inductor", + visible=True, +) +register_artifact( + "aot_joint_graph", + "Print FX joint graph from AOTAutograd, prior to partitioning. Useful for debugging partitioning", +) +register_artifact( + "aot_graphs_effects", + "Prints the FX forward and backward graph generated by AOTDispatch, useful for debugging effects processing.", + visible=True, +) +register_artifact( + "pre_grad_graphs", + "Prints the FX graph before inductor pre grad passes. Useful to understand what's being given to Inductor before grad passes", +) +register_artifact( + "post_grad_graphs", + "Prints the FX graph generated by post grad passes. Useful to understand what's being given to Inductor after post grad passes", +) +register_artifact( + "ir_pre_fusion", + "Prints the IR before inductor fusion passes.", + off_by_default=True, +) +register_artifact( + "ir_post_fusion", + "Prints the IR after inductor fusion passes.", + off_by_default=True, +) +register_artifact( + "compiled_autograd", + "Prints various logs in compiled_autograd, including but not limited to the graphs. Useful for debugging compiled_autograd.", + visible=True, +) +register_artifact( + "compiled_autograd_verbose", + "Will affect performance. Prints compiled_autograd logs with C++ info e.g. autograd node -> fx node mapping", + off_by_default=True, +) +register_artifact( + "ddp_graphs", + "Only relevant for compiling DDP. DDP splits into multiple graphs to trigger comms early. This will print each individual graph here.", +) +register_artifact( + "recompiles", + "Prints the reason why we recompiled a graph. Very, very useful.", + visible=True, +) +register_artifact( + "recompiles_verbose", + "Prints all guard checks that fail during a recompilation. " + "At runtime, Dynamo will stop at the first failed check for each failing guard. " + "So not all logged failing checks are actually ran by Dynamo.", + visible=True, + off_by_default=True, +) +register_artifact( + "graph_breaks", + "Prints whenever Dynamo decides that it needs to graph break (i.e. create a new graph). Useful for debugging why torch.compile has poor performance", + visible=True, +) +register_artifact( + "not_implemented", + "Prints log messages whenever we return NotImplemented in a multi-dispatch, letting you trace through each object we attempted to dispatch to", +) +register_artifact( + "output_code", + "Prints the code that Inductor generates (either Triton or C++)", + off_by_default=True, + visible=True, +) +register_artifact( + "kernel_code", + "Prints the code that Inductor generates (on a per-kernel basis)", + off_by_default=True, + visible=True, +) +register_artifact( + "schedule", + "Inductor scheduler information. Useful if working on Inductor fusion algo", + off_by_default=True, +) +register_artifact("perf_hints", "", off_by_default=True) +register_artifact("onnx_diagnostics", "", off_by_default=True) +register_artifact( + "fusion", + "Detailed Inductor fusion decisions. More detailed than 'schedule'", + off_by_default=True, +) +register_artifact( + "loop_ordering", + "Logs related to loop ordering", + off_by_default=True, +) +register_artifact( + "loop_tiling", + "Logs related to loop ordering", + off_by_default=True, +) + +register_artifact( + "overlap", + "Detailed Inductor compute/comm overlap decisions", + off_by_default=True, +) +register_artifact( + "sym_node", + "Logs extra info for various SymNode operations", + off_by_default=True, +) +register_artifact( + "trace_shape_events", + "Logs traces for every ShapeEnv operation that we record for replay", + off_by_default=True, +) +register_artifact( + "cudagraph_static_inputs", + "Logs static inputs handling in dynamo, AOT, and cudagraphs", + off_by_default=True, +) +register_artifact( + "benchmarking", + "Detailed Inductor benchmarking information.", + off_by_default=True, +) +register_artifact( + "autotuning", + "Autotuning choice logs, such as kernel source, perf, and tuning parameters.", + off_by_default=True, +) +register_artifact( + "graph_region_expansion", + "Logs detailed steps of the duplicate graph region tracker expansion algorithm", + off_by_default=True, +) + +register_artifact( + "inductor_metrics", + "Logs Inductor metrics, such as num_bytes, nodes_num_elem, node_runtimes", + off_by_default=True, +) +register_artifact( + "hierarchical_compile", + "Logs debug info for hierarchical compilation", + off_by_default=True, +) +register_artifact("custom_format_test_artifact", "Testing only", log_format="") diff --git a/phivenv/Lib/site-packages/torch/_logging/scribe.py b/phivenv/Lib/site-packages/torch/_logging/scribe.py new file mode 100644 index 0000000000000000000000000000000000000000..befb1cc8dd7b571d5327d5a5ea331b34b1d6d6c0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_logging/scribe.py @@ -0,0 +1,63 @@ +from typing import Callable, Union +from typing_extensions import TypeAlias + + +try: + from fbscribelogger import ( # type: ignore[import-untyped, import-not-found] + make_scribe_logger, + ) +except ImportError: + TAtom: TypeAlias = Union[int, float, bool, str] + TField: TypeAlias = Union[TAtom, list[TAtom]] + TLazyField: TypeAlias = Union[TField, Callable[[], TField]] + + def make_scribe_logger(name: str, thrift_src: str) -> Callable[..., None]: + def inner(**kwargs: TLazyField) -> None: + pass + + return inner + + +open_source_signpost = make_scribe_logger( + "TorchOpenSourceSignpost", + """ +struct TorchOpenSourceSignpostLogEntry { + + # The commit SHA that triggered the workflow, e.g., 02a6b1d30f338206a71d0b75bfa09d85fac0028a. Derived from GITHUB_SHA. + 4: optional string commit_sha; + + # Commit date (not author date) of the commit in commit_sha as timestamp, e.g., 1724208105. Increasing if merge bot is used, though not monotonic; duplicates occur when stack is landed. + 5: optional i64 commit_date; + + # The fully-formed ref of the branch or tag that triggered the workflow run, e.g., refs/pull/133891/merge or refs/heads/main. Derived from GITHUB_REF. + 6: optional string github_ref; + + # Indicates if branch protections or rulesets are configured for the ref that triggered the workflow run. Derived from GITHUB_REF_PROTECTED. + 7: optional bool github_ref_protected; + + # A unique number for each attempt of a particular workflow run in a repository, e.g., 1. Derived from GITHUB_RUN_ATTEMPT. + 8: optional string github_run_attempt; + + # A unique number for each workflow run within a repository, e.g., 19471190684. Derived from GITHUB_RUN_ID. + 9: optional string github_run_id; + + # A unique number for each run of a particular workflow in a repository, e.g., 238742. Derived from GITHUB_RUN_NUMBER. + 10: optional string github_run_number_str; + + # The name of the current job. Derived from JOB_NAME, e.g., linux-jammy-py3.8-gcc11 / test (default, 3, 4, linux.2xlarge). + 11: optional string job_name; + + # The GitHub user who triggered the job. Derived from GITHUB_TRIGGERING_ACTOR. + 12: optional string github_triggering_actor; + 13: optional string name; # Event name + 14: optional string parameters; # Parameters (JSON data) + 16: optional string subsystem; # Subsystem the event is associated with + + # The unit timestamp in second for the Scuba Time Column override + 17: optional i64 time; + + # The weight of the record according to current sampling rate + 18: optional i64 weight; +} +""", # noqa: B950 +) diff --git a/phivenv/Lib/site-packages/torch/_logging/structured.py b/phivenv/Lib/site-packages/torch/_logging/structured.py new file mode 100644 index 0000000000000000000000000000000000000000..ac0b2e3f272dffe82d8b8712a9a18297b7bcbfbd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_logging/structured.py @@ -0,0 +1,108 @@ +""" +Utilities for converting data types into structured JSON for dumping. +""" +import inspect +import os +import traceback +from collections.abc import Sequence +from typing import Any, Optional + +import torch._logging._internal + + +INTERN_TABLE: dict[str, int] = {} + + +DUMPED_FILES: set[str] = set() + + +def intern_string(s: Optional[str]) -> int: + if s is None: + return -1 + + r = INTERN_TABLE.get(s, None) + if r is None: + r = len(INTERN_TABLE) + INTERN_TABLE[s] = r + torch._logging._internal.trace_structured( + "str", lambda: (s, r), suppress_context=True + ) + return r + + +def dump_file(filename: str) -> None: + if "eval_with_key" not in filename: + return + if filename in DUMPED_FILES: + return + DUMPED_FILES.add(filename) + from torch.fx.graph_module import _loader + + torch._logging._internal.trace_structured( + "dump_file", + metadata_fn=lambda: { + "name": filename, + }, + payload_fn=lambda: _loader.get_source(filename), + ) + + +def from_traceback(tb: Sequence[traceback.FrameSummary]) -> list[dict[str, Any]]: + # dict naming convention here coincides with + # python/combined_traceback.cpp + r = [ + { + "line": frame.lineno, + "name": frame.name, + "filename": intern_string(frame.filename), + "loc": frame.line, + } + for frame in tb + ] + return r + + +def get_user_stack(num_frames: int) -> list[dict[str, Any]]: + from torch._guards import TracingContext + from torch.utils._traceback import CapturedTraceback + + user_tb = TracingContext.extract_stack() + if user_tb: + return from_traceback(user_tb[-1 * num_frames :]) + + tb = CapturedTraceback.extract().summary() + + # Filter out frames that are within the torch/ codebase + torch_filepath = os.path.dirname(inspect.getfile(torch)) + os.path.sep + for i, frame in enumerate(reversed(tb)): + if torch_filepath not in frame.filename: + # Only display `num_frames` frames in the traceback + filtered_tb = tb[len(tb) - i - num_frames : len(tb) - i] + return from_traceback(filtered_tb) + + return from_traceback(tb[-1 * num_frames :]) + + +def get_framework_stack( + num_frames: int = 25, cpp: bool = False +) -> list[dict[str, Any]]: + """ + Returns the traceback for the user stack and the framework stack + """ + from torch.fx.experimental.symbolic_shapes import uninteresting_files + from torch.utils._traceback import CapturedTraceback + + tb = CapturedTraceback.extract(cpp=cpp).summary() + tb = [ + frame + for frame in tb + if ( + ( + frame.filename.endswith(".py") + and frame.filename not in uninteresting_files() + ) + or ("at::" in frame.name or "torch::" in frame.name) + ) + ] + + return from_traceback(tb[-1 * num_frames :]) diff --git a/phivenv/Lib/site-packages/torch/_numpy/__init__.py b/phivenv/Lib/site-packages/torch/_numpy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..110e222289cc2e037efc3fc2867c88090851139f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/__init__.py @@ -0,0 +1,34 @@ +# mypy: ignore-errors + +from . import fft, linalg, random +from ._dtypes import * # noqa: F403 +from ._funcs import * # noqa: F403 +from ._getlimits import finfo, iinfo +from ._ndarray import ( + array, + asarray, + ascontiguousarray, + can_cast, + from_dlpack, + ndarray, + newaxis, + result_type, +) +from ._ufuncs import * # noqa: F403 +from ._util import AxisError, UFuncTypeError + + +from math import pi, e # usort: skip + + +all = all +alltrue = all + +any = any +sometrue = any + +inf = float("inf") +nan = float("nan") + +False_ = False +True_ = True diff --git a/phivenv/Lib/site-packages/torch/_numpy/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d0713099d978038cb8d32e61ad977f082de6566 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_binary_ufuncs_impl.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_binary_ufuncs_impl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cce1c89e553b7daa17e13f269b0bfda88fa4b96 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_binary_ufuncs_impl.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_casting_dicts.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_casting_dicts.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..193b01b644cf374edb773269f62c8fdecfffb988 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_casting_dicts.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_dtypes.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_dtypes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fda2caf751d499785d30d0bee7200d25280741e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_dtypes.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_dtypes_impl.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_dtypes_impl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cbc4bc5dcefa28351724f9a95a2f2bfd22e5e90 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_dtypes_impl.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_funcs.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_funcs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a7fa5bdd8c474a5f5f7ca6c61d40f3c4dc8f05f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_funcs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_funcs_impl.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_funcs_impl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4ca9289bc76c98c9ffb00e215edb8267cd18f40 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_funcs_impl.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_getlimits.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_getlimits.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ccd3cb160ca6322bd35506b6476cdabeda5c4ed Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_getlimits.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_ndarray.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_ndarray.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98b0646211c40eeb86802d699f6a07522e76870b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_ndarray.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_normalizations.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_normalizations.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0df2fa666e670e0c2f7895ab6ae4b9daa089ce7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_normalizations.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_reductions_impl.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_reductions_impl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a21f8444ff11572e1ef5a8e5f201c2779c8a1aff Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_reductions_impl.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_ufuncs.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_ufuncs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcc4d2b8e06bb29551437ae8ed58f168baf75fec Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_ufuncs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_unary_ufuncs_impl.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_unary_ufuncs_impl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b0c093214c88b201effd58ab3f321df6c34faf7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_unary_ufuncs_impl.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_util.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..232b4be1d21b75faa41d8f2372b368206819114a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/_util.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/__pycache__/fft.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/fft.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fda052cd6b706f503b717394f663b1dcaa40d43 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/fft.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/__pycache__/linalg.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/linalg.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cff54f320437f11078ed12ab8285c80303ca47a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/linalg.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/__pycache__/random.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/random.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5007f813cb880bf4448732b119929078bc8aa3d4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/__pycache__/random.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/_binary_ufuncs_impl.py b/phivenv/Lib/site-packages/torch/_numpy/_binary_ufuncs_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..8c3086a7cc5dd078ff03ee1bb67b91290be24c6b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/_binary_ufuncs_impl.py @@ -0,0 +1,85 @@ +# mypy: ignore-errors + +"""Export torch work functions for binary ufuncs, rename/tweak to match numpy. +This listing is further exported to public symbols in the `torch._numpy/_ufuncs.py` module. +""" + +import torch +from torch import ( # noqa: F401 + add, + arctan2, + bitwise_and, + bitwise_left_shift as left_shift, + bitwise_or, + bitwise_right_shift as right_shift, + bitwise_xor, + copysign, + divide, + eq as equal, + float_power, + floor_divide, + fmax, + fmin, + fmod, + gcd, + greater, + greater_equal, + heaviside, + hypot, + lcm, + ldexp, + less, + less_equal, + logaddexp, + logaddexp2, + logical_and, + logical_or, + logical_xor, + maximum, + minimum, + multiply, + nextafter, + not_equal, + pow as power, + remainder, + remainder as mod, + subtract, + true_divide, +) + +from . import _dtypes_impl, _util + + +# work around torch limitations w.r.t. numpy +def matmul(x, y): + # work around: + # - RuntimeError: expected scalar type Int but found Double + # - RuntimeError: "addmm_impl_cpu_" not implemented for 'Bool' + # - RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' + dtype = _dtypes_impl.result_type_impl(x, y) + is_bool = dtype == torch.bool + is_half = (x.dtype == torch.float16 or y.dtype == torch.float16) and ( + x.is_cpu or y.is_cpu + ) + + work_dtype = dtype + if is_bool: + work_dtype = torch.uint8 + if is_half: + work_dtype = torch.float32 + + x = _util.cast_if_needed(x, work_dtype) + y = _util.cast_if_needed(y, work_dtype) + + result = torch.matmul(x, y) + + if work_dtype != dtype: + result = result.to(dtype) + + return result + + +# a stub implementation of divmod, should be improved after +# https://github.com/pytorch/pytorch/issues/90820 is fixed in pytorch +def divmod(x, y): + return x // y, x % y diff --git a/phivenv/Lib/site-packages/torch/_numpy/_casting_dicts.py b/phivenv/Lib/site-packages/torch/_numpy/_casting_dicts.py new file mode 100644 index 0000000000000000000000000000000000000000..bcac74593963fc41fa929cb6d9bbc29d12218ad3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/_casting_dicts.py @@ -0,0 +1,1368 @@ +# mypy: ignore-errors + +import torch + + +# These two dicts are autogenerated with autogen/gen_dtypes.py, +# using numpy version 1.24.3. + +_can_cast_dict = { + "no": { + torch.float16: { + torch.float16: True, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.float32: { + torch.float16: False, + torch.float32: True, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.float64: { + torch.float16: False, + torch.float32: False, + torch.float64: True, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.complex64: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: True, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.complex128: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint8: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: True, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint16: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: True, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint32: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: True, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint64: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: True, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.int8: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: True, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.int16: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: True, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.int32: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: True, + torch.int64: False, + torch.bool: False, + }, + torch.int64: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: True, + torch.bool: False, + }, + torch.bool: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: True, + }, + }, + "equiv": { + torch.float16: { + torch.float16: True, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.float32: { + torch.float16: False, + torch.float32: True, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.float64: { + torch.float16: False, + torch.float32: False, + torch.float64: True, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.complex64: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: True, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.complex128: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint8: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: True, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint16: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: True, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint32: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: True, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint64: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: True, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.int8: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: True, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.int16: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: True, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.int32: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: True, + torch.int64: False, + torch.bool: False, + }, + torch.int64: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: True, + torch.bool: False, + }, + torch.bool: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: False, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: True, + }, + }, + "safe": { + torch.float16: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.float32: { + torch.float16: False, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.float64: { + torch.float16: False, + torch.float32: False, + torch.float64: True, + torch.complex64: False, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.complex64: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.complex128: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: False, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint8: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: False, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.uint16: { + torch.float16: False, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: False, + torch.int16: False, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.uint32: { + torch.float16: False, + torch.float32: False, + torch.float64: True, + torch.complex64: False, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: True, + torch.uint64: True, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: True, + torch.bool: False, + }, + torch.uint64: { + torch.float16: False, + torch.float32: False, + torch.float64: True, + torch.complex64: False, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: True, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.int8: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.int16: { + torch.float16: False, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.int32: { + torch.float16: False, + torch.float32: False, + torch.float64: True, + torch.complex64: False, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.int64: { + torch.float16: False, + torch.float32: False, + torch.float64: True, + torch.complex64: False, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: True, + torch.bool: False, + }, + torch.bool: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + }, + "same_kind": { + torch.float16: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.float32: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.float64: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.complex64: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.complex128: { + torch.float16: False, + torch.float32: False, + torch.float64: False, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: False, + torch.int16: False, + torch.int32: False, + torch.int64: False, + torch.bool: False, + }, + torch.uint8: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.uint16: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.uint32: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.uint64: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.int8: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.int16: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.int32: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.int64: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: False, + torch.uint16: False, + torch.uint32: False, + torch.uint64: False, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: False, + }, + torch.bool: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + }, + "unsafe": { + torch.float16: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.float32: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.float64: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.complex64: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.complex128: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.uint8: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.uint16: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.uint32: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.uint64: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.int8: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.int16: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.int32: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.int64: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + torch.bool: { + torch.float16: True, + torch.float32: True, + torch.float64: True, + torch.complex64: True, + torch.complex128: True, + torch.uint8: True, + torch.uint16: True, + torch.uint32: True, + torch.uint64: True, + torch.int8: True, + torch.int16: True, + torch.int32: True, + torch.int64: True, + torch.bool: True, + }, + }, +} + + +_result_type_dict = { + torch.float16: { + torch.float16: torch.float16, + torch.float32: torch.float32, + torch.float64: torch.float64, + torch.complex64: torch.complex64, + torch.complex128: torch.complex128, + torch.uint8: torch.float16, + torch.uint16: torch.float32, + torch.uint32: torch.float64, + torch.uint64: torch.float64, + torch.int8: torch.float16, + torch.int16: torch.float32, + torch.int32: torch.float64, + torch.int64: torch.float64, + torch.bool: torch.float16, + }, + torch.float32: { + torch.float16: torch.float32, + torch.float32: torch.float32, + torch.float64: torch.float64, + torch.complex64: torch.complex64, + torch.complex128: torch.complex128, + torch.uint8: torch.float32, + torch.uint16: torch.float32, + torch.uint32: torch.float64, + torch.uint64: torch.float64, + torch.int8: torch.float32, + torch.int16: torch.float32, + torch.int32: torch.float64, + torch.int64: torch.float64, + torch.bool: torch.float32, + }, + torch.float64: { + torch.float16: torch.float64, + torch.float32: torch.float64, + torch.float64: torch.float64, + torch.complex64: torch.complex128, + torch.complex128: torch.complex128, + torch.uint8: torch.float64, + torch.uint16: torch.float64, + torch.uint32: torch.float64, + torch.uint64: torch.float64, + torch.int8: torch.float64, + torch.int16: torch.float64, + torch.int32: torch.float64, + torch.int64: torch.float64, + torch.bool: torch.float64, + }, + torch.complex64: { + torch.float16: torch.complex64, + torch.float32: torch.complex64, + torch.float64: torch.complex128, + torch.complex64: torch.complex64, + torch.complex128: torch.complex128, + torch.uint8: torch.complex64, + torch.uint16: torch.complex64, + torch.uint32: torch.complex128, + torch.uint64: torch.complex128, + torch.int8: torch.complex64, + torch.int16: torch.complex64, + torch.int32: torch.complex128, + torch.int64: torch.complex128, + torch.bool: torch.complex64, + }, + torch.complex128: { + torch.float16: torch.complex128, + torch.float32: torch.complex128, + torch.float64: torch.complex128, + torch.complex64: torch.complex128, + torch.complex128: torch.complex128, + torch.uint8: torch.complex128, + torch.uint16: torch.complex128, + torch.uint32: torch.complex128, + torch.uint64: torch.complex128, + torch.int8: torch.complex128, + torch.int16: torch.complex128, + torch.int32: torch.complex128, + torch.int64: torch.complex128, + torch.bool: torch.complex128, + }, + torch.uint8: { + torch.float16: torch.float16, + torch.float32: torch.float32, + torch.float64: torch.float64, + torch.complex64: torch.complex64, + torch.complex128: torch.complex128, + torch.uint8: torch.uint8, + torch.uint16: torch.uint16, + torch.uint32: torch.uint32, + torch.uint64: torch.uint64, + torch.int8: torch.int16, + torch.int16: torch.int16, + torch.int32: torch.int32, + torch.int64: torch.int64, + torch.bool: torch.uint8, + }, + torch.uint16: { + torch.float16: torch.float32, + torch.float32: torch.float32, + torch.float64: torch.float64, + torch.complex64: torch.complex64, + torch.complex128: torch.complex128, + torch.uint8: torch.uint16, + torch.uint16: torch.uint16, + torch.uint32: torch.uint32, + torch.uint64: torch.uint64, + torch.int8: torch.int32, + torch.int16: torch.int32, + torch.int32: torch.int32, + torch.int64: torch.int64, + torch.bool: torch.uint16, + }, + torch.uint32: { + torch.float16: torch.float64, + torch.float32: torch.float64, + torch.float64: torch.float64, + torch.complex64: torch.complex128, + torch.complex128: torch.complex128, + torch.uint8: torch.uint32, + torch.uint16: torch.uint32, + torch.uint32: torch.uint32, + torch.uint64: torch.uint64, + torch.int8: torch.int64, + torch.int16: torch.int64, + torch.int32: torch.int64, + torch.int64: torch.int64, + torch.bool: torch.uint32, + }, + torch.uint64: { + torch.float16: torch.float64, + torch.float32: torch.float64, + torch.float64: torch.float64, + torch.complex64: torch.complex128, + torch.complex128: torch.complex128, + torch.uint8: torch.uint64, + torch.uint16: torch.uint64, + torch.uint32: torch.uint64, + torch.uint64: torch.uint64, + torch.int8: torch.float64, + torch.int16: torch.float64, + torch.int32: torch.float64, + torch.int64: torch.float64, + torch.bool: torch.uint64, + }, + torch.int8: { + torch.float16: torch.float16, + torch.float32: torch.float32, + torch.float64: torch.float64, + torch.complex64: torch.complex64, + torch.complex128: torch.complex128, + torch.uint8: torch.int16, + torch.uint16: torch.int32, + torch.uint32: torch.int64, + torch.uint64: torch.float64, + torch.int8: torch.int8, + torch.int16: torch.int16, + torch.int32: torch.int32, + torch.int64: torch.int64, + torch.bool: torch.int8, + }, + torch.int16: { + torch.float16: torch.float32, + torch.float32: torch.float32, + torch.float64: torch.float64, + torch.complex64: torch.complex64, + torch.complex128: torch.complex128, + torch.uint8: torch.int16, + torch.uint16: torch.int32, + torch.uint32: torch.int64, + torch.uint64: torch.float64, + torch.int8: torch.int16, + torch.int16: torch.int16, + torch.int32: torch.int32, + torch.int64: torch.int64, + torch.bool: torch.int16, + }, + torch.int32: { + torch.float16: torch.float64, + torch.float32: torch.float64, + torch.float64: torch.float64, + torch.complex64: torch.complex128, + torch.complex128: torch.complex128, + torch.uint8: torch.int32, + torch.uint16: torch.int32, + torch.uint32: torch.int64, + torch.uint64: torch.float64, + torch.int8: torch.int32, + torch.int16: torch.int32, + torch.int32: torch.int32, + torch.int64: torch.int64, + torch.bool: torch.int32, + }, + torch.int64: { + torch.float16: torch.float64, + torch.float32: torch.float64, + torch.float64: torch.float64, + torch.complex64: torch.complex128, + torch.complex128: torch.complex128, + torch.uint8: torch.int64, + torch.uint16: torch.int64, + torch.uint32: torch.int64, + torch.uint64: torch.float64, + torch.int8: torch.int64, + torch.int16: torch.int64, + torch.int32: torch.int64, + torch.int64: torch.int64, + torch.bool: torch.int64, + }, + torch.bool: { + torch.float16: torch.float16, + torch.float32: torch.float32, + torch.float64: torch.float64, + torch.complex64: torch.complex64, + torch.complex128: torch.complex128, + torch.uint8: torch.uint8, + torch.uint16: torch.uint16, + torch.uint32: torch.uint32, + torch.uint64: torch.uint64, + torch.int8: torch.int8, + torch.int16: torch.int16, + torch.int32: torch.int32, + torch.int64: torch.int64, + torch.bool: torch.bool, + }, +} diff --git a/phivenv/Lib/site-packages/torch/_numpy/_dtypes.py b/phivenv/Lib/site-packages/torch/_numpy/_dtypes.py new file mode 100644 index 0000000000000000000000000000000000000000..083d3415a1cf8f9ed3ba67a0b81e293fe4d5e174 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/_dtypes.py @@ -0,0 +1,453 @@ +# mypy: ignore-errors + +""" Define analogs of numpy dtypes supported by pytorch. +Define the scalar types and supported dtypes and numpy <--> torch dtype mappings. +""" +import builtins + +import torch + +from . import _dtypes_impl + + +# ### Scalar types ### + + +class generic: + name = "generic" + + def __new__(cls, value): + # NumPy scalars are modelled as 0-D arrays + # so a call to np.float32(4) produces a 0-D array. + + from ._ndarray import asarray, ndarray + + if isinstance(value, str) and value in ["inf", "nan"]: + value = {"inf": torch.inf, "nan": torch.nan}[value] + + if isinstance(value, ndarray): + return value.astype(cls) + else: + return asarray(value, dtype=cls) + + +################## +# abstract types # +################## + + +class number(generic): + name = "number" + + +class integer(number): + name = "integer" + + +class inexact(number): + name = "inexact" + + +class signedinteger(integer): + name = "signedinteger" + + +class unsignedinteger(integer): + name = "unsignedinteger" + + +class floating(inexact): + name = "floating" + + +class complexfloating(inexact): + name = "complexfloating" + + +_abstract_dtypes = [ + "generic", + "number", + "integer", + "signedinteger", + "unsignedinteger", + "inexact", + "floating", + "complexfloating", +] + +# ##### concrete types + +# signed integers + + +class int8(signedinteger): + name = "int8" + typecode = "b" + torch_dtype = torch.int8 + + +class int16(signedinteger): + name = "int16" + typecode = "h" + torch_dtype = torch.int16 + + +class int32(signedinteger): + name = "int32" + typecode = "i" + torch_dtype = torch.int32 + + +class int64(signedinteger): + name = "int64" + typecode = "l" + torch_dtype = torch.int64 + + +# unsigned integers + + +class uint8(unsignedinteger): + name = "uint8" + typecode = "B" + torch_dtype = torch.uint8 + + +class uint16(unsignedinteger): + name = "uint16" + typecode = "H" + torch_dtype = torch.uint16 + + +class uint32(signedinteger): + name = "uint32" + typecode = "I" + torch_dtype = torch.uint32 + + +class uint64(signedinteger): + name = "uint64" + typecode = "L" + torch_dtype = torch.uint64 + + +# floating point + + +class float16(floating): + name = "float16" + typecode = "e" + torch_dtype = torch.float16 + + +class float32(floating): + name = "float32" + typecode = "f" + torch_dtype = torch.float32 + + +class float64(floating): + name = "float64" + typecode = "d" + torch_dtype = torch.float64 + + +class complex64(complexfloating): + name = "complex64" + typecode = "F" + torch_dtype = torch.complex64 + + +class complex128(complexfloating): + name = "complex128" + typecode = "D" + torch_dtype = torch.complex128 + + +class bool_(generic): + name = "bool_" + typecode = "?" + torch_dtype = torch.bool + + +# name aliases +_name_aliases = { + "intp": int64, + "int_": int64, + "intc": int32, + "byte": int8, + "short": int16, + "longlong": int64, # XXX: is this correct? + "ulonglong": uint64, + "ubyte": uint8, + "half": float16, + "single": float32, + "double": float64, + "float_": float64, + "csingle": complex64, + "singlecomplex": complex64, + "cdouble": complex128, + "cfloat": complex128, + "complex_": complex128, +} +# We register float_ = float32 and so on +for name, obj in _name_aliases.items(): + vars()[name] = obj + + +# Replicate this NumPy-defined way of grouping scalar types, +# cf tests/core/test_scalar_methods.py +sctypes = { + "int": [int8, int16, int32, int64], + "uint": [uint8, uint16, uint32, uint64], + "float": [float16, float32, float64], + "complex": [complex64, complex128], + "others": [bool_], +} + + +# Support mappings/functions + +_names = {st.name: st for cat in sctypes for st in sctypes[cat]} +_typecodes = {st.typecode: st for cat in sctypes for st in sctypes[cat]} +_torch_dtypes = {st.torch_dtype: st for cat in sctypes for st in sctypes[cat]} + + +_aliases = { + "u1": uint8, + "i1": int8, + "i2": int16, + "i4": int32, + "i8": int64, + "b": int8, # XXX: srsly? + "f2": float16, + "f4": float32, + "f8": float64, + "c8": complex64, + "c16": complex128, + # numpy-specific trailing underscore + "bool_": bool_, +} + + +_python_types = { + int: int64, + float: float64, + complex: complex128, + builtins.bool: bool_, + # also allow stringified names of python types + int.__name__: int64, + float.__name__: float64, + complex.__name__: complex128, + builtins.bool.__name__: bool_, +} + + +def sctype_from_string(s): + """Normalize a string value: a type 'name' or a typecode or a width alias.""" + if s in _names: + return _names[s] + if s in _name_aliases.keys(): + return _name_aliases[s] + if s in _typecodes: + return _typecodes[s] + if s in _aliases: + return _aliases[s] + if s in _python_types: + return _python_types[s] + raise TypeError(f"data type {s!r} not understood") + + +def sctype_from_torch_dtype(torch_dtype): + return _torch_dtypes[torch_dtype] + + +# ### DTypes. ### + + +def dtype(arg): + if arg is None: + arg = _dtypes_impl.default_dtypes().float_dtype + return DType(arg) + + +class DType: + def __init__(self, arg): + # a pytorch object? + if isinstance(arg, torch.dtype): + sctype = _torch_dtypes[arg] + elif isinstance(arg, torch.Tensor): + sctype = _torch_dtypes[arg.dtype] + # a scalar type? + elif issubclass_(arg, generic): + sctype = arg + # a dtype already? + elif isinstance(arg, DType): + sctype = arg._scalar_type + # a has a right attribute? + elif hasattr(arg, "dtype"): + sctype = arg.dtype._scalar_type + else: + sctype = sctype_from_string(arg) + self._scalar_type = sctype + + @property + def name(self): + return self._scalar_type.name + + @property + def type(self): + return self._scalar_type + + @property + def kind(self): + # https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html + return _torch_dtypes[self.torch_dtype].name[0] + + @property + def typecode(self): + return self._scalar_type.typecode + + def __eq__(self, other): + if isinstance(other, DType): + return self._scalar_type == other._scalar_type + try: + other_instance = DType(other) + except TypeError: + return False + return self._scalar_type == other_instance._scalar_type + + @property + def torch_dtype(self): + return self._scalar_type.torch_dtype + + def __hash__(self): + return hash(self._scalar_type.name) + + def __repr__(self): + return f'dtype("{self.name}")' + + __str__ = __repr__ + + @property + def itemsize(self): + elem = self.type(1) + return elem.tensor.element_size() + + def __getstate__(self): + return self._scalar_type + + def __setstate__(self, value): + self._scalar_type = value + + +typecodes = { + "All": "efdFDBbhil?", + "AllFloat": "efdFD", + "AllInteger": "Bbhil", + "Integer": "bhil", + "UnsignedInteger": "B", + "Float": "efd", + "Complex": "FD", +} + + +# ### Defaults and dtype discovery + + +def set_default_dtype(fp_dtype="numpy", int_dtype="numpy"): + """Set the (global) defaults for fp, complex, and int dtypes. + + The complex dtype is inferred from the float (fp) dtype. It has + a width at least twice the width of the float dtype, + i.e., it's complex128 for float64 and complex64 for float32. + + Parameters + ---------- + fp_dtype + Allowed values are "numpy", "pytorch" or dtype_like things which + can be converted into a DType instance. + Default is "numpy" (i.e. float64). + int_dtype + Allowed values are "numpy", "pytorch" or dtype_like things which + can be converted into a DType instance. + Default is "numpy" (i.e. int64). + + Returns + ------- + The old default dtype state: a namedtuple with attributes ``float_dtype``, + ``complex_dtypes`` and ``int_dtype``. These attributes store *pytorch* + dtypes. + + Notes + ------------ + This functions has a side effect: it sets the global state with the provided dtypes. + + The complex dtype has bit width of at least twice the width of the float + dtype, i.e. it's complex128 for float64 and complex64 for float32. + + """ + if fp_dtype not in ["numpy", "pytorch"]: + fp_dtype = dtype(fp_dtype).torch_dtype + if int_dtype not in ["numpy", "pytorch"]: + int_dtype = dtype(int_dtype).torch_dtype + + if fp_dtype == "numpy": + float_dtype = torch.float64 + elif fp_dtype == "pytorch": + float_dtype = torch.float32 + else: + float_dtype = fp_dtype + + complex_dtype = { + torch.float64: torch.complex128, + torch.float32: torch.complex64, + torch.float16: torch.complex64, + }[float_dtype] + + if int_dtype in ["numpy", "pytorch"]: + int_dtype = torch.int64 + else: + int_dtype = int_dtype + + new_defaults = _dtypes_impl.DefaultDTypes( + float_dtype=float_dtype, complex_dtype=complex_dtype, int_dtype=int_dtype + ) + + # set the new global state and return the old state + old_defaults = _dtypes_impl.default_dtypes + _dtypes_impl._default_dtypes = new_defaults + return old_defaults + + +def issubclass_(arg, klass): + try: + return issubclass(arg, klass) + except TypeError: + return False + + +def issubdtype(arg1, arg2): + # cf https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numerictypes.py#L356-L420 + + # We also accept strings even if NumPy doesn't as dtypes are serialized as their + # string representation in dynamo's graph + def str_to_abstract(t): + if isinstance(t, str) and t in _abstract_dtypes: + return globals()[t] + return t + + arg1 = str_to_abstract(arg1) + arg2 = str_to_abstract(arg2) + + if not issubclass_(arg1, generic): + arg1 = dtype(arg1).type + if not issubclass_(arg2, generic): + arg2 = dtype(arg2).type + return issubclass(arg1, arg2) + + +__all__ = ["dtype", "DType", "typecodes", "issubdtype", "set_default_dtype", "sctypes"] +__all__ += list(_names.keys()) # noqa: PLE0605 +__all__ += list(_name_aliases.keys()) # noqa: PLE0605 +__all__ += _abstract_dtypes # noqa: PLE0605 diff --git a/phivenv/Lib/site-packages/torch/_numpy/_dtypes_impl.py b/phivenv/Lib/site-packages/torch/_numpy/_dtypes_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..19e2cf10756f3c37a9264207347ae0da2a0205e6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/_dtypes_impl.py @@ -0,0 +1,217 @@ +# mypy: ignore-errors + +"""Dtypes/scalar type implementaions with torch dtypes. + +Here `dtype` is always a torch.dtype, this module knows nothing about +scalar types, wrapper dtypes or anything like that. PyTorch only. +""" +from collections import namedtuple + +import torch + + +# defaults : mimic NumPy, allow user control +DefaultDTypes = namedtuple( + "DefaultDTypes", ["float_dtype", "complex_dtype", "int_dtype"] +) + +# a global state +# We set it the first time we call default_dtypes() to avoid importing +# torch._dynamo.config and create a circular reference +_default_dtypes = None + + +def default_dtypes(): + global _default_dtypes + if _default_dtypes is None: + import torch._dynamo.config as config + + _default_dtypes = DefaultDTypes( + float_dtype=getattr(torch, config.numpy_default_float), + complex_dtype=getattr(torch, config.numpy_default_complex), + int_dtype=getattr(torch, config.numpy_default_int), + ) + assert isinstance(_default_dtypes.float_dtype, torch.dtype) + assert isinstance(_default_dtypes.complex_dtype, torch.dtype) + assert isinstance(_default_dtypes.int_dtype, torch.dtype) + return _default_dtypes + + +def get_default_dtype_for(dtype): + """Default scalar type given sctype category.""" + if dtype == torch.bool: + return dtype + if dtype.is_complex: + return default_dtypes().complex_dtype + if dtype.is_floating_point: + return default_dtypes().float_dtype + # else, it must be (some) integer + return default_dtypes().int_dtype + + +from . import _casting_dicts as _cd + + +def can_cast_impl(from_torch_dtype, to_torch_dtype, casting): + return _cd._can_cast_dict[casting][from_torch_dtype][to_torch_dtype] + + +def result_type_impl(*tensors): + # NB: torch dtypes here + dtyp = tensors[0].dtype + if len(tensors) == 1: + return dtyp + + for curr in tensors[1:]: + dtyp = _cd._result_type_dict[dtyp][curr.dtype] + + return dtyp + + +def python_type_for_torch(dtyp): + """Get a python scalar type a torch dtype""" + if dtyp.is_floating_point: + typ = float + elif dtyp.is_complex: + typ = complex + elif dtyp == torch.bool: + typ = bool + else: + typ = int + return typ + + +# ### NEP 50 helpers ### + +_SCALAR_TYPES = (int, bool, float, complex) + +_SCALAR_AND_SYMBOLIC_TYPES = ( + *_SCALAR_TYPES, + torch.SymInt, + torch.SymFloat, + torch.SymBool, +) + +_NEP50_FUNCS_TENSOR_ONLY = ( + "minimum", + "maximum", + "logaddexp", + "logaddexp2", + "lcm", + "gcd", + "hypot", + "heaviside", + "fmod", + "fmin", + "fmax", + "copysign", + "arctan2", +) + + +def is_scalar(x): + return isinstance(x, _SCALAR_TYPES) + + +def is_scalar_or_symbolic(x): + return isinstance(x, _SCALAR_AND_SYMBOLIC_TYPES) + + +def _dtype_for_scalar(py_type): + return { + bool: torch.bool, + torch.SymBool: torch.bool, + int: torch.int64, + torch.SymInt: torch.int64, + float: torch.float64, + torch.SymFloat: torch.float64, + complex: torch.complex128, + }[py_type] + + +def _dtype_for_scalar_or_tensor(x): + return x.dtype if isinstance(x, torch.Tensor) else _dtype_for_scalar(type(x)) + + +def is_float_or_fp_tensor(x): + return _dtype_for_scalar_or_tensor(x).is_floating_point + + +def is_complex_or_complex_tensor(x): + return _dtype_for_scalar_or_tensor(x).is_complex + + +def _category(dtype): + return { + torch.bool: 0, + torch.SymBool: 0, + # int + torch.uint8: 1, + torch.int8: 1, + torch.int16: 1, + torch.int32: 1, + torch.int64: 1, + torch.SymInt: 1, + # float + torch.float16: 2, + torch.float32: 2, + torch.float64: 2, + torch.SymFloat: 2, + # complex + torch.complex64: 3, + torch.complex128: 3, + }[dtype] + + +def nep50_to_tensors(x1, x2, handle_weaks, function_name): + """If either of inputs is a python scalar, type-promote with NEP 50.""" + + def to_tensor(scalar, dtype=None): + if dtype is None: + dtype = _dtype_for_scalar(type(scalar)) + dtype = get_default_dtype_for(dtype) + return torch.as_tensor(scalar, dtype=dtype) + + x1_is_weak = not isinstance(x1, torch.Tensor) + x2_is_weak = not isinstance(x2, torch.Tensor) + if not handle_weaks or (x1_is_weak and x2_is_weak): + x1 = to_tensor(x1) if x1_is_weak else x1 + x2 = to_tensor(x2) if x2_is_weak else x2 + return x1, x2 + + # scalar tensor: NEP 50 + assert x1_is_weak != x2_is_weak + + weak, not_weak = (x1, x2) if x1_is_weak else (x2, x1) + + # find the dtype for the weak's type + weak_dtype = _dtype_for_scalar(type(weak)) + + cat_weak = _category(weak_dtype) + cat_not_weak = _category(not_weak.dtype) + + dt = not_weak.dtype if cat_weak <= cat_not_weak else None + + # special-case complex + float32 + if weak_dtype.is_complex and not_weak.dtype == torch.float32: + dt = torch.complex64 + + # detect overflows: in PyTorch, uint8(-1) wraps around to 255, + # while NEP50 mandates an exception. + # + # Note that we only check if each element of the binop overflows, + # not the result. Consider, e.g. `uint8(100) + 200`. Operands are OK + # in uint8, but the result overflows and wrap around 255. + # Numpy emits a RuntimeWarning, PyTorch does not, and we do not either. + if cat_weak == 1 and cat_not_weak == 1: + # integers + iinfo = torch.iinfo(not_weak.dtype) + if not (iinfo.min <= weak <= iinfo.max): + raise OverflowError( + f"Python integer {weak} out of bounds for {not_weak.dtype}" + ) + if weak_dtype != dt or function_name in _NEP50_FUNCS_TENSOR_ONLY: + # finally, can make `weak` into a 0D tensor, if both parameters are required to be tensor. + weak = to_tensor(weak, dt) + + return (weak, not_weak) if x1_is_weak else (not_weak, weak) diff --git a/phivenv/Lib/site-packages/torch/_numpy/_funcs.py b/phivenv/Lib/site-packages/torch/_numpy/_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..4ba676c633e457d0201bd3be25a51b5974f5d746 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/_funcs.py @@ -0,0 +1,76 @@ +# mypy: ignore-errors + +import inspect +import itertools + +from . import _funcs_impl, _reductions_impl +from ._normalizations import normalizer + + +# _funcs_impl.py contains functions which mimic NumPy's eponymous equivalents, +# and consume/return PyTorch tensors/dtypes. +# They are also type annotated. +# Pull these functions from _funcs_impl and decorate them with @normalizer, which +# - Converts any input `np.ndarray`, `torch._numpy.ndarray`, list of lists, Python scalars, etc into a `torch.Tensor`. +# - Maps NumPy dtypes to PyTorch dtypes +# - If the input to the `axis` kwarg is an ndarray, it maps it into a tuple +# - Implements the semantics for the `out=` arg +# - Wraps back the outputs into `torch._numpy.ndarrays` + + +def _public_functions(mod): + def is_public_function(f): + return inspect.isfunction(f) and not f.__name__.startswith("_") + + return inspect.getmembers(mod, is_public_function) + + +# We fill in __all__ in the loop below +__all__ = [] + +# decorate implementer functions with argument normalizers and export to the top namespace +for name, func in itertools.chain( + _public_functions(_funcs_impl), _public_functions(_reductions_impl) +): + if name in ["percentile", "quantile", "median"]: + decorated = normalizer(func, promote_scalar_result=True) + elif name == "einsum": + # normalized manually + decorated = func + else: + decorated = normalizer(func) + + decorated.__qualname__ = name + decorated.__name__ = name + vars()[name] = decorated + __all__.append(name) + + +""" +Vendored objects from numpy.lib.index_tricks +""" + + +class IndexExpression: + """ + Written by Konrad Hinsen + last revision: 1999-7-23 + + Cosmetic changes by T. Oliphant 2001 + """ + + def __init__(self, maketuple): + self.maketuple = maketuple + + def __getitem__(self, item): + if self.maketuple and not isinstance(item, tuple): + return (item,) + else: + return item + + +index_exp = IndexExpression(maketuple=True) +s_ = IndexExpression(maketuple=False) + + +__all__ += ["index_exp", "s_"] diff --git a/phivenv/Lib/site-packages/torch/_numpy/_funcs_impl.py b/phivenv/Lib/site-packages/torch/_numpy/_funcs_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..48bad82cf993060d7db3f23c945a3e15e04430b1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/_funcs_impl.py @@ -0,0 +1,2058 @@ +# mypy: ignore-errors + +"""A thin pytorch / numpy compat layer. + +Things imported from here have numpy-compatible signatures but operate on +pytorch tensors. +""" +# Contents of this module ends up in the main namespace via _funcs.py +# where type annotations are used in conjunction with the @normalizer decorator. +from __future__ import annotations + +import builtins +import itertools +import operator +from typing import Optional, TYPE_CHECKING + +import torch + +from . import _dtypes_impl, _util + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from ._normalizations import ( + ArrayLike, + ArrayLikeOrScalar, + CastingModes, + DTypeLike, + NDArray, + NotImplementedType, + OutArray, + ) + + +def copy( + a: ArrayLike, order: NotImplementedType = "K", subok: NotImplementedType = False +): + return a.clone() + + +def copyto( + dst: NDArray, + src: ArrayLike, + casting: Optional[CastingModes] = "same_kind", + where: NotImplementedType = None, +): + (src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting) + dst.copy_(src) + + +def atleast_1d(*arys: ArrayLike): + res = torch.atleast_1d(*arys) + if isinstance(res, tuple): + return list(res) + else: + return res + + +def atleast_2d(*arys: ArrayLike): + res = torch.atleast_2d(*arys) + if isinstance(res, tuple): + return list(res) + else: + return res + + +def atleast_3d(*arys: ArrayLike): + res = torch.atleast_3d(*arys) + if isinstance(res, tuple): + return list(res) + else: + return res + + +def _concat_check(tup, dtype, out): + if tup == (): + raise ValueError("need at least one array to concatenate") + + """Check inputs in concatenate et al.""" + if out is not None and dtype is not None: + # mimic numpy + raise TypeError( + "concatenate() only takes `out` or `dtype` as an " + "argument, but both were provided." + ) + + +def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"): + """Figure out dtypes, cast if necessary.""" + + if out is not None or dtype is not None: + # figure out the type of the inputs and outputs + out_dtype = out.dtype.torch_dtype if dtype is None else dtype + else: + out_dtype = _dtypes_impl.result_type_impl(*tensors) + + # cast input arrays if necessary; do not broadcast them agains `out` + tensors = _util.typecast_tensors(tensors, out_dtype, casting) + + return tensors + + +def _concatenate( + tensors, axis=0, out=None, dtype=None, casting: Optional[CastingModes] = "same_kind" +): + # pure torch implementation, used below and in cov/corrcoef below + tensors, axis = _util.axis_none_flatten(*tensors, axis=axis) + tensors = _concat_cast_helper(tensors, out, dtype, casting) + return torch.cat(tensors, axis) + + +def concatenate( + ar_tuple: Sequence[ArrayLike], + axis=0, + out: Optional[OutArray] = None, + dtype: Optional[DTypeLike] = None, + casting: Optional[CastingModes] = "same_kind", +): + _concat_check(ar_tuple, dtype, out=out) + result = _concatenate(ar_tuple, axis=axis, out=out, dtype=dtype, casting=casting) + return result + + +def vstack( + tup: Sequence[ArrayLike], + *, + dtype: Optional[DTypeLike] = None, + casting: Optional[CastingModes] = "same_kind", +): + _concat_check(tup, dtype, out=None) + tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting) + return torch.vstack(tensors) + + +row_stack = vstack + + +def hstack( + tup: Sequence[ArrayLike], + *, + dtype: Optional[DTypeLike] = None, + casting: Optional[CastingModes] = "same_kind", +): + _concat_check(tup, dtype, out=None) + tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting) + return torch.hstack(tensors) + + +def dstack( + tup: Sequence[ArrayLike], + *, + dtype: Optional[DTypeLike] = None, + casting: Optional[CastingModes] = "same_kind", +): + # XXX: in numpy 1.24 dstack does not have dtype and casting keywords + # but {h,v}stack do. Hence add them here for consistency. + _concat_check(tup, dtype, out=None) + tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting) + return torch.dstack(tensors) + + +def column_stack( + tup: Sequence[ArrayLike], + *, + dtype: Optional[DTypeLike] = None, + casting: Optional[CastingModes] = "same_kind", +): + # XXX: in numpy 1.24 column_stack does not have dtype and casting keywords + # but row_stack does. (because row_stack is an alias for vstack, really). + # Hence add these keywords here for consistency. + _concat_check(tup, dtype, out=None) + tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting) + return torch.column_stack(tensors) + + +def stack( + arrays: Sequence[ArrayLike], + axis=0, + out: Optional[OutArray] = None, + *, + dtype: Optional[DTypeLike] = None, + casting: Optional[CastingModes] = "same_kind", +): + _concat_check(arrays, dtype, out=out) + + tensors = _concat_cast_helper(arrays, dtype=dtype, casting=casting) + result_ndim = tensors[0].ndim + 1 + axis = _util.normalize_axis_index(axis, result_ndim) + return torch.stack(tensors, axis=axis) + + +def append(arr: ArrayLike, values: ArrayLike, axis=None): + if axis is None: + if arr.ndim != 1: + arr = arr.flatten() + values = values.flatten() + axis = arr.ndim - 1 + return _concatenate((arr, values), axis=axis) + + +# ### split ### + + +def _split_helper(tensor, indices_or_sections, axis, strict=False): + if isinstance(indices_or_sections, int): + return _split_helper_int(tensor, indices_or_sections, axis, strict) + elif isinstance(indices_or_sections, (list, tuple)): + # NB: drop split=..., it only applies to split_helper_int + return _split_helper_list(tensor, list(indices_or_sections), axis) + else: + raise TypeError("split_helper: ", type(indices_or_sections)) + + +def _split_helper_int(tensor, indices_or_sections, axis, strict=False): + if not isinstance(indices_or_sections, int): + raise NotImplementedError("split: indices_or_sections") + + axis = _util.normalize_axis_index(axis, tensor.ndim) + + # numpy: l%n chunks of size (l//n + 1), the rest are sized l//n + l, n = tensor.shape[axis], indices_or_sections + + if n <= 0: + raise ValueError + + if l % n == 0: + num, sz = n, l // n + lst = [sz] * num + else: + if strict: + raise ValueError("array split does not result in an equal division") + + num, sz = l % n, l // n + 1 + lst = [sz] * num + + lst += [sz - 1] * (n - num) + + return torch.split(tensor, lst, axis) + + +def _split_helper_list(tensor, indices_or_sections, axis): + if not isinstance(indices_or_sections, list): + raise NotImplementedError("split: indices_or_sections: list") + # numpy expects indices, while torch expects lengths of sections + # also, numpy appends zero-size arrays for indices above the shape[axis] + lst = [x for x in indices_or_sections if x <= tensor.shape[axis]] + num_extra = len(indices_or_sections) - len(lst) + + lst.append(tensor.shape[axis]) + lst = [ + lst[0], + ] + [a - b for a, b in zip(lst[1:], lst[:-1])] + lst += [0] * num_extra + + return torch.split(tensor, lst, axis) + + +def array_split(ary: ArrayLike, indices_or_sections, axis=0): + return _split_helper(ary, indices_or_sections, axis) + + +def split(ary: ArrayLike, indices_or_sections, axis=0): + return _split_helper(ary, indices_or_sections, axis, strict=True) + + +def hsplit(ary: ArrayLike, indices_or_sections): + if ary.ndim == 0: + raise ValueError("hsplit only works on arrays of 1 or more dimensions") + axis = 1 if ary.ndim > 1 else 0 + return _split_helper(ary, indices_or_sections, axis, strict=True) + + +def vsplit(ary: ArrayLike, indices_or_sections): + if ary.ndim < 2: + raise ValueError("vsplit only works on arrays of 2 or more dimensions") + return _split_helper(ary, indices_or_sections, 0, strict=True) + + +def dsplit(ary: ArrayLike, indices_or_sections): + if ary.ndim < 3: + raise ValueError("dsplit only works on arrays of 3 or more dimensions") + return _split_helper(ary, indices_or_sections, 2, strict=True) + + +def kron(a: ArrayLike, b: ArrayLike): + return torch.kron(a, b) + + +def vander(x: ArrayLike, N=None, increasing=False): + return torch.vander(x, N, increasing) + + +# ### linspace, geomspace, logspace and arange ### + + +def linspace( + start: ArrayLike, + stop: ArrayLike, + num=50, + endpoint=True, + retstep=False, + dtype: Optional[DTypeLike] = None, + axis=0, +): + if axis != 0 or retstep or not endpoint: + raise NotImplementedError + if dtype is None: + dtype = _dtypes_impl.default_dtypes().float_dtype + # XXX: raises TypeError if start or stop are not scalars + return torch.linspace(start, stop, num, dtype=dtype) + + +def geomspace( + start: ArrayLike, + stop: ArrayLike, + num=50, + endpoint=True, + dtype: Optional[DTypeLike] = None, + axis=0, +): + if axis != 0 or not endpoint: + raise NotImplementedError + base = torch.pow(stop / start, 1.0 / (num - 1)) + logbase = torch.log(base) + return torch.logspace( + torch.log(start) / logbase, + torch.log(stop) / logbase, + num, + base=base, + ) + + +def logspace( + start, + stop, + num=50, + endpoint=True, + base=10.0, + dtype: Optional[DTypeLike] = None, + axis=0, +): + if axis != 0 or not endpoint: + raise NotImplementedError + return torch.logspace(start, stop, num, base=base, dtype=dtype) + + +def arange( + start: Optional[ArrayLikeOrScalar] = None, + stop: Optional[ArrayLikeOrScalar] = None, + step: Optional[ArrayLikeOrScalar] = 1, + dtype: Optional[DTypeLike] = None, + *, + like: NotImplementedType = None, +): + if step == 0: + raise ZeroDivisionError + if stop is None and start is None: + raise TypeError + if stop is None: + # XXX: this breaks if start is passed as a kwarg: + # arange(start=4) should raise (no stop) but doesn't + start, stop = 0, start + if start is None: + start = 0 + + # the dtype of the result + if dtype is None: + dtype = ( + _dtypes_impl.default_dtypes().float_dtype + if any(_dtypes_impl.is_float_or_fp_tensor(x) for x in (start, stop, step)) + else _dtypes_impl.default_dtypes().int_dtype + ) + work_dtype = torch.float64 if dtype.is_complex else dtype + + # RuntimeError: "lt_cpu" not implemented for 'ComplexFloat'. Fall back to eager. + if any(_dtypes_impl.is_complex_or_complex_tensor(x) for x in (start, stop, step)): + raise NotImplementedError + + if (step > 0 and start > stop) or (step < 0 and start < stop): + # empty range + return torch.empty(0, dtype=dtype) + + result = torch.arange(start, stop, step, dtype=work_dtype) + result = _util.cast_if_needed(result, dtype) + return result + + +# ### zeros/ones/empty/full ### + + +def empty( + shape, + dtype: Optional[DTypeLike] = None, + order: NotImplementedType = "C", + *, + like: NotImplementedType = None, +): + if dtype is None: + dtype = _dtypes_impl.default_dtypes().float_dtype + return torch.empty(shape, dtype=dtype) + + +# NB: *_like functions deliberately deviate from numpy: it has subok=True +# as the default; we set subok=False and raise on anything else. + + +def empty_like( + prototype: ArrayLike, + dtype: Optional[DTypeLike] = None, + order: NotImplementedType = "K", + subok: NotImplementedType = False, + shape=None, +): + result = torch.empty_like(prototype, dtype=dtype) + if shape is not None: + result = result.reshape(shape) + return result + + +def full( + shape, + fill_value: ArrayLike, + dtype: Optional[DTypeLike] = None, + order: NotImplementedType = "C", + *, + like: NotImplementedType = None, +): + if isinstance(shape, int): + shape = (shape,) + if dtype is None: + dtype = fill_value.dtype + if not isinstance(shape, (tuple, list)): + shape = (shape,) + return torch.full(shape, fill_value, dtype=dtype) + + +def full_like( + a: ArrayLike, + fill_value, + dtype: Optional[DTypeLike] = None, + order: NotImplementedType = "K", + subok: NotImplementedType = False, + shape=None, +): + # XXX: fill_value broadcasts + result = torch.full_like(a, fill_value, dtype=dtype) + if shape is not None: + result = result.reshape(shape) + return result + + +def ones( + shape, + dtype: Optional[DTypeLike] = None, + order: NotImplementedType = "C", + *, + like: NotImplementedType = None, +): + if dtype is None: + dtype = _dtypes_impl.default_dtypes().float_dtype + return torch.ones(shape, dtype=dtype) + + +def ones_like( + a: ArrayLike, + dtype: Optional[DTypeLike] = None, + order: NotImplementedType = "K", + subok: NotImplementedType = False, + shape=None, +): + result = torch.ones_like(a, dtype=dtype) + if shape is not None: + result = result.reshape(shape) + return result + + +def zeros( + shape, + dtype: Optional[DTypeLike] = None, + order: NotImplementedType = "C", + *, + like: NotImplementedType = None, +): + if dtype is None: + dtype = _dtypes_impl.default_dtypes().float_dtype + return torch.zeros(shape, dtype=dtype) + + +def zeros_like( + a: ArrayLike, + dtype: Optional[DTypeLike] = None, + order: NotImplementedType = "K", + subok: NotImplementedType = False, + shape=None, +): + result = torch.zeros_like(a, dtype=dtype) + if shape is not None: + result = result.reshape(shape) + return result + + +# ### cov & corrcoef ### + + +def _xy_helper_corrcoef(x_tensor, y_tensor=None, rowvar=True): + """Prepare inputs for cov and corrcoef.""" + + # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L2636 + if y_tensor is not None: + # make sure x and y are at least 2D + ndim_extra = 2 - x_tensor.ndim + if ndim_extra > 0: + x_tensor = x_tensor.view((1,) * ndim_extra + x_tensor.shape) + if not rowvar and x_tensor.shape[0] != 1: + x_tensor = x_tensor.mT + x_tensor = x_tensor.clone() + + ndim_extra = 2 - y_tensor.ndim + if ndim_extra > 0: + y_tensor = y_tensor.view((1,) * ndim_extra + y_tensor.shape) + if not rowvar and y_tensor.shape[0] != 1: + y_tensor = y_tensor.mT + y_tensor = y_tensor.clone() + + x_tensor = _concatenate((x_tensor, y_tensor), axis=0) + + return x_tensor + + +def corrcoef( + x: ArrayLike, + y: Optional[ArrayLike] = None, + rowvar=True, + bias=None, + ddof=None, + *, + dtype: Optional[DTypeLike] = None, +): + if bias is not None or ddof is not None: + # deprecated in NumPy + raise NotImplementedError + xy_tensor = _xy_helper_corrcoef(x, y, rowvar) + + is_half = (xy_tensor.dtype == torch.float16) and xy_tensor.is_cpu + if is_half: + # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" + dtype = torch.float32 + + xy_tensor = _util.cast_if_needed(xy_tensor, dtype) + result = torch.corrcoef(xy_tensor) + + if is_half: + result = result.to(torch.float16) + + return result + + +def cov( + m: ArrayLike, + y: Optional[ArrayLike] = None, + rowvar=True, + bias=False, + ddof=None, + fweights: Optional[ArrayLike] = None, + aweights: Optional[ArrayLike] = None, + *, + dtype: Optional[DTypeLike] = None, +): + m = _xy_helper_corrcoef(m, y, rowvar) + + if ddof is None: + ddof = 1 if bias == 0 else 0 + + is_half = (m.dtype == torch.float16) and m.is_cpu + if is_half: + # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" + dtype = torch.float32 + + m = _util.cast_if_needed(m, dtype) + result = torch.cov(m, correction=ddof, aweights=aweights, fweights=fweights) + + if is_half: + result = result.to(torch.float16) + + return result + + +def _conv_corr_impl(a, v, mode): + dt = _dtypes_impl.result_type_impl(a, v) + a = _util.cast_if_needed(a, dt) + v = _util.cast_if_needed(v, dt) + + padding = v.shape[0] - 1 if mode == "full" else mode + + if padding == "same" and v.shape[0] % 2 == 0: + # UserWarning: Using padding='same' with even kernel lengths and odd + # dilation may require a zero-padded copy of the input be created + # (Triggered internally at pytorch/aten/src/ATen/native/Convolution.cpp:1010.) + raise NotImplementedError("mode='same' and even-length weights") + + # NumPy only accepts 1D arrays; PyTorch requires 2D inputs and 3D weights + aa = a[None, :] + vv = v[None, None, :] + + result = torch.nn.functional.conv1d(aa, vv, padding=padding) + + # torch returns a 2D result, numpy returns a 1D array + return result[0, :] + + +def convolve(a: ArrayLike, v: ArrayLike, mode="full"): + # NumPy: if v is longer than a, the arrays are swapped before computation + if a.shape[0] < v.shape[0]: + a, v = v, a + + # flip the weights since numpy does and torch does not + v = torch.flip(v, (0,)) + + return _conv_corr_impl(a, v, mode) + + +def correlate(a: ArrayLike, v: ArrayLike, mode="valid"): + v = torch.conj_physical(v) + return _conv_corr_impl(a, v, mode) + + +# ### logic & element selection ### + + +def bincount(x: ArrayLike, /, weights: Optional[ArrayLike] = None, minlength=0): + if x.numel() == 0: + # edge case allowed by numpy + x = x.new_empty(0, dtype=int) + + int_dtype = _dtypes_impl.default_dtypes().int_dtype + (x,) = _util.typecast_tensors((x,), int_dtype, casting="safe") + + return torch.bincount(x, weights, minlength) + + +def where( + condition: ArrayLike, + x: Optional[ArrayLikeOrScalar] = None, + y: Optional[ArrayLikeOrScalar] = None, + /, +): + if (x is None) != (y is None): + raise ValueError("either both or neither of x and y should be given") + + if condition.dtype != torch.bool: + condition = condition.to(torch.bool) + + if x is None and y is None: + result = torch.where(condition) + else: + result = torch.where(condition, x, y) + return result + + +# ###### module-level queries of object properties + + +def ndim(a: ArrayLike): + return a.ndim + + +def shape(a: ArrayLike): + return tuple(a.shape) + + +def size(a: ArrayLike, axis=None): + if axis is None: + return a.numel() + else: + return a.shape[axis] + + +# ###### shape manipulations and indexing + + +def expand_dims(a: ArrayLike, axis): + shape = _util.expand_shape(a.shape, axis) + return a.view(shape) # never copies + + +def flip(m: ArrayLike, axis=None): + # XXX: semantic difference: np.flip returns a view, torch.flip copies + if axis is None: + axis = tuple(range(m.ndim)) + else: + axis = _util.normalize_axis_tuple(axis, m.ndim) + return torch.flip(m, axis) + + +def flipud(m: ArrayLike): + return torch.flipud(m) + + +def fliplr(m: ArrayLike): + return torch.fliplr(m) + + +def rot90(m: ArrayLike, k=1, axes=(0, 1)): + axes = _util.normalize_axis_tuple(axes, m.ndim) + return torch.rot90(m, k, axes) + + +# ### broadcasting and indices ### + + +def broadcast_to(array: ArrayLike, shape, subok: NotImplementedType = False): + return torch.broadcast_to(array, size=shape) + + +# This is a function from tuples to tuples, so we just reuse it +from torch import broadcast_shapes + + +def broadcast_arrays(*args: ArrayLike, subok: NotImplementedType = False): + return torch.broadcast_tensors(*args) + + +def meshgrid(*xi: ArrayLike, copy=True, sparse=False, indexing="xy"): + ndim = len(xi) + + if indexing not in ["xy", "ij"]: + raise ValueError("Valid values for `indexing` are 'xy' and 'ij'.") + + s0 = (1,) * ndim + output = [x.reshape(s0[:i] + (-1,) + s0[i + 1 :]) for i, x in enumerate(xi)] + + if indexing == "xy" and ndim > 1: + # switch first and second axis + output[0] = output[0].reshape((1, -1) + s0[2:]) + output[1] = output[1].reshape((-1, 1) + s0[2:]) + + if not sparse: + # Return the full N-D matrix (not only the 1-D vector) + output = torch.broadcast_tensors(*output) + + if copy: + output = [x.clone() for x in output] + + return list(output) # match numpy, return a list + + +def indices(dimensions, dtype: Optional[DTypeLike] = int, sparse=False): + # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1691-L1791 + dimensions = tuple(dimensions) + N = len(dimensions) + shape = (1,) * N + if sparse: + res = () + else: + res = torch.empty((N,) + dimensions, dtype=dtype) + for i, dim in enumerate(dimensions): + idx = torch.arange(dim, dtype=dtype).reshape( + shape[:i] + (dim,) + shape[i + 1 :] + ) + if sparse: + res = res + (idx,) + else: + res[i] = idx + return res + + +# ### tri*-something ### + + +def tril(m: ArrayLike, k=0): + return torch.tril(m, k) + + +def triu(m: ArrayLike, k=0): + return torch.triu(m, k) + + +def tril_indices(n, k=0, m=None): + if m is None: + m = n + return torch.tril_indices(n, m, offset=k) + + +def triu_indices(n, k=0, m=None): + if m is None: + m = n + return torch.triu_indices(n, m, offset=k) + + +def tril_indices_from(arr: ArrayLike, k=0): + if arr.ndim != 2: + raise ValueError("input array must be 2-d") + # Return a tensor rather than a tuple to avoid a graphbreak + return torch.tril_indices(arr.shape[0], arr.shape[1], offset=k) + + +def triu_indices_from(arr: ArrayLike, k=0): + if arr.ndim != 2: + raise ValueError("input array must be 2-d") + # Return a tensor rather than a tuple to avoid a graphbreak + return torch.triu_indices(arr.shape[0], arr.shape[1], offset=k) + + +def tri( + N, + M=None, + k=0, + dtype: Optional[DTypeLike] = None, + *, + like: NotImplementedType = None, +): + if M is None: + M = N + tensor = torch.ones((N, M), dtype=dtype) + return torch.tril(tensor, diagonal=k) + + +# ### equality, equivalence, allclose ### + + +def isclose(a: ArrayLike, b: ArrayLike, rtol=1.0e-5, atol=1.0e-8, equal_nan=False): + dtype = _dtypes_impl.result_type_impl(a, b) + a = _util.cast_if_needed(a, dtype) + b = _util.cast_if_needed(b, dtype) + return torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + +def allclose(a: ArrayLike, b: ArrayLike, rtol=1e-05, atol=1e-08, equal_nan=False): + dtype = _dtypes_impl.result_type_impl(a, b) + a = _util.cast_if_needed(a, dtype) + b = _util.cast_if_needed(b, dtype) + return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) + + +def _tensor_equal(a1, a2, equal_nan=False): + # Implementation of array_equal/array_equiv. + if a1.shape != a2.shape: + return False + cond = a1 == a2 + if equal_nan: + cond = cond | (torch.isnan(a1) & torch.isnan(a2)) + return cond.all().item() + + +def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan=False): + return _tensor_equal(a1, a2, equal_nan=equal_nan) + + +def array_equiv(a1: ArrayLike, a2: ArrayLike): + # *almost* the same as array_equal: _equiv tries to broadcast, _equal does not + try: + a1_t, a2_t = torch.broadcast_tensors(a1, a2) + except RuntimeError: + # failed to broadcast => not equivalent + return False + return _tensor_equal(a1_t, a2_t) + + +def nan_to_num( + x: ArrayLike, copy: NotImplementedType = True, nan=0.0, posinf=None, neginf=None +): + # work around RuntimeError: "nan_to_num" not implemented for 'ComplexDouble' + if x.is_complex(): + re = torch.nan_to_num(x.real, nan=nan, posinf=posinf, neginf=neginf) + im = torch.nan_to_num(x.imag, nan=nan, posinf=posinf, neginf=neginf) + return re + 1j * im + else: + return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) + + +# ### put/take_along_axis ### + + +def take( + a: ArrayLike, + indices: ArrayLike, + axis=None, + out: Optional[OutArray] = None, + mode: NotImplementedType = "raise", +): + (a,), axis = _util.axis_none_flatten(a, axis=axis) + axis = _util.normalize_axis_index(axis, a.ndim) + idx = (slice(None),) * axis + (indices, ...) + result = a[idx] + return result + + +def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis): + (arr,), axis = _util.axis_none_flatten(arr, axis=axis) + axis = _util.normalize_axis_index(axis, arr.ndim) + return torch.take_along_dim(arr, indices, axis) + + +def put( + a: NDArray, + indices: ArrayLike, + values: ArrayLike, + mode: NotImplementedType = "raise", +): + v = values.type(a.dtype) + # If indices is larger than v, expand v to at least the size of indices. Any + # unnecessary trailing elements are then trimmed. + if indices.numel() > v.numel(): + ratio = (indices.numel() + v.numel() - 1) // v.numel() + v = v.unsqueeze(0).expand((ratio,) + v.shape) + # Trim unnecessary elements, regardless if v was expanded or not. Note + # np.put() trims v to match indices by default too. + if indices.numel() < v.numel(): + v = v.flatten() + v = v[: indices.numel()] + a.put_(indices, v) + return None + + +def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis): + (arr,), axis = _util.axis_none_flatten(arr, axis=axis) + axis = _util.normalize_axis_index(axis, arr.ndim) + + indices, values = torch.broadcast_tensors(indices, values) + values = _util.cast_if_needed(values, arr.dtype) + result = torch.scatter(arr, axis, indices, values) + arr.copy_(result.reshape(arr.shape)) + return None + + +def choose( + a: ArrayLike, + choices: Sequence[ArrayLike], + out: Optional[OutArray] = None, + mode: NotImplementedType = "raise", +): + # First, broadcast elements of `choices` + choices = torch.stack(torch.broadcast_tensors(*choices)) + + # Use an analog of `gather(choices, 0, a)` which broadcasts `choices` vs `a`: + # (taken from https://github.com/pytorch/pytorch/issues/9407#issuecomment-1427907939) + idx_list = [ + torch.arange(dim).view((1,) * i + (dim,) + (1,) * (choices.ndim - i - 1)) + for i, dim in enumerate(choices.shape) + ] + + idx_list[0] = a + return choices[tuple(idx_list)].squeeze(0) + + +# ### unique et al. ### + + +def unique( + ar: ArrayLike, + return_index: NotImplementedType = False, + return_inverse=False, + return_counts=False, + axis=None, + *, + equal_nan: NotImplementedType = True, +): + (ar,), axis = _util.axis_none_flatten(ar, axis=axis) + axis = _util.normalize_axis_index(axis, ar.ndim) + + result = torch.unique( + ar, return_inverse=return_inverse, return_counts=return_counts, dim=axis + ) + + return result + + +def nonzero(a: ArrayLike): + return torch.nonzero(a, as_tuple=True) + + +def argwhere(a: ArrayLike): + return torch.argwhere(a) + + +def flatnonzero(a: ArrayLike): + return torch.flatten(a).nonzero(as_tuple=True)[0] + + +def clip( + a: ArrayLike, + min: Optional[ArrayLike] = None, + max: Optional[ArrayLike] = None, + out: Optional[OutArray] = None, +): + return torch.clamp(a, min, max) + + +def repeat(a: ArrayLike, repeats: ArrayLikeOrScalar, axis=None): + return torch.repeat_interleave(a, repeats, axis) + + +def tile(A: ArrayLike, reps): + if isinstance(reps, int): + reps = (reps,) + return torch.tile(A, reps) + + +def resize(a: ArrayLike, new_shape=None): + # implementation vendored from + # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/fromnumeric.py#L1420-L1497 + if new_shape is None: + return a + + if isinstance(new_shape, int): + new_shape = (new_shape,) + + a = a.flatten() + + new_size = 1 + for dim_length in new_shape: + new_size *= dim_length + if dim_length < 0: + raise ValueError("all elements of `new_shape` must be non-negative") + + if a.numel() == 0 or new_size == 0: + # First case must zero fill. The second would have repeats == 0. + return torch.zeros(new_shape, dtype=a.dtype) + + repeats = -(-new_size // a.numel()) # ceil division + a = concatenate((a,) * repeats)[:new_size] + + return reshape(a, new_shape) + + +# ### diag et al. ### + + +def diagonal(a: ArrayLike, offset=0, axis1=0, axis2=1): + axis1 = _util.normalize_axis_index(axis1, a.ndim) + axis2 = _util.normalize_axis_index(axis2, a.ndim) + return torch.diagonal(a, offset, axis1, axis2) + + +def trace( + a: ArrayLike, + offset=0, + axis1=0, + axis2=1, + dtype: Optional[DTypeLike] = None, + out: Optional[OutArray] = None, +): + result = torch.diagonal(a, offset, dim1=axis1, dim2=axis2).sum(-1, dtype=dtype) + return result + + +def eye( + N, + M=None, + k=0, + dtype: Optional[DTypeLike] = None, + order: NotImplementedType = "C", + *, + like: NotImplementedType = None, +): + if dtype is None: + dtype = _dtypes_impl.default_dtypes().float_dtype + if M is None: + M = N + z = torch.zeros(N, M, dtype=dtype) + z.diagonal(k).fill_(1) + return z + + +def identity(n, dtype: Optional[DTypeLike] = None, *, like: NotImplementedType = None): + return torch.eye(n, dtype=dtype) + + +def diag(v: ArrayLike, k=0): + return torch.diag(v, k) + + +def diagflat(v: ArrayLike, k=0): + return torch.diagflat(v, k) + + +def diag_indices(n, ndim=2): + idx = torch.arange(n) + return (idx,) * ndim + + +def diag_indices_from(arr: ArrayLike): + if not arr.ndim >= 2: + raise ValueError("input array must be at least 2-d") + # For more than d=2, the strided formula is only valid for arrays with + # all dimensions equal, so we check first. + s = arr.shape + if s[1:] != s[:-1]: + raise ValueError("All dimensions of input must be of equal length") + return diag_indices(s[0], arr.ndim) + + +def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False): + if a.ndim < 2: + raise ValueError("array must be at least 2-d") + if val.numel() == 0 and not wrap: + a.fill_diagonal_(val) + return a + + if val.ndim == 0: + val = val.unsqueeze(0) + + # torch.Tensor.fill_diagonal_ only accepts scalars + # If the size of val is too large, then val is trimmed + if a.ndim == 2: + tall = a.shape[0] > a.shape[1] + # wrap does nothing for wide matrices... + if not wrap or not tall: + # Never wraps + diag = a.diagonal() + diag.copy_(val[: diag.numel()]) + else: + # wraps and tall... leaving one empty line between diagonals?! + max_, min_ = a.shape + idx = torch.arange(max_ - max_ // (min_ + 1)) + mod = idx % min_ + div = idx // min_ + a[(div * (min_ + 1) + mod, mod)] = val[: idx.numel()] + else: + idx = diag_indices_from(a) + # a.shape = (n, n, ..., n) + a[idx] = val[: a.shape[0]] + + return a + + +def vdot(a: ArrayLike, b: ArrayLike, /): + # 1. torch only accepts 1D arrays, numpy flattens + # 2. torch requires matching dtype, while numpy casts (?) + t_a, t_b = torch.atleast_1d(a, b) + if t_a.ndim > 1: + t_a = t_a.flatten() + if t_b.ndim > 1: + t_b = t_b.flatten() + + dtype = _dtypes_impl.result_type_impl(t_a, t_b) + is_half = dtype == torch.float16 and (t_a.is_cpu or t_b.is_cpu) + is_bool = dtype == torch.bool + + # work around torch's "dot" not implemented for 'Half', 'Bool' + if is_half: + dtype = torch.float32 + elif is_bool: + dtype = torch.uint8 + + t_a = _util.cast_if_needed(t_a, dtype) + t_b = _util.cast_if_needed(t_b, dtype) + + result = torch.vdot(t_a, t_b) + + if is_half: + result = result.to(torch.float16) + elif is_bool: + result = result.to(torch.bool) + + return result + + +def tensordot(a: ArrayLike, b: ArrayLike, axes=2): + if isinstance(axes, (list, tuple)): + axes = [[ax] if isinstance(ax, int) else ax for ax in axes] + + target_dtype = _dtypes_impl.result_type_impl(a, b) + a = _util.cast_if_needed(a, target_dtype) + b = _util.cast_if_needed(b, target_dtype) + + return torch.tensordot(a, b, dims=axes) + + +def dot(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None): + dtype = _dtypes_impl.result_type_impl(a, b) + is_bool = dtype == torch.bool + if is_bool: + dtype = torch.uint8 + + a = _util.cast_if_needed(a, dtype) + b = _util.cast_if_needed(b, dtype) + + if a.ndim == 0 or b.ndim == 0: + result = a * b + else: + result = torch.matmul(a, b) + + if is_bool: + result = result.to(torch.bool) + + return result + + +def inner(a: ArrayLike, b: ArrayLike, /): + dtype = _dtypes_impl.result_type_impl(a, b) + is_half = dtype == torch.float16 and (a.is_cpu or b.is_cpu) + is_bool = dtype == torch.bool + + if is_half: + # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" + dtype = torch.float32 + elif is_bool: + dtype = torch.uint8 + + a = _util.cast_if_needed(a, dtype) + b = _util.cast_if_needed(b, dtype) + + result = torch.inner(a, b) + + if is_half: + result = result.to(torch.float16) + elif is_bool: + result = result.to(torch.bool) + return result + + +def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None): + return torch.outer(a, b) + + +def cross(a: ArrayLike, b: ArrayLike, axisa=-1, axisb=-1, axisc=-1, axis=None): + # implementation vendored from + # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1486-L1685 + if axis is not None: + axisa, axisb, axisc = (axis,) * 3 + + # Check axisa and axisb are within bounds + axisa = _util.normalize_axis_index(axisa, a.ndim) + axisb = _util.normalize_axis_index(axisb, b.ndim) + + # Move working axis to the end of the shape + a = torch.moveaxis(a, axisa, -1) + b = torch.moveaxis(b, axisb, -1) + msg = "incompatible dimensions for cross product\n(dimension must be 2 or 3)" + if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3): + raise ValueError(msg) + + # Create the output array + shape = broadcast_shapes(a[..., 0].shape, b[..., 0].shape) + if a.shape[-1] == 3 or b.shape[-1] == 3: + shape += (3,) + # Check axisc is within bounds + axisc = _util.normalize_axis_index(axisc, len(shape)) + dtype = _dtypes_impl.result_type_impl(a, b) + cp = torch.empty(shape, dtype=dtype) + + # recast arrays as dtype + a = _util.cast_if_needed(a, dtype) + b = _util.cast_if_needed(b, dtype) + + # create local aliases for readability + a0 = a[..., 0] + a1 = a[..., 1] + if a.shape[-1] == 3: + a2 = a[..., 2] + b0 = b[..., 0] + b1 = b[..., 1] + if b.shape[-1] == 3: + b2 = b[..., 2] + if cp.ndim != 0 and cp.shape[-1] == 3: + cp0 = cp[..., 0] + cp1 = cp[..., 1] + cp2 = cp[..., 2] + + if a.shape[-1] == 2: + if b.shape[-1] == 2: + # a0 * b1 - a1 * b0 + cp[...] = a0 * b1 - a1 * b0 + return cp + else: + assert b.shape[-1] == 3 + # cp0 = a1 * b2 - 0 (a2 = 0) + # cp1 = 0 - a0 * b2 (a2 = 0) + # cp2 = a0 * b1 - a1 * b0 + cp0[...] = a1 * b2 + cp1[...] = -a0 * b2 + cp2[...] = a0 * b1 - a1 * b0 + else: + assert a.shape[-1] == 3 + if b.shape[-1] == 3: + cp0[...] = a1 * b2 - a2 * b1 + cp1[...] = a2 * b0 - a0 * b2 + cp2[...] = a0 * b1 - a1 * b0 + else: + assert b.shape[-1] == 2 + cp0[...] = -a2 * b1 + cp1[...] = a2 * b0 + cp2[...] = a0 * b1 - a1 * b0 + + return torch.moveaxis(cp, -1, axisc) + + +def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=False): + # Have to manually normalize *operands and **kwargs, following the NumPy signature + # We have a local import to avoid poluting the global space, as it will be then + # exported in funcs.py + from ._ndarray import ndarray + from ._normalizations import ( + maybe_copy_to, + normalize_array_like, + normalize_casting, + normalize_dtype, + wrap_tensors, + ) + + dtype = normalize_dtype(dtype) + casting = normalize_casting(casting) + if out is not None and not isinstance(out, ndarray): + raise TypeError("'out' must be an array") + if order != "K": + raise NotImplementedError("'order' parameter is not supported.") + + # parse arrays and normalize them + sublist_format = not isinstance(operands[0], str) + if sublist_format: + # op, str, op, str ... [sublistout] format: normalize every other argument + + # - if sublistout is not given, the length of operands is even, and we pick + # odd-numbered elements, which are arrays. + # - if sublistout is given, the length of operands is odd, we peel off + # the last one, and pick odd-numbered elements, which are arrays. + # Without [:-1], we would have picked sublistout, too. + array_operands = operands[:-1][::2] + else: + # ("ij->", arrays) format + subscripts, array_operands = operands[0], operands[1:] + + tensors = [normalize_array_like(op) for op in array_operands] + target_dtype = _dtypes_impl.result_type_impl(*tensors) if dtype is None else dtype + + # work around 'bmm' not implemented for 'Half' etc + is_half = target_dtype == torch.float16 and all(t.is_cpu for t in tensors) + if is_half: + target_dtype = torch.float32 + + is_short_int = target_dtype in [torch.uint8, torch.int8, torch.int16, torch.int32] + if is_short_int: + target_dtype = torch.int64 + + tensors = _util.typecast_tensors(tensors, target_dtype, casting) + + from torch.backends import opt_einsum + + try: + # set the global state to handle the optimize=... argument, restore on exit + if opt_einsum.is_available(): + old_strategy = torch.backends.opt_einsum.strategy + old_enabled = torch.backends.opt_einsum.enabled + + # torch.einsum calls opt_einsum.contract_path, which runs into + # https://github.com/dgasmith/opt_einsum/issues/219 + # for strategy={True, False} + if optimize is True: + optimize = "auto" + elif optimize is False: + torch.backends.opt_einsum.enabled = False + + torch.backends.opt_einsum.strategy = optimize + + if sublist_format: + # recombine operands + sublists = operands[1::2] + has_sublistout = len(operands) % 2 == 1 + if has_sublistout: + sublistout = operands[-1] + operands = list(itertools.chain.from_iterable(zip(tensors, sublists))) + if has_sublistout: + operands.append(sublistout) + + result = torch.einsum(*operands) + else: + result = torch.einsum(subscripts, *tensors) + + finally: + if opt_einsum.is_available(): + torch.backends.opt_einsum.strategy = old_strategy + torch.backends.opt_einsum.enabled = old_enabled + + result = maybe_copy_to(out, result) + return wrap_tensors(result) + + +# ### sort and partition ### + + +def _sort_helper(tensor, axis, kind, order): + if tensor.dtype.is_complex: + raise NotImplementedError(f"sorting {tensor.dtype} is not supported") + (tensor,), axis = _util.axis_none_flatten(tensor, axis=axis) + axis = _util.normalize_axis_index(axis, tensor.ndim) + + stable = kind == "stable" + + return tensor, axis, stable + + +def sort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None): + # `order` keyword arg is only relevant for structured dtypes; so not supported here. + a, axis, stable = _sort_helper(a, axis, kind, order) + result = torch.sort(a, dim=axis, stable=stable) + return result.values + + +def argsort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None): + a, axis, stable = _sort_helper(a, axis, kind, order) + return torch.argsort(a, dim=axis, stable=stable) + + +def searchsorted( + a: ArrayLike, v: ArrayLike, side="left", sorter: Optional[ArrayLike] = None +): + if a.dtype.is_complex: + raise NotImplementedError(f"searchsorted with dtype={a.dtype}") + + return torch.searchsorted(a, v, side=side, sorter=sorter) + + +# ### swap/move/roll axis ### + + +def moveaxis(a: ArrayLike, source, destination): + source = _util.normalize_axis_tuple(source, a.ndim, "source") + destination = _util.normalize_axis_tuple(destination, a.ndim, "destination") + return torch.moveaxis(a, source, destination) + + +def swapaxes(a: ArrayLike, axis1, axis2): + axis1 = _util.normalize_axis_index(axis1, a.ndim) + axis2 = _util.normalize_axis_index(axis2, a.ndim) + return torch.swapaxes(a, axis1, axis2) + + +def rollaxis(a: ArrayLike, axis, start=0): + # Straight vendor from: + # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1259 + # + # Also note this function in NumPy is mostly retained for backwards compat + # (https://stackoverflow.com/questions/29891583/reason-why-numpy-rollaxis-is-so-confusing) + # so let's not touch it unless hard pressed. + n = a.ndim + axis = _util.normalize_axis_index(axis, n) + if start < 0: + start += n + msg = "'%s' arg requires %d <= %s < %d, but %d was passed in" + if not (0 <= start < n + 1): + raise _util.AxisError(msg % ("start", -n, "start", n + 1, start)) + if axis < start: + # it's been removed + start -= 1 + if axis == start: + # numpy returns a view, here we try returning the tensor itself + # return tensor[...] + return a + axes = list(range(0, n)) + axes.remove(axis) + axes.insert(start, axis) + return a.view(axes) + + +def roll(a: ArrayLike, shift, axis=None): + if axis is not None: + axis = _util.normalize_axis_tuple(axis, a.ndim, allow_duplicate=True) + if not isinstance(shift, tuple): + shift = (shift,) * len(axis) + return torch.roll(a, shift, axis) + + +# ### shape manipulations ### + + +def squeeze(a: ArrayLike, axis=None): + if axis == (): + result = a + elif axis is None: + result = a.squeeze() + else: + if isinstance(axis, tuple): + result = a + for ax in axis: + result = a.squeeze(ax) + else: + result = a.squeeze(axis) + return result + + +def reshape(a: ArrayLike, newshape, order: NotImplementedType = "C"): + # if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh) + newshape = newshape[0] if len(newshape) == 1 else newshape + return a.reshape(newshape) + + +# NB: cannot use torch.reshape(a, newshape) above, because of +# (Pdb) torch.reshape(torch.as_tensor([1]), 1) +# *** TypeError: reshape(): argument 'shape' (position 2) must be tuple of SymInts, not int + + +def transpose(a: ArrayLike, axes=None): + # numpy allows both .transpose(sh) and .transpose(*sh) + # also older code uses axes being a list + if axes in [(), None, (None,)]: + axes = tuple(reversed(range(a.ndim))) + elif len(axes) == 1: + axes = axes[0] + return a.permute(axes) + + +def ravel(a: ArrayLike, order: NotImplementedType = "C"): + return torch.flatten(a) + + +def diff( + a: ArrayLike, + n=1, + axis=-1, + prepend: Optional[ArrayLike] = None, + append: Optional[ArrayLike] = None, +): + axis = _util.normalize_axis_index(axis, a.ndim) + + if n < 0: + raise ValueError(f"order must be non-negative but got {n}") + + if n == 0: + # match numpy and return the input immediately + return a + + if prepend is not None: + shape = list(a.shape) + shape[axis] = prepend.shape[axis] if prepend.ndim > 0 else 1 + prepend = torch.broadcast_to(prepend, shape) + + if append is not None: + shape = list(a.shape) + shape[axis] = append.shape[axis] if append.ndim > 0 else 1 + append = torch.broadcast_to(append, shape) + + return torch.diff(a, n, axis=axis, prepend=prepend, append=append) + + +# ### math functions ### + + +def angle(z: ArrayLike, deg=False): + result = torch.angle(z) + if deg: + result = result * (180 / torch.pi) + return result + + +def sinc(x: ArrayLike): + return torch.sinc(x) + + +# NB: have to normalize *varargs manually +def gradient(f: ArrayLike, *varargs, axis=None, edge_order=1): + N = f.ndim # number of dimensions + + varargs = _util.ndarrays_to_tensors(varargs) + + if axis is None: + axes = tuple(range(N)) + else: + axes = _util.normalize_axis_tuple(axis, N) + + len_axes = len(axes) + n = len(varargs) + if n == 0: + # no spacing argument - use 1 in all axes + dx = [1.0] * len_axes + elif n == 1 and (_dtypes_impl.is_scalar(varargs[0]) or varargs[0].ndim == 0): + # single scalar or 0D tensor for all axes (np.ndim(varargs[0]) == 0) + dx = varargs * len_axes + elif n == len_axes: + # scalar or 1d array for each axis + dx = list(varargs) + for i, distances in enumerate(dx): + distances = torch.as_tensor(distances) + if distances.ndim == 0: + continue + elif distances.ndim != 1: + raise ValueError("distances must be either scalars or 1d") + if len(distances) != f.shape[axes[i]]: + raise ValueError( + "when 1d, distances must match " + "the length of the corresponding dimension" + ) + if not (distances.dtype.is_floating_point or distances.dtype.is_complex): + distances = distances.double() + + diffx = torch.diff(distances) + # if distances are constant reduce to the scalar case + # since it brings a consistent speedup + if (diffx == diffx[0]).all(): + diffx = diffx[0] + dx[i] = diffx + else: + raise TypeError("invalid number of arguments") + + if edge_order > 2: + raise ValueError("'edge_order' greater than 2 not supported") + + # use central differences on interior and one-sided differences on the + # endpoints. This preserves second order-accuracy over the full domain. + + outvals = [] + + # create slice objects --- initially all are [:, :, ..., :] + slice1 = [slice(None)] * N + slice2 = [slice(None)] * N + slice3 = [slice(None)] * N + slice4 = [slice(None)] * N + + otype = f.dtype + if _dtypes_impl.python_type_for_torch(otype) in (int, bool): + # Convert to floating point. + # First check if f is a numpy integer type; if so, convert f to float64 + # to avoid modular arithmetic when computing the changes in f. + f = f.double() + otype = torch.float64 + + for axis, ax_dx in zip(axes, dx): + if f.shape[axis] < edge_order + 1: + raise ValueError( + "Shape of array too small to calculate a numerical gradient, " + "at least (edge_order + 1) elements are required." + ) + # result allocation + out = torch.empty_like(f, dtype=otype) + + # spacing for the current axis (NB: np.ndim(ax_dx) == 0) + uniform_spacing = _dtypes_impl.is_scalar(ax_dx) or ax_dx.ndim == 0 + + # Numerical differentiation: 2nd order interior + slice1[axis] = slice(1, -1) + slice2[axis] = slice(None, -2) + slice3[axis] = slice(1, -1) + slice4[axis] = slice(2, None) + + if uniform_spacing: + out[tuple(slice1)] = (f[tuple(slice4)] - f[tuple(slice2)]) / (2.0 * ax_dx) + else: + dx1 = ax_dx[0:-1] + dx2 = ax_dx[1:] + a = -(dx2) / (dx1 * (dx1 + dx2)) + b = (dx2 - dx1) / (dx1 * dx2) + c = dx1 / (dx2 * (dx1 + dx2)) + # fix the shape for broadcasting + shape = [1] * N + shape[axis] = -1 + a = a.reshape(shape) + b = b.reshape(shape) + c = c.reshape(shape) + # 1D equivalent -- out[1:-1] = a * f[:-2] + b * f[1:-1] + c * f[2:] + out[tuple(slice1)] = ( + a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)] + ) + + # Numerical differentiation: 1st order edges + if edge_order == 1: + slice1[axis] = 0 + slice2[axis] = 1 + slice3[axis] = 0 + dx_0 = ax_dx if uniform_spacing else ax_dx[0] + # 1D equivalent -- out[0] = (f[1] - f[0]) / (x[1] - x[0]) + out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_0 + + slice1[axis] = -1 + slice2[axis] = -1 + slice3[axis] = -2 + dx_n = ax_dx if uniform_spacing else ax_dx[-1] + # 1D equivalent -- out[-1] = (f[-1] - f[-2]) / (x[-1] - x[-2]) + out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_n + + # Numerical differentiation: 2nd order edges + else: + slice1[axis] = 0 + slice2[axis] = 0 + slice3[axis] = 1 + slice4[axis] = 2 + if uniform_spacing: + a = -1.5 / ax_dx + b = 2.0 / ax_dx + c = -0.5 / ax_dx + else: + dx1 = ax_dx[0] + dx2 = ax_dx[1] + a = -(2.0 * dx1 + dx2) / (dx1 * (dx1 + dx2)) + b = (dx1 + dx2) / (dx1 * dx2) + c = -dx1 / (dx2 * (dx1 + dx2)) + # 1D equivalent -- out[0] = a * f[0] + b * f[1] + c * f[2] + out[tuple(slice1)] = ( + a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)] + ) + + slice1[axis] = -1 + slice2[axis] = -3 + slice3[axis] = -2 + slice4[axis] = -1 + if uniform_spacing: + a = 0.5 / ax_dx + b = -2.0 / ax_dx + c = 1.5 / ax_dx + else: + dx1 = ax_dx[-2] + dx2 = ax_dx[-1] + a = (dx2) / (dx1 * (dx1 + dx2)) + b = -(dx2 + dx1) / (dx1 * dx2) + c = (2.0 * dx2 + dx1) / (dx2 * (dx1 + dx2)) + # 1D equivalent -- out[-1] = a * f[-3] + b * f[-2] + c * f[-1] + out[tuple(slice1)] = ( + a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)] + ) + + outvals.append(out) + + # reset the slice object in this dimension to ":" + slice1[axis] = slice(None) + slice2[axis] = slice(None) + slice3[axis] = slice(None) + slice4[axis] = slice(None) + + if len_axes == 1: + return outvals[0] + else: + return outvals + + +# ### Type/shape etc queries ### + + +def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None): + if a.is_floating_point(): + result = torch.round(a, decimals=decimals) + elif a.is_complex(): + # RuntimeError: "round_cpu" not implemented for 'ComplexFloat' + result = torch.complex( + torch.round(a.real, decimals=decimals), + torch.round(a.imag, decimals=decimals), + ) + else: + # RuntimeError: "round_cpu" not implemented for 'int' + result = a + return result + + +around = round +round_ = round + + +def real_if_close(a: ArrayLike, tol=100): + if not torch.is_complex(a): + return a + if tol > 1: + # Undocumented in numpy: if tol < 1, it's an absolute tolerance! + # Otherwise, tol > 1 is relative tolerance, in units of the dtype epsilon + # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L577 + tol = tol * torch.finfo(a.dtype).eps + + mask = torch.abs(a.imag) < tol + return a.real if mask.all() else a + + +def real(a: ArrayLike): + return torch.real(a) + + +def imag(a: ArrayLike): + if a.is_complex(): + return a.imag + return torch.zeros_like(a) + + +def iscomplex(x: ArrayLike): + if torch.is_complex(x): + return x.imag != 0 + return torch.zeros_like(x, dtype=torch.bool) + + +def isreal(x: ArrayLike): + if torch.is_complex(x): + return x.imag == 0 + return torch.ones_like(x, dtype=torch.bool) + + +def iscomplexobj(x: ArrayLike): + return torch.is_complex(x) + + +def isrealobj(x: ArrayLike): + return not torch.is_complex(x) + + +def isneginf(x: ArrayLike, out: Optional[OutArray] = None): + return torch.isneginf(x) + + +def isposinf(x: ArrayLike, out: Optional[OutArray] = None): + return torch.isposinf(x) + + +def i0(x: ArrayLike): + return torch.special.i0(x) + + +def isscalar(a): + # We need to use normalize_array_like, but we don't want to export it in funcs.py + from ._normalizations import normalize_array_like + + try: + t = normalize_array_like(a) + return t.numel() == 1 + except Exception: + return False + + +# ### Filter windows ### + + +def hamming(M): + dtype = _dtypes_impl.default_dtypes().float_dtype + return torch.hamming_window(M, periodic=False, dtype=dtype) + + +def hanning(M): + dtype = _dtypes_impl.default_dtypes().float_dtype + return torch.hann_window(M, periodic=False, dtype=dtype) + + +def kaiser(M, beta): + dtype = _dtypes_impl.default_dtypes().float_dtype + return torch.kaiser_window(M, beta=beta, periodic=False, dtype=dtype) + + +def blackman(M): + dtype = _dtypes_impl.default_dtypes().float_dtype + return torch.blackman_window(M, periodic=False, dtype=dtype) + + +def bartlett(M): + dtype = _dtypes_impl.default_dtypes().float_dtype + return torch.bartlett_window(M, periodic=False, dtype=dtype) + + +# ### Dtype routines ### + +# vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L666 + + +array_type = [ + [torch.float16, torch.float32, torch.float64], + [None, torch.complex64, torch.complex128], +] +array_precision = { + torch.float16: 0, + torch.float32: 1, + torch.float64: 2, + torch.complex64: 1, + torch.complex128: 2, +} + + +def common_type(*tensors: ArrayLike): + is_complex = False + precision = 0 + for a in tensors: + t = a.dtype + if iscomplexobj(a): + is_complex = True + if not (t.is_floating_point or t.is_complex): + p = 2 # array_precision[_nx.double] + else: + p = array_precision.get(t, None) + if p is None: + raise TypeError("can't get common type for non-numeric array") + precision = builtins.max(precision, p) + if is_complex: + return array_type[1][precision] + else: + return array_type[0][precision] + + +# ### histograms ### + + +def histogram( + a: ArrayLike, + bins: ArrayLike = 10, + range=None, + normed=None, + weights: Optional[ArrayLike] = None, + density=None, +): + if normed is not None: + raise ValueError("normed argument is deprecated, use density= instead") + + if weights is not None and weights.dtype.is_complex: + raise NotImplementedError("complex weights histogram.") + + is_a_int = not (a.dtype.is_floating_point or a.dtype.is_complex) + is_w_int = weights is None or not weights.dtype.is_floating_point + if is_a_int: + a = a.double() + + if weights is not None: + weights = _util.cast_if_needed(weights, a.dtype) + + if isinstance(bins, torch.Tensor): + if bins.ndim == 0: + # bins was a single int + bins = operator.index(bins) + else: + bins = _util.cast_if_needed(bins, a.dtype) + + if range is None: + h, b = torch.histogram(a, bins, weight=weights, density=bool(density)) + else: + h, b = torch.histogram( + a, bins, range=range, weight=weights, density=bool(density) + ) + + if not density and is_w_int: + h = h.long() + if is_a_int: + b = b.long() + + return h, b + + +def histogram2d( + x, + y, + bins=10, + range: Optional[ArrayLike] = None, + normed=None, + weights: Optional[ArrayLike] = None, + density=None, +): + # vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/twodim_base.py#L655-L821 + if len(x) != len(y): + raise ValueError("x and y must have the same length.") + + try: + N = len(bins) + except TypeError: + N = 1 + + if N != 1 and N != 2: + bins = [bins, bins] + + h, e = histogramdd((x, y), bins, range, normed, weights, density) + + return h, e[0], e[1] + + +def histogramdd( + sample, + bins=10, + range: Optional[ArrayLike] = None, + normed=None, + weights: Optional[ArrayLike] = None, + density=None, +): + # have to normalize manually because `sample` interpretation differs + # for a list of lists and a 2D array + if normed is not None: + raise ValueError("normed argument is deprecated, use density= instead") + + from ._normalizations import normalize_array_like, normalize_seq_array_like + + if isinstance(sample, (list, tuple)): + sample = normalize_array_like(sample).T + else: + sample = normalize_array_like(sample) + + sample = torch.atleast_2d(sample) + + if not (sample.dtype.is_floating_point or sample.dtype.is_complex): + sample = sample.double() + + # bins is either an int, or a sequence of ints or a sequence of arrays + bins_is_array = not ( + isinstance(bins, int) or builtins.all(isinstance(b, int) for b in bins) + ) + if bins_is_array: + bins = normalize_seq_array_like(bins) + bins_dtypes = [b.dtype for b in bins] + bins = [_util.cast_if_needed(b, sample.dtype) for b in bins] + + if range is not None: + range = range.flatten().tolist() + + if weights is not None: + # range=... is required : interleave min and max values per dimension + mm = sample.aminmax(dim=0) + range = torch.cat(mm).reshape(2, -1).T.flatten() + range = tuple(range.tolist()) + weights = _util.cast_if_needed(weights, sample.dtype) + w_kwd = {"weight": weights} + else: + w_kwd = {} + + h, b = torch.histogramdd(sample, bins, range, density=bool(density), **w_kwd) + + if bins_is_array: + b = [_util.cast_if_needed(bb, dtyp) for bb, dtyp in zip(b, bins_dtypes)] + + return h, b + + +# ### odds and ends + + +def min_scalar_type(a: ArrayLike, /): + # https://github.com/numpy/numpy/blob/maintenance/1.24.x/numpy/core/src/multiarray/convert_datatype.c#L1288 + + from ._dtypes import DType + + if a.numel() > 1: + # numpy docs: "For non-scalar array a, returns the vector's dtype unmodified." + return DType(a.dtype) + + if a.dtype == torch.bool: + dtype = torch.bool + + elif a.dtype.is_complex: + fi = torch.finfo(torch.float32) + fits_in_single = a.dtype == torch.complex64 or ( + fi.min <= a.real <= fi.max and fi.min <= a.imag <= fi.max + ) + dtype = torch.complex64 if fits_in_single else torch.complex128 + + elif a.dtype.is_floating_point: + for dt in [torch.float16, torch.float32, torch.float64]: + fi = torch.finfo(dt) + if fi.min <= a <= fi.max: + dtype = dt + break + else: + # must be integer + for dt in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]: + # Prefer unsigned int where possible, as numpy does. + ii = torch.iinfo(dt) + if ii.min <= a <= ii.max: + dtype = dt + break + + return DType(dtype) + + +def pad(array: ArrayLike, pad_width: ArrayLike, mode="constant", **kwargs): + if mode != "constant": + raise NotImplementedError + value = kwargs.get("constant_values", 0) + # `value` must be a python scalar for torch.nn.functional.pad + typ = _dtypes_impl.python_type_for_torch(array.dtype) + value = typ(value) + + pad_width = torch.broadcast_to(pad_width, (array.ndim, 2)) + pad_width = torch.flip(pad_width, (0,)).flatten() + + return torch.nn.functional.pad(array, tuple(pad_width), value=value) diff --git a/phivenv/Lib/site-packages/torch/_numpy/_getlimits.py b/phivenv/Lib/site-packages/torch/_numpy/_getlimits.py new file mode 100644 index 0000000000000000000000000000000000000000..75036ce6ab4b0b417be7ea0a308ec19018304fd4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/_getlimits.py @@ -0,0 +1,15 @@ +# mypy: ignore-errors + +import torch + +from . import _dtypes + + +def finfo(dtyp): + torch_dtype = _dtypes.dtype(dtyp).torch_dtype + return torch.finfo(torch_dtype) + + +def iinfo(dtyp): + torch_dtype = _dtypes.dtype(dtyp).torch_dtype + return torch.iinfo(torch_dtype) diff --git a/phivenv/Lib/site-packages/torch/_numpy/_ndarray.py b/phivenv/Lib/site-packages/torch/_numpy/_ndarray.py new file mode 100644 index 0000000000000000000000000000000000000000..7237bfd700188c0e7f69b24d9a518c4411435902 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/_ndarray.py @@ -0,0 +1,592 @@ +# mypy: ignore-errors + +from __future__ import annotations + +import builtins +import math +import operator +from collections.abc import Sequence + +import torch + +from . import _dtypes, _dtypes_impl, _funcs, _ufuncs, _util +from ._normalizations import ( + ArrayLike, + normalize_array_like, + normalizer, + NotImplementedType, +) + + +newaxis = None + +FLAGS = [ + "C_CONTIGUOUS", + "F_CONTIGUOUS", + "OWNDATA", + "WRITEABLE", + "ALIGNED", + "WRITEBACKIFCOPY", + "FNC", + "FORC", + "BEHAVED", + "CARRAY", + "FARRAY", +] + +SHORTHAND_TO_FLAGS = { + "C": "C_CONTIGUOUS", + "F": "F_CONTIGUOUS", + "O": "OWNDATA", + "W": "WRITEABLE", + "A": "ALIGNED", + "X": "WRITEBACKIFCOPY", + "B": "BEHAVED", + "CA": "CARRAY", + "FA": "FARRAY", +} + + +class Flags: + def __init__(self, flag_to_value: dict): + assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check + self._flag_to_value = flag_to_value + + def __getattr__(self, attr: str): + if attr.islower() and attr.upper() in FLAGS: + return self[attr.upper()] + else: + raise AttributeError(f"No flag attribute '{attr}'") + + def __getitem__(self, key): + if key in SHORTHAND_TO_FLAGS.keys(): + key = SHORTHAND_TO_FLAGS[key] + if key in FLAGS: + try: + return self._flag_to_value[key] + except KeyError as e: + raise NotImplementedError(f"{key=}") from e + else: + raise KeyError(f"No flag key '{key}'") + + def __setattr__(self, attr, value): + if attr.islower() and attr.upper() in FLAGS: + self[attr.upper()] = value + else: + super().__setattr__(attr, value) + + def __setitem__(self, key, value): + if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys(): + raise NotImplementedError("Modifying flags is not implemented") + else: + raise KeyError(f"No flag key '{key}'") + + +def create_method(fn, name=None): + name = name or fn.__name__ + + def f(*args, **kwargs): + return fn(*args, **kwargs) + + f.__name__ = name + f.__qualname__ = f"ndarray.{name}" + return f + + +# Map ndarray.name_method -> np.name_func +# If name_func == None, it means that name_method == name_func +methods = { + "clip": None, + "nonzero": None, + "repeat": None, + "round": None, + "squeeze": None, + "swapaxes": None, + "ravel": None, + # linalg + "diagonal": None, + "dot": None, + "trace": None, + # sorting + "argsort": None, + "searchsorted": None, + # reductions + "argmax": None, + "argmin": None, + "any": None, + "all": None, + "max": None, + "min": None, + "ptp": None, + "sum": None, + "prod": None, + "mean": None, + "var": None, + "std": None, + # scans + "cumsum": None, + "cumprod": None, + # advanced indexing + "take": None, + "choose": None, +} + +dunder = { + "abs": "absolute", + "invert": None, + "pos": "positive", + "neg": "negative", + "gt": "greater", + "lt": "less", + "ge": "greater_equal", + "le": "less_equal", +} + +# dunder methods with right-looking and in-place variants +ri_dunder = { + "add": None, + "sub": "subtract", + "mul": "multiply", + "truediv": "divide", + "floordiv": "floor_divide", + "pow": "power", + "mod": "remainder", + "and": "bitwise_and", + "or": "bitwise_or", + "xor": "bitwise_xor", + "lshift": "left_shift", + "rshift": "right_shift", + "matmul": None, +} + + +def _upcast_int_indices(index): + if isinstance(index, torch.Tensor): + if index.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8): + return index.to(torch.int64) + elif isinstance(index, tuple): + return tuple(_upcast_int_indices(i) for i in index) + return index + + +# Used to indicate that a parameter is unspecified (as opposed to explicitly +# `None`) +class _Unspecified: + pass + + +_Unspecified.unspecified = _Unspecified() + +############################################################### +# ndarray class # +############################################################### + + +class ndarray: + def __init__(self, t=None): + if t is None: + self.tensor = torch.Tensor() + elif isinstance(t, torch.Tensor): + self.tensor = t + else: + raise ValueError( + "ndarray constructor is not recommended; prefer" + "either array(...) or zeros/empty(...)" + ) + + # Register NumPy functions as methods + for method, name in methods.items(): + fn = getattr(_funcs, name or method) + vars()[method] = create_method(fn, method) + + # Regular methods but coming from ufuncs + conj = create_method(_ufuncs.conjugate, "conj") + conjugate = create_method(_ufuncs.conjugate) + + for method, name in dunder.items(): + fn = getattr(_ufuncs, name or method) + method = f"__{method}__" + vars()[method] = create_method(fn, method) + + for method, name in ri_dunder.items(): + fn = getattr(_ufuncs, name or method) + plain = f"__{method}__" + vars()[plain] = create_method(fn, plain) + rvar = f"__r{method}__" + vars()[rvar] = create_method(lambda self, other, fn=fn: fn(other, self), rvar) + ivar = f"__i{method}__" + vars()[ivar] = create_method( + lambda self, other, fn=fn: fn(self, other, out=self), ivar + ) + + # There's no __idivmod__ + __divmod__ = create_method(_ufuncs.divmod, "__divmod__") + __rdivmod__ = create_method( + lambda self, other: _ufuncs.divmod(other, self), "__rdivmod__" + ) + + # prevent loop variables leaking into the ndarray class namespace + del ivar, rvar, name, plain, fn, method + + @property + def shape(self): + return tuple(self.tensor.shape) + + @property + def size(self): + return self.tensor.numel() + + @property + def ndim(self): + return self.tensor.ndim + + @property + def dtype(self): + return _dtypes.dtype(self.tensor.dtype) + + @property + def strides(self): + elsize = self.tensor.element_size() + return tuple(stride * elsize for stride in self.tensor.stride()) + + @property + def itemsize(self): + return self.tensor.element_size() + + @property + def flags(self): + # Note contiguous in torch is assumed C-style + return Flags( + { + "C_CONTIGUOUS": self.tensor.is_contiguous(), + "F_CONTIGUOUS": self.T.tensor.is_contiguous(), + "OWNDATA": self.tensor._base is None, + "WRITEABLE": True, # pytorch does not have readonly tensors + } + ) + + @property + def data(self): + return self.tensor.data_ptr() + + @property + def nbytes(self): + return self.tensor.storage().nbytes() + + @property + def T(self): + return self.transpose() + + @property + def real(self): + return _funcs.real(self) + + @real.setter + def real(self, value): + self.tensor.real = asarray(value).tensor + + @property + def imag(self): + return _funcs.imag(self) + + @imag.setter + def imag(self, value): + self.tensor.imag = asarray(value).tensor + + # ctors + def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True): + if order != "K": + raise NotImplementedError(f"astype(..., order={order} is not implemented.") + if casting != "unsafe": + raise NotImplementedError( + f"astype(..., casting={casting} is not implemented." + ) + if not subok: + raise NotImplementedError(f"astype(..., subok={subok} is not implemented.") + if not copy: + raise NotImplementedError(f"astype(..., copy={copy} is not implemented.") + torch_dtype = _dtypes.dtype(dtype).torch_dtype + t = self.tensor.to(torch_dtype) + return ndarray(t) + + @normalizer + def copy(self: ArrayLike, order: NotImplementedType = "C"): + return self.clone() + + @normalizer + def flatten(self: ArrayLike, order: NotImplementedType = "C"): + return torch.flatten(self) + + def resize(self, *new_shape, refcheck=False): + # NB: differs from np.resize: fills with zeros instead of making repeated copies of input. + if refcheck: + raise NotImplementedError( + f"resize(..., refcheck={refcheck} is not implemented." + ) + if new_shape in [(), (None,)]: + return + + # support both x.resize((2, 2)) and x.resize(2, 2) + if len(new_shape) == 1: + new_shape = new_shape[0] + if isinstance(new_shape, int): + new_shape = (new_shape,) + + if builtins.any(x < 0 for x in new_shape): + raise ValueError("all elements of `new_shape` must be non-negative") + + new_numel, old_numel = math.prod(new_shape), self.tensor.numel() + + self.tensor.resize_(new_shape) + + if new_numel >= old_numel: + # zero-fill new elements + assert self.tensor.is_contiguous() + b = self.tensor.flatten() # does not copy + b[old_numel:].zero_() + + def view(self, dtype=_Unspecified.unspecified, type=_Unspecified.unspecified): + if dtype is _Unspecified.unspecified: + dtype = self.dtype + if type is not _Unspecified.unspecified: + raise NotImplementedError(f"view(..., type={type} is not implemented.") + torch_dtype = _dtypes.dtype(dtype).torch_dtype + tview = self.tensor.view(torch_dtype) + return ndarray(tview) + + @normalizer + def fill(self, value: ArrayLike): + # Both Pytorch and NumPy accept 0D arrays/tensors and scalars, and + # error out on D > 0 arrays + self.tensor.fill_(value) + + def tolist(self): + return self.tensor.tolist() + + def __iter__(self): + return (ndarray(x) for x in self.tensor.__iter__()) + + def __str__(self): + return ( + str(self.tensor) + .replace("tensor", "torch.ndarray") + .replace("dtype=torch.", "dtype=") + ) + + __repr__ = create_method(__str__) + + def __eq__(self, other): + try: + return _ufuncs.equal(self, other) + except (RuntimeError, TypeError): + # Failed to convert other to array: definitely not equal. + falsy = torch.full(self.shape, fill_value=False, dtype=bool) + return asarray(falsy) + + def __ne__(self, other): + return ~(self == other) + + def __index__(self): + try: + return operator.index(self.tensor.item()) + except Exception as exc: + raise TypeError( + "only integer scalar arrays can be converted to a scalar index" + ) from exc + + def __bool__(self): + return bool(self.tensor) + + def __int__(self): + return int(self.tensor) + + def __float__(self): + return float(self.tensor) + + def __complex__(self): + return complex(self.tensor) + + def is_integer(self): + try: + v = self.tensor.item() + result = int(v) == v + except Exception: + result = False + return result + + def __len__(self): + return self.tensor.shape[0] + + def __contains__(self, x): + return self.tensor.__contains__(x) + + def transpose(self, *axes): + # np.transpose(arr, axis=None) but arr.transpose(*axes) + return _funcs.transpose(self, axes) + + def reshape(self, *shape, order="C"): + # arr.reshape(shape) and arr.reshape(*shape) + return _funcs.reshape(self, shape, order=order) + + def sort(self, axis=-1, kind=None, order=None): + # ndarray.sort works in-place + _funcs.copyto(self, _funcs.sort(self, axis, kind, order)) + + def item(self, *args): + # Mimic NumPy's implementation with three special cases (no arguments, + # a flat index and a multi-index): + # https://github.com/numpy/numpy/blob/main/numpy/_core/src/multiarray/methods.c#L702 + if args == (): + return self.tensor.item() + elif len(args) == 1: + # int argument + return self.ravel()[args[0]] + else: + return self.__getitem__(args) + + def __getitem__(self, index): + tensor = self.tensor + + def neg_step(i, s): + if not (isinstance(s, slice) and s.step is not None and s.step < 0): + return s + + nonlocal tensor + tensor = torch.flip(tensor, (i,)) + + # Account for the fact that a slice includes the start but not the end + assert isinstance(s.start, int) or s.start is None + assert isinstance(s.stop, int) or s.stop is None + start = s.stop + 1 if s.stop else None + stop = s.start + 1 if s.start else None + + return slice(start, stop, -s.step) + + if isinstance(index, Sequence): + index = type(index)(neg_step(i, s) for i, s in enumerate(index)) + else: + index = neg_step(0, index) + index = _util.ndarrays_to_tensors(index) + index = _upcast_int_indices(index) + return ndarray(tensor.__getitem__(index)) + + def __setitem__(self, index, value): + index = _util.ndarrays_to_tensors(index) + index = _upcast_int_indices(index) + + if not _dtypes_impl.is_scalar(value): + value = normalize_array_like(value) + value = _util.cast_if_needed(value, self.tensor.dtype) + + return self.tensor.__setitem__(index, value) + + take = _funcs.take + put = _funcs.put + + def __dlpack__(self, *, stream=None): + return self.tensor.__dlpack__(stream=stream) + + def __dlpack_device__(self): + return self.tensor.__dlpack_device__() + + +def _tolist(obj): + """Recursively convert tensors into lists.""" + a1 = [] + for elem in obj: + if isinstance(elem, (list, tuple)): + elem = _tolist(elem) + if isinstance(elem, ndarray): + a1.append(elem.tensor.tolist()) + else: + a1.append(elem) + return a1 + + +# This is the ideally the only place which talks to ndarray directly. +# The rest goes through asarray (preferred) or array. + + +def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None): + if subok is not False: + raise NotImplementedError("'subok' parameter is not supported.") + if like is not None: + raise NotImplementedError("'like' parameter is not supported.") + if order != "K": + raise NotImplementedError + + # a happy path + if ( + isinstance(obj, ndarray) + and copy is False + and dtype is None + and ndmin <= obj.ndim + ): + return obj + + if isinstance(obj, (list, tuple)): + # FIXME and they have the same dtype, device, etc + if obj and all(isinstance(x, torch.Tensor) for x in obj): + # list of arrays: *under torch.Dynamo* these are FakeTensors + obj = torch.stack(obj) + else: + # XXX: remove tolist + # lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists + obj = _tolist(obj) + + # is obj an ndarray already? + if isinstance(obj, ndarray): + obj = obj.tensor + + # is a specific dtype requested? + torch_dtype = None + if dtype is not None: + torch_dtype = _dtypes.dtype(dtype).torch_dtype + + tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin) + return ndarray(tensor) + + +def asarray(a, dtype=None, order="K", *, like=None): + return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0) + + +def ascontiguousarray(a, dtype=None, *, like=None): + arr = asarray(a, dtype=dtype, like=like) + if not arr.tensor.is_contiguous(): + arr.tensor = arr.tensor.contiguous() + return arr + + +def from_dlpack(x, /): + t = torch.from_dlpack(x) + return ndarray(t) + + +def _extract_dtype(entry): + try: + dty = _dtypes.dtype(entry) + except Exception: + dty = asarray(entry).dtype + return dty + + +def can_cast(from_, to, casting="safe"): + from_ = _extract_dtype(from_) + to_ = _extract_dtype(to) + + return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting) + + +def result_type(*arrays_and_dtypes): + tensors = [] + for entry in arrays_and_dtypes: + try: + t = asarray(entry).tensor + except (RuntimeError, ValueError, TypeError): + dty = _dtypes.dtype(entry) + t = torch.empty(1, dtype=dty.torch_dtype) + tensors.append(t) + + torch_dtype = _dtypes_impl.result_type_impl(*tensors) + return _dtypes.dtype(torch_dtype) diff --git a/phivenv/Lib/site-packages/torch/_numpy/_normalizations.py b/phivenv/Lib/site-packages/torch/_numpy/_normalizations.py new file mode 100644 index 0000000000000000000000000000000000000000..abbcc3288c9b9de060581182d9a1278faf007d66 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/_normalizations.py @@ -0,0 +1,259 @@ +# mypy: ignore-errors + +""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on. +""" +from __future__ import annotations + +import functools +import inspect +import operator +import typing + +import torch + +from . import _dtypes, _dtypes_impl, _util + + +ArrayLike = typing.TypeVar("ArrayLike") +Scalar = typing.Union[int, float, complex, bool] +ArrayLikeOrScalar = typing.Union[ArrayLike, Scalar] + +DTypeLike = typing.TypeVar("DTypeLike") +AxisLike = typing.TypeVar("AxisLike") +NDArray = typing.TypeVar("NDArray") +CastingModes = typing.TypeVar("CastingModes") +KeepDims = typing.TypeVar("KeepDims") + +# OutArray is to annotate the out= array argument. +# +# This one is special is several respects: +# First, It needs to be an NDArray, and we need to preserve the `result is out` +# semantics. Therefore, we cannot just extract the Tensor from the out array. +# So we never pass the out array to implementer functions and handle it in the +# `normalizer` below. +# Second, the out= argument can be either keyword or positional argument, and +# as a positional arg, it can be anywhere in the signature. +# To handle all this, we define a special `OutArray` annotation and dispatch on it. +# +OutArray = typing.TypeVar("OutArray") + +try: + from typing import NotImplementedType +except ImportError: + NotImplementedType = typing.TypeVar("NotImplementedType") + + +def normalize_array_like(x, parm=None): + from ._ndarray import asarray + + return asarray(x).tensor + + +def normalize_array_like_or_scalar(x, parm=None): + if _dtypes_impl.is_scalar_or_symbolic(x): + return x + return normalize_array_like(x, parm) + + +def normalize_optional_array_like_or_scalar(x, parm=None): + if x is None: + return None + return normalize_array_like_or_scalar(x, parm) + + +def normalize_optional_array_like(x, parm=None): + # This explicit normalizer is needed because otherwise normalize_array_like + # does not run for a parameter annotated as Optional[ArrayLike] + return None if x is None else normalize_array_like(x, parm) + + +def normalize_seq_array_like(x, parm=None): + return tuple(normalize_array_like(value) for value in x) + + +def normalize_dtype(dtype, parm=None): + # cf _decorators.dtype_to_torch + torch_dtype = None + if dtype is not None: + dtype = _dtypes.dtype(dtype) + torch_dtype = dtype.torch_dtype + return torch_dtype + + +def normalize_not_implemented(arg, parm): + if arg != parm.default: + raise NotImplementedError(f"'{parm.name}' parameter is not supported.") + + +def normalize_axis_like(arg, parm=None): + from ._ndarray import ndarray + + if isinstance(arg, ndarray): + arg = operator.index(arg) + return arg + + +def normalize_ndarray(arg, parm=None): + # check the arg is an ndarray, extract its tensor attribute + if arg is None: + return arg + + from ._ndarray import ndarray + + if not isinstance(arg, ndarray): + raise TypeError(f"'{parm.name}' must be an array") + return arg.tensor + + +def normalize_outarray(arg, parm=None): + # almost normalize_ndarray, only return the array, not its tensor + if arg is None: + return arg + from ._ndarray import ndarray + + # Dynamo can pass torch tensors as out arguments, + # wrap it in an ndarray before processing + if isinstance(arg, torch.Tensor): + arg = ndarray(arg) + + if not isinstance(arg, ndarray): + raise TypeError(f"'{parm.name}' must be an array") + return arg + + +def normalize_casting(arg, parm=None): + if arg not in ["no", "equiv", "safe", "same_kind", "unsafe"]: + raise ValueError( + f"casting must be one of 'no', 'equiv', 'safe', 'same_kind', or 'unsafe' (got '{arg}')" + ) + return arg + + +normalizers = { + "ArrayLike": normalize_array_like, + "ArrayLikeOrScalar": normalize_array_like_or_scalar, + "Optional[ArrayLike]": normalize_optional_array_like, + "Sequence[ArrayLike]": normalize_seq_array_like, + "Optional[ArrayLikeOrScalar]": normalize_optional_array_like_or_scalar, + "Optional[NDArray]": normalize_ndarray, + "Optional[OutArray]": normalize_outarray, + "NDArray": normalize_ndarray, + "Optional[DTypeLike]": normalize_dtype, + "AxisLike": normalize_axis_like, + "NotImplementedType": normalize_not_implemented, + "Optional[CastingModes]": normalize_casting, +} + + +def maybe_normalize(arg, parm): + """Normalize arg if a normalizer is registered.""" + normalizer = normalizers.get(parm.annotation, None) + return normalizer(arg, parm) if normalizer else arg + + +# ### Return value helpers ### + + +def maybe_copy_to(out, result, promote_scalar_result=False): + # NB: here out is either an ndarray or None + if out is None: + return result + elif isinstance(result, torch.Tensor): + if result.shape != out.shape: + can_fit = result.numel() == 1 and out.ndim == 0 + if promote_scalar_result and can_fit: + result = result.squeeze() + else: + raise ValueError( + f"Bad size of the out array: out.shape = {out.shape}" + f" while result.shape = {result.shape}." + ) + out.tensor.copy_(result) + return out + elif isinstance(result, (tuple, list)): + return type(result)( + maybe_copy_to(o, r, promote_scalar_result) for o, r in zip(out, result) + ) + else: + raise AssertionError # We should never hit this path + + +def wrap_tensors(result): + from ._ndarray import ndarray + + if isinstance(result, torch.Tensor): + return ndarray(result) + elif isinstance(result, (tuple, list)): + result = type(result)(wrap_tensors(x) for x in result) + return result + + +def array_or_scalar(values, py_type=float, return_scalar=False): + if return_scalar: + return py_type(values.item()) + else: + from ._ndarray import ndarray + + return ndarray(values) + + +# ### The main decorator to normalize arguments / postprocess the output ### + + +def normalizer(_func=None, *, promote_scalar_result=False): + def normalizer_inner(func): + @functools.wraps(func) + def wrapped(*args, **kwds): + sig = inspect.signature(func) + params = sig.parameters + first_param = next(iter(params.values())) + + # NumPy's API does not have positional args before variadic positional args + if first_param.kind == inspect.Parameter.VAR_POSITIONAL: + args = [maybe_normalize(arg, first_param) for arg in args] + else: + # NB: extra unknown arguments: pass through, will raise in func(*args) below + args = ( + tuple( + maybe_normalize(arg, parm) + for arg, parm in zip(args, params.values()) + ) + + args[len(params.values()) :] + ) + + kwds = { + name: maybe_normalize(arg, params[name]) if name in params else arg + for name, arg in kwds.items() + } + + result = func(*args, **kwds) + + # keepdims + bound_args = None + if "keepdims" in params and params["keepdims"].annotation == "KeepDims": + # keepdims can be in any position so we need sig.bind + bound_args = sig.bind(*args, **kwds).arguments + if bound_args.get("keepdims", False): + # In this case the first arg is the initial tensor and + # the second arg is (optionally) the axis + tensor = args[0] + axis = bound_args.get("axis") + result = _util.apply_keepdims(result, axis, tensor.ndim) + + # out + if "out" in params: + # out can be in any position so we need sig.bind + if bound_args is None: + bound_args = sig.bind(*args, **kwds).arguments + out = bound_args.get("out") + result = maybe_copy_to(out, result, promote_scalar_result) + result = wrap_tensors(result) + + return result + + return wrapped + + if _func is None: + return normalizer_inner + else: + return normalizer_inner(_func) diff --git a/phivenv/Lib/site-packages/torch/_numpy/_reductions_impl.py b/phivenv/Lib/site-packages/torch/_numpy/_reductions_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..d656ea8f823ed14dadeb859d682f8a8e2fd494af --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/_reductions_impl.py @@ -0,0 +1,459 @@ +# mypy: ignore-errors + +""" Implementation of reduction operations, to be wrapped into arrays, dtypes etc +in the 'public' layer. + +Anything here only deals with torch objects, e.g. "dtype" is a torch.dtype instance etc +""" +from __future__ import annotations + +import functools +from typing import Optional, TYPE_CHECKING + +import torch + +from . import _dtypes_impl, _util + + +if TYPE_CHECKING: + from ._normalizations import ( + ArrayLike, + AxisLike, + DTypeLike, + KeepDims, + NotImplementedType, + OutArray, + ) + + +def _deco_axis_expand(func): + """ + Generically handle axis arguments in reductions. + axis is *always* the 2nd arg in the function so no need to have a look at its signature + """ + + @functools.wraps(func) + def wrapped(a, axis=None, *args, **kwds): + if axis is not None: + axis = _util.normalize_axis_tuple(axis, a.ndim) + + if axis == (): + # So we insert a length-one axis and run the reduction along it. + # We cannot return a.clone() as this would sidestep the checks inside the function + newshape = _util.expand_shape(a.shape, axis=0) + a = a.reshape(newshape) + axis = (0,) + + return func(a, axis, *args, **kwds) + + return wrapped + + +def _atleast_float(dtype, other_dtype): + """Return a dtype that is real or complex floating-point. + + For inputs that are boolean or integer dtypes, this returns the default + float dtype; inputs that are complex get converted to the default complex + dtype; real floating-point dtypes (`float*`) get passed through unchanged + """ + if dtype is None: + dtype = other_dtype + if not (dtype.is_floating_point or dtype.is_complex): + return _dtypes_impl.default_dtypes().float_dtype + return dtype + + +@_deco_axis_expand +def count_nonzero(a: ArrayLike, axis: AxisLike = None, *, keepdims: KeepDims = False): + return a.count_nonzero(axis) + + +@_deco_axis_expand +def argmax( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[OutArray] = None, + *, + keepdims: KeepDims = False, +): + if a.is_complex(): + raise NotImplementedError(f"argmax with dtype={a.dtype}.") + + axis = _util.allow_only_single_axis(axis) + + if a.dtype == torch.bool: + # RuntimeError: "argmax_cpu" not implemented for 'Bool' + a = a.to(torch.uint8) + + return torch.argmax(a, axis) + + +@_deco_axis_expand +def argmin( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[OutArray] = None, + *, + keepdims: KeepDims = False, +): + if a.is_complex(): + raise NotImplementedError(f"argmin with dtype={a.dtype}.") + + axis = _util.allow_only_single_axis(axis) + + if a.dtype == torch.bool: + # RuntimeError: "argmin_cpu" not implemented for 'Bool' + a = a.to(torch.uint8) + + return torch.argmin(a, axis) + + +@_deco_axis_expand +def any( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[OutArray] = None, + keepdims: KeepDims = False, + *, + where: NotImplementedType = None, +): + axis = _util.allow_only_single_axis(axis) + axis_kw = {} if axis is None else {"dim": axis} + return torch.any(a, **axis_kw) + + +@_deco_axis_expand +def all( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[OutArray] = None, + keepdims: KeepDims = False, + *, + where: NotImplementedType = None, +): + axis = _util.allow_only_single_axis(axis) + axis_kw = {} if axis is None else {"dim": axis} + return torch.all(a, **axis_kw) + + +@_deco_axis_expand +def amax( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[OutArray] = None, + keepdims: KeepDims = False, + initial: NotImplementedType = None, + where: NotImplementedType = None, +): + if a.is_complex(): + raise NotImplementedError(f"amax with dtype={a.dtype}") + + return a.amax(axis) + + +max = amax + + +@_deco_axis_expand +def amin( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[OutArray] = None, + keepdims: KeepDims = False, + initial: NotImplementedType = None, + where: NotImplementedType = None, +): + if a.is_complex(): + raise NotImplementedError(f"amin with dtype={a.dtype}") + + return a.amin(axis) + + +min = amin + + +@_deco_axis_expand +def ptp( + a: ArrayLike, + axis: AxisLike = None, + out: Optional[OutArray] = None, + keepdims: KeepDims = False, +): + return a.amax(axis) - a.amin(axis) + + +@_deco_axis_expand +def sum( + a: ArrayLike, + axis: AxisLike = None, + dtype: Optional[DTypeLike] = None, + out: Optional[OutArray] = None, + keepdims: KeepDims = False, + initial: NotImplementedType = None, + where: NotImplementedType = None, +): + assert dtype is None or isinstance(dtype, torch.dtype) + + if dtype == torch.bool: + dtype = _dtypes_impl.default_dtypes().int_dtype + + axis_kw = {} if axis is None else {"dim": axis} + return a.sum(dtype=dtype, **axis_kw) + + +@_deco_axis_expand +def prod( + a: ArrayLike, + axis: AxisLike = None, + dtype: Optional[DTypeLike] = None, + out: Optional[OutArray] = None, + keepdims: KeepDims = False, + initial: NotImplementedType = None, + where: NotImplementedType = None, +): + axis = _util.allow_only_single_axis(axis) + + if dtype == torch.bool: + dtype = _dtypes_impl.default_dtypes().int_dtype + + axis_kw = {} if axis is None else {"dim": axis} + return a.prod(dtype=dtype, **axis_kw) + + +product = prod + + +@_deco_axis_expand +def mean( + a: ArrayLike, + axis: AxisLike = None, + dtype: Optional[DTypeLike] = None, + out: Optional[OutArray] = None, + keepdims: KeepDims = False, + *, + where: NotImplementedType = None, +): + dtype = _atleast_float(dtype, a.dtype) + + axis_kw = {} if axis is None else {"dim": axis} + result = a.mean(dtype=dtype, **axis_kw) + + return result + + +@_deco_axis_expand +def std( + a: ArrayLike, + axis: AxisLike = None, + dtype: Optional[DTypeLike] = None, + out: Optional[OutArray] = None, + ddof=0, + keepdims: KeepDims = False, + *, + where: NotImplementedType = None, +): + in_dtype = dtype + dtype = _atleast_float(dtype, a.dtype) + tensor = _util.cast_if_needed(a, dtype) + result = tensor.std(dim=axis, correction=ddof) + return _util.cast_if_needed(result, in_dtype) + + +@_deco_axis_expand +def var( + a: ArrayLike, + axis: AxisLike = None, + dtype: Optional[DTypeLike] = None, + out: Optional[OutArray] = None, + ddof=0, + keepdims: KeepDims = False, + *, + where: NotImplementedType = None, +): + in_dtype = dtype + dtype = _atleast_float(dtype, a.dtype) + tensor = _util.cast_if_needed(a, dtype) + result = tensor.var(dim=axis, correction=ddof) + return _util.cast_if_needed(result, in_dtype) + + +# cumsum / cumprod are almost reductions: +# 1. no keepdims +# 2. axis=None flattens + + +def cumsum( + a: ArrayLike, + axis: AxisLike = None, + dtype: Optional[DTypeLike] = None, + out: Optional[OutArray] = None, +): + if dtype == torch.bool: + dtype = _dtypes_impl.default_dtypes().int_dtype + if dtype is None: + dtype = a.dtype + + (a,), axis = _util.axis_none_flatten(a, axis=axis) + axis = _util.normalize_axis_index(axis, a.ndim) + + return a.cumsum(axis=axis, dtype=dtype) + + +def cumprod( + a: ArrayLike, + axis: AxisLike = None, + dtype: Optional[DTypeLike] = None, + out: Optional[OutArray] = None, +): + if dtype == torch.bool: + dtype = _dtypes_impl.default_dtypes().int_dtype + if dtype is None: + dtype = a.dtype + + (a,), axis = _util.axis_none_flatten(a, axis=axis) + axis = _util.normalize_axis_index(axis, a.ndim) + + return a.cumprod(axis=axis, dtype=dtype) + + +cumproduct = cumprod + + +def average( + a: ArrayLike, + axis=None, + weights: ArrayLike = None, + returned=False, + *, + keepdims=False, +): + if weights is None: + result = mean(a, axis=axis) + wsum = torch.as_tensor(a.numel() / result.numel(), dtype=result.dtype) + else: + if not a.dtype.is_floating_point: + a = a.double() + + # axis & weights + if a.shape != weights.shape: + if axis is None: + raise TypeError( + "Axis must be specified when shapes of a and weights differ." + ) + if weights.ndim != 1: + raise TypeError( + "1D weights expected when shapes of a and weights differ." + ) + if weights.shape[0] != a.shape[axis]: + raise ValueError( + "Length of weights not compatible with specified axis." + ) + + # setup weight to broadcast along axis + weights = torch.broadcast_to(weights, (a.ndim - 1) * (1,) + weights.shape) + weights = weights.swapaxes(-1, axis) + + # do the work + result_dtype = _dtypes_impl.result_type_impl(a, weights) + numerator = sum(a * weights, axis, dtype=result_dtype) + wsum = sum(weights, axis, dtype=result_dtype) + result = numerator / wsum + + # We process keepdims manually because the decorator does not deal with variadic returns + if keepdims: + result = _util.apply_keepdims(result, axis, a.ndim) + + if returned: + if wsum.shape != result.shape: + wsum = torch.broadcast_to(wsum, result.shape).clone() + return result, wsum + else: + return result + + +# Not using deco_axis_expand as it assumes that axis is the second arg +def quantile( + a: ArrayLike, + q: ArrayLike, + axis: AxisLike = None, + out: Optional[OutArray] = None, + overwrite_input=False, + method="linear", + keepdims: KeepDims = False, + *, + interpolation: NotImplementedType = None, +): + if overwrite_input: + # raise NotImplementedError("overwrite_input in quantile not implemented.") + # NumPy documents that `overwrite_input` MAY modify inputs: + # https://numpy.org/doc/stable/reference/generated/numpy.percentile.html#numpy-percentile + # Here we choose to work out-of-place because why not. + pass + + if not a.dtype.is_floating_point: + dtype = _dtypes_impl.default_dtypes().float_dtype + a = a.to(dtype) + + # edge case: torch.quantile only supports float32 and float64 + if a.dtype == torch.float16: + a = a.to(torch.float32) + + if axis is None: + a = a.flatten() + q = q.flatten() + axis = (0,) + else: + axis = _util.normalize_axis_tuple(axis, a.ndim) + + # FIXME(Mario) Doesn't np.quantile accept a tuple? + # torch.quantile does accept a number. If we don't want to implement the tuple behaviour + # (it's deffo low prio) change `normalize_axis_tuple` into a normalize_axis index above. + axis = _util.allow_only_single_axis(axis) + + q = _util.cast_if_needed(q, a.dtype) + + return torch.quantile(a, q, axis=axis, interpolation=method) + + +def percentile( + a: ArrayLike, + q: ArrayLike, + axis: AxisLike = None, + out: Optional[OutArray] = None, + overwrite_input=False, + method="linear", + keepdims: KeepDims = False, + *, + interpolation: NotImplementedType = None, +): + # np.percentile(float_tensor, 30) : q.dtype is int64 => q / 100.0 is float32 + if _dtypes_impl.python_type_for_torch(q.dtype) == int: + q = q.to(_dtypes_impl.default_dtypes().float_dtype) + qq = q / 100.0 + + return quantile( + a, + qq, + axis=axis, + overwrite_input=overwrite_input, + method=method, + keepdims=keepdims, + interpolation=interpolation, + ) + + +def median( + a: ArrayLike, + axis=None, + out: Optional[OutArray] = None, + overwrite_input=False, + keepdims: KeepDims = False, +): + return quantile( + a, + torch.as_tensor(0.5), + axis=axis, + overwrite_input=overwrite_input, + out=out, + keepdims=keepdims, + ) diff --git a/phivenv/Lib/site-packages/torch/_numpy/_ufuncs.py b/phivenv/Lib/site-packages/torch/_numpy/_ufuncs.py new file mode 100644 index 0000000000000000000000000000000000000000..139aa89ebc5016dad0c522efaa2970c308f2df59 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/_ufuncs.py @@ -0,0 +1,334 @@ +# mypy: ignore-errors + +from __future__ import annotations + +from typing import Optional + +import torch + +from . import _binary_ufuncs_impl, _dtypes_impl, _unary_ufuncs_impl, _util +from ._normalizations import ( + ArrayLike, + ArrayLikeOrScalar, + CastingModes, + DTypeLike, + normalizer, + NotImplementedType, + OutArray, +) + + +def _ufunc_postprocess(result, out, casting): + if out is not None: + result = _util.typecast_tensor(result, out.dtype.torch_dtype, casting) + result = torch.broadcast_to(result, out.shape) + return result + + +# ############# Binary ufuncs ###################### + +_binary = [ + name + for name in dir(_binary_ufuncs_impl) + if not name.startswith("_") and name not in ["torch", "matmul", "divmod", "ldexp"] +] + + +NEP50_FUNCS = ( + "add", + "subtract", + "multiply", + "floor_divide", + "true_divide", + "divide", + "remainder", + "bitwise_and", + "bitwise_or", + "bitwise_xor", + "bitwise_left_shift", + "bitwise_right_shift", + "hypot", + "arctan2", + "logaddexp", + "logaddexp2", + "heaviside", + "copysign", + "fmax", + "minimum", + "fmin", + "maximum", + "fmod", + "gcd", + "lcm", + "pow", +) + + +def deco_binary_ufunc(torch_func): + """Common infra for binary ufuncs. + + Normalize arguments, sort out type casting, broadcasting and delegate to + the pytorch functions for the actual work. + """ + + @normalizer + def wrapped( + x1: ArrayLikeOrScalar, + x2: ArrayLikeOrScalar, + /, + out: Optional[OutArray] = None, + *, + where: NotImplementedType = True, + casting: Optional[CastingModes] = "same_kind", + order: NotImplementedType = "K", + dtype: Optional[DTypeLike] = None, + subok: NotImplementedType = False, + signature: NotImplementedType = None, + extobj: NotImplementedType = None, + ): + if dtype is not None: + + def cast(x, dtype): + if isinstance(x, torch.Tensor): + return _util.typecast_tensor(x, dtype, casting) + else: + return torch.as_tensor(x, dtype=dtype) + + x1 = cast(x1, dtype) + x2 = cast(x2, dtype) + elif isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor): + dtype = _dtypes_impl.result_type_impl(x1, x2) + x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting) + else: + x1, x2 = _dtypes_impl.nep50_to_tensors( + x1, x2, torch_func.__name__ in NEP50_FUNCS, torch_func.__name__ + ) + + result = torch_func(x1, x2) + + return _ufunc_postprocess(result, out, casting) + + wrapped.__qualname__ = torch_func.__name__ + wrapped.__name__ = torch_func.__name__ + + return wrapped + + +# matmul's signature is _slightly_ different from other ufuncs: +# - no where=... +# - additional axis=..., axes=... +# - no NEP50 scalars in or out +@normalizer +def matmul( + x1: ArrayLike, + x2: ArrayLike, + /, + out: Optional[OutArray] = None, + *, + casting: Optional[CastingModes] = "same_kind", + order: NotImplementedType = "K", + dtype: Optional[DTypeLike] = None, + subok: NotImplementedType = False, + signature: NotImplementedType = None, + extobj: NotImplementedType = None, + axes: NotImplementedType = None, + axis: NotImplementedType = None, +): + if dtype is None: + dtype = _dtypes_impl.result_type_impl(x1, x2) + x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting) + + result = _binary_ufuncs_impl.matmul(x1, x2) + + result = _ufunc_postprocess(result, out, casting) + return result + + +# ldexp casting is special : the dtype of the result == dtype of the 1st arg +@normalizer +def ldexp( + x1: ArrayLikeOrScalar, + x2: ArrayLikeOrScalar, + /, + out: Optional[OutArray] = None, + *, + where: NotImplementedType = True, + casting: Optional[CastingModes] = "same_kind", + order: NotImplementedType = "K", + dtype: Optional[DTypeLike] = None, + subok: NotImplementedType = False, + signature: NotImplementedType = None, + extobj: NotImplementedType = None, +): + if dtype is not None: + if isinstance(x1, torch.Tensor): + x1 = _util.typecast_tensor(x1, dtype, casting) + else: + x1 = torch.as_tensor(x1, dtype=dtype) + else: + if not isinstance(x1, torch.Tensor): + x1 = torch.as_tensor(x1) + x1 = _util.cast_int_to_float(x1) + + x2 = torch.as_tensor(x2) + # the second arg must be integer + if _dtypes_impl._category(x2.dtype) != 1: + raise ValueError("ldexp 2nd arg must be integer") + + result = _binary_ufuncs_impl.ldexp(x1, x2) + + if x1.dtype == torch.float16: + # torch.ldexp(f16, int) -> f32, undo it + result = result.to(torch.float16) + + return _ufunc_postprocess(result, out, casting) + + +# nin=2, nout=2 +@normalizer +def divmod( + x1: ArrayLike, + x2: ArrayLike, + out1: Optional[OutArray] = None, + out2: Optional[OutArray] = None, + /, + out: tuple[Optional[OutArray], Optional[OutArray]] = (None, None), + *, + where: NotImplementedType = True, + casting: Optional[CastingModes] = "same_kind", + order: NotImplementedType = "K", + dtype: Optional[DTypeLike] = None, + subok: NotImplementedType = False, + signature: NotImplementedType = None, + extobj: NotImplementedType = None, +): + # make sure we either have no out arrays at all, or there is either + # out1, out2, or out=tuple, but not both + num_outs = sum(x is not None for x in [out1, out2]) + if num_outs == 1: + raise ValueError("both out1 and out2 need to be provided") + elif num_outs == 2: + o1, o2 = out + if o1 is not None or o2 is not None: + raise TypeError( + "cannot specify 'out' as both a positional and keyword argument" + ) + else: + out1, out2 = out + + if dtype is None: + dtype = _dtypes_impl.result_type_impl(x1, x2) + x1, x2 = _util.typecast_tensors((x1, x2), dtype, casting) + + quot, rem = _binary_ufuncs_impl.divmod(x1, x2) + + quot = _ufunc_postprocess(quot, out1, casting) + rem = _ufunc_postprocess(rem, out2, casting) + return quot, rem + + +# +# Attach ufuncs to this module, for a further export to the public namespace in __init__.py +# +for name in _binary: + ufunc = getattr(_binary_ufuncs_impl, name) + vars()[name] = deco_binary_ufunc(ufunc) + + +def modf(x, /, *args, **kwds): + quot, rem = divmod(x, 1, *args, **kwds) + return rem, quot + + +_binary = _binary + ["divmod", "modf", "matmul", "ldexp"] + + +# ############# Unary ufuncs ###################### + + +_unary = [ + name + for name in dir(_unary_ufuncs_impl) + if not name.startswith("_") and name != "torch" +] + + +# these are ufunc(int) -> float +_fp_unary = [ + "arccos", + "arccosh", + "arcsin", + "arcsinh", + "arctan", + "arctanh", + "cbrt", + "cos", + "cosh", + "deg2rad", + "degrees", + "exp", + "exp2", + "expm1", + "log", + "log10", + "log1p", + "log2", + "rad2deg", + "radians", + "reciprocal", + "sin", + "sinh", + "sqrt", + "square", + "tan", + "tanh", + "trunc", +] + + +def deco_unary_ufunc(torch_func): + """Common infra for unary ufuncs. + + Normalize arguments, sort out type casting, broadcasting and delegate to + the pytorch functions for the actual work. + """ + + @normalizer + def wrapped( + x: ArrayLike, + /, + out: Optional[OutArray] = None, + *, + where=True, + casting: Optional[CastingModes] = "same_kind", + order="K", + dtype: Optional[DTypeLike] = None, + subok: NotImplementedType = False, + signature=None, + extobj=None, + ): + if dtype is not None: + x = _util.typecast_tensor(x, dtype, casting) + + if torch_func.__name__ in _fp_unary: + x = _util.cast_int_to_float(x) + + result = torch_func(x) + result = _ufunc_postprocess(result, out, casting) + return result + + wrapped.__qualname__ = torch_func.__name__ + wrapped.__name__ = torch_func.__name__ + + return wrapped + + +# +# Attach ufuncs to this module, for a further export to the public namespace in __init__.py +# +for name in _unary: + ufunc = getattr(_unary_ufuncs_impl, name) + vars()[name] = deco_unary_ufunc(ufunc) + + +__all__ = _binary + _unary # noqa: PLE0605 diff --git a/phivenv/Lib/site-packages/torch/_numpy/_unary_ufuncs_impl.py b/phivenv/Lib/site-packages/torch/_numpy/_unary_ufuncs_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..00827fbabeb8ba1c3a6dfc72ce3826cfb9bbd3c1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/_unary_ufuncs_impl.py @@ -0,0 +1,72 @@ +# mypy: ignore-errors + +"""Export torch work functions for unary ufuncs, rename/tweak to match numpy. +This listing is further exported to public symbols in the `_numpy/_ufuncs.py` module. +""" + +import torch +from torch import ( # noqa: F401 + absolute as fabs, + arccos, + arccosh, + arcsin, + arcsinh, + arctan, + arctanh, + bitwise_not, + bitwise_not as invert, + ceil, + conj_physical as conjugate, + cos, + cosh, + deg2rad, + deg2rad as radians, + exp, + exp2, + expm1, + floor, + isfinite, + isinf, + isnan, + log, + log10, + log1p, + log2, + logical_not, + negative, + rad2deg, + rad2deg as degrees, + reciprocal, + round as fix, + round as rint, + sign, + signbit, + sin, + sinh, + sqrt, + square, + tan, + tanh, + trunc, +) + + +# special cases: torch does not export these names +def cbrt(x): + return torch.pow(x, 1 / 3) + + +def positive(x): + return +x + + +def absolute(x): + # work around torch.absolute not impl for bools + if x.dtype == torch.bool: + return x + return torch.absolute(x) + + +# TODO set __name__ and __qualname__ +abs = absolute +conj = conjugate diff --git a/phivenv/Lib/site-packages/torch/_numpy/_util.py b/phivenv/Lib/site-packages/torch/_numpy/_util.py new file mode 100644 index 0000000000000000000000000000000000000000..7f7792af8e2549eb1fc39749e8db54b8935e296e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/_util.py @@ -0,0 +1,261 @@ +# mypy: ignore-errors + +"""Assorted utilities, which do not need anything other then torch and stdlib. +""" + +import operator + +import torch + +from . import _dtypes_impl + + +# https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504 +def is_sequence(seq): + if isinstance(seq, str): + return False + try: + len(seq) + except Exception: + return False + return True + + +class AxisError(ValueError, IndexError): + pass + + +class UFuncTypeError(TypeError, RuntimeError): + pass + + +def cast_if_needed(tensor, dtype): + # NB: no casting if dtype=None + if dtype is not None and tensor.dtype != dtype: + tensor = tensor.to(dtype) + return tensor + + +def cast_int_to_float(x): + # cast integers and bools to the default float dtype + if _dtypes_impl._category(x.dtype) < 2: + x = x.to(_dtypes_impl.default_dtypes().float_dtype) + return x + + +# a replica of the version in ./numpy/numpy/core/src/multiarray/common.h +def normalize_axis_index(ax, ndim, argname=None): + if not (-ndim <= ax < ndim): + raise AxisError(f"axis {ax} is out of bounds for array of dimension {ndim}") + if ax < 0: + ax += ndim + return ax + + +# from https://github.com/numpy/numpy/blob/main/numpy/core/numeric.py#L1378 +def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): + """ + Normalizes an axis argument into a tuple of non-negative integer axes. + + This handles shorthands such as ``1`` and converts them to ``(1,)``, + as well as performing the handling of negative indices covered by + `normalize_axis_index`. + + By default, this forbids axes from being specified multiple times. + Used internally by multi-axis-checking logic. + + Parameters + ---------- + axis : int, iterable of int + The un-normalized index or indices of the axis. + ndim : int + The number of dimensions of the array that `axis` should be normalized + against. + argname : str, optional + A prefix to put before the error message, typically the name of the + argument. + allow_duplicate : bool, optional + If False, the default, disallow an axis from being specified twice. + + Returns + ------- + normalized_axes : tuple of int + The normalized axis index, such that `0 <= normalized_axis < ndim` + """ + # Optimization to speed-up the most common cases. + if type(axis) not in (tuple, list): + try: + axis = [operator.index(axis)] + except TypeError: + pass + # Going via an iterator directly is slower than via list comprehension. + axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis]) + if not allow_duplicate and len(set(map(int, axis))) != len(axis): + if argname: + raise ValueError(f"repeated axis in `{argname}` argument") + else: + raise ValueError("repeated axis") + return axis + + +def allow_only_single_axis(axis): + if axis is None: + return axis + if len(axis) != 1: + raise NotImplementedError("does not handle tuple axis") + return axis[0] + + +def expand_shape(arr_shape, axis): + # taken from numpy 1.23.x, expand_dims function + if type(axis) not in (list, tuple): + axis = (axis,) + out_ndim = len(axis) + len(arr_shape) + axis = normalize_axis_tuple(axis, out_ndim) + shape_it = iter(arr_shape) + shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)] + return shape + + +def apply_keepdims(tensor, axis, ndim): + if axis is None: + # tensor was a scalar + shape = (1,) * ndim + tensor = tensor.expand(shape).contiguous() + else: + shape = expand_shape(tensor.shape, axis) + tensor = tensor.reshape(shape) + return tensor + + +def axis_none_flatten(*tensors, axis=None): + """Flatten the arrays if axis is None.""" + if axis is None: + tensors = tuple(ar.flatten() for ar in tensors) + return tensors, 0 + else: + return tensors, axis + + +def typecast_tensor(t, target_dtype, casting): + """Dtype-cast tensor to target_dtype. + + Parameters + ---------- + t : torch.Tensor + The tensor to cast + target_dtype : torch dtype object + The array dtype to cast all tensors to + casting : str + The casting mode, see `np.can_cast` + + Returns + ------- + `torch.Tensor` of the `target_dtype` dtype + + Raises + ------ + ValueError + if the argument cannot be cast according to the `casting` rule + + """ + can_cast = _dtypes_impl.can_cast_impl + + if not can_cast(t.dtype, target_dtype, casting=casting): + raise TypeError( + f"Cannot cast array data from {t.dtype} to" + f" {target_dtype} according to the rule '{casting}'" + ) + return cast_if_needed(t, target_dtype) + + +def typecast_tensors(tensors, target_dtype, casting): + return tuple(typecast_tensor(t, target_dtype, casting) for t in tensors) + + +def _try_convert_to_tensor(obj): + try: + tensor = torch.as_tensor(obj) + except Exception as e: + mesg = f"failed to convert {obj} to ndarray. \nInternal error is: {str(e)}." + raise NotImplementedError(mesg) # noqa: B904 + return tensor + + +def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0): + """The core logic of the array(...) function. + + Parameters + ---------- + obj : tensor_like + The thing to coerce + dtype : torch.dtype object or None + Coerce to this torch dtype + copy : bool + Copy or not + ndmin : int + The results as least this many dimensions + is_weak : bool + Whether obj is a weakly typed python scalar. + + Returns + ------- + tensor : torch.Tensor + a tensor object with requested dtype, ndim and copy semantics. + + Notes + ----- + This is almost a "tensor_like" coersion function. Does not handle wrapper + ndarrays (those should be handled in the ndarray-aware layer prior to + invoking this function). + """ + if isinstance(obj, torch.Tensor): + tensor = obj + else: + # tensor.dtype is the pytorch default, typically float32. If obj's elements + # are not exactly representable in float32, we've lost precision: + # >>> torch.as_tensor(1e12).item() - 1e12 + # -4096.0 + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(_dtypes_impl.get_default_dtype_for(torch.float32)) + try: + tensor = _try_convert_to_tensor(obj) + finally: + torch.set_default_dtype(default_dtype) + + # type cast if requested + tensor = cast_if_needed(tensor, dtype) + + # adjust ndim if needed + ndim_extra = ndmin - tensor.ndim + if ndim_extra > 0: + tensor = tensor.view((1,) * ndim_extra + tensor.shape) + + # copy if requested + if copy: + tensor = tensor.clone() + + return tensor + + +def ndarrays_to_tensors(*inputs): + """Convert all ndarrays from `inputs` to tensors. (other things are intact)""" + from ._ndarray import ndarray + + if len(inputs) == 0: + return ValueError() + elif len(inputs) == 1: + input_ = inputs[0] + if isinstance(input_, ndarray): + return input_.tensor + elif isinstance(input_, tuple): + result = [] + for sub_input in input_: + sub_result = ndarrays_to_tensors(sub_input) + result.append(sub_result) + return tuple(result) + else: + return input_ + else: + assert isinstance(inputs, tuple) # sanity check + return ndarrays_to_tensors(inputs) diff --git a/phivenv/Lib/site-packages/torch/_numpy/fft.py b/phivenv/Lib/site-packages/torch/_numpy/fft.py new file mode 100644 index 0000000000000000000000000000000000000000..ac26d8bc787c90023cd6b0e7a4b9abcb336dee92 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/fft.py @@ -0,0 +1,130 @@ +# mypy: ignore-errors + +from __future__ import annotations + +import functools + +import torch + +from . import _dtypes_impl, _util +from ._normalizations import ArrayLike, normalizer + + +def upcast(func): + """NumPy fft casts inputs to 64 bit and *returns 64-bit results*.""" + + @functools.wraps(func) + def wrapped(tensor, *args, **kwds): + target_dtype = ( + _dtypes_impl.default_dtypes().complex_dtype + if tensor.is_complex() + else _dtypes_impl.default_dtypes().float_dtype + ) + tensor = _util.cast_if_needed(tensor, target_dtype) + return func(tensor, *args, **kwds) + + return wrapped + + +@normalizer +@upcast +def fft(a: ArrayLike, n=None, axis=-1, norm=None): + return torch.fft.fft(a, n, dim=axis, norm=norm) + + +@normalizer +@upcast +def ifft(a: ArrayLike, n=None, axis=-1, norm=None): + return torch.fft.ifft(a, n, dim=axis, norm=norm) + + +@normalizer +@upcast +def rfft(a: ArrayLike, n=None, axis=-1, norm=None): + return torch.fft.rfft(a, n, dim=axis, norm=norm) + + +@normalizer +@upcast +def irfft(a: ArrayLike, n=None, axis=-1, norm=None): + return torch.fft.irfft(a, n, dim=axis, norm=norm) + + +@normalizer +@upcast +def fftn(a: ArrayLike, s=None, axes=None, norm=None): + return torch.fft.fftn(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def ifftn(a: ArrayLike, s=None, axes=None, norm=None): + return torch.fft.ifftn(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def rfftn(a: ArrayLike, s=None, axes=None, norm=None): + return torch.fft.rfftn(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def irfftn(a: ArrayLike, s=None, axes=None, norm=None): + return torch.fft.irfftn(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def fft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None): + return torch.fft.fft2(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def ifft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None): + return torch.fft.ifft2(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def rfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None): + return torch.fft.rfft2(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def irfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None): + return torch.fft.irfft2(a, s, dim=axes, norm=norm) + + +@normalizer +@upcast +def hfft(a: ArrayLike, n=None, axis=-1, norm=None): + return torch.fft.hfft(a, n, dim=axis, norm=norm) + + +@normalizer +@upcast +def ihfft(a: ArrayLike, n=None, axis=-1, norm=None): + return torch.fft.ihfft(a, n, dim=axis, norm=norm) + + +@normalizer +def fftfreq(n, d=1.0): + return torch.fft.fftfreq(n, d) + + +@normalizer +def rfftfreq(n, d=1.0): + return torch.fft.rfftfreq(n, d) + + +@normalizer +def fftshift(x: ArrayLike, axes=None): + return torch.fft.fftshift(x, axes) + + +@normalizer +def ifftshift(x: ArrayLike, axes=None): + return torch.fft.ifftshift(x, axes) diff --git a/phivenv/Lib/site-packages/torch/_numpy/linalg.py b/phivenv/Lib/site-packages/torch/_numpy/linalg.py new file mode 100644 index 0000000000000000000000000000000000000000..a4422b638f8ccc9f4fd601dcc458ab814d4aea21 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/linalg.py @@ -0,0 +1,243 @@ +# mypy: ignore-errors + +from __future__ import annotations + +import functools +import math +from typing import TYPE_CHECKING + +import torch + +from . import _dtypes_impl, _util +from ._normalizations import ArrayLike, KeepDims, normalizer + + +if TYPE_CHECKING: + from collections.abc import Sequence + + +class LinAlgError(Exception): + pass + + +def _atleast_float_1(a): + if not (a.dtype.is_floating_point or a.dtype.is_complex): + a = a.to(_dtypes_impl.default_dtypes().float_dtype) + return a + + +def _atleast_float_2(a, b): + dtyp = _dtypes_impl.result_type_impl(a, b) + if not (dtyp.is_floating_point or dtyp.is_complex): + dtyp = _dtypes_impl.default_dtypes().float_dtype + + a = _util.cast_if_needed(a, dtyp) + b = _util.cast_if_needed(b, dtyp) + return a, b + + +def linalg_errors(func): + @functools.wraps(func) + def wrapped(*args, **kwds): + try: + return func(*args, **kwds) + except torch._C._LinAlgError as e: + raise LinAlgError(*e.args) # noqa: B904 + + return wrapped + + +# ### Matrix and vector products ### + + +@normalizer +@linalg_errors +def matrix_power(a: ArrayLike, n): + a = _atleast_float_1(a) + return torch.linalg.matrix_power(a, n) + + +@normalizer +@linalg_errors +def multi_dot(inputs: Sequence[ArrayLike], *, out=None): + return torch.linalg.multi_dot(inputs) + + +# ### Solving equations and inverting matrices ### + + +@normalizer +@linalg_errors +def solve(a: ArrayLike, b: ArrayLike): + a, b = _atleast_float_2(a, b) + return torch.linalg.solve(a, b) + + +@normalizer +@linalg_errors +def lstsq(a: ArrayLike, b: ArrayLike, rcond=None): + a, b = _atleast_float_2(a, b) + # NumPy is using gelsd: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/umath_linalg.cpp#L3991 + # on CUDA, only `gels` is available though, so use it instead + driver = "gels" if a.is_cuda or b.is_cuda else "gelsd" + return torch.linalg.lstsq(a, b, rcond=rcond, driver=driver) + + +@normalizer +@linalg_errors +def inv(a: ArrayLike): + a = _atleast_float_1(a) + result = torch.linalg.inv(a) + return result + + +@normalizer +@linalg_errors +def pinv(a: ArrayLike, rcond=1e-15, hermitian=False): + a = _atleast_float_1(a) + return torch.linalg.pinv(a, rtol=rcond, hermitian=hermitian) + + +@normalizer +@linalg_errors +def tensorsolve(a: ArrayLike, b: ArrayLike, axes=None): + a, b = _atleast_float_2(a, b) + return torch.linalg.tensorsolve(a, b, dims=axes) + + +@normalizer +@linalg_errors +def tensorinv(a: ArrayLike, ind=2): + a = _atleast_float_1(a) + return torch.linalg.tensorinv(a, ind=ind) + + +# ### Norms and other numbers ### + + +@normalizer +@linalg_errors +def det(a: ArrayLike): + a = _atleast_float_1(a) + return torch.linalg.det(a) + + +@normalizer +@linalg_errors +def slogdet(a: ArrayLike): + a = _atleast_float_1(a) + return torch.linalg.slogdet(a) + + +@normalizer +@linalg_errors +def cond(x: ArrayLike, p=None): + x = _atleast_float_1(x) + + # check if empty + # cf: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744 + if x.numel() == 0 and math.prod(x.shape[-2:]) == 0: + raise LinAlgError("cond is not defined on empty arrays") + + result = torch.linalg.cond(x, p=p) + + # Convert nans to infs (numpy does it in a data-dependent way, depending on + # whether the input array has nans or not) + # XXX: NumPy does this: https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744 + return torch.where(torch.isnan(result), float("inf"), result) + + +@normalizer +@linalg_errors +def matrix_rank(a: ArrayLike, tol=None, hermitian=False): + a = _atleast_float_1(a) + + if a.ndim < 2: + return int((a != 0).any()) + + if tol is None: + # follow https://github.com/numpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1885 + atol = 0 + rtol = max(a.shape[-2:]) * torch.finfo(a.dtype).eps + else: + atol, rtol = tol, 0 + return torch.linalg.matrix_rank(a, atol=atol, rtol=rtol, hermitian=hermitian) + + +@normalizer +@linalg_errors +def norm(x: ArrayLike, ord=None, axis=None, keepdims: KeepDims = False): + x = _atleast_float_1(x) + return torch.linalg.norm(x, ord=ord, dim=axis) + + +# ### Decompositions ### + + +@normalizer +@linalg_errors +def cholesky(a: ArrayLike): + a = _atleast_float_1(a) + return torch.linalg.cholesky(a) + + +@normalizer +@linalg_errors +def qr(a: ArrayLike, mode="reduced"): + a = _atleast_float_1(a) + result = torch.linalg.qr(a, mode=mode) + if mode == "r": + # match NumPy + result = result.R + return result + + +@normalizer +@linalg_errors +def svd(a: ArrayLike, full_matrices=True, compute_uv=True, hermitian=False): + a = _atleast_float_1(a) + if not compute_uv: + return torch.linalg.svdvals(a) + + # NB: ignore the hermitian= argument (no pytorch equivalent) + result = torch.linalg.svd(a, full_matrices=full_matrices) + return result + + +# ### Eigenvalues and eigenvectors ### + + +@normalizer +@linalg_errors +def eig(a: ArrayLike): + a = _atleast_float_1(a) + w, vt = torch.linalg.eig(a) + + if not a.is_complex() and w.is_complex() and (w.imag == 0).all(): + w = w.real + vt = vt.real + return w, vt + + +@normalizer +@linalg_errors +def eigh(a: ArrayLike, UPLO="L"): + a = _atleast_float_1(a) + return torch.linalg.eigh(a, UPLO=UPLO) + + +@normalizer +@linalg_errors +def eigvals(a: ArrayLike): + a = _atleast_float_1(a) + result = torch.linalg.eigvals(a) + if not a.is_complex() and result.is_complex() and (result.imag == 0).all(): + result = result.real + return result + + +@normalizer +@linalg_errors +def eigvalsh(a: ArrayLike, UPLO="L"): + a = _atleast_float_1(a) + return torch.linalg.eigvalsh(a, UPLO=UPLO) diff --git a/phivenv/Lib/site-packages/torch/_numpy/random.py b/phivenv/Lib/site-packages/torch/_numpy/random.py new file mode 100644 index 0000000000000000000000000000000000000000..57155b7bf9f081366dac3cfe706bc5b0c7231a2d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/random.py @@ -0,0 +1,191 @@ +# mypy: ignore-errors + +"""Wrapper to mimic (parts of) np.random API surface. + +NumPy has strict guarantees on reproducibility etc; here we don't give any. + +Q: default dtype is float64 in numpy + +""" +from __future__ import annotations + +import functools +from math import sqrt +from typing import Optional + +import torch + +from . import _dtypes_impl, _util +from ._normalizations import array_or_scalar, ArrayLike, normalizer + + +__all__ = [ + "seed", + "random_sample", + "sample", + "random", + "rand", + "randn", + "normal", + "choice", + "randint", + "shuffle", + "uniform", +] + + +def use_numpy_random(): + # local import to avoid ref cycles + import torch._dynamo.config as config + + return config.use_numpy_random_stream + + +def deco_stream(func): + @functools.wraps(func) + def inner(*args, **kwds): + if not use_numpy_random(): + return func(*args, **kwds) + else: + import numpy + + from ._ndarray import ndarray + + f = getattr(numpy.random, func.__name__) + + # numpy funcs accept numpy ndarrays, unwrap + args = tuple( + arg.tensor.numpy() if isinstance(arg, ndarray) else arg for arg in args + ) + kwds = { + key: val.tensor.numpy() if isinstance(val, ndarray) else val + for key, val in kwds.items() + } + + value = f(*args, **kwds) + + # `value` can be either numpy.ndarray or python scalar (or None) + if isinstance(value, numpy.ndarray): + value = ndarray(torch.as_tensor(value)) + + return value + + return inner + + +@deco_stream +def seed(seed=None): + if seed is not None: + torch.random.manual_seed(seed) + + +@deco_stream +def random_sample(size=None): + if size is None: + size = () + dtype = _dtypes_impl.default_dtypes().float_dtype + values = torch.empty(size, dtype=dtype).uniform_() + return array_or_scalar(values, return_scalar=size == ()) + + +def rand(*size): + if size == (): + size = None + return random_sample(size) + + +sample = random_sample +random = random_sample + + +@deco_stream +def uniform(low=0.0, high=1.0, size=None): + if size is None: + size = () + dtype = _dtypes_impl.default_dtypes().float_dtype + values = torch.empty(size, dtype=dtype).uniform_(low, high) + return array_or_scalar(values, return_scalar=size == ()) + + +@deco_stream +def randn(*size): + dtype = _dtypes_impl.default_dtypes().float_dtype + values = torch.randn(size, dtype=dtype) + return array_or_scalar(values, return_scalar=size == ()) + + +@deco_stream +def normal(loc=0.0, scale=1.0, size=None): + if size is None: + size = () + dtype = _dtypes_impl.default_dtypes().float_dtype + values = torch.empty(size, dtype=dtype).normal_(loc, scale) + return array_or_scalar(values, return_scalar=size == ()) + + +@deco_stream +def shuffle(x): + # no @normalizer because we do not cast e.g. lists to tensors + from ._ndarray import ndarray + + if isinstance(x, torch.Tensor): + tensor = x + elif isinstance(x, ndarray): + tensor = x.tensor + else: + raise NotImplementedError("We do not random.shuffle lists in-place") + + perm = torch.randperm(tensor.shape[0]) + xp = tensor[perm] + tensor.copy_(xp) + + +@deco_stream +def randint(low, high=None, size=None): + if size is None: + size = () + if not isinstance(size, (tuple, list)): + size = (size,) + if high is None: + low, high = 0, low + values = torch.randint(low, high, size=size) + return array_or_scalar(values, int, return_scalar=size == ()) + + +@deco_stream +@normalizer +def choice(a: ArrayLike, size=None, replace=True, p: Optional[ArrayLike] = None): + # https://stackoverflow.com/questions/59461811/random-choice-with-pytorch + if a.numel() == 1: + a = torch.arange(a) + + # TODO: check a.dtype is integer -- cf np.random.choice(3.4) which raises + + # number of draws + if size is None: + num_el = 1 + elif _util.is_sequence(size): + num_el = 1 + for el in size: + num_el *= el + else: + num_el = size + + # prepare the probabilities + if p is None: + p = torch.ones_like(a) / a.shape[0] + + # cf https://github.com/numpy/numpy/blob/main/numpy/random/mtrand.pyx#L973 + atol = sqrt(torch.finfo(p.dtype).eps) + if abs(p.sum() - 1.0) > atol: + raise ValueError("probabilities do not sum to 1.") + + # actually sample + indices = torch.multinomial(p, num_el, replacement=replace) + + if _util.is_sequence(size): + indices = indices.reshape(size) + + samples = a[indices] + + return samples diff --git a/phivenv/Lib/site-packages/torch/_numpy/testing/__init__.py b/phivenv/Lib/site-packages/torch/_numpy/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5780dd1ef69b5a171e70ddc1ffa9f1679cfa98d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/testing/__init__.py @@ -0,0 +1,20 @@ +# mypy: ignore-errors + +from .utils import ( + _gen_alignment_data, + assert_, + assert_allclose, + assert_almost_equal, + assert_array_almost_equal, + assert_array_equal, + assert_array_less, + assert_equal, + assert_raises_regex, + assert_warns, + HAS_REFCOUNT, + IS_WASM, + suppress_warnings, +) + + +# from .testing import assert_allclose # FIXME diff --git a/phivenv/Lib/site-packages/torch/_numpy/testing/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/testing/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a4da6d5036d1c874018c63c0e480b45ff3890f4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/testing/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/testing/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_numpy/testing/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..546e600f540f442f0840ec06bebad4c8b0c8612e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_numpy/testing/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_numpy/testing/utils.py b/phivenv/Lib/site-packages/torch/_numpy/testing/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..581d35307920a9830b79c952a1d0c612848f318b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_numpy/testing/utils.py @@ -0,0 +1,2386 @@ +# mypy: ignore-errors + +""" +Utility function to facilitate testing. + +""" +import contextlib +import gc +import operator +import os +import platform +import pprint +import re +import shutil +import sys +import warnings +from functools import wraps +from io import StringIO +from tempfile import mkdtemp, mkstemp +from warnings import WarningMessage + +import torch._numpy as np +from torch._numpy import arange, asarray as asanyarray, empty, float32, intp, ndarray + + +__all__ = [ + "assert_equal", + "assert_almost_equal", + "assert_approx_equal", + "assert_array_equal", + "assert_array_less", + "assert_string_equal", + "assert_", + "assert_array_almost_equal", + "build_err_msg", + "decorate_methods", + "print_assert_equal", + "verbose", + "assert_", + "assert_array_almost_equal_nulp", + "assert_raises_regex", + "assert_array_max_ulp", + "assert_warns", + "assert_no_warnings", + "assert_allclose", + "IgnoreException", + "clear_and_catch_warnings", + "temppath", + "tempdir", + "IS_PYPY", + "HAS_REFCOUNT", + "IS_WASM", + "suppress_warnings", + "assert_array_compare", + "assert_no_gc_cycles", + "break_cycles", + "IS_PYSTON", +] + + +verbose = 0 + +IS_WASM = platform.machine() in ["wasm32", "wasm64"] +IS_PYPY = sys.implementation.name == "pypy" +IS_PYSTON = hasattr(sys, "pyston_version_info") +HAS_REFCOUNT = getattr(sys, "getrefcount", None) is not None and not IS_PYSTON + + +def assert_(val, msg=""): + """ + Assert that works in release mode. + Accepts callable msg to allow deferring evaluation until failure. + + The Python built-in ``assert`` does not work when executing code in + optimized mode (the ``-O`` flag) - no byte-code is generated for it. + + For documentation on usage, refer to the Python documentation. + + """ + __tracebackhide__ = True # Hide traceback for py.test + if not val: + try: + smsg = msg() + except TypeError: + smsg = msg + raise AssertionError(smsg) + + +def gisnan(x): + return np.isnan(x) + + +def gisfinite(x): + return np.isfinite(x) + + +def gisinf(x): + return np.isinf(x) + + +def build_err_msg( + arrays, + err_msg, + header="Items are not equal:", + verbose=True, + names=("ACTUAL", "DESIRED"), + precision=8, +): + msg = ["\n" + header] + if err_msg: + if err_msg.find("\n") == -1 and len(err_msg) < 79 - len(header): + msg = [msg[0] + " " + err_msg] + else: + msg.append(err_msg) + if verbose: + for i, a in enumerate(arrays): + if isinstance(a, ndarray): + # precision argument is only needed if the objects are ndarrays + # r_func = partial(array_repr, precision=precision) + r_func = ndarray.__repr__ + else: + r_func = repr + + try: + r = r_func(a) + except Exception as exc: + r = f"[repr failed for <{type(a).__name__}>: {exc}]" + if r.count("\n") > 3: + r = "\n".join(r.splitlines()[:3]) + r += "..." + msg.append(f" {names[i]}: {r}") + return "\n".join(msg) + + +def assert_equal(actual, desired, err_msg="", verbose=True): + """ + Raises an AssertionError if two objects are not equal. + + Given two objects (scalars, lists, tuples, dictionaries or numpy arrays), + check that all elements of these objects are equal. An exception is raised + at the first conflicting values. + + When one of `actual` and `desired` is a scalar and the other is array_like, + the function checks that each element of the array_like object is equal to + the scalar. + + This function handles NaN comparisons as if NaN was a "normal" number. + That is, AssertionError is not raised if both objects have NaNs in the same + positions. This is in contrast to the IEEE standard on NaNs, which says + that NaN compared to anything must return False. + + Parameters + ---------- + actual : array_like + The object to check. + desired : array_like + The expected object. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal. + + Examples + -------- + >>> np.testing.assert_equal([4,5], [4,6]) + Traceback (most recent call last): + ... + AssertionError: + Items are not equal: + item=1 + ACTUAL: 5 + DESIRED: 6 + + The following comparison does not raise an exception. There are NaNs + in the inputs, but they are in the same positions. + + >>> np.testing.assert_equal(np.array([1.0, 2.0, np.nan]), [1, 2, np.nan]) + + """ + __tracebackhide__ = True # Hide traceback for py.test + + num_nones = sum([actual is None, desired is None]) + if num_nones == 1: + raise AssertionError(f"Not equal: {actual} != {desired}") + elif num_nones == 2: + return True + # else, carry on + + if isinstance(actual, np.DType) or isinstance(desired, np.DType): + result = actual == desired + if not result: + raise AssertionError(f"Not equal: {actual} != {desired}") + else: + return True + + if isinstance(desired, str) and isinstance(actual, str): + assert actual == desired + return + + if isinstance(desired, dict): + if not isinstance(actual, dict): + raise AssertionError(repr(type(actual))) + assert_equal(len(actual), len(desired), err_msg, verbose) + for k in desired.keys(): + if k not in actual: + raise AssertionError(repr(k)) + assert_equal(actual[k], desired[k], f"key={k!r}\n{err_msg}", verbose) + return + if isinstance(desired, (list, tuple)) and isinstance(actual, (list, tuple)): + assert_equal(len(actual), len(desired), err_msg, verbose) + for k in range(len(desired)): + assert_equal(actual[k], desired[k], f"item={k!r}\n{err_msg}", verbose) + return + + from torch._numpy import imag, iscomplexobj, isscalar, ndarray, real, signbit + + if isinstance(actual, ndarray) or isinstance(desired, ndarray): + return assert_array_equal(actual, desired, err_msg, verbose) + msg = build_err_msg([actual, desired], err_msg, verbose=verbose) + + # Handle complex numbers: separate into real/imag to handle + # nan/inf/negative zero correctly + # XXX: catch ValueError for subclasses of ndarray where iscomplex fail + try: + usecomplex = iscomplexobj(actual) or iscomplexobj(desired) + except (ValueError, TypeError): + usecomplex = False + + if usecomplex: + if iscomplexobj(actual): + actualr = real(actual) + actuali = imag(actual) + else: + actualr = actual + actuali = 0 + if iscomplexobj(desired): + desiredr = real(desired) + desiredi = imag(desired) + else: + desiredr = desired + desiredi = 0 + try: + assert_equal(actualr, desiredr) + assert_equal(actuali, desiredi) + except AssertionError: + raise AssertionError(msg) # noqa: B904 + + # isscalar test to check cases such as [np.nan] != np.nan + if isscalar(desired) != isscalar(actual): + raise AssertionError(msg) + + # Inf/nan/negative zero handling + try: + isdesnan = gisnan(desired) + isactnan = gisnan(actual) + if isdesnan and isactnan: + return # both nan, so equal + + if desired == 0 and actual == 0: + if not signbit(desired) == signbit(actual): + raise AssertionError(msg) + + except (TypeError, ValueError, NotImplementedError): + pass + + try: + # Explicitly use __eq__ for comparison, gh-2552 + if not (desired == actual): + raise AssertionError(msg) + + except (DeprecationWarning, FutureWarning) as e: + # this handles the case when the two types are not even comparable + if "elementwise == comparison" in e.args[0]: + raise AssertionError(msg) # noqa: B904 + else: + raise + + +def print_assert_equal(test_string, actual, desired): + """ + Test if two objects are equal, and print an error message if test fails. + + The test is performed with ``actual == desired``. + + Parameters + ---------- + test_string : str + The message supplied to AssertionError. + actual : object + The object to test for equality against `desired`. + desired : object + The expected result. + + Examples + -------- + >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 1]) # doctest: +SKIP + >>> np.testing.print_assert_equal('Test XYZ of func xyz', [0, 1], [0, 2]) # doctest: +SKIP + Traceback (most recent call last): + ... + AssertionError: Test XYZ of func xyz failed + ACTUAL: + [0, 1] + DESIRED: + [0, 2] + + """ + __tracebackhide__ = True # Hide traceback for py.test + import pprint + + if not (actual == desired): + msg = StringIO() + msg.write(test_string) + msg.write(" failed\nACTUAL: \n") + pprint.pprint(actual, msg) + msg.write("DESIRED: \n") + pprint.pprint(desired, msg) + raise AssertionError(msg.getvalue()) + + +def assert_almost_equal(actual, desired, decimal=7, err_msg="", verbose=True): + """ + Raises an AssertionError if two items are not equal up to desired + precision. + + .. note:: It is recommended to use one of `assert_allclose`, + `assert_array_almost_equal_nulp` or `assert_array_max_ulp` + instead of this function for more consistent floating point + comparisons. + + The test verifies that the elements of `actual` and `desired` satisfy. + + ``abs(desired-actual) < float64(1.5 * 10**(-decimal))`` + + That is a looser test than originally documented, but agrees with what the + actual implementation in `assert_array_almost_equal` did up to rounding + vagaries. An exception is raised at conflicting values. For ndarrays this + delegates to assert_array_almost_equal + + Parameters + ---------- + actual : array_like + The object to check. + desired : array_like + The expected object. + decimal : int, optional + Desired precision, default is 7. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Examples + -------- + >>> from torch._numpy.testing import assert_almost_equal + >>> assert_almost_equal(2.3333333333333, 2.33333334) + >>> assert_almost_equal(2.3333333333333, 2.33333334, decimal=10) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 10 decimals + ACTUAL: 2.3333333333333 + DESIRED: 2.33333334 + + >>> assert_almost_equal(np.array([1.0,2.3333333333333]), + ... np.array([1.0,2.33333334]), decimal=9) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 9 decimals + + Mismatched elements: 1 / 2 (50%) + Max absolute difference: 6.666699636781459e-09 + Max relative difference: 2.8571569790287484e-09 + x: torch.ndarray([1.0000, 2.3333], dtype=float64) + y: torch.ndarray([1.0000, 2.3333], dtype=float64) + + """ + __tracebackhide__ = True # Hide traceback for py.test + from torch._numpy import imag, iscomplexobj, ndarray, real + + # Handle complex numbers: separate into real/imag to handle + # nan/inf/negative zero correctly + # XXX: catch ValueError for subclasses of ndarray where iscomplex fail + try: + usecomplex = iscomplexobj(actual) or iscomplexobj(desired) + except ValueError: + usecomplex = False + + def _build_err_msg(): + header = f"Arrays are not almost equal to {decimal:d} decimals" + return build_err_msg([actual, desired], err_msg, verbose=verbose, header=header) + + if usecomplex: + if iscomplexobj(actual): + actualr = real(actual) + actuali = imag(actual) + else: + actualr = actual + actuali = 0 + if iscomplexobj(desired): + desiredr = real(desired) + desiredi = imag(desired) + else: + desiredr = desired + desiredi = 0 + try: + assert_almost_equal(actualr, desiredr, decimal=decimal) + assert_almost_equal(actuali, desiredi, decimal=decimal) + except AssertionError: + raise AssertionError(_build_err_msg()) # noqa: B904 + + if isinstance(actual, (ndarray, tuple, list)) or isinstance( + desired, (ndarray, tuple, list) + ): + return assert_array_almost_equal(actual, desired, decimal, err_msg) + try: + # If one of desired/actual is not finite, handle it specially here: + # check that both are nan if any is a nan, and test for equality + # otherwise + if not (gisfinite(desired) and gisfinite(actual)): + if gisnan(desired) or gisnan(actual): + if not (gisnan(desired) and gisnan(actual)): + raise AssertionError(_build_err_msg()) + else: + if not desired == actual: + raise AssertionError(_build_err_msg()) + return + except (NotImplementedError, TypeError): + pass + if abs(desired - actual) >= np.float64(1.5 * 10.0 ** (-decimal)): + raise AssertionError(_build_err_msg()) + + +def assert_approx_equal(actual, desired, significant=7, err_msg="", verbose=True): + """ + Raises an AssertionError if two items are not equal up to significant + digits. + + .. note:: It is recommended to use one of `assert_allclose`, + `assert_array_almost_equal_nulp` or `assert_array_max_ulp` + instead of this function for more consistent floating point + comparisons. + + Given two numbers, check that they are approximately equal. + Approximately equal is defined as the number of significant digits + that agree. + + Parameters + ---------- + actual : scalar + The object to check. + desired : scalar + The expected object. + significant : int, optional + Desired precision, default is 7. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Examples + -------- + >>> np.testing.assert_approx_equal(0.12345677777777e-20, 0.1234567e-20) # doctest: +SKIP + >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345671e-20, # doctest: +SKIP + ... significant=8) + >>> np.testing.assert_approx_equal(0.12345670e-20, 0.12345672e-20, # doctest: +SKIP + ... significant=8) + Traceback (most recent call last): + ... + AssertionError: + Items are not equal to 8 significant digits: + ACTUAL: 1.234567e-21 + DESIRED: 1.2345672e-21 + + the evaluated condition that raises the exception is + + >>> abs(0.12345670e-20/1e-21 - 0.12345672e-20/1e-21) >= 10**-(8-1) + True + + """ + __tracebackhide__ = True # Hide traceback for py.test + import numpy as np + + (actual, desired) = map(float, (actual, desired)) + if desired == actual: + return + # Normalized the numbers to be in range (-10.0,10.0) + # scale = float(pow(10,math.floor(math.log10(0.5*(abs(desired)+abs(actual)))))) + scale = 0.5 * (np.abs(desired) + np.abs(actual)) + scale = np.power(10, np.floor(np.log10(scale))) + try: + sc_desired = desired / scale + except ZeroDivisionError: + sc_desired = 0.0 + try: + sc_actual = actual / scale + except ZeroDivisionError: + sc_actual = 0.0 + msg = build_err_msg( + [actual, desired], + err_msg, + header=f"Items are not equal to {significant:d} significant digits:", + verbose=verbose, + ) + try: + # If one of desired/actual is not finite, handle it specially here: + # check that both are nan if any is a nan, and test for equality + # otherwise + if not (gisfinite(desired) and gisfinite(actual)): + if gisnan(desired) or gisnan(actual): + if not (gisnan(desired) and gisnan(actual)): + raise AssertionError(msg) + else: + if not desired == actual: + raise AssertionError(msg) + return + except (TypeError, NotImplementedError): + pass + if np.abs(sc_desired - sc_actual) >= np.power(10.0, -(significant - 1)): + raise AssertionError(msg) + + +def assert_array_compare( + comparison, + x, + y, + err_msg="", + verbose=True, + header="", + precision=6, + equal_nan=True, + equal_inf=True, + *, + strict=False, +): + __tracebackhide__ = True # Hide traceback for py.test + from torch._numpy import all, array, asarray, bool_, inf, isnan, max + + x = asarray(x) + y = asarray(y) + + def array2string(a): + return str(a) + + # original array for output formatting + ox, oy = x, y + + def func_assert_same_pos(x, y, func=isnan, hasval="nan"): + """Handling nan/inf. + + Combine results of running func on x and y, checking that they are True + at the same locations. + + """ + __tracebackhide__ = True # Hide traceback for py.test + x_id = func(x) + y_id = func(y) + # We include work-arounds here to handle three types of slightly + # pathological ndarray subclasses: + # (1) all() on `masked` array scalars can return masked arrays, so we + # use != True + # (2) __eq__ on some ndarray subclasses returns Python booleans + # instead of element-wise comparisons, so we cast to bool_() and + # use isinstance(..., bool) checks + # (3) subclasses with bare-bones __array_function__ implementations may + # not implement np.all(), so favor using the .all() method + # We are not committed to supporting such subclasses, but it's nice to + # support them if possible. + if (x_id == y_id).all().item() is not True: + msg = build_err_msg( + [x, y], + err_msg + f"\nx and y {hasval} location mismatch:", + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise AssertionError(msg) + # If there is a scalar, then here we know the array has the same + # flag as it everywhere, so we should return the scalar flag. + if isinstance(x_id, bool) or x_id.ndim == 0: + return bool_(x_id) + elif isinstance(y_id, bool) or y_id.ndim == 0: + return bool_(y_id) + else: + return y_id + + try: + if strict: + cond = x.shape == y.shape and x.dtype == y.dtype + else: + cond = (x.shape == () or y.shape == ()) or x.shape == y.shape + if not cond: + if x.shape != y.shape: + reason = f"\n(shapes {x.shape}, {y.shape} mismatch)" + else: + reason = f"\n(dtypes {x.dtype}, {y.dtype} mismatch)" + msg = build_err_msg( + [x, y], + err_msg + reason, + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise AssertionError(msg) + + flagged = bool_(False) + + if equal_nan: + flagged = func_assert_same_pos(x, y, func=isnan, hasval="nan") + + if equal_inf: + flagged |= func_assert_same_pos( + x, y, func=lambda xy: xy == +inf, hasval="+inf" + ) + flagged |= func_assert_same_pos( + x, y, func=lambda xy: xy == -inf, hasval="-inf" + ) + + if flagged.ndim > 0: + x, y = x[~flagged], y[~flagged] + # Only do the comparison if actual values are left + if x.size == 0: + return + elif flagged: + # no sense doing comparison if everything is flagged. + return + + val = comparison(x, y) + + if isinstance(val, bool): + cond = val + reduced = array([val]) + else: + reduced = val.ravel() + cond = reduced.all() + + # The below comparison is a hack to ensure that fully masked + # results, for which val.ravel().all() returns np.ma.masked, + # do not trigger a failure (np.ma.masked != True evaluates as + # np.ma.masked, which is falsy). + if not cond: + n_mismatch = reduced.size - int(reduced.sum(dtype=intp)) + n_elements = flagged.size if flagged.ndim != 0 else reduced.size + percent_mismatch = 100 * n_mismatch / n_elements + remarks = [ + f"Mismatched elements: {n_mismatch} / {n_elements} ({percent_mismatch:.3g}%)" + ] + + # with errstate(all='ignore'): + # ignore errors for non-numeric types + with contextlib.suppress(TypeError, RuntimeError): + error = abs(x - y) + if np.issubdtype(x.dtype, np.unsignedinteger): + error2 = abs(y - x) + np.minimum(error, error2, out=error) + max_abs_error = max(error) + remarks.append( + "Max absolute difference: " + array2string(max_abs_error.item()) + ) + + # note: this definition of relative error matches that one + # used by assert_allclose (found in np.isclose) + # Filter values where the divisor would be zero + nonzero = bool_(y != 0) + if all(~nonzero): + max_rel_error = array(inf) + else: + max_rel_error = max(error[nonzero] / abs(y[nonzero])) + remarks.append( + "Max relative difference: " + array2string(max_rel_error.item()) + ) + + err_msg += "\n" + "\n".join(remarks) + msg = build_err_msg( + [ox, oy], + err_msg, + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise AssertionError(msg) + except ValueError: + import traceback + + efmt = traceback.format_exc() + header = f"error during assertion:\n\n{efmt}\n\n{header}" + + msg = build_err_msg( + [x, y], + err_msg, + verbose=verbose, + header=header, + names=("x", "y"), + precision=precision, + ) + raise ValueError(msg) # noqa: B904 + + +def assert_array_equal(x, y, err_msg="", verbose=True, *, strict=False): + """ + Raises an AssertionError if two array_like objects are not equal. + + Given two array_like objects, check that the shape is equal and all + elements of these objects are equal (but see the Notes for the special + handling of a scalar). An exception is raised at shape mismatch or + conflicting values. In contrast to the standard usage in numpy, NaNs + are compared like numbers, no assertion is raised if both objects have + NaNs in the same positions. + + The usual caution for verifying equality with floating point numbers is + advised. + + Parameters + ---------- + x : array_like + The actual object to check. + y : array_like + The desired, expected object. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + strict : bool, optional + If True, raise an AssertionError when either the shape or the data + type of the array_like objects does not match. The special + handling for scalars mentioned in the Notes section is disabled. + + Raises + ------ + AssertionError + If actual and desired objects are not equal. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Notes + ----- + When one of `x` and `y` is a scalar and the other is array_like, the + function checks that each element of the array_like object is equal to + the scalar. This behaviour can be disabled with the `strict` parameter. + + Examples + -------- + The first assert does not raise an exception: + + >>> np.testing.assert_array_equal([1.0,2.33333,np.nan], + ... [np.exp(0),2.33333, np.nan]) + + Use `assert_allclose` or one of the nulp (number of floating point values) + functions for these cases instead: + + >>> np.testing.assert_allclose([1.0,np.pi,np.nan], + ... [1, np.sqrt(np.pi)**2, np.nan], + ... rtol=1e-10, atol=0) + + As mentioned in the Notes section, `assert_array_equal` has special + handling for scalars. Here the test checks that each value in `x` is 3: + + >>> x = np.full((2, 5), fill_value=3) + >>> np.testing.assert_array_equal(x, 3) + + Use `strict` to raise an AssertionError when comparing a scalar with an + array: + + >>> np.testing.assert_array_equal(x, 3, strict=True) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not equal + + (shapes (2, 5), () mismatch) + x: torch.ndarray([[3, 3, 3, 3, 3], + [3, 3, 3, 3, 3]]) + y: torch.ndarray(3) + + The `strict` parameter also ensures that the array data types match: + + >>> x = np.array([2, 2, 2]) + >>> y = np.array([2., 2., 2.], dtype=np.float32) + >>> np.testing.assert_array_equal(x, y, strict=True) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not equal + + (dtypes dtype("int64"), dtype("float32") mismatch) + x: torch.ndarray([2, 2, 2]) + y: torch.ndarray([2., 2., 2.]) + """ + __tracebackhide__ = True # Hide traceback for py.test + assert_array_compare( + operator.__eq__, + x, + y, + err_msg=err_msg, + verbose=verbose, + header="Arrays are not equal", + strict=strict, + ) + + +def assert_array_almost_equal(x, y, decimal=6, err_msg="", verbose=True): + """ + Raises an AssertionError if two objects are not equal up to desired + precision. + + .. note:: It is recommended to use one of `assert_allclose`, + `assert_array_almost_equal_nulp` or `assert_array_max_ulp` + instead of this function for more consistent floating point + comparisons. + + The test verifies identical shapes and that the elements of ``actual`` and + ``desired`` satisfy. + + ``abs(desired-actual) < 1.5 * 10**(-decimal)`` + + That is a looser test than originally documented, but agrees with what the + actual implementation did up to rounding vagaries. An exception is raised + at shape mismatch or conflicting values. In contrast to the standard usage + in numpy, NaNs are compared like numbers, no assertion is raised if both + objects have NaNs in the same positions. + + Parameters + ---------- + x : array_like + The actual object to check. + y : array_like + The desired, expected object. + decimal : int, optional + Desired precision, default is 6. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_allclose: Compare two array_like objects for equality with desired + relative and/or absolute precision. + assert_array_almost_equal_nulp, assert_array_max_ulp, assert_equal + + Examples + -------- + the first assert does not raise an exception + + >>> np.testing.assert_array_almost_equal([1.0,2.333,np.nan], + ... [1.0,2.333,np.nan]) + + >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], + ... [1.0,2.33339,np.nan], decimal=5) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 5 decimals + + Mismatched elements: 1 / 3 (33.3%) + Max absolute difference: 5.999999999994898e-05 + Max relative difference: 2.5713661239633743e-05 + x: torch.ndarray([1.0000, 2.3333, nan], dtype=float64) + y: torch.ndarray([1.0000, 2.3334, nan], dtype=float64) + + >>> np.testing.assert_array_almost_equal([1.0,2.33333,np.nan], + ... [1.0,2.33333, 5], decimal=5) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not almost equal to 5 decimals + + x and y nan location mismatch: + x: torch.ndarray([1.0000, 2.3333, nan], dtype=float64) + y: torch.ndarray([1.0000, 2.3333, 5.0000], dtype=float64) + + """ + __tracebackhide__ = True # Hide traceback for py.test + from torch._numpy import any as npany, float_, issubdtype, number, result_type + + def compare(x, y): + try: + if npany(gisinf(x)) or npany(gisinf(y)): + xinfid = gisinf(x) + yinfid = gisinf(y) + if not (xinfid == yinfid).all(): + return False + # if one item, x and y is +- inf + if x.size == y.size == 1: + return x == y + x = x[~xinfid] + y = y[~yinfid] + except (TypeError, NotImplementedError): + pass + + # make sure y is an inexact type to avoid abs(MIN_INT); will cause + # casting of x later. + dtype = result_type(y, 1.0) + y = asanyarray(y, dtype) + z = abs(x - y) + + if not issubdtype(z.dtype, number): + z = z.astype(float_) # handle object arrays + + return z < 1.5 * 10.0 ** (-decimal) + + assert_array_compare( + compare, + x, + y, + err_msg=err_msg, + verbose=verbose, + header=f"Arrays are not almost equal to {decimal:d} decimals", + precision=decimal, + ) + + +def assert_array_less(x, y, err_msg="", verbose=True): + """ + Raises an AssertionError if two array_like objects are not ordered by less + than. + + Given two array_like objects, check that the shape is equal and all + elements of the first object are strictly smaller than those of the + second object. An exception is raised at shape mismatch or incorrectly + ordered values. Shape mismatch does not raise if an object has zero + dimension. In contrast to the standard usage in numpy, NaNs are + compared, no assertion is raised if both objects have NaNs in the same + positions. + + + + Parameters + ---------- + x : array_like + The smaller object to check. + y : array_like + The larger object to compare. + err_msg : string + The error message to be printed in case of failure. + verbose : bool + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired objects are not equal. + + See Also + -------- + assert_array_equal: tests objects for equality + assert_array_almost_equal: test objects for equality up to precision + + + + Examples + -------- + >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan]) + >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan]) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not less-ordered + + Mismatched elements: 1 / 3 (33.3%) + Max absolute difference: 1.0 + Max relative difference: 0.5 + x: torch.ndarray([1., 1., nan], dtype=float64) + y: torch.ndarray([1., 2., nan], dtype=float64) + + >>> np.testing.assert_array_less([1.0, 4.0], 3) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not less-ordered + + Mismatched elements: 1 / 2 (50%) + Max absolute difference: 2.0 + Max relative difference: 0.6666666666666666 + x: torch.ndarray([1., 4.], dtype=float64) + y: torch.ndarray(3) + + >>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4]) + Traceback (most recent call last): + ... + AssertionError: + Arrays are not less-ordered + + (shapes (3,), (1,) mismatch) + x: torch.ndarray([1., 2., 3.], dtype=float64) + y: torch.ndarray([4]) + + """ + __tracebackhide__ = True # Hide traceback for py.test + assert_array_compare( + operator.__lt__, + x, + y, + err_msg=err_msg, + verbose=verbose, + header="Arrays are not less-ordered", + equal_inf=False, + ) + + +def assert_string_equal(actual, desired): + """ + Test if two strings are equal. + + If the given strings are equal, `assert_string_equal` does nothing. + If they are not equal, an AssertionError is raised, and the diff + between the strings is shown. + + Parameters + ---------- + actual : str + The string to test for equality against the expected string. + desired : str + The expected string. + + Examples + -------- + >>> np.testing.assert_string_equal('abc', 'abc') # doctest: +SKIP + >>> np.testing.assert_string_equal('abc', 'abcd') # doctest: +SKIP + Traceback (most recent call last): + File "", line 1, in + ... + AssertionError: Differences in strings: + - abc+ abcd? + + + """ + # delay import of difflib to reduce startup time + __tracebackhide__ = True # Hide traceback for py.test + import difflib + + if not isinstance(actual, str): + raise AssertionError(repr(type(actual))) + if not isinstance(desired, str): + raise AssertionError(repr(type(desired))) + if desired == actual: + return + + diff = list( + difflib.Differ().compare(actual.splitlines(True), desired.splitlines(True)) + ) + diff_list = [] + while diff: + d1 = diff.pop(0) + if d1.startswith(" "): + continue + if d1.startswith("- "): + l = [d1] + d2 = diff.pop(0) + if d2.startswith("? "): + l.append(d2) + d2 = diff.pop(0) + if not d2.startswith("+ "): + raise AssertionError(repr(d2)) + l.append(d2) + if diff: + d3 = diff.pop(0) + if d3.startswith("? "): + l.append(d3) + else: + diff.insert(0, d3) + if d2[2:] == d1[2:]: + continue + diff_list.extend(l) + continue + raise AssertionError(repr(d1)) + if not diff_list: + return + msg = f"Differences in strings:\n{''.join(diff_list).rstrip()}" + if actual != desired: + raise AssertionError(msg) + + +import unittest + + +class _Dummy(unittest.TestCase): + def nop(self): + pass + + +_d = _Dummy("nop") + + +def assert_raises_regex(exception_class, expected_regexp, *args, **kwargs): + """ + assert_raises_regex(exception_class, expected_regexp, callable, *args, + **kwargs) + assert_raises_regex(exception_class, expected_regexp) + + Fail unless an exception of class exception_class and with message that + matches expected_regexp is thrown by callable when invoked with arguments + args and keyword arguments kwargs. + + Alternatively, can be used as a context manager like `assert_raises`. + + Notes + ----- + .. versionadded:: 1.9.0 + + """ + __tracebackhide__ = True # Hide traceback for py.test + return _d.assertRaisesRegex(exception_class, expected_regexp, *args, **kwargs) + + +def decorate_methods(cls, decorator, testmatch=None): + """ + Apply a decorator to all methods in a class matching a regular expression. + + The given decorator is applied to all public methods of `cls` that are + matched by the regular expression `testmatch` + (``testmatch.search(methodname)``). Methods that are private, i.e. start + with an underscore, are ignored. + + Parameters + ---------- + cls : class + Class whose methods to decorate. + decorator : function + Decorator to apply to methods + testmatch : compiled regexp or str, optional + The regular expression. Default value is None, in which case the + nose default (``re.compile(r'(?:^|[\\b_\\.%s-])[Tt]est' % os.sep)``) + is used. + If `testmatch` is a string, it is compiled to a regular expression + first. + + """ + if testmatch is None: + testmatch = re.compile(rf"(?:^|[\\b_\\.{os.sep}-])[Tt]est") + else: + testmatch = re.compile(testmatch) + cls_attr = cls.__dict__ + + # delayed import to reduce startup time + from inspect import isfunction + + methods = [_m for _m in cls_attr.values() if isfunction(_m)] + for function in methods: + try: + if hasattr(function, "compat_func_name"): + funcname = function.compat_func_name + else: + funcname = function.__name__ + except AttributeError: + # not a function + continue + if testmatch.search(funcname) and not funcname.startswith("_"): + setattr(cls, funcname, decorator(function)) + return + + +def _assert_valid_refcount(op): + """ + Check that ufuncs don't mishandle refcount of object `1`. + Used in a few regression tests. + """ + if not HAS_REFCOUNT: + return True + + import gc + + import numpy as np + + b = np.arange(100 * 100).reshape(100, 100) + c = b + i = 1 + + gc.disable() + try: + rc = sys.getrefcount(i) + for _ in range(15): + d = op(b, c) + assert_(sys.getrefcount(i) >= rc) + finally: + gc.enable() + del d # for pyflakes + + +def assert_allclose( + actual, + desired, + rtol=1e-7, + atol=0, + equal_nan=True, + err_msg="", + verbose=True, + check_dtype=False, +): + """ + Raises an AssertionError if two objects are not equal up to desired + tolerance. + + Given two array_like objects, check that their shapes and all elements + are equal (but see the Notes for the special handling of a scalar). An + exception is raised if the shapes mismatch or any values conflict. In + contrast to the standard usage in numpy, NaNs are compared like numbers, + no assertion is raised if both objects have NaNs in the same positions. + + The test is equivalent to ``allclose(actual, desired, rtol, atol)`` (note + that ``allclose`` has different default values). It compares the difference + between `actual` and `desired` to ``atol + rtol * abs(desired)``. + + .. versionadded:: 1.5.0 + + Parameters + ---------- + actual : array_like + Array obtained. + desired : array_like + Array desired. + rtol : float, optional + Relative tolerance. + atol : float, optional + Absolute tolerance. + equal_nan : bool, optional. + If True, NaNs will compare equal. + err_msg : str, optional + The error message to be printed in case of failure. + verbose : bool, optional + If True, the conflicting values are appended to the error message. + + Raises + ------ + AssertionError + If actual and desired are not equal up to specified precision. + + See Also + -------- + assert_array_almost_equal_nulp, assert_array_max_ulp + + Notes + ----- + When one of `actual` and `desired` is a scalar and the other is + array_like, the function checks that each element of the array_like + object is equal to the scalar. + + Examples + -------- + >>> x = [1e-5, 1e-3, 1e-1] + >>> y = np.arccos(np.cos(x)) + >>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0) + + """ + __tracebackhide__ = True # Hide traceback for py.test + + def compare(x, y): + return np.isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) + + actual, desired = asanyarray(actual), asanyarray(desired) + header = f"Not equal to tolerance rtol={rtol:g}, atol={atol:g}" + + if check_dtype: + assert actual.dtype == desired.dtype + + assert_array_compare( + compare, + actual, + desired, + err_msg=str(err_msg), + verbose=verbose, + header=header, + equal_nan=equal_nan, + ) + + +def assert_array_almost_equal_nulp(x, y, nulp=1): + """ + Compare two arrays relatively to their spacing. + + This is a relatively robust method to compare two arrays whose amplitude + is variable. + + Parameters + ---------- + x, y : array_like + Input arrays. + nulp : int, optional + The maximum number of unit in the last place for tolerance (see Notes). + Default is 1. + + Returns + ------- + None + + Raises + ------ + AssertionError + If the spacing between `x` and `y` for one or more elements is larger + than `nulp`. + + See Also + -------- + assert_array_max_ulp : Check that all items of arrays differ in at most + N Units in the Last Place. + spacing : Return the distance between x and the nearest adjacent number. + + Notes + ----- + An assertion is raised if the following condition is not met:: + + abs(x - y) <= nulp * spacing(maximum(abs(x), abs(y))) + + Examples + -------- + >>> x = np.array([1., 1e-10, 1e-20]) + >>> eps = np.finfo(x.dtype).eps + >>> np.testing.assert_array_almost_equal_nulp(x, x*eps/2 + x) # doctest: +SKIP + + >>> np.testing.assert_array_almost_equal_nulp(x, x*eps + x) # doctest: +SKIP + Traceback (most recent call last): + ... + AssertionError: X and Y are not equal to 1 ULP (max is 2) + + """ + __tracebackhide__ = True # Hide traceback for py.test + import numpy as np + + ax = np.abs(x) + ay = np.abs(y) + ref = nulp * np.spacing(np.where(ax > ay, ax, ay)) + if not np.all(np.abs(x - y) <= ref): + if np.iscomplexobj(x) or np.iscomplexobj(y): + msg = f"X and Y are not equal to {nulp:d} ULP" + else: + max_nulp = np.max(nulp_diff(x, y)) + msg = f"X and Y are not equal to {nulp:d} ULP (max is {max_nulp:g})" + raise AssertionError(msg) + + +def assert_array_max_ulp(a, b, maxulp=1, dtype=None): + """ + Check that all items of arrays differ in at most N Units in the Last Place. + + Parameters + ---------- + a, b : array_like + Input arrays to be compared. + maxulp : int, optional + The maximum number of units in the last place that elements of `a` and + `b` can differ. Default is 1. + dtype : dtype, optional + Data-type to convert `a` and `b` to if given. Default is None. + + Returns + ------- + ret : ndarray + Array containing number of representable floating point numbers between + items in `a` and `b`. + + Raises + ------ + AssertionError + If one or more elements differ by more than `maxulp`. + + Notes + ----- + For computing the ULP difference, this API does not differentiate between + various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 + is zero). + + See Also + -------- + assert_array_almost_equal_nulp : Compare two arrays relatively to their + spacing. + + Examples + -------- + >>> a = np.linspace(0., 1., 100) + >>> res = np.testing.assert_array_max_ulp(a, np.arcsin(np.sin(a))) # doctest: +SKIP + + """ + __tracebackhide__ = True # Hide traceback for py.test + import numpy as np + + ret = nulp_diff(a, b, dtype) + if not np.all(ret <= maxulp): + raise AssertionError( + f"Arrays are not almost equal up to {maxulp:g} " + f"ULP (max difference is {np.max(ret):g} ULP)" + ) + return ret + + +def nulp_diff(x, y, dtype=None): + """For each item in x and y, return the number of representable floating + points between them. + + Parameters + ---------- + x : array_like + first input array + y : array_like + second input array + dtype : dtype, optional + Data-type to convert `x` and `y` to if given. Default is None. + + Returns + ------- + nulp : array_like + number of representable floating point numbers between each item in x + and y. + + Notes + ----- + For computing the ULP difference, this API does not differentiate between + various representations of NAN (ULP difference between 0x7fc00000 and 0xffc00000 + is zero). + + Examples + -------- + # By definition, epsilon is the smallest number such as 1 + eps != 1, so + # there should be exactly one ULP between 1 and 1 + eps + >>> nulp_diff(1, 1 + np.finfo(x.dtype).eps) # doctest: +SKIP + 1.0 + """ + import numpy as np + + if dtype: + x = np.asarray(x, dtype=dtype) + y = np.asarray(y, dtype=dtype) + else: + x = np.asarray(x) + y = np.asarray(y) + + t = np.common_type(x, y) + if np.iscomplexobj(x) or np.iscomplexobj(y): + raise NotImplementedError("_nulp not implemented for complex array") + + x = np.array([x], dtype=t) + y = np.array([y], dtype=t) + + x[np.isnan(x)] = np.nan + y[np.isnan(y)] = np.nan + + if not x.shape == y.shape: + raise ValueError(f"x and y do not have the same shape: {x.shape} - {y.shape}") + + def _diff(rx, ry, vdt): + diff = np.asarray(rx - ry, dtype=vdt) + return np.abs(diff) + + rx = integer_repr(x) + ry = integer_repr(y) + return _diff(rx, ry, t) + + +def _integer_repr(x, vdt, comp): + # Reinterpret binary representation of the float as sign-magnitude: + # take into account two-complement representation + # See also + # https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/ + rx = x.view(vdt) + if not (rx.size == 1): + rx[rx < 0] = comp - rx[rx < 0] + else: + if rx < 0: + rx = comp - rx + + return rx + + +def integer_repr(x): + """Return the signed-magnitude interpretation of the binary representation + of x.""" + import numpy as np + + if x.dtype == np.float16: + return _integer_repr(x, np.int16, np.int16(-(2**15))) + elif x.dtype == np.float32: + return _integer_repr(x, np.int32, np.int32(-(2**31))) + elif x.dtype == np.float64: + return _integer_repr(x, np.int64, np.int64(-(2**63))) + else: + raise ValueError(f"Unsupported dtype {x.dtype}") + + +@contextlib.contextmanager +def _assert_warns_context(warning_class, name=None): + __tracebackhide__ = True # Hide traceback for py.test + with suppress_warnings() as sup: + l = sup.record(warning_class) + yield + if not len(l) > 0: + name_str = f" when calling {name}" if name is not None else "" + raise AssertionError("No warning raised" + name_str) + + +def assert_warns(warning_class, *args, **kwargs): + """ + Fail unless the given callable throws the specified warning. + + A warning of class warning_class should be thrown by the callable when + invoked with arguments args and keyword arguments kwargs. + If a different type of warning is thrown, it will not be caught. + + If called with all arguments other than the warning class omitted, may be + used as a context manager: + + with assert_warns(SomeWarning): + do_something() + + The ability to be used as a context manager is new in NumPy v1.11.0. + + .. versionadded:: 1.4.0 + + Parameters + ---------- + warning_class : class + The class defining the warning that `func` is expected to throw. + func : callable, optional + Callable to test + *args : Arguments + Arguments for `func`. + **kwargs : Kwargs + Keyword arguments for `func`. + + Returns + ------- + The value returned by `func`. + + Examples + -------- + >>> import warnings + >>> def deprecated_func(num): + ... warnings.warn("Please upgrade", DeprecationWarning) + ... return num*num + >>> with np.testing.assert_warns(DeprecationWarning): + ... assert deprecated_func(4) == 16 + >>> # or passing a func + >>> ret = np.testing.assert_warns(DeprecationWarning, deprecated_func, 4) + >>> assert ret == 16 + """ + if not args: + return _assert_warns_context(warning_class) + + func = args[0] + args = args[1:] + with _assert_warns_context(warning_class, name=func.__name__): + return func(*args, **kwargs) + + +@contextlib.contextmanager +def _assert_no_warnings_context(name=None): + __tracebackhide__ = True # Hide traceback for py.test + with warnings.catch_warnings(record=True) as l: + warnings.simplefilter("always") + yield + if len(l) > 0: + name_str = f" when calling {name}" if name is not None else "" + raise AssertionError(f"Got warnings{name_str}: {l}") + + +def assert_no_warnings(*args, **kwargs): + """ + Fail if the given callable produces any warnings. + + If called with all arguments omitted, may be used as a context manager: + + with assert_no_warnings(): + do_something() + + The ability to be used as a context manager is new in NumPy v1.11.0. + + .. versionadded:: 1.7.0 + + Parameters + ---------- + func : callable + The callable to test. + \\*args : Arguments + Arguments passed to `func`. + \\*\\*kwargs : Kwargs + Keyword arguments passed to `func`. + + Returns + ------- + The value returned by `func`. + + """ + if not args: + return _assert_no_warnings_context() + + func = args[0] + args = args[1:] + with _assert_no_warnings_context(name=func.__name__): + return func(*args, **kwargs) + + +def _gen_alignment_data(dtype=float32, type="binary", max_size=24): + """ + generator producing data with different alignment and offsets + to test simd vectorization + + Parameters + ---------- + dtype : dtype + data type to produce + type : string + 'unary': create data for unary operations, creates one input + and output array + 'binary': create data for unary operations, creates two input + and output array + max_size : integer + maximum size of data to produce + + Returns + ------- + if type is 'unary' yields one output, one input array and a message + containing information on the data + if type is 'binary' yields one output array, two input array and a message + containing information on the data + + """ + ufmt = "unary offset=(%d, %d), size=%d, dtype=%r, %s" + bfmt = "binary offset=(%d, %d, %d), size=%d, dtype=%r, %s" + for o in range(3): + for s in range(o + 2, max(o + 3, max_size)): + if type == "unary": + + def inp(): + return arange(s, dtype=dtype)[o:] + + out = empty((s,), dtype=dtype)[o:] + yield out, inp(), ufmt % (o, o, s, dtype, "out of place") + d = inp() + yield d, d, ufmt % (o, o, s, dtype, "in place") + yield out[1:], inp()[:-1], ufmt % ( + o + 1, + o, + s - 1, + dtype, + "out of place", + ) + yield out[:-1], inp()[1:], ufmt % ( + o, + o + 1, + s - 1, + dtype, + "out of place", + ) + yield inp()[:-1], inp()[1:], ufmt % (o, o + 1, s - 1, dtype, "aliased") + yield inp()[1:], inp()[:-1], ufmt % (o + 1, o, s - 1, dtype, "aliased") + if type == "binary": + + def inp1(): + return arange(s, dtype=dtype)[o:] + + inp2 = inp1 + out = empty((s,), dtype=dtype)[o:] + yield out, inp1(), inp2(), bfmt % (o, o, o, s, dtype, "out of place") + d = inp1() + yield d, d, inp2(), bfmt % (o, o, o, s, dtype, "in place1") + d = inp2() + yield d, inp1(), d, bfmt % (o, o, o, s, dtype, "in place2") + yield out[1:], inp1()[:-1], inp2()[:-1], bfmt % ( + o + 1, + o, + o, + s - 1, + dtype, + "out of place", + ) + yield out[:-1], inp1()[1:], inp2()[:-1], bfmt % ( + o, + o + 1, + o, + s - 1, + dtype, + "out of place", + ) + yield out[:-1], inp1()[:-1], inp2()[1:], bfmt % ( + o, + o, + o + 1, + s - 1, + dtype, + "out of place", + ) + yield inp1()[1:], inp1()[:-1], inp2()[:-1], bfmt % ( + o + 1, + o, + o, + s - 1, + dtype, + "aliased", + ) + yield inp1()[:-1], inp1()[1:], inp2()[:-1], bfmt % ( + o, + o + 1, + o, + s - 1, + dtype, + "aliased", + ) + yield inp1()[:-1], inp1()[:-1], inp2()[1:], bfmt % ( + o, + o, + o + 1, + s - 1, + dtype, + "aliased", + ) + + +class IgnoreException(Exception): + "Ignoring this exception due to disabled feature" + + +@contextlib.contextmanager +def tempdir(*args, **kwargs): + """Context manager to provide a temporary test folder. + + All arguments are passed as this to the underlying tempfile.mkdtemp + function. + + """ + tmpdir = mkdtemp(*args, **kwargs) + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir) + + +@contextlib.contextmanager +def temppath(*args, **kwargs): + """Context manager for temporary files. + + Context manager that returns the path to a closed temporary file. Its + parameters are the same as for tempfile.mkstemp and are passed directly + to that function. The underlying file is removed when the context is + exited, so it should be closed at that time. + + Windows does not allow a temporary file to be opened if it is already + open, so the underlying file must be closed after opening before it + can be opened again. + + """ + fd, path = mkstemp(*args, **kwargs) + os.close(fd) + try: + yield path + finally: + os.remove(path) + + +class clear_and_catch_warnings(warnings.catch_warnings): + """Context manager that resets warning registry for catching warnings + + Warnings can be slippery, because, whenever a warning is triggered, Python + adds a ``__warningregistry__`` member to the *calling* module. This makes + it impossible to retrigger the warning in this module, whatever you put in + the warnings filters. This context manager accepts a sequence of `modules` + as a keyword argument to its constructor and: + + * stores and removes any ``__warningregistry__`` entries in given `modules` + on entry; + * resets ``__warningregistry__`` to its previous state on exit. + + This makes it possible to trigger any warning afresh inside the context + manager without disturbing the state of warnings outside. + + For compatibility with Python 3.0, please consider all arguments to be + keyword-only. + + Parameters + ---------- + record : bool, optional + Specifies whether warnings should be captured by a custom + implementation of ``warnings.showwarning()`` and be appended to a list + returned by the context manager. Otherwise None is returned by the + context manager. The objects appended to the list are arguments whose + attributes mirror the arguments to ``showwarning()``. + modules : sequence, optional + Sequence of modules for which to reset warnings registry on entry and + restore on exit. To work correctly, all 'ignore' filters should + filter by one of these modules. + + Examples + -------- + >>> import warnings + >>> with np.testing.clear_and_catch_warnings( # doctest: +SKIP + ... modules=[np.core.fromnumeric]): + ... warnings.simplefilter('always') + ... warnings.filterwarnings('ignore', module='np.core.fromnumeric') + ... # do something that raises a warning but ignore those in + ... # np.core.fromnumeric + """ + + class_modules = () + + def __init__(self, record=False, modules=()): + self.modules = set(modules).union(self.class_modules) + self._warnreg_copies = {} + super().__init__(record=record) + + def __enter__(self): + for mod in self.modules: + if hasattr(mod, "__warningregistry__"): + mod_reg = mod.__warningregistry__ + self._warnreg_copies[mod] = mod_reg.copy() + mod_reg.clear() + return super().__enter__() + + def __exit__(self, *exc_info): + super().__exit__(*exc_info) + for mod in self.modules: + if hasattr(mod, "__warningregistry__"): + mod.__warningregistry__.clear() + if mod in self._warnreg_copies: + mod.__warningregistry__.update(self._warnreg_copies[mod]) + + +class suppress_warnings: + """ + Context manager and decorator doing much the same as + ``warnings.catch_warnings``. + + However, it also provides a filter mechanism to work around + https://bugs.python.org/issue4180. + + This bug causes Python before 3.4 to not reliably show warnings again + after they have been ignored once (even within catch_warnings). It + means that no "ignore" filter can be used easily, since following + tests might need to see the warning. Additionally it allows easier + specificity for testing warnings and can be nested. + + Parameters + ---------- + forwarding_rule : str, optional + One of "always", "once", "module", or "location". Analogous to + the usual warnings module filter mode, it is useful to reduce + noise mostly on the outmost level. Unsuppressed and unrecorded + warnings will be forwarded based on this rule. Defaults to "always". + "location" is equivalent to the warnings "default", match by exact + location the warning warning originated from. + + Notes + ----- + Filters added inside the context manager will be discarded again + when leaving it. Upon entering all filters defined outside a + context will be applied automatically. + + When a recording filter is added, matching warnings are stored in the + ``log`` attribute as well as in the list returned by ``record``. + + If filters are added and the ``module`` keyword is given, the + warning registry of this module will additionally be cleared when + applying it, entering the context, or exiting it. This could cause + warnings to appear a second time after leaving the context if they + were configured to be printed once (default) and were already + printed before the context was entered. + + Nesting this context manager will work as expected when the + forwarding rule is "always" (default). Unfiltered and unrecorded + warnings will be passed out and be matched by the outer level. + On the outmost level they will be printed (or caught by another + warnings context). The forwarding rule argument can modify this + behaviour. + + Like ``catch_warnings`` this context manager is not threadsafe. + + Examples + -------- + + With a context manager:: + + with np.testing.suppress_warnings() as sup: + sup.filter(DeprecationWarning, "Some text") + sup.filter(module=np.ma.core) + log = sup.record(FutureWarning, "Does this occur?") + command_giving_warnings() + # The FutureWarning was given once, the filtered warnings were + # ignored. All other warnings abide outside settings (may be + # printed/error) + assert_(len(log) == 1) + assert_(len(sup.log) == 1) # also stored in log attribute + + Or as a decorator:: + + sup = np.testing.suppress_warnings() + sup.filter(module=np.ma.core) # module must match exactly + @sup + def some_function(): + # do something which causes a warning in np.ma.core + pass + """ + + def __init__(self, forwarding_rule="always"): + self._entered = False + + # Suppressions are either instance or defined inside one with block: + self._suppressions = [] + + if forwarding_rule not in {"always", "module", "once", "location"}: + raise ValueError("unsupported forwarding rule.") + self._forwarding_rule = forwarding_rule + + def _clear_registries(self): + if hasattr(warnings, "_filters_mutated"): + # clearing the registry should not be necessary on new pythons, + # instead the filters should be mutated. + warnings._filters_mutated() + return + # Simply clear the registry, this should normally be harmless, + # note that on new pythons it would be invalidated anyway. + for module in self._tmp_modules: + if hasattr(module, "__warningregistry__"): + module.__warningregistry__.clear() + + def _filter(self, category=Warning, message="", module=None, record=False): + if record: + record = [] # The log where to store warnings + else: + record = None + if self._entered: + if module is None: + warnings.filterwarnings("always", category=category, message=message) + else: + module_regex = module.__name__.replace(".", r"\.") + "$" + warnings.filterwarnings( + "always", category=category, message=message, module=module_regex + ) + self._tmp_modules.add(module) + self._clear_registries() + + self._tmp_suppressions.append( + (category, message, re.compile(message, re.IGNORECASE), module, record) + ) + else: + self._suppressions.append( + (category, message, re.compile(message, re.IGNORECASE), module, record) + ) + + return record + + def filter(self, category=Warning, message="", module=None): + """ + Add a new suppressing filter or apply it if the state is entered. + + Parameters + ---------- + category : class, optional + Warning class to filter + message : string, optional + Regular expression matching the warning message. + module : module, optional + Module to filter for. Note that the module (and its file) + must match exactly and cannot be a submodule. This may make + it unreliable for external modules. + + Notes + ----- + When added within a context, filters are only added inside + the context and will be forgotten when the context is exited. + """ + self._filter(category=category, message=message, module=module, record=False) + + def record(self, category=Warning, message="", module=None): + """ + Append a new recording filter or apply it if the state is entered. + + All warnings matching will be appended to the ``log`` attribute. + + Parameters + ---------- + category : class, optional + Warning class to filter + message : string, optional + Regular expression matching the warning message. + module : module, optional + Module to filter for. Note that the module (and its file) + must match exactly and cannot be a submodule. This may make + it unreliable for external modules. + + Returns + ------- + log : list + A list which will be filled with all matched warnings. + + Notes + ----- + When added within a context, filters are only added inside + the context and will be forgotten when the context is exited. + """ + return self._filter( + category=category, message=message, module=module, record=True + ) + + def __enter__(self): + if self._entered: + raise RuntimeError("cannot enter suppress_warnings twice.") + + self._orig_show = warnings.showwarning + self._filters = warnings.filters + warnings.filters = self._filters[:] + + self._entered = True + self._tmp_suppressions = [] + self._tmp_modules = set() + self._forwarded = set() + + self.log = [] # reset global log (no need to keep same list) + + for cat, mess, _, mod, log in self._suppressions: + if log is not None: + del log[:] # clear the log + if mod is None: + warnings.filterwarnings("always", category=cat, message=mess) + else: + module_regex = mod.__name__.replace(".", r"\.") + "$" + warnings.filterwarnings( + "always", category=cat, message=mess, module=module_regex + ) + self._tmp_modules.add(mod) + warnings.showwarning = self._showwarning + self._clear_registries() + + return self + + def __exit__(self, *exc_info): + warnings.showwarning = self._orig_show + warnings.filters = self._filters + self._clear_registries() + self._entered = False + del self._orig_show + del self._filters + + def _showwarning( + self, message, category, filename, lineno, *args, use_warnmsg=None, **kwargs + ): + for cat, _, pattern, mod, rec in (self._suppressions + self._tmp_suppressions)[ + ::-1 + ]: + if issubclass(category, cat) and pattern.match(message.args[0]) is not None: + if mod is None: + # Message and category match, either recorded or ignored + if rec is not None: + msg = WarningMessage( + message, category, filename, lineno, **kwargs + ) + self.log.append(msg) + rec.append(msg) + return + # Use startswith, because warnings strips the c or o from + # .pyc/.pyo files. + elif mod.__file__.startswith(filename): + # The message and module (filename) match + if rec is not None: + msg = WarningMessage( + message, category, filename, lineno, **kwargs + ) + self.log.append(msg) + rec.append(msg) + return + + # There is no filter in place, so pass to the outside handler + # unless we should only pass it once + if self._forwarding_rule == "always": + if use_warnmsg is None: + self._orig_show(message, category, filename, lineno, *args, **kwargs) + else: + self._orig_showmsg(use_warnmsg) + return + + if self._forwarding_rule == "once": + signature = (message.args, category) + elif self._forwarding_rule == "module": + signature = (message.args, category, filename) + elif self._forwarding_rule == "location": + signature = (message.args, category, filename, lineno) + + if signature in self._forwarded: + return + self._forwarded.add(signature) + if use_warnmsg is None: + self._orig_show(message, category, filename, lineno, *args, **kwargs) + else: + self._orig_showmsg(use_warnmsg) + + def __call__(self, func): + """ + Function decorator to apply certain suppressions to a whole + function. + """ + + @wraps(func) + def new_func(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return new_func + + +@contextlib.contextmanager +def _assert_no_gc_cycles_context(name=None): + __tracebackhide__ = True # Hide traceback for py.test + + # not meaningful to test if there is no refcounting + if not HAS_REFCOUNT: + yield + return + + assert_(gc.isenabled()) + gc.disable() + gc_debug = gc.get_debug() + try: + for _ in range(100): + if gc.collect() == 0: + break + else: + raise RuntimeError( + "Unable to fully collect garbage - perhaps a __del__ method " + "is creating more reference cycles?" + ) + + gc.set_debug(gc.DEBUG_SAVEALL) + yield + # gc.collect returns the number of unreachable objects in cycles that + # were found -- we are checking that no cycles were created in the context + n_objects_in_cycles = gc.collect() + objects_in_cycles = gc.garbage[:] + finally: + del gc.garbage[:] + gc.set_debug(gc_debug) + gc.enable() + + if n_objects_in_cycles: + name_str = f" when calling {name}" if name is not None else "" + raise AssertionError( + "Reference cycles were found{}: {} objects were collected, " + "of which {} are shown below:{}".format( + name_str, + n_objects_in_cycles, + len(objects_in_cycles), + "".join( + "\n {} object with id={}:\n {}".format( + type(o).__name__, + id(o), + pprint.pformat(o).replace("\n", "\n "), + ) + for o in objects_in_cycles + ), + ) + ) + + +def assert_no_gc_cycles(*args, **kwargs): + """ + Fail if the given callable produces any reference cycles. + + If called with all arguments omitted, may be used as a context manager: + + with assert_no_gc_cycles(): + do_something() + + .. versionadded:: 1.15.0 + + Parameters + ---------- + func : callable + The callable to test. + \\*args : Arguments + Arguments passed to `func`. + \\*\\*kwargs : Kwargs + Keyword arguments passed to `func`. + + Returns + ------- + Nothing. The result is deliberately discarded to ensure that all cycles + are found. + + """ + if not args: + return _assert_no_gc_cycles_context() + + func = args[0] + args = args[1:] + with _assert_no_gc_cycles_context(name=func.__name__): + func(*args, **kwargs) + + +def break_cycles(): + """ + Break reference cycles by calling gc.collect + Objects can call other objects' methods (for instance, another object's + __del__) inside their own __del__. On PyPy, the interpreter only runs + between calls to gc.collect, so multiple calls are needed to completely + release all cycles. + """ + + gc.collect() + if IS_PYPY: + # a few more, just to make sure all the finalizers are called + gc.collect() + gc.collect() + gc.collect() + gc.collect() + + +def requires_memory(free_bytes): + """Decorator to skip a test if not enough memory is available""" + import pytest + + def decorator(func): + @wraps(func) + def wrapper(*a, **kw): + msg = check_free_memory(free_bytes) + if msg is not None: + pytest.skip(msg) + + try: + return func(*a, **kw) + except MemoryError: + # Probably ran out of memory regardless: don't regard as failure + pytest.xfail("MemoryError raised") + + return wrapper + + return decorator + + +def check_free_memory(free_bytes): + """ + Check whether `free_bytes` amount of memory is currently free. + Returns: None if enough memory available, otherwise error message + """ + env_var = "NPY_AVAILABLE_MEM" + env_value = os.environ.get(env_var) + if env_value is not None: + try: + mem_free = _parse_size(env_value) + except ValueError as exc: + raise ValueError( # noqa: B904 + f"Invalid environment variable {env_var}: {exc}" + ) + + msg = ( + f"{free_bytes / 1e9} GB memory required, but environment variable " + f"NPY_AVAILABLE_MEM={env_value} set" + ) + else: + mem_free = _get_mem_available() + + if mem_free is None: + msg = ( + "Could not determine available memory; set NPY_AVAILABLE_MEM " + "environment variable (e.g. NPY_AVAILABLE_MEM=16GB) to run " + "the test." + ) + mem_free = -1 + else: + msg = f"{free_bytes / 1e9} GB memory required, but {mem_free / 1e9} GB available" + + return msg if mem_free < free_bytes else None + + +def _parse_size(size_str): + """Convert memory size strings ('12 GB' etc.) to float""" + suffixes = { + "": 1, + "b": 1, + "k": 1000, + "m": 1000**2, + "g": 1000**3, + "t": 1000**4, + "kb": 1000, + "mb": 1000**2, + "gb": 1000**3, + "tb": 1000**4, + "kib": 1024, + "mib": 1024**2, + "gib": 1024**3, + "tib": 1024**4, + } + + size_re = re.compile( + r"^\s*(\d+|\d+\.\d+)\s*({})\s*$".format("|".join(suffixes.keys())), + re.IGNORECASE, + ) + + m = size_re.match(size_str.lower()) + if not m or m.group(2) not in suffixes: + raise ValueError(f"value {size_str!r} not a valid size") + return int(float(m.group(1)) * suffixes[m.group(2)]) + + +def _get_mem_available(): + """Return available memory in bytes, or None if unknown.""" + try: + import psutil + + return psutil.virtual_memory().available + except (ImportError, AttributeError): + pass + + if sys.platform.startswith("linux"): + info = {} + with open("/proc/meminfo") as f: + for line in f: + p = line.split() + info[p[0].strip(":").lower()] = int(p[1]) * 1024 + + if "memavailable" in info: + # Linux >= 3.14 + return info["memavailable"] + else: + return info["memfree"] + info["cached"] + + return None + + +def _no_tracing(func): + """ + Decorator to temporarily turn off tracing for the duration of a test. + Needed in tests that check refcounting, otherwise the tracing itself + influences the refcounts + """ + if not hasattr(sys, "gettrace"): + return func + else: + + @wraps(func) + def wrapper(*args, **kwargs): + original_trace = sys.gettrace() + try: + sys.settrace(None) + return func(*args, **kwargs) + finally: + sys.settrace(original_trace) + + return wrapper + + +def _get_glibc_version(): + try: + ver = os.confstr("CS_GNU_LIBC_VERSION").rsplit(" ")[1] + except Exception: + ver = "0.0" + + return ver + + +_glibcver = _get_glibc_version() + + +def _glibc_older_than(x): + return _glibcver != "0.0" and _glibcver < x diff --git a/phivenv/Lib/site-packages/torch/_prims/__init__.py b/phivenv/Lib/site-packages/torch/_prims/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f19ce64ae15916824cbff73a383355d5a379c42a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_prims/__init__.py @@ -0,0 +1,2958 @@ +# mypy: allow-untyped-defs +import operator +from collections.abc import Sequence +from enum import Enum +from functools import partial, reduce +from typing import Callable, Optional, Union + +import torch +import torch._prims_common as utils +import torch.library +from torch import sym_float, Tensor +from torch._C import _get_default_device +from torch._higher_order_ops.effects import new_token_tensor +from torch._library.utils import is_functional_schema +from torch._prims.debug_prims import register_debug_prims +from torch._prims.rng_prims import register_rng_prims +from torch._prims_common import ( + Dim, + DimsSequenceType, + DimsType, + IntLike, + Number, + NumberType, + RETURN_TYPE, + ShapeType, + StrideType, + TensorLike, + TensorLikeType, + type_to_dtype, +) +from torch._prims_common.wrappers import backwards_not_supported +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.overrides import handle_torch_function, has_torch_function +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten + + +prim = torch.library.Library("prims", "DEF") +prim_impl = torch.library.Library("prims", "IMPL", "CompositeExplicitAutograd") +prim_backend_select_impl = torch.library.Library("prims", "IMPL", "BackendSelect") +prim_autograd_impl = torch.library.Library("prims", "IMPL", "Autograd") +prim_meta_impl = torch.library.Library("prims", "IMPL", "Meta") + +# Experimental module containing prototype "primitive" operations. + +__all__ = [ + # + # Common datastructures and helpers + # + "RETURN_TYPE", + # + # Elementwise unary prims + # + "abs", + "acos", + "acosh", + "asin", + "asinh", + "atan", + "atanh", + "cos", + "cosh", + "bessel_i0", + "bessel_i0e", + "bessel_i1", + "bessel_i1e", + "bessel_j0", + "bessel_j1", + "bitwise_not", + "cbrt", + "ceil", + "conj_physical", + "digamma", + "erf", + "erf_inv", + "erfc", + "erfcx", + "exp", + "expm1", + "exp2", + "fill", + "floor", + "imag", + "isfinite", + "lgamma", + "log", + "log1p", + "log2", + "log10", + "ndtri", + "neg", + "real", + "reciprocal", + "round", + "sign", + "signbit", + "sin", + "sinh", + "spherical_bessel_j0", + "sqrt", + "tan", + "tanh", + "trunc", + # + # Elementwise binary prims + # + "add", + "atan2", + "bitwise_and", + "bitwise_or", + "bitwise_xor", + # 'complex', # needs custom meta + "div", + "eq", + "fmax", + "fmin", + "fmod", + "frexp", + "gcd", + "ge", + "gt", + "hypot", + "igamma", + "igammac", + "le", + "lt", + "maximum", + "minimum", + "mul", + "ne", + "nextafter", + "pow", + "remainder", + "rsqrt", + "shift_left", + "shift_right_arithmetic", + "shift_right_logical", # not implemented + "sub", + "zeta", + # + # View prims + # + "as_strided", + "broadcast_in_dim", + "collapse_view", + "conj", + "expand_dims", + "slice", + "slice_in_dim", # implemented using slice -- make this a ref? + "split_dim", + "squeeze", + "transpose", + "view_of", + "view_element_type", + # + # Functionalized view mutations + # + "as_strided_scatter", + # + # Shape prims + # + "collapse", + "cat", + "reshape", + "rev", + # + # Conditional prims + # + "where", + # + # Data conversion and movement prims + # + "clone", + "convert_element_type", + "device_put", + "item", + "maximum_value", + "minimum_value", + "copy_strided", + # + # Inplace prims + # + "copy_to", + "resize", + # "_set", # Commented out, see note below + # + # Reduction prims + # + "amax", + "amin", + "prod", + "sum", + "xor_sum", + "var", + # + # Tensor Creation Prims + # + "empty_strided", + "empty_permuted", + "scalar_tensor", + "iota", + # + # Linear algebra (linalg) Prims + # + "svd", + # + # Randomness Prims + # + "normal", + "_uniform_helper", + # + # FFT prims + # + "fft_r2c", + "fft_c2c", + "fft_c2r", + # + # prims for making/sinking tokens + # + "_make_token", + "_sink_tokens", +] + + +def TensorMeta( + tensorlike: Optional[Union[NumberType, torch.Tensor]] = None, + *, + shape: Optional[ShapeType] = None, + strides: Optional[StrideType] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[Union[torch.device, str]] = None, +): + if isinstance(tensorlike, Number): + assert not shape and (shape is None or isinstance(shape, Sequence)) + assert not strides and (strides is None or isinstance(strides, Sequence)) + inferred_shape: tuple[int, ...] = () + inferred_strides: tuple[int, ...] = () + inferred_dtype = type_to_dtype(type(tensorlike)) + inferred_device = torch.device("cpu") + # TODO: This looks wrong, a number that is wrapped into a tensor + # needs to behave differently than a scalar tensor for type + # promotion purposes + elif tensorlike is not None: + assert isinstance(tensorlike, torch.Tensor) + inferred_shape = tuple(tensorlike.shape) + inferred_strides = tuple(tensorlike.stride()) + inferred_dtype = tensorlike.dtype + inferred_device = tensorlike.device + else: + # If no tensorlike "example" is given then all metadata + # must be provided explicitly + assert shape is not None + assert strides is not None + assert dtype is not None + assert device is not None + + shape = inferred_shape if shape is None else tuple(shape) # type: ignore[possibly-undefined] + strides = inferred_strides if strides is None else tuple(strides) # type: ignore[possibly-undefined] + dtype = inferred_dtype if dtype is None else dtype # type: ignore[possibly-undefined] + device = inferred_device if device is None else device # type: ignore[possibly-undefined] + + if isinstance(device, str): + device = torch.device(device) + + return torch.empty_strided(shape, strides, dtype=dtype, device=device) + + +def _make_prim( + *, + schema: str, + return_type: Union[RETURN_TYPE, tuple[RETURN_TYPE, ...]], + meta: Callable, + impl_aten: Callable, + doc: str, + tags: Optional[Sequence[torch.Tag]] = None, + use_old_custom_ops_api: bool = False, + register_conj_neg_fallthrough: bool = False, +): + """ + Creates a primitive operation. + + """ + + def _prim_impl(*args, **kwargs): + # always run the meta function because aten implementation will + # typically accept more inputs (e.g., it will do promotion and + # broadcasting) which we want to reject + meta(*args, **kwargs) + return impl_aten(*args, **kwargs) + + # Right now prims don't support autograd (we can and should add an + # argument that provides an implementation for backward here.) Because we + # don't have derivative formulas, we must setup a custom autograd function + # that raises an error if backwards is invoked + def _autograd_impl(*args, **kwargs): + return backwards_not_supported(_prim)(*args, **kwargs) + + def _backend_select_impl(*args, **kwargs): + if kwargs.get("device") and kwargs["device"].type == "meta": + return meta(*args, **kwargs) + if any(isinstance(x, torch.device) and x.type == "meta" for x in args): + return meta(*args, **kwargs) + else: + return _prim_impl(*args, **kwargs) + + name = schema.split("(")[0] + schema = schema[len(name) :] + + # register non-functional ops with old custom ops API + cpp_schema = torch._C.parse_schema(name + schema) + if use_old_custom_ops_api or not is_functional_schema(cpp_schema): + prim.define(name + schema, tags=torch.Tag.pt2_compliant_tag) + prim_impl.impl(name, _prim_impl) + prim_autograd_impl.impl(name, _autograd_impl) + prim_meta_impl.impl(name, meta) + else: + mutates_args = [ + arg.name + for arg in cpp_schema.arguments + if arg.alias_info is not None and arg.alias_info.is_write + ] + prim_def = torch.library.custom_op( + "prims::" + name, + _prim_impl, + mutates_args=tuple(mutates_args), + schema=schema, + ) + prim_def.register_fake(meta) + + # all view ops get conj/neg fallthroughs + if return_type == RETURN_TYPE.VIEW or register_conj_neg_fallthrough: + prim_def._lib.impl(name, torch.library.fallthrough_kernel, "Conjugate") + prim_def._lib.impl(name, torch.library.fallthrough_kernel, "Negative") + + _prim_packet = getattr(torch._ops.ops.prims, name) + _prim = _prim_packet.default + if tags: + _prim._tags = tags + elif aten_packet := getattr(torch.ops.aten, name, None): + overload_tags = [ + getattr(aten_packet, overload).tags for overload in aten_packet.overloads() + ] + tags_intersection = set(overload_tags[0]) + tags_intersection.intersection_update(*overload_tags[1:]) + + # dont inadvertently add to prim ops + tags_intersection.discard(torch.Tag.core) + # causes errors with python ref executor tests, none of the + # data dependent pytorch ops actually decompose to prims + tags_intersection.discard(torch.Tag.data_dependent_output) + + # iter over first tags for determinism + _prim._tags = tuple(t for t in overload_tags[0] if t in tags_intersection) + + from torch._subclasses.fake_tensor import contains_tensor_types + + if not any(contains_tensor_types(a.type) for a in _prim._schema.arguments) or str( + _prim + ) in [ + # See https://github.com/pytorch/pytorch/issues/103532 + "prims.device_put.default" + ]: + prim_backend_select_impl.impl(name, _backend_select_impl) + + for p in (_prim_packet, _prim): + p.__doc__ = doc + p.return_type = return_type # type: ignore[attr-defined] + + p.schema = schema + p.prim_impl = _prim_impl + p.prim_meta_impl = meta + p.impl_aten = impl_aten + + return _prim + + +class ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND(Enum): + DEFAULT = (0,) + INT_TO_FLOAT = (2,) + ALWAYS_BOOL = (3,) + COMPLEX_TO_FLOAT = (4,) + + +# TODO: implement dtype validation here, too, or on the corresponding refs +def _prim_elementwise_meta( + *args, + type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, + args_with_fixed_dtypes: Optional[tuple[TensorLikeType, ...]] = None, +) -> FakeTensor: + """ + Meta function for elementwise operations that produce outputs in the same dtype + as their inputs. + + Stride logic is currently incorrect. + """ + + assert len(args) > 0 + + utils.check_same_dtype(*args) + + args_ = list(args) + if args_with_fixed_dtypes is not None: + args_ = list(args_with_fixed_dtypes) + args_ + + utils.check_same_device(*args_, allow_cpu_scalar_tensors=True) + utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True) + + l2p_perm = utils.compute_elementwise_output_logical_to_physical_perm(*args_) + shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True) + + # Acquires the dtype + dtype = None + scalar_type = None + for arg in args: + if isinstance(arg, TensorLike): + if not utils.is_cpu_scalar_tensor(arg): + dtype = arg.dtype + break + else: + dtype = arg.dtype + elif isinstance(arg, Number): + scalar_type = type(arg) + + if dtype is None and scalar_type is not None: + dtype = utils.type_to_dtype(scalar_type) + + # Acquires the device (if it exists) or number + device = None + number = None + for arg in args_: + if isinstance(arg, TensorLike): + if utils.is_cpu_scalar_tensor(arg): + if device is None: + device = arg.device + # keep going, in case there is a cuda tensor later + else: + device = arg.device + break + + elif isinstance(arg, Number): + if number is None: + number = arg + + # NOTE: type promotion behavior here is mostly hidden from tests because + # references will typically handle the type promotion properly even if this doesn't + # (but getting it wrong will cause too many casts to be inserted in traces!) + if device is not None: + assert dtype is not None + if type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT: + dtype = dtype + elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL: + dtype = torch.bool + elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.INT_TO_FLOAT: + if utils.is_integer_dtype(dtype) or utils.is_boolean_dtype(dtype): + dtype = torch.get_default_dtype() + elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT: + if utils.is_complex_dtype(dtype): + dtype = utils.corresponding_real_dtype(dtype) + else: + dtype = dtype + + assert shape is not None + return torch.empty_permuted(shape, l2p_perm, device=device, dtype=dtype) # type: ignore[return-value] + + # Number case + # TODO: fix number type promotion (bool, complex->float) + + # For now for symint/float, just implementing the common / simple cases of (int,float,symint,symfloat) + seen_float = False + if isinstance(number, (torch.SymInt, torch.SymFloat)): + for a in args: + assert isinstance(a, (int, float, torch.SymInt, torch.SymFloat)), "NYI" + seen_float = seen_float or isinstance(a, (float, torch.SymFloat)) + if seen_float: + number = sym_float(number) + + return TensorMeta(number) # type: ignore[arg-type] + + +def _complex_only_elementwise_meta(*args, **kwargs): + torch._check( + utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported" + ) + return _prim_elementwise_meta(*args, **kwargs) + + +def _make_elementwise_unary_prim( + name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs +): + """ + Creates an elementwise unary prim. + """ + + return _make_prim( + schema=f"{name}(Tensor self) -> Tensor", + meta=partial(_prim_elementwise_meta, type_promotion=type_promotion), + return_type=RETURN_TYPE.NEW, + **kwargs, + ) + + +def _make_elementwise_binary_prim( + name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs +): + """ + Creates an elementwise binary prim. + """ + + return _make_prim( + schema=f"{name}(Tensor self, Tensor other) -> Tensor", + meta=partial(_prim_elementwise_meta, type_promotion=type_promotion), + return_type=RETURN_TYPE.NEW, + **kwargs, + ) + + +def _not_impl(*args, **kwargs): + raise NotImplementedError + + +# +# Elementwise unary operations +# + + +abs = _make_elementwise_unary_prim( + "abs", + impl_aten=torch.abs, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) + +acos = _make_elementwise_unary_prim( + "acos", + impl_aten=torch.acos, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +acosh = _make_elementwise_unary_prim( + "acosh", + impl_aten=torch.acosh, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +asin = _make_elementwise_unary_prim( + "asin", + impl_aten=torch.asin, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +asinh = _make_elementwise_unary_prim( + "asinh", + impl_aten=torch.asinh, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +atan = _make_elementwise_unary_prim( + "atan", + impl_aten=torch.atan, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +atanh = _make_elementwise_unary_prim( + "atanh", + impl_aten=torch.atanh, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +cos = _make_elementwise_unary_prim( + "cos", + impl_aten=torch.cos, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +cosh = _make_elementwise_unary_prim( + "cosh", + impl_aten=torch.cosh, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bessel_j0 = _make_elementwise_unary_prim( + "bessel_j0", + impl_aten=torch.special.bessel_j0, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bessel_j1 = _make_elementwise_unary_prim( + "bessel_j1", + impl_aten=torch.special.bessel_j1, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bessel_i0 = _make_elementwise_unary_prim( + "bessel_i0", + impl_aten=torch.i0, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bessel_i0e = _make_elementwise_unary_prim( + "bessel_i0e", + impl_aten=torch.special.i0e, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bessel_i1 = _make_elementwise_unary_prim( + "bessel_i1", + impl_aten=torch.special.i1, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bessel_i1e = _make_elementwise_unary_prim( + "bessel_i1e", + impl_aten=torch.special.i1e, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bitwise_not = _make_elementwise_unary_prim( + "bitwise_not", + impl_aten=torch.bitwise_not, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + + +def _cbrt_aten(a: torch.Tensor) -> Tensor: + torch._check( + not a.is_complex(), + lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)", + ) + # Returns the real cubic root of the number. + # Note that if a < 0, pow(a, (1. / 3.)) returns th complex number + # exp(1/3 * log(a)) = exp(1/3 * (log(abs(a)) + pi*i)) = cbrt(abs(a)) * e^{pi/3*i} + # which is a complex number. + # For more info see the section Note in + # https://en.cppreference.com/w/cpp/numeric/math/cbrt + return torch.copysign(torch.pow(a.abs(), 1 / 3), a) + + +cbrt = _make_elementwise_unary_prim( + "cbrt", + impl_aten=_cbrt_aten, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +ceil = _make_elementwise_unary_prim( + "ceil", + impl_aten=torch.ceil, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + + +def _conj_physical_meta(input: TensorLikeType) -> TensorLikeType: + if not input.dtype.is_complex: + raise RuntimeError("prims.conj_physical is only defined for complex dtypes") + + strides = utils.compute_elementwise_output_strides(input) + return TensorMeta(input, strides=strides) + + +conj_physical = _make_prim( + schema="conj_physical(Tensor self) -> Tensor", + meta=_conj_physical_meta, + impl_aten=torch._conj_physical, + doc="Returns the physical conjugation of a complex tensor", + return_type=RETURN_TYPE.NEW, +) + + +def _clone_meta( + input: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format +) -> TensorLikeType: + if memory_format != torch.preserve_format: + return torch.empty( + input.shape, + dtype=input.dtype, + layout=input.layout, + device=input.device, + memory_format=memory_format, + ) + + # memory_format == torch.preserve_format + strides = utils.compute_elementwise_output_strides(input) + return torch.empty_strided( + input.shape, + strides, + dtype=input.dtype, + layout=input.layout, + device=input.device, + ) + + +clone = _make_prim( + schema="clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", + meta=_clone_meta, + impl_aten=torch.clone, + doc="Returns the copy of a tensor", + return_type=RETURN_TYPE.NEW, + register_conj_neg_fallthrough=True, +) + +digamma = _make_elementwise_unary_prim( + "digamma", + impl_aten=torch.digamma, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +erf = _make_elementwise_unary_prim( + "erf", + impl_aten=torch.erf, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +erf_inv = _make_elementwise_unary_prim( + "erf_inv", + impl_aten=torch.special.erfinv, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +erfc = _make_elementwise_unary_prim( + "erfc", + impl_aten=torch.special.erfc, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +erfcx = _make_elementwise_unary_prim( + "erfcx", + impl_aten=torch.special.erfcx, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +exp = _make_elementwise_unary_prim( + "exp", + impl_aten=torch.exp, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +expm1 = _make_elementwise_unary_prim( + "expm1", + impl_aten=torch.special.expm1, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +exp2 = _make_elementwise_unary_prim( + "exp2", + impl_aten=torch.special.exp2, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + + +def _fill_meta(a: TensorLikeType, value: NumberType) -> TensorLikeType: + return _prim_elementwise_meta( + a, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT + ) + + +# NOTE: fill uses _make_prim directly because it has a value parameter +fill = _make_prim( + schema="fill(Tensor self, Scalar value) -> Tensor", + return_type=RETURN_TYPE.NEW, + meta=_fill_meta, + impl_aten=torch.fill, + doc="", +) + +floor = _make_elementwise_unary_prim( + "floor", + impl_aten=torch.floor, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +imag = _make_prim( + schema="imag(Tensor(a) self) -> Tensor(a)", + meta=partial( + _complex_only_elementwise_meta, + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, + ), + return_type=RETURN_TYPE.VIEW, + impl_aten=torch.imag, + doc="", +) + +isfinite = _make_elementwise_unary_prim( + "isfinite", + impl_aten=torch.isfinite, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) + +lgamma = _make_elementwise_unary_prim( + "lgamma", + impl_aten=torch.lgamma, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +log = _make_elementwise_unary_prim( + "log", + impl_aten=torch.log, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +log1p = _make_elementwise_unary_prim( + "log1p", + impl_aten=torch.log1p, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +log2 = _make_elementwise_unary_prim( + "log2", + impl_aten=torch.log2, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +log10 = _make_elementwise_unary_prim( + "log10", + impl_aten=torch.log10, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +real = _make_prim( + schema="real(Tensor(a) self) -> Tensor(a)", + meta=partial( + _complex_only_elementwise_meta, + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, + ), + return_type=RETURN_TYPE.VIEW, + impl_aten=torch.real, + doc="", +) + +reciprocal = _make_elementwise_unary_prim( + "reciprocal", + impl_aten=torch.reciprocal, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +ndtri = _make_elementwise_unary_prim( + "ndtri", + impl_aten=torch.special.ndtri, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +neg = _make_elementwise_unary_prim( + "neg", + impl_aten=torch.neg, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +round = _make_elementwise_unary_prim( + "round", + impl_aten=torch.round, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +rsqrt = _make_elementwise_unary_prim( + "rsqrt", + impl_aten=torch.rsqrt, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +sign = _make_elementwise_unary_prim( + "sign", + impl_aten=torch.sign, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +signbit = _make_elementwise_unary_prim( + "signbit", + impl_aten=torch.signbit, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +sin = _make_elementwise_unary_prim( + "sin", + impl_aten=torch.sin, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +sinh = _make_elementwise_unary_prim( + "sinh", + impl_aten=torch.sinh, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +spherical_bessel_j0 = _make_elementwise_unary_prim( + "spherical_bessel_j0", + impl_aten=torch.special.spherical_bessel_j0, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +sqrt = _make_elementwise_unary_prim( + "sqrt", + impl_aten=torch.sqrt, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +tan = _make_elementwise_unary_prim( + "tan", + impl_aten=torch.tan, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +tanh = _make_elementwise_unary_prim( + "tanh", + impl_aten=torch.tanh, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +trunc = _make_elementwise_unary_prim( + "trunc", + impl_aten=torch.trunc, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +# +# Elementwise binary operations +# + +add = _make_elementwise_binary_prim( + name="add", + impl_aten=torch.add, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +atan2 = _make_elementwise_binary_prim( + name="atan2", + impl_aten=torch.atan2, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bitwise_and = _make_elementwise_binary_prim( + "bitwise_and", + impl_aten=torch.bitwise_and, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bitwise_or = _make_elementwise_binary_prim( + "bitwise_or", + impl_aten=torch.bitwise_or, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +bitwise_xor = _make_elementwise_binary_prim( + "bitwise_xor", + impl_aten=torch.bitwise_xor, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +# TODO: complex needs a special meta to account for its float -> complex behavior +# complex = _make_elementwise_binary_prim( +# impl_aten=torch.complex, +# doc="", +# ) + + +# div prim performs truncation division on integer inputs +# and true division for floating and complex inputs +def _div_aten(a, b): + is_integral = isinstance(a, (bool, int, torch.SymInt)) or ( + isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype) + ) + + if is_integral: + return torch.div(a, b, rounding_mode="trunc") + else: + return torch.true_divide(a, b) + + +div = _make_elementwise_binary_prim( + "div", + impl_aten=_div_aten, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +eq = _make_elementwise_binary_prim( + "eq", + impl_aten=torch.eq, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) + +fmax = _make_elementwise_binary_prim( + "fmax", + impl_aten=torch.fmax, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +fmin = _make_elementwise_binary_prim( + "fmin", + impl_aten=torch.fmin, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +fmod = _make_elementwise_binary_prim( + "fmod", + impl_aten=torch.fmod, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + + +gcd = _make_elementwise_binary_prim( + "gcd", + impl_aten=torch.gcd, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + + +ge = _make_elementwise_binary_prim( + "ge", + impl_aten=torch.ge, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) + +gt = _make_elementwise_binary_prim( + "gt", + impl_aten=torch.gt, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) + +hypot = _make_elementwise_binary_prim( + "hypot", + impl_aten=torch.hypot, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +igamma = _make_elementwise_binary_prim( + "igamma", + impl_aten=torch.special.gammainc, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +igammac = _make_elementwise_binary_prim( + "igammac", + impl_aten=torch.special.gammaincc, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +le = _make_elementwise_binary_prim( + "le", + impl_aten=torch.le, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) + +lt = _make_elementwise_binary_prim( + "lt", + impl_aten=torch.lt, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) + + +# Note: the following impls are because torch.maximum and torch.minimum do not support scalar inputs +def _maximum_aten( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +) -> TensorLikeType: + if isinstance(a, TensorLike) and isinstance(b, Number): + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(b, TensorLike) and isinstance(a, Number): + a = scalar_tensor(a, dtype=b.dtype, device=b.device) + + return torch.maximum(a, b) # type: ignore[arg-type] + + +maximum = _make_elementwise_binary_prim( + "maximum", + impl_aten=_maximum_aten, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + + +def _minimum_aten( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +) -> TensorLikeType: + if isinstance(a, TensorLike) and isinstance(b, Number): + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(b, TensorLike) and isinstance(a, Number): + a = scalar_tensor(a, dtype=b.dtype, device=b.device) + + return torch.minimum(a, b) # type: ignore[arg-type] + + +minimum = _make_elementwise_binary_prim( + "minimum", + impl_aten=_minimum_aten, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +mul = _make_elementwise_binary_prim( + "mul", + impl_aten=torch.mul, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +ne = _make_elementwise_binary_prim( + "ne", + impl_aten=torch.ne, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) + +nextafter = _make_elementwise_binary_prim( + "nextafter", + impl_aten=torch.nextafter, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +pow = _make_elementwise_binary_prim( + "pow", + impl_aten=torch.pow, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +remainder = _make_elementwise_binary_prim( + "remainder", + impl_aten=torch.remainder, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + + +shift_left = _make_elementwise_binary_prim( + "shift_left", + impl_aten=torch.bitwise_left_shift, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +shift_right_arithmetic = _make_elementwise_binary_prim( + "shift_right_arithmetic", + impl_aten=torch.bitwise_right_shift, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +shift_right_logical = _not_impl + +sub = _make_elementwise_binary_prim( + "sub", + impl_aten=torch.sub, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + +zeta = _make_elementwise_binary_prim( + "zeta", + impl_aten=torch.special.zeta, + doc="", + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, +) + + +# +# View operations +def _as_strided_meta( + a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int +) -> TensorLikeType: + assert len(size) == len(stride) + assert storage_offset >= 0 + utils.validate_strides(stride) + utils.validate_shape(size) + + if reduce(operator.mul, size) == 0: + # NOTE: This special case is to avoid having to acquire the storage below + # as_strided to shapes with no elements are trivially valid, so it's OK + pass + elif isinstance(a, torch.Tensor): + utils.check_in_bounds_for_storage( + a._typed_storage(), size, stride, storage_offset + ) + + return torch.as_strided(a, size, stride, storage_offset) + + +def _as_strided_aten( + a: Tensor, size: ShapeType, stride: StrideType, storage_offset: int +) -> Tensor: + return torch.as_strided(a, size, stride, storage_offset) + + +_as_strided_doc = """ + Creates a view of the tensor with the given shape (size), strides (stride) and + storage offset (storage_offset). +""" + +as_strided = _make_prim( + schema="as_strided(Tensor(a!) a, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor(a!)", + meta=_as_strided_meta, + impl_aten=_as_strided_aten, + return_type=RETURN_TYPE.VIEW, + doc=_as_strided_doc, +) + + +def _broadcast_in_dim_meta( + a: TensorLikeType, shape: ShapeType, broadcast_dimensions: Sequence[int] +): + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + sym_or, + ) + + # Type checks + assert isinstance(a, TensorLike) + assert isinstance(shape, Sequence) + assert isinstance(broadcast_dimensions, Sequence) + + # every dimension must be accounted for + assert a.ndim == len(broadcast_dimensions) + + # broadcast shape must have weakly more dimensions + assert len(shape) >= a.ndim + + # broadcast_dimensions must be an ascending sequence + # (no relative reordering of dims) of integers and + # each dimension must be within the new shape + def _greater_than_reduce(acc, x): + assert isinstance(x, Dim) + assert x > acc + assert x < len(shape) + + return x + + reduce(_greater_than_reduce, broadcast_dimensions, -1) + + # shape must be broadcastable to + for idx, new_idx in enumerate(broadcast_dimensions): + torch._check( + sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]), + lambda: f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}", + ) + + new_strides = [] + original_idx = 0 + for idx in range(len(shape)): + if idx in broadcast_dimensions: + # Assigns a stride of zero to dimensions + # which were actually broadcast + if guard_or_false(a.shape[original_idx] == 1): + if guard_or_false(a.shape[original_idx] == shape[idx]): + new_strides.append(a.stride()[original_idx]) + else: + new_strides.append(0) + else: + torch._check( + a.shape[original_idx] == shape[idx], + lambda: f"non-broadcasting semantics require {a.shape[original_idx]} == {shape[idx]}", + ) + new_strides.append(a.stride()[original_idx]) + original_idx = original_idx + 1 + else: + if guard_or_true(shape[idx] != 1): + # consistent with previous use of guard_size_oblivious + new_strides.append(0) + elif original_idx == a.ndim: + new_strides.append(1) + else: + new_strides.append(a.stride()[original_idx] * a.size()[original_idx]) + + return a.as_strided(shape, new_strides, a.storage_offset()) + + +def _broadcast_in_dim_aten(a, shape, broadcast_dimensions): + s = list(shape) + for broadcast_dimension in broadcast_dimensions: + s[broadcast_dimension] = -1 + + v = a + for idx, x in enumerate(s): + if x != -1: + v = v.unsqueeze(idx) + + return v.expand(shape) + + +_broadcast_in_dim_doc = """ + Creates a view of a with the specified shape. + + Allows adding dimensions of any length and broadcasting + dimensions of length one in a to any length. + + The location of the broadcast dimensions must be specified + using the broadcast_dimensions argument. Changing the + relative order of dimensions is not supported. + """ + +broadcast_in_dim = _make_prim( + schema="broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)", + meta=_broadcast_in_dim_meta, + impl_aten=_broadcast_in_dim_aten, + return_type=RETURN_TYPE.VIEW, + doc=_broadcast_in_dim_doc, +) + + +def _validate_collapse_args(a: Tensor, start: int, end: int) -> None: + # Special-case for zero dimensional tensors + ndim = max(1, a.dim()) + utils.validate_idx(ndim, start) + utils.validate_idx(ndim, end) + + # Verifies end is strictly greater than start + # (Collapse requires a non-empty interval) + torch._check_value( + end >= start, + lambda: f"Attempting to collapse but end, {end}, is less than start, {start}!", + ) + + +def _collapsed_shape(shape: ShapeType, start: int, end: int) -> tuple[int, ...]: + """ + Returns the shape of a with dims in [start, end) merged into a single dimension. + """ + # Special-case for zero dimensional tensors + shape = (1,) if len(shape) == 0 else tuple(shape) + + dim_length = 1 + for s in shape[start : end + 1]: + dim_length = dim_length * s + + return shape[0:start] + (dim_length,) + shape[end + 1 :] + + +def _collapse_view_helper( + a: TensorLikeType, start: int, end: int +) -> tuple[Optional[ShapeType], Optional[StrideType]]: + assert isinstance(a, TensorLike) + + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + _validate_collapse_args(a, start, end) + + # Special-case for zero dimensional tensors + if a.ndim == 0: + shape = (1,) + strides = (1,) + else: + shape = a.shape # type: ignore[assignment] + strides = a.stride() # type: ignore[assignment] + + if a.ndim == 0 or (end == start): + return shape, strides + + length = shape[end] + stride = strides[end] + for idx in range(end - 1, start - 1, -1): + if guard_size_oblivious(shape[idx] == 0) or guard_size_oblivious( + shape[idx + 1] == 0 + ): + length = 0 + stride = 0 + break + + if guard_size_oblivious(shape[idx] == 1): + continue + + length = length * shape[idx] + if guard_size_oblivious(stride < strides[idx]): + stride = stride + else: + stride = strides[idx] + + if ( + guard_size_oblivious(a.numel() > 0) + and guard_size_oblivious(shape[idx + 1] != 1) + and not guard_size_oblivious( + strides[idx] == strides[idx + 1] * shape[idx + 1] + ) + ): + return None, None + + new_shape = shape[:start] + (length,) + shape[end + 1 :] + new_strides = strides[:start] + (stride,) + strides[end + 1 :] + + # NOTE: when the input has no elements it's restrided as if it were contiguous + if guard_size_oblivious(a.numel() == 0): + new_strides = utils.make_contiguous_strides_for(new_shape) + + return new_shape, new_strides + + +def _collapse_view_meta(a: TensorLikeType, start: int, end: int) -> TensorLikeType: + new_shape, new_strides = _collapse_view_helper(a, start, end) + + if new_shape is None: + msg = "Attempting to view a collapsed tensor, but no such view exists!" + raise ValueError(msg) + + assert new_strides is not None + return a.as_strided(new_shape, new_strides, a.storage_offset()) + + +def _collapse_view_aten(a: Tensor, start: int, end: int) -> Tensor: + new_shape = _collapsed_shape(a.shape, start, end) + return a.view(new_shape) + + +_collapse_view_doc = """ + Creates a view of a with the dimensions between + start (inclusive) and end (exclusive) merged into a + single dimension. + + If it's not possible to take such a view then an error + is thrown. See collapse instead. + + The dimensions can be merged if and only if + they are all "nested" with each other. That is, they all + have the property that + + stride[i] = stride[i+1] * shape[i+1] + + for all i in [start, end - 1). + """ + +collapse_view = _make_prim( + schema="collapse_view(Tensor(a) a, int start, int end) -> Tensor(a)", + meta=_collapse_view_meta, + impl_aten=_collapse_view_aten, + return_type=RETURN_TYPE.VIEW, + doc=_collapse_view_doc, +) + + +def _conj_meta(a: TensorLikeType) -> TensorLikeType: + if not a.dtype.is_complex: + raise RuntimeError("Expected complex dtype in prims.conj") + out = a.as_strided(a.shape, a.stride(), a.storage_offset()) + torch._C._set_conj(out, not a.is_conj()) + return out + + +_conj_doc = """ +Returns a conjugated view of the original tensor +""" + +conj = _make_prim( + schema="conj(Tensor(a) a) -> Tensor(a)", + meta=_conj_meta, + impl_aten=torch.conj, + return_type=RETURN_TYPE.VIEW, + doc=_conj_doc, +) + + +def expand_dims( + a: TensorLikeType, dimensions: DimsSequenceType, ndim=None +) -> TensorLikeType: + """ + Creates a view of a with a.ndim + len(dimensions) dimensions, with new + dimensions of length one at the dimensions specified by dimensions. + """ + if ndim is not None: + # TODO: this is only here to support the unsqueeze ref + dims = sorted(utils.canonicalize_dims(ndim, dimensions)) # type: ignore[arg-type] + else: + dims = sorted(utils.canonicalize_dims(a.ndim, dimensions)) # type: ignore[arg-type] + if len(set(dims)) != len(dims): + msg = f"Received duplicate dimensions to expand in {str(dimensions)}" + raise ValueError(msg) + + new_shape = list(a.shape) + for idx in dims: + new_shape.insert(idx, 1) + + broadcast_dimensions = [ + idx for idx in range(len(new_shape)) if idx not in dimensions + ] + return broadcast_in_dim(a, new_shape, broadcast_dimensions) + + +def _split_dim_meta(a: TensorLikeType, dim: int, outer_length: int) -> TensorLikeType: + assert isinstance(a, TensorLike) + utils.validate_idx(a.ndim, dim) + utils.validate_dim_length(outer_length) + + # Verifies the dim can be split with the specified lhs_length + inner_length = a.shape[dim] // outer_length + + if (a.shape[dim] % outer_length) != 0: + msg = ( + f"Attempting to split dimension of length {a.shape[dim]}, " + f"but outer length of {outer_length} divides it with a remainder!" + ) + raise ValueError(msg) + + new_shape: list[int] = [] + new_strides: list[int] = [] + for idx in range(a.ndim): + if idx == dim: + new_shape.extend((outer_length, inner_length)) + new_strides.extend((a.stride()[idx] * inner_length, a.stride()[idx])) + else: + new_shape.append(a.shape[idx]) + new_strides.append(a.stride()[idx]) + + return a.as_strided(new_shape, new_strides, a.storage_offset()) + + +def _split_dim_aten(a: Tensor, dim: int, outer_length: int) -> Tensor: + inner_length = a.shape[dim] // outer_length + new_shape = a.shape[0:dim] + (outer_length, inner_length) + a.shape[dim + 1 :] + + return a.view(new_shape) + + +_split_dim_doc = """ + Creates a view of a with the given dimension (of length l) split + into two dimensions, with the outer of the two having + length outer_length and the inner of the two having computed + length inner_length such outer_length * inner_length = l. + """ + +# TODO: consider renaming split_dim_view +split_dim = _make_prim( + schema="split_dim(Tensor(a) a, int dim, SymInt outer_length) -> Tensor(a)", + meta=_split_dim_meta, + impl_aten=_split_dim_aten, + return_type=RETURN_TYPE.VIEW, + doc=_split_dim_doc, +) + + +# Note: allows dimensions to be specified redundantly +def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType: + assert isinstance(a, TensorLike) + + for idx in dimensions: + utils.validate_idx(a.ndim, idx) + assert a.shape[idx] == 1 + + new_shape = [] + new_strides = [] + for idx in range(len(a.shape)): + if idx in dimensions: + continue + + new_shape.append(a.shape[idx]) + new_strides.append(a.stride()[idx]) + + return a.as_strided(new_shape, new_strides, a.storage_offset()) + + +_squeeze_doc = """ + Creates a view of the tensor with the specified dimensions removed. + + The removed dimensions must each have length one. + """ + +squeeze = _make_prim( + schema="squeeze(Tensor(a) a, int[] dimensions) -> Tensor(a)", + meta=_squeeze_meta, + impl_aten=torch.squeeze, + return_type=RETURN_TYPE.VIEW, + doc=_squeeze_doc, +) + + +def _transpose_meta(a: TensorLikeType, permutation: DimsSequenceType) -> TensorLikeType: + if a.ndim != len(permutation): + msg = f"Attempting to permute a tensor of rank {a.ndim}, but received a permutation of length {len(permutation)}!" + raise ValueError(msg) + + if not utils.is_valid_permutation(a.ndim, permutation): + msg = f"Received an invalid permutation, {permutation}!" + raise ValueError(msg) + + new_shape = [0] * a.ndim + new_strides = [0] * a.ndim + for idx, dim in enumerate(permutation): + new_shape[idx] = a.shape[dim] + new_strides[idx] = a.stride()[dim] + + return a.as_strided(tuple(new_shape), tuple(new_strides), a.storage_offset()) + + +def _transpose_aten(a: Tensor, permutation: DimsSequenceType) -> Tensor: + return torch.permute(a, permutation) + + +_transpose_doc = """ + Creates a view of the tensor with its dimensions permuted. + + The length of the permutation must be the rank of the tensor, + and each element of the permutation specifies the new order + for the corresponding dimension. + """ + +transpose = _make_prim( + schema="transpose(Tensor(a) a, int[] permutation) -> Tensor(a)", + meta=_transpose_meta, + impl_aten=_transpose_aten, + return_type=RETURN_TYPE.VIEW, + doc=_transpose_doc, +) + + +def _view_of_meta(a: TensorLikeType) -> TensorLikeType: + return a.as_strided(a.shape, a.stride(), a.storage_offset()) + + +def _view_of_aten(a: Tensor) -> Tensor: + return a.view(a.shape) + + +_view_of_doc = """ + Creates a view of the tensor. + """ + +view_of = _make_prim( + schema="view_of(Tensor(a) a) -> Tensor(a)", + meta=_view_of_meta, + impl_aten=_view_of_aten, + return_type=RETURN_TYPE.VIEW, + doc=_view_of_doc, +) + + +def _view_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType: + return a.view(dtype) + + +def _view_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor: + return a.view(dtype) + + +_view_element_type_doc = """ + Creates a view of the tensor with a different dtype. + """ + +view_element_type = _make_prim( + schema="view_of_dtype(Tensor(a) a, ScalarType dtype) -> Tensor(a)", + meta=_view_element_type_meta, + impl_aten=_view_element_type_aten, + return_type=RETURN_TYPE.VIEW, + doc=_view_element_type_doc, +) + +# +# Functionalized view mutations +# + + +def _as_strided_scatter_meta( + input: TensorLikeType, + src: TensorLikeType, + size: ShapeType, + stride: StrideType, + storage_offset: int, +) -> TensorLikeType: + utils.validate_shape(size) + utils.validate_strides(stride) + + required_size = utils.compute_required_storage_length(size, stride, storage_offset) + torch._check( + input.numel() >= required_size, + lambda: ( + f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} " + f" and itemsize {input.element_size()} requiring a storage size of " + f"{required_size * input.element_size()} are out of bounds " + f"for storage of size {input.numel() * input.element_size()}" + ), + ) + torch._check( + utils.is_same_shape(src.shape, size), + lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}", + ) + + return utils.clone_preserve_strides(input) + + +_as_strided_scatter_doc = """ + Creates a new tensor equivalent to ``out = input.clone()`` after mutation by + ``out.as_strided(size, stride, storage_offset).copy_(src)``. +""" + +as_strided_scatter = _make_prim( + schema="as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor", + meta=_as_strided_scatter_meta, + impl_aten=torch.as_strided_scatter, + return_type=RETURN_TYPE.NEW, + doc=_as_strided_scatter_doc, +) + + +# +# Shape operations +# + + +def _collapse_meta(a: Tensor, start: int, end: int) -> Tensor: + # Special-case for zero dimensional tensors + _validate_collapse_args(a, start, end) + new_shape = _collapsed_shape(a.shape, start, end) + return a.new_empty(new_shape) + + +def _collapse_aten(a: Tensor, start: int, end: int) -> Tensor: + new_shape = _collapsed_shape(a.shape, start, end) + out = a.new_empty(new_shape) + with torch.no_grad(): + out.view_as(a).copy_(a) + return out + + +_collapse_doc = """ +Collapse a span of neighboring dimensions into one. + +See collapse_view for the corresponding view operation. +""" +collapse = _make_prim( + schema="collapse(Tensor a, int start, int end) -> Tensor", + meta=_collapse_meta, + impl_aten=_collapse_aten, + return_type=RETURN_TYPE.NEW, + doc=_collapse_doc, +) + + +# TODO: review stride logic +# NB: unlike torch.cat, this is more strict about empty tensors and dim is +# never negative +def _cat_meta(tensors: Sequence[TensorLikeType], dim: int) -> TensorLikeType: + # Verifies same shape (except in the concat dimension) + assert dim >= 0 + shape = tensors[0].shape + sym_sum_args = [] + for tensor_idx, tensor in enumerate(tensors): + assert len(shape) == len(tensor.shape) + for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)): + if idx == dim: + sym_sum_args.append(length) + else: + torch._check( + length == common_length, + lambda: f"Sizes of tensors must match except in dimension {dim}. " + f"Expected {common_length} in dimension {idx} but got {length} for tensor number " + f"{tensor_idx} in the list", + ) + + new_shape = list(tensors[0].shape).copy() + new_shape[dim] = torch.sym_sum(sym_sum_args) + return TensorMeta( + tensors[0], + shape=new_shape, + strides=utils.make_contiguous_strides_for(new_shape), + ) + + +def _cat_aten(tensors: Union[tuple[Tensor, ...], list[Tensor]], dim: int) -> Tensor: + return torch.cat(tensors, dim) + + +_cat_doc = """ + Concatenates tensors along the specified dimension. + + The tensors' shapes must have the same rank and same length for other dimensions. + """ + +cat = _make_prim( + schema="cat(Tensor[] tensors, int dim) -> Tensor", + meta=_cat_meta, + impl_aten=_cat_aten, + return_type=RETURN_TYPE.NEW, + doc=_cat_doc, +) + + +def _reshape_meta(a: TensorLikeType, shape: ShapeType): + assert isinstance(a, TensorLike) + utils.validate_shape(shape) + + # Validates the tensor and the requested shape have the + # same number of elements + numel = reduce(operator.mul, shape) + if numel != a.numel(): + msg = f"Attempting to reshape a tensor with {a.numel()} elements to a shape with {numel} elements!" + raise ValueError(msg) + + return TensorMeta(a, shape=shape, strides=utils.make_contiguous_strides_for(shape)) + + +def _reshape_aten(a: Tensor, shape: ShapeType) -> Tensor: + return a.reshape(shape).clone(memory_format=torch.contiguous_format) + + +_reshape_doc = """ + Creates a contiguous tensor with the specified shape + containing a copy of the data in a. + """ +reshape = _make_prim( + schema="reshape(Tensor a, SymInt[] shape) -> Tensor", + meta=_reshape_meta, + impl_aten=_reshape_aten, + return_type=RETURN_TYPE.NEW, + doc=_reshape_doc, +) + + +def _rev_meta(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType: + utils.validate_dimension_indices(a.ndim, dims) + return torch.empty_like(a, memory_format=torch.preserve_format) + + +_rev_doc = """ + Reverses the order of elements along the given dimensions. + """ + +rev = _make_prim( + schema="rev(Tensor a, int[] dims) -> Tensor", + meta=_rev_meta, + impl_aten=torch.flip, + return_type=RETURN_TYPE.NEW, + doc=_rev_doc, +) + +# +# Conditional prims +# + + +def _where_meta( + pred: TensorLikeType, a: TensorLikeType, b: TensorLikeType +) -> TensorLikeType: + return _prim_elementwise_meta( + a, + b, + type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT, + args_with_fixed_dtypes=(pred,), + ) + + +_where_doc = """ + Selects elements from a and b according to pred. + + Where pred is true the result contains the element from a, and + where pred is false the result contains the element from b. + """ + +where = _make_prim( + schema="where(Tensor pred, Tensor a, Tensor b) -> Tensor", + meta=_where_meta, + impl_aten=torch.where, + return_type=RETURN_TYPE.NEW, + doc=_where_doc, +) + + +# +# Type conversions +# +def _convert_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType: + # Type checks + assert isinstance(a, TensorLike) + assert isinstance(dtype, torch.dtype) + + # dtype conversion preserves dense strides + if torch._prims_common.is_non_overlapping_and_dense(a): + strides = a.stride() + else: + strides = utils.compute_elementwise_output_strides(a) + + return TensorMeta(a, strides=strides, dtype=dtype) + + +def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor: + # Propagates requires grad when possible + if not utils.is_grad_dtype(dtype): + requires_grad = False + else: + # TODO: update meta objects so this can be acquired directly + try: + requires_grad = a.requires_grad + except Exception: + requires_grad = False + + result = torch.empty_like( + a, device=a.device, dtype=dtype, requires_grad=requires_grad + ) + with torch.no_grad(): + return copy_to(result, a) + + +_convert_element_type_doc = """ + Creates a copy of a tensor with the given dtype. + """ + +convert_element_type = _make_prim( + schema="convert_element_type(Tensor a, ScalarType dtype) -> Tensor", + meta=_convert_element_type_meta, + impl_aten=_convert_element_type_aten, + return_type=RETURN_TYPE.NEW, + doc=_convert_element_type_doc, + tags=(torch.Tag.pointwise,), +) + + +def _device_put_meta( + a: TensorLikeType, device: Union[str, torch.device], non_blocking=False +) -> TensorLikeType: + assert isinstance(a, TensorLike) + assert isinstance(device, (str, torch.device)) + assert isinstance(non_blocking, bool) + + return TensorMeta(a, device=utils.canonicalize_device(device)) + + +def _device_put_aten( + a: Tensor, device: Union[str, torch.device], non_blocking=False +) -> Tensor: + return a.to(device, non_blocking=non_blocking) + + +_device_put_doc = """ + Creates a copy of a tensor on the given device. + """ + +device_put = _make_prim( + schema="device_put(Tensor a, Device device, bool non_blocking=False) -> Tensor", + meta=_device_put_meta, + impl_aten=_device_put_aten, + return_type=RETURN_TYPE.NEW, + doc=_device_put_doc, +) + + +# NOTE: need to model meta scalars +# See https://github.com/pytorch/pytorch/issues/78070 +def _item_meta(a: TensorLikeType) -> FakeTensor: + number_type = utils.dtype_to_type(a.dtype) + return TensorMeta(number_type(-1)) + + +_item_doc = """ + Converts a tensor with one element to a Python number. +""" + + +# We can't call into python dispatcher for item again +# because the current prim decomp calls into python dispatcher +# again. https://github.com/pytorch/pytorch/issues/136050 +def _item_aten_no_python_dispatcher(*args, **kwargs): + with torch._dispatch.python.no_python_dispatcher(): + return torch.Tensor.item(*args, **kwargs) + + +# TODO: create a new return type for scalars? +# FIXME: currently returns integers for boolean tensors +# https://github.com/pytorch/pytorch/issues/78071 +item = _make_prim( + schema="item(Tensor a) -> Scalar", + meta=_item_meta, + impl_aten=_item_aten_no_python_dispatcher, + return_type=RETURN_TYPE.NEW, + doc=_item_doc, +) + + +# NOTE: need to model meta scalars +# See https://github.com/pytorch/pytorch/issues/78070 +def _maximum_value_meta(dtype: torch.dtype) -> FakeTensor: + number_type = utils.dtype_to_type(dtype) + return TensorMeta(number_type(-1)) + + +def _maximum_value_aten(dtype: torch.dtype): + if dtype == torch.bool: + return True + elif dtype.is_complex or dtype.is_floating_point: + return torch.finfo(dtype).max + else: + return torch.iinfo(dtype).max + + +_maximum_value_doc = """ + Return the maximum finite value for a dtype. +""" + +# TODO: create a new return type for scalars? +# FIXME: currently returns integers for boolean tensors +# https://github.com/pytorch/pytorch/issues/78071 +maximum_value = _make_prim( + schema="maximum_value(ScalarType dtype) -> Scalar", + meta=_maximum_value_meta, + impl_aten=_maximum_value_aten, + return_type=RETURN_TYPE.NEW, + doc=_maximum_value_doc, +) + + +# NOTE: need to model meta scalars +# See https://github.com/pytorch/pytorch/issues/78070 +def _minimum_value_meta(dtype: torch.dtype) -> FakeTensor: + number_type = utils.dtype_to_type(dtype) + return TensorMeta(number_type(-1)) + + +def _minimum_value_aten(dtype: torch.dtype): + if dtype == torch.bool: + return False + elif dtype.is_complex or dtype.is_floating_point: + return torch.finfo(dtype).min + else: + return torch.iinfo(dtype).min + + +_minimum_value_doc = """ + Return the minimum finite value for a dtype. +""" + +# TODO: create a new return type for scalars? +# FIXME: currently returns integers for boolean tensors +# https://github.com/pytorch/pytorch/issues/78071 +minimum_value = _make_prim( + schema="minimum_value(ScalarType dtype) -> Scalar", + meta=_minimum_value_meta, + impl_aten=_minimum_value_aten, + return_type=RETURN_TYPE.NEW, + doc=_minimum_value_doc, +) + +# +# Inplace operators +# + + +def _copy_to_meta(a: TensorLikeType, b: TensorLikeType): + assert isinstance(a, TensorLike) + assert isinstance(b, TensorLike) + + # Validates the cast is safe + # TODO: move this as an option on the reference + # a_typ = utils.dtype_to_type(a.dtype) + # b_typ = utils.dtype_to_type(b.dtype) + # if a_typ is not utils.get_higher_type(a_typ, b_typ): + # raise RuntimeError(str(b.dtype), " can't be cast safely to ", str(a.dtype), "!") + + # Validates the tensors have the same number of elements + if a.numel() != b.numel(): + msg = f"Attempting to copy {b.numel()} elements to a tensor with {a.numel()} elements!" + raise RuntimeError(msg) + + return a + + +def _copy_to_aten(a: Tensor, b: Tensor) -> Tensor: + return a.copy_(b) + + +_copy_to_doc = """ + Copies the data in b to a and returns the modified a. + """ + +# TODO: Remove safe casting and implement on reference instead +copy_to = _make_prim( + schema="copy_to(Tensor(a!) a, Tensor b) -> Tensor(a!)", + meta=_copy_to_meta, + impl_aten=_copy_to_aten, + return_type=RETURN_TYPE.INPLACE, + doc=_copy_to_doc, + register_conj_neg_fallthrough=True, +) + + +def _copy_strided_meta(a: TensorLikeType, stride: ShapeType): + assert isinstance(a, TensorLike) + return torch.empty_strided( + a.shape, + stride, + dtype=a.dtype, + layout=a.layout, + device=a.device, + requires_grad=a.requires_grad, + ) + + +def _copy_strided_aten(a: Tensor, stride: ShapeType) -> Tensor: + out = torch.empty_strided( + a.size(), + stride=stride, + dtype=a.dtype, + layout=a.layout, + device=a.device, + requires_grad=a.requires_grad, + ) + out.copy_(a) + return out + + +_copy_strided_doc = """ + Copies the data in a to a new tensor, the new tensor has same shape with a size, but has different stride. + """ + + +copy_strided = _make_prim( + schema="copy_strided(Tensor a, SymInt[] stride) -> Tensor", + meta=_copy_strided_meta, + impl_aten=_copy_strided_aten, + return_type=RETURN_TYPE.NEW, + doc=_copy_strided_doc, +) + + +def _resize_meta(a: TensorLikeType, shape: ShapeType): + return a.resize_(shape) + + +def _resize_aten(a: Tensor, shape: ShapeType) -> Tensor: + return a.resize_(shape) + + +_resize_doc = """ + Gives a tensor with no elements a new shape, returning the modified tensor. + + The tensor's strides are contiguous and its values are unitialized. + """ + +# TODO: review support arbitrary resizes +resize = _make_prim( + schema="resize(Tensor(a!) a, SymInt[] shape) -> Tensor(a!)", + meta=_resize_meta, + impl_aten=_resize_aten, + return_type=RETURN_TYPE.INPLACE, + doc=_resize_doc, +) + + +def _reduction_meta(inp, dims, *, output_dtype=None): + """ + Meta function for single output reduction operations + Stride logic is incorrect + """ + assert isinstance(inp, TensorLike) + if output_dtype is None: + output_dtype = inp.dtype + output_shape = utils.compute_reduction_output_shape(inp.shape, dims) + return TensorMeta( + shape=output_shape, + strides=utils.make_contiguous_strides_for(output_shape), + dtype=output_dtype, + device=inp.device, + ) + + +def _var_reduction_meta(inp, dims, correction): + if utils.is_complex_dtype(inp.dtype): + output_dtype = utils.corresponding_real_dtype(inp.dtype) + else: + output_dtype = inp.dtype + return _reduction_meta(inp, dims, output_dtype=output_dtype) + + +_sum_doc = """ + Computes the sum of elements in the input tensor over the list of dimensions + specified in the dim argument + """ +_xor_sum_doc = """ + Computes the xor sum of elements in the input tensor over the list of dimensions + specified in the dim argument + """ +_prod_doc = """ + Computes the product of elements in the input tensor over the list of dimensions + specified in the dim argument + """ +_amax_doc = """ + Computes the maximum value of elements in the input tensor over the list of dimensions + specified in the dim argument + """ +_amin_doc = """ + Computes the minimum value of elements in the input tensor over the list of dimensions + specified in the dim argument + """ +_var_doc = """ + Computes the biased variance of x over the list of dimensions specified in the dim argument + """ + + +def _make_reduction_prim(name: str, impl_aten, doc): + """Creates a reduction prim.""" + return _make_prim( + schema=f"{name}(Tensor inp, int[]? dims, *, ScalarType? output_dtype=None) -> Tensor", + meta=_reduction_meta, + impl_aten=impl_aten, + return_type=RETURN_TYPE.NEW, + doc=doc, + ) + + +def _make_var_reduction_prim(name: str, impl_aten, doc): + """Creates a reduction prim.""" + return _make_prim( + schema=f"{name}(Tensor inp, int[]? dims, float? correction=1, *, ScalarType? output_dtype=None) -> Tensor", + meta=_var_reduction_meta, + impl_aten=impl_aten, + return_type=RETURN_TYPE.NEW, + doc=doc, + ) + + +sum = _make_reduction_prim( + name="sum", + impl_aten=torch.sum, + doc=_sum_doc, +) + + +def _xor_sum_aten( + inp: TensorLikeType, + dims: Optional[DimsSequenceType], + *, + dtype: Optional[torch.dtype] = None, +) -> Tensor: + raise NotImplementedError("xor_sum only implemented with inductor") + + +xor_sum = _make_reduction_prim( + name="xor_sum", + impl_aten=_xor_sum_aten, + doc=_xor_sum_doc, +) + + +def _prod_aten( + inp: TensorLikeType, + dims: Optional[DimsSequenceType], + *, + dtype: Optional[torch.dtype] = None, +) -> Tensor: + if dims is not None: + if len(dims) == 0: + return inp.clone() + for d in sorted(dims, reverse=True): + assert d >= 0 + inp = torch.prod(inp, d, dtype=dtype) + return inp + else: + return torch.prod(inp, dims, dtype=dtype) + + +prod = _make_reduction_prim( + name="prod", + impl_aten=_prod_aten, + doc=_prod_doc, +) + + +# torch.var, but correction is not kwarg-only +def torch_var(input, dim=None, correction=1, **kwargs): + return torch.var(input, dim=dim, correction=correction, **kwargs) + + +var = _make_var_reduction_prim( + name="var", + impl_aten=torch_var, + doc=_var_doc, +) + +amax = _make_reduction_prim( + name="amax", + impl_aten=torch.amax, + doc=_amax_doc, +) + +amin = _make_reduction_prim( + name="amin", + impl_aten=torch.amin, + doc=_amin_doc, +) + + +_iota_doc = """ + Constructs a 1-D tensor t where ``t[i] == start + i * step``. +""" + + +# TODO: layout, pin_memory, memory_format +# TODO: model requires_grad on TensorMeta +def _iota_meta( + length: int, + *, + start: int, + step: int, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, +) -> TensorLikeType: + torch._check( + utils.is_integer_dtype(dtype), + lambda: "prims.iota only supports integer dtypes", + ) + torch._check(step != 0, lambda: "step must be nonzero") + return torch.empty( + length, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + +def _iota_aten( + length: int, + *, + start: int, + step: int, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, +) -> TensorLikeType: + end = start + length * step + return torch.arange( + start, end, step, dtype=dtype, device=device, requires_grad=requires_grad + ) + + +iota = _make_prim( + schema="iota(SymInt length, *, SymInt start, SymInt step, ScalarType dtype, Device device, bool requires_grad) -> Tensor", # noqa: B950 + return_type=RETURN_TYPE.NEW, + meta=_iota_meta, + impl_aten=_iota_aten, + doc=_iota_doc, +) + + +# TODO: layout, pin_memory, memory_format +# TODO: model requires_grad on TensorMeta +def _empty_meta( + shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool +) -> TensorLikeType: + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device) + + +def _empty_aten( + shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool +) -> Tensor: + return torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad) + + +_empty_doc = """ + Creates a tensor with uninitialized values and the specified shape, dtype, and device. +""" + +empty = _make_prim( + schema="empty(SymInt[] shape, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", + meta=_empty_meta, + impl_aten=_empty_aten, + return_type=RETURN_TYPE.NEW, + doc=_empty_doc, +) + + +def _empty_strided_meta( + shape: ShapeType, + strides: StrideType, + *, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, +) -> TensorLikeType: + return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device) + + +_empty_strided_doc = """ + Creates a tensor with uninitialized values. +""" + +# TODO: add layout, pin_memory +empty_strided = _make_prim( + schema="empty_strided(SymInt[] shape, SymInt[] strides, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", + return_type=RETURN_TYPE.NEW, + meta=_empty_strided_meta, + impl_aten=torch.empty_strided, + doc=_empty_strided_doc, +) + + +def _empty_permuted_meta( + shape: ShapeType, + physical_layout: DimsSequenceType, + *, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, +) -> TensorLikeType: + p_strides = utils.make_contiguous_strides_for([shape[l] for l in physical_layout]) + dim = len(shape) + torch._check( + len(physical_layout) == dim, + lambda: ( + "Number of dimensions in the tensor input does not match the " + f"length of the physical layout; i.e. len(size) = {dim} " + f"is not equal to len(physical_layout) = {len(physical_layout)}" + ), + ) + strides = [0] * len(shape) + seen_dims = set() + for p, l in enumerate(physical_layout): + torch._check( + 0 <= l < dim, + lambda: ( + f"Dimension out of range (expected to be between 0 and {dim - 1}, but got " + f"{l} at index {p}). NB: negative dims " + "not currently supported; file an issue if you want it." + ), + ) + torch._check(l not in seen_dims, lambda: "Duplicate dim not allowed") + strides[l] = p_strides[p] + seen_dims.add(l) + return TensorMeta( + shape=shape, + strides=strides, + dtype=dtype, + device=device, + ) + + +_empty_permuted_doc = """ + Creates a tensor with uninitialized values according to some physical layout, + that is guaranteed to be non-overlapping and dense. +""" + +# TODO: add layout, pin_memory +empty_permuted = _make_prim( + schema="empty_permuted(SymInt[] shape, int[] physical_layout, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", # noqa: B950 + return_type=RETURN_TYPE.NEW, + meta=_empty_permuted_meta, + impl_aten=torch.empty_permuted, + doc=_empty_permuted_doc, +) + + +def _full_meta( + shape: ShapeType, + fill_value: NumberType, + *, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, +) -> TensorLikeType: + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device) + + +def _full_aten( + shape: ShapeType, + fill_value: NumberType, + *, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, +) -> Tensor: + # Note that Mypy thinks torch.full can't accept a complex fill_value + return torch.full( + shape, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type] + ) + + +_full_doc = """ + Creates a tensor filled with the given fill value, and with the specified shape, dtype, and device. +""" + +# TODO: add layout +full = _make_prim( + schema="full(SymInt[] shape, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", + meta=_full_meta, + impl_aten=_full_aten, + return_type=RETURN_TYPE.NEW, + doc=_full_doc, +) + + +def _full_like_meta( + a: TensorLikeType, + fill_value: NumberType, + *, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, +) -> TensorLikeType: + strides = utils.compute_elementwise_output_strides(a) + if a.numel() == 0: + strides = a.stride() + + return TensorMeta(a, strides=strides, dtype=dtype, device=device) + + +def _full_like_aten( + a: Tensor, + fill_value: NumberType, + *, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, +) -> Tensor: + # Note that Mypy thinks torch.full can't accept a complex fill_value + return torch.full_like( + a, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type] + ) + + +_full_like_doc = """ + Creates a tensor filled with the given fill value, and the same shape, dtype, and device as the + given tensor by default. The dtype and device settings can be overridden + by specifying them explicitly. +""" + +full_like = _make_prim( + schema="full_like(Tensor a, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", + meta=_full_like_meta, + impl_aten=_full_like_aten, + return_type=RETURN_TYPE.NEW, + doc=_full_like_doc, +) + + +def _scalar_tensor_meta( + scalar: NumberType, + *, + dtype: torch.dtype, + device: torch.device, +) -> TensorLikeType: + shape: ShapeType = [] + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta(scalar, shape=shape, strides=strides, dtype=dtype, device=device) + + +def _scalar_tensor_aten( + scalar: NumberType, + *, + dtype: torch.dtype, + device: torch.device, +) -> Tensor: + if isinstance(scalar, complex) and ( + dtype is None or not utils.is_complex_dtype(dtype) + ): + raise TypeError("Complex scalar requires complex tensor dtype.") + # Note that Mypy thinks torch.scalar can't accept a complex scalar + return torch.scalar_tensor(scalar, dtype=dtype, device=device) # type: ignore[arg-type] + + +_scalar_tensor_doc = """ + Wraps a Number into a Tensor with the specified dtype and device. +""" + +# TODO: add layout and pin_memory support +scalar_tensor = _make_prim( + schema="scalar_tensor(Scalar s, *, ScalarType? dtype=None, Device? device=None) -> Tensor", + meta=_scalar_tensor_meta, + impl_aten=_scalar_tensor_aten, + return_type=RETURN_TYPE.NEW, + doc=_scalar_tensor_doc, +) + + +# +# Linear algebra (linalg) prims +# + + +def _svd_meta( + A: TensorLikeType, *, full_matrices: bool +) -> tuple[TensorLikeType, TensorLikeType, TensorLikeType]: + utils.check_is_matrix(A, "linalg.svd") + utils.check_fp_or_complex(A.dtype, "linalg.svd", allow_low_precision_dtypes=False) + + A_shape = A.shape + batch = A_shape[:-2] + m, n = A_shape[-2:] + k = min(m, n) + + shape_U = batch + (m, m if full_matrices else k) + strides_U = utils.make_contiguous_strides_for(shape_U, row_major=False) + U = TensorMeta(shape=shape_U, strides=strides_U, dtype=A.dtype, device=A.device) + + shape_S = batch + (k,) + strides_S = utils.make_contiguous_strides_for(shape_S) + S = TensorMeta( + shape=shape_S, + strides=strides_S, + dtype=utils.corresponding_real_dtype(A.dtype) if A.is_complex() else A.dtype, + device=A.device, + ) + + shape_Vh = batch + (n if full_matrices else k, n) + # The CPU backend returns V, but the cuSolver backend returns V^H + # TODO The MAGMA backend returns V, so this is wrong if used with the MAGMA backend + is_cuda = A.device.type == "cuda" + strides_Vh = utils.make_contiguous_strides_for(shape_Vh, row_major=is_cuda) + Vh = TensorMeta(shape=shape_Vh, strides=strides_Vh, dtype=A.dtype, device=A.device) + # Also makes sure this is CUDA or HIP: + # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip + if A.numel() != 0 and Vh.is_complex() and torch.cuda.is_available(): + Vh = Vh.conj() + return U, S, Vh + + +def _svd_aten( + A: TensorLikeType, *, full_matrices: bool +) -> tuple[Tensor, Tensor, Tensor]: + return torch.linalg.svd(A, full_matrices=full_matrices) + + +_svd_doc = """ + Returns the SVD of a matrix or batch of matrices. + + The `full_matrices` flag controls whether the full or reduced SVD decomposition is returned. +""" + +svd = _make_prim( + schema="svd(Tensor A, *, bool full_matrices) -> (Tensor U, Tensor S, Tensor Vh)", + meta=_svd_meta, + impl_aten=_svd_aten, + return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW, RETURN_TYPE.NEW), + doc=_svd_doc, +) + + +# +# Randomness Prims +# + + +def _normal_meta( + shape: ShapeType, + *, + mean: Union[float, complex], + std: float, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, + generator: Optional[torch.Generator] = None, +) -> TensorLikeType: + torch._check( + std >= 0.0, + lambda: f"expected non-negative standard deviation, but got std={std}", + ) + + torch._check( + utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), + lambda: f"expected a floating-point or complex dtype, but got dtype={dtype}", + ) + + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device) + + +def _normal_aten( + shape: ShapeType, + *, + mean: Union[float, complex], + std: float, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, + generator: Optional[torch.Generator] = None, +) -> Tensor: + a = torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad) + with torch.no_grad(): + # NOTE: normal_ is incorrectly annotated to expect mean to be a float + a.normal_(mean, std, generator=generator) # type: ignore[arg-type] + return a + + +_normal_doc = """ + Constructs a tensor filled with values drawn from a normal distribution with the specified mean + and standard deviation. + + Only supports floating-point types. +""" + +normal = _make_prim( + schema=( + "normal(SymInt[] shape, *, Scalar mean, Scalar std, ScalarType dtype, Device device, bool requires_grad, Generator? generator=None) -> Tensor" # noqa: B950 + ), + return_type=RETURN_TYPE.NEW, + meta=_normal_meta, + impl_aten=_normal_aten, + doc=_normal_doc, +) + + +def _uniform_meta( + shape: ShapeType, + *, + low: float, + high: float, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator] = None, +) -> TensorLikeType: + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device) + + +def _uniform_aten( + shape: ShapeType, + *, + low: float, + high: float, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator] = None, +) -> Tensor: + a = torch.empty(shape, dtype=dtype, device=device) + a.uniform_(low, high, generator=generator) + return a + + +_uniform_doc = """ + Constructs a tensor filled with values drawn uniformly from low to high. +""" + +# TODO: we should more seriously review randomness modeling and prims +_uniform_helper = _make_prim( + schema=( + "uniform(SymInt[] shape, *, Scalar low, Scalar high, ScalarType dtype, Device device, Generator? generator=None) -> Tensor" + ), + return_type=RETURN_TYPE.NEW, + meta=_uniform_meta, + impl_aten=_uniform_aten, + doc=_uniform_doc, +) + +# +# FFT prims +# + + +def _fft_r2c_meta( + input: TensorLike, + *, + dim: DimsSequenceType, + onesided: bool, +) -> TensorLikeType: + dim = utils.canonicalize_dims(input.ndim, dim) + utils.validate_no_repeating_dims(dim) + + shape = list(input.shape) + if onesided: + last_dim = dim[-1] + shape[last_dim] = shape[last_dim] // 2 + 1 + + dtype = utils.corresponding_complex_dtype(input.dtype) + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device) + + +def _fft_r2c_aten( + input: TensorLike, + *, + dim: DimsSequenceType, + onesided: bool, +) -> TensorLikeType: + normalization = 0 # No normalization + return torch._fft_r2c(input, dim, normalization, onesided) + + +_fft_r2c_doc = """ + Performs a real to complex Fast Fourier Transform +""" + + +fft_r2c = _make_prim( + schema="fft_r2c(Tensor self, *, int[] dim, bool onesided) -> Tensor", + meta=_fft_r2c_meta, + impl_aten=_fft_r2c_aten, + return_type=RETURN_TYPE.NEW, + doc=_fft_r2c_doc, +) + + +def _fft_c2c_meta( + input: TensorLike, + *, + dim: DimsSequenceType, + forward: bool, +) -> TensorLikeType: + dim = utils.canonicalize_dims(input.ndim, dim) + utils.validate_no_repeating_dims(dim) + + shape = input.shape + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta( + shape=shape, strides=strides, dtype=input.dtype, device=input.device + ) + + +def _fft_c2c_aten( + input: TensorLike, + *, + dim: DimsSequenceType, + forward: bool, +) -> TensorLikeType: + normalization = 0 # No normalization + return torch._fft_c2c(input, dim, normalization, forward) + + +_fft_c2c_doc = """ + Performs either a Fast Fourier Transform, or its inverse +""" + + +fft_c2c = _make_prim( + schema="fft_c2c(Tensor self, *, int[] dim, bool forward) -> Tensor", + meta=_fft_c2c_meta, + impl_aten=_fft_c2c_aten, + return_type=RETURN_TYPE.NEW, + doc=_fft_c2c_doc, +) + + +def _fft_c2r_meta( + input: TensorLike, + *, + dim: DimsSequenceType, + last_dim_size: int, +) -> TensorLikeType: + dim = utils.canonicalize_dims(input.ndim, dim) + utils.validate_no_repeating_dims(dim) + + shape = list(input.shape) + shape[dim[-1]] = last_dim_size + dtype = utils.corresponding_real_dtype(input.dtype) + strides = utils.make_contiguous_strides_for(shape) + return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device) + + +def _fft_c2r_aten( + input: TensorLike, + *, + dim: DimsSequenceType, + last_dim_size: int, +) -> TensorLikeType: + normalization = 0 # No normalization + return torch._fft_c2r(input, dim, normalization, last_dim_size) + + +_fft_c2r_doc = """ + Performs a complex to real Inverse Fast Fourier Transform +""" + + +fft_c2r = _make_prim( + schema="fft_c2r(Tensor self, *, int[] dim, SymInt last_dim_size) -> Tensor", + meta=_fft_c2r_meta, + impl_aten=_fft_c2r_aten, + return_type=RETURN_TYPE.NEW, + doc=_fft_c2r_doc, +) + + +def _frexp_meta(self: TensorLikeType) -> tuple[TensorLikeType, TensorLikeType]: + torch._check( + self.dtype.is_floating_point, + lambda: "torch.frexp() only supports floating-point dtypes", + ) + return torch.empty_like(self), torch.empty_like(self, dtype=torch.int32) + + +frexp = _make_prim( + schema="frexp(Tensor self) -> (Tensor mantissa, Tensor exponent)", + meta=_frexp_meta, + return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW), + impl_aten=torch.frexp, + doc="", +) + + +def _make_token_aten() -> TensorLikeType: + return new_token_tensor() + + +_make_token = _make_prim( + schema="_make_token() -> Tensor", + meta=_make_token_aten, + return_type=RETURN_TYPE.NEW, + impl_aten=_make_token_aten, + doc="Creates a token used for keeping track of side effects.", +) + + +def _sink_tokens_aten(tokens) -> None: + pass + + +_sink_tokens = _make_prim( + schema="_sink_tokens(Tensor[] tokens) -> ()", + meta=_sink_tokens_aten, + return_type=RETURN_TYPE.NONE, + impl_aten=_sink_tokens_aten, + doc="Sink all of the tokens which were previously used for keeping track of side effects.", +) + + +register_rng_prims() +register_debug_prims() diff --git a/phivenv/Lib/site-packages/torch/_prims/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_prims/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..966a62f2581b4790b56ba69d7eb4e9cbd54183dc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_prims/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_prims/__pycache__/context.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_prims/__pycache__/context.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91fa96671ce12a48add4d0fd4dad0f5411699d7f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_prims/__pycache__/context.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_prims/__pycache__/debug_prims.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_prims/__pycache__/debug_prims.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..870d1e1c2a05eb15927ab9c641819946f9aa35ec Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_prims/__pycache__/debug_prims.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_prims/__pycache__/executor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_prims/__pycache__/executor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..785867c1cbc8744db84096398264adc88803729e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_prims/__pycache__/executor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_prims/__pycache__/rng_prims.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_prims/__pycache__/rng_prims.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..928d30e097dc10f5c1ec667e5e698055d34cf9f5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_prims/__pycache__/rng_prims.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_prims/context.py b/phivenv/Lib/site-packages/torch/_prims/context.py new file mode 100644 index 0000000000000000000000000000000000000000..9630ff0ed674cc836eecef9aad18b82873bd249d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_prims/context.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import functools +from contextlib import nullcontext +from typing import Any, Callable, TYPE_CHECKING, TypeVar +from typing_extensions import ParamSpec + + +if TYPE_CHECKING: + from collections.abc import Sequence + +import torch +import torch._decomp +import torch._prims +import torch._refs +import torch._refs.nn +import torch._refs.nn.functional +import torch._refs.special +import torch.overrides +from torch._prims_common import torch_function_passthrough + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +@functools.cache +def torch_to_refs_map() -> dict[Any, Any]: + """ + Mapping of torch API functions to torch._refs functions. + E.g. torch_to_refs_map()[torch.add] == torch._refs.add + """ + modules = [ + (torch, torch._refs), + (torch.nn, torch._refs.nn), + (torch.nn.functional, torch._refs.nn.functional), + (torch.special, torch._refs.special), + (torch.fft, torch._refs.fft), + (torch.linalg, torch._refs.linalg), + ] + r: dict[Any, Any] = { + torch.Tensor.__invert__: torch._refs.bitwise_not, + torch.Tensor.__xor__: torch._refs.bitwise_xor, + torch.Tensor.__and__: torch._refs.bitwise_and, + torch.Tensor.__or__: torch._refs.bitwise_or, + torch.Tensor.__eq__: torch._refs.eq, + torch.Tensor.__rsub__: torch._refs.rsub, + torch.Tensor.__rtruediv__: torch._refs.rtruediv, + torch.Tensor.__floordiv__: torch._refs.floor_divide, + torch.Tensor.__rfloordiv__: torch._refs.rfloordiv, + torch.Tensor.__pow__: torch._refs.pow, + torch.Tensor.__rpow__: torch._refs.rpow, + torch.Tensor.new_empty: torch._refs.new_empty, + torch.Tensor.new_full: torch._refs.new_full, + torch.Tensor.new_zeros: torch._refs.new_zeros, + torch.Tensor.new_ones: torch._refs.new_ones, + torch.Tensor.fill_: torch._refs.fill_, + torch.Tensor.zero_: torch._refs.zero_, + torch.Tensor.to: torch._refs.to, + torch.Tensor.sum_to_size: torch._refs.sum_to_size, + # TODO: Should these methods be mapped some other way? + torch.Tensor.copy_: torch._prims.copy_to, + torch.Tensor.resize: torch._prims.resize, + } + for mod_torch, mod_refs in modules: + for s in mod_refs.__all__: # type: ignore[attr-defined] + r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s) + + # Support remapping torch.Tensor.foo to _refs.foo + for s in dir(torch.Tensor): + if s in torch._refs.__all__: + r[getattr(torch.Tensor, s)] = torch._refs.__dict__.get(s) + + # Support conversions + for s in torch._refs._conversions.__all__: + tensor_attr = getattr(torch.Tensor, s, None) or getattr(torch, s) + r[tensor_attr] = torch._refs._conversions.__dict__.get(s) + + return r + + +@functools.cache +def all_prims() -> set[Any]: + """ + Set of all prim functions, e.g., torch._prims.add in all_prims() + """ + return {torch._prims.__dict__.get(s) for s in torch._prims.__all__} + + +class TorchRefsMode(torch.overrides.TorchFunctionMode): + """ + Switches the interpretation of torch.* functions and Tensor methods to + use PrimTorch refs in torch._refs. (Direct calls to _refs are unaffected.) + + >>> # xdoctest: +SKIP + >>> with TorchRefsMode(): + ... torch.add(x, y) # calls torch._refs.add(x, y) + + By default, this context manager will fall back on the torch.* if the + ref does not exist; set strict=True to error if this occurs. + If the ref exists we still would like to fall back on the torch.* sometimes, + this behavior can be customized by passing a function to should_fallback_fn. + """ + + def __init__( + self, + strict: bool = False, + should_fallback_fn: Callable[..., bool] = lambda *_: False, + prims_mode_cls: type = nullcontext, + ) -> None: + self.strict = strict + self.should_fallback_fn = should_fallback_fn + self.prims_mode_cls = prims_mode_cls + + def __torch_function__( + self, + orig_func: Callable[_P, _R], + types: Sequence[type], + args: Sequence[Any] = (), + kwargs: dict[str, Any] | None = None, + ) -> Any: + if kwargs is None: + kwargs = {} + # For primitive operations, run them as is without interception + # Unless we are in prims_mode, in which case we want to use nvprims + if orig_func in torch_function_passthrough or orig_func in all_prims(): + with self.prims_mode_cls(): + return orig_func(*args, **kwargs) + mapping = torch_to_refs_map() + func = mapping.get(orig_func, None) + + # For torch.ops.aten.*, use registered decompositions from torch._decomp + # torch._decomp.decomposition_table provides a mapping from + # torch.ops.aten.* to torch._refs or torch._decomp.decompositions + # implementations. + # There're other ways to implement this functionality, + # see https://github.com/pytorch/pytorch/pull/82657#discussion_r939776417 + if func is None and isinstance(orig_func, torch._ops.OpOverload): + func = torch._decomp.decomposition_table.get(orig_func, None) + elif func is None and isinstance(orig_func, torch._ops.OpOverloadPacket): + default = getattr(orig_func, "default", None) + if default is None and orig_func._dir: + default = getattr(orig_func, orig_func._dir[0], None) + if default is not None: + func = torch._decomp.decomposition_table.get(default, None) + + if func is not None: + # If the ref exists query whether we should use it or not + if self.should_fallback_fn(self, orig_func, func, args, kwargs): + return orig_func(*args, **kwargs) + # torch calls inside func should be interpreted as refs calls + with self: + return func(*args, **kwargs) + if self.strict: + raise RuntimeError( + f"no _refs support for {torch.overrides.resolve_name(orig_func)}" + ) + return orig_func(*args, **kwargs) diff --git a/phivenv/Lib/site-packages/torch/_prims/debug_prims.py b/phivenv/Lib/site-packages/torch/_prims/debug_prims.py new file mode 100644 index 0000000000000000000000000000000000000000..53232c231fae06c0609343fde891ea0066f485f6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_prims/debug_prims.py @@ -0,0 +1,54 @@ +# mypy: allow-untyped-defs +import contextlib +from typing import Optional + +import torch +from torch.utils._content_store import ContentStoreReader + + +LOAD_TENSOR_READER: Optional[ContentStoreReader] = None + + +@contextlib.contextmanager +def load_tensor_reader(loc): + global LOAD_TENSOR_READER + assert LOAD_TENSOR_READER is None + # load_tensor is an "op", and we will play merry hell on + # Inductor's memory planning if we return a tensor that + # aliases another tensor that we previously returned from + # an operator. So unlike standard ContentStoreReader use, + # we disable the cache so that you always get fresh storages + # (no aliasing for you!) + LOAD_TENSOR_READER = ContentStoreReader(loc, cache=False) + try: + yield + finally: + LOAD_TENSOR_READER = None + + +def register_debug_prims(): + torch.library.define( + "debugprims::load_tensor", + "(str name, int[] size, int[] stride, *, ScalarType dtype, Device device) -> Tensor", + ) + + @torch.library.impl("debugprims::load_tensor", "BackendSelect") + def load_tensor_factory(name, size, stride, dtype, device): + if LOAD_TENSOR_READER is None: + from torch._dynamo.testing import rand_strided + + return rand_strided(size, stride, dtype, device) + else: + from torch._dynamo.utils import clone_input + + # device argument here takes care of coercion + r = LOAD_TENSOR_READER.read_tensor(name, device=device) + assert list(r.size()) == size, f"{r.size()} != {size}" + assert list(r.stride()) == stride, f"{r.stride()} != {stride}" + assert r.device == device, f"{r.device} != {device}" + + # Unlike the other properties, we will do coercions for dtype + # mismatch + if r.dtype != dtype: + r = clone_input(r, dtype=dtype) + return r diff --git a/phivenv/Lib/site-packages/torch/_prims/executor.py b/phivenv/Lib/site-packages/torch/_prims/executor.py new file mode 100644 index 0000000000000000000000000000000000000000..23f2d0e164e383d8fff79ca10b11fd82275cc09a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_prims/executor.py @@ -0,0 +1,67 @@ +from typing import Any, Callable, Optional, TypeVar +from typing_extensions import ParamSpec, TypeVarTuple, Unpack + +from torch._prims.context import TorchRefsMode +from torch.fx import GraphModule +from torch.fx.experimental.proxy_tensor import make_fx, wrapper_and_args_for_make_fx + + +T = TypeVar("T") +P = ParamSpec("P") +Ts = TypeVarTuple("Ts") + + +def execute( + gm: GraphModule, + *args: Unpack[Ts], + executor: str = "aten", + executor_parameters: Optional[dict] = None, +) -> Any: + """ + Prototype ATen executor. + + Just executes the context's graph. + """ + + if executor == "aten": + return gm.forward(*args) + + msg = f"Received unexpected value for 'executor': {executor}. Allowed values are: aten." + raise ValueError(msg) + + +def make_traced(fn: Callable[P, T]) -> Callable[P, T]: + """ + Returns a function that, when called, will + trace its torch operations to prims and then + execute those prims on the requested trace executor + (possibly lowering them to that trace executor first). + + Only supports the torch operations defined in _torch_to_reference_map + in context.py and operations with positional args. All args must + be tensors. + In the near future all these restrictions will be lifted. + + Example usage: + + def foo(a, b): + return torch.add(a, b) + + traced_foo = make_traced(foo) + + a = torch.randn((1, 2, 3, 4, 5), device='cuda') + b = torch.randn((1, 2, 3, 4, 5), device='cuda') + result = traced_foo(a, b, executor='aten') + """ + + def _traced(*args: P.args, **kwargs: P.kwargs) -> T: + executor = str(kwargs.pop("executor", "aten")) + + # TODO: caching + wrapped, all_args = wrapper_and_args_for_make_fx(fn, args, kwargs) + + with TorchRefsMode(): + gm = make_fx(wrapped)(all_args) + return execute(gm, all_args, executor=executor) + + return _traced # type: ignore[return-value] diff --git a/phivenv/Lib/site-packages/torch/_prims/rng_prims.py b/phivenv/Lib/site-packages/torch/_prims/rng_prims.py new file mode 100644 index 0000000000000000000000000000000000000000..5cd51363b4946e36ac8d2e59efb1880866de4156 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_prims/rng_prims.py @@ -0,0 +1,389 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +import torch.utils._pytree as pytree +from torch import _prims +from torch._C import DispatchKey +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._prims_common import CUDARngStateHelper, make_contiguous_strides_for +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.types import _device, _dtype + + +def throw_on_non_cuda(device): + raise RuntimeError( + f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not " + f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is " + "not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU." + ) + + +def register_rng_prim(name, schema, impl_aten, impl_meta, doc, tags=None): + rngprim_def = torch.library.custom_op( + "rngprims::" + name, impl_aten, mutates_args=(), schema=schema + ) + rngprim_def.register_fake(impl_meta) + + prim_packet = getattr(torch._ops.ops.rngprims, name) + prim = prim_packet.default + if tags: + prim._tags = tags + + for p in (prim_packet, prim): + p.__doc__ = doc + p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined] + + p.schema = name + schema + p.impl_aten = impl_aten + p.prim_meta_impl = impl_meta + + +# Philox rand offsets could be shared in future with other philox ops, so +# keeping these functions in global scope. +def philox_rand_offset_meta( + shape: torch.Size, +): + return _prims.TensorLike(torch.tensor(0, dtype=torch.int64)) + + +def philox_rand_offset( + shape: torch.Size, +): + # For impl, look at the function calc_execution_policy in the file + # aten/src/ATen/native/cuda/DistributionTemplates.h. The impl was copied at + # commit hash 72aa0667bd16707d50eb8fa337092a1f5d11dfb6 + numel_scalar = 1 + for dim_size in shape: + numel_scalar *= dim_size + numel = torch.scalar_tensor(numel_scalar, dtype=torch.int64) + + block_size = 256 + unroll = 4 + curand4_engine_calls = 4 + device_property = torch.cuda.get_device_properties(torch.cuda.current_device()) + blocks_per_sm = device_property.max_threads_per_multi_processor // block_size + grid_size = (numel + block_size - 1) // block_size + grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm) + offset = ( + (numel - 1) // (block_size * grid_size * unroll) + 1 + ) * curand4_engine_calls + return offset + + +def register_philox_rand(): + name = "philox_rand" + schema = "(SymInt[] size, Tensor seed, Tensor offset, int[]? stride, Device? device=None, ScalarType? dtype=None) -> (Tensor, Tensor)" # noqa: B950 + + def _philox_rand_meta( + shape: torch.Size, + seed: torch.Tensor, + offset: torch.Tensor, + stride: Optional[tuple[int, ...]], + device: _device, + dtype: _dtype, + ): + # stride arg will be useful for distributed usecase. Currently, its unused. + assert stride is None + stride = make_contiguous_strides_for(shape) + random_values = _prims.TensorMeta( + shape=shape, strides=stride, dtype=dtype, device=device + ) + offset = philox_rand_offset_meta(shape) + return (random_values, offset) + + def _philox_rand( + shape: torch.Size, + seed: torch.Tensor, + offset: torch.Tensor, + stride: Optional[tuple[int, ...]], + device: _device, + dtype: _dtype, + ): + # stride arg will be useful for distributed usecase. Currently, its unused. + assert stride is None + if device.type == "cpu": + devices = [] + else: + devices = [device] + + if device.type != "cuda": + raise throw_on_non_cuda(device) + + with torch.random.fork_rng(devices): + CUDARngStateHelper.set_torch_state_tensor(seed, offset) + random_values = torch.rand(shape, device=device, dtype=dtype) + + return random_values, philox_rand_offset(shape) + + register_rng_prim( + name=name, + schema=schema, + impl_aten=_philox_rand, + impl_meta=_philox_rand_meta, + doc="Philox based stateless rand operator", + tags=(torch.Tag.nondeterministic_seeded,), + ) + + +def get_device(args, kwargs): + if kwargs.get("device"): + device = kwargs.get("device") + if isinstance(device, str): + device = torch.device(device) + return device.type + + devices = {arg.device.type for arg in args if isinstance(arg, torch.Tensor)} + if any(dev == "cuda" for dev in devices): + return "cuda" + elif any(dev == "xpu" for dev in devices): + return "xpu" + elif any(dev == "hpu" for dev in devices): + return "hpu" + elif any(dev == "cpu" for dev in devices): + return "cpu" + return None + + +def register_run_and_save_rng_state_op(): + class RunAndSaveRngState(HigherOrderOperator): + def __init__(self): + super().__init__("run_and_save_rng_state") + + def __call__(self, op, *args, **kwargs): + return super().__call__(op, *args, **kwargs) + + run_and_save_rng_state = RunAndSaveRngState() + + run_and_save_rng_state.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(run_and_save_rng_state, deferred_error=True) + ) + + @run_and_save_rng_state.py_impl(DispatchKey.CUDA) + def impl_cuda(op, *args, **kwargs): + return torch.cuda.get_rng_state(), op(*args, **kwargs) + + @run_and_save_rng_state.py_impl(DispatchKey.CPU) + def impl_cpu(op, *args, **kwargs): + return torch.get_rng_state(), op(*args, **kwargs) + + @run_and_save_rng_state.py_impl(DispatchKey.HPU) + def impl_hpu(op, *args, **kwargs): + if hasattr(torch, "hpu"): + return torch.hpu.get_rng_state(), op(*args, **kwargs) + raise RuntimeError("functionalize a hpu RNG operator is not supported.") + + @run_and_save_rng_state.py_impl(DispatchKey.XPU) + def impl_xpu(op, *args, **kwargs): + return torch.xpu.get_rng_state(), op(*args, **kwargs) + + @run_and_save_rng_state.py_impl(DispatchKey.BackendSelect) + def impl_backend_select(op, *args, **kwargs): + impl_map = { + "cuda": impl_cuda, + "cpu": impl_cpu, + "hpu": impl_hpu, + "xpu": impl_xpu, + } + device = get_device(args, kwargs) + assert device in impl_map, f"Backend not supported for {device}" + impl = impl_map[device] + return impl(op, *args, **kwargs) + + @run_and_save_rng_state.py_impl(FakeTensorMode) + def impl_fake_tensor_mode(mode, op, *args, **kwargs): + # Check device to call the right impl + with mode: + return impl_backend_select(op, *args, **kwargs) + + @run_and_save_rng_state.py_impl(ProxyTorchDispatchMode) + def impl_proxy_dispatch_mode(mode, op, *args, **kwargs): + out = impl_backend_select(op, *args, **kwargs) + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args)) + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + out_proxy = mode.tracer.create_proxy( + "call_function", run_and_save_rng_state, proxy_args, proxy_kwargs + ) + return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + + return run_and_save_rng_state + + +def register_run_with_rng_state_op(): + class RunWithRngState(HigherOrderOperator): + def __init__(self): + super().__init__("run_with_rng_state") + + def __call__(self, rng_state, op, *args, **kwargs): + return super().__call__(rng_state, op, *args, **kwargs) + + run_with_rng_state = RunWithRngState() + + run_with_rng_state.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(run_with_rng_state, deferred_error=True) + ) + + @run_with_rng_state.py_impl(DispatchKey.CUDA) + def impl_cuda(rng_state, op, *args, **kwargs): + current_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state.cpu()) + out = op(*args, **kwargs) + torch.cuda.set_rng_state(current_state) + return out + + @run_with_rng_state.py_impl(DispatchKey.CPU) + def impl_cpu(rng_state, op, *args, **kwargs): + current_state = torch.get_rng_state() + torch.set_rng_state(rng_state) + out = op(*args, **kwargs) + torch.set_rng_state(current_state) + return out + + @run_with_rng_state.py_impl(DispatchKey.HPU) + def impl_hpu(rng_state, op, *args, **kwargs): + if hasattr(torch, "hpu"): + current_state = torch.hpu.get_rng_state() + torch.hpu.set_rng_state(rng_state) + out = op(*args, **kwargs) + torch.hpu.set_rng_state(current_state) + return out + raise RuntimeError("functionalize a hpu RNG operator is not supported.") + + @run_with_rng_state.py_impl(DispatchKey.XPU) + def impl_xpu(rng_state, op, *args, **kwargs): + current_state = torch.xpu.get_rng_state() + torch.xpu.set_rng_state(rng_state) + out = op(*args, **kwargs) + torch.xpu.set_rng_state(current_state) + return out + + @run_with_rng_state.py_impl(ProxyTorchDispatchMode) + def impl_proxy_dispatch_mode(mode, rng_state, op, *args, **kwargs): + # TODO: you don't need to do this, the dispatch here already disabled + # it + with disable_proxy_modes_tracing(): + out = run_with_rng_state(rng_state, op, *args, **kwargs) + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (rng_state, op, *args)) + proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) + out_proxy = mode.tracer.create_proxy( + "call_function", run_with_rng_state, proxy_args, proxy_kwargs + ) + return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + + @run_with_rng_state.py_impl(DispatchKey.BackendSelect) + def impl_backend_select(rng_state, op, *args, **kwargs): + impl_map = { + "cuda": impl_cuda, + "cpu": impl_cpu, + "hpu": impl_hpu, + "xpu": impl_xpu, + } + device = get_device(args, kwargs) + assert device in impl_map, f"Backend not supported for {device}" + impl = impl_map[device] + return impl(rng_state, op, *args, **kwargs) + + @run_with_rng_state.py_impl(FakeTensorMode) + def impl_fake_tensor_mode(mode, rng_state, op, *args, **kwargs): + # Skip setting the set_rng_state as it does not work well with fake tensors. + # And it does not matter for the fake tensor mode. + with mode: + return op(*args, **kwargs) + + @run_with_rng_state.py_functionalize_impl + def impl_functional(ctx, rng_state, op, *args, **kwargs): + unwrapped_rng_state = ctx.unwrap_tensors(rng_state) + unwrapped_args = ctx.unwrap_tensors(args) + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + + with ctx.redispatch_to_next(): + out = run_with_rng_state( + unwrapped_rng_state, op, *unwrapped_args, **unwrapped_kwargs + ) + return ctx.wrap_tensors(out) + + return run_with_rng_state + + +run_and_save_rng_state = register_run_and_save_rng_state_op() +run_with_rng_state = register_run_with_rng_state_op() + + +def register_graphsafe_run_with_rng_state_op(): + class GraphSafeRunWithRngState(HigherOrderOperator): + def __init__(self): + super().__init__("graphsafe_run_with_rng_state") + + def __call__(self, op, *args, rng_state=None, **kwargs): + return super().__call__(op, *args, rng_state=rng_state, **kwargs) + + graphsafe_run_with_rng_state = GraphSafeRunWithRngState() + + graphsafe_run_with_rng_state.py_impl(DispatchKey.Autograd)( + autograd_not_implemented(graphsafe_run_with_rng_state, deferred_error=True) + ) + + @graphsafe_run_with_rng_state.py_impl(DispatchKey.CUDA) + def impl_cuda(op, *args, rng_state=None, **kwargs): + device_idx = rng_state.device.index + generator = torch.cuda.default_generators[device_idx] + current_state = generator.graphsafe_get_state() + generator.graphsafe_set_state(rng_state) + out = op(*args, **kwargs) + generator.graphsafe_set_state(current_state) + return out + + @graphsafe_run_with_rng_state.py_impl(DispatchKey.BackendSelect) + def impl_backend_select(op, *args, rng_state=None, **kwargs): + device = get_device(args, kwargs) + assert ( + device == "cuda" + ), f"GraphSafe RNG operations only supported for CUDA, got {device}" + return impl_cuda(op, *args, rng_state=rng_state, **kwargs) + + @graphsafe_run_with_rng_state.py_impl(FakeTensorMode) + def impl_fake_tensor_mode(mode, op, *args, rng_state=None, **kwargs): + with mode: + return op(*args, **kwargs) + + @graphsafe_run_with_rng_state.py_impl(ProxyTorchDispatchMode) + def impl_proxy_dispatch_mode(mode, op, *args, rng_state=None, **kwargs): + with disable_proxy_modes_tracing(): + out = graphsafe_run_with_rng_state(op, *args, rng_state=rng_state, **kwargs) + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args)) + proxy_kwargs = pytree.tree_map( + mode.tracer.unwrap_proxy, {"rng_state": rng_state, **kwargs} + ) + out_proxy = mode.tracer.create_proxy( + "call_function", graphsafe_run_with_rng_state, proxy_args, proxy_kwargs + ) + return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + + @graphsafe_run_with_rng_state.py_functionalize_impl + def impl_functional(ctx, op, *args, rng_state=None, **kwargs): + unwrapped_rng_state = ( + ctx.unwrap_tensors(rng_state) if rng_state is not None else None + ) + unwrapped_args = ctx.unwrap_tensors(args) + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + + with ctx.redispatch_to_next(): + out = graphsafe_run_with_rng_state( + op, *unwrapped_args, rng_state=unwrapped_rng_state, **unwrapped_kwargs + ) + return ctx.wrap_tensors(out) + + return graphsafe_run_with_rng_state + + +graphsafe_run_with_rng_state = register_graphsafe_run_with_rng_state_op() + + +def register_rng_prims(): + register_philox_rand() diff --git a/phivenv/Lib/site-packages/torch/_prims_common/__init__.py b/phivenv/Lib/site-packages/torch/_prims_common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7a3c8ffc5a027074af4c1e8610b04cab681d27f9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_prims_common/__init__.py @@ -0,0 +1,2107 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import operator +import typing +import warnings +from collections.abc import Sequence +from contextlib import nullcontext +from enum import Enum +from functools import reduce +from typing import ( + Any, + Callable, + cast, + NamedTuple, + Optional, + overload, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import deprecated, TypeAlias + +import torch +from torch import sym_float, sym_int, sym_max + + +if TYPE_CHECKING: + # Import the following modules during type checking to enable code intelligence features, + # such as auto-completion in tools like pylance, even when these modules are not explicitly + # imported in user code. + + import sympy + + class _WorksWithInt(typing.Protocol): + def __add__(self, other: Any) -> typing.Self: + ... + + def __radd__(self, other: Any) -> typing.Self: + ... + + def __mul__(self, other: Any) -> typing.Self: + ... + + def __rmul__(self, other: Any) -> typing.Self: + ... + + _IntLikeT = TypeVar("_IntLikeT", bound=_WorksWithInt) + + +ShapeType: TypeAlias = Union[torch.Size, list[int], tuple[int, ...]] +StrideType: TypeAlias = Union[list[int], tuple[int, ...]] +DimsType: TypeAlias = Union[int, list[int], tuple[int, ...]] +DimsSequenceType: TypeAlias = Union[list[int], tuple[int, ...]] +# TODO: Type[torch.SymInt], Type[torch.SymFloat] +NumberTypeType: TypeAlias = Union[type[bool], type[int], type[float], type[complex]] +# TODO: This needs a lot more type annotations +# NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat] +NumberType: TypeAlias = Union[bool, int, float, complex] +RealNumberType: TypeAlias = Union[bool, int, float] + +Number = (bool, int, float, complex, torch.SymInt, torch.SymFloat, torch.SymBool) +# I don't call it Integral because numbers.Integral includes bool, but IntLike +# does not +Dim = int +IntLike = (int, torch.SymInt) +FloatLike = (float, torch.SymFloat) +BoolLike = (bool, torch.SymBool) +IntWithoutSymInt = int +FloatWithoutSymFloat = float +DeviceLikeType: TypeAlias = Union[str, torch.device, int] +Tensor = torch.Tensor + + +torch_function_passthrough = { + torch.device, + torch.sym_not, + torch.sym_float, + torch.sym_int, + torch.sym_max, + torch.sym_min, + torch._sym_sqrt, # type: ignore[attr-defined] + torch.sym_ite, + torch.Tensor.dim, + torch.Tensor.ndim.__get__, # type: ignore[attr-defined] + torch.Tensor.numel, + torch.Tensor.size, + torch.Tensor.storage_offset, + torch.Tensor.stride, + torch.Tensor.dtype.__get__, # type: ignore[attr-defined] + torch.Tensor.is_sparse.__get__, # type: ignore[attr-defined] + torch.Tensor.shape.__get__, # type: ignore[attr-defined] + torch.Tensor.device.__get__, # type: ignore[attr-defined] + torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined] + torch.Tensor.layout.__get__, # type: ignore[attr-defined] + torch.Tensor.is_contiguous, + # For TorchRefsMode only + torch.Tensor.__format__, + torch.Tensor.__repr__, + torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined] + torch.Tensor.__getitem__, +} + + +TensorLikeType = torch.Tensor +TensorLike = torch.Tensor +TensorSequenceType: TypeAlias = Union[list[TensorLikeType], tuple[TensorLikeType, ...]] +TensorOrNumberLikeType: TypeAlias = Union[TensorLikeType, NumberType] + +CustomOutParamAnnotation = "__custom_out_param__" + + +def same_shape(a: ShapeType, b: ShapeType, *, allow_rhs_unbacked=False) -> bool: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if len(a) != len(b): + return False + + for x, y in zip(a, b): + if allow_rhs_unbacked: + # TODO: We should check that the symbols are consistent + # with each other + if isinstance(y, torch.SymInt): + continue + # NB: Naively, you would not expect to have to do an oblivious guard + # here because there is seemingly no broadcasting here, but in fact we + # use this in some situations to determine if we need to do an expand + # on the tensor because they don't line up, so you can definitely end + # up trying to prove u0 != 1 in this situation. See + # python test/test_proxy_tensor.py -k test_cumsum_unbacked + if guard_size_oblivious(x != y): + return False + + return True + + +def _maybe_get_pytype(t): + if t is torch.SymFloat: + return float + elif t is torch.SymInt: + return int + elif t is torch.SymBool: + return bool + else: + return t + + +# TODO: look at using torch.testing.assert_close instead with an option +# to just compare metadata +def compare_tensor_meta( + a: TensorLikeType, + b: TensorLikeType, + check_sizes=True, + check_strides=False, + *, + allow_rhs_unbacked=False, + check_conj=True, +): + """ + Checks that two tensor likes have the same shape, + dtype and device. + + In the future this will validate additional metadata, like + strides. + """ + from torch._subclasses.fake_tensor import MetadataMismatchError + + assert isinstance(a, TensorLike) + assert isinstance(b, TensorLike) + + if check_sizes and not same_shape( + a.shape, b.shape, allow_rhs_unbacked=allow_rhs_unbacked + ): + msg = f"Shapes {a.shape} and {b.shape} are not equal!" + raise MetadataMismatchError(msg) + + if a.dtype != b.dtype: + msg = f"Dtypes {a.dtype} and {b.dtype} are not equal!" + raise MetadataMismatchError(msg) + + if a.device != b.device: + # Handles special cuda:0 vs cuda case + # TODO: we should review why this happens and see about fixing it + if (str(a.device) == "cuda:0" or str(a.device) == "cuda") and ( + str(b.device) == "cuda:0" or str(b.device) == "cuda" + ): + pass + else: + msg = f"Devices {a.device} and {b.device} are not equal!" + raise MetadataMismatchError(msg) + + # Stride checking is currently disabled, see https://github.com/pytorch/pytorch/issues/78050 + if check_strides: + same_strides, idx = check_significant_strides( + a, b, allow_rhs_unbacked=allow_rhs_unbacked + ) + if not same_strides: + msg = f"Stride mismatch! Strides are {a.stride()} and {b.stride()} (mismatched at {idx})!" + raise MetadataMismatchError(msg) + + if a.storage_offset() != b.storage_offset(): + msg = f"Storage offset mismatch! Storage offsets are {a.storage_offset()} and {b.storage_offset()}!" + raise MetadataMismatchError(msg) + + if check_conj: + if a.is_conj() != b.is_conj(): + raise MetadataMismatchError( + f"Conj mismatch! is_conj is set to {a.is_conj()} and {b.is_conj()}" + ) + + if a.is_neg() != b.is_neg(): + raise MetadataMismatchError( + f"Neg mismatch! is_neg is set to {a.is_neg()} and {b.is_neg()}" + ) + + +def _check_strides_helper( + a: TensorLikeType, + b: TensorLikeType, + *, + only_cuda=True, + significant_only=True, + allow_rhs_unbacked=False, +) -> tuple[bool, Optional[int]]: + # NOTE: only on CUDA because CPU elementwise strides are incorrect in PyTorch + # See https://github.com/pytorch/pytorch/issues/77553 + # Only compares strides that are "meaningful" -- strides for dimensions with length > 1 + # and for tensors with more than one element + if ( + not only_cuda or a.device.type == "cuda" or b.device.type == "cuda" + ) and a.numel() > 0: + for idx in range(a.ndim): + check = not significant_only or a.shape[idx] > 1 + # TODO: Check the symbols are consistent with each other + if isinstance(b.stride()[idx], torch.SymInt): + continue + if a.stride()[idx] != b.stride()[idx] and check: + return False, idx + + return True, None + + +def check_significant_strides( + a: TensorLikeType, b: TensorLikeType, *, only_cuda=True, allow_rhs_unbacked=False +) -> tuple[bool, Optional[int]]: + return _check_strides_helper( + a, + b, + only_cuda=only_cuda, + significant_only=True, + allow_rhs_unbacked=allow_rhs_unbacked, + ) + + +def check_all_strides( + a: TensorLikeType, b: TensorLikeType, *, only_cuda=True +) -> tuple[bool, Optional[int]]: + return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False) + + +# This function is equivalent to compute_contiguous() from TensorImpl.cpp +def is_contiguous(a: TensorLikeType, false_if_dde=False) -> bool: + """ + Tests whether a tensor is contiguous or not. + + Tensors are contiguous when they have no elements, + one element, or when they have "nested" strides. + """ + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + guard_size_oblivious, + is_nested_int, + ) + + maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious + maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious + + if maybe_guard_or_false(a.numel() < 2): + return True + + expected_stride = 1 + for x, y in reversed(tuple(zip(a.shape, a.stride()))): + # Skips checking strides when a dimension has length 1. + if maybe_guard_or_false(x == 1): + continue + + if maybe_guard_or_true(y != expected_stride): + return False + + # if x is 0 then a is contiguous anyway. So in the check above for non-contiguity condition we can + # can assume x is not 0 in expected_stride equation. This make the check consistent with + # make_contiguous_strides_for. If we make a tensor and used strides from make_contiguous_strides_for + # and then called definitely_contiguous we should get True. + expected_stride *= ( + x if is_nested_int(x) else sym_max(x, 1) + ) # type:ignore[assignment] + + return True + + +# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp +def is_channels_last_contiguous_2d(a: Tensor, false_if_dde=False) -> bool: + # NHWC or not channels last 2D contiguous + if a.ndim != 4: + return False + + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + guard_size_oblivious, + ) + + maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious + maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious + + expected_stride = 1 + for idx in (1, 3, 2, 0): + length = a.shape[idx] + if maybe_guard_or_false(length == 1): + continue + + stride = a.stride()[idx] + if maybe_guard_or_true(stride != expected_stride): + return False + + expected_stride *= length + + return True + + +def is_channels_last_contiguous_3d(a: Tensor, false_if_dde=False) -> bool: + # NDHWC or not channels last 3D contiguous + if a.ndim != 5: + return False + + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + guard_size_oblivious, + ) + + maybe_guard_or_false = guard_or_false if false_if_dde else guard_size_oblivious + maybe_guard_or_true = guard_or_true if false_if_dde else guard_size_oblivious + + expected_stride = 1 + for idx in (1, 4, 3, 2, 0): + length = a.shape[idx] + if maybe_guard_or_false(length == 1): + continue + + stride = a.stride()[idx] + if maybe_guard_or_true(stride != expected_stride): + return False + + expected_stride *= length + + return True + + +_memory_formats = { + torch.contiguous_format, + torch.preserve_format, + torch.channels_last, + torch.channels_last_3d, +} + + +def validate_memory_format(memory_format: torch.memory_format): + torch._check( + memory_format in _memory_formats, + lambda: f"Received unknown memory format {memory_format}!", + ) + + +def is_contiguous_for_memory_format( # type: ignore[return] + a: Tensor, *, memory_format: torch.memory_format, false_if_dde=False +) -> bool: + validate_memory_format(memory_format) + + if memory_format == torch.contiguous_format: + return is_contiguous(a, false_if_dde) + if memory_format == torch.channels_last: + return is_channels_last_contiguous_2d(a, false_if_dde) + if memory_format == torch.channels_last_3d: + return is_channels_last_contiguous_3d(a, false_if_dde) + + torch._check( + False, + lambda: f"is_contiguous received unsupported memory format {memory_format}", + ) + + +def definitely_contiguous(a: TensorLikeType) -> bool: + return is_contiguous(a, false_if_dde=True) + + +# similar to is_channels_last_contiguous_2d but return false on data dependency. +def definitely_channels_last_contiguous_2d(a: Tensor) -> bool: + return is_channels_last_contiguous_2d(a, false_if_dde=True) + + +# similar to is_channels_last_contiguous_3d but return false on data dependency. +def definitely_channels_last_contiguous_3d(a: Tensor) -> bool: + return is_channels_last_contiguous_3d(a, false_if_dde=True) + + +# similar to is_contiguous_for_memory_format but return false on data dependency. +def definitely_contiguous_for_memory_format( # type: ignore[return] + a: Tensor, *, memory_format: torch.memory_format +) -> bool: + return is_contiguous_for_memory_format( + a, memory_format=memory_format, false_if_dde=True + ) + + +# NOTE: that tensors with no elements and channels last is ??? +def is_channels_last_contiguous(a: Tensor) -> bool: + """ + True when a tensor is channels-last contiguous. + + This requires that: + + - the tensor is conceptually either 4 (NHWC) or 5 (NDHWC) dimensions + - if we name the tensor's dimensions NCHW or NCDHW, then the strides are such that the + stride of the 'C' dimension (Cs) is 1 and the strides corresponding to + each dimension (Xs) can be ordered Cs <= Ws <= Hs <= (Ds) <= Ns and are + "nested" -- so Ws = Cs * Cl, where Cl is the length of the 'C' dimension, + for example. + """ + return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a) + + +# similar to is_channels_last_contiguous but return false on data dependency. +def definitely_channels_last_contiguous(a: Tensor) -> bool: + return definitely_channels_last_contiguous_2d( + a + ) or definitely_channels_last_contiguous_3d(a) + + +def is_non_overlapping_and_dense(a: Tensor) -> bool: + """ + True when a tensor is non-overlapping and dense. + + A tensor is non-overlapping and dense when there exists a permutation of + its dimensions that is contiguous. + """ + + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if a.is_sparse: + return False + + # Short-circuits if the tensor is already contiguous or channels-last contiguous + if definitely_contiguous(a) or definitely_channels_last_contiguous(a): + return True + + # The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp + + # Short-circuits for tensors of rank one, which are + # non-overlapping and "dense" if their stride is one + if a.ndim == 1: + return a.stride()[0] == 1 + + # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous + # Sorts (length, stride) pairs by stride + # + # This sort is done in a size-oblivious way, which helps if we do a + # comparison like 2048*u0 > u0; we just want this to return True + # (and not worry about what if u0 is zero). + class K(NamedTuple): + size: int + stride: int + + def __lt__(self, other): + return guard_size_oblivious(self.stride < other.stride) + + def __gt__(self, other): + return guard_size_oblivious(self.stride > other.stride) + + def __le__(self, other): + return guard_size_oblivious(self.stride <= other.stride) + + def __ge__(self, other): + return guard_size_oblivious(self.stride >= other.stride) + + def __eq__(self, other): + return guard_size_oblivious(self.stride == other.stride) + + lengths_and_strides = sorted(map(K, a.shape, a.stride())) + + expected_stride = 1 + for length, stride in lengths_and_strides: + if guard_size_oblivious(length == 1): + continue + + if guard_size_oblivious(stride != expected_stride): + return False + + expected_stride *= length + + return True + + +# NOTE: Based on the implementation in TensorIterator.cpp, but note that +# the note [Computing output strides] is incorrect, because it +# says that strides will be preserved even if they are not +# "non overlapping and dense", but this is incorrect. The +# output of elementwise operations are always given +# non overlapping and dense strides. +# This is also INCORRECT because it does not model TensorIterator's +# short-circuit, which can cause different strides. +def compute_elementwise_output_logical_to_physical_perm( + *tensors, _skip_checks=False +) -> list[int]: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if not _skip_checks and len(tensors) == 0: + msg = "Can't compute elementwise output strides for zero tensors!" + raise ValueError(msg) + + if not _skip_checks: + check_same_shape(*tensors, allow_cpu_scalar_tensors=True) + + # Filters the tensors to actual tensors + if not _skip_checks: + tensors = tuple( + a + for a in tensors + if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a) + ) + + # Short-circuits for CPU scalar case + if len(tensors) == 0: + return [] + + # Short-circuits for shapes with zero or one dimensions + # TODO: are these necessary? + ndim = tensors[0].ndim + if ndim == 0: + return [] + if ndim == 1: + return [0] + + # Short-circuits if contiguous or channels last, following the fake fast path. + # This reduces the number of guards we end up making + is_contiguous = True + is_channels_last = True + for t in tensors: + is_contiguous = is_contiguous and definitely_contiguous_for_memory_format( + t, memory_format=torch.contiguous_format + ) + is_channels_last = is_channels_last and definitely_contiguous_for_memory_format( + t, memory_format=torch.channels_last + ) + + if is_contiguous and not is_channels_last: + return list(range(ndim)) + + if is_channels_last and not is_contiguous: + return [0, *list(range(2, ndim)), 1] + + shape = tensors[0].shape + + def should_swap(idx_a, idx_b): + for tensor in tensors: + stride_a = tensor.stride()[idx_a] + stride_b = tensor.stride()[idx_b] + + if guard_size_oblivious(stride_a == 0) or guard_size_oblivious( + stride_b == 0 + ): + continue + + if guard_size_oblivious(stride_a < stride_b): + return -1 + + if guard_size_oblivious(stride_a > stride_b): + return 1 + + # stride_a == stride_b + if guard_size_oblivious(shape[idx_a] > shape[idx_b]): + return 1 + + # Note: this case is hit if all strides are zero, + # or all strides are equal and all dimensions have the same length + return 0 + + # The "sort" order for the permutation is back-to-front, but + # the natural order for permutations is front-to-back. Do the + # sorting back-to-front and then reverse it on output. + # + # also, note this returns the logical to physical shape permutation + perm = list(reversed(range(ndim))) + + # insertion sort with support for ambiguous comparisons + for i in range(1, ndim): + dim1 = i + for dim0 in reversed(range(i)): + comparison = should_swap(perm[dim0], perm[dim1]) + if comparison > 0: + perm[dim0], perm[dim1] = perm[dim1], perm[dim0] + dim1 = dim0 + elif comparison < 0: + break + + return list(reversed(perm)) + + +def compute_elementwise_output_strides(*tensors) -> tuple[int, ...]: + """ + Computes the output strides for elementwise operations. + """ + if len(tensors) == 0: + msg = "Can't compute elementwise output strides for zero tensors!" + raise ValueError(msg) + + check_same_shape(*tensors, allow_cpu_scalar_tensors=True) + + # Filters the tensors to actual tensors + tensors = tuple( + a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a) + ) + + # Short-circuits for CPU scalar case + if len(tensors) == 0: + return () + + ndim = tensors[0].ndim + shape = tensors[0].shape + + if ndim == 0: + return () + if ndim == 1: + return (1,) + + logical_to_physical_perm = compute_elementwise_output_logical_to_physical_perm( + *tensors, _skip_checks=True + ) + permuted_shape = apply_perm(shape, logical_to_physical_perm) # to physical + + new_strides = make_contiguous_strides_for(permuted_shape) + permuted_strides = apply_perm( + new_strides, invert_perm(logical_to_physical_perm) + ) # to logical + + return tuple(permuted_strides) + + +# Identity permutation is [0, 1, 2] +def apply_perm(inp, perm): + ndim = len(inp) + permuted_inp = [-1] * ndim + for idx, x in enumerate(perm): + permuted_inp[idx] = inp[x] + return permuted_inp + + +def invert_perm(perm): + ndim = len(perm) + new_perm = [-1] * ndim + for idx, x in enumerate(perm): + new_perm[x] = idx + return new_perm + + +# +# Common helper functions +# + + +def validate_dim_length(length: int): + """ + Validates that an object represents a valid + dimension length. + """ + + if isinstance(length, (int, torch.SymInt)): + torch._check_is_size(length) + else: + # sometimes called with sympy expression by inductor + assert length >= 0 + + +def validate_shape(shape: ShapeType): + """ + Validates that a sequence represents a valid shape. + """ + + assert isinstance(shape, Sequence), type(shape) + for l in shape: + validate_dim_length(l) + + +def validate_strides(strides: StrideType): + """ + Verifies the object specifies valid strides. + """ + + assert isinstance(strides, Sequence) + for stride in strides: + assert stride >= 0 + + +def validate_idx(rank: int, idx: int): + """ + Validates that idx is a valid index for the given shape. + Assumes the index is already canonicalized. + """ + + assert isinstance(idx, Dim) + assert isinstance(rank, Dim) + + assert idx >= 0 and idx < rank or idx == 0 + + +def validate_dimension_indices(rank: int, indices: DimsSequenceType): + for idx in indices: + validate_idx(rank, idx) + + +def validate_exclusive_idx(rank: int, ex_idx: int): + """ + Validates that ex_idx is a valid exclusive index + for the given shape. + """ + + assert isinstance(ex_idx, Dim) + assert isinstance(rank, Dim) + assert ex_idx > 0 and ex_idx <= rank + + +# "Wraps" a dim (up to one time) for the given rank, allowing dims to be +# specified using negative indices. If `wrap_scalar` is true then scalar +# tensors of rank 0 will allow dimensions in the range [-1, 0]. Otherwise, +# idx should be in the range [-rank, rank-1]. +def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int: + if rank < 0: + msg = f"Rank cannot be negative but got {rank}" + raise IndexError(msg) + + if rank == 0: + if not wrap_scalar: + msg = f"Dimension specified as {idx} but tensor has no dimensions" + raise IndexError(msg) + rank = 1 + + if idx >= 0 and idx < rank: + return idx + + if idx < 0: + _idx = idx + rank + else: + _idx = idx + + if _idx < 0 or _idx >= rank: + # Same error message as in aten/src/ATen/WrapDimUtils.h:49 + msg = f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], but got {idx})" + raise IndexError(msg) + + return _idx + + +# Takes a dimension or sequence of dimensions and "wraps" them, +# mapping negative offsets to positive ones +@overload +def canonicalize_dims( + rank: int, indices: Sequence[int], wrap_scalar: bool = True +) -> tuple[int, ...]: + pass + + +@overload +def canonicalize_dims(rank: int, indices: int, wrap_scalar: bool = True) -> int: + pass + + +def canonicalize_dims(rank, indices, wrap_scalar=True): + if isinstance(indices, Dim): + return canonicalize_dim(rank, indices, wrap_scalar) + + return tuple(canonicalize_dim(rank, x, wrap_scalar) for x in indices) + + +def is_valid_permutation(rank: int, perm: DimsSequenceType) -> bool: + """ + Validates that perm is a permutation of length rank. + """ + + return isinstance(perm, Sequence) and sorted(perm) == list(range(rank)) + + +def is_same_shape(a: Sequence, b: Sequence) -> bool: + """ + Compares two shapes a and b, returning True if they are the same + (their ranks and corresponding lengths match) and False otherwise. + """ + + return tuple(a) == tuple(b) + + +def is_cpu_scalar_tensor(a: Any) -> bool: + return isinstance(a, TensorLike) and a.ndim == 0 and a.device.type == "cpu" + + +def check_same_device(*args, allow_cpu_scalar_tensors): + """ + Checks that all Tensors in args have the same device. + + Raises a RuntimeError when: + - args contains an object whose type is not Tensor or Number + - two Tensor objects in args have different devices, unless one is a CPU scalar tensor and allow_cpu_scalar_tensors is True + """ + # Short-circuits if all (one or fewer) arguments are trivially on the same device + if len(args) <= 1: + return + + # Note: cannot initialize device to the first arg's device (it may not have one) + device = None + for arg in args: + if isinstance(arg, Number): + continue + elif isinstance(arg, TensorLike): + if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg): + continue + + if device is None: + device = arg.device + + if device != arg.device: + msg = ( + "Tensor on device " + + str(arg.device) + + " is not on the expected device " + + str(device) + + "!" + ) + raise RuntimeError(msg) + else: + msg = ( + "Unexpected type when checking for same device, " + str(type(arg)) + "!" + ) + raise RuntimeError(msg) + + +def canonicalize_device(device: DeviceLikeType) -> torch.device: + if isinstance(device, torch.device): + return device + + assert isinstance(device, str) + return torch.device(device) + + +# Asserts if any of the following are true: +# - a non-scalar or non-Tensor is given +# - the shape of any tensors is distinct +def check_same_shape(*args, allow_cpu_scalar_tensors: bool): + """ + Checks that all Tensors in args have the same shape. + + Raises a RuntimeError when: + - args contains an object whose type is not Tensor or Number + - two Tensor objects in args have different devices + """ + shape = None + + for arg in args: + if isinstance(arg, Number): + continue + elif isinstance(arg, TensorLike): + if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg): + continue + + if shape is None: + shape = arg.shape + + if not is_same_shape(shape, arg.shape): + msg = f"Shape {arg.shape} is not the expected shape {shape}!" + raise RuntimeError(msg) + else: + msg = ( + "Unexpected type when checking for same shape, " + str(type(arg)) + "!" + ) + raise RuntimeError(msg) + + +# Acquires a common shape, if it exists, from one or more tensor arguments, +# filtering number arguments +def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]: + shape = None + scalar_shape = None + + for arg in args: + if isinstance(arg, Number): + continue + elif isinstance(arg, TensorLike): + if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg): + scalar_shape = arg.shape + continue + + if shape is None: + shape = arg.shape + + if not is_same_shape(shape, arg.shape): + return None + else: + return None + + return shape if shape is not None else scalar_shape + + +# Extracts dimensions that might be passed either as a list/tuple or as varargs. +# A typical case is Tensor.permute . +def extract_dims_from_varargs( + dims: Union[DimsSequenceType, tuple[DimsSequenceType, ...]] +) -> DimsSequenceType: + if dims and isinstance(dims[0], Sequence): + assert len(dims) == 1 + dims = cast(tuple[DimsSequenceType], dims) + return dims[0] + else: + return cast(DimsSequenceType, dims) + + +def extract_shape_from_varargs( + shape: Union[ShapeType, tuple[ShapeType]], + validate=True, +) -> tuple[int, ...]: + """ + Returns a shape from varargs. + + In PyTorch, operations that accept shapes often accept them as varargs, like + foo(*shape). However a user can pass the shape as a sequence of integers, + like this: + + foo(1, 2, 3) + + or as a sequence of integers + + foo((1, 2, 3)) + + In the first case shape will be a tuple of integers, and in the second case it's a tuple + containing a tuple of integers. This validates those inputs and canonicalizes them + to a tuple of integers. + """ + + # Handles tuple unwrapping + if len(shape) == 1 and isinstance(shape[0], Sequence): + shape = shape[0] + + if validate: + validate_shape(shape) # type: ignore[arg-type] + return shape # type: ignore[return-value] + + +def infer_size_shapes(a: ShapeType, b: ShapeType) -> tuple[int, ...]: + ndim = max(len(a), len(b)) + expandedSizes = [0] * ndim + + for i in range(ndim - 1, -1, -1): + offset = ndim - 1 - i + dimA = len(a) - 1 - offset + dimB = len(b) - 1 - offset + sizeA = a[dimA] if dimA >= 0 else 1 + sizeB = b[dimB] if dimB >= 0 else 1 + + torch._check( + (sizeA == sizeB) or (sizeA == 1) or (sizeB == 1), + lambda: ( + f"The size of tensor a ({sizeA}) must match the size of " + f"tensor b ({sizeB}) at non-jagged dimension {i}" + ), + ) + + # 1s map to the other size (even 0) + expandedSizes[i] = sizeB if sizeA == 1 else sizeA + + return tuple(expandedSizes) + + +def infer_size(shape: ShapeType, numel: int) -> tuple[int, ...]: + """ + Infers the size of a dim with size -1, if it exists. + Also checks that new shape is compatible with the number of elements. + """ + from torch.fx.experimental.symbolic_shapes import guard_or_false + + dim = None + newsize = 1 + for i, d in enumerate(shape): + if guard_or_false(d == -1): + torch._check(dim is None, lambda: "only one dimension can be inferred") + dim = i + else: + torch._check( + d >= 0, + lambda: ( + f"invalid shape dimension {d}. If this was symbolic, it was assumed to not be -1." + "If this was meant to be inferred, please explicitly pass in -1." + ), + ) + newsize *= d + if dim is None: + torch._check( + numel == newsize, + lambda: f"shape '{list(shape)}' is invalid for input of size {numel}", + ) + else: + torch._check( + newsize != 0, + lambda: ( + f"cannot reshape tensor of 0 elements into shape {list(shape)} because the " + f"unspecified dimension size -1 can be any value and is ambiguous" + if guard_or_false(numel == 0) + else f"shape '{list(shape)}' is invalid for input of size {numel}" + ), + ) + torch._check( + numel % newsize == 0, + lambda: f"shape '{list(shape)}' is invalid for input of size {numel}", + ) + # Convert to list to produce a compatible error message with core + # PyTorch, which prints sequences in square brackets. + shape = list(shape) + shape[dim] = numel // newsize + # NB: This is pretty important when you have unbacked SymInts. + # Suppose you have (i0, 12) resizing into (2, -1, 12). The old + # range for i0 is typically [2, inf], which means if you divide + # by two the new range should be [1, inf]. But this is bad news + # if you have an unbacked SymInt: we need to reapply the unsound + # assumption that the size is >= 2. + torch._check_is_size(shape[dim]) + return tuple(shape) + + +_integer_dtypes = ( + torch.uint8, + torch.uint16, + torch.uint32, + torch.uint64, + torch.int8, + torch.int16, + torch.int32, + torch.int64, +) +_low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32) +_complex_dtypes = (torch.complex32, torch.complex64, torch.complex128) + + +def is_boolean_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype is torch.bool + + +def is_integer_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype in _integer_dtypes + + +def is_low_precision_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype in _low_precision_dtypes + + +def is_float_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype.is_floating_point + + +def is_complex_dtype(dtype: torch.dtype) -> bool: + assert isinstance(dtype, torch.dtype) + return dtype in _complex_dtypes + + +def is_grad_dtype(dtype: torch.dtype) -> bool: + """ + Checks if the dtype can require a gradient. + """ + return dtype.is_floating_point or is_complex_dtype(dtype) + + +_complex_to_real_dtype_map = { + torch.complex128: torch.float64, + torch.complex64: torch.float32, + torch.complex32: torch.float16, +} + +_real_to_complex_dtype_map = { + torch.float16: torch.complex32, + torch.bfloat16: torch.complex64, + torch.float32: torch.complex64, + torch.float64: torch.complex128, +} + + +def corresponding_real_dtype(dtype: torch.dtype) -> torch.dtype: + return _complex_to_real_dtype_map[dtype] + + +def corresponding_complex_dtype(dtype: torch.dtype) -> torch.dtype: + return _real_to_complex_dtype_map[dtype] + + +def dtype_to_type(dtype: torch.dtype) -> type: + """ + Computes the corresponding Python type (AKA "type kind") for the + given dtype. + """ + assert isinstance(dtype, torch.dtype) + + if dtype is torch.bool: + return bool + if dtype in _integer_dtypes: + return int + if dtype.is_floating_point: + return float + if dtype in _complex_dtypes: + return complex + + raise ValueError("Invalid dtype!") + + +def dtype_to_type_ctor(dtype: torch.dtype) -> Callable[[NumberType], NumberType]: + """ + Computes the corresponding Python type constructor for the + given dtype. + """ + assert isinstance(dtype, torch.dtype) + + if dtype is torch.bool: + return lambda x: bool(x) + if dtype in _integer_dtypes: + return sym_int + if dtype.is_floating_point: + return sym_float + if dtype in _complex_dtypes: + # TODO: type error here is real, replace with sym_complex + return lambda x: complex(x) # type: ignore[arg-type] + + raise ValueError("Invalid dtype!") + + +def type_to_dtype(typ: type) -> torch.dtype: + """ + Computes the corresponding dtype for a Number type. + """ + + assert isinstance(typ, type) + + if typ in (bool, torch.SymBool): + return torch.bool + if typ in (int, torch.SymInt): + return torch.long + if typ in (float, torch.SymFloat): + return torch.get_default_dtype() + # TODO: sym_complex_float? + if typ is complex: + return corresponding_complex_dtype(torch.get_default_dtype()) + + raise ValueError(f"Invalid type {typ}!") + + +def get_dtype(x: Union[torch.Tensor, NumberType]): + if isinstance(x, torch.Tensor): + return x.dtype + else: + return type_to_dtype(type(x)) + + +_ordered_types = (bool, int, float, complex) + + +def check_fp_or_complex( + dtype: torch.dtype, fn_name: str, allow_low_precision_dtypes: bool = True +): + """ + Checks whether the input is floating point or complex. + If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32 + """ + torch._check( + is_float_dtype(dtype) or is_complex_dtype(dtype), + lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}", + ) + torch._check( + allow_low_precision_dtypes or not is_low_precision_dtype(dtype), + lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}", + ) + + +def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"): + torch._check( + len(A.shape) >= 2, + lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.", + ) + + +def get_higher_type(a: type, b: type) -> type: + """ + Returns the higher of the two given Number types. + + The types are ordered bool -> int -> float -> complex. + """ + a, b = _maybe_get_pytype(a), _maybe_get_pytype(b) + # Type checking + if a not in _ordered_types or b not in _ordered_types: + raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}") + + if a is b: + return a + + for typ in _ordered_types: + if a is typ: + return b + if b is typ: + return a + + raise ValueError("Unknown Python scalar type!") + + +# Returns the higher of two torch datatypes a and b or, if the two +# are not ordered relative to each other, the next +# higher datatype +def get_higher_dtype( + a: Optional[Union[torch.dtype, TensorLikeType, NumberType]], + b: Optional[Union[torch.dtype, TensorLikeType, NumberType]], +) -> Optional[torch.dtype]: + """ + Computes the "lowest" datatype that is weakly + "higher" than both a and b. + """ + + # Type checking + assert a is None or isinstance(a, (torch.dtype, TensorLike, Number)) + assert b is None or isinstance(b, (torch.dtype, TensorLike, Number)) + + def _extract_dtype( + x: Optional[Union[torch.dtype, TensorLikeType, NumberType]] + ) -> Optional[torch.dtype]: + if x is None: + return None + if isinstance(x, torch.dtype): + return x + if isinstance(x, TensorLike): + return x.dtype + if isinstance(x, Number): + return type_to_dtype(type(x)) + + raise RuntimeError("Unexpected type given to _extract_dtype!") + + a, b = _extract_dtype(a), _extract_dtype(b) + + if a is b: + return a + + if a is None: + return b + + if b is None: + return a + + ordered_datatypes = ( + (torch.bool,), + (torch.uint8, torch.int8), + (torch.int16,), + (torch.int32,), + (torch.int64,), + (torch.float16, torch.bfloat16), + (torch.float32,), + (torch.float64,), + (torch.complex32,), + (torch.complex64,), + (torch.complex128,), + ) + + for idx, dtypes in enumerate(ordered_datatypes): + if a in dtypes and b in dtypes: + return ordered_datatypes[idx + 1][0] + if a in dtypes: + return b + if b in dtypes: + return a + + raise RuntimeError("Unexpected termination!") + + +def check_pin_memory(pin_memory: bool): + torch._check_not_implemented( + not pin_memory, lambda: "PrimTorch does not support pinned memory" + ) + + +def check_layout(layout: torch.layout): + torch._check_not_implemented( + layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}" + ) + + +# TODO: maybe unify with can_cast_to? +def is_weakly_lesser_type(a: type, b: type) -> bool: + """ + Compares two types, a and b, returning True if a is weakly "less" than b. + + The comparison is determined by the following type ordering: bool, int, float, complex. + """ + + a, b = _maybe_get_pytype(a), _maybe_get_pytype(b) + + if a not in _ordered_types or b not in _ordered_types: + raise RuntimeError(f"Expected builtin numeric types, found {a}, {b}") + + for typ in _ordered_types: + if a == typ: + return True + if b == typ: + return False + + raise RuntimeError("Unexpected termination!") + + +def can_safe_cast_to(*, cast_to: torch.dtype, cast_from: torch.dtype) -> bool: + for fn in (is_complex_dtype, is_float_dtype, is_integer_dtype, is_boolean_dtype): + if fn(cast_to): + return True + if fn(cast_from): + return False + + raise ValueError(f"Received unknown dtypes {cast_to}, {cast_from}!") + + +def check_same_dtype(*args): + """ + Checks that all Tensors in args have the same device and that all Numbers have the + same corresponding Python type. + + Raises a RuntimeError when: + - args contains an object whose type is not Tensor or Number + - two Tensors objects in args have different dtypes + - two Number objects in args have different types + - there are Tensors and Numbers in args, and one of those Tensors corresponding + Python types is different from the type of one of those Numbers + """ + full_dtype = None + scalar_type = None + + for arg in args: + if isinstance(arg, Number): + # Scalar type checking is disabled (and may be removed in the future) + continue + # if scalar_type is None: + # scalar_type = type(arg) + + # if scalar_type is not type(arg): + # msg = ( + # "Scalar of type " + # + str(type(arg)) + # + " is not the expected type of " + # + str(scalar_type) + # + "!" + # ) + # raise RuntimeError(msg) + elif isinstance(arg, TensorLike): + if full_dtype is None: + full_dtype = arg.dtype + if scalar_type is None: + scalar_type = dtype_to_type(arg.dtype) + + if full_dtype is not arg.dtype: + msg = ( + "Tensor with dtype " + + str(arg.dtype) + + " is not the expected dtype of " + + str(full_dtype) + + "!" + ) + raise RuntimeError(msg) + + arg_type = dtype_to_type(arg.dtype) + if arg_type is not scalar_type: + msg = ( + "Tensor with corresponding Python type " + + str(arg_type) + + " is not the expected type of " + + str(scalar_type) + + "!" + ) + raise RuntimeError(msg) + else: + msg = ( + "Unexpected type when checking for same dtype, " + str(type(arg)) + "!" + ) + raise RuntimeError(msg) + + +# Maps datatypes to their computation types for elementwise operations +_computation_dtype_map = { + torch.bfloat16: torch.float32, + torch.float16: torch.float32, + torch.complex32: torch.complex64, +} + + +def get_computation_dtype(dtype: torch.dtype) -> torch.dtype: + return _computation_dtype_map.get(dtype, dtype) + + +_cpu_acc_type_map = { + torch.bfloat16: torch.float64, + torch.float16: torch.float64, + torch.float32: torch.float64, + torch.complex32: torch.complex128, + torch.complex64: torch.complex128, +} + + +def get_acc_type(dtype: torch.dtype, device: torch.device) -> torch.dtype: + # Equivalent to at::toAccumulateType, prefer computation_dtype where possible + if device.type == "cpu": + return _cpu_acc_type_map.get(dtype, dtype) + else: + return get_computation_dtype(dtype) + + +class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum): + DEFAULT = (0,) + NO_OPMATH = (1,) + INT_TO_FLOAT = (2,) + ALWAYS_BOOL = (3,) + COMPLEX_TO_FLOAT = (4,) + BOOL_TO_LONG = (5,) + + +class REDUCTION_OUTPUT_TYPE_KIND(Enum): + SAME = (0,) + COMPLEX_TO_FLOAT = (1,) # for complex types outputs corresponding real type + KEEP_PROMOTED_TYPE = (2,) # keep output in opmath type, needed for mean + ALWAYS_BOOL = (3,) + + +# Describes the return type of the primitive: +# +# - NEW, a new tensor is created +# - VIEW, a view of an input tensor is returned +# - INPLACE, one or more input tensors is modified +# +# these descriptors are mututally exclusive and exhaustive. +class RETURN_TYPE(Enum): + NEW = (0,) + VIEW = (1,) + INPLACE = (2,) + NONE = (3,) + + +# TODO: when NumberType contains the sym types, can simplify this +def number_type( + x: Union[NumberType, torch.SymInt, torch.SymFloat, torch.SymBool] +) -> type: + if isinstance(x, torch.SymInt): + return int + elif isinstance(x, torch.SymFloat): + return float + elif isinstance(x, torch.SymBool): + return bool + else: + return type(x) + + +def expr_type(x: sympy.Basic) -> type: + import sympy + + if x.kind is sympy.core.kind.BooleanKind: + return bool + elif x.is_integer: # type: ignore[attr-defined] + return int + else: + # NB: Not strictly correct, but we don't support SymPy complex or bool. + return float + + +# TODO: document type promotion kinds +def elementwise_dtypes( + *_args, + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND, +) -> tuple[torch.dtype, torch.dtype]: + """ + Computes the computation and result dtypes for elementwise type promotion + on the given arguments and with the given elementwise type promotion kind. + + Note that not all inputs to an elementwise operation necessarily participate in type promotion. + For example, the "alpha" parameter of torch.add does not participate in type promotion, + although it may be cast to the Python type corresponding to the computation dtype that + the type promotion algorithm determines. + + Default elementwise type promotion, which all other type promotion kinds tweak (see below), + first decides which of four ordered types to use: + + bool -> integer -> floating point -> complex + + The selected type is the "lowest" type in the above list such that all number arguments + have a weakly "lower" type and all tensor arguments have a weakly lower corresponding + type for their dtype. + + Once the type is determined, the particular result dtype is found. The dtypes are + partially ordered as follows: + + bool -> uint8, int8 -> int16 -> int32 -> int64 -> + float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128 + + The result dtype is selected by: + - if no tensor's dtype has the same corresponding type as the one selected, + then the result dtype is the (default) dtype corresponding to the selected type + (for example, 1.5 + an integer tensor has a result dtype of the default floating point dtype) + - if the result type is complex then the dtype is: + - the default complex dtype if there are no floating point or complex tensors + - if there are floating point or complex tensors with one or more dimensions, then + the complex dtype corresponding to the highest corresponding complex dtype among those tensors + (for example, double + cfloat -> cdouble) + - if there are only floating point or complex tensors with zero dimensions, then + the complex dtype corresponding to the highest corresponding complex dtype among those tensors + - if the first two cases do not apply, the result dtype is the highest dtype among + all tensors with one or more dimensions of the output type, and if there are no such + tensors then it's the highest dtype among all tensors with zero dimensions of the output type + (for example, long + half -> half, even if the half tensor has zero dimensions) + + The "corresponding complex dtypes" are: + float16 -> complex32 + bfloat16 -> complex64 + float32 -> complex64 + float64 -> complex128 + complex32 -> complex32 + complex64 -> complex64 + complex128 -> complex128 + + The DEFAULT type promotion kind computes per above, and then uses the result dtype to pick a computation + dtype by mapping low precision floating point and complex dtypes as follows: + + float16 -> float32 + bfloat16 -> float32 + complex32 -> complex64 + + This is referred to as "op math", and the NO_OPMATH type promotion kind disables this mapping, making the + computation dtype the same as the result dtype when it's selected. NO_OPMATH is appropriate for kernels + which perform no mathematical operations on their tensors (see below for examples). + + The INT_TO_FLOAT type promotion kind maps boolean and integer result dtypes to the default floating point dtype, + and computation dtypes to the appropriate op math dtype. + + The COMPLEX_TO_FLOAT type promotion kind maps complex result dtypes to the corresponding float dtype, following this + mapping: + + complex32 -> float16 + complex64 -> float32 + complex128 -> float64 + + Note that COMPLEX_TO_FLOAT derives the computation dtype as the DEFAULT setting does. + + The BOOL_TO_LONG type promotion kind maps boolean computation and result dtypes to long. + + The ALWAYS_BOOL type promotion kind always sets the result dtype to bool. + + Example operators for each type promotion option: + DEFAULT : add + NO_OPMATH : where, nextafter, cat + INT_TO_FLOAT : sin + COMPLEX_TO_FLOAT : abs + BOOL_TO_LONG : pow + ALWAYS_BOOL : eq + + """ + + args = tuple(x for x in _args if x is not None) + + highest_type: type = bool + + # Import sympy locally, as importing it eagerly at a module level is too slow + # See https://dev-discuss.pytorch.org/t/delving-into-what-happens-when-you-import-torch/1589 + import sympy + + for x in args: + if not isinstance(x, (Number, TensorLike, sympy.Basic)): + msg = f"Unexpected type {str(type(x))} when computing elementwise type promotion!" + raise ValueError(msg) + + if isinstance(x, Number): + highest_type = get_higher_type(highest_type, number_type(x)) + elif isinstance(x, sympy.Basic): + highest_type = get_higher_type(highest_type, expr_type(x)) + else: + # x is a TensorLike + highest_type = get_higher_type(highest_type, dtype_to_type(x.dtype)) + + result_dtype = None + + def _find_highest_dtype_filtered( + args, filter, *, float_as_complex=False + ) -> Optional[torch.dtype]: + zero_dim_tensor_dtype = None + one_plus_dim_tensor_dtype = None + for x in args: + if isinstance(x, TensorLike) and filter(x.dtype): + _dtype = x.dtype + if float_as_complex and is_float_dtype(_dtype): + _dtype = corresponding_complex_dtype(_dtype) + if x.ndim == 0: + zero_dim_tensor_dtype = get_higher_dtype( + zero_dim_tensor_dtype, _dtype + ) + else: + # x.ndim > 0 + one_plus_dim_tensor_dtype = get_higher_dtype( + one_plus_dim_tensor_dtype, _dtype + ) + + # Prefers dtype of tensors with one or more dimensions + if one_plus_dim_tensor_dtype is not None: + return one_plus_dim_tensor_dtype + + return zero_dim_tensor_dtype + + if highest_type is float: + result_dtype = _find_highest_dtype_filtered(args, is_float_dtype) + result_dtype = ( + torch.get_default_dtype() if result_dtype is None else result_dtype + ) + elif highest_type is complex: + result_dtype = _find_highest_dtype_filtered( + args, + lambda x: is_float_dtype(x) or is_complex_dtype(x), + float_as_complex=True, + ) + if result_dtype is None: + result_dtype = corresponding_complex_dtype(torch.get_default_dtype()) + elif highest_type is int: + result_dtype = _find_highest_dtype_filtered(args, is_integer_dtype) + result_dtype = torch.long if result_dtype is None else result_dtype + else: + # highest_type is bool + result_dtype = torch.bool + + if type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT: + return get_computation_dtype(result_dtype), result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH: + return result_dtype, result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT: + if is_integer_dtype(result_dtype) or is_boolean_dtype(result_dtype): + result_dtype = torch.get_default_dtype() + return get_computation_dtype(result_dtype), result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT: + # NOTE: computation can still occur in a complex dtype + computation_dtype = get_computation_dtype(result_dtype) + if is_complex_dtype(result_dtype): + result_dtype = corresponding_real_dtype(result_dtype) + return computation_dtype, result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG: + if is_boolean_dtype(result_dtype): + return torch.long, torch.long + return get_computation_dtype(result_dtype), result_dtype + elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL: + return get_computation_dtype(result_dtype), torch.bool + else: + raise ValueError(f"Unknown type promotion kind {str(type_promotion_kind)}") + + +def reduction_dtypes( + arg, + output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND, + dtype: Optional[torch.dtype] = None, +) -> tuple[torch.dtype, Optional[torch.dtype]]: + # even though some reductions, like amin or amax, don't strictly require type promotion, + # all the math ops (including comparisons) are still defined only for a computation type, + # so promotion will still happen. We are doing it explicitly here + inp_dtype = dtype if dtype is not None else arg.dtype + computation_dtype = get_computation_dtype(inp_dtype) + if ( + output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME + or output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT + ): + result_dtype = dtype if dtype else arg.dtype + if ( + output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT + and is_complex_dtype(result_dtype) + ): + result_dtype = corresponding_real_dtype(result_dtype) + elif output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE: + result_dtype = None + else: # ALWAYS_BOOL + result_dtype = torch.bool + return computation_dtype, result_dtype + + +# This function's logic is borrowed from the following functions defined in C++: +# batched_matrix_contiguous_strides and contiguous_strides +def make_contiguous_strides_for( + shape: ShapeType, row_major: bool = True +) -> tuple[Union[_IntLikeT, int], ...]: + """ + Returns the strides of a contiguous tensor if row_major + If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices + This is often used when calling external libraries like BLAS/LAPACK/cuSolver... + """ + # contiguous_strides from c10/util/strides.h + validate_shape(shape) + if not shape: + return () + + from torch.fx.experimental.symbolic_shapes import is_nested_int + + multiplier: Union[_IntLikeT, int] = 1 + strides = [] + for l in reversed(shape): + strides.append(multiplier) + multiplier *= ( + l if is_nested_int(l) else sym_max(l, 1) + ) # type:ignore[assignment] + + result = tuple(reversed(strides)) + + # batched_matrix_contiguous_strides from aten/src/ATen/native/LinearAlgebraUtils.h + if row_major: + return result + else: + if len(shape) < 2: + return result + return result[:-2] + (1, max(shape[-2], 1)) + + +def make_channels_last_1d_strides_for( + shape: Sequence[_IntLikeT], +) -> tuple[Union[_IntLikeT, int], ...]: + torch._check( + len(shape) == 3, + lambda: "Only tensors of rank 3 can use the channels_last_1d memory format", + ) + + multiplier: Union[_IntLikeT, int] = 1 + strides: list[Union[_IntLikeT, int]] = [0] * 3 + for idx in (1, -1, 0): + # NOTE: intentionally divergence from make_contiguous_strides_for + # This is consistent with eager + strides[idx] = multiplier + multiplier *= shape[idx] + + return tuple(strides) + + +def make_channels_last_2d_strides_for( + shape: Sequence[_IntLikeT], +) -> tuple[Union[_IntLikeT, int], ...]: + # TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5? + torch._check( + len(shape) == 4, + lambda: "Only tensors of rank 4 can use the channels_last memory format", + ) + + multiplier: Union[_IntLikeT, int] = 1 + strides: list[Union[_IntLikeT, int]] = [0] * 4 + for idx in (1, -1, -2, 0): + # NOTE: intentionally divergence from make_contiguous_strides_for + # This is consistent with eager + strides[idx] = multiplier + multiplier *= shape[idx] + + return tuple(strides) + + +def make_channels_last_3d_strides_for( + shape: Sequence[_IntLikeT], +) -> tuple[Union[_IntLikeT, int], ...]: + torch._check( + len(shape) == 5, + lambda: "Only tensors of rank 5 can use the channels_last_3d memory format", + ) + + multiplier: Union[_IntLikeT, int] = 1 + strides: list[Union[_IntLikeT, int]] = [0] * 5 + for idx in (1, -1, -2, -3, 0): + # NOTE: intentionally divergence from make_contiguous_strides_for + # This is consistent with eager + strides[idx] = multiplier + multiplier *= shape[idx] + + return tuple(strides) + + +def make_channels_last_strides_for( + shape: Sequence[_IntLikeT], +) -> tuple[Union[_IntLikeT, int], ...]: + ndim = len(shape) if isinstance(shape, Sequence) else 1 + if ndim == 3: + return make_channels_last_1d_strides_for(shape) + elif ndim == 4: + return make_channels_last_2d_strides_for(shape) + elif ndim == 5: + return make_channels_last_3d_strides_for(shape) + else: + raise RuntimeError( + f"no channels last format strides exist in {ndim} dimensions" + ) + + +def compute_reduction_output_shape( + shape: ShapeType, dimensions: Sequence +) -> tuple[int, ...]: + for idx in dimensions: + validate_idx(len(shape), idx) + + new_shape = [] + for idx in range(len(shape)): + if idx in dimensions: + continue + + new_shape.append(shape[idx]) + + return tuple(new_shape) + + +def validate_no_repeating_dims(dims: Sequence): + if len(dims) != len(set(dims)): + raise RuntimeError("duplicate value in the list of dims") + + +def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> tuple[int, ...]: + if dims is None: + return tuple(range(len(shape))) + dims = tuple(canonicalize_dim(len(shape), idx) for idx in dims) + validate_no_repeating_dims(dims) + return dims + + +def set_correction( + unbiased: Optional[bool] = None, + correction: Optional[NumberType] = None, +) -> float: + if correction is not None and unbiased is not None: + raise RuntimeError("cannot specify both correction and unbiased arguments") + elif correction is None and unbiased is None: + correction = 1.0 + elif correction is None and unbiased is not None: + correction = 0.0 if unbiased is False else 1.0 + # NB: we don't actually support symint here, but it's harmless to accept + if not isinstance(correction, (IntLike, FloatLike)): + raise ValueError("correction argument should be integer or float") + if correction < 0: + raise ValueError("correction argument should be non-negative") + return sym_float(correction) + + +def compute_required_storage_length( + shape: ShapeType, strides: StrideType, storage_offset: int +) -> int: + """Computes the minimum storage size to hold the given tensor geometry. + + Example + ======= + + This is the size of a newly allocated tensor's storage, in units of elements + + >>> t = torch.empty((10, 20)) + >>> compute_required_storage_length(t.shape, t.stride(), t.storage_offset()) + 200 + + >>> # xdoctest: +SKIP(failing) + >>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11)) + >>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset()) + >>> size == t.storage().size() + True + + A valid tensor may have a larger storage size, but never smaller + + >>> slice = torch.empty(100)[20:40] + >>> slice.storage().size() + 100 + + >>> compute_required_storage_length(slice.shape, slice.stride(), slice.storage_offset()) + 40 + + """ + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + # Short-circuits if the shape has no elements + if guard_size_oblivious(reduce(operator.mul, shape, 1) == 0): + return 0 + + max_offset = sum((x - 1) * y for x, y in zip(shape, strides)) + # +1 to account for the first element which offsets are taken from + return 1 + storage_offset + max_offset + + +def check_in_bounds_for_storage( + a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int +): + """ + Determines if the given shape, strides, and offset are valid for the given storage. + """ + + required_length = compute_required_storage_length(shape, strides, storage_offset) + if a.size() < required_length: + msg = ( + f"Can't view a storage of size {a.size()} with an offset of {storage_offset}, " + f"shape of {str(shape)}, and strides of {str(strides)}, " + f"which requires a storage of size {required_length}" + ) + raise ValueError(msg) + + +# NOTE: This function should ideally be removed, but some Meta internal models +# packaged with `torch.package` are using it, so it will have to be removed +# at some point in the future when those models no longer use this function. +@deprecated( + "`torch._prims_common.check` is deprecated and will be removed in the future. " + "Please use `torch._check*` functions instead.", + category=FutureWarning, +) +def check( + b: bool, s: Callable[[], str], exc_type: type[Exception] = RuntimeError +) -> None: + """ + Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails. + Error message is a callable producing a string (to avoid wasting time + string formatting in non-error case, and also to make it easier for torchdynamo + to trace.) + + .. note:: This function is planned for removal in the future. Please use + `torch._check*` functions instead. + """ + torch._check_with(exc_type, b, s) + + +# This combines is_channels_last_strides_2d and is_channels_last_strides_3d in +# c10/core/MemoryFormat.h into one function +def are_strides_like_channels_last( + shape: Sequence[int], strides: Sequence[int] +) -> bool: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + ndim = len(shape) + + if ndim == 4: + # Check for channels_last_2d + dim_order = [1, 3, 2, 0] + elif ndim == 5: + # Check for channels_last_3d + dim_order = [1, 4, 3, 2, 0] + else: + return False + + if guard_size_oblivious(strides[1] == 0): + return False + + min = 0 + for d in dim_order: + if guard_size_oblivious(shape[d] == 0): + return False + if guard_size_oblivious(strides[d] < min): + return False + if d == 0 and min == strides[1]: + return False + min = strides[d] + if guard_size_oblivious(strides[d] > 1): + min *= shape[d] + return True + + +def suggest_memory_format(x: TensorLikeType) -> torch.memory_format: + if x.layout != torch.strided: + return torch.contiguous_format + + if are_strides_like_channels_last(x.shape, x.stride()): + return torch.channels_last if x.ndim == 4 else torch.channels_last_3d + + return torch.contiguous_format + + +def prod(xs: Sequence[NumberType]) -> NumberType: + """Product of elements in input sequence. Returns 1 for empty sequence""" + return reduce(operator.mul, xs, 1) + + +def is_expandable_to(shape: ShapeType, desired: ShapeType) -> bool: + """Checks if a shape can be expanded to another shape. + This is equivalent to checking if the two shapes are broadcastable. + """ + # This is a Python implementation of + # aten/src/ATen/ExpandUtils.h:is_expandable_to + if len(shape) > len(desired): + return False + for i in range(len(shape)): + if shape[-i - 1] != desired[-i - 1] and shape[-i - 1] != 1: + return False + return True + + +def mask_tensor(mask: TensorLikeType, t: TensorLikeType): + """ + Similar to torch.where(mask, t, 0) but if t is boolean, + result is also boolean and not promoted to int. + """ + # torch.where(mask, t, False) is equivalent + # but feels hacky and might break in the future + if t.dtype is torch.bool: + return mask.logical_and(t) + else: + return torch.where(mask, t, 0) + + +def get_aten_op(fn: Callable, name: str): + """ + Given the __module__ of reference and its name, it returns + (our best guess of) the ATen name of the associated operation + + Note: In ATen, the __name__ of a function within a module often + starts by the module name. E.g. linalg_eigh, or special_zeta + """ + module = fn.__module__ + prefix = "torch._refs" + assert module.startswith(prefix) + module = module[len(prefix) :] + # We want to go from .special / .nn.functional + # to special and special_ / nn_functional_ + if module: + module = module[1:] + module = module.replace(".", "_") + module = module + "_" + return getattr(torch._ops.ops.aten, f"{module}{name}") + + +def dtype_or_default(dtype: Optional[torch.dtype]) -> torch.dtype: + return dtype if dtype is not None else torch.get_default_dtype() + + +def device_or_default(device: Optional[DeviceLikeType]) -> DeviceLikeType: + return device if device is not None else torch.device("cpu") + + +def layout_or_default(layout: Optional[torch.layout]) -> torch.layout: + return layout if layout is not None else torch.strided + + +def clone_preserve_strides(x): + needed_size = compute_required_storage_length( + x.size(), x.stride(), x.storage_offset() + ) + # Our eager implementations for *_scatter ops are all primitives w.r.t autograd, + # so these as_strided() calls are not seen by autograd. + # We need to mimic this behavior in our ref/prim implementations. + # TODO: a better way to handle this would be with a new op, "_unsafe_as_strided" + # We should revisit this when we add a compositional as_strided op, + # and also as part of https://github.com/pytorch/pytorch/issues/90507 + try: + old = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView + ) + torch._C._dispatch_tls_set_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView, True + ) + buffer = torch.as_strided(x, (needed_size,), (1,), 0).clone() + return torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset()) + finally: + torch._C._dispatch_tls_set_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView, old + ) + + +def alert_not_deterministic(caller: str): + if torch.are_deterministic_algorithms_enabled(): + if torch.is_deterministic_algorithms_warn_only_enabled(): + warnings.warn( + f"{caller} does not have a deterministic implementation, but you set " + f"'torch.use_deterministic_algorithms(True, warn_only=True)'. " + f"You can file an issue at https://github.com/pytorch/pytorch/issues " + f"to help us prioritize adding deterministic support for this operation." + ) + else: + torch._check( + False, + lambda: ( + f"{caller} does not have a deterministic implementation, but you set " + f"'torch.use_deterministic_algorithms(True)'. You can turn off " + f"determinism just for this operation, or you can use the " + f"'warn_only=True' option, if that's acceptable for your application. " + f"You can also file an issue at https://github.com/pytorch/pytorch/issues " + f"to help us prioritize adding deterministic support for this operation." + ), + ) + + +class CUDARngStateHelper: + @staticmethod + def get_torch_state_as_tuple(fake_mode=nullcontext()): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA not available") + + with fake_mode: + seed = torch.tensor(torch.cuda.initial_seed()) + offset = torch.tensor(torch.cuda._get_rng_state_offset()) + return seed, offset + + @staticmethod + def set_torch_state_tensor(seed, offset): + # Rng state is [64-bit seed, 64-bit offset] + seed_portion = seed.reshape([1]).view(torch.uint8) + offset_portion = offset.reshape([1]).view(torch.uint8) + new_state = torch.cat([seed_portion, offset_portion]) + torch.cuda.set_rng_state(new_state) + + @staticmethod + def set_new_offset(relative_offset): + torch.cuda._set_rng_state_offset(relative_offset.item()) diff --git a/phivenv/Lib/site-packages/torch/_prims_common/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_prims_common/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7852ab90a0a491c6ac82ec78a6c71e66635fecf3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_prims_common/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_prims_common/__pycache__/wrappers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_prims_common/__pycache__/wrappers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c078d596c692d43c46cf96433e33723ef68ad2e6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_prims_common/__pycache__/wrappers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_prims_common/wrappers.py b/phivenv/Lib/site-packages/torch/_prims_common/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..44aeb2a296335b86ff144631a794b027d5a85715 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_prims_common/wrappers.py @@ -0,0 +1,479 @@ +# mypy: allow-untyped-defs +import inspect +import types +import warnings +from collections.abc import Sequence +from functools import wraps +from types import GenericAlias +from typing import Callable, NamedTuple, Optional, overload, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +import torch._prims_common as utils +from torch._prims_common import ( + CustomOutParamAnnotation, + ELEMENTWISE_TYPE_PROMOTION_KIND, + Number, + NumberType, + ShapeType, + TensorLike, + TensorLikeType, +) +from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_flatten, tree_unflatten + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +@overload +def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType: + pass + + +@overload +def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType: + pass + + +@overload +def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence: + pass + + +@overload +def _maybe_convert_to_dtype(a: None, dtype: torch.dtype) -> None: + pass + + +# TODO: implement ref.cast with an option to enforce safe casting +def _maybe_convert_to_dtype(a, dtype): + if isinstance(a, TensorLike): + if a.dtype != dtype: + return a.to(dtype) + return a + if isinstance(a, Number): + return utils.dtype_to_type_ctor(dtype)(a) # type: ignore[arg-type] + if isinstance(a, Sequence): + return tuple(_maybe_convert_to_dtype(x, dtype) for x in a) + # Passthrough None because some functions wrapped with type promotion + # wrapper might have optional args + if a is None: + return None + + raise ValueError( + f"Received unsupported type {type(a)}. Expected TensorLike, Number, or Sequence." + ) + + +def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType: + if not isinstance(a, Number): + msg = f"Found unknown type {type(a)} when trying to convert scalars!" + raise ValueError(msg) + if not utils.is_weakly_lesser_type(type(a), typ): + msg = f"Scalar {a} of type {type(a)} cannot be safely cast to type {typ}!" + raise ValueError(msg) + + return typ(a) + + +def _annotation_has_type(*, typ, annotation): + if hasattr(annotation, "__args__"): + for a in annotation.__args__: + if _annotation_has_type(typ=typ, annotation=a): + return True + return False + + return typ is annotation + + +class elementwise_type_promotion_wrapper: + """ + Adds elementwise type promotion to a Python reference implementation. + + Takes two kwargs, type_promoting_args and type_promotion_kind. + + type_promoting_args must be a string Sequence specifiying the argument names of all + arguments that participate in type promotion (and should be type promoted). If the + arg specifies a Sequence-type then every element of the Sequence will participate in + type promotion. + + type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND. + See its documentation for details. + + The return_dtype will be coerced to the wrapped function's dtype arg if it is available and + not None. + + Other type promotion behavior, like validating the Python type of scalar arguments, must + be handled separately. + """ + + def __init__( + self, + *, + type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND, + type_promoting_args: Optional[Sequence[str]] = None, + ): + self.type_promoting_arg_names = type_promoting_args + self.type_promotion_kind = type_promotion_kind + + def __call__(self, fn: Callable) -> Callable: + sig = inspect.signature(fn) + + # TorchDynamo tracing of inspect causes fake tensor dynamo_wrapped tests to fail + # PYTORCH_TEST_WITH_DYNAMO=1 python test/test_fake_tensor.py FakeTensorTest.test_basic + @torch._disable_dynamo + @wraps(fn) + def _fn(*args, **kwargs): + bound = sig.bind(*args, **kwargs) + type_promoting_args = tuple( + bound.arguments[x] + for x in self.type_promoting_arg_names # type: ignore[union-attr] + if x in bound.arguments.keys() + ) + + flattened_type_promoting_args = pytree.arg_tree_leaves(*type_promoting_args) + compute_dtype, result_dtype = utils.elementwise_dtypes( + *flattened_type_promoting_args, + type_promotion_kind=self.type_promotion_kind, + ) + + promoted_args = { + x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype) + for x in self.type_promoting_arg_names # type: ignore[union-attr] + if x in bound.arguments.keys() + } + bound.arguments.update(promoted_args) + + result = fn(**bound.arguments) + + # Override the return_dtype if a dtype arg is present and not None + if "dtype" in bound.arguments: + maybe_dtype = bound.arguments["dtype"] + if maybe_dtype: # dtype cannot be None + result_dtype = maybe_dtype + + if isinstance(result, TensorLike): + return _maybe_convert_to_dtype(result, result_dtype) + if isinstance(result, Sequence): + return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result) + raise AssertionError(f"Unhandled result type: {type(result)}") + + _fn.__signature__ = sig # type: ignore[attr-defined] + return _fn + + +# Returns True if resize is necessary +def _resize_output_check(out: TensorLikeType, shape: ShapeType): + # If the shapes are correct there's nothing to do + if utils.same_shape(out.shape, shape): + return False + if out.numel() != 0: + msg = ( + f"An output with one or more elements was resized since it had shape {str(out.shape)} " + "which does not match the required output shape {str(shape)}. " + "This behavior is deprecated, and in a future PyTorch release outputs will not " + "be resized unless they have zero elements. " + "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)." + ) + warnings.warn(msg) + return True + + +# TODO: handle tuples of tensors +def _maybe_resize_out( + out: TensorLikeType, + shape: ShapeType, + memory_format: Optional[torch.memory_format] = None, +): + if _resize_output_check(out, shape): + return out.resize_(shape, memory_format=memory_format) + else: + return out + + +def is_cpu_scalar(x: TensorLikeType) -> bool: + return x.dim() == 0 and x.device.type == "cpu" + + +def check_copy_devices(*, copy_from: TensorLikeType, copy_to: TensorLikeType) -> None: + if copy_from.device != copy_to.device: + msg = ( + f"Attempting to copy from device {copy_from.device} " + f"to device {copy_to.device}, but cross-device copies are not allowed!" + ) + raise RuntimeError(msg) + + +def _safe_copy_out( + *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False +): + # Checks same device + if not is_cpu_scalar(copy_from): + check_copy_devices(copy_from=copy_from, copy_to=copy_to) + + # Checks safe cast + if exact_dtype: + torch._check( + copy_from.dtype == copy_to.dtype, + lambda: f"Expected out tensor to have dtype {copy_from.dtype} " + f"but got {copy_to.dtype} instead", + ) + else: + torch._check( + utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype), + lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, " + "but this can't be cast because it is not safe!", + ) + + return copy_to.copy_(copy_from) + + +def out_wrapper( + *out_names: str, + exact_dtype: bool = False, + pass_is_out: bool = False, + preserve_memory_format: bool = False, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + # The wrapped function needs to convert the output parameters to ensure + # compatibility between the Python API (which always uses "out" as the + # parameter name and may be a tuple) and the Aten API (which may have + # multiple output parameters and use different parameter names such as + # "grad_input", "indices" or "values".) + + default_out_names = ("out",) + if len(out_names) == 0: + # Use default in out name + out_names = default_out_names + + is_tensor = len(out_names) == 1 + + def maybe_compute_memory_format(t): + return utils.suggest_memory_format(t) if preserve_memory_format else None + + def _out_wrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]: + """ + Adds the out parameter to a Python reference. + """ + out_type = ( + TensorLikeType + if is_tensor + else GenericAlias( + tuple, tuple(TensorLikeType for _ in range(len(out_names))) + ) + ) + # For backward compatibility - should be able to remove once PEP585 + # conversion is complete. + bc_out_type = ( + TensorLikeType + if is_tensor + else types.GenericAlias( + tuple, tuple(TensorLikeType for _ in range(len(out_names))) + ) + ) + return_type = ( + TensorLikeType + if is_tensor + else NamedTuple( + f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names] + ) + ) + + sig = inspect.signature(fn) + factory_kwargs = ("device", "dtype") + is_factory_fn = all(p in sig.parameters for p in factory_kwargs) + + @wraps(fn) + def _fn(*args: _P.args, **kwargs: _P.kwargs): + out = kwargs.pop("out", None) + if is_factory_fn and out is not None: + for k in factory_kwargs: + out_attr = getattr(out, k) + if k not in kwargs: + kwargs[k] = out_attr + + def maybe_check_copy_devices(out): + if isinstance(out, TensorLike) and isinstance(args[0], TensorLike): + check_copy_devices(copy_from=args[0], copy_to=out) + + if isinstance(out, (tuple, list)): + for o in out: + maybe_check_copy_devices(o) + else: + maybe_check_copy_devices(out) + + if pass_is_out: + result = fn(*args, is_out=(out is not None), **kwargs) # type: ignore[arg-type] + else: + result = fn(*args, **kwargs) + if result is NotImplemented: + return NotImplemented + assert ( + (isinstance(result, TensorLike) and is_tensor) + or ( + isinstance(result, tuple) # type: ignore[arg-type] + and len(result) == len(out_names) # type: ignore[arg-type] + ) + or ( + fn.__name__ == "unbind" + and isinstance(result, (list, tuple)) # type: ignore[arg-type] + ) + ) + # unbind_copy is a special case: see https://github.com/pytorch/pytorch/issues/130829 + if out is not None: + # Naively you might expect this assert to be true, but + # it's not: + # + # assert type(out) == type(result) + # + # The reason is that functions under this wrapper can + # get registered to the Meta dispatch key, and that + # means they can be executed in a context where tensor + # subclasses are disabled (with no_dispatch), which is a + # handy way for an is-a tensor subclass (e.g., + # FakeTensor) to have the normal meta backend create a + # meta tensor, to be wrapped once it gets returned. + # In this situation, you will get a FakeTensor as + # the output tensor, but not the result--which will + # be a normal meta tensor, but this is perfectly + # harmless. + if is_tensor and fn.__name__ != "unbind": + assert isinstance(out, TensorLike) + # These two operations are done in-place + _maybe_resize_out( + out, result.shape, maybe_compute_memory_format(result) # type: ignore[union-attr] + ) + _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type] + else: + if fn.__name__ != "unbind": + assert isinstance(out, tuple) # type: ignore[arg-type] + else: + assert isinstance(out, (list, tuple)) # type: ignore[arg-type] + torch._check_type( + len(out) == len(result), # type: ignore[arg-type] + lambda: f"expected tuple of {len(result)} elements but got {len(out)}", # type: ignore[arg-type] + ) + for r, o in zip(result, out): # type: ignore[arg-type] + # These two operations are done in-place + _maybe_resize_out(o, r.shape, maybe_compute_memory_format(r)) + _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type] + else: + out = result + # mypy does not see through the definition of out_type given that it's in a different scope + return out if is_tensor else return_type(*out) # type: ignore[operator] + + out_param = inspect.Parameter( + "out", + kind=inspect.Parameter.KEYWORD_ONLY, + default=None, + annotation=out_type, + ) + # Mark that the function now returns a tuple + assert isinstance( + sig.return_annotation, (str, TypeVar) + ) or sig.return_annotation in ( + sig.empty, + out_type, + bc_out_type, + ) + params = *sig.parameters.values(), out_param + + # If there's a Parameter.VAR_KEYWORD parameter (like **kwds), it must appear + # after the out= parameter, which is Parameter.KEYWORD_ONLY. Sorting by + # Parameter.kind guarantees that all the parameters are in legal order. + params = sorted(params, key=lambda p: p.kind) + + _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=params, return_annotation=return_type # type: ignore[arg-type] + ) + + _fn.__annotations__ = dict(getattr(fn, "__annotations__", {})) + _fn.__annotations__["out"] = out_type + _fn.__annotations__["return"] = return_type + + # In the special case of having a single tensor out parameter with a + # name other than out, add a special annotation to name the parameter + if is_tensor and out_names != default_out_names: + _fn.__annotations__[CustomOutParamAnnotation] = out_names[0] + + # Add an indicator attribute that can be used in special cases + # where having a function wrapped by `out_wrapper` is not desirable e.g. + # jit + _fn._torch_decompositions_out_wrapper = f"This function is wrapped by {out_wrapper.__module__}.out_wrapper" # type: ignore[attr-defined] + + return _fn + + return _out_wrapper + + +def _maybe_remove_out_wrapper(fn: Callable): + return inspect.unwrap( + fn, + stop=lambda f: not hasattr(f, "_torch_decompositions_out_wrapper"), + ) + + +def backwards_not_supported(prim): + def redispatch_prim(args, kwargs): + with torch._C._AutoDispatchBelowAutograd(): + return prim(*args, **kwargs) + + class BackwardsNotSupported(torch.autograd.Function): + @staticmethod + def forward(ctx, args_spec, *flat_args): + args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type] + return redispatch_prim(args, kwargs) + + @staticmethod + def backward(ctx, *args): + raise RuntimeError("backwards not supported on prim") + + @wraps(prim) + def _autograd_impl(*args, **kwargs): + flat_args, args_spec = tree_flatten((args, kwargs)) + if torch.is_grad_enabled() and any( + a.requires_grad for a in flat_args if isinstance(a, torch.Tensor) + ): + # TODO: There is a subtle bug here: prims like copy_to + # return their input argument after mutating it; and custom + # autograd function will incorrectly turn the result into + # a view which will fail test_python_ref_executor tests. + # At the moment, we sidestep this by observing that the + # unit tests don't ever try to run the executor with + # autograd, so we don't exercise the buggy case, but if + # you ever want to feed autograd through this, be aware + # of it! We need a way of properly implementing autograd + # for mutating operations in Python to do this. + return BackwardsNotSupported.apply(args_spec, *flat_args) + else: + return redispatch_prim(args, kwargs) + + return _autograd_impl + + +# TODO: when tracing this will add torch tensors and not TensorMeta objects +# to the trace -- we should fix this by adding a tracing context and NumberMeta classes +# TODO: this wrapper is currently untested +def elementwise_unary_scalar_wrapper( + fn: Callable[_P, _T], +) -> Callable[_P, Union[_T, NumberType]]: + """ + Allows unary operators that accept tensors to work with Python numbers. + """ + sig = inspect.signature(fn) + + @wraps(fn) + def _fn(*args, **kwargs): + if len(args) > 0 and isinstance(args[0], Number): + dtype = utils.type_to_dtype(type(args[0])) + args_ = list(args) + args_[0] = torch.tensor(args[0], dtype=dtype) + result = fn(*args_, **kwargs) + assert isinstance(result, torch.Tensor) + return result.item() + + return fn(*args, **kwargs) + + _fn.__signature__ = sig # type: ignore[attr-defined] + return _fn diff --git a/phivenv/Lib/site-packages/torch/_refs/__init__.py b/phivenv/Lib/site-packages/torch/_refs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1cc1947075f507b541f9899df5a616c0946a4461 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_refs/__init__.py @@ -0,0 +1,6707 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import builtins +import collections +import inspect +import itertools +import math +import operator +import warnings +from collections.abc import Iterable, Sequence +from enum import Enum +from functools import partial, reduce, singledispatch, wraps +from typing import Any, Callable, cast, Optional, overload, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch.utils._pytree as pytree +from torch import sym_float, sym_int +from torch._prims_common import ( + BoolLike, + definitely_contiguous, + definitely_contiguous_for_memory_format, + DeviceLikeType, + Dim, + DimsSequenceType, + DimsType, + dtype_to_type, + ELEMENTWISE_TYPE_PROMOTION_KIND, + FloatLike, + FloatWithoutSymFloat, + IntLike, + is_weakly_lesser_type, + Number, + NumberType, + RealNumberType, + REDUCTION_OUTPUT_TYPE_KIND, + ShapeType, + StrideType, + TensorLike, + TensorLikeType, + TensorOrNumberLikeType, + TensorSequenceType, +) +from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, + _maybe_resize_out, + _safe_copy_out, + elementwise_type_promotion_wrapper, + elementwise_unary_scalar_wrapper, + out_wrapper, +) + + +# Experimental module containing prototype Python references for existing +# PyTorch operations. + +__all__ = [ + # + # Elementwise Unary References + # + "abs", + "acos", + "acosh", + "asinh", + "asin", + "atan", + "atanh", + "bitwise_not", + # "cbrt", # No corresponding torch operation + "ceil", + "conj_physical", + "cos", + "cosh", + "count_nonzero", + "deg2rad", + "digamma", + "erf", + "erfinv", + "erfc", + "exp", + "expm1", + "exponential", + "exp2", + "fill", + "fill_", + "floor", + "frac", + "geometric", + "index_add", + "index_copy", + "index_copy_", + "index_select", + "index_fill", + "index_fill_", + "isfinite", + "isinf", + "isposinf", + "isneginf", + "isnan", + "isreal", + "i0", + "lerp", + "lgamma", + "log", + "log1p", + "log2", + "log10", + "log_normal", + "log_softmax", + "mvlgamma", + "norm", + "normal", + "nan_to_num", + "neg", + "positive", + "rad2deg", + "reciprocal", + "round", # TODO: model kwargs + "sigmoid", + "sgn", + "sign", + "signbit", + "sin", + "sinc", + "sinh", + "softmax", + "sqrt", + "square", + "tan", + "tanh", + "trace", + "trunc", + # + # Elementwise Binary References + # + "add", + "atan2", + "bitwise_and", + "bitwise_left_shift", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "clamp_min", + "clamp_max", + "copysign", + "div", + "eq", + "float_power", + "floor_divide", + "fmax", + "fmin", + "fmod", + "gcd", + "ge", + "gt", + "heaviside", + "hypot", + "igamma", + "igammac", + "imag", + "isclose", + "lcm", + # 'ldexp', + "le", + "logaddexp", + "logaddexp2", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "logsumexp", + "lt", + # 'max', # implement with reductions + "maximum", + # 'min', # implement with reductions + "minimum", + "mul", + "ne", + "nextafter", + # 'polar', # abs, cos, sin + "pow", + "real", + "rpow", + "remainder", + "rsub", + "rtruediv", + "rfloordiv", + "sub", + "true_divide", + "trunc_divide", + "xlogy", + # + # Elementwise Ternary References + # + "addcdiv", + "addcmul", + "clamp", + # + # Conditional references + # + "masked_fill", + "masked_fill_", + "where", + # + # Data conversion and movement references + # + "clone", + "copy_to", # TODO: add OpInfo (or implement .to) + "item", + "to", + # + # Reduction ops + # + "all", + "amax", + "amin", + "any", + "cumsum", + "cumprod", + "mean", + "dot", + "vdot", + "std", + "std_mean", + "sum", + "sum_to_size", + "prod", + "var", + "var_mean", + # + # Linear algebra ops + # + "addr", + # + # View & Shape Ops + # + "alias", + "alias_copy", + "atleast_1d", + "atleast_2d", + "atleast_3d", + "as_strided", + "as_strided_copy", + "as_strided_scatter", + "block_diag", + "broadcast_shapes", + "broadcast_tensors", + "broadcast_to", + "cat", + "chunk", + "column_stack", + "conj", + "constant_pad_nd", + "contiguous", + "diag_embed", + "diag", + "diagonal", + "diagonal_copy", + "diagonal_scatter", + "dsplit", + "dstack", + "expand", + "expand_as", + "expand_copy", + "flatten", + "flip", + "fliplr", + "flipud", + "hsplit", + "hstack", + "meshgrid", + "movedim", + "narrow", + "narrow_copy", + "native_group_norm", + "native_layer_norm", + "permute", + "permute_copy", + "ravel", + "repeat", + "reshape", + "reshape_as", + "roll", + "rot90", + "rsqrt", + "split_with_sizes", + "stack", + "swap_axes", # alias for transpose + "squeeze", + "squeeze_copy", + "t", + "t_copy", + "T", + "take_along_dim", + "tensor_split", + "transpose", + "transpose_copy", + "unbind_copy", + "unfold", + "unfold_copy", + "unsqueeze", + "unsqueeze_copy", + "view", + "view_as", + "view_copy", + "vsplit", + "vstack", + "view_as_complex", + "unflatten", + "unbind", + "triu", + "tril", + "triu_indices", + "tril_indices", + # + # Tensor Creation + # + "arange", + "cauchy", + "empty", + "empty_like", + "empty_permuted", + "empty_strided", + "eye", + "full", + "full_like", + "linspace", + "logspace", + "new_empty", + "new_empty_strided", + "new_full", + "new_ones", + "new_zeros", + "ones", + "ones_like", + "randn", + "scalar_tensor", + "zero", + "zeros", + "zeros_like", + # + # Test-related functions + # + "allclose", + "equal", + # + # Statistical operations + # + "bucketize", + # + # Misc + # + "is_complex", + "renorm", + "stft", + "istft", +] + +Tensor = torch.Tensor +DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] +aten = torch._ops.ops.aten + +# Note that the docstrings for the public methods from this file are in +# torch/_torch_docs.py + + +def is_noncontiguous_supported(device): + return device is None or device.type != "hpu" + + +def handle_noncontiguous_outputs(input_tlist, output): + device = None + from torch._subclasses.fake_tensor import FakeTensor + + for t in input_tlist: + if isinstance(t, FakeTensor): + device = t.fake_device + break + + if not is_noncontiguous_supported(device): + output = output.contiguous() + + return output + + +def _broadcast_shapes(*_shapes): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + shapes = tuple( + (x,) if isinstance(x, IntLike) else x + for x in filter(lambda x: x is not None, _shapes) + ) + + # Short-circuits on no input + if len(shapes) == 0: + return None + + # Type checking + # TODO: make common validations available as utils + for shape in shapes: + assert isinstance(shape, Sequence) + + # Computes common shape + common_shape: list[Union[int, torch.SymInt]] = [ + 1, + ] * reduce(max, (len(shape) for shape in shapes)) + for arg_idx, shape in enumerate(shapes): + for idx in range(-1, -1 - len(shape), -1): + if guard_size_oblivious(common_shape[idx] == 1): + if shape[idx] < 0: + raise ValueError( + "Attempting to broadcast a dimension with negative length!" + ) + common_shape[idx] = shape[idx] + elif guard_size_oblivious(shape[idx] != 1): + torch._check( + common_shape[idx] == shape[idx], + lambda: f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! " + f"Mismatching argument at index {arg_idx} had {shape}; but expected shape " + f"should be broadcastable to {common_shape}", + ) + + return common_shape + + +def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True): + # Computes common shape + common_shape = _broadcast_shapes( + *(t.shape if isinstance(t, TensorLike) else None for t in args) + ) + + def __maybe_broadcast(x, shape): + if x is None: + return None + elif isinstance(x, Number): + return x + elif isinstance(x, TensorLike): + if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x): + return x + + if not utils.same_shape(x.shape, common_shape): + return x.expand(common_shape) + + return x + else: + raise RuntimeError( + "Unexpected type when broadcasting: " + str(type(x)) + "!" + ) + + return tuple(__maybe_broadcast(x, common_shape) for x in args) + + +# Utilities should come BEFORE this import +from torch._decomp import register_decomposition + + +# +# Elementwise unary references +# + +infer_aten_op = object() + + +# TODO: add type promotion support +def _make_elementwise_unary_reference( + type_promotion_kind, + *, + aten_op=infer_aten_op, + extra_meta=None, + exact_dtype=False, +) -> Callable: + def inner(prim: Callable): + nonlocal aten_op + + @wraps(prim) + @out_wrapper(exact_dtype=exact_dtype) + @elementwise_unary_scalar_wrapper + @elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=type_promotion_kind, + ) + def _ref(a: TensorLikeType) -> TensorLikeType: + if extra_meta is not None: + extra_meta(a) + + output = prim(a) + return handle_noncontiguous_outputs([a], output) + + if aten_op is infer_aten_op: + aten_op = utils.get_aten_op(prim, prim.__name__) + if aten_op is not None: + register_decomposition(aten_op)(_ref) + + return _ref + + return inner + + +def _make_alias(fn, name): + """ + This function defines an alias of another function and sets its __name__ argument. + It also sets its __module__ argument to the module of the caller. + Note that when naively doing `alias = fn`, we have that `alias.__name__ == "fn"`, and + `alias.__module__ == fn.__module__`. + """ + + def _fn(*args, **kwargs): + return fn(*args, **kwargs) + + _fn.__name__ = name + _fn.__module__ = inspect.currentframe().f_back.f_globals["__name__"] # type: ignore[union-attr] + return _fn + + +def _make_inplace(fn): + """ + Given a function with out variant (i.e. using `out_wrapper()), it returns its in-place variant + See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-do-in-place-operations-work-in-pytorch + """ + + # nb. We use the name of the first argument used in the unary references + @wraps(fn) + def _fn(a, *args, **kwargs): + return fn(a, *args, out=a, **kwargs) + + inplace_name = f"{fn.__name__}_" + _fn.__name__ = inplace_name + _fn = register_decomposition(getattr(aten, inplace_name))(_fn) # type: ignore[assignment] + + # We access the __all__ attribute of the module where fn is defined + # There may be a cleaner way of doing this... + from inspect import getmodule + + _all = getmodule(fn).__all__ # type: ignore[union-attr] + if inplace_name not in _all: + _all.append(inplace_name) + return _fn + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, + exact_dtype=True, +) +def abs(a): + return prims.abs(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def acos(a): + return prims.acos(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def acosh(a): + return prims.acosh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def asin(a): + return prims.asin(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def asinh(a): + return prims.asinh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def atan(a): + return prims.atan(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def atanh(a): + return prims.atanh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) +def bitwise_not(a): + return prims.bitwise_not(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + exact_dtype=True, +) +def ceil(a): + return prims.ceil(a) + + +@register_decomposition(aten.is_complex) +def is_complex(input: TensorLikeType): + return utils.is_complex_dtype(input.dtype) + + +@register_decomposition(aten.conj_physical) +@out_wrapper() +def conj_physical(input: TensorLikeType): + if not utils.is_complex_dtype(input.dtype): + return input + return prims.conj_physical(input) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def cos(a): + return prims.cos(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def cosh(a): + return prims.cosh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def digamma(a): + return prims.digamma(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def erf(a): + return prims.erf(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def erfinv(a): + return prims.erf_inv(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def erfc(a): + return prims.erfc(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def exp(a): + return prims.exp(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def expm1(a): + return prims.expm1(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def exp2(a): + return prims.exp2(a) + + +# Fill has its own implementation because it has a value parameter +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a,"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, +) +def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType: + assert isinstance(a, TensorLike) + assert isinstance(value, Number) + + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(value), python_type): + msg = f"value argument of type {type(value)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + + return prims.fill(a, value) + + +def fill_(a: TensorLikeType, value: NumberType) -> TensorLikeType: + r = prims.fill(a, value) + prims.copy_to(a, r) + return a + + +@register_decomposition(aten.zero) +@out_wrapper() +def zero(input: TensorLikeType) -> TensorLikeType: + return torch.zeros_like(input) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + exact_dtype=True, +) +def floor(a): + return prims.floor(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + exact_dtype=True, +) +def frac(x: TensorLikeType) -> TensorLikeType: + trunc_x = torch.mul(torch.floor(torch.abs(x)), torch.sign(x)) + return torch.sub(x, trunc_x) + + +# imag does not use _make_elementwise_unary_reference because it does not support out +def imag(a: TensorLikeType) -> TensorLikeType: + assert isinstance(a, TensorLike) + torch._check( + utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors." + ) + return prims.imag(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + aten_op=None, # CompositeImplicitAutograd +) +def isfinite(a: TensorLikeType) -> TensorLikeType: + if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype): + return prims.isfinite(a) + + return ones_like(a, dtype=torch.bool) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def isinf(a: TensorLikeType) -> TensorLikeType: + if utils.is_complex_dtype(a.dtype): + return torch.logical_or(isinf(torch.real(a)), isinf(torch.imag(a))) + if utils.is_float_dtype(a.dtype): + return torch.abs(a) == float("inf") + return torch.zeros_like(a, dtype=torch.bool) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + exact_dtype=True, +) +def isposinf(a: TensorLikeType) -> TensorLikeType: + torch._check( + not utils.is_complex_dtype(a.dtype), + lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}", + ) + if utils.is_float_dtype(a.dtype): + return a == float("inf") + return torch.zeros_like(a, dtype=torch.bool) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + exact_dtype=True, +) +def isneginf(a: TensorLikeType) -> TensorLikeType: + torch._check( + not utils.is_complex_dtype(a.dtype), + lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}", + ) + if utils.is_float_dtype(a.dtype): + return a == float("-inf") + return torch.zeros_like(a, dtype=torch.bool) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def isnan(a: TensorLikeType) -> TensorLikeType: + return prims.ne(a, a) + + +# alias +mvlgamma = _make_alias(torch.special.multigammaln, "mvlgamma") # type: ignore[has-type] + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + aten_op=None, # CompositeImplicitAutograd +) +def isreal(a: TensorLikeType) -> TensorLikeType: + if utils.is_complex_dtype(a.dtype): + return torch.imag(a) == 0 + return torch.ones_like(a, dtype=torch.bool) + + +# TODO: if this is special maybe it should be defined there and imported here? +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=aten.i0 +) +def i0(a): + return prims.bessel_i0(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def lgamma(a): + return prims.lgamma(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def log(a): + return prims.log(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def log1p(a): + return prims.log1p(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def log2(a): + return prims.log2(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def log10(a): + return prims.log10(a) + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def log_softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + result_dtype = dtype or a.dtype + computation_dtype = utils.get_computation_dtype(result_dtype) + a_ = _maybe_convert_to_dtype(a, computation_dtype) + return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype) # type: ignore[return-value] + + +@register_decomposition(aten.logsumexp) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def logsumexp( + self: TensorLikeType, dim: DimsType, keepdim: bool = False +) -> TensorLikeType: + if not isinstance(dim, Iterable): + dim = (dim,) + if self.numel() == 0: + return torch.sum(torch.exp(self), dim, keepdim).log() + maxes = torch.amax(torch.real(self), dim, keepdim=True) + maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0) + maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim) + result = torch.sum(torch.exp(self - maxes), dim, keepdim) + return result.log().add(maxes_squeezed) + + +@register_decomposition(aten.nan_to_num) +@out_wrapper() +def nan_to_num( + a: TensorLikeType, + nan: Optional[NumberType] = 0.0, + posinf: Optional[NumberType] = None, + neginf: Optional[NumberType] = None, +) -> TensorLikeType: + assert isinstance(a, TensorLike) + + if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype): + return a.clone() + + if nan is None: + nan = 0.0 + + if posinf is None: + posinf = torch.finfo(a.dtype).max + + if neginf is None: + neginf = torch.finfo(a.dtype).min + + result = torch.where(torch.isnan(a), nan, a) # type: ignore[call-overload] + result = torch.where(torch.isneginf(a), neginf, result) # type: ignore[call-overload] + result = torch.where(torch.isposinf(a), posinf, result) # type: ignore[call-overload] + return result + + +def _neg_meta(a: TensorLikeType): + torch._check( + a.dtype is not torch.bool, + lambda: ( + "Negation, the `-` operator, on a bool tensor is not supported. " + "If you are trying to invert a mask, use the `~` or `logical_not()` " + "operator instead." + ), + ) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, extra_meta=_neg_meta +) +def neg(a): + return prims.neg(a) + + +# positive does not use _make_elementwise_unary_reference because it does not support out +# CompositeImplicitAutograd - don't register decomp +def positive(a: TensorLikeType) -> TensorLikeType: + assert isinstance(a, TensorLike) + if a.dtype is torch.bool: + msg = "positive does not support bool tensors." + raise RuntimeError(msg) + return a + + +# real does not use _make_elementwise_unary_reference because it does not support out +def real(a: TensorLikeType) -> TensorLikeType: + assert isinstance(a, TensorLike) + if utils.is_complex_dtype(a.dtype): + return prims.real(a) + return a + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def reciprocal(a): + return prims.reciprocal(a) + + +@register_decomposition(aten.round) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def round(a: TensorLikeType, *, decimals: int = 0) -> TensorLikeType: + if decimals == 0: + return prims.round(a) + else: + ten_pow = 10**decimals + ten_neg_pow = 10 ** (-decimals) + return prims.mul(prims.round(prims.mul(a, ten_pow)), ten_neg_pow) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def rsqrt(a): + return prims.rsqrt(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sigmoid(a: TensorLikeType) -> TensorLikeType: + return true_divide(1, add(1, exp(neg(a)))) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + exact_dtype=True, +) +def sgn(a): + if utils.is_complex_dtype(a.dtype): + a_abs = a.abs() + return torch.where(a_abs == 0, 0, a / a_abs) + else: + return a.sign() + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + exact_dtype=True, +) +def sign(a): + return prims.sign(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + exact_dtype=True, +) +def signbit(a): + return prims.signbit(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sin(a): + return prims.sin(a) + + +# Autograd note: This will give the right first derivative at zero (by chance), +# but not the right second derivative +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sinc(a): + a = math.pi * a + return torch.where(a == 0, 1, torch.sin(a) / a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sinh(a): + return prims.sinh(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def sqrt(a): + return prims.sqrt(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, + aten_op=None, # CompositeImplicitAutograd, +) +def square(a: TensorLikeType) -> TensorLikeType: + return mul(a, a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def tan(a): + return prims.tan(a) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def tanh(a): + return prims.tanh(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + exact_dtype=True, +) +def trunc(a): + return prims.trunc(a) + + +# TODO: register this as a real ref/decomposition once TorchInductor supports complex! +def view_as_complex(self: TensorLikeType) -> TensorLikeType: + input_dtype = self.dtype + torch._check( + utils.is_float_dtype(input_dtype), + lambda: f"view_as_complex is only supported for floating point" + f"tensors, but got a tensor of scalar type: {input_dtype}", + ) + sizes = self.size() + torch._check( + len(sizes) != 0, + lambda: "Input tensor must have one or more dimensions", + ) + torch._check( + sizes[-1] == 2, + lambda: "Tensor must have a last dimension of size 2", + ) + + old_strides = self.stride() + torch._check( + old_strides[-1] == 1, + lambda: "Tensor must have a last dimension with stride 1", + ) + dims = old_strides[:-1] + torch._check( + builtins.all(stride % 2 == 0 for stride in dims), + lambda: "Tensor must have a stride divisible by 2 for all but last dimension", + ) + torch._check( + self.storage_offset() % 2 == 0, + lambda: "Tensor must have a storage_offset divisible by 2", + ) + return prims.view_element_type( + self, utils.corresponding_complex_dtype(input_dtype) + ).squeeze(-1) + + +def _make_elementwise_binary_reference( + type_promotion_kind, + aten_op=infer_aten_op, + name=None, + has_out=True, + supports_lhs_python_scalar=True, + supports_rhs_python_scalar=True, + supports_two_python_scalars=False, + should_register_decomposition=True, +) -> Callable: + def inner(prim: Callable): + nonlocal aten_op, name + if name is None: + name = prim.__name__ + + @wraps(prim) + @elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=type_promotion_kind, + ) + def _ref( + a: Union[Tensor, NumberType], + b: Union[Tensor, NumberType], + ) -> Tensor: + torch._check_value( + supports_lhs_python_scalar or not isinstance(a, Number), + lambda: f"{name}: Received a lhs Python scalar to an elementwise binary " + "operation that does not accept lhs scalars!", + ) + torch._check_value( + supports_rhs_python_scalar or not isinstance(b, Number), + lambda: f"{name}: Received a rhs Python scalar to an elementwise binary " + "operation that does not accept rhs scalars!", + ) + torch._check_value( + supports_two_python_scalars + or not (isinstance(a, Number) and isinstance(b, Number)), + lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!", + ) + a, b = _maybe_broadcast(a, b) + output = prim(a, b) + return handle_noncontiguous_outputs([a, b], output) + + if has_out: + _ref = out_wrapper()(_ref) # type: ignore[assignment] + + _ref.__name__ = name + if aten_op is infer_aten_op: + aten_op = utils.get_aten_op(prim, name) + if aten_op is not None and should_register_decomposition: + register_decomposition(aten_op)(_ref) + + return _ref + + return inner + + +# Add has its own implementation because it has an alpha argument +@register_decomposition(aten.add) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def add( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + *, + alpha: Optional[NumberType] = None, +): + """ + Reference implementation of torch.add + """ + + a, b = _maybe_broadcast(a, b) + + if alpha is not None: + dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr] + python_type = utils.dtype_to_type(dtype) + if python_type != bool and not utils.is_weakly_lesser_type( + type(alpha), python_type + ): + msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + if isinstance(b, TensorLike): + b = prims.mul(b, alpha) + else: + b = b * alpha + + output = prims.add(a, b) + return handle_noncontiguous_outputs([a, b], output) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def atan2(a, b): + return prims.atan2(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_and(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.bitwise_and(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_left_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.shift_left(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_or(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.bitwise_or(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_right_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.shift_right_arithmetic(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def bitwise_xor(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.bitwise_xor(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, +) +def copysign( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +): + if isinstance(b, Number) and isinstance(a, Tensor): + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device: + msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!" + raise RuntimeError(msg) + return where(signbit(b), neg(abs(a)), abs(a)) + + +# complex = _make_elementwise_binary_reference(prims.complex, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) + + +@register_decomposition(aten.div) +@out_wrapper() +def div( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + *, + rounding_mode: Optional[str] = None, +): + """ + Reference implementation of torch.div + """ + if rounding_mode is None: + return true_divide(a, b) + elif rounding_mode == "trunc": + return trunc_divide(a, b) + elif rounding_mode == "floor": + return floor_divide(a, b) + else: + msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}." + raise ValueError(msg) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def eq(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.eq(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, +) +def pow( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], +) -> TensorLikeType: + assert isinstance(a, TensorLikeType) or isinstance(b, TensorLikeType) + + if isinstance(b, Number): + if b == 1.0: + return a.clone() # type: ignore[return-value,union-attr] + elif b == 2.0: + return a * a # type: ignore[return-value] + elif b == 0.5: + return torch.sqrt(a) # type: ignore[arg-type] + elif isinstance(a, Number): + if a == 1.0: + return torch.fill(b, True) + if a == 2.0 and ( + utils.is_float_dtype(b.dtype) or utils.is_complex_dtype(b.dtype) + ): + return torch.exp2(b) + + return prims.pow(a, b) + + +# Float power has its own implementation because it has unique type promotion. +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def float_power( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], +) -> Tensor: + if isinstance(a, Number) and isinstance(b, Number): + raise ValueError( + "Receive two Number inputs to an elementwise binary operation!" + ) + + # Handles type promotion + dtype = utils.get_higher_dtype(a, b) + assert dtype is not None + if utils.is_complex_dtype(dtype): + dtype = torch.complex128 + else: + dtype = torch.float64 + + # Float power has the following contiguous cast behavior to be + # consistent with its C++ impl + a = _maybe_convert_to_dtype(a, dtype) + b = _maybe_convert_to_dtype(b, dtype) + + a, b = _maybe_broadcast(a, b) + return pow(a, b) + + +# >>> a = torch.tensor(-0.2500, dtype=torch.float64) +# tensor(-0.250000000000000, dtype=torch.float64) +# +# >>> b = torch.tensor(-0.0010, dtype=torch.float64) +# tensor(-0.001000000000000, dtype=torch.float64) +# +# Note: In this case, casting float to double will expand the float mantissa with zeros, +# while creating a double generates a distinct mantissa. +# >>> torch.tensor(-0.001).to(dtype=torch.float64) +# tensor(-0.001000000047497, dtype=torch.float64) +# +# Floor Division +# The difference is caused because torch.remainder(a, b) = -0.001. +# +# >>> torch.floor(torch.true_divide(a, b)) +# tensor(250., dtype=torch.float64) +# +# >>> torch.div(a, b, rounding_mode='floor') +# tensor(249., dtype=torch.float64) +# +# Definition: a // b = (a - remainder(a, b)) / b +# >>> torch.true_divide(torch.sub(a, torch.remainder(a, b)), b) +# tensor(249., dtype=torch.float64) +# +# For reference, see CPython's implementation: +# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 + + +@_make_elementwise_binary_reference( + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_two_python_scalars=True, + should_register_decomposition=False, +) +def floor_divide( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +): + # Wrap scalars because some references only accept tensor arguments. + if isinstance(a, Number) and isinstance(b, Number): + a = scalar_tensor(a) + b = scalar_tensor(b) + elif isinstance(b, Number) and isinstance(a, Tensor): + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(a, Number) and isinstance(b, Tensor): + a = scalar_tensor(a, dtype=b.dtype, device=b.device) + elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device: + if a.device == torch.device("cpu"): + msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!" + raise RuntimeError(msg) + else: + b = prims.device_put(b, device=a.device) + + assert isinstance(a, Tensor) and isinstance(b, Tensor) + dtype = a.dtype + if utils.is_float_dtype(dtype): + return _floor_divide_float(a, b) + elif utils.is_integer_dtype(dtype): + return _floor_divide_integer(a, b) + else: + torch._check(False, lambda: f"{dtype} not supported for floor_divide") + + +def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor: + a, b = _maybe_broadcast(a, b) + + if not a.dtype.is_signed: + return prims.div(a, b) + + # Convert truncation to flooring: + offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0) + return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype) + + +def _floor_divide_float(a: Tensor, b: Tensor) -> Tensor: + mod = fmod(a, b) + div = true_divide(sub(a, mod), b) + + # Ensure that the remainder has the same sign as denominator + different_signed_inputs = bitwise_xor(lt(a, 0), lt(b, 0)) + non_zero_remainder = ne(mod, 0) + mask = bitwise_and(non_zero_remainder, different_signed_inputs) + div = where(mask, sub(div, 1), div) + + # Map quotient to nearest integer value + floor_div = floor(div) + mask = gt(sub(div, floor_div), 0.5) + floor_div = where(mask, add(floor_div, 1), floor_div) + + basic_div = true_divide(a, b) + zero_tensor = scalar_tensor(0, dtype=basic_div.dtype, device=basic_div.device) + + # If quotient is zero, copy signbit from true_divide quotient + floor_div = where(ne(div, 0), floor_div, copysign(zero_tensor, basic_div)) + + # If denominator is zero, then follow true_divide behavior + return where(ne(b, 0), floor_div, basic_div) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def fmax(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.fmax(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def fmin(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.fmin(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=True, +) +def fmod(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.fmod(a, b) + + +@register_decomposition(aten.frexp) +@out_wrapper("mantissa", "exponent") +def frexp(self: TensorLikeType) -> tuple[TensorLikeType, TensorLikeType]: + return torch.return_types.frexp(prims.frexp(self)) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def gcd(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.gcd(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def ge(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.ge(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def gt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.gt(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType: + input_eq_zero = torch.eq(input, 0) + input_lt_zero = torch.logical_or(torch.lt(input, 0), torch.isnan(input)) + zeros_and_ones = torch.where(input_lt_zero, 0, 1) + output = torch.where(input_eq_zero, values, zeros_and_ones) + return output + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def hypot(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.hypot(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def igamma(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.igamma(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def igammac(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.igammac(a, b) + + +def _check_close_args( + name: str, + a: TensorLikeType, + b: TensorLikeType, + rtol: float, + atol: float, +) -> None: + torch._check_value( + a.dtype == b.dtype, + lambda: f"{name}: Attempting to compare tensors of different dtypes {a.dtype} and {b.dtype}!", + ) + torch._check( + rtol >= 0, + lambda: f"{name}: rtol must be greater than or equal to zero, but got {rtol}!", + ) + torch._check( + atol >= 0, + lambda: f"{name}: atol must be greater than or equal to zero, but got {atol}!", + ) + + +# CompositeImplicitAutograd - don't register decomp +def isclose( + a: TensorLikeType, + b: TensorLikeType, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> TensorLikeType: + _check_close_args(name="torch.isclose", a=a, b=b, rtol=rtol, atol=atol) + + close = eq(a, b) + if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)): + close = logical_or(close, logical_and(isnan(a), isnan(b))) + + # Note: In case of zero tolerances the closeness inequality degenerates to an equality check. + # In this case, the short-circuit prevents false positives as detailed in the paragraph below. + if atol == 0 and rtol == 0: + return close + + # Note [closeness error computation] + # atol and rtol are provided as doubles, so the computation + # rtol * other will produce a float or complex tensor. + # When the difference (self - other) is compared to it then the + # tensor representing the difference will also be cast to float or complex. + # However, since (self - other) in uint8 is very likely to produce a + # negative value, this moves the cast forward so the difference is + # always computed in a float or complex type. + # If the values of the integer tensors cannot be exactly represented + # by the default scalar type then this may cause an incorrect result. + if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype(a.dtype): + a = prims.convert_element_type(a, torch.get_default_dtype()) + b = prims.convert_element_type(b, torch.get_default_dtype()) + + allowed_error = add(atol, abs(mul(b, rtol))) + actual_error = abs(sub(a, b)) + + # Computes finite closeness + result = logical_or( + close, logical_and(isfinite(actual_error), le(actual_error, allowed_error)) + ) + + return result + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def lcm(a: TensorLikeType, b: TensorLikeType): + dtype = a.dtype + # promoting to int32 to maintain 100% consistency with C++ and to + # prevent overflow in case of int8 and int16 + promote_to_int = dtype in (torch.int8, torch.int16) + if promote_to_int: + a = prims.convert_element_type(a, torch.int32) + b = prims.convert_element_type(b, torch.int32) + + g = torch.gcd(a, b) + # Avoid division by zero in case gcd(0, 0) == 0 + g = torch.where(g == 0, 1, g) + res = torch.abs(prims.div(a, g) * b) + return res if not promote_to_int else prims.convert_element_type(res, dtype) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def le(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.le(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def logaddexp(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + # Nb. this implementation does not distribute the gradients evenly when a == b + mask = torch.real(a) >= torch.real(b) + max_ = torch.where(mask, a, b) + min_ = torch.where(mask, b, a) + inf_mask = torch.logical_and( + torch.logical_not(torch.isfinite(torch.real(a))), torch.real(a) == torch.real(b) + ) + if utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype): + # are you wondering what this bunch of codes are for? edge cases! + neg_min_mask = torch.real(min_) < 0 + inf_vals = torch.where( + neg_min_mask, min_, torch.log(torch.exp(min_) + torch.exp(max_)) + ) + non_nan_vals = torch.where( + inf_mask, inf_vals, max_ + torch.log1p(torch.exp(min_ - max_)) + ) + # the type for full_like does not include tensor yet + nan_mask = torch.isnan(min_) + return torch.where(nan_mask, complex(float("nan"), float("nan")), non_nan_vals) # type: ignore[call-overload] + else: + return torch.where(inf_mask, a, max_ + torch.log1p(torch.exp(min_ - max_))) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def logaddexp2(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + torch._check( + not (utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype)), + lambda: "logaddexp2 doesn't support complex dtypes", + ) + # Nb. this implementation does not distribute the gradients evenly when a == b + mask = a >= b + max_ = torch.where(mask, a, b) + min_ = torch.where(mask, b, a) + inf_mask = torch.logical_and(torch.isinf(a), a == b) + inv_log_2 = 1.0 / math.log(2) + result = max_ + torch.log1p(torch.exp2(min_ - max_)) * inv_log_2 + return torch.where(inf_mask, a, result) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) +def logical_and(a: TensorLikeType, b: TensorLikeType): + if not utils.is_boolean_dtype(a.dtype): + a = a != 0 + if not utils.is_boolean_dtype(b.dtype): + b = b != 0 + return a & b + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) +def logical_not(a: TensorLikeType): + if not utils.is_boolean_dtype(a.dtype): + return a == 0 + return ~a + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) +def logical_or(a: TensorLikeType, b: TensorLikeType): + if not utils.is_boolean_dtype(a.dtype): + a = a != 0 + if not utils.is_boolean_dtype(b.dtype): + b = b != 0 + return bitwise_or(a, b) + + +# TODO: skip unnecessary conversion of long to float +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) +def logical_xor(a: TensorLikeType, b: TensorLikeType): + if not utils.is_boolean_dtype(a.dtype): + a = a != 0 + if not utils.is_boolean_dtype(b.dtype): + b = b != 0 + return a ^ b + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def lt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.lt(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def maximum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.maximum(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def minimum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.minimum(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_two_python_scalars=True, +) +def mul(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.mul(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, + supports_lhs_python_scalar=False, +) +def ne(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.ne(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def nextafter(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.nextafter(a, b) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def remainder(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.remainder(a, b) + + +# reverse sub +@register_decomposition(aten.rsub) +@out_wrapper() +def rsub( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + alpha: NumberType = 1, +): + if isinstance(a, Number): + msg = "Received a Number for the first argument, but expected a Tensor" + raise ValueError(msg) + + return torch.sub(b, a, alpha=alpha) + + +# TODO: consider refactoring this with add impl +# sub has its own implementation because it has an alpha argument +@register_decomposition(aten.sub) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def sub( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + *, + alpha: NumberType = 1, +): + """ + Reference implementation of torch.sub + """ + + a, b = _maybe_broadcast(a, b) + + if isinstance(a, TensorLike) and isinstance(b, TensorLike): + torch._check( + not utils.is_boolean_dtype(a.dtype) and not utils.is_boolean_dtype(b.dtype), + lambda: ( + "Subtraction, the `-` operator, with two bool tensors is not supported. " + "Use the `^` or `logical_xor()` operator instead." + ), + ) + + if alpha != 1: + dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr] + python_type = utils.dtype_to_type(dtype) + if not utils.is_weakly_lesser_type(type(alpha), python_type): + msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + if isinstance(b, torch.Tensor): + b = prims.mul(b, alpha) + else: + # Carefully not to use prims.mul if b is a scalar / symint. + # prims.mul always returns a tensor, + # which will mess with type promotion. + b = b * alpha + + output = prims.sub(a, b) + return handle_noncontiguous_outputs([a, b], output) + + +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + name="true_divide", + aten_op=None, # CompositeImplicitAutograd + supports_two_python_scalars=True, +) +def true_divide(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.div(a, b) + + +@register_decomposition(aten.xlogy) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): + torch._check( + isinstance(a, TensorLike) or isinstance(b, TensorLike), + lambda: 'Expected either argument a or b to be a Tensor"', + ) + + # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors. + if isinstance(b, TensorLike) and isinstance(a, Number): + a = scalar_tensor(a, dtype=b.dtype, device=b.device) + elif isinstance(a, TensorLike) and isinstance(b, Number): + b = scalar_tensor(b, dtype=a.dtype, device=a.device) + + # mypy: expected "Tensor" + assert isinstance(a, TensorLike) + assert isinstance(b, TensorLike) + rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log(b))) + return torch.where(torch.isnan(b), float("nan"), rhs) + + +@_make_elementwise_binary_reference( + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + aten_op=None, # CompositeImplicitAutograd + supports_two_python_scalars=True, +) +def trunc_divide( + a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] +): + dtype = utils.get_dtype(a) + if utils.is_integer_dtype(dtype): + return prims.div(a, b) + + return trunc(prims.div(a, b)) + + +# +# Elementwise Ternary References +# + + +@register_decomposition(aten.addcdiv) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "tensor1", "tensor2"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def addcdiv( + self: TensorLikeType, + tensor1: TensorLikeType, + tensor2: TensorLikeType, + *, + value: NumberType = 1, +) -> TensorLikeType: + """ + Reference implementation of torch.addcdiv + """ + if value is not None: + dtype = self.dtype # no scalars allowed, see add + python_type = utils.dtype_to_type(dtype) + torch._check_value( + utils.is_weakly_lesser_type(type(value), python_type), + lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!", + ) + + return self + value * tensor1 / tensor2 + + +@register_decomposition(aten.addcmul) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "tensor1", "tensor2"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def addcmul( + self: TensorLikeType, + tensor1: TensorLikeType, + tensor2: TensorLikeType, + *, + value: NumberType = 1, +) -> TensorLikeType: + """ + Reference implementation of torch.addcmul + """ + if value is not None: + dtype = self.dtype # no scalars allowed, see add + python_type = utils.dtype_to_type(dtype) + torch._check_value( + utils.is_weakly_lesser_type(type(value), python_type), + lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!", + ) + + return self + value * tensor1 * tensor2 + + +@register_decomposition(aten.clamp) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "min", "max"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def clamp( + a: TensorLikeType, + min: Optional[TensorOrNumberLikeType] = None, + max: Optional[TensorOrNumberLikeType] = None, +) -> TensorLikeType: + # NOTE: grad behavior with implementation `where` is not consistent on `nan` + if min is None and max is None: + msg = "clamp called but both min and max are none!" + raise ValueError(msg) + if min is not None: + a_isnan = torch.isnan(a) + condition = torch.bitwise_or(torch.ge(a, min), a_isnan) # type: ignore[arg-type] + # we should also propagate `nan` coming from boundaries. However, that's + # not necessary since `ge` would already `False` when either operands has + # a `nan`. So this line below is redundant + # `condition = bitwise_and(condition, bitwise_not(isnan(min)))` + a = torch.where(condition, a, min) # type: ignore[arg-type] + if max is not None: + a_isnan = torch.isnan(a) + # same as above, no need to adjust `nan` from `max` + condition = torch.bitwise_or(torch.le(a, max), a_isnan) # type: ignore[arg-type] + a = torch.where(condition, a, max) # type: ignore[arg-type] + + return a + + +@register_decomposition(aten.clamp_min) +@out_wrapper() +def clamp_min( + self: TensorLikeType, + min: Optional[TensorOrNumberLikeType] = None, +) -> TensorLikeType: + return torch.clamp(self, min=min) # type: ignore[arg-type] + + +@register_decomposition(aten.clamp_max) +@out_wrapper() +def clamp_max( + self: TensorLikeType, + max: Optional[TensorOrNumberLikeType] = None, +) -> TensorLikeType: + return torch.clamp(self, max=max) # type: ignore[arg-type] + + +# +# Conditional references +# + + +# https://pytorch.org/docs/stable/generated/torch.where.html +# TODO: implement alternate where +@register_decomposition(aten.where) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, +) +def where( + pred: Tensor, + a: Optional[TensorOrNumberLikeType] = None, + b: Optional[TensorOrNumberLikeType] = None, +): + """ """ + + if a is None or b is None: + raise NotImplementedError + + utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True) + torch._check( + pred.dtype is torch.bool, + lambda: f"expected predicate to be bool, got {pred.dtype}", + ) + + pred, a, b = _maybe_broadcast(pred, a, b) + return prims.where(pred, a, b) + + +# +# Data Movement References +# +@register_decomposition(aten.clone) +@out_wrapper() +def clone( + a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format +) -> TensorLikeType: + result = prims.clone(a, memory_format=memory_format) + return result + + +def copy_to(a: Tensor, b: Tensor, *, allow_cross_device=True): + if not allow_cross_device and a.device != b.device: + msg = f"Attempting to copy from device {b.device} to device {a.device}, but cross-device copies are not allowed!" + raise RuntimeError(msg) + + return prims.copy_to(a, b) + + +@register_decomposition(aten.item) +def item(a: TensorLikeType) -> NumberType: + if a.numel() != 1: + msg = f"Can't convert a tensor with {a.numel()} elements to a number!" + raise ValueError(msg) + + # NOTE: explicit conversion is necessary for bool! + # See https://github.com/pytorch/pytorch/issues/78071 + number_type = utils.dtype_to_type(a.dtype) + return number_type(prims.item(a)) + + +# fast path when `to` returns an alias to input. This mimics the same function in aten +def _to_will_alias( + a: TensorLikeType, + device: Optional[DeviceLikeType] = None, + dtype: Optional[torch.dtype] = None, + copy: Optional[bool] = None, + layout: Optional[torch.layout] = None, + memory_format: Optional[torch.memory_format] = None, + pin_memory: Optional[bool] = False, + non_blocking: bool = False, # not using non_blocking +) -> bool: + return ( + not copy + and (device is None or a.device == device) + and (dtype is None or a.dtype == dtype) + and (layout is None or a.layout == layout) + # is_pinned issue #84925 + # and (pin_memory is None or pin_memory == a.is_pinned()) + and ( + memory_format is None + or memory_format == torch.preserve_format + or utils.is_contiguous_for_memory_format(a, memory_format=memory_format) + ) + ) + + +@singledispatch +def _to_dispatch(*args, **kwargs): + raise NotImplementedError + + +@_to_dispatch.register +def _to_device( + device: torch.device, + dtype: torch.dtype, + non_blocking: bool = False, + copy: bool = False, + memory_format: Optional[torch.memory_format] = None, +) -> dict[str, Any]: + kwargs = { + "device": device, + "dtype": dtype, + "non_blocking": non_blocking, + "copy": copy, + "memory_format": memory_format, + } + return kwargs + + +@_to_dispatch.register +def _to_device_str( + device: str, + dtype: torch.dtype, + non_blocking: bool = False, + copy: bool = False, + memory_format: Optional[torch.memory_format] = None, +) -> dict[str, Any]: + kwargs = { + "device": torch.device(device), + "dtype": dtype, + "non_blocking": non_blocking, + "copy": copy, + "memory_format": memory_format, + } + return kwargs + + +@_to_dispatch.register +def _to_dtype( + dtype: torch.dtype, + non_blocking: bool = False, + copy: bool = False, + memory_format: Optional[torch.memory_format] = None, +) -> dict[str, Any]: + kwargs = { + "dtype": dtype, + "non_blocking": non_blocking, + "copy": copy, + "memory_format": memory_format, + } + return kwargs + + +@_to_dispatch.register +def _to_other( + other: Tensor, + non_blocking: bool = False, + copy: bool = False, + memory_format: Optional[torch.memory_format] = None, +) -> dict[str, Any]: + device = other.device + dtype = other.dtype + layout = other.layout + # is_pinned issue #84925 + # pin_memory = other.is_pinned() + kwargs = { + "device": device, + "dtype": dtype, + "layout": layout, + "non_blocking": non_blocking, + "copy": copy, + "memory_format": memory_format, + } + return kwargs + + +# remove to_kwargs that is already present in `a` +def _canonicalize_to_arguments(a: Tensor, to_kwargs: dict): + options_to_check = ["dtype", "device", "layout", "memory_format"] + # "device" option could be passed a str instead torch.device + if "device" in to_kwargs and isinstance(to_kwargs["device"], str): + to_kwargs["device"] = torch.device(to_kwargs["device"]) + + for kw in options_to_check: + if kw in to_kwargs: + if ( + (kw == "memory_format" and to_kwargs[kw] is torch.preserve_format) + or ( + kw == "device" + and to_kwargs[kw].type == a.device.type + and ( + not to_kwargs[kw].index or to_kwargs[kw].index == a.device.index + ) + ) + or ( + getattr(a, kw, None) == to_kwargs[kw] + ) # this also handles {"memory_format": None} + ): + to_kwargs.pop(kw) + + +def to(a: TensorLikeType, *args, **kwargs) -> TensorLikeType: + # handled dispatch via positional arguments + if len(args) != 0: + kwargs = _to_dispatch(*args, **kwargs) + + # TODO: is_pinned is not currently supported in refs or fake_tensor + # https://github.com/pytorch/pytorch/issues/84925 + assert "pin_memory" not in kwargs + _canonicalize_to_arguments(a, kwargs) + + if _to_will_alias(a, **kwargs): + return a + + copy = kwargs.pop("copy") if "copy" in kwargs else False + non_blocking = kwargs.pop("non_blocking") if "non_blocking" in kwargs else False + + # short-circuit to `prims.convert_element_type` when `to` is just a dtype change + if ( + (copy or (kwargs.get("dtype", a.dtype) != a.dtype)) + and (not non_blocking) + and ("memory_format" not in kwargs) + and ("device" not in kwargs) + and ("layout" not in kwargs) + # is_pinned issue #84925 + # and ("pin_memory" not in kwargs) + ): + return prims.convert_element_type(a, kwargs.get("dtype", a.dtype)) + + result = torch.empty_like(a, **kwargs) + # TODO: non_blocking should be handled by `copy_to` + copy_to(result, a) + return result + + +# +# Reduction references +# + + +def _reduction( + a: TensorLikeType, + prim: Callable, + *, + has_identity: bool = True, + accepts_dim_tuple: bool = True, # to handle min/argmin that accept single dim only + dims: Optional[DimsType] = None, + keepdims: bool = False, + dtype: Optional[torch.dtype] = None, # should be specified for ops that support it + out: Optional[Tensor] = None, + output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND, +) -> TensorLikeType: # it is usually SAME, but I want + # ref writers to actually think about what to put here + assert isinstance(a, TensorLike) + if a.ndim > 64: + raise RuntimeError( + f"Received a tensor with {a.ndim} dimensions, but only tensors with up to 64 dims are supported!" + ) + + if out is not None: + assert isinstance(out, TensorLike) + if dtype is not None: + # TODO - this is true for eager mode currently, but it's wrong behavior for complex norms + if dtype != out.dtype: + raise RuntimeError( + "dtype argument and out dtype must match in reduction" + ) + if not accepts_dim_tuple: + assert dims is None or isinstance(dims, Dim) + if isinstance(dims, Dim): + dims = (dims,) # type: ignore[assignment] + dims = utils.reduction_dims(a.shape, dims) + if not has_identity: + valid_shape = a.ndim == 0 or builtins.all(a.shape[i] for i in dims) + if not valid_shape: + raise RuntimeError( + "reducing over zero-size dimension for reduction operation without identity" + ) + computation_dtype, result_dtype = utils.reduction_dtypes( + a, output_dtype_kind, dtype + ) + a = _maybe_convert_to_dtype(a, computation_dtype) # type: ignore[method-assign] + result = prim(a, dims) + if keepdims: + output_shape = [a.shape[i] if i not in dims else 1 for i in range(a.ndim)] + broadcast_dims = [i for i in range(a.ndim) if i not in dims] + result = prims.broadcast_in_dim(result, output_shape, broadcast_dims) + + if out is not None: + assert result_dtype is not None + if dtype is not None and result_dtype != out.dtype: + raise RuntimeError( + "Expected the dtype of reduction result and out to match" + ) + out = _maybe_resize_out(out, result.shape) + return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type] + + if result.dtype != result_dtype and result_dtype is not None: + result = prims.convert_element_type(result, result_dtype) + + return result + + +def _make_copy_from_view(fn): + """ + Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy) + """ + aten_fn = getattr(aten, fn.__name__) + annotations = getattr(fn, "__annotations__", {}) + fn = out_wrapper()(aten_fn) + + @wraps(fn) + def _fn(*args, out=None, **kwargs): + result = fn(*args, out=out, **kwargs) + if out is not None: + return result + + return pytree.tree_map( + lambda x: x.clone(memory_format=torch.contiguous_format), + result, + ) + + copy_name = f"{fn.__name__}_copy" + _fn.__name__ = copy_name + _fn.__annotations__.update(annotations) + register_decomposition(getattr(aten, copy_name))(_fn) + return _fn + + +@register_decomposition(aten.all) +@out_wrapper() +def all( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, +) -> TensorLikeType: + result = torch.logical_not(torch.any(torch.logical_not(a), dim, keepdim=keepdim)) + + if a.dtype == torch.uint8: + result = result.to(dtype=torch.uint8) + + return result + + +@register_decomposition(aten.any) +@out_wrapper() +def any( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, +) -> TensorLikeType: + a_ = _maybe_convert_to_dtype(a, torch.bool) + if isinstance(dim, (list, tuple)) and len(dim) == 0: + result = a_.clone() + else: + result = a_.sum(dim=dim, keepdim=keepdim).ne(False) + + # Preserves uint8 -- probably a legacy mask thing + if a.dtype is torch.uint8: + return prims.convert_element_type(result, torch.uint8) + + return result + + +@register_decomposition([aten.sum.dim_IntList, aten.sum.IntList_out]) +def sum( + a: TensorLikeType, + dim: Union[Optional[int], Optional[list[int]]] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + if dtype is None: + if out is not None: + dtype = out.dtype + elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype): + dtype = torch.int64 + else: + dtype = a.dtype + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + return _reduction( + a, + prims.sum, + dims=dim, + keepdims=keepdim, + dtype=dtype, + out=out, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, + ) + + +def sum_to_size( + a: Tensor, + *shape, +) -> Tensor: + shape = utils.extract_shape_from_varargs(shape, validate=False) + torch._check( + utils.is_expandable_to(shape, a.shape), + lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"', + ) + # In ATen scalar tensors are sent through sum and the result is returned as + # type promoted + if utils.is_same_shape(shape, a.shape) and len(shape) > 0: + return prims.view_of(a) + leading_dims = a.ndim - len(shape) + reduce_dims = tuple(range(leading_dims)) + tuple( + i + for i in range(leading_dims, len(shape)) + if shape[i - leading_dims] == 1 and a.shape[i] != 1 + ) + return torch.sum(a, dim=reduce_dims, keepdim=True, dtype=None) + + +@register_decomposition(aten.prod) +def prod( + a: TensorLikeType, + dim: Union[Optional[int], Optional[list[int]]] = None, + keepdim: bool = False, + *, + dtype=None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + if dtype is None: + if out is not None: + dtype = out.dtype + elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype): + dtype = torch.int64 + else: + dtype = a.dtype + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + return _reduction( + a, + prims.prod, + dims=dim, + keepdims=keepdim, + dtype=dtype, + out=out, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, + ) + + +@register_decomposition(aten.amin) +def amin( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + out: Optional[Tensor] = None, +) -> TensorLikeType: + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + + return _reduction( + a, + prims.amin, + dims=dim, + keepdims=keepdim, + dtype=None, + out=out, + has_identity=False, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, + ) + + +@register_decomposition(aten.amax) +def amax( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + out: Optional[Tensor] = None, +) -> TensorLikeType: + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + + return _reduction( + a, + prims.amax, + dims=dim, + keepdims=keepdim, + dtype=None, + out=out, + has_identity=False, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME, + ) + + +def _dim_var_dispatch(dim=None, unbiased=None): + # There's the following overload of torch.var: + # var(Tensor self, bool unbiased=True) -> (Tensor, Tensor) + # We need to explicitly convert bool dims to unbiased arg + if unbiased is None and isinstance(dim, bool): + unbiased = dim + dim = None + return dim, unbiased + + +@register_decomposition(aten.var) +@out_wrapper() +def var( + a: TensorLikeType, + dim: Optional[DimsType] = None, + unbiased: Optional[bool] = None, + keepdim: bool = False, + *, + correction: Optional[NumberType] = None, +) -> TensorLikeType: + dim, unbiased = _dim_var_dispatch(dim, unbiased) + correction = utils.set_correction(unbiased, correction) + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + + result = _reduction( + a, + partial(prims.var, correction=correction), + dims=dim, + keepdims=keepdim, + dtype=None, + out=None, + has_identity=True, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, + ) + return result + + +@register_decomposition(aten.std) +@out_wrapper() +def std( + a: TensorLikeType, + dim: Union[Optional[int], Optional[list[int]]] = None, + unbiased: Optional[bool] = None, + keepdim: bool = False, + *, + correction: Optional[NumberType] = None, +) -> TensorLikeType: + dim, unbiased = _dim_var_dispatch(dim, unbiased) + correction = utils.set_correction(unbiased, correction) + + opmath_dtype, dtype = utils.reduction_dtypes( + a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT + ) + a = _maybe_convert_to_dtype(a, opmath_dtype) + a_var = torch.var(a, dim, correction=correction, keepdim=keepdim) + a_std = torch.sqrt(a_var) + assert dtype is not None + return _maybe_convert_to_dtype(a_std, dtype) + + +@register_decomposition(aten.mean) +def mean( + a: TensorLikeType, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype=None, + out=None, +) -> TensorLikeType: + # reduces over all dimensions if dim=() is passed + if dim == () or dim == []: + dim = None + orig_dtype = dtype + if dtype is None: + dtype = a.dtype + result = _reduction( + a, + prims.sum, + dims=dim, + keepdims=keepdim, + dtype=dtype, + out=None, + output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE, + ) + torch._check( + utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), + lambda: ( + f"mean(): could not infer output dtype. " + f"{'Input' if orig_dtype is None else 'Optional'} dtype must be either " + f"a floating point or complex dtype. Got: {dtype}" + ), + ) + if isinstance(dim, Dim): + dim = (dim,) # type: ignore[assignment] + dims = utils.reduction_dims(a.shape, dim) # type: ignore[arg-type] + nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1) + result = true_divide(result, nelem) + result_dtype = a.dtype if dtype is None else dtype + result = _maybe_convert_to_dtype(result, result_dtype) # type: ignore[method-assign] + if out is not None: + assert isinstance(out, TensorLike) + out = _maybe_resize_out(out, result.shape) + return _safe_copy_out(copy_from=result, copy_to=out) # type: ignore[arg-type] + return result + + +@register_decomposition(aten.std_mean) +@out_wrapper("out0", "out1") +def std_mean( + a: TensorLikeType, + dim: Optional[DimsType] = None, + *, + unbiased: Optional[bool] = None, + keepdim: bool = False, + correction: Optional[NumberType] = None, +): + dim, unbiased = _dim_var_dispatch(dim, unbiased) + correction = utils.set_correction(unbiased, correction) + opmath_dtype, dtype = utils.reduction_dtypes( + a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT + ) + original_dtype = a.dtype + a = _maybe_convert_to_dtype(a, opmath_dtype) + a_var, a_mean = torch.var_mean(a, dim, correction=correction, keepdim=keepdim) + a_std = torch.sqrt(a_var) + assert dtype is not None + return ( + _maybe_convert_to_dtype(a_std, dtype), + _maybe_convert_to_dtype(a_mean, original_dtype), + ) + + +@register_decomposition(aten.var_mean) +@out_wrapper("out0", "out1") +def var_mean( + a: TensorLikeType, + dim: Optional[DimsType] = None, + unbiased: Optional[bool] = None, + keepdim: bool = False, + *, + correction: Optional[NumberType] = None, +): + dim, unbiased = _dim_var_dispatch(dim, unbiased) + v = var(a, dim, unbiased, keepdim, correction=correction) + m = mean(a, dim, keepdim) + return v, m + + +@register_decomposition(aten.addr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "vec1", "vec2"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def addr( + self: TensorLikeType, + vec1: TensorLikeType, + vec2: TensorLikeType, + *, + beta: NumberType = 1, + alpha: NumberType = 1, +) -> TensorLikeType: + torch._check( + vec1.ndim == 1, + lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D", + ) + torch._check( + vec2.ndim == 1, + lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D", + ) + for arg, arg_name in ((alpha, "alpha"), (beta, "beta")): + if isinstance(arg, bool): + torch._check( + utils.is_boolean_dtype(self.dtype) + and utils.is_boolean_dtype(vec1.dtype) + and utils.is_boolean_dtype(vec2.dtype), + lambda: f"Boolean {arg_name} only supported for Boolean results.", + ) + self = self.expand(vec1.shape[0], vec2.shape[0]) + if utils.is_boolean_dtype(self.dtype): + # Integers are accepted for booleans + torch._check( + is_weakly_lesser_type(type(beta), int), + lambda: f"expected bool/int beta but got {type(beta)}", + ) + torch._check( + is_weakly_lesser_type(type(alpha), int), + lambda: f"expected bool/int alpha but got {type(beta)}", + ) + if not beta: + return torch.outer(vec1, vec2) if alpha else torch.full_like(self, False) + else: + return torch.logical_or( + self, + torch.outer(vec1, vec2) if alpha else torch.full_like(self, False), + ) + else: + torch._check( + is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)), + lambda: f"cannot safely convert {type(beta)} to {self.dtype}", + ) + torch._check( + is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)), + lambda: f"cannot safely convert {type(alpha)} to {self.dtype}", + ) + if beta == 0: + # This means NaNs from self are dropped if beta is zero + return alpha * torch.outer(vec1, vec2) + else: + return beta * self + alpha * torch.outer(vec1, vec2) + + +# CompositeImplicitAutograd - don't register decomp +def atleast_1d( + arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType +) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]: + """Reference implementation of :func:`torch.atleast_1d`.""" + if not args and isinstance(arg, collections.abc.Sequence): + args_ = arg + else: + assert not isinstance(arg, collections.abc.Sequence) + args_ = (arg,) + args + res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_) + return res if len(res) > 1 else res[0] + + +# Helper function with assert to avoid MyPy error +# of incompatible type passed to unsqueeze +def _unsqueeze_atleast( + at_least_fn: Callable, dim: int, arg: TensorLikeType +) -> TensorLikeType: + arg_ = at_least_fn(arg) + assert isinstance(arg_, TensorLike) + return unsqueeze(arg_, dim) + + +# CompositeImplicitAutograd - don't register decomp +def atleast_2d( + arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType +) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]: + """Reference implementation of :func:`torch.atleast_2d`.""" + if not args and isinstance(arg, collections.abc.Sequence): + args_ = arg + else: + assert not isinstance(arg, collections.abc.Sequence) + args_ = (arg,) + args + unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0) + res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_) + return res if len(res) > 1 else res[0] + + +# CompositeImplicitAutograd - don't register decomp +def atleast_3d( + arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType +) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]: + """Reference implementation of :func:`torch.atleast_3d`.""" + if not args and isinstance(arg, collections.abc.Sequence): + args_ = arg + else: + assert not isinstance(arg, collections.abc.Sequence) + args_ = (arg,) + args + unsqueeze_atleast_2d = partial(_unsqueeze_atleast, atleast_2d, -1) + res = tuple(a if a.ndim >= 3 else unsqueeze_atleast_2d(a) for a in args_) + return res if len(res) > 1 else res[0] + + +def as_strided( + a: TensorLikeType, + size: ShapeType, + stride: StrideType, + storage_offset: Optional[int] = None, +) -> TensorLikeType: + storage_offset_int = ( + storage_offset if storage_offset is not None else a.storage_offset() + ) + return prims.as_strided(a, size, stride, storage_offset_int) + + +@register_decomposition(aten.as_strided_scatter) +@out_wrapper() +def as_strided_scatter( + input: TensorLikeType, + src: TensorLikeType, + size: ShapeType, + stride: StrideType, + storage_offset: Optional[int] = None, +) -> TensorLikeType: + storage_offset_int = 0 if storage_offset is None else storage_offset + return prims.as_strided_scatter(input, src, size, stride, storage_offset_int) + + +def broadcast_shapes(*shapes) -> ShapeType: + return torch.Size(_broadcast_shapes(*shapes)) + + +@aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten.broadcast_tensors.default.py_impl(DispatchKey.Meta) +def broadcast_tensors(*tensors) -> list[TensorLikeType]: + if len(tensors) == 1 and not isinstance(tensors[0], Tensor): + tensors = tensors[0] + return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False)) + + +# CompositeImplicitAutograd - don't register decomp +def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType: + start = len(size) - len(a.shape) + dims = tuple(range(start, len(a.shape) + start)) + return prims.broadcast_in_dim(a, size, dims) + + +@register_decomposition(aten.cat) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("tensors",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, +) +def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: + def cat_compute_output_memory_format(inputs): + format = None + for t in inputs: + f = utils.suggest_memory_format(t) + if f == torch.contiguous_format: + return f + if format is not None and format != f: + return torch.contiguous_format + format = f + assert format is not None + return format + + if len(tensors) == 0: + msg = "cat expects at least one tensor, but received zero!" + raise ValueError(msg) + + for tensor in tensors: + assert isinstance(tensor, TensorLike) + + utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False) + + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_size_oblivious, + ) + + # This is a bit tricky. Naively, you would expect to just pick one + # arbitrary tensor and check that all tensors match this tensor. However, + # there is legacy behavior which says that if you have a 1-D empty tensor + # (0,), this is permissible. So you can't assume that all the tensors + # have same dimensionality, and you can't assume that the first tensor is + # the correct stencil. + # + # We'll implement this in a few passes. First, we will try to infer the + # ndim of the cat output. If this ndim != 1, then we know that all ndim = + # 1 inputs must be empty, or are errors. If this ndim == 1, then life + # is easy (the legacy special case coincides with regular handling). + # + # NB: The regular implementation of cat just filters out empty inputs, + # but we do it slightly different here for better handling for unbacked + # SymInts + + example = None + for i, t in enumerate(tensors): + if example is None: + if t.ndim != 1: + example = t + else: + if t.ndim != 1: + torch._check( + t.ndim == example.ndim, + lambda: "Number of dimensions of tensors must match. " + f"Expected {example.ndim}-D tensors, but got {t.ndim}-D for " + f"tensor number {i} in the list", + ) + + if example is None: + # example is None if everything is 1-D. If so, just arbitrarily pick + # the first one + example = tensors[0] + + shape = example.shape + filtered = [] + for tensor_idx, tensor in enumerate(tensors): + if len(shape) != len(tensor.shape): + assert tensor.ndim == 1 # we've already checked this above + # Don't suggest the legacy behavior in the error message + torch._check( + # NB: it is not enough to simply assert that tensor.shape[0] == 0; + # this MUST be true even under guard size oblivious. + # Effectively, we must actually know that the shape is zero, + # passing an unbacked SymInt which we will defer a runtime + # assert on won't cut it. This is a policy decision (size + # oblivious semantics say that u0 tensors never are inferred + # to be zero size, even if they must be that for the cat to go + # through), and is load bearing for our Inductor lowerings + # (which assume that size oblivious tests are OK to determine + # if a shape is permissibly zero.) + guard_size_oblivious(tensor.shape[0] == 0), + lambda: f"Number of dimensions of tensors must match. " + f"Expected {example.ndim}-D tensors, but got 1-D for " + f"tensor number {tensor_idx} in the list", + ) + else: + # Remove inputs that are 1-D, zero size + if tensor.ndim == 1 and guard_or_false(tensor.shape[0] == 0): + continue + # Don't bother checking size match, prims.cat will handle it + filtered.append(tensor) + + memory_format = cat_compute_output_memory_format(tensors) + + if len(filtered) == 0: + t = tensors[0] + + # TODO: fix this to work with meta tensors + try: + # BUG? This looks like it wants to call builtins.any() but is + # actually calling .any() (in this file). Changing to builtins.any() + # causes tests to fail: + # PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 python test/test_ops.py -k \ + # TestFakeTensorCUDA.test_fake_crossref_backward_amp_cat_cuda_float32 + requires_grad = bool(any(x.requires_grad for x in tensors)) # type: ignore[arg-type] + except Exception: + requires_grad = False # type: ignore[assignment] + + return empty( + (0,), + dtype=t.dtype, + device=t.device, + requires_grad=requires_grad, + memory_format=memory_format, + ) + + dim = utils.canonicalize_dim(filtered[0].ndim, dim) + utils.validate_idx(filtered[0].ndim, dim) + + return prims.cat(filtered, dim).clone(memory_format=memory_format) + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def column_stack(tensors: TensorSequenceType) -> TensorLikeType: + aligned_tensors = tuple( + x if x.ndim > 1 else x.reshape((x.numel(), 1)) for x in tensors + ) + return cat(aligned_tensors, 1) + + +def conj(input: TensorLikeType) -> TensorLikeType: + if not utils.is_complex_dtype(input.dtype): + return input + if input.is_sparse: + return torch.conj_physical(input) + return prims.conj(input) + + +# This replicates at::constant_pad_nd, defined in ATen/native/PadNd.cpp +@register_decomposition(aten.constant_pad_nd) +@out_wrapper() +def constant_pad_nd( + input: TensorLikeType, pad: list[int], value: NumberType = 0 +) -> TensorLikeType: + torch._check( + len(pad) % 2 == 0, + lambda: f"Length of pad must be even but instead it equals {len(pad)}", + ) + + input_sizes = input.shape + l_inp = len(input_sizes) + + l_pad = len(pad) // 2 + l_diff = l_inp - l_pad + + torch._check( + l_inp >= l_pad, + lambda: "Length of pad should be no more than twice the number of " + f"dimensions of the input. Pad length is {len(pad)} while the input has " + f"{l_inp} dimensions.", + ) + + c_input = input + for i in range(l_diff, l_inp): + pad_idx = 2 * (l_inp - i - 1) + if pad[pad_idx] < 0: + c_input = c_input.narrow(i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx]) + + if pad[pad_idx + 1] < 0: + c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1]) + + # If all the pads are negative we can return the result. + # Avoid early exiting if all pads = 0 to prevent specialization on export. + # During export, raw if statements are specialized on the input, meaning + # that we lose a branch depending on the example input used to export. + # Here, this is either the case where all pads = 0, or the case where at + # least one pad > 0 and the rest are >= 0. + # Avoiding the early exit when all pads = 0 ensures we can export + # constant_pad_nd for cases when all pads >= 0. + # Note: if any pads are negative, this code specializes due to the if statements above. + if builtins.all(p < 0 for p in pad): + return c_input.clone() + + new_shape = list(input_sizes[:l_diff]) + + for i in range(l_pad): + pad_idx = len(pad) - ((i + 1) * 2) + new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1] + torch._check( + new_dim > 0, + lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding " + f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, " + f"which is invalid. Check dimension {l_diff + i} of your input.", + ) + new_shape.append(new_dim) + + memory_format = utils.suggest_memory_format(input) + output = torch.empty( + new_shape, + dtype=input.dtype, + device=input.device, + requires_grad=input.requires_grad, + memory_format=memory_format, + ) + + if value == 0 and input.dtype == torch.bool: + value = False + # torch.fill isn't typed to allow complex values + output = torch.fill(output, value) # type: ignore[arg-type] + + c_output = output + for i in range(l_diff, l_inp): + pad_idx = 2 * (l_inp - i - 1) + if pad[pad_idx] >= 0: + c_output = c_output.narrow( + i, pad[pad_idx], c_output.shape[i] - pad[pad_idx] + ) + if pad[pad_idx + 1] >= 0: + c_output = c_output.narrow(i, 0, c_output.shape[i] - pad[pad_idx + 1]) + + prims.copy_to(c_output, c_input) + return output + + +def contiguous( + a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format +) -> Tensor: + torch._check( + memory_format != torch.preserve_format, + lambda: "preserve memory format is unsupported by the contiguous operator", + ) + + # TODO: make logic consistent with aten contiguous + if definitely_contiguous_for_memory_format(a, memory_format=memory_format): + return a + + return torch.clone(a, memory_format=memory_format) + + +@out_wrapper() +def dstack(tensors: TensorSequenceType) -> TensorLikeType: + torch._check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList") + aligned_tensors = atleast_3d(*tensors) + return cat(aligned_tensors, 2) + + +@register_decomposition(aten.expand) +def expand(a: Tensor, *shape) -> Tensor: + from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_or + + # NOTE: cannot use utils.extract_shape_from_varargs here + # because that also validates the shape, but the shape + # given to expand may be "invalid" + if len(shape) == 1 and isinstance(shape[0], Sequence): + shape = tuple(shape[0]) + + torch._check( + len(shape) >= len(a.shape), + lambda: "expand: the requested shape has too few dimensions!", + ) + + offset = len(shape) - len(a.shape) + shape_ = list(shape) + for idx, x in enumerate(a.shape): + offset_idx = idx + offset + requested_length = shape[offset_idx] + + # expand(in -> out) has 3 different semantics: + # 1) out == -1 -> size = in, stride unchanged + # 2) in == 1 -> size = out, stride = 0 + # 3) in == out -> size = in, stride unchanged + # + # the code below is written for unbacked semantics s.t. we assume unbacked symbols don't + # represent -1 unless explicitly specified, and the user is opting for case 2) or 3). + # the sym_or allows either case, but in the decomposition's current state, broadcast_in_dim() + # will either assume case 3) (via validate_shape() marking the expanded shape size-like), or will + # raise a data-dependent error trying to figure out if the stride is 0, requiring the user to manually + # select between the semantics of cases 2) and 3). + if guard_or_false(requested_length == -1): + shape_[offset_idx] = x + else: + torch._check( + sym_or(x == 1, requested_length == x), + lambda: f"expand: attempting to expand a dimension of length {x} -> {requested_length}!", + ) + torch._check(requested_length >= 0) + shape_[offset_idx] = requested_length + + # At this point shape must be valid + utils.validate_shape(shape_) + + return prims.broadcast_in_dim( + a, shape_, tuple(range(offset, len(a.shape) + offset)) + ) + + +# CompositeImplicitAutograd - don't register decomp +def expand_as(a: Tensor, b: Tensor) -> Tensor: + return a.expand(b.shape) + + +def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> tuple[TensorLikeType, ...]: + if chunks <= 0: + msg = f"Expected at least one chunk, but got {chunks}!" + raise ValueError(msg) + + dim = utils.canonicalize_dim(a.ndim, dim) + length = a.shape[dim] + chunk_size = math.ceil(length / chunks) + full_chunks = math.floor(length / chunk_size) + tail_chunk_size = length % chunk_size + + result = [narrow(a, dim, i * chunk_size, chunk_size) for i in range(full_chunks)] + + if tail_chunk_size != 0: + result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size)) + + return tuple(result) + + +# Note: flatten, unlike other shape operators, returns the input tensor on a no-op (unless +# a 0D tensor is flattened, in which case it's returned in 1D) +# CompositeImplicitAutograd - don't register decomp +def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorLikeType: + start_dim = utils.canonicalize_dim(a.ndim, start_dim) + end_dim = utils.canonicalize_dim(a.ndim, end_dim) + + # Short-circuits on no-op + if start_dim == end_dim and a.ndim != 0: + return a + + # Tries to take a view + # TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view) + new_shape, _new_strides = prims._collapse_view_helper(a, start_dim, end_dim) + if new_shape is not None: + return prims.collapse_view(a, start_dim, end_dim) + + # Makes a copy if it can't make a view + return prims.collapse(a, start_dim, end_dim) + + +@register_decomposition(aten.flip) +@out_wrapper() +def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType: + if not isinstance(dims, tuple) and not isinstance(dims, list): + raise ValueError("dims has to be a sequence of ints") + dims = utils.canonicalize_dims(a.ndim, dims) # type: ignore[assignment] + utils.validate_no_repeating_dims(dims) + return prims.rev(a, dims) + + +# CompositeImplicitAutograd - don't register decomp +def fliplr(a: TensorLikeType) -> TensorLikeType: + if a.ndim < 2: + raise RuntimeError("Input must be >= 2-d.") + + return flip(a, (1,)) + + +# CompositeImplicitAutograd - don't register decomp +def flipud(a: TensorLikeType) -> TensorLikeType: + if a.ndim < 1: + raise RuntimeError("Input must be >= 1-d.") + + return flip(a, (0,)) + + +# CompositeImplicitAutograd - don't register decomp +def narrow( + a: TensorLikeType, dim: int, start: Union[int, TensorLikeType], length: int +) -> TensorLikeType: + # Supports Tensor overload that was added for XLA: + # https://github.com/pytorch/pytorch/issues/31558 + if isinstance(start, TensorLike): + torch._check( + start.dim() == 0 and utils.is_integer_dtype(start.dtype), + lambda: "start must be an 0-dim integral Tensor.", + ) + start = start.item() # type: ignore[assignment] + start = cast(int, start) + torch._check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.") + torch._check(length >= 0, lambda: "narrow(): length must be non-negative.") + dim = utils.canonicalize_dim(a.ndim, dim) + dim_length = a.size(dim) + torch._check_with( + IndexError, + -dim_length <= start and start <= dim_length, + lambda: f"start out of range (expected to be in range of [{-dim_length}, {dim_length}], but got {start})", + ) + if start < 0: + start = start + dim_length + torch._check( + start <= dim_length - length, + lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).", + ) + new_shape = list(a.shape) + new_shape[dim] = length + return a.as_strided( + new_shape, a.stride(), a.storage_offset() + a.stride(dim) * start + ) + + +def _normalize( + a: Tensor, norm_dims: DimsType, eps: float +) -> tuple[Tensor, Tensor, Tensor]: + """Computes mean and 1/std of a tensor along norm_dims. + + Used as a helper function for normalization layers. + + Args: + a (Tensor): input tensor + norm_dims (DimsType): dimensions to normalize over + eps (float): epsilon for numerical stability + + Returns: + out (Tensor): normalized tensor. + mean (Tensor): mean of the tensor along norm_dims. + rstd (Tensor): 1/std of the tensor along norm_dims. + """ + norm_dims = utils.canonicalize_dims(a.ndim, norm_dims) + computation_dtype = utils.get_computation_dtype(a.dtype) + a_acc = _maybe_convert_to_dtype(a, computation_dtype) + assert isinstance(a_acc, TensorLike) # to avoid mypy error for var_mean + biased_var, mean = torch.var_mean( + a_acc, dim=norm_dims, unbiased=False, keepdim=True + ) + rstd = torch.rsqrt(biased_var + eps) + out = (a_acc - mean) * rstd + return out, mean, rstd + + +# add all specified dimensions +def _unsqueeze_multiple(x: TensorLikeType, dimensions: list[int]) -> TensorLikeType: + for dim in sorted(dimensions): + x = torch.unsqueeze(x, dim) + return x + + +@register_decomposition(aten.native_group_norm.default) +def native_group_norm( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + batch_size: int, + num_channels: int, + flattened_inner_size: int, + num_groups: int, + eps: float, +) -> tuple[Tensor, Tensor, Tensor]: + torch._check( + input.ndim >= 2, + lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", + ) + torch._check( + num_channels % num_groups == 0, + lambda: "Expected number of channels in input to be divisible by num_groups, " + + f"but got input of shape {input.shape} and num_groups = {num_groups}", + ) + + computation_dtype = utils.get_computation_dtype(input.dtype) + input_acc = _maybe_convert_to_dtype(input, computation_dtype) + # num_channels / num_groups and flattened inner dimension are the reduction axes + reduction_dims = [2, 3] + input_reshaped = torch.reshape( + input_acc, + [batch_size, num_groups, num_channels // num_groups, flattened_inner_size], + ) + reduction_dims = utils.canonicalize_dims(input_reshaped.ndim, reduction_dims) + biased_var, mean = torch.var_mean( + input_reshaped, dim=reduction_dims, unbiased=False, keepdim=True + ) + rstd = torch.rsqrt(biased_var + eps) + if input.device.type == "cpu" and weight is not None: + weight_reshaped = torch.reshape( + weight, [1, num_groups, num_channels // num_groups, 1] + ) + w = rstd * weight_reshaped + b = -mean * w + if bias is not None: + bias_reshaped = torch.reshape( + bias, [1, num_groups, num_channels // num_groups, 1] + ) + b = b + bias_reshaped + w = w.contiguous().as_strided([batch_size, num_channels], [num_channels, 1]) + b = b.contiguous().as_strided([batch_size, num_channels], [num_channels, 1]) + broadcast_dims = list(range(2, input.ndim)) + unsqueeze_w = _unsqueeze_multiple(w, broadcast_dims) + unsqueeze_b = _unsqueeze_multiple(b, broadcast_dims) + out = input_acc * unsqueeze_w + unsqueeze_b + else: + out = (input_reshaped - mean) * rstd + out = out.view(input.shape) + broadcast_dims = [0] + list(range(2, input.ndim)) + if weight is not None: + unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims) + out = out * unsqueeze_weight + if bias is not None: + unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims) + out = out + unsqueeze_bias + + out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment] + mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment] + rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment] + + # remove broadcast dimensions from mean and rstd + mean = torch.squeeze(mean, reduction_dims) + rstd = torch.squeeze(rstd, reduction_dims) + return (out, mean, rstd) + + +@register_decomposition(aten.native_layer_norm) +@out_wrapper("out0", "out1", "out2") +def native_layer_norm( + input: Tensor, + normalized_shape: ShapeType, + weight: Optional[Tensor], + bias: Optional[Tensor], + eps: float, +) -> tuple[Tensor, Tensor, Tensor]: + normalized_ndim = len(normalized_shape) + torch._check( + normalized_ndim >= 1, + lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., " + + "containing at least one element, but got normalized_shape = " + + str(normalized_shape), + ) + # torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False + # while torch.Size([1, 2, 3]) == (1, 2, 3) is True + # therefore we use tuple(normalized_shape) + torch._check( + weight is None or weight.shape == tuple(normalized_shape), + lambda: "Expected weight to be of same shape as normalized_shape, but got " + + "weight of shape " + + str(weight.shape) # type: ignore[union-attr] + + " and normalized_shape = " + + str(normalized_shape), + ) + torch._check( + bias is None or bias.shape == tuple(normalized_shape), + lambda: "Expected bias to be of same shape as normalized_shape, but got " + + "bias of shape " + + str(bias.shape) # type: ignore[union-attr] + + " and normalized_shape = " + + str(normalized_shape), + ) + torch._check( + input.ndim >= normalized_ndim + and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape), + lambda: "Given normalized_shape=" + + str(normalized_shape) + + ", expected input with shape " + + str(normalized_shape) + + ", but got input of size " + + str(input.shape), + ) + + input = contiguous(input) + if weight is not None: + weight = contiguous(weight) + if bias is not None: + bias = contiguous(bias) + + axis = input.ndim - normalized_ndim + reduction_dims = list(range(axis, input.ndim)) + out, mean, rstd = _normalize(input, reduction_dims, eps) + + if weight is None and bias is not None: + out = out + bias + elif weight is not None and bias is None: + out = out * weight + elif weight is not None and bias is not None: + out = out * weight + bias + + out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment] + if input.device.type in ["cpu", "mtia"]: + mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment] + rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment] + return (out, mean, rstd) + + +@torch._subclasses.fake_impls.register_op_impl(aten.native_layer_norm.default) +def native_layer_norm_fake(fake_mode, func, *args, **kwargs): + return native_layer_norm(*args) + + +# TODO: Adding this as a meta function causes functorch tests to fail when compiled with debug mode. +# test/test_eager_transforms.py::TestFunctionalizeCPU::test_functionalize_fx_transpose_simple_cpu +@register_decomposition(aten.permute) +def permute(a: TensorLikeType, *dims) -> TensorLikeType: + _permutation = utils.canonicalize_dims( + a.ndim, utils.extract_dims_from_varargs(dims) + ) + return prims.transpose(a, _permutation) + + +@register_decomposition(aten.renorm) +@out_wrapper() +def renorm( + input: TensorLikeType, p: RealNumberType, dim: int, maxnorm: RealNumberType +) -> TensorLikeType: + torch._check(not isinstance(p, complex), lambda: "renorm: p must be real-valued") + torch._check(p > 0, lambda: "renorm: non-positive norm not supported") + torch._check( + not isinstance(maxnorm, complex), lambda: "renorm: maxnorm must be real-valued" + ) + torch._check( + maxnorm >= 0, lambda: f"renorm: expected maxnorm to be >= 0 but got {maxnorm}" + ) + ndim = input.ndim + torch._check( + ndim > 1, + lambda: f"renorm: input needs at least 2 dimensions, got {ndim} dimensions", + ) + + dim = utils.canonicalize_dim(ndim, dim) + reduce_dims = list(range(ndim)) + del reduce_dims[dim] + + # For half and bfloat16, calculate norm in float precision then cast + # normalization factor to half + acc_type = utils.get_computation_dtype(input.dtype) + if acc_type != input.dtype: + norm = torch.linalg.vector_norm( + input, p, reduce_dims, keepdim=True, dtype=acc_type + ) + else: + norm = torch.linalg.vector_norm(input, p, reduce_dims, keepdim=True) + + eps = 1e-7 + norm_factor = torch.where(norm > maxnorm, maxnorm / (norm + eps), 1.0) + if acc_type != input.dtype: + norm_factor = prims.convert_element_type(norm_factor, input.dtype) + return (input * norm_factor).contiguous() + + +# CompositeImplicitAutograd - don't register decomp +@aten.stft.center.py_impl(DispatchKey.CompositeImplicitAutograd) +def stft( + input: Tensor, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[Tensor] = None, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + onesided: Optional[bool] = None, + return_complex: Optional[bool] = None, + align_to_window: Optional[bool] = None, +) -> Tensor: + torch._check( + window is None or window.device == input.device, + lambda: ( + f"stft input and window must be on the same device but got self on {input.device}" + + f" and window on {window.device}" # type: ignore[union-attr] + ), + ) + torch._check( + not center or align_to_window is None, + "stft only supports align_to_window for center = False.", + ) + + hop_length_ = hop_length if hop_length is not None else n_fft // 4 + win_length_ = win_length if win_length is not None else n_fft + + if return_complex is None: + return_complex_ = input.is_complex() or ( + window is not None and utils.is_complex_dtype(window.dtype) + ) + torch._check( + return_complex_, + ( + "stft requires the return_complex parameter be given for real inputs, " + + "and will further require that return_complex=True in a future PyTorch release." + ), + ) + else: + return_complex_ = return_complex + + torch._check( + utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype), + lambda: "stft expected a tensor of floating point or complex values", + ) + torch._check(1 <= input.ndim <= 2, lambda: "stft expected a 1D or 2D tensor") + + original_ndim = input.ndim + if original_ndim == 1: + input = input.unsqueeze(0) + + if center: + extra_dims = 3 - input.ndim + pad_amount = n_fft // 2 + extended_shape = [*itertools.repeat(1, extra_dims), *input.shape] + input = aten.pad(input.view(extended_shape), [pad_amount, pad_amount], pad_mode) + input = input.view(input.size()[extra_dims:]) + + length = input.size(1) + torch._check( + 0 < n_fft <= length, + lambda: f"stft expected 0 < n_fft <= {length}, but got n_fft={n_fft}", + ) + torch._check( + hop_length_ > 0, + lambda: f"stft expected hop_length > 0 but got hop_length={hop_length_}", + ) + torch._check( + 0 < win_length_ <= n_fft, + lambda: f"stft expected 0 < win_length <= n_fft but got win_length={win_length_}", + ) + torch._check( + window is None or window.shape == (win_length_,), + lambda: ( + f"expected a 1D window tensor of size equal to win_length={win_length_}, " + + f"but got window with size {window.shape}" # type: ignore[union-attr] + ), + ) + + if win_length_ < n_fft: + if window is None: + window = torch.ones(win_length_, dtype=input.dtype, device=input.device) + left = (n_fft - win_length_) // 2 + window = aten.constant_pad_nd(window, [left, n_fft - win_length_ - left]) + + if not center and align_to_window: + input_pad_amount = (n_fft - win_length_) // 2 + input = aten.pad(input, [input_pad_amount, input_pad_amount], pad_mode) + + input = input.unfold(dimension=-1, size=n_fft, step=hop_length_) + + if window is not None: + input = input * window + + complex_fft = utils.is_complex_dtype(input.dtype) + onesided = onesided if onesided is not None else not complex_fft + norm = "ortho" if normalized else None + if onesided: + torch._check( + not complex_fft, + lambda: "Cannot have onesided output if window or input is complex", + ) + out = torch.fft.rfft(input, dim=-1, norm=norm) + else: + out = torch.fft.fft(input, dim=-1, norm=norm) + + out.transpose_(1, 2) + + if original_ndim == 1: + out = out.squeeze_(0) + + return out if return_complex_ else torch.view_as_real(out) + + +# CompositeImplicitAutograd - don't register decomp +@aten.istft.default.py_impl(DispatchKey.CompositeImplicitAutograd) +def istft( + input: Tensor, + n_fft: int, + hop_length: Optional[int] = None, + win_length: Optional[int] = None, + window: Optional[Tensor] = None, + center: bool = True, + normalized: bool = False, + onesided: Optional[bool] = None, + length: Optional[int] = None, + return_complex=False, +) -> Tensor: + torch._check( + window is None or window.device == input.device, + lambda: ( + f"istft input and window must be on the same device but got self on {input.device}" + + f" and window on {window.device}" # type: ignore[union-attr] + ), + ) + + hop_length_ = hop_length if hop_length is not None else n_fft // 4 + win_length_ = win_length if win_length is not None else n_fft + + torch._check( + utils.is_complex_dtype(input.dtype), + lambda: ( + "istft input and window must be on the same device but got self on " + + f"{input.device} and window on {window.device}" # type: ignore[union-attr] + ), + ) + n_frames = input.size(-1) + fft_size = input.size(-2) + + expected_output_signal_len = n_fft + hop_length_ * (n_frames - 1) + torch._check(input.numel() > 0, lambda: "istft input tensor cannot be empty") + torch._check( + 2 <= input.ndim <= 3, + lambda: f"istft expected a tensor with 2 or 3 dimensions, but got {input.ndim}", + ) + onesided_ = onesided if onesided is not None else fft_size != n_fft + + if onesided_: + torch._check( + n_fft // 2 + 1 == fft_size, + lambda: ( + "istft expected the frequency dimension (3rd to the last) of the input tensor " + + "to match n_fft / 2 + 1 when onesided=True, but got {fft_size}" + ), + ) + else: + torch._check( + n_fft == fft_size, + lambda: ( + "istft expected the frequency dimension (3rd to the last) of the input tensor " + + "to match n_fft when onesided=False, but got {fft_size}", + ), + ) + + torch._check( + 0 < hop_length_ <= win_length_, + lambda: "istft expected 0 < hop_length <= win_length", + ) + torch._check( + 0 < win_length_ <= n_fft, lambda: "istft expected 0 < win_length <= n_fft" + ) + torch._check( + window is None or window.shape == (win_length_,), + lambda: "Invalid window shape. window has to be 1D and length of `win_length`", + ) + + if window is None: + real_dtype = utils.corresponding_real_dtype(input.dtype) + window_ = torch.ones(win_length_, dtype=real_dtype, device=input.device) + else: + window_ = window + + if win_length_ != n_fft: + left = (n_fft - win_length_) // 2 + window_ = aten.constant_pad_nd(window_, (left, n_fft - win_length_ - left), 0) + + original_ndim = input.ndim + if input.ndim == 2: + input = input.unsqueeze(0) + + input = input.transpose(1, 2) + norm = "ortho" if normalized else None + if return_complex: + torch._check( + not onesided_, + lambda: "cannot have onesided output if window or input is complex", + ) + input = torch.fft.ifft(input, dim=-1, norm=norm) + else: + torch._check( + window is None or not utils.is_complex_dtype(window.dtype), + lambda: "Complex windows are incompatible with return_complex=False", + ) + if not onesided_: + input = input.narrow(dim=-1, start=0, length=n_fft // 2 + 1) + input = torch.fft.irfft(input, dim=-1, norm=norm) + + assert input.size(2) == n_fft + + y_tmp = input * window_.view([1, 1, n_fft]) + y = aten.unfold_backward( + y_tmp, + input_sizes=(y_tmp.size(0), expected_output_signal_len), + dim=1, + size=n_fft, + step=hop_length_, + ) + window_envelop = aten.unfold_backward( + window_.pow(2).expand((1, n_frames, n_fft)), + input_sizes=(y_tmp.size(0), expected_output_signal_len), + dim=1, + size=n_fft, + step=hop_length_, + ) + + assert expected_output_signal_len == y.size(1) + assert expected_output_signal_len == window_envelop.size(1) + + start = n_fft // 2 if center else 0 + if length is not None: + end = start + length + elif center: + end = expected_output_signal_len - n_fft // 2 + else: + end = expected_output_signal_len + + length = max(0, end - start) + y = y.narrow(dim=1, start=start, length=length) + window_envelop = window_envelop.narrow(dim=1, start=start, length=length) + + y = y / window_envelop + if original_ndim == 2: + y = y.squeeze(0) + + if end > expected_output_signal_len: + warnings.warn( + "The length of signal is shorter than the length parameter. Result is being " + + "padded with zeros in the tail. Please check your center and hop_length settings" + ) + y = aten.constant_pad_nd(y, (0, end - expected_output_signal_len), 0) + return y + + +# Get the new shape and stride after applying unfold to an input tensor +def _get_unfold_shape_stride( + a_shape: ShapeType, a_stride: StrideType, dimension: int, size: int, step: int +): + a_ndim = len(a_shape) + dim = utils.canonicalize_dim(a_ndim, dimension, wrap_scalar=True) + max_size = 1 if a_ndim == 0 else a_shape[dim] + last_stride = 1 if a_ndim == 0 else a_stride[dim] + + torch._check( + size <= max_size, + lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}", + ) + + torch._check( + step > 0, + lambda: f"Step is {step} but must be > 0", + ) + + shape = list(a_shape) + strides = list(a_stride) + shape.append(size) + strides.append(last_stride) + if dim < a_ndim: + shape[dim] = (shape[dim] - size) // step + 1 + strides[dim] *= step + return shape, strides + + +@register_decomposition(aten.repeat) +@out_wrapper() +def repeat(a: Tensor, *repeat_shape) -> Tensor: + repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False) + torch._check( + len(repeat_shape) >= len(a.shape), + lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor", + ) + + if len(repeat_shape) == 0: + return torch.clone(a) + + num_new_dimensions = len(repeat_shape) - a.ndim + padded_shape = [1] * num_new_dimensions + for dim_size in a.shape: + padded_shape.append(dim_size) + + target_shape = tuple( + padded_size * repeat_size + for padded_size, repeat_size in zip(padded_shape, repeat_shape) + ) + + # return an empty tensor if one of the repeat_shape dimensions is zero + if 0 in repeat_shape: + return torch.empty( + target_shape, + dtype=a.dtype, + device=a.device, + requires_grad=a.requires_grad, + memory_format=utils.suggest_memory_format(a), + ) + + urtensor_shape = target_shape + urtensor_stride = utils.make_contiguous_strides_for(target_shape) + for dim, dim_size in enumerate(padded_shape): + # repeat each dimension by using unfold_copy operation + urtensor_shape, urtensor_stride = _get_unfold_shape_stride( + urtensor_shape, urtensor_stride, dim, dim_size, max(dim_size, 1) + ) + + # derive permute order by sorting urtensor strides + enumerated_stride = list(enumerate(urtensor_stride)) + enumerated_stride.sort(key=operator.itemgetter(1), reverse=True) + permute_order, _sorted_stride = zip(*enumerated_stride) + + # add new and expand dimensions according to urtensor + repeat_xtensor = a.expand(urtensor_shape) + + # clone tensor to concretize expanded dimensions + cloned_result = torch.clone(repeat_xtensor) + + # transpose axis so strides are in sorted order + permuted_result = cloned_result.permute(permute_order) + + # reshape to get contiguous tensor with correct target shape + return permuted_result.reshape(target_shape) + + +def _reshape_view_helper_core_alg( + a: TensorLikeType, shape, allow_copy: bool +) -> TensorLikeType: + # NOTE [Reshape Algorithm] + # This algorithm works by attempting to greedily construct the desired dimensions in + # the output shape, left to right. It does this by, conceptually, accumulating + # dimensions of the original tensor, also left to right, until the dimension + # can be constructed using prims.split_dim. + # The algorithm also has special handling for tail squeezes/unsqueezes, like + # if a reshape from (5, 5) to (5, 5, 1) or vice versa. + # + # This algorithm does not flatten the original tensor and then split dims as appropriate + # because that would create copies more often than this algorithm. flatten is the only + # operation below which can create a view or a copy, and while it prefers creating + # views it may sometimes create a copy if the tensor's strides do not permit a view. + # As a result, this algorithm tries to minimize flattening. + # + # Note that a better version of this algorithm may exist. Regions which could be + # flattened without creating a copy can be identified in advance, and that might + # allow fewer flatten calls or faster short-circuiting to make a copy. + idx = 0 + a_ = a + for length in shape: + # Handles tail unsqueezes + if idx >= a_.ndim: + assert length == 1 + last_dim = a_.ndim - 1 + # NOTE: using split_dim instead of unsqueeze may seem silly here, + # but it's necessary to get the strides correct + a_ = prims.split_dim(a_, last_dim, a_.shape[last_dim]) + idx = idx + 1 + continue + + # Skips dimensions that are already the correct length + if length == a_.shape[idx]: + idx = idx + 1 + continue + + accum = a_.shape[idx] + end = idx + while accum % length != 0: + end += 1 + accum *= a_.shape[end] + if end != idx: + # NOTE: in this case multiple dimensions must be flatten to create the desired dimension + # This flattening is why reshape sometimes creates a copy -- because flattening + # may return a view of a copy + + # Checks if collapse can be a view and short-circuits to copying reshape if it can't + new_shape, _new_strides = prims._collapse_view_helper(a_, idx, end) + if new_shape is None: + if allow_copy: + return prims.reshape(a, shape) + + msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!" + raise ValueError(msg) + + a_ = flatten(a_, idx, end) + + # Splits the (possibly flattened) dimension to create the desired dim length. + # guard_or_true is safe due to the tail unsqueeze routine. + if accum != length: + a_ = prims.split_dim(a_, idx, length) + + idx = idx + 1 + + # Squeezes tail + while idx < a_.ndim: + torch._check( + a_.shape[idx] == 1, + lambda: f"a.size({idx}) expected to be 1 but got {a_.shape[idx]}", + ) + a_ = squeeze(a_, idx) + + if a_ is a: + return prims.view_of(a) + else: + return a_ + + +def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType: + # Creates a valid shape + shape = utils.extract_shape_from_varargs(shape, validate=False) + # Reshape may be given a shape with a -1 length + # This indicates that the dimension's length should be inferred + shape = utils.infer_size(shape, a.numel()) + + # Special-cases tensors with no elements + if a.numel() == 0: + return as_strided(a, shape, utils.make_contiguous_strides_for(shape)) + + # Special-cases reshaping zero dim tensors + if a.ndim == 0: + _a = a + for length in shape: + assert length == 1 + _a = unsqueeze(_a, -1) + if _a is a: + return prims.view_of(a) + else: + return _a + + # Special-cases reshaping to zero dim tensors + if len(shape) == 0: + _a = a + for length in a.shape: + assert length == 1 + _a = squeeze(_a, -1) + if _a is a: + return prims.view_of(a) + else: + return _a + + if definitely_contiguous(a): + # Special-cases for nd_to_1d + if len(shape) == 1 and a.ndim > 1: + return torch.as_strided(a, [a.numel()], [1]) + # Special-cases for 1d_to_2d + if len(shape) == 2 and a.ndim == 1: + dim0 = shape[0] + dim1 = shape[1] + return torch.as_strided(a, [dim0, dim1], [dim1, 1]) + + shape_numel = reduce(operator.mul, shape, 1) + torch._check( + a.numel() == shape_numel, + f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!", + ) + + # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape + return _reshape_view_helper_core_alg(a, shape, allow_copy) + + +# CompositeImplicitAutograd - don't register decomp +# NOTE: shape is a vararg because Tensor.reshape can be called with as +# Tensor.reshape(a, b, c) or Tensor.reshape((a, b, c)) Function call +# torch.reshape doesn't support unpacked shapes +def reshape(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType: + return _reshape_view_helper(a, *shape, allow_copy=True) + + +# CompositeImplicitAutograd - don't register decomp +def reshape_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType: + return self.reshape(other.size()) + + +@register_decomposition(aten.roll) +@out_wrapper() +def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLikeType: + """Reference implementation of :func:`torch.roll`.""" + dims = utils.canonicalize_dims(a.ndim, dims) + # ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1 + if not isinstance(shifts, Iterable): + shifts = (shifts,) + if not isinstance(dims, Iterable): + dims = (dims,) + + # Avoid modulo by zero + if a.numel() == 0: + # Keeping this as ref for now as FakeTensor runs into some issues with complex tensors + return a.clone() + + if a.dim() == 0 and len(dims) > 0: + raise IndexError( + f"Dimension specified as {dims[0]} but tensor has no dimensions" + ) + + len_shifts = len(shifts) + len_dims = len(dims) + if len_shifts != 1 or len_dims != 1: + if len_shifts == 0: + raise RuntimeError("`shifts` required") + # Takes care of the case when dims is not specified (default) + # By default, the tensor is flattened before shifting, after which the original shape is restored + if len_dims == 0 and len_shifts == 1: + return torch.roll(torch.flatten(a), shifts, 0).view(a.shape) + if len_shifts != len_dims: + raise RuntimeError( + f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}" + ) + assert len_dims > 1 + tail_shifts = shifts[1:] + tail_dims = dims[1:] + first_dim_rolled = torch.roll(a, (shifts[0],), dims[0]) + return torch.roll(first_dim_rolled, tail_shifts, tail_dims) + + # This path is taken when only one dimension is rolled + # For example to get `first_dim_rolled` above + dim = dims[0] + size = a.shape[dim] + start = (size - shifts[0]) % size + idx = torch.arange(size, device=a.device) + return a.index_select(dim, torch.fmod(start + idx, size)) + + +@register_decomposition(aten.rot90) +@out_wrapper() +def rot90( + a: TensorLikeType, k: int = 1, dims: DimsSequenceType = (0, 1) +) -> TensorLikeType: + """Reference implementation of :func:`torch.rot90`.""" + if len(dims) != 2: + raise RuntimeError( + f"expected total rotation dims == 2, but got dims = {len(dims)}" + ) + if a.ndim < 2: + raise RuntimeError(f"expected total dims >= 2, but got total dims = {a.ndim}") + + # Do this after the initial checks to be compatible with the behavior in + # core. + dims = utils.canonicalize_dims(a.ndim, dims) + + if dims[0] == dims[1]: + raise RuntimeError( + f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}" + ) + k = k % 4 # Rotation direction is from the second towards the first axis for k < 0 + if k == 1: + return torch.transpose(torch.flip(a, (dims[1],)), dims[0], dims[1]) + elif k == 2: + return torch.flip(a, dims) + elif k == 3: + return torch.transpose(torch.flip(a, (dims[0],)), dims[0], dims[1]) + else: + return a.clone(memory_format=torch.contiguous_format) + + +def _check_stack_inputs(tensors: TensorSequenceType) -> None: + entry_shape = tensors[0].shape + for i in range(1, len(tensors)): + assert tensors[i].shape == entry_shape, ( + f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0 " + f"and {tensors[i].shape} at entry {i}" + ) + + +@register_decomposition(aten.stack) +@out_wrapper() +def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: + assert len(tensors) > 0, "stack expects a non-empty TensorList" + wrapped_dim = utils.canonicalize_dim(tensors[0].ndim + 1, dim) + # Refs need sparse support to check other condition + if wrapped_dim < tensors[0].ndim: # and not tensors[0].is_sparse: + _check_stack_inputs(tensors) + result_sizes = list(tensors[0].shape) + result_sizes.insert(wrapped_dim, len(tensors)) + out = torch.cat(tensors, wrapped_dim) + return out.view(result_sizes) + + # If dim == tensors[0].ndim, view cannot efficiently handle it + return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim) + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + result_dtype = dtype or a.dtype + computation_dtype = utils.get_computation_dtype(result_dtype) + a_ = _maybe_convert_to_dtype(a, computation_dtype) + if a.numel() == 0: + a_exp = exp(a_) + else: + a_max = amax(a_, dim, keepdim=True) + a_exp = exp(a_ - a_max) + return _maybe_convert_to_dtype( + true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype + ) # type: ignore[return-value] + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def hstack(tensors: TensorSequenceType) -> TensorLikeType: + torch._check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList") + aligned_tensors = atleast_1d(*tensors) + if aligned_tensors[0].ndim == 1: + return cat(aligned_tensors, 0) + return cat(aligned_tensors, 1) + + +# CompositeImplicitAutograd - don't register decomp +@out_wrapper() +def vstack(tensors: TensorSequenceType) -> TensorLikeType: + torch._check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList") + aligned_tensors = atleast_2d(*tensors) + return cat(aligned_tensors, 0) + + +# CompositeImplicitAutograd - don't register decomp +def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType: + dim = utils.canonicalize_dim(a.ndim, dim) + torch._check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty") + return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :])) + + +@register_decomposition(aten.unbind) +def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + dim = utils.canonicalize_dim(t.ndim, dim) + torch._check_index( + len(t.shape) > 0, + lambda: "Dimension specified as 0 but tensor has no dimensions", + ) + if guard_size_oblivious(t.shape[dim] == 0): + return () + else: + return tuple( + torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim) + ) + + +@out_wrapper() +def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): + return x.clone(memory_format=torch.contiguous_format).index_copy_( + dim, index, tensor + ) + + +def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike): + dim = utils.canonicalize_dims(x.ndim, dim) + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + # Treat scalars as elements of \R^1 + y = x.unsqueeze(0) if x.ndim == 0 else x + idx = (slice(None),) * dim + (index,) + y[idx] = tensor + return x + + +@register_decomposition(aten.index_fill) +@out_wrapper() +def index_fill( + x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike] +): + return _index_fill(x, dim, index, value, inplace=False) + + +@register_decomposition(aten.index_fill_) +def index_fill_( + x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike] +): + return _index_fill(x, dim, index, value, inplace=True) + + +def _index_fill( + x: TensorLike, + dim: int, + index: TensorLike, + value: Union[NumberType, TensorLike], + *, + inplace: bool, +): + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + if isinstance(value, TensorLike): + torch._check( + value.ndim == 0, + lambda: "Only supports 0-dimensional value tensor. " # type: ignore[union-attr] + f"Got a tensor with {value.ndim} dimensions.", + ) # type: ignore[arg-type] + else: + value = torch.scalar_tensor( + value, + dtype=x.dtype, + layout=x.layout, + device=x.device, # type: ignore[arg-type] + ) + + # index_copy has some unnecessary preconditions when x is a scalar. We do this to work through them + zero_dim = x.ndim == 0 + y = x.unsqueeze(0) if zero_dim else x + # index_copy does not broadcast on value so we have to do it manually + shape = list(y.shape) + shape[dim] = index.numel() + value = value.expand(shape) + index_copy = Tensor.index_copy_ if inplace else torch.index_copy + out = index_copy(y, dim, index, value) # type: ignore[operator] + if inplace: + return x + else: + if zero_dim: + # The clone is necessary so that it returns a fresh tensor rather than a view + out = out.squeeze(0).clone() + # index_fill preserves the strides. index_copy always returns contiguous tensors + if out.stride() != x.stride(): + new_out = torch.empty_like(x) + new_out.copy_(out) + out = new_out + return out + + +@out_wrapper() +def index_add( + x: TensorLike, + dim: int, + index: TensorLike, + tensor: TensorLike, + *, + alpha: NumberType = 1, +): + # index_add always returns a new contiguous tensor + return x.clone(memory_format=torch.contiguous_format).index_add_( + dim, + index, + tensor, + alpha=alpha, # type: ignore[arg-type] + ) + + +@register_decomposition(aten.index_select) +@out_wrapper() +def index_select(x: TensorLike, dim: int, index: TensorLike): + dim = utils.canonicalize_dims(x.ndim, dim) + torch._check( + index.ndim <= 1, + lambda: f"Index should have dimension 1 or 0 (got {index.ndim})", + ) + if index.ndim == 0: + index = index.unsqueeze(0) + if x.ndim == 0: + # Treat scalars as elements of \R^1 + # We cannot use x[idx] here as it accesses item() (??), hence this awkward construction + return torch.empty_like(x).index_copy(0, index, x.expand_as(index)) + + idx = (slice(None),) * dim + (index,) + return x[idx] + + +@register_decomposition(aten.squeeze.dims) +def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if dim is None: + dims = tuple(idx for idx, size in enumerate(a.shape) if size == 1) + return prims.squeeze(a, dims) if dims else prims.view_of(a) + + ndim = a.ndim + dim = utils.canonicalize_dims(ndim, dim) + dims = (dim,) if isinstance(dim, Dim) else dim + # Short-circuits if the tensor has no dimensions + if ndim == 0: + assert len(dims) == 0 or dims == (0,) + return prims.view_of(a) + + # Note: squeeze does not modify tensors when the given dim is not a dimension of length 1 + dims = tuple(d for d in dims if guard_size_oblivious(a.shape[d] == 1)) + if len(dims) == 0: + return prims.view_of(a) + if len(dims) == 1: + return prims.squeeze(a, dims) + dims_list = list(dims) + dims_list = sorted(dims_list, reverse=True) + for i in dims_list: + a = squeeze(a, i) + return a + + +@register_decomposition(aten.split_with_sizes) +def split_with_sizes( + self: Tensor, split_sizes: list[int], dim: int = 0 +) -> list[Tensor]: + # NB: Perform the check_is_size tests first so that the + # sum test does not try to do a replacement + for i in range(len(split_sizes)): + torch._check_is_size( + split_sizes[i], + lambda: "split_with_sizes expects split_sizes have only non-negative entries", + ) + torch._check_with( + ValueError, + builtins.sum(split_sizes) == self.shape[dim], + lambda: f"Split sizes add up to {builtins.sum(split_sizes)} but got the tensor's size of {self.shape[dim]}", + ) + + splits = [] + offset = self.storage_offset() + + for split_size in split_sizes: + new_shape = list(self.shape) + new_shape[dim] = split_size + # We reimplement narrow here to avoid a lot of checks in the + # decomposition of narrow which calls slice_in_dim and slice + splits.append(self.as_strided(new_shape, self.stride(), offset)) + offset = offset + self.stride()[dim] * split_size + return splits + + +# Note: does not work with TensorMetas because of data-dependent control-flow +# CompositeImplicitAutograd - don't register decomp +def tensor_split( + a: TensorLikeType, + indices_or_sections: Union[Tensor, DimsType], + dim: int = 0, +) -> tuple[TensorLikeType, ...]: + _dim = utils.canonicalize_dim(a.ndim, dim) + if a.ndim == 0: + msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!" + raise ValueError(msg) + + # If indices_or_sections is a tensor, it must be a CPU Long tensor + if isinstance(indices_or_sections, TensorLike): + if not indices_or_sections.device.type == "cpu": + msg = ( + f"tensor_split: if indices_or_sections is a tensor it must be on the CPU, " + f"but received one on {indices_or_sections.device}" + ) + raise ValueError(msg) + if indices_or_sections.dtype != torch.long: + msg = ( + "tensor_split: if indices_or_sections is a tensor it must have long dtype, " + f" but received one with dtype {indices_or_sections.dtype}" + ) + raise ValueError(msg) + + # Case 0 -- indices_or_sections is an integer or a scalar tensor n and a is split along dim into n parts of equal-ish length + if isinstance(indices_or_sections, IntLike) or ( + isinstance(indices_or_sections, TensorLike) and indices_or_sections.ndim == 0 + ): + sections: int = ( + indices_or_sections # type: ignore[assignment] + if isinstance(indices_or_sections, Number) + else indices_or_sections.item() + ) + + if sections <= 0: + msg = f"tensor_split: number of sections must be greater than 0, but was {sections}" + raise ValueError(msg) + + dim_size = a.shape[_dim] + min_split_size = math.floor(dim_size / sections) + num_splits_one_extra = dim_size % sections + + split_sizes = [] + for split_idx in range(sections): + split_size = ( + min_split_size + 1 + if (split_idx < num_splits_one_extra) + else min_split_size + ) + split_sizes.append(split_size) + + return tuple(aten.split_with_sizes(a, split_sizes, dim=_dim)) + # Case 1 -- indices_or_sections is a sequence of integers or a 1D tensor describing the splits + else: + indices = indices_or_sections + if isinstance(indices_or_sections, TensorLike): + if indices_or_sections.ndim != 1: + msg = ( + "tensor_split: non-scalar indices_or_sections tensors must have only one dimension, " + f"but received a tensor with {indices_or_sections.ndim} dimensions" + ) + raise ValueError(msg) + + indices = indices_or_sections.tolist() + + indices = [0] + list(indices) + [a.shape[_dim]] + split_sizes = [indices[i + 1] - indices[i] for i in range(len(indices) - 1)] + return tuple(aten.split_with_sizes(a, split_sizes, dim=_dim)) + + +# CompositeImplicitAutograd - don't register decomp +def hsplit( + a: TensorLikeType, indices_or_sections: DimsType +) -> tuple[TensorLikeType, ...]: + torch._check( + a.ndim >= 1, + lambda: ( + "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with " + + str(a.ndim) + + " dimensions!" + ), + ) + dim = 0 if a.ndim == 1 else 1 + if isinstance(indices_or_sections, IntLike): + split_size = indices_or_sections + torch._check( + (split_size != 0 and a.shape[dim] % split_size == 0), + lambda: ( + "torch.hsplit attempted to split along dimension " + + str(dim) + + ", but the size of the dimension " + + str(a.shape[dim]) + + " is not divisible by the split_size " + + str(split_size) + + "!" + ), + ) + return tensor_split(a, split_size, dim) + + torch._check_type( + isinstance(indices_or_sections, (list, tuple)), + lambda: ( + "hsplit(): received an invalid combination of arguments. " + "Expected indices_or_sections to be of type int, list of ints or tuple of ints " + f"but got type {type(indices_or_sections)}" + ), + ) + + split_sizes = indices_or_sections + return tensor_split(a, split_sizes, dim) + + +# CompositeImplicitAutograd - don't register decomp +def vsplit( + a: TensorLikeType, indices_or_sections: DimsType +) -> tuple[TensorLikeType, ...]: + torch._check( + a.ndim >= 2, + lambda: ( + "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with " + + str(a.ndim) + + " dimensions!" + ), + ) + if isinstance(indices_or_sections, IntLike): + split_size = indices_or_sections + torch._check( + (split_size != 0 and a.shape[0] % split_size == 0), + lambda: ( + f"torch.vsplit attempted to split along dimension 0" + f", but the size of the dimension " + f"{a.shape[0]}" + f" is not divisible by the split_size " + f"{split_size}" + f"!" + ), + ) + return tensor_split(a, split_size, 0) + + torch._check_type( + isinstance(indices_or_sections, (list, tuple)), + lambda: ( + "vsplit(): received an invalid combination of arguments. " + "Expected indices_or_sections to be of type int, list of ints or tuple of ints " + f"but got type {type(indices_or_sections)}" + ), + ) + + split_sizes = indices_or_sections + return tensor_split(a, split_sizes, 0) + + +@register_decomposition(aten.diag.out) +@out_wrapper() +def diag( + self: TensorLikeType, + offset: int = 0, +) -> TensorLikeType: + ndim = self.dim() + torch._check( + ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D" + ) + if ndim == 1: + return torch.diag_embed(self, offset) + else: + return torch.diagonal_copy(self, offset) + + +@register_decomposition(aten.diagonal_scatter) +@out_wrapper() +def diagonal_scatter( + input: TensorLikeType, + src: TensorLikeType, + offset: int = 0, + dim1: int = 0, + dim2: int = 1, +) -> TensorLikeType: + out = utils.clone_preserve_strides(input) + diag = out.diagonal(offset, dim1, dim2) + torch._check( + diag.shape == src.shape, + lambda: "expected src to have a size equal to the diagonal of the input." + f"Got {src.shape} for a diagonal of shape {diag.shape}", + ) + copy_to(diag, src) + return out + + +@register_decomposition(aten.diagonal) +def diagonal( + self: TensorLikeType, + offset: int = 0, + dim1: int = 0, + dim2: int = 1, +) -> TensorLikeType: + """ + Reference implementation of torch.diagonal + """ + num_dims = self.dim() + dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims) + dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims) + + torch._check( + dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" + ) + + storage_offset = self.storage_offset() + + if offset >= 0: + diag_size = max(min(self.size()[dim1], self.size()[dim2] - offset), 0) + else: + diag_size = max(min(self.size()[dim1] + offset, self.size()[dim2]), 0) + + if diag_size > 0: + if offset >= 0: + storage_offset += offset * self.stride()[dim2] + else: + storage_offset -= offset * self.stride()[dim1] + + sizes = [s for i, s in enumerate(self.size()) if i not in (dim1, dim2)] + sizes.append(diag_size) + + strides = [s for i, s in enumerate(self.stride()) if i not in (dim1, dim2)] + strides.append(self.stride()[dim1] + self.stride()[dim2]) + + result = self.as_strided(size=sizes, stride=strides, storage_offset=storage_offset) + + return result + + +@register_decomposition(aten.diag_embed) +@out_wrapper() +def diag_embed( + t: TensorLikeType, + offset: int = 0, + dim1: int = -2, + dim2: int = -1, +) -> TensorLikeType: + """ + Reference implementation of torch.diag_embed + """ + # convert from negative dims + rank = t.ndim + 1 + dim1 = utils.canonicalize_dim(rank=rank, idx=dim1) + dim2 = utils.canonicalize_dim(rank=rank, idx=dim2) + + # as per the docs, exchanging dims is equivalent to changing the sign of + # offset + if dim1 > dim2: + dim1, dim2 = dim2, dim1 + offset = -offset + + torch._check( + dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}" + ) + + # as per the docs, the size of last dim is placed at dim1 and dim2 + last_dim = t.size(-1) + + if offset != 0: + # add padding to match the new size + t_shape = list(t.shape) + t_shape[-1] = builtins.abs(offset) + z = torch.zeros(t_shape, dtype=t.dtype, device=t.device, requires_grad=False) + pair = (z, t) if offset > 0 else (t, z) + t = torch.cat(pair, dim=-1) + # make sure the diagonal always has the same size + last_dim += builtins.abs(offset) + + # preserve original data, but place 1 at dim1 and move last dim to dim2 + t = t.unsqueeze(dim1).movedim(-1, dim2) + + # generate ranges shifting indices based on offset + a_range = torch.arange(last_dim, device=t.device, dtype=torch.int64) + b_range = torch.arange( + offset, last_dim + offset, device=t.device, dtype=torch.int64 + ) + + # broadcast + cond = a_range == b_range.unsqueeze(-1) + cond_shape = [last_dim if i in (dim1, dim2) else 1 for i in range(len(t.shape))] + cond = cond.reshape(cond_shape) + + # aten.diag_embed always returns a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return utils.mask_tensor(cond, t).contiguous() + + +@register_decomposition(aten.block_diag) +@out_wrapper() +def _block_diag_iterable(tensors: list[TensorLikeType]) -> TensorLikeType: + """ + Reference implementation of torch.block_diag + """ + tensors_2d = [ + tensor.view(1, -1) if tensor.dim() <= 1 else tensor for tensor in tensors + ] + + ncols = builtins.sum(tensor.shape[1] for tensor in tensors_2d) + device = tensors_2d[0].device + + result = [] + + col_start = 0 + for i, tensor in enumerate(tensors_2d): + torch._check( + tensor.dim() == 2, + lambda: "Input tensors must have 2 or fewer dimensions. " + f"Input {i} has {tensor.dim()} dimensions", + ) + torch._check( + tensor.device == device, + lambda: "Input tensors must all be on the same device. " + f"Input 0 is on device {device} and input {i} is on device {tensor.device}.", + ) + row, col = tensor.shape + left = torch.zeros((row, col_start), device=device, dtype=tensor.dtype) + right = torch.zeros( + (row, ncols - col_start - col), device=device, dtype=tensor.dtype + ) + result += [torch.cat((left, tensor, right), dim=1)] + col_start += col + + return torch.cat(result, dim=0) + + +def block_diag(*tensors: list[TensorLikeType]) -> TensorLikeType: + """ + This is used as an input to PythonRefInfo. `torch.block_diag` + expects arguments splatted, but `aten.block_diag` expects only + one argument that is a list of Tensors. + """ + return _block_diag_iterable(tensors) # type: ignore[arg-type] + + +# CompositeImplicitAutograd - don't register decomp +def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType: + if a.ndim < 3: + raise RuntimeError( + f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!" + ) + if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0): + raise RuntimeError( + "torch.dsplit attempted to split along dimension 2, " + + f"but the size of the dimension {a.shape[2]} is not divisible by the split_size {sections}!" + ) + return tensor_split(a, sections, 2) + + +@register_decomposition(aten.t.default) +def t(a: TensorLikeType): + # TODO: Add sparse support + # if a.is_sparse: + # sparse_dim = a.sparse_dim() + # dense_dim = a.dense_dim() + # if not (sparse_dim <= 2 and dense_dim == 0): + # raise RuntimeError( + # f"t() expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and" + # f"{dense_dim} dense dimensions" + # ) + if a.ndim > 2: + raise RuntimeError( + f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D" + ) + return torch.transpose(a, 0, 0 if a.ndim < 2 else 1) + + +# CompositeImplicitAutograd - don't register decomp +def T(a: TensorLikeType) -> TensorLikeType: + # n != 2 && n != 0 is deprecated in regular PyTorch. + torch._check( + a.ndim in (0, 2), + lambda: ( + "The use of `x.T` on tensors of dimension other than 0 or 2 " + "to reverse their shape is not supported." + ), + ) + return a.t() + + +@register_decomposition(aten.alias) +def alias(a: TensorLikeType) -> TensorLikeType: + return prims.view_of(a) + + +@register_decomposition(aten.transpose) +def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType: + _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1)) # type: ignore[misc] + + if a.ndim <= 1 or dim0 == dim1: + return aten.alias.default(a) + + _permutation = list(range(0, a.ndim)) + _permutation[_dim0] = _dim1 + _permutation[_dim1] = _dim0 + return torch.permute(a, _permutation) + + +# Aliases for transpose +swap_axes = transpose + + +@register_decomposition(aten.unfold) +def unfold( + self: TensorLikeType, dimension: int, size: int, step: int +) -> TensorLikeType: + shape, strides = _get_unfold_shape_stride( + self.shape, self.stride(), dimension, size, step + ) + return self.as_strided(shape, strides) + + +@register_decomposition(aten.unfold_copy) +@out_wrapper() +def unfold_copy(self: TensorLikeType, dimension: int, size: int, step: int): + return self.unfold(dimension, size, step).clone( + memory_format=torch.contiguous_format + ) + + +def _cumsumprod_common( + func, + init, + a: TensorLikeType, + dim: int, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + # We implement all the kwargs of a reduction. ATen just handles dtype + # nb. This decomposition may not be as efficient as a backend-specific implementation + ndim = a.ndim + dim = utils.canonicalize_dim(ndim, dim) + if ndim == 0: + return func(a.unsqueeze(0), dim=0, dtype=dtype, out=out) + a = a.unsqueeze(dim + 1) + rg = torch.arange(a.shape[dim], device=a.device) + mask = rg.unsqueeze(1) <= rg + for _ in range(ndim - dim - 1): + mask = mask.unsqueeze(-1) + masked_a = torch.where(mask, a, init) + return func(masked_a, dim=dim, dtype=dtype, out=out) + + +@register_decomposition(aten.cumsum) +def cumsum( + a: TensorLikeType, + dim: int, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + return _cumsumprod_common(func=sum, init=0, a=a, dim=dim, dtype=dtype, out=out) + + +@register_decomposition(aten.cumprod) +def cumprod( + a: TensorLikeType, + dim: int, + *, + dtype: Optional[torch.dtype] = None, + out: Optional[Tensor] = None, +) -> TensorLikeType: + return _cumsumprod_common(func=prod, init=1, a=a, dim=dim, dtype=dtype, out=out) + + +# Note: although squeeze is documented as having the out= kwarg it doesn't +@register_decomposition(aten.unsqueeze) +def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType: + # Note that unsqueeze canonicalizes with rank + 1 because it allows + # a new innermost dimension to be specified + ndim = a.ndim + 1 + dim = utils.canonicalize_dim(ndim, dim) + return prims.expand_dims(a, (dim,), ndim=ndim) + + +# NOTE: shape is a vararg because Tensor.reshape can be called with as +# Tensor.view(a, b, c) or Tensor.view((a, b, c)) Function call torch.view +# doesn't support unpacked shapes +# TODO: Turn this into a decomposition (currently fails on reshape meta tests) +@register_decomposition(aten.view.default) +def view(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType: + return _reshape_view_helper(a, *shape, allow_copy=False) + + +# CompositeImplicitAutograd - don't register decomp +def view_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType: + return self.view(other.size()) + + +# CompositeImplicitAutograd - don't register decomp +def ravel(a: TensorLikeType) -> TensorLikeType: + return reshape(a, (-1,)) + + +# CompositeImplicitAutograd - don't register decomp +# missing ref impl. for aten.gather +@out_wrapper() +def take_along_dim( + a: torch.Tensor, indices: torch.Tensor, dim: Optional[int] = None +) -> torch.Tensor: + torch._check( + a.ndim == indices.ndim, + lambda: ( + "torch.take_along_dim(): input and indices should have the same " + f"number of dimensions, but got {a.ndim} dimensions for input, and " + f"{indices.ndim} dimensions for indices" + ), + ) + + torch._check( + utils.is_integer_dtype(indices.dtype), + lambda: ( + "torch.take_along_dim(): dtype of indices should be int but got " + f"{indices.dtype} instead" + ), + ) + + if dim is None: + return torch.gather(a.view(-1), 0, indices.view(-1)) + else: + self_sizes = list(a.shape) + self_sizes[dim] = indices.size(dim) + broadcast_shape = utils.infer_size_shapes(self_sizes, indices.size()) + indices_broadcast = broadcast_to(indices, broadcast_shape) + + indices_sizes = list(indices.shape) + indices_sizes[dim] = a.size(dim) + broadcast_shape = utils.infer_size_shapes(indices_sizes, a.size()) + self_broadcast = broadcast_to(a, broadcast_shape) + + return torch.gather(self_broadcast, dim, indices_broadcast) + + +@out_wrapper() +def empty( + *shape, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + requires_grad: bool = False, + pin_memory: bool = False, + memory_format: torch.memory_format = torch.contiguous_format, +) -> TensorLikeType: + torch._check( + memory_format != torch.preserve_format, + lambda: "torch.empty: the Preserve memory format is not supported", + ) + + shape = utils.extract_shape_from_varargs(shape) + + if memory_format == torch.contiguous_format: + strides = utils.make_contiguous_strides_for(shape) + elif memory_format == torch.channels_last_3d: + strides = utils.make_channels_last_3d_strides_for(shape) + else: # memory_format == torch.channels_last + torch._check( + memory_format == torch.channels_last, + lambda: f"torch.empty: received an unknown memory format {memory_format}!", + ) + strides = utils.make_channels_last_2d_strides_for(shape) + + return torch.empty_strided( + shape, + strides, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@out_wrapper() +def empty_permuted( + shape, + physical_layout, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + requires_grad: bool = False, + pin_memory: bool = False, +) -> TensorLikeType: + return prims.empty_permuted( + shape, + physical_layout, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.new_empty) +@out_wrapper() +def new_empty( + a: TensorLikeType, + size: ShapeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.empty( + size, + dtype=dtype, + device=device, + pin_memory=pin_memory, + layout=layout, + ) + + +@register_decomposition(aten.new_empty_strided) +@out_wrapper() +def new_empty_strided( + a: TensorLikeType, + size: ShapeType, + stride: StrideType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.Tensor.new_empty_strided + """ + + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.empty_strided( + size, + stride, + dtype=dtype, + device=device, + pin_memory=pin_memory, + layout=layout, + ) + + +@register_decomposition(aten.zeros.default) +@out_wrapper() +def zeros( + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + size = utils.extract_shape_from_varargs(size) + + if dtype is None: + dtype = torch.get_default_dtype() + + return torch.full( + size, + False if dtype == torch.bool else 0, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.new_zeros) +@out_wrapper() +def new_zeros( + a: TensorLikeType, + size: ShapeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.full( + size, + False if (dtype or a.dtype) == torch.bool else 0, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.ones.default) +@out_wrapper() +def ones( + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + size = utils.extract_shape_from_varargs(size) + + if dtype is None: + dtype = torch.get_default_dtype() + + return torch.full( + size, + True if dtype == torch.bool else 1, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.new_ones) +@out_wrapper() +def new_ones( + a: TensorLikeType, + size: ShapeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.full( + size, + True if (dtype or a.dtype) == torch.bool else 1, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.new_full) +@out_wrapper() +def new_full( + a: TensorLikeType, + size: ShapeType, + fill_value: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + return torch.full( + size, + fill_value, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + ) + + +@aten.empty.out.py_impl(DispatchKey.CompositeImplicitAutograd) +def empty_out( + size: TensorLikeType, + out: TensorLikeType, + memory_format: Optional[torch.memory_format] = None, +) -> TensorLikeType: + return out + + +@register_decomposition(aten.empty_like) +@out_wrapper() +def empty_like( + a: TensorLikeType, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: Optional[torch.layout] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> TensorLikeType: + dtype = a.dtype if dtype is None else dtype + layout = a.layout if layout is None else layout + device = a.device if device is None else device + + if memory_format != torch.preserve_format: + return torch.empty( + a.shape, + dtype=dtype, + layout=layout, + device=device, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + ) + + # memory_format == torch.preserve_format + logical_to_physical_perm = ( + utils.compute_elementwise_output_logical_to_physical_perm(a) + ) + # identity perm is [2, 1, 0] + return torch.empty_permuted( + a.shape, + logical_to_physical_perm, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + + +@register_decomposition([aten.arange.start_step, aten.arange.start_out]) +@out_wrapper() +def arange( + start: NumberType = 0, + end: Optional[NumberType] = None, + step: NumberType = 1, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + device = torch.device(utils.device_or_default(device)) + + assert not isinstance(start, complex) + assert not isinstance(end, complex) + assert not isinstance(step, complex) + + # Case: torch.arange(5) + if end is None: + end = start + start = 0 + torch._check(step != 0, lambda: "step must be nonzero") + if step > 0: + torch._check( + end >= start, + lambda: "upper bound and lower bound inconsistent with step sign", + ) + elif step < 0: + torch._check( + end <= start, + lambda: "upper bound and lower bound inconsistent with step sign", + ) + + def is_finite(x): + return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x) + + torch._check( + is_finite(start) and is_finite(end), + lambda: f"unsupported range: {start} -> {end}", + ) + torch._check( + is_finite(step), + lambda: f"step must be finite but got {step}", + ) + + args = (start, end, step) + integer_args = builtins.all(isinstance(arg, IntLike) for arg in args) + + if dtype is None: + dtype = torch.int64 if integer_args else torch.get_default_dtype() + + is_integer = utils.is_integer_dtype(dtype) + if is_integer or integer_args: + xstart = sym_int(start) + xend = sym_int(end) + xstep = sym_int(step) + + # For int64 we truncate arguments to int before calculating length, but + # other integral dtypes we don't. Weird... but needed to match ATen shapes. + if dtype == torch.int64 or integer_args: + # Uses floordiv to avoid ceil in inductor. + sgn = bool(xstep > 0) - bool(xstep < 0) # type: ignore[possibly-undefined] + length = (xend - xstart + xstep - sgn) // xstep # type: ignore[possibly-undefined] + else: + length = math.ceil((end - start) / step) + + if is_integer: + return prims.iota( + length, + start=xstart, # type: ignore[possibly-undefined] + step=xstep, # type: ignore[possibly-undefined] + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + index = prims.iota( + length, + start=0, + step=1, + dtype=torch.int64, + device=device, + requires_grad=False, + ) + + computation_dtype = ( + torch.long if integer_args else utils.get_acc_type(dtype, device) + ) + index = _maybe_convert_to_dtype(index, computation_dtype) + result = start + step * index + result = _maybe_convert_to_dtype(result, dtype) + + if requires_grad: + result.requires_grad_(True) + return result + + +@register_decomposition(aten.lerp) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("start", "end", "weight"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def lerp(start: Tensor, end: Tensor, weight: Union[Tensor, NumberType]): + inputs = [start, end] + if isinstance(weight, Number): + weight = start.new_full((), weight) # type: ignore[arg-type] + else: + inputs.append(weight) + assert isinstance(weight, Tensor) # mypy + # We implement it this way for numerical stability. We assume (in the stability optimisation) + # that 0 <= weight <= 1. We take the abs to deal with complex numbers + # We want to perform operations near zero, which is where floating points are most precise + # thus, we perform the following optimisation: + # If weight.abs() >= 0.5: + # return (1 - weight) * (start - end) + end + mask = weight.abs() >= 0.5 + coeff = torch.where(mask, weight - 1, weight) + base = torch.where(mask, end, start) + output = coeff * (end - start) + base + # make sure the decomposition output's stride is same as non-decomposition path. + stride = utils.compute_elementwise_output_strides(*_maybe_broadcast(*inputs)) + if output.stride() != stride: + output = prims.copy_strided(output, stride) + + return handle_noncontiguous_outputs(inputs, output) + + +@register_decomposition(aten.linspace) +@out_wrapper() +def linspace( + start: Union[NumberType, TensorLikeType], + end: Union[NumberType, TensorLikeType], + steps: NumberType, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: torch.layout = torch.strided, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + if isinstance(start, TensorLikeType): + torch._check( + start.dim() == 0, + lambda: "linspace only supports 0-dimensional start and end tensors", + ) + start = _maybe_convert_to_dtype(start, torch.float64) + if isinstance(end, TensorLikeType): + torch._check( + end.dim() == 0, + lambda: "linspace only supports 0-dimensional start and end tensors", + ) + end = _maybe_convert_to_dtype(end, torch.float64) + + if builtins.any(isinstance(arg, complex) for arg in (start, end, steps)): + default_complex_dtype = utils.corresponding_complex_dtype( + torch.get_default_dtype() + ) + if dtype is None: + dtype = default_complex_dtype + else: + torch._check( + utils.is_complex_dtype(dtype), + lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}", + ) + else: + dtype = dtype or torch.get_default_dtype() + assert isinstance(dtype, torch.dtype) + + # steps does not participate in the computation of the dtype + torch._check_type( + isinstance(steps, IntLike), + lambda: f"received an invalid combination of arguments - got \ +({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})", + ) + assert isinstance(steps, IntLike) # for mypy + torch._check(steps >= 0, lambda: "number of steps must be non-negative") + + factory_kwargs = { + "layout": layout, + "device": device, + "pin_memory": pin_memory, + "requires_grad": requires_grad, + } + if steps == 0: + return torch.full((0,), 0, dtype=dtype, **factory_kwargs) # type: ignore[arg-type] + if steps == 1: + if isinstance(start, TensorLikeType): + empty_tensor = torch.empty((steps,), dtype=dtype, **factory_kwargs) # type: ignore[arg-type] + return torch.ops.aten.copy.default(empty_tensor, start) + else: + return torch.full((steps,), start, dtype=dtype, **factory_kwargs) # type: ignore[arg-type] + + # Perform in arange in int because some backends like ATen or Triton do not support all the dtypes + rg = torch.arange(0, steps, **factory_kwargs) # type: ignore[arg-type] + + # Small types need to be computed in higher precision as this is, at heart, an associative scan + dtype_red = ( + torch.int64 + if (utils.is_boolean_dtype(dtype) or utils.is_integer_dtype(dtype)) + else dtype + ) + computation_dtype, _ = utils.reduction_dtypes( + rg, REDUCTION_OUTPUT_TYPE_KIND.SAME, dtype_red + ) + cast_rg = partial(_maybe_convert_to_dtype, dtype=computation_dtype) + + # We implement torch.lerp without performing rg / (steps - 1) explicitly + # With this we get out[0] == start, out[-1] == end + step = (end - start) / (steps - 1) + out = torch.where( + rg < steps / 2, + start + step * cast_rg(rg), # type: ignore[arg-type,operator] + end - step * cast_rg((steps - 1) - rg), # type: ignore[arg-type,operator] + ) + return _maybe_convert_to_dtype(out, dtype) # type: ignore[return-value] + + +@register_decomposition(aten.logspace) +@out_wrapper() +def logspace( + start: Union[NumberType, TensorLikeType], + end: Union[NumberType, TensorLikeType], + steps: NumberType, + base: NumberType = 10, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: torch.layout = torch.strided, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + if dtype is None: + dtype = torch.get_default_dtype() + + # NB: NumPy doesn't have this cast + if prims.utils.is_integer_dtype(dtype): + if isinstance(start, FloatLike): + start = sym_int(start) + elif isinstance(start, TensorLikeType): + torch._check( + start.dim() == 0, + lambda: "logspace only supports 0-dimensional start and end tensors", + ) + start = _maybe_convert_to_dtype(start, dtype) + if isinstance(end, FloatLike): + end = sym_int(end) + elif isinstance(end, TensorLikeType): + torch._check( + end.dim() == 0, + lambda: "logspace only supports 0-dimensional start and end tensors", + ) + end = _maybe_convert_to_dtype(end, dtype) + + if builtins.any(isinstance(arg, complex) for arg in (start, end, steps)): + default_complex_dtype = utils.corresponding_complex_dtype( + torch.get_default_dtype() + ) + dtype = default_complex_dtype + _dtype = None # torch.linspace will update the correct dtype + else: + _dtype = torch.float64 + + assert not isinstance(base, complex) # for mypy + if base < 0: + raise NotImplementedError + ret = torch.linspace( # type: ignore[misc] + start, # type: ignore[arg-type] + end, # type: ignore[arg-type] + steps, # type: ignore[arg-type] + dtype=_dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + return _maybe_convert_to_dtype(torch.pow(base, ret), dtype) # type: ignore[arg-type,return-value] + + +@overload +def meshgrid(tensors: Sequence[TensorLikeType], indexing: str): + pass + + +@overload +def meshgrid(*tensors: TensorLikeType, indexing: str): + pass + + +@register_decomposition(aten.meshgrid) # type: ignore[misc] +def meshgrid( + *tensors: Union[TensorLikeType, list[TensorLikeType], tuple[TensorLikeType]], + indexing: str, +) -> list[TensorLikeType]: + # This ref simultaneously handles two overloads (see stubs above) + # The `indexing` argument is currently optional for torch.meshgrid, but we + # plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276 + if isinstance(tensors[0], (list, tuple)): + assert len(tensors) == 1 + tensors = tuple(tensors[0]) + + torch._check( + builtins.all(isinstance(a, TensorLike) for a in tensors), + lambda: "meshgrid expects its inputs to be tensors", + ) + + torch._check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList") + + for i in range(len(tensors) - 1): + torch._check( + tensors[i].dtype == tensors[i + 1].dtype, # type: ignore[union-attr] + lambda: "meshgrid expects all tensors to have the same dtype", + ) + torch._check( + tensors[i].device == tensors[i + 1].device, # type: ignore[union-attr] + lambda: "meshgrid expects all tensors to have the same device", + ) + + swap_first_and_second_tensors = False + if indexing == "xy": + swap_first_and_second_tensors = len(tensors) >= 2 + if swap_first_and_second_tensors: + tensors = (tensors[1], tensors[0], *tensors[2:]) + else: + torch._check( + indexing == "ij", + lambda: ( + 'torch.meshgrid: indexing must be one of "xy" or "ij", ' + f"but received: {indexing}" + ), + ) + + result_shape: list[int] = [] + for t in tensors: + assert isinstance(t, TensorLike) # mypy + torch._check( + t.ndim == 0 or t.ndim == 1, + lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}", + ) + result_shape.append(t.numel()) + + grids: list[TensorLikeType] = [] + for i, t in enumerate(tensors): + assert isinstance(t, TensorLike) # mypy + if t.ndim == 0: + t = t.view((1,)) + grids.append(prims.broadcast_in_dim(t, result_shape, (i,))) + + if swap_first_and_second_tensors: + # Swap outputs if we originally swapped at the beginning + grids[0], grids[1] = grids[1], grids[0] + + return grids + + +# CompositeImplicitAutograd - don't register decomp +def movedim( + input: TensorLikeType, + source: Union[int, DimsSequenceType], + destination: Union[int, DimsSequenceType], +) -> TensorLikeType: + """ + Reference implementation of torch.movedim + """ + if type(source) is int: + source = (source,) + if type(destination) is int: + destination = (destination,) + + # Converts to list to produce a compatible error message with core PyTorch, + # which prints sequences in square brackets. + torch._check( + len(source) == len(destination), # type: ignore[arg-type] + lambda: ( + "movedim: Invalid source or destination dims: source " # type: ignore[arg-type] + f"({list(source)} dims) should contain the same number " # type: ignore[arg-type] + f"of dims as destination ({list(destination)} dims)" # type: ignore[arg-type] + ), + ) + + rank = input.ndim + ss = tuple(utils.canonicalize_dims(rank=rank, indices=source)) # type: ignore[arg-type] + ds = tuple(utils.canonicalize_dims(rank=rank, indices=destination)) # type: ignore[arg-type] + + sss = set(ss) + dss = set(ds) + + # See above on why this converts to list in error messages. + torch._check( + len(ss) == len(sss), + lambda: f"movedim: repeated dim in `source` ({list(source)})", # type: ignore[arg-type] + ) + torch._check( + len(ds) == len(dss), + lambda: f"movedim: repeated dim in `destination` ({list(destination)})", # type: ignore[arg-type] + ) + + m = dict(zip(ds, ss)) + dims = [] + si = 0 # source index + for di in range(rank): + # check if the destination index is in the mapping + s = m.get(di) + if s is not None: + # insert source index if found + dims.append(s) + else: + # insert source index sequentially, skipping indices from the mapping + while si in sss: + si += 1 + dims.append(si) + si += 1 + + result = torch.permute(input, tuple(dims)) + + return result + + +# NOTE: for convenience, shape can be a tuple of ints or a tuple containing a tuple of ints +@register_decomposition(aten.empty_strided) +@out_wrapper() +def empty_strided( + shape: Union[ShapeType, tuple[ShapeType]], + strides: StrideType, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + pin_memory: bool = False, +) -> TensorLikeType: + # Layout == strided, pin_memory is False + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + + shape = utils.extract_shape_from_varargs(shape) + dtype = torch.get_default_dtype() if dtype is None else dtype + device = torch.device("cpu") if device is None else device + + return prims.empty_strided( + shape, + strides, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + +@register_decomposition(aten.eye) +@out_wrapper() +def eye( + n: int, + m: Optional[int] = None, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, # TODO: unused +) -> TensorLikeType: + """ + Reference implementation of torch.eye + """ + if m is None: + m = n + + torch._check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}") + torch._check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}") + + range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False) + range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False) + + cond = range_n.unsqueeze(-1) == range_m + if dtype is torch.bool: + return cond + else: + one = torch.ones( + (1,), + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=False, + ) + return torch.where(cond, one, 0) + # TODO: Use requires_grad. All refs taking the requires_grad kwarg must + # return a leaf tensor. + # result.requires_grad_(requires_grad) + + +@register_decomposition([aten.full.default, aten.full.out]) +@out_wrapper() +def full( + shape: ShapeType, + fill_value: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, +) -> TensorLikeType: + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + + dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value)) + device = device if device is not None else torch.device("cpu") + + e = empty( + shape, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + return torch.fill(e, fill_value) # type: ignore[arg-type] + + +def full_like( + a: TensorLikeType, + fill_value: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> TensorLikeType: + e = torch.empty_like( + a, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + memory_format=memory_format, + ) + return fill(e, fill_value) + + +@register_decomposition(aten.zeros_like) +@out_wrapper() +def zeros_like( + a: TensorLikeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> TensorLikeType: + return torch.full_like( + a, + False if (dtype or a.dtype) == torch.bool else 0, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + memory_format=memory_format, + ) + + +@register_decomposition(aten.ones_like) +@out_wrapper() +def ones_like( + a: TensorLikeType, + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> TensorLikeType: + return torch.full_like( + a, + True if (dtype or a.dtype) == torch.bool else 1, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + memory_format=memory_format, + ) + + +@register_decomposition(aten.randn.default) +@out_wrapper() +def randn( + *shape, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLikeType] = None, + layout: Optional[torch.layout] = None, + requires_grad: bool = False, + pin_memory: bool = False, +) -> TensorLikeType: + utils.check_pin_memory(pin_memory) + + shape_ = utils.extract_shape_from_varargs(shape) + + dtype = utils.dtype_or_default(dtype) + device = utils.device_or_default(device) + + return prims.normal( + shape_, + mean=0.0, + std=1.0, + dtype=dtype, + device=device, + requires_grad=requires_grad, + ) + + +def scalar_tensor( + a: NumberType, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device: Optional[DeviceLikeType] = None, + pin_memory: bool = False, +) -> TensorLikeType: + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + dtype = dtype if dtype is not None else utils.type_to_dtype(type(a)) + device = device if device is not None else torch.device("cpu") + return prims.scalar_tensor(a, dtype=dtype, device=device) + + +# +# Randomness References +# + + +def _uniform_helper( + shape: ShapeType, + low: Union[bool, int, float] = 0.0, + high: Union[bool, int, float] = 1.0, + *, + dtype: torch.dtype, + device: DeviceLikeType, +) -> TensorLikeType: + utils.validate_shape(shape) + + assert isinstance(low, Number) + assert isinstance(high, Number) + low = sym_float(low) + high = sym_float(high) + + assert isinstance(dtype, torch.dtype) + device = utils.canonicalize_device(device) + + return prims._uniform_helper(shape, low=low, high=high, dtype=dtype, device=device) + + +@register_decomposition(aten.masked_fill) +@out_wrapper() +def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType): + python_type = utils.dtype_to_type(a.dtype) + if isinstance(value, Number): + value_type = type(value) + else: + # NOTE: Could not use value = item(value) as it resulted in + # RuntimeError: Cannot cast FakeTensor(cpu) to number + value_ndim = value.ndim + torch._check( + value_ndim == 0, + lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension", + ) + # `masked_fill` allows cpu scalar to be moved to cuda, xpu and hpu but not otherwise. + is_cpu_scalar = ( + a.device.type + in ["cuda", "xpu", "mps", torch._C._get_privateuse1_backend_name(), "hpu"] + and value.device.type == "cpu" + ) + torch._check( + is_cpu_scalar or value.device == a.device, + lambda: "Expected `value` to be on same device as `a`", + ) + value_type = utils.dtype_to_type(value.dtype) + + if value_type is complex: + # only downcasting from complex to lower type is not allowed. + # We allow casting `value` to lower type for other case + # Eg. float -> int. + # Ref: https://github.com/pytorch/pytorch/issues/79195 + torch._check( + utils.is_weakly_lesser_type(value_type, python_type), + lambda: f"could not convert to type {python_type} without overflow", + ) + + # Since `where` allows type-promotion, + # cast value to correct type before passing to `where` + value = _maybe_convert_to_dtype(value, a.dtype) + r = torch.where(mask, value, a) # type: ignore[arg-type] + + # aten.mask_fill always return a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return r.contiguous() + + +@register_decomposition(aten.masked_fill_) +def masked_fill_( + a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType +) -> TensorLikeType: + b = torch.masked_fill(a, mask, value) # type: ignore[arg-type] + a.copy_(b) + return a + + +# CompositeImplicitAutograd - don't register decomp +def allclose( + a: TensorLikeType, + b: TensorLikeType, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> bool: + """ + Reference implementation of torch.allclose + """ + _check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol) + + return bool( + torch.all(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)).item() + ) + + +def equal(a: TensorLikeType, b: TensorLikeType) -> bool: + utils.check_same_device(a, b, allow_cpu_scalar_tensors=False) + utils.check_same_dtype(a, b) + + # Shape check + if a.ndim != b.ndim: + return False + + for x, y in zip(a.shape, b.shape): + if x != y: + return False + + # Short-circuits if there are no elements to validate + if a.numel() == 0: + return True + + return item(all(eq(a, b))) # type: ignore[return-value] + + +@register_decomposition(aten.norm) +@out_wrapper(exact_dtype=True) +def norm( + input: TensorLikeType, + p: Optional[Union[float, str]] = "fro", + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # In these cases we compute the "Frobenius norm" + if ( + p == "fro" and (dim is None or isinstance(dim, Dim) or len(dim) <= 2) + ) or p is None: + p = 2 + if isinstance(dim, Dim): + dim = [dim] + if isinstance(p, str): + # Here we either call the nuclear norm, or we call matrix_norm with some arguments + # that will throw an error + if dim is None: + dim = tuple(range(input.ndim)) + return torch.linalg.matrix_norm(input, p, dim, keepdim, dtype=dtype) + else: + return torch.linalg.vector_norm(input, p, dim, keepdim, dtype=dtype) + + +@register_decomposition(aten.trace) +@out_wrapper() +def trace(self: TensorLikeType) -> TensorLikeType: + torch._check( + self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}" + ) + return torch.sum(torch.diag(self, 0)) + + +def _make_r_binary_op(base_op): + def rop( + a: Union[TensorLikeType, NumberType], + b: Union[TensorLikeType, NumberType], + ) -> TensorLikeType: + return base_op(b, a) + + return rop + + +rtruediv = _make_r_binary_op(true_divide) +rfloordiv = _make_r_binary_op(floor_divide) +rpow = _make_r_binary_op(pow) + + +@register_decomposition(aten.triu) +@out_wrapper() +def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: + torch._check( + a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions" + ) + h, w = a.shape[-2:] + mask = ( + torch.arange(w, device=a.device).unsqueeze(-2) + - torch.arange(h, device=a.device).unsqueeze(-1) + ) >= diagonal + + # aten.triu always returns a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return utils.mask_tensor(mask, a).contiguous() + + +@register_decomposition(aten.tril) +@out_wrapper() +def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType: + torch._check( + a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions" + ) + h, w = a.shape[-2:] + mask = ( + torch.arange(w, device=a.device).unsqueeze(-2) + - torch.arange(h, device=a.device).unsqueeze(-1) + ) <= diagonal + + # aten.tril always returns a new contiguous tensor + # contiguous() is needed to correctly model the output stride + return utils.mask_tensor(mask, a).contiguous() + + +# This is based on get_tril_size in aten/src/ATen/native/TensorFactories.h +# The components of the matrix that belong to the lower triangle with offset +# form a pentagon that can be broken down into a top trapezoid and a bottom +# rectangle. For the implementation of tril_indices, we need the sizes of +# both of these, as well as the length of the top side of the trapezoid. +def _get_tril_sizes(row: int, col: int, offset: int) -> tuple[int, int, int]: + if row == 0 or col == 0: + return 0, 0, 0 + + m_first_row = min(col, 1 + offset) if offset > 0 else int(row + offset > 0) + m_last_row = max(0, min(col, row + offset)) + n_row_all = max(0, min(row, row + offset)) + n_row_trapezoid = m_last_row - m_first_row + 1 + + # Number of elements in top trapezoid + trapezoid_size = (m_first_row + m_last_row) * n_row_trapezoid // 2 + # Number of elements in bottom rectangle + diff_row = n_row_all - n_row_trapezoid + rectangle_size = max(0, diff_row * col) + + return trapezoid_size, rectangle_size, m_first_row + + +def _trilu_checks( + name: str, + row: int, + col: int, + dtype: torch.dtype, + layout: torch.layout, + pin_memory: bool, +): + torch._check(row >= 0, lambda: f"row must be non-negative, got {row}") + torch._check(col >= 0, lambda: f"col must be non-negative, got {col}") + torch._check( + dtype in (torch.int32, torch.int64), + lambda: f"\"{name}\" not implemented for '{dtype}'", + ) + + +# This is based on tril_indices_cuda in aten/src/ATen/native/cuda/TensorFactories.cu +@register_decomposition(aten.tril_indices) +@out_wrapper() +def tril_indices( + row: int, + col: int, + offset: int = 0, + *, + dtype: torch.dtype = torch.long, + layout: torch.layout = torch.strided, + device: DeviceLikeType = "cpu", + pin_memory: bool = False, +) -> TensorLikeType: + _trilu_checks("tril_indices", row, col, dtype, layout, pin_memory) + + trapezoid_size, rectangle_size, m_first_row = _get_tril_sizes(row, col, offset) + row_offset = max(0, -offset) + + arange_kw = partial( + torch.arange, layout=layout, device=device, pin_memory=pin_memory + ) + + # first we do the indices for top trapezoid + xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64) + b = m_first_row - 0.5 + row_inds1 = torch.floor(-b + torch.sqrt(b * b + 2 * xs1)) + col_inds1 = torch.floor(xs1 - (2 * m_first_row - 1 + row_inds1) * row_inds1 * 0.5) + row_inds1 = _maybe_convert_to_dtype(row_inds1 + row_offset, dtype) + col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype) + + # then bottom rectangle + xs2 = arange_kw(0, rectangle_size, dtype=dtype) + row_inds2 = xs2 // col + (col - m_first_row + 1 + row_offset) + col_inds2 = xs2 % col + + return torch.stack( + (torch.cat((row_inds1, row_inds2)), torch.cat((col_inds1, col_inds2))) + ) + + +# Similar to _get_tril_sizes above, but here there is a top trapezoid and +# a bottom rectangle instead. Note that you can't reduce this to +# _get_tril_sizes(col, row, -offset) because that would correspond to +# decomposing into a left trapezoid and right rectangle. +def _get_triu_sizes(row: int, col: int, offset: int) -> tuple[int, int, int]: + if row == 0 or col == 0: + return 0, 0, 0 + + m_first_row = max(0, col - offset) if offset > 0 else col + + # Number of elements in top rectangle + rectangle_size = max(0, min(row, -offset) * col) + + # Number of elements in bottom trapezoid + trapezoid_size_tril, rectangle_size_tril, _ = _get_tril_sizes(row, col, offset - 1) + triu_size = row * col - (trapezoid_size_tril + rectangle_size_tril) + trapezoid_size = triu_size - rectangle_size + + return trapezoid_size, rectangle_size, m_first_row + + +@register_decomposition(aten.triu_indices) +@out_wrapper() +def triu_indices( + row: int, + col: int, + offset: int = 0, + *, + dtype: torch.dtype = torch.long, + layout: torch.layout = torch.strided, + device: DeviceLikeType = "cpu", + pin_memory: bool = False, +) -> TensorLikeType: + _trilu_checks("triu_indices", row, col, dtype, layout, pin_memory) + + trapezoid_size, rectangle_size, m_first_row = _get_triu_sizes(row, col, offset) + col_offset = max(0, offset) + + arange_kw = partial( + torch.arange, layout=layout, device=device, pin_memory=pin_memory + ) + + # indices for top rectangle + xs2 = arange_kw(0, rectangle_size, dtype=dtype) + row_inds2 = xs2 // col + col_inds2 = xs2 % col + + # bottom trapezoid + xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64) + b = -0.5 - m_first_row + row_inds1 = torch.floor(-b - torch.sqrt(b * b - 2 * xs1)) + col_inds1 = torch.floor(xs1 - ((2 * m_first_row - 1 - row_inds1) * row_inds1) * 0.5) + row_inds1 = _maybe_convert_to_dtype(row_inds1, dtype) + col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype) + + if col: + row_inds1 = row_inds1 + (rectangle_size // col) + col_inds1 = col_inds1 + col_offset + + return torch.stack( + (torch.cat((row_inds2, row_inds1)), torch.cat((col_inds2, col_inds1))) + ) + + +@register_decomposition(aten.bucketize) +@out_wrapper(exact_dtype=True) +def bucketize( + a: TensorOrNumberLikeType, + boundaries: TensorLikeType, + *, + out_int32: bool = False, + right: bool = False, +): + torch._check( + boundaries.dim() == 1, + lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})", + ) + + a = a if isinstance(a, torch.Tensor) else torch.tensor(a) + out_dtype = torch.int32 if out_int32 else torch.int64 + n_boundaries = boundaries.shape[-1] + if n_boundaries == 0: + return torch.zeros_like(a) + # We are trying to find the bucket (defined by pairs of consecutive elements of `boundaries`) + # each element of `a` belongs to. We use binary search to achieve logarithimic complexity, + # but each step of the search is done "in parallel" over all elements of `a` + # can't use int32 as indexes, so we have to do all computations with int64 and convert at the end + start = torch.zeros(a.shape, device=a.device, dtype=torch.int64) + end = start + n_boundaries + # Max depth of the binary search + # Since we can't break out of the loop at different points for different elements of a, + # we just do the max amount of iterations that binary search requires and add condition + # tensor (cond_update below) to stop updating once the search terminates + + # For first iteration through loop we can skip some checks, we have separate implementation + mid = start + (end - start) // 2 + mid_val = boundaries[mid] + if right: + cond_mid = mid_val > a + else: + cond_mid = mid_val >= a + start = torch.where(cond_mid, start, mid + 1) + + if n_boundaries > 1: + cond_update = torch.ones_like(a, dtype=torch.bool) + niters = int(math.log2(n_boundaries)) + for _ in range(niters): + end = torch.where(cond_mid & cond_update, mid, end) + cond_update = start < end + # start might end up pointing to 1 past the end, we guard against that + mid = torch.where(cond_update, start + (end - start) // 2, 0) + mid_val = boundaries[mid] + # If right is true, the buckets are closed on the *left* + # (i.e., we are doing the equivalent of std::upper_bound in C++) + # Otherwise they are closed on the right (std::lower_bound) + if right: + cond_mid = mid_val > a + else: + cond_mid = mid_val >= a + start = torch.where((~cond_mid) & cond_update, mid + 1, start) + + return start.to(dtype=out_dtype) + + +@register_decomposition(aten.cauchy) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def cauchy(self, median=0, sigma=1, generator=None): + assert generator is None + torch._check( + not utils.is_complex_dtype(self.dtype) + and not utils.is_integer_dtype(self.dtype) + and not utils.is_boolean_dtype(self.dtype), + lambda: f"Cauchy distribution is a continuous probability distribution. \ + dtype must be a floating point but you specified {self.dtype}", + ) + torch._check( + sigma > 0.0, + lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}", + ) + return median + sigma * torch.tan(math.pi * (torch.rand_like(self) - 0.5)) + + +@register_decomposition(aten.exponential) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def exponential(self, rate=1, generator=None): + assert generator is None + torch._check( + not utils.is_complex_dtype(self.dtype) + and not utils.is_integer_dtype(self.dtype) + and not utils.is_boolean_dtype(self.dtype), + lambda: f"Exponential distribution is a continuous probability distribution. \ + dtype must be a floating point but you specified {self.dtype}", + ) + torch._check( + rate > 0.0, + lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}", + ) + + uniform_val = torch.rand_like(self) + + # copying numerics of transformation::exponential see comment: + # curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0. + # we need log to be not 0, and not underflow when converted to half + # fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1 args + epsilon = torch.finfo(uniform_val.dtype).eps / 2 + condition = uniform_val >= 1.0 - epsilon + log_uniform = torch.where(condition, -epsilon, torch.log(uniform_val)) + + return -1 / rate * log_uniform + + +@register_decomposition(aten.geometric) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def geometric(self, p, generator=None): + assert generator is None + # TODO: fix inductor rand_like for integer, bool dtypes + torch._check( + not utils.is_complex_dtype(self.dtype) + and not utils.is_boolean_dtype(self.dtype), + lambda: f"geometric not implemented for {self.dtype}", + ) + torch._check( + 0 < p and p < 1, + lambda: f"geometric_ expects p to be in (0, 1), but got p={p}", + ) + return torch.floor(torch.log1p(-torch.rand_like(self)) / math.log1p(-p)) + 1 + + +@register_decomposition(aten.log_normal) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def log_normal(self, mean=1, std=2, generator=None): + assert generator is None + torch._check( + not utils.is_complex_dtype(self.dtype) + and not utils.is_integer_dtype(self.dtype) + and not utils.is_boolean_dtype(self.dtype), + lambda: f"log_normal not implemented for {self.dtype}", + ) + torch._check( + 0 < std, + lambda: f"log_normal_ expects std > 0.0, but found std={std}", + ) + return torch.exp(std * torch.randn_like(self) + mean) + + +# TODO: add support for functionalization aten.normal_functional +# NOTE: the device and dtype will be ignored when shape is None +@register_decomposition(aten.normal) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=( + "mean", + "std", + ), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def normal( + mean=0, + std=1, + size=None, + *, + generator=None, + dtype=None, + layout=None, + device=None, + pin_memory=None, +): + assert layout is None or layout == torch.strided + + if not isinstance(std, TensorLike): + torch._check( + std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}" + ) + + if size is None: + tensors = tuple(t for t in (mean, std) if isinstance(t, TensorLike)) + torch._check( + len(tensors) > 0, + lambda: "normal expects that either mean or std is a tensor, or size is defined", + ) + torch._check( + layout is None and pin_memory is None, + lambda: "Cannot pass layout, or pin_memory without size", + ) + + size = _broadcast_shapes(*(t.shape for t in tensors)) + dtype = tensors[0].dtype + device = tensors[0].device + else: + torch._check( + not isinstance(mean, TensorLike) and not isinstance(std, TensorLike), + lambda: "normal expects mean and std to be scalars when size is defined", + ) + dtype = torch.get_default_dtype() if dtype is None else dtype + device = torch.device("cpu") if device is None else device + + normal_samples = prims.normal( + size, + mean=0.0, + std=1.0, + dtype=dtype, + device=device, + requires_grad=False, + generator=generator, + ) + return std * normal_samples + mean + + +@register_decomposition(aten.normal_) +def normal_(self, mean=0, std=1, *, generator=None): + return normal(mean, std, self.shape, out=self, generator=generator) + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def rad2deg(self: TensorLikeType): + torch._check( + not utils.is_complex_dtype(self.dtype), + lambda: "rad2deg is not supported for complex tensors.", + ) + M_180_PI = 57.295779513082320876798154814105170332405472466564 + return self * M_180_PI + + +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT) +def deg2rad(self: TensorLikeType): + torch._check( + not utils.is_complex_dtype(self.dtype), + lambda: "deg2rad is not supported for complex tensors.", + ) + M_PI_180 = 0.017453292519943295769236907684886127134428718885417 + return self * M_PI_180 + + +@register_decomposition(aten.count_nonzero) +@out_wrapper() +def count_nonzero(self, dim: Optional[DimsType] = None): + return (self != 0).sum(dim) + + +def _dot_check(self, other): + torch._check( + self.dim() == 1 and other.dim() == 1, + lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors", + ) + + torch._check( + self.dtype == other.dtype, + lambda: "dot : expected both vectors to have same dtype, but found " + f"{self.dtype} and {other.dtype}", + ) + + def numel_error(): + return ( + f"inconsistent tensor size, expected tensor [{self.numel()}] and src [{other.numel()}] to have the" + f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively" + ) + + torch._check(self.numel() == other.numel(), numel_error) + + +def _dot_check_wrapper(fn): + @wraps(fn) + def wrapper(self, other): + _dot_check(self, other) + return fn(self, other) + + return wrapper + + +@register_decomposition(aten.dot) +@out_wrapper(exact_dtype=True) +@_dot_check_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "other"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def dot(self, other): + if self.is_complex(): + if self.is_conj(): + if other.is_conj(): + return torch.dot(self.conj(), other.conj()).conj() + else: + return torch.vdot(self.conj(), other) + elif other.is_conj(): + return torch.vdot(other.conj(), self) + + return (self * other).sum() + + +@register_decomposition(aten.vdot) +@out_wrapper(exact_dtype=True) +@_dot_check_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("self", "other"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def vdot(self, other): + if not self.is_complex(): + return torch.dot(self, other) + + if self.is_conj(): + if other.is_conj(): + return torch.vdot(other.conj(), self.conj()) + else: + return torch.dot(self.conj(), other) + elif other.is_conj(): + return torch.dot(self, other.conj()).conj() + + # The decomposition fails if you do self.conj()... not sure why + return (self.conj_physical() * other).sum() + + +@register_decomposition(aten.select_scatter) +@out_wrapper() +def select_scatter(x: TensorLikeType, src: TensorLikeType, dim: int, index: int): + dim = utils.canonicalize_dim(x.ndim, dim) + mask_shape = [1] * x.ndim + mask_shape[dim] = -1 + if index < 0: + index = index + x.shape[dim] + mask = torch.arange(x.shape[dim], device=x.device).view(mask_shape) == index + src = torch.unsqueeze(src, dim).expand(x.shape) + return torch.where(mask, src, x) + + +# inplace +abs_ = _make_inplace(abs) +acos_ = _make_inplace(acos) +acosh_ = _make_inplace(acosh) +add_ = _make_inplace(add) +addcmul_ = _make_inplace(addcmul) +addcdiv_ = _make_inplace(addcdiv) +asin_ = _make_inplace(asin) +asinh_ = _make_inplace(asinh) +atan_ = _make_inplace(atan) +atanh_ = _make_inplace(atanh) +atan2_ = _make_inplace(atan2) +bitwise_and_ = _make_inplace(bitwise_and) +bitwise_left_shift_ = _make_inplace(bitwise_left_shift) +bitwise_not_ = _make_inplace(bitwise_not) +bitwise_or_ = _make_inplace(bitwise_or) +bitwise_right_shift_ = _make_inplace(bitwise_right_shift) +bitwise_xor_ = _make_inplace(bitwise_xor) +ceil_ = _make_inplace(ceil) +clamp_ = _make_inplace(clamp) +clamp_min_ = _make_inplace(clamp_min) +clamp_max_ = _make_inplace(clamp_max) +conj_physical_ = _make_inplace(conj_physical) +copysign_ = _make_inplace(copysign) +cos_ = _make_inplace(cos) +cosh_ = _make_inplace(cosh) +cumsum_ = _make_inplace(cumsum) +cumprod_ = _make_inplace(cumprod) +deg2rad_ = _make_inplace(deg2rad) +digamma_ = _make_inplace(digamma) +div_ = _make_inplace(div) +eq_ = _make_inplace(eq) +erf_ = _make_inplace(erf) +erfc_ = _make_inplace(erfc) +erfinv_ = _make_inplace(erfinv) +exp_ = _make_inplace(exp) +exp2_ = _make_inplace(exp2) +expm1_ = _make_inplace(expm1) +float_power_ = _make_inplace(float_power) +floor_ = _make_inplace(floor) +floor_divide_ = _make_inplace(floor_divide) +fmod_ = _make_inplace(fmod) +frac_ = _make_inplace(frac) +gcd_ = _make_inplace(gcd) +ge_ = _make_inplace(ge) +gt_ = _make_inplace(gt) +heaviside_ = _make_inplace(heaviside) +hypot_ = _make_inplace(hypot) +igamma_ = _make_inplace(igamma) +igammac_ = _make_inplace(igammac) +i0_ = _make_inplace(i0) +lcm_ = _make_inplace(lcm) +le_ = _make_inplace(le) +lerp_ = _make_inplace(lerp) +lgamma_ = _make_inplace(lgamma) +log10_ = _make_inplace(log10) +log1p_ = _make_inplace(log1p) +log2_ = _make_inplace(log2) +log_ = _make_inplace(log) +logical_and_ = _make_inplace(logical_and) +logical_not_ = _make_inplace(logical_not) +logical_or_ = _make_inplace(logical_or) +logical_xor_ = _make_inplace(logical_xor) +lt_ = _make_inplace(lt) +mul_ = _make_inplace(mul) +mvlgamma_ = _make_inplace(mvlgamma) +nan_to_num_ = _make_inplace(nan_to_num) +ne_ = _make_inplace(ne) +neg_ = _make_inplace(neg) +nextafter_ = _make_inplace(nextafter) +pow_ = _make_inplace(pow) +rad2deg_ = _make_inplace(rad2deg) +reciprocal_ = _make_inplace(reciprocal) +remainder_ = _make_inplace(remainder) +rsqrt_ = _make_inplace(rsqrt) +sgn_ = _make_inplace(sgn) +sigmoid_ = _make_inplace(sigmoid) +sign_ = _make_inplace(sign) +sin_ = _make_inplace(sin) +sinc_ = _make_inplace(sinc) +sinh_ = _make_inplace(sinh) +sqrt_ = _make_inplace(sqrt) +square_ = _make_inplace(square) +sub_ = _make_inplace(sub) +tan_ = _make_inplace(tan) +tanh_ = _make_inplace(tanh) +tril_ = _make_inplace(tril) +triu_ = _make_inplace(triu) +true_divide_ = _make_inplace(true_divide) +trunc_ = _make_inplace(trunc) +xlogy_ = _make_inplace(xlogy) +cauchy_ = _make_inplace(cauchy) +exponential_ = _make_inplace(exponential) +geometric_ = _make_inplace(geometric) +log_normal_ = _make_inplace(log_normal) +zero_ = _make_inplace(zero) + +alias_copy = _make_copy_from_view(aten.alias) +as_strided_copy = _make_copy_from_view(aten.as_strided) +diagonal_copy = _make_copy_from_view(aten.diagonal) +expand_copy = _make_copy_from_view(aten.expand) +# TODO: This must return a sparse tensor if the input is sparse, but refs have +# no sparse support. See narrow_copy_sparse in core. +narrow_copy = _make_copy_from_view(aten.narrow) +squeeze_copy = _make_copy_from_view(aten.squeeze) +permute_copy = _make_copy_from_view(aten.permute) +t_copy = _make_copy_from_view(aten.t) +transpose_copy = _make_copy_from_view(aten.transpose) +unbind_copy = _make_copy_from_view(aten.unbind) +unsqueeze_copy = _make_copy_from_view(aten.unsqueeze) +view_copy = _make_copy_from_view(aten.view) + + +# xref: isStorage in torch/csrc/DynamicTypes.cpp +def _isStorage(obj): + return isinstance(obj, (torch.TypedStorage, torch.UntypedStorage)) + + +# xref: compute_sizes in torch/csrc/utils/tensor_new.cpp +def _compute_sizes(seq, scalar_type): + MAX_DIMS = 128 + is_storage = _isStorage(seq) + sizes = [] + # TODO: this is inaccurate, we actually test PySequence_Check + while isinstance(seq, (list, tuple)): + length = len(seq) + if is_storage: + length //= scalar_type.itemsize + sizes.append(length) + if len(sizes) > MAX_DIMS: + raise ValueError(f"too many dimensions '{type(seq).__name__}'") + if length == 0: + break + try: + handle = seq[0] + except Exception: + raise ValueError( # noqa: B904 + f"could not determine the shape of object type '{type(seq).__name__}'" + ) + seq = handle + + return sizes + + +# xref: infer_scalar_type in torch/csrc/utils/tensor_new.cpp +def _infer_scalar_type(obj): + if isinstance(obj, FloatLike): + return torch.get_default_dtype() + if isinstance(obj, IntLike) and not isinstance(obj, bool): # careful! + return torch.int64 + if isinstance(obj, BoolLike): + return torch.bool + if isinstance(obj, complex): + default_dtype = torch.get_default_dtype() + if default_dtype is torch.float: + return torch.cfloat + elif default_dtype is torch.double: + return torch.cdouble + elif default_dtype is torch.half: + return torch.chalf + else: + raise RuntimeError("invalid default scalar type for complex") + if isinstance(obj, torch.Tensor): + return obj.dtype + if isinstance(obj, str): + raise TypeError(f"new(): invalid data type '{type(obj).__name__}'") + # TODO: this is inaccurate, we actually test PySequence_Check + if isinstance(obj, (list, tuple)): + scalarType = None + length = len(obj) + # match NumPy semantics, except use default tensor type instead of + # double. + if length == 0: + return torch.get_default_dtype() + for i in range(length): + cur_item = obj[i] + # TODO: test this + """ + if cur_item is obj: + raise TypeError("new(): self-referential lists are incompatible") + """ + item_scalarType = _infer_scalar_type(cur_item) # recurse! + if scalarType is not None: + scalarType = torch.promote_types(scalarType, item_scalarType) + else: + scalarType = item_scalarType + if scalarType is torch.cdouble: + # this won't change (unless we hit undefined, but that will + # fail later) + return scalarType + return scalarType + raise RuntimeError(f"Could not infer dtype of {type(obj).__name__}") + + +# Analogous to recursive_store +# xref: recursive_store in torch/csrc/utils/tensor_new.cpp +def _recursive_build( + scalarType: torch.dtype, obj: Union[TensorOrNumberLikeType, TensorSequenceType] +): + if isinstance(obj, Tensor) and obj.numel() == 1: + return obj.detach().to(dtype=scalarType, device="cpu", copy=True).view(()) + elif isinstance(obj, Tensor): + # It is invalid to call ".tensor([...])" with a non-scalar tensor in eager mode + # >>> torch.tensor([torch.randn(2)]) + # ValueError: only one element tensors can be converted to Python scalars + # + # But it is possible with a NumPy array + # >>> torch.tensor([np.random.uniform(size=(2,))]).shape + # torch.Size([1, 2]) + return obj.detach().to(dtype=scalarType, device="cpu", copy=True) + elif isinstance(obj, Number): + return torch.scalar_tensor(obj, dtype=scalarType) + + # seq can be a list of tensors + seq = obj + return ( + torch.empty(0) + if not seq + else torch.stack([_recursive_build(scalarType, item) for item in seq]) + ) + + +# xref: internal_new_from_data in torch/csrc/utils/tensor_new.cpp +def _internal_new_from_data( + options, + scalar_type, + device_opt, + data, + copy_variables, + copy_numpy, + type_inference, + pin_memory=False, +): + if isinstance(data, torch.Tensor): + torch._check( + not pin_memory, lambda: "Can't pin tensor constructed from a variable" + ) + var = data + if copy_variables: + var = var.detach() + inferred_scalar_type = var.dtype if type_inference else scalar_type + device = device_opt if device_opt is not None else var.device + return var.to( + device=device, + dtype=inferred_scalar_type, + non_blocking=False, + copy=copy_variables, + ) + + # TODO + if hasattr(data, "__cuda_array_interface__"): + return NotImplemented + + # TODO: test for numpy input with PyArray_Check + + device = device_opt if device_opt is not None else options["device"] + inferred_scalar_type = _infer_scalar_type(data) if type_inference else scalar_type + + # NB: Don't need to avoid tracing, as we aren't going to do any manual + # pointer filling tricks + if _isStorage(data): + return NotImplemented + else: + if torch.device(device).type == "meta": + return NotImplemented + + # In the C implementation, we would directly start poking the memory + # of a freshly allocated CPU tensor. Here, we're going to do an + # alternate, heinously slow implementation: turn each individual + # scalar into a tensor, and then repeatedly cat them together + tensor = _recursive_build(inferred_scalar_type, data) + + tensor = tensor.to(device, inferred_scalar_type, non_blocking=False, copy=False) + + # NB: lift_fresh is not needed, because we built the tensor from scalars + # guaranteeing a fresh tensor in this case + return tensor + + +# xref: tensor_ctor in torch/csrc/utils/tensor_new.cpp +def tensor(data, *, dtype=None, device=None, pin_memory=False, requires_grad=False): + # TODO (or not): support names kwarg + if isinstance(data, torch.Tensor): + warnings.warn( + "To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() " + "or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor)", + UserWarning, + stacklevel=2, + ) + type_inference = dtype is None + new_tensor = _internal_new_from_data( + # device="cpu" because that's what you get with torch.tensor(2) no + # device by default + {"device": "cpu"}, # TODO: use torch.get_default_tensor_type + dtype if dtype is not None else torch.get_default_dtype(), + device, + data, + copy_variables=True, + copy_numpy=True, + type_inference=type_inference, + pin_memory=pin_memory, + ) + new_tensor.detach_() + if requires_grad: + new_tensor.requires_grad_(requires_grad) + return new_tensor + + +# Views +# We can't model these as above, as the pattern of doing `op(a, out=a)` does not work for a view function +# given that it does not reshape the input (it just copies the result into it) + +# squeeze_ = _make_inplace(squeeze) +# t_ = _make_inplace(t) +# transpose_ = _make_inplace(transpose) +# unsqueeze_ = _make_inplace(unsqueeze) + + +import torch._refs._conversions +import torch._refs.fft +import torch._refs.linalg +import torch._refs.nn.functional +import torch._refs.special diff --git a/phivenv/Lib/site-packages/torch/_refs/__pycache__/_conversions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_refs/__pycache__/_conversions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a67ffab982988a3316b45f5301488f7f3bd6ab8c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_refs/__pycache__/_conversions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_refs/__pycache__/fft.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_refs/__pycache__/fft.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dff108bce3490277d7038035c93bdf409af7e9c4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_refs/__pycache__/fft.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_refs/_conversions.py b/phivenv/Lib/site-packages/torch/_refs/_conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..ca766da8a590f7db64d5ca690390cde8b0c87b7f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_refs/_conversions.py @@ -0,0 +1,119 @@ +# mypy: allow-untyped-defs +import torch +import torch._prims_common as utils + +# Utilities should come BEFORE this import +from torch._decomp import register_decomposition +from torch._prims_common import TensorLikeType +from torch._prims_common.wrappers import out_wrapper +from torch._refs import _broadcast_shapes + + +# Data conversion references. +# +# Note: this module breaks the usual _refs to torch naming scheme where +# _refs.foo.bar is a ref for torch.foo.bar. The following definitions are not +# part of _refs/__init__.py to avoid name clashes with Python builtin types +# (like int). + +__all__ = [ + # dtypes + "bfloat16", + "bool", + "byte", + "cdouble", + "cfloat", + "chalf", + "char", + "double", + "float", + "half", + "int", + "long", + "short", + # misc + "complex", + "polar", +] + + +def _make_conversion_method(name: str, dtype: torch.dtype): + def fn( + self: TensorLikeType, memory_format: torch.memory_format = torch.preserve_format + ) -> TensorLikeType: + return self.to(dtype, memory_format=memory_format) # type: ignore[call-overload] + + fn.__name__ = name + return fn + + +bfloat16 = _make_conversion_method("bfloat16", torch.bfloat16) + +bool = _make_conversion_method("bool", torch.bool) + +byte = _make_conversion_method("byte", torch.uint8) + +cdouble = _make_conversion_method("cdouble", torch.cdouble) + +cfloat = _make_conversion_method("cfloat", torch.cfloat) + +chalf = _make_conversion_method("chalf", torch.complex32) + +char = _make_conversion_method("char", torch.int8) + +double = _make_conversion_method("double", torch.double) + +float = _make_conversion_method("float", torch.float) + +half = _make_conversion_method("half", torch.half) + +int = _make_conversion_method("int", torch.int) + +long = _make_conversion_method("long", torch.long) + +short = _make_conversion_method("short", torch.short) + + +@register_decomposition(torch._ops.ops.aten.complex) +# Note: complex has type promotion tests disabled due to different semantics. +# exact_dtype is for compat with complex_check_dtype from core. +@out_wrapper(exact_dtype=True) +def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType: + allowed_dtypes = (torch.float32, torch.float64, torch.float16) + torch._check( + real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes, + lambda: ( + f"Expected both inputs to be Half, Float or Double tensors but got " + f"{real.dtype} and {imag.dtype}" + ), + ) + torch._check( + real.dtype == imag.dtype, + lambda: ( + f"Expected object of scalar type {real.dtype} but got " + f"scalar type {imag.dtype} for second argument" + ), + ) + result_dtype = utils.corresponding_complex_dtype(real.dtype) # type: ignore[arg-type] + common_shape = _broadcast_shapes(real.shape, imag.shape) + result = real.new_empty( + common_shape, + dtype=result_dtype, + layout=real.layout, + device=real.device, + # pin_memory=real.is_pinned(), # NYI + ) + result.real = real + result.imag = imag + return result + + +@register_decomposition(torch._ops.ops.aten.polar) +# Note: polar has type promotion tests disabled due to different semantics. +# exact_dtype is for compat with complex_check_dtype from core. +@out_wrapper(exact_dtype=True) +def polar(abs: TensorLikeType, angle: TensorLikeType) -> TensorLikeType: + result = torch.complex(abs, angle) + result.real = abs * torch.cos(angle) + result.imag = abs * torch.sin(angle) + return result diff --git a/phivenv/Lib/site-packages/torch/_refs/fft.py b/phivenv/Lib/site-packages/torch/_refs/fft.py new file mode 100644 index 0000000000000000000000000000000000000000..a36f64e4e9b44ca17f6640e9caada8a84f7b4b5f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_refs/fft.py @@ -0,0 +1,591 @@ +import math +from collections.abc import Iterable, Sequence +from typing import Literal, NamedTuple, Optional, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +from torch._decomp import register_decomposition +from torch._prims_common import DimsType, ShapeType, TensorLikeType +from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper + + +__all__ = [ + # Transforms + "fft", + "fft2", + "fftn", + "hfft", + "hfft2", + "hfftn", + "rfft", + "rfft2", + "rfftn", + "ifft", + "ifft2", + "ifftn", + "ihfft", + "ihfft2", + "ihfftn", + "irfft", + "irfft2", + "irfftn", + # Helpers + "fftshift", + "ifftshift", +] + +NormType = Union[None, Literal["forward", "backward", "ortho"]] +_NORM_VALUES = {None, "forward", "backward", "ortho"} +aten = torch._ops.ops.aten + + +def _apply_norm( + x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool +) -> TensorLikeType: + """Apply normalization to the un-normalized FFT result""" + torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}") + + if norm == "ortho": + return x * (1 / math.sqrt(signal_numel)) + + normalize = (not forward and (norm is None or norm == "backward")) or ( + forward and norm == "forward" + ) + return x * (1 / signal_numel) if normalize else x + + +def _promote_type_fft( + dtype: torch.dtype, require_complex: bool, device: torch.device +) -> torch.dtype: + """Helper to promote a dtype to one supported by the FFT primitives""" + if dtype.is_complex: + return dtype + + # Promote integral to default float type + if not dtype.is_floating_point: + dtype = torch.get_default_dtype() + + allowed_types = [torch.float32, torch.float64] + maybe_support_half = device.type in ["cuda", "meta"] + + if maybe_support_half: + allowed_types.append(torch.float16) + torch._check(dtype in allowed_types, lambda: f"Unsupported dtype {dtype}") + + if require_complex: + dtype = utils.corresponding_complex_dtype(dtype) + + return dtype + + +def _maybe_promote_tensor_fft( + t: TensorLikeType, require_complex: bool = False +) -> TensorLikeType: + """Helper to promote a tensor to a dtype supported by the FFT primitives""" + cur_type = t.dtype + new_type = _promote_type_fft(cur_type, require_complex, t.device) + return _maybe_convert_to_dtype(t, new_type) # type: ignore[return-value] + + +def _resize_fft_input( + x: TensorLikeType, dims: tuple[int, ...], sizes: tuple[int, ...] +) -> TensorLikeType: + """ + Fixes the shape of x such that x.size(dims[i]) == sizes[i], + either by zero-padding, or by slicing x starting from 0. + """ + assert len(dims) == len(sizes) + must_copy = False + x_sizes = x.shape + pad_amount = [0] * len(x_sizes) * 2 + for i in range(len(dims)): + if sizes[i] == -1: + continue + + if x_sizes[dims[i]] < sizes[i]: + must_copy = True + pad_idx = len(pad_amount) - 2 * dims[i] - 1 + pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]] + + if x_sizes[dims[i]] > sizes[i]: + x = x.narrow(dims[i], 0, sizes[i]) + + return torch.constant_pad_nd(x, pad_amount) if must_copy else x + + +def _fft_c2r( + func_name: str, + input: TensorLikeType, + n: Optional[int], + dim: int, + norm: NormType, + forward: bool, +) -> TensorLikeType: + """Common code for performing any complex to real FFT (irfft or hfft)""" + input = _maybe_promote_tensor_fft(input, require_complex=True) + dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) + last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1) + torch._check( + last_dim_size >= 1, + lambda: f"Invalid number of data points ({last_dim_size}) specified", + ) + + if n is not None: + input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,)) + + if forward: + input = torch.conj(input) + + output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size) + return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward) + + +def _fft_r2c( + func_name: str, + input: TensorLikeType, + n: Optional[int], + dim: int, + norm: NormType, + forward: bool, + onesided: bool, +) -> TensorLikeType: + """Common code for performing any real to complex FFT (rfft or ihfft)""" + torch._check( + not input.dtype.is_complex, + lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}", + ) + input = _maybe_promote_tensor_fft(input) + dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) + dim_size = n if n is not None else input.shape[dim] + torch._check( + dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified" + ) + + if n is not None: + input = _resize_fft_input(input, dims, (n,)) + + ret = prims.fft_r2c(input, dim=dims, onesided=onesided) + ret = _apply_norm(ret, norm, dim_size, forward) + return ret if forward else torch.conj(ret) + + +def _fft_c2c( + func_name: str, + input: TensorLikeType, + n: Optional[int], + dim: int, + norm: NormType, + forward: bool, +) -> TensorLikeType: + """Common code for performing any complex to complex FFT (fft or ifft)""" + torch._check( + input.dtype.is_complex, + lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}", + ) + dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),) + dim_size = n if n is not None else input.shape[dim] + torch._check( + dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified" + ) + + if n is not None: + input = _resize_fft_input(input, dims, (n,)) + + ret = prims.fft_c2c(input, dim=dims, forward=forward) + return _apply_norm(ret, norm, dim_size, forward) + + +@register_decomposition(aten.fft_fft) +@out_wrapper() +def fft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + if input.dtype.is_complex: + return _fft_c2c("fft", input, n, dim, norm, forward=True) + else: + return _fft_r2c("fft", input, n, dim, norm, forward=True, onesided=False) + + +@register_decomposition(aten.fft_ifft) +@out_wrapper() +def ifft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + if input.dtype.is_complex: + return _fft_c2c("ifft", input, n, dim, norm, forward=False) + else: + return _fft_r2c("ifft", input, n, dim, norm, forward=False, onesided=False) + + +@register_decomposition(aten.fft_rfft) +@out_wrapper() +def rfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_r2c("rfft", input, n, dim, norm, forward=True, onesided=True) + + +@register_decomposition(aten.fft_irfft) +@out_wrapper() +def irfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_c2r("irfft", input, n, dim, norm, forward=False) + + +@register_decomposition(aten.fft_hfft) +@out_wrapper() +def hfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_c2r("hfft", input, n, dim, norm, forward=True) + + +@register_decomposition(aten.fft_ihfft) +@out_wrapper() +def ihfft( + input: TensorLikeType, + n: Optional[int] = None, + dim: int = -1, + norm: NormType = None, +) -> TensorLikeType: + return _fft_r2c("ihfft", input, n, dim, norm, forward=False, onesided=True) + + +class _ShapeAndDims(NamedTuple): + shape: tuple[int, ...] + dims: tuple[int, ...] + + +def _canonicalize_fft_shape_and_dim_args( + input: TensorLikeType, shape: Optional[ShapeType], dim: Optional[DimsType] +) -> _ShapeAndDims: + """Convert the shape and dim arguments into a canonical form where neither are optional""" + input_dim = input.ndim + input_sizes = input.shape + + if dim is not None: + if not isinstance(dim, Sequence): + dim = (dim,) + ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False) + + # Check dims are unique + torch._check( + len(set(ret_dims)) == len(ret_dims), lambda: "FFT dims must be unique" + ) + + if shape is not None: + if not isinstance(shape, Sequence): + shape = (shape,) + + # Has shape, might have dim + torch._check( + dim is None or len(dim) == len(shape), + lambda: "When given, dim and shape arguments must have the same length", + ) + transform_ndim = len(shape) + + torch._check( + transform_ndim <= input_dim, + lambda: f"Got shape with {transform_ndim} values but input tensor " + f"only has {input_dim} dimensions.", + ) + + # If shape is given, dims defaults to the last len(shape) dimensions + if dim is None: + ret_dims = tuple(range(input_dim - transform_ndim, input_dim)) + + # Translate any -1 values in shape to the default length + ret_shape = tuple( + s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined] + ) + elif dim is None: + # No shape, no dim + ret_dims = tuple(range(input_dim)) + ret_shape = tuple(input_sizes) + else: + # No shape, has dim + ret_shape = tuple(input_sizes[d] for d in ret_dims) # type: ignore[possibly-undefined] + + for n in ret_shape: + torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified") + + return _ShapeAndDims(shape=ret_shape, dims=ret_dims) # type: ignore[possibly-undefined] + + +def _prod(xs: Iterable[int]) -> int: + """Compute product of a list""" + prod = 1 + for x in xs: + prod *= x + return prod + + +def _fftn_c2c( + function_name: str, + input: TensorLikeType, + shape: tuple[int, ...], + dim: tuple[int, ...], + norm: NormType, + forward: bool, +) -> TensorLikeType: + """Common code for n-dimensional complex to complex FFTs (fftn or ifftn)""" + torch._check( + input.dtype.is_complex, + lambda: f"{function_name} expects a complex input tensor, " + f"but got {input.dtype}", + ) + x = _resize_fft_input(input, dim, shape) + output = prims.fft_c2c(x, dim=dim, forward=forward) + return _apply_norm(output, norm=norm, signal_numel=_prod(shape), forward=forward) + + +@register_decomposition(aten.fft_fftn) +@out_wrapper() +def fftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) + x = _maybe_promote_tensor_fft(input, require_complex=True) + return _fftn_c2c("fftn", x, shape, dim, norm, forward=True) + + +@register_decomposition(aten.fft_ifftn) +@out_wrapper() +def ifftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) + x = _maybe_promote_tensor_fft(input, require_complex=True) + return _fftn_c2c("ifftn", x, shape, dim, norm, forward=False) + + +@register_decomposition(aten.fft_rfftn) +@out_wrapper() +def rfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + torch._check( + not input.dtype.is_complex, + lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}", + ) + shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim) + input = _maybe_promote_tensor_fft(input, require_complex=False) + input = _resize_fft_input(input, dim, shape) + out = prims.fft_r2c(input, dim=dim, onesided=True) + return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True) + + +@register_decomposition(aten.fft_ihfftn) +@out_wrapper() +def ihfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + torch._check( + not input.dtype.is_complex, + lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}", + ) + shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim) + torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis") + input = _maybe_promote_tensor_fft(input, require_complex=False) + input = _resize_fft_input(input, dim, shape) + + tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True) + + if len(dim) == 1: + tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False) + return prims.conj(tmp) + + tmp = prims.conj_physical(tmp) + tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False) + return _apply_norm(tmp, norm=norm, signal_numel=_prod(shape), forward=False) + + +class _CanonicalizeC2rReturn(NamedTuple): + shape: tuple[int, ...] + dim: tuple[int, ...] + last_dim_size: int + + +def _canonicalize_fft_c2r_shape_and_dim_args( + fname: str, + input: TensorLikeType, + s: Optional[ShapeType], + dim: Optional[DimsType], +) -> _CanonicalizeC2rReturn: + """Canonicalize shape and dim arguments for n-dimensional c2r transforms, + as well as calculating the last_dim_size which is shape[dim[-1]] for the output""" + (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim) + torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis") + + if s is None or s[-1] == -1: + last_dim_size = 2 * (input.shape[dim[-1]] - 1) + else: + last_dim_size = shape[-1] + + torch._check( + last_dim_size >= 1, + lambda: f"Invalid number of data points ({last_dim_size}) specified", + ) + + shape_list = list(shape) + shape_list[-1] = last_dim_size // 2 + 1 + return _CanonicalizeC2rReturn( + shape=tuple(shape_list), dim=dim, last_dim_size=last_dim_size + ) + + +@register_decomposition(aten.fft_irfftn) +@out_wrapper() +def irfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args( + "irfftn", input, s, dim + ) + input = _maybe_promote_tensor_fft(input, require_complex=True) + input = _resize_fft_input(input, dim, shape) + out = prims.fft_c2r(input, dim=dim, last_dim_size=last_dim_size) + return _apply_norm(out, norm, _prod(out.shape[d] for d in dim), forward=False) + + +@register_decomposition(aten.fft_hfftn) +@out_wrapper() +def hfftn( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = None, + norm: NormType = None, +) -> TensorLikeType: + shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args( + "hfftn", input, s, dim + ) + input = _maybe_promote_tensor_fft(input, require_complex=True) + input = _resize_fft_input(input, dim, shape) + + tmp = prims.fft_c2c(input, dim=dim[:-1], forward=True) if len(dim) > 1 else input + tmp = _apply_norm(tmp, norm, _prod(shape[:-1]), forward=True) + tmp = prims.conj_physical(tmp) + out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size) + return _apply_norm(out, norm, last_dim_size, forward=True) + + +@register_decomposition(aten.fft_fft2) +@out_wrapper() +def fft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.fftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_ifft2) +@out_wrapper() +def ifft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.ifftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_rfft2) +@out_wrapper() +def rfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.rfftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_irfft2) +@out_wrapper() +def irfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.irfftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_hfft2) +@out_wrapper() +def hfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.hfftn(input, s=s, dim=dim, norm=norm) + + +@register_decomposition(aten.fft_ihfft2) +@out_wrapper() +def ihfft2( + input: TensorLikeType, + s: Optional[ShapeType] = None, + dim: Optional[DimsType] = (-2, -1), + norm: NormType = None, +) -> TensorLikeType: + return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm) + + +def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> list[int]: + """Convert Optional[DimsType] to a simple list, defaulting to all dimensions""" + if dim is None: + return list(range(x.ndim)) + elif not isinstance(dim, Sequence): + return [dim] + else: + return list(dim) + + +@register_decomposition(aten.fft_fftshift) +def fftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: + dims = _default_alldims(dim, input) + shift = [input.shape[d] // 2 for d in dims] + return torch.roll(input, shift, dims) + + +@register_decomposition(aten.fft_ifftshift) +def ifftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType: + dims = _default_alldims(dim, input) + shift = [(input.shape[d] + 1) // 2 for d in dims] + return torch.roll(input, shift, dims) diff --git a/phivenv/Lib/site-packages/torch/_refs/linalg/__init__.py b/phivenv/Lib/site-packages/torch/_refs/linalg/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4c6f4b4a47777cb6bb1ead5c629a62600e4a519 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_refs/linalg/__init__.py @@ -0,0 +1,343 @@ +# mypy: allow-untyped-defs +from functools import partial +from typing import Optional, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch._refs as refs +import torch._refs.linalg as linalg +from torch import Tensor +from torch._prims_common import ( + check_fp_or_complex, + check_is_matrix, + Dim, + DimsType, + ELEMENTWISE_TYPE_PROMOTION_KIND, + IntLike, + TensorLikeType, +) +from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, + elementwise_type_promotion_wrapper, + out_wrapper, +) + + +__all__ = [ + "diagonal", + "matrix_norm", + "norm", + "svd", + "svdvals", + "vector_norm", + "vecdot", + "cross", +] + + +def _check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str): + """ + Checks related to the dtype kwarg in `linalg.*norm` functions + """ + if dtype is not None: + torch._check( + utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype), + lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}", + ) + torch._check( + utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype), + lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format( + fn_name=fn_name, + d="complex" if utils.is_complex_dtype(x_dtype) else "real", + dtype=dtype, + ), + ) + torch._check( + utils.get_higher_dtype(dtype, x_dtype) == dtype, + lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible " + "without narrowing to the specified dtype ({dtype})", + ) + + +import operator + +# Utilities should come BEFORE this import +from torch._decomp import register_decomposition +from torch._decomp.decompositions import pw_cast_for_opmath + + +@register_decomposition(torch._ops.ops.aten.linalg_cross) +@out_wrapper() +@pw_cast_for_opmath +def cross(a: Tensor, b: Tensor, dim: int = -1): + torch._check( + a.ndim == b.ndim, + lambda: "linalg.cross: inputs must have the same number of dimensions.", + ) + torch._check( + a.size(dim) == 3 and b.size(dim) == 3, + lambda: f"linalg.cross: inputs dim {dim} must have length 3, got {a.size(dim)} and {b.size(dim)}", + ) + a, b = torch.broadcast_tensors(a, b) + dim = utils.canonicalize_dim(a.ndim, dim) + idx = torch.arange(3, device=a.device) + return a.index_select(dim, (idx + 1) % 3) * b.index_select( + dim, (idx + 2) % 3 + ) - a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3) + + +def diagonal( + input: TensorLikeType, + *, + offset: int = 0, + dim1: int = -2, + dim2: int = -1, +) -> TensorLikeType: + return torch.diagonal(input, offset=offset, dim1=dim1, dim2=dim2) + + +def _check_vector_norm_args( + x: TensorLikeType, ord: Union[float, int] = 2, dim: Optional[DimsType] = None +): + from torch.fx.experimental.symbolic_shapes import sym_or + + if not (ord < 0.0 or ord == float("inf")): + return + + torch._check( + sym_or( + x.numel() != 0, + not isinstance(dim, IntLike) and dim is not None and len(dim) != 0, + ), + "linalg.vector_norm cannot compute the {ord} norm on an empty tensor " + "because the operation does not have an identity", + ) + + shape = x.shape + if dim is not None and not isinstance(dim, IntLike): + for d in dim: + torch._check( + sym_or(x.numel() != 0, d < len(shape) and d >= 0 and shape[d] != 0), + "linalg.vector_norm cannot compute the {ord} norm on the " + f"dimension {d} because this dimension is empty and the " + "operation does not have an identity", + ) + + +@register_decomposition(torch._ops.ops.aten.linalg_vector_norm) +@out_wrapper(exact_dtype=True) +def vector_norm( + x: TensorLikeType, + ord: Union[float, int] = 2, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> Tensor: + from torch.fx.experimental.symbolic_shapes import guard_or_false + + check_fp_or_complex(x.dtype, "linalg.vector_norm") + + if isinstance(dim, Dim): + dim = [dim] # type: ignore[assignment] + + _check_vector_norm_args(x, ord, dim) + + _check_norm_dtype(dtype, x.dtype, "linalg.vector_norm") + + computation_dtype, result_dtype = utils.reduction_dtypes( + x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype + ) + + to_result_dtype = partial(_maybe_convert_to_dtype, dtype=result_dtype) + + # Implementation + if ord == 0.0: + return torch.sum(torch.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype) + elif ord == float("inf"): + return to_result_dtype(torch.amax(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type] + elif ord == float("-inf"): + return to_result_dtype(torch.amin(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type] + else: + # From here on the computation dtype is important as the reduction is non-trivial + x = _maybe_convert_to_dtype(x, computation_dtype) # type: ignore[assignment] + reduce_sum = partial(torch.sum, dim=dim, keepdim=keepdim) + + is_ord_even = ord % 2 == 0 if isinstance(ord, IntLike) else ord % 2.0 == 0.0 + if dim == []: + dim = None + + if (dim is None and x.numel() == 1) or ( + dim is not None + and (x.ndim > 0 and all(guard_or_false(x.shape[d] == 1) for d in dim)) + ): + if x.ndim > 64: + raise RuntimeError( + f"Received a tensor with {x.ndim} dimensions, but only tensors with up to 64 dims are supported!" + ) + x = torch.abs(x) + if keepdim or x.ndim == 0: + return to_result_dtype(x).contiguous() + elif dim is None: + return x.flatten()[0] + else: + new_shape = [s for d, s in enumerate(x.shape) if d not in dim] + return to_result_dtype(x.view(new_shape)).contiguous() + + if not (is_ord_even and utils.is_float_dtype(x.dtype)): + x = torch.abs(x) + return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord)) # type: ignore[return-value] + + +def _backshift_permutation(dim0, dim1, ndim): + # Auxiliary function for matrix_norm + # Computes the permutation that moves the two given dimensions to the back + ret = [i for i in range(ndim) if i != dim0 and i != dim1] + ret.extend((dim0, dim1)) + return ret + + +def _inverse_permutation(perm): + # Given a permutation, returns its inverse. It's equivalent to argsort on an array + return [i for i, j in sorted(enumerate(perm), key=operator.itemgetter(1))] + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def matrix_norm( + A: TensorLikeType, + ord: Union[float, str] = "fro", + dim: DimsType = (-2, -1), + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # shape + check_is_matrix(A, "linalg.matrix_norm") + # dim + dim = utils.canonicalize_dims(A.ndim, dim) + if isinstance(dim, Dim): + dim = (dim,) # type: ignore[assignment] + torch._check( + len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}" + ) + torch._check( + dim[0] != dim[1], + lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})", + ) + # dtype arg + _check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm") + + if isinstance(ord, str): + # ord + torch._check( + ord in ("fro", "nuc"), + lambda: "linalg.matrix_norm: Order {ord} not supported.", + ) + # dtype + check_fp_or_complex( + A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != "nuc" + ) + + if ord == "fro": + return vector_norm(A, 2, dim, keepdim, dtype=dtype) + else: # ord == "nuc" + if dtype is not None: + A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment] + perm = _backshift_permutation(dim[0], dim[1], A.ndim) + result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim) + if keepdim: + inv_perm = _inverse_permutation(perm) + result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) + return result + else: + # ord + abs_ord = abs(ord) + torch._check( + abs_ord in (2, 1, float("inf")), + lambda: "linalg.matrix_norm: Order {ord} not supported.", + ) + # dtype + check_fp_or_complex( + A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != 2 + ) + + max_min = partial(torch.amax if ord > 0.0 else torch.amin, keepdim=keepdim) + + if abs_ord == 2.0: + if dtype is not None: + A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment] + perm = _backshift_permutation(dim[0], dim[1], A.ndim) + result = max_min(svdvals(prims.transpose(A, perm)), dim=-1) + if keepdim: + inv_perm = _inverse_permutation(perm) + result = prims.transpose(torch.unsqueeze(result, -1), inv_perm) + return result + else: # 1, -1, inf, -inf + dim0, dim1 = dim + if abs_ord == float("inf"): + dim0, dim1 = dim1, dim0 + if not keepdim and (dim0 < dim1): + dim1 -= 1 + return max_min( + vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype), dim1 + ) + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def norm( + A: TensorLikeType, + ord: Optional[Union[float, str]] = None, + dim: Optional[DimsType] = None, + keepdim: bool = False, + *, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + if dim is not None: + if isinstance(dim, Dim): + dim = (dim,) # type: ignore[assignment] + torch._check( + len(dim) in (1, 2), + lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}", + ) + elif ord is not None: + torch._check( + A.ndim in (1, 2), + lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D", + ) + + if ord is not None and ( + (dim is not None and len(dim) == 2) or (dim is None and A.ndim == 2) + ): + if dim is None: + dim = (0, 1) + return matrix_norm(A, ord, dim, keepdim, dtype=dtype) + else: + if ord is None: + ord = 2.0 + return vector_norm(A, ord, dim, keepdim, dtype=dtype) # type: ignore[arg-type] + + +# CompositeImplicitAutograd +@out_wrapper("U", "S", "Vh", exact_dtype=True) +def svd(A: TensorLikeType, full_matrices: bool = True) -> tuple[Tensor, Tensor, Tensor]: + return prims.svd(A, full_matrices=full_matrices) + + +# CompositeImplicitAutograd +@out_wrapper(exact_dtype=True) +def svdvals(A: TensorLikeType) -> Tensor: + return svd(A, full_matrices=False)[1] + + +# CompositeImplicitAutograd +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("x", "y"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def vecdot(x: Tensor, y: Tensor, dim: int = -1) -> Tensor: + check_fp_or_complex(x.dtype, "linalg.vecdot") + return (x.conj() * y).sum(dim=dim) diff --git a/phivenv/Lib/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5da732b5a0987d55e8cfec84ae783ae2f09b8d27 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_refs/linalg/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_refs/nn/__init__.py b/phivenv/Lib/site-packages/torch/_refs/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..feb6ae10c7741c940bdb7937c2886f976e744140 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_refs/nn/__init__.py @@ -0,0 +1 @@ +__all__: list[str] = [] diff --git a/phivenv/Lib/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65bba31455339311d52c6e11144a9f664fcca380 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_refs/nn/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_refs/nn/functional/__init__.py b/phivenv/Lib/site-packages/torch/_refs/nn/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e65f2dcacb67939dbd573b4a176d0f7d311fbcf3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_refs/nn/functional/__init__.py @@ -0,0 +1,1289 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import math +from functools import wraps +from typing import Callable, Optional, TypeVar, Union +from typing_extensions import Concatenate, ParamSpec + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch._refs as refs +from torch._decomp import register_decomposition +from torch._prims_common import ( + ELEMENTWISE_TYPE_PROMOTION_KIND, + NumberType, + ShapeType, + TensorLike, + TensorLikeType, +) +from torch._prims_common.wrappers import ( + elementwise_type_promotion_wrapper, + elementwise_unary_scalar_wrapper, + out_wrapper, +) +from torch._refs import _make_inplace + + +__all__ = [ + "alpha_dropout", + "celu", + "celu_", + "channel_shuffle", + "dropout", + "elu", + "elu_", + "gelu", + "glu", + "group_norm", + "hardshrink", + "hardtanh", + "hinge_embedding_loss", + "huber_loss", + "l1_loss", + "layer_norm", + "leaky_relu", + "log_softmax", + "margin_ranking_loss", + "mish", + "mish_", + "mse_loss", + "nll_loss", + "pairwise_distance", + "pdist", + "poisson_nll_loss", + "prelu", + "relu", + "relu6", + "selu", + "selu_", + "smooth_l1_loss", + "softmax", + "softmin", + "softplus", + "softshrink", + "tanhshrink", + "threshold", + "threshold_", + "triplet_margin_loss", +] + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +Tensor = torch.Tensor +aten = torch._ops.ops.aten +DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] + + +def _dropout_helper( + self: TensorLikeType, + val: float, +) -> TensorLikeType: + """ + Helper function for all dropout-type operators. During training, + some of the elements of the input tensor are randomly masked. + + Returns the masked tensor of the boolean values. + + """ + + return ( + refs._uniform_helper( + self.shape, low=0.0, high=1.0, dtype=torch.float32, device=self.device + ) + < val + ) + + +@register_decomposition(aten.alpha_dropout) +def alpha_dropout( + self: TensorLikeType, p: float = 0.5, training: bool = False, inplace: bool = False +) -> TensorLikeType: + if inplace: + raise NotImplementedError + + if not training: + return self + + torch._check( + p <= 1 and p >= 0, + lambda: f"dropout probability has to be between 0 and 1, but got, {p}", + ) + + if p == 1: + return torch.zeros_like(self) + + if p == 0: + return self + + dropout_mask = _dropout_helper(self, 1 - p) + + # From paper: Self-Normalizing Neural Networks (https://arxiv.org/pdf/1706.02515.pdf) + # alpha = - SELU.alpha * SELU.scale, here + # SELU.alpha = 1.6732632423543772848170429916717 and + # SELU.scale = 1.0507009873554804934193349852946 + alpha = -1.7580993408473766 + + a = 1.0 / math.sqrt((alpha * alpha * p + 1) * (1 - p)) + b = torch.logical_not(dropout_mask) + b = b * (alpha * a) + alpha * a * p + dropout_mask = a * dropout_mask + + return self * dropout_mask + b + + +def _inplace_wrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]: + """ + Given a nn.functional non-linearity, implements its `inplace: bool` argument + """ + + # nb. We use the name of the first argument used in the unary references + @wraps(fn) + def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T: + a = args[0] + if "inplace" not in kwargs: + kwargs["inplace"] = False + if kwargs["inplace"]: + torch._check( + "out" not in kwargs, + lambda: "Cannot set inplace=True and pass out= at the same time", + ) + kwargs["inplace"] = False + kwargs["out"] = a + return fn(*args, **kwargs) + else: + return fn(*args, **kwargs) + + return _fn + + +# celu is implemented specially because it has an alpha argument +# celu is very similar to elu +@register_decomposition(aten.celu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def celu( + a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.celu + """ + + if inplace: + raise NotImplementedError + + rhs: TensorLikeType + if alpha is not None: + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(alpha), python_type): + msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + rhs = alpha * torch.expm1(torch.true_divide(a, alpha)) # type: ignore[arg-type] + else: + rhs = torch.expm1(a) + + return torch.where(a > 0, a, rhs) + + +@_inplace_wrapper +@out_wrapper() +def dropout( + a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False +) -> TensorLikeType: + if inplace: + raise NotImplementedError + + if not training: + return a + + torch._check( + p <= 1 and p >= 0, + lambda: f"dropout probability has to be between 0 and 1, but got, {p}", + ) + + if p == 1: + return torch.zeros_like(a) + + if p == 0: + return a + + scale = 1 / (1 - p) + dropout_mask = _dropout_helper(a, 1 - p) + + return a * dropout_mask * scale + + +@register_decomposition(aten.elu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def elu( + a: TensorLikeType, + alpha: NumberType = 1.0, + scale: NumberType = 1.0, + input_scale: NumberType = 1.0, + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.elu + """ + if inplace: + raise NotImplementedError + + # nb. This should be factored out into a can_cast aux function + python_type = utils.dtype_to_type(a.dtype) + torch._check( + utils.is_weakly_lesser_type(type(input_scale), python_type), + lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!", + ) + torch._check( + utils.is_weakly_lesser_type(type(scale), python_type), + lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!", + ) + torch._check( + utils.is_weakly_lesser_type(type(alpha), python_type), + lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", + ) + + return torch.where(a > 0, scale * a, (alpha * scale) * torch.expm1(a * input_scale)) + + +@register_decomposition(aten.relu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def relu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.relu + """ + + if inplace: + raise NotImplementedError + + return torch.where(torch.le(a, 0), 0, a) + + +@register_decomposition(aten.channel_shuffle) +@out_wrapper() +def channel_shuffle(input: TensorLikeType, groups: int) -> TensorLikeType: + """ + Reference implementation of :func:`torch.nn.functional.channel_shuffle`. + """ + from torch._meta_registrations import device_hint + + torch._check( + input.dim() > 2, + lambda: f"channel_shuffle expects input with > 2 dims, but got input with sizes {list(input.size())}", + ) + c = input.shape[1] + torch._check( + groups > 0, + lambda: f"Number of groups to divide channels in must be positive. Value of groups:{groups}", + ) + torch._check( + (c % groups) == 0, + lambda: f"Number of channels must be divisible by groups. Got {c} channels and {groups} groups.", + ) + n = input.shape[0] + cg = c // groups + dhw = input.shape[2:] + + if input.numel() == 0 or ( + device_hint(input) == "cuda" and (groups == 1 or groups == c) + ): + return input.view(input.shape) + + return ( + input.reshape(n, groups, cg, *dhw) + .transpose(1, 2) + .reshape(input.shape) + .contiguous() + ) + + +def group_norm( + input: Tensor, + num_groups: int, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + """ + Reference implementation of :func:`torch.nn.functional.group_norm`. + """ + torch._check( + input.ndim >= 2, + lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", + ) + + batch_size = input.shape[0] + num_channels = input.shape[1] + torch._check( + num_channels % num_groups == 0, + lambda: "Expected number of channels in input to be divisible by num_groups, " + + f"but got input of shape {input.shape} and num_groups = {num_groups}", + ) + + # input shape is (N, C, *), so we flatten all inner dimensions except (N, C) + flattened_inner_size = 1 + for dim_length in input.shape[2:]: + flattened_inner_size *= dim_length + + return torch.native_group_norm( + input, + weight, + bias, + batch_size, + num_channels, + flattened_inner_size, + num_groups, + eps, + )[0] + + +def layer_norm( + input: Tensor, + normalized_shape: ShapeType, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + """ + Reference implementation of :func:`torch.nn.functional.layer_norm`. + """ + return torch.native_layer_norm(input, normalized_shape, weight, bias, eps)[0] + + +@register_decomposition(aten.leaky_relu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def leaky_relu( + a: TensorLikeType, negative_slope: float = 0.01, inplace: bool = False +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.leaky_relu + """ + + if inplace: + raise NotImplementedError + + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(negative_slope), python_type): + msg = f"negative_slope argument of type {type(negative_slope)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + return torch.where(torch.gt(a, 0), a, torch.mul(a, negative_slope)) + + +@register_decomposition(aten.mish) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def mish(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.mish + """ + + if inplace: + raise NotImplementedError + return a * torch.tanh(torch.nn.functional.softplus(a)) + + +@register_decomposition(aten.selu) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def selu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.selu + """ + if inplace: + raise NotImplementedError + + alpha = 1.6732632423543772848170429916717 + scale = 1.0507009873554804934193349852946 + + rhs = alpha * torch.expm1(a) + + return scale * torch.where(a > 0, a, rhs) + + +# Forwarding alias: the functional variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def softmax( + a: TensorLikeType, + dim: Optional[int] = None, + _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # The error is for compat with regular PyTorch, which has this behavior + # deprecated. For PrimTorch, it's fine to drop support for deprecated + # behavior because it requires explicit opt in. This error is to inform + # users how to update their calls. + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") + return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +# CompositeImplicitAutograd - don't register decomp +def softmin( + a: TensorLikeType, + dim: Optional[int] = None, + _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # The error is for compat with regular PyTorch, which has this behavior + # deprecated. For PrimTorch, it's fine to drop support for deprecated + # behavior because it requires explicit opt in. This error is to inform + # users how to update their calls. + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") + return torch.softmax(a=-a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +# softplus is implemented specially because it has beta and threshold arguments +@register_decomposition(aten.softplus) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def softplus( + a: TensorLikeType, + beta: Optional[NumberType] = None, + threshold: NumberType = 20, + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.softplus + """ + + if inplace: + raise NotImplementedError + + rhs: TensorLikeType + if beta is not None: + python_type = utils.dtype_to_type(a.dtype) + if not utils.is_weakly_lesser_type(type(beta), python_type): + msg = f"beta argument of type {type(beta)} cannot be safely cast to type {python_type}!" + raise ValueError(msg) + scaled_input = a * beta + rhs = torch.true_divide(torch.log1p(torch.exp(scaled_input)), beta) # type: ignore[arg-type] + + else: + scaled_input = a + rhs = torch.log1p(torch.exp(scaled_input)) + + return torch.where(scaled_input > threshold, a, rhs) + + +@aten.hardshrink.default.py_impl(DispatchKey.Autograd) +@register_decomposition(aten.hardshrink) +@out_wrapper() +def hardshrink(a: TensorLikeType, lambd: float = 0.5): + # Formula for reference, + # hardshrink(x) = x if x > lambd + # = x if x < -lambd + # = 0 otherwise + return torch.where(torch.abs(a) <= lambd, 0, a) + + +@aten.softshrink.default.py_impl(DispatchKey.Autograd) +@register_decomposition(aten.softshrink) +@out_wrapper() +def softshrink(a: TensorLikeType, lambd: float = 0.5): + # Formula for reference, + # softshrink(x) = x - lambd if x > lambd + # = x + lambd if x < -lambd + # = 0 otherwise + torch._check( + lambd >= 0, + lambda: f"lambda must be greater or equal to 0, but found to be {lambd}", + ) + # We implement this in one torch.where to generate better code in the backward + # see https://github.com/pytorch/pytorch/pull/107052#discussion_r1293748211 + # We multiply by 0 for dealing with nans + return torch.where(torch.abs(a) > lambd, a - torch.sign(a) * lambd, a * 0) + + +# Losses +def _reduction_int_to_str(reduction: int) -> str: + from torch._decomp.decompositions import Reduction + + if reduction == Reduction.NONE.value: + return "none" + elif reduction == Reduction.MEAN.value: + return "mean" + elif reduction == Reduction.SUM.value: + return "sum" + else: + raise ValueError(f"{reduction} is not a valid value for reduction") + + +def _apply_loss_reduction(loss: TensorLikeType, reduction: str) -> TensorLikeType: + if reduction == "sum": + return torch.sum(loss) + elif reduction == "mean": + return torch.mean(loss) + else: # reduction == "none" + return loss + + +def _check_reduction_value(reduction: str): + if reduction not in ("mean", "sum", "none"): + raise ValueError(f"{reduction} is not a valid value for reduction") + + +# This helper function maps depreciated arguments, "size_average" and "reduce" +# to their corresponding "reduction" string argument +def _get_string_reduction_arg( + *, size_average: Optional[bool], reduce: Optional[bool] +) -> str: + if size_average is None: + size_average = True + if reduce is None: + reduce = True + if size_average and reduce: + ret = "mean" + elif reduce: + ret = "sum" + else: + ret = "none" + return ret + + +# CompositeImplicitAutograd - don't register decomp +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) +def l1_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.l1_loss + """ + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + loss = torch.abs(input - target) + return _apply_loss_reduction(loss, reduction) + + +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) +def smooth_l1_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", + beta: float = 1.0, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.smooth_l1_loss + """ + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + + if beta == 0.0: + return torch.nn.functional.l1_loss( + input, target, size_average=size_average, reduce=reduce, reduction=reduction + ) + else: + loss = torch.abs(input - target) + loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta) + return _apply_loss_reduction(loss, reduction) + + +# Forwarding alias: the functional variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def log_softmax( + a: TensorLikeType, + dim: Optional[int] = None, + _stacklevel: int = 3, # for compat when using TorchRefsMode(strict=True) + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + # The error is for compat with regular PyTorch, which has this behavior + # deprecated. For PrimTorch, it's fine to drop support for deprecated + # behavior because it requires explicit opt in. This error is to inform + # users how to update their calls. + torch._check(dim is not None, lambda: "implicit dim not supported, use dim=X") + return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +@register_decomposition(aten.margin_ranking_loss) +def margin_ranking_loss( + input1: TensorLikeType, + input2: TensorLikeType, + target: TensorLikeType, + margin: float = 0.0, + reduction: str = "mean", +) -> TensorLikeType: + # loss_without_reduction = max(0, -target * (input1 - input2) + margin) + if input1.ndim != input2.ndim or input1.ndim != target.ndim: + raise RuntimeError( + "margin_ranking_loss : All input tensors should have same dimension but got sizes: " + f"input1: {input1.shape}, input2: {input2.shape}, target: {target.shape} " + ) + _check_reduction_value(reduction) + loss = torch.clamp_min(-target * (input1 - input2) + margin, 0) + return _apply_loss_reduction(loss, reduction) + + +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, +) +def mse_loss( + input: TensorLikeType, + target: TensorLikeType, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + loss = torch.pow(input - target, 2) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.hinge_embedding_loss) +def hinge_embedding_loss( + input: TensorLikeType, + target: TensorLikeType, + margin: float = 1.0, + reduction: str = "mean", +) -> TensorLikeType: + # loss_without_reduction = input if y == 1 + # = max(0, margin - input) if y == -1 + _check_reduction_value(reduction) + margin_clamp = torch.clamp_min(margin - input, 0) + output_margin = torch.where(target != 1, margin_clamp, 0) + output_self = torch.where(target != -1, input, 0) + loss = output_margin + output_self + return _apply_loss_reduction(loss, reduction) + + +def _nll_loss_nd( + input: TensorLikeType, + target: TensorLikeType, + weight: Optional[TensorLikeType], + reduction: str, + ignore_index: int, +) -> TensorLikeType: + torch._check( + input.ndim > 0 and input.ndim <= 3, + lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.", + ) + + torch._check( + (input.ndim == 1) or (input.shape[0] == target.shape[0]), + lambda: f"Expected input batch size {input.shape[0]} to match target batch size {target.shape[0]}.", + ) + + _check_reduction_value(reduction) + + flat_target = torch.flatten(target) + ignore_classes_mask = torch.eq(flat_target, ignore_index) + + # TODO: Enable data-dependent checks with debug mode + # TODO: This check does not work with FakeTensor inputs; See Issue #85834 + # Explicit cast for class_check to bool; See Issue #78071 + """ + from torch._subclasses.fake_tensor import FakeTensor + num_classes = input.shape[1] if input.ndim > 1 else input.shape[0] + valid_classes_mask = torch.logical_and( + (flat_target >= 0), (flat_target < num_classes) + ) + class_check = torch.all(torch.logical_or(ignore_classes_mask, valid_classes_mask)) + torch._check( + isinstance(target, FakeTensor) or bool(class_check.item()), + lambda: "A target class is out-of-bounds and not the ignore index.", + ) + """ + + ignore_class_weight = torch.scalar_tensor(0, dtype=input.dtype, device=input.device) + class_weight = ( + torch.scalar_tensor(1, dtype=input.dtype, device=input.device) + if weight is None + else weight[flat_target] + ) + current_weight = torch.where( + ignore_classes_mask, + ignore_class_weight, + class_weight, + ) + + if input.ndim == 1: + # implicit batch size = 1 + # input (1 batch size, C classes) + loss = -input[target] * current_weight + elif input.ndim == 2: + # input (N batch size, C classes) + batch_size = input.shape[0] + loss = -input[torch.arange(batch_size), target] * current_weight + else: + # 3D case (N batch size, C classe, K dimensions) + # input (N batch size, C classes, K) + batch_size = input.shape[0] + extent = input.shape[2] + numel = batch_size * extent + indices = torch.arange(numel) + bdx = indices // extent + kdx = indices % extent + loss = -input[bdx, flat_target, kdx] * current_weight + loss = torch.reshape(loss, target.shape) + + if reduction == "none": + return loss + elif reduction == "sum": + return torch.sum(loss) + else: + # calculate weighted mean of the loss function + return torch.sum(loss) / torch.sum(current_weight) + + +@register_decomposition(aten.nll_loss) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("input",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def nll_loss( + input: TensorLikeType, + target: TensorLikeType, + weight: Optional[TensorLikeType] = None, + size_average: Optional[bool] = None, + ignore_index: int = -100, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.nll_loss + """ + torch._check( + input.ndim > 0, + lambda: f"Expected input tensor to have 1 or more dimensions (got {input.ndim})", + ) + + # TODO: raise exception instead of converting value + # msg = "size_average and reduce args are deprecated, please use reduction argument." + # Convert these options for consistency with the eager mode + if size_average is not None or reduce is not None: + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + + # The expected behavior when the target and input have zero elements: + # reduction = 'none' --- tensor([]) + # reduction = 'sum' --- tensor(0.) + # reduction = 'mean' --- tensor(nan) + # Mean reduction on empty tensors produces NaN. See the discussion in + # https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162 + if input.numel() == 0 and target.numel() == 0: + if reduction == "none": + return torch.zeros_like(target) + elif reduction == "sum": + return torch.empty_like(target) + else: + return torch.full_like(target, float("nan")) + + # The _nll_loss_nd helper function handles the most common cases. + # ndim == 1 (Single Example) + # => Batch Size: 1, Input: (C), Target: () + # ndim == 2 (k = 1) + # => Batch Size: N, Input: (N, C), Target: (N) + # ndim == 3 (k > 1) + # => Batch Size: N, Input: (N, C, K), Target: (N, K) + if input.ndim <= 3: + return _nll_loss_nd(input, target, weight, reduction, ignore_index) + + # For ndim > 3, we reshape the input and target to 3-D case. + # Input (N batch-size, C classes, k-dimensions) + # Target (N batch-size, k-dimensions) + torch._check( + input.ndim > 0 and target.ndim > 0 and target.shape[1:] == input.shape[2:], + lambda: ( + "Expected input and target to both have ndim > 0 and " + "target.shape[1:] == input.shape[2:], but got " + f"target.shape {target.shape} and input.shape {input.shape}" + ), + ) + + batch_size = input.shape[0] + num_classes = input.shape[1] + out_size = [batch_size] + list(target.shape[1:]) + + input = torch.reshape(input, [batch_size, num_classes, -1]) + target = torch.reshape(target, [batch_size, -1]) + if reduction != "none": + return _nll_loss_nd(input, target, weight, reduction, ignore_index) + else: + result = _nll_loss_nd(input, target, weight, reduction, ignore_index) + # reshape flattened inner-dim to original k-dimensions + return torch.reshape(result, out_size) + + +# TODO: This ref supports int reduction and out kwarg to be compatible with ATen: +# https://github.com/pytorch/pytorch/issues/83931 +# TODO: Could be rewritten to support complex: +# https://github.com/pytorch/pytorch/pull/85041 +@register_decomposition(aten.huber_loss) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def huber_loss( + input: TensorLikeType, + target: TensorLikeType, + reduction: Union[str, int] = "mean", + delta: float = 1.0, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.huber_loss + """ + if type(reduction) is int: + reduction = _reduction_int_to_str(reduction) + _check_reduction_value(reduction) # type: ignore[arg-type] + torch._check( + delta > 0, + lambda: "huber_loss does not support non-positive values for delta.", + ) + z = (input - target).abs() + loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta)) + return _apply_loss_reduction(loss, reduction) # type: ignore[arg-type] + + +# tanhshrink does not use _make_elementwise_unary_reference because it does not support out +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def tanhshrink(a: TensorLikeType) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.tanhshrink + """ + if not isinstance(a, TensorLike): + raise RuntimeError( + "Expected a tensor input for an elementwise unary operation!" + ) + return a - torch.tanh(a) + + +@register_decomposition(aten.threshold) +@_inplace_wrapper +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def threshold( + a: TensorLikeType, + threshold: NumberType, + value: Union[bool, int, float], + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.threshold + """ + + if inplace: + raise NotImplementedError + + return torch.where(a <= threshold, value, a) + + +# CompositeImplicitAutograd - don't register decomp +# No elementwise type promotion - core op doesn't explicitly type promote +def triplet_margin_loss( + anchor: TensorLikeType, + positive: TensorLikeType, + negative: TensorLikeType, + margin: float = 1.0, + p: float = 2, + eps: float = 1e-6, + swap: bool = False, + size_average: Optional[bool] = None, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + + if margin <= 0: + raise ValueError(f"margin must be greater than 0, got {margin}") + + # torch.nn.functional.triplet_margin_with_distance_loss has no ref defined + # since it's a pure Python implementation. Use this helper instead. + return _triplet_margin_with_distance_loss( + anchor=anchor, + positive=positive, + negative=negative, + distance_function=lambda x, y: torch.pairwise_distance(x, y, p, eps), + margin=margin, + swap=swap, + reduction=reduction, + ) + + +# Pure Python impl - don't register decomp and don't add a ref. Defined as a +# helper here since triplet_margin_loss can be nicely implemented with it. +def _triplet_margin_with_distance_loss( + anchor: TensorLikeType, + positive: TensorLikeType, + negative: TensorLikeType, + *, + distance_function: Optional[ + Callable[[TensorLikeType, TensorLikeType], TensorLikeType] + ] = None, + margin: float = 1.0, + swap: bool = False, + reduction: str = "mean", +) -> TensorLikeType: + _check_reduction_value(reduction) + + a_dim = anchor.ndim + p_dim = positive.ndim + n_dim = negative.ndim + torch._check( + a_dim == p_dim and p_dim == n_dim, + lambda: ( + f"The anchor, positive, and negative tensors are expected to have " + f"the same number of dimensions, but got: anchor {a_dim}D, " + f"positive {p_dim}D, and negative {n_dim}D inputs" + ), + ) + + if distance_function is None: + distance_function = torch.pairwise_distance + + dist_pos = distance_function(anchor, positive) + dist_neg = distance_function(anchor, negative) + # The distance swap is described in the paper "Learning shallow + # convolutional feature descriptors with triplet losses" by V. Balntas, E. + # Riba et al. If True, and if the positive example is closer to the + # negative example than the anchor is, swaps the positive example and the + # anchor in the loss computation. + if swap: + dist_swap = distance_function(positive, negative) + dist_neg = torch.minimum(dist_neg, dist_swap) + loss = torch.clamp_min(margin + dist_pos - dist_neg, 0) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.hardtanh) +@_inplace_wrapper +@out_wrapper() +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def hardtanh( + a: TensorLikeType, + min_val: NumberType = -1, + max_val: NumberType = 1, + inplace: bool = False, +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.hardtanh + """ + if inplace: + raise NotImplementedError + if utils.is_boolean_dtype(a.dtype): + raise RuntimeError("Bool inputs not supported for hardtanh") + + # preserve legacy behavior of boundaries not causing type promotion + if utils.is_integer_dtype(a.dtype): + min_val = int(min_val) # type: ignore[arg-type] + max_val = int(max_val) # type: ignore[arg-type] + if not (a.dtype != torch.uint8 or (min_val >= 0 and max_val >= 0)): + raise RuntimeError( + "Cannot do hardtanh on an unsigned type with negative limits" + ) + + if min_val > max_val: # type: ignore[operator] + raise ValueError("min_val cannot be greater than max_val") + + return torch.clamp(a, min_val, max_val) # type: ignore[arg-type] + + +@register_decomposition(aten.gelu) +@out_wrapper() +@elementwise_unary_scalar_wrapper +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def gelu(a: TensorLikeType, approximate: str = "none") -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.gelu + """ + if not isinstance(a, TensorLike): + raise RuntimeError( + "Expected a tensor input for an elementwise unary operation!" + ) + M_SQRT2 = 1.41421356237309504880 + M_SQRT1_2 = 0.70710678118654752440 + M_2_SQRTPI = 1.12837916709551257390 + if approximate == "tanh": + kBeta = M_SQRT2 * M_2_SQRTPI * 0.5 + kKappa = 0.044715 + a_cube = a * a * a + inner = kBeta * (a + kKappa * a_cube) + return 0.5 * a * (1 + torch.tanh(inner)) + elif approximate == "none": + kAlpha = M_SQRT1_2 + return a * 0.5 * (1 + torch.erf(a * kAlpha)) + else: + raise RuntimeError("approximate argument must be either none or tanh.") + + +# CompositeImplicitAutograd - don't register decomp +@elementwise_type_promotion_wrapper( + type_promoting_args=("input", "target"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def poisson_nll_loss( + input: TensorLikeType, + target: TensorLikeType, + log_input: bool = True, + full: bool = False, + size_average: Optional[bool] = None, + eps: float = 1e-8, + reduce: Optional[bool] = None, + reduction: str = "mean", +) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.poisson_nll_loss + """ + if size_average is not None or reduce is not None: + # TODO: Raise exception instead of converting value. This is only for + # primTorch since it can drop support for deprecated arguments. + # msg = "size_average and reduce args are deprecated, please use reduction argument." + reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce) + _check_reduction_value(reduction) + if log_input: + loss = torch.exp(input) - target * input + else: + loss = input - target * torch.log(input + eps) + + if full: + stirling_term = ( + target * torch.log(target) - target + 0.5 * torch.log(2 * torch.pi * target) + ) + # avoid inplace add + loss = loss + stirling_term.masked_fill(target <= 1, 0) + return _apply_loss_reduction(loss, reduction) + + +@register_decomposition(aten.prelu) +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "weight"), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.prelu + """ + torch._check( + isinstance(a, TensorLike), + lambda: f"prelu: Expected `a` to be tensor, but got: {type(a)}", + ) + torch._check( + isinstance(weight, TensorLike), + lambda: f"prelu: Expected `weight` to be tensor, but got: {type(weight)}", + ) + + if weight.numel() != 1: + torch._check(a.ndim > 0, lambda: "Not allow zero-dim input tensor.") + channel_size = a.shape[1] if a.ndim >= 2 else 1 + torch._check( + weight.numel() == channel_size, + lambda: f"Mismatch of parameter numbers and input channel size. Found parameter numbers =" + f" {weight.numel()} and channel size = {channel_size}.", + ) + + torch._check( + weight.ndim == 0 or weight.ndim == 1, + lambda: f"prelu: Expected `weight` to be a scalar or 1D tensor, but got: " + f"ndim = {weight.ndim}", + ) + if a.ndim == 0: + weight = weight[0] if weight.ndim == 1 else weight + else: + weight = prims.broadcast_in_dim( + weight, a.shape, () if weight.ndim == 0 else (0 if a.ndim == 1 else 1,) + ) + + return torch.where(a > 0, a, a * weight) + + +@register_decomposition(aten.relu6) +@_inplace_wrapper +@out_wrapper() +def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: + """ + Reference implementation of torch.nn.functional.relu6 + """ + if inplace: + raise NotImplementedError + + # See https://github.com/pytorch/pytorch/pull/81142#discussion_r918220126 + # It may be better to use clamp here, but we use hardtanh to replicate + # the behavior of the existing implementation + return torch.nn.functional.hardtanh(a, 0, 6) + + +@register_decomposition(aten.glu) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType: + dim = utils.canonicalize_dims(a.ndim, dim) + torch._check( + a.shape[dim] % 2 == 0, + lambda: f"Halving dimension must be even, but dimension {dim} is size {a.shape[dim]}", + ) + b, c = torch.tensor_split(a, 2, dim) + + return b * torch.sigmoid(c) + + +@register_decomposition(aten.pairwise_distance) +@out_wrapper() +def pairwise_distance( + x1: TensorLikeType, + x2: TensorLikeType, + p: NumberType = 2.0, + eps: NumberType = 1e-6, + keepdim=False, +) -> TensorLikeType: + return torch.linalg.vector_norm(x1 - x2 + eps, ord=p, dim=-1, keepdim=keepdim) + + +@register_decomposition(aten.pdist) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, +) +def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType: + torch._check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D") + torch._check(p >= 0, lambda: "pdist only supports non-negative p values") + # For p == 2 we can use an efficient implementation, but other values of p + # require creating a much bigger tensor for an intermediate step + if p == 2: + aTa = torch.mm(a, a.T) + aTa_diag = torch.diag(aTa) + t = torch.sqrt(torch.clamp(aTa_diag + aTa_diag.unsqueeze(-1) - 2 * aTa, min=0)) + else: + t = torch.linalg.vector_norm(a.unsqueeze(1) - a, ord=p, dim=2) + i = torch.triu_indices(t.shape[0], t.shape[1], offset=1, device=a.device) + return t.flatten().index_select(0, i[0] * t.shape[0] + i[1]) + + +@register_decomposition(aten.pixel_shuffle) +@out_wrapper() +def pixel_shuffle(self: Tensor, upscale_factor: int): + torch._check( + self.dim() >= 3, + lambda: f"pixel_shuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)", + ) + batch = self.shape[:-3] + C_out = self.shape[-3] // upscale_factor**2 + HW_out = (self.shape[-2] * upscale_factor, self.shape[-1] * upscale_factor) + n = len(batch) + B_dims = range(n) + C_dim, r1_dim, r2_dim, H_dim, W_dim = range(n, n + 5) + return ( + self.view( + *batch, + C_out, + upscale_factor, + upscale_factor, + self.shape[-2], + self.shape[-1], + ) + .permute(*B_dims, C_dim, H_dim, r1_dim, W_dim, r2_dim) + .reshape(*batch, C_out, *HW_out) + .clone(memory_format=utils.suggest_memory_format(self)) + ) + + +@register_decomposition(aten.pixel_unshuffle) +@out_wrapper() +def pixel_unshuffle(self: Tensor, downscale_factor: int): + torch._check( + self.dim() >= 3, + lambda: f"pixel_unshuffle expects input to have at least 3 dimensions, but got input with {self.dim} dimension(s)", + ) + batch = self.shape[:-3] + C_out = self.shape[-3] * downscale_factor**2 + HW_out = (self.shape[-2] // downscale_factor, self.shape[-1] // downscale_factor) + n = len(batch) + B_dims = range(n) + C_dim, H_dim, r1_dim, W_dim, r2_dim = range(n, n + 5) + return ( + self.view( + *batch, + self.shape[-3], + HW_out[0], + downscale_factor, + HW_out[1], + downscale_factor, + ) + .permute(*B_dims, C_dim, r1_dim, r2_dim, H_dim, W_dim) + .reshape(*batch, C_out, *HW_out) + .clone(memory_format=utils.suggest_memory_format(self)) + ) + + +# Needed as aten.{celu_,elu_...} exist (even if they don't have the in-place kwarg) +celu_ = _make_inplace(celu) +elu_ = _make_inplace(elu) +mish_ = _make_inplace(mish) +selu_ = _make_inplace(selu) +threshold_ = _make_inplace(threshold) diff --git a/phivenv/Lib/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b286889f888ee1e50cbd78ffc113716f608da0f8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_refs/nn/functional/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_refs/special/__init__.py b/phivenv/Lib/site-packages/torch/_refs/special/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..343e65886121ef8866aed31b13e5eb7c7cb74e65 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_refs/special/__init__.py @@ -0,0 +1,236 @@ +# mypy: allow-untyped-defs +import math +from typing import Optional, Union + +import torch +import torch._prims as prims +import torch._prims_common as utils +import torch._refs as refs +from torch import Tensor +from torch._decomp import register_decomposition +from torch._prims_common import ( + ELEMENTWISE_TYPE_PROMOTION_KIND, + Number, + NumberType, + TensorLike, + TensorLikeType, +) +from torch._prims_common.wrappers import elementwise_type_promotion_wrapper, out_wrapper +from torch._refs import ( + _make_alias, + _make_elementwise_binary_reference, + _make_elementwise_unary_reference, +) + + +__all__ = [ + "bessel_j0", + "bessel_j1", + "entr", + "erfcx", + "expit", + "i0e", + "i1", + "i1e", + "log_ndtr", + "logit", + "log_softmax", + "multigammaln", + "ndtr", + "ndtri", + "softmax", + "spherical_bessel_j0", + "xlog1py", + "zeta", +] +aten = torch._ops.ops.aten + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def bessel_j0(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_j0(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def bessel_j1(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_j1(a) + + +@register_decomposition(aten.special_entr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def entr(a: TensorLikeType) -> TensorLikeType: + return torch.where( + torch.isnan(a), + a, + torch.where(a > 0, -a * torch.log(a), torch.where(a == 0, 0, -torch.inf)), + ) + + +@register_decomposition(aten.special_erfcx) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def erfcx(a: TensorLikeType) -> TensorLikeType: + return prims.erfcx(a) + + +# alias for sigmoid +expit = _make_alias(torch.sigmoid, "expit") + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def i0e(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_i0e(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def i1(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_i1(a) + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def i1e(a: TensorLikeType) -> TensorLikeType: + return prims.bessel_i1e(a) + + +@register_decomposition(aten.special_log_ndtr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def log_ndtr(a: TensorLikeType) -> TensorLikeType: + # Note: M_SQRT1_2 is the value of 1 / sqrt(2) + M_SQRT1_2 = 0.707106781186547524400844362104849039 + t = a * M_SQRT1_2 + return torch.where( + a < 1.0, + torch.log(torch.special.erfcx(-t) / 2) - t * t, + torch.log1p(-torch.erfc(t) / 2), + ) + + +@register_decomposition(aten.logit) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("self",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType: + if eps is None: + eps = -1.0 + lo = eps + hi = 1 - eps + self = torch.where(self < lo, lo, torch.where(self > hi, hi, self)) + return torch.log(torch.true_divide(self, torch.sub(1, self))) + + +@register_decomposition(aten.special_xlog1py) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]): + torch._check( + isinstance(a, TensorLike) or isinstance(b, TensorLike), + lambda: 'Expected either argument a or b to be a Tensor"', + ) + + # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors. + if isinstance(a, TensorLike) and isinstance(b, Number): + b = refs.scalar_tensor(b, dtype=a.dtype, device=a.device) + elif isinstance(b, TensorLike) and isinstance(a, Number): + a = refs.scalar_tensor(a, dtype=b.dtype, device=b.device) + + # mypy: expected "Tensor" + assert isinstance(a, TensorLike) + assert isinstance(b, TensorLike) + rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log1p(b))) + return torch.where(torch.isnan(b), float("nan"), rhs) + + +@register_decomposition(aten.mvlgamma) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType: + c = 0.25 * p * (p - 1) * math.log(math.pi) + b = 0.5 * torch.arange(start=(1 - p), end=1, step=1, dtype=a.dtype, device=a.device) + return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c + + +@register_decomposition(aten.special_ndtr) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def ndtr(a: TensorLikeType) -> TensorLikeType: + # Note: M_SQRT1_2 is the value of 1 / sqrt(2) + M_SQRT1_2 = 0.707106781186547524400844362104849039 + a_sqrt_2 = a * M_SQRT1_2 + return (1 + torch.erf(a_sqrt_2)) * 0.5 + + +@register_decomposition(aten.special_ndtri) +@out_wrapper() +@elementwise_type_promotion_wrapper( + type_promoting_args=("a",), + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def ndtri(a: TensorLikeType) -> TensorLikeType: + return prims.ndtri(a) + + +# Forwarding alias: the special variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def log_softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +# Forwarding alias: the special variant doesn't support the out kwarg +# CompositeImplicitAutograd - don't register decomp +def softmax( + a: TensorLikeType, + dim: int, + dtype: Optional[torch.dtype] = None, +) -> TensorLikeType: + return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload] + + +@_make_elementwise_unary_reference( + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def spherical_bessel_j0(a: TensorLikeType) -> TensorLikeType: + return prims.spherical_bessel_j0(a) + + +# TODO: add docstring +@_make_elementwise_binary_reference( + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) +def zeta(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.zeta(a, b) diff --git a/phivenv/Lib/site-packages/torch/_refs/special/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_refs/special/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..243d4711a68e26fd05e8cee234219126b74f4cd7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_refs/special/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_strobelight/__init__.py b/phivenv/Lib/site-packages/torch/_strobelight/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_strobelight/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_strobelight/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15ac338fedc186d9aff99ba79b339af896a3baba Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_strobelight/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_strobelight/__pycache__/cli_function_profiler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_strobelight/__pycache__/cli_function_profiler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f39344fa93a9b1186d8a172837be0677a4e05ad2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_strobelight/__pycache__/cli_function_profiler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_strobelight/__pycache__/compile_time_profiler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_strobelight/__pycache__/compile_time_profiler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be14880d37b3faca51fcf5beafcc3bc2563ae4e0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_strobelight/__pycache__/compile_time_profiler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_strobelight/cli_function_profiler.py b/phivenv/Lib/site-packages/torch/_strobelight/cli_function_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..bafdc48505d43697cee1637e12200d8f9bd94a1b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_strobelight/cli_function_profiler.py @@ -0,0 +1,321 @@ +# mypy: disallow-untyped-defs + +import functools +import logging +import os +import re +import subprocess +import time +from collections.abc import Sequence +from threading import Lock +from timeit import default_timer as timer +from typing import Any, Callable, Optional, TypeVar +from typing_extensions import ParamSpec + + +logger = logging.getLogger("strobelight_function_profiler") + +console_handler = logging.StreamHandler() +formatter = logging.Formatter( + "%(name)s, line %(lineno)d, %(asctime)s, %(levelname)s: %(message)s" +) +console_handler.setFormatter(formatter) + +logger.addHandler(console_handler) +logger.setLevel(logging.INFO) +logger.propagate = False + +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +class StrobelightCLIProfilerError(Exception): + """ + Raised when an error happens during strobelight profiling + """ + + +def _pid_namespace_link(pid: Optional[int] = None) -> str: + """Returns the link to the process's namespace, example: pid:[4026531836]""" + PID_NAMESPACE_PATH = "/proc/{}/ns/pid" + pid = pid or os.getpid() + return os.readlink(PID_NAMESPACE_PATH.format(pid)) + + +def _pid_namespace(pid: Optional[int] = None) -> int: + """Returns the process's namespace id""" + pid = pid or os.getpid() + link = _pid_namespace_link(pid) + return int(link[link.find("[") + 1 : -1]) + + +def _command_to_string(command: Sequence[str]) -> str: + return " ".join(command) + + +class StrobelightCLIFunctionProfiler: + """ + Note: this is a Meta only tool. + + StrobelightCLIFunctionProfiler can be used to profile a python function and + generate a strobelight link with the results. It works on meta servers but + does not requries an fbcode target. + When stop_at_error is false(default), error during profiling does not prevent + the work function from running. + + Check function_profiler_example.py for an example. + """ + + # This lock is used to make sure only one thread is running the profiler at any point. + _lock = Lock() + + def __init__( + self, + *, + stop_at_error: bool = False, + max_profile_duration_sec: int = 60 * 10, + sample_each: float = 1e7, # sample each sample_each cycles. + run_user_name: str = "pytorch-strobelight-ondemand", + timeout_wait_for_running_sec: int = 60, + timeout_wait_for_finished_sec: int = 60, + recorded_env_variables: Optional[list[str]] = None, + sample_tags: Optional[list[str]] = None, + stack_max_len: int = 127, + async_stack_max_len: int = 127, + ): + self.stop_at_error = stop_at_error + self.max_profile_duration_sec = max_profile_duration_sec + self.sample_each = sample_each + self.run_user_name = run_user_name + self.timeout_wait_for_running_sec = timeout_wait_for_running_sec + self.timeout_wait_for_finished_sec = timeout_wait_for_finished_sec + # Results of the most recent run. + # Tracks the strobelight run id of the most recent run + self.current_run_id: Optional[int] = None + self.profile_result: Optional[list[str]] = None + self.sample_tags = sample_tags + + def _run_async(self) -> None: + processId = os.getpid() + namespace = _pid_namespace(processId) + command = [ + "strobeclient", + "run", + "--profiler", + "pyperf", + "--event", + "cycles", + "--async", + "--sample-interval", + f"{int(self.sample_each)}", + "--duration-ms", + f"{int(self.max_profile_duration_sec * 1000)}", + "--pid", + f"{namespace}:{processId}", + ] + + if self.sample_tags: + command.append("--sample-tags") + command.append(",".join(self.sample_tags)) + + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to start strobelight profiling, error in run_async:{output}" + ) + + if match := re.search(r"INFO Run Id: (-?\d+)", output): + self.current_run_id = int(match.group(1)) + return + + raise StrobelightCLIProfilerError( + f"failed to start strobelight profiling, unexpected result {output}" + ) + + def _wait_for_running(self, counter: int = 0) -> None: + if counter > 20: + raise StrobelightCLIProfilerError( + "wait_for_running called more than 20 times" + ) + + command = ["strobeclient", "getRunStatus", "--run-id", f"{self.current_run_id}"] + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to start strobelight profiling, error in wait_for_running:{output}" + ) + + if match := re.search("Profile run status: (.*)", output): + current_status = match.group(1) + if current_status == "RUNNING": + return + elif current_status == "PREPARING": + time.sleep(10) + self._wait_for_running(counter + 1) + return + else: + raise StrobelightCLIProfilerError(f"unexpected {current_status} phase") + + raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ") + + def _stop_run(self) -> None: + command = ["strobeclient", "stopRun", "--run-id", str(self.current_run_id)] + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to stop strobelight profiling, return code is not 0 :{output}" + ) + + if match := re.search("INFO ::1:(.*)", output): + current_status = match.group(1) + if current_status.__contains__("Success!"): + return + else: + raise StrobelightCLIProfilerError( + f"failed to stop strobelight profiling, got {current_status} result" + ) + + raise StrobelightCLIProfilerError(f"unexpected output\n: {output} ") + + def _get_results(self) -> None: + command = ["strobeclient", "getRunStatus", "--run-id", str(self.current_run_id)] + logger.debug("running command: %s", _command_to_string(command)) + result = subprocess.run(command, capture_output=True) + output = result.stderr.decode("utf-8") + logger.debug("output:\n{%s}", output) + + if result.returncode != 0: + raise StrobelightCLIProfilerError( + f"failed to extract profiling results, return code is not 0 : {output}" + ) + + if match := re.search("INFO ::1:(.*)", output): + current_status = match.group(1) + if current_status.__contains__("Profile run status: PROCESSING"): + time.sleep(10) + self._get_results() + return + elif not current_status.__contains__("Profile run finished with SUCCESS"): + raise StrobelightCLIProfilerError( + f"failed to extract profiling results, unexpected response {output}" + ) + + self.profile_result = [] + for item in re.findall( + r"(Total samples(.*)|GraphProfiler(.*)|Icicle view \(python stack\)(.*))", + output, + ): + self.profile_result += item[0] + logger.info(item[0]) + + def _stop_strobelight_no_throw( + self, + collect_results: bool, + ) -> None: + try: + # call stop run + self._stop_run() + logger.info("strobelight profiling stopped") + + logger.debug("collection stopped") + + if not collect_results: + return + + self._get_results() + except Exception: + logger.warning("error during stop_strobelight", exc_info=True) + + # Return true if strobelight started and is running. Never throw. + def _start_strobelight(self) -> bool: + strobelight_started = False + try: + self._run_async() + strobelight_started = True + logger.info("strobelight run id is: %s", self.current_run_id) + self._wait_for_running() + logger.info("strobelight profiling running") + return True + + except Exception: + logger.warning("error during start_strobelight:", exc_info=True) + if strobelight_started: + self._stop_strobelight_no_throw(collect_results=False) + return False + + def profile( + self, work_function: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs + ) -> Optional[_R]: + self.current_run_id = None + self.profile_result = None + + if locked := StrobelightCLIFunctionProfiler._lock.acquire(False): + if not locked: + if self.stop_at_error: + raise StrobelightCLIProfilerError("concurrent runs not supported") + + logger.warning("concurrent runs not supported") + return work_function(*args, **kwargs) + + started = self._start_strobelight() + if not started: + if self.stop_at_error: + StrobelightCLIFunctionProfiler._lock.release() + raise StrobelightCLIProfilerError( + "failed to start strobelight profiling" + ) + result = work_function(*args, **kwargs) + StrobelightCLIFunctionProfiler._lock.release() + return result + + try: + logger.debug("collection started") + start = timer() + result = work_function(*args, **kwargs) + end = timer() + total_time = end - start # Time in seconds, e.g. 5.38091952400282 + logger.info("work function took %s seconds", total_time) + self._stop_strobelight_no_throw(collect_results=True) + StrobelightCLIFunctionProfiler._lock.release() + return result + except Exception as error: + logger.warning("work function throw exception", exc_info=True) + self._stop_strobelight_no_throw(collect_results=False) + StrobelightCLIFunctionProfiler._lock.release() + raise error + return None + + +# A function decorator that wraps profile, if no profiler is provided one with +# default args is created. A function can be annotated as: +# @strobelight() +# @strobelight(profiler = StrobelightFunctionProfiler(stop_at_error=True,..)) +# @strobelight(stop_at_error=True,...) +def strobelight( + profiler: Optional[StrobelightCLIFunctionProfiler] = None, **kwargs: Any +) -> Callable[[Callable[_P, _R]], Callable[_P, Optional[_R]]]: + if not profiler: + profiler = StrobelightCLIFunctionProfiler(**kwargs) + + def strobelight_inner( + work_function: Callable[_P, _R] + ) -> Callable[_P, Optional[_R]]: + @functools.wraps(work_function) + def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]: + return profiler.profile(work_function, *args, **kwargs) + + return wrapper_function + + return strobelight_inner diff --git a/phivenv/Lib/site-packages/torch/_strobelight/compile_time_profiler.py b/phivenv/Lib/site-packages/torch/_strobelight/compile_time_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..b56a39516d9d5494a186c9e7b01c4f9d15628204 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_strobelight/compile_time_profiler.py @@ -0,0 +1,224 @@ +# mypy: disallow-untyped-defs + +import json +import logging +import os +import re +import subprocess +from datetime import datetime +from socket import gethostname +from typing import Any, Optional + +from torch._strobelight.cli_function_profiler import StrobelightCLIFunctionProfiler + + +logger = logging.getLogger("strobelight_compile_time_profiler") + +console_handler = logging.StreamHandler() +formatter = logging.Formatter( + "%(name)s, line %(lineno)d, %(asctime)s, %(levelname)s: %(message)s" +) +console_handler.setFormatter(formatter) + +logger.addHandler(console_handler) +logger.setLevel(logging.INFO) +logger.propagate = False + + +def get_fburl(url: str) -> str: + short_url = url + # Attempt to shorten the URL + try: + result = subprocess.run( + ["fburl", url], capture_output=True, stdin=subprocess.DEVNULL + ) + if result.returncode == 0: + short_url = result.stdout.decode("utf-8") + except Exception as e: + logger.warning("URL shortening failed: %s, using long URL", repr(e)) + return short_url + + +def get_strobelight_url(identifier: str) -> str: + scuba_json = { + "aggregateList": [], + "aggregation_field": "async_stack_complete", + "b_constraints": [[]], + "c_constraints": [[]], + "cols": ["namespace_id", "namespace_process_id"], + "compare": "none", + "constraints": [ + [{"column": "sample_tags", "op": "all", "value": [f'["{identifier}"]']}] + ], + "derivedCols": [], + "end": "now", + "enumCols": [], + "filterMode": "DEFAULT", + "hideEmptyColumns": "false", + "ignoreGroupByInComparison": "false", + "is_timeseries": "false", + "mappedCols": [], + "metric": "count", + "modifiers": [], + "order": "weight", + "order_desc": "true", + "param_dimensions": [ + {"dim": "py_async_stack", "op": "edge", "param": "0", "anchor": "0"} + ], + "purposes": [], + "return_remainder": "false", + "samplingRatio": "1", + "should_pivot": "false", + "start": "-30 days", + "timezone": "America/Los_Angeles", + "top": 10000, + } + scuba_url_prefix = "https://www.internalfb.com/intern/scuba/query/?dataset=pyperf_experimental/on_demand&drillstate=" + scuba_url_suff = "&view=GraphProfilerView&&normalized=1726332703&pool=uber" + long_url = scuba_url_prefix + json.dumps(scuba_json) + scuba_url_suff + return get_fburl(long_url) + + +class StrobelightCompileTimeProfiler: + success_profile_count: int = 0 + failed_profile_count: int = 0 + ignored_profile_runs: int = 0 + inside_profile_compile_time: bool = False + enabled: bool = False + + # A regex that can be used to filter out what frames to profile. ex: "1/.*" + frame_id_filter: Optional[str] = os.environ.get("COMPILE_STROBELIGHT_FRAME_FILTER") + + # A unique identifier that is used as the run_user_name in the strobelight profile to + # associate all compile time profiles together. + identifier: Optional[str] = None + + current_phase: Optional[str] = None + + profiler: Optional[Any] = None + + max_stack_length: int = int( + os.environ.get("COMPILE_STROBELIGHT_MAX_STACK_LENGTH", 500) + ) + max_profile_time: int = int( + os.environ.get("COMPILE_STROBELIGHT_MAX_PROFILE_TIME", 60 * 30) + ) + # Collect sample each x cycles. + sample_each: int = int( + float(os.environ.get("COMPILE_STROBELIGHT_SAMPLE_RATE", 1e7)) + ) + + @classmethod + def get_frame(cls) -> str: + from torch._guards import CompileContext + + return (str)(CompileContext.current_trace_id()) + + @classmethod + def enable(cls, profiler_class: Any = StrobelightCLIFunctionProfiler) -> None: + if cls.enabled: + logger.info("compile time strobelight profiling already enabled") + return + + logger.info("compile time strobelight profiling enabled") + + if profiler_class is StrobelightCLIFunctionProfiler: + import shutil + + if not shutil.which("strobeclient"): + logger.info( + "strobeclient not found, cant enable compile time strobelight profiling, seems" + "like you are not on a FB machine." + ) + return + + cls.enabled = True + cls._cls_init() + # profiler_class should have public API similar to that of StrobelightCLIFunctionProfiler. + # we have pass different functionProfilerClass for meta-internal fbcode targets. + # NB: the actual implementation in Meta is at + # fbcode/caffe2/fb/strobelight/function_profiler.py + cls.profiler = profiler_class( + sample_each=cls.sample_each, + max_profile_duration_sec=cls.max_profile_time, + stack_max_len=cls.max_stack_length, + async_stack_max_len=cls.max_stack_length, + run_user_name="pt2-profiler/" + + os.environ.get("USER", os.environ.get("USERNAME", "")), + sample_tags={cls.identifier}, + ) + + @classmethod + def _cls_init(cls) -> None: + cls.identifier = "{date}{pid}{hostname}".format( + date=datetime.now().strftime("%Y-%m-%d-%H:%M:%S"), + pid=os.getpid(), + hostname=gethostname(), + ) + + logger.info("Unique sample tag for this run is: %s", cls.identifier) + logger.info( + "URL to access the strobelight profile at the end of the run: %s", + get_strobelight_url(cls.identifier), + ) + + @classmethod + def _log_stats(cls) -> None: + logger.info( + "%s strobelight success runs out of %s non-recursive compilation events.", + cls.success_profile_count, + cls.success_profile_count + cls.failed_profile_count, + ) + + # TODO use threadlevel meta data to tags to record phases. + @classmethod + def profile_compile_time( + cls, func: Any, phase_name: str, *args: Any, **kwargs: Any + ) -> Any: + def skip() -> Any: + return func(*args, **kwargs) + + if not cls.enabled: + return skip() + + if cls.profiler is None: + logger.error("profiler is not set") + return + + frame_id = cls.get_frame() + + if cls.inside_profile_compile_time: + cls.ignored_profile_runs += 1 + logger.info( + "profile_compile_time is requested for phase: %s, frame %s, while already in running phase: %s," + "frame %s, recursive call ignored", + phase_name, + frame_id, + cls.current_phase, + frame_id, + ) + return skip() + + if cls.frame_id_filter is not None: + should_run = re.match(cls.frame_id_filter, frame_id) is not None + if not should_run: + logger.info( + "profiling frame %s is skipped due to frame_id_filter %s", + frame_id, + cls.frame_id_filter, + ) + return skip() + + cls.inside_profile_compile_time = True + cls.current_phase = phase_name + logger.info("profiling frame %s", frame_id) + work_result = cls.profiler.profile(func, *args, **kwargs) + + if cls.profiler.profile_result is not None: + cls.success_profile_count += 1 + else: + cls.failed_profile_count += 1 + + cls._log_stats() + cls.inside_profile_compile_time = False + return work_result diff --git a/phivenv/Lib/site-packages/torch/_subclasses/__init__.py b/phivenv/Lib/site-packages/torch/_subclasses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..21ec5c02e3679fda228ffc300a746c0626d76a28 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_subclasses/__init__.py @@ -0,0 +1,17 @@ +import torch +from torch._subclasses.fake_tensor import ( + DynamicOutputShapeException, + FakeTensor, + FakeTensorMode, + UnsupportedFakeTensorException, +) +from torch._subclasses.fake_utils import CrossRefFakeMode + + +__all__ = [ + "FakeTensor", + "FakeTensorMode", + "UnsupportedFakeTensorException", + "DynamicOutputShapeException", + "CrossRefFakeMode", +] diff --git a/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..baa6863152ed3ac3215c2710ad43335fb4318e96 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/_fake_tensor_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/_fake_tensor_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7384a1e3315ff8090d175b46bfdea88396a6935 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/_fake_tensor_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/fake_impls.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/fake_impls.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec8be578eded56e8ec4667a6371da9b228061fd1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/fake_impls.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/fake_tensor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/fake_tensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6292282e5a5f285b7bd6ec8748881ccf987aeae1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/fake_tensor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/fake_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/fake_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4351a3cab3f861510d04b6de6b9e1f76d800d0b9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/fake_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/functional_tensor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/functional_tensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69cc24afb91455e5174a43cf575a8a0e2a7c3437 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/functional_tensor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/meta_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/meta_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0393b2a5ab611eeb1e6f1d3c2d2b0977ce15dfc0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/meta_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/schema_check_mode.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/schema_check_mode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba3cda2a3a14cad7b37eaf71109faed2d736b160 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_subclasses/__pycache__/schema_check_mode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_subclasses/_fake_tensor_utils.py b/phivenv/Lib/site-packages/torch/_subclasses/_fake_tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1e5ea46d47cbd156ecc0eb3ab06ae2c66892f6b7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_subclasses/_fake_tensor_utils.py @@ -0,0 +1,265 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, TYPE_CHECKING, Union + +import torch +from torch import SymInt +from torch.fx.experimental.sym_node import SymNode +from torch.types import py_sym_types, PySymType +from torch.utils._backport_slots import dataclass_slots + + +if TYPE_CHECKING: + import sympy + + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + from .fake_tensor import _DispatchCacheKey, _MetadataIntLike + + +@dataclass_slots +@dataclass(frozen=True) +class _DeconstructedSymNode: + """ + Represents a SymNode without the associated ShapeEnv + """ + + # n.b. keep the same protocol as SymNode + _expr: sympy.Expr + pytype: type + _hint: Optional[Union[int, float, bool]] + constant: Optional[Union[int, float, bool]] + fx_node: torch.fx.Node + + @staticmethod + def from_node(node: SymNode) -> _DeconstructedSymNode: + return _DeconstructedSymNode( + node._expr, node.pytype, node._hint, node.constant, node.fx_node + ) + + def extract(self, shape_env: ShapeEnv) -> SymNode: + return SymNode( + self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node + ) + + def __str__(self) -> str: + return str(self._expr) + + def __repr__(self) -> str: + return f"_DeconstructedSymNode{{{self._expr!r}, {self.pytype!r}, {self._hint!r}, {self.constant!r}, {self.fx_node!r}}}" + + def __eq__(self, other: object) -> bool: + raise NotImplementedError + + def __hash__(self) -> int: + raise NotImplementedError + + # _value_eq to match SymNode + def _value_eq(self, other: object) -> bool: + if isinstance(other, (SymNode, _DeconstructedSymNode)): + return ( + self._expr == other._expr + and self.pytype == other.pytype + and self._hint == other._hint + and self.constant == other.constant + and self.fx_node == other.fx_node + ) + else: + return False + + # _value_hash to match SymNode + def _value_hash(self) -> int: + return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node)) + + +@dataclass_slots +@dataclass(frozen=True) +class _DeconstructedSymType: + """ + Represents a SymInt, SymFloat, SymBool without the associated ShapeEnv + """ + + ty: type[PySymType] + node: _DeconstructedSymNode + + @staticmethod + def from_sym_type(value: PySymType) -> _DeconstructedSymType: + return _DeconstructedSymType(type(value), value.node) + + def extract(self, shape_env: ShapeEnv) -> PySymType: + return self.ty(self.node.extract(shape_env)) + + def __str__(self) -> str: + return f"{self.ty}({self.node})" + + def __repr__(self) -> str: + return f"_DeconstructedSymType({self.ty}, {self.node!r})" + + def __eq__(self, other: object) -> bool: + return NotImplemented + + def __hash__(self) -> int: + return NotImplemented + + +@dataclass_slots +@dataclass(frozen=True) +class _InputBackref: + value: int + + +@dataclass_slots +@dataclass +class _PySymInputStub: + """ + Represents a SymInt in the cached key. Needed because SymInt doesn't + support __eq__ or __hash__ directly. + """ + + # value can be: + # PySymType: This is the 'normal' SymInt value, wrapped so we can use + # hash/eq as value hash/eq (normally SymInt does object + # hash/eq). + # _DeconstructedSymType: This is used when storing the _PySymInputStub in + # the cache to avoid cyclic ShapeEnv references. + # _InputBackref: This is a back-reference to a previous _PySymInputStub in + # the key. + value: Union[PySymType, _DeconstructedSymType, _InputBackref] + + def __init__( + self, value: Union[PySymType, _DeconstructedSymType, _InputBackref] + ) -> None: + # For inputs (values in the `key`) we need to keep the PySymType intact + # - this way if we need to reuse it as an output we can properly copy + # the original value. + self.value = value + + def strip_shape_env(self) -> None: + if isinstance(self.value, py_sym_types): + self.value = _DeconstructedSymType.from_sym_type(self.value) + + def extract(self, shape_env: ShapeEnv) -> PySymType: + if isinstance(self.value, _DeconstructedSymType): + return self.value.extract(shape_env) + else: + # We should never see an _InputBackref here - anyone extracting a + # value should be pulling from the original entry (the one this + # backref points at). + assert not isinstance(self.value, _InputBackref) + return self.value + + def __str__(self) -> str: + return str(self.value) + + def __repr__(self) -> str: + return f"_PySymInputStub({self.value!r})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _PySymInputStub): + return False + elif isinstance(self.value, _InputBackref) or isinstance( + other.value, _InputBackref + ): + return self.value == other.value + else: + return self.value.node._value_eq(other.value.node) + + def __hash__(self) -> int: + if isinstance(self.value, _InputBackref): + return hash(self.value) + else: + return self.value.node._value_hash() + + +@dataclass_slots +@dataclass +class _SymIntOutputStub: + """ + Represents a SymInt in the cached output. + """ + + # This is either an `int` which represents the index in the key to copy the + # SymNode from or it's the deconstructed SymNode itself. + value: Union[int, _DeconstructedSymNode] + + def __init__(self, value: SymInt, key_path: Optional[int]) -> None: + if key_path is None: + self.value = _DeconstructedSymNode.from_node(value.node) + else: + self.value = key_path + + def extract(self, key: _DispatchCacheKey, shape_env: ShapeEnv) -> SymInt: + if isinstance(self.value, _DeconstructedSymNode): + return SymInt(self.value.extract(shape_env)) + else: + src = key.key[self.value] + assert isinstance(src, _PySymInputStub) and isinstance(src.value, SymInt) + return src.value + + def __repr__(self) -> str: + return f"_SymIntOutputStub({self.value!r})" + + def __eq__(self, other: object) -> bool: + raise NotImplementedError + + def __hash__(self) -> int: + raise NotImplementedError + + +@dataclass_slots +@dataclass +class _CacheKeyState: + """ + State used while building our cache key. + """ + + # We track the SymNodes so when we get the output we can see if it exactly + # matches one of the inputs so we can uncache it properly. + sym_node_lookup: dict[int, int] # id(SymNode) -> index + + # This is a list of all seen input sympy.Symbols. We use it when building + # the cache entry to see if the output value has any symbols that we didn't + # see on input. See _has_unrepresented_symbols(). + known_symbols: set[sympy.Symbol] + + # There are cases where we're asked to perform an op when we have no + # ShapeEnv on the FakeTensorMode - but for SymNodes we MUST have a + # ShapeEnv. So as we scan if we see a SymNode (with a ShapeEnv) we record it + # here. + shape_env: Optional[ShapeEnv] + + def __init__(self, shape_env: Optional[ShapeEnv] = None) -> None: + self.sym_node_lookup = {} + self.known_symbols = set() + self.shape_env = shape_env + + def cache_on_shape_env(self) -> bool: + """ + Returns true if the CacheKey needs to be cached on the ShapeEnv + rather than the global cache. + + If our inputs contain a SymNode then we can't cache this operation on + the global cache because the cached output will implicitly depend on + guard values which might not be true on some other ShapeEnv. So unless + we're also going to cache the guards we need to cache this operation on + the ShapeEnv instead of globally. + """ + return bool(self.sym_node_lookup) + + def convert_sym_int(self, result: list[object], arg: SymInt) -> None: + node_id = id(arg.node) + if node_id in self.sym_node_lookup: + result.append(_InputBackref(self.sym_node_lookup[node_id])) + else: + self.sym_node_lookup[node_id] = len(result) + self.known_symbols.update(arg.node.expr.free_symbols) + if self.shape_env is None: + self.shape_env = arg.node.shape_env + result.append(_PySymInputStub(arg)) + + def convert_output(self, arg: _MetadataIntLike) -> _MetadataIntLike: + if isinstance(arg, SymInt): + return _SymIntOutputStub(arg, self.sym_node_lookup.get(id(arg.node), None)) + else: + return arg diff --git a/phivenv/Lib/site-packages/torch/_subclasses/fake_impls.py b/phivenv/Lib/site-packages/torch/_subclasses/fake_impls.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb1d2d96ff11028f0e551f7a664f807ac1f01f9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_subclasses/fake_impls.py @@ -0,0 +1,1102 @@ +# mypy: ignore-errors + +import functools +import itertools +import math +import sys +from typing import Callable, Union + +import torch +import torch._custom_op +import torch._logging +from torch._dispatch.python import no_python_dispatcher +from torch._ops import OpOverload +from torch._prims_common import ( + definitely_contiguous_for_memory_format, + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + is_boolean_dtype, + is_float_dtype, + is_integer_dtype, +) +from torch._subclasses.fake_tensor import ( + DataDependentOutputException, + DynamicOutputShapeException, + FakeTensor, + in_kernel_invocation_manager, + run_fallback_kernel, + UnsupportedOperatorException, +) +from torch.fx.operator_schemas import normalize_function +from torch.utils._stats import count_label + + +pytree = torch.utils._pytree + +__all__ = [ + "op_implementations_checks", + "get_fast_op_impls", + "stride_incorrect_op", + "has_meta", +] + +op_implementations_dict = {} +op_implementations_checks = [] + + +aten = torch._ops.ops.aten + + +def ordered_set(*items): + return dict.fromkeys(items, True) + + +# This function indicates if the backend device +# supports non-contiguous tensors +def is_noncontiguous_supported(device): + return device.type != "hpu" + + +_like_tensor_constructors = ordered_set( + aten.empty_like.default, + aten.empty_like.out, + aten.full_like.default, + aten.full_like.out, + aten.ones_like.default, + aten.ones_like.out, + aten.rand_like.default, + aten.rand_like.out, + aten.randn_like.default, + aten.randn_like.out, + aten.randint_like.default, + aten.randint_like.Tensor, + aten.randint_like.Tensor_out, + aten.randint_like.out, + aten.randint_like.low_dtype, + aten.randint_like.low_dtype_out, + aten.zeros_like.default, + aten.zeros_like.out, + aten.new_empty.default, + aten.new_empty.out, + aten.new_empty_strided.default, + aten.new_empty_strided.out, + aten.new_full.default, + aten.new_full.out, + aten.new_zeros.default, + aten.new_zeros.out, + aten.new_ones.default, + aten.new_ones.out, +) + + +_device_not_kwarg_ops = ordered_set( + aten._resize_output_.default, + aten._nested_tensor_from_tensor_list.default, + aten._nested_tensor_from_tensor_list.out, + aten.pin_memory.default, + aten.to.device, + aten.to.prim_Device, + aten.is_pinned.default, + aten._pin_memory.default, + aten._pin_memory.out, + aten._resize_output.default, + aten._resize_output.out, +) + +# this op is never actually used +_non_kwarg_device_constructors = (aten._list_to_tensor,) + + +def contains_tensor_types(type): + tensor_type = torch._C.TensorType.get() + return type.isSubtypeOf(tensor_type) or any( + contains_tensor_types(e) for e in type.containedTypes() + ) + + +@functools.cache +def _is_tensor_constructor(func: OpOverload): + assert isinstance(func, OpOverload) + schema = func._schema + if any(contains_tensor_types(arg.type) for arg in schema.arguments): + return False + # TODO: no real reason to restrict multiple outputs + return ( + len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get() + ) + + +def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]): + def impl_decorator(op_impl): + if isinstance(run_impl_check, OpOverload): + assert ( + run_impl_check not in op_implementations_dict + ), f"duplicate registration: {run_impl_check}" + op_implementations_dict[run_impl_check] = op_impl + elif isinstance(run_impl_check, (list, tuple)): + for op in run_impl_check: + register_op_impl(op)(op_impl) + else: + assert callable(run_impl_check) + op_implementations_checks.append((run_impl_check, op_impl)) + + return op_impl + + return impl_decorator + + +def _is_op_registered_to_fake_rule(op): + return op in op_implementations_dict + + +def _deregister_op_impl(op): + if op in op_implementations_dict: + del op_implementations_dict[op] + for check, impl in op_implementations_checks: + if check is op: + op_implementations_checks.remove((check, impl)) + break + + +@register_op_impl(op_implementations_dict.__contains__) +def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs): + return op_implementations_dict[func](fake_mode, func, *args, **kwargs) + + +@register_op_impl(_is_tensor_constructor) +@register_op_impl([*_like_tensor_constructors]) +def constructors(fake_mode, func, *args, **kwargs): + assert func not in _non_kwarg_device_constructors + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + if "names" in kwargs: + raise UnsupportedOperatorException( + "torch.compile doesn't support named tensors" + ) + + if func in _like_tensor_constructors: + default_device = new_kwargs["input"].device + # TODO: file issue + args = (new_kwargs.pop("input"),) + else: + # cpu is default device if none is specified + default_device = torch.device("cpu") + args = () + out_device = new_kwargs.pop("device", None) + out_device = out_device if out_device is not None else default_device + new_kwargs["device"] = torch.device("meta") + # _like constructors have fake tensor inputs (maybe this causes the non-like + # to fail? hmmm) + with in_kernel_invocation_manager(fake_mode): + r = func(*args, **new_kwargs) + return FakeTensor(fake_mode, r, out_device) + + +@register_op_impl(aten.is_pinned.default) +def non_kwarg_is_pinned(fake_mode, func, *args, **kwargs): + _, new_kwargs = normalize_function( + func, args, kwargs, normalize_to_only_use_kwargs=True + ) + inp = new_kwargs.pop("input") + # we'll ignore device argument because it is deprecated and not + # actually used by is_pinned. + with in_kernel_invocation_manager(fake_mode): + r = func(inp) + return r + + +@register_op_impl(aten.to.prim_Device) +@register_op_impl(aten.to.device) +def non_kwarg_to(fake_mode, func, *args, **kwargs): + _, new_kwargs = normalize_function( + func, args, kwargs, normalize_to_only_use_kwargs=True + ) + input_device = new_kwargs["device"] + out_device = input_device if input_device else new_kwargs["input"].device + new_kwargs["device"] = torch.device("meta") + inp = new_kwargs.pop("input") + with in_kernel_invocation_manager(fake_mode): + r = func(inp, **new_kwargs) + # TODO: I think this does the wrong thing if r is inp + return fake_mode.fake_tensor_converter.from_meta_and_device( + fake_mode, r, out_device + ) + + +def stride_incorrect_op(op): + return False + + +# These operators have meta implementations with incorrect strides +@register_op_impl(stride_incorrect_op) +def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs): + # This is a workaround for meta implmentations with incorrect strides + + def is_symbolic(x): + if isinstance(x, FakeTensor): + return x._has_symbolic_sizes_strides + if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)): + return True + return False + + # For static shapes, we can fall back to eager for the real strides + if fake_mode.allow_fallback_kernels: + require_dynamic = any( + is_symbolic(x) for x in itertools.chain(args, kwargs.values()) + ) + if not require_dynamic: + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None) + + raise UnsupportedOperatorException(func) + + +# Dont default to default device handling, +# since the device of `the_template` is ignored +@register_op_impl(aten.resize_as_.default) +def resize_as_(fake_mode, func, *args, **kwargs): + with in_kernel_invocation_manager(fake_mode): + return func(*args, **kwargs) + + +@register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default) +def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs): + # TODO: remove me + return constructors(fake_mode, func, *args, **kwargs) + + +# index.Tensor data-dependent in only some conditions +@register_op_impl( + lambda func: torch.Tag.dynamic_output_shape in func.tags + and func + not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor] +) +def dyn_shape(fake_mode, func, *args, **kwargs): + raise DynamicOutputShapeException(func) + + +def _unique( + fake_mode, + func, + arg, + dim, + sorted=True, + return_inverse=False, + return_counts=False, + *, + unique_consecutive=False, +): + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + # Without symints/symfloats, cannot handle this + raise DynamicOutputShapeException(func) + + nnz = arg.unique_consecutive_memo if unique_consecutive else arg.unique_memo + + # Do not use a memo for unique_dim + if dim is not None or nnz is None: + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import ( + _constrain_range_for_size, + has_free_symbols, + ) + + if not has_free_symbols(arg.numel()) and arg.numel() == 0: + # If numel is zero, then the output size must be zero. + # In this case, we must not allocate an unbacked SymInt, + # because if we do, it will immediately get refined to + # zero, but this will be inconsistent with size oblivious + # tests (which will continue to claim that the unbacked + # symint cannot equal zero). We could also unconditionally + # allocate an unbacked SymInt and not refine its range, + # but this seems more precise. + nnz = 0 + else: + nnz = fake_mode.shape_env.create_unbacked_symint() + + maxval = sys.maxsize - 1 + + numel = arg.numel() if dim is None else arg.size(dim) + if not has_free_symbols(numel): + maxval = int(numel) + + _constrain_range_for_size(nnz, max=maxval) + + if dim is None: + if unique_consecutive: + arg.unique_consecutive_memo = nnz + else: + arg.unique_memo = nnz + + if dim is None: + ret = [arg.new_empty((nnz,))] + else: + ret = [arg.new_empty(*arg.shape[:dim], nnz, *arg.shape[dim + 1 :])] + + return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu") + if return_inverse or return_if_dim_and_cpu: + inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],)) + else: + inverse = arg.new_empty(0) + ret.append(inverse) + + if return_counts or return_if_dim_and_cpu: + counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],)) + else: + counts = arg.new_empty(0) + ret.append(counts) + + return tuple(ret) + + +@register_op_impl(aten._unique2.default) +def unique2( + fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False +): + return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts) + + +@register_op_impl(aten.unique_dim.default) +def unique_dim( + fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False +): + return _unique( + fake_mode, + func, + arg, + # normalize dim to be non-negative + dim if dim >= 0 else dim % max(arg.ndim, 1), + sorted, + return_inverse, + return_counts, + ) + + +@register_op_impl(aten.unique_consecutive.default) +def _(fake_mode, func, arg, return_inverse=False, return_counts=False, dim=None): + return _unique( + fake_mode, + func, + arg, + dim, + False, + return_inverse, + return_counts, + unique_consecutive=True, + ) + + +@register_op_impl(aten.repeat_interleave.Tensor) +def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None): + if output_size is None: + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + raise DynamicOutputShapeException(func) + + output_size = fake_mode.shape_env.create_unbacked_symint() + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size + + _constrain_range_for_size(output_size) + # TODO: consider a memo + return repeats.new_empty(output_size) + + +@register_op_impl(torch.ops.aten.item.default) +@register_op_impl(torch.ops.aten._local_scalar_dense.default) +def local_scalar_dense(fake_mode, func, arg): + if (r := arg.item_memo) is not None: + return r + if fake_mode.shape_env is None or ( + not fake_mode.shape_env.allow_scalar_outputs + and not fake_mode.allow_scalar_outputs + ): + # Without symints/symfloats, cannot handle this + raise DataDependentOutputException(func) + if is_float_dtype(arg.dtype): + r = fake_mode.shape_env.create_unbacked_symfloat() + elif is_integer_dtype(arg.dtype): + r = fake_mode.shape_env.create_unbacked_symint() + elif is_boolean_dtype(arg.dtype): + r = fake_mode.shape_env.create_unbacked_symbool() + else: + raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}") + arg.item_memo = r + return r + + +@register_op_impl(torch.ops.aten.nonzero_numpy.default) +def nonzero_numpy(fake_mode, func, arg): + return torch.ops.aten.nonzero.default(arg).unbind(1) + + +@register_op_impl(torch.ops.aten.nonzero.default) +def nonzero(fake_mode, func, arg): + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + # Without symints/symfloats, cannot handle this + raise DynamicOutputShapeException(func) + + if (nnz := arg.nonzero_memo) is None: + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import ( + _constrain_range_for_size, + has_free_symbols, + ) + from torch.utils._sympy.numbers import IntInfinity + from torch.utils._sympy.value_ranges import bound_sympy + + if not has_free_symbols(arg.numel()) and arg.numel() == 0: + # If numel is zero, then the output size must be zero. + # In this case, we must not allocate an unbacked SymInt, + # because if we do, it will immediately get refined to + # zero, but this will be inconsistent with size oblivious + # tests (which will continue to claim that the unbacked + # symint cannot equal zero). We could also unconditionally + # allocate an unbacked SymInt and not refine its range, + # but this seems more precise. + nnz = 0 + else: + nnz = fake_mode.shape_env.create_unbacked_symint() + + maxval = sys.maxsize - 1 + + if not has_free_symbols(arg.numel()): + maxval = int(arg.numel()) + else: + prod_node = math.prod(arg.shape).node + prod_range = bound_sympy( + prod_node.expr, prod_node.shape_env.var_to_range + ) + if isinstance(prod_range.upper, IntInfinity): + maxval = sys.maxsize - 1 + else: + maxval = prod_range.upper + + _constrain_range_for_size(nnz, max=maxval) + + arg.nonzero_memo = nnz + + return arg.new_empty_strided((nnz, arg.dim()), (1, nnz), dtype=torch.int64) + + +@register_op_impl(torch.ops.aten._padded_dense_to_jagged_forward.default) +def _padded_dense_to_jagged_forward(fake_mode, func, padded, offsets, total_L=None): + # only one jagged dim is supported for now + assert len(offsets) == 1 + + if not total_L: + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + # Without symints/symfloats, cannot handle this + raise DynamicOutputShapeException(func) + + total_L = fake_mode.shape_env.create_unbacked_symint() + + maxval = sys.maxsize - 1 + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import ( + _constrain_range_for_size, + has_free_symbols, + ) + + if not has_free_symbols(padded.numel()): + maxval = int(padded.numel()) + + _constrain_range_for_size(total_L, min=0, max=maxval) + + output_shape = (total_L, *padded.shape[2:]) + return padded.new_empty(output_shape) + + +@register_op_impl(torch.ops.aten.masked_select.default) +def masked_select(fake_mode, func, self, mask): + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + # Without symints/symfloats, cannot handle this + raise DynamicOutputShapeException(func) + + nnz = fake_mode.shape_env.create_unbacked_symint() + + # see nonzero for commentary + maxval = sys.maxsize - 1 + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import ( + _constrain_range_for_size, + has_free_symbols, + ) + from torch.utils._sympy.numbers import IntInfinity + from torch.utils._sympy.value_ranges import bound_sympy + + # If num elements is expressed symbolically, calculate + # the concrete value based on upper bounds. Otherwise, + # we can set max val directly. + if not has_free_symbols(self.numel()): + num_elements = int(self.numel()) + else: + prod_node = math.prod(self.shape).node + prod_range = bound_sympy(prod_node.expr, prod_node.shape_env.var_to_range) + if isinstance(prod_range.upper, IntInfinity): + num_elements = sys.maxsize - 1 + else: + num_elements = prod_range.upper + if num_elements > 2: + maxval = num_elements + + _constrain_range_for_size(nnz, max=maxval) + + return self.new_empty((nnz,)) + + +@register_op_impl(torch.ops.aten._assert_tensor_metadata.default) +def assert_tensor_metadata( + fake_mode, + func, + t, + sizes=None, + strides=None, + dtype=None, + *, + device=None, + layout=None, +) -> None: + if sizes is not None: + assert ( + t.size() == sizes + ), f"Tensor sizes mismatch! Expected: {sizes}, Got: {t.size()}" + if strides is not None: + assert ( + t.stride() == strides + ), f"Tensor strides mismatch! Expected: {strides}, Got: {t.stride()}" + if dtype is not None: + assert ( + t.dtype == dtype + ), f"Tensor dtype mismatch! Expected: {dtype}, Got: {t.dtype}" + if layout is not None: + assert ( + t.layout == layout + ), f"Tensor layout mismatch! Expected: {layout}, Got: {t.layout()}" + if device is not None: + assert ( + t.device == device + ), f"Tensor device mismatch! Expected: {device}, Got: {t.device}" + + +# NB: this must be ordered after local_scalar_dense +@register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags) +def data_dep(fake_mode, func, *args, **kwargs): + raise DataDependentOutputException(func) + + +# Bool Indices get Expanded as Masks +# See: IndexingUtils.h:expandTensors +def check_no_bool_index_tensors(func, self, indices): + for index in indices: + if index is not None and index.dtype in (torch.bool, torch.uint8): + raise DynamicOutputShapeException(func) + + +def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs): + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + out_device = new_kwargs["input"].device + with in_kernel_invocation_manager(fake_mode): + out = func(*args, **kwargs) + if not is_noncontiguous_supported(out_device): + out = out.new_empty(out.shape) + + if out is new_kwargs["input"]: + return out # copy_ + return FakeTensor(fake_mode, out, out_device) + + +_is_builtin_namespaces = ordered_set("aten", "prims", "prim") + + +def is_builtin(op): + return op.namespace in _is_builtin_namespaces + + +def has_meta(func): + return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta") + + +# These are for the `torch._foreach_...` ops like `torch._foreach_add`. +@register_op_impl( + lambda func: is_builtin(func) + and func.name().startswith("aten::_foreach_") + and has_meta(func) +) +def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs): + tensor_lists = [ + arg + for arg in itertools.chain(args, kwargs.values()) + if isinstance(arg, (list, tuple)) + and len(arg) + and isinstance(arg[0], torch.Tensor) + ] + + try: + with in_kernel_invocation_manager(fake_mode): + out_meta = func(*args, **kwargs) + except NotImplementedError: + return NotImplemented + + if not out_meta: + return out_meta + + assert tensor_lists + out_fake = [] + + for i, meta_t in enumerate(out_meta): + device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists]) + out_fake.append( + fake_mode.fake_tensor_converter.from_meta_and_device( + fake_mode, meta_t, device + ) + ) + + return out_fake + + +# Dont default to default device handling, +# Since op can take in non-zero sized cpu +# index tensors with cuda self +@register_op_impl(aten.index.Tensor) +def index_tensor(fake_mode, func, *args, **kwargs): + from torch._meta_registrations import meta_index_Tensor + + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + out_device = new_kwargs["input"].device + # ensure nonzero call goes to fake tensor + with fake_mode: + out = meta_index_Tensor(*args, **kwargs) + return out.to(out_device) + + +# Can take mixed meta/non-meta arguments; the meta registration +# will roughly do the right thing even when given real devices +@register_op_impl(aten._embedding_bag.default) +def embedding_bag(fake_mode, func, *args, **kwargs): + from torch._meta_registrations import meta_embedding_bag + + with fake_mode: + return meta_embedding_bag(*args, **kwargs) + + +# takes in multiple-devices, dont default to default device handling +@register_op_impl(aten._unsafe_index_put.default) +@register_op_impl(aten.copy.default) +@register_op_impl(aten.copy_.default) +@register_op_impl(aten.slice_scatter.default) +def multi_device_op_default(fake_mode, func, *args, **kwargs): + return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) + + +# same with multi_device_op_default, but return the input +@register_op_impl(aten.copy.out) +@register_op_impl(aten.slice_scatter.out) +def multi_device_op_out(fake_mode, func, *args, **kwargs): + with in_kernel_invocation_manager(fake_mode): + func(*args, **kwargs) + + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + return new_kwargs["input"] + + +@register_op_impl(aten.index_put.default) +@register_op_impl(aten.index_put_.default) +def index_put_impl(fake_mode, func, *args, **kwargs): + _, new_kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + values = new_kwargs["values"] + self_device = new_kwargs["input"].fake_device + torch._check( + self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1), + lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})", + ) + + out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) + if func is aten.index_put_.default: + return new_kwargs["input"] + else: + return out + + +@register_op_impl(aten._nested_tensor_from_tensor_list.default) +@register_op_impl(aten._nested_tensor_from_tensor_list.out) +@register_op_impl(aten._nested_view_from_buffer.default) +@register_op_impl(aten._nested_view_from_buffer_copy.default) +def nested_tensors_unsupported(fake_mode, func, *args, **kwargs): + raise UnsupportedOperatorException( + "torch.compile does not support strided NestedTensor" + ) + + +@register_op_impl( + [ + x + for x in _device_not_kwarg_ops + if x + not in ( + # these are already registered elsewhere + aten.is_pinned.default, + aten.to.device, + aten.to.prim_Device, + aten._nested_tensor_from_tensor_list.default, + aten._nested_tensor_from_tensor_list.out, + ) + ] +) +def nyi(fake_mode, func, *args, **kwargs): + assert func not in _device_not_kwarg_ops, f"NYI: {func}" + + +@register_op_impl([aten.convolution.default, aten.convolution_backward.default]) +def conv(fake_mode, func, *args, **kwargs): + _, kwargs = normalize_function( + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + device = kwargs["input"].fake_device + # need to re-enable mode so the tensors report fake device + with fake_mode: + # if the input is unsqueezed is done in Convolution.cpp we get segfault + k = kwargs["weight"].ndim + batch = kwargs["input"].shape[0] + + # Avoid importing sympy at a module level + from torch.fx.experimental.symbolic_shapes import has_hint + + if not has_hint(batch): + # TODO: We can make this a little more faithful with best effort + # channels last detection (but only if it's statically obvious!) + mem_fmt = None + elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu: + mem_fmt = None + else: + if func is aten.convolution.default: + conv_backend = torch._C._select_conv_backend(**kwargs) + else: + conv_backend = torch._C._select_conv_backend( + kwargs["input"], + kwargs["weight"], + bias=None, + stride=kwargs["stride"], + padding=kwargs["padding"], + dilation=kwargs["dilation"], + transposed=kwargs["transposed"], + output_padding=kwargs["output_padding"], + groups=kwargs["groups"], + bias_sizes=kwargs["bias_sizes"], + ) + mem_fmt = torch._C._conv_determine_backend_memory_format( + kwargs["input"], kwargs["weight"], conv_backend + ) + + def convert(t, mem_fmt): + if t is None: + return t + if mem_fmt is not None: + t = t.to(memory_format=mem_fmt) + return FakeTensor(fake_mode, t, device) + + with in_kernel_invocation_manager(fake_mode): + out = func(**kwargs) + + if func is aten.convolution.default: + return convert(out, mem_fmt) + else: + return ( + convert(out[0], mem_fmt), + convert(out[1], mem_fmt), + convert(out[2], None), + ) + + +@register_op_impl(torch.ops.aten.bincount.default) +def bincount(fake_mode, func, inputs, weights=None, minlength=0): + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + # Without symints/symfloats, cannot handle this + raise DynamicOutputShapeException(func) + + new_size = fake_mode.shape_env.create_unbacked_symint() + + from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size + + _constrain_range_for_size(new_size) + torch._check(new_size >= minlength) + return inputs.new_empty(new_size) + + +@register_op_impl(torch.ops.aten._pack_padded_sequence.default) +def _pack_padded_sequence(fake_mode, func, inputs, lengths, batch_first): + if ( + fake_mode.shape_env is None + or not fake_mode.shape_env.allow_dynamic_output_shape_ops + ): + # Without symints/symfloats, cannot handle this + raise DynamicOutputShapeException(func) + + new_batch_size = fake_mode.shape_env.create_unbacked_symint() + + from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size + + _constrain_range_for_size(new_batch_size) + + if not batch_first: + # Inputs should have shape (batch_size, seq_len, *) + inputs = inputs.transpose(0, 1) + + res_size = inputs.shape[1:] + packed_data = inputs.new_empty(res_size) + batch_size = inputs.new_empty((new_batch_size,)) + return (packed_data, batch_size) + + +FAST_OP_IMPLEMENTATIONS = {} + + +# Unlike register_op_impl, these don't do the slow iteration for +# run_impl_check, and these run BEFORE decompositions +def register_fast_op_impl(func: OpOverload): + def impl_decorator(op_impl): + FAST_OP_IMPLEMENTATIONS[func] = op_impl + return op_impl + + return impl_decorator + + +# infer_size_impl in ExpandUtils +def infer_size(a, b): + from torch.fx.experimental.symbolic_shapes import guard_or_false + + dimsA = len(a) + dimsB = len(b) + ndim = max(dimsA, dimsB) + expandedSizes = [0] * ndim + for i in range(ndim - 1, -1, -1): + offset = ndim - 1 - i + dimA = dimsA - 1 - offset + dimB = dimsB - 1 - offset + sizeA = a[dimA] if dimA >= 0 else 1 + sizeB = b[dimB] if dimB >= 0 else 1 + + # NB: It is very important to test for broadcasting, before testing + # sizeA == sizeB. This is because the broadcasting tests are likely + # to be statically known (in particular, if sizeA/sizeB is unbacked + # but size-like, we will unsoundly assume they never equal 1), but + # the sizeA == sizeB test may not be statically known. However, once + # we have established that no broadcasting is happening, the + # sizeA == sizeB is now expect_true and we can defer it as a runtime + # assert (this works because Python will return the terminal + # expression of an or statement as-is, without bool()'ing it; if this + # were not the case, we'd need to write this using torch.sym_or() or + # something like that). + torch._check( + guard_or_false(sizeA == 1) or guard_or_false(sizeB == 1) or sizeA == sizeB, + lambda: f"The size of tensor a ({sizeA}) " + f"must match the size of tensor b ({sizeB}) " + f"at non-singleton dimension {i})", + ) + expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA + return tuple(expandedSizes) + + +def make_fast_binary_impl( + slow_ref, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT +): + def fast_binary_impl(mode, *args, **kwargs): + def slow(msg): + count_label(f"slow {msg}") + with mode: + return slow_ref(*args, **kwargs) + + count_label("attempt fast") + + # Fast path (based off of TensorIterator fast path). + # Unfortunately, there is no way to easily deduplicate + # this with either the TensorIterator C++ implementation + # (which we don't want to SymIntify, and also the algorithm + # here is slightly different from TensorIterator to allow + # for broadcasting), nor the PrimTorch implementation + # (which does not actually implement a fast path.) + + operands = args + + # compute_shape + final_shape = None + for op in operands: + shape = op.shape if isinstance(op, torch.Tensor) else () + if final_shape is None: + final_shape = shape + # TODO: Minor optimization: track if the shapes + # were equal so you can skip the equality check + # below if unnecessary + final_shape = infer_size(final_shape, shape) + assert final_shape is not None + + from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_eq + + # Do some extra safety checks to see if the output + # stride is obvious + for op in operands: + if ( + isinstance(op, torch.Tensor) + and len(op.shape) == len(final_shape) + # take the slow path if result is not determined. + and guard_or_false(sym_eq(op.shape, final_shape)) + ): + break + else: + # if we never break in the for loop above we take the slow path. + return slow("both tensors nontrivially broadcast") + + # compute_types + cpu = torch.device("cpu") + common_device = cpu + common_dtype = None + has_different_input_dtypes = False + for op in operands: + if not isinstance(op, torch.Tensor): + # Use elementwise_dtypes for the tricky case + has_different_input_dtypes = True + continue + if common_device == cpu and not op.device.type == "cpu": + common_device = op.device + # Slightly simplified here as target_dtype cannot vary + if common_dtype is None: + common_dtype = op.dtype + elif common_dtype != op.dtype: + has_different_input_dtypes = True + + if has_different_input_dtypes: + # compute promotion + # TODO: we don't need the compute type + _, common_dtype = elementwise_dtypes( + *operands, type_promotion_kind=type_promotion_kind + ) + + # check all tensors on same device + # cpu scalars are assumed allow + current_cpu_scalars_on_non_cpu = 0 + max_cpu_scalars_on_non_cpu = 1 # hard coded atm + for op in operands: + if not isinstance(op, torch.Tensor): + continue + if common_device != cpu and op.dim() == 0 and op.device == cpu: + if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu: + return slow("error") + current_cpu_scalars_on_non_cpu += 1 + elif op.device != common_device: + return slow("error") + + # compute_fast_setup_type + definitely_contiguous = True + definitely_channels_last = True + # TODO: is_non-overlapping_and_dense (not bound from Python + # no inplace, no out, everything defined + + if is_noncontiguous_supported(common_device): + for op in operands: + if not isinstance(op, torch.Tensor): + continue + definitely_contiguous = ( + definitely_contiguous + and definitely_contiguous_for_memory_format( + op, memory_format=torch.contiguous_format + ) + ) + definitely_channels_last = ( + definitely_channels_last + and definitely_contiguous_for_memory_format( + op, memory_format=torch.channels_last + ) + ) + if definitely_contiguous: + # do contiguous + count_label("fast is_contiguous") + return FakeTensor( + mode, + torch.empty( + final_shape, + dtype=common_dtype, + device="meta", + memory_format=torch.contiguous_format, + ), + device=common_device, + ) + if definitely_channels_last: + count_label("fast channels_last") + # do channels last + return FakeTensor( + mode, + torch.empty( + final_shape, + dtype=common_dtype, + device="meta", + memory_format=torch.channels_last, + ), + device=common_device, + ) + + return slow("no contiguity match") + + return fast_binary_impl + + +# disable the python dispatcher to avoid decomposing detach() further +# (proxy_mode should still decompose detach() though) +def fast_detach(fake_mode, x, include_real=False): + with no_python_dispatcher(), in_kernel_invocation_manager(fake_mode): + out = torch.ops.aten.detach.default(x) + if include_real: + return FakeTensor(fake_mode, out, x.device, real_tensor=x.real_tensor) + return FakeTensor(fake_mode, out, x.device) + + +@functools.cache +def get_fast_op_impls(): + import torch._refs + + register_fast_op_impl(torch.ops.aten.add.Tensor)( + make_fast_binary_impl(torch._refs.add) + ) + register_fast_op_impl(torch.ops.aten.sub.Tensor)( + make_fast_binary_impl(torch._refs.sub) + ) + register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type] + register_fast_op_impl(torch.ops.aten.div.Tensor)( + make_fast_binary_impl( + torch._refs.div, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) + ) + register_fast_op_impl(torch.ops.aten.detach.default)(fast_detach) + return FAST_OP_IMPLEMENTATIONS diff --git a/phivenv/Lib/site-packages/torch/_subclasses/fake_tensor.py b/phivenv/Lib/site-packages/torch/_subclasses/fake_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..8b5f89be9d180cb08e144d6f370f308143c15c11 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_subclasses/fake_tensor.py @@ -0,0 +1,3258 @@ +# mypy: allow-untyped-decorators +from __future__ import annotations + +import atexit +import contextlib +import dataclasses +import functools +import logging +import math +import os +import threading +import traceback +import types +import typing +import weakref +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Callable, cast, Literal, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import Self, TypeGuard +from weakref import ReferenceType + +import torch +import torch._library.utils as library_utils +from torch import SymBool, SymFloat, SymInt, Tensor +from torch._C._functorch import is_functorch_wrapped_tensor, is_legacy_batchedtensor +from torch._library.fake_class_registry import FakeScriptObject +from torch._library.fake_profile import MissingOpProfile +from torch._logging import dtrace_structured +from torch._prims_common import suggest_memory_format +from torch._subclasses.meta_utils import ( + assert_eq, + assert_metadata_eq, + is_sparse_any, + is_sparse_compressed, + MetaConverter, +) +from torch._utils import render_call +from torch.fx.immutable_collections import immutable_dict +from torch.fx.operator_schemas import normalize_function +from torch.multiprocessing.reductions import StorageWeakRef +from torch.overrides import TorchFunctionMode +from torch.types import IntLikeType, py_sym_types +from torch.utils._backport_slots import dataclass_slots +from torch.utils._mode_utils import no_dispatch +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + TorchDispatchMode, +) +from torch.utils._pytree import KeyPath, keystr, PyTree, tree_map, tree_map_, TreeSpec +from torch.utils._stats import count +from torch.utils._traceback import CapturedTraceback + +from ._fake_tensor_utils import _CacheKeyState, _PySymInputStub, _SymIntOutputStub + + +if TYPE_CHECKING: + from collections.abc import Generator, Iterable, Mapping, Sequence + from types import TracebackType + + from torch._guards import Source + from torch._ops import OpOverload + from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext + +log = logging.getLogger(__name__) +hc_log = torch._logging.getArtifactLogger(__name__, "hierarchical_compile") + +# TODO: Hack to unblock https://github.com/pytorch/pytorch/pull/108186 +# Proper fix tracked by https://github.com/pytorch/pytorch/issues/120105 +try: + not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") +except ValueError as e: + if "'not_implemented' not registered" in str(e): + not_implemented_log = logging.getLogger(__name__ + ".not_implemented") + else: + raise e + + +DimList = list + +pytree = torch.utils._pytree +T = TypeVar("T") + +aten = torch._ops.ops.aten + +CONSTANT_NUMEL_LIMIT = 1 + +RECURSION_COUNT = 0 + + +# Small helper that increments recursion count, and +# resets it when the object goes out of scope. Useful +# if you don't want to increase indentation which is +# what a context manager would do. +class IncrementRecursionCount: + def __init__(self) -> None: + global RECURSION_COUNT + RECURSION_COUNT += 1 + + def __del__(self) -> None: + global RECURSION_COUNT + RECURSION_COUNT -= 1 + + +@dataclass +class UnsupportedFakeTensorException(RuntimeError): + reason: str + + +@dataclass +class DynamicOutputShapeException(RuntimeError): + func: OpOverload + + +@dataclass +class DataDependentOutputException(RuntimeError): + func: OpOverload + + +@dataclass +class UnsupportedOperatorException(RuntimeError): + func: OpOverload + + +@dataclass +class UnsupportedMutationAliasingException(RuntimeError): + reason: str + + +@dataclass +class MetadataMismatchError(RuntimeError): + reason: str + + +class FakeTensorTLS(threading.local): + # Default to None, otherwise it'll be used to override _all_ + # `FakeTensorMode.allow_non_fake_inputs` in this thread. + allow_non_fake_inputs_override: Optional[bool] + + def __init__(self) -> None: + self.allow_non_fake_inputs_override = None + + +fake_tensor_tls = FakeTensorTLS() + + +def ordered_set(*items: T) -> dict[T, Literal[True]]: + return dict.fromkeys(items, True) + + +@contextlib.contextmanager +def unset_fake_temporarily() -> Generator[Optional[TorchDispatchMode], None, None]: + old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) + try: + yield old + finally: + if old is not None: + torch._C._set_dispatch_mode(old) + + +@contextlib.contextmanager +def disable_fake_tensor_cache(fake_mode: FakeTensorMode) -> Generator[None, None, None]: + old_value: bool = fake_mode.cache_enabled + try: + fake_mode.cache_enabled = False + yield + finally: + fake_mode.cache_enabled = old_value + + +def get_plain_tensors( + subclass: Tensor, *, out: list[Union[Tensor, int, SymInt]] +) -> list[Union[Tensor, int, SymInt]]: + # This function is used in Runtime, do not add redundant asserts + todo = [subclass] + while todo: + curr = todo.pop() + if not is_traceable_wrapper_subclass(curr): + out.append(curr) + continue + + inner_keys, _ = curr.__tensor_flatten__() + todo.extend(getattr(curr, key) for key in reversed(inner_keys)) + + return out + + +def is_fake(x: object) -> TypeGuard[Tensor]: + from torch._subclasses.functional_tensor import FunctionalTensor + + if isinstance(x, FakeTensor): + return True + if is_traceable_wrapper_subclass(x): + attrs, _ = type(x).__tensor_flatten__(x) + flattened_tensors = [getattr(x, attr) for attr in attrs] + all_fake = all(is_fake(x) for x in flattened_tensors) + any_fake = any(is_fake(x) for x in flattened_tensors) + assert all_fake == any_fake, "got mixed fake and real tensors!" + return all_fake + elif isinstance(x, FunctionalTensor): + return is_fake(x.elem) + elif isinstance(x, Tensor) and torch._is_functional_tensor(x): + reapply_views = torch._C._functionalization_reapply_views_tls() + unwrapped = torch._C._functorch._unwrap_functional_tensor(x, reapply_views) + return is_fake(unwrapped) + elif isinstance(x, Tensor) and is_functorch_wrapped_tensor(x): + unwrapped = torch._C._functorch.get_unwrapped(x) + return is_fake(unwrapped) + return False + + +def maybe_get_fake_mode(t: object) -> Optional[FakeTensorMode]: + from torch._subclasses.functional_tensor import FunctionalTensor + + if isinstance(t, FakeTensor): + return t.fake_mode + if is_traceable_wrapper_subclass(t): + inner_tensor_names, _ = t.__tensor_flatten__() + modes = [ + maybe_get_fake_mode(getattr(t, t_name)) for t_name in inner_tensor_names + ] + m = modes[0] + assert all(m is x for x in modes) + return m + elif isinstance(t, FunctionalTensor): + return maybe_get_fake_mode(t.elem) + elif isinstance(t, Tensor) and torch._is_functional_tensor(t): + reapply_views = torch._C._functionalization_reapply_views_tls() + unwrapped = torch._C._functorch._unwrap_functional_tensor(t, reapply_views) + return maybe_get_fake_mode(unwrapped) + elif isinstance(t, Tensor) and is_functorch_wrapped_tensor(t): + unwrapped = torch._C._functorch.get_unwrapped(t) + return maybe_get_fake_mode(unwrapped) + return None + + +@functools.cache +def get_schema_info(func: OpOverload) -> torch._C._SchemaInfo: + return torch._C._SchemaInfo(func._schema) + + +# many of the decompositions registered to torch/_prims do not at the moment model +# aliasing or strides, so as an incremental step, just enable the decompositions in +# torch/_decomp/decompositions.py. +# decomps are used for aot autograd tracing so we would like to unify on their +# implementation and add additional testing to them +@functools.cache +def torch_decomp_decompositions(func: OpOverload) -> bool: + from torch._decomp import decomposition_table + + decompositions = torch._decomp.decompositions + # Note that the function in the decomposition table might be + # different from the one in the module because of the difference + # in out handling in aten API and torch public API + return decomposition_table[func].__module__.startswith( + "torch._decomp" + ) and decomposition_table[func].__name__ in dir(decompositions) + + +def tree_flatten_only(ty: type[T], tree: PyTree) -> list[T]: + flat_vals = pytree.tree_leaves(tree) + return [elem for elem in flat_vals if isinstance(elem, ty)] + + +def _is_plain_tensor(t: object) -> bool: + return ( + type(t) is Tensor + and t.layout == torch.strided + and not ( + t.is_sparse + or t.is_nested + or is_functorch_wrapped_tensor(t) + or is_legacy_batchedtensor(t) + or torch._is_functional_tensor(t) + ) + ) + + +# Similar to `MetaConverter`, this is a class for converting +# multiple tensors into fake tensors which share the same view/storage +# structure. Like `MetaConverter`, it uses `WeakIdRef` to +# hold a weak reference for all memoized tensors. +class FakeTensorConverter: + @property + def tensor_memo( + self, + ) -> weakref.WeakValueDictionary: + # not valid until py3.10 + # weakref.WeakValueDictionary["torch._subclasses.meta_utils.MetaTensorId", Optional["FakeTensor"]] + return self.meta_converter.tensor_memo + + meta_converter: MetaConverter + constant_storage_mapping: dict[StorageWeakRef, list[ReferenceType]] + export: bool + + def __init__(self, *, copy_data: bool = False, export: bool = False) -> None: + self.meta_converter = MetaConverter(copy_data=copy_data) + self.export = export + + # map from to storage to corresponding constant tensors + self.constant_storage_mapping = {} + + def add_constant_storage_mapping(self, fake_tensor: FakeTensor) -> None: + # when you have a constant, aliased tensor: + # const_tensor.add_(torch.rand([1])) + # all aliases of it must become no longer const + assert isinstance(fake_tensor, FakeTensor) and fake_tensor.constant is not None + weak_st = StorageWeakRef(fake_tensor.constant._typed_storage()) + + # we need a map from a weak storage to all of its corresponding + # constant tensors. python doesn't have the weak value equivalent + # of defaultdict(list), so we are using a WeakValueDictionary as one + if weak_st not in self.constant_storage_mapping: + self.constant_storage_mapping[weak_st] = [] + self.constant_storage_mapping[weak_st].append(weakref.ref(fake_tensor)) + + def invalidate_constant_aliases(self, tensor: Tensor) -> None: + assert not isinstance(tensor, FakeTensor) + + weak_st = StorageWeakRef(tensor._typed_storage()) + if weak_st not in self.constant_storage_mapping: + return + + for weak_tensor_ref in self.constant_storage_mapping[weak_st]: + ten = weak_tensor_ref() + if ten is not None: + ten._fix_weakref() + ten.constant = None + + del self.constant_storage_mapping[weak_st] + + def _get_memo(self, t: Tensor) -> Optional[FakeTensor]: + tid = self.meta_converter.describer.lookup_tensor.get(t) + if tid is None: + return None + return self.tensor_memo.get(tid) + + def set_tensor_memo(self, t: Tensor, v: FakeTensor) -> None: + tid = self.meta_converter.describer.get_tensor_id(t) + self.meta_converter.tensor_memo[tid] = v + + # You can have a real tensor that you need to convert into a fake tensor. + # If you have a meta tensor already, call from_meta_and_device. + # + # You're allowed to pass a meta tensor to be turned into a fake + # tensor; although an odd thing to do, this can occur if you're doing + # cross ref testing and the inner test is already operating on meta tensors. + def from_real_tensor( + self, + fake_mode: FakeTensorMode, + t: Tensor, + make_constant: bool = False, + shape_env: Optional[ShapeEnv] = None, + *, + source: Optional[Source] = None, + symbolic_context: Optional[SymbolicContext] = None, + trace: bool = True, + ) -> FakeTensor: + # see note [Tensor Fakification and Symbol Caching] + if not symbolic_context and not source and shape_env: + if tracing_context := torch._guards.TracingContext.try_get(): + if t in tracing_context.tensor_to_context: + symbolic_context = tracing_context.tensor_to_context[t] + from torch.fx.experimental.symbolic_shapes import ( + StatefulSymbolicContext, + ) + + assert isinstance(symbolic_context, StatefulSymbolicContext) + source = symbolic_context.tensor_source + + maybe_memo = self._get_memo(t) + if maybe_memo is not None: + return maybe_memo + # not yet supported in metatensors + if t.is_quantized: + raise UnsupportedFakeTensorException("quantized nyi in meta tensors") + if type(t) is torch.nn.Parameter: + assert not make_constant + + constant = t if make_constant else None + + # This callback is used by both subclass and inner tensors. Require the + # caller to explicitly specify the device in case outer and inner tensors + # have different devices. + def mk_fake_tensor( + make_meta_t: Callable[[], object], device: Union[torch.device, str] + ) -> FakeTensor: + # NB: don't use in_kernel_invocation_manager. to + # ensure FakeTensor can internally do constant computation + # as necessary. Invocation manager is "more correct" as + # it works for more operators in make_meta_t, but + # invariant is that make_meta_t only calls factories + # for which it is not strictly necessary to use the + # invocation manager (I think!) + with no_dispatch(): + return FakeTensor( + fake_mode, + make_meta_t(), + device, + # TODO: callback might be used in recursive contexts, in + # which case using t is wrong! BUG! + constant=constant, + ) + + out = self.meta_converter( + t, + shape_env=shape_env, + callback=mk_fake_tensor, + source=source, + symbolic_context=symbolic_context, + trace=trace, + ) + if out is NotImplemented: + raise UnsupportedFakeTensorException("meta converter nyi") + + from torch._dynamo.source import RandomValueSource + + value = None + if ( + not self.export + and _is_plain_tensor(t) # mostly, we want to know if item() works + and t.dim() == 0 + and t.device.type == "cpu" + # All integer types are fair game, because signed overflow is UB + # (and even int64 can overflow, since integers in Python are + # arbitrary precision). But only float64 is OK for float, because + # switching between float32 and float64 changes semantics in an + # observable way without hitting UB. + and t.dtype + in [torch.int64, torch.int32, torch.int16, torch.int8, torch.float64] + and source is not None + # Impede setting up item() on things coming from random. These + # are not "real" item() calls, instead UnspecializedPythonVariable + # is unsafely pretending an int is a tensor, which can sometimes + # implicitly cause an item call. The problem is this is pretty + # unsound: there's no reason substituting an int with a Tensor is + # going to give the same results. Today, you mostly get around + # this by typically not having capture_scalar_outputs on and graph + # breaking when someone tries to use the unspec variable in an + # int-y context. But allowing it through here would break that. + # So don't. + # + # Once random values are setup to be represented as + # SymNodeVariable, this condition can be removed. To check if + # you've done it right, this is a good test: + # + # PYTORCH_TEST_WITH_DYNAMO=1 python test/test_reductions.py -k + # TestReductionsCPU.test_dim_reduction_fns_fn_name_amax_cpu_bfloat16 + and not isinstance(source, RandomValueSource) + # In Dynamo, shape_env is never none (even with static shapes). + # However, FakeTensorMode can be used by hand and in some cases + # ShapeEnv is not allocated. + and shape_env is not None + ): + from torch._dynamo.source import CallMethodItemSource, FloatTensorSource + from torch.fx.experimental.symbolic_shapes import DimDynamic + + with no_dispatch(): + value = t.item() + if not math.isnan(value) and not math.isinf(value): + # Peephole strip out unnecessary torch.as_tensor(x).item() + if isinstance(source, FloatTensorSource): + item_source = source.base + else: + item_source = CallMethodItemSource(source) + symbol = shape_env.create_unspecified_symbol( + value, + source=item_source, + dynamic_dim=DimDynamic.DYNAMIC, + symbolic_context=symbolic_context, + ) + # NB: reusing item_memo here ensures that we invalidate on + # mutation + if t.dtype == torch.int64: + out.item_memo = shape_env.create_symintnode( + symbol, + hint=value, + source=item_source, + ) + elif t.dtype == torch.float64: + out.item_memo = shape_env.create_symfloatnode( + symbol, + hint=value, + source=item_source, + ) + if make_constant: + self.add_constant_storage_mapping(out) + # NB: meta_converter set the memo + return out + + # If you specify the device, it MUST be a meta tensor. + def from_meta_and_device( + self, + fake_mode: FakeTensorMode, + t: Tensor, + device: torch.device, + pytype: Optional[type[torch.Tensor]] = None, + dispatch_keys: Optional[torch.DispatchKeySet] = None, + ) -> FakeTensor: + assert ( + t.device.type == "meta" + ), f"tensor's device must be `meta`, got {t.device.type} instead" + # This is a bit abusive (this is not the "real" tensor) but whatever, + # the meta tensor should be fresh so there's no way to get it wrong + maybe_memo = self._get_memo(t) + if maybe_memo is not None: + return maybe_memo + out = FakeTensor( + fake_mode, t, device, pytype=pytype, dispatch_keys=dispatch_keys + ) + self.set_tensor_memo(t, out) + return out + + +@functools.cache +def init_gpu_context(device: torch.device) -> None: + # Backward will error with cuda Fake Tensors if no cuda tensors have been initialized first + if torch.cuda.is_available() or torch.xpu.is_available(): + ( + torch.empty(1, device=device) + if torch.version.hip is None + else torch.zeros(1, device=device) + ) + + +@contextlib.contextmanager +def in_kernel_invocation_manager( + fake_mode: FakeTensorMode, +) -> Generator[None, None, None]: + # See: note [Fake Tensor Dispatch Keys] + prev_in_kernel = fake_mode.in_kernel_invocation + meta_in_tls = torch._C._meta_in_tls_dispatch_include() + assert meta_in_tls == prev_in_kernel, f"{meta_in_tls}, {prev_in_kernel}" + + with torch._C._DisableTorchDispatch(): + fake_mode.in_kernel_invocation = True + # Unfortunately _set_meta_in_tls_dispatch_include(False) can leave + # `Dense` turned on (because it's implied by `Meta`) + with torch._C._PreserveDispatchKeyGuard(): + torch._C._set_meta_in_tls_dispatch_include(True) + try: + yield + finally: + fake_mode.in_kernel_invocation = prev_in_kernel + # torch._C._set_meta_in_tls_dispatch_include(prev_in_kernel) + + +# Return if the function allows Python numbers to bind to Tensors +def should_allow_numbers_as_tensors(func: OpOverload) -> bool: + return torch._C._should_allow_numbers_as_tensors( + func.name().split("::")[-1].split(".")[0] + ) + + +class FakeTensorConfig: + debug = os.environ.get("TORCH_FAKE_TENSOR_DEBUG", "0") == "1" + + +# This memorizes unbacked SymInt or SymFloats representing quantities like the +# number of nonzero elements in this tensor or learning rate. There is one +# instance of the descriptor per particular quantity to memoize. +# +# Memoization is helpful if you do something like x[mask] and y[mask]; +# mask.nonzero() gets repeatedly called and should give a consistent unbacked +# SymInt. It needs to be invalidated in the same way constant is. +# +# Making this a descriptor may seem overly fancy, but actually it's the most +# convenient way to ensure access to FakeTensor during access, which is +# required for testing version counter and epoch validity. +class SymNumberMemoDescriptor: + _name: str + + # By default, SymInts in this memo are invalidated across versions/epochs. + # nested_ints however are preserved across epochs and across versions. + # Preserving across versions is okay for nested int since the association + # of a nested int is agnostic to the underlying data and nested ints are not + # shared across multiple distinct tensors. + _is_nested_int: bool + + def __init__(self, *, is_nested_int: bool = False) -> None: + self._is_nested_int = is_nested_int + + def __set_name__(self, owner: str, name: str) -> None: + self._name = name + + def _memo(self, obj: FakeTensor) -> str: + return f"_{self._name}" + + def _memo_vc(self, obj: FakeTensor) -> str: + return f"_{self._name}_vc" + + # When we retrace, we need to invalidate all the memos so that we can + # accurately identify the first time unbacked SymInts are allocated. + # This is only relevant for inputs; for intermediates, they will get fresh + # fake tensors so you won't have a memo anyway + def _memo_epoch(self, obj: FakeTensor) -> str: + return f"_{self._name}_epoch" + + def __get__( + self, obj: FakeTensor, objtype: Optional[type[FakeTensor]] = None + ) -> Optional[Union[torch.SymInt, torch.SymFloat]]: + if (r := getattr(obj, self._memo(obj))) is None: + return None + + # If backed, it's ok to preserve memo since we know it won't renumber. + if isinstance(r, torch.SymFloat) and r.node.hint is not None: + return r + + # Version counter based tracking isn't 100% sound but it's close + # enough + if ( + not self._is_nested_int and getattr(obj, self._memo_vc(obj)) != obj._version + ) or ( + not self._is_nested_int + and getattr(obj, self._memo_epoch(obj)) != obj.fake_mode.epoch + ): + setattr(obj, self._memo(obj), None) + return None + return r + + def __set__( + self, obj: FakeTensor, value: Optional[Union[torch.SymInt, torch.SymFloat]] + ) -> None: + if value is None: + setattr(obj, self._memo(obj), None) + setattr(obj, self._memo_vc(obj), None) + setattr(obj, self._memo_epoch(obj), None) + elif not obj.is_inference() or self._is_nested_int: + setattr(obj, self._memo(obj), value) + if not self._is_nested_int: + setattr(obj, self._memo_vc(obj), obj._version) + setattr(obj, self._memo_epoch(obj), obj.fake_mode.epoch) + + +class FakeTensor(Tensor): + """ + Meta tensors give you the ability to run PyTorch code without having to + actually do computation through tensors allocated on a `meta` device. + Because the device is `meta`, meta tensors do not model device propagation. + FakeTensor extends MetaTensors to also carry an additional `fake_device` + which tracks devices that would have been used. + """ + + fake_device: torch.device + fake_mode: FakeTensorMode + constant: Optional[Tensor] + real_tensor: Optional[Tensor] + + # TODO: Generalize this as needed, e.g., into a trie of memos, if + # you do something like x[0].item() (x[0] is fresh each time, so + # memo mechanism here won't work) + nonzero_memo = SymNumberMemoDescriptor() + item_memo = SymNumberMemoDescriptor() + unique_memo = SymNumberMemoDescriptor() + unique_consecutive_memo = SymNumberMemoDescriptor() + + # We expect nested_int_memo to be None when an offsets is a graph + # intermediate, or an input that has never been associated with a + # nested int. + nested_int_memo = SymNumberMemoDescriptor(is_nested_int=True) + + # FakeTensor doesn't fully emulate the original tensor's Python type + # and dispatch key set, therefore sometimes we want to track them + # separately. + pytype: Optional[type[Tensor]] + dispatch_keys: Optional[torch.DispatchKeySet] + + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + _mode_key = torch._C._TorchDispatchModeKey.FAKE + + @property + def device(self) -> torch.device: + if self.fake_mode.in_kernel_invocation: + return torch.device("meta") + else: + return self.fake_device + + @device.setter + def device(self, _: torch.device) -> None: + raise NotImplementedError + + # Note: [Fake Tensor Dispatch Keys] + # In order to model the behavior of device-specific autocast + # and autograd logic, we update the dispatch keys of FakeTensors + # to reflect their fake device. This includes the BackendComponent + # (DispatchKey::Meta -> DispatchKey::CUDA), and also the BackendComponent + # related Autocast and Autograd keys. __torch_dispatch__ sits below + # Autocast and Autograd, and is only invoked when we are at the + # kernel for the BackendComponent. Then, we add Meta to the + # thread-local dispatch include set to hit the meta kernel + # instead of the kernel of the BackendComponent for the fake device. + # The `device_for_backend_keys` does that below + # NOTE: this probably will not do the right thing for backends + # that have dispatch keys which are higher than the "meta" key: + # https://github.com/pytorch/pytorch/blob/main/c10/core/DispatchKey.h#L189 + + # We don't support named tensors; graph break + @property + def names(self) -> list[str]: + raise UnsupportedFakeTensorException( + "torch.compile doesn't support named tensors" + ) + + @names.setter + def names(self, _: list[str]) -> None: + raise NotImplementedError + + @staticmethod + def __new__( + cls, + fake_mode: FakeTensorMode, + elem: Tensor, + device: torch.device, + constant: Optional[Tensor] = None, + real_tensor: Optional[Tensor] = None, + pytype: Optional[type[Tensor]] = None, + dispatch_keys: Optional[torch.DispatchKeySet] = None, + ) -> Self: + self = Tensor._make_subclass( + cls, + elem, + elem.requires_grad, + dispatch_device=True, + device_for_backend_keys=device, + ) + if not fake_mode._allow_unsafe_data_ptr_access: + torch._C._set_throw_on_mutable_data_ptr(self) + else: + torch._C._set_warn_deprecated_on_mutable_data_ptr(self) + + assert elem.device.type == "meta", elem.device.type + device = device if isinstance(device, torch.device) else torch.device(device) + # NB: it is fine, if a little confusing, for device to be meta + # (we are faking a meta tensor in that case). However, it often + # indicates some sort of confusion (e.g., you accidentally passed + # in a meta tensor when you should have passed in the real tensor). + # So by default we disallow meta, and if you are working in a situation + # where it is helpful (e.g., crossref testing) you can turn it back + # on + if not fake_mode.allow_meta: + assert device.type != "meta" + # normalize device. + if device.type in ["cuda", "xpu"]: + init_gpu_context(device) + + if ( + device.type + in ["cuda", "hpu", "xpu", "mps", torch._C._get_privateuse1_backend_name()] + and device.index is None + ): + if device.type != "mps" and getattr(torch, device.type).is_initialized(): + device = torch.device( + f"{device.type}:{getattr(torch, device.type).current_device()}" + ) + else: + device = torch.device(f"{device.type}:0") + self.fake_device = device + self.fake_mode = fake_mode + self.constant = constant + self.pytype = pytype + self.dispatch_keys = dispatch_keys + assert not isinstance(real_tensor, FakeTensor) + self.real_tensor = real_tensor + self.nonzero_memo = None + self.item_memo = None + self.unique_memo = None + self.unique_consecutive_memo = None + self.nested_int_memo = None + + if FakeTensorConfig.debug: + self._debug_trace = CapturedTraceback.extract() # type: ignore[attr-defined] + return self + + # In some circumstances, a conventional Tensor constructor + # will get rewritten to call into FakeTensor. We must provide an + # __init__ method that can accept the Python interpreters initialization + # in such a situation; we must also be able to handle direct fake + # tensor construction via FakeTensor(). + # + # In particular, the __init__ call will look funny in the following case: + # + # with FakeTensorMode(): + # x = Tensor([1, 2, 3]) + # + # this desugars into: + # + # with FakeTensorMode(): + # x = Tensor.__new__([1, 2, 3]) + # # NB: x is a fake tensor, because of the mode! + # x.__init__([1, 2, 3]) # not the normal fake tensor args! + # + def __init__(self, *args: object, **kwargs: object) -> None: + super().__init__() + + @staticmethod + def from_tensor(t: Tensor, fake_mode: FakeTensorMode) -> FakeTensor: + return fake_mode.from_tensor(t) + + @classmethod + @count + def __torch_dispatch__( # type: ignore[override] # TODO + cls, + func: OpOverload, + types: Sequence[type], + args: Sequence[object] = (), + kwargs: Mapping[str, object] = immutable_dict(), + ) -> object: + # need to handle here to avoid infinite recursion + # see [in_kernel_invocation] + if func == torch.ops.prim.device.default: + assert len(args) == 1 and isinstance(args[0], FakeTensor) + if args[0].fake_mode.in_kernel_invocation: + return torch.device("meta") + else: + return args[0].fake_device + + # this handler must be done inside FakeTensor subclass, not mode, because + # we can end up dispatching here when we have a fake tensor with + # symbolic sizes running under in_kernel_invocation_manager. + # The subclass is asked to handle this query because size (not + # sym_size) was called, but we are unable to serve it directly because + # there are symbolic sizes in the class. The use of + # in_kernel_invocation_manager means it's incorrect to activate a + # mode to actually handle this (this caused + # https://github.com/pytorch/pytorch/issues/122772). + if handler := _DISPATCH_META_HANDLERS.get(func): + return handler(args) + + # Because fake mode can return NotImplemented (if it sees a subclass + # it doesn't know how to deal with), this test here is important + # because the next dispatch after a fake mode will attempt to use + # subclasses of tensors to dispatch, and any FakeTensor arguments + # will be considered eligible. + unrecognized_types = [ + t for t in types if not issubclass(t, FakeTensor) and t is not Tensor + ] + if unrecognized_types: + not_implemented_log.debug( + "FakeTensor unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + + fake_mode = None + for arg in pytree.arg_tree_leaves(*args, **kwargs): + if isinstance(arg, FakeTensor): + fake_mode = arg.fake_mode + break + + assert fake_mode is not None + + # If the fake mode is already active, don't try to reapply it! + # NotImplemented is the right thing to return here, because the + # typical situation this can occur is if ProxyTensorMode returned a + # NotImplemented because of a not implemented subclass; we may have + # unluckily attempted to hit FakeTensor's dispatch first, + # NotImplemented lets us keep chaining until we find the actual + # subclass + maybe_cur_fake_mode = torch._C._get_dispatch_mode( + torch._C._TorchDispatchModeKey.FAKE + ) + if maybe_cur_fake_mode: + not_implemented_log.debug( + "FakeTensor mode already active: %s in %s", + fake_mode, + maybe_cur_fake_mode, + ) + return NotImplemented + + assert not fake_mode.in_kernel_invocation + + with fake_mode: + return func(*args, **kwargs) + + @staticmethod + def _find_common_device( + func: OpOverload, flat_args: Sequence[object] + ) -> tuple[torch.device, bool]: + # Returns: (common_device, has_scalar_only_inputs) + + # cpu - zero-dim tensors can be called in cuda kernels, + # so overwrite the common_device if it the only existing + # device comes from a cpu zero-dim tensor + common_device = None + has_scalar_only_inputs = False + is_cpu_zero_dim = None + + # list of ops which can have args(tensor/tensorList) in mixed device + mixed_device_fns = ordered_set( + aten._foreach_copy.default, + ) + + def check_cpu_device(device: torch.device) -> bool: + return device.type == "cpu" + + def cpu_zero_dim(t: Tensor) -> bool: + return check_cpu_device(t.device) and t.dim() == 0 + + def merge_devices(t: object) -> None: + nonlocal common_device + nonlocal is_cpu_zero_dim + if not isinstance(t, FakeTensor): + return + + if common_device is None: + common_device = t.device + is_cpu_zero_dim = cpu_zero_dim(t) + return + + t_is_cpu_zero_dim = cpu_zero_dim(t) + if t.device == common_device: + if is_cpu_zero_dim: + is_cpu_zero_dim = t_is_cpu_zero_dim + return + + # mismatching devices ! + # if current tensor is cpu 0 dim, defer to existing device + if t_is_cpu_zero_dim: + return + + # current device is from cpu 0 dim tensor, overwrite + if is_cpu_zero_dim: + common_device = t.device + is_cpu_zero_dim = t_is_cpu_zero_dim + return + + # if still device mismatches we will check ops which can work + # on different devices for ex. _foreach_copy, and one of the + # device must be cpu in this case we will return from here without + # throwing an error + if func in mixed_device_fns: + if any(map(check_cpu_device, (common_device, t.device))): + return + + # mismatching devices of non-zero dim tensors, throw + # This might be valid behavior and need to be explicitly modeled, e.g. reshape_as + raise RuntimeError( + f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}" + ) + + for arg in flat_args: + merge_devices(arg) + + # some functions that allow Python numbers to bind to Tensors + # if we have failed to find a device, and we're running one of these operators, + # we must have scalar only inputs + if should_allow_numbers_as_tensors(func) and common_device is None: + # ops with scalar only inputs always have result on cpu + has_scalar_only_inputs = True + common_device = torch.device("cpu") + + assert common_device is not None, f"Could not find common device for {func}" + + return common_device, has_scalar_only_inputs + + def get_nested_int( + self, + *, + coeff: Union[int, torch.SymInt] = 1, + ) -> torch.SymInt: + if self.nested_int_memo is None: + self.nested_int_memo = self.fake_mode.create_symbolic_nested_int( + nt_tensor_id=None + ) + assert isinstance(self.nested_int_memo, torch.SymInt) + return self.nested_int_memo * coeff + + # Similar to FunctionalTensor.tolist + def tolist(self) -> Any: + if self.dim() == 0: + return self.item() + elif self.dim() == 1: + return [elem.item() for elem in self] + else: + return [elem.tolist() for elem in self] + + +_MetadataIntLike = Union[IntLikeType, "_PySymInputStub", "_SymIntOutputStub"] + + +@dataclass_slots +@dataclass +class TensorMetadata: + """ + The Tensor metadata relevant to hashing FakeTensors when caching. + """ + + dtype: torch.dtype + shape: tuple[_MetadataIntLike, ...] + stride: tuple[_MetadataIntLike, ...] + device: torch.device + layout: torch.layout + memory_format: Optional[torch.memory_format] + storage_offset: _MetadataIntLike + storage_bytes: Optional[_MetadataIntLike] + requires_grad: bool + is_quantized: bool + is_conj: bool + is_neg: bool + is_inference: bool + is_sparse: bool # read: is sparse COO + is_coalesced: Optional[bool] + dense_dim: Optional[int] + sparse_dim: Optional[int] + + def _flatten_into( + self, + result: list[object], + mode: FakeTensorMode, + state: _CacheKeyState, + ) -> None: + # Flatten the TensorMetadata out into `result`. Make sure to call + # state.convert_sym_int() on any SymInts. + for field in dataclasses.fields(self): + value = getattr(self, field.name) + if isinstance(value, (tuple, list, torch.Size)): + # This will recursively flatten the iterable, calling + # convert_sym_int() as necessary. + id_hashed_objects: list[object] = [] + mode._prep_args_for_hash(result, value, state, id_hashed_objects) + id_hashed_objects.clear() + elif isinstance(value, SymInt): + state.convert_sym_int(result, value) + else: + result.append(value) + + +def extract_tensor_metadata(t: Tensor) -> TensorMetadata: + """ + Extract the TensorMetadata of a tensor. + """ + memory_format = suggest_memory_format(t) + # Don't call is_contiguous() on a Tensor which has symbolic sizes or things + # will go badly (guards will be messed up?) + if ( + t._has_symbolic_sizes_strides + or is_sparse_any(t) + or not t.is_contiguous(memory_format=memory_format) + ): + memory_format = None # type: ignore[assignment] + + storage_offset = t.storage_offset() + + return TensorMetadata( + t.dtype, + t.shape, + t.stride() if t.layout == torch.strided else (), + t.device, + t.layout, + memory_format, + storage_offset, + # Only set storage_bytes for tensors that have storage (not sparse) + t.untyped_storage().nbytes() if not is_sparse_any(t) else None, + t.requires_grad, + t.is_quantized, + t.is_conj(), + t.is_neg(), + t.is_inference(), + t.is_sparse, + t.is_coalesced() if t.is_sparse else None, + t.dense_dim() if is_sparse_any(t) else None, + t.sparse_dim() if is_sparse_any(t) else None, + ) + + +@dataclass_slots +@dataclass +class _DispatchCacheKey: + """ + Key for the FakeTensor dispatch cache. + """ + + key: tuple[object, ...] + hashvalue: int + + def __init__(self, tup: tuple[object, ...]) -> None: + self.key = tup + self.hashvalue = hash(tup) + + def __eq__(self, other: object) -> bool: + return isinstance(other, _DispatchCacheKey) and self.key == other.key + + def __hash__(self) -> int: + return self.hashvalue + + def strip_shape_env(self) -> None: + # We need to strip the ShapeEnv from any values before we store in the + # cache so the cache doesn't keep our ShapeEnvs alive. + for v in self.key: + if isinstance(v, _PySymInputStub): + v.strip_shape_env() + + +# Default value for constant_value in _DispatchCacheEntryOutputInfo. This is +# only for checking and differentiates from None. +class SingletonConstant: + pass + + +@dataclass_slots +@dataclass(frozen=True) +class _DispatchCacheEntryOutputInfo: + """ + Entry type for the FakeTensor dispatch cache for an output. Accounts for three + possibilities: + 1) The op is inplace, and a hit means we need to alias the argument at a + given index. + 2) We need to synthesize a new FakeTensor given tensor metadata. For view + ops, we further capture the index of the arg to alias. + 3) if the tensor related fields are None, then it is a constant value (e.g. + None or integer) + """ + + inplace_idx: Optional[int] + metadata: Optional[TensorMetadata] + view_idx: Optional[int] + constant_value: Optional[Any] = SingletonConstant + + +@dataclass_slots +@dataclass(frozen=True) +class _DispatchCacheValidEntry: + """ + Entry type for the FakeTensor dispatch cache. It supports two types of outputs + 1) tensor + 2) tuple of tensors + + is_output_tuple flag helps in differentiating the return type + """ + + output_infos: tuple[_DispatchCacheEntryOutputInfo] + is_output_tuple: bool = False + + +@dataclass_slots +@dataclass(frozen=True) +class _DispatchCacheBypassEntry: + """ + Entry type for a negative cache entry. + """ + + reason: str + + +if TYPE_CHECKING: + _DispatchCacheEntry = Union[_DispatchCacheValidEntry, _DispatchCacheBypassEntry] + + +@dataclass_slots +@dataclass(frozen=True) +class _BypassDispatchCache(Exception): + """ + Signals cases that should skip FakeTensor caching. + """ + + reason: str + + +@dataclass_slots +@dataclass(frozen=True) +class DispatchCacheInfo: + """ + Information about the state of the FakeTensor dispatch cache. + """ + + hits: int + misses: int + bypasses: dict[str, int] + size: int + + +# We keep one instantiation of `fake_tensor_converter` active +# for the duration of `with FakeTensorMode()`. +# This allows accurate storage aliasing across invocation of +# different operators. While this will keep all freshly allocated +# tensors alive during `FakeTensorMode`, there will be no +# new allocations of Tensors which have non-meta storage so +# memory should not significantly increase. + + +class FakeTensorMode(TorchDispatchMode): + cache: dict[_DispatchCacheKey, _DispatchCacheEntry] = {} + cache_hits: int = 0 + cache_misses: int = 0 + cache_bypasses: dict[str, int] = defaultdict(int) + # Every time you retrace using the same fake tensor mode, you should + # advance the epoch so we don't reuse unbacked memos + epoch: int = 0 + in_kernel_invocation: bool = False + static_shapes: bool + shape_env: Optional[ShapeEnv] + _stack: Optional[str] + allow_meta: bool + + # NestedTensor uses a tensor_id_counter to uniquely identify offsets. + # This counter is incremented when an offsets is used to create an NJT + # for the first time. To avoid mutating eager state if we construct NJT + # during tracing, we maintain a separate counter on the FakeTensorMode. + # The initial count is set to the current eager tensor_id_counter value + # upon initialization, and every time you retrace using the same fake tensor + # mode, you should reset the counter to the initial count. + nt_tensor_id_counter: int = -1 + nt_tensor_id_initial_count: int = -1 + + def __init__( + self, + *, + allow_fallback_kernels: bool = True, + allow_non_fake_inputs: bool = False, + shape_env: Optional[ShapeEnv] = None, + static_shapes: Optional[bool] = None, + # TODO: This is a temporary measure, see + # https://github.com/pytorch/pytorch/pull/126245#discussion_r1604185748 + # We're currently solely using this to impede population of + # item_memo for 0d scalar tensor inputs when export, because this + # causes things that used to be deferred runtime asserts to turn into + # guards, and then the guards are just lost. We can potentially fix + # this by ensuring guards also get put in the graph, but this is + # pending a rework of how deferred runtime asserts in export. Once + # that's done, we can remove this. + export: bool = False, + ) -> None: + log.debug("create_mode 0x%x", id(self)) + super().__init__() + self.allow_fallback_kernels = allow_fallback_kernels + + import torch._dynamo.config + import torch._functorch.config + + self.propagate_real_tensors = ( + torch._functorch.config.fake_tensor_propagate_real_tensors + ) + self.fake_tensor_converter = FakeTensorConverter( + copy_data=self.propagate_real_tensors, + export=export, + ) + + if static_shapes is not None: + self.static_shapes = static_shapes + else: + self.static_shapes = shape_env is None + + # This is temporarily patched to True in Dynamo to grandfather in some + # places where we unconditionally allow scalar outputs, TO BE REMOVED + self.allow_scalar_outputs = False + + self._allow_unsafe_data_ptr_access = ( + torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access + ) + self.allow_meta = torch._functorch.config.fake_tensor_allow_meta + self.cache_enabled: bool = ( + torch._dynamo.config.fake_tensor_cache_enabled + and not self.propagate_real_tensors + ) + self.cache_crosscheck_enabled = ( + torch._dynamo.config.fake_tensor_cache_crosscheck_enabled + ) + + # A flag that controls, whether we want to invoke ops on mix of + # real weights/global variables and fake inputs + self.allow_non_fake_inputs = allow_non_fake_inputs + + # [in_kernel_invocation] + # when FakeTensor is invoked in user code, .device should return + # the fake_device of the tensor so that code such as as `if x.is_cuda` + # or torch.zeros([10, 10], device=x.device) continues to execute as if + # the FakeTensor were real. However, within kernel execution, we return + # the `Meta` device because all computation within the kernels should + # behave as if the Tensors are on meta devices. Kernels should allocate + # new tensors on meta devices, and checks like `is_meta` should return true. + # within python refs, we always return the real device by defining + # the device property + self.in_kernel_invocation = False + + # True if we enter'ed and actually enabled fake tensor mode, + # false if it was a no-op. Not thread safe but neither is + # in_kernel_invocation + # If another fake mode was already active when we enter, we also stash it here. + # That way when we exit, we know to re-enable the previous fake mode. + self.enter_stack: list[ + tuple[bool, Optional[TorchDispatchMode], Optional[bool]] + ] = [] + + self.shape_env = shape_env + + self._stack_trace = traceback.extract_stack() + self._stack = None + + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + self._mode_key = torch._C._TorchDispatchModeKey.FAKE + + import torch.nested._internal.nested_tensor + + self.nt_tensor_id_initial_count = ( + torch.nested._internal.nested_tensor._tensor_id_counter + ) + self.nt_tensor_id_counter = self.nt_tensor_id_initial_count + + def reset_nt_tensor_id_counter(self) -> None: + self.nt_tensor_id_counter = self.nt_tensor_id_initial_count + + # Typically, there is only one fake tensor mode and you test for it by + # doing an isinstance test. However, in some situations, there might be + # TWO fake tensor modes. The canonical example of this is exporting + # a fake model: there is an outer fake mode created by the user, and + # an inner fake mode created by Dynamo. The two phase process is required + # because the outer fake mode typically won't have a ShapeEnv, even if + # the user is interested in exporting with dynamic shapes (so the inner + # fake mode will actually have a ShapeEnv and swap in symbolic sizes.) + # + # In this case, it's insufficient to test only one FakeTensor: you need + # to distinguish between our fake tensor and other fake tensors. That's + # what this function does. + def is_our_fake(self, t: object) -> TypeGuard[FakeTensor]: + return isinstance(t, FakeTensor) and t.fake_mode is self + + # If we should avoid device init. This changes the behavior of various APIs: + # - We avoid constant-prop on Tensors with ops that move them to another device + # - We change the torch.tensor ctor contract to never materialize + # tensors on device + # (see NOTE: [torch.tensor, lift_fresh, and device movement]) + @property + def avoid_device_init(self) -> bool: + if torch.xpu._is_compiled(): + assert not torch.cuda._is_compiled() + return not torch.xpu.is_available() + + return not ( + torch.cuda.is_available() + or (hasattr(torch, "hpu") and torch.hpu.is_available()) + ) + + @property + def stack(self) -> str: + if self._stack is None: + self._stack = "".join(traceback.format_list(self._stack_trace)) + return self._stack + + @count + def __torch_dispatch__( + self, + func: OpOverload, + types: Sequence[type], + args: Sequence[object] = (), + kwargs: Mapping[str, object] = immutable_dict(), + ) -> object: + # FakeTensorMode should not be set when we're inside of it. + assert ( + torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is None + ), func + try: + return self.dispatch(func, types, args, kwargs) + except TypeError: + log.exception("fake tensor raised TypeError") + raise + + # No-op if FakeTensorMode is already in use + def __enter__(self) -> Self: + import torch.nested._internal.nested_tensor + + prev_only_lift_cpu_tensors = None + if self.avoid_device_init: + # See NOTE: [torch.tensor, lift_fresh, and device movement] + prev_only_lift_cpu_tensors = torch._C._only_lift_cpu_tensors() + torch._C._set_only_lift_cpu_tensors(True) + maybe_prev_fake_mode = torch._C._unset_dispatch_mode(self._mode_key) + if self is not maybe_prev_fake_mode: + self.enter_stack.append( + (True, maybe_prev_fake_mode, prev_only_lift_cpu_tensors) + ) + return super().__enter__() + else: + # no-op (still need to re-set the fake mode though since we unset it) + torch._C._set_dispatch_mode(self) + self.enter_stack.append((False, None, prev_only_lift_cpu_tensors)) + return self + + def __exit__( + self, + a: Optional[type[BaseException]], + b: Optional[BaseException], + c: Optional[TracebackType], + ) -> None: + ( + live, + maybe_prev_fake_mode, + maybe_prev_only_lift_cpu_tensors, + ) = self.enter_stack.pop() + if live: + super().__exit__(a, b, c) + + # Re-enable the previous fake mode, if there was one. + if maybe_prev_fake_mode is not None: + torch._C._set_dispatch_mode(maybe_prev_fake_mode) + if maybe_prev_only_lift_cpu_tensors is not None: + torch._C._set_only_lift_cpu_tensors(maybe_prev_only_lift_cpu_tensors) + + @classmethod + def is_infra_mode(cls) -> bool: + return True + + @classmethod + def cache_info(cls) -> DispatchCacheInfo: + """ + Query the state of the dispatch cache. + """ + return DispatchCacheInfo( + FakeTensorMode.cache_hits, + FakeTensorMode.cache_misses, + dict(FakeTensorMode.cache_bypasses), + len(FakeTensorMode.cache), + ) + + @classmethod + def cache_clear(cls) -> None: + """ + Clear the dispatch cache. + """ + cls.cache_hits = 0 + cls.cache_misses = 0 + cls.cache_bypasses.clear() + cls.cache.clear() + + def _cached_dispatch_impl( + self, + func: OpOverload, + types: Sequence[type], + args: Sequence[object], + kwargs: Mapping[str, object], + ) -> object: + """ + Lookup a cache entry for the given arguments. If none exists, dispatch + and cache the result (if the result is eligible for caching). + """ + state = None + key = None + try: + state = _CacheKeyState(self.shape_env) + key = self._cache_key(state, func, args, kwargs) + except _BypassDispatchCache as e: + # We couldn't create the cache key at all + if ( + isinstance(func, torch._ops.HigherOrderOperator) + and func.name() == "invoke_subgraph" + ): + hc_log.debug( + "Fake tensor cache failed: identifier = %s, reason = %s", + args[1], + e.reason, + ) + FakeTensorMode.cache_bypasses[e.reason] += 1 + + if key is None: + # Do this dispatch outside the above except handler so if it + # generates its own exception there won't be a __context__ caused by + # the caching mechanism. + return self._dispatch_impl(func, types, args, kwargs) + + assert state is not None + if state.cache_on_shape_env(): + assert state.shape_env is not None + cache = state.shape_env.fake_tensor_cache + set_cache_key = _set_cache_key_for_shape_env + else: + cache = FakeTensorMode.cache + set_cache_key = _set_cache_key + entry = cache.get(key, None) + + if entry is not None: + if isinstance(entry, _DispatchCacheBypassEntry): + # This represents a negative cache entry - we already saw that the + # output is uncachable. Compute it from first principals. + FakeTensorMode.cache_bypasses[entry.reason] += 1 + return self._dispatch_impl(func, types, args, kwargs) + + # We have a cache entry. + output = self._output_from_cache_entry(state, entry, key, func, args) + FakeTensorMode.cache_hits += 1 + if self.cache_crosscheck_enabled: + # For debugging / testing: Validate that the output synthesized + # from the cache matches the output created by normal dispatch. + with disable_fake_tensor_cache(self): + self._crosscheck_cache_output(output, func, types, args, kwargs) + return output + + # We don't have a cache entry. + output = self._dispatch_impl(func, types, args, kwargs) + + try: + self._validate_cache_key(func, args, kwargs) + except _BypassDispatchCache as e: + # We ran "extra" checks on the cache key and determined that it's no + # good. Record the reason and mark it so we don't bother validating + # again. + if ( + isinstance(func, torch._ops.HigherOrderOperator) + and func.name() == "invoke_subgraph" + ): + hc_log.debug( + "Fake tensor cache failed: identifier = %s, reason = %s", + args[1], + e.reason, + ) + FakeTensorMode.cache_bypasses[e.reason] += 1 + set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason)) + return output + + try: + entry = self._make_cache_entry(state, key, func, args, kwargs, output) + except _BypassDispatchCache as e: + # We had trouble making the cache entry. Record the reason and mark + # it. + FakeTensorMode.cache_bypasses[e.reason] += 1 + set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason)) + return output + + set_cache_key(cache, key, entry) + FakeTensorMode.cache_misses += 1 + return output + + def _cache_key( + self, + state: _CacheKeyState, + func: OpOverload, + args: Sequence[object], + kwargs: Mapping[str, object], + ) -> _DispatchCacheKey: + """ + Create a cache key given the dispatch args. Raises _BypassDispatchCache + for any situation that precludes caching. + """ + key_values = [ + func, + # Capture the default_dtype mode since that can affect the output tensor, + # e.g., when operating on constant float values. + torch.get_default_dtype(), + # Capture the current device to support, e.g., cache tensor creation, + # where there isn't necessarily a tensor to take the device from. + torch._C._get_default_device(), + # We want to create tensors from cached metadata only when the inference + # mode is the same. + torch.is_inference_mode_enabled(), + # Shape env settings could affect behavior. One example seen in the wild: + # Disallowing dynamic shapes can introduce a DynamicOutputShapeException + # where it wasn't seen on a previous instance of the same op. + self.shape_env.settings if self.shape_env else None, + ] + if state.known_symbols: + # If there are symbols then include the epoch - this is really more + # of a Shape env var which lives on the FakeTensorMode. + key_values.append(self.epoch) + # Collect the id_hashed objects to attach a weakref finalize later + id_hashed_objects: list[object] = [] + # Translate any FakeTensor args to metadata. + if args: + self._prep_args_for_hash(key_values, args, state, id_hashed_objects) + if kwargs: + self._prep_args_for_hash(key_values, kwargs, state, id_hashed_objects) + key = _DispatchCacheKey(tuple(key_values)) + + for id_hashed_obj in id_hashed_objects: + weakref.finalize( + id_hashed_obj, functools.partial(evict_fake_tensor_cache_key, key=key) + ) + id_hashed_objects.clear() + return key + + def _validate_cache_key( + self, + func: OpOverload, + args: Sequence[object], + kwargs: Mapping[str, object], + ) -> None: + """ + Validate that the cache key generated by _cache_key will be + reasonable. + """ + from torch._higher_order_ops.utils import registered_hop_fake_fns + + # For hops, we perform the validity check in _make_cache_entry because we + # need to have the output tensor. + if ( + isinstance(func, torch._ops.HigherOrderOperator) + and func in registered_hop_fake_fns + ): + return + + # Avoid caching for any ops that would require a more sophisticated + # caching implementation, e.g., data dependent ops or ops that modify + # the inputs. + if torch.Tag.data_dependent_output in func.tags: + raise _BypassDispatchCache("data dependent output") + + if torch.Tag.dynamic_output_shape in func.tags: + if func is aten.index.Tensor: + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type] + ) + for index in new_kwargs["indices"]: + # index calls nonzero for bool or int8 tensors, and + # therefore has a dynamic shape output. For other dtypes, + # the output shape depends on the input shape (and not data) + if isinstance(index, torch.Tensor) and index.dtype in ( + torch.bool, + torch.int8, + ): + raise _BypassDispatchCache("dynamic output shape") + return + + raise _BypassDispatchCache("dynamic output shape") + + if torch.Tag.inplace_view in func.tags: + raise _BypassDispatchCache("inplace view") + + if func == aten._unsafe_view.default: + raise _BypassDispatchCache("unsafe view") + + if func in self.lift_fns: + raise _BypassDispatchCache("lift") + + if func.name() == "inductor::resize_storage_bytes_": + raise _BypassDispatchCache("inductor::resize_storage_bytes_") + + if not torch._library.utils.is_builtin(func): + raise _BypassDispatchCache("non-builtin") + + # In order to handle storage aliasing, we need to establish the alias + # for any view op on a cache hit. But CompositeImplicitAutograd ops may + # or may not alias the input, so just punt on caching these. + if func.is_view and torch._C._dispatch_has_kernel_for_dispatch_key( + func.name(), torch._C.DispatchKey.CompositeImplicitAutograd + ): + raise _BypassDispatchCache("CompositeImplicitAutograd") + + def _prep_args_for_hash( + self, + result: list[object], + args: Union[Mapping[str, object], Sequence[object], Iterable[object]], + state: _CacheKeyState, + id_hashed_objects: list[object], + ) -> None: + """ + Translate the provided args into a form suitable for caching at FakeTensor + dispatch, i.e., convert unhashable types like lists & dicts into tuples and + convert FakeTensors into metadata. Raises _BypassDispatchCache to signal + unsupported cases that should bypass caching. + """ + from torch._higher_order_ops.utils import FunctionalizeCtxWrapper + + if isinstance(args, dict): + self._prep_args_for_hash(result, args.keys(), state, id_hashed_objects) + self._prep_args_for_hash(result, args.values(), state, id_hashed_objects) + return + + for arg in args: + if isinstance(arg, FakeTensor): + if not self.is_our_fake(arg): + raise _BypassDispatchCache("not our fake") + if arg.constant is not None: + raise _BypassDispatchCache("constant attribute") + if is_sparse_any(arg): + raise _BypassDispatchCache(f"{arg.layout} tensor") + metadata = extract_tensor_metadata(arg) + metadata._flatten_into(result, self, state) + elif isinstance(arg, Tensor): + raise _BypassDispatchCache("non-fake tensor") + elif isinstance(arg, SymInt): + state.convert_sym_int(result, arg) + elif isinstance(arg, (SymBool, SymFloat)): + raise _BypassDispatchCache("symbolic shape") + elif isinstance(arg, (list, tuple, dict)): + self._prep_args_for_hash(result, arg, state, id_hashed_objects) + elif isinstance(arg, types.FunctionType): + raise _BypassDispatchCache("function argument") + elif isinstance(arg, torch.fx.GraphModule): + # This is used for invoke_subgraph where id(graph_module) allows + # us to cache fake outputs + result.append(type(arg)) + result.append(id(arg)) + id_hashed_objects.append(arg) + elif isinstance(arg, FunctionalizeCtxWrapper): + # Special case for AOT Dispatcher first pass, where the fake + # tensor is called on the functional wrapper of the subgraph. + result.append(hash(arg)) + # functional wrapper is destroyed after fake tensor prop. We + # need to put the finalizer on the subgraph. + id_hashed_objects.append(arg.subgraph) + else: + # It's important to capture the type of the arg since, e.g., 1 and 1.0 + # hash to the same value, but can produce different dtypes for the + # output tensor. + result.append(type(arg)) + result.append(arg) + + def _validate_output_for_cache_entry( + self, + state: _CacheKeyState, + key: _DispatchCacheKey, + func: OpOverload, + args: Sequence[object], + kwargs: Mapping[str, object], + output: Optional[FakeTensor], + ) -> None: + # Is this even possible? According to the signature this can be None but + # not `int`. So either the signature is a lie or (part of) this line is + # unnecessary... + if isinstance(output, (int, type(None))): + return + + if _has_unrepresented_symbols(state, output): + # Unbacked symbols are fine - but only if they're also represented + # in the input. If there are any new unbacked symbols then we can't + # cache this output. + raise _BypassDispatchCache("unrepresented symbol in output") + + # Some ops return tuples of Tensors, but it's rare, so avoid + # the complexity of caching other types. + if not isinstance(output, FakeTensor): + raise _BypassDispatchCache("non-FakeTensor output") + + # Avoid caching FakeTensors with constants attached since those + # can be invalidated. + if output.constant is not None: + raise _BypassDispatchCache("constant attribute") + + # TODO: support caching sparse outputs? + if output.is_sparse: + raise _BypassDispatchCache("sparse output") + + if is_sparse_compressed(output): + raise _BypassDispatchCache("sparse compressed output") + + # Can an in-place op really reference a kwarg? If so, then we need + # to extend the implementation to handle it. + for kval in kwargs.values(): + if id(kval) == id(output): + raise _BypassDispatchCache("kwarg aliases output") + + def _get_output_info_for_cache_entry( + self, + state: _CacheKeyState, + key: _DispatchCacheKey, + func: OpOverload, + args: Sequence[object], + kwargs: Mapping[str, object], + output: FakeTensor, + ) -> _DispatchCacheEntryOutputInfo: + if isinstance(output, (int, torch.SymInt, type(None))): + return _DispatchCacheEntryOutputInfo( + inplace_idx=None, metadata=None, view_idx=None, constant_value=output + ) + + # If this is an in-place op, the entry records which input arg is aliased. + for idx in range(len(args)): + if id(args[idx]) == id(output): + return _DispatchCacheEntryOutputInfo( + inplace_idx=idx, metadata=None, view_idx=None + ) + + # Otherwise, create an entry that records the output tensor's metadata. + view_idx = None + if isinstance(func, torch._ops.OpOverload) and func.is_view: + idxs = [i for i, t in enumerate(args) if isinstance(t, Tensor)] + assert len(idxs) == 1 + view_idx = idxs[0] + + metadata = extract_tensor_metadata(output) + metadata.shape = tuple(state.convert_output(v) for v in metadata.shape) + metadata.stride = tuple(state.convert_output(v) for v in metadata.stride) + metadata.storage_offset = state.convert_output(metadata.storage_offset) + metadata.storage_bytes = ( + None + if metadata.storage_bytes is None + else state.convert_output(metadata.storage_bytes) + ) + + entry = _DispatchCacheEntryOutputInfo( + inplace_idx=None, + metadata=metadata, + view_idx=view_idx, + ) + + # N.B.: Some checks for bypassing the cache would be performed on the + # output tensor synthesized from the cached metadata. As an optimization, + # we can synthesize a tensor here and do the checks on that instance. + # This approach keeps the (more frequent) cache-hit path as lightweight + # as possible. + entry_for_synth_output = _DispatchCacheValidEntry( + output_infos=(entry,), is_output_tuple=False + ) + from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode + + try: + synth_output = self._output_from_cache_entry( + state, entry_for_synth_output, key, func, args + ) + except GuardOnDataDependentSymNode: + # This should probably never really happen. If it does it means that + # although the original call didn't get a data-dependent error when + # we tried to reconstruct the output we did - that's almost + # certainly a bug. + raise _BypassDispatchCache("data dependent symnode") from None + + # Make sure the dispatch_key_set from the synthesized output tensor will + # be the same. + synth_key_set = torch._C._dispatch_key_set(synth_output) + key_set = torch._C._dispatch_key_set(output) + if synth_key_set != key_set: + raise _BypassDispatchCache("dispatch_key_set mismatch") + + return entry + + def _make_cache_entry( + self, + state: _CacheKeyState, + key: _DispatchCacheKey, + func: OpOverload, + args: Sequence[object], + kwargs: Mapping[str, object], + output: Optional[FakeTensor], + ) -> _DispatchCacheValidEntry: + """ + Make a cache entry object for the given 'output' Tensor. Raises + _BypassDispatchCache if the output tensor has characteristics that + prevent caching it. + """ + from torch._higher_order_ops.utils import registered_hop_fake_fns + from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols + + # For hops, lets look at the output tensor to find any unbacked symints. + # If there are none, then we rely on the existing checks to validate + # caching. + # NB: Note that the HOPs that sta alive till FakeTensor are functional, + # once they support mutations, we will have to revisit this logic. + if ( + isinstance(func, torch._ops.HigherOrderOperator) + and func in registered_hop_fake_fns + ): + assert isinstance(output, tuple) + non_cacheable = any( + isinstance(o, (torch.Tensor, torch.SymInt)) + and has_free_unbacked_symbols(o) + for o in output + ) + if non_cacheable: + raise _BypassDispatchCache(f"unbacked symbol in HOP {func} output") + + if isinstance(output, (int, torch.SymInt, type(None))): + output_info = _DispatchCacheEntryOutputInfo( + inplace_idx=None, metadata=None, view_idx=None, constant_value=output + ) + return _DispatchCacheValidEntry( + output_infos=(output_info,), is_output_tuple=False + ) + + if isinstance(output, tuple): + for out_element in output: + self._validate_output_for_cache_entry( + state, key, func, args, kwargs, out_element + ) + else: + self._validate_output_for_cache_entry( + state, key, func, args, kwargs, output + ) + + if isinstance(output, tuple): + output_infos = [ + self._get_output_info_for_cache_entry( + state, key, func, args, kwargs, out_elem + ) + for out_elem in output + ] + return _DispatchCacheValidEntry( + output_infos=tuple(output_infos), is_output_tuple=True + ) + + else: + output_info = self._get_output_info_for_cache_entry( + state, key, func, args, kwargs, output + ) + return _DispatchCacheValidEntry( + output_infos=(output_info,), is_output_tuple=False + ) + + def _get_output_tensor_from_cache_entry( + self, + state: _CacheKeyState, + entry: _DispatchCacheEntryOutputInfo, + key: _DispatchCacheKey, + func: OpOverload, + args: Sequence[object], + ) -> Optional[FakeTensor]: + if ( + entry.inplace_idx is None + and entry.metadata is None + and entry.view_idx is None + ): + assert entry.constant_value is not SingletonConstant + return entry.constant_value + if entry.inplace_idx is not None: + # This is an in-place op; return the aliased arg. + inplace_arg = args[entry.inplace_idx] + assert isinstance(inplace_arg, FakeTensor) + return inplace_arg + + # Synthesize a new FakeTensor with the cached metadata. + metadata = entry.metadata + if metadata is None: + return None + + assert not is_sparse_any(metadata) + + def check_value( + value: _MetadataIntLike, state: _CacheKeyState + ) -> Union[IntLikeType]: + if isinstance(value, _SymIntOutputStub): + assert state.shape_env is not None + return value.extract(key, state.shape_env) + else: + assert not isinstance(value, _PySymInputStub) + return value + + shape = tuple(check_value(v, state) for v in metadata.shape) + stride = tuple(check_value(v, state) for v in metadata.stride) + storage_offset = check_value(metadata.storage_offset, state) + if metadata.storage_bytes is not None: + check_value(metadata.storage_bytes, state) + + maybe_suppress: Callable[[], typing.ContextManager] = contextlib.nullcontext + if self.shape_env is not None: + maybe_suppress = self.shape_env.suppress_guards + + with in_kernel_invocation_manager(self), maybe_suppress(): + empty = torch.empty_strided( + shape, + stride, + dtype=metadata.dtype, + layout=metadata.layout, + device="meta", + requires_grad=metadata.requires_grad, + ) + + if metadata.is_conj: + torch._C._set_conj(empty, True) + if metadata.is_neg: + torch._C._set_neg(empty, True) + + if isinstance(func, torch._ops.OpOverload) and func.is_view: + # For view ops, the storage should be the same as the tensor input. + view_arg = args[cast(int, entry.view_idx)] + assert isinstance(view_arg, FakeTensor) + storage = view_arg.untyped_storage() + with in_kernel_invocation_manager(self), maybe_suppress(): + empty.set_(storage, storage_offset, shape, stride) + + return FakeTensor(self, empty, metadata.device) + + def _output_from_cache_entry( + self, + state: _CacheKeyState, + entry: _DispatchCacheValidEntry, + key: _DispatchCacheKey, + func: OpOverload, + args: Sequence[object], + ) -> Union[Optional[FakeTensor], tuple[Optional[FakeTensor], ...]]: + """ + Create a new FakeTensor from the cache entry. + """ + + if entry.is_output_tuple: + outputs = [ + self._get_output_tensor_from_cache_entry( + state, output_info, key, func, args + ) + for output_info in entry.output_infos + ] + return tuple(outputs) + else: + return self._get_output_tensor_from_cache_entry( + state, entry.output_infos[0], key, func, args + ) + + def _crosscheck_cache_output( + self, + output: Union[Optional[FakeTensor], tuple[Optional[FakeTensor], ...]], + func: OpOverload, + types: Sequence[type], + args: Sequence[object], + kwargs: Mapping[str, object], + ) -> None: + """ + Helper to validate that the output synthesized from the cache matches + the output created by normal dispatch. + """ + + def assert_helper(a: Any, b: Any) -> None: + if isinstance(a, tuple): + assert isinstance(b, tuple) + assert len(a) == len(b) + for l, r in zip(a, b): + assert_helper(l, r) + elif isinstance(a, int): + assert isinstance(b, int) and a == b + elif a is None: + assert b is None + elif isinstance(a, py_sym_types): + assert type(a) == type(b) and a.node is b.node + elif isinstance(a, torch.Tensor): + assert isinstance(b, torch.Tensor) + assert_metadata_eq(assert_eq, a, b) + else: + raise RuntimeError(f"Unsupported type {type(a)}") + + try: + true_output = self._dispatch_impl(func, types, args, kwargs) + except Exception as e: + raise RuntimeError( + f"FakeTensor cache crosscheck failure: func={func}, " + f"args={args}, kwargs={kwargs}: Dispatch raised={e}" + ) from e + try: + assert_helper(true_output, output) + except Exception as e: + raise RuntimeError( + f"FakeTensor cache crosscheck failure: func={func}, " + f"args={args}, kwargs={kwargs}" + ) from e + + def dispatch( + self, + func: OpOverload, + types: Sequence[type], + args: Sequence[object] = (), + kwargs: Mapping[str, object] = immutable_dict(), + ) -> object: + kwargs = kwargs or {} + with no_dispatch(): + log.debug("%s %s %s", func, args, kwargs) + + if func in _DISPATCH_META_HANDLERS: + return _DISPATCH_META_HANDLERS[func](args) + + if log.getEffectiveLevel() <= logging.DEBUG: + log.debug( + "%sFakeTensorMode.__torch_dispatch__: %s", " " * RECURSION_COUNT, func + ) + # NOTE: incr is intentionally unused for a RAII pattern + incr = IncrementRecursionCount() # noqa: F841 + + # Some attribute queries that can be serviced directly + # See Note [is_coalesced is dispatched] + if func in _DISPATCH_HANDLE_DIRECTLY: + # NB: no_dispatch is ok here too, this func is very simple + with in_kernel_invocation_manager(self): + return func(*args, **kwargs) + + if self.cache_enabled: + return self._cached_dispatch_impl(func, types, args, kwargs) + else: + return self._dispatch_impl(func, types, args, kwargs) + + def _maybe_infer_fake( + self, func: OpOverload, path: KeyPath, fake: object, real: object + ) -> tuple[Optional[object], bool]: + """ + Helper to cross-check fake/real output properties & values, + and create new fake vals if mismatched. + Returns tuple of object & boolean, for whether or not it was overwrriten + """ + import sympy + + from torch._subclasses.fake_utils import _check_fake_real_tensors + + def _check_fake_real_vals(fake: Any, real: Any) -> None: + # use real values + ShapeEnv to check mismatches between potentially symbolic values + if isinstance(fake, (SymInt, SymFloat)): + # symbolic expression, ask ShapeEnv to substitute known backed/unbacked values + assert self.shape_env is not None + if ( + not fake.node.expr.free_symbols + - self.shape_env.var_to_val.keys() + - self.shape_env.unbacked_var_to_val.keys() + ): + if ( + self.shape_env._maybe_evaluate_static( + sympy.Eq(fake.node.expr, real), compute_hint=True + ) + is not sympy.S.true + ): + raise MetadataMismatchError( + f"mismatch between fake value {fake} and real value {real} " + ) + elif isinstance( + fake, (int, float, bool) + ): # concrete value, check direct equality + if fake != real: + raise MetadataMismatchError( + f"mismatch between fake value {fake} and real value {real} " + ) + + if isinstance(fake, torch.Tensor): + try: + _check_fake_real_tensors( + real, # type: ignore[arg-type] + fake, # type: ignore[arg-type] + context="Real tensor propagation found", + sizes=False, # manual check below + strides=False, # skip strides + storage_offset=True, + requires_grad=False, # issues with FakeTensorConverter preserving requires_grad + ) + except MetadataMismatchError as exc: + if torch._functorch.config.generate_fake_kernels_from_real_mismatches: + dtrace_structured( + "mismatched_fake_kernel", + metadata_fn=lambda: { + "op": str(func), + "reason": exc.reason, # noqa: F821 + }, + ) + return _infer_fake_from_real_tensor(self, func, real), True # type: ignore[arg-type] + raise MetadataMismatchError( + f"Real tensor propagation found a metadata mismatch between " + f"fake tensor {fake} and real tensor {real}, " + f" at output{keystr(path)}, for func: {func}" + ) from exc + + for j, (s_fake, s_real) in enumerate(zip(fake.size(), real.size())): # type: ignore[attr-defined] + try: + _check_fake_real_vals(s_fake, s_real) + except MetadataMismatchError as exc: + if ( + torch._functorch.config.generate_fake_kernels_from_real_mismatches + ): + dtrace_structured( + "mismatched_fake_kernel", + metadata_fn=lambda: { + "op": str(func), + "reason": exc.reason, # noqa: F821 + }, + ) + return _infer_fake_from_real_tensor(self, func, real), True # type: ignore[arg-type] + raise MetadataMismatchError( + f"Real tensor propagation found an output size mismatch between " + f"fake shape {s_fake} and real shape {s_real}, " + f"at output{keystr(path)}.size({j}), for func: {func}" + ) from exc + elif fake is None and real is not None: + if torch._functorch.config.generate_fake_kernels_from_real_mismatches: + dtrace_structured( + "mismatched_fake_kernel", + metadata_fn=lambda: { + "op": str(func), + "reason": f"mismatch between fake value {fake} and real value {real}", # noqa: F821 + }, + ) + return _infer_fake_from_real_tensor(self, func, real), True # type: ignore[arg-type] + raise MetadataMismatchError( + f"Real tensor propagation found a metadata mismatch between " + f"fake tensor {fake} and real tensor {real}, " + f" at output{keystr(path)}, for func: {func}" + ) + else: + try: + _check_fake_real_vals(fake, real) + except MetadataMismatchError as exc: + raise MetadataMismatchError( + f"Real tensor propagation found an output value mismatch between " + f"fake output value {fake} and real output value {real}, " + f"at output{keystr(path)}, for func: {func}" + ) from exc + return fake, False + + def _maybe_infer_fake_kernel_from_pytree_out( + self, + func: OpOverload, + fake_in: object, + real_in: object, + fake_out: object, + real_out: object, + ) -> Optional[object]: + """ + Helper to cross-check fake/real output properties & values, + and create new fake vals if mismatched, but at the kernel level. + Means this handles pytree outputs & checks aliasing. + """ + from torch._subclasses.fake_utils import _check_alias_info + + # we might have to clear pending unbacked symbols, if we override the kernel + pending_unbacked = None + if self.shape_env: + pending_unbacked = list(self.shape_env.pending_fresh_unbacked_symbols) + + def _clear_pending_unbacked() -> None: + self.shape_env.pending_fresh_unbacked_symbols = list( # type: ignore[union-attr] + set(self.shape_env.pending_fresh_unbacked_symbols).difference( # type: ignore[union-attr] + pending_unbacked # type: ignore[arg-type] + ) + ) + + fake_paths_leaves, fake_spec = pytree.tree_flatten_with_path(fake_out) + real_leaves, _ = pytree.tree_flatten(real_out) + try: + # catch aliasing mismatches between fake/real tensors + _check_alias_info( + "Real tensor propagation found", real_out, real_in, fake_out, fake_in + ) + except MetadataMismatchError as exc: + # if mismatch found, optionally infer fake kernel + if torch._functorch.config.generate_fake_kernels_from_real_mismatches: + dtrace_structured( + "mismatched_fake_kernel", + metadata_fn=lambda: { + "op": str(func), + "reason": ( + f"Mismatched aliasing spec between fake kernel and real kernel: {exc.reason}" # noqa: F821 + ), + }, + ) + # if aliasing mismatches are found, it's likely that the fake tensor impl + # is incorrectly aliasing, since we don't support aliasing custom ops. + # in this case we can default to inferring non-aliasing fake kernels from the real outputs. + _clear_pending_unbacked() + return tree_map( + lambda x: _infer_fake_from_real_tensor(self, func, x), real_out + ) + else: + raise MetadataMismatchError( + f"Real tensor propagation found an aliasing mismatch between " + f"fake output {fake_out} and real output {real_out}, " + f" for func: {func}" + ) from exc + + # if no errors raised, run cross checks on fake/real tensors, + # optionally overriding individual fake tensors, if individual meta kernel output is incorrect. + fake_leaves, overrides = zip( + *[ + self._maybe_infer_fake(func, _fake_path, _fake_out, _real_out) + for (_fake_path, _fake_out), _real_out in zip( + fake_paths_leaves, real_leaves + ) + ] + ) + if ( + any(overrides) and pending_unbacked + ): # only keep new pending unbacked symbols + _clear_pending_unbacked() + return pytree.tree_unflatten(fake_leaves, fake_spec) + + def _dispatch_impl( + self, + func: OpOverload, + types: Sequence[type], + args: Sequence[object], + kwargs: Mapping[str, object], + ) -> Optional[FakeTensor]: + from torch._higher_order_ops.utils import registered_hop_fake_fns + + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + + # DO NOT PUT LOGIC BEFORE UNRECOGNIZED TYPE CHECKING + # We must throw NotImplemented in case of unrecognized types to handle subclasses. + # Throwing the exception will pass the control to the next __torch_dispatch__. + # See [subclass inputs] below + # NB: If you're seeing a mysterious infinite loop involving fake + # tensor, it might be related to this line. Though I'm not sure + # how you'll know to read this comment, as this line won't show up + # in the stack trace. + has_unrecognized_types = _check_for_subclass(flat_args) + if has_unrecognized_types: + unrecognized_types = [ + type(x) for x in flat_args if _check_for_subclass_arg(x) + ] + not_implemented_log.debug( + "FakeTensorMode unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + + flat_arg_fake_tensors = [t for t in flat_args if self.is_our_fake(t)] + has_symbolic_sizes = any( + i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors + ) or any(isinstance(a, SymInt) for a in flat_args) + + converter = self.fake_tensor_converter + + is_lift_func = func in self.lift_fns + device_conversion_skip_const_prop = ( + func is torch.ops.aten._to_copy.default + and isinstance(args[0], torch.Tensor) + and args[0].device.type == "meta" + ) + + # To constant propagate through these functions: + # 1, If this is a lift due to a torch.tensor call, + # the input tensor is guaranteed to be a + # constant, so we keep a copy of the original argument along so + # we can query it if we're asked to item() it at some later point. + # (Note that you can always call a lift fn manually, so we do + # have to check if there are any fake tensors!) + # 2, Some functions that allow Python numbers to bind to Tensors, e.g, torch.div + if (is_lift_func and not flat_arg_fake_tensors) or ( + should_allow_numbers_as_tensors(func) + and not has_symbolic_sizes + and not flat_arg_fake_tensors + and not device_conversion_skip_const_prop + ): + assert all( + t.constant is not None for t in flat_arg_fake_tensors + ), f"{func} should not have fake inputs without constants" + const_flat_args = [ + a.constant if self.is_our_fake(a) else a for a in flat_args + ] + const_args, const_kwargs = pytree.tree_unflatten(const_flat_args, args_spec) + out = func(*const_args, **const_kwargs) + if type(out) is Tensor and self.may_turn_const(out): + # NB: not in_kernel_invocation_manager because we're doing real + # compute here + # NB: no_dispatch() here is VERY DANGEROUS (like, segfault + # dangerous) if this is actually a wrapper subclass tensor, + # therefore the exact type test above + with no_dispatch(): + out = out.clone() + return converter.from_real_tensor(self, out, make_constant=True) + + # if we are in the dispatch mode, we will enter this function even if the inputs + # are not FakeTensors. For now, throw if any non-Fake Tensor inputs + # and just support constructors. + + # this is generated from torch.tensor(), which does not use the + # dispatcher, to allow wrapper subclasses to wrap the new tensor + if is_lift_func: + assert len(kwargs) == 0 and len(args) == 1, f"{args} {kwargs}" + + if type(args[0]) is Tensor: + return converter.from_real_tensor(self, args[0]) + + # If we are trying to avoid device init, then we need to avoid constant + # prop on constant tensors for ops that change devices. + avoiding_device_init = False + if self.avoid_device_init: + if ( + func == torch.ops.aten._to_copy.default + and "device" in kwargs + and kwargs["device"] != "cpu" + ): + avoiding_device_init = True + if func == torch.ops.prims.device_put.default: + avoiding_device_init = True + + # Recompute flat_arg_fake_tensors here again in case some of the inputs + # were real tensors and fakified in validate_and_convert_non_fake_tensors + (flat_args, flat_arg_fake_tensors) = self.validate_and_convert_non_fake_tensors( + func, converter, flat_args, args_spec + ) + del args, kwargs # Invalidated + + # The current constant handling only support tracing systems + # (aot autograd, torchdynamo) where each operation is run consecutively. + # Because each operation is run in order, we can trace out and support + # sequences like: x = torch.tensor(0.); y = x.add_(1) + # Whenver a constant is written to but with inputs that cannot be evaluated + # statically, such as random_(), we invalidate all constants that alias the input + # We will rely on functionalization for use of fake tensors constants as persistent + # objects on an FX Graph. + + # We dispatch size/stride/numel on the FakeTensor not its constant, so bail on inplace_view + all_constant = all(e.constant is not None for e in flat_arg_fake_tensors) + if ( + isinstance(func, torch._ops.OpOverload) + and torch.Tag.nondeterministic_seeded not in func.tags + and torch.Tag.inplace_view not in func.tags + and all_constant + and len(flat_arg_fake_tensors) != 0 + and not has_symbolic_sizes + and not avoiding_device_init + and func is not aten._nested_tensor_from_tensor_list.default + ): + const_flat_args = [ + a.constant if self.is_our_fake(a) else a for a in flat_args + ] + const_args, const_kwargs = pytree.tree_unflatten(const_flat_args, args_spec) + + # NB: not in_kernel_invocation_manager(self) as we want to do REAL + # compute + with no_dispatch(): + out = func(*const_args, **const_kwargs) + + flat_out = pytree.tree_leaves(out) + flat_out_tensors = [t for t in flat_out if isinstance(t, Tensor)] + all_constant = all(self.may_turn_const(t) for t in flat_out_tensors) + + if all_constant: + return pytree.tree_map_only( + Tensor, + lambda t: converter.from_real_tensor(self, t, make_constant=True), + out, + ) + + # we weren't able to turn outputs to constants, + # so invalidate all constants that might be aliases of the outputs + for ten in flat_out_tensors: + converter.invalidate_constant_aliases(ten) + + # we are falling through to running non constant tensors, any input constant that + # is written to must be invalidated + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + + if ( + isinstance(func, torch._ops.HigherOrderOperator) + and func in registered_hop_fake_fns + ): + # Reenable the fake tensor mode for the registered fake function + maybe_ignore_fresh_unbacked_symbols = ( + contextlib.nullcontext + if self.shape_env is None + else self.shape_env.ignore_fresh_unbacked_symbols + ) + + with self, maybe_ignore_fresh_unbacked_symbols(): + return registered_hop_fake_fns[func](*args, **kwargs) + + self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs) + + def maybe_to_real_tensor( + t: T, + ) -> Optional[Union[T, Tensor, torch._C.ScriptObject]]: + if isinstance(t, FakeTensor): + return t.real_tensor + elif isinstance(t, py_sym_types): + assert self.shape_env is not None + return t.node.pytype( + t.node.expr.xreplace(self.shape_env.var_to_val).xreplace( + self.shape_env.unbacked_var_to_val + ) + ) + elif isinstance(t, FakeScriptObject): + return t.real_obj + else: + return t + + from torch.fx.experimental.symbolic_shapes import ( + compute_unbacked_bindings, + free_unbacked_symbols, + ) + + nil = object() + + real_out = nil + if ( + self.propagate_real_tensors + and all(e.real_tensor is not None for e in flat_arg_fake_tensors) + and not any( + ( + isinstance(a, py_sym_types) + and (syms := free_unbacked_symbols(a)) + and self.shape_env is not None + and any(s not in self.shape_env.unbacked_var_to_val for s in syms) + ) + for a in flat_args + ) + ): + log.debug("propagate_real_tensors %s", func) + real_flat_args = [maybe_to_real_tensor(a) for a in flat_args] + real_args, real_kwargs = pytree.tree_unflatten(real_flat_args, args_spec) + + is_builtin = library_utils.is_builtin(func) + if not is_builtin: + mutation_checker = library_utils.MutationChecker( + func, real_flat_args, args_spec + ) + + try: + real_out = func(*real_args, **real_kwargs) + except ZeroDivisionError as exc: + # we shouldn't broadly catch all errors here; + # some come from real-kernel mutation/aliasing checks we want to run. + # add more exception types as needed. + log.debug( + "real-tensor fallback failed for %s: %s; silently ignoring", + func, + exc, + ) + + if not is_builtin: + mutation_checker.check() # type: ignore[possibly-undefined] + library_utils.check_aliasing_constraint(func._name, flat_args, real_out) + + elif self.propagate_real_tensors: + # This can happen occasionally legitimately, specifically when you + # are inside the meta of a data dependent operation and you create + # a tensor on an unbacked SymInt; at this point in time we don't + # know what the unbacked SymInt is, but we will know later. + # However, if there's a bug in the condition above, this condition + # will also trigger. + log.debug( + "SKIPPED propagate_real_tensors %s(%s, %s) %s", + func, + flat_arg_fake_tensors, + flat_args, + self.shape_env.unbacked_var_to_val if self.shape_env else None, + ) + + def maybe_propagate_real_tensors(fake_out: T) -> T: + import sympy + + log.debug("maybe_propagate_real_tensors %s", func) + + def go(t: object, real_t: Tensor) -> None: + if isinstance(t, FakeTensor): + # NB: unconditionally overwrite + log.debug( + "maybe_propagate_real_tensors %s -> %s", id(t), id(real_t) + ) + t.real_tensor = real_t + for s, real_s in zip(t.size(), real_t.size()): + go(s, real_s) # type: ignore[arg-type] + for s, real_s in zip(t.stride(), real_t.stride()): + go(s, real_s) # type: ignore[arg-type] + go(t.storage_offset(), real_t.storage_offset()) # type: ignore[arg-type] + elif isinstance(t, py_sym_types) and free_unbacked_symbols(t): + if isinstance(t.node.expr, sympy.Symbol): + assert self.shape_env is not None + self.shape_env.set_unbacked_var_to_val(t.node.expr, real_t) + elif ( + isinstance(s := t.node.expr, sympy.Eq) + and isinstance(s.lhs, sympy.Symbol) + and s.rhs == 1 + ): + assert self.shape_env is not None + self.shape_env.set_unbacked_var_to_val(s, int(real_t)) + + if real_out is not nil: + # cross check fake/real outputs, and optionally override fake kernel mismatches + if ( + not torch._functorch.config.generate_fake_kernels_from_real_mismatches + ): + self._maybe_infer_fake_kernel_from_pytree_out( + func, + (args, kwargs), + (real_args, real_kwargs), + fake_out, + real_out, + ) + else: + # this can override the output only when the flag is True + fake_out = self._maybe_infer_fake_kernel_from_pytree_out( # type: ignore[assignment] + func, + (args, kwargs), + (real_args, real_kwargs), + fake_out, + real_out, + ) + + # populate unbacked_var_to_val + if ( + not isinstance(fake_out, Tensor) + and not isinstance(real_out, Tensor) + and type(fake_out) != type(real_out) + ): + # This can happen when decompositions have different return types, + # e.g. namedtuple vs. tuple vs. list. + tree_map_( + go, + tuple(pytree.tree_flatten(fake_out)), + tuple(pytree.tree_flatten(real_out)), + ) + else: + tree_map_(go, fake_out, real_out) + + # If a data-dependent op is used in a decomposition, we + # may need to get the unbacked settings "early" + # TODO: Is this really needed? + compute_unbacked_bindings(self.shape_env, fake_out, peek=True) + + return fake_out + + # Try for fastpath + if has_symbolic_sizes: + fast_impl = get_fast_op_impls().get(func) + if fast_impl is not None: + return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs)) + + # If there's a Python meta, prefer that over the decomposition + from torch._decomp import meta_table as meta_table + + if func not in meta_table and not self.cpp_meta_supports_symint(func): + from torch._decomp import decomposition_table + + # Prefer Python decompositions over C++ ones + if func in decomposition_table and ( + has_symbolic_sizes + or ( + # TODO: Remove these exclusions, so that we can remove + # this leg entirely + torch_decomp_decompositions(func) + and all(not is_sparse_any(e) for e in flat_arg_fake_tensors) + ) + ): + with self: + return maybe_propagate_real_tensors( + decomposition_table[func](*args, **kwargs) + ) + + with self: + # Decomposes CompositeImplicitAutograd ops + r = func.decompose(*args, **kwargs) + if r is not NotImplemented: + return maybe_propagate_real_tensors(r) + + # prims already wrap FakeTensor inputs to FakeTensor outputs + # and do device logic, we dont need do anything but run them + # and ensure that Meta kernels are dispatched to (see) + # Fake Tensor Dispatch Keys + # TODO - we should be use the prim aten impl + # TODO - fix prims complex ops + if ( + "prims::" in func._schema.name + and hasattr(func, "prim_meta_impl") + and not stride_incorrect_op(func) + ): + with self: + return maybe_propagate_real_tensors( + func.prim_meta_impl(*args, **kwargs) + ) + + profiles = torch._dynamo.config._custom_ops_profile + if profiles is not None: + if func in profiles.data: + return profiles.generic_fake_kernel(func, self, *args, **kwargs) + + if ( + self.propagate_real_tensors + and real_out is not nil + and not library_utils.is_builtin(func) + and self.shape_env is not None + ): + # Automatically infer a Fake kernel if there isn't one. + if not library_utils.has_fake_kernel(func): + result = inferred_fake_kernel_from_real_out(self, func, real_out) + + dtrace_structured( + "missing_fake_kernel", + metadata_fn=lambda: { + "op": str(func), + }, + ) + return maybe_propagate_real_tensors(result) + + # Users can register FakeTensor rules for custom operators + # Call them if they exist. + maybe_fake_impl = torch._library.simple_registry.singleton.find( + func.name() + ).fake_impl.kernel + if maybe_fake_impl: + try: + ctx = torch._library.fake_impl.FakeImplCtx(self, func) + with torch._library.fake_impl.set_ctx_getter(lambda: ctx), self: + result = maybe_fake_impl(*args, **kwargs) + return maybe_propagate_real_tensors(result) + + except MissingOpProfile as e: + # If we have a fake kernel registered generated from OpProfiles + # but there doesn't exist a profile for the existing inputs, and we are in + if ( + self.propagate_real_tensors + and real_out is not nil + and not library_utils.is_builtin(func) + and self.shape_env is not None + ): + result = inferred_fake_kernel_from_real_out(self, func, real_out) + + dtrace_structured( + "missing_fake_kernel", + metadata_fn=lambda: { + "op": str(func), + }, + ) + return maybe_propagate_real_tensors(result) + else: + raise e + + # special handling for funcs registered through `register_op_impl`, + # e.g., manipulating args on constructor calls to construct meta tensors + # and then afterwards wrapping them to a FakeTensor + for run_impl_check, op_impl in op_implementations_checks: + if run_impl_check(func): + op_impl_out = op_impl(self, func, *args, **kwargs) + if op_impl_out is not NotImplemented: + return maybe_propagate_real_tensors(op_impl_out) + + def maybe_run_unsafe_fallback( + error: Optional[RuntimeError] = None, + ) -> Optional[FakeTensor]: + # We infer the meta of a custom ops that return None to just + # return None. custom ops are not allowed to mutate metadata + # of their inputs, so this is safe. + if torch._library.utils.can_generate_trivial_fake_impl(func): + return None + # no meta kernel registered, fallback to kernel for the device + if has_symbolic_sizes or not self.can_run_unsafe_fallback(func): + raise UnsupportedOperatorException(func) + if error is None: + error = UnsupportedOperatorException(func) + return run_fallback_kernel(self, func, flat_args, args_spec, error) + + # Optimization: If there is no Meta kernel, it takes a surprisingly long + # amount of time to catch the NotImplementedError, so we check it here. + if not has_meta(func): + fallback = maybe_run_unsafe_fallback() + return maybe_propagate_real_tensors(fallback) + + # run kernel registered to meta for func, which include + # python meta registrations, prims, decomps, and c++ meta fns (structured kernels) + # It's possible that the kernel will return NotImplementedError + try: + with in_kernel_invocation_manager(self): + r = func(*args, **kwargs) + except NotImplementedError as not_implemented_error: + return maybe_run_unsafe_fallback(not_implemented_error) + except Exception: + log.exception("failed while attempting to run meta for %s", func) + raise + + return maybe_propagate_real_tensors( + self.wrap_meta_outputs_with_default_device_logic( + r, func, flat_args, device=kwargs.get("device") + ) + ) + + # WARNING: DO NOT add any additional namespaces/operators here if they refer to operators + # outside of the pytorch/pytorch library! Any pre-existing things here + # are either in the pytorch/pytorch library or have been grandfathered in. + # The fallback does not always work and MAY CRASH and emit unreadable error messages + # so it should not be allowed by default. + _can_run_unsafe_fallback_allowed_namespaces = ordered_set( + "debugprims", + "prims", + "aten", + "xla", + "vision", + "torchtext", + "torchaudio", + "quantized", + ) + + def can_run_unsafe_fallback(self, func: OpOverload) -> bool: + if not self.allow_fallback_kernels: + return False + # It's OK to try the fallback for built-in ops (e.g. aten, prims) + # because we control and test these but the fallback leads to unexpected behavior + # in user-defined custom ops + return ( + func.namespace in self._can_run_unsafe_fallback_allowed_namespaces + or func.name() == "fbgemm::gmm" + ) + + def validate_and_convert_non_fake_tensors( + self, + func: OpOverload, + converter: FakeTensorConverter, + flat_args: Sequence[object], + args_spec: TreeSpec, + ) -> tuple[list[object], list[FakeTensor]]: + """ + Checks if the list of tensors are fake tensors. + If not, try to convert them to fake tensors. + Returns the original args, kwargs, and a flattened list of (args, kwargs) that are fake tensors. + """ + flat_arg_fake_tensors: list[FakeTensor] = [] + + def validate(x: T) -> Union[T, FakeTensor]: + if not isinstance(x, Tensor): + return x + + nonlocal flat_arg_fake_tensors + if not self.is_our_fake(x): + if hasattr(func, "tags") and torch.Tag.inplace_view in func.tags: + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + raise AssertionError( + f"Can't call metadata mutating ops on non-Fake Tensor inputs. Found in {render_call(func, args, kwargs)}" + ) + allow_non_fake_inputs = ( + self.allow_non_fake_inputs + if fake_tensor_tls.allow_non_fake_inputs_override is None + else fake_tensor_tls.allow_non_fake_inputs_override + ) + if not allow_non_fake_inputs: + if isinstance(x, FakeTensor) and x.fake_mode is not self: + raise AssertionError("Mixing fake modes NYI") + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + raise AssertionError( + f"Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode " + f"with 'allow_non_fake_inputs'. Found in {render_call(func, args, kwargs)}" + ) + + out = converter.from_real_tensor(self, x) + else: + out = x + + flat_arg_fake_tensors.append(out) + return out + + validated_args = [validate(a) for a in flat_args] + return validated_args, flat_arg_fake_tensors + + def wrap_meta_outputs_with_default_device_logic( + self, + r: object, + func: OpOverload, + flat_args: Sequence[object], + device: torch.device, + ) -> PyTree: + converter = self.fake_tensor_converter + + # Lazily initialized, in case there are no tensor returns + common_device = None + has_scalar_only_inputs = False + + def wrap(e: T) -> Union[T, FakeTensor]: + nonlocal common_device + nonlocal has_scalar_only_inputs + + if not isinstance(e, Tensor): + return e + + if common_device is None: + ( + common_device, + has_scalar_only_inputs, + ) = FakeTensor._find_common_device(func, flat_args) + + is_our_fake = self.is_our_fake(e) + if is_our_fake: + torch._check( + e.device == common_device, + lambda: f"FakeTensor is wrapped to wrong device, found {e.device}, expected {common_device}", + ) + return cast(T, e) + elif converter is not None: + if has_scalar_only_inputs: + # Under FakeTensorMode, op accepts scalar only inputs, such as aten.add/sub/mul/div, + # returns a real scalar tensor on CPU. See TensorMeta() in _prims/__init__.py for details. + # We thus directly convert real tensor to fake tensor. + return converter.from_real_tensor(self, e) + else: + return converter.from_meta_and_device( + self, e, device or common_device + ) + else: + return e + + return tree_map(wrap, r) + + def create_symbolic_nested_int( + self, *, nt_tensor_id: Optional[int] = None + ) -> torch.SymInt: + # See Note: [Creating symbolic nested int] + # Returned nested int always has coeff=1; multiply the result by coeff if needed + import torch.nested._internal.nested_tensor + from torch.nested._internal.nested_int import NestedIntNode + + if nt_tensor_id is None: + nt_tensor_id = self.nt_tensor_id_counter + assert self.enter_stack, "should only called while FakeTensorMode is active" + self.nt_tensor_id_counter += 1 + hint = torch.SymInt(NestedIntNode(nt_tensor_id, 1)) + + src = torch._dynamo.source.EphemeralSource("intermediate_offsets_or_lengths") + assert self.shape_env is not None + ret = self.shape_env.create_symintnode( + sym=self.shape_env.create_symbol( + val=hint, + source=src, + ), + hint=hint, + source=src, + ) + return ret + + _cpp_meta_supports_symint = ordered_set( + aten.empty.memory_format, + aten.empty_strided.default, + aten.as_strided_scatter.default, + aten.as_strided.default, + aten.as_strided_.default, + aten.zeros.default, + aten.detach.default, + aten.view_as_real.default, + aten.view_as_complex.default, + aten.set_.source_Storage_storage_offset, + aten._sparse_coo_tensor_with_dims_and_tensors.default, + ) + + def cpp_meta_supports_symint(self, func: OpOverload) -> bool: + if torch.Tag.view_copy in func.tags: + return True + return func in self._cpp_meta_supports_symint + + lift_fns = ordered_set(aten.lift_fresh.default, aten.lift_fresh_copy.default) + + def may_turn_const(self, t: Tensor) -> bool: + return ( + t.numel() <= CONSTANT_NUMEL_LIMIT + and not is_sparse_any(t) + and not self.is_our_fake(t) + and not t.device.type == "meta" + ) + + def invalidate_written_to_constants( + self, + func: OpOverload, + flat_arg_fake_tensors: Sequence[FakeTensor], + args: Sequence[object], + kwargs: Mapping[str, object], + ) -> None: + any_constant = any(e.constant is not None for e in flat_arg_fake_tensors) + schema_info = get_schema_info(func) + if any_constant and schema_info.is_mutable(): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type] + ) + for k, v in new_kwargs.items(): + k = k if (k != "input" or schema_info.has_argument(k)) else "self" + if ( + self.is_our_fake(v) + and schema_info.is_mutable(k) + and v.constant is not None + ): + self.fake_tensor_converter.invalidate_constant_aliases(v.constant) + + def from_tensor( + self, + tensor: Tensor, + *, + static_shapes: Optional[bool] = None, + source: Optional[Source] = None, + symbolic_context: Optional[SymbolicContext] = None, + trace: bool = True, + ) -> FakeTensor: + shape_env: Optional[ShapeEnv] = self.shape_env + if static_shapes is None: + static_shapes = self.static_shapes + if static_shapes: + assert ( + symbolic_context is None + ), "cannot set both static_shapes and symbolic_context" + shape_env = None + return self.fake_tensor_converter.from_real_tensor( + self, + tensor, + shape_env=shape_env, + source=source, + symbolic_context=symbolic_context, + trace=trace, + ) + + +_StoragePointer = object + + +def _has_unrepresented_symbols( + state: _CacheKeyState, output: Optional[FakeTensor] +) -> bool: + from torch.fx.experimental.symbolic_shapes import _iterate_exprs + + for s in _iterate_exprs(output): + for symbol in s.free_symbols: + if symbol not in state.known_symbols: + return True + + return False + + +# NB: returns fake tensors +def run_fallback_kernel( + fake_mode: FakeTensorMode, + func: OpOverload, + flat_args: Sequence[object], + args_spec: PyTree, + orig_not_implemented_exception: RuntimeError, +) -> FakeTensor: + # these should all be supported, just to be safe + # avoid fallback for operators which inplace modify metadata + # because the input fake tensors would be umodified + if torch.Tag.inplace_view in func.tags: + raise orig_not_implemented_exception + + inp_impls = {} + + # Don't use in_kernel_invocation_manager(fake_mode) as we want to do + # REAL compute (not with meta device) + with no_dispatch(): + + def to_real_tensor(e: T) -> Union[T, Tensor]: + if fake_mode.is_our_fake(e): + out = torch.zeros_like(e, device=e.fake_device) + if e.is_sparse: + out._coalesced_(e.is_coalesced()) + inp_impls[id(out)] = e + return out + return e + + flat_args = [to_real_tensor(a) for a in flat_args] + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + + r = func(*args, **kwargs) + + storages: set[_StoragePointer] = set() + + for e in flat_args: + if isinstance(e, Tensor): + if not is_sparse_any(e): + storages.add(e._typed_storage()._cdata) + + # TODO: also check metadata change on inputs + # proper aliasing/metadata relationship between outputs and inputs will + # not be set up, bc of conversion to device, unless we can reuse an + # input impl + + def map_out(e: T) -> Union[T, FakeTensor]: + if id(e) not in inp_impls and ( + isinstance(e, Tensor) + and not is_sparse_any(e) + and e._typed_storage()._cdata in storages + ): + raise orig_not_implemented_exception + + if isinstance(e, Tensor): + if id(e) in inp_impls: + return inp_impls[id(e)] + else: + return fake_mode.fake_tensor_converter.from_real_tensor(fake_mode, e) + else: + return e + + return pytree.tree_map(map_out, r) + + +def _set_cache_key_for_shape_env( + cache: dict[_DispatchCacheKey, _DispatchCacheEntry], + key: _DispatchCacheKey, + entry: _DispatchCacheEntry, +) -> None: + key.strip_shape_env() + cache[key] = entry + + +def _set_cache_key( + cache: dict[_DispatchCacheKey, _DispatchCacheEntry], + key: _DispatchCacheKey, + entry: _DispatchCacheEntry, +) -> None: + cache[key] = entry + + +# Just for use to allow copying a module to fake tensors, +# does not apply elsewhere +class FakeCopyMode(TorchFunctionMode): + def __init__(self, fake_mode: FakeTensorMode) -> None: + self.fake_mode = fake_mode + + def __torch_function__( + self, + func: OpOverload, + types: Sequence[type], + args: Sequence[object] = (), + kwargs: Optional[Mapping[str, object]] = None, + ) -> FakeTensor: + kwargs = kwargs if kwargs else {} + + # clone will get called in Parameter deepcopy + if func == torch._C.TensorBase.clone: + assert isinstance(args[0], Tensor) + return func( + self.fake_mode.from_tensor(args[0], static_shapes=True), **kwargs + ) + elif func == Tensor.__deepcopy__: + assert len(args) == 2 and len(kwargs) == 0 + tensor = cast(Tensor, args[0]) + memo = cast(dict[int, FakeTensor], args[1]) + + if id(tensor) in memo: + return memo[id(tensor)] + + out = self.fake_mode.from_tensor(tensor, static_shapes=True) + memo[id(tensor)] = out + return out + else: + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + +def _device_handler(args: Sequence[object]) -> torch.device: + # NB: Don't use is_our_fake, just serve the fake information + # as is. Notice we don't use 'self'; we use args[0].fake_mode + # because they may not be the same. It would also be possible + # to return NotImplemented here, in which case the FakeTensor + # handler on args[0] would handle it, but we're being nice and + # short-circuiting quickly. + assert len(args) == 1 and isinstance(args[0], FakeTensor) + if args[0].fake_mode.in_kernel_invocation: + return torch.device("meta") + else: + return args[0].fake_device + + +# [subclass inputs] +# Suppose we enable fake tensor mode. This means that fake tensor +# mode will run first. But what if we do an operation that +# involves a tensor subclass that will desugar into normal tensor +# operations? Without returning NotImplemented, fake tensor mode will run first, +# decide that a conversion was made (since there was a non fake +# tensor argument), and report an error that converting non +# fake tensor is not supported. What we actually wanted to happen +# was to give the subclass a chance to figure out what it wants to +# before erroring out. Returning NotImplemented here allows this. +def _check_for_subclass(flat_args: Sequence[object]) -> bool: + return any(_check_for_subclass_arg(x) for x in flat_args) + + +def _check_for_subclass_arg(x: object) -> bool: + return ( + not isinstance(x, FakeTensor) + and isinstance(x, Tensor) + and type(x) is not Tensor + and type(x) is not torch.nn.Parameter + ) + + +_DISPATCH_META_HANDLERS = { + torch.ops.prim.device.default: _device_handler, + torch.ops.aten.size.default: lambda args: tuple( + int(s) for s in cast(Tensor, args[0]).size() + ), + torch.ops.aten.stride.default: lambda args: tuple( + int(s) for s in cast(Tensor, args[0]).stride() + ), + torch.ops.aten.storage_offset.default: lambda args: int( + cast(Tensor, args[0]).storage_offset() + ), +} + +_DISPATCH_HANDLE_DIRECTLY = ordered_set( + torch.ops.aten.is_coalesced.default, + torch.ops.aten.dense_dim.default, + torch.ops.aten.sparse_dim.default, + # _RecordFunction doesn't support __eq__ so make sure not to attempt to + # cache it. + torch.ops.profiler._record_function_exit._RecordFunction, +) + +from torch._subclasses.fake_impls import ( # noqa: F401 + _device_not_kwarg_ops, + _is_tensor_constructor, + _like_tensor_constructors, + contains_tensor_types, + get_fast_op_impls, + has_meta, + op_implementations_checks, + stride_incorrect_op, +) + + +def evict_fake_tensor_cache_key(key: _DispatchCacheKey) -> None: + if key in FakeTensorMode.cache: + FakeTensorMode.cache.pop(key) + + +@atexit.register +def dump_cache_stats() -> None: + log.info("FakeTensor cache stats:") + log.info(" cache_hits: %s", FakeTensorMode.cache_hits) + log.info(" cache_misses: %s", FakeTensorMode.cache_misses) + bypasses = FakeTensorMode.cache_bypasses + if bypasses: + log.info(" cache_bypasses:") + width = max(len(k) for k in bypasses) + for k, v in sorted(bypasses.items(), key=lambda i: -i[1]): + log.info(" %-*s %s", width + 1, f"{k}:", v) + + +def _infer_fake_from_real_tensor( + mode: FakeTensorMode, op: torch._ops.OpOverload, real_out: torch.Tensor +) -> torch.Tensor: + def unsupported(reason: str) -> None: + raise RuntimeError( + f"propagate_real_tensors: we cannot infer a Fake kernel " + f"(meta kernel) for operator {op._name} because {reason}. " + f"Please use torch.library.register_fake to add a Fake kernel." + ) + + if real_out.storage_offset() != 0: + unsupported( + f"a return has a non-zero storage offset {real_out.storage_offset()}" + ) + + # Since PT2 is rank specialized, there's no such thing as a symbolic + # output rank. So we can assume the fake tensor has the same number of + # dimensions as the real tensor output. + # + # We shouldn't assume the Fake sizes/strides are exactly what we see on + # the real tensor output (perhaps we should give users a lever to toggle + # this). This is because there's a good amount of operators that return + # outputs with data-dependent output shape. + # So we infer the output sizes to all be unbacked symints + fake_shape = [ + torch._library.fake_impl.allocate_size(mode.shape_env) + for _ in range(real_out.dim()) + ] + + # We infer what the strides are. We had a couple of options for this: + # - assume the strides are computable from the sizes + # - use new fresh unbacked symints in the strides + # This doesn't work that well (PT2 doesn't support unbacked symint strides well) + # - use the real strides + # This can only be used if we assume the strides are static. + # We went with the first option. + fake_strides = [-1] * real_out.dim() + strides = [(s, idx) for idx, s in enumerate(real_out.stride())] + strides.sort(key=lambda x: (x[0], -x[1])) + expected = 1 + fake_stride = expected + for s, idx in strides: + if s != expected: + unsupported( + f"a return was not dense in memory (sizes {real_out.shape} strides {real_out.stride()})" + ) + fake_strides[idx] = fake_stride + expected = expected * real_out.shape[idx] + fake_stride = fake_stride * fake_shape[idx] + + with mode: + return torch.empty_strided( + fake_shape, + fake_strides, + device=real_out.device, + dtype=real_out.dtype, + layout=real_out.layout, + ) + + +def inferred_fake_kernel_from_real_out( + mode: FakeTensorMode, op: torch._ops.OpOverload, real_out: Any +) -> Any: + assert mode.shape_env is not None + + # Only support operators that have all Tensor outputs + # This is a general limitation on custom ops that we impose for PT2 + # to avoid baking non-symbolic float/int outputs into the graph. + real_flat_out, spec = pytree.tree_flatten(real_out) + if not all(isinstance(t, torch.Tensor) for t in real_flat_out): + raise RuntimeError( + f"propagate_real_tensors: we don't support operators that return " + f"non-Tensors. Got {op._schema}" + ) + + fake_flat_out = [_infer_fake_from_real_tensor(mode, op, t) for t in real_flat_out] + return pytree.tree_unflatten(fake_flat_out, spec) diff --git a/phivenv/Lib/site-packages/torch/_subclasses/fake_utils.py b/phivenv/Lib/site-packages/torch/_subclasses/fake_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..53ec927219398b6505b0c686cd42166eec4fe5af --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_subclasses/fake_utils.py @@ -0,0 +1,304 @@ +# mypy: ignore-errors + +import functools +import warnings +from typing import Any, Callable, Union + +import torch +import torch.utils._pytree as pytree +from torch._ops import OpOverload +from torch._subclasses.fake_tensor import ( + FakeTensor, + FakeTensorMode, + MetadataMismatchError, + tree_flatten_only, + UnsupportedFakeTensorException, +) +from torch.utils._python_dispatch import TorchDispatchMode + + +aten = torch._ops.ops.aten + + +def outputs_alias_inputs(outputs, inputs): + input_storages = { + inp._typed_storage()._cdata + for inp in tree_flatten_only(torch.Tensor, inputs) + if torch._C._has_storage(inp) + } + return any( + torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages + for out in tree_flatten_only(torch.Tensor, outputs) + ) + + +def outputs_are_inputs(outputs, inputs): + input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)} + return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs)) + + +def output_alias_each_other(outputs): + storages = set() + for out in tree_flatten_only(torch.Tensor, outputs): + if not torch._C._has_storage(out): + continue + stor = out._typed_storage()._cdata + if stor in storages: + return True + storages.add(stor) + return False + + +def _check_alias_info(context, real_out, real_in, fake_out, fake_in): + r_aliasing = outputs_alias_inputs(real_out, real_in) + f_aliasing = outputs_alias_inputs(fake_out, fake_in) + if r_aliasing != f_aliasing: + raise MetadataMismatchError( + f"{context} mismatch in outputs_alias_inputs check {f_aliasing} != {r_aliasing}" + ) + + r_identity_eq = outputs_are_inputs(real_out, real_in) + f_identity_eq = outputs_are_inputs(fake_out, fake_in) + if r_identity_eq != f_identity_eq: + raise MetadataMismatchError( + f"{context} mismatch in outputs_are_inputs check {f_identity_eq} != {r_identity_eq}" + ) + + r_output_alias_each_other = output_alias_each_other(real_out) + f_output_alias_each_other = output_alias_each_other(fake_out) + if r_output_alias_each_other != f_output_alias_each_other: + raise MetadataMismatchError( + f"{context} mismatch in outputs_alias_each_other check " + f"{f_output_alias_each_other} != {r_output_alias_each_other}" + ) + + +def is_sdpa_error(func, idx, e): + if ( + ( + func is aten._scaled_dot_product_flash_attention.default + or func is aten._flash_attention_forward.default + ) + and idx in (6, 7) + and "Devices" in repr(e) + ): + return True + if ( + ( + func is aten._scaled_dot_product_efficient_attention.default + or func is aten._efficient_attention_forward.default + ) + and idx in (2, 3) + and "Devices" in repr(e) + ): + return True + if ( + func is aten._scaled_dot_product_cudnn_attention.default + and idx in (6, 7) + and "Devices" in repr(e) + ): + return True + return False + + +def try_convert_fake_to_real( + ten_list: list[Union[FakeTensor, Any]] +) -> list[Union[FakeTensor, torch.Tensor, Any]]: + """ + Attempt to convert fake tensors to a corresponding real tensor with the correct underlying storage by looking up + the FakeTensorMode meta to real storage mapping. On failure to find the storage mapping, the FakeTensor will + remain in the list. + + Note: this is not currently optimized (makes copies of the meta converter internal dictionaries) + """ + + fake_tensor = next( + (item for item in ten_list if isinstance(item, FakeTensor)), None + ) + if fake_tensor is None: + return ten_list + + fake_mode = fake_tensor.fake_mode + meta_converter = fake_mode.fake_tensor_converter.meta_converter + desc = meta_converter.describer + + storage_to_key = {v: k for k, v in meta_converter.storage_memo.items()} + key_to_real_storage = {v: k for k, v in desc.lookup_storage.items()} + out = [] + for t in ten_list: + if not isinstance(t, FakeTensor) or not t.layout == torch.strided: + out.append(t) + continue + + key = storage_to_key.get(t.untyped_storage()) + real_storage = None if key is None else key_to_real_storage.get(key) + if real_storage is None: + out.append(t) + continue + + unhinted = False + + def map_symint(s): + nonlocal unhinted + if not isinstance(s, torch.SymInt): + return s + unhinted = unhinted if not unhinted else s.node.has_hint() + return s.node.hint + + stor_offset = map_symint(t.storage_offset()) + size = [map_symint(s) for s in t.shape] + stride = [map_symint(s) for s in t.stride()] + + if unhinted: + out.append(t) + continue + + new_tensor = torch.empty( + [], + dtype=t.dtype, + device=t.device, + ) + new_tensor.set_( + real_storage, + storage_offset=stor_offset, + size=size, + stride=stride, + ) + out.append(new_tensor.clone()) + + return out + + +def _check_fake_real_tensors( + real_out: torch.Tensor, + fake_out: FakeTensor, + context="", + sizes=True, + strides=False, + storage_offset=True, + requires_grad=True, +): + if requires_grad: + if real_out.requires_grad != fake_out.requires_grad: + raise MetadataMismatchError( + f"{context} mismatched requires_grad-ness of outputs. " + f"This usually means that you have added autograd support " + f"for your operator at a dispatch key other than Autograd, " + f"which will lead to problems" + ) + + if torch._C._has_storage(real_out): + r_offset = real_out.storage_offset() + f_offset = fake_out.storage_offset() + if r_offset != f_offset: + raise MetadataMismatchError(f"{context} mismatched storage offset") + + torch._prims.utils.compare_tensor_meta( + real_out, + fake_out, + check_sizes=sizes, + check_strides=strides, + allow_rhs_unbacked=True, + ) + + +class CrossRefFakeMode(TorchDispatchMode): + def __init__( + self, + ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None, + *, + check_strides=True, + check_aliasing=True, + only_check_ops_with_meta=True, + ): + super().__init__() + self.ignore_op_fn = ( + ignore_op_fn if ignore_op_fn is not None else lambda fn: False + ) + self.check_strides = check_strides + self.check_aliasing = check_aliasing + self.only_check_ops_with_meta = only_check_ops_with_meta + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + fake_r = None + + # empty_like excluded for now due to sparse complex + # aten._to_dense.default this one is getting called with csc + if ( + func + not in ( + aten.lift_fresh.default, + aten.lift_fresh_copy.default, + aten.set_.source_Storage_storage_offset, + ) + and not self.ignore_op_fn(func) + and ( + not self.only_check_ops_with_meta + or torch._subclasses.fake_impls.has_meta(func) + ) + and torch.Tag.dynamic_output_shape not in func.tags + and torch.Tag.inplace_view not in func.tags + and torch.Tag.data_dependent_output not in func.tags + ): + # Do not import symbolic_shapes at the top of the module as it imports sympy and that's slow + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + try: + # TODO: enable_python_dispatcher() here + with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode: + fake_args, fake_kwargs = pytree.tree_map_only( + torch.Tensor, + functools.partial(fake_mode.from_tensor, static_shapes=True), + (args, kwargs), + ) + with warnings.catch_warnings(): + fake_r = func(*fake_args, **fake_kwargs) + except UnsupportedFakeTensorException: + pass + + context = ( + f"When comparing the output of {func} on FakeTensor and concrete Tensors, " + f"found" + ) + r = func(*args, **kwargs) + if fake_r is not None: + r_flat = pytree.tree_leaves(r) + f_flat = pytree.tree_leaves(fake_r) + assert len(f_flat) == len( + r_flat + ), f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}" + + if self.check_aliasing: + _check_alias_info( + context, r, (args, kwargs), fake_r, (fake_args, fake_kwargs) + ) + + for idx, (r_out, f_out) in enumerate( + zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r)) + ): + r_is_ten = isinstance(r_out, torch.Tensor) + assert r_is_ten == isinstance( + f_out, torch.Tensor + ), f"{context} mismatched number of tensor outputs" + if r_is_ten: + try: + _check_fake_real_tensors( + r_out, + f_out, + sizes=True, + strides=self.check_strides, + storage_offset=True, + requires_grad=True, + ) + except Exception as e: + if is_sdpa_error(func, idx, e): + continue + error_message = ( + f"{context} mismatched tensor metadata: {e}" + if len(r_flat) == 1 + else f"{context} mismatched tensor metadata for output[{idx}]: {e}" + ) + raise MetadataMismatchError(error_message) from e + return r diff --git a/phivenv/Lib/site-packages/torch/_subclasses/functional_tensor.py b/phivenv/Lib/site-packages/torch/_subclasses/functional_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..7016bde8fcbc2a790fb76bbc6606671bfd81b72c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_subclasses/functional_tensor.py @@ -0,0 +1,781 @@ +# mypy: allow-untyped-defs +import contextlib +import warnings +import weakref +from abc import ABC, abstractmethod +from contextlib import AbstractContextManager +from typing import Any, Callable, Optional, Union + +import torch +import torch.utils._pytree as pytree +from torch._C import _functionalization_reapply_views_tls as _reapply_views +from torch._ops import _get_dispatch_mode_pre_dispatch +from torch._subclasses.meta_utils import is_sparse_any +from torch.utils._python_dispatch import ( + _detect_infra_mode, + _disable_infra_mode, + return_and_correct_aliasing, + TorchDispatchMode, +) + + +not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") + + +# NOTE Some special handling for tensor conversion during export is needed. +# Normally, when tracing through the model with tensor.to(), the maybe-aliasing +# relationship between input and output tensors will be baked into the graph. +# For example, if we got a tensor with device cpu and call tensor.to("cpu"), +# it will become a no-op in the graph. For a whole graph capture, this is not +# sound so we need to do something different. Instead, in export we will try to +# preserve the tensor conversion by forcing a non-semantic-breaking aten::_to_copy +# operator to be traced in the graph, and subsequently banning mutations on all +# such converted tensors. +# In addition to patching .to() method call in functionalization, we will have to +# patch other similar methods like float() and cpu(), because they intentionally +# don't fall back to .to() methods, but have the same behavior as .to() according to +# pytorch document. https://pytorch.org/docs/stable/generated/torch.Tensor.float.html +# thus we simply force them to go through .to() call. +def _conversion_method_template(**extra_kwargs): + def _(self, *args, **kwargs): + return self.to(*args, **{**kwargs, **extra_kwargs}) + + return _ + + +class FunctionalTensor(torch.Tensor): + """ + Functional tensors represent tensors that will remove mutations + from a program. If you perform a mutable operation on a functional tensor, + it will re-dispatch to the functional variant of that operation. + + Historically, functionalization is implemented in C++ in the dispatcher. + This class is a lightweight python shim around the C++ functionalization logic. + + FunctionalTensor is required to be used with a corresponding + FunctionalTensormode active, because it relies + on using the mode for dispatch (which can properly handle factory functions). + """ + + elem: torch.Tensor + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + _mode_key = torch._C._TorchDispatchModeKey.FUNCTIONAL + + # Note: The reason we add these extra keys to our FunctionalTensor subclass + # is to mirror the behavior of C++ functionalization (we can choose to change this + # later, as long as it doesn't break anything). + # FunctionalTensorWrapper copies **all** dispatch keys from the inner tensor + # to the wrapper, excluding functorch and python dispatch keys. + # Here I'm trying to re-use the keyset the functorch wrapper subclasses copy, + # except that they don't include ZeroTensor so I'm manually adding it in. + _extra_dispatch_keys = torch._C._additional_keys_to_prop_for_wrapper_tensors.add( + torch._C.DispatchKey.ZeroTensor + ) + + # These are all aten ops that correspond to metadata queries. + # We want FunctionalTensor to be able to handle them directly. + metadata_fns = [ + torch.ops.aten.is_contiguous.default, # type: ignore[has-type] + torch.ops.aten.is_contiguous.memory_format, # type: ignore[has-type] + torch.ops.aten.is_strides_like_format.default, # type: ignore[has-type] + torch.ops.aten.is_non_overlapping_and_dense.default, # type: ignore[has-type] + torch.ops.aten.size.default, # type: ignore[has-type] + torch.ops.aten.sym_size.default, # type: ignore[has-type] + torch.ops.aten.stride.default, # type: ignore[has-type] + torch.ops.aten.sym_stride.default, # type: ignore[has-type] + torch.ops.aten.storage_offset.default, # type: ignore[has-type] + torch.ops.aten.sym_storage_offset.default, # type: ignore[has-type] + torch.ops.aten.numel.default, # type: ignore[has-type] + torch.ops.aten.sym_numel.default, # type: ignore[has-type] + torch.ops.aten.dim.default, # type: ignore[has-type] + torch.ops.prim.device.default, # type: ignore[has-type] + ] + + # Used by auto_functionalize to determine base of tensors during inference mode. + _inference_mode_base: Optional["FunctionalTensor"] = None + + def __new__(cls, elem, mode): + assert torch._is_functional_tensor(elem) + + # In general, we'd like our functional tensor subclass to only be in charge of functionalization, + # and defer to the inner subclass for all other functionality. + # Example: If our inner tensor is a ZeroTensor, we would want to defer running the ZeroTensor fallback + # until after we redispatch to our inner ZeroTensor. + # However, there are a few keys that we need to mirror between the inner and outer tensors. + # Conjugate + # Negative + # Why? These keys are used to test metadata queries, like `.is_conj()` and `.is_neg()`. + # We **need** calls to is_conj() to return the same thing on the outer and inner tensors, + # Because user code / framework code that branches like so needs to do the same thing + # when it sees the outer FunctionalTensor: + # if (x.is_conj()) { + # return at::view_as_real(x.resolve_conj()); + # } else { + # return at::view_as_real(x); + # } + extra_dispatch_keys = ( + FunctionalTensor._extra_dispatch_keys & torch._C._dispatch_keys(elem) + ) + + out = torch.Tensor._make_wrapper_subclass( + # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. + # Calling the overload that has kwargs causes us to go down the first overload path, + # which will **always** specialize sizes. + # We should probably eventually fix this so that the first overload can just handle dynamic shapes. + cls, + elem.shape, # sizes + elem.stride() if not is_sparse_any(elem) else None, # strides + ( + elem.storage_offset() if not is_sparse_any(elem) else None + ), # storage_offset + None, # memory_format + elem.dtype, # dtype + elem.layout, # layout + elem.device, # device + False, # pin_memory + elem.requires_grad, # requires_grad + None, # dispatch_sizes_strides_policy + False, # dispatch_device + False, # dispatch_layout + extra_dispatch_keys, # _extra_dispatch_keys + ) + torch._C._set_throw_on_mutable_data_ptr(out) + out.elem = elem + + if ( + not mode.export + and torch.is_inference_mode_enabled() + and torch._inductor.config.enable_auto_functionalized_v2 + ): + if out.is_base_tensor(): + out._inference_mode_base = None + # This assumes that the FunctionalTensor.elem does not change its storage after this point. + # Otherwise this would be invalid. + mode._storage_to_base[out.elem.untyped_storage()] = out + else: + out._inference_mode_base = mode._storage_to_base[ + out.elem.untyped_storage() + ] + assert out._inference_mode_base is not None + return out + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): # type: ignore[override] + unrecognized_types = [ + t + for t in types + if t not in [torch.Tensor, torch._subclasses.FakeTensor, FunctionalTensor] + ] + if unrecognized_types: + not_implemented_log.debug( + "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + + if kwargs is None: + kwargs = {} + + # FunctionalTensor needs to plumb all metadata requests to the inner tensor. + # In theory we don't have to do this - but if we want to service metadata requests here, + # we need to carefully make sure all metadata is accurate (including metadata mutations) + if func in FunctionalTensor.metadata_fns: + # All metadata accesses should be plumbed to the inner tensor, that way we don't have to worry + # about the problem of keeping metadata in sync between the wrapper and inner tensor. + # This also alleviates us from having to manually handle metadata mutations on the wrapper. + assert len(kwargs) == 0 + if func in [ + torch.ops.aten.is_strides_like_format.default, + torch.ops.aten.is_contiguous.memory_format, + ]: + assert len(args) == 2 and isinstance(args[0], FunctionalTensor) + return func(torch._from_functional_tensor(args[0].elem), args[1]) + assert len(args) == 1 and isinstance(args[0], FunctionalTensor) + + return func(torch._from_functional_tensor(args[0].elem)) + # Originally I tried to implement my subclass without giving it a torch_dispatch, but I gave up: + # - _make_wrapper_subclass requires a __torch_dispatch__ + # - If we want to use _make_subclass(), we have a problem: the subclass will share a TensorImpl with the inner tensor, + # which is of type FunctionalTensorWrapper! We explicitly do not want our wrapper to be a FunctionalTensorWrapper. + # - If we use the default tensor.__new__(), we have another problem: it returns inner_tensor.alias(), + # which causes every subclass created above autograd to have autograd view metadata + # (in addition to also being a FunctionalTensorWrapper). + raise RuntimeError( + "Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()" + ) + + def __repr__(self) -> str: # type: ignore[override] + return f"FunctionalTensor({repr(self.elem)})" + + @staticmethod + def to_functional(x): + # We will do the wrapping for the user. + + assert not torch._is_functional_tensor(x) + # The only autograd metadata we care about on the FunctionalTensor is: + # - requires_grad (so autograd runs) + # - is_leaf (so that mutations on graph inputs that are not leaves are allowed by the autograd engine) + # this is handled by FunctionalTensor.to_functional + x_functional = torch._to_functional_tensor(x) + # Technically the FunctionalTensormode here is unnecessary, + # but it avoids spurious NotImplemented logs during `ProxyTorchDispatchMode` tracing. + # _mirror_autograd_meta_to queries tensor sizes, + # and otherwise the sym_size() call will go to the proxy mode before hitting + # FunctionalTensor.__torch_dispatch__ + + functional_mode = _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) + assert functional_mode is not None + + with functional_mode: + torch._mirror_autograd_meta_to(x, x_functional) # type: ignore[attr-defined] + out = FunctionalTensor(x_functional, functional_mode) + torch._mirror_autograd_meta_to(x_functional, out) # type: ignore[attr-defined] + return out + + def from_functional(self): + torch._sync(self) + return torch._from_functional_tensor(self.elem) + + def is_base_tensor(self) -> bool: + return torch._is_functional_tensor_base(self.elem) + + def replace_(self, output) -> None: + torch._functionalize_replace(self.elem, output) + + def commit_update(self) -> None: + torch._functionalize_commit_update(self.elem) + + def sync(self) -> None: + torch._functionalize_sync(self.elem) + + def mark_mutation_hidden_from_autograd(self) -> None: + torch._functionalize_mark_mutation_hidden_from_autograd(self.elem) + + def tolist(self) -> Any: + if self.elem.dim() == 0: + return self.elem.item() + elif self.elem.dim() == 1: + return [elem.item() for elem in self.elem] + else: + return [elem.tolist() for elem in self.elem] + + def to(self, *args, **kwargs): + if _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL).export: + torch.ops.aten._assert_tensor_metadata( + self, + dtype=self.dtype, + device=self.device, + layout=self.layout, + ) + return super().to(*args, **kwargs) + + def cuda(self, device=None, *args, **kwargs): + device = device or torch.cuda.current_device() + if len(args) > 0: + return self.to(device, *args, **kwargs) + else: + return self.to(device=device, **kwargs) + + char = _conversion_method_template(dtype=torch.int8) + cpu = _conversion_method_template(device=torch.device("cpu")) + bfloat16 = _conversion_method_template(dtype=torch.bfloat16) + byte = _conversion_method_template(dtype=torch.uint8) + double = _conversion_method_template(dtype=torch.float64) + float = _conversion_method_template(dtype=torch.float32) + bool = _conversion_method_template(dtype=torch.bool) + half = _conversion_method_template(dtype=torch.float16) + int = _conversion_method_template(dtype=torch.int32) + long = _conversion_method_template(dtype=torch.int64) + + # TODO(sparse-team): fixes #133174 but can we do without the relay? + def to_dense(self): # type: ignore[override] + return self.elem.to_dense() + + @property + def layout(self): # type: ignore[override] + return self.elem.layout + + def __bool__(self): + return bool(self.item()) + + +class FunctionalTensorMode(TorchDispatchMode): + def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=False): + super().__init__() + self.export = export + self.is_on_stack = False + self.enter_stack = [] + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + self._mode_key = torch._C._TorchDispatchModeKey.FUNCTIONAL + self.pre_dispatch = pre_dispatch + # This will be turned off later for pre-dispatch functionalization + self._dispatch_key = torch._C.DispatchKey.PreDispatch if pre_dispatch else None # type: ignore[attr-defined] + # Map of effect type (ex. _EffectType.ORDERED) to a token. The tokens help keep + # track of the ordering between side effectful operations. + self._tokens: dict[Any, torch.Tensor] = {} + + # Filled after forward tracing. + self._tokens_forward_output: dict[Any, torch.Tensor] = {} + + # Functionalization runs twice in AOTAutograd, once in + # `run_functionalized_fw_and_collect_metadata` to collect metadata to + # see which tensors need to be functionalized and discover how many + # tokens we need, and another time in `make_fx` which does the actual + # tracing to replace ops with their functional variants and handling + # side-effectful ops. In the second stage there should be no token + # discovery. This flag distinguishes between the two stages. + self._allow_token_discovery = _allow_token_discovery + + self._storage_to_base: weakref.WeakKeyDictionary[ + torch.storage.UntypedStorage, Optional[FunctionalTensor] + ] = weakref.WeakKeyDictionary() + + # No-op if FunctionalTensorMode is already in use + def __enter__(self): + def _get_prev_mode(): + if self._dispatch_key == torch._C.DispatchKey.PreDispatch: + return _get_dispatch_mode_pre_dispatch( + torch._C._TorchDispatchModeKey.FUNCTIONAL + ) + return torch._C._get_dispatch_mode( + torch._C._TorchDispatchModeKey.FUNCTIONAL + ) + + if _get_prev_mode() is None: + self.enter_stack.append(True) + return super().__enter__() + else: + self.enter_stack.append(False) + return self + + def __exit__(self, a, b, c): + is_on_stack = self.enter_stack.pop() + if is_on_stack: + super().__exit__(a, b, c) + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + unrecognized_types = [ + t + for t in types + if not issubclass(t, torch._subclasses.FakeTensor) + and t not in [torch.Tensor, FunctionalTensor] + ] + + if unrecognized_types: + not_implemented_log.debug( + "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + + def _can_decompose(func): + # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832 + # Never decompose dropout in export + if self.export and func == torch.ops.aten.dropout.default: + return False + + # We unconditionally decompose ops that are maybe aliasing or mutating ops + from torch._decomp import _should_decompose_because_unsafe_op + + if _should_decompose_because_unsafe_op(func): + return True + + # (1) we unconditionally decompose maybe-aliasing or maybe-mutating ops, + # because we must know statically of an op mutates or aliasing in order to functionalize it properly + # (2) for mutating ops that have CompositeImplicit decomps, we choose to decompose them today. + # In theory, we could walk this back and avoid decomposing them later if we need to. + alias_info_present = any(arg.alias_info for arg in func._schema.arguments) + if alias_info_present or func._schema.is_mutable: + return True + + # If we are here, it means we are seeing functional composite op. + # For pre-dispatch IR, we don't want to decompose this op + # For post-dispatch IR, we do want to decompose this op. it is fine + # to decompose here even if you want to preserve a CIA in post-dispatch export + # because we already override decompose behaviour so it will do the + # right thing. + if self.export: + if self.pre_dispatch: + # If it is CIA custom op, we warn that we are assuming this op is indeed functional. + if func.namespace not in ["aten", "prim"] and func._can_decompose(): + warnings.warn( + f"At pre-dispatch tracing, we assume that any custom op marked with " + f"CompositeImplicitAutograd and have functional schema are safe to not decompose. " + f"Found {func} to be one such op." + ) + return False + return True + + # in normal torch.compile IR, we decompose functional composite ops + return True + + if ( + func not in FunctionalTensor.metadata_fns + and _can_decompose(func) + # Not all funcs from __torch_dispatch__ are actual dispatcher ops, + # e.g. prim.device + and torch._C._dispatch_has_kernel(func.name()) + ): + with self: + r = func.decompose(*args, **kwargs) + if r is not NotImplemented: + return r + + def wrap(x): + # Only wrap our outputs in subclasses if the inner functionalization call + # also wrapped outputs into FunctionalTensorWrappers. + # When can this happen? e.g. `torch.div(2, 2)` + assert not isinstance(x, FunctionalTensor) + if isinstance(x, torch.Tensor) and torch._is_functional_tensor(x): + return FunctionalTensor(x, self) + return x + + def unwrap(x): + return x.elem + + from torch._higher_order_ops.auto_functionalize import ( + can_auto_functionalize, + do_auto_functionalize, + do_auto_functionalize_v2, + ) + + if can_auto_functionalize( + func + ) and not torch._C._dispatch_has_kernel_for_dispatch_key( + func.name(), torch._C.DispatchKey.Functionalize + ): + import torch._inductor.config as inductor_config + + if self.export or not inductor_config.enable_auto_functionalized_v2: + return do_auto_functionalize(self, func, args, kwargs) + else: + return do_auto_functionalize_v2(self, func, args, kwargs) + + from torch._higher_order_ops.effects import handle_effects, has_effects + + if has_effects(func, args, kwargs): + assert not torch._C._dispatch_has_kernel_for_dispatch_key( + func.name(), torch._C.DispatchKey.Functionalize + ) + return handle_effects( + self._allow_token_discovery, self._tokens, func, args, kwargs + ) + + args_unwrapped, kwargs_unwrapped = pytree.tree_map_only( + FunctionalTensor, unwrap, (args, kwargs) + ) + + # Expectation: functionalization should not **already** be enabled above our mode. + # Why would that be bad? when we return a FunctionalTensor here, we don't want functionalization + # to run above this mode and further wrap that output in **another** C++ FunctionalTensorWrapper. + is_included = torch._C._dispatch_tls_is_dispatch_key_included( + torch._C.DispatchKey.Functionalize + ) + is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.Functionalize + ) + assert is_excluded or not is_included + include_to_set = ( + torch._C._dispatch_tls_local_include_set() + | torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + exclude_to_set = ( + torch._C._dispatch_tls_local_exclude_set().remove( + torch._C.DispatchKey.Functionalize + ) + - FunctionalTensor._extra_dispatch_keys + ) + + # All we want to do here is re-use the existing C++ functionalization logic. + # This requires swizzling our TLS dispatch keys so that the Functionalize key is active. + with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set): + try: + # By default for python functionalization (for AOTAutograd), we reapply views. + old_apply_views = torch._functionalize_enable_reapply_views(True) # type: ignore[attr-defined] + + # Sometimes these functions cannot be directly dispatched to functionalize key + # because args are sometimes not functional tensors for some reason? + if func in FunctionalTensor.metadata_fns: + outs_unwrapped = func(*args_unwrapped, **kwargs_unwrapped) + outs_wrapped = pytree.tree_map_only( + torch.Tensor, wrap, outs_unwrapped + ) + else: + # When we dispatch to the C++ functionalization kernel, we might need to jump back to the + # PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath + # FunctionalTensorMode. If we call func() directly, we would need to exclude PreDispatch + # from the TLS in order to avoid infinite looping, but this would prevent us from coming + # back to PreDispatch later + outs_unwrapped = func._op_dk( + torch._C.DispatchKey.Functionalize, + *args_unwrapped, + **kwargs_unwrapped, + ) + + if self.export: + if func == torch.ops.aten.dropout.default: + torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined] + outs_wrapped = pytree.tree_map_only( + torch.Tensor, wrap, outs_unwrapped + ) + finally: + torch._disable_functionalization() + torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined] + + is_included = torch._C._dispatch_tls_is_dispatch_key_included( + torch._C.DispatchKey.Functionalize + ) + is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.Functionalize + ) + assert is_excluded or not is_included + + if ( + # If no outputs are our functional subclass, then don't try to fix up aliasing + not any( + isinstance(x, FunctionalTensor) + for x in pytree.tree_leaves(outs_wrapped) + ) + # Since lift_fresh lifts its argument into a functional tensor, we can skip the + # aliasing correction step. Otherwise, we would be setting the storage of a + # lifted tensor to that of an unlifted tensor. + # Ref: https://github.com/pytorch/pytorch/issues/111506 + or func == torch.ops.aten.lift_fresh.default + ): + return outs_wrapped + # for metadata mutations, need to manually mutate the metadata of the FunctionalTensor wrapper + if ( + torch.Tag.inplace_view in func.tags + and func is not torch.ops.aten.set_.source_Tensor + ): + with torch.utils._mode_utils.no_dispatch(): + func(*args, **kwargs) + # Wrapper tensor subclasses do not have correct aliasing info! Use this util to manually correct the output aliasing. + # inplace ops like `aten.add_()` are expected to return inputs **directly**, instead of creating fresh tensor objects. + # Use this util to figure out the right thing to return. + # If none of our inputs were wrapped, then we have no FunctionalTensor outputs that we need to fix up storages for. + return return_and_correct_aliasing(func, args, kwargs, outs_wrapped) + + @classmethod + def is_infra_mode(cls) -> bool: + return True + + +@contextlib.contextmanager +def disable_functional_mode(): + return _disable_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) + + +# This is similar to torch.func.functionalize, but: +# - It uses FunctionalTensorMode, and FunctionalTensor (a python subclass). +# One important advantage to using this mode is that it will let us +# run functionalization underneath __torch_dispatch__, +# which we need in AOTAutograd. +# - Doing so means that it does not automatically compose with other +# functorch transforms, since these transforms always run above __torch_dispatch__. +# That's why this util lives here, and not in functorch. +def dispatch_functionalize(func, mode: FunctionalTensorMode = FunctionalTensorMode()): + # TODO: pull these from aot autograd + def to_fun(t): + if isinstance(t, torch.Tensor): + return FunctionalTensor.to_functional(t) + return t + + def from_fun(t): + if not isinstance(t, FunctionalTensor): + # quick sanity assert + if isinstance(t, torch.Tensor): + assert not torch._is_functional_tensor(t) + return t + torch._sync(t) + return torch._from_functional_tensor(t.elem) + + def inner(*args, **kwargs): + disable_above = torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + with disable_above, mode: + func_args = pytree.tree_map_only(torch.Tensor, to_fun, args) + func_kwargs = pytree.tree_map_only(torch.Tensor, to_fun, kwargs) + func_outputs = func(*func_args, **func_kwargs) + outputs = pytree.tree_map_only(FunctionalTensor, from_fun, func_outputs) + + return outputs + + return inner + + +class BaseFunctionalizeAPI(ABC): + @abstractmethod + def wrap_tensors(self, args: tuple[Any]) -> tuple[Any]: + pass + + @abstractmethod + def unwrap_tensors( + self, args: Union[torch.Tensor, tuple[torch.Tensor, ...]] + ) -> Any: + pass + + @abstractmethod + def functionalize(self, inner_f: Callable) -> Callable: + pass + + @abstractmethod + def redispatch_to_next(self) -> AbstractContextManager: + pass + + @abstractmethod + def replace(self, input_tensor, output_tensor) -> None: + pass + + @abstractmethod + def commit_update(self, tensor) -> None: + pass + + @abstractmethod + def sync(self, tensor) -> None: + pass + + @abstractmethod + def mark_mutation_hidden_from_autograd(self, tensor) -> None: + pass + + +class PythonFunctionalizeAPI(BaseFunctionalizeAPI): + def __init__( + self, mode: Optional[FunctionalTensorMode] = None, pre_dispatch: bool = False + ) -> None: + super().__init__() + self.mode = mode if mode else FunctionalTensorMode() + self.pre_dispatch = pre_dispatch + + def wrap_tensors(self, args: tuple[Any]) -> tuple[Any]: + with self.mode: + return torch.utils._pytree.tree_map_only( + torch.Tensor, FunctionalTensor.to_functional, args + ) + + def unwrap_tensors( + self, args: Union[torch.Tensor, tuple[torch.Tensor, ...], list[torch.Tensor]] + ) -> Any: + return torch.utils._pytree.tree_map_only( + FunctionalTensor, FunctionalTensor.from_functional, args + ) + + def functionalize(self, inner_f: Callable) -> Callable: + return dispatch_functionalize(inner_f, self.mode) + + def redispatch_to_next(self) -> AbstractContextManager: + # [NOTE] We don't do anything here because at the time + # we exercise this path, we would have already popped the + # FunctionalTensorMode from mode stack. Since FunctionalTensorMode + # is now stateful, it is better to explicitly pass in correct mode + # directly instead of globally setting it. + return contextlib.nullcontext() + + def replace(self, input_tensor, output_tensor) -> None: + assert isinstance(input_tensor, FunctionalTensor) + assert not isinstance(output_tensor, FunctionalTensor) + input_tensor.replace_(output_tensor) + + def commit_update(self, tensor) -> None: + assert isinstance(tensor, FunctionalTensor) + tensor.commit_update() + + def sync(self, tensor) -> None: + assert isinstance(tensor, FunctionalTensor) + tensor.sync() + + def mark_mutation_hidden_from_autograd(self, tensor) -> None: + assert isinstance(tensor, FunctionalTensor) + tensor.mark_mutation_hidden_from_autograd() + + +class CppFunctionalizeAPI(BaseFunctionalizeAPI): + def wrap_tensors(self, args: tuple[Any]) -> tuple[Any]: + from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional + + return _wrap_all_tensors_to_functional(args, level=0) + + def unwrap_tensors( + self, args: Union[torch.Tensor, tuple[torch.Tensor, ...]] + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + from torch._functorch.eager_transforms import ( + _unwrap_all_tensors_from_functional, + ) + + return _unwrap_all_tensors_from_functional(args, reapply_views=_reapply_views()) + + def functionalize(self, inner_f: Callable) -> Callable: + return torch.func.functionalize(inner_f) + + def redispatch_to_next(self) -> AbstractContextManager: + return torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ) + + def replace(self, input_tensor, output_tensor) -> None: + torch._functionalize_replace(input_tensor, output_tensor) + + def commit_update(self, tensor) -> None: + torch._functionalize_commit_update(tensor) + + def sync(self, tensor) -> None: + torch._functionalize_sync(tensor) + + def mark_mutation_hidden_from_autograd(self, tensor) -> None: + torch._functionalize_mark_mutation_hidden_from_autograd(tensor) + + +class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI): + def __init__(self, interpreter): + self.interpreter = interpreter + + def wrap_tensors(self, args: tuple[Any]) -> tuple[Any]: + from torch._functorch.eager_transforms import _wrap_all_tensors_to_functional + + return _wrap_all_tensors_to_functional(args, level=self.interpreter.level()) + + def unwrap_tensors( + self, args: Union[torch.Tensor, tuple[torch.Tensor, ...]] + ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + from torch._functorch.eager_transforms import ( + _unwrap_all_tensors_from_functional, + ) + + return _unwrap_all_tensors_from_functional( + args, reapply_views=self.interpreter.functionalize_add_back_views() + ) + + def functionalize(self, inner_f: Callable) -> Callable: + return torch.func.functionalize( + inner_f, + remove=( + "mutations_and_views" + if self.interpreter.functionalize_add_back_views() + else "mutations" + ), + ) + + def redispatch_to_next(self) -> AbstractContextManager: + return self.interpreter.lower() + + def replace(self, input_tensor, output_tensor) -> None: + torch._functionalize_replace(input_tensor, output_tensor) + + def commit_update(self, tensor) -> None: + torch._functionalize_commit_update(tensor) + + def sync(self, tensor) -> None: + torch._functionalize_sync(tensor) + + def mark_mutation_hidden_from_autograd(self, tensor) -> None: + torch._functionalize_mark_mutation_hidden_from_autograd(tensor) + + +def mb_unwrap_functional_tensor(tensor: torch.Tensor): + if isinstance(tensor, FunctionalTensor): + return torch._from_functional_tensor(tensor.elem) + return tensor diff --git a/phivenv/Lib/site-packages/torch/_subclasses/meta_utils.py b/phivenv/Lib/site-packages/torch/_subclasses/meta_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9597663200e86989eabfcd1ffdd638a3ba967def --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_subclasses/meta_utils.py @@ -0,0 +1,1937 @@ +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import threading +import typing +import warnings +import weakref +from abc import abstractmethod +from contextlib import AbstractContextManager, contextmanager +from dataclasses import dataclass +from typing import ( + Any, + Callable, + ClassVar, + Generic, + NewType, + Optional, + Protocol, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import override, TypedDict, TypeGuard, TypeIs, Unpack + +import torch +from torch._C._autograd import CreationMeta +from torch._C._functorch import ( + _add_batch_dim, + _unwrap_functional_tensor, + _wrap_functional_tensor, + get_unwrapped, + is_batchedtensor, + is_functorch_wrapped_tensor, + is_gradtrackingtensor, + is_legacy_batchedtensor, + maybe_get_bdim, + maybe_get_level, + peek_interpreter_stack, +) +from torch._dispatch.python import enable_python_dispatcher +from torch._logging import trace_structured +from torch.utils._mode_utils import no_dispatch +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torch.utils.weak import WeakIdKeyDictionary + + +if TYPE_CHECKING: + from collections.abc import Generator + + from torch._C._functorch import CInterpreter + from torch._guards import Source + from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode + + # Import here to avoid cycle + # Import the following modules during type checking to enable code intelligence features, + # Do not import unconditionally, as they import sympy and importing sympy is very slow + from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext + + +def _is_fake_tensor(t: object) -> TypeIs[FakeTensor]: + from torch._subclasses.fake_tensor import FakeTensor + + return isinstance(t, FakeTensor) + + +DimList = list +_TensorLikeT = TypeVar("_TensorLikeT", "MetaTensorDesc", torch.Tensor) +_T = TypeVar("_T") +_TensorT = TypeVar("_TensorT", bound=torch.Tensor) +_TensorT_cov = TypeVar("_TensorT_cov", bound=torch.Tensor, covariant=True) + + +def safe_is_leaf(t: Union[MetaTensorDesc, torch.Tensor]) -> bool: + try: + return t.is_leaf + except RuntimeError: + # inference mode can trigger this + return False + + +def safe_grad(t: _TensorLikeT) -> Optional[_TensorLikeT]: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") + return t.grad + + +def _expect_safe_grad(t: _TensorLikeT) -> _TensorLikeT: + grad = safe_grad(t) + assert grad is not None + return grad + + +def assert_eq(a: _T, b: _T) -> None: + assert a == b, f"{a} != {b}" + + +tls = threading.local() +# Turns off inference mode for fake tensor propagation. This is turned to True +# only for `torch.compile`. Also look at +# _dynamo.config.fake_tensor_disable_inference_mode +tls.disable_inference_mode = False + + +@contextmanager +def disable_inference_mode_for_fake_prop() -> Generator[None, None, None]: + prior = getattr(tls, "disable_inference_mode", False) + tls.disable_inference_mode = True + try: + yield + finally: + tls.disable_inference_mode = prior + + +def assert_metadata_eq( + assert_eq: Callable[[object, object], None], + m1: Union[MetaTensorDesc, torch.Tensor], + m2: torch.Tensor, + *, + skip_symbolic: bool = False, + skip_leaf: bool = False, +) -> None: + m1 = ( + MetaTensorDescriber().describe_tensor(m1) + if isinstance(m1, torch.Tensor) + else m1 + ) + + def go(m1: MetaTensorDesc, m2: torch.Tensor) -> None: + assert_eq(m1.dtype, m2.dtype) + if not skip_symbolic: + assert_eq(m1.shape, m2.shape) + assert_eq(m1.requires_grad, m2.requires_grad) + if not skip_leaf: + assert_eq(m1.is_leaf, m2.is_leaf) + # MetaTensorDesc doesn't store grad_fn; inferred from leaf + # assert_eq(m1.grad_fn is None, m2.grad_fn is None) + assert_eq(m1.is_sparse, m2.is_sparse) + if not getattr(tls, "disable_inference_mode", False): + assert_eq(m1.is_inference, m2.is_inference()) + else: + assert_eq(m1.is_inference, False) + assert_eq(m1.is_conj, m2.is_conj()) + assert_eq(m1.is_neg, m2.is_neg()) + assert_eq(m1.grad is not None, safe_grad(m2) is not None) + if m1.grad is not None: + go(m1.grad, _expect_safe_grad(m2)) + # TODO: move "assert_eq(m1.layout, m2.layout)" out of sparse + # branches (but not ready for prime time yet)... + if m1.is_sparse: + assert_eq(m1.layout, m2.layout) + assert_eq(m1.dense_dim, m2.dense_dim()) + assert_eq(m1.sparse_dim, m2.sparse_dim()) + assert_eq(m1.is_coalesced, m2.is_coalesced()) + elif is_sparse_compressed(m1): + assert_eq(m1.layout, m2.layout) + assert_eq(m1.dense_dim, m2.dense_dim()) + assert_eq(m1.sparse_dim, m2.sparse_dim()) + else: + if not skip_symbolic: + assert_eq(m1.stride, m2.stride()) + assert_eq(m1.storage_offset, m2.storage_offset()) + assert_eq(m1.is_view, m2._is_view()) + if m1.is_view: + assert m1.base is not None + assert m2._base is not None + go(m1.base, m2._base) + # TODO: test if is resizable (no direct query for this atm) + # TODO: audit AutogradMeta to see if it matches + # TODO: test forward AD + + return go(m1, m2) + + +# TypeGuard (not TypeIs): False does not imply !torch.Tensor +def is_sparse_coo(t: object) -> TypeGuard[torch.Tensor]: + return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo + + +def is_sparse_compressed_layout(layout: torch.layout) -> bool: + return layout in { + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + } + + +# TypeGuard (not TypeIs): False does not imply !torch.Tensor +def is_sparse_compressed(t: object) -> TypeGuard[torch.Tensor]: + return isinstance(t, torch.Tensor) and is_sparse_compressed_layout(t.layout) + + +# TypeGuard (not TypeIs): False does not imply !torch.Tensor +def is_sparse_any(t: object) -> TypeGuard[torch.Tensor]: + return is_sparse_coo(t) or is_sparse_compressed(t) + + +def _checked_cast(ty: type[_T], obj: object) -> _T: + assert isinstance(obj, ty), f"expected {ty} but got {type(obj)}" + return obj + + +def _get_real_storage(base: torch.UntypedStorage) -> torch.UntypedStorage: + return base.real_storage # type: ignore[attr-defined] + + +def _set_real_storage( + base: torch.UntypedStorage, real_storage: torch.UntypedStorage +) -> None: + base.real_storage = real_storage # type: ignore[attr-defined] + + +# Don't use id() directly, because those can get reallocated over time. +MetaStorageId = NewType("MetaStorageId", int) +MetaTensorId = NewType("MetaTensorId", int) + + +_DescriberId = NewType("_DescriberId", int) +DESCRIBER_NEXT_ID = _DescriberId(0) + + +class MetaTensorDescriber: + """ + Given a Tensor/Storage, generate a MetaTensorDesc/MetaStorageDesc + for it, which is enough information to reconstruct a meta tensor/fake tensor + corresponding to a Tensor as faithfully as possible. + + This is a stateful conversion object because we keep track of the IDs + of the tensors/storages passed to us, so we can consistently give + the same ID when we see the same tensor/storage. + """ + + def __init__(self, *, copy_data: bool = False) -> None: + global DESCRIBER_NEXT_ID + self.id = DESCRIBER_NEXT_ID + DESCRIBER_NEXT_ID = _DescriberId(DESCRIBER_NEXT_ID + 1) + self.next_tensor_id: MetaTensorId = MetaTensorId(0) + self.next_storage_id: MetaStorageId = MetaStorageId(0) + # Tensor -> int + self.lookup_tensor = WeakIdKeyDictionary() + # Storage -> int + self.lookup_storage = WeakIdKeyDictionary() + self.copy_data = copy_data + self.traced_tensors: set[int] = set() + self.traced_storages: set[int] = set() + + def get_tensor_id(self, t: torch.Tensor) -> MetaTensorId: + if t not in self.lookup_tensor: + self.lookup_tensor[t] = self.next_tensor_id + self.next_tensor_id = MetaTensorId(self.next_tensor_id + 1) + return self.lookup_tensor[t] + + def get_storage_id(self, s: torch.UntypedStorage) -> MetaStorageId: + if s not in self.lookup_storage: + self.lookup_storage[s] = self.next_storage_id + self.next_storage_id = MetaStorageId(self.next_storage_id + 1) + return self.lookup_storage[s] + + def describe_storage( + self, s: torch.UntypedStorage, *, trace: bool = False + ) -> MetaStorageDesc: + r = MetaStorageDesc( + id=self.get_storage_id(s), + size=s.size(), + # NB: We don't do the copy yet; copy happens when we start + # creating the new storages + data=s if self.copy_data else None, + ) + if trace and r.id not in self.traced_storages: + trace_structured( + "describe_storage", + metadata_fn=lambda: r.as_json(self.id), + ) + self.traced_storages.add(r.id) + return r + + def describe_tensor( + self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False + ) -> MetaTensorDesc: + is_leaf = safe_is_leaf(t) + is_view = t._is_view() + is_sparse = t.is_sparse + layout = t.layout + is_nested = t.is_nested + is_traceable_wrapper_subclass_v = is_traceable_wrapper_subclass(t) + is_functorch_wrapped = is_functorch_wrapped_tensor(t) + is_mkldnn = t.is_mkldnn + is_batchedtensor_v = is_batchedtensor(t) + is_legacy_batchedtensor_v = is_legacy_batchedtensor(t) + is_gradtrackingtensor_v = is_gradtrackingtensor(t) + is_functional = torch._is_functional_tensor(t) + + storage = None + # NB: For compatibility, I default this to zero, as sometimes people + # still have stuffed zero into storage offset even though the tensor + # doesn't meaningfully have an offset + storage_offset = 0 + if not ( + is_sparse + or is_sparse_compressed_layout(layout) + or (is_nested and not is_traceable_wrapper_subclass_v) + or is_mkldnn + # TODO: TBH, functorch wrapped tensors probably should have + # storage associated with them + or is_functorch_wrapped + or is_legacy_batchedtensor_v + ): + # NB: We actually don't use storage to do views, but might as well + # put it in for accuracy + storage = self.describe_storage(t.untyped_storage(), trace=trace) + storage_offset = t.storage_offset() # type: ignore[assignment] + + stride = None + if not ( + is_sparse + or is_sparse_compressed_layout(layout) + or (is_nested and not is_traceable_wrapper_subclass_v) + ): + # stride/storage_offset are called from is_functorch_wrapped, + # view_from_base, empty_create_subclass, + # sym_sizes_strides_storage_offset (empty_create) + stride = t.stride() + + # NB: this technically should refer to functorch unwrapped tensor, but + # I am (perhaps abusively) using it to store both the functorch and + # non-functorch functional tensor + unwrapped = None + autograd_meta_from = None + current_level = None + if is_batchedtensor_v or is_gradtrackingtensor_v: + unwrapped = self.describe_tensor(get_unwrapped(t), trace=trace) + # xla and lazy tensors present as functional tensors, but we want them + # to be handled specially + elif is_functional and t.device.type not in ("xla", "lazy"): + if t._is_view(): + raise RuntimeError( + "Cannot safely fakify a view because this process drops the view information right now." + ) + if not is_functorch_wrapped: + torch._sync(t) + unwrapped = self.describe_tensor( + torch._from_functional_tensor(t), trace=trace + ) + autograd_meta_from = t + else: + reapply_views = torch._C._functionalization_reapply_views_tls() + # NB: has side effects! + unwrapped = self.describe_tensor( + _unwrap_functional_tensor(t, reapply_views), trace=trace + ) + # TODO: It's pretty suspicious that functional tensors don't have + # valid level and thus we just grab whatever the current level + # is + current_level = torch._C._functorch.current_level() + + maybe_functorch_stack = None + if is_functorch_wrapped: + with torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() as maybe_functorch_stack: + pass + + attrs = None + ctx = None + type_v = None + if is_traceable_wrapper_subclass_v: + assert hasattr(t, "__tensor_flatten__") + raw_attrs, ctx = t.__tensor_flatten__() + attrs = { + attr: self.describe_tensor(getattr(t, attr), trace=trace) + for attr in raw_attrs + } + type_v = type(t) + + from torch.nested._internal.nested_tensor import _tensor_symint_registry + + view_func = ViewFunc.from_tensor(t) + + # TODO: Is it important to enable torch.inference_mode before querying + # these values? + is_inference_mode_disabled = getattr(tls, "disable_inference_mode", False) + r: MetaTensorDesc = MetaTensorDesc( + id=self.get_tensor_id(t), + storage=storage, + is_inference=False if is_inference_mode_disabled else t.is_inference(), + is_leaf=is_leaf, + requires_grad=t.requires_grad, + # NB: ndim should be OK too but there is a disaster at + # python test/dynamo/test_subclasses.py -k test_user_overridden_property_unsupported + # Actually, this means that we have a little bit of a problem + # here, which is that there is some sensitivity to how exactly an + # access is done if you have a __torch_function__ subclass. Maybe + # should disable torch function before doing accesses? + ndim=t.dim(), + dtype=t.dtype, + is_sparse=is_sparse, + is_mkldnn=is_mkldnn, + is_functorch_wrapped=is_functorch_wrapped, + is_batchedtensor=is_batchedtensor_v, + is_legacy_batchedtensor=is_legacy_batchedtensor_v, + is_gradtrackingtensor=is_gradtrackingtensor_v, + is_view=is_view, + is_conj=t.is_conj(), + is_neg=t.is_neg(), + is_parameter=isinstance(t, torch.nn.Parameter), + is_traceable_wrapper_subclass=is_traceable_wrapper_subclass_v, + is_nested=is_nested, + nested_int=( + _tensor_symint_registry[t].node.nested_int() + if t in _tensor_symint_registry + else None + ), + is_functional=is_functional, + layout=layout, + device=t.device, + size=t.size(), + stride=stride, + storage_offset=storage_offset, + dynamo_dynamic_indices=list(getattr(t, "_dynamo_dynamic_indices", set())), + sparse_dim=( + t.sparse_dim() if t.is_sparse or is_sparse_compressed(t) else None + ), + dense_dim=t.dense_dim() if t.is_sparse or is_sparse_compressed(t) else None, + is_coalesced=t.is_coalesced() if t.is_sparse else None, + # TODO: I actually think recursing here is correct, but we have at + # least an infinite cycle from base -> values -> base + # https://github.com/pytorch/pytorch/issues/122089 + crow_indices=( + self.describe_tensor(t.crow_indices(), recurse=False, trace=trace) + if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr} + else None + ), + col_indices=( + self.describe_tensor(t.col_indices(), recurse=False, trace=trace) + if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr} + else None + ), + ccol_indices=( + self.describe_tensor(t.ccol_indices(), recurse=False, trace=trace) + if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc} + else None + ), + row_indices=( + self.describe_tensor(t.row_indices(), recurse=False, trace=trace) + if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc} + else None + ), + values=( + self.describe_tensor(t.values(), recurse=False, trace=trace) + if recurse and is_sparse_compressed(t) + else None + ), + grad=( + self.describe_tensor(grad, trace=trace) + if (grad := safe_grad(t)) is not None + else None + ), + creation_meta=( + torch._C._autograd._get_creation_meta(t) if t._is_view() else None + ), + unwrapped=unwrapped, + level=( + maybe_get_level(t) + if is_batchedtensor_v or is_gradtrackingtensor_v + else None + ), + bdim=maybe_get_bdim(t) if is_batchedtensor_v else None, + base=( + self.describe_tensor(t._base, trace=trace) + if recurse and t._is_view() and t._base is not None + else None + ), + fake_mode=torch._subclasses.fake_tensor.maybe_get_fake_mode(t), + view_func=view_func, + attrs=attrs, + ctx=ctx, + type=type_v, + # NB: even if functorch is enabled, don't actually save the + # interpreter stack here unless we are actually functorch wrapped; + # it's irrelevant for non-functorch stuff + functorch_stack=maybe_functorch_stack, + autograd_meta_from=autograd_meta_from, + current_level=current_level, + data=t if self.copy_data else None, + ) + if trace and r.id not in self.traced_tensors: + trace_structured( + "describe_tensor", + metadata_fn=lambda: r.as_json(self.id), + ) + self.traced_tensors.add(r.id) + return r + + +@dataclass(frozen=True) +class MetaStorageDesc: + id: MetaStorageId + size: int + # NB: this is only populated with copy_data True, it is not directly + # serializable in JSON, you want to do something special here anyway + data: Optional[torch.UntypedStorage] + + def as_json(self, describer_id: _DescriberId) -> dict[str, object]: + return { + "id": self.id, + "describer_id": describer_id, + "size": self.size if isinstance(self.size, int) else repr(self.size), + } + + +@dataclass(frozen=True) +class ViewFunc(Generic[_TensorT]): + @abstractmethod + def apply( + self, + t: _TensorT, + new_base: _TensorT, + symint_visitor_fn: Optional[Callable[[int], int]] = None, + tensor_visitor_fn: Optional[Callable[[torch.Tensor], _TensorT]] = None, + ) -> _TensorT: + ... + + @staticmethod + def from_tensor(t: torch.Tensor) -> ViewFunc: + if _is_fake_tensor(t): + return _FakeTensorViewFunc() + else: + return _CustomViewFunc(t._view_func_unsafe) + + +@dataclass(frozen=True) +class _FakeTensorViewFunc(ViewFunc["FakeTensor"]): + @override + def apply( + self, + t: torch.Tensor, + new_base: torch.Tensor, + symint_visitor_fn: Optional[Callable[[int], int]] = None, + tensor_visitor_fn: Optional[Callable[[torch.Tensor], FakeTensor]] = None, + ) -> FakeTensor: + return torch._subclasses.fake_tensor.FakeTensor._view_func_unsafe( + t, new_base, symint_visitor_fn, tensor_visitor_fn + ) + + +@dataclass(frozen=True) +class _CustomViewFunc(ViewFunc[_TensorT], Generic[_TensorT]): + func: Callable[ + [ + torch.Tensor, + Optional[Callable[[int], int]], + Optional[Callable[[torch.Tensor], _TensorT]], + ], + _TensorT, + ] + + @override + def apply( + self, + t: torch.Tensor, + new_base: torch.Tensor, + symint_visitor_fn: Optional[Callable[[int], int]] = None, + tensor_visitor_fn: Optional[Callable[[torch.Tensor], _TensorT]] = None, + ) -> _TensorT: + # ignore `t` + return self.func(new_base, symint_visitor_fn, tensor_visitor_fn) + + +# A callback where the device is either optional or required. +# All of these satisfy this protocol: +# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str]) +# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str] = "meta") +# def mk(arg: Callable[[], torch.Tensor], device: Optional[Union[torch.device, str]] = None) +class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]): + def __call__( + self, arg: Callable[[], torch.Tensor], /, *, device: Union[torch.device, str] + ) -> _TensorT_cov: + ... + + +class _MetaTensorCallbackKwargs(TypedDict, total=False): + device: Union[torch.device, str] + + +# A callback where the device may not be provided (is optional). +# All of these satisfy this protocol: +# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str] = "meta") +# def mk(arg: Callable[[], torch.Tensor], device: Optional[Union[torch.device, str]] = None) +class _MetaTensorCallbackOptDevice(Protocol, Generic[_TensorT_cov]): + def __call__( + self, + arg: Callable[[], torch.Tensor], + /, + **kwargs: Unpack[_MetaTensorCallbackKwargs], + ) -> _TensorT_cov: + ... + + +@dataclass(frozen=True) +class MetaTensorDesc(Generic[_TensorT]): + id: MetaTensorId + ndim: int + dtype: torch.dtype + device: torch.device + + # NB: Sometimes, size, stride and storage_offset contain SymInt, in which + # case this is NOT serializable. That only happens when you're + # re-fakeifying a fake tensor with an existing ShapeEnv... maybe we + # can get rid of this use case entirely. Notably, even if we are + # fakeifying a real tensor into a fake tensor with symbolic shapes, the + # size here is NOT dynamic + # NB: These also contain SymInt because wrap_meta_outputs_with_default_device_logic + # goes through this codepath. But it really should not LOL. + # NB: size could potentially be None as you can override it and make it + # throw an error, but we don't currently have any subclasses that do this + # except C++ nested tensor but we're going to have nested int to make this + # defined on NJT + size: tuple[int, ...] + dynamo_dynamic_indices: list[int] + + layout: torch.layout = torch.strided + is_inference: bool = False + is_leaf: bool = False + requires_grad: bool = False + is_sparse: bool = False + is_mkldnn: bool = False + is_functorch_wrapped: bool = False + is_batchedtensor: bool = False + is_legacy_batchedtensor: bool = False + is_gradtrackingtensor: bool = False + is_view: bool = False + is_nested: bool = False + # We eagerly symbolicize the associated nested int for e.g. offsets / lengths + # metadata if that offsets is already associated with a nested int. + # See test_construct_from_jagged_with_input_offsets_mixed_case. + nested_int: Optional[int] = None + is_traceable_wrapper_subclass: bool = False + is_functional: bool = False + is_conj: bool = False + is_neg: bool = False + is_parameter: bool = False + stride: Optional[tuple[int, ...]] = None + storage_offset: int = 0 + # NB: We have a choice whether or not to store the id or a direct pointer + # to the data structure. For ease of use, we store the data structure, + # but this means that when we serialize, we have to swizzle these pointers + # back into ids (so we have accurate aliasing relationships) + storage: Optional[MetaStorageDesc] = None + sparse_dim: Optional[int] = None # is_sparse, is_sparse_compressed + dense_dim: Optional[int] = None # is_sparse, is_sparse_compressed + is_coalesced: Optional[bool] = None # is_sparse + crow_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed + col_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed + ccol_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed + row_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed + values: Optional[MetaTensorDesc] = None # is_sparse_compressed + unwrapped: Optional[MetaTensorDesc] = None # is_functorch_wrapped + bdim: Optional[int] = None # is_functorch_wrapped + base: Optional[MetaTensorDesc] = None # is_view + attrs: Optional[dict[str, MetaTensorDesc]] = None # is_traceable_wrapper_subclass + creation_meta: Optional[CreationMeta] = None + grad: Optional[MetaTensorDesc] = None + + # Everything below is NOT serializable, need some more work + + _UNSERIALIZABLE: ClassVar[set[str]] = { + "ctx", + "type", + "fake_mode", + # view_func isn't serializable when it's a _CustomViewFunc + "view_func", + "level", + "current_level", + "functorch_stack", + "autograd_meta_from", + "data", + "nested_int", + } + + ctx: Optional[object] = None # is_traceable_wrapper_subclass + type: Optional[type] = None # is_traceable_wrapper_subclass + fake_mode: Optional[FakeTensorMode] = None + view_func: Optional[ViewFunc] = None + # level looks serializable, but actually it is meaningless without + # the functorch_stack below + level: Optional[int] = None # is_functorch_wrapped + current_level: Optional[int] = None + functorch_stack: Optional[list[CInterpreter]] = None + autograd_meta_from: Optional[torch.Tensor] = None + + # This is only populated on copy_data, and typically is not used at all, + # except for some of our meta-ification paths that don't properly use + # storage (pro-tip: you should use storage) + data: Optional[torch.Tensor] = None + + # Faithfully serializing functorch tensors will not be too difficult. + # We only need to consider grad/vmap interpreters, and their internal + # state is only bools (mostly what the grad enabled/disabled state + # should be in the lower layer). Beyond that, tensors just need to + # precisely indicate which particular interpreter they correspond + # to (we then replace level with a pointer to the interpreter stack.) + # However, this use of functorch is very "non-lexical" so it's not + # entirely clear how to make it all lexical again, so we haven't done + # it for now. + + # NB: This will reference numeric IDs, and it is assumed that you've + # already serialized everything this recursively references + def as_json(self, describer_id: _DescriberId) -> dict[str, object]: + def json(k: str, v: object) -> object: + # Some best-effort debugging serialization for unserializable + # fields (feel free to add other special cases as appropriate) + if k in ["data", "autograd_meta_from"]: + return None # never repr these + if k in MetaTensorDesc._UNSERIALIZABLE: + return repr(v) + if isinstance(v, (torch.device, torch.dtype, torch.layout)): + return repr(v) + if isinstance(v, torch.SymInt): + return repr(v) + if isinstance(v, (tuple, list)): + return [json(k, v1) for v1 in v] + if isinstance(v, (MetaStorageDesc, MetaTensorDesc)): + return v.id + if isinstance(v, CreationMeta): + return str(v) + if k == "attrs" and isinstance(v, dict): + return {k1: v1.id for k1, v1 in v.items()} + return v + + r = { + field.name: json(field.name, getattr(self, field.name)) + for field in dataclasses.fields(self) + if not ( + getattr(self, field.name) is field.default + or ( + field.name == "dynamo_dynamic_indices" + and not getattr(self, field.name) + ) + ) + } + r.update({"describer_id": describer_id}) + return r + + @property + def shape(self) -> tuple[int, ...]: + return self.size + + +# A more faithful reproduction would do a copy on the entire +# storage, but this needs to be done carefully because the +# underlying storage could have larger extent than is implied +# by size/stride. The real fix is to properly call +# meta_storage recursively here. +# +# These "safe" functions are intended to be used under no_dispatch() mode. +# The no_dispatch() here is intended to prevent ambient fake tensor mode from +# fakeifying the operation. But if we are given an honest to goodness +# FakeTensor as src, we MUST NOT run the copy/clone operation. A better way +# to do this would be to not use no_dispatch and instead just disable fake +# tensor mode only (allowing for subclass dispatch to occur) +def _safe_copy(dst: torch.Tensor, src: Optional[torch.Tensor]) -> None: + if type(src) is not torch.Tensor: + return + dst.copy_(src) + + +def _safe_clone(src: torch.Tensor) -> Optional[torch.Tensor]: + if type(src) is not torch.Tensor: + return None + return src.clone() + + +# This is a class for converting multiple tensors into meta tensors which +# share the same view/storage structure. The operation model is you allocate +# one of these, and then call it repeatedly on all the tensors you want to +# convert. It's important to use the same object for tensors you want to +# share storage because this is how we correlate shared storages to the same +# meta storages. This class will hold weak references to cached tenosrs +# and tensor storages. +class MetaConverter(Generic[_TensorT]): + def __init__(self, *, copy_data: bool = False) -> None: + # Maps MetaStorageId to UntypedStorage + self.storage_memo: weakref.WeakValueDictionary[ + MetaStorageId, torch.UntypedStorage + ] = weakref.WeakValueDictionary() + # Maps MetaTensorId to torch.Tensor (typically a meta tensor or + # FakeTensor) + self.tensor_memo: weakref.WeakValueDictionary[ + MetaTensorId, _TensorT + ] = weakref.WeakValueDictionary() + self.hit = 0 + self.miss = 0 + self.del_hook = None + self.arg_cnt = 0 + # Ensures real_storage/real_tensor are populated on the resulting + # metaified storage/tensor. The naming of this attribute is load + # bearing: FakeTensor relies on real tensor being set to exactly this + # value + self.copy_data = copy_data + self.describer = MetaTensorDescriber(copy_data=copy_data) + + def successful(self) -> bool: + return self.hit > 0 and self.miss == 0 + + def get_tensor_memo(self, t: MetaTensorDesc) -> Optional[torch.Tensor]: + return self.tensor_memo.get(t.id, None) + + def _checked_get_tensor_memo(self, t: MetaTensorDesc) -> _TensorT: + r = self.tensor_memo.get(t.id, None) + assert r is not None + return r + + def set_tensor_memo(self, t: MetaTensorDesc, v: _TensorT) -> None: + self.tensor_memo[t.id] = v + + def get_storage_memo(self, s: MetaStorageDesc) -> Optional[torch.UntypedStorage]: + return self.storage_memo.get(s.id, None) + + def set_storage_memo(self, s: MetaStorageDesc, v: torch.UntypedStorage) -> None: + self.storage_memo[s.id] = v + + def meta_storage( + self, + s: MetaStorageDesc, + callback: Callable[[Callable[[], torch.Tensor]], _TensorT], + ) -> torch.UntypedStorage: + # If we are fakeifying a tensor that has a secretly-zero-sized storage, + # Need to make sure to resize the meta storage too. + if (memo := self.get_storage_memo(s)) is None: + r_s = callback( + lambda: torch.empty(s.size, dtype=torch.uint8, device="meta"), + ).untyped_storage() + if self.copy_data: + # NB: no_dispatch is needed because internally storage copy is + # implemented as Tensor operations + with torch.no_grad(), no_dispatch(): + assert s.data is not None + _set_real_storage(r_s, s.data.clone()) + self.set_storage_memo(s, r_s) + return r_s + else: + return memo + + @classmethod + def _checked_cast_tensor_t(cls, t: torch.Tensor) -> _TensorT: + # TODO: how to check _TensorT? + return typing.cast(_TensorT, t) + + @classmethod + def _identity_callable( + cls, + t: Callable[[], torch.Tensor], + device: Optional[Union[torch.device, str]] = None, + ) -> _TensorT: + return cls._checked_cast_tensor_t(t()) + + @classmethod + def _backward_error(cls, t: _TensorT) -> _TensorT: + errfn = torch._C._functions.DelayedError( + "Internal error: Tried to backward() through example input", + 1, + ) + err = errfn(t) + return typing.cast(_TensorT, err) + + # This function assumes that it's possible to do the conversion + # NB: name here is used in a conventional way by Dynamo; it corresponds + # precisely to the Source.name() of the tensor we're fakeifying and + # corresponds to a valid Python expression. When we construct sub-names + # as part of this process, we will maintain this invariant! (Even though + # other users of this may not need it this property to be upheld.) + def meta_tensor( + self, + t: MetaTensorDesc, + shape_env: Optional[ShapeEnv], + callback_: _MetaTensorCallback[_TensorT], + source: Optional[Source], + symbolic_context: Optional[SymbolicContext], + ) -> _TensorT: + callback: _MetaTensorCallbackOptDevice = functools.partial( + callback_, device=t.device + ) + if source is None: + from torch._dynamo.source import ConstantSource + + # TODO: make a dedicated UnknownSource for this? + source = ConstantSource( + f"__meta_utils_unknown_tensor{len(self.tensor_memo)}" + ) + + # This indicates you set no_dispatch() before calling into this + # function. This is an error: we may be creating fake tensors and + # will perform operations on them which need fake tensor mode to + # be active. You will segfault if you are in a no_dispatch() block. + assert not torch._C._dispatch_tls_local_exclude_set().has( + torch._C.DispatchKey.Python + ) + self.arg_cnt += 1 + + # When we make as_strided calls, we end up generating a guard + # that the new as_strided tensor is in bounds for the old storage + # for the base (since as_strided calls can "bust" out of their + # bounding box.) This guard is unnecessary: if a user is able + # to provide us a tensor with the view base setup this way, we + # don't need to produce a guard, because the fact that they + # were able to produce the view base means its in bounds. + # + # Now, ordinarily, this guard would be harmless. However, the + # generated guard refers to variables bound on the base variable. + # At the moment, Dynamo doesn't actually guard on x._base, because + # according to Voz this results in a lot of spurious invalidations, + # and also if the user doesn't directly make use of _base, its + # pointless anyway (because programs should be parametric over + # whether or not the input tensor is a view or not--unless you're + # mutating the input, but that's a whole 'nother ballgame). So + # for expediency, we suppress these guards so we don't have to + # deal with this (yet, anyway.) + # + # NB: An old version of this code suppressed guards for ALL operations + # happening during meta conversion, not just as_strided calls. + # This is too aggressive: we do duck sizing and 0/1 simplification + # as we allocate variables, and we do need to register guards for + # these cases. + maybe_suppress: Callable[[], Any] = contextlib.nullcontext + if shape_env is not None: + maybe_suppress = shape_env.suppress_guards + + def sym_sizes_strides_storage_offset( + t: MetaTensorDesc, + src: torch._guards.Source, + symbolic_context: Optional[ + torch.fx.experimental.symbolic_shapes.SymbolicContext + ] = symbolic_context, + ) -> tuple[tuple[int, ...], tuple[int, ...], int]: + assert t.stride is not None + if shape_env is not None: + fake_mode = t.fake_mode + if fake_mode is not None and fake_mode.shape_env is shape_env: + # Don't reallocate the sizes; the shape envs are the same, + # so reuse the old sizes/strides/etc + return (t.size, t.stride, t.storage_offset) + else: + # TODO: deduplicate this + t_size = tuple( + shape_env._maybe_specialize_sym_int_with_hint(sz) + for sz in t.size + ) + t_stride = tuple( + shape_env._maybe_specialize_sym_int_with_hint(sd) + for sd in t.stride + ) + t_storage_offset = shape_env._maybe_specialize_sym_int_with_hint( + t.storage_offset + ) + return shape_env._create_symbolic_sizes_strides_storage_offset( + t_size, + t_stride, + t_storage_offset, + [d in t.dynamo_dynamic_indices for d in range(t.ndim)], + src, + symbolic_context=symbolic_context, + ) + else: + return (t.size, t.stride, t.storage_offset) + + def empty_create( + inner_t: MetaTensorDesc, + inner_src: torch._guards.Source, + symbolic_context: Optional[ + torch.fx.experimental.symbolic_shapes.SymbolicContext + ] = symbolic_context, + ) -> torch.Tensor: + ( + inner_sizes, + inner_strides, + _inner_storage_offset, + ) = sym_sizes_strides_storage_offset(inner_t, inner_src, symbolic_context) + return torch.empty_strided( + inner_sizes, + inner_strides, + dtype=inner_t.dtype, + device="meta", + ) + + # Creates a subclass instance with empty inner tensors according to the specified + # symbolic context. + def empty_create_subclass( + t: MetaTensorDesc, + outer_size: tuple[int, ...], + outer_stride: tuple[int, ...], + symbolic_context: Optional[ + torch.fx.experimental.symbolic_shapes.SymbolicContext + ] = symbolic_context, + source: Optional[torch._guards.Source] = source, + ) -> _TensorT: + from torch._dynamo.source import AttrSource + from torch.fx.experimental.symbolic_shapes import SubclassSymbolicContext + + assert t.attrs is not None + assert t.type is not None + # NB: t.ctx could be None if the subclass in question has no + # meaningful context + + # Note: transform_subclass will use __tensor_unflatten__ to generate + # a fresh subclass wrapper with outer sizes / strides according to the + # outer symbolic context (passed in to this function). Inner size / stride + # / storage offset symbols are allocated according to the appropriate inner + # symbolic contexts, after which the checks in transform_subclass() will + # relate them to the outer metadata as possible. + # + # Morally, the code here is same as transform_subclass, but we've + # written it from scratch to read EmptyCreateSubclass + outer_size = outer_size if outer_size is not None else t.size + outer_stride = outer_stride if outer_stride is not None else t.stride + + assert symbolic_context is None or isinstance( + symbolic_context, SubclassSymbolicContext + ) + + def _empty_create_subclass( + t: MetaTensorDesc, + outer_size: Optional[tuple[int, ...]], + outer_stride: Optional[tuple[int, ...]], + symbolic_context: Optional[ + torch.fx.experimental.symbolic_shapes.SymbolicContext + ], + callback: _MetaTensorCallbackOptDevice[_TensorT], + source: torch._guards.Source, + ) -> _TensorT: + # We are hitting plain meta_desc tensor so actually + # create a tensor here. + if t.attrs is None: + return self.meta_tensor( + t, + shape_env, + callback, + source, + symbolic_context, + ) + + inner_tensors = {} + for attr, meta_tensor_desc in t.attrs.items(): + current_context = None + if symbolic_context is not None: + assert isinstance(symbolic_context, SubclassSymbolicContext) + if ( + current_context_ := symbolic_context.inner_contexts[attr] + ) is not None: + current_context = _checked_cast( + torch.fx.experimental.symbolic_shapes.SymbolicContext, + current_context_, + ) + + current_source = AttrSource(source, attr) + inner_callback = functools.partial( + callback, device=meta_tensor_desc.device + ) + new_empty_tensor = _empty_create_subclass( + meta_tensor_desc, + meta_tensor_desc.size, + meta_tensor_desc.stride, + current_context, + inner_callback, + current_source, + ) + inner_tensors[attr] = new_empty_tensor + + assert t.type is not None + return t.type.__tensor_unflatten__( # type: ignore[attr-defined] + inner_tensors, t.ctx, outer_size, outer_stride + ) + + assert source is not None + sub = _empty_create_subclass( + t, outer_size, outer_stride, symbolic_context, callback, source + ) + + # NB: Purposefully guard here to simplify the inner / outer symbols. + # Using sym_eq() for symbolic comparison can result in an expression that's too + # difficult to guard on, so we use == here. + assert sub.shape == outer_size, ( + f"Expected return value from {t.type}__tensor_unflatten__() to have " + f"shape equal to {outer_size}, but got: {sub.shape}" + ) + assert sub.stride() == outer_stride, ( + f"Expected return value from {t.type}__tensor_unflatten__() to have " + f"stride equal to {outer_stride}, but got: {sub.stride()}" + ) + + return sub + + # Returns an all-dynamic symbolic context used for metafying the given tensor with + # fully dynamic dims. This is useful when fake-ifying intermediate tensors in + # closed-over ViewFunc state, as we don't have symbolic contexts for them, but we + # don't want to over-specialize during view replay. + def all_dynamic_symbolic_context( + t: MetaTensorDesc, + source: torch._guards.Source, + shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv], + callback: _MetaTensorCallback[_TensorT], + ) -> torch.fx.experimental.symbolic_shapes.SymbolicContext: + from torch._dynamo.source import AttrSource + from torch.fx.experimental.symbolic_shapes import ( + DimDynamic, + StatelessSymbolicContext, + SubclassSymbolicContext, + ) + + view_base_context: Optional[ + torch.fx.experimental.symbolic_shapes.SymbolicContext + ] = None + if t.is_view: + assert t.base is not None + view_base_context = all_dynamic_symbolic_context( + t.base, AttrSource(source, "_base"), shape_env, callback + ) + + t_symbolic_context: torch.fx.experimental.symbolic_shapes.SymbolicContext + t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.ndim + if t.is_traceable_wrapper_subclass: + assert t.attrs is not None + inner_contexts: dict[ + str, torch.fx.experimental.symbolic_shapes.SymbolicContext + ] = {} + for attr, inner in t.attrs.items(): + assert isinstance(attr, str) + inner_contexts[attr] = all_dynamic_symbolic_context( + inner, AttrSource(source, attr), shape_env, callback + ) + t_symbolic_context = SubclassSymbolicContext( + dynamic_sizes=t_dynamic_sizes, + constraint_sizes=[None] * t.ndim, + inner_contexts=inner_contexts, # type: ignore[arg-type] + tensor_source=source, + view_base_context=view_base_context, + ) + else: + t_symbolic_context = StatelessSymbolicContext( + dynamic_sizes=t_dynamic_sizes, + constraint_sizes=[None] * t.ndim, + view_base_context=view_base_context, + ) + + return t_symbolic_context + + # Returns a fake-ified version of an input view tensor t, given an already fake-ified + # base. At a high level, we want two things: + # 1. fake_t should have the same view relationship to the given fake base as the + # input t has to its _base. + # 2. fake_t should have symbolic sizes / strides / storage offset according to the + # appropriate symbolic context (i.e. from the automatic dynamic algorithm). + # + # We currently take different strategies across view types: + # * For dense -> dense views, accomplish both (1) and (2) simultaneously via an + # as_strided() call on the fake-ified base, passing symbolic metadata. + # * For views involving subclasses, perform view replay using view funcs to + # achieve (1). It's necessary for (2) to swap out any closed-over state in + # the view funcs with symbolicized SymInts and fake-ified tensors. Doing this + # avoids specialization (and thus over-eager simplification of symbols) that + # could occur during view replay on the fake-ified base. + # + # Examples: + # * t.unsqueeze(-1) with dense t is a dense -> dense view. It can be modeled + # with an as_strided() call on the fake base passing symbolic metadata. + # * sub.select(dim=0, index=3) is a subclass -> subclass view. The index arg + # is made symbolic to avoid invalid specialization and view replay is then + # done to reconstruct the view. + # * _nested_from_jagged(values, offsets) is a dense -> subclass view + # that returns a subclass instance from a dense values tensor. The offsets + # tensor is closed over in the view func, as it can be considered view metadata. + # First, the offsets tensor is fake-ified according to the inner symbolic + # context and with the correct relationship to the outer size / stride metadata. + # Then view replay is done, swapping in the fake offsets so the view replay output + # is fully fake with no invalid specialization. + def view_from_base( + base: _TensorT, + t: MetaTensorDesc, + shape_env: Optional[ + torch.fx.experimental.symbolic_shapes.ShapeEnv + ] = shape_env, + ) -> _TensorT: + with enable_python_dispatcher(): + # fake-ify t's metadata according to the outer symbolic context + (sizes, strides, storage_offset) = sym_sizes_strides_storage_offset( + t, source + ) + if ( + not t.is_traceable_wrapper_subclass + and not is_traceable_wrapper_subclass(base) + ): + # Dense -> Dense view case uses as_strided() to construct view relationship. + # TODO: Change this logic to use view replay for consistency? + # It's likely there is no view func available. + with maybe_suppress(): + return self._checked_cast_tensor_t( + base.as_strided(sizes, strides, storage_offset) + ) + + from torch._dynamo.source import EphemeralSource + from torch.fx.experimental.symbolic_shapes import ( + StatelessSymbolicContext, + sym_eq, + ) + + def symint_visitor_fn(s: int) -> int: + nonlocal symbolic_context + from torch.fx.experimental.symbolic_shapes import DimDynamic + + all_static_sizes = ( + symbolic_context is not None + and isinstance(symbolic_context, StatelessSymbolicContext) + and all( + x is DimDynamic.STATIC + for x in symbolic_context.dynamic_sizes + ) + ) + # Can't just rely on shape env being None - dynamo always initializes it + if all_static_sizes or shape_env is None: + return s + + # NB: The symbol here is expected to be simplified out because we a priori + # allocate inner and outer symbols according to the appropriate symbolic + # contexts and prefer those over this symbol during symbol simplification + # (via usage of EphemeralSource below). This -shouldn't- happen, but if + # this symbol somehow leaks out beyond the view tensor's shape metadata, our + # assumption of it being simplified out will fail and it may be guarded on, + # which will hard error. + sym_source = EphemeralSource("symint_visitor_fn") + + symbol = shape_env.create_symbol(s, sym_source, positive=None) + return shape_env.create_symintnode( + symbol, hint=s, source=sym_source + ) + + real_to_fake_mapping = {} + if t.is_traceable_wrapper_subclass: + assert t.attrs is not None + # NB: t.ctx could be None if the subclass in question has no + # meaningful context + assert t.type is not None + + # Fake-ify t naively here; this is only done so we can get fake-ified inner + # tensors with the correct relationships to the outer sizes / strides for use + # in view replay. It's done beforehand here because it's not easy to do when + # visiting tensors one-by-one during view replay. + # + # Example: + # Consider a Dense -> NJT view. NJT has (values, offsets) components and we + # want a view of values with the offsets closed over. As the offsets component + # is needed to describe the output view, it's important that it's fakeified + # correctly. + fake_t: _TensorT = empty_create_subclass( + t, outer_size=sizes, outer_stride=strides + ) + attrs, _ = fake_t.__tensor_flatten__() # type: ignore[attr-defined] + for attr in attrs: + real_to_fake_mapping[t.attrs[attr].id] = getattr(fake_t, attr) + + def tensor_visitor_fn( + visited_t: torch.Tensor, + # These arguments are never passed, we just use them to close + # over these relevant values + shape_env: Optional[ + torch.fx.experimental.symbolic_shapes.ShapeEnv + ] = shape_env, + callback: _MetaTensorCallbackOptDevice[_TensorT] = callback, + ) -> torch.Tensor: + # It's possible to close over an undefined tensor (e.g. NJT's lengths). + if visited_t is None: + return None + + # NB: visited_t being a Tensor here is very naughty! Should + # have already been described + + # Fake inner tensors of view subclasses will come from the mapping built above. + visited_id = self.describer.get_tensor_id(visited_t) + fake_visited_t = real_to_fake_mapping.get(visited_id, None) + if fake_visited_t is not None: + return fake_visited_t + + visited_desc = self.describer.describe_tensor(visited_t) + + # For other closed-over tensor state, fake-ify it as all dynamic with an + # ephemeral source. This avoids invalid specialization during view replay. + # If we find that in practice the usage of ephemeral sources isn't enough + # to guarantee that we don't have guards on these symbols, we may need to + # explicitly suppress guards (as is done for _base in the dense -> dense + # view case). + temp_source = EphemeralSource("tensor_visitor_fn") + return self.meta_tensor( + visited_desc, + shape_env, + callback, + temp_source, + all_dynamic_symbolic_context( + visited_desc, temp_source, shape_env, callback + ), + ) + + # Replay the view, swapping out any non-symbolic SymInts or real tensors + # for symbolic SymInts or fake tensors. + assert t.view_func is not None + # NB: we do NOT suppress guards here, we need to remove ephemeral + # sources + fake_t = t.view_func.apply( + t, base, symint_visitor_fn, tensor_visitor_fn + ) + + # Ensure the output has symbolic shapes according to the outer symbolic context. + # These checks should simplify out any symbols created for closed-over view func + # SymInts. + torch._check(sym_eq(fake_t.size(), sizes)) + torch._check(sym_eq(fake_t.stride(), strides)) + torch._check(sym_eq(fake_t.storage_offset(), storage_offset)) + return fake_t + + if self.get_tensor_memo(t) is None: + GRAD_TENSOR_SENTINEL_VALUE = -2 + + with torch.inference_mode(t.is_inference): + if t.is_sparse: + is_leaf = t.is_leaf + + # The lambda function below is similar to + # `t.to(device='meta')` except the latter + # preserves nnz value + r = callback( + lambda: torch.ops.aten._sparse_coo_tensor_with_dims( + t.sparse_dim, + t.dense_dim, + t.size, + dtype=t.dtype, + layout=torch.sparse_coo, + device="meta", + ) + ) + if self.copy_data: + # Pray that sparse clone doesn't lose information + assert t.data is not None + with torch.no_grad(), no_dispatch(): + assert _is_fake_tensor(r) + r.real_tensor = _safe_clone(t.data) + assert safe_is_leaf(r), "the callback you passed in doesn't detach" + # Note [is_coalesced is dispatched] + # Strangely enough, is_coalesced() is a dispatched operator, + # which means that it will get caught by fake tensor mode. + # Ordinarily this would error, but there's some logic in + # fake tensor ensure this doesn't happen. + r._coalesced_(bool(t.is_coalesced)) + if t.requires_grad: + r.requires_grad = True + if t.requires_grad and not is_leaf: + # This should probably use DelayedError, + # but clone is fine for now for sparse tensors. + # (DelayedError does not work for sparse because it causes + # the Fake sparse tensor to "lose" its fakeness) + r = self._checked_cast_tensor_t(r.clone()) + with torch.enable_grad(): + r._coalesced_(bool(t.is_coalesced)) + elif is_sparse_compressed_layout(t.layout): + is_leaf = t.is_leaf + + if t.layout in {torch.sparse_bsr, torch.sparse_bsc}: + assert t.sparse_dim is not None + assert t.dense_dim is not None + assert t.values is not None + batch_dim = t.ndim - t.sparse_dim - t.dense_dim + blocksize = t.values.shape[batch_dim + 1 : batch_dim + 3] + else: + blocksize = () + if t.layout in {torch.sparse_csr, torch.sparse_bsr}: + assert t.crow_indices is not None + index_dtype = t.crow_indices.dtype + else: + assert t.ccol_indices is not None + index_dtype = t.ccol_indices.dtype + + r = callback( + lambda: torch.ops.aten._sparse_compressed_tensor_with_dims( + 0, + t.dense_dim, + t.shape, + blocksize, + index_dtype, + layout=t.layout, + dtype=t.dtype, + device="meta", + ) + ) + if self.copy_data: + # Pray sparse clone doesn't lose information + assert t.data is not None + with torch.no_grad(), no_dispatch(): + assert _is_fake_tensor(r) + r.real_tensor = _safe_clone(t.data) + assert safe_is_leaf(r), "the callback you passed in doesn't detach" + if t.requires_grad: + r.requires_grad = True + if t.requires_grad and not is_leaf: + r = self._backward_error(r) + elif t.is_nested and not t.is_traceable_wrapper_subclass: + # TODO: Handle this better in Dynamo? + # There are checks there now, but this can still be triggered by a dense + # tensor graph input that is a view of a strided NT. + from torch._dynamo.exc import unimplemented + + unimplemented( + "strided nested tensors are not supported by meta conversion" + ) + elif t.is_mkldnn: + is_leaf = t.is_leaf + ( + sizes, + strides, + _storage_offset, + ) = sym_sizes_strides_storage_offset(t, source) + # TODO: This doesn't seem right, where's the MKLDNN'ness + # lol + r = callback( + lambda: torch.empty_strided( + sizes, strides, dtype=t.dtype, device="meta" + ) + ) + if self.copy_data: + with torch.no_grad(), no_dispatch(): + assert t.size is not None + assert t.stride is not None + assert _is_fake_tensor(r) + r.real_tensor = torch.empty_strided( + t.size, t.stride, dtype=t.dtype, device=t.device + ) + assert t.data is not None + _safe_copy(r.real_tensor, t.data) + assert safe_is_leaf(r), "the callback you passed in doesn't detach" + if t.requires_grad: + r.requires_grad = True + if t.requires_grad and not is_leaf: + r = self._backward_error(r) + elif t.is_functorch_wrapped: + if t.is_view: + from torch._dynamo.exc import unimplemented + + unimplemented( + "view functorch tensors are not supported by meta conversion" + ) + + # Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor) + # in a FakeTensor + def _to_fake_tensor(t: MetaTensorDesc) -> _TensorT: + # TODO: why aren't the recursive calls going to + # meta_tensor + r: _TensorT + if t.is_batchedtensor: + assert t.unwrapped is not None + assert t.level is not None + assert t.bdim is not None + ft = _to_fake_tensor(t.unwrapped) + lvl = t.level + bdim = t.bdim + # You cannot create functorch tensors without + # having the ambient funtorch interpreter stack + # available, as the level refers to things in the + # stack + with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack( + t.functorch_stack + ): + r = self._checked_cast_tensor_t( + _add_batch_dim(ft, bdim, lvl) + ) + elif t.is_gradtrackingtensor: + assert t.unwrapped is not None + assert t.level is not None + disable_functorch = torch._C._DisableFuncTorch + with disable_functorch(): + ft = _to_fake_tensor(t.unwrapped) + lvl = t.level + if lvl == GRAD_TENSOR_SENTINEL_VALUE: + r = ft + else: + with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack( + t.functorch_stack + ): + r = self._checked_cast_tensor_t( + torch._C._functorch._wrap_for_grad(ft, lvl), + ) + + is_leaf = t.is_leaf + if t.requires_grad and safe_is_leaf(r): + r.requires_grad = True + elif t.requires_grad and not is_leaf: + r = self._backward_error(r) + elif t.is_functional: + assert t.unwrapped is not None + assert t.current_level is not None + ft = self.meta_tensor( + t.unwrapped, + shape_env, + callback, + # NB: reuse these exactly, we treat the + # functional tensor as "invisible". + # TODO: Actually this all probably doesn't + # work, take a closer look. + source, + symbolic_context, + ) + r = self._checked_cast_tensor_t( + _wrap_functional_tensor(ft, t.current_level), + ) + # TODO: is_leaf/requires_grad? + else: + assert t.stride is not None + + sizes = t.size + strides = t.stride + r = callback( + lambda: torch.empty_strided( + sizes, + strides, + dtype=t.dtype, + device="meta", + ), + # device="meta", + ) + if self.copy_data: + with torch.no_grad(), no_dispatch(): + r.real_tensor = torch.empty_strided( # type: ignore[attr-defined] + t.size, + t.stride, + dtype=t.dtype, + device=t.device, + ) + assert t.data is not None + _safe_copy(r.real_tensor, t.data) # type: ignore[attr-defined] + return r + + r = _to_fake_tensor(t) + + elif t.is_functional and t.device.type not in ["xla", "lazy"]: + assert t.unwrapped is not None + assert not t.is_functorch_wrapped # handled above + unwrapped = self.meta_tensor( + t.unwrapped, + shape_env, + callback, + source, + symbolic_context, + ) + r = self._checked_cast_tensor_t( + torch._to_functional_tensor(unwrapped) + ) + torch._mirror_autograd_meta_to(t.autograd_meta_from, r) # type: ignore[attr-defined] + + elif t.is_view: + # Construct views in two steps: recursively meta-fy their + # base, and then create view(s) off that. NB: doing it + # directly from storage is WRONG because this won't cause + # version counters to get shared. + + assert t.base is not None + + base_symbolic_context = None + if shape_env and symbolic_context is not None: + from torch.fx.experimental.symbolic_shapes import ( + StatelessSymbolicContext, + ) + + assert isinstance(symbolic_context, StatelessSymbolicContext) + # NB: This should generally be set when the input is a view, + # but the exception right now is for fake-ifying grads, which is + # a work in progress. + if symbolic_context.view_base_context is not None: + base_symbolic_context = symbolic_context.view_base_context + + base = self.meta_tensor( + t.base, + shape_env, + callback, + torch._dynamo.source.AttrSource(source, "_base"), + base_symbolic_context, + ) + + def is_c_of_r( + complex_dtype: torch.dtype, real_dtype: torch.dtype + ) -> bool: + return ( + utils.is_complex_dtype(complex_dtype) + and utils.corresponding_real_dtype(complex_dtype) + == real_dtype + ) + + # In some situations, MetaConverter may be called in a + # context where autograd is disabled. For the _is_view + # assert to pass, we have to setup the autograd view + # metadata anyway. Do this by reenabling the + # ADInplaceOrView key. This is kind of a hack. + old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView + ) + torch._C._dispatch_tls_set_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView, False + ) + try: + if base.dtype == t.dtype: + pass + elif is_c_of_r(base.dtype, t.dtype): + base = self._checked_cast_tensor_t(torch.view_as_real(base)) + elif is_c_of_r(t.dtype, base.dtype): + base = self._checked_cast_tensor_t( + torch.view_as_complex(base) + ) + else: + # This is not guaranteed to succeed. If it fails, it + # means there is another dtype-converting view function + # that hasn't been handled here + base = self._checked_cast_tensor_t(base.view(t.dtype)) + + # This is very tricky. Naively, you might expect this + # to hold: + # + # if t.requires_grad and not safe_is_leaf(t) + # assert t._base.requires_grad + # + # But it's not true! As you can see in the following + # program: + # + # x = torch.zeros(4) + # y = x.view(1, 4) + # y.requires_grad = True + # z = y.view(1, 1, 4) + # assert z._base is x + # + # So we may have to do *two* views out of the base to + # recreate this situation. + if t.is_leaf: + # Leaf views that track view metadata are created by + # creating a view inside a no_grad block + with torch.no_grad(): + r = view_from_base(base, t) + # As it's a leaf, we can directly assign requires_grad + r.requires_grad = t.requires_grad + else: + if t.base.requires_grad == t.requires_grad: + # Easy case, just run the view op + with torch.enable_grad(): + r = view_from_base(base, t) + + # NB: We don't actaully faithfully replicate + # autograd connectivity, but that doesn't matter + # today. See following for more info: + # https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913 + else: + # Obscure case. Create a leaf view and give it the + # correct requires_grad, then do the final view. + # NB: Can't have a non-leaf without requiring grad! + assert t.requires_grad + with torch.no_grad(), enable_python_dispatcher(): + mid = self._checked_cast_tensor_t( + base.view(base.shape) + ) + mid.requires_grad = t.requires_grad + with torch.enable_grad(): + r = view_from_base(mid, t) + # The CreationMeta influences whether or not inplace + # mutation is an error or not. So we need to make + # sure we properly propagate this as well. + assert t.creation_meta is not None + torch._C._autograd._set_creation_meta(r, t.creation_meta) + finally: + torch._C._dispatch_tls_set_dispatch_key_excluded( + torch._C.DispatchKey.ADInplaceOrView, old_exclude + ) + + else: + is_leaf = t.is_leaf + + # Graph-Break for wrapped tensors + if ( + not (t.is_batchedtensor or t.is_gradtrackingtensor) + and t.is_functorch_wrapped + ) or t.is_legacy_batchedtensor: + return NotImplemented + + ( + sizes, + strides, + storage_offset, + ) = sym_sizes_strides_storage_offset(t, source, symbolic_context) + + # If we have a subclass that desugars into dense tensors, + # perform our callback on each inner tensor. + if t.is_traceable_wrapper_subclass: + r = empty_create_subclass( + t, outer_size=sizes, outer_stride=strides + ) + else: + r = callback( + lambda: torch.empty_strided( + sizes, + strides, + dtype=t.dtype, + device="meta", + ) + ) + if self.copy_data: + with torch.no_grad(), no_dispatch(): + assert t.size is not None + assert t.stride is not None + assert _is_fake_tensor(r) + r.real_tensor = torch.empty_strided( + t.size, t.stride, dtype=t.dtype, device=t.device + ) + _safe_copy(r.real_tensor, t.data) + + assert safe_is_leaf(r), "the callback you passed in doesn't detach" + if t.requires_grad: + r.requires_grad = t.requires_grad + if not is_leaf: + # Fake up some autograd history. + # Note: we *used* to call .clone() here to mock up some autograd history. + # This is bad for subclasses. + # Consider the case where you have a wrapper subclass that is contiguous, + # but its inner tensor is noncontiguous(). + # .clone() (or other ops) will have the side effect of changing + # the metadata of the inner tensor. + # So instead, we now have a dedicated fn to set autograd history, + # without inadvertently changing other metadata. + r = self._backward_error(r) + + s = t.storage + assert s is not None + if s.id not in self.storage_memo and ( + r.is_nested + or ( + r.stride() == strides + and r.storage_offset() == storage_offset + ) + ): + # You're normal and happy, install the fresh storage into the memo + self.set_storage_memo(s, r.untyped_storage()) + if self.copy_data: + assert _is_fake_tensor(r) + assert r.real_tensor is not None + _set_real_storage( + r.untyped_storage(), r.real_tensor.untyped_storage() + ) + else: + # You're in crazy town; somehow you gave us a tensor + # that wasn't a view, but had nonzero storage offset, + # nontrivial strides (such that clone() couldn't + # preserve them), or already aliases with another + # tensor's storage. The most typical way to end + # up here is with set_. So use set_ to bludgeon this + # in. + r_s = self.meta_storage(s, callback=callback) + # NB: In principle, this should always work, but there + # is some subtle difference in the autograd metadata + # that means we will backprop the set_ call, even if + # r is declared as an input to grad. + # See https://github.com/pytorch/pytorch/issues/87956 + # for the reproducer. + # NB: The in_kernel_invocation_manager here is necessary + # for fake tensor. If we run the set_ call with fake + # tensor on, r will improperly report that it is NOT a + # meta tensor but a cpu tensor, and then the set_ call + # will fail due to device mismatch. no_dispatch() is + # not enough, because the fake tensor will still claim + # to be a CPU tensor and you'll end up in the CPU + # kernel. Arguably this is a hack; a cleaner way to + # solve this is to have a FakeStorage concept which + # would report it's CPU device--no problem now! But + # this is difficult to do because we don't have storage + # subclasses. Relevant test is + # DynamicShapesFunctionTests::test_add_dynamic_shapes in + # test/dynamo/test_dynamic_shapes.py + maybe_fake_mgr: AbstractContextManager[ + None + ] = contextlib.nullcontext() + from torch._subclasses.fake_tensor import ( + in_kernel_invocation_manager, + maybe_get_fake_mode, + ) + + mb_fake_mode = maybe_get_fake_mode(r) + if mb_fake_mode is not None: + maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode) + with torch.no_grad(), maybe_suppress(): + with maybe_fake_mgr: + r.set_(r_s, storage_offset, sizes, strides) + if self.copy_data: + with torch.no_grad(), no_dispatch(): + assert _is_fake_tensor(r) + assert r.real_tensor is not None + assert t.stride is not None + r.real_tensor.set_( + _get_real_storage(r_s), + t.storage_offset, + t.size, + t.stride, + ) + + if t.grad is not None: + from torch._dynamo.source import AttrSource + + # TODO: Use a valid grad-specific symbolic context instead of recycling + # the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view(). + r.grad = self.meta_tensor( + t.grad, + shape_env, + callback, + AttrSource(source, "grad"), + symbolic_context, + ) + torch._C._set_conj(r, t.is_conj) + torch._C._set_neg(r, t.is_neg) + # This can be skipped if necessary for performance reasons + skip_leaf = ( + t.is_gradtrackingtensor and t.level == GRAD_TENSOR_SENTINEL_VALUE + ) + assert_metadata_eq(assert_eq, t, r, skip_symbolic=True, skip_leaf=skip_leaf) + # Thanks to storage resizing, it's possible to end up with a tensor + # that advertises a real size, but has a storage that actually has zero bytes. + # Need to reflect this in the generated FakeTensor. + from torch.fx.experimental.symbolic_shapes import guard_or_false + + if t.storage is not None and guard_or_false(t.storage.size == 0): + r.untyped_storage().resize_(0) + + if t.is_parameter: + r._is_param = True + + # See Note: [Creating symbolic nested int] + if t.nested_int is not None: + assert _is_fake_tensor(r) + r.nested_int_memo = r.fake_mode.create_symbolic_nested_int( + nt_tensor_id=t.nested_int + ) + + self.set_tensor_memo(t, r) + + return self._checked_get_tensor_memo(t) + + def __call__( + self, + t: torch.Tensor, + shape_env: Optional[ShapeEnv] = None, + *, + callback: Optional[_MetaTensorCallback[_TensorT]] = None, + source: Optional[Source] = None, + symbolic_context: Optional[SymbolicContext] = None, + # Controls whether or not we should dump the tensor metadata to structured logs + # when source is not None. Because we refakify after Dynamo is done, + # we don't want to dump info again from AOTAutograd, it is redundant. + trace: bool = True, + ) -> _TensorT: + callback_: _MetaTensorCallback[_TensorT] + if callback is None: + callback_ = self._identity_callable + else: + callback_ = callback + # TODO: zero tensors? We appear to have eliminated them by + # excluding complex for now + + # Filter out cases we don't support + # TODO: This can probably be simplified quite a bit + if isinstance(t, torch.Tensor): + if ( + # Lazy tensors are not supported. Note that XLA is + # implemented on top of lazy tensor, not excluded here; we + # have some special handling for it; this is for XLA Dynamo + # integration + t.device.type == "lazy" + or + # Quantization is not supported + t.is_quantized + or + # Views out of sparse tensors not currently supported (plain + # sparse is supported htough) + (t._is_view() and t._base is not None and t._base.is_sparse) + ): + self.miss += 1 + return NotImplemented + else: + self.hit += 1 + elif torch.overrides.is_tensor_like(t): + self.miss += 1 + return NotImplemented + else: + # non-Tensor types don't count as hit or miss + return t + + if source is None: + trace = False + + # Describe the tensor. NB: do NOT disable ambient modes, we may need + # to query them when figuring out what to put in here + t_desc = self.describer.describe_tensor(t, trace=trace) + + if trace: + assert source is not None + trace_structured( + "describe_source", + metadata_fn=lambda: { + "describer_id": self.describer.id, + "id": t_desc.id, + "source": source.name(), + }, + ) + + # Do the meta-fication. Here, we disable all the ambient modes, to + # better simulate what would be like to re-fakeify from a fresh + # process + with contextlib.ExitStack() as exit_stack: + exit_stack.enter_context(torch._dispatch.python.suspend_functionalization()) + st = peek_interpreter_stack() + if st is not None: + exit_stack.enter_context( + torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() + ) + + r = self.meta_tensor( + t_desc, + shape_env, + callback_, + source, + symbolic_context, + ) + + if type(t) is torch.nn.Parameter: + # NB: Cannot directly use Parameter constructor + # because that would force a detach, not desirable + r._is_param = True + + # TODO: return the description for later + return r + + +import torch._prims_common as utils diff --git a/phivenv/Lib/site-packages/torch/_subclasses/schema_check_mode.py b/phivenv/Lib/site-packages/torch/_subclasses/schema_check_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..0b7359e4b3a32915fbfce482cc49859c358c6baa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_subclasses/schema_check_mode.py @@ -0,0 +1,230 @@ +# mypy: ignore-errors + +from collections import namedtuple +from copy import deepcopy +from itertools import combinations + +import torch +from torch.fx.operator_schemas import normalize_function +from torch.utils import _pytree as pytree +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_map + + +# Named Tuples used within SchemaCheckMode +Mutation = namedtuple("Mutation", ["op_name", "arg_name"]) +Aliasing = namedtuple("Aliasing", ["op_name", "arg_name", "output_number"]) + +# Simplified naming for C++ classes +SchemaArgument = torch._C._SchemaArgument +SchemaArgType = torch._C._SchemaArgType +SchemaInfo = torch._C._SchemaInfo + +# This TorchDispatchMode Subclass is used to verify op schemas +# This TorchDispatchMode Scubclass currently: +# - Records the called ops +# - Checks for mutations on all inputs +# - Checks for aliasing on all inputs + + +# move these 2 functions here to avoid numpy dependency in testing/_internal/common_utils.py + + +def is_iterable_of_tensors(iterable): + # Tensor itself is iterable so we check this first + if isinstance(iterable, torch.Tensor): + return False + try: + if len(iterable) == 0: + return False + for t in iter(iterable): + if not isinstance(t, torch.Tensor): + return False + except TypeError: + return False + return True + + +def clone_inputs(args): + inputs = [] + + for arg in args: + if isinstance(arg, torch.Tensor): + inputs.append(arg.detach().clone()) + elif is_iterable_of_tensors(arg): + inputs.append([t.detach().clone() for t in arg]) + else: + inputs.append(arg) + + return inputs + + +class SchemaCheckMode(TorchDispatchMode): + def __init__(self) -> None: + # Information recorded for testing purposes. For example: + # - incorrect schemas + # - overly conservative schemas + self.ops = [] + self.mutated = [] + self.aliasing = [] + + def reset_cache(self): + self.ops.clear() + self.mutated.clear() + self.aliasing.clear() + + def display_ops(self): + print(*self.ops, sep=",") + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + def bitwise_equal(lhs, rhs): + if lhs.is_quantized: + # TODO: This is only OK if can't have NaN quantized; idk if + # this is actually true + return torch.equal(lhs, rhs) + else: + return torch.allclose(lhs, rhs, equal_nan=True) + + def has_mutated(before, after, md): + are_tensors = type(before) == torch.Tensor and type(after) == torch.Tensor + if ( + are_tensors + and before.layout != torch.sparse_csr + and after.layout != torch.sparse_csr + ): + return not ( + before.size() == after.size() + and bitwise_equal(before, after) + and md[0] == after.stride() + and md[1] == after._typed_storage()._cdata + ) + return False + + def has_aliased(lhs, rhs): + try: + return torch._C._overlaps(lhs, rhs) + except Exception as exception: + if str(exception).startswith("Cannot inspect value of type "): + return False + else: + raise exception + + def standardize_name(name): + return name if name != "self" else "input" + + def unwrap(e): + if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor: + try: + return e.elem + except AttributeError: + return e + return e + + def parse_metadata(e): + if isinstance(e, torch.Tensor): + if not type(e) == torch.Tensor: + try: + current = e.elem + return ( + deepcopy(current.stride()), + current._typed_storage()._cdata, + ) + except AttributeError: + return None + # Sparse CSR tensors do not have strides or storage + elif e.layout != torch.sparse_csr: + return (deepcopy(e.stride()), e._typed_storage()._cdata) + return None + + self.ops.append(func._schema.name) + + # Clone and process arguments and outputs + pre_arguments = normalize_function( + func, args, kwargs, normalize_to_only_use_kwargs=True + ).kwargs + + c_p_args = dict(zip(pre_arguments.keys(), clone_inputs(pre_arguments.values()))) + cloned_arguments = { + name: tree_map(unwrap, c_p_args.get(name)) for name in c_p_args + } + cloned_metadata = { + name: [ + parse_metadata(a) for a in pytree.tree_leaves(pre_arguments.get(name)) + ] + for name in pre_arguments + } + + out = func(*args, **kwargs) + arguments = { + name: tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments + } + tuple_out = out if isinstance(out, tuple) else (out,) + tuple_out = tree_map(unwrap, tuple_out) + + schema_info = SchemaInfo(func._schema) + schema_info.add_argument_values(pre_arguments) + + # Process arguments with outputs + for i in range(len(func._schema.arguments)): + arg = func._schema.arguments[i] + name = standardize_name(arg.name) + if arguments.get(name) is not None: + before = cloned_arguments.get(name) + md = cloned_metadata.get(name) + after = arguments.get(name) + for j in range(len(tuple_out)): + # aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe) + unsafe_ops = ("aten::_unsafe_view", "aten::unsafe_split") + if ( + has_aliased(tuple_out[j], after) + and func._schema.name not in unsafe_ops + ): + if not schema_info.may_contain_alias( + SchemaArgument(SchemaArgType.output, j), + SchemaArgument(SchemaArgType.input, i), + ): + raise RuntimeError( + f"Argument {name} is not defined to alias output but was aliasing" + ) + else: + self.aliasing.append( + Aliasing(func._schema.name, name, f"output_{j}") + ) + if after is tuple_out[j] and isinstance(after, torch.Tensor): + # Only mutable ops e.g. (add_, add.out) are allowed to directly return inputs. + if not schema_info.is_mutable( + SchemaArgument(SchemaArgType.input, i) + ) and func not in [ + torch.ops.aten.lift.default, + torch.ops.aten.lift_fresh.default, + ]: + raise RuntimeError( + f"""\ +Dispatcher operators below autograd are not allowed to directly return inputs. +However, we found that `outputs[{str(j)}] is {name}""" + ) + if any( + has_mutated(a, b, c) + for a, b, c in zip( + pytree.tree_leaves(before), pytree.tree_leaves(after), md + ) + ): + if not schema_info.is_mutable( + SchemaArgument(SchemaArgType.input, i) + ): + raise RuntimeError( + f"Argument {name} is not defined as mutable but was mutated" + ) + else: + self.mutated.append(Mutation(func._schema.name, name)) + + # Aliasing between outputs + for i, j in combinations(range(len(func._schema.returns)), 2): + if has_aliased(tuple_out[i], tuple_out[j]): + if not schema_info.may_contain_alias( + SchemaArgument(SchemaArgType.output, i), + SchemaArgument(SchemaArgType.output, j), + ): + raise RuntimeError(f"Outputs {i} and {j} alias unexpectedly") + + return out diff --git a/phivenv/Lib/site-packages/torch/_vendor/__init__.py b/phivenv/Lib/site-packages/torch/_vendor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/_vendor/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_vendor/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1825c674b0fe14320b697dc844d85f3549bb5371 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_vendor/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_vendor/packaging/__init__.py b/phivenv/Lib/site-packages/torch/_vendor/packaging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7d2ad1ca0a2bf4d73bb6dc5252c3407dd0f20d14 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_vendor/packaging/__init__.py @@ -0,0 +1,15 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +__title__ = "packaging" +__summary__ = "Core utilities for Python packages" +__uri__ = "https://github.com/pypa/packaging" + +__version__ = "23.2" + +__author__ = "Donald Stufft and individual contributors" +__email__ = "donald@stufft.io" + +__license__ = "BSD-2-Clause or Apache-2.0" +__copyright__ = "2014 %s" % __author__ diff --git a/phivenv/Lib/site-packages/torch/_vendor/packaging/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_vendor/packaging/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4370e039aa7e011dc47de6d5bb865e1a56d7aa25 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_vendor/packaging/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_vendor/packaging/__pycache__/_structures.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_vendor/packaging/__pycache__/_structures.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4a95dd9d5a1936eef7cad20f638142db1d27b59 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_vendor/packaging/__pycache__/_structures.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_vendor/packaging/__pycache__/version.cpython-39.pyc b/phivenv/Lib/site-packages/torch/_vendor/packaging/__pycache__/version.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d02b0b69d715117543ddffbf838bae35d6fd64e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/_vendor/packaging/__pycache__/version.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/_vendor/packaging/_structures.py b/phivenv/Lib/site-packages/torch/_vendor/packaging/_structures.py new file mode 100644 index 0000000000000000000000000000000000000000..dfc91962d80e24f98b76d0da1d765fc78b0a1dcb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_vendor/packaging/_structures.py @@ -0,0 +1,61 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + + +class InfinityType: + def __repr__(self) -> str: + return "Infinity" + + def __hash__(self) -> int: + return hash(repr(self)) + + def __lt__(self, other: object) -> bool: + return False + + def __le__(self, other: object) -> bool: + return False + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) + + def __gt__(self, other: object) -> bool: + return True + + def __ge__(self, other: object) -> bool: + return True + + def __neg__(self: object) -> "NegativeInfinityType": + return NegativeInfinity + + +Infinity = InfinityType() + + +class NegativeInfinityType: + def __repr__(self) -> str: + return "-Infinity" + + def __hash__(self) -> int: + return hash(repr(self)) + + def __lt__(self, other: object) -> bool: + return True + + def __le__(self, other: object) -> bool: + return True + + def __eq__(self, other: object) -> bool: + return isinstance(other, self.__class__) + + def __gt__(self, other: object) -> bool: + return False + + def __ge__(self, other: object) -> bool: + return False + + def __neg__(self: object) -> InfinityType: + return Infinity + + +NegativeInfinity = NegativeInfinityType() diff --git a/phivenv/Lib/site-packages/torch/_vendor/packaging/version.py b/phivenv/Lib/site-packages/torch/_vendor/packaging/version.py new file mode 100644 index 0000000000000000000000000000000000000000..e1cca483cee045aa1acfa9f5cf27c0331cc532aa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/_vendor/packaging/version.py @@ -0,0 +1,563 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. +""" +.. testsetup:: + + from packaging.version import parse, Version +""" + +import itertools +import re +from typing import Any, Callable, NamedTuple, Optional, SupportsInt, Tuple, Union + +from ._structures import Infinity, InfinityType, NegativeInfinity, NegativeInfinityType + +__all__ = ["VERSION_PATTERN", "parse", "Version", "InvalidVersion"] + +LocalType = Tuple[Union[int, str], ...] + +CmpPrePostDevType = Union[InfinityType, NegativeInfinityType, Tuple[str, int]] +CmpLocalType = Union[ + NegativeInfinityType, + Tuple[Union[Tuple[int, str], Tuple[NegativeInfinityType, Union[int, str]]], ...], +] +CmpKey = Tuple[ + int, + Tuple[int, ...], + CmpPrePostDevType, + CmpPrePostDevType, + CmpPrePostDevType, + CmpLocalType, +] +VersionComparisonMethod = Callable[[CmpKey, CmpKey], bool] + + +class _Version(NamedTuple): + epoch: int + release: Tuple[int, ...] + dev: Optional[Tuple[str, int]] + pre: Optional[Tuple[str, int]] + post: Optional[Tuple[str, int]] + local: Optional[LocalType] + + +def parse(version: str) -> "Version": + """Parse the given version string. + + >>> parse('1.0.dev1') + + + :param version: The version string to parse. + :raises InvalidVersion: When the version string is not a valid version. + """ + return Version(version) + + +class InvalidVersion(ValueError): + """Raised when a version string is not a valid version. + + >>> Version("invalid") + Traceback (most recent call last): + ... + packaging.version.InvalidVersion: Invalid version: 'invalid' + """ + + +class _BaseVersion: + _key: Tuple[Any, ...] + + def __hash__(self) -> int: + return hash(self._key) + + # Please keep the duplicated `isinstance` check + # in the six comparisons hereunder + # unless you find a way to avoid adding overhead function calls. + def __lt__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key < other._key + + def __le__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key <= other._key + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key == other._key + + def __ge__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key >= other._key + + def __gt__(self, other: "_BaseVersion") -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key > other._key + + def __ne__(self, other: object) -> bool: + if not isinstance(other, _BaseVersion): + return NotImplemented + + return self._key != other._key + + +# Deliberately not anchored to the start and end of the string, to make it +# easier for 3rd party code to reuse +_VERSION_PATTERN = r""" + v? + (?: + (?:(?P[0-9]+)!)? # epoch + (?P[0-9]+(?:\.[0-9]+)*) # release segment + (?P
                                          # pre-release
+            [-_\.]?
+            (?Palpha|a|beta|b|preview|pre|c|rc)
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+        (?P                                         # post release
+            (?:-(?P[0-9]+))
+            |
+            (?:
+                [-_\.]?
+                (?Ppost|rev|r)
+                [-_\.]?
+                (?P[0-9]+)?
+            )
+        )?
+        (?P                                          # dev release
+            [-_\.]?
+            (?Pdev)
+            [-_\.]?
+            (?P[0-9]+)?
+        )?
+    )
+    (?:\+(?P[a-z0-9]+(?:[-_\.][a-z0-9]+)*))?       # local version
+"""
+
+VERSION_PATTERN = _VERSION_PATTERN
+"""
+A string containing the regular expression used to match a valid version.
+
+The pattern is not anchored at either end, and is intended for embedding in larger
+expressions (for example, matching a version number as part of a file name). The
+regular expression should be compiled with the ``re.VERBOSE`` and ``re.IGNORECASE``
+flags set.
+
+:meta hide-value:
+"""
+
+
+class Version(_BaseVersion):
+    """This class abstracts handling of a project's versions.
+
+    A :class:`Version` instance is comparison aware and can be compared and
+    sorted using the standard Python interfaces.
+
+    >>> v1 = Version("1.0a5")
+    >>> v2 = Version("1.0")
+    >>> v1
+    
+    >>> v2
+    
+    >>> v1 < v2
+    True
+    >>> v1 == v2
+    False
+    >>> v1 > v2
+    False
+    >>> v1 >= v2
+    False
+    >>> v1 <= v2
+    True
+    """
+
+    _regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
+    _key: CmpKey
+
+    def __init__(self, version: str) -> None:
+        """Initialize a Version object.
+
+        :param version:
+            The string representation of a version which will be parsed and normalized
+            before use.
+        :raises InvalidVersion:
+            If the ``version`` does not conform to PEP 440 in any way then this
+            exception will be raised.
+        """
+
+        # Validate the version and parse it into pieces
+        match = self._regex.search(version)
+        if not match:
+            raise InvalidVersion(f"Invalid version: '{version}'")
+
+        # Store the parsed out pieces of the version
+        self._version = _Version(
+            epoch=int(match.group("epoch")) if match.group("epoch") else 0,
+            release=tuple(int(i) for i in match.group("release").split(".")),
+            pre=_parse_letter_version(match.group("pre_l"), match.group("pre_n")),
+            post=_parse_letter_version(
+                match.group("post_l"), match.group("post_n1") or match.group("post_n2")
+            ),
+            dev=_parse_letter_version(match.group("dev_l"), match.group("dev_n")),
+            local=_parse_local_version(match.group("local")),
+        )
+
+        # Generate a key which will be used for sorting
+        self._key = _cmpkey(
+            self._version.epoch,
+            self._version.release,
+            self._version.pre,
+            self._version.post,
+            self._version.dev,
+            self._version.local,
+        )
+
+    def __repr__(self) -> str:
+        """A representation of the Version that shows all internal state.
+
+        >>> Version('1.0.0')
+        
+        """
+        return f""
+
+    def __str__(self) -> str:
+        """A string representation of the version that can be rounded-tripped.
+
+        >>> str(Version("1.0a5"))
+        '1.0a5'
+        """
+        parts = []
+
+        # Epoch
+        if self.epoch != 0:
+            parts.append(f"{self.epoch}!")
+
+        # Release segment
+        parts.append(".".join(str(x) for x in self.release))
+
+        # Pre-release
+        if self.pre is not None:
+            parts.append("".join(str(x) for x in self.pre))
+
+        # Post-release
+        if self.post is not None:
+            parts.append(f".post{self.post}")
+
+        # Development release
+        if self.dev is not None:
+            parts.append(f".dev{self.dev}")
+
+        # Local version segment
+        if self.local is not None:
+            parts.append(f"+{self.local}")
+
+        return "".join(parts)
+
+    @property
+    def epoch(self) -> int:
+        """The epoch of the version.
+
+        >>> Version("2.0.0").epoch
+        0
+        >>> Version("1!2.0.0").epoch
+        1
+        """
+        return self._version.epoch
+
+    @property
+    def release(self) -> Tuple[int, ...]:
+        """The components of the "release" segment of the version.
+
+        >>> Version("1.2.3").release
+        (1, 2, 3)
+        >>> Version("2.0.0").release
+        (2, 0, 0)
+        >>> Version("1!2.0.0.post0").release
+        (2, 0, 0)
+
+        Includes trailing zeroes but not the epoch or any pre-release / development /
+        post-release suffixes.
+        """
+        return self._version.release
+
+    @property
+    def pre(self) -> Optional[Tuple[str, int]]:
+        """The pre-release segment of the version.
+
+        >>> print(Version("1.2.3").pre)
+        None
+        >>> Version("1.2.3a1").pre
+        ('a', 1)
+        >>> Version("1.2.3b1").pre
+        ('b', 1)
+        >>> Version("1.2.3rc1").pre
+        ('rc', 1)
+        """
+        return self._version.pre
+
+    @property
+    def post(self) -> Optional[int]:
+        """The post-release number of the version.
+
+        >>> print(Version("1.2.3").post)
+        None
+        >>> Version("1.2.3.post1").post
+        1
+        """
+        return self._version.post[1] if self._version.post else None
+
+    @property
+    def dev(self) -> Optional[int]:
+        """The development number of the version.
+
+        >>> print(Version("1.2.3").dev)
+        None
+        >>> Version("1.2.3.dev1").dev
+        1
+        """
+        return self._version.dev[1] if self._version.dev else None
+
+    @property
+    def local(self) -> Optional[str]:
+        """The local version segment of the version.
+
+        >>> print(Version("1.2.3").local)
+        None
+        >>> Version("1.2.3+abc").local
+        'abc'
+        """
+        if self._version.local:
+            return ".".join(str(x) for x in self._version.local)
+        else:
+            return None
+
+    @property
+    def public(self) -> str:
+        """The public portion of the version.
+
+        >>> Version("1.2.3").public
+        '1.2.3'
+        >>> Version("1.2.3+abc").public
+        '1.2.3'
+        >>> Version("1.2.3+abc.dev1").public
+        '1.2.3'
+        """
+        return str(self).split("+", 1)[0]
+
+    @property
+    def base_version(self) -> str:
+        """The "base version" of the version.
+
+        >>> Version("1.2.3").base_version
+        '1.2.3'
+        >>> Version("1.2.3+abc").base_version
+        '1.2.3'
+        >>> Version("1!1.2.3+abc.dev1").base_version
+        '1!1.2.3'
+
+        The "base version" is the public version of the project without any pre or post
+        release markers.
+        """
+        parts = []
+
+        # Epoch
+        if self.epoch != 0:
+            parts.append(f"{self.epoch}!")
+
+        # Release segment
+        parts.append(".".join(str(x) for x in self.release))
+
+        return "".join(parts)
+
+    @property
+    def is_prerelease(self) -> bool:
+        """Whether this version is a pre-release.
+
+        >>> Version("1.2.3").is_prerelease
+        False
+        >>> Version("1.2.3a1").is_prerelease
+        True
+        >>> Version("1.2.3b1").is_prerelease
+        True
+        >>> Version("1.2.3rc1").is_prerelease
+        True
+        >>> Version("1.2.3dev1").is_prerelease
+        True
+        """
+        return self.dev is not None or self.pre is not None
+
+    @property
+    def is_postrelease(self) -> bool:
+        """Whether this version is a post-release.
+
+        >>> Version("1.2.3").is_postrelease
+        False
+        >>> Version("1.2.3.post1").is_postrelease
+        True
+        """
+        return self.post is not None
+
+    @property
+    def is_devrelease(self) -> bool:
+        """Whether this version is a development release.
+
+        >>> Version("1.2.3").is_devrelease
+        False
+        >>> Version("1.2.3.dev1").is_devrelease
+        True
+        """
+        return self.dev is not None
+
+    @property
+    def major(self) -> int:
+        """The first item of :attr:`release` or ``0`` if unavailable.
+
+        >>> Version("1.2.3").major
+        1
+        """
+        return self.release[0] if len(self.release) >= 1 else 0
+
+    @property
+    def minor(self) -> int:
+        """The second item of :attr:`release` or ``0`` if unavailable.
+
+        >>> Version("1.2.3").minor
+        2
+        >>> Version("1").minor
+        0
+        """
+        return self.release[1] if len(self.release) >= 2 else 0
+
+    @property
+    def micro(self) -> int:
+        """The third item of :attr:`release` or ``0`` if unavailable.
+
+        >>> Version("1.2.3").micro
+        3
+        >>> Version("1").micro
+        0
+        """
+        return self.release[2] if len(self.release) >= 3 else 0
+
+
+def _parse_letter_version(
+    letter: Optional[str], number: Union[str, bytes, SupportsInt, None]
+) -> Optional[Tuple[str, int]]:
+
+    if letter:
+        # We consider there to be an implicit 0 in a pre-release if there is
+        # not a numeral associated with it.
+        if number is None:
+            number = 0
+
+        # We normalize any letters to their lower case form
+        letter = letter.lower()
+
+        # We consider some words to be alternate spellings of other words and
+        # in those cases we want to normalize the spellings to our preferred
+        # spelling.
+        if letter == "alpha":
+            letter = "a"
+        elif letter == "beta":
+            letter = "b"
+        elif letter in ["c", "pre", "preview"]:
+            letter = "rc"
+        elif letter in ["rev", "r"]:
+            letter = "post"
+
+        return letter, int(number)
+    if not letter and number:
+        # We assume if we are given a number, but we are not given a letter
+        # then this is using the implicit post release syntax (e.g. 1.0-1)
+        letter = "post"
+
+        return letter, int(number)
+
+    return None
+
+
+_local_version_separators = re.compile(r"[\._-]")
+
+
+def _parse_local_version(local: Optional[str]) -> Optional[LocalType]:
+    """
+    Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
+    """
+    if local is not None:
+        return tuple(
+            part.lower() if not part.isdigit() else int(part)
+            for part in _local_version_separators.split(local)
+        )
+    return None
+
+
+def _cmpkey(
+    epoch: int,
+    release: Tuple[int, ...],
+    pre: Optional[Tuple[str, int]],
+    post: Optional[Tuple[str, int]],
+    dev: Optional[Tuple[str, int]],
+    local: Optional[LocalType],
+) -> CmpKey:
+
+    # When we compare a release version, we want to compare it with all of the
+    # trailing zeros removed. So we'll use a reverse the list, drop all the now
+    # leading zeros until we come to something non zero, then take the rest
+    # re-reverse it back into the correct order and make it a tuple and use
+    # that for our sorting key.
+    _release = tuple(
+        reversed(list(itertools.dropwhile(lambda x: x == 0, reversed(release))))
+    )
+
+    # We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
+    # We'll do this by abusing the pre segment, but we _only_ want to do this
+    # if there is not a pre or a post segment. If we have one of those then
+    # the normal sorting rules will handle this case correctly.
+    if pre is None and post is None and dev is not None:
+        _pre: CmpPrePostDevType = NegativeInfinity
+    # Versions without a pre-release (except as noted above) should sort after
+    # those with one.
+    elif pre is None:
+        _pre = Infinity
+    else:
+        _pre = pre
+
+    # Versions without a post segment should sort before those with one.
+    if post is None:
+        _post: CmpPrePostDevType = NegativeInfinity
+
+    else:
+        _post = post
+
+    # Versions without a development segment should sort after those with one.
+    if dev is None:
+        _dev: CmpPrePostDevType = Infinity
+
+    else:
+        _dev = dev
+
+    if local is None:
+        # Versions without a local segment should sort before those with one.
+        _local: CmpLocalType = NegativeInfinity
+    else:
+        # Versions with a local segment need that segment parsed to implement
+        # the sorting rules in PEP440.
+        # - Alpha numeric segments sort before numeric segments
+        # - Alpha numeric segments sort lexicographically
+        # - Numeric segments sort numerically
+        # - Shorter versions sort before longer versions when the prefixes
+        #   match exactly
+        _local = tuple(
+            (i, "") if isinstance(i, int) else (NegativeInfinity, i) for i in local
+        )
+
+    return epoch, _release, _pre, _post, _dev, _local
diff --git a/phivenv/Lib/site-packages/torch/lib/kineto.lib b/phivenv/Lib/site-packages/torch/lib/kineto.lib
new file mode 100644
index 0000000000000000000000000000000000000000..f70c1cd6a77114b22bbed63fa1be677c1530e3d4
--- /dev/null
+++ b/phivenv/Lib/site-packages/torch/lib/kineto.lib
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:df1d865c452efc35f9baf212baf738018c4691422826718d18d09fa7502c46fd
+size 57893210
diff --git a/phivenv/Lib/site-packages/torch/lib/libprotobuf-lite.lib b/phivenv/Lib/site-packages/torch/lib/libprotobuf-lite.lib
new file mode 100644
index 0000000000000000000000000000000000000000..ede1f105eaeab5c9dadc042b6cafe480babbd2f2
--- /dev/null
+++ b/phivenv/Lib/site-packages/torch/lib/libprotobuf-lite.lib
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0aae59299b6e96a450237fa112a4b0345038f7a42c8f5f09412f8fed25407509
+size 4661164
diff --git a/phivenv/Lib/site-packages/torch/lib/libprotobuf.lib b/phivenv/Lib/site-packages/torch/lib/libprotobuf.lib
new file mode 100644
index 0000000000000000000000000000000000000000..c2d8cbba6fe3a51d1e1e54d19895ab541981b60e
--- /dev/null
+++ b/phivenv/Lib/site-packages/torch/lib/libprotobuf.lib
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aa673cd04fef9b0a5d453776946999ff3be8089296b3c9b0edf66cdd4c678192
+size 32694610
diff --git a/phivenv/Lib/site-packages/torch/lib/libprotoc.lib b/phivenv/Lib/site-packages/torch/lib/libprotoc.lib
new file mode 100644
index 0000000000000000000000000000000000000000..e48e2391acd94f850811dba0cd9d47a3649b61e7
--- /dev/null
+++ b/phivenv/Lib/site-packages/torch/lib/libprotoc.lib
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:43944f08849b86f9175d07ab9e292ce918c62984a13818174046d92929ff4794
+size 38476918
diff --git a/phivenv/Lib/site-packages/torchgen/__init__.py b/phivenv/Lib/site-packages/torchgen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8b61af2c4b58ff14ab7e3b24bf22e8ec6a95da0
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/__init__.py
@@ -0,0 +1,10 @@
+"""torchgen
+
+This module contains codegeneration utilities for PyTorch. It is used to
+build PyTorch from source, but may also be used for out-of-tree projects
+that extend PyTorch.
+
+Note well that we provide no BC guarantees for torchgen. If you're interested
+in using torchgen and want the PyTorch team to be aware, please reach out
+on GitHub.
+"""
diff --git a/phivenv/Lib/site-packages/torchgen/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..539958d2d53be19433c439371a89e65f91ab51f8
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/__pycache__/__init__.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/__pycache__/code_template.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/__pycache__/code_template.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7ffa5762bd98ca15ce2f19c21d32b145e2fc40e1
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/__pycache__/code_template.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/__pycache__/context.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/__pycache__/context.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bbff5cd75bffef9a442b4cab29b9bf4e24efb238
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/__pycache__/context.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/__pycache__/gen.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/__pycache__/gen.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a74e770f38dead67f8010ad41a89ac677d44510
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/__pycache__/gen.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..85a707ef36b272db98f970a6f368e9529a0e35e1
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/__pycache__/gen_aoti_c_shim.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..37de0ece9697daed4a537cc19c58ce6b0435b71f
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/__pycache__/gen_backend_stubs.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a2d0a3d667a174e7b776a84068f88bb7958936d
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/__pycache__/gen_functionalization_type.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d7a5a81ef759b2319516187aa6e137541629044c
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/__pycache__/gen_lazy_tensor.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/__pycache__/gen_schema_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/__pycache__/gen_schema_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aec279893fabd9d0eeaae659f2bfd168b68bcdca
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/__pycache__/gen_schema_utils.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d631e1134cba0dd639f05f113f67d7b66263c45b
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/__pycache__/gen_vmap_plumbing.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/__pycache__/local.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/__pycache__/local.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..01ac54f20ed0964312b67acfeb03ce9103a1b80d
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/__pycache__/local.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/__pycache__/model.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/__pycache__/model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..db28fb4cd03dfd9d85591ba1f0b3b1dae47eb634
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/__pycache__/model.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/__pycache__/native_function_generation.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/__pycache__/native_function_generation.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5e2adab4d2fa4fae0542d7bb7728f417054a43f6
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/__pycache__/native_function_generation.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f853dfe1ca143f1c80ad007a5c8174da6666764b
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/__pycache__/utils.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/__pycache__/yaml_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/__pycache__/yaml_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..985cf0545d514df2380fa33fd7f904142b0199fe
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/__pycache__/yaml_utils.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/aoti/__init__.py b/phivenv/Lib/site-packages/torchgen/aoti/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phivenv/Lib/site-packages/torchgen/aoti/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/aoti/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d39094050d9f4e95bdd969ea2701bb9b4c9d38c3
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/aoti/__pycache__/__init__.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/aoti/__pycache__/fallback_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/aoti/__pycache__/fallback_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa046ffa316099119f5f0784892ff0f37d82588b
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/aoti/__pycache__/fallback_ops.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/aoti/fallback_ops.py b/phivenv/Lib/site-packages/torchgen/aoti/fallback_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bc23fcd5c263dbd09c7dcbfae9ebd70e87ad52b
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/aoti/fallback_ops.py
@@ -0,0 +1,175 @@
+# Be extra careful when you edit this file, because it affects AOTInductor ABI compatibility. See
+# https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436
+# for details.
+#
+# The inductor_fallback_ops list is based on the fallback ops from torch/_inductor/lowering.py.
+#
+# Generally speaking, it is ok to add a new op to the list, but you need to run
+# `python torchgen/gen.py --update-aoti-c-shim` in order to regenerate C shim header files.
+# But it is NOT ok to remove an existing fallback op from the list, since that will break
+# some existing AOTInductor-compiled models.
+#
+# A fallback op version defaults to 1. If you want to extend an existing fallback op by adding
+# a new argument with a default value, while it is fine in the Python world, it will be BC-breaking
+# when generating C shim. Thus you need to bump up the version number of that fallback op by
+# updating the entry in the inductor_fallback_ops list, adding a new version number with a list
+# of new arguments, and then run `python torchgen/gen.py --update-aoti-c-shim` to regenerate.
+
+inductor_fallback_ops: dict[str, dict[str, list[str]]] = {
+    "aten._adaptive_avg_pool2d_backward.default": {},
+    "aten._adaptive_avg_pool2d.default": {},
+    "aten._adaptive_avg_pool3d_backward.default": {},
+    "aten._adaptive_avg_pool3d.default": {},
+    "aten._addmm_activation.default": {},
+    "aten._cdist_backward.default": {},
+    "aten._cdist_forward.default": {},
+    "aten._cudnn_rnn.default": {},
+    "aten._dyn_quant_matmul_4bit.default": {},
+    "aten._dyn_quant_pack_4bit_weight.default": {},
+    "aten._efficient_attention_backward.default": {},
+    "aten._efficient_attention_forward.default": {},
+    "aten._efficientzerotensor.default": {},
+    "aten._embedding_bag_dense_backward.default": {},
+    "aten._embedding_bag_forward_only.default": {},
+    "aten._embedding_bag_per_sample_weights_backward.default": {},
+    "aten._embedding_bag.default": {},
+    "aten._fft_c2c.default": {},
+    "aten._fft_r2c.default": {},
+    "aten._flash_attention_backward.default": {},
+    "aten._flash_attention_forward.default": {},
+    "aten._fused_moving_avg_obs_fq_helper_functional.default": {},
+    "aten._fused_moving_avg_obs_fq_helper.default": {},
+    "aten._histogramdd_from_bin_cts.default": {},
+    "aten._int_mm.out": {},
+    "aten._pdist_backward.default": {},
+    "aten._pdist_forward.default": {},
+    "aten._scaled_dot_product_cudnn_attention_backward.default": {},
+    "aten._scaled_dot_product_cudnn_attention.default": {},
+    "aten._scaled_dot_product_efficient_attention_backward.default": {},
+    "aten._scaled_dot_product_efficient_attention.default": {},
+    "aten._scaled_dot_product_flash_attention_backward.default": {},
+    "aten._scaled_dot_product_flash_attention_for_cpu_backward.default": {},
+    "aten._scaled_dot_product_flash_attention_for_cpu.default": {},
+    "aten._scaled_dot_product_flash_attention.default": {},
+    "aten._scaled_dot_product_fused_attention_overrideable_backward.default": {},
+    "aten._scaled_dot_product_fused_attention_overrideable.default": {},
+    "aten._scaled_mm.default": {},
+    "aten._scaled_mm.out": {},
+    "aten._segment_reduce_backward.default": {},
+    "aten._thnn_fused_lstm_cell.default": {},
+    "aten._to_sparse.default": {},
+    "aten._trilinear.default": {},
+    "aten._weight_int4pack_mm.default": {},
+    "aten._weight_int8pack_mm.default": {},
+    "aten.abs.default": {},
+    "aten.adaptive_max_pool2d_backward.default": {},
+    "aten.adaptive_max_pool2d.default": {},
+    "aten.adaptive_max_pool3d_backward.default": {},
+    "aten.adaptive_max_pool3d.default": {},
+    "aten.add.Scalar": {},
+    "aten.add.Tensor": {},
+    "aten.addbmm.default": {},
+    "aten.addmm.out": {},
+    "aten.addmv.default": {},
+    "aten.angle.default": {},
+    "aten.avg_pool2d_backward.default": {},
+    "aten.avg_pool2d.default": {},
+    "aten.avg_pool3d_backward.default": {},
+    "aten.avg_pool3d.default": {},
+    "aten.baddbmm.out": {},
+    "aten.bernoulli_.float": {},
+    "aten.bernoulli_.Tensor": {},
+    "aten.bmm.out": {},
+    "aten.bucketize.Tensor": {},
+    "aten.cat.default": {},
+    "aten.cholesky_inverse.default": {},
+    "aten.cholesky_solve.default": {},
+    "aten.convolution_backward.default": {},
+    "aten.convolution.default": {},
+    "aten.cummax.default": {},
+    "aten.cummin.default": {},
+    "aten.cumprod.default": {},
+    "aten.cumsum.default": {},
+    "aten.exponential.default": {},
+    "aten.fill_.Scalar": {},
+    "aten.fractional_max_pool2d_backward.default": {},
+    "aten.fractional_max_pool2d.default": {},
+    "aten.fractional_max_pool3d_backward.default": {},
+    "aten.fractional_max_pool3d.default": {},
+    "aten.gcd.default": {},
+    "aten.geqrf.default": {},
+    "aten.grid_sampler_2d_backward.default": {},
+    "aten.hann_window.default": {},
+    "aten.histc.default": {},
+    "aten.histogram.bin_ct": {},
+    "aten.index_put.default": {},
+    "aten.index_reduce.default": {},
+    "aten.index.Tensor": {},
+    "aten.kthvalue.default": {},
+    "aten.logcumsumexp.default": {},
+    "aten.lu_unpack.default": {},
+    "aten.masked_scatter_backward.default": {},
+    "aten.masked_scatter.default": {},
+    "aten.masked_select.default": {},
+    "aten.max_pool2d_with_indices_backward.default": {},
+    "aten.max_pool2d_with_indices.default": {},
+    "aten.max_pool3d_with_indices_backward.default": {},
+    "aten.max_pool3d_with_indices.default": {},
+    "aten.max_unpool2d.default": {},
+    "aten.max_unpool3d.default": {},
+    "aten.median.default": {},
+    "aten.mm.out": {},
+    "aten.mode.default": {},
+    "aten.mul.Scalar": {},
+    "aten.mul.Tensor": {},
+    "aten.nanmedian.default": {},
+    "aten.narrow.default": {},
+    "aten.native_dropout.default": {},
+    "aten.nonzero.default": {},
+    "aten.normal_functional.default": {},
+    "aten.ormqr.default": {},
+    "aten.pad.default": {},
+    "aten.permute.default": {},
+    "aten.polar.default": {},
+    "aten.pow.Scalar": {},
+    "aten.pow.Tensor_Scalar": {},
+    "aten.pow.Tensor_Tensor": {},
+    "aten.rand.default": {},
+    "aten.rand.generator": {},
+    "aten.randint.default": {},
+    "aten.randint.generator": {},
+    "aten.randint.low_out": {},
+    "aten.randint.low": {},
+    "aten.randn.default": {},
+    "aten.randn.generator": {},
+    "aten.randperm.default": {},
+    "aten.repeat_interleave.Tensor": {},
+    "aten.replication_pad1d_backward.default": {},
+    "aten.replication_pad2d_backward.default": {},
+    "aten.reshape.default": {},
+    "aten.resize_.default": {},
+    "aten.resize_as_.default": {},
+    "aten.scatter_reduce.two_out": {},
+    "aten.scatter.src_out": {},
+    "aten.scatter.value_out": {},
+    "aten.searchsorted.Scalar": {},
+    "aten.searchsorted.Tensor": {},
+    "aten.segment_reduce.default": {},
+    "aten.set_.source_Tensor": {},
+    "aten.slice.Tensor": {},
+    "aten.soft_margin_loss_backward.default": {},
+    "aten.sort.default": {},
+    "aten.sort.stable": {},
+    "aten.squeeze.dim": {},
+    "aten.to_sparse.default": {},
+    "aten.topk.default": {},
+    "aten.triangular_solve.default": {},
+    "aten.uniform.default": {},
+    "aten.upsample_bicubic2d_backward.default": {},
+    "aten.upsample_linear1d_backward.default": {},
+    "aten.upsample_trilinear3d_backward.default": {},
+    "aten.view_as_complex.default": {},
+    "aten.view_as_real.default": {},
+    "aten.view.dtype": {},
+    "aten._weight_int4pack_mm_with_scales_and_zeros.default": {},
+}
diff --git a/phivenv/Lib/site-packages/torchgen/api/__init__.py b/phivenv/Lib/site-packages/torchgen/api/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phivenv/Lib/site-packages/torchgen/api/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/api/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..61056293771973c7e7ff728f5dda89f559897f63
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/api/__pycache__/__init__.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/api/__pycache__/autograd.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/api/__pycache__/autograd.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..60d237f7d1e1db09c694bdefcb4d0a5919db9be5
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/api/__pycache__/autograd.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/api/__pycache__/cpp.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/api/__pycache__/cpp.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..96b812f9d5593ef4ceedaea464c669882a06ddef
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/api/__pycache__/cpp.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/api/__pycache__/dispatcher.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/api/__pycache__/dispatcher.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..620e88ae3af5e190ed3af4f2e8f8f4d68549f100
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/api/__pycache__/dispatcher.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/api/__pycache__/functionalization.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/api/__pycache__/functionalization.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a92db8c691f04bb4a31e333070a729eb446c8275
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/api/__pycache__/functionalization.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/api/__pycache__/lazy.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/api/__pycache__/lazy.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a22cb5692127347af11aa4b773984e50420e9b1a
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/api/__pycache__/lazy.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/api/__pycache__/meta.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/api/__pycache__/meta.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3720984c67ce9525a49e1f82742de8dfde6e3b1b
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/api/__pycache__/meta.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/api/__pycache__/native.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/api/__pycache__/native.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ad1508eec0c7bce37dc5705ed140764466f2febd
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/api/__pycache__/native.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/api/__pycache__/python.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/api/__pycache__/python.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..122d6b6ec9be4452cfc3af82a773f505dbde8305
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/api/__pycache__/python.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/api/__pycache__/structured.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/api/__pycache__/structured.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..61fbb97717435e2b8714397bb4de4f08fad9534b
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/api/__pycache__/structured.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/api/__pycache__/translate.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/api/__pycache__/translate.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f33d0869db289a24a626da7779290ec2f6d33935
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/api/__pycache__/translate.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/api/__pycache__/ufunc.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/api/__pycache__/ufunc.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a36d3d4ee6262e755d6da9300455aca39c0774a8
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/api/__pycache__/ufunc.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/api/__pycache__/unboxing.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/api/__pycache__/unboxing.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..252635bc2864d93aee589b3dc330061f53b9cb58
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/api/__pycache__/unboxing.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/api/autograd.py b/phivenv/Lib/site-packages/torchgen/api/autograd.py
new file mode 100644
index 0000000000000000000000000000000000000000..d12d7d1b87e4014035114c46f25dd22b51265849
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/api/autograd.py
@@ -0,0 +1,874 @@
+from __future__ import annotations
+
+import re
+from dataclasses import dataclass
+from typing import cast, TYPE_CHECKING
+
+from torchgen import local
+from torchgen.api import cpp
+from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT
+from torchgen.model import (
+    BaseTy,
+    BaseType,
+    FunctionSchema,
+    ListType,
+    NativeFunction,
+    NativeFunctionsViewGroup,
+    SchemaKind,
+    Type,
+)
+from torchgen.utils import IDENT_REGEX
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+
+# Represents a saved attribute involved in backward calculation.
+# Note that it can be a derived property of an input argument, e.g.:
+# we could save `other.scalar_type()` instead of the entire `other` tensor.
+@dataclass(frozen=True)
+class SavedAttribute:
+    # The NamedCType holds the updated name and cpp type of the attribute
+    # for the name, Suffix is appended if it's derived property, e.g.: `other_scalar_type`
+    nctype: NamedCType
+
+    # The expression to read the derived property at save time, e.g.:
+    # `other.scalar_type()`.
+    expr: str
+
+
+# Represents a backward formula that calculates derivatives for one
+# or more tensors.
+@dataclass(frozen=True)
+class Derivative:
+    # The formula string (legit C++ expression).
+    # Note that expressions against input arguments have been replaced with the
+    # corresponding saved attributes.
+    # E.g.:
+    #  raw formula: `mul_tensor_backward(grad, self, other.scalar_type())`
+    #         here: `mul_tensor_backward(grad, self, other_scalar_type)`
+    formula: str
+
+    # The formula string before input argument replacement
+    original_formula: str
+
+    # Names of the arguments for which this formula calculates derivatives.
+    var_names: tuple[str, ...]
+
+    # Saved inputs that are referenced by the formula.
+    saved_inputs: tuple[SavedAttribute, ...]
+
+    # Saved outputs that are referenced by the formula.
+    saved_outputs: tuple[SavedAttribute, ...]
+
+    # Gradients that are referenced by name in the formula.
+    named_gradients: set[str]
+
+
+# Represents a forward formula that calculates forward derivatives
+# for one tensor.
+@dataclass(frozen=True)
+class ForwardDerivative:
+    # The formula string (legit C++ expression).
+    # Note that special keywords such as "linear" or "element_wise" have been
+    # replaced by the automatically generated formula.
+    formula: str
+
+    # Name of the output arguments for which this formula calculates forward
+    # derivatives
+    var_names: tuple[str, ...]
+
+    # Type of the output arguments for which this formula calculates forward
+    # derivatives
+    var_types: tuple[Type, ...]
+
+    # Inputs for which the forward derivatives are required for this formula
+    required_inputs_fw_grad: tuple[str, ...] | None
+
+    # Inputs for which the primal is required for this formula
+    required_inputs_primal: tuple[str, ...] | None
+
+    # Flag to specify if this formula requires the original value of self
+    # This is only used by inplace operations
+    required_original_self_value: bool
+
+    # If this formula is specified in derivatives.yaml or if we are reusing the
+    # out of place formula for inplace
+    is_reusing_outplace_formula: bool
+
+
+# Represents differentiability info for a NativeFunction.
+@dataclass(frozen=True)
+class DifferentiabilityInfo:
+    # The base name read from derivatives.yaml.
+    name: str
+
+    # The matching native function.
+    #
+    # There can be multiple NativeFunction having the same base name:
+    #  - different overloads with different types of input arguments;
+    #  - in-place/out/functional variants of the same function;
+    #
+    # We first use the schema string (under the 'name' key) in derivatives.yaml
+    # to find the NativeFunction having the same schema string.
+    # Then we find the in-place/out/functional variants of the matching function.
+    # Among these variants, we choose the one having the same name as the
+    # derivatives.yaml entry. If there is no exact match, then we choose the
+    # in-place variant.
+    # TODO: maybe the logic to search for all variants is no longer necessary?
+    func: NativeFunction
+
+    # The name of the generated autograd function.
+    # It's set only if we will calculate a derivative, i.e.
+    # 'args_with_derivatives' is not empty.
+    op: str | None
+
+    # The derivatives formulae for this function.
+    # Note that the length of this sequence is the number of differentiable inputs
+    derivatives: Sequence[Derivative]
+
+    # The forward derivatives formulae for this function.
+    # Note that the length of this sequence is the number of differentiable outputs
+    forward_derivatives: Sequence[ForwardDerivative]
+
+    # The union of 'saved_inputs' of all 'derivatives'.
+    all_saved_inputs: Sequence[SavedAttribute]
+
+    # The union of 'saved_outputs' of all 'derivatives'.
+    all_saved_outputs: Sequence[SavedAttribute]
+
+    # All named gradients that are available for use, in the same
+    # order as in the grads vector.
+    available_named_gradients: Sequence[str]
+
+    # The named gradients that are used in any of the derivatives.
+    # Invariant: all(name in available_named_gradients for name in used_named_gradients)
+    used_named_gradients: set[str]
+
+    # The function's input arguments for which it calculates derivatives.
+    # It's the union of 'var_names' of all 'derivatives', sorted by the
+    # argument order in the function schema.
+    args_with_derivatives: Sequence[Binding]
+
+    # Names of arguments whose derivative formula is 'non_differentiable'.
+    non_differentiable_arg_names: Sequence[str]
+
+    # Raw data read from derivatives.yaml.
+    output_differentiability: list[bool] | None
+
+    # output_differentiability in derivatives.yaml can be a list of
+    # conditions that express if the output is differentiable. In this case,
+    # the number of conditions must match the number of outputs
+    # (NB: we only support one condition right now).
+    # output_differentiability gets populated with True for each condition,
+    # while output_differentiability_conditions gets populated with the conditions
+    output_differentiability_conditions: list[str] | None
+
+    @property
+    def has_derivatives(self) -> bool:
+        return len(self.args_with_derivatives) > 0
+
+    # Generates a new DifferentiabilityInfo using the exact same set of derivative information,
+    # but with a new operator name.
+    # This is used when generating "copy" variants of view ops,
+    # which are able to use the exact same derivative formula as the original view op
+    # See Note [Codegen'd {view}_copy Operators]
+    def create_view_copy_from_view_derivative(
+        self, g: NativeFunctionsViewGroup
+    ) -> DifferentiabilityInfo | None:
+        if g.view_copy is None:
+            return None
+        f = g.view_copy
+
+        name_split_by_period = self.name.split(".", maxsplit=2)
+        # Append a "_copy" to the base name of the operator (but keep the overload name the same)
+        view_copy_name = f"{name_split_by_period[0]}_copy." + ".".join(
+            name_split_by_period[1:]
+        )
+        view_copy_op_name = None if self.op is None else f"{self.op}_copy"
+
+        return DifferentiabilityInfo(
+            # Use the "_copy" version of name/func/op
+            name=view_copy_name,
+            func=f,
+            op=view_copy_op_name,
+            # But keep all derivative info the same
+            derivatives=self.derivatives,
+            forward_derivatives=self.forward_derivatives,
+            all_saved_inputs=self.all_saved_inputs,
+            all_saved_outputs=self.all_saved_outputs,
+            available_named_gradients=self.available_named_gradients,
+            used_named_gradients=self.used_named_gradients,
+            args_with_derivatives=self.args_with_derivatives,
+            non_differentiable_arg_names=self.non_differentiable_arg_names,
+            output_differentiability=self.output_differentiability,
+            output_differentiability_conditions=self.output_differentiability_conditions,
+        )
+
+
+def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
+    if info is None:
+        return False
+    for derivative in info.derivatives:
+        formula = derivative.formula
+        if re.search(IDENT_REGEX.format(ident), formula):
+            return True
+    return False
+
+
+def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool:
+    return uses_ident(info, "retain_variables")
+
+
+def uses_single_grad(info: DifferentiabilityInfo | None) -> bool:
+    return uses_ident(info, "grad")
+
+
+# Represents a differentiable `Argument`.
+# How is it different from the `Argument` type?
+# - It's processed Arguments which are differentiable and only used in the
+#   context of the autograd codegen;
+# - It can represent SelfArgument or regular Argument but not TensorOptionsArgument;
+@dataclass(frozen=True)
+class DifferentiableInput:
+    name: str
+    type: Type
+
+    # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
+    cpp_type: str
+
+
+# Represents a differentiable `Return`.
+# How it it different from the `Return` type?
+# - The name in `Return` is optional. Here it is always populated using the same
+#   `cpp.return_names()` method.
+#   TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant?
+# - It's processed Returns which are differentiable, in compliance with the
+#   `output_differentiability` field defined in derivatives.yaml (if specified),
+#   and are only used in the context of the autograd codegen;
+@dataclass(frozen=True)
+class DifferentiableOutput:
+    name: str
+    type: Type
+
+    # TODO: only to keep it byte-for-byte compatible with the old codegen, should remove.
+    cpp_type: str
+
+
+@dataclass(frozen=True)
+class NativeFunctionWithDifferentiabilityInfo:
+    func: NativeFunction
+    info: dict[str, DifferentiabilityInfo] | None
+    fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None
+
+
+# TODO: Update comment below since it is out of date.
+def dispatch_strategy(fn: NativeFunctionWithDifferentiabilityInfo) -> str:
+    """How are we going to call the underlying implementation of a
+    declaration?  There are two strategies:
+        - use_derived: we want to call the implementation on CPUDoubleType
+          (or a similar, derived Type instance).  Because these derived
+          instances deal in Tensors, not Variables (it's a completely different
+          object, so it doesn't dispatch back to VariableType), code on
+          this dispatch path needs to wrap/unwrap tensors.  If the
+          derived implementation takes and returns tensors, the
+          implementation is usually differentiable (although we also use
+          the derived dispatch path for non-differentiable functions
+          that we still want to dispatch on the derived Type instance;
+          e.g., size())
+        - use_type: we want to call the implementation on Type, because
+          it is implemented concretely, and the functions it invokes will
+          get dispatched back to VariableType (which will ensure that they
+          are differentiable.)
+    """
+    # fn is derived as long as any of its per-key differentiability infos
+    # has_derivatives. dispatch_strategy() is used to guard generation of fns in VariableType
+    # and ADInplaceOrViewType. We want to generate these functions as long as a
+    # derivative is defined for ANY dispatch key.
+    if fn.func.is_abstract or (
+        fn.info is not None and any(info.has_derivatives for info in fn.info.values())
+    ):
+        # If the function is abstract (not implemented on at::Type), we must
+        # call the implementation on the derived type with unpacked tensors.
+
+        # If the function has a derivative specified and is concrete, we could
+        # call either implementation. We prefer the calling the derived
+        # type's implementation with unpacked tensors because it is more
+        # performant in some cases: any internal calls to other ATen functions
+        # won't have the history tracked.
+
+        # If the function has a type dispatched argument (i.e. is a factory),
+        # we prefer calling the derived type's implementation both because it is
+        # more performant and to ensure factory functions return tensors with _version
+        # of 0 (probably not strictly necessary, but nice to have to keeps versions simple
+        # to understand.
+
+        return "use_derived"
+    else:
+        # If the function is concrete (we don't have to override it) and we
+        # didn't declare it in derivatives.yaml, we'll assume that it is
+        # actually implemented out of differentiable functions. (This
+        # assumption might not hold, but then you'll see gradcheck fail.)
+        return "use_type"
+
+
+def is_foreach_func(f: NativeFunction) -> bool:
+    return f.func.name.name.base.startswith("_foreach_")
+
+
+# note(crcrpar): Most foreach functions can reference an out-place `torch` function whose schema kind
+# is functional for their backward derivatives (and forward derivatives in the future), i.e.,
+# they would find such one in `functional_info_by_signature`. There however are some exceptions:
+_foreach_with_inplace_ref = {"_foreach_zero_"}
+_foreach_with_tensor_overload = {
+    "_foreach_add.Tensor",
+    "_foreach_mul.Tensor",
+    "_foreach_div.Tensor",
+}
+# The following do not support the alpha kwarg, which the nonforeach versions support.
+_skip_argument_len_check = {
+    "_foreach_add.Scalar",
+    "_foreach_add_.Scalar",
+    "_foreach_add.ScalarList",
+    "_foreach_add_.ScalarList",
+    "_foreach_sub.Scalar",
+    "_foreach_sub_.Scalar",
+    "_foreach_sub.ScalarList",
+    "_foreach_sub_.ScalarList",
+}
+
+
+# Checks if `function_schema` is a native, non-foreach function which `f`, a foreach function
+# reference to generate derivatives.
+def is_reference_for_foreach(
+    f: NativeFunction,
+    function_schema: FunctionSchema,
+) -> bool:
+    return (
+        f.func.name.name.base.split("_foreach_")[-1] == function_schema.name.name.base
+        and (
+            not function_schema.name.name.inplace
+            or str(f.func.name) in _foreach_with_inplace_ref
+        )
+        and (
+            str(f.func.name) in _skip_argument_len_check
+            or len(f.func.arguments.flat_non_out)
+            == len(function_schema.arguments.flat_non_out)
+        )
+        and all(
+            ref_arg.type in (arg.type, getattr(arg.type, "elem", None))
+            for arg, ref_arg in zip(
+                f.func.arguments.flat_non_out,
+                function_schema.arguments.flat_non_out,
+            )
+        )
+    )
+
+
+# TODO(crcrpar): Avoid hard coding "Default" ideally.
+def gen_foreach_derivativeinfo(
+    foreach_function: NativeFunction,
+    functional_info_by_signature: dict[
+        FunctionSchema, dict[str, DifferentiabilityInfo]
+    ],
+    non_functional_info_by_signature: dict[
+        FunctionSchema, dict[str, DifferentiabilityInfo]
+    ],
+    dispatch_key: str = "Default",
+) -> tuple[DifferentiabilityInfo | None, bool]:
+    """Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.
+
+    The second return value indicates whether the info is generated in this function.
+    """
+    ref_diff_info: DifferentiabilityInfo | None = None
+
+    for function_schema, diff_info in functional_info_by_signature.items():
+        if not is_reference_for_foreach(foreach_function, function_schema):
+            continue
+        ref_diff_info = diff_info[dispatch_key]
+        if ref_diff_info is not None:
+            break
+    # note(crcrpar): It seems like `zero`'s info isn't available in functional_info_by_signature
+    # while the info of `zero_` is in non_functional_info_by_signature
+    if (
+        ref_diff_info is None
+        and foreach_function.func.kind() == SchemaKind.inplace
+        and str(foreach_function.func.name) in _foreach_with_inplace_ref
+    ):
+        for function_schema, diff_info in non_functional_info_by_signature.items():
+            if not is_reference_for_foreach(foreach_function, function_schema):
+                continue
+            ref_diff_info = diff_info[dispatch_key]
+            if ref_diff_info is not None:
+                break
+    if ref_diff_info is None:
+        return None, False
+
+    # non out-place uses the existing Derivative.
+    if foreach_function.func.kind() == SchemaKind.inplace:
+        return ref_diff_info, False
+
+    map_refarg2foreacharg, map_name2arg = {}, {}
+    for i, (arg, ref_arg) in enumerate(
+        zip(
+            foreach_function.func.arguments.flat_non_out,
+            function_schema.arguments.flat_non_out,
+        )
+    ):
+        map_refarg2foreacharg[ref_arg.name] = arg.name
+        map_name2arg[arg.name] = arg
+
+    all_saved_inputs, all_saved_outputs, all_var_names = [], [], []
+    modified_derivative_formulas = []
+    for i, derivative in enumerate(ref_diff_info.derivatives):
+        modified_formula = derivative.formula.replace("grad", "grads[i]").replace(
+            "result", "result[i]"
+        )
+        saved_inputs, saved_outputs = [], []
+        # note(crcrpar): This context seems necessary to call `cpp.argument_type`
+        with local.parametrize(
+            use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
+            use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
+        ):
+            for ref_input in derivative.saved_inputs:
+                ref_input_jit_name = ref_input.expr.split(".")[0]
+                mapped_name = map_refarg2foreacharg[ref_input_jit_name]
+                if isinstance(map_name2arg[mapped_name].type, ListType):
+                    mapped_expr = mapped_name + "[i]"
+                else:
+                    mapped_expr = mapped_name
+                new_expr = ref_input.expr.replace(ref_input_jit_name, mapped_expr)
+                modified_formula = modified_formula.replace(
+                    cast(str, ref_input.nctype.name), new_expr
+                )
+
+                nctype = cpp.argument_type(map_name2arg[mapped_name], binds=mapped_name)
+                canonical_nctype = NamedCType(
+                    nctype.name, nctype.type.remove_const_ref()
+                )
+                saved_inputs.append(
+                    SavedAttribute(nctype=canonical_nctype, expr=mapped_name)
+                )
+            for ref_output in derivative.saved_outputs:
+                if ref_output.nctype.name == "result":
+                    saved_outputs.append(
+                        SavedAttribute(
+                            nctype=NamedCType(
+                                name="result", type=BaseCType(tensorListT)
+                            ),
+                            expr="result",
+                        )
+                    )
+                else:
+                    raise RuntimeError("")
+        var_names = [map_refarg2foreacharg[var] for var in derivative.var_names]
+        all_var_names.extend(var_names)
+        all_saved_inputs.extend(saved_inputs)
+        all_saved_outputs.extend(saved_outputs)
+        modified_derivative = Derivative(
+            formula=modified_formula,
+            original_formula=derivative.formula,
+            var_names=tuple(var_names),
+            saved_inputs=tuple(saved_inputs),
+            saved_outputs=tuple(saved_outputs),
+            named_gradients=set(),
+        )
+        modified_derivative_formulas.append(modified_derivative)
+
+    with local.parametrize(
+        use_const_ref_for_mutable_tensors=foreach_function.use_const_ref_for_mutable_tensors,
+        use_ilistref_for_tensor_lists=foreach_function.part_of_structured_group,
+    ):
+        args_with_derivatives = [
+            Binding(
+                name=arg.name,
+                nctype=cpp.argument_type(arg, binds=arg.name),
+                argument=arg,
+                default=None,
+            )
+            for arg in foreach_function.func.arguments.flat_non_out
+            if arg.name in all_var_names
+        ]
+
+    forward_derivatives: list[ForwardDerivative] = []
+    fw_derivative: ForwardDerivative
+    for fw_derivative in ref_diff_info.forward_derivatives:
+        var_names: list[str] = list(fw_derivative.var_names)  # type: ignore[no-redef]
+        var_types: list[Type] = list(fw_derivative.var_types)
+        required_inputs_fw_grad: list[str] = []
+        required_inputs_primal: list[str] = []
+        if fw_derivative.required_inputs_fw_grad is not None:
+            required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
+        if fw_derivative.required_inputs_primal:
+            required_inputs_primal = list(fw_derivative.required_inputs_primal)
+        modified_formula = fw_derivative.formula
+
+        # Foreach's result is TensorList
+        if "result" in modified_formula:
+            modified_formula = fw_derivative.formula.replace("result", "result[i]")
+
+        for foreach_arg, ref_arg in zip(
+            foreach_function.func.arguments.flat_non_out,
+            ref_diff_info.func.func.arguments.flat_non_out,
+        ):
+            # Modify reference forward formula
+            if (
+                isinstance(foreach_arg.type, ListType)
+                and not foreach_arg.type.is_tensor_like()
+            ):
+                # Assuming ScalarList
+                modified_formula = modified_formula.replace(
+                    ref_arg.name, foreach_arg.name + "[i]"
+                )
+            elif foreach_arg.type.is_tensor_like():
+                # Assuming TensorList / Tensor
+                # assert isinstance(foreach_arg.type, ListType), f"{foreach_function.func.name}, {foreach_arg.type}"
+                assert isinstance(foreach_arg.type, ListType) or (
+                    foreach_arg.type == BaseType(BaseTy.Tensor)
+                    and str(foreach_function.func.name) in _foreach_with_tensor_overload
+                ), f"{foreach_function.func.name}, {foreach_arg.type}"
+                for suffix in ("_p", "_t"):
+                    curr_expr = ref_arg.name + suffix
+                    if curr_expr in modified_formula:
+                        new_expr = foreach_arg.name + suffix
+                        modified_formula = modified_formula.replace(curr_expr, new_expr)
+            else:
+                # Assuming Scalar
+                if foreach_arg.name != ref_arg.name:
+                    modified_formula = modified_formula.replace(
+                        ref_arg.name, foreach_arg.name
+                    )
+
+            # note(crcrpar): there should exist a cooler way...
+            for i, name in enumerate(var_names):
+                if name == ref_arg.name:
+                    var_names[i] = foreach_arg.name
+                    var_types[i] = foreach_arg.type
+            for i, name in enumerate(required_inputs_fw_grad):
+                if name == ref_arg.name:
+                    required_inputs_fw_grad[i] = foreach_arg.name
+            for i, name in enumerate(required_inputs_primal):
+                if name == ref_arg.name:
+                    required_inputs_primal[i] = foreach_arg.name
+        forward_derivatives.append(
+            ForwardDerivative(
+                formula=modified_formula,
+                var_names=tuple(var_names),
+                var_types=tuple(var_types),
+                required_inputs_fw_grad=tuple(required_inputs_fw_grad),
+                required_inputs_primal=tuple(required_inputs_primal),
+                required_original_self_value=fw_derivative.required_original_self_value,
+                is_reusing_outplace_formula=fw_derivative.is_reusing_outplace_formula,
+            )
+        )
+
+    return (
+        DifferentiabilityInfo(
+            name=foreach_function.func.name.name.base,
+            func=foreach_function,
+            op=f"Foreach{ref_diff_info.op}{foreach_function.func.name.overload_name}",
+            derivatives=modified_derivative_formulas,
+            forward_derivatives=forward_derivatives,
+            all_saved_inputs=tuple(set(all_saved_inputs)),
+            all_saved_outputs=tuple(set(all_saved_outputs)),
+            available_named_gradients=(),
+            used_named_gradients=set(),
+            args_with_derivatives=args_with_derivatives,
+            non_differentiable_arg_names=[],
+            output_differentiability=None,
+            output_differentiability_conditions=None,
+        ),
+        True,
+    )
+
+
+def match_differentiability_info(
+    native_functions: list[NativeFunction],
+    differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
+) -> list[NativeFunctionWithDifferentiabilityInfo]:
+    """Sets the "derivative" key on declarations to matching autograd function
+    In-place functions will use the out-of-place derivative definition if there
+    is no in-place specific derivative.
+    """
+
+    functional_info_by_signature = {
+        schema.signature(strip_default=True): info_dict
+        for schema, info_dict in differentiability_infos.items()
+        if schema.kind() == SchemaKind.functional
+    }
+    non_functional_info_by_signature = {
+        schema.signature(strip_default=True): info_dict
+        for schema, info_dict in differentiability_infos.items()
+        if schema.kind() != SchemaKind.functional
+    }
+
+    def find_info(
+        f: NativeFunction,
+    ) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]:
+        # Don't bother matching info to generated out= variants
+        if "generated" in f.tags and f.func.kind() == SchemaKind.out:
+            return None, False
+
+        # (1) Check for an exact match
+        if f.func in differentiability_infos:
+            return differentiability_infos[f.func], True
+
+        # (2) If no exact match, check if the out-of-place variant
+        # of this operator has a match.
+        # i.e mul() for mul_() or mul_out()
+        # note(crcrpar): Check foreach or not because in-place foreach functions use backward defined for the existing
+        # native functions instead of the out-place counterparts.
+        f_sig = f.func.signature(strip_default=True)
+        if f_sig in functional_info_by_signature and not is_foreach_func(f):
+            return functional_info_by_signature[f_sig], False
+
+        # (3) Some operators have a derivative explicitly defined for the mutable
+        # variant, but get a code-generated out-of-place variant which does *not*
+        # come with a derivative formula.
+        # For the generated out-of-place variant, use the mutable variant's formula
+        # if it exists.
+        if "generated" in f.tags and f_sig in non_functional_info_by_signature:
+            info_dict = non_functional_info_by_signature[f_sig]
+            # See https://github.com/pytorch/pytorch/pull/76320/files#r874816389
+            assert not any(
+                any("self" in str(input.nctype.name) for input in info.all_saved_inputs)
+                for info in info_dict.values()
+            ), f"""\
+Attempted to convert a derivative formula for a mutable operator
+ to be used by automatically by its functional variant ("{str(f.func)}").
+ this is not currently supported (we'd need to fix up the formula in the codegen)."""
+            return info_dict, False
+
+        # (4) Generate derivative information of foreach functions if none is defined in `derivatives.yaml`
+        if is_foreach_func(f):
+            assert f.func not in differentiability_infos
+            diff_info, is_generated = gen_foreach_derivativeinfo(
+                f,
+                functional_info_by_signature,
+                non_functional_info_by_signature,
+            )
+            if diff_info is None:
+                return None, False
+            # TODO(crcrpar): Avoid hard coding "Default" ideally.
+            diff_info_dict = {"Default": diff_info}
+            if is_generated:
+                differentiability_infos[f.func] = diff_info_dict
+                functional_info_by_signature[f.func] = diff_info_dict
+            return diff_info_dict, is_generated
+
+        return None, False
+
+    result: list[NativeFunctionWithDifferentiabilityInfo] = []
+    for f in native_functions:
+        info_dict, is_exact_match = find_info(f)
+
+        # Currently, the '.strides()' to 'strides_or_error' replacement does not support
+        # 'self' derivatives of an inplace function, so we must check for this case.
+        if f.func.kind() == SchemaKind.inplace and (info_dict is not None):
+            for info in info_dict.values():
+                for derivative in info.derivatives:
+                    if "self" in derivative.var_names:
+                        for saved_input in derivative.saved_inputs:
+                            assert "strides_or_error" not in saved_input.expr, (
+                                "Calling '.strides()' in the 'self' derivative formula of an "
+                                f"in-place function is not supported: {f.func}"
+                            )
+
+        if not info_dict:
+            result.append(
+                NativeFunctionWithDifferentiabilityInfo(
+                    func=f, info=None, fw_derivatives=None
+                )
+            )
+            continue
+
+        fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {}
+        for key, info in info_dict.items():
+            if not info.forward_derivatives:
+                fw_derivative_dict[key] = []
+                continue
+
+            forward_derivatives = info.forward_derivatives
+
+            # For functions that have a single def for out-of-place and inplace (like abs())
+            if f.func.kind() == SchemaKind.inplace:
+                # For inplace functions there is a little bit of work to do:
+                #  1) Validate the formula and make sure the input that is modified in not used:
+                #    - If there is a formula for the inplace variant of the function (is_exact_match == True) then
+                #      we make sure that the original value of the input that is being modified inplace (self_p) is
+                #      not used in the formula. Note that the formula can use "original_self_p" here and that would
+                #      trigger a clone of the original input.
+                #    - If we are reusing the out of place formula (is_exact_match == False) then we replace every
+                #      occurrence of self_p and self_t by original_self_p and original_self_t. These will be
+                #      populated by cloned version of the original input (either the clone done by the backward AD
+                #      logic if self is also used in a backward formula or a special clone that we add).
+                #  2) At this point, there cannot be a self_p in the formula.
+                #  3) Change "result" into "self_p" as by design, in the inplace function codegen, the result is
+                #     simply called self (as it is modified inplace).
+                #  4) Update the required primals data in case it used to contain "result" but should now contain
+                #     "self"
+                #  5) If it is not an exact match, the user formula is not modifying the existing forward grad
+                #     inplace as it should. So add some code that makes sure that we do so if the forward grad
+                #     already exists.
+
+                assert (
+                    len(info.forward_derivatives) == 1
+                )  # Only single output inplace should exist
+                fw_info = info.forward_derivatives[0]
+                formula = fw_info.formula
+
+                def replace_self_with_original_self(formula: str, postfix: str) -> str:
+                    def repl(m: re.Match[str]) -> str:
+                        return f"{m.group(1)}original_self{postfix}{m.group(2)}"
+
+                    return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
+
+                if re.search(IDENT_REGEX.format("self_p"), formula):
+                    if is_exact_match:
+                        # For manually defined formulas, don't allow the original value to be used
+                        raise RuntimeError(
+                            f'The formula for "{f.func.name}" is using the original value of self '
+                            "that is being modified inplace. This would lead to wrong forward gradients. "
+                            'Please use "result" in the formula only.'
+                        )
+                    else:
+                        # When the original formula is out of place, we save a clone of the primal
+                        # value to be able to access this value if needed
+                        # replace "self_p"/"self_t" from the formula by "original_self_p"/"original_self_t"
+                        formula = replace_self_with_original_self(formula, "_p")
+                        formula = replace_self_with_original_self(formula, "_t")
+
+                # replace "result" from the formula by "self_p"
+                def repl(m: re.Match[str]) -> str:
+                    return f"{m.group(1)}self_p{m.group(2)}"
+
+                formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
+
+                required_primals = fw_info.required_inputs_primal
+                if re.search(IDENT_REGEX.format("self_p"), formula):
+                    required_primals = (
+                        required_primals + ("self",) if required_primals else ("self",)
+                    )
+
+                if not is_exact_match:
+                    # NOTE [In-place forward AD formula Optimization]
+                    #
+                    # This optimization transforms the formula to directly do inplace, i.e.
+                    # instead of self_t.copy_(self_t.op()) we do self_t.op_() when the following are met:
+                    #
+                    # 1) the formula satisfies the pattern: "self_t.op(*args)"
+                    # 2) "op" in (1) needs to be the same as the op the derivative is for
+                    #
+                    # (2) may seem too strict, but currently the only ops that satisfy (1) also satisfy (2)
+                    # If there is a need, we can relax (2) to allow any op that has an in-place variant
+                    is_single_method_on_self_t = False
+                    directly_do_inplace = False
+                    op_name: str | None = None
+                    between_parens: str | None = None
+                    match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
+                    if match:
+                        op_name, between_parens = match.group(1), match.group(2)
+
+                        # We want to...
+                        #   Match: self_t.op1(other_p.op2(arg))
+                        #   Avoid: self_t.op1(args) + self_t.op2(args)
+                        #   Avoid: self_t.op1(other_p.op2(arg)) + self_t.op2(args)
+                        def check_parens_nest_level_gt_zero(s: str) -> bool:
+                            level = 1
+                            for ch in s:
+                                if ch == ")":
+                                    level -= 1
+                                    if level == 0:
+                                        return False
+                                if ch == "(":
+                                    level += 1
+                            return True
+
+                        is_single_method_on_self_t = check_parens_nest_level_gt_zero(
+                            between_parens
+                        )
+                        directly_do_inplace = (
+                            is_single_method_on_self_t and op_name == info.name
+                        )
+
+                    if directly_do_inplace:
+                        assert op_name is not None
+                        assert between_parens is not None
+                        formula = f"self_t_raw.defined() ? self_t_raw.{op_name}_({between_parens}) : {formula}"
+                    else:
+                        # Make sure that the forward grad is modified inplace when the original formula
+                        # is out of place
+                        formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}"
+
+                required_original_self_value = bool(
+                    re.search(IDENT_REGEX.format("original_self_p"), formula)
+                ) or bool(re.search(IDENT_REGEX.format("original_self_t"), formula))
+
+                forward_derivatives = [
+                    ForwardDerivative(
+                        formula=formula,
+                        var_names=("self",),
+                        var_types=fw_info.var_types,
+                        required_inputs_fw_grad=fw_info.required_inputs_fw_grad,
+                        required_inputs_primal=required_primals,
+                        required_original_self_value=required_original_self_value,
+                        is_reusing_outplace_formula=not is_exact_match,
+                    ),
+                ]
+
+            fw_derivative_dict[key] = forward_derivatives
+
+        result.append(
+            NativeFunctionWithDifferentiabilityInfo(
+                func=f, info=info_dict, fw_derivatives=fw_derivative_dict
+            )
+        )
+
+    return result
+
+
+def is_differentiable(
+    name: str, type: Type, info: DifferentiabilityInfo | None
+) -> bool:
+    return type.is_tensor_like() and (
+        info is None or name not in info.non_differentiable_arg_names
+    )
+
+
+def gen_differentiable_outputs(
+    fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
+) -> list[DifferentiableOutput]:
+    f = fn.func
+    info = fn.info[key] if fn.info else None
+    outputs: list[DifferentiableOutput] = [
+        DifferentiableOutput(
+            name=name,
+            type=ret.type,
+            cpp_type=cpp.return_type(ret, symint=True).cpp_type(),
+        )
+        for name, ret in zip(cpp.return_names(f), f.func.returns)
+    ]
+    output_differentiability = info.output_differentiability if info else None
+    if output_differentiability is not None:
+        if len(output_differentiability) != len(outputs):
+            raise RuntimeError(
+                f"The length of output_differentiability ({len(output_differentiability)}), "
+                f"does not match the number of outputs ({len(outputs)})."
+            )
+        differentiable_outputs: list[DifferentiableOutput] = []
+        if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
+            raise RuntimeError(
+                "output_differentiability=False for inplace operation (version_counter won't get updated)"
+            )
+        for differentiable, output in zip(output_differentiability, outputs):
+            if differentiable:
+                differentiable_outputs.append(output)
+        return differentiable_outputs
+    candidate_differentiable_outputs = list(
+        filter(lambda r: is_differentiable(r.name, r.type, info), outputs)
+    )
+    if uses_single_grad(info):
+        return candidate_differentiable_outputs[:1]
+    else:
+        return candidate_differentiable_outputs
diff --git a/phivenv/Lib/site-packages/torchgen/api/cpp.py b/phivenv/Lib/site-packages/torchgen/api/cpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..638abc2a54362e4c14091fc1c8af31b8f0ccc1b5
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/api/cpp.py
@@ -0,0 +1,469 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+from typing_extensions import assert_never
+
+from torchgen import local
+from torchgen.api.types import (
+    ArgName,
+    ArrayCType,
+    ArrayRefCType,
+    BaseCType,
+    BaseTypeToCppMapping,
+    Binding,
+    boolT,
+    ConstRefCType,
+    CType,
+    dimnameListT,
+    intArrayRefT,
+    iTensorListRefT,
+    ListCType,
+    longT,
+    MutRefCType,
+    NamedCType,
+    OptionalCType,
+    optionalIntArrayRefT,
+    optionalSymIntArrayRefT,
+    scalarT,
+    SpecialArgName,
+    symIntArrayRefT,
+    SymIntT,
+    tensorListT,
+    tensorOptionsT,
+    tensorT,
+    TupleCType,
+    VectorCType,
+    voidT,
+)
+from torchgen.model import (
+    Argument,
+    Arguments,
+    BaseTy,
+    BaseType,
+    FunctionSchema,
+    ListType,
+    NativeFunction,
+    OptionalType,
+    Return,
+    SelfArgument,
+    TensorOptionsArguments,
+    Type,
+)
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+
+# This file describes the translation of JIT schema to the public C++
+# API, which is what people use when they call functions like at::add.
+#
+# Prominent characteristics of the C++ API:
+#
+#   - dtype, layout, device and pin_memory are collected into
+#     a single C++ type TensorOptions  (the native functions API
+#     also has this, but tensor options is really most relevant
+#     for the C++ API; it makes calling kwarg factory functions
+#     pleasant)
+#
+#   - defaulting lives here (in fact, the dispatcher is completely
+#     oblivious of defaults!)
+#
+# BTW: policy on name collisions: we try not to have types with
+# collisions, but functions are fair game to collide
+
+
+def name(
+    func: FunctionSchema,
+    *,
+    faithful_name_for_out_overloads: bool = False,
+    symint_overload: bool = False,
+) -> str:
+    name = str(func.name.name)
+    if symint_overload:
+        name += "_symint"
+    if func.is_out_fn():
+        if faithful_name_for_out_overloads:
+            name += "_outf"
+        else:
+            name += "_out"
+
+    return name
+
+
+# Translation of "value types" in JIT schema to C++ API type.  Value
+# types look the same no matter if they are argument types or return
+# types.  Returns None if the type in question is not a value type.
+def valuetype_type(
+    t: Type,
+    *,
+    binds: ArgName,
+    mutable: bool = True,
+    symint: bool = False,
+) -> NamedCType | None:
+    if isinstance(t, BaseType):
+        if t.name in (BaseTy.Tensor, BaseTy.Scalar):
+            return None
+        elif str(t) == "SymInt":
+            if symint:
+                return NamedCType(binds, BaseCType(SymIntT))
+            else:
+                return NamedCType(binds, BaseCType(longT))
+        # All other BaseType currently map directly to BaseCppTypes.
+        return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
+    elif isinstance(t, OptionalType):
+        elem = valuetype_type(t.elem, binds=binds, mutable=mutable, symint=symint)
+        if elem is None:
+            return None
+        return NamedCType(binds, OptionalCType(elem.type))
+    elif isinstance(t, ListType):
+        if str(t.elem) == "bool":
+            assert t.size is not None
+            return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size))
+        else:
+            return None
+    else:
+        raise AssertionError(f"unrecognized type {repr(t)}")
+
+
+# Translation of types occurring in JIT arguments to a C++ argument type.
+# If remove_non_owning_ref_types is set, we'll guarantee that the output CType is not a non-owning reference type.
+# For example, we'll return std::vector instead of IntArrayRef.
+# See Note [translation from C++ reference to value types]
+def argumenttype_type(
+    t: Type,
+    *,
+    mutable: bool,
+    binds: ArgName,
+    remove_non_owning_ref_types: bool = False,
+    symint: bool = False,
+) -> NamedCType:
+    # If it's a value type, do the value type translation
+    r = valuetype_type(
+        t,
+        binds=binds,
+        mutable=mutable,
+        symint=symint,
+    )
+    if r is not None:
+        return r
+
+    if isinstance(t, BaseType):
+        if t.name == BaseTy.Tensor:
+            if mutable and not local.use_const_ref_for_mutable_tensors():
+                return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
+            else:
+                return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
+        elif t.name == BaseTy.Scalar:
+            return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
+        else:
+            raise AssertionError(f"base type should have been value type {t}")
+    elif isinstance(t, OptionalType):
+        if str(t.elem) == "Tensor":
+            if mutable and not local.use_const_ref_for_mutable_tensors():
+                return NamedCType(
+                    binds, MutRefCType(BaseCType(tensorT))
+                )  # TODO: fix this discrepancy
+            else:
+                return NamedCType(
+                    binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
+                )
+        elif str(t.elem) == "Scalar":
+            return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
+        elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
+            return NamedCType(binds, BaseCType(optionalIntArrayRefT))
+        elif isinstance(t.elem, ListType) and str(t.elem.elem) == "SymInt":
+            if symint:
+                return NamedCType(binds, BaseCType(optionalSymIntArrayRefT))
+            else:
+                return NamedCType(binds, BaseCType(optionalIntArrayRefT))
+        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
+        return NamedCType(binds, OptionalCType(elem.type))
+    elif isinstance(t, ListType):
+        # TODO: remove these special cases, ArrayRef fallthrough works fine
+        if str(t.elem) == "int":
+            if remove_non_owning_ref_types:
+                return NamedCType(binds, VectorCType(BaseCType(longT)))
+            else:
+                return NamedCType(binds, BaseCType(intArrayRefT))
+        if str(t.elem) == "SymInt":
+            if remove_non_owning_ref_types:
+                if symint:
+                    return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
+                else:
+                    return NamedCType(binds, VectorCType(BaseCType(longT)))
+            else:
+                if symint:
+                    return NamedCType(binds, BaseCType(symIntArrayRefT))
+                else:
+                    return NamedCType(binds, BaseCType(intArrayRefT))
+        if str(t.elem) == "Tensor":
+            if local.use_ilistref_for_tensor_lists():
+                return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
+            else:
+                return NamedCType(binds, BaseCType(tensorListT))
+        elif str(t.elem) == "Scalar":
+            return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
+        elif str(t.elem) == "Dimname":
+            return NamedCType(binds, BaseCType(dimnameListT))
+        elif str(t.elem) == "Tensor?":
+            return NamedCType(
+                binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
+            )
+        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
+        return NamedCType(binds, ArrayRefCType(elem.type))
+    else:
+        raise AssertionError(f"unrecognized type {repr(t)}")
+
+
+# Translate a JIT argument into its C++ type
+def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType:
+    return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds)
+
+
+# Translation of a (non-multi) return type from JIT to C++
+# N.B: returntype_type returns a CType, not a NamedCType.
+# This is mostly because of the mismatch between return types and return names.
+# e.g. a function with a return type of 'void' has 0 return names,
+# and a function with a return type of 'std::tuple' has >1 return name.
+def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
+    # placeholder is ignored
+    # NB: symint is ALWAYS respected for return types.  So symint argument
+    # here is IGNORED
+    r = valuetype_type(t, binds="__placeholder__", mutable=mutable, symint=True)
+    if r is not None:
+        return r.type
+
+    if isinstance(t, BaseType):
+        if t.name == BaseTy.Tensor:
+            if mutable:
+                if local.use_const_ref_for_mutable_tensors():
+                    return ConstRefCType(BaseCType(tensorT))
+                else:
+                    return MutRefCType(BaseCType(tensorT))
+            else:
+                # Note [Tensor Copy Returns]
+                # Currently, we use "Argument.is_write" to determine
+                # whether or not Tensor return types should be copies or references.
+                # If that ever changes, take a look at other locations of this note!
+                return BaseCType(tensorT)
+        elif t.name == BaseTy.Scalar:
+            return BaseCType(scalarT)
+    elif isinstance(t, ListType):
+        assert not mutable, (
+            "Native functions should never return a mutable tensor list. They should return void."
+        )
+        elem = returntype_type(t.elem, mutable=False)
+        assert t.size is None, f"fixed size list returns not supported: {t}"
+        return VectorCType(elem)
+    elif isinstance(t, OptionalType):
+        elem = returntype_type(t.elem, mutable=mutable)
+        if str(t.elem) == "Tensor":
+            return OptionalCType(elem)
+
+    raise AssertionError(f"unrecognized return type {t}")
+
+
+# Translation of a single return to its C++ type
+def return_type(r: Return, *, symint: bool = False) -> CType:
+    return returntype_type(r.type, mutable=r.is_write, symint=symint)
+
+
+# Translation of a full (possibly multi) return from JIT to its C++ type
+def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
+    if len(rs) == 0:
+        return BaseCType(voidT)
+    elif len(rs) == 1:
+        return return_type(rs[0], symint=symint)
+    else:
+        return TupleCType([return_type(r, symint=symint) for r in rs])
+
+
+def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
+    returns: list[str] = []
+    for i, r in enumerate(f.func.returns):
+        # If we have an inplace function, the return argument is
+        # implicitly named self.
+        # TODO: Consider incorporating this into the data model
+        if f.func.name.name.inplace:
+            assert i == 0, "illegal inplace function with multiple returns"
+            name = "self"
+        # If we are out function, the name is the name of the
+        # corresponding output function (r.name will get recorded
+        # in field_name later.)
+        elif f.func.is_out_fn():
+            name = f.func.arguments.out[i].name
+        # If the return argument is explicitly named...
+        elif r.name:
+            name_conflict = any(
+                r.name == a.name for a in f.func.schema_order_arguments()
+            )
+            if name_conflict and not f.func.is_out_fn():
+                name = f"{r.name}_return"
+            else:
+                name = r.name
+        # If there is no explicit name and no fallback name was passed in, we just name the output result,
+        # unless it's a multi-return, in which case it's result0,
+        # result1, etc (zero-indexed)
+        else:
+            name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
+        returns.append(name)
+    return returns
+
+
+JIT_TO_CPP_DEFAULT = {
+    "False": "false",
+    "True": "true",
+    "None": "::std::nullopt",  # UGH this one is type directed
+    "Mean": "at::Reduction::Mean",
+    "[]": "{}",
+    "contiguous_format": "c10::MemoryFormat::Contiguous",
+    "long": "at::kLong",
+}
+
+
+# Convert a JIT default into C++ expression representing the default
+def default_expr(d: str, t: Type, *, symint: bool) -> str:
+    if d == "None" and str(t) == "Tensor?":
+        return "{}"
+    if isinstance(t, BaseType) and t.name is BaseTy.str:
+        # Schema allows single quotes but C++ needs double
+        if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
+            s = ""
+            i = 1
+            while i + 1 < len(d):
+                if d[i] != "\\":
+                    if d[i] == '"':
+                        s += '\\"'
+                    else:
+                        s += d[i]
+                    i += 1
+                else:
+                    if d[i + 1] == "'":
+                        s += "'"
+                    else:
+                        s += d[i : i + 2]
+                    i += 2
+
+            return f'"{s}"'
+
+    if isinstance(t, OptionalType):
+        if d == "None":
+            return "::std::nullopt"
+
+        return default_expr(d, t.elem, symint=symint)
+
+    if isinstance(t, ListType):
+        if d.startswith("[") and d.endswith("]"):
+            return "{" + d[1:-1] + "}"
+        elif symint and d.isdigit() and str(t.elem) == "SymInt":
+            return f"c10::SymInt({d})"
+        elif t.size is None:
+            # NOTE: Sized lists can have scalar defaults
+            raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
+
+    return JIT_TO_CPP_DEFAULT.get(d, d)
+
+
+# Convert an argument into its C++ API form
+
+
+def argument(
+    a: Argument | TensorOptionsArguments | SelfArgument,
+    *,
+    cpp_no_default_args: set[str],
+    method: bool,
+    faithful: bool,
+    symint: bool = False,
+    has_tensor_options: bool,
+) -> list[Binding]:
+    def sub_argument(
+        a: Argument | TensorOptionsArguments | SelfArgument,
+    ) -> list[Binding]:
+        return argument(
+            a,
+            cpp_no_default_args=cpp_no_default_args,
+            method=method,
+            faithful=faithful,
+            symint=symint,
+            has_tensor_options=has_tensor_options,
+        )
+
+    if isinstance(a, Argument):
+        binds: ArgName
+        if a.name == "memory_format" and has_tensor_options:
+            binds = SpecialArgName.possibly_redundant_memory_format
+        else:
+            binds = a.name
+        default: str | None = None
+        if a.name not in cpp_no_default_args and a.default is not None:
+            default = default_expr(a.default, a.type, symint=symint)
+        return [
+            Binding(
+                nctype=argument_type(a, binds=binds, symint=symint),
+                name=a.name,
+                default=default,
+                argument=a,
+            )
+        ]
+    elif isinstance(a, TensorOptionsArguments):
+        if faithful:
+            return (
+                sub_argument(a.dtype)
+                + sub_argument(a.layout)
+                + sub_argument(a.device)
+                + sub_argument(a.pin_memory)
+            )
+        else:
+            default = None
+            # Enforced by NativeFunction.__post_init__
+            assert "options" not in cpp_no_default_args
+            if all(x.default == "None" for x in a.all()):
+                default = "{}"
+            elif a.dtype.default == "long":
+                default = "at::kLong"  # TODO: this is wrong
+            return [
+                Binding(
+                    nctype=NamedCType("options", BaseCType(tensorOptionsT)),
+                    name="options",
+                    default=default,
+                    argument=a,
+                )
+            ]
+    elif isinstance(a, SelfArgument):
+        if method:
+            # Caller is responsible for installing implicit this in context!
+            return []
+        else:
+            return sub_argument(a.argument)
+    else:
+        assert_never(a)
+
+
+def arguments(
+    arguments: Arguments,
+    *,
+    faithful: bool,
+    symint: bool = False,
+    method: bool,
+    cpp_no_default_args: set[str],
+) -> list[Binding]:
+    args: list[Argument | TensorOptionsArguments | SelfArgument] = []
+    if faithful:
+        args.extend(arguments.non_out)
+        args.extend(arguments.out)
+    else:
+        args.extend(arguments.out)
+        args.extend(arguments.non_out)
+    return [
+        r.no_default() if faithful else r
+        for a in args
+        for r in argument(
+            a,
+            faithful=faithful,
+            symint=symint,
+            method=method,
+            has_tensor_options=arguments.tensor_options is not None,
+            cpp_no_default_args=cpp_no_default_args,
+        )
+    ]
diff --git a/phivenv/Lib/site-packages/torchgen/api/dispatcher.py b/phivenv/Lib/site-packages/torchgen/api/dispatcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..02e11146f1fd05982f38fcad69594a06fe366773
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/api/dispatcher.py
@@ -0,0 +1,125 @@
+from __future__ import annotations
+
+import itertools
+from typing import TYPE_CHECKING
+from typing_extensions import assert_never
+
+from torchgen.api import cpp
+from torchgen.api.types import ArgName, Binding, CType, NamedCType
+from torchgen.model import (
+    Argument,
+    FunctionSchema,
+    Return,
+    SelfArgument,
+    TensorOptionsArguments,
+    Type,
+)
+from torchgen.utils import concatMap
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+
+# This file describes the translation of JIT schema to the dispatcher
+# API, the *unboxed* calling convention by which invocations through
+# the dispatcher are made.  Historically, the dispatcher API matched
+# the C++ API, but with the establishment of the boxed API, we've
+# made changes to the dispatcher API to so that the unboxed API
+# better aligns with the boxed API.  The dispatcher API hooks heavily
+# into our template based boxing/unboxing machinery, so changes
+# to this convention will usually need template updates too.
+#
+# Prominent characteristics of the dispatcher API:
+#
+#   - dtype, layout, device and pin_memory are represented as separate
+#     arguments.
+#
+
+
+def name(func: FunctionSchema) -> str:
+    return cpp.name(func)
+
+
+def argumenttype_type(
+    t: Type,
+    *,
+    mutable: bool,
+    binds: ArgName,
+    remove_non_owning_ref_types: bool = False,
+    symint: bool = True,
+) -> NamedCType:
+    # This is a faux amis.  If it makes sense in the future to add
+    # more special cases here, or invert things so cpp.argument_type
+    # calls this, or just completely inline the function, please do
+    # it.
+    return cpp.argumenttype_type(
+        t,
+        mutable=mutable,
+        binds=binds,
+        symint=symint,
+        remove_non_owning_ref_types=remove_non_owning_ref_types,
+    )
+
+
+def argument_type(
+    a: Argument,
+    *,
+    binds: ArgName,
+    remove_non_owning_ref_types: bool = False,
+    symint: bool = True,
+) -> NamedCType:
+    return argumenttype_type(
+        a.type,
+        mutable=a.is_write,
+        binds=binds,
+        remove_non_owning_ref_types=remove_non_owning_ref_types,
+        symint=symint,
+    )
+
+
+def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType:
+    # At present, there is no difference. But there could be!
+    return cpp.returns_type(rs, symint=symint)
+
+
+def jit_arguments(func: FunctionSchema) -> list[Argument]:
+    def to_argument(
+        a: Argument | TensorOptionsArguments | SelfArgument,
+    ) -> list[Argument]:
+        if isinstance(a, Argument):
+            return [a]
+        elif isinstance(a, SelfArgument):
+            return [a.argument]
+        elif isinstance(a, TensorOptionsArguments):
+            return [a.dtype, a.layout, a.device, a.pin_memory]
+        else:
+            assert_never(a)
+
+    return list(
+        concatMap(
+            to_argument,
+            itertools.chain(
+                func.arguments.positional, func.arguments.kwarg_only, func.arguments.out
+            ),
+        )
+    )
+
+
+def argument(
+    a: Argument, *, remove_non_owning_ref_types: bool = False, symint: bool = True
+) -> Binding:
+    return Binding(
+        nctype=argument_type(
+            a,
+            binds=a.name,
+            remove_non_owning_ref_types=remove_non_owning_ref_types,
+            symint=symint,
+        ),
+        name=a.name,
+        argument=a,
+    )
+
+
+def arguments(func: FunctionSchema, *, symint: bool = True) -> list[Binding]:
+    return [argument(a, symint=symint) for a in jit_arguments(func)]
diff --git a/phivenv/Lib/site-packages/torchgen/api/functionalization.py b/phivenv/Lib/site-packages/torchgen/api/functionalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d2adb7555ead8a7afabacff20e5616ece3c688f
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/api/functionalization.py
@@ -0,0 +1,199 @@
+from __future__ import annotations
+
+from torchgen.api import dispatcher
+from torchgen.api.types import (
+    BaseCppType,
+    BaseCType,
+    Binding,
+    boolT,
+    ConstRefCType,
+    CType,
+    longT,
+    NamedCType,
+    tensorT,
+)
+from torchgen.model import (
+    Argument,
+    BaseTy,
+    BaseType,
+    FunctionSchema,
+    NativeFunction,
+    NativeFunctionsViewGroup,
+)
+
+
+# This file describes the translation of JIT schema to API's used
+# when creating view lambdas that are used by the functionalization pass.
+# There are two types of lambdas: forward lambdas and reverse lambdas.
+# These API's mostly follow the dispatcher API, with a few quirks:
+# - The lambda capture has to convert reference types to value types
+# - While the forward lambda just directly calls into the at::_ops API
+#   (following the dispatcher convention), the logic here for the reverse lambda
+#   is responsible for generating both the call-site, and the declarations
+#   (which are implemented manually in the at::functionalization::impl namespace).
+
+# The lambdas generated for each view op in the functionalization pass are of the form
+# [capture_arguments](outer_arguments) -> returns_type {
+#     return name(inner_arguments);
+# }
+
+# Define some specific lambda input arguments.
+base_binding = Binding(
+    name="base",
+    nctype=NamedCType(name="base", type=ConstRefCType(BaseCType(tensorT))),
+    argument=Argument(
+        name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
+    ),
+    default=None,
+)
+mutated_view_binding = Binding(
+    name="mutated_view",
+    nctype=NamedCType(name="mutated_view", type=ConstRefCType(BaseCType(tensorT))),
+    argument=Argument(
+        name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
+    ),
+    default=None,
+)
+mutated_view_idx_binding = Binding(
+    name="mutated_view_idx",
+    nctype=NamedCType(name="mutated_view_idx", type=BaseCType(longT)),
+    argument=Argument(
+        name="base", type=BaseType(BaseTy.Tensor), default=None, annotation=None
+    ),
+    default=None,
+)
+reapply_views_binding = Binding(
+    name="reapply_views",
+    nctype=NamedCType(name="reapply_views", type=BaseCType(boolT)),
+    argument=Argument(
+        name="reapply_views", type=BaseType(BaseTy.bool), default=None, annotation=None
+    ),
+    default=None,
+)
+
+InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode")
+inverse_return_mode_binding = Binding(
+    name="inverse_return_mode",
+    nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)),
+    argument=Argument(
+        name="inverse_return_mode",
+        # NB: not actually a bool but it doesn't matter because this isn't used
+        type=BaseType(BaseTy.bool),
+        default=None,
+        annotation=None,
+    ),
+    default=None,
+)
+
+
+# The lambda capture itself doesn't have a name.
+# The name returned here corresponds to the name of the inner function called by the lambda.
+def name(
+    g: NativeFunctionsViewGroup,
+    *,
+    is_reverse: bool,
+    include_namespace: bool,
+    reapply_views: bool | None = None,
+) -> str:
+    if reapply_views is None:
+        # reapply_views is only important for the fwd lambda,
+        # since we always plumb the runtime "reapply_views" argument into the reverse function.
+        assert is_reverse
+    if is_reverse:
+        return reverse_name(g.view, include_namespace)
+    # in the forward case, we just directly call into the at::_ops API (so we always need the namespace)
+    assert include_namespace
+    assert g.view_copy is not None
+    api_name = (
+        g.view.func.name.unambiguous_name()
+        if reapply_views
+        else g.view_copy.func.name.unambiguous_name()
+    )
+    return f"at::_ops::{api_name}::call"
+
+
+def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
+    # for the reverse: we plumb the "reapply_views" flag into that function and support
+    # both copy and non-copy variants. (We could avoid doing that, but that would require
+    # writing out twice as many view inverse functions).
+    api_name = f.func.name.unambiguous_name()
+    # in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't)
+    if include_namespace:
+        return f"at::functionalization::FunctionalInverses::{api_name}_inverse"
+    else:
+        return f"{api_name}_inverse"
+
+
+def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
+    # capture arguments include all arguments except `self`.
+    # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
+    # So any reference types (IntArrayRef) need to be converted to value types (vector)
+    args = func.arguments.flat_all
+    assert args[0].type == BaseType(BaseTy.Tensor)
+    non_self_args = args[1:]
+    non_self_value_bindings = [
+        dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
+    ]
+
+    all_bindings = [
+        inverse_return_mode_binding if is_reverse else reapply_views_binding
+    ]
+    all_bindings.extend(non_self_value_bindings)
+    return all_bindings
+
+
+def returns_type(func: FunctionSchema) -> CType:
+    # Assertion: all view ops return tensor-like outputs
+    assert len(func.returns) >= 1
+    for ret in func.returns:
+        assert ret.type.is_tensor_like()
+    # However, the return type of the lambda is always an individual tensor.
+    # For multi-tensor outputs, each tensor needs to be tracked individually.
+    return BaseCType(tensorT)
+
+
+def outer_arguments(*, is_reverse: bool) -> list[Binding]:
+    if is_reverse:
+        return [base_binding, mutated_view_binding, mutated_view_idx_binding]
+    else:
+        return [base_binding, mutated_view_idx_binding]
+
+
+def inner_call_index(func: FunctionSchema) -> Binding | None:
+    # For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
+    # When we replay a view op that returns multiple tensors, we need to index into the output appropriately
+    if len(func.returns) > 1 or (
+        len(func.returns) == 1 and func.returns[0].type.is_list_like()
+    ):
+        return mutated_view_idx_binding
+    return None
+
+
+def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
+    args = func.arguments.flat_all
+    assert args[0].type == BaseType(BaseTy.Tensor)
+    non_self_args = args[1:]
+    # The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API.
+    # Both of these follow the dispatcher API.
+    non_self_bindings = [dispatcher.argument(a) for a in non_self_args]
+    if not is_reverse:
+        # the forward lambda swaps out the original tensor argument with the lambd arg "base"
+        return [base_binding] + non_self_bindings
+    else:
+        # the reverse lambda does the same, but with an additional "mutated_view" arg
+        # additionally, we have a calling convention: for view ops that return multiple tensor outputs
+        # their corresponding view_inverse function takes in an additional index argument.
+        index_binding = inner_call_index(func)
+        if index_binding is not None:
+            return [
+                base_binding,
+                mutated_view_binding,
+                inverse_return_mode_binding,
+                index_binding,
+            ] + non_self_bindings
+        else:
+            return [
+                base_binding,
+                mutated_view_binding,
+                inverse_return_mode_binding,
+            ] + non_self_bindings
diff --git a/phivenv/Lib/site-packages/torchgen/api/lazy.py b/phivenv/Lib/site-packages/torchgen/api/lazy.py
new file mode 100644
index 0000000000000000000000000000000000000000..b506225dc029f1c05add9df3837b3dc2d4cea51a
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/api/lazy.py
@@ -0,0 +1,468 @@
+from __future__ import annotations
+
+from typing import Any
+
+from torchgen.api.types import (
+    BaseCppType,
+    BaseCType,
+    boolT,
+    CType,
+    deviceT,
+    doubleT,
+    generatorT,
+    layoutT,
+    ListCType,
+    longT,
+    memoryFormatT,
+    NamedCType,
+    OptionalCType,
+    scalarT,
+    scalarTypeT,
+    stringT,
+    SymIntT,
+    VectorCType,
+)
+from torchgen.model import (
+    Argument,
+    BaseTy,
+    BaseType,
+    FunctionSchema,
+    ListType,
+    OperatorName,
+    OptionalType,
+    Return,
+    TensorOptionsArguments,
+    Type,
+)
+
+
+_valueT: BaseCppType | None = None
+
+
+# A ValueT is an IR type which represents the computation of a Tensor.  In other
+# words, a PyTorch user will do operations on lazy tensors, and each output lazy
+# tensor internally tracks a ValueT representing the IR node that would have
+# actually produced the value of this tensor for real.
+#
+# This is configurable because different lazy tensor backends (LTC vs XLA) will
+# have different IR representations.  (Though, arguably, after unification they
+# shouldn't!)
+def getValueT() -> BaseCppType:
+    global _valueT
+    if not _valueT:
+        raise NotImplementedError(
+            "The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
+        )
+
+    return _valueT
+
+
+def setValueT(val: BaseCppType) -> None:
+    global _valueT
+    _valueT = val
+
+
+# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
+# making it easier to represent special properties of an arg.
+tensorListValueT = BaseCppType("torch::lazy", "Value")
+
+
+def process_ir_type(
+    typ: Type, properties: LazyIrProperties, *, symint: bool
+) -> BaseCType | VectorCType | OptionalCType | ListCType:
+    """
+    This function takes a type from NativeFunctions and converts it for use with
+    lazy tensor codegen.
+
+    Type conversion for lazy currently consists of
+     (1) changing at::Tensors into lazy::Values
+     (2) wrapping everything in a BaseCType
+     (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)
+
+    (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
+    There is special handling for Optional[Tensor] or list[Tensor], etc- hence 'tensor-like'
+
+    This is incomplete- there are assertions in places that it's expected to need to add
+    more types as the codegen is used with more operators.
+    """
+    if isinstance(typ, BaseType):
+        if typ.name == BaseTy.Tensor:
+            return BaseCType(getValueT())
+        elif typ.name == BaseTy.Scalar:
+            if properties.TreatScalarsAsConstants:
+                return BaseCType(scalarT)
+            # at::scalar has special handling,
+            # and is wrapped in an lazy::Value just like at::tensor
+            return BaseCType(getValueT())
+        elif typ.name == BaseTy.ScalarType:
+            return BaseCType(scalarTypeT)
+        elif typ.name == BaseTy.int:
+            return BaseCType(longT)
+        elif typ.name == BaseTy.SymInt:
+            if symint:
+                return BaseCType(getValueT())
+            else:
+                return BaseCType(longT)
+        elif typ.name == BaseTy.bool:
+            return BaseCType(boolT)
+        elif typ.name == BaseTy.float:
+            return BaseCType(doubleT)
+        elif typ.name == BaseTy.str:
+            return BaseCType(stringT)
+        elif typ.name == BaseTy.Device:
+            return BaseCType(deviceT)
+        elif typ.name == BaseTy.Generator:
+            return BaseCType(generatorT)
+        elif typ.name == BaseTy.Layout:
+            return BaseCType(layoutT)
+        elif typ.name == BaseTy.MemoryFormat:
+            return BaseCType(memoryFormatT)
+        else:
+            raise AssertionError(f"TODO add support for type {repr(typ)}")
+    elif isinstance(typ, OptionalType):
+        return OptionalCType(process_ir_type(typ.elem, properties, symint=symint))
+    elif isinstance(typ, ListType):
+        if str(typ.elem) == "Tensor?":
+            # TODO(whc) is this actually correct? or should it use a Vector like above
+            return ListCType(OptionalCType(BaseCType(getValueT())))
+        elif str(typ.elem) == "Tensor":
+            # this is a TensorList which comes in from GetTensorList as a Value
+            return BaseCType(tensorListValueT)
+        elif typ.elem == BaseType(BaseTy.SymInt):
+            # TODO: return a value type.  The problem here is analogous to
+            # the problem with tensorListValueT: if you have SymInt[] you
+            # cannot conveniently save the list of Value directly, as nodes
+            # expect to save values as a vector for ALL arguments.  So you
+            # need a separate IR node that represents all of the size nodes
+            # assembled into a list.  I'm not an LTC dev so I don't want to
+            # figure it out right now.  Y'all figure it out...
+            return VectorCType(BaseCType(longT))
+
+        else:
+            return VectorCType(process_ir_type(typ.elem, properties, symint=symint))
+    else:
+        raise AssertionError(f"unrecognized type {repr(typ)}")
+
+
+# TODO: Determining this based off of CType is bad; this should be computed
+# from Type directly; then the same logic as process_ir_type can be used
+#
+# Invariant: passed typ should be an *owning* CType (e.g., we will report
+# that ArrayRef is NOT a value type)
+def isValueType(typ: CType, properties: LazyIrProperties | None = None) -> bool:
+    """
+    Given a type, determine if it is a Value-like type.  This is equivalent to
+    being Tensor-like, but assumes the type has already been transformed.
+    """
+    if isinstance(typ, BaseCType):
+        # I am regretting my naming conventions, but now we are wrapping at::scalar in
+        # lazy value, while preserving other 'scalar' types as scalars in the IR
+        treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants
+        return (
+            typ.type == getValueT()
+            or (typ.type == scalarT and not treat_scalars_as_constants)
+            or typ.type == SymIntT
+        )
+    elif typ == VectorCType(BaseCType(SymIntT)):
+        # TODO: report True for this
+        return False
+    elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
+        return isValueType(typ.elem, properties)
+    return False
+
+
+def isSymIntType(typ: Type) -> bool:
+    return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt
+
+
+def isWrappedScalarType(typ: Type) -> bool:
+    """
+    Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
+    Since we literally change the type from scalarT to valueT, information is lost.
+    This function helps build a list of wrapped scalars to save that information
+    """
+    if isinstance(typ, BaseType):
+        # I am regretting my naming conventions, but now we are wrapping at::scalar in
+        # lazy value, while preserving other 'scalar' types as scalars in the IR
+        return typ.name == BaseTy.Scalar
+    elif isinstance(typ, (OptionalType, ListType)):
+        return isWrappedScalarType(typ.elem)
+    return False
+
+
+# TODO: dedupe with Type.is_generator_like
+def isGeneratorType(typ: Type) -> bool:
+    if isinstance(typ, BaseType):
+        return typ.name == BaseTy.Generator
+    elif isinstance(typ, (OptionalType)):
+        return isGeneratorType(typ.elem)
+    return False
+
+
+# This class caches a few derived properties computed from an Argument
+# and LazyIrProperties
+class LazyArgument:
+    name: str
+    orig_type: Type
+    lazy_type_: CType | None
+    is_wrapped_scalar: bool
+    is_generator: bool
+    # TODO: this is lies, it is false for symint list
+    is_symint_or_list: bool
+
+    # Whether or not we are treating this as symint or not
+    symint: bool
+
+    # true if this argument is or contains a lazy IR value
+    is_lazy_value: bool
+
+    def __init__(
+        self, arg: Argument, properties: LazyIrProperties, *, symint: bool
+    ) -> None:
+        self.name = arg.name
+        self.orig_type = arg.type
+        self.symint = symint
+        self.is_optional = isinstance(arg.type, OptionalType)
+        self.is_generator = isGeneratorType(arg.type)
+        self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint)
+        self.is_wrapped_scalar = isWrappedScalarType(arg.type)
+        self.is_symint_or_list = symint and (
+            isSymIntType(arg.type)
+            or (isinstance(arg.type, OptionalType) and isSymIntType(arg.type.elem))
+            # TODO: lists of symints are not currently treated as value types
+            # or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem))
+        )
+
+        self.is_lazy_value = isValueType(self.lazy_type, properties)
+
+    @property
+    def lazy_type(self) -> CType:
+        assert self.lazy_type_ is not None, (
+            f"Attempted to access lazy_type for invalid argument {self.name}"
+        )
+        return self.lazy_type_
+
+
+class LazyIrProperties:
+    """Collection of properties for an IR node
+
+    The property groups are listed below. Each group is mutually
+    exclusive, meaning that only one property from each group can be True
+    at any one time. The properties can be accessed as if they were normal
+    attributes. The mutual exclusivity is automatically handled.
+    """
+
+    Properties: tuple[tuple[str, ...], ...] = (
+        (
+            "ShapePrecompute",  # Assume shape has been precomputed
+            "ShapeCompute",  # Need to compute the shape on construction
+            "ShapeCache",  # Utilize the shape cache to defer computation
+        ),
+        (
+            "Lower",  # Codegen full lower function
+            "LowerDeclOnly",  # Codegen only lower function declaration
+        ),
+        (
+            "CanBeReused",  # Codegen full reuse function
+            "CanBeReusedDeclOnly",  # Codegen only reuse function declaration
+        ),
+        (
+            "CreateFn",  # Codegen full create function
+            "CreateFnDeclOnly",  # Codegen only create function declaration
+        ),
+        (
+            "TreatScalarsAsConstants",  # Treat Scalars as constants instead of handling like values
+        ),
+    )
+
+    def __init__(self, *default_properties: str) -> None:
+        properties: dict[tuple[str, ...], str | None] = dict.fromkeys(
+            LazyIrProperties.Properties
+        )
+        self.__dict__["properties"] = properties
+        for p in default_properties:
+            setattr(self, p, True)
+
+    def __getattr__(self, key: str) -> Any:
+        properties = self.__dict__["properties"]
+        for values in LazyIrProperties.Properties:
+            if key in values:
+                return properties[values] == key
+
+        return self.__getattribute__(key)
+
+    def __setattr__(self, key: str, value: Any) -> Any:
+        properties = self.__dict__["properties"]
+        for values in LazyIrProperties.Properties:
+            if key in values:
+                properties[values] = key if value else None
+                return value
+
+        raise KeyError(f"Invalid property: {key}")
+
+
+# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
+# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
+# but carries type information from a native FunctionSchema modified for use with IR nodes,
+# and preserving original argument names.
+#
+# TODO: This is not idiomatic with how other torchgen APIs transform on schema.
+class LazyIrSchema:
+    # The name of the operator this function schema describes.
+    name: OperatorName
+
+    positional_args: tuple[LazyArgument, ...]
+    keyword_args: tuple[LazyArgument, ...]
+
+    # TODO: Need to handle collisions with argument names at some point
+    returns: tuple[Return, ...]
+
+    # if this schema has a Generator arg, list its orig ctype/name but don't
+    # build a LazyArgument since lazy IR doesn't support it
+    generator_arg: NamedCType | None = None
+
+    # original function schema
+    func: FunctionSchema
+
+    # Whether or not we are code-genning for SymInt or not
+    symint: bool
+
+    properties: LazyIrProperties = LazyIrProperties(
+        # default properties
+        "ShapePrecompute",
+        "Lower",
+        "CanBeReused",
+    )
+    opkind: str | None = None
+
+    def __init__(
+        self,
+        func: FunctionSchema,
+        properties: LazyIrProperties | None = None,
+        *,
+        symint: bool,
+    ) -> None:
+        if properties:
+            self.properties = properties
+
+        self.func = func
+        self.symint = symint
+        positional_args: list[LazyArgument] = []
+        for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
+            if arg_field == "self_arg" and func.arguments.self_arg is not None:
+                arg = func.arguments.self_arg.argument
+                positional_args.append(
+                    LazyArgument(arg, self.properties, symint=symint)
+                )
+            elif getattr(func.arguments, arg_field) is not None:
+                positional_args.extend(
+                    LazyArgument(arg, self.properties, symint=symint)
+                    for arg in getattr(func.arguments, arg_field)
+                )
+        self.positional_args = tuple(positional_args)
+
+        keyword_args: list[LazyArgument] = []
+        for arg_field in [
+            "pre_tensor_options_kwarg_only",
+            "tensor_options",
+            "post_tensor_options_kwarg_only",
+            "out",
+        ]:
+            curr_args = getattr(func.arguments, arg_field)
+            if curr_args is not None:
+                if isinstance(curr_args, TensorOptionsArguments):
+                    curr_args = curr_args.all()
+                for arg in curr_args:
+                    if isGeneratorType(arg.type):
+                        assert self.generator_arg is None, (
+                            "We expect there is only one generator arg"
+                        )
+                        self.generator_arg = NamedCType(
+                            arg.name,
+                            arg.type,  # type:ignore[arg-type]
+                        )
+                keyword_args.extend(
+                    LazyArgument(arg, self.properties, symint=symint)
+                    for arg in curr_args
+                )
+        self.keyword_args = tuple(keyword_args)
+        self.name = func.name
+        self.returns = func.returns
+
+    @property
+    def node_name(self) -> str:
+        """
+        Return camel-case version of op in node.
+
+        Note: This function also appends any `overload_name` in the operation.
+        For example, if the op is `bitwise_and.Tensor`, the returned name
+        will be `BitwiseAndTensor`.
+        """
+        op_name = f"{self.name.name}_{self.name.overload_name}".lower()
+        return "".join(word.capitalize() or "" for word in op_name.split("_"))
+
+    @property
+    def aten_name(self) -> str:
+        return str(self.name.name)
+
+    @property
+    def base_name(self) -> str:
+        return f"{self.name.name.base}"
+
+    def filtered_args(
+        self,
+        positional: bool = True,
+        keyword: bool = True,
+        values: bool = True,
+        scalars: bool = True,
+        generator: bool = True,
+    ) -> list[LazyArgument]:
+        # This function maintains the sorted order of arguments but provides different filtered views.
+        # Some parts of the code care about kwargs vs args (TS lowerings),
+        # other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
+        # Generators are special cased, as they are needed for fallback/shape-inference but not supported
+        # in TS lowerings and therefore also omitted from lazy IR.
+        args: list[LazyArgument] = []
+        if positional:
+            args.extend(self.positional_args)
+        if keyword:
+            args.extend(self.keyword_args)
+
+        if values and scalars and generator:
+            return args
+        elif values and scalars:
+            return [a for a in args if not a.is_generator]
+        elif values:
+            return [a for a in args if a.is_lazy_value]
+        elif scalars:
+            return [
+                a
+                for a in args
+                if not a.is_lazy_value and (generator or not a.is_generator)
+            ]
+
+        return []
+
+    @property
+    def positional_values(self) -> list[LazyArgument]:
+        return self.filtered_args(
+            positional=True, keyword=False, values=True, scalars=False
+        )
+
+    @property
+    def positional_scalars(self) -> list[LazyArgument]:
+        return self.filtered_args(
+            positional=True, keyword=False, values=False, scalars=True
+        )
+
+    @property
+    def keyword_values(self) -> list[LazyArgument]:
+        return self.filtered_args(
+            positional=False, keyword=True, values=True, scalars=False
+        )
+
+    @property
+    def keyword_scalars(self) -> list[LazyArgument]:
+        return self.filtered_args(
+            positional=False, keyword=True, values=False, scalars=True
+        )
diff --git a/phivenv/Lib/site-packages/torchgen/api/meta.py b/phivenv/Lib/site-packages/torchgen/api/meta.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffafc500dbfdde8b290ef6a2283c57fd602f0606
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/api/meta.py
@@ -0,0 +1,13 @@
+from torchgen.model import NativeFunctionsGroup
+
+
+# Follows dispatcher calling convention, but:
+#   - Mutable arguments not allowed.  Meta functions are always
+#     written in functional form.  Look at FunctionSchema.signature()
+#   - No tensor returns; instead we return a TensorMeta describing
+#     the tensor in question
+
+
+def name(g: NativeFunctionsGroup) -> str:
+    # use the overload name from the functional version
+    return str(g.functional.func.name).replace(".", "_")
diff --git a/phivenv/Lib/site-packages/torchgen/api/native.py b/phivenv/Lib/site-packages/torchgen/api/native.py
new file mode 100644
index 0000000000000000000000000000000000000000..95e868a9254f9031207a5bde709db85a97bdfb72
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/api/native.py
@@ -0,0 +1,159 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+from typing_extensions import assert_never
+
+from torchgen import local
+from torchgen.api import cpp
+from torchgen.api.types import (
+    ArgName,
+    BaseCType,
+    Binding,
+    boolT,
+    ConstRefCType,
+    CType,
+    deviceT,
+    layoutT,
+    ListCType,
+    MutRefCType,
+    NamedCType,
+    OptionalCType,
+    scalarT,
+    scalarTypeT,
+    tensorT,
+)
+from torchgen.model import (
+    Argument,
+    FunctionSchema,
+    Return,
+    SelfArgument,
+    TensorOptionsArguments,
+    Type,
+)
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+
+# This file describes the translation of JIT schema to the native functions API.
+# This looks a lot like the C++ API (which makes historical sense, because the
+# idea was you wrote native functions to implement functions in the C++ API),
+# but over time we have evolved the C++ API without actually changing our
+# native:: kernels.  The intention is to make native API and dispatcher API
+# line up as closely as possible, since this results in the least overhead
+# (no translation is needed from dispatcher API to native API).
+#
+# NB: this is symint aware, you will get the non-SymInt variant for some
+# dispatch entries and SymInt for others.
+
+
+def name(func: FunctionSchema) -> str:
+    name = str(func.name.name)
+    # TODO: delete this!
+    if func.is_out_fn():
+        name += "_out"
+    if func.name.overload_name:
+        name += f"_{func.name.overload_name}"
+    return name
+
+
+def argumenttype_type(
+    t: Type, *, mutable: bool, binds: ArgName, symint: bool
+) -> NamedCType:
+    if str(t) == "Tensor?":
+        tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
+        if mutable and not local.use_const_ref_for_mutable_tensors():
+            return NamedCType(binds, MutRefCType(tensor_type))
+        else:
+            return NamedCType(binds, ConstRefCType(tensor_type))
+    elif str(t) == "Tensor?[]":
+        return NamedCType(
+            binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
+        )
+    elif str(t) == "Scalar":
+        return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
+    elif str(t) == "Scalar?":
+        return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
+    return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint)
+
+
+def returns_type(rs: Sequence[Return], *, symint: bool) -> CType:
+    return cpp.returns_type(rs, symint=symint)
+
+
+def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType:
+    return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint)
+
+
+def argument(
+    a: Argument | SelfArgument | TensorOptionsArguments,
+    *,
+    is_out: bool,
+    symint: bool,
+) -> list[Binding]:
+    # Ideally, we NEVER default native functions.  However, there are a number
+    # of functions that call native:: directly and rely on the defaulting
+    # existing.  So for BC, we generate defaults for non-out variants (but not
+    # for out variants, where it is impossible to generate an appropriate
+    # default)
+    should_default = not is_out
+    if isinstance(a, Argument):
+        default: str | None = None
+        if should_default and a.default is not None:
+            default = cpp.default_expr(a.default, a.type, symint=symint)
+        return [
+            Binding(
+                nctype=argument_type(a, binds=a.name, symint=symint),
+                name=a.name,
+                default=default,
+                argument=a,
+            )
+        ]
+    elif isinstance(a, SelfArgument):
+        # Erase SelfArgument from the distinction
+        return argument(a.argument, is_out=is_out, symint=symint)
+    elif isinstance(a, TensorOptionsArguments):
+        default = None
+        if should_default:
+            default = "{}"
+        # TODO: Not sure why the arguments assigned here are for
+        # TensorOptionsArguments and not the constituent pieces.  It seems
+        # to matter
+        return [
+            Binding(
+                nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))),
+                name="dtype",
+                default=default,
+                argument=a,
+            ),
+            Binding(
+                nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))),
+                name="layout",
+                default=default,
+                argument=a,
+            ),
+            Binding(
+                nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))),
+                name="device",
+                default=default,
+                argument=a,
+            ),
+            Binding(
+                nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))),
+                name="pin_memory",
+                default=default,
+                argument=a,
+            ),
+        ]
+    else:
+        assert_never(a)
+
+
+def arguments(func: FunctionSchema, *, symint: bool) -> list[Binding]:
+    args: list[Argument | TensorOptionsArguments | SelfArgument] = []
+    args.extend(func.arguments.non_out)
+    args.extend(func.arguments.out)
+    return [
+        r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn())
+    ]
diff --git a/phivenv/Lib/site-packages/torchgen/api/python.py b/phivenv/Lib/site-packages/torchgen/api/python.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6b0a5ec735fec95448a08f537512bc5e6900c05
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/api/python.py
@@ -0,0 +1,1548 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING
+
+from torchgen.api import cpp
+from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
+from torchgen.gen import pythonify_default
+from torchgen.model import (
+    Argument,
+    BaseTy,
+    BaseType,
+    FunctionSchema,
+    ListType,
+    NativeFunction,
+    OptionalType,
+    Return,
+    Type,
+    Variant,
+)
+
+
+if TYPE_CHECKING:
+    from collections.abc import Iterable, Sequence
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                           Data Models
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+# [Notes] python binding codegen
+#
+# The Python binding codegen produces code that takes the input list of
+# PyObjects, finds the matching ATen C++ function using PythonArgParser,
+# converts the PyObjects into C++ types and calls the ATen C++ function:
+#
+# +--------+  parsing   +------------------------+  binding   +-----------------------+
+# | PyObjs | ---------> | PythonArgParser Output | ---------> | Cpp Function Dispatch |
+# +--------+            +------------------------+            +-----------------------+
+#
+# The following examples demonstrate the data models the Python binding
+# codegen needs to deal with and the tasks it needs to accomplish. It
+# helps understand the purpose of the new data types we introduced below.
+#
+#  - Function Schema (source of truth)
+#
+#      aten::empty.names(int[] size, *, Dimname[]? names,
+#                        ScalarType? dtype=None, Layout? layout=None,
+#                        Device? device=None, bool? pin_memory=None,
+#                        MemoryFormat? memory_format=None) -> Tensor
+#
+#  - Python Signature
+#
+#    It's used to generate input schema string for PythonArgParser.
+#    Note: TensorOptions fields are reordered and the additional
+#    'requires_grad' field is added:
+#
+#      empty(IntArrayRef size, *, DimnameList? names,
+#            MemoryFormat? memory_format=None, ScalarType dtype=None,
+#            Layout layout=torch.strided, Device device=None,
+#            bool pin_memory=False, bool requires_grad=False)
+#
+#  - C++ Signature
+#
+#    It's used to generate C++ lambda formals & dispatch call.
+#    Note: the scattered TensorOptions fields are packed into 'options'.
+#
+#      auto dispatch_empty =
+#          [](IntArrayRef size, std::optional names,
+#             const TensorOptions & options,
+#             std::optional memory_format) -> Tensor {
+#          pybind11::gil_scoped_release no_gil;
+#          return torch::empty(size, names, options, memory_format);
+#      };
+#
+#  - Binding between Python Arguments and C++ Arguments
+#
+#    Given a set of Python Arguments in scope, we need produce the
+#    binding expressions that translate the Python API into C++ API:
+#
+#            Python Args               Cpp Args       Binding Exprs
+#     -----------------------------------------------------------------
+#         0: size                      size           '_r.intlist(0)'
+#         1: names                     names          'names' [special init]
+#         2: memory_format -------+
+#         3: dtype         -----+-|--> options        'options' [special packing]
+#         4: layout            /  |
+#         5: device           /   +--> memory_format  '_r.memoryformatOptional(2)'
+#         6: pin_memory      /
+#         7: requires_grad -+
+#
+#    So the full dispatch expression would look like:
+#
+#      dispatch_empty(_r.intlist(0), names, options,
+#                     _r.memoryformatOptional(2))
+#
+#    Where does 'names' come from? It involves special local init:
+#
+#      auto __names = _r.toDimnameListOptional(1);
+#      std::optional names =
+#          __names ? std::make_optional(DimnameList(__names.value()))
+#                  : std::nullopt;
+#
+#    Where does 'options' come from? It involves special local init
+#    for TensorOptions. Note that Python side has the additional
+#    'requires_grad' field:
+#
+#      const auto options = TensorOptions()
+#          .dtype(_r.scalartype(3))
+#          .device(_r.device(5))
+#          .layout(_r.layoutOptional(4))
+#          .requires_grad(_r.toBool(7))
+#          .pinned_memory(_r.toBool(6));
+#
+#    In some other cases one Python Argument can map to multiple C++
+#    Arguments. For example:
+#
+#     aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False)
+#       -> (Tensor values, Tensor indices)
+#
+#            Python Args               Cpp Args          Binding Exprs
+#     ---------------------------------------------------------------------
+#                               +----> max               'out[0]'
+#                              /-----> max_values        'out[1]
+#         0: input            /        self              '_r.tensor(0)'
+#         1: dim             /         dim               '_r.dimname(1)'
+#         2: keepdim        /          keepdim           '_r.toBool(2)'
+#         3: out      -----+           [local init] out  '_r.tensorlist_n<2>(3)'
+#
+#    As demonstrated above, the binding can involve reordering,
+#    packing, unpacking and special local inits.
+#
+#
+#  Let's look at a concrete example:
+#
+#      static PythonArgParser parser({
+#        "abs(Tensor input, *, Tensor out=None)",
+#        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#         ^
+#         +--- Python Schema, represented by PythonSignature and PythonArgument
+#
+#      }, /*traceable=*/true);
+#
+#      ParsedArgs<2> parsed_args;
+#      auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
+#
+#      ...
+#
+#      if (_r.isNone(1)) {
+#          ~~~~~~~~~~~~  <--- Scattered PythonArgParser output (arg name = 'out')
+#                             represented by PythonArgParserOutputExpr
+#
+#        // aten::abs(Tensor self) -> Tensor
+#        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#         ^
+#         +--- NativeFunction schema, base version
+#
+#        auto dispatch_abs = [](const Tensor & self) -> Tensor {
+#                            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#                             ^
+#                             +--- dispatch_lambda_args / dispatch_lambda_return_str
+#                                  generated from NativeFunction / CppSignature
+#                                  (deprecated PythonSignature is special)
+#                                  arguments are represented by DispatchLambdaArgument
+#
+#          pybind11::gil_scoped_release no_gil;
+#          return self.abs();
+#                 ~~~~~~~~~~~  <--- cpp_dispatch_target / cpp_dispatch_exprs
+#                                   generated from NativeFunction / CppSignature
+#        };
+#        return wrap(dispatch_abs(_r.tensor(0)));
+#                                 ~~~~~~~~~~~~~
+#                                  ^
+#                                  +--- dispatch_lambda_exprs
+#                                       binding PythonArgParserOutputExpr (python args)
+#                                       and DispatchLambdaArgument (c++ args)
+#
+#      } else {
+#        // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+#        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+#         ^
+#         +--- NativeFunction schema, out-variant
+#
+#        auto dispatch_abs_out = [](Tensor out, const Tensor & self) -> Tensor {
+#          pybind11::gil_scoped_release no_gil;
+#          return at::abs_out(out, self);
+#        };
+#        return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0)));
+#      }
+#
+#
+# [Notes] python interface codegen
+# The python dataclasses below are used used to generate both python binding code
+# and pyi type hint signatures.
+# In theory these two should look very similar, but there are number of differences
+# in how pyi signatures vs. python_arg_parser signatures are generated.
+# These differences have been encapsulated in signature_str() vs. signature_str_pyi()
+# to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments.
+# For examples, only pyi signatures include return types.
+
+
+def format_function_signature(
+    name: str, arguments: Iterable[str] = (), return_type: str | None = None
+) -> str:
+    if not isinstance(arguments, (list, tuple)):
+        arguments = tuple(arguments)
+    return_type = f" -> {return_type}" if return_type is not None else ""
+
+    sig = f"def {name}({', '.join(arguments)}){return_type}: ..."
+    if len(sig) <= 80 or len(arguments) == 0 or tuple(arguments) == ("self",):
+        return sig
+
+    lines = [
+        f"def {name}(",
+        *(f"    {arg}," for arg in arguments),
+        f"){return_type}: ...",
+    ]
+    sig = "\n".join(lines)
+    if all(len(line) <= 80 for line in lines):
+        return sig
+    # ruff format bug for compound statements: https://github.com/astral-sh/ruff/issues/18658
+    # use `skip` instead of `on` + `off`
+    return sig.removesuffix(" ...") + "  # fmt: skip\n    ..."
+
+
+@dataclass(frozen=True)
+class PythonReturns:
+    returns: tuple[Return, ...]
+
+
+@dataclass(frozen=True)
+class PythonArgument:
+    name: str
+    type: Type
+    default: str | None
+
+    # Used to generate the default init expr for some PythonArgParser outputs, e.g.:
+    #
+    #   _r.layoutWithDefault(3, layout_from_backend(self.options().backend())))
+    #                           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+    #                            ^
+    #                            +--- default_init str
+    default_init: str | None
+
+    # Compute argument formal for python argument parsing.
+    # Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
+    def argument_str(self, *, method: bool = False, symint: bool = True) -> str:
+        type_str = (
+            argument_type_str(self.type, symint=symint)
+            .replace("const ", "")
+            .replace(" &", "")
+        )
+
+        name = self.name
+        # s/self/input/ outside method bindings
+        # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
+        # for the parse string
+        if name == "self" and type_str in ["Tensor", "Number"] and not method:
+            name = "input"
+
+        # add default
+        if self.default is not None:
+            default = {
+                "nullptr": "None",
+                "::std::nullopt": "None",
+                "std::nullopt": "None",
+                "{}": "None",
+            }.get(self.default, self.default)
+            return f"{type_str} {name}={default}"
+        else:
+            return f"{type_str} {name}"
+
+    def argument_str_pyi(
+        self, *, method: bool = False, deprecated: bool = False
+    ) -> str:
+        type_str = argument_type_str_pyi(self.type)
+
+        name = self.name
+        # s/self/input/ outside method bindings
+        # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
+        # for the parse string
+        if name == "self" and type_str == "Tensor" and not method and not deprecated:
+            name = "input"
+
+        if name == "from":  # from is a Python keyword...
+            name += "_"
+
+        # pyi merges the _out and functional variants into the same signature, with an optional out arg
+        if name == "out" and type_str == "Tensor" and not deprecated:
+            type_str = f"{type_str} | None".replace(" | None | None", " | None")
+
+        # pyi deprecated signatures don't get defaults for their out arg
+        treat_as_no_default = (
+            deprecated
+            and isinstance(self, PythonOutArgument)
+            and self.default == "None"
+        )
+
+        # add default
+        if self.default is not None and not treat_as_no_default:
+            if (
+                isinstance(self.type, ListType)
+                and self.type.elem == BaseType(BaseTy.int)
+                and self.default.startswith("{")
+                and self.default.endswith("}")
+            ):
+                default = (
+                    "(" + ", ".join(map(str.strip, self.default[1:-1].split(","))) + ")"
+                )
+            else:
+                default = {
+                    "nullptr": "None",
+                    "::std::nullopt": "None",
+                    "std::nullopt": "None",
+                    "{}": "None",
+                    "c10::MemoryFormat::Contiguous": "contiguous_format",
+                    "QScheme::PER_TENSOR_AFFINE": "per_tensor_affine",
+                }.get(self.default, self.default)
+            return f"{name}: {type_str} = {default}"
+        else:
+            return f"{name}: {type_str}"
+
+
+@dataclass(frozen=True)
+class PythonOutArgument(PythonArgument):
+    # In Python signature multiple output fields are packed into one 'out' argument.
+    # When binding to C++, it's first binded to a local 'out' variable:
+    #   'auto out = _r.tensorlist_n<2>(2);',
+    # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
+    # TODO: maybe don't need keep scattered out fields for python signature?
+    outputs: tuple[PythonArgument, ...]
+
+    @staticmethod
+    def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None:
+        if not outputs:
+            return None
+
+        size = len(outputs)
+        if size == 1:
+            return PythonOutArgument(
+                name=outputs[0].name,
+                type=outputs[0].type,
+                default="None",
+                default_init=None,
+                outputs=outputs,
+            )
+        elif size > 1:
+            if any(not a.type.is_tensor_like() for a in outputs):
+                raise RuntimeError(f"Unsupported output type: {outputs}")
+            return PythonOutArgument(
+                name="out",
+                # TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None?
+                type=ListType(BaseType(BaseTy.Tensor), size),
+                default="None",
+                default_init=None,
+                outputs=outputs,
+            )
+        raise AssertionError(r"Unexpected PythonOutArgument size")
+
+
+@dataclass(frozen=True)
+class PythonSignature:
+    # Base operator name, without inplace/outplace suffix.
+    name: str
+
+    # Positional arguments.
+    # TODO: create a dedicated SelfArgument type for 'self'?
+    input_args: tuple[PythonArgument, ...]
+
+    # Keyword arguments excluding the 'out' argument and scattered kwargs belonging
+    # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
+    input_kwargs: tuple[PythonArgument, ...]
+
+    output_args: PythonOutArgument | None
+
+    # Return types, which are only used by pyi
+    returns: PythonReturns
+
+    # These are scattered kwargs arguments belonging to TensorOptions.
+    # When binding to C++, they are packed into a TensorOptions object 'options'.
+    # It's possible that the C++ signature doesn't take TensorOptions object (e.g.
+    # for out variant), in which case they will be used as scattered fields without
+    # being packed into 'options'.
+    # TODO: maybe create a PythonTensorOptionsArgument?
+    tensor_options_args: tuple[PythonArgument, ...]
+
+    # method or function signature?
+    method: bool
+
+    @property
+    def deprecated(self) -> bool:
+        return False
+
+    def arguments(
+        self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
+    ) -> tuple[PythonArgument | PythonOutArgument, ...]:
+        result: list[PythonArgument | PythonOutArgument] = []
+        result.extend(self.input_args)
+        result.extend(self.input_kwargs)
+        if self.output_args is not None and not skip_outputs:
+            result.append(self.output_args)
+        if not skip_tensor_options:
+            result.extend(self.tensor_options_args)
+        return tuple(result)
+
+    def arguments_count(self) -> int:
+        return len(self.arguments())
+
+    def output_idx(self) -> int:
+        return len(self.input_args) + len(self.input_kwargs)
+
+    # [old codegen] Compute the Python function signature for argument parsing,
+    # as specified in torch/csrc/utils/python_arg_parser.h.  WARNING:
+    # this is NOT the same type signature as specified by PEP 484
+    # as understood by mypy; our format was independently developed
+    # and has some quirks to make it more suitable specifically
+    # for error parsing.
+    #
+    # For a translation to mypy-valid type signatures, see
+    # signature_str_pyi().
+    def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
+        args = self.arguments(skip_outputs=skip_outputs)
+        schema_formals: list[str] = [
+            a.argument_str(method=self.method, symint=symint) for a in args
+        ]
+        positional_argc = len(self.input_args)
+        if len(schema_formals) > positional_argc:
+            schema_formals.insert(positional_argc, "*")
+
+        return f"{self.name}({', '.join(schema_formals)})"
+
+    def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
+        args = self.arguments(skip_outputs=skip_outputs)
+        schema_formals: list[str] = [
+            a.argument_str_pyi(method=self.method) for a in args
+        ]
+        positional_argc = len(self.input_args)
+        if len(schema_formals) > positional_argc:
+            schema_formals.insert(positional_argc, "*")
+
+        # only pyi signatures include returns
+        returns_str = returns_str_pyi(self)
+        # pyi also includes self (with no typing/defaults) for methods
+        if self.method:
+            schema_formals.insert(0, "self")
+        return format_function_signature(self.name, schema_formals, returns_str)
+
+    def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
+        # only pyi uses vararg signatures
+        args = self.arguments(skip_outputs=skip_outputs)
+        schema_formals: list[str] = [
+            a.argument_str_pyi(method=self.method) for a in args
+        ]
+        # vararg only applies to pyi signatures. vararg variants are not generated for all signatures
+        num_args = self.arguments_count()
+        if num_args == 0:
+            return None
+
+        num_positionalargs = len(self.input_args)
+
+        vararg_type = args[0].type
+        if not (
+            isinstance(vararg_type, ListType)
+            and str(vararg_type.elem) in ["int", "SymInt"]
+            and num_positionalargs == 1
+        ):
+            return None
+
+        # Below are the major changes in vararg vs. regular pyi signatures
+        # vararg signatures also omit the asterix
+        assert isinstance(vararg_type, ListType)
+        schema_formals[0] = (
+            "*" + args[0].name + ": " + argument_type_str_pyi(vararg_type.elem)
+        )
+
+        returns_str = returns_str_pyi(self)
+        # pyi also includes self (with no typing/defaults) for methods
+        if self.method:
+            schema_formals.insert(0, "self")
+        return format_function_signature(self.name, schema_formals, returns_str)
+
+
+# The deprecated python signature involves some special logic, so create a
+# dedicated data model to store these extra properties.
+@dataclass(frozen=True)
+class PythonSignatureDeprecated(PythonSignature):
+    # Schema for the deprecated function
+    deprecated_schema: FunctionSchema
+
+    # The deprecated signature might miss some arguments that the corresponding
+    # C++ signature expects. We need store the constant default values to pass in.
+    # For example:
+    #   [deprecate signature]: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2)
+    #   [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+    #   [func call]: self.addmm(mat1, mat2, beta, 1)
+    # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
+    deprecated_args_exprs: tuple[str, ...]
+
+    @property
+    def deprecated(self) -> bool:
+        return True
+
+    def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
+        return (
+            PythonSignature.signature_str(
+                self, skip_outputs=skip_outputs, symint=symint
+            )
+            + "|deprecated"
+        )
+
+    def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
+        args = self.arguments(skip_outputs=skip_outputs)
+        schema_formals: list[str] = [
+            a.argument_str_pyi(method=self.method, deprecated=True) for a in args
+        ]
+        positional_argc = len(self.input_args)
+        if len(schema_formals) > positional_argc:
+            schema_formals.insert(positional_argc, "*")
+
+        returns_str = returns_str_pyi(self)
+        return format_function_signature(self.name, schema_formals, returns_str)
+
+    def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
+        # the codegen doesn't include vararg variants for deprecated signatures
+        return None
+
+
+# This struct is used to hold the PythonSignature and its corresponding
+# NativeFunction BEFORE grouping base and out-variant functions.
+# Why not store NativeFunction in PythonSignature or construct PythonSignature
+# from NativeFunction? Because they are not 1-1 mapped.
+# One native function could have both deprecated and non-deprecated python
+# signatures - NativeFunction doesn't contain information to construct the
+# deprecated python signature.
+# One python signature is used to handle both the base and the out-variant
+# function - see 'PythonSignatureGroup'.
+@dataclass(frozen=True)
+class PythonSignatureNativeFunctionPair:
+    signature: PythonSignature
+    function: NativeFunction
+
+
+# We merge pairs of functions with signatures that are equivalent mod
+# output arguments, and use a single entry in the python_arg_parser sig
+# list for both (output arguments become optional).
+@dataclass(frozen=True)
+class PythonSignatureGroup:
+    # The signature used for Python argument parsing. The outplace signature
+    # is preferred if exists, because it can be used to parse inputs for both
+    # the out-place variant and the base version (with output omitted).
+    signature: PythonSignature
+
+    # The regular ATen declaration (e.g. conv2d)
+    base: NativeFunction
+
+    # The out variant (e.g. conv2d_out)
+    outplace: NativeFunction | None
+
+    @classmethod
+    def from_pairs(
+        cls,
+        functional: PythonSignatureNativeFunctionPair,
+        out: PythonSignatureNativeFunctionPair | None,
+    ) -> PythonSignatureGroup:
+        if out is None:
+            return PythonSignatureGroup(
+                signature=functional.signature,
+                base=functional.function,
+                outplace=None,
+            )
+
+        # prefer the signature with optional out=... arguments because it's the
+        # superset that can be used to parse input for both base and outplace.
+        signature_kwargs = out.signature.__dict__.copy()
+
+        # Out overloads in C++ don't have TensorOptions arguments,
+        # so take these from the functional variant
+        signature_kwargs["tensor_options_args"] = (
+            functional.signature.tensor_options_args
+        )
+
+        return PythonSignatureGroup(
+            signature=type(out.signature)(**signature_kwargs),
+            base=functional.function,
+            outplace=out.function,
+        )
+
+
+# C++ function dispatch is wrapped in a lambda function. The lambda function
+# has almost the same signature as the C++ function, only with some small
+# variants - see details below.
+# This data model is used to represent arguments of the lambda function
+# signature.
+@dataclass(frozen=True)
+class DispatchLambdaArgument:
+    name: str
+    type_str: str
+    is_out_arg: bool
+
+
+# To pass PyObjects arguments to C++ function (via the lambda wrapper),
+# we need first convert PyObjects into simple C++ objects. This work
+# is done by PythonArgParser.
+# This data model is used to represent the output of PythonArgParser.
+# It has 1-1 mapping with PythonArgument in PythonSignature.
+@dataclass(frozen=True)
+class PythonArgParserOutputExpr:
+    # argument name
+    name: str
+
+    # RHS expression to reference PythonArgParser output.
+    expr: str
+
+    # In some special cases we need create different expr, e.g.:
+    # '_r.isNone(1)' instead of '_r.tensor(1)'.
+    index: int
+
+    # The python argument it maps to.
+    argument: PythonArgument
+
+    @property
+    def is_none_expr(self) -> str:
+        return f"_r.isNone({self.index})"
+
+
+# To pass PythonArgParser output to the lambda wrapper, we need bind
+# PythonArgParserOutputExpr to DispatchLambdaArgument.
+# They are not always 1-1 mapped, e.g. scattered TensorOptions fields
+# need be packed into a TensorOptions object, which is the argument
+# that the lambda function wrapper takes.
+@dataclass(frozen=True)
+class DispatchLambdaArgumentExprs:
+    # The exprs that provide the binding for lambda arguments, e.g.:
+    #
+    #   'self' -> '_r.tensor(0)'
+    #   'min' -> 'out[0]' / 'min_indices' -> 'out[1]'
+    #   'options' -> 'options'
+    #
+    # It has 1-1 mapping with DispatchLambdaArgument.
+    exprs: Sequence[str]
+
+    # Special local inits, which might introduce new variables that
+    # the 'exprs' above reference, e.g.:
+    #
+    #   'auto out = _r.tensorlist_n<2>(2);'
+    #
+    inits: Sequence[str]
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                          Helper Functions
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature:
+    return CppSignatureGroup.from_native_function(f, method=method).signature
+
+
+def has_tensor_options(f: NativeFunction) -> bool:
+    return f.func.arguments.tensor_options is not None
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                          Python Signature
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+# 'simple_type' was introduced by the old codegen, which is slightly
+# different from the python schema type, e.g.: doesn't have '?' suffix
+# for optional Tensor/TensorList; doesn't have '[size]' suffix for list type.
+def argument_type_str(
+    t: Type, *, simple_type: bool = False, symint: bool = True
+) -> str:
+    if isinstance(t, BaseType):
+        if t.name == BaseTy.int:
+            return "int64_t"
+        elif t.name == BaseTy.float:
+            return "double"
+        elif t.name == BaseTy.str:
+            return "c10::string_view"
+        elif t.name in [
+            BaseTy.Tensor,
+            BaseTy.bool,
+            BaseTy.QScheme,
+            BaseTy.Scalar,
+            BaseTy.ScalarType,
+            BaseTy.Generator,
+            BaseTy.Storage,
+            BaseTy.Layout,
+            BaseTy.Device,
+            BaseTy.DeviceIndex,
+            BaseTy.MemoryFormat,
+            BaseTy.Dimname,
+            BaseTy.Stream,
+            BaseTy.SymInt,
+        ]:
+            # These python schema type names line up with their function schema names
+            return t.name.name
+
+    elif isinstance(t, OptionalType):
+        elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
+        return f"{elem}?"
+    elif isinstance(t, ListType):
+        size = t.size if not simple_type else None
+        if str(t.elem) == "bool":
+            assert t.size is not None
+            return f"::std::array"
+        elif str(t.elem) == "int":
+            return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
+        elif str(t.elem) == "SymInt":
+            if symint:
+                return (
+                    f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef"
+                )
+            else:
+                return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
+        elif str(t.elem) == "Tensor":
+            return f"TensorList[{size}]" if size is not None else "TensorList"
+        elif str(t.elem) == "Scalar":
+            return f"ScalarList[{size}]" if size is not None else "ScalarList"
+        elif str(t.elem) == "Tensor?":
+            if simple_type:
+                return "c10::List<::std::optional>"
+            else:
+                return "const c10::List<::std::optional> &"
+        elif str(t.elem) == "Dimname":
+            return f"DimnameList[{size}]" if size is not None else "DimnameList"
+        elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
+        return f"ArrayRef<{elem}>"
+
+    raise RuntimeError(f"unrecognized type {repr(t)}")
+
+
+def argument_type_size(t: Type) -> int | None:
+    l = t.is_list_like()
+    if l is not None and str(l.elem) != "bool":
+        return l.size
+    else:
+        return None
+
+
+def argument(a: Argument) -> PythonArgument:
+    return PythonArgument(
+        name=a.name,
+        type=a.type,
+        # TODO: directly translate a.default to python default
+        default=(
+            str(pythonify_default(cpp.default_expr(a.default, a.type, symint=False)))
+            if a.default is not None
+            else None
+        ),
+        default_init=None,
+    )
+
+
+# Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen
+def signature(
+    f: NativeFunction, *, method: bool = False, pyi: bool = False
+) -> PythonSignature:
+    return signature_from_schema(
+        f.func, category_override=f.category_override, method=method, pyi=pyi
+    )
+
+
+def signature_from_schema(
+    func: FunctionSchema,
+    *,
+    category_override: str | None,
+    method: bool = False,
+    pyi: bool = False,
+) -> PythonSignature:
+    args: list[Argument] = []
+    args.extend(func.arguments.pre_self_positional)
+    # Skip SelfArgument if this is method.
+    if not method and func.arguments.self_arg is not None:
+        args.append(func.arguments.self_arg.argument)
+    args.extend(func.arguments.post_self_positional)
+    args.extend(func.arguments.pre_tensor_options_kwarg_only)
+    # Skip TensorOptionsArguments. Python side TensorOptions
+    # arguments are created based on different rules - see below.
+    args.extend(func.arguments.post_tensor_options_kwarg_only)
+    args.extend(func.arguments.out)
+
+    input_arg_set = {a.name for a in func.arguments.flat_positional}
+    kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only}
+    out_arg_set = {a.name for a in func.arguments.out}
+
+    input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args)))
+    input_kwargs = tuple(
+        map(argument, filter(lambda a: a.name in kwarg_only_set, args))
+    )
+    outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args)))
+
+    # Reintroduce the scattered fields of TensorOptions for Python.
+    # Compared to the cpp counterpart, the python arguments have new property
+    # (default_init) and a new argument 'requires_grad', which require some
+    # special handlings.
+    # [old codegen] TODO: because these aren't guaranteed to be 100% faithful
+    # to the original versions in the yaml, this recreation is a potential
+    # source of drift between eager and JIT. Pull this logic out to a shared place.
+
+    has_tensor_input_arg = any(
+        a.type.is_tensor_like() for a in func.arguments.flat_non_out
+    )
+    if any(a.name == "requires_grad" for a in func.schema_order_arguments()):
+        raise ValueError(
+            "argument named requires_grad is reserved, should not explicitly add it in the schema"
+        )
+
+    # [old codegen] this probably won't work if one of the returns is not a tensor,
+    # but it will produce a compile-time error that is obvious.
+    has_tensor_return = any(r.type.is_tensor_like() for r in func.returns)
+
+    name: str = cpp.name(func)
+    is_factory_function = category_override == "factory" or (
+        has_tensor_return and not has_tensor_input_arg
+    )
+    is_like_or_new_function = (
+        category_override in ("new", "like")
+        or name.startswith("new_")
+        or name.endswith("_like")
+    )
+    is_dummy_function = category_override == "dummy"
+
+    tensor_options_args: list[PythonArgument] = []
+    if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
+
+        def topt_default_init(name: str) -> str | None:
+            topt_args = func.arguments.tensor_options
+            if topt_args is None:
+                return None
+            a = getattr(topt_args, name)
+            if a.default is None or a.default == "None":
+                return None
+            return cpp.default_expr(a.default, a.type, symint=False)
+
+        tensor_options_args.append(
+            PythonArgument(
+                name="dtype",
+                type=OptionalType(BaseType(BaseTy.ScalarType)),
+                default="None",
+                default_init=(
+                    None if is_like_or_new_function else topt_default_init("dtype")
+                ),
+            )
+        )
+        tensor_options_args.append(
+            PythonArgument(
+                name="layout",
+                type=OptionalType(BaseType(BaseTy.Layout)),
+                default="None",
+                default_init=(
+                    None if is_like_or_new_function else topt_default_init("layout")
+                ),
+            )
+        )
+        tensor_options_args.append(
+            PythonArgument(
+                name="device",
+                type=OptionalType(BaseType(BaseTy.Device)),
+                default="None",
+                default_init=(
+                    None
+                    if is_like_or_new_function
+                    else (
+                        topt_default_init("device")
+                        or "torch::tensors::get_default_device()"
+                    )
+                ),
+            )
+        )
+        tensor_options_args.append(
+            PythonArgument(
+                name="pin_memory",
+                type=OptionalType(BaseType(BaseTy.bool)),
+                default="False",
+                default_init=None,
+            )
+        )
+        tensor_options_args.append(
+            PythonArgument(
+                name="requires_grad",
+                type=OptionalType(BaseType(BaseTy.bool)),
+                default="False",
+                default_init=None,
+            )
+        )
+
+    returns = PythonReturns(returns=func.returns)
+
+    return PythonSignature(
+        name=str(func.name.name),
+        input_args=input_args,
+        input_kwargs=input_kwargs,
+        output_args=PythonOutArgument.from_outputs(outputs),
+        tensor_options_args=tuple(tensor_options_args),
+        returns=returns,
+        method=method,
+    )
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                          Python Interface
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]:
+    if len(returns) <= 1 or all(r.name is None for r in returns):
+        return []
+    else:
+        if any(r.name is None for r in returns):
+            # When building on Windows, `PyStructSequence_UnnamedField` could not be
+            # resolved by the linker for some reason, which cause error in building:
+            #
+            # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
+            # PyStructSequence_UnnamedField
+            #
+            # Thus, at this point in time, we do not support unnamed
+            # fields in structseq; you must either name all fields,
+            # or none of them.
+            raise ValueError("Unnamed field is not supported by codegen")
+
+        return [str(r.name) for r in returns]
+
+
+def argument_type_str_pyi(t: Type) -> str:
+    add_optional = False
+    if isinstance(t, OptionalType):
+        t = t.elem
+        add_optional = True
+
+    ret = ""
+    if isinstance(t, BaseType):
+        if t.name in [BaseTy.int, BaseTy.DeviceIndex]:
+            ret = "_int"
+        if t.name == BaseTy.SymInt:
+            ret = "_int | SymInt"
+        elif t.name == BaseTy.float:
+            ret = "_float"
+        elif t.name == BaseTy.str:
+            ret = "str"
+        elif t.name == BaseTy.Scalar:
+            ret = "Number | _complex"
+        elif t.name == BaseTy.ScalarType:
+            ret = "_dtype"
+        elif t.name == BaseTy.bool:
+            ret = "_bool"
+        elif t.name == BaseTy.QScheme:
+            ret = "_qscheme"
+        elif t.name == BaseTy.Layout:
+            ret = "_layout"
+        elif t.name == BaseTy.Device:
+            ret = "DeviceLikeType | None"
+        elif t.name == BaseTy.MemoryFormat:
+            ret = "memory_format"
+        elif t.name == BaseTy.Dimname:
+            ret = "str | EllipsisType | None"
+        elif t.name == BaseTy.Storage:
+            ret = "Storage | UntypedStorage"
+        elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Stream]:
+            # These python schema type names line up with their function schema names
+            ret = t.name.name
+
+    elif isinstance(t, ListType):
+        if str(t.elem) == "int":
+            ret = "_int | _size" if t.size is not None else "_size"
+        elif t.is_tensor_like():
+            # TODO: this doesn't seem right...
+            # Tensor?[] currently translates to tuple[Tensor, ...] | list[Tensor] | None
+            # It should probably translate to   tuple[Tensor | None, ...] | list[Tensor | None]
+            add_optional = True
+            ret = (
+                "Tensor | tuple[Tensor, ...] | list[Tensor]"
+                if t.size is not None
+                else "tuple[Tensor, ...] | list[Tensor]"
+            )
+        elif str(t.elem) == "float":
+            ret = "Sequence[_float]"
+        elif str(t.elem) == "SymInt" and t.size is not None:
+            elem = argument_type_str_pyi(t.elem)
+            ret = f"{elem} | Sequence[{elem}]"
+        else:
+            elem = argument_type_str_pyi(t.elem)
+            ret = f"Sequence[{elem}]"
+
+    else:
+        raise RuntimeError(f"unrecognized type {repr(t)}")
+
+    if add_optional:
+        ret = f"{ret} | None".replace(" | None | None", " | None")
+
+    return ret
+
+
+def return_type_str_pyi(t: Type) -> str:
+    # Where arguments are open to accepting Union, return types should return
+    # concrete types
+
+    if isinstance(t, OptionalType):
+        inner = return_type_str_pyi(t.elem)
+        return f"{inner} | None".replace(" | None | None", " | None")
+
+    if isinstance(t, BaseType):
+        if t.name == BaseTy.Device:
+            return "_device"
+        elif t.name == BaseTy.Dimname:
+            return "str | None"
+        else:
+            return argument_type_str_pyi(t)
+
+    if isinstance(t, ListType):
+        inner = return_type_str_pyi(t.elem)
+        return f"tuple[{inner}, ...]"
+
+    return argument_type_str_pyi(t)
+
+
+def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
+    python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
+    structseq_name = signature.name
+    field_names = structseq_fieldnames(signature.returns.returns)
+    if field_names:
+        # These types are structseq objects which act like named NamedTuples, but
+        # the constructor acts like the constructor of tuple. Using typing.NamedTuple
+        # does not allow us to override __init__.
+        seq_type = f"tuple[{', '.join(python_returns)}]"
+        structseq_def_lines = [
+            f"class {structseq_name}({seq_type}):  # fmt: skip",
+        ]
+        for name, ret_type in zip(field_names, python_returns):
+            structseq_def_lines.extend(
+                [
+                    "    @property",
+                    f"    def {name}(self) -> {ret_type}: ...",
+                ]
+            )
+        structseq_def_lines.extend(
+            [
+                "    def __new__(",
+                "        cls,",
+                f"        sequence: {seq_type},",
+                "    ) -> Self:  # fmt: skip",
+                "        ...",
+                f"    n_fields: Final[_int] = {len(field_names)}",
+                f"    n_sequence_fields: Final[_int] = {len(field_names)}",
+                "    n_unnamed_fields: Final[_int] = 0",
+                "    def __init_subclass__(cls) -> NoReturn: ...  # prohibit subclassing",
+                "",  # add an extra newline
+            ]
+        )
+        structseq_def = "\n".join(structseq_def_lines)
+        # Example:
+        # structseq_def = (
+        #     "class max(tuple[Tensor, Tensor]):  # fmt: skip\n"
+        #     "    @property\n"
+        #     "    def values(self) -> Tensor: ...\n"
+        #     "    @property\n"
+        #     "    def indices(self) -> Tensor: ...\n"
+        #     "    def __new__(\n"
+        #     "        cls,\n"
+        #     "        sequence: tuple[Tensor, Tensor],\n"
+        #     "    ) -> Self:  # fmt: skip\n"
+        #     "        ...\n"
+        #     "    n_fields: Final[_int] = 2",
+        #     "    n_sequence_fields: Final[_int] = 2",
+        #     "    n_unnamed_fields: Final[_int] = 0",
+        #     "    def __init_subclass__(cls) -> NoReturn: ...  # prohibit subclassing",
+        # )
+        return structseq_name, structseq_def
+    return None
+
+
+def returns_str_pyi(signature: PythonSignature) -> str:
+    field_names = structseq_fieldnames(signature.returns.returns)
+    if field_names:
+        return f"torch.return_types.{signature.name}"
+
+    python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
+    if len(python_returns) > 1:
+        return "tuple[" + ", ".join(python_returns) + "]"
+    if len(python_returns) == 1:
+        return python_returns[0]
+    return "None"
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                        C++ Function Dispatch
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+# This section provides APIs to generate the code that does C++ function
+# dispatch. The C++ function call is wrapped by a lambda function.
+# For example:
+#
+#    // aten::selu_(Tensor(a!) self) -> Tensor(a!)
+#    auto dispatch_selu_ = [](Tensor self) -> Tensor {
+#      pybind11::gil_scoped_release no_gil;
+#      return at::selu_(self);
+#    };
+#
+# The lambda function's signature follows the C++ signature in common
+# cases, e.g.:
+#
+#   // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+#   [](const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
+#
+# For out variant the 'out' argument's type is changed from 'Tensor &'
+# to 'Tensor'. It's because when calling the lambda it passes in the
+# PythonArgParser output '_r.tensor(3)', which is stack allocated object
+# and needs to pass by value. Also see comments in 'dispatch_lambda_return_str()'.
+#
+#   // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+#   [](Tensor out, const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
+#
+# For multi-output case it can keep using reference type because the
+# PythonArgParser output has been unpacked to local variables, e.g.:
+#
+#   // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *,
+#   //     Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
+#   [](Tensor & max, Tensor & max_values, const Tensor & self, Dimname dim, bool keepdim) -> std::tuple
+#
+# For deprecated python signature, it should follow deprecated python arg order.
+# TODO: This is to keep same byte-for-byte result as the old codegen - maybe unnecessary?
+
+
+def dispatch_lambda_args(
+    ps: PythonSignature, f: NativeFunction, symint: bool = True
+) -> tuple[DispatchLambdaArgument, ...]:
+    if isinstance(ps, PythonSignatureDeprecated):
+        schema = ps.deprecated_schema
+    else:
+        schema = f.func
+
+    # Start with cpp arguments - dispatch lambda signature always include 'self'
+    cpp_args = cpp.arguments(
+        arguments=schema.arguments,
+        faithful=False,
+        symint=symint,
+        method=False,
+        cpp_no_default_args=f.cpp_no_default_args,
+    )
+    out_args: set[str] = {a.name for a in schema.arguments.out}
+
+    # Convert from cpp argument to lambda argument
+    def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
+        type_str = cpp_arg.type
+        is_out_arg = cpp_arg.name in out_args
+        if ps.method and cpp_arg.name == "self":
+            # For method's 'self', we can use 'const Tensor &' and simply ignore mutability!
+            type_str = "const at::Tensor &"
+        else:
+            # For other cases we need prevent dangling refs to temps (unless it's
+            # unpacked scattered output)
+            # The reason is explained in the comments above and in 'dispatch_lambda_return_str()'.
+            # TODO: avoid this special handling?
+            ensure_temp_safe = len(out_args) <= 1 or not is_out_arg
+            if ensure_temp_safe:
+                type_str = {
+                    "at::Tensor &": "at::Tensor",
+                }.get(type_str, type_str)
+        return DispatchLambdaArgument(
+            name=cpp_arg.name,
+            type_str=type_str,
+            is_out_arg=is_out_arg,
+        )
+
+    return tuple(map(dispatch_lambda_arg, cpp_args))
+
+
+# [old codegen] XXX: if you got here because of an assertion failure, it doesn't mean
+# it's enough to just extend the list here. Before you do this, make sure
+# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
+SUPPORTED_RETURN_TYPES = {
+    "at::Tensor",
+    "::std::tuple",
+    "::std::tuple",
+    "::std::tuple",
+    "::std::tuple",
+    "::std::tuple",
+    "::std::tuple",
+    "::std::tuple",
+    "::std::tuple",
+    "::std::tuple",
+    "::std::tuple",
+    "::std::tuple>",
+    "::std::vector",
+    # Needed for flash attention forw/backward
+    "::std::tuple",
+    "at::Scalar",
+    "bool",
+    "int64_t",
+    "void*",
+    "void",
+    "at::QScheme",
+    "double",
+    "at::IntArrayRef",
+    "at::ScalarType",
+    "at::Stream",
+}
+
+
+def dispatch_lambda_return_str(f: NativeFunction) -> str:
+    # [old codegen] Remove type annotation (e.g. 'Tensor' rather than 'Tensor &')
+    # because the dispatch lambdas take mutable arguments *by value*, not
+    # by reference. If you then return a reference to such an argument, you
+    # will now have a pointer to a dangling stack entry. Not good.
+    #
+    # You want:
+    #
+    #   auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); };
+    #                                            ^^^^^^
+    #
+    # *not*
+    #
+    #   auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); };
+    #                                            ^^^^^^^
+    #
+    # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing
+    # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a
+    # mutable reference to temporary.  Maybe we could assign it to a
+    # variable itself.)
+    returns_without_annotation = tuple(
+        Return(r.name, r.type, None) for r in f.func.returns
+    )
+    return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
+    if return_str not in SUPPORTED_RETURN_TYPES:
+        raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}")
+    return return_str
+
+
+def cpp_dispatch_target(f: NativeFunction) -> str:
+    symint = f.func.has_symint()
+    name = cpp.name(f.func, symint_overload=symint)
+    if Variant.method in f.variants:
+        return f"self.{name}"
+    if Variant.function in f.variants:
+        if has_tensor_options(f) or f.func.name.name.base.endswith("_like"):
+            namespace = "torch"
+        else:
+            namespace = "at"
+        return f"{namespace}::{name}"
+    raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}")
+
+
+def cpp_dispatch_exprs(
+    f: NativeFunction,
+    *,
+    python_signature: PythonSignature | None = None,
+) -> tuple[str, ...]:
+    cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
+
+    exprs: tuple[str, ...] = ()
+    if not isinstance(python_signature, PythonSignatureDeprecated):
+        # By default the exprs are consistent with the C++ signature.
+        exprs = tuple(a.name for a in cpp_args)
+    else:
+        # For deprecated python signature we may need fill in some constants.
+        exprs = tuple(
+            filter(
+                lambda n: n != "out" or f.func.is_out_fn(),
+                python_signature.deprecated_args_exprs,
+            )
+        )
+
+    if Variant.method in f.variants:
+        exprs = tuple(filter("self".__ne__, exprs))
+
+    return exprs
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                     Python / C++ Args Binding
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+# We explicitly enumerate the PythonArgParser unpacking methods for all
+# supported types. This might be more verbose than necessary, partially
+# because of the irregularity of unpacking method naming, partially
+# because we want to mimic the old codegen behavior - to reject
+# unexpected and/or unsupported cases which the old codegen rejects.
+# For certain cases it is intentionally more restrictive than necessary,
+# e.g.: it doesn't accepts doublelist with definite size.
+def arg_parser_unpack_method(
+    t: Type, default: str | None, default_init: str | None, *, symint: bool = True
+) -> str:
+    has_default_init = default_init is not None
+    if has_default_init and str(t) not in (
+        "ScalarType?",
+        "ScalarType",
+        "Device",
+        "Device?",
+        "Layout",
+        "Layout?",
+        "bool",
+        "bool?",
+    ):
+        raise RuntimeError(f"type '{t}' does not supported unpacking with default")
+
+    if isinstance(t, BaseType):
+        if t.name in [
+            BaseTy.Tensor,
+            BaseTy.Stream,
+            BaseTy.Storage,
+            BaseTy.Scalar,
+            BaseTy.Dimname,
+        ]:
+            # These unpack methods line up with their schema names
+            return t.name.name.lower()
+        elif t.name == BaseTy.ScalarType:
+            return "scalartypeWithDefault" if has_default_init else "scalartype"
+        elif t.name == BaseTy.Device:
+            return "deviceWithDefault" if has_default_init else "device"
+        elif t.name == BaseTy.DeviceIndex:
+            return "toInt64"
+        elif t.name == BaseTy.int:
+            return "toInt64"
+        elif t.name == BaseTy.SymInt:
+            return "toSymInt" if symint else "toInt64"
+        elif t.name == BaseTy.bool:
+            return "toBoolWithDefault" if has_default_init else "toBool"
+        elif t.name == BaseTy.float:
+            return "toDouble"
+        elif t.name == BaseTy.str:
+            return "stringView"
+        elif t.name == BaseTy.Layout:
+            return "layoutWithDefault" if has_default_init else "layout"
+        elif t.name == BaseTy.MemoryFormat:
+            return "memoryformat"
+
+    elif isinstance(t, OptionalType):
+        if str(t.elem) == "Tensor":
+            return "optionalTensor"
+        elif str(t.elem) == "Generator":
+            return "generator"
+        elif str(t.elem) == "Dimname[]":
+            return "toDimnameListOptional"
+        elif not has_default_init and default in (
+            None,
+            "None",
+            "::std::nullopt",
+            "std::nullopt",
+        ):
+            # If default is None: append 'Optional' to elem's unpacking method
+            return (
+                arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional"
+            )
+        else:
+            # Otherwise, load as underlying type with default
+            return arg_parser_unpack_method(
+                t.elem, default, default_init, symint=symint
+            )
+
+    elif isinstance(t, ListType):
+        if str(t.elem) == "Tensor":
+            # accept and use definite size
+            return f"tensorlist_n<{t.size}>" if t.size is not None else "tensorlist"
+        elif str(t.elem) == "Tensor?":
+            return "list_of_optional_tensors"
+        elif str(t.elem) == "Dimname":
+            # accept definite size
+            return "dimnamelist"
+        elif str(t.elem) == "int":
+            # accept definite size
+            return "intlist"
+        elif str(t.elem) == "float":
+            return "doublelist"
+        elif str(t.elem) == "SymInt":
+            # accept definite size
+            return "symintlist" if symint else "intlist"
+        elif str(t.elem) == "Scalar":
+            return "scalarlist"
+    raise RuntimeError(f"type '{t}' is not supported by PythonArgParser")
+
+
+# Return RHS expression for python argument using PythonArgParser output.
+# e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)'
+def arg_parser_output_expr(
+    arg_index: int, a: PythonArgument, *, symint: bool = True
+) -> PythonArgParserOutputExpr:
+    has_default = a.default_init is not None
+    unpack_method = arg_parser_unpack_method(
+        t=a.type, default=a.default, default_init=a.default_init, symint=symint
+    )
+    default = f", {a.default_init}" if has_default else ""
+    expr = f"_r.{unpack_method}({arg_index}{default})"
+
+    return PythonArgParserOutputExpr(
+        name=a.name,
+        expr=expr,
+        index=arg_index,
+        argument=a,
+    )
+
+
+# Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
+def arg_parser_output_exprs(
+    ps: PythonSignature, f: NativeFunction, *, symint: bool = True
+) -> dict[str, PythonArgParserOutputExpr]:
+    return {
+        e.name: e
+        for i, a in enumerate(ps.arguments())
+        for e in (arg_parser_output_expr(i, a, symint=symint),)
+    }
+
+
+# argument name to type for scattered tensor options fields
+TENSOR_OPTIONS_FIELDS = {
+    "dtype": "ScalarType?",
+    "device": "Device?",
+    "layout": "Layout?",
+    "pin_memory": "bool?",
+    "requires_grad": "bool?",
+}
+
+
+# bind arg parser outputs (python args) with dispatch lambda arguments (c++ args).
+def dispatch_lambda_exprs(
+    ps: PythonSignature, f: NativeFunction, *, symint: bool = True
+) -> DispatchLambdaArgumentExprs:
+    # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing
+    # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser
+    # outputs.
+    arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
+    lambda_args = dispatch_lambda_args(ps, f, symint=symint)
+    inits: list[str] = []
+    lambda_args_exprs: dict[str, str] = {}
+
+    has_toptions = has_tensor_options(f)
+
+    # 1. special inits/unpacking to provide binding exprs for lambda arguments.
+    for a in ps.arguments(skip_tensor_options=True):
+        name = a.name
+        arg_parser_expr = arg_parser_outputs[a.name].expr
+
+        if has_toptions and name == "self":
+            # TODO: why this needs to be special case?
+            inits.extend(
+                [
+                    f"auto self = {arg_parser_expr};",
+                ]
+            )
+            lambda_args_exprs[name] = name
+        elif (
+            isinstance(a, PythonOutArgument)
+            and len(a.outputs) > 1
+            and f.func.is_out_fn()
+        ):
+            inits.extend(
+                [
+                    f"auto out = {arg_parser_expr};",
+                ]
+            )
+            for i, out_arg in enumerate(a.outputs):
+                lambda_args_exprs[out_arg.name] = f"out[{i}]"
+        elif str(a.type) == "Dimname[]?":
+            # [old codegen]
+            # TODO: make this part of something more general, or get rid of it.
+            # optional> are special. The PythonArgParser returns an
+            # optional>, which cannot be implicitly converted to
+            # optional>. One needs to unwrap the optional and rewrap.
+            inits.extend(
+                [
+                    f"auto __{name} = {arg_parser_expr};",
+                    f"::std::optional {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;",  # noqa: B950
+                ]
+            )
+            lambda_args_exprs[name] = name
+        else:
+            # default case - directly using PythonArgParser output expr
+            lambda_args_exprs[name] = arg_parser_expr
+
+    # method's self is passed directly to python binding, rather than parsed
+    if ps.method:
+        lambda_args_exprs["self"] = "self"
+
+    # 2. special packing/checking for TensorOptions.
+    tensor_options_args_names = [a.name for a in ps.tensor_options_args]
+    if has_toptions:
+        if f.func.is_out_fn():
+            raise RuntimeError(f"{f.func}: tensor options with output arg")
+        for a in ps.tensor_options_args:
+            if a.name not in TENSOR_OPTIONS_FIELDS:
+                raise RuntimeError(
+                    f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments"
+                )
+            if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name):
+                raise RuntimeError(
+                    f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'"
+                )
+        if not all(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS):
+            raise RuntimeError(
+                f"{f.func}: incomplete tensor options args: {tensor_options_args_names}"
+            )
+
+        inits.append(
+            f"""\
+const auto options = TensorOptions()
+    .dtype({arg_parser_outputs["dtype"].expr})
+    .device({arg_parser_outputs["device"].expr})
+    .layout({arg_parser_outputs["layout"].expr})
+    .requires_grad({arg_parser_outputs["requires_grad"].expr})
+    .pinned_memory({arg_parser_outputs["pin_memory"].expr});
+torch::utils::maybe_initialize_device(options);
+"""
+        )
+        lambda_args_exprs["options"] = "options"
+
+    # 3. special case - access scattered TensorOptions fields without packing
+    # TODO: maybe move to the generator side as it's not related to binding.
+    if not has_toptions and tensor_options_args_names:
+        if "dtype" in tensor_options_args_names:
+            # we're an output-arg variant, check these args against output tensor
+            if not f.func.is_out_fn():
+                raise RuntimeError(
+                    f"{f.func}: dtype in tensor_options_args without output arg, {ps} {ps.arguments}"
+                )
+            if not all(a in tensor_options_args_names for a in ("layout", "device")):
+                raise RuntimeError(
+                    f"{f.func}: incomplete tensor options for output check"
+                )
+
+            inits.append(
+                f"""\
+check_out_type_matches({arg_parser_outputs["out"].expr}, {arg_parser_outputs["dtype"].expr},
+                       {arg_parser_outputs["dtype"].is_none_expr}, {arg_parser_outputs["layout"].expr},
+                       {arg_parser_outputs["device"].expr}, {arg_parser_outputs["device"].is_none_expr});
+"""
+            )
+        # we'll set requires_grad on outgoing tensor
+        if "requires_grad" not in tensor_options_args_names:
+            raise RuntimeError(
+                f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]'
+            )
+
+    return DispatchLambdaArgumentExprs(
+        exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args),
+        inits=inits,
+    )
diff --git a/phivenv/Lib/site-packages/torchgen/api/structured.py b/phivenv/Lib/site-packages/torchgen/api/structured.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bf47d5f4e5c7581530a78f68da264dbc9caa10d
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/api/structured.py
@@ -0,0 +1,158 @@
+from __future__ import annotations
+
+from typing_extensions import assert_never
+
+from torchgen.api import cpp
+from torchgen.api.types import (
+    ArgName,
+    ArrayRefCType,
+    BaseCType,
+    Binding,
+    ConstRefCType,
+    dimnameListT,
+    intArrayRefT,
+    iOptTensorListRefT,
+    iTensorListRefT,
+    NamedCType,
+    OptionalCType,
+    optionalIntArrayRefT,
+    optionalScalarRefT,
+    optionalTensorRefT,
+    scalarT,
+    tensorT,
+)
+from torchgen.model import (
+    Argument,
+    BaseTy,
+    BaseType,
+    ListType,
+    NativeFunctionsGroup,
+    OptionalType,
+    SelfArgument,
+    TensorOptionsArguments,
+    Type,
+)
+
+
+# This file describes the translation of JIT schema to the structured functions API.
+# This is similar to native API, but a number of historical problems with native
+# API have been fixed.
+
+
+# Translation of types occurring in JIT arguments to a C++ argument type.
+# NB: For now, mutable doesn't do anything; but it could if we make
+# some more nominal types
+def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
+    # If it's a value type, do the value type translation
+    # NB: structured kernels ALWAYS have symint off, since they involve actual
+    # kernels that require real ints.  The one exception is the
+    # CompositeExplicitAutograd and the meta function (which could
+    # hypothetically be SymInt), but for simplicity we plan for these to just
+    # be handled in Python
+    r = cpp.valuetype_type(t, symint=False, binds=binds, mutable=mutable)
+    if r is not None:
+        return r
+
+    if isinstance(t, BaseType):
+        if t.name == BaseTy.Tensor:
+            return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
+        elif t.name == BaseTy.Scalar:
+            return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
+        else:
+            raise AssertionError(f"base type should have been value type {t}")
+    elif isinstance(t, OptionalType):
+        if t.elem == BaseType(BaseTy.Tensor):
+            return NamedCType(binds, BaseCType(optionalTensorRefT))
+        elif t.elem == BaseType(BaseTy.Scalar):
+            return NamedCType(binds, BaseCType(optionalScalarRefT))
+        elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
+            return NamedCType(binds, BaseCType(optionalIntArrayRefT))
+        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
+        return NamedCType(binds, OptionalCType(elem.type))
+    elif isinstance(t, ListType):
+        if t.elem == BaseType(BaseTy.Tensor):
+            return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
+        elif t.elem == OptionalType(BaseType(BaseTy.Tensor)):
+            return NamedCType(binds, BaseCType(iOptTensorListRefT))
+        # TODO: delete these special cases; see torchgen.api.cpp--these
+        # must be changed in tandem, but there are problems; see
+        # https://github.com/pytorch/pytorch/pull/51485
+        elif str(t.elem) == "int":
+            return NamedCType(binds, BaseCType(intArrayRefT))
+        elif str(t.elem) == "Dimname":
+            return NamedCType(binds, BaseCType(dimnameListT))
+        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
+        return NamedCType(binds, ArrayRefCType(elem.type))
+    else:
+        raise AssertionError(f"unrecognized type {repr(t)}")
+
+
+def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
+    return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
+
+
+# returns_type intentionally omitted, because structured kernels never "return";
+# instead, they always indirectly report their outputs (in the case of a meta
+# function, by calling set_output; in the case of an impl function, by writing
+# directly into the provided out argument).
+
+
+# Structured kernels are never defaulted
+def argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Binding]:
+    if isinstance(a, Argument):
+        return [
+            Binding(
+                nctype=argument_type(a, binds=a.name),
+                name=a.name,
+                default=None,
+                argument=a,
+            )
+        ]
+    elif isinstance(a, SelfArgument):
+        return argument(a.argument)
+    elif isinstance(a, TensorOptionsArguments):
+        raise AssertionError("structured kernels don't support TensorOptions yet")
+    else:
+        assert_never(a)
+
+
+def impl_arguments(g: NativeFunctionsGroup) -> list[Binding]:
+    args: list[Argument | TensorOptionsArguments | SelfArgument] = []
+
+    if g.out.precomputed:
+        # A list of parameters for the impl function with
+        # certain parameters replaced with precomputed counterparts
+        # as specified in native_functions.yaml.
+        non_out_args_replaced: list[
+            Argument | TensorOptionsArguments | SelfArgument
+        ] = []
+        for a in g.out.func.arguments.non_out:
+            if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
+                # If a is in precompute.replace, append the parameters
+                # that should replace it onto non_out_args_replaced.
+                non_out_args_replaced.extend(g.out.precomputed.replace[a.name])
+            else:
+                # If not, push a as it is.
+                non_out_args_replaced.append(a)
+
+        args.extend(non_out_args_replaced)
+        # g.out.precomputed.add is the list of parameters that are added
+        # without replacement after the non out args and just before the out args
+        args.extend(g.out.precomputed.add)
+    else:
+        args.extend(g.out.func.arguments.non_out)
+
+    args.extend(g.out.func.arguments.out)
+    return [r for arg in args for r in argument(arg)]
+
+
+def meta_arguments(g: NativeFunctionsGroup) -> list[Binding]:
+    args: list[Argument | TensorOptionsArguments | SelfArgument] = []
+    args.extend(g.functional.func.arguments.non_out)
+    return [r for arg in args for r in argument(arg)]
+
+
+def out_arguments(g: NativeFunctionsGroup) -> list[Binding]:
+    args: list[Argument | TensorOptionsArguments | SelfArgument] = []
+    args.extend(g.out.func.arguments.out)
+    return [r for arg in args for r in argument(arg)]
diff --git a/phivenv/Lib/site-packages/torchgen/api/translate.py b/phivenv/Lib/site-packages/torchgen/api/translate.py
new file mode 100644
index 0000000000000000000000000000000000000000..e119e436bd296e0d592da40b2682a3e328eb9cad
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/api/translate.py
@@ -0,0 +1,437 @@
+from __future__ import annotations
+
+from typing import NoReturn, TYPE_CHECKING
+
+from torchgen.api.types import (
+    ArrayRefCType,
+    BaseCType,
+    Binding,
+    boolT,
+    ConstRefCType,
+    deviceT,
+    Expr,
+    intArrayRefT,
+    iOptTensorListRefT,
+    layoutT,
+    ListCType,
+    longT,
+    memoryFormatT,
+    MutRefCType,
+    NamedCType,
+    opmath_t,
+    OptionalCType,
+    optionalIntArrayRefT,
+    optionalScalarRefT,
+    optionalSymIntArrayRefT,
+    optionalTensorRefT,
+    scalar_t,
+    scalarT,
+    scalarTypeT,
+    SpecialArgName,
+    symIntArrayRefT,
+    SymIntT,
+    tensorOptionsT,
+    tensorT,
+    VectorCType,
+)
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+
+# This file implements a small program synthesis engine that implements
+# conversions between one API to another.
+#
+# The key data type in this file in NamedCType, short for Named C++ semantic type.  A NamedCType
+# represents a C++ type, plus semantic information about what it represents.
+# For example, consider the argument "bool pin_memory"; its normal C++ type is
+# "bool", but its C++ semantic type also keeps track that this represents a
+# "pin_memory"; you can't just use a random other boolean in a context where you
+# need a "pin_memory"!
+#
+# The translator takes a list of needed NamedCTypes, and then figures out how
+# to construct expressions with these NamedCTypes from the given bindings.  Many
+# of these expressions are trivial (I need a Tensor other; there's a Tensor
+# other scope); others are more nontrivial and may require packing/unpacking.
+# Some examples of non-trivial action:
+#
+#   - Need the "dtype" binding?  Well, maybe "dtype" isn't available
+#     in the context, instead, "options" is, and you need to extract
+#     it from there.  (Gather)
+#
+#   - Need the "context" binding?  Well, maybe "context" isn't available
+#     in the context, and you need to construct it from "dtype", "device",
+#     etc.  (Scatter)
+#
+#   - Need the "memory_format" binding?  Well, actually, it's available
+#     from both "memory_format" and "options", so you had better make sure
+#     they are consistent.  (Join)
+
+options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))
+
+out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT)))
+
+longVec_ctype = VectorCType(BaseCType(longT))
+longSymVec_ctype = VectorCType(BaseCType(SymIntT))
+optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT)))
+optionalScalar_ctype = OptionalCType(BaseCType(scalarT))
+optionalTensor_ctype = OptionalCType(BaseCType(tensorT))
+
+
+class UnsatError(RuntimeError):
+    pass
+
+
+# Given a set of in-scope bindings and a set of target bindings, synthesize
+# a list of expressions that uses only the in-scope bindings (bindings) that
+# have all of the types of goals.  You may want to use this function if
+# you're generating code for a function like:
+#
+#   void f({args}) {
+#     g({exprs}); // g is a different API
+#   }
+#
+# and you need to generate "exprs".
+#
+# Typically, a list of Bindings is convenient to get (you usually call something
+# like arguments() to get them); but technically you only need less information:
+# for 'bindings' an (un-ordered) list of Exprs is sufficient; similarly, for
+# 'goals', an (ordered) list of NamedCType goals is sufficient.  If you are doing
+# something more complicated, e.g., tracking the set of bindings in a context,
+# you may find using these smaller types more convenient.
+def translate(
+    bindings: Sequence[Expr | Binding],
+    goals: Sequence[NamedCType | Binding],
+    *,
+    method: bool = False,
+    allow_expensive_conversions: bool = False,
+) -> list[Expr]:
+    binding_exprs: list[Expr] = []
+    for b in bindings:
+        if isinstance(b, Binding):
+            binding_exprs.append(
+                Expr(
+                    expr=b.name,
+                    type=b.nctype,
+                )
+            )
+        else:
+            binding_exprs.append(b)
+
+    goal_ctypes: list[NamedCType] = []
+    for g in goals:
+        if isinstance(g, Binding):
+            goal_ctypes.append(g.nctype)
+        else:
+            goal_ctypes.append(g)
+
+    # Add all the bindings to the context
+    ctx: dict[NamedCType, str] = {}
+    for b in binding_exprs:
+        ctx[b.type] = b.expr
+
+        # While we're at it, do some simple forward inference, looking through
+        # constructors.
+        #
+        # NB: When should you do forward inference versus backward inference?
+        # The general idea:
+        #
+        #   - Backward inference WHEN the goal gets smaller
+        #   - Forward inference WHEN the hypothesis gets smaller
+        #
+        # This helps ensure termination: backward inference starts with a goal
+        # and tries to make it simpler and simpler until it's trivial; if the
+        # goal can grow in size, we blow up to a really huge goal size.
+        # Similarly, with forward inference we take hypotheses and decompose
+        # them into simpler hypotheses; if hypotheses could expand in size,
+        # we also have potential nontermination.  (In the code below, forward
+        # inference is only ever carried out at a single step, but you could
+        # imagine repeated application of forward inference being profitable.)
+        #
+        # A good starting point in the literature for exploring more about proof
+        # search are these lecture notes
+        # https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf
+        #
+        # TODO: My kingdom for a pattern matcher
+        # https://www.python.org/dev/peps/pep-0634/
+        #
+        # TODO: This could get us in recomputation trouble if b.expr is nontrivial.
+        # Fix this by implementing some sort of sharing so that if multiple
+        # goals share the same expression, we only compute it once.  This seems
+        # to matter in practice as compiler is often unwilling to CSE nontrivial
+        # expressions like scalar.to()
+        t = b.type
+        if (
+            isinstance(t, ConstRefCType)
+            and isinstance(t.elem, OptionalCType)
+            and isinstance(t.elem.elem, BaseCType)
+            and str(t.elem.elem.type) == "at::Tensor"
+        ):
+            ctx[NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))] = (
+                f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())"
+            )
+
+        if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))):
+            ctx[NamedCType(t.name, BaseCType(optionalTensorRefT))] = (
+                f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())"
+            )
+
+        if t.type == ConstRefCType(BaseCType(scalarT)):
+            ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to()"
+
+        if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))):
+            ctx[NamedCType(t.name, BaseCType(optionalScalarRefT))] = (
+                f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())"
+            )
+
+        if t.type == BaseCType(scalar_t):
+            ctx[NamedCType(t.name, BaseCType(opmath_t))] = (
+                f"static_cast({b.expr})"
+            )
+
+        # [Note: IOptTensorListRef]
+        if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))):
+            ctx[NamedCType(t.name, BaseCType(iOptTensorListRefT))] = (
+                f"at::IOptTensorListRef({b.expr})"
+            )
+
+    # Add implicit bindings if the generated code is inside a Tensor method
+    if method:
+        ctx[NamedCType("self", MutRefCType(BaseCType(tensorT)))] = (
+            "const_cast(*this)"
+        )
+        ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = (
+            "const_cast(*this)"
+        )
+        # This is better!  Byte-for-byte compat
+        # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this"
+
+    def unsat(goal: NamedCType) -> NoReturn:
+        ctx_desc = "\n".join(
+            f"  {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items()
+        )
+        raise UnsatError(
+            f"""
+Failed to synthesize the expression "{goal.cpp_type()} {goal.name}".
+When I failed, the following bindings were available in the context:
+
+{ctx_desc}
+
+This probably means there is a missing rule in the rules of torchgen.api.translate.
+Check this module for more information.
+"""
+        )
+
+    # A shitty backtracking search implementation.  It's shitty because it
+    # does backtracking via stack (bad idea!) and for the most part tries to
+    # avoid backtracking.  In particular, if
+    # direct=True, we won't try to do any fancy synthesis, just trivial
+    # conversions (e.g., "T a" is OK for "const T& a").  So all of the
+    # existing rules in this function simply try to solve immediately,
+    # and bail if things don't work out.
+    def solve(goal: NamedCType, *, direct: bool) -> str:
+        def direct_solve(goal: NamedCType) -> str:
+            return solve(goal, direct=True)
+
+        if goal in ctx:
+            # Trivial
+            return ctx[goal]
+
+        # const & is satisfied with mutable &
+        if isinstance(goal.type, ConstRefCType):
+            try:
+                # WARNING: not strictly decreasing; be careful not
+                # to add a direct conversion that goes satisfies
+                # mutable& with const&
+                return solve(
+                    NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct
+                )
+            except UnsatError:
+                pass
+
+        # mutable & is satisfied with value
+        if isinstance(goal.type, MutRefCType):
+            try:
+                return solve(NamedCType(goal.name, goal.type.elem), direct=direct)
+            except UnsatError:
+                pass
+
+        # TODO: These are referentially equal, shouldn't have to do this;
+        # ensuring we don't use type synonym IntArrayRef in codegen would
+        # help
+        if goal.type == ArrayRefCType(BaseCType(longT)):
+            return solve(NamedCType(goal.name, BaseCType(intArrayRefT)), direct=direct)
+
+        if direct:
+            unsat(goal)
+
+        # For now, all of these rules are mutually exclusive.
+        if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))):
+            memory_format = direct_solve(
+                NamedCType(
+                    SpecialArgName.possibly_redundant_memory_format,
+                    OptionalCType(BaseCType(memoryFormatT)),
+                )
+            )
+            # No need to join "memory_format" and "options" if the target API takes "options" directly.
+            # Otherwise it will cause the redundant memory_format error.
+            if options_ctype in goal_ctypes:
+                return memory_format
+            try:
+                options = direct_solve(options_ctype)
+                return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
+            except UnsatError:
+                return memory_format
+        elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
+            dtype = direct_solve(
+                NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT)))
+            )
+            pin_memory = direct_solve(
+                NamedCType("pin_memory", OptionalCType(BaseCType(boolT)))
+            )
+            device = direct_solve(
+                NamedCType("device", OptionalCType(BaseCType(deviceT)))
+            )
+            layout = direct_solve(
+                NamedCType("layout", OptionalCType(BaseCType(layoutT)))
+            )
+            return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})"
+
+        elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
+            try:
+                options = direct_solve(options_ctype)
+                return f"c10::optTypeMetaToScalarType({options}.dtype_opt())"
+            except UnsatError:
+                out_tensor = direct_solve(out_tensor_ctype)
+                return f"{out_tensor}.scalar_type()"
+
+        elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
+            try:
+                options = direct_solve(options_ctype)
+                return f"{options}.layout_opt()"
+            except UnsatError:
+                out_tensor = direct_solve(out_tensor_ctype)
+                return f"{out_tensor}.layout()"
+
+        elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
+            try:
+                options = direct_solve(options_ctype)
+                return f"{options}.device_opt()"
+            except UnsatError:
+                out_tensor = direct_solve(out_tensor_ctype)
+                return f"{out_tensor}.device()"
+
+        elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
+            try:
+                options = direct_solve(options_ctype)
+                return f"{options}.pinned_memory_opt()"
+            except UnsatError:
+                # If we're calling a factory op from its out= variant,
+                # We don't actually care about the value of pin_memory.
+                out_tensor = direct_solve(out_tensor_ctype)
+                return "::std::nullopt"
+
+        # We can always do translations from value types to reference types, like vector -> IntArrayRef
+        elif goal.type == BaseCType(intArrayRefT):
+            try:
+                return direct_solve(NamedCType(goal.name, longVec_ctype))
+            except UnsatError:
+                # We can also go SymIntArrayRef -> IntArrayRef
+                symIntArrayRef_type = direct_solve(
+                    NamedCType(goal.name, BaseCType(symIntArrayRefT))
+                )
+                return f"C10_AS_INTARRAYREF_SLOW({symIntArrayRef_type})"
+        elif goal.type == BaseCType(symIntArrayRefT):
+            try:
+                r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT)))
+                return f"c10::fromIntArrayRefSlow({r})"
+            except UnsatError:
+                return direct_solve(NamedCType(goal.name, longSymVec_ctype))
+        elif goal.type == BaseCType(SymIntT):
+            return direct_solve(NamedCType(goal.name, BaseCType(longT)))
+        elif goal.type == OptionalCType(BaseCType(SymIntT)):
+            argname = direct_solve(
+                NamedCType(goal.name, OptionalCType(BaseCType(longT)))
+            )
+            return f"{argname}.has_value() ? ::std::make_optional(c10::SymInt(*{argname})) : ::std::nullopt"
+        elif goal.type == BaseCType(longT):
+            symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT)))
+            return f"{symInt_type}.guard_int(__FILE__, __LINE__)"
+        elif goal.type == OptionalCType(BaseCType(longT)):
+            argname = direct_solve(
+                NamedCType(goal.name, OptionalCType(BaseCType(SymIntT)))
+            )
+            return f"{argname}.has_value() ? ::std::make_optional({argname}->guard_int(__FILE__, __LINE__)) : ::std::nullopt"
+        elif goal.type == BaseCType(optionalIntArrayRefT):
+            try:
+                return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))
+            except UnsatError:
+                argname = direct_solve(
+                    NamedCType(goal.name, BaseCType(optionalSymIntArrayRefT))
+                )
+                return f"{argname}.has_value() ? ::std::make_optional(C10_AS_INTARRAYREF_SLOW(*{argname})) : ::std::nullopt"
+        elif goal.type == BaseCType(optionalSymIntArrayRefT):
+            # TODO: You might also want to solve this from longSymVec_ctype or
+            # an optional version of it
+            argname = direct_solve(
+                NamedCType(goal.name, BaseCType(optionalIntArrayRefT))
+            )
+            return f"{argname}.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*{argname})) : ::std::nullopt"
+        elif goal.type == BaseCType(optionalScalarRefT):
+            return direct_solve(NamedCType(goal.name, optionalScalar_ctype))
+        elif goal.type == BaseCType(optionalTensorRefT):
+            return direct_solve(NamedCType(goal.name, optionalTensor_ctype))
+
+        # Note [translation from C++ reference to value types]
+        # The below cases are all for when we have an argument with a reference type,
+        # and a corresponding goal with a value type.
+        # These are needed when we populate the inputs to a lambda capture and we need
+        # to guarantee the lifetime of each captured argument.
+        # We guard it with an explicit kwarg because converting to a value type is expensive
+        # (O(n)) to convert from IntArrayRef to vector),
+        # so the caller of translate() should be explicit that they need it.
+        if allow_expensive_conversions:
+            if goal.type == VectorCType(BaseCType(longT)):
+                intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT))
+                argname = direct_solve(intArrayRef_ctype)
+                return f"{argname}.vec()"
+            if goal.type == VectorCType(BaseCType(SymIntT)):
+                symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT))
+                argname = direct_solve(symIntArrayRef_ctype)
+                return f"{argname}.vec()"
+            elif goal.type == OptionalCType(VectorCType(BaseCType(longT))):
+                optionalIntArrayRef_ctype = NamedCType(
+                    goal.name, BaseCType(optionalIntArrayRefT)
+                )
+                argname = direct_solve(optionalIntArrayRef_ctype)
+                return f"{argname}.has_value() ? ::std::make_optional({argname}->vec()) : ::std::nullopt"
+            elif goal.type == OptionalCType(BaseCType(scalarT)):
+                optionalScalarRef_ctype = NamedCType(
+                    goal.name, BaseCType(optionalScalarRefT)
+                )
+                argname = direct_solve(optionalScalarRef_ctype)
+                return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt"
+            elif goal.type == OptionalCType(BaseCType(scalarT)):
+                optionalTensorRef_ctype = NamedCType(
+                    goal.name, BaseCType(optionalTensorRefT)
+                )
+                argname = direct_solve(optionalTensorRef_ctype)
+                return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt"
+            # Technically, we also need to handle cases of C++ containers holding reference types.
+            # But there currently aren't any ops that require lambda capture codegen
+            # With arguments like ::std::vector.
+            # If that changes, we'll have to add the translation here.
+
+        # We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor.
+        # We could probably generalize this to non-tensor types too.
+        if goal.type == MutRefCType(BaseCType(tensorT)):
+            const_ref_tensor_ctype = NamedCType(
+                goal.name, ConstRefCType(BaseCType(tensorT))
+            )
+            argname = direct_solve(const_ref_tensor_ctype)
+            return f"const_cast({argname})"
+
+        unsat(goal)
+
+    return [Expr(solve(g, direct=False), g) for g in goal_ctypes]
diff --git a/phivenv/Lib/site-packages/torchgen/api/types/__init__.py b/phivenv/Lib/site-packages/torchgen/api/types/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b48163dbd3706644e162f899267674b86bf053e
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/api/types/__init__.py
@@ -0,0 +1,5 @@
+from torchgen.api.types.types import *
+from torchgen.api.types.types_base import *
+
+
+from torchgen.api.types.signatures import *  # usort: skip
diff --git a/phivenv/Lib/site-packages/torchgen/api/types/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/api/types/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bc204bf3ec6a9790a74702c4d215302ce9027ca6
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/api/types/__pycache__/__init__.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/api/types/__pycache__/signatures.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/api/types/__pycache__/signatures.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3559e889570244db69d80fda8d463957d5bf3d33
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/api/types/__pycache__/signatures.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/api/types/__pycache__/types.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/api/types/__pycache__/types.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3191d7dbf728304fb2d4d0bd18640d03934c6235
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/api/types/__pycache__/types.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/api/types/__pycache__/types_base.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/api/types/__pycache__/types_base.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72f24d4f87b7c0bfede4f67c676c4658bd58d831
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/api/types/__pycache__/types_base.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/api/types/signatures.py b/phivenv/Lib/site-packages/torchgen/api/types/signatures.py
new file mode 100644
index 0000000000000000000000000000000000000000..d64f5afc281b1bee7a6484cde4cacaef717d5f71
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/api/types/signatures.py
@@ -0,0 +1,428 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING
+
+from torchgen.api.types.types_base import Binding, CType, Expr
+
+
+if TYPE_CHECKING:
+    from collections.abc import Iterator, Sequence
+
+    from torchgen.model import (
+        BackendIndex,
+        FunctionSchema,
+        NativeFunction,
+        NativeFunctionsGroup,
+        NativeFunctionsViewGroup,
+    )
+
+
+@dataclass(frozen=True)
+class CppSignature:
+    """
+    A CppSignature represents a single overload in the C++ API.  For
+    any given function schema, there may be multiple CppSignatures
+    corresponding to it, based on how we desugar to C++.  See also
+    CppSignatureGroup.
+    """
+
+    # The schema this signature is derived from
+    func: FunctionSchema
+
+    # Is this a C++ signature for a method, i.e. Tensor::my_op(...)?
+    method: bool
+
+    # Is this a faithful C++ signature (i.e. following the JIT schema) or a convenience API
+    # (i.e. with a potential TensorOptions argument and out arguments in the front)
+    faithful: bool
+
+    # Is this a symint C++ signature.  For BC reasons, functions that take
+    # SymInts still present as int64_t in C++, and the SymInt variant is
+    # offered at a different overload name
+    #
+    # NB: If a function RETURNS a SymInt, this is ALWAYS false
+    symint: bool
+
+    # The set of C++ arguments which should not have defaults applied to them
+    cpp_no_default_args: set[str]
+
+    # Is this a fallback C++ binding?  Fallback bindings are enabled by
+    # manual_cpp_binding: True and are alternate, non-public API that
+    # lets manual C++ binding implementers access the binding that would
+    # have been automatically generated
+    fallback_binding: bool = False
+
+    # Return the unpacked argument structure of this signature,
+    # discarding information about which arguments are semantically
+    # related to each other.
+    def arguments(self) -> Sequence[Binding]:
+        return cpp.arguments(
+            self.func.arguments,
+            faithful=self.faithful,
+            symint=self.symint,
+            method=self.method,
+            cpp_no_default_args=self.cpp_no_default_args,
+        )
+
+    def name(self, *, suppress_symint_suffix: bool = False) -> str:
+        n = cpp.name(
+            self.func,
+            faithful_name_for_out_overloads=self.faithful,
+            symint_overload=False if suppress_symint_suffix else self.symint,
+        )
+        if self.fallback_binding:
+            n = f"__dispatch_{n}"
+        return n
+
+    # Render the C++ declaration for this signature
+    def decl(
+        self,
+        *,
+        name: str | None = None,
+        prefix: str = "",
+        is_redispatching_fn: bool = False,
+        suppress_symint_suffix: bool = False,
+    ) -> str:
+        returns_type = cpp.returns_type(
+            self.func.returns, symint=self.symint
+        ).cpp_type()
+        cpp_args = [a.decl() for a in self.arguments()]
+        if is_redispatching_fn:
+            cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args
+        cpp_args_str = ", ".join(cpp_args)
+        if name is None:
+            name = prefix + self.name(suppress_symint_suffix=suppress_symint_suffix)
+        return f"{returns_type} {name}({cpp_args_str})"
+
+    # Render the C++ definition for this signature, not including
+    # the body (with curly braces)
+    def defn(
+        self,
+        *,
+        name: str | None = None,
+        prefix: str = "",
+        is_redispatching_fn: bool = False,
+    ) -> str:
+        returns_type = cpp.returns_type(
+            self.func.returns, symint=self.symint
+        ).cpp_type()
+        cpp_args = [a.defn() for a in self.arguments()]
+        if is_redispatching_fn:
+            cpp_args = ["c10::DispatchKeySet dispatchKeySet"] + cpp_args
+        cpp_args_str = ", ".join(cpp_args)
+        if name is None:
+            name = prefix + self.name()
+        return f"{returns_type} {name}({cpp_args_str})"
+
+    def ptr_type(self) -> str:
+        args_types_str = ", ".join(a.type for a in self.arguments())
+        return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_types_str})"
+
+    # Return the C++ function type, e.g., something like int(bool)
+    def type(self) -> str:
+        args_types_str = ", ".join(a.type for a in self.arguments())
+        return f"{cpp.returns_type(self.func.returns, symint=self.symint).cpp_type()} ({args_types_str})"
+
+
+# Represents group of all CppSignatures associated with a
+# FunctionSchema.  Right now, that's the regular, user-visible
+# signature, as well as a "faithful" signature which doesn't
+# have grouping.
+@dataclass(frozen=True)
+class CppSignatureGroup:
+    func: FunctionSchema
+    signature: CppSignature
+    faithful_signature: CppSignature | None
+    symint_signature: CppSignature | None
+    symint_faithful_signature: CppSignature | None
+
+    def most_faithful_signature(self) -> CppSignature:
+        if self.faithful_signature:
+            return self.faithful_signature
+        else:
+            return self.signature
+
+    def signatures(self, *, symint: bool = True) -> Iterator[CppSignature]:
+        yield self.signature
+        if self.faithful_signature:
+            yield self.faithful_signature
+        if symint:
+            if self.symint_signature:
+                yield self.symint_signature
+            if self.symint_faithful_signature:
+                yield self.symint_faithful_signature
+
+    @staticmethod
+    def from_native_function(
+        f: NativeFunction, *, method: bool, fallback_binding: bool = False
+    ) -> CppSignatureGroup:
+        func = f.func
+
+        def make_sig(*, faithful: bool, symint: bool) -> CppSignature:
+            return CppSignature(
+                func=func,
+                faithful=faithful,
+                symint=symint,
+                method=method,
+                fallback_binding=fallback_binding,
+                cpp_no_default_args=f.cpp_no_default_args,
+            )
+
+        def make_sigs(*, symint: bool) -> tuple[CppSignature, CppSignature | None]:
+            faithful_signature: CppSignature | None = None
+            if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
+                faithful_signature = make_sig(faithful=True, symint=symint)
+            signature = make_sig(faithful=False, symint=symint)
+            return signature, faithful_signature
+
+        signature, faithful_signature = make_sigs(symint=False)
+        symint_signature: CppSignature | None = None
+        symint_faithful_signature: CppSignature | None = None
+        if func.has_symint():
+            symint_signature, symint_faithful_signature = make_sigs(symint=True)
+
+        return CppSignatureGroup(
+            func=func,
+            signature=signature,
+            faithful_signature=faithful_signature,
+            symint_signature=symint_signature,
+            symint_faithful_signature=symint_faithful_signature,
+        )
+
+
+@dataclass(frozen=True)
+class DispatcherSignature:
+    # The schema this signature is derived from
+    func: FunctionSchema
+
+    # Allows you to prepend an arbitrary prefix to the signature name.
+    # This is useful for parts of the codegen that generate wrappers around kernels,
+    # and need to avoid naming collisions.
+    prefix: str = ""
+
+    symint: bool = True
+
+    def arguments(self) -> list[Binding]:
+        return dispatcher.arguments(self.func, symint=self.symint)
+
+    def name(self) -> str:
+        return self.prefix + dispatcher.name(self.func)
+
+    def decl(self, name: str | None = None) -> str:
+        args_str = ", ".join(a.decl() for a in self.arguments())
+        if name is None:
+            name = self.name()
+        return f"{self.returns_type().cpp_type()} {name}({args_str})"
+
+    def defn(
+        self, name: str | None = None, *, is_redispatching_fn: bool = False
+    ) -> str:
+        args = [a.defn() for a in self.arguments()]
+        if is_redispatching_fn:
+            args = ["c10::DispatchKeySet dispatchKeySet"] + args
+        args_str = ", ".join(args)
+        if name is None:
+            name = self.name()
+        return f"{self.returns_type().cpp_type()} {name}({args_str})"
+
+    def exprs(self) -> list[Expr]:
+        return [Expr(a.name, a.nctype) for a in self.arguments()]
+
+    def returns_type(self) -> CType:
+        return dispatcher.returns_type(self.func.returns, symint=self.symint)
+
+    def ptr_type(self) -> str:
+        dispatcher_args_types_str = ", ".join(a.type for a in self.arguments())
+        return f"{self.returns_type().cpp_type()} (*)({dispatcher_args_types_str})"
+
+    # Return the C++ function type, e.g., something like int(bool)
+    def type(self) -> str:
+        dispatcher_args_types_str = ", ".join(a.type for a in self.arguments())
+        return f"{self.returns_type().cpp_type()} ({dispatcher_args_types_str})"
+
+    @staticmethod
+    def from_schema(
+        func: FunctionSchema, *, prefix: str = "", symint: bool = True
+    ) -> DispatcherSignature:
+        return DispatcherSignature(func, prefix, symint)
+
+
+@dataclass(frozen=True)
+class NativeSignature:
+    # The schema this signature is derived from
+    func: FunctionSchema
+
+    symint: bool
+
+    prefix: str = ""
+
+    def name(self) -> str:
+        return self.prefix + native.name(self.func)
+
+    def decl(self, name: str | None = None) -> str:
+        args_str = ", ".join(a.decl() for a in self.arguments())
+        if name is None:
+            name = self.name()
+        return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
+
+    def defn(self, name: str | None = None) -> str:
+        args_str = ", ".join(a.defn() for a in self.arguments())
+        if name is None:
+            name = self.name()
+        return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
+
+    def ptr_type(self) -> str:
+        # don't include defaults in type signature!
+        args_str = ", ".join(a.defn() for a in self.arguments())
+        return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})"
+
+    def arguments(self) -> list[Binding]:
+        return native.arguments(self.func, symint=self.symint)
+
+    def returns_type(self) -> CType:
+        return native.returns_type(self.func.returns, symint=self.symint)
+
+    def dispatcher_exprs(self) -> list[Expr]:
+        return translate.translate(
+            self.arguments(), dispatcher.arguments(self.func), method=False
+        )
+
+
+@dataclass(frozen=True)
+class ViewInverseSignature:
+    g: NativeFunctionsViewGroup
+
+    def name(self) -> str:
+        return functionalization.reverse_name(self.g.view, include_namespace=False)
+
+    def decl(self) -> str:
+        return_type = functionalization.returns_type(self.g.view.func)
+        decls = [
+            a.decl()
+            for a in functionalization.inner_arguments(
+                self.g.view.func, is_reverse=True
+            )
+        ]
+        return f"static {return_type.cpp_type()} {self.name()}({', '.join(decls)});"
+
+
+@dataclass(frozen=True)
+class FunctionalizationLambda:
+    g: NativeFunctionsViewGroup
+
+    # are we generating the forward lambda or the reverse lambda?
+    is_reverse: bool
+
+    def captures(self) -> list[Expr]:
+        # The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments
+        # We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
+        # and plumb it into the lambda.
+        outer_ctx = dispatcher.arguments(self.g.view.func) + [
+            functionalization.reapply_views_binding,
+            functionalization.inverse_return_mode_binding,
+        ]
+        capture_bindings = functionalization.capture_arguments(
+            self.g.view.func, is_reverse=self.is_reverse
+        )
+        # allow_expensive_conversions is set because we want to convert
+        # some reference types (IntArrayRef) to value types (vector).
+        capture_exprs = translate.translate(
+            outer_ctx, capture_bindings, method=False, allow_expensive_conversions=True
+        )
+        return capture_exprs
+
+    def decl(self) -> str:
+        return_type = functionalization.returns_type(self.g.view.func)
+        capture_str = ", ".join(
+            f"{val.type.name} = {val.expr}" for val in self.captures()
+        )
+        decls = [
+            a.decl()
+            for a in functionalization.outer_arguments(is_reverse=self.is_reverse)
+        ]
+        return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}"
+
+    def inner_call(self, *, reapply_views: bool | None = None) -> str:
+        inner_call_name = functionalization.name(
+            self.g,
+            is_reverse=self.is_reverse,
+            include_namespace=True,
+            reapply_views=reapply_views,
+        )
+
+        arg_ctx = functionalization.outer_arguments(is_reverse=self.is_reverse)
+        capture_ctx = functionalization.capture_arguments(
+            self.g.view.func, is_reverse=self.is_reverse
+        )
+        full_ctx = arg_ctx + capture_ctx
+
+        assert self.g.view_copy is not None
+        call_bindings = functionalization.inner_arguments(
+            self.g.view_copy.func, is_reverse=self.is_reverse
+        )
+        maybe_index = functionalization.inner_call_index(self.g.view_copy.func)
+        call_exprs = [
+            e.expr for e in translate.translate(full_ctx, call_bindings, method=False)
+        ]
+        if not self.is_reverse and maybe_index is not None:
+            return f"{inner_call_name}({', '.join(call_exprs)})[{maybe_index.name}];"
+        else:
+            return f"{inner_call_name}({', '.join(call_exprs)});"
+
+    @staticmethod
+    def from_func(
+        g: NativeFunctionsViewGroup, *, is_reverse: bool
+    ) -> FunctionalizationLambda:
+        return FunctionalizationLambda(g, is_reverse)
+
+
+@dataclass(frozen=True)
+class StructuredImplSignature:
+    g: NativeFunctionsGroup
+    name: str
+
+    def defn(self, name: str | None = None) -> str:
+        args_str = ", ".join(a.defn() for a in self.arguments())
+        return f"TORCH_IMPL_FUNC({self.name})({args_str})"
+
+    def arguments(self) -> list[Binding]:
+        return structured.impl_arguments(self.g)
+
+
+# Helper functions
+
+
+def kernel_signature(
+    f: NativeFunction, backend_index: BackendIndex, *, prefix: str = ""
+) -> NativeSignature | DispatcherSignature:
+    # Note [External Backends Follow Dispatcher API]
+    # Kernel signatures for in-tree backends follow the "native" API,
+    # while kernels for out-of-tree backends follow the dispatcher API.
+    # See the comments in `native.py` for details, but historically there have been
+    # some small differences in schema convention between them and the Dispatcher API.
+    # Any differences that require translating between the two will results in a runtime cost,
+    # so we'd like to keep the differences as small as possible.
+    # With external backends, we'd like to enforce that they write their kernels with schemas
+    # that match the Dispatcher API directly, if they can.
+    meta = backend_index.get_kernel(f)
+    symint = meta is not None and meta.supports_symint()
+    if symint:
+        assert f.func.has_symint(), (
+            f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema"
+        )
+    if backend_index.external:
+        return DispatcherSignature.from_schema(f.func, prefix=prefix, symint=symint)
+    else:
+        return NativeSignature(f.func, prefix=prefix, symint=symint)
+
+
+# Functions only, no types
+from torchgen.api import (
+    cpp,
+    dispatcher,
+    functionalization,
+    native,
+    structured,
+    translate,
+)
diff --git a/phivenv/Lib/site-packages/torchgen/api/types/types.py b/phivenv/Lib/site-packages/torchgen/api/types/types.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e5feaff723f1a262981ba0386fe6119ad36d5ee
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/api/types/types.py
@@ -0,0 +1,181 @@
+"""
+Where should I add a new type? `types_base.py` vs `types.py`
+
+This file defines data model classes for torchgen typing system, as well as some base types such as int32_t.
+
+`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types.
+
+The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't
+contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused
+if we want to generate code for another C++ library.
+
+Add new types to `types.py` if these types are ATen/c10 related.
+Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
+"""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+from torchgen.api.types.types_base import (
+    BaseCppType,
+    BaseCType,
+    boolT,
+    byteT,
+    charT,
+    CType,
+    doubleT,
+    floatT,
+    int32T,
+    longT,
+    shortT,
+)
+from torchgen.model import BaseTy, ScalarType
+
+
+TENSOR_LIST_LIKE_CTYPES = [
+    "at::TensorList",
+    "const c10::List<::std::optional> &",
+    "const at::ITensorListRef &",
+]
+
+
+halfT = BaseCppType("at", "Half")
+complexHalfT = BaseCppType(
+    "c10", "complex"
+)  # stuffing template param here is an abuse
+complexFloatT = BaseCppType("c10", "complex")
+complexDoubleT = BaseCppType("c10", "complex")
+bfloat16T = BaseCppType("at", "BFloat16")
+float8_e5m2T = BaseCppType("at", "Float8_e5m2")
+float8_e5m2fnuzT = BaseCppType("at", "Float8_e5m2fnuz")
+float8_e4m3fnT = BaseCppType("at", "Float8_e4m3fn")
+float8_e4m3fnuzT = BaseCppType("at", "Float8_e4m3fnuz")
+float8_e8m0fnuT = BaseCppType("at", "Float8_e8m0fnu")
+stringT = BaseCppType("c10", "string_view")
+generatorT = BaseCppType("at", "Generator")
+scalarTypeT = BaseCppType("at", "ScalarType")
+tensorT = BaseCppType("at", "Tensor")
+optionalTensorRefT = BaseCppType("at", "OptionalTensorRef")
+tensorListT = BaseCppType("at", "TensorList")
+iTensorListRefT = BaseCppType("at", "ITensorListRef")
+iOptTensorListRefT = BaseCppType("at", "IOptTensorListRef")
+dimnameT = BaseCppType("at", "Dimname")
+dimnameListT = BaseCppType("at", "DimnameList")
+dimVectorT = BaseCppType("at", "DimVector")
+layoutT = BaseCppType("at", "Layout")
+deviceT = BaseCppType("at", "Device")
+deviceIndexT = BaseCppType("at", "DeviceIndex")
+scalarT = BaseCppType("at", "Scalar")
+optionalScalarRefT = BaseCppType("at", "OptionalScalarRef")
+memoryFormatT = BaseCppType("at", "MemoryFormat")
+qschemeT = BaseCppType("at", "QScheme")
+storageT = BaseCppType("at", "Storage")
+streamT = BaseCppType("at", "Stream")
+intArrayRefT = BaseCppType("at", "IntArrayRef")
+optionalIntArrayRefT = BaseCppType("at", "OptionalIntArrayRef")
+optionalSymIntArrayRefT = BaseCppType("at", "OptionalSymIntArrayRef")
+tensorOptionsT = BaseCppType("at", "TensorOptions")
+typeAndSizeT = BaseCppType("torch::autograd::generated", "TypeAndSize")
+tensorGeometryT = BaseCppType("at", "TensorGeometry")
+SymIntT = BaseCppType("c10", "SymInt")
+symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef")
+
+# Types representing template parameters.  Technically, we probably shouldn't
+# represent them this way in codegen, but it was pretty convenient.
+scalar_t = BaseCppType("", "scalar_t")
+opmath_t = BaseCppType("", "opmath_t")
+
+ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = {
+    ScalarType.Byte: byteT,
+    ScalarType.Char: charT,
+    ScalarType.Short: shortT,
+    ScalarType.Int: int32T,
+    ScalarType.Long: longT,
+    ScalarType.Half: halfT,
+    ScalarType.Float: floatT,
+    ScalarType.Double: doubleT,
+    ScalarType.ComplexHalf: complexHalfT,
+    ScalarType.ComplexFloat: complexFloatT,
+    ScalarType.ComplexDouble: complexDoubleT,
+    ScalarType.Bool: boolT,
+    ScalarType.Float8_e5m2: float8_e5m2T,
+    ScalarType.Float8_e5m2fnuz: float8_e5m2fnuzT,
+    ScalarType.Float8_e4m3fn: float8_e4m3fnT,
+    ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT,
+    ScalarType.Float8_e8m0fnu: float8_e8m0fnuT,
+}
+
+BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
+    BaseTy.int: longT,
+    BaseTy.float: doubleT,
+    BaseTy.bool: boolT,
+    BaseTy.str: stringT,
+    BaseTy.Generator: generatorT,
+    BaseTy.ScalarType: scalarTypeT,
+    BaseTy.Tensor: tensorT,
+    BaseTy.Dimname: dimnameT,
+    BaseTy.DimVector: dimVectorT,
+    BaseTy.Layout: layoutT,
+    BaseTy.Device: deviceT,
+    BaseTy.DeviceIndex: deviceIndexT,
+    BaseTy.Scalar: scalarT,
+    BaseTy.MemoryFormat: memoryFormatT,
+    BaseTy.QScheme: qschemeT,
+    BaseTy.Storage: storageT,
+    BaseTy.Stream: streamT,
+    BaseTy.SymInt: SymIntT,
+}
+
+# CTypes encode C++ type structure as needed for translation.
+
+
+@dataclass(frozen=True)
+class OptionalCType(CType):
+    elem: CType
+
+    def cpp_type(self, *, strip_ref: bool = False) -> str:
+        # Do not pass `strip_ref` recursively.
+        return f"::std::optional<{self.elem.cpp_type()}>"
+
+    def remove_const_ref(self) -> CType:
+        return OptionalCType(self.elem.remove_const_ref())
+
+
+@dataclass(frozen=True)
+class ListCType(CType):
+    elem: CType
+
+    def cpp_type(self, *, strip_ref: bool = False) -> str:
+        # Do not pass `strip_ref` recursively.
+        return f"c10::List<{self.elem.cpp_type()}>"
+
+    def remove_const_ref(self) -> CType:
+        return ListCType(self.elem.remove_const_ref())
+
+
+@dataclass(frozen=True)
+class ArrayRefCType(CType):
+    elem: CType
+
+    def cpp_type(self, *, strip_ref: bool = False) -> str:
+        # Do not pass `strip_ref` recursively.
+        return f"at::ArrayRef<{self.elem.cpp_type()}>"
+
+    def remove_const_ref(self) -> CType:
+        return ArrayRefCType(self.elem.remove_const_ref())
+
+
+@dataclass(frozen=True)
+class VectorizedCType(CType):
+    # This template is explicitly specialized, so the only valid
+    # elems are those we have specializations for (e.g., float, double, ...)
+    # scalar_t is also a common argument here (when we are codegen in
+    # a templated context)
+    elem: BaseCType
+
+    def cpp_type(self, *, strip_ref: bool = False) -> str:
+        return f"at::vec::Vectorized<{self.elem.cpp_type()}>"
+
+    def remove_const_ref(self) -> CType:
+        return self
diff --git a/phivenv/Lib/site-packages/torchgen/api/types/types_base.py b/phivenv/Lib/site-packages/torchgen/api/types/types_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..81b100db5f2b2a53839c72a20617a37371d7ca9d
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/api/types/types_base.py
@@ -0,0 +1,238 @@
+"""
+Where should I add a new type? `types_base.py` vs `types.py`
+
+This file defines data model classes for torchgen typing system, as well as some base types such as int32_t.
+
+`types.py` defines ATen Tensor type and some c10 types, along with signatures that use these types.
+
+The difference between these two files, is `types_base.py` should be implementation-agnostic, meaning it shouldn't
+contain any type definition that is tight to a specific C++ library (e.g., ATen), so that it can be easily reused
+if we want to generate code for another C++ library.
+
+Add new types to `types.py` if these types are ATen/c10 related.
+Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
+"""
+
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from enum import auto, Enum
+from typing import TYPE_CHECKING, Union
+
+
+if TYPE_CHECKING:
+    from torchgen.model import Argument, SelfArgument, TensorOptionsArguments
+
+
+# An ArgName is just the str name of the argument in schema;
+# but in some special circumstances, we may add a little extra
+# context.  The Enum SpecialArgName covers all of these cases;
+# grep for their construction sites to see when they can occur.
+
+
+class SpecialArgName(Enum):
+    possibly_redundant_memory_format = auto()
+
+
+ArgName = Union[str, SpecialArgName]
+
+
+# This class shouldn't be created directly; instead, use/create one of the singletons below.
+@dataclass(frozen=True)
+class BaseCppType:
+    ns: str | None
+    name: str
+
+    def __str__(self) -> str:
+        if self.ns is None or self.ns == "":
+            return self.name
+        return f"{self.ns}::{self.name}"
+
+
+# The set of all non-templated, valid, fully-qualified names of C++ types that are used in the codegen.
+# Templated types get their own dataclass, mainly to make namespace parsing easier.
+byteT = BaseCppType("", "uint8_t")
+charT = BaseCppType("", "int8_t")
+shortT = BaseCppType("", "int16_t")
+# It would be more symmetric for this to be called intT, but it easy to mix
+# this up with JIT int (which is int64_t in C++), so we intentionally don't
+# define intT to make it obvious when you've stuffed it up
+int32T = BaseCppType("", "int32_t")
+longT = BaseCppType("", "int64_t")
+doubleT = BaseCppType("", "double")
+floatT = BaseCppType("", "float")
+boolT = BaseCppType("", "bool")
+voidT = BaseCppType("", "void")
+
+
+class CType(ABC):
+    @abstractmethod
+    def cpp_type(self, *, strip_ref: bool = False) -> str:
+        raise NotImplementedError
+
+    @abstractmethod
+    def remove_const_ref(self) -> CType:
+        return self
+
+
+@dataclass(frozen=True)
+class BaseCType(CType):
+    type: BaseCppType
+
+    def cpp_type(self, *, strip_ref: bool = False) -> str:
+        return str(self.type)
+
+    def remove_const_ref(self) -> CType:
+        return self
+
+
+@dataclass(frozen=True)
+class ConstRefCType(CType):
+    elem: CType
+
+    def cpp_type(self, *, strip_ref: bool = False) -> str:
+        if strip_ref:
+            return self.elem.cpp_type(strip_ref=strip_ref)
+        return f"const {self.elem.cpp_type()} &"
+
+    def remove_const_ref(self) -> CType:
+        return self.elem.remove_const_ref()
+
+
+@dataclass(frozen=True)
+class VectorCType(CType):
+    elem: CType
+
+    def cpp_type(self, *, strip_ref: bool = False) -> str:
+        # Do not pass `strip_ref` recursively.
+        return f"::std::vector<{self.elem.cpp_type()}>"
+
+    def remove_const_ref(self) -> CType:
+        return VectorCType(self.elem.remove_const_ref())
+
+
+@dataclass(frozen=True)
+class ArrayCType(CType):
+    elem: CType
+    size: int
+
+    def cpp_type(self, *, strip_ref: bool = False) -> str:
+        # Do not pass `strip_ref` recursively.
+        return f"::std::array<{self.elem.cpp_type()},{self.size}>"
+
+    def remove_const_ref(self) -> CType:
+        return ArrayCType(self.elem.remove_const_ref(), self.size)
+
+
+@dataclass(frozen=True)
+class TupleCType(CType):
+    elems: list[CType]
+
+    def cpp_type(self, *, strip_ref: bool = False) -> str:
+        # Do not pass `strip_ref` recursively.
+        return f"::std::tuple<{','.join([e.cpp_type() for e in self.elems])}>"
+
+    def remove_const_ref(self) -> CType:
+        return TupleCType([e.remove_const_ref() for e in self.elems])
+
+
+@dataclass(frozen=True)
+class MutRefCType(CType):
+    elem: CType
+
+    def cpp_type(self, *, strip_ref: bool = False) -> str:
+        if strip_ref:
+            return self.elem.cpp_type(strip_ref=strip_ref)
+        return f"{self.elem.cpp_type()} &"
+
+    def remove_const_ref(self) -> CType:
+        return self.elem.remove_const_ref()
+
+
+# A NamedCType is short for Named C++ semantic type.  A NamedCType represents a C++ type, plus
+# semantic information about what it represents.  For example, consider the
+# argument "bool pin_memory"; its normal C++ type is "bool", but its C++
+# semantic type also keeps track that this represents a "pin_memory"; you can't
+# just use a random other boolean in a context where you need a "pin_memory"!
+#
+
+
+@dataclass(frozen=True)
+class NamedCType:
+    name: ArgName
+    type: CType
+
+    def cpp_type(self, *, strip_ref: bool = False) -> str:
+        return self.type.cpp_type(strip_ref=strip_ref)
+
+    def remove_const_ref(self) -> NamedCType:
+        return NamedCType(self.name, self.type.remove_const_ref())
+
+    def with_name(self, name: str) -> NamedCType:
+        return NamedCType(name, self.type)
+
+
+# A binding represents any C++ binding site for a formal parameter.
+# We don't distinguish between binding sites for different APIs;
+# instead, all of the important distinctions are encoded in CType,
+# which you can use to figure out if a given Binding is appropriate
+# for use in another context.  (See torchgen.api.translate)
+
+
+@dataclass(frozen=True)
+class Binding:
+    name: str
+    nctype: NamedCType
+    argument: Argument | TensorOptionsArguments | SelfArgument
+    # TODO: maybe don't represent default here
+    default: str | None = None
+
+    def rename(self, name: str) -> Binding:
+        return Binding(
+            name=name,
+            nctype=self.nctype,
+            argument=self.argument,
+            default=self.default,
+        )
+
+    @property
+    def type(self) -> str:
+        return self.nctype.cpp_type()
+
+    def no_default(self) -> Binding:
+        return Binding(
+            name=self.name,
+            nctype=self.nctype,
+            default=None,
+            argument=self.argument,
+        )
+
+    def decl(self, *, func_ptr_cast: bool = False) -> str:
+        mb_default = ""
+        if self.default is not None:
+            mb_default = f"={self.default}"
+
+        # casting only needs to know the type
+        if func_ptr_cast:
+            return f"{self.type}"
+        else:
+            return f"{self.type} {self.name}{mb_default}"
+
+    def defn(self) -> str:
+        return f"{self.type} {self.name}"
+
+    def with_name(self, name: str) -> Binding:
+        return Binding(
+            name=name, nctype=self.nctype, argument=self.argument, default=self.default
+        )
+
+
+# An Expr is a C++ expression.  It has a C++ string representing its syntax,
+# as well as a CType saying what it provides.
+
+
+@dataclass(frozen=True)
+class Expr:
+    expr: str
+    type: NamedCType
diff --git a/phivenv/Lib/site-packages/torchgen/api/ufunc.py b/phivenv/Lib/site-packages/torchgen/api/ufunc.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bfc0de5fe6fd3e8f56eb3753bdb8031d8e5c659
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/api/ufunc.py
@@ -0,0 +1,209 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+import torchgen.api.types as api_types
+from torchgen.api import cpp, structured
+from torchgen.api.types import (
+    ArgName,
+    BaseCppType,
+    BaseCType,
+    Binding,
+    ConstRefCType,
+    CType,
+    NamedCType,
+    scalarT,
+)
+from torchgen.model import (
+    Argument,
+    BaseTy,
+    BaseType,
+    DispatchKey,
+    FunctionSchema,
+    NativeFunctionsGroup,
+    Type,
+)
+
+
+def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str:
+    assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas"
+    return f"ufunc_{func.name.name}_{dispatch_key}"
+
+
+def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
+    return schema_kernel_name(g.out.func, dispatch_key)
+
+
+# Tensors are omitted (as they are stored in TensorIterator), everything else is
+# passed along  (technically, we can pass tensors along too, it just wastes
+# argument registers)
+#
+# NB: used for CPU only
+def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None:
+    # Dispatch stubs are always plain ints
+    r = cpp.valuetype_type(t, binds=binds, symint=False)
+    if r is not None:
+        return r
+
+    if t == BaseType(BaseTy.Scalar):
+        return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
+    elif t == BaseType(BaseTy.Tensor):
+        return None
+    else:
+        raise AssertionError(f"unrecognized type {repr(t)}")
+
+
+def opmath_type(scalar_t: BaseCppType) -> BaseCppType:
+    if scalar_t == api_types.scalar_t:
+        return api_types.opmath_t
+    raise NotImplementedError
+
+
+# NB: Tensors in constructor are stored in opmath_t, not scalar_t
+# because Tensor in constructor = its a scalar tensor partially applied =
+# it can be higher precision and we want to compute in that higher precision
+#
+# NB: CUDA only
+def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType:
+    r = cpp.valuetype_type(t, binds=binds, symint=False)
+    if r is not None:
+        return r
+
+    if t == BaseType(BaseTy.Scalar):
+        return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
+    elif t == BaseType(BaseTy.Tensor):
+        return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
+    else:
+        raise AssertionError(f"unrecognized type {repr(t)}")
+
+
+# Only Tensors ever get passed directly to operator()
+#
+# NB: CUDA only
+# (Actually, this works for CPU too)
+def ufunctor_apply_type(
+    t: Type, *, binds: ArgName, scalar_t: BaseCppType
+) -> NamedCType:
+    if t == BaseType(BaseTy.Tensor):
+        return NamedCType(binds, BaseCType(scalar_t))
+    else:
+        raise AssertionError(f"unrecognized type {repr(t)}")
+
+
+# The actual ufunc template function the user writes.  Everything here
+# is done in the computation type.  compute_t is opmath_t in CUDA and scalar_t
+# in CPU
+def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType:
+    r = cpp.valuetype_type(t, binds=binds, symint=False)
+    if r is not None:
+        return r
+
+    if t == BaseType(BaseTy.Scalar):
+        return NamedCType(binds, compute_t)
+    elif t == BaseType(BaseTy.Tensor):
+        return NamedCType(binds, compute_t)
+    else:
+        raise AssertionError(f"unrecognized type {repr(t)}")
+
+
+def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
+    return Binding(
+        nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t),
+        name=a.name,
+        default=None,
+        argument=a,
+    )
+
+
+def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
+    return Binding(
+        nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t),
+        name=a.name,
+        default=None,
+        argument=a,
+    )
+
+
+def ufunc_argument(a: Argument, compute_t: CType) -> Binding:
+    return Binding(
+        nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t),
+        name=a.name,
+        default=None,
+        argument=a,
+    )
+
+
+@dataclass(frozen=True)
+class UfunctorBindings:
+    ctor: list[Binding]
+    apply: list[Binding]
+
+
+# ufunctors are a CUDA-only concept representing functors that take some of
+# their arguments on a host-side constructor, and the rest in the device-side
+# apply.  E.g.,
+#
+# template 
+# struct CUDAFunctorOnSelf_add {
+#   using opmath_t = at::opmath_type;
+#   opmath_t other_;
+#   opmath_t alpha_;
+#   CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {}
+#   __device__ scalar_t operator()(scalar_t self) {
+#     return ufunc::add(static_cast(self), other_, alpha_);
+#   }
+# };
+#
+# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers
+# to the operator() definition
+def ufunctor_arguments(
+    g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType
+) -> UfunctorBindings:
+    ctor = []
+    apply = []
+    for a in g.functional.func.arguments.flat_non_out:
+        if a.type.is_tensor_like():
+            if scalar_tensor_idx == 0:
+                # put it in the ctor anyway
+                ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
+                scalar_tensor_idx = None
+            else:
+                if scalar_tensor_idx is not None:
+                    scalar_tensor_idx -= 1
+                apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t))
+        else:
+            ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
+    assert scalar_tensor_idx is None
+    return UfunctorBindings(ctor=ctor, apply=apply)
+
+
+# ufuncs are the inner loop template functions that you wrote in ufunc/add.h
+# which do the actual computation in question.  E.g.,
+#
+# template 
+# C10_HOST_DEVICE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ {
+#   return self + alpha * other;
+# }
+#
+# In this file, we refer to T as compute_t which is bound by caller
+def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]:
+    return [
+        ufunc_argument(a, compute_t=compute_t)
+        for a in g.functional.func.arguments.flat_non_out
+    ]
+
+
+# Stubs are the DispatchStub trampolines that CPU kernels use to get to their
+# vectorized versions.  E.g.,
+#
+# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
+# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
+def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]:
+    # stubs drop all tensor arguments (they are implicit in the TensorIterator
+    # argument and keep everything else)
+    return [
+        r
+        for a in g.out.func.arguments.flat_non_out
+        if not a.type.is_tensor_like()
+        for r in structured.argument(a)
+    ]
diff --git a/phivenv/Lib/site-packages/torchgen/api/unboxing.py b/phivenv/Lib/site-packages/torchgen/api/unboxing.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1249692aeeabbb95f4c4ecb6a8dfdd9015901eb
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/api/unboxing.py
@@ -0,0 +1,241 @@
+from __future__ import annotations
+
+from torchgen.api import cpp
+from torchgen.api.types import Binding, CppSignatureGroup, CType
+from torchgen.model import (
+    Argument,
+    BaseTy,
+    BaseType,
+    ListType,
+    NativeFunction,
+    OptionalType,
+    Type,
+)
+
+
+# This file generates the code for unboxing wrappers, i.e., the glue logic to unbox a boxed operator and convert the
+# ivalues from stack to correct arguments to the unboxed kernel, based on corresponding JIT schema. This codegen is
+# an alternative way to generate unboxing wrappers similar to the existing C++ metaprogramming approach but gets the
+# job done statically. These generated unboxing wrappers will be useful under the scenario where we need to register
+# a fixed set of operators known at compile time and thus can save some time in runtime initialization phase.
+#
+# Here's an example on how the codegen works:
+#
+# - Function Schema (source of truth)
+#
+#      aten::empty.names(int[] size, *, Dimname[]? names,
+#                        ScalarType? dtype=None, Layout? layout=None,
+#                        Device? device=None, bool? pin_memory=None,
+#                        MemoryFormat? memory_format=None) -> Tensor
+# - Argument Conversion
+#       Generates C++ code to convert an ivalue (from stack) to its underlying C++ type.
+#    - int[] size
+#        ```cpp
+#           const c10::List size_list_in = (std::move(peek(stack, 0, 7))).toList();
+#
+#           std::vector size_vec;
+#           for (c10::IValue size_elem: size_list_in) {
+#               int64_t size_base = size_elem.to();
+#               size_vec.push_back(size_base);
+#           }
+#           at::ArrayRef size_list_out(size_vec);
+#                                 ~~~~~~~~~~~~~ <-- The converted argument from ivalues in the stack.
+#                                                   Will be passed to unboxed kernel.
+#       ```
+#    - Dimname[]? names
+#       ```cpp
+#           ::std::optional names_opt = (std::move(peek(stack, 1, 7))).toOptional();
+#           ::std::optional> names_opt_out;
+#           if (names_opt.has_value()) {
+#                         ~~~~~~~~~~~ <-- Unwrapping optional shell
+#               const c10::IValue names_opt_in = names_opt.value();
+#               const c10::List names_list_in = names_opt_in.toList();
+#
+#               std::vector names_vec;
+#               for (c10::IValue names_elem: names_list_in) {
+#                                ~~~~~~~~~~~~~~~~~~~~~~~~~ <-- Unrolling list, then convert elements one by one.
+#                   at::Dimname names_base = names_elem.to();
+#                   names_vec.push_back(names_base);
+#               }
+#               at::ArrayRef names_list_out(names_vec);
+#
+#               names_opt_out = ::std::optional>(names_list_out);
+#           } else {
+#               names_opt_out = ::std::optional>();
+#           }
+#       ```
+#    - ScalarType? dtype (similarly for the rest of the arguments)
+#       ```cpp
+#           ::std::optional dtype_opt = (std::move(peek(stack, 2, 7))).toOptional();
+#           ::std::optional dtype_opt_out;
+#           if (dtype_opt.has_value()) {
+#               const c10::IValue dtype_opt_in = dtype_opt.value();
+#               at::ScalarType dtype_base = dtype_opt_in.to();
+#                                                        ~~~~~~~~~~~~~~~~~~~~ <-- For base types, convert ivalue to it
+#                                                                                 directly using ".to()" API.
+#               dtype_opt_out = ::std::optional(dtype_base);
+#           } else {
+#               dtype_opt_out = ::std::optional();
+#           }
+#       ```
+#
+# - Unboxed Kernel Call
+#   ```cpp
+#       auto result_ = torch::empty(
+#           size_list_out,
+#           names_opt_out,
+#           options,
+#           memory_format_opt_out
+#       );
+#   ```
+#
+# - Push Result Back to Stack
+#   ```cpp
+#       drop(stack, 7);
+#       pack(stack, std::move(result_));
+#   ```
+connector = "\n\t"
+
+
+# Return unboxing function name for a NativeFunction
+def name(f: NativeFunction) -> str:
+    return f.func.name.unambiguous_name()
+
+
+# Convert all the arguments in a NativeFunction to C++ code
+def convert_arguments(f: NativeFunction) -> tuple[list[Binding], list[str]]:
+    # we need the 'self' argument so method needs to be False
+    args = (
+        CppSignatureGroup.from_native_function(f, method=False)
+        .most_faithful_signature()
+        .arguments()
+    )
+    code_list = [
+        f"c10::IValue {args[i].name} = std::move(peek(stack, {i}, {len(args)}));"
+        for i in range(len(args))
+    ] + [""]
+    binding_list = []
+    for arg in args:
+        # expecting only Argument
+        if not isinstance(arg.argument, Argument):
+            raise Exception(  # noqa: TRY002
+                f"Unexpected argument type, expecting `Argument` but got {arg}"
+            )
+        argument: Argument = arg.argument
+        unboxed_name, _, code, decl = argumenttype_ivalue_convert(
+            argument.type,
+            argument.name,
+            mutable=argument.is_write,
+        )
+        code_list.extend(decl)
+        code_list.extend(code)
+        binding_list.append(arg.with_name(unboxed_name))
+    return binding_list, code_list
+
+
+# Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
+# (1) the C++ code necessary to unbox the argument
+# (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
+def argumenttype_ivalue_convert(
+    t: Type, arg_name: str, *, mutable: bool = False
+) -> tuple[str, CType, list[str], list[str]]:
+    # Unboxing is for mobile, which doesn't care about SymInts
+    ctype = cpp.argumenttype_type(
+        t=t, mutable=mutable, binds=arg_name, symint=False
+    ).type
+
+    if isinstance(t, BaseType):
+        out_name = f"{arg_name}_base"
+        code, decl = _gen_code_base_type(
+            arg_name=arg_name, out_name=out_name, ctype=ctype
+        )
+    elif isinstance(t, OptionalType):
+        out_name = f"{arg_name}_opt_out"
+        code, decl = _gen_code_optional_type(
+            arg_name=arg_name,
+            out_name=out_name,
+            t=t,
+            ctype=ctype,
+        )
+    elif isinstance(t, ListType):
+        out_name = f"{arg_name}_list_out"
+        code, decl = _gen_code_list_type(
+            arg_name=arg_name,
+            out_name=out_name,
+            t=t,
+            ctype=ctype,
+        )
+    else:
+        raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}")  # noqa: TRY002
+    return out_name, ctype, code, decl
+
+
+def _gen_code_base_type(
+    arg_name: str, out_name: str, ctype: CType
+) -> tuple[list[str], list[str]]:
+    return [
+        f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
+    ], []
+
+
+def _gen_code_optional_type(
+    arg_name: str, out_name: str, t: OptionalType, ctype: CType
+) -> tuple[list[str], list[str]]:
+    in_name = f"{arg_name}_opt_in"
+    res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name)
+    return (
+        f"""
+auto {arg_name}_opt = {arg_name}.toOptional();
+{ctype.cpp_type(strip_ref=True)} {out_name};
+if ({arg_name}_opt.has_value()) {{
+    const c10::IValue {in_name} = {arg_name}_opt.value();
+    {connector.join(res_code)}
+    {out_name} = {ctype.cpp_type(strip_ref=True)}({res_name});
+}} else {{
+    {out_name} = {ctype.cpp_type(strip_ref=True)}();
+}}
+        """.split("\n"),
+        decl,
+    )
+
+
+def _gen_code_list_type(
+    arg_name: str, out_name: str, t: ListType, ctype: CType
+) -> tuple[list[str], list[str]]:
+    in_name = f"{arg_name}_list_in"
+    elem_name = f"{arg_name}_elem"
+    code = [f"const c10::List {in_name} = {arg_name}.toList();"]
+    res_name, res_ctype, res_code, decl = argumenttype_ivalue_convert(t.elem, elem_name)
+    # handle list type with size, e.g., bool[4]
+    if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool and t.size:
+        code.extend(
+            f"""
+{ctype.cpp_type(strip_ref=True)} {out_name} = as_array<{res_ctype.cpp_type(strip_ref=True)}, {t.size}>({in_name});
+            """.split("\n")
+        )
+    # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional>
+    elif isinstance(t.elem, OptionalType):
+        code.extend(
+            f"""
+{ctype.cpp_type(strip_ref=True)} {out_name};
+for (c10::IValue {elem_name}: {in_name}) {{
+    {connector.join(res_code)}
+    {out_name}.push_back({res_name});
+}}
+            """.split("\n")
+        )
+    else:
+        # use ArrayRef as default.
+        vec_name = arg_name + "_vec"
+        # need to bring vector instantiation out of scope so that ArrayRef has valid data
+        decl.append(f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};")
+        code.extend(
+            f"""
+for (c10::IValue {elem_name}: {in_name}) {{
+    {connector.join(res_code)}
+    {vec_name}.push_back({res_name});
+}}
+{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
+            """.split("\n")
+        )
+    return code, decl
diff --git a/phivenv/Lib/site-packages/torchgen/code_template.py b/phivenv/Lib/site-packages/torchgen/code_template.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8a37df8c779f8b418cccad7ff337f57e778da47
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/code_template.py
@@ -0,0 +1,108 @@
+from __future__ import annotations
+
+import itertools
+import re
+import textwrap
+from typing import TYPE_CHECKING
+
+
+if TYPE_CHECKING:
+    from collections.abc import Mapping, Sequence
+
+
+# match $identifier or ${identifier} and replace with value in env
+# If this identifier is at the beginning of whitespace on a line
+# and its value is a list then it is treated as
+# block substitution by indenting to that depth and putting each element
+# of the list on its own line
+# if the identifier is on a line starting with non-whitespace and a list
+# then it is comma separated ${,foo} will insert a comma before the list
+# if this list is not empty and ${foo,} will insert one after.
+
+
+class CodeTemplate:
+    substitution_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})"
+    substitution = re.compile(substitution_str, re.MULTILINE)
+
+    pattern: str
+    filename: str
+
+    @staticmethod
+    def from_file(filename: str) -> CodeTemplate:
+        with open(filename) as f:
+            return CodeTemplate(f.read(), filename)
+
+    def __init__(self, pattern: str, filename: str = "") -> None:
+        self.pattern = pattern
+        self.filename = filename
+
+    def substitute(
+        self, env: Mapping[str, object] | None = None, **kwargs: object
+    ) -> str:
+        if env is None:
+            env = {}
+
+        def lookup(v: str) -> object:
+            assert env is not None
+            return kwargs[v] if v in kwargs else env[v]
+
+        def indent_lines(indent: str, v: Sequence[object]) -> str:
+            content = "\n".join(
+                itertools.chain.from_iterable(str(e).splitlines() for e in v)
+            )
+            content = textwrap.indent(content, prefix=indent)
+            # Remove trailing whitespace on each line
+            return "\n".join(map(str.rstrip, content.splitlines())).rstrip()
+
+        def replace(match: re.Match[str]) -> str:
+            indent = match.group(1)
+            key = match.group(2)
+            comma_before = ""
+            comma_after = ""
+            if key[0] == "{":
+                key = key[1:-1]
+                if key[0] == ",":
+                    comma_before = ", "
+                    key = key[1:]
+                if key[-1] == ",":
+                    comma_after = ", "
+                    key = key[:-1]
+            v = lookup(key)
+            if indent is not None:
+                if not isinstance(v, list):
+                    v = [v]
+                return indent_lines(indent, v)
+            elif isinstance(v, list):
+                middle = ", ".join([str(x) for x in v])
+                if len(v) == 0:
+                    return middle
+                return comma_before + middle + comma_after
+            else:
+                return str(v)
+
+        return self.substitution.sub(replace, self.pattern)
+
+
+if __name__ == "__main__":
+    c = CodeTemplate(
+        """\
+    int foo($args) {
+
+        $bar
+            $bar
+        $a+$b
+    }
+    int commatest(int a${,stuff})
+    int notest(int a${,empty,})
+    """
+    )
+    print(
+        c.substitute(
+            args=["hi", 8],
+            bar=["what", 7],
+            a=3,
+            b=4,
+            stuff=["things...", "others"],
+            empty=[],
+        )
+    )
diff --git a/phivenv/Lib/site-packages/torchgen/context.py b/phivenv/Lib/site-packages/torchgen/context.py
new file mode 100644
index 0000000000000000000000000000000000000000..a288d96559a5128ce576ba116778e7c26fb513e9
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/context.py
@@ -0,0 +1,134 @@
+from __future__ import annotations
+
+import contextlib
+import functools
+from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
+
+import torchgen.local as local
+from torchgen.model import (
+    BackendIndex,
+    DispatchKey,
+    NativeFunction,
+    NativeFunctionsGroup,
+    NativeFunctionsViewGroup,
+)
+from torchgen.utils import context, S, T
+
+
+if TYPE_CHECKING:
+    from collections.abc import Iterator
+
+
+# Helper functions for defining generators on things in the model
+
+F = TypeVar(
+    "F",
+    NativeFunction,
+    NativeFunctionsGroup,
+    NativeFunctionsViewGroup,
+    Union[NativeFunction, NativeFunctionsGroup],
+    Union[NativeFunction, NativeFunctionsViewGroup],
+)
+
+F2 = TypeVar(
+    "F2",
+    NativeFunction,
+    NativeFunctionsGroup,
+    Optional[NativeFunction],
+    bool,
+    str,
+)
+
+F3 = TypeVar("F3", tuple[NativeFunction, Any], list[NativeFunction])
+
+
+@contextlib.contextmanager
+def native_function_manager(
+    g: NativeFunctionsGroup | NativeFunctionsViewGroup | NativeFunction,
+) -> Iterator[None]:
+    if isinstance(g, NativeFunctionsGroup):
+        # By default, we associate all errors with structured native functions
+        # with the out variant.  In some cases, it might be better to have
+        # a more specific place to hang things; if so, use
+        # native_function_manager again on the inside
+        f = g.out
+    elif isinstance(g, NativeFunctionsViewGroup):
+        # We associate errors with the view operator
+        f = g.view
+    else:
+        f = g
+    with context(lambda: f"in native_functions.yaml line {f.loc}:\n  {f.func}"):
+        with local.parametrize(
+            use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
+            use_ilistref_for_tensor_lists=f.part_of_structured_group,
+        ):
+            yield
+
+
+# Given a function that operates on NativeFunction, wrap it into a new function
+# that sets some appropriate context managers for that native function.
+# YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound
+# (you will get an error if we try to access the local variables without having
+# set them).
+def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]:
+    @functools.wraps(func)
+    def wrapper(f: F) -> T:
+        with native_function_manager(f):
+            return func(f)
+
+    return wrapper
+
+
+def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]:
+    @functools.wraps(func)
+    def wrapper(f: F, f2: F2) -> T:
+        # The first native_function is assumed to be the one with the appropriate context.
+        with native_function_manager(f):
+            return func(f, f2)
+
+    return wrapper
+
+
+def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]:
+    @functools.wraps(func)
+    def wrapper(slf: S, f: F) -> T:
+        with native_function_manager(f):
+            return func(slf, f)
+
+    return wrapper
+
+
+def method_with_nested_native_function(
+    func: Callable[[S, F3], T],
+) -> Callable[[S, F3], T]:
+    @functools.wraps(func)
+    def wrapper(slf: S, f: F3) -> T:
+        with native_function_manager(f[0]):
+            return func(slf, f)
+
+    return wrapper
+
+
+# Convenience decorator for functions that explicitly take in a BackendIndex,
+# instead of indirectly taking one in as a closure
+def with_native_function_and_index(
+    func: Callable[[F, BackendIndex], T],
+) -> Callable[[F, BackendIndex], T]:
+    @functools.wraps(func)
+    def wrapper(f: F, backend_index: BackendIndex) -> T:
+        with native_function_manager(f):
+            return func(f, backend_index)
+
+    return wrapper
+
+
+# Convenience decorator for functions that explicitly take in a Dict of BackendIndices
+def with_native_function_and_indices(
+    func: Callable[[F, dict[DispatchKey, BackendIndex]], T],
+) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]:
+    @functools.wraps(func)
+    def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T:
+        with native_function_manager(f):
+            return func(f, backend_indices)
+
+    return wrapper
diff --git a/phivenv/Lib/site-packages/torchgen/dest/__init__.py b/phivenv/Lib/site-packages/torchgen/dest/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3db7b19d69fbcb8ce74e03dba7a53e85f2b17c20
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/dest/__init__.py
@@ -0,0 +1,19 @@
+from torchgen.dest.lazy_ir import (
+    generate_non_native_lazy_ir_nodes as generate_non_native_lazy_ir_nodes,
+    GenLazyIR as GenLazyIR,
+    GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition,
+    GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition,
+)
+from torchgen.dest.native_functions import (
+    compute_native_function_declaration as compute_native_function_declaration,
+)
+from torchgen.dest.register_dispatch_key import (
+    gen_registration_headers as gen_registration_headers,
+    gen_registration_helpers as gen_registration_helpers,
+    RegisterDispatchKey as RegisterDispatchKey,
+)
+from torchgen.dest.ufunc import (
+    compute_ufunc_cpu as compute_ufunc_cpu,
+    compute_ufunc_cpu_kernel as compute_ufunc_cpu_kernel,
+    compute_ufunc_cuda as compute_ufunc_cuda,
+)
diff --git a/phivenv/Lib/site-packages/torchgen/dest/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/dest/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..948e91a051c6de8aea78f19a86a3b89ae638f6f8
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/dest/__pycache__/__init__.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/dest/__pycache__/lazy_ir.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/dest/__pycache__/lazy_ir.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..218c7e97883ebc58281924a6cf4acd201c815321
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/dest/__pycache__/lazy_ir.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..048ef4c14a5ab74f106ce02c6aa4c8a24daebeb8
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/dest/__pycache__/lazy_ts_lowering.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/dest/__pycache__/native_functions.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/dest/__pycache__/native_functions.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9753c67a10fd1c66f257e6f60648061cde53c0c3
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/dest/__pycache__/native_functions.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/dest/__pycache__/register_dispatch_key.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/dest/__pycache__/register_dispatch_key.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1d7e7e04b8ef1f98880a040e9c7161746082adfb
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/dest/__pycache__/register_dispatch_key.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/dest/__pycache__/ufunc.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/dest/__pycache__/ufunc.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..206719de855e82cf240899d48d336dc4ca551bde
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/dest/__pycache__/ufunc.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/dest/lazy_ir.py b/phivenv/Lib/site-packages/torchgen/dest/lazy_ir.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad464b9e20ea112aa917795ac5a08ee6d628bf7b
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/dest/lazy_ir.py
@@ -0,0 +1,707 @@
+from __future__ import annotations
+
+import itertools
+from abc import ABC
+from dataclasses import dataclass
+from typing import Any
+
+import torchgen.api.dispatcher as dispatcher
+from torchgen.api.lazy import (
+    getValueT,
+    isValueType,
+    LazyArgument,
+    LazyIrProperties,
+    LazyIrSchema,
+    tensorListValueT,
+)
+from torchgen.api.translate import translate
+from torchgen.api.types import (
+    BaseCType,
+    Binding,
+    deviceT,
+    DispatcherSignature,
+    kernel_signature,
+    NativeSignature,
+    OptionalCType,
+    VectorCType,
+)
+from torchgen.context import method_with_native_function
+from torchgen.dest.lazy_ts_lowering import ts_lowering_body
+from torchgen.model import (
+    Argument,
+    BackendIndex,
+    BackendMetadata,
+    BaseTy,
+    BaseType,
+    FunctionSchema,
+    ListType,
+    NativeFunction,
+    NativeFunctionsGroup,
+)
+
+
+def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
+    """
+    Given a LazyArgument,
+    generate a c++ string for materializing an rvalue of that arg for passing into
+    a lazy Node constructor.
+    """
+
+    # TODO: Matching on CType seems wrong; should be matching on Type
+    if isValueType(arg.lazy_type):
+        if isinstance(arg.lazy_type, BaseCType):
+            if arg.is_wrapped_scalar:
+                return f"node_{arg.name}"
+            elif arg.lazy_type.type is tensorListValueT:
+                return f"lazy_{arg.name}_tensorlist"
+            elif arg.is_symint_or_list:
+                return f"GetSymIntValue({arg.name})"
+            return f"lazy_{arg.name}->GetIrValue()"
+        elif isinstance(arg.lazy_type, OptionalCType):
+            if arg.is_symint_or_list:
+                # TODO: I don't understand when you should put lazy_ in the name
+                # or not
+                return f"{arg.name} ? std::make_optional(GetSymIntValue(*{arg.name})) : ::std::nullopt"
+            elif arg.is_wrapped_scalar:
+                return f"node_{arg.name}"
+            return (
+                f"lazy_{arg.name} ? "
+                f"std::make_optional(lazy_{arg.name}->GetIrValue()) : "
+                "::std::nullopt"
+            )
+        else:
+            raise AssertionError(
+                f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
+            )
+    else:
+        # NB: this is here because right now we aren't treating SymInt[] as a
+        # value type; when we do this needs to move above
+        # NB: we cannot test arg.lazy_type as we've already specified it is an
+        # int64_t and so we cannot distinguish between SymInt and int64_t
+        if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType(
+            BaseTy.SymInt
+        ):
+            if arg.symint:
+                return f"GetSymIntArrayRefValue({arg.name})"
+            else:
+                return f"std::vector({arg.name}.begin(), {arg.name}.end())"
+        elif isinstance(arg.lazy_type, VectorCType) and isinstance(
+            arg.lazy_type.elem, BaseCType
+        ):
+            return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())"
+        elif (
+            isinstance(arg.lazy_type, OptionalCType)
+            and isinstance(arg.lazy_type.elem, VectorCType)
+            and isinstance(arg.lazy_type.elem.elem, BaseCType)
+        ):
+            return f"torch::lazy::ToOptionalVector<{arg.lazy_type.elem.elem.type}>({arg.name})"
+        else:
+            return f"{arg.name}"
+
+
+def node_ctor_inputs(schema: LazyIrSchema) -> str:
+    """
+    Produce a formatted string with the arguments as passed into the constructor of a node class.
+    """
+    node_ctor_values = [
+        node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args()
+    ]
+    return ", ".join(node_ctor_values)
+
+
+def gen_fallback_code(
+    schema: LazyIrSchema,
+    sig: DispatcherSignature | NativeSignature,
+    overload_name: str,
+) -> str:
+    """
+    Generate code that falls back to eager conditioned on a predicate
+    """
+    dispatcher_sig = DispatcherSignature.from_schema(schema.func)
+    exprs = translate(sig.arguments(), dispatcher_sig.arguments())
+    fallback_args = ",\n                ".join([a.expr for a in exprs])
+    if len(overload_name):
+        aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})"
+    else:
+        aten_op_str = f"ATEN_OP({schema.aten_name})"
+    return f"""
+        if (force_eager_fallback({aten_symbol(schema)})) {{
+            return at::native::call_fallback_fn_symint<<c_eager_fallback, {aten_op_str}>::call(
+                {fallback_args}
+            );
+        }}
+"""
+
+
+def aten_symbol(schema: LazyIrSchema) -> str:
+    missing_interned_strings = {
+        "sigmoid_backward",
+    }
+    if schema.aten_name in missing_interned_strings:
+        return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")'
+
+    if not schema.aten_name.startswith("at::"):
+        return f"at::aten::{schema.aten_name}"
+    else:
+        return schema.aten_name
+
+
+# converts  all tensor-like arguments to meta tensors. Returns:
+# (1) a string containing all of the logic that does the conversions.
+# (2) a context, to be used by translate(), with all of the relevant bindings.
+def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
+    context: list[Binding] = []
+    unwrapped_tensor_args: list[str] = []
+    for arg in sig.arguments():
+        if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
+            unwrapped_name = f"{arg.name}_meta"
+            unwrapped_tensor_args.append(
+                f"auto {unwrapped_name} = to_meta({arg.name});"
+            )
+            context.append(arg.with_name(unwrapped_name))
+        else:
+            context.append(arg)
+    unwrap_tensor_args_str = "\n        ".join(unwrapped_tensor_args)
+    return unwrap_tensor_args_str, context
+
+
+@dataclass(frozen=True)
+class GenLazyIR(ABC):
+    backend_index: BackendIndex
+    backend_name: str
+    node_base: str
+    use_lazy_shape: bool
+
+    @method_with_native_function
+    def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
+        func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
+        metadata = self.backend_index.get_kernel(
+            f.functional if isinstance(f, NativeFunctionsGroup) else f
+        )
+        schema = LazyIrSchema(
+            func, symint=metadata is not None and metadata.supports_symint()
+        )
+        return self.gen(schema)
+
+    # there is no lowering functionality generated unless this IR base class is subclassed and
+    # implemented as a backend-specific node
+    def lowering_function(self, schema: LazyIrSchema) -> str:
+        return ""
+
+    def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
+        return ""
+
+    def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
+        return f"""bool CanBeReused({node_ctor_args}) const {{
+    return false;
+    }}"""
+
+    def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
+        value_args = schema.filtered_args(values=True, scalars=False)
+        # backends can customize the way the node base class constructor is called,
+        # as long as all of its arguments can be generated from information available from the schema
+        base_ctor_value_args_list = []
+        for arg in value_args:
+            if isinstance(arg.lazy_type, (BaseCType, VectorCType)):
+                base_ctor_value_args_list.append(f"{arg.name}")
+            elif isinstance(arg.lazy_type, OptionalCType):
+                base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)")
+            else:
+                raise AssertionError(
+                    f"Unsupported type ({arg.lazy_type}) - add support if necessary"
+                )
+        base_ctor_value_args = ", ".join(base_ctor_value_args_list)
+
+        scalar_args = schema.filtered_args(values=False, scalars=True)
+
+        # Shape construction.
+        # Conditionally build shape depending on specified shape property
+        if schema.properties.ShapePrecompute:
+            shape_ctor_arg = "std::move(shapes),"
+        elif schema.properties.ShapeCompute:
+            shape_args = [a.name for a in value_args]
+            shape_args.extend(a.name for a in scalar_args)
+            shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)}),"
+        elif schema.properties.ShapeCache:
+            shape_args = [f"operand({i})" for i in range(len(value_args))]
+            shape_args.extend(a.name for a in scalar_args)
+            shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }},"
+        else:
+            shape_ctor_arg = ""
+
+        scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args)
+
+        return f"""{self.node_base}(
+              {schema.node_name}::ClassOpKind(),
+              OpList{{{base_ctor_value_args}}},
+              {shape_ctor_arg}
+              /* num_outputs */ {len(schema.returns)},
+              torch::lazy::MHash({scalar_hashes}))"""
+
+    def gen(self, schema: LazyIrSchema) -> list[str]:
+        opkind = schema.opkind or aten_symbol(schema)
+
+        # for now, we just want one IR class decl and soon after also the method defs
+        # and we use the functional version not out/inplace.
+        all_args = schema.filtered_args()
+        scalar_args = schema.filtered_args(values=False, scalars=True)
+
+        ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
+        reuse_ctor_args = ", ".join(ctor_args)
+        if self.use_lazy_shape and schema.properties.ShapePrecompute:
+            ctor_args.append("std::vector&& shapes")
+        node_ctor_args = ", ".join(ctor_args)
+
+        scalar_initializers = ",\n        ".join(
+            [
+                # This code is just special casing the mapping from string_view -> strings
+                f"{a.name}({a.name}.has_value() ? ::std::make_optional(std::string(*{a.name})) : ::std::nullopt)"
+                if a.lazy_type.cpp_type() == "::std::optional"
+                else f"{a.name}({a.name})"
+                for a in scalar_args
+            ]
+        )
+        if len(scalar_initializers):
+            scalar_initializers = f",\n        {scalar_initializers}"
+        scalar_decls = "\n  ".join(
+            [
+                f"std::string {a.name};"
+                if a.lazy_type.cpp_type() == "c10::string_view"
+                else f"::std::optional {a.name};"
+                if a.lazy_type.cpp_type() == "::std::optional"
+                else f"{a.lazy_type.cpp_type()} {a.name};"
+                for a in scalar_args
+            ]
+        )
+        optional_values = [
+            arg.name
+            for arg in schema.filtered_args(values=True, scalars=False)
+            if isinstance(arg.lazy_type, OptionalCType)
+        ]
+        has_optional_decls = "\n  ".join(
+            [f"bool has_{value}: 1;" for value in optional_values]
+        )
+        has_optional_defs = "\n    ".join(
+            [f"has_{value} = !!{value};" for value in optional_values]
+        )
+        members_to_string = []
+        for arg in scalar_args:
+            if isinstance(arg.lazy_type, OptionalCType):
+                value = f"{arg.name}.value()"
+                if arg.is_generator:
+                    value = '"torch.Generator()"'
+                members_to_string.append(
+                    f"""if ({arg.name}.has_value()) {{
+      ss << ", {arg.name}=" << {value};
+    }} else {{
+      ss << ", {arg.name}=null";
+    }}"""
+                )
+            else:
+                members_to_string.append(f'ss << ", {arg.name}=" << {arg.name};')
+        members_to_string_str = "\n    ".join(members_to_string)
+
+        return [
+            f"""\
+class {schema.node_name} : public {self.node_base} {{
+ public:
+  static torch::lazy::OpKind ClassOpKind() {{
+    return torch::lazy::OpKind({opkind});
+  }}
+
+  {schema.node_name}({node_ctor_args})
+      : {self.node_base_ctor_call(schema)}{scalar_initializers}
+  {{
+    {has_optional_defs}
+  }}
+
+  std::string ToString() const override {{
+    std::stringstream ss;
+    ss << {self.node_base}::ToString();
+    {members_to_string_str}
+    return ss.str();
+  }}
+
+  {self.create_function(schema, reuse_ctor_args)}
+
+  {self.can_be_reused_function(schema, reuse_ctor_args)}
+
+  {self.lowering_function(schema)}
+
+  {scalar_decls}
+  {has_optional_decls}
+
+}};
+
+""",
+        ]
+
+
+@dataclass(frozen=True)
+class GenTSLazyIR(GenLazyIR):
+    def lowering_function(self, schema: LazyIrSchema) -> str:
+        signature = """
+  torch::lazy::TSOpVector Lower(
+      std::shared_ptr function,
+      torch::lazy::TSLoweringContext* loctx) const override"""
+
+        if schema.properties.LowerDeclOnly:
+            return f"{signature};"
+        elif schema.properties.Lower:
+            return f"""{signature} {{
+    {ts_lowering_body(schema)}
+  }}
+            """
+        else:
+            return ""
+
+    def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
+        signature = f"static NodePtr Create({node_ctor_args})"
+        if schema.properties.CreateFnDeclOnly:
+            return f"{signature};"
+        elif not schema.properties.CreateFn:
+            return ""
+        return f"""{signature} {{
+    return ReuseOrMakeNode<{schema.node_name}>(data);
+  }}"""
+
+    def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
+        signature = f"bool CanBeReused({node_ctor_args}) const"
+        if schema.properties.CanBeReusedDeclOnly:
+            return f"{signature};"
+        elif not schema.properties.CanBeReused:
+            return ""
+        value_comparison = []
+        for arg in itertools.chain(schema.positional_values, schema.keyword_values):
+            if isinstance(arg.lazy_type, OptionalCType):
+                value_comparison.append(
+                    f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)"
+                )
+            else:
+                value_comparison.append(f"operand(i++) == {arg.name}")
+        for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars):
+            if isinstance(arg.lazy_type, OptionalCType):
+                value_comparison.append(
+                    f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))"
+                )
+            else:
+                value_comparison.append(f"this->{arg.name} == {arg.name}")
+        value_comparison_str = " &&\n        ".join(value_comparison)
+
+        return f"""{signature} {{
+    size_t i = 0;
+    return ({value_comparison_str});
+  }}"""
+
+
+@dataclass(frozen=True)
+class GenLazyNativeFuncDefinition:
+    class_method_name: str
+    backend_index: BackendIndex
+    tensor_class: str
+    gen_forced_fallback_code: bool
+    backend_namespace: str
+    get_tensorlist: str
+    get_tensor_or_wrap_number: str
+    try_get_tensor: str
+    metrics_counter: str
+    create_tensor: str
+    create_from_first_tensor: bool
+    create_aten_from_ltc_tensor: str
+    tuple_aten_from_ltc_tensors: str
+    lazy_tensor_ptr: str
+    get_device_fn: str
+
+    def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
+        value_args = schema.filtered_args(values=True, scalars=False)
+        # Generates lazy_{name} variables for LazyTensors wrapping input tensors
+        lazy_tensor_decls: list[str] = []
+        for arg in value_args:
+            if arg.is_wrapped_scalar:
+                if isinstance(arg.lazy_type, OptionalCType):
+                    lazy_tensor_decls.append(
+                        f"""auto node_{arg.name} = {arg.name} ?
+                std::make_optional(torch::lazy::LazyGraphExecutor::Get()->
+                    GetIrValueForScalarFromCodegen(*{arg.name}, *common_device)):
+                ::std::nullopt;"""
+                    )
+                else:
+                    lazy_tensor_decls.append(
+                        f"""auto node_{arg.name} = torch::lazy::LazyGraphExecutor::Get()->
+                            GetIrValueForScalarFromCodegen({arg.name}, *common_device);"""
+                    )
+            elif arg.is_symint_or_list:
+                continue  # values are extracted in isValueType
+            elif isinstance(arg.lazy_type, BaseCType):
+                if arg.lazy_type.type is tensorListValueT:
+                    lazy_tensor_decls.append(
+                        f"auto lazy_{arg.name}_tensorlist = "
+                        f"{self.backend_namespace}::{self.get_tensorlist}({arg.name});"
+                    )
+                else:
+                    lazy_tensor_decls.append(
+                        f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
+                        f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);"
+                    )
+            elif isinstance(arg.lazy_type, OptionalCType):
+                assert arg.lazy_type.elem == BaseCType(getValueT()), arg.lazy_type.elem
+                # TODO(alanwaketan): Maybe we want to apply GetLtcTensorOrCreateForWrappedNumber here, but hold it
+                # until we encounter a real world example.
+                lazy_tensor_decls.append(
+                    f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
+                    f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));"
+                )
+            else:
+                raise AssertionError(
+                    f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
+                )
+        return ("\n        ").join(lazy_tensor_decls)
+
+    def force_eager_fallback(
+        self,
+        func: NativeFunction,
+        schema: LazyIrSchema,
+        metadata: BackendMetadata,
+        sig: DispatcherSignature | NativeSignature,
+    ) -> str:
+        if self.gen_forced_fallback_code:
+            return gen_fallback_code(
+                schema, sig, overload_name=func.func.name.overload_name
+            )
+        return ""
+
+    def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str:
+        return f"{self.metrics_counter};"
+
+    def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str:
+        value_args = schema.filtered_args(values=True, scalars=False)
+        scalar_args = schema.filtered_args(values=False, scalars=True)
+        value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
+        optional_device = OptionalCType(BaseCType(deviceT))
+        optional_devices = [
+            a.name for a in scalar_args if a.lazy_type == optional_device
+        ]
+        assert len(value_types_names) > 0 or len(optional_devices) > 0, (
+            "Expected at least one Value or Device type"
+        )
+        get_device_str = (
+            f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})"
+        )
+        return f"""auto common_device = {get_device_str};
+        TORCH_INTERNAL_ASSERT(common_device);
+        """
+
+    def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str:
+        metadata = self.backend_index.get_kernel(func)
+        assert metadata is not None
+        all_args = schema.filtered_args()
+        returns_length = len(schema.returns)
+        # call the meta kernel if it exists, to compute output shape/dtype for our IR
+        # Note [Generated LTC Shape Functions]
+        # LTC uses meta tensors from core to do shape inference when possible, and otherwise
+        # we generate a shape function declaration that needs to be manually implemented.
+        # How do we detect which ops are eligible to use meta tensors?
+        # In general we should be able to use meta tensors not just on structured operators,
+        # but also on composite operators that are implemented in terms of structured kernels.
+        # We don't currently have a way of knowing at codegen time which ops are implemented that way.
+        # This is the case for all view and view_copy operators however, so we're going to
+        # use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them).
+        is_view_copy_op = "view_copy" in func.tags
+        is_structured = func.structured or func.structured_delegate is not None
+        if is_structured or is_view_copy_op:
+            meta_out = """
+std::vector shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};"""
+            if returns_length > 1:
+
+                def this_shape(i: int) -> str:
+                    return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())"
+
+                shapes_str = ",".join([this_shape(i) for i in range(returns_length)])
+                meta_out = "std::vector shapes{" + shapes_str + "};"
+
+            # Convert tensor args to the meta device and call it.
+            # (We can't pass in the input tensors directly, because they are "functional wrappers".
+            # If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.)
+            # Even at::meta:: functions might redispatch, e.g. if they call into view ops.
+            dispatcher_sig = DispatcherSignature.from_schema(func.func)
+            meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
+            meta_call_args = [
+                e.expr
+                for e in translate(
+                    meta_call_ctx, dispatcher_sig.arguments(), method=False
+                )
+            ]
+            if is_view_copy_op:
+                # view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel
+                assert func.has_composite_explicit_autograd_non_functional_kernel
+                dispatch_ns = "compositeexplicitautogradnonfunctional"
+            else:
+                dispatch_ns = "meta"
+            aten_name = schema.aten_name
+            # TODO: this is trolling
+            if func.func.has_symint() and metadata.supports_symint():
+                aten_name += "_symint"
+            shape_str = f"""\
+        {meta_conversion_str}
+        auto out_meta = at::{dispatch_ns}::{aten_name}({", ".join(meta_call_args)});
+        {meta_out}"""
+        else:
+            shape_sig = ComputeShapeSignature(
+                metadata.kernel, func, symint=metadata.supports_symint()
+            )
+            shape_str = f"""
+            auto shapes = {shape_sig.shape_call};"""
+
+        shape_str += f"""
+            TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});"""
+
+        # Calculating which dimensions are symbolic
+        func_schema_str = "aten::" + str(func.func)
+        shape_str += f"""
+            if(torch::lazy::symbolicShapeEnabled()){{
+                std::vector inputs = {{ {", ".join(str(a.name) for a in all_args)} }};
+                const char* schema_str = "{func_schema_str}";
+                applySymbolicShapesOnLT(schema_str, inputs, shapes);
+            }}
+        """
+        return shape_str
+
+    def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str:
+        node_ctor_input_str = node_ctor_inputs(schema)
+        return f"""torch::lazy::NodePtr node = torch::lazy::ReuseNode<{schema.node_name}>({node_ctor_input_str});
+        if (!node) {{
+            {self.shape_inference(func, schema)}
+            node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes));
+            CacheNode(node);
+        }}
+        """
+
+    def create_lazy_tensor(self, first_tensor_name: str | None = None) -> str:
+        # xla uses an instance method for tensor creation, for the time being
+        if self.create_from_first_tensor:
+            # TODO(whc) remove this if XLA switches to using static method for creation
+            assert first_tensor_name is not None, (
+                "Requires first tensor to create lazy tensor"
+            )
+            return f"{first_tensor_name}.{self.create_tensor}"
+        return f"{self.backend_namespace}::{self.create_tensor}"
+
+    def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str:
+        returns_length = len(schema.returns)
+        value_args = schema.filtered_args(values=True, scalars=False)
+        value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
+        first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None
+        bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}(
+                {self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));"""
+
+        if returns_length > 1:
+            assert len(value_types_names) > 0, (
+                "Code below assumes there is at least one tensor arg"
+            )
+            bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
+        for (int i = 0; i < {returns_length}; i++) {{
+            lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device));
+        }}
+        auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);"""
+
+        if schema.name.name.inplace or func.func.is_out_fn():
+            assert returns_length == 1, (
+                "We assumed there was no such case where an op is an in-place variant "
+                f"and has tuple outputs, but got tuple of len {returns_length}."
+            )
+            bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node);
+        auto& result = {first_tensor_name};"""
+
+        bridge_str += """
+        return result;"""
+        return bridge_str
+
+    @method_with_native_function
+    def __call__(self, func: NativeFunction) -> list[str]:
+        sig = kernel_signature(func, self.backend_index)
+        metadata = self.backend_index.get_kernel(func)
+        assert metadata is not None
+        schema = LazyIrSchema(func.func, symint=metadata.supports_symint())
+        return [
+            f"""\
+    {sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{
+        {self.force_eager_fallback(func, schema, metadata, sig)}
+        {self.metrics(func, schema)}
+        {self.get_device(func, schema)}
+        {self.lazy_tensor_decls(func, schema)}
+        {self.build_ir_node(func, schema)}
+        {self.return_aten_tensor(func, schema)}
+    }}\n
+    """
+        ]
+
+
+class ComputeShapeSignature:
+    """
+    Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
+    """
+
+    def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool) -> None:
+        self.__schema = LazyIrSchema(f.func, symint=symint)
+        self.__dispatch_args = ", ".join(
+            [a.decl() for a in dispatcher.arguments(f.func, symint=symint)]
+        )
+        self.__call_args = ", ".join(
+            [f"{arg.name}" for arg in self.__schema.filtered_args(generator=True)]
+        )
+        self.__kernel_name = kernel_name
+
+    def __decl_suffix(self) -> str:
+        return f"{self.__kernel_name}({self.__dispatch_args})"
+
+    def __call_suffix(self) -> str:
+        return f"{self.__kernel_name}({self.__call_args})"
+
+    @property
+    def shape_decl(self) -> str:
+        return f"TORCH_API std::vector compute_shape_{self.__decl_suffix()}"
+
+    @property
+    def shape_call(self) -> str:
+        return f"torch::lazy::compute_shape_{self.__call_suffix()}"
+
+
+@dataclass(frozen=True)
+class GenLazyShapeInferenceDefinition:
+    backend_index: BackendIndex
+    tensor_class: str
+
+    @method_with_native_function
+    def __call__(self, f: NativeFunction) -> list[str]:
+        metadata = self.backend_index.get_kernel(f)
+        assert metadata is not None
+
+        # See Note [Generated LTC Shape Functions]
+        is_view_copy_op = "view_copy" in f.tags
+        is_structured = f.structured or f.structured_delegate is not None
+        if is_structured or is_view_copy_op:
+            return []
+        else:
+            shape_sig = ComputeShapeSignature(
+                metadata.kernel, f, symint=metadata.supports_symint()
+            )
+            return ["\n".join([f"{shape_sig.shape_decl};"])]
+
+
+def generate_non_native_lazy_ir_nodes(
+    non_native: list[dict[str, Any]], gen_lazy_ir: GenLazyIR
+) -> list[str]:
+    """Generate the non-native lazy IR node classes"""
+    nodes = []
+    for op in non_native:
+        # Set default properties for Non-Native IRs
+        properties = LazyIrProperties("ShapeCache", "CanBeReused", "LowerDeclOnly")
+        for p in op.get("properties", []):
+            setattr(properties, p, True)
+
+        # non-native is assumed to want symint bindings if you wrote symint
+        schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties, symint=True)
+        schema.opkind = op.get("opkind")
+        nodes.append(gen_lazy_ir.gen(schema)[0])
+
+    return nodes
diff --git a/phivenv/Lib/site-packages/torchgen/dest/lazy_ts_lowering.py b/phivenv/Lib/site-packages/torchgen/dest/lazy_ts_lowering.py
new file mode 100644
index 0000000000000000000000000000000000000000..1efbd63d7e7722d39c314afdf5474f80a5994c28
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/dest/lazy_ts_lowering.py
@@ -0,0 +1,48 @@
+from torchgen.api.lazy import LazyArgument, LazyIrSchema
+from torchgen.api.types import OptionalCType
+
+
+def ts_lowering_body(schema: LazyIrSchema) -> str:
+    # for now, we just want one IR class decl and soon after also the method defs
+    # and we use the functional version not out/inplace.
+    emplace_arguments = []
+
+    def get_value(arg: LazyArgument) -> str:
+        if isinstance(arg.lazy_type, OptionalCType):
+            return f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr"
+        return "loctx->GetOutputOp(operand(i++))"
+
+    for arg in schema.positional_args:
+        if arg.is_lazy_value:
+            emplace_arguments.append(get_value(arg))
+            continue
+        emplace_arguments.append(f'"{arg.name}", {arg.name}')
+
+    emplace_arguments_str = "\n    ".join(
+        [f"arguments.emplace_back({a});" for a in emplace_arguments]
+    )
+    emplace_kwarg_values = [
+        f'"{arg.name}", {get_value(arg)}' for arg in schema.keyword_values
+    ]
+    emplace_kwarg_scalars = [
+        f'"{arg.name}", {arg.name}' for arg in schema.keyword_scalars
+    ]
+    emplace_kwarguments = "\n    ".join(
+        [
+            f"kwarguments.emplace_back({a});"
+            for a in emplace_kwarg_values + emplace_kwarg_scalars
+        ]
+    )
+    return f"""\
+    std::vector arguments;
+    std::vector kwarguments;
+    arguments.reserve({len(emplace_arguments)});
+    kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
+    size_t i = 0;
+    {emplace_arguments_str}
+    {emplace_kwarguments}
+    torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
+    TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
+
+    return {schema.aten_name}_out;
+"""
diff --git a/phivenv/Lib/site-packages/torchgen/dest/native_functions.py b/phivenv/Lib/site-packages/torchgen/dest/native_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..deea43fc5a453126c2e669842d86d1077977d46f
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/dest/native_functions.py
@@ -0,0 +1,84 @@
+from __future__ import annotations
+
+import torchgen.api.meta as meta
+import torchgen.api.structured as structured
+from torchgen.api.types import kernel_signature
+from torchgen.context import with_native_function_and_index
+from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup
+from torchgen.utils import mapMaybe
+
+
+def torch_api_key_word_prefix(bankend_index: BackendIndex) -> str:
+    if bankend_index.external:
+        return ""
+
+    # Although Intel GPU ATen library is out-of-tree, it still utilizes torchgen to produce structured
+    # kernels. Regarding these produced structured kernels, they should be visible for the Intel GPU ATen
+    # library. Therefore, we need to add "TORCH_XPU_API" prefix to these structured kernels,
+    # rather than "TORCH_API". Because the semantic of "TORCH_API" is "hidden" for out-of-tree backends.
+    # For other in-tree backends like cpu and cuda, they still use "TORCH_API" prefix with "visible" semantic.
+    device_torch_api_key_word_mapping = {
+        "XPU": "TORCH_XPU_API",
+    }
+
+    return (
+        device_torch_api_key_word_mapping.get(
+            bankend_index.dispatch_key.name, "TORCH_API"
+        )
+        + " "
+    )
+
+
+@with_native_function_and_index
+def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | None:
+    sig = kernel_signature(f, backend_index)
+    metadata = backend_index.get_kernel(f)
+    if metadata is None:
+        return None
+    if "legacy::" in metadata.kernel:
+        return None
+    else:
+        prefix = "static" if backend_index.external else "TORCH_API"
+        return f"{prefix} {sig.decl(name=metadata.kernel)};"
+
+
+@with_native_function_and_index
+def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list[str]:
+    meta_name = meta.name(g)
+    out_args = structured.impl_arguments(g)
+    metadata = backend_index.get_kernel(g)
+    if metadata is None:
+        return []
+    prefix = torch_api_key_word_prefix(backend_index)
+    return [
+        f"""\
+struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{
+void impl({", ".join(a.decl() for a in out_args)});
+}};
+"""
+    ]
+
+
+# Generates NativeFunctions.h, a list of forward declarations of all
+# actual kernel definitions we keep in aten/src/ATen/native/
+@with_native_function_and_index
+def compute_native_function_declaration(
+    g: NativeFunctionsGroup | NativeFunction, backend_index: BackendIndex
+) -> list[str]:
+    metadata = backend_index.get_kernel(g)
+    if isinstance(g, NativeFunctionsGroup):
+        if metadata is not None and metadata.structured:
+            if backend_index.external:
+                # Structured hasn't been tested with external backends yet.
+                raise AssertionError(
+                    "Structured external backend functions are not implemented yet."
+                )
+            else:
+                return gen_structured(g, backend_index)
+        else:
+            return list(
+                mapMaybe(lambda f: gen_unstructured(f, backend_index), g.functions())
+            )
+    else:
+        x = gen_unstructured(g, backend_index)
+        return [] if x is None else [x]
diff --git a/phivenv/Lib/site-packages/torchgen/dest/register_dispatch_key.py b/phivenv/Lib/site-packages/torchgen/dest/register_dispatch_key.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5c87b565637def335382c183f3b45aa0db8f565
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/dest/register_dispatch_key.py
@@ -0,0 +1,1016 @@
+from __future__ import annotations
+
+import itertools
+import textwrap
+from dataclasses import dataclass
+from typing import Literal, TYPE_CHECKING
+from typing_extensions import assert_never
+
+import torchgen.api.cpp as cpp
+import torchgen.api.meta as meta
+import torchgen.api.structured as structured
+from torchgen.api.translate import translate
+from torchgen.api.types import (
+    BaseCType,
+    Binding,
+    ConstRefCType,
+    CppSignature,
+    CppSignatureGroup,
+    DispatcherSignature,
+    Expr,
+    kernel_signature,
+    MutRefCType,
+    NamedCType,
+    NativeSignature,
+    tensorT,
+)
+from torchgen.context import method_with_native_function, native_function_manager
+from torchgen.model import (
+    Argument,
+    BackendIndex,
+    DeviceCheckType,
+    DispatchKey,
+    gets_generated_out_inplace_wrapper,
+    is_cuda_dispatch_key,
+    NativeFunction,
+    NativeFunctionsGroup,
+    SchemaKind,
+    TensorOptionsArguments,
+)
+from torchgen.utils import mapMaybe, Target
+
+
+if TYPE_CHECKING:
+    from torchgen.selective_build.selector import SelectiveBuilder
+
+
+def gen_registration_headers(
+    backend_index: BackendIndex,
+    per_operator_headers: bool,
+    rocm: bool,
+) -> list[str]:
+    if per_operator_headers:
+        headers = ["#include "]
+    else:
+        headers = ["#include "]
+
+    if backend_index.dispatch_key in (DispatchKey.CPU, DispatchKey.Meta):
+        headers.append("#include ")
+    elif backend_index.dispatch_key == DispatchKey.CUDA:
+        if rocm:
+            headers.append("#include ")
+        else:
+            headers.append("#include ")
+    elif backend_index.dispatch_key == DispatchKey.MPS:
+        headers.append("#include ")
+    elif backend_index.dispatch_key == DispatchKey.XPU:
+        # XPU specific, this header resides in third_party/torch-xpu-ops
+        headers.append("#include ")
+    elif backend_index.dispatch_key == DispatchKey.MTIA:
+        headers.append("#include ")
+    elif per_operator_headers:
+        headers += [
+            "#include ",
+            "#include ",
+            "#include ",
+            "#include ",
+        ]
+    else:
+        headers.append("#include ")
+
+    headers.append("#include ")
+    return headers
+
+
+def gen_empty_impl_names(
+    backend_index: BackendIndex,
+) -> tuple[str | None, str | None]:
+    empty_impl = None
+    empty_strided_impl = None
+
+    if backend_index.dispatch_key in (
+        DispatchKey.Meta,
+        DispatchKey.CPU,
+        DispatchKey.CUDA,
+        DispatchKey.MPS,
+        DispatchKey.XPU,
+        DispatchKey.MTIA,
+    ):
+        dispatch = str(backend_index.dispatch_key).lower()
+        empty_impl = f"at::detail::empty_{dispatch}"
+        empty_strided_impl = f"at::detail::empty_strided_{dispatch}"
+    elif backend_index.dispatch_key in (
+        DispatchKey.CompositeExplicitAutogradNonFunctional,
+        DispatchKey.QuantizedCPU,
+        DispatchKey.QuantizedCUDA,
+        DispatchKey.XPU,
+    ):
+        empty_impl = "at::empty"
+        empty_strided_impl = "at::empty_strided"
+
+    return empty_impl, empty_strided_impl
+
+
+def gen_create_out_helper(backend_index: BackendIndex) -> list[str]:
+    if backend_index.dispatch_key == DispatchKey.Meta:
+        empty_options = "options.device(at::kMeta)"
+    else:
+        empty_options = "options"
+
+    empty_impl, empty_strided_impl = gen_empty_impl_names(backend_index)
+    if empty_impl is None:
+        return []
+
+    return [
+        f"""
+Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
+  if (strides.empty()) {{
+      return {empty_impl}(sizes, {empty_options});
+  }} else {{
+      return {empty_strided_impl}(sizes, strides, {empty_options});
+  }}
+}}
+"""
+    ]
+
+
+def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> list[str]:
+    _, empty_strided_impl = gen_empty_impl_names(backend_index)
+    return (
+        []
+        if empty_strided_impl is None
+        else [
+            f"""
+std::optional maybe_create_proxy(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
+  if (out.strides() != strides) {{
+    return {empty_strided_impl}(sizes, strides, options);
+  }}
+  return std::nullopt;
+}}
+"""
+        ]
+    )
+
+
+def gen_resize_out_helper(backend_index: BackendIndex) -> list[str]:
+    if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
+        # The function isn't used by this key (since only functional ops have a kernel for this key),
+        # so we need to not include it to avoid a defined-but-not-used error.
+        return []
+    return [
+        """
+void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {
+  TORCH_CHECK(options.dtype() == out.dtype(),
+      "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead");
+  TORCH_CHECK(options.device() == out.device(),
+      "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead");
+  const bool resized = at::native::resize_output(out, sizes);
+  // Only restride if a resize occurred; otherwise we ignore the (advisory)
+  // strides from the meta function and directly use the output tensor's
+  // preexisting strides
+  if (resized) {
+    if (!strides.empty()) {
+      TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
+      // TODO: avoid the redispatch here
+      out.as_strided_(sizes, strides);
+    } else if (options.memory_format_opt().has_value()) {
+      out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
+    }
+  }
+}
+"""
+    ]
+
+
+def gen_check_inplace_helper(backend_index: BackendIndex) -> list[str]:
+    return [
+        """
+void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
+  // These checks are needed on those operators that:
+  //   1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm')
+  //   2) have particular typing rules (e.g. 'cumsum' and 'cumprod')
+  // For other operators (e.g. 'add'), 'TensorIterator' already checks
+  // these things separately.
+  TORCH_CHECK(options.dtype() == self.dtype(),
+      "Bad in-place call: ",
+      "input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match");
+  TORCH_CHECK(options.device() == self.device(),
+      "Bad in-place call: ",
+      "input tensor device ", self.device(), " and output tensor device ", options.device(), " should match");
+  TORCH_CHECK(sizes == self.sizes(),
+      "Bad in-place call: ",
+      "input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match");
+}
+"""
+    ]
+
+
+def gen_registration_helpers(backend_index: BackendIndex) -> list[str]:
+    return [
+        'C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")',
+        *gen_create_out_helper(backend_index),
+        *gen_resize_out_helper(backend_index),
+        *gen_check_inplace_helper(backend_index),
+        *gen_maybe_create_proxy_helper(backend_index),
+        "C10_DIAGNOSTIC_POP()",
+    ]
+
+
+# Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp).
+#
+#   - The primary function of this file is to register all of the
+#     implementations for the given dispatch key to the dispatcher,
+#     so they are available for use in PyTorch.  If dispatch is
+#     None, we generate schema (def) registrations and catchall
+#     registrations.
+#   - The secondary function of this file is to generate a wrapper
+#     around functions.  In CPUType these wrappers do nothing
+#     (and should be removed), but in other cases they handle
+#     DeviceGuard. A small extra benefit of wrappers is they
+#     are not overloaded, so they can be used in the registration
+#     API without having to disambiguate which overload you want
+#     (as would be the case if you directly registered native::
+#     functions).
+#   - The tertiary function of this file is to generate *static*
+#     cpp API bindings which can be used to bypass dispatcher
+#     directly to kernels, but with user-friendly cpp-style API
+@dataclass(frozen=True)
+class RegisterDispatchKey:
+    backend_index: BackendIndex
+
+    target: Literal[
+        Target.ANONYMOUS_DEFINITION,
+        Target.NAMESPACED_DEFINITION,
+        Target.NAMESPACED_DECLARATION,
+        Target.REGISTRATION,
+    ]
+
+    # Selector object to determine which operators to generate
+    # registration code for.
+    selector: SelectiveBuilder
+
+    # Whether or not we are actually code-genning for ROCm
+    rocm: bool
+
+    # Whether or not to generate symint registrations or not.  External users
+    # of codegen who don't care about symints can set this to false to get
+    # non-SymInt codegen
+    symint: bool
+
+    # The class that all unstructured native functions live under. This is used to improve
+    # compiler error messages when a kernel writer adds a native function with the wrong signature.
+    # This is only used in unstructured kernels, since structured kernels already live in a class.
+    # Finally, this field is currently Optional because it is only used by external backends.
+    # It would be nice if we can add the same logic to in-tree kernels too, but that requires updating
+    # all of the existing kernel signatures scattered across aten/src/ATen/native.
+    class_method_name: str | None
+
+    # Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering
+    # operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher.
+    skip_dispatcher_op_registration: bool
+
+    @staticmethod
+    def gen_device_check(
+        type: DeviceCheckType, args: list[Argument], method_name: str
+    ) -> str:
+        if type == DeviceCheckType.NoCheck:
+            return "  // No device check\n"
+
+        device_check = "std::optional common_device = std::nullopt;\n"
+        device_check += "(void)common_device; // Suppress unused variable warning\n"
+        for arg in args:
+            # Only tensor like arguments are eligible
+            if arg.type.is_tensor_like():
+                device_check += f"""
+  c10::impl::check_and_update_common_device(common_device, {arg.name}, "{method_name}", "{arg.name}");"""
+        return device_check
+
+    @method_with_native_function
+    def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
+        if isinstance(f, NativeFunctionsGroup):
+            g: NativeFunctionsGroup = f
+            # Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
+            # gen_structured() has special logic to handle auto-generated kernels.
+            if g.structured:
+                return self.gen_structured(g)
+            else:
+                return list(
+                    mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())
+                )
+        elif isinstance(f, NativeFunction):
+            r = self.gen_unstructured(f)
+            return [] if r is None else [r]
+        else:
+            assert_never(f)
+
+    def wrapper_kernel_sig(
+        self, f: NativeFunction
+    ) -> NativeSignature | DispatcherSignature:
+        # The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
+        return DispatcherSignature.from_schema(
+            f.func,
+            prefix=f"wrapper_{self.backend_index.dispatch_key}_{f.func.name.overload_name}_",
+            symint=self.symint,
+        )
+
+    def gen_out_inplace_wrapper(
+        self, f: NativeFunction, g: NativeFunctionsGroup | None
+    ) -> str | None:
+        if g is None:
+            return None
+        k = f.func.kind()
+        if k is SchemaKind.inplace:
+            copy_op = "at::_copy_from"
+        elif k is SchemaKind.out:
+            copy_op = "at::_copy_from_and_resize"
+        else:
+            raise AssertionError("gen_out_inplace_wrapper called on a functional op")
+
+        sig = self.wrapper_kernel_sig(f)
+        name = sig.name()
+
+        func_res = f"{name}_tmp"
+        return_names = cpp.return_names(f)
+        if len(return_names) > 1:
+            updates = "\n  ".join(
+                f"{copy_op}(std::get<{i}>({func_res}), {ret_name});"
+                for i, ret_name in enumerate(return_names)
+            )
+            returns = f"{sig.returns_type().cpp_type()}({', '.join(return_names)})"
+        elif len(return_names) == 1:
+            ret_name = return_names[0]
+            updates = f"{copy_op}({func_res}, {ret_name});"
+            returns = ret_name
+        else:
+            assert len(f.func.arguments.out) == 1
+            returns = ""
+            out_arg = f.func.arguments.out[0]
+            if out_arg.type.is_list_like():
+                updates = f"""\
+    for (int64_t i = 0; i < {func_res}.size(); ++i) {{
+        {copy_op}({func_res}[i], {out_arg.name}[i]);
+    }}"""
+            else:
+                updates = f"{copy_op}({func_res}, {out_arg.name});"
+
+        functional_sig = self.wrapper_kernel_sig(g.functional)
+        wrapper_name = sig.name()
+
+        return f"""\
+{sig.defn(name=wrapper_name)} {{
+  auto {func_res} = {functional_sig.name()}({", ".join(e.expr for e in translate(sig.arguments(), functional_sig.arguments()))});
+  {updates}
+  return {returns};
+}}
+"""
+
+    def gen_structured(self, g: NativeFunctionsGroup) -> list[str]:
+        metadata = self.backend_index.get_kernel(g)
+        if self.backend_index.dispatch_key == DispatchKey.Meta:
+            assert not self.backend_index.has_kernel(g.out), (
+                "Do not explicitly specify Meta dispatch key on structured "
+                "functions, they will be automatically generated for you"
+            )
+        elif (
+            self.backend_index.dispatch_key
+            == DispatchKey.CompositeExplicitAutogradNonFunctional
+        ):
+            assert not self.backend_index.has_kernel(g.out), (
+                "Do not explicitly specify CompositeExplicitAutograd dispatch key on structured "
+                "functions, they will be automatically generated for you"
+            )
+        elif metadata is None or not metadata.structured:
+            return list(mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()))
+        structured_gen = StructuredRegisterDispatchKey(
+            self.backend_index,
+            self.target,
+            self.selector,
+            self.rocm,
+            self.symint,
+            self.class_method_name,
+            self.skip_dispatcher_op_registration,
+            g,
+        )
+        return list(mapMaybe(structured_gen.gen_one, g.functions()))
+
+    def gen_unstructured(
+        self, f: NativeFunction, g: NativeFunctionsGroup | None = None
+    ) -> str | None:
+        with native_function_manager(f):
+            inplace_meta = False
+            gets_out_inplace_wrapper = False
+            if not self.backend_index.has_kernel(f):
+                if (
+                    self.backend_index.dispatch_key == DispatchKey.Meta
+                    and f.func.kind() is SchemaKind.inplace
+                    and
+                    # Defer to composites for meta implementation
+                    not f.has_composite_kernel
+                    and
+                    # Inplace list operations are not supported
+                    len(f.func.returns) == 1
+                ):
+                    inplace_meta = True
+                elif (
+                    not self.backend_index.use_out_as_primary
+                    and g is not None
+                    and gets_generated_out_inplace_wrapper(f, g, self.backend_index)
+                ):
+                    # We want to generate inplace/out wrappers, that don't have a kernel for the backend.
+                    gets_out_inplace_wrapper = True
+                else:
+                    return None
+            if f.manual_kernel_registration:
+                return None
+
+            if (
+                self.target is Target.REGISTRATION
+                and not self.selector.is_native_function_selected(f)
+            ):
+                return None
+
+            sig = self.wrapper_kernel_sig(f)
+
+            name = sig.name()
+            returns_type = sig.returns_type().cpp_type()
+            args = sig.arguments()
+            args_str = ", ".join(a.defn() for a in args)
+
+            # See Note [Direct dispatch bindings]
+            cpp_sig_group = CppSignatureGroup.from_native_function(
+                f, method=False, fallback_binding=False
+            )
+
+            # TODO: dedupe this with the structured codegen
+            if self.target is Target.NAMESPACED_DECLARATION:
+                result = ""
+                for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
+                    result += f"TORCH_API {cpp_sig.decl()};\n"
+                return result
+            elif self.target is Target.NAMESPACED_DEFINITION:
+
+                def generate_defn(cpp_sig: CppSignature) -> str:
+                    return f"""
+{cpp_sig.defn()} {{
+return {sig.name()}({", ".join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
+}}
+"""
+
+                result = ""
+                for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
+                    result += generate_defn(cpp_sig)
+                return result
+
+            elif self.target is Target.ANONYMOUS_DEFINITION:
+                # short circuit for inplace_meta
+                if inplace_meta:
+                    assert f.func.arguments.self_arg is not None
+                    self_arg_name = f.func.arguments.self_arg.argument.name
+                    # TODO: handle in place on tensor list
+                    return f"""
+{returns_type} {name}({args_str}) {{
+  TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(),
+    "Cannot inplace into non-meta tensor with meta tensor argument");
+  return {self_arg_name};
+}}
+"""
+
+                # short circuit for generated inplace/out wrappers
+                if gets_out_inplace_wrapper:
+                    return self.gen_out_inplace_wrapper(f, g)
+
+                metadata = self.backend_index.get_kernel(f)
+                if metadata is None:
+                    return None
+                if self.class_method_name is None:
+                    impl_name = f"{metadata.cpp_namespace}::{metadata.kernel}"
+                else:
+                    impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}"
+
+                kernel_sig = kernel_signature(f, self.backend_index)
+
+                args_exprs_str = ", ".join(
+                    e.expr
+                    for e in translate(
+                        sig.arguments(), kernel_sig.arguments(), method=False
+                    )
+                )
+
+                device_check = "  // No device check\n"
+                # Backends that require device guards presumably also require device checks.
+                if self.backend_index.device_guard:
+                    device_check_args = itertools.chain(
+                        f.func.arguments.out, f.func.arguments.flat_positional
+                    )
+                    device_check = RegisterDispatchKey.gen_device_check(
+                        f.device_check, list(device_check_args), name
+                    )
+
+                device_guard = "// DeviceGuard omitted"  # default
+                if f.device_guard and self.backend_index.device_guard:
+                    has_tensor_options = any(
+                        isinstance(a, TensorOptionsArguments)
+                        for a in f.func.arguments.non_out
+                    )
+                    if has_tensor_options:
+                        # kernel is creating a tensor
+                        device_guard = """
+  const DeviceGuard device_guard(device_or_default(device));"""
+
+                        # CUDA requires special handling
+                        if is_cuda_dispatch_key(self.backend_index.dispatch_key):
+                            device_guard = f"globalContext().lazyInitDevice(c10::DeviceType::CUDA);\n{device_guard}"
+                    else:
+                        # kernel is operating on existing tensors
+
+                        # There is precedence for which argument we use to do
+                        # device guard.  This describes the precedence order.
+                        self_arg = (
+                            [f.func.arguments.self_arg.argument]
+                            if f.func.arguments.self_arg is not None
+                            else []
+                        )
+                        candidate_args = itertools.chain(
+                            self_arg,
+                            f.func.arguments.out,
+                            f.func.arguments.flat_positional,
+                        )
+
+                        # Only tensor like arguments are eligible
+                        device_of = next(
+                            (
+                                f"{a.name}"
+                                for a in candidate_args
+                                if a.type.is_tensor_like()
+                            ),
+                            None,
+                        )
+                        if device_of is not None:
+                            device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));"
+
+                return f"""\
+namespace {{
+
+{returns_type} {name}({args_str}) {{
+  {device_check}
+
+  {device_guard}
+  return {impl_name}({args_exprs_str});
+}}
+
+}} // anonymous namespace
+"""
+
+            elif self.target is Target.REGISTRATION:
+                if f.manual_kernel_registration or self.skip_dispatcher_op_registration:
+                    return None
+                else:
+                    payload = f"TORCH_FN({name})"
+                    return f'm.impl("{f.func.name}",\n{payload});\n'
+            else:
+                assert_never(self.target)
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                           STRUCTURED
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+@dataclass(frozen=True)
+class StructuredRegisterDispatchKey(RegisterDispatchKey):
+    g: NativeFunctionsGroup
+
+    def gen_class_set_output_functions(
+        self, k: SchemaKind, parent_class: str, generate_super: bool
+    ) -> str:
+        if generate_super:
+            set_output_super = f"{parent_class}::set_output_raw_strided(output_idx, sizes, strides, options, names);"
+        else:
+            set_output_super = ""
+
+        def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str:
+            return f"""
+void set_output_{name}(
+    int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
+    TensorOptions options, DimnameList names
+) override {{
+{textwrap.indent(self.gen_class_set_output_body(k, maybe_create_proxy), "    ")}
+    if (!names.empty()) {{
+      namedinference::propagate_names(outputs_[output_idx], names);
+    }}
+    // super must happen after, so that downstream can use maybe_get_output
+    // to retrieve the output
+{textwrap.indent(set_output_super, "    ")}
+}}
+"""
+
+        return f"""
+{gen_set_output_function("strided", maybe_create_proxy=True)}
+{gen_set_output_function("raw_strided", maybe_create_proxy=False)}
+"""
+
+    def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> str:
+        if self.backend_index.dispatch_key in [
+            DispatchKey.CUDA,
+            DispatchKey.MPS,
+            DispatchKey.XPU,
+            DispatchKey.CompositeExplicitAutogradNonFunctional,
+        ]:
+            maybe_set_guard = """
+auto current_device = guard_.current_device();
+if (C10_UNLIKELY(current_device.has_value())) {
+  TORCH_INTERNAL_ASSERT(*current_device == options.device(),
+    "structured kernels don't support multi-device outputs");
+} else {
+  guard_.reset_device(options.device());
+}
+"""
+            maybe_set_guard_line = maybe_set_guard + "\n"
+        else:
+            maybe_set_guard_line = maybe_set_guard = ""
+
+        if maybe_create_proxy:
+            create_proxy = """
+auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options);
+if (C10_UNLIKELY(maybe_proxy.has_value())) {
+    proxy_outputs_[output_idx] = std::move(maybe_proxy).value();
+}
+"""
+        else:
+            create_proxy = ""
+
+        if k is SchemaKind.functional:
+            assert self.backend_index.dispatch_key in (
+                DispatchKey.Meta,
+                DispatchKey.CPU,
+                DispatchKey.CUDA,
+                DispatchKey.MPS,
+                DispatchKey.XPU,
+                DispatchKey.MTIA,
+                DispatchKey.CompositeExplicitAutogradNonFunctional,
+            )
+            return f"""{maybe_set_guard_line}
+outputs_[output_idx] = create_out(sizes, strides, options);"""
+        elif k is SchemaKind.inplace:
+            return f"""{maybe_set_guard_line}
+const auto& out = outputs_[output_idx].get();
+check_inplace(out, sizes, options);
+{create_proxy}"""
+        elif k is SchemaKind.out:
+            return f"""{maybe_set_guard_line}
+const auto& out = outputs_[output_idx].get();
+resize_out(out, sizes, strides, options);
+{create_proxy}"""
+        elif k is SchemaKind.mutable or k is SchemaKind.scratch:
+            raise AssertionError(
+                f"{k} structured operators are currently not supported"
+            )
+        else:
+            assert_never(k)
+
+    # returns the definition of a ctor, as well as how to construct
+    # this class to a variable named op
+    def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str:
+        if k is SchemaKind.functional:
+            return ""
+        elif k is SchemaKind.inplace:
+            # TODO: Make sure out argument is guaranteed to be self
+            return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}"
+        elif k is SchemaKind.out:
+            out_args = ", ".join(f"Tensor& out{i}" for i in range(returns))
+            out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns))
+            return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}"
+        elif k is SchemaKind.mutable or k is SchemaKind.scratch:
+            raise AssertionError(
+                f"{k} structured operators are currently not supported"
+            )
+        else:
+            assert_never(k)
+
+    def gen_class(
+        self,
+        f: NativeFunction,
+        k: SchemaKind,
+        *,
+        class_name: str,
+        parent_class: str,
+        generate_super: bool,
+    ) -> str:
+        if k is SchemaKind.functional:
+            output_type = "Tensor"
+            output_value = "outputs_[output_idx]"
+            proxy_field = ""
+        elif k is SchemaKind.inplace:
+            output_type = "std::reference_wrapper"
+            output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
+            proxy_field = f"std::array<::std::optional, {len(f.func.returns)}> proxy_outputs_;"
+        elif k is SchemaKind.out:
+            output_type = "std::reference_wrapper"
+            output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
+            proxy_field = f"std::array<::std::optional, {len(f.func.returns)}> proxy_outputs_;"
+        else:
+            raise RuntimeError(f"Unsupported SchemaKind {k}")
+
+        if self.backend_index.dispatch_key == DispatchKey.CUDA:
+            if self.rocm:
+                guard_field = "c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;"
+            else:
+                guard_field = "c10::cuda::OptionalCUDAGuard guard_;"
+        elif (
+            self.backend_index.dispatch_key
+            == DispatchKey.CompositeExplicitAutogradNonFunctional
+        ):
+            guard_field = "c10::OptionalDeviceGuard guard_;"
+        elif self.backend_index.dispatch_key == DispatchKey.MPS:
+            # TODO: Move to OptionalMPSGuard.
+            guard_field = "c10::OptionalDeviceGuard guard_;"
+        elif self.backend_index.dispatch_key == DispatchKey.XPU:
+            guard_field = "c10::OptionalDeviceGuard guard_;"
+        elif self.backend_index.dispatch_key == DispatchKey.MTIA:
+            guard_field = "c10::OptionalDeviceGuard guard_;"
+        else:
+            guard_field = ""
+
+        indent = " " * 4
+        class_ctor_str = self.gen_class_ctor(k, class_name, len(f.func.returns))
+        lines = (
+            f"struct {class_name} final : public {parent_class} {{",
+            f"{textwrap.indent(class_ctor_str, indent)}",
+            f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}",
+            "    const Tensor& maybe_get_output(int64_t output_idx) override {",
+            f"      return {output_value};\n",  # type: ignore[possibly-undefined]  # TODO: audit
+            "    }",
+            # type: ignore[possibly-undefined]  # TODO: audit
+            f"    std::array<{output_type}, {len(f.func.returns)}> outputs_;",
+            f"{textwrap.indent(proxy_field, indent)}",  # type: ignore[possibly-undefined]  # TODO: audit
+            f"{textwrap.indent(guard_field, indent)}",
+            "};",
+        )
+        return "\n".join(line for line in lines if line)
+
+    @method_with_native_function
+    def gen_one(self, f: NativeFunction) -> str | None:
+        assert not f.manual_kernel_registration
+
+        if (
+            self.target is Target.REGISTRATION
+            and not self.selector.is_native_function_selected(f)
+        ):
+            return None
+
+        # TODO: Now, there is something interesting going on here.  In the code below,
+        # we generate CompositeExplicitAutogradNonFunctional implementations of functional and inplace
+        # based on the out implementation.  But in fact, out is definable by
+        # functional too (just not very efficiently), and this is honestly the
+        # MORE likely situation for a backend implementer.  How do we pick?
+        # Well, taking a page from Haskell type classes and default methods,
+        # we could conceivably register a circular definition (out in terms
+        # of functional, and functional in terms of out) and just require
+        # someone to implement one or the other.  We'd have to do a little bit
+        # of work to not register one of these "weak" definitions unless there
+        # is a strong definition somewhere in the DAG!  So it's not implemented yet.
+        if (
+            self.backend_index.dispatch_key
+            == DispatchKey.CompositeExplicitAutogradNonFunctional
+            and f.func.kind() is SchemaKind.out
+        ):
+            # Never generate a default implementation for out, that's what you
+            # have to define as a backend implementer
+            return None
+
+        # Note [Direct dispatch bindings]
+        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+        # Signature of the non-dispatched function we'll expose in a header
+        # (e.g., at::cpu::add).  We don't generate methods (TODO: do this
+        # when CPUTensor class is a thing); nor do we generate fallback
+        # bindings for manual_cpp_binding functions.
+        cpp_sig_group = CppSignatureGroup.from_native_function(
+            f, method=False, fallback_binding=False
+        )
+
+        # Signature of the wrapper function we'll register to the dispatcher
+        kern = self.backend_index.get_kernel(f)
+        sig = NativeSignature(
+            f.func,
+            prefix=f"wrapper_{self.backend_index.dispatch_key}_",
+            symint=kern is not None and kern.supports_symint(),
+        )
+
+        if self.target is Target.NAMESPACED_DECLARATION:
+            result = ""
+            for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
+                result += f"TORCH_API {cpp_sig.decl()};\n"
+            return result
+
+        elif self.target is Target.NAMESPACED_DEFINITION:
+
+            def generate_defn(cpp_sig: CppSignature) -> str:
+                return f"""
+{cpp_sig.defn()} {{
+return {sig.name()}({", ".join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
+}}
+"""
+
+            result = ""
+            for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
+                result += generate_defn(cpp_sig)
+            return result
+
+        elif self.target is Target.ANONYMOUS_DEFINITION:
+            k = f.func.kind()
+
+            # Construct the body of the wrapper function with signature sig
+            sig_body = []
+            # We'll use context to keep track of any variables we've brought
+            # into scope while generating code
+            context: list[Binding | Expr] = list(sig.arguments())
+
+            # Initialize the class corresponding to this structured
+            # operator; feeding it the output argument(s) if it is known
+            if self.backend_index.dispatch_key is DispatchKey.Meta:
+                class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
+                parent_class = f"at::meta::structured_{meta.name(self.g)}"
+            elif (
+                self.backend_index.dispatch_key
+                is DispatchKey.CompositeExplicitAutogradNonFunctional
+            ):
+                # TODO: dedup this branch
+                class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}"
+                parent_class = f"at::meta::structured_{meta.name(self.g)}"
+            else:
+                metadata = self.backend_index.get_kernel(self.g)
+                assert metadata is not None
+                class_name = f"structured_{metadata.kernel}_{k.name}"
+                parent_class = f"{metadata.cpp_namespace}::structured_{metadata.kernel}"
+
+            if self.backend_index.device_guard:
+                device_check_args = itertools.chain(
+                    f.func.arguments.out, f.func.arguments.flat_positional
+                )
+                sig_body.append(
+                    RegisterDispatchKey.gen_device_check(
+                        f.device_check, list(device_check_args), sig.name()
+                    )
+                )
+
+            if k is SchemaKind.functional:
+                sig_body.append(f"{class_name} op;")
+            elif k is SchemaKind.inplace:
+                sig_body.append(f"{class_name} op(self);")
+            elif k is SchemaKind.out:
+                out_args_str = ", ".join(a.name for a in f.func.arguments.out)
+                sig_body.append(f"{class_name} op({out_args_str});")
+
+            # Translate the input native arguments into structured
+            # arguments for the meta call
+            meta_exprs = ", ".join(
+                e.expr
+                for e in translate(
+                    context, structured.meta_arguments(self.g), method=False
+                )
+            )
+
+            if self.g.out.precomputed:
+                # If this function group has precomputed elements, the meta function
+                # returns a struct containing them which must be saved so that it
+                # can be unpacked when generating code to call the impl.
+                sig_body.append(f"auto precompute = op.meta({meta_exprs});")
+
+                # Put all of the contents of the precompute struct into the context
+                # so that translate will be able to return the correct args for the
+                # call to the impl.
+                precomputed_values = [
+                    *self.g.out.precomputed.replace.values(),
+                    self.g.out.precomputed.add,
+                ]
+                for precomputed_elems in precomputed_values:
+                    context.extend(
+                        Expr(
+                            expr=f"precompute.{arg.name}",
+                            type=structured.argument_type(arg, binds=arg.name),
+                        )
+                        for arg in precomputed_elems
+                    )
+
+                # Add a use of the precompute struct so FB internal compilers don't
+                # complain that there is an unused variable.
+                sig_body.append("(void)precompute;")
+            else:
+                sig_body.append(f"op.meta({meta_exprs});")
+
+            # After running meta, op.outputs_ is guaranteed to be valid;
+            # add it to the context
+            out_args = structured.out_arguments(self.g)
+            for i, out_arg in enumerate(out_args):
+                assert ConstRefCType(BaseCType(tensorT)) == out_arg.nctype.type
+
+                if k is SchemaKind.out:
+                    expr = f"op.maybe_get_output({i})"
+                else:
+                    expr = f"op.outputs_[{i}]"
+
+                context.append(
+                    Expr(
+                        expr=expr,
+                        # TODO: Stop hardcoding that the output type is a Tensor.  Note
+                        # that for the codegen here this is fine because outputs_ is
+                        # hardcoded to be tensor already
+                        type=NamedCType(
+                            out_arg.nctype.name, MutRefCType(BaseCType(tensorT))
+                        ),
+                    )
+                )
+
+            # With the expanded context, do the impl call (if not a meta
+            # function)
+            if (
+                self.backend_index.dispatch_key
+                == DispatchKey.CompositeExplicitAutogradNonFunctional
+            ):
+                # TODO: https://github.com/pytorch/pytorch/issues/53023
+                out_sig_group = CppSignatureGroup.from_native_function(
+                    self.g.out, method=False, fallback_binding=f.manual_cpp_binding
+                )
+                out_sig = out_sig_group.most_faithful_signature()
+                api_name = out_sig.name()
+                out_exprs = ", ".join(
+                    e.expr
+                    for e in translate(context, out_sig.arguments(), method=False)
+                )
+                # TODO: I think this means structured won't work with method
+                # only functions (but maybe you're saved by faithful? iunno.)
+                # NB: Originally I wrote this as an at::redispatch call, but
+                # I got in trouble because that meant I needed a DispatchKeySet
+                # in the wrapper function, which meant I needed a DispatchKeySet
+                # in the DispatchKeyFunctions declarations, but the defined API
+                # there does NOT permit a dispatch key set.  I think you can
+                # probably unwind this by calling some function to do the TLS
+                # fetch and get the DispatchKeySet when you don't have it, but
+                # I didn't do it for this version
+                sig_body.append(f"at::{api_name}({out_exprs});")
+            elif self.backend_index.dispatch_key != DispatchKey.Meta:
+                impl_exprs = ", ".join(
+                    e.expr
+                    for e in translate(
+                        context, structured.impl_arguments(self.g), method=False
+                    )
+                )
+                sig_body.append(f"op.impl({impl_exprs});")
+
+            # Go over each output, and check if there is a proxy created for it.
+            # If so, copy it over to the original output.
+            if k is SchemaKind.out or k is SchemaKind.inplace:
+                for i in range(len(f.func.returns)):
+                    sig_body.append(
+                        f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);"
+                    )
+
+            # Destructively return the final tensors
+            # TODO: Do this in translate instead
+            if k is SchemaKind.functional:
+                if len(f.func.returns) == 1:
+                    ret_expr = "std::move(op.outputs_[0])"  # small optimization
+                else:
+                    moved = ", ".join(
+                        f"std::move(op.outputs_[{i}])"
+                        for i in range(len(f.func.returns))
+                    )
+                    ret_expr = f"std::make_tuple({moved})"
+            elif k is SchemaKind.inplace:
+                ret_expr = "self"
+            elif k is SchemaKind.out:
+                if len(f.func.returns) == 1:
+                    ret_expr = f.func.arguments.out[0].name
+                else:
+                    refs = ", ".join(a.name for a in f.func.arguments.out)
+                    ret_expr = f"std::forward_as_tuple({refs})"
+            sig_body.append(f"return {ret_expr};")  # type: ignore[possibly-undefined]  # TODO: audit
+
+            sig_body_str = "\n".join(sig_body)
+
+            # For an overview of what this template code looks like, see
+            # https://github.com/pytorch/rfcs/pull/9
+            return f"""\
+{
+                self.gen_class(
+                    f,
+                    k,
+                    class_name=class_name,
+                    parent_class=parent_class,
+                    generate_super=self.g.out.structured_inherits is not None,
+                )
+            }
+
+{sig.defn()} {{
+{sig_body_str}
+}}
+"""
+
+        elif self.target is Target.REGISTRATION:
+            return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
+        else:
+            assert_never(self.target)
+            # Silence mypy's "Missing return statement" error
+            return None
diff --git a/phivenv/Lib/site-packages/torchgen/dest/ufunc.py b/phivenv/Lib/site-packages/torchgen/dest/ufunc.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc46ac05057f57264fd53b43cef4129d863df260
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/dest/ufunc.py
@@ -0,0 +1,553 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING
+
+import torchgen.api.ufunc as ufunc
+from torchgen.api.translate import translate
+from torchgen.api.types import (
+    BaseCType,
+    Binding,
+    CType,
+    Expr,
+    NamedCType,
+    opmath_t,
+    scalar_t,
+    StructuredImplSignature,
+    VectorizedCType,
+)
+from torchgen.context import with_native_function
+from torchgen.model import (
+    Argument,
+    BaseTy,
+    BaseType,
+    DispatchKey,
+    NativeFunctionsGroup,
+    ScalarType,
+    UfuncKey,
+)
+from torchgen.utils import OrderedSet
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+    from torchgen.api.ufunc import UfunctorBindings
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                                  CUDA STUFF
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+# NB: not bothering to generate dispatch stub forward declaration in header,
+# we can just paste it wherever necessary
+
+# TODO: use BackendIndex
+# dispatch_key: DispatchKey  # only CPU/CUDA right now
+
+
+# Represents functors for implementing CUDA ufuncs.
+# Functors are templated by scalar_t because when USERS instantiate functors
+# they are templated.  A functor looks something like this:
+#
+#   template 
+#   struct CUDAFunctorOnSelf_add {
+#     using opmath_t = at::opmath_type;
+#     opmath_t other_;
+#     opmath_t alpha_;
+#     CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha)
+#         : other_(other), alpha_(alpha) {}
+#     __device__ scalar_t operator()(scalar_t self) {
+#       return ufunc::add(static_cast(self), other_, alpha_);
+#     }
+#   };
+#
+@dataclass(frozen=True)
+class UfunctorSignature:
+    g: NativeFunctionsGroup
+    scalar_tensor_idx: int | None
+    name: str
+
+    def arguments(self) -> UfunctorBindings:
+        return ufunc.ufunctor_arguments(
+            self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
+        )
+
+    def fields(self) -> list[Binding]:
+        # fields are renamed to have a trailing underscore, as is conventional
+        return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
+
+    def returns_type(self) -> CType:
+        # TODO: don't hardcode; return type will be inferred based on tags on
+        # the native function
+        return BaseCType(scalar_t)
+
+    def decl_fields(self) -> str:
+        return "\n".join(f"{f.type} {f.name};" for f in self.fields())
+
+    def inline_defn_ctor(self) -> str:
+        args_str = ", ".join(a.decl() for a in self.arguments().ctor)
+        # NB: hypothetically could do this with translate but the
+        # transition here is very regular
+        init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor)
+        return f"{self.name}({args_str}) : {init_str} {{}}"
+
+    def decl_apply(self) -> str:
+        args_str = ", ".join(a.decl() for a in self.arguments().apply)
+        return f"{self.returns_type().cpp_type()} operator()({args_str}) const"
+
+
+@dataclass(frozen=True)
+class UfuncSignature:
+    g: NativeFunctionsGroup
+    name: str
+    compute_t: CType
+
+    def arguments(self) -> list[Binding]:
+        return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
+
+    def call(self, ctx: Sequence[Binding | Expr]) -> str:
+        return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
+
+
+# steps:
+#   1. take the functional signature
+#   2. use api.ufunc to convert it to template signature.  this establishes
+#      the type of the template function
+#   3. use api.ufunc (II) to generate a split struct / operator() signature.
+#      this establish context in which we call the template signature
+#
+# StructuredImplSignature context
+#   ~> functor constructor sig
+#
+# Functor constructor context
+#   ~> functor fields sig
+#
+# Functor apply context (functor fields + functor apply sig)
+#   ~> template sig
+#
+
+
+def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
+    num_tensors = sum(
+        1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()
+    )
+    return num_tensors == 2
+
+
+def compute_ufunc_cuda_functors(
+    g: NativeFunctionsGroup,
+) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]:
+    # First, build the functors.
+    ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {}
+    ufunctors: list[str] = []
+    loops = g.out.ufunc_inner_loop
+    scalar_tensor_idx_lookup = {
+        UfuncKey.CUDAFunctorOnSelf: 1,
+        UfuncKey.CUDAFunctorOnOther: 0,
+        UfuncKey.CUDAFunctor: None,
+    }
+    if eligible_for_binary_scalar_specialization(g):
+        keys = [
+            UfuncKey.CUDAFunctorOnSelf,
+            UfuncKey.CUDAFunctorOnOther,
+            UfuncKey.CUDAFunctor,
+        ]
+    else:
+        keys = [UfuncKey.CUDAFunctor]
+        for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]:
+            assert k not in loops, f"cannot use {k} on non-binary function"
+    for k in keys:
+        # If the key was directly defined, skip functor codegen; we assume the
+        # user already done it for us
+        if k in loops:
+            ufunctor_sig = UfunctorSignature(
+                g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name
+            )
+            for dtype in loops[k].supported_dtypes:
+                ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
+            continue
+
+        # Note [ScalarOnly and Generic must match names for CUDA]
+        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+        # Otherwise, look in ANY of the generic entries.  For simplicity of
+        # codegen, both ScalarOnly and Generic are defined, the ufunc name
+        # must match  (if they didn't match, we'd have to generate distinct
+        # functors per dtype, which is awful, so we're not going to do it unless
+        # someone really forces us to)
+        ufunc_name = None
+        supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
+        for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]:
+            if lk not in loops:
+                continue
+            if ufunc_name is None:
+                ufunc_name = loops[lk].name
+            else:
+                # See Note [ScalarOnly and Generic must match names for CUDA]
+                assert ufunc_name == loops[lk].name, (
+                    "ScalarOnly and Generic must have same ufunc name"
+                )
+            supported_dtypes |= loops[lk].supported_dtypes
+        assert ufunc_name is not None
+
+        name = f"{k}_{ufunc_name}"
+        ufunctor_sig = UfunctorSignature(
+            g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name
+        )
+        for dtype in supported_dtypes:
+            ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
+
+        ufunc_sig = UfuncSignature(
+            g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)
+        )
+        apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply
+        ufunctors.append(
+            f"""
+template 
+struct {ufunctor_sig.name} {{
+  using opmath_t = at::opmath_type;
+  {ufunctor_sig.decl_fields()}
+  {ufunctor_sig.inline_defn_ctor()}
+  __device__ {ufunctor_sig.decl_apply()} {{
+    return {ufunc_sig.call(apply_ctx)};
+  }}
+}};
+"""
+        )
+
+    return ufunctor_sigs, "\n".join(ufunctors)
+
+
+@dataclass(frozen=True)
+class BinaryScalarSpecializationConfig:
+    scalar_idx: int
+    ctor_tensor: str
+    ufunc_key: UfuncKey
+
+
+BinaryScalarSpecializationConfigs = [
+    BinaryScalarSpecializationConfig(
+        scalar_idx=0,
+        ctor_tensor="self",
+        ufunc_key=UfuncKey.CUDAFunctorOnOther,
+    ),
+    BinaryScalarSpecializationConfig(
+        scalar_idx=1,
+        ctor_tensor="other",
+        ufunc_key=UfuncKey.CUDAFunctorOnSelf,
+    ),
+]
+
+
+def compute_ufunc_cuda_dtype_body(
+    g: NativeFunctionsGroup,
+    dtype: ScalarType,
+    inner_loops: dict[UfuncKey, UfunctorSignature],
+    parent_ctx: Sequence[Binding],
+) -> str:
+    body = "using opmath_t = at::opmath_type;"
+    body += "if (false) {}\n"  # for ease of codegen
+    for config in BinaryScalarSpecializationConfigs:
+        if config.ufunc_key not in inner_loops:
+            continue
+        ufunctor_sig = inner_loops[config.ufunc_key]
+        scalar_idx = config.scalar_idx + 1
+        # Make a copy and at the same time widen the type (not permissible
+        # without copy; we don't want to mutate the input argument anyway)
+        ctx: list[Expr | Binding] = list(parent_ctx)
+        ctx.append(
+            Expr(
+                expr=f"iter.scalar_value({scalar_idx})",
+                type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
+            )
+        )
+        ufunctor_ctor_exprs_str = ", ".join(
+            a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)
+        )
+
+        # NB: ufunctor must be allocated before iter.remove_operand is called,
+        # as it relies on iter
+        body += f"""\
+else if (iter.is_cpu_scalar({scalar_idx})) {{
+  {ufunctor_sig.name} ufunctor({ufunctor_ctor_exprs_str});
+  iter.remove_operand({scalar_idx});
+  gpu_kernel(iter, ufunctor);
+}}"""
+
+    ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
+    ufunctor_ctor_exprs_str = ", ".join(
+        a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)
+    )
+    body += f"""
+else {{
+  gpu_kernel(iter, {ufunctor_sig.name}({ufunctor_ctor_exprs_str}));
+}}
+    """
+    return body
+
+
+@with_native_function
+def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str:
+    # First, build the functors, indexing them by dtype
+    ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g)
+
+    # Next, build the conditionals
+    sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA))
+    dtype_cases = []
+    for dtype, inner_ufunc_sigs in ufunctor_sigs.items():
+        dtype_cases.append(
+            f"""
+AT_DISPATCH_CASE(at::ScalarType::{dtype},
+  [&]() {{
+    {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())}
+  }}
+)
+"""
+        )
+
+    dtype_cases_str = "\n".join(dtype_cases)
+
+    stub_sig = StubSignature(g)
+
+    return f"""
+{ufunctors}
+
+{stub_sig.type_defn()};
+{stub_sig.dispatch_decl()}
+
+{stub_sig.kernel_defn()} {{
+  AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}",
+    {dtype_cases_str}
+  );
+}}
+REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name})
+
+{sig.defn()} {{
+  {stub_sig.direct_call(sig.arguments())};
+}}
+"""
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                                   CPU STUFF
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+@dataclass(frozen=True)
+class StubSignature:
+    g: NativeFunctionsGroup
+
+    @property
+    def name(self) -> str:
+        return f"{str(self.g.functional.func.name.name)}_stub"
+
+    @property
+    def kernel_name(self) -> str:
+        return f"{str(self.g.functional.func.name.name)}_kernel"
+
+    @property
+    def type_name(self) -> str:
+        return f"{str(self.g.functional.func.name.name)}_fn"
+
+    def arguments(self) -> list[Binding]:
+        return ufunc.stub_arguments(self.g)
+
+    def type(self) -> str:
+        cpp_args = self.arguments()
+        return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})"
+
+    def dispatch_decl(self) -> str:
+        return f"DECLARE_DISPATCH({self.type_name}, {self.name})"
+
+    def dispatch_defn(self) -> str:
+        return f"DEFINE_DISPATCH({self.name})"
+
+    def kernel_defn(self) -> str:
+        return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})"
+
+    def type_defn(self) -> str:
+        return f"using {self.type_name} = {self.type()}"
+
+    # must be called from context where this is TensorIteratorBase*
+    def call(self, ctx: Sequence[Binding]) -> str:
+        return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
+
+    # used in CUDA to skip the unnecessary dynamic dispatch
+    def direct_call(self, ctx: Sequence[Binding]) -> str:
+        return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
+
+
+@with_native_function
+def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
+    stub_sig = StubSignature(g)
+    sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU))
+
+    return f"""
+{stub_sig.type_defn()};
+{stub_sig.dispatch_decl()}
+{stub_sig.dispatch_defn()};
+
+{sig.defn()} {{
+  {stub_sig.call(sig.arguments())};
+}}
+"""
+
+
+def compute_ufunc_cpu_dtype_body(
+    g: NativeFunctionsGroup,
+    dtype: ScalarType,
+    inner_loops: dict[UfuncKey, UfuncSignature],
+    parent_ctx: Sequence[Binding],
+) -> str:
+    assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
+    assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}
+    scalar_loop = inner_loops[UfuncKey.CPUScalar]
+    vec_loop = None
+    if UfuncKey.CPUVector in inner_loops:
+        vec_loop = inner_loops[UfuncKey.CPUVector]
+
+    # NB: We DON'T use translate here, because translate is
+    # incapable of CSE'ing the scalar accesses in case it is also
+    # used by Vectorized; also, the unpacking here is very simple
+    # and only affects Scalar; everything else is implicitly captured
+    # by the lambda
+
+    # Setup scalar in scope
+    body = []
+    ctx = []
+    for b in parent_ctx:
+        if isinstance(b.argument, Argument) and b.argument.type != BaseType(
+            BaseTy.Scalar
+        ):
+            continue
+        body.append(f"auto _s_{b.name} = {b.name}.to();")
+        ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t))))
+    if vec_loop is not None:
+        for b in parent_ctx:
+            if isinstance(b.argument, Argument) and b.argument.type != BaseType(
+                BaseTy.Scalar
+            ):
+                continue
+            body.append(
+                f"auto _v_{b.name} = at::vec::Vectorized(_s_{b.name});"
+            )
+            ctx.append(
+                Expr(
+                    f"_v_{b.name}",
+                    NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))),
+                )
+            )
+
+    # Setup lambda signature
+    # NB: simplified version of ufunctor_arguments
+    scalar_bindings = []
+    vec_bindings = []
+    for a in g.functional.func.arguments.flat_non_out:
+        if not a.type.is_tensor_like():
+            continue
+        assert a.type == BaseType(BaseTy.Tensor)
+        scalar_bindings.append(
+            Binding(
+                name=a.name,
+                nctype=NamedCType(a.name, BaseCType(scalar_t)),
+                argument=a,
+            )
+        )
+        if vec_loop is not None:
+            vec_bindings.append(
+                Binding(
+                    name=a.name,
+                    nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))),
+                    argument=a,
+                )
+            )
+
+    def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]:
+        r: list[Expr | Binding] = []
+        r.extend(ctx)
+        r.extend(b)
+        return r
+
+    body_str = "\n".join(body)
+    if vec_loop is not None:
+        return f"""
+{body_str}
+cpu_kernel_vec(iter,
+  [=]({", ".join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
+  [=]({", ".join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
+);
+"""
+    else:
+        return f"""
+{body_str}
+cpu_kernel(iter,
+  [=]({", ".join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
+);
+"""
+
+
+@with_native_function
+def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
+    stub_sig = StubSignature(g)
+
+    # Reindex the ufunc by dtypes; processing generic/scalaronly as well
+    loops = g.out.ufunc_inner_loop
+    ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {}
+    for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
+        lks = []
+        # ORDER MATTERS: this specifies overriding precedence
+        if k in loops:  # should happen rarely
+            lks.append(k)
+        if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar:
+            lks.append(UfuncKey.ScalarOnly)
+        if UfuncKey.Generic in loops:
+            lks.append(UfuncKey.Generic)
+        # TODO: don't hardcode ufunc:: namespace here, should be centralized smh
+        for lk in lks:
+            for dtype in loops[lk].supported_dtypes:
+                compute_t: CType
+                if k is UfuncKey.CPUScalar:
+                    compute_t = BaseCType(scalar_t)
+                elif k is UfuncKey.CPUVector:
+                    compute_t = VectorizedCType(BaseCType(scalar_t))
+                else:
+                    raise AssertionError
+                inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
+                if k not in inner_ufunc_sigs:
+                    inner_ufunc_sigs[k] = UfuncSignature(
+                        g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t
+                    )
+
+    # Build the conditionals
+    dtype_cases = []
+    for dtype, inner_ufunc_sigs in ufunc_sigs.items():
+        dtype_cases.append(
+            f"""
+AT_DISPATCH_CASE(at::ScalarType::{dtype},
+  [&]() {{
+    {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())}
+  }}
+)
+"""
+        )
+
+    dtype_cases_str = "\n".join(dtype_cases)
+    return f"""
+namespace {{
+
+{stub_sig.kernel_defn()} {{
+  AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}",
+    {dtype_cases_str}
+  );
+}}
+
+}} // anonymous namespace
+
+{stub_sig.type_defn()};
+{stub_sig.dispatch_decl()}
+REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name})
+"""
diff --git a/phivenv/Lib/site-packages/torchgen/gen.py b/phivenv/Lib/site-packages/torchgen/gen.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebf41b472c272b701a20b76c881d433ca6a5bca4
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/gen.py
@@ -0,0 +1,3003 @@
+from __future__ import annotations
+
+import argparse
+import functools
+import json
+import keyword
+import os
+from collections import defaultdict, namedtuple, OrderedDict
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any, Callable, Literal, TYPE_CHECKING, TypeVar
+from typing_extensions import assert_never
+
+import yaml
+
+import torchgen.api.dispatcher as dispatcher
+import torchgen.api.meta as meta
+import torchgen.api.native as native
+import torchgen.api.structured as structured
+import torchgen.dest as dest
+from torchgen.api import cpp
+from torchgen.api.translate import translate
+from torchgen.api.types import (
+    Binding,
+    CppSignature,
+    CppSignatureGroup,
+    DispatcherSignature,
+    NamedCType,
+    NativeSignature,
+    SpecialArgName,
+)
+from torchgen.context import (
+    method_with_native_function,
+    native_function_manager,
+    with_native_function,
+    with_native_function_and_indices,
+)
+from torchgen.gen_aoti_c_shim import (
+    gen_aoti_c_shim_files,
+    gen_static_dispatch_backend_call_signature,
+)
+from torchgen.gen_functionalization_type import (
+    gen_functionalization_definition,
+    gen_functionalization_registration,
+    gen_functionalization_view_inverse_declaration,
+    GenCompositeViewCopyKernel,
+)
+from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
+from torchgen.model import (
+    Argument,
+    BackendIndex,
+    BackendMetadata,
+    BaseOperatorName,
+    DEFAULT_KERNEL_NAMESPACE,
+    dispatch_device_map,
+    DispatchKey,
+    FRAGMENT_NAMESPACES,
+    FunctionSchema,
+    is_cuda_dispatch_key,
+    is_generic_dispatch_key,
+    is_ufunc_dispatch_key,
+    is_xpu_dispatch_key,
+    Location,
+    NativeFunction,
+    NativeFunctionsGroup,
+    NativeFunctionsViewGroup,
+    OperatorName,
+    OptionalType,
+    SchemaKind,
+    SelfArgument,
+    STRUCTURED_DISPATCH_KEYS,
+    TensorOptionsArguments,
+    Type,
+    Variant,
+    ViewSchemaKind,
+)
+from torchgen.native_function_generation import (
+    add_generated_native_functions,
+    gen_composite_functional_kernel,
+    gen_composite_out_kernel,
+    pre_group_native_functions,
+)
+from torchgen.selective_build.selector import SelectiveBuilder
+from torchgen.utils import (
+    concatMap,
+    context,
+    FileManager,
+    make_file_manager,
+    mapMaybe,
+    NamespaceHelper,
+    Target,
+)
+from torchgen.yaml_utils import YamlDumper, YamlLoader
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+
+T = TypeVar("T")
+
+# Welcome to the ATen code generator v2!  The ATen code generator is
+# responsible for parsing native_functions.yaml and then generating
+# various generated files (e.g., TypeDefault.cpp) based on the operators
+# defined in this file.  This means that the code generator knows how to
+# parse function schema, and then translate this into various C++ types
+# and boilerplate code.
+#
+# Some things to know about this file when you modify it:
+#
+# - This file has STRICT mypy typechecking.  Typecheck it with
+#   `mypy --config mypy-strict.ini` in the root source directory
+#
+# - Most of the heavy lifting lives in external modules:
+#   - 'model' has the data model for native_functions.yaml.  The classes
+#     in those file represent what you see when you look at
+#     a native_functions.yaml
+#   - 'api' has conversions for how to translate JIT schema into
+#     the various C++ APIs that the codegen interacts with.  There
+#     are in fact THREE different C++ APIs: the public C++ API,
+#     the dispatcher API, and the legacy dispatcher API.  See each
+#     of these respective files for more information
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                         HELPER FUNCTIONS
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+# A custom loader for YAML to let us also keep track of line numbers
+# of each entry in the YAML file
+class LineLoader(YamlLoader):
+    def construct_mapping(self, node, deep=False):  # type: ignore[no-untyped-def]
+        mapping = super().construct_mapping(node, deep=deep)  # type: ignore[no-untyped-call]
+        # Add 1 so line numbering starts at 1
+        mapping["__line__"] = node.start_mark.line + 1
+        return mapping
+
+
+# Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
+ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
+
+
+_GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {}
+_GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {}
+
+
+def file_manager_from_dispatch_key(
+    dispatch_key: DispatchKey,
+    device_fms: dict[str, FileManager],
+    default_fm: FileManager,
+) -> FileManager:
+    fm = device_fms.get(
+        next(
+            (
+                device
+                for check, device in dispatch_device_map.items()
+                if check(dispatch_key)
+            ),
+            "",
+        ),
+        default_fm,
+    )
+    return fm
+
+
+def parse_native_yaml_struct(
+    es: object,
+    valid_tags: set[str],
+    ignore_keys: set[DispatchKey] | None = None,
+    path: str = "",
+    skip_native_fns_gen: bool = False,
+) -> ParsedYaml:
+    assert isinstance(es, list)
+    rs: list[NativeFunction] = []
+    bs: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = defaultdict(dict)
+    for e in es:
+        assert isinstance(e, dict), f"expected to be dict: {e}"
+        assert isinstance(e.get("__line__"), int), e
+        loc = Location(path, e["__line__"])
+        funcs = e.get("func")
+        assert funcs is not None, f"missed 'func' in {e}"
+        with context(lambda: f"in {loc}:\n  {funcs}"):
+            func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys)
+            rs.append(func)
+            BackendIndex.grow_index(bs, m)
+    error_check_native_functions(rs)
+    # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
+    indices: dict[DispatchKey, BackendIndex] = defaultdict(
+        lambda: BackendIndex(
+            dispatch_key=DispatchKey.Undefined,
+            use_out_as_primary=True,
+            external=False,
+            device_guard=False,
+            # I'm actually not sure about this; undefined could be hit on
+            # empty TensorList, hypothetically that could have sizes in it
+            index={},
+        )
+    )
+    if not skip_native_fns_gen:
+        add_generated_native_functions(rs, bs)
+    for k, v in bs.items():
+        # All structured in-tree operators are implemented in terms of their out operator.
+        indices[k] = BackendIndex(
+            dispatch_key=k,
+            use_out_as_primary=True,
+            external=False,
+            # Only cuda-like devices in tree require device guards
+            device_guard=is_cuda_dispatch_key(k) or is_xpu_dispatch_key(k),
+            index=v,
+        )
+    return ParsedYaml(rs, indices)
+
+
+def parse_tags_yaml_struct(es: object, path: str = "") -> set[str]:
+    assert isinstance(es, list)
+    rs: set[str] = set()
+    for e in es:
+        assert isinstance(e.get("__line__"), int), e
+        loc = Location(path, e["__line__"])
+        tags = e.get("tag")
+        with context(lambda: f"in {loc}:\n  {tags}"):
+            e_i = e.copy()
+            name = e_i.pop("tag")
+            desc = e_i.pop("desc", "")
+            # ensure that each tag has a non-empty description
+            assert desc != ""
+            rs.add(name)
+    return rs
+
+
+@functools.cache
+def parse_tags_yaml(path: str) -> set[str]:
+    global _GLOBAL_PARSE_TAGS_YAML_CACHE
+    if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
+        with open(path) as f:
+            es = yaml.load(f, Loader=LineLoader)
+            _GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path)
+
+    return _GLOBAL_PARSE_TAGS_YAML_CACHE[path]
+
+
+def parse_native_yaml(
+    path: str,
+    tags_yaml_path: str,
+    ignore_keys: set[DispatchKey] | None = None,
+    *,
+    skip_native_fns_gen: bool = False,
+    loaded_yaml: object | None = None,
+) -> ParsedYaml:
+    global _GLOBAL_PARSE_NATIVE_YAML_CACHE
+    if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
+        valid_tags = parse_tags_yaml(tags_yaml_path)
+
+        # if a loaded yaml is provided, use that instead of reading from path
+        if loaded_yaml is None:
+            with open(path) as f:
+                es = yaml.load(f, Loader=LineLoader)
+        else:
+            es = loaded_yaml
+
+        _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct(
+            es,
+            valid_tags,
+            ignore_keys,
+            path=path,
+            skip_native_fns_gen=skip_native_fns_gen,
+        )
+
+    return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path]
+
+
+# Some assertions are already performed during parsing, but those are only within a single NativeFunction.
+# Assertions here are meant to be performed across NativeFunctions.
+def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
+    func_map: dict[OperatorName, NativeFunction] = {}
+    base_func_map: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
+    for f in funcs:
+        func_map[f.func.name] = f
+        base_func_map[f.func.name.name].append(f)
+    for f in funcs:
+        if f.structured_delegate is not None:
+            delegate_func = func_map.get(f.structured_delegate)
+            assert delegate_func is not None, (
+                f"{f.func.name} is marked as a structured_delegate pointing to "
+                f"{f.structured_delegate}, but {f.structured_delegate} is missing."
+            )
+            assert delegate_func.structured, (
+                f"{f.func.name} is marked as a structured_delegate pointing to "
+                f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. "
+                f"Consider adding 'structured=True' to the delegated operator"
+            )
+
+        # Check for reserved Python keywords
+        PYTHON_RESERVED_KEYWORDS = set(keyword.kwlist)
+        # List of pre-existing operators that are known to have reserved keywords
+        # Exclusion list is used to suppress the assertion for these operators
+        EXCLUSION_LIST = {
+            ("_has_compatible_shallow_copy_type", "from"),
+            ("random_.from", "from"),
+            ("uniform_", "from"),
+        }
+
+        for arg in f.func.arguments.flat_all:
+            if arg.name in PYTHON_RESERVED_KEYWORDS:
+                if (str(f.func.name), arg.name) not in EXCLUSION_LIST:
+                    raise AssertionError(
+                        f"Argument name '{arg.name}' in function '{f.func.name}' is a reserved Python keyword."
+                    )
+        # See Note [resize_ in Functionalization]
+        # resize_() is technically an inplace view op (and therefore needs the tag),
+        # but it would be overkill to add a true "view" variant of resize.
+        # Instead, resize_() gets special treatment in functionalization,
+        # and we have a resize() op that is non-aliasing + functional.
+        if (
+            "inplace_view" in f.tags
+            and str(f.func.name) != "resize_"
+            and str(f.func.name) != "resize_as_"
+            and str(f.func.name.name) != "set_"
+        ):
+            base_name = f.func.name.name
+            assert base_name.inplace, (
+                f"{f.func.name} is marked with tag: inplace_view, but it doesn't follow the naming "
+                "convention for inplace ops - the codegen expects the base name to have a trailing underscore. "
+            )
+            out_of_place_base_name = BaseOperatorName(
+                base_name.base, False, base_name.dunder_method
+            )
+            assert len(base_func_map[out_of_place_base_name]) > 0, (
+                f"{f.func.name} is marked with tag: inplace_view. The codegen expects there to be a corresponding "
+                f"out-of-place view op with the name '{base_name}' and matching schema, but it didn't find one. "
+            )
+
+
+def cpp_string(s: str) -> str:
+    """Convert a python string into a c++ string literal"""
+    s = s.replace("\\", "\\\\")
+    s = s.replace('"', '\\"')
+    s = s.replace("\a", "\\a")
+    s = s.replace("\b", "\\b")
+    s = s.replace("\f", "\\f")
+    s = s.replace("\n", "\\n")
+    s = s.replace("\v", "\\v")
+    s = s.replace("\t", "\\t")
+    return f'"{s}"'
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                        C++ CODE GENERATION
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+# Most functions in this section are curried: they consist of a function
+# that takes some parameters (e.g., what is to be generated) which itself
+# returns a function that actually maps NativeFunction to the code
+# to be generated.  This pattern makes it convenient to use map, concatMap
+# and similar functional combinators.
+
+
+def static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]:
+    if len(backends) == 0:
+        return []
+    else:
+        return [backend.dispatch_key for backend in backends] + [
+            DispatchKey.CompositeImplicitAutograd,
+            DispatchKey.CompositeImplicitAutogradNestedTensor,
+            DispatchKey.CompositeExplicitAutograd,
+            DispatchKey.CompositeExplicitAutogradNonFunctional,
+        ]
+
+
+def get_static_dispatch_backend(
+    f: NativeFunction, backend_index: BackendIndex
+) -> DispatchKey | None:
+    if f.structured_delegate is not None or backend_index.has_kernel(f):
+        # TODO: for ops with structured_delegate it should check the dispatch table of
+        # the out variant instead. For now, these structured ops all have CPU/CUDA kernels
+        # so we always dispatch to the `backend`, but this could be wrong when we
+        # migrate math/default_backend ops to use structured delegate.
+        return backend_index.dispatch_key
+    elif f.has_composite_explicit_autograd_kernel:
+        return DispatchKey.CompositeExplicitAutograd
+    elif f.has_composite_explicit_autograd_non_functional_kernel:
+        return DispatchKey.CompositeExplicitAutogradNonFunctional
+    elif f.has_composite_implicit_autograd_kernel:
+        return DispatchKey.CompositeImplicitAutograd
+    elif f.has_composite_implicit_autograd_nested_tensor_kernel:
+        return DispatchKey.CompositeImplicitAutogradNestedTensor
+    return None
+
+
+def static_dispatch_ops_header(
+    f: NativeFunction, backend_index: list[BackendIndex]
+) -> str | None:
+    if backend_index is None or f.manual_kernel_registration:
+        return None
+
+    output = []
+    for index in backend_index:
+        dispatch_key = get_static_dispatch_backend(f, index)
+        if dispatch_key is not None:
+            output.append(
+                f"#include "
+            )
+    return "\n".join(output)
+
+
+def static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]:
+    return [
+        f"#include "
+        for dispatch_key in static_dispatch_keys(backends)
+    ]
+
+
+# Translates arguments of `sig` to CppSignature bindings.
+# Note that we have a special case for `memory_format` argument and this case is not covered by
+# tools.codegen.api.translate() yet as its application is limited to static dispatch.
+def translate_args(
+    sig: CppSignature | DispatcherSignature,
+    cpp_sig: CppSignature,
+) -> str:
+    # Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
+    def add_spl_memory_format_binding(input_bindings: list[Binding]) -> list[Binding]:
+        output_bindings: list[Binding] = []
+        for binding in input_bindings:
+            if binding.name == "memory_format":
+                spl_mem_format_binding = Binding(
+                    nctype=NamedCType(
+                        SpecialArgName.possibly_redundant_memory_format,
+                        binding.nctype.type,
+                    ),
+                    name=binding.name,
+                    default=binding.default,
+                    argument=binding.argument,
+                )
+                output_bindings.append(spl_mem_format_binding)
+            else:
+                output_bindings.append(binding)
+        return output_bindings
+
+    src_bindings = list(sig.arguments())
+    goal_bindings = list(cpp_sig.arguments())
+    # When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
+    # get memory_format bindings of dispatcher signature to have the same NCType as well
+    for arg in goal_bindings:
+        if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format:
+            src_bindings = add_spl_memory_format_binding(src_bindings)
+            break
+    exprs = translate(src_bindings, goal_bindings)
+    return ", ".join(a.expr for a in exprs)
+
+
+def generate_static_dispatch_backend_call(
+    sig: CppSignature | DispatcherSignature,
+    f: NativeFunction,
+    backend_index: BackendIndex,
+) -> str:
+    cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
+    name = cpp_sig.name()
+    exprs = translate_args(sig, cpp_sig)
+    backend_metadata = backend_index.get_kernel(f)
+    kernel_ns = (
+        backend_metadata.cpp_namespace
+        if backend_metadata and backend_metadata.cpp_namespace
+        else DEFAULT_KERNEL_NAMESPACE
+    )
+    ns = kernel_ns.replace("::native", "")
+    return f"return {ns}::{backend_index.dispatch_key.lower()}::{name}({exprs});"
+
+
+def generate_static_dispatch_fallback_call(
+    sig: CppSignature | DispatcherSignature,
+    f: NativeFunction,
+    backend_indices: list[BackendIndex],
+) -> str:
+    cpp_sigs = CppSignatureGroup.from_native_function(
+        f, method=False, fallback_binding=False
+    )
+    if sig.symint and f.func.has_symint():
+        cpp_sig = cpp_sigs.symint_signature
+    else:
+        cpp_sig = cpp_sigs.signature
+    assert cpp_sig is not None
+    name = cpp_sig.name()
+    exprs = translate_args(sig, cpp_sig)
+    ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "")
+    if f.has_composite_explicit_autograd_kernel:
+        return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
+    elif f.has_composite_explicit_autograd_non_functional_kernel:
+        return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});"
+    elif f.has_composite_implicit_autograd_kernel:
+        return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});"
+    elif f.has_composite_implicit_autograd_nested_tensor_kernel:
+        return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});"
+    else:
+        return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
+{", ".join([str(index.dispatch_key) for index in backend_indices])} ");"""
+
+
+def static_dispatch(
+    sig: CppSignature | DispatcherSignature,
+    f: NativeFunction,
+    backend_indices: list[BackendIndex],
+) -> str:
+    """
+    For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one
+    backends exist, fallback to static dispatch by determining dispatch key from inputs.
+    Arguments:
+        sig: A CppSignature or DispatcherSignature for this native function we want to use.
+        f: NativeFunction to generate static dispatch.
+        backend_indices: All available backends.
+    Return:
+        C++ code to call backend-specific functions, e.g., "return at::cpu::add(self, other, scale);"
+    """
+    if len(backend_indices) == 0 or f.manual_kernel_registration:
+        return ""
+
+    keys = [
+        b
+        for b in backend_indices
+        if b.has_kernel(f)
+        or (
+            f.structured_delegate is not None
+            and b.dispatch_key in STRUCTURED_DISPATCH_KEYS
+        )
+    ]
+    if len(keys) == 1:
+        return generate_static_dispatch_backend_call(sig, f, keys[0])
+    elif len(keys) == 0:
+        return generate_static_dispatch_fallback_call(sig, f, backend_indices)
+
+    native_tensor_args = [
+        a.name
+        for a in sig.arguments()
+        if isinstance(a.argument, SelfArgument)
+        or isinstance(a.argument, Argument)
+        and a.argument.type.is_tensor_like()
+    ]
+    tensor_args = ", ".join(native_tensor_args)
+    tensor_opts = f.func.arguments.tensor_options
+
+    stmts = []
+    subexprs: list[str] = []
+    if tensor_opts is not None:
+        subexprs.append(
+            "DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))"
+        )
+    if tensor_args != "":
+        subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})")
+    stmts.append(f"""DispatchKeySet _dk_set = {" | ".join(subexprs)};""")
+    stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);")
+
+    dispatch_code = []
+    for index in keys:
+        dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
+        dispatch_code.append(
+            f"""\t{generate_static_dispatch_backend_call(sig, f, index)};"""
+        )
+
+    fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices)
+    connector = "\n\t\t"
+
+    return f"""
+    {connector.join(stmts)}
+    switch (_dk) {{
+        {connector.join(dispatch_code)}
+        default:
+            {fallback}
+    }}
+    """
+
+
+# Generates RegisterSchema.cpp.  Depending on the selector, either
+# all schemas are registered, or only some are (in the case of
+# selective build)
+@dataclass(frozen=True)
+class RegisterSchema:
+    selector: SelectiveBuilder
+    known_tags: dict[str, int] = field(default_factory=dict)
+
+    @method_with_native_function
+    def __call__(self, f: NativeFunction) -> str | None:
+        if not self.selector.is_native_function_selected(f):
+            return None
+        tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}"
+        if tags == "{}":
+            return f"m.def({cpp_string(str(f.func))}, {{}});\n"
+        maybe_tags = ""
+        if tags not in self.known_tags:
+            idx = len(self.known_tags)
+            self.known_tags[tags] = idx
+            maybe_tags = f"const std::vector tags_{idx} = {tags};\n"
+        return f"{maybe_tags}m.def({cpp_string(str(f.func))}, tags_{self.known_tags[tags]});\n"
+
+
+# Generates Operators.h and Operators.cpp.
+# These provide macros that, given an operator and overload name, allow users
+# to access an "un-overloaded" function version of the operator. This
+# is useful for extension writers who want to (1) want to decltype the operator
+# and (2) don't want to worry about method-only operators.
+@dataclass(frozen=True)
+class ComputeOperators:
+    target: Literal[Target.DECLARATION, Target.DEFINITION]
+    static_dispatch_backend_indices: list[BackendIndex]
+
+    @method_with_native_function
+    def __call__(self, f: NativeFunction) -> str:
+        sig = DispatcherSignature.from_schema(f.func)
+        name = f.func.name.unambiguous_name()
+
+        if self.target is Target.DECLARATION:
+            # Note [The ATen Operators API]
+            # The ATen Operators API lives in the at::_ops namespace, and contains compile-time
+            # metadata about each operator + entry points into the Dispatcher.
+            # The C++ function, method, and redispatch API's are all implemented as wrappers
+            # into various bits of the structs defined here.
+            #
+            # Important characteristics about the Operators API:
+            # (1) It follows the Dispatcher API.
+            #     This is kind of necessary to avoid overhead.
+            #     For example: if it followed the C++ API, then all of the faithful C++ factory functions
+            #     would need to wrap their arguments into TensorOptions only to unwrap them again.
+            # (2) Overload names are disambiguated.
+            #     This is helpful for pytorch extenders who would like to decltype() an aten operator,
+            #     that has overloads, e.g. decltype(at::_ops::mul_Tensor::call)
+            # (3) No argument defaulting is allowed.
+            #     This is more of an implementation detail to avoid #include cycles,
+            #     since TensorBody.h (which defines the Tensor class) needs to include this file.
+            # (4) manual_cpp_bindings and faithful names are not included in the API.
+            #     This applies to stuff like __dispatch__is_complex(), and add_outf().
+            #     These aren't "real aten ops", they're just additional functions provided by the C++ API.
+            #     They're implemented as wrappers in Functions.h that call into the actual operators
+            #     defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call().
+            #     This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher.
+            return f"""
+struct TORCH_API {name} {{
+  using schema = {sig.type()};
+  using ptr_schema = schema*;
+  // See Note [static constexpr char* members for windows NVCC]
+  static constexpr const char* name = "aten::{f.func.name.name}";
+  static constexpr const char* overload_name = "{f.func.name.overload_name}";
+  static constexpr const char* schema_str = {cpp_string(str(f.func))};
+  static {sig.defn(name="call", is_redispatching_fn=False)};
+  static {sig.defn(name="redispatch", is_redispatching_fn=True)};
+}};"""
+
+        elif self.target is Target.DEFINITION:
+            defns = f"""
+// aten::{f.func}
+static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{
+  return c10::Dispatcher::singleton()
+      .findSchemaOrThrow({name}::name, {name}::overload_name)
+      .typed<{name}::schema>();
+}}
+"""
+            for is_redispatching_fn in [False, True]:
+                if is_redispatching_fn:
+                    dispatcher_exprs_str = ", ".join(
+                        ["dispatchKeySet"] + [a.name for a in sig.arguments()]
+                    )
+                    method_base = "redispatch"
+                else:
+                    dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()])
+                    method_base = "call"
+
+                dispatcher_call = method_base
+                method_name = f"{name}::{method_base}"
+
+                fn_body = f"""
+    static auto op = create_{name}_typed_handle();
+    return op.{dispatcher_call}({dispatcher_exprs_str});"""
+
+                if (
+                    not is_redispatching_fn
+                    and len(self.static_dispatch_backend_indices) > 0
+                ):
+                    # call() should go through static dispatch
+                    fn_body = static_dispatch(
+                        sig, f, backend_indices=self.static_dispatch_backend_indices
+                    )
+                defns += f"""
+// aten::{f.func}
+{sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
+    {fn_body}
+}}
+"""
+            return defns
+        else:
+            assert_never(self.target)
+
+
+# Generates Functions.h, which provides the functional public C++ API,
+# and the scaffolding to call into the dispatcher from these functions.
+@dataclass(frozen=True)
+class ComputeFunction:
+    @method_with_native_function
+    def __call__(self, f: NativeFunction) -> str | None:
+        sig_group = CppSignatureGroup.from_native_function(
+            f, method=False, fallback_binding=f.manual_cpp_binding
+        )
+        has_symint = f.func.has_symint()
+
+        result = ""
+        for sig in sig_group.signatures():
+            # See Note [The ATen Operators API]
+            target_sig = DispatcherSignature.from_schema(f.func)
+            exprs = translate(sig.arguments(), target_sig.arguments())
+            exprs_str = ", ".join([e.expr for e in exprs])
+
+            if sig.symint:
+                intlike_t = "c10::SymInt"
+            else:
+                intlike_t = "int64_t"
+
+            if Variant.function in f.variants:
+                result += f"""
+// aten::{f.func}
+inline {sig.decl()} {{
+    return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
+}}"""
+
+            # The template function can be used from template situations
+            # where you want to switch between the symint or not version
+            # depending on a template argument
+            #
+            # NB: we ALWAYS generate this even for methods.  But we put it in
+            # this header so it can take advantage of per-op headers
+            if has_symint:
+                result += f"""
+namespace symint {{
+  template >>
+  {sig.decl(suppress_symint_suffix=True)} {{
+    return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
+  }}
+}}
+"""
+        return result
+
+
+# Generates TensorBody.h. This file provides the object-oriented (method-based)
+# public C++ API, and the scaffolding to call into the dispatcher from these functions.
+@dataclass(frozen=True)
+class ComputeTensorMethod:
+    target: Literal[Target.DECLARATION, Target.DEFINITION]
+    static_dispatch_backend_indices: list[BackendIndex]
+
+    @method_with_native_function
+    def __call__(self, f: NativeFunction) -> str | None:
+        if Variant.method not in f.variants:
+            return None
+
+        assert not f.func.is_out_fn()
+        assert f.func.arguments.self_arg is not None
+
+        sig_group = CppSignatureGroup.from_native_function(
+            f, method=True, fallback_binding=f.manual_cpp_binding
+        )
+
+        if self.target is Target.DECLARATION:
+            result = ""
+            for sig in sig_group.signatures():
+                result += f"{sig.decl()} const;\n"
+            return result
+
+        if self.target is not Target.DEFINITION:
+            assert_never(self.target)
+
+        result = ""
+
+        for sig in sig_group.signatures():
+            target_sig = DispatcherSignature.from_schema(f.func)
+            exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
+            exprs_str = ", ".join([e.expr for e in exprs])
+
+            result += f"""
+// aten::{f.func}
+inline {sig.defn(prefix="Tensor::")} const {{
+    return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
+}}
+"""
+
+        return result
+
+
+# Generates RedispatchFunctions.h.
+# This is similar to the C++ API defined in Functions.h, but provides access
+# to the dispatcher's redispatch API.
+@dataclass(frozen=True)
+class ComputeRedispatchFunction:
+    @method_with_native_function
+    def __call__(self, f: NativeFunction) -> str | None:
+        # We unconditionally generate function variants of the redispatch API.
+        # This is mainly because we can namespace functions separately, but not methods,
+        sig_group = CppSignatureGroup.from_native_function(
+            f, method=False, fallback_binding=f.manual_cpp_binding
+        )
+
+        result = ""
+        for sig in sig_group.signatures():
+            target_sig = DispatcherSignature.from_schema(f.func)
+            exprs = translate(sig.arguments(), target_sig.arguments())
+            exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs])
+
+            result += f"""
+// aten::{f.func}
+inline {sig.decl(is_redispatching_fn=True)} {{
+    return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
+}}
+"""
+
+        return result
+
+
+# Generates ATenOpList.cpp, a runtime accessible list of all aten
+# operators.
+# TODO: This was historically used to help some JIT interop code
+# figure out whether or not to treat aten namespace'd operators
+# one way or another, we should reevaluate if this is actually needed.
+@with_native_function
+def compute_aten_op(f: NativeFunction) -> str:
+    return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},'
+
+
+# Generates MetaFunctions.h
+def compute_meta_function_declaration(g: NativeFunctionsGroup) -> str | None:
+    if not g.structured:
+        return None
+    with native_function_manager(g.out):
+        name = meta.name(g)
+        args = structured.meta_arguments(g)
+        args_str = ", ".join(a.decl() for a in args)
+        parent_class = g.out.structured_inherits
+        if parent_class is None:
+            parent_class = "at::impl::MetaBase"
+        meta_return = "void"
+        precomputed = g.out.precomputed if g.structured else None
+
+        if precomputed:
+            # Generate the template declaration with one bool parameter for each
+            # precomputed element. Each parameter is true if the corresponding (in
+            # terms of position) precomputed element has been set.
+            precomputed_values = [*precomputed.replace.values(), precomputed.add]
+            precomputed_elements = [
+                elem for replace_list in precomputed_values for elem in replace_list
+            ]
+            precomputed_template_parameters = [
+                elem.name.upper() for elem in precomputed_elements
+            ]
+            precomputed_template_params_str = ", ".join(
+                f"bool {param} = false" for param in precomputed_template_parameters
+            )
+            precompute_template_decl = f"template <{precomputed_template_params_str}>"
+
+            # Generate a string containing declarations of all precomputed elements.
+            precomputed_elements_with_cpp_types = [
+                structured.argument_type(elem, binds=elem.name)
+                for elem in precomputed_elements
+            ]
+
+            precomputed_elements_decl = ";\n".join(
+                f"{elem.cpp_type(strip_ref=True)} {elem.name}"
+                for elem in precomputed_elements_with_cpp_types
+            )
+
+            # Generate "setter" methods for each precomputed element. Each method will return
+            # a new instance of precompute_out with the template parameter that corresponds to
+            # the member set by the method to true (to indicate that it has been set).
+            setter_methods = []
+            for i, elem in enumerate(precomputed_elements):
+                # Generate the signature. The return type will be the same
+                # as the type of `this` but with the template parameter
+                # corresponding to the element set by this method set to true.
+                # The assert generated below will ensure that this template
+                # parameter is false on the type of `this`.
+                return_ty_templates = ", ".join(
+                    precomputed_template_parameters[:i]
+                    + ["true"]
+                    + precomputed_template_parameters[i + 1 :]
+                )
+                return_ty = f"precompute_out<{return_ty_templates}>"
+                elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type(
+                    strip_ref=True
+                )
+                signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)"
+
+                # Generate an assert which checks that the
+                # template parameter corresponding to the precomputed
+                # element that is set by this method is false on the
+                # class corresponding to the object that `this` points to.
+                # This ensures that each element can be set only once.
+                assert_msg = f'"{elem.name} already set"'
+                assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});"
+
+                # Generate the new object construction block. All state
+                # except the element that this method sets is copied from the
+                # object that `this` points to. The value for the element that
+                # the method sets is taken from a method parameter.
+                construction_stmts = []
+                construction_stmts.append(f"{return_ty} ret;")
+
+                for j, elem in enumerate(precomputed_elements):
+                    if i == j:
+                        construction_stmts.append(f"ret.{elem.name} = value;")
+                    else:
+                        construction_stmts.append(
+                            f"ret.{elem.name} = this->{elem.name};"
+                        )
+
+                construction_stmts.append("return ret;")
+                construction_block = "\n".join(construction_stmts)
+
+                setter_methods.append(
+                    f"""
+                    {signature} {{
+                        {assert_stmt}
+                        {construction_block}
+                    }}
+                """
+                )
+            setter_methods_decl = "\n".join(setter_methods)
+
+            # Meta should return an instance of the struct containing the precomputed elements.
+            meta_return_template_params = ", ".join(
+                ["true"] * len(precomputed_template_parameters)
+            )
+            # This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return
+            # type (which has a variable number of template parameters).
+            meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;"
+            meta_return = "meta_return_ty"
+            precomputed_decl = f"""
+                {precompute_template_decl}
+                struct TORCH_API precompute_out {{
+                    {setter_methods_decl}
+                    {precomputed_elements_decl};
+            }};"""
+        else:
+            meta_return_typedef = ""
+            precomputed_decl = ""
+
+        return f"""\
+struct TORCH_API structured_{name} : public {parent_class} {{
+    {precomputed_decl}
+    {meta_return_typedef}
+    {meta_return} meta({args_str});
+}};
+"""
+
+
+def needs_backend_select(f: NativeFunction, selector: SelectiveBuilder) -> bool:
+    name = str(f.func.name.name)
+    if name.endswith("_like") or name.startswith("new_"):
+        return False
+    if f.func.arguments.tensor_options is None:
+        return False
+    return selector.is_native_function_selected(f)
+
+
+# Generates RegisterBackendSelect.cpp, a series of kernels which provide
+# specialized computation of dispatch key for operator signatures which cannot
+# be easily done automatically using templating.
+@dataclass(frozen=True)
+class ComputeBackendSelect:
+    target: Literal[Target.DEFINITION, Target.REGISTRATION]
+
+    # Selector object to determine which operators to generate
+    # registration code for.
+    selector: SelectiveBuilder
+
+    @method_with_native_function
+    def __call__(self, f: NativeFunction) -> str | None:
+        if not needs_backend_select(f, self.selector):
+            return None
+
+        name = native.name(f.func)
+        # BackendSelect can go to Meta, so it must preserve symints
+        native_sig = NativeSignature(f.func, symint=True)
+
+        native_tensor_args = [
+            a
+            for a in native_sig.arguments()
+            if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
+        ]
+
+        dispatcher_sig = DispatcherSignature.from_schema(f.func)
+
+        sig: NativeSignature | DispatcherSignature
+        sig = dispatcher_sig
+        dispatcher_exprs = dispatcher_sig.exprs()
+        dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
+
+        if self.target is Target.DEFINITION:
+            # I don't think there's actually a good reason to generate
+            # these two cases differently
+            # The first case could probably be improved though- it calls computeDispatchKeySet(),
+            # which looks at TLS dispatch keys- there should not be any by the time we reach backend select.
+            if native_tensor_args:
+                assert f.func.arguments.has_tensor_arg()
+                tensor_args = ", ".join(a.name for a in native_tensor_args)
+                compute_dk = f"""\
+DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
+DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
+DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);"""
+            else:
+                assert not f.func.arguments.has_tensor_arg()
+                compute_dk = (
+                    f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});"
+                )
+            return f"""\
+// aten::{f.func}
+C10_ALWAYS_INLINE
+{sig.defn(name)} {{
+  {compute_dk}
+  return at::_ops::{f.func.name.unambiguous_name()}::redispatch(
+      _dk, {", ".join(a.expr for a in dispatcher_exprs)});
+}}
+"""
+        elif self.target is Target.REGISTRATION:
+            return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
+        else:
+            assert_never(self.target)
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                       YAML CODE GENERATION
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+def format_yaml(data: object) -> str:
+    # Ignore alias in Dumper
+    YamlDumper.ignore_aliases = lambda self, data: True  # type: ignore[assignment]
+
+    # Support serializing OrderedDict
+    def dict_representer(dumper: Any, data: Any) -> Any:
+        return dumper.represent_dict(data.items())
+
+    YamlDumper.add_representer(OrderedDict, dict_representer)  # type: ignore[no-untyped-call]
+    # Some yaml parsers (e.g. Haskell's) don't understand line breaks.
+    # width=1e9 turns off optional line breaks and improves
+    # the portability of the outputted yaml.
+    return yaml.dump(data, default_flow_style=False, Dumper=YamlDumper, width=1e9)  # type: ignore[no-any-return, call-overload]
+
+
+# For some reason, some defaults we write to YAML are written as native
+# YAML objects, rather than doing them uniformly as strings.  This
+# function detects those cases and converts them into native Python
+# objects.
+def pythonify_default(s: str) -> object:
+    if s == "true":
+        return True
+    elif s == "false":
+        return False
+
+    try:
+        return int(s)
+    except ValueError:
+        try:
+            return float(s)
+        except ValueError:
+            return s
+
+
+# What is a dynamic type?  Over time, the semantic meaning of
+# dynamic type has degraded to meaninglessness (in the old days,
+# it captured dtype-ness of types, but that has gone away with
+# the removal of TH).  These days, it's mostly the same thing as
+# the C++ API argument type, except that Tensor and Tensor?
+# arguments simply present as Tensor.
+#
+# TODO: Get rid of dynamic_type, after getting tools/autograd
+# to use the new codegen framework
+def dynamic_type(t: Type) -> str:
+    if isinstance(t, OptionalType):
+        return dynamic_type(t.elem)
+    # Note we don't use t.is_tensor_like() here because it would
+    # also include Tensor[]
+    if str(t) == "Tensor":
+        return "at::Tensor"
+    # This is a legacy concept, so never report SymInt
+    return cpp.argumenttype_type(
+        t, mutable=False, binds="__placeholder__", symint=False
+    ).cpp_type()
+
+
+def compute_method_of_yaml(variants: set[Variant]) -> list[str]:
+    # This is written out explicitly to ensure that Tensor and
+    # namespace are put into the list in the right order
+    method_of = ["Type"]
+    if Variant.method in variants:
+        method_of.append("Tensor")
+    if Variant.function in variants:
+        method_of.append("namespace")
+    return method_of
+
+
+def compute_returns_yaml(
+    f: NativeFunction,
+) -> tuple[list[dict[str, str]], dict[str, str]]:
+    # Note [name and field_name]
+    # ~~~~~~~~~~~~~~~~~~~~~~~~~~
+    # To understand name_to_field_name, we must first talk about this
+    # schema:
+    #
+    #   lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
+    #
+    # There is something very odd about this schema: it is an out
+    # variant of the function (that is to say, it will convert into
+    # at::lstsq_out() in the C++ API), but the names of the output
+    # return arguments don't match the keyword argument names of
+    # the inputs.  It TURNS OUT that in this situation, the historical
+    # Declarations.yaml we want to output is this (abbreviated to
+    # only show relevant fields):
+    #
+    #   arguments:
+    #     ...
+    #   - field_name: solution
+    #     name: X
+    #   - field_name: QR
+    #     name: qr
+    #     ...
+    #
+    #   returns:
+    #   - field_name: solution
+    #     name: X
+    #   - field_name: QR
+    #     name: qr
+    #
+    # The name of the return fields is stored in 'field_name', and the
+    # name of the arguments is stored in 'name'.  So when we process
+    # arguments, we need a way to get at the corresponding return.  At
+    # the moment, this is most conveniently done by constructing a
+    # mapping from name (the argument concept) to field_name (the
+    # return concept) while processing return arguments, since we don't
+    # directly maintain this correspondence in the modeling of function
+    # schema itself.
+    #
+    # See also https://github.com/pytorch/pytorch/issues/43114
+    name_to_field_name: dict[str, str] = {}
+
+    # Compute the returns field of the YAML entry
+    names = cpp.return_names(f)
+    returns = []
+    for i, (r, name) in enumerate(zip(f.func.returns, names)):
+        ret = {
+            "dynamic_type": dynamic_type(r.type),
+            "name": name,
+            # legacy, report ints
+            "type": cpp.return_type(r, symint=False).cpp_type(),
+        }
+
+        if r.name:
+            # See Note [name and field_name]
+            ret["field_name"] = r.name
+            if f.func.is_out_fn():
+                name_to_field_name[f.func.arguments.out[i].name] = r.name
+
+        returns.append(ret)
+
+    return returns, name_to_field_name
+
+
+# arguments in yaml roughly corresponds to the public C++ API
+def compute_cpp_argument_yaml(
+    cpp_a: Binding,
+    *,
+    schema_order: bool,
+    kwarg_only_set: set[str],
+    out_arg_set: set[str],
+    name_to_field_name: dict[str, str],
+) -> object:
+    if isinstance(cpp_a.argument, TensorOptionsArguments):
+        arg: dict[str, object] = {
+            "annotation": None,
+            "dynamic_type": "at::TensorOptions",
+            "is_nullable": False,
+            "name": cpp_a.name,
+            "type": cpp_a.type,
+            "kwarg_only": True,
+        }
+        if cpp_a.default is not None:
+            arg["default"] = cpp_a.default
+        return arg
+    elif isinstance(cpp_a.argument, SelfArgument):
+        raise AssertionError
+    elif isinstance(cpp_a.argument, Argument):
+        return compute_argument_yaml(
+            cpp_a.argument,
+            schema_order=schema_order,
+            kwarg_only_set=kwarg_only_set,
+            out_arg_set=out_arg_set,
+            name_to_field_name=name_to_field_name,
+        )
+
+
+def compute_argument_yaml(
+    a: Argument,
+    *,
+    schema_order: bool,
+    kwarg_only_set: set[str],
+    out_arg_set: set[str],
+    name_to_field_name: dict[str, str],
+) -> object:
+    arg: dict[str, object] = {
+        "annotation": str(a.annotation) if a.annotation else None,
+        "dynamic_type": dynamic_type(a.type),
+        "is_nullable": a.type.is_nullable(),
+        "name": a.name,
+        # legacy, report ints
+        "type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(),
+    }
+    if a.default is not None:
+        arg["default"] = pythonify_default(
+            cpp.default_expr(a.default, a.type, symint=False)
+        )
+    if a.name in kwarg_only_set:
+        arg["kwarg_only"] = True
+    if a.name in out_arg_set:
+        arg["output"] = True
+        arg["allocate"] = True
+        # See Note [name and field_name]
+        if a.name in name_to_field_name:
+            arg["field_name"] = name_to_field_name[a.name]
+    # Historically, booleans don't get their size recorded, because it
+    # is already built into the cpp type (e.g., std::array)
+    l = a.type.is_list_like()
+    if l is not None and l.size is not None and str(l.elem) != "bool":
+        arg["size"] = l.size
+    return arg
+
+
+@with_native_function
+def compute_declaration_yaml(f: NativeFunction) -> object:
+    returns, name_to_field_name = compute_returns_yaml(f)
+
+    # These sets are used to conveniently test if an argument is a
+    # kwarg-only or out argument
+    kwarg_only_set = {a.name for a in f.func.arguments.flat_kwarg_only}
+    out_arg_set = {a.name for a in f.func.arguments.out}
+
+    sig_group = CppSignatureGroup.from_native_function(
+        f, method=False, fallback_binding=False
+    )
+    cpp_args = sig_group.signature.arguments()
+    arguments = [
+        compute_cpp_argument_yaml(
+            cpp_a,
+            schema_order=False,
+            kwarg_only_set=kwarg_only_set,
+            out_arg_set=out_arg_set,
+            name_to_field_name=name_to_field_name,
+        )
+        for cpp_a in cpp_args
+    ]
+
+    schema_order_jit_arguments = list(f.func.schema_order_arguments())
+
+    schema_order_arguments = [
+        compute_argument_yaml(
+            a,
+            schema_order=True,
+            kwarg_only_set=kwarg_only_set,
+            out_arg_set=out_arg_set,
+            name_to_field_name=name_to_field_name,
+        )
+        for a in schema_order_jit_arguments
+    ]
+
+    cpp_schema_order_types = [
+        # NB: method here doesn't matter
+        r.type
+        for a in schema_order_jit_arguments
+        for r in cpp.argument(
+            a,
+            method=False,
+            cpp_no_default_args=set(),
+            faithful=False,
+            symint=False,
+            has_tensor_options=False,
+        )
+    ]
+
+    # legacy, report ints
+    cpp_returns = cpp.returns_type(f.func.returns, symint=False).cpp_type()
+    schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"
+
+    is_factory_method = (
+        any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args)
+        and Variant.method not in f.variants
+    )
+
+    return OrderedDict(
+        [
+            ("name", cpp.name(f.func)),
+            ("operator_name", str(f.func.name.name)),
+            ("overload_name", str(f.func.name.overload_name)),
+            ("manual_kernel_registration", f.manual_kernel_registration),
+            (
+                "category_override",
+                f.category_override if f.category_override is not None else "",
+            ),
+            ("schema_string", f"aten::{f.func}"),
+            ("arguments", arguments),
+            ("schema_order_cpp_signature", schema_order_cpp_signature),
+            ("schema_order_arguments", schema_order_arguments),
+            ("method_of", compute_method_of_yaml(f.variants)),
+            ("mode", "native"),
+            ("python_module", "" if f.python_module is None else f.python_module),
+            ("returns", returns),
+            ("inplace", f.func.name.name.inplace),
+            ("is_factory_method", is_factory_method),
+            ("abstract", f.is_abstract),
+            ("device_guard", f.device_guard),
+            ("with_gil", False),
+            ("deprecated", False),
+            ("has_math_kernel", f.has_composite_implicit_autograd_kernel),
+        ]
+    )
+
+
+# See Note [Auto generated composite kernels]
+def has_autogenerated_composite_kernel(f: NativeFunction) -> bool:
+    return (f.structured or f.structured_delegate is not None) and (
+        f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace
+    )
+
+
+@with_native_function_and_indices
+def compute_registration_declarations(
+    f: NativeFunction, backend_indices: dict[DispatchKey, BackendIndex]
+) -> str:
+    name = dispatcher.name(f.func)
+    returns_type = dispatcher.returns_type(f.func.returns).cpp_type()
+    args = dispatcher.arguments(f.func)
+    args_str = ", ".join(a.no_default().decl() for a in args)
+    comment_data: dict[str, str] = {
+        "schema": f"aten::{f.func}",
+        # TODO: What exactly is the semantics of the 'dispatch' field?
+        "dispatch": str(
+            {k for k, v in backend_indices.items() if v.has_kernel(f)}
+            != {DispatchKey.CompositeImplicitAutograd}
+            and {k for k, v in backend_indices.items() if v.has_kernel(f)}
+            != {
+                DispatchKey.CompositeImplicitAutograd,
+                DispatchKey.CompositeImplicitAutogradNestedTensor,
+            }
+        ),
+        "default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)),
+    }
+    return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
+"""
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                           RUN IT ALL
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+def get_custom_build_selector(
+    provided_op_registration_allowlist: list[str] | None,
+    op_selection_yaml_path: str | None,
+) -> SelectiveBuilder:
+    assert not (
+        provided_op_registration_allowlist is not None
+        and op_selection_yaml_path is not None
+    ), (
+        "Both provided_op_registration_allowlist and "
+        + "op_selection_yaml_path can NOT be provided at the "
+        + "same time."
+    )
+
+    op_registration_allowlist: set[str] | None = None
+    if provided_op_registration_allowlist is not None:
+        op_registration_allowlist = set(provided_op_registration_allowlist)
+
+    if op_registration_allowlist is not None:
+        selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
+            op_registration_allowlist,
+            True,
+            False,
+        )
+    elif op_selection_yaml_path is not None:
+        selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path)
+    else:
+        selector = SelectiveBuilder.get_nop_selector()
+
+    return selector
+
+
+def get_grouped_by_view_native_functions(
+    native_functions: Sequence[NativeFunction],
+) -> Sequence[NativeFunction | NativeFunctionsViewGroup]:
+    def maybe_create_view_group(
+        d: dict[ViewSchemaKind | SchemaKind, NativeFunction],
+    ) -> list[NativeFunction | NativeFunctionsViewGroup]:
+        funcs: list[NativeFunction | NativeFunctionsViewGroup] = []
+        if ViewSchemaKind.aliasing in d:
+            view = d.pop(ViewSchemaKind.aliasing)
+            view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
+            view_copy = d.pop(SchemaKind.functional, None)
+
+            funcs.append(
+                NativeFunctionsViewGroup(
+                    view=view,
+                    view_copy=view_copy,
+                    view_inplace=view_inplace,
+                )
+            )
+        # Take the remaining functions that weren't part of the view group
+        # and emit them separately
+        funcs.extend(d.values())
+        return funcs
+
+    grouped_by_views: dict[
+        FunctionSchema, dict[SchemaKind | ViewSchemaKind, NativeFunction]
+    ] = defaultdict(dict)
+    for f in native_functions:
+        schema = f.func.view_signature()
+        view_kind: ViewSchemaKind = f.view_schema_kind
+        # We need to group up ops relevant to the same "view", consisting of:
+        # view op (ViewSchemaKind.aliasing)
+        # view_inplace op (ViewSchemaKind.aliasing_inplace)
+        # view_copy op (SchemaKind.functional)
+        if view_kind == ViewSchemaKind.non_aliasing:
+            kind = f.func.kind()
+            assert kind not in grouped_by_views[schema]
+            grouped_by_views[schema][kind] = f
+        else:
+            assert view_kind not in grouped_by_views[schema], (
+                f"{view_kind} already in {grouped_by_views[schema].keys()}"
+            )
+            grouped_by_views[schema][view_kind] = f
+
+    return list(concatMap(maybe_create_view_group, grouped_by_views.values()))
+
+
+def get_grouped_native_functions(
+    native_functions: Sequence[NativeFunction],
+) -> Sequence[NativeFunction | NativeFunctionsGroup]:
+    def flatten_pre_group(
+        d: dict[SchemaKind, NativeFunction],
+    ) -> Sequence[NativeFunction | NativeFunctionsGroup]:
+        r = NativeFunctionsGroup.from_dict(d)
+        if r is None:
+            # Invariant: any NativeFunctions that are code-generated
+            # should have been grouped into NativeFunctionsGroup objects
+            assert not any("generated" in f.tags for f in d.values())
+            return list(d.values())
+        else:
+            return [r]
+
+    # TODO: how come ValuesView isn't a Sequence lol
+    pre_grouped_native_functions = pre_group_native_functions(native_functions)
+    return list(
+        concatMap(flatten_pre_group, list(pre_grouped_native_functions.values()))
+    )
+
+
+def get_ns_grouped_kernels(
+    *,
+    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
+    backend_indices: dict[DispatchKey, BackendIndex],
+    native_function_decl_gen: Callable[
+        [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
+    ] = dest.compute_native_function_declaration,
+) -> dict[str, list[str]]:
+    ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
+    for f in grouped_native_functions:
+        native_function_namespaces = set()
+        dispatch_keys = set()
+        for dispatch_key, backend_idx in backend_indices.items():
+            backend_metadata = backend_idx.get_kernel(f)
+            if backend_metadata:
+                namespace = backend_metadata.cpp_namespace
+                dispatch_keys.add(dispatch_key)
+                native_function_namespaces.add(namespace)
+            else:
+                namespace = DEFAULT_KERNEL_NAMESPACE
+            assert len(native_function_namespaces) <= 1, (
+                f"Codegen only supports one namespace per operator, got {native_function_namespaces} from {dispatch_keys}"
+            )
+            ns_grouped_kernels[namespace].extend(
+                native_function_decl_gen(f, backend_idx)
+            )
+    return ns_grouped_kernels
+
+
+def get_native_function_declarations_from_ns_grouped_kernels(
+    *,
+    ns_grouped_kernels: dict[str, list[str]],
+) -> list[str]:
+    declarations: list[str] = []
+    newline = "\n"
+    for namespace, kernels in ns_grouped_kernels.items():
+        ns_helper = NamespaceHelper(
+            namespace_str=namespace,
+            entity_name="",
+            max_level=4,
+        )
+        # Convert to a set first to remove duplicate kernel names. Backends are
+        # allowed to repeat kernel names; only generate the declaration once!
+        ordered_kernels = list(OrderedDict.fromkeys(kernels))
+        declarations.extend(
+            f"""
+{ns_helper.prologue}
+{newline.join(ordered_kernels)}
+{ns_helper.epilogue}
+        """.split(newline)
+        )
+    return declarations
+
+
+# Return native function declarations grouped by their namespaces.
+def get_native_function_declarations(
+    *,
+    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
+    backend_indices: dict[DispatchKey, BackendIndex],
+    native_function_decl_gen: Callable[
+        [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
+    ] = dest.compute_native_function_declaration,
+) -> list[str]:
+    """
+    Generate kernel declarations, in `NativeFunction(s).h`.
+    :param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
+    :param backend_indices: kernel collections grouped by dispatch key.
+    :param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`.
+    :return: a list of string, from the string with all declarations, grouped by namespaces, split by newline.
+    """
+
+    ns_grouped_kernels = get_ns_grouped_kernels(
+        grouped_native_functions=grouped_native_functions,
+        backend_indices=backend_indices,
+        native_function_decl_gen=native_function_decl_gen,
+    )
+    return get_native_function_declarations_from_ns_grouped_kernels(
+        ns_grouped_kernels=ns_grouped_kernels
+    )
+
+
+def get_kernel_namespace(
+    *, f: NativeFunction | NativeFunctionsGroup, backend_idx: BackendIndex
+) -> str:
+    backend_metadata = backend_idx.get_kernel(f)
+    assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, (
+        f"The kernel for function {f.func.name if isinstance(f, NativeFunction) else f.functional.func.name} "
+        f"with dispatch key {backend_idx.dispatch_key}"
+        f" has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'."
+    )
+    return (
+        backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE
+    )
+
+
+# Return native function definitions grouped by dispatch key and custom namespace.
+# Used in RegisterDispatchKey.cpp and etc.
+def get_native_function_definitions(
+    *,
+    fm: FileManager,
+    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
+    dispatch_key: DispatchKey,
+    backend_idx: BackendIndex,
+    selector: SelectiveBuilder,
+    rocm: bool,
+    symint: bool,
+    skip_dispatcher_op_registration: bool,
+    gen_dispatch_helpers: bool,
+) -> list[str]:
+    definitions: list[str] = []
+    ns_definitions: dict[str, list[str]] = defaultdict(list)
+    anonymous_definitions: dict[str, list[str]] = defaultdict(list)
+    registrations: dict[str, dict[str, list[str]]] = defaultdict(dict)
+    newline = "\n"
+    ns_gen = dest.RegisterDispatchKey(
+        backend_idx,
+        Target.NAMESPACED_DEFINITION,
+        selector,
+        rocm=rocm,
+        symint=symint,
+        class_method_name=None,
+        skip_dispatcher_op_registration=skip_dispatcher_op_registration,
+    )
+    anonymous_gen = dest.RegisterDispatchKey(
+        backend_idx,
+        Target.ANONYMOUS_DEFINITION,
+        selector,
+        rocm=rocm,
+        symint=symint,
+        class_method_name=None,
+        skip_dispatcher_op_registration=skip_dispatcher_op_registration,
+    )
+    reg_gen = dest.RegisterDispatchKey(
+        backend_idx,
+        Target.REGISTRATION,
+        selector,
+        rocm=rocm,
+        symint=symint,
+        class_method_name=None,
+        skip_dispatcher_op_registration=skip_dispatcher_op_registration,
+    )
+    for f in grouped_native_functions:
+        kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
+            "::native", ""
+        )
+
+        ns_definitions[kernel_namespace].extend(
+            ns_gen(f),
+        )
+        anonymous_definitions[kernel_namespace].extend(
+            anonymous_gen(f),
+        )
+        namespace = (
+            f.namespace if isinstance(f, NativeFunction) else f.functional.namespace
+        )
+        if namespace not in registrations[kernel_namespace]:
+            registrations[kernel_namespace] = defaultdict(list)
+        registrations[kernel_namespace][namespace].extend(
+            reg_gen(f),
+        )
+
+    for kernel_namespace in ns_definitions:
+        if len(ns_definitions[kernel_namespace]) == 0:
+            continue
+        ns_helper = NamespaceHelper(namespace_str=kernel_namespace)
+        registration_body = ""
+        for namespace in registrations[kernel_namespace]:
+            if not registrations[kernel_namespace][namespace]:
+                continue
+            registration_body += f"""
+TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
+    {newline.join(registrations[kernel_namespace][namespace])}
+}}"""
+        definitions.extend(
+            fm.substitute_with_template(
+                "RegisterDispatchDefinitions.ini",
+                lambda: {
+                    "ns_prologue": ns_helper.prologue,
+                    "ns_epilogue": ns_helper.epilogue,
+                    "dispatch_anonymous_definitions": anonymous_definitions[
+                        kernel_namespace
+                    ],
+                    "static_init_dispatch_registrations": ""
+                    if skip_dispatcher_op_registration
+                    else registration_body,
+                    "deferred_dispatch_registrations": "",
+                    "dispatch_namespace": dispatch_key.lower(),
+                    "dispatch_namespaced_definitions": ns_definitions[kernel_namespace],
+                },
+            ).split(newline)
+        )
+
+    return definitions
+
+
+# Return native function declarations grouped by dispatch key and custom namespace.
+# Used in CPUFunctions_inl.h and etc.
+def get_namespaced_declaration(
+    *,
+    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
+    dispatch_key: DispatchKey,
+    backend_idx: BackendIndex,
+    selector: SelectiveBuilder,
+    rocm: bool,
+    symint: bool,
+) -> list[str]:
+    declarations: list[str] = []
+    ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
+    newline = "\n"
+    func = dest.RegisterDispatchKey(
+        backend_idx,
+        Target.NAMESPACED_DECLARATION,
+        selector,
+        rocm=rocm,
+        class_method_name=None,
+        skip_dispatcher_op_registration=False,
+        symint=symint,
+    )
+    for f in grouped_native_functions:
+        namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
+            "native", dispatch_key.lower()
+        )
+
+        ns_grouped_kernels[namespace].extend(
+            func(f),
+        )
+
+    for namespace, kernels in ns_grouped_kernels.items():
+        if len(kernels) == 0:
+            continue
+        ns_helper = NamespaceHelper(
+            namespace_str=namespace, entity_name="", max_level=3
+        )
+        ordered_kernels = list(OrderedDict.fromkeys(kernels))
+        declarations.extend(
+            f"""
+{ns_helper.prologue}
+{newline.join(ordered_kernels)}
+{ns_helper.epilogue}
+        """.split(newline)
+        )
+    return declarations
+
+
+# Return native function schema registration code for aten and other namespaces.
+def get_native_function_schema_registrations(
+    *,
+    native_functions: Sequence[NativeFunction],
+    schema_selector: SelectiveBuilder,
+) -> tuple[list[str], str]:
+    ns_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
+    for native_function in native_functions:
+        ns_native_functions[native_function.namespace].append(native_function)
+    schema_registrations = ""
+    aten_schema_registrations = []
+    custom_namespace = None
+    for namespace, funcs in ns_native_functions.items():
+        schema_registrations_body = list(
+            mapMaybe(RegisterSchema(schema_selector), funcs)
+        )
+        # NB: we have to separate aten namespace registration from other namespaces,
+        # because in the template we hardcoded an operator for ATen already.
+        if namespace == "aten":
+            aten_schema_registrations = schema_registrations_body
+        else:
+            custom_namespace = namespace
+            tab = "\t"
+            # if the namespace is predefined, we should use define a library fragment
+            # instead of a new library
+            torch_library_macro = (
+                "TORCH_LIBRARY_FRAGMENT"
+                if namespace in FRAGMENT_NAMESPACES
+                else "TORCH_LIBRARY"
+            )
+            schema_registrations += f"""
+{torch_library_macro}({custom_namespace}, m) {{
+  {tab.join(schema_registrations_body)}
+}};"""
+    return (aten_schema_registrations, schema_registrations)
+
+
+def gen_aggregated_headers(
+    *,
+    native_functions: Sequence[NativeFunction],
+    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
+    structured_native_functions: Sequence[NativeFunctionsGroup],
+    static_dispatch_idx: list[BackendIndex],
+    selector: SelectiveBuilder,
+    backend_indices: dict[DispatchKey, BackendIndex],
+    cpu_fm: FileManager,
+    device_fms: dict[str, FileManager],
+    functions_keys: set[DispatchKey],
+    dispatch_keys: Sequence[DispatchKey],
+    rocm: bool,
+) -> None:
+    # Buck doesn't support dynamic output files, so we aggregate all operator
+    # headers into a single file
+    cpu_fm.write(
+        "NativeMetaFunctions.h",
+        lambda: {
+            "NativeMetaFunctions_includes": [],
+            "NativeMetaFunctions_declarations": list(
+                mapMaybe(compute_meta_function_declaration, structured_native_functions)
+            ),
+        },
+    )
+    method_native_functions = [
+        fn for fn in native_functions if Variant.method in fn.variants
+    ]
+    non_method_native_functions = [
+        fn for fn in native_functions if fn not in method_native_functions
+    ]
+    cpu_fm.write(
+        "MethodOperators.h",
+        lambda: {
+            "MethodOperators_includes": [],
+            "MethodOperators_declarations": list(
+                mapMaybe(
+                    ComputeOperators(
+                        Target.DECLARATION,
+                        static_dispatch_backend_indices=static_dispatch_idx,
+                    ),
+                    method_native_functions,
+                )
+            ),
+        },
+    )
+    cpu_fm.write(
+        "Operators.h",
+        lambda: {
+            "Operators_includes": ["#include "],
+            "Operators_declarations": list(
+                mapMaybe(
+                    ComputeOperators(
+                        Target.DECLARATION,
+                        static_dispatch_backend_indices=static_dispatch_idx,
+                    ),
+                    non_method_native_functions,
+                )
+            ),
+        },
+    )
+    cpu_fm.write(
+        "Functions.h",
+        lambda: {
+            "static_dispatch_extra_headers": static_dispatch_extra_headers(
+                static_dispatch_idx
+            ),
+            "Functions_includes": ["#include "],
+            "Functions_declarations": list(
+                mapMaybe(
+                    ComputeFunction(),
+                    native_functions,
+                )
+            ),
+        },
+    )
+    declarations = get_native_function_declarations(
+        grouped_native_functions=grouped_native_functions,
+        backend_indices=backend_indices,
+    )
+    cpu_fm.write(
+        "NativeFunctions.h",
+        lambda: {
+            "NativeFunctions_includes": ["#include "],
+            "NativeFunctions_declarations": declarations,
+        },
+    )
+
+    for dispatch_key in dispatch_keys:
+        fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
+        if dispatch_key in functions_keys:
+            inl_headers = f"#include "
+
+            fm.write_with_template(
+                f"{dispatch_key}Functions.h",
+                "DispatchKeyFunctions.h",
+                lambda: {
+                    "dispatch_key": str(dispatch_key),
+                    "inline_headers": inl_headers,
+                },
+            )
+            fm.write_with_template(
+                f"{dispatch_key}Functions_inl.h",
+                "DispatchKeyFunctions_inl.h",
+                lambda: {
+                    "DispatchKeyFunctions_inl_includes": [],
+                    "dispatch_namespace": dispatch_key.lower(),
+                    "dispatch_namespaced_declarations": get_namespaced_declaration(
+                        grouped_native_functions=grouped_native_functions,
+                        dispatch_key=dispatch_key,
+                        backend_idx=backend_indices[dispatch_key],
+                        selector=selector,
+                        rocm=rocm,
+                        symint=True,
+                    ),
+                },
+            )
+
+        del fm
+
+
+def gen_per_operator_headers(
+    *,
+    native_functions: Sequence[NativeFunction],
+    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
+    static_dispatch_idx: list[BackendIndex],
+    selector: SelectiveBuilder,
+    backend_indices: dict[DispatchKey, BackendIndex],
+    cpu_fm: FileManager,
+    device_fms: dict[str, FileManager],
+    ops_fm: FileManager,
+    functions_keys: set[DispatchKey],
+    dispatch_keys: Sequence[DispatchKey],
+    rocm: bool,
+) -> None:
+    # For CMake builds, split operator declarations into separate headers in
+    # the ATen/ops folder to split up header dependencies
+    functions_by_root_name: dict[str, list[NativeFunction]] = defaultdict(list)
+    for fn in native_functions:
+        functions_by_root_name[fn.root_name].append(fn)
+
+    grouped_functions_by_root_name: dict[
+        str, list[NativeFunction | NativeFunctionsGroup]
+    ] = defaultdict(list)
+    for group in grouped_native_functions:
+        name = group.root_name
+        grouped_functions_by_root_name[name].append(group)
+
+    for name, functions in functions_by_root_name.items():
+        ops_fm.write_with_template(
+            f"{name}_ops.h",
+            "Operator.h",
+            lambda: {
+                "declarations": list(
+                    mapMaybe(
+                        ComputeOperators(
+                            Target.DECLARATION,
+                            static_dispatch_backend_indices=static_dispatch_idx,
+                        ),
+                        functions,
+                    )
+                ),
+            },
+        )
+
+        ops_fm.write_with_template(
+            f"{name}.h",
+            "Function.h",
+            lambda: {
+                "static_dispatch_ops_headers": list(
+                    mapMaybe(
+                        lambda fn: static_dispatch_ops_header(
+                            fn, backend_index=static_dispatch_idx
+                        ),
+                        functions,
+                    )
+                ),
+                "operator_includes": f"#include ",
+                "function_definitions": list(
+                    mapMaybe(
+                        ComputeFunction(),
+                        functions,
+                    )
+                ),
+            },
+        )
+
+        grouped_functions = grouped_functions_by_root_name.get(name, [])
+        structured_functions = [
+            fn
+            for fn in grouped_functions
+            if isinstance(fn, NativeFunctionsGroup) and fn.structured
+        ]
+        is_structured = len(structured_functions) > 0
+
+        if is_structured:
+            ops_fm.write_with_template(
+                f"{name}_meta.h",
+                "NativeMetaFunction.h",
+                lambda: {
+                    "meta_function_declarations": list(
+                        mapMaybe(
+                            compute_meta_function_declaration, structured_functions
+                        )
+                    ),
+                },
+            )
+        declarations = get_native_function_declarations(
+            grouped_native_functions=grouped_functions,
+            backend_indices=backend_indices,
+            native_function_decl_gen=dest.compute_native_function_declaration,
+        )
+        ops_fm.write_with_template(
+            f"{name}_native.h",
+            "NativeFunction.h",
+            lambda: {
+                "extra_includes": (
+                    f"#include " if is_structured else []
+                ),
+                "native_function_declarations": declarations,
+            },
+        )
+
+    for category, suffix in [
+        ("Functions", ""),
+        ("Operators", "_ops"),
+        ("NativeMetaFunctions", "_meta"),
+        ("NativeFunctions", "_native"),
+    ]:
+        cpu_fm.write(
+            f"{category}.h",
+            lambda: {
+                f"{category}_includes": [
+                    f"#include "
+                    for name in sorted(functions_by_root_name.keys())
+                ],
+                f"{category}_declarations": [],
+            },
+        )
+
+    for dispatch_key in dispatch_keys:
+        if dispatch_key not in functions_keys:
+            continue
+
+        dispatch_namespace = dispatch_key.lower()
+        dispatch_names = []
+
+        for name, functions in functions_by_root_name.items():
+            grouped_functions = grouped_functions_by_root_name.get(name, [])
+            declarations = list(
+                concatMap(
+                    dest.RegisterDispatchKey(
+                        backend_indices[dispatch_key],
+                        Target.NAMESPACED_DECLARATION,
+                        selector,
+                        rocm=rocm,
+                        symint=True,
+                        class_method_name=None,
+                        skip_dispatcher_op_registration=False,
+                    ),
+                    grouped_functions,
+                )
+            )
+
+            if len(declarations) == 0:
+                continue
+
+            dispatch_names.append(name)
+            ops_fm.write_with_template(
+                f"{name}_{dispatch_namespace}_dispatch.h",
+                "DispatchKeyFunction.h",
+                lambda: {
+                    "dispatch_namespace": dispatch_namespace,
+                    "dispatch_namespaced_declarations": declarations,
+                },
+            )
+
+        fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
+        inl_headers = f"#include "
+
+        fm.write_with_template(
+            f"{dispatch_key}Functions.h",
+            "DispatchKeyFunctions.h",
+            lambda: {
+                "dispatch_key": str(dispatch_key),
+                "inline_headers": inl_headers,
+            },
+        )
+        fm.write_with_template(
+            f"{dispatch_key}Functions_inl.h",
+            "DispatchKeyFunctions_inl.h",
+            lambda: {
+                "dispatch_namespace": dispatch_namespace,
+                "DispatchKeyFunctions_inl_includes": [
+                    f"#include "
+                    for name in sorted(dispatch_names)
+                ],
+                "dispatch_namespaced_declarations": [],
+            },
+        )
+        del fm
+
+    cpu_fm.write(
+        "MethodOperators.h",
+        lambda: {
+            "MethodOperators_includes": sorted(
+                f"#include "
+                for name, functions in functions_by_root_name.items()
+                if any(Variant.method in fn.variants for fn in functions)
+            ),
+            "MethodOperators_declarations": [],
+        },
+    )
+
+
+def gen_headers(
+    *,
+    native_functions: Sequence[NativeFunction],
+    valid_tags: set[str],
+    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
+    structured_native_functions: Sequence[NativeFunctionsGroup],
+    static_dispatch_idx: list[BackendIndex],
+    selector: SelectiveBuilder,
+    backend_indices: dict[DispatchKey, BackendIndex],
+    core_fm: FileManager,
+    cpu_fm: FileManager,
+    device_fms: dict[str, FileManager],
+    ops_fm: FileManager,
+    dispatch_keys: Sequence[DispatchKey],
+    functions_keys: set[DispatchKey],
+    rocm: bool,
+    per_operator_headers: bool,
+) -> None:
+    if per_operator_headers:
+        gen_per_operator_headers(
+            native_functions=native_functions,
+            grouped_native_functions=grouped_native_functions,
+            static_dispatch_idx=static_dispatch_idx,
+            selector=selector,
+            backend_indices=backend_indices,
+            cpu_fm=cpu_fm,
+            device_fms=device_fms,
+            ops_fm=ops_fm,
+            dispatch_keys=dispatch_keys,
+            functions_keys=functions_keys,
+            rocm=rocm,
+        )
+    else:
+        gen_aggregated_headers(
+            native_functions=native_functions,
+            grouped_native_functions=grouped_native_functions,
+            structured_native_functions=structured_native_functions,
+            static_dispatch_idx=static_dispatch_idx,
+            selector=selector,
+            backend_indices=backend_indices,
+            cpu_fm=cpu_fm,
+            device_fms=device_fms,
+            dispatch_keys=dispatch_keys,
+            functions_keys=functions_keys,
+            rocm=rocm,
+        )
+
+    core_fm.write(
+        "TensorBody.h",
+        lambda: {
+            "tensor_method_declarations": list(
+                mapMaybe(
+                    ComputeTensorMethod(
+                        target=Target.DECLARATION,
+                        static_dispatch_backend_indices=static_dispatch_idx,
+                    ),
+                    native_functions,
+                )
+            ),
+            "tensor_method_definitions": list(
+                mapMaybe(
+                    ComputeTensorMethod(
+                        target=Target.DEFINITION,
+                        static_dispatch_backend_indices=static_dispatch_idx,
+                    ),
+                    native_functions,
+                )
+            ),
+        },
+    )
+
+    cpu_fm.write(
+        "RedispatchFunctions.h",
+        lambda: {
+            "function_redispatch_definitions": list(
+                mapMaybe(ComputeRedispatchFunction(), native_functions)
+            ),
+        },
+    )
+
+    cpu_fm.write(
+        "RegistrationDeclarations.h",
+        lambda: {
+            "registration_declarations": [
+                compute_registration_declarations(f, backend_indices)
+                for f in native_functions
+            ],
+        },
+    )
+
+    cpu_fm.write(
+        "VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions)
+    )
+
+    def gen_aten_interned_strings() -> dict[str, str]:
+        attrs: set[str] = set()  # All function argument names
+        names = set()  # All ATen function names
+        for func in native_functions:
+            names.add(str(func.func.name.name))
+            # Some operators don't have a functional variant but we still create a
+            # symbol without the underscore
+            names.add(func.func.name.name.base)
+
+            attrs.update(arg.name for arg in func.func.schema_order_arguments())
+
+        # These are keywords in C++, so aren't valid symbol names
+        # https://en.cppreference.com/w/cpp/language/operator_alternative
+        names -= {
+            "and",
+            "and_eq",
+            "bitand",
+            "bitor",
+            "compl",
+            "not",
+            "not_eq",
+            "or",
+            "or_eq",
+            "xor",
+            "xor_eq",
+        }
+
+        return {
+            "aten_symbols": " \\\n".join(
+                [f"_(aten, {name})" for name in sorted(names)]
+            ),
+            "attr_symbols": " \\\n".join(
+                [f"_(attr, {name})" for name in sorted(attrs)]
+            ),
+        }
+
+    core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
+
+    def gen_tags_enum() -> dict[str, str]:
+        return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))}
+
+    core_fm.write("enum_tag.h", gen_tags_enum)
+
+
+def gen_source_files(
+    *,
+    native_functions: Sequence[NativeFunction],
+    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
+    structured_native_functions: Sequence[NativeFunctionsGroup],
+    view_groups: Sequence[NativeFunctionsViewGroup],
+    selector: SelectiveBuilder,
+    static_dispatch_idx: list[BackendIndex],
+    backend_indices: dict[DispatchKey, BackendIndex],
+    aoti_fm: FileManager,
+    core_fm: FileManager,
+    cpu_vec_fm: FileManager,
+    cpu_fm: FileManager,
+    device_fms: dict[str, FileManager],
+    dispatch_keys: Sequence[DispatchKey],
+    functions_keys: set[DispatchKey],
+    rocm: bool,
+    force_schema_registration: bool,
+    per_operator_headers: bool,
+    skip_dispatcher_op_registration: bool,
+    update_aoti_c_shim: bool,
+    aoti_backends: set[DispatchKey],
+    extend_aoti_c_shim: bool,
+) -> None:
+    extra_cuda_headers = """\
+#include 
+#include 
+#include 
+#include """
+    if rocm:
+        extra_cuda_headers = """\
+#include 
+#include 
+#include 
+#include """
+
+    for dispatch_key in dispatch_keys:
+        fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
+        if per_operator_headers:
+
+            def operator_headers() -> list[str]:
+                headers = []
+                for g in grouped_native_functions:
+                    is_registered = False
+                    if backend_index.has_kernel(g):
+                        is_registered = True
+                    # The above has_kernel test on a group will only test for
+                    # the existence of out dispatch, because that's how
+                    # structured kernels work. But sometimes functions can be
+                    # grouped but not be structured, and then you need to check
+                    # each individual piece, as they may have manual dispatch
+                    # entries.
+                    elif isinstance(g, NativeFunctionsGroup) and any(
+                        backend_index.has_kernel(fn) for fn in g.functions()
+                    ):
+                        is_registered = True
+                    # TODO: this condition is a bit questionable
+                    # (It has to do with the fact that structured kernels get generated kernels
+                    # to the Meta + CompositeExplicitAutogradNonFunctional keys).
+                    elif g.structured and dispatch_key in (
+                        DispatchKey.Meta,
+                        DispatchKey.CompositeExplicitAutogradNonFunctional,
+                    ):
+                        is_registered = True
+                    if not is_registered:
+                        continue
+
+                    headers.append(f"#include ")
+                    if (
+                        dispatch_key
+                        == DispatchKey.CompositeExplicitAutogradNonFunctional
+                    ):
+                        headers.append(f"#include ")
+                    if dispatch_key in functions_keys:
+                        headers.append(
+                            f"#include "
+                        )
+
+                return sorted(set(headers))
+
+        else:
+
+            def operator_headers() -> list[str]:
+                headers = ["#include "]
+                if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
+                    headers.append("#include ")
+                if dispatch_key in functions_keys:
+                    headers.append(f"#include ")
+                return headers
+
+        backend_index = backend_indices[dispatch_key]
+        ns_grouped_native_functions = defaultdict(list)
+        for grouped_native_function in grouped_native_functions:
+            namespace = (
+                grouped_native_function.namespace
+                if isinstance(grouped_native_function, NativeFunction)
+                else grouped_native_function.functional.namespace
+            )
+            ns_grouped_native_functions[namespace].append(grouped_native_function)
+
+        dispatch_namespace = str(dispatch_key).lower()
+
+        # CompositeImplicitAutogradNestdTensor does not currently user the helpers generated
+        # compilation will fail when `-Werror=unused-function` flag is set
+        gen_dispatch_helpers: bool = (
+            dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor
+        )
+
+        register_dispatch_key_base_env = {
+            "extra_cuda_headers": extra_cuda_headers
+            if is_cuda_dispatch_key(dispatch_key)
+            else "",
+            "external_backend_headers": "",
+            "dispatch_headers": dest.gen_registration_headers(
+                backend_index, per_operator_headers, rocm
+            ),
+            # ops_headers *could* be sharded, but doesn't seem necessary?
+            "ops_headers": operator_headers(),
+            "dispatch_helpers": (
+                dest.gen_registration_helpers(backend_index)
+                if gen_dispatch_helpers
+                else []
+            ),
+        }
+
+        def register_dispatch_key_env_callable(
+            gnf: NativeFunction | NativeFunctionsGroup,
+        ) -> dict[str, list[str]]:
+            return {
+                "dispatch_definitions": get_native_function_definitions(
+                    fm=fm,  # noqa: F821
+                    grouped_native_functions=[gnf],
+                    dispatch_key=dispatch_key,
+                    backend_idx=backend_index,
+                    selector=selector,
+                    rocm=rocm,
+                    symint=True,
+                    skip_dispatcher_op_registration=skip_dispatcher_op_registration,
+                    gen_dispatch_helpers=gen_dispatch_helpers,
+                )
+            }
+
+        fm.write_sharded_with_template(
+            f"Register{dispatch_key}.cpp",
+            "RegisterDispatchKey.cpp",
+            grouped_native_functions,
+            key_fn=lambda x: x.root_name,
+            env_callable=register_dispatch_key_env_callable,
+            num_shards=4 if dispatch_key == DispatchKey.CPU else 1,
+            base_env=register_dispatch_key_base_env,
+            sharded_keys={"dispatch_definitions"},
+        )
+
+        for g in structured_native_functions:
+            if not g.out.ufunc_inner_loop or not is_ufunc_dispatch_key(dispatch_key):
+                continue
+            name = g.functional.func.name.name
+            if dispatch_key is DispatchKey.CPU:
+                assert fm is cpu_fm
+                fm.write_with_template(
+                    f"UfuncCPU_{name}.cpp",
+                    "UfuncCPU.cpp",
+                    lambda: {
+                        "meta_declaration": compute_meta_function_declaration(g),
+                        "native_declaration": dest.compute_native_function_declaration(
+                            g, backend_indices[dispatch_key]
+                        ),
+                        "native_definitions": dest.compute_ufunc_cpu(g),
+                    },
+                )
+                cpu_vec_fm.write_with_template(
+                    f"UfuncCPUKernel_{name}.cpp",
+                    "UfuncCPUKernel.cpp",
+                    lambda: {
+                        "name": name,
+                        "native_definitions": dest.compute_ufunc_cpu_kernel(g),
+                    },
+                )
+            elif dispatch_key is DispatchKey.CUDA:
+                cuda_headers = "#include "
+                if rocm:
+                    cuda_headers = "#include "
+                fm.write_with_template(
+                    f"UfuncCUDA_{name}.cu",
+                    "UfuncCUDA.cu",
+                    lambda: {
+                        "name": name,
+                        "cuda_headers": cuda_headers,
+                        "meta_declaration": compute_meta_function_declaration(g),
+                        "native_declaration": dest.compute_native_function_declaration(
+                            g, backend_indices[dispatch_key]
+                        ),
+                        "native_definitions": dest.compute_ufunc_cuda(g),
+                    },
+                )
+            else:
+                raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
+
+        del fm
+
+    gen_aoti_c_shim_files(
+        aoti_fm=aoti_fm,
+        aoti_backends=aoti_backends,
+        native_functions=native_functions,
+        backend_indices=backend_indices,
+        structured_native_functions=structured_native_functions,
+        extra_cuda_headers=extra_cuda_headers,
+        update_aoti_c_shim=update_aoti_c_shim,
+        extend_aoti_c_shim=extend_aoti_c_shim,
+    )
+
+    # BackendSelect is generated specially
+    def gen_backend_select() -> dict[str, list[str]]:
+        relevant_fns = [
+            fn for fn in native_functions if needs_backend_select(fn, selector)
+        ]
+        return {
+            "ops_headers": [
+                f"#include " for fn in relevant_fns
+            ],
+            "backend_select_method_definitions": list(
+                mapMaybe(
+                    ComputeBackendSelect(Target.DEFINITION, selector), relevant_fns
+                )
+            ),
+            "backend_select_function_registrations": list(
+                mapMaybe(
+                    ComputeBackendSelect(Target.REGISTRATION, selector), relevant_fns
+                )
+            ),
+        }
+
+    cpu_fm.write("RegisterBackendSelect.cpp", gen_backend_select)
+
+    schema_selector = selector
+    if force_schema_registration:
+        schema_selector = SelectiveBuilder.get_nop_selector()
+
+    (
+        aten_schema_registrations,
+        schema_registrations,
+    ) = get_native_function_schema_registrations(
+        native_functions=native_functions, schema_selector=schema_selector
+    )
+    cpu_fm.write(
+        "RegisterSchema.cpp",
+        lambda: {
+            "aten_schema_registrations": []
+            if skip_dispatcher_op_registration
+            else aten_schema_registrations,
+            "schema_registrations": []
+            if skip_dispatcher_op_registration
+            else schema_registrations,
+        },
+    )
+
+    def key_func(
+        fn: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
+    ) -> str:
+        return fn.root_name
+
+    cpu_fm.write_sharded(
+        "Operators.cpp",
+        native_functions,
+        key_fn=key_func,
+        env_callable=lambda fn: {
+            "operator_headers": [f"#include "],
+            "definitions": [
+                ComputeOperators(
+                    Target.DEFINITION,
+                    static_dispatch_backend_indices=static_dispatch_idx,
+                )(fn)
+            ],
+        },
+        base_env={
+            "static_dispatch_extra_headers": static_dispatch_extra_headers(
+                static_dispatch_idx
+            ),
+        },
+        num_shards=5,
+        sharded_keys={
+            "operator_headers",
+            "definitions",
+            "static_dispatch_extra_headers",
+        },
+    )
+
+    cpu_fm.write("Functions.cpp", dict)
+
+    core_fm.write("TensorMethods.cpp", dict)
+
+    core_fm.write(
+        "ATenOpList.cpp",
+        lambda: {
+            "aten_ops": list(mapMaybe(compute_aten_op, native_functions)),
+        },
+    )
+
+    def functionalization_env_callable(
+        g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
+    ) -> dict[str, list[str]]:
+        def gen_op_headers(
+            g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
+        ) -> list[str]:
+            if isinstance(g, NativeFunctionsViewGroup):
+                # view ops always get a functionalization kernel
+                headers = [
+                    f"#include ",
+                    f"#include ",
+                ]
+                if g.view_copy is not None:
+                    headers += [
+                        f"#include ",
+                        f"#include ",
+                    ]
+                return headers
+            elif isinstance(g, NativeFunctionsGroup):
+                headers = [
+                    f"#include ",
+                    f"#include ",
+                    f"#include ",
+                    f"#include ",
+                ]
+                if g.inplace is not None:
+                    headers += [
+                        f"#include ",
+                        f"#include ",
+                    ]
+                if g.mutable is not None:
+                    headers += [
+                        f"#include ",
+                        f"#include ",
+                    ]
+                return headers
+            else:
+                return [
+                    f"#include ",
+                    f"#include ",
+                ]
+
+        return {
+            "ops_headers": gen_op_headers(g),
+            "func_definitions": gen_functionalization_definition(
+                selector,
+                g,
+            ),
+            "func_registrations": gen_functionalization_registration(
+                selector,
+                g,
+                backend_indices[DispatchKey.CompositeImplicitAutograd],
+            ),
+        }
+
+    all_groups: list[
+        NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup
+    ] = list(structured_native_functions) + list(
+        view_groups  # type: ignore[assignment, arg-type, operator]
+    )
+    # Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly.
+    # The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because:
+    # (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
+    # (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
+    #     Although this could go away long-term if we add a dedicated dispatch key for decompositions.
+    structured_map: dict[OperatorName, NativeFunction] = {
+        f.func.name: f
+        for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
+    }
+    view_map: dict[OperatorName, NativeFunction] = {
+        f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
+    }
+    all_groups.extend(
+        f
+        for f in native_functions
+        if f.func.name not in structured_map and f.func.name not in view_map
+    )
+
+    cpu_fm.write_sharded(
+        "RegisterFunctionalization.cpp",
+        all_groups,
+        key_fn=key_func,
+        env_callable=functionalization_env_callable,
+        num_shards=4,
+        sharded_keys={
+            "ops_headers",
+            "func_definitions",
+            "func_registrations",
+            "func_add_back_views_definitions",
+            "func_add_back_views_registrations",
+        },
+    )
+
+    cpu_fm.write(
+        "FunctionalInverses.h",
+        lambda: {
+            "view_inverse_declarations": list(
+                mapMaybe(
+                    lambda g: gen_functionalization_view_inverse_declaration(
+                        selector, g
+                    ),
+                    view_groups,
+                )
+            )
+        },
+    )
+
+    # Note [view_copy NativeFunctions]
+    # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
+    # needs to have a corresponding non-aliasing {view}_copy variant.
+    # Backends that use functionalization and don't know how to handle aliasing ops
+    # are expected to implement kernels for these {view}_copy kernels instead.
+    # The code for {view}_copy operators in core is pretty boilerplate-heavy however,
+    # so we codegen the following:
+    # (1) A CompositeExplicitAutogradNonFunctional kernel for every {view}_copy operator.
+    #     These are never explicitly invoked by the functionalization pass,
+    #     but they could theoretically be called from user code (I added these kernels for completeness,
+    #     since the ops are part of the public API).
+    # (2) A derivative formula for every {view}_copy operator
+    #     {view}_copy operators can reuse the same derivative formulas as their {view} op counterparts,
+    #     so rather than stamping all of the entries out in derivatives.yaml,
+    #     we codegen them in.
+    #     This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry.
+    cpu_fm.write(
+        "CompositeViewCopyKernels.cpp",
+        lambda: {
+            "ops_headers": [
+                "\n".join(
+                    f"#include \n"
+                    # NB: this include is important as it ensures we
+                    # set the visibility on generated view_copy kernels
+                    # correctly
+                    f"#include "
+                    for f in (
+                        [g.view] if g.view_copy is None else [g.view, g.view_copy]
+                    )
+                )
+                for g in view_groups
+            ]
+            + [
+                "\n".join(
+                    f"#include \n"
+                    # NB: this include is also important for correct visibility
+                    f"#include "
+                    for f in [g.inplace, g.mutable, g.functional]
+                    if f is not None and "generated" not in f.tags
+                )
+                for g in structured_native_functions
+            ],
+            "CompositeViewCopyKernel_Definitions": list(
+                mapMaybe(
+                    GenCompositeViewCopyKernel(
+                        backend_indices[
+                            DispatchKey.CompositeExplicitAutogradNonFunctional
+                        ]
+                    ),
+                    view_groups,
+                )
+            ),
+            "GeneratedCompositeFunctional_Definitions": list(
+                mapMaybe(
+                    gen_composite_functional_kernel,
+                    structured_native_functions,
+                )
+            ),
+            "GeneratedCompositeOut_Definitions": list(
+                mapMaybe(
+                    gen_composite_out_kernel,
+                    structured_native_functions,
+                )
+            ),
+        },
+    )
+
+
+def gen_declarations_yaml(
+    cpu_fm: FileManager, native_functions: Sequence[NativeFunction]
+) -> None:
+    cpu_fm.write(
+        "Declarations.yaml",
+        lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]),
+    )
+
+
+def get_torchgen_root() -> Path:
+    """
+    If you're depending on torchgen out-of-tree, you can use the root to figure
+    out the path to native_functions.yaml
+    """
+    return Path(__file__).parent.resolve()
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(description="Generate ATen source files")
+    parser.add_argument(
+        "-s",
+        "--source-path",
+        help="path to source directory for ATen",
+        default="aten/src/ATen",
+    )
+    parser.add_argument(
+        "-o",
+        "--output-dependencies",
+        help="output a list of dependencies into the given file and exit",
+    )
+    parser.add_argument(
+        "--dry-run",
+        action="store_true",
+        help="run without writing any files (still updates outputs)",
+    )
+    parser.add_argument(
+        "--per-operator-headers",
+        action="store_true",
+        help="generate separate headers per operator in ATen/ops",
+    )
+    parser.add_argument(
+        "-d",
+        "--install-dir",
+        "--install_dir",
+        help="output directory",
+        default="build/aten/src/ATen",
+    )
+    parser.add_argument(
+        "--aoti-install-dir",
+        "--aoti_install_dir",
+        help="output directory for AOTInductor shim",
+        default="torch/csrc/inductor/aoti_torch/generated",
+    )
+    parser.add_argument(
+        "--rocm",
+        action="store_true",
+        help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly",
+    )
+    parser.add_argument(
+        "--mps",
+        action="store_true",
+        help="Generate MPS registration code when set",
+    )
+    parser.add_argument(
+        "--xpu",
+        action="store_true",
+        help="Generate XPU registration code when set",
+    )
+    parser.add_argument(
+        "--mtia",
+        action="store_true",
+        help="Generate MTIA registration code when set",
+    )
+
+    # TODO: --op-registration-whitelist will be removed when all call-sites
+    # for gen.py are moved over to using the operator YAML file for mobile
+    # custom build.
+    parser.add_argument(
+        "--op-registration-whitelist",
+        "--op_registration_whitelist",
+        nargs="*",
+        help="filter op registrations by the whitelist (if set); "
+        "each item is `namespace`::`operator name` without overload name; "
+        "e.g.: aten::empty aten::conv2d ...",
+    )
+    parser.add_argument(
+        "--op-selection-yaml-path",
+        "--op_selection_yaml_path",
+        help="Provide a path to the operator selection (for custom build) YAML "
+        "that contains the information about the set of selected operators "
+        "and their categories (training, ...). Each operator is either a "
+        "full operator name with overload or just a bare operator name. "
+        "The operator names also contain the namespace prefix (e.g. aten::)",
+    )
+    parser.add_argument(
+        "--backend-whitelist",
+        "--backend_whitelist",
+        nargs="*",
+        help="filter dispatch backend by the whitelist (if set), "
+        "e.g.: CPU CUDA QuantizedCPU ...",
+    )
+    parser.add_argument(
+        "--static-dispatch-backend",
+        "--static_dispatch_backend",
+        nargs="*",
+        help="generate static dispatch code for the specific backend (if set)",
+    )
+    parser.add_argument(
+        "--skip-dispatcher-op-registration",
+        "--skip_dispatcher_op_registration",
+        action="store_true",
+        help="Avoid registering operators into the dispatcher.",
+    )
+    parser.add_argument(
+        "--force-schema-registration",
+        "--force_schema_registration",
+        action="store_true",
+        help="force it to generate schema-only registrations for all ops, including"
+        "those that are not listed on --op-registration-whitelist",
+    )
+    parser.add_argument(
+        "--generate",
+        type=str,
+        nargs="*",
+        choices=["headers", "sources", "declarations_yaml"],
+        default=["headers", "sources", "declarations_yaml"],
+        help="Generate only a subset of files",
+    )
+    parser.add_argument(
+        "--update-aoti-c-shim",
+        action="store_true",
+        help="Update AOTInductor C shim after adding an entry to inductor_fallback_ops in torchgen/aoti/fallback_ops.py. "
+        "WARNING: Do not use this unless you are sure what you are doing!!!",
+    )
+    parser.add_argument(
+        "--extend-aoti-c-shim",
+        action="store_true",
+        help="This Flag indicates the generation of c shims for out-of-tree ATen ops,"
+        "which is an extension to the In-tree ATen op c shims. This flag needs to be combined with"
+        "---source-path="
+        "--aoti-install-dir=/extend"
+        "   default is torch/csrc/inductor/aoti_torch/generated/extend"
+        "WARNING: Do not use this unless you are sure what you are doing!!!",
+    )
+
+    options = parser.parse_args()
+
+    selector = get_custom_build_selector(
+        options.op_registration_whitelist,
+        options.op_selection_yaml_path,
+    )
+
+    native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
+    tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
+
+    from torchgen.model import dispatch_keys
+
+    # Only a limited set of dispatch keys get CPUFunctions.h headers generated
+    # for them; this is the set
+    functions_keys = {
+        DispatchKey.CPU,
+        DispatchKey.CUDA,
+        DispatchKey.CompositeImplicitAutograd,
+        DispatchKey.CompositeImplicitAutogradNestedTensor,
+        DispatchKey.CompositeExplicitAutograd,
+        DispatchKey.CompositeExplicitAutogradNonFunctional,
+        DispatchKey.Meta,
+        DispatchKey.MTIA,
+    }
+
+    aoti_backends = {
+        DispatchKey.CPU,
+        DispatchKey.CUDA,
+    }
+
+    # TODO: stop generating CUDA kernels for non-CUDA builds
+    ignore_keys = set()
+
+    if options.mps or options.update_aoti_c_shim:
+        functions_keys.add(DispatchKey.MPS)
+        aoti_backends.add(DispatchKey.MPS)
+    else:
+        ignore_keys.add(DispatchKey.MPS)
+
+        if DispatchKey.MPS in dispatch_keys:
+            del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
+
+    if options.xpu or options.update_aoti_c_shim:
+        functions_keys.add(DispatchKey.XPU)
+        aoti_backends.add(DispatchKey.XPU)
+    else:
+        ignore_keys.add(DispatchKey.XPU)
+
+        if DispatchKey.XPU in dispatch_keys:
+            del dispatch_keys[dispatch_keys.index(DispatchKey.XPU)]
+
+    if not options.mtia:
+        ignore_keys.add(DispatchKey.MTIA)
+
+        if DispatchKey.MTIA in dispatch_keys:
+            del dispatch_keys[dispatch_keys.index(DispatchKey.MTIA)]
+
+    if options.backend_whitelist:
+        dispatch_keys = [
+            k
+            for k in dispatch_keys
+            if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
+        ]
+
+    parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
+    valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
+    native_functions, backend_indices = (
+        parsed_yaml.native_functions,
+        parsed_yaml.backend_indices,
+    )
+
+    grouped_native_functions = get_grouped_native_functions(native_functions)
+
+    structured_native_functions = [
+        g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)
+    ]
+    native_functions_with_view_groups = get_grouped_by_view_native_functions(
+        native_functions
+    )
+    view_groups = [
+        g
+        for g in native_functions_with_view_groups
+        if isinstance(g, NativeFunctionsViewGroup)
+    ]
+
+    # NB: It is mandatory to NOT use os.path.join here, as the install directory
+    # will eventually be ingested by cmake, which does not respect Windows style
+    # path slashes.  If you switch this to use os.path.join, you'll get an error
+    # like:
+    #
+    #   Syntax error in cmake code when parsing string
+    #
+    #     C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h
+    #
+    #   Invalid character escape '\c'.
+    core_install_dir = f"{options.install_dir}/core"
+    Path(core_install_dir).mkdir(parents=True, exist_ok=True)
+    ops_install_dir = f"{options.install_dir}/ops"
+    Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
+
+    aoti_install_dir = f"{options.aoti_install_dir}"
+    Path(aoti_install_dir).mkdir(parents=True, exist_ok=True)
+
+    core_fm = make_file_manager(options=options, install_dir=core_install_dir)
+    cpu_fm = make_file_manager(options=options)
+    cpu_vec_fm = make_file_manager(options=options)
+    cuda_fm = make_file_manager(options=options)
+    ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
+    aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir)
+    device_fms = {"cuda": cuda_fm}
+    if options.xpu:
+        device_fms["xpu"] = make_file_manager(options=options)
+
+    static_dispatch_idx: list[BackendIndex] = []
+    if options.static_dispatch_backend:
+        static_dispatch_idx = [
+            backend_indices[DispatchKey.parse(key)]
+            for key in options.static_dispatch_backend
+        ]
+        for key in options.static_dispatch_backend:
+            dp_key = DispatchKey.parse(key)
+            if dp_key not in functions_keys:
+                functions_keys.add(dp_key)
+
+    if "sources" in options.generate:
+        gen_source_files(
+            native_functions=native_functions,
+            grouped_native_functions=grouped_native_functions,
+            structured_native_functions=structured_native_functions,
+            view_groups=view_groups,
+            selector=selector,
+            static_dispatch_idx=static_dispatch_idx,
+            backend_indices=backend_indices,
+            aoti_fm=aoti_fm,
+            core_fm=core_fm,
+            cpu_vec_fm=cpu_vec_fm,
+            cpu_fm=cpu_fm,
+            device_fms=device_fms,
+            dispatch_keys=dispatch_keys,
+            functions_keys=functions_keys,
+            rocm=options.rocm,
+            force_schema_registration=options.force_schema_registration,
+            per_operator_headers=options.per_operator_headers,
+            skip_dispatcher_op_registration=options.skip_dispatcher_op_registration,
+            update_aoti_c_shim=options.update_aoti_c_shim,
+            aoti_backends=aoti_backends,
+            extend_aoti_c_shim=options.extend_aoti_c_shim,
+        )
+
+    if "headers" in options.generate:
+        gen_headers(
+            native_functions=native_functions,
+            valid_tags=valid_tags,
+            grouped_native_functions=grouped_native_functions,
+            structured_native_functions=structured_native_functions,
+            static_dispatch_idx=static_dispatch_idx,
+            selector=selector,
+            backend_indices=backend_indices,
+            core_fm=core_fm,
+            cpu_fm=cpu_fm,
+            device_fms=device_fms,
+            ops_fm=ops_fm,
+            dispatch_keys=dispatch_keys,
+            functions_keys=functions_keys,
+            rocm=options.rocm,
+            per_operator_headers=options.per_operator_headers,
+        )
+
+    if "declarations_yaml" in options.generate:
+        gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm)
+
+    if options.output_dependencies:
+        depfile_path = Path(options.output_dependencies).resolve()
+        depfile_name = depfile_path.name
+        depfile_stem = depfile_path.stem
+
+        for fm, prefix in [
+            (cpu_fm, ""),
+            (cpu_vec_fm, "cpu_vec_"),
+            (core_fm, "core_"),
+            (ops_fm, "ops_"),
+        ] + [(device_fm, f"{device}_") for device, device_fm in device_fms.items()]:
+            varname = prefix + depfile_stem
+            path = depfile_path.parent / (prefix + depfile_name)
+            fm.write_outputs(varname, str(path))
+
+
+if __name__ == "__main__":
+    main()
diff --git a/phivenv/Lib/site-packages/torchgen/gen_aoti_c_shim.py b/phivenv/Lib/site-packages/torchgen/gen_aoti_c_shim.py
new file mode 100644
index 0000000000000000000000000000000000000000..613489b35c14e1ab555e9e994b3aa492df71e54b
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/gen_aoti_c_shim.py
@@ -0,0 +1,715 @@
+from __future__ import annotations
+
+import difflib
+import os
+import textwrap
+from dataclasses import dataclass
+from typing import TYPE_CHECKING
+
+from torchgen.aoti.fallback_ops import inductor_fallback_ops
+from torchgen.api.types import DispatcherSignature
+from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
+from torchgen.context import method_with_native_function
+from torchgen.model import (
+    Argument,
+    BackendIndex,
+    BaseTy,
+    BaseType,
+    DispatchKey,
+    FunctionSchema,
+    is_cuda_dispatch_key,
+    ListType,
+    NativeFunction,
+    NativeFunctionsGroup,
+    OperatorName,
+    OptionalType,
+    Type,
+)
+from torchgen.utils import FileManager, mapMaybe
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+
+base_type_to_c_type = {
+    BaseTy.Tensor: "AtenTensorHandle",
+    BaseTy.bool: "int32_t",  # Use int to pass bool
+    BaseTy.int: "int64_t",
+    BaseTy.SymInt: "int64_t",  # Inductor-generated code won't see a SymInt
+    BaseTy.Scalar: "double",  # Use double to pass both integer and floating point
+    BaseTy.float: "double",  # TODO: how about other floating point types?
+    BaseTy.str: "const char*",
+    BaseTy.DeviceIndex: "int32_t",
+    BaseTy.Layout: "int32_t",  # Represent enum as int
+    BaseTy.MemoryFormat: "int32_t",  # Represent enum as int
+    BaseTy.ScalarType: "int32_t",  # Represent enum as int
+    BaseTy.Generator: "AtenGeneratorHandle",
+}
+
+base_type_to_aten_type = {
+    BaseTy.Tensor: "at::Tensor",
+    BaseTy.bool: "bool",
+    BaseTy.int: "int64_t",
+    BaseTy.SymInt: "c10::SymInt",
+    BaseTy.Scalar: "c10::Scalar",
+    BaseTy.float: "double",
+    BaseTy.str: "::std::string_view",
+    BaseTy.DeviceIndex: "c10::DeviceIndex",
+    BaseTy.Layout: "c10::Layout",
+    BaseTy.MemoryFormat: "c10::MemoryFormat",
+    BaseTy.ScalarType: "c10::ScalarType",
+    BaseTy.Generator: "at::Generator",
+}
+
+base_type_to_callsite_expr = {
+    BaseTy.Tensor: "resolve_tensor_dispatch_flags",
+    BaseTy.bool: "",
+    BaseTy.int: "",
+    BaseTy.SymInt: "",
+    BaseTy.Scalar: "",
+    BaseTy.float: "",
+    BaseTy.str: "",
+    BaseTy.DeviceIndex: "static_cast",
+    BaseTy.Layout: "static_cast",
+    BaseTy.MemoryFormat: "static_cast",
+    BaseTy.ScalarType: "static_cast",
+    BaseTy.Generator: "*generator_handle_to_generator_pointer",
+}
+
+
+# convert args to C types, names in declarations, and expressions in function bodies
+def convert_arg_type_and_name(
+    typ: Type,
+    name: str,
+    is_write: bool = False,
+) -> tuple[list[str], list[str], list[str], list[str]]:
+    if isinstance(typ, BaseType):
+        if typ.name in base_type_to_c_type:
+            if typ.name == BaseTy.Tensor and is_write:
+                # For output tensors, our normal call to resolve_tensor_dispatch_flags
+                # results in an rvalue tensor, which can't be passed to at::Tensor&.
+                # Override this case specifically.
+                callsite_expr = [f"*tensor_handle_to_tensor_pointer({name})"]
+            else:
+                callsite_expr = [
+                    f"{base_type_to_callsite_expr[typ.name]}({name})"
+                    if base_type_to_callsite_expr[typ.name]
+                    else name
+                ]
+
+            return (
+                [base_type_to_c_type[typ.name]],
+                [name],
+                [base_type_to_aten_type[typ.name]],
+                callsite_expr,
+            )
+        elif typ.name == BaseTy.Device:
+            return (
+                ["int32_t", "int32_t"],
+                [name, name + "_index_"],
+                ["c10::Device"],
+                [
+                    f"c10::Device(static_cast({name}), static_cast({name}_index_))"
+                ],
+            )
+        else:
+            # TODO: BaseTy.Dimname, etc.
+            raise NotImplementedError(f"TODO: add support for arg type {repr(typ)}")
+    elif isinstance(typ, OptionalType):
+        c_types, names, aten_types, callsite_exprs = convert_arg_type_and_name(
+            typ.elem, name
+        )
+        j = 0  # index for names
+        new_aten_types = []
+        new_callsite_exprs = []
+        for aten_type in aten_types:
+            # Use pointer to denote optional type
+            c_types[j] = c_types[j] + "*"
+            if aten_type.startswith("c10::ArrayRef<"):
+                # ArrayRef is passed as pointer + size, but no need to add "*" to the size argument
+                new_aten_types.append(f"::std::optional<{aten_type}>")
+                base_type = aten_type[len("c10::ArrayRef<") : -1]
+                new_callsite_exprs.append(
+                    f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j + 1]})"
+                )
+                j += 2
+            elif aten_type == "c10::Device":
+                # Device is passed as device_type + device_index
+                new_aten_types.append("::std::optional")
+                new_callsite_exprs.append(
+                    f"pointer_to_optional_device({names[j]}, {names[j + 1]})"
+                )
+                j += 2
+            elif aten_type == "at::Tensor":
+                new_aten_types.append(f"::std::optional<{aten_type}>")
+                new_callsite_exprs.append(f"resolve_tensor_dispatch_flags({names[j]})")
+                j += 1
+            else:
+                new_aten_types.append(f"::std::optional<{aten_type}>")
+                new_callsite_exprs.append(
+                    f"pointer_to_optional<{aten_type}>({names[j]})"
+                )
+                j += 1
+
+        return (
+            c_types,
+            names,
+            new_aten_types,
+            new_callsite_exprs,
+        )
+    elif isinstance(typ, ListType):
+        # Need to explicitly pass the list as pointer + length
+        c_types, names, aten_types, _ = convert_arg_type_and_name(typ.elem, name)
+        assert len(c_types) == 1, "ListType with unsupported element type " + repr(typ)
+
+        # The list content should never be modified
+        c_types[0] = f"const {c_types[0]}*"
+        c_types.append("int64_t")
+        name = names[0]
+        names.append(name + "_len_")
+
+        atype = aten_types[0]
+        callsite_exprs = []
+        if atype == "bool":
+            # no converter from std::vector to c10::ArrayRef
+            # construct std::array instead
+            assert typ.size is not None
+            callsite_exprs.append(f"pointer_to_list<{typ.size}>({name})")
+        elif atype == "at::Tensor" and not is_write:
+            callsite_exprs.append(
+                f"resolve_tensor_list_dispatch_flags({name}, {name}_len_)"
+            )
+        elif atype == "::std::optional":
+            # convert from std::vector<::std::optional> to c10::List<::std::optional>
+            callsite_exprs.append(
+                f"c10::List<{atype}>(c10::ArrayRef<{atype}>(resolve_tensor_list_dispatch_flags({name}, {name}_len_)))"
+            )
+        else:
+            callsite_exprs.append(f"pointer_to_list<{atype}>({name}, {name}_len_)")
+
+        aten_types = [f"c10::ArrayRef<{t}>" for t in aten_types]
+        return (
+            c_types,
+            names,
+            aten_types,
+            callsite_exprs,
+        )
+    raise NotImplementedError(f"Argument type {repr(typ)} not supported!")
+
+
+def zip_type_and_name(types: list[str], names: list[str]) -> list[str]:
+    return [typ + " " + name for typ, name in zip(types, names)]
+
+
+# Generate argument declarations and callsite expressions
+def gen_arguments(
+    flat_arguments: Sequence[Argument], skipped_args: set[str]
+) -> tuple[list[str], list[str]]:
+    types: list[str] = []
+    new_names: list[str] = []
+    callsite_exprs: list[str] = []
+    for arg in flat_arguments:
+        if arg.name in skipped_args:
+            callsite_exprs.append("std::nullopt")
+            continue
+        new_types, names, _, new_callsite_exprs = convert_arg_type_and_name(
+            arg.type, arg.name, arg.is_write
+        )
+        types.extend(new_types)
+        new_names.extend(names)
+        callsite_exprs.extend(new_callsite_exprs)
+    return zip_type_and_name(types, new_names), callsite_exprs
+
+
+# Return values are passed out as pointer arguments because all the C shim functions
+# are expected to return AOTITorchError.
+# Generate returns as declarations and callsite expressions
+def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]:
+    types = []
+    names = []
+    for idx, ret in enumerate(schema.returns):
+        names.append(f"ret{idx}")
+        if isinstance(ret.type, BaseType) and ret.type.name in base_type_to_c_type:
+            types.append(base_type_to_c_type[ret.type.name] + "*")
+        else:
+            raise NotImplementedError(
+                f"TODO: add support for return type {repr(ret.type)}"
+            )
+
+    def convert_return(typ: BaseType, val: str) -> str:
+        if typ.name == BaseTy.Tensor:
+            return f"new_tensor_handle(std::move({val}))"
+        elif typ.name == BaseTy.SymInt:
+            return f"{val}.expect_int()"
+        elif typ.name == BaseTy.Scalar:
+            return f"{val}.toDouble()"
+        else:
+            return val
+
+    ret_pointer_can_be_null = False
+    unambiguous_name = schema.name.unambiguous_name()
+    for name in [
+        "_scaled_dot_product_flash_attention",
+        "_scaled_dot_product_efficient_attention",
+        "_scaled_dot_product_cudnn_attention",
+        "_scaled_dot_product_fused_attention_overrideable",
+        "convolution_backward",
+    ]:
+        if name in unambiguous_name:
+            ret_pointer_can_be_null = True
+            break
+
+    callsite_exprs: list[str] = []
+    for idx, ret in enumerate(schema.returns):
+        tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)"
+        assert isinstance(ret.type, BaseType)
+        rval = convert_return(ret.type, tmp)
+        if ret_pointer_can_be_null:
+            callsite_exprs.append(f"if ({names[idx]}) {{ *{names[idx]} = {rval}; }}")
+        else:
+            callsite_exprs.append(f"*{names[idx]} = {rval};")
+
+    return zip_type_and_name(types, names), callsite_exprs
+
+
+# gen.py generates header first and then src, so caching the result here to avoid duplicate work
+declaration_definition_cache: dict[tuple[str, str, str], tuple[str, str]] = {}
+
+
+def gen_declaration_and_definition(
+    schema: FunctionSchema,
+    device: str,
+    backend_call: str,
+    version_info: dict[str, list[str]],
+) -> tuple[str, str]:
+    base_name = schema.name.unambiguous_name()
+
+    global declaration_definition_cache
+    if (base_name, device, backend_call) in declaration_definition_cache:
+        return declaration_definition_cache[(base_name, device, backend_call)]
+
+    # Check the validity of version_info. The format should look like
+    # {"v2" : ["new_arg1"], "v3": ["new_arg2, new_arg3"]}.
+    indexed_version_info: dict[int, list[str]] = {1: []}
+    for ver_str, new_args in sorted(version_info.items()):
+        assert ver_str.startswith("v"), (
+            f"Version number for {base_name} is {ver_str}, not starting with 'v'"
+        )
+        try:
+            ver_id = int(ver_str[1:])
+        except ValueError as e:
+            raise AssertionError(
+                f"Version number for {base_name} is {ver_str}, not a valid integer after 'v'"
+            ) from e
+        assert ver_id not in indexed_version_info, (
+            f"{ver_str} for {base_name} has already been defined"
+        )
+        indexed_version_info[ver_id] = new_args
+
+    declarations: list[str] = []
+    definitions: list[str] = []
+    skipped_args: set[str] = set()
+
+    for ver_id, new_args in sorted(indexed_version_info.items(), reverse=True):
+        # Iterate in the reverse order, so the latest version of an op will get generated first
+        # with all the arguments included, while a set of to-be-trimmed args is carried down
+        # to generate earlier version of the op.
+        func_name = base_name if ver_id == 1 else f"{base_name}_v{ver_id}"
+        if schema.is_out_fn():
+            # out_variant has out arguments in the front, and it's ok to ignore return values
+            # because C shim functions only return AOTITorchError
+            args, callsite_exprs = gen_arguments(
+                [*schema.arguments.out, *schema.arguments.flat_non_out], skipped_args
+            )
+            ret_assignments: list[str] = []
+        else:
+            args, callsite_exprs = gen_arguments(
+                schema.arguments.flat_all, skipped_args
+            )
+            # ignore return values for inplace ops
+            ret_declarations, ret_assignments = (
+                ([], []) if schema.name.name.inplace else gen_returns(schema)
+            )
+            args.extend(ret_declarations)
+
+        declaration = textwrap.dedent(
+            f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
+        )
+
+        tmp_result = "auto tmp_result = " if ret_assignments else ""
+        indent = "\t\t"
+        ret_assignments_str = (
+            "\n".join(indent + r for r in ret_assignments) if ret_assignments else ""
+        )
+        definition = (
+            textwrap.dedent(f"""
+        {declaration} {{
+            AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{
+                {tmp_result}{backend_call}(
+                    {", ".join(callsite_exprs)}
+                );
+        """)
+            + ret_assignments_str
+            + textwrap.dedent("""
+            });
+        }
+        """)
+        )
+        skipped_args.update(new_args)
+        declarations.append(f"AOTI_TORCH_EXPORT {declaration};")
+        definitions.append(definition)
+
+    declaration_definition_cache[(base_name, device, backend_call)] = (
+        "\n".join(declarations),
+        "\n".join(definitions),
+    )
+    return declaration_definition_cache[(base_name, device, backend_call)]
+
+
+def gen_static_dispatch_backend_call_signature(
+    sig: CppSignature | DispatcherSignature,
+    f: NativeFunction,
+) -> CppSignature:
+    sig = DispatcherSignature.from_schema(f.func)
+    cpp_sigs = CppSignatureGroup.from_native_function(
+        f, method=False, fallback_binding=False
+    )
+    if sig.symint and f.func.has_symint():
+        cpp_sig = cpp_sigs.symint_signature
+    else:
+        cpp_sig = cpp_sigs.signature
+    assert cpp_sig is not None
+    return cpp_sig
+
+
+def gen_static_dispatch_backend_call(
+    f: NativeFunction,
+    backend_index: BackendIndex,
+) -> str:
+    sig = DispatcherSignature.from_schema(f.func)
+    cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
+    return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
+
+
+def get_backend_index_for_aoti(
+    func: NativeFunction,
+    func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
+    dispatch_key: DispatchKey,
+    backend_indices: dict[DispatchKey, BackendIndex],
+    extend_aoti_c_shim: bool,
+) -> BackendIndex | None:
+    backend_index = None
+    if backend_indices[dispatch_key].has_kernel(func) or (
+        func.structured_delegate is not None
+        and func.structured_delegate in func_group_mapping
+        and backend_indices[dispatch_key].has_kernel(
+            func_group_mapping[func.structured_delegate]
+        )
+    ):
+        backend_index = backend_indices[dispatch_key]
+    else:
+        # for the extend out-of-tree kernels, we don't need to
+        # duplicatly create C shim wrappers for other dispatch keys
+        if extend_aoti_c_shim:
+            return backend_index
+
+        elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(func):
+            # We need to create C shim wrappers for CompositeExplicitAutograd kernels
+            backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd]
+        elif backend_indices[
+            DispatchKey.CompositeExplicitAutogradNonFunctional
+        ].has_kernel(func):
+            # We need to create C shim wrappers for CompositeExplicitAutogradNonFunctional kernels
+            backend_index = backend_indices[
+                DispatchKey.CompositeExplicitAutogradNonFunctional
+            ]
+        elif backend_indices[DispatchKey.CompositeImplicitAutograd].has_kernel(func):
+            backend_index = backend_indices[DispatchKey.CompositeImplicitAutograd]
+
+    return backend_index
+
+
+def get_header_for_aoti(
+    func: NativeFunction,
+    func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
+    dispatch_key: DispatchKey,
+    backend_indices: dict[DispatchKey, BackendIndex],
+    extend_aoti_c_shim: bool,
+) -> str | None:
+    backend_index = get_backend_index_for_aoti(
+        func, func_group_mapping, dispatch_key, backend_indices, extend_aoti_c_shim
+    )
+    return (
+        None
+        if backend_index is None
+        else f"#include "
+    )
+
+
+def get_fallback_op_name(func: NativeFunction) -> str:
+    return (
+        f"{func.namespace}.{func.func.name.name}.{func.func.name.overload_name}"
+        if func.func.name.overload_name
+        else f"{func.namespace}.{func.func.name.name}.default"
+    )
+
+
+def gen_c_shim(
+    func: NativeFunction,
+    version_info: dict[str, list[str]],
+    func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
+    dispatch_key: DispatchKey,
+    backend_indices: dict[DispatchKey, BackendIndex],
+    header: bool,
+    extend_aoti_c_shim: bool,
+) -> str | None:
+    backend_index = get_backend_index_for_aoti(
+        func, func_group_mapping, dispatch_key, backend_indices, extend_aoti_c_shim
+    )
+    if backend_index is None:
+        return None
+
+    schema = func.func
+    device = dispatch_key.lower()
+    backend_call = gen_static_dispatch_backend_call(
+        func,
+        backend_index,
+    )
+
+    try:
+        if header:
+            declaration, _ = gen_declaration_and_definition(
+                schema, device, backend_call, version_info
+            )
+            return declaration
+        else:
+            _, definition = gen_declaration_and_definition(
+                schema, device, backend_call, version_info
+            )
+            return definition
+
+    except NotImplementedError:
+        return None
+
+
+@dataclass(frozen=True)
+class ShimGenerator:
+    inductor_fallback_ops: dict[str, dict[str, list[str]]]
+    func_group_mapping: dict[OperatorName, NativeFunctionsGroup]
+    dispatch_key: DispatchKey
+    backend_indices: dict[DispatchKey, BackendIndex]
+    header: bool  # True to generate .h and False to generate .cpp
+    extend_aoti_c_shim: bool
+
+    @method_with_native_function
+    def __call__(
+        self,
+        func: NativeFunction,
+    ) -> str | None:
+        version_info = self.inductor_fallback_ops[get_fallback_op_name(func)]
+        result = gen_c_shim(
+            func,
+            version_info,
+            self.func_group_mapping,
+            self.dispatch_key,
+            self.backend_indices,
+            self.header,
+            self.extend_aoti_c_shim,
+        )
+        return result
+
+
+def gen_aoti_c_shim(
+    native_functions: Sequence[NativeFunction],
+    inductor_fallback_ops: dict[str, dict[str, list[str]]],
+    func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
+    dispatch_key: DispatchKey,
+    backend_indices: dict[DispatchKey, BackendIndex],
+    header: bool,
+    extend_aoti_c_shim: bool,
+    includes: str = "",
+) -> str:
+    body = "\n".join(
+        list(
+            mapMaybe(
+                ShimGenerator(
+                    inductor_fallback_ops,
+                    func_group_mapping,
+                    dispatch_key,
+                    backend_indices,
+                    header,
+                    extend_aoti_c_shim,
+                ),
+                native_functions,
+            )
+        )
+    )
+    device = dispatch_key.lower()
+    warning = """
+
+// WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
+// See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details"""
+
+    if header:
+        return (
+            warning
+            + textwrap.dedent("""
+
+            #pragma once
+
+            #include 
+
+            #ifdef __cplusplus
+            extern "C" {
+            #endif
+
+            """)
+            + body
+            + textwrap.dedent("""
+
+            #ifdef __cplusplus
+            } // extern "C"
+            #endif
+            """)
+        )
+    else:
+        return (
+            warning
+            + textwrap.dedent(f"""
+
+            #include 
+            #include 
+
+            #ifndef AT_PER_OPERATOR_HEADERS
+            #include 
+            #include 
+            #include 
+            #include 
+            #else
+            """)
+            + includes
+            + textwrap.dedent("""
+            #endif // AT_PER_OPERATOR_HEADERS
+
+            using namespace torch::aot_inductor;
+
+            """)
+            + body
+        )
+
+
+def gen_aoti_c_shim_files(
+    aoti_fm: FileManager,
+    aoti_backends: set[DispatchKey],
+    native_functions: Sequence[NativeFunction],
+    backend_indices: dict[DispatchKey, BackendIndex],
+    structured_native_functions: Sequence[NativeFunctionsGroup],
+    extra_cuda_headers: str,
+    extend_aoti_c_shim: bool,
+    update_aoti_c_shim: bool,
+) -> None:
+    structured_func_group_dict = {}
+    for func_group in structured_native_functions:
+        for func in func_group.functions():
+            if func.structured_delegate is not None:
+                structured_func_group_dict[func.structured_delegate] = func_group
+                break
+
+    for dispatch_key in aoti_backends:
+        fallbacks = {}
+        for func in native_functions:
+            op_name = get_fallback_op_name(func)
+            if op_name in inductor_fallback_ops:
+                fallbacks[op_name] = func
+        fallback_native_functions = tuple(
+            value for _, value in sorted(fallbacks.items())
+        )
+
+        # header files were checked in for ABI-compatiblilty checking
+        header_file_name = f"c_shim_{dispatch_key.lower()}.h"
+        new_header = gen_aoti_c_shim(
+            fallback_native_functions,
+            inductor_fallback_ops,
+            structured_func_group_dict,
+            dispatch_key,
+            backend_indices,
+            header=True,
+            extend_aoti_c_shim=extend_aoti_c_shim,
+            includes="",
+        )
+        if update_aoti_c_shim:
+            aoti_fm.write(
+                header_file_name,
+                lambda: new_header,
+            )
+        else:
+            try:
+                with open(
+                    os.path.join(aoti_fm.install_dir, header_file_name)
+                ) as old_file:
+                    old_header = old_file.read()
+
+                    if old_header != new_header:
+                        diff = "\n".join(
+                            difflib.unified_diff(
+                                old_header.splitlines(),
+                                new_header.splitlines(),
+                                fromfile="expected",
+                                tofile="actual",
+                                lineterm="",
+                            )
+                        )
+
+                        raise RuntimeError(f"""
+The generated AOTInductor C shim header files have unexpectedly changed. This
+indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
+Only in a limited number of situations, this is allowed:
+
+1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py.
+If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to add a new entry to
+existing C shim header files.
+
+2. You added a new default argument to an existing fallback op. This is clearly a BC breaking
+change in the AOTInductor land. You need to annotate the new default argument in
+torchgen/aoti/fallback_ops.py, and then run `python torchgen/gen.py --update-aoti-c-shim` to
+update the C shim header files by creating different versions of the fallback op. See
+https://github.com/pytorch/pytorch/pull/154848 as an example.
+
+{diff}
+                    """)
+            except FileNotFoundError:
+                print(
+                    f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found"
+                )
+
+        # cpp files are always generated on-the-fly
+        def headers_for_aoti() -> str:
+            headers = []
+            for func in fallback_native_functions:
+                header = get_header_for_aoti(
+                    func,
+                    structured_func_group_dict,
+                    dispatch_key,
+                    backend_indices,
+                    extend_aoti_c_shim=extend_aoti_c_shim,
+                )
+                if header is not None:
+                    headers.append(header)
+            return "\n".join(sorted(set(headers)))
+
+        extra_headers = extra_cuda_headers if is_cuda_dispatch_key(dispatch_key) else ""
+
+        aoti_fm.write(
+            f"c_shim_{dispatch_key.lower()}.cpp",
+            lambda: gen_aoti_c_shim(
+                fallback_native_functions,
+                inductor_fallback_ops,
+                structured_func_group_dict,
+                dispatch_key,
+                backend_indices,
+                header=False,
+                extend_aoti_c_shim=extend_aoti_c_shim,
+                includes=headers_for_aoti() + "\n" + extra_headers,
+            ),
+        )
diff --git a/phivenv/Lib/site-packages/torchgen/gen_backend_stubs.py b/phivenv/Lib/site-packages/torchgen/gen_backend_stubs.py
new file mode 100644
index 0000000000000000000000000000000000000000..f15bde68f133894e886e05345e4ec1adda278c74
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/gen_backend_stubs.py
@@ -0,0 +1,615 @@
+from __future__ import annotations
+
+import argparse
+import os
+import re
+from collections import Counter, defaultdict, namedtuple
+from pathlib import Path
+from typing import TYPE_CHECKING
+
+import yaml
+
+import torchgen.api.dispatcher as dispatcher
+import torchgen.dest as dest
+from torchgen.api.types import DispatcherSignature
+from torchgen.code_template import CodeTemplate
+from torchgen.context import native_function_manager
+from torchgen.gen import get_grouped_native_functions, parse_native_yaml
+from torchgen.model import (
+    BackendIndex,
+    BackendMetadata,
+    DispatchKey,
+    NativeFunction,
+    NativeFunctionsGroup,
+    OperatorName,
+)
+from torchgen.selective_build.selector import SelectiveBuilder
+from torchgen.utils import concatMap, context, FileManager, NamespaceHelper, Target
+from torchgen.yaml_utils import YamlLoader
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+
+# Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key.
+# Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping)
+ParsedExternalYaml = namedtuple(
+    "ParsedExternalYaml",
+    ["backend_key", "autograd_key", "class_name", "cpp_namespace", "backend_indices"],
+)
+
+
+def parse_backend_yaml(
+    backend_yaml_path: str,
+    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
+    backend_indices: dict[DispatchKey, BackendIndex],
+) -> ParsedExternalYaml:
+    native_functions_map: dict[OperatorName, NativeFunction] = {
+        f.func.name: f
+        for f in concatMap(
+            lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()),
+            grouped_native_functions,
+        )
+    }
+
+    with open(backend_yaml_path) as f:
+        yaml_values = yaml.load(f, Loader=YamlLoader)
+    assert isinstance(yaml_values, dict)
+
+    valid_keys = [
+        "backend",
+        "class_name",
+        "cpp_namespace",
+        "extra_headers",
+        "supported",
+        "autograd",
+        "full_codegen",
+        "non_native",
+        "ir_gen",
+        "symint",
+    ]
+
+    backend = yaml_values.pop("backend", None)
+    assert backend is not None, 'You must provide a value for "backend"'
+
+    class_name = yaml_values.pop("class_name", None)
+
+    cpp_namespace = yaml_values.pop("cpp_namespace", None)
+    assert cpp_namespace is not None, 'You must provide a value for "cpp_namespace"'
+
+    # Mostly just defaulting to false to stick with LazyTensor convention.
+    use_out_as_primary = yaml_values.pop("use_out_as_primary", False)
+    assert isinstance(use_out_as_primary, bool), (
+        f"You must provide either True or False for use_out_as_primary. Provided: {use_out_as_primary}"
+    )
+
+    use_device_guard = yaml_values.pop("device_guard", False)
+    assert isinstance(use_device_guard, bool), (
+        f"You must provide either True or False for device_guard. Provided: {use_device_guard}"
+    )
+
+    supported = yaml_values.pop("supported", [])
+    if supported is None:
+        supported = []  # Allow an empty list of supported ops
+    assert isinstance(supported, list), (
+        f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})'
+    )
+
+    symint = yaml_values.pop("symint", [])
+    if symint is None:
+        symint = []  # Allow an empty list of symint ops
+    assert isinstance(symint, list), (
+        f'expected "symint" to be a list, but got: {supported} (of type {type(supported)})'
+    )
+    symint_set = set(symint)
+
+    supported_autograd = yaml_values.pop("autograd", [])
+    assert isinstance(supported_autograd, list), (
+        f'expected "autograd" to be a list, but got: {supported_autograd}'
+    )
+
+    # full_codegen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
+    full_codegen = yaml_values.pop("full_codegen", [])
+    supported.extend(full_codegen)
+
+    # non_native is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
+    yaml_values.pop("non_native", {})
+
+    # ir_gen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
+    yaml_values.pop("ir_gen", {})
+
+    assert len(yaml_values.keys()) == 0, (
+        f"{backend_yaml_path} contains unexpected keys: {', '.join(yaml_values.keys())}. "
+        f"Only the following keys are supported: {', '.join(valid_keys)}"
+    )
+
+    def create_backend_index(
+        backend_ops: list[str],
+        symint_ops: set[str],
+        dispatch_key: DispatchKey,
+        *,
+        use_out_as_primary: bool,
+        use_device_guard: bool,
+    ) -> BackendIndex:
+        metadata: dict[OperatorName, BackendMetadata] = {}
+        for op in backend_ops:
+            op_name = OperatorName.parse(op)
+            assert op_name in native_functions_map, (
+                f"Found an invalid operator name: {op_name}"
+            )
+            # See Note [External Backends Follow Dispatcher API]
+            kernel_name = dispatcher.name(native_functions_map[op_name].func)
+            if op in symint_ops:
+                kernel_name += "_symint"
+            # TODO: allow structured external backends later.
+            m = BackendMetadata(
+                kernel=kernel_name, structured=False, cpp_namespace=cpp_namespace
+            )
+            metadata[op_name] = m
+        return BackendIndex(
+            dispatch_key=dispatch_key,
+            use_out_as_primary=use_out_as_primary,
+            external=True,
+            device_guard=use_device_guard,
+            index=metadata,
+        )
+
+    backend_key: DispatchKey | None = None
+    if len(supported) > 0:
+        with context(
+            lambda: f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.'
+        ):
+            backend_key = DispatchKey.parse(backend)
+
+        backend_idx = create_backend_index(
+            supported,
+            symint_set,
+            backend_key,
+            use_out_as_primary=use_out_as_primary,
+            use_device_guard=use_device_guard,
+        )
+        assert backend_key not in backend_indices
+        backend_indices[backend_key] = backend_idx
+
+    autograd_key: DispatchKey | None = None
+    if len(supported_autograd) > 0:
+        with context(
+            lambda: f'The "autograd" key was specified, which indicates that you would like to override \
+the behavior of autograd for some operators on your backend. However "Autograd{backend}" is not a valid DispatchKey.'
+        ):
+            autograd_key = DispatchKey.parse(f"Autograd{backend}")
+
+        autograd_idx = create_backend_index(
+            supported_autograd,
+            symint_set,
+            autograd_key,
+            use_out_as_primary=use_out_as_primary,
+            use_device_guard=use_device_guard,
+        )
+        assert autograd_key not in backend_indices
+        backend_indices[autograd_key] = autograd_idx
+
+    for g in grouped_native_functions:
+        if isinstance(g, NativeFunction):
+            forward_kernels = (
+                []
+                if backend_key is None
+                else [
+                    m
+                    for m in [backend_indices[backend_key].get_kernel(g)]
+                    if m is not None
+                ]
+            )
+            backward_kernels = (
+                []
+                if autograd_key is None
+                else [
+                    m
+                    for m in [backend_indices[autograd_key].get_kernel(g)]
+                    if m is not None
+                ]
+            )
+        else:
+            forward_kernels = (
+                []
+                if backend_key is None
+                else [
+                    m
+                    for m in [
+                        backend_indices[backend_key].get_kernel(f)
+                        for f in g.functions()
+                    ]
+                    if m is not None
+                ]
+            )
+            backward_kernels = (
+                []
+                if autograd_key is None
+                else [
+                    m
+                    for m in [
+                        backend_indices[autograd_key].get_kernel(f)
+                        for f in g.functions()
+                    ]
+                    if m is not None
+                ]
+            )
+
+        forward_kernels = [f for f in forward_kernels if f is not None]
+        backward_kernels = [f for f in backward_kernels if f is not None]
+        assert len(forward_kernels) == 0 or len(backward_kernels) == 0, (
+            f'Currently, all variants of an op must either be registered to a backend key, or to a backend\'s \
+autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! \
+{forward_kernels[0].kernel} is listed under "supported", but {backward_kernels[0].kernel} is listed under "autograd".'
+        )
+
+    return ParsedExternalYaml(
+        backend_key, autograd_key, class_name, cpp_namespace, backend_indices
+    )
+
+
+def error_on_missing_kernels(
+    native_functions: Sequence[NativeFunction],
+    backend_indices: dict[DispatchKey, BackendIndex],
+    backend_key: DispatchKey,
+    autograd_key: DispatchKey | None,
+    class_name: str,
+    kernel_defn_file_path: str,
+    full_codegen: list[OperatorName] | None = None,
+) -> None:
+    try:
+        with open(kernel_defn_file_path) as f:
+            backend_defns = f.read()
+    except OSError as e:
+        raise AssertionError(
+            f"Unable to read from the specified impl_path file: {kernel_defn_file_path}"
+        ) from e
+
+    if full_codegen is None:
+        full_codegen = []
+
+    indices = [backend_indices[backend_key].index] + (
+        [] if autograd_key is None else [backend_indices[autograd_key].index]
+    )
+    # Quick mapping from each OperatorName used by the external backend
+    # to its backend kernel name
+    expected_backend_op_names: dict[OperatorName, str] = dict(
+        list(
+            concatMap(
+                lambda index: [
+                    (op_name, metadata.kernel) for op_name, metadata in index.items()
+                ],
+                indices,
+            )
+        )
+    )
+    expected_backend_native_funcs: list[NativeFunction] = [
+        f
+        for f in native_functions
+        if f.func.name in expected_backend_op_names.keys()
+        and f.func.name not in full_codegen
+    ]
+    expected_backend_kernel_name_counts: dict[str, list[NativeFunction]] = defaultdict(
+        list
+    )
+    for native_f in expected_backend_native_funcs:
+        expected_backend_kernel_name_counts[
+            expected_backend_op_names[native_f.func.name]
+        ].append(native_f)
+
+    # This just looks for lines containing "foo(", and assumes that the kernel foo has been implemented.
+    # It might cause false negatives (we won't catch all cases), but that's ok - if we catch a missing kernel
+    # here, then we get a nicer error message. If we miss it, you get a linker error.
+    kernel_defn_regex = rf"(.*){class_name}::\s*([\w\d]*)\("
+    actual_backend_kernel_name_counts = Counter(
+        # A bit unwieldy (this could probably be moved into regex),
+        # but we don't want to include kernel names that come from function calls,
+        # like "return torch_xla::XLANativeFunctions::empty_strided_symint(...)".
+        # Easy check is to ignore any lines with colons before the class name.
+        [
+            y
+            for (x, y) in re.findall(kernel_defn_regex, backend_defns)
+            if not x.endswith(":")
+        ]
+    )
+
+    missing_kernels_err_msg = ""
+    for expected_name, funcs in expected_backend_kernel_name_counts.items():
+        expected_overload_count = len(funcs)
+        actual_overload_count = actual_backend_kernel_name_counts[expected_name]
+        if expected_overload_count != actual_overload_count:
+
+            def create_decl(f: NativeFunction) -> str:
+                with native_function_manager(f):
+                    return DispatcherSignature.from_schema(f.func).decl()
+
+            expected_schemas_str = "\n".join([create_decl(f) for f in funcs])
+            missing_kernels_err_msg += f"""
+{class_name} is missing a kernel definition for {expected_name}. We found {actual_overload_count} kernel(s) with that name,
+but expected {expected_overload_count} kernel(s). The expected function schemas for the missing operator are:
+{expected_schemas_str}
+
+"""
+    assert missing_kernels_err_msg == "", missing_kernels_err_msg
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(description="Generate backend stub files")
+    parser.add_argument(
+        "-s",
+        "--source-yaml",
+        "--source_yaml",
+        help="path to source yaml file containing operator external definitions",
+    )
+    parser.add_argument("-o", "--output-dir", "--output_dir", help="output directory")
+    parser.add_argument(
+        "--dry-run", "--dry_run", type=bool, default=False, help="output directory"
+    )
+    parser.add_argument(
+        "--impl-path",
+        "--impl_path",
+        type=str,
+        default=None,
+        help="path to the source C++ file containing kernel definitions",
+    )
+    options = parser.parse_args()
+
+    run(options.source_yaml, options.output_dir, options.dry_run, options.impl_path)
+
+
+def gen_dispatchkey_nativefunc_headers(
+    fm: FileManager,
+    class_name: str,
+    cpp_namespace: str,
+    backend_indices: dict[DispatchKey, BackendIndex],
+    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
+    backend_dispatch_key: DispatchKey,
+    autograd_dispatch_key: DispatchKey | None,
+    backend_name: str = "",
+) -> None:
+    assert class_name is not None
+    generated_comment = (
+        "Autogenerated file by gen_backend_stubs.py. Do not edit directly!"
+    )
+
+    # Convert to a set first to remove duplicate kernel names.
+    # Backends are allowed to repeat kernel names; only generate the declaration once!
+    # Sort for deterministic output.
+    backend_declarations = sorted(
+        set(
+            concatMap(
+                lambda f: dest.compute_native_function_declaration(
+                    f, backend_indices[backend_dispatch_key]
+                ),
+                grouped_native_functions,
+            )
+        )
+    )
+    autograd_declarations = sorted(
+        set(
+            concatMap(
+                lambda f: []
+                if autograd_dispatch_key is None
+                else dest.compute_native_function_declaration(
+                    f, backend_indices[autograd_dispatch_key]
+                ),
+                grouped_native_functions,
+            )
+        )
+    )
+
+    ns_helper = NamespaceHelper(cpp_namespace)
+    fm.write_with_template(
+        f"{backend_dispatch_key}NativeFunctions.h",
+        "DispatchKeyNativeFunctions.h",
+        lambda: {
+            "generated_comment": generated_comment,
+            "namespace_prologue": ns_helper.prologue,
+            "class_name": class_name,
+            "namespace_epilogue": ns_helper.epilogue,
+            "dispatch_declarations": backend_declarations + autograd_declarations,
+            "BackendName": backend_name,
+            "DispatchKey": backend_dispatch_key,
+        },
+    )
+
+
+def gen_dispatcher_registrations(
+    fm: FileManager,
+    output_dir: str,
+    class_name: str,
+    backend_indices: dict[DispatchKey, BackendIndex],
+    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
+    backend_dispatch_key: DispatchKey,
+    dispatch_key: DispatchKey,
+    selector: SelectiveBuilder,
+    # build_in_tree is true for lazy TS backend and affects include paths, not used for external backends
+    build_in_tree: bool = False,
+    per_operator_headers: bool = False,
+    backend_name: str = "",
+    eager_registration: bool = True,
+) -> None:
+    headers = [
+        f"{output_dir}/{backend_dispatch_key}NativeFunctions.h",
+    ]
+    if build_in_tree:
+        external_backend_headers_str = "\n".join(f"#include <{h}>" for h in headers)
+    else:
+        external_backend_headers_str = "\n".join(f'#include "{h}"' for h in headers)
+
+    assert class_name is not None
+    backend_index = backend_indices[dispatch_key]
+
+    dispatch_registrations_body = list(
+        concatMap(
+            dest.RegisterDispatchKey(
+                backend_index,
+                Target.REGISTRATION,
+                selector,
+                rocm=False,
+                symint=True,
+                class_method_name=f"{class_name}",
+                skip_dispatcher_op_registration=False,
+            ),
+            grouped_native_functions,
+        )
+    )
+    newline = "\n"
+    ns_helper = NamespaceHelper(namespace_str="at")
+    deferred_dispatch_registrations = ""
+    static_init_dispatch_registrations = ""
+    if eager_registration:
+        static_template = CodeTemplate(
+            """\
+TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
+    $dispatch_registrations_body
+}"""
+        )
+        static_init_dispatch_registrations = static_template.substitute(
+            dispatch_key=dispatch_key,
+            dispatch_registrations_body=dispatch_registrations_body,
+        )
+    else:
+        deferred_template = CodeTemplate(
+            """\
+TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions();
+TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
+    static auto m = MAKE_TORCH_LIBRARY_IMPL(aten, $dispatch_key);
+    $dispatch_registrations_body
+}"""
+        )
+        deferred_dispatch_registrations = deferred_template.substitute(
+            backend_name=backend_name,
+            dispatch_key=dispatch_key,
+            dispatch_registrations_body=dispatch_registrations_body,
+        )
+
+    fm.write_with_template(
+        f"Register{dispatch_key}.cpp",
+        "RegisterDispatchKey.cpp",
+        lambda: {
+            "extra_cuda_headers": "",
+            "external_backend_headers": external_backend_headers_str,
+            "ops_headers": "#include "
+            if not per_operator_headers
+            else "",
+            "DispatchKey": dispatch_key,
+            "dispatch_namespace": dispatch_key.lower(),
+            "dispatch_headers": dest.gen_registration_headers(
+                backend_index, per_operator_headers=per_operator_headers, rocm=False
+            ),
+            "dispatch_helpers": dest.gen_registration_helpers(backend_index),
+            "dispatch_definitions": fm.substitute_with_template(
+                "RegisterDispatchDefinitions.ini",
+                lambda: {
+                    "ns_prologue": ns_helper.prologue,
+                    "ns_epilogue": ns_helper.epilogue,
+                    "static_init_dispatch_registrations": static_init_dispatch_registrations,
+                    "deferred_dispatch_registrations": deferred_dispatch_registrations,
+                    "dispatch_namespace": dispatch_key.lower(),
+                    "dispatch_namespaced_definitions": "",
+                    "dispatch_anonymous_definitions": list(
+                        concatMap(
+                            dest.RegisterDispatchKey(
+                                backend_index,
+                                Target.ANONYMOUS_DEFINITION,
+                                selector,
+                                rocm=False,
+                                symint=True,
+                                class_method_name=f"{class_name}",
+                                skip_dispatcher_op_registration=False,
+                            ),
+                            grouped_native_functions,
+                        )
+                    ),
+                },
+            ).split(newline),
+        },
+    )
+
+
+def run(
+    source_yaml: str, output_dir: str, dry_run: bool, impl_path: str | None = None
+) -> None:
+    # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
+    pytorch_root = Path(__file__).absolute().parent.parent
+    template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")
+
+    def make_file_manager(install_dir: str) -> FileManager:
+        return FileManager(
+            install_dir=install_dir, template_dir=template_dir, dry_run=dry_run
+        )
+
+    fm = make_file_manager(output_dir)
+
+    native_yaml_path = os.path.join(
+        pytorch_root, "aten/src/ATen/native/native_functions.yaml"
+    )
+    tags_yaml_path = os.path.join(pytorch_root, "aten/src/ATen/native/tags.yaml")
+    parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
+    native_functions, backend_indices = (
+        parsed_yaml.native_functions,
+        parsed_yaml.backend_indices,
+    )
+    grouped_native_functions = get_grouped_native_functions(native_functions)
+    parsed_backend_yaml = parse_backend_yaml(
+        source_yaml, grouped_native_functions, backend_indices
+    )
+    backend_key = parsed_backend_yaml.backend_key
+    autograd_key = parsed_backend_yaml.autograd_key
+    cpp_namespace = parsed_backend_yaml.cpp_namespace
+    class_name = parsed_backend_yaml.class_name
+    backend_indices = parsed_backend_yaml.backend_indices
+
+    selector = SelectiveBuilder.get_nop_selector()
+
+    if backend_key is None:
+        # This could be useful if a backend wants to quickly set up a noop yaml file but doesn't have any kernels ready yet.
+        return
+
+    if class_name is None:
+        # class_name is an optional argument to backend yaml file.
+        # if specified it allows an external backend to override
+        # the name of the class that all generated kernel definitions live under.
+        # if not specified, its value is given as native_function_class_name.
+        class_name = backend_indices[backend_key].native_function_class_name()
+    assert class_name is not None
+
+    if impl_path is not None:
+        error_on_missing_kernels(
+            native_functions,
+            backend_indices,
+            backend_key,
+            autograd_key,
+            class_name,
+            impl_path,
+        )
+
+    gen_dispatchkey_nativefunc_headers(
+        fm,
+        class_name,
+        cpp_namespace,
+        backend_indices,
+        grouped_native_functions,
+        backend_key,
+        autograd_key,
+    )
+
+    for dispatch_key in (
+        [backend_key] if autograd_key is None else [backend_key, autograd_key]
+    ):
+        gen_dispatcher_registrations(
+            fm,
+            output_dir,
+            class_name,
+            backend_indices,
+            grouped_native_functions,
+            backend_key,
+            dispatch_key,
+            selector,
+        )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/phivenv/Lib/site-packages/torchgen/gen_functionalization_type.py b/phivenv/Lib/site-packages/torchgen/gen_functionalization_type.py
new file mode 100644
index 0000000000000000000000000000000000000000..5379929591e881cd4ca8d2cc58e41da4de8a8d81
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/gen_functionalization_type.py
@@ -0,0 +1,882 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Callable, TYPE_CHECKING
+
+from torchgen.api import cpp, dispatcher
+from torchgen.api.translate import translate
+from torchgen.api.types import (
+    BaseCType,
+    Binding,
+    CType,
+    DispatcherSignature,
+    FunctionalizationLambda,
+    iTensorListRefT,
+    NativeSignature,
+    OptionalCType,
+    optionalSymIntArrayRefT,
+    symIntArrayRefT,
+    SymIntT,
+    tensorListT,
+    tensorT,
+    VectorCType,
+    ViewInverseSignature,
+)
+from torchgen.context import (
+    method_with_native_function,
+    native_function_manager,
+    with_native_function,
+    with_native_function_and,
+)
+from torchgen.model import (
+    Argument,
+    BackendIndex,
+    BaseTy,
+    BaseType,
+    FunctionSchema,
+    ListType,
+    NativeFunction,
+    NativeFunctionsGroup,
+    NativeFunctionsViewGroup,
+    Return,
+    SchemaKind,
+    SelfArgument,
+    TensorOptionsArguments,
+)
+from torchgen.native_function_generation import (
+    INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
+    MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
+    OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
+)
+from torchgen.utils import dataclass_repr
+
+
+if TYPE_CHECKING:
+    from torchgen.selective_build.selector import SelectiveBuilder
+
+
+# Note: [Mutable Ops Not Using Functionalization]
+# Ops in this list currently do not work with functionalization and should be fixed.
+MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION = (
+    OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY
+    + MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
+    + INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY
+    + [
+        # It will be BC-breaking, but we should fix their schemas.
+        # should be inplace?
+        "record_stream",
+        # See Note [resize_ in Functionalization]
+        "resize_",
+        "resize_as_",
+        # This function is used as for testing purposes only.
+        "_fill_mem_eff_dropout_mask_",
+    ]
+)
+
+# This file contains codegen that relates to the functionalization pass.
+# It includes:
+# - gen_functionalization_definition
+#     Generates dispatcher kernel definitions for the functionalization pass.
+# - gen_functionalization_registration
+#     Generates dispatcher kernel registrations for the functionalization pass.
+# - gen_functionalization_view_inverse_declaration
+#     Generates a declaration for an "inverse view", for every view op
+#     that is needed in functionalization. We manually implement their definitions.
+# - gen_composite_view_copy_kernel
+#     Generates view_copy() composite kernels for all view_copy operators.
+
+
+# Generates the body of the default composite C++ kernel for a {view}_copy NativeFunction
+# See Note [view_copy NativeFunctions]
+@dataclass(frozen=True)
+class GenCompositeViewCopyKernel:
+    backend_index: BackendIndex
+
+    @method_with_native_function
+    def __call__(self, g: NativeFunctionsViewGroup) -> str | None:
+        if g.view_copy is None:
+            return None
+        elif g.view_copy.func.name.name.base != f"{g.view.func.name.name}_copy":
+            # If the view_copy doesn't match the standard naming scheme of _copy,
+            # assume it already exists and doesn't need to be generated.
+            # Example: slice_inverse() with the copy variant named slice_scatter()
+            # instead of slice_inverse_copy()
+            return None
+
+        metadata = self.backend_index.get_kernel(g.view_copy)
+        assert metadata is not None
+
+        # We can make view_copy work in more cases by using reshape()
+        # when a normal view call would ordinarily fail.
+        # This also makes LTC more efficient, because they don't need to include
+        # clone() calls in their graph (which is normally needed by reshape).
+        if str(g.view_copy.func.name) == "view_copy":
+            assert metadata.kernel == "view_copy_symint"
+            return """\
+at::Tensor view_copy_symint(const at::Tensor & self, at::SymIntArrayRef size) {
+  c10::SymDimVector shape = infer_size_dv(size, self.sym_numel());
+  if (!at::detail::computeStride(self.sym_sizes(), self.sym_strides(), shape).has_value()) {
+    return self.reshape_symint(size);
+  } else {
+    auto output = at::_ops::view::call(self, size);
+    return output.clone(/*memory_format=*/at::MemoryFormat::Contiguous);
+  }
+}
+"""
+        # view_copy is a native signature, since we're generating an at::native:: kernel
+        # Functionalization always operates on symints though
+        view_copy_sig = NativeSignature(
+            g.view_copy.func, symint=metadata.supports_symint()
+        )
+
+        # view is a dispatcher signature, since we're calling into the at::_ops API
+        view_sig = DispatcherSignature(g.view.func)
+
+        view_api_name = g.view.func.name.unambiguous_name()
+        exprs = ", ".join(
+            [e.expr for e in translate(view_copy_sig.arguments(), view_sig.arguments())]
+        )
+
+        # view ops today always return either a Tensor or a list of Tensors
+        assert len(g.view.func.returns) == 1
+        assert g.view.func.returns[0].type == BaseType(
+            BaseTy.Tensor
+        ) or g.view.func.returns[0].type == ListType(BaseType(BaseTy.Tensor), None)
+
+        if g.view.func.returns[0].type == BaseType(BaseTy.Tensor):
+            return_cloned_output = """\
+  return output.clone(/*memory_format=*/at::MemoryFormat::Contiguous);"""
+        else:
+            # If the return type is a list, we need to clone each tensor in the list.
+            return_cloned_output = f"""\
+  {view_copy_sig.returns_type().cpp_type()} out_clone;
+  for (const auto i : c10::irange(output.size())) {{
+    out_clone.push_back(output[i].clone(/*memory_format=*/at::MemoryFormat::Contiguous));
+  }}
+  return out_clone;"""
+
+        # The default generated composite kernel for {view}_copy() operators just clones
+        # the input tensor, and runs the underlying view on the clone.
+        return f"""
+{view_copy_sig.defn(name=metadata.kernel)} {{
+  auto output = at::_ops::{view_api_name}::call({exprs});
+  {return_cloned_output}
+}}
+"""
+
+
+def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
+    assert len(rets) == len(names)
+    if len(rets) == 0:
+        return ""
+    elif len(rets) == 1:
+        return f"return {names[0]};"
+    else:
+        return f"return {dispatcher.returns_type(rets).cpp_type()}({', '.join(names)});"
+
+
+def modifies_arguments(f: NativeFunction) -> bool:
+    return any(
+        a.annotation is not None and a.annotation.is_write
+        for a in f.func.arguments.flat_all
+    )
+
+
+def wrapper_name(func: FunctionSchema) -> str:
+    if func.name.overload_name:
+        return f"{cpp.name(func)}_{func.name.overload_name}"
+    else:
+        return cpp.name(func)
+
+
+def is_tensor_like(a: Argument | TensorOptionsArguments | SelfArgument) -> bool:
+    return isinstance(a, SelfArgument) or (
+        isinstance(a, Argument) and a.type.is_tensor_like()
+    )
+
+
+# We need to wrap / unwrap various arguments from the op in the functionalization kernels.
+# Some op schemas include non-owning types though (like TensorList),
+# and when we unwrap them we expect to get out an owning type!.
+# We also return a lambda that tells you how to convert the non-owning type argument into the owning type.
+def get_owning_type(t: CType) -> tuple[CType, Callable[[str], str]]:
+    if t == BaseCType(tensorListT):
+        return VectorCType(BaseCType(tensorT)), lambda x: f"{x}.vec()"
+    if t == BaseCType(iTensorListRefT):
+        return VectorCType(BaseCType(tensorT)), lambda x: f"{{{x}.begin(), {x}.end()}}"
+    # There are technically other non-owning types out there (like IntArrayRef),
+    # but functionalization only actually cares about the ones involving tensors.
+    return t, lambda x: x
+
+
+# unwraps all tensor-like arguments, returning:
+# (1) a string containing all of the logic that does the unwrapping
+# (2) a context, to be used by translate(), with all of the relevant bindings.
+def unwrap_tensor_args(
+    sig: DispatcherSignature, *, is_view_op: bool
+) -> tuple[str, list[Binding]]:
+    context: list[Binding] = []
+    unwrapped_tensor_args: list[str] = []
+    for arg in sig.arguments():
+        if is_tensor_like(arg.argument):
+            # for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
+            unwrapped_name = f"{arg.name}_"
+            # For most ops, the functionalization needs to sync any pending updates on the input tensors
+            # before calling the operator, since otherwise the operator will act on stale data.
+            # For view ops though, we can continue to defer syncing until the tensor is used by
+            # a non-view operator.
+            maybe_sync_input = (
+                "" if is_view_op else f"at::functionalization::impl::sync({arg.name});"
+            )
+            unwrapped_type, conversion_fn = get_owning_type(
+                arg.nctype.remove_const_ref().type
+            )
+            unwrapped_tensor_args.append(
+                f"""
+      {unwrapped_type.cpp_type()} {unwrapped_name};
+      if (at::functionalization::impl::isFunctionalTensor({arg.name})) {{
+        {maybe_sync_input}
+        {unwrapped_name} = at::functionalization::impl::from_functional_tensor({arg.name});
+      }} else {{
+        {unwrapped_name} = {conversion_fn(arg.name)};
+      }}"""
+            )
+            context.append(arg.with_name(unwrapped_name))
+        else:
+            # for non-tensor inputs, we want to pass them directly into the redispatch calls.
+            context.append(arg)
+    unwrap_tensor_args_str = "\n      ".join(unwrapped_tensor_args)
+    return unwrap_tensor_args_str, context
+
+
+# converts  all tensor-like arguments to meta tensors, which are used to compute stride info. Returns:
+# (1) a string containing all of the logic that does the conversions.
+# (2) a context, to be used by translate(), with all of the relevant bindings.
+def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
+    context: list[Binding] = []
+    unwrapped_tensor_args: list[str] = []
+    for arg in sig.arguments():
+        if is_tensor_like(arg.argument):
+            # for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
+            a_ = arg.name
+            unwrapped_name = f"{arg.name}_meta"
+            unwrapped_tensor_args.append(f"auto {unwrapped_name} = to_meta({a_});")
+            context.append(arg.with_name(unwrapped_name))
+        else:
+            # for non-tensor inputs, we want to pass them directly into the redispatch calls.
+            context.append(arg)
+    unwrap_tensor_args_str = "\n        ".join(unwrapped_tensor_args)
+    return unwrap_tensor_args_str, context
+
+
+# The functionalization codegen currently expects view op schemas to have this form:
+# foo(Tensor(a), ...) -> Tensor(a) (e.g. transpose)
+# foo(Tensor(a!), ...) -> Tensor(a!) (e.g. transpose_)
+def assert_view_op_properties(func: FunctionSchema) -> None:
+    def is_alias(a: Argument) -> bool:
+        return a.annotation is not None
+
+    args = func.arguments.flat_non_out
+    # The first argument is a tensor with an alias semantics (annotations)
+    assert (
+        len(args) > 0 and args[0].type == BaseType(BaseTy.Tensor)
+    ), f"""In the functionalization codegen, we expect the first argument of every view operator to be a tensor,
+but found an argument of type {str(args[0].type)} for operator: {str(func.name)}."""
+    # No other arguments have aliasing semantics
+    assert (
+        is_alias(args[0]) and not any(is_alias(a) for a in args[1:])
+    ), """In the functionalization codegen, we expect the first argument of every view operator to alias the output.
+View operators with multiple aliasing inputs aren't supported yet. Found an operator that doesn't satisfy this constraint"""
+
+
+# One-liner expression for checking if an expression expr of type type has any
+# symbolic values.
+def emit_expr_has_symbolic_values(expr: str, type: CType) -> str:
+    if type == BaseCType(SymIntT):
+        return f"{expr}.is_symbolic()"
+
+    if isinstance(type, OptionalCType):
+        innerexpr = f"(*{expr})"
+        return f"{expr}.has_value() ? {emit_expr_has_symbolic_values(innerexpr, type.elem)} : false"
+
+    if type == BaseCType(optionalSymIntArrayRefT):
+        return emit_expr_has_symbolic_values(
+            expr, OptionalCType(BaseCType(symIntArrayRefT))
+        )
+
+    if type in (BaseCType(symIntArrayRefT), VectorCType(BaseCType(SymIntT))):
+        argname = "arg"
+        lambda_check = emit_expr_has_symbolic_values(argname, BaseCType(SymIntT))
+        return (
+            "std::any_of("
+            f"{expr}.begin(), {expr}.end(), "
+            f"[=](auto& {argname}) {{ return {lambda_check}; }})"
+        )
+
+    raise ValueError(
+        "unsupported type for has_symbolic_values check. "
+        "It should be a SymInt or a collection of those. "
+        f"Got: {type.cpp_type()}"
+    )
+
+
+# Detects whether any of the SymInt arguments are, in fact, symbolic values.
+# This is used in the constructor of ViewMeta.
+def emit_has_symbolic_inputs(sig: DispatcherSignature) -> tuple[str, str]:
+    name = "has_symbolic_inputs"
+    statements = [
+        f"{name} = {name} | ({emit_expr_has_symbolic_values(binding.name, binding.nctype.type)});"
+        for binding in sig.arguments()
+        if (
+            isinstance(binding.argument, Argument)
+            and binding.argument.type.is_symint_like()
+        )
+    ]
+    body = "\n      ".join(statements)
+    return (
+        name,
+        f"""
+      bool {name} = false;
+      {body}""",
+    )
+
+
+# Generates the Functionalization kernel for:
+# - ops that create aliases (e.g. transpose())
+# - ops that are views AND mutations (e.g. transpose_())
+def emit_view_functionalization_body(
+    g: NativeFunctionsViewGroup, *, view_inplace: bool
+) -> str:
+    if view_inplace:
+        # This op is both an inplace op AND a view op.
+        # See Note [Functionalization Pass - Inplace View Ops] for details.
+        # I currently have the view meta call into the out-of-place variant of the view, to avoid
+        # having to define an extra ~20 inplace {view}_inverse_ functions.
+        # Most view ops don't have NativeFunctionGroup's both, because we don't define out= variants for view ops.
+        # I'm assuming that every inplace-view op has a corresponding out-of-place view op,
+        # with the same name but the trailing underscore removed.
+        # This is currently asserted at parse time in gen.py (see error_check_native_functions).
+        assert g.view_inplace is not None
+        f = g.view_inplace
+    else:
+        f = g.view
+
+    assert g.view_copy is not None
+    with native_function_manager(f):
+        call_sig = DispatcherSignature.from_schema(g.view_copy.func)
+
+        # the "view_copy" op name that the functionalization kernels need to call
+        api_name = g.view_copy.func.name.unambiguous_name()
+        # Sometimes the functionalization pass needs to no-op (e.g. if it was passed non-functional tensors)
+        # "no-op"ing in this context is just redispatching to the original op.
+        noop_api_name = f.func.name.unambiguous_name()
+
+        dispatcher_sig = DispatcherSignature.from_schema(f.func)
+        assert_view_op_properties(f.func)
+        view_tensor_name = dispatcher_sig.arguments()[0].name
+
+        return_type = dispatcher_sig.returns_type().remove_const_ref().cpp_type()
+
+        unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args(
+            dispatcher_sig, is_view_op=True
+        )
+        view_redispatch_args = [
+            e.expr
+            for e in translate(unwrapped_args_ctx, call_sig.arguments(), method=False)
+        ]
+
+        forward_lambda = FunctionalizationLambda.from_func(g, is_reverse=False)
+        reverse_lambda = FunctionalizationLambda.from_func(g, is_reverse=True)
+
+        # The meta API call should use the same arguments, but convert all tensors to meta tensors first.
+        meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
+        meta_call_args = [
+            e.expr for e in translate(meta_call_ctx, call_sig.arguments(), method=False)
+        ]
+
+        (
+            symbolic_inputs_varname,
+            symbolic_inputs_check,
+        ) = emit_has_symbolic_inputs(call_sig)
+
+        if "inplace_view" in f.tags:
+            # See Note [Functionalization Pass - Inplace View Ops] for more details
+            return f"""
+    {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
+      if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{
+        // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper.
+        {unwrap_tensor_args_str}
+        at::AutoDispatchSkipFunctionalize guard;
+        return at::_ops::{noop_api_name}::call({", ".join(view_redispatch_args)});
+      }}
+      auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
+      auto inverse_return_mode = (
+          reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse
+            : at::functionalization::InverseReturnMode::NeverView
+      );
+      {symbolic_inputs_check}
+      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
+        {forward_lambda.decl()} {{
+          if (reapply_views) {{
+            return {forward_lambda.inner_call(reapply_views=True)}
+          }} else {{
+            return {forward_lambda.inner_call(reapply_views=False)}
+          }}
+        }},
+        {reverse_lambda.decl()} {{
+          return {reverse_lambda.inner_call()}
+        }},
+        /*has_symbolic_inputs=*/{symbolic_inputs_varname}
+      );
+      auto compute_reference_meta =
+        {view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
+        {view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
+      {return_type} reference_tensor_output;
+      if (compute_reference_meta && !disable_meta_reference()) {{
+        {meta_conversion_str}
+        at::AutoDispatchSkipFunctionalize func_guard;
+        c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
+        reference_tensor_output = at::_ops::{noop_api_name}::call({", ".join(meta_call_args)});
+      }}
+      // This function adds the above view meta to the current tensor and replays them off the base,
+      // mutating the size/stride info of the current FunctionalTensorWrapper.
+      // Because of this, we need to make sure to run the reference shape function above,
+      // BEFORE doing this (otherwise we'll end up running the reference function using the wrong sizes/strides)
+      at::functionalization::impl::mutate_view_meta({view_tensor_name}, view_meta);
+      // See  Note [Propagating strides in the functionalization pass]
+      // XLA/LTC don't implement the logic to propagate strides correctly, so we need to rely
+      // on a reference implementation here (instead of relying on the output from the forward lambda
+      // having the correct stride info)
+      if (compute_reference_meta && !disable_meta_reference()) {{
+        at::functionalization::impl::set_sizes_strides_offset({view_tensor_name}, reference_tensor_output);
+      }}
+      return {view_tensor_name};
+    }}
+"""
+
+        else:
+            is_multi_output_view = isinstance(f.func.returns[0].type, ListType)
+            return f"""
+    {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
+      {unwrap_tensor_args_str}
+      if (!at::functionalization::impl::isFunctionalTensor({view_tensor_name})) {{
+        // functionalization is re-entrant, but will no-op if it wasn't passed a FunctionalTensorWrapper.
+        at::AutoDispatchSkipFunctionalize guard;
+        return at::_ops::{noop_api_name}::call({", ".join(view_redispatch_args)});
+      }}
+      auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
+      auto inverse_return_mode = (
+          reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse
+            : at::functionalization::InverseReturnMode::NeverView
+      );
+      auto compute_reference_meta =
+        {view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
+        {view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
+      {return_type} reference_tensor_output;
+      if (compute_reference_meta && !disable_meta_reference()) {{
+        {meta_conversion_str}
+        at::AutoDispatchSkipFunctionalize func_guard;
+        c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
+        reference_tensor_output = at::_ops::{noop_api_name}::call({", ".join(meta_call_args)});
+      }}
+      {return_type} tmp_output;
+      {{
+        at::AutoDispatchSkipFunctionalize guard;
+        if (reapply_views) {{
+          tmp_output = at::_ops::{noop_api_name}::call({", ".join(view_redispatch_args)});
+        }} else {{
+          tmp_output = at::_ops::{api_name}::call({", ".join(view_redispatch_args)});
+        }}
+      }}
+      {symbolic_inputs_check}
+      at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta(
+        {forward_lambda.decl()} {{
+          if (reapply_views) {{
+            return {forward_lambda.inner_call(reapply_views=True)}
+          }} else {{
+            return {forward_lambda.inner_call(reapply_views=False)}
+          }}
+        }},
+        {reverse_lambda.decl()} {{
+          return {reverse_lambda.inner_call()}
+        }},
+        /*has_symbolic_inputs=*/{symbolic_inputs_varname},
+        /*is_multi_output=*/{str(is_multi_output_view).lower()},
+        /*is_as_strided=*/{str(str(f.func.name) == "as_strided").lower()}
+      );
+      auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta);
+      // See  Note [Propagating strides in the functionalization pass]
+      if (compute_reference_meta && !disable_meta_reference()) {{
+        at::functionalization::impl::set_sizes_strides_offset(out, reference_tensor_output);
+      }}
+      return out;
+    }}
+"""
+
+
+def maybe_create_output(f: NativeFunction, var_name: str) -> str:
+    if len(f.func.returns) == 0:
+        return ""
+    return_type = dispatcher.returns_type(f.func.returns).remove_const_ref().cpp_type()
+    return f"{return_type} {var_name} = "
+
+
+# Given a NativeFunction, and a variable name corresponding to the output of redispatching on the function,
+# this returns two lists of names, consisting of:
+# - the names of returns corresponding to the original (mutable) inputs of the outer function
+# - the names of returns corresponding to the (immutable) outputs of the inner redispatched function
+def get_mutable_redispatch_return_names(
+    f: NativeFunction, inner_return_var: str
+) -> tuple[list[str], list[str]]:
+    aliased_returns = []
+    non_aliased_returns = []
+    for i, name in enumerate(f.func.aliased_return_names()):
+        if name is not None:
+            aliased_returns.append(name)
+        else:
+            non_aliased_returns.append(
+                inner_return_var
+                if len(f.func.returns) == 1
+                else f"std::get<{i}>({inner_return_var})"
+            )
+    return aliased_returns, non_aliased_returns
+
+
+# When functionalization "no-op's" and redispatches on a mutable operator, we need to take care so that:
+#  - For fresh outputs, we return the result of the redispatch (without wrapping outputs)
+#  - For outputs that were aliased to inputs, we return the inputs directly (since some of them might have been wrapped)
+def return_from_mutable_noop_redispatch(
+    f: NativeFunction, inner_return_var: str
+) -> str:
+    aliased, non_aliased = get_mutable_redispatch_return_names(f, inner_return_var)
+    # Just get all of the return names, and immediately return them
+    return return_str(f.func.returns, aliased + non_aliased)
+
+
+def wrap_propagate_mutations_and_return(
+    f: NativeFunction, functional_op: NativeFunction, inner_return_var: str
+) -> str:
+    mutable_arg_names = f.func.arguments.mutable_arg_names()
+    (
+        aliased_outer_rets,
+        non_aliased_outer_rets,
+    ) = get_mutable_redispatch_return_names(f, inner_return_var)
+    _, non_aliased_inner_rets = get_mutable_redispatch_return_names(
+        functional_op, inner_return_var
+    )
+    # The outer function may have a mix of aliased and non-aliased outputs,
+    # But the inner functional op that we're transforming to should only have non-aliased outputs
+    assert len(mutable_arg_names) + len(non_aliased_outer_rets) == len(
+        non_aliased_inner_rets
+    )
+
+    # First, take all of the newly created outputs from the inner call and wrap them into functional tensors
+    updates = []
+    non_aliased_wrapped_ret_names = []
+    for i, inner_ret in enumerate(
+        non_aliased_inner_rets[: len(non_aliased_outer_rets)]
+    ):
+        ret_name = f"output_{i}"
+        updates.append(
+            f"""\
+  auto output_{i} = at::functionalization::impl::to_functional_tensor({inner_ret});"""
+        )
+        non_aliased_wrapped_ret_names.append(ret_name)
+
+    # Next, take all of the mutated outputs from the inner call corresponding to mutated inputs,
+    # and propagate the mutations
+    for outer_arg, inner_ret in zip(
+        mutable_arg_names, non_aliased_inner_rets[len(non_aliased_outer_rets) :]
+    ):
+        updates.append(
+            f"""\
+  auto {outer_arg}_inner = at::functionalization::impl::from_functional_tensor({outer_arg});
+  at::functionalization::impl::replace_({outer_arg}, {inner_ret});
+  at::functionalization::impl::commit_update({outer_arg});
+  at::functionalization::impl::sync({outer_arg});
+  auto {outer_arg}_inner_updated = at::functionalization::impl::from_functional_tensor({outer_arg});
+  at::functionalization::impl::propagate_xla_data_direct({outer_arg}_inner, {outer_arg}_inner_updated);"""
+        )
+
+    # Finally, we return:
+    # - Any mutable arguments that also returns
+    # - Any immutable returns that were created wrapping the output from the inner call
+    returns_str = return_str(
+        f.func.returns, aliased_outer_rets + non_aliased_wrapped_ret_names
+    )
+    updates_str = "\n".join(updates)
+    return f"""\
+{updates_str}
+    {returns_str}"""
+
+
+# Generates the Functionalization kernel for:
+# - mutation ops (inplace and out= ops)
+@with_native_function_and
+def emit_inplace_functionalization_body(
+    f: NativeFunction, g: NativeFunctionsGroup
+) -> str:
+    # mutation case
+    assert modifies_arguments(f)
+
+    dispatcher_sig = DispatcherSignature.from_schema(f.func)
+
+    unwrap_tensor_args_str, unwrapped_args_ctx = unwrap_tensor_args(
+        dispatcher_sig, is_view_op=False
+    )
+
+    mutated_names = [
+        a.name
+        for a in f.func.arguments.flat_all
+        if a.type.is_tensor_like() and a.annotation is not None
+    ]
+    non_mutated_names = [
+        a.name
+        for a in f.func.arguments.flat_all
+        if a.type.is_tensor_like() and a.annotation is None
+    ]
+    non_mutated_tensor_names = [
+        a.name
+        for a in f.func.arguments.flat_all
+        if a.type == BaseType(BaseTy.Tensor) and a.annotation is None
+    ]
+    # all mutable inputs must be functional tensors in order to participate in functionalization
+    check_all_mutated_args_are_functional = " && ".join(
+        ["true"]
+        + [
+            f"at::functionalization::impl::isFunctionalTensor({a})"
+            for a in mutated_names
+        ]
+    )
+    check_any_non_mutated_args_are_functional = " || ".join(
+        ["false"]
+        + [
+            f"at::functionalization::impl::isFunctionalTensor({a})"
+            for a in non_mutated_names
+        ]
+    )
+
+    check_any_non_mutated_tensors_are_xla = " || ".join(
+        ["false"]
+        + [
+            f"{a}.device().type() == c10::DeviceType::XLA"
+            for a in non_mutated_tensor_names
+        ]
+    )
+    # These are used in the cases where we don't functionalize and redispatch to the inplace op
+    # case 1: we hit an inplace op that doesn't have an out-of-place equivalent
+    # case 2: we hit an inplace ops but our inputs are not functional tensors (in which case our kernel just no-ops)
+    inplace_exprs = [
+        e.expr
+        for e in translate(unwrapped_args_ctx, dispatcher_sig.arguments(), method=False)
+    ]
+
+    # call the out-of-place variant of the op
+    return_type = (
+        dispatcher.returns_type(g.functional.func.returns).remove_const_ref().cpp_type()
+    )
+    functional_sig = DispatcherSignature.from_schema(g.functional.func)
+    functional_exprs = [
+        e.expr
+        for e in translate(unwrapped_args_ctx, functional_sig.arguments(), method=False)
+    ]
+
+    if f.func.is_out_fn():
+        mutable_input_post_processing = "\n".join(
+            [
+                f"""
+      at::functionalization::impl::replace_(
+        {a.name}, {"std::get<" + str(i) + ">(tmp_output)" if len(f.func.returns) > 1 else "tmp_output"});
+      at::functionalization::impl::commit_update({a.name});"""
+                for (i, a) in enumerate(f.func.arguments.out)
+                if a.annotation and a.annotation.is_write and a.type.is_tensor_like()
+            ]
+        )
+    else:
+        mutable_input_post_processing = "\n".join(  # noqa: F841
+            [
+                f"""
+      at::functionalization::impl::replace_({a.name}, tmp_output);
+      at::functionalization::impl::commit_update({a.name});"""
+                for a in f.func.arguments.flat_all
+                if a.annotation and a.annotation.is_write and a.type.is_tensor_like()
+            ]
+        )
+
+    meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
+    # We don't want to run the inplace meta func for ops like .set_(), because:
+    # (1) they're unnecessary: inplace meta checks are only useful for ops like add_(),
+    #     where broadcasting will work for the out-of-place case but should fail on the inplace call
+    # (2) They'll also fail without adding extra infra: we'd need to convert the input storage argument
+    #     into a meta storage
+    any_storage_args = any(
+        a.type == BaseType(BaseTy.Storage) for a in f.func.arguments.flat_all
+    )
+
+    return f"""
+    {dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
+      if ({str(not any_storage_args and f.func.kind() == SchemaKind.inplace).lower()} && !disable_meta_reference()) {{
+        // Before converting the mutable op to its functional variant, run meta tensors through the original op.
+        // This will help us catch shape errors that apply to inplace ops that wouldn't apply to their functional variants.
+        // (We can only do this for inplace ops today though, because they technically all support meta tensors).
+        {meta_conversion_str}
+        at::AutoDispatchSkipFunctionalize func_guard;
+        c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
+        at::_ops::{f.func.name.unambiguous_name()}::call({", ".join(a.name for a in meta_call_ctx)});
+      }}
+      {unwrap_tensor_args_str}
+      if (!({check_all_mutated_args_are_functional})) {{
+        // We want to disable this check if there are any XLA tensors.
+        // cpu_tensor.copy_(xla_tensor) is valid code.
+        if (!({check_any_non_mutated_tensors_are_xla}) && ({check_any_non_mutated_args_are_functional})) {{
+         // case 1: trying to mutate a non functional tensor with a functional tensor is an error
+         TORCH_INTERNAL_ASSERT(false,
+           "mutating a non-functional tensor with a functional tensor is not allowed.",
+           " Please ensure that all of your inputs are wrapped inside of a functionalize() call.");
+        }} else {{
+         // case 2: arguments are not functional tensors, so we no-op and redispatch.
+         at::AutoDispatchSkipFunctionalize guard;
+         {maybe_create_output(f, "tmp_output")}at::_ops::{f.func.name.unambiguous_name()}::call({", ".join(inplace_exprs)});
+         {return_from_mutable_noop_redispatch(f, "tmp_output")}
+        }}
+      }} else {{
+        {return_type} tmp_output;
+        {{
+          at::AutoDispatchSkipFunctionalize guard;
+          tmp_output = at::_ops::{g.functional.func.name.unambiguous_name()}::call({", ".join(functional_exprs)});
+        }}
+        {wrap_propagate_mutations_and_return(f, g.functional, "tmp_output")}
+      }}
+    }}"""
+
+
+# The below functions generate RegisterFunctionalization.cpp
+# These files provide the kernels that run the functionalization pass, which can be opted into
+# per backend (e.g. XLA or Vulkan), or as a composable transform (functionalize() in functorch).
+
+
+# See Note [Functionalization Pass: View Inverses].
+def gen_functionalization_view_inverse_declaration(
+    selector: SelectiveBuilder, g: NativeFunctionsViewGroup
+) -> str | None:
+    # For every (non-composite) view op, we need a corresponding "inverse view" function.
+    # This generates the declarations so we get a good compiler error when someone adds a new view.
+    @with_native_function
+    def emit_decl_helper(g: NativeFunctionsViewGroup) -> str | None:
+        if g.view.has_composite_implicit_autograd_kernel:
+            return None
+        view_inverse_sig = ViewInverseSignature(g)
+        return view_inverse_sig.decl()
+
+    return emit_decl_helper(g)
+
+
+def gen_functionalization_registration(
+    selector: SelectiveBuilder,
+    g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
+    composite_implicit_autograd_index: BackendIndex,
+) -> list[str]:
+    @with_native_function
+    def emit_registration_helper(f: NativeFunction) -> str:
+        assert not f.has_composite_implicit_autograd_kernel
+        registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})"
+        return f'm.impl("{f.func.name}", {registration_str});'
+
+    # Don't generate kernels in mobile build
+    if not selector.include_all_operators:
+        return []
+
+    if isinstance(g, NativeFunctionsViewGroup):
+        # functionalization needs to register kernels for view + view_inplace ops
+        # See Note [Functionalization <> torch.Tensor constructor]
+        if str(g.view.func.name) == "lift_fresh":
+            return []
+        view_str = []
+        if not g.view.has_composite_implicit_autograd_kernel:
+            view_str.append(emit_registration_helper(g.view))
+        if (
+            g.view_inplace is not None
+            and not g.view_inplace.has_composite_implicit_autograd_kernel
+        ):
+            assert g.view_inplace.is_view_op
+            view_str.append(emit_registration_helper(g.view_inplace))
+        return view_str
+
+    elif isinstance(g, NativeFunctionsGroup):
+        # Gets a hand-written functionalization kernel
+        if g.inplace is not None and str(g.inplace.func.name) == "set_.source_Tensor":
+            fns = []
+        else:
+            fns = list(g.functions())
+    else:
+        if str(g.func.name) in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION:
+            return []
+        fns = [g]
+
+    registrations = []
+    for f in fns:
+        if f.has_composite_implicit_autograd_kernel:
+            continue
+        if str(f.func.name) == "lift":
+            # See Note [Functionalization <> torch.Tensor constructor]
+            return []
+        if str(f.func.name) == "resize_":
+            # See Note [resize_ in Functionalization]
+            return []
+        if str(f.func.name.name) != "set_":
+            assert not f.is_view_op
+        # functionalization needs to generate and register kernels for inplace ops.
+        # We *also* need to directly register CompositeImplicitAUtograd kernels
+        # so that they decompose properly before functioanlization.
+        if modifies_arguments(f):
+            registrations.append(emit_registration_helper(f))
+    return registrations
+
+
+def gen_functionalization_definition(
+    selector: SelectiveBuilder,
+    # Note: Ideally this code should never have to look at NativeFunction
+    # (and instead only need to operate on grouped NativeFunctions).
+    # The only reason currently is because we need to emit direct dispatch registrations
+    # For CompositeImplicitAutograd operators, which are potentially ungrouped.
+    g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
+) -> list[str]:
+    # Don't generate kernels in mobile build
+    if not selector.include_all_operators:
+        return []
+
+    if isinstance(g, NativeFunctionsViewGroup):
+        # Case 1: emit view -> view_copy kernels for the functionalization pass
+        view_defs = []
+        if not g.composite:
+            # invariant: NativeFunctionsViewGroup's always have a view_copy operator
+            # if the view is not composite (implicit autograd)
+            assert g.view_copy is not None, dataclass_repr(g, indent=1)
+            view_defs.append(emit_view_functionalization_body(g, view_inplace=False))
+            if g.view_inplace is not None:
+                view_defs.append(emit_view_functionalization_body(g, view_inplace=True))
+        return view_defs
+    elif isinstance(g, NativeFunction):
+        # Invariant: all mutable operators that we need to handle in functionalization
+        # should have been properly grouped up.
+        # TODO: The below ops all have "problematic" schemas that prevent them from
+        # getting functionalized. Instead of bending over backwards to get things to work,
+        # I think we should either:
+        # (1) fix their schemas (BC-breaking)
+        # (2) hand-write their functionalization kernels
+        if (
+            str(g.func.name) not in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION
+            and str(g.func.name.name) not in MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION
+        ):
+            assert g.has_composite_implicit_autograd_kernel or not modifies_arguments(g)
+        return []
+    else:
+        # Case 2: emit inplace -> out-of-place kernels for the functionalization pass
+        mutation_defs = []
+        mutation_defs.append(emit_inplace_functionalization_body(g.out, g))
+        if g.inplace is not None:
+            mutation_defs.append(emit_inplace_functionalization_body(g.inplace, g))
+        if g.mutable is not None:
+            mutation_defs.append(emit_inplace_functionalization_body(g.mutable, g))
+        return mutation_defs
+    return []
diff --git a/phivenv/Lib/site-packages/torchgen/gen_lazy_tensor.py b/phivenv/Lib/site-packages/torchgen/gen_lazy_tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..cad785c0b9296dd4d8778ab17133923b43a8c188
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/gen_lazy_tensor.py
@@ -0,0 +1,585 @@
+from __future__ import annotations
+
+import argparse
+import os
+from collections import namedtuple
+from pathlib import Path
+from typing import Any, Callable, TYPE_CHECKING
+
+import yaml
+
+import torchgen.dest as dest
+from torchgen.api.lazy import setValueT
+from torchgen.api.types import BaseCppType
+from torchgen.dest.lazy_ir import GenLazyIR, GenLazyNativeFuncDefinition, GenTSLazyIR
+from torchgen.gen import get_grouped_native_functions, parse_native_yaml
+from torchgen.gen_backend_stubs import (
+    error_on_missing_kernels,
+    gen_dispatcher_registrations,
+    gen_dispatchkey_nativefunc_headers,
+    parse_backend_yaml,
+)
+from torchgen.model import NativeFunction, NativeFunctionsGroup, OperatorName
+from torchgen.selective_build.selector import SelectiveBuilder
+from torchgen.utils import FileManager, NamespaceHelper
+from torchgen.yaml_utils import YamlLoader
+
+
+if TYPE_CHECKING:
+    from collections.abc import Iterable, Iterator, Sequence
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                        Lazy Tensor Codegen
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+# Overview
+# ~~~~~~~~
+#
+# This codegen script builds on existing data models and helpers used
+# by all ATen backends, and adds new functionality specific to lazy
+# tensor backends.
+#
+# Inputs:
+# - _native_functions.yaml: controls which operators are
+#   supported by the backend.
+#
+# Outputs:
+# (for all backends)
+# Ir.h defines Lazy IR classes to be constructed during tracing
+# - opt-in: also generate 'lowering' methods for the TorchScript backend only
+# NativeFunctions.cpp defines implementations of native functions which perform lazy tracing
+# - opt-in: 'full_codegen' section of backend yaml; 'supported' section omits these implementations
+# NativeFunctions.h declares implementations of native functions for both 'supported' and 'full_codegen'
+# ops
+#
+# Register.cpp registers all op implementations with the dispatcher
+# RegisterAutograd.cpp registers all autograd implementations with the dispatcher
+#
+# Validation Helpers:
+# - Shape Inference: errs if any ops in backend yaml require shape inference not provided by meta kernels or
+#   implementations in torch/csrc/lazy/core/shape_inference.*
+# - native function impls: errs if any 'supported' ops do not have an implementation defined in the backend
+#   (non-codegen) implementation file
+#
+#
+# About the Data Model
+# ~~~~~~~~~~~~~~~~~~~~
+#
+# Modeled after ATen codegen, the first step is to parse yaml and build a data model for the operators
+# we care about.  In this case, the _native_functions yaml defines a subset of the core operators
+# (defined in more detail in the main native_functions.yaml), which will be supported by your backend.
+# Backends can list ops in two categories:
+#  - `supported` ops require hand-implementations but still get codegenned declarations and registrations
+#  - `full_codegen` ops get implementations (and IR classes) generated too
+#
+# Each native function is modeled as an object with a schema, and each schema has objects representing their
+# arguments.  Much of the codegen is manipulation of the arguments and their types.  For example, lazy tensor
+# backends need to transform 'at::Tensor' arguments into 'lazy::Value' objects, as well as replacing reference
+# types (stringref) with actual string objects, and this is done by manipulating the data model objects.
+# - see api/lazy.py for the lazy data model
+#
+# Once the data model is set up, the rest of this script processes a number of templates for output CPP file
+# and fills in the template values using helpers in `dest/lazy_ir.py` and `dest/lazy_ts_lowering.py`.  These
+# helpers mostly iterate over functions and their arguments, outputting different c++ snippets.
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+# Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key.
+# Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping, full_codegen)
+ParsedExternalYaml = namedtuple(
+    "ParsedExternalYaml",
+    ["backend_key", "autograd_key", "cpp_namespace", "backend_indices", "full_codegen"],
+)
+
+
+def parse_native_functions_keys(
+    backend_yaml_path: str,
+    grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
+) -> tuple[list[OperatorName], list[Any], list[OperatorName]]:
+    with open(backend_yaml_path) as f:
+        yaml_values = yaml.load(f, Loader=YamlLoader)
+    assert isinstance(yaml_values, dict)
+
+    full_codegen = yaml_values.pop("full_codegen", [])
+    non_native = yaml_values.pop("non_native", [])
+    ir_gen = yaml_values.pop("ir_gen", [])
+    assert isinstance(full_codegen, list)
+    assert isinstance(non_native, list)
+    assert isinstance(ir_gen, list)
+    full_codegen_opnames = [OperatorName.parse(name) for name in full_codegen]
+    ir_gen_opnames = [OperatorName.parse(name) for name in ir_gen]
+    return full_codegen_opnames, non_native, ir_gen_opnames
+
+
+def validate_shape_inference_header(
+    shape_inference_hdr: str, expected_shape_infr_decls: list[str]
+) -> None:
+    try:
+        with open(shape_inference_hdr) as f:
+            shape_infr_decls = f.read()
+            shape_infr_decl_lines = set(shape_infr_decls.split("\n"))
+    except OSError as e:
+        raise AssertionError(
+            f"Unable to read from the specified shape_inference_hdr file: {shape_inference_hdr}"
+        ) from e
+
+    # TODO(whc) add a check for shape inference functions that have meta kernels implement and should be retired.
+
+    missing_decls = [
+        decl for decl in expected_shape_infr_decls if decl not in shape_infr_decl_lines
+    ]
+    if missing_decls:
+        raise Exception(  # noqa: TRY002
+            f"""Missing shape inference function.\n
+Please add declare this function in {shape_inference_hdr}:\n
+and implement it in the corresponding shape_inference.cpp file.\n
+{os.linesep.join(missing_decls)}"""
+        )
+
+
+# Some helper functions for the codegen.
+def get_ltc_helper_fns() -> str:
+    return """\
+at::Tensor to_meta(const at::Tensor& tensor) {
+  // undefined tensors can't be converted to the meta device, since they don't have sizes/strides
+  if (!tensor.defined()) return tensor;
+  auto out = at::native::empty_strided_meta_symint(tensor.sym_sizes(), tensor.sym_strides(), \
+/*dtype=*/tensor.scalar_type(), /*layout=*/tensor.layout(), \
+/*device=*/c10::Device(c10::kMeta), /*pin_memory=*/std::nullopt);
+  // needs to handle wrapped numbers, so dtype promotion works properly.
+  if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
+    out.unsafeGetTensorImpl()->set_wrapped_number(true);
+  }
+  return out;
+}
+std::optional to_meta(const std::optional& tensor) {
+  if (tensor.has_value()) {
+    return to_meta(*tensor);
+  }
+  return std::nullopt;
+}
+
+std::vector to_meta(at::ITensorListRef t_list) {
+  std::vector outs;
+  outs.reserve(t_list.size());
+  for (const auto& tensor : t_list) {
+    outs.push_back(to_meta(tensor));
+  }
+  return outs;
+}
+"""
+
+
+class default_args:
+    node_base: str = "Node"
+    node_base_hdr: str | None = None
+    shape_inference_hdr: str = "torch/csrc/lazy/core/shape_inference.h"
+    tensor_class: str = "torch::lazy::LazyTensor"
+    tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h"
+    lazy_ir_generator: type[GenLazyIR] = GenLazyIR
+    native_func_definition_generator: type[GenLazyNativeFuncDefinition] = (
+        GenLazyNativeFuncDefinition
+    )
+    backend_name: str = "TorchScript"
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(description="Generate Lazy Tensor backend files")
+    parser.add_argument(
+        "-s",
+        "--source-yaml",
+        "--source_yaml",
+        help="path to source yaml file containing operator external definitions",
+    )
+    parser.add_argument("-o", "--output-dir", "--output_dir", help="output directory")
+    parser.add_argument(
+        "--dry-run", "--dry_run", type=bool, default=False, help="output directory"
+    )
+    parser.add_argument(
+        "--impl-path",
+        "--impl_path",
+        type=str,
+        default=None,
+        help="path to the source C++ file containing kernel definitions",
+    )
+    parser.add_argument(
+        "--gen-ts-lowerings",
+        "--gen_ts_lowerings",
+        action="store_true",
+        help="Generate TorchScript lowerings in addition to Lazy IR and NativeFunctions",
+    )
+    parser.add_argument(
+        "--node-base",
+        "--node_base",
+        type=str,
+        default=default_args.node_base,
+        help="Name of backend specific custom Lazy IR Node base class",
+    )
+    parser.add_argument(
+        "--node-base-hdr",
+        "--node_base_hdr",
+        type=str,
+        default=default_args.node_base_hdr,
+        help="Path to header file defining custom Lazy IR Node base class",
+    )
+    parser.add_argument(
+        "--shape-inference-hdr",
+        "--shape_inference_hdr",
+        type=str,
+        default=default_args.shape_inference_hdr,
+        help="Path to header file defining custom Lazy shape inference functions",
+    )
+    parser.add_argument(
+        "--tensor-class",
+        "--tensor_class",
+        type=str,
+        default=default_args.tensor_class,
+        help="Name of backend specific custom Lazy Tensor class",
+    )
+    parser.add_argument(
+        "--tensor-class-hdr",
+        "--tensor_class_hdr",
+        type=str,
+        default=default_args.tensor_class_hdr,
+        help="Path to header file defining custom Lazy Tensor class",
+    )
+    parser.add_argument(
+        "--backend-name",
+        "--backend_name",
+        type=str,
+        default=default_args.backend_name,
+        help="Name of the backend to generate",
+    )
+    options = parser.parse_args()
+
+    # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
+    torch_root = Path(__file__).absolute().parents[2]
+    aten_path = str(torch_root / "aten" / "src" / "ATen")
+    lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator
+    if options.gen_ts_lowerings:
+        lazy_ir_generator = GenTSLazyIR
+    native_func_definition_generator: type[GenLazyNativeFuncDefinition] = (
+        default_args.native_func_definition_generator
+    )
+
+    run_gen_lazy_tensor(
+        aten_path,
+        options.source_yaml,
+        options.output_dir,
+        options.dry_run,
+        options.impl_path,
+        options.node_base,
+        options.node_base_hdr,
+        options.tensor_class,
+        options.tensor_class_hdr,
+        options.shape_inference_hdr,
+        lazy_ir_generator,
+        native_func_definition_generator,
+        options.backend_name,
+    )
+
+
+def run_gen_lazy_tensor(
+    aten_path: str,
+    source_yaml: str,
+    output_dir: str,
+    dry_run: bool,
+    impl_path: str | None,
+    node_base: str = default_args.node_base,
+    node_base_hdr: str | None = default_args.node_base_hdr,
+    tensor_class: str = default_args.tensor_class,
+    tensor_class_hdr: str = default_args.tensor_class_hdr,
+    shape_inference_hdr: str = default_args.shape_inference_hdr,
+    lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator,
+    native_func_definition_generator: type[
+        GenLazyNativeFuncDefinition
+    ] = default_args.native_func_definition_generator,
+    # build_in_tree is true for TS backend and affects include paths
+    build_in_tree: bool = False,
+    # per_operator_headers changes whether ATen/Functions.h or individual operator headers are used
+    # it must match how ATen was built
+    per_operator_headers: bool = False,
+    backend_name: str = default_args.backend_name,
+    gen_forced_fallback_code: bool = False,
+    use_lazy_shape: bool = True,
+    # the following arguments are temporary customization points for xla backend migration.
+    # do not rely on them otherwise, they should be removed once migration is complete
+    backend_namespace: str = "torch::lazy",
+    get_tensorlist: str = "GetTensorList",
+    get_tensor_or_wrap_number: str = "GetLtcTensorOrCreateForWrappedNumber",
+    try_get_tensor: str = "TryGetLtcTensor",
+    metrics_counter: str = 'TORCH_LAZY_FN_COUNTER("lazy::")',
+    create_tensor: str = "LazyTensor::Create",
+    create_from_first_tensor: bool = False,
+    create_aten_from_ltc_tensor: str = "torch::lazy::CreateAtenFromLtcTensor",
+    tuple_aten_from_ltc_tensors: str = "torch::lazy::TupleAtenFromLtcTensors",
+    lazy_value_class: str = "torch::lazy::Value",
+    lazy_tensor_ptr: str = "LazyTensorPtr",
+    get_device_fn: str = "torch::lazy::GetBackendDevice",
+) -> None:
+    lv_tokens = lazy_value_class.split("::")
+    lv_class = lv_tokens[-1]
+    lv_ns = "::".join(lv_tokens[:-1])
+    setValueT(BaseCppType(lv_ns, lv_class))
+    template_dir = os.path.join(aten_path, "templates")
+
+    def make_file_manager(install_dir: str) -> FileManager:
+        return FileManager(
+            install_dir=install_dir, template_dir=template_dir, dry_run=dry_run
+        )
+
+    fm = make_file_manager(output_dir)
+
+    native_yaml_path = os.path.join(aten_path, "native/native_functions.yaml")
+    tags_yaml_path = os.path.join(aten_path, "native/tags.yaml")
+    parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
+    native_functions, backend_indices = (
+        parsed_yaml.native_functions,
+        parsed_yaml.backend_indices,
+    )
+    grouped_native_functions = get_grouped_native_functions(native_functions)
+
+    def sort_native_function(f: NativeFunctionsGroup | NativeFunction) -> str:
+        """
+        We sort the native function because of the note in concat_map_codegen.
+        TODO(alanwaketan): Remove this sorting hack once all ops are grouped properly.
+        """
+        func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
+        return str(func.name.name)
+
+    grouped_native_functions = sorted(
+        grouped_native_functions, key=sort_native_function
+    )
+
+    parsed_backend_yaml = parse_backend_yaml(
+        source_yaml, grouped_native_functions, backend_indices
+    )
+    backend_key = parsed_backend_yaml.backend_key
+    autograd_key = parsed_backend_yaml.autograd_key
+    cpp_namespace = parsed_backend_yaml.cpp_namespace
+    backend_indices = parsed_backend_yaml.backend_indices
+    # the following 3 keys are all processed differently
+    # for full_codegen, we generate IR, kernels, etc
+    # for ir_gen, we generate only IR
+    # non_native is used to register kernels not declared in
+    # native_functions.yaml
+    full_codegen, non_native, ir_gen = parse_native_functions_keys(
+        source_yaml, grouped_native_functions
+    )
+
+    def concat_map_codegen(
+        func: Callable[[NativeFunction], Sequence[str]],
+        xs: Iterable[NativeFunctionsGroup | NativeFunction],
+        ops_list: list[OperatorName] = full_codegen,
+    ) -> Iterator[str]:
+        """
+        We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we
+        only code-gen additional entries for the inplace variant for the native functions.
+        """
+
+        for x in xs:
+            fs = list(x.functions()) if isinstance(x, NativeFunctionsGroup) else [x]
+            for f in fs:
+                if f.func.name in ops_list:
+                    yield from func(f)
+
+    selector = SelectiveBuilder.get_nop_selector()
+
+    assert backend_key is not None
+    class_name = backend_indices[backend_key].native_function_class_name()
+
+    if impl_path is not None:
+        error_on_missing_kernels(
+            native_functions,
+            backend_indices,
+            backend_key,
+            autograd_key,
+            class_name,
+            impl_path,
+            full_codegen,
+        )
+
+    """ Validate Shape Inference Definitions
+
+    Generated lazy native functions all perform shape inference, by first using a meta:: kernel
+    if available for that op, and otherwise using a 'compute_shape_{op}' function instead.  The generator
+    knows the call signature for compute_shape_{op} because it matches the nativefunction (and meta::) signature,
+    so it just has to check whether the op is structured and generate a call for one or the other.  It's up to the dev
+    to supply the missing compute_shape_{op} function, but the codegen at least warns you about this and provides
+    the expected signature which can be copy-pasted into shape_inference.h.
+
+    compute_shape_{op} functions are handwritten and should be replaced over time as ops get ported
+    to structured kernels.
+
+    See torch/csrc/lazy/core/shape_inference.cpp #READ THIS! for more information.
+    """
+    if shape_inference_hdr is not None:
+        expected_shape_infr_decls = list(
+            concat_map_codegen(
+                dest.GenLazyShapeInferenceDefinition(
+                    backend_indices[backend_key], tensor_class
+                ),
+                grouped_native_functions,
+            )
+        )
+
+        validate_shape_inference_header(shape_inference_hdr, expected_shape_infr_decls)
+    assert class_name is not None
+
+    # Generate nativefunction declarations
+    # Note, eager registrations is set to False for the lazy TS backend as another LTC backend
+    # may want to register their own lazy kernels instead of registering the TS ones.
+    # The registration will lazily happen when init_ts_backend is called.
+    gen_dispatchkey_nativefunc_headers(
+        fm,
+        class_name,
+        cpp_namespace,
+        backend_indices,
+        grouped_native_functions,
+        backend_key,
+        autograd_key,
+        backend_name,
+    )
+
+    # Generate Dispatcher registrations which hook up the nativefunctions
+    for dispatch_key in (
+        [backend_key] if autograd_key is None else [backend_key, autograd_key]
+    ):
+        gen_dispatcher_registrations(
+            fm,
+            output_dir,
+            class_name,
+            backend_indices,
+            grouped_native_functions,
+            backend_key,
+            dispatch_key,
+            selector,
+            build_in_tree=build_in_tree,
+            per_operator_headers=per_operator_headers,
+            backend_name=backend_name,
+            eager_registration=False,
+        )
+
+    # Generate native function impls that build IR nodes
+    ns_helper = NamespaceHelper(cpp_namespace)
+    fm.write_with_template(
+        f"{backend_key}NativeFunctions.cpp",
+        "DispatchKeyNativeFunctions.cpp",
+        lambda: {
+            "includes": [
+                f"#include <{path}>"
+                for path in [
+                    tensor_class_hdr,
+                    shape_inference_hdr,
+                    "ATen/Functions.h",
+                    "ATen/native/TensorConversions.h",
+                    "ATen/NativeFunctions.h",
+                    "ATen/CompositeExplicitAutogradNonFunctionalFunctions.h",
+                    "ATen/MetaFunctions.h",
+                    "ATen/Operators.h",
+                    "ATen/native/CPUFallback.h",
+                    "torch/csrc/lazy/core/ir_builder.h",
+                    "torch/csrc/lazy/core/lazy_graph_executor.h",
+                    "torch/csrc/lazy/core/metrics.h",
+                    "torch/csrc/lazy/core/shape.h",
+                    f"{output_dir}/{backend_key}NativeFunctions.h",
+                    f"{output_dir}/LazyIr.h",
+                ]
+                + (
+                    ["torch/csrc/lazy/ts_backend/ts_eager_fallback.h"]
+                    if gen_forced_fallback_code
+                    else []
+                )
+            ],
+            "helper_fns": get_ltc_helper_fns(),
+            "native_functions_include": "",
+            "namespace_prologue": ns_helper.prologue,
+            "namespace_epilogue": ns_helper.epilogue,
+            "native_function_definitions": list(
+                concat_map_codegen(
+                    native_func_definition_generator(
+                        f"{backend_key}NativeFunctions",
+                        backend_indices[backend_key],
+                        tensor_class,
+                        gen_forced_fallback_code,
+                        backend_namespace,
+                        get_tensorlist,
+                        get_tensor_or_wrap_number,
+                        try_get_tensor,
+                        metrics_counter,
+                        create_tensor,
+                        create_from_first_tensor,
+                        create_aten_from_ltc_tensor,
+                        tuple_aten_from_ltc_tensors,
+                        lazy_tensor_ptr,
+                        get_device_fn,
+                    ),
+                    grouped_native_functions,
+                )
+            ),
+        },
+    )
+    # Generate IR node classes
+    lazy_ir_obj = lazy_ir_generator(
+        backend_indices[backend_key], backend_name, node_base, use_lazy_shape
+    )
+
+    fm.write_with_template(
+        "LazyIr.h",
+        "LazyIr.h",
+        lambda: {
+            "lazy_ir_sysinc": [
+                f"#include <{path}>"
+                for path in [
+                    "ATen/core/Formatting.h",
+                    "c10/core/ScalarType.h",
+                    "torch/csrc/lazy/core/hash.h",
+                    "torch/csrc/lazy/core/ir.h",
+                    "torch/csrc/lazy/core/shape.h",
+                    "optional",
+                    "vector",
+                ]
+            ],
+            "lazy_ir_inc": [f'#include "{node_base_hdr}"']
+            if node_base_hdr is not None
+            else [],
+            "ir_declarations": list(
+                concat_map_codegen(
+                    lazy_ir_obj, grouped_native_functions, full_codegen + ir_gen
+                )
+            ),
+            "namespace_prologue": ns_helper.prologue,
+            "namespace_epilogue": ns_helper.epilogue,
+        },
+    )
+
+    # Generate Non Native IR Node classes
+    fm.write_with_template(
+        "LazyNonNativeIr.h",
+        "LazyNonNativeIr.h",
+        lambda: {
+            "lazy_non_native_ir_inc": [
+                f"#include <{path}>"
+                for path in [
+                    "torch/csrc/lazy/core/ir.h",
+                    "torch/csrc/lazy/core/ir_builder.h",
+                    "torch/csrc/lazy/core/internal_ops/ltc_ops.h",
+                    "torch/csrc/lazy/core/shape_inference.h",
+                ]
+                + ([node_base_hdr] if node_base_hdr else [])
+                if path
+            ],
+            "non_native_ir_nodes": dest.generate_non_native_lazy_ir_nodes(
+                non_native, lazy_ir_obj
+            ),
+            "namespace_prologue": ns_helper.prologue,
+            "namespace_epilogue": ns_helper.epilogue,
+        },
+    )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/phivenv/Lib/site-packages/torchgen/gen_schema_utils.py b/phivenv/Lib/site-packages/torchgen/gen_schema_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2d19fb9306c8d1233ced11bc038222c1c735c80
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/gen_schema_utils.py
@@ -0,0 +1,97 @@
+from typing import Any, Optional, Union
+
+from torchgen.model import (
+    Annotation,
+    Argument,
+    Arguments,
+    BaseOperatorName,
+    BaseTy,
+    BaseType,
+    CustomClassType,
+    FunctionSchema,
+    ListType,
+    OperatorName,
+    Return,
+)
+
+
+# Note: These aren't actually used in torchgen, they're some utilities for generating a schema
+# from real arguments. For example, this is used to generate HigherOrderOperators' schema since
+# their schemas can vary for different instances of the same HOP.
+
+
+class TypeGen:
+    convert_to_base_ty = {
+        int: BaseTy.int,
+        float: BaseTy.float,
+        str: BaseTy.str,
+        bool: BaseTy.bool,
+    }
+
+    @staticmethod
+    def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]:
+        import torch
+
+        if isinstance(obj, torch.fx.GraphModule):
+            return BaseType(BaseTy.GraphModule)
+        elif isinstance(obj, torch.Tensor):
+            return BaseType(BaseTy.Tensor)
+        elif isinstance(obj, torch.SymInt):
+            return BaseType(BaseTy.SymInt)
+        elif isinstance(obj, torch.SymBool):
+            return BaseType(BaseTy.SymBool)
+        elif isinstance(obj, torch.ScriptObject):
+            return CustomClassType(obj._type().name())  # type: ignore[attr-defined]
+        elif isinstance(obj, (list, tuple)):
+            assert len(obj) > 0
+            all_base_tys = [TypeGen.from_example(x) for x in obj]
+            if len(set(all_base_tys)) > 1:
+                raise RuntimeError(
+                    f"Cannot generate schema for a sequence of args of heterogeneous types: {all_base_tys}. "
+                    "Consider unpacking the argument and give proper names to them if possible "
+                    "instead of using *args."
+                )
+            return ListType(all_base_tys[0], len(obj))
+        tp = type(obj)
+        if tp not in TypeGen.convert_to_base_ty:
+            raise RuntimeError(f"unsupported type {tp}")
+        return BaseType(TypeGen.convert_to_base_ty[tp])
+
+
+class ReturnGen:
+    @staticmethod
+    def from_example(
+        name: Optional[str], obj: Any, annotation: Optional[Annotation]
+    ) -> Return:
+        return Return(name, TypeGen.from_example(obj), annotation)
+
+
+class ArgumentGen:
+    @staticmethod
+    def from_example(
+        name: str, obj: Any, default: Optional[str], annotation: Optional[Annotation]
+    ) -> Argument:
+        return Argument(
+            name, TypeGen.from_example(obj), default=default, annotation=annotation
+        )
+
+
+class FunctionSchemaGen:
+    @staticmethod
+    def from_example(
+        op_name: str,
+        example_inputs: tuple[tuple[str, Any], ...],
+        example_outputs: tuple[Any, ...],
+    ) -> FunctionSchema:
+        args = []
+        for name, inp in example_inputs:
+            args.append(ArgumentGen.from_example(name, inp, None, None))
+        # ignore the annotations and other attributes for now, we could add more when needed.
+        arguments = Arguments(
+            tuple(), None, tuple(args), tuple(), None, tuple(), tuple()
+        )
+        returns = tuple(
+            ReturnGen.from_example(None, out, None) for out in example_outputs
+        )
+        op_name = OperatorName(BaseOperatorName(op_name, False, False, False), "")
+        return FunctionSchema(op_name, arguments, returns)
diff --git a/phivenv/Lib/site-packages/torchgen/gen_vmap_plumbing.py b/phivenv/Lib/site-packages/torchgen/gen_vmap_plumbing.py
new file mode 100644
index 0000000000000000000000000000000000000000..506fd9cbc59e61e336c7e8ab02e12ca01afde509
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/gen_vmap_plumbing.py
@@ -0,0 +1,275 @@
+from __future__ import annotations
+
+import textwrap
+from dataclasses import dataclass
+from typing import TYPE_CHECKING
+
+from torchgen.api.translate import translate
+from torchgen.api.types import DispatcherSignature
+from torchgen.context import method_with_native_function
+from torchgen.model import (
+    Argument,
+    BaseTy,
+    BaseType,
+    FunctionSchema,
+    ListType,
+    NativeFunction,
+    OptionalType,
+    Return,
+    SchemaKind,
+    Type,
+)
+from torchgen.utils import mapMaybe
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+
+def is_tensor(typ: Type) -> bool:
+    return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor
+
+
+def is_optional_tensor(typ: Type) -> bool:
+    return isinstance(typ, OptionalType) and is_tensor(typ.elem)
+
+
+def is_tensor_list(typ: Type) -> bool:
+    return isinstance(typ, ListType) and is_tensor(typ.elem)
+
+
+def unwrap_tensor(name: str, cur_level_var: str) -> list[str]:
+    result = f"""\
+    auto [{name}_value, {name}_bdim] = unwrapTensorAtLevel({name}, {cur_level_var});"""
+    return textwrap.dedent(result).split("\n")
+
+
+def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]:
+    result = f"""\
+    std::optional {name}_value;
+    std::optional {name}_bdim;
+    if ({name}) {{
+        std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}.value(), {cur_level_var});
+    }}"""
+    return textwrap.dedent(result).split("\n")
+
+
+def gen_unwraps(
+    flat_arguments: Sequence[Argument], cur_level_var: str
+) -> tuple[str, list[str]]:
+    arg_names = [a.name for a in flat_arguments]
+    arg_types = [a.type for a in flat_arguments]
+
+    tensors = [name for typ, name in zip(arg_types, arg_names) if is_tensor(typ)]
+    optional_tensors = [
+        name for typ, name in zip(arg_types, arg_names) if is_optional_tensor(typ)
+    ]
+
+    unwraps = []
+    for tensor in tensors:
+        unwraps += unwrap_tensor(tensor, cur_level_var)
+
+    for opt_tensor in optional_tensors:
+        unwraps += unwrap_optional_tensor(opt_tensor, cur_level_var)
+    unwrap_code = "\n".join(unwraps)
+
+    unwrapped_arg_list = []
+    for arg in arg_names:
+        if arg in tensors or arg in optional_tensors:
+            unwrapped_arg_list += [f"{arg}_value", f"{arg}_bdim"]
+        else:
+            unwrapped_arg_list.append(arg)
+    return unwrap_code, unwrapped_arg_list
+
+
+def gen_case_where_all_bdims_are_none(
+    outer_sig: DispatcherSignature, schema: FunctionSchema, cur_level_var: str
+) -> str:
+    conditions = []
+    flat_args = schema.arguments.flat_all
+    for arg in flat_args:
+        if not arg.type.is_tensor_like():
+            continue
+        conditions.append(f"!isBatchedAtLevel({arg.name}, {cur_level_var})")
+
+    sig = DispatcherSignature.from_schema(schema)
+    translated_args = ", ".join(
+        e.expr for e in translate(outer_sig.arguments(), sig.arguments())
+    )
+    return f"""\
+if ({" && ".join(conditions)}) {{
+  return at::_ops::{sig.func.name.unambiguous_name()}::call({translated_args});
+}}"""
+
+
+def gen_returns(
+    returns: tuple[Return, ...], cur_level_var: str, results_var: str
+) -> str:
+    idx = 0
+    wrapped_returns = []
+    for ret in returns:
+        if is_tensor(ret.type):
+            wrapped_returns.append(
+                f"makeBatched(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})"
+            )
+            idx += 2
+        elif is_tensor_list(ret.type):
+            wrapped_returns.append(
+                f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})"
+            )
+            idx += 2
+        else:
+            wrapped_returns.append(f"std::get<{idx}>({results_var})")
+            idx += 1
+    if len(wrapped_returns) == 1:
+        result = f"return {wrapped_returns[0]};"
+    else:
+        result = f"return std::make_tuple({', '.join(wrapped_returns)});"
+    return result
+
+
+def accepts_at_least_one_tensor_input(schema: FunctionSchema) -> bool:
+    return any(a.type.is_tensor_like() for a in schema.arguments.flat_all)
+
+
+def is_mutated_arg(argument: Argument) -> bool:
+    return argument.annotation is not None and argument.annotation.is_write
+
+
+def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None:
+    # Assumptions:
+    # - only one argument is being modified in-place
+    # - the argument that is being modified in-place is the first argument
+    # - all returns are either Tensor, tuple of Tensor, or TensorList
+    schema = native_function.func
+    sig = DispatcherSignature.from_schema(schema)
+    returns = schema.returns
+
+    # Check assumptions. If these are invalid we return None
+    # and punt the work to handle them to the future.
+    assert schema.kind() == SchemaKind.inplace
+    if not is_mutated_arg(schema.arguments.flat_all[0]):
+        return None
+    if not len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) == 1:
+        return None
+
+    # Only support cases where all returns are Tensors or vector
+    if len(returns) == 0:
+        return None
+    if not all(is_tensor(ret.type) or is_tensor_list(ret.type) for ret in returns):
+        return None
+    if not accepts_at_least_one_tensor_input(schema):
+        return None
+
+    cur_level_var = "cur_level"
+
+    unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
+    bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
+
+    return f"""\
+template 
+{sig.decl(name=schema.name.unambiguous_name() + "_generated_plumbing")} {{
+  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
+  auto maybe_layer = maybeCurrentDynamicLayer();
+  vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing");
+  int64_t {cur_level_var} = maybe_layer->layerId();
+{textwrap.indent(bdims_all_none_case, "  ")}
+{textwrap.indent(unwraps, "  ")}
+  batch_rule({", ".join(unwrapped_arg_list)});
+  return {schema.arguments.flat_all[0].name};
+}}"""
+
+
+def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str:
+    schema = native_function.func
+    sig = DispatcherSignature.from_schema(schema)
+    cur_level_var = "cur_level"
+
+    unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
+    bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
+
+    return f"""\
+template 
+{sig.decl(name=schema.name.unambiguous_name() + "_generated_plumbing")} {{
+  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
+  auto maybe_layer = maybeCurrentDynamicLayer();
+  vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns");
+  int64_t {cur_level_var} = maybe_layer->layerId();
+{textwrap.indent(bdims_all_none_case, "  ")}
+{textwrap.indent(unwraps, "  ")}
+  batch_rule({", ".join(unwrapped_arg_list)});
+}}"""
+
+
+def gen_vmap_plumbing(native_function: NativeFunction) -> str | None:
+    schema = native_function.func
+    sig = DispatcherSignature.from_schema(schema)
+    returns = schema.returns
+
+    # Only support cases where all returns are Tensors or vector
+    if not accepts_at_least_one_tensor_input(schema):
+        return None
+    if len(returns) == 0:
+        return gen_vmap_plumbing_no_returns(native_function)
+    return_symint_overrides = [
+        "_scaled_dot_product_flash_attention",
+        "_scaled_dot_product_cudnn_attention",
+    ]
+    if (
+        not all(ret.type.is_tensor_like() for ret in returns)
+        and schema.name.unambiguous_name() not in return_symint_overrides
+    ):
+        return None
+    # in-place views need special handling
+    if "inplace_view" in native_function.tags:
+        return None
+
+    if schema.kind() == SchemaKind.inplace:
+        return gen_vmap_inplace_plumbing(native_function)
+
+    # Don't support these (mutable, out, scratch)
+    if schema.kind() != SchemaKind.functional:
+        return None
+
+    results_var = "results"
+    cur_level_var = "cur_level"
+
+    unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
+    bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
+
+    wrapped_returns = gen_returns(returns, cur_level_var, results_var)
+    return f"""\
+template 
+{sig.decl(name=schema.name.unambiguous_name() + "_generated_plumbing")} {{
+  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
+  auto maybe_layer = maybeCurrentDynamicLayer();
+  vmap_check_escaped(maybe_layer, "gen_vmap_plumbing");
+  int64_t {cur_level_var} = maybe_layer->layerId();
+{textwrap.indent(bdims_all_none_case, "  ")}
+{textwrap.indent(unwraps, "  ")}
+  auto {results_var} = batch_rule({", ".join(unwrapped_arg_list)});
+  {wrapped_returns}
+}}"""
+
+
+@dataclass(frozen=True)
+class ComputeBatchRulePlumbing:
+    @method_with_native_function
+    def __call__(self, f: NativeFunction) -> str | None:
+        result = gen_vmap_plumbing(f)
+        return result
+
+
+def gen_all_vmap_plumbing(native_functions: Sequence[NativeFunction]) -> str:
+    body = "\n".join(list(mapMaybe(ComputeBatchRulePlumbing(), native_functions)))
+    return f"""
+#pragma once
+#include 
+#include 
+
+namespace at {{ namespace functorch {{
+
+{body}
+
+}}}} // namespace at::functorch
+"""
diff --git a/phivenv/Lib/site-packages/torchgen/local.py b/phivenv/Lib/site-packages/torchgen/local.py
new file mode 100644
index 0000000000000000000000000000000000000000..b359a29b74e76be5a65ae8588b054d4a9009b890
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/local.py
@@ -0,0 +1,62 @@
+from __future__ import annotations
+
+import threading
+from contextlib import contextmanager
+from typing import TYPE_CHECKING
+
+
+if TYPE_CHECKING:
+    from collections.abc import Iterator
+
+
+# Simple dynamic scoping implementation.  The name "parametrize" comes
+# from Racket.
+#
+# WARNING WARNING: LOOKING TO EDIT THIS FILE?  Think carefully about
+# why you need to add a toggle to the global behavior of code
+# generation.  The parameters here should really only be used
+# for "temporary" situations, where we need to temporarily change
+# the codegen in some cases because we cannot conveniently update
+# all call sites, and are slated to be eliminated once all call
+# sites are eliminated.  If you don't have a plan for how to get there,
+# DON'T add a new entry here.
+
+
+class Locals(threading.local):
+    use_const_ref_for_mutable_tensors: bool | None = None
+    use_ilistref_for_tensor_lists: bool | None = None
+
+
+_locals = Locals()
+
+
+def use_const_ref_for_mutable_tensors() -> bool:
+    assert _locals.use_const_ref_for_mutable_tensors is not None, (
+        "need to initialize local.use_const_ref_for_mutable_tensors with "
+        "local.parametrize"
+    )
+    return _locals.use_const_ref_for_mutable_tensors
+
+
+def use_ilistref_for_tensor_lists() -> bool:
+    assert _locals.use_ilistref_for_tensor_lists is not None, (
+        "need to initialize local.use_ilistref_for_tensor_lists with local.parametrize"
+    )
+    return _locals.use_ilistref_for_tensor_lists
+
+
+@contextmanager
+def parametrize(
+    *, use_const_ref_for_mutable_tensors: bool, use_ilistref_for_tensor_lists: bool
+) -> Iterator[None]:
+    old_use_const_ref_for_mutable_tensors = _locals.use_const_ref_for_mutable_tensors
+    old_use_ilistref_for_tensor_lists = _locals.use_ilistref_for_tensor_lists
+    try:
+        _locals.use_const_ref_for_mutable_tensors = use_const_ref_for_mutable_tensors
+        _locals.use_ilistref_for_tensor_lists = use_ilistref_for_tensor_lists
+        yield
+    finally:
+        _locals.use_const_ref_for_mutable_tensors = (
+            old_use_const_ref_for_mutable_tensors
+        )
+        _locals.use_ilistref_for_tensor_lists = old_use_ilistref_for_tensor_lists
diff --git a/phivenv/Lib/site-packages/torchgen/model.py b/phivenv/Lib/site-packages/torchgen/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d2e4ca34f5eb6600a9f7e113a195c3b4482c2e9
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/model.py
@@ -0,0 +1,2885 @@
+from __future__ import annotations
+
+import dataclasses
+import itertools
+import re
+from dataclasses import dataclass
+from enum import auto, Enum
+from typing import Callable, Optional, TYPE_CHECKING
+from typing_extensions import assert_never
+
+from torchgen.utils import NamespaceHelper, OrderedSet
+
+
+if TYPE_CHECKING:
+    from collections.abc import Iterator, Sequence
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                           DATA MODEL
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+# Some general principles for our data model.
+#
+# - Stop using C++ data types as the internal data representation
+#   format.  Instead, the internal data structures are centered
+#   around JIT schema representation.  This avoid a big problem
+#   with the old codegen where we read in all the types from
+#   native_functions.yaml and then immediately had to retranslate
+#   them into C++ types.
+#
+# - More semantic data representation.  Instead of representing
+#   everything as dicts and strings, we define dataclasses for
+#   every interesting entity the code generation has to deal with.
+#   These dataclasses have strong semantic invariants: for example,
+#   we generally require them to roundtrip losslessly into the
+#   form they were parsed from.  These structures are immutable
+#   and you're expected to populate information once during
+#   construction.
+
+
+# Represent a source location; used for better error reporting
+@dataclass(frozen=True)
+class Location:
+    file: str
+    line: int
+
+    def __str__(self) -> str:
+        return f"{self.file}:{self.line}"
+
+
+# Valid values of the 'variants' field in native_functions.yaml
+class Variant(Enum):
+    function = auto()
+    method = auto()
+
+
+# Default kernel namespace
+DEFAULT_KERNEL_NAMESPACE = "at::native"
+
+# NOTE: Keep the list in sync with `DispatchKey` in c10/core/DispatchKey.h
+BACKEND_COMPONENTS = "CPU CUDA HIP XLA MTIA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split()
+FUNCTIONALITY_KEYS = [
+    "",
+    "Quantized",
+    "Sparse",
+    "SparseCsr",
+    "NestedTensor",
+    "Autograd",
+]
+
+# This list guards dispatches that can be used in derivatives.yaml
+# For now we omit AutogradFunctionality and AutogradOther
+AUTOGRAD_KEYS = ["AutogradNestedTensor"] + [
+    "Autograd" + component for component in BACKEND_COMPONENTS
+]
+
+FRAGMENT_NAMESPACES = {"quantized", "quantized_decomposed"}
+
+
+# This doesn't have to be in sync with the header, it only needs to contain
+# entries that we actually use in the codegen or want pyi entries for
+class DispatchKey(Enum):
+    Undefined = 0
+    CatchAll = Undefined
+
+    FPGA = auto()
+    MAIA = auto()
+    Vulkan = auto()
+    Metal = auto()
+    MKLDNN = auto()
+    OpenGL = auto()
+    OpenCL = auto()
+    IDEEP = auto()
+    CustomRNGKeyId = auto()
+    MkldnnCPU = auto()
+    Sparse = auto()
+    SparseCsr = auto()
+    NestedTensor = auto()
+    Dense = auto()
+
+    PythonTLSSnapshot = auto()
+    PreDispatch = auto()
+    PythonDispatcher = auto()
+    Python = auto()
+    FuncTorchDynamicLayerBackMode = auto()
+    ZeroTensor = auto()
+    Conjugate = auto()
+    Negative = auto()
+    BackendSelect = auto()
+    Named = auto()
+    AutogradOther = auto()
+    AutogradFunctionality = auto()
+    AutogradNestedTensor = auto()
+    Tracer = auto()
+    Autocast = auto()
+    AutocastCPU = auto()
+    AutocastCUDA = auto()
+    Batched = auto()
+    VmapMode = auto()
+    FuncTorchGradWrapper = auto()
+    FuncTorchBatched = auto()
+    BatchedNestedTensor = auto()
+    FuncTorchVmapMode = auto()
+    FuncTorchDynamicLayerFrontMode = auto()
+    Functionalize = auto()
+    TESTING_ONLY_GenericWrapper = auto()
+    TESTING_ONLY_GenericMode = auto()
+
+    ADInplaceOrView = auto()
+    Autograd = auto()
+    CompositeImplicitAutograd = auto()
+    CompositeImplicitAutogradNestedTensor = auto()
+    CompositeExplicitAutograd = auto()
+    CompositeExplicitAutogradNonFunctional = auto()
+    FuncTorchBatchedDecomposition = auto()
+
+    # BEGIN autogenerated
+    CPU = auto()
+    CUDA = auto()
+    HIP = auto()
+    XLA = auto()
+    MTIA = auto()
+    MPS = auto()
+    IPU = auto()
+    XPU = auto()
+    HPU = auto()
+    VE = auto()
+    Lazy = auto()
+    Meta = auto()
+    PrivateUse1 = auto()
+    PrivateUse2 = auto()
+    PrivateUse3 = auto()
+    QuantizedCPU = auto()
+    QuantizedCUDA = auto()
+    QuantizedHIP = auto()
+    QuantizedXLA = auto()
+    QuantizedMTIA = auto()
+    QuantizedMPS = auto()
+    QuantizedIPU = auto()
+    QuantizedXPU = auto()
+    QuantizedHPU = auto()
+    QuantizedVE = auto()
+    QuantizedLazy = auto()
+    QuantizedMeta = auto()
+    QuantizedPrivateUse1 = auto()
+    QuantizedPrivateUse2 = auto()
+    QuantizedPrivateUse3 = auto()
+    SparseCPU = auto()
+    SparseCUDA = auto()
+    SparseHIP = auto()
+    SparseXLA = auto()
+    SparseMTIA = auto()
+    SparseMPS = auto()
+    SparseIPU = auto()
+    SparseXPU = auto()
+    SparseHPU = auto()
+    SparseVE = auto()
+    SparseLazy = auto()
+    SparseMeta = auto()
+    SparsePrivateUse1 = auto()
+    SparsePrivateUse2 = auto()
+    SparsePrivateUse3 = auto()
+    SparseCsrCPU = auto()
+    SparseCsrCUDA = auto()
+    SparseCsrHIP = auto()
+    SparseCsrXLA = auto()
+    SparseCsrMTIA = auto()
+    SparseCsrMPS = auto()
+    SparseCsrIPU = auto()
+    SparseCsrXPU = auto()
+    SparseCsrHPU = auto()
+    SparseCsrVE = auto()
+    SparseCsrLazy = auto()
+    SparseCsrMeta = auto()
+    SparseCsrPrivateUse1 = auto()
+    SparseCsrPrivateUse2 = auto()
+    SparseCsrPrivateUse3 = auto()
+    NestedTensorCPU = auto()
+    NestedTensorCUDA = auto()
+    NestedTensorHIP = auto()
+    NestedTensorXLA = auto()
+    NestedTensorMTIA = auto()
+    NestedTensorMPS = auto()
+    NestedTensorIPU = auto()
+    NestedTensorXPU = auto()
+    NestedTensorHPU = auto()
+    NestedTensorVE = auto()
+    NestedTensorLazy = auto()
+    NestedTensorMeta = auto()
+    NestedTensorPrivateUse1 = auto()
+    NestedTensorPrivateUse2 = auto()
+    NestedTensorPrivateUse3 = auto()
+    AutogradCPU = auto()
+    AutogradCUDA = auto()
+    AutogradHIP = auto()
+    AutogradXLA = auto()
+    AutogradMTIA = auto()
+    AutogradMPS = auto()
+    AutogradIPU = auto()
+    AutogradXPU = auto()
+    AutogradHPU = auto()
+    AutogradVE = auto()
+    AutogradLazy = auto()
+    AutogradMeta = auto()
+    AutogradPrivateUse1 = auto()
+    AutogradPrivateUse2 = auto()
+    AutogradPrivateUse3 = auto()
+    # END autogenerated
+
+    def __str__(self) -> str:
+        return self.name
+
+    def lower(self) -> str:
+        return str(self).lower()
+
+    @staticmethod
+    def parse(value: str) -> DispatchKey:
+        for k, v in DispatchKey.__members__.items():
+            if k == value:
+                return v
+        raise AssertionError(f"unknown dispatch key {value}")
+
+
+class _TorchDispatchModeKey(Enum):
+    FAKE = auto()
+    PROXY = auto()
+    FUNCTIONAL = auto()
+
+
+def codegen_per_backend_entries() -> str:
+    r: list[str] = []
+    for fk in FUNCTIONALITY_KEYS:
+        r.extend(f"    {fk}{bc} = auto()" for bc in BACKEND_COMPONENTS)
+    return "\n".join(r)
+
+
+for fk in FUNCTIONALITY_KEYS:
+    for bc in BACKEND_COMPONENTS:
+        if not hasattr(DispatchKey, fk + bc):
+            r = codegen_per_backend_entries()
+            print(r)
+            raise RuntimeError(
+                f"Missing {fk}{bc} from DispatchKey enum.  Here is the autogenerated list we expect to have:\n\n{r}"
+            )
+
+
+STRUCTURED_DISPATCH_KEYS = {
+    DispatchKey.MPS,
+    DispatchKey.CUDA,
+    DispatchKey.CPU,
+    DispatchKey.XPU,
+    DispatchKey.MTIA,
+}
+UFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU}
+
+# Set of supported dispatch keys
+dispatch_keys = [
+    DispatchKey.CPU,
+    DispatchKey.SparseCPU,
+    DispatchKey.SparseCsrCPU,
+    DispatchKey.MkldnnCPU,
+    DispatchKey.CUDA,
+    DispatchKey.MPS,
+    DispatchKey.XPU,
+    DispatchKey.SparseXPU,
+    DispatchKey.SparseCsrXPU,
+    DispatchKey.SparseCUDA,
+    DispatchKey.SparseCsrCUDA,
+    DispatchKey.QuantizedCPU,
+    DispatchKey.QuantizedCUDA,
+    DispatchKey.CompositeImplicitAutograd,
+    DispatchKey.CompositeImplicitAutogradNestedTensor,
+    DispatchKey.CompositeExplicitAutograd,
+    DispatchKey.CompositeExplicitAutogradNonFunctional,
+    DispatchKey.NestedTensorCPU,
+    DispatchKey.NestedTensorCUDA,
+    DispatchKey.NestedTensorXPU,
+    DispatchKey.NestedTensorHPU,
+    # Meta is a magic key: it is automatically generated for structured
+    # kernels
+    DispatchKey.Meta,
+    DispatchKey.SparseMeta,
+    DispatchKey.SparseCsrMeta,
+    DispatchKey.QuantizedMeta,
+    DispatchKey.NestedTensorMeta,
+    DispatchKey.ZeroTensor,
+    DispatchKey.MTIA,
+]
+
+
+# Dispatch keys that "support all backends".  These codegen slightly differently
+# then backend specific keys.
+def is_generic_dispatch_key(dk: DispatchKey) -> bool:
+    return dk in {
+        DispatchKey.CompositeExplicitAutograd,
+        DispatchKey.CompositeExplicitAutogradNonFunctional,
+        DispatchKey.CompositeImplicitAutograd,
+        DispatchKey.CompositeImplicitAutogradNestedTensor,
+    }
+
+
+# CUDA specific dispatch keys
+def is_cuda_dispatch_key(dk: DispatchKey) -> bool:
+    return dk in {
+        DispatchKey.CUDA,
+        DispatchKey.QuantizedCUDA,
+        DispatchKey.SparseCUDA,
+        DispatchKey.SparseCsrCUDA,
+        DispatchKey.NestedTensorCUDA,
+        DispatchKey.AutogradCUDA,
+    }
+
+
+# XPU specific dispatcy keys
+def is_xpu_dispatch_key(dk: DispatchKey) -> bool:
+    return dk in {
+        DispatchKey.XPU,
+        DispatchKey.QuantizedXPU,
+        DispatchKey.SparseXPU,
+        DispatchKey.SparseCsrXPU,
+        DispatchKey.NestedTensorXPU,
+        DispatchKey.AutogradXPU,
+    }
+
+
+# Structured kernel generation is only supported for certain key types;
+# otherwise use old-style
+def is_structured_dispatch_key(dk: DispatchKey) -> bool:
+    return dk in STRUCTURED_DISPATCH_KEYS
+
+
+def is_ufunc_dispatch_key(dk: DispatchKey) -> bool:
+    # For now, ufunc dispatch keys coincide with structured keys
+    return dk in UFUNC_DISPATCH_KEYS
+
+
+dispatch_device_map = {is_cuda_dispatch_key: "cuda", is_xpu_dispatch_key: "xpu"}
+
+
+# This is oddly named ScalarType and not DType for symmetry with C++
+class ScalarType(Enum):
+    Byte = auto()
+    Char = auto()
+    Short = auto()
+    Int = auto()
+    Long = auto()
+    Half = auto()
+    Float = auto()
+    Double = auto()
+    ComplexHalf = auto()
+    ComplexFloat = auto()
+    ComplexDouble = auto()
+    Bool = auto()
+    BFloat16 = auto()
+    Float8_e5m2 = auto()
+    Float8_e5m2fnuz = auto()
+    Float8_e4m3fn = auto()
+    Float8_e4m3fnuz = auto()
+    Float8_e8m0fnu = auto()
+
+    def __str__(self) -> str:
+        return self.name
+
+    @staticmethod
+    def maybe_parse(value: str) -> ScalarType | None:
+        for k, v in ScalarType.__members__.items():
+            if k == value:
+                return v
+        return None
+
+    @staticmethod
+    def parse(value: str) -> ScalarType:
+        mb_r = ScalarType.maybe_parse(value)
+        assert mb_r is not None, f"unknown dtype {value}"
+        return mb_r
+
+    @staticmethod
+    def parse_set(values: str) -> OrderedSet[ScalarType]:
+        dtypes: OrderedSet[ScalarType] = OrderedSet()
+        for value in values.split(", "):
+            if value in DTYPE_CLASSES:
+                dtypes.update(DTYPE_CLASSES[value])
+            else:
+                dtypes.add(ScalarType.parse(value))
+        return dtypes
+
+
+DTYPE_CLASSES: dict[str, OrderedSet[ScalarType]] = {}
+# NB: Integral doesn't include boolean
+DTYPE_CLASSES["Integral"] = OrderedSet(
+    [
+        ScalarType.Byte,
+        ScalarType.Char,
+        ScalarType.Int,
+        ScalarType.Long,
+        ScalarType.Short,
+    ]
+)
+# NB: Floating doesn't include low precision types
+DTYPE_CLASSES["Floating"] = OrderedSet([ScalarType.Float, ScalarType.Double])
+DTYPE_CLASSES["Complex"] = OrderedSet(
+    [ScalarType.ComplexFloat, ScalarType.ComplexDouble]
+)
+DTYPE_CLASSES["All"] = DTYPE_CLASSES["Integral"] | DTYPE_CLASSES["Floating"]
+DTYPE_CLASSES["AllAndComplex"] = DTYPE_CLASSES["All"] | DTYPE_CLASSES["Complex"]
+DTYPE_CLASSES["FloatingAndComplex"] = (
+    DTYPE_CLASSES["Floating"] | DTYPE_CLASSES["Complex"]
+)
+
+
+# Represents the valid entries for ufunc_inner_loop in native_functions.yaml.
+# NB: if you add a new UfuncKey, you will teach torchgen.dest.ufunc how
+# to process it.  Most logic will ignore keys they don't understand, so your
+# new key will get silently ignored until you hook in logic to deal with it.
+class UfuncKey(Enum):
+    # These are low level keys that represent exactly one particular
+    # instantiation of the kernel produced by codegen
+    CUDAFunctor = auto()
+    CUDAFunctorOnOther = auto()
+    CUDAFunctorOnSelf = auto()
+
+    CPUScalar = auto()
+    CPUVector = auto()
+
+    # These are the ones users will usually specify, and
+    # implicitly "fill in" the low level keys
+    ScalarOnly = auto()  # CUDA*, CPUScalar
+    Generic = auto()  # CUDA*, CPU*
+
+    def __str__(self) -> str:
+        return self.name
+
+    @staticmethod
+    def parse(value: str) -> UfuncKey:
+        for k, v in UfuncKey.__members__.items():
+            if k == value:
+                return v
+        raise AssertionError(f"unknown ufunc key {value}")
+
+
+class DeviceCheckType(Enum):
+    NoCheck = 0
+    ExactSame = 1
+
+
+class ViewSchemaKind(Enum):
+    aliasing = auto()
+    aliasing_inplace = auto()
+    non_aliasing = auto()
+
+
+# The basic input to the code generation is native_functions.yaml.
+# The name "native", BTW, comes from the distinction between native
+# functions and legacy TH functions.  The legacy TH functions are gone,
+# but the "native" descriptor has stuck.
+#
+# NativeFunction models a single entry in native_functions.yaml.  Its
+# fields roughly correspond to what you would see in the YAML itself,
+# but after canonicalization and parsing has occurred.
+#
+# You can see some of the overall design patterns for how we setup
+# dataclasses in this class, but we will defer a complete discussion
+# of this at FunctionSchema.
+@dataclass(frozen=True)
+class NativeFunction:
+    # The namespace for this operator. For example, if we have "at::add"
+    # then the namespace would be "at". This enables ops to be registered
+    # through the same DSL with a custom namespace. If not specified, the
+    # default namespace would be "at".
+    namespace: str
+
+    # The function schema of the operator in question.  This schema
+    # has been parsed; see FunctionSchema for more about its structure.
+    # (This type is quoted as we are forward referencing a type
+    # defined later in the file.  I opted for this ordering of the
+    # classes for expository clarity.)
+    func: FunctionSchema
+
+    # Whether or not to generate mutable tensor arguments like regular
+    # ones
+    use_const_ref_for_mutable_tensors: bool
+
+    # Whether or not to omit automatic generation of a DeviceGuard
+    device_guard: bool
+
+    # How to emit automatic generation of device check
+    device_check: DeviceCheckType
+
+    # What python module to put the function in
+    python_module: str | None
+
+    # TODO: figure out what this does
+    category_override: str | None
+
+    # If no variants are specified in native_functions.yaml, this is
+    # assumed to be {'function'}.
+    variants: set[Variant]
+
+    # Whether or not we should skip generating registrations for
+    # this kernel.  This is a bit of a double-edged sword, as manual
+    # registrations don't participate in codegen-based selective build!
+    manual_kernel_registration: bool
+
+    # Whether or not to skip generating TensorMethod/Functions bindings
+    # for this kernel.  Technically, this doesn't actually skip generating
+    # the binding; instead, the binding gets generated to __dispatch_{funcname}
+    # so you can make use of the normal binding if you need it.
+    manual_cpp_binding: bool
+
+    # The location in the YAML file were this native function entry was
+    # defined.  This is for conveniently reporting error messages!
+    loc: Location
+
+    # A list of operators that are expected to be auto-generated for this NativeFunction.
+    # Note: This list isn't actually directly used by the codegen to generate anything.
+    # Instead, the codegen figures out what operators to generate purely based off of
+    # function schema, and uses the autogen declarations to error check.
+    # We expect every NativeFunction that gets auto-generated be explicitly called out
+    # in native_functions.yaml
+    autogen: list[OperatorName]
+
+    # If non-empty, this kernel is subject to ufunc codegen.
+    # Sorted by ufunc_key
+    ufunc_inner_loop: dict[UfuncKey, UfuncInnerLoop]
+
+    # Whether or not this out functions is a "structured kernel".  Structured
+    # kernels are defined a little differently from normal kernels; in
+    # particular, their shape checking logic is defined separately from
+    # the kernel.  Only out functions can be structured; other functions
+    # delegate to the out function using the structured_delegate keyword.
+    # Every structured kernel must have at least an out and a functional
+    # variant.
+    structured: bool
+
+    # Whether or not this non-out function is a structured kernel, defined
+    # in terms of the out kernel referenced by the string here.
+    structured_delegate: OperatorName | None
+
+    # Only valid for structured kernels.  Specifies alternative of what
+    # to inherit from when defining the meta class for the structured
+    # operator.  This will usually be TensorIteratorBase.  This also
+    # changes the semantics of set_output to call the parent class.
+    structured_inherits: str | None
+
+    # Structured kernels can declare elements as "precomputed". These elements
+    # are returned by the meta function in one struct and passed to the impl
+    # function in lieu of certain kernel arguments that these precomputed
+    # elements supersede. Information about the names and types of these
+    # precomputed elements and how they correspond to kernel arguments is stored
+    # in this member, if applicable.
+    precomputed: Precompute | None
+
+    # Argument names whose default  should be excluded from the C++ interface.
+    # Intended for resolving overload ambiguities between signatures.
+    cpp_no_default_args: set[str]
+
+    # Note [Abstract ATen methods]
+    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+    # An abstract ATen method is one whose dispatch differs between
+    # types.  These are implemented in derived types (with a
+    # standard (throwing) definition in Type).  A concrete ATen
+    # method is one which has the same dispatch for all types;
+    # we just implement it in the base Type.  This is exposed
+    # in Declarations.yaml via a field named 'abstract'.
+    is_abstract: bool
+
+    # Whether or not the NativeFunction contains a backend-agnostic kernel
+    has_composite_implicit_autograd_kernel: bool
+    has_composite_implicit_autograd_nested_tensor_kernel: bool
+    has_composite_explicit_autograd_kernel: bool
+    has_composite_explicit_autograd_non_functional_kernel: bool
+
+    # Tags are used to describe semantic information about (groups of) operators,
+    # That aren't easily inferable directly from the operator's schema.
+    tags: set[str]
+
+    # NB: The benefit of defining a dataclass is that we automatically get
+    # a constructor defined for all the fields we specify.  No need
+    # to explicitly write it out.
+
+    # We parse both the NativeFunction + backend-specific information about it, which it stored in a corresponding BackendIndex.
+    @staticmethod
+    def from_yaml(
+        ei: dict[str, object],
+        loc: Location,
+        valid_tags: set[str],
+        ignore_keys: set[DispatchKey] | None = None,
+    ) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
+        """
+        Parse a NativeFunction from a dictionary as directly parsed
+        from native_functions.yaml
+        """
+        e = ei.copy()
+
+        funcs = e.pop("func")
+        assert isinstance(funcs, str), f"not a str: {funcs}"
+        # only support one level of namespace. E.g., aten::add
+        namespace_helper = NamespaceHelper.from_namespaced_entity(
+            namespaced_entity=funcs, max_level=1
+        )
+        namespace = namespace_helper.get_cpp_namespace(default="aten")
+        func = FunctionSchema.parse(namespace_helper.entity_name)
+
+        cpp_no_default_args_list = e.pop("cpp_no_default_args", [])
+        assert isinstance(cpp_no_default_args_list, list)
+        cpp_no_default_args = set(cpp_no_default_args_list)
+
+        use_const_ref_for_mutable_tensors = e.pop(
+            "use_const_ref_for_mutable_tensors", False
+        )
+        assert isinstance(use_const_ref_for_mutable_tensors, bool)
+
+        if use_const_ref_for_mutable_tensors:
+            assert not func.arguments.out, (
+                "see https://github.com/pytorch/pytorch/issues/145522"
+            )
+
+        variants_s = e.pop("variants", "function")
+        assert isinstance(variants_s, str)
+        variants: set[Variant] = set()
+        for v in variants_s.split(", "):
+            if v == "function":
+                variants.add(Variant.function)
+            elif v == "method":
+                variants.add(Variant.method)
+            else:
+                raise AssertionError(f"illegal variant {v}")
+
+        manual_kernel_registration = e.pop("manual_kernel_registration", False)
+        assert isinstance(manual_kernel_registration, bool), (
+            f"not a bool: {manual_kernel_registration}"
+        )
+
+        manual_cpp_binding = e.pop("manual_cpp_binding", False)
+        assert isinstance(manual_cpp_binding, bool), f"not a bool: {manual_cpp_binding}"
+
+        device_guard = e.pop("device_guard", True)
+        assert isinstance(device_guard, bool), f"not a bool: {device_guard}"
+
+        device_check_s = e.pop("device_check", None)
+        assert device_check_s is None or isinstance(device_check_s, str), (
+            f"not a str: {device_check_s}"
+        )
+        assert (
+            device_check_s is None or device_check_s in DeviceCheckType.__members__
+        ), f"illegal device_check: {device_check_s}"
+        device_check: DeviceCheckType
+        if device_check_s is None:
+            device_check = DeviceCheckType.ExactSame
+        else:
+            device_check = DeviceCheckType[device_check_s]
+
+        structured = e.pop("structured", False)
+        assert isinstance(structured, bool), f"not a bool: {structured}"
+
+        structured_delegate_s = e.pop("structured_delegate", None)
+        assert structured_delegate_s is None or isinstance(
+            structured_delegate_s, str
+        ), f"not a str: {structured_delegate_s}"
+        assert structured_delegate_s is None or "::" not in structured_delegate_s, (
+            "namespace is not supported in structured delegate,"
+            " using the same namespace as the native function"
+        )
+        structured_delegate: OperatorName | None = None
+        if structured_delegate_s is not None:
+            structured_delegate = OperatorName.parse(structured_delegate_s)
+
+        structured_inherits = e.pop("structured_inherits", None)
+        assert structured_inherits is None or isinstance(structured_inherits, str), (
+            f"not a str: {structured_inherits}"
+        )
+        assert structured_inherits is None or "::" not in structured_inherits, (
+            "namespace is not supported in structured inherits,"
+            " using the same namespace as the native function"
+        )
+
+        python_module = e.pop("python_module", None)
+        assert python_module is None or isinstance(python_module, str), (
+            f"not a str: {python_module}"
+        )
+        assert python_module is None or Variant.method not in variants, (
+            "functions in modules cannot be methods"
+        )
+
+        category_override = e.pop("category_override", None)
+        assert category_override is None or isinstance(category_override, str), (
+            f"not a str: {category_override}"
+        )
+
+        precomputed_dict = e.pop("precomputed", None)
+        assert precomputed_dict is None or structured is True
+        precomputed = Precompute.parse(precomputed_dict) if precomputed_dict else None
+
+        tags_inp = e.pop("tags", [])
+        if isinstance(tags_inp, str):
+            tags_inp = [tags_inp]
+        assert isinstance(tags_inp, list)
+
+        # All aten ops generated by torchgen receive the pt2_compliant tag.
+        if namespace == "aten" and "pt2_compliant_tag" in valid_tags:
+            tags_inp.append("pt2_compliant_tag")
+
+        tags: set[str] = set()
+        for t in tags_inp:
+            assert len(valid_tags) > 0
+            # TODO: verify that the tag is valid and has an entry in tags.yaml
+            if t in valid_tags:
+                tags.add(t)
+            else:
+                raise AssertionError(f"illegal tag {t}")
+
+        from torchgen.api import cpp
+
+        raw_dispatch = e.pop("dispatch", None)
+        assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
+        dispatch: dict[DispatchKey, BackendMetadata] = {}
+        num_dispatch_keys: int = 0
+        if raw_dispatch is not None:
+            assert not manual_kernel_registration, (
+                "cannot specify both manual_kernel_registration and dispatch; with "
+                "manual registration, dispatch has no effect!"
+            )
+            redundant_composite_implicit_autograd = False
+            for ks, v in raw_dispatch.items():
+                if ks == "__line__":
+                    continue  # not worth tracking line numbers for dispatch entries
+                assert isinstance(ks, str), (
+                    f"illegal dispatch key '{ks}' in {raw_dispatch}"
+                )
+                assert isinstance(v, str), (
+                    f"illegal dispatch value '{v}' in {raw_dispatch}"
+                )
+                for k in ks.split(","):
+                    dispatch_key = DispatchKey.parse(k.strip())
+                    num_dispatch_keys += 1
+
+                    if ignore_keys and dispatch_key in ignore_keys:
+                        continue
+                    assert dispatch_key in dispatch_keys, (
+                        f"Dispatch key {dispatch_key} of kernel {v} "
+                        "is not a supported dispatch key."
+                    )
+                    # We only allow at most 3 levels of namespace for kernels.
+                    # We will append "native" to a custom kernel namespace.
+                    namespace_helper = NamespaceHelper.from_namespaced_entity(
+                        v, max_level=3
+                    )
+                    kernel_namespace = namespace_helper.get_cpp_namespace(default="at")
+                    # Why is 'structured' included? External backends (e.g.
+                    # XLA) opt into which ops are structured independently
+                    # of which in-tree ops are structured
+                    dispatch[dispatch_key] = BackendMetadata(
+                        kernel=namespace_helper.entity_name,
+                        structured=structured
+                        and is_structured_dispatch_key(dispatch_key),
+                        cpp_namespace=(kernel_namespace + "::native"),
+                    )
+                    if (
+                        dispatch_key is DispatchKey.CompositeImplicitAutograd
+                        and v == cpp.name(func)
+                    ):
+                        redundant_composite_implicit_autograd = True
+
+            # We count the number of dispatch keys which have not been ignored to prevent a dispatch table
+            # in which all backend keys are ignored but necessarily kept, remaining compositeimplicit,
+            # from being treated as redundant.
+            assert not (
+                num_dispatch_keys == 1 and redundant_composite_implicit_autograd
+            ), (
+                "unnecessary dispatch table for this function; just delete the dispatch "
+                "key entirely"
+            )
+            # if a function is a structured delegate, deleting the dispatch
+            # table is NOT semantics preserving
+            assert (
+                structured_delegate
+                or dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
+                or dispatch[DispatchKey.CompositeImplicitAutograd].supports_symint()
+                or num_dispatch_keys != 1
+            ), (
+                f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} "
+                f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}.  Rename your implementation to the expected "
+                "name, then delete the dispatch table"
+            )
+        elif not structured and structured_delegate is None:
+            name = str(func.name.name)
+            assert not (
+                name.startswith("new_")
+                or name.endswith("_like")
+                # TODO: maybe it's better to test the return
+                or (
+                    func.arguments.tensor_options
+                    and not func.arguments.has_tensor_arg()
+                )
+            ), (
+                f"expected {name} to have a CompositeExplicitAutograd "
+                "dispatch entry, but there was no dispatch table.  Factory functions "
+                "should not have implicit dispatch as they should not be decomposed "
+                "for __torch_dispatch__"
+            )
+            dispatch[DispatchKey.CompositeImplicitAutograd] = BackendMetadata(
+                cpp.name(func), structured=False, cpp_namespace=DEFAULT_KERNEL_NAMESPACE
+            )
+
+        composites_in_dispatch = [
+            d
+            for d in dispatch
+            if d == DispatchKey.CompositeExplicitAutograd
+            or d == DispatchKey.CompositeExplicitAutogradNonFunctional
+            or d == DispatchKey.CompositeImplicitAutograd
+            or d == DispatchKey.CompositeImplicitAutogradNestedTensor
+        ]
+
+        assert len(composites_in_dispatch) <= 1 or (
+            len(composites_in_dispatch) == 2
+            and (
+                DispatchKey.CompositeExplicitAutogradNonFunctional
+                not in composites_in_dispatch
+            )
+            and (
+                DispatchKey.CompositeImplicitAutogradNestedTensor
+                in composites_in_dispatch
+            )
+        ), (
+            "cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, "
+            "or CompositeImplicitAutograd on a single kernel; each "
+            "strictly subsumes the other.  If you wanted to provide an explicit autograd "
+            "implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only"
+        )
+
+        autogen_str = e.pop("autogen", "")
+        assert isinstance(autogen_str, str)
+        autogen = (
+            []
+            if autogen_str == ""
+            else [OperatorName.parse(x) for x in autogen_str.split(", ")]
+        )
+
+        raw_ufunc_inner_loop = e.pop("ufunc_inner_loop", {})
+        ufunc_inner_loop = {}
+        if isinstance(raw_ufunc_inner_loop, str):
+            ufunc_inner_loop[UfuncKey.Generic] = UfuncInnerLoop.parse(
+                raw_ufunc_inner_loop, UfuncKey.Generic
+            )
+        elif isinstance(raw_ufunc_inner_loop, dict):
+            for k, vo in raw_ufunc_inner_loop.items():
+                if k == "__line__":
+                    continue
+                assert isinstance(k, str), f"ufunc_inner_loop key is not a str: {k}"
+                assert isinstance(vo, str), f"ufunc_inner_loop value is not a str: {v}"
+                ufunc_key = UfuncKey.parse(k)
+                ufunc_inner_loop[ufunc_key] = UfuncInnerLoop.parse(vo, ufunc_key)
+        else:
+            raise AssertionError(
+                f"ufunc_inner_loop not str or dict: {raw_ufunc_inner_loop}"
+            )
+        # Program the BackendIndex for the implicit dispatch entry from ufunc
+        if ufunc_inner_loop:
+            assert structured, "ufunc must be structured"
+
+            # Delay import ufunc here to avoid circular import issue
+            # See: https://github.com/pytorch/pytorch/issues/81294
+            import torchgen.api.ufunc as ufunc
+
+            for dispatch_key in UFUNC_DISPATCH_KEYS:
+                assert dispatch_key not in dispatch, (
+                    f"ufunc should not have explicit dispatch entry for {dispatch_key}"
+                )
+                dispatch[dispatch_key] = BackendMetadata(
+                    kernel=ufunc.schema_kernel_name(func, dispatch_key),
+                    structured=True,
+                    cpp_namespace=DEFAULT_KERNEL_NAMESPACE,
+                )
+
+        if structured_delegate:
+            # Structured functions MUST have a dispatch table
+            is_abstract = True
+        else:
+            is_abstract = (
+                dispatch.keys() != {DispatchKey.CompositeImplicitAutograd}
+                and dispatch.keys()
+                != {DispatchKey.CompositeImplicitAutogradNestedTensor}
+                and dispatch.keys()
+                != {
+                    DispatchKey.CompositeImplicitAutograd,
+                    DispatchKey.CompositeImplicitAutogradNestedTensor,
+                }
+            )
+
+        has_composite_implicit_autograd_kernel = (
+            DispatchKey.CompositeImplicitAutograd in dispatch
+        )
+        has_composite_implicit_autograd_nested_tensor_kernel = (
+            DispatchKey.CompositeImplicitAutogradNestedTensor in dispatch
+        )
+        has_composite_explicit_autograd_kernel = (
+            DispatchKey.CompositeExplicitAutograd in dispatch
+        )
+        has_composite_explicit_autograd_non_functional_kernel = (
+            DispatchKey.CompositeExplicitAutogradNonFunctional in dispatch
+        )
+
+        # We aren't going to store dispatch metadata inline in NativeFunctions;
+        # instead it is separately indexed by backend (so other backends can
+        # add more dispatch entries after the fact).  Reindex the individual
+        # metadata by OperatorName!
+        backend_metadata = {k: {func.name: v} for k, v in dispatch.items()}
+
+        # don't care if it exists or not; make it easier to use this function
+        # with other yaml parsers that aren't setting __line__ in the dict
+        e.pop("__line__", None)
+        assert not e, f"leftover entries: {e}"
+
+        # Asserts that we can't do in post_init, because they rely on backend-specific info
+        if structured_delegate is not None:
+            for key in STRUCTURED_DISPATCH_KEYS:
+                assert key not in dispatch, (
+                    f"if structured_delegate, then must not have {key} in dispatch dictionary "
+                    "(it is delegated!)"
+                )
+
+        return (
+            NativeFunction(
+                func=func,
+                use_const_ref_for_mutable_tensors=use_const_ref_for_mutable_tensors,
+                variants=variants,
+                structured=structured,
+                structured_delegate=structured_delegate,
+                structured_inherits=structured_inherits,
+                precomputed=precomputed,
+                autogen=autogen,
+                ufunc_inner_loop=ufunc_inner_loop,
+                manual_kernel_registration=manual_kernel_registration,
+                manual_cpp_binding=manual_cpp_binding,
+                python_module=python_module,
+                category_override=category_override,
+                device_guard=device_guard,
+                device_check=device_check,
+                loc=loc,
+                cpp_no_default_args=cpp_no_default_args,
+                is_abstract=is_abstract,
+                has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel,
+                has_composite_implicit_autograd_nested_tensor_kernel=has_composite_implicit_autograd_nested_tensor_kernel,
+                has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel,
+                has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel,
+                tags=tags,
+                namespace=namespace,
+            ),
+            backend_metadata,
+        )
+
+    def validate_unstructured(self) -> None:
+        # TODO: probably better to accumulate these errors and report them all
+        # at once
+        assert not self.structured, (
+            "This function is structured, but there was "
+            "no valid functional variant of it."
+        )
+        assert self.structured_delegate, (
+            "This function delegates to another structured out function, "
+            "but no valid function was found (the delegate may not exist, or it has the wrong type)"
+        )
+
+    # __post_init__ functions in dataclasses can be used to do extra
+    # validation after construction.
+    #
+    # Notice that we don't do any type validation here.  In fact, we
+    # rely exclusively on mypy to check if you've done types correctly!
+    # Validation is for nontrivial invariants that cannot be (conveniently)
+    # encoded in the type system.
+    def __post_init__(self) -> None:
+        if self.func.arguments.out:
+            assert self.variants == {Variant.function}, (
+                "Native functions with out arguments MUST "
+                "be declared with only function variant; e.g., variants: function; "
+                "otherwise you will tickle a Python argument binding bug "
+                "(which usually manifests itself as the result variable being undefined.)"
+            )
+        if self.structured:
+            assert self.func.kind() == SchemaKind.out, (
+                "Put structured field on the out= "
+                "variant of a function; did you mean structured_delegate?"
+            )
+            assert self.device_guard, (
+                "device_guard: False is not respected by structured kernels"
+            )
+        if self.structured_delegate:
+            assert self.func.kind() != SchemaKind.out, (
+                "structured_delegate field not allowed "
+                "on out= functions; did you mean structured?"
+            )
+            assert self.device_guard, (
+                "device_guard: False is not respected by structured kernels"
+            )
+        # Technically, with the asserts above, this assert is impossible to
+        # happen
+        assert not (self.structured and self.structured_delegate), (
+            "Cannot have both structured and structured_delegate on function"
+        )
+        defaulted_arguments = {
+            a.name for a in self.func.schema_order_arguments() if a.default is not None
+        }
+        invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments)
+        assert len(invalid_args) == 0, f"Invalid cpp_no_default_args: {invalid_args}"
+        if self.structured_inherits is not None:
+            assert self.structured, (
+                "structured_inherits must also imply structured: True"
+            )
+        if str(self.func.name).startswith("_foreach"):
+            assert self.device_check == DeviceCheckType.NoCheck, (
+                "foreach kernels fall back to slow path when tensor are on different devices, "
+                "device_check not allowed to be enabled"
+            )
+
+        # NB: if your function accidentally has rand/dropout/... in its name
+        # but is not actually random, feel free to amend this to special case
+        if (
+            "rand" in str(self.func.name)
+            or (
+                (
+                    "dropout" in str(self.func.name)
+                    or any(
+                        "dropout" in arg.name for arg in self.func.arguments.flat_all
+                    )
+                )
+                # Backwards of dropout is typically deterministic
+                and "backward" not in str(self.func.name)
+                and str(self.func.name.name) not in ["_cudnn_init_dropout_state"]
+            )
+            or self.func.arguments.has_generator_arg()
+        ):
+            assert "nondeterministic_seeded" in self.tags, str(self.func.name)
+
+    @property
+    def has_composite_kernel(self) -> bool:
+        return (
+            self.has_composite_implicit_autograd_kernel
+            or self.has_composite_explicit_autograd_kernel
+            or self.has_composite_explicit_autograd_non_functional_kernel
+        ) or (
+            self.has_composite_implicit_autograd_kernel
+            and self.has_composite_implicit_autograd_nested_tensor_kernel
+        )
+
+    @property
+    def is_view_op(self) -> bool:
+        rets = self.func.returns
+        is_non_mutating_view = len(rets) > 0 and any(
+            r.annotation is not None and not r.annotation.is_write for r in rets
+        )
+        # See Note [resize_ in Functionalization] for more dtails
+        is_inplace_view = (
+            "inplace_view" in self.tags
+            and str(self.func.name) != "resize_"
+            and str(self.func.name) != "resize_as_"
+        )
+        is_wildcard_view = any(
+            inp.annotation is not None and "*" in inp.annotation.alias_set_after
+            for inp in self.func.schema_order_arguments()
+        )
+        return is_non_mutating_view or is_inplace_view or is_wildcard_view
+
+    @property
+    def view_schema_kind(self) -> ViewSchemaKind:
+        if self.is_view_op and self.func.name.name.inplace:
+            assert "inplace_view" in self.tags
+            return ViewSchemaKind.aliasing_inplace
+        if self.is_view_op:
+            return ViewSchemaKind.aliasing
+        else:
+            return ViewSchemaKind.non_aliasing
+
+    @property
+    def root_name(self) -> str:
+        return self.func.name.name.base
+
+    @property
+    def part_of_structured_group(self) -> bool:
+        return self.structured or self.structured_delegate is not None
+
+
+class SchemaKind(Enum):
+    functional = auto()
+    inplace = auto()
+    out = auto()
+    mutable = auto()
+    scratch = auto()
+
+
+# A structured kernel is guaranteed to have a functional and out variant, and
+# optionally an inplace variant.
+#
+# NB: we create NativeFunctionsGroup *even if* the function is not
+# actually annotated structured.  Test the structured boolean to see if it
+# actually is structured or not.
+@dataclass(frozen=True)
+class NativeFunctionsGroup:
+    functional: NativeFunction
+    inplace: NativeFunction | None
+    mutable: NativeFunction | None
+    out: NativeFunction
+
+    @property
+    def structured(self) -> bool:
+        # Whether or not the operator has a meta() function. This information is backend-agnostic.
+        return self.out.structured
+
+    def __post_init__(self) -> None:
+        test_sig: FunctionSchema = self.functional.func.signature()
+        for f in self.functions():
+            if test_sig != f.func.signature():
+                raise AssertionError(
+                    "NativeFunctionsGroup constructed from two NativeFunctions "
+                    f"that don't have matching signatures: {test_sig} != {f.func.signature()}"
+                )
+
+            if self.structured != f.part_of_structured_group:
+                raise AssertionError(
+                    "NativeFunctionsGroup constructed from structured and unstructured "
+                    f"functions: {self.out.func.name} and {f.func.name}"
+                )
+        assert self.functional.func.kind() == SchemaKind.functional
+        assert self.out.func.kind() == SchemaKind.out
+        assert self.functional.namespace == self.out.namespace
+        if self.inplace is not None:
+            assert self.inplace.func.kind() == SchemaKind.inplace
+            assert self.inplace.namespace == self.functional.namespace
+
+        if self.mutable is not None:
+            assert self.mutable.func.kind() == SchemaKind.mutable
+            assert self.mutable.namespace == self.functional.namespace
+            # See Note [Overload Ambiguity With Functional Variants]
+            assert self.functional.func.name.name.functional_overload
+
+        if self.structured:
+            # For now, structured composite kernels are not supported (need some
+            # design work to figure out how to make the composite case work)
+            assert (
+                not self.out.has_composite_implicit_autograd_kernel
+                and not self.out.has_composite_implicit_autograd_nested_tensor_kernel
+            )
+
+            assert self.functional.structured_delegate == self.out.func.name, (
+                f"{self.functional.func.name} delegates to {self.functional.structured_delegate} "
+                f"but its actual delegate is {self.out.func.name}"
+            )
+            if self.inplace is not None:
+                assert self.inplace.structured_delegate == self.out.func.name
+
+        generated_fns = sorted(
+            [str(f.func.name) for f in self.functions() if "generated" in f.tags]
+        )
+        generated_fns_str = ", ".join(str(x) for x in generated_fns)
+        expected_generated_fns: set[str] = set()
+        for f in self.functions():
+            expected_generated_fns.update(str(op) for op in f.autogen)
+        expected_generated_fns_str = ", ".join(
+            str(x) for x in sorted(expected_generated_fns)
+        )
+        if len(expected_generated_fns) == 0 and len(generated_fns) > 0:
+            raise RuntimeError(
+                f"The codegen expects to be able to generate '{generated_fns_str}'."
+                " In order to generate them however, we expect them to be called out explicitly in the yaml."
+                f" Please add an 'autogen: {generated_fns_str}' line to the entry for {str(f.func.name)}"
+            )
+        if expected_generated_fns_str != generated_fns_str:
+            raise RuntimeError(
+                f"The codegen expects to be able to generate '{generated_fns_str}'."
+                f" To do so, it expects a line: 'autogen: {generated_fns_str}'."
+                f" Instead, it found 'autogen: {expected_generated_fns_str}'"
+            )
+
+    def signature(self) -> FunctionSchema:
+        return self.out.func.signature()
+
+    def functions(self) -> Iterator[NativeFunction]:
+        yield self.functional
+        yield self.out
+        if self.inplace is not None:
+            yield self.inplace
+        if self.mutable is not None:
+            yield self.mutable
+
+    @property
+    def root_name(self) -> str:
+        return self.functional.root_name
+
+    @staticmethod
+    def from_dict(d: dict[SchemaKind, NativeFunction]) -> NativeFunctionsGroup | None:
+        assert d
+        if len(d) == 1:
+            return None
+        d = dict(d)  # non-destructive updates please
+        functional = d.pop(SchemaKind.functional, None)
+        inplace = d.pop(SchemaKind.inplace, None)
+        mutable = d.pop(SchemaKind.mutable, None)
+        out = d.pop(SchemaKind.out, None)
+        assert not d
+        assert functional is not None
+        # There are a few operators which only have functional/inplace variants;
+        # these don't count as structured for our purposes here
+        if out is None:
+            return None
+        # assuming all variants have the same namespace
+        return NativeFunctionsGroup(
+            functional=functional,
+            inplace=inplace,
+            mutable=mutable,
+            out=out,
+        )
+
+
+@dataclass(frozen=True)
+class BackendMetadata:
+    # The name of the backend kernel, for a given operator
+    # for in-tree backends. These names come directly from the 'dispatch" field
+    # in native_functions.yaml. The dispatch entry is optional; in that
+    # case, that is equivalent to having written:
+    #
+    #   dispatch:
+    #       CompositeImplicitAutograd: $operator_name
+    kernel: str
+    # Whether or not the operator has a structured kernel implemented, for this particular backend.
+    # For in-tree backends, they all have the same value for structured- this is listed
+    # in native_functions.yaml.
+    # However, external backends like XLA can indendently toggle which ops are structured.
+    structured: bool
+
+    # The namespace for kernels, default value: DEFAULT_KERNEL_NAMESPACE
+    cpp_namespace: str
+
+    def supports_symint(self) -> bool:
+        return "_symint" in self.kernel
+
+
+@dataclass(frozen=True)
+class UfuncInnerLoop:
+    name: str
+    supported_dtypes: OrderedSet[ScalarType]
+    # key is stored here because it affects the semantics of name,
+    # so its helpful to have them together for further processing
+    ufunc_key: UfuncKey
+
+    @staticmethod
+    def parse(value: str, ufunc_key: UfuncKey) -> UfuncInnerLoop:
+        name, supported_dtypes_str = value.split(" ", 1)
+        assert supported_dtypes_str[0] == "("
+        assert supported_dtypes_str[-1] == ")"
+        supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
+        for k in supported_dtypes_str[1:-1].split(", "):
+            supported_dtypes |= ScalarType.parse_set(k)
+        return UfuncInnerLoop(
+            name=name, supported_dtypes=supported_dtypes, ufunc_key=ufunc_key
+        )
+
+
+# BackendIndex represents a backend.
+# The BackendIndex encodes per-operator information that is potentially different
+# for each backend. The most obvious example is the name of the kernel
+# (the 'dispatch' entry in native_functions.yaml).
+# However, there can be other examples of different backends having different information.
+# External backends can choose to opt their kernels to be structured independently from in-tree backends,
+# which means that this information isn't inherently tied to a NativeFunction- it's different per backend.
+@dataclass(frozen=True)
+class BackendIndex:
+    dispatch_key: DispatchKey
+    # Mainly important for structured kernels, this determines which variant in the operator group is used to implement the others.
+    # All in-tree ops use out kernels, while XLA uses functional kernels.
+    use_out_as_primary: bool
+    # Whether the backend requires a device guard, and device checks.
+    # For in-tree backends, this is currently just CUDA/HIP
+    # For out-of-tree backends, this is currently just Intel XPU
+    device_guard: bool
+    # Whether the backend is in-tree (CPU/CUDA) or out-of-tree (XLA)
+    external: bool
+    # Other backend-specific information that is on a per-operator basis
+    index: dict[OperatorName, BackendMetadata]
+
+    @staticmethod
+    def grow_index(
+        parent_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
+        child_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
+    ) -> None:
+        for k, v in child_index.items():
+            for op_name, metadata in v.items():
+                assert op_name not in parent_index[k], (
+                    f"duplicate operator {op_name} for dispatch key {k}"
+                )
+                parent_index[k][op_name] = metadata
+
+    def primary(self, g: NativeFunctionsGroup) -> NativeFunction:
+        if self.use_out_as_primary:
+            return g.out
+        else:
+            return g.functional
+
+    def has_kernel(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
+        m = self.get_kernel(g)
+        return m is not None
+
+    def get_kernel(
+        self, g: NativeFunction | NativeFunctionsGroup
+    ) -> BackendMetadata | None:
+        if isinstance(g, NativeFunction):
+            f = g
+        elif isinstance(g, NativeFunctionsGroup):
+            f = self.primary(g)
+        else:
+            assert_never(g)
+        if f.func.name not in self.index:
+            return None
+        return self.index[f.func.name]
+
+    def native_function_class_name(self) -> str | None:
+        if self.external:
+            return f"{str(self.dispatch_key)}NativeFunctions"
+        else:
+            # TODO: This discrepancy isn't required; we could also generated
+            # a class for in-tree kernels. It'll just require carefully
+            # updating every kernel definition + callsite of every in-tree aten kernel.
+            return None
+
+
+# The function schema is undoubtedly the most important data structure
+# in all of the codegen, as it defines the type signature for operators,
+# and most of the code generation we do is type directed (e.g., look at
+# the types, decide what to do.  Think about how we code generate
+# C++ function stubs!)
+#
+# We will also see in this class the general structure for how we model
+# data in this code generation.  A few notable properties to point out
+# ahead of time:
+#
+#   - These dataclasses are a *lossless* representation of the strings
+#     they are parsed from.  In fact, we assert that given the
+#     information stored in the dataclass, we can exactly reconstruct
+#     the string we parsed from (and assert this inside the parse
+#     definition).  There are a few reasons for this:
+#
+#       - If you find that it is difficult to reconstruct the string
+#         given a dataclass, that is a clue that you are data
+#         representation is wrong.
+#
+#       - It helps ensure that all relevant information is present
+#         in the dataclass, so that downstream users aren't tempted
+#         to reparse the original string to get some information
+#         that was omitted.
+#
+#       - It forces you to represent the data in-memory in the same way
+#         it is recorded textually, which makes the dataclasses easier
+#         to understand for someone who is familiar with the
+#         textual format.  (As a tradeoff, it means you have to model
+#         the syntax, even when it is inconvenient.  But maybe that means
+#         the syntax is bad!)  If you don't understand the internal
+#         representation, go look at the printing code to see how
+#         it maps onto the surface syntax!
+#
+#       - It makes it easy to test the parsing code, as parsing code
+#         that is inconsistent with the string code will fail early
+#         and loudly.  (As a tradeoff, it makes the parsing code a bit
+#         brittle (in particular, with trivial whitespace changes you
+#         are likely to trigger an assert error).
+#
+#     In general, try to make the __str__ code as simple as possible
+#     (even at the cost of more complex parsing logic.)  Additionally,
+#     try to minimize redundancy in data representation.  (Precomputed
+#     fields are OK though: they are defined as a simple function on
+#     the canonical representation in question.)
+#
+#   - These dataclasses are all frozen; once constructed their
+#     values never change.  This makes it easy to tell where any
+#     given data came from: just look to the constructor.  As a
+#     tradeoff, you can't easily "decorate" a schema with extra
+#     information from a post-facto analysis.  We impose this
+#     restriction to make these structures more understandable.
+#
+@dataclass(frozen=True)
+class FunctionSchema:
+    # The name of the operator this function schema describes.
+    name: OperatorName
+
+    arguments: Arguments
+
+    # TODO: Need to handle collisions with argument names at some point
+    returns: tuple[Return, ...]
+
+    @property
+    def is_mutable(self) -> bool:
+        def is_write(arg: Argument) -> bool:
+            if arg.annotation is None:
+                return False
+            return arg.annotation.is_write
+
+        # Corresponds to torch._C._FunctionSchema.is_mutable
+        # See aten/src/ATen/core/function_schema.h (keep these in sync)
+        return any(is_write(a) for a in self.arguments.flat_all)
+
+    def schema_order_arguments(self) -> Iterator[Argument]:
+        return itertools.chain(
+            self.arguments.flat_positional,
+            self.arguments.flat_kwarg_only,
+            self.arguments.out,
+        )
+
+    decl_re = re.compile(r"(?P[^\(]+)\((?P.*)\) -> (?P.*)")
+
+    @staticmethod
+    def parse(func: str) -> FunctionSchema:
+        # We should probably get a proper parser here
+        decls = FunctionSchema.decl_re.findall(func)
+        assert len(decls) == 1, f"Invalid function schema: {func}"
+        ops, args, return_decl = decls[0]
+        name = OperatorName.parse(ops)
+        arguments = Arguments.parse(args)
+        returns = parse_returns(return_decl)
+        r = FunctionSchema(name=name, arguments=arguments, returns=returns)
+        assert str(r) == func, f"{str(r)} != {func}"
+        return r
+
+    def returns_are_aliased(self) -> bool:
+        # We assert earlier that schemas can't have a mix of aliased and non-aliased returns
+        return any(
+            r
+            for r in self.returns
+            if r.annotation is not None and r.annotation.is_write
+        )
+
+    def __post_init__(self) -> None:
+        for arg, ret in zip(self.arguments.out, self.returns):
+            assert arg.annotation == ret.annotation, (
+                "Out arguments must have matching return Tensor; furthermore, "
+                "the ith-argument needs to correspond to the ith return"
+            )
+        # We also enforce that if you have any mutable, positional args, then they are not returned.
+        # This makes it easier to group these functions properly with their functional/out= counterparts.
+        for a in self.arguments.post_self_positional_mutable:
+            assert not any(a.annotation == r.annotation for r in self.returns), (
+                f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}"
+            )
+        # Invariant: we expect out arguments to appear as keyword arguments in the schema.
+        # This means that all mutable returns should be aliased to a keyword argument
+        # (except for "self", which we explicitly don't treat as an out argument because of its use in methods)
+        # See Note [is_out_fn]
+        out_and_self = list(self.arguments.out) + [
+            arg for arg in self.arguments.flat_positional if arg.name == "self"
+        ]
+        mutable_returns = [
+            ret
+            for ret in self.returns
+            if ret.annotation is not None and ret.annotation.is_write
+        ]
+        immutable_returns = [
+            ret
+            for ret in self.returns
+            if ret.annotation is None or not ret.annotation.is_write
+        ]
+        # Some assertions: We don't want any functions with a return type of "-> (Tensor(a!), Tensor)",
+        # because:
+        # (1) It's more annoying to handle properly
+        # (2) It's unnecessary - you can't method-chain on the first (mutated) output because it's part of a tuple.
+        # Instead, we expect the (a!) argument to not be returned.
+        assert len(mutable_returns) == 0 or len(immutable_returns) == 0, (
+            f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}"
+        )
+        for ret in mutable_returns:
+            assert any(ret.annotation == arg.annotation for arg in out_and_self), (
+                'All mutable returns must be aliased either to a keyword argument, or to "self". '
+                "Did you forget to mark an out argument as keyword-only?"
+            )
+        if self.arguments.out:
+            # out= ops that return their mutable inputs are only really useful for method chaining.
+            # And method chaining is only really useful if the thing you're returning is a plain Tensor.
+            # So ideally, we'd enforce that out= ops with a single plain mutable tensor should return the tensor,
+            # and all other types of out= op schemas should return void.
+            # There are a bunch of existing out= ops that return tuples of tensors though, so we're stuck with allowing that.
+            if any(a.type != BaseType(BaseTy.Tensor) for a in self.arguments.out):
+                assert len(self.returns) == 0, (
+                    "out= ops that accept tensor lists as out arguments "
+                )
+                "are expected to have no return type (since you can't do method chaining on them)"
+            else:
+                # mutable keyword arguments whose name has _scratch_ prefix are
+                # scratch tensors for memory planning and should not be returned
+                assert len(
+                    [
+                        arg
+                        for arg in self.arguments.out
+                        if not arg.name.startswith("_scratch_")
+                    ]
+                ) == len(self.returns), (
+                    "Must return as many arguments as there are out arguments, or no return at all"
+                )
+
+        if self.name.name.inplace:
+            self_a = self.arguments.self_arg
+            assert (
+                self_a
+                and self_a.argument.annotation
+                and self_a.argument.annotation.is_write
+            )
+            if self_a.argument.type == BaseType(BaseTy.Tensor):
+                # All inplace ops with an ordinary `Tensor self` argument should return self,
+                # to allow for method chaining.
+                assert (
+                    len(self.returns) == 1
+                    and self.returns[0].annotation == self_a.argument.annotation
+                )
+            else:
+                # You can't method chain on non-tensor self arguments though (like a list[Tensor])
+                # so in all other cases we expect the return type to be none.
+                assert len(self.returns) == 0
+
+        if self.arguments.tensor_options is not None:
+            assert self.kind() == SchemaKind.functional, (
+                "Found an operator that is not functional or out variant, but has tensor options arguments."
+                "This is not allowed- tensor options arguments are only allowed for factory functions."
+                f"schema: {str(self)}"
+            )
+        if self.is_functional_fn():
+            assert self.kind() == SchemaKind.functional, (
+                "Found an operator that is not functional, but its overload contains the string 'functional'."
+                "This is a special keyword in the codegen, please use a different overload name."
+                f"schema: {str(self)}"
+            )
+
+    def is_functional_fn(self) -> bool:
+        return "functional" in self.name.overload_name
+
+    def is_out_fn(self) -> bool:
+        # Note [is_out_fn]
+        #
+        # out functions are the variants which take an explicit out= argument
+        # to populate into.  We need to know if a schema corresponds to an
+        # out function for several reasons:
+        #
+        #   - They codegen differently in C++ API
+        #       - codegen to at::add_out rather than at::add
+        #       - out argument is moved to front of C++ argument list
+        #
+        # out functions are DEFINED to be any function with a keyword-only
+        # argument that is mutable.  In principle, this could lead to a
+        # false positive if you define a function that mutates a
+        # kwarg only argument, but this isn't the "true" output of this
+        # function.  A more robust definition that would work in this
+        # case would also look at:
+        #
+        #   - The output types.  Out functions take in the arguments
+        #     they mutate and then return them again; this is sort
+        #     of "definitionally" what makes something an out function.
+        #     Historically, we DO check this for consistency.
+        #   - Correspondence with pure variant.  An out function
+        #     should have a signature equivalent to its pure variant,
+        #     but just with extra kwargs for the output elements.  This
+        #     is difficult to actually check for and historically
+        #     we only do this check in tools/
+        return bool(self.arguments.out)
+
+    def kind(self) -> SchemaKind:
+        """
+        What kind of schema is this?  A functional schema is one
+        that returns a newly allocated output; an inplace schema
+        modifies the self argument inplace; an out schema writes
+        the result into an explicitly provided out argument.
+        """
+        is_out = bool(self.arguments.out)
+        is_scratch = bool(
+            [arg for arg in self.arguments.out if arg.name.startswith("_scratch_")]
+        )
+        is_inplace = self.name.name.inplace
+        is_mutable = any(
+            a.annotation is not None and a.annotation.is_write
+            for a in self.arguments.post_self_positional
+        )
+        assert not (is_out and is_inplace)
+        # out= and inplace schemas can also have post_self_positional mutable args,
+        # but we give precedence to out= and inplace when deciding the schema kind.
+        # Tradeoff: we probably don't want to have to teach codegen that looks at inplace ops
+        # to also worry about mutable post_self_positional arguments,
+        # but it seems like a much bigger lift to classify them has having a new schema kind.
+        # The number of ops that fit in this strange category is small enough that
+        # we can probably manually write code for them instead of forcing the codegen to handle them.
+        if is_inplace:
+            return SchemaKind.inplace
+        elif is_scratch:
+            assert is_out, (
+                "invariant: all scratch operators are expected to be out= operators too"
+            )
+            return SchemaKind.scratch
+        elif is_out:
+            assert not is_scratch, (
+                "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!"
+            )  # noqa: B950
+            return SchemaKind.out
+        elif is_mutable:
+            return SchemaKind.mutable
+        else:
+            return SchemaKind.functional
+
+    # For every return:
+    # - If the return aliases an input, we return the input name
+    # - Otherwise, we return None.
+    # If return names were enforced to be consistent with aliasing information, then we wouldn't need this.
+    def aliased_return_names(self) -> list[str | None]:
+        outs: list[str | None] = []
+        for r in self.returns:
+            aliased_args = [
+                a
+                for a in self.arguments.flat_all
+                if a.annotation is not None and a.annotation == r.annotation
+            ]
+            if len(aliased_args) == 0:
+                outs.append(None)
+            elif len(aliased_args) == 1:
+                outs.append(aliased_args[0].name)
+            else:
+                aliased_names = ", ".join(a.name for a in aliased_args)
+                raise AssertionError(
+                    f"Found a return ({r.name})that aliases multiple inputs ({aliased_names})"
+                )
+        return outs
+
+    def signature(
+        self,
+        *,
+        strip_default: bool = False,
+        strip_view_copy_name: bool = False,
+        keep_return_names: bool = False,
+    ) -> FunctionSchema:
+        """
+                Certain schemas are 'related', in that they are simply
+                inplace/out/functional versions of the same function.  This method
+                factors these schemas into the "core" functional signature which
+                is equal across all versions.
+
+                Here is what normalization happens to the schema to convert
+                it to a signature:
+                - The overload name is stripped (name is retained, since
+                  it expresses semantic content about what the function does)
+                - Inplace is set False
+                - Out arguments are stripped
+                - Mutable post_self_positional args are converted to returns
+                - Mutability annotations are stripped  (this is sound
+                  because you cannot overload on mutability annotation)
+                - Return names are stripped since they are not overloadable and
+                  some variants have return names but some not
+                - TensorOptions are dropped
+                  because out= variants of factory functions don't include them
+                  (and we want to be able to pair up factory functions with their out variants)
+
+                Finally, we want to be able to pair up related "view" and their
+                corresponding "view_copy" operators. We do this by optionally
+                stripping the trailing "_copy" from the base name.
+
+                Example of a mutable op before and after:
+
+                f.func (Mutable operator):
+        _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask)  # noqa: B950
+
+                f.func (Corresponding functional operator):
+        _fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out)  # noqa: B950
+
+                f.func.signature() output:
+        _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)  # noqa: B950
+        """
+
+        def strip_ret_annotation(r: Return) -> Return:
+            return Return(
+                name=r.name if keep_return_names else None,
+                type=r.type,
+                annotation=None,
+            )
+
+        base_name = self.name.name.base
+        if strip_view_copy_name:
+            if base_name.endswith("_copy"):
+                base_name = base_name.replace("_copy", "")
+            elif base_name.endswith("_scatter"):
+                base_name = base_name.replace("scatter", "inverse")
+
+        # find mutable inputs that are not originally returned, and convert them to returns
+        returns_from_mutable_inputs = tuple(
+            # When we're grouping functions we strip the return names,
+            # but when we're generating the actual functional variants then we follow
+            # a convention for what to name the returns
+            Return(
+                name=f"{a.name}_out" if keep_return_names else None,
+                type=a.type,
+                annotation=None,
+            )
+            for a in itertools.chain(
+                # Order is important here (otherwise e.g. inplace with mutable args
+                # and out= with mutable args won't have the same signature)
+                (
+                    [self.arguments.self_arg.argument]
+                    if self.arguments.self_arg is not None
+                    else []
+                ),
+                self.arguments.out,
+                self.arguments.post_self_positional,
+            )
+            if a.annotation is not None
+            and a.annotation.is_write
+            and not any(a.annotation == r.annotation for r in self.returns)
+        )
+        original_returns = tuple(map(strip_ret_annotation, self.returns))
+        # Ordering is important here. We expect the "mutable input" returns to come last.
+        returns = original_returns + returns_from_mutable_inputs
+
+        args_sig = self.arguments.signature(strip_default=strip_default)
+        # See Note [bernoulli.p schema]
+        if str(self.name) == "bernoulli.p":
+            args_sig = Arguments.parse(str(args_sig).replace("float p", "float p=0.5"))
+
+        return FunctionSchema(
+            name=OperatorName(
+                name=BaseOperatorName(
+                    base=base_name,
+                    inplace=False,
+                    dunder_method=self.name.name.dunder_method,
+                ),
+                overload_name="",  # stripped
+            ),
+            arguments=args_sig,
+            returns=returns,
+        )
+
+    def view_signature(self) -> FunctionSchema:
+        return self.signature(strip_view_copy_name=True)
+
+    def with_name(self, name: OperatorName) -> FunctionSchema:
+        return FunctionSchema(
+            name=name,
+            arguments=self.arguments,
+            returns=self.returns,
+        )
+
+    @property
+    def modifies_arguments(self) -> bool:
+        return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable]
+
+    def has_symint(self) -> bool:
+        return self.arguments.has_symint_arg()
+
+    def __str__(self) -> str:
+        all_arguments_str = str(self.arguments)
+        if len(self.returns) == 1:
+            returns = str(self.returns[0])  # omit parentheses
+        else:
+            returns = "(" + ", ".join(map(str, self.returns)) + ")"
+        return f"{self.name}({all_arguments_str}) -> {returns}"
+
+
+# Here is the rest of the data model, described more briefly.
+
+
+# Simplified version for what actually shows up in built-ins.
+# Look at alias_info.h for expanded syntax.  If you need the structure,
+# you also need to make this structure recursive so it can be lined
+# up with the type components too.  For primitives this isn't really
+# necessary
+@dataclass(frozen=True)
+class Annotation:
+    # Typically only has one element.  Not actually a set so
+    # we can conveniently assume it is canonically ordered
+    alias_set: tuple[str, ...]
+    is_write: bool
+    alias_set_after: tuple[str, ...]
+
+    @staticmethod
+    def parse(ann: str) -> Annotation:
+        # TODO: implement a proper parser if this gets more ugly
+        # Regex Explanation:
+        # Example: "a! -> a|b"
+        # Group #1: alias before optional '|', required. Matches the first
+        #   character 'a' in the example
+        # Group #2: optional alias set after optional '|', matches empty string
+        #   in the example
+        # Group #3: optional "is write" flag, matches '!' in the example.
+        # Group #4: optional section containing arrow, matches " -> a|b" in the
+        #   example.
+        # Group #5: optional alias after set, supports wildcard, matches "a|b"
+        #   in the example.
+        # Group #6: optional sub-section of alias after set, matches "|b" in the
+        #   example.
+        m = re.match(r"^([a-z])(\|[a-z])*(!?)( -> (\*|[a-z](\|[a-z])*))?$", ann)
+
+        assert m is not None, f"unrecognized alias annotation {ann}"
+        before_alias = m.group(1) + (m.group(2) if m.group(2) else "")
+        alias_set = tuple(before_alias.split("|"))
+        is_write = m.group(3) == "!"
+        assert not (is_write and len(alias_set) > 1), (
+            f"alias set larger than 1 is not mutable, got {ann} instead."
+        )
+        after_set = tuple(m.group(5).split("|")) if m.group(5) else ()
+        assert not (len(before_alias) > 1 and len(after_set) > 1), (
+            f"before alias set and after alias set cannot be larger than 1 at the same time, got {ann} instead."
+        )
+        r = Annotation(
+            alias_set=alias_set, is_write=is_write, alias_set_after=after_set
+        )
+        assert str(r) == ann, f"{r} != {ann}"
+        return r
+
+    def __str__(self) -> str:
+        alias_set = "|".join(self.alias_set)
+        if self.is_write:
+            alias_set = f"{alias_set}!"
+        alias_set_after = "|".join(self.alias_set_after)
+        if alias_set_after:
+            alias_set = f"{alias_set} -> {alias_set_after}"
+        return alias_set
+
+
+# The base class for the type system.  This is also loosely modeled
+# off of jit_type.h, but we've simplified the hierarchy to focus
+# in on the aspects of the type system that matter for code generation
+# (for example, there's no SingleElementType subclass anymore).
+# You never actually construct a Type; usually it's going to be one
+# of the subclasses.  If Python had ADTs this would be one!
+@dataclass(frozen=True)
+class Type:
+    @staticmethod
+    def parse(t: str) -> Type:
+        r = Type._parse(t)
+        assert str(r) == t, f"{r} != {t}"
+        return r
+
+    @staticmethod
+    def _parse(t: str) -> Type:
+        m = re.match(r"^(.+)\?$", t)
+        if m is not None:
+            return OptionalType(Type.parse(m.group(1)))
+        m = re.match(r"^(.+)\[([0-9]+)?\]$", t)
+        if m is not None:
+            size = int(m.group(2)) if m.group(2) is not None else None
+            return ListType(elem=Type.parse(m.group(1)), size=size)
+
+        # '__torch__.torch.classes.' is the prefix for custom class
+        m = re.match(r"^__torch__\.torch\.classes\.([a-zA-Z0-9_.]+)$", t)
+        if m is not None:
+            return CustomClassType(m.group(1))
+        try:
+            return BaseType(BaseTy[t])
+        except KeyError as e:
+            raise RuntimeError(f"unrecognized type {t}") from e
+
+    def __str__(self) -> str:
+        raise NotImplementedError
+
+    # WARNING: These concepts are not very well-defined.  For example,
+    # is "int?" nullable? How about "int?[]".  They are defined
+    # so we can conveniently generate legacy Declarations.yaml but
+    # really we should probably just remove these at some point
+
+    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
+        raise NotImplementedError
+
+    def is_tensor_like(self) -> bool:
+        return self.is_base_ty_like(BaseTy.Tensor)
+
+    def is_generator_like(self) -> bool:
+        return self.is_base_ty_like(BaseTy.Generator)
+
+    def is_symint_like(self) -> bool:
+        return self.is_base_ty_like(BaseTy.SymInt)
+
+    def is_nullable(self) -> bool:
+        raise NotImplementedError
+
+    def is_list_like(self) -> ListType | None:
+        raise NotImplementedError
+
+
+# Base types are simple, atomic types with no further structure
+class BaseTy(Enum):
+    Generator = auto()
+    ScalarType = auto()
+    Tensor = auto()
+    int = auto()
+    Dimname = auto()
+    DimVector = auto()
+    float = auto()
+    str = auto()
+    bool = auto()
+    Layout = auto()
+    Device = auto()
+    DeviceIndex = auto()
+    Scalar = auto()
+    MemoryFormat = auto()
+    QScheme = auto()
+    Storage = auto()
+    Stream = auto()
+    SymInt = auto()
+    SymBool = auto()
+    GraphModule = auto()
+
+
+@dataclass(frozen=True)
+class BaseType(Type):
+    name: BaseTy
+
+    def __str__(self) -> str:
+        return f"{self.name.name}"
+
+    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
+        return self.name == base_ty
+
+    def is_nullable(self) -> bool:
+        return False
+
+    def is_list_like(self) -> ListType | None:
+        return None
+
+    def is_symint_like(self) -> bool:
+        return self.name == BaseTy.SymInt
+
+
+# Optional types may be specified, or may also be validly given None
+@dataclass(frozen=True)
+class OptionalType(Type):
+    elem: Type
+
+    def __str__(self) -> str:
+        return f"{self.elem}?"
+
+    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
+        return self.elem.is_base_ty_like(base_ty)
+
+    def is_symint_like(self) -> bool:
+        return self.elem.is_symint_like()
+
+    def is_nullable(self) -> bool:
+        return True
+
+    def is_list_like(self) -> ListType | None:
+        return self.elem.is_list_like()
+
+
+# A type representing a PyTorch custom class
+@dataclass(frozen=True)
+class CustomClassType(Type):
+    class_name: str
+
+    def __str__(self) -> str:
+        """
+        Return the class name will prefix __torch__.torch.classes
+        """
+        return f"__torch__.torch.classes.{self.class_name}"
+
+    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
+        return False
+
+    def is_symint_like(self) -> bool:
+        return False
+
+    def is_nullable(self) -> bool:
+        """
+        Assume a custom class is not nullable.
+        """
+        return False
+
+    def is_list_like(self) -> ListType | None:
+        return None
+
+
+# List types specify that we may have multiples of an element.  We
+# also support explicit sizes on list types, but these have
+# some nontrivial semantics!  (However, for C++ API purposes, explicit
+# sizes are mostly erased from the type system.)
+#
+# DANGER WILL ROBINSON: C++ elaboration depends on elem type; e.g.,
+# int[] elaborates differently than bool[3]!
+@dataclass(frozen=True)
+class ListType(Type):
+    elem: Type
+    size: int | None
+
+    def __str__(self) -> str:
+        size = f"{self.size}" if self.size else ""
+        return f"{self.elem}[{size}]"
+
+    def is_base_ty_like(self, base_ty: BaseTy) -> bool:
+        return self.elem.is_base_ty_like(base_ty)
+
+    def is_symint_like(self) -> bool:
+        return self.elem.is_symint_like()
+
+    def is_nullable(self) -> bool:
+        return self.elem.is_nullable()
+
+    def is_list_like(self) -> ListType | None:
+        return self
+
+
+@dataclass(frozen=True)
+class Argument:
+    # NB: I didn't put kwarg_only as a boolean field here, unlike
+    # c10::Argument, so that printing works correctly
+
+    name: str
+    type: Type
+    default: str | None
+
+    # The semantics of the annotation field are a little strange.
+    #
+    # Alias annotations parametrize Tensors (since Tensors are the only things
+    # that can alias.)  This motivates why I write Tensor(a!)?  (and not, for
+    # example, Tensor?(a!)), because the (a!) describes aliasing on the tensor,
+    # which may be optional (i.e., the alias annotation should bind first to
+    # Tensor, before the optional postfix annotation).
+    #
+    # However, despite being a property of Tensor, we (and c10::Argument)
+    # store the annotation at the top level of the Argument, rather than
+    # inside the embedded Tensor type.  In the C++ version of this
+    # class, we then go through great lengths to mimic the type
+    # structure in the annotation structure so we can correlate
+    # annotations with types.
+    #
+    # Now, it turns out, in all applications in code generation, the
+    # structure of annotated types is very simple.  So we just hard
+    # code it here.  But if we ever do get anything more complex, this
+    # model will have to change!
+    annotation: Annotation | None
+
+    @property
+    def alias_info(self) -> Annotation | None:
+        return self.annotation
+
+    @staticmethod
+    def parse(arg: str) -> Argument:
+        name: str
+        default: str | None
+        assert " " in arg, f"illegal argument '{arg}'"
+        if "=" in arg:
+            assert arg.count("=") == 1, f"illegal argument with default value: '{arg}'"
+            type_and_annot_and_name, default = arg.split("=")
+            type_and_annot, name = type_and_annot_and_name.rsplit(" ", 1)
+            name_and_default = f"{name}={default}"
+        else:
+            type_and_annot, name_and_default = arg.rsplit(" ", 1)
+            name = name_and_default
+            default = None
+        # TODO: deduplicate annotation matching with Return
+        match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
+        annotation: Annotation | None
+        if match:
+            # If you update this, make sure the __str__ still works too
+            assert match.group(2) in [
+                "",
+                "?",
+                "[]",
+            ], "unrecognized alias analysis form with Tensor"
+            type_s = "Tensor" + match.group(2)
+            annotation = Annotation.parse(match.group(1))
+        else:
+            type_s = type_and_annot
+            annotation = None
+        type = Type.parse(type_s)
+        r = Argument(
+            name=name,
+            type=type,
+            default=default,
+            annotation=annotation,
+        )
+        assert str(r) == arg, f"{str(r)} != {arg}"
+        return r
+
+    @property
+    def is_write(self) -> bool:
+        return self.annotation is not None and self.annotation.is_write
+
+    def __str__(self) -> str:
+        type = f"{self.type}"
+        if self.annotation:
+            assert type in ["Tensor", "Tensor?", "Tensor[]"]
+            type = type.replace("Tensor", f"Tensor({self.annotation})")
+        if self.name is None:
+            return type
+        else:
+            mb_default = ""
+            if self.default:
+                mb_default = f"={self.default}"
+            return f"{type} {self.name}{mb_default}"
+
+
+@dataclass(frozen=True)
+class Return:
+    name: str | None
+    type: Type
+    annotation: Annotation | None
+
+    @property
+    def alias_info(self) -> Annotation | None:
+        return self.annotation
+
+    @staticmethod
+    def parse(arg: str) -> Return:
+        name: str | None
+        if " " in arg:
+            type_and_annot, name = arg.rsplit(" ", 1)
+        else:
+            type_and_annot = arg
+            name = None
+        match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
+        annotation: Annotation | None
+        if match:
+            # If you update this, make sure the __str__ still works too
+            assert match.group(2) in [
+                "",
+                "?",
+                "[]",
+            ], "unrecognized alias analysis form with Tensor"
+            type_s = "Tensor" + match.group(2)
+            annotation = Annotation.parse(match.group(1))
+        else:
+            type_s = type_and_annot
+            annotation = None
+        type = Type.parse(type_s)
+        r = Return(
+            name=name,
+            type=type,
+            annotation=annotation,
+        )
+        assert str(r) == arg, f"{str(r)} != {arg}"
+        return r
+
+    @property
+    def is_write(self) -> bool:
+        return self.annotation is not None and self.annotation.is_write
+
+    def __str__(self) -> str:
+        type = f"{self.type}"
+        if self.annotation:
+            assert type in ["Tensor", "Tensor?", "Tensor[]"]
+            type = type.replace("Tensor", f"Tensor({self.annotation})")
+        if self.name is None:
+            return type
+        else:
+            return f"{type} {self.name}"
+
+
+# Represents the self argument for functions that may be methods
+@dataclass(frozen=True)
+class SelfArgument:
+    argument: Argument
+
+
+# Bundle of arguments that represent a TensorOptions.  This is mostly
+# relevant for the public C++ API but we bake it into the core data
+# model because other APIs often have to interact with it
+@dataclass(frozen=True)
+class TensorOptionsArguments:
+    dtype: Argument
+    layout: Argument
+    device: Argument
+    pin_memory: Argument
+
+    def all(self) -> Sequence[Argument]:
+        return [self.dtype, self.layout, self.device, self.pin_memory]
+
+
+@dataclass(frozen=True)
+class Arguments:
+    # pre_self_positional is usually empty, but is notably non-empty
+    # for where.self, where the condition argument comes before the
+    # self argument
+    pre_self_positional: tuple[Argument, ...]
+    self_arg: SelfArgument | None
+    post_self_positional: tuple[Argument, ...]
+
+    pre_tensor_options_kwarg_only: tuple[Argument, ...]
+    tensor_options: TensorOptionsArguments | None
+    # post_tensor_options is typically memory format, which should be
+    # part of tensor options but isn't right now, and is usually
+    # placed after the tensor options arguments
+    post_tensor_options_kwarg_only: tuple[Argument, ...]
+
+    # Unlike in the previous codegen, we have factored out 'out' arguments
+    # in the canonical representation, removing them from kwarg
+    # arguments.  This choice is justified by numerous downstream
+    # transformations which treat out arguments specially; additionally,
+    # you can see that canonicity is not violated!
+    out: tuple[Argument, ...]  # these are also kwarg-only
+
+    @property
+    def flat_non_out(self) -> Sequence[Argument]:
+        ret: list[Argument] = []
+        ret.extend(self.flat_positional)
+        ret.extend(self.flat_kwarg_only)
+        return ret
+
+    @property
+    def flat_positional(self) -> Sequence[Argument]:
+        ret: list[Argument] = []
+        ret.extend(self.pre_self_positional)
+        if self.self_arg is not None:
+            ret.append(self.self_arg.argument)
+        ret.extend(self.post_self_positional)
+        return ret
+
+    @property
+    def post_self_positional_mutable(self) -> Sequence[Argument]:
+        return [a for a in self.post_self_positional if a.is_write]
+
+    # NB: doesn't contain out arguments
+    @property
+    def flat_kwarg_only(self) -> Sequence[Argument]:
+        ret: list[Argument] = []
+        ret.extend(self.pre_tensor_options_kwarg_only)
+        if self.tensor_options is not None:
+            ret.extend(self.tensor_options.all())
+        ret.extend(self.post_tensor_options_kwarg_only)
+        return ret
+
+    @property
+    def flat_all(self) -> Sequence[Argument]:
+        ret: list[Argument] = []
+        ret.extend(self.flat_positional)
+        ret.extend(self.flat_kwarg_only)
+        ret.extend(self.out)
+        return ret
+
+    @property
+    def non_out(
+        self,
+    ) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
+        ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
+        ret.extend(self.positional)
+        ret.extend(self.kwarg_only)
+        return ret
+
+    @property
+    def positional(self) -> Sequence[Argument | SelfArgument]:
+        ret: list[Argument | SelfArgument] = []
+        ret.extend(self.pre_self_positional)
+        if self.self_arg is not None:
+            ret.append(self.self_arg)
+        ret.extend(self.post_self_positional)
+        return ret
+
+    @property
+    def kwarg_only(self) -> Sequence[Argument | TensorOptionsArguments]:
+        ret: list[Argument | TensorOptionsArguments] = []
+        ret.extend(self.pre_tensor_options_kwarg_only)
+        if self.tensor_options is not None:
+            ret.append(self.tensor_options)
+        ret.extend(self.post_tensor_options_kwarg_only)
+        return ret
+
+    @property
+    def all(self) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
+        ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
+        ret.extend(self.positional)
+        ret.extend(self.kwarg_only)
+        ret.extend(self.out)
+        return ret
+
+    def mutable_arg_names(self) -> list[str]:
+        return [
+            a.name
+            for a in self.flat_all
+            if a.annotation is not None and a.annotation.is_write
+        ]
+
+    def has_tensor_arg(self) -> bool:
+        return any(a.type.is_tensor_like() for a in self.flat_non_out)
+
+    def has_symint_arg(self) -> bool:
+        return any(a.type.is_symint_like() for a in self.flat_non_out)
+
+    def has_generator_arg(self) -> bool:
+        return any(a.type.is_generator_like() for a in self.flat_non_out)
+
+    def signature(self, *, strip_default: bool = False) -> Arguments:
+        # dataclasses.replace could be used here, but it is less
+        # type safe so for now I've opted to type everything out
+        def strip_arg_annotation(a: Argument) -> Argument:
+            return Argument(
+                name=a.name,
+                type=a.type,
+                default=a.default if not strip_default else None,
+                annotation=None,
+            )
+
+        return Arguments(
+            pre_self_positional=tuple(
+                map(strip_arg_annotation, self.pre_self_positional)
+            ),
+            self_arg=(
+                SelfArgument(strip_arg_annotation(self.self_arg.argument))
+                if self.self_arg is not None
+                else None
+            ),
+            post_self_positional=tuple(
+                map(strip_arg_annotation, self.post_self_positional)
+            ),
+            # Since TensorOptions are dropped, the post_tensor_options_kwargs are
+            # converted to pre_tensor_options_kwargs
+            pre_tensor_options_kwarg_only=tuple(
+                map(strip_arg_annotation, self.pre_tensor_options_kwarg_only)
+            )
+            + tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)),
+            # TensorOptions are dropped in signature,
+            # so we can pair factory functions with their out= variants.
+            tensor_options=None,
+            post_tensor_options_kwarg_only=(),
+            # out arguments are dropped in signature
+            out=(),
+        )
+
+    def remove_self_annotation(self) -> Arguments:
+        assert self.self_arg is not None
+        return dataclasses.replace(
+            self,
+            self_arg=SelfArgument(
+                dataclasses.replace(self.self_arg.argument, annotation=None)
+            ),
+        )
+
+    def with_out_args(self, outs: list[Argument]) -> Arguments:
+        assert len(self.out) == 0
+        return dataclasses.replace(
+            self,
+            out=tuple(outs),
+        )
+
+    @staticmethod
+    def _preparse(args: str) -> tuple[list[Argument], list[Argument], list[Argument]]:
+        positional: list[Argument] = []
+        kwarg_only: list[Argument] = []
+        out: list[Argument] = []
+        arguments_acc = positional
+
+        # TODO: Use a real parser here; this will get bamboozled
+        # by signatures that contain things like std::array (note the space)
+        for arg in args.split(", "):
+            if not arg:
+                continue
+            if arg == "*":
+                assert arguments_acc is positional, (
+                    "invalid syntax: kwarg-only specifier * can only occur once"
+                )
+                arguments_acc = kwarg_only
+                continue
+            parg = Argument.parse(arg)
+            # Currently, we rely directly on the invariant that there are NO
+            # kwarg-only mutating arguments.  If you want to relax this,
+            # we will need a more semantic way of matching that takes
+            # into account return arguments.  In that case, you will have
+            # to manage out computation a level up, in FunctionSchema.  See Note
+            # [is_out_fn]
+            if parg.annotation is not None and parg.annotation.is_write:
+                if arguments_acc is positional:
+                    pass  # do nothing
+                elif arguments_acc is kwarg_only:
+                    arguments_acc = out
+            else:
+                assert arguments_acc is not out
+            arguments_acc.append(parg)
+
+        return positional, kwarg_only, out
+
+    @staticmethod
+    def parse(args: str) -> Arguments:
+        """
+        Input: 'int x, int y, int z'
+        """
+
+        # We do this in two phases.  First we parse into three
+        # main categories: positional, kwarg_only, out.
+        # Then, we reparse positional and kwarg_only to separate
+        # out the self argument and tensor options arguments.
+
+        positional, kwarg_only, out = Arguments._preparse(args)
+
+        # Split self argument
+        self_ix = None
+        for i, a in enumerate(positional):
+            if a.name == "self":
+                self_ix = i
+                break
+        pre_self_positional: list[Argument]
+        self_arg: SelfArgument | None
+        post_self_positional: list[Argument]
+        if self_ix is not None:
+            pre_self_positional = positional[:self_ix]
+            self_arg = SelfArgument(positional[self_ix])
+            post_self_positional = positional[self_ix + 1 :]
+        else:
+            pre_self_positional = []
+            self_arg = None
+            post_self_positional = positional
+
+        # Group tensor options arguments
+        pre_tensor_options_kwarg_only: list[Argument] = []
+        tensor_options: TensorOptionsArguments | None = None
+        post_tensor_options_kwarg_only: list[Argument] = []
+        kwarg_only_acc = pre_tensor_options_kwarg_only
+
+        def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
+            return lambda a: a.name == name and a.type in [ty, OptionalType(ty)]
+
+        predicates = [  # order matters
+            pred("dtype", Type.parse("ScalarType")),
+            pred("layout", Type.parse("Layout")),
+            pred("device", Type.parse("Device")),
+            pred("pin_memory", Type.parse("bool")),
+        ]
+
+        i = 0
+        while i < len(kwarg_only):
+            # If there is enough space...
+            if i <= len(kwarg_only) - len(predicates):
+                # And the next len(predicates) arguments look like TensorOptions arguments
+                if all(
+                    p(a)
+                    for p, a in zip(predicates, kwarg_only[i : i + len(predicates)])
+                ):
+                    assert kwarg_only_acc is pre_tensor_options_kwarg_only
+                    # Group them together as one argument
+                    tensor_options = TensorOptionsArguments(
+                        dtype=kwarg_only[i],
+                        layout=kwarg_only[i + 1],
+                        device=kwarg_only[i + 2],
+                        pin_memory=kwarg_only[i + 3],
+                    )
+                    i += len(predicates)
+                    kwarg_only_acc = post_tensor_options_kwarg_only
+                    continue
+            kwarg_only_acc.append(kwarg_only[i])
+            i += 1
+
+        return Arguments(
+            pre_self_positional=tuple(pre_self_positional),
+            self_arg=self_arg,
+            post_self_positional=tuple(post_self_positional),
+            pre_tensor_options_kwarg_only=tuple(pre_tensor_options_kwarg_only),
+            tensor_options=tensor_options,
+            post_tensor_options_kwarg_only=tuple(post_tensor_options_kwarg_only),
+            out=tuple(out),
+        )
+
+    def __str__(self) -> str:
+        all_arguments: list[str] = []
+        all_arguments.extend(map(str, self.flat_positional))
+        if self.flat_kwarg_only or self.out:
+            all_arguments.append("*")
+        all_arguments.extend(map(str, self.flat_kwarg_only))
+        all_arguments.extend(map(str, self.out))
+        return ", ".join(all_arguments)
+
+    def __post_init__(self) -> None:
+        # TODO: These invariants are weirdly asymmetric?
+        # TODO: Fancier types?
+        if self.self_arg is None:
+            assert not self.pre_self_positional
+        if self.tensor_options is None:
+            assert not self.post_tensor_options_kwarg_only
+
+        # We don't allow any of the following to have argument annotations,
+        # to keep things simple.
+        mutable_pre_self_positionals = [
+            a
+            for a in self.pre_self_positional
+            if a.annotation is not None and a.annotation.is_write
+        ]
+        assert len(mutable_pre_self_positionals) == 0, (
+            "mutable pre_self_positional arguments are not currently supported in the schema"
+        )
+
+
+# Names that validly are __iXXX__ indicating inplace operations.
+# Taken from https://www.python.org/dev/peps/pep-0203/#new-methods
+# NB: PyTorch hasn't actually implemented all of these
+AUGMENTED_ASSIGNMENT_NAMES = [
+    "add",
+    "sub",
+    "mul",
+    "div",
+    "mod",
+    "pow",
+    "lshift",
+    "rshift",
+    "and",
+    "xor",
+    "or",
+]
+
+
+# A BaseOperatorName is what we think of the operator name, without
+# the overload name.  Unusually, we don't represent this as just a
+# string; instead, we directly represent a few important semantic
+# bits of information we derive from the string: namely whether
+# or not it's inplace (add_) and whether or not it's a double-underscore
+# method (__add__)
+@dataclass(frozen=True)
+class BaseOperatorName:
+    base: str
+    inplace: bool
+    dunder_method: bool
+    # Note [Overload Ambiguity With Functional Variants]
+    # A handful of operators have both a "mutable" and a "functional" variant.
+    # (native_batch_norm is a good example, although this isn't the case today).
+    # For those operators, the mutable and functional variant take in the same set of
+    # arguments, but have different alias annotations.
+    # this makes it ambiguous when you try to resolve an OverloadPacket into an overload,
+    # given a set of input arguments.
+    #
+    # So instead of making the "functional" variant in this case a real overload, e.g:
+    #   native_batch_norm (mutable variant)
+    #   native_batch_norm.functional (functional variant)
+    # we make it a new base operator,
+    #   native_batch_norm_functional (functional variant)
+    #
+    # In an ideal world, we would probably invert this so the operators were:
+    #   native_batch_norm.mutable (mutable variant)
+    #   native_batch_norm (functional variant)
+    #
+    # Doing that is BC-breaking though, so we're stuck with the above modeling.
+    functional_overload: bool = False
+
+    # NB: We don't officially support namespace in FunctionSchema, we treat this prefix
+    # as part of the base operator name, for __str__() to consume.
+    # The canonical input (from the rest of the infra) will not contain namespace, but
+    # we have a usecase in ExecuTorch where we want to support BaseOperatorName with namespace.
+    namespace: Optional[str] = None
+
+    @staticmethod
+    def parse(op: str) -> BaseOperatorName:
+        assert op != ""
+        assert not op.endswith("_out"), (
+            "_out suffix is reserved and not permitted for operator names; "
+            "did you mean to specify an out overload name instead?"
+        )
+        # Extract namespace out. Base operator name may or may not contain namespace.
+        # E.g., aten::__lshift__ is a valid base operator name, __lshift__ is also valid.
+        # We want to split the namespace out from the base operator name.
+        match = re.match(r"^(?:(.*)::)?(.*)$", op)
+        namespace = match.group(1) if match else ""
+        op_without_ns = match.group(2) if match else op
+        m = re.match(r"^__([^_]+)__$", op_without_ns)
+        if m is not None:
+            dunder_method = True
+            base = m.group(1)
+            if any(base == f"i{n}" for n in AUGMENTED_ASSIGNMENT_NAMES):
+                inplace = True
+                base = base[1:]
+            else:
+                inplace = False
+                # temporary, this is not intrinsically true but
+                # has been historically true for dunder methods
+                # we support  (but, if we ever got, say, __int__, this would
+                # be wrong!)
+                assert base[0] != "i"
+        else:
+            dunder_method = False
+            base = op_without_ns
+            if base[-1] == "_":
+                inplace = True
+                base = base[:-1]
+            else:
+                inplace = False
+
+        # See Note [Overload Ambiguity With Functional Variants]
+        functional_suffix = "_functional"
+        if base.endswith(functional_suffix):
+            functional_overload = True
+            base = base[: -len(functional_suffix)]
+            # This seems complicated and unnecessary, so banning dunder methods
+            # for now on ops that have a functional + mutable variant (like native_batch_norm).
+            assert not dunder_method and not inplace
+        else:
+            functional_overload = False
+
+        r = BaseOperatorName(
+            base=base,
+            inplace=inplace,
+            dunder_method=dunder_method,
+            functional_overload=functional_overload,
+            namespace=namespace,
+        )
+        assert str(r) == op, f"{str(r)} != {op}"
+        return r
+
+    def __str__(self) -> str:
+        namespace_prefix = f"{self.namespace}::" if self.namespace else ""
+        if self.dunder_method:
+            i = "i" if self.inplace else ""
+            return f"{namespace_prefix}__{i}{self.base}__"
+        else:
+            i = (
+                "_"
+                if self.inplace
+                else "_functional"
+                if self.functional_overload
+                else ""
+            )
+            return f"{namespace_prefix}{self.base}{i}"
+
+
+# Operator name is the base operator name along with the (typically not
+# user visible) overload string.
+@dataclass(frozen=True)
+class OperatorName:
+    name: BaseOperatorName
+    overload_name: str
+
+    @staticmethod
+    def parse(op_name: str) -> OperatorName:
+        if "." in op_name:
+            name, overload_name = op_name.split(".", 1)
+        else:
+            name = op_name
+            overload_name = ""
+        r = OperatorName(name=BaseOperatorName.parse(name), overload_name=overload_name)
+        assert str(r) == op_name, f"{str(r)} != {op_name}"
+        return r
+
+    def __str__(self) -> str:
+        if self.overload_name:
+            return f"{self.name}.{self.overload_name}"
+        else:
+            return f"{self.name}"
+
+    # NB: This must be synchronized with the naming scheme in
+    # aten/src/ATen/templates/Operators.h
+    # Given a function schema "aten::op.overload(...)",
+    # If there is no overload name, this returns f"{op}"
+    # If there is an overload name, this returns f"{op}_{overload}"
+    def unambiguous_name(self) -> str:
+        if self.overload_name:
+            return f"{self.name}_{self.overload_name}"
+        else:
+            return f"{self.name}"
+
+    def remove_inplace(self) -> OperatorName:
+        return OperatorName(
+            name=BaseOperatorName(
+                base=self.name.base,
+                inplace=False,
+                dunder_method=self.name.dunder_method,
+            ),
+            overload_name=self.overload_name,
+        )
+
+    def with_overload(self, overload: str) -> OperatorName:
+        return OperatorName(
+            name=BaseOperatorName(
+                base=self.name.base,
+                inplace=False,
+                dunder_method=self.name.dunder_method,
+            ),
+            overload_name=overload,
+        )
+
+
+def gets_generated_out_inplace_wrapper(
+    f: NativeFunction, g: NativeFunctionsGroup, b: BackendIndex
+) -> bool:
+    return (
+        f.func.kind() is not SchemaKind.functional
+        and not b.has_kernel(f)
+        and b.has_kernel(g.functional)
+    )
+
+
+# NativeFunction objects that are views (f.is_view_op returns True)
+# are added into a `NativeFunctionsViewGroup`, which we can use to
+# easily access the generated (optional) view_copy NativeFunction.
+# It's convenient to group them together, so we pair them up in NativeFunctionsViewGroup.
+# See Note [Codegen'd {view}_copy Operators]
+#
+# One property of this representation is that in order for a view-like op to be part of
+# a NativeFunctionsViewGroup, the "aliasing" version of that view op must exist.
+# There's one case where that doesn't happen: we have a non-aliasing `narrow_copy.out` op,
+# but don't have corresponding aliasing `narrow.out` op.
+# This means that `narrow_copy.out` won't appear as a NativeFunctionsViewGroup.
+@dataclass(frozen=True)
+class NativeFunctionsViewGroup:
+    view: NativeFunction
+    # Note: the {view}_copy operator is optional because we currently don't generate copy variants
+    # for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views
+    # (we already get them "for free" through decomposition)
+    view_copy: NativeFunction | None
+    # view_inplace ops are also optional, but every view_inplace op should have out-of-place variant.
+    view_inplace: NativeFunction | None
+
+    def __post_init__(self) -> None:
+        assert self.view.is_view_op
+        if self.view_copy is None:
+            assert not gets_generated_view_copy(self.view), (
+                f"{str(self.view.func.name)} appears to be a new operator that aliases its inputs."
+                " The codegen expects you to add a corresponding operator to native_functions.yaml:"
+                f" {get_view_copy_name(self.view)!s}."
+                " See Note [view_copy NativeFunctions] for details."
+            )
+        else:
+            assert self.view_copy.func.name.name.base.endswith(("_copy", "_scatter"))
+            assert self.view.func.signature() == self.view_copy.func.signature(
+                strip_view_copy_name=True,
+            )
+            assert "view_copy" in self.view_copy.tags, (
+                f"{str(self.view_copy.func.name), str(self.view.tags)} appears to be a view_copy operator. The codegen expects"
+                " view_copy operators to be annotated with the 'view_copy' tag in native_functions.yaml."
+                " See Note [view_copy NativeFunction] for details."
+            )
+        if self.view_inplace is not None:
+            assert self.view.func.signature() == self.view_inplace.func.signature()
+
+        if self.view.has_composite_implicit_autograd_kernel:
+            if self.view_inplace is not None:
+                assert self.view_inplace.has_composite_implicit_autograd_kernel, (
+                    f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
+                    " both have CompositeImplicitAutograd kernels, or both not have composite kernels."
+                )
+        if self.view.has_composite_implicit_autograd_nested_tensor_kernel:
+            if self.view_inplace is not None:
+                assert self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel, (
+                    f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either"
+                    " both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels."
+                )
+
+    def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]:
+        yield self.view
+        if self.view_inplace is not None:
+            yield self.view_inplace
+        if self.view_copy is not None and include_copy:
+            yield self.view_copy
+
+    @property
+    def root_name(self) -> str:
+        return self.view.root_name
+
+    @property
+    def composite(self) -> bool:
+        # We currently assert that the "group" is consistent.
+        # If the view op is composite, then its view_inplace op is too.
+        return self.view.has_composite_implicit_autograd_kernel
+
+
+def gets_generated_view_copy(f: NativeFunction) -> bool:
+    # Only aliasing (view) operators get a copy variant.
+    if not f.is_view_op:
+        return False
+    # We don't need to bother generating copy variants for CompositeImplicitAutograd ops,
+    # because we can let them decompose into base view ops.
+    if f.has_composite_implicit_autograd_kernel:
+        return False
+    # We also don't need to generate copy variants for inplace views.
+    if "inplace_view" in f.tags:
+        return False
+    # Assume ops ending in _inverse have manually-defined copy variants
+    # (e.g. slice_inverse() has the copy variant slice_scatter()).
+    # We -could- probably generate these as well, but the codegen will be
+    # slightly different, and hand-writing these few kernels keeps codegen
+    # complexity lower.
+    if f.func.name.name.base.endswith("_inverse"):
+        return False
+    return True
+
+
+# Given a NativeFunction that corresponds to a view op,
+# returns the OperatorName of the corresponding "copy" variant of the op.
+def get_view_copy_name(f: NativeFunction) -> OperatorName:
+    # Right now, when asking for a view op's corresponding "view_copy" name
+    # we assert for sanity that the op is allowed to have a generated view_copy variant.
+    # (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op).
+    # However, narrow_copy() already exists as an op directly in native_functions.yaml.
+    # I'm hardcoding narrow_copy here for now to maintain the assert,
+    # But we could also just get rid of the assert.
+    list_of_ops_with_explicit_view_copy_operators = ["narrow"]
+    if str(f.func.name) not in list_of_ops_with_explicit_view_copy_operators:
+        assert gets_generated_view_copy(f)
+
+    base_name = f"{f.func.name.name.base}_copy"
+    view_copy_name = OperatorName(
+        name=BaseOperatorName(
+            base=base_name, inplace=False, dunder_method=f.func.name.name.dunder_method
+        ),
+        overload_name=f.func.name.overload_name,
+    )
+    return view_copy_name
+
+
+# Helper functions for parsing argument lists (both inputs and returns)
+
+
+def parse_returns(return_decl: str) -> tuple[Return, ...]:
+    """
+    Input: '()'
+    Output: []
+    """
+    if return_decl == "()":
+        return ()
+    if return_decl[0] == "(" and return_decl[-1] == ")":
+        return_decl = return_decl[1:-1]
+    return tuple(Return.parse(arg) for arg in return_decl.split(", "))
+
+
+# A Precompute instance consists of a map from kernel argument name
+# to the list of Argument instances that should replace that
+# kernel argument in the impl function.
+@dataclass(frozen=True)
+class Precompute:
+    # A map from kernel argument name -> a list of precomputed
+    # elements that replaces/supersedes it.
+    replace: dict[str, list[Argument]]
+    # List of precomputed args added without replacement
+    add: list[Argument]
+
+    @staticmethod
+    def parse(src: object) -> Precompute:
+        assert isinstance(src, list)
+
+        # src is a list of strings of the format:
+        #   {kernel param name} -> {replacement decl}[, {replacement decl}, ...]
+        #   [{add decl}[, {add decl}, ...]]
+        # The last line is optional and contains the precomputed parameters that are
+        # added without replacement.
+        # The other lines are parsed to get the names of which precomputed elements
+        # should replace which kernel arguments.
+        add_args = []
+        if " -> " not in src[-1]:
+            add_list = src[-1].split(",")
+            add_args = [Argument.parse(name.strip()) for name in add_list]
+            src = src[:-1]
+
+        replace = {}
+        for raw_replace_item in src:
+            assert isinstance(raw_replace_item, str)
+            assert " -> " in raw_replace_item, (
+                "precomputed parameters without replacement"
+                " are allowed only in the last line"
+            )
+
+            arg, with_list_raw = raw_replace_item.split(" -> ")
+            assert " " not in arg, (
+                f"illegal kernel param name '{arg}' in precomputed parameters'"
+            )
+            with_list = with_list_raw.split(",")
+            with_list_args = [Argument.parse(name.strip()) for name in with_list]
+            replace[arg] = with_list_args
+
+        r = Precompute(replace=replace, add=add_args)
+        assert r.to_list() == src, "r.to_list() != src"
+        return r
+
+    def __post_init__(self) -> None:
+        # the template parameters are upper so if these are the
+        # same then it is ambiguous
+        for a in self.add:
+            assert a.name.upper() != a.name
+        for args in self.replace.values():
+            for a in args:
+                assert a.name.upper() != a.name
+
+    def to_list(self) -> list[str]:
+        replace_list = []
+        for kernel_param, replacement_params in self.replace.items():
+            replacements = ", ".join(str(param) for param in replacement_params)
+            replace_list.append(f"{kernel_param} -> {replacements}")
+
+        return replace_list
diff --git a/phivenv/Lib/site-packages/torchgen/native_function_generation.py b/phivenv/Lib/site-packages/torchgen/native_function_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c2ec2e7dc47d9b507bb44be87db3b845f355898
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/native_function_generation.py
@@ -0,0 +1,651 @@
+from __future__ import annotations
+
+import string
+from collections import defaultdict
+from typing import TYPE_CHECKING
+
+import torchgen.api.dispatcher as dispatcher
+from torchgen.api.translate import translate
+from torchgen.api.types import Binding, DispatcherSignature, Expr
+from torchgen.context import with_native_function
+from torchgen.model import (
+    Annotation,
+    Argument,
+    BackendIndex,
+    BackendMetadata,
+    BaseOperatorName,
+    BaseTy,
+    BaseType,
+    DEFAULT_KERNEL_NAMESPACE,
+    DeviceCheckType,
+    DispatchKey,
+    FunctionSchema,
+    NativeFunction,
+    NativeFunctionsGroup,
+    OperatorName,
+    Return,
+    SchemaKind,
+    Variant,
+)
+from torchgen.utils import concatMap
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+
+# See Note: [Out ops with functional variants that don't get grouped properly]
+OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
+    # This has a functional variant, but it's currently marked private.
+    # This function should be marked private as well (*_backward ops aren't exposed to python anyway).
+    "adaptive_avg_pool3d_backward.grad_input",
+    # There's a functional variant, _slow_conv2d_backward.output_mask, that isn't grouped properly.
+    # Maybe we can kill this operator in favor of convolution_backward?
+    "_slow_conv2d_backward.grad_input",
+]
+
+
+# See Note: [Mutable ops that cannot get an out variant]
+MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
+    # should be out=?
+    "_cummax_helper",
+    # should be out=?
+    "_cummin_helper",
+]
+
+# All of these operators don't have any tensor like returns
+FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
+    "_assert_async",  # no return
+    "_assert_async.msg",  # no return
+    "_assert_tensor_metadata",  # no return
+    "_cslt_sparse_mm_search",  # returns an int
+    "_assert_scalar",  # no return
+    "_dimI",  # returns an int
+    "_dimV",  # returns an int
+    "_has_same_storage_numel",  # returns a boolean
+    "_linalg_check_errors",  # no return
+    "_local_scalar_dense",  # returns a Scalar
+    "_nested_tensor_from_mask_left_aligned",  # returns a boolean
+    "_nnz",  # returns an int
+    "_use_cudnn_ctc_loss",  # returns a boolean
+    "_use_cudnn_ctc_loss.Tensor",  # returns a boolean
+    "_validate_compressed_sparse_indices",  # no return
+    "allclose",  # returns a boolean
+    "dense_dim",  # returns an int
+    "equal",  # returns a boolean
+    "is_coalesced",  # returns an boolean
+    "is_pinned",  # returns a boolean
+    "is_same_size",  # returns a boolean
+    "is_set_to",  # returns a boolean
+    "q_per_channel_axis",  # returns an int
+    "q_scale",  # returns a float
+    "q_zero_point",  # returns an int
+    "qscheme",  # returns a QScheme
+    "record_stream",  # no return
+    "sparse_dim",  # returns an int
+    "sym_constrain_range",  # no return
+    "sym_constrain_range_for_size",  # no return
+    "_nested_tensor_storage_offsets",  # returns a vector of ints
+    "_chunk_grad_outputs_efficient_attention",  # returns a bool
+    "_fused_sdp_choice",  # returns an int
+    "_print",  # no return
+    "_sink_tokens",  # no return
+    "_nested_get_ragged_idx",  # returns an int
+]
+
+INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
+    # polygamma and polygamma.out both exist, but have a
+    # pre-self arg (while polygamma_ does not)
+    # We should either fix this schema so it can be grouped properly,
+    # or allow the codegen to generate new functional/out= NativeFunctions for this op
+    # (which would require changing its overload name to prevent overload ambiguity).
+    "polygamma_"
+]
+
+
+# Groups "similar" NativeFunctions together
+# example add.Tensor, add_.Tensor, add.out
+# "similar" NativeFunctions are all expected to have an identical `signature()`,
+# But have differing SchemaKinds.
+def pre_group_native_functions(
+    native_functions: Sequence[NativeFunction],
+) -> dict[FunctionSchema, dict[SchemaKind, NativeFunction]]:
+    pre_grouped_native_functions: dict[
+        FunctionSchema, dict[SchemaKind, NativeFunction]
+    ] = defaultdict(dict)
+    for f in native_functions:
+        d = pre_grouped_native_functions[f.func.signature()]
+        assert f.func.kind() not in d
+        d[f.func.kind()] = f
+    return pre_grouped_native_functions
+
+
+# Returns the out variant overload name given a base function overload name
+def get_expected_out_variant_overload_name(overload_name: str | None) -> str:
+    return "out" if not overload_name else f"{overload_name}_out"
+
+
+# Helper function: given an inplace FunctionSchema, generate its corresponding out= variant
+# Example before:
+#   _add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
+# Example after:
+#   _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out)
+def self_to_out_signature(func: FunctionSchema) -> FunctionSchema:
+    # Generating an out= schema from an inplace schema.
+    assert func.kind() == SchemaKind.inplace
+    assert func.arguments.self_arg is not None
+    # The new out= schema has:
+    # - a new out argument with the same type as "func" (but with a mutable annotation)
+    # - The returns (if any) now alias the out= argument instead of "func"
+    # - an "out" overload name
+    return FunctionSchema(
+        name=func.name.remove_inplace().with_overload(
+            get_expected_out_variant_overload_name(func.name.overload_name)
+        ),
+        arguments=func.arguments.remove_self_annotation().with_out_args(
+            [
+                Argument(
+                    name="out",
+                    type=func.arguments.self_arg.argument.type,
+                    default=None,
+                    annotation=func.arguments.self_arg.argument.annotation,
+                )
+            ]
+        ),
+        returns=func.returns,
+    )
+
+
+# Helper function: given a functional FunctionSchema, generate its corresponding out= variant
+# Example before:
+#   _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None,
+#       bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
+# Example after:
+#   _to_copy._out(Tensor self, *, bool non_blocking=False, MemoryFormat? memory_format=None,
+#       Tensor(a!) out) -> Tensor(a!)
+def functional_to_out_signature(func: FunctionSchema) -> FunctionSchema:
+    # Generating an out= schema from a functional schema.
+    assert func.kind() == SchemaKind.functional
+
+    new_returns, new_out_args = generate_out_args_from_schema(func)
+    # The new out= schema has:
+    # - one or more new out argument(s) with the same type as returns (but with a mutable annotation)
+    # - The returns now alias the out= arguments
+    # - an "_out" overload name
+    return FunctionSchema(
+        name=func.name.with_overload(
+            get_expected_out_variant_overload_name(func.name.overload_name)
+        ),
+        arguments=func.arguments.signature().with_out_args(
+            new_out_args,
+        ),
+        returns=tuple(new_returns),
+    )
+
+
+# Helper function: given a function schema, generate corresponding out arguments, also the updated return annotations.
+def generate_out_args_from_schema(
+    func: FunctionSchema,
+) -> tuple[list[Return], list[Argument]]:
+    # More of a sanity check - our existing restrictions on schemas should enforce that
+    # mutable schema kinds never return their mutable arguments.
+    assert not any(
+        r.annotation is not None and r.annotation.is_write for r in func.returns
+    )
+
+    tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()]
+    assert len(tensorlike_rets) > 0
+
+    used_annotations = concatMap(
+        lambda a: [] if a.annotation is None else a.annotation.alias_set,
+        func.arguments.flat_all,
+    )
+    valid_annotations = [x for x in string.ascii_lowercase if x not in used_annotations]
+
+    all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns)
+
+    new_out_args: list[Argument] = []
+    # The end result of new_returns is that:
+    # - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added.
+    # - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any).
+    new_returns: list[Return] = []
+    for i, r in enumerate(func.returns):
+        if r.type.is_tensor_like():
+            new_out = Argument(
+                name="out" if len(func.returns) == 1 else f"out{i}",
+                type=r.type,
+                default=None,
+                annotation=Annotation.parse(f"{valid_annotations[i]}!"),
+            )
+            new_out_args.append(new_out)
+            if all_rets_are_tensors:
+                # The convention for out= schemas is that they only return their out arguments
+                # if the return is a plain Tensor (or if it's a tuple of plain Tensors)
+                new_ret = Return(
+                    name=None, type=new_out.type, annotation=new_out.annotation
+                )
+                new_returns.append(new_ret)
+        else:
+            new_returns.append(r)
+    return new_returns, new_out_args
+
+
+# Helper function: given a mutable FunctionSchema, generate its corresponding out= variant
+# Example before:
+#   _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask)  # noqa: B950
+# Example after:
+#   _fused_moving_avg_obs_fq_helper._out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!))  # noqa: B950
+def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema:
+    # Generating an out= schema from a mutable schema.
+    assert func.kind() == SchemaKind.mutable
+    # The new out= schema has:
+    # - Any non-aliased tensor-like returns are converted to mutable, aliased out= arguments
+    #   (if the argument is a tensor then we also return it for method chaining,
+    #   otherwise we return nothing)
+    # - an "out" overload name
+    #
+    # Note that:
+    # (1) This also means that we can *only* generate an out= variant from a mutable schema
+    #     if the mutable schema has at least one tensor-like non-aliasing return.
+    # (2) The generated out= variant still has mutable positional arguments,
+    #     but if necessary we could probably add another out= variant that also
+    #     functionalizes the mutable arguments (a functional_out variant)
+
+    new_returns, new_out_args = generate_out_args_from_schema(func)
+
+    return FunctionSchema(
+        name=func.name.remove_inplace().with_overload(
+            get_expected_out_variant_overload_name(func.name.overload_name)
+        ),
+        arguments=func.arguments.with_out_args(new_out_args),
+        returns=tuple(new_returns),
+    )
+
+
+# This function, given function of one SchemaKind, as well as a target SchemaKind,
+# generates a new NativeFunction with the same properties, but using the target SchemaKind.
+# We only actually generate functions for either functional or out= SchemaKinds.
+# This function returns a tuple, with:
+# - The generated NativeFunction
+# - a dictionary of `BackendIndex` objects, describing which dispatch keys
+#   we will generate kernels for, for the new NativeFunction.
+#   Details are in the function, but we only generate composite kernels (in some cases) today.
+def generate_function(
+    f: NativeFunction, k: SchemaKind
+) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
+    from torchgen.api import cpp
+
+    if k == SchemaKind.functional:
+        assert f.func.kind() != SchemaKind.functional
+        # The new "functional" NativeFunction has:
+        # - any mutable arguments have been converted into (immutable) returns.
+        #   (if a mutable argument was not also a return, it gets converted to one)
+        # - "_functional" appended to the base name, ONLY IF this op has a mutable variant.
+        #   See Note [Overload Ambiguity With Functional Variants]
+        # The default grouping logic in signature() actually already does this,
+        # so we can piggy-back off it (but we still want return names)
+        func = f.func.signature(keep_return_names=True).with_name(
+            OperatorName(
+                name=BaseOperatorName(
+                    base=f.func.name.name.base,
+                    inplace=False,
+                    dunder_method=f.func.name.name.dunder_method,
+                    # See Note [Overload Ambiguity With Functional Variants]
+                    functional_overload=f.func.kind() == SchemaKind.mutable,
+                ),
+                overload_name=f.func.name.overload_name,
+            )
+        )
+    elif k == SchemaKind.out:
+        # We generate out= ops mostly just so that we can pair up NativeFunctions into groups easily,
+        # but at least today, there is no good reason to actually use them.
+        # we'll generate a dispatcher entry for them, but won't actually register any kernels for them.
+        if f.func.kind() == SchemaKind.inplace:
+            func = self_to_out_signature(f.func)
+        elif f.func.kind() == SchemaKind.mutable:
+            func = mutable_to_out_signature(f.func)
+        elif f.func.kind() == SchemaKind.functional:
+            func = functional_to_out_signature(f.func)
+        else:
+            raise AssertionError(
+                "We only bother generating out= functions from either inplace or mutable or functional variants"
+            )
+    else:
+        raise AssertionError(
+            "We currently only generate either functional or out= NativeFunctions"
+        )
+
+    # Generated kernel naming convention for out: _. The reason for this is to
+    # disambiguate operator with the same name but different overload name, e.g., `randn.names_out` and
+    # `randn.generator_with_names_out`.
+    kernel_name = (
+        func.name.unambiguous_name()
+        if func.kind() == SchemaKind.out
+        else cpp.name(func)
+    )
+    if f.func.has_symint():
+        kernel_name += "_symint"
+    backend_metadata = {
+        DispatchKey.CompositeExplicitAutograd: {
+            func.name: BackendMetadata(
+                kernel=kernel_name,
+                structured=False,
+                cpp_namespace=DEFAULT_KERNEL_NAMESPACE,
+            )
+        }
+    }
+    tags = {"generated"} | set(
+        f.tags & {"nondeterministic_seeded", "view_copy", "pt2_compliant_tag"}
+    )
+
+    return (
+        NativeFunction(
+            func=func,
+            use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
+            # These generated fn's aren't meant to be user friendly- don't generate methods.
+            variants={Variant.function},
+            structured=False,
+            structured_delegate=None,
+            structured_inherits=None,
+            precomputed=None,
+            autogen=[],
+            ufunc_inner_loop={},
+            manual_kernel_registration=False,
+            manual_cpp_binding=False,
+            python_module=None,
+            category_override=None,
+            device_guard=False,
+            device_check=DeviceCheckType.NoCheck,
+            loc=f.loc,
+            cpp_no_default_args=set(),
+            is_abstract=f.is_abstract,
+            has_composite_implicit_autograd_kernel=False,
+            has_composite_implicit_autograd_nested_tensor_kernel=False,
+            has_composite_explicit_autograd_kernel=True,
+            has_composite_explicit_autograd_non_functional_kernel=False,
+            # Every generated NativeFunction gets a "generated" tag, so it's easy to tell
+            # which NativeFunction objects did not come directly from native_functions.yaml.
+            tags=tags,
+            namespace=f.namespace,
+        ),
+        backend_metadata,
+    )
+
+
+# This function is responsible for adding generated NativeFunctions which don't appear
+# explicitly in the codegen.
+# You can inspect the full list of NativeFunctions yourself with the torchgen package, by running
+# torchgen.parse_native_yaml("aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/tags.yaml")
+# (Maybe we should make a friendly API for this)
+#
+# Note: this function *mutates* its two inputs,
+# adding the new NativeFunctions / BackendMetadata to them
+def add_generated_native_functions(
+    rs: list[NativeFunction],
+    indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
+) -> None:
+    # The main code for generating new NativeFunctions
+    # First we group of NativeFunctions by schema kind,
+    # then we detect which ones are missing and generate them.
+    pre_grouped_native_functions = pre_group_native_functions(rs)
+    for d in pre_grouped_native_functions.values():
+        has_functional = SchemaKind.functional in d
+        has_inplace = SchemaKind.inplace in d
+        has_mutable = SchemaKind.mutable in d
+        has_out = SchemaKind.out in d
+        is_core = any("core" in variant.tags for variant in d.values())
+
+        # We automatically generate a few native functions that don't exist in the yaml, for a few reasons:
+        # (1) If an operator has an inplace/out= variant but no functional variant, we can generate
+        #     a simple functional variant that the functionalization pass can consume.
+        # (2) If an operator has an inplace or functional but no out= variant, we generate an out=
+        #     variant, mostly so we can easily pair up functions into NativeFunctionsGroup,
+        #     while maintaining the constraint that the out= variant is "required".
+        if has_mutable or has_inplace or has_out or has_functional:
+            # Don't bother generating functions trio's for native functions that bypass the dispatcher.
+            are_manual = all(f.manual_cpp_binding for f in d.values())
+            # Don't bother generating functional + out= variants for view operators
+            # set_ is technically an inplace_view, but for now it is treated
+            # as a normal inplace op in the codegen
+            has_view_ops = any(
+                f.is_view_op and str(f.func.name.name) != "set_" for f in d.values()
+            )
+            # Don't generate the other variants for non-core CompositeImplicitAutograd operators.
+            # We could probably do this, but the main benefit of generating the function triplets
+            # is for transforms that need them, and transforms don't need to act directly
+            # on CompositeImplicitAutograd operators (since we let them decompose).
+            are_composite_implicit = all(
+                f.has_composite_implicit_autograd_kernel for f in d.values()
+            )
+            if are_manual or has_view_ops or are_composite_implicit and not is_core:
+                continue
+            if has_out and len(d.values()) == 1:
+                # Note: [Out ops with functional variants that don't get grouped properly]
+                # In theory we could validly have an out= operator in native_functions.yaml
+                # that has no other variants.
+                # But today, all of the operators where that's the case actually do have
+                # functional variants, that we are just unable to pair up properly.
+                # I think banning this all together is probably safer
+                # (you can always add a functional variant yourself if you want to add a new out= operator).
+                #
+                # We should probably fix the existing cases; this check is to prevent us from adding more over time.
+                if (
+                    str(d[SchemaKind.out].func.name)
+                    not in OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY
+                ):
+                    raise AssertionError(
+                        f"Found an out= operator that we could not find any other variants of: {str(d[SchemaKind.out].func)}"
+                    )
+                continue
+
+            # Some inplace ops that have problematic schemas (that we should fix), which prevent us
+            # from generating out= and functional variants
+            if (
+                has_inplace
+                and str(d[SchemaKind.inplace].func.name)
+                in INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY
+            ):
+                continue
+
+            base_fn = (
+                d[SchemaKind.mutable]
+                if has_mutable
+                else d[SchemaKind.inplace]
+                if has_inplace
+                else d[SchemaKind.out]
+                if has_out
+                else d[SchemaKind.functional]
+            )
+
+            # Note: [Mutable ops that cannot get an out variant]
+            # We can only generate an out= variant if either:
+            # - the original function has tensor-like returns (since we can convert them to out kwargs)
+            # - or it's inplace (since we can convert `self` to an out kwarg)
+            # There are only two functions that don't fit this criteria today though,
+            # and they both look like they should be fixed to be out= variants,
+            # so if feels safer to ban this schema all-together
+            base_fn_valid = base_fn.func.kind() == SchemaKind.inplace or any(
+                r.type.is_tensor_like() for r in base_fn.func.returns
+            )
+            # Note: [Loosen the assertion that all functional should have out variant]
+            # By design all functional operators should have our variants. The needs_out check
+            # is loosening this requirement, changing it to only generate out variant if there's
+            # an `autogen` block in the native function, in the long run it should be removed.
+            # FIXME: Remove this after figuring out CI job failures related to min, max, mean
+            needs_out = any("out" in str(op_name) for op_name in base_fn.autogen)
+            gets_out_variant = not has_out and base_fn_valid and needs_out
+            if not has_out and not base_fn_valid:
+                if (
+                    str(base_fn.func.name)
+                    not in MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
+                    and str(base_fn.func.name)
+                    not in FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
+                ):
+                    raise AssertionError(
+                        f"""Found an operator that we could not generate an out= variant for: {str(base_fn.func)}.
+This type of operators don't have tensor-like return, making it difficult to generate a proper out= variant. If
+out= variant is not needed, please add the function name into FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT list."""
+                    )
+
+            # Generate an out= variant
+            if gets_out_variant:
+                fn, metadata = generate_function(base_fn, SchemaKind.out)
+                d[SchemaKind.out] = fn
+                BackendIndex.grow_index(indices, metadata)
+                rs.append(fn)
+
+            # Generate a functional variant, but only do it if the operator got an out= variant
+            # (Functional variants are only useful if we can group up the variants,
+            # which we can only do if they have an out= variant)
+            if not has_functional and (has_out or gets_out_variant):
+                fn, metadata = generate_function(base_fn, SchemaKind.functional)
+                d[SchemaKind.functional] = fn
+                BackendIndex.grow_index(indices, metadata)
+                rs.append(fn)
+
+
+def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
+    assert len(rets) == len(names)
+    if len(rets) == 0:
+        return ""
+    elif len(rets) == 1:
+        return f"return {names[0]};"
+    else:
+        return f"return {dispatcher.returns_type(rets).cpp_type()}({', '.join(names)});"
+
+
+# Given a function, and the name of a variable corresponding to the output of that function,
+# gather up all of the individual returns that are not aliased
+def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> list[str]:
+    aliased_rets = func.aliased_return_names()
+    non_aliased_names = []
+    is_out_var_a_tuple = len(func.returns) > 1
+    for i, r in enumerate(aliased_rets):
+        if r is None:
+            non_aliased_names.append(
+                f"std::get<{i}>({out_var})" if is_out_var_a_tuple else out_var
+            )
+    return non_aliased_names
+
+
+# Generates functional kernels in terms of their inplace.mutable counterparts.
+# We only do this for "generated" NativeFunctions
+@with_native_function
+def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> str | None:
+    # We should only be generating these for code-generated NativeFunctions
+    if "generated" not in g.functional.tags:
+        return None
+    # And we always write the kernel for a generated op in terms of a non-generated op.
+    if g.inplace is not None and "generated" not in g.inplace.tags:
+        target_f = g.inplace
+    elif g.mutable is not None and "generated" not in g.mutable.tags:
+        target_f = g.mutable
+    else:
+        # We should be guaranteed to have a valid inplace/mutable variant to call into.
+        # See Note: [Mutable Ops Not Using Functionalization]
+        raise AssertionError(str(g.functional.func))
+
+    sig = DispatcherSignature(g.functional.func)
+    target_sig = DispatcherSignature(target_f.func)
+
+    context: list[Binding | Expr] = []
+    clone_mutable_inputs = []
+    cloned_return_names = []
+    # We can't just directly pass all of the arguments from the functional op into the mutating op.
+    # We need to check for which inputs to the mutating operator are mutable,
+    # and clone those inputs first.
+    for a_curr, a_tgt in zip(
+        dispatcher.jit_arguments(g.functional.func),
+        dispatcher.jit_arguments(target_f.func),
+    ):
+        if a_tgt.annotation is not None and a_tgt.annotation.is_write:
+            clone_mutable_inputs.append(
+                f"auto {a_curr.name}_clone = clone_arg({a_curr.name});"
+            )
+            context.append(
+                Expr(
+                    expr=f"{a_curr.name}_clone",
+                    type=dispatcher.argument_type(a_curr, binds=a_curr.name),
+                )
+            )
+            # Invariant: mutable arguments on the inner mutable op are always returns on the functional op.
+            cloned_return_names.append(f"{a_curr.name}_clone")
+        else:
+            context.append(dispatcher.argument(a_curr))
+    exprs = ", ".join([e.expr for e in translate(context, target_sig.arguments())])
+
+    out_name = "output"
+    maybe_assign = f"auto {out_name} = " if len(target_f.func.returns) > 0 else ""
+    inner_return_names = gather_nonaliased_inner_rets(target_f.func, out_name)
+    ret_str = return_str(
+        g.functional.func.returns, inner_return_names + cloned_return_names
+    )
+
+    clone_mutable_inputs_str = "\n".join(clone_mutable_inputs)
+    return f"""
+{sig.defn(name=sig.name() + ("_symint" if g.out.func.has_symint() else ""))} {{
+  {clone_mutable_inputs_str}
+  {maybe_assign}at::_ops::{target_f.func.name.unambiguous_name()}::call({exprs});
+  {ret_str}
+}}
+"""
+
+
+# Generates out= kernels in terms of their functional counterparts.
+# We only do this for "generated" NativeFunctions
+@with_native_function
+def gen_composite_out_kernel(g: NativeFunctionsGroup) -> str | None:
+    # We should only be generating these for code-generated NativeFunctions
+    if "generated" not in g.out.tags:
+        return None
+    # And we always write the kernel for the out= op in terms of the functional.
+    # Note that the functional op might have also been generated, but we don't have to
+    # worry about cycles, because the generated functional kernels are always implemented
+    # in terms of non-generated kernels (see gen_composite_functional_kernel).
+
+    sig = DispatcherSignature(g.out.func)
+    target_sig = DispatcherSignature(g.functional.func)
+
+    exprs = ", ".join(
+        [e.expr for e in translate(sig.arguments(), target_sig.arguments())]
+    )
+
+    copy_outs = []
+    out_name = "tmp_output"
+    for i, out_arg in enumerate(g.out.func.arguments.out):
+        functional_return_name = (
+            out_name
+            if len(g.functional.func.returns) == 1
+            else f"std::get<{i}>({out_name})"
+        )
+        copy_outs.append(
+            f"""\
+  resize_out_helper({out_arg.name}, {functional_return_name});
+  copy_arg({out_arg.name}, {functional_return_name});"""
+        )
+
+    rets = []
+    # For each return arg in the calling (out=) operator,
+    # If it corresponds to an aliased input, return the input.
+    # Otherwise, return the corresponding output from calling the functional operator.
+    for i, ret_name in enumerate(g.out.func.aliased_return_names()):
+        if ret_name is not None:
+            rets.append(ret_name)
+        else:
+            functional_return_name = (
+                out_name
+                if len(g.functional.func.returns) == 1
+                else f"std::get<{i}>({out_name})"
+            )
+            rets.append(functional_return_name)
+
+    copy_outs_str = "\n".join(copy_outs)
+
+    # Kernel name needs to follow the naming convention defined in `generate_function()`
+    return f"""
+{sig.defn(name=g.out.func.name.unambiguous_name() + ("_symint" if g.out.func.has_symint() else ""))} {{
+  auto {out_name} = at::_ops::{g.functional.func.name.unambiguous_name()}::call({exprs});
+  {copy_outs_str}
+  {return_str(g.out.func.returns, rets)}
+}}
+"""
diff --git a/phivenv/Lib/site-packages/torchgen/operator_versions/__init__.py b/phivenv/Lib/site-packages/torchgen/operator_versions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phivenv/Lib/site-packages/torchgen/operator_versions/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/operator_versions/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ddb6b1899092c4eed76db710c45343aa5f6b122c
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/operator_versions/__pycache__/__init__.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4149f5fa68932a769d62c346f004f72a678f8014
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders_constant.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders_constant.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0036cedae4b52d749fc11c075470916a15ca48d2
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/operator_versions/__pycache__/gen_mobile_upgraders_constant.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/operator_versions/gen_mobile_upgraders.py b/phivenv/Lib/site-packages/torchgen/operator_versions/gen_mobile_upgraders.py
new file mode 100644
index 0000000000000000000000000000000000000000..490329157df020aceae97cb34dfd9ec910f9b9cd
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/operator_versions/gen_mobile_upgraders.py
@@ -0,0 +1,386 @@
+#!/usr/bin/env python3
+
+from __future__ import annotations
+
+import os
+from enum import Enum
+from operator import itemgetter
+from pathlib import Path
+from typing import Any
+
+import torch
+from torch.jit.generate_bytecode import generate_upgraders_bytecode
+from torchgen.code_template import CodeTemplate
+from torchgen.operator_versions.gen_mobile_upgraders_constant import (
+    MOBILE_UPGRADERS_HEADER_DESCRIPTION,
+)
+
+
+class ByteCode(Enum):
+    instructions = 1
+    constants = 2
+    types = 3
+    operators = 4
+    register_size = 5
+
+
+EXCLUDED_OP_SET = [
+    "aten::full.names",
+    "aten::full.out",
+    "aten::full",
+]
+
+EXCLUE_UPGRADER_SET = ["full_0_4", "full_out_0_4"]
+
+ONE_INSTRUCTION = CodeTemplate(
+    """
+    Instruction{OpCode::${operator_name}, ${X}, ${N}},"""
+)
+
+INSTRUCTION_LIST = CodeTemplate(
+    """std::vector({
+        ${instruction_list}
+    }), // instructions list"""
+)
+
+ONE_CONSTANT = CodeTemplate(
+    """
+    c10::IValue(${constant}),"""
+)
+
+CONSTANT_LIST = CodeTemplate(
+    """std::vector({
+        ${constant_list}
+    }), // constants list"""
+)
+
+CONSTANTS_LIST_EMPTY = """std::vector(), // constants list"""
+
+ONE_TYPE = CodeTemplate("""c10::parseType("${type_str}"),""")
+
+TYPE_LIST = CodeTemplate(
+    """std::vector({
+        ${type_list}
+    }), // types list"""
+)
+
+TYPE_LIST_EMPTY = """std::vector(), // types list"""
+
+ONE_OPERATOTR_STRING = CodeTemplate(
+    """
+    OperatorString({"${operator_name}", "${overload_name}", ${num_of_args}}),"""
+)
+
+OPERATOR_STRING_LIST = CodeTemplate(
+    """
+    std::vector({
+        ${operator_string_list}
+    }), // operators list"""
+)
+
+ONE_UPGRADER_FUNCTION = CodeTemplate(
+    """
+    mobile::Function::registerFunc(
+        "${upgrader_name}",
+        ${instruction_list},
+        ${constant_list},
+        ${type_list},
+        ${register_size}
+    )"""
+)
+
+ONE_UPGRADER_SRC = CodeTemplate(
+    """
+    ByteCodeFunctionWithOperator({
+        ${bytecode_function},
+        ${operator_string_list}
+    }),"""
+)
+
+
+ONE_UPGRADER_IN_VERSION_MAP = CodeTemplate(
+    """Upgrader({${upgrader_min_version}, ${upgrader_max_version}, "${upgrader_name}", ${bytecode_func_index}})"""
+)  # noqa: E501
+
+ONE_OPERATOR_IN_VERSION_MAP = CodeTemplate(
+    """
+    {std::string("${operator_name}"),
+        std::vector({
+            ${upgrader_list_in_version_map}
+        })},"""
+)
+
+
+OPERATOR_VERSION_MAP = CodeTemplate(
+    """
+const std::unordered_map>
+getOperatorVersionMapForMobile() {
+  static std::unordered_map>
+        operatorVersionMapForMobile({
+            ${operator_list_in_version_map}
+      });
+  return operatorVersionMapForMobile;
+}
+"""
+)
+
+
+UPGRADER_CPP_SRC = CodeTemplate(
+    MOBILE_UPGRADERS_HEADER_DESCRIPTION
+    + """
+#include 
+#include 
+#include 
+
+namespace torch {
+namespace jit {
+
+// clang-format off
+
+// From operator_versions_map
+${operator_version_map}
+
+const std::vector& getUpgraderBytecodeList() {
+  auto generate_upgrader_bytecode_list = []() {
+    std::vector upgrader_function_list({
+               ${upgrader_bytecode}
+            });
+    for (const auto& upgrader_function : upgrader_function_list) {
+      for (const auto& op : upgrader_function.operators) {
+        upgrader_function.function.append_operator(
+            op.name,
+            op.overload_name,
+            op.num_specified_args);
+      }
+    }
+    return upgrader_function_list;
+  };
+  static std::vector upgraderBytecodeList =
+      generate_upgrader_bytecode_list();
+  return upgraderBytecodeList;
+}
+
+// clang-format on
+
+} // namespace jit
+} // namespace torch
+"""
+)
+
+UPGRADER_MOBILE_FILE_NAME = "upgrader_mobile.cpp"
+
+UPGRADER_ELEMENT = CodeTemplate(
+    """\
+Upgrader({${min_version}, ${max_version}, ${operator_name}, ${index}}),
+"""
+)
+
+PER_OPERATOR_UPGRADER_LIST = CodeTemplate(
+    """\
+{
+  std::string(${operator_name}),
+  std::vector({${upgrader_list}});
+}
+"""
+)
+
+
+def construct_instruction(instruction_list_from_yaml: list[Any]) -> str:
+    instruction_list_part = [
+        ONE_INSTRUCTION.substitute(
+            operator_name=instruction[0],
+            X=instruction[1],
+            N=instruction[2],
+        )
+        for instruction in instruction_list_from_yaml
+    ]
+    return INSTRUCTION_LIST.substitute(
+        instruction_list="".join(instruction_list_part).lstrip("\n")
+    )
+
+
+def construct_constants(constants_list_from_yaml: list[Any]) -> str:
+    constants_list_part = []
+    for constant_from_yaml in constants_list_from_yaml:
+        convert_constant = None
+        if isinstance(constant_from_yaml, str):
+            # Add quotes if it's string
+            convert_constant = f'"{constant_from_yaml}"'
+        elif isinstance(constant_from_yaml, bool):
+            convert_constant = "true" if constant_from_yaml else "false"
+        elif constant_from_yaml is None:
+            convert_constant = ""
+        elif isinstance(constant_from_yaml, int):
+            convert_constant = str(constant_from_yaml)
+        else:
+            raise ValueError(
+                f"The type of {constant_from_yaml} is {type(constant_from_yaml)}. "
+                "Please add change in construct_constants function in gen_mobile_upgraders.py."
+            )
+        constants_list_part.append(ONE_CONSTANT.substitute(constant=convert_constant))
+    if len(constants_list_part) == 0:
+        return CONSTANTS_LIST_EMPTY
+    return CONSTANT_LIST.substitute(
+        constant_list="".join(constants_list_part).lstrip("\n")
+    )
+
+
+def construct_operators(operator_list_from_yaml: list[Any]) -> str:
+    operator_list_part = [
+        ONE_OPERATOTR_STRING.substitute(
+            operator_name=operator[0],
+            overload_name=operator[1],
+            num_of_args=operator[2],
+        )
+        for operator in operator_list_from_yaml
+    ]
+    return OPERATOR_STRING_LIST.substitute(
+        operator_string_list="".join(operator_list_part).lstrip("\n")
+    )
+
+
+def construct_types(types_tr_list_from_yaml: list[Any]) -> str:
+    types_tr_list_part = [
+        ONE_TYPE.substitute(type_str=types_tr) for types_tr in types_tr_list_from_yaml
+    ]
+    if len(types_tr_list_part) == 0:
+        return TYPE_LIST_EMPTY
+    return TYPE_LIST.substitute(type_list="".join(types_tr_list_part).lstrip("\n"))
+
+
+def construct_register_size(register_size_from_yaml: int) -> str:
+    if not isinstance(register_size_from_yaml, int):
+        raise ValueError(
+            f"Input register size is {register_size_from_yaml} and"
+            "it's type is {type(register_size_from_yaml)}. An int type is expected."
+        )
+    return str(register_size_from_yaml)
+
+
+def construct_version_maps(
+    upgrader_bytecode_function_to_index_map: dict[str, Any],
+) -> str:
+    version_map = torch._C._get_operator_version_map()
+    sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0))  # type: ignore[no-any-return]
+    sorted_version_map = dict(sorted_version_map_)
+
+    operator_list_in_version_map_part = []
+    for op_name in sorted_version_map:
+        upgraders_in_version_map_part = []
+        # TODO: remove the skip after these two operators schemas are fixed
+        if op_name in EXCLUDED_OP_SET:
+            continue
+        upgrader_ranges = torch._C._get_upgrader_ranges(op_name)
+        upgrader_entries = sorted_version_map[op_name]
+        assert len(upgrader_ranges) == len(upgrader_entries)
+        for idx, upgrader_entry in enumerate(upgrader_entries):
+            upgrader_name = upgrader_entry.upgrader_name
+            bytecode_function_index = upgrader_bytecode_function_to_index_map[
+                upgrader_name
+            ]
+            upgraders_in_version_map_part.append(
+                ONE_UPGRADER_IN_VERSION_MAP.substitute(
+                    upgrader_min_version=upgrader_ranges[idx].min_version,
+                    upgrader_max_version=upgrader_ranges[idx].max_version,
+                    upgrader_name=upgrader_name,
+                    bytecode_func_index=bytecode_function_index,
+                )
+            )
+        operator_list_in_version_map_part.append(
+            ONE_OPERATOR_IN_VERSION_MAP.substitute(
+                operator_name=op_name,
+                upgrader_list_in_version_map="".join(upgraders_in_version_map_part),
+            )
+        )
+    return OPERATOR_VERSION_MAP.substitute(
+        operator_list_in_version_map="".join(operator_list_in_version_map_part).lstrip(
+            "\n"
+        )
+    )
+
+
+def get_upgrader_bytecode_function_to_index_map(
+    upgrader_dict: list[dict[str, Any]],
+) -> dict[str, Any]:
+    upgrader_bytecode_function_to_index_map = {}
+    index = 0
+    for upgrader_bytecode in upgrader_dict:
+        for upgrader_name in upgrader_bytecode.keys():
+            if upgrader_name in EXCLUE_UPGRADER_SET:
+                continue
+            upgrader_bytecode_function_to_index_map[upgrader_name] = index
+            index += 1
+    return upgrader_bytecode_function_to_index_map
+
+
+def write_cpp(cpp_path: str, upgrader_dict: list[dict[str, Any]]) -> None:
+    upgrader_bytecode_function_to_index_map = (
+        get_upgrader_bytecode_function_to_index_map(upgrader_dict)
+    )
+    version_map_src = construct_version_maps(upgrader_bytecode_function_to_index_map)
+    all_upgrader_src_string = []
+    for upgrader_bytecode in upgrader_dict:
+        for upgrader_name, bytecode in upgrader_bytecode.items():
+            # TODO: remove the skip after these two operators schemas are fixed
+            if upgrader_name in EXCLUE_UPGRADER_SET:
+                continue
+            instruction_list_str = ""
+            constant_list_str = ""
+            type_list_str = ""
+            register_size_str = ""
+            operator_list_str = ""
+            for table_name, contents in bytecode.items():
+                element = ByteCode[table_name]
+                if element is ByteCode.instructions:
+                    instruction_list_str = construct_instruction(contents)
+                elif element is ByteCode.constants:
+                    constant_list_str = construct_constants(contents)
+                elif element is ByteCode.operators:
+                    operator_list_str = construct_operators(contents)
+                elif element is ByteCode.types:
+                    type_list_str = construct_types(contents)
+                elif element is ByteCode.register_size:
+                    register_size_str = construct_register_size(contents)
+
+            one_upgrader_function_string = ONE_UPGRADER_FUNCTION.substitute(
+                upgrader_name=upgrader_name,
+                instruction_list=instruction_list_str,
+                constant_list=constant_list_str,
+                type_list=type_list_str,
+                register_size=register_size_str,
+            )
+            one_upgrader_src_string = ONE_UPGRADER_SRC.substitute(
+                bytecode_function=one_upgrader_function_string.lstrip("\n"),
+                operator_string_list=operator_list_str.lstrip("\n"),
+            )
+            all_upgrader_src_string.append(one_upgrader_src_string)
+
+    upgrader_file_content = UPGRADER_CPP_SRC.substitute(
+        operator_version_map=version_map_src,
+        upgrader_bytecode="".join(all_upgrader_src_string).lstrip("\n"),
+    )
+    print("writing file to : ", cpp_path + "/" + UPGRADER_MOBILE_FILE_NAME)
+    with open(os.path.join(cpp_path, UPGRADER_MOBILE_FILE_NAME), "wb") as out_file:
+        out_file.write(upgrader_file_content.encode("utf-8"))
+
+
+def sort_upgrader(upgrader_list: list[dict[str, Any]]) -> list[dict[str, Any]]:
+    sorted_upgrader_list = sorted(
+        upgrader_list, key=lambda one_upgrader: next(iter(one_upgrader))
+    )
+    return sorted_upgrader_list
+
+
+def main() -> None:
+    upgrader_list = generate_upgraders_bytecode()
+    sorted_upgrader_list = sort_upgrader(upgrader_list)
+    for up in sorted_upgrader_list:
+        print("after sort upgrader : ", next(iter(up)))
+
+    pytorch_dir = Path(__file__).resolve().parents[2]
+    upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "mobile"
+    write_cpp(str(upgrader_path), sorted_upgrader_list)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/phivenv/Lib/site-packages/torchgen/operator_versions/gen_mobile_upgraders_constant.py b/phivenv/Lib/site-packages/torchgen/operator_versions/gen_mobile_upgraders_constant.py
new file mode 100644
index 0000000000000000000000000000000000000000..923e39c4891e0562df75652d05673c4e393aff1b
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/operator_versions/gen_mobile_upgraders_constant.py
@@ -0,0 +1,7 @@
+MOBILE_UPGRADERS_HEADER_DESCRIPTION = """/**
+ * @generated
+ * This is an auto-generated file. Please do not modify it by hand.
+ * To re-generate, please run:
+ * cd ~/pytorch && python torchgen/operator_versions/gen_mobile_upgraders.py
+ */
+"""
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/native/native_functions.yaml b/phivenv/Lib/site-packages/torchgen/packaged/ATen/native/native_functions.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b6002f6c2a5cb85bdc7d0654d854f35a2956ed8e
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/native/native_functions.yaml
@@ -0,0 +1,15855 @@
+# See README.md in this directory for more guidance
+
+# *********NB: _cast_* operators are DEPRECATED and will be removed
+# eventually. These were previously used before TorchScript IR supported
+# representing ScalarType's. They are now superseded by usage of
+# `aten::to()`. The ops remain here for backward compatibility purposes.
+
+# DEPRECATED. DO NOT USE
+- func: _cast_Byte(Tensor self, bool non_blocking=False) -> Tensor
+  variants: function
+
+# DEPRECATED. DO NOT USE
+- func: _cast_Char(Tensor self, bool non_blocking=False) -> Tensor
+  variants: function
+
+# DEPRECATED. DO NOT USE
+- func: _cast_Double(Tensor self, bool non_blocking=False) -> Tensor
+  variants: function
+
+# DEPRECATED. DO NOT USE
+- func: _cast_Float(Tensor self, bool non_blocking=False) -> Tensor
+  variants: function
+
+# DEPRECATED. DO NOT USE
+- func: _cast_Int(Tensor self, bool non_blocking=False) -> Tensor
+  variants: function
+
+# DEPRECATED. DO NOT USE
+- func: _cast_Long(Tensor self, bool non_blocking=False) -> Tensor
+  variants: function
+
+# DEPRECATED. DO NOT USE
+- func: _cast_Short(Tensor self, bool non_blocking=False) -> Tensor
+  variants: function
+
+# DEPRECATED. DO NOT USE
+- func: _cast_Half(Tensor self, bool non_blocking=False) -> Tensor
+  variants: function
+
+# Computes the gradient of current tensor w.r.t. graph leaves.
+- func: _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
+  manual_cpp_binding: True
+  variants: method
+
+# DEPRECATED. Sets the tensor data held by this `Variable` to be the same as
+# `new_data`.  It requires that `new_data` and `Variable` have compatible tensor
+# type, by checking `_has_compatible_shallow_copy_type(this, new_data)`.
+#
+# This function is deprecated because it doesn't really make sense in a world
+# where Variables *are* Tensors (as opposed to them containing tensors, which
+# is what the previous interpretation was.)
+- func: set_data(Tensor(a!) self, Tensor new_data) -> ()
+  manual_cpp_binding: True
+  variants: method
+
+- func: data(Tensor self) -> Tensor
+  manual_cpp_binding: True
+  variants: method
+
+# True if this `Variable` is a leaf and thus does not have a `grad_fn`.
+- func: is_leaf(Tensor self) -> bool
+  manual_cpp_binding: True
+  variants: method
+
+# Returns the output index of this variable from the forward operation that
+# produced it.  Conversely, it returns the input index of the gradient `Node` to
+# which this `Variable` is connected (because in the gradient computation,
+# inputs and outputs switch meaning).  For example:
+#
+#   y0, y1, y2 = f(x)
+#   assert y0.output_nr == 0
+#   assert y1.output_nr == 1
+#   assert y2.output_nr == 2
+#
+- func: output_nr(Tensor self) -> int
+  manual_cpp_binding: True
+  variants: method
+
+- func: _version(Tensor self) -> int
+  manual_cpp_binding: True
+  variants: method
+
+- func: requires_grad_(Tensor(a!) self, bool requires_grad=True) -> Tensor(a!)
+  manual_cpp_binding: True
+  variants: method
+
+# Enables .grad attribute for non-leaf Tensors.
+- func: retain_grad(Tensor(a!) self) -> ()
+  manual_cpp_binding: True
+  variants: method
+
+- func: retains_grad(Tensor self) -> bool
+  manual_cpp_binding: True
+  variants: method
+
+- func: _fw_primal(Tensor(a) self, int level) -> Tensor(a)
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: _fw_primal
+
+- func: _make_dual(Tensor(a) primal, Tensor tangent, int level) -> Tensor(a)
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: _make_dual
+
+- func: _unpack_dual(Tensor(a) dual, int level) -> (Tensor(a) primal, Tensor tangent)
+  variants: function
+
+# NOTE: [_new_zeros_with_same_feature_meta]
+# This function creates a new tensor with the layout and TensorOptions
+# of `other` but also takes into account the batch dimensions of `self`
+#
+# This function has a couple extra constraints because it is also used for `jvp`
+# in functorch.
+# - is used for forward AD because there is the restriction
+#   that the primal and tangent must have the same layout
+# - We cannot assume that `self` and `other` have the same sizes or even dim
+#   because in the inplace over view case, `other` is the base tensor, and
+#   `self` is the forward grad with respect to the view, which can have an
+#   entirely different shape
+# - takes the number of batch dims for `self` because we also handle
+#   some batching logic. We handle that here instead of a batching rule because
+#   we'd like to avoid calling as_strided in the batching rule (as to enable
+#   nested vmap in functorch).
+# - needs to be CompositeExplicitAutograd for jvp support in functorch.
+#   functorch currently relies on TensorWrapper which does not have storage
+#   CompositeExplicitAutograd makes sure the TensorWrapper is unwrapped.
+# - this function may eventually take on another int argument to store the
+#   the number of batch dims for other once we support that use case
+- func: _new_zeros_with_same_feature_meta(Tensor self, Tensor other, *, int self_num_batch_dims=0) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: _new_zeros_with_same_feature_meta
+  autogen: _new_zeros_with_same_feature_meta.out
+
+# This function compares the storage numel of self with that of other, where
+# storage numel is computed as: `other.storage().nbytes() / other.itemsize()`.
+# We create this function for composite compliance purposes. The batching rule
+# always returns true because vmapped as_strided does not support accessing
+# storage locations not indexable by the input tensor.
+# See the note above for more information.
+- func: _has_same_storage_numel(Tensor self, Tensor other) -> bool
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: _has_same_storage_numel
+
+- func: rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)
+  variants: method
+  tags: inplace_view
+
+- func: rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)
+  variants: method
+
+- func: align_to(Tensor(a) self, Dimname[] names) -> Tensor(a)
+  variants: method
+
+- func: align_to.ellipsis_idx(Tensor(a) self, Dimname[] order, int ellipsis_idx) -> Tensor(a)
+  variants: method
+
+- func: align_as(Tensor self, Tensor other) -> Tensor
+  variants: method
+
+- func: align_tensors(Tensor[] tensors) -> Tensor[]
+
+# Not assert because it's a keyword; not Assert because FX already
+# took that syntax
+# TODO: need to specify this is side-effectful somehow
+- func: _assert_async(Tensor self) -> ()
+  dispatch:
+    CPU: _assert_async_cpu
+    CUDA: _assert_async_cuda
+
+- func: _assert_async.msg(Tensor self, str assert_msg) -> ()
+  dispatch:
+    CPU: _assert_async_msg_cpu
+    CUDA: _assert_async_msg_cuda
+
+- func: _assert_scalar(Scalar self, str assert_msg) -> ()
+  dispatch:
+    CompositeExplicitAutograd: _assert_scalar
+
+- func: _functional_assert_scalar(Scalar self, str assert_msg, Tensor dep_token) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: _functional_assert_scalar
+
+- func: _functional_assert_async.msg(Tensor self, str assert_msg, Tensor dep_token) -> Tensor
+  dispatch:
+    CPU: _functional_assert_async_msg_cpu
+
+- func: _assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None, *, Device? device=None, Layout? layout=None) -> ()
+  dispatch:
+    CompositeExplicitAutograd: _assert_tensor_metadata
+    Meta: _assert_tensor_metadata_meta_symint
+
+- func: _print(str s) -> ()
+  dispatch:
+    CompositeExplicitAutograd: _print
+
+- func: sym_constrain_range(Scalar size, *, int? min=None, int? max=None) -> ()
+  dispatch:
+    CompositeExplicitAutograd: sym_constrain_range
+
+- func: sym_constrain_range_for_size(Scalar size, *, int? min=None, int? max=None) -> ()
+  dispatch:
+    CompositeExplicitAutograd: sym_constrain_range_for_size
+
+- func: _functional_sym_constrain_range(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: _functional_sym_constrain_range
+
+- func: _functional_sym_constrain_range_for_size(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: _functional_sym_constrain_range_for_size
+
+- func: _make_dep_token(*, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+  dispatch:
+    CPU: _make_dep_token_cpu
+
+- func: refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a)
+  variants: method
+
+- func: _use_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank) -> bool
+  device_check: NoCheck  # Tensor arguments allowed to be on different devices, see also _cudnn_ctc_loss
+  dispatch:
+    CUDA: _use_cudnn_ctc_loss
+
+- func: _use_cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank) -> bool
+  device_check: NoCheck  # Tensor arguments allowed to be on different devices, see also _cudnn_ctc_loss
+  dispatch:
+    CUDA: _use_cudnn_ctc_loss_tensor
+
+- func: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)
+  device_check: NoCheck  # log_probs is expected to be on CUDA while targets is expected to be on CPU
+  dispatch:
+    CUDA: _cudnn_ctc_loss
+  autogen: _cudnn_ctc_loss.out
+
+- func: _cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)
+  device_check: NoCheck  # log_probs is expected to be on CUDA while targets is expected to be on CPU
+  dispatch:
+    CUDA: _cudnn_ctc_loss_tensor
+
+- func: _use_cudnn_rnn_flatten_weight() -> bool
+
+- func: _cudnn_rnn_flatten_weight(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional) -> Tensor
+  dispatch:
+    CUDA: _cudnn_rnn_flatten_weight
+  autogen: _cudnn_rnn_flatten_weight.out
+
+- func: _cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
+  # rnn_tanh may or may not redispatch to _cudnn_rnn based on algorithm and build. Thus it might hit dispatch or kernel device check.
+  # Disable dispatch time device check for consistent behavior.
+  device_check: NoCheck
+  dispatch:
+    CUDA: _cudnn_rnn
+  autogen: _cudnn_rnn.out
+  tags: nondeterministic_seeded
+
+- func: _cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[])
+  dispatch:
+    CUDA: _cudnn_rnn_backward
+  autogen: _cudnn_rnn_backward.out
+
+- func: _cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+  dispatch:
+    CUDA: _cudnn_init_dropout_state
+  autogen: _cudnn_init_dropout_state.out
+  tags: nondeterministic_seeded
+
+- func: _debug_has_internal_overlap(Tensor self) -> int
+  variants: function
+
+- func: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor)
+  variants: function
+  dispatch:
+    CUDA: fused_dropout_cuda
+  tags: nondeterministic_seeded
+  autogen: _fused_dropout.out
+
+- func: _masked_scale(Tensor self, Tensor mask, float scale) -> Tensor
+  variants: function
+  dispatch:
+    CUDA: masked_scale_cuda
+  autogen: _masked_scale.out
+
+- func: native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)
+  variants: function
+  dispatch:
+    CPU: native_dropout_cpu
+    CUDA: native_dropout_cuda
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: native_dropout_nested
+  tags: [nondeterministic_seeded, core]
+  autogen: native_dropout.out
+
+- func: native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor
+  dispatch:
+    CPU, NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: native_dropout_backward
+    CUDA: native_dropout_backward_cuda
+  autogen: native_dropout_backward.out
+  tags: pointwise
+
+- func: _sobol_engine_draw(Tensor quasi, int n, Tensor sobolstate, int dimension, int num_generated, ScalarType? dtype) -> (Tensor, Tensor)
+
+- func: _sobol_engine_ff_(Tensor(a!) self, int n, Tensor sobolstate, int dimension, int num_generated) -> Tensor(a!)
+
+- func: _sobol_engine_scramble_(Tensor(a!) self, Tensor ltm, int dimension) -> Tensor(a!)
+
+- func: _sobol_engine_initialize_state_(Tensor(a!) self, int dimension) -> Tensor(a!)
+
+- func: _reshape_from_tensor(Tensor self, Tensor shape) -> Tensor
+
+- func: _shape_as_tensor(Tensor self) -> Tensor
+
+- func: dropout(Tensor input, float p, bool train) -> Tensor
+  tags: [nondeterministic_seeded, maybe_aliasing_or_mutating]
+
+- func: dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
+  tags: nondeterministic_seeded
+
+- func: feature_dropout(Tensor input, float p, bool train) -> Tensor
+  tags: [nondeterministic_seeded, maybe_aliasing_or_mutating]
+
+- func: feature_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
+  tags: nondeterministic_seeded
+
+- func: alpha_dropout(Tensor input, float p, bool train) -> Tensor
+  tags: [nondeterministic_seeded, maybe_aliasing_or_mutating]
+
+- func: alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
+  tags: nondeterministic_seeded
+
+- func: feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor
+  tags: [nondeterministic_seeded, maybe_aliasing_or_mutating]
+
+- func: feature_alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
+  tags: nondeterministic_seeded
+
+- func: abs(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: abs
+    SparseCPU, SparseCUDA: abs_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: abs_sparse_csr
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_abs
+  tags: [core, pointwise]
+
+- func: abs_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: abs_
+    SparseCPU, SparseCUDA: abs_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: abs_sparse_csr_
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_abs_
+
+- func: abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA, MPS: abs_out
+    SparseCPU, SparseCUDA: abs_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: abs_sparse_csr_out
+  tags: pointwise
+
+# Note [Adding an alias]
+# To add an alias do the following:
+#
+# 1) Copy the original functions native_functions.yaml entry, but replace the
+#      original function's name with their own and delete any dispatch
+#      keys for the aliases. Specifying a dispatch key will prevent
+#      autograd from recording the operations the alias performs, which
+#      will stop it from "inheriting" the original operation's autograd behavior.
+# 2) Implement the corresponding functions and have them redispatch to the
+#      original function.
+# 3) Add docstrings to the new function that reference the original function,
+#      and document the method as usual (if it exists.)
+#    (See torch/_torch_docs.py and docs/source/torch.rst if adding a function,
+#     torch/_tensor_docs.py and docs/source/tensors.rst if adding a method,
+#     or module-specific doc bindings (like torch/linalg/__init__.py) if
+#     adding an alias in a namespace.)
+# 4) Update torch/overrides.py consistent with the original function.
+# 5) Update the alias_map in torch/csrc/jit/passes/normalize_ops.cpp.
+# 6) Add aliases argument to existing OpInfo/UnaryUfuncInfo or create new OpInfo/UnaryUfuncInfo entry
+# in op_db list in torch/testing/_internal/common_methods_invocations.py
+#
+# See torch.absolute, an alias for torch.abs, as an example.
+# Absolute, alias for abs
+
+- func: absolute(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: absolute_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+- func: absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+
+- func: angle(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CPU, CUDA: angle
+    MPS: angle_mps
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: angle_sparse_csr
+  tags: pointwise
+
+- func: angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: angle_out
+    MPS: angle_out_mps
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: angle_sparse_csr_out
+  tags: pointwise
+
+- func: view_as_real(Tensor(a) self) -> Tensor(a)
+  variants: function
+  dispatch:
+    CPU, CUDA, MPS, Meta: view_as_real
+
+- func: view_as_complex(Tensor(a) self) -> Tensor(a)
+  variants: function
+  dispatch:
+    CPU, CUDA, MPS, Meta: view_as_complex
+
+- func: sgn(Tensor self) -> Tensor
+  variants: function, method
+  structured_delegate: sgn.out
+  dispatch:
+    SparseCPU, SparseCUDA: sgn_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sgn_sparse_csr
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_sgn
+  tags: pointwise
+
+- func: sgn_(Tensor(a!) self) -> Tensor(a!)
+  variants: method
+  structured_delegate: sgn.out
+  dispatch:
+    SparseCPU, SparseCUDA: sgn_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sgn_sparse_csr_
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_sgn_
+  tags: pointwise
+
+- func: sgn.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: sgn_out
+    MPS: sgn_out_mps
+    SparseCPU, SparseCUDA: sgn_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sgn_sparse_csr_out
+  tags: pointwise
+
+- func: chalf(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
+  variants: method
+
+- func: real(Tensor(a) self) -> Tensor(a)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+
+- func: imag(Tensor(a) self) -> Tensor(a)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+
+- func: _conj(Tensor(a) self) -> Tensor(a)
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: _conj
+
+- func: conj(Tensor(a) self) -> Tensor(a)
+  variants: function, method
+  manual_cpp_binding: True
+
+- func: _conj_physical(Tensor self) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: _conj_physical
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: conj_physical_sparse_csr
+  autogen: _conj_physical.out
+
+- func: conj_physical(Tensor self) -> Tensor
+  variants: function, method
+  tags: [pointwise, maybe_aliasing_or_mutating]
+
+- func: conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA: conj_physical_out
+    MPS: conj_physical_out_mps
+    SparseCPU, SparseCUDA: conj_physical_out_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: conj_physical_sparse_csr_out
+  tags: pointwise
+
+- func: conj_physical_(Tensor(a!) self) -> Tensor(a!)
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: conj_physical_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: conj_physical_sparse_csr_
+  tags: pointwise
+
+- func: resolve_conj(Tensor(a) self) -> Tensor(a)
+  variants: function, method
+
+- func: resolve_neg(Tensor(a) self) -> Tensor(a)
+  variants: function, method
+
+- func: _neg_view(Tensor(a) self) -> Tensor(a)
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: _neg_view
+
+- func: acos(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: acos.out
+  tags: [core, pointwise]
+
+- func: acos_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: acos.out
+  tags: pointwise
+
+- func: acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: acos_out
+  tags: pointwise
+
+# arccos, alias of acos
+- func: arccos(Tensor self) -> Tensor
+  variants: function, method
+
+- func: arccos_(Tensor(a!) self) -> Tensor(a!)
+  variants: function, method
+
+- func: arccos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor
+  tags: core
+  autogen: avg_pool1d.out
+
+- func: adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor
+  tags: core
+  autogen: adaptive_avg_pool1d.out
+
+# Return: (Tensor output, Tensor indices)
+- func: adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor)
+
+- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: add.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA, SparseMeta: add_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: add_sparse_csr
+    MkldnnCPU: mkldnn_add
+    ZeroTensor: add_zerotensor
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_add_Tensor
+  tags: [core, pointwise]
+
+- func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  structured_delegate: add.out
+  dispatch:
+    SparseCPU, SparseCUDA, SparseMeta: add_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: add_sparse_csr_
+    MkldnnCPU: mkldnn_add_
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_add__Tensor
+  tags: pointwise
+
+- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  ufunc_inner_loop:
+    Generic: add (AllAndComplex, BFloat16, Half, ComplexHalf)
+    ScalarOnly: add (Bool)
+  dispatch:
+    SparseCPU, SparseMeta: add_out_sparse_cpu
+    SparseCUDA: add_out_sparse_cuda
+    SparseCsrCPU, SparseCsrMeta: add_out_sparse_compressed_cpu
+    SparseCsrCUDA: add_out_sparse_compressed_cuda
+    MkldnnCPU: mkldnn_add_out
+    MPS: add_out_mps
+  tags: pointwise
+
+- func: _add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+  variants: function
+  dispatch:
+    CPU: add_relu
+
+- func: _add_relu_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
+  variants: function
+  dispatch:
+    CPU: add_relu_
+
+- func: _add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+  dispatch:
+    CPU: add_relu_out
+
+- func: _add_relu.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
+  variants: function
+  dispatch:
+    CPU: add_relu
+
+- func: _add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
+  variants: function
+  dispatch:
+    CPU: add_relu_
+  autogen: _add_relu.Scalar_out
+
+# For C++ only, until we have conversion from C++ numbers to Tensor
+- func: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: add
+  tags: [core, pointwise]
+
+- func: add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: add_
+  autogen: add.Scalar_out
+  tags: pointwise
+
+- func: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+  structured_delegate: addmv.out
+  variants: function, method
+
+- func: addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
+  structured_delegate: addmv.out
+  variants: function, method
+
+- func: addmv.out(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  dispatch:
+    CPU: addmv_out_cpu
+    CUDA: addmv_out_cuda
+    MPS: addmv_out_mps
+    XPU: addmv_out_xpu
+    SparseCsrCPU: addmv_out_sparse_compressed
+    SparseCsrCUDA: addmv_out_sparse_compressed_cuda
+
+- func: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+  variants: function, method
+  dispatch:
+    CPU, CUDA: addr
+    MPS: addr_mps
+    CompositeExplicitAutograd: math_addr
+
+- func: addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: addr_
+
+- func: addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA: addr_out
+    MPS: addr_out_mps
+    CompositeExplicitAutograd: math_addr_out
+
+- func: affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: affine_grid_generator
+  autogen: affine_grid_generator.out
+
+- func: affine_grid_generator_backward(Tensor grad, SymInt[] size, bool align_corners) -> Tensor
+  variants: function
+
+- func: _is_all_true(Tensor self) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: _is_all_true
+
+- func: _is_any_true(Tensor self) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: _is_any_true
+
+# Note: this function is only for testing.
+- func: _test_check_tensor(Tensor self) -> Tensor
+  variants: function
+
+# Note; this function is only for testing
+- func: _test_functorch_fallback(Tensor self, Tensor other) -> Tensor
+  variants: function
+  dispatch:
+    CPU: _test_functorch_fallback
+  autogen: _test_functorch_fallback.out
+
+- func: all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: all.out
+  variants: function, method
+  dispatch:
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_all
+
+
+- func: all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: all.dims_out
+  variants: function, method
+  cpp_no_default_args: ['dim']
+  dispatch:
+    CompositeExplicitAutograd: all_dims_default
+
+- func: all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  dispatch:
+    CPU, CUDA: all_out
+    MPS: all_out_mps
+
+- func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  dispatch:
+    CPU, CUDA: all_dims_out
+    CompositeExplicitAutograd: all_dims_out_default
+  cpp_no_default_args: ['dim']
+
+- func: all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+
+- func: allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool
+  variants: function, method
+  tags: data_dependent_output
+  dispatch:
+    CompositeExplicitAutograd: allclose
+
+- func: any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: any.out
+  variants: function, method
+  tags: core
+
+- func: any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: any.dims_out
+  variants: function, method
+  cpp_no_default_args: ['dim']
+  tags: core
+  dispatch:
+    CompositeExplicitAutograd: any_dims_default
+
+- func: any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  dispatch:
+    CPU, CUDA: any_out
+    MPS: any_out_mps
+
+- func: any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  dispatch:
+    CPU, CUDA: any_dims_out
+    CompositeExplicitAutograd: any_dims_out_default
+  cpp_no_default_args: ['dim']
+
+- func: any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+
+- func: arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: arange
+
+- func: arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: arange
+
+# This operator should be named `arange.start_out` if following the naming convention. However that
+# name is already taken. Disabled because of CI job failures.
+# FIXME: enable this
+#- func: arange.start_out_(Scalar start, Scalar end, *, Tensor(a!) out) -> Tensor(a!)
+#  dispatch:
+#    CompositeExplicitAutograd: arange_start_out
+
+- func: arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: arange
+  cpp_no_default_args: ['step']
+  tags: core
+
+- func: arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: arange_out
+
+- func: arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, Meta: arange_out
+    CUDA: arange_cuda_out
+    MPS: arange_mps_out
+  cpp_no_default_args: ['step']
+
+# This function is a temporary hack to allow tracing of arange like constructs with dynamic
+# bounds on arange.  Normal arange is not traceable because it does not take any tensor inputs;
+# if the range you need is based on another tensor, calling this function directly will
+# preserve tracing.  Get rid of this when arange can directly take tensors for bounds
+# (so that it can be traced directly).
+- func: _dim_arange(Tensor like, int dim) -> Tensor
+
+- func: argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor
+  structured_delegate: argmax.out
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  tags: core
+
+- func: argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  dispatch:
+    CPU, CUDA: argmax_out
+    MPS: argmax_out_mps
+
+- func: argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor
+  structured_delegate: argmin.out
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  tags: core
+
+- func: argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  dispatch:
+    CPU, CUDA: argmin_out
+    MPS: argmin_out_mps
+
+- func: acosh(Tensor self) -> Tensor
+  variants: function, method
+  structured_delegate: acosh.out
+  tags: [core, pointwise]
+
+- func: acosh_(Tensor(a!) self) -> Tensor(a!)
+  variants: function, method
+  structured_delegate: acosh.out
+  tags: pointwise
+
+- func: acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: acosh_out
+    MPS: acosh_out_mps
+  tags: pointwise
+# arccosh, alias for acosh
+
+- func: arccosh(Tensor self) -> Tensor
+  variants: function, method
+
+- func: arccosh_(Tensor(a!) self) -> Tensor(a!)
+  variants: function, method
+
+- func: arccosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: asinh(Tensor self) -> Tensor
+  variants: function, method
+  structured_delegate: asinh.out
+  dispatch:
+    SparseCPU, SparseCUDA: asinh_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asinh_sparse_csr
+  tags: [core, pointwise]
+
+- func: asinh_(Tensor(a!) self) -> Tensor(a!)
+  variants: function, method
+  structured_delegate: asinh.out
+  dispatch:
+    SparseCPU, SparseCUDA: asinh_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asinh_sparse_csr_
+  tags: pointwise
+
+- func: asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: asinh_out
+    MPS: asinh_out_mps
+    SparseCPU, SparseCUDA: asinh_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asinh_sparse_csr_out
+  tags: pointwise
+
+# arcsinh, alias for asinh
+- func: arcsinh(Tensor self) -> Tensor
+  variants: function, method
+
+- func: arcsinh_(Tensor(a!) self) -> Tensor(a!)
+  variants: function, method
+
+- func: arcsinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: atanh(Tensor self) -> Tensor
+  structured_delegate: atanh.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: atanh_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atanh_sparse_csr
+  tags: [core, pointwise]
+
+- func: atanh_(Tensor(a!) self) -> Tensor(a!)
+  structured_delegate: atanh.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: atanh_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atanh_sparse_csr_
+  tags: pointwise
+
+- func: atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: atanh_out
+    MPS: atanh_out_mps
+    SparseCPU, SparseCUDA: atanh_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atanh_sparse_csr_out
+  tags: pointwise
+# arctanh, alias for atanh
+
+- func: arctanh(Tensor self) -> Tensor
+  variants: function, method
+
+- func: arctanh_(Tensor(a!) self) -> Tensor(a!)
+  variants: function, method
+
+- func: arctanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)
+  variants: function, method
+  dispatch:
+    ZeroTensor, CPU, CUDA, MTIA: as_strided_tensorimpl
+    Meta: as_strided_tensorimpl_meta_symint
+    MPS: as_strided_tensorimpl_mps
+    QuantizedCPU, QuantizedCUDA: as_strided_qtensorimpl
+  device_check: NoCheck
+  device_guard: False
+  tags: core
+
+- func: as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!)
+  use_const_ref_for_mutable_tensors: True
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  tags: inplace_view
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: as_strided__symint
+
+- func: asin(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: asin.out
+  dispatch:
+    SparseCPU, SparseCUDA: asin_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asin_sparse_csr
+  tags: [core, pointwise]
+
+- func: asin_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: asin.out
+  dispatch:
+    SparseCPU, SparseCUDA: asin_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asin_sparse_csr_
+  tags: pointwise
+
+- func: asin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: asin_out
+    SparseCPU, SparseCUDA: asin_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: asin_sparse_csr_out
+  tags: pointwise
+
+# arcsin, alias of asin
+- func: arcsin(Tensor self) -> Tensor
+  variants: function, method
+
+- func: arcsin_(Tensor(a!) self) -> Tensor(a!)
+  variants: function, method
+
+- func: arcsin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: atan(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: atan.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: atan_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atan_sparse_csr
+  tags: [core, pointwise]
+
+- func: atan_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: atan.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: atan_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atan_sparse_csr_
+  tags: pointwise
+
+- func: atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: atan_out
+    SparseCPU, SparseCUDA: atan_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: atan_sparse_csr_out
+  tags: pointwise
+
+# arctan, alias of atan
+- func: arctan(Tensor self) -> Tensor
+  variants: function, method
+
+- func: arctan_(Tensor(a!) self) -> Tensor(a!)
+  variants: function, method
+
+- func: arctan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: atleast_1d(Tensor self) -> Tensor
+  variants: function
+  tags: maybe_aliasing_or_mutating
+
+- func: atleast_1d.Sequence(Tensor[] tensors) -> Tensor[]
+
+- func: atleast_2d(Tensor self) -> Tensor
+  variants: function
+  tags: maybe_aliasing_or_mutating
+
+- func: atleast_2d.Sequence(Tensor[] tensors) -> Tensor[]
+  variants: function
+
+- func: atleast_3d(Tensor self) -> Tensor
+  variants: function
+  tags: maybe_aliasing_or_mutating
+
+- func: atleast_3d.Sequence(Tensor[] tensors) -> Tensor[]
+  variants: function
+
+- func: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+  variants: function, method
+  structured_delegate: baddbmm.out
+
+- func: baddbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
+  variants: method
+  structured_delegate: baddbmm.out
+
+- func: baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  variants: function
+  dispatch:
+    CPU: baddbmm_out_cpu
+    CUDA: baddbmm_out_cuda
+    MPS: baddbmm_out_mps
+    XPU: baddbmm_out_xpu
+    SparseCsrCUDA: baddbmm_out_sparse_csr_cuda
+
+- func: baddbmm.dtype(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+  variants: function
+  dispatch:
+    CUDA: _baddbmm_dtype_cuda
+
+- func: baddbmm.dtype_out(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+  dispatch:
+    CUDA: _baddbmm_out_dtype_cuda
+
+- func: bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: bartlett_window
+  autogen: bartlett_window.out
+
+- func: bartlett_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: bartlett_window
+  autogen: bartlett_window.periodic_out
+
+- func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor
+  tags: maybe_aliasing_or_mutating
+
+- func: quantized_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor
+  dispatch:
+    QuantizedCPU: quantized_batch_norm
+  autogen: quantized_batch_norm.out
+
+- func: _batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)
+  tags: maybe_aliasing_or_mutating
+
+- func: _batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)
+
+# Sample bernoulli with values in `self` as probability.
+- func: bernoulli(Tensor self, *, Generator? generator=None) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: bernoulli
+  tags: nondeterministic_seeded
+
+- func: bernoulli.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  tags: nondeterministic_seeded
+  dispatch:
+    CPU, CUDA: bernoulli_out
+    MPS: bernoulli_out_mps
+
+- func: bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  tags: nondeterministic_seeded
+  dispatch:
+    CPU, CUDA: bernoulli_
+    MPS: bernoulli_mps_
+  autogen: bernoulli.Tensor, bernoulli.Tensor_out
+
+- func: bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  tags: nondeterministic_seeded
+  dispatch:
+    CPU, CUDA: bernoulli_
+    MPS: bernoulli_mps_
+  autogen: bernoulli.float_out
+
+# Note [bernoulli.p schema]
+# We should probably just fix the overload ambiguity by appending a _functional to the C++ API name (BC breaking)
+# This out-of-place version isn't used explicitly, but needed by jit.
+# There is no default valid on `p` here because it would introduce ambiguity
+# with `bernoulli(Tensor self, *, Generator? generator=None)` declaration.
+- func: bernoulli.p(Tensor self, float p, *, Generator? generator=None) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  tags: nondeterministic_seeded
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: bernoulli
+
+- func: bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor
+
+- func: binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  variants: function
+  dispatch:
+    CPU: binary_cross_entropy_cpu
+    CUDA: binary_cross_entropy_cuda
+    MPS: binary_cross_entropy_mps
+
+- func: binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  variants: function
+  dispatch:
+    CPU: binary_cross_entropy_out_cpu
+    CUDA: binary_cross_entropy_out_cuda
+    MPS: binary_cross_entropy_out_mps
+
+- func: binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor
+  python_module: nn
+  variants: function
+  dispatch:
+    CPU: binary_cross_entropy_backward_cpu
+    CUDA: binary_cross_entropy_backward_cuda
+    MPS: binary_cross_entropy_backward_mps
+
+- func: binary_cross_entropy_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  variants: function
+  dispatch:
+    CPU: binary_cross_entropy_backward_out_cpu
+    CUDA: binary_cross_entropy_backward_out_cuda
+    MPS: binary_cross_entropy_backward_out_mps
+
+- func: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: binary_cross_entropy_with_logits
+  autogen: binary_cross_entropy_with_logits.out
+
+- func: bincount(Tensor self, Tensor? weights=None, SymInt minlength=0) -> Tensor
+  variants: function, method
+  dispatch:
+    CPU: _bincount_cpu
+    CUDA: _bincount_cuda
+    MPS: _bincount_mps
+  tags: dynamic_output_shape
+  autogen: bincount.out
+
+- func: bitwise_not(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: bitwise_not.out
+  variants: function, method
+  tags: [core, pointwise]
+
+- func: bitwise_not_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: bitwise_not.out
+  variants: method
+  tags: pointwise
+
+- func: bitwise_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS, MTIA: bitwise_not_out
+  tags: pointwise
+
+- func: copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: copysign_out
+  tags: pointwise
+
+- func: copysign.Tensor(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: copysign.out
+  tags: pointwise
+
+- func: copysign_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  structured_delegate: copysign.out
+
+- func: copysign.Scalar(Tensor self, Scalar other) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: copysign
+  tags: pointwise
+
+- func: copysign_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: copysign_
+
+- func: copysign.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: copysign_out
+  tags: pointwise
+
+- func: _lazy_clone(Tensor self) -> Tensor
+  # Like clone, but the copy takes place lazily, only if either the
+  # input or the output are written.
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: _lazy_clone
+
+- func: logical_not(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: logical_not
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_logical_not
+  tags: [core, pointwise]
+
+- func: logical_not_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: logical_not_
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_logical_not_
+  tags: pointwise
+
+- func: logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: logical_not_out
+    MPS: logical_not_out_mps
+  tags: pointwise
+
+- func: logical_xor(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: logical_xor
+  tags: [core, pointwise]
+
+- func: logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: logical_xor_
+  tags: pointwise
+
+- func: logical_xor.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: logical_xor_out
+    MPS: logical_xor_out_mps
+  tags: pointwise
+
+- func: logical_and(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: logical_and
+  tags: [core, pointwise]
+
+- func: logical_and_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: logical_and_
+  tags: pointwise
+
+- func: logical_and.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA, MTIA: logical_and_out
+    MPS: logical_and_out_mps
+  tags: pointwise
+
+- func: logical_or(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: logical_or
+  tags: [core, pointwise]
+
+- func: logical_or_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: logical_or_
+  tags: pointwise
+
+- func: logical_or.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA, MTIA: logical_or_out
+    MPS: logical_or_out_mps
+  tags: pointwise
+
+- func: blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: blackman_window
+  autogen: blackman_window.out
+
+- func: blackman_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: blackman_window
+  autogen: blackman_window.periodic_out
+
+- func: bmm(Tensor self, Tensor mat2) -> Tensor
+  structured_delegate: bmm.out
+  variants: function, method
+  dispatch:
+    SparseCPU: bmm_sparse_cpu
+    SparseCUDA: bmm_sparse_cuda
+    NestedTensorCPU: bmm_nested
+    NestedTensorCUDA: bmm_nested_cuda
+  tags: core
+
+- func: bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  variants: function
+  dispatch:
+    CPU: bmm_out_cpu
+    CUDA: bmm_out_cuda
+    MPS: bmm_out_mps
+    XPU: bmm_out_xpu
+    SparseCPU: bmm_out_sparse_cpu
+    SparseCUDA: bmm_out_sparse_cuda
+    SparseCsrCUDA: bmm_out_sparse_csr_cuda
+
+- func: bmm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor
+  variants: function
+  dispatch:
+    CUDA: _bmm_dtype_cuda
+
+- func: bmm.dtype_out(Tensor self, Tensor mat2, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+  dispatch:
+    CUDA: _bmm_out_dtype_cuda
+
+- func: broadcast_tensors(Tensor[] tensors) -> Tensor[]
+  device_check: NoCheck
+  device_guard: False
+
+- func: broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)
+  variants: function, method
+  dispatch:
+    CompositeImplicitAutograd: broadcast_to_symint
+
+- func: _sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a)
+  variants: function
+  dispatch:
+    SparseCPU, SparseCUDA: sparse_broadcast_to
+
+- func: cat(Tensor[] tensors, int dim=0) -> Tensor
+  structured_delegate: cat.out
+  dispatch:
+    SparseCPU, SparseCUDA: cat_sparse
+    QuantizedCPU: cat_quantized_cpu
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: cat_nested
+  tags: core
+
+- func: cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  precomputed:
+  - dim -> int dim, int valid, bool all_contiguous, bool all_same_dtype, bool all_same_sizes_and_stride, MemoryFormat memory_format
+  dispatch:
+    CPU: cat_out_cpu
+    CUDA: cat_out_cuda
+    MPS: cat_out_mps
+    QuantizedCPU: cat_out_quantized_cpu
+
+- func: cat.names(Tensor[] tensors, Dimname dim) -> Tensor
+
+- func: cat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)
+
+# alias for torch.cat
+- func: concat(Tensor[] tensors, int dim=0) -> Tensor
+
+- func: concat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: concat.names(Tensor[] tensors, Dimname dim) -> Tensor
+
+- func: concat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)
+
+# alias for torch.cat
+- func: concatenate(Tensor[] tensors, int dim=0) -> Tensor
+
+- func: concatenate.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: concatenate.names(Tensor[] tensors, Dimname dim) -> Tensor
+
+- func: concatenate.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: block_diag(Tensor[] tensors) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: block_diag
+  autogen: block_diag.out
+
+- func: ceil(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: ceil.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: ceil_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: ceil_sparse_csr
+  tags: [core, pointwise]
+
+- func: ceil_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: ceil.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: ceil_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: ceil_sparse_csr_
+  tags: pointwise
+
+- func: ceil.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: ceil_out
+    SparseCPU, SparseCUDA: ceil_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: ceil_sparse_csr_out
+  tags: pointwise
+
+# alias for torch.linalg.multi_dot
+- func: chain_matmul(Tensor[] matrices) -> Tensor
+  variants: function
+
+# alias for torch.linalg.multi_dot
+- func: chain_matmul.out(Tensor[] matrices, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: unsafe_chunk(Tensor self, int chunks, int dim=0) -> Tensor[]
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  tags: maybe_aliasing_or_mutating
+
+- func: chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeImplicitAutograd: chunk
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: chunk_nested_tensor
+
+- func: tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[]
+  variants: function, method
+  dispatch:
+    CompositeImplicitAutograd: tensor_split_sections_symint
+
+- func: tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[]
+  variants: function, method
+  dispatch:
+    CompositeImplicitAutograd: tensor_split_indices_symint
+
+- func: tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[]
+  variants: function, method
+
+- func: clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  cpp_no_default_args: ['min']
+  structured_delegate: clamp.out
+  dispatch:
+    QuantizedCPU: clamp_quantized_cpu
+  tags: [core, pointwise]
+
+- func: clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor
+  variants: function, method
+  structured_delegate: clamp.Tensor_out
+  tags: [core, pointwise]
+
+- func: clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  cpp_no_default_args: ['min']
+  structured_delegate: clamp.out
+  tags: pointwise
+
+- func: clamp_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!)
+  variants: function, method
+  structured_delegate: clamp.Tensor_out
+  tags: pointwise
+
+- func: clamp.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  cpp_no_default_args: ['min']
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MTIA: clamp_out
+    MPS: clamp_out_mps
+  tags: pointwise
+
+- func: clamp.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: clamp_Tensor_out
+    MPS: clamp_Tensor_out_mps
+  tags: pointwise
+
+- func: clamp_max(Tensor self, Scalar max) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: clamp_max.out
+  tags: pointwise
+
+- func: clamp_max.Tensor(Tensor self, Tensor max) -> Tensor
+  variants: function, method
+  structured_delegate: clamp_max.Tensor_out
+  tags: pointwise
+
+- func: clamp_max_(Tensor(a!) self, Scalar max) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: clamp_max.out
+  tags: pointwise
+
+- func: clamp_max_.Tensor(Tensor(a!) self, Tensor max) -> Tensor(a!)
+  variants: function, method
+  structured_delegate: clamp_max.Tensor_out
+  tags: pointwise
+
+- func: clamp_max.out(Tensor self, Scalar max, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MTIA: clamp_max_out
+    MPS: clamp_max_out_mps
+  tags: pointwise
+
+- func: clamp_max.Tensor_out(Tensor self, Tensor max, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: clamp_max_Tensor_out
+    MPS: clamp_max_Tensor_out_mps
+  tags: pointwise
+
+- func: clamp_min(Tensor self, Scalar min) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: clamp_min.out
+  tags: pointwise
+
+- func: clamp_min.Tensor(Tensor self, Tensor min) -> Tensor
+  variants: function, method
+  structured_delegate: clamp_min.Tensor_out
+  tags: pointwise
+
+- func: clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: clamp_min.out
+  tags: pointwise
+
+- func: clamp_min_.Tensor(Tensor(a!) self, Tensor min) -> Tensor(a!)
+  variants: function, method
+  structured_delegate: clamp_min.Tensor_out
+  tags: pointwise
+
+- func: clamp_min.out(Tensor self, Scalar min, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MTIA: clamp_min_out
+    MPS: clamp_min_out_mps
+  tags: pointwise
+
+- func: clamp_min.Tensor_out(Tensor self, Tensor min, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: clamp_min_Tensor_out
+    MPS: clamp_min_Tensor_out_mps
+  tags: pointwise
+
+# clip is an alias for clamp
+- func: clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
+  cpp_no_default_args: ['min']
+  variants: function, method
+  tags: pointwise
+
+- func: clip.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor
+  variants: function, method
+  tags: pointwise
+
+- func: clip_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)
+  cpp_no_default_args: ['min']
+  variants: function, method
+  tags: pointwise
+
+- func: clip_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!)
+  variants: function, method
+  tags: pointwise
+
+- func: clip.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!)
+  cpp_no_default_args: ['min']
+  tags: pointwise
+
+- func: clip.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: cudnn_is_acceptable(Tensor self) -> bool
+  device_check: NoCheck
+  device_guard: False
+
+- func: complex(Tensor real, Tensor imag) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: complex
+
+- func: complex.out(Tensor real, Tensor imag, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA, MPS: complex_out
+
+- func: polar(Tensor abs, Tensor angle) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: polar
+
+- func: polar.out(Tensor abs, Tensor angle, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA, MPS: polar_out
+
+- func: constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: constant_pad_nd
+    MPS: constant_pad_nd_mps
+  autogen: constant_pad_nd.out
+  tags: core
+
+- func: contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)
+  variants: method
+  manual_cpp_binding: True
+
+- func: convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: convolution
+  autogen: convolution.out
+  tags: core
+
+- func: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    CompositeExplicitAutograd, CUDA: convolution_backward
+  autogen: convolution_backward.out
+  tags: core
+
+- func: convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: convolution_overrideable
+  autogen: convolution_overrideable.out
+
+- func: convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
+  dispatch:
+    CompositeExplicitAutograd: convolution_backward_overrideable
+  autogen: convolution_backward_overrideable.out
+
+- func: _convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: _convolution
+  autogen: _convolution.out
+
+- func: _convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, int[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor
+
+- func: _convolution_mode(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, str padding, SymInt[] dilation, SymInt groups) -> Tensor
+  dispatch:
+    CompositeImplicitAutograd: _convolution_mode_symint
+
+- func: _convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+
+- func: conv1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] dilation=1, SymInt groups=1) -> Tensor
+  dispatch:
+    CompositeImplicitAutograd: conv1d_symint
+
+- func: conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor
+  dispatch:
+    CompositeImplicitAutograd: conv2d_symint
+
+- func: conv3d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, SymInt groups=1) -> Tensor
+  dispatch:
+    CompositeImplicitAutograd: conv3d_symint
+
+- func: conv1d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, str padding="valid", SymInt[1] dilation=1, SymInt groups=1) -> Tensor
+  cpp_no_default_args: ['bias', 'stride', 'padding']
+  dispatch:
+    CompositeImplicitAutograd: conv1d_padding_symint
+
+- func: conv2d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding="valid", SymInt[2] dilation=1, SymInt groups=1) -> Tensor
+  cpp_no_default_args: ['bias', 'stride', 'padding']
+  dispatch:
+    CompositeImplicitAutograd: conv2d_padding_symint
+
+- func: conv3d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, str padding="valid", SymInt[3] dilation=1, SymInt groups=1) -> Tensor
+  cpp_no_default_args: ['bias', 'stride', 'padding']
+  dispatch:
+    CompositeImplicitAutograd: conv3d_padding_symint
+
+- func: conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: conv_tbc
+  autogen: conv_tbc.out
+
+- func: conv_tbc_backward(Tensor self, Tensor input, Tensor weight, Tensor bias, int pad) -> (Tensor, Tensor, Tensor)
+
+# NB: we inherit the goofy argument order from PyTorch torch.nn.functional
+- func: conv_transpose1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] output_padding=0, SymInt groups=1, SymInt[1] dilation=1) -> Tensor
+  dispatch:
+    CompositeImplicitAutograd: conv_transpose1d_symint
+
+- func: conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt groups=1, SymInt[2] dilation=1) -> Tensor
+  dispatch:
+    CompositeImplicitAutograd: conv_transpose2d_symint
+
+- func: conv_transpose3d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt groups=1, SymInt[3] dilation=1) -> Tensor
+  dispatch:
+    CompositeImplicitAutograd: conv_transpose3d_symint
+
+- func: copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor
+  variants: function
+  dispatch:
+    Meta: copy_meta
+    CompositeExplicitAutogradNonFunctional: copy
+  tags: core
+
+- func: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    MkldnnCPU: copy_mkldnn_
+    SparseCPU, SparseCUDA: copy_sparse_wrapper_
+    CompositeExplicitAutograd: copy_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: copy_sparse_compressed_
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: copy_nested_
+  autogen: copy.out
+
+- func: _copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor
+  dispatch:
+    MPS: _copy_from_mps
+  autogen: _copy_from.out
+
+# We need this to be able to properly copy from a CPU to an XLA tensor with different sizes.
+# See https://github.com/pytorch/xla/issues/2881
+- func: _copy_from_and_resize(Tensor self, Tensor dst) -> Tensor
+  dispatch:
+    MPS: _copy_from_and_resize_mps
+  autogen: _copy_from_and_resize.out
+
+- func: cos(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: cos.out
+  dispatch:
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_cos
+  tags: [core, pointwise]
+
+- func: cos_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: cos.out
+  tags: pointwise
+
+- func: cos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS, MTIA: cos_out
+  tags: pointwise
+
+- func: cosh(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: cosh.out
+  tags: [core, pointwise]
+
+- func: cosh_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: cosh.out
+  tags: pointwise
+
+- func: cosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: cosh_out
+  tags: pointwise
+
+- func: cosine_embedding_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor
+
+- func: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor
+  variants: function, method
+  dispatch:
+    CPU: count_nonzero_cpu
+    CUDA: count_nonzero_cuda
+    MPS: count_nonzero_mps
+  autogen: count_nonzero.dim_IntList_out
+
+- func: count_nonzero(Tensor self, int? dim=None) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: count_nonzero
+  autogen: count_nonzero.out
+
+- func: cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor
+  variants: function, method
+
+- func: corrcoef(Tensor self) -> Tensor
+  variants: function, method
+
+- func: cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid
+  dispatch:
+    CUDA: cudnn_affine_grid_generator_forward
+  autogen: cudnn_affine_grid_generator.out
+
+# TODO: Why do I have to call this grad?!
+- func: cudnn_affine_grid_generator_backward(Tensor grad, int N, int C, int H, int W) -> Tensor grad_theta
+  dispatch:
+    CUDA: cudnn_affine_grid_generator_backward
+  autogen: cudnn_affine_grid_generator_backward.out
+
+- func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)
+  dispatch:
+    CUDA: cudnn_batch_norm
+  autogen: cudnn_batch_norm.out
+
+# NB: You can only use this if you used cudnn_batch_norm training=True
+- func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    CUDA: cudnn_batch_norm_backward
+  autogen: cudnn_batch_norm_backward.out
+
+- func: cudnn_convolution(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
+  dispatch:
+    CUDA: cudnn_convolution
+
+- func: cudnn_convolution.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CUDA: cudnn_convolution_out
+
+- func: cudnn_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
+  dispatch:
+    CUDA: cudnn_convolution_transpose
+  autogen: cudnn_convolution_transpose.out
+
+- func: _mps_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor
+  dispatch:
+    MPS: _mps_convolution_transpose
+  autogen: _mps_convolution_transpose.out
+
+- func: mps_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask) -> (Tensor, Tensor)
+  dispatch:
+    MPS: mps_convolution_transpose_backward
+  autogen: mps_convolution_transpose_backward.out
+
+- func: cudnn_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor
+  dispatch:
+    CUDA: cudnn_convolution_relu
+  autogen: cudnn_convolution_relu.out
+
+- func: cudnn_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor
+  dispatch:
+    CUDA: cudnn_convolution_add_relu
+  autogen: cudnn_convolution_add_relu.out
+
+# NB: input is special cased in a way I don't quite understand
+- func: cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output
+  dispatch:
+    CUDA: cudnn_grid_sampler_forward
+  autogen: cudnn_grid_sampler.out
+
+- func: cudnn_grid_sampler_backward(Tensor self, Tensor grid, Tensor grad_output) -> (Tensor grad_self, Tensor grad_grid)
+  dispatch:
+    CUDA: cudnn_grid_sampler_backward
+  autogen: cudnn_grid_sampler_backward.out
+
+- func: cummax(Tensor self, int dim) -> (Tensor values, Tensor indices)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: cummax
+
+- func: cummax.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CompositeExplicitAutograd: cummax_out
+
+- func: cummax.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: cummax.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+  device_check: NoCheck   # TensorIterator
+
+- func: _cummax_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> ()
+  variants: function
+  dispatch:
+    CPU: cummax_helper_cpu
+    CUDA: cummax_helper_cuda
+    MPS: cummax_helper_mps
+
+- func: cummin(Tensor self, int dim) -> (Tensor values, Tensor indices)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: cummin
+
+- func: cummin.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CompositeExplicitAutograd: cummin_out
+
+- func: cummin.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: cummin.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+  device_check: NoCheck   # TensorIterator
+
+- func: _cummin_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> ()
+  variants: function
+  dispatch:
+    CPU: cummin_helper_cpu
+    CUDA: cummin_helper_cuda
+    MPS: cummin_helper_mps
+
+- func: cummaxmin_backward(Tensor grad, Tensor input, Tensor indices, int dim) -> Tensor
+  variants: function
+  device_check: NoCheck
+  device_guard: False
+
+- func: cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor
+  structured_delegate: cumprod.out
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: cumprod_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!)
+  structured_delegate: cumprod.out
+  variants: method
+
+- func: cumprod.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: cumprod_out
+    MPS: cumprod_out_mps
+
+- func: cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: cumprod_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!)
+  variants: method
+
+- func: cumprod.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+
+- func: cumprod_backward(Tensor grad, Tensor input, int dim, Tensor output) -> Tensor
+  variants: function
+  device_check: NoCheck
+  device_guard: False
+
+- func: cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor
+  structured_delegate: cumsum.out
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  tags: core
+
+- func: cumsum_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!)
+  structured_delegate: cumsum.out
+  variants: method
+
+- func: cumsum.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: cumsum_out
+    MPS: cumsum_out_mps
+
+- func: cumsum.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: cumsum_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!)
+  variants: method
+
+- func: cumsum.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+
+- func: cumulative_trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor
+
+- func: cumulative_trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor
+
+- func: ctc_loss.IntList(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor
+
+# convenience function that converts to intlists for you
+- func: ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor
+
+- func: _ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor)
+  dispatch:
+    CPU: ctc_loss_cpu
+    CUDA: ctc_loss_gpu
+    Meta: ctc_loss_meta
+  autogen: _ctc_loss.out
+  tags: dynamic_output_shape  # the shape of second output is data dependent
+
+- func: _ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor)
+  dispatch:
+    CPU, CUDA: ctc_loss_tensor
+  autogen: _ctc_loss.Tensor_out
+  tags: dynamic_output_shape  # the shape of second output is data dependent
+
+- func: _ctc_loss_backward(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor
+  dispatch:
+    CPU: ctc_loss_backward_cpu
+    CUDA: ctc_loss_backward_gpu
+  autogen: _ctc_loss_backward.out
+
+- func: _ctc_loss_backward.Tensor(Tensor grad, Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor
+  dispatch:
+    CPU, CUDA: ctc_loss_backward_tensor
+
+- func: diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: diag_embed
+  autogen: diag_embed.out
+
+- func: diagflat(Tensor self, int offset=0) -> Tensor
+  variants: function, method
+
+- func: diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: diagonal
+  tags: core
+
+- func: linalg_diagonal(Tensor(a) A, *, int offset=0, int dim1=-2, int dim2=-1) -> Tensor(a)
+  python_module: linalg
+  variants: function
+
+- func: diagonal.Dimname(Tensor(a) self, *, Dimname outdim, Dimname dim1, Dimname dim2, int offset=0) -> Tensor(a)
+  variants: function, method
+
+- func: diagonal_backward(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2) -> Tensor
+  variants: function
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: diagonal_backward_symint
+  autogen: diagonal_backward.out
+
+- func: fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!)
+  variants: method
+
+- func: diff(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None) -> Tensor
+  variants: function, method
+
+- func: diff.out(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None, *, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+
+- func: gradient.scalarint(Tensor self, *, Scalar? spacing=None, int? dim=None, int edge_order=1) -> Tensor[]
+  variants: function
+
+- func: gradient.scalararray(Tensor self, *, Scalar spacing, int[] dim, int edge_order=1) -> Tensor[]
+  variants: function
+
+- func: gradient.array(Tensor self, *, int[] dim, int edge_order=1) -> Tensor[]
+  variants: function
+
+- func: gradient.scalarrayint(Tensor self, *, Scalar[] spacing, int? dim=None, int edge_order=1) -> Tensor[]
+  variants: function
+
+- func: gradient.scalarrayarray(Tensor self, *, Scalar[] spacing, int[] dim, int edge_order=1) -> Tensor[]
+  variants: function
+
+- func: gradient.tensorarrayint(Tensor self, *, Tensor[] spacing, int? dim=None, int edge_order=1) -> Tensor[]
+  variants: function
+
+- func: gradient.tensorarray(Tensor self, *, Tensor[] spacing, int[] dim, int edge_order=1) -> Tensor[]
+  variants: function
+
+- func: div.Tensor(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: div.out
+  dispatch:
+    SparseCPU, SparseCUDA: div_sparse
+    ZeroTensor: div_zerotensor
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_div_Tensor
+  tags: [core, pointwise]
+
+- func: div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  structured_delegate: div.out
+  dispatch:
+    SparseCPU, SparseCUDA: div_sparse_
+  tags: pointwise
+
+- func: div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: div_out
+    SparseCPU, SparseCUDA: div_out_sparse_zerodim
+  tags: pointwise
+
+- func: div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: div.out_mode
+  dispatch:
+    SparseCPU, SparseCUDA: div_sparse
+  tags: [core, pointwise]
+
+- func: div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  structured_delegate: div.out_mode
+  dispatch:
+    SparseCPU, SparseCUDA: div_sparse_
+  tags: pointwise
+
+- func: div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: div_out_mode
+    SparseCPU, SparseCUDA: div_out_sparse_zerodim
+  tags: pointwise
+
+# For C++ only, until we have conversion from C++ numbers to Tensor
+- func: div.Scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: div
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_div_Scalar
+  tags: [core, pointwise]
+
+- func: div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: div_
+  autogen: div.Scalar_out
+  tags: pointwise
+
+- func: div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: div
+  tags: [core, pointwise]
+
+- func: div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: div_
+  autogen: div.Scalar_mode_out
+  tags: pointwise
+
+# divide, alias for div
+- func: divide.Tensor(Tensor self, Tensor other) -> Tensor
+  variants: function, method
+
+- func: divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  variants: method
+
+- func: divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: divide.Scalar(Tensor self, Scalar other) -> Tensor
+  variants: function, method
+
+- func: divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  variants: method
+
+- func: divide.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor
+  variants: function, method
+
+- func: divide_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)
+  variants: method
+
+- func: divide.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)
+
+- func: divide.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor
+  variants: function, method
+
+- func: divide_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)
+  variants: method
+
+  # true_divide, an alias for div
+- func: true_divide.Tensor(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  tags: pointwise
+
+- func: true_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+- func: true_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+
+- func: true_divide.Scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: true_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+- func: dot(Tensor self, Tensor tensor) -> Tensor
+  variants: function, method
+  dispatch:
+    CPU: dot
+    CUDA: dot_cuda
+    MPS: dot_mps
+
+- func: dot.out(Tensor self, Tensor tensor, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: dot_out
+
+- func: vdot(Tensor self, Tensor other) -> Tensor
+  variants: function, method
+  dispatch:
+    CPU: vdot
+    CUDA: vdot_cuda
+
+- func: vdot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: vdot_out
+
+- func: einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> Tensor
+
+- func: embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: embedding_symint
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_embedding
+  autogen: embedding.out
+  tags: core
+
+- func: embedding_backward(Tensor grad, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor
+  dispatch:
+    CompositeImplicitAutograd: embedding_backward_symint
+
+- func: embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor
+  dispatch:
+    CPU: embedding_dense_backward_cpu
+    CUDA: embedding_dense_backward_cuda
+    MPS: embedding_dense_backward_mps
+  autogen: embedding_dense_backward.out
+  tags: core
+
+- func: embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)
+  dispatch:
+    CPU: embedding_renorm_cpu_
+    CUDA: embedding_renorm_cuda_
+  autogen: embedding_renorm, embedding_renorm.out
+
+- func: embedding_sparse_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor
+
+# NOTE [ embedding_bag Native Functions ]
+# The `_embedding_bag.*` variants assume that input tensors except for `weight`,
+# e.g. `indices` and `offsets` (and `offset2bag`), are contiguous.
+# We really only need to enforce this for `_embedding_bag` (the forward) because
+# the backward inputs are the same as forward ones.
+# The above `embedding_bag` wrapper is created to achieve this, e.g.,
+# applying indices = indices.contiguous().
+# The backward functions apply a check that these input tensors are contiguous.
+
+
+- func: _embedding_bag_forward_only(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)
+  dispatch:
+    CPU: _embedding_bag_forward_only_cpu
+    CUDA: _embedding_bag_forward_only_cuda
+  autogen: _embedding_bag_forward_only.out
+
+- func: _rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor)
+
+# row_stack is the alias of vstack
+- func: row_stack(Tensor[] tensors) -> Tensor
+
+- func: row_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)
+
+# To keep backward and forward compatibility, and to avoid ambiguity with the
+# original signature above, scale_grad_by_freq, mode, sparse,
+# per_sample_weights, and include_last_offset parameters do not have default
+# values. Once the original signature is removed, default values can be added.
+- func: embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)
+
+- func: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)
+  dispatch:
+    CPU: _embedding_bag_cpu
+    CUDA: _embedding_bag_cuda
+  autogen: _embedding_bag.out
+  tags: core
+
+- func: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
+  dispatch:
+    CPU, CUDA: _embedding_bag_backward_symint
+
+- func: _embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
+  dispatch:
+    CompositeImplicitAutograd: _embedding_bag_sparse_backward_symint
+
+- func: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
+  dispatch:
+    CPU: _embedding_bag_dense_backward_cpu
+    CUDA: _embedding_bag_dense_backward_cuda
+  autogen: _embedding_bag_dense_backward.out
+
+- func: _embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1) -> Tensor
+  dispatch:
+    CPU: _embedding_bag_per_sample_weights_backward_cpu
+    CUDA: _embedding_bag_per_sample_weights_backward_cuda
+  autogen: _embedding_bag_per_sample_weights_backward.out
+
+- func: empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: empty_names
+  autogen: empty.names_out
+
+- func: empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+  dispatch:
+    CPU: empty_cpu
+    CUDA: empty_cuda
+    MPS: empty_mps
+    Meta: empty_meta_symint
+    MkldnnCPU: empty_mkldnn
+    SparseCPU, SparseCUDA: empty_sparse
+    SparseMeta: empty_sparse_symint
+    SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed
+    SparseCsrMeta: empty_sparse_compressed_symint
+    QuantizedCPU, QuantizedCUDA, QuantizedMeta: empty_unknown_quantized
+  tags: core
+
+- func: empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: empty_permuted_symint
+  autogen: empty_permuted.out
+
+# We do not make new_empty a composite that calls into new_empty_strided, as the strided version
+# is significantly more difficult to implement by different backends
+- func: new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: new_empty_symint
+  autogen: new_empty.out
+
+- func: new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  variants: method
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: new_empty_strided_symint
+  autogen: new_empty_strided.out
+
+- func: new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  variants: method
+  dispatch:
+    # NB: Although this composite mutates on the inside, it is
+    # non-differentiable so NonFunctional doesn't apply
+    CompositeExplicitAutograd: new_full
+  autogen: new_full.out
+
+- func: new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  variants: method
+  dispatch:
+    # NB: Although this composite mutates on the inside, it is
+    # non-differentiable so NonFunctional doesn't apply
+    CompositeExplicitAutograd: new_zeros
+  autogen: new_zeros.out
+
+- func: new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  variants: method
+  dispatch:
+    # NB: Although this composite mutates on the inside, it is
+    # non-differentiable so NonFunctional doesn't apply
+    CompositeExplicitAutograd: new_ones
+  autogen: new_ones.out
+
+# other overrides are to provide a more helpful error message that dtype is required
+- func: _empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor
+  dispatch:
+    CPU: empty_affine_quantized_other_backends_stub
+    QuantizedCPU, QuantizedCUDA: empty_affine_quantized
+  autogen: _empty_affine_quantized.out
+
+# it's a factory function receiving a tensor argument, thus overriding explicitly
+# other overrides are to provide a more helpful error message that dtype is required
+- func: _empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor
+  category_override: factory
+  dispatch:
+    CPU: empty_per_channel_affine_quantized_other_backends_stub
+    QuantizedCPU, QuantizedCUDA: empty_per_channel_affine_quantized
+  autogen: _empty_per_channel_affine_quantized.out
+
+- func: resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)
+  use_const_ref_for_mutable_tensors: True
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  tags: [core, inplace_view]
+  dispatch:
+    Meta: resize__symint
+    CPU: resize_
+    CUDA: resize_cuda_
+    MPS: resize_mps_
+    QuantizedCPU: quantized_resize_cpu_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: resize_sparse_csr_
+  autogen: resize, resize.out
+
+# This is a utility function to enable users to resize out tensor while registering kernels for out variants.
+# Eventually, we can consider exposing `resize_output` as a public API to ship it with python op registration
+# to make it easy to register out variants for ops.
+- func: _resize_output_(Tensor(a!) self, SymInt[] size, Device device) -> Tensor(a!)
+  use_const_ref_for_mutable_tensors: True
+  variants: function
+  dispatch:
+    Meta: _resize_output_
+  autogen: _resize_output, _resize_output.out
+
+- func: empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+  category_override: factory
+  variants: function
+  dispatch:
+    QuantizedCPU, QuantizedCUDA: empty_quantized
+  autogen: empty_quantized.out
+
+- func: empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck
+  device_guard: False
+
+- func: empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: empty_like
+    QuantizedCPU, QuantizedCUDA: empty_like_quantized
+    SparseCPU, SparseCUDA, SparseMeta: empty_like_sparse_coo
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: empty_like_sparse_csr
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: empty_like_nested
+  autogen: empty_like.out
+
+- func: empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CPU: empty_strided_cpu
+    CUDA: empty_strided_cuda
+    MPS: empty_strided_mps
+    Meta: empty_strided_meta_symint
+    QuantizedCPU, QuantizedCUDA: empty_strided_unknown_quantized
+  autogen: empty_strided.out
+  tags: core
+
+- func: erf(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: erf.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: erf_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erf_sparse_csr
+  tags: [core, pointwise]
+
+- func: erf_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: erf.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: erf_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erf_sparse_csr_
+  tags: pointwise
+
+- func: erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS, MTIA: erf_out
+    SparseCPU, SparseCUDA: erf_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erf_sparse_csr_out
+  tags: pointwise
+
+- func: erfc(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: erfc.out
+  variants: function, method
+  tags: pointwise
+
+- func: erfc_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: erfc.out
+  variants: function, method
+  tags: pointwise
+
+- func: erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: erfc_out
+  tags: pointwise
+
+- func: exp(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: exp.out
+  variants: function, method
+  tags: [core, pointwise]
+
+- func: exp_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: exp.out
+  variants: function, method
+  tags: pointwise
+
+- func: exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS, MTIA: exp_out
+  tags: pointwise
+
+- func: exp2(Tensor self) -> Tensor
+  structured_delegate: exp2.out
+  variants: function, method
+  tags: pointwise
+
+- func: exp2_(Tensor(a!) self) -> Tensor(a!)
+  structured_delegate: exp2.out
+  variants: function, method
+  tags: pointwise
+
+- func: exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: exp2_out
+  tags: pointwise
+
+- func: expm1(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: expm1.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: expm1_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: expm1_sparse_csr
+  tags: [core, pointwise]
+
+- func: expm1_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: expm1.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: expm1_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: expm1_sparse_csr_
+  tags: pointwise
+
+- func: expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: expm1_out
+    SparseCPU, SparseCUDA: expm1_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: expm1_sparse_csr_out
+  tags: pointwise
+
+- func: expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
+  variants: method  # This is method-only to match the previous tensor API. In the future we could make this a function too.
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: expand
+  tags: core
+
+- func: expand_as(Tensor(a) self, Tensor other) -> Tensor(a)
+  variants: method  # This is method-only to match the previous tensor API. In the future we could make this a function too.
+  device_check: NoCheck
+  device_guard: False
+
+# decomposes to eye.m
+- func: eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: eye
+
+- func: eye.m(SymInt n, SymInt m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: eye
+
+- func: eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, Meta: eye_out_cpu
+    CUDA: eye_out_cuda
+    MPS: eye_out_mps
+
+- func: eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, Meta: eye_out_cpu
+    CUDA: eye_out_cuda
+    MPS: eye_out_mps
+
+- func: flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)
+  variants: function, method
+
+- func: flatten.named_out_dim(Tensor(a) self, int start_dim, int end_dim, Dimname out_dim) -> Tensor(a)
+  variants: function, method
+
+- func: flatten.using_names(Tensor(a) self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor(a)
+  variants: function, method
+
+- func: flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a)
+  variants: function, method
+
+- func: unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a)
+  variants: function, method
+  dispatch:
+    CompositeImplicitAutograd: unflatten_symint
+
+- func: unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a)
+  variants: function, method
+  dispatch:
+    CompositeImplicitAutograd: unflatten_dimname_symint
+
+- func: fill.Scalar(Tensor self, Scalar value) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: fill
+  tags: core
+
+- func: fill.Tensor(Tensor self, Tensor value) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: fill
+
+- func: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CPU, CUDA: fill_
+    MPS: fill_scalar_mps
+    QuantizedCPU, QuantizedCUDA: fill_quantized_
+    Meta: fill_meta_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: fill_sparse_csr_
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: fill_nested_
+  autogen: fill.Scalar_out
+
+- func: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CPU, CUDA: fill_
+    MPS: fill_tensor_mps_
+    QuantizedCPU, QuantizedCUDA: fill_quantized_
+    Meta: fill_meta_
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: fill_nested_
+  autogen: fill.Tensor_out
+
+- func: floor(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: floor.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: floor_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: floor_sparse_csr
+  tags: [core, pointwise]
+
+- func: floor_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: floor.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: floor_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: floor_sparse_csr_
+  tags: pointwise
+
+- func: floor.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: floor_out
+    SparseCPU, SparseCUDA: floor_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: floor_sparse_csr_out
+  tags: pointwise
+
+- func: floor_divide(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CPU, CUDA, MPS: floor_divide
+    SparseCPU, SparseCUDA: floor_divide_sparse
+
+- func: floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CPU, CUDA, MPS: floor_divide_
+    SparseCPU, SparseCUDA: floor_divide_sparse_
+
+- func: floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA, MPS: floor_divide_out
+    SparseCPU, SparseCUDA: floor_divide_out_sparse_zerodim
+
+- func: floor_divide.Scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: floor_divide
+
+- func: floor_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: floor_divide_
+  autogen: floor_divide.Scalar_out
+
+- func: frac(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: frac.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: frac_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: frac_sparse_csr
+  tags: pointwise
+
+- func: frac_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: frac.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: frac_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: frac_sparse_csr_
+  tags: pointwise
+
+- func: frac.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: frac_out
+    MPS: frac_out_mps
+    SparseCPU, SparseCUDA: frac_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: frac_sparse_csr_out
+  tags: pointwise
+
+- func: full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: full
+  autogen: full.names_out
+
+- func: full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: full
+  tags: core
+
+- func: full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: full_out
+
+- func: full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+  dispatch:
+    # NB: Although this composite mutates on the inside, it is
+    # non-differentiable so NonFunctional doesn't apply
+    CompositeExplicitAutograd: full_like
+  autogen: full_like.out
+  tags: core
+
+- func: from_file(str filename, bool? shared=None, int? size=0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CPU: from_file
+  autogen: from_file.out
+
+- func: gcd.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: gcd_out
+  tags: pointwise
+
+- func: gcd(Tensor self, Tensor other) -> Tensor
+  structured_delegate: gcd.out
+  variants: function, method
+  tags: pointwise
+
+- func: gcd_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  structured_delegate: gcd.out
+  variants: function, method
+
+- func: lcm.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: lcm_out
+  tags: pointwise
+
+- func: lcm(Tensor self, Tensor other) -> Tensor
+  structured_delegate: lcm.out
+  variants: function, method
+  tags: pointwise
+
+- func: lcm_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  structured_delegate: lcm.out
+  variants: function, method
+
+# NOTE [ grid_sampler Native Functions ]
+# `grid_sampler` is _supposed to_ do all the shape checking and then dispatch to
+# one of `cudnn_grid_sampler`, `grid_sampler_2d`, or `grid_sampler_3d`, each of
+# which has the corresponding backward defined as native functions as well.
+# However, we do shape checking everywhere for now since each of the mentioned
+# functions can be called directly, which will lead to crashes otherwise.
+# See https://github.com/pytorch/pytorch/issues/73187 for more information.
+#
+# There is also _grid_sampler_2d_backward_cpu_fallback which is an
+# implementation detail of grid_sampler_2d and is only exposed here for testing
+# purposes.
+#
+# Additionally, arguments `padding_mode` and `interpolation_mode` are cast to
+# enums defined in `native/GridSampler.h`. `cudnn_grid_sampler` doesn't take in
+# `interpolation_mode` because it only supports Bilinear interpolation mode.
+# Nor does it take in `align_corners` because it only supports the mode
+# `align_corners = True`.
+- func: grid_sampler(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor
+
+- func: grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor
+  dispatch:
+    CPU, QuantizedCPU: grid_sampler_2d_cpu
+    CUDA: grid_sampler_2d_cuda
+    MPS: grid_sampler_2d_mps
+  autogen: grid_sampler_2d.out
+  tags: core
+
+# `grid_sampler_2d_backward` takes in `output_mask` to optimize performance for
+# the case where `input` doesn't require gradient. Gradient for `grid` is always
+# computed (only `output_mask[0]` is checked by the implementations).
+- func: grid_sampler_2d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor)
+  dispatch:
+    CPU: grid_sampler_2d_backward_cpu
+    CUDA: grid_sampler_2d_backward_cuda
+  autogen: grid_sampler_2d_backward.out
+
+# See NOTE [ grid_sample CPU fallback ]
+- func: _grid_sampler_2d_cpu_fallback(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: _grid_sampler_2d_cpu_fallback
+  autogen: _grid_sampler_2d_cpu_fallback.out
+
+- func: _grid_sampler_2d_cpu_fallback_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> (Tensor, Tensor)
+
+- func: grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor
+  dispatch:
+    CPU: grid_sampler_3d_cpu
+    CUDA: grid_sampler_3d_cuda
+  autogen: grid_sampler_3d.out
+
+# `grid_sampler_3d_backward` takes in `output_mask` to optimize performance for
+# the case where `input` doesn't require gradient. Gradient for `grid` is always
+# computed (only `output_mask[0]` is checked by the implementations).
+- func: grid_sampler_3d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor)
+  dispatch:
+    CPU: grid_sampler_3d_backward_cpu
+    CUDA: grid_sampler_3d_backward_cuda
+  autogen: grid_sampler_3d_backward.out
+
+- func: hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: hann_window
+  autogen: hann_window.out
+
+- func: hann_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: hann_window
+  autogen: hann_window.periodic_out
+
+- func: hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: hamming_window
+  autogen: hamming_window.out
+
+- func: hamming_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: hamming_window
+  autogen: hamming_window.periodic_out
+
+- func: hamming_window.periodic_alpha(int window_length, bool periodic, float alpha, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: hamming_window
+  autogen: hamming_window.periodic_alpha_out
+
+- func: hamming_window.periodic_alpha_beta(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: hamming_window
+  autogen: hamming_window.periodic_alpha_beta_out
+
+- func: kaiser_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: kaiser_window
+  autogen: kaiser_window.out
+
+- func: kaiser_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: kaiser_window
+  autogen: kaiser_window.periodic_out
+
+- func: kaiser_window.beta(int window_length, bool periodic, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: kaiser_window
+  autogen: kaiser_window.beta_out
+
+- func: hinge_embedding_loss(Tensor self, Tensor target, float margin=1.0, int reduction=Mean) -> Tensor
+
+- func: group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor
+
+- func: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    CPU, CUDA: native_group_norm
+    CompositeExplicitAutograd: math_group_norm
+  autogen: native_group_norm.out
+  tags: core
+
+- func: native_group_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    CPU, CUDA: native_group_norm_backward
+  autogen: native_group_norm_backward.out
+  tags: core
+
+# Real to complex forward FFT
+- func: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
+  variants: function
+  dispatch:
+    CPU: _fft_r2c_mkl
+    CUDA: _fft_r2c_cufft
+    MPS: _fft_r2c_mps
+  tags: core
+
+- func: _fft_r2c.out(Tensor self, int[] dim, int normalization, bool onesided, *, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+  dispatch:
+    CPU: _fft_r2c_mkl_out
+    CUDA: _fft_r2c_cufft_out
+    MPS: _fft_r2c_mps_out
+
+# Complex to real inverse FFT
+- func: _fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor
+  variants: function
+  dispatch:
+    CPU: _fft_c2r_mkl
+    CUDA: _fft_c2r_cufft
+    MPS: _fft_c2r_mps
+
+- func: _fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+  dispatch:
+    CPU: _fft_c2r_mkl_out
+    CUDA: _fft_c2r_cufft_out
+    MPS: _fft_c2r_mps_out
+
+# Standard complex to complex FFT (forward or backward)
+- func: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor
+  variants: function
+  dispatch:
+    CPU: _fft_c2c_mkl
+    CUDA: _fft_c2c_cufft
+    MPS: _fft_c2c_mps
+
+- func: _fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+  dispatch:
+    CPU: _fft_c2c_mkl_out
+    CUDA: _fft_c2c_cufft_out
+    MPS: _fft_c2c_mps_out
+
+- func: _validate_compressed_sparse_indices(bool is_crow, Tensor compressed_idx, Tensor plain_idx, int cdim, int dim, int nnz) -> ()
+  device_check: NoCheck
+  variants: function
+  dispatch:
+    CPU: _validate_compressed_sparse_indices_cpu
+    CUDA: _validate_compressed_sparse_indices_cuda
+
+- func: _cufft_get_plan_cache_size(DeviceIndex device_index) -> int
+
+- func: _cufft_get_plan_cache_max_size(DeviceIndex device_index) -> int
+
+- func: _cufft_set_plan_cache_max_size(DeviceIndex device_index, int max_size) -> ()
+
+- func: _cufft_clear_plan_cache(DeviceIndex device_index) -> ()
+
+- func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: index.Tensor_out
+  variants: function, method
+  dispatch:
+    QuantizedCPU: quantized_index
+  tags: [core, dynamic_output_shape]
+  # NB: This function is special-cased in tools/autograd/gen_variable_type.py
+  # NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp:
+  # - Tensor Tensor::index(ArrayRef indices)
+  # - Tensor Tensor::index(std::initializer_list indices)
+
+- func: index.Tensor_out(Tensor self, Tensor?[] indices, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck
+  structured: True
+  structured_inherits: TensorIteratorBase
+  precomputed:
+  - indices -> DimVector sizes, DimVector strides
+  dispatch:
+    CPU, CUDA, MPS: index_out
+
+# Used by inductor to signal indexing without bounds checks
+# Note that we don't support boolean indexing, to avoid dynamic output shapes
+- func: _unsafe_index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: _unsafe_index
+
+# Used by inductor to generate masked loads
+# Note that we don't support boolean indexing, to avoid dynamic output shapes
+- func: _unsafe_masked_index(Tensor self, Tensor mask, Tensor?[] indices, Scalar fill) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: _unsafe_masked_index
+
+- func: _unsafe_masked_index_put_accumulate(Tensor self, Tensor mask, Tensor?[] indices, Tensor values) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: _unsafe_masked_index_put_accumulate
+
+- func: index_copy.out(Tensor self, int dim, Tensor index, Tensor source, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  variants: function
+  precomputed:
+  - dim -> int dim
+  dispatch:
+    CPU, CUDA: index_copy_out
+    MPS: index_copy_out_mps
+
+- func: index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)
+  variants: method
+  structured_delegate: index_copy.out
+
+- func: index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor
+  variants: function, method
+  structured_delegate: index_copy.out
+
+- func: index_copy_.dimname(Tensor(a!) self, Dimname dim, Tensor index, Tensor source) -> Tensor(a!)
+  variants: method
+
+- func: index_copy.dimname(Tensor self, Dimname dim, Tensor index, Tensor source) -> Tensor
+  variants: function, method
+
+- func: index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)
+  device_check: NoCheck   # delegate to _index_put_impl_, which leverages TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: index_put_
+  autogen: index_put.out
+  # NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp:
+  # - Tensor & Tensor::index_put_(ArrayRef indices, Tensor const & rhs)
+  # - Tensor & Tensor::index_put_(ArrayRef indices, Scalar v)
+  # - Tensor & Tensor::index_put_(std::initializer_list indices, Tensor const & rhs)
+  # - Tensor & Tensor::index_put_(std::initializer_list indices, Scalar v)
+
+- func: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
+  device_check: NoCheck   # delegate to _index_put_impl_ after clone, which leverages TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: index_put
+  tags: core
+
+- func: _unsafe_index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
+  device_check: NoCheck   # delegate to _index_put_impl_ after clone, which leverages TensorIterator
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: _unsafe_index_put
+
+- func: _index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CPU, CUDA, MPS: _index_put_impl_
+    QuantizedCPU: _index_put_impl_quantized_cpu_
+    QuantizedCUDA: _index_put_impl_quantized_cuda_
+  autogen: _index_put_impl, _index_put_impl.out
+
+- func: instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor
+  variants: function
+
+- func: isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor
+  variants: function, method
+
+- func: isin.Tensor_Tensor_out(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+  structured: True
+  dispatch:
+    CPU, CUDA: isin_Tensor_Tensor_out
+    MPS: isin_Tensor_Tensor_out_mps
+
+- func: isin.Tensor_Tensor(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor
+  variants: function
+  structured_delegate: isin.Tensor_Tensor_out
+
+- func: isin.Tensor_Scalar_out(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+  structured: True
+  dispatch:
+    CPU, CUDA, MPS: isin_Tensor_Scalar_out
+
+- func: isin.Tensor_Scalar(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False) -> Tensor
+  variants: function
+  structured_delegate: isin.Tensor_Scalar_out
+
+- func: isin.Scalar_Tensor_out(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+  structured: True
+  dispatch:
+    CPU, CUDA: isin_Scalar_Tensor_out
+    MPS: isin_Scalar_Tensor_out_mps
+
+- func: isin.Scalar_Tensor(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor
+  variants: function
+  structured_delegate: isin.Scalar_Tensor_out
+
+- func: isnan(Tensor self) -> Tensor
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CPU, CUDA, MPS, MTIA: isnan
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_isnan
+    SparseCPU, SparseCUDA: isnan_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isnan_sparse_csr
+  autogen: isnan.out
+  tags: [core, pointwise]
+
+- func: is_distributed(Tensor self) -> bool
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+
+- func: is_floating_point(Tensor self) -> bool
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  manual_cpp_binding: True
+
+- func: is_complex(Tensor self) -> bool
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  manual_cpp_binding: True
+
+- func: is_conj(Tensor self) -> bool
+  variants: function, method
+  device_guard: False
+  manual_cpp_binding: True
+
+- func: _is_zerotensor(Tensor self) -> bool
+  variants: function, method
+  device_guard: False
+  manual_cpp_binding: True
+
+- func: is_neg(Tensor self) -> bool
+  variants: function, method
+  device_guard: False
+  manual_cpp_binding: True
+
+- func: isreal(Tensor self) -> Tensor
+  variants: function, method
+
+- func: is_nonzero(Tensor self) -> bool
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+
+- func: is_same_size(Tensor self, Tensor other) -> bool
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: nested_is_same_size
+    CompositeExplicitAutograd: is_same_size
+
+- func: is_signed(Tensor self) -> bool
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  manual_cpp_binding: True
+
+- func: is_inference(Tensor self) -> bool
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  manual_cpp_binding: True
+
+- func: kl_div(Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor
+
+- func: kron(Tensor self, Tensor other) -> Tensor
+  variants: function, method
+
+- func: kron.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: kthvalue(Tensor self, SymInt k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: kthvalue
+
+- func: kthvalue.values(Tensor self, SymInt k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+  dispatch:
+    CPU: kthvalue_out_cpu
+    CUDA: kthvalue_out_cuda
+
+- func: kthvalue.dimname(Tensor self, SymInt k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+  variants: function, method
+
+- func: kthvalue.dimname_out(Tensor self, SymInt k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+
+- func: layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor
+  dispatch:
+    CompositeImplicitAutograd: layer_norm_symint
+
+- func: native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    CPU: layer_norm_cpu
+    CUDA: layer_norm_cuda
+    MPS: layer_norm_mps
+    CompositeExplicitAutograd: math_native_layer_norm
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: nested_layer_norm
+  autogen: native_layer_norm.out
+  tags: core
+
+- func: native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    CPU: layer_norm_backward_cpu
+    CUDA: layer_norm_backward_cuda
+    MPS: layer_norm_backward_mps
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: layer_norm_backward_nested
+  autogen: native_layer_norm_backward.out
+  tags: core
+
+- func: rms_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor
+  dispatch:
+    CompositeImplicitAutograd: rms_norm_symint
+
+- func: _fused_rms_norm(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor
+  dispatch:
+    MPS: _fused_rms_norm_mps
+
+- func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: nan_to_num
+    SparseCPU, SparseCUDA: nan_to_num_sparse
+  tags: pointwise
+
+- func: nan_to_num_(Tensor(a!) self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor(a!)
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: nan_to_num_
+    SparseCPU, SparseCUDA: nan_to_num_sparse_
+  tags: pointwise
+
+- func: nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA, MTIA: nan_to_num_out
+    MPS: nan_to_num_out_mps
+    SparseCPU, SparseCUDA: nan_to_num_sparse_out
+  tags: pointwise
+
+- func: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
+  python_module: nn
+  dispatch:
+    CompositeImplicitAutograd: linear
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: nested_linear
+    MPS: _mps_linear
+
+- func: linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: nested_linear_backward
+    MPS: mps_linear_backward
+  autogen: linear_backward.out
+
+- func: linear.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CompositeExplicitAutograd: linear_out
+
+- func: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor
+  python_module: nn
+  dispatch:
+    MkldnnCPU: mkldnn_linear
+  autogen: mkldnn_linear.out
+
+- func: mkldnn_linear_backward_input(int[] input_size, Tensor grad_output, Tensor weight) -> Tensor
+  dispatch:
+    MkldnnCPU: mkldnn_linear_backward_input
+  autogen: mkldnn_linear_backward_input.out
+
+- func: mkldnn_linear_backward_weights(Tensor grad_output, Tensor input, Tensor weight, bool bias_defined) -> (Tensor, Tensor)
+  dispatch:
+    MkldnnCPU: mkldnn_linear_backward_weights
+  autogen: mkldnn_linear_backward_weights.out
+
+- func: mkldnn_linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    MkldnnCPU: mkldnn_linear_backward
+  autogen: mkldnn_linear_backward.out
+
+- func: _cslt_compress(Tensor input) -> Tensor
+  dispatch:
+    CUDA: _cslt_compress
+
+- func: _cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0, int split_k=1, int split_k_mode=-1) -> Tensor
+  dispatch:
+    CUDA: _cslt_sparse_mm
+  tags: needs_fixed_stride_order
+
+- func: _cslt_sparse_mm_search(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False) -> int
+  dispatch:
+    CUDA: _cslt_sparse_mm_search
+
+- func: _sparse_semi_structured_tile(Tensor input, str algorithm="", bool use_cutlass=True) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
+  dispatch:
+    CUDA: _sparse_semi_structured_tile
+
+- func: _sparse_semi_structured_apply(Tensor input, Tensor thread_masks) -> (Tensor, Tensor)
+  dispatch:
+    CUDA: _sparse_semi_structured_apply
+
+- func: _sparse_semi_structured_apply_dense(Tensor input, Tensor thread_masks) -> Tensor
+  dispatch:
+    CUDA: _sparse_semi_structured_apply_dense
+
+# DEPRECATED: Use torch.__sparse_semi_structured_mm/torch._sparse_semi_structured_addmm instead
+- func: _sparse_semi_structured_linear(Tensor input, Tensor weight, Tensor meta, *, Tensor? bias=None, str? activation=None, ScalarType? out_dtype=None) -> Tensor
+  dispatch:
+    CUDA: _sparse_semi_structured_linear
+
+- func: _sparse_semi_structured_mm(Tensor mat1, Tensor mat1_meta, Tensor mat2, *, ScalarType? out_dtype=None) -> Tensor
+  dispatch:
+    CUDA: _sparse_semi_structured_mm
+
+- func: _sparse_semi_structured_addmm(Tensor input, Tensor mat1, Tensor mat1_meta, Tensor mat2, *, Scalar alpha=1, Scalar beta=1, ScalarType? out_dtype=None) -> Tensor
+  dispatch:
+    CUDA: _sparse_semi_structured_addmm
+
+- func: _mixed_dtypes_linear(Tensor input, Tensor weight, Tensor scale, *, Tensor? bias=None, str? activation=None) -> Tensor
+  dispatch:
+    CUDA: _mixed_dtypes_linear
+
+- func: fbgemm_linear_int8_weight_fp32_activation(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor
+
+- func: fbgemm_linear_int8_weight(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor
+
+- func: fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, float, int)
+
+- func: fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor
+
+- func: _wrapped_linear_prepack(Tensor weight, Tensor weight_scale, Tensor weight_zero_point, Tensor bias) -> Tensor
+
+- func: _wrapped_quantized_linear_prepacked(Tensor input, Tensor input_scale, Tensor input_zero_point, Tensor packed_weight, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor
+
+- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor
+
+- func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor
+
+- func: fbgemm_pack_quantized_matrix(Tensor input) -> Tensor
+
+- func: fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor
+
+- func: ldexp.Tensor(Tensor self, Tensor other) -> Tensor
+  variants: function, method
+
+- func: ldexp_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  variants: function, method
+  tags: pointwise
+
+- func: ldexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  tags: pointwise
+
+- func: linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: linspace
+
+- func: linspace.Tensor_Tensor(Tensor start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  category_override: factory
+  dispatch:
+    CompositeExplicitAutograd: linspace
+
+- func: linspace.Tensor_Scalar(Tensor start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  category_override: factory
+  dispatch:
+    CompositeExplicitAutograd: linspace
+
+- func: linspace.Scalar_Tensor(Scalar start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  category_override: factory
+  dispatch:
+    CompositeExplicitAutograd: linspace
+
+- func: linspace.out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, Meta: linspace_out
+    CUDA: linspace_cuda_out
+    MPS: linspace_out_mps
+
+- func: linspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!)
+  category_override: factory
+  dispatch:
+    CompositeExplicitAutograd: linspace_out
+
+- func: linspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!)
+  category_override: factory
+  dispatch:
+    CompositeExplicitAutograd: linspace_out
+
+- func: linspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!)
+  category_override: factory
+  dispatch:
+    CompositeExplicitAutograd: linspace_out
+
+- func: log(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: log.out
+  variants: function, method
+  tags: [core, pointwise]
+
+- func: log_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: log.out
+  variants: function, method
+  tags: pointwise
+
+- func: log.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS, MTIA: log_out
+  tags: pointwise
+
+- func: log10(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: log10.out
+  variants: function, method
+  tags: [core, pointwise]
+
+- func: log10_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: log10.out
+  variants: function, method
+  tags: pointwise
+
+- func: log10.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: log10_out
+  tags: pointwise
+
+- func: log1p(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: log1p.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: log1p_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: log1p_sparse_csr
+  tags: [core, pointwise]
+
+- func: log1p_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: log1p.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: log1p_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: log1p_sparse_csr_
+  tags: pointwise
+
+- func: log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: log1p_out
+    SparseCPU, SparseCUDA: log1p_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: log1p_sparse_csr_out
+  tags: pointwise
+
+- func: log2(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: log2.out
+  variants: function, method
+  tags: [core, pointwise]
+
+- func: log2_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: log2.out
+  variants: function, method
+  tags: pointwise
+
+- func: log2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS, MTIA: log2_out
+  tags: pointwise
+
+- func: logaddexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: logaddexp_out
+    MPS: logaddexp_out_mps
+  tags: pointwise
+
+- func: logaddexp(Tensor self, Tensor other) -> Tensor
+  variants: method, function
+  structured_delegate: logaddexp.out
+  tags: pointwise
+
+- func: logaddexp2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: logaddexp2_out
+    MPS: logaddexp2_out_mps
+  tags: pointwise
+
+- func: logaddexp2(Tensor self, Tensor other) -> Tensor
+  variants: method, function
+  structured_delegate: logaddexp2.out
+  tags: pointwise
+
+- func: xlogy.Tensor(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: xlogy.OutTensor
+  variants: function, method
+  tags: pointwise
+
+- func: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: xlogy
+  tags: pointwise
+
+- func: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: xlogy
+  tags: pointwise
+
+# xlogy: inplace variant
+- func: xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: xlogy.OutTensor
+  tags: pointwise
+
+- func: xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: xlogy_
+
+# xlogy: out variant
+- func: xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  variants: function
+  dispatch:
+    CPU, CUDA: xlogy_out
+    MPS: xlogy_out_mps
+  tags: pointwise
+
+- func: xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: xlogy_out
+  tags: pointwise
+
+- func: xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: xlogy_out
+  tags: pointwise
+
+- func: logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: logspace
+
+- func: logspace.Tensor_Tensor(Tensor start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  category_override: factory
+  dispatch:
+    CompositeExplicitAutograd: logspace
+
+- func: logspace.Tensor_Scalar(Tensor start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  category_override: factory
+  dispatch:
+    CompositeExplicitAutograd: logspace
+
+- func: logspace.Scalar_Tensor(Scalar start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  category_override: factory
+  dispatch:
+    CompositeExplicitAutograd: logspace
+
+- func: logspace.out(Scalar start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, Meta: logspace_out
+    CUDA: logspace_cuda_out
+
+- func: logspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)
+  category_override: factory
+  dispatch:
+    CompositeExplicitAutograd: logspace_out
+
+- func: logspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)
+  category_override: factory
+  dispatch:
+    CompositeExplicitAutograd: logspace_out
+
+- func: logspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)
+  category_override: factory
+  dispatch:
+    CompositeExplicitAutograd: logspace_out
+
+# log_softmax allows positional dtype, unlike most operators, because kwonly is BC-breaking when loading jit models.
+- func: log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
+  variants: function, method
+
+- func: log_softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: log_softmax_out
+
+- func: log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
+  variants: function, method
+
+- func: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
+  structured_delegate: _log_softmax.out
+  tags: core
+
+- func: _log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  dispatch:
+    CPU: log_softmax_cpu_out
+    CUDA: log_softmax_cuda_out
+    MTIA: log_softmax_mtia_out
+    MPS: log_softmax_mps_out
+
+- func: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor
+  structured_delegate: _log_softmax_backward_data.out
+
+- func: _log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  dispatch:
+    CPU: log_softmax_backward_cpu_out
+    CUDA: log_softmax_backward_cuda_out
+    MTIA: log_softmax_backward_mtia_out
+    MPS: log_softmax_backward_mps_out
+
+- func: _logcumsumexp(Tensor self, int dim) -> Tensor
+  dispatch:
+    CPU: _logcumsumexp_cpu
+    CUDA: _logcumsumexp_cuda
+
+- func: _logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU: _logcumsumexp_out_cpu
+    CUDA: _logcumsumexp_out_cuda
+
+- func: logcumsumexp(Tensor self, int dim) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: logcumsumexp
+
+- func: logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: logcumsumexp_out
+
+- func: logcumsumexp.dimname(Tensor self, Dimname dim) -> Tensor
+  variants: function, method
+
+- func: logcumsumexp.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: logsumexp
+
+- func: logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    # calls squeeze
+    CompositeExplicitAutogradNonFunctional: logsumexp_out
+
+- func: logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+
+- func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor
+
+- func: matmul(Tensor self, Tensor other) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeImplicitAutograd: matmul
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: matmul_nested
+
+- func: matmul_backward(Tensor grad, Tensor self, Tensor other, bool[2] mask) -> (Tensor, Tensor)
+  dispatch:
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: matmul_backward_nested
+  autogen: matmul_backward.out
+
+- func: matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeImplicitAutograd: matmul_out
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: matmul_out_nested
+
+# Alias to linalg.matrix_power
+- func: matrix_power(Tensor self, int n) -> Tensor
+  variants: function, method
+
+# Alias to linalg.matrix_power
+- func: matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!)
+
+# Alias to linalg.matrix_exp
+- func: matrix_exp(Tensor self) -> Tensor
+  variants: function, method
+
+# This function should be deprecated in favor of differential_analytic_matrix_function in FunctionsManual.cpp
+- func: matrix_exp_backward(Tensor self, Tensor grad) -> Tensor
+
+# DEPRECATED: Use torch.aminmax instead
+- func: _aminmax(Tensor self) -> (Tensor, Tensor)
+  dispatch:
+    CPU, CUDA: _aminmax_all
+  autogen: _aminmax.out
+
+# DEPRECATED: Use torch.aminmax instead
+- func: _aminmax.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor)
+  dispatch:
+    CPU, CUDA: _aminmax
+  autogen: _aminmax.dim_out
+
+- func: aminmax(Tensor self, *, int? dim=None, bool keepdim=False) -> (Tensor min, Tensor max)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: aminmax.out
+  variants: function, method
+
+- func: aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  dispatch:
+    CPU, CUDA: aminmax_out
+    MPS: aminmax_out_mps
+
+- func: _compute_linear_combination(Tensor input, Tensor coefficients) -> Tensor
+  dispatch:
+    CPU, CUDA: _compute_linear_combination
+
+- func: _compute_linear_combination.out(Tensor input, Tensor coefficients, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA: _compute_linear_combination_out
+
+- func: max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: max.dim_max
+  variants: function, method
+  dispatch:
+    QuantizedCPU, QuantizedCUDA: qmax
+  tags: core
+
+- func: max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  precomputed:
+  - dim -> int dim
+  dispatch:
+    CPU, CUDA, MTIA: max_out
+    MPS: max_out_mps
+
+- func: max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
+  device_check: NoCheck   # TensorIterator
+
+- func: value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor
+  variants: function
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeImplicitAutograd: value_selecting_reduction_backward_symint
+    NestedTensorCPU, NestedTensorCUDA: value_selecting_reduction_backward_nested_symint
+
+- func: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
+  variants: function, method
+  structured_delegate: amax.out
+  tags: core
+
+- func: amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  dispatch:
+    CPU, CUDA: amax_out
+    MPS: amax_out_mps
+
+# Return: (Tensor output, Tensor indices)
+- func: max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
+
+- func: max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor
+
+- func: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
+  dispatch:
+    CompositeImplicitAutograd: max_pool2d
+    MPS: mps_max_pool2d
+
+- func: max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
+  dispatch:
+    MPS: mps_max_pool2d_backward
+  autogen: max_pool2d_backward.out
+
+- func: mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
+  dispatch:
+    MkldnnCPU: mkldnn_max_pool2d
+  autogen: mkldnn_max_pool2d.out
+
+- func: mkldnn_max_pool2d_backward(Tensor grad_output, Tensor output, Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
+  dispatch:
+    MkldnnCPU: mkldnn_max_pool2d_backward
+  autogen: mkldnn_max_pool2d_backward.out
+
+- func: mkldnn_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor
+  dispatch:
+    MkldnnCPU: mkldnn_max_pool3d
+  autogen: mkldnn_max_pool3d.out
+
+- func: mkldnn_max_pool3d_backward(Tensor grad_output, Tensor output, Tensor input, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor
+  dispatch:
+    MkldnnCPU: mkldnn_max_pool3d_backward
+  autogen: mkldnn_max_pool3d_backward.out
+
+- func: quantized_max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor
+  dispatch:
+    QuantizedCPU: quantized_max_pool1d
+  autogen: quantized_max_pool1d.out
+
+- func: quantized_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
+  dispatch:
+    QuantizedCPU: quantized_max_pool2d
+    QuantizedCUDA: quantized_max_pool2d_cudnn
+  autogen: quantized_max_pool2d.out
+
+- func: quantized_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor
+  dispatch:
+    QuantizedCPU: quantized_max_pool3d
+  autogen: quantized_max_pool3d.out
+
+- func: max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor
+
+# The CPU and GPU dispatch variants are named weirdly here because otherwise there
+# are namespacing issues in C++
+- func: mean(Tensor self, *, ScalarType? dtype=None) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: mean
+  tags: core
+
+# For normal naming convention this should be `mean.out`. However since we already have `mean.out` we have to rename this.
+- func: mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CompositeExplicitAutograd: mean_dtype_out
+
+- func: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  structured_delegate: mean.out
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    QuantizedCPU: mean_quantized_cpu
+  tags: core
+
+- func: mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: mean_out
+    MPS: mean_out_mps
+    QuantizedCPU: mean_out_quantized_cpu
+
+- func: mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+
+- func: nanmean(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  device_check: NoCheck   # Composite
+  variants: function, method
+
+- func: nanmean.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # Composite
+
+- func: median(Tensor self) -> Tensor
+  variants: function, method
+  dispatch:
+    CPU: median_cpu
+    CUDA: median_cuda
+    MPS: median_mps
+  autogen: median.out
+
+- func: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: median
+
+- func: median.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+  dispatch:
+    CPU: median_out_cpu
+    CUDA: median_out_cuda
+    MPS: median_out_mps
+
+- func: median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+  variants: function, method
+
+- func: median.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+
+- func: nanmedian(Tensor self) -> Tensor
+  variants: function, method
+  dispatch:
+    CPU: nanmedian_cpu
+    CUDA: nanmedian_cuda
+    MPS: nanmedian_mps
+  autogen: nanmedian.out
+
+- func: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: nanmedian
+
+- func: nanmedian.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+  dispatch:
+    CPU: nanmedian_out_cpu
+    CUDA: nanmedian_out_cuda
+    MPS: nanmedian_out_mps
+
+- func: nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+  variants: function, method
+
+- func: nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+
+- func: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: min.dim_min
+  variants: function, method
+  dispatch:
+    QuantizedCPU, QuantizedCUDA: qmin
+  tags: core
+
+- func: min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  precomputed:
+  - dim -> int dim
+  dispatch:
+    CPU, CUDA, MTIA: min_out
+    MPS: min_out_mps
+
+- func: min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)
+  device_check: NoCheck   # TensorIterator
+
+- func: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
+  variants: function, method
+  structured_delegate: amin.out
+  tags: core
+
+- func: amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  dispatch:
+    CPU, CUDA: amin_out
+    MPS: amin_out_mps
+
+# TODO: Add this function to MPS dispatch key so that we avoid declaring it in
+# native_functions.yaml
+# https://github.com/pytorch/pytorch/issues/77394
+- func: _mps_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor
+  dispatch:
+    MPS: _mps_convolution
+  autogen: _mps_convolution.out
+
+- func: mps_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    MPS: mps_convolution_backward
+  autogen: mps_convolution_backward.out
+
+- func: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: mkldnn_convolution
+  autogen: mkldnn_convolution.out
+
+- func: mkldnn_rnn_layer(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) -> (Tensor, Tensor, Tensor, Tensor)
+  dispatch:
+    CPU: mkldnn_rnn_layer
+    MkldnnCPU: mkldnn_rnn_layer
+  autogen: mkldnn_rnn_layer.out
+
+- func: mkldnn_rnn_layer_backward(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
+  dispatch:
+    CPU: mkldnn_rnn_layer_backward
+  autogen: mkldnn_rnn_layer_backward.out
+
+- func: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    CUDA: miopen_batch_norm
+  autogen: miopen_batch_norm.out
+
+- func: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    CUDA: miopen_batch_norm_backward
+  autogen: miopen_batch_norm_backward.out
+
+- func: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor
+  dispatch:
+    CUDA: miopen_convolution
+  autogen: miopen_convolution.out
+
+- func: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor
+  dispatch:
+    CUDA: miopen_convolution_transpose
+  autogen: miopen_convolution_transpose.out
+
+- func: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor
+  dispatch:
+    CUDA: miopen_depthwise_convolution
+  autogen: miopen_depthwise_convolution.out
+
+- func: miopen_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor
+  dispatch:
+    CUDA: miopen_convolution_relu
+
+- func: miopen_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor
+  dispatch:
+    CUDA: miopen_convolution_add_relu
+
+- func: miopen_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
+  dispatch:
+    CUDA: miopen_rnn
+  autogen: miopen_rnn.out
+  tags: nondeterministic_seeded
+
+
+- func: miopen_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[])
+  dispatch:
+    CUDA: miopen_rnn_backward
+  autogen: miopen_rnn_backward.out
+
+- func: mm(Tensor self, Tensor mat2) -> Tensor
+  structured_delegate: mm.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: _sparse_mm
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm
+  tags: core
+
+- func: mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  dispatch:
+    CPU: mm_out_cpu
+    CUDA: mm_out_cuda
+    MTIA: mm_out_mtia
+    MPS: mm_out_mps
+    XPU: mm_out_xpu
+    SparseCPU, SparseCUDA: _sparse_mm_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm_out
+
+- func: mm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor
+  dispatch:
+    CUDA: _mm_dtype_cuda
+
+- func: mm.dtype_out(Tensor self, Tensor mat2, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CUDA: _mm_dtype_out_cuda
+
+- func: _int_mm(Tensor self, Tensor mat2) -> Tensor
+  dispatch:
+    CPU: _int_mm_cpu
+    CUDA: _int_mm_cuda
+
+- func: _int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU: _int_mm_out_cpu
+    CUDA: _int_mm_out_cuda
+
+- func: _convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor
+  dispatch:
+    CUDA: _convert_weight_to_int4pack_cuda
+    MPS: _convert_weight_to_int4pack_mps
+
+- func: _weight_int4pack_mm(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor
+  dispatch:
+    MPS: _weight_int4pack_mm_mps
+    CUDA: _weight_int4pack_mm_cuda
+
+- func: _weight_int4pack_mm_with_scales_and_zeros(Tensor self, Tensor mat2, int qGroupSize, Tensor qScale, Tensor qZeros) -> Tensor
+  dispatch:
+    XPU: _weight_int4pack_mm_xpu
+
+# Split int4 pack weight between cpu and other devices due to
+# https://github.com/pytorch/ao/issues/1117#issuecomment-2451252756.
+- func: _convert_weight_to_int4pack_for_cpu(Tensor self, int innerKTiles) -> Tensor
+  dispatch:
+    CPU: _convert_weight_to_int4pack_cpu
+
+- func: _weight_int4pack_mm_for_cpu(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor
+  dispatch:
+    CPU: _weight_int4pack_mm_cpu
+
+- func: _dyn_quant_pack_4bit_weight(Tensor weights, Tensor scales_zeros, Tensor? bias, int block_size, int in_features, int out_features) -> Tensor
+  dispatch:
+    CPU: _dyn_quant_pack_4bit_weight_cpu
+
+- func: _dyn_quant_matmul_4bit(Tensor inp, Tensor packed_weights, int block_size, int in_features, int out_features) -> Tensor
+  dispatch:
+    CPU: _dyn_quant_matmul_4bit_cpu
+
+- func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor
+  dispatch:
+    CPU: _weight_int8pack_mm_cpu
+    MPS: _weight_int8pack_mm_mps
+
+- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
+  python_module: sparse
+
+- func: _sparse_mm.reduce(Tensor sparse, Tensor dense, str reduce) -> Tensor
+  python_module: sparse
+
+- func: _sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor
+  dispatch:
+    SparseCPU: sparse_sparse_matmul_cpu
+    SparseCUDA: sparse_sparse_matmul_cuda
+  autogen: _sparse_sparse_matmul.out
+
+- func: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
+  variants: function, method
+  dispatch:
+    CPU, CUDA: mode
+
+- func: mode.values(Tensor self, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+  dispatch:
+    CompositeExplicitAutograd: mode_out
+
+- func: mode.dimname(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+  variants: function, method
+
+- func: mode.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+
+- func: mul.Tensor(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: mul.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: mul_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: mul_sparse_csr
+    MkldnnCPU: mkldnn_mul
+    ZeroTensor: mul_zerotensor
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_mul_Tensor
+  tags: [core, pointwise]
+
+- func: mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: mul.out
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA: mul_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: mul_sparse_csr_
+    MkldnnCPU: mkldnn_mul_
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_mul__Tensor
+  tags: pointwise
+
+- func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: mul_out
+    SparseCPU: mul_out_sparse_cpu
+    SparseCUDA: mul_out_sparse_cuda
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: mul_out_sparse_csr
+    MkldnnCPU: mkldnn_mul_out
+  tags: pointwise
+  # For C++ only, until we have conversion from C++ numbers to Tensor
+
+- func: mul.Scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: mul
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: mul_scalar_sparse_csr
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_mul_Scalar
+  tags: [core, pointwise]
+
+- func: mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: mul_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: mul__scalar_sparse_csr
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_mul__Scalar
+  autogen: mul.Scalar_out
+  tags: pointwise
+# multiply, alias for mul
+
+- func: multiply.Tensor(Tensor self, Tensor other) -> Tensor
+  variants: function, method
+
+- func: multiply_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  variants: method
+
+- func: multiply.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: multiply.Scalar(Tensor self, Scalar other) -> Tensor
+  variants: function, method
+
+- func: multiply_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  variants: method
+
+- func: mv(Tensor self, Tensor vec) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: mv
+    SparseCPU, SparseCUDA: mv_sparse
+
+- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: mv_out
+
+- func: mvlgamma.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA: mvlgamma_out
+  tags: pointwise
+
+- func: mvlgamma(Tensor self, int p) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: mvlgamma
+  tags: pointwise
+
+- func: mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: mvlgamma_
+  tags: pointwise
+
+- func: narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor
+  variants: function, method
+  dispatch:
+    CPU: narrow_copy_dense_cpu
+    SparseCPU, SparseCUDA: narrow_copy_sparse
+    CompositeExplicitAutogradNonFunctional: narrow_copy_dense_symint
+  tags: view_copy
+
+- func: narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU: narrow_copy_dense_cpu_out
+
+- func: narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeImplicitAutograd: narrow_symint
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: narrow_nested_symint
+
+- func: narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a)
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeImplicitAutograd: narrow_tensor_symint
+
+- func: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    CPU: batch_norm_cpu
+    CUDA: batch_norm_cuda
+    MPS: batch_norm_mps
+    MkldnnCPU: mkldnn_batch_norm
+
+- func: native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+  dispatch:
+    CUDA: batch_norm_cuda_out
+    MPS: batch_norm_mps_out
+    CPU: batch_norm_cpu_out
+
+# TODO: In 2 weeks, we should make native_batch_norm composite implicit so that this correct schema percolates correctly through our dispatching
+- func: _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    CPU: _batch_norm_legit_cpu
+    CUDA: _batch_norm_legit_cuda
+    MPS: _batch_norm_legit_mps
+    MkldnnCPU: _mkldnn_batch_norm_legit
+  autogen: _native_batch_norm_legit_functional
+  tags: core
+
+# HACK: identical to _native_batch_norm_legit, but training is known to be False,
+# So we known that running stats will not be mutated.
+# The real fix here is batch norm consolidation.
+- func: _native_batch_norm_legit_no_training(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    CompositeExplicitAutograd: _batch_norm_legit_no_training
+  autogen: _native_batch_norm_legit_no_training.out
+  tags: core
+
+- func: _native_batch_norm_legit.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd) -> (Tensor(d!), Tensor(e!), Tensor(f!))
+  dispatch:
+    CPU: _batch_norm_legit_cpu_out
+    CUDA: _batch_norm_legit_cuda_out
+    MPS: _batch_norm_legit_mps_out
+
+- func: _native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    CPU: _batch_norm_legit_no_stats_cpu
+    CUDA: _batch_norm_legit_no_stats_cuda
+    MPS: _batch_norm_legit_no_stats_mps
+    MkldnnCPU: _mkldnn_batch_norm_legit_no_stats
+  tags: core
+
+- func: _native_batch_norm_legit.no_stats_out(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+  dispatch:
+    CPU: _batch_norm_legit_no_stats_cpu_out
+    CUDA: _batch_norm_legit_no_stats_cuda_out
+    MPS: _batch_norm_legit_no_stats_mps_out
+
+- func: batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor)
+  dispatch:
+    CUDA: batch_norm_stats_cuda
+  autogen: batch_norm_stats.out
+
+- func: batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor
+  dispatch:
+    CUDA: batch_norm_elemt_cuda
+
+- func: batch_norm_elemt.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CUDA: batch_norm_elemt_cuda_out
+
+# for backward compatibility
+- func: batch_norm_gather_stats(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count) -> (Tensor, Tensor)
+  dispatch:
+    CUDA: batch_norm_gather_stats_cuda
+  autogen: batch_norm_gather_stats.out
+
+- func: batch_norm_gather_stats_with_counts(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts) -> (Tensor, Tensor)
+  dispatch:
+    CUDA: batch_norm_gather_stats_with_counts_cuda
+  autogen: batch_norm_gather_stats_with_counts.out
+
+- func: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    CPU: batch_norm_backward_cpu
+    CUDA: batch_norm_backward_cuda
+    MPS: batch_norm_backward_mps
+    MkldnnCPU: mkldnn_batch_norm_backward
+  autogen: native_batch_norm_backward.out
+
+- func: batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor)
+  dispatch:
+    CUDA: batch_norm_backward_reduce_cuda
+  autogen: batch_norm_backward_reduce.out
+
+- func: batch_norm_backward_elemt(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor
+  dispatch:
+    CUDA: batch_norm_backward_elemt_cuda
+  autogen: batch_norm_backward_elemt.out
+
+- func: batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum) -> (Tensor, Tensor)
+  dispatch:
+    CPU: batch_norm_update_stats_cpu
+    CUDA: batch_norm_update_stats_cuda
+  autogen: batch_norm_update_stats.out
+
+- func: is_vulkan_available() -> bool
+
+- func: _nnpack_available() -> bool
+
+- func: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: _nnpack_spatial_convolution
+  autogen: _nnpack_spatial_convolution.out
+
+- func: ones.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: ones
+  autogen: ones.names_out
+
+- func: ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: ones
+
+- func: ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: ones_out
+
+- func: ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+  dispatch:
+    # NB: Although this composite mutates on the inside, it is
+    # non-differentiable so NonFunctional doesn't apply
+    CompositeExplicitAutograd: ones_like
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: ones_like
+  autogen: ones_like.out
+
+- func: pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor
+
+- func: cdist(Tensor x1, Tensor x2, float p=2, int? compute_mode=None) -> Tensor
+
+- func: _euclidean_dist(Tensor x1, Tensor x2) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: _euclidean_dist
+  autogen: _euclidean_dist.out
+
+- func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
+  dispatch:
+    CPU, CUDA: _cdist_forward
+    MPS: _cdist_forward_mps
+  autogen: _cdist_forward.out
+  tags: core
+
+- func: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor
+  dispatch:
+    CPU, CUDA: _cdist_backward
+  autogen: _cdist_backward.out
+
+- func: pdist(Tensor self, float p=2) -> Tensor
+
+- func: _pdist_forward(Tensor self, float p=2) -> Tensor
+  dispatch:
+    CPU, CUDA: _pdist_forward
+  autogen: _pdist_forward.out
+  tags: core
+
+- func: _pdist_backward(Tensor grad, Tensor self, float p, Tensor pdist) -> Tensor
+  dispatch:
+    CPU, CUDA: _pdist_backward
+  autogen: _pdist_backward.out
+
+- func: cosine_similarity(Tensor x1, Tensor x2, int dim=1, float eps=1e-08) -> Tensor
+  variants: function
+
+- func: permute(Tensor(a) self, int[] dims) -> Tensor(a)
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: permute
+    MPS: permute_mps
+    SparseCPU, SparseCUDA: permute_sparse_coo
+  tags: core
+
+- func: movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)
+  variants: function, method
+
+- func: movedim.int(Tensor(a) self, int source, int destination) -> Tensor(a)
+  variants: function, method
+
+# moveaxis, alias for movedim
+- func: moveaxis.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)
+  variants: function, method
+
+- func: moveaxis.int(Tensor(a) self, int source, int destination) -> Tensor(a)
+  variants: function, method
+
+# Only exposed from C++ -- in Python,
+# we expose it as an attribute `T`, not a function.
+#
+# I'd like to name this "T" in C++ too, but
+# calling a native function "T" causes undefined
+# behavior on Windows, for reasons I don't understand
+# (maybe related to capital letter collation somehow...)
+- func: numpy_T(Tensor(a) self) -> Tensor(a)
+  variants: method
+
+# Exposed on Python as an attribute 'H'
+- func: matrix_H(Tensor(a) self) -> Tensor(a)
+  variants: method
+
+# Exposed on Python as an attribute 'mT'
+- func: mT(Tensor(a) self) -> Tensor(a)
+  variants: method
+
+# Exposed on Python as an attribute 'mH'
+- func: mH(Tensor(a) self) -> Tensor(a)
+  variants: method
+
+- func: adjoint(Tensor(a) self) -> Tensor(a)
+  variants: function, method
+
+- func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
+  dispatch:
+    CPU: pixel_shuffle_cpu
+    MPS: pixel_shuffle_mps
+    CompositeExplicitAutogradNonFunctional: math_pixel_shuffle
+  autogen: pixel_shuffle.out
+
+- func: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor
+  dispatch:
+    CPU: pixel_unshuffle_cpu
+    MPS: pixel_unshuffle_mps
+    CompositeExplicitAutogradNonFunctional: math_pixel_unshuffle
+  autogen: pixel_unshuffle.out
+
+- func: channel_shuffle(Tensor self, SymInt groups) -> Tensor
+  dispatch:
+    CPU, CUDA: channel_shuffle
+    QuantizedCPU: channel_shuffle_quantized_cpu
+  autogen: channel_shuffle.out
+
+- func: native_channel_shuffle(Tensor self, SymInt groups) -> Tensor
+  dispatch:
+    CPU: channel_shuffle_cpu
+    CompositeImplicitAutograd: math_channel_shuffle
+
+- func: is_pinned(Tensor self, Device? device=None) -> bool
+  variants: method
+  dispatch:
+    # the NestedTensor keys are necessary because NestedTensor has been removed
+    # from the CompositeExplicitAutograd keyset see Note [NestedTensor Not Included in Backend Keys]
+    CompositeExplicitAutograd, NestedTensorCPU: is_pinned
+    SparseCsrCPU: is_pinned_sparse_compressed
+    SparseCPU: is_pinned_sparse_coo
+
+# TODO: add a copy kwarg that guarantees that the tensor is put into fresh
+# pinned memory
+- func: pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a)
+  variants: method
+
+# Unlike pin_memory, this is guaranteed to give a new non-aliasing tensor
+- func: _pin_memory(Tensor self, Device? device=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: _pin_memory
+    NestedTensorCPU: _pin_memory_nested
+    SparseCPU: _pin_memory_sparse_coo
+    SparseCsrCPU: _pin_memory_sparse_compressed
+  autogen: _pin_memory.out
+
+- func: pinverse(Tensor self, float rcond=1e-15) -> Tensor
+  variants: function, method
+
+- func: poisson_nll_loss(Tensor input, Tensor target, bool log_input, bool full, float eps, int reduction) -> Tensor
+  variants: function
+
+- func: rad2deg(Tensor self) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: rad2deg
+    SparseCPU, SparseCUDA: rad2deg_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: rad2deg_sparse_csr
+  tags: pointwise
+
+- func: rad2deg_(Tensor(a!) self) -> Tensor(a!)
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: rad2deg_
+    SparseCPU, SparseCUDA: rad2deg_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: rad2deg_sparse_csr_
+  tags: pointwise
+
+- func: rad2deg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: rad2deg_out
+    SparseCPU, SparseCUDA: rad2deg_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: rad2deg_sparse_csr_out
+  tags: pointwise
+
+- func: deg2rad(Tensor self) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: deg2rad
+    SparseCPU, SparseCUDA: deg2rad_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: deg2rad_sparse_csr
+  tags: pointwise
+
+- func: deg2rad_(Tensor(a!) self) -> Tensor(a!)
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: deg2rad_
+    SparseCPU, SparseCUDA: deg2rad_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: deg2rad_sparse_csr_
+  tags: pointwise
+
+- func: deg2rad.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: deg2rad_out
+    SparseCPU, SparseCUDA: deg2rad_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: deg2rad_sparse_csr_out
+  tags: pointwise
+
+- func: scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: scalar_tensor
+  autogen: scalar_tensor.out
+  tags: core
+
+- func: rand.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: rand
+  autogen: rand.names_out
+  tags: nondeterministic_seeded
+
+- func: rand.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  device_check: NoCheck
+  device_guard: False
+  tags: nondeterministic_seeded
+  dispatch:
+    CompositeExplicitAutograd: rand
+  autogen: rand.generator_with_names_out
+
+- func: rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  tags: [core, nondeterministic_seeded]
+  dispatch:
+    CompositeExplicitAutograd: rand
+
+- func: rand.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  tags: nondeterministic_seeded
+  dispatch:
+    CompositeExplicitAutograd: rand
+
+- func: rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+  tags: nondeterministic_seeded
+  dispatch:
+    CompositeExplicitAutograd: rand_out
+
+- func: rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+  tags: nondeterministic_seeded
+
+- func: rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+  tags: nondeterministic_seeded
+  dispatch:
+    # NB: Although this composite mutates on the inside, it is
+    # non-differentiable so NonFunctional doesn't apply
+    CompositeExplicitAutograd: rand_like
+  autogen: rand_like.out
+
+- func: randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  tags: nondeterministic_seeded
+  dispatch:
+    CompositeExplicitAutograd: randint
+
+- func: randint.generator(SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  tags: nondeterministic_seeded
+  dispatch:
+    CompositeExplicitAutograd: randint
+
+- func: randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  tags: nondeterministic_seeded
+  dispatch:
+    CompositeExplicitAutograd: randint
+
+- func: randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  tags: nondeterministic_seeded
+  dispatch:
+    CompositeExplicitAutograd: randint
+
+- func: randint.out(SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+  tags: nondeterministic_seeded
+  dispatch:
+    CompositeExplicitAutograd: randint_out
+
+- func: randint.generator_out(SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+  tags: nondeterministic_seeded
+  dispatch:
+    CompositeExplicitAutograd: randint_out
+
+- func: randint.low_out(SymInt low, SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+  tags: nondeterministic_seeded
+  dispatch:
+    CompositeExplicitAutograd: randint_out
+
+- func: randint.low_generator_out(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+  tags: nondeterministic_seeded
+  dispatch:
+    CompositeExplicitAutograd: randint_out
+
+- func: randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+  tags: nondeterministic_seeded
+  dispatch:
+    # NB: Although this composite mutates on the inside, it is
+    # non-differentiable so NonFunctional doesn't apply
+    CompositeExplicitAutograd: randint_like
+  autogen: randint_like.out
+
+- func: randint_like.Tensor(Tensor self, Tensor high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+  tags: nondeterministic_seeded
+  dispatch:
+    # NB: Although this composite mutates on the inside, it is
+    # non-differentiable so NonFunctional doesn't apply
+    CompositeExplicitAutograd: randint_like
+  autogen: randint_like.Tensor_out
+
+- func: randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+  tags: nondeterministic_seeded
+  dispatch:
+    # NB: Although this composite mutates on the inside, it is
+    # non-differentiable so NonFunctional doesn't apply
+    CompositeExplicitAutograd: randint_like
+  autogen: randint_like.low_dtype_out
+
+- func: randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  tags: [core, nondeterministic_seeded]
+  dispatch:
+    CompositeExplicitAutograd: randn
+
+- func: randn.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  tags: nondeterministic_seeded
+  dispatch:
+    CompositeExplicitAutograd: randn
+
+- func: randn.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  tags: nondeterministic_seeded
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: randn
+  autogen: randn.names_out
+
+- func: randn.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  tags: nondeterministic_seeded
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: randn
+  autogen: randn.generator_with_names_out
+
+- func: randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+  tags: nondeterministic_seeded
+
+- func: randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+  tags: nondeterministic_seeded
+
+- func: randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+  tags: nondeterministic_seeded
+  dispatch:
+    # NB: Although this composite mutates on the inside, it is
+    # non-differentiable so NonFunctional doesn't apply
+    CompositeExplicitAutograd, CompositeImplicitAutogradNestedTensor: randn_like
+  autogen: randn_like.out
+
+- func: randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  tags: [core, nondeterministic_seeded]
+  dispatch:
+    CompositeExplicitAutograd: randperm
+
+- func: randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  tags: nondeterministic_seeded
+  dispatch:
+    CompositeExplicitAutograd: randperm
+
+- func: randperm.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!)
+  tags: nondeterministic_seeded
+  dispatch:
+    CompositeExplicitAutograd: randperm_out
+
+- func: randperm.generator_out(SymInt n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)
+  tags: nondeterministic_seeded
+  dispatch:
+    CPU: randperm_out_cpu
+    CUDA: randperm_out_cuda
+    MPS: randperm_out_mps
+
+- func: range.step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: range
+
+- func: range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: range
+
+- func: range.out_(Scalar start, Scalar end, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: range_out_no_step
+
+- func: range.out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, Meta: range_out
+    CUDA: range_cuda_out
+    MPS: range_mps_out
+  cpp_no_default_args: ['step']
+
+- func: ravel(Tensor(a) self) -> Tensor(a)
+  variants: function, method
+
+- func: reciprocal(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: reciprocal.out
+  variants: function, method
+  tags: [core, pointwise]
+
+- func: reciprocal_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: reciprocal.out
+  variants: function, method
+  tags: pointwise
+
+- func: reciprocal.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MTIA: reciprocal_out
+    MPS: reciprocal_out_mps
+  tags: pointwise
+
+- func: neg(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: neg.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: neg_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: neg_sparse_csr
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_neg
+  tags: [core, pointwise]
+
+- func: neg_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: neg.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: neg_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: neg_sparse_csr_
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_neg_
+  tags: pointwise
+
+- func: neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS, MTIA: neg_out
+    SparseCPU, SparseCUDA: neg_out_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: neg_sparse_csr_out
+  tags: pointwise
+# Alias for neg
+
+- func: negative(Tensor self) -> Tensor
+  variants: function, method
+
+- func: negative_(Tensor(a!) self) -> Tensor(a!)
+  variants: function, method
+
+- func: negative.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: repeat(Tensor self, SymInt[] repeats) -> Tensor
+  variants: method  # This is method-only to match the previous tensor API. In the future we could make this a function too.
+  dispatch:
+    CompositeExplicitAutograd: repeat
+    MPS: repeat_mps
+  autogen: repeat.out
+  tags: core
+
+- func: repeat_interleave.Tensor(Tensor repeats, *, SymInt? output_size=None) -> Tensor
+  variants: function
+  dispatch:
+    CPU: repeat_interleave_cpu
+    CUDA: repeat_interleave_cuda
+    MPS: repeat_interleave_mps
+  tags: dynamic_output_shape
+  autogen: repeat_interleave.Tensor_out
+
+- func: repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeImplicitAutograd: repeat_interleave_symint
+
+- func: repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeImplicitAutograd: repeat_interleave_symint
+
+- func: reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeImplicitAutograd: reshape_symint
+    CompositeImplicitAutogradNestedTensor: reshape_nested_symint
+
+- func: _reshape_copy(Tensor self, SymInt[] size) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: _reshape_copy_symint
+
+# NOTE [ _reshape_alias ] is meant to be used in the implementation of reshape.
+# They are not user-facing, hence the leading underscore. Please don't use it
+# anywhere else.
+- func: _reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a)
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CPU, CUDA, Meta, QuantizedCPU, QuantizedCUDA, ZeroTensor, MPS, MTIA: _reshape_alias
+    # We don't need to support mkldnn since this is handled explicitly by the reshape operator.
+
+- func: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    MkldnnCPU: mkldnn_reshape
+  autogen: _mkldnn_reshape.out
+
+- func: reshape_as(Tensor(a) self, Tensor other) -> Tensor(a)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeImplicitAutograd: reshape_as
+    CompositeImplicitAutogradNestedTensor: reshape_as_nested
+
+- func: round(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: round.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: round_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: round_sparse_csr
+  tags: [core, pointwise]
+
+- func: round_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: round.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: round_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: round_sparse_csr_
+  tags: pointwise
+
+- func: round.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: round_out
+    SparseCPU, SparseCUDA: round_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: round_sparse_csr_out
+  tags: pointwise
+
+- func: round.decimals(Tensor self, *, int decimals) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: round.decimals_out
+  variants: function, method
+  tags: pointwise
+
+- func: round_.decimals(Tensor(a!) self, *, int decimals) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: round.decimals_out
+  variants: function, method
+  tags: pointwise
+
+- func: round.decimals_out(Tensor self, *, int decimals, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: round_decimals_out
+  tags: pointwise
+
+- func: rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  tags: [pointwise, nondeterministic_seeded]
+
+- func: rrelu_(Tensor(a!) self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
+  tags: nondeterministic_seeded
+  device_check: NoCheck   # TensorIterator
+
+- func: relu(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CPU, CUDA, MTIA: relu
+    MPS: relu_mps
+    MkldnnCPU: mkldnn_relu
+    QuantizedCPU: relu_quantized_cpu
+    QuantizedCUDA: relu_quantized_cuda
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_relu
+    SparseCPU, SparseCUDA: relu_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: relu_sparse_csr
+  tags: [core, pointwise]
+
+- func: relu_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CPU, CUDA, MTIA: relu_
+    MPS: relu_mps_
+    MkldnnCPU: mkldnn_relu_
+    QuantizedCPU: relu_quantized_cpu_
+    QuantizedCUDA: relu_quantized_cuda_
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_relu_
+    SparseCPU, SparseCUDA: relu_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: relu_sparse_csr_
+  autogen: relu.out
+  tags: pointwise
+
+- func: relu6(Tensor self) -> Tensor
+  python_module: nn
+  tags: pointwise
+
+- func: relu6_(Tensor(a!) self) -> Tensor(a!)
+  python_module: nn
+
+- func: prelu(Tensor self, Tensor weight) -> Tensor
+  variants: function, method
+  autogen: prelu.out
+
+- func: _prelu_kernel(Tensor self, Tensor weight) -> Tensor
+  dispatch:
+    CPU, CUDA: _prelu_kernel
+    QuantizedCPU: _prelu_kernel_quantized_cpu
+    MkldnnCPU: mkldnn_prelu
+    MPS: prelu_mps
+
+- func: _prelu_kernel_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor)
+  dispatch:
+    CPU, CUDA: _prelu_kernel_backward
+    MkldnnCPU: mkldnn_prelu_backward
+    MPS: prelu_backward_mps
+
+- func: gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    CPU: gelu_out_cpu
+    CUDA: gelu_out_cuda
+    MPS: gelu_out_mps
+
+- func: gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!)
+  structured_delegate: gelu.out
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    QuantizedCPU: gelu_quantized_cpu_
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_gelu_
+
+- func: gelu(Tensor self, *, str approximate='none') -> Tensor
+  structured_delegate: gelu.out
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    MkldnnCPU: mkldnn_gelu
+    QuantizedCPU: gelu_quantized_cpu
+    QuantizedCUDA: gelu_quantized_cuda
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_gelu
+  tags: [core, pointwise]
+
+- func: gelu_backward.grad_input(Tensor grad_output, Tensor self, *, str approximate='none', Tensor(a!) grad_input) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  python_module: nn
+  dispatch:
+    CPU: gelu_backward_out_cpu
+    CUDA: gelu_backward_out_cuda
+    MPS: gelu_backward_out_mps
+
+- func: gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor
+  structured_delegate: gelu_backward.grad_input
+  python_module: nn
+  dispatch:
+    MkldnnCPU: mkldnn_gelu_backward
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: gelu_backwards_nested
+  tags: pointwise
+
+- func: infinitely_differentiable_gelu_backward(Tensor grad, Tensor self) -> Tensor
+  variants: function
+  python_module: nn
+  device_check: NoCheck
+  device_guard: False
+
+- func: hardshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA, MPS: hardshrink_out
+
+- func: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor
+  structured_delegate: hardshrink.out
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  tags: pointwise
+
+- func: hardshrink_backward.grad_input(Tensor grad_out, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: hardshrink_backward_out
+
+- func: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor
+  structured_delegate: hardshrink_backward.grad_input
+  variants: function, method
+
+- func: rsqrt(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: rsqrt.out
+  variants: function, method
+  tags: [core, pointwise]
+
+- func: rsqrt_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: rsqrt.out
+  variants: function, method
+  tags: pointwise
+
+- func: rsqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS, MTIA: rsqrt_out
+  tags: pointwise
+
+- func: select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a)
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+
+- func: select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a)
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: select_symint
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: select_sparse_csr
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: select_nested
+  tags: core
+
+- func: select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor
+  variants: function
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: select_backward_symint
+  autogen: select_backward.out
+
+- func: _nested_select_backward(Tensor grad_output, Tensor self, int dim, SymInt index) -> Tensor
+  variants: function
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: _nested_select_backward_symint
+
+- func: selu(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  tags: pointwise
+
+- func: selu_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+
+- func: celu(Tensor self, Scalar alpha=1.0) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CompositeExplicitAutograd: celu
+  tags: pointwise
+
+- func: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CompositeExplicitAutograd: celu_
+  autogen: celu.out
+
+- func: silu(Tensor self) -> Tensor
+  structured_delegate: silu.out
+  python_module: nn
+  dispatch:
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_silu
+  tags: pointwise
+
+- func: silu_(Tensor(a!) self) -> Tensor(a!)
+  structured_delegate: silu.out
+  python_module: nn
+  dispatch:
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_silu_
+  tags: pointwise
+
+- func: silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  python_module: nn
+  dispatch:
+    CPU, CUDA, MTIA: silu_out
+    MPS: silu_out_mps
+  tags: pointwise
+
+- func: silu_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  python_module: nn
+  dispatch:
+    CPU, CUDA: silu_backward_out
+    MPS: silu_backward_out_mps
+  tags: pointwise
+
+- func: silu_backward(Tensor grad_output, Tensor self) -> Tensor
+  structured_delegate: silu_backward.grad_input
+  python_module: nn
+  dispatch:
+    CompositeImplicitAutograd: math_silu_backward
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: silu_backward_nested
+  tags: pointwise
+
+- func: mish(Tensor self) -> Tensor
+  structured_delegate: mish.out
+  python_module: nn
+  tags: pointwise
+
+- func: mish_(Tensor(a!) self) -> Tensor(a!)
+  structured_delegate: mish.out
+  python_module: nn
+
+- func: mish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  python_module: nn
+  dispatch:
+    CPU, CUDA: mish_out
+    MPS: mish_out_mps
+
+- func: mish_backward(Tensor grad_output, Tensor self) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU, CUDA: mish_backward
+    MPS: mish_backward_mps
+    CompositeImplicitAutograd: math_mish_backward
+
+- func: sigmoid(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: sigmoid.out
+  variants: function, method
+  dispatch:
+    QuantizedCPU: sigmoid_quantized_cpu
+    MkldnnCPU: mkldnn_sigmoid
+  tags: [core, pointwise]
+
+- func: sigmoid_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: sigmoid.out
+  variants: function, method
+  dispatch:
+    MkldnnCPU: mkldnn_sigmoid_
+  tags: pointwise
+
+- func: sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: sigmoid_out
+  tags: pointwise
+
+- func: logit(Tensor self, float? eps=None) -> Tensor
+  variants: function, method
+  dispatch:
+    CPU, CUDA, MTIA: logit
+    MPS: logit_mps
+  tags: pointwise
+
+- func: logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!)
+  variants: function, method
+  dispatch:
+    CPU, CUDA: logit_
+  tags: pointwise
+
+- func: logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA: logit_out
+    MPS: logit_out_mps
+  tags: pointwise
+
+- func: sin(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: sin.out
+  variants: function, method
+  dispatch:
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sin_sparse_csr
+    SparseCPU, SparseCUDA: sin_sparse
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_sin
+  tags: [core, pointwise]
+
+- func: sin_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: sin.out
+  variants: function, method
+  dispatch:
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sin_sparse_csr_
+    SparseCPU, SparseCUDA: sin_sparse_
+  tags: pointwise
+
+- func: sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS, MTIA: sin_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sin_sparse_csr_out
+    SparseCPU, SparseCUDA: sin_sparse_out
+  tags: pointwise
+
+- func: sinc(Tensor self) -> Tensor
+  structured_delegate: sinc.out
+  variants: function, method
+  tags: pointwise
+
+- func: sinc_(Tensor(a!) self) -> Tensor(a!)
+  structured_delegate: sinc.out
+  variants: function, method
+  tags: pointwise
+
+- func: sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: sinc_out
+  tags: pointwise
+
+- func: sinh(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: sinh.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: sinh_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sinh_sparse_csr
+  tags: [core, pointwise]
+
+- func: sinh_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: sinh.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: sinh_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sinh_sparse_csr_
+  tags: pointwise
+
+- func: sinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: sinh_out
+    SparseCPU, SparseCUDA: sinh_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sinh_sparse_csr_out
+
+# Returns a copy of this `Variable` that is detached from its autograd graph.
+# This method is OK to call if the `Variable` is a view.
+#
+# NOTE: Previously, if we change the tensor metadata (e.g. sizes / strides /
+# storage / storage_offset) of a tensor created from `detach()`, those metadata
+# in the original tensor will also be updated. However, the new behavior is that
+# those metadata changes to the detached tensor will not update the original tensor
+# anymore, and in the `detach()` function we need to set `allow_tensor_metadata_change_`
+# to false to make such changes explicitly illegal, in order to prevent users from
+# changing metadata of the detached tensor and expecting the original tensor to also
+# be updated.
+  tags: pointwise
+- func: detach(Tensor(a) self) -> Tensor(a)
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: detach
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: detach
+
+# Like `detach()`, but modifies this `Variable` in-place. This method may
+# only be called on non-view `Variable`s. You can use `is_view()` to check
+# this. If this `Variable` is a view, throws an `std::runtime_error()`.
+- func: detach_(Tensor(a!) self) -> Tensor(a!)
+  variants: function, method
+  tags: inplace_view
+  dispatch:
+    CompositeExplicitAutograd: detach_
+
+- func: size.int(Tensor self, int dim) -> int
+  variants: function
+  device_check: NoCheck
+  device_guard: False
+  manual_cpp_binding: True
+
+- func: size.Dimname(Tensor self, Dimname dim) -> int
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+
+- func: sym_size.int(Tensor self, int dim) -> SymInt
+  variants: function
+  device_check: NoCheck
+  device_guard: False
+  tags: core
+  manual_cpp_binding: True
+
+- func: sym_numel(Tensor self) -> SymInt
+  variants: function
+  device_check: NoCheck
+  device_guard: False
+  tags: core
+  manual_cpp_binding: True
+
+- func: sym_storage_offset(Tensor self) -> SymInt
+  variants: function
+  device_check: NoCheck
+  device_guard: False
+  tags: core
+  manual_cpp_binding: True
+
+- func: slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: slice
+  tags: core
+
+# NOTE: The implementation of split_with_sizes bypasses the dispatcher to call this; undo
+# that if adding specific implementations here!
+
+- func: slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor
+  variants: function
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: slice_backward
+  autogen: slice_backward.out
+
+# NB: This op exists to back the implementation of reverse view_funcs for various views (chunk,
+# slice.Tensor, split_with_sizes, et al.). Currently, these are only used during fake-ification
+# of PT2 graph input subclass instances that are views. This means:
+# * This op shouldn't really show up in eager mode (so e.g. XLA shouldn't have to implement it)
+# * This op shouldn't show up in a PT2 graph (so a PT2 backend shouldn't have to implement it)
+# * A subclass will have to implement this to work in PT2 if a subclass view is used as a graph
+#   input AND the view utilizes this op in its inverse. The idea is that slice_inverse() is
+#   easier to implement for a subclass than as_strided()
+- func: slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: slice_inverse_symint
+
+- func: slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: slice_scatter
+  autogen: slice_scatter.out
+  tags: [core, view_copy]
+
+- func: select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: select_scatter_symint
+  autogen: select_scatter.out
+  tags: core
+
+- func: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: diagonal_scatter
+  autogen: diagonal_scatter.out
+
+- func: as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: as_strided_scatter_symint
+  autogen: as_strided_scatter.out
+
+- func: smm(Tensor self, Tensor mat2) -> Tensor
+  variants: function, method
+
+# softmax allows positional dtype, unlike most operators, because kwonly is BC-breaking when loading jit models.
+- func: softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
+  variants: function, method
+
+- func: softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: softmax_out
+
+- func: softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
+  variants: function, method
+
+- func: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor
+  structured_delegate: _softmax.out
+  dispatch:
+    MkldnnCPU: mkldnn_softmax
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: softmax_nested
+  tags: core
+
+- func: _softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  dispatch:
+    CPU: softmax_cpu_out
+    CUDA: softmax_cuda_out
+    MPS: softmax_mps_out
+
+- func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor
+  structured_delegate: _softmax_backward_data.out
+  dispatch:
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: nested_softmax_backward
+
+- func: _softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) grad_input) -> Tensor(a!)
+  structured: True
+  dispatch:
+    CPU: softmax_backward_cpu_out
+    CUDA: softmax_backward_cuda_out
+    MPS: softmax_backward_mps_out
+
+- func: unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: unsafe_split
+  autogen: unsafe_split.Tensor_out
+
+- func: split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: split
+
+- func: split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[]
+  variants: function, method
+  device_guard: False
+  dispatch:
+    CompositeImplicitAutograd: split_symint
+
+- func: unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: unsafe_split_with_sizes
+  autogen: unsafe_split_with_sizes.out
+
+- func: split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[]
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: split_with_sizes
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: split_with_sizes_nested
+  tags: core
+
+- func: hsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[]
+  variants: function, method
+
+- func: hsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[]
+  variants: function, method
+
+- func: vsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[]
+  variants: function, method
+
+- func: vsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[]
+  variants: function, method
+
+- func: dsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[]
+  variants: function, method
+
+- func: dsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[]
+  variants: function, method
+
+- func: squeeze(Tensor(a) self) -> Tensor(a)
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: squeeze
+    QuantizedCPU, QuantizedCUDA: squeeze_quantized
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: squeeze_nested
+
+- func: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: squeeze
+    QuantizedCPU, QuantizedCUDA: squeeze_quantized
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: squeeze_dim_nested
+  tags: core
+
+- func: squeeze.dimname(Tensor(a) self, Dimname dim) -> Tensor(a)
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+
+
+- func: squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: squeeze
+    QuantizedCPU, QuantizedCUDA: squeeze_quantized
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: squeeze_dim_nested
+  tags: core
+
+- func: squeeze_(Tensor(a!) self) -> Tensor(a!)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  tags: inplace_view
+  dispatch:
+    CompositeExplicitAutograd: squeeze_
+
+- func: squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  tags: inplace_view
+  dispatch:
+    CompositeExplicitAutograd: squeeze_
+
+- func: squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  tags: inplace_view
+  dispatch:
+    CompositeExplicitAutograd: squeeze_
+
+- func: squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  tags: inplace_view
+
+- func: sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+  variants: function, method
+
+- func: sspaddmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU: _sspaddmm_out_only_sparse
+    CUDA: _sspaddmm_out_only_sparse_cuda
+    SparseCPU: _sspaddmm_out_cpu
+    SparseCUDA: _sspaddmm_out_cuda
+
+- func: _chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: _chunk_cat
+    CUDA: _chunk_cat_cuda
+
+- func: _chunk_cat.out(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: _chunk_cat_out
+    CUDA: _chunk_cat_out_cuda
+
+- func: stack(Tensor[] tensors, int dim=0) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: stack
+
+- func: stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: stack_out
+
+- func: _stack(Tensor[] tensors, int dim=0) -> Tensor
+  dispatch: # match the backends supported by _cat
+    CPU: _stack_cpu
+    CompositeExplicitAutograd: _stack
+
+- func: _stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch: # match the backends supported by _cat_out
+    CPU: _stack_out_cpu
+    CompositeExplicitAutograd: _stack_out
+
+- func: hstack(Tensor[] tensors) -> Tensor
+
+- func: hstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: vstack(Tensor[] tensors) -> Tensor
+
+- func: vstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: dstack(Tensor[] tensors) -> Tensor
+
+- func: dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+
+# Overload without center & pad mode, needed for forward-compatibility
+- func: stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor
+  variants: function, method
+  cpp_no_default_args: ['hop_length', 'win_length', 'window', 'normalized']
+
+- func: stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor
+  variants: function, method
+
+- func: istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor
+  variants: function, method
+
+- func: stride.int(Tensor self, int dim) -> int
+  variants: function
+  device_check: NoCheck
+  device_guard: False
+  manual_cpp_binding: True
+
+- func: stride.Dimname(Tensor self, Dimname dim) -> int
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+
+- func: sym_stride.int(Tensor self, int dim) -> SymInt
+  variants: function
+  device_check: NoCheck
+  device_guard: False
+  tags: core
+  manual_cpp_binding: True
+
+- func: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: sum
+    SparseCPU, SparseCUDA, SparseMeta: sum_coo
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_csr
+  autogen: sum.out
+
+- func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  # TODO: Align the signature of sum.dim_IntList and _sparse_csr_sum.dim_dtype
+  structured_delegate: sum.IntList_out
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    NestedTensorCPU: NestedTensor_sum_dim_CPU
+    SparseCPU, SparseCUDA: sum_sparse_coo
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_sparse_compressed
+  tags: core
+
+- func: sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: sum_out
+    MPS: sum_out_mps
+
+- func: sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+
+# TODO: this function will be replaced once nested expand semantics have been settled on
+- func: _nested_sum_backward(Tensor grad, Tensor self, int[1]? dim, bool keepdim=False) -> Tensor
+  dispatch:
+    NestedTensorCPU: _nested_sum_backward_cpu
+
+- func: nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  variants: function, method
+  dispatch:
+    CPU, CUDA: nansum
+    MPS: nansum_mps
+
+- func: nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA: nansum_out
+    MPS: nansum_out_mps
+
+- func: sum_to_size(Tensor self, SymInt[] size) -> Tensor
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeImplicitAutograd: sum_to_size_symint
+
+- func: sqrt(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: sqrt.out
+  variants: function, method
+  dispatch:
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_sqrt
+    SparseCPU, SparseCUDA: sqrt_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sqrt_sparse_csr
+  tags: [core, pointwise]
+
+- func: sqrt_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: sqrt.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: sqrt_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sqrt_sparse_csr_
+  tags: pointwise
+
+- func: sqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS, MTIA: sqrt_out
+    SparseCPU, SparseCUDA: sqrt_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sqrt_sparse_csr_out
+  tags: pointwise
+
+- func: square(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  tags: pointwise
+
+- func: square_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  tags: pointwise
+
+- func: square.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  tags: pointwise
+
+- func: std(Tensor self, bool unbiased=True) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  cpp_no_default_args: ["unbiased"]
+
+- func: std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  cpp_no_default_args: ["unbiased"]
+
+- func: std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CPU, CUDA: std
+    MPS: std_mps
+    QuantizedCPU: std_quantized_cpu
+
+- func: std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  cpp_no_default_args: ["unbiased"]
+
+- func: std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  cpp_no_default_args: ["unbiased"]
+
+- func: std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CPU, CUDA: std_mean
+    MPS: std_mean_mps
+  autogen: std_mean.correction_out
+
+- func: std_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  cpp_no_default_args: ["unbiased"]
+
+- func: std_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+
+- func: std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  cpp_no_default_args: ["unbiased"]
+
+- func: std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: std_out
+    QuantizedCPU: std_out_quantized_cpu
+
+- func: std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  cpp_no_default_args: ["unbiased"]
+
+- func: std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  cpp_no_default_args: ["unbiased"]
+
+- func: std.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: std.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+
+- func: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CPU, CUDA: prod
+    MPS: prod_mps
+  autogen: prod.out
+  tags: core
+
+- func: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  structured_delegate: prod.int_out
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  tags: core
+
+- func: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: prod_out
+    MPS: prod_out_mps
+
+- func: prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+
+- func: t(Tensor(a) self) -> Tensor(a)
+  device_check: NoCheck
+  device_guard: False
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: t
+
+- func: t_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck
+  device_guard: False
+  variants: method
+  tags: inplace_view
+  dispatch:
+    CompositeExplicitAutograd: t_
+
+- func: tan(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: tan.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: tan_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tan_sparse_csr
+  tags: [core, pointwise]
+
+- func: tan_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: tan.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: tan_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tan_sparse_csr_
+  tags: pointwise
+
+- func: tan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: tan_out
+    SparseCPU, SparseCUDA: tan_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tan_sparse_csr_out
+  tags: pointwise
+
+- func: tanh(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: tanh.out
+  variants: function, method
+  dispatch:
+    QuantizedCPU: tanh_quantized_cpu
+    MkldnnCPU: mkldnn_tanh
+    SparseCPU, SparseCUDA: tanh_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tanh_sparse_csr
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_tanh
+  tags: [core, pointwise]
+
+- func: tanh_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: tanh.out
+  variants: function, method
+  dispatch:
+    MkldnnCPU: mkldnn_tanh_
+    SparseCPU, SparseCUDA: tanh_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tanh_sparse_csr_
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_tanh_
+  tags: pointwise
+
+- func: tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS, MTIA: tanh_out
+    SparseCPU, SparseCUDA: tanh_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: tanh_sparse_csr_out
+  tags: pointwise
+
+- func: tensordot(Tensor self, Tensor other, int[] dims_self, int[] dims_other) -> Tensor
+  variants: function
+
+- func: tensordot.out(Tensor self, Tensor other, int[] dims_self, int[] dims_other, *, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+
+# TODO: namespace threshold in 'nn'
+- func: threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  structured_delegate: threshold.out
+  dispatch:
+    QuantizedCPU: threshold_quantized_cpu
+  tags: pointwise
+
+- func: threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  structured_delegate: threshold.out
+
+- func: threshold.out(Tensor self, Scalar threshold, Scalar value, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: threshold_out
+    MPS: threshold_out_mps
+
+- func: threshold_backward.grad_input(Tensor grad_output, Tensor self, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: threshold_backward_out
+    MPS: threshold_backward_out_mps
+    SparseCPU, SparseCUDA: threshold_backward_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: threshold_backward_sparse_compressed_out
+
+- func: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor
+  variants: function
+  structured_delegate: threshold_backward.grad_input
+  dispatch:
+    MkldnnCPU: mkldnn_relu_backward
+    SparseCPU, SparseCUDA: threshold_backward_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: threshold_backward_sparse_compressed
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: threshold_backwards_nested
+  tags: pointwise
+
+- func: tile(Tensor self, SymInt[] dims) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeImplicitAutograd: tile_symint
+
+- func: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: transpose
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: transpose_nested
+
+- func: transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a)
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+
+- func: _mkldnn_transpose(Tensor self, int dim0, int dim1) -> Tensor
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    MkldnnCPU: mkldnn_transpose
+
+- func: transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  tags: inplace_view
+  dispatch:
+    CompositeExplicitAutograd: transpose_
+
+- func: _mkldnn_transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    MkldnnCPU: mkldnn_transpose_
+  autogen: _mkldnn_transpose.out
+
+- func: one_hot(Tensor self, int num_classes=-1) -> Tensor
+  python_module: nn
+  variants: function
+  tags: dynamic_output_shape
+
+- func: flip(Tensor self, int[] dims) -> Tensor
+  variants: function, method
+  dispatch:
+    CPU, QuantizedCPU, CUDA, QuantizedCUDA: flip
+    MPS: flip_mps
+  autogen: flip.out
+  tags: core
+
+- func: fliplr(Tensor self) -> Tensor
+  variants: function, method
+
+- func: flipud(Tensor self) -> Tensor
+  variants: function, method
+
+- func: roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor
+  variants: function, method
+  dispatch:
+    CPU, MPS: roll
+    CUDA: roll_cuda
+  autogen: roll.out
+
+# default int[] value [0,1] should not add space after comma, since codegen parser uses ', ' to split args
+
+- func: rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: rot90
+  autogen: rot90.out
+
+- func: trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor
+
+- func: trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor
+
+- func: trapz.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor
+
+- func: trapz.dx(Tensor y, *, float dx=1, int dim=-1) -> Tensor
+
+# Fused implementation detail for transformers. Adds in-projection bias to QKV and divides Q by sqrt(D/num_heads).
+- func: _transform_bias_rescale_qkv(Tensor qkv, Tensor qkv_bias, int num_heads) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    CPU, NestedTensorCPU: transform_bias_rescale_qkv_cpu
+    CUDA, NestedTensorCUDA: transform_bias_rescale_qkv_cuda
+  autogen: _transform_bias_rescale_qkv.out
+
+- func: _nested_tensor_from_mask(Tensor t, Tensor mask, bool mask_check=True) -> Tensor
+  dispatch:
+    CPU, CUDA: NestedTensor_nested_tensor_from_mask
+  autogen: _nested_tensor_from_mask.out
+
+- func: _nested_tensor_from_mask_left_aligned(Tensor t, Tensor mask) -> bool
+  dispatch:
+    CPU, CUDA: NestedTensor_nested_tensor_from_mask_left_aligned
+
+- func: _nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor
+  device_check: NoCheck # cpu_nested_shape_example will always be on CPU
+  dispatch:
+    CPU: nested_from_padded_generic
+    CUDA: nested_from_padded_cuda
+  autogen: _nested_from_padded.out
+
+# These private functions are temporary. They will be updated/deleted when nested tensors switch to using SymInts for their metadata representation
+- func: _nested_tensor_size(Tensor self) -> Tensor
+  variants: method
+  dispatch:
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: _nested_tensor_size
+  autogen: _nested_tensor_size.out
+
+- func: _nested_tensor_strides(Tensor self) -> Tensor
+  variants: method
+  dispatch:
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: _nested_tensor_strides
+  autogen: _nested_tensor_strides.out
+
+- func: _nested_tensor_storage_offsets(Tensor self) -> Tensor
+  variants: method
+  dispatch:
+    NestedTensorCPU, NestedTensorCUDA, NestedTensorMeta: _nested_tensor_storage_offsets
+  autogen: _nested_tensor_storage_offsets.out
+
+# _nested_from_padded is not usable from Python, so
+# _nested_from_padded_and_nested_example is available for testing.
+- func: _nested_from_padded_and_nested_example(Tensor padded, Tensor nt_example) -> Tensor
+  dispatch:
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_from_padded_and_nested_example
+  autogen: _nested_from_padded_and_nested_example.out
+
+# The input arguments' types to this functions are temporary. When nested tensors switch to using SymInts for their metadata representation
+# this will need to be updated
+- func: _nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor(a)
+  variants: function
+  device_check: NoCheck
+  dispatch:
+    CPU, CUDA: _nested_view_from_buffer
+
+- func: _nested_view_from_buffer_copy(Tensor self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor
+  variants: function
+  device_check: NoCheck
+  tags: view_copy
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: _nested_view_from_buffer_copy
+  autogen: _nested_view_from_buffer_copy.out
+
+- func: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a)
+  variants: function
+  device_check: NoCheck
+  dispatch: {}
+
+- func: _nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor
+  variants: function
+  device_check: NoCheck
+  tags: view_copy
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: _nested_view_from_jagged_copy
+  autogen: _nested_view_from_jagged_copy.out
+
+- func: _nested_get_values(Tensor(a) self) -> Tensor(a)
+  variants: function
+  device_check: NoCheck
+  dispatch: {}
+
+- func: _nested_get_values_copy(Tensor self) -> Tensor
+  variants: function
+  device_check: NoCheck
+  tags: view_copy
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: _nested_get_values_copy
+  autogen: _nested_get_values_copy.out
+
+- func: _nested_get_offsets(Tensor self) -> Tensor
+  variants: function
+  device_check: NoCheck
+  dispatch: {}
+
+# returns undefined Tensor if no lengths present
+- func: _nested_get_lengths(Tensor self) -> Tensor
+  variants: function
+  device_check: NoCheck
+  dispatch: {}
+
+- func: _nested_get_ragged_idx(Tensor self) -> int
+  variants: function
+  device_check: NoCheck
+  dispatch: {}
+
+- func: _nested_get_min_seqlen(Tensor self) -> Tensor
+  variants: function
+  device_check: NoCheck
+  dispatch: {}
+
+- func: _nested_get_max_seqlen(Tensor self) -> Tensor
+  variants: function
+  device_check: NoCheck
+  dispatch: {}
+
+- func: _nested_get_jagged_dummy(Tensor any) -> Tensor
+  category_override: dummy
+  dispatch: {}
+
+- func: _nested_compute_contiguous_strides_offsets(Tensor nested_size) -> (Tensor, Tensor)
+  variants: function
+  device_check: NoCheck
+  dispatch:
+    CPU, CUDA: _nested_compute_contiguous_strides_offsets
+
+- func: _trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor
+  dispatch:
+    # calls unsqueeze
+    CompositeExplicitAutogradNonFunctional: _trilinear
+  autogen: _trilinear.out
+
+- func: triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, float margin=1.0, float p=2, float eps=1e-06, bool swap=False, int reduction=Mean) -> Tensor
+
+- func: trunc(Tensor self) -> Tensor
+  structured_delegate: trunc.out
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: trunc_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: trunc_sparse_csr
+  tags: [core, pointwise]
+
+- func: trunc_(Tensor(a!) self) -> Tensor(a!)
+  structured_delegate: trunc.out
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: trunc_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: trunc_sparse_csr_
+  tags: pointwise
+
+- func: trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA, MPS: trunc_out
+    SparseCPU, SparseCUDA: trunc_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: trunc_sparse_csr_out
+  tags: pointwise
+# Alias for trunc
+
+- func: fix(Tensor self) -> Tensor
+  variants: function, method
+
+- func: fix_(Tensor(a!) self) -> Tensor(a!)
+  variants: function, method
+
+- func: fix.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: type_as(Tensor self, Tensor other) -> Tensor
+  variants: method
+
+- func: _has_compatible_shallow_copy_type(Tensor self, Tensor from) -> bool
+  variants: function
+
+- func: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)
+  variants: function
+  dispatch:
+    CPU: _unique_cpu
+    CUDA: _unique_cuda
+  autogen: _unique.out
+
+- func: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
+  variants: function
+  dispatch:
+    CPU: unique_dim_cpu
+    CUDA: unique_dim_cuda
+  tags: dynamic_output_shape
+  autogen: unique_dim.out
+
+- func: unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor)
+  variants: function
+  dispatch:
+    CPU: unique_consecutive_cpu
+    CUDA: unique_consecutive_cuda
+    MPS: unique_consecutive_mps
+  tags: dynamic_output_shape
+  autogen: unique_consecutive.out
+
+- func: unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
+  variants: function
+  dispatch:
+    CPU: unique_dim_consecutive_cpu
+    CUDA: unique_dim_consecutive_cuda
+    MPS: unique_dim_consecutive_mps
+  tags: dynamic_output_shape
+  autogen: unique_dim_consecutive.out
+
+# _unique and _unique_dim are fragile and modifying them easily cause internal break
+# the below operator is a temporary hack for adding return_counts support
+# Please don't rely on these two operators, they will be removed soon
+
+- func: _unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
+  variants: function
+  dispatch:
+    CPU: _unique2_cpu
+    CUDA: _unique2_cuda
+    MPS: _unique2_mps
+  tags: dynamic_output_shape
+  autogen: _unique2.out
+
+- func: _unsafe_view(Tensor self, SymInt[] size) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: _unsafe_view
+  autogen: _unsafe_view.out
+
+- func: unsqueeze(Tensor(a) self, int dim) -> Tensor(a)
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: unsqueeze
+    SparseCPU, SparseCUDA: unsqueeze_sparse
+    QuantizedCPU, QuantizedCUDA: unsqueeze_quantized
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: unsqueeze_nested
+  tags: core
+
+- func: unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  tags: inplace_view
+  dispatch:
+    CompositeExplicitAutograd: unsqueeze_
+
+- func: vander(Tensor x, int? N=None, bool increasing=False) -> Tensor
+
+- func: var(Tensor self, bool unbiased=True) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  cpp_no_default_args: ["unbiased"]
+
+- func: var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  tags: core
+  cpp_no_default_args: ["unbiased"]
+
+- func: var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CPU, CUDA: var
+    MPS: var_mps
+  tags: core
+
+- func: var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  cpp_no_default_args: ["unbiased"]
+
+- func: var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: var_out
+
+- func: var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  cpp_no_default_args: ["unbiased"]
+
+- func: var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  cpp_no_default_args: ["unbiased"]
+
+- func: var.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: var.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+
+- func: var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  cpp_no_default_args: ["unbiased"]
+
+- func: var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  cpp_no_default_args: ["unbiased"]
+
+- func: var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CPU, CUDA: var_mean
+    MPS: var_mean_mps
+  autogen: var_mean.correction_out
+
+- func: var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  cpp_no_default_args: ["unbiased"]
+
+- func: var_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+
+- func: view_as(Tensor(a) self, Tensor other) -> Tensor(a)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+
+- func: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CPU, CUDA, MPS, MTIA: where
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_where
+  tags: [core, pointwise]
+
+- func: where.self_out(Tensor condition, Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA, MPS, MTIA: where_self_out
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_where_out
+
+- func: where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor
+  variants: function
+
+- func: where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor
+  variants: function, method
+
+- func: where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor
+  variants: function
+
+- func: where(Tensor condition) -> Tensor[]
+  device_check: NoCheck   # TensorIterator
+  variants: function
+
+- func: norm_except_dim(Tensor v, int pow=2, int dim=0) -> Tensor
+  variants: function
+
+# VariableType::_weight_norm does not want to be given a gap in the autograd graph,
+# so we don't define "dispatch" variants for it.
+- func: _weight_norm(Tensor v, Tensor g, int dim=0) -> Tensor
+  variants: function
+
+- func: _weight_norm_interface(Tensor v, Tensor g, int dim=0) -> (Tensor, Tensor)
+  variants: function
+  dispatch:
+    CPU: weight_norm_cpu
+    CUDA: weight_norm_cuda
+    MPS: weight_norm_mps
+  autogen: _weight_norm_interface.out
+
+- func: _weight_norm_interface_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor)
+  variants: function
+  dispatch:
+    CPU: weight_norm_backward_cpu
+    CUDA: weight_norm_backward_cuda
+    MPS: weight_norm_backward_mps
+  autogen: _weight_norm_interface_backward.out
+
+- func: _weight_norm_differentiable_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor)
+  variants: function
+
+- func: zeros.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: zeros
+  autogen: zeros.names_out
+
+- func: _efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CPU: _efficientzerotensor
+    CUDA: _efficientzerotensor_cuda
+    MPS: _efficientzerotensor_mps
+    Meta: _efficientzerotensor_meta_symint
+  autogen: _efficientzerotensor.out
+
+- func: zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: zeros_symint
+
+- func: zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: zeros_out
+    SparseCPU, SparseCUDA, SparseMeta: zeros_sparse_out
+
+- func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
+  dispatch:
+    # NB: Although this composite mutates on the inside, it is
+    # non-differentiable so NonFunctional doesn't apply
+    CompositeExplicitAutograd, CompositeImplicitAutogradNestedTensor: zeros_like
+  autogen: zeros_like.out
+
+- func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor
+  variants: function
+  dispatch:
+    CPU: _standard_gamma_grad_cpu
+    CUDA: _standard_gamma_grad_cuda
+  autogen: _standard_gamma_grad.out
+
+- func: _standard_gamma(Tensor self, Generator? generator=None) -> Tensor
+  variants: function
+  dispatch:
+    CPU: _s_gamma_cpu
+    CUDA: _s_gamma_cuda
+  tags: nondeterministic_seeded
+  autogen: _standard_gamma.out
+
+- func: _dirichlet_grad(Tensor x, Tensor alpha, Tensor total) -> Tensor
+  dispatch:
+    CPU: _dirichlet_grad_cpu
+    CUDA: _dirichlet_grad_cuda
+  autogen: _dirichlet_grad.out
+
+- func: _sample_dirichlet(Tensor self, Generator? generator=None) -> Tensor
+  tags: nondeterministic_seeded
+  variants: function
+  dispatch:
+    CPU: _s_dirichlet_cpu
+    CUDA: _s_dirichlet_cuda
+  autogen: _sample_dirichlet.out
+
+- func: poisson(Tensor self, Generator? generator=None) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU: _s_poisson_cpu
+    CUDA: _s_poisson_cuda
+  tags: nondeterministic_seeded
+  autogen: poisson.out
+
+- func: binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU: _s_binomial_cpu
+    CUDA: _s_binomial_cuda
+  tags: nondeterministic_seeded
+  autogen: binomial.out
+
+# When more variants get ported to native, this dispatch will get more
+# complicated
+
+- func: native_norm(Tensor self, Scalar p=2) -> Tensor
+  dispatch:
+    SparseCPU, SparseCUDA: norm_sparse
+  autogen: native_norm.out
+
+- func: native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype) -> Tensor
+  dispatch:
+    SparseCPU, SparseCUDA: norm_sparse
+  autogen: native_norm.ScalarOpt_dim_dtype_out
+
+- func: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
+  dispatch:
+    CPU: _batch_norm_with_update_cpu
+    CUDA: _batch_norm_with_update_cuda
+    MPS: _batch_norm_with_update_mps
+    MkldnnCPU: _batch_norm_with_update_mkldnn
+  autogen: _batch_norm_with_update_functional
+
+- func: _batch_norm_with_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd, Tensor(g!) reserve) -> (Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!))
+  dispatch:
+    CPU: _batch_norm_with_update_cpu_out
+    CUDA: _batch_norm_with_update_cuda_out
+    MPS: _batch_norm_with_update_mps_out
+
+- func: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
+  dispatch:
+    CompositeExplicitAutograd: _batch_norm_no_update
+  autogen: _batch_norm_no_update.out
+
+- func: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    CPU: _new_batch_norm_backward_cpu
+    CUDA: _new_batch_norm_backward_cuda
+    MPS: _new_batch_norm_backward_mps
+    MkldnnCPU: _new_batch_norm_backward_mkldnn
+
+# TODO: reduce signatures down to one when optional args is available
+- func: _sparse_sum(Tensor self) -> Tensor
+
+- func: _sparse_sum.dtype(Tensor self, *, ScalarType dtype) -> Tensor
+
+- func: _sparse_sum.dim(Tensor self, int[1] dim) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: _sparse_sum
+  autogen: _sparse_sum.dim_out
+
+- func: _sparse_sum.dim_dtype(Tensor self, int[1] dim, *, ScalarType dtype) -> Tensor
+
+- func: _sparse_sum_backward(Tensor grad, Tensor self, int[] dim) -> Tensor
+  dispatch:
+    SparseCPU: _sparse_sum_backward_cpu
+    SparseCUDA: _sparse_sum_backward_cuda
+  autogen: _sparse_sum_backward.out
+
+- func: _sparse_csr_sum.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  dispatch:
+    SparseCsrCPU: _sparse_csr_sum_cpu
+    SparseCsrCUDA: _sparse_csr_sum_cuda
+  autogen: _sparse_csr_sum.dim_dtype_out
+
+- func: _sparse_csr_prod.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  dispatch:
+    SparseCsrCPU: _sparse_csr_prod_cpu
+    SparseCsrCUDA: _sparse_csr_prod_cuda
+  autogen: _sparse_csr_prod.dim_dtype_out
+
+- func: _sparse_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
+  python_module: sparse
+  variants: function
+
+- func: _sparse_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
+  python_module: sparse
+  variants: function
+
+- func: _sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
+  python_module: sparse
+  dispatch:
+    SparseCPU: softmax_sparse_cpu
+    SparseCUDA: softmax_sparse_cuda
+  autogen: _sparse_softmax.out
+
+- func: _sparse_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor
+  dispatch:
+    SparseCPU: softmax_backward_sparse_cpu
+    SparseCUDA: softmax_backward_sparse_cuda
+  autogen: _sparse_softmax_backward_data.out
+
+- func: _sparse_log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
+  python_module: sparse
+  variants: function
+
+- func: _sparse_log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor
+  python_module: sparse
+  variants: function
+
+- func: _sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
+  python_module: sparse
+  dispatch:
+    SparseCPU: log_softmax_sparse_cpu
+    SparseCUDA: log_softmax_sparse_cuda
+  autogen: _sparse_log_softmax.out
+
+- func: _sparse_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor
+  dispatch:
+    SparseCPU: log_softmax_backward_sparse_cpu
+    SparseCUDA: log_softmax_backward_sparse_cuda
+  autogen: _sparse_log_softmax_backward_data.out
+
+- func: _spdiags(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None) -> Tensor
+  python_module: sparse
+  dispatch:
+    CPU: spdiags
+  autogen: _spdiags.out
+
+- func: norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: norm
+  autogen: norm.ScalarOpt_dtype_out
+
+- func: norm.Scalar(Tensor self, Scalar p=2) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: norm
+  autogen: norm.Scalar_out
+
+- func: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
+  structured_delegate: norm.dtype_out
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: sparse_dtype_norm
+
+- func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor
+  structured_delegate: norm.out
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: sparse_norm
+
+- func: norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: norm_dtype_out
+    MPS: norm_dtype_out_mps
+
+- func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: norm_out
+    MPS: norm_out_mps
+
+# These four redispatch in their implementation, so OK to be CompositeImplicitAutograd
+- func: norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+
+- func: norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+
+- func: frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent)
+  variants: method, function
+  dispatch:
+    CompositeExplicitAutograd: frexp
+  tags: pointwise
+
+- func: frexp.Tensor_out(Tensor self, *, Tensor(a!) mantissa, Tensor(b!) exponent) -> (Tensor(a!) mantissa, Tensor(b!) exponent)
+  dispatch:
+    CPU, CUDA: frexp_out
+  tags: pointwise
+
+# Deprecated (v.1.12)
+- func: frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
+  variants: function
+
+# Deprecated (v.1.12)
+- func: frobenius_norm.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+
+# Deprecated (v.1.12)
+- func: nuclear_norm(Tensor self, bool keepdim=False) -> Tensor
+  variants: function
+
+# Deprecated (v.1.12)
+- func: nuclear_norm.out(Tensor self, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+
+# Deprecated (v.1.12)
+- func: nuclear_norm.dim(Tensor self, int[2] dim, bool keepdim=False) -> Tensor
+  variants: function
+
+# Deprecated (v.1.12)
+- func: nuclear_norm.dim_out(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+
+- func: clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: clone
+    SparseCPU, SparseCUDA: clone_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: clone_sparse_compressed
+    MkldnnCPU: mkldnn_clone
+    QuantizedCPU, QuantizedCUDA: quantized_clone
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: clone_nested
+  autogen: clone.out
+  tags: [core, pointwise]
+
+- func: positive(Tensor(a) self) -> Tensor(a)
+  variants: function, method
+  tags: pointwise
+
+- func: resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!)
+  use_const_ref_for_mutable_tensors: True
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: resize_as_
+  autogen: resize_as, resize_as.out
+  tags: inplace_view
+
+- func: resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)
+  use_const_ref_for_mutable_tensors: True
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: resize_as_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: resize_as_sparse_compressed_
+  autogen: resize_as_sparse, resize_as_sparse.out
+
+- func: zero_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    CPU, CUDA: zero_
+    MPS: zero_mps_
+    Meta: zero_meta_
+    SparseCPU, SparseCUDA, SparseMeta: zero_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: zero_sparse_csr_
+    MkldnnCPU: mkldnn_zero_
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: zero_nested_
+  autogen: zero, zero.out
+
+- func: sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: sub_out
+    MPS: sub_out_mps
+    SparseCPU, SparseCUDA: sub_out_sparse
+  tags: pointwise
+
+- func: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: sub.out
+  dispatch:
+    SparseCPU, SparseCUDA: sub_sparse
+    ZeroTensor: sub_zerotensor
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_sub_Tensor
+  tags: [core, pointwise]
+
+- func: sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  structured_delegate: sub.out
+  dispatch:
+    SparseCPU, SparseCUDA: sub_sparse_
+  tags: pointwise
+# For C++ only, until we have conversion from C++ numbers to Tensor
+
+- func: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: sub
+  tags: [core, pointwise]
+
+- func: sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: sub_
+  autogen: sub.Scalar_out
+  tags: pointwise
+# subtract, alias for sub
+
+- func: subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+
+- func: subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+  variants: function, method
+
+- func: subtract_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
+  variants: method
+
+# For C++ only, until we have conversion from C++ numbers to Tensor
+- func: subtract.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
+  variants: function, method
+
+- func: subtract_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
+  variants: method
+
+- func: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CPU, CUDA, MPS: rsub
+  autogen: rsub.Tensor_out
+
+- func: heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: heaviside_out
+  tags: pointwise
+
+- func: heaviside(Tensor self, Tensor values) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: heaviside.out
+  tags: pointwise
+
+- func: heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  structured_delegate: heaviside.out
+
+# For C++ only, until we have conversion from C++ numbers to Tensor
+- func: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: rsub
+  autogen: rsub.Scalar_out
+
+# Functionally the same as addmm, but we give it a different derivative formula
+# that doesn't propagate gradients to non-present entries on sparse.
+  tags: pointwise
+- func: _sparse_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+  python_module: sparse
+  dispatch:
+    CompositeExplicitAutograd: _sparse_addmm
+  autogen: _sparse_addmm.out
+
+- func: sparse_sampled_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+  python_module: sparse
+  dispatch:
+    SparseCsrCUDA: sparse_sampled_addmm_out_sparse_csr_cuda
+    SparseCsrCPU: sparse_sampled_addmm_out_sparse_csr_cpu
+
+- func: sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+  python_module: sparse
+  dispatch:
+    SparseCsrCUDA: sparse_sampled_addmm_sparse_csr_cuda
+    SparseCsrCPU: sparse_sampled_addmm_sparse_csr_cpu
+
+- func: _sparse_mm_reduce_impl(Tensor self, Tensor other, str reduce) -> (Tensor, Tensor)
+  python_module: sparse
+  dispatch:
+    SparseCsrCPU: _sparse_mm_reduce_impl_sparse_csr_cpu
+
+- func: _sparse_mm_reduce_impl_backward(Tensor self, Tensor grad_out, Tensor weight, str reduce, Tensor arg_out, bool[2] output_mask) -> (Tensor, Tensor)
+  python_module: sparse
+  dispatch:
+    SparseCsrCPU: _sparse_mm_reduce_impl_backward_sparse_csr_cpu
+
+- func: addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  dispatch:
+    CPU: addmm_out_cpu
+    CUDA: addmm_out_cuda
+    MPS: addmm_out_mps
+    XPU: addmm_out_xpu
+    SparseCPU: addmm_out_sparse_dense_cpu
+    SparseCUDA: addmm_out_sparse_dense_cuda
+    SparseCsrCPU: addmm_out_sparse_compressed_cpu
+    SparseCsrCUDA: addmm_out_sparse_compressed_cuda
+
+- func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+  structured_delegate: addmm.out
+  variants: function, method
+  dispatch:
+    SparseCPU: addmm_sparse_dense_cpu
+    SparseCUDA: addmm_sparse_dense_cuda
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: addmm_sparse_compressed_dense
+  tags: core
+
+- func: addmm.dtype(Tensor self, Tensor mat1, Tensor mat2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+  dispatch:
+    CUDA: _addmm_dtype_cuda
+
+- func: addmm.dtype_out(Tensor self, Tensor mat1, Tensor mat2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CUDA: _addmm_dtype_out_cuda
+
+- func: addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
+  structured_delegate: addmm.out
+  variants: method
+  dispatch:
+    # Warning!  For whatever reason, the inplace sparse addmm is NON
+    # broadcasting
+    SparseCPU: s_addmm_sparse_dense_cpu_
+    SparseCUDA: s_addmm_sparse_dense_cuda_
+
+- func: _addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  dispatch:
+    CPU: addmm_activation_out_cpu
+    CUDA: addmm_activation_out_cuda
+    XPU: addmm_activation_out_xpu
+
+- func: _addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor
+  structured_delegate: _addmm_activation.out
+  variants: function, method
+
+- func: _scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
+  variants: function
+  dispatch:
+    CPU: _scaled_mm_cpu
+    CUDA: _scaled_mm_cuda
+
+- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
+  variants: function
+  dispatch:
+    CPU: _scaled_mm_out_cpu
+    CUDA: _scaled_mm_out_cuda
+
+
+- func: _scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
+  variants: function
+  dispatch:
+    CUDA: _scaled_grouped_mm_cuda
+
+- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
+  variants: function
+  dispatch:
+    CUDA: _grouped_mm_cuda
+
+# NOTE [ Sparse: autograd and API ]
+#
+#
+# Sparse Tensor Constructors
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# The API entry points to sparse tensor construction should be
+# `sparse_coo tensor` and `_sparse_coo_tensor_unsafe`. Depending on whether the
+# indices and values tensors are given, they eventually dispatch to either
+# `sparse_coo_tensor_with_dims` or `sparse_coo_tensor_with_dims_and_tensors`.
+#
+# The autograd support for ctor is implement on `sparse_coo_tensor_with_dims_and_tensors`.
+#
+# The API methods `sparse_coo tensor` and `_sparse_coo_tensor_unsafe`
+# **must not** have specific type dispatches because otherwise codegen will
+# consider them as abstract methods (see Note [Abstract ATen methods]), dispatch
+# using **Tensor** type, and thus lose autograd tracking on the actual method
+# they dispatch to, e.g., `sparse_coo_tensor_with_dims_and_tensors`.
+#
+#
+# Sparse Methods API Design
+# ~~~~~~~~~~~~~~~~~~~~~~~~~
+#
+# Goals: 1. Flexible API for users to write custom sparse ops
+#        2. ctor and member accessor with autograd support
+#
+# To achieve 1, we need to provide a set of *dangerous* APIs (dangerous in the
+# sense that misusing them will break sparse tensor invariant and may out in
+# unexpected behavior, e.g., crash). These methods are all prefixed with
+# underscore "_" to indicate that they should be used with care. We provide:
+#
+#   + `_indices()`: returns the *raw* indices within the sparse tensor (not just
+#                   sharing storage). Any inplace operation will change the
+#                   actual indices, including t_, set_, as_strided_, resize_,
+#                   etc.
+#   + `_values()`: returns the *raw* values within the sparse tensor. Similar
+#                  semantics as `_indices()`
+#   + `_nnz()`: returns the number of non-zero entries. This will always be
+#               determined by the shapes of indices and values.
+#   + `_coalesced_(bool)`: inplace sets whether the tensor is coalesced, and
+#                          returns itself.
+#
+# These methods are very useful in writing new operations, e.g., a custom
+# autograd Function.
+#
+# We also provide other public *safe* APIs:
+#   + `indices()`: returns a **view** of the indices tensor if the sparse tensor
+#                  is **coalesced**.
+#   + `values()`: returns a **view** of the values tensor if the containing
+#                 sparse tensor is **coalesced**.
+#   + `sparse_dim()`: number of sparse dimensions
+#   + `dense_dim()`: number of dense dimensions
+#   + `is_coalesced()`: whether the sparse tensor is coalesced
+#
+# `_indices()` and `_values()` should returns the raw indices and values dense
+# tensors within a sparse tensor. They can be quite unsafe with inplace
+# operations like `t_()`, and exposes uncoalesced indices and values. The public
+# recommended API is `indices()` and `values()`, both of which first check that
+# the tensor is coalesced and return views on those tensors.
+#
+#
+# Autograd Support
+# ~~~~~~~~~~~~~~~~
+#
+# Autograd is supported on `values()` and sparse tensor ctor with indices and
+# values tensors. E.g., `torch.sparse_coo_tensor(i, v).values().sum()` is
+# differentiable w.r.t. `v`.
+#
+# NB: The `values()` and `_values()` operators are special in that they are
+# layout-aware, i.e., the output depends not just on the data it represents, but
+# also on the input layout details (in this case, the `indices` tensor). See
+# NOTE [ as_strided Backward and layout-aware/agnostic autograd ] in Functions.cpp
+# for discussion on layout-aware vs layout-agnostic autograd. Since PyTorch ops
+# operate in the layout-agnostic mode, similar to `as_strided`, backward of
+# these two operators need to consider them in a layout-agnostic way:
+#   + `values()`:
+#     Input is coalesced.
+#     We just pretend having `input.indices()` as an additional argument
+#     `input_indices`, then forward is similar to
+#     `input.to(kStrided).index_select(input_indices)` regardless of the layout.
+#     Note that `values()` normally is layout-aware even if we constrain
+#     ourselves on sparse inputs since it may include all zeros values entries
+#     as "present" entries.
+#   + `_values()`:
+#     Input may be uncoalesced.
+#     It is not straightforward to construct a layout-agnostic version because
+#     duplicate indices entries may exist and additional parameterization is
+#     needed to distribute the value into different values entries. Furthermore,
+#     this op is intended to provide ways to write custom sparse ops, rather
+#     than being used in autograd graph, so it is marked as *non-differentiable*
+#     in derivatives.yaml.
+#
+# Before reading the following, see NOTE [ Autograd Variable Views ] in
+# variable.h for details on views that are tracked by autograd, and views that
+# are not.
+#
+# Moreover, these methods return tensors that share storage with inputs, so we
+# mark these methods as view ops to support autograd history tracking.
+# The sparse tensor ctor output should technically be view of both input indices
+# and values tensors, but currently we only support setting as view of a single
+# Variable, so it is only view of the values tensor.
+# TODO: clone indices in sparse tensor ctor.
+#
+# For other methods that return outputs that share storage with inputs, i.e.,
+# `indices()` and `_indices()`. We mark their outputs as non-differentiable, so
+# the view relation is not tracked by autograd, but the version counter is still
+# shared. In other words, their outputs are non-differentiable views of the
+# sparse tensor.
+# FIXME: would be nicer if TensorOptions was optional based; not adding default arguments for options given
+# the default would never make sense.
+
+- func: _sparse_compressed_tensor_with_dims(int nnz, int dense_dim, int[] size, int[] blocksize, ScalarType index_dtype, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: sparse_compressed_tensor_with_dims
+
+- func: sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: sparse_compressed_tensor
+
+- func: sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+- func: sparse_csc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+- func: sparse_bsr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+- func: sparse_bsc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+
+- func: sparse_compressed_tensor.comp_plain_value(Tensor compressed_indices, Tensor plain_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: sparse_compressed_tensor
+- func: sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+- func: sparse_csc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+- func: sparse_bsr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+- func: sparse_bsc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+
+- func: _sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeImplicitAutograd: _sparse_compressed_tensor_unsafe_symint
+
+- func: _sparse_csr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+- func: _sparse_csc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+- func: _sparse_bsr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+- func: _sparse_bsc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+
+- func: sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: sparse_coo_tensor
+  autogen: sparse_coo_tensor.size_out
+
+- func: sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor
+
+- func: sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor
+
+- func: _sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor
+  dispatch:
+    CompositeImplicitAutograd: _sparse_coo_tensor_unsafe_symint
+
+- func: _validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size, bool? is_coalesced=None, bool? check_pinning=None) -> ()
+
+- func: _validate_sparse_compressed_tensor_args(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, Layout layout, bool? check_pinning=None) -> ()
+- func: _validate_sparse_csr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()
+- func: _validate_sparse_csc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()
+- func: _validate_sparse_bsr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()
+- func: _validate_sparse_bsc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()
+
+- func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+  dispatch:
+    SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_sparse
+  autogen: _sparse_coo_tensor_with_dims.out
+
+- func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor
+  dispatch:
+    SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_and_tensor_sparse_symint
+  autogen: _sparse_coo_tensor_with_dims_and_tensors.out
+
+- func: sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
+  use_const_ref_for_mutable_tensors: True
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA, SparseMeta: sparse_resize_
+  autogen: sparse_resize, sparse_resize.out
+
+- func: sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
+  use_const_ref_for_mutable_tensors: True
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA, SparseMeta: sparse_resize_and_clear_
+  autogen: sparse_resize_and_clear, sparse_resize_and_clear.out
+
+- func: sparse_mask(Tensor self, Tensor mask) -> Tensor
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA: sparse_mask
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_mask_sparse_compressed
+  autogen: sparse_mask.out
+
+- func: _sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA: sparse_mask_projection
+  autogen: _sparse_mask_projection.out
+
+- func: _to_cpu(Tensor[] tensors) -> Tensor[]
+  variants: function
+
+- func: to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor
+  variants: method
+
+# Special case of to_dense with custom derivative
+- func: _to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA: sparse_to_dense
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_dense
+    MkldnnCPU: mkldnn_to_dense
+  autogen: _to_dense.out
+
+- func: to_dense_backward(Tensor grad, Tensor input, bool? masked_grad=None) -> Tensor
+
+- func: sparse_dim(Tensor self) -> int
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA, SparseMeta: sparse_dim_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_dim_sparse_csr
+    CompositeExplicitAutograd: sparse_dim_default
+  device_check: NoCheck
+  device_guard: False
+
+# legacy method
+- func: _dimI(Tensor self) -> int
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA: sparse_dim_sparse
+  device_check: NoCheck
+  device_guard: False
+
+- func: dense_dim(Tensor self) -> int
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA, SparseMeta: dense_dim_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: dense_dim_sparse_csr
+    CompositeExplicitAutograd: dense_dim_default
+  device_check: NoCheck
+  device_guard: False
+
+# legacy method
+- func: _dimV(Tensor self) -> int
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA, SparseMeta: dense_dim_sparse
+  device_check: NoCheck
+  device_guard: False
+
+- func: _nnz(Tensor self) -> int
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA, SparseMeta: _nnz_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _nnz_sparse_csr
+  device_check: NoCheck
+  device_guard: False
+
+# NOTE: [ coalesce autograd ]
+# coalesce returns self directly for already coalesced sparse tensors.
+# This means coalesce cannot have a derivative registered, otherwise it creates
+# circular references in the autograd graph (see gh-52874).
+# Instead, the derivative is registered on the slow-path "_coalesce"
+- func: coalesce(Tensor(a) self) -> Tensor(a)
+  variants: method
+
+- func: _coalesce(Tensor self) -> Tensor
+  dispatch:
+    SparseCPU: _coalesce_sparse_cpu
+    SparseCUDA: _coalesce_sparse_cuda
+  autogen: _coalesce.out
+
+- func: is_coalesced(Tensor self) -> bool
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA, SparseMeta: is_coalesced_sparse
+    CompositeExplicitAutograd: is_coalesced_default
+  device_check: NoCheck
+  device_guard: False
+
+- func: _indices(Tensor(a) self) -> Tensor(a)
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA, SparseMeta: _indices_sparse
+  device_check: NoCheck
+  device_guard: False
+
+- func: _values(Tensor(a) self) -> Tensor(a)
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA, SparseMeta: _values_sparse
+  device_check: NoCheck
+  device_guard: False
+
+# This method doesn't do any check but only directly sets the flag. So it can be
+# a bit unsafe. Similar to _indices and _values, this is useful for implementing
+# custom sparse operations in Python/C++ extension.
+- func: _coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!)
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA, SparseMeta: _coalesced_sparse_
+  device_check: NoCheck
+  device_guard: False
+  autogen: _coalesced, _coalesced.out
+
+- func: indices(Tensor(a) self) -> Tensor(a)
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA, SparseMeta: indices_sparse
+    CompositeExplicitAutograd: indices_default
+  device_check: NoCheck
+  device_guard: False
+
+- func: values(Tensor(a) self) -> Tensor(a)
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA, SparseMeta: values_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: values_sparse_csr
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: values_nested
+    CompositeExplicitAutograd: values_default
+  device_check: NoCheck
+  device_guard: False
+
+- func: crow_indices(Tensor(a) self) -> Tensor(a)
+  variants: method
+  dispatch:
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: crow_indices_sparse_csr
+    CompositeExplicitAutograd: crow_indices_default
+  device_check: NoCheck
+  device_guard: False
+
+- func: col_indices(Tensor(a) self) -> Tensor(a)
+  variants: method
+  dispatch:
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: col_indices_sparse_csr
+    CompositeExplicitAutograd: col_indices_default
+  device_check: NoCheck
+  device_guard: False
+
+- func: ccol_indices(Tensor(a) self) -> Tensor(a)
+  variants: method
+  dispatch:
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: ccol_indices_sparse_csr
+    CompositeExplicitAutograd: ccol_indices_default
+  device_check: NoCheck
+  device_guard: False
+
+- func: row_indices(Tensor(a) self) -> Tensor(a)
+  variants: method
+  dispatch:
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: row_indices_sparse_csr
+    CompositeExplicitAutograd: row_indices_default
+  device_check: NoCheck
+  device_guard: False
+
+- func: hspmm.out(Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    SparseCPU: hspmm_out_sparse_cpu
+    SparseCUDA: hspmm_out_sparse_cuda
+
+- func: hspmm(Tensor mat1, Tensor mat2) -> Tensor
+  dispatch:
+    SparseCPU: hspmm_sparse_cpu
+    SparseCUDA: hspmm_sparse_cuda
+
+- func: copy_sparse_to_sparse_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)
+  device_check: NoCheck  # Allows copy into different device
+  variants: function
+  dispatch:
+    SparseCPU, SparseCUDA, SparseMeta: copy_sparse_
+  autogen: copy_sparse_to_sparse, copy_sparse_to_sparse.out
+
+# By adding the AutogradNestedTensor this makes this function CompositeImplicit-like for nested tensors
+- func: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: unbind
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_unbind
+
+- func: unbind.Dimname(Tensor(a -> *) self, Dimname dim) -> Tensor(a)[]
+  variants: function, method
+
+- func: to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor
+  variants: method
+
+# Special case of to_sparse.sparse_dim with custom derivative
+- func: _to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor
+  variants: method
+  dispatch:
+    CPU, CUDA: dense_to_sparse
+    SparseCPU, SparseCUDA: sparse_coo_to_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse
+  autogen: _to_sparse.sparse_dim_out
+
+- func: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
+  variants: method
+
+# Special case of to_sparse with custom derivative
+- func: _to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
+  variants: method
+  dispatch:
+    CPU, CUDA: dense_to_sparse
+    SparseCPU, SparseCUDA: sparse_coo_to_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse
+  autogen: _to_sparse.out
+
+- func: to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor
+  variants: method
+
+# Special case of to_sparse_csr with custom derivative
+- func: _to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor
+  variants: method
+  dispatch:
+    CPU, CUDA: dense_to_sparse_csr
+    SparseCPU, SparseCUDA: coo_to_sparse_csr
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse_csr
+  autogen: _to_sparse_csr.out
+
+- func: to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor
+  variants: method
+
+# Special case of to_sparse_csc with custom derivative
+- func: _to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor
+  variants: method
+  dispatch:
+    CPU, CUDA: dense_to_sparse_csc
+    SparseCPU, SparseCUDA: coo_to_sparse_csc
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse_csc
+  autogen: _to_sparse_csc.out
+
+- func: to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
+  variants: method
+
+# Special case of to_sparse_bsr with custom derivative
+- func: _to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
+  variants: method
+  dispatch:
+    CPU, CUDA: dense_to_sparse_bsr
+    SparseCPU, SparseCUDA: coo_to_sparse_bsr
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse_bsr
+  autogen: _to_sparse_bsr.out
+
+- func: to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
+  variants: method
+
+# Special case of to_sparse_bsc with custom derivative
+- func: _to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
+  variants: method
+  dispatch:
+    CPU, CUDA: dense_to_sparse_bsc
+    SparseCPU, SparseCUDA: coo_to_sparse_bsc
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse_bsc
+  autogen: _to_sparse_bsc.out
+
+- func: _to_sparse_semi_structured(Tensor dense) -> (Tensor, Tensor)
+  variants: function
+  dispatch:
+    CUDA: _to_sparse_semi_structured
+
+- func: to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor
+  variants: method
+  dispatch:
+    CPU: dense_to_mkldnn
+  autogen: to_mkldnn.out
+
+- func: mkldnn_reorder_conv2d_weight(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor
+  variants: function
+  python_module: nn
+  dispatch:
+    MkldnnCPU: mkldnn_reorder_conv2d_weight
+  autogen: mkldnn_reorder_conv2d_weight.out
+
+- func: mkldnn_reorder_conv3d_weight(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor
+  variants: function
+  python_module: nn
+  dispatch:
+    MkldnnCPU: mkldnn_reorder_conv3d_weight
+  autogen: mkldnn_reorder_conv3d_weight.out
+
+- func: to_mkldnn_backward(Tensor grad, Tensor input) -> Tensor
+
+- func: quantize_per_tensor_dynamic(Tensor self, ScalarType dtype, bool reduce_range) -> Tensor
+  variants: function
+  dispatch:
+    CPU, CUDA: quantize_per_tensor_dynamic
+  autogen: quantize_per_tensor_dynamic.out
+
+- func: quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor
+  variants: function
+  dispatch:
+    CPU, CUDA: quantize_per_tensor
+  autogen: quantize_per_tensor.out
+
+- func: quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor
+  variants: function
+  dispatch:
+    CPU, CUDA: quantize_per_tensor_tensor_qparams
+  autogen: quantize_per_tensor.tensor_qparams_out
+
+- func: quantize_per_tensor.tensors(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype) -> Tensor[]
+  variants: function
+  dispatch:
+    CPU: quantize_per_tensor_list_cpu
+  autogen: quantize_per_tensor.tensors_out
+
+- func: quantize_per_channel(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype) -> Tensor
+  variants: function
+  dispatch:
+    CPU, CUDA: quantize_per_channel
+  autogen: quantize_per_channel.out
+
+- func: dequantize.self(Tensor self) -> Tensor
+  variants: function, method
+  dispatch:
+    CPU, CUDA: dequantize_cpu_or_cuda
+    QuantizedCPU, QuantizedCUDA: dequantize_quantized
+  autogen: dequantize.self_out
+
+- func: dequantize.tensors(Tensor[] tensors) -> Tensor[]
+  variants: function
+  dispatch:
+    QuantizedCPU: dequantize_tensors_quantized_cpu
+  autogen: dequantize.tensors_out
+
+- func: q_scale(Tensor self) -> float
+  variants: function, method
+  dispatch:
+    QuantizedCPU, QuantizedCUDA: q_scale_quant
+
+- func: q_zero_point(Tensor self) -> int
+  variants: function, method
+  dispatch:
+    QuantizedCPU, QuantizedCUDA: q_zero_point_quant
+
+- func: q_per_channel_scales(Tensor self) -> Tensor
+  variants: function, method
+  dispatch:
+    QuantizedCPU, QuantizedCUDA: q_per_channel_scales
+  autogen: q_per_channel_scales.out
+
+- func: q_per_channel_zero_points(Tensor self) -> Tensor
+  variants: function, method
+  dispatch:
+    QuantizedCPU, QuantizedCUDA: q_per_channel_zero_points
+  autogen: q_per_channel_zero_points.out
+
+- func: q_per_channel_axis(Tensor self) -> int
+  variants: function, method
+  dispatch:
+    QuantizedCPU, QuantizedCUDA: q_per_channel_axis
+
+- func: int_repr(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    QuantizedCPU: int_repr_quantized_cpu
+    QuantizedCUDA: int_repr_quantized_cuda
+  autogen: int_repr.out
+
+- func: _make_per_tensor_quantized_tensor(Tensor self, float scale, int zero_point) -> Tensor
+  dispatch:
+    CPU: make_per_tensor_quantized_tensor_cpu
+    CUDA: make_per_tensor_quantized_tensor_cuda
+  autogen: _make_per_tensor_quantized_tensor.out
+
+- func: _make_per_channel_quantized_tensor(Tensor self, Tensor scale, Tensor zero_point, int axis) -> Tensor
+  dispatch:
+    CPU: make_per_channel_quantized_tensor_cpu
+    CUDA: make_per_channel_quantized_tensor_cuda
+  autogen: _make_per_channel_quantized_tensor.out
+
+- func: qscheme(Tensor self) -> QScheme
+  variants: method
+  dispatch:
+    QuantizedCPU, QuantizedCUDA: qscheme_quant
+
+- func: fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function
+
+- func: fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function
+
+- func: fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
+  variants: function
+  dispatch:
+    CPU, CUDA: fake_quantize_per_tensor_affine_cachemask
+  autogen: fake_quantize_per_tensor_affine_cachemask.out
+
+- func: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
+  variants: function
+  dispatch:
+    CPU, CUDA: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams
+  autogen: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out
+
+- func: fake_quantize_per_tensor_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor
+  variants: function
+
+- func: _fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor
+  variants: function
+  dispatch:
+    CPU, CUDA: _fake_quantize_learnable_per_tensor_affine
+  autogen: _fake_quantize_learnable_per_tensor_affine.out
+
+- func: _fake_quantize_learnable_per_tensor_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> (Tensor, Tensor, Tensor)
+  variants: function
+  dispatch:
+    CPU, CUDA: _fake_quantize_learnable_per_tensor_affine_backward
+
+- func: fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function
+
+- func: fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
+  variants: function
+  dispatch:
+    CPU, CUDA: fake_quantize_per_channel_affine_cachemask
+  autogen: fake_quantize_per_channel_affine_cachemask.out
+
+- func: fake_quantize_per_channel_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor
+  variants: function
+
+- func: _fake_quantize_learnable_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor
+  variants: function
+  dispatch:
+    CPU, CUDA: _fake_quantize_learnable_per_channel_affine
+  autogen: _fake_quantize_learnable_per_channel_affine.out
+
+- func: _fake_quantize_learnable_per_channel_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> (Tensor, Tensor, Tensor)
+  variants: function
+  dispatch:
+    CPU, CUDA: _fake_quantize_learnable_per_channel_affine_backward
+
+- func: fused_moving_avg_obs_fake_quant(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> Tensor
+  variants: function
+
+- func: _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask)
+  dispatch:
+    CPU: fused_moving_avg_obs_fake_quant_cpu
+    CUDA: fused_moving_avg_obs_fake_quant_cuda
+  autogen: _fused_moving_avg_obs_fq_helper_functional, _fused_moving_avg_obs_fq_helper.out
+
+- func: _choose_qparams_per_tensor(Tensor self, bool reduce_range=False) -> (float, int)
+  variants: function
+
+- func: _saturate_weight_to_fp16(Tensor weight) -> Tensor
+  variants: function
+
+- func: choose_qparams_optimized(Tensor input, int numel, int n_bins, float ratio, int bit_width) -> (Tensor, Tensor)
+  variants: function
+
+- func: _autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a)
+  variants: method
+  device_guard: False
+
+- func: _autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a)
+  variants: method
+  device_guard: False
+
+- func: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: _to_copy
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: _to_copy_nested
+  autogen: _to_copy.out
+  tags: core
+
+# to(Device) must not exist because all constructors of Device also works for
+# TensorOptions. Otherwise, an ambiguity error is thrown.
+# See NOTE [ TensorOptions Constructors ].
+- func: to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+
+- func: to.device(Tensor(a) self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+
+- func: to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+
+- func: to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+
+- func: meshgrid(Tensor[] tensors) -> Tensor[]
+
+# TODO: Two weeks after this lands, combine these two overloads,
+#       making "indexing" optional. These are temporarily distinct for
+#       forward-compatibility reasons.
+- func: meshgrid.indexing(Tensor[] tensors, *, str indexing) -> Tensor[]
+
+- func: cartesian_prod(Tensor[] tensors) -> Tensor
+  variants: function
+  tags: maybe_aliasing_or_mutating
+
+- func: combinations(Tensor self, int r=2, bool with_replacement=False) -> Tensor
+  variants: function
+
+- func: item(Tensor self) -> Scalar
+  tags: data_dependent_output
+  variants: method
+
+- func: result_type.Tensor(Tensor tensor, Tensor other) -> ScalarType
+  variants: function
+
+- func: result_type.Scalar(Tensor tensor, Scalar other) -> ScalarType
+  variants: function
+
+- func: result_type.Scalar_Tensor(Scalar scalar, Tensor tensor) -> ScalarType
+  variants: function
+
+- func: result_type.Scalar_Scalar(Scalar scalar1, Scalar scalar2) -> ScalarType
+
+- func: can_cast(ScalarType from_, ScalarType to) -> bool
+  variants: function
+
+- func: promote_types(ScalarType type1, ScalarType type2) -> ScalarType
+  variants: function
+
+# NB: Does NOT check precondition that numel == 1
+- func: _local_scalar_dense(Tensor self) -> Scalar
+  tags: [core, data_dependent_output]
+  dispatch:
+    CPU: _local_scalar_dense_cpu
+    CUDA: _local_scalar_dense_cuda
+    MPS: _local_scalar_dense_mps
+  variants: function
+
+# MPS LSTM implementation
+
+- func: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
+  dispatch:
+    MPS: _lstm_mps
+  autogen: _lstm_mps.out
+  tags: nondeterministic_seeded
+
+- func: lstm_mps_backward(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
+  dispatch:
+    MPS: lstm_mps_backward
+  autogen: lstm_mps_backward.out
+
+
+# Fused RNN kernels
+- func: _thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    CUDA: _thnn_fused_lstm_cell_cuda
+  autogen: _thnn_fused_lstm_cell.out
+
+# NB: The composite version of this function below is a simple wrapper that duplicates some of the outputs
+#     It is necessary to avoid triggering TensorImpl use count checks in debug mode
+# NB: this is function is NOT differentiable
+- func: _thnn_fused_lstm_cell_backward_impl(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    CUDA: _thnn_fused_lstm_cell_backward_impl_cuda
+  autogen: _thnn_fused_lstm_cell_backward_impl.out
+
+- func: _thnn_fused_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
+
+- func: _thnn_differentiable_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor input_gates, Tensor hidden_gates, Tensor? input_bias, Tensor? hidden_bias, Tensor cx, Tensor cy) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
+
+- func: _thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor)
+  dispatch:
+    CUDA: _thnn_fused_gru_cell_cuda
+  autogen: _thnn_fused_gru_cell.out
+
+- func: _thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
+  dispatch:
+    CUDA: _thnn_fused_gru_cell_backward_cuda
+  autogen: _thnn_fused_gru_cell_backward.out
+
+- func: _thnn_differentiable_gru_cell_backward(Tensor grad_hy, Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias, Tensor? hidden_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
+
+# RNN cells and layers
+- func: lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor)
+  tags: nondeterministic_seeded
+
+- func: lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor, Tensor)
+  tags: nondeterministic_seeded
+
+- func: gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)
+  tags: nondeterministic_seeded
+
+- func: gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)
+  tags: nondeterministic_seeded
+
+- func: rnn_tanh.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)
+  tags: nondeterministic_seeded
+
+- func: rnn_tanh.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)
+  tags: nondeterministic_seeded
+
+- func: rnn_relu.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)
+  tags: nondeterministic_seeded
+
+- func: rnn_relu.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)
+  tags: nondeterministic_seeded
+
+- func: lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)
+
+- func: gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor
+
+- func: rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor
+
+- func: rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor
+
+# Quantized RNN layer registration has been moved to C10 dispatch in `RNN.cpp`
+
+# Quantized RNN layers
+# - func: quantized_lstm(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)
+
+
+# - func: quantized_lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)
+
+
+# Quantized GRU layers
+
+# - func: quantized_gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)
+#
+
+# - func: quantized_gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)
+#
+
+# Quantized RNN cells
+- func: quantized_lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> (Tensor, Tensor)
+
+- func: quantized_gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor
+
+- func: quantized_rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor
+
+- func: quantized_rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor
+
+# PackedSequence utilities
+- func: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)
+  dispatch:
+    CompositeExplicitAutograd: _pack_padded_sequence
+  autogen: _pack_padded_sequence.out
+
+- func: _pack_padded_sequence_backward(Tensor grad, SymInt[] input_size, Tensor batch_sizes, bool batch_first) -> Tensor
+  dispatch:
+    CompositeImplicitAutograd: _pack_padded_sequence_backward_symint
+
+- func: _pad_packed_sequence(Tensor data, Tensor batch_sizes, bool batch_first, Scalar padding_value, int total_length) -> (Tensor, Tensor)
+
+# wrappers for legacy TH methods
+
+- func: set_.source_Storage(Tensor(a!) self, Storage source) -> Tensor(a!)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CPU, CUDA, Meta, MPS: set_
+  autogen: set.source_Storage, set.source_Storage_out
+  tags: inplace_view
+
+- func: set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CPU: set_storage_cpu_
+    Meta: set_storage_meta__symint
+    CUDA: set_storage_cuda_
+    MPS: set_storage_mps_
+    QuantizedCPU, QuantizedCUDA: set_storage_quantized_
+  autogen: set.source_Storage_storage_offset, set.source_Storage_storage_offset_out
+  tags: inplace_view
+
+- func: set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeImplicitAutograd: set__symint
+  tags: inplace_view
+
+- func: set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CPU, CUDA, Meta, MPS: set_tensor_
+  autogen: set.source_Tensor, set.source_Tensor_out
+  tags: inplace_view
+
+- func: set_(Tensor(a!) self) -> Tensor(a!)
+  variants: method
+  dispatch:
+    CPU: set_cpu_
+    CUDA: set_cuda_
+    Meta: set_meta_
+    MPS: set_mps_
+  autogen: set, set.out
+  tags: inplace_view
+
+# Not making it CompositeImplicitAutograd because lift
+# should be a primitive w.r.t. functorch
+
+# TODO: this should have a view annotation
+# TODO: shouldn't be a method
+- func: lift(Tensor self) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: lift
+  autogen: lift.out
+
+# lift_fresh is called with an argument that is guaranteed to be
+# fresh (i.e., newly allocated).  This is ONLY called from a
+# torch.tensor call; if you FX trace a lift_fresh, you are obligated
+# to convert this into a lift_fresh_copy (because FX will violate the
+# freshness invariant when tracing).
+- func: lift_fresh(Tensor(a) self) -> Tensor(a)
+  dispatch:
+    CompositeExplicitAutograd: lift_fresh
+
+# Like lift, but it clones the input.
+- func: lift_fresh_copy(Tensor self) -> Tensor
+  tags: view_copy
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: lift_fresh_copy
+  autogen: lift_fresh_copy.out
+
+- func: is_set_to(Tensor self, Tensor tensor) -> bool
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CPU, CUDA, MPS: is_set_to
+
+- func: masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CPU: masked_fill__cpu
+    CUDA: masked_fill__cuda
+    QuantizedCPU: masked_fill__quantized_cpu
+    QuantizedCUDA: masked_fill__quantized_cuda
+    MPS: masked_fill__mps
+  autogen: masked_fill.Scalar_out
+
+- func: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: masked_fill
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_masked_fill
+  tags: pointwise
+
+- func: masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CPU: masked_fill__cpu
+    CUDA: masked_fill__cuda
+    QuantizedCPU: masked_fill__quantized_cpu
+    QuantizedCUDA: masked_fill__quantized_cuda
+    MPS: masked_fill__mps
+  autogen: masked_fill.Tensor_out
+
+- func: masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: masked_fill
+
+- func: masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!)
+  variants: method
+  dispatch:
+    CPU: masked_scatter__cpu
+    CUDA: masked_scatter__cuda
+    MPS: masked_scatter__mps
+  autogen: masked_scatter.out
+
+- func: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: masked_scatter
+  tags: core
+
+- func: masked_scatter_backward(Tensor grad_output, Tensor mask, SymInt[] sizes) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: masked_scatter_backward_symint
+
+- func: _masked_softmax(Tensor self, Tensor mask, int? dim=None, int? mask_type=None) -> Tensor
+  dispatch:
+    CUDA: masked_softmax_cuda
+    CPU: masked_softmax_cpu
+  autogen: _masked_softmax.out
+
+- func: _masked_softmax_backward(Tensor grad_output, Tensor output, Tensor mask, int? dim=None) -> Tensor
+  dispatch:
+    CUDA: masked_softmax_backward_cuda
+    CPU: masked_softmax_backward_cpu
+  autogen: _masked_softmax_backward.out
+
+- func: view(Tensor(a) self, SymInt[] size) -> Tensor(a)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    ZeroTensor, Meta, CPU, CUDA, QuantizedCPU, QuantizedCUDA, MPS, MTIA: view
+    MkldnnCPU: mkldnn_view
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: view_nested
+  tags: core
+
+# Warning: If you want to change the name or overload name of this
+# operator, you might also want to change the `isBlockListedSchema`
+# function in `torch/csrc/jit/frontend/schema_catching.cpp`.
+# The name and overload name of this operator is hardcoded in that
+# function in order to workaround a bug:
+# https://github.com/pytorch/pytorch/issues/47964
+- func: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: view_dtype
+
+- func: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!)
+  variants: method
+  dispatch:
+    CPU, CUDA: put_
+  autogen: put.out
+
+- func: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: put
+
+- func: index_add.out(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  variants: function
+  precomputed:
+  - dim -> int dim
+  dispatch:
+    CPU: index_add_cpu_out
+    CUDA: index_add_cuda_out
+    MPS: index_add_mps_out
+
+- func: index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor(a!)
+  structured_delegate: index_add.out
+  variants: method
+
+- func: index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor
+  structured_delegate: index_add.out
+  variants: function, method
+
+- func: index_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor
+  variants: function, method
+
+- func: index_reduce.out(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  variants: function
+  precomputed:
+  - dim -> int dim
+  dispatch:
+    CPU: index_reduce_cpu_out
+    CUDA: index_reduce_cuda_out
+
+- func: index_reduce_(Tensor(a!) self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor(a!)
+  structured_delegate: index_reduce.out
+  variants: method
+
+- func: index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor
+  structured_delegate: index_reduce.out
+  variants: function, method
+
+- func: index_fill_.int_Scalar(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CPU: index_fill_
+    CUDA: index_fill_
+    MPS: index_fill_mps_
+  autogen: index_fill.int_Scalar_out
+
+- func: index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: index_fill
+
+- func: index_fill_.int_Tensor(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CPU, CUDA: index_fill_
+    MPS: index_fill_mps_
+  autogen: index_fill.int_Tensor_out
+
+- func: index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  dispatch:
+    CompositeExplicitAutograd: index_fill
+
+- func: index_fill_.Dimname_Scalar(Tensor(a!) self, Dimname dim, Tensor index, Scalar value) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+- func: index_fill_.Dimname_Tensor(Tensor(a!) self, Dimname dim, Tensor index, Tensor value) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+- func: index_fill.Dimname_Scalar(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: index_fill.Dimname_Tensor(Tensor self, Dimname dim, Tensor index, Tensor value) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+
+- func: scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
+  structured_delegate: scatter.src_out
+  variants: function, method
+  tags: core
+
+- func: scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)
+  structured_delegate: scatter.src_out
+  variants: method
+
+- func: scatter.src_out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  variants: function
+  dispatch:
+    CPU, CUDA: scatter_src_out
+    MPS: scatter_src_out_mps
+
+- func: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
+  structured_delegate: scatter.value_out
+  variants: function, method
+  tags: core
+
+- func: scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)
+  structured_delegate: scatter.value_out
+  variants: method
+
+- func: scatter.value_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  variants: function
+  dispatch:
+    CPU, CUDA: scatter_value_out
+    MPS: scatter_value_out_mps
+
+- func: scatter.reduce(Tensor self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor
+  structured_delegate: scatter.reduce_out
+  variants: function, method
+
+- func: scatter_.reduce(Tensor(a!) self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor(a!)
+  structured_delegate: scatter.reduce_out
+  variants: method
+
+- func: scatter.reduce_out(Tensor self, int dim, Tensor index, Tensor src, *, str reduce, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  variants: function
+  dispatch:
+    CPU, CUDA: scatter_reduce_out
+    MPS: scatter_reduce_out_mps
+
+- func: scatter.value_reduce(Tensor self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor
+  structured_delegate: scatter.value_reduce_out
+  variants: function, method
+
+- func: scatter_.value_reduce(Tensor(a!) self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor(a!)
+  structured_delegate: scatter.value_reduce_out
+  variants: method
+
+- func: scatter.value_reduce_out(Tensor self, int dim, Tensor index, Scalar value, *, str reduce, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  variants: function
+  dispatch:
+    CPU, CUDA: scatter_value_reduce_out
+    MPS: scatter_value_reduce_out_mps
+
+- func: scatter.dimname_src(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor
+  variants: function, method
+
+- func: scatter.dimname_value(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor
+  variants: function, method
+
+- func: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
+  structured_delegate: scatter_add.out
+  variants: function, method
+  tags: core
+
+- func: scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)
+  structured_delegate: scatter_add.out
+  variants: method
+
+- func: scatter_add.out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  variants: function
+  dispatch:
+    CPU, CUDA: scatter_add
+    MPS: scatter_add_mps_out
+
+- func: scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor
+  variants: function, method
+
+- func: scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor
+  structured_delegate: scatter_reduce.two_out
+  variants: function, method
+  tags: core
+
+- func: scatter_reduce_.two(Tensor(a!) self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor(a!)
+  structured_delegate: scatter_reduce.two_out
+  variants: method
+
+- func: scatter_reduce.two_out(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  variants: function
+  dispatch:
+    CPU, CUDA, MPS: scatter_reduce_two
+
+- func: eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  structured_delegate: eq.Scalar_out
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+- func: eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  structured_delegate: eq.Tensor_out
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+- func: bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  variants: function
+  dispatch:
+    CPU, CUDA, MTIA: bitwise_and_out
+    MPS: bitwise_and_out_mps
+  tags: pointwise
+
+- func: bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: bitwise_and_out
+  tags: pointwise
+
+- func: bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    CompositeExplicitAutograd: bitwise_and
+  tags: [core, pointwise]
+
+- func: bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: bitwise_and
+  autogen: bitwise_and.Scalar_Tensor_out
+  tags: pointwise
+
+- func: bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  structured_delegate: bitwise_and.Tensor_out
+  tags: [core, pointwise]
+
+- func: bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: bitwise_and_
+  tags: pointwise
+
+- func: bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  structured_delegate: bitwise_and.Tensor_out
+  tags: pointwise
+
+- func: __and__.Scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+
+- func: __and__.Tensor(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+
+- func: __iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+- func: __iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+- func: bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  variants: function
+  dispatch:
+    CPU, CUDA, MTIA: bitwise_or_out
+    MPS: bitwise_or_out_mps
+  tags: pointwise
+
+- func: bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: bitwise_or_out
+  tags: pointwise
+
+- func: bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    CompositeExplicitAutograd: bitwise_or
+  tags: [core, pointwise]
+
+- func: bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: bitwise_or
+  autogen: bitwise_or.Scalar_Tensor_out
+  tags: pointwise
+
+- func: bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  structured_delegate: bitwise_or.Tensor_out
+  tags: [core, pointwise]
+
+- func: bitwise_or_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: bitwise_or_
+  tags: pointwise
+
+- func: bitwise_or_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  structured_delegate: bitwise_or.Tensor_out
+  tags: pointwise
+
+- func: __or__.Scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+
+- func: __or__.Tensor(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+
+- func: __ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+- func: __ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+- func: bitwise_xor.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  variants: function
+  dispatch:
+    CPU, CUDA: bitwise_xor_out
+    MPS: bitwise_xor_out_mps
+  tags: pointwise
+
+- func: bitwise_xor.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: bitwise_xor_out
+  tags: pointwise
+
+- func: bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    CompositeExplicitAutograd: bitwise_xor
+  tags: [core, pointwise]
+
+- func: bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: bitwise_xor
+  autogen: bitwise_xor.Scalar_Tensor_out
+  tags: pointwise
+
+- func: bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  structured_delegate: bitwise_xor.Tensor_out
+  tags: [core, pointwise]
+
+- func: bitwise_xor_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: bitwise_xor_
+  tags: pointwise
+
+- func: bitwise_xor_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  structured_delegate: bitwise_xor.Tensor_out
+  tags: pointwise
+
+- func: __xor__.Scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  tags: pointwise
+
+- func: __xor__.Tensor(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  tags: pointwise
+
+- func: __ixor__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  tags: pointwise
+
+- func: __ixor__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  tags: pointwise
+
+- func: __lshift__.Scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    CPU, CUDA, MPS: __lshift__
+  tags: pointwise
+
+- func: __lshift__.Tensor(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    CPU, CUDA, MPS: __lshift__
+  tags: pointwise
+
+- func: __ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CPU, CUDA, MPS: __ilshift__
+  autogen: __lshift__.Scalar_out
+  tags: pointwise
+
+- func: __ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CPU, CUDA, MPS: __ilshift__
+  autogen: __lshift__.Tensor_out
+  tags: pointwise
+
+- func: bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: bitwise_left_shift.Tensor_out
+  tags: pointwise
+
+- func: bitwise_left_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  structured_delegate: bitwise_left_shift.Tensor_out
+  tags: pointwise
+
+- func: bitwise_left_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: bitwise_left_shift_out
+  tags: pointwise
+
+- func: bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    CompositeExplicitAutograd: bitwise_left_shift
+  tags: pointwise
+
+- func: bitwise_left_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: bitwise_left_shift_
+  tags: pointwise
+
+- func: bitwise_left_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: bitwise_left_shift_out
+  tags: pointwise
+
+- func: bitwise_left_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: bitwise_left_shift
+  autogen: bitwise_left_shift.Scalar_Tensor_out
+  tags: pointwise
+
+- func: __rshift__.Scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    CPU, CUDA, MPS: __rshift__
+  tags: pointwise
+
+- func: __rshift__.Tensor(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    CPU, CUDA, MPS: __rshift__
+  tags: pointwise
+
+- func: __irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CPU, CUDA, MPS: __irshift__
+  autogen: __rshift__.Scalar_out
+
+- func: __irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CPU, CUDA, MPS: __irshift__
+  autogen: __rshift__.Tensor_out
+
+- func: bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function, method
+  structured_delegate: bitwise_right_shift.Tensor_out
+  tags: pointwise
+
+- func: bitwise_right_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  structured_delegate: bitwise_right_shift.Tensor_out
+  tags: pointwise
+
+- func: bitwise_right_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: bitwise_right_shift_out
+  tags: pointwise
+
+- func: bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    CompositeExplicitAutograd: bitwise_right_shift
+  tags: pointwise
+
+- func: bitwise_right_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: bitwise_right_shift_
+  tags: pointwise
+
+- func: bitwise_right_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: bitwise_right_shift_out
+  tags: pointwise
+
+- func: bitwise_right_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: bitwise_right_shift
+  autogen: bitwise_right_shift.Scalar_Tensor_out
+  tags: pointwise
+
+- func: tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)
+  structured_delegate: tril.out
+  variants: method
+
+- func: triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)
+  structured_delegate: triu.out
+  variants: method
+
+- func: digamma_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: digamma.out
+  variants: method
+  tags: pointwise
+
+- func: lerp_.Scalar(Tensor(a!) self, Tensor end, Scalar weight) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  structured_delegate: lerp.Scalar_out
+  tags: pointwise
+
+- func: lerp_.Tensor(Tensor(a!) self, Tensor end, Tensor weight) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  structured_delegate: lerp.Tensor_out
+  tags: pointwise
+
+- func: addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
+  variants: method
+  dispatch:
+    CPU, CUDA, XPU: addbmm_
+    MPS: addbmm_mps_
+
+- func: addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA, XPU: addbmm_out
+    MPS: addbmm_out_mps
+
+- func: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+  variants: method, function
+  dispatch:
+    CPU, CUDA, XPU: addbmm
+    MPS: addbmm_mps
+
+- func: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  tags: nondeterministic_seeded
+  dispatch:
+    CPU, CUDA: random_
+    Meta: random_meta_
+    MPS: random_mps_
+  autogen: random.from, random.from_out
+
+- func: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  tags: nondeterministic_seeded
+  variants: method
+  dispatch:
+    CPU, CUDA: random_
+    Meta: random_meta_
+    MPS: random_mps_
+  autogen: random.to, random.to_out
+
+- func: random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  tags: nondeterministic_seeded
+  variants: method
+  dispatch:
+    CPU, CUDA: random_
+    MPS: random_mps_
+    Meta: random_meta_
+  autogen: random, random.out
+
+- func: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  tags: nondeterministic_seeded
+  variants: method
+  dispatch:
+    CPU, CUDA: uniform_
+    MPS: uniform_mps_
+    Meta: uniform_meta_
+  autogen: uniform, uniform.out
+
+- func: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  tags: nondeterministic_seeded
+  dispatch:
+    CPU, CUDA: cauchy_
+  autogen: cauchy, cauchy.out
+
+- func: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  tags: nondeterministic_seeded
+  variants: method
+  dispatch:
+    CPU, CUDA: log_normal_
+  autogen: log_normal, log_normal.out
+
+- func: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  tags: nondeterministic_seeded
+  variants: method
+  dispatch:
+    CPU, CUDA: exponential_
+    MPS: exponential_mps_
+  autogen: exponential, exponential.out
+
+- func: geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  tags: nondeterministic_seeded
+  variants: method
+  dispatch:
+    CPU, CUDA: geometric_
+
+  # wrappers for TH functions
+  autogen: geometric, geometric.out
+
+- func: diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: diag(Tensor self, int diagonal=0) -> Tensor
+  variants: method, function
+
+- func: cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: cross(Tensor self, Tensor other, int? dim=None) -> Tensor
+  variants: method, function
+
+- func: triu.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  dispatch:
+    CPU: triu_cpu
+    CUDA: triu_cuda
+    MPS: triu_mps_out
+
+- func: triu(Tensor self, int diagonal=0) -> Tensor
+  structured_delegate: triu.out
+  variants: method, function
+
+- func: tril.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  dispatch:
+    CPU: tril_cpu
+    CUDA: tril_cuda
+    MPS: tril_mps_out
+
+- func: tril(Tensor self, int diagonal=0) -> Tensor
+  structured_delegate: tril.out
+  variants: method, function
+
+- func: tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CPU: tril_indices_cpu
+    CUDA: tril_indices_cuda
+    MPS: tril_indices_mps
+  autogen: tril_indices.out
+
+- func: triu_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CPU: triu_indices_cpu
+    CUDA: triu_indices_cuda
+    MPS: triu_indices_mps
+  autogen: triu_indices.out
+
+- func: trace(Tensor self) -> Tensor
+  variants: method, function
+  dispatch:
+    CPU: trace_cpu
+    CUDA: trace_cuda
+    MPS: trace_mps
+  autogen: trace.out
+
+- func: trace_backward(Tensor grad, SymInt[] sizes) -> Tensor
+  variants: function
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeImplicitAutograd: trace_backward_symint
+
+- func: ne.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: ne_Scalar_out
+    MPS: ne_scalar_out_mps
+    QuantizedCPU: ne_out_quantized_cpu
+  tags: pointwise
+
+- func: ne.Scalar(Tensor self, Scalar other) -> Tensor
+  structured_delegate: ne.Scalar_out
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    QuantizedCPU: ne_quantized_cpu
+  tags: [core, pointwise]
+
+- func: ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: ne_Tensor_out
+    MPS: ne_tensor_out_mps
+    QuantizedCPU: ne_out_quantized_cpu
+  tags: pointwise
+
+- func: ne.Tensor(Tensor self, Tensor other) -> Tensor
+  structured_delegate: ne.Tensor_out
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    QuantizedCPU: ne_quantized_cpu
+  tags: [core, pointwise]
+
+- func: ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  structured_delegate: ne.Scalar_out
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+- func: ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  structured_delegate: ne.Tensor_out
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+# not_equal, alias for torch.ne
+- func: not_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: not_equal.Scalar(Tensor self, Scalar other) -> Tensor
+  variants: method, function
+
+- func: not_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: not_equal.Tensor(Tensor self, Tensor other) -> Tensor
+  variants: method, function
+
+- func: not_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  variants: method
+
+- func: not_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  variants: method
+
+- func: eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: eq_Scalar_out
+    MPS: eq_scalar_out_mps
+    QuantizedCPU: eq_out_quantized_cpu
+  tags: pointwise
+
+- func: eq.Scalar(Tensor self, Scalar other) -> Tensor
+  structured_delegate: eq.Scalar_out
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    QuantizedCPU: eq_quantized_cpu
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: eq_scalar_nested
+  tags: [core, pointwise]
+
+- func: eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: eq_Tensor_out
+    MPS: eq_tensor_out_mps
+    QuantizedCPU: eq_out_quantized_cpu
+  tags: pointwise
+
+- func: eq.Tensor(Tensor self, Tensor other) -> Tensor
+  structured_delegate: eq.Tensor_out
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    QuantizedCPU: eq_quantized_cpu
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: eq_tensor_nested
+  tags: [core, pointwise]
+
+- func: ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: ge_Scalar_out
+    MPS: ge_scalar_out_mps
+    QuantizedCPU: ge_out_quantized_cpu
+  tags: pointwise
+
+- func: ge.Scalar(Tensor self, Scalar other) -> Tensor
+  structured_delegate: ge.Scalar_out
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    QuantizedCPU: ge_quantized_cpu
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: ge_scalar_nested
+  tags: [core, pointwise]
+
+- func: ge.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: ge_Tensor_out
+    MPS: ge_tensor_out_mps
+    QuantizedCPU: ge_out_quantized_cpu
+  tags: pointwise
+
+- func: ge.Tensor(Tensor self, Tensor other) -> Tensor
+  structured_delegate: ge.Tensor_out
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    QuantizedCPU: ge_quantized_cpu
+  tags: [core, pointwise]
+
+- func: ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  structured_delegate: ge.Scalar_out
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+- func: ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  structured_delegate: ge.Tensor_out
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+# greater_equal, alias for torch.ge
+- func: greater_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: greater_equal.Scalar(Tensor self, Scalar other) -> Tensor
+  variants: method, function
+
+- func: greater_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: greater_equal.Tensor(Tensor self, Tensor other) -> Tensor
+  variants: method, function
+
+- func: greater_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  variants: method
+
+- func: greater_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  variants: method
+
+- func: le.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: le_Scalar_out
+    MPS: le_scalar_out_mps
+    QuantizedCPU: le_out_quantized_cpu
+  tags: pointwise
+
+- func: le.Scalar(Tensor self, Scalar other) -> Tensor
+  structured_delegate: le.Scalar_out
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    QuantizedCPU: le_quantized_cpu
+  tags: [core, pointwise]
+
+- func: le.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: le_Tensor_out
+    MPS: le_tensor_out_mps
+    QuantizedCPU: le_out_quantized_cpu
+  tags: pointwise
+
+- func: le.Tensor(Tensor self, Tensor other) -> Tensor
+  structured_delegate: le.Tensor_out
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    QuantizedCPU: le_quantized_cpu
+  tags: [core, pointwise]
+
+- func: le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  structured_delegate: le.Scalar_out
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+- func: le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  structured_delegate: le.Tensor_out
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+# less_equal, alias for torch.le
+- func: less_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: less_equal.Scalar(Tensor self, Scalar other) -> Tensor
+  variants: method, function
+
+- func: less_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: less_equal.Tensor(Tensor self, Tensor other) -> Tensor
+  variants: method, function
+
+- func: less_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  variants: method
+
+- func: less_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  variants: method
+
+- func: gt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: gt_Scalar_out
+    MPS: gt_scalar_out_mps
+    QuantizedCPU: gt_out_quantized_cpu
+  tags: pointwise
+
+- func: gt.Scalar(Tensor self, Scalar other) -> Tensor
+  structured_delegate: gt.Scalar_out
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    QuantizedCPU: gt_quantized_cpu
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: gt_scalar_nested
+  tags: [core, pointwise]
+
+- func: gt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: gt_Tensor_out
+    MPS: gt_tensor_out_mps
+    QuantizedCPU: gt_out_quantized_cpu
+  tags: pointwise
+
+- func: gt.Tensor(Tensor self, Tensor other) -> Tensor
+  structured_delegate: gt.Tensor_out
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    QuantizedCPU: gt_quantized_cpu
+  tags: [core, pointwise]
+
+- func: gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  structured_delegate: gt.Scalar_out
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+- func: gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  structured_delegate: gt.Tensor_out
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+#  greater, alias for torch.gt
+- func: greater.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: greater.Scalar(Tensor self, Scalar other) -> Tensor
+  variants: method, function
+
+- func: greater.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: greater.Tensor(Tensor self, Tensor other) -> Tensor
+  variants: method, function
+
+- func: greater_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  variants: method
+
+- func: greater_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  variants: method
+
+- func: lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA, MTIA: lt_Scalar_out
+    MPS: lt_scalar_out_mps
+    QuantizedCPU: lt_out_quantized_cpu
+  tags: pointwise
+
+- func: lt.Scalar(Tensor self, Scalar other) -> Tensor
+  structured_delegate: lt.Scalar_out
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    QuantizedCPU: lt_quantized_cpu
+  tags: [core, pointwise]
+
+- func: lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA, MTIA: lt_Tensor_out
+    MPS: lt_tensor_out_mps
+    QuantizedCPU: lt_out_quantized_cpu
+  tags: pointwise
+
+- func: lt.Tensor(Tensor self, Tensor other) -> Tensor
+  structured_delegate: lt.Tensor_out
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    QuantizedCPU: lt_quantized_cpu
+  tags: [core, pointwise]
+
+- func: lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  structured_delegate: lt.Scalar_out
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+- func: lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  structured_delegate: lt.Tensor_out
+  device_check: NoCheck   # TensorIterator
+  variants: method
+
+#  less, alias for torch.lt
+- func: less.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: less.Scalar(Tensor self, Scalar other) -> Tensor
+  variants: method, function
+
+- func: less.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: less.Tensor(Tensor self, Tensor other) -> Tensor
+  variants: method, function
+
+- func: less_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  variants: method
+
+- func: less_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  variants: method
+
+- func: take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA: take_out
+
+- func: take(Tensor self, Tensor index) -> Tensor
+  variants: method, function
+  dispatch:
+    CPU, CUDA: take
+
+- func: take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: take_along_dim(Tensor self, Tensor indices, int? dim=None) -> Tensor
+  variants: method, function
+
+- func: index_select.out(Tensor self, int dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, QuantizedCPU: index_select_out_cpu_
+    CUDA, QuantizedCUDA: index_select_out_cuda
+    MPS: index_select_out_mps
+
+- func: index_select(Tensor self, int dim, Tensor index) -> Tensor
+  variants: method, function
+  dispatch:
+    CPU: index_select_cpu_
+    QuantizedCPU: index_select_quantized_cpu_
+    CUDA: index_select_cuda
+    QuantizedCUDA: index_select_quantized_cuda
+    SparseCPU: index_select_sparse_cpu
+    SparseCUDA: index_select_sparse_cuda
+    MPS: index_select_mps
+  tags: core
+
+- func: index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: index_select.dimname(Tensor self, Dimname dim, Tensor index) -> Tensor
+  variants: method, function
+
+- func: index_select_backward(Tensor grad, SymInt[] self_sizes, int dim, Tensor index) -> Tensor
+  variants: function
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeImplicitAutograd: index_select_backward_symint
+
+- func: masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU: masked_select_out_cpu
+    CUDA: masked_select_out_cuda
+    MPS: masked_select_out_mps
+  tags: dynamic_output_shape
+
+- func: masked_select(Tensor self, Tensor mask) -> Tensor
+  variants: method, function
+  dispatch:
+    CPU: masked_select_cpu
+    CUDA: masked_select_cuda
+    MPS: masked_select_mps
+  tags: dynamic_output_shape
+
+- func: masked_select_backward(Tensor grad, Tensor input, Tensor mask) -> Tensor
+  variants: function
+  device_check: NoCheck
+  device_guard: False
+
+- func: nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU: nonzero_out_cpu
+    CUDA: nonzero_out_cuda
+    MPS: nonzero_out_mps
+  tags: dynamic_output_shape
+
+- func: nonzero(Tensor self) -> Tensor
+  variants: method, function
+  dispatch:
+    CPU: nonzero_cpu
+    CUDA: nonzero_cuda
+    MPS: nonzero_mps
+  tags: [dynamic_output_shape, core]
+
+- func: nonzero_static.out(Tensor self, *, SymInt size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU: nonzero_static_out_cpu
+    CUDA: nonzero_static_out_cuda
+
+- func: nonzero_static(Tensor self, *, SymInt size, int fill_value=-1) -> Tensor
+  variants: method, function
+  dispatch:
+    CPU: nonzero_static_cpu
+    CUDA: nonzero_static_cuda
+
+- func: nonzero_numpy(Tensor self) -> Tensor[]
+  variants: method, function
+
+- func: argwhere(Tensor self) -> Tensor
+  variants: method, function
+  tags: dynamic_output_shape
+
+- func: gather.out(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  dispatch:
+    CPU, CUDA: gather_out
+    MPS: gather_out_mps
+
+- func: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor
+  variants: method, function
+  structured_delegate: gather.out
+  tags: core
+
+- func: gather_backward(Tensor grad, Tensor self, int dim, Tensor index, bool sparse_grad) -> Tensor
+  variants: function
+  device_check: NoCheck
+  device_guard: False
+
+- func: gather.dimname_out(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!)
+
+- func: gather.dimname(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False) -> Tensor
+  variants: method, function
+
+- func: _gather_sparse_backward(Tensor self, int dim, Tensor index, Tensor grad) -> Tensor
+
+- func: addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: addcmul_out
+    MPS: addcmul_out_mps
+  tags: pointwise
+
+- func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
+  structured_delegate: addcmul.out
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  tags: pointwise
+
+- func: addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)
+  structured_delegate: addcmul.out
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  tags: pointwise
+
+- func: addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: addcdiv_out
+    MPS: addcdiv_out_mps
+  tags: pointwise
+
+- func: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
+  structured_delegate: addcdiv.out
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  tags: pointwise
+
+- func: addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)
+  structured_delegate: addcdiv.out
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  tags: pointwise
+
+- func: cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor
+  python_module: nn
+  dispatch:
+    CompositeImplicitAutograd: cross_entropy_loss_symint
+
+- func: triangular_solve.X(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!) solution, Tensor(b!) cloned_coefficient)
+  structured: True
+  dispatch:
+    CPU, CUDA: triangular_solve_out
+    MPS: triangular_solve_mps_out
+    SparseCsrCPU: triangular_solve_out_sparse_csr_cpu
+    SparseCsrCUDA: triangular_solve_out_sparse_csr_cuda
+
+- func: triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient)
+  structured_delegate: triangular_solve.X
+  variants: method, function
+
+- func: _linalg_check_errors(Tensor info, str api_name, *, bool is_matrix) -> ()
+  dispatch:
+    CompositeExplicitAutograd: _linalg_check_errors
+
+- func: linalg_solve_triangular.out(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  dispatch:
+    CPU, CUDA: linalg_solve_triangular_out
+    MPS: linalg_solve_triangular_mps_out
+
+- func: linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor
+  python_module: linalg
+  variants: function
+  dispatch:
+    CPU, CUDA: linalg_solve_triangular
+    MPS: linalg_solve_triangular_mps
+
+- func: linalg_vander(Tensor x, *, SymInt? N=None) -> Tensor
+  python_module: linalg
+  dispatch:
+    CompositeImplicitAutograd: linalg_vander_symint
+
+- func: svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V)
+
+- func: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V)
+  variants: method, function
+
+# swapaxes, alias for transpose
+- func: swapaxes(Tensor(a) self, int axis0, int axis1) -> Tensor(a)
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+
+- func: swapaxes_(Tensor(a!) self, int axis0, int axis1) -> Tensor(a!)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  tags: inplace_view
+
+# swapdims, alias for transpose
+- func: swapdims(Tensor(a) self, int dim0, int dim1) -> Tensor(a)
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+
+- func: swapdims_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  tags: inplace_view
+
+- func: cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA, MPS: cholesky_out
+
+- func: cholesky(Tensor self, bool upper=False) -> Tensor
+  variants: method, function
+  dispatch:
+    CPU, CUDA, MPS: cholesky
+
+- func: cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: cholesky_solve_out
+
+- func: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor
+  variants: method, function
+  dispatch:
+    CompositeExplicitAutograd: cholesky_solve
+
+- func: _cholesky_solve_helper(Tensor self, Tensor A, bool upper) -> Tensor
+  variants: function
+  dispatch:
+    CPU: _cholesky_solve_helper_cpu
+    CUDA: _cholesky_solve_helper_cuda
+  autogen: _cholesky_solve_helper.out
+
+- func: cholesky_inverse(Tensor self, bool upper=False) -> Tensor
+  variants: method, function
+  dispatch:
+    CPU, CUDA: cholesky_inverse
+
+- func: cholesky_inverse.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA: cholesky_inverse_out
+
+- func: qr.Q(Tensor self, bool some=True, *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R)
+
+- func: qr(Tensor self, bool some=True) -> (Tensor Q, Tensor R)
+  variants: method, function
+
+- func: geqrf.a(Tensor self, *, Tensor(a!) a, Tensor(b!) tau) -> (Tensor(a!) a, Tensor(b!) tau)
+  dispatch:
+    CPU, CUDA: geqrf_out
+
+- func: geqrf(Tensor self) -> (Tensor a, Tensor tau)
+  variants: method, function
+  dispatch:
+    CPU, CUDA: geqrf
+
+# orgqr, alias for linalg_householder_product
+- func: orgqr(Tensor self, Tensor input2) -> Tensor
+  variants: method, function
+
+- func: orgqr.out(Tensor self, Tensor input2, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: ormqr.out(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA: ormqr_out
+
+- func: ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor
+  variants: method, function
+  dispatch:
+    CPU, CUDA: ormqr
+
+- func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info)
+  variants: function
+
+- func: lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor
+  variants: method, function
+
+# lu_unpack
+- func: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)
+  structured_delegate: lu_unpack.out
+  variants: function
+
+- func: lu_unpack.out(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True, *, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U)
+  variants: function
+  structured: True
+  dispatch:
+    CPU, CUDA: lu_unpack_out
+    MPS: lu_unpack_out_mps
+
+# TODO: remove dispatch section when porting TH CUDA to ATen
+- func: multinomial.out(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+  tags: nondeterministic_seeded
+  dispatch:
+    CPU, CUDA: multinomial_out
+    MPS: multinomial_out_mps
+
+- func: multinomial(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor
+  variants: method, function
+  dispatch:
+    CPU, CUDA: multinomial
+    MPS: multinomial_mps
+  tags: nondeterministic_seeded
+
+- func: lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: lgamma_out
+    MPS: lgamma_out_mps
+  tags: pointwise
+
+- func: lgamma_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: lgamma.out
+  variants: method
+  tags: pointwise
+
+- func: lgamma(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: lgamma.out
+  variants: method, function
+  tags: pointwise
+
+- func: digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: digamma_out
+    MPS: digamma_out_mps
+  tags: pointwise
+
+- func: digamma(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: digamma.out
+  variants: method, function
+  tags: pointwise
+
+- func: polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: polygamma_out
+    MPS: polygamma_out_mps
+  tags: pointwise
+
+- func: polygamma(int n, Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: polygamma.out
+  variants: method, function
+  tags: pointwise
+
+- func: polygamma_(Tensor(a!) self, int n) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: polygamma_
+  tags: pointwise
+
+- func: erfinv(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: erfinv.out
+  variants: method, function
+  dispatch:
+    SparseCPU, SparseCUDA: erfinv_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr
+  tags: pointwise
+
+- func: erfinv_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: erfinv.out
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA: erfinv_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_
+  tags: pointwise
+
+- func: erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: erfinv_out
+    SparseCPU, SparseCUDA: erfinv_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_out
+  tags: pointwise
+
+- func: i0(Tensor self) -> Tensor
+  structured_delegate: i0.out
+  variants: function, method
+  tags: pointwise
+
+- func: i0_(Tensor(a!) self) -> Tensor(a!)
+  structured_delegate: i0.out
+  variants: function, method
+  tags: pointwise
+
+- func: i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: i0_out
+  tags: pointwise
+
+- func: sign(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: sign.out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: sign_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sign_sparse_csr
+  tags: [core, pointwise]
+
+- func: sign_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: sign.out
+  variants: method
+  dispatch:
+    SparseCPU, SparseCUDA: sign_sparse_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sign_sparse_csr_
+  tags: pointwise
+
+- func: sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: sign_out
+    MPS: sign_out_mps
+    SparseCPU, SparseCUDA: sign_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sign_sparse_csr_out
+  tags: pointwise
+
+- func: signbit(Tensor self) -> Tensor
+  variants: function, method
+  structured_delegate: signbit.out
+  dispatch:
+    SparseCPU, SparseCUDA: signbit_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: signbit_sparse_csr
+  tags: pointwise
+
+- func: signbit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU: signbit_out
+    CUDA: signbit_out
+    MPS: signbit_out_mps
+    SparseCPU, SparseCUDA: signbit_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: signbit_sparse_csr_out
+  tags: pointwise
+
+- func: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    CompositeExplicitAutograd: dist
+  autogen: dist.out
+
+- func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: atan2_out
+    MPS: atan2_out_mps
+  tags: [core, pointwise]
+
+- func: atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: atan2.out
+  variants: method
+  tags: pointwise
+
+- func: atan2(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: atan2.out
+  variants: method, function
+  tags: [core, pointwise]
+# arctan2, alias of atan2
+
+- func: arctan2(Tensor self, Tensor other) -> Tensor
+  variants: method, function
+
+- func: arctan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+
+- func: arctan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  variants: method
+
+- func: lerp.Scalar_out(Tensor self, Tensor end, Scalar weight, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: lerp_Scalar
+  tags: pointwise
+
+- func: lerp.Tensor_out(Tensor self, Tensor end, Tensor weight, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: lerp_Tensor
+    MPS: lerp_Tensor_mps
+  tags: pointwise
+
+- func: lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  structured_delegate: lerp.Scalar_out
+  tags: pointwise
+
+- func: lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  structured_delegate: lerp.Tensor_out
+  tags: pointwise
+
+- func: histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, MPS: histogram_histc_out
+    CUDA: _histc_out_cuda
+
+- func: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor
+  variants: method, function
+  dispatch:
+    CPU, MPS: histogram_histc
+    CUDA: _histc_cuda
+
+- func: histogram.bins_tensor_out(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges)
+  dispatch:
+    CPU, MPS: histogram_out
+
+- func: histogram.bins_tensor(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges)
+  variants: method, function
+  dispatch:
+    CPU, MPS: histogram
+
+- func: histogram.bin_ct_out(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges)
+  dispatch:
+    CPU, MPS: histogram_out
+
+- func: histogram.bin_ct(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges)
+  variants: method, function
+  dispatch:
+    CPU, MPS: histogram
+
+- func: _histogramdd_bin_edges(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor[]
+  dispatch:
+    CPU, MPS: histogramdd_bin_edges
+  autogen: _histogramdd_bin_edges.out
+
+- func: _histogramdd_from_bin_cts(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor
+  dispatch:
+    CPU, MPS: _histogramdd
+  autogen: _histogramdd_from_bin_cts.out
+
+- func: _histogramdd_from_bin_tensors(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False) -> Tensor
+  dispatch:
+    CPU, MPS: _histogramdd
+  autogen: _histogramdd_from_bin_tensors.out
+
+- func: histogramdd(Tensor self, int[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges)
+
+- func: histogramdd.int_bins(Tensor self, int bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges)
+
+- func: histogramdd.TensorList_bins(Tensor self, Tensor[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges)
+
+- func: fmod.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CompositeExplicitAutograd: fmod_out
+  tags: pointwise
+
+- func: fmod.Scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    CompositeExplicitAutograd: fmod
+  tags: [core, pointwise]
+
+- func: fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: fmod_
+  tags: pointwise
+
+- func: fmod.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: fmod_out
+  tags: pointwise
+
+- func: fmod.Tensor(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: fmod.Tensor_out
+  variants: method, function
+  tags: [core, pointwise]
+
+- func: fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  structured_delegate: fmod.Tensor_out
+  tags: pointwise
+
+- func: hypot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: hypot_out
+    MPS: hypot_out_mps
+  tags: pointwise
+
+- func: hypot(Tensor self, Tensor other) -> Tensor
+  structured_delegate: hypot.out
+  variants: method, function
+  tags: pointwise
+
+- func: hypot_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  structured_delegate: hypot.out
+  variants: method
+  tags: pointwise
+
+- func: igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: igamma_out
+  tags: pointwise
+
+- func: igamma(Tensor self, Tensor other) -> Tensor
+  structured_delegate: igamma.out
+  variants: method, function
+  tags: pointwise
+
+- func: igamma_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  structured_delegate: igamma.out
+  variants: method
+  tags: pointwise
+
+- func: igammac.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: igammac_out
+  tags: pointwise
+
+- func: igammac(Tensor self, Tensor other) -> Tensor
+  structured_delegate: igammac.out
+  variants: method, function
+  tags: pointwise
+
+- func: igammac_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  structured_delegate: igammac.out
+  variants: method
+  tags: pointwise
+
+- func: nextafter.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: nextafter_out
+  tags: pointwise
+
+- func: nextafter(Tensor self, Tensor other) -> Tensor
+  structured_delegate: nextafter.out
+  variants: method, function
+  tags: pointwise
+
+- func: nextafter_(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  structured_delegate: nextafter.out
+  variants: method
+  tags: pointwise
+
+- func: remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: remainder_out
+  tags: pointwise
+
+- func: remainder.Scalar(Tensor self, Scalar other) -> Tensor
+  variants: method, function
+  dispatch:
+    CompositeExplicitAutograd: remainder
+  tags: [core, pointwise]
+
+- func: remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  variants: method
+  dispatch:
+    CompositeExplicitAutograd: remainder_
+  tags: pointwise
+
+- func: remainder.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS, MTIA: remainder_out
+  tags: pointwise
+
+- func: remainder.Tensor(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: remainder.Tensor_out
+  variants: method, function
+  tags: [core, pointwise]
+
+- func: remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: remainder.Tensor_out
+  variants: method
+  tags: pointwise
+
+- func: remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: function
+  dispatch:
+    CPU, CUDA, MPS: remainder
+  autogen: remainder.Scalar_Tensor_out
+  tags: pointwise
+
+- func: min(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    CPU, CUDA: min
+    MPS: min_mps
+    QuantizedCPU: min_quantized_cpu
+
+- func: min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: min_unary_out
+    QuantizedCPU: min_quantized_unary_out
+
+- func: fmin(Tensor self, Tensor other) -> Tensor
+  structured_delegate: fmin.out
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  tags: pointwise
+
+- func: fmin.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA, MPS: fmin_out
+  tags: pointwise
+
+- func: max(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    CPU, CUDA: max
+    MPS: max_mps
+    QuantizedCPU: max_quantized_cpu
+
+- func: fmax(Tensor self, Tensor other) -> Tensor
+  structured_delegate: fmax.out
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  tags: pointwise
+
+- func: fmax.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA, MPS: fmax_out
+  tags: pointwise
+
+- func: maximum(Tensor self, Tensor other) -> Tensor
+  structured_delegate: maximum.out
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  tags: [core, pointwise]
+
+- func: maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA, MTIA: maximum_out
+    MPS: maximum_out_mps
+  tags: pointwise
+
+# binary max, alias of maximum
+# NOTE: max is not an alias for maximum, since there is also unary max
+- func: max.other(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  tags: pointwise
+
+- func: max.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  tags: pointwise
+
+- func: max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA: max_unary_out
+    QuantizedCPU: max_quantized_unary_out
+
+- func: minimum(Tensor self, Tensor other) -> Tensor
+  structured_delegate: minimum.out
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  tags: [core, pointwise]
+
+- func: minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CPU, CUDA, MTIA: minimum_out
+    MPS: minimum_out_mps
+  tags: pointwise
+
+# binary min, alias for minimum
+# NOTE: min is not an alias for minimum, since there is also unary min
+- func: min.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  tags: pointwise
+
+- func: min.other(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  tags: pointwise
+
+- func: quantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor
+  variants: method, function
+
+- func: quantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)
+
+- func: quantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor
+  variants: method, function
+
+- func: quantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)
+
+- func: nanquantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor
+  variants: method, function
+
+- func: nanquantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)
+
+- func: nanquantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor
+  variants: method, function
+
+- func: nanquantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)
+
+- func: sort.values(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+  device_check: NoCheck   # TensorIterator
+  dispatch:
+    CompositeExplicitAutograd: sort_out
+
+- func: sort.values_stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+  structured: True
+  dispatch:
+    CPU, CUDA: sort_stable_out
+    MPS: sort_stable_out_mps
+
+- func: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  dispatch:
+    CompositeExplicitAutograd: sort
+  tags: core
+
+- func: sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
+  structured_delegate: sort.values_stable
+  variants: method, function
+  dispatch:
+    QuantizedCPU: sort_quantized_cpu_stable
+
+- func: sort.dimname_values(Tensor self, Dimname dim, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+
+- func: sort.dimname_values_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+
+- func: sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)
+  variants: method, function
+
+- func: sort.dimname_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)
+  variants: method, function
+
+- func: msort.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: msort(Tensor self) -> Tensor
+  variants: method, function
+
+- func: argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+
+- func: argsort.stable(Tensor self, *, bool stable, int dim=-1, bool descending=False) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+
+- func: argsort.stable_out(Tensor self, *, bool stable, int dim=-1, bool descending=False, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: function
+
+- func: argsort.dimname(Tensor self, Dimname dim, bool descending=False) -> Tensor
+  variants: method, function
+
+- func: topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)
+  structured: True
+  dispatch:
+    CPU: topk_out_cpu
+    CUDA: topk_out_cuda
+    MPS: topk_out_mps
+
+- func: topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
+  variants: method, function
+  structured_delegate: topk.values
+  dispatch:
+    QuantizedCPU: topk_quantized_cpu
+  tags: core
+
+- func: all(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: all.all_out
+  variants: method, function
+
+- func: all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck
+  structured: True
+  dispatch:
+    CPU, CUDA: all_all_out
+    MPS: all_all_out_mps
+
+- func: any(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: any.all_out
+  variants: method, function
+  dispatch:
+    SparseCPU, SparseCUDA: any_sparse
+  tags: core
+
+- func: any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck
+  structured: True
+  dispatch:
+    CPU, CUDA: any_all_out
+    MPS: any_all_out_mps
+
+- func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  dispatch:
+    CPU, CUDA: renorm_out
+    MPS: renorm_out_mps
+
+- func: renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  variants: method, function
+  structured_delegate: renorm.out
+
+- func: renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  variants: method
+  structured_delegate: renorm.out
+
+- func: unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)
+  variants: method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CPU, CUDA, Meta, MPS, MTIA: unfold
+    QuantizedCPU, QuantizedCUDA: unfold
+
+- func: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor
+  variants: function
+  dispatch:
+    CPU, CUDA, MPS: unfold_backward
+  autogen: unfold_backward.out
+
+- func: equal(Tensor self, Tensor other) -> bool
+  tags: [data_dependent_output, pointwise]
+  variants: method, function
+  dispatch:
+    CPU: cpu_equal
+    CUDA: cuda_equal
+    MPS: mps_equal
+    QuantizedCPU: equal_quantized_cpu
+
+- func: pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: pow_Tensor_Tensor_out
+    MPS: pow_tensor_tensor_out_mps
+  tags: pointwise
+
+- func: pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: pow.Tensor_Tensor_out
+  variants: method, function
+  tags: [core, pointwise]
+
+- func: pow.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  dispatch:
+    CPU, CUDA: pow_Scalar_out
+    MPS: pow_Scalar_out_mps
+  tags: pointwise
+
+- func: pow.Scalar(Scalar self, Tensor exponent) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: pow.Scalar_out
+  tags: [core, pointwise]
+
+- func: pow.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: pow_Tensor_Scalar_out
+    SparseCPU, SparseCUDA: pow_out_sparse_scalar
+    MPS: pow_tensor_scalar_out_mps
+  tags: pointwise
+
+- func: pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: pow.Tensor_Scalar_out
+  variants: function, method
+  dispatch:
+    SparseCPU, SparseCUDA: pow_sparse_scalar
+  tags: [core, pointwise]
+
+- func: pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: pow.Tensor_Scalar_out
+  variants: method
+  tags: pointwise
+
+- func: pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: pow.Tensor_Tensor_out
+  variants: method
+  tags: pointwise
+
+- func: float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
+  tags: pointwise
+
+- func: float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor
+  variants: function, method
+  tags: pointwise
+
+- func: float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)
+  tags: pointwise
+
+- func: float_power.Scalar(Scalar self, Tensor exponent) -> Tensor
+  tags: pointwise
+
+- func: float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!)
+  tags: pointwise
+
+- func: float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor
+  variants: function, method
+  tags: pointwise
+
+- func: float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)
+  variants: method
+  tags: pointwise
+
+- func: float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)
+  variants: method
+  tags: pointwise
+
+- func: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  tags: nondeterministic_seeded
+  variants: method
+  dispatch:
+    CPU, CUDA: normal_
+    MPS: normal_mps_
+    Meta: normal_meta_
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: normal_sparse_csr_
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: normal_nested_
+  autogen: normal.out
+
+# Only used by the functionalization pass.
+# Normally, the codegen would be able to generate a normal() NativeFunction,
+# but we can't due to overload ambiguity with normal.Tensor_float.
+- func: normal_functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  tags: nondeterministic_seeded
+  dispatch:
+    CompositeExplicitAutograd: normal_functional
+
+- func: normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+  tags: nondeterministic_seeded
+  dispatch:
+    CPU, CUDA: normal_out
+    MPS: normal_mps_out
+    Meta: normal_out_meta
+
+- func: normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor
+  dispatch:
+    CPU, CUDA: normal
+    MPS: normal_mps
+    Meta: normal_meta
+  tags: nondeterministic_seeded
+
+- func: normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA: normal_out
+    Meta: normal_out_meta
+    MPS: normal_mps_out
+  tags: nondeterministic_seeded
+
+- func: normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor
+  dispatch:
+    CPU, CUDA: normal
+    MPS: normal_mps
+    Meta: normal_meta
+  tags: nondeterministic_seeded
+
+- func: normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA: normal_out
+    Meta: normal_out_meta
+    MPS: normal_mps_out
+  tags: nondeterministic_seeded
+
+- func: normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor
+  dispatch:
+    CPU, CUDA: normal
+    MPS: normal_mps
+    Meta: normal_meta
+  tags: nondeterministic_seeded
+
+- func: normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: normal
+  tags: nondeterministic_seeded
+
+- func: normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: normal_out
+  tags: nondeterministic_seeded
+
+- func: alias(Tensor(a) self) -> Tensor(a)
+  variants: method, function
+  dispatch:
+    CompositeExplicitAutograd: alias
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: alias_nested
+  tags: core
+
+- func: _amp_foreach_non_finite_check_and_unscale_(Tensor(a!)[] self, Tensor(b!) found_inf, Tensor inv_scale) -> ()
+  variants: function
+  dispatch:
+    CUDA: _amp_foreach_non_finite_check_and_unscale_cuda_
+    CPU: _amp_foreach_non_finite_check_and_unscale_cpu_
+    MPS: _amp_foreach_non_finite_check_and_unscale_mps_
+  autogen: _amp_foreach_non_finite_check_and_unscale, _amp_foreach_non_finite_check_and_unscale.out
+
+- func: _amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!)
+  variants: function
+  dispatch:
+    CUDA: _amp_update_scale_cuda_
+    CPU: _amp_update_scale_cpu_
+    MPS: _amp_update_scale_mps_
+  autogen: _amp_update_scale, _amp_update_scale.out
+
+    #- func: _cat(Tensor[] tensors, int dim=0) -> Tensor
+    #dispatch:
+    #CPU: _cat_cpu
+    #CUDA: cat_cuda
+    #MPS: cat_mps
+    #QuantizedCPU: cat_quantized_cpu
+
+    #- func: _cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)
+    #dispatch:
+    #CPU: _cat_out_cpu
+  #CUDA: cat_out_cuda
+  #QuantizedCPU: cat_out_quantized_cpu
+
+- func: _foreach_add.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_add_scalar_kernel_slow
+    CUDA: foreach_tensor_add_scalar_kernel_cuda
+
+- func: _foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_add_scalar_kernel_slow_
+    CUDA: foreach_tensor_add_scalar_kernel_cuda_
+  autogen: _foreach_add.Scalar_out
+
+- func: _foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow
+    CUDA: foreach_tensor_add_list_kernel_cuda
+
+- func: _foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow_
+    CUDA: foreach_tensor_add_list_kernel_cuda_
+  autogen: _foreach_add.List_out
+
+- func: _foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_add_scalarlist_kernel_slow
+    CUDA: foreach_tensor_add_scalarlist_kernel_cuda
+
+- func: _foreach_add_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_add_scalarlist_kernel_slow_
+    CUDA: foreach_tensor_add_scalarlist_kernel_cuda_
+  autogen: _foreach_add.ScalarList_out
+
+- func: _foreach_add.Tensor(Tensor[] self, Tensor other, *, Scalar alpha=1) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_add_tensor_kernel_slow
+    CUDA: foreach_tensor_add_tensor_kernel_cuda
+
+- func: _foreach_add_.Tensor(Tensor(a!)[] self, Tensor other, *, Scalar alpha=1) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_add_tensor_kernel_slow_
+    CUDA: foreach_tensor_add_tensor_kernel_cuda_
+  autogen: _foreach_add.Tensor_out
+
+- func: _foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_sub_scalar_kernel_slow
+    CUDA: foreach_tensor_sub_scalar_kernel_cuda
+
+- func: _foreach_sub_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_sub_scalar_kernel_slow_
+    CUDA: foreach_tensor_sub_scalar_kernel_cuda_
+  autogen: _foreach_sub.Scalar_out
+
+- func: _foreach_sub.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_sub_list_kernel_slow
+    CUDA: foreach_tensor_sub_list_kernel_cuda
+
+- func: _foreach_sub_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_sub_list_kernel_slow_
+    CUDA: foreach_tensor_sub_list_kernel_cuda_
+  autogen: _foreach_sub.List_out
+
+- func: _foreach_sub.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_sub_scalarlist_kernel_slow
+    CUDA: foreach_tensor_sub_scalarlist_kernel_cuda
+
+- func: _foreach_sub_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_sub_scalarlist_kernel_slow_
+    CUDA: foreach_tensor_sub_scalarlist_kernel_cuda_
+  autogen: _foreach_sub.ScalarList_out
+
+- func: _foreach_mul.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_mul_scalar_kernel_slow
+    CUDA: foreach_tensor_mul_scalar_kernel_cuda
+
+- func: _foreach_mul_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_mul_scalar_kernel_slow_
+    CUDA: foreach_tensor_mul_scalar_kernel_cuda_
+  autogen: _foreach_mul.Scalar_out
+
+- func: _foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow
+    CUDA: foreach_tensor_mul_list_kernel_cuda
+
+- func: _foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow_
+    CUDA: foreach_tensor_mul_list_kernel_cuda_
+  autogen: _foreach_mul.List_out
+
+- func: _foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_mul_scalarlist_kernel_slow
+    CUDA: foreach_tensor_mul_scalarlist_kernel_cuda
+
+- func: _foreach_mul_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_mul_scalarlist_kernel_slow_
+    CUDA: foreach_tensor_mul_scalarlist_kernel_cuda_
+  autogen: _foreach_mul.ScalarList_out
+
+- func: _foreach_mul.Tensor(Tensor[] self, Tensor other) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow
+    CUDA: foreach_tensor_mul_tensor_kernel_cuda
+
+- func: _foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow_
+    CUDA: foreach_tensor_mul_tensor_kernel_cuda_
+  autogen: _foreach_mul.Tensor_out
+
+- func: _foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_div_scalar_kernel_slow
+    CUDA: foreach_tensor_div_scalar_kernel_cuda
+
+- func: _foreach_div_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_div_scalar_kernel_slow_
+    CUDA: foreach_tensor_div_scalar_kernel_cuda_
+  autogen: _foreach_div.Scalar_out
+
+- func: _foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_div_list_kernel_slow
+    CUDA: foreach_tensor_div_list_kernel_cuda
+
+- func: _foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_div_list_kernel_slow_
+    CUDA: foreach_tensor_div_list_kernel_cuda_
+  autogen: _foreach_div.List_out
+
+- func: _foreach_div.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_div_scalarlist_kernel_slow
+    CUDA: foreach_tensor_div_scalarlist_kernel_cuda
+
+- func: _foreach_div_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_div_scalarlist_kernel_slow_
+    CUDA: foreach_tensor_div_scalarlist_kernel_cuda_
+  autogen: _foreach_div.ScalarList_out
+
+- func: _foreach_div.Tensor(Tensor[] self, Tensor other) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_div_tensor_kernel_slow
+    CUDA: foreach_tensor_div_tensor_kernel_cuda
+
+- func: _foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_div_tensor_kernel_slow_
+    CUDA: foreach_tensor_div_tensor_kernel_cuda_
+  autogen: _foreach_div.Tensor_out
+
+- func: _foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow
+    CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda
+
+- func: _foreach_clamp_max_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow_
+    CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda_
+  autogen: _foreach_clamp_max.Scalar_out
+
+- func: _foreach_clamp_max.List(Tensor[] self, Tensor[] other) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow
+    CUDA: foreach_tensor_clamp_max_list_kernel_cuda
+
+- func: _foreach_clamp_max_.List(Tensor(a!)[] self, Tensor[] other) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow_
+    CUDA: foreach_tensor_clamp_max_list_kernel_cuda_
+  autogen: _foreach_clamp_max.List_out
+
+- func: _foreach_clamp_max.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow
+    CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda
+
+- func: _foreach_clamp_max_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow_
+    CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda_
+  autogen: _foreach_clamp_max.ScalarList_out
+
+- func: _foreach_clamp_min.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow
+    CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda
+
+- func: _foreach_clamp_min_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow_
+    CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda_
+  autogen: _foreach_clamp_min.Scalar_out
+
+- func: _foreach_clamp_min.List(Tensor[] self, Tensor[] other) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow
+    CUDA: foreach_tensor_clamp_min_list_kernel_cuda
+
+- func: _foreach_clamp_min_.List(Tensor(a!)[] self, Tensor[] other) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow_
+    CUDA: foreach_tensor_clamp_min_list_kernel_cuda_
+  autogen: _foreach_clamp_min.List_out
+
+- func: _foreach_clamp_min.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow
+    CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda
+
+- func: _foreach_clamp_min_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow_
+    CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda_
+  autogen: _foreach_clamp_min.ScalarList_out
+
+# foreach_minimum/maximum dispatches to clamp_max/min
+- func: _foreach_maximum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow
+    CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda
+
+- func: _foreach_maximum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_min_scalar_kernel_slow_
+    CUDA: foreach_tensor_clamp_min_scalar_kernel_cuda_
+  autogen: _foreach_maximum.Scalar_out
+
+# foreach_minimum/maximum dispatches to clamp_max/min
+- func: _foreach_maximum.List(Tensor[] self, Tensor[] other) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow
+    CUDA: foreach_tensor_clamp_min_list_kernel_cuda
+
+- func: _foreach_maximum_.List(Tensor(a!)[] self, Tensor[] other) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_min_list_kernel_slow_
+    CUDA: foreach_tensor_clamp_min_list_kernel_cuda_
+  autogen: _foreach_maximum.List_out
+
+# foreach_minimum/maximum dispatches to clamp_max/min
+- func: _foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow
+    CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda
+
+- func: _foreach_maximum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_min_scalarlist_kernel_slow_
+    CUDA: foreach_tensor_clamp_min_scalarlist_kernel_cuda_
+  autogen: _foreach_maximum.ScalarList_out
+
+- func: _foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow
+    CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda
+
+- func: _foreach_minimum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_max_scalar_kernel_slow_
+    CUDA: foreach_tensor_clamp_max_scalar_kernel_cuda_
+  autogen: _foreach_minimum.Scalar_out
+
+- func: _foreach_minimum.List(Tensor[] self, Tensor[] other) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow
+    CUDA: foreach_tensor_clamp_max_list_kernel_cuda
+
+- func: _foreach_minimum_.List(Tensor(a!)[] self, Tensor[] other) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_max_list_kernel_slow_
+    CUDA: foreach_tensor_clamp_max_list_kernel_cuda_
+  autogen: _foreach_minimum.List_out
+
+- func: _foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow
+    CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda
+
+- func: _foreach_minimum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_clamp_max_scalarlist_kernel_slow_
+    CUDA: foreach_tensor_clamp_max_scalarlist_kernel_cuda_
+  autogen: _foreach_minimum.ScalarList_out
+
+- func: _foreach_addcdiv.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_addcdiv_scalar_slow
+    CUDA: foreach_tensor_addcdiv_scalar_cuda
+
+- func: _foreach_addcdiv.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_addcdiv_scalarlist_slow
+    CUDA: foreach_tensor_addcdiv_scalarlist_cuda
+
+- func: _foreach_addcdiv.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_addcdiv_tensor_slow
+    CUDA: foreach_tensor_addcdiv_tensor_cuda
+
+- func: _foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_addcdiv_scalar_slow_
+    CUDA: foreach_tensor_addcdiv_scalar_cuda_
+  autogen: _foreach_addcdiv.Scalar_out
+
+- func: _foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_addcdiv_scalarlist_slow_
+    CUDA: foreach_tensor_addcdiv_scalarlist_cuda_
+  autogen: _foreach_addcdiv.ScalarList_out
+
+- func: _foreach_addcdiv_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_addcdiv_tensor_slow_
+    CUDA: foreach_tensor_addcdiv_tensor_cuda_
+  autogen: _foreach_addcdiv.Tensor_out
+
+- func: _foreach_addcmul.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow
+    CUDA: foreach_tensor_addcmul_scalar_cuda
+
+- func: _foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_addcmul_scalarlist_slow
+    CUDA: foreach_tensor_addcmul_scalarlist_cuda
+
+- func: _foreach_addcmul.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_addcmul_tensor_slow
+    CUDA: foreach_tensor_addcmul_tensor_cuda
+
+- func: _foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow_
+    CUDA: foreach_tensor_addcmul_scalar_cuda_
+  autogen: _foreach_addcmul.Scalar_out
+
+- func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_addcmul_scalarlist_slow_
+    CUDA: foreach_tensor_addcmul_scalarlist_cuda_
+  autogen: _foreach_addcmul.ScalarList_out
+
+- func: _foreach_addcmul_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_addcmul_tensor_slow_
+    CUDA: foreach_tensor_addcmul_tensor_cuda_
+  autogen: _foreach_addcmul.Tensor_out
+
+- func: _foreach_abs(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_abs_slow
+    CUDA: foreach_tensor_abs_cuda
+
+- func: _foreach_abs_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_abs_slow_
+    CUDA: foreach_tensor_abs_cuda_
+  autogen: _foreach_abs.out
+
+- func: _foreach_acos(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_acos_slow
+    CUDA: foreach_tensor_acos_cuda
+
+- func: _foreach_acos_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_acos_slow_
+    CUDA: foreach_tensor_acos_cuda_
+  autogen: _foreach_acos.out
+
+- func: _foreach_asin(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_asin_slow
+    CUDA: foreach_tensor_asin_cuda
+
+- func: _foreach_asin_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_asin_slow_
+    CUDA: foreach_tensor_asin_cuda_
+  autogen: _foreach_asin.out
+
+- func: _foreach_atan(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_atan_slow
+    CUDA: foreach_tensor_atan_cuda
+
+- func: _foreach_atan_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_atan_slow_
+    CUDA: foreach_tensor_atan_cuda_
+  autogen: _foreach_atan.out
+
+- func: _foreach_ceil(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_ceil_slow
+    CUDA: foreach_tensor_ceil_cuda
+
+- func: _foreach_ceil_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_ceil_slow_
+    CUDA: foreach_tensor_ceil_cuda_
+  autogen: _foreach_ceil.out
+
+- func: _foreach_cos(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_cos_slow
+    CUDA: foreach_tensor_cos_cuda
+
+- func: _foreach_cos_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_cos_slow_
+    CUDA: foreach_tensor_cos_cuda_
+  autogen: _foreach_cos.out
+
+- func: _foreach_cosh(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_cosh_slow
+    CUDA: foreach_tensor_cosh_cuda
+
+- func: _foreach_cosh_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_cosh_slow_
+    CUDA: foreach_tensor_cosh_cuda_
+  autogen: _foreach_cosh.out
+
+- func: _foreach_erf(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_erf_slow
+    CUDA: foreach_tensor_erf_cuda
+
+- func: _foreach_erf_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_erf_slow_
+    CUDA: foreach_tensor_erf_cuda_
+  autogen: _foreach_erf.out
+
+- func: _foreach_erfc(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_erfc_slow
+    CUDA: foreach_tensor_erfc_cuda
+
+- func: _foreach_erfc_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_erfc_slow_
+    CUDA: foreach_tensor_erfc_cuda_
+  autogen: _foreach_erfc.out
+
+- func: _foreach_exp(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_exp_slow
+    CUDA: foreach_tensor_exp_cuda
+
+- func: _foreach_exp_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_exp_slow_
+    CUDA: foreach_tensor_exp_cuda_
+  autogen: _foreach_exp.out
+
+- func: _foreach_expm1(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_expm1_slow
+    CUDA: foreach_tensor_expm1_cuda
+
+- func: _foreach_expm1_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_expm1_slow_
+    CUDA: foreach_tensor_expm1_cuda_
+  autogen: _foreach_expm1.out
+
+- func: _foreach_floor(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_floor_slow
+    CUDA: foreach_tensor_floor_cuda
+
+- func: _foreach_floor_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_floor_slow_
+    CUDA: foreach_tensor_floor_cuda_
+  autogen: _foreach_floor.out
+
+- func: _foreach_frac(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_frac_slow
+    CUDA: foreach_tensor_frac_cuda
+
+- func: _foreach_frac_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_frac_slow_
+    CUDA: foreach_tensor_frac_cuda_
+  autogen: _foreach_frac.out
+
+- func: _foreach_lerp.List(Tensor[] self, Tensor[] tensors1, Tensor[] weights) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensors are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_ternary_lerp_slow
+    CUDA: foreach_tensor_lerp_ternary_cuda
+  autogen: _foreach_lerp.List_out
+
+- func: _foreach_lerp_.List(Tensor(a!)[] self, Tensor[] tensors1, Tensor[] weights) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensors are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_ternary_lerp_slow_
+    CUDA: foreach_tensor_lerp_ternary_cuda_
+  autogen: _foreach_lerp.List_out
+
+- func: _foreach_lerp.Scalar(Tensor[] self, Tensor[] tensors1, Scalar weight) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensors are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_lerp_list_kernel_slow
+    CUDA: foreach_tensor_lerp_list_cuda
+  autogen: _foreach_lerp.Scalar_out
+
+- func: _foreach_lerp_.Scalar(Tensor(a!)[] self, Tensor[] tensors1, Scalar weight) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensors are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_lerp_list_kernel_slow_
+    CUDA: foreach_tensor_lerp_list_cuda_
+  autogen: _foreach_lerp.Scalar_out
+
+- func: _foreach_lerp.ScalarList(Tensor[] self, Tensor[] tensors1, Scalar[] weight) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensors are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_lerp_scalarlist_kernel_slow
+    CUDA: foreach_tensor_lerp_scalarlist_cuda
+  autogen: _foreach_lerp.ScalarList_out
+
+- func: _foreach_lerp_.ScalarList(Tensor(a!)[] self, Tensor[] tensors1, Scalar[] weight) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensors are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_lerp_scalarlist_kernel_slow_
+    CUDA: foreach_tensor_lerp_scalarlist_cuda_
+  autogen: _foreach_lerp.ScalarList_out
+
+- func: _foreach_lgamma(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_lgamma_slow
+    CUDA: foreach_tensor_lgamma_cuda
+
+- func: _foreach_lgamma_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_lgamma_slow_
+    CUDA: foreach_tensor_lgamma_cuda_
+  autogen: _foreach_lgamma.out
+
+- func: _foreach_log(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_log_slow
+    CUDA: foreach_tensor_log_cuda
+
+- func: _foreach_log_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_log_slow_
+    CUDA: foreach_tensor_log_cuda_
+  autogen: _foreach_log.out
+
+- func: _foreach_log10(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_log10_slow
+    CUDA: foreach_tensor_log10_cuda
+
+- func: _foreach_log10_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_log10_slow_
+    CUDA: foreach_tensor_log10_cuda_
+  autogen: _foreach_log10.out
+
+- func: _foreach_log1p(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_log1p_slow
+    CUDA: foreach_tensor_log1p_cuda
+
+- func: _foreach_log1p_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_log1p_slow_
+    CUDA: foreach_tensor_log1p_cuda_
+  autogen: _foreach_log1p.out
+
+- func: _foreach_log2(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_log2_slow
+    CUDA: foreach_tensor_log2_cuda
+
+- func: _foreach_log2_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_log2_slow_
+    CUDA: foreach_tensor_log2_cuda_
+  autogen: _foreach_log2.out
+
+- func: _foreach_max(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_max_slow
+    CUDA: foreach_tensor_max_cuda
+  autogen: _foreach_max.out
+
+- func: _foreach_neg(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_neg_slow
+    CUDA: foreach_tensor_neg_cuda
+
+- func: _foreach_neg_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_neg_slow_
+    CUDA: foreach_tensor_neg_cuda_
+  autogen: _foreach_neg.out
+
+- func: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2, ScalarType? dtype=None) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_norm_slow
+    CUDA: foreach_tensor_norm_cuda
+  autogen: _foreach_norm.Scalar_out
+
+- func: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_pow_list_kernel_slow
+    CUDA: foreach_tensor_pow_list_kernel_cuda
+
+- func: _foreach_pow.Scalar(Tensor[] self, Scalar exponent) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_pow_scalar_kernel_slow
+    CUDA: foreach_tensor_pow_scalar_kernel_cuda
+
+- func: _foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_pow_scalarlist_kernel_slow
+    CUDA: foreach_tensor_pow_scalarlist_kernel_cuda
+
+- func: _foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_scalar_pow_list_kernel_slow
+    CUDA: foreach_scalar_pow_list_kernel_cuda
+
+- func: _foreach_pow_.List(Tensor(a!)[] self, Tensor[] exponent) -> ()
+  device_check: NoCheck
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_pow_list_kernel_slow_
+    CUDA: foreach_tensor_pow_list_kernel_cuda_
+  autogen: _foreach_pow.List_out
+
+- func: _foreach_pow_.Scalar(Tensor(a!)[] self, Scalar exponent) -> ()
+  device_check: NoCheck
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_pow_scalar_kernel_slow_
+    CUDA: foreach_tensor_pow_scalar_kernel_cuda_
+  autogen: _foreach_pow.Scalar_out
+
+- func: _foreach_pow_.ScalarList(Tensor(a!)[] self, Scalar[] exponent) -> ()
+  device_check: NoCheck
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_pow_scalarlist_kernel_slow_
+    CUDA: foreach_tensor_pow_scalarlist_kernel_cuda_
+  autogen: _foreach_pow.ScalarList_out
+
+- func: _foreach_reciprocal(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_reciprocal_slow
+    CUDA: foreach_tensor_reciprocal_cuda
+
+- func: _foreach_reciprocal_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_reciprocal_slow_
+    CUDA: foreach_tensor_reciprocal_cuda_
+  autogen: _foreach_reciprocal.out
+
+- func: _foreach_round(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_round_slow
+    CUDA: foreach_tensor_round_cuda
+
+- func: _foreach_round_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_round_slow_
+    CUDA: foreach_tensor_round_cuda_
+  autogen: _foreach_round.out
+
+- func: _foreach_rsqrt(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_rsqrt_slow
+    CUDA: foreach_tensor_rsqrt_cuda
+
+- func: _foreach_rsqrt_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_rsqrt_slow_
+    CUDA: foreach_tensor_rsqrt_cuda_
+  autogen: _foreach_rsqrt.out
+
+- func: _foreach_sigmoid(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_sigmoid_slow
+    CUDA: foreach_tensor_sigmoid_cuda
+
+- func: _foreach_sigmoid_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_sigmoid_slow_
+    CUDA: foreach_tensor_sigmoid_cuda_
+  autogen: _foreach_sigmoid.out
+
+- func: _foreach_sign(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_sign_slow
+    CUDA: foreach_tensor_sign_cuda
+
+- func: _foreach_sign_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_sign_slow_
+    CUDA: foreach_tensor_sign_cuda_
+  autogen: _foreach_sign.out
+
+- func: _foreach_sin(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_sin_slow
+    CUDA: foreach_tensor_sin_cuda
+
+- func: _foreach_sin_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_sin_slow_
+    CUDA: foreach_tensor_sin_cuda_
+  autogen: _foreach_sin.out
+
+- func: _foreach_sinh(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_sinh_slow
+    CUDA: foreach_tensor_sinh_cuda
+
+- func: _foreach_sinh_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_sinh_slow_
+    CUDA: foreach_tensor_sinh_cuda_
+  autogen: _foreach_sinh.out
+
+- func: _foreach_sqrt(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_sqrt_slow
+    CUDA: foreach_tensor_sqrt_cuda
+
+- func: _foreach_sqrt_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_sqrt_slow_
+    CUDA: foreach_tensor_sqrt_cuda_
+  autogen: _foreach_sqrt.out
+
+- func: _foreach_tan(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_tan_slow
+    CUDA: foreach_tensor_tan_cuda
+
+- func: _foreach_tan_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_tan_slow_
+    CUDA: foreach_tensor_tan_cuda_
+  autogen: _foreach_tan.out
+
+- func: _foreach_tanh(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_tanh_slow
+    CUDA: foreach_tensor_tanh_cuda
+
+- func: _foreach_tanh_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_tanh_slow_
+    CUDA: foreach_tensor_tanh_cuda_
+  autogen: _foreach_tanh.out
+
+- func: _foreach_trunc(Tensor[] self) -> Tensor[]
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_trunc_slow
+    CUDA: foreach_tensor_trunc_cuda
+
+- func: _foreach_trunc_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_trunc_slow_
+    CUDA: foreach_tensor_trunc_cuda_
+  autogen: _foreach_trunc.out
+
+- func: _foreach_zero_(Tensor(a!)[] self) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_zero_slow_
+    CUDA: foreach_tensor_zero_cuda_
+  autogen: _foreach_zero, _foreach_zero.out
+
+- func: _foreach_copy_(Tensor(a!)[] self, Tensor[] src, bool non_blocking=False) -> ()
+  device_check: NoCheck   # foreach kernels fall back to slow path when tensor are on different devices
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: foreach_tensor_copy_list_kernel_slow_
+    CUDA: foreach_tensor_copy_list_kernel_cuda_
+  autogen: _foreach_copy.out
+
+- func: _foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out
+  device_check: NoCheck
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: _foreach_copy
+
+- func: bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor
+  dispatch:
+    CPU: bucketize_cpu
+    CUDA: bucketize_cuda
+    MPS: bucketize_mps
+
+- func: bucketize.Tensor_out(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU: bucketize_out_cpu
+    CUDA: bucketize_out_cuda
+    MPS: bucketize_out_mps
+
+- func: bucketize.Scalar(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor
+  dispatch:
+    CPU: bucketize_cpu
+    CUDA: bucketize_cuda
+    MPS: bucketize_mps
+  autogen: bucketize.Scalar_out
+
+- func: searchsorted.Tensor(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor
+  dispatch:
+    CPU: searchsorted_cpu
+    CUDA: searchsorted_cuda
+    MPS: searchsorted_mps
+
+- func: searchsorted.Tensor_out(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU: searchsorted_out_cpu
+    CUDA: searchsorted_out_cuda
+    MPS: searchsorted_out_mps
+
+- func: searchsorted.Scalar(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor
+  dispatch:
+    CPU: searchsorted_cpu
+    CUDA: searchsorted_cuda
+    MPS: searchsorted_mps
+
+- func: searchsorted.Scalar_out(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU: searchsorted_out_cpu
+    CUDA: searchsorted_out_cuda
+    MPS: searchsorted_out_mps
+
+- func: _convert_indices_from_coo_to_csr(Tensor self, int size, *, bool out_int32=False) -> Tensor
+  structured_delegate: _convert_indices_from_coo_to_csr.out
+
+- func: _convert_indices_from_coo_to_csr.out(Tensor self, int size, *, bool out_int32=False, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  dispatch:
+    CPU: _convert_indices_from_coo_to_csr_structured_cpu
+    CUDA: _convert_indices_from_coo_to_csr_structured_cuda
+
+- func: _convert_indices_from_csr_to_coo(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False) -> Tensor
+  structured_delegate: _convert_indices_from_csr_to_coo.out
+
+- func: _convert_indices_from_csr_to_coo.out(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  dispatch:
+    CPU: _convert_indices_from_csr_to_coo_structured_cpu
+    CUDA: _convert_indices_from_csr_to_coo_structured_cuda
+
+## NN wrappers
+
+- func: mse_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  python_module: nn
+  dispatch:
+    CPU, CUDA: mse_loss_out
+    MPS: mse_loss_out_mps
+
+- func: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: mse_loss.out
+  python_module: nn
+
+- func: mse_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU, CUDA: mse_loss_backward_out
+    MPS: mse_loss_backward_out_mps
+
+- func: mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU, CUDA: mse_loss_backward
+    MPS: mse_loss_backward_mps
+
+- func: l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
+  python_module: nn
+
+- func: multi_margin_loss.out(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: multi_margin_loss_cpu_out
+    CUDA: multi_margin_loss_cuda_out
+
+- func: multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: multi_margin_loss_cpu
+    CUDA: multi_margin_loss_cuda
+
+- func: multi_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: multi_margin_loss_cpu_backward_out
+    CUDA: multi_margin_loss_cuda_backward_out
+
+- func: multi_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: multi_margin_loss_cpu_backward
+    CUDA: multi_margin_loss_cuda_backward
+
+- func: multilabel_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+
+- func: multilabel_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
+  python_module: nn
+
+- func: multilabel_margin_loss_forward.output(Tensor self, Tensor target, int reduction, *, Tensor(a!) output, Tensor(b!) is_target) -> (Tensor(a!), Tensor(b!))
+  python_module: nn
+  dispatch:
+    CPU: multilabel_margin_loss_forward_out_cpu
+    CUDA: multilabel_margin_loss_forward_out_cuda
+
+- func: multilabel_margin_loss_forward(Tensor self, Tensor target, int reduction) -> (Tensor output, Tensor is_target)
+  python_module: nn
+  dispatch:
+    CPU: multilabel_margin_loss_forward_cpu
+    CUDA: multilabel_margin_loss_forward_cuda
+
+- func: multilabel_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: multilabel_margin_loss_backward_cpu_out
+    CUDA: multilabel_margin_loss_backward_cuda_out
+
+- func: multilabel_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: multilabel_margin_loss_backward_cpu
+    CUDA: multilabel_margin_loss_backward_cuda
+
+- func: nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+
+- func: nll_loss_nd(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor
+  python_module: nn
+  dispatch:
+    CompositeImplicitAutograd: nll_loss_nd_symint
+
+- func: nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor
+  python_module: nn
+  dispatch:
+    CompositeImplicitAutograd: nll_loss_symint
+
+- func: nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!))
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: nll_loss_forward_out_cpu
+    CUDA: nll_loss_forward_out_cuda
+    MPS: nll_loss_forward_out_mps
+
+- func: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight)
+  python_module: nn
+  structured_delegate: nll_loss_forward.output
+
+- func: nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: nll_loss_backward_out_cpu
+    CUDA: nll_loss_backward_out_cuda
+    MPS: nll_loss_backward_out_mps
+
+- func: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor
+  python_module: nn
+  structured_delegate: nll_loss_backward.grad_input
+
+- func: nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+
+- func: nll_loss2d(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor
+  python_module: nn
+  dispatch:
+    CompositeImplicitAutograd: nll_loss2d_symint
+
+- func: nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!))
+  python_module: nn
+  dispatch:
+    CPU: nll_loss2d_forward_out_cpu
+    CUDA: nll_loss2d_forward_out_cuda
+    MPS: nll_loss2d_forward_out_mps
+
+- func: nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight)
+  python_module: nn
+  dispatch:
+    CPU: nll_loss2d_forward_cpu
+    CUDA: nll_loss2d_forward_cuda
+    MPS: nll_loss2d_forward_mps
+
+- func: nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: nll_loss2d_backward_out_cpu
+    CUDA: nll_loss2d_backward_out_cuda
+    MPS: nll_loss2d_backward_out_mps
+
+- func: nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: nll_loss2d_backward_cpu
+    CUDA: nll_loss2d_backward_cuda
+    MPS: nll_loss2d_backward_mps
+
+- func: smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  python_module: nn
+  dispatch:
+    CPU, CUDA: smooth_l1_loss_out
+    MPS: smooth_l1_loss_out_mps
+
+- func: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  structured_delegate: smooth_l1_loss.out
+  python_module: nn
+
+- func: smooth_l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: smooth_l1_loss_backward_out
+    CUDA: smooth_l1_loss_backward_out
+    MPS: smooth_l1_loss_backward_out_mps
+
+- func: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor
+  python_module: nn
+  dispatch:
+    CompositeExplicitAutograd: smooth_l1_loss_backward
+
+- func: huber_loss.out(Tensor self, Tensor target, int reduction=Mean, float delta=1.0, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU, CUDA: huber_loss_out
+    MPS: huber_loss_out_mps
+
+- func: huber_loss(Tensor self, Tensor target, int reduction=Mean, float delta=1.0) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU, CUDA: huber_loss
+    MPS: huber_loss_mps
+
+- func: huber_loss_backward.out(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU, CUDA: huber_loss_backward_out
+    MPS: huber_loss_backward_out_mps
+
+- func: huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor
+  python_module: nn
+  dispatch:
+    CompositeExplicitAutograd: huber_loss_backward
+
+- func: soft_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CompositeExplicitAutograd: soft_margin_loss_out
+
+- func: soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
+  python_module: nn
+  dispatch:
+    CompositeExplicitAutograd: soft_margin_loss
+
+- func: soft_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CompositeExplicitAutograd: soft_margin_loss_backward_out
+
+- func: soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
+  python_module: nn
+  dispatch:
+    CompositeExplicitAutograd: soft_margin_loss_backward
+
+- func: elu.out(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    CPU, CUDA: elu_out
+    MPS: elu_out_mps
+
+- func: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor
+  structured_delegate: elu.out
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  tags: [core, pointwise]
+
+- func: elu_backward.grad_input(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result, *, Tensor(a!) grad_input) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  python_module: nn
+  dispatch:
+    CPU, CUDA: elu_backward_out
+    MPS: elu_backward_out_mps
+
+- func: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor
+  structured_delegate: elu_backward.grad_input
+  python_module: nn
+
+- func: elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!)
+  structured_delegate: elu.out
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+
+- func: glu.out(Tensor self, int dim=-1, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  python_module: nn
+  dispatch:
+    CPU, CUDA: glu_out
+    MPS: glu_out_mps
+
+- func: glu(Tensor self, int dim=-1) -> Tensor
+  structured_delegate: glu.out
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+
+- func: glu_backward.grad_input(Tensor grad_output, Tensor self, int dim, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: glu_backward_cpu_out
+    CUDA: glu_backward_cuda_out
+    MPS: glu_backward_mps_out
+
+- func: glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: glu_backward_cpu
+    CUDA: glu_backward_cuda
+    MPS: glu_backward_mps
+
+- func: glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU, CUDA: glu_jvp
+  autogen: glu_jvp.out
+
+- func: glu_backward_jvp(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU, CUDA: glu_backward_jvp
+  autogen: glu_backward_jvp.out
+
+- func: hardsigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    CPU, CUDA, MPS: hardsigmoid_out
+    QuantizedCPU: hardsigmoid_out_quantized_cpu
+
+- func: hardsigmoid(Tensor self) -> Tensor
+  structured_delegate: hardsigmoid.out
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    QuantizedCPU: hardsigmoid_quantized_cpu
+  tags: pointwise
+
+- func: hardsigmoid_(Tensor(a!) self) -> Tensor(a!)
+  structured_delegate: hardsigmoid.out
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+
+- func: hardsigmoid_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  python_module: nn
+  dispatch:
+    CPU, CUDA, MPS: hardsigmoid_backward_out
+
+- func: hardsigmoid_backward(Tensor grad_output, Tensor self) -> Tensor
+  structured_delegate: hardsigmoid_backward.grad_input
+  python_module: nn
+
+- func: hardtanh.out(Tensor self, Scalar min_val=-1, Scalar max_val=1, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    CPU, CUDA, MPS: hardtanh_out
+    QuantizedCPU: hardtanh_out_quantized_cpu
+
+- func: hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    CPU, CUDA, MPS: hardtanh
+    QuantizedCPU: hardtanh_quantized_cpu
+  tags: [pointwise, core]
+
+- func: hardtanh_backward.grad_input(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU, CUDA: hardtanh_backward_out
+    MPS: hardtanh_backward_out_mps
+
+- func: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU, CUDA: hardtanh_backward
+    MPS: hardtanh_backward_mps
+
+- func: hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    CPU, CUDA, MPS: hardtanh_
+    QuantizedCPU: hardtanh_quantized_cpu_
+
+- func: hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    CPU, CUDA, MPS: hardswish_out
+
+- func: hardswish(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    CPU, CUDA, MPS: hardswish
+
+- func: hardswish_(Tensor(a!) self) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    CPU, CUDA, MPS: hardswish_
+
+- func: hardswish_backward(Tensor grad_output, Tensor self) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU, CUDA, MPS: hardswish_backward
+  autogen: hardswish_backward.out
+
+- func: leaky_relu.out(Tensor self, Scalar negative_slope=0.01, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    CPU, CUDA, MPS: leaky_relu_out
+    QuantizedCPU: leaky_relu_out_quantized_cpu
+
+- func: leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor
+  structured_delegate: leaky_relu.out
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    QuantizedCPU: leaky_relu_quantized_cpu
+  tags: core
+
+- func: leaky_relu_backward.grad_input(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result, *, Tensor(a!) grad_input) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  python_module: nn
+  dispatch:
+    CPU, CUDA, MPS: leaky_relu_backward_out
+
+- func: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor
+  structured_delegate: leaky_relu_backward.grad_input
+  python_module: nn
+
+- func: leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!)
+  structured_delegate: leaky_relu.out
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    QuantizedCPU: leaky_relu_quantized_cpu_
+
+- func: log_sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+
+- func: log_sigmoid(Tensor self) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+
+- func: log_sigmoid_forward.output(Tensor self, *, Tensor(a!) output, Tensor(b!) buffer) -> (Tensor(a!), Tensor(b!))
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    CPU: log_sigmoid_forward_out_cpu
+    CUDA: log_sigmoid_forward_out_cuda
+    MPS: log_sigmoid_forward_out_mps
+
+- func: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer)
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    CPU: log_sigmoid_forward_cpu
+    CUDA: log_sigmoid_forward_cuda
+    MPS: log_sigmoid_forward_mps
+
+- func: log_sigmoid_backward.grad_input(Tensor grad_output, Tensor self, Tensor buffer, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: log_sigmoid_backward_cpu_out
+    CUDA: log_sigmoid_backward_cuda_out
+    MPS: log_sigmoid_backward_mps_out
+
+- func: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: log_sigmoid_backward_cpu
+    CUDA: log_sigmoid_backward_cuda
+    MPS: log_sigmoid_backward_mps
+
+- func: rrelu_with_noise.out(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  tags: nondeterministic_seeded
+  dispatch:
+    CPU: rrelu_with_noise_out_cpu
+    CUDA: rrelu_with_noise_out_cuda
+
+- func: rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: rrelu_with_noise_cpu
+    CUDA: rrelu_with_noise_cuda
+  tags: nondeterministic_seeded
+  autogen: rrelu_with_noise_functional
+
+- func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor
+  python_module: nn
+  dispatch:
+    CompositeExplicitAutograd: rrelu_with_noise_backward
+  autogen: rrelu_with_noise_backward.out
+
+- func: rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
+  python_module: nn
+  tags: nondeterministic_seeded
+  dispatch:
+    CPU: rrelu_with_noise_cpu_
+    CUDA: rrelu_with_noise_cuda_
+
+- func: softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    CPU, CUDA: softplus_out
+    MPS: softplus_out_mps
+
+- func: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor
+  structured_delegate: softplus.out
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  tags: pointwise
+
+- func: softplus_backward.grad_input(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  python_module: nn
+  dispatch:
+    CPU, CUDA: softplus_backward_out
+    MPS: softplus_backward_out_mps
+
+- func: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold) -> Tensor
+  structured_delegate: softplus_backward.grad_input
+  python_module: nn
+
+- func: softshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  dispatch:
+    CPU, CUDA, MPS: softshrink_out
+
+- func: softshrink(Tensor self, Scalar lambd=0.5) -> Tensor
+  structured_delegate: softshrink.out
+  device_check: NoCheck   # TensorIterator
+  python_module: nn
+  tags: pointwise
+
+- func: softshrink_backward.grad_input(Tensor grad_output, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  python_module: nn
+  dispatch:
+    CPU, CUDA, MPS: softshrink_backward_out
+
+- func: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor
+  structured_delegate: softshrink_backward.grad_input
+  python_module: nn
+
+- func: adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: adaptive_avg_pool2d_out_cpu
+    CUDA: adaptive_avg_pool2d_out_cuda
+    MPS: adaptive_avg_pool2d_out_mps
+    MkldnnCPU: mkldnn_adaptive_avg_pool2d_out_stub
+
+- func: adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor
+  python_module: nn
+  dispatch:
+    CompositeImplicitAutograd: adaptive_avg_pool2d_symint
+
+- func: mkldnn_adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor
+  dispatch:
+    MkldnnCPU: mkldnn_adaptive_avg_pool2d
+
+- func: mkldnn_adaptive_avg_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    MkldnnCPU: mkldnn_adaptive_avg_pool2d_out
+
+- func: mkldnn_adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor
+  dispatch:
+    MkldnnCPU: mkldnn_adaptive_avg_pool2d_backward
+  autogen: mkldnn_adaptive_avg_pool2d_backward.out
+
+- func: _adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor
+  dispatch:
+    CPU: adaptive_avg_pool2d_cpu
+    CUDA: adaptive_avg_pool2d_cuda
+    MPS: adaptive_avg_pool2d_mps
+    QuantizedCPU: adaptive_avg_pool2d_quantized_cpu
+    QuantizedCUDA: adaptive_avg_pool2d_quantized_cuda
+  autogen: _adaptive_avg_pool2d.out
+  tags: core
+
+- func: _adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: adaptive_avg_pool2d_backward_cpu
+    CUDA: adaptive_avg_pool2d_backward_cuda
+    MPS: adaptive_avg_pool2d_backward_mps
+  autogen: _adaptive_avg_pool2d_backward.out
+  tags: core
+
+- func: adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: adaptive_avg_pool3d_out_cpu
+    CUDA: adaptive_avg_pool3d_out_cuda
+    QuantizedCPU: adaptive_avg_pool3d_out_quantized_cpu
+
+- func: adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor
+  python_module: nn
+  dispatch:
+    CompositeImplicitAutograd: adaptive_avg_pool3d_symint
+
+- func: _adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor
+  dispatch:
+    CPU: adaptive_avg_pool3d_cpu
+    CUDA: adaptive_avg_pool3d_cuda
+    QuantizedCPU: adaptive_avg_pool3d_quantized_cpu
+  autogen: _adaptive_avg_pool3d.out
+  tags: core
+
+- func: adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: adaptive_avg_pool3d_backward_out_cpu
+    CUDA: adaptive_avg_pool3d_backward_out_cuda
+
+- func: _adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: adaptive_avg_pool3d_backward_cpu
+    CUDA: adaptive_avg_pool3d_backward_cuda
+  autogen: _adaptive_avg_pool3d_backward.out
+
+# Return: (Tensor output, Tensor indices)
+- func: adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: adaptive_max_pool2d_out_cpu
+    CUDA: adaptive_max_pool2d_out_cuda
+    MPS: adaptive_max_pool2d_out_mps
+
+# Return: (Tensor output, Tensor indices)
+- func: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)
+  python_module: nn
+  structured_delegate: adaptive_max_pool2d.out
+
+- func: adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: adaptive_max_pool2d_backward_out_cpu
+    CUDA: adaptive_max_pool2d_backward_out_cuda
+    MPS: adaptive_max_pool2d_backward_out_mps
+
+- func: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor
+  python_module: nn
+  structured_delegate: adaptive_max_pool2d_backward.grad_input
+
+# Return: (Tensor output, Tensor indices)
+- func: adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: adaptive_max_pool3d_out_cpu
+    CUDA: adaptive_max_pool3d_out_cuda
+
+# Return: (Tensor output, Tensor indices)
+- func: adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor)
+  python_module: nn
+  structured_delegate: adaptive_max_pool3d.out
+
+- func: adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: adaptive_max_pool3d_backward_out_cpu
+    CUDA: adaptive_max_pool3d_backward_out_cuda
+
+- func: adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor
+  python_module: nn
+  structured_delegate: adaptive_max_pool3d_backward.grad_input
+
+- func: avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  precomputed:
+  - kernel_size -> int kH, int kW
+  - stride -> int dH, int dW
+  - padding -> int padH, int padW
+  dispatch:
+    CPU: avg_pool2d_out_cpu
+    CUDA: avg_pool2d_out_cuda
+    MPS: avg_pool2d_out_mps
+    MkldnnCPU: mkldnn_avg_pool2d_out
+
+- func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor
+  python_module: nn
+  structured_delegate: avg_pool2d.out
+  dispatch:
+    MkldnnCPU: mkldnn_avg_pool2d
+    QuantizedCPU: avg_pool2d_quantized_cpu
+  tags: core
+
+- func: avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: avg_pool2d_backward_out_cpu
+    CUDA: avg_pool2d_backward_out_cuda
+    MPS: avg_pool2d_backward_out_mps
+    MkldnnCPU: mkldnn_avg_pool2d_backward_out
+
+- func: avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor
+  python_module: nn
+  structured_delegate: avg_pool2d_backward.grad_input
+  dispatch:
+    MkldnnCPU: mkldnn_avg_pool2d_backward
+  tags: core
+
+- func: avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: avg_pool3d_out_cpu
+    CUDA: avg_pool3d_out_cuda
+    MkldnnCPU: mkldnn_avg_pool3d_out
+
+- func: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor
+  python_module: nn
+  structured_delegate: avg_pool3d.out
+  dispatch:
+    MkldnnCPU: mkldnn_avg_pool3d
+    QuantizedCPU: avg_pool3d_quantized_cpu
+  tags: core
+
+- func: avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: avg_pool3d_backward_out_cpu
+    CUDA: avg_pool3d_backward_out_cuda
+    MkldnnCPU: mkldnn_avg_pool3d_backward_out
+
+- func: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor
+  python_module: nn
+  structured_delegate: avg_pool3d_backward.grad_input
+  dispatch:
+    MkldnnCPU: mkldnn_avg_pool3d_backward
+
+# Return: (Tensor output, Tensor indices)
+- func: fractional_max_pool2d.output(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: fractional_max_pool2d_out_cpu
+    CUDA: fractional_max_pool2d_out_cuda
+
+# Return: (Tensor output, Tensor indices)
+- func: fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor)
+  python_module: nn
+  structured_delegate: fractional_max_pool2d.output
+
+- func: fractional_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: fractional_max_pool2d_backward_cpu
+    CUDA: fractional_max_pool2d_backward_cuda
+
+- func: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor
+  python_module: nn
+  structured_delegate: fractional_max_pool2d_backward.grad_input
+
+# Return: (Tensor output, Tensor indices)
+- func: fractional_max_pool3d.output(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+  python_module: nn
+  structured: True
+  precomputed:
+  - kernel_size -> int poolSizeT, int poolSizeH, int poolSizeW
+  - output_size -> int outputT, int outputH, int outputW
+  - int numBatch, int numPlanes, int inputT, int inputH, int inputW
+  dispatch:
+    CPU: fractional_max_pool3d_out_cpu
+    CUDA: fractional_max_pool3d_out_cuda
+
+# Return: (Tensor output, Tensor indices)
+- func: fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor)
+  python_module: nn
+  structured_delegate: fractional_max_pool3d.output
+
+- func: fractional_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: fractional_max_pool3d_backward_out_cpu
+    CUDA: fractional_max_pool3d_backward_out_cuda
+
+- func: fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: fractional_max_pool3d_backward_cpu
+    CUDA: fractional_max_pool3d_backward_cuda
+
+# Return: (Tensor output, Tensor indices)
+- func: max_pool2d_with_indices.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: max_pool2d_with_indices_out_cpu
+    CUDA: max_pool2d_with_indices_out_cuda
+    MPS: max_pool2d_with_indices_out_mps
+
+# Return: (Tensor output, Tensor indices)
+- func: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
+  python_module: nn
+  structured_delegate: max_pool2d_with_indices.out
+  tags: core
+
+- func: max_pool2d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: max_pool2d_with_indices_backward_out_cpu
+    CUDA: max_pool2d_with_indices_backward_out_cuda
+    MPS: max_pool2d_with_indices_backward_out_mps
+
+- func: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor
+  python_module: nn
+  structured_delegate: max_pool2d_with_indices_backward.grad_input
+  tags: core
+
+# Return: (Tensor output, Tensor indices)
+- func: max_pool3d_with_indices.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))
+  python_module: nn
+  dispatch:
+    CPU: max_pool3d_with_indices_out_cpu
+    CUDA: max_pool3d_with_indices_out_cuda
+
+# Return: (Tensor output, Tensor indices)
+- func: max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
+  python_module: nn
+  dispatch:
+    CPU: max_pool3d_with_indices_cpu
+    CUDA: max_pool3d_with_indices_cuda
+  tags: core
+
+- func: max_pool3d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: max_pool3d_with_indices_backward_out_cpu
+    CUDA: max_pool3d_with_indices_backward_out_cuda
+
+- func: max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: max_pool3d_with_indices_backward_cpu
+    CUDA: max_pool3d_with_indices_backward_cuda
+
+- func: max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: max_unpooling2d_forward_out_cpu
+    CUDA: max_unpooling2d_forward_out_cuda
+
+- func: max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: max_unpooling2d_forward_cpu
+    CUDA: max_unpooling2d_forward_cuda
+
+- func: max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: max_unpooling3d_forward_out_cpu
+    CUDA: max_unpooling3d_forward_out_cuda
+
+- func: max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: max_unpooling3d_forward_cpu
+    CUDA: max_unpooling3d_forward_cuda
+
+- func: reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: reflection_pad1d_out_cpu
+    QuantizedCPU: reflection_pad1d_out_quantized_cpu
+    CUDA: reflection_pad1d_out_cuda
+    MPS: reflection_pad1d_out_mps
+
+- func: reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor
+  python_module: nn
+  structured_delegate: reflection_pad1d.out
+  tags: core
+
+- func: reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: reflection_pad1d_backward_out_cpu
+    CUDA: reflection_pad1d_backward_out_cuda
+    MPS: reflection_pad1d_backward_out_mps
+
+- func: reflection_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor
+  python_module: nn
+  structured_delegate: reflection_pad1d_backward.grad_input
+
+- func: reflection_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU, QuantizedCPU: reflection_pad2d_out_cpu
+    CUDA: reflection_pad2d_out_cuda
+    MPS: reflection_pad2d_out_mps
+
+- func: reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: reflection_pad2d_cpu
+    QuantizedCPU: reflection_pad2d_quantized_cpu
+    CUDA: reflection_pad2d_cuda
+    MPS: reflection_pad2d_mps
+  tags: core
+
+- func: reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: reflection_pad2d_backward_out_cpu
+    CUDA: reflection_pad2d_backward_out_cuda
+    MPS: reflection_pad2d_backward_out_mps
+
+- func: reflection_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: reflection_pad2d_backward_cpu
+    CUDA: reflection_pad2d_backward_cuda
+    MPS: reflection_pad2d_backward_mps
+
+- func: reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: reflection_pad3d_out_cpu
+    CUDA: reflection_pad3d_out_cuda
+    MPS: reflection_pad3d_out_mps
+
+- func: reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor
+  python_module: nn
+  structured_delegate: reflection_pad3d.out
+  tags: core
+
+- func: reflection_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: reflection_pad3d_backward_out_cpu
+    CUDA: reflection_pad3d_backward_out_cuda
+    MPS: reflection_pad3d_backward_out_mps
+
+- func: reflection_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor
+  python_module: nn
+  structured_delegate: reflection_pad3d_backward.grad_input
+
+- func: replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: replication_pad1d_out_cpu
+    CUDA: replication_pad1d_out_cuda
+    MPS: replication_pad1d_out_mps
+
+- func: replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor
+  python_module: nn
+  structured_delegate: replication_pad1d.out
+
+- func: replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: replication_pad1d_backward_out_cpu
+    CUDA: replication_pad1d_backward_out_cuda
+    MPS: replication_pad1d_backward_out_mps
+
+- func: replication_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor
+  python_module: nn
+  structured_delegate: replication_pad1d_backward.grad_input
+
+- func: replication_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: replication_pad2d_out_cpu
+    CUDA: replication_pad2d_out_cuda
+    MPS: replication_pad2d_out_mps
+
+- func: replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor
+  python_module: nn
+  structured_delegate: replication_pad2d.out
+  tags: core
+
+- func: replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: replication_pad2d_backward_out_cpu
+    CUDA: replication_pad2d_backward_out_cuda
+    MPS: replication_pad2d_backward_out_mps
+
+- func: replication_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: replication_pad2d_backward_cpu
+    CUDA: replication_pad2d_backward_cuda
+    MPS: replication_pad2d_backward_mps
+
+- func: replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: replication_pad3d_out_cpu
+    CUDA: replication_pad3d_out_cuda
+    MPS: replication_pad3d_out_mps
+
+- func: replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor
+  python_module: nn
+  structured_delegate: replication_pad3d.out
+  tags: core
+
+
+- func: replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: replication_pad3d_backward_out_cpu
+    CUDA: replication_pad3d_backward_out_cuda
+    MPS: replication_pad3d_backward_out_mps
+
+- func: replication_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: replication_pad3d_backward_cpu
+    CUDA: replication_pad3d_backward_cuda
+    MPS: replication_pad3d_backward_mps
+
+- func: _pad_circular(Tensor self, SymInt[] pad) -> Tensor
+  python_module: nn
+  dispatch:
+    CompositeImplicitAutograd: _pad_circular_symint
+
+- func: _pad_enum(Tensor self, SymInt[] pad, int mode, float? value=None) -> Tensor
+  python_module: nn
+  dispatch:
+    CompositeImplicitAutograd: _pad_enum_symint
+
+- func: pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor
+  python_module: nn
+  dispatch:
+    CompositeImplicitAutograd: pad_symint
+
+- func: upsample_linear1d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
+  python_module: nn
+  autogen: upsample_linear1d.vec_out
+
+- func: upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
+  python_module: nn
+  autogen: upsample_bilinear2d.vec_out
+  tags: core
+
+- func: _upsample_bilinear2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
+  python_module: nn
+  autogen: _upsample_bilinear2d_aa.vec_out
+
+- func: upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
+  python_module: nn
+  autogen: upsample_trilinear3d.vec_out
+
+- func: upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
+  python_module: nn
+  autogen: upsample_bicubic2d.vec_out
+
+- func: _upsample_bicubic2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor
+  python_module: nn
+  autogen: _upsample_bicubic2d_aa.vec_out
+
+- func: upsample_nearest1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
+  python_module: nn
+  autogen: upsample_nearest1d.vec_out
+
+- func: _upsample_nearest_exact1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
+  python_module: nn
+  autogen: _upsample_nearest_exact1d.vec_out
+
+- func: upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
+  python_module: nn
+  autogen: upsample_nearest2d.vec_out
+  tags: core
+
+- func: _upsample_nearest_exact2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
+  python_module: nn
+  autogen: _upsample_nearest_exact2d.vec_out
+
+- func: upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
+  python_module: nn
+  autogen: upsample_nearest3d.vec_out
+
+- func: _upsample_nearest_exact3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor
+  python_module: nn
+  autogen: _upsample_nearest_exact3d.vec_out
+
+# NOTE: all of the non-"vec" upsample overloads are only kept for backward compatibility.
+- func: upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: upsample_linear1d_out_cpu
+    CUDA: upsample_linear1d_out_cuda
+    MPS: upsample_linear1d_out_mps
+
+- func: upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor
+  python_module: nn
+  structured_delegate: upsample_linear1d.out
+
+- func: upsample_linear1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: upsample_linear1d_backward_out_cpu
+    CUDA: upsample_linear1d_backward_out_cuda
+    MPS: upsample_linear1d_backward_out_mps
+
+- func: upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor
+  python_module: nn
+  structured_delegate: upsample_linear1d_backward.grad_input
+
+- func: upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: upsample_bilinear2d_out_cpu
+    CUDA: upsample_bilinear2d_out_cuda
+    MPS: upsample_bilinear2d_out_mps
+
+- func: upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+  python_module: nn
+  structured_delegate: upsample_bilinear2d.out
+  dispatch:
+    QuantizedCPU: upsample_bilinear2d_quantized_cpu
+
+- func: upsample_bilinear2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: upsample_bilinear2d_backward_out_cpu
+    CUDA: upsample_bilinear2d_backward_out_cuda
+    MPS: upsample_bilinear2d_backward_out_mps
+
+- func: upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+  python_module: nn
+  structured_delegate: upsample_bilinear2d_backward.grad_input
+
+- func: _upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: _upsample_bilinear2d_aa_out_cpu
+    CUDA: _upsample_bilinear2d_aa_out_cuda
+    MPS: _upsample_bilinear2d_aa_out_mps
+
+- func: _upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+  python_module: nn
+  structured_delegate: _upsample_bilinear2d_aa.out
+
+- func: _upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: _upsample_bilinear2d_aa_backward_out_cpu
+    CUDA: _upsample_bilinear2d_aa_backward_out_cuda
+
+- func: _upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+  python_module: nn
+  structured_delegate: _upsample_bilinear2d_aa_backward.grad_input
+
+- func: upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: upsample_bicubic2d_out_cpu
+    CUDA: upsample_bicubic2d_out_cuda
+    MPS: upsample_bicubic2d_out_mps
+
+- func: upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+  python_module: nn
+  structured_delegate: upsample_bicubic2d.out
+
+- func: upsample_bicubic2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: upsample_bicubic2d_backward_out_cpu
+    CUDA: upsample_bicubic2d_backward_out_cuda
+    MPS: upsample_bicubic2d_backward_out_mps
+
+- func: upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+  python_module: nn
+  structured_delegate: upsample_bicubic2d_backward.grad_input
+
+- func: _upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: _upsample_bicubic2d_aa_out_cpu
+    CUDA: _upsample_bicubic2d_aa_out_cuda
+    MPS: _upsample_bicubic2d_aa_out_mps
+
+- func: _upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+  python_module: nn
+  structured_delegate: _upsample_bicubic2d_aa.out
+
+- func: _upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: _upsample_bicubic2d_aa_backward_out_cpu
+    CUDA: _upsample_bicubic2d_aa_backward_out_cuda
+
+- func: _upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+  python_module: nn
+  structured_delegate: _upsample_bicubic2d_aa_backward.grad_input
+
+- func: upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: upsample_trilinear3d_out_cpu
+    CUDA: upsample_trilinear3d_out_cuda
+    MPS: upsample_trilinear3d_out_mps
+
+- func: upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+  python_module: nn
+  structured_delegate: upsample_trilinear3d.out
+
+- func: upsample_trilinear3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: upsample_trilinear3d_backward_out_cpu
+    CUDA: upsample_trilinear3d_backward_out_cuda
+    MPS: upsample_trilinear3d_backward_out_mps
+
+- func: upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+  python_module: nn
+  structured_delegate: upsample_trilinear3d_backward.grad_input
+
+- func: upsample_nearest1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: upsample_nearest1d_out_cpu
+    CUDA: upsample_nearest1d_out_cuda
+    MPS: upsample_nearest1d_out_mps
+
+- func: _upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: _upsample_nearest_exact1d_out_cpu
+    CUDA: _upsample_nearest_exact1d_out_cuda
+    MPS: _upsample_nearest_exact1d_out_mps
+
+- func: upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor
+  python_module: nn
+  structured_delegate: upsample_nearest1d.out
+
+- func: _upsample_nearest_exact1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor
+  python_module: nn
+  structured_delegate: _upsample_nearest_exact1d.out
+
+- func: upsample_nearest1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: upsample_nearest1d_backward_out_cpu
+    CUDA: upsample_nearest1d_backward_out_cuda
+    MPS: upsample_nearest1d_backward_out_mps
+
+- func: _upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: _upsample_nearest_exact1d_backward_out_cpu
+    CUDA: _upsample_nearest_exact1d_backward_out_cuda
+    MPS: _upsample_nearest_exact1d_backward_out_mps
+
+- func: upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor
+  python_module: nn
+  structured_delegate: upsample_nearest1d_backward.grad_input
+
+- func: _upsample_nearest_exact1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor
+  python_module: nn
+  structured_delegate: _upsample_nearest_exact1d_backward.grad_input
+
+- func: upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: upsample_nearest2d_out_cpu
+    CUDA: upsample_nearest2d_out_cuda
+    MPS: upsample_nearest2d_out_mps
+
+- func: _upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: _upsample_nearest_exact2d_out_cpu
+    CUDA: _upsample_nearest_exact2d_out_cuda
+    MPS: _upsample_nearest_exact2d_out_mps
+
+- func: upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor
+  python_module: nn
+  structured_delegate: upsample_nearest2d.out
+  dispatch:
+    QuantizedCPU: upsample_nearest2d_quantized_cpu
+
+- func: _upsample_nearest_exact2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor
+  python_module: nn
+  structured_delegate: _upsample_nearest_exact2d.out
+  dispatch:
+    QuantizedCPU: _upsample_nearest_exact2d_quantized_cpu
+
+- func: upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: upsample_nearest2d_backward_out_cpu
+    CUDA: upsample_nearest2d_backward_out_cuda
+    MPS: upsample_nearest2d_backward_out_mps
+
+- func: _upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: _upsample_nearest_exact2d_backward_out_cpu
+    CUDA: _upsample_nearest_exact2d_backward_out_cuda
+    MPS: _upsample_nearest_exact2d_backward_out_mps
+
+- func: upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor
+  python_module: nn
+  structured_delegate: upsample_nearest2d_backward.grad_input
+
+- func: _upsample_nearest_exact2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor
+  python_module: nn
+  structured_delegate: _upsample_nearest_exact2d_backward.grad_input
+
+- func: upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: upsample_nearest3d_out_cpu
+    CUDA: upsample_nearest3d_out_cuda
+    MPS: upsample_nearest3d_out_mps
+
+- func: _upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: _upsample_nearest_exact3d_out_cpu
+    CUDA: _upsample_nearest_exact3d_out_cuda
+    MPS: _upsample_nearest_exact3d_out_mps
+
+- func: upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+  python_module: nn
+  structured_delegate: upsample_nearest3d.out
+  dispatch:
+    QuantizedCPU: upsample_nearest3d_quantized_cpu
+
+- func: _upsample_nearest_exact3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+  python_module: nn
+  structured_delegate: _upsample_nearest_exact3d.out
+  dispatch:
+    QuantizedCPU: _upsample_nearest_exact3d_quantized_cpu
+
+- func: upsample_nearest3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: upsample_nearest3d_backward_out_cpu
+    CUDA: upsample_nearest3d_backward_out_cuda
+    MPS: upsample_nearest3d_backward_out_mps
+
+- func: _upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: _upsample_nearest_exact3d_backward_out_cpu
+    CUDA: _upsample_nearest_exact3d_backward_out_cuda
+    MPS: _upsample_nearest_exact3d_backward_out_mps
+
+- func: upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+  python_module: nn
+  structured_delegate: upsample_nearest3d_backward.grad_input
+
+- func: _upsample_nearest_exact3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+  python_module: nn
+  structured_delegate: _upsample_nearest_exact3d_backward.grad_input
+
+- func: sigmoid_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: sigmoid_backward_out
+    MPS: sigmoid_backward_out_mps
+  tags: pointwise
+
+- func: sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor
+  python_module: nn
+  structured_delegate: sigmoid_backward.grad_input
+  tags: pointwise
+
+- func: logit_backward.grad_input(Tensor grad_output, Tensor self, float? eps=None, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: logit_backward_out
+    MPS: logit_backward_out_mps
+  tags: pointwise
+
+- func: logit_backward(Tensor grad_output, Tensor self, float? eps=None) -> Tensor
+  python_module: nn
+  structured_delegate: logit_backward.grad_input
+  tags: pointwise
+
+- func: tanh_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MTIA: tanh_backward_out
+    MPS: tanh_backward_out_mps
+  tags: pointwise
+
+- func: tanh_backward(Tensor grad_output, Tensor output) -> Tensor
+  python_module: nn
+  structured_delegate: tanh_backward.grad_input
+
+# What's a thnn_conv_ versus a slow_conv_?
+#
+# Historically, we have inefficient implementations of convolutions
+# coming from the THNN/THCUNN library.  These convolutions typically
+# operated by computing the Toeplitz matrix and then doing a matrix
+# multiply with the input; this is very memory inefficient!  However,
+# occasionally, we really don't have anything better, so it's helpful
+# to have these fallbacks when there is no more optimized implementation
+# in cudnn or mkldnn, etc.  Both thnn_ and slow_ convolutions fall
+# into this bucket.
+#
+# The difference between these two designations, is that thnn_ refers
+# to a convolution that is still written in the "legacy" style; that is,
+# C code in the THNN/ or THCUNN/ directory.  A slow_ convolution is
+# one that is written in the native style: modern C++.  Algorithmically,
+# these are the same thing, but we give them different prefixes to
+# make the operational distinction clear.
+  tags: pointwise
+
+- func: slow_conv_transpose2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  structured: True
+  dispatch:
+    CPU: slow_conv_transpose2d_structured_cpu
+    CUDA: slow_conv_transpose2d_structured_cuda
+
+- func: slow_conv_transpose2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1) -> Tensor
+  python_module: nn
+  structured_delegate: slow_conv_transpose2d.out
+
+- func: slow_conv_transpose3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: slow_conv_transpose3d_out_cpu
+    CUDA: slow_conv_transpose3d_out_cuda
+
+- func: slow_conv_transpose3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: slow_conv_transpose3d_cpu
+    CUDA: slow_conv_transpose3d_cuda
+
+- func: thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+
+- func: thnn_conv2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0) -> Tensor
+  python_module: nn
+
+- func: _slow_conv2d_forward.output(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) output) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: slow_conv2d_forward_out_cpu
+    CUDA: slow_conv2d_forward_out_cuda
+
+- func: _slow_conv2d_forward(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: slow_conv2d_forward_cpu
+    CUDA: slow_conv2d_forward_cuda
+
+- func: _slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))
+  python_module: nn
+  dispatch:
+    CPU: slow_conv2d_backward_out_cpu
+    CUDA: slow_conv2d_backward_out_cuda
+
+- func: _slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
+  python_module: nn
+  dispatch:
+    CPU: slow_conv2d_backward_cpu
+    CUDA: slow_conv2d_backward_cuda
+  autogen: _slow_conv2d_backward.output_mask_out
+
+- func: _conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CUDA: conv_depthwise2d_cuda_out
+
+- func: _conv_depthwise2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor
+  python_module: nn
+  dispatch:
+    CUDA: conv_depthwise2d_cuda
+
+- func: conv_depthwise3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation) -> Tensor
+  python_module: nn
+  dispatch:
+    CUDA: conv_depthwise3d_cuda
+  autogen: conv_depthwise3d.out
+
+- func: slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+
+- func: slow_conv3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0) -> Tensor
+  python_module: nn
+
+- func: slow_conv3d_forward.output(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: slow_conv3d_forward_out_cpu
+
+- func: slow_conv3d_forward(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: slow_conv3d_forward_cpu
+
+- func: slow_conv_dilated2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: slow_conv_dilated2d_cpu
+    CUDA: slow_conv_dilated2d_cuda
+  autogen: slow_conv_dilated2d.out
+
+- func: slow_conv_dilated3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: slow_conv_dilated3d_cpu
+    CUDA: slow_conv_dilated3d_cuda
+  autogen: slow_conv_dilated3d.out
+
+- func: col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: col2im_out_cpu
+    CUDA: col2im_out_cuda
+    MPS: col2im_out_mps
+
+- func: col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: col2im_cpu
+    CUDA: col2im_cuda
+    MPS: col2im_mps
+  tags: core
+
+- func: column_stack(Tensor[] tensors) -> Tensor
+
+- func: column_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: im2col.out(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: nn
+  dispatch:
+    CPU: im2col_out_cpu
+    CUDA: im2col_out_cuda
+    MPS: im2col_out_mps
+
+- func: im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: im2col_cpu
+    CUDA: im2col_cuda
+    MPS: im2col_mps
+
+- func: isfinite(Tensor self) -> Tensor
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  tags: pointwise
+
+- func: isinf(Tensor self) -> Tensor
+  variants: function, method
+  device_check: NoCheck
+  device_guard: False
+  dispatch:
+    CompositeExplicitAutograd: isinf
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_isinf
+    SparseCPU, SparseCUDA: isinf_sparse
+    SparseMeta: isinf_sparse_meta
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isinf_sparse_csr
+  autogen: isinf.out
+  tags: [core, pointwise]
+
+- func: record_stream(Tensor(a!) self, Stream s) -> ()
+  variants: method
+  dispatch:
+    CUDA: record_stream_cuda
+
+- func: isposinf(Tensor self) -> Tensor
+  variants: function, method
+  structured_delegate: isposinf.out
+  dispatch:
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_isposinf
+    SparseCPU, SparseCUDA: isposinf_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isposinf_sparse_csr
+  tags: pointwise
+
+- func: isposinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: isposinf_out
+    SparseCPU, SparseCUDA: isposinf_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isposinf_sparse_csr_out
+  tags: pointwise
+
+- func: isneginf(Tensor self) -> Tensor
+  variants: function, method
+  structured_delegate: isneginf.out
+  dispatch:
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_isneginf
+    SparseCPU, SparseCUDA: isneginf_sparse
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isneginf_sparse_csr
+  tags: pointwise
+
+- func: isneginf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: isneginf_out
+    SparseCPU, SparseCUDA: isneginf_sparse_out
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: isneginf_sparse_csr_out
+  tags: pointwise
+
+# NOTE [_add_batch_dim and _remove_batch_dim]
+# _add_batch_dim and _remove_batch_dim are meant to be used in the implementation
+# of the vmap frontend API (see torch/_vmap_internals.py). They are not
+# user-facing, hence the leading underscore. Please don't use them them anywhere else.
+- func: _add_batch_dim(Tensor self, int batch_dim, int level) -> Tensor
+  variants: function
+
+# See NOTE [_add_batch_dim and _remove_batch_dim]
+- func: _remove_batch_dim(Tensor self, int level, SymInt batch_size, int out_dim) -> Tensor
+  variants: function
+
+## Functions related to the `torch.special` namespace
+# Note [special namespace binding]
+# Functions in the special python module should have their names start with
+#   "special_" underscore and be bound to the desired Python name in
+#   torch/special/__init__.py, and the desired C++ name in torch/csrc/api/include/torch/special.h.
+#   The "special_" names should be hidden from the user and not documented.
+
+- func: special_entr(Tensor self) -> Tensor
+  structured_delegate: special_entr.out
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_entr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  python_module: special
+  variants: function
+  dispatch:
+    CPU, CUDA, MPS: special_entr_out
+  tags: pointwise
+
+- func: special_ndtri(Tensor self) -> Tensor
+  structured_delegate: special_ndtri.out
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_ndtri.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  python_module: special
+  variants: function
+  dispatch:
+    CPU, CUDA: special_ndtri_out
+  tags: pointwise
+
+- func: special_log_ndtr(Tensor self) -> Tensor
+  structured_delegate: special_log_ndtr.out
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_log_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  structured_inherits: TensorIteratorBase
+  python_module: special
+  variants: function
+  dispatch:
+    CPU, CUDA: special_log_ndtr_out
+  tags: pointwise
+
+- func: special_expm1(Tensor self) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  variants: function
+
+- func: special_exp2(Tensor self) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  variants: function
+
+- func: special_psi(Tensor self) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_psi.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  variants: function
+
+- func: special_digamma(Tensor self) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  variants: function
+
+- func: special_gammaln(Tensor self) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_gammaln.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  variants: function
+
+- func: special_erf(Tensor self) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  variants: function
+
+- func: special_erfc(Tensor self) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+
+- func: special_erfcx(Tensor self) -> Tensor
+  python_module: special
+  variants: function
+  structured_delegate: special_erfcx.out
+  tags: pointwise
+
+- func: special_erfcx.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA: special_erfcx_out
+  tags: pointwise
+
+- func: special_erfinv(Tensor self) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+
+- func: special_ndtr(Tensor self) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  variants: function
+
+- func: special_xlog1py(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+  structured_delegate: special_xlog1py.out
+  tags: pointwise
+
+- func: special_xlog1py.self_scalar(Scalar self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: special_xlog1py
+  tags: pointwise
+
+- func: special_xlog1py.other_scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: special_xlog1py
+  tags: pointwise
+
+- func: special_xlog1py.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  python_module: special
+  variants: function
+  dispatch:
+    CPU, CUDA, MPS: special_xlog1py_out
+  tags: pointwise
+
+- func: special_xlog1py.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: special_xlog1py_out
+  tags: pointwise
+
+- func: special_xlog1py.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: special_xlog1py_out
+  tags: pointwise
+
+- func: special_xlogy(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+
+- func: special_xlogy.self_scalar(Scalar self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+
+- func: special_xlogy.other_scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+
+- func: special_xlogy.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+
+- func: special_xlogy.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+
+- func: special_xlogy.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+
+- func: special_zeta(Tensor self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+  structured_delegate: special_zeta.out
+  tags: pointwise
+
+- func: special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: special_zeta
+  tags: pointwise
+
+- func: special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: special_zeta
+  tags: pointwise
+
+- func: special_zeta.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  structured: True
+  structured_inherits: TensorIteratorBase
+  python_module: special
+  variants: function
+  dispatch:
+    CPU, CUDA, MPS: special_zeta_out
+  tags: pointwise
+
+- func: special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: special_zeta_out
+  tags: pointwise
+
+- func: special_zeta.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck   # TensorIterator
+  python_module: special
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: special_zeta_out
+  tags: pointwise
+
+- func: special_i0(Tensor self) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  variants: function
+
+- func: special_i0e(Tensor self) -> Tensor
+  python_module: special
+  variants: function
+  structured_delegate: special_i0e.out
+  tags: pointwise
+
+- func: special_i0e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: special_i0e_out
+  tags: pointwise
+
+- func: special_i1(Tensor self) -> Tensor
+  python_module: special
+  variants: function
+  structured_delegate: special_i1.out
+  tags: pointwise
+
+- func: special_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: special_i1_out
+  tags: pointwise
+
+- func: special_i1e(Tensor self) -> Tensor
+  python_module: special
+  variants: function
+  structured_delegate: special_i1e.out
+  tags: pointwise
+
+- func: special_i1e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  structured: True
+  structured_inherits: TensorIteratorBase
+  dispatch:
+    CPU, CUDA, MPS: special_i1e_out
+  tags: pointwise
+
+- func: special_logit(Tensor self, float? eps=None) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+
+- func: special_polygamma(int n, Tensor self) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+
+- func: special_logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+
+- func: special_expit(Tensor self) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_expit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  variants: function
+
+- func: special_sinc(Tensor self) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  variants: function
+
+- func: special_round(Tensor self, *, int decimals=0) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_round.out(Tensor self, *, int decimals=0, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  variants: function
+
+- func: special_log1p(Tensor self) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  variants: function
+
+- func: special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_gammainc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  variants: function
+
+- func: special_gammainc(Tensor self, Tensor other) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_gammaincc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  variants: function
+
+- func: special_gammaincc(Tensor self, Tensor other) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_multigammaln(Tensor self, int p) -> Tensor
+  python_module: special
+  variants: function
+
+- func: special_multigammaln.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: special
+  variants: function
+
+- func: special_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
+  python_module: special
+  variants: function
+
+## Functions related to the fast Fourier transform and the torch.fft namespace
+# Note [FFT namespace binding]
+# Functions in the fft python module should have their names start with
+#   "fft_" underscore and be bound to the desired Python name in
+#   torch/fft/__init__.py, and the desired C++ name in torch/csrc/api/include/torch/fft.h.
+#   The "fft_" names should be hidden from the user and not documented.
+#
+# See fft_fft as an example.
+
+# torch.fft.fft
+# NOTE: NOT an alias for torch.fft, which has different semantics
+- func: fft_fft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_fft_symint
+
+- func: fft_fft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_fft_symint_out
+
+- func: fft_ifft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_ifft_symint
+
+- func: fft_ifft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_ifft_symint_out
+
+- func: fft_rfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_rfft_symint
+
+- func: fft_rfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_rfft_symint_out
+
+- func: fft_irfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_irfft_symint
+
+- func: fft_irfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_irfft_symint_out
+
+- func: fft_hfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_hfft_symint
+
+- func: fft_hfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_hfft_symint_out
+
+- func: fft_ihfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_ihfft_symint
+
+- func: fft_ihfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_ihfft_symint_out
+
+- func: fft_fft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_fft2_symint
+
+- func: fft_fft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_fft2_symint_out
+
+- func: fft_ifft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_ifft2_symint
+
+- func: fft_ifft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_ifft2_symint_out
+
+- func: fft_rfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_rfft2_symint
+
+- func: fft_rfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_rfft2_symint_out
+
+- func: fft_irfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_irfft2_symint
+
+- func: fft_irfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_irfft2_symint_out
+
+- func: fft_hfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor
+  use_const_ref_for_mutable_tensors: True
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_hfft2_symint
+
+- func: fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_hfft2_symint_out
+
+- func: fft_ihfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor
+  use_const_ref_for_mutable_tensors: True
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_ihfft2_symint
+
+- func: fft_ihfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_ihfft2_symint_out
+
+- func: fft_fftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_fftn_symint
+
+- func: fft_fftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_fftn_symint_out
+
+- func: fft_ifftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_ifftn_symint
+
+- func: fft_ifftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_ifftn_symint_out
+
+- func: fft_rfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_rfftn_symint
+
+- func: fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_rfftn_symint_out
+
+- func: fft_irfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_irfftn_symint
+
+- func: fft_irfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_irfftn_symint_out
+
+- func: fft_hfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+  use_const_ref_for_mutable_tensors: True
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_hfftn_symint
+
+- func: fft_hfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_hfftn_symint_out
+
+- func: fft_ihfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor
+  use_const_ref_for_mutable_tensors: True
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_ihfftn_symint
+
+- func: fft_ihfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: fft_ihfftn_symint_out
+
+- func: fft_fftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: fft_fftfreq
+
+- func: fft_fftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: fft_fftfreq_out
+
+- func: fft_rfftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: fft_rfftfreq
+
+- func: fft_rfftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: fft
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: fft_rfftfreq_out
+
+- func: fft_fftshift(Tensor self, int[1]? dim=None) -> Tensor
+  python_module: fft
+  variants: function
+
+- func: fft_ifftshift(Tensor self, int[1]? dim=None) -> Tensor
+  python_module: fft
+  variants: function
+
+## Functions for linear algebra and the torch.linalg namespace
+# Note [linalg namespace binding]
+# Functions in the linalg python module should have their names start with
+#   "linalg_" and be bound to the desired Python name in
+#   torch/linalg/__init__.py, and the desired C++ name in torch/csrc/api/include/torch/linalg.h.
+#   The "linalg_" names should be hidden from the user and not documented.
+#
+# See linalg_det as an example.
+
+# "_ex" stands for experimental
+- func: linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info)
+  python_module: linalg
+  structured_delegate: linalg_cholesky_ex.L
+
+- func: linalg_cholesky_ex.L(Tensor self, *, bool upper=False, bool check_errors=False, Tensor(a!) L, Tensor(b!) info) -> (Tensor(a!) L, Tensor(b!) info)
+  python_module: linalg
+  structured: True
+  dispatch:
+    CPU, CUDA, MPS: linalg_cholesky_ex_out
+
+- func: linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor
+  python_module: linalg
+
+- func: linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+
+- func: linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor
+  python_module: linalg
+  variants: function
+  structured_delegate: linalg_cross.out
+  dispatch:
+    ZeroTensor: linalg_cross_zerotensor
+
+- func: linalg_cross.out(Tensor self, Tensor other, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  structured: True
+  dispatch:
+    CPU, CUDA, MPS: linalg_cross_out
+
+# linalg.lu_factor
+- func: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots)
+  python_module: linalg
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: linalg_lu_factor
+    MPS: linalg_lu_factor_mps
+
+- func: linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots)
+  python_module: linalg
+  variants: function
+  dispatch:
+    CompositeImplicitAutograd: linalg_lu_factor_out
+    MPS: linalg_lu_factor_out_mps
+
+- func: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info)
+  python_module: linalg
+  structured_delegate: linalg_lu_factor_ex.out
+  variants: function
+
+- func: linalg_lu_factor_ex.out(Tensor A, *, bool pivot=True, bool check_errors=False, Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info)
+  python_module: linalg
+  variants: function
+  structured: True
+  dispatch:
+    CPU, CUDA: linalg_lu_factor_ex_out
+    MPS: linalg_lu_factor_ex_out_mps
+
+# linalg.lu
+- func: linalg_lu(Tensor A, *, bool pivot=True) -> (Tensor P, Tensor L, Tensor U)
+  python_module: linalg
+  structured_delegate: linalg_lu.out
+  variants: function
+
+- func: linalg_lu.out(Tensor A, *, bool pivot=True, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U)
+  python_module: linalg
+  variants: function
+  structured: True
+  dispatch:
+    CPU, CUDA: linalg_lu_out
+
+# linalg.lu_solve
+- func: linalg_lu_solve(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False) -> Tensor
+  python_module: linalg
+  structured_delegate: linalg_lu_solve.out
+  variants: function
+
+- func: linalg_lu_solve.out(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  variants: function
+  structured: True
+  dispatch:
+    CPU, CUDA: linalg_lu_solve_out
+
+# linalg.det
+- func: _linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots)
+  structured_delegate: _linalg_det.result
+
+- func: _linalg_det.result(Tensor A, *, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots)
+  structured: True
+  dispatch:
+    CPU, CUDA, MPS: _linalg_det_out
+
+- func: linalg_det(Tensor A) -> Tensor
+  python_module: linalg
+  variants: function
+
+- func: linalg_det.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+
+# torch.det, alias for torch.linalg.det
+- func: det(Tensor self) -> Tensor
+  variants: function, method
+
+- func: linalg_ldl_factor_ex(Tensor self, *, bool hermitian=False, bool check_errors=False) -> (Tensor LD, Tensor pivots, Tensor info)
+  structured_delegate: linalg_ldl_factor_ex.out
+  python_module: linalg
+  variants: function
+
+- func: linalg_ldl_factor_ex.out(Tensor self, *, bool hermitian=False, bool check_errors=False, Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info)
+  structured: True
+  python_module: linalg
+  variants: function
+  dispatch:
+    CPU, CUDA: linalg_ldl_factor_ex_out
+
+- func: linalg_ldl_factor(Tensor self, *, bool hermitian=False) -> (Tensor LD, Tensor pivots)
+  python_module: linalg
+  variants: function
+
+- func: linalg_ldl_factor.out(Tensor self, *, bool hermitian=False, Tensor(a!) LD, Tensor(b!) pivots) -> (Tensor(a!) LD, Tensor(b!) pivots)
+  python_module: linalg
+  variants: function
+
+- func: linalg_ldl_solve(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False) -> Tensor
+  structured_delegate: linalg_ldl_solve.out
+  python_module: linalg
+  variants: function
+
+- func: linalg_ldl_solve.out(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)
+  structured: True
+  python_module: linalg
+  variants: function
+  dispatch:
+    CPU, CUDA: linalg_ldl_solve_out
+
+- func: linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
+  python_module: linalg
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: linalg_lstsq
+  tags: dynamic_output_shape
+
+- func: linalg_lstsq.out(Tensor self, Tensor b, float? rcond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values)
+  python_module: linalg
+  variants: function
+  dispatch:
+    CPU, CUDA: linalg_lstsq_out
+  tags: dynamic_output_shape
+
+# torch.linalg.matmul, alias for torch.matmul
+- func: linalg_matmul(Tensor self, Tensor other) -> Tensor
+  python_module: linalg
+  variants: function
+
+- func: linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+
+- func: linalg_vecdot(Tensor x, Tensor y, *, int dim=-1) -> Tensor
+  python_module: linalg
+  variants: function
+
+- func: linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+
+- func: linalg_matrix_exp(Tensor self) -> Tensor
+  python_module: linalg
+  variants: function
+  dispatch:
+    CPU, CUDA: linalg_matrix_exp
+  autogen: linalg_matrix_exp.out
+
+- func: _linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet, Tensor LU, Tensor pivots)
+  structured_delegate: _linalg_slogdet.sign
+
+- func: _linalg_slogdet.sign(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) -> (Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots)
+  structured: True
+  dispatch:
+    CPU, CUDA, MPS: _linalg_slogdet_out
+
+- func: linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet)
+  python_module: linalg
+
+- func: linalg_slogdet.out(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet)
+  python_module: linalg
+
+- func: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)
+  variants: function, method
+
+- func: slogdet.out(Tensor self, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet)
+  variants: function
+
+- func: logdet(Tensor self) -> Tensor
+  variants: function, method
+
+- func: linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors)
+  python_module: linalg
+  variants: function
+  dispatch:
+    CPU, CUDA: linalg_eig
+
+- func: linalg_eig.out(Tensor self, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
+  python_module: linalg
+  dispatch:
+    CPU, CUDA: linalg_eig_out
+
+- func: _linalg_eigvals(Tensor self) -> Tensor
+  python_module: linalg
+  dispatch:
+    CPU, CUDA: _linalg_eigvals
+
+- func: linalg_eigvals(Tensor self) -> Tensor
+  python_module: linalg
+
+- func: linalg_eigvals.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  dispatch:
+    CPU, CUDA: linalg_eigvals_out
+
+# This function is exposes the `compute_v` flag, which is then used to implement `linalg.eigh` and
+# `linalg.eigvalsh` as composite functions that call this one
+- func: _linalg_eigh(Tensor A, str UPLO="L", bool compute_v=True) -> (Tensor eigenvalues, Tensor eigenvectors)
+  structured_delegate: _linalg_eigh.eigenvalues
+
+- func: _linalg_eigh.eigenvalues(Tensor A, str UPLO="L", bool compute_v=True, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
+  structured: True
+  dispatch:
+    CPU, CUDA: _linalg_eigh_out
+
+- func: linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors)
+  python_module: linalg
+
+- func: linalg_eigh.eigvals(Tensor self, str UPLO="L", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
+  python_module: linalg
+
+- func: linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor
+  python_module: linalg
+
+- func: linalg_eigvalsh.out(Tensor self, str UPLO="L", *, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+
+- func: linalg_householder_product(Tensor input, Tensor tau) -> Tensor
+  python_module: linalg
+  variants: function
+  dispatch:
+    CPU, CUDA: linalg_householder_product
+
+- func: linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  dispatch:
+    CPU, CUDA: linalg_householder_product_out
+
+- func: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info)
+  python_module: linalg
+  structured_delegate: linalg_inv_ex.inverse
+
+- func: linalg_inv_ex.inverse(Tensor A, *, bool check_errors=False, Tensor(a!) inverse, Tensor(b!) info) -> (Tensor(a!) inverse, Tensor(b!) info)
+  python_module: linalg
+  structured: True
+  dispatch:
+    CPU, CUDA: linalg_inv_ex_out
+    MPS: linalg_inv_ex_out_mps
+
+- func: linalg_inv(Tensor A) -> Tensor
+  python_module: linalg
+
+- func: linalg_inv.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+
+- func: inverse(Tensor self) -> Tensor
+  variants: function, method
+
+- func: inverse.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: inner(Tensor self, Tensor other) -> Tensor
+  variants: function, method
+
+- func: inner.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: outer(Tensor self, Tensor vec2) -> Tensor
+  variants: function, method
+
+- func: outer.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)
+
+# torch.ger, alias for torch.outer
+- func: ger(Tensor self, Tensor vec2) -> Tensor
+  variants: function, method
+
+- func: ger.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  python_module: linalg
+  variants: function
+
+- func: linalg_norm.ord_str(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  python_module: linalg
+  variants: function
+
+- func: linalg_norm.out(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  variants: function
+
+- func: linalg_norm.ord_str_out(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  variants: function
+
+- func: linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  python_module: linalg
+  variants: function
+  structured_delegate: linalg_vector_norm.out
+
+- func: linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  structured: True
+  dispatch:
+    CPU, CUDA: linalg_vector_norm_out
+    MPS: linalg_vector_norm_out_mps
+
+- func: linalg_matrix_norm(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  python_module: linalg
+
+- func: linalg_matrix_norm.out(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+
+- func: linalg_matrix_norm.str_ord(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  python_module: linalg
+
+- func: linalg_matrix_norm.str_ord_out(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+
+# This function is exposes the `compute_uv` flag, which is then used to implement `linalg.svd` and
+# `linalg.svdvals` as composite functions that call this one
+- func: _linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)
+  variants: function
+  structured_delegate: _linalg_svd.U
+
+- func: _linalg_svd.U(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)
+  structured: True
+  dispatch:
+    CPU, CUDA: _linalg_svd_out
+
+- func: linalg_svd(Tensor A, bool full_matrices=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)
+  python_module: linalg
+  variants: function
+
+- func: linalg_svd.U(Tensor A, bool full_matrices=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)
+  python_module: linalg
+  variants: function
+
+- func: linalg_svdvals(Tensor A, *, str? driver=None) -> Tensor
+  python_module: linalg
+  variants: function
+
+- func: linalg_svdvals.out(Tensor A, *, str? driver=None, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  variants: function
+
+- func: linalg_cond(Tensor self, Scalar? p=None) -> Tensor
+  python_module: linalg
+  variants: function
+
+- func: linalg_cond.out(Tensor self, Scalar? p=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  variants: function
+
+- func: linalg_cond.p_str(Tensor self, str p) -> Tensor
+  python_module: linalg
+  variants: function
+
+- func: linalg_cond.p_str_out(Tensor self, str p, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  variants: function
+
+- func: linalg_pinv.atol_rtol_tensor(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor
+  python_module: linalg
+  variants: function
+  dispatch:
+    # calls svd, which calls mH() (view op)
+    # also calls narrow()
+    CompositeExplicitAutogradNonFunctional: linalg_pinv
+
+- func: linalg_pinv.atol_rtol_tensor_out(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: linalg_pinv_out
+
+- func: linalg_pinv.atol_rtol_float(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False) -> Tensor
+  cpp_no_default_args: ['atol', 'rtol']
+  python_module: linalg
+  variants: function
+
+- func: linalg_pinv.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)
+  cpp_no_default_args: ['atol', 'rtol']
+  python_module: linalg
+  variants: function
+
+- func: linalg_pinv(Tensor self, float rcond, bool hermitian=False) -> Tensor
+  python_module: linalg
+  variants: function
+
+- func: linalg_pinv.rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False) -> Tensor
+  python_module: linalg
+  variants: function
+
+- func: linalg_pinv.out(Tensor self, float rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  variants: function
+
+- func: linalg_pinv.out_rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  variants: function
+
+- func: _linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info)
+  structured_delegate: _linalg_solve_ex.result
+
+- func: _linalg_solve_ex.result(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info)
+  structured: True
+  dispatch:
+    CPU, CUDA: _linalg_solve_ex_out
+    MPS: _linalg_solve_ex_out_mps
+
+- func: linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor info)
+  python_module: linalg
+
+- func: linalg_solve_ex.out(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) info) -> (Tensor(a!) result, Tensor(b!) info)
+  python_module: linalg
+
+- func: linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor
+  python_module: linalg
+
+- func: _spsolve(Tensor A, Tensor B, *, bool left=True) -> Tensor
+  python_module: sparse
+  dispatch:
+    SparseCsrCUDA: _sparse_csr_linear_solve
+
+- func: linalg_solve.out(Tensor A, Tensor B, *, bool left=True, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+
+- func: linalg_tensorinv(Tensor self, int ind=2) -> Tensor
+  python_module: linalg
+  variants: function
+
+- func: linalg_tensorinv.out(Tensor self, int ind=2, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  variants: function
+
+- func: linalg_tensorsolve(Tensor self, Tensor other, int[]? dims=None) -> Tensor
+  python_module: linalg
+  variants: function
+
+- func: linalg_tensorsolve.out(Tensor self, Tensor other, int[]? dims=None, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  variants: function
+
+- func: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R)
+  python_module: linalg
+  variants: function
+  structured_delegate: linalg_qr.out
+
+- func: linalg_qr.out(Tensor A, str mode='reduced', *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R)
+  python_module: linalg
+  structured: True
+  dispatch:
+    CPU, CUDA: linalg_qr_out
+
+- func: linalg_matrix_power(Tensor self, int n) -> Tensor
+  python_module: linalg
+
+- func: linalg_matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+
+- func: linalg_matrix_rank.atol_rtol_tensor(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor
+  python_module: linalg
+  variants: function
+
+- func: linalg_matrix_rank.atol_rtol_tensor_out(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  variants: function
+
+- func: linalg_matrix_rank.atol_rtol_float(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False) -> Tensor
+  cpp_no_default_args: ['atol', 'rtol']
+  python_module: linalg
+  variants: function
+
+- func: linalg_matrix_rank.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)
+  cpp_no_default_args: ['atol', 'rtol']
+  python_module: linalg
+  variants: function
+
+- func: linalg_matrix_rank(Tensor self, float tol, bool hermitian=False) -> Tensor
+  python_module: linalg
+  variants: function
+
+- func: linalg_matrix_rank.out(Tensor self, float tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  variants: function
+
+- func: linalg_matrix_rank.tol_tensor(Tensor input, Tensor tol, bool hermitian=False) -> Tensor
+  python_module: linalg
+  variants: function
+
+- func: linalg_matrix_rank.out_tol_tensor(Tensor input, Tensor tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+  variants: function
+
+- func: linalg_multi_dot(Tensor[] tensors) -> Tensor
+  python_module: linalg
+
+- func: linalg_multi_dot.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)
+  python_module: linalg
+
+## Functions related to the `torch.nested` namespace
+# Note [nested namespace binding]
+# Functions in the nested python module should have their names start with
+#   "nested_" underscore and be bound to the desired Python name in
+#   torch/nested/__init__.py, and the desired C++ name in torch/csrc/api/include/torch/nested.h.
+#   The "nested_" names should be hidden from the user and not documented.
+
+- func: nested_to_padded_tensor(Tensor self, float padding, int[]? output_size=None) -> Tensor
+  python_module: nested
+  variants: function
+
+## Functions that are only for testing
+# It is undocumented and should not be used outside of tests.
+- func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor
+
+# Note: for testing COW materialization within `at::parallel_for` loop function
+- func: _test_parallel_materialize(Tensor self, int num_parallel, bool skip_first=False) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: _test_parallel_materialize
+
+# Note: this function is only for testing.
+- func: _test_optional_intlist(Tensor values, int[]? addends) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: _test_optional_intlist
+  autogen: _test_optional_intlist.out
+
+# Note: this function is only for testing.
+- func: _test_optional_filled_intlist(Tensor values, int[2]? addends) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: _test_optional_intlist
+  autogen: _test_optional_filled_intlist.out
+
+# Note: this function is only for testing.
+- func: _test_optional_floatlist(Tensor values, float[]? addends) -> Tensor
+  python_module: nn
+  dispatch:
+    CPU: _test_optional_floatlist
+  autogen: _test_optional_floatlist.out
+
+# Note: this function is only for testing.
+- func: _test_string_default(Tensor dummy, str a="\"'\\", str b='"\'\\') -> Tensor
+  python_module: nn
+
+# Note: this function is only for testing.
+- func: _test_ambiguous_defaults.a(Tensor dummy, int a=1, int b=1) -> Tensor
+  python_module: nn
+
+# Note: this function is only for testing.
+- func: _test_ambiguous_defaults.b(Tensor dummy, int a=2, str b="2") -> Tensor
+  cpp_no_default_args: ['a', 'b']
+  python_module: nn
+
+# Note: this function is only for testing.
+- func: _test_warn_in_autograd(Tensor self) -> Tensor
+  python_module: nn
+  dispatch:
+    CompositeExplicitAutograd: _test_warn_in_autograd
+  autogen: _test_warn_in_autograd.out
+
+# Note: this function is only for testing.
+- func: _test_autograd_multiple_dispatch.fullcoverage(Tensor self) -> Tensor
+  dispatch:
+    # the NestedTensor keys are necessary because NestedTensor has been removed
+    # from the CompositeExplicitAutograd keyset see Note [NestedTensor Not Included in Backend Keys]
+    CompositeExplicitAutograd, NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: _test_autograd_multiple_dispatch_fullcoverage
+  autogen: _test_autograd_multiple_dispatch.fullcoverage_out
+
+# Note: this function is only for testing.
+- func: _test_autograd_multiple_dispatch.ntonly(Tensor self, bool b) -> Tensor
+  dispatch:
+    CompositeImplicitAutograd, NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: _test_autograd_multiple_dispatch_ntonly
+
+# Note: this function is only for testing.
+- func: _test_autograd_multiple_dispatch_view(Tensor(a) self) -> Tensor(a)
+  dispatch:
+    CompositeExplicitAutograd: _test_autograd_multiple_dispatch_view
+
+# Note: this function is only for testing.
+- func: _test_autograd_multiple_dispatch_view_copy(Tensor self) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: _test_autograd_multiple_dispatch_view_copy
+  tags: view_copy
+  autogen: _test_autograd_multiple_dispatch_view_copy.out
+
+- func: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
+  variants: function
+  dispatch:
+    CPU, CUDA: segment_reduce_kernel
+  autogen: segment_reduce.out
+
+- func: _segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None) -> Tensor
+  variants: function
+  dispatch:
+    CPU, CUDA: _segment_reduce_backward_kernel
+  autogen: _segment_reduce_backward.out
+
+- func: pad_sequence(Tensor[] sequences, bool batch_first=False, float padding_value=0.0, str padding_side="right") -> Tensor
+  python_module: nn
+  variants: function
+
+- func: flatten_dense_tensors(Tensor[] tensors) -> Tensor
+  variants: function
+  python_module: nn
+
+- func: unflatten_dense_tensors(Tensor flat, Tensor[] tensors) -> Tensor[]
+  variants: function
+  python_module: nn
+
+- func: _nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: _nested_tensor_from_tensor_list
+  autogen: _nested_tensor_from_tensor_list.out
+
+- func: _fw_primal_copy(Tensor self, int level) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: _fw_primal_copy
+  tags: view_copy
+  autogen: _fw_primal_copy.out
+
+- func: _make_dual_copy(Tensor primal, Tensor tangent, int level) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: _make_dual_copy
+  tags: view_copy
+  autogen: _make_dual_copy.out
+
+- func: view_as_real_copy(Tensor self) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: view_as_real_copy
+  tags: view_copy
+  autogen: view_as_real_copy.out
+
+- func: view_as_complex_copy(Tensor self) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: view_as_complex_copy
+  tags: view_copy
+  autogen: view_as_complex_copy.out
+
+- func: _conj_copy(Tensor self) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: _conj_copy
+  tags: view_copy
+  autogen: _conj_copy.out
+
+- func: _neg_view_copy(Tensor self) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: _neg_view_copy
+  tags: view_copy
+  autogen: _neg_view_copy.out
+
+- func: as_strided_copy(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: as_strided_copy_symint
+  tags: view_copy
+  autogen: as_strided_copy.out
+
+- func: _sparse_broadcast_to_copy(Tensor self, int[] size) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: _sparse_broadcast_to_copy
+  tags: view_copy
+  autogen: _sparse_broadcast_to_copy.out
+
+- func: diagonal_copy(Tensor self, int offset=0, int dim1=0, int dim2=1) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: diagonal_copy
+  tags: view_copy
+  autogen: diagonal_copy.out
+
+- func: expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: expand_copy_symint
+  tags: view_copy
+  autogen: expand_copy.out
+
+- func: permute_copy(Tensor self, int[] dims) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: permute_copy
+  tags: view_copy
+  autogen: permute_copy.out
+
+- func: _reshape_alias_copy(Tensor self, SymInt[] size, SymInt[] stride) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: _reshape_alias_copy_symint
+  tags: view_copy
+  autogen: _reshape_alias_copy.out
+
+- func: select_copy.int(Tensor self, int dim, SymInt index) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: select_copy_symint
+    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: select_copy_sparse_csr
+  tags: view_copy
+  autogen: select_copy.int_out
+
+- func: detach_copy(Tensor self) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: detach_copy
+  tags: view_copy
+  autogen: detach_copy.out
+
+- func: slice_copy.Tensor(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: slice_copy_Tensor_symint
+  tags: view_copy
+  autogen: slice_copy.Tensor_out
+
+- func: split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: split_copy_Tensor_symint
+  tags: view_copy
+
+- func: split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: split_with_sizes_copy_symint
+  tags: view_copy
+
+- func: squeeze_copy(Tensor self) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: squeeze_copy
+  tags: view_copy
+  autogen: squeeze_copy.out
+
+- func: squeeze_copy.dim(Tensor self, int dim) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: squeeze_copy_dim
+  tags: view_copy
+  autogen: squeeze_copy.dim_out
+
+- func: squeeze_copy.dims(Tensor self, int[] dim) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: squeeze_copy_dims
+  tags: view_copy
+  autogen: squeeze_copy.dims_out
+
+- func: t_copy(Tensor self) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: t_copy
+  tags: view_copy
+  autogen: t_copy.out
+
+- func: transpose_copy.int(Tensor self, int dim0, int dim1) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: transpose_copy_int
+  tags: view_copy
+  autogen: transpose_copy.int_out
+
+- func: unsqueeze_copy(Tensor self, int dim) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: unsqueeze_copy
+  tags: view_copy
+  autogen: unsqueeze_copy.out
+
+- func: _indices_copy(Tensor self) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: _indices_copy
+  tags: view_copy
+  autogen: _indices_copy.out
+
+- func: _values_copy(Tensor self) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: _values_copy
+  tags: view_copy
+  autogen: _values_copy.out
+
+- func: indices_copy(Tensor self) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: indices_copy
+  tags: view_copy
+  autogen: indices_copy.out
+
+- func: values_copy(Tensor self) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: values_copy
+  tags: view_copy
+  autogen: values_copy.out
+
+- func: crow_indices_copy(Tensor self) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: crow_indices_copy
+  tags: view_copy
+  autogen: crow_indices_copy.out
+
+- func: col_indices_copy(Tensor self) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: col_indices_copy
+  tags: view_copy
+  autogen: col_indices_copy.out
+
+- func: ccol_indices_copy(Tensor self) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: ccol_indices_copy
+  tags: view_copy
+  autogen: ccol_indices_copy.out
+
+- func: row_indices_copy(Tensor self) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: row_indices_copy
+  tags: view_copy
+  autogen: row_indices_copy.out
+
+- func: unbind_copy.int(Tensor self, int dim=0) -> Tensor[]
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: unbind_copy_int
+  tags: view_copy
+
+- func: unbind_copy.int_out(Tensor self, int dim=0, *, Tensor(a!)[] out) -> ()
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: unbind_copy_int_out
+
+- func: split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> ()
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: split_copy_Tensor_out
+
+
+- func: split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: split_with_sizes_copy_out
+    CUDA: split_with_sizes_copy_out_cuda
+
+- func: view_copy(Tensor self, SymInt[] size) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: view_copy_symint
+  tags: view_copy
+  autogen: view_copy.out
+
+- func: view_copy.dtype(Tensor self, ScalarType dtype) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: view_copy_dtype
+  tags: view_copy
+  autogen: view_copy.dtype_out
+
+- func: unfold_copy(Tensor self, int dimension, int size, int step) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: unfold_copy
+  tags: view_copy
+  autogen: unfold_copy.out
+
+- func: alias_copy(Tensor self) -> Tensor
+  variants: function
+  dispatch:
+    CompositeExplicitAutogradNonFunctional: alias_copy
+  tags: view_copy
+  autogen: alias_copy.out
+
+- func: to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor
+  variants: method
+  dispatch:
+    NestedTensorCPU: NestedTensor_to_padded_tensor_generic
+    NestedTensorCUDA: NestedTensor_to_padded_tensor_cuda
+  autogen: to_padded_tensor.out
+
+- func: _jagged_to_padded_dense_forward(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value=0.0) -> Tensor
+  variants: function
+  dispatch:
+    CUDA: _fbgemm_jagged_to_padded_dense_forward
+    CPU: _jagged_to_padded_dense_forward_cpu
+
+- func: _padded_dense_to_jagged_forward(Tensor dense, Tensor[] offsets, SymInt? total_L=None) -> Tensor
+  variants: function
+  dispatch:
+    CUDA: _fbgemm_dense_to_jagged_forward_symint
+    CPU: _padded_dense_to_jagged_forward_cpu
+
+- func: _nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, SymInt? sum_S=None) -> Tensor
+  variants: function
+  device_check: NoCheck
+  dispatch: {}
+
+- func: _nested_tensor_softmax_with_shape(Tensor self, Tensor query) -> Tensor
+  dispatch:
+    NestedTensorCPU: NestedTensor_softmax_dropout
+    NestedTensorCUDA: NestedTensor_softmax_dropout_cuda
+  tags: nondeterministic_seeded
+
+- func: _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: _safe_softmax
+    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: _safe_softmax
+
+# Apparently, putting "forward" in the name will cause Python bindings to be skipped, so "fwd" it is.
+- func: _transformer_encoder_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None) -> Tensor
+  variants: function
+  dispatch:
+    CPU, CUDA, NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: transformer_encoder_layer_forward
+  autogen: _transformer_encoder_layer_fwd.out
+
+- func: _native_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None) -> (Tensor, Tensor)
+  variants: function
+  dispatch:
+    CPU, NestedTensorCPU: native_multi_head_attention_cpu
+    CUDA, NestedTensorCUDA: native_multi_head_attention_cuda
+  autogen: _native_multi_head_attention.out
+
+- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor
+  python_module: nn
+  variants: function
+  autogen: scaled_dot_product_attention.out
+  tags: nondeterministic_seeded
+
+# This aten function is kept so that we can test the choice function from Python
+- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> int
+  dispatch:
+    Meta: _fused_sdp_choice_meta
+    CPU, NestedTensorCPU: _fused_sdp_choice_cpp
+    CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda
+    XPU: _fused_sdp_choice_xpu
+  tags: nondeterministic_seeded
+
+- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor)
+  variants: function
+  tags: nondeterministic_seeded
+
+- func: _scaled_dot_product_attention_math_for_mps(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor)
+  dispatch:
+    MPS: _scaled_dot_product_attention_math_mps
+  tags: nondeterministic_seeded
+
+- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
+  dispatch:
+    CUDA: _scaled_dot_product_flash_attention_cuda
+    NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
+  tags: nondeterministic_seeded
+
+- func: _scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp)
+  dispatch:
+    CPU: _scaled_dot_product_flash_attention_cpu
+  tags: nondeterministic_seeded
+
+- func: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
+  dispatch:
+    CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable
+    XPU: _scaled_dot_product_fused_attention_overrideable_xpu
+  tags: nondeterministic_seeded
+
+- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
+  device_check: NoCheck
+  variants: function
+  dispatch:
+    CUDA: _scaled_dot_product_flash_attention_backward_cuda
+    NestedTensorCUDA: _scaled_dot_product_flash_attention_backward_nested
+
+- func: _scaled_dot_product_flash_attention_for_cpu_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, float dropout_p, bool is_causal, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
+  device_check: NoCheck
+  variants: function
+  dispatch:
+    CPU: _scaled_dot_product_flash_attention_cpu_backward
+
+- func: _scaled_dot_product_fused_attention_overrideable_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor attn_bias, bool[4] grad_input_mask, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value, Tensor grad_attn_bias)
+  device_check: NoCheck
+  variants: function
+  dispatch:
+    CompositeExplicitAutograd: _scaled_dot_product_fused_attention_overrideable_backward
+
+- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
+  dispatch:
+    CUDA: _scaled_dot_product_efficient_attention_cuda
+    NestedTensorCUDA: _scaled_dot_product_efficient_attention_nestedtensor_cuda
+  tags: nondeterministic_seeded
+
+- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor)
+  device_check: NoCheck
+  dispatch:
+    CUDA: _scaled_dot_product_efficient_attention_backward_cuda
+  tags: nondeterministic_seeded
+
+- func: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
+  dispatch:
+    CUDA: _scaled_dot_product_cudnn_attention_cuda
+    NestedTensorCUDA: _scaled_dot_product_cudnn_attention_nestedtensor_cuda
+  tags: nondeterministic_seeded
+
+- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
+  dispatch:
+    CUDA: _scaled_dot_product_cudnn_attention_backward_cuda
+  tags: nondeterministic_seeded
+
+- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
+  variants: function
+  dispatch:
+    CUDA: _flash_attention_forward
+  tags: nondeterministic_seeded
+
+- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor rng_state, Tensor unused, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor)
+  device_check: NoCheck
+  variants: function
+  dispatch:
+    CUDA: _flash_attention_backward
+
+# Returns output, logsumexp if compute_logsumexp
+- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)
+  variants: function
+  dispatch:
+    CUDA: _efficient_attention_forward
+  tags: nondeterministic_seeded
+
+- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None, int? window_size=None, bool shared_storage_dqdkdv=False) -> (Tensor, Tensor, Tensor, Tensor)
+  device_check: NoCheck
+  variants: function
+  dispatch:
+    CUDA: _efficient_attention_backward
+
+- func: _cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
+  dispatch:
+    CUDA: _cudnn_attention_forward
+  tags: nondeterministic_seeded
+
+- func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor
+  variants: function
+  dispatch:
+    CUDA: triton_scaled_dot_attention
+  tags: nondeterministic_seeded
+  autogen: _triton_scaled_dot_attention.out
+
+- func: _fill_mem_eff_dropout_mask_(Tensor(a!) self, float dropout_p, int seed, int offset) -> Tensor(a!)
+  variants: function
+  dispatch:
+    CUDA: _fill_mem_eff_dropout_mask_
+  tags: nondeterministic_seeded
+
+- func: _triton_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None) -> Tensor
+  variants: function
+  dispatch:
+    CUDA: triton_multi_head_attention
+  autogen: _triton_multi_head_attention.out
+
+- func: special_airy_ai(Tensor x) -> Tensor
+  python_module: special
+  structured_delegate: special_airy_ai.out
+  variants: function
+  tags: pointwise
+
+- func: special_airy_ai.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA: special_airy_ai_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_bessel_j0(Tensor self) -> Tensor
+  python_module: special
+  structured_delegate: special_bessel_j0.out
+  variants: function
+  tags: pointwise
+
+- func: special_bessel_j0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA, MPS: special_bessel_j0_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_bessel_j1(Tensor self) -> Tensor
+  python_module: special
+  structured_delegate: special_bessel_j1.out
+  variants: function
+  tags: pointwise
+
+- func: special_bessel_j1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA, MPS: special_bessel_j1_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_bessel_y0(Tensor self) -> Tensor
+  python_module: special
+  structured_delegate: special_bessel_y0.out
+  variants: function
+  tags: pointwise
+
+- func: special_bessel_y0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA, MPS: special_bessel_y0_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_bessel_y1(Tensor self) -> Tensor
+  python_module: special
+  structured_delegate: special_bessel_y1.out
+  variants: function
+  tags: pointwise
+
+- func: special_bessel_y1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA, MPS: special_bessel_y1_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor
+  device_check: NoCheck
+  python_module: special
+  structured_delegate: special_chebyshev_polynomial_t.out
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_chebyshev_polynomial_t
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_chebyshev_polynomial_t
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck
+  dispatch:
+    CPU, CUDA, MPS: special_chebyshev_polynomial_t_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_chebyshev_polynomial_t_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_chebyshev_polynomial_t_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor
+  device_check: NoCheck
+  python_module: special
+  structured_delegate: special_chebyshev_polynomial_u.out
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_chebyshev_polynomial_u
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_chebyshev_polynomial_u
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck
+  dispatch:
+    CPU, CUDA, MPS: special_chebyshev_polynomial_u_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_chebyshev_polynomial_u_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_chebyshev_polynomial_u_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor
+  device_check: NoCheck
+  python_module: special
+  structured_delegate: special_chebyshev_polynomial_v.out
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_chebyshev_polynomial_v
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_chebyshev_polynomial_v
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck
+  dispatch:
+    CPU, CUDA, MPS: special_chebyshev_polynomial_v_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_chebyshev_polynomial_v_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_chebyshev_polynomial_v_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor
+  device_check: NoCheck
+  python_module: special
+  structured_delegate: special_chebyshev_polynomial_w.out
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_chebyshev_polynomial_w
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_chebyshev_polynomial_w
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck
+  dispatch:
+    CPU, CUDA, MPS: special_chebyshev_polynomial_w_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_chebyshev_polynomial_w_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_chebyshev_polynomial_w_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor
+  device_check: NoCheck
+  python_module: special
+  structured_delegate: special_hermite_polynomial_h.out
+  variants: function
+  tags: pointwise
+
+- func: special_hermite_polynomial_h.x_scalar(Scalar x, Tensor n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_hermite_polynomial_h
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_hermite_polynomial_h.n_scalar(Tensor x, Scalar n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_hermite_polynomial_h
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_hermite_polynomial_h.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck
+  dispatch:
+    CPU, CUDA, MPS: special_hermite_polynomial_h_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_hermite_polynomial_h.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_hermite_polynomial_h_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_hermite_polynomial_h.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_hermite_polynomial_h_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_hermite_polynomial_he(Tensor x, Tensor n) -> Tensor
+  device_check: NoCheck
+  python_module: special
+  structured_delegate: special_hermite_polynomial_he.out
+  variants: function
+  tags: pointwise
+
+- func: special_hermite_polynomial_he.x_scalar(Scalar x, Tensor n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_hermite_polynomial_he
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_hermite_polynomial_he.n_scalar(Tensor x, Scalar n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_hermite_polynomial_he
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_hermite_polynomial_he.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck
+  dispatch:
+    CPU, CUDA, MPS: special_hermite_polynomial_he_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_hermite_polynomial_he.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_hermite_polynomial_he_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_hermite_polynomial_he.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_hermite_polynomial_he_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_laguerre_polynomial_l(Tensor x, Tensor n) -> Tensor
+  device_check: NoCheck
+  python_module: special
+  structured_delegate: special_laguerre_polynomial_l.out
+  variants: function
+  tags: pointwise
+
+- func: special_laguerre_polynomial_l.x_scalar(Scalar x, Tensor n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_laguerre_polynomial_l
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_laguerre_polynomial_l.n_scalar(Tensor x, Scalar n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_laguerre_polynomial_l
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_laguerre_polynomial_l.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck
+  dispatch:
+    CPU, CUDA: special_laguerre_polynomial_l_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_laguerre_polynomial_l.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_laguerre_polynomial_l_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_laguerre_polynomial_l.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_laguerre_polynomial_l_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_legendre_polynomial_p(Tensor x, Tensor n) -> Tensor
+  device_check: NoCheck
+  python_module: special
+  structured_delegate: special_legendre_polynomial_p.out
+  variants: function
+  tags: pointwise
+
+- func: special_legendre_polynomial_p.x_scalar(Scalar x, Tensor n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_legendre_polynomial_p
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_legendre_polynomial_p.n_scalar(Tensor x, Scalar n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_legendre_polynomial_p
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_legendre_polynomial_p.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck
+  dispatch:
+    CPU, CUDA: special_legendre_polynomial_p_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_legendre_polynomial_p.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_legendre_polynomial_p_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_legendre_polynomial_p.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_legendre_polynomial_p_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_modified_bessel_i0(Tensor self) -> Tensor
+  python_module: special
+  structured_delegate: special_modified_bessel_i0.out
+  variants: function
+  tags: pointwise
+
+- func: special_modified_bessel_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA, MPS: special_modified_bessel_i0_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_modified_bessel_i1(Tensor self) -> Tensor
+  python_module: special
+  structured_delegate: special_modified_bessel_i1.out
+  variants: function
+  tags: pointwise
+
+- func: special_modified_bessel_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA, MPS: special_modified_bessel_i1_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_modified_bessel_k0(Tensor self) -> Tensor
+  python_module: special
+  structured_delegate: special_modified_bessel_k0.out
+  variants: function
+  tags: pointwise
+
+- func: special_modified_bessel_k0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA, MPS: special_modified_bessel_k0_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_modified_bessel_k1(Tensor self) -> Tensor
+  python_module: special
+  structured_delegate: special_modified_bessel_k1.out
+  variants: function
+  tags: pointwise
+
+- func: special_modified_bessel_k1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA, MPS: special_modified_bessel_k1_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_scaled_modified_bessel_k0(Tensor x) -> Tensor
+  python_module: special
+  structured_delegate: special_scaled_modified_bessel_k0.out
+  variants: function
+  tags: pointwise
+
+- func: special_scaled_modified_bessel_k0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA, MPS: special_scaled_modified_bessel_k0_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_scaled_modified_bessel_k1(Tensor x) -> Tensor
+  python_module: special
+  structured_delegate: special_scaled_modified_bessel_k1.out
+  variants: function
+  tags: pointwise
+
+- func: special_scaled_modified_bessel_k1.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA, MPS: special_scaled_modified_bessel_k1_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor
+  device_check: NoCheck
+  python_module: special
+  structured_delegate: special_shifted_chebyshev_polynomial_t.out
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_t
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_t
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck
+  dispatch:
+    CPU, CUDA: special_shifted_chebyshev_polynomial_t_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_t_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_t_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor
+  device_check: NoCheck
+  python_module: special
+  structured_delegate: special_shifted_chebyshev_polynomial_u.out
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_u
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_u
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck
+  dispatch:
+    CPU, CUDA: special_shifted_chebyshev_polynomial_u_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_u_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_u_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor
+  device_check: NoCheck
+  python_module: special
+  structured_delegate: special_shifted_chebyshev_polynomial_v.out
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_v
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_v
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck
+  dispatch:
+    CPU, CUDA: special_shifted_chebyshev_polynomial_v_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_v_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_v_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor
+  device_check: NoCheck
+  python_module: special
+  structured_delegate: special_shifted_chebyshev_polynomial_w.out
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_w
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_w
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  device_check: NoCheck
+  dispatch:
+    CPU, CUDA: special_shifted_chebyshev_polynomial_w_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_w_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_shifted_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: special_shifted_chebyshev_polynomial_w_out
+  device_check: NoCheck
+  python_module: special
+  variants: function
+  tags: pointwise
+
+- func: special_spherical_bessel_j0(Tensor x) -> Tensor
+  python_module: special
+  structured_delegate: special_spherical_bessel_j0.out
+  variants: function
+  tags: pointwise
+
+- func: special_spherical_bessel_j0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CPU, CUDA, MPS: special_spherical_bessel_j0_out
+  python_module: special
+  structured_inherits: TensorIteratorBase
+  structured: True
+  variants: function
+  tags: pointwise
+
+# Aux function used in the test TestPythonDispatch.test_kwarg_only_and_positional_default
+# within test/test_python_dispatch.py
+- func: _foobar(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True) -> Tensor
+  dispatch:
+    CPU: foobar
+  autogen: _foobar.out
+
+- func: _fused_adam_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
+  # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now).
+  variants: function
+  dispatch:
+    CPU: _fused_adam_kernel_cpu_
+    CUDA: _fused_adam_kernel_cuda_
+    MPS: _fused_adam_kernel_mps_
+  autogen: _fused_adam, _fused_adam.out
+
+- func: _fused_adam_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
+  # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now),
+  # but still skip the device check as the Tensor LR can be on CPU
+  device_check: NoCheck
+  variants: function
+  dispatch:
+    CPU: _fused_adam_kernel_cpu_
+    CUDA: _fused_adam_kernel_cuda_
+    MPS: _fused_adam_kernel_mps_
+  autogen: _fused_adam.tensor_lr, _fused_adam.tensor_lr_out
+
+- func: _fused_adamw_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
+  # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now).
+  variants: function
+  dispatch:
+    CPU: _fused_adamw_kernel_cpu_
+    CUDA: _fused_adamw_kernel_cuda_
+    MPS: _fused_adamw_kernel_mps_
+  autogen: _fused_adamw, _fused_adamw.out
+
+- func: _fused_adamw_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
+  # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now),
+  # but still skip the device check as the Tensor LR can be on CPU
+  device_check: NoCheck
+  variants: function
+  dispatch:
+    CPU: _fused_adamw_kernel_cpu_
+    CUDA: _fused_adamw_kernel_cuda_
+    MPS: _fused_adamw_kernel_mps_
+  autogen: _fused_adamw.tensor_lr, _fused_adamw.tensor_lr_out
+
+- func: _fused_sgd_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
+  # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now).
+  variants: function
+  dispatch:
+    CPU: _fused_sgd_kernel_cpu_
+    CUDA: _fused_sgd_kernel_cuda_
+    MPS: _fused_sgd_kernel_mps_
+  autogen: _fused_sgd, _fused_sgd.out
+
+- func: _fused_sgd_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
+  # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now).
+  # but still skip the device check as the Tensor LR can be on CPU
+  device_check: NoCheck
+  variants: function
+  dispatch:
+    CPU: _fused_sgd_kernel_cpu_
+    CUDA: _fused_sgd_kernel_cuda_
+    MPS: _fused_sgd_kernel_mps_
+  autogen: _fused_sgd.tensor_lr, _fused_sgd.tensor_lr_out
+
+- func: _fused_adagrad_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor(d!)[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
+  variants: function
+  dispatch:
+    CPU: _fused_adagrad_kernel_cpu_
+  autogen: _fused_adagrad, _fused_adagrad.out
+
+- func: _fused_adagrad_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor[] state_steps, *, Tensor lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
+  device_check: NoCheck
+  variants: function
+  dispatch:
+    CPU: _fused_adagrad_kernel_cpu_
+  autogen: _fused_adagrad.tensor_lr, _fused_adagrad.tensor_lr_out
+
+# This op is ONLY used by pytorch/XLA in functionalization, and should never show up in vanilla eager mode or in any pytorch tracing contexts.
+- func: _propagate_xla_data(Tensor input, Tensor output) -> ()
+  variants: function
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/native/tags.yaml b/phivenv/Lib/site-packages/torchgen/packaged/ATen/native/tags.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dbbaf684509a842b023b4b7095d4f59fa528f9e2
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/native/tags.yaml
@@ -0,0 +1,95 @@
+# This yaml file contains all the possible tags that can be defined in `tags` in `native_functions.yaml`
+
+- tag: inplace_view
+  desc: |
+          This tag indicates if an operator *only* modifies the tensor metadata
+- tag: pt2_compliant_tag
+  desc: |
+          This tag indicates if the operator is guaranteed to
+          work with the PT2 compilation APIs (torch.compile,
+          torch.export, etc). If you add this tag to an
+          operator, please use
+          `torch.testing._internal.optest.opcheck` to test that
+          the operator has been registered correctly and
+          works with torch.compile
+- tag: view_copy
+  desc: |
+          This tag indicates operators that are *_copy* variants
+          of view/aliasing operators. If an operator has a view_copy tag,
+          then it should have the name {op}_copy, where {op} is a view operator.
+- tag: dynamic_output_shape
+  desc: |
+          This tag indicates if an operator's output's shape depends on input Tensor
+          data.
+- tag: data_dependent_output
+  desc: |
+          Operator has a non-Tensor output whose value is dependent on the data
+          of Tensor inputs.  Among other things, this implies that this operator
+          cannot be run with meta tensor (since data is not available), nor
+          can it be symbolically traced.
+- tag: generated
+  desc: |
+          This tag indicates that the operator doesn't have an explicit entry in
+          native_functions.yaml, and instead was generated automatically by the codegen.
+- tag: nondeterministic_seeded
+  desc: |
+          This tag indicates if an operator is nondeterministically seeded
+          (i.e., is random) such that the operator intentionally produces
+          different results when run twice on the same inputs, but this randomness
+          is controlled by a Generator which, if reseeded would give you the
+          same result.
+- tag: nondeterministic_bitwise
+  desc: |
+          This tag indicates if an operator doesn't guarantee bitwise equivalence
+          across different runs of an operator with identical inputs.
+- tag: needs_exact_strides
+  desc: |
+          This tag indicates that the operator should be passed Tensors following
+          the same strides as observed in eager when compiled in inductor.
+          Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout}
+          can apply; if multiple are assigned then we assume the most restrictive one.
+- tag: needs_contiguous_strides
+  desc: |
+          This tag indicates that the operator should be passed contiguous Tensors.
+          Failure to do so will result in undefined behavior.
+- tag: needs_fixed_stride_order
+  desc: |
+          This tag indicates that the operator should be passed Tensors following
+          the same stride permutation as observed in eager when compiled in inductor.
+          Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout}
+          can apply; if multiple are assigned then we assume the most restrictive one.
+- tag: flexible_layout
+  desc: |
+          This tag indicates that the custom operator can accept inputs with varying
+          strides/storage_offset and that when compiled, Inductor is allowed to change
+          the strides/storage_offset of inputs to the custom operator.
+          Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout}
+          can apply; if multiple are assigned then we assume the most restrictive one.
+
+# NOTE [Core ATen Ops]
+- tag: core
+  desc: |
+          Core aten ops is a subset of aten ops that remains after aten-to-aten decomposition and
+          functionalization pass. Core aten ops are fully functional and adhere to single static
+          assignment (SSA): this implies there will be no `inplace` or `_out` variants in this opset.
+          This opset is designed to serve as the functional IR to interface with compiler backends.
+          In contrast to primTorch, core aten opset doesn't decompose ops into explicit
+          type promotion and broadcasting ops.
+          Core aten ops is also effectively the opset produced by torchdynamo.export(aten_graph=True),
+          and thus can be used as an opset for export purpose.
+- tag: pointwise
+  desc: |
+          Pointwise operators are operators where each element of the output is computed only by accessing
+          the corresponding element of all the broadcasted inputs. The output shape will be the broadcasted
+          shape of the inputs.
+- tag: maybe_aliasing_or_mutating
+  desc: |
+          For some ops, we can't statically determine whether the op is functional or not. Note that this is only
+          relevant to CIA ops that decompose before functionalization/autograd. It is useful to
+          know this information for export as we would want to decompose these ops as they are unsafe to be
+          preserved.
+- tag: cudagraph_unsafe
+  desc: |
+          This operator does not support cudagraphs. The presence of this tag on an operator will cause
+          Inductor to split the graph around this operator. Note that operators without this tag may still
+          not support CUDAGraphs. Inductor may have other hardcoded lists around that.
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/ATenOpList.cpp b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/ATenOpList.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..1e2cdef2ba05d346a852d1873251df006a5a6bba
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/ATenOpList.cpp
@@ -0,0 +1,36 @@
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// ${generated_comment}
+
+namespace at {
+
+namespace {
+struct OpNameEquals final {
+  bool operator()(const std::pair& lhs, const std::pair& rhs) const {
+      return 0 == strcmp(lhs.first, rhs.first) && 0 == strcmp(lhs.second, rhs.second);
+  }
+};
+
+struct OpNameHash final {
+  size_t operator()(const std::pair& p) const {
+      // use std::hash because std::hash would hash pointers and not pointed-to strings
+      return std::hash()(p.first) ^ (~ std::hash()(p.second));
+  }
+};
+}
+
+bool is_custom_op(const c10::OperatorName& opName) {
+  static std::unordered_set, OpNameHash, OpNameEquals> ops {
+    ${aten_ops}
+    {"", ""}
+  };
+  return ops.count(std::make_pair(
+             opName.name.c_str(), opName.overload_name.c_str())) == 0;
+}
+}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/CompositeViewCopyKernels.cpp b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/CompositeViewCopyKernels.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a950b6356078215e04328e3d1eca6df19d0060a4
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/CompositeViewCopyKernels.cpp
@@ -0,0 +1,73 @@
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+// ${generated_comment}
+
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+$ops_headers
+#endif
+
+namespace at {
+namespace native {
+
+// This file contains a number of kernels for aten functions that are fully code-generated.
+// TODO: rename this file to something more generic.
+
+namespace {
+at::Tensor clone_arg(const at::Tensor& t) {
+    return t.clone();
+}
+
+std::vector clone_arg(const at::TensorList& t_list) {
+    std::vector out(t_list.size());
+    for (const auto& i : c10::irange(t_list.size())) {
+        out[i] = t_list[i].clone();
+    }
+    return out;
+}
+
+// duped with gen_resize_out_helper from structured kernels
+void copy_arg(const at::Tensor& dst, const at::Tensor& src) {
+    TORCH_CHECK(src.dtype() == dst.dtype(),
+        "Expected out tensor to have dtype ", src.dtype(), ", but got ", dst.dtype(), " instead");
+    TORCH_CHECK(src.device() == dst.device(),
+        "Expected out tensor to have device ", src.device(), ", but got ", dst.device(), " instead");
+    dst.copy_(src);
+}
+
+void copy_arg(const at::TensorList& dst, const at::TensorList& src) {
+    TORCH_INTERNAL_ASSERT(dst.size() == src.size());
+    for (const auto& i : c10::irange(dst.size())) {
+        copy_arg(dst[i], src[i]);
+    }
+}
+
+// TODO: this doesn't handle restriding empty tensors correctly; see
+// gen_resize_out_helper for the correct algorithm
+
+void resize_out_helper(const at::Tensor& dst, const at::Tensor& src) {
+    at::native::resize_output(dst, src.sizes());
+}
+
+void resize_out_helper(const at::TensorList& dst, const at::TensorList& src) {
+    TORCH_INTERNAL_ASSERT(dst.size() == src.size());
+    for (const auto& i : c10::irange(dst.size())) {
+        at::native::resize_output(dst[i], src[i].sizes());
+    }
+}
+}
+
+
+${CompositeViewCopyKernel_Definitions}
+
+${GeneratedCompositeFunctional_Definitions}
+
+${GeneratedCompositeOut_Definitions}
+
+} // namespace native
+} // namespace at
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunction.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunction.h
new file mode 100644
index 0000000000000000000000000000000000000000..4bd999183177c647735f6db51ae19619113b2df2
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunction.h
@@ -0,0 +1,23 @@
+#pragma once
+// ${generated_comment}
+
+// NB: The implementing C++ file is RegisterDispatchKey.cpp
+
+// The only #includes we need are for custom classes that have defaults in the C++ API
+#include 
+#include 
+#include 
+
+// Forward declarations of any types needed in the operator signatures.
+// We can't directly include these classes because it will cause circular include dependencies.
+// This file is included by TensorBody.h, which defines the Tensor class.
+#include 
+
+namespace at {
+
+namespace ${dispatch_namespace} {
+
+${dispatch_namespaced_declarations}
+
+} // namespace ${dispatch_namespace}
+} // namespace at
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions.h
new file mode 100644
index 0000000000000000000000000000000000000000..4532acdf7f790e4dd520bfe4879c693e52f958d4
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions.h
@@ -0,0 +1,29 @@
+#include 
+
+// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch]
+// Code introduced to avoid cyclic dependency in static dispatch is no longer
+// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place,
+// to Operators.cpp for supporting multiple backends with multiple kernels.
+//
+// Note [Avoiding Include Cycles In Static Dispatch]
+// In order to avoid #include cycles in the static dispatch build, we've carefully split out
+// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h.
+//
+// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h.
+// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods
+//   all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all
+//   directly inlined into TensorBody.h.
+// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API,
+//   which include functions that have defaultable std::optional arguments.
+//   That requires knowing the full Tensor class definition.
+//
+// We break the cycle by doing the following:
+// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h
+// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl.,
+// - CPUFunctions_inl.h includes everything else
+// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class,
+//   and then it includes CPUFunctions_inl.h.
+// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly.
+// - This also means that static dispatch build, CPUFunctions.h only needs to
+//   #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h.
+${inline_headers}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions_inl.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions_inl.h
new file mode 100644
index 0000000000000000000000000000000000000000..7e9fe55a26ba9915b231d541880e3e8c9dd2bcec
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/DispatchKeyFunctions_inl.h
@@ -0,0 +1,22 @@
+#pragma once
+// ${generated_comment}
+
+// NB: The implementing C++ file is RegisterDispatchKey.cpp
+
+// The only #includes we need are for custom classes that have defaults in the C++ API
+#include 
+#include 
+#include 
+
+#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
+#error This change adds a dependency on all pytorch operators, meaning the     \
+  file will need to be re-compiled every time an operator is changed or added. \
+  Consider including a specific operator from                                  \
+  .                   \
+  See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
+#endif
+
+${DispatchKeyFunctions_inl_includes}
+
+
+${dispatch_namespaced_declarations}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..604a7bcb6275616ebe98756378f7feaffe4a6856
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.cpp
@@ -0,0 +1,13 @@
+// ${generated_comment}
+${includes}
+${native_functions_include}
+
+namespace {
+${helper_fns}
+} // namespace
+
+${namespace_prologue}
+
+${native_function_definitions}
+
+${namespace_epilogue}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.h
new file mode 100644
index 0000000000000000000000000000000000000000..e616b4d7ef360ab4e2223a55a11fe5d423efc8c2
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/DispatchKeyNativeFunctions.h
@@ -0,0 +1,19 @@
+#pragma once
+
+// an external backend might generate file within its code tree
+// and check all the source files within the tree with clang-format.
+// so, disable it since the backend might have a different config.
+// clang-format off
+
+// ${generated_comment}
+
+#include 
+
+${namespace_prologue}
+
+struct ${class_name} {
+
+${dispatch_declarations}
+
+};
+${namespace_epilogue}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/Function.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/Function.h
new file mode 100644
index 0000000000000000000000000000000000000000..d2833026e7d0e86110932a3ba5282f81de5bd250
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/Function.h
@@ -0,0 +1,27 @@
+#pragma once
+
+// ${generated_comment}
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+${static_dispatch_ops_headers}
+
+${operator_includes}
+
+namespace at {
+
+${function_definitions}
+
+}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/FunctionalInverses.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/FunctionalInverses.h
new file mode 100644
index 0000000000000000000000000000000000000000..2a817670b7cd6dd80404fc6a512c51e3ed8b8352
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/FunctionalInverses.h
@@ -0,0 +1,33 @@
+#pragma once
+
+// ${generated_comment}
+
+#include 
+
+namespace at {
+namespace functionalization {
+
+enum class InverseReturnMode {
+  /// Specifies that functional inverses should always return a view.
+  AlwaysView,
+  /// Specifies that functional inverses should always return a non-view / copy.
+  NeverView,
+  /// Specifies that functional inverses should return a view unless a (copying) scatter
+  /// inverse exists, in which case that will be used instead.
+  /// This avoids as_strided() calls that can be difficult for subclasses to handle.
+  ViewOrScatterInverse,
+};
+
+struct FunctionalInverses {
+
+${view_inverse_declarations}
+
+// NB: These are not generated! They're manually implemented in the template.
+// TODO: Change codegen to generate these. See the following link:
+// https://github.com/pytorch/pytorch/blob/main/torchgen/model.py#L2583-L2585
+static at::Tensor chunk_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, int chunks, int dim);
+static at::Tensor narrow_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int dim, c10::SymInt start, c10::SymInt length);
+
+};
+}
+}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/Functions.cpp b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/Functions.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..a7d887c59f141dd4c063e0b392b67a317987d6d4
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/Functions.cpp
@@ -0,0 +1,105 @@
+#include 
+
+#include 
+#include 
+#include 
+
+namespace at {
+
+Tensor TensorMaker::make_tensor() {
+   AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
+   tracer::impl::NoTracerDispatchMode tracer_guard{};
+
+   check_size_nonnegative(sizes_);
+
+   TORCH_CHECK_VALUE(
+       !deleter_ || !ctx_,
+       "The deleter and context arguments are mutually exclusive.");
+
+   if (device_ == std::nullopt) {
+     device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type());
+   }
+
+   if (opts_.device().has_index()) {
+     // clang-format off
+     TORCH_CHECK_VALUE(
+         opts_.device() == *device_,
+         "Specified device ", opts_.device(), " does not match device of data ", *device_);
+     // clang-format on
+   }
+
+   std::size_t size_bytes = computeStorageSize();
+
+   DataPtr data_ptr{};
+   if (deleter_) {
+     data_ptr = makeDataPtrFromDeleter();
+   } else {
+     data_ptr = makeDataPtrFromContext();
+   }
+
+   TORCH_CHECK(!resizeable_ || allocator_ != nullptr, "Must specify an allocator with allocator() if you want to use resizeable_storage()");
+   Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr), /*allocator=*/allocator_, /*resizable=*/resizeable_};
+
+   Tensor tensor = detail::make_tensor(
+       std::move(storage), opts_.computeDispatchKey(), opts_.dtype());
+
+  TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
+  if (strides_) {
+    tensor_impl->set_sizes_and_strides(sizes_, *strides_);
+  } else {
+    tensor_impl->set_sizes_contiguous(sizes_);
+  }
+  if (storage_offset_) {
+    tensor_impl->set_storage_offset(*storage_offset_);
+  }
+
+  tensor_impl->set_requires_grad(opts_.requires_grad());
+
+  return tensor;
+ }
+
+ std::size_t TensorMaker::computeStorageSize() const noexcept {
+   std::size_t itemsize = opts_.dtype().itemsize();
+
+   if (strides_) {
+     auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize);
+     if (storage_offset_) {
+       storage_size += storage_offset_.value();
+     }
+     return storage_size;
+   }
+
+   std::size_t size = 1;
+   for (std::int64_t s : sizes_) {
+     size *= static_cast(s);
+   }
+   auto storage_size = size * itemsize;
+   if (storage_offset_) {
+     storage_size += storage_offset_.value();
+   }
+   return storage_size;
+ }
+
+ inline DataPtr TensorMaker::makeDataPtrFromDeleter() noexcept {
+   return InefficientStdFunctionContext::makeDataPtr(data_, std::move(deleter_), *device_);
+ }
+
+ inline DataPtr TensorMaker::makeDataPtrFromContext() noexcept {
+   return DataPtr{data_, ctx_.release(), ctx_.get_deleter(), *device_};
+ }
+
+ IntArrayRef TensorMaker::makeTempSizes() const noexcept {
+   static std::int64_t zeros[5] = {0, 0, 0, 0, 0};
+   if (opts_.has_memory_format()) {
+     MemoryFormat format = *opts_.memory_format_opt();
+     if (format == MemoryFormat::ChannelsLast) {
+       return IntArrayRef(zeros, 4);
+     }
+     if (format == MemoryFormat::ChannelsLast3d) {
+       return IntArrayRef(zeros, 5);
+     }
+   }
+   return IntArrayRef(zeros, 1);
+ }
+
+} // namespace at
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/Functions.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/Functions.h
new file mode 100644
index 0000000000000000000000000000000000000000..55fa67c5bddb2f4abc730396c3affe4d75479753
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/Functions.h
@@ -0,0 +1,143 @@
+#pragma once
+
+// ${generated_comment}
+
+#ifdef TORCH_ASSERT_NO_OPERATORS
+#error This change adds a dependency on native_functions.yaml,            \
+  meaning the file will need to be re-compiled every time an operator     \
+  is changed or added. Consider if your change would be better placed in  \
+  another file, or if a more specific header might achieve the same goal. \
+  See NOTE: [Tensor vs. TensorBase]
+#endif
+
+#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
+#error This change adds a dependency on all pytorch operators, meaning the     \
+  file will need to be re-compiled every time an operator is changed or added. \
+  Consider including a specific operator from  and   \
+  see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
+#endif
+
+// NOTE: [TORCH_ASSERT_ONLY_METHOD_OPERATORS]
+//
+// In ATen, certain generated headers files include the definitions of
+// every single operator in PyTorch. Unfortunately this means every
+// time an operator signature is updated or changed in
+// native_functions.yaml, you (and every other PyTorch developer) need
+// to recompile every source file that includes any of these headers.
+//
+// To break up these header dependencies, and improve incremental
+// build times for all PyTorch developers. These headers are split
+// into per-operator headers in the `ATen/ops` folder. This limits
+// incremental builds to only changes to methods of `Tensor`, or files
+// that use the specific operator being changed. With `at::sum` as an
+// example, you should include
+//
+//                  // instead of ATen/Functions.h
+//           // instead of ATen/NativeFunctions.h
+//              // instead of ATen/Operators.h
+//     // instead of ATen/CPUFunctions.h
+//
+// However, even if you're careful to use this in your own code.
+// `Functions.h` might be included indirectly through another header
+// without you realising. To avoid this, you can add
+//
+//   #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+//
+// to the top of your source file. This way any time the non-specific
+// headers are included, the compiler will error out.
+//
+// Also, be aware that `ops` are not available in all build
+// configurations (namely fb-internal) so you must guard these
+// includes with `#ifdef AT_PER_OPERATOR_HEADERS`. e.g.
+//
+//   #ifndef AT_PER_OPERATOR_HEADERS
+//   #include 
+//   #else
+//   #include 
+//   #endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+${Functions_includes}
+
+namespace at {
+
+${Functions_declarations}
+
+// Special C++ only overloads for std()-like functions (See gh-40287)
+// These are needed because int -> bool conversion takes precedence over int -> IntArrayRef
+// So, for example std(0) would select the std(unbiased=False) overload
+TORCH_API inline Tensor var(const Tensor& self, int dim) {
+  return at::var(self, IntArrayRef{dim});
+}
+TORCH_API inline std::tuple var_mean(const Tensor& self, int dim) {
+  return at::var_mean(self, IntArrayRef{dim});
+}
+TORCH_API inline Tensor std(const Tensor& self, int dim) {
+  return at::std(self, IntArrayRef{dim});
+}
+TORCH_API inline std::tuple std_mean(const Tensor& self, int dim) {
+  return at::std_mean(self, IntArrayRef{dim});
+}
+
+inline int64_t numel(const Tensor& tensor) {
+  return tensor.numel();
+}
+
+inline int64_t size(const Tensor& tensor, int64_t dim) {
+  return tensor.size(dim);
+}
+
+inline int64_t stride(const Tensor& tensor, int64_t dim) {
+  return tensor.stride(dim);
+}
+
+inline bool is_complex(const Tensor& tensor) {
+  return tensor.is_complex();
+}
+
+inline bool is_floating_point(const Tensor& tensor) {
+  return tensor.is_floating_point();
+}
+
+inline bool is_signed(const Tensor& tensor) {
+  return tensor.is_signed();
+}
+
+inline bool is_inference(const Tensor& tensor) {
+  return tensor.is_inference();
+}
+
+inline bool _is_zerotensor(const Tensor& tensor) {
+  return tensor._is_zerotensor();
+}
+
+inline bool is_conj(const Tensor& tensor) {
+  return tensor.is_conj();
+}
+
+inline Tensor conj(const Tensor& tensor) {
+  return tensor.conj();
+}
+
+inline bool is_neg(const Tensor& tensor) {
+  return tensor.is_neg();
+}
+
+}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/LazyIr.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/LazyIr.h
new file mode 100644
index 0000000000000000000000000000000000000000..6f3867cbd91d2f6331908372de5f1434107f664d
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/LazyIr.h
@@ -0,0 +1,19 @@
+#pragma once
+
+// This file contains autogenerated LazyTensor IR nodes
+${lazy_ir_sysinc}
+${lazy_ir_inc}
+
+${namespace_prologue}
+using at::operator<<;
+
+// kNullValue is used to contribute a static hash value any time
+// a node has an Optional input that is nullopt.  It is important
+// to differentiate between HASH(std::nullopt, something) and HASH(something, std::nullopt),
+// and using kNullValue in the hash function in the order of arguments
+// serves this purpose.
+static const torch::lazy::Value kNullValue = torch::lazy::Value();
+
+${ir_declarations}
+
+${namespace_epilogue}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/LazyNonNativeIr.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/LazyNonNativeIr.h
new file mode 100644
index 0000000000000000000000000000000000000000..df0f621c9620d3075a23a1af2da621d79cdb712f
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/LazyNonNativeIr.h
@@ -0,0 +1,11 @@
+#pragma once
+
+${lazy_non_native_ir_inc}
+
+// This file contains autogenerated LazyTensor Non Native IR nodes
+
+${namespace_prologue}
+
+${non_native_ir_nodes}
+
+${namespace_epilogue}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/MethodOperators.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/MethodOperators.h
new file mode 100644
index 0000000000000000000000000000000000000000..b6fe7ba41de0850908d3de589363a58cd97cf0ce
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/MethodOperators.h
@@ -0,0 +1,24 @@
+#pragma once
+
+// ${generated_comment}
+
+#ifdef TORCH_ASSERT_NO_OPERATORS
+#error This change adds a dependency on native_functions.yaml,             \
+  meaning the file will need to be re-compiled every time an operator      \
+  is changed or added. Consider if your change would be better placed in   \
+  another file, or if a more specific header might achieve the same goal.  \
+  See NOTE: [Tensor vs. TensorBase]
+#endif
+
+// Forward declarations of any types needed in the operator signatures.
+// We can't directly include these classes because it will cause circular include dependencies.
+// This file is included by TensorBody.h, which defines the Tensor class.
+#include 
+
+${MethodOperators_includes}
+
+namespace at {
+namespace _ops {
+${MethodOperators_declarations}
+} // namespace _ops
+} // namespace at
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/NativeFunction.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/NativeFunction.h
new file mode 100644
index 0000000000000000000000000000000000000000..323849f0dbc73d1340594ffcc5ba692975290a76
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/NativeFunction.h
@@ -0,0 +1,17 @@
+#pragma once
+
+// ${generated_comment}
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+${extra_includes}
+
+${native_function_declarations}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/NativeFunctions.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/NativeFunctions.h
new file mode 100644
index 0000000000000000000000000000000000000000..d99b9afa9a9874328512b40a3e07478d48539c93
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/NativeFunctions.h
@@ -0,0 +1,33 @@
+#pragma once
+
+// ${generated_comment}
+
+#ifdef TORCH_ASSERT_NO_OPERATORS
+#error This change adds a dependency on native_functions.yaml,            \
+  meaning the file will need to be re-compiled every time an operator     \
+  is changed or added. Consider if your change would be better placed in  \
+  another file, or if a more specific header might achieve the same goal. \
+  See NOTE: [Tensor vs. TensorBase]
+#endif
+
+#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
+#error This change adds a dependency on all pytorch operators, meaning the      \
+  file will need to be re-compiled every time an operator is changed or added.  \
+  Consider including a specific operator from  \
+  and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+${NativeFunctions_includes}
+
+${NativeFunctions_declarations}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunction.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunction.h
new file mode 100644
index 0000000000000000000000000000000000000000..daf2bdf65bb47fe7609930137a9aece8f75f6839
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunction.h
@@ -0,0 +1,23 @@
+#pragma once
+
+// ${generated_comment}
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace meta {
+
+${meta_function_declarations}
+
+} // namespace native
+} // namespace at
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunctions.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunctions.h
new file mode 100644
index 0000000000000000000000000000000000000000..ed628e1656dcb541d3bfe07e3df179d0bcbbf3e1
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/NativeMetaFunctions.h
@@ -0,0 +1,19 @@
+#pragma once
+
+// ${generated_comment}
+
+#include 
+#include 
+#include 
+#include 
+
+${NativeMetaFunctions_includes}
+
+namespace at {
+
+namespace meta {
+
+${NativeMetaFunctions_declarations}
+
+} // namespace meta
+} // namespace at
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/Operator.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/Operator.h
new file mode 100644
index 0000000000000000000000000000000000000000..6963810710aaf73b040f174beec208534328afc6
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/Operator.h
@@ -0,0 +1,19 @@
+#pragma once
+
+// ${generated_comment}
+
+#include 
+#include 
+#include 
+
+// Forward declarations of any types needed in the operator signatures.
+// We can't directly include these classes because it will cause circular include dependencies.
+// This file is included by TensorBody.h, which defines the Tensor class.
+#include 
+
+namespace at {
+namespace _ops {
+
+${declarations}
+
+}} // namespace at::_ops
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/Operators.cpp b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/Operators.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..9eaff50501a2f3487566eedbfe4cfc33b26c3594
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/Operators.cpp
@@ -0,0 +1,19 @@
+#include 
+#include 
+
+// ${generated_comment}
+// NOTE See [Sharded File] comment in VariableType
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+${operator_headers}
+#endif
+
+${static_dispatch_extra_headers}
+
+namespace at { namespace _ops {
+
+${definitions}
+
+}} // namespace at::_ops
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/Operators.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/Operators.h
new file mode 100644
index 0000000000000000000000000000000000000000..c4ff5dc101c6764563301310036d48a24bc3c6cd
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/Operators.h
@@ -0,0 +1,74 @@
+#pragma once
+
+// ${generated_comment}
+
+#ifdef TORCH_ASSERT_NO_OPERATORS
+#error This change adds a dependency on native_functions.yaml,             \
+  meaning the file will need to be re-compiled every time an operator      \
+  is changed or added. Consider if your change would be better placed in   \
+  another file, or if a more specific header might achieve the same goal.  \
+  See NOTE: [Tensor vs. TensorBase]
+#endif
+
+#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
+#error This change adds a dependency on all pytorch operators, meaning the     \
+  file will need to be re-compiled every time an operator is changed or added. \
+  Consider including a specific operator from    \
+  and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+${Operators_includes}
+
+// Extension writers: do you write wrapper functions? Are you frustrated with
+// resolving overloads of operators? Are you frustrated with dealing with
+// pointer-to-methods and resolving overloads of pointer-to-methods?? Look no
+// further, this is the utility for you.
+//
+// Given an operator schema: aten::op.overload(...
+//
+// Use ATEN_FN2(op, overload) to get a *function* version of the operator
+// that is guaranteed to not be overloaded. This means that you can safely
+// decltype(&ATEN_FN2(op, overload)) it. NB: the 2 means this macro takes 2 args.
+//
+// Given an operator schema without an overload name: aten::op(...
+//
+// Use ATEN_FN(op) to get an unambiguous *function* version of the operator.
+//
+// There is some interesting behavior for out= operations.
+// ATEN_FN2(sin, out) gives a function that is *faithful* to the schema;
+// that is, the order of arguments is exactly what it looks like in the schema.
+
+#define ATEN_FN2(op_name, overload) at::_ops::op_name##_##overload::call
+#define ATEN_FN(op_name) at::_ops::op_name::call
+
+// Separately, ATEN_OP(op) and ATEN_OP2(op, overload) define a class containing compile-time
+// metadata about a given aten operator.
+// Notable data on the class includes:
+// - ATEN_OP2(add, Tensor)::name // returns the string name: "add"
+// - ATEN_OP2(add, Tensor)::overload_name // returns the string overload name: "Tensor"
+// - ATEN_OP2(add, Tensor)::schema // returns the C++ schema type: at::Tensor (const at::Tensor &, const at::Tensor &, const at::Scalar &)
+// - ATEN_OP2(add, Tensor)::schema_str // returns the string jit type: "add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"
+
+#define ATEN_OP2(op_name, overload) at::_ops::op_name##_##overload
+#define ATEN_OP(op_name) at::_ops::op_name
+
+// WARNING: Please do not call any of the ops in the _ops namespace directly.
+// Use the ATEN_FN macros. We do not guarantee stability of the naming
+// scheme for the functions in at::_ops
+
+// See Note [The ATen Operators API] for details of the at::_ops namespace
+
+namespace at {
+namespace _ops {
+${Operators_declarations}
+} // namespace _ops
+} // namespace at
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RedispatchFunctions.cpp b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RedispatchFunctions.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..800ae1bc4cc0411ba4b673daec15328fa14aae8e
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RedispatchFunctions.cpp
@@ -0,0 +1,15 @@
+// ${generated_comment}
+
+#include 
+#include 
+
+#include 
+#include 
+
+namespace at {
+
+namespace redispatch {
+    ${function_redispatch_definitions}
+} // namespace redispatch
+
+} // namespace at
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RedispatchFunctions.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RedispatchFunctions.h
new file mode 100644
index 0000000000000000000000000000000000000000..379ad8d5d426d3ad4e965c97085369b567379deb
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RedispatchFunctions.h
@@ -0,0 +1,32 @@
+#pragma once
+
+// ${generated_comment}
+
+#ifdef TORCH_ASSERT_ONLY_METHOD_OPERATORS
+#error This change adds a dependency on all pytorch operators, meaning the     \
+  file will need to be re-compiled every time an operator is changed or added. \
+  Consider using the at::_ops::{name}::redispatch() interface by including     \
+  the specific operator from 
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+
+namespace redispatch {
+    ${function_redispatch_definitions}
+} // namespace redispatch
+
+}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegisterBackendSelect.cpp b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegisterBackendSelect.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..250fada8eb1befefe83f1ab0a241e44ede50268b
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegisterBackendSelect.cpp
@@ -0,0 +1,29 @@
+// We register ops with a higher priority dispatch key (BackendSelect) than the usual backend-specific keys (e.g. CPU)
+// which makes calls to the factory functions dispatch to here.
+// We then 'manually' compute a lower-priority to re-dispatch to (e.g. CPU) to get to the eventually correct backend.
+// ${generated_comment}
+
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+
+${ops_headers}
+#endif
+
+namespace at {
+
+namespace {
+
+${backend_select_method_definitions}
+
+TORCH_LIBRARY_IMPL(aten, BackendSelect, m) {
+  ${backend_select_function_registrations};
+}
+
+} // namespace
+} // at
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegisterCodegenUnboxedKernels.cpp b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegisterCodegenUnboxedKernels.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..52adaeec74da5905fbc1ffbfaccd7454577892b9
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegisterCodegenUnboxedKernels.cpp
@@ -0,0 +1,41 @@
+#include 
+#include 
+#include 
+
+#include 
+
+// ${generated_comment}
+
+// NOTE [Sharded File]: This file is generated in a sharded fashion to speed up
+// incremental rebuilds. See the comment at the top of
+// templates/VariableType.cpp for an analogous, in-depth discussion.
+//
+// Generated by tools/jit/gen_unboxing.py. This file registers all ATen ops into JIT op registry instead of c10
+// dispatcher. JIT op registry only takes boxed kernels, so we are calling unboxing functions in UnboxingFunctions.h
+// to cast arguments into C++ types (instead of IValue) and delegate to unboxed kernels.
+
+namespace torch { namespace jit {
+
+using autograd::Variable;
+using autograd::variable_list;
+using at::Scalar;
+using at::ScalarType;
+using at::Tensor;
+using at::TensorOptions;
+using at::DeviceGuard;
+
+using ::c10::fmap;
+using ::c10::filter;
+
+namespace {
+
+RegisterOperators reg({
+
+    // Generated operators
+    ${unboxed_ops}
+});
+
+} // anon namespace
+
+
+}} // namespace torch::jit
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegisterDispatchDefinitions.ini b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegisterDispatchDefinitions.ini
new file mode 100644
index 0000000000000000000000000000000000000000..ed24f73247cc01c432ffff544e5fb34c0890b7fe
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegisterDispatchDefinitions.ini
@@ -0,0 +1,22 @@
+${ns_prologue}
+
+// NB: TORCH_LIBRARY_IMPL must be in an anonymous namespace to avoid
+// ambiguity with conflicting identifiers that may have been defined in
+// at namespace already.
+namespace {
+
+${dispatch_anonymous_definitions}
+
+${static_init_dispatch_registrations}
+
+} // anonymous namespace
+
+${deferred_dispatch_registrations}
+
+namespace ${dispatch_namespace} {
+
+${dispatch_namespaced_definitions}
+
+} // namespace ${dispatch_namespace}
+
+${ns_epilogue}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegisterDispatchKey.cpp b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegisterDispatchKey.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..fbb73f3df2d25fd46448606f297fe3317b8a915f
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegisterDispatchKey.cpp
@@ -0,0 +1,54 @@
+// an external backend might generate file within its code tree
+// and check all the source files within the tree with clang-format.
+// so, disable it since the backend might have a different config.
+// clang-format off
+
+// NOTE: This condition is true for all PyTorch internal libraries, it
+//       just excludes external projects such as torch_xla which
+//       re-use some of the PyTorch codegen machinery.
+#if defined(CAFFE2_BUILD_MAIN_LIB)        || \
+    defined(TORCH_CUDA_BUILD_MAIN_LIB)    || \
+    defined(TORCH_HIP_BUILD_MAIN_LIB)     || \
+    defined(TORCH_XPU_BUILD_MAIN_LIB)     || \
+    defined(TORCH_CUDA_CU_BUILD_MAIN_LIB) || \
+    defined(TORCH_CUDA_CPP_BUILD_MAIN_LIB)
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+#endif
+
+// ${generated_comment}
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+$extra_cuda_headers
+$external_backend_headers
+$dispatch_headers
+$ops_headers
+
+namespace at {
+namespace {
+$dispatch_helpers
+} // namespace
+} // namespace at
+
+// See template file RegisterDispatchDefinitions.ini
+$dispatch_definitions
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegisterFunctionalization.cpp b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegisterFunctionalization.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..07b286af1a4831671c4a5c412016899b35076f6d
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegisterFunctionalization.cpp
@@ -0,0 +1,116 @@
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+// ${generated_comment}
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#include 
+#else
+// needed for the meta tensor calls to get stride info in functionalization
+#include 
+// needed for special handling of copy_().
+// See Note [functionalizating copy_() and not preserving strides]
+#include 
+#include 
+
+$ops_headers
+#endif
+
+namespace at {
+namespace functionalization {
+
+// This keyset is used by functionalization when it calls into meta kernels
+// to accurately propagate stride metadata.
+// Exclude any modes: the purpose of calling into meta kernels is only as an implementation
+// detail to perform shape inference, and we don't want any modal keys to run.
+// Specifically, we want to prevent functionalization and Python modes from running.
+constexpr auto exclude_keys_for_meta_dispatch =
+    c10::functorch_transforms_ks |
+    c10::DispatchKeySet({
+        c10::DispatchKey::FuncTorchDynamicLayerBackMode,
+        c10::DispatchKey::FuncTorchDynamicLayerFrontMode,
+        c10::DispatchKey::Python,
+        c10::DispatchKey::PreDispatch,
+
+    });
+
+// Helper around at::has_internal_overlap.
+// The ATen util is used in hot-path eager mode: it's always fast,
+// but might return TOO_HARD sometimes.
+// During functionalization, we're ok taking a bit longer
+// to detect memory overlap.
+inline bool has_internal_overlap_helper(const at::Tensor t) {
+  auto has_overlap = at::has_internal_overlap(t);
+  if (has_overlap == at::MemOverlap::Yes) return true;
+  if (has_overlap == at::MemOverlap::No) return false;
+  return false;
+}
+
+
+inline Tensor to_meta(const Tensor& t) {
+    if (!t.defined()) return t;
+    return at::native::empty_strided_meta_symint(t.sym_sizes(), t.sym_strides(),
+/*dtype=*/t.scalar_type(), /*layout=*/t.layout(),
+/*device=*/c10::Device(kMeta), /*pin_memory=*/std::nullopt);
+}
+
+inline std::optional to_meta(const std::optional& t) {
+  if (t.has_value()) {
+    return to_meta(*t);
+  }
+  return std::nullopt;
+}
+
+inline std::vector to_meta(at::ITensorListRef t_list) {
+  std::vector outputs;
+  outputs.reserve(t_list.size());
+  for (const auto& tensor : t_list) {
+    outputs.push_back(to_meta(tensor));
+  }
+  return outputs;
+}
+
+inline c10::List to_meta(const c10::List& t_list) {
+  c10::List outputs;
+  outputs.reserve(t_list.size());
+  for (const auto i : c10::irange(t_list.size())) {
+    outputs.push_back(to_meta(t_list[i]));
+  }
+  return outputs;
+}
+
+inline c10::List<::std::optional> to_meta(const c10::List<::std::optional>& t_list) {
+  c10::List<::std::optional> outputs;
+  outputs.reserve(t_list.size());
+  for (const auto i : c10::irange(t_list.size())) {
+    outputs.push_back(to_meta(t_list[i]));
+  }
+  return outputs;
+}
+
+static bool disable_meta_reference() {
+  static auto env = c10::utils::get_env("TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE");
+  return env == "1";
+}
+
+
+${func_definitions}
+
+}  // namespace functionalization
+
+namespace {
+
+TORCH_LIBRARY_IMPL(aten, Functionalize, m) {
+  ${func_registrations};
+}
+
+}  // namespace
+
+} // namespace at
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegisterSchema.cpp b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegisterSchema.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..72bc8ed613c3626e9faeebe290832dc868d341f9
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegisterSchema.cpp
@@ -0,0 +1,13 @@
+// ${generated_comment}
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+#include 
+
+namespace at {
+TORCH_LIBRARY(aten, m) {
+  ${aten_schema_registrations};
+  // Distributed Ops
+  // Implementations located in torch/csrc/jit/runtime/register_distributed_ops.cpp
+  m.def("get_gradients(int context_id) -> Dict(Tensor, Tensor)");
+}
+${schema_registrations}
+}  // namespace at
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegistrationDeclarations.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegistrationDeclarations.h
new file mode 100644
index 0000000000000000000000000000000000000000..f645f271585b28724829d7ac2672fab582f18dcf
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/RegistrationDeclarations.h
@@ -0,0 +1,4 @@
+// This file contains all native_functions that can be registered to
+// and the schema string that they should be registered with
+
+${registration_declarations}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/TensorBody.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/TensorBody.h
new file mode 100644
index 0000000000000000000000000000000000000000..565d13fb2c99fcd78571ee012e32bcec17e454ff
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/TensorBody.h
@@ -0,0 +1,753 @@
+#pragma once
+
+#ifdef TORCH_ASSERT_NO_OPERATORS
+#error This change adds a dependency on native_functions.yaml,            \
+  meaning the file will need to be re-compiled every time an operator     \
+  is changed or added. Consider if your change would be better placed in  \
+  another file, or if a more specific header might achieve the same goal. \
+  See NOTE: [Tensor vs. TensorBase]
+#endif
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+
+#include 
+
+namespace c10{
+template class List;
+template class IListRef;
+}
+namespace at {
+struct Generator;
+struct Type;
+class DeprecatedTypeProperties;
+class Tensor;
+} // namespace at
+namespace at {
+namespace indexing {
+struct TensorIndex;
+} // namespace indexing
+} // namespace at
+
+namespace torch { namespace autograd {
+
+struct Node;
+
+}} // namespace torch::autograd
+
+namespace at {
+
+class OptionalTensorRef;
+class TensorRef;
+class Tensor;
+using TensorList = ArrayRef;
+using ITensorList = c10::IListRef;
+
+using Stream = c10::Stream;
+
+// Tensor is a "generic" object holding a pointer to the underlying TensorImpl object, which
+// has an embedded reference count. In this way, Tensor is similar to boost::intrusive_ptr.
+//
+// For example:
+//
+// void func(Tensor a) {
+//   Tensor b = a;
+//   ...
+// }
+//
+// In this example, when we say Tensor b = a, we are creating a new object that points to the
+// same underlying TensorImpl, and bumps its reference count. When b goes out of scope, the
+// destructor decrements the reference count by calling release() on the TensorImpl it points to.
+// The existing constructors, operator overloads, etc. take care to implement the correct semantics.
+//
+// Note that Tensor can also be NULL, i.e. it is not associated with any underlying TensorImpl, and
+// special care must be taken to handle this.
+class TORCH_API Tensor: public TensorBase {
+ protected:
+  // Create a Tensor with a +0 reference count. Special care must be
+  // taken to avoid decrementing this reference count at destruction
+  // time. Intended to support MaybeOwnedTraits.
+  explicit Tensor(unsafe_borrow_t, const TensorBase& rhs): TensorBase(unsafe_borrow_t{}, rhs) {}
+  friend MaybeOwnedTraits;
+  friend OptionalTensorRef;
+  friend TensorRef;
+
+ public:
+  Tensor() = default;
+  // This constructor should not be used by end users and is an implementation
+  // detail invoked by autogenerated code.
+  explicit Tensor(
+      c10::intrusive_ptr tensor_impl)
+      : TensorBase(std::move(tensor_impl)) {}
+  Tensor(const Tensor &tensor) = default;
+  Tensor(Tensor &&tensor) = default;
+
+  // Implicitly move-constructible from TensorBase, but must be explicit to increase refcount
+  explicit Tensor(const TensorBase &base): TensorBase(base) {}
+  /*implicit*/ Tensor(TensorBase &&base): TensorBase(std::move(base)) {}
+
+  // Creates a new wrapper from TensorImpl. Intentionally a free method because
+  // it should be used with care. Checks necessary invariants
+  static Tensor wrap_tensor_impl(
+      c10::intrusive_ptr tensor_impl) {
+    return TensorBase::wrap_tensor_impl(std::move(tensor_impl));
+  }
+
+  Tensor contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const {
+    return TensorBase::contiguous(memory_format);
+  }
+
+  Tensor conj() const {
+    if (!this->is_complex()) {
+      return *this;
+    }
+
+    switch (this->layout()) {
+      case at::kSparse:
+      case at::kSparseCsr:
+      case at::kSparseCsc:
+      case at::kSparseBsr:
+      case at::kSparseBsc:
+        return this->conj_physical();
+      default:
+        return this->_conj();
+    }
+  }
+
+  // Aliased by Dimname overloads, so need explicit using
+  using TensorBase::size;
+  using TensorBase::sym_size;
+  using TensorBase::stride;
+
+  /// Should be used if *this can reasonably be expected to be contiguous and
+  /// performance is important.
+  /// Compared to contiguous, it saves a reference count
+  /// increment/decrement if *this is already contiguous, at the cost
+  /// in all cases of an extra pointer of stack usage, an extra branch
+  /// to access, and an extra branch at destruction time.
+  c10::MaybeOwned expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const &;
+
+  // Use .contiguous() instead. Trying to borrow from a prvalue Tensor
+  // will only lead to trouble and dangling references.
+  c10::MaybeOwned expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete;
+
+  // The following overloads are very intruiging.  Consider the following
+  // program:
+  //
+  //    x[1] = 3;
+  //
+  // We would expect that the first entry of x is written to 3.  But how can we
+  // actually achieve this?  x[1] evaluates to a tensor...
+  //
+  // The answer is, using a ref-qualifier.  x[1] is an rvalue, which cannot be
+  // (profitably) assigned to in the traditional sense, so we overload
+  // assignment to mean, "Actually, copy 3 into the tensor data."  This is done
+  // with an rvalue-reference ref-qualified overload (the methods with && at the
+  // end of their type.)
+  //
+  // There's one more fly in the ointment: We also want
+  //
+  //    Tensor x = y;
+  //
+  // to work, and we want it NOT to copy.  So we need a traditional operator=
+  // overload.  But we MUST specify a mutable lvalue ref-qualifier, to
+  // disambiguate the traditional overload from the rvalue-reference
+  // ref-qualified overload.  Otherwise, it will be ambiguous, because
+  // a non ref-qualified method is eligible for all situations.
+
+  // Unfortunately, we have to write these constructors out manually
+  // to work around an MSVC bug:
+  //    error C2580: 'at::Tensor &at::Tensor::operator =(const at::Tensor &) &':
+  //    multiple versions of a defaulted special member functions are not allowed
+  // Tensor& operator=(const Tensor&) & = default;
+  // Tensor& operator=(Tensor&&) & = default;
+
+  // Also MSVC will wrongly issue the following warning with the aforementioned fix
+  //    warning C4522: 'at::Tensor': multiple assignment operators specified
+  // Let's just skip the warning.
+  //
+  // TODO: temporarily disabled
+
+  Tensor& operator=(const TensorBase& x) & noexcept {
+    impl_ = x.getIntrusivePtr();
+    return *this;
+  }
+  Tensor& operator=(TensorBase&& x) & noexcept {
+    impl_ = x.unsafeReleaseIntrusivePtr();
+    return *this;
+  }
+
+  Tensor& operator=(const Tensor &x) & noexcept {
+    return operator=(static_cast(x));
+  }
+  Tensor& operator=(Tensor &&x) & noexcept {
+    return operator=(static_cast(x));
+  }
+
+  Tensor& operator=(const Scalar &v) && {
+    return fill_(v);
+  }
+  Tensor& operator=(const Tensor &rhs) && {
+    return copy_(rhs);
+  }
+  Tensor& operator=(Tensor&& rhs) && {
+    return copy_(rhs);
+  }
+
+  C10_DEPRECATED_MESSAGE("Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device().")
+  DeprecatedTypeProperties & type() const {
+    return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
+        dispatchKeyToBackend(legacyExtractDispatchKey(key_set())),
+        scalar_type());
+  }
+
+  Tensor toType(ScalarType t) const {
+    return to(options().dtype(t), /*non_blocking*/ false, /*copy*/ false);
+  }
+
+  // TODO: Deprecate me
+  Tensor toBackend(Backend b) const {
+    return to(options().device(backendToDeviceType(b)).layout(layout_from_backend(b)), /*non_blocking*/ false, /*copy*/ false);
+  }
+
+  C10_DEPRECATED_MESSAGE("Tensor.is_variable() is deprecated; everything is a variable now. (If you want to assert that variable has been appropriately handled already, use at::impl::variable_excluded_from_dispatch())")
+  bool is_variable() const noexcept {
+    return !at::impl::variable_excluded_from_dispatch();
+  }
+
+  template
+  C10_DEPRECATED_MESSAGE("Tensor.data() is deprecated. Please use Tensor.data_ptr() instead.")
+  T * data() const {
+    return data_ptr();
+  }
+
+  template 
+  T item() const;
+
+  template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
+  C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead")
+  GenericPackedTensorAccessor packed_accessor() const & {
+    return generic_packed_accessor();
+  }
+  template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
+  C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead")
+  GenericPackedTensorAccessor packed_accessor() && = delete;
+
+  Tensor operator~() const {
+    return bitwise_not();
+  }
+  Tensor operator-() const {
+    return neg();
+  }
+  Tensor& operator+=(const Tensor & other) {
+    return add_(other);
+  }
+  Tensor& operator+=(const Scalar & other) {
+    return add_(other);
+  }
+  Tensor& operator-=(const Tensor & other) {
+    return sub_(other);
+  }
+  Tensor& operator-=(const Scalar & other) {
+    return sub_(other);
+  }
+  Tensor& operator*=(const Tensor & other) {
+    return mul_(other);
+  }
+  Tensor& operator*=(const Scalar & other) {
+    return mul_(other);
+  }
+  Tensor& operator/=(const Tensor & other) {
+    return div_(other);
+  }
+  Tensor& operator/=(const Scalar & other) {
+    return div_(other);
+  }
+  Tensor& operator&=(const Tensor & other) {
+    return bitwise_and_(other);
+  }
+  Tensor& operator|=(const Tensor & other) {
+    return bitwise_or_(other);
+  }
+  Tensor& operator^=(const Tensor & other) {
+    return bitwise_xor_(other);
+  }
+  Tensor operator[](const Scalar & index) const {
+    if (!index.isIntegral(false)) {
+      TORCH_CHECK_INDEX(false, "Can only index tensors with integral scalars");
+    }
+    return this->operator[](index.toLong());
+  }
+  Tensor operator[](const Tensor & index) const {
+    // These properties are checked in the Scalar constructor, but we already
+    // check them here to provide more useful diagnostics for the user.
+    if (!index.defined()) {
+      TORCH_CHECK_INDEX(false, "Can only index with tensors that are defined");
+    }
+    if (index.dim() != 0) {
+      TORCH_CHECK_INDEX(false,
+                        "Can only index with tensors that are scalars (zero-dim)");
+    }
+    // The Scalar(Tensor) constructor is explicit, so we need to call it.
+    return this->operator[](index.item());
+  }
+  Tensor operator[](int64_t index) const {
+    return select(0, index);
+  }
+
+  Tensor index(ArrayRef indices) const;
+  Tensor index(std::initializer_list indices) const;
+
+  Tensor & index_put_(ArrayRef indices, Tensor const & rhs);
+  Tensor & index_put_(ArrayRef indices, const Scalar& v);
+  Tensor & index_put_(std::initializer_list indices, Tensor const & rhs);
+  Tensor & index_put_(std::initializer_list indices, const Scalar& v);
+
+  Tensor cpu() const {
+    return to(options().device(c10::DeviceType::CPU), /*non_blocking*/ false, /*copy*/ false);
+  }
+
+  // TODO: The Python version also accepts arguments
+  Tensor cuda() const {
+    return to(options().device(c10::DeviceType::CUDA), /*non_blocking*/ false, /*copy*/ false);
+  }
+
+  Tensor hip() const {
+    return to(options().device(c10::DeviceType::HIP), /*non_blocking*/ false, /*copy*/ false);
+  }
+
+  Tensor ve() const {
+    return to(options().device(c10::DeviceType::VE), /*non_blocking*/ false, /*copy*/ false);
+  }
+
+  Tensor vulkan() const {
+    return to(options().device(c10::DeviceType::Vulkan), /*non_blocking*/ false, /*copy*/ false);
+  }
+
+  Tensor metal() const {
+    return to(options().device(c10::DeviceType::Metal), /*non_blocking*/ false, /*copy*/ false);
+  }
+
+  Tensor meta() const {
+    return to(options().device(c10::DeviceType::Meta), /*non_blocking*/ false, /*copy*/ false);
+  }
+
+  // ~~~~~ Autograd API ~~~~~
+
+  /// \fn bool is_leaf() const;
+  ///
+  /// All Tensors that have `requires_grad()` which is ``false`` will be leaf Tensors by convention.
+  ///
+  /// For Tensors that have `requires_grad()` which is ``true``, they will be leaf Tensors if they were
+  /// created by the user. This means that they are not the result of an operation and so
+  /// `grad_fn()` is `nullptr`.
+  ///
+  /// Only leaf Tensors will have their `grad()` populated during a call to `backward()`.
+  /// To get `grad()` populated for non-leaf Tensors, you can use `retain_grad()`.
+  ///
+  /// Example:
+  /// @code
+  /// auto a = torch::rand(10, torch::requires_grad());
+  /// std::cout << a.is_leaf() << std::endl; // prints `true`
+  ///
+  /// auto b = torch::rand(10, torch::requires_grad()).to(torch::kCUDA);
+  /// std::cout << b.is_leaf() << std::endl; // prints `false`
+  /// // b was created by the operation that cast a cpu Tensor into a cuda Tensor
+  ///
+  /// auto c = torch::rand(10, torch::requires_grad()) + 2;
+  /// std::cout << c.is_leaf() << std::endl; // prints `false`
+  /// // c was created by the addition operation
+  ///
+  /// auto d = torch::rand(10).cuda();
+  /// std::cout << d.is_leaf() << std::endl; // prints `true`
+  /// // d does not require gradients and so has no operation creating it (that is tracked by the autograd engine)
+  ///
+  /// auto e = torch::rand(10).cuda().requires_grad_();
+  /// std::cout << e.is_leaf() << std::endl; // prints `true`
+  /// // e requires gradients and has no operations creating it
+  ///
+  /// auto f = torch::rand(10, torch::device(torch::kCUDA).requires_grad(true));
+  /// std::cout << f.is_leaf() << std::endl; // prints `true`
+  /// // f requires grad, has no operation creating it
+  /// @endcode
+
+  /// \fn void backward(const Tensor & gradient={}, std::optional retain_graph=std::nullopt, bool create_graph=false, std::optional inputs=std::nullopt) const;
+  ///
+  /// Computes the gradient of current tensor with respect to graph leaves.
+  ///
+  /// The graph is differentiated using the chain rule. If the tensor is
+  /// non-scalar (i.e. its data has more than one element) and requires
+  /// gradient, the function additionally requires specifying ``gradient``.
+  /// It should be a tensor of matching type and location, that contains
+  /// the gradient of the differentiated function w.r.t. this Tensor.
+  ///
+  /// This function accumulates gradients in the leaves - you might need to
+  /// zero them before calling it.
+  ///
+  /// \param gradient Gradient w.r.t. the
+  ///     tensor. If it is a tensor, it will be automatically converted
+  ///     to a Tensor that does not require grad unless ``create_graph`` is True.
+  ///     None values can be specified for scalar Tensors or ones that
+  ///     don't require grad. If a None value would be acceptable then
+  ///     this argument is optional.
+  /// \param retain_graph If ``false``, the graph used to compute
+  ///     the grads will be freed. Note that in nearly all cases setting
+  ///     this option to True is not needed and often can be worked around
+  ///     in a much more efficient way. Defaults to the value of
+  ///     ``create_graph``.
+  /// \param create_graph If ``true``, graph of the derivative will
+  ///     be constructed, allowing to compute higher order derivative
+  ///     products. Defaults to ``false``.
+  /// \param inputs Inputs w.r.t. which the gradient will be accumulated into
+  ///     ``at::Tensor::grad``. All other Tensors will be ignored. If not
+  ///     provided, the gradient is accumulated into all the leaf Tensors
+  ///     that were used to compute the current tensor.
+  ///     When inputs are provided and a given input is not a leaf,
+  ///     the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
+  ///     It is an implementation detail on which the user should not rely.
+  ///     See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
+  void backward(const Tensor & gradient={}, std::optional retain_graph=std::nullopt, bool create_graph=false, std::optional inputs=std::nullopt) const {
+    // NB: Adding this wrapper to _backward here because we'd like our
+    // 'backwards' api to accept the 'inputs' argument optionally. Since code gen
+    // currently does not support optional of TensorList our approach is to replace
+    // backward in native_functions.yaml with _backward and call it here instead.
+    if (inputs.has_value()) {
+      TORCH_CHECK(inputs.value().size() > 0, "'inputs' argument to backward cannot be empty")
+      this->_backward(inputs.value(), gradient, retain_graph, create_graph);
+    } else {
+      this->_backward({}, gradient, retain_graph, create_graph);
+    }
+  }
+
+  /// \fn Tensor detach() const;
+  ///
+  /// Returns a new Tensor, detached from the current graph.
+  /// The result will never require gradient.
+
+  /// \fn Tensor & detach_() const;
+  ///
+  /// Detaches the Tensor from the graph that created it, making it a leaf.
+  /// Views cannot be detached in-place.
+
+  /// \fn void retain_grad() const;
+  ///
+  /// Enables this Tensor to have their :attr:`grad` populated during
+  /// :func:`backward`. This is a no-op for leaf tensors.
+
+  /// \fn bool retains_grad() const;
+  ///
+  /// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be
+  /// populated during :func:`backward`, ``false`` otherwise.
+
+  const Tensor& set_requires_grad(bool requires_grad) const {
+    TensorBase::set_requires_grad(requires_grad);
+    return *this;
+  }
+
+  /// Return a mutable reference to the gradient. This is conventionally
+  /// used as `t.grad() = x` to set a gradient to a completely new tensor.
+  /// Note that this function work with a non-const Tensor and is not
+  /// thread safe.
+  Tensor& mutable_grad() const {
+    return impl_->mutable_grad();
+  }
+
+  /// This function returns an undefined tensor by default and returns a defined tensor
+  /// the first time a call to `backward()` computes gradients for this Tensor.
+  /// The attribute will then contain the gradients computed and future calls
+  /// to `backward()` will accumulate (add) gradients into it.
+  const Tensor& grad() const {
+    const Tensor& maybe_grad = impl_->grad();
+    if (!is_leaf() && !retains_grad() && !maybe_grad.defined()) {
+      TORCH_WARN(
+        "The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "
+        "attribute won't be populated during autograd.backward(). If you indeed want the .grad "
+        "field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. "
+        "If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor "
+        "instead. See github.com/pytorch/pytorch/pull/30531 for more informations.");
+    }
+    return maybe_grad;
+  }
+
+  // The Forward AD API functions below are low level and are not to be used by end
+  // users who should use the API provided in torch/csrc/autograd.h
+
+  /// This function returns the forward gradient for this Tensor at the given level.
+  const Tensor& _fw_grad(uint64_t level) const {
+    return impl_->_fw_grad(level, *this);
+  }
+
+  /// This function can be used to set the value of the forward grad.
+  /// Note that the given new_grad might not be used directly if it has different
+  /// metadata (size/stride/storage offset) compared to this Tensor. In that case,
+  /// new_grad content will be copied into a new Tensor
+  void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const {
+    impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op);
+  }
+
+
+  // STOP.  Thinking of adding a method here, which only makes use
+  // of other ATen methods?  Define it in native_functions.yaml.
+
+  //example
+  //Tensor * add(Tensor & b);
+  ${tensor_method_declarations}
+
+  // Special C++ only overloads for std()-like functions (See gh-40287)
+  // These are needed because int -> bool conversion takes precedence over int -> IntArrayRef
+  // So, for example std(0) would select the std(unbiased=False) overload
+
+  Tensor var(int dim) const {
+    return var(IntArrayRef{dim});
+  }
+
+  Tensor std(int dim) const {
+    return std(IntArrayRef{dim});
+  }
+
+  // We changed .dtype() to return a TypeMeta in #12766. Ideally, we want the
+  // at::kDouble and its friends to be TypeMeta's, but that hasn't happened yet.
+  // Before that change, we make this method to maintain BC for C++ usage like
+  // `x.to(y.dtype)`.
+  // TODO: remove following two after at::kDouble and its friends are TypeMeta's.
+  inline Tensor to(caffe2::TypeMeta type_meta, bool non_blocking=false, bool copy=false) const {
+    return this->to(/*scalar_type=*/typeMetaToScalarType(type_meta), non_blocking, copy);
+  }
+  inline Tensor to(Device device, caffe2::TypeMeta type_meta, bool non_blocking=false, bool copy=false) const {
+    return this->to(device, /*scalar_type=*/typeMetaToScalarType(type_meta), non_blocking, copy);
+  }
+
+  template 
+  decltype(auto) m(F func, Args&&... params) const {
+    return func(*this, std::forward(params)...);
+  }
+
+  /// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended
+  /// to be used from functions that need to access the `Variable`'s equivalent `Tensor`
+  /// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`).
+  ///
+  /// One notable difference with the legacy `.data()` function is that changes to the
+  /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset)
+  /// will not update the original `Variable`, due to the fact that this function
+  /// shallow-copies the `Variable`'s underlying TensorImpl.
+  at::Tensor tensor_data() const {
+    return TensorBase::tensor_data();
+  }
+
+  /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data`
+  /// in Python, which create a new `Variable` that shares the same storage and
+  /// tensor metadata with the original `Variable`, but with a completely new
+  /// autograd history.
+  ///
+  /// NOTE: If we change the tensor metadata (e.g. sizes / strides /
+  /// storage / storage_offset) of a variable created from `var.variable_data()`, those
+  /// changes will not update the original variable `var`. In `.variable_data()`, we set
+  /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal,
+  /// in order to prevent users from changing metadata of `var.variable_data()`
+  /// and expecting the original variable `var` to also be updated.
+  at::Tensor variable_data() const {
+    return TensorBase::variable_data();
+  }
+
+  // Hooks
+  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+  template 
+  using hook_return_void_t = std::enable_if_t>::value, unsigned>;
+  template 
+  using hook_return_var_t = std::enable_if_t, Tensor>, unsigned>;
+
+  /// Registers a backward hook.
+  ///
+  /// The hook will be called every time a gradient with respect to the Tensor is computed.
+  /// The hook should have one of the following signature:
+  /// ```
+  /// hook(Tensor grad) -> Tensor
+  /// ```
+  /// ```
+  /// hook(Tensor grad) -> void
+  /// ```
+  /// The hook should not modify its argument, but it can optionally return a new gradient
+  /// which will be used in place of `grad`.
+  ///
+  /// This function returns the index of the hook in the list which can be used to remove hook.
+  ///
+  /// Example:
+  /// @code
+  /// auto v = torch::tensor({0., 0., 0.}, torch::requires_grad());
+  /// auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient
+  /// v.backward(torch::tensor({1., 2., 3.}));
+  /// // This prints:
+  /// // ```
+  /// //  2
+  /// //  4
+  /// //  6
+  /// // [ CPUFloatType{3} ]
+  /// // ```
+  /// std::cout << v.grad() << std::endl;
+  /// v.remove_hook(h);  // removes the hook
+  /// @endcode
+  template 
+  hook_return_void_t register_hook(T&& hook) const;
+  template 
+  hook_return_var_t register_hook(T&& hook) const;
+
+  // Variable methods
+  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+  Tensor data() const {
+    return TensorBase::data();
+  }
+
+  void _backward(TensorList inputs, const std::optional& gradient, std::optional keep_graph, bool create_graph) const;
+
+  const Tensor& requires_grad_(bool _requires_grad=true) const {
+    TensorBase::requires_grad_(_requires_grad);
+    return *this;
+  }
+};
+
+namespace detail {
+// Helper creator for Tensor class which doesn't requires the users to pass
+// in an intrusive_ptr instead it just converts the argument passed to
+// requested intrusive_ptr type.
+template 
+Tensor make_tensor(Args&&... args) {
+  return Tensor(c10::make_intrusive(std::forward(args)...));
+}
+
+} // namespace detail
+
+} // namespace at
+
+
+namespace at {
+${tensor_method_definitions}
+} // namespace at
+
+
+namespace c10 {
+template <>
+struct MaybeOwnedTraits {
+  using owned_type = at::Tensor;
+  using borrow_type = at::Tensor;
+
+  static borrow_type createBorrow(const owned_type& from) {
+    // NOTE: this can be implemented without the special
+    // unsafe_borrow_t Tensor constructor as
+    //
+    // return borrow_type(c10::intrusive_ptr::reclaim(from.unsafeGetTensorImpl()));
+    //
+    // but that hurts inlining due to the nullptr check in the
+    // Tensor(c10::intrusive_ptr<...>) constructor. We already know
+    // that from.impl_ isn't null because from is a valid Tensor, so
+    // we needn't do the check again. (using __builtin_assume can
+    // avoid this, but wouldn't be portable to MSVC.)
+    return borrow_type(borrow_type::unsafe_borrow_t{}, from);
+  }
+
+  static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
+    lhs.unsafeReleaseTensorImpl();
+    // See above note: this can be implemented with public API
+    // similarly to createBorrow(), but that would hurt inlining.
+    lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs);
+  }
+
+  static void destroyBorrow(borrow_type& toDestroy) {
+    toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0.
+  }
+
+  static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
+    return borrow;
+  }
+
+  static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
+    return &borrow;
+  }
+
+  static bool debugBorrowIsValid(const borrow_type& /*borrow*/) {
+    return true;
+  }
+};
+
+template <>
+struct ExclusivelyOwnedTraits {
+  using repr_type = at::Tensor;
+  using pointer_type = at::Tensor*;
+  using const_pointer_type = const at::Tensor*;
+
+  static repr_type nullRepr() {
+    return at::Tensor();
+  }
+
+  template 
+  static repr_type createInPlace(Args&&... args) {
+    return at::Tensor(std::forward(args)...);
+  }
+
+  static repr_type moveToRepr(at::Tensor&& x) {
+    return std::move(x);
+  }
+
+  static void destroyOwned(at::Tensor& x) {
+    return ExclusivelyOwnedTraits::destroyOwned(x);
+  }
+
+  static at::Tensor take(at::Tensor& x) {
+    return std::move(x);
+  }
+
+  static pointer_type getImpl(repr_type& x) {
+    return &x;
+  }
+
+  static const_pointer_type getImpl(const repr_type& x) {
+    return &x;
+  }
+};
+} // namespace c10
+
+namespace at {
+
+inline c10::MaybeOwned borrow_from_optional_tensor(
+    const std::optional& opt) {
+  return opt.has_value()
+    ? c10::MaybeOwned::borrowed(*opt)
+    : c10::MaybeOwned::owned(std::in_place);
+}
+
+inline c10::MaybeOwned Tensor::expect_contiguous(MemoryFormat memory_format) const & {
+  if (is_contiguous(memory_format)) {
+    return c10::MaybeOwned::borrowed(*this);
+  } else {
+    return c10::MaybeOwned::owned(__dispatch_contiguous(memory_format));
+  }
+}
+} // namespace at
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/TensorMethods.cpp b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/TensorMethods.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..19c612ad6f041cf8bb9a6a06f23e7b09ff07896f
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/TensorMethods.cpp
@@ -0,0 +1,61 @@
+#include 
+#include 
+
+#include 
+
+namespace at {
+
+namespace {
+
+// Verifies the requested type is the same as the Tensor's type.
+void check_type(const TensorBase& tensor, ScalarType type, std::string_view type_name) {
+  TORCH_CHECK(
+      tensor.scalar_type() == type
+      || (isQIntType(tensor.scalar_type())
+          && toUnderlying(tensor.scalar_type()) == type),
+      "expected scalar type ", type_name, " but found ", tensor.scalar_type());
+}
+
+} // namespace
+
+#define DEFINE_CAST(T, name)                                         \
+   template <>                                                       \
+   TORCH_API const T* TensorBase::const_data_ptr() const {           \
+     check_type(*this, ScalarType::name, #name);                     \
+     return this->unsafeGetTensorImpl()->data_ptr_impl();         \
+   }                                                                 \
+                                                                     \
+   template <>                                                       \
+   TORCH_API const T* TensorBase::const_data_ptr() const {  \
+     check_type(*this, ScalarType::name, #name);                     \
+     return this->unsafeGetTensorImpl()->data_ptr_impl>(); \
+   }                                                                 \
+                                                                     \
+   template <>                                                       \
+   TORCH_API T* TensorBase::mutable_data_ptr() const {               \
+     check_type(*this, ScalarType::name, #name);                     \
+     return this->unsafeGetTensorImpl()->mutable_data_ptr_impl(); \
+   }                                                                 \
+                                                                     \
+   template <>                                                       \
+   TORCH_API T* TensorBase::data_ptr() const {                       \
+     return mutable_data_ptr();                                   \
+   }                                                                 \
+
+ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CAST)
+ AT_FORALL_QINT_TYPES(DEFINE_CAST)
+ DEFINE_CAST(uint16_t, UInt16)
+ DEFINE_CAST(uint32_t, UInt32)
+ DEFINE_CAST(uint64_t, UInt64)
+ #undef DEFINE_CAST
+
+ #define DEFINE_ITEM(T, name)      \
+   template <>                     \
+   TORCH_API T Tensor::item() const { \
+     return item().to##name();     \
+   }
+
+ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_ITEM)
+ #undef DEFINE_ITEM
+
+ } //namespace at
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/UfuncCPU.cpp b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/UfuncCPU.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..46e9f4eca41156ec4ea6a962f8d643e292165cd8
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/UfuncCPU.cpp
@@ -0,0 +1,19 @@
+#define TORCH_ASSERT_NO_OPERATORS
+
+#include 
+#include 
+#include 
+
+namespace at {
+
+// NB: this is explicitly copied here (via codegen) rather than
+// included via NativeFunctions.h to avoid recompiling this file when
+// NativeFunctions.h changes
+namespace meta {
+${meta_declaration}
+}
+
+namespace native {
+${native_declaration}
+${native_definitions}
+}} // namespace at::native
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/UfuncCPUKernel.cpp b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/UfuncCPUKernel.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..6db4c0280bda7e46a6dd92ec09f3aab60f278c44
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/UfuncCPUKernel.cpp
@@ -0,0 +1,14 @@
+#define TORCH_ASSERT_NO_OPERATORS
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace at {
+namespace native {
+${native_definitions}
+}} // namespace at::native
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/UfuncCUDA.cu b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/UfuncCUDA.cu
new file mode 100644
index 0000000000000000000000000000000000000000..90cbe9d4add4ca094b5fa7661df6bc798556767d
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/UfuncCUDA.cu
@@ -0,0 +1,21 @@
+#define TORCH_ASSERT_NO_OPERATORS
+
+#include 
+#include 
+#include 
+#include 
+${cuda_headers}
+
+namespace at {
+
+// NB: this is explicitly copied here (via codegen) rather than
+// included via NativeFunctions.h to avoid recompiling this file when
+// NativeFunctions.h changes
+namespace meta {
+${meta_declaration}
+}
+
+namespace native {
+${native_declaration}
+${native_definitions}
+}} // namespace at::native
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/UnboxingFunctions.cpp b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/UnboxingFunctions.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..b564cac031754439e5b3e2dd0a2a2a694c1af504
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/UnboxingFunctions.cpp
@@ -0,0 +1,35 @@
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+namespace at {
+namespace unboxing {
+
+using ::c10::fmap;
+using ::c10::filter;
+using torch::jit::peek;
+using torch::jit::drop;
+using torch::jit::pack;
+using torch::jit::pop;
+
+// Generated function declaration
+${definitions}
+
+} // namespace unboxing
+} // namespace at
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/UnboxingFunctions.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/UnboxingFunctions.h
new file mode 100644
index 0000000000000000000000000000000000000000..698fb032046497eb87882fa57fc71de1fd49537b
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/UnboxingFunctions.h
@@ -0,0 +1,32 @@
+// ${generated_comment}
+
+// Generated by tools/jit/gen_unboxing.py. This file declares code generated boxed C++ functions for operators,
+// base off of native_functions.yaml (or similar yaml file with the same syntax). The definition of such a boxed
+// function will pop out IValues from the stack then convert them into the correct C++ types based on given schema. This
+// unboxing logic is an alternative to template-based metaprogramming unboxing.
+
+#pragma once
+
+#include 
+namespace at {
+namespace unboxing {
+namespace {
+
+template
+std::array as_array(const c10::List& list) {
+    std::array res;
+    AT_ASSERT(list.size() == N);
+    std::vector vec;
+    for (c10::IValue elem : list) {
+        vec.push_back(elem.to());
+    }
+    std::copy(vec.begin(), vec.end(), res.begin());
+    return res;
+}
+}  // namespace 
+using Stack = std::vector;
+// Generated function declaration
+${declarations}
+
+} // namespace unboxing
+} // namespace at
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/aten_interned_strings.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/aten_interned_strings.h
new file mode 100644
index 0000000000000000000000000000000000000000..8af363bc783fe29a3b2c82444c8a06108a38f59f
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/aten_interned_strings.h
@@ -0,0 +1,22 @@
+#pragma once
+
+// ${generated_comment}
+
+#if defined(TORCH_ASSERT_NO_OPERATORS) || defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
+#error This change adds a dependency on native_functions.yaml,          \
+  meaning the file will need to be re-compiled every time an operator   \
+  is changed or added. Consider if including  for   \
+  the c10::Symbol class would be sufficient, or if your change would be \
+  better placed in another file.
+#endif
+
+// ATen symbols correspond exactly to operators defined in ATen. Every
+// symbol here corresponds exactly to an ATen operation defined in
+// native_functions.yaml; attributes are in one-to-one correspondence
+// with their ATen name.
+
+#define FORALL_ATEN_BASE_SYMBOLS(_) \
+${aten_symbols}
+
+#define FORALL_ATTR_BASE_SYMBOLS(_) \
+${attr_symbols}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/enum_tag.h b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/enum_tag.h
new file mode 100644
index 0000000000000000000000000000000000000000..39c8c0049e4b9e833481dfdfc896fd889a01ca5a
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/ATen/templates/enum_tag.h
@@ -0,0 +1,10 @@
+#pragma once
+
+// ${generated_comment}
+
+namespace at {
+    // Enum of valid tags obtained from the entries in tags.yaml
+    enum class Tag {
+        ${enum_of_valid_tags}
+    };
+}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/BUILD.bazel b/phivenv/Lib/site-packages/torchgen/packaged/autograd/BUILD.bazel
new file mode 100644
index 0000000000000000000000000000000000000000..f4127325e0958e6884843e37efff282ab7af484d
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/BUILD.bazel
@@ -0,0 +1,4 @@
+load("//:tools/bazel.bzl", "rules")
+load(":build.bzl", "define_targets")
+
+define_targets(rules = rules)
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/README.md b/phivenv/Lib/site-packages/torchgen/packaged/autograd/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0200bb9d56d4148d8befc61133552989eb09b947
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/README.md
@@ -0,0 +1,3 @@
+If you add a file to this directory, you **MUST** update
+`torch/CMakeLists.txt` and add the file as a dependency to
+the `add_custom_command` call.
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/__init__.py b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..39bd15af6cd96c1260b9acfbb8c91b145a0b237b
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/__init__.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/context.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/context.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0559a07721f81045456a7bae3c34076c926f9fd1
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/context.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_annotated_fn_args.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_annotated_fn_args.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..87bae8028fe78a9883fbdb27f98af7b2c607320d
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_annotated_fn_args.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0bfb15d3dbfa2d8260145bba46e41bb354414b34
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd_functions.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd_functions.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ae0c5e2ea8293685e81913c17b318be56fdeba2c
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_autograd_functions.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_inplace_or_view_type.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_inplace_or_view_type.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..178ca99eba28d033d1be2e8da2603fdbaf95266d
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_inplace_or_view_type.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_python_functions.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_python_functions.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7980556dbaeb35f5bcafc25ec60d689ce854d812
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_python_functions.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_trace_type.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_trace_type.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..959f407502697b5d6b968ff51544bdcc1e68763c
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_trace_type.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_factories.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_factories.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c2b1aa1947c8772c4afdd14459e3d9aefa4414b
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_factories.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_type.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_type.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e6dc4adc6d7a6b57c5a8436936c25b82d16cb7bb
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_variable_type.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_view_funcs.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_view_funcs.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c8c68d1113298dd786cf3af8202e8fd01486aafc
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/gen_view_funcs.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/load_derivatives.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/load_derivatives.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..be581434bfe1fd2b927b1f6220e464dff6cb36d3
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/packaged/autograd/__pycache__/load_derivatives.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/build.bzl b/phivenv/Lib/site-packages/torchgen/packaged/autograd/build.bzl
new file mode 100644
index 0000000000000000000000000000000000000000..08071722d7cbb8af370d679de92c2624b1b205ad
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/build.bzl
@@ -0,0 +1,14 @@
+def define_targets(rules):
+    rules.py_library(
+        name = "autograd",
+        srcs = rules.glob(["*.py"]),
+        data = rules.glob([
+            "*.yaml",
+            "templates/*",
+        ]),
+        visibility = ["//:__subpackages__"],
+        deps = [
+            rules.requirement("PyYAML"),
+            "//torchgen",
+        ],
+    )
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/context.py b/phivenv/Lib/site-packages/torchgen/packaged/autograd/context.py
new file mode 100644
index 0000000000000000000000000000000000000000..1978a8ec7958bcaace429e1ce0adf5c9f9a0c934
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/context.py
@@ -0,0 +1,31 @@
+import functools
+from typing import Callable
+
+from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI
+from torchgen.context import native_function_manager
+from torchgen.utils import T
+
+
+# Like tools.api.context.with_native_function, but for
+# NativeFunctionWithDifferentiabilityInfo.
+def with_native_function_with_differentiability_info(
+    func: Callable[[NFWDI], T],
+) -> Callable[[NFWDI], T]:
+    @functools.wraps(func)
+    def wrapper(f: NFWDI) -> T:
+        with native_function_manager(f.func):
+            return func(f)
+
+    return wrapper
+
+
+# Like the above but with an additional dispatch key string argument
+def with_native_function_with_differentiability_info_and_key(
+    func: Callable[[NFWDI, str], T],
+) -> Callable[[NFWDI, str], T]:
+    @functools.wraps(func)
+    def wrapper(f: NFWDI, key: str) -> T:
+        with native_function_manager(f.func):
+            return func(f, key)
+
+    return wrapper
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/deprecated.yaml b/phivenv/Lib/site-packages/torchgen/packaged/autograd/deprecated.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e0998acd3fc869df696acbfda28b6bab8f4779d4
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/deprecated.yaml
@@ -0,0 +1,134 @@
+# Deprecated function signatures. These are exposed in Python, but not included
+# in the error message suggestions.
+
+- name: add(Tensor self, Scalar alpha, Tensor other) -> Tensor
+  aten: add(self, other, alpha)
+
+- name: add_(Tensor(a!) self, Scalar alpha, Tensor other) -> Tensor(a!)
+  aten: add_(self, other, alpha)
+
+- name: add(Tensor self, Scalar alpha, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  aten: add_out(out, self, other, alpha)
+
+- name: addbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2) -> Tensor
+  aten: addbmm(self, batch1, batch2, beta, alpha)
+
+- name: addbmm_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor batch1, Tensor batch2) -> Tensor(a!)
+  aten: addbmm_(self, batch1, batch2, beta, alpha)
+
+- name: addbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2, *, Tensor(a!) out) -> Tensor(a!)
+  aten: addbmm_out(out, self, batch1, batch2, beta, alpha)
+
+- name: addbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2) -> Tensor
+  aten: addbmm(self, batch1, batch2, beta, 1)
+
+- name: addbmm_(Scalar beta, Tensor(a!) self, Tensor batch1, Tensor batch2) -> Tensor(a!)
+  aten: addbmm_(self, batch1, batch2, beta, 1)
+
+- name: addbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2, *, Tensor(a!) out) -> Tensor(a!)
+  aten: addbmm_out(out, self, batch1, batch2, beta, 1)
+
+- name: addcdiv(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2) -> Tensor
+  aten: addcdiv(self, tensor1, tensor2, value)
+
+- name: addcdiv_(Tensor(a!) self, Scalar value, Tensor tensor1, Tensor tensor2) -> Tensor(a!)
+  aten: addcdiv_(self, tensor1, tensor2, value)
+
+- name: addcdiv(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2, *, Tensor(a!) out) -> Tensor(a!)
+  aten: addcdiv_out(out, self, tensor1, tensor2, value)
+
+- name: addcmul(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2) -> Tensor
+  aten: addcmul(self, tensor1, tensor2, value)
+
+- name: addcmul_(Tensor(a!) self, Scalar value, Tensor tensor1, Tensor tensor2) -> Tensor(a!)
+  aten: addcmul_(self, tensor1, tensor2, value)
+
+- name: addcmul(Tensor self, Scalar value, Tensor tensor1, Tensor tensor2, *, Tensor(a!) out) -> Tensor(a!)
+  aten: addcmul_out(out, self, tensor1, tensor2, value)
+
+- name: addmm(Scalar beta, Tensor self, Scalar alpha, Tensor mat1, Tensor mat2) -> Tensor
+  aten: addmm(self, mat1, mat2, beta, alpha)
+
+- name: addmm_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor mat1, Tensor mat2) -> Tensor(a!)
+  aten: addmm_(self, mat1, mat2, beta, alpha)
+
+- name: addmm(Scalar beta, Tensor self, Scalar alpha, Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
+  aten: addmm_out(out, self, mat1, mat2, beta, alpha)
+
+- name: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2) -> Tensor
+  aten: addmm(self, mat1, mat2, beta, 1)
+
+- name: addmm_(Scalar beta, Tensor(a!) self, Tensor mat1, Tensor mat2) -> Tensor(a!)
+  aten: addmm_(self, mat1, mat2, beta, 1)
+
+- name: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
+  aten: addmm_out(out, self, mat1, mat2, beta, 1)
+
+- name: sspaddmm(Scalar beta, Tensor self, Scalar alpha, Tensor mat1, Tensor mat2) -> Tensor
+  aten: sspaddmm(self, mat1, mat2, beta, alpha)
+
+- name: sspaddmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2) -> Tensor
+  aten: sspaddmm(self, mat1, mat2, beta, 1)
+
+- name: addmv(Scalar beta, Tensor self, Scalar alpha, Tensor mat, Tensor vec) -> Tensor
+  aten: addmv(self, mat, vec, beta, alpha)
+
+- name: addmv_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor mat, Tensor vec) -> Tensor(a!)
+  aten: addmv_(self, mat, vec, beta, alpha)
+
+- name: addmv(Scalar beta, Tensor self, Scalar alpha, Tensor mat, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
+  aten: addmv_out(out, self, mat, vec, beta, alpha)
+
+- name: addmv(Scalar beta, Tensor self, Tensor mat, Tensor vec) -> Tensor
+  aten: addmv(self, mat, vec, beta, 1)
+
+- name: addmv_(Scalar beta, Tensor(a!) self, Tensor mat, Tensor vec) -> Tensor(a!)
+  aten: addmv_(self, mat, vec, beta, 1)
+
+- name: addmv(Scalar beta, Tensor self, Tensor mat, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
+  aten: addmv_out(out, self, mat, vec, beta, 1)
+
+- name: addr(Scalar beta, Tensor self, Scalar alpha, Tensor vec1, Tensor vec2) -> Tensor
+  aten: addr(self, vec1, vec2, beta, alpha)
+
+- name: addr_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor vec1, Tensor vec2) -> Tensor(a!)
+  aten: addr_(self, vec1, vec2, beta, alpha)
+
+- name: addr(Scalar beta, Tensor self, Scalar alpha, Tensor vec1, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)
+  aten: addr_out(out, self, vec1, vec2, beta, alpha)
+
+- name: addr(Scalar beta, Tensor self, Tensor vec1, Tensor vec2) -> Tensor
+  aten: addr(self, vec1, vec2, beta, 1)
+
+- name: addr_(Scalar beta, Tensor(a!) self, Tensor vec1, Tensor vec2) -> Tensor(a!)
+  aten: addr_(self, vec1, vec2, beta, 1)
+
+- name: addr(Scalar beta, Tensor self, Tensor vec1, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)
+  aten: addr_out(out, self, vec1, vec2, beta, 1)
+
+- name: baddbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2) -> Tensor
+  aten: baddbmm(self, batch1, batch2, beta, alpha)
+
+- name: baddbmm_(Scalar beta, Tensor(a!) self, Scalar alpha, Tensor batch1, Tensor batch2) -> Tensor(a!)
+  aten: baddbmm_(self, batch1, batch2, beta, alpha)
+
+- name: baddbmm(Scalar beta, Tensor self, Scalar alpha, Tensor batch1, Tensor batch2, *, Tensor(a!) out) -> Tensor(a!)
+  aten: baddbmm_out(out, self, batch1, batch2, beta, alpha)
+
+- name: baddbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2) -> Tensor
+  aten: baddbmm(self, batch1, batch2, beta, 1)
+
+- name: baddbmm_(Scalar beta, Tensor(a!) self, Tensor batch1, Tensor batch2) -> Tensor(a!)
+  aten: baddbmm_(self, batch1, batch2, beta, 1)
+
+- name: baddbmm(Scalar beta, Tensor self, Tensor batch1, Tensor batch2, *, Tensor(a!) out) -> Tensor(a!)
+  aten: baddbmm_out(out, self, batch1, batch2, beta, 1)
+
+- name: sub(Tensor self, Scalar alpha, Tensor other) -> Tensor
+  aten: sub(self, other, alpha)
+
+- name: sub_(Tensor(a!) self, Scalar alpha, Tensor other) -> Tensor(a!)
+  aten: sub_(self, other, alpha)
+
+- name: sub(Tensor self, Scalar alpha, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+  aten: sub_out(out, self, other, alpha)
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/derivatives.yaml b/phivenv/Lib/site-packages/torchgen/packaged/autograd/derivatives.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..aee2c70b2891b6ae07ba59d0635ef880a5dfdc38
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/derivatives.yaml
@@ -0,0 +1,3235 @@
+# Defines derivative formulas and Python signatures of methods on Variable
+#
+# Note about possibly confusing nomenclature: An 'output gradient' is the
+# gradient of an output of a forward function. Output gradients are used as
+# the inputs to backward functions. `grads` is a vector of output gradients,
+# and `grad == grads[0]`, in all the derivative formulas in this file.
+# An 'input gradient' is the gradient of an input to a forward function.
+# Input gradients are the outputs of backward functions, corresponding to the
+# input names included in the derivative formulas defined in this file.
+# Also, every time we talk computing "gradient" we actually mean computing
+# the vector jacobian product using the given 'output gradient' as the vector.
+#
+# Each entry consists of:
+#   - A 'name', which specifies the ATen name of the function you
+#     are defining derivatives for, and an argument specification.
+#   - An optional 'dispatch' entry which can be used to specify
+#     per-autograd dispatch key derivatives. If this entry is not
+#     specified, then the gradient entries will be taken as the
+#     default gradients (i.e. registered for every backward dispatch
+#     key). (see _test_autograd_multiple_dispatch for an example
+#     of how to register separate derivates for different dispatch keys).
+#     The list of allowed dispatch keys (in addition to 'Default' which
+#     represents the Autograd alias key) is torchgen/model.py:AUTOGRAD_KEYS.
+#   - One or more gradients entries, mapping differentiable input
+#     names to a formula specifying how to compute its gradient.
+#     Note that a single gradient entry can specify the gradient
+#     formula for multiple input names, by specifying a key
+#     "input1, input2" (see atan2 for an example).
+#   - An argument can be flagged as 'non_differentiable'.
+#   - Optional entry with key 'output_differentiability' and value a list of the
+#     same length as the number of outputs from the forward function. The list
+#     should contain only booleans, specifying whether each of the output Tensor
+#     is differentiable.
+#     If it is not specified for a function that returns multiple elements but
+#     uses `grad` instead of `grads[idx]`, then all but the first output will
+#     be marked as non-differentiable.
+#     If None of the output is differentiable, you can also add the function
+#     name to `gen_variable_type.py`'s `DONT_REQUIRE_DERIVATIVE` list.
+#
+# There are two cases for Tensor and TensorList arguments here:
+#   - If that argument is differentiable, in the sense that a gradient with respect
+#     to that argument could exist. You should either:
+#       - Specify the formula for that gradient
+#       - Specify not_implemented("function_name") as a formula to say that this is not
+#         implemented yet (but might be in the future and the user can request that on an issue)
+#   - If that argument is not differentiable, because it is not a floating point dtype or the
+#     function is not differentiable with respect to that argument  for
+#     example. You should either:
+#       - Do not specify any formula for this argument
+#       - Specify explicitly that this argument is "non_differentiable". Note that in this case,
+#         we trust you that this argument will never have requires_grad=True and it will be silently
+#         ignored if it does.
+#
+# If a function has out-of-place and in-place variants, then the derivative
+# definition for the in-place variant is optional. It will default to the
+# definition for the out-of-place variant. Note that _out variants are never
+# differentiable.
+#
+# Gradient expressions are standard C++ expressions operating on ATen
+# variables.  In a gradient expression, the following variables/functions
+# are in scope:
+#
+#   - 'grad', the gradient of the output (often spelled grad_output
+#     in Python) which we are going to left-multiply.
+#
+#     When a function returns multiple *differentiable* outputs,
+#     you can refer to the gradients of each outputs using 'grads',
+#     e.g., 'grads[0]', 'grads[1]'.
+#
+#     When a function returns multiple *differentiable* outputs that
+#     are named, you can refer to the gradients of each outputs using
+#     'grad_{name}', e.g., 'grad_x', 'grad_y'.
+#
+#     When a function returns *one* differentiable output (the
+#     first output) and some more nondifferentiable outputs,
+#     you MUST refer to the gradient of the differentiable output with
+#     'grad' (this case is special-cased in our code generation).
+#
+#     Note that the number of differentiable outputs can be modified by the
+#     'output_differentiability' entry (see above).
+#
+#     Across a differentiable function's derivatives set, it is not
+#     permitted to mix the use of "grad", "grads", and
+#     "grad_{name}". You must be consistent for that differentiable
+#     function.
+#
+#   - Any of the input arguments, tensor or non-tensor, including
+#     argument names that only appear in Declarations.yaml, e.g. 'output'.
+#
+#   - 'result', representing the result of evaluating the forward
+#     expression for ATen native function declarations. If the forward
+#     expression outputs a tuple, use 'resultX' instead to access the
+#     X-th entry
+#
+#   - 'grad_input_mask', a std::array, specifies which input
+#     gradients are actually needed.  For example, in the entry
+#     `input0, input1: foo(grad_input_mask)`, `grad_input_mask` is a size
+#     two array, where `grad_input_mask[0]` is true if `input0` requires
+#     grad, and `grad_input_mask[1]` is true if `input1` requires grad.
+#
+#     (NB: if your function computes gradient for a list of tensors,
+#     the `grad_input_mask` will only have a single entry for the list
+#     specifying if either zero or at least one tensor from the list requires
+#     grad.  If we want to support more fine-grained signalling,
+#     we'll need some alternate variable which is not a std::array)
+#
+#   - 'retain_variables', a bool which is true if a user has specified
+#     that saved variables should be retained in case the backwards is
+#     run again later.  This allows an optimization where we can
+#     destroy saved buffers if we know variables are not going to be retained,
+#     e.g., it is used by _cudnn_rnn
+#
+#   - `wrap_opt_if`, is a 2-argument function that accepts a tensor
+#     variable and a boolean condition that dictates whether to save that
+#     variable in a graph. The result of this function is `std::optional`,
+#     and it is `::std::nullopt` when the condition evaluates to `false`,
+#     otherwise it is the variable wrapped in `std::optional`.
+#     For example, wrap_opt_if(var_0, grad_input_mask[1] || grad_input_mask[2])
+#     would mean that `var_0` is saved as long as the second (grad_input_mask[1])
+#     or the third (grad_input_mask[2]) argument requires gradients.
+#     Another interpretation of this expression would read as `var_0` is needed
+#     in the backward computation of the second or the third argument.
+#     NOTE: the usage of `var_i.requires_grad()` in the conditional expression
+#     is not supported, use `grad_input_mask[i]` instead.
+#     NOTE: `wrap_opt_if` could be used to prevent saving redundant variables
+#     with multi-output backward formulas.
+#     See https://github.com/pytorch/pytorch/issues/97575 for more details
+#     on the issue.
+#
+# If you need a complex expression, e.g., with local variables,
+# write a _backward function in torch/csrc/autograd/FunctionsManual.cpp
+# and invoke it from here.  By the way, go read
+# https://github.com/zdevito/ATen/issues/163; this describes an
+# important hazard that occurs when porting backwards from Python to C++
+#
+# Double backwards gradient expressions can be somewhat confusing;
+# the most important thing to remember is: (1) you need to define a
+# derivative formula for every input, including inputs named things
+# like 'grad_output', and (2) the gradient to multiply with is always
+# called 'grad' (even though it really is a grad-grad).
+#
+# You can also add forward derivative definition by defining a formula for
+# a returned value (in general "result" if the name is not specified). This
+# formula works the same way as the backward one and advanced implementations
+# should also be placed in the FunctionsManual file.
+# This formula should compute a single Jacobian vector product using the (primal)
+# value of the argument "foo_p", its forward grad "foo_t" and the result of the
+# function as "result".
+# Note that the forward derivative can be automatically generated in two cases:
+#     - if your function is linear (NOT affine or multi-linear), then you can
+#       specify so by just using the string "auto_linear" for the formula.
+#     - if your function is applied element wise (and has a single input), you
+#       can specify so by just using the string "auto_element_wise" for the formula.
+#
+# Note that to avoid unpacking overhead, functions taking TensorList as inputs
+# will always have their forward grad formula called. This function is responsible
+# to check if any computation is needed and should return an undefined Tensor when
+# there is nothing to do. You can check "cat_forward" for a full example.
+#
+# NB: There are a number of gradient definitions in here which are bogus
+# (implemented using zeros_like).  These gradients are (hopefully) not
+# used by our frontend.  You MUST check the frontend code; search for
+# OpName.apply to see if it's still using a legacy Python style API.
+#
+# Note: Returning views.
+# The following cases exist:
+#     - If a function returns no view, it can have arbitrary outputs.
+#     - If a function return at least one Tensor that is a differentiable view
+#       of one of its input:
+#         - If there is only one differentiable output, this Tensor is marked as a
+#           differentiable view. (alias or transpose for example)
+#         - If there are more than one differentiable output, by default all the views are
+#           marked as differentiable views and created with allow_rebase_history=false.
+#           Meaning that any inplace operation on it will raise an error. (unbind for example)
+#
+#  Notes about undefined output gradients:
+#     All backward functions must support all combinations of undefined output
+#     gradient Tensors, where `grad[i].defined() == false`. Depending on the
+#     number of input and output grads your derivative formula uses, code
+#     generation may automatically add some level of undefined grad support,
+#     according to these three cases:
+#
+#       * 1 input grad and 1 output grad:
+#           Complete undefined grad support is automatically added, so you
+#           shouldn't have to think about it, unless there is a bug in the code
+#           generation.
+#
+#       * 1 input grad and multiple output grads:
+#           Undefined grad support is automatically added ONLY in the case where
+#           all output grads are undefined. You will have to add explicit support
+#           for cases where a subset of output grads is undefined.
+#
+#       * multiple input grads:
+#           No automatic support, so you will need to add it.
+#
+#     If your derivative formula uses more than one output grad, it is usually
+#     preferable to add undefined grad support in the backward function itself
+#     (if you're using one), rather than in the derivative formula in this file.
+#
+#     Undefined Tensors are created with the default constructor `at::Tensor()`.
+#     It is an efficient way to represent a Tensor filled with zeros because
+#     the Tensor holds no sizing information and no Storage data is allocated.
+#     But consequently, Tensor operations cannot be performed on them.
+#     Therefore, your backward function should treat an undefined output grad as
+#     a zero, and it needs to be a special case.
+#
+#     If all output grads are undefined, then it should be correct for the
+#     backward function to return undefined input grads. Since we use the chain
+#     rule, output grads equal to zero should result in input grads equal to zero,
+#     unless there is some rare special case.
+#
+#     If a subset of output grads is undefined, then it may be acceptable for
+#     the backward function to return undefined input grads--it depends on the
+#     specific function, so you'll have to determine that yourself. If returning
+#     an undefined Tensor is correct for a given input grad, it is also logically
+#     correct to return a defined grad full of zeros, but that would not be
+#     preferable since it would be less efficient.
+#
+# NB: The parameter names here MUST be consistent with the parameter names
+# in native_functions.yaml
+- name: abs(Tensor self) -> Tensor
+  self: grad * self.sgn()
+  result: handle_r_to_c(result.scalar_type(), self_t.conj() * self_p.sgn())
+
+- name: acos(Tensor self) -> Tensor
+  self: grad * -((-self * self + 1).rsqrt()).conj()
+  result: auto_element_wise
+
+- name: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+  self: handle_r_to_c(self.scalar_type(), grad)
+  other: handle_r_to_c(other.scalar_type(), maybe_multiply(grad, alpha.conj()))
+  result: self_t + maybe_multiply(other_t, alpha)
+
+- name: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
+  self: handle_r_to_c(self.scalar_type(), grad)
+  result: self_t.clone()
+
+- name: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+  self: maybe_multiply(grad, beta.conj())
+  batch1: maybe_multiply(grad.unsqueeze(0).expand_symint({ batch1.sym_size(0), batch1.sym_size(1), batch2.sym_size(2) }).bmm(batch2.transpose(1, 2).conj()), alpha.conj())
+  batch2: maybe_multiply(batch1.transpose(1, 2).conj().bmm(grad.unsqueeze(0).expand_symint({ batch1.sym_size(0), batch1.sym_size(1), batch2.sym_size(2) })), alpha.conj())
+  result: maybe_multiply(self_t, beta) + maybe_multiply(batch1_t.bmm(batch2_p).sum(0), alpha) + maybe_multiply(batch1_p.bmm(batch2_t).sum(0), alpha)
+
+- name: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
+  self: handle_r_to_c(self.scalar_type(), grad)
+  tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (value / tensor2).conj())
+  tensor2: handle_r_to_c(tensor2.scalar_type(), -grad * (value * tensor1 / (tensor2 * tensor2)).conj())
+  result: self_t + maybe_multiply(tensor1_t / tensor2_p, value) - maybe_multiply(tensor2_t * (tensor1_p / tensor2_p) / tensor2_p, value)
+
+- name: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor
+  self: handle_r_to_c(self.scalar_type(), grad)
+  tensor1: handle_r_to_c(tensor1.scalar_type(), grad * (tensor2 * value).conj())
+  tensor2: handle_r_to_c(tensor2.scalar_type(), grad * (tensor1 * value).conj())
+  result: self_t + maybe_multiply(tensor1_t * tensor2_p, value) + maybe_multiply(tensor2_t * tensor1_p, value)
+
+- name: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+  self: maybe_multiply(grad, beta.conj())
+  mat1: mm_mat1_backward(grad, mat2, mat1.sym_sizes(), mat1.sym_strides(), mat1.layout(), alpha)
+  mat2: mm_mat2_backward(grad, mat1, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), alpha)
+  result: maybe_multiply(self_t, beta) + maybe_multiply(mat1_t.mm(mat2_p), alpha) + maybe_multiply(mat1_p.mm(mat2_t), alpha)
+
+- name: _sparse_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+  self: maybe_multiply(grad, beta)
+  mat1: mm_mat1_sparse_backward(grad, mat1, mat2, alpha)
+  mat2: mm_mat2_backward(grad, mat1, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), alpha)
+
+- name: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+  self: maybe_multiply(grad, beta.conj())
+  mat: maybe_multiply(grad.ger(vec.conj()), alpha.conj())
+  vec: maybe_multiply(mat.t().conj().mv(grad), alpha.conj())
+  result: maybe_multiply(self_t, beta) + maybe_multiply(mat_t.mv(vec_p), alpha) + maybe_multiply(mat_p.mv(vec_t), alpha)
+
+- name: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+  self: maybe_multiply(grad, beta.conj())
+  vec1: maybe_multiply(grad.mv(vec2.conj()), alpha.conj())
+  vec2: maybe_multiply(grad.t().mv(vec1.conj()), alpha.conj())
+  result: maybe_multiply(self_t, beta) + maybe_multiply(vec1_t.outer(vec2_p), alpha) + maybe_multiply(vec1_p.outer(vec2_t), alpha)
+
+- name: affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor
+  theta: affine_grid_generator_backward_symint(grad, size, align_corners)
+  result: auto_linear
+
+- name: alias(Tensor(a) self) -> Tensor(a)
+  self: grad
+  result: self_t
+
+- name: angle(Tensor self) -> Tensor
+  self: angle_backward(grad, self)
+  result: handle_r_to_c(result.scalar_type(), angle_backward(self_t.conj(), self_p).conj())
+
+# The four items below are necessary because TensorIterator doesn't work on
+# Variables (codegen does not unwrap the input Tensor for all() and any() ).
+- name: any(Tensor self) -> Tensor
+  output_differentiability: [False]
+
+- name: any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
+  output_differentiability: [False]
+
+- name: any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
+  output_differentiability: [False]
+
+- name: _is_all_true(Tensor self) -> Tensor
+  self: non_differentiable
+
+- name: _is_any_true(Tensor self) -> Tensor
+  self: non_differentiable
+
+- name: all(Tensor self) -> Tensor
+  output_differentiability: [False]
+
+- name: all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor
+  output_differentiability: [False]
+
+- name: all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor
+  output_differentiability: [False]
+
+- name: acosh(Tensor self) -> Tensor
+# Save one rsqrt in the real case by using that for x real and positive sqrt(x*y) = sqrt(x)*sqrt(y) (not true in the complex case)
+  self: "self.is_complex() ? grad * ((self + 1).rsqrt() * (self - 1).rsqrt()).conj() : grad * (self * self - 1).rsqrt()"
+  result: auto_element_wise
+
+- name: acosh_(Tensor(a!) self) -> Tensor(a!)
+  self: not_implemented("inplace version of acosh")
+
+- name: asinh(Tensor self) -> Tensor
+  self: grad * (self.pow(2) + 1).rsqrt().conj()
+  result: auto_element_wise
+
+- name: asinh_(Tensor(a!) self) -> Tensor(a!)
+  self: not_implemented("inplace version of asinh")
+
+- name: atanh(Tensor self) -> Tensor
+  self: grad * 1 / (1 - self.pow(2)).conj()
+  result: auto_element_wise
+
+- name: atanh_(Tensor(a!) self) -> Tensor(a!)
+  self: not_implemented("inplace version of atanh")
+
+- name: as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)
+  self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset)
+  result: auto_linear
+
+- name: as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!)
+  self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset)
+  result: auto_linear
+
+- name: asin(Tensor self) -> Tensor
+  self: grad * (-self * self + 1).rsqrt().conj()
+  result: auto_element_wise
+
+- name: atan(Tensor self) -> Tensor
+  self: grad / (self * self + 1).conj()
+  result: auto_element_wise
+
+- name: atan2(Tensor self, Tensor other) -> Tensor
+  self, other: atan2_backward(grad, self, other, grad_input_mask)
+  result: (-self_p * other_t + other_p * self_t) / (self_p.pow(2) + other_p.pow(2))
+
+- name: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+  self: maybe_multiply(grad, beta.conj())
+  batch1: maybe_multiply(grad.bmm(batch2.transpose(1, 2).conj()), alpha.conj())
+  batch2: maybe_multiply(batch1.transpose(1, 2).conj().bmm(grad), alpha.conj())
+  result: maybe_multiply(self_t, beta) + maybe_multiply(batch1_t.bmm(batch2_p), alpha) + maybe_multiply(batch1_p.bmm(batch2_t), alpha)
+
+- name: bernoulli(Tensor self, *, Generator? generator=None) -> Tensor
+  self: zeros_like(grad)
+  result: auto_element_wise
+
+- name: bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!)
+  self: zeros_like(grad)
+  p: zeros_like(p)
+  result: self_t.zero_()
+
+- name: bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!)
+  self: zeros_like(grad)
+  result: self_t.zero_()
+
+- name: bmm(Tensor self, Tensor mat2) -> Tensor
+  self: grad.bmm(mat2.transpose(1, 2).conj())
+  mat2: self.transpose(1, 2).conj().bmm(grad)
+  result: self_t.bmm(mat2_p) + self_p.bmm(mat2_t)
+
+- name: matmul(Tensor self, Tensor other) -> Tensor
+  self, other: matmul_backward(grad, self, other, grad_input_mask)
+
+- name: cat(Tensor[] tensors, int dim=0) -> Tensor
+  tensors: cat_tensors_backward(grad, to_args_sizes_symint(tensors), to_args_scalartypes(tensors), dim)
+  result: cat_jvp(tensors, dim)
+
+- name: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)
+  self: zeros_like(grad)
+  result: self_t.zero_()
+
+- name: ceil(Tensor self) -> Tensor
+  self: zeros_like(grad)
+  result: auto_element_wise
+
+- name: cholesky(Tensor self, bool upper=False) -> Tensor
+  self: cholesky_backward(grad, upper, result)
+
+- name: chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]
+  dispatch:
+    Default:
+      # the default case will use the CompositeImplicitAutograd
+      self: not_implemented("chunk")
+    AutogradNestedTensor:
+      self: chunk_backward_nested(grads, self, chunks, dim)
+
+- name: linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info)
+  self: cholesky_backward(grad, upper, L)
+  L: cholesky_jvp(self_t, L, upper)
+
+- name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor
+  self, input2: cholesky_solve_backward(grad, self, input2, result, upper, grad_input_mask)
+  result: cholesky_solve_jvp(result, input2_p, input2_t, self_t, upper)
+
+- name: cholesky_inverse(Tensor self, bool upper=False) -> Tensor
+  self: cholesky_inverse_backward(grad, self, upper, result)
+  result: cholesky_inverse_jvp(self_p, self_t, result, upper)
+
+# For clamp, gradient is not defined at the boundaries. But empirically it's helpful
+# to be able to get gradient on min and max, so we return the subgradient 1 for these cases.
+- name: clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor
+  self: clamp_backward(grad, self, min, max)
+  min, max: clamp_backward_min_max(grad, self, min, max, grad_input_mask)
+  result: clamp_jvp(self_p, self_t, min_p, min_t, max_p, max_t)
+
+- name: clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor
+  self: clamp_backward(grad, self, min, max)
+  result: auto_element_wise
+
+- name: clamp_min(Tensor self, Scalar min) -> Tensor
+  self: where(self >= min, grad, at::scalar_tensor(0., grad.options()))
+  result: auto_element_wise
+
+- name: clamp_min.Tensor(Tensor self, Tensor min) -> Tensor
+  self: where(self >= min, grad, at::scalar_tensor(0., grad.options()))
+  min: where(self < min, grad, at::scalar_tensor(0., grad.options()))
+  result: where(self_p >= min_p, self_t, min_t)
+
+- name: clamp_max(Tensor self, Scalar max) -> Tensor
+  self: where(self <= max, grad, at::scalar_tensor(0., grad.options()))
+  result: auto_element_wise
+
+- name: clamp_max.Tensor(Tensor self, Tensor max) -> Tensor
+  self: where(self <= max, grad, at::scalar_tensor(0., grad.options()))
+  max: where(self > max, grad, at::scalar_tensor(0., grad.options()))
+  result: where(self_p <= max_p, self_t, max_t)
+
+- name: clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor
+  self: grad
+  result: auto_linear
+
+- name: _lazy_clone(Tensor self) -> Tensor
+  self: grad
+  result: auto_linear
+
+- name: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
+  self: _to_copy_backward(grad, self.options())
+  result: _to_copy(self_t, dtype, layout, device, pin_memory, non_blocking, memory_format)
+  # The condition is: if dtype is not nullopt, then isDifferentiableType(*dtype)
+  # (If dtype IS nullopt, we rely on the regular check that any input requires grad).
+  output_differentiability: ["!dtype || isDifferentiableType(*dtype)"]
+
+- name: _coalesce(Tensor self) -> Tensor
+  self: grad
+
+- name: complex(Tensor real, Tensor imag) -> Tensor
+  real: at::real(grad)
+  imag: at::imag(grad)
+  result: at::complex(real_t, imag_t)
+
+- name: polar(Tensor abs, Tensor angle) -> Tensor
+  abs, angle: polar_backward(grad, result)
+  result: at::complex(abs_t*angle_p.cos() - angle_t*abs_p*angle_p.sin(), abs_t*angle_p.sin() + angle_t*abs_p*angle_p.cos())
+
+- name: _conj(Tensor(a) self) -> Tensor(a)
+  self: grad.conj()
+  result: self_t.conj()
+
+- name: _neg_view(Tensor(a) self) -> Tensor(a)
+  self: grad.neg()
+  result: self_t._neg_view()
+
+- name: _conj_physical(Tensor self) -> Tensor
+  self: grad.conj_physical()
+  result: self_t.conj_physical()
+
+- name: conj_physical_(Tensor(a!) self) -> Tensor(a!)
+  self: grad.conj_physical()
+  result: self_t.conj_physical_()
+
+- name: copysign.Tensor(Tensor self, Tensor other) -> Tensor
+  self: copysign_tensor_self_backward(grad, self, result)
+  other: zeros_like(other)
+  result: copysign_tensor_self_backward(self_t, self_p, result)
+
+- name: copysign.Scalar(Tensor self, Scalar other) -> Tensor
+  self: copysign_tensor_self_backward(grad, self, result)
+  result: auto_element_wise
+
+- name: cos(Tensor self) -> Tensor
+  self: grad * -self.sin().conj()
+  result: auto_element_wise
+
+- name: cosh(Tensor self) -> Tensor
+  self: grad * self.sinh().conj()
+  result: auto_element_wise
+
+- name: count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor
+  output_differentiability: [False]
+
+- name: count_nonzero(Tensor self, int? dim=None) -> Tensor
+  output_differentiability: [False]
+
+- name: linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor
+  self: at::linalg_cross(other.conj(), grad, dim)
+  other: at::linalg_cross(grad, self.conj(), dim)
+  result: "at::linalg_cross(self_t, other_p, dim) + at::linalg_cross(self_p, other_t, dim)"
+
+- name: logcumsumexp(Tensor self, int dim) -> Tensor
+  self: logcumsumexp_backward(grad, self, result, dim)
+  result: logcumsumexp_jvp(self_p, self_t, dim)
+
+- name: cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor
+  self: cumprod_backward(grad.to(self.scalar_type()), self, dim, result)
+  result: "cumprod_jvp(self_t, self_p, result, dim).to(dtype.has_value() ? *dtype : self_p.scalar_type())"
+
+- name: cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor
+  self: cumsum_backward(grad.to(self.scalar_type()), dim)
+  result: auto_linear
+
+- name: cummax(Tensor self, int dim) -> (Tensor values, Tensor indices)
+  self: cummaxmin_backward(grad, self, indices, dim)
+  values: self_t.gather(dim, indices)
+
+- name: cummin(Tensor self, int dim) -> (Tensor values, Tensor indices)
+  self: cummaxmin_backward(grad, self, indices, dim)
+  values: self_t.gather(dim, indices)
+
+- name: conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor
+  self, weight, bias: "grad.defined() ? conv_tbc_backward(grad, self, weight, bias, pad) : std::tuple()"
+
+- name: _ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor)
+  log_probs: _ctc_loss_backward(grad, log_probs, targets, input_lengths, target_lengths, result0, result1, blank, zero_infinity)
+
+- name: _ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor)
+  log_probs: _ctc_loss_backward(grad, log_probs, targets, input_lengths, target_lengths, result0, result1, blank, zero_infinity)
+
+- name: deg2rad(Tensor self) -> Tensor
+  self: deg2rad_backward(grad)
+  result: auto_element_wise
+
+- name: _linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots)
+  A: linalg_det_backward(grad, result, A, LU, pivots)
+  result: linalg_det_jvp(A_t, result, LU, pivots, A_p.is_contiguous() && !A_p.is_complex())
+  output_differentiability: [True, False, False]
+
+- name: _linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet, Tensor LU, Tensor pivots)
+  A: slogdet_backward(grad_sign, grad_logabsdet, A, sign, LU, pivots)
+  sign, logabsdet: slogdet_jvp(LU, pivots, A_t, sign, A_p.is_contiguous() && !A_p.is_complex())
+  output_differentiability: [True, True, False, False]
+
+- name: block_diag(Tensor[] tensors) -> Tensor
+  tensors: block_diag_backward(grad, to_args_sizes(tensors), to_args_scalartypes(tensors))
+  result: block_diag_jvp(tensors)
+
+- name: diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor
+  self: grad.diagonal(offset, dim1, dim2)
+  result: auto_linear
+
+- name: diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)
+  self: diagonal_backward_symint(grad, self.sym_sizes(), offset, dim1, dim2)
+  result: auto_linear
+
+- name: diagonal_backward(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2) -> Tensor
+  grad_output: grad.diagonal(offset, dim1, dim2)
+  result: auto_linear
+
+- name: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor
+  self: norm_backward(grad, self - other, p, result)
+  other: -norm_backward(grad, self - other, p, result)
+  result: norm_jvp(self_p - other_p, self_t - other_t, p, result, {}, false)
+
+# The backward formula is done in this order to improve numerical stability
+# of the higher order derivatives, see https://github.com/pytorch/pytorch/issues/43414
+# Note that we don't use "result" because saving it would be BC-breaking when it is used in an inplace operation later
+- name: div.Tensor(Tensor self, Tensor other) -> Tensor
+  self: div_tensor_self_backward(grad, other, self.scalar_type())
+  other: div_tensor_other_backward(grad, self, other)
+  result: (self_t - other_t * result) / other_p
+
+- name: div.Scalar(Tensor self, Scalar other) -> Tensor
+  self: div_tensor_self_backward(grad, other, self.scalar_type())
+  result: self_t / other
+
+- name: div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor
+  self: div_tensor_self_backward(grad, other, self.scalar_type(), rounding_mode)
+  other: div_tensor_other_backward(grad, self, other, rounding_mode)
+  result: "rounding_mode.has_value() ? result.new_zeros_symint(result.sym_sizes()) : self_t / other_p - other_t * (self_p / other_p) / other_p"
+
+- name: div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor
+  self: div_tensor_self_backward(grad, other, self.scalar_type(), rounding_mode)
+  result: "rounding_mode.has_value() ? result.new_zeros_symint(result.sym_sizes()) : self_t / other"
+
+- name: dot(Tensor self, Tensor tensor) -> Tensor
+  self: grad * tensor.conj()
+  tensor: grad * self.conj()
+  result: at::dot(self_t, tensor_p) + at::dot(self_p, tensor_t)
+
+- name: vdot(Tensor self, Tensor other) -> Tensor
+  self: grad.conj() * other
+  other: grad * self
+  result: at::vdot(self_t, other_p) + at::vdot(self_p, other_t)
+
+- name: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor)
+  self: _fused_dropout_backward(grad, result1, p)
+
+- name: native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)
+  input: "GradMode::is_enabled() ? infinitely_differentiable_native_dropout_backward(grad, result1, (!train.has_value() || !train.value() ? 1 : (p == 1 ? 0.0 : 1.0 / (1.0 - p)))) : native_dropout_backward(grad, result1, (!train.has_value() || !train.value() ? 1 : (p == 1 ? 0.0 : 1.0 / (1.0 - p))))"
+  result0: "(!train.has_value() || train.value()) ? (p == 1 ? 0.0 : 1.0 / (1.0 - p)) * input_t * result1 : input_t"
+
+- name: native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor
+  grad_output: "native_dropout_double_backward(grad, grad_output, mask, scale)"
+  mask: 'not_implemented("native_dropout_backward: mask")'
+
+- name: eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  self: zeros_like(self)
+  result: self_t.zero_()
+
+- name: eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  self: zeros_like(self)
+  other: zeros_like(other)
+  result: self_t.zero_()
+
+- name: erf(Tensor self) -> Tensor
+  self: 2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad
+  result: auto_element_wise
+
+- name: erfc(Tensor self) -> Tensor
+  self: -2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad
+  result: auto_element_wise
+
+- name: special_erfcx(Tensor self) -> Tensor
+  self: (2.0 * self * result - 2.0 / sqrt(M_PI)) * grad
+  result: auto_element_wise
+
+- name: erfinv(Tensor self) -> Tensor
+  self: 0.5 * sqrt(M_PI) * exp(self.erfinv().pow(2)) * grad
+  result: auto_element_wise
+
+- name: exp(Tensor self) -> Tensor
+  self: grad * result.conj()
+  result: auto_element_wise
+
+- name: exp2(Tensor self) -> Tensor
+  self: grad * result.conj() * M_LN2
+  result: auto_element_wise
+
+- name: expm1(Tensor self) -> Tensor
+  self: grad * (result.conj() + 1)
+  result: auto_element_wise
+
+# TODO: this derivative is not SymInt safe, need sum_to support
+- name: expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
+  self: at::sum_to(grad, self.sym_sizes())
+  result: auto_linear
+
+- name: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!)
+  self: zeros_like(grad)
+  result: self_t.zero_()
+
+- name: fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
+  self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask)
+
+- name: _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
+  self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask)
+
+- name: _fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor
+  self, scale, zero_point: "grad.defined() ? _fake_quantize_learnable_per_tensor_affine_backward(grad, self, scale, zero_point, quant_min, quant_max, grad_factor) : std::tuple()"
+
+- name: fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask)
+  self: fake_quantize_per_channel_affine_cachemask_backward(grad, mask)
+
+- name: _fake_quantize_learnable_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor
+  self, scale, zero_point: "grad.defined() ? _fake_quantize_learnable_per_channel_affine_backward(grad, self, scale, zero_point, axis, quant_min, quant_max, grad_factor) : std::tuple()"
+
+- name: _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask)
+  self: fake_quantize_per_tensor_affine_cachemask_backward(grad, mask)
+
+- name: fill.Scalar(Tensor self, Scalar value) -> Tensor
+  self: zeros_like(grad)
+  result: at::fill(self_t, 0)
+
+- name: fill.Tensor(Tensor self, Tensor value) -> Tensor
+  self: zeros_like(grad)
+  value: grad.sum()
+  result: at::fill(self_t, value_t)
+
+- name: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
+  self: zeros_like(grad)
+  result: self_t.fill_(0)
+
+- name: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!)
+  self: zeros_like(grad)
+  value: grad.sum()
+  result: self_t.fill_(value_t)
+
+- name: floor(Tensor self) -> Tensor
+  self: zeros_like(grad)
+  result: auto_element_wise
+
+- name: fmod.Scalar(Tensor self, Scalar other) -> Tensor
+  self: grad
+  result: auto_element_wise
+
+- name: fmod.Tensor(Tensor self, Tensor other) -> Tensor
+  self: grad
+  other: -grad * self.div(other, /*rounding_mode=*/"trunc")
+  result: self_t - other_t * self_p.div(other_p, /*rounding_mode=*/"trunc")
+
+- name: frac(Tensor self) -> Tensor
+  self: grad
+  result: self_t
+
+- name: frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent)
+  self: grad / exponent.exp2()
+  mantissa: self_t / exponent.exp2()
+
+- name: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor
+  self: gather_backward(grad, self, dim, index, sparse_grad)
+  index: non_differentiable
+  result: auto_linear
+
+- name: ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  self: zeros_like(self)
+  result: self_t.zero_()
+
+- name: ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  self: zeros_like(self)
+  other: zeros_like(other)
+  result: self_t.zero_()
+
+- name: geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!)
+  self: zeros_like(grad)
+  result: self_t.zero_()
+
+- name: geqrf(Tensor self) -> (Tensor a, Tensor tau)
+  self: not_implemented("geqrf")
+
+- name: indices(Tensor(a) self) -> Tensor(a)
+  output_differentiability: [False]
+
+- name: _indices(Tensor(a) self) -> Tensor(a)
+  output_differentiability: [False]
+
+- name: crow_indices(Tensor(a) self) -> Tensor(a)
+  output_differentiability: [False]
+
+- name: col_indices(Tensor(a) self) -> Tensor(a)
+  output_differentiability: [False]
+
+- name: ccol_indices(Tensor(a) self) -> Tensor(a)
+  output_differentiability: [False]
+
+- name: row_indices(Tensor(a) self) -> Tensor(a)
+  output_differentiability: [False]
+
+- name: grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor
+  input, grid: "grad.defined() ? grid_sampler_2d_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners, grad_input_mask) : std::tuple()"
+
+- name: grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor
+  input, grid: "grad.defined() ? grid_sampler_3d_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners, grad_input_mask) : std::tuple()"
+
+# See NOTE [ grid_sample CPU fallback ]
+- name: _grid_sampler_2d_cpu_fallback(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor
+  input, grid: "grad.defined() ? _grid_sampler_2d_cpu_fallback_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners) : std::tuple()"
+
+- name: gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  self: zeros_like(self)
+  result: self_t.zero_()
+
+- name: gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  self: zeros_like(self)
+  other: zeros_like(other)
+  result: self_t.zero_()
+
+- name: hardsigmoid(Tensor self) -> Tensor
+  self: hardsigmoid_backward(grad, self)
+  result: auto_element_wise
+
+- name: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor
+  output_differentiability: [False]
+
+- name: hardswish(Tensor self) -> Tensor
+  self: hardswish_backward(grad, self)
+  result: auto_element_wise
+
+- name: hardswish_backward(Tensor grad_output, Tensor self) -> Tensor
+  grad_output: hardswish_backward(grad, self)
+  self: at::where(at::logical_and(-3.0 < self, self < 3.0), grad * grad_output / 3.0, at::zeros({}, self.options()))
+  result: "hardswish_backward(grad_output_t, self_p)
+         + at::where(at::logical_and(-3.0 < self_p, self_p < 3.0), self_t * grad_output_p / 3.0, at::zeros({}, self_p.options()))"
+
+- name: hypot(Tensor self, Tensor other) -> Tensor
+  self: grad * self / result
+  other: grad * other / result
+  result: self_t * self_p / result + other_t * other_p / result
+
+- name: i0(Tensor self) -> Tensor
+  self: grad * at::special_i1(self)
+  result: auto_element_wise
+
+- name: special_i0e(Tensor self) -> Tensor
+  self: grad * (at::special_i1e(self) - self.sgn() * result)
+  result: auto_element_wise
+
+- name: special_i1(Tensor self) -> Tensor
+  self: i1_backward(grad, self, result)
+  result: auto_element_wise
+
+- name: special_i1e(Tensor self) -> Tensor
+  self: i1e_backward(grad, self, result)
+  result: auto_element_wise
+
+- name: igamma(Tensor self, Tensor other) -> Tensor
+  self: 'not_implemented("igamma: input")'
+  other: grad * exp((self - 1) * log(other) - other - lgamma(self))
+
+- name: igammac(Tensor self, Tensor other) -> Tensor
+  self: 'not_implemented("igammac: input")'
+  other: -grad * exp((self - 1) * log(other) - other - lgamma(self))
+
+- name: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
+  self: index_backward(grad.new_zeros_symint(self.sym_sizes(), self.options()), indices, grad)
+  result: auto_linear
+
+- name: _unsafe_index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
+  self: at::_unsafe_index_put(grad.new_zeros_symint(self.sym_sizes(), self.options()), indices, grad, true)
+  result: auto_linear
+
+- name: _unsafe_masked_index(Tensor self, Tensor mask, Tensor?[] indices, Scalar fill) -> Tensor
+  self: at::_unsafe_masked_index_put_accumulate(grad.new_zeros_symint(self.sym_sizes(), self.options()), mask, indices, grad)
+  mask: non_differentiable
+  result: _unsafe_masked_index(self_t, mask, indices, 0)
+
+- name: _unsafe_masked_index_put_accumulate(Tensor self, Tensor mask, Tensor?[] indices, Tensor values) -> Tensor
+  self: grad
+  mask: non_differentiable
+  values: at::_unsafe_masked_index(grad, mask, indices, 0)
+  result: at::_unsafe_masked_index_put_accumulate(self_t, mask, indices, values_t)
+
+- name: index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor
+  self: grad
+  # The case source.dim() == 0  is necessary to support scalar tensors of the form
+  # source.dim() == 0 and index.dim() == 1 and index.size() == (1,),
+  # This is because source is not broadcastable to index, as source.dim() < index.dim()
+  source: "maybe_multiply(source.dim() > 0 ? grad.index_select(dim, index).expand_as(source) : grad.index_select(dim, index.squeeze(0)), alpha)"
+  index: non_differentiable
+  result: at::index_add(self_t, dim, index, maybe_multiply(source_t, alpha))
+
+- name: index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor
+  self, source: index_reduce_backward(grad, self, dim, index, source, reduce, include_self, result)
+  index: non_differentiable
+
+- name: index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor
+  self: grad.index_fill(dim, index, 0)
+  # The case source.dim() == 0 is necessary to support scalar tensors of the form
+  # source.dim() == 0 and index.dim() == 1 and index.size() == (1,),
+  # This is because source is not broadcastable to index, as source.dim() < index.dim()
+  source: "source.dim() > 0 ? grad.index_select(dim, index).expand_as(source) : grad.index_select(dim, index.squeeze(0))"
+  index: non_differentiable
+  result: self_t.index_copy(dim, index, source_t)
+
+- name: index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
+  self: grad.index_fill(dim, index, 0)
+  index: non_differentiable
+  result: self_t.index_fill(dim, index, 0)
+
+- name: index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor
+  self: grad.index_fill(dim, index, 0)
+  value: grad.index_select(dim, std::get<0>(at::_unique(index, /*sorted=*/false))).sum()
+  index: non_differentiable
+  result: self_t.index_fill(dim, index, value_t)
+
+- name: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
+  self: "accumulate ? grad : grad.index_put(indices, zeros_like(values), false)"
+  values: grad.index(indices)
+  result: self_t.index_put(indices, values_t, accumulate)
+
+- name: _unsafe_index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
+  self: "accumulate ? grad : at::_unsafe_index_put(grad, indices, zeros_like(values), false)"
+  values: at::_unsafe_index(grad, indices)
+  result: at::_unsafe_index_put(self_t, indices, values_t, accumulate)
+
+- name: _index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)
+  self: "accumulate ? grad : grad.index_put(indices, zeros_like(values), false)"
+  values: grad.index(indices)
+  result: at::_index_put_impl_(self_t, indices, values_t, accumulate, unsafe)
+
+- name: index_select(Tensor self, int dim, Tensor index) -> Tensor
+  self: index_select_backward_symint(grad, self.sym_sizes(), dim, index)
+  index: non_differentiable
+  result: auto_linear
+
+- name: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info)
+  A: -at::matmul(inverse.mH(), at::matmul(grad, inverse.mH()))
+  inverse: -at::matmul(at::matmul(inverse, A_t), inverse)
+  output_differentiability: [True, False]
+
+- name: linalg_pinv.atol_rtol_tensor(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor
+  self: pinv_backward(grad, result, self)
+  result: pinv_jvp(self_p, result, self_t)
+
+- name: isnan(Tensor self) -> Tensor
+  self: non_differentiable
+
+- name: kthvalue(Tensor self, SymInt k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
+  self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim)
+  values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
+
+- name: le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  self: zeros_like(self)
+  result: self_t.zero_()
+
+- name: le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  self: zeros_like(self)
+  other: zeros_like(other)
+  result: self_t.zero_()
+
+- name: lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor
+  self: "weight.isComplex() ? grad * (1 - weight.conj().toComplexDouble()) : grad * (1 - weight.toDouble())"
+  end: grad * weight.conj()
+  result: at::lerp(self_t, end_t, weight)
+
+- name: lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor
+  self: grad * (1 - weight).conj()
+  end: grad * weight.conj()
+  weight: grad * (end - self).conj()
+  result: at::lerp(self_t, end_t, weight_p) + weight_t * (end_p - self_p)
+
+- name: lgamma(Tensor self) -> Tensor
+  self: grad * digamma(self)
+  result: auto_element_wise
+
+- name: digamma(Tensor self) -> Tensor
+  self: grad * polygamma(1, self)
+  result: auto_element_wise
+
+- name: polygamma(int n, Tensor self) -> Tensor
+  self: grad * polygamma(n + 1, self)
+  result: auto_element_wise
+
+- name: polygamma_(Tensor(a!) self, int n) -> Tensor(a!)
+  self: grad * polygamma(n + 1, self)
+  result: self_t.mul_(polygamma(n + 1, original_self_p))
+
+- name: log(Tensor self) -> Tensor
+  self: grad.div(self.conj())
+  result: auto_element_wise
+
+- name: log10(Tensor self) -> Tensor
+  self: grad / (self.conj() * 2.3025850929940456)
+  result: auto_element_wise
+
+- name: log1p(Tensor self) -> Tensor
+  self: log1p_backward(grad, self)
+  result: auto_element_wise
+
+- name: log2(Tensor self) -> Tensor
+  self: grad / (self.conj() * 0.6931471805599453)
+  result: auto_element_wise
+
+- name: logaddexp(Tensor self, Tensor other) -> Tensor
+  self: grad / (1 + exp(other - self)).conj()
+  other: grad / (1 + exp(self - other)).conj()
+  result: self_t / (1 + exp(other_p - self_p)) + other_t / (1 + exp(self_p - other_p))
+
+- name: logaddexp2(Tensor self, Tensor other) -> Tensor
+  self: grad / (1 + pow(2, other - self))
+  other: grad / (1 + pow(2, self - other))
+  result: self_t / (1 + pow(2, other_p - self_p)) + other_t / (1 + pow(2, self_p - other_p))
+
+# Note [Gradient formula for xlogy at x = 0, y <= 0]
+# x * log(y) is not defined at y <= 0, so we cannot even talk about differentiability
+# Now, xlogy(0, y) = 0 by definition.
+# This does not make it differentiable as it's not defined in a neighbourhood of a point
+# (0, y) when y <= 0.
+# Now, when a function is non-differentiable, sometimes we return "a relatively sensible value"
+# In this case, as per the discussion in https://github.com/pytorch/pytorch/issues/80770, we choose
+# this value to be zero, which is the directional derivative along the line {x = 0}.
+- name: xlogy.Tensor(Tensor self, Tensor other) -> Tensor
+  self: at::xlogy(grad, other).masked_fill((self == 0.) & (other <= 0.), 0.)
+  other: grad * self / other
+  result: at::xlogy(self_t, other_p).masked_fill((self_p == 0.) & (other_p <= 0.), 0.) + other_t * self_p / other_p
+
+- name: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor
+  other: grad * self / other
+  result: auto_element_wise
+
+- name: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor
+  self: "other.toDouble() > 0.
+          ? at::xlogy(grad,  other)
+          : at::xlogy(grad,  other).masked_fill(self == 0., 0.)"
+  result: auto_element_wise
+
+# See Note [Gradient formula for xlogy at x = 0, y <= 0]
+# Same here but with y <= -1
+- name: special_xlog1py(Tensor self, Tensor other) -> Tensor
+  self: at::special_xlog1py(grad,  other).masked_fill((self == 0.) & (other <= -1.), 0.)
+  other: grad * self / (other + 1)
+  result: at::special_xlog1py(self_t,  other_p).masked_fill((self_p == 0.) & (other_p <= -1.), 0.) + other_t * self_p / (other_p + 1)
+
+- name: special_xlog1py.self_scalar(Scalar self, Tensor other) -> Tensor
+  other: grad * self / (other + 1)
+  result: auto_element_wise
+
+- name: special_xlog1py.other_scalar(Tensor self, Scalar other) -> Tensor
+  self: "other.toDouble() > -1.
+          ? at::special_xlog1py(grad,  other)
+          : at::special_xlog1py(grad,  other).masked_fill(self == 0., 0.)"
+  result: auto_element_wise
+
+- name: special_zeta(Tensor self, Tensor other) -> Tensor
+  self: not_implemented("zeta")
+  other:  grad * -self * special_zeta(self + 1., other)
+
+- name: special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor
+  other:  grad * -self * special_zeta(self.toDouble() + 1., other)
+
+- name: special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor
+  self: not_implemented("zeta")
+
+- name: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!)
+  self: zeros_like(grad)
+  result: self_t.zero_()
+
+- name: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor
+  self: logsumexp_backward(grad, self, result, dim, keepdim)
+  result: logsumexp_jvp(self_p, self_t, dim, keepdim)
+
+- name: linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
+  self, b: linalg_lstsq_backward(grads[0], grads[1], self, b, solution, grad_input_mask)
+  solution: linalg_lstsq_solution_jvp(self_p, b_p, self_t, b_t)
+  residuals: linalg_lstsq_residuals_jvp(self_p, b_p, self_t, b_t, solution, residuals)
+  output_differentiability: [True, True, False, False]
+
+- name: lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  self: zeros_like(self)
+  result: self_t.zero_()
+
+- name: lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  self: zeros_like(self)
+  other: zeros_like(other)
+  result: self_t.zero_()
+
+- name: linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info)
+  A: lu_factor_ex_backward(grad, LU, pivots, pivot)
+  LU: lu_factor_ex_jvp(A_t, LU, pivots, pivot)
+  output_differentiability: [True, False, False]
+
+- name: linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots)
+  A: lu_factor_ex_backward(grad, LU, pivots, pivot)
+  LU: lu_factor_ex_jvp(A_t, LU, pivots, pivot)
+  output_differentiability: [True, False]
+
+- name: linalg_lu(Tensor A, *, bool pivot=True) -> (Tensor P, Tensor L, Tensor U)
+  A: linalg_lu_backward(grad_L, grad_U, P, L, U, pivot)
+  L: std::get<0>(linalg_lu_jvp(A_t, P, L, U, pivot))
+  U: std::get<1>(linalg_lu_jvp(A_t, P, L, U, pivot))
+  output_differentiability: [False, True, True]
+
+- name: linalg_lu_solve(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False) -> Tensor
+  LU: linalg_lu_solve_LU(grad, LU, pivots, result, left, adjoint)
+  B: "at::linalg_lu_solve(LU, pivots, grad, left, !adjoint)"
+  result: linalg_lu_solve_jvp(result, LU_p, pivots, LU_t, B_t, left, adjoint)
+
+- name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)
+  LU_data: lu_unpack_backward(grad_L, grad_U, LU_data.sym_size(-2), LU_data.sym_size(-1))
+  LU_pivots: non_differentiable
+  L: "LU_data_t.sym_size(-2) >= LU_data_t.sym_size(-1) ? LU_data_t.tril(-1) : LU_data_t.narrow_symint(-1, 0, LU_data_t.sym_size(-2)).tril(-1)"
+  U: "LU_data_t.sym_size(-1) >= LU_data_t.sym_size(-2) ? LU_data_t.triu() : LU_data_t.narrow_symint(-2, 0, LU_data_t.sym_size(-1)).triu()"
+  output_differentiability: [False, True, True]
+
+- name: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor
+  self: grad.masked_fill(mask, 0)
+  mask: non_differentiable
+  result: self_t.masked_fill(mask, 0)
+
+- name: masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor
+  self: grad.masked_fill(mask, 0)
+  value: masked_fill_backward(grad, mask)
+  mask: non_differentiable
+  result: self_t.masked_fill(mask, value_t)
+
+- name: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor
+  self: grad.masked_fill(mask, 0)
+  source: masked_scatter_backward_symint(grad, mask, source.sym_sizes())
+  mask: non_differentiable
+  result: self_t.masked_scatter(mask, source_t)
+
+- name: masked_scatter_backward(Tensor grad_output, Tensor mask, SymInt[] sizes) -> Tensor
+  grad_output: zeros_like(grad_output).masked_scatter(mask, grad)
+  mask: non_differentiable
+  result: masked_scatter_backward(grad_output_t, mask, grad_output_t.sizes())
+
+- name: masked_select(Tensor self, Tensor mask) -> Tensor
+  self: masked_select_backward(grad, self, mask)
+  mask: non_differentiable
+  result: auto_linear
+
+- name: linalg_matrix_exp(Tensor self) -> Tensor
+  self: linalg_matrix_exp_differential(self, grad, /*adjoint*/ true)
+  result: linalg_matrix_exp_differential(self_p, self_t, /*adjoint*/ false)
+
+- name: max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+  self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim)
+  values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
+
+- name: max(Tensor self) -> Tensor
+  self: evenly_distribute_backward(grad, self, result)
+  result: evenly_read_jvp(self_t, self_p, result)
+
+- name: maximum(Tensor self, Tensor other) -> Tensor
+  self: at::where(self == other, grad / 2, grad).masked_fill_(self < other, 0)
+  other: at::where(self == other, grad / 2, grad).masked_fill_(self > other, 0)
+  result: other_t + at::where(self_p == other_p, at::scalar_tensor(0.5, result.options()), (self_p > other_p).to(result.scalar_type())) * (self_t - other_t)
+
+- name: fmax(Tensor self, Tensor other) -> Tensor
+  self: grad.masked_fill((self >= other).logical_or_(other.isnan()).logical_not_(), 0)
+  other: grad.masked_fill((self >= other).logical_or_(other.isnan()), 0)
+  result: other_t + (self_p > other_p).logical_or_(other_p.isnan()) * (self_t - other_t)
+
+- name: mean(Tensor self, *, ScalarType? dtype=None) -> Tensor
+  dispatch:
+    Default:
+      self: grad.expand_symint(self.sym_sizes()) / self.sym_numel()
+      result: auto_linear
+    AutogradNestedTensor:
+      # TODO: replace this with grad.expand_as(self) / self.sym_numel() when that is supported
+      self: (ones_like(self) * grad) / self.sym_numel()
+      result: auto_linear
+
+- name: mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  self: mean_backward(grad, self.sym_sizes(), dim, self.sym_numel(), keepdim)
+  result: auto_linear
+
+- name: median(Tensor self) -> Tensor
+  self: evenly_distribute_backward(grad, self, result)
+  result: evenly_read_jvp(self_t, self_p, result)
+
+- name: nanmedian(Tensor self) -> Tensor
+  self: evenly_distribute_backward(grad, self, result)
+  result: evenly_read_jvp(self_t, self_p, result)
+
+# This is in theory incorrect in the following case:
+#   sorted list: [..., a, b, b, ..., b, b, c, ...] with median = b and the value
+#                            |                     at middle position of the
+#                            |                     list between two `b`s. E.g.,
+#                            |
+#                            ^the middle position
+# The gradient exists and is essentially 0 in this case.
+#
+# In case where the middle position is at the boundary of `b` range, e.g.,
+#   sorted list: [..., a, b, b, ..., b, b, c, ...]
+#                                       |
+#                                       ^the middle position
+# The backward implementation is correct in the sense that it returns the
+# subgradient on one side.
+- name: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+  self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim)
+  values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
+
+- name: nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+  self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim)
+  values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
+
+- name: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
+  self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim)
+  values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
+
+- name: min(Tensor self) -> Tensor
+  self: evenly_distribute_backward(grad, self, result)
+  result: evenly_read_jvp(self_t, self_p, result)
+
+- name: minimum(Tensor self, Tensor other) -> Tensor
+  self: at::where(self == other, grad / 2, grad).masked_fill_(self > other, 0)
+  other: at::where(self == other, grad / 2, grad).masked_fill_(self < other, 0)
+  result: other_t + at::where(self_p == other_p, at::scalar_tensor(0.5, result.options()), (self_p < other_p).to(result.scalar_type())) * (self_t - other_t)
+
+- name: fmin(Tensor self, Tensor other) -> Tensor
+  self: grad.masked_fill((self <= other).logical_or_(other.isnan()).logical_not_(), 0)
+  other: grad.masked_fill((self <= other).logical_or_(other.isnan()), 0)
+  result: other_t + (self_p <= other_p).logical_or_(other_p.isnan()) * (self_t - other_t)
+
+- name: amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
+  self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim)
+  result: amaxamin_jvp(self_p, self_t, result, dim, keepdim)
+
+- name: amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
+  self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim)
+  result: amaxamin_jvp(self_p, self_t, result, dim, keepdim)
+
+- name: mm(Tensor self, Tensor mat2) -> Tensor
+  self: mm_mat1_backward(grad, mat2, self.sym_sizes(), self.sym_strides(), self.layout(), 1)
+  mat2: mm_mat2_backward(grad, self, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), 1)
+  result: at::mm(self_t, mat2_p) + at::mm(self_p, mat2_t)
+
+- name: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
+  self: _grouped_mm_mat1_backward(grad, mat2, self.sym_sizes(), self.sym_strides(), self.layout(), offs, 1)
+  mat2: _grouped_mm_mat2_backward(grad, self, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), offs, 1)
+
+- name: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
+  self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim)
+  values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)
+
+- name: mul.Tensor(Tensor self, Tensor other) -> Tensor
+  self: mul_tensor_backward(grad, other, self.scalar_type())
+  other: mul_tensor_backward(grad, self, other.scalar_type())
+  result: other_t * self_p + self_t * other_p
+
+- name: mul.Scalar(Tensor self, Scalar other) -> Tensor
+  self: mul_tensor_backward(grad, other, self.scalar_type())
+  result: self_t * other
+
+- name: mv(Tensor self, Tensor vec) -> Tensor
+  self: grad.ger(vec.conj())
+  vec: self.conj().t().mv(grad)
+  result: mv(self_t, vec_p) + mv(self_p, vec_t)
+
+- name: mvlgamma(Tensor self, int p) -> Tensor
+  self: mvlgamma_backward(grad, self, p)
+  result: auto_element_wise
+
+- name: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
+  self: grad * at::isfinite(self)
+  result: auto_element_wise
+
+- name: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
+  input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple()"
+  result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps)
+
+- name: _native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
+  input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask) : std::tuple()"
+  result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, eps)
+
+- name: _native_batch_norm_legit_no_training(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor)
+  input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*training=*/false, eps, grad_input_mask) : std::tuple()"
+  result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, /*training=*/false, eps)
+
+- name: _native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)
+  input, weight, bias: "grad.defined() ? native_batch_norm_backward(grad, input, weight, Tensor(), Tensor(), result1, result2, training, eps, grad_input_mask) : std::tuple()"
+  result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, Tensor(), Tensor(), result1, result2, training, eps)
+
+- name: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+  input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, train, eps, save_mean, save_invstd, grad_input_mask)
+  save_mean: not_implemented("native_batch_norm_backward save_mean")
+  save_invstd: not_implemented("native_batch_norm_backward save_invstd")
+
+- name: native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
+  input, weight, bias: "grad.defined() ? native_layer_norm_backward_symint(grad, input, normalized_shape, result1, result2, weight, bias, grad_input_mask) : std::tuple()"
+  result0: layer_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, normalized_shape)
+
+- name: native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+  input, weight, grad_out: layer_norm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, mean, rstd, normalized_shape, grad_input_mask)
+  bias: Tensor()
+  mean: not_implemented("native_layer_norm_backward mean")
+  rstd: not_implemented("native_layer_norm_backward rstd")
+
+- name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
+  input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), input.device().is_xpu() ? input : input.contiguous(input.device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple())"
+  result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group)
+  result1: group_norm_mean_jvp(input_t, result1, group)
+  result2: group_norm_invstd_jvp(input_p, input_t, result1, result2, group)
+
+- name: ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  self: zeros_like(self)
+  result: self_t.zero_()
+
+- name: ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  self: zeros_like(self)
+  other: zeros_like(other)
+  result: self_t.zero_()
+
+- name: neg(Tensor self) -> Tensor
+  self: grad.neg()
+  result: auto_element_wise
+
+- name: _batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
+  input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/true, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple()"
+  result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, true, eps)
+
+- name: _batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)
+  input, weight, bias: "grad.defined() ? batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, /*update*/false, eps, grad_input_mask, retain_variables ? result3.clone() : result3) : std::tuple()"
+  result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, false, eps)
+
+- name: batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)
+  input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, update, eps, save_mean, save_var, grad_input_mask)
+  save_mean: not_implemented("batch_norm_backward save_mean")
+  save_var: not_implemented("batch_norm_backward save_var")
+  reserve: not_implemented("batch_norm_backward reserve")
+
+- name: nextafter(Tensor self, Tensor other) -> Tensor
+  self: not_implemented("nextafter")
+  other: not_implemented("nextafter")
+
+- name: norm.Scalar(Tensor self, Scalar p=2) -> Tensor
+  self: norm_backward(grad, self, p, result)
+  result: norm_jvp(self_p, self_t, p, result)
+
+- name: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor
+  self: norm_backward(grad, self, p, result, dim, keepdim)
+  result: norm_jvp(self_p, self_t, p, result, dim, keepdim)
+
+- name: norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor
+  self: norm_backward(grad, self.to(grad.scalar_type()), p, result)
+  result: norm_jvp(self_p, self_t, p, result)
+
+- name: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor
+  self: norm_backward(grad, self.to(grad.scalar_type()), p, result, dim, keepdim)
+  result: norm_jvp(self_p, self_t, p, result, dim, keepdim)
+
+- name: linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  self: linalg_vector_norm_backward(grad, self, ord, result, dim, keepdim)
+  result: linalg_vector_norm_jvp(self_p, self_t, ord, result, dim, keepdim)
+
+- name: _pdist_forward(Tensor self, float p=2) -> Tensor
+  self: _pdist_backward(grad, self, p, result)
+
+- name: _pdist_backward(Tensor grad, Tensor self, float p, Tensor pdist) -> Tensor
+  grad: not_implemented("_pdist_backward")
+  self: not_implemented("_pdist_backward")
+  pdist: not_implemented("_pdist_backward")
+
+- name: _euclidean_dist(Tensor x1, Tensor x2) -> Tensor
+  x1, x2: _euclidean_dist_backward(grad, x1, x2, result)
+
+- name: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
+  x1: _cdist_backward(grad.contiguous(), x1, x2, p, result)
+  x2: _cdist_backward(grad.mT().contiguous(), x2, x1, p, result.mT().contiguous())
+
+- name: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor
+  grad: not_implemented("_cdist_backward")
+  x1: not_implemented("_cdist_backward")
+  x2: not_implemented("_cdist_backward")
+  cdist: not_implemented("_cdist_backward")
+
+- name: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)
+  self: zeros_like(grad)
+  result: self_t.zero_()
+
+- name: normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor
+  mean: at::zeros_symint(mean.sym_sizes(), grad.options())
+  result: auto_element_wise
+
+- name: normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor
+  std: at::zeros_symint(std.sym_sizes(), grad.options())
+  result: auto_element_wise
+
+- name: normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor
+  mean: at::zeros_symint(mean.sym_sizes(), grad.options())
+  std: at::zeros_symint(std.sym_sizes(), grad.options())
+  result: zeros_like(mean_t)
+
+- name: linalg_householder_product(Tensor input, Tensor tau) -> Tensor
+  input, tau: householder_product_backward(grad, result, input, tau)
+  result: householder_product_jvp(input_t, tau_t, result, input_p, tau_p)
+
+- name: ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor
+  self, input2, input3: ormqr_backward(grad, result, self, input2, input3, left, transpose, grad_input_mask)
+
+- name: permute(Tensor(a) self, int[] dims) -> Tensor(a)
+  self: permute_backwards(grad, dims)
+  result: auto_linear
+
+- name: poisson(Tensor self, Generator? generator=None) -> Tensor
+  self: zeros_like(self)
+  result: auto_element_wise
+
+- name: pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor
+  self: pow_backward(grad, self, exponent)
+  result: auto_element_wise
+
+- name: pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor
+  self: pow_backward_self(grad, self, exponent)
+  exponent: pow_backward_exponent(grad, self, exponent, result)
+  result: (pow_backward_self(self_t.conj(), self_p, exponent_p) + pow_backward_exponent(exponent_t.conj(), self_p, exponent_p, result)).conj()
+
+- name: pow.Scalar(Scalar self, Tensor exponent) -> Tensor
+  exponent: pow_backward_exponent(grad, self, exponent, result)
+  result: auto_element_wise
+
+- name: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor
+  self: prod_backward(grad, self.to(grad.scalar_type()), result)
+  result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result) * self_t.conj()).sum().conj()
+
+- name: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  self: prod_backward(grad, self.to(grad.scalar_type()), result, dim, keepdim)
+  result: (prod_backward(at::ones({}, result.options()).expand_as(result), self_p.to(result.scalar_type()), result, dim, keepdim) * self_t.conj()).sum(dim, keepdim).conj()
+
+- name: put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor
+  self: "accumulate ? grad : grad.put(index, zeros_like(source), false)"
+  index: non_differentiable
+  source: grad.take(index).reshape_as(source)
+  result: self_t.put(index, source_t, accumulate)
+
+- name: linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R)
+  A: linalg_qr_backward(grad_Q, grad_R, Q, R, mode)
+  Q, R: linalg_qr_jvp(A_t, Q, R, mode)
+
+- name: rad2deg(Tensor self) -> Tensor
+  self: rad2deg_backward(grad)
+  result: auto_element_wise
+
+- name: random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)
+  self: zeros_like(grad)
+  result: self_t.zero_()
+
+- name: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)
+  self: zeros_like(grad)
+  result: self_t.zero_()
+
+- name: random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)
+  self: zeros_like(grad)
+  result: self_t.zero_()
+
+- name: reciprocal(Tensor self) -> Tensor
+  self: -grad * (result * result).conj()
+  result: auto_element_wise
+
+- name: remainder.Scalar(Tensor self, Scalar other) -> Tensor
+  self: grad
+  result: auto_element_wise
+
+- name: remainder.Tensor(Tensor self, Tensor other) -> Tensor
+  self: grad
+  other: -grad * self.div(other, /*rounding_mode=*/"floor")
+  result: self_t - other_t * self_p.div(other_p, /*rounding_mode=*/"floor")
+
+- name: renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor
+  self: renorm_backward(grad, self, p, dim, maxnorm)
+  result: renorm_jvp(self_p, self_t, p, dim, maxnorm)
+
+- name: repeat(Tensor self, SymInt[] repeats) -> Tensor
+  self: repeat_backward(grad, repeats, self.sym_sizes())
+  result: auto_linear
+
+- name: special_entr(Tensor self) -> Tensor
+  self: grad * (-(1 + self.log()))
+  result: auto_element_wise
+
+- name: special_ndtri(Tensor self) -> Tensor
+  self: grad * std::sqrt(2 * M_PI) * (result.square() / 2).exp()
+  result: auto_element_wise
+
+- name: special_log_ndtr(Tensor self) -> Tensor
+  self: grad / std::sqrt(2 * M_PI) * (result + self.pow(2) / 2).neg().exp()
+  result: auto_element_wise
+
+# [Note: Sometimes view derivatives]
+# The following situation applies to other operations as well.
+# TODO: This note is only referenced by to_dense and to_sparse*. Make
+# this more generic if it's been referenced more than once.
+#
+# DO NOT define a backward for reshape!
+# reshape is special in that it sometimes returns a view, and sometimes not.
+# Defining a backward will make codegen spit out the forward call as
+#     as_variable(baseType->reshape(self)),
+# making it impossible (hard) to detect when it is actually a view.
+# - name: reshape(Tensor self, IntArrayRef shape)
+
+- name: _reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a)
+  self: grad.reshape_symint(self.sym_sizes())
+  result: auto_linear
+
+- name: round(Tensor self) -> Tensor
+  self: zeros_like(grad)
+  result: auto_element_wise
+
+- name: round.decimals(Tensor self, *, int decimals) -> Tensor
+  self: zeros_like(grad)
+  result: auto_element_wise
+
+- name: rsqrt(Tensor self) -> Tensor
+  self: -0.5 * grad * result.pow(3).conj()
+  result: auto_element_wise
+
+- name: scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
+  self: grad.scatter(dim, index, 0)
+  index: non_differentiable
+  src: grad.gather(dim, index)
+  result: self_t.scatter(dim, index, src_t)
+
+- name: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor
+  self: grad.scatter(dim, index, 0)
+  index: non_differentiable
+  result: self_t.scatter(dim, index, 0)
+
+- name: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
+  self: grad
+  index: non_differentiable
+  src: grad.gather(dim, index)
+  result: scatter_add(self_t, dim, index, src_t)
+
+- name: select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a)
+  dispatch:
+    Default:
+      self: select_backward_symint(grad, self.sym_sizes(), dim, index)
+      result: auto_linear
+    AutogradNestedTensor:
+      self: _nested_select_backward_symint(grad, self, dim, index)
+
+- name: select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor
+  grad_output: grad.select_symint(dim, index)
+  result: auto_linear
+
+- name: sigmoid(Tensor self) -> Tensor
+  self: sigmoid_backward(grad, result)
+  result: auto_element_wise
+
+- name: logit(Tensor self, float? eps=None) -> Tensor
+  self: "GradMode::is_enabled() ? infinitely_differentiable_logit_backward(grad, self, eps) : logit_backward(grad, self, eps)"
+  result: auto_element_wise
+
+- name: sign(Tensor self) -> Tensor
+  self: zeros_like(grad)
+  result: auto_element_wise
+
+- name: sgn(Tensor self) -> Tensor
+  self: sgn_backward(self, grad, result)
+  # Cannot use auto_element_wise here because the Jacobian is *not* Hermitian (in fact, it is symmetric)
+  # The function is not holomorphic, so there's no reason for its Jacobian to be Hermitian
+  # auto_element_wise has a name that's a bit deceiving in the complex case
+  result: sgn_backward(self_p, self_t, result)
+
+- name: sin(Tensor self) -> Tensor
+  self: grad * self.cos().conj()
+  result: auto_element_wise
+
+- name: sinc(Tensor self) -> Tensor
+  self: sinc_backward(grad, self)
+  result: auto_element_wise
+
+- name: sinh(Tensor self) -> Tensor
+  self: grad * self.cosh().conj()
+  result: auto_element_wise
+
+- name: slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
+  self: slice_backward_wrapper(grad, self.sym_sizes(), dim, start, end, step)
+  result: auto_linear
+
+- name: slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor
+  grad_output: grad.slice_symint(dim, start, end, step)
+  result: auto_linear
+
+- name: slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)
+  self: grad.slice_symint(dim, start, end, step)
+  src: slice_scatter_symint(grad, zeros_like(self), dim, start, end, step)
+  result: auto_linear
+
+- name: slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor
+  self: slice_scatter_symint(grad, zeros_like(src), dim, start, end, step)
+  src: grad.slice_symint(dim, start, end, step)
+  result: auto_linear
+
+- name: select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor
+  self: select_scatter_symint(grad, zeros_like(src), dim, index)
+  src: grad.select_symint(dim, index)
+  result: auto_linear
+
+- name: diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor
+  self: diagonal_scatter(grad, zeros_like(src), offset, dim1, dim2)
+  src: grad.diagonal(offset, dim1, dim2)
+  result: auto_linear
+
+- name: as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor
+  self: as_strided_scatter_backward(grad, TensorGeometry(self), TensorGeometry(src), size, stride, storage_offset)
+  # See Note [as_strided_scatter backward support]
+  src: grad.contiguous().as_strided_symint(size, stride, storage_offset)
+  result: auto_linear
+
+- name: _linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info)
+  A, B: linalg_solve_backward(grad, result, A, LU, pivots, left, grad_input_mask[1])
+  result: "linalg_solve_jvp(A_t, B_t, result, LU, pivots, left, A_p.is_contiguous() && !A_p.is_complex())"
+  output_differentiability: [True, False, False, False]  # LU is an auxiliary tensor not exposed to the user
+
+- name: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
+  self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true)
+  output_differentiability: [True, False]
+  values: gather_with_keepdimed_indices(self_t, dim, indices, true)
+
+- name: sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)
+  self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true)
+  output_differentiability: [True, False]
+  values: gather_with_keepdimed_indices(self_t, dim, indices, true)
+
+- name: split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]
+  self: split_backward(grads, split_size, dim, self.sym_sizes(), self.options())
+  result: auto_linear
+
+- name: unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]
+  self: split_backward(grads, split_size, dim, self.sym_sizes(), self.options())
+  result: auto_linear
+
+- name: split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[]
+  dispatch:
+    Default:
+      self: split_with_sizes_backward(grads, split_sizes, dim, self.sym_sizes(), self.options())
+      result: auto_linear
+    AutogradNestedTensor:
+      self: _nested_split_with_sizes_backward(grads, split_sizes, dim, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), self.options())
+
+- name: unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]
+  self: split_with_sizes_backward(grads, split_sizes, dim, self.sym_sizes(), self.options())
+  result: auto_linear
+
+- name: sqrt(Tensor self) -> Tensor
+  self: grad / (2 * result.conj())
+  result: auto_element_wise
+
+- name: squeeze(Tensor(a) self) -> Tensor(a)
+  self: unsqueeze_to(grad, self.sym_sizes())
+  result: auto_linear
+
+- name: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)
+  dispatch:
+    Default:
+      self: unsqueeze_to(grad, dim, self.sym_sizes())
+      result: auto_linear
+    AutogradNestedTensor:
+      self: grad.unsqueeze(dim)
+
+- name: squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)
+  dispatch:
+    Default:
+      self: unsqueeze_to(grad, dim, self.sym_sizes())
+      result: auto_linear
+    AutogradNestedTensor:
+      self: unsqueeze_multiple(grad, dim, self.dim())
+
+- name: squeeze_(Tensor(a!) self) -> Tensor(a!)
+  self: unsqueeze_to(grad, self.sym_sizes())
+  result: auto_linear
+
+- name: squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!)
+  self: unsqueeze_to(grad, dim, self.sym_sizes())
+  result: auto_linear
+
+- name: squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!)
+  self: unsqueeze_to(grad, dim, self.sym_sizes())
+  result: auto_linear
+
+- name: std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor
+  self: std_backward(result, grad, self, dim, correction, keepdim)
+  # pointwise (variance) + sum + sqrt
+  result: (at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) / (2. * result)).masked_fill_(result == 0, 0)
+
+- name: std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
+  self: std_mean_backward(grads[0], grads[1], self, result0, dim, correction, keepdim)
+  result0: (at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim)) / (2. * result0)).masked_fill_(result0 == 0, 0)
+  # linear
+  result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim)
+
+- name: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+  self: handle_r_to_c(self.scalar_type(), grad)
+  other: handle_r_to_c(other.scalar_type(), maybe_multiply(-grad, alpha.conj()))
+  result: self_t - maybe_multiply(other_t, alpha)
+
+- name: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
+  self: handle_r_to_c(self.scalar_type(), grad)
+  result: auto_element_wise
+
+- name: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
+  self: handle_r_to_c(self.scalar_type(), maybe_multiply(-grad, alpha.conj()))
+  other: handle_r_to_c(other.scalar_type(), grad)
+  result: -maybe_multiply(self_t, alpha) + other_t
+
+- name: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
+  self: handle_r_to_c(self.scalar_type(), maybe_multiply(-grad, alpha.conj()))
+  result: auto_element_wise
+
+- name: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor
+  dispatch:
+    Default:
+      self: grad.expand_symint(self.sym_sizes())
+      result: auto_linear
+    AutogradNestedTensor:
+      # TODO: replace this with grad.expand_as(self) when that is supported
+      self: ones_like(self) * grad
+      result: auto_linear
+
+- name: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  dispatch:
+    Default:
+      self: sum_backward(grad, self.sym_sizes(), dim, keepdim)
+      result: auto_linear
+    AutogradNestedTensor:
+      # TODO: replace this function once semantics for nested tensor expand have been settled on
+      self: _nested_sum_backward(grad, self, dim, keepdim)
+
+- name: nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
+  self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim)
+  result: at::where(self_p.isnan(), 0, self_t).sum(dim, keepdim, dtype)
+
+# We never call _linalg_svd with compute_uv=False in an autograd context, so we don't even consider it here
+- name: _linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)
+  A: "svd_backward(full_matrices && grad_U.defined() ? grad_U.narrow_symint(-1, 0, S.sym_size(-1)) : grad_U,
+                   grad_S,
+                   full_matrices && grad_Vh.defined() ? grad_Vh.narrow_symint(-2, 0, S.sym_size(-1)) : grad_Vh,
+                   full_matrices ? U.narrow_symint(-1, 0, S.sym_size(-1)) : U,
+                   S,
+                   full_matrices ? Vh.narrow_symint(-2, 0, S.sym_size(-1)) : Vh)"
+  U, S, Vh: linalg_svd_jvp(A_t, U, S, Vh, full_matrices)
+
+- name: _linalg_eigh(Tensor A, str UPLO="L", bool compute_v=True) -> (Tensor eigenvalues, Tensor eigenvectors)
+  A: linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors, /*is_hermitian=*/true)
+  eigenvalues, eigenvectors: linalg_eig_jvp(A_t, eigenvalues, eigenvectors, /*is_hermitian=*/true)
+
+- name: linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors)
+  self: handle_r_to_c(self.scalar_type(), linalg_eig_backward(grads[0], grads[1], eigenvalues, eigenvectors, /*is_hermitian=*/false))
+  eigenvalues, eigenvectors: linalg_eig_jvp(self_t, eigenvalues, eigenvectors, /*is_hermitian=*/false)
+
+- name: t(Tensor(a) self) -> Tensor(a)
+  self: grad.t()
+  result: auto_linear
+
+- name: t_(Tensor(a!) self) -> Tensor(a!)
+  self: grad.t()
+  result: auto_linear
+
+- name: one_hot(Tensor self, int num_classes=-1) -> Tensor
+  self: non_differentiable
+
+- name: flip(Tensor self, int[] dims) -> Tensor
+  self: grad.flip(dims)
+  result: auto_linear
+
+- name: roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor
+  self: grad.roll_symint(fmap(reverse_list_symint(shifts), [](c10::SymInt i){return -i;}), reverse_list(dims))
+  result: auto_linear
+
+- name: rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor
+  self: grad.rot90(-k, dims)
+  result: auto_linear
+
+- name: take(Tensor self, Tensor index) -> Tensor
+  self: take_backward(grad, self, index)
+  index: non_differentiable
+  result: auto_linear
+
+- name: tan(Tensor self) -> Tensor
+  self: grad * (1 + result.pow(2)).conj()
+  result: auto_element_wise
+
+- name: tanh(Tensor self) -> Tensor
+  self: tanh_backward(grad, result)
+  result: auto_element_wise
+
+- name: topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
+  self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), true)
+  output_differentiability: [True, False]
+  values: gather(self_t, dim, indices)
+
+- name: trace(Tensor self) -> Tensor
+  self: trace_backward_symint(grad, self.sym_sizes())
+  result: auto_linear
+
+- name: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)
+  self: grad.transpose(dim0, dim1)
+  result: auto_linear
+
+- name: transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)
+  self: grad.transpose(dim0, dim1)
+  result: auto_linear
+
+- name: triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient)
+  self, A: triangular_solve_backward(grad_solution, grad_cloned_coefficient, self, A, solution, upper, transpose, unitriangular, grad_input_mask)
+  solution: triangular_solve_jvp(solution, A_p, A_t, self_t, upper, transpose, unitriangular)
+  cloned_coefficient: A_t
+
+- name: linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor
+  self, B: linalg_solve_triangular_backward(grad, self, result, upper, left, unitriangular, grad_input_mask)
+  result: linalg_solve_triangular_forward_AD(self_t, B_t, self_p, result, upper, left, unitriangular)
+
+- name: tril(Tensor self, int diagonal=0) -> Tensor
+  self: grad.tril(diagonal)
+  result: auto_linear
+
+- name: triu(Tensor self, int diagonal=0) -> Tensor
+  self: grad.triu(diagonal)
+  result: auto_linear
+
+- name: trunc(Tensor self) -> Tensor
+  self: zeros_like(grad)
+  result: auto_element_wise
+
+# DO NOT define a backward for to_dense
+# See [Note: Sometimes view derivatives]
+# - name: to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor
+#
+- name: _to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor
+  self: to_dense_backward(grad, self, masked_grad)
+
+# DO NOT define a backward for to_sparse.sparse_dim
+# See [Note: Sometimes view derivatives]
+# - name: to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor
+#
+- name: _to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor
+  self: to_sparse_backward(grad, self.layout(), self.sym_blocksize())
+
+# DO NOT define a backward for to_sparse
+# See [Note: Sometimes view derivatives]
+# - name: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
+#
+- name: _to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
+  self: to_sparse_backward(grad, self.layout(), self.sym_blocksize())
+
+# DO NOT define a backward for to_sparse_csr
+# See [Note: Sometimes view derivatives]
+# - name: to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor
+#
+- name: _to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor
+  self: to_sparse_backward(grad, self.layout(), self.sym_blocksize())
+
+# DO NOT define a backward for to_sparse_csc
+# See [Note: Sometimes view derivatives]
+# - name: to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor
+#
+- name: _to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor
+  self: to_sparse_backward(grad, self.layout(), self.sym_blocksize())
+
+# DO NOT define a backward for to_sparse_bsr
+# See [Note: Sometimes view derivatives]
+# - name: to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
+#
+- name: _to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
+  self: to_sparse_backward(grad, self.layout(), self.sym_blocksize())
+
+# DO NOT define a backward for to_sparse_bsc
+# See [Note: Sometimes view derivatives]
+# - name: to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
+#
+- name: _to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor
+  self: to_sparse_backward(grad, self.layout(), self.sym_blocksize())
+
+- name: to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor
+  self: to_mkldnn_backward(grad, self)
+
+- name: unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)
+  self: unfold_backward_symint(grad, self.sym_sizes(), dimension, size, step)
+  result: auto_linear
+
+- name: unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor
+  grad_in: grad.unfold(dim, size, step)
+  result: auto_linear
+
+- name: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)
+  self: zeros_like(grad)
+  result: self_t.zero_()
+
+- name: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)
+  output_differentiability: [True, False]
+  self: not_implemented("_unique")
+
+- name: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
+  output_differentiability: [True, False, False]
+  self: not_implemented("unique_dim")
+
+- name: unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor)
+  output_differentiability: [True, False, False]
+  self: not_implemented("unique_consecutive")
+
+- name: unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
+  output_differentiability: [True, False, False]
+  self: not_implemented("unique_dim_consecutive")
+
+- name: _unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)
+  output_differentiability: [True, False, False]
+  self: not_implemented("_unique2")
+
+- name: _unsafe_view(Tensor self, SymInt[] size) -> Tensor
+  self: grad.reshape_symint(self.sym_sizes())
+  result: auto_linear
+
+- name: lift(Tensor self) -> Tensor
+  self: grad
+  result: auto_linear
+
+- name: lift_fresh(Tensor(a) self) -> Tensor(a)
+  self: grad
+  result: auto_linear
+
+- name: unsqueeze(Tensor(a) self, int dim) -> Tensor(a)
+  self: grad.squeeze(dim)
+  result: auto_linear
+
+- name: unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!)
+  self: grad.squeeze(dim)
+  result: auto_linear
+
+- name: var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor
+  self: var_backward(grad, self, dim, correction, keepdim)
+  # pointwise + sum
+  result: at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim))
+
+- name: var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)
+  self: var_mean_backward(grads[0], grads[1], self, dim, correction, keepdim)
+  result0: at::real(var_backward(self_t.conj(), self_p, dim, correction, true).sum(dim.value_or(IntArrayRef({})), keepdim))
+  # linear
+  result1: mean(self_t, dim.value_or(IntArrayRef({})), keepdim)
+
+- name: view(Tensor(a) self, SymInt[] size) -> Tensor(a)
+  dispatch:
+    Default:
+      self: grad.reshape_symint(self.sym_sizes())
+      result: auto_linear
+    AutogradNestedTensor:
+      self: grad.reshape_as(self)
+      result: auto_linear
+
+- name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)
+  output_differentiability: [False]
+
+- name: view_as_real(Tensor(a) self) -> Tensor(a)
+  self: at::view_as_complex(grad.contiguous()) # gx0 + 1j * gx1
+  result: at::view_as_real(self_t)
+
+- name: view_as_complex(Tensor(a) self) -> Tensor(a)
+  self: at::view_as_real(grad.contiguous().resolve_conj()) # [gx, gy]
+  result: at::view_as_complex(self_t)
+
+- name: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor
+  condition: non_differentiable
+  self: where(condition, grad, 0)
+  other: where(condition, 0, grad)
+  result: where(condition, self_t, other_t)
+
+# weight_norm_cuda_interface_backward does not have an explicitly defined derivative, so if we do happen
+# to be running backward with create_graph=True, fall back to a backward function that uses
+# differentiable ops.
+- name: _weight_norm_interface(Tensor v, Tensor g, int dim=0) -> (Tensor, Tensor)
+  v, g: "grad.defined() ? (GradMode::is_enabled() ? _weight_norm_differentiable_backward(grad.contiguous(), v, g, result1, dim) : _weight_norm_interface_backward(grad.contiguous(), v, g, result1, dim)) : std::tuple()"
+
+- name: zero_(Tensor(a!) self) -> Tensor(a!)
+  self: zeros_like(grad)
+  result: auto_linear
+
+- name: sparse_mask(Tensor self, Tensor mask) -> Tensor
+  self: sparse_mask_backward(grad, mask, self.layout())
+  mask: non_differentiable
+
+- name: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor
+  indices: non_differentiable
+  values: grad.sparse_mask(result)._values()
+
+- name: sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
+  compressed_indices: non_differentiable
+  plain_indices: non_differentiable
+  # TODO: remove to_dense after gh-107381 is fixed
+  values: grad.to_dense().sparse_mask(result).values()
+
+- name: _sparse_sum.dim(Tensor self, int[1] dim) -> Tensor
+  self: at::_sparse_sum_backward(grad, self, dim)
+
+- name: _standard_gamma(Tensor self, Generator? generator=None) -> Tensor
+  self: grad * _standard_gamma_grad(self, result)
+
+- name: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor
+  self: not_implemented("_standard_gamma_grad")
+
+- name: values(Tensor(a) self) -> Tensor(a)
+  dispatch:
+    Default:
+      self: values_backward(grad, self)
+    AutogradNestedTensor:
+      self: at::_nested_view_from_buffer(grad.contiguous(), self._nested_tensor_size(), self._nested_tensor_strides(), self._nested_tensor_storage_offsets())
+
+# Why is _values() not differentiable?
+# See NOTE [ Sparse: autograd and API ]
+- name: _values(Tensor(a) self) -> Tensor(a)
+  output_differentiability: [False]
+
+# NN
+- name: _trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor
+  i1, i2, i3: "_trilinear_backward(grad,
+               wrap_opt_if(i1, grad_input_mask[1] || grad_input_mask[2]),
+               wrap_opt_if(i2, grad_input_mask[0] || grad_input_mask[2]),
+               wrap_opt_if(i3, grad_input_mask[0] || grad_input_mask[1]),
+               expand1, expand2, expand3, sumdim, grad_input_mask)"
+  result: "_trilinear(i1_t, i2_p, i3_p, expand1, expand2, expand3, sumdim, unroll_dim) +
+           _trilinear(i1_p, i2_t, i3_p, expand1, expand2, expand3, sumdim, unroll_dim) +
+           _trilinear(i1_p, i2_p, i3_t, expand1, expand2, expand3, sumdim, unroll_dim)"
+
+- name: constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor
+  self: constant_pad_nd_backward(grad, pad)
+  result: constant_pad_nd_symint(self_t, pad, 0)
+
+- name: binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor
+  self: binary_cross_entropy_backward(grad, self, target, weight, reduction)
+  target: binary_cross_entropy_target_backward(grad, self, target, weight, reduction)
+  result: "apply_loss_reduction(
+               binary_cross_entropy_backward(self_t, self_p, target_p, weight, at::Reduction::None)
+             + binary_cross_entropy_target_backward(target_t, self_p, target_p, weight, at::Reduction::None),
+           reduction)"
+
+- name: binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor
+  self: binary_cross_entropy_double_backward(grad_output, grad, self, target, weight, reduction)
+  target: binary_cross_entropy_double_backward_target(grad, grad_output, self, target, weight, reduction)
+  grad_output: binary_cross_entropy_double_backward_grad_output(grad, self, target, weight, reduction)
+  result: " binary_cross_entropy_double_backward(grad_output_p, self_t, self_p, target_p, weight, reduction)
+          + binary_cross_entropy_double_backward_target(target_t, grad_output_p, self_p, target_p, weight, reduction)
+          + binary_cross_entropy_double_backward_grad_output(grad_output_t, self_p, target_p, weight, reduction)"
+
+- name: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor
+  self: binary_cross_entropy_with_logits_backward(grad, self, target, weight, pos_weight, reduction)
+  target: binary_cross_entropy_with_logits_target_backward(grad, self, target, weight, pos_weight, reduction)
+  result: "apply_loss_reduction(
+               binary_cross_entropy_with_logits_backward(self_t, self_p, target_p, weight, pos_weight, at::Reduction::None)
+             + binary_cross_entropy_with_logits_target_backward(target_t, self_p, target_p, weight, pos_weight, at::Reduction::None),
+           reduction)"
+
+- name: embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor
+  indices: non_differentiable
+  weight: embedding_backward_symint(grad, indices, weight.sym_size(0), padding_idx, scale_grad_by_freq, sparse)
+  result: auto_linear
+
+- name: embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor
+  grad_output: embedding_dense_double_backward_symint(grad, indices, padding_idx)
+  indices: non_differentiable
+  result: auto_linear
+
+- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)
+  indices: non_differentiable
+  offsets: non_differentiable
+  weight: _embedding_bag_backward_symint(grad, indices, offsets, result1, result2, result3, weight.sym_size(0), scale_grad_by_freq, mode, sparse, per_sample_weights, padding_idx)
+  per_sample_weights: _embedding_bag_per_sample_weights_backward(grad, weight, indices, offsets, result1, mode, padding_idx)
+
+- name: _embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
+  grad: not_implemented("_embedding_bag_backward")
+  indices: non_differentiable
+  offsets: non_differentiable
+  offset2bag: non_differentiable
+  bag_size: non_differentiable
+  maximum_indices: non_differentiable
+  per_sample_weights: not_implemented("_embedding_bag_backward")
+
+- name: _embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor
+  grad: not_implemented("_embedding_bag_dense_backward")
+  indices: non_differentiable
+  offset2bag: non_differentiable
+  bag_size: non_differentiable
+  maximum_indices: non_differentiable
+  per_sample_weights: not_implemented("_embedding_bag_dense_backward")
+
+- name: embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)
+  indices: non_differentiable
+  self: not_implemented("embedding_renorm")
+
+- name: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
+  self: mse_loss_backward(grad, self, target, reduction)
+  target: mse_loss_backward(grad, target, self, reduction)
+  result: apply_loss_reduction(mse_loss_backward(self_t.conj(), self_p, target_p, at::Reduction::None).conj() + mse_loss_backward(target_t.conj(), target_p, self_p, at::Reduction::None).conj(), reduction)
+
+- name: multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor
+  self: multi_margin_loss_backward(grad, self, target, p, margin, weight, reduction)
+  target: non_differentiable
+
+- name: multilabel_margin_loss_forward(Tensor self, Tensor target, int reduction) -> (Tensor output, Tensor is_target)
+  self: multilabel_margin_loss_backward(grad, self, target, reduction, is_target)
+  target: non_differentiable
+
+- name: nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight)
+  self: nll_loss_backward_symint(grad, self, target, weight, reduction, ignore_index, total_weight)
+  target: non_differentiable
+  output: std::get<0>(nll_loss_forward_symint(self_t, target, weight, reduction, ignore_index))
+
+- name: nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight)
+  self: nll_loss2d_backward_symint(grad, self, target, weight, reduction, ignore_index, total_weight)
+  target: non_differentiable
+  output: std::get<0>(nll_loss2d_forward_symint(self_t, target, weight, reduction, ignore_index))
+
+- name: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor
+  self: smooth_l1_loss_backward(grad, self, target, reduction, beta)
+  target: smooth_l1_loss_backward(grad, target, self, reduction, beta)
+  result: apply_loss_reduction(smooth_l1_loss_backward(self_t.conj(), self_p, target_p, at::Reduction::None, beta).conj() + smooth_l1_loss_backward(target_t.conj(), target_p, self_p, at::Reduction::None, beta).conj(), reduction)
+
+- name: huber_loss(Tensor self, Tensor target, int reduction=Mean, float delta=1.0) -> Tensor
+  self: huber_loss_backward(grad, self, target, reduction, delta)
+  target: huber_loss_backward(grad, target, self, reduction, delta)
+  result: apply_loss_reduction(huber_loss_backward(self_t.conj(), self_p, target_p, at::Reduction::None, delta).conj() + huber_loss_backward(target_t.conj(), target_p, self_p, at::Reduction::None, delta).conj(), reduction)
+
+- name: soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
+  self: soft_margin_loss_backward(grad, self, target, reduction)
+  result: apply_loss_reduction(soft_margin_loss_backward(self_t.conj(), self_p, target, at::Reduction::None).conj(), reduction)
+
+- name: relu(Tensor self) -> Tensor
+  self: threshold_backward(grad, result, 0)
+  result: auto_element_wise
+
+- name: silu(Tensor self) -> Tensor
+  self: "GradMode::is_enabled() ? infinitely_differentiable_silu_backward(grad, self) : silu_backward(grad, self)"
+  result: auto_element_wise
+
+- name: mish(Tensor self) -> Tensor
+  self: "GradMode::is_enabled() ? infinitely_differentiable_mish_backward(grad, self) : mish_backward(grad, self)"
+  result: auto_element_wise
+
+- name: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor
+  self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ false, self)
+  result: auto_element_wise
+
+- name: elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!)
+  self: elu_backward(grad, alpha, scale, input_scale, /* is_result */ true, result)
+  result: self_t.copy_(elu_backward(original_self_t, alpha, scale, input_scale, /* is_result */ true, result))
+
+- name: celu(Tensor self, Scalar alpha=1.0) -> Tensor
+  self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ false, self)
+  result: auto_element_wise
+
+- name: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!)
+  self: elu_backward(grad, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result)
+  result: self_t.copy_(elu_backward(original_self_t, alpha, 1, 1.0/alpha.toFloat(), /* is_result */ true, result))
+
+- name: gelu(Tensor self, *, str approximate='none') -> Tensor
+  self: gelu_backward(grad, self, approximate)
+  result: auto_element_wise
+
+- name: gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor
+  grad_output: gelu_backward(grad, self, approximate)
+  self: gelu_double_backward(grad, grad_output, self, approximate)
+  result: gelu_backward(grad_output_t, self_p, approximate) + gelu_double_backward(self_t, grad_output_p, self_p, approximate)
+
+- name: glu(Tensor self, int dim=-1) -> Tensor
+  # TODO: glu_backward can benefit from forward result,
+  # and forward ad/forward over reverse ad for that matter
+  self: glu_backward(grad, self, dim)
+  result: glu_jvp(result, self_p, self_t, dim)
+
+- name: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor
+  self: hardshrink_backward(grad, self, lambd)
+  result: auto_element_wise
+
+- name: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor
+  grad_out: hardshrink_backward(grad, self, lambd)
+  self: zeros_like(grad)
+  result: at::where((self_p > lambd).logical_or(self_p < -lambd), grad_out_t, at::zeros({}, result.options()).expand_as(result))
+
+- name: hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor
+  self: hardtanh_backward(grad, self, min_val, max_val)
+  result: auto_element_wise
+
+- name: leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor
+  self: leaky_relu_backward(grad, self, negative_slope, false)
+  result: auto_element_wise
+
+- name: leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!)
+  self: leaky_relu_backward(grad, result, negative_slope, true)
+  result: self_t.copy_(leaky_relu_backward(original_self_t.conj(), result, negative_slope, true).conj())
+
+- name: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer)
+  self: log_sigmoid_backward(grad, self, buffer)
+  output: log_sigmoid_backward(self_t.conj(), self_p, buffer).conj()
+  output_differentiability: [True, False]
+
+- name: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
+  self: _log_softmax_backward_data(grad, result, dim, self.scalar_type())
+  result: self_t - logsumexp_jvp(self_p, self_t, {dim}, true)
+
+- name: _sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
+  self: _sparse_log_softmax_backward_data(grad, result, dim, self)
+
+- name: _masked_softmax(Tensor self, Tensor mask, int? dim=None, int? mask_type=None) -> Tensor
+  self: _masked_softmax_backward(grad, result, mask, dim)
+  mask: non_differentiable
+
+- name: _prelu_kernel(Tensor self, Tensor weight) -> Tensor
+  self, weight: "grad.defined() ? _prelu_kernel_backward(grad, self, weight) : std::tuple()"
+  result: at::where(self_p >= 0, self_t, weight_p * self_t + weight_t * self_p)
+
+- name: _prelu_kernel_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor)
+  grad_output: "grads[0].defined() ?
+                (grads[1].defined() ? at::where(self >= 0, grads[0], grads[0] * weight + grads[1] * self)
+                                    : at::where(self >= 0, grads[0], grads[0] * weight))
+                                    : at::where(self >= 0, at::zeros({}, grad_output.options()), grads[1] * self)"
+  self: "grads[1].defined() ? at::where(self >= 0, at::zeros({}, self.options()), grad_output * grads[1]) : zeros_like(self)"
+  weight: "grads[0].defined() ? at::where(self >= 0, at::zeros({}, weight.options()), grad_output * grads[0]) : zeros_like(self)"
+  result0: at::where(self_p >= 0, grad_output_t, grad_output_t * weight_p + grad_output_p * weight_t)
+  result1: at::where(self_p >= 0, at::zeros({}, self_p.options()), grad_output_p * self_t + grad_output_t * self_p)
+
+- name: rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor
+  self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false)
+  result: auto_element_wise
+
+- name: rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)
+  self: rrelu_with_noise_backward(grad, result, noise, lower, upper, training, true)
+
+- name: rrelu_with_noise_functional(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> (Tensor, Tensor noise_out)
+  noise: non_differentiable
+  self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false)
+
+- name: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor
+  self: _softmax_backward_data(grad, result, dim, self.scalar_type())
+  result: result * (self_t - logsumexp_jvp(self_p, self_t, {dim}, true))
+
+- name: _sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
+  self: _sparse_softmax_backward_data(grad, result, dim, self)
+
+- name: _sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor
+  self: sparse_sparse_matmul_backward(grad, self, other, 0)
+  other: sparse_sparse_matmul_backward(grad, self, other, 1)
+
+- name: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor
+  self: softplus_backward(grad, self, beta, threshold)
+  result: auto_element_wise
+
+- name: softshrink(Tensor self, Scalar lambd=0.5) -> Tensor
+  self: softshrink_backward(grad, self, lambd)
+  result: auto_element_wise
+
+- name: threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor
+  self: threshold_backward(grad, self, threshold)
+  result: auto_element_wise
+
+- name: threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!)
+  self: threshold_backward(grad, self, threshold)
+  result: self_t.copy_(threshold_backward(self_t.conj(), original_self_p, threshold).conj())
+
+- name: reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor
+  self: reflection_pad1d_backward_symint(grad, self, padding)
+  result: auto_linear
+
+- name: reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor
+  self: reflection_pad2d_backward_symint(grad, self, padding)
+  result: auto_linear
+
+- name: reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor
+  self: reflection_pad3d_backward_symint(grad, self, padding)
+  result: auto_linear
+
+- name: replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor
+  self: replication_pad1d_backward_symint(grad, self, padding)
+  result: auto_linear
+
+- name: replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor
+  self: replication_pad2d_backward_symint(grad, self, padding)
+  result: auto_linear
+
+- name: replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor
+  self: replication_pad3d_backward_symint(grad, self, padding)
+  result: auto_linear
+
+- name: upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor
+  self: upsample_linear1d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales)
+  result: auto_linear
+
+- name: upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+  self: upsample_bilinear2d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w)
+  result: auto_linear
+
+- name: _upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+  self: _upsample_bilinear2d_aa_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w)
+  result: auto_linear
+
+- name: upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+  self: upsample_bicubic2d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w)
+  result: auto_linear
+
+- name: _upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+  self: _upsample_bicubic2d_aa_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w)
+  result: auto_linear
+
+- name: upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+  self: upsample_trilinear3d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_d, scales_h, scales_w)
+  result: auto_linear
+
+- name: upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor
+  self: upsample_nearest1d_backward_symint(grad, output_size, self.sym_sizes(), scales)
+  result: auto_linear
+
+- name: _upsample_nearest_exact1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor
+  self: _upsample_nearest_exact1d_backward_symint(grad, output_size, self.sym_sizes(), scales)
+  result: auto_linear
+
+- name: upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor
+  self: upsample_nearest2d_backward_symint(grad, output_size, self.sym_sizes(), scales_h, scales_w)
+  result: auto_linear
+
+- name: _upsample_nearest_exact2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor
+  self: _upsample_nearest_exact2d_backward_symint(grad, output_size, self.sym_sizes(), scales_h, scales_w)
+  result: auto_linear
+
+- name: upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+  self: upsample_nearest3d_backward_symint(grad, output_size, self.sym_sizes(), scales_d, scales_h, scales_w)
+  result: auto_linear
+
+- name: _upsample_nearest_exact3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+  self: _upsample_nearest_exact3d_backward_symint(grad, output_size, self.sym_sizes(), scales_d, scales_h, scales_w)
+  result: auto_linear
+
+- name: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
+  self: pixel_unshuffle(grad, upscale_factor)
+  result: auto_linear
+
+- name: pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor
+  self: pixel_shuffle(grad, downscale_factor)
+  result: auto_linear
+
+- name: channel_shuffle(Tensor self, SymInt groups) -> Tensor
+  self: channel_shuffle_symint(grad, grad.sym_size(1) / groups)
+  result: auto_linear
+
+- name: _adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor
+  self: _adaptive_avg_pool2d_backward(grad, self)
+  result: auto_linear
+
+- name: _adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor
+  self: _adaptive_avg_pool3d_backward(grad, self)
+  result: auto_linear
+
+- name: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)
+  self: adaptive_max_pool2d_backward(grad, self, result1)
+  result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1)
+  output_differentiability: [True, False]
+
+- name: adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor)
+  self: adaptive_max_pool3d_backward(grad, self, result1)
+  result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1)
+  output_differentiability: [True, False]
+
+- name: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor
+  self: avg_pool2d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
+  result: auto_linear
+
+- name: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor
+  self: avg_pool3d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
+  result: auto_linear
+
+- name: fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor)
+  self: fractional_max_pool2d_backward(grad, self, kernel_size, output_size, result1)
+  result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1)
+  output_differentiability: [True, False]
+
+- name: fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor)
+  self: fractional_max_pool3d_backward(grad, self, kernel_size, output_size, result1)
+  result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1)
+  output_differentiability: [True, False]
+
+- name: linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
+  input, weight, bias: "grad.defined() ? linear_backward(input, grad, weight, grad_input_mask) : std::tuple()"
+
+- name: linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+  self, grad_output, weight: linear_double_backward(grads, self, grad_output, weight)
+
+#mps
+- name: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
+  self: max_pool2d_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode)
+
+- name: _mps_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor
+  self, weight, bias: "grad.defined() ? mps_convolution_backward_symint(self, grad, weight, padding, stride, dilation, groups, grad_input_mask) : std::tuple()"
+
+- name: mps_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+  grad_output, self, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask)
+
+- name: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
+  self: max_pool2d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, result1)
+  result0: gather(self_t.flatten(-2), -1, result1.flatten(-2)).view_as(result1)
+  output_differentiability: [True, False]
+
+- name: max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
+  self: max_pool3d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, result1)
+  result0: gather(self_t.flatten(-3), -1, result1.flatten(-3)).view_as(result1)
+  output_differentiability: [True, False]
+
+- name: max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor
+  self: max_pool_double_backward(grad, indices, 2)
+  indices: non_differentiable
+  result: auto_linear
+
+- name: max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor
+  self: max_pool_double_backward(grad, indices, 3)
+  indices: non_differentiable
+  result: auto_linear
+
+- name: convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor
+  input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple()"
+  result: convolution_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, stride, padding, dilation, transposed, output_padding, groups)
+
+# TorchScript serializes calls to _convolution so this entry is present until that is changed to use convolution.
+# Note that the benchmark, deterministic, cudnn_enabled, and allow_tf32 flags are queried from the global context
+# by convolution_backward instead of being passed along from the forward pass.
+- name: _convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor
+  input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple()"
+  result: _convolution_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32)
+
+- name: convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
+  grad_output, input, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask)
+  result0: std::get<0>(convolution_backward_symint(grad_output_p, input_p, weight_t, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {true, false, false})) + std::get<0>(convolution_backward_symint(grad_output_t, input_p, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {true, false, false}))
+  result1: std::get<1>(convolution_backward_symint(grad_output_p, input_t, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {false, true, false})) + std::get<1>(convolution_backward_symint(grad_output_t, input_p, weight_p, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, {false, true, false}))
+  result2: convolution_backward_jvp_grad_bias(grad_output_t, result2)
+
+- name: convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor
+  input, weight, bias: "grad.defined() ? convolution_backward_overrideable_symint(grad, input, weight, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask) : std::tuple()"
+
+- name: convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
+  grad_output, input, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, input, stride, padding, dilation, transposed, output_padding, groups, grad_input_mask)
+
+- name: slow_conv_transpose2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1) -> Tensor
+  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple()"
+
+- name: slow_conv_transpose3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1) -> Tensor
+  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, 1, grad_input_mask) : std::tuple()"
+
+- name: _slow_conv2d_forward(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding) -> Tensor
+  self, weight, bias: "grad.defined() ? _slow_conv2d_backward_symint(grad, self, weight, kernel_size, stride, padding, grad_input_mask) : std::tuple()"
+
+- name: _slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)
+  grad_output, self, weight: _convolution_double_backward_symint(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1}}, false, {{0, 0}}, 1, grad_input_mask)
+
+- name: _conv_depthwise2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor
+  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad.contiguous(), self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple()"
+
+- name: conv_depthwise3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation) -> Tensor
+  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad.contiguous(), self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ {{0, 0, 0}}, /*groups=*/ 1, grad_input_mask) : std::tuple()"
+
+- name: slow_conv3d_forward(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding) -> Tensor
+  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, /*dilation=*/ {{1, 1, 1}}, false, /*output_padding=*/ {{0, 0, 0}}, 1, grad_input_mask) : std::tuple()"
+
+- name: slow_conv_dilated2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1) -> Tensor
+  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()"
+
+- name: slow_conv_dilated3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1) -> Tensor
+  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()"
+
+- name: col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor
+  self: im2col(grad, kernel_size, dilation, padding, stride)
+  result: auto_linear
+
+- name: im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor
+  self: col2im_symint(grad, {self.sym_size(-2), self.sym_size(-1)}, kernel_size, dilation, padding, stride)
+  result: auto_linear
+
+- name: _adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor
+  grad_output: _adaptive_avg_pool2d_symint(grad, {grad_output.sym_size(-2), grad_output.sym_size(-1)})
+  self: zeros_like(self)
+  result: _adaptive_avg_pool2d_backward(grad_output_t, self_p)
+
+- name: _adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor
+  grad_output: _adaptive_avg_pool3d_symint(grad, { grad_output.sym_size(-3), grad_output.sym_size(-2), grad_output.sym_size(-1) })
+  self: zeros_like(self)
+  result: _adaptive_avg_pool3d_backward(grad_output_t, self_p)
+
+- name: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor
+  grad_output: max_pool_double_backward(grad, indices, 2)
+  self: zeros_like(self)
+  result: auto_linear
+
+- name: adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor
+  grad_output: max_pool_double_backward(grad, indices, 3)
+  self: zeros_like(self)
+  result: auto_linear
+
+- name: avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor
+  grad_output: avg_pool2d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
+  self: zeros_like(self)
+  result: avg_pool2d_backward(grad_output_t, self_p, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
+
+- name: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor
+  grad_output: avg_pool3d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
+  self: zeros_like(self)
+  result: avg_pool3d_backward(grad_output_t, self_p, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
+
+- name: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor
+  grad_output: elu_backward(grad, alpha, scale, input_scale, is_result, self_or_result)
+  self_or_result: elu_double_backward(grad, grad_output, alpha, scale, input_scale, is_result, self_or_result)
+  result: elu_backward(grad_output_t, alpha, scale, input_scale, is_result, self_or_result_p) + elu_double_backward(self_or_result_t, grad_output_p, alpha, scale, input_scale, is_result, self_or_result_p)
+
+- name: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor
+  grad_output: max_pool_double_backward(grad, indices, 2)
+  self: zeros_like(self)
+  result: auto_linear
+
+- name: fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor
+  grad_output: max_pool_double_backward(grad, indices, 3)
+  self: zeros_like(self)
+  result: auto_linear
+
+- name: glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor
+  grad_output: glu_double_backward_grad_output(grad, self, dim)
+  self: glu_double_backward(grad, grad_output, self, dim)
+  result: glu_backward_jvp(result, grad_output_p, self_p, grad_output_t, self_t, dim)
+
+- name: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor
+  grad_output: hardtanh_backward(grad, self, min_val, max_val)
+  self: zeros_like(grad)
+  result: at::where((self_p > min_val).logical_and(self_p < max_val), grad_output_t, at::zeros({}, result.options()).expand_as(result))
+
+- name: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor
+  grad_output: log_sigmoid_backward(grad, self, buffer)
+  self: log_sigmoid_double_backward(grad * grad_output, self)
+  result: log_sigmoid_backward(grad_output_t, self_p, buffer) + log_sigmoid_double_backward(self_t * grad_output_p, self_p)
+
+- name: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor
+  grad_output: grad.to(output.dtype()) - (grad.to(output.dtype()) * output.exp()).sum(dim, true)
+  output: (-grad_output.sum(dim, true) * output.exp() * grad.to(output.dtype())).to(output.dtype())
+
+- name: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor
+  # self_is_result is always false here since double backward call is an out-of-place call, self is input itself
+  grad_output: leaky_relu_backward(grad, self, negative_slope, false)
+  self: zeros_like(grad)
+  # leaky_relu_backward(grad_output, self, negative_slope, false)
+  # computes grad_output * at::where(self_p > 0, 1, negative_slope)
+  # so the jvp formula is the following:
+  # grad_output_t * at::where(self_p > 0, self_p.new_ones([]), negative_slope);
+  #
+  # leaky_relu_backward(grad_output, result, negative_slope, true)
+  # computes grad_output * at::where(result > 0, 1, negative_slope)
+  # under the assumption that `negative_slope` is positive (otherwise,
+  # it is not possible to compute the gradient).
+  #
+  # so the jvp formula is the following:
+  # grad_output_t * at::where(result_p > 0, result_p.new_ones([]), negative_slope);
+  # with the assumption that negative_slope is positive.
+  #
+  # Combined together that results in the following optimized kernel which
+  # also checks the assumption that negative_slope is positive when self_is_result
+  # is True:
+  result: leaky_relu_backward(grad_output_t, self_p, negative_slope, self_is_result)
+
+# This derivative is mps-only, and `error_for_max_pool2d_double_backward` just raises an error.
+- name: max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
+  grad_output: error_for_max_pool2d_double_backward()
+  self: zeros_like(self)
+  result: auto_linear
+
+- name: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor
+  grad_output: max_pool_double_backward(grad, indices, 2)
+  self: zeros_like(self)
+  indices: non_differentiable
+  result: auto_linear
+
+- name: max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor
+  grad_output: max_pool_double_backward(grad, indices, 3)
+  self: zeros_like(self)
+  indices: non_differentiable
+  result: auto_linear
+
+- name: mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
+  grad_output: mse_loss_backward(grad, self, target, reduction)
+  self: mse_loss_double_backward(grad * grad_output, self, reduction)
+  target: -mse_loss_double_backward(grad * grad_output, target, reduction)
+  result: "  mse_loss_double_backward(self_t * grad_output_p, self_p, reduction)
+           - mse_loss_double_backward(target_t * grad_output_p, target_p, reduction)
+           + mse_loss_backward(grad_output_t, self_p, target_p, reduction)
+          "
+
+- name: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor
+  grad_output: nll_loss_symint(grad, target, weight, reduction, ignore_index)
+  self: zeros_like(grad)
+  target: non_differentiable
+
+- name: nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor
+  grad_output: nll_loss2d_symint(grad, target, weight, reduction, ignore_index)
+  self: zeros_like(grad)
+  target: non_differentiable
+
+- name: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor
+  # self_is_result is always false here since double backward call is an out-of-place call, self is input itself
+  grad_output: rrelu_with_noise_backward(grad, self, noise, lower, upper, training, false)
+  self: zeros_like(grad)
+  result: rrelu_with_noise_backward(grad_output_t, self_p, noise, lower, upper, training, false)
+
+- name: reflection_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor
+  grad_output: reflection_pad1d_symint(grad, padding)
+  self: zeros_like(self)
+  result: reflection_pad1d_backward_symint(grad_output_t, self_p, padding)
+
+- name: reflection_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor
+  grad_output: reflection_pad2d_symint(grad, padding)
+  self: zeros_like(self)
+  result: reflection_pad2d_backward_symint(grad_output_t, self_p, padding)
+
+- name: reflection_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor
+  grad_output: reflection_pad3d_symint(grad, padding)
+  self: zeros_like(self)
+  result: reflection_pad3d_backward_symint(grad_output_t, self_p, padding)
+
+- name: replication_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor
+  grad_output: replication_pad1d_symint(grad, padding)
+  self: zeros_like(self)
+  result: replication_pad1d_backward_symint(grad_output_t, self_p, padding)
+
+- name: replication_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor
+  grad_output: replication_pad2d_symint(grad, padding)
+  self: zeros_like(self)
+  result: replication_pad2d_backward_symint(grad_output_t, self_p, padding)
+
+- name: replication_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor
+  grad_output: replication_pad3d_symint(grad, padding)
+  self: zeros_like(self)
+  result: replication_pad3d_backward_symint(grad_output_t, self_p, padding)
+
+- name: sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
+  self, mat1, mat2: "sparse_sampled_addmm_backward(grad,
+                                                   self,
+                                                   wrap_opt_if(mat1, grad_input_mask[2]),
+                                                   wrap_opt_if(mat2, grad_input_mask[1]),
+                                                   alpha, beta, grad_input_mask)"
+
+- name: _sparse_mm_reduce_impl(Tensor self, Tensor other, str reduce) -> (Tensor, Tensor)
+  output_differentiability: [True, False]
+  self, other: "grad.defined() ? _sparse_mm_reduce_impl_backward(self, grad, other, reduce, result1, grad_input_mask) :  std::tuple()"
+
+- name: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor
+  grad_output: smooth_l1_loss_backward(grad, self, target, reduction, beta)
+  self: smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta)
+  target: -smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta)
+  result: "  smooth_l1_loss_double_backward(self_t * grad_output_p, self_p, target_p, reduction, beta)
+           - smooth_l1_loss_double_backward(target_t * grad_output_p, self_p, target_p, reduction, beta)
+           + smooth_l1_loss_backward(grad_output_t, self_p, target_p, reduction, beta)
+          "
+
+- name: huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor
+  grad_output: huber_loss_double_backward_grad_output(grad, grad_output, self, target, reduction, delta)
+  self: huber_loss_double_backward(grad * grad_output, self, target, reduction, delta)
+  target: -huber_loss_double_backward(grad * grad_output, self, target, reduction, delta)
+
+- name: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold) -> Tensor
+  grad_output: softplus_backward(grad, self, beta, threshold)
+  self: softplus_double_backward(grad * grad_output, self, beta, threshold)
+  result: "softplus_backward(grad_output_t, self_p, beta, threshold)
+         + softplus_double_backward(self_t * grad_output_p, self_p, beta, threshold)"
+
+- name: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor
+  grad_output: _softmax_backward_data(grad.to(output.dtype()), output, dim, input_dtype)
+  output: softmax_double_backward(grad.to(output.dtype()), grad_output, dim, output).to(output.dtype())
+
+- name: soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
+  grad_output: soft_margin_loss_double_backward_grad_output(grad, grad_output, self, target, reduction)
+  self: soft_margin_loss_double_backward(grad * grad_output, self, target, reduction)
+
+- name: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor
+  grad_output: softshrink_backward(grad, self, lambd)
+  self: zeros_like(grad)
+  result: at::where((self_p > lambd).logical_or(self_p < -lambd), grad_output_t, at::zeros({}, result.options()).expand_as(result))
+
+- name: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor
+  grad_output: threshold_backward(grad, self, threshold)
+  self: zeros_like(grad)
+  result: zeros_like(self_t) + threshold_backward(grad_output_t, self_p, threshold)
+
+- name: upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor
+  grad_output: upsample_linear1d_symint(grad, output_size, align_corners, scales)
+  result: auto_linear
+
+- name: upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+  grad_output: upsample_bilinear2d_symint(grad, output_size, align_corners, scales_h, scales_w)
+  result: auto_linear
+
+- name: _upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+  grad_output: _upsample_bilinear2d_aa_symint(grad, output_size, align_corners, scales_h, scales_w)
+  result: auto_linear
+
+- name: upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+  grad_output: upsample_bicubic2d_symint(grad, output_size, align_corners, scales_h, scales_w)
+  result: auto_linear
+
+- name: _upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
+  grad_output: _upsample_bicubic2d_aa_symint(grad, output_size, align_corners, scales_h, scales_w)
+  result: auto_linear
+
+- name: upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+  grad_output: upsample_trilinear3d_symint(grad, output_size, align_corners, scales_d, scales_h, scales_w)
+  result: auto_linear
+
+- name: upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor
+  grad_output: upsample_nearest1d_symint(grad, output_size, scales)
+  result: auto_linear
+
+- name: _upsample_nearest_exact1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor
+  grad_output: _upsample_nearest_exact1d_symint(grad, output_size, scales)
+  result: auto_linear
+
+- name: upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor
+  grad_output: upsample_nearest2d_symint(grad, output_size, scales_h, scales_w)
+  result: auto_linear
+
+- name: _upsample_nearest_exact2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor
+  grad_output: _upsample_nearest_exact2d_symint(grad, output_size, scales_h, scales_w)
+  result: auto_linear
+
+- name: upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+  grad_output: upsample_nearest3d_symint(grad, output_size, scales_d, scales_h, scales_w)
+  result: auto_linear
+
+- name: _upsample_nearest_exact3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor
+  grad_output: _upsample_nearest_exact3d_symint(grad, output_size, scales_d, scales_h, scales_w)
+  result: auto_linear
+
+- name: sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor
+  grad_output: sigmoid_backward(grad, output.conj())
+  output: grad.conj() * grad_output * (-2 * output.conj() + 1)
+  result: sigmoid_backward(grad_output_t, output_p) + output_t.conj() * grad_output_p * (-2 * output_p.conj() + 1)
+
+- name: tanh_backward(Tensor grad_output, Tensor output) -> Tensor
+  grad_output: tanh_backward(grad, output.conj())
+  output: grad.conj() * (-2 * output.conj() * grad_output)
+  result: tanh_backward(grad_output_t, output_p) + output_t.conj() * (-2 * output_p.conj() * grad_output_p)
+
+# cudnn
+- name: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)
+  log_probs: _cudnn_ctc_loss_backward(grad, result0, result1, zero_infinity)
+
+- name: _cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)
+  log_probs: _cudnn_ctc_loss_backward(grad, result0, result1, zero_infinity)
+
+- name: cudnn_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
+  self, weight: "_cudnn_convolution_backward(self, grad, weight, padding, output_padding, stride, dilation, true, groups, {grad_input_mask[0], grad_input_mask[1]})"
+
+- name: _mps_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor
+  self, weight: "grad.defined() ? mps_convolution_transpose_backward_symint(self, grad, weight, padding, output_padding, stride, dilation, groups, grad_input_mask) : std::tuple()"
+
+- name: cudnn_convolution(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor
+  self, weight: "_cudnn_convolution_backward(self, grad, weight, padding, std::vector(padding.size(), 0), stride, dilation, false, groups, {grad_input_mask[0], grad_input_mask[1]})"
+
+- name: cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output
+  self, grid: "grad.defined() ? cudnn_grid_sampler_backward(self, grid, grad) : std::tuple()"
+
+- name: cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid
+  theta: cudnn_affine_grid_generator_backward(grad, N, C, H, W)
+
+# NB: Why is the backwards here so complicated?  CuDNN cannot be used to compute
+# backward in evaluation mode, because the math for backward in evaluation mode
+# is different (since the forward math is different), and CuDNN does not support
+# it.  And in any case, you shouldn't be using this bn in evaluation mode,
+# because it should be merged into the previous convolution (left for future
+# work.)
+# NB2: The quotes around the gradient are needed to appease YAML parsing rules.
+- name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)
+  input, weight, bias: "grad.defined() ? (training ? cudnn_batch_norm_backward(input, grad.contiguous(input.suggest_memory_format()), weight, running_mean, running_var, result1, result2, epsilon, retain_variables ? result3.clone() : result3) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()"
+  result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon)
+
+# HACK: save_mean and save_var are going to be passed in as
+# requires_grad variables (even though we'll never backprop through
+# them) so we need to prevent the unpacking from triggering an error.
+- name: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor)
+  save_mean: not_implemented("cudnn_batch_norm_backward save_mean")
+  save_var: not_implemented("cudnn_batch_norm_backward save_var")
+  reserveSpace: not_implemented("cudnn_batch_norm_backward reserveSpace")
+  input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask)
+
+# nnpack
+
+- name: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1) -> Tensor
+  # NNPACK does not support strided convolutions in the backwards path, which is the reason why we are using the closest available function that does here.
+  input, weight, bias: "grad.defined() ? convolution_backward_symint(grad, input, weight, bias->sym_sizes(), stride, padding, std::vector(padding.size(), 1), false, std::vector(padding.size(), 0), 1, grad_input_mask) : std::tuple()"
+
+#LSTM MPS
+- name: _lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
+  output_differentiability: [True, True, True, False, False, False]
+  input, hx, params: "lstm_mps_backward(grads[0], grads[1], grads[2], result3, result4, input, result5, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first)"
+
+- name: lstm_mps_backward(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])
+
+
+
+# Only frst three of _cudnn_rnn outputs can have gradients.
+# _cudnn_rnn outputs: (output, hy, cy, reserve, weight_buf)
+- name: _cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
+  dropout_state: non_differentiable
+  output_differentiability: [True, True, True, False, False]
+  input, hx, cx, weight: "_cudnn_rnn_backward_symint(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)"
+
+- name: _cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[])
+  dropout_state: non_differentiable
+  input: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
+  weight: not_implemented_list("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
+  hx: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
+  cx: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
+  output: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
+  grad_output: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
+  grad_hy: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
+  grad_cy: not_implemented("_cudnn_rnn_backward", kCudnnDoubleBackwardMsg)
+
+# miopen
+
+- name: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor
+  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, true, output_padding, groups, grad_input_mask) : std::tuple()"
+
+- name: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor
+  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()"
+
+- name: miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor
+  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, false, std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()"
+
+- name: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)
+  input, weight, bias: "grad.defined() ? (training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()"
+  result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon)
+
+- name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor)
+  save_mean: not_implemented("miopen_batch_norm_backward save_mean")
+  save_var: not_implemented("miopen_batch_norm_backward save_var")
+  input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask)
+
+- name: miopen_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor)
+  dropout_state: non_differentiable
+  output_differentiability: [True, True, True, False, False]
+  input, hx, cx, weight: "miopen_rnn_backward(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)"
+
+- name: miopen_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[])
+  dropout_state: non_differentiable
+
+- name: mkldnn_rnn_layer(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) -> (Tensor, Tensor, Tensor, Tensor)
+  output_differentiability: [True, True, True, False]
+  input, weight0, weight1, weight2, weight3, hx_, cx_: "GradMode::is_enabled() ? mkldnn_rnn_layer_differentiable_backward(input, weight0, weight1, weight2, weight3, hx_, cx_, result0, result1, result2, grads[0], grads[1], grads[2], reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, result3) : mkldnn_rnn_layer_backward(input, weight0, weight1, weight2, weight3, hx_, cx_, result0, result1, result2, grads[0], grads[1], grads[2], reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, result3)"
+
+- name: mkldnn_rnn_layer_backward(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)
+
+# mkldnn
+- name: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor
+  self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()"
+
+- name: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor
+  self, weight, bias: mkldnn_linear_backward(self, grad, weight, grad_input_mask)
+
+- name: mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor
+  self: mkldnn_max_pool2d_backward(grad, result, self, kernel_size, stride, padding, dilation, ceil_mode)
+
+- name: mkldnn_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor
+  self: mkldnn_max_pool3d_backward(grad, result, self, kernel_size, stride, padding, dilation, ceil_mode)
+
+- name: mkldnn_adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor
+  self: mkldnn_adaptive_avg_pool2d_backward(grad, self)
+
+- name: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor
+  self: grad.reshape_symint(self.sym_sizes())
+
+# NestedTensor
+- name: _nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  list: "grad.defined()? at::unbind(grad) : std::vector(list.size())"
+
+- name: _nested_tensor_from_mask(Tensor t, Tensor mask, bool mask_check=True) -> Tensor
+  t: grad.to_padded_tensor_symint(0, t.sym_sizes())
+  mask: non_differentiable
+
+- name: _nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor
+  padded: _nested_from_padded_backward(grad, padded, fuse_transform_0213)
+  cpu_nested_shape_example: non_differentiable
+
+- name: to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor
+  self: "self.layout() == c10::kJagged ? at::_nested_from_padded_tensor_symint(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? std::optional(at::_nested_get_min_seqlen(self)) : ::std::nullopt, at::_nested_get_max_seqlen(self).defined() ? std::optional(at::_nested_get_max_seqlen(self)) : ::std::nullopt, std::optional(at::_nested_get_values(self).sym_size(0))) : at::_nested_from_padded(grad, self._nested_tensor_size())"
+  padding: non_differentiable
+
+- name: _nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, SymInt? sum_S=None) -> Tensor
+  padded: grad.to_padded_tensor_symint(0.0, at::OptionalArrayRef(padded.sym_sizes()))
+  offsets: non_differentiable
+  dummy: non_differentiable
+
+- name:  _nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor(a)
+  self: grad.values()
+  nested_size: non_differentiable
+  nested_strides: non_differentiable
+
+- name: _nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a)
+  self: grad.values()
+  offsets: non_differentiable
+  lengths: non_differentiable
+  dummy: non_differentiable
+  min_seqlen: non_differentiable
+  max_seqlen: non_differentiable
+
+- name: _nested_get_values(Tensor(a) self) -> Tensor(a)
+  self: "_nested_view_from_jagged(grad, at::_nested_get_offsets(self), at::_nested_get_jagged_dummy(self), at::_nested_get_lengths(self), at::_nested_get_ragged_idx(self), at::_nested_get_min_seqlen(self).defined() ? std::optional(at::_nested_get_min_seqlen(self)) : ::std::nullopt, at::_nested_get_max_seqlen(self).defined() ? std::optional(at::_nested_get_max_seqlen(self)) : ::std::nullopt)"
+
+# Transformer
+- name:  _safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor
+  self: _softmax_backward_data(grad, result, dim, self.scalar_type())
+  result: result * (self_t - safe_logsumexp_jvp(self_p, self_t, {dim}, true))
+
+- name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
+  output_differentiability: [True, False, False, False]
+  query, key, value, attn_bias: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, attn_bias, output, log_sumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale)
+
+- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
+  output_differentiability: [True, False, False, False, False, False, False, False, False]
+  query, key, value: _scaled_dot_product_flash_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, rng_state, unused, scale)
+
+- name: _scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp)
+  output_differentiability: [True, False]
+  query, key, value: _scaled_dot_product_flash_attention_for_cpu_backward(grad, query, key, value, output, logsumexp, dropout_p, is_causal, attn_mask, scale)
+
+- name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
+  output_differentiability: [True, False, False, False, False]
+  query, key, value: _flash_attention_backward_symint(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, rng_state, unused, scale, window_size_left, window_size_right)
+
+- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)
+  output_differentiability: [True, False, False, False, False, False]
+  query, key, value, bias: _efficient_attention_backward_symint(grad, query, key, value, bias, output, cu_seqlens_q, cu_seqlens_k, max_seqlen_batch_q, max_seqlen_batch_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias.requires_grad(), scale)
+
+- name: _scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
+  output_differentiability: [True, False, False, False, False, False, False, False, False]
+  query, key, value: _scaled_dot_product_cudnn_attention_backward_symint(grad, query, key, value, output, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale)
+
+- name: _scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
+  output_differentiability: [True, False, False, False, False, False, False, False, False]
+  query, key, value, attn_bias: _scaled_dot_product_fused_attention_overrideable_backward_symint(grad, query, key, value, attn_bias, grad_input_mask, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
+
+# fft
+- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
+  self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back()))
+  result: auto_linear
+
+- name: _fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor
+  self: fft_c2r_backward(grad, dim, normalization)
+  result: auto_linear
+
+- name: _fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor
+  self: _fft_c2c_symint(grad, dim, normalization, !forward)
+  result: auto_linear
+
+- name: unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]
+  dispatch:
+    Default:
+      self: unbind_backward(grads, dim)
+      result: auto_linear
+    AutogradNestedTensor:
+      self: "self.layout() == c10::kJagged ? unbind_backward_nested_jagged(grads, self, dim) : unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options())"
+      result: auto_linear
+
+- name: stack(Tensor[] tensors, int dim=0) -> Tensor
+  tensors: stack_tensors_backward(grad, dim, to_args_scalartypes(tensors))
+  result: stack_jvp(tensors, dim)
+
+# fused RNN kernels
+
+# Only frst two of _thnn_fused_lstm_cell outputs can have gradients.
+# _thnn_fused_lstm_cell outputs: (hy, cy, workspace)
+- name: _thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor)
+  output_differentiability: [True, True, False]
+  input_gates, hidden_gates, cx, input_bias, hidden_bias: "GradMode::is_enabled() ? _thnn_differentiable_lstm_cell_backward(grads[0], grads[1], input_gates, hidden_gates, input_bias, hidden_bias, cx, result1) : _thnn_fused_lstm_cell_backward(grads[0], grads[1], cx, result1, result2, input_bias.defined())"
+
+- name: _thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor)
+  input_gates, hidden_gates, hx, input_bias, hidden_bias: "grad.defined() ? (GradMode::is_enabled() ? _thnn_differentiable_gru_cell_backward(grad, input_gates, hidden_gates, hx, input_bias, hidden_bias) : _thnn_fused_gru_cell_backward(grad, result1, input_bias.defined())) : std::tuple()"
+
+# PackedSequence helpers
+- name: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)
+  input: _pack_padded_sequence_backward_symint(grad, input.sym_sizes(), result1, batch_first)
+
+# TH wrappers
+- name: eq.Scalar(Tensor self, Scalar other) -> Tensor
+  output_differentiability: [False]
+
+- name: eq.Tensor(Tensor self, Tensor other) -> Tensor
+  output_differentiability: [False]
+
+- name: ge.Scalar(Tensor self, Scalar other) -> Tensor
+  output_differentiability: [False]
+
+- name: ge.Tensor(Tensor self, Tensor other) -> Tensor
+  output_differentiability: [False]
+
+- name: gt.Scalar(Tensor self, Scalar other) -> Tensor
+  output_differentiability: [False]
+
+- name: gt.Tensor(Tensor self, Tensor other) -> Tensor
+  output_differentiability: [False]
+
+- name: le.Scalar(Tensor self, Scalar other) -> Tensor
+  output_differentiability: [False]
+
+- name: le.Tensor(Tensor self, Tensor other) -> Tensor
+  output_differentiability: [False]
+
+- name: lt.Scalar(Tensor self, Scalar other) -> Tensor
+  output_differentiability: [False]
+
+- name: lt.Tensor(Tensor self, Tensor other) -> Tensor
+  output_differentiability: [False]
+
+- name: ne.Scalar(Tensor self, Scalar other) -> Tensor
+  output_differentiability: [False]
+
+- name: ne.Tensor(Tensor self, Tensor other) -> Tensor
+  output_differentiability: [False]
+
+- name: multinomial(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor
+  output_differentiability: [False]
+
+- name: nonzero(Tensor self) -> Tensor
+  output_differentiability: [False]
+
+- name: segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor
+  data: _segment_reduce_backward(grad, result, data, reduce, lengths, offsets, axis, initial)
+
+- name: _pin_memory(Tensor self, Device? device=None) -> Tensor
+  self: grad
+
+- name: _new_zeros_with_same_feature_meta(Tensor self, Tensor other, *, int self_num_batch_dims=0) -> Tensor
+  self: non_differentiable
+  other: non_differentiable
+  output_differentiability: [False]
+
+- name: _test_warn_in_autograd(Tensor self) -> Tensor
+  self: warn_backwards(grad)
+
+- name: _test_autograd_multiple_dispatch.fullcoverage(Tensor self) -> Tensor
+  dispatch:
+    Default:
+      self: grad.expand_symint(self.sym_sizes()) + 1
+      result: auto_linear
+    AutogradNestedTensor:
+      self: grad.mul(grad)
+    AutogradCUDA:
+      self: grad.expand_symint(self.sym_sizes()) * 2
+
+- name: _test_autograd_multiple_dispatch.ntonly(Tensor self, bool b) -> Tensor
+  dispatch:
+    AutogradNestedTensor:
+      self: grad.mul(grad).add(grad)
+
+- name: _test_autograd_multiple_dispatch_view(Tensor(a) self) -> Tensor(a)
+  dispatch:
+    Default:
+      self: grad.reshape_as(self)
+    AutogradCUDA:
+      self: grad.reshape_as(self) + 1
+
+- name: _efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  output_differentiability: [False]
+
+- name: scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor
+  self, src: scatter_reduce_backward(grad, self, dim, index, src, reduce, include_self, result)
+  index: non_differentiable
+  result: scatter_reduce_jvp(self_p, self_t, dim, index, src_p, src_t, reduce, include_self, result)
+
+- name: special_airy_ai(Tensor x) -> Tensor
+  x: non_differentiable
+
+- name: special_bessel_j0(Tensor self) -> Tensor
+  self: non_differentiable
+
+- name: special_bessel_j1(Tensor self) -> Tensor
+  self: non_differentiable
+
+- name: special_bessel_y0(Tensor self) -> Tensor
+  self: non_differentiable
+
+- name: special_bessel_y1(Tensor self) -> Tensor
+  self: non_differentiable
+
+- name: special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor
+  x: non_differentiable
+  n: non_differentiable
+
+- name: special_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor
+  n: non_differentiable
+
+- name: special_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor
+  x: non_differentiable
+
+- name: special_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor
+  x: non_differentiable
+  n: non_differentiable
+
+- name: special_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor
+  n: non_differentiable
+
+- name: special_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor
+  x: non_differentiable
+
+- name: special_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor
+  x: non_differentiable
+  n: non_differentiable
+
+- name: special_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor
+  n: non_differentiable
+
+- name: special_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor
+  x: non_differentiable
+
+- name: special_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor
+  x: non_differentiable
+  n: non_differentiable
+
+- name: special_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor
+  n: non_differentiable
+
+- name: special_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor
+  x: non_differentiable
+
+- name: special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor
+  x: non_differentiable
+  n: non_differentiable
+
+- name: special_hermite_polynomial_h.x_scalar(Scalar x, Tensor n) -> Tensor
+  n: non_differentiable
+
+- name: special_hermite_polynomial_h.n_scalar(Tensor x, Scalar n) -> Tensor
+  x: non_differentiable
+
+- name: special_hermite_polynomial_he(Tensor x, Tensor n) -> Tensor
+  x: non_differentiable
+  n: non_differentiable
+
+- name: special_hermite_polynomial_he.x_scalar(Scalar x, Tensor n) -> Tensor
+  n: non_differentiable
+
+- name: special_hermite_polynomial_he.n_scalar(Tensor x, Scalar n) -> Tensor
+  x: non_differentiable
+
+- name: special_laguerre_polynomial_l(Tensor x, Tensor n) -> Tensor
+  x: non_differentiable
+  n: non_differentiable
+
+- name: special_laguerre_polynomial_l.x_scalar(Scalar x, Tensor n) -> Tensor
+  n: non_differentiable
+
+- name: special_laguerre_polynomial_l.n_scalar(Tensor x, Scalar n) -> Tensor
+  x: non_differentiable
+
+- name: special_legendre_polynomial_p(Tensor x, Tensor n) -> Tensor
+  x: non_differentiable
+  n: non_differentiable
+
+- name: special_legendre_polynomial_p.x_scalar(Scalar x, Tensor n) -> Tensor
+  n: non_differentiable
+
+- name: special_legendre_polynomial_p.n_scalar(Tensor x, Scalar n) -> Tensor
+  x: non_differentiable
+
+- name: special_modified_bessel_i0(Tensor self) -> Tensor
+  self: non_differentiable
+
+- name: special_modified_bessel_i1(Tensor self) -> Tensor
+  self: non_differentiable
+
+- name: special_modified_bessel_k0(Tensor self) -> Tensor
+  self: non_differentiable
+
+- name: special_modified_bessel_k1(Tensor self) -> Tensor
+  self: non_differentiable
+
+- name: special_scaled_modified_bessel_k0(Tensor x) -> Tensor
+  x: non_differentiable
+
+- name: special_scaled_modified_bessel_k1(Tensor x) -> Tensor
+  x: non_differentiable
+
+- name: special_shifted_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor
+  x: non_differentiable
+  n: non_differentiable
+
+- name: special_shifted_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor
+  n: non_differentiable
+
+- name: special_shifted_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor
+  x: non_differentiable
+
+- name: special_shifted_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor
+  x: non_differentiable
+  n: non_differentiable
+
+- name: special_shifted_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor
+  n: non_differentiable
+
+- name: special_shifted_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor
+  x: non_differentiable
+
+- name: special_shifted_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor
+  x: non_differentiable
+  n: non_differentiable
+
+- name: special_shifted_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor
+  n: non_differentiable
+
+- name: special_shifted_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor
+  x: non_differentiable
+
+- name: special_shifted_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor
+  x: non_differentiable
+  n: non_differentiable
+
+- name: special_shifted_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor
+  n: non_differentiable
+
+- name: special_shifted_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor
+  x: non_differentiable
+
+- name: special_spherical_bessel_j0(Tensor x) -> Tensor
+  x: non_differentiable
+
+- name: _reshape_copy(Tensor self, SymInt[] size) -> Tensor
+  self: grad.reshape_symint(self.sym_sizes())
+  result: auto_linear
+
+# note(crcrpar): `torchgen/api/autograd` logic would unwantedly replace substrings of `self` and `other` of function names.
+- name: _foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[]
+  self: div_tensor_self_backward(grads[i], other[i], self[i].scalar_type())
+  other: div_tensor_other_backward(grads[i], self[i], other[i])
+  result: (self_t - other_t * result[i]) / other_p
+
+- name: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[]
+  self: pow_backward_self(grads[i], self[i], exponent[i])
+  exponent: pow_backward_exponent(grads[i], self[i], exponent[i], result[i])
+  result: (pow_backward_self(self_t.conj(), self_p, exponent_p) + pow_backward_exponent(exponent_t.conj(), self_p, exponent_p, result[i])).conj()
+
+- name: _foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[]
+  self: pow_backward(grads[i], self[i], exponent[i])
+  result: pow_backward(self_t.conj(), self_p, exponent[i]).conj()
+
+- name: _foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[]
+  exponent: pow_backward_exponent(grads[i], self, exponent[i], result[i])
+
+# note(crcrpar): following definitions seem necessary because the reference native functions
+# of `maximum` and `minimum` don't have the overload def with Scalar as their second argument.
+- name: _foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
+  self: at::where(self[i] == scalar, grads[i] / 2, grads[i]).masked_fill_(self[i] > scalar, 0)
+  result: scalar + at::where(self_p == scalar, at::scalar_tensor(0.5, result[i].options()), (self_p < scalar).to(result[i].scalar_type())) * (self_t - scalar)
+
+- name: _foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
+  self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] > scalars[i], 0)
+  result: scalars[i] + at::where(self_p == scalars[i], at::scalar_tensor(0.5, result[i].options()), (self_p < scalars[i]).to(result[i].scalar_type())) * (self_t - scalars[i])
+
+- name: _foreach_maximum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
+  self: at::where(self[i] == scalar, grads[i] / 2, grads[i]).masked_fill_(self[i] < scalar, 0)
+  result: scalar + at::where(self_p == scalar, at::scalar_tensor(0.5, result[i].options()), (self_p > scalar).to(result[i].scalar_type())) * (self_t - scalar)
+
+- name: _foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
+  self: at::where(self[i] == scalars[i], grads[i] / 2, grads[i]).masked_fill_(self[i] < scalars[i], 0)
+  result: scalars[i] + at::where(self_p == scalars[i], at::scalar_tensor(0.5, result[i].options()), (self_p > scalars[i]).to(result[i].scalar_type())) * (self_t - scalars[i])
+
+# note(crcrpar): forward-mode AD is tricky for a simple string replace to handle:
+#   formula.replace("p", "ord") produces `norm_jvord(self_ord, self_t, ord, result)`
+- name: _foreach_norm.Scalar(Tensor[] self, Scalar ord=2, ScalarType? dtype=None) -> Tensor[]
+  self: norm_backward(grads[i], self[i], ord, result[i])
+  result: norm_jvp(self_p, self_t, ord, result[i])
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_annotated_fn_args.py b/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_annotated_fn_args.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8cf8ad33079e5728274c19a00f83c8f959a63e5
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_annotated_fn_args.py
@@ -0,0 +1,134 @@
+"""
+For procedural tests needed for __torch_function__, we use this function
+to export method names and signatures as needed by the tests in
+test/test_overrides.py.
+
+python -m tools.autograd.gen_annotated_fn_args \
+       aten/src/ATen/native/native_functions.yaml \
+       aten/src/ATen/native/tags.yaml \
+       $OUTPUT_DIR \
+       tools/autograd
+
+Where $OUTPUT_DIR is where you would like the files to be
+generated.  In the full build system, OUTPUT_DIR is
+torch/testing/_internal/generated
+"""
+
+from __future__ import annotations
+
+import argparse
+import os
+import textwrap
+from collections import defaultdict
+from typing import Any, TYPE_CHECKING
+
+import torchgen.api.python as python
+from torchgen.context import with_native_function
+from torchgen.gen import parse_native_yaml
+from torchgen.utils import FileManager
+
+from .gen_python_functions import (
+    is_py_fft_function,
+    is_py_linalg_function,
+    is_py_nn_function,
+    is_py_special_function,
+    is_py_torch_function,
+    is_py_variable_method,
+    should_generate_py_binding,
+)
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+    from torchgen.model import Argument, BaseOperatorName, NativeFunction
+
+
+def gen_annotated(
+    native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str
+) -> None:
+    native_functions = parse_native_yaml(
+        native_yaml_path, tags_yaml_path
+    ).native_functions
+    mappings = (
+        (is_py_torch_function, "torch._C._VariableFunctions"),
+        (is_py_nn_function, "torch._C._nn"),
+        (is_py_linalg_function, "torch._C._linalg"),
+        (is_py_special_function, "torch._C._special"),
+        (is_py_fft_function, "torch._C._fft"),
+        (is_py_variable_method, "torch.Tensor"),
+    )
+    annotated_args: list[str] = []
+    for pred, namespace in mappings:
+        groups: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
+        for f in native_functions:
+            if not should_generate_py_binding(f) or not pred(f):
+                continue
+            groups[f.func.name.name].append(f)
+        for group in groups.values():
+            for f in group:
+                annotated_args.append(f"{namespace}.{gen_annotated_args(f)}")
+
+    template_path = os.path.join(autograd_dir, "templates")
+    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
+    fm.write_with_template(
+        "annotated_fn_args.py",
+        "annotated_fn_args.py.in",
+        lambda: {
+            "annotated_args": textwrap.indent("\n".join(annotated_args), "    "),
+        },
+    )
+
+
+@with_native_function
+def gen_annotated_args(f: NativeFunction) -> str:
+    def _get_kwargs_func_exclusion_list() -> list[str]:
+        # functions that currently don't work with kwargs in test_overrides.py
+        return [
+            "diagonal",
+            "round_",
+            "round",
+            "scatter_",
+        ]
+
+    def _add_out_arg(
+        out_args: list[dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool
+    ) -> None:
+        for arg in args:
+            if arg.default is not None:
+                continue
+            out_arg: dict[str, Any] = {}
+            out_arg["is_kwarg_only"] = str(is_kwarg_only)
+            out_arg["name"] = arg.name
+            out_arg["simple_type"] = python.argument_type_str(
+                arg.type, simple_type=True
+            )
+            size_t = python.argument_type_size(arg.type)
+            if size_t:
+                out_arg["size"] = size_t
+            out_args.append(out_arg)
+
+    out_args: list[dict[str, Any]] = []
+    _add_out_arg(out_args, f.func.arguments.flat_positional, is_kwarg_only=False)
+    if f"{f.func.name.name}" not in _get_kwargs_func_exclusion_list():
+        _add_out_arg(out_args, f.func.arguments.flat_kwarg_only, is_kwarg_only=True)
+
+    return f"{f.func.name.name}: {repr(out_args)},"
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(description="Generate annotated_fn_args script")
+    parser.add_argument(
+        "native_functions", metavar="NATIVE", help="path to native_functions.yaml"
+    )
+    parser.add_argument("tags", metavar="TAGS", help="path to tags.yaml")
+    parser.add_argument("out", metavar="OUT", help="path to output directory")
+    parser.add_argument(
+        "autograd", metavar="AUTOGRAD", help="path to template directory"
+    )
+    args = parser.parse_args()
+    gen_annotated(args.native_functions, args.tags, args.out, args.autograd)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_autograd.py b/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_autograd.py
new file mode 100644
index 0000000000000000000000000000000000000000..cee788aa9c825c4ed2916254726bd3e35c0d55fa
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_autograd.py
@@ -0,0 +1,147 @@
+"""
+To run this file by hand from the root of the PyTorch
+repository, run:
+
+python -m tools.autograd.gen_autograd \
+       aten/src/ATen/native/native_functions.yaml \
+       aten/src/ATen/native/tags.yaml \
+       $OUTPUT_DIR \
+       tools/autograd
+
+Where $OUTPUT_DIR is where you would like the files to be
+generated.  In the full build system, OUTPUT_DIR is
+torch/csrc/autograd/generated/
+"""
+
+# gen_autograd.py generates C++ autograd functions and Python bindings.
+#
+# It delegates to the following scripts:
+#
+#  gen_autograd_functions.py: generates subclasses of torch::autograd::Node
+#  gen_variable_type.py: generates VariableType.h which contains all tensor methods
+#  gen_python_functions.py: generates Python bindings to THPVariable
+#
+
+from __future__ import annotations
+
+import argparse
+import os
+
+from torchgen.api import cpp
+from torchgen.api.autograd import (
+    match_differentiability_info,
+    NativeFunctionWithDifferentiabilityInfo,
+)
+from torchgen.gen import parse_native_yaml
+from torchgen.selective_build.selector import SelectiveBuilder
+
+from . import gen_python_functions
+from .gen_autograd_functions import (
+    gen_autograd_functions_lib,
+    gen_autograd_functions_python,
+)
+from .gen_inplace_or_view_type import gen_inplace_or_view_type
+from .gen_trace_type import gen_trace_type
+from .gen_variable_factories import gen_variable_factories
+from .gen_variable_type import gen_variable_type
+from .gen_view_funcs import gen_view_funcs
+from .load_derivatives import load_derivatives
+
+
+def gen_autograd(
+    native_functions_path: str,
+    tags_path: str,
+    out: str,
+    autograd_dir: str,
+    operator_selector: SelectiveBuilder,
+    disable_autograd: bool = False,
+) -> None:
+    # Parse and load derivatives.yaml
+    differentiability_infos, used_dispatch_keys = load_derivatives(
+        os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path
+    )
+
+    template_path = os.path.join(autograd_dir, "templates")
+
+    native_funcs = parse_native_yaml(native_functions_path, tags_path).native_functions
+    fns = sorted(
+        filter(
+            operator_selector.is_native_function_selected_for_training, native_funcs
+        ),
+        key=lambda f: cpp.name(f.func),
+    )
+    fns_with_diff_infos: list[NativeFunctionWithDifferentiabilityInfo] = (
+        match_differentiability_info(fns, differentiability_infos)
+    )
+
+    # Generate VariableType.h/cpp
+    if not disable_autograd:
+        gen_variable_type(
+            out,
+            native_functions_path,
+            tags_path,
+            fns_with_diff_infos,
+            template_path,
+            used_dispatch_keys,
+        )
+
+        gen_inplace_or_view_type(
+            out, native_functions_path, tags_path, fns_with_diff_infos, template_path
+        )
+
+        # operator filter not applied as tracing sources are excluded in selective build
+        gen_trace_type(out, native_funcs, template_path)
+    # Generate Functions.h/cpp
+    gen_autograd_functions_lib(out, differentiability_infos, template_path)
+
+    # Generate variable_factories.h
+    gen_variable_factories(out, native_functions_path, tags_path, template_path)
+
+    # Generate ViewFuncs.h/cpp
+    gen_view_funcs(out, fns_with_diff_infos, template_path)
+
+
+def gen_autograd_python(
+    native_functions_path: str,
+    tags_path: str,
+    out: str,
+    autograd_dir: str,
+) -> None:
+    differentiability_infos, _ = load_derivatives(
+        os.path.join(autograd_dir, "derivatives.yaml"), native_functions_path, tags_path
+    )
+
+    template_path = os.path.join(autograd_dir, "templates")
+
+    # Generate Functions.h/cpp
+    gen_autograd_functions_python(out, differentiability_infos, template_path)
+
+    # Generate Python bindings
+    deprecated_path = os.path.join(autograd_dir, "deprecated.yaml")
+    gen_python_functions.gen(
+        out, native_functions_path, tags_path, deprecated_path, template_path
+    )
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(description="Generate autograd C++ files script")
+    parser.add_argument(
+        "native_functions", metavar="NATIVE", help="path to native_functions.yaml"
+    )
+    parser.add_argument("tags", metavar="NATIVE", help="path to tags.yaml")
+    parser.add_argument("out", metavar="OUT", help="path to output directory")
+    parser.add_argument(
+        "autograd", metavar="AUTOGRAD", help="path to autograd directory"
+    )
+    args = parser.parse_args()
+    gen_autograd(
+        args.native_functions,
+        args.tags,
+        args.out,
+        args.autograd,
+        SelectiveBuilder.get_nop_selector(),
+    )
+
+
+if __name__ == "__main__":
+    main()
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_autograd_functions.py b/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_autograd_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cfb494731c07d68429ad15a50ab206acb41994a
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_autograd_functions.py
@@ -0,0 +1,1074 @@
+# Generates C++ autograd functions for the derivatives of ATen operations
+#
+# This writes two files:
+#  Functions.h/cpp: subclasses of autograd::Node
+#  python_functions.h/cpp: Python bindings for the above classes
+#
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from torchgen.api.autograd import (
+    Derivative,
+    DifferentiabilityInfo,
+    SavedAttribute,
+    uses_retain_variables,
+    uses_single_grad,
+)
+from torchgen.api.types import (
+    ArrayRefCType,
+    BaseCppType,
+    BaseCType,
+    Binding,
+    boolT,
+    doubleT,
+    intArrayRefT,
+    iTensorListRefT,
+    ListCType,
+    longT,
+    MutRefCType,
+    OptionalCType,
+    optionalIntArrayRefT,
+    optionalSymIntArrayRefT,
+    scalarT,
+    stringT,
+    symIntArrayRefT,
+    SymIntT,
+    TENSOR_LIST_LIKE_CTYPES,
+    tensorListT,
+    tensorT,
+    VectorCType,
+)
+from torchgen.code_template import CodeTemplate
+from torchgen.model import Argument, FunctionSchema
+from torchgen.utils import FileManager
+
+from .gen_inplace_or_view_type import VIEW_FUNCTIONS
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+
+FUNCTION_DECLARATION = CodeTemplate(
+    """\
+#ifdef _WIN32
+struct ${op} : public ${superclass} {
+  TORCH_API ${op}() = default;
+#else
+struct TORCH_API ${op} : public ${superclass} {
+#endif
+  using ${superclass}::${superclass};
+  variable_list apply(variable_list&& grads) override;
+  std::string name() const override { return "${op}"; }
+  void release_variables() override {
+    ${thread_lock}
+    ${release_variables}
+  }
+  ${will_release_variables}
+  void compiled_args(CompiledNodeArgs& args) const override;
+  variable_list apply_with_saved(const variable_list& inputs, SwapSavedVariables& saved) override;
+  ${saved_variables}
+  ${saved_list_sizes}
+};
+"""
+)
+
+WILL_RELEASE_VARIABLES = CodeTemplate(
+    """\
+bool retain_variables = true;
+void will_release_variables() override {
+  retain_variables = false;
+}
+"""
+)
+
+# We generate e.g. MulBackward0::apply and have that call into
+# MulBackward0_apply_functional. The apply_functional is a pure function,
+# that is, it does not rely on global state. MulBackward0::apply
+# is responsible for querying the autograd engine for which outputs should
+# be computed (needs_input_grad), applying locks,
+# and unpacking saved variables to pass to MulBackward0_apply_functional.
+#
+# needs_input_grad is a mapping from input index to if that input needs
+# gradients computed. For operators that take in List[Tensor], the List[Tensor]
+# is one element in the needs_input_grad that specifies if *any* of the
+# List[Tensor] needs input grad. In theory this could be optimized.
+FUNCTION_DEFINITION = CodeTemplate(
+    """\
+static variable_list ${op}_apply_functional(
+  variable_list&& grads,
+  std::array needs_input_grad${,apply_functional_args_signature})
+{
+  IndexRangeGenerator gen;
+  ${compute_index_ranges}
+  variable_list grad_inputs(gen.size());
+  ${body}
+  return grad_inputs;
+}
+inline variable_list ${op}_apply_functional_ivalue(const variable_list& grads, const ivalue_list& args)
+{
+#ifdef C10_MOBILE
+  TORCH_INTERNAL_ASSERT(false, "compiled autograd doesn't work on mobile");
+#else
+  auto packed_args = PackedArgs(args);
+  auto needs_input_grad = packed_args.unpack>();
+  ${unpack_ivalues}
+  return ${op}_apply_functional(variable_list(grads), needs_input_grad${,apply_functional_args});
+#endif
+}
+
+variable_list ${op}::apply(variable_list&& grads) {
+  ${thread_lock}
+  ${asserts}
+  ${unpacks}
+  ${compute_needs_input_grad}
+  return ${op}_apply_functional(std::move(grads), needs_input_grad${,apply_functional_args});
+}
+
+void ${op}::compiled_args(CompiledNodeArgs& args) const {
+    ${compiled_args}
+}
+variable_list ${op}::apply_with_saved(const variable_list& grads, SwapSavedVariables& saved) {
+#ifdef C10_MOBILE
+  TORCH_INTERNAL_ASSERT(false, "compiled autograd doesn't work on mobile");
+#else
+  ${apply_with_saved_before}
+
+  static bool called = false;
+  if (!called) {
+    called = true;
+    ${compute_schema}
+    const auto& pyinterface = torch::dynamo::autograd::getPyCompilerInterface();
+    pyinterface->bind_function(saved.get_py_compiler(), name(), ${op}_apply_functional_ivalue, schema);
+  }
+
+  variable_list output_result;
+
+  PackedArgs packed_args;
+  ${asserts}
+  ${unpacks}
+  ${compute_needs_input_grad}
+  packed_args.pack(needs_input_grad);
+  ${get_packed_args}
+
+  output_result = compiled_autograd_apply_functional(packed_args, next_edges(), saved, grads, name());
+
+  ${apply_with_saved_after}
+  return output_result;
+#endif
+}
+
+"""
+)
+
+GRAD_INPUT_MASK = CodeTemplate(
+    """\
+  auto grad_input_mask = std::array{
+    ${masks}
+  };
+"""
+)
+
+COMPUTE_NEEDS_INPUT_GRAD = CodeTemplate(
+    """\
+IndexRangeGenerator gen;
+${compute_index_ranges}
+auto needs_input_grad = std::array{
+  ${masks}
+};\
+"""
+)
+
+
+DERIVATIVE_SINGLE = CodeTemplate(
+    """\
+if (needs_input_grad[/*${name}*/${idx}]) {
+  auto grad_result = ${derivative};
+  copy_range(grad_inputs, ${name}_ix, grad_result);
+}
+"""
+)
+
+# note(crcrpar): `self` argument and other optional positional argument
+# of foreach functions are basically a list of n `Tensor`s thus iterating over
+# `grads` in order to utilize and apply the existing derivative definitions
+# to each `Tensor`(s) of `self`, and the others.
+DERIVATIVE_SINGLE_FOREACH = CodeTemplate(
+    """\
+if (needs_input_grad[/*${name}*/${idx}]) {  // ${name}
+  std::vector grad_result;
+  grad_result.reserve(grads.size());
+  for (const auto & i : c10::irange(grads.size())) {
+    if (grads[i].defined()) {
+      grad_result.emplace_back(${derivative});
+    } else {
+      grad_result.emplace_back(Tensor());
+    }
+  }
+  copy_range(grad_inputs, ${name}_ix, grad_result);
+}
+"""
+)
+
+DERIVATIVE_MULTI_COPY_RANGE = CodeTemplate(
+    """\
+  if (needs_input_grad[/*${name}*/${idx}]) {
+    copy_range(grad_inputs, ${name}_ix, std::get<${i}>(grad_result));
+  }
+"""
+)
+
+DERIVATIVE_MULTI = CodeTemplate(
+    """\
+if (${needs_input_grad}) {
+  ${grad_input_mask}
+  auto grad_result = ${derivative};
+  ${copy_ranges}
+}
+"""
+)
+
+# Generates python bindings
+#
+# This generates the definitions for:
+#   (1) The PyTypeObject for each backward grad_fn subclassing Node
+#   (2) The entry for PyTypeObject's tp_getset slot (an array of PyGetSetDef structs)
+#       We generate one PyGetSetDef struct for each of grad_fn's saved inputs and outputs
+#       Each PyGetSetDef has a function ptr to a getter, also defined here (3).
+#   (3) Getters for each of grad_fn's saved inputs and outputs.
+#
+PY_FUNCTION_DEFINITION = CodeTemplate(
+    """\
+static PyTypeObject ${op}Class;
+addClass<${op}>(module, ${op}Class, "${op}", ${op}_properties);
+"""
+)
+
+PY_FUNCTION_PROPS_AND_GETTERS = CodeTemplate(
+    """\
+${all_getter_definitions}
+
+static struct PyGetSetDef ${op}_properties[] = {
+  THP_FUNCTION_DEFAULT_PROPERTIES,
+  ${all_getsetdef_structs}
+  {nullptr} /* sentinel */
+};
+
+"""
+)
+
+PY_GETSETDEF_STRUCT = CodeTemplate(
+    """\
+{(char*)"_saved_${name}", (getter)THP${op}_${name}_getter, nullptr, nullptr, nullptr}"""
+)
+
+PY_RAW_GETSETDEF_STRUCT = CodeTemplate(
+    """\
+{(char*)"_raw_saved_${name}", (getter)THP${op}_${name}_raw_getter, nullptr, nullptr, nullptr}"""
+)
+
+# Getter templates
+GETTER_DEFINITION = CodeTemplate(
+    """\
+static PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
+  HANDLE_TH_ERRORS
+  auto prop = static_cast<${op}*>(self->cdata.get())->${name};
+  ${body}
+  END_HANDLE_TH_ERRORS
+}
+"""
+)
+
+GETTER_DEFINITION_SAVEDVAR = CodeTemplate(
+    """\
+static PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
+  HANDLE_TH_ERRORS
+  const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_;
+  ${body}
+  END_HANDLE_TH_ERRORS
+}
+"""
+)
+
+GETTER_DEFINITION_RAW_SAVEDVAR = CodeTemplate(
+    """\
+static PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) {
+  HANDLE_TH_ERRORS
+  const auto& prop = static_cast<${op}*>(self->cdata.get())->${name}_;
+  ${body}
+  END_HANDLE_TH_ERRORS
+}
+"""
+)
+
+GETTER_DEFINITION_VEC_SAVEDVAR = CodeTemplate(
+    """\
+static PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
+  HANDLE_TH_ERRORS
+  const auto *node = static_cast<${op}*>(self->cdata.get());
+  const auto& prop = node->${name}_;
+  if (node->${name}_released_) {
+    PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE);
+    return nullptr;
+  }
+  ${body}
+  END_HANDLE_TH_ERRORS
+}
+"""
+)
+
+GETTER_DEFINITION_RAW_VEC_SAVEDVAR = CodeTemplate(
+    """\
+static PyObject* THP${op}_${name}_raw_getter(THPCppFunction *self, void *_unused) {
+  HANDLE_TH_ERRORS
+  const auto *node = static_cast<${op}*>(self->cdata.get());
+  const auto& prop = node->${name}_;
+  if (node->${name}_released_) {
+    PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE);
+    return nullptr;
+  }
+  ${body}
+  END_HANDLE_TH_ERRORS
+}
+"""
+)
+
+GETTER_DEFINITION_OPT = CodeTemplate(
+    """\
+static PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
+  HANDLE_TH_ERRORS
+  auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name};
+  if (!opt_prop.has_value()) {
+    Py_RETURN_NONE;
+  }
+  auto prop = opt_prop.value();
+  ${body}
+  END_HANDLE_TH_ERRORS
+}
+"""
+)
+
+GETTER_DEFINITION_OPT_ARRAYREF = CodeTemplate(
+    """\
+static PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
+  HANDLE_TH_ERRORS
+  auto opt_prop = static_cast<${op}*>(self->cdata.get())->${name};
+  if (!opt_prop.list.has_value()) {
+    Py_RETURN_NONE;
+  }
+  auto prop = opt_prop.list.value();
+  ${body}
+  END_HANDLE_TH_ERRORS
+}
+"""
+)
+
+# Getter body
+GETTER_BODY_SAVEDVAR = """\
+return THPVariable_Wrap(prop.unpack(self->cdata));
+"""
+
+GETTER_BODY_RAW_SAVEDVAR = """\
+pybind11::object obj = pybind11::cast(prop, pybind11::return_value_policy::reference);
+return obj.release().ptr();
+"""
+
+GETTER_BODY_VEC_SAVEDVAR = """\
+PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
+for (auto i: c10::irange(prop.size())) {
+  PyTuple_SetItem(tup, (Py_ssize_t) i, THPVariable_Wrap(prop[i].unpack(self->cdata)));
+}
+return tup;
+"""
+
+GETTER_BODY_RAW_VEC_SAVEDVAR = """\
+PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
+for (auto i : c10::irange(prop.size())) {
+  pybind11::object obj = pybind11::cast(prop[i], pybind11::return_value_policy::reference);
+  PyTuple_SetItem(tup, (Py_ssize_t) i, obj.release().ptr());
+}
+return tup;
+"""
+
+GETTER_BODY_ARRAYREF_LONG = """\
+PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
+for (auto i : c10::irange(prop.size())) {
+  PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong((uint64_t) prop[i]));
+}
+return tup;
+"""
+
+GETTER_BODY_ARRAYREF_SYMINT = """\
+PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
+for (auto i : c10::irange(prop.size())) {
+    auto si = prop[i];
+    if (auto m = si.maybe_as_int()) {
+      PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromUnsignedLong(*m));
+    } else {
+      auto py_symint = py::cast(si).release().ptr();
+      PyTuple_SetItem(tup, (Py_ssize_t) i, py_symint);
+    }
+}
+return tup;
+"""
+
+GETTER_BODY_ARRAYREF_DOUBLE = """\
+PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
+for (auto i : c10::irange(prop.size())) {
+  PyTuple_SetItem(tup, (Py_ssize_t) i, PyFloat_FromDouble((double) prop[i]));
+}
+return tup;
+"""
+
+GETTER_BODY_INT64_T = """\
+return PyLong_FromUnsignedLong((int64_t) prop);
+"""
+
+GETTER_BODY_SYMINT = """\
+if (auto m = prop.maybe_as_int()) {
+  return PyLong_FromUnsignedLong(*m);
+} else {
+  return py::cast(prop).release().ptr();
+}
+"""
+
+GETTER_BODY_DOUBLE = """\
+return PyFloat_FromDouble((double) prop);
+"""
+
+GETTER_BODY_BOOL = """\
+if (prop) {
+  Py_RETURN_TRUE;
+} else {
+  Py_RETURN_FALSE;
+}
+"""
+
+GETTER_BODY_STRING = """\
+return PyUnicode_FromStringAndSize(prop.data(), prop.size());
+"""
+
+GETTER_BODY_SCALAR = """\
+if (prop.isComplex()) {
+  auto cprop = prop.to>();
+  return PyComplex_FromDoubles(cprop.real(), cprop.imag());
+} else if (prop.isFloatingPoint()) {
+  return PyFloat_FromDouble(prop.to());
+} else if (prop.isIntegral(/*includeBool=*/false)) {
+  return PyLong_FromLong(prop.to());
+} else if (prop.isBoolean()) {
+  if (prop.to()) {
+    Py_RETURN_TRUE;
+  } else {
+    Py_RETURN_FALSE;
+  }
+} else {
+  PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type");
+  return nullptr;
+}
+"""
+
+
+GETTER_BODY_VEC_SCALAR = """\
+PyObject* tup = PyTuple_New((Py_ssize_t) prop.size());
+for (auto i: c10::irange(prop.size())) {
+  if (prop[i].isComplex()) {
+    auto cprop = prop[i].to>();
+    PyTuple_SetItem(tup, (Py_ssize_t) i, PyComplex_FromDoubles(cprop.real(), cprop.imag()));
+  } else if (prop[i].isFloatingPoint()) {
+    auto double_prop = prop[i].to();
+    PyTuple_SetItem(tup, (Py_ssize_t) i, PyFloat_FromDouble(double_prop));
+  } else if (prop[i].isIntegral(/*includeBool=*/false)) {
+    auto long_prop = prop[i].to();
+    PyTuple_SetItem(tup, (Py_ssize_t) i, PyLong_FromLong(long_prop));
+  } else if (prop[i].isBoolean()) {
+    if (prop[i].to()) {
+      PyTuple_SetItem(tup, (Py_ssize_t) i, Py_True);
+    } else {
+      PyTuple_SetItem(tup, (Py_ssize_t) i, Py_False);
+    }
+  } else {
+    PyErr_SetString(PyExc_RuntimeError, "Unknown scalar type");
+    return nullptr;
+  }
+}
+return tup;
+"""
+
+
+MISC_GETTER_DEFS = {
+    OptionalCType(BaseCType(longT)): (GETTER_DEFINITION_OPT, GETTER_BODY_INT64_T),
+    OptionalCType(BaseCType(SymIntT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SYMINT),
+    BaseCType(doubleT): (GETTER_DEFINITION, GETTER_BODY_DOUBLE),
+    OptionalCType(BaseCType(doubleT)): (GETTER_DEFINITION_OPT, GETTER_BODY_DOUBLE),
+    BaseCType(boolT): (GETTER_DEFINITION, GETTER_BODY_BOOL),
+    BaseCType(scalarT): (GETTER_DEFINITION, GETTER_BODY_SCALAR),
+    OptionalCType(BaseCType(scalarT)): (GETTER_DEFINITION_OPT, GETTER_BODY_SCALAR),
+}
+
+# These functions have backwards which cannot be traced, and so must have
+# their backward functions traced opaquely.
+# VIEW_FUNCTIONS are not traceable because they use as_strided, which
+# has an untraceable backwards, see
+# https://github.com/pytorch/pytorch/issues/4250
+# TODO: This is probably not exhaustive, but it's a start
+UNTRACEABLE_FUNCTIONS = VIEW_FUNCTIONS
+
+
+def get_infos_with_derivatives_list(
+    differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
+) -> list[DifferentiabilityInfo]:
+    diff_info_list = [
+        info
+        for diffinfo_dict in differentiability_infos.values()
+        for info in diffinfo_dict.values()
+    ]
+
+    return list(filter(lambda info: info.args_with_derivatives, diff_info_list))
+
+
+def gen_autograd_functions_lib(
+    out: str,
+    differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
+    template_path: str,
+) -> None:
+    """Functions.h and Functions.cpp body
+
+    These contain the auto-generated subclasses of torch::autograd::Node
+    for each every differentiable torch function.
+    """
+
+    # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here
+    # infos with the diff dispatchkeys but the same name will still be in the same shard.
+    infos = get_infos_with_derivatives_list(differentiability_infos)
+    declarations = [process_function(f, FUNCTION_DECLARATION) for f in infos]
+    definitions = [process_function(f, FUNCTION_DEFINITION) for f in infos]
+
+    file_basename = "Functions"
+    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
+    for suffix in [".h", ".cpp"]:
+        fname = file_basename + suffix
+        fm.write_with_template(
+            fname,
+            fname,
+            lambda: {
+                "generated_comment": "@"
+                + f"generated from {fm.template_dir_for_comments()}/{fname}",
+                "autograd_function_declarations": declarations,
+                "autograd_function_definitions": definitions,
+            },
+        )
+
+
+def gen_autograd_functions_python(
+    out: str,
+    differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
+    template_path: str,
+) -> None:
+    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
+    num_shards = 5
+    fm.write(
+        "python_functions.h",
+        lambda: {
+            "generated_comment": "@"
+            + f"generated from {fm.template_dir_for_comments()}/python_functions.h",
+            "shard_forward_declare": [
+                f"void initialize_autogenerated_functions_{i}(PyObject* module);"
+                for i in range(num_shards)
+            ],
+            "shard_call": [
+                f"initialize_autogenerated_functions_{i}(module);"
+                for i in range(num_shards)
+            ],
+        },
+    )
+
+    # get a 1D list of diffinfos, we do not need them to be per FunctionSchema/DispatchKey here
+    # infos with the diff dispatchkeys but the same name will still be in the same shard.
+    infos = get_infos_with_derivatives_list(differentiability_infos)
+    fm.write_sharded(
+        "python_functions.cpp",
+        infos,
+        key_fn=lambda info: info.name,
+        base_env={
+            "generated_comment": "@"
+            + f"generated from {fm.template_dir_for_comments()}/python_functions.cpp",
+        },
+        env_callable=lambda info: {
+            "py_function_initializers": [
+                process_function(info, PY_FUNCTION_DEFINITION)
+            ],
+            "py_function_props_and_getters": [
+                process_function(info, PY_FUNCTION_PROPS_AND_GETTERS)
+            ],
+        },
+        num_shards=num_shards,
+        sharded_keys={"py_function_initializers", "py_function_props_and_getters"},
+    )
+
+
+def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str:
+    saved_variables: list[str] = []
+    release_variables: list[str] = []
+    saved_list_sizes: list[str] = []
+    unpack: list[str] = []
+    asserts: list[str] = []
+    compute_index_ranges: list[str] = []
+    getter_definitions: list[str] = []
+    py_getsetdef_structs: list[str] = []
+    compiled_args: list[str] = []
+    apply_with_saved_before: list[str] = []
+    apply_with_saved_after: list[str] = []
+    apply_functional_args: list[str] = []
+    apply_functional_args_ref_types: list[str] = []
+    # Maps the name of an input (to the original forward operator;
+    # examples are "self", "other") to the order in which they appear in the
+    # operator.
+    # For example; if the operator is foo(Tensor self, int64_t k, Tensor other),
+    # the mapping is: {"self": 0, "other": 1}.
+    # We use this mapping to populate needs_input_grad in some order and then grab
+    # values from it.
+    input_name_to_idx: dict[str, int] = {}
+
+    for idx, arg in enumerate(info.args_with_derivatives):
+        if arg.type in TENSOR_LIST_LIKE_CTYPES:
+            size = f"{arg.name}_size_"
+            saved_list_sizes.append(f"size_t {arg.name}_size_;")
+            apply_functional_args.append(f"{arg.name}_size_")
+            apply_functional_args_ref_types.append("size_t")
+        else:
+            size = "1"
+        compute_index_ranges.append(f"auto {arg.name}_ix = gen.range({size});")
+        input_name_to_idx[arg.name] = idx
+
+    def save_var(var: SavedAttribute, is_output: bool) -> None:
+        name = var.nctype.name
+        type = var.nctype.type
+        should_append_getsetdef = True
+        should_append_raw_getsetdef = False
+        visit_name = name
+        uses_cpp_saved_variable_cls = False
+        unpacked_ref_type = None
+
+        if (
+            type == BaseCType(tensorT)
+            or type == OptionalCType(BaseCType(tensorT))
+            or type == MutRefCType(OptionalCType(BaseCType(tensorT)))
+            or (type == BaseCType(scalarT) and is_output)
+        ):
+            uses_cpp_saved_variable_cls = True
+            saved_variables.append(f"SavedVariable {name}_;")
+            release_variables.append(f"{name}_.reset_data();")
+            ptr = "shared_from_this()" if is_output else ""
+            unpack.append(f"auto {name} = {name}_.unpack({ptr});")
+            getter_definitions.append(
+                GETTER_DEFINITION_SAVEDVAR.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_SAVEDVAR
+                )
+            )
+            getter_definitions.append(
+                GETTER_DEFINITION_RAW_SAVEDVAR.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_RAW_SAVEDVAR
+                )
+            )
+            should_append_raw_getsetdef = True
+            visit_name = f"{name}_"
+            unpacked_ref_type = "Tensor&"
+        elif (
+            type == BaseCType(tensorListT)
+            or type == BaseCType(iTensorListRefT)
+            or type == VectorCType(BaseCType(tensorT))
+        ):
+            # note(crcrpar): [nuanced return type of out-of-place foreach functions]
+            # When an out-of-place foreach function whose return signature is `Tensor[]`
+            # spells out its backward definitions in `derivatives.yaml`, and some of them depend on
+            # `result`, `result`'s type is interpreted and treated as `std::vector`.
+            # An out-of-place foreach whose backwards rely on their output doesn't suffer from this
+            # difference if the definitions are codegen'ed.
+            # This special case is needed for `_foreach_pow.List` and `_foreach_pow.ScalarAndTensor`
+            # as of https://github.com/pytorch/pytorch/pull/105504.
+            if type == VectorCType(BaseCType(tensorT)):
+                assert (
+                    info.func.func.name.name.base.startswith("_foreach") and is_output
+                )
+            uses_cpp_saved_variable_cls = True
+            saved_variables.append(f"std::vector {name}_;")
+            saved_variables.append(f"bool {name}_released_ = false;")
+            # Just clear() is sufficient, we don't need to loop and clear each variable.
+            # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
+            release_variables.append(f"{name}_.clear();")
+            release_variables.append(f"{name}_released_ = true;")
+            ptr = "shared_from_this()" if is_output else "nullptr"
+            unpack.append(f"auto {name} = unpack_list({name}_, {ptr});")
+            asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);")
+            getter_definitions.append(
+                GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR
+                )
+            )
+            getter_definitions.append(
+                GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR
+                )
+            )
+            should_append_raw_getsetdef = True
+            visit_name = f"{name}_"
+            unpacked_ref_type = "std::vector&"
+        elif type == ListCType(OptionalCType(BaseCType(tensorT))):
+            uses_cpp_saved_variable_cls = True
+            saved_variables.append(f"std::vector {name}_;")
+            saved_variables.append(f"bool {name}_released_ = false;")
+            # Just clear() is sufficient, we don't need to loop and clear each variable.
+            # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
+            release_variables.append(f"{name}_.clear();")
+            release_variables.append(f"{name}_released_ = true;")
+            unpack.append(f"auto {name} = unpack_opt_list({name}_);")
+            asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);")
+            getter_definitions.append(
+                GETTER_DEFINITION_VEC_SAVEDVAR.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_VEC_SAVEDVAR
+                )
+            )
+            getter_definitions.append(
+                GETTER_DEFINITION_RAW_VEC_SAVEDVAR.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_RAW_VEC_SAVEDVAR
+                )
+            )
+            should_append_raw_getsetdef = True
+            visit_name = f"{name}_"
+            unpacked_ref_type = "torch::List>&"
+        elif type == BaseCType(intArrayRefT):
+            saved_variables.append(f"std::vector {name};")
+            getter_definitions.append(
+                GETTER_DEFINITION.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
+                )
+            )
+        elif type == BaseCType(symIntArrayRefT):
+            saved_variables.append(f"std::vector {name};")
+            getter_definitions.append(
+                GETTER_DEFINITION.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT
+                )
+            )
+        elif type == BaseCType(optionalIntArrayRefT):
+            saved_variables.append(f"c10::OptionalArray {name};")
+            getter_definitions.append(
+                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
+                )
+            )
+        elif type == BaseCType(optionalSymIntArrayRefT):
+            saved_variables.append(f"c10::OptionalArray {name};")
+            getter_definitions.append(
+                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT
+                )
+            )
+        elif type == OptionalCType(BaseCType(intArrayRefT)):
+            saved_variables.append(f"c10::OptionalArray {name};")
+            getter_definitions.append(
+                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_LONG
+                )
+            )
+        elif type == OptionalCType(BaseCType(symIntArrayRefT)):
+            saved_variables.append(f"c10::OptionalArray {name};")
+            getter_definitions.append(
+                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_SYMINT
+                )
+            )
+        elif type == OptionalCType(ArrayRefCType(BaseCType(doubleT))):
+            saved_variables.append(f"c10::OptionalArray {name};")
+            getter_definitions.append(
+                GETTER_DEFINITION_OPT_ARRAYREF.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_ARRAYREF_DOUBLE
+                )
+            )
+        elif type == BaseCType(longT):
+            saved_variables.append(f"{type.cpp_type()} {name} = 0;")
+            getter_definitions.append(
+                GETTER_DEFINITION.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_INT64_T
+                )
+            )
+        elif type == BaseCType(SymIntT):
+            saved_variables.append(f"c10::SymInt {name};")
+            getter_definitions.append(
+                GETTER_DEFINITION.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_SYMINT
+                )
+            )
+        elif type == BaseCType(stringT):
+            saved_variables.append(f"std::string {name};")
+            getter_definitions.append(
+                GETTER_DEFINITION.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_STRING
+                )
+            )
+        elif type == OptionalCType(BaseCType(stringT)):
+            saved_variables.append(f"std::optional {name};")
+            getter_definitions.append(
+                GETTER_DEFINITION_OPT.substitute(
+                    op=info.op, name=name, body=GETTER_BODY_STRING
+                )
+            )
+        elif type == ArrayRefCType(
+            elem=BaseCType(type=BaseCppType(ns="at", name="Scalar"))
+        ):
+            saved_variables.append(f"std::vector {name};")
+            unpacked_ref_type = "std::vector&"
+            saved_variables.append(f"bool {name}_released_ = false;")
+            # Just clear() is sufficient, we don't need to loop and clear each variable.
+            # Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
+            release_variables.append(f"{name}.clear();")
+            # release_variables.append(f"{name}_released_ = true;")
+            # unpack.append(f"auto {name} = unpack_list({name}_);")
+            # asserts.append(f"TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);")
+            getter_definitions.append(
+                CodeTemplate(
+                    """\
+static PyObject* THP${op}_${name}_getter(THPCppFunction *self, void *_unused) {
+  HANDLE_TH_ERRORS
+  const auto *node = static_cast<${op}*>(self->cdata.get());
+  const auto& prop = node->${name};
+  if (node->${name}_released_) {
+    PyErr_SetString(PyExc_RuntimeError, ERR_BACKWARD_TWICE);
+    return nullptr;
+  }
+  ${body}
+  END_HANDLE_TH_ERRORS
+}
+                            """
+                ).substitute(
+                    op=info.op,
+                    name=name,
+                    body=GETTER_BODY_VEC_SCALAR,
+                )
+            )
+        else:
+            # Check for indicators that you're putting a non-owning reference
+            # into the saved variable field.  If this is spuriously firing,
+            # edit this field.  Otherwise, you probably need to add a case
+            # above.
+            assert (
+                "ref" not in type.cpp_type().lower()
+                and "view" not in type.cpp_type().lower()
+                and "*" not in type.cpp_type()
+                and "&" not in type.cpp_type()
+            ), f"{type.cpp_type()} looks like it contains a non-owning reference"
+            saved_variables.append(f"{type.cpp_type()} {name};")
+
+            if type in MISC_GETTER_DEFS:
+                getter_def, body = MISC_GETTER_DEFS[type]
+                getter_definitions.append(
+                    getter_def.substitute(op=info.op, name=name, body=body)
+                )
+            else:
+                # Types we don't expose python bindings to yet:
+                #   TypeAndSize, at::ScalarType, TensorOptions, TensorGeometry,
+                #   std::vector>, std::vector
+                should_append_getsetdef = False
+
+        if should_append_getsetdef:
+            py_getsetdef_structs.append(
+                PY_GETSETDEF_STRUCT.substitute(op=info.op, name=name)
+            )
+        if should_append_raw_getsetdef:
+            py_getsetdef_structs.append(
+                PY_RAW_GETSETDEF_STRUCT.substitute(op=info.op, name=name)
+            )
+
+        if uses_cpp_saved_variable_cls:
+            compiled_args.append(
+                f"args.collect({visit_name}, {'true' if is_output else 'false'});"
+            )
+        else:
+            compiled_args.append(f"args.collect({visit_name});")
+        apply_with_saved_before.append(f"saved.before({visit_name});")
+        apply_with_saved_after.append(f"saved.after({visit_name});")
+
+        if unpacked_ref_type is None:
+            unpacked_ref_type = f"{saved_variables[-1].split(' ')[0]}&"
+        apply_functional_args.append(str(name))
+        apply_functional_args_ref_types.append(unpacked_ref_type)
+
+    for var in sorted(info.all_saved_inputs, key=lambda sa: str(sa.nctype.name)):
+        save_var(var, is_output=False)
+    for var in sorted(info.all_saved_outputs, key=lambda sa: str(sa.nctype.name)):
+        save_var(var, is_output=True)
+
+    # lock the mutex when we release variables and in Node::apply to protect thread safety
+    # see Note [Thread Safety on Autograd Node]
+    if len(release_variables) > 0:
+        thread_lock = "std::lock_guard lock(mutex_);"
+    else:
+        thread_lock = ""
+
+    if uses_retain_variables(info):
+        apply_functional_args.append("retain_variables")
+        apply_functional_args_ref_types.append("bool")
+        will_release_variables = WILL_RELEASE_VARIABLES.substitute()
+    else:
+        will_release_variables = ""
+
+    body: list[str] = []
+
+    if uses_single_grad(info):
+        body.append("const auto& grad = grads[0];")
+    else:
+        # Generate aliases for gradients named for returned values.
+        body.extend(
+            f"const auto& {name} = grads[{info.available_named_gradients.index(name)}];"
+            for name in sorted(info.used_named_gradients)
+        )
+
+    def emit_derivative(
+        derivative: Derivative,
+        args_with_derivatives: Sequence[Binding],
+    ) -> tuple[bool, str]:
+        formula = derivative.formula
+        var_names = derivative.var_names
+
+        if len(var_names) == 1:
+            checks_any_grad_defined = False
+            if "not_implemented" not in formula:
+                matching_args = [
+                    arg for arg in args_with_derivatives if arg.name == var_names[0]
+                ]
+                if len(matching_args) == 1:
+                    # We can add undefined grad support if the input variable is a Tensor
+                    arg = matching_args[0]
+                    if isinstance(arg.argument, Argument) and str(
+                        arg.argument.type
+                    ) in ("Tensor", "Tensor?"):
+                        formula = "any_grad_defined ? (" + formula + ") : Tensor()"
+                        checks_any_grad_defined = True
+            if info.name.startswith("_foreach_"):
+                derivative_template = DERIVATIVE_SINGLE_FOREACH
+            else:
+                derivative_template = DERIVATIVE_SINGLE
+            return (
+                checks_any_grad_defined,
+                derivative_template.substitute(
+                    name=var_names[0],
+                    derivative=formula,
+                    idx=input_name_to_idx[var_names[0]],
+                ),
+            )
+
+        else:
+            if "grad_input_mask" in formula:
+                masks = [
+                    f"needs_input_grad[{input_name_to_idx[name]}],"
+                    for name in var_names
+                ]
+                grad_input_mask = GRAD_INPUT_MASK.substitute(
+                    n=len(var_names), masks=masks
+                )
+            else:
+                grad_input_mask = ""
+            needs_input_grad = [
+                f"needs_input_grad[{input_name_to_idx[name]}]" for name in var_names
+            ]
+            needs_input_grad = " || ".join(needs_input_grad)
+            copy_ranges: list[str] = []
+            for i, n in enumerate(var_names):
+                copy_ranges.append(
+                    DERIVATIVE_MULTI_COPY_RANGE.substitute(
+                        name=n, i=i, idx=input_name_to_idx[n]
+                    )
+                )
+            return False, DERIVATIVE_MULTI.substitute(
+                needs_input_grad=needs_input_grad,
+                copy_ranges=copy_ranges,
+                derivative=formula,
+                grad_input_mask=grad_input_mask,
+            )
+
+    masks = []
+
+    need_any_grad_defined_var = False
+    for derivative in info.derivatives:
+        checks_any_grad_defined, derivative_text = emit_derivative(
+            derivative, info.args_with_derivatives
+        )
+        body.append(derivative_text)
+        need_any_grad_defined_var |= checks_any_grad_defined
+
+    for name in input_name_to_idx:
+        masks.append(f"task_should_compute_output({{ {name}_ix }}),")
+
+    # Since single-output derivative formulas need to check if grads are
+    # defined, only perform the check once, before all the formulas
+    if need_any_grad_defined_var:
+        body.insert(
+            -len(info.derivatives),
+            "bool any_grad_defined = any_variable_defined(grads);",
+        )
+
+    if info.name in UNTRACEABLE_FUNCTIONS:
+        superclass = "Node"
+    else:
+        superclass = "TraceableFunction"
+
+    all_getsetdef_structs = (
+        ",\n".join(py_getsetdef_structs) + "," if len(py_getsetdef_structs) != 0 else ""
+    )
+    all_getter_definitions = "\n".join(getter_definitions)
+
+    compute_needs_input_grad = COMPUTE_NEEDS_INPUT_GRAD.substitute(
+        n=len(masks), compute_index_ranges=compute_index_ranges, masks=masks
+    )
+    apply_functional_args_signature = [
+        f"{T} {x}"
+        for T, x in zip(apply_functional_args_ref_types, apply_functional_args)
+    ]
+    get_packed_args = "\n".join(
+        f"packed_args.pack({name});" for name in apply_functional_args
+    )
+    unpack_ivalues = []
+    for typ, name in zip(apply_functional_args_ref_types, apply_functional_args):
+        typ = typ.removesuffix("&")
+        unpack_ivalues.append(f"auto {name} = packed_args.unpack<{typ}>();")
+
+    schema_args = [f"std::array"]
+    for typ in apply_functional_args_ref_types:
+        typ = typ.removesuffix("&")
+        typ = typ.removeprefix("const")
+        schema_args.append(typ.strip())
+    compute_schema = ["std::vector schema = {"]
+    for schema_arg in schema_args:
+        compute_schema.append(
+            f"  torch::dynamo::autograd::IValuePacker<{schema_arg}>::packed_type(),"
+        )
+    compute_schema.append("};")
+
+    return template.substitute(
+        unpacks="\n".join(unpack),
+        op=info.op,
+        compute_schema="\n".join(compute_schema),
+        apply_functional_args=apply_functional_args,
+        apply_functional_args_signature=apply_functional_args_signature,
+        compute_needs_input_grad=compute_needs_input_grad,
+        num_inputs=len(input_name_to_idx),
+        unpack_ivalues="\n".join(unpack_ivalues),
+        compute_index_ranges=compute_index_ranges,
+        saved_variables=saved_variables,
+        release_variables=release_variables,
+        saved_list_sizes=saved_list_sizes,
+        asserts=asserts,
+        thread_lock=thread_lock,
+        will_release_variables=will_release_variables,
+        body=body,
+        superclass=superclass,
+        all_getter_definitions=all_getter_definitions,
+        all_getsetdef_structs=all_getsetdef_structs,
+        compiled_args=compiled_args,
+        apply_with_saved_before=apply_with_saved_before,
+        apply_with_saved_after=apply_with_saved_after,
+        get_packed_args=get_packed_args,
+    )
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_inplace_or_view_type.py b/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_inplace_or_view_type.py
new file mode 100644
index 0000000000000000000000000000000000000000..099c528b7e6b2beb161f2d5f669f715310bfc970
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_inplace_or_view_type.py
@@ -0,0 +1,673 @@
+# Generates ADInplaceOrViewType.h/cpp
+#
+# NOTE: If any changes are being made to the ADInplaceOrView codegen please also check
+# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
+# The fallback is expected to mimic this codegen, so we should keep the two in sync.
+
+from __future__ import annotations
+
+from torchgen.api import cpp
+from torchgen.api.autograd import (
+    dispatch_strategy,
+    gen_differentiable_outputs,
+    NativeFunctionWithDifferentiabilityInfo,
+)
+from torchgen.api.types import (
+    BaseCType,
+    Binding,
+    boolT,
+    ConstRefCType,
+    CType,
+    DispatcherSignature,
+    intArrayRefT,
+    longT,
+    OptionalCType,
+    symIntArrayRefT,
+    SymIntT,
+    tensorT,
+)
+from torchgen.code_template import CodeTemplate
+from torchgen.context import with_native_function
+from torchgen.model import (
+    NativeFunction,
+    SchemaKind,
+    SelfArgument,
+    TensorOptionsArguments,
+    Type,
+)
+from torchgen.utils import FileManager
+
+from .context import with_native_function_with_differentiability_info
+from .gen_trace_type import (
+    get_return_value,
+    MANUAL_AUTOGRAD,
+    tie_return_values,
+    type_wrapper_name,
+)
+
+
+# See NOTE [ Autograd View Variables ] in variable.h for details.
+# If you update list VIEW_FUNCTIONS or RETURNS_VIEWS_OF_INPUT,
+# you **MUST** also update the public list of view ops accordingly in
+# docs/source/tensor_view.rst. Note not all ATen functions are exposed to public,
+# e.g alias & sparse_coo_tensor_with_dims_and_tensors.
+#
+# A map: function name => name of the argument that all outputs are view of
+
+VIEW_FUNCTIONS_WITH_METADATA_CHANGE = [
+    "view_as_complex",
+    "view_as_real",
+    "_conj",
+    "_neg_view",
+    "_nested_get_values",
+    "_nested_view_from_buffer",
+    "_nested_view_from_jagged",
+]
+
+VIEW_FUNCTIONS = {
+    "numpy_T": "self",
+    "alias": "self",
+    "as_strided": "self",
+    "diagonal": "self",
+    "expand": "self",
+    "permute": "self",
+    "select": "self",
+    "slice": "self",
+    "slice_inverse": "self",
+    "split": "self",
+    "split_with_sizes": "self",
+    "squeeze": "self",
+    "t": "self",
+    "transpose": "self",
+    "unfold": "self",
+    "unsqueeze": "self",
+    "flatten": "self",
+    "view": "self",
+    "unbind": "self",
+    "_indices": "self",
+    "_values": "self",
+    "indices": "self",
+    "values": "self",
+    "crow_indices": "self",
+    "col_indices": "self",
+    "ccol_indices": "self",
+    "row_indices": "self",
+    # sparse_coo ctor output should really be views of both indices and values,
+    # but we only supports making as view of a single variable, and indices is
+    # discrete anyways.
+    # FIXME: clone indices on construction.
+    "sparse_coo_tensor_with_dims_and_tensors": "values",
+    "_reshape_alias": "self",
+    "_test_autograd_multiple_dispatch_view": "self",
+}
+
+for key in VIEW_FUNCTIONS_WITH_METADATA_CHANGE:
+    VIEW_FUNCTIONS[key] = "self"
+
+# note: some VIEW_FUNCTIONS are just compositions of the view functions above
+# this list contains both the root view functions and any that are purely composed
+# of viewing functions, and is used by the JIT to determine when an operator
+# may return a view of its inputs; however they may sometimes return a copy.
+# (e.g. `contiguous`)
+RETURNS_VIEWS_OF_INPUT = set(VIEW_FUNCTIONS.keys()).union(
+    {
+        "chunk",
+        "detach",
+        "contiguous",
+        "reshape",
+        "reshape_as",
+        "expand_as",
+        "view_as",
+        "real",
+        "imag",
+        "narrow",
+        "movedim",
+        "tensor_split",
+        "swapdims",
+        "swapaxes",
+        "mT",
+        "mH",
+        "adjoint",
+        "matrix_H",
+    }
+)
+
+# These are the functions we consider views for the purposes of validating
+# StorageImpl and TensorImpl in gen_variable_type.
+# `_unsafe_view` is not included in VIEW_FUNCTIONS above because it is not a
+# view for the purposes of ADInplaceOrView kernel, we do not want to call as_view
+# See NOTE [Unsafe View] for more info.
+ALL_VIEW_FUNCTIONS = {
+    **VIEW_FUNCTIONS,
+    "_unsafe_view": "self",
+}
+
+ARRAYREF_TO_VEC = CodeTemplate(
+    """\
+auto ${vec} = ${arg}.vec();
+"""
+)
+
+OPTIONAL_TO_VAL = CodeTemplate(
+    """\
+auto ${val} = ${arg}.value_or(${default});
+"""
+)
+
+CALL_DISPATCH = CodeTemplate(
+    """\
+at::_ops::${unambiguous_name}::call(${unpacked_args})"""
+)
+
+REVERSE_VIEW_DISPATCH = CodeTemplate(
+    """\
+${reverse_name}(${unpacked_args})"""
+)
+
+MULTI_OUTPUT_VIEW_ITERATION = CodeTemplate(
+    """\
+for (auto ${view_idx} : c10::irange(${var}.size())) {
+  ${body}
+}
+"""
+)
+
+SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE = CodeTemplate(
+    """\
+std::unique_ptr func(nullptr);
+std::function rev_func=nullptr;
+if (${is_view_with_metadata_change} ||
+    !self.unsafeGetTensorImpl()->support_as_strided() ||
+    self.unsafeGetTensorImpl()->is_python_dispatch() ||
+    c10::AutogradState::get_tls_state().get_view_replay_enabled()) {
+  ${replay_view_func}
+  ${reverse_replay_view_func}
+}
+"""
+)
+
+REPLAY_VIEW_FUNC = CodeTemplate(
+    """\
+func = std::make_unique<${view_func_name}>(${view_func_args});
+"""
+)
+
+REVERSE_REPLAY_VIEW_LAMBDA_FUNC = CodeTemplate(
+    """\
+rev_func = [=](const at::Tensor& ${input_view}) {
+  return ${reverse_replay_view_call};
+};
+"""
+)
+
+METHOD_DEFINITION = CodeTemplate(
+    """\
+${return_type} ${type_wrapper_name}(${formals}) {
+  ${type_definition_body}
+}
+"""
+)
+
+WRAPPER_REGISTRATION = CodeTemplate(
+    """\
+m.impl("${unqual_operator_name_with_overload}",
+       TORCH_FN(${class_type}::${type_wrapper_name})
+);
+"""
+)
+
+AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION = CodeTemplate(
+    """\
+m.impl("${unqual_operator_name_with_overload}", torch::autograd::autogradNotImplementedFallback());
+"""
+)
+
+INPLACE_REDISPATCH = CodeTemplate(
+    """\
+{
+  at::AutoDispatchBelowADInplaceOrView guard;
+  at::_ops::${unambiguous_name}::redispatch(${unpacked_args});
+}
+"""
+)
+
+ASSIGN_RETURN_VALUE = CodeTemplate(
+    """\
+${return_values} = ${rhs_value};
+"""
+)
+
+VIEW_REDISPATCH = CodeTemplate(
+    """\
+${assign_return_values} ([&]() {
+  at::AutoDispatchBelowADInplaceOrView guard;
+  return at::_ops::${unambiguous_name}::redispatch(${unpacked_args});
+})();
+"""
+)
+
+TMP_VAR = "_tmp"
+
+
+# FIXME: Ideally these functions should be methods on Type class, but we have a
+#        comment in codegen/model.py there saying these concepts are not well defined.
+#        Thus we put a version that commonly used by autograd codegen here.
+def is_tensor_type(t: Type) -> bool:
+    # TODO: Should handle optional here?
+    return t.is_tensor_like() and t.is_list_like() is None
+
+
+def is_tensor_list_type(t: Type) -> bool:
+    # TODO: Should handle optional here?
+    return t.is_tensor_like() and t.is_list_like() is not None
+
+
+UNPACK_TENSOR = CodeTemplate(
+    """\
+auto${ref} ${arg_name}_ = unpack${suffix}(${arg_name}, "${arg_name}", ${arg_pos});"""
+)
+
+
+def unpacked_name(arg_name: str) -> str:
+    return arg_name + "_"
+
+
+# e.g. select.int -> select_copy_int_inverse()
+def inverse_view_name(f: NativeFunction) -> str:
+    copy_variant = f"{f.root_name}_copy"
+    overload = f"{f.func.name.overload_name}"
+    if overload != "":
+        overload = "_" + overload
+    return f"{copy_variant}{overload}_inverse"
+
+
+def extract_bindings(f: NativeFunction) -> list[Binding]:
+    return [
+        r
+        for a in f.func.schema_order_arguments()
+        for r in cpp.argument(
+            a,
+            method=False,
+            symint=True,
+            cpp_no_default_args=set(),
+            faithful=False,
+            has_tensor_options=False,
+        )
+    ]
+
+
+@with_native_function
+def unpack_args(f: NativeFunction) -> tuple[list[str], list[Binding]]:
+    body: list[str] = []
+    unpacked_bindings: list[Binding] = []
+
+    for i, binding in enumerate(extract_bindings(f)):
+        assert not isinstance(binding.argument, SelfArgument)
+        if isinstance(binding.argument, TensorOptionsArguments):
+            raise RuntimeError("VariableKernel shouldn't take TensorOptions")
+
+        is_nullable = binding.argument.type.is_nullable()
+        if not binding.argument.type.is_tensor_like() or is_nullable:
+            unpacked_bindings.append(binding)
+            continue
+
+        is_tensor_list = is_tensor_list_type(binding.argument.type)
+        ref = (not is_nullable) and not is_tensor_list
+        suffix = "_opt" if is_nullable and not is_tensor_list else ""
+        body.append(
+            UNPACK_TENSOR.substitute(
+                arg_name=binding.name,
+                arg_pos=i,
+                suffix=suffix,
+                ref="&" if ref else "",
+            )
+        )
+        unpacked_bindings.append(
+            Binding(
+                name=unpacked_name(binding.name),
+                nctype=binding.nctype,
+                argument=binding.argument,
+                default=binding.default,
+            )
+        )
+
+    return body, unpacked_bindings
+
+
+def get_base_name(f: NativeFunction) -> str:
+    return f.func.name.name.base  # TODO: should be str(f.func.name.name)?
+
+
+def get_view_info(f: NativeFunction) -> str | None:
+    base_name = get_base_name(f)
+    view_info = VIEW_FUNCTIONS.get(base_name, None)
+    if view_info is None and base_name in RETURNS_VIEWS_OF_INPUT:
+        view_info = "self"
+    return view_info
+
+
+def emit_view_func(
+    f: NativeFunction, bindings: list[Binding], view_idx: str | None = None
+) -> str:
+    """Generate an additional lambda function to recover views in backward when as_strided is not supported.
+    See Note [View + Inplace update for base tensor] and [View + Inplace update for view tensor] for more details.
+    """
+    # TODO: Clean this logic up if we get rid of reverse view funcs or reify them.
+    input_base = "input_base"
+    replay_view_func = ""
+    updated_args: list[str] = []
+    known_view_arg_simple_types: list[CType] = [
+        BaseCType(longT),
+        OptionalCType(BaseCType(longT)),
+        BaseCType(SymIntT),
+        OptionalCType(BaseCType(SymIntT)),
+        BaseCType(boolT),
+        BaseCType(intArrayRefT),
+        BaseCType(symIntArrayRefT),
+        ConstRefCType(BaseCType(tensorT)),
+        ConstRefCType(OptionalCType(BaseCType(tensorT))),
+    ]
+    for binding in bindings:
+        arg, arg_type = binding.name, binding.nctype.type
+        if arg == "self":
+            updated_args.append(input_base)
+            continue
+        if arg_type not in known_view_arg_simple_types:
+            known_types_str = ", ".join([str(t) for t in known_view_arg_simple_types])
+            raise TypeError(
+                f"You are adding an {arg_type} {arg} argument to op {cpp.name(f.func)} in addition to known types: "
+                f"{known_types_str}. Please update the list or materialize it so that it can be closed "
+                "over by value, also add a test in pytorch/xla/test/test_operations.py where this code "
+                "is exercised."
+            )
+        if arg_type == BaseCType(intArrayRefT) or arg_type == BaseCType(
+            symIntArrayRefT
+        ):
+            # It's not safe to close over IntArrayRef by value, since this is a
+            # reference type, so materialize a vector to close over by value
+            arg_vec = arg + "_vec"
+            replay_view_func += ARRAYREF_TO_VEC.substitute(arg=arg, vec=arg_vec)
+            updated_args.append(arg_vec)
+        elif arg_type == OptionalCType(BaseCType(longT)):
+            # Materialize int64_t? to int64_t
+            arg_value = arg + "_val"
+            replay_view_func += OPTIONAL_TO_VAL.substitute(
+                arg=arg, val=arg_value, default="0"
+            )
+            updated_args.append(arg_value)
+        elif arg_type == ConstRefCType(BaseCType(tensorT)) or arg_type == ConstRefCType(
+            OptionalCType(BaseCType(tensorT))
+        ):
+            # NB: Closing over a tensor. If a user modifies this tensor, this will be silently
+            # incorrect. The proper thing to do is to store the version counter and copy on write.
+            updated_args.append(arg)
+        else:
+            updated_args.append(arg)
+
+    from .gen_view_funcs import view_func_name
+
+    view_func_args = [b.name for b in bindings if b.name != "self"]
+    if view_idx is not None:
+        view_func_args.append(f"{view_idx}")
+    replay_view_func += REPLAY_VIEW_FUNC.substitute(
+        view_func_name=view_func_name(f, include_namespace=True),
+        view_func_args=view_func_args,
+    )
+
+    input_view = "input_view"
+    reverse_unpacked_args = [
+        "self",
+        f"{input_view}",
+        # inverse_return_mode=
+        "at::functionalization::InverseReturnMode::AlwaysView",
+        *(() if view_idx is None else (f"{view_idx}",)),
+        # skip input_base arg
+        *updated_args[1:],
+    ]
+
+    from torchgen.api.functionalization import reverse_name
+
+    reverse_replay_view_call = REVERSE_VIEW_DISPATCH.substitute(
+        reverse_name=reverse_name(f, include_namespace=True),
+        unpacked_args=reverse_unpacked_args,
+    )
+    reverse_replay_view_func = REVERSE_REPLAY_VIEW_LAMBDA_FUNC.substitute(
+        input_view=input_view, reverse_replay_view_call=reverse_replay_view_call
+    )
+
+    is_view_with_metadata_change = (
+        "true" if cpp.name(f.func) in VIEW_FUNCTIONS_WITH_METADATA_CHANGE else "false"
+    )
+
+    return SETUP_REPLAY_VIEW_IF_NOT_SUPPORT_AS_STRIDED_OR_VIEW_WITH_METADATA_CHANGE.substitute(
+        is_view_with_metadata_change=is_view_with_metadata_change,
+        replay_view_func=replay_view_func,
+        reverse_replay_view_func=reverse_replay_view_func,
+    )
+
+
+def emit_view_body(
+    fn: NativeFunctionWithDifferentiabilityInfo, var: str
+) -> tuple[str, str]:
+    # See NOTE [ Autograd View Variables ] in variable.h for details.
+    f = fn.func
+    base_name = get_base_name(f)
+    view_info = get_view_info(f)
+    call = ""
+    differentiable_outputs = gen_differentiable_outputs(fn)
+    differentiable_output_vars = {r.name for r in differentiable_outputs}
+    if not isinstance(view_info, str):
+        raise TypeError(
+            f"The view info should be a string for {base_name}, but it is: {view_info}"
+        )
+    if len(differentiable_output_vars) == 0:
+        # no output is differentiable (.indices() for SparseTensors for example)
+        rhs_value = (
+            f"as_view({view_info}, {var}, "
+            f"/* is_bw_differentiable */ false, /* is_fw_differentiable */ false)"
+        )
+    elif len(differentiable_output_vars) == 1:
+        # Single differentiable output (Tensor or Tensor[])
+        return_info = differentiable_outputs[0]
+        # We only support simple Tensor or a TensorList for functions that return views
+        if not is_tensor_type(return_info.type) and not is_tensor_list_type(
+            return_info.type
+        ):
+            raise RuntimeError(
+                f"{base_name} that return differentiable views can only return Tensor or Tensor[]"
+            )
+
+        # See Note [ View + Inplace detection]
+        def get_creation_meta_in_mode(original: str) -> str:
+            creation_meta_with_grad_mode = f"(at::GradMode::is_enabled() ? {original} : CreationMeta::NO_GRAD_MODE)"
+            return f"InferenceMode::is_enabled() ? CreationMeta::INFERENCE_MODE : {creation_meta_with_grad_mode}"
+
+        # Only allow rebasing of the history if we return a single Tensor
+        # If we are in a no grad block, raise a warning
+        # See NOTE [ View + Inplace detection ] for more details about this logic
+        if is_tensor_list_type(return_info.type):
+            creation_meta = get_creation_meta_in_mode("CreationMeta::MULTI_OUTPUT_NODE")
+            view_idx = "view_idx"
+            view_func = emit_view_func(
+                f, extract_bindings(f), view_idx=view_idx
+            ).strip()
+            as_view_call = (
+                f"as_view(/* base */ {view_info}, /* output */ {var}[{view_idx}], "
+                "/* is_bw_differentiable */ true, /* is_fw_differentiable */ true, "
+                "/* view_func */ std::move(func), /* rev_view_func */ rev_func, "
+                f"/* creation_meta */ {creation_meta});"
+            )
+            call += MULTI_OUTPUT_VIEW_ITERATION.substitute(
+                var=var, view_idx=view_idx, body=f"{view_func}\n{as_view_call}"
+            )
+            rhs_value = f"std::move({var})"
+        else:
+            call += emit_view_func(f, extract_bindings(f), view_idx=None)
+            creation_meta = get_creation_meta_in_mode("CreationMeta::DEFAULT")
+            rhs_value = (
+                f"as_view(/* base */ {view_info}, /* output */ {var}, /* is_bw_differentiable */ true, "
+                "/* is_fw_differentiable */ true, "
+                f"/* view_func */ std::move(func), /* rev_view_func */ rev_func, /* creation_meta */ {creation_meta})"
+            )
+    else:
+        # This could be supported but we don't need it at the moment, so keeping things simple.
+        raise RuntimeError(
+            "Function that return multiple differentiable output "
+            "when at least one of them is view is not supported."
+        )
+    return call, rhs_value
+
+
+def modifies_arguments(f: NativeFunction) -> bool:
+    return f.func.kind() in [SchemaKind.inplace, SchemaKind.out]
+
+
+@with_native_function_with_differentiability_info
+def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> list[str]:
+    f = fn.func
+    inplace_view_body: list[str] = []
+
+    dispatcher_sig = DispatcherSignature.from_schema(f.func)
+    dispatcher_exprs = dispatcher_sig.exprs()
+
+    # code-generated ADInplaceOrView kernels plumb and recompute dispatch keys directly through the kernel for performance.
+    # See Note [Plumbing Keys Through The Dispatcher] for details.
+    dispatch_key_set = "ks & c10::after_ADInplaceOrView_keyset"
+    redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs])
+
+    # Note that this calls the slow, dispatching variants of manual_cpp_binding ops.
+    # We could probably work harder to ensure that the fast variants are called instead, but the perf benefit would be minimal.
+    if modifies_arguments(f):  # inplace op
+        inplace_view_body.append(
+            INPLACE_REDISPATCH.substitute(
+                unambiguous_name=f.func.name.unambiguous_name(),
+                unpacked_args=redispatch_args,
+            )
+        )
+        for r in cpp.return_names(f):
+            inplace_view_body.append(f"increment_version({r});")
+    else:
+        assert get_view_info(f) is not None
+        inplace_view_body.append(
+            VIEW_REDISPATCH.substitute(
+                assign_return_values="auto " + TMP_VAR + " = ",
+                unambiguous_name=f.func.name.unambiguous_name(),
+                unpacked_args=redispatch_args,
+            )
+        )
+        call, rhs_value = emit_view_body(fn, TMP_VAR)
+        inplace_view_body.append(call)
+        assert rhs_value is not None
+        inplace_view_body.append(
+            ASSIGN_RETURN_VALUE.substitute(
+                return_values=tie_return_values(f), rhs_value=rhs_value
+            )
+        )
+    if f.func.returns:
+        inplace_view_body.append(f"return {get_return_value(f)};")
+    return inplace_view_body
+
+
+@with_native_function
+def gen_formals(f: NativeFunction) -> str:
+    return ", ".join(
+        # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance.
+        # See Note [Plumbing Keys Through The Dispatcher] for details.
+        ["c10::DispatchKeySet ks"]
+        + [
+            f"{cpp.argument_type(a, binds='__placeholder__', symint=True).cpp_type()} {a.name}"
+            for a in f.func.schema_order_arguments()
+        ]
+    )
+
+
+@with_native_function_with_differentiability_info
+def inplace_or_view_method_definition(
+    fn: NativeFunctionWithDifferentiabilityInfo,
+) -> str | None:
+    f = fn.func
+    if get_view_info(f) is None and (
+        # For functions that modify their inputs but don't return them,
+        # we can't give them autograd support.
+        # See https://github.com/pytorch/pytorch/issues/53796
+        not modifies_arguments(f) or len(f.func.returns) == 0
+    ):
+        return None
+    return METHOD_DEFINITION.substitute(
+        return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(),
+        type_wrapper_name=type_wrapper_name(f),
+        formals=gen_formals(f),
+        type_definition_body=emit_inplace_or_view_body(fn),
+    )
+
+
+@with_native_function_with_differentiability_info
+def inplace_or_view_method_registration(
+    fn: NativeFunctionWithDifferentiabilityInfo,
+) -> str | None:
+    f = fn.func
+    if get_view_info(f) is None and (
+        not modifies_arguments(f) or len(f.func.returns) == 0
+    ):
+        return None
+    return WRAPPER_REGISTRATION.substitute(
+        unqual_operator_name_with_overload=f.func.name,
+        type_wrapper_name=type_wrapper_name(f),
+        class_type="ADInplaceOrView",
+    )
+
+
+def use_derived(fn: NativeFunctionWithDifferentiabilityInfo) -> bool:
+    f = fn.func
+    name = cpp.name(f.func)
+    return name not in MANUAL_AUTOGRAD and dispatch_strategy(fn) == "use_derived"
+
+
+def gen_inplace_or_view_type_env(
+    fn: NativeFunctionWithDifferentiabilityInfo,
+) -> dict[str, list[str]]:
+    definition = inplace_or_view_method_definition(fn)
+    registration = inplace_or_view_method_registration(fn)
+
+    return {
+        "ops_headers": (
+            [f"#include "]
+            if definition is not None
+            else []
+        ),
+        "inplace_or_view_method_definitions": [definition]
+        if definition is not None
+        else [],
+        "inplace_or_view_wrapper_registrations": [registration]
+        if registration is not None
+        else [],
+    }
+
+
+def gen_inplace_or_view_type(
+    out: str,
+    native_yaml_path: str,
+    tags_yaml_path: str,
+    fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo],
+    template_path: str,
+) -> None:
+    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
+    # template regarding sharding of the generated files.
+
+    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
+    fm.write_sharded(
+        "ADInplaceOrViewType.cpp",
+        [fn for fn in fns_with_infos if use_derived(fn)],
+        key_fn=lambda fn: fn.func.root_name,
+        base_env={
+            "generated_comment": "@"
+            + f"generated from {fm.template_dir_for_comments()}/ADInplaceOrViewType.cpp",
+        },
+        env_callable=gen_inplace_or_view_type_env,
+        num_shards=2,
+        sharded_keys={
+            "ops_headers",
+            "inplace_or_view_method_definitions",
+            "inplace_or_view_wrapper_registrations",
+        },
+    )
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_python_functions.py b/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_python_functions.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4adbfbd773be95d5e703c6db8507f640ae8d4f1
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_python_functions.py
@@ -0,0 +1,1404 @@
+# Generates Python bindings for ATen functions
+#
+# The bindings are generated as methods on python_variable or functions on the
+# torch._C._nn. torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse
+# or torch._C._special objects.
+#
+
+# Code tries to stick to the following rules:
+#
+# - templates should be colocated with the functions that use them.
+#   no templates are currently shared between functions, but if that
+#   happens, maybe put the template with the first one
+#
+# - don't use environment dictionaries when calling template.substitute().
+#   pass named arguments directly for everything, otherwise it's much too
+#   hard to track what's actually being used and by who
+#
+# - colocate any new hacks/adjustments with existing ones of the same kind.
+#   ideally in a data structure rather than code if possible. See e.g.
+#   SCHEMA_DEFAULT_CONVERSION_HACKS, etc.
+#
+# - similarly, conversions from one format to another should ideally happen
+#   all at once in a single place.
+#
+# - no nontrivial nested functions. couple-liners are ok but please no more.
+#   especially avoid functions that read/write outer variables defined far away.
+#
+# - raise RuntimeError instead of asserting, and put as much
+#   information as is available into the message. I.e. no need to
+#   plumb in new params whose only purpose is to fill out an error
+#   message, but use what's there
+#
+
+from __future__ import annotations
+
+import itertools
+import re
+from collections import defaultdict
+from typing import Callable, TYPE_CHECKING
+
+import yaml
+
+from torchgen.api import cpp
+from torchgen.api.python import (
+    arg_parser_output_exprs,
+    cpp_dispatch_exprs,
+    cpp_dispatch_target,
+    dispatch_lambda_args,
+    dispatch_lambda_exprs,
+    dispatch_lambda_return_str,
+    has_tensor_options,
+    PythonSignature,
+    PythonSignatureDeprecated,
+    PythonSignatureGroup,
+    PythonSignatureNativeFunctionPair,
+    signature,
+    signature_from_schema,
+    structseq_fieldnames,
+)
+from torchgen.code_template import CodeTemplate
+from torchgen.context import with_native_function
+from torchgen.gen import cpp_string, parse_native_yaml, parse_tags_yaml
+from torchgen.model import (
+    Argument,
+    BaseOperatorName,
+    FunctionSchema,
+    NativeFunction,
+    SchemaKind,
+    Type,
+    Variant,
+)
+from torchgen.utils import FileManager, split_name_params
+from torchgen.yaml_utils import YamlLoader
+
+from .gen_inplace_or_view_type import is_tensor_list_type
+from .gen_trace_type import should_trace
+
+
+if TYPE_CHECKING:
+    from collections.abc import Iterable, Sequence
+
+
+#
+# declarations blocklist
+# We skip codegen for these functions, for various reasons.
+# Future PRs will categorize this list and eliminate or hoist
+# them out of eager-only codegen.
+# See https://github.com/pytorch/pytorch/issues/30788
+#
+
+# These functions require manual Python bindings or are not exposed to Python
+_SKIP_PYTHON_BINDINGS = [
+    "alias",
+    "contiguous",
+    "is_cuda",
+    "is_sparse",
+    "is_sparse_csr",
+    "size",
+    "stride",
+    "sym_size",
+    "sym_stride",
+    "sym_storage_offset",
+    "sym_numel",
+    ".*_backward",
+    ".*_backward_(out|input|weight|bias)",
+    ".*_forward",
+    ".*_forward_out",
+    ".*_jvp",
+    "_unsafe_view",
+    "tensor",
+    "_?sparse_(coo|compressed|csr|csc|bsr|bsc)_tensor.*",
+    "_range.*",
+    "_sparse_add_out",
+    "_sparse_div.*",
+    "_sparse_mul.*",
+    "_sparse_sub.*",
+    "_sparse_dense_add_out",
+    "index",
+    "index_out",
+    "unique_dim_consecutive",
+    "_cumsum.*",
+    "_cumprod.*",
+    "_sum.*",
+    "_prod.*",
+    "_th_.*",
+    "_thnn_.*",
+    "range.*",
+    "_solve.*",
+    "_inverse.*",
+    "_cholesky.*",
+    "_triangular_solve.*",
+    "_qr.*",
+    "_svd.*",
+    "slice",
+    "item",
+    "_local_scalar_dense",
+    "to",
+    "_to_copy",
+    "_to_copy_out",
+    "_reshape_copy",
+    "_reshape_copy_out",
+    "copy_sparse_to_sparse_",
+    "copy_",
+    "_foreach_copy",
+    "numpy_T",
+    "matrix_H",
+    "mT",
+    "mH",  # these need to be an attributes in Python, not functions
+    "nonzero(_(out|numpy))?",
+    "set_data",
+    ".*_overrideable",  # overridable functions for backend extension
+    "data",
+    "is_leaf",
+    "output_nr",
+    "_version",
+    "requires_grad_",
+    "retains_grad",
+    "set_",
+    "_fw_primal",
+    "fake_quantize_per_tensor_affine_cachemask",
+    "fake_quantize_per_channel_affine_cachemask",
+    "_new_zeros_with_same_feature_meta",
+    "_has_same_storage_numel",  # used for forward AD internals
+    "_reshape_alias",
+    "replace_",  # only used by the functionalization pass, doesn't need to be exposed to python
+    "copy",  # only used by the functionalization pass
+    "fill.Tensor",  # only used by the functionalization pass
+    "fill.Scalar",  # only used by the functionalization pass
+    "lift.*",
+    "normal_functional",  # only used by the functionalization pass
+    "nbytes",
+    "itemsize",
+    "_batch_norm_with_update",
+    "_batch_norm_with_update_out",
+    "_batch_norm_no_update",
+]
+
+SKIP_PYTHON_BINDINGS = [
+    re.compile(rf"^{pattern}$") for pattern in _SKIP_PYTHON_BINDINGS
+]
+
+# These function signatures are not exposed to Python. Note that this signature
+# list does not support regex.
+SKIP_PYTHON_BINDINGS_SIGNATURES = [
+    "add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
+    "add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)",
+    "sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
+    "sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)",
+    "mul.Scalar(Tensor self, Scalar other) -> Tensor",
+    "mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
+    "div.Scalar(Tensor self, Scalar other) -> Tensor",
+    "div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
+]
+
+
+@with_native_function
+def should_generate_py_binding(f: NativeFunction) -> bool:
+    # NativeFunctions that are entirely code-generated should not get python bindings
+    # because these codegen implementations are often inefficient. A handful of
+    # view_copy style ops were exposed accidentally when they were handwritten and now
+    # that we are moving them to codegen for bc reasons we need to keep them exposed in
+    # python.
+    if "generated" in f.tags and "view_copy" not in f.tags:
+        return False
+
+    name = cpp.name(f.func)
+    for skip_regex in SKIP_PYTHON_BINDINGS:
+        if skip_regex.match(name):
+            return False
+
+    signature = str(f.func)
+    for pattern in SKIP_PYTHON_BINDINGS_SIGNATURES:
+        if pattern == signature:
+            return False
+    return True
+
+
+def get_pycname(name: BaseOperatorName) -> str:
+    return f"THPVariable_{name}"
+
+
+def is_noarg(overloads: Sequence[PythonSignatureNativeFunctionPair]) -> bool:
+    return len(overloads) == 1 and overloads[0].signature.arguments_count() == 0
+
+
+def is_py_variable_method(f: NativeFunction) -> bool:
+    return f.python_module is None and Variant.method in f.variants
+
+
+def is_py_torch_function(f: NativeFunction) -> bool:
+    return f.python_module is None and Variant.function in f.variants
+
+
+def is_py_nn_function(f: NativeFunction) -> bool:
+    return f.python_module == "nn"
+
+
+def is_py_fft_function(f: NativeFunction) -> bool:
+    return f.python_module == "fft"
+
+
+def is_py_linalg_function(f: NativeFunction) -> bool:
+    return f.python_module == "linalg"
+
+
+def is_py_nested_function(f: NativeFunction) -> bool:
+    return f.python_module == "nested"
+
+
+def is_py_sparse_function(f: NativeFunction) -> bool:
+    return f.python_module == "sparse"
+
+
+def is_py_special_function(f: NativeFunction) -> bool:
+    return f.python_module == "special"
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                            Main Function
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+def gen(
+    out: str,
+    native_yaml_path: str,
+    tags_yaml_path: str,
+    deprecated_yaml_path: str,
+    template_path: str,
+    *,
+    symint: bool = True,
+) -> None:
+    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
+    native_functions = parse_native_yaml(
+        native_yaml_path, tags_yaml_path
+    ).native_functions
+    native_functions = list(filter(should_generate_py_binding, native_functions))
+
+    methods = load_signatures(native_functions, deprecated_yaml_path, method=True)
+    create_python_bindings(
+        fm,
+        methods,
+        is_py_variable_method,
+        None,
+        "python_variable_methods.cpp",
+        method=True,
+        symint=symint,
+    )
+
+    # NOTE: num_shards here must be synced with gatherTorchFunctions in
+    #       torch/csrc/autograd/python_torch_functions_manual.cpp
+    functions = load_signatures(native_functions, deprecated_yaml_path, method=False)
+    create_python_bindings_sharded(
+        fm,
+        functions,
+        is_py_torch_function,
+        "torch",
+        "python_torch_functions.cpp",
+        method=False,
+        num_shards=3,
+        symint=symint,
+    )
+
+    create_python_bindings(
+        fm,
+        functions,
+        is_py_nn_function,
+        "torch.nn",
+        "python_nn_functions.cpp",
+        method=False,
+        symint=symint,
+    )
+
+    create_python_bindings(
+        fm,
+        functions,
+        is_py_fft_function,
+        "torch.fft",
+        "python_fft_functions.cpp",
+        method=False,
+        symint=symint,
+    )
+
+    create_python_bindings(
+        fm,
+        functions,
+        is_py_linalg_function,
+        "torch.linalg",
+        "python_linalg_functions.cpp",
+        method=False,
+        symint=symint,
+    )
+
+    create_python_bindings(
+        fm,
+        functions,
+        is_py_nested_function,
+        "torch.nested",
+        "python_nested_functions.cpp",
+        method=False,
+    )
+
+    create_python_bindings(
+        fm,
+        functions,
+        is_py_sparse_function,
+        "torch.sparse",
+        "python_sparse_functions.cpp",
+        method=False,
+        symint=symint,
+    )
+
+    create_python_bindings(
+        fm,
+        functions,
+        is_py_special_function,
+        "torch.special",
+        "python_special_functions.cpp",
+        method=False,
+        symint=symint,
+    )
+
+    # Currently, we only use `functions` to generate `return_types` bindings.
+    # All methods which return structseq have function variant at this point.
+    # If any method only operator with structseq is added in the future,
+    # we will have to address that.
+    create_python_return_type_bindings(
+        fm, functions, lambda fn: True, "python_return_types.cpp"
+    )
+    create_python_return_type_bindings_header(
+        fm, functions, lambda fn: True, "python_return_types.h"
+    )
+
+    valid_tags = parse_tags_yaml(tags_yaml_path)
+
+    def gen_tags_enum() -> dict[str, str]:
+        return {
+            "enum_of_valid_tags": (
+                "".join(
+                    [f'\n.value("{tag}", at::Tag::{tag})' for tag in sorted(valid_tags)]
+                )
+            )
+        }
+
+    fm.write("python_enum_tag.cpp", gen_tags_enum)
+
+
+def group_filter_overloads(
+    pairs: Sequence[PythonSignatureNativeFunctionPair],
+    pred: Callable[[NativeFunction], bool],
+) -> dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]:
+    grouped: dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]] = (
+        defaultdict(list)
+    )
+    for pair in pairs:
+        if pred(pair.function):
+            grouped[pair.function.func.name.name].append(pair)
+    return grouped
+
+
+def create_python_bindings(
+    fm: FileManager,
+    pairs: Sequence[PythonSignatureNativeFunctionPair],
+    pred: Callable[[NativeFunction], bool],
+    module: str | None,
+    filename: str,
+    *,
+    method: bool,
+    symint: bool = True,
+) -> None:
+    """Generates Python bindings to ATen functions"""
+    py_methods: list[str] = []
+    ops_headers: list[str] = []
+    py_method_defs: list[str] = []
+    py_forwards: list[str] = []
+
+    grouped = group_filter_overloads(pairs, pred)
+
+    for name in sorted(grouped.keys(), key=str):
+        overloads = grouped[name]
+        py_methods.append(
+            method_impl(name, module, overloads, method=method, symint=symint)
+        )
+        py_method_defs.append(method_def(name, module, overloads, method=method))
+        py_forwards.extend(forward_decls(name, overloads, method=method))
+        ops_headers.append(f"#include ")
+
+    fm.write_with_template(
+        filename,
+        filename,
+        lambda: {
+            "generated_comment": "@"
+            + f"generated from {fm.template_dir_for_comments()}/{filename}",
+            "ops_headers": ops_headers,
+            "py_forwards": py_forwards,
+            "py_methods": py_methods,
+            "py_method_defs": py_method_defs,
+        },
+    )
+
+
+def create_python_return_type_bindings(
+    fm: FileManager,
+    pairs: Sequence[PythonSignatureNativeFunctionPair],
+    pred: Callable[[NativeFunction], bool],
+    filename: str,
+) -> None:
+    """
+    Generate function to initialize and return named tuple for native functions
+    which returns named tuple and registration invocations in `python_return_types.cpp`.
+    """
+    py_return_types_definition: list[str] = []
+    py_return_types_registrations: list[str] = []
+
+    grouped = group_filter_overloads(pairs, pred)
+
+    for name in sorted(grouped.keys(), key=str):
+        overloads = grouped[name]
+        definitions, registrations = generate_return_type_definition_and_registrations(
+            overloads
+        )
+        py_return_types_definition.append(
+            "" if not definitions else "\n".join(definitions)
+        )
+        py_return_types_registrations.append(
+            "" if not registrations else "\n".join(registrations)
+        )
+
+    fm.write_with_template(
+        filename,
+        filename,
+        lambda: {
+            "generated_comment": "@"
+            + f"generated from {fm.template_dir_for_comments()}/{filename}",
+            "py_return_types": py_return_types_definition,
+            "py_return_types_registrations": py_return_types_registrations,
+        },
+    )
+
+
+def create_python_return_type_bindings_header(
+    fm: FileManager,
+    pairs: Sequence[PythonSignatureNativeFunctionPair],
+    pred: Callable[[NativeFunction], bool],
+    filename: str,
+) -> None:
+    """
+    Generate function to initialize and return named tuple for native functions
+    which returns named tuple and relevant entry for the map in `python_return_types.cpp`.
+    """
+    py_return_types_declarations: list[str] = []
+
+    grouped = group_filter_overloads(pairs, pred)
+
+    for name in sorted(grouped.keys(), key=str):
+        overloads = grouped[name]
+        declarations = generate_return_type_declarations(overloads)
+        py_return_types_declarations.append(
+            "" if not declarations else "\n".join(declarations)
+        )
+
+    fm.write_with_template(
+        filename,
+        filename,
+        lambda: {
+            "generated_comment": "@"
+            + f"generated from {fm.template_dir_for_comments()}/{filename}",
+            "py_return_types_declarations": py_return_types_declarations,
+        },
+    )
+
+
+def create_python_bindings_sharded(
+    fm: FileManager,
+    pairs: Sequence[PythonSignatureNativeFunctionPair],
+    pred: Callable[[NativeFunction], bool],
+    module: str | None,
+    filename: str,
+    *,
+    method: bool,
+    num_shards: int,
+    symint: bool = True,
+) -> None:
+    """Generates Python bindings to ATen functions"""
+    grouped = group_filter_overloads(pairs, pred)
+
+    def key_func(
+        kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]],
+    ) -> str:
+        return kv[0].base
+
+    def env_func(
+        kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]],
+    ) -> dict[str, list[str]]:
+        name, fn_pairs = kv
+        return {
+            "ops_headers": [f"#include "],
+            "py_forwards": list(forward_decls(name, fn_pairs, method=method)),
+            "py_methods": [
+                method_impl(name, module, fn_pairs, method=method, symint=symint)
+            ],
+            "py_method_defs": [method_def(name, module, fn_pairs, method=method)],
+        }
+
+    fm.write_sharded(
+        filename,
+        grouped.items(),
+        base_env={
+            "generated_comment": "@"
+            + f"generated from {fm.template_dir_for_comments()}/{filename}",
+        },
+        key_fn=key_func,
+        env_callable=env_func,
+        num_shards=num_shards,
+        sharded_keys={"ops_headers", "py_forwards", "py_methods", "py_method_defs"},
+    )
+
+
+def load_signatures(
+    native_functions: list[NativeFunction],
+    deprecated_yaml_path: str,
+    *,
+    method: bool,
+    skip_deprecated: bool = False,
+    pyi: bool = False,
+) -> Sequence[PythonSignatureNativeFunctionPair]:
+    @with_native_function
+    def gen_signature_pairs(f: NativeFunction) -> PythonSignatureNativeFunctionPair:
+        return PythonSignatureNativeFunctionPair(
+            signature=signature(f, method=method, pyi=pyi),
+            function=f,
+        )
+
+    pairs = list(map(gen_signature_pairs, native_functions))
+    deprecated = load_deprecated_signatures(
+        pairs, deprecated_yaml_path, method=method, pyi=pyi
+    )
+    return pairs if skip_deprecated else pairs + deprecated
+
+
+def load_deprecated_signatures(
+    pairs: Sequence[PythonSignatureNativeFunctionPair],
+    deprecated_yaml_path: str,
+    *,
+    method: bool,
+    pyi: bool,
+) -> list[PythonSignatureNativeFunctionPair]:
+    # The deprecated.yaml doesn't have complete type information, we need
+    # find and leverage the original ATen signature (to which it delegates
+    # the call) to generate the full python signature.
+    # We join the deprecated and the original signatures using type-only form.
+
+    # group the original ATen signatures by name
+    grouped: dict[str, list[PythonSignatureNativeFunctionPair]] = defaultdict(list)
+    for pair in pairs:
+        grouped[pair.signature.name].append(pair)
+
+    # find matching original signatures for each deprecated signature
+    results: list[PythonSignatureNativeFunctionPair] = []
+
+    with open(deprecated_yaml_path) as f:
+        deprecated_defs = yaml.load(f, Loader=YamlLoader)
+
+    for deprecated in deprecated_defs:
+        schema = FunctionSchema.parse(deprecated["name"])
+        aten_name, call_args = split_name_params(deprecated["aten"])
+        is_out = aten_name.endswith("_out")
+        if is_out:
+            aten_name = aten_name.replace("_out", "")
+
+        # HACK: these are fixed constants used to pass the aten function.
+        # The type must be known ahead of time
+        known_constants = {
+            "1": Type.parse("Scalar"),
+        }
+        schema_args_by_name = {a.name: a for a in schema.arguments.flat_all}
+        for name in call_args:
+            assert name in schema_args_by_name or name in known_constants, (
+                f"deprecation definition: Unrecognized value {name}"
+            )
+
+        # Map deprecated signature arguments to their aten signature and test
+        # if the types and alias annotation match.
+        def is_schema_compatible(
+            aten_schema: FunctionSchema,
+        ) -> bool:
+            arguments: Iterable[Argument]
+            if is_out:
+                arguments = itertools.chain(
+                    aten_schema.arguments.out, aten_schema.arguments.flat_non_out
+                )
+            else:
+                arguments = aten_schema.arguments.flat_all
+
+            for i, arg in enumerate(arguments):
+                if i < len(call_args):
+                    arg_name = call_args[i]
+                    if arg_name in known_constants:
+                        schema_type = known_constants[arg_name]
+                        schema_annotation = None
+                    else:
+                        schema_arg = schema_args_by_name[arg_name]
+                        schema_type = schema_arg.type
+                        schema_annotation = schema_arg.annotation
+
+                    if schema_type != arg.type or schema_annotation != arg.annotation:
+                        return False
+                else:
+                    if arg.default is None:
+                        return False
+
+            return len(schema.returns) == len(aten_schema.returns) and all(
+                a == b for a, b in zip(schema.returns, aten_schema.returns)
+            )
+
+        any_schema_found = False
+        for pair in grouped[aten_name]:
+            if not is_schema_compatible(pair.function.func):
+                continue
+            any_schema_found = True
+
+            python_sig = signature_from_schema(
+                schema,
+                category_override=pair.function.category_override,
+                method=method,
+                pyi=pyi,
+            )
+
+            results.append(
+                PythonSignatureNativeFunctionPair(
+                    signature=PythonSignatureDeprecated(
+                        name=python_sig.name,
+                        input_args=python_sig.input_args,
+                        input_kwargs=python_sig.input_kwargs,
+                        output_args=python_sig.output_args,
+                        tensor_options_args=python_sig.tensor_options_args,
+                        method=python_sig.method,
+                        deprecated_schema=schema,
+                        deprecated_args_exprs=tuple(call_args),
+                        returns=python_sig.returns,
+                    ),
+                    function=pair.function,
+                )
+            )
+        assert any_schema_found, (
+            f"No native function with name {aten_name} matched signature:\n  {str(schema)}"
+        )
+
+    return results
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                         Named Tuple Codegen
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+@with_native_function
+def gen_structseq_typename_key(f: NativeFunction) -> str:
+    name = cpp.name(f.func)
+    fieldnames = structseq_fieldnames(f.func.returns)
+    return "_".join([name] + fieldnames)
+
+
+def emit_structseq_call(
+    overloads: Sequence[PythonSignatureNativeFunctionPair],
+) -> tuple[list[str], dict[str, str]]:
+    """
+    Generate block of named tuple type def inits, and add typeref snippets
+    to declarations that use them
+    """
+    typenames: dict[
+        str, str
+    ] = {}  # map from unique name + field name lists to typedef name
+    typedefs: list[str] = []  # typedef declarations and init code
+
+    for overload in overloads:
+        fieldnames = structseq_fieldnames(overload.function.func.returns)
+        if not fieldnames:
+            continue
+
+        name = cpp.name(overload.function.func)  # use @with_native_function?
+        tn_key = gen_structseq_typename_key(overload.function)
+        typename = typenames.get(tn_key)
+        if typename is None:
+            typename = f"NamedTuple{'' if not typedefs else len(typedefs)}"
+            typenames[tn_key] = typename
+            typedefs.append(
+                f"""\
+static PyTypeObject* {typename} = generated::get_{name}_structseq();"""
+            )
+
+    return typedefs, typenames
+
+
+def generate_return_type_definition_and_registrations(
+    overloads: Sequence[PythonSignatureNativeFunctionPair],
+) -> tuple[list[str], list[str]]:
+    """
+    Generate block of function in `python_return_types.cpp` to initialize
+    and return named tuple for a native function which returns named tuple
+    and registration invocations in same file.
+    """
+    typenames: dict[
+        str, str
+    ] = {}  # map from unique name + field name lists to typedef name
+    definitions: list[str] = []  # function definition to register the typedef
+    registrations: list[str] = []  # register call for the typedef
+
+    for overload in overloads:
+        fieldnames = structseq_fieldnames(overload.function.func.returns)
+        if not fieldnames:
+            continue
+
+        fields = ", ".join(f'{{"{fn}", ""}}' for fn in fieldnames)
+
+        name = cpp.name(overload.function.func)  # use @with_native_function?
+        tn_key = gen_structseq_typename_key(overload.function)
+        typename = typenames.get(tn_key)
+
+        if typename is None:
+            typename = f"{name}NamedTuple{'' if not definitions else len(definitions)}"
+            typenames[tn_key] = typename
+            definitions.append(
+                f"""\
+PyTypeObject* get_{name}_structseq() {{
+    static PyStructSequence_Field NamedTuple_fields[] = {{ {fields},  {{nullptr}} }};
+    static PyTypeObject {typename};
+    static bool is_initialized = false;
+    static PyStructSequence_Desc desc = {{ "torch.return_types.{name}", nullptr, NamedTuple_fields, {len(fieldnames)} }};
+    if (!is_initialized) {{
+        PyStructSequence_InitType(&{typename}, &desc);
+        {typename}.tp_repr = (reprfunc)torch::utils::returned_structseq_repr;
+        is_initialized = true;
+    }}
+    return &{typename};
+}}
+"""
+            )
+            registrations.append(
+                f'addReturnType(return_types_module, "{name}", generated::get_{name}_structseq());'
+            )
+
+    return definitions, registrations
+
+
+def generate_return_type_declarations(
+    overloads: Sequence[PythonSignatureNativeFunctionPair],
+) -> list[str]:
+    """
+    Generate block of function declarations in `python_return_types.h` to initialize
+    and return named tuple for a native function.
+    """
+    typenames: dict[
+        str, str
+    ] = {}  # map from unique name + field name lists to typedef name
+    declarations: list[str] = []  # function declaration to register the typedef
+
+    for overload in overloads:
+        fieldnames = structseq_fieldnames(overload.function.func.returns)
+        if not fieldnames:
+            continue
+
+        name = cpp.name(overload.function.func)  # use @with_native_function?
+        tn_key = gen_structseq_typename_key(overload.function)
+        typename = typenames.get(tn_key)
+
+        if typename is None:
+            typename = (
+                f"{name}NamedTuple{'' if not declarations else len(declarations)}"
+            )
+            typenames[tn_key] = typename
+            declarations.append(f"PyTypeObject* get_{name}_structseq();")
+
+    return declarations
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                         Method Impl Codegen
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+# python binding for all overloads of a particular function/method
+PY_VARIABLE_METHOD_VARARGS = CodeTemplate(
+    r"""\
+// ${name}
+static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
+{
+  ${method_header}
+  static PythonArgParser parser({
+    ${signatures}
+  }, /*traceable=*/${traceable});
+
+  ParsedArgs<${max_args}> parsed_args;
+  auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
+  ${check_has_torch_function}
+  switch (_r.idx) {
+    ${dispatch}
+  }
+  ${method_footer}
+}
+
+"""
+)
+
+# handler for a single parsed signature - may be a single overload or
+# a pair of overloads that whose signatures only differ in output params
+# (plugged into PY_VARIABLE_METHOD_VARARGS as an item in ${dispatch})
+PY_VARIABLE_CASE = CodeTemplate(
+    """\
+case ${overload_index}: {
+  ${body}
+}
+"""
+)
+
+# python binding for single-overload function/method
+PY_VARIABLE_METHOD_VARARGS_SINGLETON = CodeTemplate(
+    """\
+// ${name}
+static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs)
+{
+  ${method_header}
+  static PythonArgParser parser({
+    ${signatures}
+  }, /*traceable=*/${traceable});
+
+  ParsedArgs<${max_args}> parsed_args;
+  auto _r = parser.parse(${self_}, args, kwargs, parsed_args);
+  ${check_has_torch_function}
+  ${dispatch}
+  ${method_footer}
+}
+
+"""
+)
+
+# python binding for a method with no args, shortcuts parsing
+PY_VARIABLE_METHOD_NOARGS = CodeTemplate(
+    """\
+// ${name}
+static PyObject * ${pycname}(PyObject* self_, PyObject* args)
+{
+  ${method_header}
+  ${check_has_torch_function}
+  ${dispatch}
+  ${method_footer}
+}
+
+"""
+)
+
+
+def method_impl(
+    name: BaseOperatorName,
+    module: str | None,
+    overloads: Sequence[PythonSignatureNativeFunctionPair],
+    *,
+    method: bool,
+    symint: bool = True,
+) -> str:
+    """
+    Generate a python binding for all overloads of an op.
+    """
+    pycname = get_pycname(name)
+    noarg = is_noarg(overloads)
+    structseq_inits, structseq_typenames = emit_structseq_call(overloads)
+
+    method_header = ["HANDLE_TH_ERRORS"]
+    method_header += structseq_inits
+    method_header += (
+        ["const Tensor& self = THPVariable_Unpack(self_);"] if method else []
+    )
+
+    method_footer = ([] if noarg else ["Py_RETURN_NONE;"]) + ["END_HANDLE_TH_ERRORS"]
+
+    traceable = "true" if all(should_trace(o.function) for o in overloads) else "false"
+
+    grouped_overloads: Sequence[PythonSignatureGroup] = group_overloads(
+        overloads, symint=symint
+    )
+    is_singleton = len(grouped_overloads) == 1
+    signatures: list[str] = []
+    dispatch: list[str] = []
+    for overload_index, overload in enumerate(grouped_overloads):
+        signature = overload.signature.signature_str(symint=symint)
+        signatures.append(f"{cpp_string(str(signature))},")
+        dispatch_body = emit_dispatch_case(overload, structseq_typenames, symint=symint)
+        dispatch.append(
+            PY_VARIABLE_CASE.substitute(
+                overload_index=overload_index, body=dispatch_body
+            )
+            if not is_singleton
+            else dispatch_body
+        )
+
+    if noarg:
+        template = PY_VARIABLE_METHOD_NOARGS
+    elif is_singleton:
+        template = PY_VARIABLE_METHOD_VARARGS_SINGLETON
+    else:
+        template = PY_VARIABLE_METHOD_VARARGS
+
+    return template.substitute(
+        name=name,
+        pycname=pycname,
+        method_header=method_header,
+        max_args=max(o.signature.arguments_count() for o in overloads),
+        signatures=signatures,
+        traceable=traceable,
+        check_has_torch_function=gen_has_torch_function_check(
+            name=name,
+            module=module,
+            noarg=noarg,
+            method=method,
+        ),
+        dispatch=dispatch,
+        method_footer=method_footer,
+        self_="self_" if method else "nullptr",
+    )
+
+
+def gen_has_torch_function_check(
+    name: BaseOperatorName, module: str | None, *, noarg: bool, method: bool
+) -> str:
+    if noarg:
+        if method:
+            return f"""\
+if(check_has_torch_function(self_)) {{
+  return handle_torch_function(self_, "{name}");
+}}
+"""
+        else:
+            return ""
+
+    self_ = "self_" if method else "nullptr"
+    namespace = (
+        {
+            "torch": "THPVariableFunctionsModule",
+            "torch.nn": "THPNNVariableFunctionsModule",
+            "torch.fft": "THPFFTVariableFunctionsModule",
+            "torch.linalg": "THPLinalgVariableFunctionsModule",
+            "torch.nested": "THPNestedVariableFunctionsModule",
+            "torch.sparse": "THPSparseVariableFunctionsModule",
+            "torch.special": "THPSpecialVariableFunctionsModule",
+        }[module]
+        if module
+        else "THPVariableClass"
+    )
+
+    return f"""\
+if(_r.has_torch_function()) {{
+  return handle_torch_function(_r, {self_}, args, kwargs, {namespace}, "{module or "torch.Tensor"}");
+}}
+"""
+
+
+# handler for output/no-output overload pair
+PY_VARIABLE_OUT = CodeTemplate(
+    """\
+if (_r.isNone(${out_idx})) {
+  ${call_dispatch}
+} else {
+  ${call_dispatch_out}
+}
+"""
+)
+
+
+def emit_dispatch_case(
+    overload: PythonSignatureGroup,
+    structseq_typenames: dict[str, str],
+    *,
+    symint: bool = True,
+) -> str:
+    """
+    Emit dispatch code for a single parsed signature. This corresponds to either
+    a single native function, or a pair that differ only in output params. In the
+    latter case, a single python signature is used for both and dispatching
+    switches on the presence/absence of passed output args.
+    """
+    if overload.outplace is not None:
+        # dispatch output and no-output variants, branch on _r.isNone()
+        return PY_VARIABLE_OUT.substitute(
+            out_idx=overload.signature.output_idx(),
+            call_dispatch=emit_single_dispatch(
+                overload.signature, overload.base, structseq_typenames, symint=symint
+            ),
+            call_dispatch_out=emit_single_dispatch(
+                overload.signature,
+                overload.outplace,
+                structseq_typenames,
+                symint=symint,
+            ),
+        )
+    else:
+        # no-output version only
+        return emit_single_dispatch(
+            overload.signature, overload.base, structseq_typenames, symint=symint
+        )
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                    Forward Declarations Codegen
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+def forward_decls(
+    name: BaseOperatorName,
+    overloads: Sequence[PythonSignatureNativeFunctionPair],
+    *,
+    method: bool,
+) -> tuple[str, ...]:
+    if method:
+        return ()
+
+    pycname = get_pycname(name)
+    if is_noarg(overloads):
+        return (
+            f"""\
+static PyObject * {pycname}(PyObject* self_, PyObject* args);
+""",
+        )
+    else:
+        return (
+            f"""\
+static PyObject * {pycname}(PyObject* self_, PyObject* args, PyObject* kwargs);
+""",
+        )
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#              Method Def (Binding Table Entry) Codegen
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+def method_def(
+    name: BaseOperatorName,
+    module: str | None,
+    overloads: Sequence[PythonSignatureNativeFunctionPair],
+    *,
+    method: bool,
+) -> str:
+    """
+    Generate method def entry.
+    """
+    pycname = get_pycname(name)
+
+    if name.dunder_method:
+        # PyMethodDef entry for binary op, throws not implemented error
+        pycname = f"TypeError_to_NotImplemented_<{pycname}>"
+
+    if is_noarg(overloads):
+        flags = "METH_NOARGS" if method else "METH_VARARGS | METH_KEYWORDS"
+    else:
+        pycname = f"castPyCFunctionWithKeywords({pycname})"
+        flags = "METH_VARARGS | METH_KEYWORDS"
+
+    if module == "torch":
+        flags += " | METH_STATIC"
+
+    return f'{{"{name}", {pycname}, {flags}, nullptr}},'
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                   Overload Sorting and Grouping
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+def group_overloads(
+    overloads: Sequence[PythonSignatureNativeFunctionPair], *, symint: bool = True
+) -> Sequence[PythonSignatureGroup]:
+    bases: dict[str, PythonSignatureNativeFunctionPair] = {}
+    outplaces: dict[str, PythonSignatureNativeFunctionPair] = {}
+
+    # first group by signature ignoring out arguments
+    for overload in overloads:
+        sig = overload.signature.signature_str(skip_outputs=True, symint=symint)
+        if overload.function.func.is_out_fn():
+            if sig in outplaces:
+                raise RuntimeError(
+                    f"Found duplicated function definition:\n- {overload.function.func}.\n"
+                    f"Existing definition:\n- {outplaces[sig].function.func}."
+                )
+            outplaces[sig] = overload
+        else:
+            if sig in bases:
+                raise RuntimeError(
+                    f"Found duplicated function definition:\n- {overload.function.func}.\n"
+                    f"Existing definition:\n- {bases[sig].function.func}."
+                )
+            bases[sig] = overload
+
+    for sig, out in outplaces.items():
+        if sig not in bases:
+            candidates: list[str] = []
+            for overload in overloads:
+                if (
+                    str(overload.function.func.name.name)
+                    == str(out.function.func.name.name)
+                    and not overload.function.func.is_out_fn()
+                    and not overload.signature.deprecated
+                ):
+                    candidates.append(
+                        overload.signature.signature_str(
+                            skip_outputs=True, symint=symint
+                        )
+                    )
+            out_sig = out.signature.signature_str(symint=symint)
+            raise RuntimeError(
+                f"While identifying overloads, we found an out schema {out_sig} without a corresponding non-out variant. "
+                f"We expected the non-out variant to have schema: \n- {sig}\nPlease check that you spelled the schema "
+                "correctly in native_functions.yaml. We discovered the following candidate(s): \n"
+                + "\n".join(f"- {candidate}" for candidate in candidates)
+            )
+
+    grouped = [
+        PythonSignatureGroup.from_pairs(
+            functional=base,
+            out=outplaces.get(sig),
+        )
+        for sig, base in bases.items()
+    ]
+    return sort_overloads(grouped, symint=symint)
+
+
+# This function declares a partial order on declarations, and sorts them according
+# to its linear extension. This is necessary, because there's some ambiguity in the
+# choice of overload, and we want a different order.
+#
+# See Note[Order of overloads matters]
+#
+# A few examples of ambiguous python signature pairs.
+#
+#   All parameters have the same type, except one taking Tensor the other taking
+#   Scalar. A numeric PyObject can be casted into Tensor, and a zero-dim Tensor
+#   object can be accepted as Scalar type parameter (see python_arg_parser.cpp).
+#   Therefore, same input arguments might be accepted by either python signature.
+#   We want to always parse the one taking Tensor first.
+#
+#     bitwise_and(Tensor input, Tensor other, *, Tensor out=None)
+#     bitwise_and(Tensor input, Scalar other, *, Tensor out=None)
+#
+#   If they have different number of parameters then they are not ambiguous - but
+#   the difference on output param can be ignored as it's optional.
+#
+#     multiply(Tensor input, Tensor other, *, Tensor out=None)
+#     multiply(Tensor input, Scalar other)
+#
+#   Both positional args and keyword-only args are considered together.
+#
+#     subtract(Tensor other, *, Scalar alpha=1)
+#     subtract(Scalar other, Scalar alpha=1)
+#
+# A few ambiguous cases which it does NOT handle yet.
+#
+#   If there is any difference in other parameters besides the Tensor/Scalar
+#   difference, then they are not considered ambiguous by this method anymore.
+#   However, the difference could be too trivial to disambiguate.
+#
+#     foo(Tensor input, Scalar other, Scalar bar)
+#     foo(Tensor input, Tensor other, double bar)
+#
+#   If they are taking different number of parameters then they are not considered
+#   ambiguous anymore, even if the difference is only on optional kwargs.
+#
+#     foo(Scalar other, Scalar alpha=1)
+#     foo(Tensor other, *, Scalar alpha=1, Scalar beta=1)
+#
+
+
+def sort_overloads(
+    grouped_overloads: Sequence[PythonSignatureGroup], *, symint: bool = True
+) -> Sequence[PythonSignatureGroup]:
+    # NB: Smaller here means lower priority
+
+    def is_arg_smaller(t1: Type, t2: Type) -> bool:
+        return (
+            str(t1) == "Scalar"
+            and str(t2) == "Tensor"
+            or str(t1) == "Scalar?"
+            and str(t2) == "Tensor?"
+            or "Dimname" in str(t1)
+            and "Dimname" not in str(t2)
+            or
+            # In the discussion https://github.com/pytorch/pytorch/issues/54555 it has been
+            # discussed why it is important to prioritize int/int? over int[]
+            str(t1) == "int[]"
+            and (str(t2) == "int" or str(t2) == "int?")
+            or
+            # TensorList currently throws an error during argument parsing, that's why it needs to be
+            # last in signature ordering. See discussion: https://github.com/pytorch/pytorch/issues/58087
+            str(t1) == "Tensor[]"
+            and str(t2).find("[]") != -1
+            or
+            # Prioritize IntArrayRef overload over SymIntArrayRef
+            str(t1) == "SymInt[]"
+            and str(t2) == "int[]"
+            or
+            # Make sure both in, SymInt are sorted consistently w.r.t. Tensor since Tensor can be implicitly
+            # converted to either int or SymInt.  Prioritize the Tensor overload since it otherwise gets shadowed.
+            (str(t1) == "SymInt" or str(t1) == "int")
+            and str(t2) == "Tensor"
+        )
+
+    def is_smaller(s1: PythonSignature, s2: PythonSignature) -> bool:
+        """Returns True if s1 < s2 in the partial order."""
+        args1, args2 = s1.arguments(skip_outputs=True), s2.arguments(skip_outputs=True)
+        if len(args1) != len(args2):
+            return False
+        # TODO: should use some canonical form instead of 'str(arg.type)' - see comments
+        # above. The old codegen used the deprecated 'dynamic_type(arg.type)', which
+        # ignores the optional annotation, i.e. 'Scalar' and 'Scalar?'.
+        equal = all(arg1.type == arg2.type for arg1, arg2 in zip(args1, args2))
+        smaller_or_equal = all(
+            str(arg1.type) == str(arg2.type) or is_arg_smaller(arg1.type, arg2.type)
+            for arg1, arg2 in zip(args1, args2)
+        )
+        return smaller_or_equal and not equal
+
+    # First sort by signature
+    grouped_overloads = sorted(
+        grouped_overloads, key=lambda x: x.signature.signature_str(symint=symint)
+    )
+
+    # Construct the relation graph
+    larger_than: dict[int, set[int]] = defaultdict(set)
+    for i1, overload1 in enumerate(grouped_overloads):
+        for i2, overload2 in enumerate(grouped_overloads):
+            if is_smaller(overload1.signature, overload2.signature):
+                larger_than[i1].add(i2)
+
+    if not larger_than:
+        return list(grouped_overloads)
+
+    # Use a topological sort to sort overloads according to the partial order.
+    N = len(grouped_overloads)
+    sorted_ids: list[int] = list(filter(lambda x: x not in larger_than, range(N)))
+
+    for idx in range(N):
+        # The size of sorted_ids will grow to N eventually.
+        i = sorted_ids[idx]
+        for j in sorted(larger_than.keys()):
+            larger = larger_than[j]
+            larger.discard(i)
+            if not larger:
+                del larger_than[j]
+                sorted_ids.append(j)
+
+    return [grouped_overloads[x] for x in sorted_ids]
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+#
+#                       Codegen API Integration
+#
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
+
+
+def emit_single_dispatch(
+    ps: PythonSignature,
+    f: NativeFunction,
+    structseq_typenames: dict[str, str],
+    *,
+    symint: bool = True,
+) -> str:
+    """
+    Emit dispatch code for a single native function.
+    """
+
+    @with_native_function
+    def go(f: NativeFunction) -> str:
+        # header comments
+        if isinstance(ps, PythonSignatureDeprecated):
+            schema_comment = f"// [deprecated] aten::{ps.deprecated_schema}"
+        else:
+            schema_comment = f"// aten::{f.func}"
+
+        # dispatch lambda signature
+        name = cpp.name(f.func)
+        lambda_formals = ", ".join(
+            f"{a.type_str} {a.name}" for a in dispatch_lambda_args(ps, f, symint=symint)
+        )
+        lambda_return = dispatch_lambda_return_str(f)
+
+        # dispatch lambda body
+        dispatch_callee = cpp_dispatch_target(f)
+        dispatch_args = ", ".join(cpp_dispatch_exprs(f, python_signature=ps))
+
+        # from arg parser outputs to dispatch lambda arguments
+        parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
+        lambda_arg_exprs = dispatch_lambda_exprs(ps, f, symint=symint)
+        inits = "\n".join(lambda_arg_exprs.inits)
+        lambda_args = ", ".join(lambda_arg_exprs.exprs)
+
+        # scatter fields
+        # TODO: Checking `ps.method and ('requires_grad' in parser_outputs)` is a hacky
+        #       solution for enabling the 'requires_grad' argument for tensor methods
+        #       new_full, new_empty, and new_zeros. A much better but more difficult to
+        #       implement solution involves refactoring according to Ed's description here:
+        #       https://github.com/pytorch/pytorch/issues/36455#issuecomment-614767589
+        need_set_requires_grad = ps.tensor_options_args and (
+            not has_tensor_options(f)
+            or (ps.method and ("requires_grad" in parser_outputs))
+        )
+        set_requires_grad = (
+            f".set_requires_grad({parser_outputs['requires_grad'].expr})"
+            if need_set_requires_grad
+            else ""
+        )
+
+        if lambda_return == "void":
+            # Make in-place foreach return `self` at python-binding level.
+            # ref: https://github.com/pytorch/pytorch/pull/118622#pullrequestreview-1904804954
+            self_arg = f.func.arguments.self_arg
+            return_stmt: str
+            if (
+                str(f.func.name).startswith("_foreach_")
+                and f.func.kind() == SchemaKind.inplace
+            ):
+                # note(crcrpar): `_foreach_pow.ScalarAndTensor` does NOT have its in-place
+                # variant and it unlikely to have it in the future. Thus it's safe to have the following assert.
+                assert self_arg is not None and is_tensor_list_type(
+                    self_arg.argument.type
+                )
+                return_stmt = """PyObject* self_tensorlist = _r.args[0];
+Py_INCREF(self_tensorlist);
+return self_tensorlist;
+"""
+            else:
+                return_stmt = "Py_RETURN_NONE;"
+            return f"""\
+{schema_comment}
+{inits}
+auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
+  pybind11::gil_scoped_release no_gil;
+  {dispatch_callee}({dispatch_args});
+}};
+dispatch_{name}({lambda_args}){set_requires_grad};
+{return_stmt}
+"""
+        else:
+            typename = structseq_typenames.get(gen_structseq_typename_key(f))
+            structseq_typeref = f"{typename}, " if typename is not None else ""
+            return f"""\
+{schema_comment}
+{inits}
+auto dispatch_{name} = []({lambda_formals}) -> {lambda_return} {{
+  pybind11::gil_scoped_release no_gil;
+  return {dispatch_callee}({dispatch_args});
+}};
+return wrap({structseq_typeref}dispatch_{name}({lambda_args}){set_requires_grad});
+"""
+
+    return go(f)
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_trace_type.py b/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_trace_type.py
new file mode 100644
index 0000000000000000000000000000000000000000..3351e23efe45dd3a312738d6ffb1c58617e1eb28
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_trace_type.py
@@ -0,0 +1,539 @@
+from __future__ import annotations
+
+import itertools
+from typing import TYPE_CHECKING
+
+from torchgen.api import cpp
+from torchgen.api.types import DispatcherSignature
+from torchgen.code_template import CodeTemplate
+from torchgen.context import with_native_function
+from torchgen.model import Argument, NativeFunction, SchemaKind, TensorOptionsArguments
+from torchgen.utils import FileManager
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+
+# Note [Manual Backend kernels]
+# For these ops, we want to manually register to dispatch key Backend and
+# skip codegen-ed registration to all keys before Backend.
+# For codegen this means:
+#   - op set below must match ops with manual_kernel_registration=True in native_functions.yaml
+#     where we skip codegen backend kernels
+#   - all ops below are part of MANUAL_AUTOGRAD to skip codegen Autograd kernel registration
+#   - all ops below are part of MANUAL_TRACER to skip codegen Tracer kernel registration
+# Note: we still register to dispatch key Profiler for these ops, keeping it untouched for now.
+# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
+MANUAL_BACKEND = {
+    "options",
+    "data",
+    "set_data",
+    "is_leaf",
+    "output_nr",
+    "_version",
+    "retain_grad",
+    "_backward",
+    "requires_grad_",
+}
+
+# For these ops we want to skip the codegen-ed registration to both Autograd and Tracer keys.
+# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
+MANUAL_AUTOGRAD_AND_TRACER = {
+    "resize_",
+    "resize_as_",
+    "detach",
+    "detach_",
+    "copy_",
+    "_fw_primal",
+    "_make_dual",
+}
+
+# Currently MANUAL_AUTOGRAD and MANUAL_TRACER share the same set of ops:
+#   union(MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER)
+# You can find the manual registration in torch/csrc/autograd/VariableTypeManual.cpp
+MANUAL_AUTOGRAD = MANUAL_TRACER = MANUAL_BACKEND | MANUAL_AUTOGRAD_AND_TRACER
+
+# These functions we don't want to record for tracing, because we always want
+# to trace their constituent parts.  This is a temporary hack in lieue
+# of proper scopes, where subsequent compilation passes can ask for the unfolding
+# on demand.  Only concrete ATen methods can be disabled this way; it will have
+# NO EFFECT otherwise.
+DONT_RECORD_TRACE = {
+    "convolution",
+    "conv1d",
+    "conv2d",
+    "conv3d",
+    "conv_transpose1d",
+    "conv_transpose2d",
+    "conv_transpose3d",
+    "lstm_cell",
+    "gru_cell",
+    "rnn_tanh_cell",
+    "rnn_relu_cell",
+    # FIXME: figure out a better way when we support sparse tensors in jit
+    "_coalesced",
+}
+
+
+def should_trace(f: NativeFunction) -> bool:
+    # Operations involving Storage or Type are not traceable at the moment
+    if any(
+        str(arg.type) in {"Storage", "Type"} for arg in f.func.schema_order_arguments()
+    ):
+        return False
+    # We can't trace functions which don't have any Tensor or TensorList returns
+    if not any(r.type.is_tensor_like() for r in f.func.returns):
+        return False
+    return f.func.name.name.base not in DONT_RECORD_TRACE
+
+
+SELECT = CodeTemplate(
+    """\
+
+if (${cond}) {
+  ${true}
+} else {
+  ${false}
+}
+"""
+)
+
+OP_NAME = CodeTemplate(
+    """\
+op_name = c10::Symbol::fromQualString("aten::${trace_name}");
+"""
+)
+
+# These functions have their names recorded under trace renamed,
+RENAME_TRACE = {
+    "zero": "zeros_like",  # replacing aten::zero_ with aten::zeros_like
+    "fill": "full_like",  # replacing aten::fill_ with aten::full_like
+}
+
+
+def format_trace_op_name(f: NativeFunction) -> str:
+    # TODO: byte-for-byte compatible with old codegen behavior - should clean up
+    if (
+        f.func.kind() in (SchemaKind.functional, SchemaKind.out)
+        or f.func.name.name.dunder_method
+    ):
+        # special case for *_out functions: the in-place and out-of-place ops
+        # are overloaded with the same name in the JIT
+        trace_name = str(f.func.name.name)
+        trace_name = RENAME_TRACE.get(trace_name, trace_name)
+        return OP_NAME.substitute(trace_name=trace_name)
+
+    # otherwise, this is an in-place op and we need to emit both in- and
+    # out-of-place versions
+    outplace_trace_name = f.func.name.name.base
+    inplace_trace_name = cpp.name(f.func)
+    outplace_trace_name = RENAME_TRACE.get(outplace_trace_name, outplace_trace_name)
+    inplace_trace_name = RENAME_TRACE.get(inplace_trace_name, inplace_trace_name)
+
+    return SELECT.substitute(
+        cond="tracer_state->force_outplace",
+        true=OP_NAME.substitute(trace_name=outplace_trace_name),
+        false=OP_NAME.substitute(trace_name=inplace_trace_name),
+    )
+
+
+ADD_TRACE_INPUT = CodeTemplate("""jit::tracer::addInputs(node, "${name}", ${input});""")
+
+
+def format_trace_inputs(f: NativeFunction) -> str:
+    def dispatch_trace_input(arg: Argument | TensorOptionsArguments) -> Sequence[str]:
+        if isinstance(arg, TensorOptionsArguments):
+            name = "options"
+            return [
+                ADD_TRACE_INPUT.substitute(
+                    name=name, input="c10::optTypeMetaToScalarType(options.dtype_opt())"
+                ),
+                ADD_TRACE_INPUT.substitute(name=name, input="options.layout()"),
+                ADD_TRACE_INPUT.substitute(name=name, input="options.device()"),
+                ADD_TRACE_INPUT.substitute(name=name, input="options.pinned_memory()"),
+            ]
+        else:
+            name = arg.name
+            if str(arg.type) == "Tensor?[]":
+                return [f'jit::tracer::addInputs(node, "{name}", {name});']
+            else:
+                return [ADD_TRACE_INPUT.substitute(name=name, input=name)]
+
+    args: list[Argument | TensorOptionsArguments] = list(
+        f.func.schema_order_arguments()
+    )
+
+    if f.func.is_out_fn():
+        # *_out functions take the result as a separate argument, but we don't want to
+        # trace that argument directly. Instead, we trace its TensorOptions.
+        # So first, we need to remove the out argument from the list of arguments to trace.
+        num_out_args = len(f.func.arguments.out)
+        args = args[:-num_out_args]
+
+    trace_inputs = itertools.chain.from_iterable(
+        dispatch_trace_input(arg) for arg in args
+    )
+
+    if f.func.is_out_fn():
+        # for *_out functions, handle the result argument differently for inplace/outplace.
+        # For inplace: just add the input to the end to confirm with the JIT schema
+        inplace = [
+            ADD_TRACE_INPUT.substitute(
+                name=f.func.arguments.out[i].name, input=f.func.arguments.out[i].name
+            )
+            for i in range(num_out_args)
+        ]
+
+        # for outplace: do nothing, except if the function is a factory.
+        # Factories are a bit special because their out-of-place overloads
+        # take an extra TensorOptions argument, which is missing in the _out function
+        has_tensor_return = any(r.type.is_tensor_like() for r in f.func.returns)
+        has_tensor_input_arg = any(
+            a.type.is_tensor_like() for a in f.func.arguments.flat_non_out
+        )
+        is_factory_method = f.category_override == "factory" or (
+            has_tensor_return and not has_tensor_input_arg
+        )
+
+        # HACK: preserve old codegen behavior - the old codegen set the `is_factory_method`
+        # flag for the whole family of ops with the same basename if any of them is a
+        # factory method. For most cases the whole family of ops are indeed all factory
+        # method - 'normal' is the only exception. So we handle it specially here to avoid
+        # cloning the old logic.
+        if f.func.name.name.base == "normal":
+            is_factory_method = True
+
+        if is_factory_method:
+            outplace = [
+                ADD_TRACE_INPUT.substitute(
+                    name="out",
+                    input="c10::optTypeMetaToScalarType(out.options().dtype_opt())",
+                ),
+                ADD_TRACE_INPUT.substitute(name="out", input="out.options().layout()"),
+                ADD_TRACE_INPUT.substitute(name="out", input="out.options().device()"),
+                ADD_TRACE_INPUT.substitute(
+                    name="out", input="out.options().pinned_memory()"
+                ),
+            ]
+        else:
+            outplace = []
+
+        trace_inputs = itertools.chain(
+            trace_inputs,
+            [
+                SELECT.substitute(
+                    cond="tracer_state->force_outplace",
+                    true="\n".join(outplace),
+                    false="\n".join(inplace),
+                )
+            ],
+        )
+
+    return "\n".join(trace_inputs)
+
+
+# `torch.jit.trace` have undocumented keyword argument `_force_outplace`,
+# which force jit to replace functions with outplace variants (for
+# example `aten::add_` becomes `aten::add`).
+#
+# This replacement implemented in-place with minimum modifications of
+# arguments stack (as it assumes that outplace call has the same arguments
+# as inplace version).
+#
+# However there are no such substitutions available for `aten::fill_`
+# and `aten::zero_` operators, as we never implemented `aten::fill`
+# and `aten::zero`. So jit tracing hack replacing `aten::zero_` with
+# `aten::zeros_like` and replacing `aten::fill_` with `aten::full_like`.
+#
+# But as they potentially can have different arguments, we also have
+# to hack into the stack and add missing ones.
+#
+# A possible alternative would be:
+#
+#  - Add `aten::fill` and `aten::zero`
+#
+#  - Or keep `aten::zeros_like` arguments aligned with `aten::zero_`
+# arguments (inside of the `native_functions.yaml`)
+RENAME_TRACE_ADD_ARGS = {
+    "fill": """\
+    jit::tracer::addInputs(node, "options", ::std::optional());
+    jit::tracer::addInputs(node, "options", layout_or_default(::std::nullopt));
+    jit::tracer::addInputs(node, "options", device_or_default(::std::nullopt));
+    jit::tracer::addInputs(node, "options", pinned_memory_or_default(::std::nullopt));
+    ::std::optional memory_format = c10::MemoryFormat::Preserve;
+    jit::tracer::addInputs(node, "memory_format", memory_format);
+""",
+    "zero": """\
+    jit::tracer::addInputs(node, "options", ::std::optional());
+    jit::tracer::addInputs(node, "options", layout_or_default(::std::nullopt));
+    jit::tracer::addInputs(node, "options", device_or_default(::std::nullopt));
+    jit::tracer::addInputs(node, "options", pinned_memory_or_default(::std::nullopt));
+    ::std::optional memory_format = c10::MemoryFormat::Preserve;
+    jit::tracer::addInputs(node, "memory_format", memory_format);
+""",
+}
+
+INPLACE_GUARD = CodeTemplate(
+    """\
+jit::tracer::ensureUniqueIfOutOfPlaced("${name}", ${mutable_input});
+"""
+)
+
+PRE_RECORD_TRACE = CodeTemplate(
+    """\
+torch::jit::Node* node = nullptr;
+std::shared_ptr tracer_state;
+if (jit::tracer::isTracing()) {
+  tracer_state = jit::tracer::getTracingState();
+  at::Symbol op_name;
+  ${set_op_name}
+  node = tracer_state->createNode(op_name, /*num_outputs=*/0);
+  jit::tracer::recordSourceLocation(node);
+  ${add_trace_inputs}
+  tracer_state->insertNode(node);
+  ${inplace_guard}
+  jit::tracer::setTracingState(nullptr);
+}
+"""
+)
+
+
+def format_prerecord_trace(f: NativeFunction) -> str:
+    if not should_trace(f):
+        return ""
+
+    # TODO: clean up old codegen behavior
+    is_inplace = (
+        f.func.kind() in (SchemaKind.inplace, SchemaKind.out)
+        and not f.func.name.name.dunder_method
+    )
+    add_args = (
+        RENAME_TRACE_ADD_ARGS.get(f.func.name.name.base, "") if is_inplace else ""
+    )
+    additional_inputs = (
+        SELECT.substitute(
+            cond="tracer_state->force_outplace",
+            true=add_args,
+            false="",
+        )
+        if add_args
+        else ""
+    )
+
+    return PRE_RECORD_TRACE.substitute(
+        set_op_name=format_trace_op_name(f),
+        add_trace_inputs=format_trace_inputs(f) + additional_inputs,
+        inplace_guard=INPLACE_GUARD.substitute(
+            name=cpp.name(f.func),
+            mutable_input=f.func.arguments.out[0].name
+            if f.func.arguments.out
+            else "self",
+        )
+        if is_inplace
+        else "",
+    )
+
+
+POST_RECORD_TRACE = CodeTemplate(
+    """\
+if (tracer_state) {
+  jit::tracer::setTracingState(std::move(tracer_state));
+  ${add_trace_outputs}
+}
+"""
+)
+
+
+def format_postrecord_trace(f: NativeFunction) -> str:
+    if not should_trace(f):
+        return ""
+
+    # For outplacing ops, *_out overloads require special handling to move the
+    # output *argument* to a return value
+    if f.func.is_out_fn():
+        output_names_outplace = [arg.name for arg in f.func.arguments.out]
+        output_names_inplace = cpp.return_names(f)
+
+        # Code size optimization: the common case is that the return value is
+        # the same for both variants
+        if output_names_outplace == output_names_inplace:
+            outputs = [
+                f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace
+            ]
+            return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs)
+
+        selection = SELECT.substitute(
+            cond="force_outplace",
+            true="\n".join(
+                f"jit::tracer::addOutput(node, {n});" for n in output_names_outplace
+            ),
+            false="\n".join(
+                f"jit::tracer::addOutput(node, {n});" for n in output_names_inplace
+            ),
+        )
+        return POST_RECORD_TRACE.substitute(add_trace_outputs=selection)
+    else:
+        output_names = cpp.return_names(f)
+        outputs = [f"jit::tracer::addOutput(node, {n});" for n in output_names]
+        return POST_RECORD_TRACE.substitute(add_trace_outputs=outputs)
+
+
+def tie_return_values(f: NativeFunction) -> str:
+    if len(f.func.returns) == 1:
+        return f"auto {f.func.returns[0].name or 'result'}"
+    names = cpp.return_names(f)
+    return f"auto [{', '.join(names)}]"
+
+
+def get_return_value(f: NativeFunction) -> str:
+    names = cpp.return_names(f)
+    if len(f.func.returns) == 1:
+        return names[0]
+    if f.func.kind() == SchemaKind.out:
+        return f"std::forward_as_tuple({', '.join(names)})"
+    else:
+        moved = ", ".join(f"std::move({name})" for name in names)
+        return f"std::make_tuple({moved})"
+
+
+TRACE_DISPATCH = CodeTemplate(
+    """\
+${assign_return_values}at::_ops::${unambiguous_name}::redispatch(${unpacked_args});"""
+)
+
+
+def emit_trace_body(f: NativeFunction) -> list[str]:
+    trace_body: list[str] = []
+
+    trace_body.append(format_prerecord_trace(f))
+
+    dispatcher_sig = DispatcherSignature.from_schema(f.func)
+    dispatcher_exprs = dispatcher_sig.exprs()
+
+    # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance.
+    # See Note [Plumbing Keys Through The Dispatcher] for details.
+    dispatch_key_set = "ks & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Tracer)"
+    redispatch_args = ", ".join([dispatch_key_set] + [a.expr for a in dispatcher_exprs])
+
+    assign_return_values = (
+        f"{tie_return_values(f)} = "
+        if f.func.kind() in [SchemaKind.functional, SchemaKind.mutable]
+        and f.func.returns
+        else ""
+    )
+
+    # Note that this calls the slow, dispatching variants of manual_cpp_binding ops.
+    # We could probably work harder to ensure that the fast variants are
+    # called instead, but the perf benefit would be minimal.
+    trace_body.append(
+        TRACE_DISPATCH.substitute(
+            assign_return_values=assign_return_values,
+            unambiguous_name=f.func.name.unambiguous_name(),
+            unpacked_args=redispatch_args,
+        )
+    )
+
+    trace_body.append(format_postrecord_trace(f))
+    if f.func.returns:
+        trace_body.append(f"return {get_return_value(f)};")
+    return trace_body
+
+
+METHOD_DEFINITION = CodeTemplate(
+    """\
+${return_type} ${type_wrapper_name}(${formals}) {
+  ${type_definition_body}
+}
+"""
+)
+
+
+def type_wrapper_name(f: NativeFunction, key: str = "Default") -> str:
+    if f.func.name.overload_name:
+        name = f"{cpp.name(f.func)}_{f.func.name.overload_name}"
+    else:
+        name = cpp.name(f.func)
+
+    # The key argument is only used in gen_variable_type where we need fns per autograd dispatch key.
+    # In gen_trace_type and gen_inplace_view_type where only one fn per native_fn must be generated,
+    # the key argument should not be passed.
+    # We do not append key if it is Default so that generated functions from
+    # before per-dispatch-key derivatives were added retain the same names.
+    if key != "Default":
+        name = name + f"_{key}"
+    return name
+
+
+@with_native_function
+def method_definition(f: NativeFunction) -> str:
+    assert cpp.name(f.func) not in MANUAL_TRACER
+
+    formals = ", ".join(
+        # code-generated tracing kernels plumb and recompute dispatch keys directly through the kernel for performance.
+        # See Note [Plumbing Keys Through The Dispatcher] for details.
+        ["c10::DispatchKeySet ks"]
+        + [
+            f"{cpp.argument_type(a, binds='__placeholder__', symint=True).cpp_type()} {a.name}"
+            for a in f.func.schema_order_arguments()
+        ]
+    )
+
+    return METHOD_DEFINITION.substitute(
+        return_type=cpp.returns_type(f.func.returns, symint=True).cpp_type(),
+        type_wrapper_name=type_wrapper_name(f),
+        formals=formals,
+        type_definition_body=emit_trace_body(f),
+    )
+
+
+WRAPPER_REGISTRATION = CodeTemplate(
+    """\
+m.impl("${name}",
+       TORCH_FN(${class_type}::${type_wrapper_name})
+);
+"""
+)
+
+
+@with_native_function
+def method_registration(f: NativeFunction) -> str:
+    assert cpp.name(f.func) not in MANUAL_TRACER
+
+    return WRAPPER_REGISTRATION.substitute(
+        name=f.func.name,
+        type_wrapper_name=type_wrapper_name(f),
+        class_type="TraceType",
+    )
+
+
+def gen_trace_type_func(fn: NativeFunction) -> dict[str, list[str]]:
+    return {
+        "ops_headers": [f"#include "],
+        "trace_method_definitions": [method_definition(fn)],
+        "trace_wrapper_registrations": [method_registration(fn)],
+    }
+
+
+def gen_trace_type(
+    out: str, native_functions: list[NativeFunction], template_path: str
+) -> None:
+    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
+    # template regarding sharding of the generated files.
+    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
+    fm.write_sharded(
+        "TraceType.cpp",
+        [fn for fn in native_functions if cpp.name(fn.func) not in MANUAL_TRACER],
+        key_fn=lambda fn: fn.root_name,
+        base_env={
+            "generated_comment": "@"
+            + f"generated from {fm.template_dir_for_comments()}/TraceType.cpp",
+        },
+        env_callable=gen_trace_type_func,
+        num_shards=5,
+        sharded_keys={
+            "ops_headers",
+            "trace_method_definitions",
+            "trace_wrapper_registrations",
+        },
+    )
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_variable_factories.py b/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_variable_factories.py
new file mode 100644
index 0000000000000000000000000000000000000000..4446e8a615eadd5c8d5ca394f8b9dadbe26cd0ad
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_variable_factories.py
@@ -0,0 +1,116 @@
+# Generates C++ functions that wrap ATen tensor factory methods to turn them into Variables.
+#
+# This writes one file: variable_factories.h
+
+from __future__ import annotations
+
+import re
+
+import torchgen.api.python as python
+from torchgen.api import cpp
+from torchgen.api.types import CppSignatureGroup
+from torchgen.context import with_native_function
+from torchgen.gen import parse_native_yaml
+from torchgen.model import NativeFunction, TensorOptionsArguments, Variant
+from torchgen.utils import FileManager, mapMaybe
+
+
+OPTIONAL_TYPE_PATTERN = re.compile(r"std::optional<(.+)>")
+TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)")
+
+
+# Add 'at::' to types defined in ATen namespace, e.g. Tensor, TensorList, IntArrayRef and etc.
+# TODO: maybe update the cpp argument API to take optional namespace argument?
+def fully_qualified_type(argument_type: str) -> str:
+    def maybe_optional_type(type: str, is_opt: bool) -> str:
+        return f"std::optional<{type}>" if is_opt else type
+
+    opt_match = OPTIONAL_TYPE_PATTERN.match(argument_type)
+    is_opt = opt_match is not None
+    if opt_match:
+        argument_type = argument_type[opt_match.start(1) : opt_match.end(1)]
+    match = TYPE_PATTERN.match(argument_type)
+    if match is None:
+        return maybe_optional_type(argument_type, is_opt)
+    index = match.start(1)
+    qualified_type = f"{argument_type[:index]}at::{argument_type[index:]}"
+    return maybe_optional_type(qualified_type, is_opt)
+
+
+def gen_variable_factories(
+    out: str, native_yaml_path: str, tags_yaml_path: str, template_path: str
+) -> None:
+    native_functions = parse_native_yaml(
+        native_yaml_path, tags_yaml_path
+    ).native_functions
+    factory_functions = [fn for fn in native_functions if is_factory_function(fn)]
+    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
+    fm.write_with_template(
+        "variable_factories.h",
+        "variable_factories.h",
+        lambda: {
+            "generated_comment": "@"
+            + f"generated from {fm.template_dir_for_comments()}/variable_factories.h",
+            "ops_headers": [
+                f"#include " for fn in factory_functions
+            ],
+            "function_definitions": list(mapMaybe(process_function, factory_functions)),
+        },
+    )
+
+
+@with_native_function
+def is_factory_function(f: NativeFunction) -> bool:
+    if Variant.function not in f.variants:
+        return False
+
+    name = cpp.name(f.func)
+    has_tensor_options = python.has_tensor_options(f)
+    return has_tensor_options or name.endswith("_like")
+
+
+@with_native_function
+def process_function(f: NativeFunction) -> str | None:
+    name = cpp.name(f.func)
+    has_tensor_options = python.has_tensor_options(f)
+    is_factory = has_tensor_options or name.endswith("_like")
+
+    if Variant.function not in f.variants or not is_factory:
+        return None
+
+    cpp_sigs = CppSignatureGroup.from_native_function(f, method=False)
+    sigs = [cpp_sigs.signature]
+    if cpp_sigs.symint_signature is not None:
+        sigs.append(cpp_sigs.symint_signature)
+    r = ""
+    for sig in sigs:
+        formals: list[str] = []
+        exprs: list[str] = []
+        requires_grad = "false"
+        for arg in sig.arguments():
+            qualified_type = fully_qualified_type(arg.type)
+            if arg.default:
+                formals.append(f"{qualified_type} {arg.name} = {arg.default}")
+            else:
+                formals.append(f"{qualified_type} {arg.name}")
+
+            if isinstance(arg.argument, TensorOptionsArguments):
+                # note: we remove the requires_grad setting from the TensorOptions because
+                # it is ignored anyways (and we actually have an assertion that it isn't set
+                # which would fail otherwise). We handle requires_grad explicitly here
+                # instead of passing it through to the kernel.
+                exprs.append(
+                    f"at::TensorOptions({arg.name}).requires_grad(::std::nullopt)"
+                )
+                # Manually set the requires_grad bit on the result tensor.
+                requires_grad = f"{arg.name}.requires_grad()"
+            else:
+                exprs.append(arg.name)
+
+        r += f"""\
+inline at::Tensor {sig.name()}({", ".join(formals)}) {{
+  at::AutoDispatchBelowADInplaceOrView guard;
+  return autograd::make_variable(at::{sig.name()}({", ".join(exprs)}), /*requires_grad=*/{requires_grad});
+}}
+"""
+    return r
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_variable_type.py b/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_variable_type.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa5ee0ba3df00b44c885c7e46ba2af76b00d097d
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_variable_type.py
@@ -0,0 +1,2186 @@
+# Generates VariableType.h/cpp
+#
+# **If any changes are being made to the VariableType codegen please also check
+# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
+#
+# VariableType is a subclass of at::Type that provides the binding code
+# necessary to provide a differentiable version of ATen operators. There are a
+# number of different things we could mean:
+#
+#   - Given a non-differentiable forward implementation, we might
+#     directly associate it with a backward implementation to make
+#     it differentiable.  This is the common case.
+#
+#   - Some functions don't need a backwards implementation, because
+#     backpropagation will never propagate beyond them.  There are a
+#     number of different reasons why this may be the case:
+#
+#       - The function has no differentiable inputs
+#       - The function's output is not differentiable
+#       - The function has no data dependency on its input
+#
+#   - Some function don't need a backwards implementation because they
+#     are implemented as a composition of other (differentiable) ATen
+#     functions.  These are dispatched directly to the Type superclass,
+#     which will in turn dispatch back to VariableType for its
+#     differentiable subcomponents.
+#
+
+from __future__ import annotations
+
+import re
+from typing import Callable, TYPE_CHECKING
+
+from torchgen.api import cpp
+from torchgen.api.autograd import (
+    DifferentiableInput,
+    dispatch_strategy,
+    ForwardDerivative,
+    gen_differentiable_outputs,
+    is_differentiable,
+    NativeFunctionWithDifferentiabilityInfo,
+    SavedAttribute,
+)
+from torchgen.api.types import (
+    ArrayRefCType,
+    BaseCppType,
+    BaseCType,
+    Binding,
+    intArrayRefT,
+    iTensorListRefT,
+    ListCType,
+    MutRefCType,
+    OptionalCType,
+    scalarT,
+    SpecialArgName,
+    stringT,
+    symIntArrayRefT,
+    TENSOR_LIST_LIKE_CTYPES,
+    tensorListT,
+    tensorT,
+    TupleCType,
+    VectorCType,
+)
+from torchgen.code_template import CodeTemplate
+from torchgen.context import (
+    native_function_manager,
+    with_native_function,
+    with_native_function_and,
+)
+from torchgen.model import (
+    Argument,
+    BaseType,
+    ListType,
+    NativeFunction,
+    SchemaKind,
+    SelfArgument,
+    TensorOptionsArguments,
+)
+from torchgen.utils import FileManager, mapMaybe
+
+from .context import with_native_function_with_differentiability_info_and_key
+from .gen_inplace_or_view_type import (
+    ALL_VIEW_FUNCTIONS,
+    ASSIGN_RETURN_VALUE,
+    AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION,
+    gen_formals,
+    get_base_name,
+    get_view_info,
+    is_tensor_list_type,
+    is_tensor_type,
+    METHOD_DEFINITION,
+    modifies_arguments,
+    TMP_VAR,
+    unpack_args,
+    unpacked_name,
+    use_derived,
+    WRAPPER_REGISTRATION,
+)
+from .gen_trace_type import (
+    get_return_value,
+    MANUAL_AUTOGRAD_AND_TRACER,
+    MANUAL_BACKEND,
+    tie_return_values,
+    type_wrapper_name,
+)
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+
+# We don't set or modify grad_fn on these methods. Generally, they return
+# tensors that have requires_grad=False. In-place functions listed here will
+# not examine or modify requires_grad or grad_fn.
+# NB: this does NOT include overload name
+DONT_REQUIRE_DERIVATIVE = {
+    # These only depend on the input Tensor's shape and device, not the data
+    "empty_like",
+    "ones_like",
+    "full_like",
+    "zeros_like",
+    "rand_like",
+    "randn_like",
+    "new_empty",
+    "new_empty_strided",
+    "new_full",
+    "new_zeros",
+    "new_ones",
+    # These are only implemented on integral types
+    "__and__",
+    "__iand__",
+    "__ilshift__",
+    "__ior__",
+    "__irshift__",
+    "__ixor__",
+    "__lshift__",
+    "__or__",
+    "__rshift__",
+    "__xor__",
+    # These work on integral data types, and hence don't require derivative
+    "_sobol_engine_draw",
+    "_sobol_engine_ff",
+    "_sobol_engine_scramble_",
+    "_sobol_engine_initialize_state_",
+    # This is an unsafe method that is meant to be out of reach of autograd.
+    "_coalesced_",
+    # Quantize functions should not record gradients
+    "quantize_per_tensor",
+    "quantize_per_channel",
+    # Functions that return integers should not have output that require gradients
+    "argmax",
+    "argmin",
+    "argsort",
+    "searchsorted",
+    "bucketize",
+    # Functions that return booleans are not differentiable
+    "isnan",
+    "isposinf",
+    "isneginf",
+    "isinf",
+    "signbit",
+    "isin",
+    "allclose",
+    # Functions return none are not differentiable
+    "record_stream",
+    # These functions are not differentiable
+    "logical_and",
+    "logical_xor",
+    "logical_not",
+    "logical_or",
+    # This function returns nested_tensor shape as a tensor that is non-differentiable
+    "_nested_tensor_size",
+    "_nested_tensor_strides",
+    "_nested_tensor_storage_offsets",
+}
+
+# The C -> R functions at the time of adding this are still being audited and tested
+# but will not error out.
+# C -> C, R -> C functions for which backward is correctly implemented and tested
+GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
+    "fill",
+    "t",
+    "t_copy",
+    "view",
+    "reshape",
+    "reshape_as",
+    "view_as",
+    "view_copy",
+    "roll",
+    "clone",
+    "block_diag",
+    "diag_embed",
+    "repeat",
+    "expand",
+    "expand_copy",
+    "flip",
+    "fliplr",
+    "flipud",
+    "rot90",
+    "nanmean",
+    "nansum",
+    "transpose",
+    "transpose_copy",
+    "permute",
+    "permute_copy",
+    "squeeze",
+    "squeeze_copy",
+    "unsqueeze",
+    "unsqueeze_copy",
+    "resize",
+    "resize_as",
+    "tril",
+    "triu",
+    "chunk",
+    "zero_",
+    "eq_",
+    "ne_",
+    "add",
+    "__radd__",
+    "sum",
+    "_conj",
+    "sin",
+    "cos",
+    "mul",
+    "sinc",
+    "sinh",
+    "cosh",
+    "__rmul__",
+    "sgn",
+    "asin",
+    "acos",
+    "sub",
+    "div",
+    "cat",
+    "view_as_complex",
+    "index_put",
+    "neg",
+    "complex",
+    "select",
+    "where",
+    "as_strided",
+    "as_strided_copy",
+    "as_strided_scatter",
+    "slice",
+    "constant_pad_nd",
+    "unbind",
+    "unbind_copy",
+    "split",
+    "split_with_sizes",
+    "unsafe_split",
+    "split_with_sizes_backward",
+    "dot",
+    "vdot",
+    "cholesky",
+    "triangular_solve",
+    "mm",
+    "_unsafe_view",
+    "mv",
+    "outer",
+    "bmm",
+    "diagonal",
+    "alias",
+    "atan",
+    "log",
+    "log10",
+    "log1p",
+    "log2",
+    "logaddexp",
+    "logsumexp",
+    "logcumsumexp",
+    "reciprocal",
+    "tan",
+    "pow",
+    "rsqrt",
+    "tanh",
+    "tanh_backward",
+    "asinh",
+    "acosh",
+    "atanh",
+    "take",
+    "fill_",
+    "exp",
+    "exp2",
+    "expm1",
+    "nonzero",
+    "mean",
+    "std_mean",
+    "var_mean",
+    "inverse",
+    "solve",
+    "linalg_cholesky",
+    "addcmul",
+    "addcdiv",
+    "matrix_exp",
+    "linalg_matrix_exp",
+    "_linalg_eigh",
+    "cholesky_solve",
+    "linalg_qr",
+    "_linalg_svd",
+    "_fft_c2c",
+    "_fft_r2c",
+    "linalg_solve",
+    "sqrt",
+    "stack",
+    "gather",
+    "index_select",
+    "index_add_",
+    "linalg_inv",
+    "linalg_inv_ex",
+    "baddbmm",
+    "addbmm",
+    "addmm",
+    "addmv",
+    "addr",
+    "linalg_householder_product",
+    "ormqr",
+    "reflection_pad1d",
+    "reflection_pad2d",
+    "reflection_pad3d",
+    "linalg_cholesky_ex",
+    "linalg_eig",
+    "diagonal_copy",
+    "diagonal_scatter",
+    "alias_copy",
+    "select_backward",
+    "diagonal_backward",
+    "slice_backward",
+    "reflection_pad1d_backward",
+    "reflection_pad2d_backward",
+    "reflection_pad3d_backward",
+    "_sparse_sparse_matmul",
+    "replication_pad1d",
+    "replication_pad2d",
+    "replication_pad3d",
+    "put",
+    "put_",
+    "_to_copy",
+    "replication_pad1d_backward",
+    "replication_pad2d_backward",
+    "replication_pad3d_backward",
+    "diag",
+    "masked_scatter",
+    "masked_select",
+    "index_add",
+    "index_fill",
+    "trace",
+    "polar",
+    "cumsum",
+    "rsub",
+    "eig",
+    "lerp",
+    "linalg_vector_norm",
+    "cumprod",
+    "prod",
+    "index_copy",
+    "lu",
+    "unfold",
+    "unfold_backward",
+    "index",
+    "masked_fill",
+    "masked_scatter_backward",
+    "linalg_cross",
+    "lu_unpack",
+    "renorm",
+    "_conj_physical",
+    "linalg_lu_factor_ex",
+    "scatter",
+    "scatter_add",
+    "sigmoid",
+    "sigmoid_backward",
+    "sparse_mask",
+    "trapezoid",
+    "cumulative_trapezoid",
+    "conj_physical_",
+    "_neg_view",
+    "_reshape_alias",
+    "_reshape_copy",
+    "_linalg_det",
+    "lu_solve",
+    "linalg_solve_triangular",
+    "linalg_pinv",
+    "linalg_lstsq",
+    "unfold_copy",
+    "col2im",
+    "im2col",
+    "cholesky_inverse",
+    "to_sparse",
+    "sparse_sampled_addmm",
+    "linalg_lu",
+    "pixel_shuffle",
+    "pixel_unshuffle",
+    "channel_shuffle",
+    "linalg_lu_solve",
+    "_linalg_slogdet",
+    "_linalg_solve_ex",
+    "_unsafe_index",
+    "_unsafe_index_put",
+    "_unsafe_masked_index",
+    "_unsafe_masked_index_put_accumulate",
+}
+
+GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
+    "_to_dense",
+    "_coalesce",
+    "coalesce",
+    "values",
+    "_sparse_coo_tensor_with_dims_and_tensors",
+    "_sparse_addmm",
+}
+
+GRADIENT_IMPLEMENTED_FOR_COMPLEX.update(GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX)
+
+# Some operators invalidate the grad_accumulator. Let's reset it.
+RESET_GRAD_ACCUMULATOR = {"set_", "resize_"}
+
+# NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
+#
+# We check the following properties:
+#   1) A function should never change the input tensors' underlying c10::TensorImpl
+#      pointers or c10::Storage pointers, even if it modifies its input tensors (via
+#      inplace or out-variants)
+# If the function does not modify its arguments, we also check the following properties
+# pertaining to its output:
+#   2) Its TensorImpl has use_count of 1
+#   3) If the function is a view function, it has the same StorageImpl as that of
+#      the input it is aliased with. Otherwise, its StorageImpl has use_count of 1
+#
+# The following code templates implement the checks for this invariant:
+SAVE_TENSOR_STORAGE = CodeTemplate(
+    """\
+auto ${tensor_name}_storage_saved =
+  ${tensor_name}.has_storage() ? ::std::optional(${tensor_name}.storage()) : ::std::nullopt;
+"""
+)
+
+
+# If tensor_name == out_tensor_name, used to enforce (1), otherwise used for (2)
+ENFORCE_SAME_TENSOR_STORAGE = CodeTemplate(
+    """\
+if (${tensor_name}_storage_saved.has_value() &&
+    !at::impl::dispatch_mode_enabled() &&
+    !at::impl::tensor_has_dispatch(${tensor_name}) &&
+    !at::impl::tensor_has_dispatch(${out_tensor_name}))
+  TORCH_INTERNAL_ASSERT(${tensor_name}_storage_saved.value().is_alias_of(${out_tensor_name}.storage()));
+"""
+)
+
+SAVE_TENSORLIST_STORAGE = CodeTemplate(
+    """\
+std::vector<::std::optional> ${tensorlist_name}_storage_saved(${tensorlist_name}.size());
+for (const Tensor& tensor : ${tensorlist_name})
+  ${tensorlist_name}_storage_saved.push_back(
+    tensor.has_storage() ? ::std::optional(tensor.storage()) : ::std::nullopt);
+"""
+)
+
+ENFORCE_SAME_TENSORLIST_STORAGE = CodeTemplate(
+    """\
+for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) {
+  if (${tensorlist_name}_storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(${tensorlist_name}))
+    TORCH_INTERNAL_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(${tensorlist_name}[i].storage()));
+}
+"""
+)
+
+SAVE_OPTIONALTENSORLIST_STORAGE = CodeTemplate(
+    """\
+std::vector<::std::optional> ${tensorlist_name}_storage_saved(${tensorlist_name}.size());
+for (const ::std::optional& tensor : ${tensorlist_name})
+  ${tensorlist_name}_storage_saved.push_back(
+    tensor.has_value() && tensor->has_storage() ? ::std::optional(tensor->storage()) : ::std::nullopt);
+"""
+)
+
+ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE = CodeTemplate(
+    """\
+for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) {
+  if (${tensorlist_name}_storage_saved[i].has_value() && !at::impl::tensorlist_has_dispatch(${tensorlist_name}))
+    TORCH_INTERNAL_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(
+        static_cast<::std::optional>(${tensorlist_name}[i])->storage()));
+}
+"""
+)
+
+SAVE_TENSOR_IMPL = CodeTemplate(
+    """\
+c10::intrusive_ptr ${tensor_name}_impl_saved;
+if (${tensor_name}.defined()) ${tensor_name}_impl_saved = ${tensor_name}.getIntrusivePtr();
+"""
+)
+
+ENFORCE_SAME_TENSOR_IMPL = CodeTemplate(
+    """\
+if (${tensor_name}_impl_saved && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name}))
+  TORCH_INTERNAL_ASSERT(${tensor_name}_impl_saved == ${tensor_name}.getIntrusivePtr());
+"""
+)
+
+ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE = CodeTemplate(
+    """\
+if (!at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name}))
+  TORCH_INTERNAL_ASSERT(${tensor_name}.use_count() <= 1, "function: ${fn_name}");
+"""
+)
+
+ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE = CodeTemplate(
+    """\
+if (${tensor_name}.has_storage() && !at::impl::dispatch_mode_enabled() && !at::impl::tensor_has_dispatch(${tensor_name})) {
+  TORCH_INTERNAL_ASSERT(${tensor_name}.storage().use_count() == 1, "function: ${fn_name}");
+}
+"""
+)
+
+SAVE_TENSORLIST_IMPL = CodeTemplate(
+    """\
+std::vector> ${tensorlist_name}_impl_saved(${tensorlist_name}.size());
+for (size_t i=0; i<${tensorlist_name}.size(); i++)
+  if (${tensorlist_name}[i].defined()) ${tensorlist_name}_impl_saved[i] = ${tensorlist_name}[i].getIntrusivePtr();
+"""
+)
+
+ENFORCE_SAME_TENSORLIST_IMPL = CodeTemplate(
+    """\
+for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) {
+  if (${tensorlist_name}_impl_saved[i] && !at::impl::tensorlist_has_dispatch(${tensorlist_name}))
+    TORCH_INTERNAL_ASSERT(${tensorlist_name}_impl_saved[i] == ${tensorlist_name}[i].getIntrusivePtr());
+}
+"""
+)
+
+SAVE_OPTIONALTENSORLIST_IMPL = CodeTemplate(
+    """\
+std::vector> ${tensorlist_name}_impl_saved(${tensorlist_name}.size());
+for (size_t i=0; i<${tensorlist_name}.size(); i++) {
+  ::std::optional t = ${tensorlist_name}[i];
+  if (t.has_value() && t->defined()) ${tensorlist_name}_impl_saved[i] = t->getIntrusivePtr();
+}
+"""
+)
+
+ENFORCE_SAME_OPTIONALTENSORLIST_IMPL = CodeTemplate(
+    """\
+for (size_t i=0; i<${tensorlist_name}.size() && !at::impl::dispatch_mode_enabled(); i++) {
+  if (${tensorlist_name}_impl_saved[i])
+    TORCH_INTERNAL_ASSERT(
+      ${tensorlist_name}_impl_saved[i] == static_cast<::std::optional>(${tensorlist_name}[i])->getIntrusivePtr());
+}
+"""
+)
+
+# The following list contains functions that we don't enforce the invariant on.
+DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = {
+    # These functions are expected to change impl or storage of input tensors
+    "set_",
+    "_cudnn_rnn_flatten_weight",
+    "_unsafe_masked_index",
+    "_unsafe_masked_index_put_accumulate",
+}
+DONT_ENFORCE_TENSOR_IMPL_USE_COUNT = {
+    # These non-inplace, non-out functions return tensors with use_count > 1
+    # Therefore, they MAY (but not necessarily) return one of its inputs as-is
+    # See https://github.com/pytorch/pytorch/issues/60426 for more information
+    "_embedding_bag",
+    "_embedding_bag_forward_only",
+    "q_per_channel_scales",
+    "q_per_channel_zero_points",
+    "lu_unpack",
+    "_cudnn_rnn_backward",
+    # The below failed StorageImpl use_count check but we skip tensor_impl check
+    # just in case
+    "_cudnn_rnn",
+    "dequantize_self",
+    # lift() should never actually be called with a requires_grad=True tensor,
+    "lift",
+    "lift_fresh",
+    "lift_fresh_copy",
+    # Nested Tensors related functions
+    # _nested_tensor_size() should never actually be called with requires_grad=True tensor
+    "_nested_tensor_size",
+    "_nested_tensor_strides",
+    "_nested_tensor_storage_offsets",
+}
+
+DONT_ENFORCE_STORAGE_IMPL_USE_COUNT = {
+    # These non-view functions return tensors with storage use_count != 1
+    "_slow_conv2d_forward",
+    "slow_conv3d_forward",
+    "channel_shuffle",
+    # If an input is returned as-is in output, we cannot guarantee its storage_impl
+    # use count to be 1 either.
+    *DONT_ENFORCE_TENSOR_IMPL_USE_COUNT,
+}
+# END CHECKS FOR [ TensorImpl and Storage Pointer Sanity Checks ]
+
+DECLARE_GRAD_FN = CodeTemplate(
+    """\
+std::shared_ptr<${op}> grad_fn;
+"""
+)
+
+DECLARE_VECTOR_OF_GRAD_FN = CodeTemplate(
+    """\
+std::vector> grad_fns;
+"""
+)
+
+SETUP_ANY_REQUIRES_GRAD = CodeTemplate(
+    """\
+[[maybe_unused]] auto _any_requires_grad = compute_requires_grad( ${args_with_derivatives} );
+${extra_differentiability_conditions}
+"""
+)
+
+SETUP_DERIVATIVE = CodeTemplate(
+    """\
+if (_any_requires_grad) {
+  ${setup}
+}
+"""
+)
+
+SETUP_NONE_REQUIRES_GRAD = CodeTemplate(
+    """\
+if (compute_requires_grad( ${args_to_check} )) {
+  throw_error_out_requires_grad("${base_name}");
+}
+"""
+)
+
+ASSIGN_GRAD_FN = CodeTemplate(
+    """\
+grad_fn = std::shared_ptr<${op}>(new ${op}(${op_ctor}), deleteNode);
+grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} ));
+"""
+)
+
+# note(crcrpar): `compute_requires_grad` in the template below is supplied with arguments indexed with `i`
+# while the `SETUP_ANY_REQUIRES_GRAD` above takes whole tensors and scalars.
+ASSIGN_VECTOR_OF_GRAD_FN = CodeTemplate(
+    """\
+for (const auto& i : c10::irange( ${irange} )) {
+  const auto ith_requires_grad = compute_requires_grad(${args_with_derivatives});
+  check_inplace(self[i], ith_requires_grad);
+  grad_fns.push_back([&]() -> std::shared_ptr<${op}> {
+      if (!ith_requires_grad) {
+          return nullptr;
+      } else {
+          auto grad_fn = std::shared_ptr<${op}>(new ${op}(${op_ctor}), deleteNode);
+          grad_fn->set_next_edges(collect_next_edges( ${args_with_derivatives} ));
+          return grad_fn;
+      }
+  }());
+}
+"""
+)
+
+CALL_REDISPATCH = CodeTemplate(
+    """\
+at::redispatch::${api_name}(${unpacked_args})"""
+)
+# If the non-variable operation has return values, we use the `tmp` variable to hold the
+# values temporarily and pass the values to the return variables outside of the
+# `at::AutoDispatchBelowAutograd` guard block.
+DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES_JVP_DECOMP = CodeTemplate(
+    """\
+auto ${tmp_var} = ([&]() {
+  if (${any_has_forward_grad}) {
+    static c10::OperatorName full_name("aten::${op_name}", "${op_overload}");
+    static ::std::optional opt_op = c10::Dispatcher::singleton().findSchema(full_name);
+    return impl::run_jit_decomposition_with_args_for_jvp<${return_types}>("${op_name}", *opt_op, ks, ${arg_names});
+  } else {
+    ${guard}
+    return ${base_type_call};
+  }
+})();
+"""
+)
+
+DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES = CodeTemplate(
+    """\
+auto ${tmp_var} = ([&]() {
+  ${guard}
+  return ${base_type_call};
+})();
+"""
+)
+
+DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES = CodeTemplate(
+    """\
+{
+  ${guard}
+  ${base_type_call};
+}
+"""
+)
+
+SET_HISTORY = CodeTemplate(
+    """\
+if (grad_fn) {
+    ${fn}_history(${differentiable_outputs}, grad_fn);
+}
+"""
+)
+
+LOOP_OVER_VECTOR_OF_GRAD_FNS = CodeTemplate(
+    """\
+if (!grad_fns.empty()) {
+    ${preamble}
+    for (const auto& i : c10::irange(grad_fns.size())) {
+        auto grad_fn = grad_fns[i];
+        if (grad_fn != nullptr) {
+            ${statements}
+        }
+    }
+}
+"""
+)
+
+CONDITIONAL = CodeTemplate(
+    """\
+if (${cond}) {
+  ${statements}
+}
+"""
+)
+
+RUN_ONLY_IN_DEBUG_MODE = CodeTemplate(
+    """\
+#ifndef NDEBUG
+${statements}
+#endif
+"""
+)
+
+FW_DERIVATIVE_CHECK_TEMPLATE = CodeTemplate(
+    """\
+isFwGradDefined(${req_inp})\
+"""
+)
+FW_DERIVATIVE_SIZE_CHECK_TEMPLATE = CodeTemplate(
+    """\
+TORCH_CHECK(
+    self.size() == ${inp_name}.size(),
+      "Tensor lists must have the same number of tensors, got ",
+    self.size(),
+      " and ",
+    ${inp_name}.size());
+"""
+)
+
+FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE = CodeTemplate(
+    """\
+isFwGradDefinedTensorList(${req_inp})\
+"""
+)
+
+FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE = CodeTemplate(
+    """\
+auto ${inp_name}_t_raw = toNonOptFwGrad(${inp});
+auto ${inp_name}_tensor = toNonOptTensor(${inp});
+auto ${inp_name}_t = (${inp_name}_t_raw.defined() || !${inp_name}_tensor.defined())
+  ? ${inp_name}_t_raw : at::${zeros_fn}(${inp_name}_tensor.sym_sizes(), ${inp_name}_tensor.options());
+"""
+)
+
+FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE = CodeTemplate(
+    """\
+auto ${inp_name}_p = toNonOptPrimal(${inp});
+"""
+)
+
+FW_DERIVATIVE_SETTER_TENSOR = CodeTemplate(
+    """\
+if (${out_arg}_new_fw_grad_opt.has_value() && ${out_arg}_new_fw_grad_opt.value().defined() && ${out_arg}.defined()) {
+  // The hardcoded 0 here will need to be updated once we support multiple levels.
+  ${out_arg}._set_fw_grad(${out_arg}_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ ${is_inplace});
+}
+"""
+)
+
+FW_DERIVATIVE_SETTER_TENSOR_FOREACH = CodeTemplate(
+    """\
+for (const auto& i : c10::irange(${out_arg}_new_fw_grad_opts.size())) {
+  auto& ${out_arg}_new_fw_grad_opt = ${out_arg}_new_fw_grad_opts[i];
+  if (${out_arg}_new_fw_grad_opt.has_value() && ${out_arg}_new_fw_grad_opt.value().defined() && ${out_arg}[i].defined()) {
+    // The hardcoded 0 here will need to be updated once we support multiple levels.
+    ${out_arg}[i]._set_fw_grad(${out_arg}_new_fw_grad_opt.value(), /* level */ 0, /* is_inplace_op */ ${is_inplace});
+  }
+}
+"""
+)
+
+FW_DERIVATIVE_SETTER_MULTI_OUTPUT = CodeTemplate(
+    """\
+if (${all_res}_new_fw_grad_opt.has_value() && std::get<${idx}>(${all_res}_new_fw_grad_opt.value()).defined()
+    && ${out_arg}.defined()) {
+  ${out_arg}._set_fw_grad(std::get<${idx}>(${all_res}_new_fw_grad_opt.value()), /* level */ 0, /* is_inplace_op */ false);
+}
+"""
+)
+
+FW_DERIVATIVE_SETTER_TENSOR_LIST = CodeTemplate(
+    """\
+if (${out_arg}_new_fw_grad_opt.has_value()) {
+  auto ${out_arg}_new_fw_grad = ${out_arg}_new_fw_grad_opt.value();
+  TORCH_INTERNAL_ASSERT(${out_arg}.size() == ${out_arg}_new_fw_grad.size());
+  for (const auto i : c10::irange(${out_arg}.size())) {
+    if (${out_arg}_new_fw_grad[i].defined() && ${out_arg}[i].defined()) {
+      // The hardcoded 0 here will need to be updated once we support multiple levels.
+      ${out_arg}[i]._set_fw_grad(${out_arg}_new_fw_grad[i], /* level */ 0, /* is_inplace_op */ ${is_inplace});
+    }
+  }
+}
+"""
+)
+
+FW_DERIVATIVE_TEMPLATE = CodeTemplate(
+    """\
+${fw_grad_opt_definition}
+if (${requires_fw_grad}) {
+    ${unpacked_arguments}
+    ${out_arg}_new_fw_grad_opt = ${formula};
+}
+"""
+)
+
+FW_DERIVATIVE_FOREACH_TEMPLATE = CodeTemplate(
+    """\
+${fw_grad_opt_definition}
+for (const auto& i : c10::irange(${vector_of_optional_tensor}.size())) {
+  if (${any_has_forward_grad_for_current_index}) {
+      ${unpacked_arguments}
+      ${vector_of_optional_tensor}[i] = ${formula};
+  }
+}
+"""
+)
+
+FW_DERIVATIVE_FORBID_TEMPLATE = CodeTemplate(
+    """\
+TORCH_CHECK_NOT_IMPLEMENTED(!(${cond}), "Trying to use forward AD with ${name} that does not support it ${msg}");
+"""
+)
+
+FW_DERIVATIVE_FORBID_LIST_TEMPLATE = CodeTemplate(
+    """\
+for (const auto& _t: ${arg}) {
+    TORCH_CHECK_NOT_IMPLEMENTED(!(${cond}), "Trying to use forward AD with ${name} that does not support it ${msg}");
+}
+"""
+)
+
+
+def gen_variable_type(
+    out: str,
+    native_yaml_path: str,
+    tags_yaml_path: str,
+    fns_with_diff_infos: list[NativeFunctionWithDifferentiabilityInfo],
+    template_path: str,
+    used_keys: set[str],
+) -> None:
+    """VariableType.h and VariableType.cpp body
+
+    This is the at::Type subclass for differentiable tensors. The
+    implementation of each function dispatches to the base tensor type to
+    compute the output. The grad_fn is attached to differentiable functions.
+    """
+    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
+    fm.write(
+        "VariableType.h",
+        lambda: {
+            "generated_comment": "@"
+            + f"generated from {fm.template_dir_for_comments()}/VariableType.h"
+        },
+    )
+
+    # helper that generates a TORCH_LIBRARY_IMPL macro for each
+    # dispatch key that appears in derivatives.yaml
+    def wrapper_registrations(used_keys: set[str]) -> str:
+        library_impl_macro_list: list[str] = []
+        for key in sorted(used_keys):
+            dispatch_key = key
+            if key == "Default":
+                dispatch_key = "Autograd"
+            library_impl_macro = (
+                f"TORCH_LIBRARY_IMPL(aten, {dispatch_key}, m) "
+                + "{\n"
+                + "${"
+                + f"wrapper_registrations_{key}"
+                + "}\n}"
+            )
+            library_impl_macro_list += [library_impl_macro]
+        return "\n\n".join(library_impl_macro_list)
+
+    # Generate a new template from VariableType.cpp which replaces ${wrapper_registrations}
+    # with per key TORCH_LIBRARY_IMPL macros for each key that appears in derivatives.yaml
+    fm1 = FileManager(
+        install_dir=out + "/templates", template_dir=template_path, dry_run=False
+    )
+    fm1.write(
+        "VariableType.cpp",
+        lambda: {
+            "type_derived_method_definitions": "\n\n".join(
+                [
+                    "${" + f"type_derived_method_definitions_{key}" + "}"
+                    for key in sorted(used_keys)
+                ]
+            ),
+            "wrapper_registrations": wrapper_registrations(used_keys),
+        },
+    )
+
+    # Generate final VariableType_*.cpp files from the generated template
+    fm2 = FileManager(install_dir=out, template_dir=out + "/templates", dry_run=False)
+
+    sharded_keys = set(
+        [f"type_derived_method_definitions_{key}" for key in sorted(used_keys)]
+        + [f"wrapper_registrations_{key}" for key in sorted(used_keys)]
+    )
+    # NOTE: see Note [Sharded File] at the top of the VariableType.cpp
+    # template regarding sharding of the generated files.
+    fm2.write_sharded(
+        "VariableType.cpp",
+        [fn for fn in fns_with_diff_infos if use_derived(fn)],
+        key_fn=lambda fn: cpp.name(fn.func.func),
+        base_env={
+            "generated_comment": "@"
+            + f"generated from {fm.template_dir_for_comments()}/VariableType.cpp",
+        },
+        env_callable=gen_variable_type_func,
+        num_shards=5,
+        sharded_keys=sharded_keys,
+    )
+
+
+@with_native_function_and
+def gen_wrapper_registration(f: NativeFunction, key: str = "Default") -> str:
+    return WRAPPER_REGISTRATION.substitute(
+        unqual_operator_name_with_overload=f.func.name,
+        type_wrapper_name=type_wrapper_name(f, key),
+        class_type="VariableType",
+    )
+
+
+def gen_variable_type_func(
+    fn: NativeFunctionWithDifferentiabilityInfo,
+) -> dict[str, list[str]]:
+    f = fn.func
+    result = {}
+    with native_function_manager(f):
+        name = cpp.name(f.func)
+        formals = gen_formals(f)
+
+        if (
+            fn.info is None
+            and str(f.func.name.name) not in RESET_GRAD_ACCUMULATOR
+            and get_base_name(f) not in DONT_REQUIRE_DERIVATIVE
+            and len(gen_differentiable_outputs(fn)) > 0
+            and cpp.name(f.func) not in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE
+            and type_wrapper_name(f) not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT
+            and type_wrapper_name(f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT
+        ):
+            # NOTE: [ Registering AutogradNotImplemented boxed kernel ]
+            #
+            # When there is no derivatives.yaml entry, we register a generic boxed
+            # NotImplemented kernel to set grad_fn to be NotImplemented, so that forward
+            # proceeds as usual but an error is properly produced on backward.
+            # TODO: it would be nice to not have these special cases
+            #
+            # There are several cases where still let codegen handle it:
+            # 1) ops that need to reset grad accumulator (we let codegen handle this case
+            #     because) the list is (currently) only accessible in Python.
+            # 2) User explicitly specifies DONT_REQUIRE_DERIVATIVE. This basically makes
+            #    autograd a fallthrough with NDEBUG checks. This can be useful for when all
+            #    outputs are integral.
+            # 3) When there are no differentiable outputs. This is similar to (2).
+            # 4) There are certain ops where we skip certain NDEBUG checks. this is similar
+            #    to (1).
+            type_definition = ""
+            wrapper_registration = AUTOGRAD_NOT_IMPLEMENTED_REGISTRATION.substitute(
+                unqual_operator_name_with_overload=f.func.name
+            )
+            result["type_derived_method_definitions_Default"] = [type_definition]
+            result["wrapper_registrations_Default"] = [wrapper_registration]
+        else:
+            if not fn.info:
+                key = "Default"
+                type_definition = METHOD_DEFINITION.substitute(
+                    return_type=cpp.returns_type(
+                        f.func.returns, symint=True
+                    ).cpp_type(),
+                    type_wrapper_name=type_wrapper_name(f, key),
+                    type_definition_body=emit_body(fn, key),
+                    formals=formals,
+                )
+                wrapper_registration = gen_wrapper_registration(f, key)
+                result[f"type_derived_method_definitions_{key}"] = [type_definition]
+                result[f"wrapper_registrations_{key}"] = [wrapper_registration]
+            else:
+                for key in fn.info.keys():
+                    type_definition = METHOD_DEFINITION.substitute(
+                        return_type=cpp.returns_type(
+                            f.func.returns, symint=True
+                        ).cpp_type(),
+                        type_wrapper_name=type_wrapper_name(f, key),
+                        type_definition_body=emit_body(fn, key),
+                        formals=formals,
+                    )
+                    wrapper_registration = gen_wrapper_registration(f, key)
+                    result[f"type_derived_method_definitions_{key}"] = [type_definition]
+                    result[f"wrapper_registrations_{key}"] = [wrapper_registration]
+    # See Note [Manual Backend kernels]
+    assert (name in MANUAL_BACKEND) == f.manual_kernel_registration
+    # If you want to register a kernel to Autograd, you must make the op abstract.
+    # In other words, this op must have dispatch section in native_functions.yaml.
+    if name in MANUAL_AUTOGRAD_AND_TRACER or (
+        fn.info and any(info.has_derivatives for info in fn.info.values())
+    ):
+        msg = (
+            f"There's a formula for {name}(or its functional variant) in derivatives.yaml. "
+            f"It's required to add a dispatch section for it with explicit supported backends e.g CPU/CUDA "
+            f"or CompositeExplicitAutograd in native_functions.yaml. Please see "
+            f"https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#choosing-the-right-dispatch-keyword "
+            f"for instructions to choose the right dispatch keyword."
+        )
+        assert f.is_abstract, msg
+
+    return result
+
+
+_foreach_ops_without_differentiability_info = {
+    # No reference backward available due to the lack of `{maximum, minimum}(tensor, scalar)`.
+    ("_foreach_maximum", "Scalar"),
+    ("_foreach_maximum", "ScalarList"),
+    ("_foreach_minimum", "Scalar"),
+    ("_foreach_minimum", "ScalarList"),
+    # No reference backward available as addcdiv/addcmul don't support Tensor as scaling factor.
+    ("_foreach_addcdiv", "Tensor"),
+    ("_foreach_addcmul", "Tensor"),
+    ("_foreach_copy", ""),
+}
+
+_foreach_ops_with_different_arity = {
+    # These ops lack `alpha` of scaling factor to applied to the right hand side argument.
+    ("_foreach_add", "Scalar"),
+    ("_foreach_add", "ScalarList"),
+    ("_foreach_sub", "Scalar"),
+    ("_foreach_sub", "ScalarList"),
+}
+
+
+@with_native_function_with_differentiability_info_and_key
+def emit_body(
+    fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
+) -> list[str]:
+    assert dispatch_strategy(fn) == "use_derived"
+    f = fn.func
+    info = fn.info[key] if fn.info else None
+    fw_derivatives = fn.fw_derivatives.get(key, []) if fn.fw_derivatives else []
+
+    name = cpp.name(f.func)
+    inplace = f.func.kind() == SchemaKind.inplace
+    is_out_fn = f.func.kind() == SchemaKind.out
+    returns_void = len(f.func.returns) == 0
+    base_name = get_base_name(f)
+    view_info = get_view_info(f)
+
+    is_foreach = name.startswith("_foreach")
+    is_inplace_foreach = is_foreach and inplace
+    if is_inplace_foreach:
+        inplace_foreacharg2refarg: dict[Argument, Argument] = {}
+        refargname2inplace_foreacharg: dict[str, Argument] = {}
+        base_name_and_overload_name = (f.func.name.name.base, f.func.name.overload_name)
+        if info is None:
+            assert (
+                base_name_and_overload_name
+                in _foreach_ops_without_differentiability_info
+            ), (
+                f"{'.'.join(base_name_and_overload_name)} should have a differentiability info"
+            )
+        else:
+            assert (
+                len(f.func.arguments.flat_non_out)
+                == len(info.func.func.arguments.flat_non_out)
+            ) or (base_name_and_overload_name in _foreach_ops_with_different_arity), (
+                f"{'.'.join(base_name_and_overload_name)} has {len(f.func.arguments.flat_non_out)} args "
+                f"but the reference has {len(info.func.func.arguments.flat_non_out)}"
+            )
+            for foreach_arg, ref_arg in zip(
+                f.func.arguments.flat_non_out, info.func.func.arguments.flat_non_out
+            ):
+                foreach_arg_type = foreach_arg.type
+                if isinstance(foreach_arg_type, ListType):
+                    foreach_arg_type = foreach_arg_type.elem
+                assert foreach_arg_type == ref_arg.type
+                inplace_foreacharg2refarg[foreach_arg] = ref_arg
+                refargname2inplace_foreacharg[ref_arg.name] = foreach_arg
+
+    def gen_differentiable_input(
+        arg: Argument | SelfArgument | TensorOptionsArguments,
+    ) -> DifferentiableInput | None:
+        if isinstance(arg, TensorOptionsArguments):
+            return None
+        a: Argument = arg.argument if isinstance(arg, SelfArgument) else arg
+
+        # TODO: `cpp_type` is only to keep it byte-for-byte compatible with the old codegen, should remove.
+        # NB: This is not a clone of cpp.argument() - TensorOptionsArguments / faithful / binds are
+        # not handled properly as they are irrelevant for this codegen.
+        cpp_type = cpp.argument_type(a, binds=a.name, symint=True).cpp_type()
+
+        if not is_differentiable(a.name, a.type, info):
+            return None
+        return DifferentiableInput(
+            name=a.name,
+            type=a.type,
+            cpp_type=cpp_type,
+        )
+
+    @with_native_function
+    def gen_differentiable_inputs(f: NativeFunction) -> list[DifferentiableInput]:
+        arguments = list(f.func.arguments.non_out)
+        if is_inplace_foreach and info is not None:
+            for i, arg in enumerate(f.func.arguments.flat_non_out):
+                if arg in inplace_foreacharg2refarg:
+                    # note(crcrpar): From what I understand, what matters is only the name.
+                    # Thus originally I only replace argument only when the names are different.
+                    # TODO(crcrpar): Make it simpler.
+                    mapped_arg = inplace_foreacharg2refarg[arg]
+                    arguments[i] = Argument(
+                        mapped_arg.name,
+                        mapped_arg.type,
+                        mapped_arg.default,
+                        mapped_arg.annotation,
+                    )
+        return list(mapMaybe(gen_differentiable_input, arguments))
+
+    def find_args_with_derivatives(
+        differentiable_inputs: list[DifferentiableInput],
+    ) -> list[DifferentiableInput]:
+        """Find arguments that have derivative definitions"""
+        if info is None or not info.has_derivatives:
+            return differentiable_inputs
+        names = {name for d in info.derivatives for name in d.var_names}
+        differentiable = [arg for arg in differentiable_inputs if arg.name in names]
+        if len(differentiable) != len(names):
+            missing = names - {arg.name for arg in differentiable}
+            raise RuntimeError(
+                f"Missing arguments for derivatives: {missing} in {info.name}"
+            )
+        return differentiable
+
+    differentiable_inputs = gen_differentiable_inputs(f)
+    args_with_derivatives = find_args_with_derivatives(differentiable_inputs)
+    differentiable_outputs = gen_differentiable_outputs(fn, key)
+
+    undifferentiable = (base_name in DONT_REQUIRE_DERIVATIVE) or (
+        name in DONT_REQUIRE_DERIVATIVE
+    )
+
+    requires_derivative = (
+        (not undifferentiable)
+        and (len(differentiable_inputs) > 0)
+        and (
+            (len(differentiable_outputs) > 0)
+            # note(crcrpar): In-place foreach functions are a void function.
+            or is_inplace_foreach
+        )
+    )
+
+    if (
+        info is not None
+        and info.has_derivatives
+        and not requires_derivative
+        # out= ops are allowed to have zero returns which cause requires_derivative to be False
+        # we shouldn't error out though (out= ops for autograd just redispatch)
+        and len(f.func.returns) > 0
+    ):
+        raise RuntimeError(
+            f"ERROR: derivative ignored for {name} -- specified an autograd function without derivative"
+        )
+
+    # note(crcrpar): In-place foreach functions do not support forward AD
+    if requires_derivative and len(fw_derivatives) > 0 and not is_inplace_foreach:
+        assert sum(len(derivative.var_names) for derivative in fw_derivatives) == len(
+            differentiable_outputs
+        ), (
+            "Expected the number of forward derivatives implemented to match the "
+            "number of differentiable outputs. NB: This only applies when at least "
+            "one forward derivative is implemented. Not implementing any forward "
+            "derivatives is also okay, and we would require inputs to the op to "
+            "not have associated tangents in that case."
+        )
+
+    try_jit_decomposition = (
+        requires_derivative
+        and len(fw_derivatives) == 0
+        and (not modifies_arguments(f))
+        and (not returns_void)
+    )
+
+    def emit_save_inputs() -> list[str]:
+        setup: list[str] = []
+        if info is None or not info.has_derivatives:
+            return setup
+
+        has_tensorlist_arg = any(
+            is_tensor_list_type(arg.type) for arg in args_with_derivatives
+        )
+
+        # We don't want to save tensors if we know that they will never be used
+        # when computing the derivative, so we add guards to those statements
+        def guard_for(arg: SavedAttribute) -> str | None:
+            assert info is not None
+
+            # It's hard to determine the edge offset if we have TensorLists
+            # NOTE(crcrpar): in-place foreach functions' arguments include tensorlist
+            # but their derivatives don't use it, so let them bypass this check.
+            if has_tensorlist_arg and (not is_inplace_foreach):
+                return None
+
+            # Empirical evaluation of the cases where we insert those guards in
+            # backward show that they are somewhat useless. E.g. there's no need
+            # to guard on some values captured from forward, because they had to
+            # require_grad if the backward function even gets executed. I don't
+            # have any good ideas for detecting those cases, so I simply disabled the
+            # checks.
+            if "backward" in info.name:
+                return None
+
+            # If there's a single derivative we could compute, we already have
+            # a requires_grad check that is sufficient
+            if len(args_with_derivatives) <= 1:
+                return None
+
+            # We really only care about trimming down the amount of tensors we save
+            if arg.nctype.type != BaseCType(tensorT):
+                return None
+
+            # We want to emit simple guards, so we only allow that if checking one
+            # input is enough to determine whether we need that value
+            used_in = [d for d in info.derivatives if arg in d.saved_inputs]
+            assert len(used_in) > 0
+            if len(used_in) != 1:
+                return None
+            derivative = used_in[0]
+
+            # Case with multioutput formulas
+            # TODO: process all derivative formulas!!!
+            if len(derivative.var_names) != 1:
+                wrap_opt_if_start = derivative.formula.find(
+                    f"wrap_opt_if({arg.nctype.name}"
+                )
+                if wrap_opt_if_start == -1:
+                    return None
+
+                wrap_opt_if_match = re.match(
+                    rf"wrap_opt_if\({arg.nctype.name},(.*?)\)",
+                    derivative.formula[wrap_opt_if_start:],
+                )
+                assert wrap_opt_if_match is not None
+
+                # Condition is between 'wrap_opt_if(var_name,' and ')'.
+                condition_slice = slice(len(rf"wrap_opt_if\({arg.nctype.name},"), -1)
+                wrap_opt_if_condition = wrap_opt_if_match.group(0)[
+                    condition_slice
+                ].strip()
+                # replace 'grad_input_mask[num]' with 'grad_fn->should_compute_output(num)'
+                wrap_opt_if_condition = re.sub(
+                    r"grad_input_mask\[(\d+)\]",
+                    r"grad_fn->should_compute_output(\1)",
+                    wrap_opt_if_condition,
+                )
+                return f"{wrap_opt_if_condition}"
+
+            # Figure out the offset of the edge that uses this variable
+            derivative_var_name = derivative.var_names[0]
+            for edge_off, a in enumerate(args_with_derivatives):
+                if a.name == derivative_var_name:
+                    break
+            else:
+                raise AssertionError
+            return f"grad_fn->should_compute_output({edge_off})"
+
+        if is_inplace_foreach:
+            save_input_stmts = save_variables(info.all_saved_inputs, False, guard_for)
+            if save_input_stmts:
+                setup.append(
+                    LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute(
+                        preamble="", statements=save_input_stmts
+                    )
+                )
+        else:
+            setup.extend(save_variables(info.all_saved_inputs, False, guard_for))
+            for arg in args_with_derivatives:
+                if is_tensor_list_type(arg.type):
+                    setup.append(f"grad_fn->{arg.name}_size_ = {arg.name}.size();")
+        return setup
+
+    def setup_derivative(differentiable_inputs: list[DifferentiableInput]) -> list[str]:
+        body: list[str] = []
+        if is_out_fn:
+            # For out functions, ensure that no input or output requires grad
+            body.append(DECLARE_GRAD_FN.substitute(op="Node"))
+            body.append(
+                SETUP_NONE_REQUIRES_GRAD.substitute(
+                    base_name=base_name,
+                    args_to_check=[arg.name for arg in differentiable_inputs],
+                )
+            )
+            body.append(
+                SETUP_NONE_REQUIRES_GRAD.substitute(
+                    base_name=base_name,
+                    args_to_check=[arg.name for arg in differentiable_outputs],
+                )
+            )
+            return body
+
+        op = info.op if info is not None and info.has_derivatives else "NotImplemented"
+        setup = []
+        if not is_inplace_foreach:
+            setup.extend(
+                ASSIGN_GRAD_FN.substitute(
+                    op=op,
+                    op_ctor=""
+                    if info is not None and info.has_derivatives
+                    else f'"{cpp.name(f.func)}"',
+                    args_with_derivatives=[arg.name for arg in args_with_derivatives],
+                ).split("\n")
+            )
+        else:
+            # note(crcrpar): Assuming in-place foreach function's self_arg is always TensorList.
+            list_like_arg = "self"
+            args = [arg.name for arg in args_with_derivatives]
+            for i, arg in enumerate(args):
+                if is_inplace_foreach and info is not None:
+                    if arg in refargname2inplace_foreacharg:
+                        foreach_arg = refargname2inplace_foreacharg[arg]
+                        args[i] = foreach_arg.name + (
+                            "[i]" if isinstance(foreach_arg.type, ListType) else ""
+                        )
+                else:
+                    if arg == list_like_arg:
+                        args[i] = arg + "[i]"
+            setup.extend(
+                ASSIGN_VECTOR_OF_GRAD_FN.substitute(
+                    op=op,
+                    op_ctor=""
+                    if info is not None and info.has_derivatives
+                    else f'"{cpp.name(f.func)}"',
+                    args_with_derivatives=args,
+                    irange=f"{list_like_arg}.size()",
+                ).split("\n")
+            )
+        setup.extend(emit_save_inputs())
+
+        body.extend(
+            emit_check_no_requires_grad(differentiable_inputs, args_with_derivatives)
+        )
+        declare_grad_fn_template = (
+            DECLARE_GRAD_FN if not is_inplace_foreach else DECLARE_VECTOR_OF_GRAD_FN
+        )
+        body.append(declare_grad_fn_template.substitute(op=op))
+        body.append(SETUP_DERIVATIVE.substitute(setup=setup))
+        return body
+
+    def emit_check_if_in_complex_autograd_allowlist() -> list[str]:
+        body: list[str] = []
+        if base_name in GRADIENT_IMPLEMENTED_FOR_COMPLEX:
+            return body
+        for arg in differentiable_outputs:
+            name = arg.name
+            # TODO: should be `arg.type.is_tensor_like()`?
+            if arg.cpp_type == "at::Tensor" or arg.cpp_type in TENSOR_LIST_LIKE_CTYPES:
+                body.append(f'throw_error_for_complex_autograd({name}, "{base_name}");')
+        return body
+
+    def emit_check_no_requires_grad(
+        tensor_args: list[DifferentiableInput],
+        args_with_derivatives: list[DifferentiableInput],
+    ) -> list[str]:
+        """Checks that arguments without derivatives don't require grad"""
+        body: list[str] = []
+        for arg in tensor_args:
+            if arg in args_with_derivatives:
+                continue
+            arg_name = arg.name
+            if info and arg_name in info.non_differentiable_arg_names:
+                continue
+            if arg_name == "output":
+                # Double-backwards definitions sometimes take in 'input' and
+                # 'output', but only define the derivative for input.
+                continue
+            body.append(f'check_no_requires_grad({arg_name}, "{arg_name}", "{name}");')
+        return body
+
+    def emit_original_self_definition() -> list[str]:
+        body: list[str] = []
+        if inplace:
+            if is_inplace_foreach:
+                body.append(
+                    "std::vector<::std::optional> original_selfs(self.size());"
+                )
+            else:
+                body.append("::std::optional original_self;")
+
+            all_forward_grad_cond = []
+            for derivative in fw_derivatives:
+                if derivative.required_original_self_value:
+                    all_forward_grad_cond.append(
+                        get_any_has_forward_grad_name(derivative.var_names)
+                    )
+
+            if all_forward_grad_cond:
+                if not is_inplace_foreach:
+                    body.append(f"if ({' || '.join(all_forward_grad_cond)}) {{")
+                    body.append("  original_self = self.clone();")
+                    body.append("}")
+                else:
+                    current_all_forward_grad_cond = [
+                        f"{cond}[i]" for cond in all_forward_grad_cond
+                    ]
+                    body.append("for (const auto& i : c10::irange(self.size())) {")
+                    body.append(
+                        f"  if ({' || '.join(current_all_forward_grad_cond)}) {{"
+                    )
+                    body.append("    original_selfs[i] = self[i].clone();")
+                    body.append("  }")
+                    body.append("}")
+
+        return body
+
+    def save_variables(
+        saved_variables: Sequence[SavedAttribute],
+        is_output: bool,
+        guard_for: Callable[[SavedAttribute], str | None] = lambda name: None,
+    ) -> Sequence[str]:
+        # assign the saved variables to the generated grad_fn
+        stmts: list[str] = []
+        for arg in sorted(saved_variables, key=lambda sa: str(sa.nctype.name)):
+            name = (
+                arg.nctype.name.name
+                if isinstance(arg.nctype.name, SpecialArgName)
+                else arg.nctype.name
+            )
+            foreacharg: Argument | None = None
+            is_foreacharg_list_type: bool = False
+            type = arg.nctype.type
+            expr = arg.expr
+            stmts_prepend = None
+            if is_inplace_foreach and info is not None:
+                # todo(crcrpar): See if we can add some check e.g. `assert foreacharg is not None`.
+                # for now the example assert would fail.
+                name_to_query = name.split("_scalar_type")[0]
+                if name_to_query in refargname2inplace_foreacharg:
+                    foreacharg = refargname2inplace_foreacharg[name_to_query]
+                    is_foreacharg_list_type = isinstance(foreacharg.type, ListType)
+                if foreacharg is not None:
+                    name_in_expr = (
+                        f"{foreacharg.name}{'[i]' if is_foreacharg_list_type else ''}"
+                    )
+                    src_name = name
+                    if "_scalar_type" in src_name:
+                        split_src_name = src_name.split("_scalar_type")
+                        assert len(split_src_name) == 2
+                        src_name = split_src_name[0]
+                    expr = expr.replace(src_name, name_in_expr)
+            if (
+                type == BaseCType(tensorT)
+                or type == OptionalCType(BaseCType(tensorT))
+                or type == MutRefCType(OptionalCType(BaseCType(tensorT)))
+                or (is_output and type == BaseCType(scalarT))
+            ):
+                # note(crcrpar): Here `expr` is generated from scratch, `arg.expr` is ignored.
+                var = name
+                name += "_"
+                if var == "self" and inplace:
+                    original_self_var = (
+                        "original_self"
+                        if not is_inplace_foreach
+                        else "original_selfs[i]"
+                    )
+                    self_var = var if not is_inplace_foreach else var + "[i]"
+                    stmts_prepend = f"if (!{original_self_var}.has_value()) {original_self_var} = {self_var}.clone()"
+                    var = f"{original_self_var}.value()"
+                    assert not is_output
+                if inplace and is_output:
+                    assert name == "result_"
+                    var = (
+                        "self[i]"
+                        if is_inplace_foreach or is_foreacharg_list_type
+                        else "self"
+                    )
+                    is_inplace_view = f"{var}.is_view()"
+                    expr = f"SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})"
+                else:
+                    expr = f"SavedVariable({var}, {str(is_output).lower()})"
+                    if foreacharg is not None and "original_selfs" not in expr:
+                        expr = expr.replace(src_name, name_in_expr)
+            elif (
+                type == BaseCType(tensorListT)
+                or type == ListCType(OptionalCType(BaseCType(tensorT)))
+                or type == BaseCType(iTensorListRefT)
+                or type == VectorCType(BaseCType(tensorT))
+            ):
+                # See Note [nuanced return type of out-of-place foreach functions]
+                if type == VectorCType(BaseCType(tensorT)):
+                    assert is_foreach and is_output
+                expr = f"make_saved_variable_list({name}, {str(is_foreach and is_output).lower()})"
+                name += "_"
+            elif type == BaseCType(intArrayRefT):
+                expr = expr + ".vec()"
+            elif type == BaseCType(symIntArrayRefT):
+                expr = expr + ".vec()"
+            elif type == BaseCType(stringT):
+                expr = f"std::string({expr})"
+            elif type == OptionalCType(BaseCType(stringT)):
+                expr = f"{expr}.has_value() ? ::std::optional(std::string({expr}.value())) : ::std::nullopt"
+            elif type == ArrayRefCType(
+                elem=BaseCType(type=BaseCppType(ns="at", name="Scalar"))
+            ):
+                expr = expr + ".vec()"
+
+            guard = guard_for(arg)
+            if guard is None:
+                if stmts_prepend:
+                    stmts.append(f"{stmts_prepend};")
+                stmts.append(f"grad_fn->{name} = {expr};")
+            else:
+                stmts.append(f"if ({guard}) {{")
+                if stmts_prepend:
+                    stmts.append(f"  {stmts_prepend};")
+                stmts.append(f"  grad_fn->{name} = {expr};")
+                stmts.append("}")
+        return stmts
+
+    # Generates a Dispatcher::redispatch() call into the dispatcher. We do this mainly for performance reasons:
+    #  - Pre-compute the full DispatchKeySet. This saves the dispatcher from having to read from TLS.
+    #  - redispatch() avoids a redundant call to RecordFunction, which was already called right before
+    #    we entered this autograd kernel.
+    def emit_dispatch_call(
+        f: NativeFunction, input_base: str, unpacked_args: Sequence[str]
+    ) -> str:
+        """Dispatch call via function in a namespace or method on Tensor."""
+        # code-generated autograd kernels plumb and recompute dispatch keys directly through the kernel for performance.
+        # Ops also always have a function variant of the redispatch API.
+        # See Note [Plumbing Keys Through The Dispatcher] for details.
+        dispatch_key_set = "ks & c10::after_autograd_keyset"
+        call = CALL_REDISPATCH.substitute(
+            api_name=cpp.name(
+                f.func,
+                faithful_name_for_out_overloads=True,
+                symint_overload=f.func.has_symint(),
+            ),
+            unpacked_args=[dispatch_key_set] + list(unpacked_args),
+        )
+        return call
+
+    def wrap_output(
+        f: NativeFunction, unpacked_bindings: list[Binding], var: str
+    ) -> str:
+        call = ""
+        rhs_value: str | None = None
+        if not any(r.type.is_tensor_like() for r in f.func.returns):
+            rhs_value = var
+        else:
+            rhs_value = f"std::move({var})"
+        assert rhs_value is not None
+        call += ASSIGN_RETURN_VALUE.substitute(
+            return_values=tie_return_values(f), rhs_value=rhs_value
+        )
+        return call
+
+    def check_tensorimpl_and_storage(
+        call: str, unpacked_bindings: list[Binding]
+    ) -> str:
+        # See NOTE [ TensorImpl and Storage Pointer Sanity Checks ]
+        stmts_before_call: list[str] = []
+        stmts_after_call: list[str] = []
+
+        if cpp.name(f.func) in DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE:
+            return call
+
+        # Check properties of inputs (enforce (1))
+        for unpacked_binding in unpacked_bindings:
+            arg = unpacked_binding.name
+            noref_cpp_type = unpacked_binding.nctype.type.remove_const_ref()
+            if noref_cpp_type == BaseCType(tensorListT) or noref_cpp_type == BaseCType(
+                iTensorListRefT
+            ):
+                stmts_before_call += [
+                    SAVE_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
+                    SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg),
+                ]
+                stmts_after_call += [
+                    ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
+                    ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg),
+                ]
+            elif noref_cpp_type == ListCType(OptionalCType(BaseCType(tensorT))):
+                stmts_before_call += [
+                    SAVE_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg),
+                    SAVE_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg),
+                ]
+                stmts_after_call += [
+                    ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(
+                        tensorlist_name=arg
+                    ),
+                    ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(
+                        tensorlist_name=arg
+                    ),
+                ]
+            elif noref_cpp_type == BaseCType(tensorT):
+                stmts_before_call += [
+                    SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
+                    SAVE_TENSOR_IMPL.substitute(tensor_name=arg),
+                ]
+                stmts_after_call += [
+                    ENFORCE_SAME_TENSOR_STORAGE.substitute(
+                        tensor_name=arg, out_tensor_name=arg
+                    ),
+                    ENFORCE_SAME_TENSOR_IMPL.substitute(tensor_name=arg),
+                ]
+
+        assert (stmts_before_call and stmts_after_call) or (
+            not stmts_before_call and not stmts_after_call
+        )
+
+        # Check properties of outputs (enforce (2), (3))
+        if f.func.kind() not in (SchemaKind.inplace, SchemaKind.out):
+            base_name = f.func.name.name.base  # TODO: should be str(f.func.name.name)?
+            aliased_arg_name = ALL_VIEW_FUNCTIONS.get(base_name, None)
+            if aliased_arg_name is not None:
+                aliased_arg_name = unpacked_name(aliased_arg_name)
+            for i, (ret, ret_name) in enumerate(
+                zip(f.func.returns, cpp.return_names(f))
+            ):
+                noref_cpp_type = cpp.return_type(ret, symint=True).remove_const_ref()
+                if noref_cpp_type == BaseCType(tensorT):
+                    if aliased_arg_name is not None:
+                        assert i == 0, (
+                            "Expect non-CompositeImplicitAutograd view function {base} to return single output"
+                        )
+                        stmts_after_call += [
+                            ENFORCE_SAME_TENSOR_STORAGE.substitute(
+                                tensor_name=aliased_arg_name, out_tensor_name=ret_name
+                            )
+                        ]
+                    else:
+                        if (
+                            type_wrapper_name(f)
+                            not in DONT_ENFORCE_STORAGE_IMPL_USE_COUNT
+                        ):
+                            stmts_after_call += [
+                                ENFORCE_TENSOR_STORAGE_USE_COUNT_EQUALS_ONE.substitute(
+                                    tensor_name=ret_name, fn_name=type_wrapper_name(f)
+                                )
+                            ]
+
+                    if type_wrapper_name(f) not in DONT_ENFORCE_TENSOR_IMPL_USE_COUNT:
+                        stmts_after_call += [
+                            ENFORCE_TENSOR_IMPL_USE_COUNT_LT_OR_EQ_ONE.substitute(
+                                tensor_name=ret_name, fn_name=type_wrapper_name(f)
+                            )
+                        ]
+
+                # Currently we don't have any functions that return the following types, but
+                # we should update the checks once we do
+                elif noref_cpp_type == ListCType(OptionalCType(BaseCType(tensorT))):
+                    raise AssertionError(
+                        f"Please add use_count checks for {noref_cpp_type}"
+                    )
+                elif noref_cpp_type == BaseCType(tensorListT):
+                    raise AssertionError(
+                        f"Please add use_count checks for {noref_cpp_type}"
+                    )
+
+        if stmts_before_call and stmts_after_call:
+            call = (
+                RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_before_call)
+                + call
+                + RUN_ONLY_IN_DEBUG_MODE.substitute(statements=stmts_after_call)
+            )
+        return call
+
+    def emit_call(
+        f: NativeFunction, unpacked_bindings: list[Binding], try_jit_decomposition: bool
+    ) -> str:
+        # We only care about adding `at::AutoDispatchBelowAutograd` guard for non-variable dispatch
+        # (which corresponds to 'use_derived' strategy). The purpose of this guard is to make sure
+        # the baseType operations still dispatch to non-Variable type, even if the arguments passed
+        # in are now Variables.
+        # See NOTE [ Treating Variables as non-Variables in type dispatch ] for details.
+        unpacked_args = [b.name for b in unpacked_bindings]
+        base_type_call = emit_dispatch_call(f, "self_", unpacked_args)
+
+        if get_view_info(f) is not None or modifies_arguments(f):
+            guard = "at::AutoDispatchBelowAutograd guard;"
+        else:
+            guard = "at::AutoDispatchBelowADInplaceOrView guard;"
+
+        any_has_forward_grad = (
+            get_any_has_fw_grad_cond(derivative=None)
+            if requires_derivative
+            else "false"
+        )
+        return_types = ", ".join(
+            [cpp.return_type(a, symint=True).cpp_type() for a in f.func.returns]
+        )
+        if len(f.func.returns) > 1:
+            return_types = f"std::tuple<{return_types}>"
+
+        arg_names = [
+            a.name
+            for a in cpp.arguments(
+                f.func.arguments,
+                faithful=True,
+                symint=True,
+                method=False,
+                cpp_no_default_args=set(),
+            )
+        ]
+
+        if not modifies_arguments(f) and not returns_void:
+            if try_jit_decomposition:
+                call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES_JVP_DECOMP.substitute(
+                    base_type_call=base_type_call,
+                    tmp_var=TMP_VAR,
+                    guard=guard,
+                    any_has_forward_grad=any_has_forward_grad,
+                    op_name=cpp.name(f.func),
+                    op_overload=f.func.name.overload_name,
+                    return_types=return_types,
+                    arg_names=arg_names,
+                )
+            else:
+                call = DISPATCH_TO_NON_VAR_TYPE_WITH_TMP_RETURN_VALUES.substitute(
+                    base_type_call=base_type_call,
+                    tmp_var=TMP_VAR,
+                    guard=guard,
+                )
+
+            call += wrap_output(f, unpacked_bindings, TMP_VAR)
+        else:
+            assert not try_jit_decomposition
+            call = DISPATCH_TO_NON_VAR_TYPE_WITHOUT_RETURN_VALUES.substitute(
+                base_type_call=base_type_call, guard=guard
+            )
+        call = check_tensorimpl_and_storage(call, unpacked_bindings)
+        return call
+
+    def emit_history() -> str:
+        fn = "rebase" if modifies_arguments(f) and view_info is None else "set"
+        output_names = [r.name for r in differentiable_outputs]
+        # TODO: flatten allocates a std::vector, which could be expensive
+        outs = CodeTemplate("flatten_tensor_args( ${outs} )").substitute(
+            outs=output_names if not is_inplace_foreach else "self"
+        )
+        if not is_inplace_foreach:
+            return SET_HISTORY.substitute(fn=fn, differentiable_outputs=outs)
+        else:
+            return LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute(
+                preamble=(
+                    f"auto differentiable_outputs = {outs};\n"
+                    f"TORCH_INTERNAL_ASSERT(differentiable_outputs.size() == grad_fns.size());"
+                ),
+                statements=f"{fn}_history(differentiable_outputs[i], grad_fns[i]);",
+            )
+
+    def emit_save_outputs() -> str:
+        if is_out_fn:
+            # out functions don't currently support differentiation
+            return ""
+        if info is not None and info.has_derivatives:
+            stmts = save_variables(info.all_saved_outputs, True)
+            if len(stmts) == 0:
+                return ""
+            if not is_inplace_foreach:
+                return CONDITIONAL.substitute(cond="grad_fn", statements=stmts)
+            else:
+                return LOOP_OVER_VECTOR_OF_GRAD_FNS.substitute(
+                    preamble="", statements=stmts
+                )
+        return ""
+
+    def emit_any_requires_grad() -> list[str]:
+        extra_condition = ""
+        if info and info.output_differentiability_conditions:
+            assert len(info.output_differentiability_conditions) == 1
+            extra_condition = f"_any_requires_grad &= ({info.output_differentiability_conditions[0]});"
+        names_of_args_with_derivatives = [arg.name for arg in args_with_derivatives]
+        if is_inplace_foreach and info is not None:
+            for i, arg in enumerate(names_of_args_with_derivatives):
+                for f_arg, r_arg in inplace_foreacharg2refarg.items():
+                    if arg == r_arg.name:
+                        names_of_args_with_derivatives[i] = f_arg.name
+        return [
+            SETUP_ANY_REQUIRES_GRAD.substitute(
+                args_with_derivatives=names_of_args_with_derivatives,
+                extra_differentiability_conditions=extra_condition,
+            )
+        ]
+
+    def get_any_has_forward_grad_name(var_names: tuple[str, ...]) -> str:
+        if len(var_names) == 1:
+            return f"_any_has_forward_grad_{var_names[0]}"
+        else:
+            return f"_any_has_forward_grad_{'_'.join(var_names)}"
+
+    def emit_any_has_forward_grad() -> list[str]:
+        content: list[str] = []
+        if not is_foreach:
+            for derivative in fw_derivatives:
+                requires_fw_grad = get_any_has_fw_grad_cond(derivative=derivative)
+                if info and info.output_differentiability_conditions:
+                    assert len(info.output_differentiability_conditions) == 1
+                    requires_fw_grad = f"({info.output_differentiability_conditions[0]}) && {requires_fw_grad}"
+                content.append(
+                    f"[[maybe_unused]] auto {get_any_has_forward_grad_name(derivative.var_names)} = {requires_fw_grad};"
+                )
+        else:
+            for derivative in fw_derivatives:
+                bool_vector_name = get_any_has_forward_grad_name(derivative.var_names)
+                cur_derivative_conditions = []
+                for inp in differentiable_inputs:
+                    if derivative.required_inputs_fw_grad is None:
+                        continue
+                    if inp.name not in derivative.required_inputs_fw_grad:
+                        continue
+                    inp_name = (
+                        inp.name
+                        if not inplace
+                        else refargname2inplace_foreacharg[inp.name].name
+                    )
+                    inp_type = (
+                        inp.type
+                        if not inplace
+                        else refargname2inplace_foreacharg[inp.name].type
+                    )
+                    is_list_type = is_tensor_list_type(inp_type)
+                    if is_list_type:
+                        if inp_name != "self":
+                            content.append(
+                                FW_DERIVATIVE_SIZE_CHECK_TEMPLATE.substitute(
+                                    inp_name=inp_name
+                                )
+                            )
+                        cur_derivative_conditions.append(
+                            FW_DERIVATIVE_CHECK_TEMPLATE.substitute(
+                                req_inp=inp_name + "[i]"
+                            )
+                        )
+                    else:
+                        cur_derivative_conditions.append(
+                            FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp_name)
+                        )
+
+                content.append(f"std::vector {bool_vector_name}(self.size());")
+                content.append("for (const auto& i : c10::irange(self.size())) {")
+                content.append(
+                    f"  {bool_vector_name}[i] = {' || '.join(cur_derivative_conditions)};"
+                )
+                content.append("}")
+        return content
+
+    def emit_check_inplace() -> list[str]:
+        if not inplace:
+            return []
+        return [
+            f"check_inplace({arg.name}, _any_requires_grad);"
+            for arg in differentiable_outputs
+        ]
+
+    def emit_fw_derivatives() -> list[str]:
+        content: list[str] = []
+        fw_grad_setters: list[str] = []
+        for derivative in fw_derivatives:
+            res = derivative.var_names
+            if f.func.name.name.inplace:
+                assert len(res) == 1, (
+                    "Expected number of outputs to be 1 if function is inplace"
+                )
+                # TODO update this when inplace namings are unified
+                res = ("self",)
+
+            assert derivative.required_inputs_fw_grad is not None
+
+            unpacked_arguments = ""
+            for inp in differentiable_inputs:
+                inp_name = inp.name
+                is_input_tensorlist = is_foreach and is_tensor_list_type(
+                    inp.type
+                    if not inplace
+                    else refargname2inplace_foreacharg[inp.name].type
+                )
+                input_suffix = "[i]" if is_input_tensorlist else ""
+                if is_inplace_foreach:
+                    if inp.name in refargname2inplace_foreacharg:
+                        inp_name = refargname2inplace_foreacharg[inp.name].name
+                zeros_fn = (
+                    "zeros_symint"
+                    if inplace and inp.name == "self"
+                    else "_efficientzerotensor_symint"
+                )
+                if inp.name in derivative.required_inputs_fw_grad:
+                    unpacked_arguments += (
+                        FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(
+                            inp_name=inp.name,
+                            inp=inp_name + input_suffix,
+                            zeros_fn=zeros_fn,
+                        )
+                    )
+                if inp.name in (derivative.required_inputs_primal or []):
+                    unpacked_arguments += (
+                        FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(
+                            inp_name=inp.name,
+                            inp=inp_name + input_suffix,
+                        )
+                    )
+            if derivative.required_original_self_value:
+                input_suffix = "s[i]" if is_inplace_foreach else ""
+                unpacked_arguments += FW_DERIVATIVE_DEFINED_GRAD_TEMPLATE.substitute(
+                    inp_name="original_self",
+                    inp="original_self" + input_suffix,
+                    zeros_fn=zeros_fn,
+                )
+                unpacked_arguments += FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute(
+                    inp_name="original_self",
+                    inp="original_self" + input_suffix,
+                )
+            elif inplace and derivative.is_reusing_outplace_formula:
+                # The gradient wasn't already cloned, do it if grad mode is enabled
+                unpacked_arguments += (
+                    "self_t = GradMode::is_enabled() ? self_t.clone() : self_t;"
+                )
+
+            if inplace:
+                is_inplace_str = "true"
+            else:
+                is_inplace_str = "false"
+
+            requires_fw_grad = get_any_has_forward_grad_name(derivative.var_names)
+
+            if all(
+                (isinstance(var_type, BaseType) and var_type.is_tensor_like())
+                for var_type in derivative.var_types
+            ):
+                # Is there a way to get from BaseType to BaseCType
+                if len(derivative.var_types) == 1:
+                    opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type()
+                    if not is_foreach:
+                        fw_grad_setters.append(
+                            FW_DERIVATIVE_SETTER_TENSOR.substitute(
+                                out_arg=res[0], is_inplace=is_inplace_str
+                            )
+                        )
+                    else:
+                        assert res[0] == ("result" if not inplace else "self")
+                        fw_grad_setters.append(
+                            FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute(
+                                out_arg=res[0], is_inplace=is_inplace_str
+                            )
+                        )
+                    requires_fw_grad += f" && ({derivative.var_names[0]}.defined())"
+                else:
+                    tuple_type = TupleCType(
+                        [BaseCType(tensorT)] * len(derivative.var_types)
+                    )
+                    opt_res_grad_type = OptionalCType(tuple_type).cpp_type()
+                    for idx, single_res in enumerate(res):
+                        fw_grad_setters.append(
+                            FW_DERIVATIVE_SETTER_MULTI_OUTPUT.substitute(
+                                idx=idx, all_res="_".join(res), out_arg=single_res
+                            )
+                        )
+            elif (
+                isinstance(derivative.var_types[0], ListType)
+                and derivative.var_types[0].is_tensor_like()
+            ):
+                assert len(derivative.var_types) == 1, (
+                    "Expected number of outputs to be 1 if function returns ListType"
+                )
+                if not is_foreach:
+                    opt_res_grad_type = OptionalCType(
+                        VectorCType(BaseCType(tensorT))
+                    ).cpp_type()
+                    fw_grad_setters.append(
+                        FW_DERIVATIVE_SETTER_TENSOR_LIST.substitute(
+                            out_arg=res[0], is_inplace=is_inplace_str
+                        )
+                    )
+                else:
+                    # TODO(crcrpar): Should this (= the foreach specific logic) be refactored somehow?
+                    # Only out-place foreach functions that have entries in `tools/autograd/derivatives.yaml`
+                    # can reach here.
+                    opt_res_grad_type = OptionalCType(BaseCType(tensorT)).cpp_type()
+                    fw_grad_setters.append(
+                        FW_DERIVATIVE_SETTER_TENSOR_FOREACH.substitute(
+                            out_arg=res[0], is_inplace=is_inplace_str
+                        )
+                    )
+            else:
+                raise RuntimeError("Unsupported output type for forward derivative")
+
+            if not is_foreach:
+                fw_grad_opt_definition = f"{opt_res_grad_type} {'_'.join(res)}_new_fw_grad_opt = ::std::nullopt;"
+                # View ops create fw_grad that already is a view of the base's fw_grad so just use that
+                content.append(
+                    FW_DERIVATIVE_TEMPLATE.substitute(
+                        fw_grad_opt_definition=fw_grad_opt_definition,
+                        requires_fw_grad=requires_fw_grad,
+                        formula=derivative.formula,
+                        out_arg="_".join(res),
+                        unpacked_arguments=unpacked_arguments,
+                    )
+                )
+            else:
+                # note(crcrpar): Assuming `self` is TensorList.
+                fw_grad_opt_definition = (
+                    f"std::vector<{opt_res_grad_type}> {'_'.join(res)}_new_fw_grad_opts"
+                    "(self.size(), ::std::nullopt);"
+                )
+                foreach_forward_grad_formula = derivative.formula
+                _foreach_arg: Argument | DifferentiableInput
+                if inplace:
+                    for _foreach_arg, _ref_arg in inplace_foreacharg2refarg.items():
+                        # note(crcrpar): Massage only Scalar and ArrayRef here.
+                        if not (
+                            is_tensor_type(_foreach_arg.type)
+                            or is_tensor_list_type(_foreach_arg.type)
+                        ):
+                            pattern = _foreach_arg.name
+                            if isinstance(_foreach_arg.type, ListType):
+                                pattern += "[i]"
+                            foreach_forward_grad_formula = (
+                                foreach_forward_grad_formula.replace(
+                                    _ref_arg.name, pattern
+                                )
+                            )
+                else:
+                    if (
+                        "result" in foreach_forward_grad_formula
+                        and "result[i]" not in foreach_forward_grad_formula
+                    ):
+                        foreach_forward_grad_formula = (
+                            foreach_forward_grad_formula.replace("result", "result[i]")
+                        )
+
+                content.append(
+                    FW_DERIVATIVE_FOREACH_TEMPLATE.substitute(
+                        fw_grad_opt_definition=fw_grad_opt_definition,
+                        vector_of_optional_tensor=f"{'_'.join(res)}_new_fw_grad_opts",
+                        any_has_forward_grad_for_current_index=" || ".join(
+                            get_any_has_forward_grad_name(derivative.var_names) + "[i]"
+                            for derivative in fw_derivatives
+                        ),
+                        formula=foreach_forward_grad_formula,
+                        unpacked_arguments=unpacked_arguments,
+                    )
+                )
+
+        # Set all the grads at the end to avoid: https://github.com/pytorch/pytorch/issues/67367
+        content.append("\n".join(fw_grad_setters))
+        return content
+
+    def get_any_has_fw_grad_cond(derivative: ForwardDerivative | None) -> str:
+        #
+        # Produces a condition string (e.g, "isFwGradDefined(grad_output) || isFwGradDefined(output)")
+        #
+        if derivative is None:
+            # (1) If a derivative is NOT provided, cond will check fw_grad of ALL differentiable inputs
+            # - Used in the out_fn case when we want to forbid fw derivatives
+            # - Used in the case where the fw_derivative is not defined, but we want
+            #   To check if there is a decomposition registered for jvp
+            to_check: list[str] = []
+            for inp in list(
+                mapMaybe(
+                    gen_differentiable_input,
+                    f.func.arguments.non_out + list(f.func.arguments.out),  # type: ignore[operator]
+                )
+            ):
+                if is_tensor_type(inp.type):
+                    to_check.append(
+                        FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp.name)
+                    )
+                elif is_tensor_list_type(inp.type):
+                    to_check.append(
+                        FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE.substitute(
+                            req_inp=inp.name
+                        )
+                    )
+                else:
+                    raise RuntimeError(
+                        f'Unsupported input type for "{name}" when forbidding forward AD usage.'
+                    )
+            return f"({' || '.join(to_check)})"
+        else:
+            # (2) If derivative is provided, use that information to determine which inputs
+            #     to check fw_grad for
+            assert derivative.required_inputs_fw_grad is not None
+
+            if len(derivative.required_inputs_fw_grad) == 0:
+                # Handle functions like stack
+                # For these, we don't unpack anything and always call the user function
+                if not (
+                    len(differentiable_inputs) == 1
+                    and is_tensor_list_type(differentiable_inputs[0].type)
+                ):
+                    raise RuntimeError(
+                        f'No differentiable input to "{name}" is a differentiable Tensor (as the provided '
+                        "forward AD formula does not use any input tangent) even though a forward gradient "
+                        "formula has been defined for it. This case should only happen for function that "
+                        "take a single TensorList as input. All other cases are not supported right now."
+                    )
+                any_has_fw_grad = "true"
+            else:
+                any_has_fw_grad = " || ".join(
+                    [
+                        (
+                            FW_DERIVATIVE_TENSORLIST_CHECK_TEMPLATE
+                            if is_tensor_list_type(inp.type)
+                            else FW_DERIVATIVE_CHECK_TEMPLATE
+                        ).substitute(req_inp=inp.name)
+                        for inp in differentiable_inputs
+                        if inp.name in derivative.required_inputs_fw_grad
+                    ]
+                )
+                any_has_fw_grad = f"({any_has_fw_grad})"
+
+            return any_has_fw_grad
+
+    def emit_forbid_fw_derivatives(is_out_fn: bool = False) -> str:
+        if is_out_fn:
+            msg = "because it is an out= function"
+        else:
+            msg = (
+                "because it has not been implemented yet.\\nPlease file an issue "
+                "to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
+                "so that we can prioritize its implementation."
+            )
+        cond = get_any_has_fw_grad_cond(derivative=None)
+        return (
+            FW_DERIVATIVE_FORBID_TEMPLATE.substitute(cond=cond, name=name, msg=msg)
+            if cond != ""
+            else ""
+        )
+
+    body: list[str] = []
+    unpack_args_stats, unpacked_bindings = unpack_args(f)
+
+    body.extend(unpack_args_stats)
+    if requires_derivative:
+        body.extend(emit_any_requires_grad())
+        body.extend(emit_any_has_forward_grad())
+        body.extend(emit_check_inplace())
+        body.extend(emit_original_self_definition())
+        body.extend(setup_derivative(differentiable_inputs))
+
+    body.append(emit_call(f, unpacked_bindings, try_jit_decomposition))
+    if requires_derivative:
+        # set_flags has to appear after version_counter, because rebase_history
+        # requires that the counter is incremented before it is called
+        body.append(emit_history())
+        body.extend(emit_check_if_in_complex_autograd_allowlist())
+
+    if is_out_fn:
+        body.append(emit_forbid_fw_derivatives(is_out_fn=True))
+    else:
+        if requires_derivative and not try_jit_decomposition:
+            if len(fw_derivatives) > 0:
+                body.extend(emit_fw_derivatives())
+            else:
+                body.append(emit_forbid_fw_derivatives())
+
+    if requires_derivative:
+        # Save only after the forward AD has been set up
+        body.append(emit_save_outputs())
+
+    if str(f.func.name.name) in RESET_GRAD_ACCUMULATOR:
+        # `inplace` implies that there is exactly one output named `self`,
+        # so we can keep the generated code easy. If you need to
+        # `reset_grad_accumulator` in an operator that's not `inplace`, you can
+        # remove this assert but the code generation will get more elaborate
+        assert inplace
+        body.append("reset_grad_accumulator(self);")
+    if not returns_void:
+        body.append(f"return {get_return_value(f)};")
+    return body
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_view_funcs.py b/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_view_funcs.py
new file mode 100644
index 0000000000000000000000000000000000000000..44c2ab0a8de9bcccfa237a4746f1c6b412cdc27a
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/gen_view_funcs.py
@@ -0,0 +1,339 @@
+# Generates ViewFuncs.h/cpp
+#
+# NOTE: If any changes are being made to the ViewFunc codegen please also check
+# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
+# The fallback is expected to mimic this codegen, so we should keep the two in sync.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import torchgen.api.dispatcher as dispatcher
+from torchgen.api.translate import translate
+from torchgen.api.types import (
+    BaseCType,
+    Binding,
+    NamedCType,
+    SymIntT,
+    tensorT,
+    VectorCType,
+)
+from torchgen.code_template import CodeTemplate
+from torchgen.model import Argument, NativeFunction, OptionalType
+from torchgen.utils import FileManager
+
+from .gen_inplace_or_view_type import (
+    CALL_DISPATCH,
+    extract_bindings,
+    get_view_info,
+    modifies_arguments,
+    use_derived,
+)
+
+
+if TYPE_CHECKING:
+    from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo
+
+
+FUNCTION_DECLARATION = CodeTemplate(
+    """\
+#define ${uppercase_op}_AVAILABLE
+struct ${op} : public ${superclass} {
+  ${op}(${constructor_args}) ${initializer_list}
+  {}
+  virtual ~${op}() override = default;
+  virtual std::vector get_symints() const override;
+  virtual size_t num_symints() const override;
+  virtual std::vector get_tensors() const override;
+  virtual size_t num_tensors() const override;
+  virtual at::Tensor operator()(const at::Tensor&) const override;
+  virtual std::unique_ptr clone_and_set(
+      std::optional> = ::std::nullopt,
+      std::optional> = ::std::nullopt) const override;
+
+protected:
+  virtual void set_symints(std::vector) override;
+  virtual void set_tensors(std::vector) override;
+
+private:
+  ${state}
+};
+
+"""
+)
+
+FUNCTION_DEFINITION = CodeTemplate(
+    """\
+std::vector ${op}::get_symints() const {
+  ${get_symints}
+}
+
+size_t ${op}::num_symints() const {
+  return static_cast(${num_symints});
+}
+
+void ${op}::set_symints(std::vector ${symints_vec}) {
+  TORCH_INTERNAL_ASSERT(${symints_vec}.size() == num_symints());
+  ${set_symints}
+}
+
+std::vector ${op}::get_tensors() const {
+  ${get_tensors}
+}
+
+size_t ${op}::num_tensors() const {
+  return static_cast(${num_tensors});
+}
+
+void ${op}::set_tensors(std::vector ${tensors_vec}) {
+  TORCH_INTERNAL_ASSERT(${tensors_vec}.size() == num_tensors());
+  ${set_tensors}
+}
+
+at::Tensor ${op}::operator()(const at::Tensor& ${call_input_name}) const {
+  return ${op_call};
+}
+
+std::unique_ptr ${op}::clone_and_set(
+    std::optional> ${symints_vec},
+    std::optional> ${tensors_vec}) const {
+  auto output = std::make_unique<${op}>(${clone_args});
+  if (${symints_vec}.has_value()) {
+    output->set_symints(std::move(*(${symints_vec})));
+  }
+  if (${tensors_vec}.has_value()) {
+    output->set_tensors(std::move(*(${tensors_vec})));
+  }
+  return output;
+}
+
+"""
+)
+
+
+# e.g. as_strided -> AsStridedViewFunc for camel case or
+# as_strided_view_func otherwise
+def view_func_name(
+    f: NativeFunction, include_namespace: bool = False, camel_case: bool = True
+) -> str:
+    name = f.func.name.unambiguous_name()
+    view_func_name = f"{name.replace('.', '_')}_view_func"
+    if camel_case:
+        is_private = view_func_name.startswith("_")
+        view_func_name = "".join(
+            [p.title() for p in view_func_name.replace(".", "_").split("_")]
+        )
+        if is_private:
+            # put the leading underscore back in
+            view_func_name = f"_{view_func_name}"
+    namespace = "torch::autograd::generated::" if include_namespace else ""
+    return f"{namespace}{view_func_name}"
+
+
+def is_symint_or_tensor(arg: Argument) -> bool:
+    return arg.type.is_tensor_like() or arg.type.is_symint_like()
+
+
+def remove_const_ref(binding: Binding) -> Binding:
+    return Binding(
+        name=binding.name,
+        nctype=binding.nctype.remove_const_ref(),
+        argument=binding.argument,
+        default=binding.default,
+    )
+
+
+def returns_multi_tensor(fn: NativeFunction) -> bool:
+    returns = fn.func.returns
+    assert len(returns) == 1
+    returns_list_like = returns[0].type.is_list_like() is not None
+    returns_tensor_like = returns[0].type.is_tensor_like()
+    return returns_list_like and returns_tensor_like
+
+
+# Generates strings with logic for getting / setting state of a particular type.
+#
+# Args:
+#   bindings (list): List of state bindings of interest (may be empty)
+#   state_vec_type (NamedCType): Type of vector to either return or copy from
+#
+# Returns:
+#   tuple: (list of getter logic strings, list of setter logic strings, string
+#     with num items expression)
+def generate_state_getter_setter(
+    bindings: list[Binding],
+    state_vec_type: NamedCType,
+) -> tuple[list[str], list[str], str]:
+    getter_logic = []
+    setter_logic = []
+
+    state_vec = state_vec_type.name
+    getter_logic.append(f"{state_vec_type.cpp_type()} {state_vec};")
+    if len(bindings) > 0:
+        setter_logic.append("auto i = 0;")
+
+    num_exprs = []
+    for i, b in enumerate(bindings):
+        assert isinstance(b.argument, Argument)
+        if b.argument.type.is_list_like():
+            # Handle list-likes.
+            num_expr = f"{b.name}.size()"
+            num_exprs.append(num_expr)
+            getter = f"{state_vec}.insert({state_vec}.end(), {b.name}.begin(), {b.name}.end());"
+            setter = f"std::copy({state_vec}.begin() + i, {state_vec}.begin() + i + {b.name}.size(), {b.name}.begin());"
+        elif isinstance(b.argument.type, OptionalType):
+            # Handle optionals.
+            num_expr = f"({b.name}.has_value() ? 1 : 0)"
+            num_exprs.append(num_expr)
+            conditional = f"if({b.name}.has_value())"
+            getter = (
+                f"{conditional} {state_vec}.insert({state_vec}.end(), *({b.name}));"
+            )
+            setter = f"{conditional} {b.name} = {state_vec}[i];"
+        else:
+            num_expr = "1"
+            num_exprs.append(num_expr)
+            getter = f"{state_vec}.push_back({b.name});"
+            setter = f"{b.name} = {state_vec}[i];"
+
+        getter_logic.append(getter)
+        setter_logic.append(setter)
+        if i < len(bindings) - 1:
+            setter_logic.append(f"i += {num_expr};")
+
+    # Reserve / assert based on the total number of items expression.
+    num_items = "0" if len(num_exprs) == 0 else " + ".join(num_exprs)
+    if len(bindings) > 0:
+        getter_logic.insert(1, f"{state_vec}.reserve({num_items});")
+
+    getter_logic.append(f"return {state_vec};")
+
+    return getter_logic, setter_logic, num_items
+
+
+def process_function(fn: NativeFunction, template: CodeTemplate) -> str:
+    bindings = extract_bindings(fn)
+    non_self_bindings = [b for b in bindings if b.name != "self"]
+
+    non_self_args = fn.func.arguments.flat_all[1:]
+    non_self_value_bindings = [
+        dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
+    ]
+
+    # Generate constructor / clone args for the generated struct.
+    constructor_args = [b.defn() for b in non_self_bindings]
+    clone_args = [b.name for b in non_self_bindings]
+
+    # Generate state variable declarations for the generated struct.
+    state_variables = [
+        f"{remove_const_ref(b).defn()};" for b in non_self_value_bindings
+    ]
+
+    # Generate initializer list expressions for the generated struct.
+    # allow_expensive_conversions=True because we need to store e.g. SymIntArrayRefs as
+    # vectors.
+    init_exprs = translate(
+        non_self_bindings, non_self_value_bindings, allow_expensive_conversions=True
+    )
+    initializers = []
+    for b, init_expr in zip(non_self_bindings, init_exprs):
+        name = b.nctype.name
+        assert isinstance(name, str)
+        initializers.append(f"{name}({init_expr.expr})")
+
+    # Generate call to underlying view op
+    call_input_name = "input_base"
+    op_call_args = [call_input_name, *(b.name for b in non_self_bindings)]
+    op_call = CALL_DISPATCH.substitute(
+        unambiguous_name=fn.func.name.unambiguous_name(),
+        unpacked_args=op_call_args,
+    )
+
+    # Multi-output views additionally require a view_idx for disambiguation.
+    if returns_multi_tensor(fn):
+        view_idx_name = "view_idx"
+        view_idx_typename = "int64_t"
+        view_idx_decl = f"{view_idx_typename} {view_idx_name}"
+        constructor_args.append(view_idx_decl)
+        clone_args.append(view_idx_name)
+        state_variables.append(f"{view_idx_decl};")
+        initializers.append(f"{view_idx_name}({view_idx_name})")
+        op_call += f"[{view_idx_name}]"
+
+    # Generate initializer list for the generated struct.
+    initializer_list = f": {', '.join(initializers)}" if len(initializers) > 0 else ""
+
+    # Generate getter / setter logic for any symints.
+    symint_bindings = [
+        b
+        for b in non_self_bindings
+        if isinstance(b.argument, Argument) and b.argument.type.is_symint_like()
+    ]
+    symints_vec_type = NamedCType("symints", VectorCType(BaseCType(SymIntT)))
+    get_symints, set_symints, num_symints = generate_state_getter_setter(
+        symint_bindings, symints_vec_type
+    )
+
+    # Generate getter / setter logic for any tensors.
+    tensor_bindings = [
+        b
+        for b in non_self_bindings
+        if isinstance(b.argument, Argument) and b.argument.type.is_tensor_like()
+    ]
+    tensors_vec_type = NamedCType("tensors", VectorCType(BaseCType(tensorT)))
+    get_tensors, set_tensors, num_tensors = generate_state_getter_setter(
+        tensor_bindings, tensors_vec_type
+    )
+
+    return template.substitute(
+        op=view_func_name(fn),
+        uppercase_op=view_func_name(fn, camel_case=False).upper(),
+        superclass="torch::autograd::ViewFunc",
+        initializer_list=initializer_list,
+        state=state_variables,
+        constructor_args=constructor_args,
+        clone_args=clone_args,
+        symints_vec=symints_vec_type.name,
+        get_symints=get_symints,
+        set_symints=set_symints,
+        num_symints=num_symints,
+        tensors_vec=tensors_vec_type.name,
+        get_tensors=get_tensors,
+        set_tensors=set_tensors,
+        num_tensors=num_tensors,
+        call_input_name=call_input_name,
+        op_call=op_call,
+    )
+
+
+def gen_view_funcs(
+    out: str,
+    fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo],
+    template_path: str,
+) -> None:
+    # don't need the info parts, just the function
+    fns = [fn.func for fn in fns_with_infos if use_derived(fn)]
+    # only want out-of-place views
+    view_fns = [
+        fn for fn in fns if get_view_info(fn) is not None and not modifies_arguments(fn)
+    ]
+
+    declarations = [process_function(fn, FUNCTION_DECLARATION) for fn in view_fns]
+    definitions = [process_function(fn, FUNCTION_DEFINITION) for fn in view_fns]
+    ops_headers = [f"#include " for fn in view_fns]
+
+    file_basename = "ViewFuncs"
+    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
+    for suffix in [".h", ".cpp"]:
+        fname = file_basename + suffix
+        fm.write_with_template(
+            fname,
+            fname,
+            lambda: {
+                "generated_comment": "@"
+                + f"generated from {fm.template_dir_for_comments()}/{fname}",
+                "view_func_declarations": declarations,
+                "view_func_definitions": definitions,
+                "ops_headers": ops_headers,
+            },
+        )
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/load_derivatives.py b/phivenv/Lib/site-packages/torchgen/packaged/autograd/load_derivatives.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a10df4996f13d25265c390567472b5cc1ac0533
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/load_derivatives.py
@@ -0,0 +1,1019 @@
+# Parses derivatives.yaml into autograd functions
+#
+# Each autograd function is represented by `DifferentiabilityInfo` containing
+# a list of `Derivative`. See `torchgen.api.autograd` for the data models.
+
+from __future__ import annotations
+
+import re
+from collections import Counter, defaultdict
+from typing import Any, TYPE_CHECKING
+
+import yaml
+
+from torchgen.api import cpp
+from torchgen.api.autograd import (
+    Derivative,
+    DifferentiabilityInfo,
+    ForwardDerivative,
+    SavedAttribute,
+)
+from torchgen.api.types import (
+    BaseCType,
+    Binding,
+    boolT,
+    CppSignatureGroup,
+    layoutT,
+    longT,
+    NamedCType,
+    OptionalCType,
+    scalarTypeT,
+    SpecialArgName,
+    stringT,
+    symIntArrayRefT,
+    SymIntT,
+    tensorGeometryT,
+    tensorOptionsT,
+    typeAndSizeT,
+    VectorCType,
+)
+from torchgen.context import with_native_function
+from torchgen.gen import get_grouped_by_view_native_functions, parse_native_yaml
+from torchgen.model import (
+    AUTOGRAD_KEYS,
+    FunctionSchema,
+    NativeFunction,
+    NativeFunctionsViewGroup,
+    OperatorName,
+    SchemaKind,
+    Type,
+    Variant,
+)
+from torchgen.utils import concatMap, IDENT_REGEX, split_name_params
+from torchgen.yaml_utils import YamlLoader
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+
+DerivativeRet = tuple[dict[FunctionSchema, dict[str, DifferentiabilityInfo]], set[str]]
+
+_GLOBAL_LOAD_DERIVATIVE_CACHE: dict[tuple[str, str], DerivativeRet] = {}
+
+_VALID_AUTOGRAD_KEYS = set(AUTOGRAD_KEYS)
+
+
+# This function directly adds per-dispatchkey derivative entries for {view}_copy variants of each view op.
+# Since every {view} and {view}_copy op shares the same derivative formula,
+# we generate them here instead of duplicating them in the yaml.
+# See Note [Codegen'd {view}_copy Operators]
+def add_view_copy_derivatives(
+    infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
+    view_groups: list[NativeFunctionsViewGroup],
+) -> None:
+    # Get the map from each view op's name to its corresponding view group
+    view_name_to_group: dict[OperatorName, NativeFunctionsViewGroup] = {
+        g.view.func.name: g for g in view_groups
+    }
+
+    view_infos = {}
+
+    for info_dispatch_dict in infos.values():
+        # maybe_view_group only needs to be calculated once per info_dispatch_dict
+        maybe_view_group = None
+        view_copy_differentiability_infos = {}
+        for dispatch_key, info in info_dispatch_dict.items():
+            maybe_view_group = view_name_to_group.get(info.func.func.name, None)
+            if maybe_view_group is not None and maybe_view_group.view_copy is not None:
+                view_copy_info = info.create_view_copy_from_view_derivative(
+                    maybe_view_group
+                )
+                if view_copy_info is not None:
+                    fn_schema = view_copy_info.func.func
+                    view_copy_differentiability_infos[dispatch_key] = view_copy_info
+            else:
+                break
+        # prefer manually-defined derivatives if any
+        if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos:
+            assert fn_schema is not None
+            view_infos[fn_schema] = view_copy_differentiability_infos
+
+    infos.update(view_infos)
+
+
+def load_derivatives(
+    derivatives_yaml_path: str, native_yaml_path: str, tags_yaml_path: str
+) -> DerivativeRet:
+    # Do some caching as this is a deterministic function
+    global _GLOBAL_LOAD_DERIVATIVE_CACHE
+    key = (derivatives_yaml_path, native_yaml_path)
+    if key not in _GLOBAL_LOAD_DERIVATIVE_CACHE:
+        with open(derivatives_yaml_path) as f:
+            definitions = yaml.load(f, Loader=YamlLoader)
+
+        funcs = parse_native_yaml(native_yaml_path, tags_yaml_path).native_functions
+        # From the parsed native functions, separate out the (generated) view_copy functions,
+        # so we can generate derivatives for them separately.
+        native_functions_with_view_groups = get_grouped_by_view_native_functions(funcs)
+        native_functions = concatMap(
+            lambda g: [g]
+            if isinstance(g, NativeFunction)
+            else list(g.functions(include_copy=True)),
+            native_functions_with_view_groups,
+        )
+        view_groups = [
+            g
+            for g in native_functions_with_view_groups
+            if isinstance(g, NativeFunctionsViewGroup)
+        ]
+
+        # What's the difference between function schema v.s. signature?
+        # function schema is the complete declaration including mutability annotation / default value and etc.
+        # signature is the canonical schema for a group of functions (in-place/out/functional variants)
+        # that are semantically related.
+        functions_by_signature: dict[FunctionSchema, list[NativeFunction]] = (
+            defaultdict(list)
+        )
+        functions_by_schema: dict[str, NativeFunction] = {}
+        for function in native_functions:
+            functions_by_signature[function.func.signature()].append(function)
+            assert str(function.func) not in functions_by_schema
+            functions_by_schema[str(function.func)] = function
+
+        # Keep track of how many of which ops we've seen so we can
+        # disambiguate them with a numeric suffix.
+        op_counter = Counter[str]()
+
+        # infos is a dict that maps FunctionSchema -> a dict of per dispatch key DifferentiabilityInfos
+        # this is useful because in tools/autograd/gen_autograd.py:match_differentiability_info
+        # we ultimately need to categorize the DifferentiabilityInfos by FunctionSchema
+        infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] = {}
+        used_dispatch_keys: set[str] = set()
+        for defn_dict in definitions:
+            # Ensure that the old derivatives.yaml schema with no dispatch key can be loaded.
+            if "dispatch" not in defn_dict:
+                specification = defn_dict.pop("name")
+                output_differentiability = defn_dict.pop(
+                    "output_differentiability", None
+                )
+                defn_dict = {"name": specification, "dispatch": {"Default": defn_dict}}
+                if output_differentiability:
+                    defn_dict["output_differentiability"] = output_differentiability
+            name, per_dispatch_diffinfos = create_differentiability_info(
+                defn_dict,
+                functions_by_signature,
+                functions_by_schema,
+                op_counter,
+                used_dispatch_keys,
+            )
+            infos[name] = per_dispatch_diffinfos
+
+        add_view_copy_derivatives(infos, view_groups)
+
+        # cache both loaded infos as well a a set of all the dispatch_keys/aliases
+        # that appear in derivatives.yaml. used_dispatch_keys is useful for generating
+        # VariableType.cpp where we need a TORCH_LIBRARY_IMPL for every autograd dispatch key used
+        _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos, used_dispatch_keys
+
+    return _GLOBAL_LOAD_DERIVATIVE_CACHE[key]
+
+
+# TODO: Why is this going through CppSignatureGroup, that doesn't make sense...
+@with_native_function
+def cpp_arguments(f: NativeFunction) -> Sequence[Binding]:
+    sigs = CppSignatureGroup.from_native_function(f, method=False)
+    if sigs.symint_signature is not None:
+        return sigs.symint_signature.arguments()
+    else:
+        return sigs.signature.arguments()
+
+
+def create_derivative(
+    f: NativeFunction,
+    formula: str,
+    var_names: tuple[str, ...],
+    available_named_gradients: Sequence[str],
+) -> Derivative:
+    original_formula = formula
+    arguments: list[NamedCType] = [
+        a.nctype.remove_const_ref() for a in cpp_arguments(f)
+    ]
+
+    return_names = tuple(n if n != "self" else "result" for n in cpp.return_names(f))
+    return_types = tuple(
+        cpp.return_type(r, symint=True).remove_const_ref() for r in f.func.returns
+    )
+
+    named_returns = [
+        NamedCType(name, type) for name, type in zip(return_names, return_types)
+    ]
+
+    formula, saved_inputs = saved_variables(formula, arguments, var_names)
+    formula, saved_outputs = saved_variables(formula, named_returns, var_names)
+
+    used_named_gradients = {
+        name
+        for name in available_named_gradients
+        if re.search(IDENT_REGEX.format(name), formula)
+    }
+
+    # Check that the referenced derivatives in the formula are in bounds
+    for i in used_gradient_indices(formula):
+        if i >= len(f.func.returns):
+            raise RuntimeError(
+                f"Out of bounds grads access: derivative formula for {cpp.name(f.func)} "
+                f"used grads[{i}], but the forward only returns {len(f.func.returns)} outputs."
+            )
+
+    return Derivative(
+        formula=formula,
+        original_formula=original_formula,
+        var_names=var_names,
+        saved_inputs=saved_inputs,
+        saved_outputs=saved_outputs,
+        named_gradients=used_named_gradients,
+    )
+
+
+def create_forward_derivative(
+    f: NativeFunction, formula: str, names: tuple[str, ...]
+) -> ForwardDerivative:
+    var_names = names
+    var_types: tuple[Type, ...] | None = None
+    for r in f.func.returns:
+        if r.name in var_names:
+            if var_types is None:
+                var_types = ()
+            var_types = var_types + (r.type,)
+
+    # Handle default return names
+    if var_types is None:
+        if var_names == ("result",):
+            assert len(f.func.returns) == 1
+            var_types = (f.func.returns[0].type,)
+        else:
+            for var_name in var_names:
+                res = re.findall(r"^result(\d+)$", var_name)
+                if len(res) == 1:
+                    if var_types is None:
+                        var_types = ()
+                    arg_idx = int(res[0])
+                    var_types = var_types + (f.func.returns[arg_idx].type,)
+
+    assert var_types is not None, "No matching output for forward derivative definition"
+    return ForwardDerivative(
+        formula=formula,
+        var_names=var_names,
+        var_types=var_types,
+        required_inputs_fw_grad=None,
+        required_inputs_primal=None,
+        required_original_self_value=False,
+        is_reusing_outplace_formula=False,
+    )
+
+
+def postprocess_forward_derivatives(
+    f: NativeFunction,
+    defn_name: str,
+    all_arg_names: list[str],
+    derivatives: list[Derivative],
+    forward_derivatives: list[ForwardDerivative],
+    args_with_derivatives: Sequence[Binding],
+) -> list[ForwardDerivative]:
+    def find_required_inputs(formula: str, postfix: str) -> tuple[str, ...]:
+        is_foreach = f.func.name.name.base.startswith("_foreach_")
+        required_inputs = set()
+        for arg in args_with_derivatives:
+            if (
+                arg.type in ("at::TensorList", "const at::ITensorListRef &")
+                and not is_foreach
+            ):
+                # The functions taking TensorList handle everything internally
+                continue
+            arg_name = arg.name
+
+            found = re.search(IDENT_REGEX.format(arg_name), formula)
+            if found:
+                raise RuntimeError(
+                    f"The forward formula for {defn_name} is using the base name of the {arg_name} "
+                    f"argument which is ambiguous. You should use {arg_name}_p to access the primal "
+                    f"value and {arg_name}_t to access the tangent."
+                )
+
+            found = re.search(IDENT_REGEX.format(arg_name + postfix), formula)
+            if found:
+                required_inputs.add(arg_name)
+
+        return tuple(required_inputs)
+
+    updated_derivatives: list[ForwardDerivative] = []
+
+    for defn in forward_derivatives:
+        formula = defn.formula
+        required_inputs_tangent = find_required_inputs(formula, "_t")
+        if formula == "auto_element_wise":
+            assert f.func.kind() != SchemaKind.inplace, (
+                f"Cannot use auto_element_wise with {f.func.name} because it is an in-place variant"
+            )
+            if (
+                (not len(args_with_derivatives) == 1)
+                or len(forward_derivatives) > 1
+                or len(forward_derivatives[0].var_names) > 1
+            ):
+                raise RuntimeError(
+                    f"Derivative definition of {defn_name} in derivatives.yaml defines the "
+                    "forward definition of gradient as element_wise but this only "
+                    "works for functions with a single differentiable input and a "
+                    "single differentiable output."
+                )
+            if not len(derivatives) == 1:
+                raise RuntimeError(
+                    f"Derivative definition of {defn_name} in derivatives.yaml defines the "
+                    "forward definition of gradient as element_wise but it does not "
+                    "defines the gradient formula for its argument which is required."
+                )
+            # This transformation is based on the observation that for element-wise functions, the Jacobian
+            # matrix is diagonal and thus doing J * v is the same as (v^T J)^T (in practice, we ignore the transpositions)
+            # For the complex case, we use hermitian transpose and get (v.conj() J).conj()
+            # So here we are going to reuse the backward formula and replace two things:
+            # 1) all occurrences of "grad" with "foo_t.conj()", where foo is the name of the unique differentiable input.
+            # 2) all usage of an original input "foo" with its primal value "foo_p".
+            # 3) conjugate the final result
+            # For example, for abs, the backward formula is:
+            #   grad * self.sgn()
+            # And this function generates a forward formula that is:
+            #   (self_t.conj() * self_p.sgn()).conj()
+
+            backward_formula = derivatives[0].original_formula
+            input_name = args_with_derivatives[0].name
+
+            # Do replacement 1) of the grad
+            def repl(m: Any) -> str:
+                return f"{m.group(1)}{input_name}_t.conj(){m.group(2)}"
+
+            fw_formula = re.sub(IDENT_REGEX.format("grad"), repl, backward_formula)
+
+            # Do replacement 2) of the input variables
+            for arg in args_with_derivatives:
+                arg_name = arg.name
+
+                def repl(m: Any) -> str:
+                    return f"{m.group(1)}{arg_name}_p{m.group(2)}"
+
+                fw_formula = re.sub(IDENT_REGEX.format(arg_name), repl, fw_formula)
+
+            # Do the final conjugate 3)
+            fw_formula = f"({fw_formula}).conj()"
+
+            # Since there is a single differentiable inputs and we necessarily need its tangent we can
+            # simply require all differentiable input's tangent.
+            required_inputs_tangent = tuple(all_arg_names)
+            formula = fw_formula
+        elif formula == "auto_linear":
+            if (
+                len(forward_derivatives) > 1
+                or len(forward_derivatives[0].var_names) > 1
+            ):
+                raise RuntimeError(
+                    f"Derivative definition of {defn_name} in derivatives.yaml defines the "
+                    "forward definition of gradient as linear but this only works "
+                    "for functions with a single differentiable output."
+                )
+            # This transformation is based on the observation that linear functions can be written as:
+            #   y = f(x) = A * x
+            # For some matrix A and the Jacobian of the function f is also A.
+            # So doing J * v = A * v = f(v).
+            # Hence to do the jvp, we simply need to evaluate the function at the point v instead of x.
+            # We do this by calling the forward again by replacing any occurrence of the differentiable
+            # input "foo" by it's tangent "foo_t".
+            # Note that multiple inputs are not a problem as long as the function is truly linear wrt to
+            # the vector where all the differentiable inputs are stacked.
+
+            diff_arg_names = [arg.name for arg in args_with_derivatives]
+            assert len(diff_arg_names) > 0
+
+            # Do replacement of input variables
+            new_args = []
+            for arg_name in all_arg_names:
+                if arg_name in diff_arg_names:
+                    arg_name = arg_name + "_t"
+                new_args.append(arg_name)
+
+            # TODO we are trolling
+            if f.func.has_symint():
+                defn_name += "_symint"
+
+            # Call into the forward again. We need two cases here to handle both Tensor methods and at:: functions.
+            if Variant.function in f.variants:
+                fw_formula = f"at::{defn_name}({', '.join(new_args)})"
+            else:
+                assert Variant.method in f.variants
+                fw_formula = f"{new_args[0]}.{defn_name}({', '.join(new_args[1:])})"
+
+            # All of the input tangents are always used so all of them are required here.
+            required_inputs_tangent = tuple(diff_arg_names)
+            formula = fw_formula
+
+        # At this point, the formula is final and is not modified anymore.
+
+        # During forward formula, we use the primal instead of the input Tensors.
+        # This call inspects the formula to find for which input's primal are used.
+        required_inputs_primal = find_required_inputs(formula, "_p")
+
+        updated_derivatives.append(
+            ForwardDerivative(
+                formula=formula,
+                var_names=defn.var_names,
+                var_types=defn.var_types,
+                required_inputs_fw_grad=required_inputs_tangent,
+                required_inputs_primal=required_inputs_primal,
+                required_original_self_value=False,
+                is_reusing_outplace_formula=False,
+            )
+        )
+
+    return updated_derivatives
+
+
+def is_forward_derivative_definition(
+    all_arg_names: list[str], names: tuple[str, ...]
+) -> bool:
+    for name in names:
+        return name not in all_arg_names
+    raise RuntimeError("Expected `names` to be non-empty")
+
+
+def create_differentiability_info(
+    defn_dict: dict[Any, Any],
+    functions_by_signature: dict[FunctionSchema, list[NativeFunction]],
+    functions_by_schema: dict[str, NativeFunction],
+    op_counter: Counter[str],
+    used_dispatch_keys: set[str],
+) -> tuple[FunctionSchema, dict[str, DifferentiabilityInfo]]:
+    """Processes a single entry `defn` in derivatives.yaml"""
+
+    def canonical_function(
+        functions: Sequence[NativeFunction], name: str
+    ) -> NativeFunction:
+        for f in functions:
+            if (
+                not f.func.is_functional_fn()
+                and not f.func.is_out_fn()
+                and name == str(f.func.name.name)
+            ):
+                return f
+        # some functions only have in-place variants
+        assert name + "_" == cpp.name(functions[0].func)
+        return functions[0]
+
+    def split_names(raw_names: str) -> tuple[str, ...]:
+        """Given "foo, bar", return ["foo", "bar"]."""
+        return tuple(x.strip() for x in raw_names.split(","))
+
+    def check_grad_usage(defn_name: str, derivatives: Sequence[Derivative]) -> None:
+        """
+        Check for some subtle mistakes one might make when writing derivatives.
+        These mistakes will compile, but will be latent until a function is
+        used with double backwards.
+        """
+
+        uses_grad = False  # true if any derivative uses "grad"
+        num_grads_uses = 0  # count of uses of "grads" or "grads[INDEX]"
+        uses_named_grads = False  # true if any derivative uses "grad_{name}"
+        used_grads_indices: list[int] = []  # which indices of grads are used
+        for d in derivatives:
+            formula = d.formula
+            uses_grad = uses_grad or bool(
+                re.findall(IDENT_REGEX.format("grad"), formula)
+            )
+            num_grads_uses += len(re.findall(IDENT_REGEX.format("grads"), formula))
+            uses_named_grads = uses_named_grads or bool(d.named_gradients)
+            used_grads_indices.extend(used_gradient_indices(formula))
+        # This is a basic sanity check: the number of places we see
+        # "grads" should be no fewer than the number of indices we see
+        # inside "grads". They may not be equal because we may use
+        # "grads" without an index.
+        assert num_grads_uses >= len(used_grads_indices)
+        # Thus if the number is equal, every use of grads is also
+        # indexed.
+        only_used_grads_indices = num_grads_uses == len(used_grads_indices)
+
+        if uses_grad and num_grads_uses > 0:
+            raise RuntimeError(
+                f"Derivative definition of {defn_name} in derivatives.yaml illegally "
+                "mixes use of 'grad' and 'grads'. Consider replacing "
+                "occurrences of 'grad' with 'grads[0]'"
+            )
+
+        if only_used_grads_indices and set(used_grads_indices) == {0}:
+            raise RuntimeError(
+                f"Derivative definition of {defn_name} in derivatives.yaml solely "
+                "refers to 'grads[0]'.  If the first output is indeed the "
+                "only differentiable output, replace 'grads[0]' with 'grad'; "
+                "otherwise, there is a likely error in your derivatives "
+                "declaration."
+            )
+
+        if uses_named_grads and (uses_grad or num_grads_uses > 0):
+            raise RuntimeError(
+                f"Derivative definition of {defn_name} in derivatives.yaml illegally "
+                'mixes use of "grad_RETURN_NAME" and "grad" or "grads[x]". Use '
+                "only one method for identifying gradients."
+            )
+
+    @with_native_function
+    def set_up_derivatives(
+        f: NativeFunction,
+    ) -> tuple[
+        Sequence[Derivative],
+        Sequence[ForwardDerivative],
+        Sequence[Binding],
+        Sequence[str],
+        Sequence[str],
+    ]:
+        # Set up the derivative information
+        derivatives: list[Derivative] = []
+        forward_derivatives: list[ForwardDerivative] = []
+        non_differentiable_arg_names: list[str] = []
+        args_with_derivatives_set: set[str] = set()
+
+        all_arg_names = [a.name for a in cpp_arguments(f)]
+        all_ret_names = [
+            r.name for r in f.func.returns
+        ]  # only used for the assert below
+        # output_differentiability is captured from the enclosed
+        # scope. Don't modify it.
+        #
+        # If it is not present, then no output is explicitly
+        # undifferentiable.
+        #
+        # It may be present and shorter than the length of return
+        # values. If that's the case, any return value that does not
+        # have a corresponding entry is considered not differentiable.
+        differentiability = output_differentiability or [True] * len(f.func.returns)
+        # A return is available as a named gradient ...
+        available_named_gradients = [
+            f"grad_{ret.name}"
+            for ret, differentiable in zip(f.func.returns, differentiability)
+            # if it has not been explicitly made undifferentiable
+            if differentiable
+            # and if it has a name
+            and ret.name is not None
+            # and if its type is differentiable
+            and ret.type.is_tensor_like()
+        ]
+
+        for raw_names in sorted(defn.keys()):
+            formula = defn[raw_names]
+            names = split_names(raw_names)
+
+            for name in names:
+                assert not (name in all_arg_names and name in all_ret_names), (
+                    f"While processing the derivative formula for '{f.func.name}' wrt '{name}', "
+                    f"expected '{name}' to not be both an input arg and named return. "
+                )
+
+            if is_forward_derivative_definition(all_arg_names, names):
+                forward_derivatives.append(create_forward_derivative(f, formula, names))
+            else:
+                if formula.lower().strip() == "non_differentiable":
+                    non_differentiable_arg_names += names
+                else:
+                    derivative = create_derivative(
+                        f, formula, names, available_named_gradients
+                    )
+                    derivatives.append(derivative)
+                    args_with_derivatives_set |= set(names)
+
+        overlap = args_with_derivatives_set.intersection(non_differentiable_arg_names)
+        if overlap:
+            raise RuntimeError(
+                f"derivatives definition for {defn} have overlapped non_differentiable "
+                f"and differentiable variables: {overlap}"
+            )
+
+        # Next, let us determine the list of inputs in order.
+        # TODO: do we need eagerly calculate and save it here? Can it be derived
+        # from NativeFunction and `derivatives` on callsites instead?
+        args_with_derivatives = [
+            a for a in cpp_arguments(f) if a.name in args_with_derivatives_set
+        ]
+
+        # Postprocess forward derivatives definitions now that we know the differentiable arguments
+        forward_derivatives = postprocess_forward_derivatives(
+            f,
+            defn_name,
+            all_arg_names,
+            derivatives,
+            forward_derivatives,
+            args_with_derivatives,
+        )
+
+        # Test to see if the use of 'grads' makes sense.
+        check_grad_usage(defn_name, derivatives)
+
+        return (
+            derivatives,
+            forward_derivatives,
+            args_with_derivatives,
+            non_differentiable_arg_names,
+            available_named_gradients,
+        )
+
+    # NB: Removes 'name' from defn dictionary
+    specification = defn_dict.pop("name")
+    defn_name, _ = split_name_params(specification)
+    # NB: Removes 'output_differentiability' from defn dictionary
+    #     `None` means all differentiable.
+    output_differentiability = defn_dict.pop("output_differentiability", None)
+    output_differentiability_conditions = None
+    if output_differentiability and any(
+        isinstance(diff, str) for diff in output_differentiability
+    ):
+        if len(output_differentiability) != 1:
+            raise RuntimeError(
+                f"Not supported: for {specification},"
+                f"output_differentiability must either be "
+                f"list[bool] or a list[str] where each str is a "
+                f"condition. In the case where it is a condition, "
+                f"we only support single-output functions. "
+                f"Please file us an issue. "
+            )
+        output_differentiability_conditions = output_differentiability
+        output_differentiability = [True]
+
+    schema_function = functions_by_schema.get(specification)
+    if not schema_function:
+        avail = "\n".join(
+            k for k, v in functions_by_schema.items() if cpp.name(v.func) == defn_name
+        )
+        raise RuntimeError(
+            f"could not find ATen function for schema: {specification} "
+            f".  Available signatures:\n{avail}"
+        )
+
+    # now map this to the legacy schema; this isn't technically necessary, but we'd need some logic here
+    # to map in-place schemas to the out-of-place variants.
+    # TODO: maybe the logic to handle the legacy schema is no longer necessary?
+    signature = schema_function.func.signature()
+    functions = functions_by_signature[signature]
+    if len(functions) == 0:
+        avail = "\n".join(
+            str(k)
+            for k, v in functions_by_signature.items()
+            if cpp.name(k) == defn_name
+        )
+        raise RuntimeError(
+            f"could not find ATen function for legacy signature: {signature} "
+            f"corresponding to schema {specification}.  Please report a bug to PyTorch. "
+            f"Available signatures:\n{avail}"
+        )
+
+    canonical = canonical_function(functions, defn_name)
+    if "grad_input_mask" in (a.name for a in cpp_arguments(canonical)):
+        raise RuntimeError(
+            f"Schema for {defn_name} has an argument named grad_input_mask, "
+            "but this name would be shadowed by our codegen. "
+            "Please use a different name in native_functions.yaml."
+        )
+
+    if "result" in (a.name for a in cpp_arguments(canonical)):
+        raise RuntimeError(
+            f"Schema for {defn_name} has an argument named result, "
+            "but this is only allowed for outputs."
+            "Please use a different name in native_functions.yaml."
+        )
+
+    diffinfo_dict = {}
+    for key, defn in defn_dict["dispatch"].items():
+        if key != "Default" and key not in _VALID_AUTOGRAD_KEYS:
+            raise RuntimeError(
+                f"Invalid dispatch key {key} in derivatives.yaml for {specification},"
+                f" expected key to be one of {_VALID_AUTOGRAD_KEYS}"
+            )
+        if key not in used_dispatch_keys:
+            used_dispatch_keys.add(key)
+
+        (
+            derivatives,
+            forward_derivatives,
+            args_with_derivatives,
+            non_differentiable_arg_names,
+            available_named_gradients,
+        ) = set_up_derivatives(canonical)
+
+        used_named_gradients: set[str] = set()
+        for d in derivatives:
+            used_named_gradients |= d.named_gradients
+
+        # only assign an op name if we are actually going to calculate a derivative
+        op = None
+        if args_with_derivatives:
+            op_prefix = _create_op_prefix(defn_name)
+            if key != "Default":
+                op_prefix = op_prefix + key
+            op = f"{op_prefix}{op_counter[op_prefix]}"
+            op_counter[op_prefix] += 1
+
+        diffinfo_dict[key] = DifferentiabilityInfo(
+            name=defn_name,
+            func=canonical,
+            op=op,
+            derivatives=derivatives,
+            forward_derivatives=forward_derivatives,
+            all_saved_inputs=dedup_vars(
+                [v for d in derivatives for v in d.saved_inputs]
+            ),
+            all_saved_outputs=dedup_vars(
+                [v for d in derivatives for v in d.saved_outputs]
+            ),
+            available_named_gradients=available_named_gradients,
+            used_named_gradients=used_named_gradients,
+            args_with_derivatives=args_with_derivatives,
+            non_differentiable_arg_names=non_differentiable_arg_names,
+            output_differentiability=output_differentiability,
+            output_differentiability_conditions=output_differentiability_conditions,
+        )
+
+    return canonical.func, diffinfo_dict
+
+
+GRAD_INDEX_REGEX = r"(?:^|\W)grads\[(\d+)\]"
+
+
+def used_gradient_indices(formula: str) -> list[int]:
+    """Determine a list of gradient indices (the i in grads[i]) that
+    are used by the formula.
+
+    >>> used_gradient_indices("foo(grads[0], grads[1])")
+    [0, 1]
+    """
+    return [int(i) for i in re.findall(GRAD_INDEX_REGEX, formula)]
+
+
+def saved_variables(
+    formula: str,
+    nctypes: list[NamedCType],
+    var_names: tuple[str, ...],
+) -> tuple[str, tuple[SavedAttribute, ...]]:
+    def stride_expr(name: str) -> str:
+        assert var_names == (name,), (
+            'Replacement for ".strides()" is currently only supported for single derivatives of the same tensor '
+            'that ".strides()" is being called on.'
+        )
+        return f'strides_or_error({name}, "{name}")'
+
+    REPLACEMENTS: list[tuple[str, dict[str, Any]]] = [
+        # replace self.sym_sizes() with self_sym_sizes
+        (
+            r"{}.sym_sizes\(\)",
+            {
+                "suffix": "_sym_sizes",
+                "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)),
+            },
+        ),
+        # replace self->sym_sizes() with self_sym_sizes_opt
+        (
+            r"{}->sym_sizes\(\)",
+            {
+                "suffix": "_sym_sizes_opt",
+                "nctype": lambda name: NamedCType(
+                    name, OptionalCType(BaseCType(symIntArrayRefT))
+                ),
+                "expr": lambda name: f"{name}.has_value() ? std::optional({name}->sym_sizes()) : std::nullopt",
+            },
+        ),
+        # replace self.sym_blocksize() with self_sym_blocksize_opt
+        (
+            r"{}.sym_blocksize\(\)",
+            {
+                "suffix": "_self_sym_blocksize_opt",
+                "nctype": lambda name: NamedCType(
+                    name, OptionalCType(BaseCType(symIntArrayRefT))
+                ),
+                "expr": lambda name: f"at::sparse_csr::getSymIntBlockSize({name})",
+            },
+        ),
+        # replace self.options() with self_options
+        (
+            r"{}.options\(\)",
+            {
+                "suffix": "_options",
+                "nctype": lambda name: NamedCType(name, BaseCType(tensorOptionsT)),
+            },
+        ),
+        # replace zeros_like(self) with self_info
+        (
+            r"zeros_like\({}\)",
+            {
+                "suffix": "_info",
+                "nctype": lambda name: NamedCType(name, BaseCType(typeAndSizeT)),
+                "expr": lambda name: name,  # at save-time
+                "res": lambda name: name + "_info.zeros()",  # at eval-time
+            },
+        ),
+        # replace self.sym_size(2) with self_sym_size_2
+        (
+            r"{}.sym_size\((-?\w+)\)",
+            {
+                "suffix": lambda m: f"_sym_argsize_{m.groups()[0].replace('-', 'minus_')}",
+                "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)),
+            },
+        ),
+        # replace self.numel() with self_numel
+        (
+            r"{}.numel\(\)",
+            {
+                "suffix": "_numel",
+                "nctype": lambda name: NamedCType(name, BaseCType(longT)),
+            },
+        ),
+        # replace self.sym_numel() with self_sym_numel
+        (
+            r"{}.sym_numel\(\)",
+            {
+                "suffix": "_sym_numel",
+                "nctype": lambda name: NamedCType(name, BaseCType(SymIntT)),
+            },
+        ),
+        # replace to_args_sizes(self) with self_args_sizes
+        (
+            r"to_args_sizes\({}\)",
+            {
+                "suffix": "_args_sizes",
+                "nctype": lambda name: NamedCType(
+                    name, VectorCType(VectorCType(BaseCType(longT)))
+                ),
+            },
+        ),
+        # replace to_args_sizes_symint(self) with self_args_sizes
+        (
+            r"to_args_sizes_symint\({}\)",
+            {
+                "suffix": "_args_sizes_symint",
+                "nctype": lambda name: NamedCType(
+                    name, VectorCType(VectorCType(BaseCType(SymIntT)))
+                ),
+            },
+        ),
+        # replace to_args_scalartypes(self) with self_args_scalartypes
+        (
+            r"to_args_scalartypes\({}\)",
+            {
+                "suffix": "_args_scalartypes",
+                "nctype": lambda name: NamedCType(
+                    name, VectorCType(BaseCType(scalarTypeT))
+                ),
+            },
+        ),
+        # replace TensorGeometry(self) with self_geometry
+        (
+            r"TensorGeometry\({}\)",
+            {
+                "suffix": "_geometry",
+                "nctype": lambda name: NamedCType(name, BaseCType(tensorGeometryT)),
+            },
+        ),
+        (
+            r"{}.scalar_type\(\)",
+            {
+                "suffix": "_scalar_type",
+                "nctype": lambda name: NamedCType(name, BaseCType(scalarTypeT)),
+            },
+        ),
+        # replace self.dim() with self_dim
+        (
+            r"{}.dim\(\)",
+            {
+                "suffix": "_dim",
+                "nctype": lambda name: NamedCType(name, BaseCType(longT)),
+            },
+        ),
+        # replace self.sym_strides() with self_sym_strides
+        (
+            r"{}.sym_strides\(\)",
+            {
+                "suffix": "_sym_strides",
+                "nctype": lambda name: NamedCType(name, BaseCType(symIntArrayRefT)),
+                "expr": stride_expr,
+            },
+        ),
+        # replace self.layout() with self_layout
+        (
+            r"{}.layout\(\)",
+            {
+                "suffix": "_layout",
+                "nctype": lambda name: NamedCType(name, BaseCType(layoutT)),
+            },
+        ),
+        # replace self.is_conj() with self_conjugate
+        (
+            r"{}.is_conj\(\)",
+            {
+                "suffix": "_conjugate",
+                "nctype": lambda name: NamedCType(name, BaseCType(boolT)),
+            },
+        ),
+    ]
+
+    # find which arguments need to be saved
+    saved: list[SavedAttribute] = []
+
+    if ".sizes()" in formula or "->sizes()" in formula:
+        raise RuntimeError(
+            ".sizes() is not supported in derivative formulas. Instead, please use the SymInt version,"
+            + f".sym_sizes(), which returned a c10::SymIntArrayRef. formula={formula}"
+        )
+    if re.search(r"\.size\([-]?\d+\)", formula) or re.search(
+        r"->size\([-]?\d+\)", formula
+    ):
+        raise RuntimeError(
+            ".size(int) is not supported in derivative formulas. Instead, please use the SymInt version,"
+            + f".sym_size(int), which returned a c10::SymIntArrayRef. formula={formula}"
+        )
+    if ".strides()" in formula or "->strides()" in formula:
+        raise RuntimeError(
+            ".strides() is not supported in derivative formulas. Instead, please use the SymInt version,"
+            + f".sym_strides(), which returned a c10::SymIntArrayRef. formula={formula}"
+        )
+    for nctype in nctypes:
+        name = (
+            nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name
+        )
+        # First search the formula for expressions which can be evaluated
+        # when the autograd Function is created to avoid saving variables
+        for regex, info in REPLACEMENTS:
+
+            def repl(m: re.Match[str]) -> str:
+                suffix: str = (
+                    info["suffix"](m) if callable(info["suffix"]) else info["suffix"]
+                )
+                expr: str = info["expr"](name) if "expr" in info else m.group(0)
+                saved.append(
+                    SavedAttribute(
+                        nctype=info["nctype"](name + suffix),
+                        expr=expr,
+                    )
+                )
+                if "res" in info:
+                    replacement: str = info["res"](name)
+                    return replacement
+                return name + suffix
+
+            formula = re.sub(regex.format(name), repl, formula)
+
+        # std::optional types stored in Backward nodes must be
+        # converted to std::optional before being passed into
+        # the backward function
+        if nctype.type == OptionalCType(BaseCType(stringT)):
+            formula = re.sub(
+                rf"\b{name}\b",
+                f"{name}.has_value() ? std::optional({name}.value()) : std::nullopt",
+                formula,
+            )
+
+        # Find any variables which remain in the formula and save them
+        if re.search(IDENT_REGEX.format(name), formula):
+            saved.append(
+                SavedAttribute(
+                    nctype=nctype,
+                    expr=name,
+                )
+            )
+
+    return formula, tuple(saved)
+
+
+def _create_op_prefix(name: str) -> str:
+    r"""Takes a native function name converts to an op prefix name.
+
+    Note that the "name" parameter must be the native function name
+    without the optional variant suffix, so "add" instead of
+    "add.out".
+
+    OP names correspond to classes, hence the change to title case.
+
+    Example::
+
+        >>> _create_op_prefix("add")
+        'AddBackward'
+    """
+    camel_case = "".join([p.title() for p in name.split("_")])
+    return (camel_case + "Backward").replace("ForwardBackward", "Backward")
+
+
+def dedup_vars(vars: Sequence[SavedAttribute]) -> Sequence[SavedAttribute]:
+    seen: set[str] = set()
+    saved: list[SavedAttribute] = []
+    for var in vars:
+        name = (
+            var.nctype.name.name
+            if isinstance(var.nctype.name, SpecialArgName)
+            else var.nctype.name
+        )
+        if name in seen:
+            continue
+        seen.add(name)
+        saved.append(var)
+    return saved
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/ADInplaceOrViewType.cpp b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/ADInplaceOrViewType.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..f02837e9d4825a409eb4061195de1e4b8d21b928
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/ADInplaceOrViewType.cpp
@@ -0,0 +1,38 @@
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+#include "torch/csrc/autograd/VariableTypeUtils.h"
+#include "torch/csrc/autograd/generated/ViewFuncs.h"
+
+#include 
+#include 
+#include 
+
+// ${generated_comment}
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+$ops_headers
+#endif
+
+using namespace at;
+using torch::autograd::CreationMeta;
+using torch::autograd::as_view;
+using torch::autograd::increment_version;
+
+namespace torch {
+
+namespace ADInplaceOrView {
+
+namespace {
+${inplace_or_view_method_definitions}
+}  // namespace
+}  // namespace ADInplaceOrView
+
+namespace {
+
+TORCH_LIBRARY_IMPL(aten, ADInplaceOrView, m) {
+  ${inplace_or_view_wrapper_registrations};
+}
+
+}  // namespace
+} // namespace torch
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/Functions.cpp b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/Functions.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..378e9f2743f166c69b49f6f17f29d91898afb370
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/Functions.cpp
@@ -0,0 +1,44 @@
+#include "torch/csrc/autograd/FunctionsManual.h"
+#include "torch/csrc/dynamo/compiled_autograd.h"
+
+// ${generated_comment}
+
+// The manual function definitions that used to be here are now in torch/csrc/autograd/FunctionsManual.cpp
+// This speeds up re-compilation and allow to share these implementations so that they can be
+// used for forward mode AD formulas as well.
+
+using namespace torch::autograd::generated::details;
+using at::Tensor;
+using at::Scalar;
+using at::IntArrayRef;
+using at::TensorList;
+
+namespace torch::autograd::generated {
+
+static at::IValue compute_output_metadata(const torch::autograd::edge_list& next_edges) {
+  auto output_metadata = torch::dynamo::autograd::IValuePacker<
+      std::vector>>::pack(
+              torch::dynamo::autograd::get_input_metadata(next_edges));
+  return output_metadata;
+}
+
+static C10_NOINLINE variable_list compiled_autograd_apply_functional(
+    const PackedArgs& packed_args,
+    const edge_list& next_edges,
+    SwapSavedVariables& saved,
+    const variable_list& grads,
+    const std::string& name) {
+  auto output_metadata = compute_output_metadata(next_edges);
+  const auto& pyinterface = torch::dynamo::autograd::getPyCompilerInterface();
+  return pyinterface->call_function(
+      saved.get_py_compiler(),
+      "apply_functional",
+      name,
+      grads,
+      packed_args.vec(),
+      output_metadata);
+}
+
+${autograd_function_definitions}
+
+} // namespace torch::autograd::generated
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/Functions.h b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/Functions.h
new file mode 100644
index 0000000000000000000000000000000000000000..41b51c9963dce410ded912dece616bebe15825b2
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/Functions.h
@@ -0,0 +1,51 @@
+#pragma once
+
+// ${generated_comment}
+
+#include 
+#include 
+#include 
+
+#include "torch/csrc/autograd/function.h"
+#include "torch/csrc/autograd/variable.h"
+#include "torch/csrc/autograd/saved_variable.h"
+#include 
+
+#include 
+
+namespace torch { namespace autograd { namespace generated {
+
+using at::Scalar;
+using at::Tensor;
+using at::IntArrayRef;
+using at::ArrayRef;
+using at::Type;
+using at::TensorGeometry;
+using at::ScalarType;
+using std::optional;
+using c10::fmap;
+
+inline std::vector unpack_list(at::ArrayRef xs, std::shared_ptr saved_for = nullptr) {
+  // NB: we must explicitly do the conversion in the lambda, otherwise template
+  // deduction will give a Tensor of Variable which is not convertible
+  return fmap(xs, [&saved_for](const SavedVariable& x) {
+    // TODO(crcrpar): Use `std::move(saved_for)` to avoid incrementing refcount, which would need refactoring.
+    return static_cast(x.unpack(saved_for));
+  });
+}
+
+inline c10::List> unpack_opt_list(at::ArrayRef xs, std::shared_ptr saved_for = nullptr) {
+  torch::List> result;
+  result.reserve(xs.size());
+  for (const SavedVariable& v : xs) {
+    auto var = v.unpack(saved_for);
+    result.push_back(var.defined() ? std::optional(var) : ::std::nullopt);
+  }
+  return result;
+}
+
+using torch::autograd::TypeAndSize;
+
+${autograd_function_declarations}
+
+}}} // namespace torch::autograd::generated
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/TraceType.cpp b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/TraceType.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..8b0feeb21cc2666248ef283bfbd6a2355b957863
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/TraceType.cpp
@@ -0,0 +1,40 @@
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+#include "torch/csrc/jit/frontend/tracer.h"
+
+#include 
+
+#include "torch/csrc/autograd/function.h"
+
+#include "ATen/quantized/Quantizer.h"
+
+// ${generated_comment}
+
+// See the `Tracer` section in `torch/csrc/jit/OVERVIEW.md`.
+// NOTE See [Sharded File] comment in VariableType
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+$ops_headers
+#endif
+
+using namespace at;
+
+namespace torch {
+
+namespace TraceType {
+
+namespace {
+${trace_method_definitions}
+}  // namespace
+}  // namespace TraceType
+
+namespace {
+
+TORCH_LIBRARY_IMPL(aten, Tracer, m) {
+  ${trace_wrapper_registrations};
+}
+
+}  // namespace
+
+} // namespace torch
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/VariableType.cpp b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/VariableType.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..4bd3e317493cd254a7023707637e86ef7b99b061
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/VariableType.cpp
@@ -0,0 +1,65 @@
+#include "torch/csrc/autograd/VariableTypeUtils.h"
+#include "torch/csrc/autograd/generated/VariableType.h"
+#include "torch/csrc/autograd/FunctionsManual.h"
+
+#include 
+#include 
+#include 
+#include 
+
+#include 
+
+
+// ${generated_comment}
+
+// NOTE [Sharded File]: on this file's split-into-shards state
+//
+// Back in the good old days, VariableType.cpp was generated as one
+// file with every function in it, and everything was great and
+// simple.
+//
+// However, this file was also very large (over 36,000 lines), and
+// compiling it was very slow, and in fact was a significant
+// bottleneck for incremental rebuilds. To address this, we now
+// generate the file split across multiple shards, named
+// VariableType_0.cpp and so on, which can be compiled in parallel.
+//
+// For ease of inspection and debugging, so that it's not necessary to
+// go rooting around in multiple files, we also generate all the
+// functions together in VariableTypeEverything.cpp. This generated
+// file is only for convenience; it's not actually used in the
+// build. If the file you're looking at now is one of the shards, you
+// may want to switch over to the Everything variant to make you
+// grepping smoother.
+
+using namespace at;
+using namespace torch::autograd::generated;
+using namespace torch::autograd::generated::details;
+
+
+namespace torch::autograd {
+
+namespace VariableType {
+namespace{
+[[maybe_unused]] void reset_grad_accumulator(Variable& self) {
+  AutogradMeta* meta = torch::autograd::impl::get_autograd_meta(self);
+  if (meta != nullptr) {
+    meta->grad_accumulator_.reset();
+  }
+}
+}
+
+namespace {
+
+
+${type_derived_method_definitions}
+}
+}
+
+namespace {
+
+${wrapper_registrations}
+
+}
+
+} // namespace torch::autograd
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/VariableType.h b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/VariableType.h
new file mode 100644
index 0000000000000000000000000000000000000000..f854a863bb68981b4138f47062e999d0b59341a2
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/VariableType.h
@@ -0,0 +1,55 @@
+#pragma once
+
+// ${generated_comment}
+
+#include 
+#include 
+
+#include 
+
+#include 
+#include 
+
+#include  // for size_t
+#include  // for function
+#include  // for unique_ptr
+#include 
+#include 
+
+namespace at {
+  struct Quantizer;
+}
+
+namespace torch { namespace autograd {
+
+using Variable = at::Tensor;
+using at::Context;
+using at::Device;
+using at::Dimname;
+using at::DimnameList;
+using at::Generator;
+using at::IntArrayRef;
+using at::MemoryFormat;
+using at::QScheme;
+using at::Scalar;
+using at::ScalarType;
+using at::Storage;
+using at::Tensor;
+using at::TensorList;
+using at::TensorOptions;
+using at::Quantizer;
+using std::optional;
+
+namespace VariableType {
+  TORCH_API std::vector allCUDATypes();
+  TORCH_API std::vector allXPUTypes();
+  TORCH_API std::vector allCPUTypes();
+  TORCH_API std::vector allPrivateUser1Types();
+
+  at::Tensor & unpack(Tensor & t, const char * name, int pos);
+  const at::Tensor & unpack(const Tensor & t, const char * name, int pos);
+  at::Tensor unpack_opt(const Tensor & t, const char * name, int pos);
+  std::vector unpack(const at::ITensorListRef& tl, const char *name, int pos);
+}
+
+}} // namespace torch::autograd
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/ViewFuncs.cpp b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/ViewFuncs.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2c7bac1bcc471be9c740c3304f33bae5d2e2ef9a
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/ViewFuncs.cpp
@@ -0,0 +1,14 @@
+#include 
+
+// ${generated_comment}
+
+using at::Tensor;
+using at::Scalar;
+using at::IntArrayRef;
+using at::TensorList;
+
+namespace torch::autograd::generated {
+
+${view_func_definitions}
+
+} // namespace torch::autograd::generated
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/ViewFuncs.h b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/ViewFuncs.h
new file mode 100644
index 0000000000000000000000000000000000000000..69701a0fc5f91aa06ac2ce208e8a1c2386008ac9
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/ViewFuncs.h
@@ -0,0 +1,28 @@
+#pragma once
+
+// ${generated_comment}
+
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+$ops_headers
+#endif
+
+namespace torch::autograd::generated {
+
+using at::Scalar;
+using at::Tensor;
+using at::IntArrayRef;
+using at::ArrayRef;
+using at::Type;
+using at::ScalarType;
+using std::optional;
+using c10::fmap;
+
+${view_func_declarations}
+
+} // namespace torch::autograd::generated
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/annotated_fn_args.py.in b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/annotated_fn_args.py.in
new file mode 100644
index 0000000000000000000000000000000000000000..bd219be4268759a52e0bceb9548616ba0fffacc8
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/annotated_fn_args.py.in
@@ -0,0 +1,11 @@
+"""
+This file is needed for generating procedural tests required for
+testing __torch_function__. See tests/test_overrides.py.
+"""
+
+# flake8: noqa
+import torch
+
+annotated_args = {
+${annotated_args}
+}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_enum_tag.cpp b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_enum_tag.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..8f96a3a3663dd15a056adc68b75d8f01fdc565ff
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_enum_tag.cpp
@@ -0,0 +1,15 @@
+#include 
+#include 
+#include 
+#include 
+
+namespace py = pybind11;
+namespace torch {
+    namespace autograd {
+    void initEnumTag(PyObject* module) {
+        auto m = py::handle(module).cast();
+        py::enum_(m, "Tag")
+        ${enum_of_valid_tags};
+        m.doc() = "An Enum that contains tags that can be assigned to an operator registered in C++.";
+    }
+}}
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_fft_functions.cpp b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_fft_functions.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2ff38cb7b71de1a8b0b2a25c42f5fc836fced426
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_fft_functions.cpp
@@ -0,0 +1,81 @@
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+// ${generated_comment}
+
+#include "torch/csrc/Device.h"
+#include "torch/csrc/DynamicTypes.h"
+#include "torch/csrc/Exceptions.h"
+#include "torch/csrc/autograd/python_fft_functions.h"
+#include "torch/csrc/autograd/generated/python_return_types.h"
+#include "torch/csrc/autograd/python_variable.h"
+#include "torch/csrc/autograd/utils/wrap_outputs.h"
+#include "torch/csrc/autograd/utils/python_arg_parsing.h"
+#include "torch/csrc/autograd/generated/variable_factories.h"
+#include "torch/csrc/utils/out_types.h"
+#include "torch/csrc/utils/pycfunction_helpers.h"
+#include "torch/csrc/utils/python_arg_parser.h"
+#include "torch/csrc/utils/structseq.h"
+#include "torch/csrc/utils/device_lazy_init.h"
+
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+$ops_headers
+#endif
+
+using at::Tensor;
+using at::Device;
+using at::Layout;
+using at::Scalar;
+using at::ScalarType;
+using at::Backend;
+using at::OptionalDeviceGuard;
+using at::DeviceGuard;
+using at::TensorOptions;
+using at::IntArrayRef;
+using at::Generator;
+using at::TensorList;
+using at::Dimname;
+using at::DimnameList;
+
+using torch::utils::check_out_type_matches;
+using namespace torch::autograd::utils;
+
+namespace torch::autograd {
+
+// generated forward declarations start here
+
+${py_forwards}
+
+static PyMethodDef fft_functions[] = {
+  ${py_method_defs}
+  {NULL}
+};
+
+static PyObject* THPFFTVariableFunctionsModule = NULL;
+
+void initFFTFunctions(PyObject* module) {
+  static struct PyModuleDef def = {
+     PyModuleDef_HEAD_INIT,
+     "torch._C._fft",
+     NULL,
+     -1,
+     fft_functions
+  };
+  PyObject* fft = PyModule_Create(&def);
+  THPFFTVariableFunctionsModule = fft;
+  if (!fft) {
+    throw python_error();
+  }
+  // steals a reference to fft
+  if (PyModule_AddObject(module, "_fft", fft) != 0) {
+    throw python_error();
+  }
+}
+
+// generated methods start here
+
+${py_methods}
+
+} // namespace torch::autograd
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_functions.cpp b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_functions.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..3e7d1ee94c3972b0861395f90a84276b57892097
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_functions.cpp
@@ -0,0 +1,37 @@
+#include 
+
+// ${generated_comment}
+
+#include 
+#include 
+
+#include 
+#include "torch/csrc/autograd/generated/Functions.h"
+#include "torch/csrc/autograd/python_cpp_function.h"
+#include 
+#include 
+#include 
+#include 
+#include 
+
+// NOTE: See [Sharded File] comment in VariableType
+
+namespace torch::autograd::generated {
+
+template
+static void addClass(PyObject* module, PyTypeObject& type, const char* name,
+  PyGetSetDef* function_properties=NULL, PyMethodDef* function_methods=NULL)
+{
+  _initFunctionPyTypeObject(type, name, function_properties, function_methods);
+  Py_INCREF(&type);
+  PyModule_AddObject(module, name, (PyObject*)&type);
+  registerCppFunction(typeid(C), &type);
+}
+
+${py_function_props_and_getters}
+
+void initialize_autogenerated_functions${shard_id}(PyObject* module) {
+  ${py_function_initializers}
+}
+
+} // namespace torch::autograd::generated
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_functions.h b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_functions.h
new file mode 100644
index 0000000000000000000000000000000000000000..92919a630ca201ca05ce1090e07389a5dcca6453
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_functions.h
@@ -0,0 +1,17 @@
+#pragma once
+
+#include 
+
+// ${generated_comment}
+
+// Python bindings for automatically generated autograd functions
+
+namespace torch { namespace autograd { namespace generated {
+
+${shard_forward_declare}
+
+inline void initialize_autogenerated_functions(PyObject* module) {
+  ${shard_call}
+}
+
+}}} // namespace torch::autograd::generated
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_linalg_functions.cpp b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_linalg_functions.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..ba116e6167092fefbbd71d5528700e4e7a34cd40
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_linalg_functions.cpp
@@ -0,0 +1,68 @@
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+// ${generated_comment}
+
+#include "torch/csrc/Device.h"
+#include "torch/csrc/DynamicTypes.h"
+#include "torch/csrc/Exceptions.h"
+#include "torch/csrc/autograd/python_linalg_functions.h"
+#include "torch/csrc/autograd/generated/python_return_types.h"
+#include "torch/csrc/autograd/python_variable.h"
+#include "torch/csrc/autograd/utils/wrap_outputs.h"
+#include "torch/csrc/autograd/utils/python_arg_parsing.h"
+#include "torch/csrc/utils/pycfunction_helpers.h"
+#include "torch/csrc/utils/python_arg_parser.h"
+#include "torch/csrc/utils/structseq.h"
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+$ops_headers
+#endif
+
+using at::Tensor;
+using at::Scalar;
+using at::ScalarType;
+using at::MemoryFormat;
+using at::Generator;
+using at::IntArrayRef;
+using at::TensorList;
+
+using namespace torch::autograd::utils;
+
+namespace torch::autograd {
+
+// generated forward declarations start here
+
+${py_forwards}
+
+static PyMethodDef linalg_functions[] = {
+  ${py_method_defs}
+  {NULL}
+};
+
+static PyObject* THPLinalgVariableFunctionsModule = NULL;
+
+void initLinalgFunctions(PyObject* module) {
+  static struct PyModuleDef def = {
+     PyModuleDef_HEAD_INIT,
+     "torch._C._linalg",
+     NULL,
+     -1,
+     linalg_functions
+  };
+  PyObject* linalg = PyModule_Create(&def);
+  THPLinalgVariableFunctionsModule = linalg;
+  if (!linalg) {
+    throw python_error();
+  }
+  // steals a reference to linalg
+  if (PyModule_AddObject(module, "_linalg", linalg) != 0) {
+    throw python_error();
+  }
+}
+
+// generated methods start here
+
+${py_methods}
+
+} // namespace torch::autograd
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_nested_functions.cpp b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_nested_functions.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..162904598d52dc4007a1cf29cd798cc4ef5b29dc
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_nested_functions.cpp
@@ -0,0 +1,81 @@
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+// ${generated_comment}
+
+#include "torch/csrc/Device.h"
+#include "torch/csrc/DynamicTypes.h"
+#include "torch/csrc/Exceptions.h"
+#include "torch/csrc/autograd/python_nested_functions.h"
+#include "torch/csrc/autograd/generated/python_return_types.h"
+#include "torch/csrc/autograd/python_variable.h"
+#include "torch/csrc/autograd/utils/wrap_outputs.h"
+#include "torch/csrc/autograd/utils/python_arg_parsing.h"
+#include "torch/csrc/autograd/generated/variable_factories.h"
+#include "torch/csrc/utils/out_types.h"
+#include "torch/csrc/utils/pycfunction_helpers.h"
+#include "torch/csrc/utils/python_arg_parser.h"
+#include "torch/csrc/utils/structseq.h"
+#include "torch/csrc/utils/device_lazy_init.h"
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+$ops_headers
+#endif
+
+using at::Tensor;
+using at::Device;
+using at::Layout;
+using at::Scalar;
+using at::ScalarType;
+using at::Backend;
+using at::OptionalDeviceGuard;
+using at::DeviceGuard;
+using at::TensorOptions;
+using at::IntArrayRef;
+using at::OptionalIntArrayRef;
+using at::Generator;
+using at::TensorList;
+using at::Dimname;
+using at::DimnameList;
+
+using namespace torch::autograd::utils;
+
+namespace torch::autograd {
+
+// generated forward declarations start here
+
+${py_forwards}
+
+static PyMethodDef nested_functions[] = {
+  {NULL, NULL, 0, NULL},
+  ${py_method_defs}
+  {NULL}
+};
+
+static PyObject* THPNestedVariableFunctionsModule = NULL;
+
+void initNestedFunctions(PyObject* module) {
+  nested_functions[0] = get_nested_functions_manual()[0];
+  static struct PyModuleDef def = {
+     PyModuleDef_HEAD_INIT,
+     "torch._C._nested",
+     NULL,
+     -1,
+     nested_functions
+  };
+  PyObject* nested = PyModule_Create(&def);
+  THPNestedVariableFunctionsModule = nested;
+  if (!nested) {
+    throw python_error();
+  }
+  // steals a reference to nested
+  if (PyModule_AddObject(module, "_nested", nested) != 0) {
+    throw python_error();
+  }
+}
+
+// generated methods start here
+
+${py_methods}
+
+} // namespace torch::autograd
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_nn_functions.cpp b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_nn_functions.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..55325801981f89fa2a8e909b87b73c2d0310c00c
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_nn_functions.cpp
@@ -0,0 +1,113 @@
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+// ${generated_comment}
+
+#include "torch/csrc/Device.h"
+#include "torch/csrc/DynamicTypes.h"
+#include "torch/csrc/Exceptions.h"
+#include "torch/csrc/autograd/python_nn_functions.h"
+#include "torch/csrc/autograd/generated/python_return_types.h"
+#include "torch/csrc/autograd/python_variable.h"
+#include "torch/csrc/autograd/utils/wrap_outputs.h"
+#include "torch/csrc/autograd/utils/python_arg_parsing.h"
+#include "torch/csrc/utils/pycfunction_helpers.h"
+#include "torch/csrc/utils/python_arg_parser.h"
+#include "torch/csrc/utils/structseq.h"
+#include "torch/csrc/utils/tensor_memoryformats.h"
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+$ops_headers
+#endif
+
+using at::Tensor;
+using at::Scalar;
+using at::MemoryFormat;
+using at::Generator;
+using at::IntArrayRef;
+using at::ArrayRef;
+
+using namespace torch::autograd::utils;
+
+namespace torch::autograd {
+
+static PyObject* THPNNVariableFunctionsModule = nullptr;
+
+static PyObject * THPVariable__parse_to(PyObject* module, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "to(Device device=None, ScalarType dtype=None, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)",
+    "to(ScalarType dtype, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)",
+    "to(Tensor tensor, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)",
+  });
+  ParsedArgs<5> parsed_args;
+  auto r = parser.parse(args, kwargs, parsed_args);
+  if (r.has_torch_function()) {
+    return handle_torch_function(r, args, kwargs, THPNNVariableFunctionsModule, "torch.nn", "_parse_to");
+  }
+  auto parsed = parse_to_conversion(r, /*allow_copy*/ false); // we don't want copy for nn.Module.to
+  auto& device = std::get<0>(parsed);
+  auto& scalarType = std::get<1>(parsed);
+  auto non_blocking = std::get<2>(parsed);
+  auto opt_memory_format = std::get<4>(parsed);
+  auto tuple = THPObjectPtr{PyTuple_New(4)};
+  if (!tuple) throw python_error();
+  if (device) {
+    PyTuple_SET_ITEM(tuple.get(), 0, THPDevice_New(*device));
+  } else {
+    Py_INCREF(Py_None);
+    PyTuple_SET_ITEM(tuple.get(), 0, Py_None);
+  }
+  if (scalarType) {
+    PyTuple_SET_ITEM(tuple.get(), 1, Py_NewRef(torch::getTHPDtype(*scalarType)));
+  } else {
+    Py_INCREF(Py_None);
+    PyTuple_SET_ITEM(tuple.get(), 1, Py_None);
+  }
+  PyTuple_SET_ITEM(tuple.get(), 2, torch::autograd::utils::wrap(non_blocking));
+  if (opt_memory_format.has_value()) {
+    PyTuple_SET_ITEM(tuple.get(), 3, Py_NewRef(torch::utils::getTHPMemoryFormat(opt_memory_format.value())));
+  } else {
+    Py_INCREF(Py_None);
+    PyTuple_SET_ITEM(tuple.get(), 3, Py_None);
+  }
+  return tuple.release();
+  END_HANDLE_TH_ERRORS
+}
+
+// generated forward declarations start here
+
+${py_forwards}
+
+static PyMethodDef nn_functions[] = {
+  {"_parse_to", castPyCFunctionWithKeywords(THPVariable__parse_to),
+    METH_VARARGS | METH_KEYWORDS, nullptr},
+  ${py_method_defs}
+  {nullptr}
+};
+
+void initNNFunctions(PyObject* module) {
+  static struct PyModuleDef def = {
+     PyModuleDef_HEAD_INIT,
+     "torch._C._nn",
+     nullptr,
+     -1,
+     nn_functions
+  };
+  PyObject* nn = PyModule_Create(&def);
+  THPNNVariableFunctionsModule = nn;
+  if (!nn) {
+    throw python_error();
+  }
+  // steals a reference to nn
+  if (PyModule_AddObject(module, "_nn", nn) != 0) {
+    throw python_error();
+  }
+}
+
+// generated methods start here
+
+${py_methods}
+
+} // namespace torch::autograd
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_return_types.cpp b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_return_types.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..e721644565a8c47a0c2081050907008e4b2055c1
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_return_types.cpp
@@ -0,0 +1,52 @@
+#include 
+
+#include 
+#include 
+#include 
+
+#include "torch/csrc/autograd/generated/python_return_types.h"
+#include "torch/csrc/utils/structseq.h"
+#include "torch/csrc/Exceptions.h"
+
+namespace torch { namespace autograd { namespace generated {
+
+${py_return_types}
+
+}}}
+
+namespace torch::autograd {
+
+static void addReturnType(
+    PyObject* module,
+    const char* name,
+    PyTypeObject* type) {
+  // hold onto the TypeObject for the unlikely case of user
+  // deleting or overriding it.
+  Py_INCREF(type);
+  if (PyModule_AddObject(
+          module,
+          name,
+          (PyObject*)type) != 0) {
+    Py_DECREF(type);
+    throw python_error();
+  }
+}
+
+void initReturnTypes(PyObject* module) {
+  static struct PyModuleDef def = {
+      PyModuleDef_HEAD_INIT, "torch._C._return_types", nullptr, -1, {}};
+  PyObject* return_types_module = PyModule_Create(&def);
+  if (!return_types_module) {
+    throw python_error();
+  }
+
+  ${py_return_types_registrations}
+
+  // steals a reference to return_types on success
+  if (PyModule_AddObject(module, "_return_types", return_types_module) != 0) {
+    Py_DECREF(return_types_module);
+    throw python_error();
+  }
+}
+
+} // namespace torch::autograd
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_return_types.h b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_return_types.h
new file mode 100644
index 0000000000000000000000000000000000000000..24c18b92ee7308b3a9ff556ab890a26674913302
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_return_types.h
@@ -0,0 +1,14 @@
+#pragma once
+
+namespace torch {
+namespace autograd {
+namespace generated {
+
+${py_return_types_declarations}
+
+}
+
+void initReturnTypes(PyObject* module);
+
+} // namespace autograd
+} // namespace torch
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_sparse_functions.cpp b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_sparse_functions.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..e565b71f76082b3946b74d645124c3a8b30fbef3
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_sparse_functions.cpp
@@ -0,0 +1,67 @@
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+// ${generated_comment}
+
+#include "torch/csrc/Device.h"
+#include "torch/csrc/DynamicTypes.h"
+#include "torch/csrc/Exceptions.h"
+#include "torch/csrc/autograd/python_sparse_functions.h"
+#include "torch/csrc/autograd/python_variable.h"
+#include "torch/csrc/autograd/utils/wrap_outputs.h"
+#include "torch/csrc/autograd/utils/python_arg_parsing.h"
+#include "torch/csrc/utils/pycfunction_helpers.h"
+#include "torch/csrc/utils/python_arg_parser.h"
+#include "torch/csrc/utils/structseq.h"
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+$ops_headers
+#endif
+
+using at::Tensor;
+using at::Scalar;
+using at::ScalarType;
+using at::MemoryFormat;
+using at::Generator;
+using at::IntArrayRef;
+using at::TensorList;
+
+using namespace torch::autograd::utils;
+
+namespace torch::autograd {
+
+// generated forward declarations start here
+
+${py_forwards}
+
+static PyMethodDef sparse_functions[] = {
+  ${py_method_defs}
+  {NULL}
+};
+
+static PyObject* THPSparseVariableFunctionsModule = NULL;
+
+void initSparseFunctions(PyObject* module) {
+  static struct PyModuleDef def = {
+     PyModuleDef_HEAD_INIT,
+     "torch._C._sparse",
+     NULL,
+     -1,
+     sparse_functions
+  };
+  PyObject* sparse = PyModule_Create(&def);
+  THPSparseVariableFunctionsModule = sparse;
+  if (!sparse) {
+    throw python_error();
+  }
+  // steals a reference to sparse
+  if (PyModule_AddObject(module, "_sparse", sparse) != 0) {
+    throw python_error();
+  }
+}
+
+// generated methods start here
+
+${py_methods}
+
+} // namespace torch::autograd
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_special_functions.cpp b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_special_functions.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..beeac9258b891d6ed1ab1abf221acf04e2f5b8b5
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_special_functions.cpp
@@ -0,0 +1,79 @@
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+// ${generated_comment}
+
+#include "torch/csrc/Device.h"
+#include "torch/csrc/DynamicTypes.h"
+#include "torch/csrc/Exceptions.h"
+#include "torch/csrc/autograd/python_special_functions.h"
+#include "torch/csrc/autograd/generated/python_return_types.h"
+#include "torch/csrc/autograd/python_variable.h"
+#include "torch/csrc/autograd/utils/wrap_outputs.h"
+#include "torch/csrc/autograd/utils/python_arg_parsing.h"
+#include "torch/csrc/autograd/generated/variable_factories.h"
+#include "torch/csrc/utils/out_types.h"
+#include "torch/csrc/utils/pycfunction_helpers.h"
+#include "torch/csrc/utils/python_arg_parser.h"
+#include "torch/csrc/utils/structseq.h"
+#include "torch/csrc/utils/device_lazy_init.h"
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+$ops_headers
+#endif
+
+using at::Tensor;
+using at::Device;
+using at::Layout;
+using at::Scalar;
+using at::ScalarType;
+using at::Backend;
+using at::OptionalDeviceGuard;
+using at::DeviceGuard;
+using at::TensorOptions;
+using at::IntArrayRef;
+using at::Generator;
+using at::TensorList;
+using at::Dimname;
+using at::DimnameList;
+
+using torch::utils::check_out_type_matches;
+using namespace torch::autograd::utils;
+
+namespace torch::autograd {
+
+// generated forward declarations start here
+
+${py_forwards}
+
+static PyMethodDef special_functions[] = {
+  ${py_method_defs}
+  {NULL}
+};
+
+static PyObject* THPSpecialVariableFunctionsModule = NULL;
+
+void initSpecialFunctions(PyObject* module) {
+  static struct PyModuleDef def = {
+     PyModuleDef_HEAD_INIT,
+     "torch._C._special",
+     NULL,
+     -1,
+     special_functions
+  };
+  PyObject* special = PyModule_Create(&def);
+  THPSpecialVariableFunctionsModule = special;
+  if (!special) {
+    throw python_error();
+  }
+  // steals a reference to special
+  if (PyModule_AddObject(module, "_special", special) != 0) {
+    throw python_error();
+  }
+}
+
+// generated methods start here
+
+${py_methods}
+
+} // namespace torch::autograd
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_torch_functions.cpp b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_torch_functions.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..900bd621bb5c6914c13c5bdd52bfe1c121640fd3
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_torch_functions.cpp
@@ -0,0 +1,93 @@
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+// ${generated_comment}
+
+// Python bindings for torch.* functions implemented through ATen.
+//
+// The functions are bound as static methods on a class
+// torch._C._VariableFunctions which is also aliased as Variable._torch
+// and also copied into 'torch' module.
+
+#include 
+
+// Undefine the copysign macro so that at::copysign works as intended with MSVC
+// https://github.com/python/cpython/blob/c60394c7fc9cc09b16e9675a3eeb5844b6d8523f/PC/pyconfig.h#L196
+#ifdef _MSC_VER
+#undef copysign
+#endif // _MSC_VER
+
+#include "torch/csrc/autograd/python_torch_functions.h"
+#include "torch/csrc/autograd/python_variable.h"
+#include "torch/csrc/autograd/utils/wrap_outputs.h"
+#include "torch/csrc/Dtype.h"
+#include "torch/csrc/DynamicTypes.h"
+#include "torch/csrc/Exceptions.h"
+#include "torch/csrc/utils/out_types.h"
+#include "torch/csrc/utils/pybind.h"
+#include "torch/csrc/utils/pycfunction_helpers.h"
+#include "torch/csrc/utils/python_arg_parser.h"
+#include "torch/csrc/utils/tensor_layouts.h"
+#include "torch/csrc/utils/tensor_new.h"
+#include "torch/csrc/utils/tensor_numpy.h"
+#include "torch/csrc/jit/frontend/tracer.h"
+#include "torch/csrc/autograd/generated/variable_factories.h"
+#include "torch/csrc/utils/structseq.h"
+#include "torch/csrc/utils/device_lazy_init.h"
+#include "torch/csrc/autograd/generated/python_return_types.h"
+
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+$ops_headers
+#endif
+
+#include 
+#include 
+#include 
+#include 
+
+using at::Tensor;
+using at::Device;
+using at::Layout;
+using at::Scalar;
+using at::ScalarType;
+using at::Backend;
+using at::OptionalDeviceGuard;
+using at::DeviceGuard;
+using at::TensorOptions;
+using at::IntArrayRef;
+using at::Generator;
+using at::TensorList;
+using at::Dimname;
+using at::DimnameList;
+using at::ArrayRef;
+
+using torch::utils::check_out_type_matches;
+using namespace torch::autograd::utils;
+
+// NOTE: See [Sharded File] comment in VariableType
+
+namespace torch::autograd {
+
+// generated forward declarations start here
+
+${py_forwards}
+
+static PyMethodDef torch_functions_shard[] = {
+  ${py_method_defs}
+};
+
+void gatherTorchFunctions${shard_id}(std::vector &torch_functions) {
+  constexpr size_t num_functions = sizeof(torch_functions_shard) / sizeof(torch_functions_shard[0]);
+  torch_functions.insert(
+    torch_functions.end(),
+    torch_functions_shard,
+    torch_functions_shard + num_functions);
+}
+
+// generated methods start here
+
+${py_methods}
+
+} // namespace torch::autograd
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_variable_methods.cpp b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_variable_methods.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..52ed9ef6e673590b7eb1d802ae17cec2fa4fc442
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/python_variable_methods.cpp
@@ -0,0 +1,1338 @@
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+// ${generated_comment}
+
+#include 
+
+// Undefine the copysign macro so that at::copysign works as intended with MSVC
+// https://github.com/python/cpython/blob/c60394c7fc9cc09b16e9675a3eeb5844b6d8523f/PC/pyconfig.h#L196
+#ifdef _MSC_VER
+#undef copysign
+#endif // _MSC_VER
+
+#include "torch/csrc/DynamicTypes.h"
+#include "torch/csrc/Exceptions.h"
+#include "torch/csrc/Size.h"
+#include "torch/csrc/autograd/generated/VariableType.h"
+#include "torch/csrc/autograd/python_variable.h"
+#include "torch/csrc/autograd/utils/python_arg_parsing.h"
+#include "torch/csrc/autograd/utils/error_messages.h"
+#include "torch/csrc/autograd/utils/wrap_outputs.h"
+#include "torch/csrc/jit/frontend/tracer.h"
+#ifdef USE_CUDA
+#include "torch/csrc/cuda/Event.h"
+#endif
+#include "torch/csrc/utils/device_lazy_init.h"
+#include 
+#include "torch/csrc/utils/object_ptr.h"
+#include "torch/csrc/utils/pycfunction_helpers.h"
+#include "torch/csrc/utils/python_arg_parser.h"
+#include "torch/csrc/utils/python_numbers.h"
+#include "torch/csrc/utils/python_strings.h"
+#include "torch/csrc/utils/tensor_apply.h"
+#include "torch/csrc/utils/tensor_list.h"
+#include "torch/csrc/utils/tensor_new.h"
+#include "torch/csrc/utils/tensor_numpy.h"
+#include "torch/csrc/utils/tensor_types.h"
+#include "torch/csrc/autograd/generated/python_return_types.h"
+
+#include 
+#include 
+#include 
+#include "c10/core/Stream.h"
+
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+$ops_headers
+#include 
+#endif
+
+using at::device_of;
+using at::OptionalDeviceGuard;
+using at::Scalar;
+using at::ScalarType;
+using at::Tensor;
+using c10::Stream;
+using namespace torch::autograd::utils;
+
+namespace torch::autograd {
+
+static PyObject * THPVariable__is_view(PyObject *self, PyObject* args)
+{
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function(self)) {
+    return handle_torch_function(self, "_is_view", args);
+  }
+  auto& self_ = THPVariable_Unpack(self);
+  if (self_.is_view()) {
+    Py_RETURN_TRUE;
+  } else {
+    Py_RETURN_FALSE;
+  }
+  END_HANDLE_TH_ERRORS
+}
+
+// implemented on the python object bc no support for first-class functions in native_functions.yaml
+// See: ATen/native/README.md for more context
+static PyObject * THPVariable_apply_(PyObject* self, PyObject* arg)
+{
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function(self)) {
+    auto args = py::make_tuple(py::handle(arg));
+    return handle_torch_function(self, "apply_", args.ptr());
+  }
+  auto& self_ = THPVariable_Unpack(self);
+  if (self_.requires_grad()) {
+    throw std::runtime_error(
+        "Can't call apply_() on Variable that requires grad. Use "
+        "var.detach().apply_() instead.");
+  }
+  return THPVariable_Wrap(torch::utils::apply_(self_, arg));
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_size(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "size(int64_t? dim=None)",
+    "size(Dimname dim)",
+  });
+  auto& self_ = THPVariable_Unpack(self);
+  ParsedArgs<3> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+  if (r.idx == 0) {
+    if (!r.toInt64Optional(0).has_value()) {
+      return THPSize_NewFromSymSizes(self_);
+    }
+    if (jit::tracer::isTracing()) {
+      // will error out if a tensor has symints
+      return wrap(jit::tracer::getSizeOf(self_, r.toInt64(0)));
+    } else {
+      return torch::toPyObject(self_.sym_size(r.toInt64(0)));
+    }
+  } else if (r.idx == 1) {
+    if (jit::tracer::isTracing()) {
+      TORCH_INTERNAL_ASSERT(false, "NYI: Named tensors w/ JIT");
+    }
+    return wrap(self_.size(r.dimname(0)));
+  }
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_stride(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "stride(int64_t? dim=None)",
+    "stride(Dimname dim)",
+  });
+  auto& self_ = THPVariable_Unpack(self);
+  ParsedArgs<3> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  if (r.idx == 0) {
+    if (r.toInt64Optional(0).has_value()) {
+      return torch::toPyObject(self_.sym_stride(r.toInt64(0)));
+    }
+    // yes, this is called strides in ATen.
+    at::SymIntArrayRef strides = self_.sym_strides();
+    // we can't do the normal wrapping here because IntArrayRef maps to both
+    // torch.Size and tuple in python
+    // TODO: consider factoring this out
+    THPObjectPtr tuple(PyTuple_New(static_cast(strides.size())));
+    if (!tuple) throw python_error();
+    for (size_t i = 0; i != strides.size(); i++) {
+      PyObject* s = torch::toPyObject(strides[i]);
+      if (!s) throw python_error();
+      PyTuple_SET_ITEM(tuple.get(), i, s);
+    }
+    return tuple.release();
+  } else if (r.idx == 1) {
+    return wrap(self_.stride(r.dimname(0)));
+  }
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+// implemented on the python object to avoid dispatch overhead
+static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args)
+{
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function(self_)) {
+    return handle_torch_function(self_, "get_device", args, nullptr);
+  }
+  auto& self = THPVariable_Unpack(self_);
+  return wrap(self.get_device());
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_has_names(PyObject* self_, PyObject* args)
+{
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function(self_)) {
+    return handle_torch_function(self_, "has_names", args);
+  }
+  auto& self = THPVariable_Unpack(self_);
+  return wrap(self.has_names());
+  END_HANDLE_TH_ERRORS
+}
+
+// implemented on the python object to avoid dispatch overhead
+static PyObject * THPVariable_data_ptr(PyObject* self_, PyObject* args)
+{
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function(self_)) {
+    return handle_torch_function(self_, "data_ptr", args);
+  }
+  auto& self = THPVariable_Unpack(self_);
+  return wrap(self.data_ptr());
+  END_HANDLE_TH_ERRORS
+}
+
+// implemented on the python object to avoid dispatch overhead
+static PyObject * THPVariable_storage_offset(PyObject* self_, PyObject* args)
+{
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function(self_)) {
+    return handle_torch_function(self_, "storage_offset");
+  }
+  auto& self = THPVariable_Unpack(self_);
+  return py::cast(self.sym_storage_offset()).release().ptr();
+  END_HANDLE_TH_ERRORS
+}
+
+// implemented on the python object to avoid dispatch overhead
+static PyObject * THPVariable_dim(PyObject* self, PyObject* args)
+{
+   HANDLE_TH_ERRORS
+   if (check_has_torch_function(self)) {
+     return handle_torch_function(self, "dim", args);
+   }
+   auto& self_ = THPVariable_Unpack(self);
+   return THPUtils_packInt64(self_.dim());
+   END_HANDLE_TH_ERRORS
+}
+
+// implemented on the python object to avoid dispatch overhead
+static PyObject * THPVariable_numel(PyObject* self, PyObject* args)
+{
+   HANDLE_TH_ERRORS
+   if (check_has_torch_function(self)) {
+     return handle_torch_function(self, "numel", args);
+   }
+   auto& self_ = THPVariable_Unpack(self);
+   if (jit::tracer::isTracing()) {
+     return wrap(jit::tracer::getNumelOf(self_));
+   } else {
+     return py::cast(self_.sym_numel()).release().ptr();
+   }
+   END_HANDLE_TH_ERRORS
+}
+
+static Tensor dispatch_contiguous(const Tensor & self, at::MemoryFormat memory_format) {
+  pybind11::gil_scoped_release no_gil;
+  OptionalDeviceGuard device_guard(device_of(self));
+  return self.contiguous(memory_format);
+}
+
+static PyObject * THPVariable_contiguous(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "contiguous(*, MemoryFormat memory_format=contiguous_format)",
+  });
+  ParsedArgs<1> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  auto& self_ = THPVariable_Unpack(self);
+  auto memory_format = r.memoryformat(0);
+  // avoids touching the GIL or current device if self is already contiguous
+  if (self_.is_contiguous(memory_format)) {
+    // NOTE: this logic is duplicated from VariableType.cpp. Since we need to
+    // record this call to contiguous() in the trace regardless of whether
+    // we actually call contiguous here, we need to record this information
+    // manually.
+    if (jit::tracer::isTracing()) {
+      const auto& tracer_state = jit::tracer::getTracingState();
+      auto op_name = c10::Symbol::fromQualString("aten::contiguous");
+      auto node = tracer_state->createNode(op_name, /*num_outputs=*/0);
+      jit::tracer::recordSourceLocation(node);
+      jit::tracer::addInputs(node, "self", self_);
+      jit::tracer::addInputs(node, "memory_format", memory_format);
+      tracer_state->insertNode(node);
+      jit::tracer::addOutput(node, self_);
+    }
+    Py_INCREF(self);
+    return self;
+  }
+  return THPVariable_Wrap(dispatch_contiguous(self_, memory_format));
+  END_HANDLE_TH_ERRORS
+}
+
+static Tensor dispatch_copy_(const Tensor & self, const Tensor & other, bool non_blocking) {
+  pybind11::gil_scoped_release no_gil;
+  OptionalDeviceGuard device_guard(device_of(self));
+  return self.copy_(other, non_blocking);
+}
+
+static void maybe_warn_requires_grad(const Tensor & self) {
+  if (at::GradMode::is_enabled() && self.requires_grad()) {
+    TORCH_WARN_ONCE("Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.\n"
+                    "Consider using tensor.detach() first.");
+  }
+}
+
+ static PyObject * THPVariable_copy_(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "copy_(Tensor other, bool non_blocking=False)",
+    "copy_(Tensor other, bool async=False)|deprecated"
+  });
+  auto& self_ = THPVariable_Unpack(self);
+  ParsedArgs<2> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  return THPVariable_Wrap(dispatch_copy_(self_, r.tensor(0), r.toBool(1)));
+  END_HANDLE_TH_ERRORS
+}
+
+template
+static T dispatch_to(const Tensor & self) {
+  pybind11::gil_scoped_release no_gil;
+  OptionalDeviceGuard device_guard(device_of(self));
+  TORCH_CHECK_VALUE(self.sym_numel() == 1, "only one element tensors can be converted to Python scalars");
+  return self.template item();
+}
+
+static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) {
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function(self)) {
+    return handle_torch_function(self, "__float__", args);
+  }
+  jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW);
+  auto& self_ = THPVariable_Unpack(self);
+  maybe_warn_requires_grad(self_);
+  return wrap(dispatch_to(self_));
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_complex_scalar(PyObject* self, PyObject* args) {
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function(self)) {
+    return handle_torch_function(self, "__complex__", args);
+  }
+  jit::tracer::warn("Converting a tensor to a Python complex", jit::tracer::WARN_PYTHON_DATAFLOW);
+  auto& self_ = THPVariable_Unpack(self);
+  maybe_warn_requires_grad(self_);
+  return wrap(dispatch_to>(self_));
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) {
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function(self)) {
+    return handle_torch_function(self, "__int__", args);
+  }
+  jit::tracer::warn("Converting a tensor to a Python integer", jit::tracer::WARN_PYTHON_DATAFLOW);
+  auto& self_ = THPVariable_Unpack(self);
+  if (isFloatingType(self_.scalar_type())) {
+    // we can't dispatch to item here because we want to avoid ATen overflow checks;
+    // the python integral type (long in python2) can't overflow.
+    return THPUtils_packDoubleAsInt(dispatch_to(self_));
+  } else {
+    return wrap(dispatch_to(self_));
+  }
+  END_HANDLE_TH_ERRORS
+}
+
+// This is the __index__ function in Python which is similar to __int__, but
+// called when used as a slice.
+static PyObject * THPVariable_index_scalar(PyObject* self, PyObject* args) {
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function(self)) {
+    return handle_torch_function(self, "__index__", args);
+  }
+  auto& self_ = THPVariable_Unpack(self);
+  // TODO: change the condition to `self_.dim() != 0` once we expose scalars
+  // in PyTorch.
+  if (!isIntegralType(self_.scalar_type(), /*includeBool=*/true) || self_.sym_numel() != 1) {
+    throw TypeError("only integer tensors of a single element can be converted to an index");
+  }
+  return wrap(dispatch_to(self_));
+  END_HANDLE_TH_ERRORS
+}
+
+static Tensor dispatch_invert(const Tensor & self) {
+  pybind11::gil_scoped_release no_gil;
+  OptionalDeviceGuard device_guard(device_of(self));
+  return self.bitwise_not();
+}
+
+static PyObject * THPVariable_invert(PyObject* self, PyObject* args) {
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function(self)) {
+    return handle_torch_function(self, "__invert__", args);
+  }
+  auto& self_ = THPVariable_Unpack(self);
+  if (!isIntegralType(self_.scalar_type(), /*includeBool=*/true)) {
+    throw TypeError("~ (operator.invert) is only implemented on integer and Boolean-type tensors");
+  }
+  return THPVariable_Wrap(dispatch_invert(self_));
+  END_HANDLE_TH_ERRORS
+}
+
+static Tensor dispatch_to(const Tensor & self, Device device, bool non_blocking, bool copy, std::optional optional_memory_format) {
+  pybind11::gil_scoped_release no_gil;
+  // NOTE: this is where we record aten::to in the graph during tracing. However, the behavior of aten::to
+  // is different with respect to TensorOptions fields that are not present: aten::to inherits fields that
+  // are missing from the self argument while the tracer assumes that they should be populated with the
+  // default values (eg. float for scalar type). By explicitly copying over the tensor options here we fully
+  // specify all tensor options and thus record the proper trace
+  return self.to(self.options().device(device).memory_format(optional_memory_format), non_blocking, copy);
+}
+
+static Tensor dispatch_to(const Tensor & self, bool non_blocking, bool copy, std::optional optional_memory_format) {
+  pybind11::gil_scoped_release no_gil;
+  return self.to(self.options().memory_format(optional_memory_format), non_blocking, copy);
+}
+
+static Tensor dispatch_to(const Tensor & self, ScalarType dtype, bool non_blocking, bool copy, std::optional optional_memory_format) {
+  pybind11::gil_scoped_release no_gil;
+  // TODO: Make this call the TensorOptions version, maybe?
+  return self.to(dtype, non_blocking, copy, optional_memory_format);
+}
+
+static Tensor dispatch_to(const Tensor & self, Device device, ScalarType dtype, bool non_blocking, bool copy, std::optional optional_memory_format) {
+  pybind11::gil_scoped_release no_gil;
+  // TODO: Make this call the TensorOptions version, maybe?
+  return self.to(device, dtype, non_blocking, copy, optional_memory_format);
+}
+
+static PyObject * THPVariable_cpu(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+   HANDLE_TH_ERRORS
+   static PythonArgParser parser({
+     "cpu(*, MemoryFormat? memory_format=None)"
+   });
+   auto& self_ = THPVariable_Unpack(self);
+   ParsedArgs<1> parsed_args;
+   auto r = parser.parse(self, args, kwargs, parsed_args);
+
+   if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+    }
+
+   auto opt_memory_format = r.memoryformatOptional(0);
+   return THPVariable_Wrap(dispatch_to(self_, at::Device(at::DeviceType::CPU), false, false, opt_memory_format));
+   END_HANDLE_TH_ERRORS
+}
+
+static Tensor dispatch_nonzero(const Tensor & self) {
+  pybind11::gil_scoped_release no_gil;
+  OptionalDeviceGuard device_guard(device_of(self));
+  return self.nonzero();
+}
+
+static std::vector dispatch_nonzero_numpy(const Tensor & self) {
+  pybind11::gil_scoped_release no_gil;
+  OptionalDeviceGuard device_guard(device_of(self));
+  return self.nonzero_numpy();
+}
+
+static PyObject * THPVariable_nonzero(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "nonzero()",
+    "nonzero(*, bool as_tuple)",
+  });
+  auto& self_ = THPVariable_Unpack(self);
+  ParsedArgs<2> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  if (r.idx == 0 || (r.idx == 1 && !r.toBool(0))) {
+    return wrap(dispatch_nonzero(self_));
+  } else {
+    return wrap(dispatch_nonzero_numpy(self_));
+  }
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_cuda(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "cuda(Device? device=None, bool non_blocking=False, *, MemoryFormat? memory_format=None)",
+    "cuda(Device? device=None, bool async=False, *, MemoryFormat? memory_format=None)|deprecated"
+  });
+  auto& self_ = THPVariable_Unpack(self);
+  ParsedArgs<3> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  auto device = r.isNone(0) ? at::Device(at::DeviceType::CUDA) : r.device(0);
+  auto opt_memory_format = r.memoryformatOptional(2);
+  TORCH_CHECK(device.is_cuda(), "Invalid device, must be cuda device");
+  torch::utils::device_lazy_init(at::kCUDA);
+  return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false, opt_memory_format));
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_mtia(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "mtia(Device? device=None, bool non_blocking=False, *, MemoryFormat? memory_format=None)",
+    "mtia(Device? device=None, bool async=False, *, MemoryFormat? memory_format=None)|deprecated"
+  });
+  auto& self_ = THPVariable_Unpack(self);
+  ParsedArgs<3> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if (r.has_torch_function()) {
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  auto device = r.isNone(0) ? at::Device(at::DeviceType::MTIA) : r.device(0);
+  auto opt_memory_format = r.memoryformatOptional(2);
+  TORCH_CHECK(device.is_mtia(), "Invalid device, must be MTIA device");
+  torch::utils::device_lazy_init(at::kMTIA);
+  return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false, opt_memory_format));
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_xpu(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "xpu(Device? device=None, bool non_blocking=False, *, MemoryFormat? memory_format=None)",
+    "xpu(Device? device=None, bool async=False, *, MemoryFormat? memory_format=None)|deprecated"
+  });
+  auto& self_ = THPVariable_Unpack(self);
+  ParsedArgs<3> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if (r.has_torch_function()) {
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  auto device = r.isNone(0) ? at::Device(at::DeviceType::XPU) : r.device(0);
+  auto opt_memory_format = r.memoryformatOptional(2);
+  TORCH_CHECK(device.is_xpu(), "Invalid device, must be xpu device");
+  torch::utils::device_lazy_init(at::kXPU);
+  return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false, opt_memory_format));
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_ipu(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "ipu(Device? device=None, bool non_blocking=False, *, MemoryFormat? memory_format=None)",
+    "ipu(Device? device=None, bool async=False, *, MemoryFormat? memory_format=None)|deprecated"
+  });
+  auto& self_ = THPVariable_Unpack(self);
+  ParsedArgs<3> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if (r.has_torch_function()) {
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  auto device = r.isNone(0) ? at::Device(at::DeviceType::IPU) : r.device(0);
+  auto opt_memory_format = r.memoryformatOptional(2);
+  TORCH_CHECK(device.is_ipu(), "Invalid device, must be ipu device");
+  return THPVariable_Wrap(dispatch_to(self_, device, r.toBool(1), false, opt_memory_format));
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_to_type(PyObject* self, ScalarType scalarType, std::optional optional_memory_format) {
+  HANDLE_TH_ERRORS
+  auto& self_ = THPVariable_Unpack(self);
+  return THPVariable_Wrap(dispatch_to(self_, scalarType, false, false, optional_memory_format));
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_byte(PyObject* self, PyObject* args, PyObject* kwargs)  {
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "byte(*, MemoryFormat? memory_format=None)"
+  });
+  ParsedArgs<1> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  auto opt_memory_format = r.memoryformatOptional(0);
+  return THPVariable_to_type(self, ScalarType::Byte, opt_memory_format);
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_char(PyObject* self, PyObject* args, PyObject* kwargs)  {
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "char(*, MemoryFormat? memory_format=None)"
+  });
+  ParsedArgs<1> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  auto opt_memory_format = r.memoryformatOptional(0);
+  return THPVariable_to_type(self, ScalarType::Char, opt_memory_format);
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_double(PyObject* self, PyObject* args, PyObject* kwargs) {
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "double(*, MemoryFormat? memory_format=None)"
+  });
+  ParsedArgs<1> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  auto opt_memory_format = r.memoryformatOptional(0);
+  return THPVariable_to_type(self, ScalarType::Double, opt_memory_format);
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_float(PyObject* self, PyObject* args, PyObject* kwargs) {
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "float(*, MemoryFormat? memory_format=None)"
+  });
+  ParsedArgs<1> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  auto opt_memory_format = r.memoryformatOptional(0);
+  return THPVariable_to_type(self, ScalarType::Float, opt_memory_format);
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_cdouble(PyObject* self, PyObject* args, PyObject* kwargs) {
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "cdouble(*, MemoryFormat? memory_format=None)"
+  });
+  ParsedArgs<1> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  auto opt_memory_format = r.memoryformatOptional(0);
+  return THPVariable_to_type(self, ScalarType::ComplexDouble, opt_memory_format);
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_cfloat(PyObject* self, PyObject* args, PyObject* kwargs) {
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "cfloat(*, MemoryFormat? memory_format=None)"
+  });
+  ParsedArgs<1> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  auto opt_memory_format = r.memoryformatOptional(0);
+  return THPVariable_to_type(self, ScalarType::ComplexFloat, opt_memory_format);
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_half(PyObject* self, PyObject* args, PyObject* kwargs) {
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "half(*, MemoryFormat? memory_format=None)"
+  });
+  ParsedArgs<1> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  auto opt_memory_format = r.memoryformatOptional(0);
+  return THPVariable_to_type(self, ScalarType::Half, opt_memory_format);
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_int(PyObject* self, PyObject* args, PyObject* kwargs) {
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "int(*, MemoryFormat? memory_format=None)"
+  });
+  ParsedArgs<1> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  auto opt_memory_format = r.memoryformatOptional(0);
+  return THPVariable_to_type(self, ScalarType::Int, opt_memory_format);
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_long(PyObject* self, PyObject* args, PyObject* kwargs) {
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "long(*, MemoryFormat? memory_format=None)"
+  });
+  ParsedArgs<1> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  auto opt_memory_format = r.memoryformatOptional(0);
+  return THPVariable_to_type(self, ScalarType::Long, opt_memory_format);
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_short(PyObject* self, PyObject* args, PyObject* kwargs) {
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "short(*, MemoryFormat? memory_format=None)"
+  });
+  ParsedArgs<1> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  auto opt_memory_format = r.memoryformatOptional(0);
+  return THPVariable_to_type(self, ScalarType::Short, opt_memory_format);
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_bool(PyObject* self, PyObject* args, PyObject* kwargs) {
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "bool(*, MemoryFormat? memory_format=None)"
+  });
+  ParsedArgs<1> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  auto opt_memory_format = r.memoryformatOptional(0);
+  return THPVariable_to_type(self, ScalarType::Bool, opt_memory_format);
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_bfloat16(PyObject* self, PyObject* args, PyObject* kwargs) {
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "bfloat16(*, MemoryFormat? memory_format=None)"
+  });
+  ParsedArgs<1> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  auto opt_memory_format = r.memoryformatOptional(0);
+  return THPVariable_to_type(self, ScalarType::BFloat16, opt_memory_format);
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_element_size(PyObject* self, PyObject* args)
+{
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function(self)) {
+    return handle_torch_function(self, "element_size", args);
+  }
+  auto& self_ = THPVariable_Unpack(self);
+  return THPUtils_packInt64(self_.element_size());
+  END_HANDLE_TH_ERRORS
+}
+
+// implemented on the python object bc PyObjects not declarable in native_functions.yaml
+// See: ATen/native/README.md for more context
+static PyObject * THPVariable_numpy(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "numpy(*, bool force=False)"
+  });
+  auto& self_ = THPVariable_Unpack(self);
+  ParsedArgs<1> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if (r.has_torch_function()) {
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  jit::tracer::warn("Converting a tensor to a NumPy array", jit::tracer::WARN_PYTHON_DATAFLOW);
+  return torch::utils::tensor_to_numpy(self_, r.toBool(0));
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_requires_grad_(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "requires_grad_(bool requires_grad=True)",
+  });
+  auto& self_ = THPVariable_Unpack(self);
+  ParsedArgs<1> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  // temporary hack to improve functorch UX.
+  const auto& functorch_tls = at::functorch::functorchTLSAccessor();
+  if (functorch_tls) {
+    functorch_tls->checkSupportsInplaceRequiresGrad();
+  }
+
+  auto requires_grad = r.toBool(0);
+  // should we throw if requires_grad is true?  var.requires_grad = True throws here
+  // but it's nice to let this be a no-op.
+  if (!self_.is_leaf() && !requires_grad) {
+    throw std::runtime_error(autograd::utils::requires_grad_leaf_error(requires_grad));
+  }
+  if (requires_grad && ! isDifferentiableType(at::typeMetaToScalarType(self_.dtype()))) {
+    throw std::runtime_error("only Tensors of floating point dtype can require gradients");
+  }
+  self_.set_requires_grad(requires_grad);
+  return THPVariable_Wrap(self_);
+  END_HANDLE_TH_ERRORS
+}
+
+static inline bool dispatch_is_contiguous(const Tensor & self, MemoryFormat memory_format) {
+  return self.is_contiguous(memory_format);
+}
+
+// implemented on the python object to avoid dispatch overhead
+static PyObject * THPVariable_is_contiguous(PyObject* self_, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "is_contiguous(*, MemoryFormat memory_format=contiguous_format)",
+  });
+  ParsedArgs<1> parsed_args;
+  auto r = parser.parse(self_, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self_, args, kwargs, PyObject_Type(self_), "torch.Tensor");
+  }
+
+  auto memory_format = r.memoryformat(0);
+  auto& self = THPVariable_Unpack(self_);
+  return wrap(dispatch_is_contiguous(self, memory_format));
+  END_HANDLE_TH_ERRORS
+}
+
+// implemented on the python object to avoid dispatch overhead
+static PyObject * THPVariable_item(PyObject* self, PyObject* args)
+{
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function(self)) {
+    return handle_torch_function(self, "item", args);
+  }
+  jit::tracer::warn("Converting a tensor to a Python number", jit::tracer::WARN_PYTHON_DATAFLOW);
+  auto& self_ = THPVariable_Unpack(self);
+  auto dispatch_item_ = [](const Tensor& self) -> at::Scalar {
+    pybind11::gil_scoped_release no_gil;
+    return self.item();
+  };
+  return py::cast(dispatch_item_(self_)).release().ptr();
+  END_HANDLE_TH_ERRORS
+}
+
+// implemented on the python object bc no support for first class functions in native_functions.yaml
+// See: ATen/native/README.md for more context
+static PyObject * THPVariable_map_(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({ "map_(Tensor other, PyObject* callable)" });
+  auto& self_ = THPVariable_Unpack(self);
+  ParsedArgs<2> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  Variable other = r.tensor(0);
+  if (self_.requires_grad() || other.requires_grad()) {
+    throw std::runtime_error(
+        "Can't call map_() on Variable that requires grad. Use "
+        "var.detach().map_() instead.");
+  }
+  TORCH_CHECK(
+      !self_.unsafeGetTensorImpl()->is_python_dispatch() && !other.unsafeGetTensorImpl()->is_python_dispatch(),
+      ".map_ is not supported for tensor subclasses.");
+
+  return THPVariable_Wrap(torch::utils::map_(self_, other, r.pyobject(1)));
+  END_HANDLE_TH_ERRORS
+}
+
+// implemented on the python object bc no support for first class functions in native_functions.yaml
+// See: ATen/native/README.md for more context
+static PyObject * THPVariable_map2_(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({ "map2_(Tensor x, Tensor y, PyObject* callable)" });
+  auto& self_ = THPVariable_Unpack(self);
+  ParsedArgs<3> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  Variable x = r.tensor(0);
+  Variable y = r.tensor(1);
+  if (self_.requires_grad() || x.requires_grad() || y.requires_grad()) {
+    throw std::runtime_error(
+        "Can't call map2_() on Variable that requires grad. Use "
+        "var.detach().map2_() instead.");
+  }
+  TORCH_CHECK(
+      !x.unsafeGetTensorImpl()->is_python_dispatch() && !y.unsafeGetTensorImpl()->is_python_dispatch(),
+      ".map2_ is not supported for tensor subclasses.");
+  return THPVariable_Wrap(torch::utils::map2_(self_, x, y, r.pyobject(2)));
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_new(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function(self)) {
+    return handle_torch_function(self, "new", args, kwargs);
+  }
+  auto& self_ = THPVariable_Unpack(self);
+  OptionalDeviceGuard device_guard(device_of(self_));
+  return THPVariable_Wrap(torch::utils::legacy_tensor_new(legacyExtractDispatchKey(self_), self_.scalar_type(), args, kwargs));
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_new_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function(self)) {
+    return handle_torch_function(self, "new_tensor", args, kwargs);
+  }
+  auto& self_ = THPVariable_Unpack(self);
+  OptionalDeviceGuard device_guard(device_of(self_));
+  return THPVariable_Wrap(torch::utils::new_tensor(legacyExtractDispatchKey(self_), self_.scalar_type(), args, kwargs));
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_storage(PyObject* self, PyObject* arg)
+{
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function(self)) {
+    return handle_torch_function(self, "untyped_storage");
+  }
+  auto& self_ = THPVariable_Unpack(self);
+  return createPyObject(self_.storage());
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_to(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "to(Device device=None, ScalarType dtype=None, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)",
+    "to(ScalarType dtype, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)",
+    "to(Tensor tensor, bool non_blocking=False, bool copy=False, *, MemoryFormat? memory_format=None)",
+  });
+  ParsedArgs<5> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+  if (r.has_torch_function()) {
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+  auto parsed = parse_to_conversion(r, /*allow_copy*/ true);
+  auto& device = std::get<0>(parsed);
+  auto& scalarType = std::get<1>(parsed);
+  auto non_blocking = std::get<2>(parsed);
+  auto copy = std::get<3>(parsed);
+  auto opt_memory_format = std::get<4>(parsed);
+  auto& self_ = THPVariable_Unpack(self);
+  torch::utils::maybe_initialize_device(device);
+  if (!device && !scalarType && !copy && !opt_memory_format.has_value()) {
+    Py_INCREF(self);
+    return self;
+  } else if (!device && !scalarType) {
+    return THPVariable_Wrap(
+        dispatch_to(self_, non_blocking, copy, opt_memory_format));
+  } else if (!device) {
+    return THPVariable_Wrap(dispatch_to(self_, *scalarType, non_blocking, copy, opt_memory_format));
+  } else if (!scalarType) {
+    return THPVariable_Wrap(dispatch_to(self_, *device, non_blocking, copy, opt_memory_format));
+  } else {
+    return THPVariable_Wrap(dispatch_to(self_, *device, *scalarType, non_blocking, copy, opt_memory_format));
+  }
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+// implemented on the python object b/c arbitrarily nested list not declarable in native_functions.yaml
+// See: ATen/native/README.md for more context
+static PyObject * THPVariable_tolist(PyObject* self, PyObject* args)
+{
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function(self)) {
+    return handle_torch_function(self, "tolist", args);
+  }
+  jit::tracer::warn("Converting a tensor to a Python list", jit::tracer::WARN_PYTHON_DATAFLOW);
+  auto self_ = THPVariable_Unpack(self);
+  return torch::utils::tensor_to_list(self_);
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+  static PythonArgParser parser({
+    "type(PyObject* dtype=None, bool non_blocking=False, *, MemoryFormat? memory_format=None)",
+    "type(PyObject* dtype=None, bool async=False, *, MemoryFormat? memory_format=None)|deprecated"
+  });
+  auto& self_ = THPVariable_Unpack(self);
+  ParsedArgs<3> parsed_args;
+  auto r = parser.parse(self, args, kwargs, parsed_args);
+
+  if(r.has_torch_function()){
+    return handle_torch_function(r, self, args, kwargs, THPVariableClass, "torch.Tensor");
+  }
+
+  if (r.isNone(0)) {
+    return THPUtils_packString(torch::utils::options_to_string(self_.options()));
+  }
+  auto obj = r.pyobject(0);
+  auto opt_memory_format = r.memoryformatOptional(2);
+  std::string type_name;
+  bool is_dtype = false;
+  if (PyType_Check(obj)) {
+    if (obj == THPVariableClass) {
+      type_name = "torch.Tensor";
+    } else {
+      type_name = ((PyTypeObject*)obj)->tp_name;
+    }
+  } else if (THPUtils_checkString(obj)) {
+    type_name = THPUtils_unpackString(obj);
+  } else if (THPDtype_Check(obj)) {
+    is_dtype = true;
+  } else {
+    throw TypeError("dtype must be a type, str, or dtype object");
+  }
+  Device device = self_.device();
+  if (is_dtype) {
+    auto scalar_type = r.scalartype(0);
+    return THPVariable_Wrap(dispatch_to(self_, scalar_type, /*non_blocking=*/ r.toBool(1), /*copy=*/ false, opt_memory_format));
+  }
+  at::TensorOptions options = torch::utils::options_from_string(type_name);
+  auto scalar_type = at::typeMetaToScalarType(options.dtype());
+  auto device_type = options.device().type();
+  if (device_type != device.type()) {
+    device = at::Device(device_type);
+  }
+  torch::utils::maybe_initialize_device(device);
+  return THPVariable_Wrap(dispatch_to(self_, device, scalar_type, /*non_blocking=*/ r.toBool(1), /*copy=*/ false, opt_memory_format));
+  END_HANDLE_TH_ERRORS
+}
+
+// generated methods start here
+
+${py_methods}
+
+static PyObject * THPVariable_bool_scalar(PyObject* self, PyObject* args) {
+  if (check_has_torch_function(self)) {
+    HANDLE_TH_ERRORS
+    return handle_torch_function(self, "__bool__", args);
+    END_HANDLE_TH_ERRORS
+  }
+  jit::tracer::warn("Converting a tensor to a Python boolean", jit::tracer::WARN_PYTHON_DATAFLOW);
+  return THPVariable_is_nonzero(self, args);
+}
+
+static PyObject * THPVariable___eq__(PyObject* self_, PyObject* args, PyObject* kwargs)
+{
+  HANDLE_TH_ERRORS
+#ifdef USE_NUMPY
+  if (torch::utils::is_numpy_available()) {
+    static PythonArgParser parser({
+      "__eq__(PyObject* other)",
+    }, /*traceable=*/true);
+
+    ParsedArgs<1> parsed_args;
+    auto _r = parser.parse(self_, args, kwargs, parsed_args);
+    if(_r.has_torch_function()) {
+      return handle_torch_function(_r, self_, args, kwargs, THPVariableClass, "torch.Tensor");
+    }
+    switch (_r.idx) {
+      case 0: {
+        auto other = _r.pyobject(0);
+        if (PyArray_Check(other)) {
+          auto other_tensor = torch::utils::tensor_from_numpy(other);
+          auto dispatch_eq = [](const at::Tensor & self, const at::Tensor & other) -> at::Tensor {
+            pybind11::gil_scoped_release no_gil;
+            return self.eq(other);
+          };
+          const Tensor& self = THPVariable_Unpack(self_);
+          return wrap(dispatch_eq(self, other_tensor));
+        }
+      }
+    }
+  }
+#endif
+  return THPVariable_eq(self_, args, kwargs);
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+// Wrapper converts a raised TypeError into returning NotImplemented
+// Used to implement binary arithmetic operators
+template 
+static PyObject * TypeError_to_NotImplemented_(PyObject* self, PyObject* args, PyObject* kwargs) {
+
+  PyObject* ret = Func(self, args, kwargs);
+  if (!ret && PyErr_ExceptionMatches(PyExc_TypeError)) {
+    PyErr_Clear();
+    Py_INCREF(Py_NotImplemented);
+    ret = Py_NotImplemented;
+  }
+  return ret;
+}
+
+// set_ has to be defined in the template because the c10::Storage object
+// does not have a type, and we need to make sure the Python storage object's
+// type matches the tensor's type
+static PyObject* THPVariable_set_(
+    PyObject* self_,
+    PyObject* args,
+    PyObject* kwargs) {
+  HANDLE_TH_ERRORS
+  const Tensor& self = THPVariable_Unpack(self_);
+  static PythonArgParser parser(
+      {
+          "set_()",
+          "set_(Storage source)",
+          "set_(Storage source, SymInt storage_offset, SymIntArrayRef size, SymIntArrayRef stride=None)",
+          "set_(Tensor source)",
+          "set_(Tensor source, SymInt storage_offset, SymIntArrayRef size, SymIntArrayRef stride=None)",
+      },
+      /*traceable=*/false);
+
+  ParsedArgs<4> parsed_args;
+  auto _r = parser.parse(args, kwargs, parsed_args);
+
+  switch (_r.idx) {
+    case 0: {
+      // aten::set_(Tensor(a!) self) -> Tensor(a!)
+      auto dispatch_set_ = [](const Tensor& self) -> Tensor {
+        pybind11::gil_scoped_release no_gil;
+        return self.set_();
+      };
+      return wrap(dispatch_set_(self));
+    }
+    case 1: {
+      // aten::set_.source_Storage(Tensor(a!) self, Storage source) ->
+      // Tensor(a!)
+      at::ScalarType storage_scalar_type{};
+      bool is_typed_storage = true;
+      at::Storage storage = _r.storage(0, storage_scalar_type, is_typed_storage);
+      TORCH_CHECK(storage_scalar_type == self.dtype() || !is_typed_storage,
+        "Expected a Storage of type ", self.dtype(),
+        " or an UntypedStorage, but got type ", storage_scalar_type,
+        " for argument 1 'storage'");
+      auto dispatch_set_ = [](const Tensor& self, Storage source) -> Tensor {
+        pybind11::gil_scoped_release no_gil;
+        return self.set_(std::move(source));
+      };
+      return wrap(dispatch_set_(self, storage));
+    }
+    case 2: {
+      // aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage
+      // source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!)
+      at::ScalarType storage_scalar_type{};
+      bool is_typed_storage = true;
+      at::Storage storage = _r.storage(0, storage_scalar_type, is_typed_storage);
+      TORCH_CHECK(storage_scalar_type == self.dtype() || !is_typed_storage,
+        "Expected a Storage of type ", self.dtype(),
+        " or an UntypedStorage, but got type ", storage_scalar_type,
+        " for argument 1 'storage'");
+      auto dispatch_set_ = [](const Tensor& self,
+                              Storage source,
+                              c10::SymInt storage_offset,
+                              c10::SymIntArrayRef size,
+                              c10::SymIntArrayRef stride) -> Tensor {
+        pybind11::gil_scoped_release no_gil;
+        return self.set__symint(std::move(source), std::move(storage_offset), size, stride);
+      };
+      return wrap(dispatch_set_(
+          self, storage, _r.toSymInt(1), _r.symintlist(2), _r.symintlist(3)));
+    }
+    case 3: {
+      // aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!)
+      auto dispatch_set_ = [](const Tensor& self, const Tensor& source) -> Tensor {
+        TORCH_CHECK(source.dtype() == self.dtype(), "Could not set tensor of type ", source.dtype(), " to a tensor of type ", self.dtype());
+        pybind11::gil_scoped_release no_gil;
+        return self.set_(source);
+      };
+      return wrap(dispatch_set_(self, _r.tensor(0)));
+    }
+    case 4: {
+      // aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor
+      // source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!)
+      at::Tensor storage = _r.tensor(0);
+      auto dispatch_set_ = [](const Tensor& self,
+                              const Tensor& source,
+                              c10::SymInt storage_offset,
+                              c10::SymIntArrayRef size,
+                              c10::SymIntArrayRef stride) -> Tensor {
+        pybind11::gil_scoped_release no_gil;
+        return self.set__symint(source, std::move(storage_offset), size, stride);
+      };
+      return wrap(dispatch_set_(
+          self, storage, _r.toSymInt(1), _r.symintlist(2), _r.symintlist(3)));
+    }
+  }
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+// XXX: ops that are bound here are not exposed to the C++ api nor the JIT.
+// Any new ops added here should be accompanied with a comment why they are not
+// being registered through native_functions.yaml, and be tagged cpp / JIT
+PyMethodDef variable_methods[] = {
+  // These magic methods are all implemented on python object to wrap NotImplementedError
+  {"__add__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__radd__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__iadd__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__rmul__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__mul__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__imul__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__sub__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__isub__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__div__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__truediv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__floordiv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__idiv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__ifloordiv__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__mod__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__imod__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__eq__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__ne__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__lt__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__le__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__gt__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__ge__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__rand__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__ror__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__rxor__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"__bool__", THPVariable_bool_scalar, METH_NOARGS, nullptr},
+  {"__float__", THPVariable_float_scalar, METH_NOARGS, nullptr},
+  {"__complex__", THPVariable_complex_scalar, METH_NOARGS, nullptr},
+  {"__int__", THPVariable_integral_scalar, METH_NOARGS, nullptr},
+  {"__long__", THPVariable_integral_scalar, METH_NOARGS, nullptr},
+  {"__index__", THPVariable_index_scalar, METH_NOARGS, nullptr},
+  {"__nonzero__", THPVariable_bool_scalar, METH_NOARGS, nullptr},
+  {"__invert__", THPVariable_invert, METH_NOARGS, nullptr},
+  {"__matmul__", castPyCFunctionWithKeywords(TypeError_to_NotImplemented_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"_is_view", THPVariable__is_view, METH_NOARGS, nullptr},
+  {"apply_", THPVariable_apply_, METH_O, nullptr},
+  {"bfloat16", castPyCFunctionWithKeywords(THPVariable_bfloat16), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"byte", castPyCFunctionWithKeywords(THPVariable_byte), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"char", castPyCFunctionWithKeywords(THPVariable_char), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"contiguous", castPyCFunctionWithKeywords(THPVariable_contiguous), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"copy_", castPyCFunctionWithKeywords(THPVariable_copy_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"cpu", castPyCFunctionWithKeywords(THPVariable_cpu), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"cuda", castPyCFunctionWithKeywords(THPVariable_cuda), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"mtia", castPyCFunctionWithKeywords(THPVariable_mtia), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"xpu", castPyCFunctionWithKeywords(THPVariable_xpu), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"ipu", castPyCFunctionWithKeywords(THPVariable_ipu), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"data_ptr", THPVariable_data_ptr, METH_NOARGS, nullptr},
+  {"dim", THPVariable_dim, METH_NOARGS, nullptr},
+  {"has_names", THPVariable_has_names, METH_NOARGS, nullptr},
+  {"double", castPyCFunctionWithKeywords(THPVariable_double), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"cdouble", castPyCFunctionWithKeywords(THPVariable_cdouble), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"element_size", THPVariable_element_size, METH_NOARGS, nullptr},
+  {"float", castPyCFunctionWithKeywords(THPVariable_float), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"cfloat", castPyCFunctionWithKeywords(THPVariable_cfloat), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"get_device", THPVariable_get_device, METH_NOARGS, nullptr},
+  {"bool", castPyCFunctionWithKeywords(THPVariable_bool), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"half", castPyCFunctionWithKeywords(THPVariable_half), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"int", castPyCFunctionWithKeywords(THPVariable_int), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"is_contiguous", castPyCFunctionWithKeywords(THPVariable_is_contiguous), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"item", THPVariable_item, METH_NOARGS, nullptr},
+  {"long", castPyCFunctionWithKeywords(THPVariable_long), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"map_", castPyCFunctionWithKeywords(THPVariable_map_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"map2_", castPyCFunctionWithKeywords(THPVariable_map2_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"ndimension", THPVariable_dim, METH_NOARGS, nullptr},
+  {"nelement", THPVariable_numel, METH_NOARGS, nullptr},
+  {"new", castPyCFunctionWithKeywords(THPVariable_new), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"new_tensor", castPyCFunctionWithKeywords(THPVariable_new_tensor), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"nonzero", castPyCFunctionWithKeywords(THPVariable_nonzero), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"numel", THPVariable_numel, METH_NOARGS, nullptr},
+  {"numpy", castPyCFunctionWithKeywords(THPVariable_numpy), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"requires_grad_", castPyCFunctionWithKeywords(THPVariable_requires_grad_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"set_", castPyCFunctionWithKeywords(THPVariable_set_), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"short", castPyCFunctionWithKeywords(THPVariable_short), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"size", castPyCFunctionWithKeywords(THPVariable_size), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"untyped_storage", THPVariable_storage, METH_NOARGS, nullptr},
+  {"storage_offset", THPVariable_storage_offset, METH_NOARGS, nullptr},
+  {"stride", castPyCFunctionWithKeywords(THPVariable_stride), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"to", castPyCFunctionWithKeywords(THPVariable_to), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"tolist", THPVariable_tolist, METH_NOARGS, nullptr},
+  {"type", castPyCFunctionWithKeywords(THPVariable_type), METH_VARARGS | METH_KEYWORDS, nullptr},
+  ${py_method_defs}
+  {nullptr}
+};
+
+} // namespace torch::autograd
diff --git a/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/variable_factories.h b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/variable_factories.h
new file mode 100644
index 0000000000000000000000000000000000000000..225ad79f0947c69b8799f7fb79d4cc37bc39c09f
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/packaged/autograd/templates/variable_factories.h
@@ -0,0 +1,135 @@
+#pragma once
+
+// ${generated_comment}
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include 
+#else
+#include 
+$ops_headers
+#endif
+
+#include 
+#include 
+#include 
+
+namespace torch {
+
+/// NOTE: Currently `torch::tensor(...)` doesn't support mixed data types
+/// (i.e. `torch::tensor({{bool, 2.0}})` doesn't work). We might be able to
+/// support it in the future by iterating over all sub-lists to find
+/// the largest data type that can represent all of the elements, or by using
+/// variadic templates.
+///
+/// NOTE: C++ `torch::tensor` with a floating-point type or an `at::ArrayRef` / `std::vector` /
+/// (nested) braced-init-list of floating-point types always produces a tensor of dtype
+/// `torch::get_default_dtype()`, matching Python `torch.tensor` behavior.
+///
+/// NOTE: C++ `torch::tensor` with an integer type or an `at::ArrayRef` / `std::vector` /
+/// (nested) braced-init-list of integer types always produces a tensor of dtype `at::kLong`
+/// (aka. int64_t), matching Python `torch.tensor` behavior.
+///
+/// NOTE: The following dtypes are not supported by `torch::tensor` currently:
+/// - `unsigned int`
+/// - `unsigned long int`
+/// - `unsigned long long int`
+/// - `long long int`
+inline at::Tensor tensor(detail::TensorDataContainer tensor_data_container, const at::TensorOptions& options = {}) {
+  return autograd::make_variable(
+    // note: we remove the requires_grad setting from the TensorOptions because
+    // it is ignored anyways (and we actually have an assertion that it isn't set
+    // which would fail otherwise). We handle requires_grad explicitly here
+    // instead of passing it through to the kernel.
+    tensor_data_container.convert_to_tensor(options.requires_grad(::std::nullopt)),
+    options.requires_grad());
+}
+
+/// A generic deleter function.
+using Deleter = std::function;
+using at::MemoryFormat;
+
+/// Exposes the given `data` as a `Tensor` without taking ownership of the
+/// original data. `sizes` should specify the shape of the tensor, `strides` the
+/// stride in each dimension. The `deleter` function (a
+/// `std::function`) will be called on the `data` when the Tensor
+/// data would normally be deallocated. The `TensorOptions` specify additional
+/// configuration options for the returned tensor, such as what type to
+/// interpret the `data` as.
+inline at::Tensor from_blob(
+    void* data,
+    at::IntArrayRef sizes,
+    at::IntArrayRef strides,
+    const Deleter& deleter,
+    const at::TensorOptions& options = at::TensorOptions()) {
+  at::Tensor tensor = ([&]() {
+    at::AutoDispatchBelowAutograd guard;  // TODO: remove
+    at::tracer::impl::NoTracerDispatchMode tracer_guard;
+    return at::from_blob(data, sizes, strides, deleter, options.requires_grad(::std::nullopt));
+  })();
+  return autograd::make_variable(tensor, options.requires_grad());
+}
+
+/// Exposes the given `data` as a `Tensor` without taking ownership of the
+/// original data. `sizes` should specify the shape of the tensor, `strides` the
+/// stride in each dimension. The `TensorOptions`
+/// specify additional configuration options for the returned tensor, such as
+/// what type to interpret the `data` as.
+inline at::Tensor from_blob(
+    void* data,
+    at::IntArrayRef sizes,
+    at::IntArrayRef strides,
+    const at::TensorOptions& options = at::TensorOptions()) {
+  at::Tensor tensor = ([&]() {
+    at::AutoDispatchBelowAutograd guard;  // TODO: remove
+    at::tracer::impl::NoTracerDispatchMode tracer_guard;
+    return at::from_blob(data, sizes, strides, options.requires_grad(::std::nullopt));
+  })();
+  return autograd::make_variable(tensor, options.requires_grad());
+}
+
+/// Exposes the given `data` as a `Tensor` without taking ownership of the
+/// original data. `sizes` should specify the shape of the tensor. The `deleter`
+/// (a `std::function`) function will be called on the `data` when
+/// the Tensor data would normally be deallocated. The `TensorOptions` specify
+/// additional configuration options for the returned tensor, such as what type
+/// to interpret the `data` as.
+inline at::Tensor from_blob(
+    void* data,
+    at::IntArrayRef sizes,
+    const Deleter& deleter,
+    const at::TensorOptions& options = at::TensorOptions()) {
+  at::Tensor tensor = ([&]() {
+    at::AutoDispatchBelowAutograd guard;  // TODO: remove
+    at::tracer::impl::NoTracerDispatchMode tracer_guard;
+    return at::from_blob(data, sizes, deleter, options.requires_grad(::std::nullopt));
+  })();
+  return autograd::make_variable(tensor, options.requires_grad());
+}
+
+/// Exposes the given `data` as a `Tensor` without taking ownership of the
+/// original data. `sizes` should specify the shape of the tensor. The
+/// `TensorOptions` specify additional configuration options for the returned
+/// tensor, such as what type to interpret the `data` as.
+inline at::Tensor from_blob(
+    void* data,
+    at::IntArrayRef sizes,
+    const at::TensorOptions& options = at::TensorOptions()) {
+  at::Tensor tensor = ([&]() {
+    at::AutoDispatchBelowAutograd guard;  // TODO: remove
+    at::tracer::impl::NoTracerDispatchMode tracer_guard;
+    return at::from_blob(data, sizes, options.requires_grad(::std::nullopt));
+  })();
+  return autograd::make_variable(tensor, options.requires_grad());
+}
+
+${function_definitions}
+
+} // namespace torch
diff --git a/phivenv/Lib/site-packages/torchgen/selective_build/__init__.py b/phivenv/Lib/site-packages/torchgen/selective_build/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phivenv/Lib/site-packages/torchgen/selective_build/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/selective_build/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d8a8217475bcc8f9b21e18dfeaa68d8b94303b80
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/selective_build/__pycache__/__init__.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/selective_build/__pycache__/operator.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/selective_build/__pycache__/operator.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5a7786cfe94850c509c058f9476b8e9e2e3b8173
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/selective_build/__pycache__/operator.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/selective_build/__pycache__/selector.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/selective_build/__pycache__/selector.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e91a82fa57941dc6330c2f12795100d53afa99b2
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/selective_build/__pycache__/selector.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/selective_build/operator.py b/phivenv/Lib/site-packages/torchgen/selective_build/operator.py
new file mode 100644
index 0000000000000000000000000000000000000000..be35a83cb083527841e04fbde6a365de0fc62c5f
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/selective_build/operator.py
@@ -0,0 +1,171 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+
+
+# This class holds information about a single operator used to determine
+# the outcome of a selective/custom PyTorch build that doesn't include
+# registration code for all the supported operators. This is done to
+# reduce the size of the generated binary so that it can be deployed in
+# situations where binary size comes at a premium.
+#
+@dataclass(frozen=True)
+class SelectiveBuildOperator:
+    # The name of the operator. This includes the aten::, etc... prefix
+    # The operator name may or may not have the overload name. If this
+    # operator name does not specify an overload name, the way to determine
+    # if this entry refers to the family of operators with this base name
+    # or just the operator with this name is to look at the value of the
+    # 'include_all_overloads' flag in this class.
+    name: str
+
+    # True if this is a root operator (i.e. called directly from a
+    # TorchScript model, etc...). An operator is considered to be a
+    # root operator if it is called directly from any one of the models
+    # that this instance of the pytorch library was built for. Hence, it
+    # may not be a root operator in all of the models that are used in
+    # this instance of the pytorch library.
+    is_root_operator: bool
+
+    # Is this operator used for on-device training? If True, then we need to
+    # use the information to generate code in VariableType_N.cpp for registration
+    # of training related operators. Again, this is True if this operator
+    # is used for training in one or more models used by this instance of the
+    # pytorch library.
+    is_used_for_training: bool
+
+    # If True, it indicates that this operator instance (object) refers to an
+    # operator without the overload name and should apply to all overloads
+    # which have this operator name as the base name. This flag is applicable
+    # only for objects that have operator names without a DOT (period) character
+    # in them.
+    #
+    # Note: This flag is a temporary workaround to grandfather in the current
+    # static selective (custom) build mechanism, which largely ignores overload
+    # names when determining whether to select operators for registration
+    # purposes.
+    include_all_overloads: bool
+
+    # Debug Information at the operator level
+    _debug_info: tuple[str, ...] | None
+
+    @staticmethod
+    def from_yaml_dict(
+        op_name: str, op_info: dict[str, object]
+    ) -> SelectiveBuildOperator:
+        allowed_keys = {
+            "name",
+            "is_root_operator",
+            "is_used_for_training",
+            "include_all_overloads",
+            "debug_info",
+        }
+
+        if len(set(op_info.keys()) - allowed_keys) > 0:
+            raise Exception(  # noqa: TRY002
+                "Got unexpected top level keys: {}".format(
+                    ",".join(set(op_info.keys()) - allowed_keys),
+                )
+            )
+
+        if "name" in op_info:
+            assert op_name == op_info["name"]
+
+        is_root_operator = op_info.get("is_root_operator", True)
+        assert isinstance(is_root_operator, bool)
+
+        is_used_for_training = op_info.get("is_used_for_training", True)
+        assert isinstance(is_used_for_training, bool)
+
+        include_all_overloads = op_info.get("include_all_overloads", True)
+        assert isinstance(include_all_overloads, bool)
+
+        debug_info: tuple[str, ...] | None = None
+        if "debug_info" in op_info:
+            di_list = op_info["debug_info"]
+            assert isinstance(di_list, list)
+            debug_info = tuple(str(x) for x in di_list)
+
+        return SelectiveBuildOperator(
+            name=op_name,
+            is_root_operator=is_root_operator,
+            is_used_for_training=is_used_for_training,
+            include_all_overloads=include_all_overloads,
+            _debug_info=debug_info,
+        )
+
+    @staticmethod
+    def from_legacy_operator_name_without_overload(
+        name: str,
+    ) -> SelectiveBuildOperator:
+        return SelectiveBuildOperator(
+            name=name,
+            is_root_operator=True,
+            is_used_for_training=True,
+            include_all_overloads=True,
+            _debug_info=None,
+        )
+
+    def to_dict(self) -> dict[str, object]:
+        ret: dict[str, object] = {
+            "is_root_operator": self.is_root_operator,
+            "is_used_for_training": self.is_used_for_training,
+            "include_all_overloads": self.include_all_overloads,
+        }
+        if self._debug_info is not None:
+            ret["debug_info"] = self._debug_info
+
+        return ret
+
+
+def merge_debug_info(
+    lhs: tuple[str, ...] | None,
+    rhs: tuple[str, ...] | None,
+) -> tuple[str, ...] | None:
+    # Ensure that when merging, each entry shows up just once.
+    if lhs is None and rhs is None:
+        return None
+
+    return tuple(set((lhs or ()) + (rhs or ())))
+
+
+def combine_operators(
+    lhs: SelectiveBuildOperator, rhs: SelectiveBuildOperator
+) -> SelectiveBuildOperator:
+    if str(lhs.name) != str(rhs.name):
+        raise Exception(  # noqa: TRY002
+            f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead"
+        )
+
+    return SelectiveBuildOperator(
+        name=lhs.name,
+        # Consider this operator to be a root operator if it is a
+        # root operator in any of the models used in this instance of
+        # the pytorch library.
+        is_root_operator=lhs.is_root_operator or rhs.is_root_operator,
+        # Consider this operator to be a training operator if it is
+        # an operator used for training in any of the models used
+        # in this instance of the pytorch library.
+        is_used_for_training=lhs.is_used_for_training or rhs.is_used_for_training,
+        include_all_overloads=lhs.include_all_overloads or rhs.include_all_overloads,
+        _debug_info=merge_debug_info(lhs._debug_info, rhs._debug_info),
+    )
+
+
+def merge_operator_dicts(
+    lhs: dict[str, SelectiveBuildOperator],
+    rhs: dict[str, SelectiveBuildOperator],
+) -> dict[str, SelectiveBuildOperator]:
+    operators: dict[str, SelectiveBuildOperator] = {}
+    for op_name, op in list(lhs.items()) + list(rhs.items()):
+        new_op = op
+        if op_name in operators:
+            new_op = combine_operators(operators[op_name], op)
+
+        operators[op_name] = new_op
+
+    return operators
+
+
+def strip_operator_overload_name(op_name: str) -> str:
+    return op_name.split(".")[0]
diff --git a/phivenv/Lib/site-packages/torchgen/selective_build/selector.py b/phivenv/Lib/site-packages/torchgen/selective_build/selector.py
new file mode 100644
index 0000000000000000000000000000000000000000..55c1b85a2dbe573d75234c705ce72a98a501c34e
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/selective_build/selector.py
@@ -0,0 +1,352 @@
+from __future__ import annotations
+
+from collections import defaultdict
+from collections.abc import Iterable
+from dataclasses import dataclass
+from typing import TYPE_CHECKING
+
+import yaml
+
+from torchgen.selective_build.operator import (
+    merge_debug_info,
+    merge_operator_dicts,
+    SelectiveBuildOperator,
+    strip_operator_overload_name,
+)
+
+
+if TYPE_CHECKING:
+    from torchgen.model import NativeFunction
+
+
+# A SelectiveBuilder holds information extracted from the selective build
+# YAML specification.
+#
+# It includes information about the build's selectivity, the debug_info
+# associated with this selective build (opaque string), and the set of
+# operators that should be included in the build.
+#
+@dataclass(frozen=True)
+class SelectiveBuilder:
+    # If true, then the build is not selective, and includes all
+    # operators.
+    include_all_operators: bool
+
+    # Debug Information at the selective/custom build level.
+    _debug_info: tuple[str, ...] | None
+
+    # A dictionary of operator -> operator metadata.
+    operators: dict[str, SelectiveBuildOperator]
+
+    # A dictionary of selected kernel tags and dtypes. Typically a
+    # PyTorch Operator Kernel (function) may have many code paths
+    # that are specialized for many many Tensor dtypes, so it's not
+    # one per kernel function, but there could be many per kernel
+    # function. The tag isn't a kernel function name, but some fragment
+    # of the kernel function implementation itself.
+    kernel_metadata: dict[str, list[str]]
+
+    # ExecuTorch only. A dictionary of kernel tag -> list of (list of input
+    # dtypes for tensor-like input args).
+    # This is from selective.yaml
+    et_kernel_metadata: dict[str, list[str]]
+
+    # A set of all the custom torch bind classes used by the selected models
+    # Stored as a set internally to remove duplicates proactively, but written
+    # as a list to yamls
+    custom_classes: set[str]
+
+    # A set of all the build features used by the selected models
+    # Stored as a set internally to remove duplicates proactively, but written
+    # as a list to yamls
+    build_features: set[str]
+
+    # If true, then fragments for all dtypes for all kernel functions
+    # are included as well as all custom classes. This is typically set when any one of the
+    # operator lists is generated from a mechanism other than
+    # tracing based selective build.
+    include_all_non_op_selectives: bool
+
+    @staticmethod
+    def get_nop_selector() -> SelectiveBuilder:
+        return SelectiveBuilder.from_yaml_dict({"include_all_operators": True})
+
+    @staticmethod
+    def from_yaml_dict(data: dict[str, object]) -> SelectiveBuilder:
+        valid_top_level_keys = {
+            "include_all_non_op_selectives",
+            "include_all_operators",
+            "debug_info",
+            "operators",
+            "kernel_metadata",
+            "et_kernel_metadata",
+            "custom_classes",
+            "build_features",
+        }
+        top_level_keys = set(data.keys())
+        if len(top_level_keys - valid_top_level_keys) > 0:
+            raise Exception(  # noqa: TRY002
+                "Got unexpected top level keys: {}".format(
+                    ",".join(top_level_keys - valid_top_level_keys),
+                )
+            )
+        include_all_operators = data.get("include_all_operators", False)
+        assert isinstance(include_all_operators, bool)
+
+        debug_info = None
+        if "debug_info" in data:
+            di_list = data["debug_info"]
+            assert isinstance(di_list, list)
+
+            debug_info = tuple(str(x) for x in di_list)
+
+        operators = {}
+        operators_dict = data.get("operators", {})
+        assert isinstance(operators_dict, dict)
+
+        for k, v in operators_dict.items():
+            operators[k] = SelectiveBuildOperator.from_yaml_dict(k, v)
+
+        kernel_metadata = {}
+        kernel_metadata_dict = data.get("kernel_metadata", {})
+        assert isinstance(kernel_metadata_dict, dict)
+
+        for k, v in kernel_metadata_dict.items():
+            kernel_metadata[str(k)] = [str(dtype) for dtype in v]
+
+        et_kernel_metadata = data.get("et_kernel_metadata", {})
+        assert isinstance(et_kernel_metadata, dict)
+
+        custom_classes = data.get("custom_classes", [])
+        assert isinstance(custom_classes, Iterable)
+        custom_classes = set(custom_classes)
+
+        build_features = data.get("build_features", [])
+        assert isinstance(build_features, Iterable)
+        build_features = set(build_features)
+
+        include_all_non_op_selectives = data.get("include_all_non_op_selectives", False)
+        assert isinstance(include_all_non_op_selectives, bool)
+
+        return SelectiveBuilder(
+            include_all_operators,
+            debug_info,
+            operators,
+            kernel_metadata,
+            et_kernel_metadata,
+            custom_classes,  # type: ignore[arg-type]
+            build_features,  # type: ignore[arg-type]
+            include_all_non_op_selectives,
+        )
+
+    @staticmethod
+    def from_yaml_str(config_contents: str) -> SelectiveBuilder:
+        contents = yaml.safe_load(config_contents)
+        return SelectiveBuilder.from_yaml_dict(contents)
+
+    @staticmethod
+    def from_yaml_path(config_path: str) -> SelectiveBuilder:
+        with open(config_path) as f:
+            contents = yaml.safe_load(f)
+            return SelectiveBuilder.from_yaml_dict(contents)
+
+    @staticmethod
+    def from_legacy_op_registration_allow_list(
+        allow_list: set[str], is_root_operator: bool, is_used_for_training: bool
+    ) -> SelectiveBuilder:
+        operators = {}
+        for op in allow_list:
+            operators[op] = {
+                "name": op,
+                "is_root_operator": is_root_operator,
+                "is_used_for_training": is_used_for_training,
+                "include_all_overloads": True,
+            }
+        return SelectiveBuilder.from_yaml_dict(
+            {
+                "operators": operators,
+                "include_all_non_op_selectives": True,
+            }
+        )
+
+    def is_operator_selected(self, name: str) -> bool:
+        if self.include_all_operators:
+            return True
+
+        if name in self.operators:
+            return True
+        name = strip_operator_overload_name(name)
+        return name in self.operators and self.operators[name].include_all_overloads
+
+    def is_native_function_selected(self, func: NativeFunction) -> bool:
+        op_name = op_name_from_native_function(func)
+        return self.is_operator_selected(op_name)
+
+    def is_operator_selected_for_training(self, name: str) -> bool:
+        if not self.is_operator_selected(name):
+            return False
+        if self.include_all_operators:
+            return True
+
+        not_training_op = SelectiveBuildOperator(
+            name="",
+            is_root_operator=False,
+            is_used_for_training=False,
+            include_all_overloads=False,
+            _debug_info=None,
+        )
+        op = not_training_op
+        if name in self.operators:
+            op = self.operators[name]
+
+        name = strip_operator_overload_name(name)
+        base_op = not_training_op
+        if name in self.operators:
+            base_op = self.operators[name]
+
+        return op.is_used_for_training or (
+            base_op.include_all_overloads and base_op.is_used_for_training
+        )
+
+    def is_native_function_selected_for_training(self, func: NativeFunction) -> bool:
+        op_name = op_name_from_native_function(func)
+        return self.is_operator_selected_for_training(op_name)
+
+    def is_root_operator(self, name: str) -> bool:
+        if not self.is_operator_selected(name):
+            return False
+        if self.include_all_operators:
+            return True
+
+        if name in self.operators:
+            op: SelectiveBuildOperator = self.operators[name]
+            return op.is_root_operator
+        name = strip_operator_overload_name(name)
+        if name not in self.operators:
+            return False
+        base_op: SelectiveBuildOperator = self.operators[name]
+        return base_op.include_all_overloads and base_op.is_root_operator
+
+    def is_kernel_dtype_selected(self, kernel_tag: str, dtype: str) -> bool:
+        if self.include_all_operators or self.include_all_non_op_selectives:
+            return True
+
+        return (
+            kernel_tag in self.kernel_metadata
+            and dtype in self.kernel_metadata[kernel_tag]
+        )
+
+    def et_get_selected_kernels(self, op_name: str, kernel_key: list[str]) -> list[str]:
+        """
+        Return a list of kernel keys that cover the used ops
+        """
+        # If no kernel metadata, either it's implied by include_all_operators=True or the op is not used.
+        if op_name not in self.et_kernel_metadata:
+            return kernel_key if self.include_all_operators else []
+        # Otherwise, only return the specific kernel keys.
+
+        result_set = set()
+
+        for model_kernel_keys in self.et_kernel_metadata[op_name]:
+            key_found = False
+            for key in kernel_key:
+                # Don't compare the version for now
+                if (
+                    key != "default"
+                    and key.split("/")[1] == model_kernel_keys.split("/")[1]
+                ):
+                    result_set.add(key)
+                    key_found = True
+                    break
+            if not key_found:
+                if "default" not in kernel_key:
+                    raise Exception("Missing kernel for the model")  # noqa: TRY002
+                else:
+                    result_set.add("default")
+
+        return list(result_set)
+
+    def to_dict(self) -> dict[str, object]:
+        ret: dict[str, object] = {
+            "include_all_non_op_selectives": self.include_all_non_op_selectives,
+            "include_all_operators": self.include_all_operators,
+        }
+        operators = {}
+        for op_name, op in self.operators.items():
+            operators[op_name] = op.to_dict()
+        ret["operators"] = operators
+
+        if self._debug_info is not None:
+            ret["debug_info"] = sorted(self._debug_info)
+
+        ret["kernel_metadata"] = {
+            k: sorted(v) for (k, v) in self.kernel_metadata.items()
+        }
+
+        ret["et_kernel_metadata"] = self.et_kernel_metadata
+
+        ret["custom_classes"] = sorted(self.custom_classes)
+
+        ret["build_features"] = sorted(self.build_features)
+
+        return ret
+
+
+def merge_kernel_metadata(
+    lhs: dict[str, list[str]],
+    rhs: dict[str, list[str]],
+) -> dict[str, list[str]]:
+    kernel_metadata: dict[str, list[str]] = {}
+    for tag_name, dtypes in list(lhs.items()) + list(rhs.items()):
+        dtypes_copy = set(dtypes)
+        if tag_name in kernel_metadata:
+            dtypes_copy |= set(kernel_metadata[tag_name])
+
+        kernel_metadata[tag_name] = list(dtypes_copy)
+
+    return kernel_metadata
+
+
+def merge_et_kernel_metadata(
+    lhs: dict[str, list[str]],
+    rhs: dict[str, list[str]],
+) -> dict[str, list[str]]:
+    merge_et_kernel_metadata: dict[str, set[str]] = defaultdict(set)
+    for op in list(lhs.keys()) + list(rhs.keys()):
+        merge_et_kernel_metadata[op].update(lhs.get(op, []))
+        merge_et_kernel_metadata[op].update(rhs.get(op, []))
+
+    return {op: sorted(val) for op, val in merge_et_kernel_metadata.items()}
+
+
+def combine_selective_builders(
+    lhs: SelectiveBuilder, rhs: SelectiveBuilder
+) -> SelectiveBuilder:
+    include_all_operators = lhs.include_all_operators or rhs.include_all_operators
+    debug_info = merge_debug_info(lhs._debug_info, rhs._debug_info)
+    operators = merge_operator_dicts(lhs.operators, rhs.operators)
+    kernel_metadata = merge_kernel_metadata(lhs.kernel_metadata, rhs.kernel_metadata)
+    et_kernel_metadata = merge_et_kernel_metadata(
+        lhs.et_kernel_metadata, rhs.et_kernel_metadata
+    )
+    include_all_non_op_selectives = (
+        lhs.include_all_non_op_selectives or rhs.include_all_non_op_selectives
+    )
+    custom_classes = lhs.custom_classes.union(rhs.custom_classes)
+    build_features = lhs.build_features.union(rhs.build_features)
+    return SelectiveBuilder(
+        include_all_operators,
+        debug_info,
+        operators,
+        kernel_metadata,
+        et_kernel_metadata,
+        custom_classes,
+        build_features,
+        include_all_non_op_selectives,
+    )
+
+
+def op_name_from_native_function(f: NativeFunction) -> str:
+    # This was originally read from the 'operator_name_with_overload' field in the
+    # declaration dict, which was the part before the first '(' in 'schema_string'.
+    return f"{f.namespace}::{f.func.name}"
diff --git a/phivenv/Lib/site-packages/torchgen/static_runtime/__init__.py b/phivenv/Lib/site-packages/torchgen/static_runtime/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phivenv/Lib/site-packages/torchgen/static_runtime/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/static_runtime/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7207c49574f797ff9c744cf536a345eb1149bd84
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/static_runtime/__pycache__/__init__.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/static_runtime/__pycache__/config.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/static_runtime/__pycache__/config.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c9bd504e6b6b6cd1ab49f67469a5f1a2bc200c10
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/static_runtime/__pycache__/config.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/static_runtime/__pycache__/gen_static_runtime_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/static_runtime/__pycache__/gen_static_runtime_ops.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..68ee6976c209bc24cba6a1b54d3509bda422d9b8
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/static_runtime/__pycache__/gen_static_runtime_ops.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/static_runtime/__pycache__/generator.cpython-39.pyc b/phivenv/Lib/site-packages/torchgen/static_runtime/__pycache__/generator.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..534757b17dae982b0b28e7c659b5c66eec7e6b45
Binary files /dev/null and b/phivenv/Lib/site-packages/torchgen/static_runtime/__pycache__/generator.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/torchgen/static_runtime/config.py b/phivenv/Lib/site-packages/torchgen/static_runtime/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..887949880af482391ba7186a455968f500cb7868
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/static_runtime/config.py
@@ -0,0 +1,388 @@
+from __future__ import annotations
+
+from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup
+
+
+def func_name_base_str(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> str:
+    if isinstance(g, NativeFunctionsGroup):
+        return str(g.functional.func.name.name.base)
+    else:
+        return str(g.view.root_name)
+
+
+is_hand_written_ops_ = frozenset(
+    (
+        "abs",
+        "add",
+        "addmm",
+        "all",
+        "any",
+        "argmin",
+        "bmm",
+        "clamp",
+        "clamp_min",
+        "cumsum",
+        "div",
+        "fmod",
+        "index_select",
+        "leaky_relu",
+        "linear",
+        "log",
+        "matmul",
+        "mul",
+        "narrow_copy",
+        "nonzero",
+        "pow",
+        "remainder",
+        "sigmoid",
+        "sign",
+        "sub",
+        "tanh",
+        "detach",
+        "expand_as",
+        "flatten",
+        "narrow",
+        "reshape_as",
+        "select",
+        "slice",
+        "softmax",
+        "split",
+        "squeeze",
+        "transpose",
+        "view",
+        "where",
+    )
+)
+
+
+def is_hand_written(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
+    name_base = func_name_base_str(g)
+    return name_base in is_hand_written_ops_
+
+
+def override_test_values(arg_map: dict[str, str], op_name: str, index: int) -> None:
+    assert index == 0 or index == 1
+    if op_name == "addr":
+        if index == 0:
+            arg_map["self"] = "at::rand({6, 6})"
+            arg_map["vec1"] = "at::rand({6})"
+            arg_map["vec2"] = "at::rand({6})"
+        else:
+            arg_map["self"] = "at::rand({22, 22})"
+            arg_map["vec1"] = "at::rand({22})"
+            arg_map["vec2"] = "at::rand({22})"
+        return
+    if op_name == "mv":
+        if index == 0:
+            arg_map["self"] = "at::rand({6, 6})"
+            arg_map["vec"] = "at::rand({6})"
+        else:
+            arg_map["self"] = "at::rand({22, 22})"
+            arg_map["vec"] = "at::rand({22})"
+        return
+    if op_name == "addbmm":
+        if index == 0:
+            arg_map["self"] = "at::rand({6, 6})"
+        else:
+            arg_map["self"] = "at::rand({22, 22})"
+        return
+    if op_name == "cross":
+        if index == 0:
+            arg_map["self"] = "at::rand({3, 3, 3})"
+            arg_map["other"] = "at::rand({3, 3, 3})"
+        else:
+            arg_map["self"] = "at::rand({22, 3, 22})"
+            arg_map["other"] = "at::rand({22, 3, 22})"
+        return
+    if op_name == "take":
+        if index == 0:
+            arg_map["index"] = "at::randint(0, 216, {20}, torch::kInt64)"
+        else:
+            arg_map["index"] = "at::randint(0, 1000, {100}, torch::kInt64)"
+        return
+    if op_name == "take_along_dim":
+        if index == 0:
+            arg_map["indices"] = "at::argsort(self0, 1, true)"
+        else:
+            arg_map["indices"] = "at::argsort(self1, 1, true)"
+        return
+    if op_name == "masked_select":
+        if index == 0:
+            arg_map["mask"] = "at::randn({6, 6, 6}) > 0.5"
+        else:
+            arg_map["mask"] = "at::rand({22, 22, 22}) > 0.5"
+        return
+    if op_name == "orgqr":
+        if index == 0:
+            arg_map["input2"] = "at::rand({6, 6})"
+        else:
+            arg_map["input2"] = "at::rand({22, 22})"
+        return
+    if op_name == "ormqr":
+        if index == 0:
+            arg_map["input2"] = "at::rand({6, 6})"
+        else:
+            arg_map["input2"] = "at::rand({22, 22})"
+        return
+    if op_name == "quantile":
+        if index == 0:
+            arg_map["q"] = "at::rand({6})"
+            arg_map["interpolation"] = '"linear"'
+        else:
+            arg_map["q"] = "at::rand({22})"
+            arg_map["interpolation"] = '"linear"'
+        return
+    if op_name == "nanquantile":
+        if index == 0:
+            arg_map["q"] = "at::rand({6})"
+            arg_map["interpolation"] = '"linear"'
+        else:
+            arg_map["q"] = "at::rand({22})"
+            arg_map["interpolation"] = '"linear"'
+        return
+    if op_name == "multi_margin_loss":
+        if index == 0:
+            arg_map["self"] = "at::rand({6, 6})"
+            arg_map["target"] = "at::randint(6, {6}, torch::kInt64)"
+            arg_map["weight"] = "at::rand({6})"
+        else:
+            arg_map["self"] = "at::rand({22, 22})"
+            arg_map["target"] = "at::randint(22, {22}, torch::kInt64)"
+            arg_map["weight"] = "at::rand({22})"
+        return
+    if op_name == "multilabel_margin_loss":
+        if index == 0:
+            arg_map["self"] = "at::rand({6, 6})"
+            arg_map["target"] = "at::randint(6, {6, 6}, torch::kInt64)"
+        else:
+            arg_map["self"] = "at::rand({22, 22})"
+            arg_map["target"] = "at::randint(22, {22, 22}, torch::kInt64)"
+        return
+    if op_name == "nll_loss":
+        if index == 0:
+            arg_map["self"] = "at::rand({6, 6})"
+            arg_map["target"] = "at::randint(6, {6}, torch::kInt64)"
+            arg_map["weight"] = "at::rand({6})"
+        else:
+            arg_map["self"] = "at::rand({22, 22})"
+            arg_map["target"] = "at::randint(22, {22}, torch::kInt64)"
+            arg_map["weight"] = "at::rand({22})"
+        return
+    if op_name == "nll_loss2d":
+        if index == 0:
+            arg_map["self"] = "at::rand({6, 6, 6, 6})"
+            arg_map["target"] = "at::randint(6, {6, 6, 6}, torch::kInt64)"
+            arg_map["weight"] = "at::rand({6})"
+        else:
+            arg_map["self"] = "at::rand({22, 22, 22, 22})"
+            arg_map["target"] = "at::randint(22, {22, 22, 22}, torch::kInt64)"
+            arg_map["weight"] = "at::rand({22})"
+        return
+    if op_name in (
+        "fft_fft",
+        "fft_ifft",
+        "fft_rfft",
+        "fft_irfft",
+        "fft_hfft",
+        "fft_ihfft",
+    ):
+        arg_map["norm"] = '"forward"'
+        return
+    if op_name == "linalg_tensorinv":
+        if index == 0:
+            arg_map["self"] = "at::rand({6, 6, 6, 6})"
+            arg_map["ind"] = "2"
+        else:
+            arg_map["self"] = "at::rand({22, 22, 22, 22})"
+            arg_map["ind"] = "2"
+        return
+    if op_name == "addmv":
+        if index == 0:
+            arg_map["self"] = "at::rand({2})"
+            arg_map["mat"] = "at::rand({2, 2})"
+            arg_map["vec"] = "at::rand({2})"
+        else:
+            arg_map["self"] = "at::rand({35})"
+            arg_map["mat"] = "at::rand({35, 35})"
+            arg_map["vec"] = "at::rand({35})"
+        return
+    if op_name == "acosh":
+        if index == 0:
+            arg_map["self"] = "at::rand({2, 2, 2}) + at::ones({2, 2, 2})"
+        else:
+            arg_map["self"] = "at::rand({5, 5, 5}) + at::ones({5, 5, 5})"
+        return
+    if op_name == "adaptive_max_pool2d_backward":
+        if index == 0:
+            arg_map["grad_output"] = "at::rand({2, 2, 2}, at::kFloat)"
+            arg_map["self"] = "at::rand({2, 2, 2}, at::kFloat)"
+            arg_map["indices"] = "at::randint(0, 1, {2, 2, 2}, at::kLong)"
+        else:
+            arg_map["grad_output"] = "at::rand({3, 3, 3}, at::kFloat)"
+            arg_map["self"] = "at::rand({3, 3, 3}, at::kFloat)"
+            arg_map["indices"] = "at::randint(0, 1, {3, 3, 3}, at::kLong)"
+        return
+    if op_name == "adaptive_max_pool3d_backward":
+        if index == 0:
+            arg_map["grad_output"] = "at::rand({2, 2, 2, 2}, at::kFloat)"
+            arg_map["self"] = "at::rand({2, 2, 2, 2}, at::kFloat)"
+            arg_map["indices"] = "at::randint(0, 1, {2, 2, 2, 2}, at::kLong)"
+        else:
+            arg_map["grad_output"] = "at::rand({3, 3, 3, 3}, at::kFloat)"
+            arg_map["self"] = "at::rand({3, 3, 3, 3}, at::kFloat)"
+            arg_map["indices"] = "at::randint(0, 1, {3, 3, 3, 3}, at::kLong)"
+        return
+    if op_name == "bitwise_left_shift":
+        if index == 0:
+            arg_map["self"] = "at::randint(1, 1 << 4, {6, 6, 6}, at::kInt)"
+            arg_map["other"] = "at::randint(1, 26, {6, 6, 6}, at::kInt)"
+        else:
+            arg_map["self"] = "at::randint(1, 1 << 4, {22, 22, 22}, at::kInt)"
+            arg_map["other"] = "at::randint(1, 26, {22, 22, 22}, at::kInt)"
+        return
+    if op_name == "bitwise_right_shift":
+        if index == 0:
+            arg_map["self"] = "at::randint(1 << 21, 1 << 30, {6, 6, 6}, at::kInt)"
+            arg_map["other"] = "at::randint(1, 22, {6, 6, 6}, at::kInt)"
+        else:
+            arg_map["self"] = "at::randint(1 << 21, 1 << 30, {22, 22, 22}, at::kInt)"
+            arg_map["other"] = "at::randint(1, 22, {22, 22, 22}, at::kInt)"
+        return
+    if op_name == "gather":
+        if index == 0:
+            arg_map["self"] = "at::randint(1, 100, {2,2,2}, at::kInt)"
+            arg_map["dim"] = "1"
+            arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)"
+            arg_map["sparse_grad"] = "false"
+        else:
+            arg_map["self"] = "at::randint(1, 100, {5,5,5}, at::kInt)"
+            arg_map["dim"] = "1"
+            arg_map["index"] = "at::randint(0, 4, {5,5,5}, torch::kInt64)"
+            arg_map["sparse_grad"] = "false"
+        return
+    if op_name == "gelu":
+        if index == 0:
+            arg_map["self"] = "at::rand({6, 6, 6})"
+            arg_map["approximate"] = '"tanh"'
+        else:
+            arg_map["self"] = "at::rand({22, 22, 22})"
+            arg_map["approximate"] = '"tanh"'
+        return
+    if op_name == "gelu_backward":
+        if index == 0:
+            arg_map["grad_output"] = "at::rand({6, 6, 6})"
+            arg_map["self"] = "at::rand({6, 6, 6})"
+            arg_map["approximate"] = '"tanh"'
+        else:
+            arg_map["grad_output"] = "at::rand({22, 22, 22})"
+            arg_map["self"] = "at::rand({22, 22, 22})"
+            arg_map["approximate"] = '"tanh"'
+        return
+    if op_name == "index_add":
+        if index == 0:
+            arg_map["self"] = "at::rand({2})"
+            arg_map["dim"] = "0"
+            arg_map["index"] = "at::randint(0, 1, {2}, at::kInt)"
+            arg_map["source"] = "at::rand({2})"
+            arg_map["alpha"] = "2"
+        else:
+            arg_map["self"] = "at::rand({16})"
+            arg_map["dim"] = "0"
+            arg_map["index"] = "at::randint(0, 10, {16}, at::kInt)"
+            arg_map["source"] = "at::rand({16})"
+            arg_map["alpha"] = "2"
+        return
+    if op_name == "index_copy":
+        if index == 0:
+            arg_map["self"] = "at::rand({2})"
+            arg_map["dim"] = "0"
+            arg_map["index"] = "at::randint(0, 1, {2}, at::kLong)"
+            arg_map["source"] = "at::rand({2})"
+        else:
+            arg_map["self"] = "at::rand({32})"
+            arg_map["dim"] = "0"
+            arg_map["index"] = "at::randint(0, 10, {32}, at::kLong)"
+            arg_map["source"] = "at::rand({32})"
+        return
+    if op_name == "linalg_cross":
+        if index == 0:
+            arg_map["self"] = "at::rand({6, 3, 6})"
+            arg_map["other"] = "at::rand({6, 3, 6})"
+            arg_map["dim"] = "1"
+        else:
+            arg_map["self"] = "at::rand({22, 3, 22})"
+            arg_map["other"] = "at::rand({22, 3, 22})"
+            arg_map["dim"] = "1"
+        return
+    if op_name == "nll_loss_backward":
+        if index == 0:
+            arg_map["grad_output"] = "at::rand({})"
+            arg_map["self"] = "at::rand({6})"
+            arg_map["target"] = "at::randint(0, 5, {6}, torch::kInt64)"
+            arg_map["weight"] = "at::rand({6})"
+            arg_map["reduction"] = "1"
+            arg_map["ignore_index"] = "1"
+            arg_map["total_weight"] = "at::rand({})"
+        else:
+            arg_map["grad_output"] = "at::rand({})"
+            arg_map["self"] = "at::rand({36})"
+            arg_map["target"] = "at::randint(0, 11, {36}, torch::kInt64)"
+            arg_map["weight"] = "at::rand({36})"
+            arg_map["reduction"] = "1"
+            arg_map["ignore_index"] = "1"
+            arg_map["total_weight"] = "at::rand({})"
+        return
+    if op_name in ["scatter", "scatter_add", "_scatter_reduce"]:
+        if index == 0:
+            arg_map["self"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)"
+            arg_map["index"] = "at::randint(0, 1, {2,2,2}, torch::kInt64)"
+            arg_map["src"] = "at::randint(1, 100, {2,2,2}, torch::kInt64)"
+        else:
+            arg_map["self"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)"
+            arg_map["index"] = "at::randint(0, 1, {5,5,5}, torch::kInt64)"
+            arg_map["src"] = "at::randint(1, 100, {5,5,5}, torch::kInt64)"
+        if "reduce" in arg_map:
+            arg_map["reduce"] = '"sum"' if op_name == "_scatter_reduce" else '"add"'
+        return
+    if op_name == "scatter_reduce":
+        arg_map["reduce"] = '"mean"'
+        if index == 0:
+            arg_map["index"] = "at::randint(6, {6, 6, 6}, torch::kInt64)"
+        else:
+            arg_map["index"] = "at::randint(22, {22, 22, 22}, torch::kInt64)"
+        return
+    if op_name == "special_zeta":
+        if index == 0:
+            arg_map["self"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})"
+            arg_map["other"] = "at::rand({2,2,2}, at::kDouble) + at::ones({2,2,2})"
+        else:
+            arg_map["self"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})"
+            arg_map["other"] = "at::rand({5,5,5}, at::kDouble) + at::ones({5,5,5})"
+        return
+    if op_name == "_convert_indices_from_csr_to_coo":
+        if index == 0:
+            arg_map["crow_indices"] = "torch::tensor({1}, torch::kInt32)"
+            arg_map["col_indices"] = "torch::tensor({0, 1, 0}, torch::kInt32)"
+            arg_map["out_int32"] = "false"
+        else:
+            arg_map["crow_indices"] = "torch::tensor({0}, torch::kInt32)"
+            arg_map["col_indices"] = (
+                "torch::tensor({0, 1, 0, 2, 1, 2, 0, 1, 0, 2, 1, 2}, torch::kInt32)"
+            )
+            arg_map["out_int32"] = "false"
+        return
+    if op_name == "_convert_indices_from_coo_to_csr":
+        if index == 0:
+            arg_map["self"] = "at::randint(0, 3, {2}, at::kInt)"
+            arg_map["size"] = "10"
+            arg_map["out_int32"] = "false"
+        else:
+            arg_map["self"] = "at::randint(0, 3, {12}, at::kInt)"
+            arg_map["size"] = "24"
+            arg_map["out_int32"] = "false"
+        return
+    if op_name in ("diagonal", "linalg_diagonal"):
+        arg_map["offset"] = "0"
+        arg_map["dim1"] = "2"
+        arg_map["dim2"] = "1"
+        return
diff --git a/phivenv/Lib/site-packages/torchgen/static_runtime/gen_static_runtime_ops.py b/phivenv/Lib/site-packages/torchgen/static_runtime/gen_static_runtime_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4efa2ba74057549ba6a30a6eb95d880cdbb9018
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/static_runtime/gen_static_runtime_ops.py
@@ -0,0 +1,231 @@
+from __future__ import annotations
+
+import argparse
+import itertools
+import os
+from typing import TYPE_CHECKING, TypeVar, Union
+
+from libfb.py.log import set_simple_logging  # type: ignore[import]
+
+from torchgen import gen
+from torchgen.context import native_function_manager
+from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup
+from torchgen.static_runtime import config, generator
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+
+# Given a list of `grouped_native_functions` sorted by their op names, return a list of
+# lists each of which groups ops that share the base name. For example, `mean` and
+# `mean.dim` are grouped together by this function.
+
+NativeGroupT = TypeVar(
+    "NativeGroupT",
+    bound=Union[NativeFunctionsGroup, NativeFunctionsViewGroup],
+)
+
+
+def group_functions_by_op_name(
+    grouped_native_functions: Sequence[NativeGroupT],
+) -> Sequence[Sequence[NativeGroupT]]:
+    if not grouped_native_functions:
+        return []
+    groups = []
+
+    def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
+        with native_function_manager(g):
+            return generator.is_supported(g)
+
+    eligible_ops = (g for g in grouped_native_functions if is_supported(g))
+    groups = [
+        list(group)
+        for k, group in (
+            itertools.groupby(
+                eligible_ops,
+                key=config.func_name_base_str,
+            )
+        )
+    ]
+
+    return groups
+
+
+def clang_format(cpp_file_path: str) -> None:
+    import subprocess
+
+    subprocess.check_call(["clang-format", "-i", cpp_file_path])
+
+
+def write_cpp(cpp_ops: Sequence[str], file_path: str) -> None:
+    code = "\n".join(cpp_ops)
+    generated = f"""// @lint-ignore-every CLANGTIDY HOWTOEVEN
+// AUTO-GENERATED FROM: torchgen/static_runtime/gen_static_runtime_ops.py
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+namespace torch {{
+namespace jit {{
+
+{code}
+
+}} // namespace jit
+}} // namespace torch
+"""
+    with open(file_path, "w") as f:
+        f.write(generated)
+    clang_format(file_path)
+
+
+def write_test_cpp(cpp_ops: Sequence[str], file_path: str) -> None:
+    code = "\n".join(cpp_ops)
+    generated = f"""// @lint-ignore-every CLANGTIDY HOWTOEVEN
+// AUTO-GENERATED FROM: torchgen/static_runtime/gen_static_runtime_ops.py
+#include 
+#include 
+#include 
+
+#include "test_utils.h"
+
+using namespace caffe2;
+using namespace torch;
+using namespace torch::jit;
+using namespace torch::jit::test;
+using c10::IValue;
+
+{code}
+
+"""
+    with open(file_path, "w") as f:
+        f.write(generated)
+    clang_format(file_path)
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(description="Generate ATen source files")
+    parser.add_argument(
+        "-s",
+        "--source-path",
+        help="path to source directory for ATen",
+        default="caffe2/aten/src/ATen",
+    )
+    parser.add_argument(
+        "-p",
+        "--generated-ops-cpp-path",
+        help="path to directory to generate op dispatcher .cpp file",
+        default="caffe2/torch/csrc/jit/runtime/static/generated_ops.cpp",
+    )
+    parser.add_argument(
+        "-t",
+        "--generated-ops-test-cpp-path",
+        help="path to directory to generate op dispatcher .cpp file",
+        default="caffe2/benchmarks/static_runtime/test_generated_ops.cc",
+    )
+    options = parser.parse_args()
+    native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
+    tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
+    parsed_yaml = gen.parse_native_yaml(native_yaml_path, tags_yaml_path)
+    native_functions, backend_indices = (
+        parsed_yaml.native_functions,
+        parsed_yaml.backend_indices,
+    )
+
+    op_generator = generator.GenOpDispatcher()
+    test_case_generator = generator.GenOpTestCase()
+
+    native_functions_groups = [
+        g
+        for g in gen.get_grouped_native_functions(native_functions)
+        if isinstance(g, NativeFunctionsGroup)
+    ]
+
+    supported_functions_groups = group_functions_by_op_name(native_functions_groups)
+
+    out_variant_op_result = [
+        op_generator.out_variant(groups, backend_indices[DispatchKey.CPU])
+        for groups in supported_functions_groups
+    ]
+    out_variant_test_result = [
+        test_case_generator.out_variant(groups) for groups in supported_functions_groups
+    ]
+
+    native_functions_view_groups = [
+        g
+        for g in gen.get_grouped_by_view_native_functions(native_functions)
+        if isinstance(g, NativeFunctionsViewGroup)
+    ]
+
+    supported_functions_view_groups = group_functions_by_op_name(
+        native_functions_view_groups
+    )
+
+    view_op_result = [
+        op_generator.view(groups, backend_indices[DispatchKey.CPU])
+        for groups in supported_functions_view_groups
+    ]
+    view_test_result = [
+        test_case_generator.view(groups) for groups in supported_functions_view_groups
+    ]
+
+    op_result = out_variant_op_result + ["\n\n"] + view_op_result
+    test_result = out_variant_test_result + ["\n\n"] + view_test_result
+
+    write_cpp(op_result, options.generated_ops_cpp_path)
+    write_test_cpp(test_result, options.generated_ops_test_cpp_path)
+
+    print(
+        f"\ntotal grouped native ops: {len(gen.get_grouped_native_functions(native_functions)):d}"
+    )
+
+    print(f"grouped native ops with out variant: {len(native_functions_groups):d}")
+    supported_functions_num = sum(len(groups) for groups in supported_functions_groups)
+    print(f"generated functions groups with out variant: {supported_functions_num:d}")
+
+    print(f"\nview grouped native ops: {len(native_functions_view_groups):d}")
+    supported_view_functions_num = sum(
+        len(groups) for groups in supported_functions_view_groups
+    )
+    print(f"generated functions view groups: {supported_view_functions_num:d}")
+
+    print(
+        f"\noverall generated : {supported_functions_num + supported_view_functions_num:d}"
+    )
+
+
+if __name__ == "__main__":
+    set_simple_logging(escape_newlines=False)
+    main()
diff --git a/phivenv/Lib/site-packages/torchgen/static_runtime/generator.py b/phivenv/Lib/site-packages/torchgen/static_runtime/generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..0673fd9d3bc0b663dcde2cc5fc9e22d322933f0c
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/static_runtime/generator.py
@@ -0,0 +1,814 @@
+from __future__ import annotations
+
+import json
+import logging
+import math
+from typing import TYPE_CHECKING
+
+import torchgen.api.cpp as cpp
+from torchgen.context import native_function_manager
+from torchgen.model import (
+    Argument,
+    BackendIndex,
+    BaseTy,
+    BaseType,
+    FunctionSchema,
+    NativeFunctionsGroup,
+    NativeFunctionsViewGroup,
+    OptionalType,
+    SelfArgument,
+    TensorOptionsArguments,
+    Type,
+)
+from torchgen.static_runtime import config
+
+
+if TYPE_CHECKING:
+    from collections.abc import Sequence
+
+
+logger: logging.Logger = logging.getLogger()
+
+
+def has_alias(
+    arguments: Sequence[Argument | SelfArgument | TensorOptionsArguments],
+) -> bool:
+    for arg in arguments:
+        annotation = getattr(arg, "annotation", None)
+        if not annotation:
+            continue
+        alias_set = getattr(annotation, "alias_set", ())
+        if alias_set:
+            return True
+    return False
+
+
+BLOCKED_OPS = frozenset(
+    (
+        # non cpu ops
+        "sparse_sampled_addmm",
+        "hspmm",
+        "linalg_svdvals",
+        # sparse ops
+        "sspaddmm",
+        "coalesce",
+        "_indices",
+        "indices",
+        "_values",
+        "values",
+        "crow_indices",
+        "col_indices",
+        # deprecated ops
+        "floor_divide",
+        "ger",
+        # buggy ops
+        "conj_physical",  # P495807361
+        "binary_cross_entropy",  # P496394764
+        "arccosh",
+        # uncommon ops
+        "cholesky",
+        "lu_solve",
+        "linalg_cholesky",
+        "linalg_householder_product",
+        "linalg_ldl_solve",
+        "_compute_linear_combination",
+        # training related ops
+        "_make_dual",
+        # cannot call directly
+        "_fw_primal",
+        # no documentation
+        "_index_reduce",
+        # TODO: these ones got added recently and need manual inspection
+        "_new_zeros_with_same_feature_meta",
+        "_conj_physical",
+        "binary_cross_entropy_with_logits",
+        "bincount",
+        "conv_tbc",
+        "copy",
+        "_copy_from",
+        "_copy_from_and_resize",
+        "count_nonzero",
+        "cudnn_affine_grid_generator",
+        "cudnn_affine_grid_generator_backward",
+        "cudnn_grid_sampler",
+        "diag_embed",
+        "embedding",
+        "embedding_dense_backward",
+        "_embedding_bag_dense_backward",
+        "_embedding_bag_per_sample_weights_backward",
+        "grid_sampler_2d",
+        "_grid_sampler_2d_cpu_fallback",
+        "grid_sampler_3d",
+        "isnan",
+        "mkldnn_linear",
+        "median",
+        "nanmedian",
+        "_sparse_sparse_matmul",
+        "batch_norm_backward_elemt",
+        "_euclidean_dist",
+        "pixel_shuffle",
+        "pixel_unshuffle",
+        "channel_shuffle",
+        "_reshape_nested_backward",
+        "relu",
+        "prelu",
+        "celu",
+        "slice_scatter",
+        "select_scatter",
+        "diagonal_scatter",
+        "sum",
+        "_mkldnn_transpose",
+        "_nested_tensor_from_mask",
+        "_nested_from_padded",
+        "_nested_tensor_size",
+        "_nested_from_padded_and_nested_example",
+        "_standard_gamma_grad",
+        "_dirichlet_grad",
+        "native_norm",
+        "_sparse_softmax",
+        "_sparse_softmax_backward_data",
+        "_sparse_log_softmax",
+        "_sparse_log_softmax_backward_data",
+        "zero",
+        "_sparse_addmm",
+        "sparse_mask",
+        "_sparse_mask_projection",
+        "_to_dense",
+        "_coalesce",
+        "_coalesced",
+        "copy_sparse_to_sparse",
+        "to_sparse",
+        "to_sparse_csr",
+        "to_sparse_csc",
+        "to_mkldnn",
+        "quantize_per_tensor_dynamic",
+        "quantize_per_channel",
+        "q_per_channel_scales",
+        "q_per_channel_zero_points",
+        "int_repr",
+        "_make_per_channel_quantized_tensor",
+        "set",
+        "lift",
+        "lift_fresh",
+        "lift_fresh_copy",
+        "masked_scatter",
+        "_masked_softmax",
+        "_masked_softmax_backward",
+        "put",
+        "index_reduce",
+        "trace",
+        "_cholesky_solve_helper",
+        "dist",
+        "max",
+        "_torch_cuda_cu_linker_symbol_op",
+        "glu_jvp",
+        "glu_backward_jvp",
+        "hardswish_backward",
+        "rrelu_with_noise_backward",
+        "mkldnn_adaptive_avg_pool2d_backward",
+        "_adaptive_avg_pool2d_backward",
+        "_adaptive_avg_pool3d_backward",
+        "isinf",
+        "linalg_lu_solve",
+        "linalg_vecdot",
+        "linalg_matrix_exp",
+        "linalg_eigvalsh",
+        "_test_warn_in_autograd",
+        "_test_autograd_multiple_dispatch_view",
+        "_test_autograd_multiple_dispatch_view_copy",
+        "_segment_reduce",
+        "_segment_reduce_backward",
+        "_fw_primal_copy",
+        "_make_dual_copy",
+        "view_as_real_copy",
+        "view_as_complex_copy",
+        "_conj_copy",
+        "_neg_view_copy",
+        "diagonal_copy",
+        "detach_copy",
+        "squeeze_copy",
+        "t_copy",
+        "unsqueeze_copy",
+        "_indices_copy",
+        "_values_copy",
+        "indices_copy",
+        "values_copy",
+        "crow_indices_copy",
+        "col_indices_copy",
+        "ccol_indices",
+        "ccol_indices_copy",
+        "row_indices",
+        "row_indices_copy",
+        "unfold_copy",
+        "alias_copy",
+        "_triton_multi_head_attention",
+        "special_airy_ai",
+        "special_bessel_j0",
+        "special_bessel_j1",
+        "special_bessel_y0",
+        "special_bessel_y1",
+        "special_chebyshev_polynomial_t",
+        "special_chebyshev_polynomial_u",
+        "special_chebyshev_polynomial_v",
+        "special_chebyshev_polynomial_w",
+        "special_hermite_polynomial_h",
+        "special_hermite_polynomial_he",
+        "special_laguerre_polynomial_l",
+        "special_legendre_polynomial_p",
+        "special_modified_bessel_i0",
+        "special_modified_bessel_i1",
+        "special_modified_bessel_k0",
+        "special_modified_bessel_k1",
+        "special_scaled_modified_bessel_k0",
+        "special_scaled_modified_bessel_k1",
+        "special_shifted_chebyshev_polynomial_t",
+        "special_shifted_chebyshev_polynomial_u",
+        "special_shifted_chebyshev_polynomial_v",
+        "special_shifted_chebyshev_polynomial_w",
+        "special_spherical_bessel_j0",
+        "_foobar",
+        "_nested_tensor_strides",
+        "_nested_tensor_storage_offsets",
+        "_nested_get_values",  # no CPU backend
+        "_nested_get_values_copy",  # no CPU backend
+        "_nested_view_from_jagged",  # testing needs to be patched
+        "_nested_view_from_jagged_copy",  # testing needs to be patched
+        "_nested_view_from_buffer",  # testing needs to be patched
+        "_nested_view_from_buffer_copy",  # testing needs to be patched
+        "_int_mm",  # testing needs to be patched
+        "_to_sparse_csc",  # testing needs to be patched
+        "_to_sparse_csr",  # testing needs to be patched
+        "segment_reduce",  # testing needs to be patched
+    )
+)
+
+
+def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
+    base_op_name = ""
+    func = None
+    if isinstance(g, NativeFunctionsViewGroup):
+        base_op_name = g.view.root_name
+        func = g.view.func
+    else:
+        base_op_name = g.out.func.name.name.base
+        func = g.out.func
+    if config.is_hand_written(g):
+        logger.info("HAND WRITTEN: %s", base_op_name)
+        return False
+    if base_op_name in BLOCKED_OPS:
+        logger.info("BLOCKED: %s", base_op_name)
+        return False
+    for arg in func.schema_order_arguments():
+        maybe_method = ivalue_type_conversion_method(arg.type)
+        if not maybe_method:
+            # Type converting is unsupported yet.
+            logger.info("NOT SUPPORTED TYPE CONVERTING: %s", func)
+            return False
+
+    if isinstance(g, NativeFunctionsViewGroup):
+        # TODO: stop doing type tests by converting to C++ and then testing
+        # the string, just test the dang thing directly
+        if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type():
+            # Returns a non-Tensor value.
+            logger.info("NON-TENSOR RET TYPE: %s", str(func))
+            return False
+        return True
+
+    # For out variant ops, we need to check the arguments of its functional func.
+    for arg in g.functional.func.schema_order_arguments():
+        maybe_method = ivalue_type_conversion_method(arg.type)
+        if not maybe_method:
+            # Type converting is unsupported yet.
+            logger.info("NOT SUPPORTED TYPE CONVERTING: %s", g.functional.func)
+            return False
+
+    if not g.structured:
+        # In case of unstructured op, we check if it has out variant implementation.
+        # The out variant implementation satisfies the minimum requirement that it has the output tensor as the last
+        # parameter.
+        if (
+            not hasattr(g, "out")
+            or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)")
+            or not str(func.name).endswith(".out")
+        ):
+            return False
+    # TODO: stop type testing by converting to C++
+    if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type():
+        logger.info("NON_TENSOR RET TYPE: %s", func)
+        return False
+    if has_alias(func.arguments.non_out):
+        # This op may create an alias of inputs.
+        logger.info("INPUTS ALIAS: %s", base_op_name)
+        return False
+    return True
+
+
+def ivalue_type_conversion_method(
+    arg_type: BaseType | OptionalType | Type,
+) -> tuple[bool, str] | None:
+    """
+    Return the method call expression of `c10::ivalue' to convert its contained value to
+    the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor,
+    this function returns ".toTensor()", so that it can be appended to the ivalue's
+    variable name to get the value of the expected type.
+    """
+    type_conversion_methods = {
+        BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional()")),
+        BaseTy.int: ((False, "toInt()"), (False, "toOptional()")),
+        BaseTy.bool: ((False, "toBool()"), (False, "toOptional()")),
+        BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional()")),
+        BaseTy.ScalarType: (
+            (False, "toScalarType()"),
+            (False, "toOptional()"),
+        ),
+        BaseTy.str: (
+            (False, "toStringView()"),
+            (False, "toOptional()"),
+            (False, "toOptional<::std::string_view>()"),
+        ),
+    }
+
+    base_ty_object = None
+    if isinstance(arg_type, BaseType):
+        base_ty_object = arg_type.name
+    elif isinstance(arg_type, OptionalType):
+        if not isinstance(arg_type.elem, BaseType):
+            # ListType is currently unsupported.
+            return None
+        base_ty_object = arg_type.elem.name
+    else:
+        return None
+
+    if base_ty_object not in type_conversion_methods:
+        return None
+    methods = type_conversion_methods[base_ty_object]
+    if isinstance(arg_type, BaseType):
+        return methods[0]
+    return methods[1]
+
+
+should_use_int_tensor_ops_ = frozenset(
+    (
+        "bitwise_not",
+        "bitwise_and",
+        "bitwise_or",
+        "bitwise_xor",
+        "bitwise_left_shift",
+        "bitwise_right_shift",
+        "gcd",
+        "lcm",
+        "scatter",
+        "gather",
+        "_convert_indices_from_coo_to_csr",
+        "_convert_indices_from_csr_to_coo",
+    )
+)
+should_use_complex_tensor_ops_ = frozenset(("view_as_real", "imag", "_conj"))
+
+
+def should_use_int_tensor(op_name: str) -> bool:
+    return op_name in should_use_int_tensor_ops_
+
+
+def should_use_complex_tensor(op_name: str) -> bool:
+    return op_name in should_use_complex_tensor_ops_
+
+
+test_tensor_dim_ops_1_ = frozenset(
+    (
+        "addmv",
+        "index_add",
+        "_convert_indices_from_coo_to_csr",
+        "_convert_indices_from_csr_to_coo",
+        "nll_loss_backward",
+        "dot",
+        "vdot",
+        "outer",
+        "ger",
+    )
+)
+test_tensor_dim_ops_2_ = frozenset(
+    ("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation", "matrix_H", "t")
+)
+
+
+def test_tensor_dim(op_name: str) -> int:
+    if op_name in test_tensor_dim_ops_1_:
+        return 1
+    if op_name in test_tensor_dim_ops_2_:
+        return 2
+    return 3
+
+
+test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}'
+test_tensor_shape_json: dict[str, str] = json.loads(test_tensor_shapes_string)
+
+
+def test_tensor_shape(op_name: str) -> str:
+    if op_name in test_tensor_shape_json:
+        return test_tensor_shape_json[op_name]
+    else:
+        return ""
+
+
+def test_value_expression(
+    arg_type: BaseType | OptionalType | Type, index: int, op_name: str
+) -> str:
+    tensor_size_ex = test_tensor_shape(op_name)
+    if tensor_size_ex == "":
+        num_tensors = 16 if index == 0 else 64
+        num_dim = test_tensor_dim(op_name)
+        size_per_dim = math.ceil(num_tensors / float(num_dim))
+        size_per_dim += size_per_dim % 2
+        tensor_size_ex = "{{{}}}".format(",".join([f"{size_per_dim}"] * num_dim))
+    if should_use_int_tensor(op_name):
+        tensor_expression = f"at::randint(1, 100, {tensor_size_ex}, at::kInt)"
+    elif should_use_complex_tensor(op_name):
+        tensor_expression = f"at::randn({tensor_size_ex}, at::kComplexFloat)"
+    else:
+        tensor_expression = f"at::rand({tensor_size_ex})"
+
+    value_expressions = {
+        BaseTy.Tensor: tensor_expression,
+        BaseTy.int: "1",
+        BaseTy.bool: "false",
+        BaseTy.Scalar: "2",
+        BaseTy.ScalarType: "at::ScalarType::Float",
+        BaseTy.str: '"floor"',
+    }
+
+    base_ty_object = None
+    if isinstance(arg_type, BaseType):
+        base_ty_object = arg_type.name
+    else:
+        assert isinstance(arg_type, OptionalType) and isinstance(
+            arg_type.elem, BaseType
+        )
+        base_ty_object = arg_type.elem.name
+    assert base_ty_object in value_expressions, "not expected type"
+    value_expression = value_expressions[base_ty_object]
+    return value_expression
+
+
+def generate_test_value_definitions(schema: FunctionSchema, index: int) -> str:
+    assert not schema.is_out_fn()
+    schema_name = schema.name.name.base
+    arg_map = {}
+    for arg in schema.schema_order_arguments():
+        test_value_exp = test_value_expression(arg.type, index, schema_name)
+        arg_map[arg.name] = test_value_exp
+    config.override_test_values(arg_map, schema_name, index)
+    arg_populations = []
+    for arg_name, arg_value in arg_map.items():
+        arg_populations.append(f"auto {arg_name}{index} = {arg_value}")
+    return ";\n    ".join(arg_populations) + ";"
+
+
+def generate_test_value_names(schema: FunctionSchema, index: int) -> str:
+    assert not schema.is_out_fn()
+    return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments())
+
+
+generate_test_ir_arguments_base_ty_to_type_str_ = {
+    BaseTy.Tensor: "Tensor",
+    BaseTy.int: "int",
+    BaseTy.float: "float",
+    BaseTy.str: "str",
+    BaseTy.Scalar: "int",
+    BaseTy.ScalarType: "int",
+    BaseTy.bool: "bool",
+}
+
+
+def generate_test_ir_arguments(
+    schema: FunctionSchema,
+) -> list[tuple[str, str | None]]:
+    def ir_argument(arg: Argument) -> tuple[str, str | None]:
+        t = arg.type
+        add_optional = False
+        if isinstance(t, OptionalType):
+            t = t.elem
+            add_optional = True
+        assert isinstance(t, BaseType)
+        type_str = None
+        if t.name in generate_test_ir_arguments_base_ty_to_type_str_:
+            type_str = generate_test_ir_arguments_base_ty_to_type_str_[t.name]
+        if type_str and add_optional:
+            type_str = f"{type_str}?"
+        return ("%" + arg.name, type_str)
+
+    return [ir_argument(arg) for arg in schema.schema_order_arguments()]
+
+
+def generate_arg_extraction(schema: FunctionSchema) -> str:
+    arg_populations = []
+    for i, arg in enumerate(schema.schema_order_arguments()):
+        maybe_method = ivalue_type_conversion_method(arg.type)
+        assert maybe_method
+        is_reference, type_conversion_method = maybe_method
+        reference = "&" if is_reference else ""
+        arg_populations.append(
+            f"const auto{reference} {arg.name} = p_node->Input({i}).{type_conversion_method}"
+        )
+    return ";\n    ".join(arg_populations) + ";"
+
+
+def get_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
+    kernel = backend_index.get_kernel(g.functional)
+    if g.structured or kernel is None:
+        return cpp.name(g.functional.func)
+    return kernel.kernel
+
+
+def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
+    kernel = backend_index.get_kernel(g.out)
+    if g.structured or kernel is None:
+        return cpp.name(g.out.func)
+    return kernel.kernel
+
+
+def generate_non_out_variant_call(
+    g: NativeFunctionsGroup, backend_index: BackendIndex
+) -> str:
+    schema = g.functional.func
+    assert not schema.is_out_fn()
+    kernel_name = get_kernel_name(g, backend_index)
+    arg_names = (arg.name for arg in schema.schema_order_arguments())
+    namespace_name = "cpu" if g.structured else "native"
+    return f"at::{namespace_name}::{kernel_name}({','.join(arg_names)})"
+
+
+def generate_call_to_view_ops(
+    g: NativeFunctionsViewGroup, backend_index: BackendIndex
+) -> str:
+    schema = g.view.func
+    kernel_name = cpp.name(schema)
+    kernel = backend_index.get_kernel(g.view)
+    if kernel:
+        kernel_name = kernel.kernel
+    arg_names = (arg.name for arg in schema.schema_order_arguments())
+    namespace_name = "native"
+    return f"at::{namespace_name}::{kernel_name}({','.join(arg_names)})"
+
+
+def generate_out_variant_call(
+    g: NativeFunctionsGroup, backend_index: BackendIndex
+) -> str:
+    schema = g.out.func
+    assert schema.is_out_fn()
+    arg_names = []
+    kernel_name = get_out_kernel_name(g, backend_index)
+    if g.structured:
+        # structured op starts with the output tensor argument.
+        arg_names = [out_arg.name for out_arg in schema.arguments.out]
+    else:
+        arg_names = []
+    for arg in schema.arguments.non_out:
+        if isinstance(arg, SelfArgument):
+            arg_names.append(arg.argument.name)
+        else:
+            assert isinstance(arg, Argument)
+            arg_names.append(arg.name)
+    if not g.structured:
+        assert len(schema.arguments.out) == 1
+        arg_names.append(schema.arguments.out[0].name)
+    cpp_arg_names = ",".join(arg_names)
+    namespace_name = "cpu" if g.structured else "native"
+    return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})"
+
+
+no_memory_resize_ops = frozenset(
+    (
+        "isin.Scalar_Tensor",
+        "index_add",
+        "dot",
+        "vdot",
+        "nuclear_norm",
+        "histc",
+        "l1_loss",
+        "multi_margin_loss",
+        "multilabel_margin_loss",
+        "nll_loss",
+        "nll_loss2d",
+        "prod",
+    )
+)
+
+
+def should_check_resize(schema: FunctionSchema) -> bool:
+    schema_str = str(schema)
+    type_variant_op_name = schema_str[: schema_str.find("(")]
+    return type_variant_op_name not in no_memory_resize_ops
+
+
+def op_name_from_group(g: NativeFunctionsGroup) -> str:
+    return g.functional.func.name.name.base
+
+
+class GenOpDispatcher:
+    def out_variant(
+        self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex
+    ) -> str:
+        if not groups:
+            return ""
+        generated_type_variants = []
+        for g in groups:
+            with native_function_manager(g):
+                assert is_supported(g)
+                assert isinstance(g, NativeFunctionsGroup)
+                generated_type_variant = self.out_variant_op_generator(g, backend_index)
+                generated_type_variants.append(generated_type_variant)
+        op_name = op_name_from_group(groups[0])
+        body = "\n".join(generated_type_variants)
+        generated = f"""
+REGISTER_OPERATOR_FUNCTOR(
+    aten::{op_name},
+    aten_{op_name},
+    [](Node* n) -> SROperator {{
+      {body}
+      LogAndDumpSchema(n);
+      return nullptr;
+    }})
+"""
+        return generated
+
+    def view(
+        self, groups: Sequence[NativeFunctionsViewGroup], backend_index: BackendIndex
+    ) -> str:
+        if not groups:
+            return ""
+        generated_type_variants = []
+        for g in groups:
+            with native_function_manager(g):
+                assert is_supported(g)
+                assert isinstance(g, NativeFunctionsViewGroup)
+                generated_type_variant = self.view_op_generator(g, backend_index)
+                generated_type_variants.append(generated_type_variant)
+        op_name = config.func_name_base_str(groups[0])
+        body = "\n".join(generated_type_variants)
+        generated = f"""
+REGISTER_NATIVE_OPERATOR_FUNCTOR(
+    aten::{op_name},
+    aten_{op_name},
+    [](Node* n) -> SROperator {{
+      {body}
+      LogAndDumpSchema(n);
+      return nullptr;
+    }});
+"""
+        return generated
+
+    def out_variant_op_generator(
+        self, g: NativeFunctionsGroup, backend_index: BackendIndex
+    ) -> str:
+        functional = g.functional
+        schema = str(functional.func)
+        populated_argument = generate_arg_extraction(g.functional.func)
+        functional_variant_call = generate_non_out_variant_call(g, backend_index)
+        assert len(g.out.func.arguments.out) == 1
+        out_variable_name = str(g.out.func.arguments.out[0].name)
+        out_variant_call = generate_out_variant_call(g, backend_index)
+        generated = f"""
+      if (n->matches(torch::schema("aten::{schema}"))) {{
+        return [](ProcessedNode* p_node) {{
+          {populated_argument}
+          if (p_node->Output(0).isNone()) {{
+            p_node->Output(0) = {functional_variant_call};
+            return;
+          }}
+          auto& {out_variable_name} = p_node->Output(0).toTensor();
+          fastResizeToZero({out_variable_name});
+          {out_variant_call};
+        }};
+      }}"""
+        return generated
+
+    def view_op_generator(
+        self, g: NativeFunctionsViewGroup, backend_index: BackendIndex
+    ) -> str:
+        schema = str(g.view.func)
+        populated_argument = generate_arg_extraction(g.view.func)
+        functional_variant_call = generate_call_to_view_ops(g, backend_index)
+        generated = f"""
+      if (n->matches(torch::schema("aten::{schema}"))) {{
+        return [](ProcessedNode* p_node) {{
+          {populated_argument}
+            p_node->Output(0) = {functional_variant_call};
+        }};
+      }}"""
+        return generated
+
+
+class GenOpTestCase:
+    def out_variant(self, groups: Sequence[NativeFunctionsGroup]) -> str:
+        if not groups:
+            return ""
+        generated_type_variants = []
+        for g in groups:
+            with native_function_manager(g):
+                assert is_supported(g)
+                assert isinstance(g, NativeFunctionsGroup)
+                generated_type_variant = self.out_variant_op_test_case_generator(g)
+                generated_type_variants.append(generated_type_variant)
+        return "\n".join(generated_type_variants)
+
+    def view(self, groups: Sequence[NativeFunctionsViewGroup]) -> str:
+        if not groups:
+            return ""
+        generated_type_variants = []
+        for g in groups:
+            with native_function_manager(g):
+                assert is_supported(g)
+                assert isinstance(g, NativeFunctionsViewGroup)
+                generated_type_variant = self.view_op_test_case_generator(g)
+                generated_type_variants.append(generated_type_variant)
+        return "\n".join(generated_type_variants)
+
+    def out_variant_op_test_case_generator(self, g: NativeFunctionsGroup) -> str:
+        schema = g.functional.func
+        schema_str = str(schema)
+        assert schema_str.find("(") > 0
+        type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
+        op_name = op_name_from_group(g)
+        assert type_variant_op_name.startswith(op_name)
+
+        arg_types = generate_test_ir_arguments(schema)
+        arg_declarations = ", ".join(
+            (
+                arg_name if arg_type is None else f"{arg_name}: {arg_type}"
+                for arg_name, arg_type in arg_types
+            )
+        )
+        arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
+        assert (
+            len(schema.returns) == 1
+            and isinstance(schema.returns[0].type, BaseType)
+            and schema.returns[0].type.name is BaseTy.Tensor
+        )
+        test_value_definitions = generate_test_value_definitions(schema, 0)
+        test_value_names = generate_test_value_names(schema, 0)
+        test_value_definitions2 = generate_test_value_definitions(schema, 1)
+        test_value_names2 = generate_test_value_names(schema, 1)
+        check_resize = "true" if should_check_resize(schema) else "false"
+        generated = f"""
+TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
+  const std::string script = R"IR(
+    graph({arg_declarations}):
+        %bias: None = prim::Constant()
+        %ret = aten::{op_name}({arg_names})
+        %cloned = aten::clone(%ret, %bias)
+        return (%cloned)
+  )IR";
+
+  {test_value_definitions}
+  std::vector args{{{test_value_names}}};
+  testStaticRuntime(script, args, {{}}, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});
+
+  {test_value_definitions2}
+  std::vector args2{{{test_value_names2}}};
+  testStaticRuntime(script, args, args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});
+
+}}
+"""
+        return generated
+
+    def view_op_test_case_generator(self, g: NativeFunctionsViewGroup) -> str:
+        schema = g.view.func
+        schema_str = str(schema)
+        assert schema_str.find("(") > 0
+        type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
+        op_name = g.view.root_name
+        assert type_variant_op_name.startswith(op_name)
+
+        arg_types = generate_test_ir_arguments(schema)
+        arg_declarations = ", ".join(
+            (
+                arg_name if arg_type is None else f"{arg_name}: {arg_type}"
+                for arg_name, arg_type in arg_types
+            )
+        )
+        arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
+        assert (
+            len(schema.returns) == 1
+            and isinstance(schema.returns[0].type, BaseType)
+            and schema.returns[0].type.name is BaseTy.Tensor
+        )
+        test_value_definitions = generate_test_value_definitions(schema, 0)
+        test_value_names = generate_test_value_names(schema, 0)
+        generated = f"""
+TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
+  const std::string script = R"IR(
+    graph({arg_declarations}):
+        %bias: None = prim::Constant()
+        %ret = aten::{op_name}({arg_names})
+        %cloned = aten::clone(%ret, %bias)
+        return (%cloned)
+  )IR";
+
+  {test_value_definitions}
+  std::vector args{{{test_value_names}}};
+  testStaticRuntime(script, args);
+}}
+"""
+
+        return generated
diff --git a/phivenv/Lib/site-packages/torchgen/utils.py b/phivenv/Lib/site-packages/torchgen/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..521beb8f16075abef695971bdb814c16cebd6240
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/utils.py
@@ -0,0 +1,568 @@
+from __future__ import annotations
+
+import contextlib
+import functools
+import hashlib
+import os
+import re
+import sys
+import textwrap
+from dataclasses import fields, is_dataclass
+from enum import auto, Enum
+from pathlib import Path
+from typing import Any, Callable, Generic, Literal, NoReturn, TYPE_CHECKING, TypeVar
+from typing_extensions import assert_never, deprecated, Self
+
+from torchgen.code_template import CodeTemplate
+
+
+if TYPE_CHECKING:
+    from argparse import Namespace
+    from collections.abc import Iterable, Iterator, Sequence
+
+
+TORCHGEN_ROOT = Path(__file__).absolute().parent
+REPO_ROOT = TORCHGEN_ROOT.parent
+
+
+# Many of these functions share logic for defining both the definition
+# and declaration (for example, the function signature is the same), so
+# we organize them into one function that takes a Target to say which
+# code we want.
+#
+# This is an OPEN enum (we may add more cases to it in the future), so be sure
+# to explicitly specify with Literal[Target.XXX] or Literal[Target.XXX, Target.YYY]
+# what targets are valid for your use.
+class Target(Enum):
+    # top level namespace (not including at)
+    DEFINITION = auto()
+    DECLARATION = auto()
+    # TORCH_LIBRARY(...) { ... }
+    REGISTRATION = auto()
+    # namespace { ... }
+    ANONYMOUS_DEFINITION = auto()
+    # namespace cpu { ... }
+    NAMESPACED_DEFINITION = auto()
+    NAMESPACED_DECLARATION = auto()
+
+
+# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
+# occurrence of a parameter in the derivative formula
+IDENT_REGEX = r"(^|\W){}($|\W)"
+
+
+# TODO: Use a real parser here; this will get bamboozled
+def split_name_params(schema: str) -> tuple[str, list[str]]:
+    m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
+    if m is None:
+        raise RuntimeError(f"Unsupported function schema: {schema}")
+    name, _, params = m.groups()
+    return name, params.split(", ")
+
+
+T = TypeVar("T")
+S = TypeVar("S")
+
+# These two functions purposely return generators in analogy to map()
+# so that you don't mix up when you need to list() them
+
+
+# Map over function that may return None; omit Nones from output sequence
+def mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]:
+    for x in xs:
+        r = func(x)
+        if r is not None:
+            yield r
+
+
+# Map over function that returns sequences and cat them all together
+def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
+    for x in xs:
+        yield from func(x)
+
+
+# Conveniently add error context to exceptions raised.  Lets us
+# easily say that an error occurred while processing a specific
+# context.
+@contextlib.contextmanager
+def context(msg_fn: Callable[[], str]) -> Iterator[None]:
+    try:
+        yield
+    except Exception as e:
+        # TODO: this does the wrong thing with KeyError
+        msg = msg_fn()
+        msg = textwrap.indent(msg, "  ")
+        msg = f"{e.args[0]}\n{msg}" if e.args else msg
+        e.args = (msg,) + e.args[1:]
+        raise
+
+
+if TYPE_CHECKING:
+    # A little trick from https://github.com/python/mypy/issues/6366
+    # for getting mypy to do exhaustiveness checking
+    # TODO: put this somewhere else, maybe
+    @deprecated("Use typing_extensions.assert_never instead")
+    def assert_never(x: NoReturn) -> NoReturn:  # type: ignore[misc] # noqa: F811
+        raise AssertionError(f"Unhandled type: {type(x).__name__}")
+
+
+@functools.cache
+def _read_template(template_fn: str) -> CodeTemplate:
+    return CodeTemplate.from_file(template_fn)
+
+
+# String hash that's stable across different executions, unlike builtin hash
+def string_stable_hash(s: str) -> int:
+    sha1 = hashlib.sha1(s.encode("latin1"), usedforsecurity=False).digest()
+    return int.from_bytes(sha1, byteorder="little")
+
+
+# A small abstraction for writing out generated files and keeping track
+# of what files have been written (so you can write out a list of output
+# files)
+class FileManager:
+    def __init__(
+        self,
+        install_dir: str | Path,
+        template_dir: str | Path,
+        dry_run: bool,
+    ) -> None:
+        self.install_dir = Path(install_dir)
+        self.template_dir = Path(template_dir)
+        self.files: set[Path] = set()
+        self.dry_run = dry_run
+
+    @property
+    def filenames(self) -> frozenset[str]:
+        return frozenset({file.as_posix() for file in self.files})
+
+    def _write_if_changed(self, filename: str | Path, contents: str) -> None:
+        file = Path(filename)
+        old_contents: str | None = None
+        try:
+            old_contents = file.read_text(encoding="utf-8")
+        except OSError:
+            pass
+        if contents != old_contents:
+            # Create output directory if it doesn't exist
+            file.parent.mkdir(parents=True, exist_ok=True)
+            file.write_text(contents, encoding="utf-8")
+
+    # Read from template file and replace pattern with callable (type could be dict or str).
+    def substitute_with_template(
+        self,
+        template_fn: str | Path,
+        env_callable: Callable[[], str | dict[str, Any]],
+    ) -> str:
+        assert not Path(template_fn).is_absolute(), (
+            f"template_fn must be relative: {template_fn}"
+        )
+        template_path = self.template_dir / template_fn
+        env = env_callable()
+        if isinstance(env, dict):
+            if "generated_comment" not in env:
+                generator_default = TORCHGEN_ROOT / "gen.py"
+                try:
+                    generator = Path(
+                        sys.modules["__main__"].__file__ or generator_default
+                    ).absolute()
+                except (KeyError, AttributeError):
+                    generator = generator_default.absolute()
+
+                try:
+                    generator_path = generator.relative_to(REPO_ROOT).as_posix()
+                except ValueError:
+                    generator_path = generator.name
+
+                env = {
+                    **env,  # copy the original dict instead of mutating it
+                    "generated_comment": (
+                        "@" + f"generated by {generator_path} from {template_fn}"
+                    ),
+                }
+            template = _read_template(template_path)
+            substitute_out = template.substitute(env)
+            # Ensure an extra blank line between the class/function definition
+            # and the docstring of the previous class/function definition.
+            # NB: It is generally not recommended to have docstrings in pyi stub
+            #     files. But if there are any, we need to ensure that the file
+            #     is properly formatted.
+            return re.sub(
+                r'''
+                (""")\n+             # match triple quotes
+                (
+                    (\s*@.+\n)*     # match decorators if any
+                    \s*(class|def)  # match class/function definition
+                )
+                ''',
+                r"\g<1>\n\n\g<2>",
+                substitute_out,
+                flags=re.VERBOSE,
+            )
+        if isinstance(env, str):
+            return env
+        assert_never(env)
+
+    def write_with_template(
+        self,
+        filename: str | Path,
+        template_fn: str | Path,
+        env_callable: Callable[[], str | dict[str, Any]],
+    ) -> None:
+        filename = Path(filename)
+        assert not filename.is_absolute(), f"filename must be relative: {filename}"
+        file = self.install_dir / filename
+        assert file not in self.files, f"duplicate file write {file}"
+        self.files.add(file)
+        if not self.dry_run:
+            substitute_out = self.substitute_with_template(
+                template_fn=template_fn,
+                env_callable=env_callable,
+            )
+            self._write_if_changed(filename=file, contents=substitute_out)
+
+    def write(
+        self,
+        filename: str | Path,
+        env_callable: Callable[[], str | dict[str, Any]],
+    ) -> None:
+        self.write_with_template(filename, filename, env_callable)
+
+    def write_sharded(
+        self,
+        filename: str | Path,
+        items: Iterable[T],
+        *,
+        key_fn: Callable[[T], str],
+        env_callable: Callable[[T], dict[str, list[str]]],
+        num_shards: int,
+        base_env: dict[str, Any] | None = None,
+        sharded_keys: set[str],
+    ) -> None:
+        self.write_sharded_with_template(
+            filename,
+            filename,
+            items,
+            key_fn=key_fn,
+            env_callable=env_callable,
+            num_shards=num_shards,
+            base_env=base_env,
+            sharded_keys=sharded_keys,
+        )
+
+    def write_sharded_with_template(
+        self,
+        filename: str | Path,
+        template_fn: str | Path,
+        items: Iterable[T],
+        *,
+        key_fn: Callable[[T], str],
+        env_callable: Callable[[T], dict[str, list[str]]],
+        num_shards: int,
+        base_env: dict[str, Any] | None = None,
+        sharded_keys: set[str],
+    ) -> None:
+        file = Path(filename)
+        assert not file.is_absolute(), f"filename must be relative: {filename}"
+        everything: dict[str, Any] = {"shard_id": "Everything"}
+        shards: list[dict[str, Any]] = [
+            {"shard_id": f"_{i}"} for i in range(num_shards)
+        ]
+        all_shards = [everything] + shards
+
+        if base_env is not None:
+            for shard in all_shards:
+                shard.update(base_env)
+
+        for key in sharded_keys:
+            for shard in all_shards:
+                if key in shard:
+                    assert isinstance(shard[key], list), (
+                        "sharded keys in base_env must be a list"
+                    )
+                    shard[key] = shard[key].copy()
+                else:
+                    shard[key] = []
+
+        def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None:
+            for k, v in from_.items():
+                assert k in sharded_keys, f"undeclared sharded key {k}"
+                into[k] += v
+
+        if self.dry_run:
+            # Dry runs don't write any templates, so incomplete environments are fine
+            items = ()
+
+        for item in items:
+            key = key_fn(item)
+            sid = string_stable_hash(key) % num_shards
+            env = env_callable(item)
+
+            merge_env(shards[sid], env)
+            merge_env(everything, env)
+
+        for shard in all_shards:
+            shard_id = shard["shard_id"]
+            self.write_with_template(
+                file.with_stem(f"{file.stem}{shard_id}"),
+                template_fn,
+                lambda: shard,
+            )
+
+        # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
+        self.files.discard(self.install_dir / file.with_stem(f"{file.stem}Everything"))
+
+    def write_outputs(self, variable_name: str, filename: str | Path) -> None:
+        """Write a file containing the list of all outputs which are generated by this script."""
+        content = "\n".join(
+            (
+                "set(",
+                variable_name,
+                # Use POSIX paths to avoid invalid escape sequences on Windows
+                *(f'    "{file.as_posix()}"' for file in sorted(self.files)),
+                ")",
+            )
+        )
+        self._write_if_changed(filename, content)
+
+    def template_dir_for_comments(self) -> str:
+        """
+        This needs to be deterministic. The template dir is an absolute path
+        that varies across builds. So, just use the path relative to this file,
+        which will point to the codegen source but will be stable.
+        """
+        return os.path.relpath(self.template_dir, os.path.dirname(__file__))
+
+
+# Helper function to generate file manager
+def make_file_manager(
+    options: Namespace,
+    install_dir: str | Path | None = None,
+) -> FileManager:
+    template_dir = os.path.join(options.source_path, "templates")
+    install_dir = install_dir if install_dir else options.install_dir
+    return FileManager(
+        install_dir=install_dir,
+        template_dir=template_dir,
+        dry_run=options.dry_run,
+    )
+
+
+# Helper function to create a pretty representation for dataclasses
+def dataclass_repr(
+    obj: Any,
+    indent: int = 0,
+    width: int = 80,
+) -> str:
+    # built-in pprint module support dataclasses from python 3.10
+    if sys.version_info >= (3, 10):
+        from pprint import pformat
+
+        return pformat(obj, indent, width)
+
+    return _pformat(obj, indent=indent, width=width)
+
+
+def _pformat(
+    obj: Any,
+    indent: int,
+    width: int,
+    curr_indent: int = 0,
+) -> str:
+    assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}"
+
+    class_name = obj.__class__.__name__
+    # update current indentation level with class name
+    curr_indent += len(class_name) + 1
+
+    fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr]
+
+    fields_str = []
+    for name, attr in fields_list:
+        # update the current indent level with the field name
+        # dict, list, set and tuple also add indent as done in pprint
+        _curr_indent = curr_indent + len(name) + 1
+        if is_dataclass(attr):
+            str_repr = _pformat(attr, indent, width, _curr_indent)
+        elif isinstance(attr, dict):
+            str_repr = _format_dict(attr, indent, width, _curr_indent)
+        elif isinstance(attr, (list, set, tuple)):
+            str_repr = _format_list(attr, indent, width, _curr_indent)
+        else:
+            str_repr = repr(attr)
+
+        fields_str.append(f"{name}={str_repr}")
+
+    indent_str = curr_indent * " "
+    body = f",\n{indent_str}".join(fields_str)
+    return f"{class_name}({body})"
+
+
+def _format_dict(
+    attr: dict[Any, Any],
+    indent: int,
+    width: int,
+    curr_indent: int,
+) -> str:
+    curr_indent += indent + 3
+    dict_repr = []
+    for k, v in attr.items():
+        k_repr = repr(k)
+        v_str = (
+            _pformat(v, indent, width, curr_indent + len(k_repr))
+            if is_dataclass(v)
+            else repr(v)
+        )
+        dict_repr.append(f"{k_repr}: {v_str}")
+
+    return _format(dict_repr, indent, width, curr_indent, "{", "}")
+
+
+def _format_list(
+    attr: list[Any] | set[Any] | tuple[Any, ...],
+    indent: int,
+    width: int,
+    curr_indent: int,
+) -> str:
+    curr_indent += indent + 1
+    list_repr = [
+        _pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l)
+        for l in attr
+    ]
+    start, end = ("[", "]") if isinstance(attr, list) else ("(", ")")
+    return _format(list_repr, indent, width, curr_indent, start, end)
+
+
+def _format(
+    fields_str: list[str],
+    indent: int,
+    width: int,
+    curr_indent: int,
+    start: str,
+    end: str,
+) -> str:
+    delimiter, curr_indent_str = "", ""
+    # if it exceed the max width then we place one element per line
+    if len(repr(fields_str)) >= width:
+        delimiter = "\n"
+        curr_indent_str = " " * curr_indent
+
+    indent_str = " " * indent
+    body = f", {delimiter}{curr_indent_str}".join(fields_str)
+    return f"{start}{indent_str}{body}{end}"
+
+
+class NamespaceHelper:
+    """A helper for constructing the namespace open and close strings for a nested set of namespaces.
+
+    e.g. for namespace_str torch::lazy,
+
+    prologue:
+    namespace torch {
+    namespace lazy {
+
+    epilogue:
+    } // namespace lazy
+    } // namespace torch
+    """
+
+    def __init__(
+        self,
+        namespace_str: str,
+        entity_name: str = "",
+        max_level: int = 2,
+    ) -> None:
+        # cpp_namespace can be a colon joined string such as torch::lazy
+        cpp_namespaces = namespace_str.split("::")
+        assert len(cpp_namespaces) <= max_level, (
+            f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}."
+        )
+        self.cpp_namespace_ = namespace_str
+        self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
+        self.epilogue_ = "\n".join(
+            [f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
+        )
+        self.namespaces_ = cpp_namespaces
+        self.entity_name_ = entity_name
+
+    @staticmethod
+    def from_namespaced_entity(
+        namespaced_entity: str,
+        max_level: int = 2,
+    ) -> NamespaceHelper:
+        """
+        Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
+        """
+        names = namespaced_entity.split("::")
+        entity_name = names[-1]
+        namespace_str = "::".join(names[:-1])
+        return NamespaceHelper(
+            namespace_str=namespace_str, entity_name=entity_name, max_level=max_level
+        )
+
+    @property
+    def prologue(self) -> str:
+        return self.prologue_
+
+    @property
+    def epilogue(self) -> str:
+        return self.epilogue_
+
+    @property
+    def entity_name(self) -> str:
+        return self.entity_name_
+
+    # Only allow certain level of namespaces
+    def get_cpp_namespace(self, default: str = "") -> str:
+        """
+        Return the namespace string from joining all the namespaces by "::" (hence no leading "::").
+        Return default if namespace string is empty.
+        """
+        return self.cpp_namespace_ if self.cpp_namespace_ else default
+
+
+class OrderedSet(Generic[T]):
+    storage: dict[T, Literal[None]]
+
+    def __init__(self, iterable: Iterable[T] | None = None) -> None:
+        if iterable is None:
+            self.storage = {}
+        else:
+            self.storage = dict.fromkeys(iterable)
+
+    def __contains__(self, item: T) -> bool:
+        return item in self.storage
+
+    def __iter__(self) -> Iterator[T]:
+        return iter(self.storage.keys())
+
+    def update(self, items: OrderedSet[T]) -> None:
+        self.storage.update(items.storage)
+
+    def add(self, item: T) -> None:
+        self.storage[item] = None
+
+    def copy(self) -> OrderedSet[T]:
+        ret: OrderedSet[T] = OrderedSet()
+        ret.storage = self.storage.copy()
+        return ret
+
+    @staticmethod
+    def union(*args: OrderedSet[T]) -> OrderedSet[T]:
+        ret = args[0].copy()
+        for s in args[1:]:
+            ret.update(s)
+        return ret
+
+    def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]:
+        return OrderedSet.union(self, other)
+
+    def __ior__(self, other: OrderedSet[T]) -> Self:
+        self.update(other)
+        return self
+
+    def __eq__(self, other: object) -> bool:
+        if isinstance(other, OrderedSet):
+            return self.storage == other.storage
+        else:
+            return set(self.storage.keys()) == other
diff --git a/phivenv/Lib/site-packages/torchgen/yaml_utils.py b/phivenv/Lib/site-packages/torchgen/yaml_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..53379216638e79cbdb9a21409812ad948c3f7cf5
--- /dev/null
+++ b/phivenv/Lib/site-packages/torchgen/yaml_utils.py
@@ -0,0 +1,26 @@
+# Safely load fast C Yaml loader/dumper if they are available
+try:
+    from yaml import CSafeLoader as Loader
+except ImportError:
+    from yaml import SafeLoader as Loader  # type: ignore[assignment, misc]
+
+try:
+    from yaml import CSafeDumper as Dumper
+except ImportError:
+    from yaml import SafeDumper as Dumper  # type: ignore[assignment, misc]
+YamlDumper = Dumper
+
+
+# A custom loader for YAML that errors on duplicate keys.
+# This doesn't happen by default: see https://github.com/yaml/pyyaml/issues/165
+class YamlLoader(Loader):
+    def construct_mapping(self, node, deep=False):  # type: ignore[no-untyped-def]
+        mapping = []
+        for key_node, value_node in node.value:
+            key = self.construct_object(key_node, deep=deep)  # type: ignore[no-untyped-call]
+            assert key not in mapping, (
+                f"Found a duplicate key in the yaml. key={key}, line={node.start_mark.line}"
+            )
+            mapping.append(key)
+        mapping = super().construct_mapping(node, deep=deep)  # type: ignore[no-untyped-call]
+        return mapping
diff --git a/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/INSTALLER b/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/INSTALLER
new file mode 100644
index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/INSTALLER
@@ -0,0 +1 @@
+pip
diff --git a/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/LICENCE b/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/LICENCE
new file mode 100644
index 0000000000000000000000000000000000000000..a8922b182e80d9bcb955e8b8ae2bd9a017d72977
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/LICENCE
@@ -0,0 +1,49 @@
+`tqdm` is a product of collaborative work.
+Unless otherwise stated, all authors (see commit logs) retain copyright
+for their respective work, and release the work under the MIT licence
+(text below).
+
+Exceptions or notable authors are listed below
+in reverse chronological order:
+
+* files: *
+  MPL-2.0 2015-2024 (c) Casper da Costa-Luis
+  [casperdcl](https://github.com/casperdcl).
+* files: tqdm/_tqdm.py
+  MIT 2016 (c) [PR #96] on behalf of Google Inc.
+* files: tqdm/_tqdm.py README.rst .gitignore
+  MIT 2013 (c) Noam Yorav-Raphael, original author.
+
+[PR #96]: https://github.com/tqdm/tqdm/pull/96
+
+
+Mozilla Public Licence (MPL) v. 2.0 - Exhibit A
+-----------------------------------------------
+
+This Source Code Form is subject to the terms of the
+Mozilla Public License, v. 2.0.
+If a copy of the MPL was not distributed with this project,
+You can obtain one at https://mozilla.org/MPL/2.0/.
+
+
+MIT License (MIT)
+-----------------
+
+Copyright (c) 2013 noamraph
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal in
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+the Software, and to permit persons to whom the Software is furnished to do so,
+subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
diff --git a/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/METADATA b/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/METADATA
new file mode 100644
index 0000000000000000000000000000000000000000..181b4dc8b2f8697d1c0374a612ffd8b2f2db346a
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/METADATA
@@ -0,0 +1,1594 @@
+Metadata-Version: 2.1
+Name: tqdm
+Version: 4.67.1
+Summary: Fast, Extensible Progress Meter
+Maintainer-email: tqdm developers 
+License: MPL-2.0 AND MIT
+Project-URL: homepage, https://tqdm.github.io
+Project-URL: repository, https://github.com/tqdm/tqdm
+Project-URL: changelog, https://tqdm.github.io/releases
+Project-URL: wiki, https://github.com/tqdm/tqdm/wiki
+Keywords: progressbar,progressmeter,progress,bar,meter,rate,eta,console,terminal,time
+Classifier: Development Status :: 5 - Production/Stable
+Classifier: Environment :: Console
+Classifier: Environment :: MacOS X
+Classifier: Environment :: Other Environment
+Classifier: Environment :: Win32 (MS Windows)
+Classifier: Environment :: X11 Applications
+Classifier: Framework :: IPython
+Classifier: Framework :: Jupyter
+Classifier: Intended Audience :: Developers
+Classifier: Intended Audience :: Education
+Classifier: Intended Audience :: End Users/Desktop
+Classifier: Intended Audience :: Other Audience
+Classifier: Intended Audience :: System Administrators
+Classifier: License :: OSI Approved :: MIT License
+Classifier: License :: OSI Approved :: Mozilla Public License 2.0 (MPL 2.0)
+Classifier: Operating System :: MacOS
+Classifier: Operating System :: MacOS :: MacOS X
+Classifier: Operating System :: Microsoft
+Classifier: Operating System :: Microsoft :: MS-DOS
+Classifier: Operating System :: Microsoft :: Windows
+Classifier: Operating System :: POSIX
+Classifier: Operating System :: POSIX :: BSD
+Classifier: Operating System :: POSIX :: BSD :: FreeBSD
+Classifier: Operating System :: POSIX :: Linux
+Classifier: Operating System :: POSIX :: SunOS/Solaris
+Classifier: Operating System :: Unix
+Classifier: Programming Language :: Python
+Classifier: Programming Language :: Python :: 3
+Classifier: Programming Language :: Python :: 3.7
+Classifier: Programming Language :: Python :: 3.8
+Classifier: Programming Language :: Python :: 3.9
+Classifier: Programming Language :: Python :: 3.10
+Classifier: Programming Language :: Python :: 3.11
+Classifier: Programming Language :: Python :: 3.12
+Classifier: Programming Language :: Python :: 3 :: Only
+Classifier: Programming Language :: Python :: Implementation
+Classifier: Programming Language :: Python :: Implementation :: IronPython
+Classifier: Programming Language :: Python :: Implementation :: PyPy
+Classifier: Programming Language :: Unix Shell
+Classifier: Topic :: Desktop Environment
+Classifier: Topic :: Education :: Computer Aided Instruction (CAI)
+Classifier: Topic :: Education :: Testing
+Classifier: Topic :: Office/Business
+Classifier: Topic :: Other/Nonlisted Topic
+Classifier: Topic :: Software Development :: Build Tools
+Classifier: Topic :: Software Development :: Libraries
+Classifier: Topic :: Software Development :: Libraries :: Python Modules
+Classifier: Topic :: Software Development :: Pre-processors
+Classifier: Topic :: Software Development :: User Interfaces
+Classifier: Topic :: System :: Installation/Setup
+Classifier: Topic :: System :: Logging
+Classifier: Topic :: System :: Monitoring
+Classifier: Topic :: System :: Shells
+Classifier: Topic :: Terminals
+Classifier: Topic :: Utilities
+Requires-Python: >=3.7
+Description-Content-Type: text/x-rst
+License-File: LICENCE
+Requires-Dist: colorama; platform_system == "Windows"
+Provides-Extra: dev
+Requires-Dist: pytest>=6; extra == "dev"
+Requires-Dist: pytest-cov; extra == "dev"
+Requires-Dist: pytest-timeout; extra == "dev"
+Requires-Dist: pytest-asyncio>=0.24; extra == "dev"
+Requires-Dist: nbval; extra == "dev"
+Provides-Extra: discord
+Requires-Dist: requests; extra == "discord"
+Provides-Extra: slack
+Requires-Dist: slack-sdk; extra == "slack"
+Provides-Extra: telegram
+Requires-Dist: requests; extra == "telegram"
+Provides-Extra: notebook
+Requires-Dist: ipywidgets>=6; extra == "notebook"
+
+|Logo|
+
+tqdm
+====
+
+|Py-Versions| |Versions| |Conda-Forge-Status| |Docker| |Snapcraft|
+
+|Build-Status| |Coverage-Status| |Branch-Coverage-Status| |Codacy-Grade| |Libraries-Rank| |PyPI-Downloads|
+
+|LICENCE| |OpenHub-Status| |binder-demo| |awesome-python|
+
+``tqdm`` derives from the Arabic word *taqaddum* (تقدّم) which can mean "progress,"
+and is an abbreviation for "I love you so much" in Spanish (*te quiero demasiado*).
+
+Instantly make your loops show a smart progress meter - just wrap any
+iterable with ``tqdm(iterable)``, and you're done!
+
+.. code:: python
+
+    from tqdm import tqdm
+    for i in tqdm(range(10000)):
+        ...
+
+``76%|████████████████████████        | 7568/10000 [00:33<00:10, 229.00it/s]``
+
+``trange(N)`` can be also used as a convenient shortcut for
+``tqdm(range(N))``.
+
+|Screenshot|
+    |Video| |Slides| |Merch|
+
+It can also be executed as a module with pipes:
+
+.. code:: sh
+
+    $ seq 9999999 | tqdm --bytes | wc -l
+    75.2MB [00:00, 217MB/s]
+    9999999
+
+    $ tar -zcf - docs/ | tqdm --bytes --total `du -sb docs/ | cut -f1` \
+        > backup.tgz
+     32%|██████████▍                      | 8.89G/27.9G [00:42<01:31, 223MB/s]
+
+Overhead is low -- about 60ns per iteration (80ns with ``tqdm.gui``), and is
+unit tested against performance regression.
+By comparison, the well-established
+`ProgressBar `__ has
+an 800ns/iter overhead.
+
+In addition to its low overhead, ``tqdm`` uses smart algorithms to predict
+the remaining time and to skip unnecessary iteration displays, which allows
+for a negligible overhead in most cases.
+
+``tqdm`` works on any platform
+(Linux, Windows, Mac, FreeBSD, NetBSD, Solaris/SunOS),
+in any console or in a GUI, and is also friendly with IPython/Jupyter notebooks.
+
+``tqdm`` does not require any dependencies (not even ``curses``!), just
+Python and an environment supporting ``carriage return \r`` and
+``line feed \n`` control characters.
+
+------------------------------------------
+
+.. contents:: Table of contents
+   :backlinks: top
+   :local:
+
+
+Installation
+------------
+
+Latest PyPI stable release
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+|Versions| |PyPI-Downloads| |Libraries-Dependents|
+
+.. code:: sh
+
+    pip install tqdm
+
+Latest development release on GitHub
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+|GitHub-Status| |GitHub-Stars| |GitHub-Commits| |GitHub-Forks| |GitHub-Updated|
+
+Pull and install pre-release ``devel`` branch:
+
+.. code:: sh
+
+    pip install "git+https://github.com/tqdm/tqdm.git@devel#egg=tqdm"
+
+Latest Conda release
+~~~~~~~~~~~~~~~~~~~~
+
+|Conda-Forge-Status|
+
+.. code:: sh
+
+    conda install -c conda-forge tqdm
+
+Latest Snapcraft release
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+|Snapcraft|
+
+There are 3 channels to choose from:
+
+.. code:: sh
+
+    snap install tqdm  # implies --stable, i.e. latest tagged release
+    snap install tqdm  --candidate  # master branch
+    snap install tqdm  --edge  # devel branch
+
+Note that ``snap`` binaries are purely for CLI use (not ``import``-able), and
+automatically set up ``bash`` tab-completion.
+
+Latest Docker release
+~~~~~~~~~~~~~~~~~~~~~
+
+|Docker|
+
+.. code:: sh
+
+    docker pull tqdm/tqdm
+    docker run -i --rm tqdm/tqdm --help
+
+Other
+~~~~~
+
+There are other (unofficial) places where ``tqdm`` may be downloaded, particularly for CLI use:
+
+|Repology|
+
+.. |Repology| image:: https://repology.org/badge/tiny-repos/python:tqdm.svg
+   :target: https://repology.org/project/python:tqdm/versions
+
+Changelog
+---------
+
+The list of all changes is available either on GitHub's Releases:
+|GitHub-Status|, on the
+`wiki `__, or on the
+`website `__.
+
+
+Usage
+-----
+
+``tqdm`` is very versatile and can be used in a number of ways.
+The three main ones are given below.
+
+Iterable-based
+~~~~~~~~~~~~~~
+
+Wrap ``tqdm()`` around any iterable:
+
+.. code:: python
+
+    from tqdm import tqdm
+    from time import sleep
+
+    text = ""
+    for char in tqdm(["a", "b", "c", "d"]):
+        sleep(0.25)
+        text = text + char
+
+``trange(i)`` is a special optimised instance of ``tqdm(range(i))``:
+
+.. code:: python
+
+    from tqdm import trange
+
+    for i in trange(100):
+        sleep(0.01)
+
+Instantiation outside of the loop allows for manual control over ``tqdm()``:
+
+.. code:: python
+
+    pbar = tqdm(["a", "b", "c", "d"])
+    for char in pbar:
+        sleep(0.25)
+        pbar.set_description("Processing %s" % char)
+
+Manual
+~~~~~~
+
+Manual control of ``tqdm()`` updates using a ``with`` statement:
+
+.. code:: python
+
+    with tqdm(total=100) as pbar:
+        for i in range(10):
+            sleep(0.1)
+            pbar.update(10)
+
+If the optional variable ``total`` (or an iterable with ``len()``) is
+provided, predictive stats are displayed.
+
+``with`` is also optional (you can just assign ``tqdm()`` to a variable,
+but in this case don't forget to ``del`` or ``close()`` at the end:
+
+.. code:: python
+
+    pbar = tqdm(total=100)
+    for i in range(10):
+        sleep(0.1)
+        pbar.update(10)
+    pbar.close()
+
+Module
+~~~~~~
+
+Perhaps the most wonderful use of ``tqdm`` is in a script or on the command
+line. Simply inserting ``tqdm`` (or ``python -m tqdm``) between pipes will pass
+through all ``stdin`` to ``stdout`` while printing progress to ``stderr``.
+
+The example below demonstrate counting the number of lines in all Python files
+in the current directory, with timing information included.
+
+.. code:: sh
+
+    $ time find . -name '*.py' -type f -exec cat \{} \; | wc -l
+    857365
+
+    real    0m3.458s
+    user    0m0.274s
+    sys     0m3.325s
+
+    $ time find . -name '*.py' -type f -exec cat \{} \; | tqdm | wc -l
+    857366it [00:03, 246471.31it/s]
+    857365
+
+    real    0m3.585s
+    user    0m0.862s
+    sys     0m3.358s
+
+Note that the usual arguments for ``tqdm`` can also be specified.
+
+.. code:: sh
+
+    $ find . -name '*.py' -type f -exec cat \{} \; |
+        tqdm --unit loc --unit_scale --total 857366 >> /dev/null
+    100%|█████████████████████████████████| 857K/857K [00:04<00:00, 246Kloc/s]
+
+Backing up a large directory?
+
+.. code:: sh
+
+    $ tar -zcf - docs/ | tqdm --bytes --total `du -sb docs/ | cut -f1` \
+      > backup.tgz
+     44%|██████████████▊                   | 153M/352M [00:14<00:18, 11.0MB/s]
+
+This can be beautified further:
+
+.. code:: sh
+
+    $ BYTES=$(du -sb docs/ | cut -f1)
+    $ tar -cf - docs/ \
+      | tqdm --bytes --total "$BYTES" --desc Processing | gzip \
+      | tqdm --bytes --total "$BYTES" --desc Compressed --position 1 \
+      > ~/backup.tgz
+    Processing: 100%|██████████████████████| 352M/352M [00:14<00:00, 30.2MB/s]
+    Compressed:  42%|█████████▎            | 148M/352M [00:14<00:19, 10.9MB/s]
+
+Or done on a file level using 7-zip:
+
+.. code:: sh
+
+    $ 7z a -bd -r backup.7z docs/ | grep Compressing \
+      | tqdm --total $(find docs/ -type f | wc -l) --unit files \
+      | grep -v Compressing
+    100%|██████████████████████████▉| 15327/15327 [01:00<00:00, 712.96files/s]
+
+Pre-existing CLI programs already outputting basic progress information will
+benefit from ``tqdm``'s ``--update`` and ``--update_to`` flags:
+
+.. code:: sh
+
+    $ seq 3 0.1 5 | tqdm --total 5 --update_to --null
+    100%|████████████████████████████████████| 5.0/5 [00:00<00:00, 9673.21it/s]
+    $ seq 10 | tqdm --update --null  # 1 + 2 + ... + 10 = 55 iterations
+    55it [00:00, 90006.52it/s]
+
+FAQ and Known Issues
+--------------------
+
+|GitHub-Issues|
+
+The most common issues relate to excessive output on multiple lines, instead
+of a neat one-line progress bar.
+
+- Consoles in general: require support for carriage return (``CR``, ``\r``).
+
+  * Some cloud logging consoles which don't support ``\r`` properly
+    (`cloudwatch `__,
+    `K8s `__) may benefit from
+    ``export TQDM_POSITION=-1``.
+
+- Nested progress bars:
+
+  * Consoles in general: require support for moving cursors up to the
+    previous line. For example,
+    `IDLE `__,
+    `ConEmu `__ and
+    `PyCharm `__ (also
+    `here `__,
+    `here `__, and
+    `here `__)
+    lack full support.
+  * Windows: additionally may require the Python module ``colorama``
+    to ensure nested bars stay within their respective lines.
+
+- Unicode:
+
+  * Environments which report that they support unicode will have solid smooth
+    progressbars. The fallback is an ``ascii``-only bar.
+  * Windows consoles often only partially support unicode and thus
+    `often require explicit ascii=True `__
+    (also `here `__). This is due to
+    either normal-width unicode characters being incorrectly displayed as
+    "wide", or some unicode characters not rendering.
+
+- Wrapping generators:
+
+  * Generator wrapper functions tend to hide the length of iterables.
+    ``tqdm`` does not.
+  * Replace ``tqdm(enumerate(...))`` with ``enumerate(tqdm(...))`` or
+    ``tqdm(enumerate(x), total=len(x), ...)``.
+    The same applies to ``numpy.ndenumerate``.
+  * Replace ``tqdm(zip(a, b))`` with ``zip(tqdm(a), b)`` or even
+    ``zip(tqdm(a), tqdm(b))``.
+  * The same applies to ``itertools``.
+  * Some useful convenience functions can be found under ``tqdm.contrib``.
+
+- `No intermediate output in docker-compose `__:
+  use ``docker-compose run`` instead of ``docker-compose up`` and ``tty: true``.
+
+- Overriding defaults via environment variables:
+  e.g. in CI/cloud jobs, ``export TQDM_MININTERVAL=5`` to avoid log spam.
+  This override logic is handled by the ``tqdm.utils.envwrap`` decorator
+  (useful independent of ``tqdm``).
+
+If you come across any other difficulties, browse and file |GitHub-Issues|.
+
+Documentation
+-------------
+
+|Py-Versions| |README-Hits| (Since 19 May 2016)
+
+.. code:: python
+
+    class tqdm():
+      """
+      Decorate an iterable object, returning an iterator which acts exactly
+      like the original iterable, but prints a dynamically updating
+      progressbar every time a value is requested.
+      """
+
+      @envwrap("TQDM_")  # override defaults via env vars
+      def __init__(self, iterable=None, desc=None, total=None, leave=True,
+                   file=None, ncols=None, mininterval=0.1,
+                   maxinterval=10.0, miniters=None, ascii=None, disable=False,
+                   unit='it', unit_scale=False, dynamic_ncols=False,
+                   smoothing=0.3, bar_format=None, initial=0, position=None,
+                   postfix=None, unit_divisor=1000, write_bytes=False,
+                   lock_args=None, nrows=None, colour=None, delay=0):
+
+Parameters
+~~~~~~~~~~
+
+* iterable  : iterable, optional  
+    Iterable to decorate with a progressbar.
+    Leave blank to manually manage the updates.
+* desc  : str, optional  
+    Prefix for the progressbar.
+* total  : int or float, optional  
+    The number of expected iterations. If unspecified,
+    len(iterable) is used if possible. If float("inf") or as a last
+    resort, only basic progress statistics are displayed
+    (no ETA, no progressbar).
+    If ``gui`` is True and this parameter needs subsequent updating,
+    specify an initial arbitrary large positive number,
+    e.g. 9e9.
+* leave  : bool, optional  
+    If [default: True], keeps all traces of the progressbar
+    upon termination of iteration.
+    If ``None``, will leave only if ``position`` is ``0``.
+* file  : ``io.TextIOWrapper`` or ``io.StringIO``, optional  
+    Specifies where to output the progress messages
+    (default: sys.stderr). Uses ``file.write(str)`` and ``file.flush()``
+    methods.  For encoding, see ``write_bytes``.
+* ncols  : int, optional  
+    The width of the entire output message. If specified,
+    dynamically resizes the progressbar to stay within this bound.
+    If unspecified, attempts to use environment width. The
+    fallback is a meter width of 10 and no limit for the counter and
+    statistics. If 0, will not print any meter (only stats).
+* mininterval  : float, optional  
+    Minimum progress display update interval [default: 0.1] seconds.
+* maxinterval  : float, optional  
+    Maximum progress display update interval [default: 10] seconds.
+    Automatically adjusts ``miniters`` to correspond to ``mininterval``
+    after long display update lag. Only works if ``dynamic_miniters``
+    or monitor thread is enabled.
+* miniters  : int or float, optional  
+    Minimum progress display update interval, in iterations.
+    If 0 and ``dynamic_miniters``, will automatically adjust to equal
+    ``mininterval`` (more CPU efficient, good for tight loops).
+    If > 0, will skip display of specified number of iterations.
+    Tweak this and ``mininterval`` to get very efficient loops.
+    If your progress is erratic with both fast and slow iterations
+    (network, skipping items, etc) you should set miniters=1.
+* ascii  : bool or str, optional  
+    If unspecified or False, use unicode (smooth blocks) to fill
+    the meter. The fallback is to use ASCII characters " 123456789#".
+* disable  : bool, optional  
+    Whether to disable the entire progressbar wrapper
+    [default: False]. If set to None, disable on non-TTY.
+* unit  : str, optional  
+    String that will be used to define the unit of each iteration
+    [default: it].
+* unit_scale  : bool or int or float, optional  
+    If 1 or True, the number of iterations will be reduced/scaled
+    automatically and a metric prefix following the
+    International System of Units standard will be added
+    (kilo, mega, etc.) [default: False]. If any other non-zero
+    number, will scale ``total`` and ``n``.
+* dynamic_ncols  : bool, optional  
+    If set, constantly alters ``ncols`` and ``nrows`` to the
+    environment (allowing for window resizes) [default: False].
+* smoothing  : float, optional  
+    Exponential moving average smoothing factor for speed estimates
+    (ignored in GUI mode). Ranges from 0 (average speed) to 1
+    (current/instantaneous speed) [default: 0.3].
+* bar_format  : str, optional  
+    Specify a custom bar string formatting. May impact performance.
+    [default: '{l_bar}{bar}{r_bar}'], where
+    l_bar='{desc}: {percentage:3.0f}%|' and
+    r_bar='| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, '
+    '{rate_fmt}{postfix}]'
+    Possible vars: l_bar, bar, r_bar, n, n_fmt, total, total_fmt,
+    percentage, elapsed, elapsed_s, ncols, nrows, desc, unit,
+    rate, rate_fmt, rate_noinv, rate_noinv_fmt,
+    rate_inv, rate_inv_fmt, postfix, unit_divisor,
+    remaining, remaining_s, eta.
+    Note that a trailing ": " is automatically removed after {desc}
+    if the latter is empty.
+* initial  : int or float, optional  
+    The initial counter value. Useful when restarting a progress
+    bar [default: 0]. If using float, consider specifying ``{n:.3f}``
+    or similar in ``bar_format``, or specifying ``unit_scale``.
+* position  : int, optional  
+    Specify the line offset to print this bar (starting from 0)
+    Automatic if unspecified.
+    Useful to manage multiple bars at once (eg, from threads).
+* postfix  : dict or ``*``, optional  
+    Specify additional stats to display at the end of the bar.
+    Calls ``set_postfix(**postfix)`` if possible (dict).
+* unit_divisor  : float, optional  
+    [default: 1000], ignored unless ``unit_scale`` is True.
+* write_bytes  : bool, optional  
+    Whether to write bytes. If (default: False) will write unicode.
+* lock_args  : tuple, optional  
+    Passed to ``refresh`` for intermediate output
+    (initialisation, iterating, and updating).
+* nrows  : int, optional  
+    The screen height. If specified, hides nested bars outside this
+    bound. If unspecified, attempts to use environment height.
+    The fallback is 20.
+* colour  : str, optional  
+    Bar colour (e.g. 'green', '#00ff00').
+* delay  : float, optional  
+    Don't display until [default: 0] seconds have elapsed.
+
+Extra CLI Options
+~~~~~~~~~~~~~~~~~
+
+* delim  : chr, optional  
+    Delimiting character [default: '\n']. Use '\0' for null.
+    N.B.: on Windows systems, Python converts '\n' to '\r\n'.
+* buf_size  : int, optional  
+    String buffer size in bytes [default: 256]
+    used when ``delim`` is specified.
+* bytes  : bool, optional  
+    If true, will count bytes, ignore ``delim``, and default
+    ``unit_scale`` to True, ``unit_divisor`` to 1024, and ``unit`` to 'B'.
+* tee  : bool, optional  
+    If true, passes ``stdin`` to both ``stderr`` and ``stdout``.
+* update  : bool, optional  
+    If true, will treat input as newly elapsed iterations,
+    i.e. numbers to pass to ``update()``. Note that this is slow
+    (~2e5 it/s) since every input must be decoded as a number.
+* update_to  : bool, optional  
+    If true, will treat input as total elapsed iterations,
+    i.e. numbers to assign to ``self.n``. Note that this is slow
+    (~2e5 it/s) since every input must be decoded as a number.
+* null  : bool, optional  
+    If true, will discard input (no stdout).
+* manpath  : str, optional  
+    Directory in which to install tqdm man pages.
+* comppath  : str, optional  
+    Directory in which to place tqdm completion.
+* log  : str, optional  
+    CRITICAL|FATAL|ERROR|WARN(ING)|[default: 'INFO']|DEBUG|NOTSET.
+
+Returns
+~~~~~~~
+
+* out  : decorated iterator.  
+
+.. code:: python
+
+    class tqdm():
+      def update(self, n=1):
+          """
+          Manually update the progress bar, useful for streams
+          such as reading files.
+          E.g.:
+          >>> t = tqdm(total=filesize) # Initialise
+          >>> for current_buffer in stream:
+          ...    ...
+          ...    t.update(len(current_buffer))
+          >>> t.close()
+          The last line is highly recommended, but possibly not necessary if
+          ``t.update()`` will be called in such a way that ``filesize`` will be
+          exactly reached and printed.
+
+          Parameters
+          ----------
+          n  : int or float, optional
+              Increment to add to the internal counter of iterations
+              [default: 1]. If using float, consider specifying ``{n:.3f}``
+              or similar in ``bar_format``, or specifying ``unit_scale``.
+
+          Returns
+          -------
+          out  : bool or None
+              True if a ``display()`` was triggered.
+          """
+
+      def close(self):
+          """Cleanup and (if leave=False) close the progressbar."""
+
+      def clear(self, nomove=False):
+          """Clear current bar display."""
+
+      def refresh(self):
+          """
+          Force refresh the display of this bar.
+
+          Parameters
+          ----------
+          nolock  : bool, optional
+              If ``True``, does not lock.
+              If [default: ``False``]: calls ``acquire()`` on internal lock.
+          lock_args  : tuple, optional
+              Passed to internal lock's ``acquire()``.
+              If specified, will only ``display()`` if ``acquire()`` returns ``True``.
+          """
+
+      def unpause(self):
+          """Restart tqdm timer from last print time."""
+
+      def reset(self, total=None):
+          """
+          Resets to 0 iterations for repeated use.
+
+          Consider combining with ``leave=True``.
+
+          Parameters
+          ----------
+          total  : int or float, optional. Total to use for the new bar.
+          """
+
+      def set_description(self, desc=None, refresh=True):
+          """
+          Set/modify description of the progress bar.
+
+          Parameters
+          ----------
+          desc  : str, optional
+          refresh  : bool, optional
+              Forces refresh [default: True].
+          """
+
+      def set_postfix(self, ordered_dict=None, refresh=True, **tqdm_kwargs):
+          """
+          Set/modify postfix (additional stats)
+          with automatic formatting based on datatype.
+
+          Parameters
+          ----------
+          ordered_dict  : dict or OrderedDict, optional
+          refresh  : bool, optional
+              Forces refresh [default: True].
+          kwargs  : dict, optional
+          """
+
+      @classmethod
+      def write(cls, s, file=sys.stdout, end="\n"):
+          """Print a message via tqdm (without overlap with bars)."""
+
+      @property
+      def format_dict(self):
+          """Public API for read-only member access."""
+
+      def display(self, msg=None, pos=None):
+          """
+          Use ``self.sp`` to display ``msg`` in the specified ``pos``.
+
+          Consider overloading this function when inheriting to use e.g.:
+          ``self.some_frontend(**self.format_dict)`` instead of ``self.sp``.
+
+          Parameters
+          ----------
+          msg  : str, optional. What to display (default: ``repr(self)``).
+          pos  : int, optional. Position to ``moveto``
+            (default: ``abs(self.pos)``).
+          """
+
+      @classmethod
+      @contextmanager
+      def wrapattr(cls, stream, method, total=None, bytes=True, **tqdm_kwargs):
+          """
+          stream  : file-like object.
+          method  : str, "read" or "write". The result of ``read()`` and
+              the first argument of ``write()`` should have a ``len()``.
+
+          >>> with tqdm.wrapattr(file_obj, "read", total=file_obj.size) as fobj:
+          ...     while True:
+          ...         chunk = fobj.read(chunk_size)
+          ...         if not chunk:
+          ...             break
+          """
+
+      @classmethod
+      def pandas(cls, *targs, **tqdm_kwargs):
+          """Registers the current `tqdm` class with `pandas`."""
+
+    def trange(*args, **tqdm_kwargs):
+        """Shortcut for `tqdm(range(*args), **tqdm_kwargs)`."""
+
+Convenience Functions
+~~~~~~~~~~~~~~~~~~~~~
+
+.. code:: python
+
+    def tqdm.contrib.tenumerate(iterable, start=0, total=None,
+                                tqdm_class=tqdm.auto.tqdm, **tqdm_kwargs):
+        """Equivalent of `numpy.ndenumerate` or builtin `enumerate`."""
+
+    def tqdm.contrib.tzip(iter1, *iter2plus, **tqdm_kwargs):
+        """Equivalent of builtin `zip`."""
+
+    def tqdm.contrib.tmap(function, *sequences, **tqdm_kwargs):
+        """Equivalent of builtin `map`."""
+
+Submodules
+~~~~~~~~~~
+
+.. code:: python
+
+    class tqdm.notebook.tqdm(tqdm.tqdm):
+        """IPython/Jupyter Notebook widget."""
+
+    class tqdm.auto.tqdm(tqdm.tqdm):
+        """Automatically chooses beween `tqdm.notebook` and `tqdm.tqdm`."""
+
+    class tqdm.asyncio.tqdm(tqdm.tqdm):
+      """Asynchronous version."""
+      @classmethod
+      def as_completed(cls, fs, *, loop=None, timeout=None, total=None,
+                       **tqdm_kwargs):
+          """Wrapper for `asyncio.as_completed`."""
+
+    class tqdm.gui.tqdm(tqdm.tqdm):
+        """Matplotlib GUI version."""
+
+    class tqdm.tk.tqdm(tqdm.tqdm):
+        """Tkinter GUI version."""
+
+    class tqdm.rich.tqdm(tqdm.tqdm):
+        """`rich.progress` version."""
+
+    class tqdm.keras.TqdmCallback(keras.callbacks.Callback):
+        """Keras callback for epoch and batch progress."""
+
+    class tqdm.dask.TqdmCallback(dask.callbacks.Callback):
+        """Dask callback for task progress."""
+
+
+``contrib``
++++++++++++
+
+The ``tqdm.contrib`` package also contains experimental modules:
+
+- ``tqdm.contrib.itertools``: Thin wrappers around ``itertools``
+- ``tqdm.contrib.concurrent``: Thin wrappers around ``concurrent.futures``
+- ``tqdm.contrib.slack``: Posts to `Slack `__ bots
+- ``tqdm.contrib.discord``: Posts to `Discord `__ bots
+- ``tqdm.contrib.telegram``: Posts to `Telegram `__ bots
+- ``tqdm.contrib.bells``: Automagically enables all optional features
+
+  * ``auto``, ``pandas``, ``slack``, ``discord``, ``telegram``
+
+Examples and Advanced Usage
+---------------------------
+
+- See the `examples `__
+  folder;
+- import the module and run ``help()``;
+- consult the `wiki `__;
+
+  * this has an
+    `excellent article `__
+    on how to make a **great** progressbar;
+
+- check out the `slides from PyData London `__, or
+- run the |binder-demo|.
+
+Description and additional stats
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Custom information can be displayed and updated dynamically on ``tqdm`` bars
+with the ``desc`` and ``postfix`` arguments:
+
+.. code:: python
+
+    from tqdm import tqdm, trange
+    from random import random, randint
+    from time import sleep
+
+    with trange(10) as t:
+        for i in t:
+            # Description will be displayed on the left
+            t.set_description('GEN %i' % i)
+            # Postfix will be displayed on the right,
+            # formatted automatically based on argument's datatype
+            t.set_postfix(loss=random(), gen=randint(1,999), str='h',
+                          lst=[1, 2])
+            sleep(0.1)
+
+    with tqdm(total=10, bar_format="{postfix[0]} {postfix[1][value]:>8.2g}",
+              postfix=["Batch", {"value": 0}]) as t:
+        for i in range(10):
+            sleep(0.1)
+            t.postfix[1]["value"] = i / 2
+            t.update()
+
+Points to remember when using ``{postfix[...]}`` in the ``bar_format`` string:
+
+- ``postfix`` also needs to be passed as an initial argument in a compatible
+  format, and
+- ``postfix`` will be auto-converted to a string if it is a ``dict``-like
+  object. To prevent this behaviour, insert an extra item into the dictionary
+  where the key is not a string.
+
+Additional ``bar_format`` parameters may also be defined by overriding
+``format_dict``, and the bar itself may be modified using ``ascii``:
+
+.. code:: python
+
+    from tqdm import tqdm
+    class TqdmExtraFormat(tqdm):
+        """Provides a `total_time` format parameter"""
+        @property
+        def format_dict(self):
+            d = super().format_dict
+            total_time = d["elapsed"] * (d["total"] or 0) / max(d["n"], 1)
+            d.update(total_time=self.format_interval(total_time) + " in total")
+            return d
+
+    for i in TqdmExtraFormat(
+          range(9), ascii=" .oO0",
+          bar_format="{total_time}: {percentage:.0f}%|{bar}{r_bar}"):
+        if i == 4:
+            break
+
+.. code::
+
+    00:00 in total: 44%|0000.     | 4/9 [00:00<00:00, 962.93it/s]
+
+Note that ``{bar}`` also supports a format specifier ``[width][type]``.
+
+- ``width``
+
+  * unspecified (default): automatic to fill ``ncols``
+  * ``int >= 0``: fixed width overriding ``ncols`` logic
+  * ``int < 0``: subtract from the automatic default
+
+- ``type``
+
+  * ``a``: ascii (``ascii=True`` override)
+  * ``u``: unicode (``ascii=False`` override)
+  * ``b``: blank (``ascii="  "`` override)
+
+This means a fixed bar with right-justified text may be created by using:
+``bar_format="{l_bar}{bar:10}|{bar:-10b}right-justified"``
+
+Nested progress bars
+~~~~~~~~~~~~~~~~~~~~
+
+``tqdm`` supports nested progress bars. Here's an example:
+
+.. code:: python
+
+    from tqdm.auto import trange
+    from time import sleep
+
+    for i in trange(4, desc='1st loop'):
+        for j in trange(5, desc='2nd loop'):
+            for k in trange(50, desc='3rd loop', leave=False):
+                sleep(0.01)
+
+For manual control over positioning (e.g. for multi-processing use),
+you may specify ``position=n`` where ``n=0`` for the outermost bar,
+``n=1`` for the next, and so on.
+However, it's best to check if ``tqdm`` can work without manual ``position``
+first.
+
+.. code:: python
+
+    from time import sleep
+    from tqdm import trange, tqdm
+    from multiprocessing import Pool, RLock, freeze_support
+
+    L = list(range(9))
+
+    def progresser(n):
+        interval = 0.001 / (n + 2)
+        total = 5000
+        text = f"#{n}, est. {interval * total:<04.2}s"
+        for _ in trange(total, desc=text, position=n):
+            sleep(interval)
+
+    if __name__ == '__main__':
+        freeze_support()  # for Windows support
+        tqdm.set_lock(RLock())  # for managing output contention
+        p = Pool(initializer=tqdm.set_lock, initargs=(tqdm.get_lock(),))
+        p.map(progresser, L)
+
+Note that in Python 3, ``tqdm.write`` is thread-safe:
+
+.. code:: python
+
+    from time import sleep
+    from tqdm import tqdm, trange
+    from concurrent.futures import ThreadPoolExecutor
+
+    L = list(range(9))
+
+    def progresser(n):
+        interval = 0.001 / (n + 2)
+        total = 5000
+        text = f"#{n}, est. {interval * total:<04.2}s"
+        for _ in trange(total, desc=text):
+            sleep(interval)
+        if n == 6:
+            tqdm.write("n == 6 completed.")
+            tqdm.write("`tqdm.write()` is thread-safe in py3!")
+
+    if __name__ == '__main__':
+        with ThreadPoolExecutor() as p:
+            p.map(progresser, L)
+
+Hooks and callbacks
+~~~~~~~~~~~~~~~~~~~
+
+``tqdm`` can easily support callbacks/hooks and manual updates.
+Here's an example with ``urllib``:
+
+**``urllib.urlretrieve`` documentation**
+
+    | [...]
+    | If present, the hook function will be called once
+    | on establishment of the network connection and once after each block read
+    | thereafter. The hook will be passed three arguments; a count of blocks
+    | transferred so far, a block size in bytes, and the total size of the file.
+    | [...]
+
+.. code:: python
+
+    import urllib, os
+    from tqdm import tqdm
+    urllib = getattr(urllib, 'request', urllib)
+
+    class TqdmUpTo(tqdm):
+        """Provides `update_to(n)` which uses `tqdm.update(delta_n)`."""
+        def update_to(self, b=1, bsize=1, tsize=None):
+            """
+            b  : int, optional
+                Number of blocks transferred so far [default: 1].
+            bsize  : int, optional
+                Size of each block (in tqdm units) [default: 1].
+            tsize  : int, optional
+                Total size (in tqdm units). If [default: None] remains unchanged.
+            """
+            if tsize is not None:
+                self.total = tsize
+            return self.update(b * bsize - self.n)  # also sets self.n = b * bsize
+
+    eg_link = "https://caspersci.uk.to/matryoshka.zip"
+    with TqdmUpTo(unit='B', unit_scale=True, unit_divisor=1024, miniters=1,
+                  desc=eg_link.split('/')[-1]) as t:  # all optional kwargs
+        urllib.urlretrieve(eg_link, filename=os.devnull,
+                           reporthook=t.update_to, data=None)
+        t.total = t.n
+
+Inspired by `twine#242 `__.
+Functional alternative in
+`examples/tqdm_wget.py `__.
+
+It is recommend to use ``miniters=1`` whenever there is potentially
+large differences in iteration speed (e.g. downloading a file over
+a patchy connection).
+
+**Wrapping read/write methods**
+
+To measure throughput through a file-like object's ``read`` or ``write``
+methods, use ``CallbackIOWrapper``:
+
+.. code:: python
+
+    from tqdm.auto import tqdm
+    from tqdm.utils import CallbackIOWrapper
+
+    with tqdm(total=file_obj.size,
+              unit='B', unit_scale=True, unit_divisor=1024) as t:
+        fobj = CallbackIOWrapper(t.update, file_obj, "read")
+        while True:
+            chunk = fobj.read(chunk_size)
+            if not chunk:
+                break
+        t.reset()
+        # ... continue to use `t` for something else
+
+Alternatively, use the even simpler ``wrapattr`` convenience function,
+which would condense both the ``urllib`` and ``CallbackIOWrapper`` examples
+down to:
+
+.. code:: python
+
+    import urllib, os
+    from tqdm import tqdm
+
+    eg_link = "https://caspersci.uk.to/matryoshka.zip"
+    response = getattr(urllib, 'request', urllib).urlopen(eg_link)
+    with tqdm.wrapattr(open(os.devnull, "wb"), "write",
+                       miniters=1, desc=eg_link.split('/')[-1],
+                       total=getattr(response, 'length', None)) as fout:
+        for chunk in response:
+            fout.write(chunk)
+
+The ``requests`` equivalent is nearly identical:
+
+.. code:: python
+
+    import requests, os
+    from tqdm import tqdm
+
+    eg_link = "https://caspersci.uk.to/matryoshka.zip"
+    response = requests.get(eg_link, stream=True)
+    with tqdm.wrapattr(open(os.devnull, "wb"), "write",
+                       miniters=1, desc=eg_link.split('/')[-1],
+                       total=int(response.headers.get('content-length', 0))) as fout:
+        for chunk in response.iter_content(chunk_size=4096):
+            fout.write(chunk)
+
+**Custom callback**
+
+``tqdm`` is known for intelligently skipping unnecessary displays. To make a
+custom callback take advantage of this, simply use the return value of
+``update()``. This is set to ``True`` if a ``display()`` was triggered.
+
+.. code:: python
+
+    from tqdm.auto import tqdm as std_tqdm
+
+    def external_callback(*args, **kwargs):
+        ...
+
+    class TqdmExt(std_tqdm):
+        def update(self, n=1):
+            displayed = super().update(n)
+            if displayed:
+                external_callback(**self.format_dict)
+            return displayed
+
+``asyncio``
+~~~~~~~~~~~
+
+Note that ``break`` isn't currently caught by asynchronous iterators.
+This means that ``tqdm`` cannot clean up after itself in this case:
+
+.. code:: python
+
+    from tqdm.asyncio import tqdm
+
+    async for i in tqdm(range(9)):
+        if i == 2:
+            break
+
+Instead, either call ``pbar.close()`` manually or use the context manager syntax:
+
+.. code:: python
+
+    from tqdm.asyncio import tqdm
+
+    with tqdm(range(9)) as pbar:
+        async for i in pbar:
+            if i == 2:
+                break
+
+Pandas Integration
+~~~~~~~~~~~~~~~~~~
+
+Due to popular demand we've added support for ``pandas`` -- here's an example
+for ``DataFrame.progress_apply`` and ``DataFrameGroupBy.progress_apply``:
+
+.. code:: python
+
+    import pandas as pd
+    import numpy as np
+    from tqdm import tqdm
+
+    df = pd.DataFrame(np.random.randint(0, 100, (100000, 6)))
+
+    # Register `pandas.progress_apply` and `pandas.Series.map_apply` with `tqdm`
+    # (can use `tqdm.gui.tqdm`, `tqdm.notebook.tqdm`, optional kwargs, etc.)
+    tqdm.pandas(desc="my bar!")
+
+    # Now you can use `progress_apply` instead of `apply`
+    # and `progress_map` instead of `map`
+    df.progress_apply(lambda x: x**2)
+    # can also groupby:
+    # df.groupby(0).progress_apply(lambda x: x**2)
+
+In case you're interested in how this works (and how to modify it for your
+own callbacks), see the
+`examples `__
+folder or import the module and run ``help()``.
+
+Keras Integration
+~~~~~~~~~~~~~~~~~
+
+A ``keras`` callback is also available:
+
+.. code:: python
+
+    from tqdm.keras import TqdmCallback
+
+    ...
+
+    model.fit(..., verbose=0, callbacks=[TqdmCallback()])
+
+Dask Integration
+~~~~~~~~~~~~~~~~
+
+A ``dask`` callback is also available:
+
+.. code:: python
+
+    from tqdm.dask import TqdmCallback
+
+    with TqdmCallback(desc="compute"):
+        ...
+        arr.compute()
+
+    # or use callback globally
+    cb = TqdmCallback(desc="global")
+    cb.register()
+    arr.compute()
+
+IPython/Jupyter Integration
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+IPython/Jupyter is supported via the ``tqdm.notebook`` submodule:
+
+.. code:: python
+
+    from tqdm.notebook import trange, tqdm
+    from time import sleep
+
+    for i in trange(3, desc='1st loop'):
+        for j in tqdm(range(100), desc='2nd loop'):
+            sleep(0.01)
+
+In addition to ``tqdm`` features, the submodule provides a native Jupyter
+widget (compatible with IPython v1-v4 and Jupyter), fully working nested bars
+and colour hints (blue: normal, green: completed, red: error/interrupt,
+light blue: no ETA); as demonstrated below.
+
+|Screenshot-Jupyter1|
+|Screenshot-Jupyter2|
+|Screenshot-Jupyter3|
+
+The ``notebook`` version supports percentage or pixels for overall width
+(e.g.: ``ncols='100%'`` or ``ncols='480px'``).
+
+It is also possible to let ``tqdm`` automatically choose between
+console or notebook versions by using the ``autonotebook`` submodule:
+
+.. code:: python
+
+    from tqdm.autonotebook import tqdm
+    tqdm.pandas()
+
+Note that this will issue a ``TqdmExperimentalWarning`` if run in a notebook
+since it is not meant to be possible to distinguish between ``jupyter notebook``
+and ``jupyter console``. Use ``auto`` instead of ``autonotebook`` to suppress
+this warning.
+
+Note that notebooks will display the bar in the cell where it was created.
+This may be a different cell from the one where it is used.
+If this is not desired, either
+
+- delay the creation of the bar to the cell where it must be displayed, or
+- create the bar with ``display=False``, and in a later cell call
+  ``display(bar.container)``:
+
+.. code:: python
+
+    from tqdm.notebook import tqdm
+    pbar = tqdm(..., display=False)
+
+.. code:: python
+
+    # different cell
+    display(pbar.container)
+
+The ``keras`` callback has a ``display()`` method which can be used likewise:
+
+.. code:: python
+
+    from tqdm.keras import TqdmCallback
+    cbk = TqdmCallback(display=False)
+
+.. code:: python
+
+    # different cell
+    cbk.display()
+    model.fit(..., verbose=0, callbacks=[cbk])
+
+Another possibility is to have a single bar (near the top of the notebook)
+which is constantly re-used (using ``reset()`` rather than ``close()``).
+For this reason, the notebook version (unlike the CLI version) does not
+automatically call ``close()`` upon ``Exception``.
+
+.. code:: python
+
+    from tqdm.notebook import tqdm
+    pbar = tqdm()
+
+.. code:: python
+
+    # different cell
+    iterable = range(100)
+    pbar.reset(total=len(iterable))  # initialise with new `total`
+    for i in iterable:
+        pbar.update()
+    pbar.refresh()  # force print final status but don't `close()`
+
+Custom Integration
+~~~~~~~~~~~~~~~~~~
+
+To change the default arguments (such as making ``dynamic_ncols=True``),
+simply use built-in Python magic:
+
+.. code:: python
+
+    from functools import partial
+    from tqdm import tqdm as std_tqdm
+    tqdm = partial(std_tqdm, dynamic_ncols=True)
+
+For further customisation,
+``tqdm`` may be inherited from to create custom callbacks (as with the
+``TqdmUpTo`` example `above <#hooks-and-callbacks>`__) or for custom frontends
+(e.g. GUIs such as notebook or plotting packages). In the latter case:
+
+1. ``def __init__()`` to call ``super().__init__(..., gui=True)`` to disable
+   terminal ``status_printer`` creation.
+2. Redefine: ``close()``, ``clear()``, ``display()``.
+
+Consider overloading ``display()`` to use e.g.
+``self.frontend(**self.format_dict)`` instead of ``self.sp(repr(self))``.
+
+Some submodule examples of inheritance:
+
+- `tqdm/notebook.py `__
+- `tqdm/gui.py `__
+- `tqdm/tk.py `__
+- `tqdm/contrib/slack.py `__
+- `tqdm/contrib/discord.py `__
+- `tqdm/contrib/telegram.py `__
+
+Dynamic Monitor/Meter
+~~~~~~~~~~~~~~~~~~~~~
+
+You can use a ``tqdm`` as a meter which is not monotonically increasing.
+This could be because ``n`` decreases (e.g. a CPU usage monitor) or ``total``
+changes.
+
+One example would be recursively searching for files. The ``total`` is the
+number of objects found so far, while ``n`` is the number of those objects which
+are files (rather than folders):
+
+.. code:: python
+
+    from tqdm import tqdm
+    import os.path
+
+    def find_files_recursively(path, show_progress=True):
+        files = []
+        # total=1 assumes `path` is a file
+        t = tqdm(total=1, unit="file", disable=not show_progress)
+        if not os.path.exists(path):
+            raise IOError("Cannot find:" + path)
+
+        def append_found_file(f):
+            files.append(f)
+            t.update()
+
+        def list_found_dir(path):
+            """returns os.listdir(path) assuming os.path.isdir(path)"""
+            listing = os.listdir(path)
+            # subtract 1 since a "file" we found was actually this directory
+            t.total += len(listing) - 1
+            # fancy way to give info without forcing a refresh
+            t.set_postfix(dir=path[-10:], refresh=False)
+            t.update(0)  # may trigger a refresh
+            return listing
+
+        def recursively_search(path):
+            if os.path.isdir(path):
+                for f in list_found_dir(path):
+                    recursively_search(os.path.join(path, f))
+            else:
+                append_found_file(path)
+
+        recursively_search(path)
+        t.set_postfix(dir=path)
+        t.close()
+        return files
+
+Using ``update(0)`` is a handy way to let ``tqdm`` decide when to trigger a
+display refresh to avoid console spamming.
+
+Writing messages
+~~~~~~~~~~~~~~~~
+
+This is a work in progress (see
+`#737 `__).
+
+Since ``tqdm`` uses a simple printing mechanism to display progress bars,
+you should not write any message in the terminal using ``print()`` while
+a progressbar is open.
+
+To write messages in the terminal without any collision with ``tqdm`` bar
+display, a ``.write()`` method is provided:
+
+.. code:: python
+
+    from tqdm.auto import tqdm, trange
+    from time import sleep
+
+    bar = trange(10)
+    for i in bar:
+        # Print using tqdm class method .write()
+        sleep(0.1)
+        if not (i % 3):
+            tqdm.write("Done task %i" % i)
+        # Can also use bar.write()
+
+By default, this will print to standard output ``sys.stdout``. but you can
+specify any file-like object using the ``file`` argument. For example, this
+can be used to redirect the messages writing to a log file or class.
+
+Redirecting writing
+~~~~~~~~~~~~~~~~~~~
+
+If using a library that can print messages to the console, editing the library
+by  replacing ``print()`` with ``tqdm.write()`` may not be desirable.
+In that case, redirecting ``sys.stdout`` to ``tqdm.write()`` is an option.
+
+To redirect ``sys.stdout``, create a file-like class that will write
+any input string to ``tqdm.write()``, and supply the arguments
+``file=sys.stdout, dynamic_ncols=True``.
+
+A reusable canonical example is given below:
+
+.. code:: python
+
+    from time import sleep
+    import contextlib
+    import sys
+    from tqdm import tqdm
+    from tqdm.contrib import DummyTqdmFile
+
+
+    @contextlib.contextmanager
+    def std_out_err_redirect_tqdm():
+        orig_out_err = sys.stdout, sys.stderr
+        try:
+            sys.stdout, sys.stderr = map(DummyTqdmFile, orig_out_err)
+            yield orig_out_err[0]
+        # Relay exceptions
+        except Exception as exc:
+            raise exc
+        # Always restore sys.stdout/err if necessary
+        finally:
+            sys.stdout, sys.stderr = orig_out_err
+
+    def some_fun(i):
+        print("Fee, fi, fo,".split()[i])
+
+    # Redirect stdout to tqdm.write() (don't forget the `as save_stdout`)
+    with std_out_err_redirect_tqdm() as orig_stdout:
+        # tqdm needs the original stdout
+        # and dynamic_ncols=True to autodetect console width
+        for i in tqdm(range(3), file=orig_stdout, dynamic_ncols=True):
+            sleep(.5)
+            some_fun(i)
+
+    # After the `with`, printing is restored
+    print("Done!")
+
+Redirecting ``logging``
+~~~~~~~~~~~~~~~~~~~~~~~
+
+Similar to ``sys.stdout``/``sys.stderr`` as detailed above, console ``logging``
+may also be redirected to ``tqdm.write()``.
+
+Warning: if also redirecting ``sys.stdout``/``sys.stderr``, make sure to
+redirect ``logging`` first if needed.
+
+Helper methods are available in ``tqdm.contrib.logging``. For example:
+
+.. code:: python
+
+    import logging
+    from tqdm import trange
+    from tqdm.contrib.logging import logging_redirect_tqdm
+
+    LOG = logging.getLogger(__name__)
+
+    if __name__ == '__main__':
+        logging.basicConfig(level=logging.INFO)
+        with logging_redirect_tqdm():
+            for i in trange(9):
+                if i == 4:
+                    LOG.info("console logging redirected to `tqdm.write()`")
+        # logging restored
+
+Monitoring thread, intervals and miniters
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+``tqdm`` implements a few tricks to increase efficiency and reduce overhead.
+
+- Avoid unnecessary frequent bar refreshing: ``mininterval`` defines how long
+  to wait between each refresh. ``tqdm`` always gets updated in the background,
+  but it will display only every ``mininterval``.
+- Reduce number of calls to check system clock/time.
+- ``mininterval`` is more intuitive to configure than ``miniters``.
+  A clever adjustment system ``dynamic_miniters`` will automatically adjust
+  ``miniters`` to the amount of iterations that fit into time ``mininterval``.
+  Essentially, ``tqdm`` will check if it's time to print without actually
+  checking time. This behaviour can be still be bypassed by manually setting
+  ``miniters``.
+
+However, consider a case with a combination of fast and slow iterations.
+After a few fast iterations, ``dynamic_miniters`` will set ``miniters`` to a
+large number. When iteration rate subsequently slows, ``miniters`` will
+remain large and thus reduce display update frequency. To address this:
+
+- ``maxinterval`` defines the maximum time between display refreshes.
+  A concurrent monitoring thread checks for overdue updates and forces one
+  where necessary.
+
+The monitoring thread should not have a noticeable overhead, and guarantees
+updates at least every 10 seconds by default.
+This value can be directly changed by setting the ``monitor_interval`` of
+any ``tqdm`` instance (i.e. ``t = tqdm.tqdm(...); t.monitor_interval = 2``).
+The monitor thread may be disabled application-wide by setting
+``tqdm.tqdm.monitor_interval = 0`` before instantiation of any ``tqdm`` bar.
+
+
+Merch
+-----
+
+You can buy `tqdm branded merch `__ now!
+
+Contributions
+-------------
+
+|GitHub-Commits| |GitHub-Issues| |GitHub-PRs| |OpenHub-Status| |GitHub-Contributions| |CII Best Practices|
+
+All source code is hosted on `GitHub `__.
+Contributions are welcome.
+
+See the
+`CONTRIBUTING `__
+file for more information.
+
+Developers who have made significant contributions, ranked by *SLoC*
+(surviving lines of code,
+`git fame `__ ``-wMC --excl '\.(png|gif|jpg)$'``),
+are:
+
+==================== ======================================================== ==== ================================
+Name                 ID                                                       SLoC Notes
+==================== ======================================================== ==== ================================
+Casper da Costa-Luis `casperdcl `__             ~80% primary maintainer |Gift-Casper|
+Stephen Larroque     `lrq3000 `__                 ~9%  team member
+Martin Zugnoni       `martinzugnoni `__     ~3%
+Daniel Ecer          `de-code `__                 ~2%
+Richard Sheridan     `richardsheridan `__ ~1%
+Guangshuo Chen       `chengs `__                   ~1%
+Helio Machado        `0x2b3bfa0 `__             ~1%
+Kyle Altendorf       `altendky `__               <1%
+Noam Yorav-Raphael   `noamraph `__               <1%  original author
+Matthew Stevens      `mjstevens777 `__       <1%
+Hadrien Mary         `hadim `__                     <1%  team member
+Mikhail Korobov      `kmike `__                     <1%  team member
+==================== ======================================================== ==== ================================
+
+Ports to Other Languages
+~~~~~~~~~~~~~~~~~~~~~~~~
+
+A list is available on
+`this wiki page `__.
+
+
+LICENCE
+-------
+
+Open Source (OSI approved): |LICENCE|
+
+Citation information: |DOI|
+
+|README-Hits| (Since 19 May 2016)
+
+.. |Logo| image:: https://tqdm.github.io/img/logo.gif
+.. |Screenshot| image:: https://tqdm.github.io/img/tqdm.gif
+.. |Video| image:: https://tqdm.github.io/img/video.jpg
+   :target: https://tqdm.github.io/video
+.. |Slides| image:: https://tqdm.github.io/img/slides.jpg
+   :target: https://tqdm.github.io/PyData2019/slides.html
+.. |Merch| image:: https://tqdm.github.io/img/merch.jpg
+   :target: https://tqdm.github.io/merch
+.. |Build-Status| image:: https://img.shields.io/github/actions/workflow/status/tqdm/tqdm/test.yml?branch=master&label=tqdm&logo=GitHub
+   :target: https://github.com/tqdm/tqdm/actions/workflows/test.yml
+.. |Coverage-Status| image:: https://img.shields.io/coveralls/github/tqdm/tqdm/master?logo=coveralls
+   :target: https://coveralls.io/github/tqdm/tqdm
+.. |Branch-Coverage-Status| image:: https://codecov.io/gh/tqdm/tqdm/branch/master/graph/badge.svg
+   :target: https://codecov.io/gh/tqdm/tqdm
+.. |Codacy-Grade| image:: https://app.codacy.com/project/badge/Grade/3f965571598f44549c7818f29cdcf177
+   :target: https://www.codacy.com/gh/tqdm/tqdm/dashboard
+.. |CII Best Practices| image:: https://bestpractices.coreinfrastructure.org/projects/3264/badge
+   :target: https://bestpractices.coreinfrastructure.org/projects/3264
+.. |GitHub-Status| image:: https://img.shields.io/github/tag/tqdm/tqdm.svg?maxAge=86400&logo=github&logoColor=white
+   :target: https://github.com/tqdm/tqdm/releases
+.. |GitHub-Forks| image:: https://img.shields.io/github/forks/tqdm/tqdm.svg?logo=github&logoColor=white
+   :target: https://github.com/tqdm/tqdm/network
+.. |GitHub-Stars| image:: https://img.shields.io/github/stars/tqdm/tqdm.svg?logo=github&logoColor=white
+   :target: https://github.com/tqdm/tqdm/stargazers
+.. |GitHub-Commits| image:: https://img.shields.io/github/commit-activity/y/tqdm/tqdm.svg?logo=git&logoColor=white
+   :target: https://github.com/tqdm/tqdm/graphs/commit-activity
+.. |GitHub-Issues| image:: https://img.shields.io/github/issues-closed/tqdm/tqdm.svg?logo=github&logoColor=white
+   :target: https://github.com/tqdm/tqdm/issues?q=
+.. |GitHub-PRs| image:: https://img.shields.io/github/issues-pr-closed/tqdm/tqdm.svg?logo=github&logoColor=white
+   :target: https://github.com/tqdm/tqdm/pulls
+.. |GitHub-Contributions| image:: https://img.shields.io/github/contributors/tqdm/tqdm.svg?logo=github&logoColor=white
+   :target: https://github.com/tqdm/tqdm/graphs/contributors
+.. |GitHub-Updated| image:: https://img.shields.io/github/last-commit/tqdm/tqdm/master.svg?logo=github&logoColor=white&label=pushed
+   :target: https://github.com/tqdm/tqdm/pulse
+.. |Gift-Casper| image:: https://img.shields.io/badge/dynamic/json.svg?color=ff69b4&label=gifts%20received&prefix=%C2%A3&query=%24..sum&url=https%3A%2F%2Fcaspersci.uk.to%2Fgifts.json
+   :target: https://cdcl.ml/sponsor
+.. |Versions| image:: https://img.shields.io/pypi/v/tqdm.svg
+   :target: https://tqdm.github.io/releases
+.. |PyPI-Downloads| image:: https://img.shields.io/pypi/dm/tqdm.svg?label=pypi%20downloads&logo=PyPI&logoColor=white
+   :target: https://pepy.tech/project/tqdm
+.. |Py-Versions| image:: https://img.shields.io/pypi/pyversions/tqdm.svg?logo=python&logoColor=white
+   :target: https://pypi.org/project/tqdm
+.. |Conda-Forge-Status| image:: https://img.shields.io/conda/v/conda-forge/tqdm.svg?label=conda-forge&logo=conda-forge
+   :target: https://anaconda.org/conda-forge/tqdm
+.. |Snapcraft| image:: https://img.shields.io/badge/snap-install-82BEA0.svg?logo=snapcraft
+   :target: https://snapcraft.io/tqdm
+.. |Docker| image:: https://img.shields.io/badge/docker-pull-blue.svg?logo=docker&logoColor=white
+   :target: https://hub.docker.com/r/tqdm/tqdm
+.. |Libraries-Rank| image:: https://img.shields.io/librariesio/sourcerank/pypi/tqdm.svg?logo=koding&logoColor=white
+   :target: https://libraries.io/pypi/tqdm
+.. |Libraries-Dependents| image:: https://img.shields.io/librariesio/dependent-repos/pypi/tqdm.svg?logo=koding&logoColor=white
+    :target: https://github.com/tqdm/tqdm/network/dependents
+.. |OpenHub-Status| image:: https://www.openhub.net/p/tqdm/widgets/project_thin_badge?format=gif
+   :target: https://www.openhub.net/p/tqdm?ref=Thin+badge
+.. |awesome-python| image:: https://awesome.re/mentioned-badge.svg
+   :target: https://github.com/vinta/awesome-python
+.. |LICENCE| image:: https://img.shields.io/pypi/l/tqdm.svg
+   :target: https://raw.githubusercontent.com/tqdm/tqdm/master/LICENCE
+.. |DOI| image:: https://img.shields.io/badge/DOI-10.5281/zenodo.595120-blue.svg
+   :target: https://doi.org/10.5281/zenodo.595120
+.. |binder-demo| image:: https://mybinder.org/badge_logo.svg
+   :target: https://mybinder.org/v2/gh/tqdm/tqdm/master?filepath=DEMO.ipynb
+.. |Screenshot-Jupyter1| image:: https://tqdm.github.io/img/jupyter-1.gif
+.. |Screenshot-Jupyter2| image:: https://tqdm.github.io/img/jupyter-2.gif
+.. |Screenshot-Jupyter3| image:: https://tqdm.github.io/img/jupyter-3.gif
+.. |README-Hits| image:: https://cgi.cdcl.ml/hits?q=tqdm&style=social&r=https://github.com/tqdm/tqdm&l=https://tqdm.github.io/img/favicon.png&f=https://tqdm.github.io/img/logo.gif
+   :target: https://cgi.cdcl.ml/hits?q=tqdm&a=plot&r=https://github.com/tqdm/tqdm&l=https://tqdm.github.io/img/favicon.png&f=https://tqdm.github.io/img/logo.gif&style=social
diff --git a/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/RECORD b/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/RECORD
new file mode 100644
index 0000000000000000000000000000000000000000..b8560998169d00101b7753855574bccc0fd7253e
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/RECORD
@@ -0,0 +1,74 @@
+../../Scripts/tqdm.exe,sha256=lybo3UyStlQyedxKJa92pAimFGF3PHvXFFDXkP9OoRk,106338
+tqdm-4.67.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
+tqdm-4.67.1.dist-info/LICENCE,sha256=3DMlLoKQFeOxUAhvubOkD2rW-zLC9GEM6BL6Z301mGo,1985
+tqdm-4.67.1.dist-info/METADATA,sha256=aIoWMt9SWhmP7FLc_vsSRtMerO6cA1qsrC1-r42P9mk,57675
+tqdm-4.67.1.dist-info/RECORD,,
+tqdm-4.67.1.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
+tqdm-4.67.1.dist-info/entry_points.txt,sha256=ReJCH7Ui3Zyh6M16E4OhsZ1oU7WtMXCfbtoyBhGO29Y,39
+tqdm-4.67.1.dist-info/top_level.txt,sha256=NLiUJNfmc9At15s7JURiwvqMEjUi9G5PMGRrmMYzNSM,5
+tqdm/__init__.py,sha256=9mQNYSSqP99JasubEC1POJLMmhkkBH6cJZxPIR5G2pQ,1572
+tqdm/__main__.py,sha256=bYt9eEaoRQWdejEHFD8REx9jxVEdZptECFsV7F49Ink,30
+tqdm/__pycache__/__init__.cpython-39.pyc,,
+tqdm/__pycache__/__main__.cpython-39.pyc,,
+tqdm/__pycache__/_dist_ver.cpython-39.pyc,,
+tqdm/__pycache__/_main.cpython-39.pyc,,
+tqdm/__pycache__/_monitor.cpython-39.pyc,,
+tqdm/__pycache__/_tqdm.cpython-39.pyc,,
+tqdm/__pycache__/_tqdm_gui.cpython-39.pyc,,
+tqdm/__pycache__/_tqdm_notebook.cpython-39.pyc,,
+tqdm/__pycache__/_tqdm_pandas.cpython-39.pyc,,
+tqdm/__pycache__/_utils.cpython-39.pyc,,
+tqdm/__pycache__/asyncio.cpython-39.pyc,,
+tqdm/__pycache__/auto.cpython-39.pyc,,
+tqdm/__pycache__/autonotebook.cpython-39.pyc,,
+tqdm/__pycache__/cli.cpython-39.pyc,,
+tqdm/__pycache__/dask.cpython-39.pyc,,
+tqdm/__pycache__/gui.cpython-39.pyc,,
+tqdm/__pycache__/keras.cpython-39.pyc,,
+tqdm/__pycache__/notebook.cpython-39.pyc,,
+tqdm/__pycache__/rich.cpython-39.pyc,,
+tqdm/__pycache__/std.cpython-39.pyc,,
+tqdm/__pycache__/tk.cpython-39.pyc,,
+tqdm/__pycache__/utils.cpython-39.pyc,,
+tqdm/__pycache__/version.cpython-39.pyc,,
+tqdm/_dist_ver.py,sha256=m5AdYI-jB-v6P0VJ_70isH_p24EzSOGSwVvuAZmkmKY,23
+tqdm/_main.py,sha256=9ySvgmi_2Sw4CAo5UDW0Q2dxfTryboEWGHohfCJz0sA,283
+tqdm/_monitor.py,sha256=Uku-DPWgzJ7dO5CK08xKJK-E_F6qQ-JB3ksuXczSYR0,3699
+tqdm/_tqdm.py,sha256=LfLCuJ6bpsVo9xilmtBXyEm1vGnUCFrliW85j3J-nD4,283
+tqdm/_tqdm_gui.py,sha256=03Hc8KayxJveieI5-0-2NGiDpLvw9jZekofJUV7CCwk,287
+tqdm/_tqdm_notebook.py,sha256=BuHiLuxu6uEfZFaPJW3RPpPaxaVctEQA3kdSJSDL1hw,307
+tqdm/_tqdm_pandas.py,sha256=c9jptUgigN6axRDhRd4Rif98Tmxeopc1nFNFhIpbFUE,888
+tqdm/_utils.py,sha256=_4E73bfDj4f1s3sM42NLHNrZDOkijZoWq-n6xWLkdZ8,553
+tqdm/asyncio.py,sha256=Kp2rSkNRf9KRqa3d9YpgeZQ7L7EZf2Ki4bSc7UPIyoo,2757
+tqdm/auto.py,sha256=nDZflj6p2zKkjBCNBourrhS81zYfZy1_dQvbckrdW8o,871
+tqdm/autonotebook.py,sha256=Yb9F5uaiBPhfbDDFpbtoG8I2YUw3uQJ89rUDLbfR6ws,956
+tqdm/cli.py,sha256=SbKlN8QyZ2ogenqt-wT_p6_sx2OOdCjCyhoZBFnlmyI,11010
+tqdm/completion.sh,sha256=j79KbSmpIj_E11jfTfBXrGnUTzKXVpQ1vGVQvsyDRl4,946
+tqdm/contrib/__init__.py,sha256=OgSwVXm-vlDJ-2imtoQ9z8qdom4snMSRztH72KMA82A,2494
+tqdm/contrib/__pycache__/__init__.cpython-39.pyc,,
+tqdm/contrib/__pycache__/bells.cpython-39.pyc,,
+tqdm/contrib/__pycache__/concurrent.cpython-39.pyc,,
+tqdm/contrib/__pycache__/discord.cpython-39.pyc,,
+tqdm/contrib/__pycache__/itertools.cpython-39.pyc,,
+tqdm/contrib/__pycache__/logging.cpython-39.pyc,,
+tqdm/contrib/__pycache__/slack.cpython-39.pyc,,
+tqdm/contrib/__pycache__/telegram.cpython-39.pyc,,
+tqdm/contrib/__pycache__/utils_worker.cpython-39.pyc,,
+tqdm/contrib/bells.py,sha256=Yx1HqGCmHrESCAO700j5wE__JCleNODJxedh1ijPLD0,837
+tqdm/contrib/concurrent.py,sha256=K1yjloKS5WRNFyjLRth0DmU5PAnDbF0A-GD27N-J4a8,3986
+tqdm/contrib/discord.py,sha256=MtVIL1s_dxH21G4sL8FBgQ4Wei23ho9Ek5T-AommvNc,5243
+tqdm/contrib/itertools.py,sha256=WdKKQU5eSzsqHu29SN_oH12huYZo0Jihqoi9-nVhwz4,774
+tqdm/contrib/logging.py,sha256=NsYtnKttj2mMrGm58mEdo5a9DP_2vv8pZyrimSuWulA,3760
+tqdm/contrib/slack.py,sha256=eP_Mr5sQonYniHxxQNGue3jk2JkIPmPWFZqIYxnOui0,4007
+tqdm/contrib/telegram.py,sha256=vn_9SATMbbwn2PAbzSDyOX6av3eBB01QBug11P4H-Og,5008
+tqdm/contrib/utils_worker.py,sha256=HJP5Mz1S1xyzEke2JaqJ2sYLHXADYoo2epT5AzQ38eA,1207
+tqdm/dask.py,sha256=9Ei58eVqTossRLhAfWyUFCduXYKjmLmwkaXIy-CHYfs,1319
+tqdm/gui.py,sha256=STIB3K8iDzDgkNUqWIpvcI_u0OGtbGNy5NwpALXhfWs,5479
+tqdm/keras.py,sha256=op9sBkb6q6c6dw2wJ0SD2ZwpPK7yM1Vbg4l1Qiy3MIo,4373
+tqdm/notebook.py,sha256=GtZ3IapLL1v8WNDaTSvPw0bJGTyfp71Vfz5HDnAzx1M,10895
+tqdm/rich.py,sha256=YyMPkEHVyYUVUR3adJKbVX26iTmNKpNMf3DEqmm-m60,5021
+tqdm/std.py,sha256=tWjz6-QCa92aqYjz7PIdkLUCAfiy-lJZheBtZyIIyO0,57461
+tqdm/tk.py,sha256=Gu0uwXwLCGPRGHORdi3WvBLGiseUp_xxX_h_gp9VpK0,6701
+tqdm/tqdm.1,sha256=aILyUPk2S4OPe_uWy2P4AMjUf0oQ6PUW0nLYXB-BWwI,7889
+tqdm/utils.py,sha256=6E0BQw3Sg7uGWKBM_cDn3P42tXswRhzkggbhBgLDjl8,11821
+tqdm/version.py,sha256=-1yWjfu3P0eghVsysHH07fbzdiADNRdzRtYPqOaqR2A,333
diff --git a/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/WHEEL b/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/WHEEL
new file mode 100644
index 0000000000000000000000000000000000000000..ae527e7d64811439e61b93aa375defb30e06edfe
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/WHEEL
@@ -0,0 +1,5 @@
+Wheel-Version: 1.0
+Generator: setuptools (75.6.0)
+Root-Is-Purelib: true
+Tag: py3-none-any
+
diff --git a/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/entry_points.txt b/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/entry_points.txt
new file mode 100644
index 0000000000000000000000000000000000000000..540e60f4e073bc53a5f0a521a3639e0d80780af4
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/entry_points.txt
@@ -0,0 +1,2 @@
+[console_scripts]
+tqdm = tqdm.cli:main
diff --git a/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/top_level.txt b/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/top_level.txt
new file mode 100644
index 0000000000000000000000000000000000000000..78620c472c9d799a14ccb02a0233f4669b3bcdcb
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm-4.67.1.dist-info/top_level.txt
@@ -0,0 +1 @@
+tqdm
diff --git a/phivenv/Lib/site-packages/tqdm/__init__.py b/phivenv/Lib/site-packages/tqdm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8081f77b8812f3b42d7949daa4195d2c35dc70ac
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/__init__.py
@@ -0,0 +1,38 @@
+from ._monitor import TMonitor, TqdmSynchronisationWarning
+from ._tqdm_pandas import tqdm_pandas
+from .cli import main  # TODO: remove in v5.0.0
+from .gui import tqdm as tqdm_gui  # TODO: remove in v5.0.0
+from .gui import trange as tgrange  # TODO: remove in v5.0.0
+from .std import (
+    TqdmDeprecationWarning, TqdmExperimentalWarning, TqdmKeyError, TqdmMonitorWarning,
+    TqdmTypeError, TqdmWarning, tqdm, trange)
+from .version import __version__
+
+__all__ = ['tqdm', 'tqdm_gui', 'trange', 'tgrange', 'tqdm_pandas',
+           'tqdm_notebook', 'tnrange', 'main', 'TMonitor',
+           'TqdmTypeError', 'TqdmKeyError',
+           'TqdmWarning', 'TqdmDeprecationWarning',
+           'TqdmExperimentalWarning',
+           'TqdmMonitorWarning', 'TqdmSynchronisationWarning',
+           '__version__']
+
+
+def tqdm_notebook(*args, **kwargs):  # pragma: no cover
+    """See tqdm.notebook.tqdm for full documentation"""
+    from warnings import warn
+
+    from .notebook import tqdm as _tqdm_notebook
+    warn("This function will be removed in tqdm==5.0.0\n"
+         "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`",
+         TqdmDeprecationWarning, stacklevel=2)
+    return _tqdm_notebook(*args, **kwargs)
+
+
+def tnrange(*args, **kwargs):  # pragma: no cover
+    """Shortcut for `tqdm.notebook.tqdm(range(*args), **kwargs)`."""
+    from warnings import warn
+
+    from .notebook import trange as _tnrange
+    warn("Please use `tqdm.notebook.trange` instead of `tqdm.tnrange`",
+         TqdmDeprecationWarning, stacklevel=2)
+    return _tnrange(*args, **kwargs)
diff --git a/phivenv/Lib/site-packages/tqdm/__main__.py b/phivenv/Lib/site-packages/tqdm/__main__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e28416e104515e90fca4b69cc60d0c61fd15d61
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/__main__.py
@@ -0,0 +1,3 @@
+from .cli import main
+
+main()
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c931434f16995ef46f16ccaa8c20205b5ba80022
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/__init__.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/__main__.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/__main__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a92d4cbc8f5ca5d06e03194dc67c8a3dd4894501
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/__main__.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/_dist_ver.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/_dist_ver.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..08cf527917456a4da393a094780f58a99cb4b590
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/_dist_ver.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/_main.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/_main.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72137ac24b341ae912d1d47dc168a1023a8782d4
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/_main.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/_monitor.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/_monitor.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5894d6ec56e2f0cba0b5772a650c84dff9585e2f
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/_monitor.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/_tqdm.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/_tqdm.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..35f88122a8cefa91ea667e27851908516fc36583
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/_tqdm.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/_tqdm_gui.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/_tqdm_gui.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c86b97528bfc05b92e817e7d478a75a10ddc5176
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/_tqdm_gui.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/_tqdm_notebook.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/_tqdm_notebook.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13c2ce62de55cdefcf891f23e42c79377e67f990
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/_tqdm_notebook.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/_tqdm_pandas.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/_tqdm_pandas.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ea36639e60a70b94c45baaad0c44b7499b89b526
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/_tqdm_pandas.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/_utils.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..65934908e4fe3a3ea50573931e511244d6d40511
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/_utils.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/asyncio.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/asyncio.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0a83769a06dbfef7fc3a58352c0783146f659af2
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/asyncio.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/auto.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/auto.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a31555b3d5fe329540f0c139c2010e1c103d47f
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/auto.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/autonotebook.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/autonotebook.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b0811b0480067acf1f8c0945037f710e5feb47b8
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/autonotebook.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/cli.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/cli.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c2f5253424a9ddd946f6b5a2484f622799295599
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/cli.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/dask.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/dask.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d2250aa404f35d325a377b42f9ab73e97c48410d
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/dask.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/gui.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/gui.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c30b1a304f6024350056480dc647addbd6be833f
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/gui.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/keras.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/keras.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b34a7f3a62d9c13eb64f2491b74d00b8a55baf8c
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/keras.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/notebook.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/notebook.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..947666d9bcd4958f2d28b8eb0760cf25306b673e
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/notebook.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/rich.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/rich.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ade8460b2b77a00a46e861bcbd1257fa768ab9e9
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/rich.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/std.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/std.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..457fa6d3a76897d02a03e4554c6bee07f6898767
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/std.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/tk.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/tk.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9b55d706e63ab6872bc304dbedca60a36eb0f80f
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/tk.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca6479eb68b8459972a0cf59311c549ca633555e
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/utils.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/__pycache__/version.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/__pycache__/version.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..80fd94f157b96bf166cfd160301dba2ec609cede
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/__pycache__/version.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/_dist_ver.py b/phivenv/Lib/site-packages/tqdm/_dist_ver.py
new file mode 100644
index 0000000000000000000000000000000000000000..61af7d5bb0b25d8dc934b45b18ea35bd32dbb465
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/_dist_ver.py
@@ -0,0 +1 @@
+__version__ = '4.67.1'
diff --git a/phivenv/Lib/site-packages/tqdm/_main.py b/phivenv/Lib/site-packages/tqdm/_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..04fdeeff17b5cc84b210f445b54b87d5b99e3748
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/_main.py
@@ -0,0 +1,9 @@
+from warnings import warn
+
+from .cli import *  # NOQA
+from .cli import __all__  # NOQA
+from .std import TqdmDeprecationWarning
+
+warn("This function will be removed in tqdm==5.0.0\n"
+     "Please use `tqdm.cli.*` instead of `tqdm._main.*`",
+     TqdmDeprecationWarning, stacklevel=2)
diff --git a/phivenv/Lib/site-packages/tqdm/_monitor.py b/phivenv/Lib/site-packages/tqdm/_monitor.py
new file mode 100644
index 0000000000000000000000000000000000000000..f71aa56817ca77eba5df4a2dd11cb0c4a9a7ea1c
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/_monitor.py
@@ -0,0 +1,95 @@
+import atexit
+from threading import Event, Thread, current_thread
+from time import time
+from warnings import warn
+
+__all__ = ["TMonitor", "TqdmSynchronisationWarning"]
+
+
+class TqdmSynchronisationWarning(RuntimeWarning):
+    """tqdm multi-thread/-process errors which may cause incorrect nesting
+    but otherwise no adverse effects"""
+    pass
+
+
+class TMonitor(Thread):
+    """
+    Monitoring thread for tqdm bars.
+    Monitors if tqdm bars are taking too much time to display
+    and readjusts miniters automatically if necessary.
+
+    Parameters
+    ----------
+    tqdm_cls  : class
+        tqdm class to use (can be core tqdm or a submodule).
+    sleep_interval  : float
+        Time to sleep between monitoring checks.
+    """
+    _test = {}  # internal vars for unit testing
+
+    def __init__(self, tqdm_cls, sleep_interval):
+        Thread.__init__(self)
+        self.daemon = True  # kill thread when main killed (KeyboardInterrupt)
+        self.woken = 0  # last time woken up, to sync with monitor
+        self.tqdm_cls = tqdm_cls
+        self.sleep_interval = sleep_interval
+        self._time = self._test.get("time", time)
+        self.was_killed = self._test.get("Event", Event)()
+        atexit.register(self.exit)
+        self.start()
+
+    def exit(self):
+        self.was_killed.set()
+        if self is not current_thread():
+            self.join()
+        return self.report()
+
+    def get_instances(self):
+        # returns a copy of started `tqdm_cls` instances
+        return [i for i in self.tqdm_cls._instances.copy()
+                # Avoid race by checking that the instance started
+                if hasattr(i, 'start_t')]
+
+    def run(self):
+        cur_t = self._time()
+        while True:
+            # After processing and before sleeping, notify that we woke
+            # Need to be done just before sleeping
+            self.woken = cur_t
+            # Sleep some time...
+            self.was_killed.wait(self.sleep_interval)
+            # Quit if killed
+            if self.was_killed.is_set():
+                return
+            # Then monitor!
+            # Acquire lock (to access _instances)
+            with self.tqdm_cls.get_lock():
+                cur_t = self._time()
+                # Check tqdm instances are waiting too long to print
+                instances = self.get_instances()
+                for instance in instances:
+                    # Check event in loop to reduce blocking time on exit
+                    if self.was_killed.is_set():
+                        return
+                    # Only if mininterval > 1 (else iterations are just slow)
+                    # and last refresh exceeded maxinterval
+                    if (
+                        instance.miniters > 1
+                        and (cur_t - instance.last_print_t) >= instance.maxinterval
+                    ):
+                        # force bypassing miniters on next iteration
+                        # (dynamic_miniters adjusts mininterval automatically)
+                        instance.miniters = 1
+                        # Refresh now! (works only for manual tqdm)
+                        instance.refresh(nolock=True)
+                    # Remove accidental long-lived strong reference
+                    del instance
+                if instances != self.get_instances():  # pragma: nocover
+                    warn("Set changed size during iteration" +
+                         " (see https://github.com/tqdm/tqdm/issues/481)",
+                         TqdmSynchronisationWarning, stacklevel=2)
+                # Remove accidental long-lived strong references
+                del instances
+
+    def report(self):
+        return not self.was_killed.is_set()
diff --git a/phivenv/Lib/site-packages/tqdm/_tqdm.py b/phivenv/Lib/site-packages/tqdm/_tqdm.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fc4962774a4651db7a739a3f143633b6215a9bd
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/_tqdm.py
@@ -0,0 +1,9 @@
+from warnings import warn
+
+from .std import *  # NOQA
+from .std import __all__  # NOQA
+from .std import TqdmDeprecationWarning
+
+warn("This function will be removed in tqdm==5.0.0\n"
+     "Please use `tqdm.std.*` instead of `tqdm._tqdm.*`",
+     TqdmDeprecationWarning, stacklevel=2)
diff --git a/phivenv/Lib/site-packages/tqdm/_tqdm_gui.py b/phivenv/Lib/site-packages/tqdm/_tqdm_gui.py
new file mode 100644
index 0000000000000000000000000000000000000000..f32aa894f54b3a5b47a0fbf4263c2fd20df56c9d
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/_tqdm_gui.py
@@ -0,0 +1,9 @@
+from warnings import warn
+
+from .gui import *  # NOQA
+from .gui import __all__  # NOQA
+from .std import TqdmDeprecationWarning
+
+warn("This function will be removed in tqdm==5.0.0\n"
+     "Please use `tqdm.gui.*` instead of `tqdm._tqdm_gui.*`",
+     TqdmDeprecationWarning, stacklevel=2)
diff --git a/phivenv/Lib/site-packages/tqdm/_tqdm_notebook.py b/phivenv/Lib/site-packages/tqdm/_tqdm_notebook.py
new file mode 100644
index 0000000000000000000000000000000000000000..f225fbf5b52d04987ccf68f4d5ee4b735e3158b0
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/_tqdm_notebook.py
@@ -0,0 +1,9 @@
+from warnings import warn
+
+from .notebook import *  # NOQA
+from .notebook import __all__  # NOQA
+from .std import TqdmDeprecationWarning
+
+warn("This function will be removed in tqdm==5.0.0\n"
+     "Please use `tqdm.notebook.*` instead of `tqdm._tqdm_notebook.*`",
+     TqdmDeprecationWarning, stacklevel=2)
diff --git a/phivenv/Lib/site-packages/tqdm/_tqdm_pandas.py b/phivenv/Lib/site-packages/tqdm/_tqdm_pandas.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4fe6efdc603579e7f8acfa27ac10dccdf3e94ce
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/_tqdm_pandas.py
@@ -0,0 +1,24 @@
+import sys
+
+__author__ = "github.com/casperdcl"
+__all__ = ['tqdm_pandas']
+
+
+def tqdm_pandas(tclass, **tqdm_kwargs):
+    """
+    Registers the given `tqdm` instance with
+    `pandas.core.groupby.DataFrameGroupBy.progress_apply`.
+    """
+    from tqdm import TqdmDeprecationWarning
+
+    if isinstance(tclass, type) or (getattr(tclass, '__name__', '').startswith(
+            'tqdm_')):  # delayed adapter case
+        TqdmDeprecationWarning(
+            "Please use `tqdm.pandas(...)` instead of `tqdm_pandas(tqdm, ...)`.",
+            fp_write=getattr(tqdm_kwargs.get('file', None), 'write', sys.stderr.write))
+        tclass.pandas(**tqdm_kwargs)
+    else:
+        TqdmDeprecationWarning(
+            "Please use `tqdm.pandas(...)` instead of `tqdm_pandas(tqdm(...))`.",
+            fp_write=getattr(tclass.fp, 'write', sys.stderr.write))
+        type(tclass).pandas(deprecated_t=tclass)
diff --git a/phivenv/Lib/site-packages/tqdm/_utils.py b/phivenv/Lib/site-packages/tqdm/_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..385e849e106d1319fe21045f14eb0aa6552fb153
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/_utils.py
@@ -0,0 +1,11 @@
+from warnings import warn
+
+from .std import TqdmDeprecationWarning
+from .utils import (  # NOQA, pylint: disable=unused-import
+    CUR_OS, IS_NIX, IS_WIN, RE_ANSI, Comparable, FormatReplace, SimpleTextIOWrapper,
+    _environ_cols_wrapper, _is_ascii, _is_utf, _screen_shape_linux, _screen_shape_tput,
+    _screen_shape_windows, _screen_shape_wrapper, _supports_unicode, _term_move_up, colorama)
+
+warn("This function will be removed in tqdm==5.0.0\n"
+     "Please use `tqdm.utils.*` instead of `tqdm._utils.*`",
+     TqdmDeprecationWarning, stacklevel=2)
diff --git a/phivenv/Lib/site-packages/tqdm/asyncio.py b/phivenv/Lib/site-packages/tqdm/asyncio.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d00a0a2e755f36068d079ccc12ca84d86ff42be
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/asyncio.py
@@ -0,0 +1,93 @@
+"""
+Asynchronous progressbar decorator for iterators.
+Includes a default `range` iterator printing to `stderr`.
+
+Usage:
+>>> from tqdm.asyncio import trange, tqdm
+>>> async for i in trange(10):
+...     ...
+"""
+import asyncio
+from sys import version_info
+
+from .std import tqdm as std_tqdm
+
+__author__ = {"github.com/": ["casperdcl"]}
+__all__ = ['tqdm_asyncio', 'tarange', 'tqdm', 'trange']
+
+
+class tqdm_asyncio(std_tqdm):
+    """
+    Asynchronous-friendly version of tqdm.
+    """
+    def __init__(self, iterable=None, *args, **kwargs):
+        super().__init__(iterable, *args, **kwargs)
+        self.iterable_awaitable = False
+        if iterable is not None:
+            if hasattr(iterable, "__anext__"):
+                self.iterable_next = iterable.__anext__
+                self.iterable_awaitable = True
+            elif hasattr(iterable, "__next__"):
+                self.iterable_next = iterable.__next__
+            else:
+                self.iterable_iterator = iter(iterable)
+                self.iterable_next = self.iterable_iterator.__next__
+
+    def __aiter__(self):
+        return self
+
+    async def __anext__(self):
+        try:
+            if self.iterable_awaitable:
+                res = await self.iterable_next()
+            else:
+                res = self.iterable_next()
+            self.update()
+            return res
+        except StopIteration:
+            self.close()
+            raise StopAsyncIteration
+        except BaseException:
+            self.close()
+            raise
+
+    def send(self, *args, **kwargs):
+        return self.iterable.send(*args, **kwargs)
+
+    @classmethod
+    def as_completed(cls, fs, *, loop=None, timeout=None, total=None, **tqdm_kwargs):
+        """
+        Wrapper for `asyncio.as_completed`.
+        """
+        if total is None:
+            total = len(fs)
+        kwargs = {}
+        if version_info[:2] < (3, 10):
+            kwargs['loop'] = loop
+        yield from cls(asyncio.as_completed(fs, timeout=timeout, **kwargs),
+                       total=total, **tqdm_kwargs)
+
+    @classmethod
+    async def gather(cls, *fs, loop=None, timeout=None, total=None, **tqdm_kwargs):
+        """
+        Wrapper for `asyncio.gather`.
+        """
+        async def wrap_awaitable(i, f):
+            return i, await f
+
+        ifs = [wrap_awaitable(i, f) for i, f in enumerate(fs)]
+        res = [await f for f in cls.as_completed(ifs, loop=loop, timeout=timeout,
+                                                 total=total, **tqdm_kwargs)]
+        return [i for _, i in sorted(res)]
+
+
+def tarange(*args, **kwargs):
+    """
+    A shortcut for `tqdm.asyncio.tqdm(range(*args), **kwargs)`.
+    """
+    return tqdm_asyncio(range(*args), **kwargs)
+
+
+# Aliases
+tqdm = tqdm_asyncio
+trange = tarange
diff --git a/phivenv/Lib/site-packages/tqdm/auto.py b/phivenv/Lib/site-packages/tqdm/auto.py
new file mode 100644
index 0000000000000000000000000000000000000000..206c4409d5269594bdbab3a092ef6e09e7c01947
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/auto.py
@@ -0,0 +1,40 @@
+"""
+Enables multiple commonly used features.
+
+Method resolution order:
+
+- `tqdm.autonotebook` without import warnings
+- `tqdm.asyncio`
+- `tqdm.std` base class
+
+Usage:
+>>> from tqdm.auto import trange, tqdm
+>>> for i in trange(10):
+...     ...
+"""
+import warnings
+
+from .std import TqdmExperimentalWarning
+
+with warnings.catch_warnings():
+    warnings.simplefilter("ignore", category=TqdmExperimentalWarning)
+    from .autonotebook import tqdm as notebook_tqdm
+
+from .asyncio import tqdm as asyncio_tqdm
+from .std import tqdm as std_tqdm
+
+if notebook_tqdm != std_tqdm:
+    class tqdm(notebook_tqdm, asyncio_tqdm):  # pylint: disable=inconsistent-mro
+        pass
+else:
+    tqdm = asyncio_tqdm
+
+
+def trange(*args, **kwargs):
+    """
+    A shortcut for `tqdm.auto.tqdm(range(*args), **kwargs)`.
+    """
+    return tqdm(range(*args), **kwargs)
+
+
+__all__ = ["tqdm", "trange"]
diff --git a/phivenv/Lib/site-packages/tqdm/autonotebook.py b/phivenv/Lib/site-packages/tqdm/autonotebook.py
new file mode 100644
index 0000000000000000000000000000000000000000..a09f2ec4b8c95f12b8c7b7774f84d5ec55826334
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/autonotebook.py
@@ -0,0 +1,29 @@
+"""
+Automatically choose between `tqdm.notebook` and `tqdm.std`.
+
+Usage:
+>>> from tqdm.autonotebook import trange, tqdm
+>>> for i in trange(10):
+...     ...
+"""
+import sys
+from warnings import warn
+
+try:
+    get_ipython = sys.modules['IPython'].get_ipython
+    if 'IPKernelApp' not in get_ipython().config:  # pragma: no cover
+        raise ImportError("console")
+    from .notebook import WARN_NOIPYW, IProgress
+    if IProgress is None:
+        from .std import TqdmWarning
+        warn(WARN_NOIPYW, TqdmWarning, stacklevel=2)
+        raise ImportError('ipywidgets')
+except Exception:
+    from .std import tqdm, trange
+else:  # pragma: no cover
+    from .notebook import tqdm, trange
+    from .std import TqdmExperimentalWarning
+    warn("Using `tqdm.autonotebook.tqdm` in notebook mode."
+         " Use `tqdm.tqdm` instead to force console mode"
+         " (e.g. in jupyter console)", TqdmExperimentalWarning, stacklevel=2)
+__all__ = ["tqdm", "trange"]
diff --git a/phivenv/Lib/site-packages/tqdm/cli.py b/phivenv/Lib/site-packages/tqdm/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..e54a7fc8599fe0dfef12cd53b76b27ae51b68b4b
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/cli.py
@@ -0,0 +1,324 @@
+"""
+Module version for monitoring CLI pipes (`... | python -m tqdm | ...`).
+"""
+import logging
+import re
+import sys
+from ast import literal_eval as numeric
+from textwrap import indent
+
+from .std import TqdmKeyError, TqdmTypeError, tqdm
+from .version import __version__
+
+__all__ = ["main"]
+log = logging.getLogger(__name__)
+
+
+def cast(val, typ):
+    log.debug((val, typ))
+    if " or " in typ:
+        for t in typ.split(" or "):
+            try:
+                return cast(val, t)
+            except TqdmTypeError:
+                pass
+        raise TqdmTypeError(f"{val} : {typ}")
+
+    # sys.stderr.write('\ndebug | `val:type`: `' + val + ':' + typ + '`.\n')
+    if typ == 'bool':
+        if (val == 'True') or (val == ''):
+            return True
+        if val == 'False':
+            return False
+        raise TqdmTypeError(val + ' : ' + typ)
+    if typ == 'chr':
+        if len(val) == 1:
+            return val.encode()
+        if re.match(r"^\\\w+$", val):
+            return eval(f'"{val}"').encode()
+        raise TqdmTypeError(f"{val} : {typ}")
+    if typ == 'str':
+        return val
+    if typ == 'int':
+        try:
+            return int(val)
+        except ValueError as exc:
+            raise TqdmTypeError(f"{val} : {typ}") from exc
+    if typ == 'float':
+        try:
+            return float(val)
+        except ValueError as exc:
+            raise TqdmTypeError(f"{val} : {typ}") from exc
+    raise TqdmTypeError(f"{val} : {typ}")
+
+
+def posix_pipe(fin, fout, delim=b'\\n', buf_size=256,
+               callback=lambda float: None, callback_len=True):
+    """
+    Params
+    ------
+    fin  : binary file with `read(buf_size : int)` method
+    fout  : binary file with `write` (and optionally `flush`) methods.
+    callback  : function(float), e.g.: `tqdm.update`
+    callback_len  : If (default: True) do `callback(len(buffer))`.
+      Otherwise, do `callback(data) for data in buffer.split(delim)`.
+    """
+    fp_write = fout.write
+
+    if not delim:
+        while True:
+            tmp = fin.read(buf_size)
+
+            # flush at EOF
+            if not tmp:
+                getattr(fout, 'flush', lambda: None)()
+                return
+
+            fp_write(tmp)
+            callback(len(tmp))
+        # return
+
+    buf = b''
+    len_delim = len(delim)
+    # n = 0
+    while True:
+        tmp = fin.read(buf_size)
+
+        # flush at EOF
+        if not tmp:
+            if buf:
+                fp_write(buf)
+                if callback_len:
+                    # n += 1 + buf.count(delim)
+                    callback(1 + buf.count(delim))
+                else:
+                    for i in buf.split(delim):
+                        callback(i)
+            getattr(fout, 'flush', lambda: None)()
+            return  # n
+
+        while True:
+            i = tmp.find(delim)
+            if i < 0:
+                buf += tmp
+                break
+            fp_write(buf + tmp[:i + len(delim)])
+            # n += 1
+            callback(1 if callback_len else (buf + tmp[:i]))
+            buf = b''
+            tmp = tmp[i + len_delim:]
+
+
+# ((opt, type), ... )
+RE_OPTS = re.compile(r'\n {4}(\S+)\s{2,}:\s*([^,]+)')
+# better split method assuming no positional args
+RE_SHLEX = re.compile(r'\s*(?  : \2', d)
+    split = RE_OPTS.split(d)
+    opt_types_desc = zip(split[1::3], split[2::3], split[3::3])
+    d = ''.join(('\n  --{0}  : {2}{3}' if otd[1] == 'bool' else
+                 '\n  --{0}=<{1}>  : {2}{3}').format(
+                     otd[0].replace('_', '-'), otd[0], *otd[1:])
+                for otd in opt_types_desc if otd[0] not in UNSUPPORTED_OPTS)
+
+    help_short = "Usage:\n  tqdm [--help | options]\n"
+    d = help_short + """
+Options:
+  -h, --help     Print this help and exit.
+  -v, --version  Print version and exit.
+""" + d.strip('\n') + '\n'
+
+    # opts = docopt(d, version=__version__)
+    if any(v in argv for v in ('-v', '--version')):
+        sys.stdout.write(__version__ + '\n')
+        sys.exit(0)
+    elif any(v in argv for v in ('-h', '--help')):
+        sys.stdout.write(d + '\n')
+        sys.exit(0)
+    elif argv and argv[0][:2] != '--':
+        sys.stderr.write(f"Error:Unknown argument:{argv[0]}\n{help_short}")
+
+    argv = RE_SHLEX.split(' '.join(["tqdm"] + argv))
+    opts = dict(zip(argv[1::3], argv[3::3]))
+
+    log.debug(opts)
+    opts.pop('log', True)
+
+    tqdm_args = {'file': fp}
+    try:
+        for (o, v) in opts.items():
+            o = o.replace('-', '_')
+            try:
+                tqdm_args[o] = cast(v, opt_types[o])
+            except KeyError as e:
+                raise TqdmKeyError(str(e))
+        log.debug('args:' + str(tqdm_args))
+
+        delim_per_char = tqdm_args.pop('bytes', False)
+        update = tqdm_args.pop('update', False)
+        update_to = tqdm_args.pop('update_to', False)
+        if sum((delim_per_char, update, update_to)) > 1:
+            raise TqdmKeyError("Can only have one of --bytes --update --update_to")
+    except Exception:
+        fp.write("\nError:\n" + help_short)
+        stdin, stdout_write = sys.stdin, sys.stdout.write
+        for i in stdin:
+            stdout_write(i)
+        raise
+    else:
+        buf_size = tqdm_args.pop('buf_size', 256)
+        delim = tqdm_args.pop('delim', b'\\n')
+        tee = tqdm_args.pop('tee', False)
+        manpath = tqdm_args.pop('manpath', None)
+        comppath = tqdm_args.pop('comppath', None)
+        if tqdm_args.pop('null', False):
+            class stdout(object):
+                @staticmethod
+                def write(_):
+                    pass
+        else:
+            stdout = sys.stdout
+            stdout = getattr(stdout, 'buffer', stdout)
+        stdin = getattr(sys.stdin, 'buffer', sys.stdin)
+        if manpath or comppath:
+            try:  # py<3.9
+                import importlib_resources as resources
+            except ImportError:
+                from importlib import resources
+            from pathlib import Path
+
+            def cp(name, dst):
+                """copy resource `name` to `dst`"""
+                fi = resources.files('tqdm') / name
+                dst.write_bytes(fi.read_bytes())
+                log.info("written:%s", dst)
+            if manpath is not None:
+                cp('tqdm.1', Path(manpath) / 'tqdm.1')
+            if comppath is not None:
+                cp('completion.sh', Path(comppath) / 'tqdm_completion.sh')
+            sys.exit(0)
+        if tee:
+            stdout_write = stdout.write
+            fp_write = getattr(fp, 'buffer', fp).write
+
+            class stdout(object):  # pylint: disable=function-redefined
+                @staticmethod
+                def write(x):
+                    with tqdm.external_write_mode(file=fp):
+                        fp_write(x)
+                    stdout_write(x)
+        if delim_per_char:
+            tqdm_args.setdefault('unit', 'B')
+            tqdm_args.setdefault('unit_scale', True)
+            tqdm_args.setdefault('unit_divisor', 1024)
+            log.debug(tqdm_args)
+            with tqdm(**tqdm_args) as t:
+                posix_pipe(stdin, stdout, '', buf_size, t.update)
+        elif delim == b'\\n':
+            log.debug(tqdm_args)
+            write = stdout.write
+            if update or update_to:
+                with tqdm(**tqdm_args) as t:
+                    if update:
+                        def callback(i):
+                            t.update(numeric(i.decode()))
+                    else:  # update_to
+                        def callback(i):
+                            t.update(numeric(i.decode()) - t.n)
+                    for i in stdin:
+                        write(i)
+                        callback(i)
+            else:
+                for i in tqdm(stdin, **tqdm_args):
+                    write(i)
+        else:
+            log.debug(tqdm_args)
+            with tqdm(**tqdm_args) as t:
+                callback_len = False
+                if update:
+                    def callback(i):
+                        t.update(numeric(i.decode()))
+                elif update_to:
+                    def callback(i):
+                        t.update(numeric(i.decode()) - t.n)
+                else:
+                    callback = t.update
+                    callback_len = True
+                posix_pipe(stdin, stdout, delim, buf_size, callback, callback_len)
diff --git a/phivenv/Lib/site-packages/tqdm/completion.sh b/phivenv/Lib/site-packages/tqdm/completion.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9f61c7f14bb8c1f6099b9eb75dce28ece6a7ae96
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/completion.sh
@@ -0,0 +1,19 @@
+#!/usr/bin/env bash
+_tqdm(){
+  local cur prv
+  cur="${COMP_WORDS[COMP_CWORD]}"
+  prv="${COMP_WORDS[COMP_CWORD - 1]}"
+
+  case ${prv} in
+  --bar_format|--buf_size|--colour|--comppath|--delay|--delim|--desc|--initial|--lock_args|--manpath|--maxinterval|--mininterval|--miniters|--ncols|--nrows|--position|--postfix|--smoothing|--total|--unit|--unit_divisor)
+    # await user input
+    ;;
+  "--log")
+    COMPREPLY=($(compgen -W       'CRITICAL FATAL ERROR WARN WARNING INFO DEBUG NOTSET' -- ${cur}))
+    ;;
+  *)
+    COMPREPLY=($(compgen -W '--ascii --bar_format --buf_size --bytes --colour --comppath --delay --delim --desc --disable --dynamic_ncols --help --initial --leave --lock_args --log --manpath --maxinterval --mininterval --miniters --ncols --nrows --null --position --postfix --smoothing --tee --total --unit --unit_divisor --unit_scale --update --update_to --version --write_bytes -h -v' -- ${cur}))
+    ;;
+  esac
+}
+complete -F _tqdm tqdm
diff --git a/phivenv/Lib/site-packages/tqdm/contrib/__init__.py b/phivenv/Lib/site-packages/tqdm/contrib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d059461f91fb79115263c16314c3487e16ab98c2
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/contrib/__init__.py
@@ -0,0 +1,92 @@
+"""
+Thin wrappers around common functions.
+
+Subpackages contain potentially unstable extensions.
+"""
+from warnings import warn
+
+from ..auto import tqdm as tqdm_auto
+from ..std import TqdmDeprecationWarning, tqdm
+from ..utils import ObjectWrapper
+
+__author__ = {"github.com/": ["casperdcl"]}
+__all__ = ['tenumerate', 'tzip', 'tmap']
+
+
+class DummyTqdmFile(ObjectWrapper):
+    """Dummy file-like that will write to tqdm"""
+
+    def __init__(self, wrapped):
+        super().__init__(wrapped)
+        self._buf = []
+
+    def write(self, x, nolock=False):
+        nl = b"\n" if isinstance(x, bytes) else "\n"
+        pre, sep, post = x.rpartition(nl)
+        if sep:
+            blank = type(nl)()
+            tqdm.write(blank.join(self._buf + [pre, sep]),
+                       end=blank, file=self._wrapped, nolock=nolock)
+            self._buf = [post]
+        else:
+            self._buf.append(x)
+
+    def __del__(self):
+        if self._buf:
+            blank = type(self._buf[0])()
+            try:
+                tqdm.write(blank.join(self._buf), end=blank, file=self._wrapped)
+            except (OSError, ValueError):
+                pass
+
+
+def builtin_iterable(func):
+    """Returns `func`"""
+    warn("This function has no effect, and will be removed in tqdm==5.0.0",
+         TqdmDeprecationWarning, stacklevel=2)
+    return func
+
+
+def tenumerate(iterable, start=0, total=None, tqdm_class=tqdm_auto, **tqdm_kwargs):
+    """
+    Equivalent of `numpy.ndenumerate` or builtin `enumerate`.
+
+    Parameters
+    ----------
+    tqdm_class  : [default: tqdm.auto.tqdm].
+    """
+    try:
+        import numpy as np
+    except ImportError:
+        pass
+    else:
+        if isinstance(iterable, np.ndarray):
+            return tqdm_class(np.ndenumerate(iterable), total=total or iterable.size,
+                              **tqdm_kwargs)
+    return enumerate(tqdm_class(iterable, total=total, **tqdm_kwargs), start)
+
+
+def tzip(iter1, *iter2plus, **tqdm_kwargs):
+    """
+    Equivalent of builtin `zip`.
+
+    Parameters
+    ----------
+    tqdm_class  : [default: tqdm.auto.tqdm].
+    """
+    kwargs = tqdm_kwargs.copy()
+    tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
+    for i in zip(tqdm_class(iter1, **kwargs), *iter2plus):
+        yield i
+
+
+def tmap(function, *sequences, **tqdm_kwargs):
+    """
+    Equivalent of builtin `map`.
+
+    Parameters
+    ----------
+    tqdm_class  : [default: tqdm.auto.tqdm].
+    """
+    for i in tzip(*sequences, **tqdm_kwargs):
+        yield function(*i)
diff --git a/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..768eca627b79032ace4775f97fe5013b306cf91d
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/__init__.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/bells.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/bells.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..10555901612bb67f7bab29da0c6a1f73d0cbb1d8
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/bells.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/concurrent.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/concurrent.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..574e027a43d1af533b62f3029813532b236ba107
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/concurrent.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/discord.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/discord.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6365042f5d243c552fadb20e698707bef3d4c41e
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/discord.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/itertools.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/itertools.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fc4f0f04c5a5c5d29c995005b64a4fb35f41ccbf
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/itertools.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/logging.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/logging.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1c4be8fe586883a061d7ab2b5ab74857d2f2de98
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/logging.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/slack.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/slack.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8cca6bfe0fca058ab56615b34dc02799eb1c2af6
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/slack.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/telegram.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/telegram.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d617ee236ecb63e32533be06a2124f9df79e0145
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/telegram.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/utils_worker.cpython-39.pyc b/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/utils_worker.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..88b326f52ad9d84485df9d71520e452746cbd8eb
Binary files /dev/null and b/phivenv/Lib/site-packages/tqdm/contrib/__pycache__/utils_worker.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/tqdm/contrib/bells.py b/phivenv/Lib/site-packages/tqdm/contrib/bells.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b8f4b9ecd894f1edfaa08d9fe730b8d7c8b93e0
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/contrib/bells.py
@@ -0,0 +1,26 @@
+"""
+Even more features than `tqdm.auto` (all the bells & whistles):
+
+- `tqdm.auto`
+- `tqdm.tqdm.pandas`
+- `tqdm.contrib.telegram`
+    + uses `${TQDM_TELEGRAM_TOKEN}` and `${TQDM_TELEGRAM_CHAT_ID}`
+- `tqdm.contrib.discord`
+    + uses `${TQDM_DISCORD_TOKEN}` and `${TQDM_DISCORD_CHANNEL_ID}`
+"""
+__all__ = ['tqdm', 'trange']
+import warnings
+from os import getenv
+
+if getenv("TQDM_SLACK_TOKEN") and getenv("TQDM_SLACK_CHANNEL"):
+    from .slack import tqdm, trange
+elif getenv("TQDM_TELEGRAM_TOKEN") and getenv("TQDM_TELEGRAM_CHAT_ID"):
+    from .telegram import tqdm, trange
+elif getenv("TQDM_DISCORD_TOKEN") and getenv("TQDM_DISCORD_CHANNEL_ID"):
+    from .discord import tqdm, trange
+else:
+    from ..auto import tqdm, trange
+
+with warnings.catch_warnings():
+    warnings.simplefilter("ignore", category=FutureWarning)
+    tqdm.pandas()
diff --git a/phivenv/Lib/site-packages/tqdm/contrib/concurrent.py b/phivenv/Lib/site-packages/tqdm/contrib/concurrent.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd81d622a1309df179042159a56cef4f8c309224
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/contrib/concurrent.py
@@ -0,0 +1,105 @@
+"""
+Thin wrappers around `concurrent.futures`.
+"""
+from contextlib import contextmanager
+from operator import length_hint
+from os import cpu_count
+
+from ..auto import tqdm as tqdm_auto
+from ..std import TqdmWarning
+
+__author__ = {"github.com/": ["casperdcl"]}
+__all__ = ['thread_map', 'process_map']
+
+
+@contextmanager
+def ensure_lock(tqdm_class, lock_name=""):
+    """get (create if necessary) and then restore `tqdm_class`'s lock"""
+    old_lock = getattr(tqdm_class, '_lock', None)  # don't create a new lock
+    lock = old_lock or tqdm_class.get_lock()  # maybe create a new lock
+    lock = getattr(lock, lock_name, lock)  # maybe subtype
+    tqdm_class.set_lock(lock)
+    yield lock
+    if old_lock is None:
+        del tqdm_class._lock
+    else:
+        tqdm_class.set_lock(old_lock)
+
+
+def _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs):
+    """
+    Implementation of `thread_map` and `process_map`.
+
+    Parameters
+    ----------
+    tqdm_class  : [default: tqdm.auto.tqdm].
+    max_workers  : [default: min(32, cpu_count() + 4)].
+    chunksize  : [default: 1].
+    lock_name  : [default: "":str].
+    """
+    kwargs = tqdm_kwargs.copy()
+    if "total" not in kwargs:
+        kwargs["total"] = length_hint(iterables[0])
+    tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
+    max_workers = kwargs.pop("max_workers", min(32, cpu_count() + 4))
+    chunksize = kwargs.pop("chunksize", 1)
+    lock_name = kwargs.pop("lock_name", "")
+    with ensure_lock(tqdm_class, lock_name=lock_name) as lk:
+        # share lock in case workers are already using `tqdm`
+        with PoolExecutor(max_workers=max_workers, initializer=tqdm_class.set_lock,
+                          initargs=(lk,)) as ex:
+            return list(tqdm_class(ex.map(fn, *iterables, chunksize=chunksize), **kwargs))
+
+
+def thread_map(fn, *iterables, **tqdm_kwargs):
+    """
+    Equivalent of `list(map(fn, *iterables))`
+    driven by `concurrent.futures.ThreadPoolExecutor`.
+
+    Parameters
+    ----------
+    tqdm_class  : optional
+        `tqdm` class to use for bars [default: tqdm.auto.tqdm].
+    max_workers  : int, optional
+        Maximum number of workers to spawn; passed to
+        `concurrent.futures.ThreadPoolExecutor.__init__`.
+        [default: max(32, cpu_count() + 4)].
+    """
+    from concurrent.futures import ThreadPoolExecutor
+    return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs)
+
+
+def process_map(fn, *iterables, **tqdm_kwargs):
+    """
+    Equivalent of `list(map(fn, *iterables))`
+    driven by `concurrent.futures.ProcessPoolExecutor`.
+
+    Parameters
+    ----------
+    tqdm_class  : optional
+        `tqdm` class to use for bars [default: tqdm.auto.tqdm].
+    max_workers  : int, optional
+        Maximum number of workers to spawn; passed to
+        `concurrent.futures.ProcessPoolExecutor.__init__`.
+        [default: min(32, cpu_count() + 4)].
+    chunksize  : int, optional
+        Size of chunks sent to worker processes; passed to
+        `concurrent.futures.ProcessPoolExecutor.map`. [default: 1].
+    lock_name  : str, optional
+        Member of `tqdm_class.get_lock()` to use [default: mp_lock].
+    """
+    from concurrent.futures import ProcessPoolExecutor
+    if iterables and "chunksize" not in tqdm_kwargs:
+        # default `chunksize=1` has poor performance for large iterables
+        # (most time spent dispatching items to workers).
+        longest_iterable_len = max(map(length_hint, iterables))
+        if longest_iterable_len > 1000:
+            from warnings import warn
+            warn("Iterable length %d > 1000 but `chunksize` is not set."
+                 " This may seriously degrade multiprocess performance."
+                 " Set `chunksize=1` or more." % longest_iterable_len,
+                 TqdmWarning, stacklevel=2)
+    if "lock_name" not in tqdm_kwargs:
+        tqdm_kwargs = tqdm_kwargs.copy()
+        tqdm_kwargs["lock_name"] = "mp_lock"
+    return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)
diff --git a/phivenv/Lib/site-packages/tqdm/contrib/discord.py b/phivenv/Lib/site-packages/tqdm/contrib/discord.py
new file mode 100644
index 0000000000000000000000000000000000000000..574baa84bbbeb5afce4a49f23edac894d680ca82
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/contrib/discord.py
@@ -0,0 +1,156 @@
+"""
+Sends updates to a Discord bot.
+
+Usage:
+>>> from tqdm.contrib.discord import tqdm, trange
+>>> for i in trange(10, token='{token}', channel_id='{channel_id}'):
+...     ...
+
+![screenshot](https://tqdm.github.io/img/screenshot-discord.png)
+"""
+from os import getenv
+from warnings import warn
+
+from requests import Session
+from requests.utils import default_user_agent
+
+from ..auto import tqdm as tqdm_auto
+from ..std import TqdmWarning
+from ..version import __version__
+from .utils_worker import MonoWorker
+
+__author__ = {"github.com/": ["casperdcl", "guigoruiz1"]}
+__all__ = ['DiscordIO', 'tqdm_discord', 'tdrange', 'tqdm', 'trange']
+
+
+class DiscordIO(MonoWorker):
+    """Non-blocking file-like IO using a Discord Bot."""
+    API = "https://discord.com/api/v10"
+    UA = f"tqdm (https://tqdm.github.io, {__version__}) {default_user_agent()}"
+
+    def __init__(self, token, channel_id):
+        """Creates a new message in the given `channel_id`."""
+        super().__init__()
+        self.token = token
+        self.channel_id = channel_id
+        self.session = Session()
+        self.text = self.__class__.__name__
+        self.message_id
+
+    @property
+    def message_id(self):
+        if hasattr(self, '_message_id'):
+            return self._message_id
+        try:
+            res = self.session.post(
+                f'{self.API}/channels/{self.channel_id}/messages',
+                headers={'Authorization': f'Bot {self.token}', 'User-Agent': self.UA},
+                json={'content': f"`{self.text}`"}).json()
+        except Exception as e:
+            tqdm_auto.write(str(e))
+        else:
+            if res.get('error_code') == 429:
+                warn("Creation rate limit: try increasing `mininterval`.",
+                     TqdmWarning, stacklevel=2)
+            else:
+                self._message_id = res['id']
+                return self._message_id
+
+    def write(self, s):
+        """Replaces internal `message_id`'s text with `s`."""
+        if not s:
+            s = "..."
+        s = s.replace('\r', '').strip()
+        if s == self.text:
+            return  # avoid duplicate message Bot error
+        message_id = self.message_id
+        if message_id is None:
+            return
+        self.text = s
+        try:
+            future = self.submit(
+                self.session.patch,
+                f'{self.API}/channels/{self.channel_id}/messages/{message_id}',
+                headers={'Authorization': f'Bot {self.token}', 'User-Agent': self.UA},
+                json={'content': f"`{self.text}`"})
+        except Exception as e:
+            tqdm_auto.write(str(e))
+        else:
+            return future
+
+    def delete(self):
+        """Deletes internal `message_id`."""
+        try:
+            future = self.submit(
+                self.session.delete,
+                f'{self.API}/channels/{self.channel_id}/messages/{self.message_id}',
+                headers={'Authorization': f'Bot {self.token}', 'User-Agent': self.UA})
+        except Exception as e:
+            tqdm_auto.write(str(e))
+        else:
+            return future
+
+
+class tqdm_discord(tqdm_auto):
+    """
+    Standard `tqdm.auto.tqdm` but also sends updates to a Discord Bot.
+    May take a few seconds to create (`__init__`).
+
+    - create a discord bot (not public, no requirement of OAuth2 code
+      grant, only send message permissions) & invite it to a channel:
+      
+    - copy the bot `{token}` & `{channel_id}` and paste below
+
+    >>> from tqdm.contrib.discord import tqdm, trange
+    >>> for i in tqdm(iterable, token='{token}', channel_id='{channel_id}'):
+    ...     ...
+    """
+    def __init__(self, *args, **kwargs):
+        """
+        Parameters
+        ----------
+        token  : str, required. Discord bot token
+            [default: ${TQDM_DISCORD_TOKEN}].
+        channel_id  : int, required. Discord channel ID
+            [default: ${TQDM_DISCORD_CHANNEL_ID}].
+
+        See `tqdm.auto.tqdm.__init__` for other parameters.
+        """
+        if not kwargs.get('disable'):
+            kwargs = kwargs.copy()
+            self.dio = DiscordIO(
+                kwargs.pop('token', getenv('TQDM_DISCORD_TOKEN')),
+                kwargs.pop('channel_id', getenv('TQDM_DISCORD_CHANNEL_ID')))
+        super().__init__(*args, **kwargs)
+
+    def display(self, **kwargs):
+        super().display(**kwargs)
+        fmt = self.format_dict
+        if fmt.get('bar_format', None):
+            fmt['bar_format'] = fmt['bar_format'].replace(
+                '', '{bar:10u}').replace('{bar}', '{bar:10u}')
+        else:
+            fmt['bar_format'] = '{l_bar}{bar:10u}{r_bar}'
+        self.dio.write(self.format_meter(**fmt))
+
+    def clear(self, *args, **kwargs):
+        super().clear(*args, **kwargs)
+        if not self.disable:
+            self.dio.write("")
+
+    def close(self):
+        if self.disable:
+            return
+        super().close()
+        if not (self.leave or (self.leave is None and self.pos == 0)):
+            self.dio.delete()
+
+
+def tdrange(*args, **kwargs):
+    """Shortcut for `tqdm.contrib.discord.tqdm(range(*args), **kwargs)`."""
+    return tqdm_discord(range(*args), **kwargs)
+
+
+# Aliases
+tqdm = tqdm_discord
+trange = tdrange
diff --git a/phivenv/Lib/site-packages/tqdm/contrib/itertools.py b/phivenv/Lib/site-packages/tqdm/contrib/itertools.py
new file mode 100644
index 0000000000000000000000000000000000000000..e67651a41a6b8760d9b928ea48239e4611d70315
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/contrib/itertools.py
@@ -0,0 +1,35 @@
+"""
+Thin wrappers around `itertools`.
+"""
+import itertools
+
+from ..auto import tqdm as tqdm_auto
+
+__author__ = {"github.com/": ["casperdcl"]}
+__all__ = ['product']
+
+
+def product(*iterables, **tqdm_kwargs):
+    """
+    Equivalent of `itertools.product`.
+
+    Parameters
+    ----------
+    tqdm_class  : [default: tqdm.auto.tqdm].
+    """
+    kwargs = tqdm_kwargs.copy()
+    tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
+    try:
+        lens = list(map(len, iterables))
+    except TypeError:
+        total = None
+    else:
+        total = 1
+        for i in lens:
+            total *= i
+        kwargs.setdefault("total", total)
+    with tqdm_class(**kwargs) as t:
+        it = itertools.product(*iterables)
+        for i in it:
+            yield i
+            t.update()
diff --git a/phivenv/Lib/site-packages/tqdm/contrib/logging.py b/phivenv/Lib/site-packages/tqdm/contrib/logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..e06febe37b5d70b5296804c55dca48a397c250e3
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/contrib/logging.py
@@ -0,0 +1,126 @@
+"""
+Helper functionality for interoperability with stdlib `logging`.
+"""
+import logging
+import sys
+from contextlib import contextmanager
+
+try:
+    from typing import Iterator, List, Optional, Type  # noqa: F401
+except ImportError:
+    pass
+
+from ..std import tqdm as std_tqdm
+
+
+class _TqdmLoggingHandler(logging.StreamHandler):
+    def __init__(
+        self,
+        tqdm_class=std_tqdm  # type: Type[std_tqdm]
+    ):
+        super().__init__()
+        self.tqdm_class = tqdm_class
+
+    def emit(self, record):
+        try:
+            msg = self.format(record)
+            self.tqdm_class.write(msg, file=self.stream)
+            self.flush()
+        except (KeyboardInterrupt, SystemExit):
+            raise
+        except:  # noqa pylint: disable=bare-except
+            self.handleError(record)
+
+
+def _is_console_logging_handler(handler):
+    return (isinstance(handler, logging.StreamHandler)
+            and handler.stream in {sys.stdout, sys.stderr})
+
+
+def _get_first_found_console_logging_handler(handlers):
+    for handler in handlers:
+        if _is_console_logging_handler(handler):
+            return handler
+
+
+@contextmanager
+def logging_redirect_tqdm(
+    loggers=None,  # type: Optional[List[logging.Logger]],
+    tqdm_class=std_tqdm  # type: Type[std_tqdm]
+):
+    # type: (...) -> Iterator[None]
+    """
+    Context manager redirecting console logging to `tqdm.write()`, leaving
+    other logging handlers (e.g. log files) unaffected.
+
+    Parameters
+    ----------
+    loggers  : list, optional
+      Which handlers to redirect (default: [logging.root]).
+    tqdm_class  : optional
+
+    Example
+    -------
+    ```python
+    import logging
+    from tqdm import trange
+    from tqdm.contrib.logging import logging_redirect_tqdm
+
+    LOG = logging.getLogger(__name__)
+
+    if __name__ == '__main__':
+        logging.basicConfig(level=logging.INFO)
+        with logging_redirect_tqdm():
+            for i in trange(9):
+                if i == 4:
+                    LOG.info("console logging redirected to `tqdm.write()`")
+        # logging restored
+    ```
+    """
+    if loggers is None:
+        loggers = [logging.root]
+    original_handlers_list = [logger.handlers for logger in loggers]
+    try:
+        for logger in loggers:
+            tqdm_handler = _TqdmLoggingHandler(tqdm_class)
+            orig_handler = _get_first_found_console_logging_handler(logger.handlers)
+            if orig_handler is not None:
+                tqdm_handler.setFormatter(orig_handler.formatter)
+                tqdm_handler.stream = orig_handler.stream
+            logger.handlers = [
+                handler for handler in logger.handlers
+                if not _is_console_logging_handler(handler)] + [tqdm_handler]
+        yield
+    finally:
+        for logger, original_handlers in zip(loggers, original_handlers_list):
+            logger.handlers = original_handlers
+
+
+@contextmanager
+def tqdm_logging_redirect(
+    *args,
+    # loggers=None,  # type: Optional[List[logging.Logger]]
+    # tqdm=None,  # type: Optional[Type[tqdm.tqdm]]
+    **kwargs
+):
+    # type: (...) -> Iterator[None]
+    """
+    Convenience shortcut for:
+    ```python
+    with tqdm_class(*args, **tqdm_kwargs) as pbar:
+        with logging_redirect_tqdm(loggers=loggers, tqdm_class=tqdm_class):
+            yield pbar
+    ```
+
+    Parameters
+    ----------
+    tqdm_class  : optional, (default: tqdm.std.tqdm).
+    loggers  : optional, list.
+    **tqdm_kwargs  : passed to `tqdm_class`.
+    """
+    tqdm_kwargs = kwargs.copy()
+    loggers = tqdm_kwargs.pop('loggers', None)
+    tqdm_class = tqdm_kwargs.pop('tqdm_class', std_tqdm)
+    with tqdm_class(*args, **tqdm_kwargs) as pbar:
+        with logging_redirect_tqdm(loggers=loggers, tqdm_class=tqdm_class):
+            yield pbar
diff --git a/phivenv/Lib/site-packages/tqdm/contrib/slack.py b/phivenv/Lib/site-packages/tqdm/contrib/slack.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bca8ee98904ce869a4f8d6417bbdc4f00b38751
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/contrib/slack.py
@@ -0,0 +1,120 @@
+"""
+Sends updates to a Slack app.
+
+Usage:
+>>> from tqdm.contrib.slack import tqdm, trange
+>>> for i in trange(10, token='{token}', channel='{channel}'):
+...     ...
+
+![screenshot](https://tqdm.github.io/img/screenshot-slack.png)
+"""
+import logging
+from os import getenv
+
+try:
+    from slack_sdk import WebClient
+except ImportError:
+    raise ImportError("Please `pip install slack-sdk`")
+
+from ..auto import tqdm as tqdm_auto
+from .utils_worker import MonoWorker
+
+__author__ = {"github.com/": ["0x2b3bfa0", "casperdcl"]}
+__all__ = ['SlackIO', 'tqdm_slack', 'tsrange', 'tqdm', 'trange']
+
+
+class SlackIO(MonoWorker):
+    """Non-blocking file-like IO using a Slack app."""
+    def __init__(self, token, channel):
+        """Creates a new message in the given `channel`."""
+        super().__init__()
+        self.client = WebClient(token=token)
+        self.text = self.__class__.__name__
+        try:
+            self.message = self.client.chat_postMessage(channel=channel, text=self.text)
+        except Exception as e:
+            tqdm_auto.write(str(e))
+            self.message = None
+
+    def write(self, s):
+        """Replaces internal `message`'s text with `s`."""
+        if not s:
+            s = "..."
+        s = s.replace('\r', '').strip()
+        if s == self.text:
+            return  # skip duplicate message
+        message = self.message
+        if message is None:
+            return
+        self.text = s
+        try:
+            future = self.submit(self.client.chat_update, channel=message['channel'],
+                                 ts=message['ts'], text='`' + s + '`')
+        except Exception as e:
+            tqdm_auto.write(str(e))
+        else:
+            return future
+
+
+class tqdm_slack(tqdm_auto):
+    """
+    Standard `tqdm.auto.tqdm` but also sends updates to a Slack app.
+    May take a few seconds to create (`__init__`).
+
+    - create a Slack app with the `chat:write` scope & invite it to a
+      channel: 
+    - copy the bot `{token}` & `{channel}` and paste below
+    >>> from tqdm.contrib.slack import tqdm, trange
+    >>> for i in tqdm(iterable, token='{token}', channel='{channel}'):
+    ...     ...
+    """
+    def __init__(self, *args, **kwargs):
+        """
+        Parameters
+        ----------
+        token  : str, required. Slack token
+            [default: ${TQDM_SLACK_TOKEN}].
+        channel  : int, required. Slack channel
+            [default: ${TQDM_SLACK_CHANNEL}].
+        mininterval  : float, optional.
+          Minimum of [default: 1.5] to avoid rate limit.
+
+        See `tqdm.auto.tqdm.__init__` for other parameters.
+        """
+        if not kwargs.get('disable'):
+            kwargs = kwargs.copy()
+            logging.getLogger("HTTPClient").setLevel(logging.WARNING)
+            self.sio = SlackIO(
+                kwargs.pop('token', getenv("TQDM_SLACK_TOKEN")),
+                kwargs.pop('channel', getenv("TQDM_SLACK_CHANNEL")))
+            kwargs['mininterval'] = max(1.5, kwargs.get('mininterval', 1.5))
+        super().__init__(*args, **kwargs)
+
+    def display(self, **kwargs):
+        super().display(**kwargs)
+        fmt = self.format_dict
+        if fmt.get('bar_format', None):
+            fmt['bar_format'] = fmt['bar_format'].replace(
+                '', '`{bar:10}`').replace('{bar}', '`{bar:10u}`')
+        else:
+            fmt['bar_format'] = '{l_bar}`{bar:10}`{r_bar}'
+        if fmt['ascii'] is False:
+            fmt['ascii'] = [":black_square:", ":small_blue_diamond:", ":large_blue_diamond:",
+                            ":large_blue_square:"]
+            fmt['ncols'] = 336
+        self.sio.write(self.format_meter(**fmt))
+
+    def clear(self, *args, **kwargs):
+        super().clear(*args, **kwargs)
+        if not self.disable:
+            self.sio.write("")
+
+
+def tsrange(*args, **kwargs):
+    """Shortcut for `tqdm.contrib.slack.tqdm(range(*args), **kwargs)`."""
+    return tqdm_slack(range(*args), **kwargs)
+
+
+# Aliases
+tqdm = tqdm_slack
+trange = tsrange
diff --git a/phivenv/Lib/site-packages/tqdm/contrib/telegram.py b/phivenv/Lib/site-packages/tqdm/contrib/telegram.py
new file mode 100644
index 0000000000000000000000000000000000000000..019151800bc0c4c4fc543314b6398aa602b0692a
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/contrib/telegram.py
@@ -0,0 +1,153 @@
+"""
+Sends updates to a Telegram bot.
+
+Usage:
+>>> from tqdm.contrib.telegram import tqdm, trange
+>>> for i in trange(10, token='{token}', chat_id='{chat_id}'):
+...     ...
+
+![screenshot](https://tqdm.github.io/img/screenshot-telegram.gif)
+"""
+from os import getenv
+from warnings import warn
+
+from requests import Session
+
+from ..auto import tqdm as tqdm_auto
+from ..std import TqdmWarning
+from .utils_worker import MonoWorker
+
+__author__ = {"github.com/": ["casperdcl"]}
+__all__ = ['TelegramIO', 'tqdm_telegram', 'ttgrange', 'tqdm', 'trange']
+
+
+class TelegramIO(MonoWorker):
+    """Non-blocking file-like IO using a Telegram Bot."""
+    API = 'https://api.telegram.org/bot'
+
+    def __init__(self, token, chat_id):
+        """Creates a new message in the given `chat_id`."""
+        super().__init__()
+        self.token = token
+        self.chat_id = chat_id
+        self.session = Session()
+        self.text = self.__class__.__name__
+        self.message_id
+
+    @property
+    def message_id(self):
+        if hasattr(self, '_message_id'):
+            return self._message_id
+        try:
+            res = self.session.post(
+                self.API + '%s/sendMessage' % self.token,
+                data={'text': '`' + self.text + '`', 'chat_id': self.chat_id,
+                      'parse_mode': 'MarkdownV2'}).json()
+        except Exception as e:
+            tqdm_auto.write(str(e))
+        else:
+            if res.get('error_code') == 429:
+                warn("Creation rate limit: try increasing `mininterval`.",
+                     TqdmWarning, stacklevel=2)
+            else:
+                self._message_id = res['result']['message_id']
+                return self._message_id
+
+    def write(self, s):
+        """Replaces internal `message_id`'s text with `s`."""
+        if not s:
+            s = "..."
+        s = s.replace('\r', '').strip()
+        if s == self.text:
+            return  # avoid duplicate message Bot error
+        message_id = self.message_id
+        if message_id is None:
+            return
+        self.text = s
+        try:
+            future = self.submit(
+                self.session.post, self.API + '%s/editMessageText' % self.token,
+                data={'text': '`' + s + '`', 'chat_id': self.chat_id,
+                      'message_id': message_id, 'parse_mode': 'MarkdownV2'})
+        except Exception as e:
+            tqdm_auto.write(str(e))
+        else:
+            return future
+
+    def delete(self):
+        """Deletes internal `message_id`."""
+        try:
+            future = self.submit(
+                self.session.post, self.API + '%s/deleteMessage' % self.token,
+                data={'chat_id': self.chat_id, 'message_id': self.message_id})
+        except Exception as e:
+            tqdm_auto.write(str(e))
+        else:
+            return future
+
+
+class tqdm_telegram(tqdm_auto):
+    """
+    Standard `tqdm.auto.tqdm` but also sends updates to a Telegram Bot.
+    May take a few seconds to create (`__init__`).
+
+    - create a bot 
+    - copy its `{token}`
+    - add the bot to a chat and send it a message such as `/start`
+    - go to  to find out
+      the `{chat_id}`
+    - paste the `{token}` & `{chat_id}` below
+
+    >>> from tqdm.contrib.telegram import tqdm, trange
+    >>> for i in tqdm(iterable, token='{token}', chat_id='{chat_id}'):
+    ...     ...
+    """
+    def __init__(self, *args, **kwargs):
+        """
+        Parameters
+        ----------
+        token  : str, required. Telegram token
+            [default: ${TQDM_TELEGRAM_TOKEN}].
+        chat_id  : str, required. Telegram chat ID
+            [default: ${TQDM_TELEGRAM_CHAT_ID}].
+
+        See `tqdm.auto.tqdm.__init__` for other parameters.
+        """
+        if not kwargs.get('disable'):
+            kwargs = kwargs.copy()
+            self.tgio = TelegramIO(
+                kwargs.pop('token', getenv('TQDM_TELEGRAM_TOKEN')),
+                kwargs.pop('chat_id', getenv('TQDM_TELEGRAM_CHAT_ID')))
+        super().__init__(*args, **kwargs)
+
+    def display(self, **kwargs):
+        super().display(**kwargs)
+        fmt = self.format_dict
+        if fmt.get('bar_format', None):
+            fmt['bar_format'] = fmt['bar_format'].replace(
+                '', '{bar:10u}').replace('{bar}', '{bar:10u}')
+        else:
+            fmt['bar_format'] = '{l_bar}{bar:10u}{r_bar}'
+        self.tgio.write(self.format_meter(**fmt))
+
+    def clear(self, *args, **kwargs):
+        super().clear(*args, **kwargs)
+        if not self.disable:
+            self.tgio.write("")
+
+    def close(self):
+        if self.disable:
+            return
+        super().close()
+        if not (self.leave or (self.leave is None and self.pos == 0)):
+            self.tgio.delete()
+
+
+def ttgrange(*args, **kwargs):
+    """Shortcut for `tqdm.contrib.telegram.tqdm(range(*args), **kwargs)`."""
+    return tqdm_telegram(range(*args), **kwargs)
+
+
+# Aliases
+tqdm = tqdm_telegram
+trange = ttgrange
diff --git a/phivenv/Lib/site-packages/tqdm/contrib/utils_worker.py b/phivenv/Lib/site-packages/tqdm/contrib/utils_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a03a2a8930001e37938836196e0d15b649b07a8
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/contrib/utils_worker.py
@@ -0,0 +1,38 @@
+"""
+IO/concurrency helpers for `tqdm.contrib`.
+"""
+from collections import deque
+from concurrent.futures import ThreadPoolExecutor
+
+from ..auto import tqdm as tqdm_auto
+
+__author__ = {"github.com/": ["casperdcl"]}
+__all__ = ['MonoWorker']
+
+
+class MonoWorker(object):
+    """
+    Supports one running task and one waiting task.
+    The waiting task is the most recent submitted (others are discarded).
+    """
+    def __init__(self):
+        self.pool = ThreadPoolExecutor(max_workers=1)
+        self.futures = deque([], 2)
+
+    def submit(self, func, *args, **kwargs):
+        """`func(*args, **kwargs)` may replace currently waiting task."""
+        futures = self.futures
+        if len(futures) == futures.maxlen:
+            running = futures.popleft()
+            if not running.done():
+                if len(futures):  # clear waiting
+                    waiting = futures.pop()
+                    waiting.cancel()
+                futures.appendleft(running)  # re-insert running
+        try:
+            waiting = self.pool.submit(func, *args, **kwargs)
+        except Exception as e:
+            tqdm_auto.write(str(e))
+        else:
+            futures.append(waiting)
+            return waiting
diff --git a/phivenv/Lib/site-packages/tqdm/dask.py b/phivenv/Lib/site-packages/tqdm/dask.py
new file mode 100644
index 0000000000000000000000000000000000000000..57f1b668f59dc5991019eee34c7df3232a2c2cd7
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/dask.py
@@ -0,0 +1,44 @@
+from functools import partial
+
+from dask.callbacks import Callback
+
+from .auto import tqdm as tqdm_auto
+
+__author__ = {"github.com/": ["casperdcl"]}
+__all__ = ['TqdmCallback']
+
+
+class TqdmCallback(Callback):
+    """Dask callback for task progress."""
+    def __init__(self, start=None, pretask=None, tqdm_class=tqdm_auto,
+                 **tqdm_kwargs):
+        """
+        Parameters
+        ----------
+        tqdm_class  : optional
+            `tqdm` class to use for bars [default: `tqdm.auto.tqdm`].
+        tqdm_kwargs  : optional
+            Any other arguments used for all bars.
+        """
+        super().__init__(start=start, pretask=pretask)
+        if tqdm_kwargs:
+            tqdm_class = partial(tqdm_class, **tqdm_kwargs)
+        self.tqdm_class = tqdm_class
+
+    def _start_state(self, _, state):
+        self.pbar = self.tqdm_class(total=sum(
+            len(state[k]) for k in ['ready', 'waiting', 'running', 'finished']))
+
+    def _posttask(self, *_, **__):
+        self.pbar.update()
+
+    def _finish(self, *_, **__):
+        self.pbar.close()
+
+    def display(self):
+        """Displays in the current cell in Notebooks."""
+        container = getattr(self.bar, 'container', None)
+        if container is None:
+            return
+        from .notebook import display
+        display(container)
diff --git a/phivenv/Lib/site-packages/tqdm/gui.py b/phivenv/Lib/site-packages/tqdm/gui.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb52fb91a8661f4c73edd352bbc6f21b877dcfee
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/gui.py
@@ -0,0 +1,179 @@
+"""
+Matplotlib GUI progressbar decorator for iterators.
+
+Usage:
+>>> from tqdm.gui import trange, tqdm
+>>> for i in trange(10):
+...     ...
+"""
+# future division is important to divide integers and get as
+# a result precise floating numbers (instead of truncated int)
+import re
+from warnings import warn
+
+# to inherit from the tqdm class
+from .std import TqdmExperimentalWarning
+from .std import tqdm as std_tqdm
+
+# import compatibility functions and utilities
+
+__author__ = {"github.com/": ["casperdcl", "lrq3000"]}
+__all__ = ['tqdm_gui', 'tgrange', 'tqdm', 'trange']
+
+
+class tqdm_gui(std_tqdm):  # pragma: no cover
+    """Experimental Matplotlib GUI version of tqdm!"""
+    # TODO: @classmethod: write() on GUI?
+    def __init__(self, *args, **kwargs):
+        from collections import deque
+
+        import matplotlib as mpl
+        import matplotlib.pyplot as plt
+        kwargs = kwargs.copy()
+        kwargs['gui'] = True
+        colour = kwargs.pop('colour', 'g')
+        super().__init__(*args, **kwargs)
+
+        if self.disable:
+            return
+
+        warn("GUI is experimental/alpha", TqdmExperimentalWarning, stacklevel=2)
+        self.mpl = mpl
+        self.plt = plt
+
+        # Remember if external environment uses toolbars
+        self.toolbar = self.mpl.rcParams['toolbar']
+        self.mpl.rcParams['toolbar'] = 'None'
+
+        self.mininterval = max(self.mininterval, 0.5)
+        self.fig, ax = plt.subplots(figsize=(9, 2.2))
+        # self.fig.subplots_adjust(bottom=0.2)
+        total = self.__len__()  # avoids TypeError on None #971
+        if total is not None:
+            self.xdata = []
+            self.ydata = []
+            self.zdata = []
+        else:
+            self.xdata = deque([])
+            self.ydata = deque([])
+            self.zdata = deque([])
+        self.line1, = ax.plot(self.xdata, self.ydata, color='b')
+        self.line2, = ax.plot(self.xdata, self.zdata, color='k')
+        ax.set_ylim(0, 0.001)
+        if total is not None:
+            ax.set_xlim(0, 100)
+            ax.set_xlabel("percent")
+            self.fig.legend((self.line1, self.line2), ("cur", "est"),
+                            loc='center right')
+            # progressbar
+            self.hspan = plt.axhspan(0, 0.001, xmin=0, xmax=0, color=colour)
+        else:
+            # ax.set_xlim(-60, 0)
+            ax.set_xlim(0, 60)
+            ax.invert_xaxis()
+            ax.set_xlabel("seconds")
+            ax.legend(("cur", "est"), loc='lower left')
+        ax.grid()
+        # ax.set_xlabel('seconds')
+        ax.set_ylabel((self.unit if self.unit else "it") + "/s")
+        if self.unit_scale:
+            plt.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
+            ax.yaxis.get_offset_text().set_x(-0.15)
+
+        # Remember if external environment is interactive
+        self.wasion = plt.isinteractive()
+        plt.ion()
+        self.ax = ax
+
+    def close(self):
+        if self.disable:
+            return
+
+        self.disable = True
+
+        with self.get_lock():
+            self._instances.remove(self)
+
+        # Restore toolbars
+        self.mpl.rcParams['toolbar'] = self.toolbar
+        # Return to non-interactive mode
+        if not self.wasion:
+            self.plt.ioff()
+        if self.leave:
+            self.display()
+        else:
+            self.plt.close(self.fig)
+
+    def clear(self, *_, **__):
+        pass
+
+    def display(self, *_, **__):
+        n = self.n
+        cur_t = self._time()
+        elapsed = cur_t - self.start_t
+        delta_it = n - self.last_print_n
+        delta_t = cur_t - self.last_print_t
+
+        # Inline due to multiple calls
+        total = self.total
+        xdata = self.xdata
+        ydata = self.ydata
+        zdata = self.zdata
+        ax = self.ax
+        line1 = self.line1
+        line2 = self.line2
+        hspan = getattr(self, 'hspan', None)
+        # instantaneous rate
+        y = delta_it / delta_t
+        # overall rate
+        z = n / elapsed
+        # update line data
+        xdata.append(n * 100.0 / total if total else cur_t)
+        ydata.append(y)
+        zdata.append(z)
+
+        # Discard old values
+        # xmin, xmax = ax.get_xlim()
+        # if (not total) and elapsed > xmin * 1.1:
+        if (not total) and elapsed > 66:
+            xdata.popleft()
+            ydata.popleft()
+            zdata.popleft()
+
+        ymin, ymax = ax.get_ylim()
+        if y > ymax or z > ymax:
+            ymax = 1.1 * y
+            ax.set_ylim(ymin, ymax)
+            ax.figure.canvas.draw()
+
+        if total:
+            line1.set_data(xdata, ydata)
+            line2.set_data(xdata, zdata)
+            if hspan:
+                hspan.set_xy((0, ymin))
+                hspan.set_height(ymax - ymin)
+                hspan.set_width(n / total)
+        else:
+            t_ago = [cur_t - i for i in xdata]
+            line1.set_data(t_ago, ydata)
+            line2.set_data(t_ago, zdata)
+
+        d = self.format_dict
+        # remove {bar}
+        d['bar_format'] = (d['bar_format'] or "{l_bar}{r_bar}").replace(
+            "{bar}", "")
+        msg = self.format_meter(**d)
+        if '' in msg:
+            msg = "".join(re.split(r'\|?\|?', msg, maxsplit=1))
+        ax.set_title(msg, fontname="DejaVu Sans Mono", fontsize=11)
+        self.plt.pause(1e-9)
+
+
+def tgrange(*args, **kwargs):
+    """Shortcut for `tqdm.gui.tqdm(range(*args), **kwargs)`."""
+    return tqdm_gui(range(*args), **kwargs)
+
+
+# Aliases
+tqdm = tqdm_gui
+trange = tgrange
diff --git a/phivenv/Lib/site-packages/tqdm/keras.py b/phivenv/Lib/site-packages/tqdm/keras.py
new file mode 100644
index 0000000000000000000000000000000000000000..cce9467c51a95388aaa502d1da9a42f3ebf0af24
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/keras.py
@@ -0,0 +1,122 @@
+from copy import copy
+from functools import partial
+
+from .auto import tqdm as tqdm_auto
+
+try:
+    import keras
+except (ImportError, AttributeError) as e:
+    try:
+        from tensorflow import keras
+    except ImportError:
+        raise e
+__author__ = {"github.com/": ["casperdcl"]}
+__all__ = ['TqdmCallback']
+
+
+class TqdmCallback(keras.callbacks.Callback):
+    """Keras callback for epoch and batch progress."""
+    @staticmethod
+    def bar2callback(bar, pop=None, delta=(lambda logs: 1)):
+        def callback(_, logs=None):
+            n = delta(logs)
+            if logs:
+                if pop:
+                    logs = copy(logs)
+                    [logs.pop(i, 0) for i in pop]
+                bar.set_postfix(logs, refresh=False)
+            bar.update(n)
+
+        return callback
+
+    def __init__(self, epochs=None, data_size=None, batch_size=None, verbose=1,
+                 tqdm_class=tqdm_auto, **tqdm_kwargs):
+        """
+        Parameters
+        ----------
+        epochs  : int, optional
+        data_size  : int, optional
+            Number of training pairs.
+        batch_size  : int, optional
+            Number of training pairs per batch.
+        verbose  : int
+            0: epoch, 1: batch (transient), 2: batch. [default: 1].
+            Will be set to `0` unless both `data_size` and `batch_size`
+            are given.
+        tqdm_class  : optional
+            `tqdm` class to use for bars [default: `tqdm.auto.tqdm`].
+        tqdm_kwargs  : optional
+            Any other arguments used for all bars.
+        """
+        if tqdm_kwargs:
+            tqdm_class = partial(tqdm_class, **tqdm_kwargs)
+        self.tqdm_class = tqdm_class
+        self.epoch_bar = tqdm_class(total=epochs, unit='epoch')
+        self.on_epoch_end = self.bar2callback(self.epoch_bar)
+        if data_size and batch_size:
+            self.batches = batches = (data_size + batch_size - 1) // batch_size
+        else:
+            self.batches = batches = None
+        self.verbose = verbose
+        if verbose == 1:
+            self.batch_bar = tqdm_class(total=batches, unit='batch', leave=False)
+            self.on_batch_end = self.bar2callback(
+                self.batch_bar, pop=['batch', 'size'],
+                delta=lambda logs: logs.get('size', 1))
+
+    def on_train_begin(self, *_, **__):
+        params = self.params.get
+        auto_total = params('epochs', params('nb_epoch', None))
+        if auto_total is not None and auto_total != self.epoch_bar.total:
+            self.epoch_bar.reset(total=auto_total)
+
+    def on_epoch_begin(self, epoch, *_, **__):
+        if self.epoch_bar.n < epoch:
+            ebar = self.epoch_bar
+            ebar.n = ebar.last_print_n = ebar.initial = epoch
+        if self.verbose:
+            params = self.params.get
+            total = params('samples', params(
+                'nb_sample', params('steps', None))) or self.batches
+            if self.verbose == 2:
+                if hasattr(self, 'batch_bar'):
+                    self.batch_bar.close()
+                self.batch_bar = self.tqdm_class(
+                    total=total, unit='batch', leave=True,
+                    unit_scale=1 / (params('batch_size', 1) or 1))
+                self.on_batch_end = self.bar2callback(
+                    self.batch_bar, pop=['batch', 'size'],
+                    delta=lambda logs: logs.get('size', 1))
+            elif self.verbose == 1:
+                self.batch_bar.unit_scale = 1 / (params('batch_size', 1) or 1)
+                self.batch_bar.reset(total=total)
+            else:
+                raise KeyError('Unknown verbosity')
+
+    def on_train_end(self, *_, **__):
+        if hasattr(self, 'batch_bar'):
+            self.batch_bar.close()
+        self.epoch_bar.close()
+
+    def display(self):
+        """Displays in the current cell in Notebooks."""
+        container = getattr(self.epoch_bar, 'container', None)
+        if container is None:
+            return
+        from .notebook import display
+        display(container)
+        batch_bar = getattr(self, 'batch_bar', None)
+        if batch_bar is not None:
+            display(batch_bar.container)
+
+    @staticmethod
+    def _implements_train_batch_hooks():
+        return True
+
+    @staticmethod
+    def _implements_test_batch_hooks():
+        return True
+
+    @staticmethod
+    def _implements_predict_batch_hooks():
+        return True
diff --git a/phivenv/Lib/site-packages/tqdm/notebook.py b/phivenv/Lib/site-packages/tqdm/notebook.py
new file mode 100644
index 0000000000000000000000000000000000000000..77b91bdd43183998fcb99e92dd4597ff7fc6c3fb
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/notebook.py
@@ -0,0 +1,317 @@
+"""
+IPython/Jupyter Notebook progressbar decorator for iterators.
+Includes a default `range` iterator printing to `stderr`.
+
+Usage:
+>>> from tqdm.notebook import trange, tqdm
+>>> for i in trange(10):
+...     ...
+"""
+# import compatibility functions and utilities
+import re
+import sys
+from html import escape
+from weakref import proxy
+
+# to inherit from the tqdm class
+from .std import tqdm as std_tqdm
+
+if True:  # pragma: no cover
+    # import IPython/Jupyter base widget and display utilities
+    IPY = 0
+    try:  # IPython 4.x
+        import ipywidgets
+        IPY = 4
+    except ImportError:  # IPython 3.x / 2.x
+        IPY = 32
+        import warnings
+        with warnings.catch_warnings():
+            warnings.filterwarnings(
+                'ignore', message=".*The `IPython.html` package has been deprecated.*")
+            try:
+                import IPython.html.widgets as ipywidgets  # NOQA: F401
+            except ImportError:
+                pass
+
+    try:  # IPython 4.x / 3.x
+        if IPY == 32:
+            from IPython.html.widgets import HTML
+            from IPython.html.widgets import FloatProgress as IProgress
+            from IPython.html.widgets import HBox
+            IPY = 3
+        else:
+            from ipywidgets import HTML
+            from ipywidgets import FloatProgress as IProgress
+            from ipywidgets import HBox
+    except ImportError:
+        try:  # IPython 2.x
+            from IPython.html.widgets import HTML
+            from IPython.html.widgets import ContainerWidget as HBox
+            from IPython.html.widgets import FloatProgressWidget as IProgress
+            IPY = 2
+        except ImportError:
+            IPY = 0
+            IProgress = None
+            HBox = object
+
+    try:
+        from IPython.display import display  # , clear_output
+    except ImportError:
+        pass
+
+__author__ = {"github.com/": ["lrq3000", "casperdcl", "alexanderkuk"]}
+__all__ = ['tqdm_notebook', 'tnrange', 'tqdm', 'trange']
+WARN_NOIPYW = ("IProgress not found. Please update jupyter and ipywidgets."
+               " See https://ipywidgets.readthedocs.io/en/stable"
+               "/user_install.html")
+
+
+class TqdmHBox(HBox):
+    """`ipywidgets.HBox` with a pretty representation"""
+    def _json_(self, pretty=None):
+        pbar = getattr(self, 'pbar', None)
+        if pbar is None:
+            return {}
+        d = pbar.format_dict
+        if pretty is not None:
+            d["ascii"] = not pretty
+        return d
+
+    def __repr__(self, pretty=False):
+        pbar = getattr(self, 'pbar', None)
+        if pbar is None:
+            return super().__repr__()
+        return pbar.format_meter(**self._json_(pretty))
+
+    def _repr_pretty_(self, pp, *_, **__):
+        pp.text(self.__repr__(True))
+
+
+class tqdm_notebook(std_tqdm):
+    """
+    Experimental IPython/Jupyter Notebook widget using tqdm!
+    """
+    @staticmethod
+    def status_printer(_, total=None, desc=None, ncols=None):
+        """
+        Manage the printing of an IPython/Jupyter Notebook progress bar widget.
+        """
+        # Fallback to text bar if there's no total
+        # DEPRECATED: replaced with an 'info' style bar
+        # if not total:
+        #    return super(tqdm_notebook, tqdm_notebook).status_printer(file)
+
+        # fp = file
+
+        # Prepare IPython progress bar
+        if IProgress is None:  # #187 #451 #558 #872
+            raise ImportError(WARN_NOIPYW)
+        if total:
+            pbar = IProgress(min=0, max=total)
+        else:  # No total? Show info style bar with no progress tqdm status
+            pbar = IProgress(min=0, max=1)
+            pbar.value = 1
+            pbar.bar_style = 'info'
+            if ncols is None:
+                pbar.layout.width = "20px"
+
+        ltext = HTML()
+        rtext = HTML()
+        if desc:
+            ltext.value = desc
+        container = TqdmHBox(children=[ltext, pbar, rtext])
+        # Prepare layout
+        if ncols is not None:  # use default style of ipywidgets
+            # ncols could be 100, "100px", "100%"
+            ncols = str(ncols)  # ipywidgets only accepts string
+            try:
+                if int(ncols) > 0:  # isnumeric and positive
+                    ncols += 'px'
+            except ValueError:
+                pass
+            pbar.layout.flex = '2'
+            container.layout.width = ncols
+            container.layout.display = 'inline-flex'
+            container.layout.flex_flow = 'row wrap'
+
+        return container
+
+    def display(self, msg=None, pos=None,
+                # additional signals
+                close=False, bar_style=None, check_delay=True):
+        # Note: contrary to native tqdm, msg='' does NOT clear bar
+        # goal is to keep all infos if error happens so user knows
+        # at which iteration the loop failed.
+
+        # Clear previous output (really necessary?)
+        # clear_output(wait=1)
+
+        if not msg and not close:
+            d = self.format_dict
+            # remove {bar}
+            d['bar_format'] = (d['bar_format'] or "{l_bar}{r_bar}").replace(
+                "{bar}", "")
+            msg = self.format_meter(**d)
+
+        ltext, pbar, rtext = self.container.children
+        pbar.value = self.n
+
+        if msg:
+            msg = msg.replace(' ', u'\u2007')  # fix html space padding
+            # html escape special characters (like '&')
+            if '' in msg:
+                left, right = map(escape, re.split(r'\|?\|?', msg, maxsplit=1))
+            else:
+                left, right = '', escape(msg)
+
+            # Update description
+            ltext.value = left
+            # never clear the bar (signal: msg='')
+            if right:
+                rtext.value = right
+
+        # Change bar style
+        if bar_style:
+            # Hack-ish way to avoid the danger bar_style being overridden by
+            # success because the bar gets closed after the error...
+            if pbar.bar_style != 'danger' or bar_style != 'success':
+                pbar.bar_style = bar_style
+
+        # Special signal to close the bar
+        if close and pbar.bar_style != 'danger':  # hide only if no error
+            try:
+                self.container.close()
+            except AttributeError:
+                self.container.visible = False
+            self.container.layout.visibility = 'hidden'  # IPYW>=8
+
+        if check_delay and self.delay > 0 and not self.displayed:
+            display(self.container)
+            self.displayed = True
+
+    @property
+    def colour(self):
+        if hasattr(self, 'container'):
+            return self.container.children[-2].style.bar_color
+
+    @colour.setter
+    def colour(self, bar_color):
+        if hasattr(self, 'container'):
+            self.container.children[-2].style.bar_color = bar_color
+
+    def __init__(self, *args, **kwargs):
+        """
+        Supports the usual `tqdm.tqdm` parameters as well as those listed below.
+
+        Parameters
+        ----------
+        display  : Whether to call `display(self.container)` immediately
+            [default: True].
+        """
+        kwargs = kwargs.copy()
+        # Setup default output
+        file_kwarg = kwargs.get('file', sys.stderr)
+        if file_kwarg is sys.stderr or file_kwarg is None:
+            kwargs['file'] = sys.stdout  # avoid the red block in IPython
+
+        # Initialize parent class + avoid printing by using gui=True
+        kwargs['gui'] = True
+        # convert disable = None to False
+        kwargs['disable'] = bool(kwargs.get('disable', False))
+        colour = kwargs.pop('colour', None)
+        display_here = kwargs.pop('display', True)
+        super().__init__(*args, **kwargs)
+        if self.disable or not kwargs['gui']:
+            self.disp = lambda *_, **__: None
+            return
+
+        # Get bar width
+        self.ncols = '100%' if self.dynamic_ncols else kwargs.get("ncols", None)
+
+        # Replace with IPython progress bar display (with correct total)
+        unit_scale = 1 if self.unit_scale is True else self.unit_scale or 1
+        total = self.total * unit_scale if self.total else self.total
+        self.container = self.status_printer(self.fp, total, self.desc, self.ncols)
+        self.container.pbar = proxy(self)
+        self.displayed = False
+        if display_here and self.delay <= 0:
+            display(self.container)
+            self.displayed = True
+        self.disp = self.display
+        self.colour = colour
+
+        # Print initial bar state
+        if not self.disable:
+            self.display(check_delay=False)
+
+    def __iter__(self):
+        try:
+            it = super().__iter__()
+            for obj in it:
+                # return super(tqdm...) will not catch exception
+                yield obj
+        # NB: except ... [ as ...] breaks IPython async KeyboardInterrupt
+        except:  # NOQA
+            self.disp(bar_style='danger')
+            raise
+        # NB: don't `finally: close()`
+        # since this could be a shared bar which the user will `reset()`
+
+    def update(self, n=1):
+        try:
+            return super().update(n=n)
+        # NB: except ... [ as ...] breaks IPython async KeyboardInterrupt
+        except:  # NOQA
+            # cannot catch KeyboardInterrupt when using manual tqdm
+            # as the interrupt will most likely happen on another statement
+            self.disp(bar_style='danger')
+            raise
+        # NB: don't `finally: close()`
+        # since this could be a shared bar which the user will `reset()`
+
+    def close(self):
+        if self.disable:
+            return
+        super().close()
+        # Try to detect if there was an error or KeyboardInterrupt
+        # in manual mode: if n < total, things probably got wrong
+        if self.total and self.n < self.total:
+            self.disp(bar_style='danger', check_delay=False)
+        else:
+            if self.leave:
+                self.disp(bar_style='success', check_delay=False)
+            else:
+                self.disp(close=True, check_delay=False)
+
+    def clear(self, *_, **__):
+        pass
+
+    def reset(self, total=None):
+        """
+        Resets to 0 iterations for repeated use.
+
+        Consider combining with `leave=True`.
+
+        Parameters
+        ----------
+        total  : int or float, optional. Total to use for the new bar.
+        """
+        if self.disable:
+            return super().reset(total=total)
+        _, pbar, _ = self.container.children
+        pbar.bar_style = ''
+        if total is not None:
+            pbar.max = total
+            if not self.total and self.ncols is None:  # no longer unknown total
+                pbar.layout.width = None  # reset width
+        return super().reset(total=total)
+
+
+def tnrange(*args, **kwargs):
+    """Shortcut for `tqdm.notebook.tqdm(range(*args), **kwargs)`."""
+    return tqdm_notebook(range(*args), **kwargs)
+
+
+# Aliases
+tqdm = tqdm_notebook
+trange = tnrange
diff --git a/phivenv/Lib/site-packages/tqdm/rich.py b/phivenv/Lib/site-packages/tqdm/rich.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d392edaf115a93f7c145de52cbe8978dcf1ede8
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/rich.py
@@ -0,0 +1,151 @@
+"""
+`rich.progress` decorator for iterators.
+
+Usage:
+>>> from tqdm.rich import trange, tqdm
+>>> for i in trange(10):
+...     ...
+"""
+from warnings import warn
+
+from rich.progress import (
+    BarColumn, Progress, ProgressColumn, Text, TimeElapsedColumn, TimeRemainingColumn, filesize)
+
+from .std import TqdmExperimentalWarning
+from .std import tqdm as std_tqdm
+
+__author__ = {"github.com/": ["casperdcl"]}
+__all__ = ['tqdm_rich', 'trrange', 'tqdm', 'trange']
+
+
+class FractionColumn(ProgressColumn):
+    """Renders completed/total, e.g. '0.5/2.3 G'."""
+    def __init__(self, unit_scale=False, unit_divisor=1000):
+        self.unit_scale = unit_scale
+        self.unit_divisor = unit_divisor
+        super().__init__()
+
+    def render(self, task):
+        """Calculate common unit for completed and total."""
+        completed = int(task.completed)
+        total = int(task.total)
+        if self.unit_scale:
+            unit, suffix = filesize.pick_unit_and_suffix(
+                total,
+                ["", "K", "M", "G", "T", "P", "E", "Z", "Y"],
+                self.unit_divisor,
+            )
+        else:
+            unit, suffix = filesize.pick_unit_and_suffix(total, [""], 1)
+        precision = 0 if unit == 1 else 1
+        return Text(
+            f"{completed/unit:,.{precision}f}/{total/unit:,.{precision}f} {suffix}",
+            style="progress.download")
+
+
+class RateColumn(ProgressColumn):
+    """Renders human readable transfer speed."""
+    def __init__(self, unit="", unit_scale=False, unit_divisor=1000):
+        self.unit = unit
+        self.unit_scale = unit_scale
+        self.unit_divisor = unit_divisor
+        super().__init__()
+
+    def render(self, task):
+        """Show data transfer speed."""
+        speed = task.speed
+        if speed is None:
+            return Text(f"? {self.unit}/s", style="progress.data.speed")
+        if self.unit_scale:
+            unit, suffix = filesize.pick_unit_and_suffix(
+                speed,
+                ["", "K", "M", "G", "T", "P", "E", "Z", "Y"],
+                self.unit_divisor,
+            )
+        else:
+            unit, suffix = filesize.pick_unit_and_suffix(speed, [""], 1)
+        precision = 0 if unit == 1 else 1
+        return Text(f"{speed/unit:,.{precision}f} {suffix}{self.unit}/s",
+                    style="progress.data.speed")
+
+
+class tqdm_rich(std_tqdm):  # pragma: no cover
+    """Experimental rich.progress GUI version of tqdm!"""
+    # TODO: @classmethod: write()?
+    def __init__(self, *args, **kwargs):
+        """
+        This class accepts the following parameters *in addition* to
+        the parameters accepted by `tqdm`.
+
+        Parameters
+        ----------
+        progress  : tuple, optional
+            arguments for `rich.progress.Progress()`.
+        options  : dict, optional
+            keyword arguments for `rich.progress.Progress()`.
+        """
+        kwargs = kwargs.copy()
+        kwargs['gui'] = True
+        # convert disable = None to False
+        kwargs['disable'] = bool(kwargs.get('disable', False))
+        progress = kwargs.pop('progress', None)
+        options = kwargs.pop('options', {}).copy()
+        super().__init__(*args, **kwargs)
+
+        if self.disable:
+            return
+
+        warn("rich is experimental/alpha", TqdmExperimentalWarning, stacklevel=2)
+        d = self.format_dict
+        if progress is None:
+            progress = (
+                "[progress.description]{task.description}"
+                "[progress.percentage]{task.percentage:>4.0f}%",
+                BarColumn(bar_width=None),
+                FractionColumn(
+                    unit_scale=d['unit_scale'], unit_divisor=d['unit_divisor']),
+                "[", TimeElapsedColumn(), "<", TimeRemainingColumn(),
+                ",", RateColumn(unit=d['unit'], unit_scale=d['unit_scale'],
+                                unit_divisor=d['unit_divisor']), "]"
+            )
+        options.setdefault('transient', not self.leave)
+        self._prog = Progress(*progress, **options)
+        self._prog.__enter__()
+        self._task_id = self._prog.add_task(self.desc or "", **d)
+
+    def close(self):
+        if self.disable:
+            return
+        self.display()  # print 100%, vis #1306
+        super().close()
+        self._prog.__exit__(None, None, None)
+
+    def clear(self, *_, **__):
+        pass
+
+    def display(self, *_, **__):
+        if not hasattr(self, '_prog'):
+            return
+        self._prog.update(self._task_id, completed=self.n, description=self.desc)
+
+    def reset(self, total=None):
+        """
+        Resets to 0 iterations for repeated use.
+
+        Parameters
+        ----------
+        total  : int or float, optional. Total to use for the new bar.
+        """
+        if hasattr(self, '_prog'):
+            self._prog.reset(total=total)
+        super().reset(total=total)
+
+
+def trrange(*args, **kwargs):
+    """Shortcut for `tqdm.rich.tqdm(range(*args), **kwargs)`."""
+    return tqdm_rich(range(*args), **kwargs)
+
+
+# Aliases
+tqdm = tqdm_rich
+trange = trrange
diff --git a/phivenv/Lib/site-packages/tqdm/std.py b/phivenv/Lib/site-packages/tqdm/std.py
new file mode 100644
index 0000000000000000000000000000000000000000..e91ad3090392916fc2bc1e34bc471e43212fe699
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/std.py
@@ -0,0 +1,1524 @@
+"""
+Customisable progressbar decorator for iterators.
+Includes a default `range` iterator printing to `stderr`.
+
+Usage:
+>>> from tqdm import trange, tqdm
+>>> for i in trange(10):
+...     ...
+"""
+import sys
+from collections import OrderedDict, defaultdict
+from contextlib import contextmanager
+from datetime import datetime, timedelta, timezone
+from numbers import Number
+from time import time
+from warnings import warn
+from weakref import WeakSet
+
+from ._monitor import TMonitor
+from .utils import (
+    CallbackIOWrapper, Comparable, DisableOnWriteError, FormatReplace, SimpleTextIOWrapper,
+    _is_ascii, _screen_shape_wrapper, _supports_unicode, _term_move_up, disp_len, disp_trim,
+    envwrap)
+
+__author__ = "https://github.com/tqdm/tqdm#contributions"
+__all__ = ['tqdm', 'trange',
+           'TqdmTypeError', 'TqdmKeyError', 'TqdmWarning',
+           'TqdmExperimentalWarning', 'TqdmDeprecationWarning',
+           'TqdmMonitorWarning']
+
+
+class TqdmTypeError(TypeError):
+    pass
+
+
+class TqdmKeyError(KeyError):
+    pass
+
+
+class TqdmWarning(Warning):
+    """base class for all tqdm warnings.
+
+    Used for non-external-code-breaking errors, such as garbled printing.
+    """
+    def __init__(self, msg, fp_write=None, *a, **k):
+        if fp_write is not None:
+            fp_write("\n" + self.__class__.__name__ + ": " + str(msg).rstrip() + '\n')
+        else:
+            super().__init__(msg, *a, **k)
+
+
+class TqdmExperimentalWarning(TqdmWarning, FutureWarning):
+    """beta feature, unstable API and behaviour"""
+    pass
+
+
+class TqdmDeprecationWarning(TqdmWarning, DeprecationWarning):
+    # not suppressed if raised
+    pass
+
+
+class TqdmMonitorWarning(TqdmWarning, RuntimeWarning):
+    """tqdm monitor errors which do not affect external functionality"""
+    pass
+
+
+def TRLock(*args, **kwargs):
+    """threading RLock"""
+    try:
+        from threading import RLock
+        return RLock(*args, **kwargs)
+    except (ImportError, OSError):  # pragma: no cover
+        pass
+
+
+class TqdmDefaultWriteLock(object):
+    """
+    Provide a default write lock for thread and multiprocessing safety.
+    Works only on platforms supporting `fork` (so Windows is excluded).
+    You must initialise a `tqdm` or `TqdmDefaultWriteLock` instance
+    before forking in order for the write lock to work.
+    On Windows, you need to supply the lock from the parent to the children as
+    an argument to joblib or the parallelism lib you use.
+    """
+    # global thread lock so no setup required for multithreading.
+    # NB: Do not create multiprocessing lock as it sets the multiprocessing
+    # context, disallowing `spawn()`/`forkserver()`
+    th_lock = TRLock()
+
+    def __init__(self):
+        # Create global parallelism locks to avoid racing issues with parallel
+        # bars works only if fork available (Linux/MacOSX, but not Windows)
+        cls = type(self)
+        root_lock = cls.th_lock
+        if root_lock is not None:
+            root_lock.acquire()
+        cls.create_mp_lock()
+        self.locks = [lk for lk in [cls.mp_lock, cls.th_lock] if lk is not None]
+        if root_lock is not None:
+            root_lock.release()
+
+    def acquire(self, *a, **k):
+        for lock in self.locks:
+            lock.acquire(*a, **k)
+
+    def release(self):
+        for lock in self.locks[::-1]:  # Release in inverse order of acquisition
+            lock.release()
+
+    def __enter__(self):
+        self.acquire()
+
+    def __exit__(self, *exc):
+        self.release()
+
+    @classmethod
+    def create_mp_lock(cls):
+        if not hasattr(cls, 'mp_lock'):
+            try:
+                from multiprocessing import RLock
+                cls.mp_lock = RLock()
+            except (ImportError, OSError):  # pragma: no cover
+                cls.mp_lock = None
+
+    @classmethod
+    def create_th_lock(cls):
+        assert hasattr(cls, 'th_lock')
+        warn("create_th_lock not needed anymore", TqdmDeprecationWarning, stacklevel=2)
+
+
+class Bar(object):
+    """
+    `str.format`-able bar with format specifiers: `[width][type]`
+
+    - `width`
+      + unspecified (default): use `self.default_len`
+      + `int >= 0`: overrides `self.default_len`
+      + `int < 0`: subtract from `self.default_len`
+    - `type`
+      + `a`: ascii (`charset=self.ASCII` override)
+      + `u`: unicode (`charset=self.UTF` override)
+      + `b`: blank (`charset="  "` override)
+    """
+    ASCII = " 123456789#"
+    UTF = u" " + u''.join(map(chr, range(0x258F, 0x2587, -1)))
+    BLANK = "  "
+    COLOUR_RESET = '\x1b[0m'
+    COLOUR_RGB = '\x1b[38;2;%d;%d;%dm'
+    COLOURS = {'BLACK': '\x1b[30m', 'RED': '\x1b[31m', 'GREEN': '\x1b[32m',
+               'YELLOW': '\x1b[33m', 'BLUE': '\x1b[34m', 'MAGENTA': '\x1b[35m',
+               'CYAN': '\x1b[36m', 'WHITE': '\x1b[37m'}
+
+    def __init__(self, frac, default_len=10, charset=UTF, colour=None):
+        if not 0 <= frac <= 1:
+            warn("clamping frac to range [0, 1]", TqdmWarning, stacklevel=2)
+            frac = max(0, min(1, frac))
+        assert default_len > 0
+        self.frac = frac
+        self.default_len = default_len
+        self.charset = charset
+        self.colour = colour
+
+    @property
+    def colour(self):
+        return self._colour
+
+    @colour.setter
+    def colour(self, value):
+        if not value:
+            self._colour = None
+            return
+        try:
+            if value.upper() in self.COLOURS:
+                self._colour = self.COLOURS[value.upper()]
+            elif value[0] == '#' and len(value) == 7:
+                self._colour = self.COLOUR_RGB % tuple(
+                    int(i, 16) for i in (value[1:3], value[3:5], value[5:7]))
+            else:
+                raise KeyError
+        except (KeyError, AttributeError):
+            warn("Unknown colour (%s); valid choices: [hex (#00ff00), %s]" % (
+                 value, ", ".join(self.COLOURS)),
+                 TqdmWarning, stacklevel=2)
+            self._colour = None
+
+    def __format__(self, format_spec):
+        if format_spec:
+            _type = format_spec[-1].lower()
+            try:
+                charset = {'a': self.ASCII, 'u': self.UTF, 'b': self.BLANK}[_type]
+            except KeyError:
+                charset = self.charset
+            else:
+                format_spec = format_spec[:-1]
+            if format_spec:
+                N_BARS = int(format_spec)
+                if N_BARS < 0:
+                    N_BARS += self.default_len
+            else:
+                N_BARS = self.default_len
+        else:
+            charset = self.charset
+            N_BARS = self.default_len
+
+        nsyms = len(charset) - 1
+        bar_length, frac_bar_length = divmod(int(self.frac * N_BARS * nsyms), nsyms)
+
+        res = charset[-1] * bar_length
+        if bar_length < N_BARS:  # whitespace padding
+            res = res + charset[frac_bar_length] + charset[0] * (N_BARS - bar_length - 1)
+        return self.colour + res + self.COLOUR_RESET if self.colour else res
+
+
+class EMA(object):
+    """
+    Exponential moving average: smoothing to give progressively lower
+    weights to older values.
+
+    Parameters
+    ----------
+    smoothing  : float, optional
+        Smoothing factor in range [0, 1], [default: 0.3].
+        Increase to give more weight to recent values.
+        Ranges from 0 (yields old value) to 1 (yields new value).
+    """
+    def __init__(self, smoothing=0.3):
+        self.alpha = smoothing
+        self.last = 0
+        self.calls = 0
+
+    def __call__(self, x=None):
+        """
+        Parameters
+        ----------
+        x  : float
+            New value to include in EMA.
+        """
+        beta = 1 - self.alpha
+        if x is not None:
+            self.last = self.alpha * x + beta * self.last
+            self.calls += 1
+        return self.last / (1 - beta ** self.calls) if self.calls else self.last
+
+
+class tqdm(Comparable):
+    """
+    Decorate an iterable object, returning an iterator which acts exactly
+    like the original iterable, but prints a dynamically updating
+    progressbar every time a value is requested.
+
+    Parameters
+    ----------
+    iterable  : iterable, optional
+        Iterable to decorate with a progressbar.
+        Leave blank to manually manage the updates.
+    desc  : str, optional
+        Prefix for the progressbar.
+    total  : int or float, optional
+        The number of expected iterations. If unspecified,
+        len(iterable) is used if possible. If float("inf") or as a last
+        resort, only basic progress statistics are displayed
+        (no ETA, no progressbar).
+        If `gui` is True and this parameter needs subsequent updating,
+        specify an initial arbitrary large positive number,
+        e.g. 9e9.
+    leave  : bool, optional
+        If [default: True], keeps all traces of the progressbar
+        upon termination of iteration.
+        If `None`, will leave only if `position` is `0`.
+    file  : `io.TextIOWrapper` or `io.StringIO`, optional
+        Specifies where to output the progress messages
+        (default: sys.stderr). Uses `file.write(str)` and `file.flush()`
+        methods.  For encoding, see `write_bytes`.
+    ncols  : int, optional
+        The width of the entire output message. If specified,
+        dynamically resizes the progressbar to stay within this bound.
+        If unspecified, attempts to use environment width. The
+        fallback is a meter width of 10 and no limit for the counter and
+        statistics. If 0, will not print any meter (only stats).
+    mininterval  : float, optional
+        Minimum progress display update interval [default: 0.1] seconds.
+    maxinterval  : float, optional
+        Maximum progress display update interval [default: 10] seconds.
+        Automatically adjusts `miniters` to correspond to `mininterval`
+        after long display update lag. Only works if `dynamic_miniters`
+        or monitor thread is enabled.
+    miniters  : int or float, optional
+        Minimum progress display update interval, in iterations.
+        If 0 and `dynamic_miniters`, will automatically adjust to equal
+        `mininterval` (more CPU efficient, good for tight loops).
+        If > 0, will skip display of specified number of iterations.
+        Tweak this and `mininterval` to get very efficient loops.
+        If your progress is erratic with both fast and slow iterations
+        (network, skipping items, etc) you should set miniters=1.
+    ascii  : bool or str, optional
+        If unspecified or False, use unicode (smooth blocks) to fill
+        the meter. The fallback is to use ASCII characters " 123456789#".
+    disable  : bool, optional
+        Whether to disable the entire progressbar wrapper
+        [default: False]. If set to None, disable on non-TTY.
+    unit  : str, optional
+        String that will be used to define the unit of each iteration
+        [default: it].
+    unit_scale  : bool or int or float, optional
+        If 1 or True, the number of iterations will be reduced/scaled
+        automatically and a metric prefix following the
+        International System of Units standard will be added
+        (kilo, mega, etc.) [default: False]. If any other non-zero
+        number, will scale `total` and `n`.
+    dynamic_ncols  : bool, optional
+        If set, constantly alters `ncols` and `nrows` to the
+        environment (allowing for window resizes) [default: False].
+    smoothing  : float, optional
+        Exponential moving average smoothing factor for speed estimates
+        (ignored in GUI mode). Ranges from 0 (average speed) to 1
+        (current/instantaneous speed) [default: 0.3].
+    bar_format  : str, optional
+        Specify a custom bar string formatting. May impact performance.
+        [default: '{l_bar}{bar}{r_bar}'], where
+        l_bar='{desc}: {percentage:3.0f}%|' and
+        r_bar='| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, '
+            '{rate_fmt}{postfix}]'
+        Possible vars: l_bar, bar, r_bar, n, n_fmt, total, total_fmt,
+            percentage, elapsed, elapsed_s, ncols, nrows, desc, unit,
+            rate, rate_fmt, rate_noinv, rate_noinv_fmt,
+            rate_inv, rate_inv_fmt, postfix, unit_divisor,
+            remaining, remaining_s, eta.
+        Note that a trailing ": " is automatically removed after {desc}
+        if the latter is empty.
+    initial  : int or float, optional
+        The initial counter value. Useful when restarting a progress
+        bar [default: 0]. If using float, consider specifying `{n:.3f}`
+        or similar in `bar_format`, or specifying `unit_scale`.
+    position  : int, optional
+        Specify the line offset to print this bar (starting from 0)
+        Automatic if unspecified.
+        Useful to manage multiple bars at once (eg, from threads).
+    postfix  : dict or *, optional
+        Specify additional stats to display at the end of the bar.
+        Calls `set_postfix(**postfix)` if possible (dict).
+    unit_divisor  : float, optional
+        [default: 1000], ignored unless `unit_scale` is True.
+    write_bytes  : bool, optional
+        Whether to write bytes. If (default: False) will write unicode.
+    lock_args  : tuple, optional
+        Passed to `refresh` for intermediate output
+        (initialisation, iterating, and updating).
+    nrows  : int, optional
+        The screen height. If specified, hides nested bars outside this
+        bound. If unspecified, attempts to use environment height.
+        The fallback is 20.
+    colour  : str, optional
+        Bar colour (e.g. 'green', '#00ff00').
+    delay  : float, optional
+        Don't display until [default: 0] seconds have elapsed.
+    gui  : bool, optional
+        WARNING: internal parameter - do not use.
+        Use tqdm.gui.tqdm(...) instead. If set, will attempt to use
+        matplotlib animations for a graphical output [default: False].
+
+    Returns
+    -------
+    out  : decorated iterator.
+    """
+
+    monitor_interval = 10  # set to 0 to disable the thread
+    monitor = None
+    _instances = WeakSet()
+
+    @staticmethod
+    def format_sizeof(num, suffix='', divisor=1000):
+        """
+        Formats a number (greater than unity) with SI Order of Magnitude
+        prefixes.
+
+        Parameters
+        ----------
+        num  : float
+            Number ( >= 1) to format.
+        suffix  : str, optional
+            Post-postfix [default: ''].
+        divisor  : float, optional
+            Divisor between prefixes [default: 1000].
+
+        Returns
+        -------
+        out  : str
+            Number with Order of Magnitude SI unit postfix.
+        """
+        for unit in ['', 'k', 'M', 'G', 'T', 'P', 'E', 'Z']:
+            if abs(num) < 999.5:
+                if abs(num) < 99.95:
+                    if abs(num) < 9.995:
+                        return f'{num:1.2f}{unit}{suffix}'
+                    return f'{num:2.1f}{unit}{suffix}'
+                return f'{num:3.0f}{unit}{suffix}'
+            num /= divisor
+        return f'{num:3.1f}Y{suffix}'
+
+    @staticmethod
+    def format_interval(t):
+        """
+        Formats a number of seconds as a clock time, [H:]MM:SS
+
+        Parameters
+        ----------
+        t  : int
+            Number of seconds.
+
+        Returns
+        -------
+        out  : str
+            [H:]MM:SS
+        """
+        mins, s = divmod(int(t), 60)
+        h, m = divmod(mins, 60)
+        return f'{h:d}:{m:02d}:{s:02d}' if h else f'{m:02d}:{s:02d}'
+
+    @staticmethod
+    def format_num(n):
+        """
+        Intelligent scientific notation (.3g).
+
+        Parameters
+        ----------
+        n  : int or float or Numeric
+            A Number.
+
+        Returns
+        -------
+        out  : str
+            Formatted number.
+        """
+        f = f'{n:.3g}'.replace('e+0', 'e+').replace('e-0', 'e-')
+        n = str(n)
+        return f if len(f) < len(n) else n
+
+    @staticmethod
+    def status_printer(file):
+        """
+        Manage the printing and in-place updating of a line of characters.
+        Note that if the string is longer than a line, then in-place
+        updating may not work (it will print a new line at each refresh).
+        """
+        fp = file
+        fp_flush = getattr(fp, 'flush', lambda: None)  # pragma: no cover
+        if fp in (sys.stderr, sys.stdout):
+            getattr(sys.stderr, 'flush', lambda: None)()
+            getattr(sys.stdout, 'flush', lambda: None)()
+
+        def fp_write(s):
+            fp.write(str(s))
+            fp_flush()
+
+        last_len = [0]
+
+        def print_status(s):
+            len_s = disp_len(s)
+            fp_write('\r' + s + (' ' * max(last_len[0] - len_s, 0)))
+            last_len[0] = len_s
+
+        return print_status
+
+    @staticmethod
+    def format_meter(n, total, elapsed, ncols=None, prefix='', ascii=False, unit='it',
+                     unit_scale=False, rate=None, bar_format=None, postfix=None,
+                     unit_divisor=1000, initial=0, colour=None, **extra_kwargs):
+        """
+        Return a string-based progress bar given some parameters
+
+        Parameters
+        ----------
+        n  : int or float
+            Number of finished iterations.
+        total  : int or float
+            The expected total number of iterations. If meaningless (None),
+            only basic progress statistics are displayed (no ETA).
+        elapsed  : float
+            Number of seconds passed since start.
+        ncols  : int, optional
+            The width of the entire output message. If specified,
+            dynamically resizes `{bar}` to stay within this bound
+            [default: None]. If `0`, will not print any bar (only stats).
+            The fallback is `{bar:10}`.
+        prefix  : str, optional
+            Prefix message (included in total width) [default: ''].
+            Use as {desc} in bar_format string.
+        ascii  : bool, optional or str, optional
+            If not set, use unicode (smooth blocks) to fill the meter
+            [default: False]. The fallback is to use ASCII characters
+            " 123456789#".
+        unit  : str, optional
+            The iteration unit [default: 'it'].
+        unit_scale  : bool or int or float, optional
+            If 1 or True, the number of iterations will be printed with an
+            appropriate SI metric prefix (k = 10^3, M = 10^6, etc.)
+            [default: False]. If any other non-zero number, will scale
+            `total` and `n`.
+        rate  : float, optional
+            Manual override for iteration rate.
+            If [default: None], uses n/elapsed.
+        bar_format  : str, optional
+            Specify a custom bar string formatting. May impact performance.
+            [default: '{l_bar}{bar}{r_bar}'], where
+            l_bar='{desc}: {percentage:3.0f}%|' and
+            r_bar='| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, '
+              '{rate_fmt}{postfix}]'
+            Possible vars: l_bar, bar, r_bar, n, n_fmt, total, total_fmt,
+              percentage, elapsed, elapsed_s, ncols, nrows, desc, unit,
+              rate, rate_fmt, rate_noinv, rate_noinv_fmt,
+              rate_inv, rate_inv_fmt, postfix, unit_divisor,
+              remaining, remaining_s, eta.
+            Note that a trailing ": " is automatically removed after {desc}
+            if the latter is empty.
+        postfix  : *, optional
+            Similar to `prefix`, but placed at the end
+            (e.g. for additional stats).
+            Note: postfix is usually a string (not a dict) for this method,
+            and will if possible be set to postfix = ', ' + postfix.
+            However other types are supported (#382).
+        unit_divisor  : float, optional
+            [default: 1000], ignored unless `unit_scale` is True.
+        initial  : int or float, optional
+            The initial counter value [default: 0].
+        colour  : str, optional
+            Bar colour (e.g. 'green', '#00ff00').
+
+        Returns
+        -------
+        out  : Formatted meter and stats, ready to display.
+        """
+
+        # sanity check: total
+        if total and n >= (total + 0.5):  # allow float imprecision (#849)
+            total = None
+
+        # apply custom scale if necessary
+        if unit_scale and unit_scale not in (True, 1):
+            if total:
+                total *= unit_scale
+            n *= unit_scale
+            if rate:
+                rate *= unit_scale  # by default rate = self.avg_dn / self.avg_dt
+            unit_scale = False
+
+        elapsed_str = tqdm.format_interval(elapsed)
+
+        # if unspecified, attempt to use rate = average speed
+        # (we allow manual override since predicting time is an arcane art)
+        if rate is None and elapsed:
+            rate = (n - initial) / elapsed
+        inv_rate = 1 / rate if rate else None
+        format_sizeof = tqdm.format_sizeof
+        rate_noinv_fmt = ((format_sizeof(rate) if unit_scale else f'{rate:5.2f}')
+                          if rate else '?') + unit + '/s'
+        rate_inv_fmt = (
+            (format_sizeof(inv_rate) if unit_scale else f'{inv_rate:5.2f}')
+            if inv_rate else '?') + 's/' + unit
+        rate_fmt = rate_inv_fmt if inv_rate and inv_rate > 1 else rate_noinv_fmt
+
+        if unit_scale:
+            n_fmt = format_sizeof(n, divisor=unit_divisor)
+            total_fmt = format_sizeof(total, divisor=unit_divisor) if total is not None else '?'
+        else:
+            n_fmt = str(n)
+            total_fmt = str(total) if total is not None else '?'
+
+        try:
+            postfix = ', ' + postfix if postfix else ''
+        except TypeError:
+            pass
+
+        remaining = (total - n) / rate if rate and total else 0
+        remaining_str = tqdm.format_interval(remaining) if rate else '?'
+        try:
+            eta_dt = (datetime.now() + timedelta(seconds=remaining)
+                      if rate and total else datetime.fromtimestamp(0, timezone.utc))
+        except OverflowError:
+            eta_dt = datetime.max
+
+        # format the stats displayed to the left and right sides of the bar
+        if prefix:
+            # old prefix setup work around
+            bool_prefix_colon_already = (prefix[-2:] == ": ")
+            l_bar = prefix if bool_prefix_colon_already else prefix + ": "
+        else:
+            l_bar = ''
+
+        r_bar = f'| {n_fmt}/{total_fmt} [{elapsed_str}<{remaining_str}, {rate_fmt}{postfix}]'
+
+        # Custom bar formatting
+        # Populate a dict with all available progress indicators
+        format_dict = {
+            # slight extension of self.format_dict
+            'n': n, 'n_fmt': n_fmt, 'total': total, 'total_fmt': total_fmt,
+            'elapsed': elapsed_str, 'elapsed_s': elapsed,
+            'ncols': ncols, 'desc': prefix or '', 'unit': unit,
+            'rate': inv_rate if inv_rate and inv_rate > 1 else rate,
+            'rate_fmt': rate_fmt, 'rate_noinv': rate,
+            'rate_noinv_fmt': rate_noinv_fmt, 'rate_inv': inv_rate,
+            'rate_inv_fmt': rate_inv_fmt,
+            'postfix': postfix, 'unit_divisor': unit_divisor,
+            'colour': colour,
+            # plus more useful definitions
+            'remaining': remaining_str, 'remaining_s': remaining,
+            'l_bar': l_bar, 'r_bar': r_bar, 'eta': eta_dt,
+            **extra_kwargs}
+
+        # total is known: we can predict some stats
+        if total:
+            # fractional and percentage progress
+            frac = n / total
+            percentage = frac * 100
+
+            l_bar += f'{percentage:3.0f}%|'
+
+            if ncols == 0:
+                return l_bar[:-1] + r_bar[1:]
+
+            format_dict.update(l_bar=l_bar)
+            if bar_format:
+                format_dict.update(percentage=percentage)
+
+                # auto-remove colon for empty `{desc}`
+                if not prefix:
+                    bar_format = bar_format.replace("{desc}: ", '')
+            else:
+                bar_format = "{l_bar}{bar}{r_bar}"
+
+            full_bar = FormatReplace()
+            nobar = bar_format.format(bar=full_bar, **format_dict)
+            if not full_bar.format_called:
+                return nobar  # no `{bar}`; nothing else to do
+
+            # Formatting progress bar space available for bar's display
+            full_bar = Bar(frac,
+                           max(1, ncols - disp_len(nobar)) if ncols else 10,
+                           charset=Bar.ASCII if ascii is True else ascii or Bar.UTF,
+                           colour=colour)
+            if not _is_ascii(full_bar.charset) and _is_ascii(bar_format):
+                bar_format = str(bar_format)
+            res = bar_format.format(bar=full_bar, **format_dict)
+            return disp_trim(res, ncols) if ncols else res
+
+        elif bar_format:
+            # user-specified bar_format but no total
+            l_bar += '|'
+            format_dict.update(l_bar=l_bar, percentage=0)
+            full_bar = FormatReplace()
+            nobar = bar_format.format(bar=full_bar, **format_dict)
+            if not full_bar.format_called:
+                return nobar
+            full_bar = Bar(0,
+                           max(1, ncols - disp_len(nobar)) if ncols else 10,
+                           charset=Bar.BLANK, colour=colour)
+            res = bar_format.format(bar=full_bar, **format_dict)
+            return disp_trim(res, ncols) if ncols else res
+        else:
+            # no total: no progressbar, ETA, just progress stats
+            return (f'{(prefix + ": ") if prefix else ""}'
+                    f'{n_fmt}{unit} [{elapsed_str}, {rate_fmt}{postfix}]')
+
+    def __new__(cls, *_, **__):
+        instance = object.__new__(cls)
+        with cls.get_lock():  # also constructs lock if non-existent
+            cls._instances.add(instance)
+            # create monitoring thread
+            if cls.monitor_interval and (cls.monitor is None
+                                         or not cls.monitor.report()):
+                try:
+                    cls.monitor = TMonitor(cls, cls.monitor_interval)
+                except Exception as e:  # pragma: nocover
+                    warn("tqdm:disabling monitor support"
+                         " (monitor_interval = 0) due to:\n" + str(e),
+                         TqdmMonitorWarning, stacklevel=2)
+                    cls.monitor_interval = 0
+        return instance
+
+    @classmethod
+    def _get_free_pos(cls, instance=None):
+        """Skips specified instance."""
+        positions = {abs(inst.pos) for inst in cls._instances
+                     if inst is not instance and hasattr(inst, "pos")}
+        return min(set(range(len(positions) + 1)).difference(positions))
+
+    @classmethod
+    def _decr_instances(cls, instance):
+        """
+        Remove from list and reposition another unfixed bar
+        to fill the new gap.
+
+        This means that by default (where all nested bars are unfixed),
+        order is not maintained but screen flicker/blank space is minimised.
+        (tqdm<=4.44.1 moved ALL subsequent unfixed bars up.)
+        """
+        with cls._lock:
+            try:
+                cls._instances.remove(instance)
+            except KeyError:
+                # if not instance.gui:  # pragma: no cover
+                #     raise
+                pass  # py2: maybe magically removed already
+            # else:
+            if not instance.gui:
+                last = (instance.nrows or 20) - 1
+                # find unfixed (`pos >= 0`) overflow (`pos >= nrows - 1`)
+                instances = list(filter(
+                    lambda i: hasattr(i, "pos") and last <= i.pos,
+                    cls._instances))
+                # set first found to current `pos`
+                if instances:
+                    inst = min(instances, key=lambda i: i.pos)
+                    inst.clear(nolock=True)
+                    inst.pos = abs(instance.pos)
+
+    @classmethod
+    def write(cls, s, file=None, end="\n", nolock=False):
+        """Print a message via tqdm (without overlap with bars)."""
+        fp = file if file is not None else sys.stdout
+        with cls.external_write_mode(file=file, nolock=nolock):
+            # Write the message
+            fp.write(s)
+            fp.write(end)
+
+    @classmethod
+    @contextmanager
+    def external_write_mode(cls, file=None, nolock=False):
+        """
+        Disable tqdm within context and refresh tqdm when exits.
+        Useful when writing to standard output stream
+        """
+        fp = file if file is not None else sys.stdout
+
+        try:
+            if not nolock:
+                cls.get_lock().acquire()
+            # Clear all bars
+            inst_cleared = []
+            for inst in getattr(cls, '_instances', []):
+                # Clear instance if in the target output file
+                # or if write output + tqdm output are both either
+                # sys.stdout or sys.stderr (because both are mixed in terminal)
+                if hasattr(inst, "start_t") and (inst.fp == fp or all(
+                        f in (sys.stdout, sys.stderr) for f in (fp, inst.fp))):
+                    inst.clear(nolock=True)
+                    inst_cleared.append(inst)
+            yield
+            # Force refresh display of bars we cleared
+            for inst in inst_cleared:
+                inst.refresh(nolock=True)
+        finally:
+            if not nolock:
+                cls._lock.release()
+
+    @classmethod
+    def set_lock(cls, lock):
+        """Set the global lock."""
+        cls._lock = lock
+
+    @classmethod
+    def get_lock(cls):
+        """Get the global lock. Construct it if it does not exist."""
+        if not hasattr(cls, '_lock'):
+            cls._lock = TqdmDefaultWriteLock()
+        return cls._lock
+
+    @classmethod
+    def pandas(cls, **tqdm_kwargs):
+        """
+        Registers the current `tqdm` class with
+            pandas.core.
+            ( frame.DataFrame
+            | series.Series
+            | groupby.(generic.)DataFrameGroupBy
+            | groupby.(generic.)SeriesGroupBy
+            ).progress_apply
+
+        A new instance will be created every time `progress_apply` is called,
+        and each instance will automatically `close()` upon completion.
+
+        Parameters
+        ----------
+        tqdm_kwargs  : arguments for the tqdm instance
+
+        Examples
+        --------
+        >>> import pandas as pd
+        >>> import numpy as np
+        >>> from tqdm import tqdm
+        >>> from tqdm.gui import tqdm as tqdm_gui
+        >>>
+        >>> df = pd.DataFrame(np.random.randint(0, 100, (100000, 6)))
+        >>> tqdm.pandas(ncols=50)  # can use tqdm_gui, optional kwargs, etc
+        >>> # Now you can use `progress_apply` instead of `apply`
+        >>> df.groupby(0).progress_apply(lambda x: x**2)
+
+        References
+        ----------
+        
+        """
+        from warnings import catch_warnings, simplefilter
+
+        from pandas.core.frame import DataFrame
+        from pandas.core.series import Series
+        try:
+            with catch_warnings():
+                simplefilter("ignore", category=FutureWarning)
+                from pandas import Panel
+        except ImportError:  # pandas>=1.2.0
+            Panel = None
+        Rolling, Expanding = None, None
+        try:  # pandas>=1.0.0
+            from pandas.core.window.rolling import _Rolling_and_Expanding
+        except ImportError:
+            try:  # pandas>=0.18.0
+                from pandas.core.window import _Rolling_and_Expanding
+            except ImportError:  # pandas>=1.2.0
+                try:  # pandas>=1.2.0
+                    from pandas.core.window.expanding import Expanding
+                    from pandas.core.window.rolling import Rolling
+                    _Rolling_and_Expanding = Rolling, Expanding
+                except ImportError:  # pragma: no cover
+                    _Rolling_and_Expanding = None
+        try:  # pandas>=0.25.0
+            from pandas.core.groupby.generic import SeriesGroupBy  # , NDFrameGroupBy
+            from pandas.core.groupby.generic import DataFrameGroupBy
+        except ImportError:  # pragma: no cover
+            try:  # pandas>=0.23.0
+                from pandas.core.groupby.groupby import DataFrameGroupBy, SeriesGroupBy
+            except ImportError:
+                from pandas.core.groupby import DataFrameGroupBy, SeriesGroupBy
+        try:  # pandas>=0.23.0
+            from pandas.core.groupby.groupby import GroupBy
+        except ImportError:  # pragma: no cover
+            from pandas.core.groupby import GroupBy
+
+        try:  # pandas>=0.23.0
+            from pandas.core.groupby.groupby import PanelGroupBy
+        except ImportError:
+            try:
+                from pandas.core.groupby import PanelGroupBy
+            except ImportError:  # pandas>=0.25.0
+                PanelGroupBy = None
+
+        tqdm_kwargs = tqdm_kwargs.copy()
+        deprecated_t = [tqdm_kwargs.pop('deprecated_t', None)]
+
+        def inner_generator(df_function='apply'):
+            def inner(df, func, *args, **kwargs):
+                """
+                Parameters
+                ----------
+                df  : (DataFrame|Series)[GroupBy]
+                    Data (may be grouped).
+                func  : function
+                    To be applied on the (grouped) data.
+                **kwargs  : optional
+                    Transmitted to `df.apply()`.
+                """
+
+                # Precompute total iterations
+                total = tqdm_kwargs.pop("total", getattr(df, 'ngroups', None))
+                if total is None:  # not grouped
+                    if df_function == 'applymap':
+                        total = df.size
+                    elif isinstance(df, Series):
+                        total = len(df)
+                    elif (_Rolling_and_Expanding is None or
+                          not isinstance(df, _Rolling_and_Expanding)):
+                        # DataFrame or Panel
+                        axis = kwargs.get('axis', 0)
+                        if axis == 'index':
+                            axis = 0
+                        elif axis == 'columns':
+                            axis = 1
+                        # when axis=0, total is shape[axis1]
+                        total = df.size // df.shape[axis]
+
+                # Init bar
+                if deprecated_t[0] is not None:
+                    t = deprecated_t[0]
+                    deprecated_t[0] = None
+                else:
+                    t = cls(total=total, **tqdm_kwargs)
+
+                if len(args) > 0:
+                    # *args intentionally not supported (see #244, #299)
+                    TqdmDeprecationWarning(
+                        "Except func, normal arguments are intentionally" +
+                        " not supported by" +
+                        " `(DataFrame|Series|GroupBy).progress_apply`." +
+                        " Use keyword arguments instead.",
+                        fp_write=getattr(t.fp, 'write', sys.stderr.write))
+
+                try:  # pandas>=1.3.0
+                    from pandas.core.common import is_builtin_func
+                except ImportError:
+                    is_builtin_func = df._is_builtin_func
+                try:
+                    func = is_builtin_func(func)
+                except TypeError:
+                    pass
+
+                # Define bar updating wrapper
+                def wrapper(*args, **kwargs):
+                    # update tbar correctly
+                    # it seems `pandas apply` calls `func` twice
+                    # on the first column/row to decide whether it can
+                    # take a fast or slow code path; so stop when t.total==t.n
+                    t.update(n=1 if not t.total or t.n < t.total else 0)
+                    return func(*args, **kwargs)
+
+                # Apply the provided function (in **kwargs)
+                # on the df using our wrapper (which provides bar updating)
+                try:
+                    return getattr(df, df_function)(wrapper, **kwargs)
+                finally:
+                    t.close()
+
+            return inner
+
+        # Monkeypatch pandas to provide easy methods
+        # Enable custom tqdm progress in pandas!
+        Series.progress_apply = inner_generator()
+        SeriesGroupBy.progress_apply = inner_generator()
+        Series.progress_map = inner_generator('map')
+        SeriesGroupBy.progress_map = inner_generator('map')
+
+        DataFrame.progress_apply = inner_generator()
+        DataFrameGroupBy.progress_apply = inner_generator()
+        DataFrame.progress_applymap = inner_generator('applymap')
+        DataFrame.progress_map = inner_generator('map')
+        DataFrameGroupBy.progress_map = inner_generator('map')
+
+        if Panel is not None:
+            Panel.progress_apply = inner_generator()
+        if PanelGroupBy is not None:
+            PanelGroupBy.progress_apply = inner_generator()
+
+        GroupBy.progress_apply = inner_generator()
+        GroupBy.progress_aggregate = inner_generator('aggregate')
+        GroupBy.progress_transform = inner_generator('transform')
+
+        if Rolling is not None and Expanding is not None:
+            Rolling.progress_apply = inner_generator()
+            Expanding.progress_apply = inner_generator()
+        elif _Rolling_and_Expanding is not None:
+            _Rolling_and_Expanding.progress_apply = inner_generator()
+
+    # override defaults via env vars
+    @envwrap("TQDM_", is_method=True, types={'total': float, 'ncols': int, 'miniters': float,
+                                             'position': int, 'nrows': int})
+    def __init__(self, iterable=None, desc=None, total=None, leave=True, file=None,
+                 ncols=None, mininterval=0.1, maxinterval=10.0, miniters=None,
+                 ascii=None, disable=False, unit='it', unit_scale=False,
+                 dynamic_ncols=False, smoothing=0.3, bar_format=None, initial=0,
+                 position=None, postfix=None, unit_divisor=1000, write_bytes=False,
+                 lock_args=None, nrows=None, colour=None, delay=0.0, gui=False,
+                 **kwargs):
+        """see tqdm.tqdm for arguments"""
+        if file is None:
+            file = sys.stderr
+
+        if write_bytes:
+            # Despite coercing unicode into bytes, py2 sys.std* streams
+            # should have bytes written to them.
+            file = SimpleTextIOWrapper(
+                file, encoding=getattr(file, 'encoding', None) or 'utf-8')
+
+        file = DisableOnWriteError(file, tqdm_instance=self)
+
+        if disable is None and hasattr(file, "isatty") and not file.isatty():
+            disable = True
+
+        if total is None and iterable is not None:
+            try:
+                total = len(iterable)
+            except (TypeError, AttributeError):
+                total = None
+        if total == float("inf"):
+            # Infinite iterations, behave same as unknown
+            total = None
+
+        if disable:
+            self.iterable = iterable
+            self.disable = disable
+            with self._lock:
+                self.pos = self._get_free_pos(self)
+                self._instances.remove(self)
+            self.n = initial
+            self.total = total
+            self.leave = leave
+            return
+
+        if kwargs:
+            self.disable = True
+            with self._lock:
+                self.pos = self._get_free_pos(self)
+                self._instances.remove(self)
+            raise (
+                TqdmDeprecationWarning(
+                    "`nested` is deprecated and automated.\n"
+                    "Use `position` instead for manual control.\n",
+                    fp_write=getattr(file, 'write', sys.stderr.write))
+                if "nested" in kwargs else
+                TqdmKeyError("Unknown argument(s): " + str(kwargs)))
+
+        # Preprocess the arguments
+        if (
+            (ncols is None or nrows is None) and (file in (sys.stderr, sys.stdout))
+        ) or dynamic_ncols:  # pragma: no cover
+            if dynamic_ncols:
+                dynamic_ncols = _screen_shape_wrapper()
+                if dynamic_ncols:
+                    ncols, nrows = dynamic_ncols(file)
+            else:
+                _dynamic_ncols = _screen_shape_wrapper()
+                if _dynamic_ncols:
+                    _ncols, _nrows = _dynamic_ncols(file)
+                    if ncols is None:
+                        ncols = _ncols
+                    if nrows is None:
+                        nrows = _nrows
+
+        if miniters is None:
+            miniters = 0
+            dynamic_miniters = True
+        else:
+            dynamic_miniters = False
+
+        if mininterval is None:
+            mininterval = 0
+
+        if maxinterval is None:
+            maxinterval = 0
+
+        if ascii is None:
+            ascii = not _supports_unicode(file)
+
+        if bar_format and ascii is not True and not _is_ascii(ascii):
+            # Convert bar format into unicode since terminal uses unicode
+            bar_format = str(bar_format)
+
+        if smoothing is None:
+            smoothing = 0
+
+        # Store the arguments
+        self.iterable = iterable
+        self.desc = desc or ''
+        self.total = total
+        self.leave = leave
+        self.fp = file
+        self.ncols = ncols
+        self.nrows = nrows
+        self.mininterval = mininterval
+        self.maxinterval = maxinterval
+        self.miniters = miniters
+        self.dynamic_miniters = dynamic_miniters
+        self.ascii = ascii
+        self.disable = disable
+        self.unit = unit
+        self.unit_scale = unit_scale
+        self.unit_divisor = unit_divisor
+        self.initial = initial
+        self.lock_args = lock_args
+        self.delay = delay
+        self.gui = gui
+        self.dynamic_ncols = dynamic_ncols
+        self.smoothing = smoothing
+        self._ema_dn = EMA(smoothing)
+        self._ema_dt = EMA(smoothing)
+        self._ema_miniters = EMA(smoothing)
+        self.bar_format = bar_format
+        self.postfix = None
+        self.colour = colour
+        self._time = time
+        if postfix:
+            try:
+                self.set_postfix(refresh=False, **postfix)
+            except TypeError:
+                self.postfix = postfix
+
+        # Init the iterations counters
+        self.last_print_n = initial
+        self.n = initial
+
+        # if nested, at initial sp() call we replace '\r' by '\n' to
+        # not overwrite the outer progress bar
+        with self._lock:
+            # mark fixed positions as negative
+            self.pos = self._get_free_pos(self) if position is None else -position
+
+        if not gui:
+            # Initialize the screen printer
+            self.sp = self.status_printer(self.fp)
+            if delay <= 0:
+                self.refresh(lock_args=self.lock_args)
+
+        # Init the time counter
+        self.last_print_t = self._time()
+        # NB: Avoid race conditions by setting start_t at the very end of init
+        self.start_t = self.last_print_t
+
+    def __bool__(self):
+        if self.total is not None:
+            return self.total > 0
+        if self.iterable is None:
+            raise TypeError('bool() undefined when iterable == total == None')
+        return bool(self.iterable)
+
+    def __len__(self):
+        return (
+            self.total if self.iterable is None
+            else self.iterable.shape[0] if hasattr(self.iterable, "shape")
+            else len(self.iterable) if hasattr(self.iterable, "__len__")
+            else self.iterable.__length_hint__() if hasattr(self.iterable, "__length_hint__")
+            else getattr(self, "total", None))
+
+    def __reversed__(self):
+        try:
+            orig = self.iterable
+        except AttributeError:
+            raise TypeError("'tqdm' object is not reversible")
+        else:
+            self.iterable = reversed(self.iterable)
+            return self.__iter__()
+        finally:
+            self.iterable = orig
+
+    def __contains__(self, item):
+        contains = getattr(self.iterable, '__contains__', None)
+        return contains(item) if contains is not None else item in self.__iter__()
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        try:
+            self.close()
+        except AttributeError:
+            # maybe eager thread cleanup upon external error
+            if (exc_type, exc_value, traceback) == (None, None, None):
+                raise
+            warn("AttributeError ignored", TqdmWarning, stacklevel=2)
+
+    def __del__(self):
+        self.close()
+
+    def __str__(self):
+        return self.format_meter(**self.format_dict)
+
+    @property
+    def _comparable(self):
+        return abs(getattr(self, "pos", 1 << 31))
+
+    def __hash__(self):
+        return id(self)
+
+    def __iter__(self):
+        """Backward-compatibility to use: for x in tqdm(iterable)"""
+
+        # Inlining instance variables as locals (speed optimisation)
+        iterable = self.iterable
+
+        # If the bar is disabled, then just walk the iterable
+        # (note: keep this check outside the loop for performance)
+        if self.disable:
+            for obj in iterable:
+                yield obj
+            return
+
+        mininterval = self.mininterval
+        last_print_t = self.last_print_t
+        last_print_n = self.last_print_n
+        min_start_t = self.start_t + self.delay
+        n = self.n
+        time = self._time
+
+        try:
+            for obj in iterable:
+                yield obj
+                # Update and possibly print the progressbar.
+                # Note: does not call self.update(1) for speed optimisation.
+                n += 1
+
+                if n - last_print_n >= self.miniters:
+                    cur_t = time()
+                    dt = cur_t - last_print_t
+                    if dt >= mininterval and cur_t >= min_start_t:
+                        self.update(n - last_print_n)
+                        last_print_n = self.last_print_n
+                        last_print_t = self.last_print_t
+        finally:
+            self.n = n
+            self.close()
+
+    def update(self, n=1):
+        """
+        Manually update the progress bar, useful for streams
+        such as reading files.
+        E.g.:
+        >>> t = tqdm(total=filesize) # Initialise
+        >>> for current_buffer in stream:
+        ...    ...
+        ...    t.update(len(current_buffer))
+        >>> t.close()
+        The last line is highly recommended, but possibly not necessary if
+        `t.update()` will be called in such a way that `filesize` will be
+        exactly reached and printed.
+
+        Parameters
+        ----------
+        n  : int or float, optional
+            Increment to add to the internal counter of iterations
+            [default: 1]. If using float, consider specifying `{n:.3f}`
+            or similar in `bar_format`, or specifying `unit_scale`.
+
+        Returns
+        -------
+        out  : bool or None
+            True if a `display()` was triggered.
+        """
+        if self.disable:
+            return
+
+        if n < 0:
+            self.last_print_n += n  # for auto-refresh logic to work
+        self.n += n
+
+        # check counter first to reduce calls to time()
+        if self.n - self.last_print_n >= self.miniters:
+            cur_t = self._time()
+            dt = cur_t - self.last_print_t
+            if dt >= self.mininterval and cur_t >= self.start_t + self.delay:
+                cur_t = self._time()
+                dn = self.n - self.last_print_n  # >= n
+                if self.smoothing and dt and dn:
+                    # EMA (not just overall average)
+                    self._ema_dn(dn)
+                    self._ema_dt(dt)
+                self.refresh(lock_args=self.lock_args)
+                if self.dynamic_miniters:
+                    # If no `miniters` was specified, adjust automatically to the
+                    # maximum iteration rate seen so far between two prints.
+                    # e.g.: After running `tqdm.update(5)`, subsequent
+                    # calls to `tqdm.update()` will only cause an update after
+                    # at least 5 more iterations.
+                    if self.maxinterval and dt >= self.maxinterval:
+                        self.miniters = dn * (self.mininterval or self.maxinterval) / dt
+                    elif self.smoothing:
+                        # EMA miniters update
+                        self.miniters = self._ema_miniters(
+                            dn * (self.mininterval / dt if self.mininterval and dt
+                                  else 1))
+                    else:
+                        # max iters between two prints
+                        self.miniters = max(self.miniters, dn)
+
+                # Store old values for next call
+                self.last_print_n = self.n
+                self.last_print_t = cur_t
+                return True
+
+    def close(self):
+        """Cleanup and (if leave=False) close the progressbar."""
+        if self.disable:
+            return
+
+        # Prevent multiple closures
+        self.disable = True
+
+        # decrement instance pos and remove from internal set
+        pos = abs(self.pos)
+        self._decr_instances(self)
+
+        if self.last_print_t < self.start_t + self.delay:
+            # haven't ever displayed; nothing to clear
+            return
+
+        # GUI mode
+        if getattr(self, 'sp', None) is None:
+            return
+
+        # annoyingly, _supports_unicode isn't good enough
+        def fp_write(s):
+            self.fp.write(str(s))
+
+        try:
+            fp_write('')
+        except ValueError as e:
+            if 'closed' in str(e):
+                return
+            raise  # pragma: no cover
+
+        leave = pos == 0 if self.leave is None else self.leave
+
+        with self._lock:
+            if leave:
+                # stats for overall rate (no weighted average)
+                self._ema_dt = lambda: None
+                self.display(pos=0)
+                fp_write('\n')
+            else:
+                # clear previous display
+                if self.display(msg='', pos=pos) and not pos:
+                    fp_write('\r')
+
+    def clear(self, nolock=False):
+        """Clear current bar display."""
+        if self.disable:
+            return
+
+        if not nolock:
+            self._lock.acquire()
+        pos = abs(self.pos)
+        if pos < (self.nrows or 20):
+            self.moveto(pos)
+            self.sp('')
+            self.fp.write('\r')  # place cursor back at the beginning of line
+            self.moveto(-pos)
+        if not nolock:
+            self._lock.release()
+
+    def refresh(self, nolock=False, lock_args=None):
+        """
+        Force refresh the display of this bar.
+
+        Parameters
+        ----------
+        nolock  : bool, optional
+            If `True`, does not lock.
+            If [default: `False`]: calls `acquire()` on internal lock.
+        lock_args  : tuple, optional
+            Passed to internal lock's `acquire()`.
+            If specified, will only `display()` if `acquire()` returns `True`.
+        """
+        if self.disable:
+            return
+
+        if not nolock:
+            if lock_args:
+                if not self._lock.acquire(*lock_args):
+                    return False
+            else:
+                self._lock.acquire()
+        self.display()
+        if not nolock:
+            self._lock.release()
+        return True
+
+    def unpause(self):
+        """Restart tqdm timer from last print time."""
+        if self.disable:
+            return
+        cur_t = self._time()
+        self.start_t += cur_t - self.last_print_t
+        self.last_print_t = cur_t
+
+    def reset(self, total=None):
+        """
+        Resets to 0 iterations for repeated use.
+
+        Consider combining with `leave=True`.
+
+        Parameters
+        ----------
+        total  : int or float, optional. Total to use for the new bar.
+        """
+        self.n = 0
+        if total is not None:
+            self.total = total
+        if self.disable:
+            return
+        self.last_print_n = 0
+        self.last_print_t = self.start_t = self._time()
+        self._ema_dn = EMA(self.smoothing)
+        self._ema_dt = EMA(self.smoothing)
+        self._ema_miniters = EMA(self.smoothing)
+        self.refresh()
+
+    def set_description(self, desc=None, refresh=True):
+        """
+        Set/modify description of the progress bar.
+
+        Parameters
+        ----------
+        desc  : str, optional
+        refresh  : bool, optional
+            Forces refresh [default: True].
+        """
+        self.desc = desc + ': ' if desc else ''
+        if refresh:
+            self.refresh()
+
+    def set_description_str(self, desc=None, refresh=True):
+        """Set/modify description without ': ' appended."""
+        self.desc = desc or ''
+        if refresh:
+            self.refresh()
+
+    def set_postfix(self, ordered_dict=None, refresh=True, **kwargs):
+        """
+        Set/modify postfix (additional stats)
+        with automatic formatting based on datatype.
+
+        Parameters
+        ----------
+        ordered_dict  : dict or OrderedDict, optional
+        refresh  : bool, optional
+            Forces refresh [default: True].
+        kwargs  : dict, optional
+        """
+        # Sort in alphabetical order to be more deterministic
+        postfix = OrderedDict([] if ordered_dict is None else ordered_dict)
+        for key in sorted(kwargs.keys()):
+            postfix[key] = kwargs[key]
+        # Preprocess stats according to datatype
+        for key in postfix.keys():
+            # Number: limit the length of the string
+            if isinstance(postfix[key], Number):
+                postfix[key] = self.format_num(postfix[key])
+            # Else for any other type, try to get the string conversion
+            elif not isinstance(postfix[key], str):
+                postfix[key] = str(postfix[key])
+            # Else if it's a string, don't need to preprocess anything
+        # Stitch together to get the final postfix
+        self.postfix = ', '.join(key + '=' + postfix[key].strip()
+                                 for key in postfix.keys())
+        if refresh:
+            self.refresh()
+
+    def set_postfix_str(self, s='', refresh=True):
+        """
+        Postfix without dictionary expansion, similar to prefix handling.
+        """
+        self.postfix = str(s)
+        if refresh:
+            self.refresh()
+
+    def moveto(self, n):
+        # TODO: private method
+        self.fp.write('\n' * n + _term_move_up() * -n)
+        getattr(self.fp, 'flush', lambda: None)()
+
+    @property
+    def format_dict(self):
+        """Public API for read-only member access."""
+        if self.disable and not hasattr(self, 'unit'):
+            return defaultdict(lambda: None, {
+                'n': self.n, 'total': self.total, 'elapsed': 0, 'unit': 'it'})
+        if self.dynamic_ncols:
+            self.ncols, self.nrows = self.dynamic_ncols(self.fp)
+        return {
+            'n': self.n, 'total': self.total,
+            'elapsed': self._time() - self.start_t if hasattr(self, 'start_t') else 0,
+            'ncols': self.ncols, 'nrows': self.nrows, 'prefix': self.desc,
+            'ascii': self.ascii, 'unit': self.unit, 'unit_scale': self.unit_scale,
+            'rate': self._ema_dn() / self._ema_dt() if self._ema_dt() else None,
+            'bar_format': self.bar_format, 'postfix': self.postfix,
+            'unit_divisor': self.unit_divisor, 'initial': self.initial,
+            'colour': self.colour}
+
+    def display(self, msg=None, pos=None):
+        """
+        Use `self.sp` to display `msg` in the specified `pos`.
+
+        Consider overloading this function when inheriting to use e.g.:
+        `self.some_frontend(**self.format_dict)` instead of `self.sp`.
+
+        Parameters
+        ----------
+        msg  : str, optional. What to display (default: `repr(self)`).
+        pos  : int, optional. Position to `moveto`
+          (default: `abs(self.pos)`).
+        """
+        if pos is None:
+            pos = abs(self.pos)
+
+        nrows = self.nrows or 20
+        if pos >= nrows - 1:
+            if pos >= nrows:
+                return False
+            if msg or msg is None:  # override at `nrows - 1`
+                msg = " ... (more hidden) ..."
+
+        if not hasattr(self, "sp"):
+            raise TqdmDeprecationWarning(
+                "Please use `tqdm.gui.tqdm(...)`"
+                " instead of `tqdm(..., gui=True)`\n",
+                fp_write=getattr(self.fp, 'write', sys.stderr.write))
+
+        if pos:
+            self.moveto(pos)
+        self.sp(self.__str__() if msg is None else msg)
+        if pos:
+            self.moveto(-pos)
+        return True
+
+    @classmethod
+    @contextmanager
+    def wrapattr(cls, stream, method, total=None, bytes=True, **tqdm_kwargs):
+        """
+        stream  : file-like object.
+        method  : str, "read" or "write". The result of `read()` and
+            the first argument of `write()` should have a `len()`.
+
+        >>> with tqdm.wrapattr(file_obj, "read", total=file_obj.size) as fobj:
+        ...     while True:
+        ...         chunk = fobj.read(chunk_size)
+        ...         if not chunk:
+        ...             break
+        """
+        with cls(total=total, **tqdm_kwargs) as t:
+            if bytes:
+                t.unit = "B"
+                t.unit_scale = True
+                t.unit_divisor = 1024
+            yield CallbackIOWrapper(t.update, stream, method)
+
+
+def trange(*args, **kwargs):
+    """Shortcut for tqdm(range(*args), **kwargs)."""
+    return tqdm(range(*args), **kwargs)
diff --git a/phivenv/Lib/site-packages/tqdm/tk.py b/phivenv/Lib/site-packages/tqdm/tk.py
new file mode 100644
index 0000000000000000000000000000000000000000..788303c8687e007338ce816bf9afeec8581f0188
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/tk.py
@@ -0,0 +1,196 @@
+"""
+Tkinter GUI progressbar decorator for iterators.
+
+Usage:
+>>> from tqdm.tk import trange, tqdm
+>>> for i in trange(10):
+...     ...
+"""
+import re
+import sys
+import tkinter
+import tkinter.ttk as ttk
+from warnings import warn
+
+from .std import TqdmExperimentalWarning, TqdmWarning
+from .std import tqdm as std_tqdm
+
+__author__ = {"github.com/": ["richardsheridan", "casperdcl"]}
+__all__ = ['tqdm_tk', 'ttkrange', 'tqdm', 'trange']
+
+
+class tqdm_tk(std_tqdm):  # pragma: no cover
+    """
+    Experimental Tkinter GUI version of tqdm!
+
+    Note: Window interactivity suffers if `tqdm_tk` is not running within
+    a Tkinter mainloop and values are generated infrequently. In this case,
+    consider calling `tqdm_tk.refresh()` frequently in the Tk thread.
+    """
+
+    # TODO: @classmethod: write()?
+
+    def __init__(self, *args, **kwargs):
+        """
+        This class accepts the following parameters *in addition* to
+        the parameters accepted by `tqdm`.
+
+        Parameters
+        ----------
+        grab  : bool, optional
+            Grab the input across all windows of the process.
+        tk_parent  : `tkinter.Wm`, optional
+            Parent Tk window.
+        cancel_callback  : Callable, optional
+            Create a cancel button and set `cancel_callback` to be called
+            when the cancel or window close button is clicked.
+        """
+        kwargs = kwargs.copy()
+        kwargs['gui'] = True
+        # convert disable = None to False
+        kwargs['disable'] = bool(kwargs.get('disable', False))
+        self._warn_leave = 'leave' in kwargs
+        grab = kwargs.pop('grab', False)
+        tk_parent = kwargs.pop('tk_parent', None)
+        self._cancel_callback = kwargs.pop('cancel_callback', None)
+        super().__init__(*args, **kwargs)
+
+        if self.disable:
+            return
+
+        if tk_parent is None:  # Discover parent widget
+            try:
+                tk_parent = tkinter._default_root
+            except AttributeError:
+                raise AttributeError(
+                    "`tk_parent` required when using `tkinter.NoDefaultRoot()`")
+            if tk_parent is None:  # use new default root window as display
+                self._tk_window = tkinter.Tk()
+            else:  # some other windows already exist
+                self._tk_window = tkinter.Toplevel()
+        else:
+            self._tk_window = tkinter.Toplevel(tk_parent)
+
+        warn("GUI is experimental/alpha", TqdmExperimentalWarning, stacklevel=2)
+        self._tk_dispatching = self._tk_dispatching_helper()
+
+        self._tk_window.protocol("WM_DELETE_WINDOW", self.cancel)
+        self._tk_window.wm_title(self.desc)
+        self._tk_window.wm_attributes("-topmost", 1)
+        self._tk_window.after(0, lambda: self._tk_window.wm_attributes("-topmost", 0))
+        self._tk_n_var = tkinter.DoubleVar(self._tk_window, value=0)
+        self._tk_text_var = tkinter.StringVar(self._tk_window)
+        pbar_frame = ttk.Frame(self._tk_window, padding=5)
+        pbar_frame.pack()
+        _tk_label = ttk.Label(pbar_frame, textvariable=self._tk_text_var,
+                              wraplength=600, anchor="center", justify="center")
+        _tk_label.pack()
+        self._tk_pbar = ttk.Progressbar(
+            pbar_frame, variable=self._tk_n_var, length=450)
+        if self.total is not None:
+            self._tk_pbar.configure(maximum=self.total)
+        else:
+            self._tk_pbar.configure(mode="indeterminate")
+        self._tk_pbar.pack()
+        if self._cancel_callback is not None:
+            _tk_button = ttk.Button(pbar_frame, text="Cancel", command=self.cancel)
+            _tk_button.pack()
+        if grab:
+            self._tk_window.grab_set()
+
+    def close(self):
+        if self.disable:
+            return
+
+        self.disable = True
+
+        with self.get_lock():
+            self._instances.remove(self)
+
+        def _close():
+            self._tk_window.after('idle', self._tk_window.destroy)
+            if not self._tk_dispatching:
+                self._tk_window.update()
+
+        self._tk_window.protocol("WM_DELETE_WINDOW", _close)
+
+        # if leave is set but we are self-dispatching, the left window is
+        # totally unresponsive unless the user manually dispatches
+        if not self.leave:
+            _close()
+        elif not self._tk_dispatching:
+            if self._warn_leave:
+                warn("leave flag ignored if not in tkinter mainloop",
+                     TqdmWarning, stacklevel=2)
+            _close()
+
+    def clear(self, *_, **__):
+        pass
+
+    def display(self, *_, **__):
+        self._tk_n_var.set(self.n)
+        d = self.format_dict
+        # remove {bar}
+        d['bar_format'] = (d['bar_format'] or "{l_bar}{r_bar}").replace(
+            "{bar}", "")
+        msg = self.format_meter(**d)
+        if '' in msg:
+            msg = "".join(re.split(r'\|?\|?', msg, maxsplit=1))
+        self._tk_text_var.set(msg)
+        if not self._tk_dispatching:
+            self._tk_window.update()
+
+    def set_description(self, desc=None, refresh=True):
+        self.set_description_str(desc, refresh)
+
+    def set_description_str(self, desc=None, refresh=True):
+        self.desc = desc
+        if not self.disable:
+            self._tk_window.wm_title(desc)
+            if refresh and not self._tk_dispatching:
+                self._tk_window.update()
+
+    def cancel(self):
+        """
+        `cancel_callback()` followed by `close()`
+        when close/cancel buttons clicked.
+        """
+        if self._cancel_callback is not None:
+            self._cancel_callback()
+        self.close()
+
+    def reset(self, total=None):
+        """
+        Resets to 0 iterations for repeated use.
+
+        Parameters
+        ----------
+        total  : int or float, optional. Total to use for the new bar.
+        """
+        if hasattr(self, '_tk_pbar'):
+            if total is None:
+                self._tk_pbar.configure(maximum=100, mode="indeterminate")
+            else:
+                self._tk_pbar.configure(maximum=total, mode="determinate")
+        super().reset(total=total)
+
+    @staticmethod
+    def _tk_dispatching_helper():
+        """determine if Tkinter mainloop is dispatching events"""
+        codes = {tkinter.mainloop.__code__, tkinter.Misc.mainloop.__code__}
+        for frame in sys._current_frames().values():
+            while frame:
+                if frame.f_code in codes:
+                    return True
+                frame = frame.f_back
+        return False
+
+
+def ttkrange(*args, **kwargs):
+    """Shortcut for `tqdm.tk.tqdm(range(*args), **kwargs)`."""
+    return tqdm_tk(range(*args), **kwargs)
+
+
+# Aliases
+tqdm = tqdm_tk
+trange = ttkrange
diff --git a/phivenv/Lib/site-packages/tqdm/tqdm.1 b/phivenv/Lib/site-packages/tqdm/tqdm.1
new file mode 100644
index 0000000000000000000000000000000000000000..b90ab4b9ebdd183c98ee8ae0c7f0a65ac676e3b7
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/tqdm.1
@@ -0,0 +1,314 @@
+.\" Automatically generated by Pandoc 1.19.2
+.\"
+.TH "TQDM" "1" "2015\-2021" "tqdm User Manuals" ""
+.hy
+.SH NAME
+.PP
+tqdm \- fast, extensible progress bar for Python and CLI
+.SH SYNOPSIS
+.PP
+tqdm [\f[I]options\f[]]
+.SH DESCRIPTION
+.PP
+See .
+Can be used as a pipe:
+.IP
+.nf
+\f[C]
+$\ #\ count\ lines\ of\ code
+$\ cat\ *.py\ |\ tqdm\ |\ wc\ \-l
+327it\ [00:00,\ 981773.38it/s]
+327
+
+$\ #\ find\ all\ files
+$\ find\ .\ \-name\ "*.py"\ |\ tqdm\ |\ wc\ \-l
+432it\ [00:00,\ 833842.30it/s]
+432
+
+#\ ...\ and\ more\ info
+$\ find\ .\ \-name\ \[aq]*.py\[aq]\ \-exec\ wc\ \-l\ \\{}\ \\;\ \\
+\ \ |\ tqdm\ \-\-total\ 432\ \-\-unit\ files\ \-\-desc\ counting\ \\
+\ \ |\ awk\ \[aq]{\ sum\ +=\ $1\ };\ END\ {\ print\ sum\ }\[aq]
+counting:\ 100%|█████████|\ 432/432\ [00:00<00:00,\ 794361.83files/s]
+131998
+\f[]
+.fi
+.SH OPTIONS
+.TP
+.B \-h, \-\-help
+Print this help and exit.
+.RS
+.RE
+.TP
+.B \-v, \-\-version
+Print version and exit.
+.RS
+.RE
+.TP
+.B \-\-desc=\f[I]desc\f[]
+str, optional.
+Prefix for the progressbar.
+.RS
+.RE
+.TP
+.B \-\-total=\f[I]total\f[]
+int or float, optional.
+The number of expected iterations.
+If unspecified, len(iterable) is used if possible.
+If float("inf") or as a last resort, only basic progress statistics are
+displayed (no ETA, no progressbar).
+If \f[C]gui\f[] is True and this parameter needs subsequent updating,
+specify an initial arbitrary large positive number, e.g.
+9e9.
+.RS
+.RE
+.TP
+.B \-\-leave
+bool, optional.
+If [default: True], keeps all traces of the progressbar upon termination
+of iteration.
+If \f[C]None\f[], will leave only if \f[C]position\f[] is \f[C]0\f[].
+.RS
+.RE
+.TP
+.B \-\-ncols=\f[I]ncols\f[]
+int, optional.
+The width of the entire output message.
+If specified, dynamically resizes the progressbar to stay within this
+bound.
+If unspecified, attempts to use environment width.
+The fallback is a meter width of 10 and no limit for the counter and
+statistics.
+If 0, will not print any meter (only stats).
+.RS
+.RE
+.TP
+.B \-\-mininterval=\f[I]mininterval\f[]
+float, optional.
+Minimum progress display update interval [default: 0.1] seconds.
+.RS
+.RE
+.TP
+.B \-\-maxinterval=\f[I]maxinterval\f[]
+float, optional.
+Maximum progress display update interval [default: 10] seconds.
+Automatically adjusts \f[C]miniters\f[] to correspond to
+\f[C]mininterval\f[] after long display update lag.
+Only works if \f[C]dynamic_miniters\f[] or monitor thread is enabled.
+.RS
+.RE
+.TP
+.B \-\-miniters=\f[I]miniters\f[]
+int or float, optional.
+Minimum progress display update interval, in iterations.
+If 0 and \f[C]dynamic_miniters\f[], will automatically adjust to equal
+\f[C]mininterval\f[] (more CPU efficient, good for tight loops).
+If > 0, will skip display of specified number of iterations.
+Tweak this and \f[C]mininterval\f[] to get very efficient loops.
+If your progress is erratic with both fast and slow iterations (network,
+skipping items, etc) you should set miniters=1.
+.RS
+.RE
+.TP
+.B \-\-ascii=\f[I]ascii\f[]
+bool or str, optional.
+If unspecified or False, use unicode (smooth blocks) to fill the meter.
+The fallback is to use ASCII characters " 123456789#".
+.RS
+.RE
+.TP
+.B \-\-disable
+bool, optional.
+Whether to disable the entire progressbar wrapper [default: False].
+If set to None, disable on non\-TTY.
+.RS
+.RE
+.TP
+.B \-\-unit=\f[I]unit\f[]
+str, optional.
+String that will be used to define the unit of each iteration [default:
+it].
+.RS
+.RE
+.TP
+.B \-\-unit\-scale=\f[I]unit_scale\f[]
+bool or int or float, optional.
+If 1 or True, the number of iterations will be reduced/scaled
+automatically and a metric prefix following the International System of
+Units standard will be added (kilo, mega, etc.) [default: False].
+If any other non\-zero number, will scale \f[C]total\f[] and \f[C]n\f[].
+.RS
+.RE
+.TP
+.B \-\-dynamic\-ncols
+bool, optional.
+If set, constantly alters \f[C]ncols\f[] and \f[C]nrows\f[] to the
+environment (allowing for window resizes) [default: False].
+.RS
+.RE
+.TP
+.B \-\-smoothing=\f[I]smoothing\f[]
+float, optional.
+Exponential moving average smoothing factor for speed estimates (ignored
+in GUI mode).
+Ranges from 0 (average speed) to 1 (current/instantaneous speed)
+[default: 0.3].
+.RS
+.RE
+.TP
+.B \-\-bar\-format=\f[I]bar_format\f[]
+str, optional.
+Specify a custom bar string formatting.
+May impact performance.
+[default: \[aq]{l_bar}{bar}{r_bar}\[aq]], where l_bar=\[aq]{desc}:
+{percentage:3.0f}%|\[aq] and r_bar=\[aq]| {n_fmt}/{total_fmt}
+[{elapsed}<{remaining}, \[aq] \[aq]{rate_fmt}{postfix}]\[aq] Possible
+vars: l_bar, bar, r_bar, n, n_fmt, total, total_fmt, percentage,
+elapsed, elapsed_s, ncols, nrows, desc, unit, rate, rate_fmt,
+rate_noinv, rate_noinv_fmt, rate_inv, rate_inv_fmt, postfix,
+unit_divisor, remaining, remaining_s, eta.
+Note that a trailing ": " is automatically removed after {desc} if the
+latter is empty.
+.RS
+.RE
+.TP
+.B \-\-initial=\f[I]initial\f[]
+int or float, optional.
+The initial counter value.
+Useful when restarting a progress bar [default: 0].
+If using float, consider specifying \f[C]{n:.3f}\f[] or similar in
+\f[C]bar_format\f[], or specifying \f[C]unit_scale\f[].
+.RS
+.RE
+.TP
+.B \-\-position=\f[I]position\f[]
+int, optional.
+Specify the line offset to print this bar (starting from 0) Automatic if
+unspecified.
+Useful to manage multiple bars at once (eg, from threads).
+.RS
+.RE
+.TP
+.B \-\-postfix=\f[I]postfix\f[]
+dict or *, optional.
+Specify additional stats to display at the end of the bar.
+Calls \f[C]set_postfix(**postfix)\f[] if possible (dict).
+.RS
+.RE
+.TP
+.B \-\-unit\-divisor=\f[I]unit_divisor\f[]
+float, optional.
+[default: 1000], ignored unless \f[C]unit_scale\f[] is True.
+.RS
+.RE
+.TP
+.B \-\-write\-bytes
+bool, optional.
+Whether to write bytes.
+If (default: False) will write unicode.
+.RS
+.RE
+.TP
+.B \-\-lock\-args=\f[I]lock_args\f[]
+tuple, optional.
+Passed to \f[C]refresh\f[] for intermediate output (initialisation,
+iterating, and updating).
+.RS
+.RE
+.TP
+.B \-\-nrows=\f[I]nrows\f[]
+int, optional.
+The screen height.
+If specified, hides nested bars outside this bound.
+If unspecified, attempts to use environment height.
+The fallback is 20.
+.RS
+.RE
+.TP
+.B \-\-colour=\f[I]colour\f[]
+str, optional.
+Bar colour (e.g.
+\[aq]green\[aq], \[aq]#00ff00\[aq]).
+.RS
+.RE
+.TP
+.B \-\-delay=\f[I]delay\f[]
+float, optional.
+Don\[aq]t display until [default: 0] seconds have elapsed.
+.RS
+.RE
+.TP
+.B \-\-delim=\f[I]delim\f[]
+chr, optional.
+Delimiting character [default: \[aq]\\n\[aq]].
+Use \[aq]\\0\[aq] for null.
+N.B.: on Windows systems, Python converts \[aq]\\n\[aq] to
+\[aq]\\r\\n\[aq].
+.RS
+.RE
+.TP
+.B \-\-buf\-size=\f[I]buf_size\f[]
+int, optional.
+String buffer size in bytes [default: 256] used when \f[C]delim\f[] is
+specified.
+.RS
+.RE
+.TP
+.B \-\-bytes
+bool, optional.
+If true, will count bytes, ignore \f[C]delim\f[], and default
+\f[C]unit_scale\f[] to True, \f[C]unit_divisor\f[] to 1024, and
+\f[C]unit\f[] to \[aq]B\[aq].
+.RS
+.RE
+.TP
+.B \-\-tee
+bool, optional.
+If true, passes \f[C]stdin\f[] to both \f[C]stderr\f[] and
+\f[C]stdout\f[].
+.RS
+.RE
+.TP
+.B \-\-update
+bool, optional.
+If true, will treat input as newly elapsed iterations, i.e.
+numbers to pass to \f[C]update()\f[].
+Note that this is slow (~2e5 it/s) since every input must be decoded as
+a number.
+.RS
+.RE
+.TP
+.B \-\-update\-to
+bool, optional.
+If true, will treat input as total elapsed iterations, i.e.
+numbers to assign to \f[C]self.n\f[].
+Note that this is slow (~2e5 it/s) since every input must be decoded as
+a number.
+.RS
+.RE
+.TP
+.B \-\-null
+bool, optional.
+If true, will discard input (no stdout).
+.RS
+.RE
+.TP
+.B \-\-manpath=\f[I]manpath\f[]
+str, optional.
+Directory in which to install tqdm man pages.
+.RS
+.RE
+.TP
+.B \-\-comppath=\f[I]comppath\f[]
+str, optional.
+Directory in which to place tqdm completion.
+.RS
+.RE
+.TP
+.B \-\-log=\f[I]log\f[]
+str, optional.
+CRITICAL|FATAL|ERROR|WARN(ING)|[default: \[aq]INFO\[aq]]|DEBUG|NOTSET.
+.RS
+.RE
+.SH AUTHORS
+tqdm developers .
diff --git a/phivenv/Lib/site-packages/tqdm/utils.py b/phivenv/Lib/site-packages/tqdm/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..af3ec7ded55daa98e1f268a3ee891e9a6bd72974
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/utils.py
@@ -0,0 +1,399 @@
+"""
+General helpers required for `tqdm.std`.
+"""
+import os
+import re
+import sys
+from functools import partial, partialmethod, wraps
+from inspect import signature
+# TODO consider using wcswidth third-party package for 0-width characters
+from unicodedata import east_asian_width
+from warnings import warn
+from weakref import proxy
+
+_range, _unich, _unicode, _basestring = range, chr, str, str
+CUR_OS = sys.platform
+IS_WIN = any(CUR_OS.startswith(i) for i in ['win32', 'cygwin'])
+IS_NIX = any(CUR_OS.startswith(i) for i in ['aix', 'linux', 'darwin', 'freebsd'])
+RE_ANSI = re.compile(r"\x1b\[[;\d]*[A-Za-z]")
+
+try:
+    if IS_WIN:
+        import colorama
+    else:
+        raise ImportError
+except ImportError:
+    colorama = None
+else:
+    try:
+        colorama.init(strip=False)
+    except TypeError:
+        colorama.init()
+
+
+def envwrap(prefix, types=None, is_method=False):
+    """
+    Override parameter defaults via `os.environ[prefix + param_name]`.
+    Maps UPPER_CASE env vars map to lower_case param names.
+    camelCase isn't supported (because Windows ignores case).
+
+    Precedence (highest first):
+
+    - call (`foo(a=3)`)
+    - environ (`FOO_A=2`)
+    - signature (`def foo(a=1)`)
+
+    Parameters
+    ----------
+    prefix  : str
+        Env var prefix, e.g. "FOO_"
+    types  : dict, optional
+        Fallback mappings `{'param_name': type, ...}` if types cannot be
+        inferred from function signature.
+        Consider using `types=collections.defaultdict(lambda: ast.literal_eval)`.
+    is_method  : bool, optional
+        Whether to use `functools.partialmethod`. If (default: False) use `functools.partial`.
+
+    Examples
+    --------
+    ```
+    $ cat foo.py
+    from tqdm.utils import envwrap
+    @envwrap("FOO_")
+    def test(a=1, b=2, c=3):
+        print(f"received: a={a}, b={b}, c={c}")
+
+    $ FOO_A=42 FOO_C=1337 python -c 'import foo; foo.test(c=99)'
+    received: a=42, b=2, c=99
+    ```
+    """
+    if types is None:
+        types = {}
+    i = len(prefix)
+    env_overrides = {k[i:].lower(): v for k, v in os.environ.items() if k.startswith(prefix)}
+    part = partialmethod if is_method else partial
+
+    def wrap(func):
+        params = signature(func).parameters
+        # ignore unknown env vars
+        overrides = {k: v for k, v in env_overrides.items() if k in params}
+        # infer overrides' `type`s
+        for k in overrides:
+            param = params[k]
+            if param.annotation is not param.empty:  # typehints
+                for typ in getattr(param.annotation, '__args__', (param.annotation,)):
+                    try:
+                        overrides[k] = typ(overrides[k])
+                    except Exception:
+                        pass
+                    else:
+                        break
+            elif param.default is not None:  # type of default value
+                overrides[k] = type(param.default)(overrides[k])
+            else:
+                try:  # `types` fallback
+                    overrides[k] = types[k](overrides[k])
+                except KeyError:  # keep unconverted (`str`)
+                    pass
+        return part(func, **overrides)
+    return wrap
+
+
+class FormatReplace(object):
+    """
+    >>> a = FormatReplace('something')
+    >>> f"{a:5d}"
+    'something'
+    """  # NOQA: P102
+    def __init__(self, replace=''):
+        self.replace = replace
+        self.format_called = 0
+
+    def __format__(self, _):
+        self.format_called += 1
+        return self.replace
+
+
+class Comparable(object):
+    """Assumes child has self._comparable attr/@property"""
+    def __lt__(self, other):
+        return self._comparable < other._comparable
+
+    def __le__(self, other):
+        return (self < other) or (self == other)
+
+    def __eq__(self, other):
+        return self._comparable == other._comparable
+
+    def __ne__(self, other):
+        return not self == other
+
+    def __gt__(self, other):
+        return not self <= other
+
+    def __ge__(self, other):
+        return not self < other
+
+
+class ObjectWrapper(object):
+    def __getattr__(self, name):
+        return getattr(self._wrapped, name)
+
+    def __setattr__(self, name, value):
+        return setattr(self._wrapped, name, value)
+
+    def wrapper_getattr(self, name):
+        """Actual `self.getattr` rather than self._wrapped.getattr"""
+        try:
+            return object.__getattr__(self, name)
+        except AttributeError:  # py2
+            return getattr(self, name)
+
+    def wrapper_setattr(self, name, value):
+        """Actual `self.setattr` rather than self._wrapped.setattr"""
+        return object.__setattr__(self, name, value)
+
+    def __init__(self, wrapped):
+        """
+        Thin wrapper around a given object
+        """
+        self.wrapper_setattr('_wrapped', wrapped)
+
+
+class SimpleTextIOWrapper(ObjectWrapper):
+    """
+    Change only `.write()` of the wrapped object by encoding the passed
+    value and passing the result to the wrapped object's `.write()` method.
+    """
+    # pylint: disable=too-few-public-methods
+    def __init__(self, wrapped, encoding):
+        super().__init__(wrapped)
+        self.wrapper_setattr('encoding', encoding)
+
+    def write(self, s):
+        """
+        Encode `s` and pass to the wrapped object's `.write()` method.
+        """
+        return self._wrapped.write(s.encode(self.wrapper_getattr('encoding')))
+
+    def __eq__(self, other):
+        return self._wrapped == getattr(other, '_wrapped', other)
+
+
+class DisableOnWriteError(ObjectWrapper):
+    """
+    Disable the given `tqdm_instance` upon `write()` or `flush()` errors.
+    """
+    @staticmethod
+    def disable_on_exception(tqdm_instance, func):
+        """
+        Quietly set `tqdm_instance.miniters=inf` if `func` raises `errno=5`.
+        """
+        tqdm_instance = proxy(tqdm_instance)
+
+        def inner(*args, **kwargs):
+            try:
+                return func(*args, **kwargs)
+            except OSError as e:
+                if e.errno != 5:
+                    raise
+                try:
+                    tqdm_instance.miniters = float('inf')
+                except ReferenceError:
+                    pass
+            except ValueError as e:
+                if 'closed' not in str(e):
+                    raise
+                try:
+                    tqdm_instance.miniters = float('inf')
+                except ReferenceError:
+                    pass
+        return inner
+
+    def __init__(self, wrapped, tqdm_instance):
+        super().__init__(wrapped)
+        if hasattr(wrapped, 'write'):
+            self.wrapper_setattr(
+                'write', self.disable_on_exception(tqdm_instance, wrapped.write))
+        if hasattr(wrapped, 'flush'):
+            self.wrapper_setattr(
+                'flush', self.disable_on_exception(tqdm_instance, wrapped.flush))
+
+    def __eq__(self, other):
+        return self._wrapped == getattr(other, '_wrapped', other)
+
+
+class CallbackIOWrapper(ObjectWrapper):
+    def __init__(self, callback, stream, method="read"):
+        """
+        Wrap a given `file`-like object's `read()` or `write()` to report
+        lengths to the given `callback`
+        """
+        super().__init__(stream)
+        func = getattr(stream, method)
+        if method == "write":
+            @wraps(func)
+            def write(data, *args, **kwargs):
+                res = func(data, *args, **kwargs)
+                callback(len(data))
+                return res
+            self.wrapper_setattr('write', write)
+        elif method == "read":
+            @wraps(func)
+            def read(*args, **kwargs):
+                data = func(*args, **kwargs)
+                callback(len(data))
+                return data
+            self.wrapper_setattr('read', read)
+        else:
+            raise KeyError("Can only wrap read/write methods")
+
+
+def _is_utf(encoding):
+    try:
+        u'\u2588\u2589'.encode(encoding)
+    except UnicodeEncodeError:
+        return False
+    except Exception:
+        try:
+            return encoding.lower().startswith('utf-') or ('U8' == encoding)
+        except Exception:
+            return False
+    else:
+        return True
+
+
+def _supports_unicode(fp):
+    try:
+        return _is_utf(fp.encoding)
+    except AttributeError:
+        return False
+
+
+def _is_ascii(s):
+    if isinstance(s, str):
+        for c in s:
+            if ord(c) > 255:
+                return False
+        return True
+    return _supports_unicode(s)
+
+
+def _screen_shape_wrapper():  # pragma: no cover
+    """
+    Return a function which returns console dimensions (width, height).
+    Supported: linux, osx, windows, cygwin.
+    """
+    _screen_shape = None
+    if IS_WIN:
+        _screen_shape = _screen_shape_windows
+        if _screen_shape is None:
+            _screen_shape = _screen_shape_tput
+    if IS_NIX:
+        _screen_shape = _screen_shape_linux
+    return _screen_shape
+
+
+def _screen_shape_windows(fp):  # pragma: no cover
+    try:
+        import struct
+        from ctypes import create_string_buffer, windll
+        from sys import stdin, stdout
+
+        io_handle = -12  # assume stderr
+        if fp == stdin:
+            io_handle = -10
+        elif fp == stdout:
+            io_handle = -11
+
+        h = windll.kernel32.GetStdHandle(io_handle)
+        csbi = create_string_buffer(22)
+        res = windll.kernel32.GetConsoleScreenBufferInfo(h, csbi)
+        if res:
+            (_bufx, _bufy, _curx, _cury, _wattr, left, top, right, bottom,
+             _maxx, _maxy) = struct.unpack("hhhhHhhhhhh", csbi.raw)
+            return right - left, bottom - top  # +1
+    except Exception:  # nosec
+        pass
+    return None, None
+
+
+def _screen_shape_tput(*_):  # pragma: no cover
+    """cygwin xterm (windows)"""
+    try:
+        import shlex
+        from subprocess import check_call  # nosec
+        return [int(check_call(shlex.split('tput ' + i))) - 1
+                for i in ('cols', 'lines')]
+    except Exception:  # nosec
+        pass
+    return None, None
+
+
+def _screen_shape_linux(fp):  # pragma: no cover
+
+    try:
+        from array import array
+        from fcntl import ioctl
+        from termios import TIOCGWINSZ
+    except ImportError:
+        return None, None
+    else:
+        try:
+            rows, cols = array('h', ioctl(fp, TIOCGWINSZ, '\0' * 8))[:2]
+            return cols, rows
+        except Exception:
+            try:
+                return [int(os.environ[i]) - 1 for i in ("COLUMNS", "LINES")]
+            except (KeyError, ValueError):
+                return None, None
+
+
+def _environ_cols_wrapper():  # pragma: no cover
+    """
+    Return a function which returns console width.
+    Supported: linux, osx, windows, cygwin.
+    """
+    warn("Use `_screen_shape_wrapper()(file)[0]` instead of"
+         " `_environ_cols_wrapper()(file)`", DeprecationWarning, stacklevel=2)
+    shape = _screen_shape_wrapper()
+    if not shape:
+        return None
+
+    @wraps(shape)
+    def inner(fp):
+        return shape(fp)[0]
+
+    return inner
+
+
+def _term_move_up():  # pragma: no cover
+    return '' if (os.name == 'nt') and (colorama is None) else '\x1b[A'
+
+
+def _text_width(s):
+    return sum(2 if east_asian_width(ch) in 'FW' else 1 for ch in str(s))
+
+
+def disp_len(data):
+    """
+    Returns the real on-screen length of a string which may contain
+    ANSI control codes and wide chars.
+    """
+    return _text_width(RE_ANSI.sub('', data))
+
+
+def disp_trim(data, length):
+    """
+    Trim a string which may contain ANSI control characters.
+    """
+    if len(data) == disp_len(data):
+        return data[:length]
+
+    ansi_present = bool(RE_ANSI.search(data))
+    while disp_len(data) > length:  # carefully delete one char at a time
+        data = data[:-1]
+    if ansi_present and bool(RE_ANSI.search(data)):
+        # assume ANSI reset is required
+        return data if data.endswith("\033[0m") else data + "\033[0m"
+    return data
diff --git a/phivenv/Lib/site-packages/tqdm/version.py b/phivenv/Lib/site-packages/tqdm/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..11cbaea79d1f4f46f9ae4bea542d7c66ded96e34
--- /dev/null
+++ b/phivenv/Lib/site-packages/tqdm/version.py
@@ -0,0 +1,9 @@
+"""`tqdm` version detector. Precedence: installed dist, git, 'UNKNOWN'."""
+try:
+    from ._dist_ver import __version__
+except ImportError:
+    try:
+        from setuptools_scm import get_version
+        __version__ = get_version(root='..', relative_to=__file__)
+    except (ImportError, LookupError):
+        __version__ = "UNKNOWN"
diff --git a/phivenv/Lib/site-packages/transformers/__init__.py b/phivenv/Lib/site-packages/transformers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..47d915737677c3a2b03e7797bcd3c3943f950db1
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/__init__.py
@@ -0,0 +1,967 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# When adding a new object to this init, remember to add it twice: once inside the `_import_structure` dictionary and
+# once inside the `if TYPE_CHECKING` branch. The `TYPE_CHECKING` should have import statements as usual, but they are
+# only there for type checking. The `_import_structure` is a dictionary submodule to list of object names, and is used
+# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
+# in the namespace without actually importing anything (and especially none of the backends).
+
+__version__ = "4.56.1"
+
+from pathlib import Path
+from typing import TYPE_CHECKING
+
+# Check the dependencies satisfy the minimal versions required.
+from . import dependency_versions_check
+from .utils import (
+    OptionalDependencyNotAvailable,
+    _LazyModule,
+    is_essentia_available,
+    is_g2p_en_available,
+    is_librosa_available,
+    is_mistral_common_available,
+    is_pretty_midi_available,
+)
+
+# Note: the following symbols are deliberately exported with `as`
+# so that mypy, pylint or other static linters can recognize them,
+# given that they are not exported using `__all__` in this file.
+from .utils import is_bitsandbytes_available as is_bitsandbytes_available
+from .utils import is_flax_available as is_flax_available
+from .utils import is_keras_nlp_available as is_keras_nlp_available
+from .utils import is_scipy_available as is_scipy_available
+from .utils import is_sentencepiece_available as is_sentencepiece_available
+from .utils import is_speech_available as is_speech_available
+from .utils import is_tensorflow_text_available as is_tensorflow_text_available
+from .utils import is_tf_available as is_tf_available
+from .utils import is_timm_available as is_timm_available
+from .utils import is_tokenizers_available as is_tokenizers_available
+from .utils import is_torch_available as is_torch_available
+from .utils import is_torchaudio_available as is_torchaudio_available
+from .utils import is_torchvision_available as is_torchvision_available
+from .utils import is_vision_available as is_vision_available
+from .utils import logging as logging
+from .utils.import_utils import define_import_structure
+
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+# Base objects, independent of any specific backend
+_import_structure = {
+    "audio_utils": [],
+    "commands": [],
+    "configuration_utils": ["PretrainedConfig"],
+    "convert_graph_to_onnx": [],
+    "convert_slow_tokenizers_checkpoints_to_fast": [],
+    "convert_tf_hub_seq_to_seq_bert_to_pytorch": [],
+    "data": [
+        "DataProcessor",
+        "InputExample",
+        "InputFeatures",
+        "SingleSentenceClassificationProcessor",
+        "SquadExample",
+        "SquadFeatures",
+        "SquadV1Processor",
+        "SquadV2Processor",
+        "glue_compute_metrics",
+        "glue_convert_examples_to_features",
+        "glue_output_modes",
+        "glue_processors",
+        "glue_tasks_num_labels",
+        "squad_convert_examples_to_features",
+        "xnli_compute_metrics",
+        "xnli_output_modes",
+        "xnli_processors",
+        "xnli_tasks_num_labels",
+    ],
+    "data.data_collator": [
+        "DataCollator",
+        "DataCollatorForLanguageModeling",
+        "DataCollatorForMultipleChoice",
+        "DataCollatorForPermutationLanguageModeling",
+        "DataCollatorForSeq2Seq",
+        "DataCollatorForSOP",
+        "DataCollatorForTokenClassification",
+        "DataCollatorForWholeWordMask",
+        "DataCollatorWithFlattening",
+        "DataCollatorWithPadding",
+        "DefaultDataCollator",
+        "default_data_collator",
+    ],
+    "data.metrics": [],
+    "data.processors": [],
+    "debug_utils": [],
+    "dependency_versions_check": [],
+    "dependency_versions_table": [],
+    "dynamic_module_utils": [],
+    "feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
+    "feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
+    "file_utils": [],
+    "generation": [
+        "AsyncTextIteratorStreamer",
+        "CompileConfig",
+        "GenerationConfig",
+        "TextIteratorStreamer",
+        "TextStreamer",
+        "WatermarkingConfig",
+    ],
+    "hf_argparser": ["HfArgumentParser"],
+    "hyperparameter_search": [],
+    "image_transforms": [],
+    "integrations": [
+        "is_clearml_available",
+        "is_comet_available",
+        "is_dvclive_available",
+        "is_neptune_available",
+        "is_optuna_available",
+        "is_ray_available",
+        "is_ray_tune_available",
+        "is_sigopt_available",
+        "is_swanlab_available",
+        "is_tensorboard_available",
+        "is_trackio_available",
+        "is_wandb_available",
+    ],
+    "loss": [],
+    "modelcard": ["ModelCard"],
+    # Losses
+    "modeling_tf_pytorch_utils": [
+        "convert_tf_weight_name_to_pt_weight_name",
+        "load_pytorch_checkpoint_in_tf2_model",
+        "load_pytorch_model_in_tf2_model",
+        "load_pytorch_weights_in_tf2_model",
+        "load_tf2_checkpoint_in_pytorch_model",
+        "load_tf2_model_in_pytorch_model",
+        "load_tf2_weights_in_pytorch_model",
+    ],
+    # Models
+    "onnx": [],
+    "pipelines": [
+        "AudioClassificationPipeline",
+        "AutomaticSpeechRecognitionPipeline",
+        "CsvPipelineDataFormat",
+        "DepthEstimationPipeline",
+        "DocumentQuestionAnsweringPipeline",
+        "FeatureExtractionPipeline",
+        "FillMaskPipeline",
+        "ImageClassificationPipeline",
+        "ImageFeatureExtractionPipeline",
+        "ImageSegmentationPipeline",
+        "ImageTextToTextPipeline",
+        "ImageToImagePipeline",
+        "ImageToTextPipeline",
+        "JsonPipelineDataFormat",
+        "KeypointMatchingPipeline",
+        "MaskGenerationPipeline",
+        "NerPipeline",
+        "ObjectDetectionPipeline",
+        "PipedPipelineDataFormat",
+        "Pipeline",
+        "PipelineDataFormat",
+        "QuestionAnsweringPipeline",
+        "SummarizationPipeline",
+        "TableQuestionAnsweringPipeline",
+        "Text2TextGenerationPipeline",
+        "TextClassificationPipeline",
+        "TextGenerationPipeline",
+        "TextToAudioPipeline",
+        "TokenClassificationPipeline",
+        "TranslationPipeline",
+        "VideoClassificationPipeline",
+        "VisualQuestionAnsweringPipeline",
+        "ZeroShotAudioClassificationPipeline",
+        "ZeroShotClassificationPipeline",
+        "ZeroShotImageClassificationPipeline",
+        "ZeroShotObjectDetectionPipeline",
+        "pipeline",
+    ],
+    "processing_utils": ["ProcessorMixin"],
+    "quantizers": [],
+    "testing_utils": [],
+    "tokenization_utils": ["PreTrainedTokenizer"],
+    "tokenization_utils_base": [
+        "AddedToken",
+        "BatchEncoding",
+        "CharSpan",
+        "PreTrainedTokenizerBase",
+        "SpecialTokensMixin",
+        "TokenSpan",
+    ],
+    "trainer_callback": [
+        "DefaultFlowCallback",
+        "EarlyStoppingCallback",
+        "PrinterCallback",
+        "ProgressCallback",
+        "TrainerCallback",
+        "TrainerControl",
+        "TrainerState",
+    ],
+    "trainer_utils": [
+        "EvalPrediction",
+        "IntervalStrategy",
+        "SchedulerType",
+        "enable_full_determinism",
+        "set_seed",
+    ],
+    "training_args": ["TrainingArguments"],
+    "training_args_seq2seq": ["Seq2SeqTrainingArguments"],
+    "training_args_tf": ["TFTrainingArguments"],
+    "utils": [
+        "CONFIG_NAME",
+        "MODEL_CARD_NAME",
+        "PYTORCH_PRETRAINED_BERT_CACHE",
+        "PYTORCH_TRANSFORMERS_CACHE",
+        "SPIECE_UNDERLINE",
+        "TF2_WEIGHTS_NAME",
+        "TF_WEIGHTS_NAME",
+        "TRANSFORMERS_CACHE",
+        "WEIGHTS_NAME",
+        "TensorType",
+        "add_end_docstrings",
+        "add_start_docstrings",
+        "is_apex_available",
+        "is_av_available",
+        "is_bitsandbytes_available",
+        "is_datasets_available",
+        "is_faiss_available",
+        "is_flax_available",
+        "is_keras_nlp_available",
+        "is_matplotlib_available",
+        "is_phonemizer_available",
+        "is_psutil_available",
+        "is_py3nvml_available",
+        "is_pyctcdecode_available",
+        "is_sacremoses_available",
+        "is_safetensors_available",
+        "is_scipy_available",
+        "is_sentencepiece_available",
+        "is_sklearn_available",
+        "is_speech_available",
+        "is_tensorflow_text_available",
+        "is_tf_available",
+        "is_timm_available",
+        "is_tokenizers_available",
+        "is_torch_available",
+        "is_torch_hpu_available",
+        "is_torch_mlu_available",
+        "is_torch_musa_available",
+        "is_torch_neuroncore_available",
+        "is_torch_npu_available",
+        "is_torchvision_available",
+        "is_torch_xla_available",
+        "is_torch_xpu_available",
+        "is_vision_available",
+        "logging",
+    ],
+    "utils.quantization_config": [
+        "AqlmConfig",
+        "AutoRoundConfig",
+        "AwqConfig",
+        "BitNetQuantConfig",
+        "BitsAndBytesConfig",
+        "CompressedTensorsConfig",
+        "EetqConfig",
+        "FbgemmFp8Config",
+        "FineGrainedFP8Config",
+        "GPTQConfig",
+        "HiggsConfig",
+        "HqqConfig",
+        "Mxfp4Config",
+        "QuantoConfig",
+        "QuarkConfig",
+        "FPQuantConfig",
+        "SpQRConfig",
+        "TorchAoConfig",
+        "VptqConfig",
+    ],
+    "video_utils": [],
+}
+
+# tokenizers-backed objects
+try:
+    if not is_tokenizers_available():
+        raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+    from .utils import dummy_tokenizers_objects
+
+    _import_structure["utils.dummy_tokenizers_objects"] = [
+        name for name in dir(dummy_tokenizers_objects) if not name.startswith("_")
+    ]
+else:
+    # Fast tokenizers structure
+    _import_structure["tokenization_utils_fast"] = ["PreTrainedTokenizerFast"]
+
+
+try:
+    if not (is_sentencepiece_available() and is_tokenizers_available()):
+        raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+    from .utils import dummy_sentencepiece_and_tokenizers_objects
+
+    _import_structure["utils.dummy_sentencepiece_and_tokenizers_objects"] = [
+        name for name in dir(dummy_sentencepiece_and_tokenizers_objects) if not name.startswith("_")
+    ]
+else:
+    _import_structure["convert_slow_tokenizer"] = [
+        "SLOW_TO_FAST_CONVERTERS",
+        "convert_slow_tokenizer",
+    ]
+
+try:
+    if not (is_mistral_common_available()):
+        raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+    from .utils import dummy_mistral_common_objects
+
+    _import_structure["utils.dummy_mistral_common_objects"] = [
+        name for name in dir(dummy_mistral_common_objects) if not name.startswith("_")
+    ]
+else:
+    _import_structure["tokenization_mistral_common"] = ["MistralCommonTokenizer"]
+
+# Vision-specific objects
+try:
+    if not is_vision_available():
+        raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+    from .utils import dummy_vision_objects
+
+    _import_structure["utils.dummy_vision_objects"] = [
+        name for name in dir(dummy_vision_objects) if not name.startswith("_")
+    ]
+else:
+    _import_structure["image_processing_base"] = ["ImageProcessingMixin"]
+    _import_structure["image_processing_utils"] = ["BaseImageProcessor"]
+    _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
+
+try:
+    if not is_torchvision_available():
+        raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+    from .utils import dummy_torchvision_objects
+
+    _import_structure["utils.dummy_torchvision_objects"] = [
+        name for name in dir(dummy_torchvision_objects) if not name.startswith("_")
+    ]
+else:
+    _import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"]
+    _import_structure["video_processing_utils"] = ["BaseVideoProcessor"]
+
+# PyTorch-backed objects
+try:
+    if not is_torch_available():
+        raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+    from .utils import dummy_pt_objects
+
+    _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
+else:
+    _import_structure["model_debugging_utils"] = [
+        "model_addition_debugger_context",
+    ]
+    _import_structure["activations"] = []
+    _import_structure["cache_utils"] = [
+        "CacheLayerMixin",
+        "DynamicLayer",
+        "StaticLayer",
+        "SlidingWindowLayer",
+        "ChunkedSlidingLayer",
+        "QuantoQuantizedLayer",
+        "HQQQuantizedLayer",
+        "Cache",
+        "DynamicCache",
+        "EncoderDecoderCache",
+        "HQQQuantizedCache",
+        "HybridCache",
+        "HybridChunkedCache",
+        "OffloadedCache",
+        "OffloadedStaticCache",
+        "QuantizedCache",
+        "QuantoQuantizedCache",
+        "SinkCache",
+        "SlidingWindowCache",
+        "StaticCache",
+    ]
+    _import_structure["data.datasets"] = [
+        "GlueDataset",
+        "GlueDataTrainingArguments",
+        "LineByLineTextDataset",
+        "LineByLineWithRefDataset",
+        "LineByLineWithSOPTextDataset",
+        "SquadDataset",
+        "SquadDataTrainingArguments",
+        "TextDataset",
+        "TextDatasetForNextSentencePrediction",
+    ]
+    _import_structure["generation"].extend(
+        [
+            "AlternatingCodebooksLogitsProcessor",
+            "BayesianDetectorConfig",
+            "BayesianDetectorModel",
+            "BeamScorer",
+            "BeamSearchScorer",
+            "ClassifierFreeGuidanceLogitsProcessor",
+            "ConstrainedBeamSearchScorer",
+            "Constraint",
+            "ConstraintListState",
+            "DisjunctiveConstraint",
+            "EncoderNoRepeatNGramLogitsProcessor",
+            "EncoderRepetitionPenaltyLogitsProcessor",
+            "EosTokenCriteria",
+            "EpsilonLogitsWarper",
+            "EtaLogitsWarper",
+            "ExponentialDecayLengthPenalty",
+            "ForcedBOSTokenLogitsProcessor",
+            "ForcedEOSTokenLogitsProcessor",
+            "GenerationMixin",
+            "HammingDiversityLogitsProcessor",
+            "InfNanRemoveLogitsProcessor",
+            "LogitNormalization",
+            "LogitsProcessor",
+            "LogitsProcessorList",
+            "MaxLengthCriteria",
+            "MaxTimeCriteria",
+            "MinLengthLogitsProcessor",
+            "MinNewTokensLengthLogitsProcessor",
+            "MinPLogitsWarper",
+            "NoBadWordsLogitsProcessor",
+            "NoRepeatNGramLogitsProcessor",
+            "PhrasalConstraint",
+            "PrefixConstrainedLogitsProcessor",
+            "RepetitionPenaltyLogitsProcessor",
+            "SequenceBiasLogitsProcessor",
+            "StoppingCriteria",
+            "StoppingCriteriaList",
+            "StopStringCriteria",
+            "SuppressTokensAtBeginLogitsProcessor",
+            "SuppressTokensLogitsProcessor",
+            "SynthIDTextWatermarkDetector",
+            "SynthIDTextWatermarkingConfig",
+            "SynthIDTextWatermarkLogitsProcessor",
+            "TemperatureLogitsWarper",
+            "TopKLogitsWarper",
+            "TopPLogitsWarper",
+            "TypicalLogitsWarper",
+            "UnbatchedClassifierFreeGuidanceLogitsProcessor",
+            "WatermarkDetector",
+            "WatermarkLogitsProcessor",
+            "WhisperTimeStampLogitsProcessor",
+        ]
+    )
+
+    # PyTorch domain libraries integration
+    _import_structure["integrations.executorch"] = [
+        "TorchExportableModuleWithStaticCache",
+        "convert_and_export_with_cache",
+    ]
+
+    _import_structure["modeling_flash_attention_utils"] = []
+    _import_structure["modeling_layers"] = ["GradientCheckpointingLayer"]
+    _import_structure["modeling_outputs"] = []
+    _import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"]
+    _import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"]
+    _import_structure["masking_utils"] = ["AttentionMaskInterface"]
+    _import_structure["optimization"] = [
+        "Adafactor",
+        "get_constant_schedule",
+        "get_constant_schedule_with_warmup",
+        "get_cosine_schedule_with_warmup",
+        "get_cosine_with_hard_restarts_schedule_with_warmup",
+        "get_inverse_sqrt_schedule",
+        "get_linear_schedule_with_warmup",
+        "get_polynomial_decay_schedule_with_warmup",
+        "get_scheduler",
+        "get_wsd_schedule",
+    ]
+    _import_structure["pytorch_utils"] = [
+        "Conv1D",
+        "apply_chunking_to_forward",
+        "prune_layer",
+        "infer_device",
+    ]
+    _import_structure["sagemaker"] = []
+    _import_structure["time_series_utils"] = []
+    _import_structure["trainer"] = ["Trainer"]
+    _import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
+    _import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"]
+
+# TensorFlow-backed objects
+try:
+    if not is_tf_available():
+        raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+    from .utils import dummy_tf_objects
+
+    _import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")]
+else:
+    _import_structure["activations_tf"] = []
+    _import_structure["generation"].extend(
+        [
+            "TFForcedBOSTokenLogitsProcessor",
+            "TFForcedEOSTokenLogitsProcessor",
+            "TFForceTokensLogitsProcessor",
+            "TFGenerationMixin",
+            "TFLogitsProcessor",
+            "TFLogitsProcessorList",
+            "TFLogitsWarper",
+            "TFMinLengthLogitsProcessor",
+            "TFNoBadWordsLogitsProcessor",
+            "TFNoRepeatNGramLogitsProcessor",
+            "TFRepetitionPenaltyLogitsProcessor",
+            "TFSuppressTokensAtBeginLogitsProcessor",
+            "TFSuppressTokensLogitsProcessor",
+            "TFTemperatureLogitsWarper",
+            "TFTopKLogitsWarper",
+            "TFTopPLogitsWarper",
+        ]
+    )
+    _import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
+    _import_structure["modeling_tf_outputs"] = []
+    _import_structure["modeling_tf_utils"] = [
+        "TFPreTrainedModel",
+        "TFSequenceSummary",
+        "TFSharedEmbeddings",
+        "shape_list",
+    ]
+    _import_structure["optimization_tf"] = [
+        "AdamWeightDecay",
+        "GradientAccumulator",
+        "WarmUp",
+        "create_optimizer",
+    ]
+    _import_structure["tf_utils"] = []
+
+
+# FLAX-backed objects
+try:
+    if not is_flax_available():
+        raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+    from .utils import dummy_flax_objects
+
+    _import_structure["utils.dummy_flax_objects"] = [
+        name for name in dir(dummy_flax_objects) if not name.startswith("_")
+    ]
+else:
+    _import_structure["generation"].extend(
+        [
+            "FlaxForcedBOSTokenLogitsProcessor",
+            "FlaxForcedEOSTokenLogitsProcessor",
+            "FlaxForceTokensLogitsProcessor",
+            "FlaxGenerationMixin",
+            "FlaxLogitsProcessor",
+            "FlaxLogitsProcessorList",
+            "FlaxLogitsWarper",
+            "FlaxMinLengthLogitsProcessor",
+            "FlaxTemperatureLogitsWarper",
+            "FlaxSuppressTokensAtBeginLogitsProcessor",
+            "FlaxSuppressTokensLogitsProcessor",
+            "FlaxTopKLogitsWarper",
+            "FlaxTopPLogitsWarper",
+            "FlaxWhisperTimeStampLogitsProcessor",
+        ]
+    )
+    _import_structure["modeling_flax_outputs"] = []
+    _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"]
+
+# Direct imports for type-checking
+if TYPE_CHECKING:
+    # All modeling imports
+    from .cache_utils import Cache as Cache
+    from .cache_utils import ChunkedSlidingLayer as ChunkedSlidingLayer
+    from .cache_utils import DynamicCache as DynamicCache
+    from .cache_utils import DynamicLayer as DynamicLayer
+    from .cache_utils import EncoderDecoderCache as EncoderDecoderCache
+    from .cache_utils import HQQQuantizedCache as HQQQuantizedCache
+    from .cache_utils import HQQQuantizedLayer as HQQQuantizedLayer
+    from .cache_utils import HybridCache as HybridCache
+    from .cache_utils import MambaCache as MambaCache
+    from .cache_utils import OffloadedCache as OffloadedCache
+    from .cache_utils import OffloadedStaticCache as OffloadedStaticCache
+    from .cache_utils import QuantizedCache as QuantizedCache
+    from .cache_utils import QuantoQuantizedCache as QuantoQuantizedCache
+    from .cache_utils import QuantoQuantizedLayer as QuantoQuantizedLayer
+    from .cache_utils import SinkCache as SinkCache
+    from .cache_utils import SlidingWindowCache as SlidingWindowCache
+    from .cache_utils import SlidingWindowLayer as SlidingWindowLayer
+    from .cache_utils import StaticCache as StaticCache
+    from .cache_utils import StaticLayer as StaticLayer
+    from .configuration_utils import PretrainedConfig as PretrainedConfig
+    from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS as SLOW_TO_FAST_CONVERTERS
+    from .convert_slow_tokenizer import convert_slow_tokenizer as convert_slow_tokenizer
+
+    # Data
+    from .data import DataProcessor as DataProcessor
+    from .data import InputExample as InputExample
+    from .data import InputFeatures as InputFeatures
+    from .data import SingleSentenceClassificationProcessor as SingleSentenceClassificationProcessor
+    from .data import SquadExample as SquadExample
+    from .data import SquadFeatures as SquadFeatures
+    from .data import SquadV1Processor as SquadV1Processor
+    from .data import SquadV2Processor as SquadV2Processor
+    from .data import glue_compute_metrics as glue_compute_metrics
+    from .data import glue_convert_examples_to_features as glue_convert_examples_to_features
+    from .data import glue_output_modes as glue_output_modes
+    from .data import glue_processors as glue_processors
+    from .data import glue_tasks_num_labels as glue_tasks_num_labels
+    from .data import squad_convert_examples_to_features as squad_convert_examples_to_features
+    from .data import xnli_compute_metrics as xnli_compute_metrics
+    from .data import xnli_output_modes as xnli_output_modes
+    from .data import xnli_processors as xnli_processors
+    from .data import xnli_tasks_num_labels as xnli_tasks_num_labels
+    from .data.data_collator import DataCollator as DataCollator
+    from .data.data_collator import DataCollatorForLanguageModeling as DataCollatorForLanguageModeling
+    from .data.data_collator import DataCollatorForMultipleChoice as DataCollatorForMultipleChoice
+    from .data.data_collator import (
+        DataCollatorForPermutationLanguageModeling as DataCollatorForPermutationLanguageModeling,
+    )
+    from .data.data_collator import DataCollatorForSeq2Seq as DataCollatorForSeq2Seq
+    from .data.data_collator import DataCollatorForSOP as DataCollatorForSOP
+    from .data.data_collator import DataCollatorForTokenClassification as DataCollatorForTokenClassification
+    from .data.data_collator import DataCollatorForWholeWordMask as DataCollatorForWholeWordMask
+    from .data.data_collator import DataCollatorWithFlattening as DataCollatorWithFlattening
+    from .data.data_collator import DataCollatorWithPadding as DataCollatorWithPadding
+    from .data.data_collator import DefaultDataCollator as DefaultDataCollator
+    from .data.data_collator import default_data_collator as default_data_collator
+    from .data.datasets import GlueDataset as GlueDataset
+    from .data.datasets import GlueDataTrainingArguments as GlueDataTrainingArguments
+    from .data.datasets import LineByLineTextDataset as LineByLineTextDataset
+    from .data.datasets import LineByLineWithRefDataset as LineByLineWithRefDataset
+    from .data.datasets import LineByLineWithSOPTextDataset as LineByLineWithSOPTextDataset
+    from .data.datasets import SquadDataset as SquadDataset
+    from .data.datasets import SquadDataTrainingArguments as SquadDataTrainingArguments
+    from .data.datasets import TextDataset as TextDataset
+    from .data.datasets import TextDatasetForNextSentencePrediction as TextDatasetForNextSentencePrediction
+    from .feature_extraction_sequence_utils import SequenceFeatureExtractor as SequenceFeatureExtractor
+
+    # Feature Extractor
+    from .feature_extraction_utils import BatchFeature as BatchFeature
+    from .feature_extraction_utils import FeatureExtractionMixin as FeatureExtractionMixin
+
+    # Generation
+    from .generation import AlternatingCodebooksLogitsProcessor as AlternatingCodebooksLogitsProcessor
+    from .generation import AsyncTextIteratorStreamer as AsyncTextIteratorStreamer
+    from .generation import BayesianDetectorConfig as BayesianDetectorConfig
+    from .generation import BayesianDetectorModel as BayesianDetectorModel
+    from .generation import BeamScorer as BeamScorer
+    from .generation import BeamSearchScorer as BeamSearchScorer
+    from .generation import ClassifierFreeGuidanceLogitsProcessor as ClassifierFreeGuidanceLogitsProcessor
+    from .generation import CompileConfig as CompileConfig
+    from .generation import ConstrainedBeamSearchScorer as ConstrainedBeamSearchScorer
+    from .generation import Constraint as Constraint
+    from .generation import ConstraintListState as ConstraintListState
+    from .generation import DisjunctiveConstraint as DisjunctiveConstraint
+    from .generation import EncoderNoRepeatNGramLogitsProcessor as EncoderNoRepeatNGramLogitsProcessor
+    from .generation import EncoderRepetitionPenaltyLogitsProcessor as EncoderRepetitionPenaltyLogitsProcessor
+    from .generation import EosTokenCriteria as EosTokenCriteria
+    from .generation import EpsilonLogitsWarper as EpsilonLogitsWarper
+    from .generation import EtaLogitsWarper as EtaLogitsWarper
+    from .generation import ExponentialDecayLengthPenalty as ExponentialDecayLengthPenalty
+    from .generation import FlaxForcedBOSTokenLogitsProcessor as FlaxForcedBOSTokenLogitsProcessor
+    from .generation import FlaxForcedEOSTokenLogitsProcessor as FlaxForcedEOSTokenLogitsProcessor
+    from .generation import FlaxForceTokensLogitsProcessor as FlaxForceTokensLogitsProcessor
+    from .generation import FlaxGenerationMixin as FlaxGenerationMixin
+    from .generation import FlaxLogitsProcessor as FlaxLogitsProcessor
+    from .generation import FlaxLogitsProcessorList as FlaxLogitsProcessorList
+    from .generation import FlaxLogitsWarper as FlaxLogitsWarper
+    from .generation import FlaxMinLengthLogitsProcessor as FlaxMinLengthLogitsProcessor
+    from .generation import FlaxSuppressTokensAtBeginLogitsProcessor as FlaxSuppressTokensAtBeginLogitsProcessor
+    from .generation import FlaxSuppressTokensLogitsProcessor as FlaxSuppressTokensLogitsProcessor
+    from .generation import FlaxTemperatureLogitsWarper as FlaxTemperatureLogitsWarper
+    from .generation import FlaxTopKLogitsWarper as FlaxTopKLogitsWarper
+    from .generation import FlaxTopPLogitsWarper as FlaxTopPLogitsWarper
+    from .generation import FlaxWhisperTimeStampLogitsProcessor as FlaxWhisperTimeStampLogitsProcessor
+    from .generation import ForcedBOSTokenLogitsProcessor as ForcedBOSTokenLogitsProcessor
+    from .generation import ForcedEOSTokenLogitsProcessor as ForcedEOSTokenLogitsProcessor
+    from .generation import GenerationConfig as GenerationConfig
+    from .generation import GenerationMixin as GenerationMixin
+    from .generation import HammingDiversityLogitsProcessor as HammingDiversityLogitsProcessor
+    from .generation import InfNanRemoveLogitsProcessor as InfNanRemoveLogitsProcessor
+    from .generation import LogitNormalization as LogitNormalization
+    from .generation import LogitsProcessor as LogitsProcessor
+    from .generation import LogitsProcessorList as LogitsProcessorList
+    from .generation import MaxLengthCriteria as MaxLengthCriteria
+    from .generation import MaxTimeCriteria as MaxTimeCriteria
+    from .generation import MinLengthLogitsProcessor as MinLengthLogitsProcessor
+    from .generation import MinNewTokensLengthLogitsProcessor as MinNewTokensLengthLogitsProcessor
+    from .generation import MinPLogitsWarper as MinPLogitsWarper
+    from .generation import NoBadWordsLogitsProcessor as NoBadWordsLogitsProcessor
+    from .generation import NoRepeatNGramLogitsProcessor as NoRepeatNGramLogitsProcessor
+    from .generation import PhrasalConstraint as PhrasalConstraint
+    from .generation import PrefixConstrainedLogitsProcessor as PrefixConstrainedLogitsProcessor
+    from .generation import RepetitionPenaltyLogitsProcessor as RepetitionPenaltyLogitsProcessor
+    from .generation import SequenceBiasLogitsProcessor as SequenceBiasLogitsProcessor
+    from .generation import StoppingCriteria as StoppingCriteria
+    from .generation import StoppingCriteriaList as StoppingCriteriaList
+    from .generation import StopStringCriteria as StopStringCriteria
+    from .generation import SuppressTokensAtBeginLogitsProcessor as SuppressTokensAtBeginLogitsProcessor
+    from .generation import SuppressTokensLogitsProcessor as SuppressTokensLogitsProcessor
+    from .generation import SynthIDTextWatermarkDetector as SynthIDTextWatermarkDetector
+    from .generation import SynthIDTextWatermarkingConfig as SynthIDTextWatermarkingConfig
+    from .generation import SynthIDTextWatermarkLogitsProcessor as SynthIDTextWatermarkLogitsProcessor
+    from .generation import TemperatureLogitsWarper as TemperatureLogitsWarper
+    from .generation import TextIteratorStreamer as TextIteratorStreamer
+    from .generation import TextStreamer as TextStreamer
+    from .generation import TFForcedBOSTokenLogitsProcessor as TFForcedBOSTokenLogitsProcessor
+    from .generation import TFForcedEOSTokenLogitsProcessor as TFForcedEOSTokenLogitsProcessor
+    from .generation import TFForceTokensLogitsProcessor as TFForceTokensLogitsProcessor
+    from .generation import TFGenerationMixin as TFGenerationMixin
+    from .generation import TFLogitsProcessor as TFLogitsProcessor
+    from .generation import TFLogitsProcessorList as TFLogitsProcessorList
+    from .generation import TFLogitsWarper as TFLogitsWarper
+    from .generation import TFMinLengthLogitsProcessor as TFMinLengthLogitsProcessor
+    from .generation import TFNoBadWordsLogitsProcessor as TFNoBadWordsLogitsProcessor
+    from .generation import TFNoRepeatNGramLogitsProcessor as TFNoRepeatNGramLogitsProcessor
+    from .generation import TFRepetitionPenaltyLogitsProcessor as TFRepetitionPenaltyLogitsProcessor
+    from .generation import TFSuppressTokensAtBeginLogitsProcessor as TFSuppressTokensAtBeginLogitsProcessor
+    from .generation import TFSuppressTokensLogitsProcessor as TFSuppressTokensLogitsProcessor
+    from .generation import TFTemperatureLogitsWarper as TFTemperatureLogitsWarper
+    from .generation import TFTopKLogitsWarper as TFTopKLogitsWarper
+    from .generation import TFTopPLogitsWarper as TFTopPLogitsWarper
+    from .generation import TopKLogitsWarper as TopKLogitsWarper
+    from .generation import TopPLogitsWarper as TopPLogitsWarper
+    from .generation import TypicalLogitsWarper as TypicalLogitsWarper
+    from .generation import (
+        UnbatchedClassifierFreeGuidanceLogitsProcessor as UnbatchedClassifierFreeGuidanceLogitsProcessor,
+    )
+    from .generation import WatermarkDetector as WatermarkDetector
+    from .generation import WatermarkingConfig as WatermarkingConfig
+    from .generation import WatermarkLogitsProcessor as WatermarkLogitsProcessor
+    from .generation import WhisperTimeStampLogitsProcessor as WhisperTimeStampLogitsProcessor
+    from .hf_argparser import HfArgumentParser as HfArgumentParser
+    from .image_processing_base import ImageProcessingMixin as ImageProcessingMixin
+    from .image_processing_utils import BaseImageProcessor as BaseImageProcessor
+    from .image_processing_utils_fast import BaseImageProcessorFast as BaseImageProcessorFast
+    from .image_utils import ImageFeatureExtractionMixin as ImageFeatureExtractionMixin
+
+    # Integrations
+    from .integrations import is_clearml_available as is_clearml_available
+    from .integrations import is_comet_available as is_comet_available
+    from .integrations import is_dvclive_available as is_dvclive_available
+    from .integrations import is_neptune_available as is_neptune_available
+    from .integrations import is_optuna_available as is_optuna_available
+    from .integrations import is_ray_available as is_ray_available
+    from .integrations import is_ray_tune_available as is_ray_tune_available
+    from .integrations import is_sigopt_available as is_sigopt_available
+    from .integrations import is_swanlab_available as is_swanlab_available
+    from .integrations import is_tensorboard_available as is_tensorboard_available
+    from .integrations import is_trackio_available as is_trackio_available
+    from .integrations import is_wandb_available as is_wandb_available
+    from .integrations.executorch import TorchExportableModuleWithStaticCache as TorchExportableModuleWithStaticCache
+    from .integrations.executorch import convert_and_export_with_cache as convert_and_export_with_cache
+    from .keras_callbacks import KerasMetricCallback as KerasMetricCallback
+    from .keras_callbacks import PushToHubCallback as PushToHubCallback
+    from .masking_utils import AttentionMaskInterface as AttentionMaskInterface
+    from .model_debugging_utils import model_addition_debugger_context as model_addition_debugger_context
+
+    # Model Cards
+    from .modelcard import ModelCard as ModelCard
+    from .modeling_flax_utils import FlaxPreTrainedModel as FlaxPreTrainedModel
+    from .modeling_layers import GradientCheckpointingLayer as GradientCheckpointingLayer
+    from .modeling_rope_utils import ROPE_INIT_FUNCTIONS as ROPE_INIT_FUNCTIONS
+    from .modeling_rope_utils import dynamic_rope_update as dynamic_rope_update
+
+    # TF 2.0 <=> PyTorch conversion utilities
+    from .modeling_tf_pytorch_utils import (
+        convert_tf_weight_name_to_pt_weight_name as convert_tf_weight_name_to_pt_weight_name,
+    )
+    from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model as load_pytorch_checkpoint_in_tf2_model
+    from .modeling_tf_pytorch_utils import load_pytorch_model_in_tf2_model as load_pytorch_model_in_tf2_model
+    from .modeling_tf_pytorch_utils import load_pytorch_weights_in_tf2_model as load_pytorch_weights_in_tf2_model
+    from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model as load_tf2_checkpoint_in_pytorch_model
+    from .modeling_tf_pytorch_utils import load_tf2_model_in_pytorch_model as load_tf2_model_in_pytorch_model
+    from .modeling_tf_pytorch_utils import load_tf2_weights_in_pytorch_model as load_tf2_weights_in_pytorch_model
+    from .modeling_tf_utils import TFPreTrainedModel as TFPreTrainedModel
+    from .modeling_tf_utils import TFSequenceSummary as TFSequenceSummary
+    from .modeling_tf_utils import TFSharedEmbeddings as TFSharedEmbeddings
+    from .modeling_tf_utils import shape_list as shape_list
+    from .modeling_utils import AttentionInterface as AttentionInterface
+    from .modeling_utils import PreTrainedModel as PreTrainedModel
+    from .models import *
+    from .models.timm_wrapper import TimmWrapperImageProcessor as TimmWrapperImageProcessor
+
+    # Optimization
+    from .optimization import Adafactor as Adafactor
+    from .optimization import get_constant_schedule as get_constant_schedule
+    from .optimization import get_constant_schedule_with_warmup as get_constant_schedule_with_warmup
+    from .optimization import get_cosine_schedule_with_warmup as get_cosine_schedule_with_warmup
+    from .optimization import (
+        get_cosine_with_hard_restarts_schedule_with_warmup as get_cosine_with_hard_restarts_schedule_with_warmup,
+    )
+    from .optimization import get_inverse_sqrt_schedule as get_inverse_sqrt_schedule
+    from .optimization import get_linear_schedule_with_warmup as get_linear_schedule_with_warmup
+    from .optimization import get_polynomial_decay_schedule_with_warmup as get_polynomial_decay_schedule_with_warmup
+    from .optimization import get_scheduler as get_scheduler
+    from .optimization import get_wsd_schedule as get_wsd_schedule
+
+    # Optimization
+    from .optimization_tf import AdamWeightDecay as AdamWeightDecay
+    from .optimization_tf import GradientAccumulator as GradientAccumulator
+    from .optimization_tf import WarmUp as WarmUp
+    from .optimization_tf import create_optimizer as create_optimizer
+
+    # Pipelines
+    from .pipelines import AudioClassificationPipeline as AudioClassificationPipeline
+    from .pipelines import AutomaticSpeechRecognitionPipeline as AutomaticSpeechRecognitionPipeline
+    from .pipelines import CsvPipelineDataFormat as CsvPipelineDataFormat
+    from .pipelines import DepthEstimationPipeline as DepthEstimationPipeline
+    from .pipelines import DocumentQuestionAnsweringPipeline as DocumentQuestionAnsweringPipeline
+    from .pipelines import FeatureExtractionPipeline as FeatureExtractionPipeline
+    from .pipelines import FillMaskPipeline as FillMaskPipeline
+    from .pipelines import ImageClassificationPipeline as ImageClassificationPipeline
+    from .pipelines import ImageFeatureExtractionPipeline as ImageFeatureExtractionPipeline
+    from .pipelines import ImageSegmentationPipeline as ImageSegmentationPipeline
+    from .pipelines import ImageTextToTextPipeline as ImageTextToTextPipeline
+    from .pipelines import ImageToImagePipeline as ImageToImagePipeline
+    from .pipelines import ImageToTextPipeline as ImageToTextPipeline
+    from .pipelines import JsonPipelineDataFormat as JsonPipelineDataFormat
+    from .pipelines import KeypointMatchingPipeline as KeypointMatchingPipeline
+    from .pipelines import MaskGenerationPipeline as MaskGenerationPipeline
+    from .pipelines import NerPipeline as NerPipeline
+    from .pipelines import ObjectDetectionPipeline as ObjectDetectionPipeline
+    from .pipelines import PipedPipelineDataFormat as PipedPipelineDataFormat
+    from .pipelines import Pipeline as Pipeline
+    from .pipelines import PipelineDataFormat as PipelineDataFormat
+    from .pipelines import QuestionAnsweringPipeline as QuestionAnsweringPipeline
+    from .pipelines import SummarizationPipeline as SummarizationPipeline
+    from .pipelines import TableQuestionAnsweringPipeline as TableQuestionAnsweringPipeline
+    from .pipelines import Text2TextGenerationPipeline as Text2TextGenerationPipeline
+    from .pipelines import TextClassificationPipeline as TextClassificationPipeline
+    from .pipelines import TextGenerationPipeline as TextGenerationPipeline
+    from .pipelines import TextToAudioPipeline as TextToAudioPipeline
+    from .pipelines import TokenClassificationPipeline as TokenClassificationPipeline
+    from .pipelines import TranslationPipeline as TranslationPipeline
+    from .pipelines import VideoClassificationPipeline as VideoClassificationPipeline
+    from .pipelines import VisualQuestionAnsweringPipeline as VisualQuestionAnsweringPipeline
+    from .pipelines import ZeroShotAudioClassificationPipeline as ZeroShotAudioClassificationPipeline
+    from .pipelines import ZeroShotClassificationPipeline as ZeroShotClassificationPipeline
+    from .pipelines import ZeroShotImageClassificationPipeline as ZeroShotImageClassificationPipeline
+    from .pipelines import ZeroShotObjectDetectionPipeline as ZeroShotObjectDetectionPipeline
+    from .pipelines import pipeline as pipeline
+    from .processing_utils import ProcessorMixin as ProcessorMixin
+    from .pytorch_utils import Conv1D as Conv1D
+    from .pytorch_utils import apply_chunking_to_forward as apply_chunking_to_forward
+    from .pytorch_utils import prune_layer as prune_layer
+
+    # Tokenization
+    from .tokenization_utils import PreTrainedTokenizer as PreTrainedTokenizer
+    from .tokenization_utils_base import AddedToken as AddedToken
+    from .tokenization_utils_base import BatchEncoding as BatchEncoding
+    from .tokenization_utils_base import CharSpan as CharSpan
+    from .tokenization_utils_base import PreTrainedTokenizerBase as PreTrainedTokenizerBase
+    from .tokenization_utils_base import SpecialTokensMixin as SpecialTokensMixin
+    from .tokenization_utils_base import TokenSpan as TokenSpan
+    from .tokenization_utils_fast import PreTrainedTokenizerFast as PreTrainedTokenizerFast
+
+    # Trainer
+    from .trainer import Trainer as Trainer
+
+    # Trainer
+    from .trainer_callback import DefaultFlowCallback as DefaultFlowCallback
+    from .trainer_callback import EarlyStoppingCallback as EarlyStoppingCallback
+    from .trainer_callback import PrinterCallback as PrinterCallback
+    from .trainer_callback import ProgressCallback as ProgressCallback
+    from .trainer_callback import TrainerCallback as TrainerCallback
+    from .trainer_callback import TrainerControl as TrainerControl
+    from .trainer_callback import TrainerState as TrainerState
+    from .trainer_pt_utils import torch_distributed_zero_first as torch_distributed_zero_first
+    from .trainer_seq2seq import Seq2SeqTrainer as Seq2SeqTrainer
+    from .trainer_utils import EvalPrediction as EvalPrediction
+    from .trainer_utils import IntervalStrategy as IntervalStrategy
+    from .trainer_utils import SchedulerType as SchedulerType
+    from .trainer_utils import enable_full_determinism as enable_full_determinism
+    from .trainer_utils import set_seed as set_seed
+    from .training_args import TrainingArguments as TrainingArguments
+    from .training_args_seq2seq import Seq2SeqTrainingArguments as Seq2SeqTrainingArguments
+    from .training_args_tf import TFTrainingArguments as TFTrainingArguments
+
+    # Files and general utilities
+    from .utils import CONFIG_NAME as CONFIG_NAME
+    from .utils import MODEL_CARD_NAME as MODEL_CARD_NAME
+    from .utils import PYTORCH_PRETRAINED_BERT_CACHE as PYTORCH_PRETRAINED_BERT_CACHE
+    from .utils import PYTORCH_TRANSFORMERS_CACHE as PYTORCH_TRANSFORMERS_CACHE
+    from .utils import SPIECE_UNDERLINE as SPIECE_UNDERLINE
+    from .utils import TF2_WEIGHTS_NAME as TF2_WEIGHTS_NAME
+    from .utils import TF_WEIGHTS_NAME as TF_WEIGHTS_NAME
+    from .utils import TRANSFORMERS_CACHE as TRANSFORMERS_CACHE
+    from .utils import WEIGHTS_NAME as WEIGHTS_NAME
+    from .utils import TensorType as TensorType
+    from .utils import add_end_docstrings as add_end_docstrings
+    from .utils import add_start_docstrings as add_start_docstrings
+    from .utils import is_apex_available as is_apex_available
+    from .utils import is_av_available as is_av_available
+    from .utils import is_datasets_available as is_datasets_available
+    from .utils import is_faiss_available as is_faiss_available
+    from .utils import is_matplotlib_available as is_matplotlib_available
+    from .utils import is_phonemizer_available as is_phonemizer_available
+    from .utils import is_psutil_available as is_psutil_available
+    from .utils import is_py3nvml_available as is_py3nvml_available
+    from .utils import is_pyctcdecode_available as is_pyctcdecode_available
+    from .utils import is_sacremoses_available as is_sacremoses_available
+    from .utils import is_safetensors_available as is_safetensors_available
+    from .utils import is_sklearn_available as is_sklearn_available
+    from .utils import is_torch_hpu_available as is_torch_hpu_available
+    from .utils import is_torch_mlu_available as is_torch_mlu_available
+    from .utils import is_torch_musa_available as is_torch_musa_available
+    from .utils import is_torch_neuroncore_available as is_torch_neuroncore_available
+    from .utils import is_torch_npu_available as is_torch_npu_available
+    from .utils import is_torch_xla_available as is_torch_xla_available
+    from .utils import is_torch_xpu_available as is_torch_xpu_available
+    from .utils import logging as logging
+
+    # bitsandbytes config
+    from .utils.quantization_config import AqlmConfig as AqlmConfig
+    from .utils.quantization_config import AutoRoundConfig as AutoRoundConfig
+    from .utils.quantization_config import AwqConfig as AwqConfig
+    from .utils.quantization_config import BitNetQuantConfig as BitNetQuantConfig
+    from .utils.quantization_config import BitsAndBytesConfig as BitsAndBytesConfig
+    from .utils.quantization_config import CompressedTensorsConfig as CompressedTensorsConfig
+    from .utils.quantization_config import EetqConfig as EetqConfig
+    from .utils.quantization_config import FbgemmFp8Config as FbgemmFp8Config
+    from .utils.quantization_config import FineGrainedFP8Config as FineGrainedFP8Config
+    from .utils.quantization_config import FPQuantConfig as FPQuantConfig
+    from .utils.quantization_config import GPTQConfig as GPTQConfig
+    from .utils.quantization_config import HiggsConfig as HiggsConfig
+    from .utils.quantization_config import HqqConfig as HqqConfig
+    from .utils.quantization_config import QuantoConfig as QuantoConfig
+    from .utils.quantization_config import QuarkConfig as QuarkConfig
+    from .utils.quantization_config import SpQRConfig as SpQRConfig
+    from .utils.quantization_config import TorchAoConfig as TorchAoConfig
+    from .utils.quantization_config import VptqConfig as VptqConfig
+    from .video_processing_utils import BaseVideoProcessor as BaseVideoProcessor
+
+else:
+    import sys
+
+    _import_structure = {k: set(v) for k, v in _import_structure.items()}
+
+    import_structure = define_import_structure(Path(__file__).parent / "models", prefix="models")
+    import_structure[frozenset({})].update(_import_structure)
+
+    sys.modules[__name__] = _LazyModule(
+        __name__,
+        globals()["__file__"],
+        import_structure,
+        module_spec=__spec__,
+        extra_objects={"__version__": __version__},
+    )
+
+
+if not is_tf_available() and not is_torch_available() and not is_flax_available():
+    logger.warning_advice(
+        "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. "
+        "Models won't be available and only tokenizers, configuration "
+        "and file/data utilities can be used."
+    )
diff --git a/phivenv/Lib/site-packages/transformers/activations.py b/phivenv/Lib/site-packages/transformers/activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..813cd2c3c811afd72473869468c54526942cc839
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/activations.py
@@ -0,0 +1,324 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from collections import OrderedDict
+
+import torch
+from torch import Tensor, nn
+
+from .utils import logging
+from .utils.import_utils import is_torchdynamo_compiling
+
+
+logger = logging.get_logger(__name__)
+
+
+class PytorchGELUTanh(nn.Module):
+    """
+    A fast C implementation of the tanh approximation of the GeLU activation function. See
+    https://huggingface.co/papers/1606.08415.
+
+    This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
+    match due to rounding errors.
+    """
+
+    def forward(self, input: Tensor) -> Tensor:
+        return nn.functional.gelu(input, approximate="tanh")
+
+
+class NewGELUActivation(nn.Module):
+    """
+    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
+    the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
+    """
+
+    def forward(self, input: Tensor) -> Tensor:
+        return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
+
+
+class GELUActivation(nn.Module):
+    """
+    Original Implementation of the GELU activation function in Google BERT repo when initially created. For
+    information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
+    torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
+    Also see the Gaussian Error Linear Units paper: https://huggingface.co/papers/1606.08415
+    """
+
+    def __init__(self, use_gelu_python: bool = False):
+        super().__init__()
+        if use_gelu_python:
+            self.act = self._gelu_python
+        else:
+            self.act = nn.functional.gelu
+
+    def _gelu_python(self, input: Tensor) -> Tensor:
+        return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
+
+    def forward(self, input: Tensor) -> Tensor:
+        return self.act(input)
+
+
+class FastGELUActivation(nn.Module):
+    """
+    Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
+    """
+
+    def forward(self, input: Tensor) -> Tensor:
+        return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
+
+
+class QuickGELUActivation(nn.Module):
+    """
+    Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
+    """
+
+    def forward(self, input: Tensor) -> Tensor:
+        return input * torch.sigmoid(1.702 * input)
+
+
+class ClippedGELUActivation(nn.Module):
+    """
+    Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
+    it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
+    https://huggingface.co/papers/2004.09602.
+
+    Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
+    initially created.
+
+    For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
+    torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://huggingface.co/papers/1606.08415
+    """
+
+    def __init__(self, min: float, max: float):
+        if min > max:
+            raise ValueError(f"min should be < max (got min: {min}, max: {max})")
+
+        super().__init__()
+        self.min = min
+        self.max = max
+
+    def forward(self, x: Tensor) -> Tensor:
+        return torch.clip(gelu(x), self.min, self.max)
+
+
+class AccurateGELUActivation(nn.Module):
+    """
+    Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
+    https://github.com/hendrycks/GELUs
+
+    Implemented along with MEGA (Moving Average Equipped Gated Attention)
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.precomputed_constant = math.sqrt(2 / math.pi)
+
+    def forward(self, input: Tensor) -> Tensor:
+        return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
+
+
+class MishActivation(nn.Module):
+    """
+    See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://huggingface.co/papers/1908.08681). Also
+    visit the official repository for the paper: https://github.com/digantamisra98/Mish
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.act = nn.functional.mish
+
+    def _mish_python(self, input: Tensor) -> Tensor:
+        return input * torch.tanh(nn.functional.softplus(input))
+
+    def forward(self, input: Tensor) -> Tensor:
+        return self.act(input)
+
+
+class LinearActivation(nn.Module):
+    """
+    Applies the linear activation function, i.e. forwarding input directly to output.
+    """
+
+    def forward(self, input: Tensor) -> Tensor:
+        return input
+
+
+class LaplaceActivation(nn.Module):
+    """
+    Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
+    https://huggingface.co/papers/2209.10655
+
+    Inspired by squared relu, but with bounded range and gradient for better stability
+    """
+
+    def forward(self, input, mu=0.707107, sigma=0.282095):
+        input = (input - mu).div(sigma * math.sqrt(2.0))
+        return 0.5 * (1.0 + torch.erf(input))
+
+
+class ReLUSquaredActivation(nn.Module):
+    """
+    Applies the relu^2 activation introduced in https://huggingface.co/papers/2109.08668v2
+    """
+
+    def forward(self, input):
+        relu_applied = nn.functional.relu(input)
+        squared = torch.square(relu_applied)
+        return squared
+
+
+class ClassInstantier(OrderedDict):
+    def __getitem__(self, key):
+        content = super().__getitem__(key)
+        cls, kwargs = content if isinstance(content, tuple) else (content, {})
+        return cls(**kwargs)
+
+
+class XIELUActivation(nn.Module):
+    """
+    Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
+
+    If the user has installed the nickjbrowning/XIELU wheel, we import xIELU CUDA
+    Otherwise, we emit a single warning and use xIELU Python
+    """
+
+    def __init__(
+        self,
+        alpha_p_init=0.8,
+        alpha_n_init=0.8,
+        beta=0.5,
+        eps=-1e-6,
+        dtype=torch.bfloat16,
+        with_vector_loads=False,
+    ):
+        super().__init__()
+        self.alpha_p = nn.Parameter(torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(0))
+        self.alpha_n = nn.Parameter(
+            torch.log(torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1).unsqueeze(0)
+        )
+        self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
+        self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
+        self.with_vector_loads = with_vector_loads
+        # Temporary until xIELU CUDA fully implemented
+        self._beta_scalar = float(self.beta.detach().cpu().float().item())
+        self._eps_scalar = float(self.eps.detach().cpu().float().item())
+
+        self._xielu_cuda_obj = None
+        try:
+            import xielu.ops  # noqa: F401
+
+            self._xielu_cuda_obj = torch.classes.xielu.XIELU()
+            msg = "Using experimental xIELU CUDA."
+            try:
+                from torch._dynamo import allow_in_graph
+
+                self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
+                msg += " Enabled torch._dynamo for xIELU CUDA."
+            except Exception as err:
+                msg += f" Could not enable torch._dynamo for xIELU ({err}) - this may result in slower performance."
+                self._xielu_cuda_fn = self._xielu_cuda
+            logger.warning_once(msg)
+        except Exception as err:
+            logger.warning_once(
+                "CUDA-fused xIELU not available (%s) – falling back to a Python version.\n"
+                "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
+                str(err),
+            )
+
+    def _xielu_python(self, x: Tensor) -> Tensor:
+        alpha_p = nn.functional.softplus(self.alpha_p)
+        alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
+        return torch.where(
+            x > 0,
+            alpha_p * x * x + self.beta * x,
+            (torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
+        )
+
+    def _xielu_cuda(self, x: Tensor) -> Tensor:
+        """Firewall function to prevent torch.compile from seeing .item() calls"""
+        original_shape = x.shape
+        # CUDA kernel expects 3D tensors, reshape if needed
+        while x.dim() < 3:
+            x = x.unsqueeze(0)
+        if x.dim() > 3:
+            x = x.view(-1, 1, x.size(-1))
+        if original_shape != x.shape:
+            logger.warning_once(
+                "Warning: xIELU input tensor expects 3 dimensions but got (shape: %s). Reshaping to (shape: %s).",
+                original_shape,
+                x.shape,
+            )
+        result = self._xielu_cuda_obj.forward(
+            x,
+            self.alpha_p,
+            self.alpha_n,
+            # Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
+            self._beta_scalar,
+            self._eps_scalar,
+            self.with_vector_loads,
+        )
+        return result.view(original_shape)
+
+    def forward(self, input: Tensor) -> Tensor:
+        if self._xielu_cuda_obj is not None and input.is_cuda:
+            if not is_torchdynamo_compiling():
+                return self._xielu_cuda_fn(input)
+            else:
+                logger.warning_once("torch._dynamo is compiling, using Python version of xIELU.")
+        return self._xielu_python(input)
+
+
+ACT2CLS = {
+    "gelu": GELUActivation,
+    "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
+    "gelu_fast": FastGELUActivation,
+    "gelu_new": NewGELUActivation,
+    "gelu_python": (GELUActivation, {"use_gelu_python": True}),
+    "gelu_pytorch_tanh": PytorchGELUTanh,
+    "gelu_accurate": AccurateGELUActivation,
+    "laplace": LaplaceActivation,
+    "leaky_relu": nn.LeakyReLU,
+    "linear": LinearActivation,
+    "mish": MishActivation,
+    "quick_gelu": QuickGELUActivation,
+    "relu": nn.ReLU,
+    "relu2": ReLUSquaredActivation,
+    "relu6": nn.ReLU6,
+    "sigmoid": nn.Sigmoid,
+    "silu": nn.SiLU,
+    "swish": nn.SiLU,
+    "tanh": nn.Tanh,
+    "prelu": nn.PReLU,
+    "xielu": XIELUActivation,
+}
+ACT2FN = ClassInstantier(ACT2CLS)
+
+
+def get_activation(activation_string):
+    if activation_string in ACT2FN:
+        return ACT2FN[activation_string]
+    else:
+        raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
+
+
+# For backwards compatibility with: from activations import gelu_python
+gelu_python = get_activation("gelu_python")
+gelu_new = get_activation("gelu_new")
+gelu = get_activation("gelu")
+gelu_fast = get_activation("gelu_fast")
+quick_gelu = get_activation("quick_gelu")
+silu = get_activation("silu")
+mish = get_activation("mish")
+linear_act = get_activation("linear")
diff --git a/phivenv/Lib/site-packages/transformers/activations_tf.py b/phivenv/Lib/site-packages/transformers/activations_tf.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dccf6c4f46b8fe1f98d7e57bd8611f660ed19f4
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/activations_tf.py
@@ -0,0 +1,147 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+import tensorflow as tf
+from packaging.version import parse
+
+
+try:
+    import tf_keras as keras
+except (ModuleNotFoundError, ImportError):
+    import keras
+
+    if parse(keras.__version__).major > 2:
+        raise ValueError(
+            "Your currently installed version of Keras is Keras 3, but this is not yet supported in "
+            "Transformers. Please install the backwards-compatible tf-keras package with "
+            "`pip install tf-keras`."
+        )
+
+
+def _gelu(x):
+    """
+    Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
+    initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
+    0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
+    https://huggingface.co/papers/1606.08415
+    """
+    x = tf.convert_to_tensor(x)
+    cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))
+
+    return x * cdf
+
+
+def _gelu_new(x):
+    """
+    Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://huggingface.co/papers/1606.0841
+
+    Args:
+        x: float Tensor to perform activation
+
+    Returns:
+        `x` with the GELU activation applied.
+    """
+    x = tf.convert_to_tensor(x)
+    pi = tf.cast(math.pi, x.dtype)
+    coeff = tf.cast(0.044715, x.dtype)
+    cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3))))
+
+    return x * cdf
+
+
+def mish(x):
+    x = tf.convert_to_tensor(x)
+
+    return x * tf.tanh(tf.math.softplus(x))
+
+
+def gelu_fast(x):
+    x = tf.convert_to_tensor(x)
+    coeff1 = tf.cast(0.044715, x.dtype)
+    coeff2 = tf.cast(0.7978845608, x.dtype)
+
+    return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x)))
+
+
+def quick_gelu(x):
+    x = tf.convert_to_tensor(x)
+    coeff = tf.cast(1.702, x.dtype)
+    return x * tf.math.sigmoid(coeff * x)
+
+
+def gelu_10(x):
+    """
+    Clip the range of possible GeLU outputs between [-10, 10]. This is especially useful for quantization purpose, as
+    it allows mapping 2 negatives values in the GeLU spectrum. For more information on this trick, please refer to
+    https://huggingface.co/papers/2004.09602
+
+    Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
+    initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
+    0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
+    https://huggingface.co/papers/1606.08415 :param x: :return:
+    """
+    return tf.clip_by_value(_gelu(x), -10, 10)
+
+
+def glu(x, axis=-1):
+    """
+    Gated Linear Unit. Implementation as defined in the original paper (see https://huggingface.co/papers/1612.08083), where
+    the input `x` is split in two halves across a dimension (`axis`), A and B, returning A * sigmoid(B).
+
+    Args:
+        `x`: float Tensor to perform activation
+        `axis`: dimension across which `x` be split in half
+
+    Returns:
+        `x` with the GLU activation applied (with its size halved across the dimension `axis`).
+    """
+    a, b = tf.split(x, 2, axis=axis)
+    return a * tf.math.sigmoid(b)
+
+
+if parse(tf.version.VERSION) >= parse("2.4"):
+
+    def approximate_gelu_wrap(x):
+        return keras.activations.gelu(x, approximate=True)
+
+    gelu = keras.activations.gelu
+    gelu_new = approximate_gelu_wrap
+else:
+    gelu = _gelu
+    gelu_new = _gelu_new
+
+
+ACT2FN = {
+    "gelu": gelu,
+    "gelu_10": gelu_10,
+    "gelu_fast": gelu_fast,
+    "gelu_new": gelu_new,
+    "glu": glu,
+    "mish": mish,
+    "quick_gelu": quick_gelu,
+    "relu": keras.activations.relu,
+    "sigmoid": keras.activations.sigmoid,
+    "silu": keras.activations.swish,
+    "swish": keras.activations.swish,
+    "tanh": keras.activations.tanh,
+}
+
+
+def get_tf_activation(activation_string):
+    if activation_string in ACT2FN:
+        return ACT2FN[activation_string]
+    else:
+        raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
diff --git a/phivenv/Lib/site-packages/transformers/audio_utils.py b/phivenv/Lib/site-packages/transformers/audio_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa9a7d3d8757f6e64c92e66f9c4c77dcba6e3c36
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/audio_utils.py
@@ -0,0 +1,1224 @@
+# Copyright 2023 The HuggingFace Inc. team and the librosa & torchaudio authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks
+and remove unnecessary dependencies.
+"""
+
+import base64
+import importlib
+import io
+import os
+import warnings
+from io import BytesIO
+from typing import Any, Optional, Sequence, Union
+
+import numpy as np
+import requests
+from packaging import version
+
+from .utils import (
+    is_librosa_available,
+    is_numpy_array,
+    is_soundfile_available,
+    is_torch_tensor,
+    is_torchcodec_available,
+    requires_backends,
+)
+
+
+if is_soundfile_available():
+    import soundfile as sf
+
+if is_librosa_available():
+    import librosa
+
+    # TODO: @eustlb, we actually don't need librosa but soxr is installed with librosa
+    import soxr
+
+if is_torchcodec_available():
+    TORCHCODEC_VERSION = version.parse(importlib.metadata.version("torchcodec"))
+
+AudioInput = Union[np.ndarray, "torch.Tensor", Sequence[np.ndarray], Sequence["torch.Tensor"]]  # noqa: F821
+
+
+def load_audio(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None) -> np.ndarray:
+    """
+    Loads `audio` to an np.ndarray object.
+
+    Args:
+        audio (`str` or `np.ndarray`):
+            The audio to be loaded to the numpy array format.
+        sampling_rate (`int`, *optional*, defaults to 16000):
+            The sampling rate to be used when loading the audio. It should be same as the
+            sampling rate the model you will be using further was trained with.
+        timeout (`float`, *optional*):
+            The timeout value in seconds for the URL request.
+
+    Returns:
+        `np.ndarray`: A numpy array representing the audio.
+    """
+    if isinstance(audio, str):
+        # Try to load with `torchcodec` but do not enforce users to install it. If not found
+        # fallback to `librosa`. If using an audio-only model, most probably `torchcodec` won't be
+        # needed. Do not raise any errors if not installed or versions do not match
+        if is_torchcodec_available() and TORCHCODEC_VERSION >= version.parse("0.3.0"):
+            audio = load_audio_torchcodec(audio, sampling_rate=sampling_rate)
+        else:
+            audio = load_audio_librosa(audio, sampling_rate=sampling_rate, timeout=timeout)
+    elif isinstance(audio, np.ndarray):
+        audio = audio
+    else:
+        raise TypeError(
+            "Incorrect format used for `audio`. Should be an url linking to an audio, a local path, or numpy array."
+        )
+    return audio
+
+
+def load_audio_torchcodec(audio: Union[str, np.ndarray], sampling_rate=16000) -> np.ndarray:
+    """
+    Loads `audio` to an np.ndarray object using `torchcodec`.
+
+    Args:
+        audio (`str` or `np.ndarray`):
+            The audio to be loaded to the numpy array format.
+        sampling_rate (`int`, *optional*, defaults to 16000):
+            The sampling rate to be used when loading the audio. It should be same as the
+            sampling rate the model you will be using further was trained with.
+
+    Returns:
+        `np.ndarray`: A numpy array representing the audio.
+    """
+    # Lazy import so that issues in torchcodec compatibility don't crash the whole library
+    requires_backends(load_audio_torchcodec, ["torchcodec"])
+    from torchcodec.decoders import AudioDecoder
+
+    # Set `num_channels` to `1` which is what most models expects and the default in librosa
+    decoder = AudioDecoder(audio, sample_rate=sampling_rate, num_channels=1)
+    audio = decoder.get_all_samples().data[0].numpy()  # NOTE: feature extractors don't accept torch tensors
+    return audio
+
+
+def load_audio_librosa(audio: Union[str, np.ndarray], sampling_rate=16000, timeout=None) -> np.ndarray:
+    """
+    Loads `audio` to an np.ndarray object using `librosa`.
+
+    Args:
+        audio (`str` or `np.ndarray`):
+            The audio to be loaded to the numpy array format.
+        sampling_rate (`int`, *optional*, defaults to 16000):
+            The sampling rate to be used when loading the audio. It should be same as the
+            sampling rate the model you will be using further was trained with.
+        timeout (`float`, *optional*):
+            The timeout value in seconds for the URL request.
+
+    Returns:
+        `np.ndarray`: A numpy array representing the audio.
+    """
+    requires_backends(load_audio_librosa, ["librosa"])
+
+    # Load audio from URL (e.g https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/translate_to_chinese.wav)
+    if audio.startswith("http://") or audio.startswith("https://"):
+        audio = librosa.load(BytesIO(requests.get(audio, timeout=timeout).content), sr=sampling_rate)[0]
+    elif os.path.isfile(audio):
+        audio = librosa.load(audio, sr=sampling_rate)[0]
+    return audio
+
+
+def load_audio_as(
+    audio: str,
+    return_format: str,
+    timeout: Optional[int] = None,
+    force_mono: bool = False,
+    sampling_rate: Optional[int] = None,
+) -> Union[str, dict[str, Any], io.BytesIO, None]:
+    """
+    Load audio from either a local file path or URL and return in specified format.
+
+    Args:
+        audio (`str`): Either a local file path or a URL to an audio file
+        return_format (`str`): Format to return the audio in:
+            - "base64": Base64 encoded string
+            - "dict": Dictionary with data and format
+            - "buffer": BytesIO object
+        timeout (`int`, *optional*): Timeout for URL requests in seconds
+        force_mono (`bool`): Whether to convert stereo audio to mono
+        sampling_rate (`int`, *optional*): If provided, the audio will be resampled to the specified sampling rate.
+
+    Returns:
+        `Union[str, Dict[str, Any], io.BytesIO, None]`:
+            - `str`: Base64 encoded audio data (if return_format="base64")
+            - `dict`: Dictionary with 'data' (base64 encoded audio data) and 'format' keys (if return_format="dict")
+            - `io.BytesIO`: BytesIO object containing audio data (if return_format="buffer")
+    """
+    # TODO: @eustlb, we actually don't need librosa but soxr is installed with librosa
+    requires_backends(load_audio_as, ["librosa"])
+
+    if return_format not in ["base64", "dict", "buffer"]:
+        raise ValueError(f"Invalid return_format: {return_format}. Must be 'base64', 'dict', or 'buffer'")
+
+    try:
+        # Load audio bytes from URL or file
+        audio_bytes = None
+        if audio.startswith(("http://", "https://")):
+            response = requests.get(audio, timeout=timeout)
+            response.raise_for_status()
+            audio_bytes = response.content
+        elif os.path.isfile(audio):
+            with open(audio, "rb") as audio_file:
+                audio_bytes = audio_file.read()
+        else:
+            raise ValueError(f"File not found: {audio}")
+
+        # Process audio data
+        with io.BytesIO(audio_bytes) as audio_file:
+            with sf.SoundFile(audio_file) as f:
+                audio_array = f.read(dtype="float32")
+                original_sr = f.samplerate
+                audio_format = f.format
+                if sampling_rate is not None and sampling_rate != original_sr:
+                    # Resample audio to target sampling rate
+                    audio_array = soxr.resample(audio_array, original_sr, sampling_rate, quality="HQ")
+                else:
+                    sampling_rate = original_sr
+
+        # Convert to mono if needed
+        if force_mono and audio_array.ndim != 1:
+            audio_array = audio_array.mean(axis=1)
+
+        buffer = io.BytesIO()
+        sf.write(buffer, audio_array, sampling_rate, format=audio_format.upper())
+        buffer.seek(0)
+
+        if return_format == "buffer":
+            return buffer
+        elif return_format == "base64":
+            return base64.b64encode(buffer.read()).decode("utf-8")
+        elif return_format == "dict":
+            return {
+                "data": base64.b64encode(buffer.read()).decode("utf-8"),
+                "format": audio_format.lower(),
+            }
+
+    except Exception as e:
+        raise ValueError(f"Error loading audio: {e}")
+
+
+def is_valid_audio(audio):
+    return is_numpy_array(audio) or is_torch_tensor(audio)
+
+
+def is_valid_list_of_audio(audio):
+    return audio and all(is_valid_audio(audio_i) for audio_i in audio)
+
+
+def make_list_of_audio(
+    audio: Union[list[AudioInput], AudioInput],
+) -> AudioInput:
+    """
+    Ensure that the output is a list of audio.
+    Args:
+        audio (`Union[list[AudioInput], AudioInput]`):
+            The input audio.
+    Returns:
+        list: A list of audio.
+    """
+    # If it's a list of audios, it's already in the right format
+    if isinstance(audio, (list, tuple)) and is_valid_list_of_audio(audio):
+        return audio
+
+    # If it's a single audio, convert it to a list of
+    if is_valid_audio(audio):
+        return [audio]
+
+    raise ValueError("Invalid input type. Must be a single audio or a list of audio")
+
+
+def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
+    """
+    Convert frequency from hertz to mels.
+
+    Args:
+        freq (`float` or `np.ndarray`):
+            The frequency, or multiple frequencies, in hertz (Hz).
+        mel_scale (`str`, *optional*, defaults to `"htk"`):
+            The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
+
+    Returns:
+        `float` or `np.ndarray`: The frequencies on the mel scale.
+    """
+
+    if mel_scale not in ["slaney", "htk", "kaldi"]:
+        raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
+
+    if mel_scale == "htk":
+        return 2595.0 * np.log10(1.0 + (freq / 700.0))
+    elif mel_scale == "kaldi":
+        return 1127.0 * np.log(1.0 + (freq / 700.0))
+
+    min_log_hertz = 1000.0
+    min_log_mel = 15.0
+    logstep = 27.0 / np.log(6.4)
+    mels = 3.0 * freq / 200.0
+
+    if isinstance(freq, np.ndarray):
+        log_region = freq >= min_log_hertz
+        mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
+    elif freq >= min_log_hertz:
+        mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
+
+    return mels
+
+
+def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
+    """
+    Convert frequency from mels to hertz.
+
+    Args:
+        mels (`float` or `np.ndarray`):
+            The frequency, or multiple frequencies, in mels.
+        mel_scale (`str`, *optional*, `"htk"`):
+            The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
+
+    Returns:
+        `float` or `np.ndarray`: The frequencies in hertz.
+    """
+
+    if mel_scale not in ["slaney", "htk", "kaldi"]:
+        raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
+
+    if mel_scale == "htk":
+        return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
+    elif mel_scale == "kaldi":
+        return 700.0 * (np.exp(mels / 1127.0) - 1.0)
+
+    min_log_hertz = 1000.0
+    min_log_mel = 15.0
+    logstep = np.log(6.4) / 27.0
+    freq = 200.0 * mels / 3.0
+
+    if isinstance(mels, np.ndarray):
+        log_region = mels >= min_log_mel
+        freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel))
+    elif mels >= min_log_mel:
+        freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel))
+
+    return freq
+
+
+def hertz_to_octave(
+    freq: Union[float, np.ndarray], tuning: Optional[float] = 0.0, bins_per_octave: Optional[int] = 12
+):
+    """
+    Convert frequency from hertz to fractional octave numbers.
+    Adapted from *librosa*.
+
+    Args:
+        freq (`float` or `np.ndarray`):
+            The frequency, or multiple frequencies, in hertz (Hz).
+        tuning (`float`, defaults to `0.`):
+            Tuning deviation from the Stuttgart pitch (A440) in (fractional) bins per octave.
+        bins_per_octave (`int`, defaults to `12`):
+            Number of bins per octave.
+
+    Returns:
+        `float` or `np.ndarray`: The frequencies on the octave scale.
+    """
+    stuttgart_pitch = 440.0 * 2.0 ** (tuning / bins_per_octave)
+    octave = np.log2(freq / (float(stuttgart_pitch) / 16))
+    return octave
+
+
+def _create_triangular_filter_bank(fft_freqs: np.ndarray, filter_freqs: np.ndarray) -> np.ndarray:
+    """
+    Creates a triangular filter bank.
+
+    Adapted from *torchaudio* and *librosa*.
+
+    Args:
+        fft_freqs (`np.ndarray` of shape `(num_frequency_bins,)`):
+            Discrete frequencies of the FFT bins in Hz.
+        filter_freqs (`np.ndarray` of shape `(num_mel_filters,)`):
+            Center frequencies of the triangular filters to create, in Hz.
+
+    Returns:
+        `np.ndarray` of shape `(num_frequency_bins, num_mel_filters)`
+    """
+    filter_diff = np.diff(filter_freqs)
+    slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
+    down_slopes = -slopes[:, :-2] / filter_diff[:-1]
+    up_slopes = slopes[:, 2:] / filter_diff[1:]
+    return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
+
+
+def chroma_filter_bank(
+    num_frequency_bins: int,
+    num_chroma: int,
+    sampling_rate: int,
+    tuning: float = 0.0,
+    power: Optional[float] = 2.0,
+    weighting_parameters: Optional[tuple[float, float]] = (5.0, 2.0),
+    start_at_c_chroma: Optional[bool] = True,
+):
+    """
+    Creates a chroma filter bank, i.e a linear transformation to project spectrogram bins onto chroma bins.
+
+    Adapted from *librosa*.
+
+    Args:
+        num_frequency_bins (`int`):
+            Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
+        num_chroma (`int`):
+            Number of chroma bins (i.e pitch classes).
+        sampling_rate (`float`):
+            Sample rate of the audio waveform.
+        tuning (`float`):
+            Tuning deviation from A440 in fractions of a chroma bin.
+        power (`float`, *optional*, defaults to 2.0):
+            If 12.0, normalizes each column with their L2 norm. If 1.0, normalizes each column with their L1 norm.
+        weighting_parameters (`tuple[float, float]`, *optional*, defaults to `(5., 2.)`):
+            If specified, apply a Gaussian weighting parameterized by the first element of the tuple being the center and
+            the second element being the Gaussian half-width.
+        start_at_c_chroma (`float`, *optional*, defaults to `True`):
+            If True, the filter bank will start at the 'C' pitch class. Otherwise, it will start at 'A'.
+    Returns:
+        `np.ndarray` of shape `(num_frequency_bins, num_chroma)`
+    """
+    # Get the FFT bins, not counting the DC component
+    frequencies = np.linspace(0, sampling_rate, num_frequency_bins, endpoint=False)[1:]
+
+    freq_bins = num_chroma * hertz_to_octave(frequencies, tuning=tuning, bins_per_octave=num_chroma)
+
+    # make up a value for the 0 Hz bin = 1.5 octaves below bin 1
+    # (so chroma is 50% rotated from bin 1, and bin width is broad)
+    freq_bins = np.concatenate(([freq_bins[0] - 1.5 * num_chroma], freq_bins))
+
+    bins_width = np.concatenate((np.maximum(freq_bins[1:] - freq_bins[:-1], 1.0), [1]))
+
+    chroma_filters = np.subtract.outer(freq_bins, np.arange(0, num_chroma, dtype="d")).T
+
+    num_chroma2 = np.round(float(num_chroma) / 2)
+
+    # Project into range -num_chroma/2 .. num_chroma/2
+    # add on fixed offset of 10*num_chroma to ensure all values passed to
+    # rem are positive
+    chroma_filters = np.remainder(chroma_filters + num_chroma2 + 10 * num_chroma, num_chroma) - num_chroma2
+
+    # Gaussian bumps - 2*D to make them narrower
+    chroma_filters = np.exp(-0.5 * (2 * chroma_filters / np.tile(bins_width, (num_chroma, 1))) ** 2)
+
+    # normalize each column
+    if power is not None:
+        chroma_filters = chroma_filters / np.sum(chroma_filters**power, axis=0, keepdims=True) ** (1.0 / power)
+
+    # Maybe apply scaling for fft bins
+    if weighting_parameters is not None:
+        center, half_width = weighting_parameters
+        chroma_filters *= np.tile(
+            np.exp(-0.5 * (((freq_bins / num_chroma - center) / half_width) ** 2)),
+            (num_chroma, 1),
+        )
+
+    if start_at_c_chroma:
+        chroma_filters = np.roll(chroma_filters, -3 * (num_chroma // 12), axis=0)
+
+    # remove aliasing columns, copy to ensure row-contiguity
+    return np.ascontiguousarray(chroma_filters[:, : int(1 + num_frequency_bins / 2)])
+
+
+def mel_filter_bank(
+    num_frequency_bins: int,
+    num_mel_filters: int,
+    min_frequency: float,
+    max_frequency: float,
+    sampling_rate: int,
+    norm: Optional[str] = None,
+    mel_scale: str = "htk",
+    triangularize_in_mel_space: bool = False,
+) -> np.ndarray:
+    """
+    Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and
+    various implementation exist, which differ in the number of filters, the shape of the filters, the way the filters
+    are spaced, the bandwidth of the filters, and the manner in which the spectrum is warped. The goal of these
+    features is to approximate the non-linear human perception of the variation in pitch with respect to the frequency.
+
+    Different banks of mel filters were introduced in the literature. The following variations are supported:
+
+    - MFCC FB-20: introduced in 1980 by Davis and Mermelstein, it assumes a sampling frequency of 10 kHz and a speech
+      bandwidth of `[0, 4600]` Hz.
+    - MFCC FB-24 HTK: from the Cambridge HMM Toolkit (HTK) (1995) uses a filter bank of 24 filters for a speech
+      bandwidth of `[0, 8000]` Hz. This assumes sampling rate ≥ 16 kHz.
+    - MFCC FB-40: from the Auditory Toolbox for MATLAB written by Slaney in 1998, assumes a sampling rate of 16 kHz and
+      speech bandwidth of `[133, 6854]` Hz. This version also includes area normalization.
+    - HFCC-E FB-29 (Human Factor Cepstral Coefficients) of Skowronski and Harris (2004), assumes a sampling rate of
+      12.5 kHz and speech bandwidth of `[0, 6250]` Hz.
+
+    This code is adapted from *torchaudio* and *librosa*. Note that the default parameters of torchaudio's
+    `melscale_fbanks` implement the `"htk"` filters while librosa uses the `"slaney"` implementation.
+
+    Args:
+        num_frequency_bins (`int`):
+            Number of frequency bins (should be the same as `n_fft // 2 + 1` where `n_fft` is the size of the Fourier Transform used to compute the spectrogram).
+        num_mel_filters (`int`):
+            Number of mel filters to generate.
+        min_frequency (`float`):
+            Lowest frequency of interest in Hz.
+        max_frequency (`float`):
+            Highest frequency of interest in Hz. This should not exceed `sampling_rate / 2`.
+        sampling_rate (`int`):
+            Sample rate of the audio waveform.
+        norm (`str`, *optional*):
+            If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization).
+        mel_scale (`str`, *optional*, defaults to `"htk"`):
+            The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
+        triangularize_in_mel_space (`bool`, *optional*, defaults to `False`):
+            If this option is enabled, the triangular filter is applied in mel space rather than frequency space. This
+            should be set to `true` in order to get the same results as `torchaudio` when computing mel filters.
+
+    Returns:
+        `np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a
+        projection matrix to go from a spectrogram to a mel spectrogram.
+    """
+    if norm is not None and norm != "slaney":
+        raise ValueError('norm must be one of None or "slaney"')
+
+    if num_frequency_bins < 2:
+        raise ValueError(f"Require num_frequency_bins: {num_frequency_bins} >= 2")
+
+    if min_frequency > max_frequency:
+        raise ValueError(f"Require min_frequency: {min_frequency} <= max_frequency: {max_frequency}")
+
+    # center points of the triangular mel filters
+    mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
+    mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
+    mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
+    filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
+
+    if triangularize_in_mel_space:
+        # frequencies of FFT bins in Hz, but filters triangularized in mel space
+        fft_bin_width = sampling_rate / ((num_frequency_bins - 1) * 2)
+        fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
+        filter_freqs = mel_freqs
+    else:
+        # frequencies of FFT bins in Hz
+        fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
+
+    mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
+
+    if norm is not None and norm == "slaney":
+        # Slaney-style mel is scaled to be approx constant energy per channel
+        enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters])
+        mel_filters *= np.expand_dims(enorm, 0)
+
+    if (mel_filters.max(axis=0) == 0.0).any():
+        warnings.warn(
+            "At least one mel filter has all zero values. "
+            f"The value for `num_mel_filters` ({num_mel_filters}) may be set too high. "
+            f"Or, the value for `num_frequency_bins` ({num_frequency_bins}) may be set too low."
+        )
+
+    return mel_filters
+
+
+def optimal_fft_length(window_length: int) -> int:
+    """
+    Finds the best FFT input size for a given `window_length`. This function takes a given window length and, if not
+    already a power of two, rounds it up to the next power or two.
+
+    The FFT algorithm works fastest when the length of the input is a power of two, which may be larger than the size
+    of the window or analysis frame. For example, if the window is 400 samples, using an FFT input size of 512 samples
+    is more optimal than an FFT size of 400 samples. Using a larger FFT size does not affect the detected frequencies,
+    it simply gives a higher frequency resolution (i.e. the frequency bins are smaller).
+    """
+    return 2 ** int(np.ceil(np.log2(window_length)))
+
+
+def window_function(
+    window_length: int,
+    name: str = "hann",
+    periodic: bool = True,
+    frame_length: Optional[int] = None,
+    center: bool = True,
+) -> np.ndarray:
+    """
+    Returns an array containing the specified window. This window is intended to be used with `stft`.
+
+    The following window types are supported:
+
+        - `"boxcar"`: a rectangular window
+        - `"hamming"`: the Hamming window
+        - `"hann"`: the Hann window
+        - `"povey"`: the Povey window
+
+    Args:
+        window_length (`int`):
+            The length of the window in samples.
+        name (`str`, *optional*, defaults to `"hann"`):
+            The name of the window function.
+        periodic (`bool`, *optional*, defaults to `True`):
+            Whether the window is periodic or symmetric.
+        frame_length (`int`, *optional*):
+            The length of the analysis frames in samples. Provide a value for `frame_length` if the window is smaller
+            than the frame length, so that it will be zero-padded.
+        center (`bool`, *optional*, defaults to `True`):
+            Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided.
+
+    Returns:
+        `np.ndarray` of shape `(window_length,)` or `(frame_length,)` containing the window.
+    """
+    length = window_length + 1 if periodic else window_length
+
+    if name == "boxcar":
+        window = np.ones(length)
+    elif name in ["hamming", "hamming_window"]:
+        window = np.hamming(length)
+    elif name in ["hann", "hann_window"]:
+        window = np.hanning(length)
+    elif name in ["povey"]:
+        window = np.power(np.hanning(length), 0.85)
+    else:
+        raise ValueError(f"Unknown window function '{name}'")
+
+    if periodic:
+        window = window[:-1]
+
+    if frame_length is None:
+        return window
+
+    if window_length > frame_length:
+        raise ValueError(
+            f"Length of the window ({window_length}) may not be larger than frame_length ({frame_length})"
+        )
+
+    padded_window = np.zeros(frame_length)
+    offset = (frame_length - window_length) // 2 if center else 0
+    padded_window[offset : offset + window_length] = window
+    return padded_window
+
+
+# TODO This method does not support batching yet as we are mainly focused on inference.
+def spectrogram(
+    waveform: np.ndarray,
+    window: np.ndarray,
+    frame_length: int,
+    hop_length: int,
+    fft_length: Optional[int] = None,
+    power: Optional[float] = 1.0,
+    center: bool = True,
+    pad_mode: str = "reflect",
+    onesided: bool = True,
+    dither: float = 0.0,
+    preemphasis: Optional[float] = None,
+    mel_filters: Optional[np.ndarray] = None,
+    mel_floor: float = 1e-10,
+    log_mel: Optional[str] = None,
+    reference: float = 1.0,
+    min_value: float = 1e-10,
+    db_range: Optional[float] = None,
+    remove_dc_offset: Optional[bool] = None,
+    dtype: np.dtype = np.float32,
+) -> np.ndarray:
+    """
+    Calculates a spectrogram over one waveform using the Short-Time Fourier Transform.
+
+    This function can create the following kinds of spectrograms:
+
+      - amplitude spectrogram (`power = 1.0`)
+      - power spectrogram (`power = 2.0`)
+      - complex-valued spectrogram (`power = None`)
+      - log spectrogram (use `log_mel` argument)
+      - mel spectrogram (provide `mel_filters`)
+      - log-mel spectrogram (provide `mel_filters` and `log_mel`)
+
+    How this works:
+
+      1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
+         - hop_length` samples.
+      2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
+      3. The DFT is taken of each windowed frame.
+      4. The results are stacked into a spectrogram.
+
+    We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
+
+      - The analysis frame. This is the size of the time slices that the input waveform is split into.
+      - The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
+      - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
+
+    In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
+    padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
+    typically the next power of two.
+
+    Note: This function is not optimized for speed yet. It should be mostly compatible with `librosa.stft` and
+    `torchaudio.functional.transforms.Spectrogram`, although it is more flexible due to the different ways spectrograms
+    can be constructed.
+
+    Args:
+        waveform (`np.ndarray` of shape `(length,)`):
+            The input waveform. This must be a single real-valued, mono waveform.
+        window (`np.ndarray` of shape `(frame_length,)`):
+            The windowing function to apply, including zero-padding if necessary. The actual window length may be
+            shorter than `frame_length`, but we're assuming the array has already been zero-padded.
+        frame_length (`int`):
+            The length of the analysis frames in samples. With librosa this is always equal to `fft_length` but we also
+            allow smaller sizes.
+        hop_length (`int`):
+            The stride between successive analysis frames in samples.
+        fft_length (`int`, *optional*):
+            The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have.
+            For optimal speed, this should be a power of two. If `None`, uses `frame_length`.
+        power (`float`, *optional*, defaults to 1.0):
+            If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `None`, returns
+            complex numbers.
+        center (`bool`, *optional*, defaults to `True`):
+            Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame
+            `t` will start at time `t * hop_length`.
+        pad_mode (`str`, *optional*, defaults to `"reflect"`):
+            Padding mode used when `center` is `True`. Possible values are: `"constant"` (pad with zeros), `"edge"`
+            (pad with edge values), `"reflect"` (pads with mirrored values).
+        onesided (`bool`, *optional*, defaults to `True`):
+            If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
+            frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
+        dither (`float`, *optional*, defaults to 0.0):
+            Adds dithering. In other words, adds a small Gaussian noise to each frame.
+            E.g. use 4.0 to add dithering with a normal distribution centered
+            around 0.0 with standard deviation 4.0, 0.0 means no dithering.
+            Dithering has similar effect as `mel_floor`. It reduces the high log_mel_fbank
+            values for signals with hard-zero sections, when VAD cutoff is present in the signal.
+        preemphasis (`float`, *optional*)
+            Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
+        mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
+            The mel filter bank. If supplied, applies a this filter bank to create a mel spectrogram.
+        mel_floor (`float`, *optional*, defaults to 1e-10):
+            Minimum value of mel frequency banks.
+        log_mel (`str`, *optional*):
+            How to convert the spectrogram to log scale. Possible options are: `None` (don't convert), `"log"` (take
+            the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels). Can only be
+            used when `power` is not `None`.
+        reference (`float`, *optional*, defaults to 1.0):
+            Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
+            the loudest part to 0 dB. Must be greater than zero.
+        min_value (`float`, *optional*, defaults to `1e-10`):
+            The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
+            `log(0)`. For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an
+            amplitude spectrogram, the value `1e-5` corresponds to -100 dB. Must be greater than zero.
+        db_range (`float`, *optional*):
+            Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
+            peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
+        remove_dc_offset (`bool`, *optional*):
+            Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in
+            order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters.
+        dtype (`np.dtype`, *optional*, defaults to `np.float32`):
+            Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be
+            `np.complex64`.
+
+    Returns:
+        `nd.array` containing a spectrogram of shape `(num_frequency_bins, length)` for a regular spectrogram or shape
+        `(num_mel_filters, length)` for a mel spectrogram.
+    """
+    window_length = len(window)
+
+    if fft_length is None:
+        fft_length = frame_length
+
+    if frame_length > fft_length:
+        raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
+
+    if window_length != frame_length:
+        raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
+
+    if hop_length <= 0:
+        raise ValueError("hop_length must be greater than zero")
+
+    if waveform.ndim != 1:
+        raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
+
+    if np.iscomplexobj(waveform):
+        raise ValueError("Complex-valued input waveforms are not currently supported")
+
+    if power is None and mel_filters is not None:
+        raise ValueError(
+            "You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram."
+            "Specify `power` to fix this issue."
+        )
+
+    # center pad the waveform
+    if center:
+        padding = [(int(frame_length // 2), int(frame_length // 2))]
+        waveform = np.pad(waveform, padding, mode=pad_mode)
+
+    # promote to float64, since np.fft uses float64 internally
+    waveform = waveform.astype(np.float64)
+    window = window.astype(np.float64)
+
+    # split waveform into frames of frame_length size
+    num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))
+
+    num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
+    spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)
+
+    # rfft is faster than fft
+    fft_func = np.fft.rfft if onesided else np.fft.fft
+    buffer = np.zeros(fft_length)
+
+    timestep = 0
+    for frame_idx in range(num_frames):
+        buffer[:frame_length] = waveform[timestep : timestep + frame_length]
+
+        if dither != 0.0:
+            buffer[:frame_length] += dither * np.random.randn(frame_length)
+
+        if remove_dc_offset:
+            buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
+
+        if preemphasis is not None:
+            buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
+            buffer[0] *= 1 - preemphasis
+
+        buffer[:frame_length] *= window
+
+        spectrogram[frame_idx] = fft_func(buffer)
+        timestep += hop_length
+
+    # note: ** is much faster than np.power
+    if power is not None:
+        spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
+
+    spectrogram = spectrogram.T
+
+    if mel_filters is not None:
+        spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))
+
+    if power is not None and log_mel is not None:
+        if log_mel == "log":
+            spectrogram = np.log(spectrogram)
+        elif log_mel == "log10":
+            spectrogram = np.log10(spectrogram)
+        elif log_mel == "dB":
+            if power == 1.0:
+                spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)
+            elif power == 2.0:
+                spectrogram = power_to_db(spectrogram, reference, min_value, db_range)
+            else:
+                raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
+        else:
+            raise ValueError(f"Unknown log_mel option: {log_mel}")
+
+        spectrogram = np.asarray(spectrogram, dtype)
+
+    return spectrogram
+
+
+def spectrogram_batch(
+    waveform_list: list[np.ndarray],
+    window: np.ndarray,
+    frame_length: int,
+    hop_length: int,
+    fft_length: Optional[int] = None,
+    power: Optional[float] = 1.0,
+    center: bool = True,
+    pad_mode: str = "reflect",
+    onesided: bool = True,
+    dither: float = 0.0,
+    preemphasis: Optional[float] = None,
+    mel_filters: Optional[np.ndarray] = None,
+    mel_floor: float = 1e-10,
+    log_mel: Optional[str] = None,
+    reference: float = 1.0,
+    min_value: float = 1e-10,
+    db_range: Optional[float] = None,
+    remove_dc_offset: Optional[bool] = None,
+    dtype: np.dtype = np.float32,
+) -> list[np.ndarray]:
+    """
+    Calculates spectrograms for a list of waveforms using the Short-Time Fourier Transform, optimized for batch processing.
+    This function extends the capabilities of the `spectrogram` function to handle multiple waveforms efficiently by leveraging broadcasting.
+
+    It supports generating various types of spectrograms:
+
+        - amplitude spectrogram (`power = 1.0`)
+        - power spectrogram (`power = 2.0`)
+        - complex-valued spectrogram (`power = None`)
+        - log spectrogram (use `log_mel` argument)
+        - mel spectrogram (provide `mel_filters`)
+        - log-mel spectrogram (provide `mel_filters` and `log_mel`)
+
+    How this works:
+
+        1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
+            - hop_length` samples.
+        2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
+        3. The DFT is taken of each windowed frame.
+        4. The results are stacked into a spectrogram.
+
+    We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
+
+      - The analysis frame. This is the size of the time slices that the input waveform is split into.
+      - The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
+      - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
+
+    In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
+    padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
+    typically the next power of two.
+
+    Note: This function is designed for efficient batch processing of multiple waveforms but retains compatibility with individual waveform processing methods like `librosa.stft`.
+
+    Args:
+        waveform_list (`list[np.ndarray]` with arrays of shape `(length,)`):
+            The list of input waveforms, each a single-channel (mono) signal.
+        window (`np.ndarray` of shape `(frame_length,)`):
+            The windowing function to apply, including zero-padding if necessary.
+        frame_length (`int`):
+            The length of each frame for analysis.
+        hop_length (`int`):
+            The step size between successive frames.
+        fft_length (`int`, *optional*):
+            The size of the FFT buffer, defining frequency bin resolution.
+        power (`float`, *optional*, defaults to 1.0):
+            Determines the type of spectrogram: 1.0 for amplitude, 2.0 for power, None for complex.
+        center (`bool`, *optional*, defaults to `True`):
+            Whether to center-pad the waveform frames.
+        pad_mode (`str`, *optional*, defaults to `"reflect"`):
+            The padding strategy when `center` is `True`.
+        onesided (`bool`, *optional*, defaults to `True`):
+            If True, returns a one-sided spectrogram for real input signals.
+        dither (`float`, *optional*, defaults to 0.0):
+            Adds dithering. In other words, adds a small Gaussian noise to each frame.
+            E.g. use 4.0 to add dithering with a normal distribution centered
+            around 0.0 with standard deviation 4.0, 0.0 means no dithering.
+        preemphasis (`float`, *optional*):
+            Applies a pre-emphasis filter to each frame.
+        mel_filters (`np.ndarray`, *optional*):
+            Mel filter bank for converting to mel spectrogram.
+        mel_floor (`float`, *optional*, defaults to 1e-10):
+            Floor value for mel spectrogram to avoid log(0).
+        log_mel (`str`, *optional*):
+            Specifies log scaling strategy; options are None, "log", "log10", "dB".
+        reference (`float`, *optional*, defaults to 1.0):
+            Reference value for dB conversion in log_mel.
+        min_value (`float`, *optional*, defaults to 1e-10):
+            Minimum floor value for log scale conversions.
+        db_range (`float`, *optional*):
+            Dynamic range for dB scale spectrograms.
+        remove_dc_offset (`bool`, *optional*):
+            Whether to remove the DC offset from each frame.
+        dtype (`np.dtype`, *optional*, defaults to `np.float32`):
+            Data type of the output spectrogram.
+
+    Returns:
+        list[`np.ndarray`]: A list of spectrogram arrays, one for each input waveform.
+    """
+    window_length = len(window)
+
+    if fft_length is None:
+        fft_length = frame_length
+
+    if frame_length > fft_length:
+        raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
+
+    if window_length != frame_length:
+        raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
+
+    if hop_length <= 0:
+        raise ValueError("hop_length must be greater than zero")
+
+    # Check the dimensions of the waveform , and if waveform is complex
+    for waveform in waveform_list:
+        if waveform.ndim != 1:
+            raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
+        if np.iscomplexobj(waveform):
+            raise ValueError("Complex-valued input waveforms are not currently supported")
+    # Center pad the waveform
+    if center:
+        padding = [(int(frame_length // 2), int(frame_length // 2))]
+        waveform_list = [
+            np.pad(
+                waveform,
+                padding,
+                mode=pad_mode,
+            )
+            for waveform in waveform_list
+        ]
+    original_waveform_lengths = [
+        len(waveform) for waveform in waveform_list
+    ]  # these lengths will be used to remove padding later
+
+    # Batch pad the waveform
+    max_length = max(original_waveform_lengths)
+    padded_waveform_batch = np.array(
+        [
+            np.pad(waveform, (0, max_length - len(waveform)), mode="constant", constant_values=0)
+            for waveform in waveform_list
+        ],
+        dtype=dtype,
+    )
+
+    # Promote to float64, since np.fft uses float64 internally
+    padded_waveform_batch = padded_waveform_batch.astype(np.float64)
+    window = window.astype(np.float64)
+
+    # Split waveform into frames of frame_length size
+    num_frames = int(1 + np.floor((padded_waveform_batch.shape[1] - frame_length) / hop_length))
+    # these lengths will be used to remove padding later
+    true_num_frames = [int(1 + np.floor((length - frame_length) / hop_length)) for length in original_waveform_lengths]
+    num_batches = padded_waveform_batch.shape[0]
+
+    num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
+    spectrogram = np.empty((num_batches, num_frames, num_frequency_bins), dtype=np.complex64)
+
+    # rfft is faster than fft
+    fft_func = np.fft.rfft if onesided else np.fft.fft
+    buffer = np.zeros((num_batches, fft_length))
+
+    for frame_idx in range(num_frames):
+        timestep = frame_idx * hop_length
+        buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length]
+
+        if dither != 0.0:
+            buffer[:, :frame_length] += dither * np.random.randn(*buffer[:, :frame_length].shape)
+
+        if remove_dc_offset:
+            buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True)
+
+        if preemphasis is not None:
+            buffer[:, 1:frame_length] -= preemphasis * buffer[:, : frame_length - 1]
+            buffer[:, 0] *= 1 - preemphasis
+
+        buffer[:, :frame_length] *= window
+
+        spectrogram[:, frame_idx] = fft_func(buffer)
+
+    # Note: ** is much faster than np.power
+    if power is not None:
+        spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
+
+    # Apply mel filters if provided
+    if mel_filters is not None:
+        result = np.tensordot(spectrogram, mel_filters.T, axes=([2], [1]))
+        spectrogram = np.maximum(mel_floor, result)
+
+    # Convert to log scale if specified
+    if power is not None and log_mel is not None:
+        if log_mel == "log":
+            spectrogram = np.log(spectrogram)
+        elif log_mel == "log10":
+            spectrogram = np.log10(spectrogram)
+        elif log_mel == "dB":
+            if power == 1.0:
+                spectrogram = amplitude_to_db_batch(spectrogram, reference, min_value, db_range)
+            elif power == 2.0:
+                spectrogram = power_to_db_batch(spectrogram, reference, min_value, db_range)
+            else:
+                raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
+        else:
+            raise ValueError(f"Unknown log_mel option: {log_mel}")
+
+        spectrogram = np.asarray(spectrogram, dtype)
+
+    spectrogram_list = [spectrogram[i, : true_num_frames[i], :].T for i in range(len(true_num_frames))]
+
+    return spectrogram_list
+
+
+def power_to_db(
+    spectrogram: np.ndarray,
+    reference: float = 1.0,
+    min_value: float = 1e-10,
+    db_range: Optional[float] = None,
+) -> np.ndarray:
+    """
+    Converts a power spectrogram to the decibel scale. This computes `10 * log10(spectrogram / reference)`, using basic
+    logarithm properties for numerical stability.
+
+    The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
+    linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
+    This means that large variations in energy may not sound all that different if the sound is loud to begin with.
+    This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
+
+    Based on the implementation of `librosa.power_to_db`.
+
+    Args:
+        spectrogram (`np.ndarray`):
+            The input power (mel) spectrogram. Note that a power spectrogram has the amplitudes squared!
+        reference (`float`, *optional*, defaults to 1.0):
+            Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
+            the loudest part to 0 dB. Must be greater than zero.
+        min_value (`float`, *optional*, defaults to `1e-10`):
+            The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
+            `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
+        db_range (`float`, *optional*):
+            Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
+            peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
+
+    Returns:
+        `np.ndarray`: the spectrogram in decibels
+    """
+    if reference <= 0.0:
+        raise ValueError("reference must be greater than zero")
+    if min_value <= 0.0:
+        raise ValueError("min_value must be greater than zero")
+
+    reference = max(min_value, reference)
+
+    spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
+    spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
+
+    if db_range is not None:
+        if db_range <= 0.0:
+            raise ValueError("db_range must be greater than zero")
+        spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
+
+    return spectrogram
+
+
+def power_to_db_batch(
+    spectrogram: np.ndarray,
+    reference: float = 1.0,
+    min_value: float = 1e-10,
+    db_range: Optional[float] = None,
+) -> np.ndarray:
+    """
+    Converts a batch of power spectrograms to the decibel scale. This computes `10 * log10(spectrogram / reference)`,
+    using basic logarithm properties for numerical stability.
+
+    This function supports batch processing, where each item in the batch is an individual power (mel) spectrogram.
+
+    Args:
+        spectrogram (`np.ndarray`):
+            The input batch of power (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape).
+            Note that a power spectrogram has the amplitudes squared!
+        reference (`float`, *optional*, defaults to 1.0):
+            Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
+            the loudest part to 0 dB. Must be greater than zero.
+        min_value (`float`, *optional*, defaults to `1e-10`):
+            The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
+            `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
+        db_range (`float`, *optional*):
+            Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
+            peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
+
+    Returns:
+        `np.ndarray`: the batch of spectrograms in decibels
+    """
+    if reference <= 0.0:
+        raise ValueError("reference must be greater than zero")
+    if min_value <= 0.0:
+        raise ValueError("min_value must be greater than zero")
+
+    reference = max(min_value, reference)
+
+    spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
+    spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
+
+    if db_range is not None:
+        if db_range <= 0.0:
+            raise ValueError("db_range must be greater than zero")
+        # Apply db_range clipping per batch item
+        max_values = spectrogram.max(axis=(1, 2), keepdims=True)
+        spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
+
+    return spectrogram
+
+
+def amplitude_to_db(
+    spectrogram: np.ndarray,
+    reference: float = 1.0,
+    min_value: float = 1e-5,
+    db_range: Optional[float] = None,
+) -> np.ndarray:
+    """
+    Converts an amplitude spectrogram to the decibel scale. This computes `20 * log10(spectrogram / reference)`, using
+    basic logarithm properties for numerical stability.
+
+    The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
+    linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
+    This means that large variations in energy may not sound all that different if the sound is loud to begin with.
+    This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
+
+    Args:
+        spectrogram (`np.ndarray`):
+            The input amplitude (mel) spectrogram.
+        reference (`float`, *optional*, defaults to 1.0):
+            Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
+            the loudest part to 0 dB. Must be greater than zero.
+        min_value (`float`, *optional*, defaults to `1e-5`):
+            The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
+            `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
+        db_range (`float`, *optional*):
+            Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
+            peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
+
+    Returns:
+        `np.ndarray`: the spectrogram in decibels
+    """
+    if reference <= 0.0:
+        raise ValueError("reference must be greater than zero")
+    if min_value <= 0.0:
+        raise ValueError("min_value must be greater than zero")
+
+    reference = max(min_value, reference)
+
+    spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
+    spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
+
+    if db_range is not None:
+        if db_range <= 0.0:
+            raise ValueError("db_range must be greater than zero")
+        spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
+
+    return spectrogram
+
+
+def amplitude_to_db_batch(
+    spectrogram: np.ndarray, reference: float = 1.0, min_value: float = 1e-5, db_range: Optional[float] = None
+) -> np.ndarray:
+    """
+    Converts a batch of amplitude spectrograms to the decibel scale. This computes `20 * log10(spectrogram / reference)`,
+    using basic logarithm properties for numerical stability.
+
+    The function supports batch processing, where each item in the batch is an individual amplitude (mel) spectrogram.
+
+    Args:
+        spectrogram (`np.ndarray`):
+            The input batch of amplitude (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape).
+        reference (`float`, *optional*, defaults to 1.0):
+            Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
+            the loudest part to 0 dB. Must be greater than zero.
+        min_value (`float`, *optional*, defaults to `1e-5`):
+            The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
+            `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
+        db_range (`float`, *optional*):
+            Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
+            peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
+
+    Returns:
+        `np.ndarray`: the batch of spectrograms in decibels
+    """
+    if reference <= 0.0:
+        raise ValueError("reference must be greater than zero")
+    if min_value <= 0.0:
+        raise ValueError("min_value must be greater than zero")
+
+    reference = max(min_value, reference)
+
+    spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
+    spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
+
+    if db_range is not None:
+        if db_range <= 0.0:
+            raise ValueError("db_range must be greater than zero")
+        # Apply db_range clipping per batch item
+        max_values = spectrogram.max(axis=(1, 2), keepdims=True)
+        spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
+
+    return spectrogram
diff --git a/phivenv/Lib/site-packages/transformers/cache_utils.py b/phivenv/Lib/site-packages/transformers/cache_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e993693c93e4b02e5d9d5c829a2dc62ef831f321
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/cache_utils.py
@@ -0,0 +1,1516 @@
+from abc import ABC, abstractmethod
+from collections.abc import Iterable
+from typing import Any, Optional
+
+import torch
+
+from .configuration_utils import PretrainedConfig
+from .utils import (
+    is_hqq_available,
+    is_quanto_greater,
+    is_torch_greater_or_equal,
+    is_torchdynamo_compiling,
+    logging,
+)
+
+
+if is_hqq_available():
+    from hqq.core.quantize import Quantizer as HQQQuantizer
+
+_is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)
+
+
+logger = logging.get_logger(__name__)
+
+
+class CacheLayerMixin(ABC):
+    """Base, abstract class for a single layer's cache."""
+
+    is_compileable = False
+
+    def __init__(self):
+        self.keys, self.values = None, None
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}"
+
+    @abstractmethod
+    def lazy_initialization(self, key_states: torch.Tensor): ...
+
+    @abstractmethod
+    def update(
+        self, key_states: torch.Tensor, value_states: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None
+    ) -> tuple[torch.Tensor, torch.Tensor]: ...
+
+    @abstractmethod
+    def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: ...
+
+    @abstractmethod
+    def get_seq_length(self) -> int: ...
+
+    @abstractmethod
+    def get_max_cache_shape(self) -> int: ...
+
+    def offload(self):
+        """Offload this layer's data to CPU device."""
+        if self.keys is not None:
+            self.keys = self.keys.to("cpu", non_blocking=True)
+            self.values = self.values.to("cpu", non_blocking=True)
+
+    def prefetch(self):
+        """In case of layer offloading, this allows to move the data back to the layer's device ahead of time."""
+        if self.keys is not None and self.keys.device != self.device:
+            self.keys = self.keys.to(self.device, non_blocking=True)
+            self.values = self.values.to(self.device, non_blocking=True)
+
+    def reset(self) -> None:
+        """Resets the cache values while preserving the objects"""
+        if self.keys is not None:
+            self.keys.zero_()
+            self.values.zero_()
+        # This attribute is set on several Layers
+        if hasattr(self, "cumulative_length"):
+            self.cumulative_length = 0
+
+    def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
+        """Reorders this layer's cache for beam search."""
+        if self.get_seq_length() > 0:
+            self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
+            self.values = self.values.index_select(0, beam_idx.to(self.values.device))
+
+
+class DynamicLayer(CacheLayerMixin):
+    """
+    A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
+    It stores the key and value states as tensors of shape `[batch_size, num_heads, seq_len, head_dim]`.
+    """
+
+    is_sliding = False
+
+    def lazy_initialization(self, key_states: torch.Tensor):
+        self.dtype, self.device = key_states.dtype, key_states.device
+        self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
+        self.values = torch.tensor([], dtype=self.dtype, device=self.device)
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        cache_kwargs: Optional[dict[str, Any]] = None,
+    ) -> tuple[torch.Tensor, torch.Tensor]:
+        """
+        Update the key and value caches in-place, and return the necessary keys and value states.
+
+        Args:
+            key_states (`torch.Tensor`): The new key states to cache.
+            value_states (`torch.Tensor`): The new value states to cache.
+            cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
+
+        Returns:
+            tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
+        """
+        # Lazy initialization
+        if self.keys is None:
+            self.lazy_initialization(key_states)
+
+        self.keys = torch.cat([self.keys, key_states], dim=-2)
+        self.values = torch.cat([self.values, value_states], dim=-2)
+        return self.keys, self.values
+
+    def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
+        """Return the length and offset of the cache, used to generate the mask"""
+        kv_offset = 0
+        query_length = cache_position.shape[0]
+        past_seen_tokens = self.get_seq_length()
+        kv_length = query_length + past_seen_tokens
+        return kv_length, kv_offset
+
+    def get_seq_length(self) -> int:
+        """Returns the sequence length of the cached states."""
+        if self.keys is None or self.keys.numel() == 0:
+            return 0
+        return self.keys.shape[-2]
+
+    def get_max_cache_shape(self) -> int:
+        """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
+        return -1
+
+    def crop(self, max_length: int) -> None:
+        """
+        Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative
+        to remove `max_length` tokens.
+        """
+        if max_length < 0:
+            max_length = self.get_seq_length() - abs(max_length)
+
+        if self.get_seq_length() <= max_length:
+            return
+
+        self.keys = self.keys[..., :max_length, :]
+        self.values = self.values[..., :max_length, :]
+
+    def batch_repeat_interleave(self, repeats: int) -> None:
+        """Repeat the cache `repeats` times in the batch dimension."""
+        if self.get_seq_length() > 0:
+            self.keys = self.keys.repeat_interleave(repeats, dim=0)
+            self.values = self.values.repeat_interleave(repeats, dim=0)
+
+    def batch_select_indices(self, indices: torch.Tensor) -> None:
+        """Only keep the `indices` in the batch dimension of the cache."""
+        if self.get_seq_length() > 0:
+            self.keys = self.keys[indices, ...]
+            self.values = self.values[indices, ...]
+
+
+class DynamicSlidingWindowLayer(DynamicLayer):
+    """
+    A cache layer that grows dynamically as more tokens are generated, up until the sliding window size.
+    It stores the key and value states as tensors of shape `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
+    """
+
+    is_sliding = True
+
+    def __init__(self, sliding_window: int):
+        super().__init__()
+        self.sliding_window = sliding_window
+        self.cumulative_length = 0
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        cache_kwargs: Optional[dict[str, Any]] = None,
+    ) -> tuple[torch.Tensor, torch.Tensor]:
+        """
+        Update the key and value caches in-place, and return the necessary keys and value states.
+
+        Args:
+            key_states (`torch.Tensor`): The new key states to cache.
+            value_states (`torch.Tensor`): The new value states to cache.
+            cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
+
+        Returns:
+            tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
+        """
+        # Lazy initialization
+        if self.keys is None:
+            self.lazy_initialization(key_states)
+
+        self.cumulative_length += key_states.shape[-2]
+
+        # Compute the full states
+        full_key_states = torch.cat([self.keys, key_states], dim=-2)
+        full_value_states = torch.cat([self.values, value_states], dim=-2)
+        # Only cache the last `self.sliding_window - 1` tokens (or all of them if lower than that)
+        self.keys = full_key_states[:, :, -self.sliding_window + 1 :, :]
+        self.values = full_value_states[:, :, -self.sliding_window + 1 :, :]
+
+        # Return the full states
+        return full_key_states, full_value_states
+
+    def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
+        """Return the length and offset of the cache, used to generate the attention mask"""
+        query_length = cache_position.shape[0]
+        first_cache_position = cache_position[0]
+
+        kv_offset = torch.clamp(first_cache_position - self.sliding_window + 1, min=0)
+
+        if self.get_seq_length() >= self.sliding_window:
+            kv_length = self.sliding_window - 1 + query_length
+        else:
+            kv_length = self.get_seq_length() + query_length
+
+        return kv_length, kv_offset
+
+    def get_seq_length(self) -> int:
+        """Returns the sequence length of the cached states."""
+        return self.cumulative_length
+
+    def get_max_cache_shape(self) -> int:
+        """Return the maximum cache shape of the cache"""
+        return self.sliding_window
+
+    def crop(self, max_length: int) -> None:
+        """
+        Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
+        negative to remove `max_length` tokens.
+        """
+        if self.get_seq_length() >= self.sliding_window:
+            raise ValueError(
+                "Cannot `crop` a `DynamicSlidingWindowLayer` after it has seen more tokens than its"
+                "sliding window (otherwise some states are lost)"
+            )
+        super().crop(max_length)
+        self.cumulative_length = self.keys.shape[-2]
+
+
+class StaticLayer(CacheLayerMixin):
+    """
+    A static cache layer that stores the key and value states as static tensors of shape `[batch_size, num_heads, max_cache_len), head_dim]`.
+    It lazily allocates its full backing tensors, and then mutates them in-place. Built for `torch.compile` support.
+
+    Args:
+        max_cache_len (`int`):
+            Maximum number of tokens that can be stored, used for tensor preallocation.
+    """
+
+    is_compileable = True
+    is_sliding = False
+
+    def __init__(self, max_cache_len: int):
+        super().__init__()
+        self.max_cache_len = max_cache_len
+
+    def lazy_initialization(self, key_states: torch.Tensor):
+        """
+        Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
+        num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
+        devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).
+
+        If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
+        function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
+        internally don't compile the prefill, this is guaranteed to have been called already when compiling.
+        If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
+        it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
+        i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
+        not be compiled anyway for performances!
+        """
+        self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
+        self.dtype, self.device = key_states.dtype, key_states.device
+
+        self.keys = torch.zeros(
+            (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
+            dtype=self.dtype,
+            device=self.device,
+        )
+        self.values = torch.zeros(
+            (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
+            dtype=self.dtype,
+            device=self.device,
+        )
+        # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph
+        # breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case.
+        # As prefill should never be compiled, this is not an issue and it will still be run (except when users compile
+        # prefill explicitly, but this should be avoided!)
+        if not is_torchdynamo_compiling():
+            torch._dynamo.mark_static_address(self.keys)
+            torch._dynamo.mark_static_address(self.values)
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        cache_kwargs: Optional[dict[str, Any]] = None,
+    ) -> tuple[torch.Tensor, torch.Tensor]:
+        """
+        Update the key and value caches in-place, and return the necessary keys and value states.
+
+        Args:
+            key_states (`torch.Tensor`): The new key states to cache.
+            value_states (`torch.Tensor`): The new value states to cache.
+            cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
+
+        Returns:
+            tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
+        """
+        # Lazy initialization
+        if self.keys is None:
+            self.lazy_initialization(key_states)
+
+        # Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
+        # in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
+        cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
+        cache_position = (
+            cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
+        )
+
+        # Update the cache
+        try:
+            self.keys.index_copy_(2, cache_position, key_states)
+            self.values.index_copy_(2, cache_position, value_states)
+        except NotImplementedError:
+            # Fallback for devices like MPS where index_copy_ might not be supported.
+            self.keys[:, :, cache_position] = key_states
+            self.values[:, :, cache_position] = value_states
+        return self.keys, self.values
+
+    def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
+        """Return the length and offset of the cache, used to generate the attention mask"""
+        kv_offset = 0
+        kv_length = self.max_cache_len
+        return kv_length, kv_offset
+
+    def get_seq_length(self) -> int:
+        """Returns the sequence length of the cached states."""
+        # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
+        # limit the check to the first batch member and head dimension.
+        return (self.keys[0, 0].any(dim=-1)).sum() if self.keys is not None else 0
+
+    def get_max_cache_shape(self) -> int:
+        """Return the maximum cache shape of the cache"""
+        return self.max_cache_len
+
+
+class SlidingWindowLayer(StaticLayer):
+    """
+    A static cache layer that stores the key and value states as static tensors of shape
+    `[batch_size, num_heads, min(max_cache_len, sliding_window), head_dim]`. It lazily allocates its full backing
+    tensors, and then mutates them in-place. Built for `torch.compile` support.
+
+    Args:
+        max_cache_len (`int`):
+            Maximum number of tokens that can be stored, used for tensor preallocation.
+        sliding_window (`int`):
+            The size of the sliding window.
+    """
+
+    is_sliding = True
+
+    def __init__(self, max_cache_len: int, sliding_window: int):
+        effective_max_cache_len = min(sliding_window, max_cache_len)
+        super().__init__(max_cache_len=effective_max_cache_len)
+        self.cumulative_length = 0
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        cache_kwargs: Optional[dict[str, Any]] = None,
+    ) -> tuple[torch.Tensor, torch.Tensor]:
+        """
+        Update the key and value caches in-place, and return the necessary keys and value states.
+
+        Args:
+            key_states (`torch.Tensor`): The new key states to cache.
+            value_states (`torch.Tensor`): The new value states to cache.
+            cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
+
+        Returns:
+            tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
+        """
+        # Lazy initialization
+        if self.keys is None:
+            self.lazy_initialization(key_states)
+
+        cache_position = cache_kwargs.get("cache_position")
+
+        is_full = self.cumulative_length >= self.max_cache_len
+        # Update it now that we saved the value above
+        self.cumulative_length += key_states.shape[-2]
+
+        # Handle prefill phase when prompt length > sliding_window_size.
+        # Note that we store cropped key/value states in the cache but return the full key/value states.
+        if cache_position.shape[0] > self.max_cache_len:
+            self.keys.copy_(key_states[:, :, -self.max_cache_len :, :])
+            self.values.copy_(value_states[:, :, -self.max_cache_len :, :])
+            # Return the full states here
+            return key_states, value_states
+
+        # Here we only assume decoding stage, i.e. 1 token at a time
+        if is_full:
+            # Roll all values to the left by 1 position
+            new_keys = self.keys.roll(-1, dims=-2)
+            new_values = self.values.roll(-1, dims=-2)
+            # Overwrite the last position with new states
+            # (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855)
+            index = torch.tensor([-1], dtype=int, device=self.device)
+            new_keys[:, :, index] = key_states
+            new_values[:, :, index] = value_states
+
+            # Copy back into `self` (do not just assign again) in order to keep the static dynamo address
+            self.keys.copy_(new_keys)
+            self.values.copy_(new_values)
+        else:
+            try:
+                self.keys.index_copy_(2, cache_position, key_states)
+                self.values.index_copy_(2, cache_position, value_states)
+            except NotImplementedError:
+                self.keys[:, :, cache_position] = key_states
+                self.values[:, :, cache_position] = value_states
+
+        return self.keys, self.values
+
+    def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
+        """Return the length and offset of the cache, used to generate the attention mask"""
+        query_length = cache_position.shape[0]
+        first_cache_position = cache_position[0]
+
+        kv_offset = torch.clamp(first_cache_position - self.max_cache_len + 1, min=0)
+        # This is not general (see HybridChunkedCache for the whole general case), but it's what the cache returns
+        kv_length = max(query_length, self.max_cache_len)
+        return kv_length, kv_offset
+
+    def get_seq_length(self) -> int:
+        """Returns the sequence length of the cached states."""
+        return self.cumulative_length
+
+
+class ChunkedSlidingLayer(SlidingWindowLayer):
+    """
+    An extended SlidingWindowLayer that supports prefill chunking, originally implemented for Llama 4.
+    """
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        cache_kwargs: Optional[dict[str, Any]] = None,
+    ) -> tuple[torch.Tensor, torch.Tensor]:
+        """
+        Update the key and value caches in-place, and return the necessary keys and value states.
+
+        Args:
+            key_states (`torch.Tensor`): The new key states to cache.
+            value_states (`torch.Tensor`): The new value states to cache.
+            cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
+
+        Returns:
+            tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
+        """
+        # Lazy initialization
+        if self.keys is None:
+            self.lazy_initialization(key_states)
+
+        cache_position = cache_kwargs.get("cache_position")
+
+        cumulative_length = self.cumulative_length
+        is_full = cumulative_length >= self.max_cache_len
+        # Update it now that we saved the value above
+        self.cumulative_length += key_states.shape[-2]
+
+        if is_full:
+            full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2)
+            full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2)
+            # Fast decoding path -> here as the effective size is still sliding window, it is extremely important
+            # to return `self.key_cache[layer_idx]` and `self.value_cache[layer_idx]`, as they have the fixed address
+            # in memory (the values are the same as the full states, but not the address!!)
+            if key_states.shape[-2] == 1:
+                self.keys.copy_(full_key_states)
+                self.values.copy_(full_value_states)
+                return self.keys, self.values
+        elif not is_full and cumulative_length + key_states.shape[2] > self.max_cache_len:
+            # Fast prefill path, no need to cat() in this case, as the cache is currently empty
+            if cumulative_length == 0:
+                full_key_states = key_states
+                full_value_states = value_states
+            else:
+                full_key_states = torch.cat((self.keys[:, :, :cumulative_length, :], key_states), dim=-2)
+                full_value_states = torch.cat((self.values[:, :, :cumulative_length, :], value_states), dim=-2)
+        else:
+            try:
+                self.keys.index_copy_(2, cache_position, key_states)
+                self.values.index_copy_(2, cache_position, value_states)
+            except NotImplementedError:
+                self.keys[:, :, cache_position] = key_states
+                self.values[:, :, cache_position] = value_states
+            return self.keys, self.values
+
+        self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :])
+        self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :])
+        # we should return the whole states instead of `self.keys/values` here, as otherwise we lose some context
+        # which is outside the window
+        return full_key_states, full_value_states
+
+    def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
+        """Return the length and offset of the cache, used to generate the attention mask"""
+        query_length = cache_position.shape[0]
+        first_cache_position = cache_position[0]
+        sliding_window = self.max_cache_len
+
+        kv_offset = torch.clamp(first_cache_position - sliding_window + 1, min=0)
+        # This is the true general case for any Cache using local attention (sliding or chunked)
+        if first_cache_position >= sliding_window:
+            # Here the Cache is already full
+            kv_length = sliding_window + query_length - 1
+        elif first_cache_position < sliding_window and first_cache_position + query_length > sliding_window:
+            # Here the Cache becomes full with the new input
+            kv_length = first_cache_position + query_length
+        else:
+            # Here the Cache is still smaller than the local size, but we return the local size as it's static
+            kv_length = sliding_window
+        return kv_length, kv_offset
+
+
+class QuantizedLayer(DynamicLayer):
+    """
+    A quantized layer similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
+    It allows the model to generate longer sequence length without allocating too much memory for the key and value caches by
+    applying quantization.
+
+    The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length`
+    is set as a maximum capacity for the original precision cache. When the length goes beyond maximum capacity, the original
+    precision cache is discarded and moved into the quantized cache. The quantization is done per-channel with a set `q_group_size`
+    for both Keys and Values, in contrast to what was described in the paper.
+    """
+
+    def __init__(
+        self,
+        nbits: int = 4,
+        axis_key: int = 0,
+        axis_value: int = 0,
+        q_group_size: int = 64,
+        residual_length: int = 128,
+    ):
+        super().__init__()
+        self.nbits = nbits
+        self.axis_key = axis_key
+        self.axis_value = axis_value
+        self.q_group_size = q_group_size
+        self.residual_length = residual_length
+        self.cumulative_length = 0
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        cache_kwargs: Optional[dict[str, Any]] = None,
+    ) -> tuple[torch.Tensor, torch.Tensor]:
+        """
+        Update the key and value caches in-place, and return the necessary keys and value states.
+
+        Args:
+            key_states (`torch.Tensor`): The new key states to cache.
+            value_states (`torch.Tensor`): The new value states to cache.
+            cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
+
+        Returns:
+            tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
+        """
+        self.cumulative_length += key_states.shape[-2]
+
+        # Lazy initialization
+        if self.keys is None:
+            self.lazy_initialization(key_states)
+            self._quantized_keys = self._quantize(key_states.contiguous(), axis=self.axis_key)
+            self._quantized_values = self._quantize(value_states.contiguous(), axis=self.axis_value)
+            return key_states, value_states
+
+        dequant_keys = self._dequantize(self._quantized_keys)
+        dequant_values = self._dequantize(self._quantized_values)
+        keys_to_return = torch.cat([dequant_keys, self.keys, key_states], dim=-2)
+        values_to_return = torch.cat([dequant_values, self.values, value_states], dim=-2)
+        if self.keys.dim() == 4 and self.keys.shape[-2] + 1 >= self.residual_length:
+            self._quantized_keys = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
+            self._quantized_values = self._quantize(values_to_return.contiguous(), axis=self.axis_value)
+            self.keys = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
+            self.values = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
+        else:
+            self.keys = torch.cat([self.keys, key_states], dim=-2)
+            self.values = torch.cat([self.values, value_states], dim=-2)
+
+        return keys_to_return, values_to_return
+
+    @abstractmethod
+    def _quantize(self, tensor, axis): ...
+
+    @abstractmethod
+    def _dequantize(self, q_tensor): ...
+
+    def get_seq_length(self) -> int:
+        """Returns the sequence length of the cached states."""
+        return self.cumulative_length
+
+
+class QuantoQuantizedLayer(QuantizedLayer):
+    def __init__(
+        self,
+        nbits: int = 4,
+        axis_key: int = 0,
+        axis_value: int = 0,
+        q_group_size: int = 64,
+        residual_length: int = 128,
+    ):
+        super().__init__(
+            nbits=nbits,
+            axis_key=axis_key,
+            axis_value=axis_value,
+            q_group_size=q_group_size,
+            residual_length=residual_length,
+        )
+
+        # We need to import quanto here to avoid circular imports due to optimum/quanto/models/transformers_models.py
+        if is_quanto_greater("0.2.5", accept_dev=True):
+            from optimum.quanto import MaxOptimizer, qint2, qint4
+        else:
+            raise ImportError(
+                "You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. "
+            )
+
+        if self.nbits not in [2, 4]:
+            raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
+
+        if self.axis_key not in [0, -1]:
+            raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")
+
+        if self.axis_value not in [0, -1]:
+            raise ValueError(
+                f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
+            )
+
+        self.qtype = qint4 if self.nbits == 4 else qint2
+        self.optimizer = MaxOptimizer()  # hardcode as it's the only one for per-channel quantization
+
+    def _quantize(self, tensor, axis):
+        from optimum.quanto import quantize_weight
+
+        scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
+        qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
+        return qtensor
+
+    def _dequantize(self, qtensor):
+        return qtensor.dequantize()
+
+
+class HQQQuantizedLayer(QuantizedLayer):
+    def __init__(
+        self,
+        nbits: int = 4,
+        axis_key: int = 0,
+        axis_value: int = 0,
+        q_group_size: int = 64,
+        residual_length: int = 128,
+    ):
+        super().__init__(
+            nbits=nbits,
+            axis_key=axis_key,
+            axis_value=axis_value,
+            q_group_size=q_group_size,
+            residual_length=residual_length,
+        )
+
+        if not is_hqq_available():
+            raise ImportError("You need to install `hqq` to use `HQQQuantizedLayer`")
+
+        if self.nbits not in [1, 2, 3, 4, 8]:
+            raise ValueError(
+                f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}"
+            )
+
+        if self.axis_key not in [0, 1]:
+            raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}")
+
+        if self.axis_value not in [0, 1]:
+            raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}")
+
+        self.quantizer = HQQQuantizer
+
+    def _quantize(self, tensor, axis):
+        qtensor, meta = self.quantizer.quantize(
+            tensor,
+            axis=axis,
+            device=self.keys.device,
+            compute_dtype=self.keys.dtype,
+            nbits=self.nbits,
+            group_size=self.q_group_size,
+        )
+        meta["compute_dtype"] = self.keys.dtype
+        self.quantizer.cuda(qtensor, meta=meta, device=self.keys.device)  # Move to device and cast to dtype
+        meta["scale"] = meta["scale"].to(qtensor.device)
+        meta["zero"] = meta["zero"].to(qtensor.device)
+        return qtensor, meta
+
+    def _dequantize(self, qtensor):
+        quant_tensor, meta = qtensor
+        tensor = self.quantizer.dequantize(quant_tensor, meta)
+        return tensor
+
+
+class Cache:
+    """
+    A `Cache` is mostly a list of `CacheLayerMixin` objects, one per model layer. It serves as a container for
+    the Cache of each layer.
+
+    Args:
+        layers (`Optional`, *optional*):
+            A list of pre-created `CacheLayerMixin`. If omitted (`None`), then `layer_class_to_replicate` will
+            be used.
+        layer_class_to_replicate (`type[CacheLayerMixin]`, *optional*):
+            Only used if `layers` is omitted (`None`), in which case it will be used as the base class for each layer,
+            and the layers will be added lazily as soon as `update` is called with a `layer_idx` greater than the current
+            list of layers.
+        offloading (`bool`, *optional*, defaults to `False`):
+            Whether to perform offloading of the layers to `cpu`, to save GPU memory.
+        offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
+            If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
+            usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
+    """
+
+    def __init__(
+        self,
+        layers: Optional[list[CacheLayerMixin]] = None,
+        layer_class_to_replicate: Optional[type[CacheLayerMixin]] = None,
+        offloading: bool = False,
+        offload_only_non_sliding: bool = True,
+    ):
+        if layers is not None and layer_class_to_replicate is not None:
+            raise ValueError(
+                "You can construct a Cache either from a list `layers` of all the predefined `CacheLayer`, or from a "
+                "`layer_class_to_replicate`, in which case the Cache will append a new layer corresponding to "
+                "`layer_class_to_replicate` for each new call to `update` with an idx not already in the Cache."
+            )
+        if layers is None and layer_class_to_replicate is None:
+            raise ValueError(
+                "You should provide exactly one of `layers` or `layer_class_to_replicate` to initialize a Cache."
+            )
+        self.layers = layers if layers is not None else []
+        self.layer_class_to_replicate = layer_class_to_replicate
+        self.offloading = offloading
+        if self.offloading:
+            self.only_non_sliding = offload_only_non_sliding
+            self.prefetch_stream = torch.Stream() if _is_torch_greater_or_equal_than_2_7 else torch.cuda.Stream()
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(layers={self.layers})"
+
+    def prefetch(self, layer_idx: int, only_non_sliding: bool = True):
+        """
+        Prefetch a given layer on its device. If `only_non_sliding` is True, it will try to prefetch only the layers
+        which are non-sliding. If the `layer_idx` is outside the range, this will circle back to the first layers.
+        Note that we use a non-default stream for this, to avoid blocking.
+        """
+        if only_non_sliding:
+            # Try to find next non-sliding, starting at `layer_idx`
+            try:
+                layer_idx = layer_idx + self.is_sliding[layer_idx:].index(False)
+            # In this case, we need to circle back to the beginning
+            except ValueError:
+                layer_idx = self.is_sliding.index(False)
+        else:
+            layer_idx = layer_idx if layer_idx < len(self.layers) else 0
+
+        # Prefetch
+        with self.prefetch_stream if _is_torch_greater_or_equal_than_2_7 else torch.cuda.stream(self.prefetch_stream):
+            self.layers[layer_idx].prefetch()
+
+    def offload(self, layer_idx: int, only_non_sliding: bool = True):
+        """
+        Offload a given `layer_idx`. If `only_non_sliding` is True, it will offload `layer_idx` only if it is a
+        non-sliding layer. Note that we do it on the default stream, so that we ensure all earlier
+        computation in the layer's `update` methods are finished.
+        """
+        if not (only_non_sliding and self.is_sliding[layer_idx]):
+            self.layers[layer_idx].offload()
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        layer_idx: int,
+        cache_kwargs: Optional[dict[str, Any]] = None,
+    ) -> tuple[torch.Tensor, torch.Tensor]:
+        """
+        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+        Parameters:
+            key_states (`torch.Tensor`):
+                The new key states to cache.
+            value_states (`torch.Tensor`):
+                The new value states to cache.
+            layer_idx (`int`):
+                The index of the layer to cache the states for.
+            cache_kwargs (`dict[str, Any]`, *optional*):
+                Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
+                cache to be created.
+
+        Return:
+            A tuple containing the updated key and value states.
+        """
+        # In this case, the `layers` were not provided, and we must append as much as `layer_idx`
+        if self.layer_class_to_replicate is not None:
+            while len(self.layers) <= layer_idx:
+                self.layers.append(self.layer_class_to_replicate())
+
+        if self.offloading:
+            # Wait for the stream to finish if needed, and start prefetching the next layer
+            torch.cuda.default_stream(key_states.device).wait_stream(self.prefetch_stream)
+            self.prefetch(layer_idx + 1, self.only_non_sliding)
+
+        keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
+
+        if self.offloading:
+            self.offload(layer_idx, self.only_non_sliding)
+
+        return keys, values
+
+    def early_initialization(
+        self, batch_size: int, num_heads: int, head_dim: int, dtype: torch.dtype, device: torch.device
+    ):
+        """
+        Initialize all the layers in advance (it's otherwise lazily initialized on the first `update` call).
+        This is useful for our `export` recipes, as `export` needs everything in advance.
+        """
+        # Note that the initialization needs all dimensions (except -2), as well as device and dtype, so we use
+        # this fake tensor approach. It has size 0 on the -2 dimension, so it does not allocate any data (it only
+        # creates an empty tensor with correct shape, dtype and device), which is very efficient and practical
+        fake_keys_tensor = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=dtype, device=device)
+        # Init all layers
+        for layer in self.layers:
+            layer.lazy_initialization(fake_keys_tensor)
+
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cache for the given layer."""
+        if layer_idx >= len(self.layers):
+            return 0
+        return self.layers[layer_idx].get_seq_length()
+
+    def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
+        """
+        Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
+        the given layer at `layer_idx`.
+        The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
+        """
+        # For DynamicCache, where the layers are created at runtime -> if it was not yet created, the size is
+        # simply the shape of `cache_position`
+        if layer_idx >= len(self.layers):
+            return cache_position.shape[0], 0
+        return self.layers[layer_idx].get_mask_sizes(cache_position)
+
+    def get_max_cache_shape(self, layer_idx: int = 0) -> int:
+        """Returns maximum sequence length of the cache object. Dynamic caches do not have a maximum length."""
+        # For DynamicCache, where the layers are created at runtime -> if it was not yet created, return -1
+        # as DynamicLayer does
+        if layer_idx >= len(self.layers):
+            return -1
+        return self.layers[layer_idx].get_max_cache_shape()
+
+    def reset(self):
+        """Recursively reset all layers tensors"""
+        for layer_idx in range(len(self.layers)):
+            self.layers[layer_idx].reset()
+
+    def reorder_cache(self, beam_idx: torch.LongTensor):
+        """Reorder the cache for beam search"""
+        for layer_idx in range(len(self.layers)):
+            self.layers[layer_idx].reorder_cache(beam_idx)
+
+    def crop(self, max_length: int):
+        """Crop the cache to the given length"""
+        for layer_idx in range(len(self.layers)):
+            self.layers[layer_idx].crop(max_length)
+
+    def batch_repeat_interleave(self, repeats: int):
+        """Repeat and interleave the cache"""
+        for layer_idx in range(len(self.layers)):
+            self.layers[layer_idx].batch_repeat_interleave(repeats)
+
+    def batch_select_indices(self, indices: torch.Tensor):
+        """Select indices from the cache"""
+        for layer_idx in range(len(self.layers)):
+            self.layers[layer_idx].batch_select_indices(indices)
+
+    @property
+    def max_batch_size(self) -> int:
+        """Return the maximum batch size of the cache"""
+        values = [layer.max_batch_size for layer in self.layers]
+        if len(set(values)) > 1:
+            raise ValueError(f"Max batch size is not consistent across layers: {values}")
+        return values[0]
+
+    @property
+    def max_cache_len(self) -> int:
+        """Return the maximum cache length of the cache"""
+        values = [layer.max_cache_len for layer in self.layers]
+        return max(values)
+
+    @property
+    def is_compileable(self) -> bool:
+        """Return whether the cache is compileable"""
+        # For DynamicCache dispatching the layers lazily (otherwise, all([]) is True)
+        if len(self.layers) == 0:
+            return False
+        return all(layer.is_compileable for layer in self.layers)
+
+    @property
+    def is_sliding(self) -> list[bool]:
+        """Return whether the layers of the cache are sliding window"""
+        return [getattr(layer, "is_sliding", False) for layer in self.layers]
+
+    def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
+        """
+        Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the
+        sequence length.
+        """
+        if layer_idx < len(self.layers):
+            return self.layers[layer_idx].keys, self.layers[layer_idx].values
+        # elif len(self.layers) == 0:
+        #     return None, None
+        else:
+            raise KeyError(
+                f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}"
+            )
+
+    def __iter__(self):
+        """
+        Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over
+        keys and values
+        """
+        for layer_idx in range(len(self)):
+            yield (self.layers[layer_idx].keys, self.layers[layer_idx].values)
+
+    def __len__(self):
+        """
+        This value corresponds to the number of layers in the model.
+        """
+        # Note: for DynamicCache, layers are initialized lazily, so this will not be accurate before the first
+        # forward through all the layers
+        return len(self.layers)
+
+
+class DynamicCache(Cache):
+    """
+    A cache that grows dynamically as more tokens are generated. This is the default for generative models.
+    It stores the key and value states as a list of `CacheLayer`, one for each layer. The expected shape for each tensor
+    in the `CacheLayer`s is `[batch_size, num_heads, seq_len, head_dim]`.
+    If a config is passed, it will additionally check for sliding or hybrid cache structure, greatly reducing the
+    memory requirement of the cached tensors to `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
+
+    See `Cache` for details on common methods that are implemented by all cache classes.
+
+    Args:
+        ddp_cache_data (`Iterable[tuple[torch.Tensor, torch.Tensor]]`, *optional*):
+            It was originally added for compatibility with `torch.distributed` (DDP). In a nutshell, it is
+            `map(gather_map, zip(*caches))`, i.e. each item in the iterable contains the key and value states
+            for a layer gathered across replicas by torch.distributed (shape=[global batch size, num_heads, seq_len, head_dim]).
+            Note: it needs to be the 1st arg as well to work correctly
+        config (`PretrainedConfig`, *optional*):
+            The config of the model for which this Cache will be used. If passed, it will be used to check for sliding
+            or hybrid layer structure, greatly reducing the memory requirement of the cached tensors to
+            `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
+        offloading (`bool`, *optional*, defaults to `False`):
+            Whether to perform offloading of the layers to `cpu`, to save GPU memory.
+        offload_only_non_sliding (`bool`, *optional*, defaults to `False`):
+            If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
+            usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
+
+    Example:
+
+    ```python
+    >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
+
+    >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+    >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+
+    >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
+
+    >>> # Prepare a cache class and pass it to model's forward
+    >>> past_key_values = DynamicCache(config=model.config)
+    >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+    >>> outputs.past_key_values # access cache filled with key/values from generation
+    ```
+    """
+
+    def __init__(
+        self,
+        ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None,
+        config: Optional[PretrainedConfig] = None,
+        offloading: bool = False,
+        offload_only_non_sliding: bool = False,
+    ):
+        layers = []
+        # If a config is passed, use it to infer the layer types and initialize accordingly
+        if config is not None:
+            config = config.get_text_config()
+            sliding_window = getattr(config, "sliding_window", None) or getattr(config, "attention_chunk_size", None)
+            layer_types = getattr(config, "layer_types", None)
+            if layer_types is None:
+                layer_types = [
+                    "sliding_attention" if sliding_window is not None else "full_attention"
+                    for _ in range(config.num_hidden_layers)
+                ]
+            # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
+            if hasattr(config, "num_kv_shared_layers"):
+                layer_types = layer_types[: -config.num_kv_shared_layers]
+
+            for layer_type in layer_types:
+                if layer_type in ("sliding_attention", "chunked_attention"):
+                    layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window))
+                else:
+                    layers.append(DynamicLayer())
+
+        # In this case, use the passed data to already fill in the Cache
+        if ddp_cache_data is not None:
+            # Init all the layers with the data
+            for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data):
+                # If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data
+                if config is None:
+                    layers.append(DynamicLayer())
+                # Update the layer with the data
+                _, _ = layers[layer_idx].update(key_states, value_states)
+
+        # If neither of config nor ddp_data was passed, then simply lazy init a full cache of DynamicLayer
+        if len(layers) == 0:
+            super().__init__(
+                layer_class_to_replicate=DynamicLayer,
+                offloading=offloading,
+                offload_only_non_sliding=offload_only_non_sliding,
+            )
+        else:
+            super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding)
+
+    def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]:
+        """
+        Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
+        backward compatibility.
+        """
+        legacy_cache = ()
+        for layer in self.layers:
+            legacy_cache += ((layer.keys, layer.values),)
+        return legacy_cache
+
+    @classmethod
+    def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.Tensor, torch.Tensor]]) -> "DynamicCache":
+        """
+        Converts a cache in the legacy cache format into an equivalent `Cache`. Used for
+        backward compatibility.
+        """
+        cache = cls()
+        if past_key_values is None:
+            logger.warning_once("past_key_values should not be None in from_legacy_cache()")
+        if past_key_values is not None:
+            for layer_idx in range(len(past_key_values)):
+                key_states, value_states = past_key_values[layer_idx]
+                cache.update(key_states, value_states, layer_idx)
+        return cache
+
+
+class StaticCache(Cache):
+    """
+    Static Cache class to be used with `torch.compile(model)` and `torch.export()`. It will check the `config`
+    for potential hybrid cache structure, and initialize each layer accordingly.
+
+    See `Cache` for details on common methods that are implemented by all cache classes.
+
+    Args:
+        config (`PretrainedConfig`):
+            The config of the model for which this Cache will be used. It will be used to check for sliding
+            or hybrid layer structure, and initialize each layer accordingly.
+        max_cache_len (`int`):
+            The maximum number of tokens that this Cache should hold.
+        offloading (`bool`, *optional*, defaults to `False`):
+            Whether to perform offloading of the layers to `cpu`, to save GPU memory.
+        offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
+            If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
+            usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
+
+    Example:
+
+    ```python
+    >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
+
+    >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
+    >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
+
+    >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt")
+
+    >>> # Prepare a cache class and pass it to model's forward
+    >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
+    >>> max_generated_length = inputs.input_ids.shape[1] + 10
+    >>> past_key_values = StaticCache(config=model.config, max_cache_len=max_generated_length)
+    >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+    >>> outputs.past_key_values # access cache filled with key/values from generation
+    StaticCache()
+    ```
+    """
+
+    # Pass-in kwargs as well to avoid crashing for BC (it used more arguments before)
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        max_cache_len: int,
+        offloading: bool = False,
+        offload_only_non_sliding: bool = True,
+        **kwargs,
+    ):
+        config = config.get_text_config()
+        layer_types = getattr(config, "layer_types", None)
+        # If `layer_types` is not explicitly provided, infer if the model is fully sliding
+        if layer_types is None:
+            if getattr(config, "sliding_window", None) is not None:
+                layer_types = ["sliding_attention" for _ in range(config.num_hidden_layers)]
+            elif getattr(config, "attention_chunk_size", None) is not None:
+                layer_types = ["chunked_attention" for _ in range(config.num_hidden_layers)]
+            else:
+                layer_types = ["full_attention" for _ in range(config.num_hidden_layers)]
+        # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
+        if hasattr(config, "num_kv_shared_layers"):
+            layer_types = layer_types[: -config.num_kv_shared_layers]
+
+        layers = []
+        for layer_type in layer_types:
+            if layer_type == "sliding_attention":
+                layer = SlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window)
+            elif layer_type == "chunked_attention":
+                layer = ChunkedSlidingLayer(max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size)
+            else:
+                layer = StaticLayer(max_cache_len=max_cache_len)
+            layers.append(layer)
+
+        super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding)
+
+
+class QuantizedCache(Cache):
+    """
+    A quantizer cache similar to what is described in the
+    [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
+    It allows the model to generate longer sequence length without allocating too much memory for keys and values
+    by applying quantization.
+    The cache has two types of storage, one for original precision and one for the
+    quantized cache. A `residual length` is set as a maximum capacity for the original precision cache. When the
+    length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache.
+    The quantization is done per-channel with a set `q_group_size` for both keys and values, in contrast to what was
+    described in the paper.
+
+    See `Cache` for details on common methods that are implemented by all cache classes.
+
+    Args:
+        backend (`str`):
+            The quantization backend to use. One of `("quanto", "hqq").
+        config (`PretrainedConfig`):
+            The config of the model for which this Cache will be used.
+        nbits (`int`, *optional*, defaults to 4):
+            The number of bits for quantization.
+        axis_key (`int`, *optional*, defaults to 0):
+            The axis on which to quantize the keys.
+        axis_value (`int`, *optional*, defaults to 0):
+            The axis on which to quantize the values.
+        q_group_size (`int`, *optional*, defaults to 64):
+            Quantization is done per-channel according to a set `q_group_size` for both keys and values.
+        residual_length (`int`, *optional*, defaults to 128):
+            Maximum capacity for the original precision cache
+    """
+
+    def __init__(
+        self,
+        backend: str,
+        config: PretrainedConfig,
+        nbits: int = 4,
+        axis_key: int = 0,
+        axis_value: int = 0,
+        q_group_size: int = 64,
+        residual_length: int = 128,
+    ):
+        if backend == "quanto":
+            layer_class = QuantoQuantizedLayer
+        elif backend == "hqq":
+            layer_class = HQQQuantizedLayer
+        else:
+            raise ValueError(f"Unknown quantization backend `{backend}`")
+
+        config = config.get_text_config(decoder=True)
+        layers = [
+            layer_class(nbits, axis_key, axis_value, q_group_size, residual_length)
+            for _ in range(config.num_hidden_layers)
+        ]
+        super().__init__(layers=layers)
+
+
+class EncoderDecoderCache(Cache):
+    """
+    Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
+    cross-attention caches.
+
+    See `Cache` for details on common methods that are implemented by all cache classes.
+
+    Args:
+        caches (`Iterable`):
+            Usually an iterable of length 2, containing 2 `Cache` objects, the first one for self-attention, the
+            second one for cross-attention. Can optionally also be an iterable of length 1, containing a
+            `tuple[tuple[torch.Tensor]]` (usually used for compatibility with torch dp and ddp).
+
+    Example:
+
+    ```python
+    >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache
+
+    >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small")
+    >>> processor = AutoProcessor.from_pretrained("openai/whisper-small")
+
+    >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt")
+
+    >>> # Prepare cache classes for encoder and decoder and pass it to model's forward
+    >>> self_attention_cache = DynamicCache(config=self.config)
+    >>> cross_attention_cache = DynamicCache(config=self.config)
+    >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache)
+    >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
+    >>> outputs.past_key_values # access cache filled with key/values from generation
+    EncoderDecoderCache()
+    ```
+    """
+
+    def __init__(self, *caches) -> None:
+        # For dp and ddp support, if only one argument is passed, it should be an iterable of tuples of tensors
+        if len(caches) == 1:
+            self.self_attention_cache = DynamicCache()
+            self.cross_attention_cache = DynamicCache()
+            # Populate cache from the iterable
+            for layer_idx, key_value_states in enumerate(caches[0]):
+                key_states, value_states = key_value_states[:2]
+                self.self_attention_cache.update(key_states, value_states, layer_idx)
+                if len(key_value_states) > 2:
+                    key_states, value_states = key_value_states[2:]
+                    self.cross_attention_cache.update(key_states, value_states, layer_idx)
+        # Otherwise, we should get two arguments, a self-attention cache and a cross-attention cache
+        elif len(caches) == 2:
+            if not isinstance(caches[0], Cache) or not isinstance(caches[1], Cache):
+                raise TypeError(f"One of the two arguments is not a Cache: {type(caches[0]) = }, {type(caches[1]) = }")
+            self.self_attention_cache = caches[0]
+            self.cross_attention_cache = caches[1]
+        # Error case
+        else:
+            raise ValueError(f"Expected 1 or 2 arguments, got {len(caches)}")
+
+        self.is_updated = {}
+        for layer_idx in range(len(self.cross_attention_cache)):
+            self.is_updated[layer_idx] = bool(self.cross_attention_cache.get_seq_length(layer_idx) > 0)
+
+    def __repr__(self) -> str:
+        return (
+            f"{self.__class__.__name__}(self_attention_cache={self.self_attention_cache}, cross_attention_cache="
+            f"{self.cross_attention_cache})"
+        )
+
+    def __iter__(self):
+        """
+        Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over
+        keys and values
+        """
+        for layer_idx in range(len(self)):
+            yield (
+                self.self_attention_cache.layers[layer_idx].keys,
+                self.self_attention_cache.layers[layer_idx].values,
+                self.cross_attention_cache.layers[layer_idx].keys,
+                self.cross_attention_cache.layers[layer_idx].values,
+            )
+
+    def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+        """
+        Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the
+        sequence length.
+        """
+        if layer_idx < len(self):
+            return (
+                self.self_attention_cache.layers[layer_idx].keys,
+                self.self_attention_cache.layers[layer_idx].values,
+                self.cross_attention_cache.layers[layer_idx].keys,
+                self.cross_attention_cache.layers[layer_idx].values,
+            )
+        else:
+            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+    def __len__(self):
+        """
+        Support for backwards-compatible `past_key_values` length, e.g. `len(past_key_values)`. This value corresponds
+        to the number of layers in the model.
+        """
+        return len(self.self_attention_cache)
+
+    def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]:
+        """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
+        legacy_cache = ()
+        if len(self.cross_attention_cache) > 0:
+            for self_attn, cross_attn in zip(
+                self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
+            ):
+                legacy_cache += (self_attn + cross_attn,)
+        else:
+            legacy_cache = self.self_attention_cache.to_legacy_cache()
+        return legacy_cache
+
+    @classmethod
+    def from_legacy_cache(
+        cls, past_key_values: Optional[Iterable[tuple[torch.FloatTensor, ...]]]
+    ) -> "EncoderDecoderCache":
+        """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
+        cache = cls(DynamicCache(), DynamicCache())
+        if past_key_values is None:
+            logger.warning_once("past_key_values should not be None in from_legacy_cache()")
+        else:
+            for layer_idx, key_value_states in enumerate(past_key_values):
+                key_states, value_states = key_value_states[:2]
+                cache.self_attention_cache.update(key_states, value_states, layer_idx)
+                if len(key_value_states) > 2:
+                    key_states, value_states = key_value_states[2:]
+                    cache.cross_attention_cache.update(key_states, value_states, layer_idx)
+                    cache.is_updated[layer_idx] = True
+        return cache
+
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+        return self.self_attention_cache.get_seq_length(layer_idx)
+
+    def reset(self):
+        self.self_attention_cache.reset()
+        self.cross_attention_cache.reset()
+        for layer_idx in self.is_updated:
+            self.is_updated[layer_idx] = False
+
+    def reorder_cache(self, beam_idx: torch.LongTensor):
+        """Reorders the cache for beam search, given the selected beam indices."""
+        self.self_attention_cache.reorder_cache(beam_idx)
+        self.cross_attention_cache.reorder_cache(beam_idx)
+
+    def check_dynamic_cache(self, method: str):
+        if not (
+            isinstance(self.self_attention_cache, DynamicCache)
+            and isinstance(self.cross_attention_cache, DynamicCache)
+        ):
+            raise ValueError(
+                f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
+                f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
+            )
+
+    # TODO(gante, sanchit-gandhi): move following functionality into `.generate`
+    def crop(self, maximum_length: int):
+        """
+        Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
+        negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search (on the Hub).
+        """
+        self.check_dynamic_cache(self.crop.__name__)
+        self.self_attention_cache.crop(maximum_length)
+
+    def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]":
+        """
+        Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
+        `_split_model_inputs()` in `generation.utils`
+        """
+        self.check_dynamic_cache(self.batch_split.__name__)
+        self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
+        cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
+
+        out = []
+        for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
+            out.append(EncoderDecoderCache(self_attn, cross_attn))
+        return out
+
+    def batch_repeat_interleave(self, repeats: int):
+        """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search (on the Hub)."""
+        self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
+        self.self_attention_cache.batch_repeat_interleave(repeats)
+        self.cross_attention_cache.batch_repeat_interleave(repeats)
+
+    def batch_select_indices(self, indices: torch.Tensor):
+        """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search (on the Hub)."""
+        self.check_dynamic_cache(self.batch_select_indices.__name__)
+        self.self_attention_cache.batch_select_indices(indices)
+        self.cross_attention_cache.batch_select_indices(indices)
+
+    def get_max_cache_shape(self) -> int:
+        """Returns the maximum sequence length (i.e. max capacity) of the cache object"""
+        return self.self_attention_cache.get_max_cache_shape()
+
+    def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
+        return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx)
+
+    @property
+    def is_sliding(self):
+        return self.self_attention_cache.is_sliding
+
+    @property
+    def is_compileable(self) -> bool:
+        return self.self_attention_cache.is_compileable
+
+
+### Deprecated classes
+
+
+class OffloadedCache(DynamicCache):
+    def __init__(self) -> None:
+        logger.warning_once(
+            "`OffloadedCache` is deprecated and will be removed in version v4.59 "
+            "Use `DynamicCache(offloading=True)` instead"
+        )
+        super().__init__(offloading=True)
+
+
+class OffloadedStaticCache(StaticCache):
+    def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
+        logger.warning_once(
+            "`OffloadedStaticCache` is deprecated and will be removed in version v4.59 "
+            "Use `StaticCache(..., offloading=True)` instead"
+        )
+        super().__init__(config=config, max_cache_len=max_cache_len, offloading=True)
+
+
+class SlidingWindowCache(StaticCache):
+    def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
+        logger.warning_once(
+            "`SlidingWindowCache` is deprecated and will be removed in version v4.59 "
+            "Use `StaticCache(...)` instead which will correctly infer the type of each layer."
+        )
+        super().__init__(config=config, max_cache_len=max_cache_len)
+
+
+class HybridCache(StaticCache):
+    def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
+        logger.warning_once(
+            "`HybridCache` is deprecated and will be removed in version v4.59 "
+            "Use `StaticCache(...)` instead which will correctly infer the type of each layer."
+        )
+        super().__init__(config=config, max_cache_len=max_cache_len)
+
+
+class HybridChunkedCache(StaticCache):
+    def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
+        logger.warning_once(
+            "`HybridChunkedCache` is deprecated and will be removed in version v4.59 "
+            "Use `StaticCache(...)` instead which will correctly infer the type of each layer."
+        )
+        super().__init__(config=config, max_cache_len=max_cache_len)
+
+
+class OffloadedHybridCache(StaticCache):
+    def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
+        logger.warning_once(
+            "`OffloadedHybridCache` is deprecated and will be removed in version v4.59 "
+            "Use `StaticCache(..., offload=True)` instead which will correctly infer the type of each layer."
+        )
+        super().__init__(config=config, max_cache_len=max_cache_len, offloading=True)
+
+
+class QuantoQuantizedCache(QuantizedCache):
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        nbits: int = 4,
+        axis_key: int = 0,
+        axis_value: int = 0,
+        q_group_size: int = 64,
+        residual_length: int = 128,
+    ):
+        logger.warning_once(
+            "`QuantoQuantizedCache` is deprecated and will be removed in version v4.59 "
+            "Use `QuantizedCache(backend='quanto', ...)` instead."
+        )
+        super().__init__("quanto", config, nbits, axis_key, axis_value, q_group_size, residual_length)
+
+
+class HQQQuantizedCache(QuantizedCache):
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        nbits: int = 4,
+        axis_key: int = 0,
+        axis_value: int = 0,
+        q_group_size: int = 64,
+        residual_length: int = 128,
+    ):
+        logger.warning_once(
+            "`HQQQuantizedCache` is deprecated and will be removed in version v4.59 "
+            "Use `QuantizedCache(backend='hqq', ...)` instead."
+        )
+        super().__init__("hqq", config, nbits, axis_key, axis_value, q_group_size, residual_length)
+
+
+class SinkCache(Cache):
+    """
+    It is now a `custom_generate` repository on the Hub: https://huggingface.co/transformers-community/sink_cache.
+    See [these docs](https://huggingface.co/docs/transformers/generation_strategies#custom-decoding-methods) for
+    general `custom_generate`usage.
+    """
+
+    # TODO (joao, manuel): Remove this class in v4.59.0
+    def __init__(self, **kwargs) -> None:
+        raise NotImplementedError(
+            "`SinkCache` has been moved as a `custom_generate` repository on the Hub: "
+            "https://huggingface.co/transformers-community/sink_cache. See the repository for usage examples."
+        )
diff --git a/phivenv/Lib/site-packages/transformers/commands/__init__.py b/phivenv/Lib/site-packages/transformers/commands/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa5d95a85b538171ec9cf4fa16e892df1efdef6b
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/commands/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from abc import ABC, abstractmethod
+from argparse import ArgumentParser
+
+
+class BaseTransformersCLICommand(ABC):
+    @staticmethod
+    @abstractmethod
+    def register_subcommand(parser: ArgumentParser):
+        raise NotImplementedError()
+
+    @abstractmethod
+    def run(self):
+        raise NotImplementedError()
diff --git a/phivenv/Lib/site-packages/transformers/commands/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/transformers/commands/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d85ed801da9025f3a6098785c56444841edf21c6
Binary files /dev/null and b/phivenv/Lib/site-packages/transformers/commands/__pycache__/__init__.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/transformers/commands/__pycache__/add_fast_image_processor.cpython-39.pyc b/phivenv/Lib/site-packages/transformers/commands/__pycache__/add_fast_image_processor.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..62907e4c560293853bcbbf815e60775b28f7d76e
Binary files /dev/null and b/phivenv/Lib/site-packages/transformers/commands/__pycache__/add_fast_image_processor.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/transformers/commands/__pycache__/add_new_model_like.cpython-39.pyc b/phivenv/Lib/site-packages/transformers/commands/__pycache__/add_new_model_like.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0c9b1e52df424496c4233604485151a1ba797fef
Binary files /dev/null and b/phivenv/Lib/site-packages/transformers/commands/__pycache__/add_new_model_like.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/transformers/commands/__pycache__/chat.cpython-39.pyc b/phivenv/Lib/site-packages/transformers/commands/__pycache__/chat.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5b0b82c2824af01e22217bff096e4192eb5e6b6d
Binary files /dev/null and b/phivenv/Lib/site-packages/transformers/commands/__pycache__/chat.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/transformers/commands/__pycache__/convert.cpython-39.pyc b/phivenv/Lib/site-packages/transformers/commands/__pycache__/convert.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a1aed1fa965f7248f6ab24e597308e92d51c4744
Binary files /dev/null and b/phivenv/Lib/site-packages/transformers/commands/__pycache__/convert.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/transformers/commands/__pycache__/download.cpython-39.pyc b/phivenv/Lib/site-packages/transformers/commands/__pycache__/download.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93133f3ffc25ab5ad34b917d669d0874897774fc
Binary files /dev/null and b/phivenv/Lib/site-packages/transformers/commands/__pycache__/download.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/transformers/commands/__pycache__/env.cpython-39.pyc b/phivenv/Lib/site-packages/transformers/commands/__pycache__/env.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..53cc40e21dad38929dfbd78941b084a4ffd7787b
Binary files /dev/null and b/phivenv/Lib/site-packages/transformers/commands/__pycache__/env.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/transformers/commands/__pycache__/run.cpython-39.pyc b/phivenv/Lib/site-packages/transformers/commands/__pycache__/run.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..056de99488cf54cad3b752bcb9bf6fa87c401a5b
Binary files /dev/null and b/phivenv/Lib/site-packages/transformers/commands/__pycache__/run.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/transformers/commands/__pycache__/serving.cpython-39.pyc b/phivenv/Lib/site-packages/transformers/commands/__pycache__/serving.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..83aac801a8abdc44c19e264c269684ec98594598
Binary files /dev/null and b/phivenv/Lib/site-packages/transformers/commands/__pycache__/serving.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/transformers/commands/__pycache__/train.cpython-39.pyc b/phivenv/Lib/site-packages/transformers/commands/__pycache__/train.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..28cd13d358f3cda3310e688be4c338c59ec40cd5
Binary files /dev/null and b/phivenv/Lib/site-packages/transformers/commands/__pycache__/train.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/transformers/commands/__pycache__/transformers_cli.cpython-39.pyc b/phivenv/Lib/site-packages/transformers/commands/__pycache__/transformers_cli.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..466e3a475983d9b7ab86226a25e828795904f77c
Binary files /dev/null and b/phivenv/Lib/site-packages/transformers/commands/__pycache__/transformers_cli.cpython-39.pyc differ
diff --git a/phivenv/Lib/site-packages/transformers/commands/add_fast_image_processor.py b/phivenv/Lib/site-packages/transformers/commands/add_fast_image_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..90c911525f7773a34ca9f97f180d5b127649f3c2
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/commands/add_fast_image_processor.py
@@ -0,0 +1,530 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import re
+from argparse import ArgumentParser, Namespace
+from datetime import date
+from pathlib import Path
+
+from ..utils import logging
+from . import BaseTransformersCLICommand
+
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+CURRENT_YEAR = date.today().year
+TRANSFORMERS_PATH = Path(__file__).parent.parent
+REPO_PATH = TRANSFORMERS_PATH.parent.parent
+
+
+def add_fast_image_processor_to_model_init(
+    fast_image_processing_module_file: str, fast_image_processor_name, model_name: str
+):
+    """
+    Add the fast image processor to the __init__.py file of the model.
+    """
+    with open(TRANSFORMERS_PATH / "models" / model_name / "__init__.py", "r", encoding="utf-8") as f:
+        content = f.read()
+
+    fast_image_processing_module_file = fast_image_processing_module_file.split(os.sep)[-1].replace(".py", "")
+
+    if "import *" in content:
+        # we have an init file in the updated format
+        # get the indented block after if TYPE_CHECKING: and before else:, append the new import, sort the imports and write the updated content
+        # Step 1: Find the block
+        block_regex = re.compile(
+            r"if TYPE_CHECKING:\n(?P.*?)(?=\s*else:)",
+            re.DOTALL,
+        )
+        match = block_regex.search(content)
+
+        if not match:
+            raise ValueError("Couldn't find the 'if TYPE_CHECKING' block.")
+
+        block_content = match.group("if_block")  # The captured import block
+
+        # Step 2: Parse existing entries
+        entries = block_content.split("\n")
+        indent = " " * (len(entries[0]) - len(entries[0].lstrip()))
+        new_entry = f"{indent}from .{fast_image_processing_module_file} import *"
+        if new_entry not in entries:
+            entries.append(new_entry)
+        entries.sort()
+        updated_block = "\n".join(entry for entry in entries)
+
+        # Replace the original block in the content
+        updated_content = content[: match.start("if_block")] + updated_block + content[match.end("if_block") :]
+    else:
+        # we have an init file in the old format
+
+        # add "is_torchvision_available" import to from ...utils import (
+        # Regex to match import statements from transformers.utils
+        pattern = r"""
+            from\s+\.\.\.utils\s+import\s+
+            (?:                                   # Non-capturing group for either:
+                ([\w, ]+)                         # 1. Single-line imports (e.g., 'a, b')
+                |                                 # OR
+                \((.*?)\)                         # 2. Multi-line imports (e.g., '(a, ... b)')
+            )
+        """
+        regex = re.compile(pattern, re.VERBOSE | re.DOTALL)
+
+        def replacement_function(match):
+            # Extract existing imports
+            imports = (match.group(1) or match.group(2)).split(",")
+            imports = imports[:-1] if imports[-1] == "\n" else imports
+            imports = [imp.strip() for imp in imports]
+
+            # Add the new import if not already present
+            if "is_torchvision_available" not in imports:
+                imports.append("is_torchvision_available")
+                imports.sort()
+
+            # Convert to multi-line import in all cases
+            updated_imports = "(\n    " + ",\n    ".join(imports) + ",\n)"
+
+            return f"from ...utils import {updated_imports}"
+
+        # Replace all matches in the file content
+        updated_content = regex.sub(replacement_function, content)
+
+        vision_import_structure_block = f'    _import_structure["{fast_image_processing_module_file[:-5]}"] = ["{fast_image_processor_name[:-4]}"]\n'
+
+        added_import_structure_block = (
+            "try:\n    if not is_torchvision_available():\n"
+            "        raise OptionalDependencyNotAvailable()\n"
+            "except OptionalDependencyNotAvailable:\n"
+            "    pass\n"
+            "else:\n"
+            f'    _import_structure["{fast_image_processing_module_file}"] = ["{fast_image_processor_name}"]\n'
+        )
+
+        if vision_import_structure_block not in updated_content:
+            raise ValueError("Couldn't find the 'vision _import_structure block' block.")
+
+        if added_import_structure_block not in updated_content:
+            updated_content = updated_content.replace(
+                vision_import_structure_block, vision_import_structure_block + "\n" + added_import_structure_block
+            )
+
+        vision_import_statement_block = (
+            f"        from .{fast_image_processing_module_file[:-5]} import {fast_image_processor_name[:-4]}\n"
+        )
+
+        added_import_statement_block = (
+            "    try:\n        if not is_torchvision_available():\n"
+            "            raise OptionalDependencyNotAvailable()\n"
+            "    except OptionalDependencyNotAvailable:\n"
+            "        pass\n"
+            "    else:\n"
+            f"        from .{fast_image_processing_module_file} import {fast_image_processor_name}\n"
+        )
+
+        if vision_import_statement_block not in updated_content:
+            raise ValueError("Couldn't find the 'vision _import_structure block' block.")
+
+        if added_import_statement_block not in updated_content:
+            updated_content = updated_content.replace(
+                vision_import_statement_block, vision_import_statement_block + "\n" + added_import_statement_block
+            )
+
+    # write the updated content
+    with open(TRANSFORMERS_PATH / "models" / model_name / "__init__.py", "w", encoding="utf-8") as f:
+        f.write(updated_content)
+
+
+def add_fast_image_processor_to_auto(image_processor_name: str, fast_image_processor_name: str):
+    """
+    Add the fast image processor to the auto module.
+    """
+    with open(TRANSFORMERS_PATH / "models" / "auto" / "image_processing_auto.py", "r", encoding="utf-8") as f:
+        content = f.read()
+
+    # get all lines containing the image processor name
+    updated_content = content.replace(
+        f'("{image_processor_name}",)', f'("{image_processor_name}", "{fast_image_processor_name}")'
+    )
+
+    # write the updated content
+    with open(TRANSFORMERS_PATH / "models" / "auto" / "image_processing_auto.py", "w", encoding="utf-8") as f:
+        f.write(updated_content)
+
+
+def add_fast_image_processor_to_doc(fast_image_processor_name: str, model_name: str):
+    """
+    Add the fast image processor to the model's doc file.
+    """
+    doc_source = REPO_PATH / "docs" / "source"
+    # find the doc files
+    doc_files = list(doc_source.glob(f"*/model_doc/{model_name}.md"))
+    if not doc_files:
+        # try again with "-"
+        doc_files = list(doc_source.glob(f"*/model_doc/{model_name.replace('_', '-')}.md"))
+    if not doc_files:
+        raise ValueError(f"No doc files found for {model_name}")
+
+    base_doc_string = (
+        f"## {fast_image_processor_name[:-4]}\n\n[[autodoc]] {fast_image_processor_name[:-4]}\n    - preprocess"
+    )
+    fast_doc_string = f"## {fast_image_processor_name}\n\n[[autodoc]] {fast_image_processor_name}\n    - preprocess"
+
+    for doc_file in doc_files:
+        with open(doc_file, "r", encoding="utf-8") as f:
+            content = f.read()
+
+        if fast_doc_string not in content:
+            # add the fast image processor to the doc
+            updated_content = content.replace(
+                base_doc_string,
+                base_doc_string + "\n\n" + fast_doc_string,
+            )
+
+            # write the updated content
+            with open(doc_file, "w", encoding="utf-8") as f:
+                f.write(updated_content)
+
+
+def add_fast_image_processor_to_tests(fast_image_processor_name: str, model_name: str):
+    """
+    Add the fast image processor to the image processing tests.
+    """
+    tests_path = REPO_PATH / "tests" / "models" / model_name
+    test_file = tests_path / f"test_image_processing_{model_name}.py"
+    if not os.path.exists(test_file):
+        logger.warning(f"No test file found for {model_name}. Skipping.")
+        return
+
+    with open(test_file, "r", encoding="utf-8") as f:
+        content = f.read()
+
+    # add is_torchvision_available import to the imports
+    # Regex to match import statements from transformers.utils
+    pattern = r"""
+        from\s+transformers\.utils\s+import\s+
+        (?:                                   # Non-capturing group for either:
+            ([\w, ]+)                         # 1. Single-line imports (e.g., 'a, b')
+            |                                 # OR
+            \((.*?)\)                         # 2. Multi-line imports (e.g., '(a, ... b)')
+        )
+    """
+    regex = re.compile(pattern, re.VERBOSE | re.DOTALL)
+
+    def replacement_function(match):
+        # Extract existing imports
+        existing_imports = (match.group(1) or match.group(2)).split(",")
+        existing_imports = existing_imports[:-1] if existing_imports[-1] == "\n" else existing_imports
+        existing_imports = [imp.strip() for imp in existing_imports]
+
+        # Add the new import if not already present
+        if "is_torchvision_available" not in existing_imports:
+            existing_imports.append("is_torchvision_available")
+            existing_imports.sort()
+
+        # Rebuild the import statement
+        if match.group(1):  # Single-line import
+            updated_imports = ", ".join(existing_imports)
+        else:  # Multi-line import
+            updated_imports = "(\n    " + ",\n    ".join(existing_imports) + ",\n)"
+
+        return f"from transformers.utils import {updated_imports}"
+
+    # Replace all matches in the file content
+    updated_content = regex.sub(replacement_function, content)
+
+    # add the fast image processor to the imports
+    base_import_string = f"    from transformers import {fast_image_processor_name[:-4]}"
+    fast_import_string = (
+        f"    if is_torchvision_available():\n        from transformers import {fast_image_processor_name}"
+    )
+    if fast_import_string not in updated_content:
+        updated_content = updated_content.replace(base_import_string, base_import_string + "\n\n" + fast_import_string)
+
+    # get line starting with "    image_processing_class = " and add a line after it starting with "    fast_image_processing_class = "
+    image_processing_class_line = re.search(r"    image_processing_class = .*", updated_content)
+    if not image_processing_class_line:
+        logger.warning(f"Couldn't find the 'image_processing_class' line in {test_file}. Skipping.")
+        return
+
+    fast_image_processing_class_line = (
+        f"    fast_image_processing_class = {fast_image_processor_name} if is_torchvision_available() else None"
+    )
+    if "    fast_image_processing_class = " not in updated_content:
+        updated_content = updated_content.replace(
+            image_processing_class_line.group(0),
+            image_processing_class_line.group(0) + "\n" + fast_image_processing_class_line,
+        )
+
+    # write the updated content
+    with open(test_file, "w", encoding="utf-8") as f:
+        f.write(updated_content)
+
+
+def get_fast_image_processing_content_header(content: str) -> str:
+    """
+    Get the header of the slow image processor file.
+    """
+    # get all the commented lines at the beginning of the file
+    content_header = re.search(r"^# coding=utf-8\n(#[^\n]*\n)*", content, re.MULTILINE)
+    if not content_header:
+        logger.warning("Couldn't find the content header in the slow image processor file. Using a default header.")
+        return (
+            f"# coding=utf-8\n"
+            f"# Copyright {CURRENT_YEAR} The HuggingFace Team. All rights reserved.\n"
+            f"#\n"
+            f'# Licensed under the Apache License, Version 2.0 (the "License");\n'
+            f"# you may not use this file except in compliance with the License.\n"
+            f"# You may obtain a copy of the License at\n"
+            f"#\n"
+            f"#     http://www.apache.org/licenses/LICENSE-2.0\n"
+            f"#\n"
+            f"# Unless required by applicable law or agreed to in writing, software\n"
+            f'# distributed under the License is distributed on an "AS IS" BASIS,\n'
+            f"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
+            f"# See the License for the specific language governing permissions and\n"
+            f"# limitations under the License.\n"
+            f"\n"
+        )
+    content_header = content_header.group(0)
+    # replace the year in the copyright
+    content_header = re.sub(r"# Copyright (\d+)\s", f"# Copyright {CURRENT_YEAR} ", content_header)
+    # get the line starting with """Image processor in content if it exists
+    match = re.search(r'^"""Image processor.*$', content, re.MULTILINE)
+    if match:
+        content_header += match.group(0).replace("Image processor", "Fast Image processor")
+
+    return content_header
+
+
+def write_default_fast_image_processor_file(
+    fast_image_processing_module_file: str, fast_image_processor_name: str, content_base_file: str
+):
+    """
+    Write a default fast image processor file. Used when encountering a problem while parsing the slow image processor file.
+    """
+    imports = "\n\nfrom ...image_processing_utils_fast import BaseImageProcessorFast\n\n\n"
+    content_header = get_fast_image_processing_content_header(content_base_file)
+    content_base_file = (
+        f"class {fast_image_processor_name}(BaseImageProcessorFast):\n"
+        "    # To be implemented\n"
+        "    resample = None\n"
+        "    image_mean = None\n"
+        "    image_std = None\n"
+        "    size = None\n"
+        "    default_to_square = None\n"
+        "    crop_size = None\n"
+        "    do_resize = None\n"
+        "    do_center_crop = None\n"
+        "    do_rescale = None\n"
+        "    do_normalize = None\n"
+        "    do_convert_rgb = None\n\n\n"
+        f'__all__ = ["{fast_image_processor_name}"]\n'
+    )
+
+    content = content_header + imports + content_base_file
+
+    with open(fast_image_processing_module_file, "w", encoding="utf-8") as f:
+        f.write(content)
+
+
+def add_fast_image_processor_file(
+    fast_image_processing_module_file: str, fast_image_processor_name: str, content_base_file: str
+):
+    """
+    Add the fast image processor file to the model's folder.
+    """
+    # if the file already exists, do nothing
+    if os.path.exists(fast_image_processing_module_file):
+        print(f"{fast_image_processing_module_file} already exists. Skipping.")
+        return
+
+    regex = rf"class {fast_image_processor_name[:-4]}.*?(\n\S|$)"
+    match = re.search(regex, content_base_file, re.DOTALL)
+    if not match:
+        print(f"Couldn't find the {fast_image_processor_name[:-4]} class in {fast_image_processing_module_file}")
+        print("Creating a new file with the default content.")
+        return write_default_fast_image_processor_file(
+            fast_image_processing_module_file, fast_image_processor_name, content_base_file
+        )
+    # Exclude the last unindented line
+    slow_class_content = match.group(0).rstrip()
+    # get default args:
+    # find the __init__ block which start with def __init__ and ends with def
+    match = re.search(r"def __init__.*?def ", slow_class_content, re.DOTALL)
+    if not match:
+        print(
+            f"Couldn't find the __init__ block for {fast_image_processor_name[:-4]} in {fast_image_processing_module_file}"
+        )
+        print("Creating a new file with the default content.")
+        return write_default_fast_image_processor_file(
+            fast_image_processing_module_file, fast_image_processor_name, content_base_file
+        )
+    init = match.group(0)
+    init_signature_block = init.split(")")[0]
+    arg_names = init_signature_block.split(":")
+    arg_names = [arg_name.split("\n")[-1].strip() for arg_name in arg_names]
+    # get the default values
+    default_args = re.findall(r"= (.*?)(?:,|\))", init_signature_block)
+
+    # build default args dict
+    default_args_dict = dict(zip(arg_names, default_args))
+    pattern_default_size = r"size = size if size is not None else\s+(.*)"
+    match_default_size = re.findall(pattern_default_size, init)
+    default_args_dict["size"] = match_default_size[0] if match_default_size else None
+    pattern_default_crop_size = r"crop_size = crop_size if crop_size is not None else\s+(.*)"
+    match_default_crop_size = re.findall(pattern_default_crop_size, init)
+    default_args_dict["crop_size"] = match_default_crop_size[0] if match_default_crop_size else None
+    pattern_default_image_mean = r"self.image_mean = image_mean if image_mean is not None else\s+(.*)"
+    match_default_image_mean = re.findall(pattern_default_image_mean, init)
+    default_args_dict["image_mean"] = match_default_image_mean[0] if match_default_image_mean else None
+    pattern_default_image_std = r"self.image_std = image_std if image_std is not None else\s+(.*)"
+    match_default_image_std = re.findall(pattern_default_image_std, init)
+    default_args_dict["image_std"] = match_default_image_std[0] if match_default_image_std else None
+    default_args_dict["default_to_square"] = False if "(size, default_to_square=False" in init else None
+
+    content_header = get_fast_image_processing_content_header(content_base_file)
+    content_base_file = (
+        f"@auto_docstring\n"
+        f"class {fast_image_processor_name}(BaseImageProcessorFast):\n"
+        "    # This generated class can be used as a starting point for the fast image processor.\n"
+        "    # if the image processor is only used for simple augmentations, such as resizing, center cropping, rescaling, or normalizing,\n"
+        "    # only the default values should be set in the class.\n"
+        "    # If the image processor requires more complex augmentations, methods from BaseImageProcessorFast can be overridden.\n"
+        "    # In most cases, only the `_preprocess` method should be overridden.\n\n"
+        "    # For an example of a fast image processor requiring more complex augmentations, see `LlavaNextImageProcessorFast`.\n\n"
+        "    # Default values should be checked against the slow image processor\n"
+        "    # None values left after checking can be removed\n"
+        f"    resample = {default_args_dict.get('resample')}\n"
+        f"    image_mean = {default_args_dict.get('image_mean')}\n"
+        f"    image_std = {default_args_dict.get('image_std')}\n"
+        f"    size = {default_args_dict.get('size')}\n"
+        f"    default_to_square = {default_args_dict.get('default_to_square')}\n"
+        f"    crop_size = {default_args_dict.get('crop_size')}\n"
+        f"    do_resize = {default_args_dict.get('do_resize')}\n"
+        f"    do_center_crop = {default_args_dict.get('do_center_crop')}\n"
+        f"    do_rescale = {default_args_dict.get('do_rescale')}\n"
+        f"    do_normalize = {default_args_dict.get('do_normalize')}\n"
+        f"    do_convert_rgb = {default_args_dict.get('do_convert_rgb')}\n\n\n"
+        f'__all__ = ["{fast_image_processor_name}"]\n'
+    )
+
+    imports = "\n\nfrom ...image_processing_utils_fast import BaseImageProcessorFast\n"
+    image_utils_imports = []
+    if default_args_dict.get("resample") is not None and "PILImageResampling" in default_args_dict.get("resample"):
+        image_utils_imports.append("PILImageResampling")
+    if default_args_dict.get("image_mean") is not None and not any(
+        char.isdigit() for char in default_args_dict.get("image_mean")
+    ):
+        image_utils_imports.append(default_args_dict.get("image_mean"))
+    if default_args_dict.get("image_std") is not None and not any(
+        char.isdigit() for char in default_args_dict.get("image_std")
+    ):
+        image_utils_imports.append(default_args_dict.get("image_std"))
+
+    if image_utils_imports:
+        # sort imports
+        image_utils_imports.sort()
+        imports += f"from ...image_utils import {', '.join(image_utils_imports)}\n"
+
+    imports += "from ...utils import auto_docstring\n"
+
+    content = content_header + imports + "\n\n" + content_base_file
+
+    with open(fast_image_processing_module_file, "w", encoding="utf-8") as f:
+        f.write(content)
+
+
+def add_fast_image_processor(model_name: str):
+    """
+    Add the necessary references to the fast image processor in the transformers package,
+    and create the fast image processor file in the model's folder.
+    """
+    model_module = TRANSFORMERS_PATH / "models" / model_name
+    image_processing_module_file = list(model_module.glob("image_processing*.py"))
+    if not image_processing_module_file:
+        raise ValueError(f"No image processing module found in {model_module}")
+    elif len(image_processing_module_file) > 1:
+        for file_name in image_processing_module_file:
+            if not str(file_name).endswith("_fast.py"):
+                image_processing_module_file = str(file_name)
+                break
+    else:
+        image_processing_module_file = str(image_processing_module_file[0])
+
+    with open(image_processing_module_file, "r", encoding="utf-8") as f:
+        content_base_file = f.read()
+
+    # regex to find object starting with "class " and ending with "ImageProcessor", including "ImageProcessor" in the match
+    image_processor_name = re.findall(r"class (\w*ImageProcessor)", content_base_file)
+    if not image_processor_name:
+        raise ValueError(f"No ImageProcessor class found in {image_processing_module_file}")
+    elif len(image_processor_name) > 1:
+        raise ValueError(f"Multiple ImageProcessor classes found in {image_processing_module_file}")
+
+    image_processor_name = image_processor_name[0]
+    fast_image_processor_name = image_processor_name + "Fast"
+    fast_image_processing_module_file = image_processing_module_file.replace(".py", "_fast.py")
+
+    print(f"Adding {fast_image_processor_name} to {fast_image_processing_module_file}")
+
+    add_fast_image_processor_to_model_init(
+        fast_image_processing_module_file=fast_image_processing_module_file,
+        fast_image_processor_name=fast_image_processor_name,
+        model_name=model_name,
+    )
+
+    add_fast_image_processor_to_auto(
+        image_processor_name=image_processor_name,
+        fast_image_processor_name=fast_image_processor_name,
+    )
+
+    add_fast_image_processor_to_doc(
+        fast_image_processor_name=fast_image_processor_name,
+        model_name=model_name,
+    )
+
+    add_fast_image_processor_to_tests(
+        fast_image_processor_name=fast_image_processor_name,
+        model_name=model_name,
+    )
+
+    add_fast_image_processor_file(
+        fast_image_processing_module_file=fast_image_processing_module_file,
+        fast_image_processor_name=fast_image_processor_name,
+        content_base_file=content_base_file,
+    )
+
+
+def add_new_model_like_command_factory(args: Namespace):
+    return AddFastImageProcessorCommand(model_name=args.model_name)
+
+
+class AddFastImageProcessorCommand(BaseTransformersCLICommand):
+    @staticmethod
+    def register_subcommand(parser: ArgumentParser):
+        add_fast_image_processor_parser = parser.add_parser("add-fast-image-processor")
+        add_fast_image_processor_parser.add_argument(
+            "--model-name",
+            type=str,
+            required=True,
+            help="The name of the folder containing the model's implementation.",
+        )
+        add_fast_image_processor_parser.set_defaults(func=add_new_model_like_command_factory)
+
+    def __init__(self, model_name: str, *args):
+        self.model_name = model_name
+
+    def run(self):
+        add_fast_image_processor(model_name=self.model_name)
diff --git a/phivenv/Lib/site-packages/transformers/commands/add_new_model_like.py b/phivenv/Lib/site-packages/transformers/commands/add_new_model_like.py
new file mode 100644
index 0000000000000000000000000000000000000000..c90a55b44bf85cfaccf06107d492b202e1a9db45
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/commands/add_new_model_like.py
@@ -0,0 +1,783 @@
+# Copyright 2021 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import difflib
+import os
+import re
+import subprocess
+import textwrap
+from argparse import ArgumentParser, Namespace
+from datetime import date
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+
+from ..models.auto.configuration_auto import CONFIG_MAPPING_NAMES, MODEL_NAMES_MAPPING
+from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING_NAMES
+from ..models.auto.image_processing_auto import IMAGE_PROCESSOR_MAPPING_NAMES
+from ..models.auto.processing_auto import PROCESSOR_MAPPING_NAMES
+from ..models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES
+from ..models.auto.video_processing_auto import VIDEO_PROCESSOR_MAPPING_NAMES
+from ..utils import is_libcst_available
+from . import BaseTransformersCLICommand
+from .add_fast_image_processor import add_fast_image_processor
+
+
+# We protect this import to avoid requiring it for all `transformers` CLI commands - however it is actually
+# strictly required for this one (we need it both for modular and for the following Visitor)
+if is_libcst_available():
+    import libcst as cst
+    from libcst import CSTVisitor
+    from libcst import matchers as m
+
+    class ClassFinder(CSTVisitor):
+        """
+        A visitor to find all classes in a python module.
+        """
+
+        def __init__(self):
+            self.classes: list = []
+            self.public_classes: list = []
+            self.is_in_class = False
+
+        def visit_ClassDef(self, node: cst.ClassDef) -> None:
+            """Record class names. We assume classes always only appear at top-level (i.e. no class definition in function or similar)"""
+            self.classes.append(node.name.value)
+            self.is_in_class = True
+
+        def leave_ClassDef(self, node: cst.ClassDef):
+            self.is_in_class = False
+
+        def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine):
+            """Record all public classes inside the `__all__` assignment."""
+            simple_top_level_assign_structure = m.SimpleStatementLine(
+                body=[m.Assign(targets=[m.AssignTarget(target=m.Name())])]
+            )
+            if not self.is_in_class and m.matches(node, simple_top_level_assign_structure):
+                assigned_variable = node.body[0].targets[0].target.value
+                if assigned_variable == "__all__":
+                    elements = node.body[0].value.elements
+                    self.public_classes = [element.value.value for element in elements]
+
+
+CURRENT_YEAR = date.today().year
+TRANSFORMERS_PATH = Path(__file__).parent.parent
+REPO_PATH = TRANSFORMERS_PATH.parent.parent
+
+COPYRIGHT = f"""
+# coding=utf-8
+# Copyright {CURRENT_YEAR} the HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""".lstrip()
+
+
+class ModelInfos:
+    """
+    Retrieve the basic informations about an existing model classes.
+    """
+
+    def __init__(self, lowercase_name: str):
+        # Just to make sure it's indeed lowercase
+        self.lowercase_name = lowercase_name.lower().replace(" ", "_").replace("-", "_")
+        if self.lowercase_name not in CONFIG_MAPPING_NAMES:
+            self.lowercase_name.replace("_", "-")
+        if self.lowercase_name not in CONFIG_MAPPING_NAMES:
+            raise ValueError(f"{lowercase_name} is not a valid model name")
+
+        self.paper_name = MODEL_NAMES_MAPPING[self.lowercase_name]
+        self.config_class = CONFIG_MAPPING_NAMES[self.lowercase_name]
+        self.camelcase_name = self.config_class.replace("Config", "")
+
+        # Get tokenizer class
+        if self.lowercase_name in TOKENIZER_MAPPING_NAMES:
+            self.tokenizer_class, self.fast_tokenizer_class = TOKENIZER_MAPPING_NAMES[self.lowercase_name]
+            self.fast_tokenizer_class = (
+                None if self.fast_tokenizer_class == "PreTrainedTokenizerFast" else self.fast_tokenizer_class
+            )
+        else:
+            self.tokenizer_class, self.fast_tokenizer_class = None, None
+
+        self.image_processor_class, self.fast_image_processor_class = IMAGE_PROCESSOR_MAPPING_NAMES.get(
+            self.lowercase_name, (None, None)
+        )
+        self.video_processor_class = VIDEO_PROCESSOR_MAPPING_NAMES.get(self.lowercase_name, None)
+        self.feature_extractor_class = FEATURE_EXTRACTOR_MAPPING_NAMES.get(self.lowercase_name, None)
+        self.processor_class = PROCESSOR_MAPPING_NAMES.get(self.lowercase_name, None)
+
+
+def add_content_to_file(file_name: Union[str, os.PathLike], new_content: str, add_after: str):
+    """
+    A utility to add some content inside a given file.
+
+    Args:
+        file_name (`str` or `os.PathLike`):
+            The name of the file in which we want to insert some content.
+        new_content (`str`):
+            The content to add.
+       add_after (`str`):
+           The new content is added just after the first instance matching it.
+    """
+    with open(file_name, "r", encoding="utf-8") as f:
+        old_content = f.read()
+
+    before, after = old_content.split(add_after, 1)
+    new_content = before + add_after + new_content + after
+
+    with open(file_name, "w", encoding="utf-8") as f:
+        f.write(new_content)
+
+
+def add_model_to_auto_mappings(
+    old_model_infos: ModelInfos,
+    new_lowercase_name: str,
+    new_model_paper_name: str,
+    filenames_to_add: list[tuple[str, bool]],
+):
+    """
+    Add a model to all the relevant mappings in the auto module.
+
+    Args:
+        old_model_infos (`ModelInfos`):
+            The structure containing the class informations of the old model.
+        new_lowercase_name (`str`):
+            The new lowercase model name.
+        new_model_paper_name (`str`):
+            The fully cased name (as in the official paper name) of the new model.
+        filenames_to_add (`list[tuple[str, bool]]`):
+            A list of tuples of all potential filenames to add for a new model, along a boolean flag describing if we
+            should add this file or not. For example, [(`modeling_xxx.px`, True), (`configuration_xxx.py`, True), (`tokenization_xxx.py`, False),...]
+    """
+    new_cased_name = "".join(x.title() for x in new_lowercase_name.replace("-", "_").split("_"))
+    old_lowercase_name = old_model_infos.lowercase_name
+    old_cased_name = old_model_infos.camelcase_name
+    filenames_to_add = [
+        (filename.replace(old_lowercase_name, "auto"), to_add) for filename, to_add in filenames_to_add[1:]
+    ]
+    # fast tokenizer/image processor have the same auto mappings as normal ones
+    corrected_filenames_to_add = []
+    for file, to_add in filenames_to_add:
+        if re.search(r"(?:tokenization)|(?:image_processing)_auto_fast.py", file):
+            previous_file, previous_to_add = corrected_filenames_to_add[-1]
+            corrected_filenames_to_add[-1] = (previous_file, previous_to_add or to_add)
+        else:
+            corrected_filenames_to_add.append((file, to_add))
+
+    # Add the config mappings directly as the handling for config is a bit different
+    add_content_to_file(
+        TRANSFORMERS_PATH / "models" / "auto" / "configuration_auto.py",
+        new_content=f'        ("{new_lowercase_name}", "{new_cased_name}Config"),\n',
+        add_after="CONFIG_MAPPING_NAMES = OrderedDict[str, str](\n    [\n        # Add configs here\n",
+    )
+    add_content_to_file(
+        TRANSFORMERS_PATH / "models" / "auto" / "configuration_auto.py",
+        new_content=f'        ("{new_lowercase_name}", "{new_model_paper_name}"),\n',
+        add_after="MODEL_NAMES_MAPPING = OrderedDict[str, str](\n    [\n        # Add full (and cased) model names here\n",
+    )
+
+    for filename, to_add in corrected_filenames_to_add:
+        if to_add:
+            # The auto mapping
+            filename = filename.replace("_fast.py", ".py")
+            with open(TRANSFORMERS_PATH / "models" / "auto" / filename) as f:
+                file = f.read()
+            # The regex has to be a bit complex like this as the tokenizer mapping has new lines everywhere
+            matching_lines = re.findall(
+                rf'( {{8,12}}\(\s*"{old_lowercase_name}",.*?\),\n)(?: {{4,12}}\(|\])', file, re.DOTALL
+            )
+            for match in matching_lines:
+                add_content_to_file(
+                    TRANSFORMERS_PATH / "models" / "auto" / filename,
+                    new_content=match.replace(old_lowercase_name, new_lowercase_name).replace(
+                        old_cased_name, new_cased_name
+                    ),
+                    add_after=match,
+                )
+
+
+def create_doc_file(new_paper_name: str, public_classes: list[str]):
+    """
+    Create a new doc file to fill for the new model.
+
+    Args:
+        new_paper_name (`str`):
+            The fully cased name (as in the official paper name) of the new model.
+        public_classes (`list[str]`):
+            A list of all the public classes that the model will have in the library.
+    """
+    added_note = (
+        "\n\n⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that "
+        "may not be rendered properly in your Markdown viewer.\n\n-->\n\n"
+    )
+    copyright_for_markdown = re.sub(r"# ?", "", COPYRIGHT).replace("coding=utf-8\n", "
+"""
+
+AUTOGENERATED_KERAS_COMMENT = """
+
+"""
+
+
+TASK_TAG_TO_NAME_MAPPING = {
+    "fill-mask": "Masked Language Modeling",
+    "image-classification": "Image Classification",
+    "image-segmentation": "Image Segmentation",
+    "multiple-choice": "Multiple Choice",
+    "object-detection": "Object Detection",
+    "question-answering": "Question Answering",
+    "summarization": "Summarization",
+    "table-question-answering": "Table Question Answering",
+    "text-classification": "Text Classification",
+    "text-generation": "Causal Language Modeling",
+    "text2text-generation": "Sequence-to-sequence Language Modeling",
+    "token-classification": "Token Classification",
+    "translation": "Translation",
+    "zero-shot-classification": "Zero Shot Classification",
+    "automatic-speech-recognition": "Automatic Speech Recognition",
+    "audio-classification": "Audio Classification",
+}
+
+
+METRIC_TAGS = [
+    "accuracy",
+    "bleu",
+    "f1",
+    "matthews_correlation",
+    "pearsonr",
+    "precision",
+    "recall",
+    "rouge",
+    "sacrebleu",
+    "spearmanr",
+    "wer",
+]
+
+
+def _listify(obj):
+    if obj is None:
+        return []
+    elif isinstance(obj, str):
+        return [obj]
+    else:
+        return obj
+
+
+def _insert_values_as_list(metadata, name, values):
+    if values is None:
+        return metadata
+    if isinstance(values, str):
+        values = [values]
+    values = [v for v in values if v is not None]
+    if len(values) == 0:
+        return metadata
+    metadata[name] = values
+    return metadata
+
+
+def infer_metric_tags_from_eval_results(eval_results):
+    if eval_results is None:
+        return {}
+    result = {}
+    for key in eval_results:
+        if key.lower().replace(" ", "_") in METRIC_TAGS:
+            result[key.lower().replace(" ", "_")] = key
+        elif key.lower() == "rouge1":
+            result["rouge"] = key
+    return result
+
+
+def _insert_value(metadata, name, value):
+    if value is None:
+        return metadata
+    metadata[name] = value
+    return metadata
+
+
+def is_hf_dataset(dataset):
+    if not is_datasets_available():
+        return False
+
+    from datasets import Dataset, IterableDataset
+
+    return isinstance(dataset, (Dataset, IterableDataset))
+
+
+def _get_mapping_values(mapping):
+    result = []
+    for v in mapping.values():
+        if isinstance(v, (tuple, list)):
+            result += list(v)
+        else:
+            result.append(v)
+    return result
+
+
+@dataclass
+class TrainingSummary:
+    model_name: str
+    language: Optional[Union[str, list[str]]] = None
+    license: Optional[str] = None
+    tags: Optional[Union[str, list[str]]] = None
+    finetuned_from: Optional[str] = None
+    tasks: Optional[Union[str, list[str]]] = None
+    dataset: Optional[Union[str, list[str]]] = None
+    dataset_tags: Optional[Union[str, list[str]]] = None
+    dataset_args: Optional[Union[str, list[str]]] = None
+    dataset_metadata: Optional[dict[str, Any]] = None
+    eval_results: Optional[dict[str, float]] = None
+    eval_lines: Optional[list[str]] = None
+    hyperparameters: Optional[dict[str, Any]] = None
+    source: Optional[str] = "trainer"
+
+    def __post_init__(self):
+        # Infer default license from the checkpoint used, if possible.
+        if (
+            self.license is None
+            and not is_offline_mode()
+            and self.finetuned_from is not None
+            and len(self.finetuned_from) > 0
+        ):
+            try:
+                info = model_info(self.finetuned_from)
+                for tag in info.tags:
+                    if tag.startswith("license:"):
+                        self.license = tag[8:]
+            except (
+                requests.exceptions.HTTPError,
+                requests.exceptions.ConnectionError,
+                HFValidationError,
+                OfflineModeIsEnabled,
+            ):
+                pass
+
+    def create_model_index(self, metric_mapping):
+        model_index = {"name": self.model_name}
+
+        # Dataset mapping tag -> name
+        dataset_names = _listify(self.dataset)
+        dataset_tags = _listify(self.dataset_tags)
+        dataset_args = _listify(self.dataset_args)
+        dataset_metadata = _listify(self.dataset_metadata)
+        if len(dataset_args) < len(dataset_tags):
+            dataset_args = dataset_args + [None] * (len(dataset_tags) - len(dataset_args))
+        dataset_mapping = dict(zip(dataset_tags, dataset_names))
+        dataset_arg_mapping = dict(zip(dataset_tags, dataset_args))
+        dataset_metadata_mapping = dict(zip(dataset_tags, dataset_metadata))
+
+        task_mapping = {
+            task: TASK_TAG_TO_NAME_MAPPING[task] for task in _listify(self.tasks) if task in TASK_TAG_TO_NAME_MAPPING
+        }
+
+        model_index["results"] = []
+
+        if len(task_mapping) == 0 and len(dataset_mapping) == 0:
+            return [model_index]
+        if len(task_mapping) == 0:
+            task_mapping = {None: None}
+        if len(dataset_mapping) == 0:
+            dataset_mapping = {None: None}
+
+        # One entry per dataset and per task
+        all_possibilities = [(task_tag, ds_tag) for task_tag in task_mapping for ds_tag in dataset_mapping]
+        for task_tag, ds_tag in all_possibilities:
+            result = {}
+            if task_tag is not None:
+                result["task"] = {"name": task_mapping[task_tag], "type": task_tag}
+
+            if ds_tag is not None:
+                metadata = dataset_metadata_mapping.get(ds_tag, {})
+                result["dataset"] = {
+                    "name": dataset_mapping[ds_tag],
+                    "type": ds_tag,
+                    **metadata,
+                }
+                if dataset_arg_mapping[ds_tag] is not None:
+                    result["dataset"]["args"] = dataset_arg_mapping[ds_tag]
+
+            if len(metric_mapping) > 0:
+                result["metrics"] = []
+                for metric_tag, metric_name in metric_mapping.items():
+                    result["metrics"].append(
+                        {
+                            "name": metric_name,
+                            "type": metric_tag,
+                            "value": self.eval_results[metric_name],
+                        }
+                    )
+
+            # Remove partial results to avoid the model card being rejected.
+            if "task" in result and "dataset" in result and "metrics" in result:
+                model_index["results"].append(result)
+            else:
+                logger.info(f"Dropping the following result as it does not have all the necessary fields:\n{result}")
+
+        return [model_index]
+
+    def create_metadata(self):
+        metric_mapping = infer_metric_tags_from_eval_results(self.eval_results)
+
+        metadata = {}
+        metadata = _insert_value(metadata, "library_name", "transformers")
+        metadata = _insert_values_as_list(metadata, "language", self.language)
+        metadata = _insert_value(metadata, "license", self.license)
+        if self.finetuned_from is not None and isinstance(self.finetuned_from, str) and len(self.finetuned_from) > 0:
+            metadata = _insert_value(metadata, "base_model", self.finetuned_from)
+        metadata = _insert_values_as_list(metadata, "tags", self.tags)
+        metadata = _insert_values_as_list(metadata, "datasets", self.dataset_tags)
+        metadata = _insert_values_as_list(metadata, "metrics", list(metric_mapping.keys()))
+        metadata["model-index"] = self.create_model_index(metric_mapping)
+
+        return metadata
+
+    def to_model_card(self):
+        model_card = ""
+
+        metadata = yaml.dump(self.create_metadata(), sort_keys=False)
+        if len(metadata) > 0:
+            model_card = f"---\n{metadata}---\n"
+
+        # Now the model card for realsies.
+        if self.source == "trainer":
+            model_card += AUTOGENERATED_TRAINER_COMMENT
+        else:
+            model_card += AUTOGENERATED_KERAS_COMMENT
+
+        model_card += f"\n# {self.model_name}\n\n"
+
+        if self.finetuned_from is None:
+            model_card += "This model was trained from scratch on "
+        else:
+            model_card += (
+                "This model is a fine-tuned version of"
+                f" [{self.finetuned_from}](https://huggingface.co/{self.finetuned_from}) on "
+            )
+
+        if self.dataset is None or (isinstance(self.dataset, list) and len(self.dataset) == 0):
+            model_card += "an unknown dataset."
+        else:
+            if isinstance(self.dataset, str):
+                model_card += f"the {self.dataset} dataset."
+            elif isinstance(self.dataset, (tuple, list)) and len(self.dataset) == 1:
+                model_card += f"the {self.dataset[0]} dataset."
+            else:
+                model_card += (
+                    ", ".join([f"the {ds}" for ds in self.dataset[:-1]]) + f" and the {self.dataset[-1]} datasets."
+                )
+
+        if self.eval_results is not None:
+            model_card += "\nIt achieves the following results on the evaluation set:\n"
+            model_card += "\n".join([f"- {name}: {_maybe_round(value)}" for name, value in self.eval_results.items()])
+        model_card += "\n"
+
+        model_card += "\n## Model description\n\nMore information needed\n"
+        model_card += "\n## Intended uses & limitations\n\nMore information needed\n"
+        model_card += "\n## Training and evaluation data\n\nMore information needed\n"
+
+        model_card += "\n## Training procedure\n"
+        model_card += "\n### Training hyperparameters\n"
+        if self.hyperparameters is not None:
+            model_card += "\nThe following hyperparameters were used during training:\n"
+            model_card += "\n".join([f"- {name}: {value}" for name, value in self.hyperparameters.items()])
+            model_card += "\n"
+        else:
+            model_card += "\nMore information needed\n"
+
+        if self.eval_lines is not None:
+            model_card += "\n### Training results\n\n"
+            model_card += make_markdown_table(self.eval_lines)
+            model_card += "\n"
+
+        model_card += "\n### Framework versions\n\n"
+        model_card += f"- Transformers {__version__}\n"
+
+        if self.source == "trainer" and is_torch_available():
+            import torch
+
+            model_card += f"- Pytorch {torch.__version__}\n"
+        elif self.source == "keras" and is_tf_available():
+            import tensorflow as tf
+
+            model_card += f"- TensorFlow {tf.__version__}\n"
+        if is_datasets_available():
+            import datasets
+
+            model_card += f"- Datasets {datasets.__version__}\n"
+        if is_tokenizers_available():
+            import tokenizers
+
+            model_card += f"- Tokenizers {tokenizers.__version__}\n"
+
+        return model_card
+
+    @classmethod
+    def from_trainer(
+        cls,
+        trainer,
+        language=None,
+        license=None,
+        tags=None,
+        model_name=None,
+        finetuned_from=None,
+        tasks=None,
+        dataset_tags=None,
+        dataset_metadata=None,
+        dataset=None,
+        dataset_args=None,
+    ):
+        # Infer default from dataset
+        one_dataset = trainer.eval_dataset if trainer.eval_dataset is not None else trainer.train_dataset
+        if is_hf_dataset(one_dataset) and (dataset_tags is None or dataset_args is None or dataset_metadata is None):
+            default_tag = one_dataset.builder_name
+            # Those are not real datasets from the Hub so we exclude them.
+            if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
+                if dataset_metadata is None:
+                    dataset_metadata = [{"config": one_dataset.config_name, "split": str(one_dataset.split)}]
+                if dataset_tags is None:
+                    dataset_tags = [default_tag]
+                if dataset_args is None:
+                    dataset_args = [one_dataset.config_name]
+
+        if dataset is None and dataset_tags is not None:
+            dataset = dataset_tags
+
+        # Infer default finetuned_from
+        if (
+            finetuned_from is None
+            and hasattr(trainer.model.config, "_name_or_path")
+            and not os.path.isdir(trainer.model.config._name_or_path)
+        ):
+            finetuned_from = trainer.model.config._name_or_path
+
+        # Infer default task tag:
+        if tasks is None:
+            model_class_name = trainer.model.__class__.__name__
+            for task, mapping in TASK_MAPPING.items():
+                if model_class_name in _get_mapping_values(mapping):
+                    tasks = task
+
+        if model_name is None:
+            model_name = Path(trainer.args.output_dir).name
+        if len(model_name) == 0:
+            model_name = finetuned_from
+
+        # Add `generated_from_trainer` to the tags
+        if tags is None:
+            tags = ["generated_from_trainer"]
+        elif isinstance(tags, str) and tags != "generated_from_trainer":
+            tags = [tags, "generated_from_trainer"]
+        elif "generated_from_trainer" not in tags:
+            tags.append("generated_from_trainer")
+
+        _, eval_lines, eval_results = parse_log_history(trainer.state.log_history)
+        hyperparameters = extract_hyperparameters_from_trainer(trainer)
+
+        return cls(
+            language=language,
+            license=license,
+            tags=tags,
+            model_name=model_name,
+            finetuned_from=finetuned_from,
+            tasks=tasks,
+            dataset=dataset,
+            dataset_tags=dataset_tags,
+            dataset_args=dataset_args,
+            dataset_metadata=dataset_metadata,
+            eval_results=eval_results,
+            eval_lines=eval_lines,
+            hyperparameters=hyperparameters,
+        )
+
+    @classmethod
+    def from_keras(
+        cls,
+        model,
+        model_name,
+        keras_history=None,
+        language=None,
+        license=None,
+        tags=None,
+        finetuned_from=None,
+        tasks=None,
+        dataset_tags=None,
+        dataset=None,
+        dataset_args=None,
+    ):
+        # Infer default from dataset
+        if dataset is not None:
+            if is_hf_dataset(dataset) and (dataset_tags is None or dataset_args is None):
+                default_tag = dataset.builder_name
+                # Those are not real datasets from the Hub so we exclude them.
+                if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
+                    if dataset_tags is None:
+                        dataset_tags = [default_tag]
+                    if dataset_args is None:
+                        dataset_args = [dataset.config_name]
+
+        if dataset is None and dataset_tags is not None:
+            dataset = dataset_tags
+
+        # Infer default finetuned_from
+        if (
+            finetuned_from is None
+            and hasattr(model.config, "_name_or_path")
+            and not os.path.isdir(model.config._name_or_path)
+        ):
+            finetuned_from = model.config._name_or_path
+
+        # Infer default task tag:
+        if tasks is None:
+            model_class_name = model.__class__.__name__
+            for task, mapping in TASK_MAPPING.items():
+                if model_class_name in _get_mapping_values(mapping):
+                    tasks = task
+
+        # Add `generated_from_keras_callback` to the tags
+        if tags is None:
+            tags = ["generated_from_keras_callback"]
+        elif isinstance(tags, str) and tags != "generated_from_keras_callback":
+            tags = [tags, "generated_from_keras_callback"]
+        elif "generated_from_keras_callback" not in tags:
+            tags.append("generated_from_keras_callback")
+
+        if keras_history is not None:
+            _, eval_lines, eval_results = parse_keras_history(keras_history)
+        else:
+            eval_lines = []
+            eval_results = {}
+        hyperparameters = extract_hyperparameters_from_keras(model)
+
+        return cls(
+            language=language,
+            license=license,
+            tags=tags,
+            model_name=model_name,
+            finetuned_from=finetuned_from,
+            tasks=tasks,
+            dataset_tags=dataset_tags,
+            dataset=dataset,
+            dataset_args=dataset_args,
+            eval_results=eval_results,
+            eval_lines=eval_lines,
+            hyperparameters=hyperparameters,
+            source="keras",
+        )
+
+
+def parse_keras_history(logs):
+    """
+    Parse the `logs` of either a `keras.History` object returned by `model.fit()` or an accumulated logs `dict`
+    passed to the `PushToHubCallback`. Returns lines and logs compatible with those returned by `parse_log_history`.
+    """
+    if hasattr(logs, "history"):
+        # This looks like a `History` object
+        if not hasattr(logs, "epoch"):
+            # This history looks empty, return empty results
+            return None, [], {}
+        logs.history["epoch"] = logs.epoch
+        logs = logs.history
+    else:
+        # Training logs is a list of dicts, let's invert it to a dict of lists to match a History object
+        logs = {log_key: [single_dict[log_key] for single_dict in logs] for log_key in logs[0]}
+
+    lines = []
+    for i in range(len(logs["epoch"])):
+        epoch_dict = {log_key: log_value_list[i] for log_key, log_value_list in logs.items()}
+        values = {}
+        for k, v in epoch_dict.items():
+            if k.startswith("val_"):
+                k = "validation_" + k[4:]
+            elif k != "epoch":
+                k = "train_" + k
+            splits = k.split("_")
+            name = " ".join([part.capitalize() for part in splits])
+            values[name] = v
+        lines.append(values)
+
+    eval_results = lines[-1]
+
+    return logs, lines, eval_results
+
+
+def parse_log_history(log_history):
+    """
+    Parse the `log_history` of a Trainer to get the intermediate and final evaluation results.
+    """
+    idx = 0
+    while idx < len(log_history) and "train_runtime" not in log_history[idx]:
+        idx += 1
+
+    # If there are no training logs
+    if idx == len(log_history):
+        idx -= 1
+        while idx >= 0 and "eval_loss" not in log_history[idx]:
+            idx -= 1
+
+        if idx >= 0:
+            return None, None, log_history[idx]
+        else:
+            return None, None, None
+
+    # From now one we can assume we have training logs:
+    train_log = log_history[idx]
+    lines = []
+    training_loss = "No log"
+    for i in range(idx):
+        if "loss" in log_history[i]:
+            training_loss = log_history[i]["loss"]
+        if "eval_loss" in log_history[i]:
+            metrics = log_history[i].copy()
+            _ = metrics.pop("total_flos", None)
+            epoch = metrics.pop("epoch", None)
+            step = metrics.pop("step", None)
+            _ = metrics.pop("eval_runtime", None)
+            _ = metrics.pop("eval_samples_per_second", None)
+            _ = metrics.pop("eval_steps_per_second", None)
+            _ = metrics.pop("eval_jit_compilation_time", None)
+            values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
+            for k, v in metrics.items():
+                if k == "eval_loss":
+                    values["Validation Loss"] = v
+                else:
+                    splits = k.split("_")
+                    name = " ".join([part.capitalize() for part in splits[1:]])
+                    values[name] = v
+            lines.append(values)
+
+    idx = len(log_history) - 1
+    while idx >= 0 and "eval_loss" not in log_history[idx]:
+        idx -= 1
+
+    if idx > 0:
+        eval_results = {}
+        for key, value in log_history[idx].items():
+            if key.startswith("eval_"):
+                key = key[5:]
+            if key not in ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"]:
+                camel_cased_key = " ".join([part.capitalize() for part in key.split("_")])
+                eval_results[camel_cased_key] = value
+        return train_log, lines, eval_results
+    else:
+        return train_log, lines, None
+
+
+def extract_hyperparameters_from_keras(model):
+    from .modeling_tf_utils import keras
+
+    hyperparameters = {}
+    if hasattr(model, "optimizer") and model.optimizer is not None:
+        hyperparameters["optimizer"] = model.optimizer.get_config()
+    else:
+        hyperparameters["optimizer"] = None
+    hyperparameters["training_precision"] = keras.mixed_precision.global_policy().name
+
+    return hyperparameters
+
+
+def _maybe_round(v, decimals=4):
+    if isinstance(v, float) and len(str(v).split(".")) > 1 and len(str(v).split(".")[1]) > decimals:
+        return f"{v:.{decimals}f}"
+    return str(v)
+
+
+def _regular_table_line(values, col_widths):
+    values_with_space = [f"| {v}" + " " * (w - len(v) + 1) for v, w in zip(values, col_widths)]
+    return "".join(values_with_space) + "|\n"
+
+
+def _second_table_line(col_widths):
+    values = ["|:" + "-" * w + ":" for w in col_widths]
+    return "".join(values) + "|\n"
+
+
+def make_markdown_table(lines):
+    """
+    Create a nice Markdown table from the results in `lines`.
+    """
+    if lines is None or len(lines) == 0:
+        return ""
+    col_widths = {key: len(str(key)) for key in lines[0]}
+    for line in lines:
+        for key, value in line.items():
+            if col_widths[key] < len(_maybe_round(value)):
+                col_widths[key] = len(_maybe_round(value))
+
+    table = _regular_table_line(list(lines[0].keys()), list(col_widths.values()))
+    table += _second_table_line(list(col_widths.values()))
+    for line in lines:
+        table += _regular_table_line([_maybe_round(v) for v in line.values()], list(col_widths.values()))
+    return table
+
+
+_TRAINING_ARGS_KEYS = [
+    "learning_rate",
+    "train_batch_size",
+    "eval_batch_size",
+    "seed",
+]
+
+
+def extract_hyperparameters_from_trainer(trainer):
+    hyperparameters = {k: getattr(trainer.args, k) for k in _TRAINING_ARGS_KEYS}
+
+    if trainer.args.parallel_mode not in [ParallelMode.NOT_PARALLEL, ParallelMode.NOT_DISTRIBUTED]:
+        hyperparameters["distributed_type"] = (
+            "multi-GPU" if trainer.args.parallel_mode == ParallelMode.DISTRIBUTED else trainer.args.parallel_mode.value
+        )
+    if trainer.args.world_size > 1:
+        hyperparameters["num_devices"] = trainer.args.world_size
+    if trainer.args.gradient_accumulation_steps > 1:
+        hyperparameters["gradient_accumulation_steps"] = trainer.args.gradient_accumulation_steps
+
+    total_train_batch_size = (
+        trainer.args.train_batch_size * trainer.args.world_size * trainer.args.gradient_accumulation_steps
+    )
+    if total_train_batch_size != hyperparameters["train_batch_size"]:
+        hyperparameters["total_train_batch_size"] = total_train_batch_size
+    total_eval_batch_size = trainer.args.eval_batch_size * trainer.args.world_size
+    if total_eval_batch_size != hyperparameters["eval_batch_size"]:
+        hyperparameters["total_eval_batch_size"] = total_eval_batch_size
+
+    if trainer.args.optim:
+        optimizer_name = trainer.args.optim
+        optimizer_args = trainer.args.optim_args if trainer.args.optim_args else "No additional optimizer arguments"
+
+        if "adam" in optimizer_name.lower():
+            hyperparameters["optimizer"] = (
+                f"Use {optimizer_name} with betas=({trainer.args.adam_beta1},{trainer.args.adam_beta2}) and"
+                f" epsilon={trainer.args.adam_epsilon} and optimizer_args={optimizer_args}"
+            )
+        else:
+            hyperparameters["optimizer"] = f"Use {optimizer_name} and the args are:\n{optimizer_args}"
+
+    hyperparameters["lr_scheduler_type"] = trainer.args.lr_scheduler_type.value
+    if trainer.args.warmup_ratio != 0.0:
+        hyperparameters["lr_scheduler_warmup_ratio"] = trainer.args.warmup_ratio
+    if trainer.args.warmup_steps != 0.0:
+        hyperparameters["lr_scheduler_warmup_steps"] = trainer.args.warmup_steps
+    if trainer.args.max_steps != -1:
+        hyperparameters["training_steps"] = trainer.args.max_steps
+    else:
+        hyperparameters["num_epochs"] = trainer.args.num_train_epochs
+
+    if trainer.args.fp16:
+        if trainer.use_apex:
+            hyperparameters["mixed_precision_training"] = f"Apex, opt level {trainer.args.fp16_opt_level}"
+        else:
+            hyperparameters["mixed_precision_training"] = "Native AMP"
+
+    if trainer.args.label_smoothing_factor != 0.0:
+        hyperparameters["label_smoothing_factor"] = trainer.args.label_smoothing_factor
+
+    return hyperparameters
diff --git a/phivenv/Lib/site-packages/transformers/modeling_attn_mask_utils.py b/phivenv/Lib/site-packages/transformers/modeling_attn_mask_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0be1b3ed25531a6bd67aa143b0883815b358678d
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/modeling_attn_mask_utils.py
@@ -0,0 +1,487 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+IMPORTANT NOTICE: Every class and function in this file is deprecated in favor of using the much more general
+`masking_utils.py` primitives. New code should not rely on it, it is only kept for backward compatibility for now,
+and will be removed in the future.
+"""
+
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+
+from .utils.import_utils import is_torchdynamo_compiling
+
+
+@dataclass
+class AttentionMaskConverter:
+    """
+    A utility attention mask class that allows one to:
+        - Create a causal 4d mask
+        - Create a causal 4d mask with slided window
+        - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
+          key_value_length) that can be multiplied with attention scores
+
+    Examples:
+
+    ```python
+    >>> import torch
+    >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+
+    >>> converter = AttentionMaskConverter(True)
+    >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
+    tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
+            [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
+            [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
+            [-3.4028e+38, -3.4028e+38, -3.4028e+38,  0.0000e+00, -3.4028e+38],
+            [-3.4028e+38, -3.4028e+38, -3.4028e+38,  0.0000e+00,  0.0000e+00]]]])
+    ```
+
+    Parameters:
+        is_causal (`bool`):
+            Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
+
+        sliding_window (`int`, *optional*):
+            Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
+    """
+
+    is_causal: bool
+    sliding_window: int
+
+    def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
+        self.is_causal = is_causal
+        self.sliding_window = sliding_window
+
+        if self.sliding_window is not None and self.sliding_window <= 0:
+            raise ValueError(
+                f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
+            )
+
+    def to_causal_4d(
+        self,
+        batch_size: int,
+        query_length: int,
+        key_value_length: int,
+        dtype: torch.dtype,
+        device: Union[torch.device, "str"] = "cpu",
+    ) -> Optional[torch.Tensor]:
+        """
+        Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
+        bias to upper right hand triangular matrix (causal mask).
+        """
+        if not self.is_causal:
+            raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
+
+        # If shape is not cached, create a new causal mask and cache it
+        input_shape = (batch_size, query_length)
+        past_key_values_length = key_value_length - query_length
+
+        # create causal mask
+        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+        causal_4d_mask = None
+        if input_shape[-1] > 1 or self.sliding_window is not None:
+            causal_4d_mask = self._make_causal_mask(
+                input_shape,
+                dtype,
+                device=device,
+                past_key_values_length=past_key_values_length,
+                sliding_window=self.sliding_window,
+            )
+
+        return causal_4d_mask
+
+    def to_4d(
+        self,
+        attention_mask_2d: torch.Tensor,
+        query_length: int,
+        dtype: torch.dtype,
+        key_value_length: Optional[int] = None,
+    ) -> torch.Tensor:
+        """
+        Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
+        key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
+        causal, a causal mask will be added.
+        """
+        input_shape = (attention_mask_2d.shape[0], query_length)
+
+        # create causal mask
+        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+        causal_4d_mask = None
+        if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
+            if key_value_length is None:
+                raise ValueError(
+                    "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
+                )
+
+            past_key_values_length = key_value_length - query_length
+            causal_4d_mask = self._make_causal_mask(
+                input_shape,
+                dtype,
+                device=attention_mask_2d.device,
+                past_key_values_length=past_key_values_length,
+                sliding_window=self.sliding_window,
+            )
+        elif self.sliding_window is not None:
+            raise NotImplementedError("Sliding window is currently only implemented for causal masking")
+
+        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+        expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
+            attention_mask_2d.device
+        )
+
+        if causal_4d_mask is not None:
+            expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
+
+        # expanded_attn_mask + causal_4d_mask can cause some overflow
+        expanded_4d_mask = expanded_attn_mask
+
+        return expanded_4d_mask
+
+    @staticmethod
+    def _make_causal_mask(
+        input_ids_shape: torch.Size,
+        dtype: torch.dtype,
+        device: torch.device,
+        past_key_values_length: int = 0,
+        sliding_window: Optional[int] = None,
+    ):
+        """
+        Make causal mask used for bi-directional self-attention.
+        """
+        bsz, tgt_len = input_ids_shape
+        mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
+        mask_cond = torch.arange(mask.size(-1), device=device)
+        mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+
+        mask = mask.to(dtype)
+
+        if past_key_values_length > 0:
+            mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+
+        # add lower triangular sliding window mask if necessary
+        if sliding_window is not None:
+            diagonal = past_key_values_length - sliding_window - 1
+
+            context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
+            # Recent changes in PyTorch prevent mutations on tensors converted with aten::_to_copy
+            # See https://github.com/pytorch/pytorch/issues/127571
+            if is_torchdynamo_compiling():
+                mask = mask.clone()
+            mask.masked_fill_(context_mask, torch.finfo(dtype).min)
+
+        return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+    @staticmethod
+    def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+        """
+        Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+        """
+        bsz, src_len = mask.size()
+        tgt_len = tgt_len if tgt_len is not None else src_len
+
+        expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+        inverted_mask = torch.tensor(1.0, dtype=dtype) - expanded_mask
+
+        return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+    @staticmethod
+    def _unmask_unattended(
+        expanded_mask: torch.FloatTensor,
+        min_dtype: float,
+    ):
+        # fmt: off
+        """
+        Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
+        using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+        Details: https://github.com/pytorch/pytorch/issues/110213
+
+        `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
+        `attention_mask` is [bsz, src_seq_len].
+
+        The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
+
+        For example, if `expanded_mask` is (e.g. here left-padding case)
+        ```
+        [[[[0, 0, 0],
+           [0, 0, 0],
+           [0, 0, 1]]],
+         [[[1, 0, 0],
+           [1, 1, 0],
+           [1, 1, 1]]],
+         [[[0, 0, 0],
+           [0, 1, 0],
+           [0, 1, 1]]]]
+        ```
+        then the modified `expanded_mask` will be
+        ```
+        [[[[1, 1, 1],   <-- modified
+           [1, 1, 1],   <-- modified
+           [0, 0, 1]]],
+         [[[1, 0, 0],
+           [1, 1, 0],
+           [1, 1, 1]]],
+         [[[1, 1, 1],   <-- modified
+           [0, 1, 0],
+           [0, 1, 1]]]]
+        ```
+        """
+        # fmt: on
+        if expanded_mask.dtype == torch.bool:
+            raise ValueError(
+                "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
+            )
+
+        return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
+
+    @staticmethod
+    def _ignore_causal_mask_sdpa(
+        attention_mask: Optional[torch.Tensor],
+        inputs_embeds: torch.Tensor,
+        past_key_values_length: int,
+        sliding_window: Optional[int] = None,
+        is_training: bool = False,
+    ) -> bool:
+        """
+        Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
+        ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
+
+        In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
+        `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
+        allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
+        passed).
+        """
+
+        _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
+        key_value_length = query_length + past_key_values_length
+
+        is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
+
+        ignore_causal_mask = False
+
+        if attention_mask is None:
+            # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
+            # shape, thus SDPA's `is_causal` argument is rightfully updated
+            # (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
+            # `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
+            # hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
+            # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
+            # Thus, we only set `ignore_causal_mask = True` if the model is set to training.
+            #
+            # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
+            # ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
+            if (
+                (is_training or not is_tracing)
+                and (query_length == 1 or key_value_length == query_length)
+                and (sliding_window is None or key_value_length < sliding_window)
+            ):
+                ignore_causal_mask = True
+        elif sliding_window is None or key_value_length < sliding_window:
+            if len(attention_mask.shape) == 4:
+                return False
+            elif not is_tracing and torch.all(attention_mask == 1):
+                if query_length == 1 or key_value_length == query_length:
+                    # For query_length == 1, causal attention and bi-directional attention are the same.
+                    ignore_causal_mask = True
+
+                # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
+                # the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
+                # SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
+                # Reference: https://github.com/pytorch/pytorch/issues/108108
+                # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
+
+        return ignore_causal_mask
+
+
+def _prepare_4d_causal_attention_mask(
+    attention_mask: Optional[torch.Tensor],
+    input_shape: Union[torch.Size, tuple, list],
+    inputs_embeds: torch.Tensor,
+    past_key_values_length: int,
+    sliding_window: Optional[int] = None,
+):
+    """
+    Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+    `(batch_size, key_value_length)`
+
+    Args:
+        attention_mask (`torch.Tensor` or `None`):
+            A 2D attention mask of shape `(batch_size, key_value_length)`
+        input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
+            The input shape should be a tuple that defines `(batch_size, query_length)`.
+        inputs_embeds (`torch.Tensor`):
+            The embedded inputs as a torch Tensor.
+        past_key_values_length (`int`):
+            The length of the key value cache.
+        sliding_window (`int`, *optional*):
+            If the model uses windowed attention, a sliding window should be passed.
+    """
+    attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
+
+    key_value_length = input_shape[-1] + past_key_values_length
+
+    # 4d mask is passed through the layers
+    if attention_mask is not None and len(attention_mask.shape) == 2:
+        attention_mask = attn_mask_converter.to_4d(
+            attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
+        )
+    elif attention_mask is not None and len(attention_mask.shape) == 4:
+        expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
+        if tuple(attention_mask.shape) != expected_shape:
+            raise ValueError(
+                f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
+            )
+        else:
+            # if the 4D mask has correct shape - invert it and fill with negative infinity
+            inverted_mask = 1.0 - attention_mask
+            attention_mask = inverted_mask.masked_fill(
+                inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
+            )
+    else:
+        attention_mask = attn_mask_converter.to_causal_4d(
+            input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
+        )
+
+    return attention_mask
+
+
+# Adapted from _prepare_4d_causal_attention_mask
+def _prepare_4d_causal_attention_mask_for_sdpa(
+    attention_mask: Optional[torch.Tensor],
+    input_shape: Union[torch.Size, tuple, list],
+    inputs_embeds: torch.Tensor,
+    past_key_values_length: int,
+    sliding_window: Optional[int] = None,
+):
+    """
+    Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
+
+    In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
+    `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
+    allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
+    """
+    attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
+
+    key_value_length = input_shape[-1] + past_key_values_length
+
+    # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
+    # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
+    # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
+    is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
+
+    ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
+        attention_mask=attention_mask,
+        inputs_embeds=inputs_embeds,
+        past_key_values_length=past_key_values_length,
+        sliding_window=sliding_window,
+    )
+
+    if ignore_causal_mask:
+        expanded_4d_mask = None
+    elif attention_mask is None:
+        expanded_4d_mask = attn_mask_converter.to_causal_4d(
+            input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
+        )
+    else:
+        if attention_mask.dim() == 4:
+            expanded_4d_mask = attention_mask
+        else:
+            expanded_4d_mask = attn_mask_converter.to_4d(
+                attention_mask,
+                input_shape[-1],
+                dtype=inputs_embeds.dtype,
+                key_value_length=key_value_length,
+            )
+
+        # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
+        # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+        # Details: https://github.com/pytorch/pytorch/issues/110213
+        if not is_tracing and expanded_4d_mask.device.type == "cuda":
+            expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
+                expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
+            )
+
+    return expanded_4d_mask
+
+
+def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+    """
+    Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+    `(batch_size, key_value_length)`
+
+    Args:
+        mask (`torch.Tensor`):
+            A 2D attention mask of shape `(batch_size, key_value_length)`
+        dtype (`torch.dtype`):
+            The torch dtype the created mask shall have.
+        tgt_len (`int`):
+            The target length or query length the created mask shall have.
+    """
+    return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
+
+
+def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+    """
+    Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+    `(batch_size, key_value_length)`
+
+    Args:
+        mask (`torch.Tensor`):
+            A 2D attention mask of shape `(batch_size, key_value_length)`
+        dtype (`torch.dtype`):
+            The torch dtype the created mask shall have.
+        tgt_len (`int`):
+            The target length or query length the created mask shall have.
+    """
+    _, key_value_length = mask.shape
+    tgt_len = tgt_len if tgt_len is not None else key_value_length
+
+    is_tracing = torch.jit.is_tracing() or isinstance(mask, torch.fx.Proxy) or is_torchdynamo_compiling()
+
+    # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
+    if not is_tracing and torch.all(mask == 1):
+        return None
+    else:
+        return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
+
+
+def _create_4d_causal_attention_mask(
+    input_shape: Union[torch.Size, tuple, list],
+    dtype: torch.dtype,
+    device: torch.device,
+    past_key_values_length: int = 0,
+    sliding_window: Optional[int] = None,
+) -> Optional[torch.Tensor]:
+    """
+    Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
+
+    Args:
+        input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
+            The input shape should be a tuple that defines `(batch_size, query_length)`.
+        dtype (`torch.dtype`):
+            The torch dtype the created mask shall have.
+        device (`int`):
+            The torch device the created mask shall have.
+        sliding_window (`int`, *optional*):
+            If the model uses windowed attention, a sliding window should be passed.
+    """
+    attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
+
+    key_value_length = past_key_values_length + input_shape[-1]
+    attention_mask = attn_mask_converter.to_causal_4d(
+        input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
+    )
+
+    return attention_mask
diff --git a/phivenv/Lib/site-packages/transformers/modeling_flash_attention_utils.py b/phivenv/Lib/site-packages/transformers/modeling_flash_attention_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..37554773a85fe7b85ccee107f58cdd48bca3ff3f
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/modeling_flash_attention_utils.py
@@ -0,0 +1,668 @@
+# Copyright 2025 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import inspect
+import os
+from functools import partial
+from typing import Optional, TypedDict
+
+import torch
+import torch.nn.functional as F
+
+from .utils import (
+    is_flash_attn_2_available,
+    is_flash_attn_3_available,
+    is_flash_attn_greater_or_equal_2_10,
+    is_torch_npu_available,
+    logging,
+)
+
+
+logger = logging.get_logger(__name__)
+
+
+# TODO Deprecate when all models have the attention interface
+def flash_attn_supports_top_left_mask():
+    if is_flash_attn_3_available():
+        return False
+    if is_flash_attn_2_available():
+        return not is_flash_attn_greater_or_equal_2_10()
+
+    from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask
+
+    return is_npu_fa2_top_left_aligned_causal_mask()
+
+
+# TODO Deprecate when all models have the attention interface
+def is_flash_attn_available():
+    return is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available()
+
+
+# `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves
+_flash_fn = None
+_flash_varlen_fn = None
+_pad_fn = None
+_unpad_fn = None
+
+# function that processes kwargs, generalized to handle any supported kwarg within the function
+_process_flash_kwargs_fn = None
+# exceptions where hf API doesn't match the original flash attention API
+_hf_api_to_flash_mapping = {
+    "dropout": "dropout_p",
+    "sliding_window": "window_size",
+}
+
+
+def _lazy_imports(implementation: Optional[str]):
+    """
+    Lazy loads the respective flash attention implementations.
+
+    Return:
+        flash_attn_func: The base flash attention function.
+        flash_attn_varlen_func: The flash attention function supporting variable sequence lengths,
+                                e.g. for padding-free training.
+        pad_input: The function to pad inputs into one sequence and returning the respective kwargs.
+        unpad_input: The function to unpad outputs based on the kwargs (from pad_input).
+    """
+    is_fa2 = is_flash_attn_2_available()
+    is_fa3 = is_flash_attn_3_available()
+
+    pad_input, unpad_input = _pad_input, _unpad_input
+
+    if (implementation == "flash_attention_2" and is_fa2) or (implementation is None and is_fa2 and not is_fa3):
+        from flash_attn import flash_attn_func, flash_attn_varlen_func
+        from flash_attn.bert_padding import pad_input, unpad_input
+    elif is_torch_npu_available():
+        # Package `flash-attn` is unavailable on Ascend NPU, which will cause ImportError
+        # Flash-Attention2 related apis for Ascend NPU must be imported from `.integrations.npu_flash_attention` module
+        from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
+        from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
+    else:
+        if implementation == "flash_attention_3" or (implementation is None and is_fa3):
+            from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
+        # Kernels fallback
+        else:
+            flash_attn_func = getattr(implementation, "flash_attn_func", None)
+            flash_attn_varlen_func = getattr(implementation, "flash_attn_varlen_func", None)
+            if flash_attn_varlen_func is None or flash_attn_func is None:
+                raise ValueError(
+                    f"Could not find the currently requested flash attention implementation at `{implementation}`."
+                    f"Make sure that you request a valid kernel from the hub, e.g. `kernels-community/flash-attn`."
+                )
+
+    return flash_attn_func, flash_attn_varlen_func, pad_input, unpad_input
+
+
+def _lazy_define_process_function(flash_function):
+    """
+    Depending on the version and kernel some features are not supported. Due to limitations in
+    `torch.compile`, we opt to statically type which (optional) kwarg parameters are supported
+    within `_process_flash_attention_kwargs`.
+
+    NOTE: While all supported kwargs are marked as `True`, everything else is marked as `False`.
+          This might be confusing for kwargs that we use in any case, e.g. `is_causal`.
+    """
+
+    flash_parameters = inspect.signature(flash_function).parameters
+    process_parameters = inspect.signature(_process_flash_attention_kwargs).parameters
+
+    supports_mapping = {}
+    for param in process_parameters:
+        fa_param = _hf_api_to_flash_mapping.get(param, param)
+        supports_mapping[fa_param] = fa_param in flash_parameters
+
+    return partial(_process_flash_attention_kwargs, supports_mapping=supports_mapping)
+
+
+def lazy_import_flash_attention(implementation: Optional[str]):
+    """
+    Lazily import flash attention and return the respective functions + flags.
+
+    NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can
+    work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`.
+    """
+    global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn
+    if any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]):
+        _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn = _lazy_imports(implementation)
+
+    global _process_flash_kwargs_fn
+    if _process_flash_kwargs_fn is None:
+        _process_flash_kwargs_fn = _lazy_define_process_function(_flash_varlen_fn)
+
+    return (_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn), _process_flash_kwargs_fn
+
+
+def _index_first_axis(tensor, indices):
+    """
+    A local implementation of the PyTorch indexing operation `tensor[indices]` on the first axis,
+    after flattening the first two dimensions of the tensor. This is functionally equivalent to
+    FA2's `index_first_axis` and replaces the need to import it.
+    """
+    # The input tensor is expected to be of shape (batch, seq_len, ...). We flatten the first
+    # two dimensions to get (total_tokens, ...) before indexing.
+    reshaped_tensor = tensor.reshape(-1, *tensor.shape[2:])
+    return reshaped_tensor[indices]
+
+
+def _unpad_input(hidden_states, attention_mask, unused_mask=None):
+    """
+    unpad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
+
+    Arguments:
+        hidden_states: (batch, seqlen, ...)
+        attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
+        unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
+
+    Return:
+        hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
+        indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
+        cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
+        max_seqlen_in_batch: int
+        seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
+    """
+    all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
+    seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
+    used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+    indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
+    max_seqlen_in_batch = seqlens_in_batch.max().item()
+    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+
+    return (
+        _index_first_axis(hidden_states, indices),
+        indices,
+        cu_seqlens,
+        max_seqlen_in_batch,
+        used_seqlens_in_batch,
+    )
+
+
+def _pad_input(hidden_states, indices, batch, seqlen):
+    """
+    pad_input function for flash attention variants that do not have them within their pkg themselves, e.g. fa3.
+
+    Arguments:
+        hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
+        indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
+        batch: int, batch size for the padded sequence.
+        seqlen: int, maximum sequence length for the padded sequence.
+
+    Return:
+        hidden_states: (batch, seqlen, ...)
+    """
+    dim = hidden_states.shape[1:]
+    output = torch.zeros((batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype)
+    output[indices] = hidden_states
+    return output.view(batch, seqlen, *dim)
+
+
+def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
+    """
+    Retrieves indexing data required to repad unpadded (ragged) tensors.
+
+    Arguments:
+        attention_mask (`torch.Tensor`):
+            Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
+
+    Return:
+        indices (`torch.Tensor`):
+            The indices of non-masked tokens from the flattened input sequence.
+        cu_seqlens (`torch.Tensor`):
+            The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+        max_seqlen_in_batch (`int`):
+            Maximum sequence length in batch.
+    """
+    seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+    indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+    # NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile,
+    # this might cause a graph break
+    max_seqlen_in_batch = seqlens_in_batch.max().item()
+    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+    return (
+        indices,
+        cu_seqlens,
+        max_seqlen_in_batch,
+    )
+
+
+def _upad_input(
+    query_layer: torch.Tensor,
+    key_layer: torch.Tensor,
+    value_layer: torch.Tensor,
+    attention_mask: torch.Tensor,
+    query_length: int,
+    unpad_input_func,
+):
+    """
+    Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
+    This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
+    tensors for query, key, value tensors.
+
+    Arguments:
+        query_layer (`torch.Tensor`):
+            Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
+        key_layer (`torch.Tensor`):
+            Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
+        value_layer (`torch.Tensor`):
+            Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
+        attention_mask (`torch.Tensor`):
+            Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
+        query_length (`int`):
+            Target length.
+        unpad_input_func:
+            The function to use for unpadding the input tensors.
+
+    Return:
+        query_layer (`torch.Tensor`):
+            Query state without padding. Shape: (total_target_length, num_heads, head_dim).
+        key_layer (`torch.Tensor`):
+            Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
+        value_layer (`torch.Tensor`):
+            Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
+        indices_q (`torch.Tensor`):
+            The indices of non-masked tokens from the flattened input target sequence.
+        (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
+            The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+        (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
+            Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
+    """
+    indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+
+    # With static caches, the k/v states may be larger than the mask -> we need to slice them to avoid generating garbage
+    # It's a bit of an anti-pattern, but otherwise we silently compute wrong attentions scores
+    if key_layer.shape[1] > (seq_len := attention_mask.shape[-1]):
+        key_layer, value_layer = key_layer[:, :seq_len, :, :], value_layer[:, :seq_len, :, :]
+
+    batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+    key_layer = _index_first_axis(key_layer, indices_k)
+    value_layer = _index_first_axis(value_layer, indices_k)
+    if query_length == kv_seq_len:
+        query_layer = _index_first_axis(query_layer, indices_k)
+        cu_seqlens_q = cu_seqlens_k
+        max_seqlen_in_batch_q = max_seqlen_in_batch_k
+        indices_q = indices_k
+    elif query_length == 1:
+        max_seqlen_in_batch_q = 1
+        cu_seqlens_q = torch.arange(
+            batch_size + 1, dtype=torch.int32, device=query_layer.device
+        )  # There is a memcpy here, that is very bad.
+        indices_q = cu_seqlens_q[:-1]
+        query_layer = query_layer.squeeze(1)
+    else:
+        # The -q_len: slice assumes left padding.
+        attention_mask = attention_mask[:, -query_length:]
+        query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, *_ = unpad_input_func(query_layer, attention_mask)
+
+    return (
+        query_layer,
+        key_layer,
+        value_layer,
+        indices_q,
+        (cu_seqlens_q, cu_seqlens_k),
+        (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+    )
+
+
+def prepare_fa_kwargs_from_position_ids(position_ids):
+    """
+    This function returns all the necessary kwargs to call `flash_attn_varlen_func` extracted from position_ids.
+
+    Arguments:
+        position_ids (`torch.Tensor`):
+            Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
+
+    Return:
+        (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
+            The cumulative sequence lengths for the target (query) and source (key, value), used to index into
+            ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+        (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
+            Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query,
+            `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
+    """
+    tensor_kwargs = {"dtype": torch.int32, "device": position_ids.device}
+
+    position_ids = position_ids.view(-1)
+    indices_q = (position_ids == 0).nonzero().view(-1)
+
+    cu_seq_lens_q = torch.cat(
+        (
+            indices_q.to(**tensor_kwargs),
+            torch.tensor(position_ids.size(), **tensor_kwargs),
+        )
+    )
+    cu_seq_lens_k = cu_seq_lens_q
+
+    # https://github.com/Dao-AILab/flash-attention/blob/2dd8078adc1d9b74e315ee99718c0dea0de8eeb6/flash_attn/flash_attn_interface.py#L1423-L1424
+    # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
+    # for some models (e.g. qwen2-vl).
+    max_length_q = cu_seq_lens_q.diff().max()
+    # NOTE: With torch compile, this will cause a graph break if you don't set
+    # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
+    # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
+    # This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
+    # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
+    max_length_q = max_length_q.item()
+    max_length_k = max_length_q
+
+    return (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k)
+
+
+def _prepare_from_posids(query, key, value, position_ids):
+    """
+    This function returns necessary arguments to call `flash_attn_varlen_func`.
+    All three query, key, value states will be flattened.
+    Cumulative lengths of each examples in the batch will be extracted from position_ids.
+    NOTE: ideally cumulative lengths should be prepared at the data collator stage
+
+    Arguments:
+        query (`torch.Tensor`):
+            Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
+        key (`torch.Tensor`):
+            Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
+        value (`torch.Tensor`):
+            Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
+        position_ids (`torch.Tensor`):
+            Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
+
+    Return:
+        query (`torch.Tensor`):
+            Query state without padding. Shape: (total_target_length, num_heads, head_dim).
+        key (`torch.Tensor`):
+            Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
+        value (`torch.Tensor`):
+            Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
+        (cu_seqlens_q, cu_seqlens_k) (`tuple[int]`):
+            The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+        (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`tuple[int]`):
+            Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
+    """
+    query = query.contiguous().view(-1, query.size(-2), query.size(-1))
+    key = key.contiguous().view(-1, key.size(-2), key.size(-1))
+    value = value.contiguous().view(-1, value.size(-2), value.size(-1))
+
+    (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = prepare_fa_kwargs_from_position_ids(position_ids)
+
+    return (query, key, value, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k))
+
+
+def _is_packed_sequence(position_ids, batch_size):
+    """
+    Check the position ids whether packed sequences are indicated or not
+        1. Position ids exist
+        2. Flattened sequences only are supported
+        3. Compile-friendly `not (torch.diff(position_ids, dim=-1) >= 0).all()`, i.e. we have multiple increasing sequences
+    """
+    if position_ids is None:
+        return False
+
+    increasing_position_sequences = (
+        torch.arange(position_ids.shape[1], device=position_ids.device) + position_ids.min()
+    )
+    return batch_size == 1 and (increasing_position_sequences - position_ids).abs().sum().bool()
+
+
+def fa_peft_integration_check(
+    q: torch.Tensor,
+    k: torch.Tensor,
+    v: torch.Tensor,
+    target_dtype: Optional[torch.dtype] = None,
+):
+    """
+    PEFT usually casts the layer norms in float32 for training stability reasons
+    therefore the input hidden states gets silently casted in float32. Hence, we need
+    cast them back in float16 / bfloat16 just to be sure everything works as expected.
+    This might slowdown training & inference so it is recommended to not cast the LayerNorms!
+    """
+    if target_dtype and q.dtype == torch.float32:
+        logger.warning_once(f"Casting fp32 inputs back to {target_dtype} for flash-attn compatibility.")
+        q, k, v = q.to(target_dtype), k.to(target_dtype), v.to(target_dtype)
+    return q, k, v
+
+
+class FlashAttentionKwargs(TypedDict, total=False):
+    """
+    Keyword arguments for Flash Attention with Compile.
+
+    Attributes:
+        cu_seq_lens_q (`torch.LongTensor`, *optional*)
+            Gets cumulative sequence length for query state.
+        cu_seq_lens_k (`torch.LongTensor`, *optional*)
+            Gets cumulative sequence length for key state.
+        max_length_q (`int`, *optional*):
+            Maximum sequence length for query state.
+        max_length_k (`int`, *optional*):
+            Maximum sequence length for key state.
+    """
+
+    cu_seq_lens_q: Optional[torch.LongTensor]
+    cu_seq_lens_k: Optional[torch.LongTensor]
+    max_length_q: Optional[int]
+    max_length_k: Optional[int]
+
+
+def _process_flash_attention_kwargs(
+    query_length: int,
+    key_length: int,
+    is_causal: bool,
+    dropout: float = 0.0,
+    softmax_scale: Optional[float] = None,
+    sliding_window: Optional[int] = None,
+    use_top_left_mask: bool = False,
+    softcap: Optional[float] = None,
+    deterministic: Optional[bool] = None,
+    s_aux: Optional[torch.Tensor] = None,
+    supports_mapping: Optional[dict[str, bool]] = None,
+    **kwargs,
+):
+    """
+    Returns a set of kwargs that are passed down to the according flash attention function based on
+    requested features and whether it is supported - depends on the version and kernel implementation
+    which is dynamically configured at `lazy_import_flash_attention`. The (un)supported features can be
+    inspected in `supports_mapping`, see `_lazy_define_process_function` for more details.
+
+    Args:
+        query_length (`int`):
+            Length of the query states
+        key_length (`int`):
+            Length of the key states
+        is_causal (`bool`):
+            Whether we perform causal (decoder) attention or full attention.
+        dropout (`float`):
+            Attention dropout.
+        softmax_scale (`float`, *optional*):
+            The scaling of QK^T before applying softmax. Default to `1 / sqrt(head_dim)`.
+        sliding_window (`int`, *optional*):
+            The size of the sliding window, i.e. we look at a max of `sliding_window` tokens back.
+        use_top_left_mask (`bool`):
+            Deprecated behavior of older versions of flash attention requiring different masking.
+        softcap (`float`, *optional*):
+            Softcap for the attention logits, used e.g. in gemma2.
+        deterministic (`bool`, *optional*):
+            Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
+        s_aux (`torch.Tensor`, *optional*):
+            Attention sink auxiliary that adds a `bias` to the attention calculation via an additional head.
+    Return:
+        flash_kwargs (`dict`):
+            A dict of kwargs that are requested and supported.
+    """
+    flash_kwargs = {
+        "causal": is_causal and not (use_top_left_mask and query_length == 1),
+        "softmax_scale": softmax_scale,
+    }
+
+    if supports_mapping["dropout_p"]:
+        flash_kwargs["dropout_p"] = dropout
+
+    if supports_mapping["window_size"] and sliding_window is not None and key_length > sliding_window:
+        # The flash attention API sets inclusive boundaries, i.e. (4, 0) would take 4 tokens to the left
+        # and the current token for a total size of 5. However, we usually define our window sizes by
+        # their total window size (when causal). Encoder models as of now seldom use SWA and when they
+        # do, they have a custom workaround (e.g. ModernBERT) which would align with this symmetric logic, i.e.
+        # for a total of `2*sliding_window + 1`.
+        flash_kwargs["window_size"] = (sliding_window - 1, sliding_window - 1)
+
+    if supports_mapping["deterministic"]:
+        flash_kwargs["deterministic"] = (
+            deterministic if deterministic is not None else os.getenv("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
+        )
+
+    if supports_mapping["softcap"] and softcap is not None:
+        flash_kwargs["softcap"] = softcap
+
+    # Only within kernel implementation atm
+    if supports_mapping["s_aux"] and s_aux is not None:
+        flash_kwargs["s_aux"] = s_aux
+
+    return flash_kwargs
+
+
+def _flash_attention_forward(
+    query_states: torch.Tensor,
+    key_states: torch.Tensor,
+    value_states: torch.Tensor,
+    attention_mask: Optional[torch.Tensor],
+    query_length: int,
+    is_causal: bool,
+    dropout: float = 0.0,
+    position_ids: Optional[torch.Tensor] = None,
+    softmax_scale: Optional[float] = None,
+    sliding_window: Optional[int] = None,
+    use_top_left_mask: bool = False,
+    softcap: Optional[float] = None,
+    deterministic: Optional[bool] = None,
+    cu_seq_lens_q: Optional[torch.LongTensor] = None,
+    cu_seq_lens_k: Optional[torch.LongTensor] = None,
+    max_length_q: Optional[int] = None,
+    max_length_k: Optional[int] = None,
+    target_dtype: Optional[torch.dtype] = None,
+    implementation: Optional[str] = None,
+    **kwargs,
+):
+    """
+    Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+    first unpad the input, then computes the attention scores and pad the final attention scores.
+
+    (Optional) kwargs are described further in `_process_flash_attention_kwargs` and `FlashAttentionKwargs`.
+
+    Args:
+        query_states (`torch.Tensor`):
+            Input query states to be passed to Flash Attention API
+        key_states (`torch.Tensor`):
+            Input key states to be passed to Flash Attention API
+        value_states (`torch.Tensor`):
+            Input value states to be passed to Flash Attention API
+        attention_mask (`torch.Tensor`, *optional*):
+            The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+            position of padding tokens and 1 for the position of non-padding tokens.
+        implementation (`str`, *optional*):
+            The attention implementation to use. If None, will default to the one based on the environment.
+    """
+    (flash_fn, flash_varlen_fn, pad_fn, unpad_fn), process_flash_kwargs_fn = lazy_import_flash_attention(
+        implementation
+    )
+
+    # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
+    query_states, key_states, value_states = fa_peft_integration_check(
+        query_states, key_states, value_states, target_dtype
+    )
+
+    # Extract the flash attention kwargs that have been requested (and are supported by the implementation)
+    flash_kwargs = process_flash_kwargs_fn(
+        query_length=query_length,
+        key_length=key_states.size(1),
+        is_causal=is_causal,
+        dropout=dropout,
+        softmax_scale=softmax_scale,
+        sliding_window=sliding_window,
+        use_top_left_mask=use_top_left_mask,
+        softcap=softcap,
+        deterministic=deterministic,
+        **kwargs,
+    )
+
+    # We will use `flash_varlen_fn` to prevent cross-example attention and also allow padding free approach under two cases:
+    # Case 1. If position ids is provided and the position ids indicate packed sequences, see `_is_packed_sequence`.
+    # Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
+    # use `flash_varlen_fn` knowing we already have all necessary the kwargs.
+    #
+    # NOTE: it is user's responsibility to take care of flattening `position_ids` if that's needed by the model.
+    # See #39121 for more information.
+    is_fa_with_position_ids = _is_packed_sequence(position_ids, batch_size=query_states.size(0))
+    is_fa_with_varlen_kwargs = all(
+        kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k)
+    )
+
+    # Contains at least one padding token in the sequence
+    if attention_mask is not None:
+        q, k, v, indices_q, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _upad_input(
+            query_states, key_states, value_states, attention_mask, query_length, unpad_fn
+        )
+
+        # TODO for now this is required to work with
+        # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py
+        if "mps" in str(q.device):
+            cu_seq_lens_k = cu_seq_lens_k.clone()
+
+        out_unpad = flash_varlen_fn(
+            q,
+            k,
+            v,
+            cu_seqlens_q=cu_seq_lens_q,
+            cu_seqlens_k=cu_seq_lens_k,
+            max_seqlen_q=max_length_q,
+            max_seqlen_k=max_length_k,
+            **flash_kwargs,
+        )
+        if isinstance(out_unpad, tuple):
+            out_unpad = out_unpad[0]
+
+        out = pad_fn(out_unpad, indices_q, query_states.size(0), query_length)
+
+    # Padding free, i.e. sequences flattened into one total sequence
+    elif is_fa_with_varlen_kwargs or is_fa_with_position_ids:
+        if cu_seq_lens_q is None or cu_seq_lens_k is None:
+            q, k, v, (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) = _prepare_from_posids(
+                query_states, key_states, value_states, position_ids
+            )
+        else:
+            q = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
+            k = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
+            v = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
+
+        # TODO for now this is required to work with
+        # https://huggingface.co/kernels-community/metal-flash-sdpa/blob/main/torch-ext/metal_flash_sdpa/__init__.py
+        if "mps" in str(q.device):
+            cu_seq_lens_k = cu_seq_lens_k.clone()
+
+        out = flash_varlen_fn(
+            q,
+            k,
+            v,
+            cu_seqlens_q=cu_seq_lens_q,
+            cu_seqlens_k=cu_seq_lens_k,
+            max_seqlen_q=max_length_q,
+            max_seqlen_k=max_length_k,
+            **flash_kwargs,
+        )
+        if isinstance(out, tuple):
+            out = out[0]
+
+        out = out.view(query_states.size(0), -1, out.size(-2), out.size(-1))
+
+    # No padding
+    else:
+        out = flash_fn(query_states, key_states, value_states, **flash_kwargs)
+        if isinstance(out, tuple):
+            out = out[0]
+
+    return out
diff --git a/phivenv/Lib/site-packages/transformers/modeling_flax_outputs.py b/phivenv/Lib/site-packages/transformers/modeling_flax_outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a25a6059a255659c6d900b35d2ffa7cab57f071
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/modeling_flax_outputs.py
@@ -0,0 +1,700 @@
+# Copyright 2021 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+import flax
+import jax.numpy as jnp
+
+from .utils import ModelOutput
+
+
+@flax.struct.dataclass
+class FlaxBaseModelOutput(ModelOutput):
+    """
+    Base class for model's outputs, with potential hidden states and attentions.
+
+    Args:
+        last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    last_hidden_state: Optional[jnp.ndarray] = None
+    hidden_states: Optional[tuple[jnp.ndarray]] = None
+    attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxBaseModelOutputWithNoAttention(ModelOutput):
+    """
+    Base class for model's outputs, with potential hidden states.
+
+    Args:
+        last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
+            for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
+            model at the output of each layer plus the optional initial embedding outputs.
+    """
+
+    last_hidden_state: Optional[jnp.ndarray] = None
+    hidden_states: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
+    """
+    Base class for model's outputs that also contains a pooling of the last hidden states.
+
+    Args:
+        last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
+            Last layer hidden-state after a pooling operation on the spatial dimensions.
+        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
+            for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
+            model at the output of each layer plus the optional initial embedding outputs.
+    """
+
+    last_hidden_state: Optional[jnp.ndarray] = None
+    pooler_output: Optional[jnp.ndarray] = None
+    hidden_states: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxImageClassifierOutputWithNoAttention(ModelOutput):
+    """
+    Base class for outputs of image classification models.
+
+    Args:
+        logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
+            Classification (or regression if config.num_labels==1) scores (before SoftMax).
+        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when
+        `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
+            for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
+            called feature maps) of the model at the output of each stage.
+    """
+
+    logits: Optional[jnp.ndarray] = None
+    hidden_states: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxBaseModelOutputWithPast(ModelOutput):
+    """
+    Base class for model's outputs, with potential hidden states and attentions.
+
+    Args:
+        last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        past_key_values (`dict[str, jnp.ndarray]`):
+            Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
+            auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
+        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    last_hidden_state: Optional[jnp.ndarray] = None
+    past_key_values: Optional[dict[str, jnp.ndarray]] = None
+    hidden_states: Optional[tuple[jnp.ndarray]] = None
+    attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxBaseModelOutputWithPooling(ModelOutput):
+    """
+    Base class for model's outputs that also contains a pooling of the last hidden states.
+
+    Args:
+        last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
+            Last layer hidden-state of the first token of the sequence (classification token) further processed by a
+            Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
+            prediction (classification) objective during pretraining.
+        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    last_hidden_state: Optional[jnp.ndarray] = None
+    pooler_output: Optional[jnp.ndarray] = None
+    hidden_states: Optional[tuple[jnp.ndarray]] = None
+    attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
+    """
+    Base class for model's outputs that also contains a pooling of the last hidden states.
+
+    Args:
+        last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
+            Last layer hidden-state of the first token of the sequence (classification token) after further processing
+            through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
+            the classification token after processing through a linear layer and a tanh activation function. The linear
+            layer weights are trained from the next sentence prediction (classification) objective during pretraining.
+        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
+            for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
+            `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
+            encoder_sequence_length, embed_size_per_head)`.
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+            `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+            input) to speed up sequential decoding.
+    """
+
+    last_hidden_state: Optional[jnp.ndarray] = None
+    pooler_output: Optional[jnp.ndarray] = None
+    hidden_states: Optional[tuple[jnp.ndarray]] = None
+    past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
+    attentions: Optional[tuple[jnp.ndarray]] = None
+    cross_attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
+    """
+    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+    Args:
+        last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+
+            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+            hidden_size)` is output.
+        past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
+            `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
+            encoder_sequence_length, embed_size_per_head)`.
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+            `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+            input) to speed up sequential decoding.
+        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+    """
+
+    last_hidden_state: Optional[jnp.ndarray] = None
+    past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
+    hidden_states: Optional[tuple[jnp.ndarray]] = None
+    attentions: Optional[tuple[jnp.ndarray]] = None
+    cross_attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxSeq2SeqModelOutput(ModelOutput):
+    """
+    Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
+    decoding.
+
+    Args:
+        last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the decoder of the model.
+
+            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+            hidden_size)` is output.
+        past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+        decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+        decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder of the model.
+        encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+        encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+    """
+
+    last_hidden_state: Optional[jnp.ndarray] = None
+    past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
+    decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
+    decoder_attentions: Optional[tuple[jnp.ndarray]] = None
+    cross_attentions: Optional[tuple[jnp.ndarray]] = None
+    encoder_last_hidden_state: Optional[jnp.ndarray] = None
+    encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
+    encoder_attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxCausalLMOutputWithCrossAttentions(ModelOutput):
+    """
+    Base class for causal language model (or autoregressive) outputs.
+
+    Args:
+        logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Cross attentions weights after the attention softmax, used to compute the weighted average in the
+            cross-attention heads.
+        past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `jnp.ndarray` tuples of length `config.n_layers`, with each tuple containing the cached key, value
+            states of the self-attention and the cross-attention layers if model is used in encoder-decoder setting.
+            Only relevant if `config.is_decoder = True`.
+
+            Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+            `past_key_values` input) to speed up sequential decoding.
+    """
+
+    logits: Optional[jnp.ndarray] = None
+    past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
+    hidden_states: Optional[tuple[jnp.ndarray]] = None
+    attentions: Optional[tuple[jnp.ndarray]] = None
+    cross_attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxMaskedLMOutput(ModelOutput):
+    """
+    Base class for masked language models outputs.
+
+    Args:
+        logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    logits: Optional[jnp.ndarray] = None
+    hidden_states: Optional[tuple[jnp.ndarray]] = None
+    attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+FlaxCausalLMOutput = FlaxMaskedLMOutput
+
+
+@flax.struct.dataclass
+class FlaxSeq2SeqLMOutput(ModelOutput):
+    """
+    Base class for sequence-to-sequence language models outputs.
+
+    Args:
+        logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+        decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+        decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder of the model.
+        encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+        encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+    """
+
+    logits: Optional[jnp.ndarray] = None
+    past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
+    decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
+    decoder_attentions: Optional[tuple[jnp.ndarray]] = None
+    cross_attentions: Optional[tuple[jnp.ndarray]] = None
+    encoder_last_hidden_state: Optional[jnp.ndarray] = None
+    encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
+    encoder_attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxNextSentencePredictorOutput(ModelOutput):
+    """
+    Base class for outputs of models predicting if two sentences are consecutive or not.
+
+    Args:
+        logits (`jnp.ndarray` of shape `(batch_size, 2)`):
+            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
+            before SoftMax).
+        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    logits: Optional[jnp.ndarray] = None
+    hidden_states: Optional[tuple[jnp.ndarray]] = None
+    attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxSequenceClassifierOutput(ModelOutput):
+    """
+    Base class for outputs of sentence classification models.
+
+    Args:
+        logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
+            Classification (or regression if config.num_labels==1) scores (before SoftMax).
+        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    logits: Optional[jnp.ndarray] = None
+    hidden_states: Optional[tuple[jnp.ndarray]] = None
+    attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxSeq2SeqSequenceClassifierOutput(ModelOutput):
+    """
+    Base class for outputs of sequence-to-sequence sentence classification models.
+
+    Args:
+        logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
+            Classification (or regression if config.num_labels==1) scores (before SoftMax).
+        past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+        decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+        decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder of the model.
+        encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+        encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+    """
+
+    logits: Optional[jnp.ndarray] = None
+    past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
+    decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
+    decoder_attentions: Optional[tuple[jnp.ndarray]] = None
+    cross_attentions: Optional[tuple[jnp.ndarray]] = None
+    encoder_last_hidden_state: Optional[jnp.ndarray] = None
+    encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
+    encoder_attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxMultipleChoiceModelOutput(ModelOutput):
+    """
+    Base class for outputs of multiple choice models.
+
+    Args:
+        logits (`jnp.ndarray` of shape `(batch_size, num_choices)`):
+            *num_choices* is the second dimension of the input tensors. (see *input_ids* above).
+
+            Classification scores (before SoftMax).
+        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    logits: Optional[jnp.ndarray] = None
+    hidden_states: Optional[tuple[jnp.ndarray]] = None
+    attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxTokenClassifierOutput(ModelOutput):
+    """
+    Base class for outputs of token classification models.
+
+    Args:
+        logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.num_labels)`):
+            Classification scores (before SoftMax).
+        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    logits: Optional[jnp.ndarray] = None
+    hidden_states: Optional[tuple[jnp.ndarray]] = None
+    attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxQuestionAnsweringModelOutput(ModelOutput):
+    """
+    Base class for outputs of question answering models.
+
+    Args:
+        start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+            Span-start scores (before SoftMax).
+        end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+            Span-end scores (before SoftMax).
+        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    start_logits: Optional[jnp.ndarray] = None
+    end_logits: Optional[jnp.ndarray] = None
+    hidden_states: Optional[tuple[jnp.ndarray]] = None
+    attentions: Optional[tuple[jnp.ndarray]] = None
+
+
+@flax.struct.dataclass
+class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
+    """
+    Base class for outputs of sequence-to-sequence question answering models.
+
+    Args:
+        start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+            Span-start scores (before SoftMax).
+        end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
+            Span-end scores (before SoftMax).
+        past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+            `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
+            `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+        decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+        decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder of the model.
+        encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+        encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+    """
+
+    start_logits: Optional[jnp.ndarray] = None
+    end_logits: Optional[jnp.ndarray] = None
+    past_key_values: Optional[tuple[tuple[jnp.ndarray]]] = None
+    decoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
+    decoder_attentions: Optional[tuple[jnp.ndarray]] = None
+    cross_attentions: Optional[tuple[jnp.ndarray]] = None
+    encoder_last_hidden_state: Optional[jnp.ndarray] = None
+    encoder_hidden_states: Optional[tuple[jnp.ndarray]] = None
+    encoder_attentions: Optional[tuple[jnp.ndarray]] = None
diff --git a/phivenv/Lib/site-packages/transformers/modeling_flax_pytorch_utils.py b/phivenv/Lib/site-packages/transformers/modeling_flax_pytorch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..07d83d5e4aa64e822203ef99393ed22b5e5b98e3
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/modeling_flax_pytorch_utils.py
@@ -0,0 +1,491 @@
+# coding=utf-8
+# Copyright 2021 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch - Flax general utilities."""
+
+import os
+from pickle import UnpicklingError
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from flax.serialization import from_bytes
+from flax.traverse_util import flatten_dict, unflatten_dict
+
+import transformers
+
+from . import is_safetensors_available, is_torch_available
+from .utils import check_torch_load_is_safe, logging
+
+
+if is_torch_available():
+    import torch
+
+if is_safetensors_available():
+    from safetensors import safe_open
+    from safetensors.flax import load_file as safe_load_file
+
+
+logger = logging.get_logger(__name__)
+
+
+#####################
+# PyTorch => Flax #
+#####################
+
+
+def load_pytorch_checkpoint_in_flax_state_dict(
+    flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
+):
+    """Load pytorch checkpoints in a flax model"""
+
+    if not is_sharded:
+        pt_path = os.path.abspath(pytorch_checkpoint_path)
+        logger.info(f"Loading PyTorch weights from {pt_path}")
+
+        if pt_path.endswith(".safetensors"):
+            pt_state_dict = {}
+            with safe_open(pt_path, framework="flax") as f:
+                for k in f.keys():
+                    pt_state_dict[k] = f.get_tensor(k)
+        else:
+            try:
+                import torch  # noqa: F401
+            except (ImportError, ModuleNotFoundError):
+                logger.error(
+                    "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
+                    " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/index.html#installation for installation"
+                    " instructions."
+                )
+                raise
+
+            check_torch_load_is_safe()
+            pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
+            logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
+
+        flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
+    else:
+        # model is sharded and pytorch_checkpoint_path already contains the list of .pt shard files
+        flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model)
+    return flax_state_dict
+
+
+def rename_key_and_reshape_tensor(
+    pt_tuple_key: tuple[str],
+    pt_tensor: np.ndarray,
+    random_flax_state_dict: dict[str, jnp.ndarray],
+    model_prefix: str,
+) -> (tuple[str], np.ndarray):
+    """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
+
+    def is_key_or_prefix_key_in_dict(key: tuple[str]) -> bool:
+        """Checks if `key` of `(prefix,) + key` is in random_flax_state_dict"""
+        return len(set(random_flax_state_dict) & {key, (model_prefix,) + key}) > 0
+
+    # layer norm
+    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
+    if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
+        return renamed_pt_tuple_key, pt_tensor
+
+    # batch norm layer mean
+    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("mean",)
+    if pt_tuple_key[-1] == "running_mean" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
+        return renamed_pt_tuple_key, pt_tensor
+
+    # batch norm layer var
+    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("var",)
+    if pt_tuple_key[-1] == "running_var" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
+        return renamed_pt_tuple_key, pt_tensor
+
+    # embedding
+    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
+    if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
+        return renamed_pt_tuple_key, pt_tensor
+
+    # conv layer
+    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
+    if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and not is_key_or_prefix_key_in_dict(pt_tuple_key):
+        pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
+        return renamed_pt_tuple_key, pt_tensor
+
+    # linear layer
+    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
+    if pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
+        pt_tensor = pt_tensor.T
+        return renamed_pt_tuple_key, pt_tensor
+
+    # old PyTorch layer norm weight
+    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
+    if pt_tuple_key[-1] == "gamma":
+        return renamed_pt_tuple_key, pt_tensor
+
+    # old PyTorch layer norm bias
+    renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
+    if pt_tuple_key[-1] == "beta":
+        return renamed_pt_tuple_key, pt_tensor
+
+    # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
+    name = None
+    if pt_tuple_key[-3::2] == ("parametrizations", "original0"):
+        name = pt_tuple_key[-2] + "_g"
+    elif pt_tuple_key[-3::2] == ("parametrizations", "original1"):
+        name = pt_tuple_key[-2] + "_v"
+    if name is not None:
+        renamed_pt_tuple_key = pt_tuple_key[:-3] + (name,)
+        return renamed_pt_tuple_key, pt_tensor
+
+    return pt_tuple_key, pt_tensor
+
+
+def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
+    # convert pytorch tensor to numpy
+    from_bin = is_torch_available() and isinstance(next(iter(pt_state_dict.values())), torch.Tensor)
+    bfloat16 = torch.bfloat16 if from_bin else "bfloat16"
+
+    weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
+
+    if from_bin:
+        for k, v in pt_state_dict.items():
+            # numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
+            if v.dtype == bfloat16:
+                v = v.float()
+            pt_state_dict[k] = v.cpu().numpy()
+
+    model_prefix = flax_model.base_model_prefix
+
+    # use params dict if the model contains batch norm layers
+    if "params" in flax_model.params:
+        flax_model_params = flax_model.params["params"]
+    else:
+        flax_model_params = flax_model.params
+    random_flax_state_dict = flatten_dict(flax_model_params)
+
+    # add batch_stats keys,values to dict
+    if "batch_stats" in flax_model.params:
+        flax_batch_stats = flatten_dict(flax_model.params["batch_stats"])
+        random_flax_state_dict.update(flax_batch_stats)
+
+    flax_state_dict = {}
+
+    load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and (
+        model_prefix in {k.split(".")[0] for k in pt_state_dict}
+    )
+    load_base_model_into_model_with_head = (model_prefix in flax_model_params) and (
+        model_prefix not in {k.split(".")[0] for k in pt_state_dict}
+    )
+
+    # Need to change some parameters name to match Flax names
+    for pt_key, pt_tensor in pt_state_dict.items():
+        pt_tuple_key = tuple(pt_key.split("."))
+        is_bfloat_16 = weight_dtypes[pt_key] == bfloat16
+
+        # remove base model prefix if necessary
+        has_base_model_prefix = pt_tuple_key[0] == model_prefix
+        if load_model_with_head_into_base_model and has_base_model_prefix:
+            pt_tuple_key = pt_tuple_key[1:]
+
+        # Correctly rename weight parameters
+        flax_key, flax_tensor = rename_key_and_reshape_tensor(
+            pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
+        )
+
+        # add model prefix if necessary
+        require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
+        if load_base_model_into_model_with_head and require_base_model_prefix:
+            flax_key = (model_prefix,) + flax_key
+
+        if flax_key in random_flax_state_dict:
+            if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
+                raise ValueError(
+                    f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
+                    f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
+                )
+
+        # add batch stats if the model contains batchnorm layers
+        if "batch_stats" in flax_model.params:
+            if "mean" in flax_key[-1] or "var" in flax_key[-1]:
+                flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
+                continue
+            # remove num_batches_tracked key
+            if "num_batches_tracked" in flax_key[-1]:
+                flax_state_dict.pop(flax_key, None)
+                continue
+
+            # also add unexpected weight so that warning is thrown
+            flax_state_dict[("params",) + flax_key] = (
+                jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
+            )
+        else:
+            # also add unexpected weight so that warning is thrown
+            flax_state_dict[flax_key] = (
+                jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
+            )
+
+    return unflatten_dict(flax_state_dict)
+
+
+############################
+# Sharded Pytorch => Flax #
+############################
+
+
+def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
+    import torch
+
+    # Load the index
+    flax_state_dict = {}
+    for shard_file in shard_filenames:
+        # load using msgpack utils
+        check_torch_load_is_safe()
+        pt_state_dict = torch.load(shard_file, weights_only=True)
+        weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
+        pt_state_dict = {
+            k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
+        }
+
+        model_prefix = flax_model.base_model_prefix
+
+        # use params dict if the model contains batch norm layers and then add batch_stats keys,values to dict
+        if "batch_stats" in flax_model.params:
+            flax_model_params = flax_model.params["params"]
+
+            random_flax_state_dict = flatten_dict(flax_model_params)
+            random_flax_state_dict.update(flatten_dict(flax_model.params["batch_stats"]))
+        else:
+            flax_model_params = flax_model.params
+            random_flax_state_dict = flatten_dict(flax_model_params)
+
+        load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and (
+            model_prefix in {k.split(".")[0] for k in pt_state_dict}
+        )
+        load_base_model_into_model_with_head = (model_prefix in flax_model_params) and (
+            model_prefix not in {k.split(".")[0] for k in pt_state_dict}
+        )
+        # Need to change some parameters name to match Flax names
+        for pt_key, pt_tensor in pt_state_dict.items():
+            pt_tuple_key = tuple(pt_key.split("."))
+            is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16
+
+            # remove base model prefix if necessary
+            has_base_model_prefix = pt_tuple_key[0] == model_prefix
+            if load_model_with_head_into_base_model and has_base_model_prefix:
+                pt_tuple_key = pt_tuple_key[1:]
+
+            # Correctly rename weight parameters
+            flax_key, flax_tensor = rename_key_and_reshape_tensor(
+                pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
+            )
+            # add model prefix if necessary
+            require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
+            if load_base_model_into_model_with_head and require_base_model_prefix:
+                flax_key = (model_prefix,) + flax_key
+
+            if flax_key in random_flax_state_dict:
+                if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
+                    raise ValueError(
+                        f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
+                        f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
+                    )
+
+            # add batch stats if the model contains batchnorm layers
+            if "batch_stats" in flax_model.params:
+                if "mean" in flax_key[-1]:
+                    flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
+                    continue
+                if "var" in flax_key[-1]:
+                    flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
+                    continue
+                # remove num_batches_tracked key
+                if "num_batches_tracked" in flax_key[-1]:
+                    flax_state_dict.pop(flax_key, None)
+                    continue
+
+                # also add unexpected weight so that warning is thrown
+                flax_state_dict[("params",) + flax_key] = (
+                    jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
+                )
+
+            else:
+                # also add unexpected weight so that warning is thrown
+                flax_state_dict[flax_key] = (
+                    jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
+                )
+    return unflatten_dict(flax_state_dict)
+
+
+#####################
+# Flax => PyTorch #
+#####################
+
+
+def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path):
+    """Load flax checkpoints in a PyTorch model"""
+    flax_checkpoint_path = os.path.abspath(flax_checkpoint_path)
+    logger.info(f"Loading Flax weights from {flax_checkpoint_path}")
+
+    # import correct flax class
+    flax_cls = getattr(transformers, "Flax" + model.__class__.__name__)
+
+    # load flax weight dict
+    if flax_checkpoint_path.endswith(".safetensors"):
+        flax_state_dict = safe_load_file(flax_checkpoint_path)
+        flax_state_dict = unflatten_dict(flax_state_dict, sep=".")
+    else:
+        with open(flax_checkpoint_path, "rb") as state_f:
+            try:
+                flax_state_dict = from_bytes(flax_cls, state_f.read())
+            except UnpicklingError:
+                raise OSError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ")
+
+    return load_flax_weights_in_pytorch_model(model, flax_state_dict)
+
+
+def load_flax_weights_in_pytorch_model(pt_model, flax_state):
+    """Load flax checkpoints in a PyTorch model"""
+
+    try:
+        import torch  # noqa: F401
+    except (ImportError, ModuleNotFoundError):
+        logger.error(
+            "Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see"
+            " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/index.html#installation for installation"
+            " instructions."
+        )
+        raise
+
+    # check if we have bf16 weights
+    is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
+    if any(is_type_bf16):
+        # convert all weights to fp32 if the are bf16 since torch.from_numpy can-not handle bf16
+        # and bf16 is not fully supported in PT yet.
+        logger.warning(
+            "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
+            "before loading those in PyTorch model."
+        )
+        flax_state = jax.tree_util.tree_map(
+            lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
+        )
+
+    flax_state_dict = flatten_dict(flax_state)
+    pt_model_dict = pt_model.state_dict()
+
+    load_model_with_head_into_base_model = (pt_model.base_model_prefix in flax_state) and (
+        pt_model.base_model_prefix not in {k.split(".")[0] for k in pt_model_dict}
+    )
+    load_base_model_into_model_with_head = (pt_model.base_model_prefix not in flax_state) and (
+        pt_model.base_model_prefix in {k.split(".")[0] for k in pt_model_dict}
+    )
+
+    # keep track of unexpected & missing keys
+    unexpected_keys = []
+    missing_keys = set(pt_model_dict.keys())
+
+    for flax_key_tuple, flax_tensor in flax_state_dict.items():
+        has_base_model_prefix = flax_key_tuple[0] == pt_model.base_model_prefix
+        require_base_model_prefix = ".".join((pt_model.base_model_prefix,) + flax_key_tuple) in pt_model_dict
+
+        # adapt flax_key to prepare for loading from/to base model only
+        if load_model_with_head_into_base_model and has_base_model_prefix:
+            flax_key_tuple = flax_key_tuple[1:]
+        elif load_base_model_into_model_with_head and require_base_model_prefix:
+            flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple
+
+        # rename flax weights to PyTorch format
+        if flax_key_tuple[-1] == "kernel" and flax_tensor.ndim == 4 and ".".join(flax_key_tuple) not in pt_model_dict:
+            # conv layer
+            flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
+            flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
+        elif flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple) not in pt_model_dict:
+            # linear layer
+            flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
+            flax_tensor = flax_tensor.T
+        elif flax_key_tuple[-1] in ["scale", "embedding"]:
+            flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
+
+        # adding batch stats from flax batch norm to pt
+        elif "mean" in flax_key_tuple[-1]:
+            flax_key_tuple = flax_key_tuple[:-1] + ("running_mean",)
+        elif "var" in flax_key_tuple[-1]:
+            flax_key_tuple = flax_key_tuple[:-1] + ("running_var",)
+
+        if "batch_stats" in flax_state:
+            flax_key = ".".join(flax_key_tuple[1:])  # Remove the params/batch_stats header
+        else:
+            flax_key = ".".join(flax_key_tuple)
+
+        # We also need to look at `pt_model_dict` and see if there are keys requiring further transformation.
+        special_pt_names = {}
+        # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
+        for key in pt_model_dict:
+            key_components = key.split(".")
+            name = None
+            if key_components[-3::2] == ["parametrizations", "original0"]:
+                name = key_components[-2] + "_g"
+            elif key_components[-3::2] == ["parametrizations", "original1"]:
+                name = key_components[-2] + "_v"
+            if name is not None:
+                key_components = key_components[:-3] + [name]
+                key_to_check = ".".join(key_components)
+                special_pt_names[key_to_check] = key
+
+        if flax_key in special_pt_names:
+            flax_key = special_pt_names[flax_key]
+
+        if flax_key in pt_model_dict:
+            if flax_tensor.shape != pt_model_dict[flax_key].shape:
+                raise ValueError(
+                    f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
+                    f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
+                )
+            else:
+                # add weight to pytorch dict
+                flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
+                pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
+                # remove from missing keys
+                missing_keys.remove(flax_key)
+        else:
+            # weight is not expected by PyTorch model
+            unexpected_keys.append(flax_key)
+
+    pt_model.load_state_dict(pt_model_dict)
+
+    # re-transform missing_keys to list
+    missing_keys = list(missing_keys)
+
+    if len(unexpected_keys) > 0:
+        logger.warning(
+            "Some weights of the Flax model were not used when initializing the PyTorch model"
+            f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
+            f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
+            " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
+            f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
+            " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
+            " FlaxBertForSequenceClassification model)."
+        )
+    else:
+        logger.warning(f"All Flax model weights were used when initializing {pt_model.__class__.__name__}.\n")
+    if len(missing_keys) > 0:
+        logger.warning(
+            f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
+            f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
+            " use it for predictions and inference."
+        )
+    else:
+        logger.warning(
+            f"All the weights of {pt_model.__class__.__name__} were initialized from the Flax model.\n"
+            "If your task is similar to the task the model of the checkpoint was trained on, "
+            f"you can already use {pt_model.__class__.__name__} for predictions without further training."
+        )
+
+    return pt_model
diff --git a/phivenv/Lib/site-packages/transformers/modeling_flax_utils.py b/phivenv/Lib/site-packages/transformers/modeling_flax_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc9a4d473f36f95bb13d6de17dc0bfa7cfae2279
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/modeling_flax_utils.py
@@ -0,0 +1,1274 @@
+# coding=utf-8
+# Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import gc
+import json
+import os
+import warnings
+from functools import partial
+from pickle import UnpicklingError
+from typing import Any, Optional, Union
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+import msgpack.exceptions
+from flax.core.frozen_dict import FrozenDict, unfreeze
+from flax.serialization import from_bytes, to_bytes
+from flax.traverse_util import flatten_dict, unflatten_dict
+from jax.random import PRNGKey
+
+from .configuration_utils import PretrainedConfig
+from .dynamic_module_utils import custom_object_save
+from .generation import FlaxGenerationMixin, GenerationConfig
+from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
+from .utils import (
+    FLAX_WEIGHTS_INDEX_NAME,
+    FLAX_WEIGHTS_NAME,
+    SAFE_WEIGHTS_INDEX_NAME,
+    SAFE_WEIGHTS_NAME,
+    WEIGHTS_INDEX_NAME,
+    WEIGHTS_NAME,
+    PushToHubMixin,
+    add_code_sample_docstrings,
+    add_start_docstrings_to_model_forward,
+    cached_file,
+    copy_func,
+    download_url,
+    has_file,
+    is_offline_mode,
+    is_remote_url,
+    logging,
+    replace_return_docstrings,
+)
+from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
+from .utils.import_utils import is_safetensors_available
+
+
+if is_safetensors_available():
+    from safetensors import safe_open
+    from safetensors.flax import load_file as safe_load_file
+    from safetensors.flax import save_file as safe_save_file
+
+logger = logging.get_logger(__name__)
+
+
+def quick_gelu(x):
+    return x * jax.nn.sigmoid(1.702 * x)
+
+
+ACT2FN = {
+    "gelu": partial(nn.gelu, approximate=False),
+    "relu": nn.relu,
+    "silu": nn.swish,
+    "swish": nn.swish,
+    "gelu_new": partial(nn.gelu, approximate=True),
+    "quick_gelu": quick_gelu,
+    "gelu_pytorch_tanh": partial(nn.gelu, approximate=True),
+    "tanh": nn.tanh,
+}
+
+
+def flax_shard_checkpoint(params, max_shard_size="10GB"):
+    """
+    Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
+    given size. The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so
+    there is no optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For
+    example, if the limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as
+    [6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
+
+    
+
+    If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will
+    have a size greater than `max_shard_size`.
+
+    
+
+    Args:
+        params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters.
+        max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
+            The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
+            (like `"5MB"`).
+    """
+    max_shard_size = convert_file_size_to_int(max_shard_size)
+
+    sharded_state_dicts = []
+    current_block = {}
+    current_block_size = 0
+    total_size = 0
+
+    # flatten the weights to chunk
+    weights = flatten_dict(params, sep="/")
+    for item in weights:
+        weight_size = weights[item].size * weights[item].dtype.itemsize
+
+        # If this weight is going to tip up over the maximal size, we split.
+        if current_block_size + weight_size > max_shard_size:
+            sharded_state_dicts.append(current_block)
+            current_block = {}
+            current_block_size = 0
+
+        current_block[item] = weights[item]
+        current_block_size += weight_size
+        total_size += weight_size
+
+    # Add the last block
+    sharded_state_dicts.append(current_block)
+
+    # If we only have one shard, we return it
+    if len(sharded_state_dicts) == 1:
+        return {FLAX_WEIGHTS_NAME: sharded_state_dicts[0]}, None
+
+    # Otherwise, let's build the index
+    weight_map = {}
+    shards = {}
+    for idx, shard in enumerate(sharded_state_dicts):
+        shard_file = FLAX_WEIGHTS_NAME.replace(".msgpack", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.msgpack")
+        shards[shard_file] = shard
+        for weight_name in shard:
+            weight_map[weight_name] = shard_file
+
+    # Add the metadata
+    metadata = {"total_size": total_size}
+    index = {"metadata": metadata, "weight_map": weight_map}
+    return shards, index
+
+
+class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
+    r"""
+    Base class for all models.
+
+    [`FlaxPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
+    downloading and saving models.
+
+    Class attributes (overridden by derived classes):
+
+        - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
+          for this model architecture.
+        - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
+          classes of the same architecture adding modules on top of the base model.
+        - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
+          models, `pixel_values` for vision models and `input_values` for speech models).
+    """
+
+    config_class = None
+    base_model_prefix = ""
+    main_input_name = "input_ids"
+    _auto_class = None
+    _missing_keys = set()
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        module: nn.Module,
+        input_shape: tuple = (1, 1),
+        seed: int = 0,
+        dtype: jnp.dtype = jnp.float32,
+        _do_init: bool = True,
+    ):
+        logger.warning_once(
+            "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We "
+            "recommend migrating to PyTorch classes or pinning your version of Transformers."
+        )
+        if config is None:
+            raise ValueError("config cannot be None")
+
+        if module is None:
+            raise ValueError("module cannot be None")
+
+        # Those are private to be exposed as typed property on derived classes.
+        self._config = config
+        self._module = module
+
+        # Those are public as their type is generic to every derived classes.
+        self.key = PRNGKey(seed)
+        self.dtype = dtype
+        self.input_shape = input_shape
+        self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
+
+        # To check if the model was initialized automatically.
+        self._is_initialized = _do_init
+
+        if _do_init:
+            # randomly initialized parameters
+            random_params = self.init_weights(self.key, input_shape)
+            params_shape_tree = jax.eval_shape(lambda params: params, random_params)
+        else:
+            init_fn = partial(self.init_weights, input_shape=input_shape)
+            params_shape_tree = jax.eval_shape(init_fn, self.key)
+
+            logger.info(
+                "Model weights are not initialized as `_do_init` is set to `False`. "
+                f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights."
+            )
+
+        # get the shape of the parameters
+        self._params_shape_tree = params_shape_tree
+
+        # save required_params as set
+        self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
+
+        # initialize the parameters
+        if _do_init:
+            self.params = random_params
+
+    def init_weights(self, rng: jax.random.PRNGKey, input_shape: tuple, params: FrozenDict = None) -> dict:
+        raise NotImplementedError(f"init method has to be implemented for {self}")
+
+    def enable_gradient_checkpointing(self):
+        raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}")
+
+    @classmethod
+    def _from_config(cls, config, **kwargs):
+        """
+        All context managers that the model should be initialized under go here.
+        """
+        return cls(config, **kwargs)
+
+    @property
+    def framework(self) -> str:
+        """
+        :str: Identifies that this is a Flax model.
+        """
+        return "flax"
+
+    @property
+    def config(self) -> PretrainedConfig:
+        return self._config
+
+    @property
+    def module(self) -> nn.Module:
+        return self._module
+
+    @property
+    def params(self) -> Union[dict, FrozenDict]:
+        if not self._is_initialized:
+            raise ValueError(
+                "`params` cannot be accessed from model when the model is created with `_do_init=False`. "
+                "You must call `init_weights` manually and store the params outside of the model and "
+                "pass it explicitly where needed."
+            )
+        return self._params
+
+    @property
+    def required_params(self) -> set:
+        return self._required_params
+
+    @property
+    def params_shape_tree(self) -> dict:
+        return self._params_shape_tree
+
+    @params.setter
+    def params(self, params: Union[dict, FrozenDict]):
+        # don't set params if the model is not initialized
+        if not self._is_initialized:
+            raise ValueError(
+                "`params` cannot be set from model when the model is created with `_do_init=False`. "
+                "You store the params outside of the model."
+            )
+
+        if isinstance(params, FrozenDict):
+            params = unfreeze(params)
+        param_keys = set(flatten_dict(params).keys())
+        if len(self.required_params - param_keys) > 0:
+            raise ValueError(
+                "Some parameters are missing. Make sure that `params` include the following "
+                f"parameters {self.required_params - param_keys}"
+            )
+        self._params = params
+
+    def _cast_floating_to(self, params: Union[dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
+        """
+        Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
+        """
+
+        # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
+        def conditional_cast(param):
+            if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
+                param = param.astype(dtype)
+            return param
+
+        if mask is None:
+            return jax.tree_util.tree_map(conditional_cast, params)
+
+        flat_params = flatten_dict(params)
+        flat_mask, _ = jax.tree_util.tree_flatten(mask)
+
+        for masked, key in zip(flat_mask, sorted(flat_params.keys())):
+            if masked:
+                flat_params[key] = conditional_cast(flat_params[key])
+
+        return unflatten_dict(flat_params)
+
+    def to_bf16(self, params: Union[dict, FrozenDict], mask: Any = None):
+        r"""
+        Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
+        the `params` in place.
+
+        This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
+        half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
+
+        Arguments:
+            params (`Union[Dict, FrozenDict]`):
+                A `PyTree` of model parameters.
+            mask (`Union[Dict, FrozenDict]`):
+                A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
+                you want to cast, and should be `False` for those you want to skip.
+
+        Examples:
+
+        ```python
+        >>> from transformers import FlaxBertModel
+
+        >>> # load model
+        >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
+        >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
+        >>> model.params = model.to_bf16(model.params)
+        >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
+        >>> # then pass the mask as follows
+        >>> from flax import traverse_util
+
+        >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
+        >>> flat_params = traverse_util.flatten_dict(model.params)
+        >>> mask = {
+        ...     path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
+        ...     for path in flat_params
+        ... }
+        >>> mask = traverse_util.unflatten_dict(mask)
+        >>> model.params = model.to_bf16(model.params, mask)
+        ```"""
+        return self._cast_floating_to(params, jnp.bfloat16, mask)
+
+    def to_fp32(self, params: Union[dict, FrozenDict], mask: Any = None):
+        r"""
+        Cast the floating-point `params` to `jax.numpy.float32`. This method can be used to explicitly convert the
+        model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
+
+        Arguments:
+            params (`Union[Dict, FrozenDict]`):
+                A `PyTree` of model parameters.
+            mask (`Union[Dict, FrozenDict]`):
+                A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
+                you want to cast, and should be `False` for those you want to skip
+
+        Examples:
+
+        ```python
+        >>> from transformers import FlaxBertModel
+
+        >>> # Download model and configuration from huggingface.co
+        >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
+        >>> # By default, the model params will be in fp32, to illustrate the use of this method,
+        >>> # we'll first cast to fp16 and back to fp32
+        >>> model.params = model.to_f16(model.params)
+        >>> # now cast back to fp32
+        >>> model.params = model.to_fp32(model.params)
+        ```"""
+        return self._cast_floating_to(params, jnp.float32, mask)
+
+    def to_fp16(self, params: Union[dict, FrozenDict], mask: Any = None):
+        r"""
+        Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
+        `params` in place.
+
+        This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
+        half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
+
+        Arguments:
+            params (`Union[Dict, FrozenDict]`):
+                A `PyTree` of model parameters.
+            mask (`Union[Dict, FrozenDict]`):
+                A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
+                you want to cast, and should be `False` for those you want to skip
+
+        Examples:
+
+        ```python
+        >>> from transformers import FlaxBertModel
+
+        >>> # load model
+        >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
+        >>> # By default, the model params will be in fp32, to cast these to float16
+        >>> model.params = model.to_fp16(model.params)
+        >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
+        >>> # then pass the mask as follows
+        >>> from flax import traverse_util
+
+        >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
+        >>> flat_params = traverse_util.flatten_dict(model.params)
+        >>> mask = {
+        ...     path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
+        ...     for path in flat_params
+        ... }
+        >>> mask = traverse_util.unflatten_dict(mask)
+        >>> model.params = model.to_fp16(model.params, mask)
+        ```"""
+        return self._cast_floating_to(params, jnp.float16, mask)
+
+    @classmethod
+    def load_flax_weights(cls, resolved_archive_file):
+        try:
+            if resolved_archive_file.endswith(".safetensors"):
+                state = safe_load_file(resolved_archive_file)
+                state = unflatten_dict(state, sep=".")
+            else:
+                with open(resolved_archive_file, "rb") as state_f:
+                    state = from_bytes(cls, state_f.read())
+        except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
+            try:
+                with open(resolved_archive_file) as f:
+                    if f.read().startswith("version"):
+                        raise OSError(
+                            "You seem to have cloned a repository without having git-lfs installed. Please"
+                            " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
+                            " folder you cloned."
+                        )
+                    else:
+                        raise ValueError from e
+            except (UnicodeDecodeError, ValueError):
+                raise OSError(f"Unable to convert {resolved_archive_file} to Flax deserializable object. ")
+
+        return state
+
+    @classmethod
+    def load_flax_sharded_weights(cls, shard_files):
+        """
+        This is the same as [`flax.serialization.from_bytes`]
+        (https:lax.readthedocs.io/en/latest/_modules/flax/serialization.html#from_bytes) but for a sharded checkpoint.
+
+        This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
+        loaded in the model.
+
+        Args:
+            shard_files (`list[str]`:
+                The list of shard files to load.
+
+        Returns:
+            `Dict`: A nested dictionary of the model parameters, in the expected format for flax models : `{'model':
+            {'params': {'...'}}}`.
+        """
+
+        # Load the index
+        state_sharded_dict = {}
+
+        for shard_file in shard_files:
+            # load using msgpack utils
+            try:
+                with open(shard_file, "rb") as state_f:
+                    state = from_bytes(cls, state_f.read())
+            except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
+                with open(shard_file) as f:
+                    if f.read().startswith("version"):
+                        raise OSError(
+                            "You seem to have cloned a repository without having git-lfs installed. Please"
+                            " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
+                            " folder you cloned."
+                        )
+                    else:
+                        raise ValueError from e
+            except (UnicodeDecodeError, ValueError):
+                raise OSError(f"Unable to convert {shard_file} to Flax deserializable object. ")
+
+            state = flatten_dict(state, sep="/")
+            state_sharded_dict.update(state)
+            del state
+            gc.collect()
+
+        # the state dict is unflattened to the match the format of model.params
+        return unflatten_dict(state_sharded_dict, sep="/")
+
+    @classmethod
+    def can_generate(cls) -> bool:
+        """
+        Returns whether this model can generate sequences with `.generate()`. Returns:
+            `bool`: Whether this model can generate sequences with `.generate()`.
+        """
+        # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
+        # Alternatively, the model can also have a custom `generate` function.
+        if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
+            return False
+        return True
+
+    @classmethod
+    def from_pretrained(
+        cls,
+        pretrained_model_name_or_path: Union[str, os.PathLike],
+        dtype: jnp.dtype = jnp.float32,
+        *model_args,
+        config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
+        cache_dir: Optional[Union[str, os.PathLike]] = None,
+        ignore_mismatched_sizes: bool = False,
+        force_download: bool = False,
+        local_files_only: bool = False,
+        token: Optional[Union[str, bool]] = None,
+        revision: str = "main",
+        **kwargs,
+    ):
+        r"""
+        Instantiate a pretrained flax model from a pre-trained model configuration.
+
+        The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+        task.
+
+        The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+        weights are discarded.
+
+        Parameters:
+            pretrained_model_name_or_path (`str` or `os.PathLike`):
+                Can be either:
+
+                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+                    - A path to a *directory* containing model weights saved using
+                      [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+                    - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case,
+                      `from_pt` should be set to `True`.
+            dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
+                The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
+                `jax.numpy.bfloat16` (on TPUs).
+
+                This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
+                specified all the computation will be performed with the given `dtype`.
+
+                **Note that this only specifies the dtype of the computation and does not influence the dtype of model
+                parameters.**
+
+                If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
+                [`~FlaxPreTrainedModel.to_bf16`].
+            model_args (sequence of positional arguments, *optional*):
+                All remaining positional arguments will be passed to the underlying model's `__init__` method.
+            config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
+                Can be either:
+
+                    - an instance of a class derived from [`PretrainedConfig`],
+                    - a string or path valid as input to [`~PretrainedConfig.from_pretrained`].
+
+                Configuration for the model to use instead of an automatically loaded configuration. Configuration can
+                be automatically loaded when:
+
+                    - The model is a model provided by the library (loaded with the *model id* string of a pretrained
+                      model).
+                    - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
+                      save directory.
+                    - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
+                      configuration JSON file named *config.json* is found in the directory.
+            cache_dir (`Union[str, os.PathLike]`, *optional*):
+                Path to a directory in which a downloaded pretrained model configuration should be cached if the
+                standard cache should not be used.
+            from_pt (`bool`, *optional*, defaults to `False`):
+                Load the model weights from a PyTorch checkpoint save file (see docstring of
+                `pretrained_model_name_or_path` argument).
+            ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
+                Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
+                as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
+                checkpoint with 3 labels).
+            force_download (`bool`, *optional*, defaults to `False`):
+                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+                cached versions if they exist.
+            resume_download:
+                Deprecated and ignored. All downloads are now resumed by default when possible.
+                Will be removed in v5 of Transformers.
+            proxies (`dict[str, str]`, *optional*):
+                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+            local_files_only(`bool`, *optional*, defaults to `False`):
+                Whether or not to only look at local files (i.e., do not try to download the model).
+            token (`str` or `bool`, *optional*):
+                The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+                the token generated when running `hf auth login` (stored in `~/.huggingface`).
+            revision (`str`, *optional*, defaults to `"main"`):
+                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+                identifier allowed by git.
+
+
+                
+
+                To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`.
+
+                
+
+            subfolder (`str`, *optional*, defaults to `""`):
+                In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+                specify the folder name here.
+            kwargs (remaining dictionary of keyword arguments, *optional*):
+                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+                `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
+                automatically loaded:
+
+                    - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
+                      underlying model's `__init__` method (we assume all relevant updates to the configuration have
+                      already been done)
+                    - If a configuration is not provided, `kwargs` will be first passed to the configuration class
+                      initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
+                      corresponds to a configuration attribute will be used to override said attribute with the
+                      supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
+                      will be passed to the underlying model's `__init__` function.
+
+        Examples:
+
+        ```python
+        >>> from transformers import BertConfig, FlaxBertModel
+
+        >>> # Download model and configuration from huggingface.co and cache.
+        >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
+        >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
+        >>> model = FlaxBertModel.from_pretrained("./test/saved_model/")
+        >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
+        >>> config = BertConfig.from_json_file("./pt_model/config.json")
+        >>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config)
+        ```"""
+        from_pt = kwargs.pop("from_pt", False)
+        resume_download = kwargs.pop("resume_download", None)
+        proxies = kwargs.pop("proxies", None)
+        use_auth_token = kwargs.pop("use_auth_token", None)
+        trust_remote_code = kwargs.pop("trust_remote_code", None)
+        from_pipeline = kwargs.pop("_from_pipeline", None)
+        from_auto_class = kwargs.pop("_from_auto", False)
+        _do_init = kwargs.pop("_do_init", True)
+        subfolder = kwargs.pop("subfolder", "")
+        commit_hash = kwargs.pop("_commit_hash", None)
+
+        # Not relevant for Flax Models
+        _ = kwargs.pop("adapter_kwargs", None)
+
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if token is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            token = use_auth_token
+
+        if trust_remote_code is True:
+            logger.warning(
+                "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
+                " ignored."
+            )
+
+        user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
+        if from_pipeline is not None:
+            user_agent["using_pipeline"] = from_pipeline
+
+        if is_offline_mode() and not local_files_only:
+            logger.info("Offline mode: forcing local_files_only=True")
+            local_files_only = True
+
+        # Load config if we don't provide a configuration
+        if not isinstance(config, PretrainedConfig):
+            config_path = config if config is not None else pretrained_model_name_or_path
+            config, model_kwargs = cls.config_class.from_pretrained(
+                config_path,
+                cache_dir=cache_dir,
+                return_unused_kwargs=True,
+                force_download=force_download,
+                resume_download=resume_download,
+                proxies=proxies,
+                local_files_only=local_files_only,
+                token=token,
+                revision=revision,
+                subfolder=subfolder,
+                _from_auto=from_auto_class,
+                _from_pipeline=from_pipeline,
+                _commit_hash=commit_hash,
+                **kwargs,
+            )
+        else:
+            model_kwargs = kwargs.copy()
+
+        if commit_hash is None:
+            commit_hash = getattr(config, "_commit_hash", None)
+
+        # Add the dtype to model_kwargs
+        model_kwargs["dtype"] = dtype
+
+        # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
+        # index of the files.
+        is_sharded = False
+
+        # Load model
+        if pretrained_model_name_or_path is not None:
+            pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+            is_local = os.path.isdir(pretrained_model_name_or_path)
+            if is_local:
+                if os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
+                    # Load from a Flax checkpoint
+                    archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
+                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)):
+                    # Load from a sharded Flax checkpoint
+                    archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
+                    is_sharded = True
+                elif is_safetensors_available() and os.path.isfile(
+                    os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
+                ):
+                    # Load from a safetensors checkpoint
+                    archive_file = os.path.join(pretrained_model_name_or_path, subfolder, SAFE_WEIGHTS_NAME)
+                elif is_safetensors_available() and os.path.isfile(
+                    os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
+                ):
+                    # Load from a safetensors checkpoint
+                    archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
+                elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
+                    # Load from a PyTorch checkpoint
+                    archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
+                elif from_pt and os.path.isfile(
+                    os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
+                ):
+                    # Load from a sharded pytorch checkpoint
+                    archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
+                    is_sharded = True
+                # At this stage we don't have a weight file so we will raise an error.
+                elif is_safetensors_available() and os.path.isfile(
+                    os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
+                ):
+                    # Load from a sharded safetensors checkpoint
+                    archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
+                    is_sharded = True
+                    raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
+                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
+                    raise OSError(
+                        f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
+                        "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
+                        "weights."
+                    )
+                else:
+                    raise OSError(
+                        f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
+                        f"{pretrained_model_name_or_path}."
+                    )
+            elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
+                archive_file = pretrained_model_name_or_path
+                is_local = True
+            elif is_remote_url(pretrained_model_name_or_path):
+                filename = pretrained_model_name_or_path
+                resolved_archive_file = download_url(pretrained_model_name_or_path)
+            else:
+                if from_pt:
+                    filename = WEIGHTS_NAME
+                else:
+                    filename = FLAX_WEIGHTS_NAME
+
+                try:
+                    # Load from URL or cache if already cached
+                    cached_file_kwargs = {
+                        "cache_dir": cache_dir,
+                        "force_download": force_download,
+                        "proxies": proxies,
+                        "resume_download": resume_download,
+                        "local_files_only": local_files_only,
+                        "token": token,
+                        "user_agent": user_agent,
+                        "revision": revision,
+                        "subfolder": subfolder,
+                        "_raise_exceptions_for_gated_repo": False,
+                        "_raise_exceptions_for_missing_entries": False,
+                        "_commit_hash": commit_hash,
+                    }
+                    resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
+
+                    # Maybe the checkpoint is sharded, we try to grab the index name in this case.
+                    if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
+                        resolved_archive_file = cached_file(
+                            pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs
+                        )
+                        if resolved_archive_file is not None:
+                            is_sharded = True
+
+                    # Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
+                    if resolved_archive_file is None and from_pt:
+                        resolved_archive_file = cached_file(
+                            pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
+                        )
+                        if resolved_archive_file is not None:
+                            is_sharded = True
+
+                    # If we still haven't found anything, look for `safetensors`.
+                    if resolved_archive_file is None:
+                        # No support for sharded safetensors yet, so we'll raise an error if that's all we find.
+                        filename = SAFE_WEIGHTS_NAME
+                        resolved_archive_file = cached_file(
+                            pretrained_model_name_or_path, SAFE_WEIGHTS_NAME, **cached_file_kwargs
+                        )
+
+                    # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
+                    # result when internet is up, the repo and revision exist, but the file does not.
+                    if resolved_archive_file is None:
+                        # Otherwise, maybe there is a TF or Torch model file.  We try those to give a helpful error
+                        # message.
+                        has_file_kwargs = {
+                            "revision": revision,
+                            "proxies": proxies,
+                            "token": token,
+                            "cache_dir": cache_dir,
+                            "local_files_only": local_files_only,
+                        }
+                        if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
+                            is_sharded = True
+                            raise NotImplementedError(
+                                "Support for sharded checkpoints using safetensors is coming soon!"
+                            )
+                        elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
+                            raise OSError(
+                                f"{pretrained_model_name_or_path} does not appear to have a file named"
+                                f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
+                                " load this model from those weights."
+                            )
+                        elif has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs):
+                            raise OSError(
+                                f"{pretrained_model_name_or_path} does not appear to have a file named"
+                                f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use"
+                                " `from_pt=True` to load this model from those weights."
+                            )
+                        else:
+                            raise OSError(
+                                f"{pretrained_model_name_or_path} does not appear to have a file named"
+                                f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
+                            )
+                except OSError:
+                    # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
+                    # to the original exception.
+                    raise
+                except Exception:
+                    # For any other exception, we throw a generic error.
+                    raise OSError(
+                        f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
+                        " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+                        f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+                        f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
+                    )
+
+            if is_local:
+                logger.info(f"loading weights file {archive_file}")
+                resolved_archive_file = archive_file
+                filename = resolved_archive_file.split(os.path.sep)[-1]
+            else:
+                logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
+        else:
+            resolved_archive_file = None
+
+        # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
+        if is_sharded:
+            # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
+            resolved_archive_file, _ = get_checkpoint_shard_files(
+                pretrained_model_name_or_path,
+                resolved_archive_file,
+                cache_dir=cache_dir,
+                force_download=force_download,
+                proxies=proxies,
+                resume_download=resume_download,
+                local_files_only=local_files_only,
+                token=token,
+                user_agent=user_agent,
+                revision=revision,
+                subfolder=subfolder,
+                _commit_hash=commit_hash,
+            )
+
+        safetensors_from_pt = False
+        if filename == SAFE_WEIGHTS_NAME:
+            with safe_open(resolved_archive_file, framework="flax") as f:
+                safetensors_metadata = f.metadata()
+            if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]:
+                raise OSError(
+                    f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
+                    " Make sure you save your model with the `save_pretrained` method."
+                )
+            safetensors_from_pt = safetensors_metadata.get("format") == "pt"
+
+        # init random models
+        model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
+
+        if from_pt or safetensors_from_pt:
+            state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
+        else:
+            if is_sharded:
+                state = cls.load_flax_sharded_weights(resolved_archive_file)
+            else:
+                state = cls.load_flax_weights(resolved_archive_file)
+            # make sure all arrays are stored as jnp.arrays
+            # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
+            # https://github.com/google/flax/issues/1261
+            if _do_init:
+                state = jax.tree_util.tree_map(jnp.array, state)
+            else:
+                # keep the params on CPU if we don't want to initialize
+                state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state)
+
+        if "batch_stats" in state:  # if flax model contains batch norm layers
+            # if model is base model only use model_prefix key
+            if (
+                cls.base_model_prefix not in dict(model.params_shape_tree["params"])
+                and cls.base_model_prefix in state["params"]
+            ):
+                state["params"] = state["params"][cls.base_model_prefix]
+                state["batch_stats"] = state["batch_stats"][cls.base_model_prefix]
+
+            # if model is head model and we are loading weights from base model
+            # we initialize new params dict with base_model_prefix
+            if (
+                cls.base_model_prefix in dict(model.params_shape_tree["params"])
+                and cls.base_model_prefix not in state["params"]
+            ):
+                state = {
+                    "params": {cls.base_model_prefix: state["params"]},
+                    "batch_stats": {cls.base_model_prefix: state["batch_stats"]},
+                }
+
+        else:
+            # if model is base model only use model_prefix key
+            if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state:
+                state = state[cls.base_model_prefix]
+
+            # if model is head model and we are loading weights from base model
+            # we initialize new params dict with base_model_prefix
+            if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state:
+                state = {cls.base_model_prefix: state}
+
+        # flatten dicts
+        state = flatten_dict(state)
+
+        random_state = flatten_dict(unfreeze(model.params if _do_init else model.params_shape_tree))
+
+        missing_keys = model.required_params - set(state.keys())
+        unexpected_keys = set(state.keys()) - model.required_params
+
+        # Disabling warning when porting pytorch weights to flax, flax does not uses num_batches_tracked
+        for unexpected_key in unexpected_keys.copy():
+            if "num_batches_tracked" in unexpected_key[-1]:
+                unexpected_keys.remove(unexpected_key)
+
+        if missing_keys and not _do_init:
+            logger.warning(
+                f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
+                "Make sure to call model.init_weights to initialize the missing weights."
+            )
+            cls._missing_keys = missing_keys
+
+        # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
+        # matching the weights in the model.
+        mismatched_keys = []
+        for key in state:
+            if key in random_state and state[key].shape != random_state[key].shape:
+                if ignore_mismatched_sizes:
+                    mismatched_keys.append((key, state[key].shape, random_state[key].shape))
+                    state[key] = random_state[key]
+                else:
+                    raise ValueError(
+                        f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
+                        f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
+                        "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
+                        "model."
+                    )
+
+        # add missing keys as random parameters if we are initializing
+        if missing_keys and _do_init:
+            for missing_key in missing_keys:
+                state[missing_key] = random_state[missing_key]
+
+        # remove unexpected keys to not be saved again
+        for unexpected_key in unexpected_keys:
+            del state[unexpected_key]
+
+        if len(unexpected_keys) > 0:
+            logger.warning(
+                f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+                f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+                f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
+                " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+                " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+                f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
+                " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
+            )
+        else:
+            logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+
+        if len(missing_keys) > 0:
+            logger.warning(
+                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+                f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+                " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+            )
+        elif len(mismatched_keys) == 0:
+            logger.info(
+                f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+                f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
+                f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
+                " training."
+            )
+        if len(mismatched_keys) > 0:
+            mismatched_warning = "\n".join(
+                [
+                    f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+                    for key, shape1, shape2 in mismatched_keys
+                ]
+            )
+            logger.warning(
+                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+                f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
+                f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
+                " to use it for predictions and inference."
+            )
+
+        # dictionary of key: dtypes for the model params
+        param_dtypes = jax.tree_util.tree_map(lambda x: x.dtype, state)
+        # extract keys of parameters not in jnp.float32
+        fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16]
+        bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16]
+
+        # raise a warning if any of the parameters are not in jnp.float32
+        if len(fp16_params) > 0:
+            logger.warning(
+                f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from "
+                f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n"
+                "You should probably UPCAST the model weights to float32 if this was not intended. "
+                "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
+            )
+
+        if len(bf16_params) > 0:
+            logger.warning(
+                f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
+                f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
+                "You should probably UPCAST the model weights to float32 if this was not intended. "
+                "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
+            )
+
+        # If it is a model with generation capabilities, attempt to load the generation config
+        if model.can_generate():
+            try:
+                model.generation_config = GenerationConfig.from_pretrained(
+                    pretrained_model_name_or_path,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    resume_download=resume_download,
+                    proxies=proxies,
+                    local_files_only=local_files_only,
+                    token=token,
+                    revision=revision,
+                    subfolder=subfolder,
+                    _from_auto=from_auto_class,
+                    _from_pipeline=from_pipeline,
+                    **kwargs,
+                )
+            except OSError:
+                logger.info(
+                    "Generation config file not found, using a generation config created from the model config."
+                )
+                pass
+
+        if _do_init:
+            # set correct parameters
+            model.params = unflatten_dict(state)
+            return model
+        else:
+            return model, unflatten_dict(state)
+
+    def save_pretrained(
+        self,
+        save_directory: Union[str, os.PathLike],
+        params=None,
+        push_to_hub=False,
+        max_shard_size="10GB",
+        token: Optional[Union[str, bool]] = None,
+        safe_serialization: bool = False,
+        **kwargs,
+    ):
+        """
+        Save a model and its configuration file to a directory, so that it can be re-loaded using the
+        `[`~FlaxPreTrainedModel.from_pretrained`]` class method
+
+        Arguments:
+            save_directory (`str` or `os.PathLike`):
+                Directory to which to save. Will be created if it doesn't exist.
+            push_to_hub (`bool`, *optional*, defaults to `False`):
+                Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+                namespace).
+            max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
+                The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
+                lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
+
+                
+
+                If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
+                which will be bigger than `max_shard_size`.
+
+                
+
+            token (`str` or `bool`, *optional*):
+                The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+                the token generated when running `hf auth login` (stored in `~/.huggingface`).
+            kwargs (`dict[str, Any]`, *optional*):
+                Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+            safe_serialization (`bool`, *optional*, defaults to `False`):
+                Whether to save the model using `safetensors` or through msgpack.
+        """
+        use_auth_token = kwargs.pop("use_auth_token", None)
+
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if token is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            token = use_auth_token
+
+        if token is not None:
+            kwargs["token"] = token
+
+        if os.path.isfile(save_directory):
+            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+            return
+
+        os.makedirs(save_directory, exist_ok=True)
+
+        if push_to_hub:
+            commit_message = kwargs.pop("commit_message", None)
+            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+            repo_id = self._create_repo(repo_id, **kwargs)
+            files_timestamps = self._get_files_timestamps(save_directory)
+
+        # get abs dir
+        save_directory = os.path.abspath(save_directory)
+        # save config as well
+        self.config.architectures = [self.__class__.__name__[4:]]
+
+        # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
+        # loaded from the Hub.
+        if self._auto_class is not None:
+            custom_object_save(self, save_directory, config=self.config)
+
+        self.config.save_pretrained(save_directory)
+        if self.can_generate():
+            self.generation_config.save_pretrained(save_directory)
+
+        # save model
+        weights_name = SAFE_WEIGHTS_NAME if safe_serialization else FLAX_WEIGHTS_NAME
+        output_model_file = os.path.join(save_directory, weights_name)
+
+        shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size)
+        # Clean the folder from a previous save
+        for filename in os.listdir(save_directory):
+            full_filename = os.path.join(save_directory, filename)
+            weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
+            if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and filename not in shards:
+                os.remove(full_filename)
+
+        if index is None:
+            if safe_serialization:
+                params = params if params is not None else self.params
+                flat_dict = flatten_dict(params, sep=".")
+                safe_save_file(flat_dict, output_model_file, metadata={"format": "flax"})
+            else:
+                with open(output_model_file, "wb") as f:
+                    params = params if params is not None else self.params
+                    model_bytes = to_bytes(params)
+                    f.write(model_bytes)
+
+        else:
+            save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME)
+            # Save the index as well
+            with open(save_index_file, "w", encoding="utf-8") as f:
+                content = json.dumps(index, indent=2, sort_keys=True) + "\n"
+                f.write(content)
+            logger.info(
+                f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
+                f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
+                f"index located at {save_index_file}."
+            )
+            for shard_file, shard in shards.items():
+                # the shard item are unflattened, to save them we need to flatten them again
+                with open(os.path.join(save_directory, shard_file), mode="wb") as f:
+                    params = unflatten_dict(shard, sep="/")
+                    shard_bytes = to_bytes(params)
+                    f.write(shard_bytes)
+
+        logger.info(f"Model weights saved in {output_model_file}")
+
+        if push_to_hub:
+            self._upload_modified_files(
+                save_directory,
+                repo_id,
+                files_timestamps,
+                commit_message=commit_message,
+                token=token,
+            )
+
+    @classmethod
+    def register_for_auto_class(cls, auto_class="FlaxAutoModel"):
+        """
+        Register this class with a given auto class. This should only be used for custom models as the ones in the
+        library are already mapped with an auto class.
+
+
+
+        Args:
+            auto_class (`str` or `type`, *optional*, defaults to `"FlaxAutoModel"`):
+                The auto class to register this new model with.
+        """
+        if not isinstance(auto_class, str):
+            auto_class = auto_class.__name__
+
+        import transformers.models.auto as auto_module
+
+        if not hasattr(auto_module, auto_class):
+            raise ValueError(f"{auto_class} is not a valid auto class.")
+
+        cls._auto_class = auto_class
+
+
+# To update the docstring, we need to copy the method, otherwise we change the original docstring.
+FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub)
+if FlaxPreTrainedModel.push_to_hub.__doc__ is not None:
+    FlaxPreTrainedModel.push_to_hub.__doc__ = FlaxPreTrainedModel.push_to_hub.__doc__.format(
+        object="model", object_class="FlaxAutoModel", object_files="model checkpoint"
+    )
+
+
+def overwrite_call_docstring(model_class, docstring):
+    # copy __call__ function to be sure docstring is changed only for this function
+    model_class.__call__ = copy_func(model_class.__call__)
+    # delete existing docstring
+    model_class.__call__.__doc__ = None
+    # set correct docstring
+    model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
+
+
+def append_call_sample_docstring(
+    model_class, checkpoint, output_type, config_class, mask=None, revision=None, real_checkpoint=None
+):
+    model_class.__call__ = copy_func(model_class.__call__)
+    model_class.__call__ = add_code_sample_docstrings(
+        checkpoint=checkpoint,
+        output_type=output_type,
+        config_class=config_class,
+        model_cls=model_class.__name__,
+        revision=revision,
+        real_checkpoint=real_checkpoint,
+    )(model_class.__call__)
+
+
+def append_replace_return_docstrings(model_class, output_type, config_class):
+    model_class.__call__ = copy_func(model_class.__call__)
+    model_class.__call__ = replace_return_docstrings(
+        output_type=output_type,
+        config_class=config_class,
+    )(model_class.__call__)
diff --git a/phivenv/Lib/site-packages/transformers/modeling_gguf_pytorch_utils.py b/phivenv/Lib/site-packages/transformers/modeling_gguf_pytorch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ef2725c10b0559d18859d5aa6c2c591ef38db2b
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/modeling_gguf_pytorch_utils.py
@@ -0,0 +1,500 @@
+# Copyright 2024 The ggml.ai team and The HuggingFace Inc. team. and pygguf author (github.com/99991)
+# https://github.com/99991/pygguf
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from typing import NamedTuple, Optional
+
+import numpy as np
+from tqdm.auto import tqdm
+
+from .integrations import (
+    GGUF_CONFIG_MAPPING,
+    GGUF_TOKENIZER_MAPPING,
+    _gguf_parse_value,
+)
+from .utils import is_torch_available
+from .utils.import_utils import is_gguf_available
+from .utils.logging import get_logger
+
+
+if is_torch_available():
+    import torch
+
+logger = get_logger(__name__)
+
+
+GGUF_TO_TRANSFORMERS_MAPPING = {
+    "ignore": {
+        "GGUF": {
+            "version": "version",
+            "tensor_count": "tensor_count",
+            "kv_count": "kv_count",
+        },
+        "general": {"file_type": "file_type", "quantization_version": "quantization_version"},
+    },
+    "config": GGUF_CONFIG_MAPPING,
+    "tokenizer": {"tokenizer": GGUF_TOKENIZER_MAPPING["tokenizer"]},
+    "tokenizer_config": {"tokenizer": GGUF_TOKENIZER_MAPPING["tokenizer_config"]},
+}
+
+GGUF_SUPPORTED_ARCHITECTURES = list(GGUF_TO_TRANSFORMERS_MAPPING["config"].keys())
+
+
+class GGUFTensor(NamedTuple):
+    weights: np.ndarray
+    name: str
+    metadata: dict
+
+
+class TensorProcessor:
+    def __init__(self, config=None):
+        self.config = config or {}
+
+    def process(self, weights, name, **kwargs):
+        return GGUFTensor(weights, name, {})
+
+
+class LlamaTensorProcessor(TensorProcessor):
+    def __init__(self, config=None):
+        super().__init__(config=config)
+
+    def process(self, weights, name, **kwargs):
+        if ".attn_k." in name or ".attn_q." in name:
+            num_heads = self.config.get("num_attention_heads")
+            num_kv_heads = self.config.get("num_key_value_heads")
+
+            if None in (num_heads, num_kv_heads):
+                return GGUFTensor(weights, name, {})
+            if ".attn_q." in name:
+                weights = self._reverse_permute_weights(weights, num_heads, num_heads)
+            elif ".attn_k." in name:
+                weights = self._reverse_permute_weights(weights, num_heads, num_kv_heads)
+        return GGUFTensor(weights, name, {})
+
+    def _reverse_permute_weights(
+        self, weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None
+    ) -> np.ndarray:
+        # Original permutation implementation
+        # https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408
+        if num_kv_heads is not None and n_head != num_kv_heads:
+            n_head = num_kv_heads
+
+        dim = weights.shape[0] // n_head // 2
+        w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
+        return w.swapaxes(2, 1).reshape(weights.shape)
+
+
+class Qwen2MoeTensorProcessor(TensorProcessor):
+    def __init__(self, config=None):
+        super().__init__(config=config)
+
+    def process(self, weights, name, **kwargs):
+        if "_exp" in name:
+            tensor_key_mapping = kwargs.get("tensor_key_mapping")
+            parsed_parameters = kwargs.get("parsed_parameters")
+            if tensor_key_mapping:
+                self._split_moe_expert_tensor(weights, parsed_parameters, name, tensor_key_mapping)
+                return GGUFTensor(weights, None, {})
+        if "ffn_gate_inp_shexp" in name:
+            # for compatibility tensor shared_expert_gate must be (1, 2048) dim,
+            # quantized one is (2048)
+            weights = np.expand_dims(weights, axis=0)
+        return GGUFTensor(weights, name, {})
+
+    def _split_moe_expert_tensor(
+        self, weights: np.ndarray, parsed_parameters: dict[str, dict], name: str, tensor_key_mapping: dict
+    ):
+        # Original merge implementation
+        # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L1994-L2022
+        name = tensor_key_mapping[name]
+        w_counter = self.config.get("num_experts", 60)
+        for i in range(0, w_counter):
+            temp_name = name.replace("mlp.experts.", f"mlp.experts.{i}.")
+            exp_weight = weights[i]
+            parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight))
+
+
+class BloomTensorProcessor(TensorProcessor):
+    def __init__(self, config=None):
+        super().__init__(config=config)
+
+    def process(self, weights, name, **kwargs):
+        if "attn_qkv" in name:
+            num_heads = self.config["n_head"]
+            n_embed = self.config["hidden_size"]
+            if "weight" in name:
+                weights = self._reverse_reshape_weights(weights, num_heads, n_embed)
+            else:
+                weights = self._reverse_reshape_bias(weights, num_heads, n_embed)
+        return GGUFTensor(weights, name, {})
+
+    def _reverse_reshape_weights(self, weights: np.ndarray, n_head: int, n_embed: int):
+        # Original reshape implementation
+        # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985
+        q, k, v = np.array_split(weights, 3, axis=0)
+
+        q = q.reshape(n_head, n_embed // n_head, n_embed)
+        k = k.reshape(n_head, n_embed // n_head, n_embed)
+        v = v.reshape(n_head, n_embed // n_head, n_embed)
+        qkv_weights = np.stack([q, k, v], axis=1)
+
+        return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed)
+
+    def _reverse_reshape_bias(self, weights: np.ndarray, n_head: int, n_embed: int):
+        # Original reshape implementation
+        # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998
+        q_bias, k_bias, v_bias = np.array_split(weights, 3)
+
+        q_bias = q_bias.reshape(n_head, n_embed // n_head)
+        k_bias = k_bias.reshape(n_head, n_embed // n_head)
+        v_bias = v_bias.reshape(n_head, n_embed // n_head)
+
+        qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten()
+        return qkv_bias
+
+
+class T5TensorProcessor(TensorProcessor):
+    def __init__(self, config=None):
+        super().__init__(config=config)
+
+    def process(self, weights, name, **kwargs):
+        bid = None
+        for chunk in name.split("."):
+            if chunk.isdigit():
+                bid = int(chunk)
+                break
+        return GGUFTensor(weights, name, {"bid": bid})
+
+
+class GPT2TensorProcessor(TensorProcessor):
+    def __init__(self, config=None):
+        super().__init__(config=config)
+
+    def process(self, weights, name, **kwargs):
+        # Original transpose implementation
+        # https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L2060-L2061
+        if (
+            "attn_qkv.weight" in name
+            or "ffn_down.weight" in name
+            or "ffn_up.weight" in name
+            or "attn_output.weight" in name
+        ):
+            weights = weights.T
+
+        # Handle special case for output.weight
+        if name == "output.weight":
+            # output.weight has conflicts with attn_output.weight in name checking
+            # Store the tensor directly and signal to skip further processing
+            name = "lm_head.weight"
+            parsed_parameters = kwargs.get("parsed_parameters", {})
+            parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
+            name = None  # Signal to skip further processing
+        return GGUFTensor(weights, name, {})
+
+
+class MambaTensorProcessor(TensorProcessor):
+    def __init__(self, config=None):
+        super().__init__(config=config)
+
+    def process(self, weights, name, **kwargs):
+        if "ssm_conv1d.weight" in name:
+            # for compatibility tensor ssm_conv1d must be (5120, 1, 4]) dim,
+            # quantized one is (5120, 4)
+            weights = np.expand_dims(weights, axis=1)
+        if "ssm_a" in name:
+            # Original exponential implementation
+            # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L2975-L2977
+            weights = np.log(-weights)
+        return GGUFTensor(weights, name, {})
+
+
+class NemotronTensorProcessor(TensorProcessor):
+    def __init__(self, config=None):
+        super().__init__(config=config)
+
+    # ref : https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L4666
+    def process(self, weights, name, **kwargs):
+        if "norm.weight" in name:
+            weights = weights - 1
+        return GGUFTensor(weights, name, {})
+
+
+class Gemma2TensorProcessor(TensorProcessor):
+    def __init__(self, config=None):
+        super().__init__(config=config)
+
+    # ref: https://github.com/ggerganov/llama.cpp/blob/d79d8f39b4da6deca4aea8bf130c6034c482b320/convert_hf_to_gguf.py#L3191
+    # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
+    def process(self, weights, name, **kwargs):
+        if "norm.weight" in name:
+            weights = weights - 1
+        return GGUFTensor(weights, name, {})
+
+
+TENSOR_PROCESSORS = {
+    "llama": LlamaTensorProcessor,
+    "qwen2moe": Qwen2MoeTensorProcessor,
+    "qwen3moe": Qwen2MoeTensorProcessor,
+    "bloom": BloomTensorProcessor,
+    "t5": T5TensorProcessor,
+    "t5encoder": T5TensorProcessor,
+    "gpt2": GPT2TensorProcessor,
+    "mamba": MambaTensorProcessor,
+    "nemotron": NemotronTensorProcessor,
+    "gemma2": Gemma2TensorProcessor,
+    "gemma3": Gemma2TensorProcessor,
+}
+
+
+def read_field(reader, field):
+    if field not in reader.fields:
+        return []
+    value = reader.fields[field]
+    return [_gguf_parse_value(value.parts[_data_index], value.types) for _data_index in value.data]
+
+
+# modified from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/loader.py#L1115-L1147
+def get_gguf_hf_weights_map(
+    hf_model,
+    model_type: Optional[str] = None,
+    num_layers: Optional[int] = None,
+    qual_name: str = "",
+):
+    """
+    GGUF uses this naming convention for their tensors from HF checkpoint:
+    `blk.N.BB.weight` and `blk.N.BB.bias`
+    where N signifies the block number of a layer, and BB signifies the
+    attention/mlp layer components.
+    See "Standardized tensor names" in
+    https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
+    """
+    if is_gguf_available() and is_torch_available():
+        from gguf import MODEL_ARCH_NAMES, get_tensor_name_map
+    else:
+        logger.error(
+            "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
+            "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
+        )
+        raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
+
+    model_type = hf_model.config.model_type if model_type is None else model_type
+    num_layers = hf_model.config.num_hidden_layers if num_layers is None else num_layers
+    # hack: ggufs have a different name for cohere
+    if model_type == "cohere":
+        model_type = "command-r"
+    elif model_type == "qwen2_moe":
+        model_type = "qwen2moe"
+    elif model_type == "qwen3_moe":
+        model_type = "qwen3moe"
+    elif model_type == "gemma3_text":
+        model_type = "gemma3"
+    arch = None
+    for key, value in MODEL_ARCH_NAMES.items():
+        if value == model_type:
+            arch = key
+            break
+    if arch is None:
+        raise NotImplementedError(
+            f"Unknown gguf model_type: {model_type} in gguf-py. "
+            "This might because you're using an outdated version of gguf-py package, "
+            "you can install `gguf` package from source refer to "
+            "https://github.com/ggerganov/llama.cpp/tree/master/gguf-py#development"
+        )
+    name_map = get_tensor_name_map(arch, num_layers)
+
+    # Use a dummy conversion to get the mapping, because
+    # hf => gguf and gguf => hf mappings are reversed
+    gguf_to_hf_name_map = {}
+    state_dict = hf_model.state_dict()
+    for hf_name in state_dict:
+        # An exception for qwen2moe/qwen3moe model, where the expert layers are packed
+        if model_type in ("qwen2moe", "qwen3moe") and "mlp.experts." in hf_name:
+            hf_name = re.sub(r"mlp.experts.\d+.", "mlp.experts.", hf_name)
+
+        name, suffix = hf_name, ""
+        if hf_name.endswith(".weight") or hf_name.endswith(".bias"):
+            name, suffix = hf_name.rsplit(".", 1)
+            suffix = "." + suffix
+
+        gguf_name = name_map.get_name(name)
+        if gguf_name is None:
+            continue
+
+        gguf_to_hf_name_map[gguf_name + suffix] = qual_name + hf_name
+
+    # Some model like Bloom converted from BloomModel instead of BloomForCausalLM
+    # Therefore, we need to check submodule as well to get a correct mapping
+    if named_children := hf_model.named_children():
+        for name, child in named_children:
+            sub_map = get_gguf_hf_weights_map(child, model_type, num_layers, qual_name=f"{qual_name}{name}.")
+            # Ignore the keys that are already in the main map to avoid overwriting
+            sub_map = {k: v for k, v in sub_map.items() if k not in gguf_to_hf_name_map}
+            gguf_to_hf_name_map.update(sub_map)
+
+    return gguf_to_hf_name_map
+
+
+def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_load=None):
+    """
+    Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed
+    tokenizer and config attributes.
+
+    Args:
+        gguf_checkpoint_path (`str`):
+            The path the to GGUF file to load
+        return_tensors (`bool`, defaults to `False`):
+            Whether to read the tensors from the file and return them. Not doing so is faster
+            and only loads the metadata in memory.
+    """
+    if is_gguf_available() and is_torch_available():
+        from gguf import GGUFReader, dequantize
+    else:
+        logger.error(
+            "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
+            "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
+        )
+        raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
+
+    reader = GGUFReader(gguf_checkpoint_path)
+    fields = reader.fields
+    reader_keys = list(fields.keys())
+
+    parsed_parameters = {k: {} for k in GGUF_TO_TRANSFORMERS_MAPPING}
+
+    architecture = read_field(reader, "general.architecture")[0]
+    # NOTE: Some GGUF checkpoints may miss `general.name` field in metadata
+    model_name = read_field(reader, "general.name")
+
+    updated_architecture = None
+    # in llama.cpp mistral models use the same architecture as llama. We need
+    # to add this patch to ensure things work correctly on our side.
+    if "llama" in architecture and "mistral" in model_name:
+        updated_architecture = "mistral"
+    # FIXME: Currently this implementation is only for flan-t5 architecture.
+    # It needs to be developed for supporting legacy t5.
+    elif "t5" in architecture or "t5encoder" in architecture:
+        parsed_parameters["config"]["is_gated_act"] = True
+        if "t5encoder" in architecture:
+            parsed_parameters["config"]["architectures"] = ["T5EncoderModel"]
+        updated_architecture = "t5"
+    else:
+        updated_architecture = architecture
+
+    if "qwen2moe" in architecture:
+        updated_architecture = "qwen2_moe"
+    elif "qwen3moe" in architecture:
+        updated_architecture = "qwen3_moe"
+
+    # For stablelm architecture, we need to set qkv_bias and use_parallel_residual from tensors
+    # If `qkv_bias=True`, qkv_proj with bias will be present in the tensors
+    # If `use_parallel_residual=False`, ffn_norm will be present in the tensors
+    if "stablelm" in architecture:
+        attn_bias_name = {"attn_q.bias", "attn_k.bias", "attn_v.bias"}
+        ffn_norm_name = "ffn_norm"
+        qkv_bias = any(bias_name in tensor.name for tensor in reader.tensors for bias_name in attn_bias_name)
+        use_parallel_residual = any(ffn_norm_name in tensor.name for tensor in reader.tensors)
+        parsed_parameters["config"]["use_qkv_bias"] = qkv_bias
+        parsed_parameters["config"]["use_parallel_residual"] = not use_parallel_residual
+
+    if architecture not in GGUF_SUPPORTED_ARCHITECTURES and updated_architecture not in GGUF_SUPPORTED_ARCHITECTURES:
+        raise ValueError(f"GGUF model with architecture {architecture} is not supported yet.")
+
+    # Handle tie_word_embeddings, if lm_head.weight is not present in tensors,
+    # tie_word_embeddings is true otherwise false
+    exceptions = ["falcon", "bloom"]
+    parsed_parameters["config"]["tie_word_embeddings"] = (
+        all("output.weight" != tensor.name for tensor in reader.tensors) or architecture in exceptions
+    )
+
+    # List all key-value pairs in a columnized format
+    for gguf_key, field in reader.fields.items():
+        gguf_key = gguf_key.replace(architecture, updated_architecture)
+        split = gguf_key.split(".")
+        prefix = split[0]
+        config_key = ".".join(split[1:])
+
+        value = [_gguf_parse_value(field.parts[_data_index], field.types) for _data_index in field.data]
+
+        if len(value) == 1:
+            value = value[0]
+
+        if isinstance(value, str) and architecture in value:
+            value = value.replace(architecture, updated_architecture)
+
+        for parameter, parameter_renames in GGUF_TO_TRANSFORMERS_MAPPING.items():
+            if prefix in parameter_renames and config_key in parameter_renames[prefix]:
+                renamed_config_key = parameter_renames[prefix][config_key]
+                if renamed_config_key == -1:
+                    continue
+
+                if renamed_config_key is not None:
+                    parsed_parameters[parameter][renamed_config_key] = value
+
+                if gguf_key in reader_keys:
+                    reader_keys.remove(gguf_key)
+
+        if gguf_key in reader_keys:
+            logger.info(f"Some keys were not parsed and added into account {gguf_key} | {value}")
+
+    # Gemma3 GGUF checkpoint only contains weights of text backbone
+    if parsed_parameters["config"]["model_type"] == "gemma3":
+        parsed_parameters["config"]["model_type"] = "gemma3_text"
+
+    # retrieve config vocab_size from tokenizer
+    # Please refer to https://github.com/huggingface/transformers/issues/32526 for more details
+    if "vocab_size" not in parsed_parameters["config"]:
+        tokenizer_parameters = parsed_parameters["tokenizer"]
+        if "tokens" in tokenizer_parameters:
+            parsed_parameters["config"]["vocab_size"] = len(tokenizer_parameters["tokens"])
+        else:
+            logger.warning(
+                "Can't find a way to retrieve missing config vocab_size from tokenizer parameters. "
+                "This will use default value from model config class and cause unexpected behavior."
+            )
+
+    if return_tensors:
+        parsed_parameters["tensors"] = {}
+
+        tensor_key_mapping = get_gguf_hf_weights_map(model_to_load)
+        config = parsed_parameters.get("config", {})
+
+        ProcessorClass = TENSOR_PROCESSORS.get(architecture, TensorProcessor)
+        processor = ProcessorClass(config=config)
+
+        for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."):
+            name = tensor.name
+            weights = dequantize(tensor.data, tensor.tensor_type)
+
+            result = processor.process(
+                weights=weights,
+                name=name,
+                tensor_key_mapping=tensor_key_mapping,
+                parsed_parameters=parsed_parameters,
+            )
+
+            weights = result.weights
+            name = result.name
+
+            if name not in tensor_key_mapping:
+                continue
+
+            name = tensor_key_mapping[name]
+
+            parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
+
+    if len(reader_keys) > 0:
+        logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")
+
+    return parsed_parameters
diff --git a/phivenv/Lib/site-packages/transformers/modeling_layers.py b/phivenv/Lib/site-packages/transformers/modeling_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..eea5595dc49e4a19e5b6195271b32bc3d346e3c0
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/modeling_layers.py
@@ -0,0 +1,289 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from functools import partial
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from .cache_utils import Cache
+from .modeling_outputs import (
+    BaseModelOutputWithPast,
+    QuestionAnsweringModelOutput,
+    SequenceClassifierOutputWithPast,
+    TokenClassifierOutput,
+)
+from .models.auto import AutoModel
+from .processing_utils import Unpack
+from .utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+class GradientCheckpointingLayer(nn.Module):
+    """Base class for layers with gradient checkpointing.
+
+    This class enables gradient checkpointing functionality for a layer. By default, gradient checkpointing is disabled
+    (`gradient_checkpointing = False`). When `model.set_gradient_checkpointing()` is called, gradient checkpointing is
+    enabled by setting `gradient_checkpointing = True` and assigning a checkpointing function to `_gradient_checkpointing_func`.
+
+    Important:
+
+        When using gradient checkpointing with `use_reentrant=True`, inputs that require gradients (e.g. hidden states)
+        must be passed as positional arguments (`*args`) rather than keyword arguments to properly propagate gradients.
+
+        Example:
+
+            ```python
+            >>> # Correct - hidden_states passed as positional arg
+            >>> out = self.layer(hidden_states, attention_mask=attention_mask)
+
+            >>> # Incorrect - hidden_states passed as keyword arg
+            >>> out = self.layer(hidden_states=hidden_states, attention_mask=attention_mask)
+            ```
+    """
+
+    gradient_checkpointing = False
+
+    def __call__(self, *args, **kwargs):
+        if self.gradient_checkpointing and self.training:
+            do_warn = False
+            layer_name = self.__class__.__name__
+            message = f"Caching is incompatible with gradient checkpointing in {layer_name}. Setting"
+
+            if "use_cache" in kwargs and kwargs["use_cache"]:
+                kwargs["use_cache"] = False
+                message += " `use_cache=False`,"
+                do_warn = True
+
+            # different names for the same thing in different layers
+            # TODO cyril: this one without `S` can be removed after deprection cycle
+            if "past_key_value" in kwargs and kwargs["past_key_value"] is not None:
+                kwargs["past_key_value"] = None
+                message += " `past_key_value=None`,"
+                do_warn = True
+
+            if "past_key_values" in kwargs and kwargs["past_key_values"] is not None:
+                kwargs["past_key_values"] = None
+                message += " `past_key_values=None`,"
+                do_warn = True
+
+            if "layer_past" in kwargs and kwargs["layer_past"] is not None:
+                kwargs["layer_past"] = None
+                message += " `layer_past=None`,"
+                do_warn = True
+
+            # warn if anything was changed
+            if do_warn:
+                message = message.rstrip(",") + "."
+                logger.warning_once(message)
+
+            return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args)
+        return super().__call__(*args, **kwargs)
+
+
+@auto_docstring
+class GenericForSequenceClassification(object):
+    base_model_prefix = "model"
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+        # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
+        setattr(self, self.base_model_prefix, AutoModel.from_config(config))
+        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @can_return_tuple
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> SequenceClassifierOutputWithPast:
+        transformer_outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)(
+            input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            **kwargs,
+        )
+        hidden_states = transformer_outputs.last_hidden_state
+        logits = self.score(hidden_states)
+
+        if input_ids is not None:
+            batch_size = input_ids.shape[0]
+        else:
+            batch_size = inputs_embeds.shape[0]
+
+        if self.config.pad_token_id is None and batch_size != 1:
+            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+        if self.config.pad_token_id is None:
+            last_non_pad_token = -1
+        elif input_ids is not None:
+            # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
+            non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
+            token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
+            last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
+        else:
+            last_non_pad_token = -1
+            logger.warning_once(
+                f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+                "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+            )
+
+        pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
+
+        loss = None
+        if labels is not None:
+            loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
+
+        return SequenceClassifierOutputWithPast(
+            loss=loss,
+            logits=pooled_logits,
+            past_key_values=transformer_outputs.past_key_values,
+            hidden_states=transformer_outputs.hidden_states,
+            attentions=transformer_outputs.attentions,
+        )
+
+
+@auto_docstring
+class GenericForQuestionAnswering(object):
+    base_model_prefix = "model"
+
+    def __init__(self, config):
+        super().__init__(config)
+        # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
+        setattr(self, self.base_model_prefix, AutoModel.from_config(config))
+        self.qa_outputs = nn.Linear(config.hidden_size, 2)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_input_embeddings(self):
+        return getattr(self, self.base_model_prefix).embed_tokens
+
+    def set_input_embeddings(self, value):
+        getattr(self, self.base_model_prefix).embed_tokens = value
+
+    @can_return_tuple
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        start_positions: Optional[torch.LongTensor] = None,
+        end_positions: Optional[torch.LongTensor] = None,
+        **kwargs: Unpack[TransformersKwargs],
+    ) -> QuestionAnsweringModelOutput:
+        outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)(
+            input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            **kwargs,
+        )
+
+        sequence_output = outputs.last_hidden_state
+
+        logits = self.qa_outputs(sequence_output)
+        start_logits, end_logits = logits.split(1, dim=-1)
+        start_logits = start_logits.squeeze(-1).contiguous()
+        end_logits = end_logits.squeeze(-1).contiguous()
+
+        loss = None
+        if start_positions is not None and end_positions is not None:
+            loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
+
+        return QuestionAnsweringModelOutput(
+            loss=loss,
+            start_logits=start_logits,
+            end_logits=end_logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+
+@auto_docstring
+class GenericForTokenClassification(object):
+    base_model_prefix = "model"
+
+    def __init__(self, config):
+        super().__init__(config)
+        self.num_labels = config.num_labels
+        # Similar to `self.model = AutoModel.from_config(config)` but allows to change the base model name if needed in the child class
+        setattr(self, self.base_model_prefix, AutoModel.from_config(config))
+        if getattr(config, "classifier_dropout", None) is not None:
+            classifier_dropout = config.classifier_dropout
+        elif getattr(config, "hidden_dropout", None) is not None:
+            classifier_dropout = config.hidden_dropout
+        else:
+            classifier_dropout = 0.1
+        self.dropout = nn.Dropout(classifier_dropout)
+        self.score = nn.Linear(config.hidden_size, config.num_labels)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @can_return_tuple
+    @auto_docstring
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[Cache] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        labels: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        **kwargs,
+    ) -> TokenClassifierOutput:
+        outputs: BaseModelOutputWithPast = getattr(self, self.base_model_prefix)(
+            input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            **kwargs,
+        )
+        sequence_output = outputs.last_hidden_state
+        sequence_output = self.dropout(sequence_output)
+        logits = self.score(sequence_output)
+
+        loss = None
+        if labels is not None:
+            loss = self.loss_function(logits, labels, self.config)
+
+        return TokenClassifierOutput(
+            loss=loss,
+            logits=logits,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
diff --git a/phivenv/Lib/site-packages/transformers/modeling_outputs.py b/phivenv/Lib/site-packages/transformers/modeling_outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..597e20b28ca8bc8aa84b7f94b1b8033b89e0cd30
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/modeling_outputs.py
@@ -0,0 +1,1715 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+
+from .cache_utils import Cache, EncoderDecoderCache
+from .utils import ModelOutput
+
+
+@dataclass
+class BaseModelOutput(ModelOutput):
+    """
+    Base class for model's outputs, with potential hidden states and attentions.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithNoAttention(ModelOutput):
+    """
+    Base class for model's outputs, with potential hidden states.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, num_channels, height, width)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPooling(ModelOutput):
+    """
+    Base class for model's outputs that also contains a pooling of the last hidden states.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+            Last layer hidden-state of the first token of the sequence (classification token) after further processing
+            through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
+            the classification token after processing through a linear layer and a tanh activation function. The linear
+            layer weights are trained from the next sentence prediction (classification) objective during pretraining.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    pooler_output: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPoolingAndNoAttention(ModelOutput):
+    """
+    Base class for model's outputs that also contains a pooling of the last hidden states.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+            Last layer hidden-state after a pooling operation on the spatial dimensions.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, num_channels, height, width)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    pooler_output: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPast(ModelOutput):
+    """
+    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+
+            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+            hidden_size)` is output.
+        past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+            `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+            input) to speed up sequential decoding.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[Cache] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithCrossAttentions(ModelOutput):
+    """
+    Base class for model's outputs, with potential hidden states and attentions.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
+    """
+    Base class for model's outputs that also contains a pooling of the last hidden states.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+            Last layer hidden-state of the first token of the sequence (classification token) after further processing
+            through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
+            the classification token after processing through a linear layer and a tanh activation function. The linear
+            layer weights are trained from the next sentence prediction (classification) objective during pretraining.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+            `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+            input) to speed up sequential decoding.
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    pooler_output: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    past_key_values: Optional[Cache] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
+    """
+    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+
+            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+            hidden_size)` is output.
+        past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+            `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+            input) to speed up sequential decoding.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[Cache] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class MoECausalLMOutputWithPast(ModelOutput):
+    """
+    Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden
+    states terms, to train a MoE model.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Language modeling loss (for next-token prediction).
+        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+            `past_key_values` input) to speed up sequential decoding.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
+            z_loss for the sparse modules.
+        aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
+            aux_loss for the sparse modules.
+        router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+            Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse
+            modules.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[Cache] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    z_loss: Optional[torch.FloatTensor] = None
+    aux_loss: Optional[torch.FloatTensor] = None
+    router_logits: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class MoEModelOutput(ModelOutput):
+    """
+    Base class for model's outputs, with potential hidden states and attentions.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+            Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary
+            loss and the z_loss for Mixture of Experts models.
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    router_probs: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class MoeModelOutputWithPast(ModelOutput):
+    """
+    Base class for model's outputs, with potential hidden states and attentions.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+            `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+            input) to speed up sequential decoding.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+            Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary
+            loss for Mixture of Experts models.
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[Cache] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    router_logits: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class MoeCausalLMOutputWithPast(ModelOutput):
+    """
+    Base class for causal language model (or autoregressive) with mixture of experts outputs.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Language modeling loss (for next-token prediction).
+
+        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+
+        aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
+            aux_loss for the sparse modules.
+
+        router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+            Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary
+            loss for Mixture of Experts models.
+
+        past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+            `past_key_values` input) to speed up sequential decoding.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    aux_loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[Cache] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    router_logits: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class MoEModelOutputWithPastAndCrossAttentions(ModelOutput):
+    """
+    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding) as well as
+    Mixture of Expert's router hidden states terms, to train a MoE model.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+
+            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+            hidden_size)` is output.
+        past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
+            `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
+            input) to speed up sequential decoding.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+            Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary
+            loss and the z_loss for Mixture of Experts models.
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[Cache] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    router_probs: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class Seq2SeqModelOutput(ModelOutput):
+    """
+    Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
+    decoding.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the decoder of the model.
+
+            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+            hidden_size)` is output.
+        past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
+        decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder of the model.
+        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
+        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[EncoderDecoderCache] = None
+    decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+    encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Seq2SeqMoEModelOutput(ModelOutput):
+    """
+    Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
+    decoding.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the decoder of the model.
+
+            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+            hidden_size)` is output.
+        past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
+        decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+            Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models.
+        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder of the model.
+        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
+        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+            Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse
+            modules.
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[EncoderDecoderCache] = None
+    decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    decoder_router_logits: Optional[tuple[torch.FloatTensor]] = None
+    cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+    encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_router_logits: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class CausalLMOutput(ModelOutput):
+    """
+    Base class for causal language model (or autoregressive) outputs.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Language modeling loss (for next-token prediction).
+        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class CausalLMOutputWithPast(ModelOutput):
+    """
+    Base class for causal language model (or autoregressive) outputs.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Language modeling loss (for next-token prediction).
+        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+            `past_key_values` input) to speed up sequential decoding.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[Cache] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class CausalLMOutputWithCrossAttentions(ModelOutput):
+    """
+    Base class for causal language model (or autoregressive) outputs.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Language modeling loss (for next-token prediction).
+        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Cross attentions weights after the attention softmax, used to compute the weighted average in the
+            cross-attention heads.
+        past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+            `past_key_values` input) to speed up sequential decoding.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[Cache] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class SequenceClassifierOutputWithPast(ModelOutput):
+    """
+    Base class for outputs of sentence classification models.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Classification (or regression if config.num_labels==1) loss.
+        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+            Classification (or regression if config.num_labels==1) scores (before SoftMax).
+        past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+            `past_key_values` input) to speed up sequential decoding.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[Cache] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class MaskedLMOutput(ModelOutput):
+    """
+    Base class for masked language models outputs.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Masked language modeling (MLM) loss.
+        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Seq2SeqLMOutput(ModelOutput):
+    """
+    Base class for sequence-to-sequence language models outputs.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Language modeling loss.
+        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+        decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder of the model.
+        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[EncoderDecoderCache] = None
+    decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+    encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Seq2SeqMoEOutput(ModelOutput):
+    """
+    Base class for sequence-to-sequence language models outputs.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Language modeling loss.
+        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+        decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        decoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+            Router logits of the decoder model, useful to compute the auxiliary loss for Mixture of Experts models.
+        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder of the model.
+        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        encoder_router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
+
+            Router logits of the encoder model, useful to compute the auxiliary loss and z_loss for Mixture of Experts
+            models.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    encoder_z_loss: Optional[torch.FloatTensor] = None
+    decoder_z_loss: Optional[torch.FloatTensor] = None
+    encoder_aux_loss: Optional[torch.FloatTensor] = None
+    decoder_aux_loss: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[EncoderDecoderCache] = None
+    decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    decoder_router_logits: Optional[tuple[torch.FloatTensor]] = None
+    cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+    encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_router_logits: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class NextSentencePredictorOutput(ModelOutput):
+    """
+    Base class for outputs of models predicting if two sentences are consecutive or not.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `next_sentence_label` is provided):
+            Next sequence prediction (classification) loss.
+        logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
+            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
+            before SoftMax).
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class SequenceClassifierOutput(ModelOutput):
+    """
+    Base class for outputs of sentence classification models.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Classification (or regression if config.num_labels==1) loss.
+        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+            Classification (or regression if config.num_labels==1) scores (before SoftMax).
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Seq2SeqSequenceClassifierOutput(ModelOutput):
+    """
+    Base class for outputs of sequence-to-sequence sentence classification models.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `label` is provided):
+            Classification (or regression if config.num_labels==1) loss.
+        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+            Classification (or regression if config.num_labels==1) scores (before SoftMax).
+        past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+        decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder of the model.
+        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[EncoderDecoderCache] = None
+    decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+    encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class MultipleChoiceModelOutput(ModelOutput):
+    """
+    Base class for outputs of multiple choice models.
+
+    Args:
+        loss (`torch.FloatTensor` of shape *(1,)*, *optional*, returned when `labels` is provided):
+            Classification loss.
+        logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`):
+            *num_choices* is the second dimension of the input tensors. (see *input_ids* above).
+
+            Classification scores (before SoftMax).
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class TokenClassifierOutput(ModelOutput):
+    """
+    Base class for outputs of token classification models.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
+            Classification loss.
+        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
+            Classification scores (before SoftMax).
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class QuestionAnsweringModelOutput(ModelOutput):
+    """
+    Base class for outputs of question answering models.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
+        start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+            Span-start scores (before SoftMax).
+        end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+            Span-end scores (before SoftMax).
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    start_logits: Optional[torch.FloatTensor] = None
+    end_logits: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
+    """
+    Base class for outputs of sequence-to-sequence question answering models.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
+        start_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+            Span-start scores (before SoftMax).
+        end_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
+            Span-end scores (before SoftMax).
+        past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+        decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder of the model.
+        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    start_logits: Optional[torch.FloatTensor] = None
+    end_logits: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[EncoderDecoderCache] = None
+    decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+    encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class SemanticSegmenterOutput(ModelOutput):
+    """
+    Base class for outputs of semantic segmentation models.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Classification (or regression if config.num_labels==1) loss.
+        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
+            Classification scores for each pixel.
+
+            
+
+            The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
+            to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
+            original image size as post-processing. You should always check your logits shape and resize as needed.
+
+            
+
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, patch_size, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class ImageClassifierOutput(ModelOutput):
+    """
+    Base class for outputs of image classification models.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Classification (or regression if config.num_labels==1) loss.
+        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+            Classification (or regression if config.num_labels==1) scores (before SoftMax).
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
+            (also called feature maps) of the model at the output of each stage.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class ImageClassifierOutputWithNoAttention(ModelOutput):
+    """
+    Base class for outputs of image classification models.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Classification (or regression if config.num_labels==1) loss.
+        logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
+            Classification (or regression if config.num_labels==1) scores (before SoftMax).
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
+            called feature maps) of the model at the output of each stage.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class DepthEstimatorOutput(ModelOutput):
+    """
+    Base class for outputs of depth estimation models.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Classification (or regression if config.num_labels==1) loss.
+        predicted_depth (`torch.FloatTensor` of shape `(batch_size, height, width)`):
+            Predicted depth for each pixel.
+
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, num_channels, height, width)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    predicted_depth: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class ImageSuperResolutionOutput(ModelOutput):
+    """
+    Base class for outputs of image super resolution models.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Reconstruction loss.
+        reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+           Reconstructed images, possibly upscaled.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
+            (also called feature maps) of the model at the output of each stage.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    reconstruction: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Wav2Vec2BaseModelOutput(ModelOutput):
+    """
+    Base class for models that have been trained with the Wav2Vec2 loss objective.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        extract_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim[-1])`):
+            Sequence of extracted feature vectors of the last convolutional layer of the model.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+            shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    extract_features: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class XVectorOutput(ModelOutput):
+    """
+    Output type of [`Wav2Vec2ForXVector`].
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Classification loss.
+        logits (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
+            Classification hidden states before AMSoftmax.
+        embeddings (`torch.FloatTensor` of shape `(batch_size, config.xvector_output_dim)`):
+            Utterance embeddings used for vector similarity-based retrieval.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+            shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    logits: Optional[torch.FloatTensor] = None
+    embeddings: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BackboneOutput(ModelOutput):
+    """
+    Base class for outputs of backbones.
+
+    Args:
+        feature_maps (`tuple(torch.FloatTensor)` of shape `(batch_size, num_channels, height, width)`):
+            Feature maps of the stages.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
+            shape `(batch_size, sequence_length, hidden_size)` or `(batch_size, num_channels, height, width)`,
+            depending on the backbone.
+
+            Hidden-states of the model at the output of each stage plus the initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`. Only applicable if the backbone uses attention.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    feature_maps: Optional[tuple[torch.FloatTensor]] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class BaseModelOutputWithPoolingAndProjection(ModelOutput):
+    """
+    Base class for model's outputs that also contains a pooling of the last hidden states.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
+            Last layer hidden-state of the first token of the sequence (classification token) after further processing
+            through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
+            the classification token after processing through a linear layer and a tanh activation function. The linear
+            layer weights are trained from the next sentence prediction (classification) objective during pretraining.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        projection_state (`tuple(torch.FloatTensor)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` of shape `(batch_size,config.project_dim)`.
+
+            Text embeddings before the projection layer, used to mimic the last hidden state of the teacher encoder.
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    pooler_output: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    projection_state: Optional[tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class Seq2SeqSpectrogramOutput(ModelOutput):
+    """
+    Base class for sequence-to-sequence spectrogram outputs.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Spectrogram generation loss.
+        spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`):
+            The predicted spectrogram.
+        past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+        decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder of the model.
+        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    spectrogram: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[EncoderDecoderCache] = None
+    decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+    encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+
+@dataclass
+class Seq2SeqTSModelOutput(ModelOutput):
+    """
+    Base class for time series model's encoder outputs that also contains pre-computed hidden states that can speed up
+    sequential decoding.
+
+    Args:
+        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the decoder of the model.
+
+            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+            hidden_size)` is output.
+        past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs.
+        decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder of the model.
+        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs.
+        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        loc (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*):
+            Shift values of each time series' context window which is used to give the model inputs of the same
+            magnitude and then used to shift back to the original magnitude.
+        scale (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*):
+            Scaling values of each time series' context window which is used to give the model inputs of the same
+            magnitude and then used to rescale back to the original magnitude.
+        static_features (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*):
+            Static features of each time series' in a batch which are copied to the covariates at inference time.
+    """
+
+    last_hidden_state: Optional[torch.FloatTensor] = None
+    past_key_values: Optional[EncoderDecoderCache] = None
+    decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+    encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    loc: Optional[torch.FloatTensor] = None
+    scale: Optional[torch.FloatTensor] = None
+    static_features: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+class Seq2SeqTSPredictionOutput(ModelOutput):
+    """
+    Base class for time series model's decoder outputs that also contain the loss as well as the parameters of the
+    chosen distribution.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when a `future_values` is provided):
+            Distributional loss.
+        params (`torch.FloatTensor` of shape `(batch_size, num_samples, num_params)`):
+            Parameters of the chosen distribution.
+        past_key_values (`EncoderDecoderCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            It is a [`~cache_utils.EncoderDecoderCache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+            Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
+            blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+        decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+        decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder of the model.
+        encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+        encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        loc (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*):
+            Shift values of each time series' context window which is used to give the model inputs of the same
+            magnitude and then used to shift back to the original magnitude.
+        scale (`torch.FloatTensor` of shape `(batch_size,)` or `(batch_size, input_size)`, *optional*):
+            Scaling values of each time series' context window which is used to give the model inputs of the same
+            magnitude and then used to rescale back to the original magnitude.
+        static_features (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*):
+            Static features of each time series' in a batch which are copied to the covariates at inference time.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    params: Optional[tuple[torch.FloatTensor]] = None
+    past_key_values: Optional[EncoderDecoderCache] = None
+    decoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    decoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_last_hidden_state: Optional[torch.FloatTensor] = None
+    encoder_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    encoder_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+    loc: Optional[torch.FloatTensor] = None
+    scale: Optional[torch.FloatTensor] = None
+    static_features: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+class SampleTSPredictionOutput(ModelOutput):
+    """
+    Base class for time series model's predictions outputs that contains the sampled values from the chosen
+    distribution.
+
+    Args:
+        sequences (`torch.FloatTensor` of shape `(batch_size, num_samples, prediction_length)` or `(batch_size, num_samples, prediction_length, input_size)`):
+            Sampled values from the chosen distribution.
+    """
+
+    sequences: Optional[torch.FloatTensor] = None
+
+
+@dataclass
+class MaskedImageModelingOutput(ModelOutput):
+    """
+    Base class for outputs of masked image completion / in-painting models.
+
+    Args:
+        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
+            Reconstruction loss.
+        reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+           Reconstructed / completed images.
+        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
+        when `config.output_hidden_states=True`):
+            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+            one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
+            (also called feature maps) of the model at the output of each stage.
+        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when
+        `config.output_attentions=True`):
+            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
+            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
+            the self-attention heads.
+    """
+
+    loss: Optional[torch.FloatTensor] = None
+    reconstruction: Optional[torch.FloatTensor] = None
+    hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
+    attentions: Optional[tuple[torch.FloatTensor, ...]] = None
+
+    @property
+    def logits(self):
+        warnings.warn(
+            "logits attribute is deprecated and will be removed in version 5 of Transformers."
+            " Please use the reconstruction attribute to retrieve the final output instead.",
+            FutureWarning,
+        )
+        return self.reconstruction
diff --git a/phivenv/Lib/site-packages/transformers/modeling_rope_utils.py b/phivenv/Lib/site-packages/transformers/modeling_rope_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..34c136980234c7a8c62676ae6f8cc1cbe89724bb
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/modeling_rope_utils.py
@@ -0,0 +1,634 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from functools import wraps
+from typing import Optional
+
+from .configuration_utils import PretrainedConfig
+from .utils import is_torch_available, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_torch_available():
+    import torch
+
+
+def dynamic_rope_update(rope_forward):
+    """
+    Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE
+    (i.e. a RoPE implementation that may recompute its frequencies in the forward pass).
+
+    Args:
+        rope_forward (Callable):
+            The forward pass of the RoPE implementation.
+
+    Returns:
+        The decorated forward pass.
+    """
+
+    def longrope_frequency_update(self, position_ids, device):
+        """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise."""
+        seq_len = torch.max(position_ids) + 1
+        if hasattr(self.config, "original_max_position_embeddings"):
+            original_max_position_embeddings = self.config.original_max_position_embeddings
+        else:
+            original_max_position_embeddings = self.config.max_position_embeddings
+        if seq_len > original_max_position_embeddings:
+            if not hasattr(self, "long_inv_freq"):
+                self.long_inv_freq, _ = self.rope_init_fn(
+                    self.config, device, seq_len=original_max_position_embeddings + 1
+                )
+            self.register_buffer("inv_freq", self.long_inv_freq, persistent=False)
+        else:
+            # This .to() is needed if the model has been moved to a device after being initialized (because
+            # the buffer is automatically moved, but not the original copy)
+            self.original_inv_freq = self.original_inv_freq.to(device)
+            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+
+    def dynamic_frequency_update(self, position_ids, device):
+        """
+        dynamic RoPE layers should recompute `inv_freq` in the following situations:
+        1 - growing beyond the cached sequence length (allow scaling)
+        2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
+        """
+        seq_len = torch.max(position_ids) + 1
+        if seq_len > self.max_seq_len_cached:  # growth
+            inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
+            self.register_buffer("inv_freq", inv_freq, persistent=False)  # TODO joao: may break with compilation
+            self.max_seq_len_cached = seq_len
+
+        if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:  # reset
+            # This .to() is needed if the model has been moved to a device after being initialized (because
+            # the buffer is automatically moved, but not the original copy)
+            self.original_inv_freq = self.original_inv_freq.to(device)
+            self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+            self.max_seq_len_cached = self.original_max_seq_len
+
+    @wraps(rope_forward)
+    def wrapper(self, x, position_ids):
+        if "dynamic" in self.rope_type:
+            dynamic_frequency_update(self, position_ids, device=x.device)
+        elif self.rope_type == "longrope":
+            longrope_frequency_update(self, position_ids, device=x.device)
+        return rope_forward(self, x, position_ids)
+
+    return wrapper
+
+
+def _compute_default_rope_parameters(
+    config: Optional[PretrainedConfig] = None,
+    device: Optional["torch.device"] = None,
+    seq_len: Optional[int] = None,
+) -> tuple["torch.Tensor", float]:
+    """
+    Computes the inverse frequencies according to the original RoPE implementation
+    Args:
+        config ([`~transformers.PretrainedConfig`]):
+            The model configuration.
+        device (`torch.device`):
+            The device to use for initialization of the inverse frequencies.
+        seq_len (`int`, *optional*):
+            The current sequence length. Unused for this type of RoPE.
+    Returns:
+        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+        post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+    """
+    base = config.rope_theta
+    partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+    head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
+    dim = int(head_dim * partial_rotary_factor)
+
+    attention_factor = 1.0  # Unused in this type of RoPE
+
+    # Compute the inverse frequencies
+    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
+    return inv_freq, attention_factor
+
+
+def _compute_linear_scaling_rope_parameters(
+    config: Optional[PretrainedConfig] = None,
+    device: Optional["torch.device"] = None,
+    seq_len: Optional[int] = None,
+) -> tuple["torch.Tensor", float]:
+    """
+    Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
+    Args:
+        config ([`~transformers.PretrainedConfig`]):
+            The model configuration.
+        device (`torch.device`):
+            The device to use for initialization of the inverse frequencies.
+        seq_len (`int`, *optional*):
+            The current sequence length. Unused for this type of RoPE.
+    Returns:
+        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+        post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+    """
+    factor = config.rope_scaling["factor"]
+
+    # Gets the default RoPE parameters
+    inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len)
+
+    # Then applies linear scaling to the frequencies.
+    # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
+    # applying scaling to the inverse frequencies is equivalent.
+    inv_freq /= factor
+    return inv_freq, attention_factor
+
+
+def _compute_dynamic_ntk_parameters(
+    config: Optional[PretrainedConfig] = None,
+    device: Optional["torch.device"] = None,
+    seq_len: Optional[int] = None,
+) -> tuple["torch.Tensor", float]:
+    """
+    Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
+    Args:
+        config ([`~transformers.PretrainedConfig`]):
+            The model configuration.
+        device (`torch.device`):
+            The device to use for initialization of the inverse frequencies.
+        seq_len (`int`, *optional*):
+            The current sequence length, used to update the dynamic RoPE at inference time.
+    Returns:
+        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+        post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
+    """
+    # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
+    base = config.rope_theta
+    partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+    head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+    dim = int(head_dim * partial_rotary_factor)
+    max_position_embeddings = config.max_position_embeddings
+    factor = config.rope_scaling["factor"]
+
+    attention_factor = 1.0  # Unused in this type of RoPE
+
+    # seq_len: default to max_position_embeddings, e.g. at init time
+    if seq_len is None:
+        seq_len = max_position_embeddings
+    elif isinstance(seq_len, torch.Tensor):
+        seq_len = torch.maximum(
+            seq_len,
+            torch.tensor(max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device),
+        )
+    else:
+        seq_len = max(seq_len, max_position_embeddings)
+
+    # Compute the inverse frequencies
+    base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
+    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
+    return inv_freq, attention_factor
+
+
+def _compute_yarn_parameters(
+    config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None
+) -> tuple["torch.Tensor", float]:
+    """
+    Computes the inverse frequencies with NTK scaling. Please refer to the
+    [original paper](https://huggingface.co/papers/2309.00071)
+    Args:
+        config ([`~transformers.PretrainedConfig`]):
+            The model configuration.
+        device (`torch.device`):
+            The device to use for initialization of the inverse frequencies.
+        seq_len (`int`, *optional*):
+            The current sequence length. Unused for this type of RoPE.
+    Returns:
+        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+        post-processing scaling factor applied to the computed cos/sin.
+    """
+
+    base = config.rope_theta
+    partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+    head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+    dim = int(head_dim * partial_rotary_factor)
+    factor = config.rope_scaling["factor"]
+    attention_factor = config.rope_scaling.get("attention_factor")
+    mscale = config.rope_scaling.get("mscale")
+    mscale_all_dim = config.rope_scaling.get("mscale_all_dim")
+    original_max_position_embeddings = (
+        config.rope_scaling.get("original_max_position_embeddings") or config.max_position_embeddings
+    )
+
+    def get_mscale(scale, mscale=1):
+        if scale <= 1:
+            return 1.0
+        return 0.1 * mscale * math.log(scale) + 1.0
+
+    # Sets the attention factor as suggested in the paper
+    if attention_factor is None:
+        if mscale and mscale_all_dim:
+            attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim))
+        else:
+            attention_factor = get_mscale(factor)
+
+    # Optional config options
+    # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
+    beta_fast = config.rope_scaling.get("beta_fast") or 32
+    beta_slow = config.rope_scaling.get("beta_slow") or 1
+
+    # Compute the inverse frequencies
+    def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
+        """Inverse dimension formula to find the dimension based on the number of rotations"""
+        return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
+
+    def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate):
+        """Find dimension range bounds based on rotations"""
+        low = find_correction_dim(low_rot, dim, base, max_position_embeddings)
+        high = find_correction_dim(high_rot, dim, base, max_position_embeddings)
+        if truncate:
+            low = math.floor(low)
+            high = math.ceil(high)
+        return max(low, 0), min(high, dim - 1)
+
+    def linear_ramp_factor(min, max, dim):
+        if min == max:
+            max += 0.001  # Prevent singularity
+
+        linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
+        ramp_func = torch.clamp(linear_func, 0, 1)
+        return ramp_func
+
+    # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
+    # to expand the possible context length. In other words, interpolation = apply scaling factor.
+    pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim)
+    inv_freq_extrapolation = 1.0 / pos_freqs
+    inv_freq_interpolation = 1.0 / (factor * pos_freqs)
+
+    truncate = config.rope_scaling.get("truncate", True)
+    low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate)
+
+    # Get n-dimensional rotational scaling corrected for extrapolation
+    inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float)
+    inv_freq = (
+        inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
+        + inv_freq_extrapolation * inv_freq_extrapolation_factor
+    )
+    return inv_freq, attention_factor
+
+
+def _compute_longrope_parameters(
+    config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None
+) -> tuple["torch.Tensor", float]:
+    """
+    Computes the inverse frequencies with LongRoPE scaling. Please refer to the
+    [original implementation](https://github.com/microsoft/LongRoPE)
+    Args:
+        config ([`~transformers.PretrainedConfig`]):
+            The model configuration.
+        device (`torch.device`):
+            The device to use for initialization of the inverse frequencies.
+        seq_len (`int`, *optional*):
+            The current sequence length.
+    Returns:
+        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+        post-processing scaling factor applied to the computed cos/sin.
+    """
+    # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
+    base = config.rope_theta
+    partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+    head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+    dim = int(head_dim * partial_rotary_factor)
+    long_factor = config.rope_scaling["long_factor"]
+    short_factor = config.rope_scaling["short_factor"]
+    factor = config.rope_scaling.get("factor")
+    attention_factor = config.rope_scaling.get("attention_factor")
+
+    # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
+    # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
+    # values to compute the default attention scaling factor, instead of using `factor`.
+    if hasattr(config, "original_max_position_embeddings"):
+        original_max_position_embeddings = config.original_max_position_embeddings
+        factor = config.max_position_embeddings / config.original_max_position_embeddings
+    else:
+        original_max_position_embeddings = config.max_position_embeddings
+
+    # Sets the attention factor as suggested in the paper
+    if attention_factor is None:
+        if factor <= 1.0:
+            attention_factor = 1.0
+        else:
+            attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))
+
+    # Compute the inverse frequencies -- scaled based on the target sequence length
+    if seq_len and seq_len > original_max_position_embeddings:
+        ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
+    else:
+        ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
+    inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
+    inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
+
+    return inv_freq, attention_factor
+
+
+def _compute_llama3_parameters(
+    config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None
+) -> tuple["torch.Tensor", float]:
+    """
+    Computes the inverse frequencies for llama 3.1.
+
+    Args:
+        config ([`~transformers.PretrainedConfig`]):
+            The model configuration.
+        device (`torch.device`):
+            The device to use for initialization of the inverse frequencies.
+        seq_len (`int`, *optional*):
+            The current sequence length. Unused for this type of RoPE.
+    Returns:
+        Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
+        post-processing scaling factor applied to the computed cos/sin.
+    """
+    # Gets the default RoPE parameters
+    inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len)
+
+    factor = config.rope_scaling["factor"]  # `8` in the original implementation
+    low_freq_factor = config.rope_scaling["low_freq_factor"]  # `1` in the original implementation
+    high_freq_factor = config.rope_scaling["high_freq_factor"]  # `4` in the original implementation
+    old_context_len = config.rope_scaling["original_max_position_embeddings"]  # `8192` in the original implementation
+
+    low_freq_wavelen = old_context_len / low_freq_factor
+    high_freq_wavelen = old_context_len / high_freq_factor
+
+    wavelen = 2 * math.pi / inv_freq
+    # wavelen < high_freq_wavelen: do nothing
+    # wavelen > low_freq_wavelen: divide by factor
+    inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
+    # otherwise: interpolate between the two, using a smooth factor
+    smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
+    smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
+    is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
+    inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
+
+    return inv_freq_llama, attention_factor
+
+
+# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
+# from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
+# parameterizations, as long as the callable has the same signature.
+ROPE_INIT_FUNCTIONS = {
+    "default": _compute_default_rope_parameters,
+    "linear": _compute_linear_scaling_rope_parameters,
+    "dynamic": _compute_dynamic_ntk_parameters,
+    "yarn": _compute_yarn_parameters,
+    "longrope": _compute_longrope_parameters,
+    "llama3": _compute_llama3_parameters,
+}
+
+
+def _check_received_keys(
+    rope_type: str,
+    received_keys: set,
+    required_keys: set,
+    optional_keys: Optional[set] = None,
+    ignore_keys: Optional[set] = None,
+):
+    """Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
+    # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
+    if "type" in received_keys:
+        received_keys -= {"type"}
+        required_keys.add("rope_type")
+
+    # Some models need to store model-specific keys, and we don't want to throw warning at them
+    if ignore_keys is not None:
+        received_keys -= ignore_keys
+
+    missing_keys = required_keys - received_keys
+    if missing_keys:
+        raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
+
+    if optional_keys is not None:
+        unused_keys = received_keys - required_keys - optional_keys
+    else:
+        unused_keys = received_keys - required_keys
+    if unused_keys:
+        logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
+
+
+def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
+    rope_scaling = config.rope_scaling
+    rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))  # BC: "rope_type" was originally "type"
+    required_keys = {"rope_type"}
+    received_keys = set(rope_scaling.keys())
+    _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
+
+
+def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
+    rope_scaling = config.rope_scaling
+    rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))  # BC: "rope_type" was originally "type"
+    required_keys = {"rope_type", "factor"}
+    received_keys = set(rope_scaling.keys())
+    _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
+
+    factor = rope_scaling["factor"]
+    if factor is None or not isinstance(factor, float) or factor < 1.0:
+        logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
+
+
+def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
+    rope_scaling = config.rope_scaling
+    rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))  # BC: "rope_type" was originally "type"
+    required_keys = {"rope_type", "factor"}
+    # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
+    optional_keys = {"original_max_position_embeddings"}
+    received_keys = set(rope_scaling.keys())
+    _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
+
+    factor = rope_scaling["factor"]
+    if factor is None or not isinstance(factor, float) or factor < 1.0:
+        logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
+
+
+def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
+    rope_scaling = config.rope_scaling
+    rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))  # BC: "rope_type" was originally "type"
+    required_keys = {"rope_type", "factor"}
+    optional_keys = {
+        "attention_factor",
+        "beta_fast",
+        "beta_slow",
+        "original_max_position_embeddings",
+        "mscale",
+        "mscale_all_dim",
+        "truncate",
+    }
+    received_keys = set(rope_scaling.keys())
+    _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
+
+    factor = rope_scaling["factor"]
+    if factor is None or not isinstance(factor, float) or factor < 1.0:
+        logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
+
+    attention_factor = rope_scaling.get("attention_factor")
+    if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
+        logger.warning(
+            f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
+        )
+    beta_fast = rope_scaling.get("beta_fast")
+    if beta_fast is not None and not isinstance(beta_fast, float):
+        logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
+    beta_slow = rope_scaling.get("beta_slow")
+    if beta_slow is not None and not isinstance(beta_slow, float):
+        logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
+
+    if (beta_fast or 32) < (beta_slow or 1):
+        logger.warning(
+            f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
+            f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
+        )
+
+    # Models should set `config.rope_scaling["original_max_position_embeddings"]` to their original (pre-yarn) context
+    # length, with `config.max_position_embeddings` corresponding to their post-yarn context length.
+    # However, for BC purposes, we allow the former to be unset.
+    original_max_position_embeddings = config.rope_scaling.get("original_max_position_embeddings")
+    if original_max_position_embeddings is not None:
+        # Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths.
+        implicit_factor = config.max_position_embeddings / original_max_position_embeddings
+        if implicit_factor != factor:
+            logger.warning_once(
+                f"The explicitly set RoPE scaling factor (config.rope_scaling['factor'] = {factor}) does not match "
+                "the ratio implicitly set by other parameters (implicit factor = "
+                "post-yarn context length / pre-yarn context length = "
+                "config.max_position_embeddings / config.rope_scaling['original_max_position_embeddings'] = "
+                f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected "
+                "behaviour in model usage, please correct the 'max_position_embeddings' fields in the model config."
+            )
+    # No `config.rope_scaling["original_max_position_embeddings"]`. Is `config.max_position_embeddings` the
+    # pre-yarn or the post-yarn context length?
+    # BC: we assume it is the pre-yarn context length.
+    else:
+        logger.warning_once(
+            "config.rope_scaling['original_max_position_embeddings'], the pre-yarn context length, is unset. We will "
+            "**assume** config.max_position_embeddings holds the pre-yarn context length. Some use cases may expect "
+            "config.max_position_embeddings to hold the post-yarn context length (pre-yarn context length * "
+            "factor) -- we recommend updating both fields for optimal downstream model usage."
+        )
+
+
+def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
+    rope_scaling = config.rope_scaling
+    rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))  # BC: "rope_type" was originally "type"
+    required_keys = {"rope_type", "short_factor", "long_factor"}
+    # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
+    optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
+    received_keys = set(rope_scaling.keys())
+    _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
+
+    partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
+    head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+    dim = int(head_dim * partial_rotary_factor)
+
+    short_factor = rope_scaling.get("short_factor")
+    if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
+        logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
+    if len(short_factor) != dim // 2:
+        logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
+
+    long_factor = rope_scaling.get("long_factor")
+    if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
+        logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
+    if len(long_factor) != dim // 2:
+        logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
+
+    # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
+    # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
+    # unique to longrope (= undesirable)
+    if hasattr(config, "original_max_position_embeddings"):
+        logger.warning_once(
+            "This model has set a `original_max_position_embeddings` field, to be used together with "
+            "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
+            "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
+            "as it is compatible with most model architectures."
+        )
+    else:
+        factor = rope_scaling.get("factor")
+        if factor is None:
+            logger.warning("Missing required keys in `rope_scaling`: 'factor'")
+        elif not isinstance(factor, float) or factor < 1.0:
+            logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
+
+        attention_factor = rope_scaling.get("attention_factor")
+        if attention_factor is not None:
+            if not isinstance(attention_factor, float) or attention_factor < 0.0:
+                logger.warning(
+                    f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
+                )
+
+
+def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
+    rope_scaling = config.rope_scaling
+    rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))  # BC: "rope_type" was originally "type"
+    required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
+    received_keys = set(rope_scaling.keys())
+    _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
+
+    factor = rope_scaling["factor"]
+    if factor is None or not isinstance(factor, float) or factor < 1.0:
+        logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
+
+    low_freq_factor = rope_scaling["low_freq_factor"]
+    high_freq_factor = rope_scaling["high_freq_factor"]
+    if low_freq_factor is None or not isinstance(low_freq_factor, float):
+        logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
+    if high_freq_factor is None or not isinstance(high_freq_factor, float):
+        logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
+    if high_freq_factor <= low_freq_factor:
+        logger.warning(
+            "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
+            f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
+        )
+
+    original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
+    if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
+        logger.warning(
+            "`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
+            f"{original_max_position_embeddings}"
+        )
+    if original_max_position_embeddings >= config.max_position_embeddings:
+        logger.warning(
+            "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
+            f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
+        )
+
+
+# Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
+ROPE_VALIDATION_FUNCTIONS = {
+    "default": _validate_default_rope_parameters,
+    "linear": _validate_linear_scaling_rope_parameters,
+    "dynamic": _validate_dynamic_scaling_rope_parameters,
+    "yarn": _validate_yarn_parameters,
+    "longrope": _validate_longrope_parameters,
+    "llama3": _validate_llama3_parameters,
+}
+
+
+def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
+    """
+    Validate the RoPE config arguments, given a `PretrainedConfig` object
+    """
+    rope_scaling = getattr(config, "rope_scaling", None)  # not a default parameter in `PretrainedConfig`
+    if rope_scaling is None:
+        return
+
+    # BC: "rope_type" was originally "type"
+    rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
+    validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
+    if validation_fn is not None:
+        validation_fn(config, ignore_keys=ignore_keys)
+    else:
+        logger.warning(
+            f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
+        )
diff --git a/phivenv/Lib/site-packages/transformers/modeling_tf_outputs.py b/phivenv/Lib/site-packages/transformers/modeling_tf_outputs.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7491b67f9aebb93b95632a5a7db2fd85cf5b4c7
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/modeling_tf_outputs.py
@@ -0,0 +1,990 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import warnings
+from dataclasses import dataclass
+
+import tensorflow as tf
+
+from .utils import ModelOutput
+
+
+@dataclass
+class TFBaseModelOutput(ModelOutput):
+    """
+    Base class for model's outputs, with potential hidden states and attentions.
+
+    Args:
+        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    last_hidden_state: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFBaseModelOutputWithNoAttention(ModelOutput):
+    """
+    Base class for model's outputs, with potential hidden states.
+
+    Args:
+        last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
+            the output of each layer) of shape `(batch_size, num_channels, height, width)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+    """
+
+    last_hidden_state: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor, ...] | None = None
+
+
+@dataclass
+class TFBaseModelOutputWithPooling(ModelOutput):
+    """
+    Base class for model's outputs that also contains a pooling of the last hidden states.
+
+    Args:
+        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
+            Last layer hidden-state of the first token of the sequence (classification token) further processed by a
+            Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
+            prediction (classification) objective during pretraining.
+
+            This output is usually *not* a good summary of the semantic content of the input, you're often better with
+            averaging or pooling the sequence of hidden-states for the whole input sequence.
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    last_hidden_state: tf.Tensor | None = None
+    pooler_output: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
+    """
+    Base class for model's outputs that also contains a pooling of the last hidden states.
+
+    Args:
+        last_hidden_state (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
+            Last layer hidden-state after a pooling operation on the spatial dimensions.
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
+            the output of each layer) of shape `(batch_size, num_channels, height, width)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+    """
+
+    last_hidden_state: tf.Tensor | None = None
+    pooler_output: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor, ...] | None = None
+
+
+@dataclass
+class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
+    """
+    Base class for model's outputs that also contains a pooling of the last hidden states.
+
+    Args:
+        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
+            Last layer hidden-state of the first token of the sequence (classification token) further processed by a
+            Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
+            prediction (classification) objective during pretraining.
+
+            This output is usually *not* a good summary of the semantic content of the input, you're often better with
+            averaging or pooling the sequence of hidden-states for the whole input sequence.
+        past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+            sequence_length, embed_size_per_head)`).
+
+            Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+            `past_key_values` input) to speed up sequential decoding.
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+    """
+
+    last_hidden_state: tf.Tensor | None = None
+    pooler_output: tf.Tensor | None = None
+    past_key_values: list[tf.Tensor] | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+    cross_attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFBaseModelOutputWithPast(ModelOutput):
+    """
+    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+    Args:
+        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+
+            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+            hidden_size)` is output.
+        past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+            sequence_length, embed_size_per_head)`).
+
+            Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+            `past_key_values` input) to speed up sequential decoding.
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    last_hidden_state: tf.Tensor | None = None
+    past_key_values: list[tf.Tensor] | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFBaseModelOutputWithCrossAttentions(ModelOutput):
+    """
+    Base class for model's outputs, with potential hidden states and attentions.
+
+    Args:
+        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+        hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+    """
+
+    last_hidden_state: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+    cross_attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
+    """
+    Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+
+    Args:
+        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the model.
+
+            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+            hidden_size)` is output.
+        past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+            sequence_length, embed_size_per_head)`).
+
+            Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+            `past_key_values` input) to speed up sequential decoding.
+        hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+    """
+
+    last_hidden_state: tf.Tensor | None = None
+    past_key_values: list[tf.Tensor] | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+    cross_attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFSeq2SeqModelOutput(ModelOutput):
+    """
+    Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
+    decoding.
+
+    Args:
+        last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
+            Sequence of hidden-states at the output of the last layer of the decoder of the model.
+
+            If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
+            hidden_size)` is output.
+        past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+            sequence_length, embed_size_per_head)`).
+
+            Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
+            used (see `past_key_values` input) to speed up sequential decoding.
+        decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+        decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder of the model.
+        encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+        encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+    """
+
+    last_hidden_state: tf.Tensor | None = None
+    past_key_values: list[tf.Tensor] | None = None
+    decoder_hidden_states: tuple[tf.Tensor] | None = None
+    decoder_attentions: tuple[tf.Tensor] | None = None
+    cross_attentions: tuple[tf.Tensor] | None = None
+    encoder_last_hidden_state: tf.Tensor | None = None
+    encoder_hidden_states: tuple[tf.Tensor] | None = None
+    encoder_attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFCausalLMOutput(ModelOutput):
+    """
+    Base class for causal language model (or autoregressive) outputs.
+
+    Args:
+        loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided):
+            Language modeling loss (for next-token prediction).
+        logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: tf.Tensor | None = None
+    logits: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFCausalLMOutputWithPast(ModelOutput):
+    """
+    Base class for causal language model (or autoregressive) outputs.
+
+    Args:
+        loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided):
+            Language modeling loss (for next-token prediction).
+        logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+            sequence_length, embed_size_per_head)`).
+
+            Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+            `past_key_values` input) to speed up sequential decoding.
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: tf.Tensor | None = None
+    logits: tf.Tensor | None = None
+    past_key_values: list[tf.Tensor] | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFCausalLMOutputWithCrossAttentions(ModelOutput):
+    """
+    Base class for causal language model (or autoregressive) outputs.
+
+    Args:
+        loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided):
+            Language modeling loss (for next-token prediction).
+        logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+        cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+            sequence_length, embed_size_per_head)`).
+
+            Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+            `past_key_values` input) to speed up sequential decoding.
+    """
+
+    loss: tf.Tensor | None = None
+    logits: tf.Tensor | None = None
+    past_key_values: list[tf.Tensor] | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+    cross_attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFMaskedLMOutput(ModelOutput):
+    """
+    Base class for masked language models outputs.
+
+    Args:
+        loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided):
+            Masked language modeling (MLM) loss.
+        logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: tf.Tensor | None = None
+    logits: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFSeq2SeqLMOutput(ModelOutput):
+    """
+    Base class for sequence-to-sequence language models outputs.
+
+    Args:
+        loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided):
+            Language modeling loss.
+        logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+        past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+            sequence_length, embed_size_per_head)`).
+
+            Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
+            used (see `past_key_values` input) to speed up sequential decoding.
+        decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+        decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
+            weighted average in the cross-attention heads.
+        encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder of the model.
+        encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+        encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+    """
+
+    loss: tf.Tensor | None = None
+    logits: tf.Tensor | None = None
+    past_key_values: list[tf.Tensor] | None = None
+    decoder_hidden_states: tuple[tf.Tensor] | None = None
+    decoder_attentions: tuple[tf.Tensor] | None = None
+    cross_attentions: tuple[tf.Tensor] | None = None
+    encoder_last_hidden_state: tf.Tensor | None = None
+    encoder_hidden_states: tuple[tf.Tensor] | None = None
+    encoder_attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFNextSentencePredictorOutput(ModelOutput):
+    """
+    Base class for outputs of models predicting if two sentences are consecutive or not.
+
+    Args:
+        loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `next_sentence_label` is provided):
+            Next sentence prediction loss.
+        logits (`tf.Tensor` of shape `(batch_size, 2)`):
+            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
+            before SoftMax).
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: tf.Tensor | None = None
+    logits: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFSequenceClassifierOutput(ModelOutput):
+    """
+    Base class for outputs of sentence classification models.
+
+    Args:
+        loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `labels` is provided):
+            Classification (or regression if config.num_labels==1) loss.
+        logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+            Classification (or regression if config.num_labels==1) scores (before SoftMax).
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: tf.Tensor | None = None
+    logits: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
+    """
+    Base class for outputs of sequence-to-sequence sentence classification models.
+
+    Args:
+        loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `label` is provided):
+            Classification (or regression if config.num_labels==1) loss.
+        logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+            Classification (or regression if config.num_labels==1) scores (before SoftMax).
+        past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+            sequence_length, embed_size_per_head)`).
+
+            Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
+            used (see `past_key_values` input) to speed up sequential decoding.
+        decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+        decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`
+        encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder of the model.
+        encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+        encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+    """
+
+    loss: tf.Tensor | None = None
+    logits: tf.Tensor | None = None
+    past_key_values: list[tf.Tensor] | None = None
+    decoder_hidden_states: tuple[tf.Tensor] | None = None
+    decoder_attentions: tuple[tf.Tensor] | None = None
+    cross_attentions: tuple[tf.Tensor] | None = None
+    encoder_last_hidden_state: tf.Tensor | None = None
+    encoder_hidden_states: tuple[tf.Tensor] | None = None
+    encoder_attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFSemanticSegmenterOutput(ModelOutput):
+    """
+    Base class for outputs of semantic segmentation models.
+
+    Args:
+        loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Classification (or regression if config.num_labels==1) loss.
+        logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
+            Classification scores for each pixel.
+
+            
+
+            The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
+            to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
+            original image size as post-processing. You should always check your logits shape and resize as needed.
+
+            
+
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
+            the output of each layer) of shape `(batch_size, patch_size, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: tf.Tensor | None = None
+    logits: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFSemanticSegmenterOutputWithNoAttention(ModelOutput):
+    """
+    Base class for outputs of semantic segmentation models that do not output attention scores.
+
+    Args:
+        loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Classification (or regression if config.num_labels==1) loss.
+        logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
+            Classification scores for each pixel.
+
+            
+
+            The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
+            to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
+            original image size as post-processing. You should always check your logits shape and resize as needed.
+
+            
+
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
+            the output of each layer) of shape `(batch_size, patch_size, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+    """
+
+    loss: tf.Tensor | None = None
+    logits: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFImageClassifierOutput(ModelOutput):
+    """
+    Base class for outputs of image classification models.
+
+    Args:
+        loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Classification (or regression if config.num_labels==1) loss.
+        logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+            Classification (or regression if config.num_labels==1) scores (before SoftMax).
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
+            the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called
+            feature maps) of the model at the output of each stage.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: tf.Tensor | None = None
+    logits: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFMultipleChoiceModelOutput(ModelOutput):
+    """
+    Base class for outputs of multiple choice models.
+
+    Args:
+        loss (`tf.Tensor` of shape *(batch_size, )*, *optional*, returned when `labels` is provided):
+            Classification loss.
+        logits (`tf.Tensor` of shape `(batch_size, num_choices)`):
+            *num_choices* is the second dimension of the input tensors. (see *input_ids* above).
+
+            Classification scores (before SoftMax).
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: tf.Tensor | None = None
+    logits: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFTokenClassifierOutput(ModelOutput):
+    """
+    Base class for outputs of token classification models.
+
+    Args:
+        loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of unmasked labels, returned when `labels` is provided) :
+            Classification loss.
+        logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`):
+            Classification scores (before SoftMax).
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: tf.Tensor | None = None
+    logits: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFQuestionAnsweringModelOutput(ModelOutput):
+    """
+    Base class for outputs of question answering models.
+
+    Args:
+        loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `start_positions` and `end_positions` are provided):
+            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
+        start_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+            Span-start scores (before SoftMax).
+        end_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+            Span-end scores (before SoftMax).
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: tf.Tensor | None = None
+    start_logits: tf.Tensor | None = None
+    end_logits: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
+    """
+    Base class for outputs of sequence-to-sequence question answering models.
+
+    Args:
+        loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
+        start_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+            Span-start scores (before SoftMax).
+        end_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
+            Span-end scores (before SoftMax).
+        past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+            sequence_length, embed_size_per_head)`).
+
+            Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
+            used (see `past_key_values` input) to speed up sequential decoding.
+        decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
+        decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+        encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+            Sequence of hidden-states at the output of the last layer of the encoder of the model.
+        encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
+        encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
+            self-attention heads.
+    """
+
+    loss: tf.Tensor | None = None
+    start_logits: tf.Tensor | None = None
+    end_logits: tf.Tensor | None = None
+    past_key_values: list[tf.Tensor] | None = None
+    decoder_hidden_states: tuple[tf.Tensor] | None = None
+    decoder_attentions: tuple[tf.Tensor] | None = None
+    encoder_last_hidden_state: tf.Tensor | None = None
+    encoder_hidden_states: tuple[tf.Tensor] | None = None
+    encoder_attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFSequenceClassifierOutputWithPast(ModelOutput):
+    """
+    Base class for outputs of sentence classification models.
+
+    Args:
+        loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `labels` is provided):
+            Classification (or regression if config.num_labels==1) loss.
+        logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+            Classification (or regression if config.num_labels==1) scores (before SoftMax).
+        past_key_values (`list[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+            List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
+            sequence_length, embed_size_per_head)`).
+
+            Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
+            `past_key_values` input) to speed up sequential decoding.
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
+            `(batch_size, sequence_length, hidden_size)`.
+
+            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+            sequence_length)`.
+
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: tf.Tensor | None = None
+    logits: tf.Tensor | None = None
+    past_key_values: list[tf.Tensor] | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+
+
+@dataclass
+class TFImageClassifierOutputWithNoAttention(ModelOutput):
+    """
+    Base class for outputs of image classification models.
+
+    Args:
+        loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+            Classification (or regression if config.num_labels==1) loss.
+        logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
+            Classification (or regression if config.num_labels==1) scores (before SoftMax).
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
+            the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also called
+            feature maps) of the model at the output of each stage.
+    """
+
+    loss: tf.Tensor | None = None
+    logits: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor, ...] | None = None
+
+
+@dataclass
+class TFMaskedImageModelingOutput(ModelOutput):
+    """
+    Base class for outputs of masked image completion / in-painting models.
+
+    Args:
+        loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
+            Reconstruction loss.
+        reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
+           Reconstructed / completed images.
+        hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when
+        `config.output_hidden_states=True`):
+            Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
+            the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called
+            feature maps) of the model at the output of each stage.
+        attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when
+        `config.output_attentions=True`):
+            Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`.
+            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+            heads.
+    """
+
+    loss: tf.Tensor | None = None
+    reconstruction: tf.Tensor | None = None
+    hidden_states: tuple[tf.Tensor] | None = None
+    attentions: tuple[tf.Tensor] | None = None
+
+    @property
+    def logits(self):
+        warnings.warn(
+            "logits attribute is deprecated and will be removed in version 5 of Transformers."
+            " Please use the reconstruction attribute to retrieve the final output instead.",
+            FutureWarning,
+        )
+        return self.reconstruction
diff --git a/phivenv/Lib/site-packages/transformers/modeling_tf_pytorch_utils.py b/phivenv/Lib/site-packages/transformers/modeling_tf_pytorch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f688af7be36439311465d019de36c20c6aaae77
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/modeling_tf_pytorch_utils.py
@@ -0,0 +1,676 @@
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch - TF 2.0 general utilities."""
+
+import os
+import re
+
+import numpy
+
+from .utils import (
+    ExplicitEnum,
+    check_torch_load_is_safe,
+    expand_dims,
+    is_numpy_array,
+    is_safetensors_available,
+    is_torch_tensor,
+    logging,
+    reshape,
+    squeeze,
+    tensor_size,
+)
+from .utils import transpose as transpose_func
+
+
+if is_safetensors_available():
+    from safetensors import safe_open
+
+
+logger = logging.get_logger(__name__)
+
+
+class TransposeType(ExplicitEnum):
+    """
+    Possible ...
+    """
+
+    NO = "no"
+    SIMPLE = "simple"
+    CONV1D = "conv1d"
+    CONV2D = "conv2d"
+
+
+def convert_tf_weight_name_to_pt_weight_name(
+    tf_name, start_prefix_to_remove="", tf_weight_shape=None, name_scope=None
+):
+    """
+    Convert a TF 2.0 model variable name in a pytorch model weight name.
+
+    Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
+
+        - '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
+        - '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
+
+    return tuple with:
+
+        - pytorch model weight name
+        - transpose: `TransposeType` member indicating whether and how TF2.0 and PyTorch weights matrices should be
+          transposed with regards to each other
+    """
+    if name_scope is not None:
+        if not tf_name.startswith(name_scope) and "final_logits_bias" not in tf_name:
+            raise ValueError(
+                f"Weight name {tf_name} does not start with name_scope {name_scope}. This is an internal error "
+                "in Transformers, so (unless you were doing something really evil) please open an issue to report it!"
+            )
+        tf_name = tf_name[len(name_scope) :]
+        tf_name = tf_name.lstrip("/")
+    tf_name = tf_name.replace(":0", "")  # device ids
+    if (len(tf_name) > 2048 and "___" in tf_name) or tf_name.count("___") > 10:
+        # ReDOS check
+        raise ValueError("TF variable name is too long or contains too many ___ separators: " + tf_name)
+    tf_name = re.sub(
+        r"/[^/]*___([^/]*)/", r"/\1/", tf_name
+    )  # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
+    tf_name = tf_name.replace(
+        "_._", "/"
+    )  # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
+    tf_name = re.sub(r"//+", "/", tf_name)  # Remove empty levels at the end
+    tf_name = tf_name.split("/")  # Convert from TF2.0 '/' separators to PyTorch '.' separators
+    # Some weights have a single name without "/" such as final_logits_bias in BART
+    if len(tf_name) > 1:
+        tf_name = tf_name[1:]  # Remove level zero
+
+    tf_weight_shape = list(tf_weight_shape)
+
+    # When should we transpose the weights
+    if tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 4:
+        transpose = TransposeType.CONV2D
+    elif tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 3:
+        transpose = TransposeType.CONV1D
+    elif bool(
+        tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"]
+        or "emb_projs" in tf_name
+        or "out_projs" in tf_name
+    ):
+        transpose = TransposeType.SIMPLE
+    else:
+        transpose = TransposeType.NO
+
+    # Convert standard TF2.0 names in PyTorch names
+    if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma":
+        tf_name[-1] = "weight"
+    if tf_name[-1] == "beta":
+        tf_name[-1] = "bias"
+
+    # The SeparableConv1D TF layer contains two weights that are translated to PyTorch Conv1D here
+    if tf_name[-1] == "pointwise_kernel" or tf_name[-1] == "depthwise_kernel":
+        tf_name[-1] = tf_name[-1].replace("_kernel", ".weight")
+
+    # Remove prefix if needed
+    tf_name = ".".join(tf_name)
+    if start_prefix_to_remove:
+        tf_name = tf_name.replace(start_prefix_to_remove, "", 1)
+
+    return tf_name, transpose
+
+
+def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf=True):
+    """
+    Apply a transpose to some weight then tries to reshape the weight to the same shape as a given shape, all in a
+    framework agnostic way.
+    """
+    if transpose is TransposeType.CONV2D:
+        # Conv2D weight:
+        #    PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
+        # -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
+        axes = (2, 3, 1, 0) if pt_to_tf else (3, 2, 0, 1)
+        weight = transpose_func(weight, axes=axes)
+    elif transpose is TransposeType.CONV1D:
+        # Conv1D weight:
+        #    PT: (num_out_channel, num_in_channel, kernel)
+        # -> TF: (kernel, num_in_channel, num_out_channel)
+        weight = transpose_func(weight, axes=(2, 1, 0))
+    elif transpose is TransposeType.SIMPLE:
+        weight = transpose_func(weight)
+
+    if match_shape is None:
+        return weight
+
+    if len(match_shape) < len(weight.shape):
+        weight = squeeze(weight)
+    elif len(match_shape) > len(weight.shape):
+        weight = expand_dims(weight, axis=0)
+
+    if list(match_shape) != list(weight.shape):
+        try:
+            weight = reshape(weight, match_shape)
+        except AssertionError as e:
+            e.args += (match_shape, match_shape)
+            raise e
+
+    return weight
+
+
+#####################
+# PyTorch => TF 2.0 #
+#####################
+
+
+def load_pytorch_checkpoint_in_tf2_model(
+    tf_model,
+    pytorch_checkpoint_path,
+    tf_inputs=None,
+    allow_missing_keys=False,
+    output_loading_info=False,
+    _prefix=None,
+    tf_to_pt_weight_rename=None,
+):
+    """Load pytorch checkpoints in a TF 2.0 model"""
+    try:
+        import tensorflow as tf  # noqa: F401
+        import torch  # noqa: F401
+        from safetensors.torch import load_file as safe_load_file  # noqa: F401
+    except ImportError:
+        logger.error(
+            "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
+            "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
+        )
+        raise
+
+    # Treats a single file as a collection of shards with 1 shard.
+    if isinstance(pytorch_checkpoint_path, str):
+        pytorch_checkpoint_path = [pytorch_checkpoint_path]
+
+    # Loads all shards into a single state dictionary
+    pt_state_dict = {}
+    for path in pytorch_checkpoint_path:
+        pt_path = os.path.abspath(path)
+        logger.info(f"Loading PyTorch weights from {pt_path}")
+        if pt_path.endswith(".safetensors"):
+            state_dict = safe_load_file(pt_path)
+        else:
+            check_torch_load_is_safe()
+            state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
+
+        pt_state_dict.update(state_dict)
+
+    logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters")
+
+    return load_pytorch_weights_in_tf2_model(
+        tf_model,
+        pt_state_dict,
+        tf_inputs=tf_inputs,
+        allow_missing_keys=allow_missing_keys,
+        output_loading_info=output_loading_info,
+        _prefix=_prefix,
+        tf_to_pt_weight_rename=tf_to_pt_weight_rename,
+    )
+
+
+def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_missing_keys=False):
+    """Load pytorch checkpoints in a TF 2.0 model"""
+    pt_state_dict = pt_model.state_dict()
+
+    return load_pytorch_weights_in_tf2_model(
+        tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys
+    )
+
+
+def load_pytorch_weights_in_tf2_model(
+    tf_model,
+    pt_state_dict,
+    tf_inputs=None,
+    allow_missing_keys=False,
+    output_loading_info=False,
+    _prefix=None,
+    tf_to_pt_weight_rename=None,
+):
+    """Load pytorch state_dict in a TF 2.0 model."""
+    try:
+        import tensorflow as tf  # noqa: F401
+        import torch  # noqa: F401
+    except ImportError:
+        logger.error(
+            "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
+            "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
+        )
+        raise
+
+    # Numpy doesn't understand bfloat16, so upcast to a dtype that doesn't lose precision
+    pt_state_dict = {
+        k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
+    }
+    return load_pytorch_state_dict_in_tf2_model(
+        tf_model,
+        pt_state_dict,
+        tf_inputs=tf_inputs,
+        allow_missing_keys=allow_missing_keys,
+        output_loading_info=output_loading_info,
+        _prefix=_prefix,
+        tf_to_pt_weight_rename=tf_to_pt_weight_rename,
+    )
+
+
+def _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name):
+    if len(unexpected_keys) > 0:
+        logger.warning(
+            "Some weights of the PyTorch model were not used when initializing the TF 2.0 model"
+            f" {class_name}: {unexpected_keys}\n- This IS expected if you are initializing"
+            f" {class_name} from a PyTorch model trained on another task or with another architecture"
+            " (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS"
+            f" NOT expected if you are initializing {class_name} from a PyTorch model that you expect"
+            " to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a"
+            " BertForSequenceClassification model)."
+        )
+    else:
+        logger.warning(f"All PyTorch model weights were used when initializing {class_name}.\n")
+    if len(missing_keys) > 0:
+        logger.warning(
+            f"Some weights or buffers of the TF 2.0 model {class_name} were not initialized from the"
+            f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a"
+            " down-stream task to be able to use it for predictions and inference."
+        )
+    else:
+        logger.warning(
+            f"All the weights of {class_name} were initialized from the PyTorch model.\n"
+            "If your task is similar to the task the model of the checkpoint was trained on, "
+            f"you can already use {class_name} for predictions without further training."
+        )
+
+    if len(mismatched_keys) > 0:
+        mismatched_warning = "\n".join(
+            [
+                f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+                for key, shape1, shape2 in mismatched_keys
+            ]
+        )
+        logger.warning(
+            f"Some weights of {class_name} were not initialized from the model checkpoint"
+            f" are newly initialized because the shapes did not"
+            f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
+            " to use it for predictions and inference."
+        )
+
+
+def load_pytorch_state_dict_in_tf2_model(
+    tf_model,
+    pt_state_dict,
+    tf_inputs=None,
+    allow_missing_keys=False,
+    output_loading_info=False,
+    _prefix=None,
+    tf_to_pt_weight_rename=None,
+    ignore_mismatched_sizes=False,
+    skip_logger_warnings=False,
+):
+    """Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
+    safetensors archive created with the safe_open() function."""
+    import tensorflow as tf
+
+    if tf_inputs is None:
+        tf_inputs = tf_model.dummy_inputs
+
+    if _prefix is None:
+        _prefix = ""
+    if tf_inputs:
+        with tf.name_scope(_prefix):
+            tf_model(tf_inputs, training=False)  # Make sure model is built
+    # Convert old format to new format if needed from a PyTorch state_dict
+    tf_keys_to_pt_keys = {}
+    for key in pt_state_dict:
+        new_key = None
+        if "gamma" in key:
+            new_key = key.replace("gamma", "weight")
+        if "beta" in key:
+            new_key = key.replace("beta", "bias")
+        if "running_var" in key:
+            new_key = key.replace("running_var", "moving_variance")
+        if "running_mean" in key:
+            new_key = key.replace("running_mean", "moving_mean")
+
+        # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
+        key_components = key.split(".")
+        name = None
+        if key_components[-3::2] == ["parametrizations", "original0"]:
+            name = key_components[-2] + "_g"
+        elif key_components[-3::2] == ["parametrizations", "original1"]:
+            name = key_components[-2] + "_v"
+        if name is not None:
+            key_components = key_components[:-3] + [name]
+            new_key = ".".join(key_components)
+
+        if new_key is None:
+            new_key = key
+        tf_keys_to_pt_keys[new_key] = key
+
+    # Matt: All TF models store the actual model stem in a MainLayer class, including the base model.
+    # In PT, the derived models (with heads) use the base model class as the stem instead,
+    # and there is no MainLayer class. This means that TF base classes have one
+    # extra layer in their weight names, corresponding to the MainLayer class. This code block compensates for that.
+    start_prefix_to_remove = ""
+    if not any(s.startswith(tf_model.base_model_prefix) for s in tf_keys_to_pt_keys):
+        start_prefix_to_remove = tf_model.base_model_prefix + "."
+
+    symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
+    tf_loaded_numel = 0
+    all_pytorch_weights = set(tf_keys_to_pt_keys.keys())
+    missing_keys = []
+    mismatched_keys = []
+    is_safetensor_archive = hasattr(pt_state_dict, "get_tensor")
+    for symbolic_weight in symbolic_weights:
+        sw_name = symbolic_weight.name
+        name, transpose = convert_tf_weight_name_to_pt_weight_name(
+            sw_name,
+            start_prefix_to_remove=start_prefix_to_remove,
+            tf_weight_shape=symbolic_weight.shape,
+            name_scope=_prefix,
+        )
+        if tf_to_pt_weight_rename is not None:
+            aliases = tf_to_pt_weight_rename(name)  # Is a tuple to account for possible name aliasing
+            for alias in aliases:  # The aliases are in priority order, take the first one that matches
+                if alias in tf_keys_to_pt_keys:
+                    name = alias
+                    break
+            else:
+                # If none of the aliases match, just use the first one (it'll be reported as missing)
+                name = aliases[0]
+
+        # Find associated numpy array in pytorch model state dict
+        if name not in tf_keys_to_pt_keys:
+            if allow_missing_keys:
+                missing_keys.append(name)
+                continue
+            elif tf_model._keys_to_ignore_on_load_missing is not None:
+                # authorized missing keys don't have to be loaded
+                if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing):
+                    continue
+            raise AttributeError(f"{name} not found in PyTorch model")
+        state_dict_name = tf_keys_to_pt_keys[name]
+        if is_safetensor_archive:
+            array = pt_state_dict.get_tensor(state_dict_name)
+        else:
+            array = pt_state_dict[state_dict_name]
+        try:
+            array = apply_transpose(transpose, array, symbolic_weight.shape)
+        except tf.errors.InvalidArgumentError as e:
+            if not ignore_mismatched_sizes:
+                error_msg = str(e)
+                error_msg += (
+                    "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
+                )
+                raise tf.errors.InvalidArgumentError(error_msg)
+            else:
+                mismatched_keys.append((name, array.shape, symbolic_weight.shape))
+                continue
+
+        tf_loaded_numel += tensor_size(array)
+
+        symbolic_weight.assign(tf.cast(array, symbolic_weight.dtype))
+        del array  # Immediately free memory to keep peak usage as low as possible
+        all_pytorch_weights.discard(name)
+
+    logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.")
+
+    unexpected_keys = list(all_pytorch_weights)
+
+    if tf_model._keys_to_ignore_on_load_missing is not None:
+        for pat in tf_model._keys_to_ignore_on_load_missing:
+            missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
+    if tf_model._keys_to_ignore_on_load_unexpected is not None:
+        for pat in tf_model._keys_to_ignore_on_load_unexpected:
+            unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
+    if not skip_logger_warnings:
+        _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__)
+
+    if output_loading_info:
+        loading_info = {
+            "missing_keys": missing_keys,
+            "unexpected_keys": unexpected_keys,
+            "mismatched_keys": mismatched_keys,
+        }
+        return tf_model, loading_info
+
+    return tf_model
+
+
+def load_sharded_pytorch_safetensors_in_tf2_model(
+    tf_model,
+    safetensors_shards,
+    tf_inputs=None,
+    allow_missing_keys=False,
+    output_loading_info=False,
+    _prefix=None,
+    tf_to_pt_weight_rename=None,
+    ignore_mismatched_sizes=False,
+):
+    all_loading_infos = []
+    for shard in safetensors_shards:
+        with safe_open(shard, framework="tf") as safetensors_archive:
+            tf_model, loading_info = load_pytorch_state_dict_in_tf2_model(
+                tf_model,
+                safetensors_archive,
+                tf_inputs=tf_inputs,
+                allow_missing_keys=allow_missing_keys,
+                output_loading_info=True,
+                _prefix=_prefix,
+                tf_to_pt_weight_rename=tf_to_pt_weight_rename,
+                ignore_mismatched_sizes=ignore_mismatched_sizes,
+                skip_logger_warnings=True,  # We will emit merged warnings at the end
+            )
+        all_loading_infos.append(loading_info)
+    # Now we just need to merge the loading info
+    # Keys are missing only if they're missing in *every* shard
+    missing_keys = sorted(set.intersection(*[set(info["missing_keys"]) for info in all_loading_infos]))
+    # Keys are unexpected/mismatched if they're unexpected/mismatched in *any* shard
+    unexpected_keys = sum([info["unexpected_keys"] for info in all_loading_infos], [])
+    mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], [])
+
+    _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__)
+
+    if output_loading_info:
+        loading_info = {
+            "missing_keys": missing_keys,
+            "unexpected_keys": unexpected_keys,
+            "mismatched_keys": mismatched_keys,
+        }
+        return tf_model, loading_info
+
+    return tf_model
+
+
+#####################
+# TF 2.0 => PyTorch #
+#####################
+
+
+def load_tf2_checkpoint_in_pytorch_model(
+    pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
+):
+    """
+    Load TF 2.0 HDF5 checkpoint in a PyTorch model We use HDF5 to easily do transfer learning (see
+    https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
+    """
+    try:
+        import tensorflow as tf  # noqa: F401
+        import torch  # noqa: F401
+    except ImportError:
+        logger.error(
+            "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
+            "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
+        )
+        raise
+
+    import transformers
+
+    from .modeling_tf_utils import load_tf_weights
+
+    logger.info(f"Loading TensorFlow weights from {tf_checkpoint_path}")
+
+    # Instantiate and load the associated TF 2.0 model
+    tf_model_class_name = "TF" + pt_model.__class__.__name__  # Add "TF" at the beginning
+    tf_model_class = getattr(transformers, tf_model_class_name)
+    tf_model = tf_model_class(pt_model.config)
+
+    if tf_inputs is None:
+        tf_inputs = tf_model.dummy_inputs
+
+    if tf_inputs is not None:
+        tf_model(tf_inputs, training=False)  # Make sure model is built
+
+    load_tf_weights(tf_model, tf_checkpoint_path)
+
+    return load_tf2_model_in_pytorch_model(
+        pt_model, tf_model, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
+    )
+
+
+def load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=False, output_loading_info=False):
+    """Load TF 2.0 model in a pytorch model"""
+    weights = tf_model.weights
+
+    return load_tf2_weights_in_pytorch_model(
+        pt_model, weights, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
+    )
+
+
+def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=False, output_loading_info=False):
+    """Load TF2.0 symbolic weights in a PyTorch model"""
+    try:
+        import tensorflow as tf  # noqa: F401
+        import torch  # noqa: F401
+    except ImportError:
+        logger.error(
+            "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
+            "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
+        )
+        raise
+
+    tf_state_dict = {tf_weight.name: tf_weight.numpy() for tf_weight in tf_weights}
+    return load_tf2_state_dict_in_pytorch_model(
+        pt_model, tf_state_dict, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
+    )
+
+
+def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_keys=False, output_loading_info=False):
+    import torch
+
+    new_pt_params_dict = {}
+    current_pt_params_dict = dict(pt_model.named_parameters())
+
+    # Make sure we are able to load PyTorch base models as well as derived models (with heads)
+    # TF models always have a prefix, some of PyTorch models (base ones) don't
+    start_prefix_to_remove = ""
+    if not any(s.startswith(pt_model.base_model_prefix) for s in current_pt_params_dict):
+        start_prefix_to_remove = pt_model.base_model_prefix + "."
+
+    # Build a map from potential PyTorch weight names to TF 2.0 Variables
+    tf_weights_map = {}
+    for name, tf_weight in tf_state_dict.items():
+        pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(
+            name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape
+        )
+        tf_weights_map[pt_name] = (tf_weight, transpose)
+
+    all_tf_weights = set(tf_weights_map.keys())
+    loaded_pt_weights_data_ptr = {}
+    missing_keys_pt = []
+    for pt_weight_name, pt_weight in current_pt_params_dict.items():
+        # Handle PyTorch shared weight not duplicated in TF 2.0
+        if pt_weight.data_ptr() in loaded_pt_weights_data_ptr and pt_weight.data_ptr() != 0:
+            new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()]
+            continue
+
+        pt_weight_name_to_check = pt_weight_name
+        # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
+        key_components = pt_weight_name.split(".")
+        name = None
+        if key_components[-3::2] == ["parametrizations", "original0"]:
+            name = key_components[-2] + "_g"
+        elif key_components[-3::2] == ["parametrizations", "original1"]:
+            name = key_components[-2] + "_v"
+        if name is not None:
+            key_components = key_components[:-3] + [name]
+            pt_weight_name_to_check = ".".join(key_components)
+
+        # Find associated numpy array in pytorch model state dict
+        if pt_weight_name_to_check not in tf_weights_map:
+            if allow_missing_keys:
+                missing_keys_pt.append(pt_weight_name)
+                continue
+
+            raise AttributeError(f"{pt_weight_name} not found in TF 2.0 model")
+
+        array, transpose = tf_weights_map[pt_weight_name_to_check]
+
+        array = apply_transpose(transpose, array, pt_weight.shape, pt_to_tf=False)
+
+        if numpy.isscalar(array):
+            array = numpy.array(array)
+        if not is_torch_tensor(array) and not is_numpy_array(array):
+            array = array.numpy()
+        if is_numpy_array(array):
+            # Convert to torch tensor
+            array = torch.from_numpy(array)
+
+        new_pt_params_dict[pt_weight_name] = array
+        loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = array
+        all_tf_weights.discard(pt_weight_name)
+
+    missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
+    missing_keys += missing_keys_pt
+
+    # Some models may have keys that are not in the state by design, removing them before needlessly warning
+    # the user.
+    if pt_model._keys_to_ignore_on_load_missing is not None:
+        for pat in pt_model._keys_to_ignore_on_load_missing:
+            missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
+
+    if pt_model._keys_to_ignore_on_load_unexpected is not None:
+        for pat in pt_model._keys_to_ignore_on_load_unexpected:
+            unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
+
+    if len(unexpected_keys) > 0:
+        logger.warning(
+            "Some weights of the TF 2.0 model were not used when initializing the PyTorch model"
+            f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
+            f" {pt_model.__class__.__name__} from a TF 2.0 model trained on another task or with another architecture"
+            " (e.g. initializing a BertForSequenceClassification model from a TFBertForPreTraining model).\n- This IS"
+            f" NOT expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model that you expect"
+            " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
+            " TFBertForSequenceClassification model)."
+        )
+    else:
+        logger.warning(f"All TF 2.0 model weights were used when initializing {pt_model.__class__.__name__}.\n")
+    if len(missing_keys) > 0:
+        logger.warning(
+            f"Some weights of {pt_model.__class__.__name__} were not initialized from the TF 2.0 model and are newly"
+            f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
+            " use it for predictions and inference."
+        )
+    else:
+        logger.warning(
+            f"All the weights of {pt_model.__class__.__name__} were initialized from the TF 2.0 model.\n"
+            "If your task is similar to the task the model of the checkpoint was trained on, "
+            f"you can already use {pt_model.__class__.__name__} for predictions without further training."
+        )
+
+    logger.info(f"Weights or buffers not loaded from TF 2.0 model: {all_tf_weights}")
+
+    if output_loading_info:
+        loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
+        return pt_model, loading_info
+
+    return pt_model
diff --git a/phivenv/Lib/site-packages/transformers/modeling_tf_utils.py b/phivenv/Lib/site-packages/transformers/modeling_tf_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7bb80656d1b8596e9f3313b6f768c6d2251099d
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/modeling_tf_utils.py
@@ -0,0 +1,3529 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""TF general model utils."""
+
+from __future__ import annotations
+
+import functools
+import gc
+import inspect
+import json
+import os
+import pickle
+import re
+import warnings
+from collections.abc import Mapping
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, Union
+
+import h5py
+import numpy as np
+import tensorflow as tf
+from packaging.version import parse
+
+from . import DataCollatorWithPadding, DefaultDataCollator
+from .activations_tf import get_tf_activation
+from .configuration_utils import PretrainedConfig
+from .dynamic_module_utils import custom_object_save
+from .generation import GenerationConfig, TFGenerationMixin
+from .tf_utils import (
+    convert_batch_encoding,
+    expand_1d,
+    load_attributes_from_hdf5_group,
+    save_attributes_to_hdf5_group,
+    shape_list,
+)
+from .utils import (
+    SAFE_WEIGHTS_INDEX_NAME,
+    SAFE_WEIGHTS_NAME,
+    TF2_WEIGHTS_INDEX_NAME,
+    TF2_WEIGHTS_NAME,
+    TF_WEIGHTS_NAME,
+    WEIGHTS_INDEX_NAME,
+    WEIGHTS_NAME,
+    ModelOutput,
+    PushToHubMixin,
+    cached_file,
+    download_url,
+    find_labels,
+    has_file,
+    is_offline_mode,
+    is_remote_url,
+    is_safetensors_available,
+    is_tf_symbolic_tensor,
+    logging,
+    requires_backends,
+    working_or_temp_dir,
+)
+from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
+
+
+if is_safetensors_available():
+    from safetensors import safe_open
+    from safetensors.tensorflow import save_file as safe_save_file
+
+if TYPE_CHECKING:
+    from . import PreTrainedTokenizerBase
+
+logger = logging.get_logger(__name__)
+
+if "TF_USE_LEGACY_KERAS" not in os.environ:
+    os.environ["TF_USE_LEGACY_KERAS"] = "1"  # Compatibility fix to make sure tf.keras stays at Keras 2
+elif os.environ["TF_USE_LEGACY_KERAS"] != "1":
+    logger.warning(
+        "Transformers is only compatible with Keras 2, but you have explicitly set `TF_USE_LEGACY_KERAS` to `0`. "
+        "This may result in unexpected behaviour or errors if Keras 3 objects are passed to Transformers models."
+    )
+
+try:
+    import tf_keras as keras
+    from tf_keras import backend as K
+except (ModuleNotFoundError, ImportError):
+    import keras
+    from keras import backend as K
+
+    if parse(keras.__version__).major > 2:
+        raise ValueError(
+            "Your currently installed version of Keras is Keras 3, but this is not yet supported in "
+            "Transformers. Please install the backwards-compatible tf-keras package with "
+            "`pip install tf-keras`."
+        )
+
+
+tf_logger = tf.get_logger()
+
+TFModelInputType = Union[
+    list[tf.Tensor],
+    list[np.ndarray],
+    dict[str, tf.Tensor],
+    dict[str, np.ndarray],
+    tf.Tensor,
+    np.ndarray,
+]
+
+
+def dummy_loss(y_true, y_pred):
+    if y_pred.shape.rank <= 1:
+        return y_pred
+    else:
+        reduction_axes = list(range(1, y_pred.shape.rank))
+        return tf.reduce_mean(y_pred, axis=reduction_axes)
+
+
+class TFModelUtilsMixin:
+    """
+    A few utilities for `keras.Model`, to be used as a mixin.
+    """
+
+    def num_parameters(self, only_trainable: bool = False) -> int:
+        """
+        Get the number of (optionally, trainable) parameters in the model.
+
+        Args:
+            only_trainable (`bool`, *optional*, defaults to `False`):
+                Whether or not to return only the number of trainable parameters
+
+        Returns:
+            `int`: The number of parameters.
+        """
+        if only_trainable:
+            return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
+        else:
+            return self.count_params()
+
+
+def keras_serializable(cls):
+    """
+    Decorate a Keras Layer class to support Keras serialization.
+
+    This is done by:
+
+    1. Adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at
+       serialization time.
+    2. Wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and
+       convert it to a config object for the actual layer initializer.
+    3. Registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does not
+       need to be supplied in `custom_objects` in the call to `keras.models.load_model`.
+
+    Args:
+        cls (a `keras.layers.Layers subclass`):
+            Typically a `TF.MainLayer` class in this project, in general must accept a `config` argument to its
+            initializer.
+
+    Returns:
+        The same class object, with modifications for Keras deserialization.
+    """
+    initializer = cls.__init__
+
+    config_class = getattr(cls, "config_class", None)
+    if config_class is None:
+        raise AttributeError("Must set `config_class` to use @keras_serializable")
+
+    @functools.wraps(initializer)
+    def wrapped_init(self, *args, **kwargs):
+        config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.pop("config", None)
+
+        if isinstance(config, dict):
+            config = config_class.from_dict(config)
+            initializer(self, config, *args, **kwargs)
+        elif isinstance(config, PretrainedConfig):
+            if len(args) > 0:
+                initializer(self, *args, **kwargs)
+            else:
+                initializer(self, config, *args, **kwargs)
+        else:
+            raise TypeError("Must pass either `config` (PretrainedConfig) or `config` (dict)")
+
+        self._config = config
+        self._kwargs = kwargs
+
+    cls.__init__ = wrapped_init
+
+    if not hasattr(cls, "get_config"):
+        raise TypeError("Only use @keras_serializable on keras.layers.Layer subclasses")
+    if hasattr(cls.get_config, "_is_default"):
+
+        def get_config(self):
+            cfg = super(cls, self).get_config()
+            cfg["config"] = self._config.to_dict()
+            cfg.update(self._kwargs)
+            return cfg
+
+        cls.get_config = get_config
+
+    cls._keras_serializable = True
+    if hasattr(keras.utils, "register_keras_serializable"):
+        cls = keras.utils.register_keras_serializable()(cls)
+    return cls
+
+
+class TFCausalLanguageModelingLoss:
+    """
+    Loss function suitable for causal language modeling (CLM), that is, the task of guessing the next token.
+
+    
+
+    Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
+
+    
+    """
+
+    def hf_compute_loss(self, labels, logits):
+        loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
+        if self.config.tf_legacy_loss:
+            # make sure only labels that are not equal to -100 affect the loss
+            active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
+            reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
+            labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
+            return loss_fn(labels, reduced_logits)
+
+        # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
+        unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
+        # make sure only labels that are not equal to -100 affect the loss
+        loss_mask = tf.cast(labels != -100, dtype=unmasked_loss.dtype)
+        masked_loss = unmasked_loss * loss_mask
+        reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask)
+        return tf.reshape(reduced_masked_loss, (1,))
+
+
+class TFQuestionAnsweringLoss:
+    """
+    Loss function suitable for question answering.
+    """
+
+    def hf_compute_loss(self, labels, logits):
+        loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
+        start_loss = loss_fn(labels["start_position"], logits[0])
+        end_loss = loss_fn(labels["end_position"], logits[1])
+
+        return (start_loss + end_loss) / 2.0
+
+
+class TFTokenClassificationLoss:
+    """
+    Loss function suitable for token classification.
+
+    
+
+    Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
+
+    
+    """
+
+    def hf_compute_loss(self, labels, logits):
+        loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
+        if tf.executing_eagerly():  # Data-dependent conditionals are forbidden in XLA
+            if tf.math.reduce_any(labels == -1):
+                tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
+
+        if self.config.tf_legacy_loss:
+            # make sure only labels that are not equal to -100
+            # are taken into account as loss
+            if tf.math.reduce_any(labels == -1):
+                tf.print("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
+                active_loss = tf.reshape(labels, (-1,)) != -1
+            else:
+                active_loss = tf.reshape(labels, (-1,)) != -100
+            reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
+            labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
+
+            return loss_fn(labels, reduced_logits)
+
+        # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
+        unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
+        # make sure only labels that are not equal to -100 or -1
+        # are taken into account as loss
+        loss_mask = tf.cast(labels >= 0, dtype=unmasked_loss.dtype)
+        # Avoid possible division by zero later
+        # Masked positions will have a loss of NaN because -100 and -1 are not valid labels
+        masked_loss = unmasked_loss * loss_mask
+        reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask)
+        return tf.reshape(reduced_masked_loss, (1,))
+
+
+class TFSequenceClassificationLoss:
+    """
+    Loss function suitable for sequence classification.
+    """
+
+    def hf_compute_loss(self, labels, logits):
+        if logits.shape.rank == 1 or logits.shape[1] == 1:
+            loss_fn = keras.losses.MeanSquaredError(reduction=keras.losses.Reduction.NONE)
+            if labels.shape.rank == 1:
+                # MeanSquaredError returns a scalar loss if the labels are 1D, so avoid that
+                labels = tf.expand_dims(labels, axis=-1)
+        else:
+            loss_fn = keras.losses.SparseCategoricalCrossentropy(
+                from_logits=True, reduction=keras.losses.Reduction.NONE
+            )
+
+        return loss_fn(labels, logits)
+
+
+class TFMultipleChoiceLoss:
+    """Loss function suitable for multiple choice tasks."""
+
+    def hf_compute_loss(self, labels, logits):
+        loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
+        return loss_fn(labels, logits)
+
+
+class TFMaskedLanguageModelingLoss(TFCausalLanguageModelingLoss):
+    """
+    Loss function suitable for masked language modeling (MLM), that is, the task of guessing the masked tokens.
+
+    
+
+    Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
+
+    
+    """
+
+
+class TFNextSentencePredictionLoss:
+    """
+    Loss function suitable for next sentence prediction (NSP), that is, the task of guessing the next sentence.
+
+    
+
+    Any label of -100 will be ignored (along with the corresponding logits) in the loss computation.
+
+    
+    """
+
+    def hf_compute_loss(self, labels, logits):
+        loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
+        if self.config.tf_legacy_loss:
+            # make sure only labels that are not equal to -100
+            # are taken into account as loss
+            next_sentence_active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
+            next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, 2)), next_sentence_active_loss)
+            next_sentence_label = tf.boolean_mask(tf.reshape(labels, (-1,)), next_sentence_active_loss)
+
+            return loss_fn(next_sentence_label, next_sentence_reduced_logits)
+
+        # make sure only labels that are not equal to -100
+        # are taken into account as loss
+
+        # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
+        unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels), y_pred=logits)
+        ns_loss_mask = tf.cast(labels != -100, dtype=unmasked_ns_loss.dtype)
+        # Just zero out samples where label is -100, no reduction
+        masked_ns_loss = unmasked_ns_loss * ns_loss_mask
+
+        return masked_ns_loss
+
+
+def booleans_processing(config, **kwargs):
+    """
+    Process the input booleans of each model.
+
+    Args:
+        config ([`PretrainedConfig`]):
+            The config of the running model.
+        **kwargs:
+            The boolean parameters
+
+    Returns:
+        A dictionary with the proper values for each boolean
+    """
+    final_booleans = {}
+
+    # Pure conv models (such as ConvNext) do not have `output_attentions`. If the signature has
+    # `output_attentions`, it will be present here in `kwargs`, even if unset (in that case, as `None`)
+    if "output_attentions" in kwargs:
+        final_booleans["output_attentions"] = (
+            kwargs["output_attentions"] if kwargs["output_attentions"] is not None else config.output_attentions
+        )
+    final_booleans["output_hidden_states"] = (
+        kwargs["output_hidden_states"] if kwargs["output_hidden_states"] is not None else config.output_hidden_states
+    )
+    final_booleans["return_dict"] = kwargs["return_dict"] if kwargs["return_dict"] is not None else config.return_dict
+
+    if "use_cache" in kwargs:
+        final_booleans["use_cache"] = (
+            kwargs["use_cache"] if kwargs["use_cache"] is not None else getattr(config, "use_cache", None)
+        )
+    return final_booleans
+
+
+def unpack_inputs(func):
+    """
+    Decorator that processes the inputs to a Keras layer, passing them to the layer as keyword arguments. This enables
+    downstream use of the inputs by their variable name, even if they arrive packed as a dictionary in the first input
+    (common case in Keras).
+
+    Args:
+        func (`callable`):
+            The callable function of the TensorFlow model.
+
+
+    Returns:
+        A callable that wraps the original `func` with the behavior described above.
+    """
+
+    original_signature = inspect.signature(func)
+
+    @functools.wraps(func)
+    def run_call_with_unpacked_inputs(self, *args, **kwargs):
+        # isolates the actual `**kwargs` for the decorated function
+        kwargs_call = {key: val for key, val in kwargs.items() if key not in dict(original_signature.parameters)}
+        fn_args_and_kwargs = {key: val for key, val in kwargs.items() if key not in kwargs_call}
+        fn_args_and_kwargs.update({"kwargs_call": kwargs_call})
+
+        # move any arg into kwargs, if they exist
+        fn_args_and_kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
+
+        # Encoder Decoder models delegate the application of the configuration options to their inner models.
+        if "EncoderDecoder" in self.__class__.__name__:
+            config = None
+        else:
+            config = self.config
+
+        unpacked_inputs = input_processing(func, config, **fn_args_and_kwargs)
+        return func(self, **unpacked_inputs)
+
+    # Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
+    # function does not follow wrapper chains (i.e. ignores `functools.wraps()`), meaning that without the line below
+    # Keras would attempt to check the first argument against the literal signature of the wrapper.
+    run_call_with_unpacked_inputs.__signature__ = original_signature
+
+    return run_call_with_unpacked_inputs
+
+
+def input_processing(func, config, **kwargs):
+    """
+    Process the input of each TensorFlow model including the booleans. In case of a list of symbolic inputs, each input
+    has to be named accordingly to the parameters name, i.e. `input_ids = keras.Input(shape=(128,), dtype='int32',
+    name="input_ids")` otherwise the order of the tensors will not be guaranteed during the training.
+
+    Args:
+        func (`callable`):
+            The callable function of the TensorFlow model.
+        config ([`PretrainedConfig`]):
+            The config of the running model.
+        **kwargs:
+            The inputs of the model.
+
+    Returns:
+        Two lists, one for the missing layers, and another one for the unexpected layers.
+    """
+    signature = dict(inspect.signature(func).parameters)
+    has_kwargs = bool(signature.pop("kwargs", None))
+    signature.pop("self", None)
+    parameter_names = list(signature.keys())
+    main_input_name = parameter_names[0]
+    main_input = kwargs.pop(main_input_name, None)
+    output = {}
+    allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
+
+    if "inputs" in kwargs["kwargs_call"]:
+        warnings.warn(
+            "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids` instead.",
+            FutureWarning,
+        )
+
+        output["input_ids"] = kwargs["kwargs_call"].pop("inputs")
+
+    if "decoder_cached_states" in kwargs["kwargs_call"]:
+        warnings.warn(
+            "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
+            " `past_key_values` instead.",
+            FutureWarning,
+        )
+        output["past_key_values"] = kwargs["kwargs_call"].pop("decoder_cached_states")
+
+    if "past" in kwargs["kwargs_call"] and "past_key_values" in parameter_names:
+        warnings.warn(
+            "The `past` argument is deprecated and will be removed in a future version, use `past_key_values`"
+            " instead.",
+            FutureWarning,
+        )
+        kwargs["past_key_values"] = kwargs["kwargs_call"].pop("past")
+    elif "past_key_values" in kwargs["kwargs_call"] and "past" in parameter_names:
+        kwargs["past"] = kwargs["kwargs_call"].pop("past_key_values")
+
+    if has_kwargs:
+        output["kwargs"] = kwargs.pop("kwargs_call", {})
+    else:
+        if len(kwargs["kwargs_call"]) > 0:
+            raise ValueError(
+                "The following keyword arguments are not supported by this model:"
+                f" {list(kwargs['kwargs_call'].keys())}."
+            )
+        kwargs.pop("kwargs_call")
+
+    for k, v in kwargs.items():
+        if isinstance(v, allowed_types) or tf.is_tensor(v) or v is None:
+            output[k] = v
+        else:
+            raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
+
+    if isinstance(main_input, (tuple, list)):
+        for i, input in enumerate(main_input):
+            # EagerTensors don't allow to use the .name property so we check for a real Tensor
+            if is_tf_symbolic_tensor(input):
+                # Tensor names have always the pattern `name:id` then we check only the
+                # `name` part
+                tensor_name = input.name.split(":")[0]
+
+                if tensor_name in parameter_names:
+                    output[tensor_name] = input
+                else:
+                    output[parameter_names[i]] = input
+            elif isinstance(input, allowed_types) or input is None:
+                output[parameter_names[i]] = input
+            else:
+                raise ValueError(
+                    f"Data of type {type(input)} is not allowed only {allowed_types} is accepted for"
+                    f" {parameter_names[i]}."
+                )
+    elif isinstance(main_input, Mapping):
+        if "inputs" in main_input:
+            warnings.warn(
+                "The `inputs` argument is deprecated and will be removed in a future version, use `input_ids`"
+                " instead.",
+                FutureWarning,
+            )
+
+            output["input_ids"] = main_input.pop("inputs")
+
+        if "decoder_cached_states" in main_input:
+            warnings.warn(
+                "The `decoder_cached_states` argument is deprecated and will be removed in a future version, use"
+                " `past_key_values` instead.",
+                FutureWarning,
+            )
+            output["past_key_values"] = main_input.pop("decoder_cached_states")
+
+        for k, v in dict(main_input).items():
+            if isinstance(v, allowed_types) or v is None:
+                output[k] = v
+            elif k not in parameter_names and "args" not in parameter_names:
+                logger.warning(
+                    f"The parameter {k} does not belongs to the parameter list {parameter_names} and will be ignored."
+                )
+                continue
+            else:
+                raise ValueError(f"Data of type {type(v)} is not allowed only {allowed_types} is accepted for {k}.")
+    else:
+        if tf.is_tensor(main_input) or main_input is None:
+            output[main_input_name] = main_input
+        else:
+            raise ValueError(
+                f"Data of type {type(main_input)} is not allowed only {allowed_types} is accepted for"
+                f" {main_input_name}."
+            )
+
+    # Populates any unspecified argument with their default value, according to the signature.
+    for name in parameter_names:
+        if name not in list(output.keys()) and name != "args":
+            output[name] = kwargs.pop(name, signature[name].default)
+
+    # When creating a SavedModel TF calls the method with LayerCall.__call__(args, **kwargs)
+    # So to respect the proper output we have to add this exception
+    if "args" in output:
+        if output["args"] is not None and is_tf_symbolic_tensor(output["args"]):
+            tensor_name = output["args"].name.split(":")[0]
+            output[tensor_name] = output["args"]
+        else:
+            # `args` in this case is always the first parameter, then `input_ids`
+            output["input_ids"] = output["args"]
+
+        del output["args"]
+
+    if "kwargs" in output:
+        del output["kwargs"]
+
+    cast_output = {}
+    for key, val in output.items():
+        if isinstance(val, tf.Tensor) and val.dtype == tf.int64:
+            cast_output[key] = tf.cast(val, tf.int32)
+        elif isinstance(val, np.ndarray) and val.dtype == np.int64:
+            cast_output[key] = val.astype(np.int32)
+        else:
+            cast_output[key] = val
+
+    output = cast_output
+    del cast_output
+
+    if config is not None:
+        boolean_dict = {
+            k: v
+            for k, v in output.items()
+            if k in ["return_dict", "output_attentions", "output_hidden_states", "use_cache"]
+        }
+
+        output.update(
+            booleans_processing(
+                config=config,
+                **boolean_dict,
+            )
+        )
+
+    return output
+
+
+def strip_model_name_and_prefix(name, _prefix=None):
+    if _prefix is not None and name.startswith(_prefix):
+        name = name[len(_prefix) :]
+        if name.startswith("/"):
+            name = name[1:]
+    if "model." not in name and len(name.split("/")) > 1:
+        name = "/".join(name.split("/")[1:])
+    return name
+
+
+def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_WEIGHTS_NAME):
+    """
+    Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
+    given size.
+
+    The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
+    optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
+    limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
+    [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
+
+    
+
+    If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will
+    have a size greater than `max_shard_size`.
+
+    
+
+    Args:
+        weights (`dict[str, tf.RessourceVariable]`): The list of tf.RessourceVariable of a model to save.
+        max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
+            The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
+            (like `"5MB"`).
+    """
+    max_shard_size = convert_file_size_to_int(max_shard_size)
+
+    sharded_state_dicts = []
+    current_block = []
+    current_block_size = 0
+    total_size = 0
+
+    for item in weights:
+        weight_size = item.numpy().size * item.dtype.size
+
+        # If this weight is going to tip up over the maximal size, we split.
+        if current_block_size + weight_size > max_shard_size:
+            sharded_state_dicts.append(current_block)
+            current_block = []
+            current_block_size = 0
+
+        current_block.append(item)
+        current_block_size += weight_size
+        total_size += weight_size
+
+    # Add the last block
+    sharded_state_dicts.append(current_block)
+
+    # If we only have one shard, we return it
+    if len(sharded_state_dicts) == 1:
+        return {weights_name: sharded_state_dicts[0]}, None
+
+    # Otherwise, let's build the index
+    weight_map = {}
+    shards = {}
+    for idx, shard in enumerate(sharded_state_dicts):
+        shard_file = weights_name.replace(".h5", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.h5")
+        shard_file = shard_file.replace(
+            ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
+        )
+        shards[shard_file] = shard
+        for weight in shard:
+            weight_name = weight.name
+            weight_map[weight_name] = shard_file
+
+    # Add the metadata
+    metadata = {"total_size": total_size}
+    index = {"metadata": metadata, "weight_map": weight_map}
+    return shards, index
+
+
+def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None):
+    """
+    This is the same as `load_tf_weights` but for a sharded checkpoint. Detect missing and unexpected layers and load
+    the TF weights from the shard file accordingly to their names and shapes.
+
+    This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
+    loaded in the model.
+
+    Args:
+        model (`keras.models.Model`): The model in which to load the checkpoint.
+        shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names.
+        ignore_mismatched_sizes`bool`, *optional`, defaults to `True`):
+            Whether or not to ignore the mismatch between the sizes
+        strict (`bool`, *optional*, defaults to `True`):
+            Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
+
+    Returns:
+        Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
+        mismatched layers.
+    """
+
+    # Load the index
+    unexpected_keys = set()
+    saved_keys = set()
+    mismatched_keys = set()
+
+    # Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
+    # the weight, we have to get rid of the first prefix of the name of the layer.
+    model_keys = set()
+    model_layer_map = {}
+    for i, k in enumerate(model.weights):
+        layer_name = k.name
+        if _prefix is not None and layer_name.startswith(_prefix):
+            layer_name = layer_name[len(_prefix) :]
+            layer_name = layer_name.lstrip("/")
+        if not ("model." in layer_name or len(layer_name.split("/")) == 1):
+            layer_name = "/".join(layer_name.split("/")[1:])
+        model_keys.add(layer_name)
+        model_layer_map[layer_name] = i
+
+    for shard_file in shard_files:
+        saved_weight_names_set, unexpected_keys_set, mismatched_keys_set = load_tf_shard(
+            model,
+            model_layer_map,
+            shard_file,
+            ignore_mismatched_sizes=ignore_mismatched_sizes,
+            _prefix=_prefix,
+        )
+        saved_keys.update(saved_weight_names_set)
+        unexpected_keys.update(unexpected_keys_set)
+        mismatched_keys.update(mismatched_keys_set)
+        gc.collect()
+
+    missing_keys = model_keys - saved_keys
+    if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
+        error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
+        if len(missing_keys) > 0:
+            str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
+            error_message += f"\nMissing key(s): {str_missing_keys}."
+        if len(unexpected_keys) > 0:
+            str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
+            error_message += f"\nMissing key(s): {str_unexpected_keys}."
+        raise RuntimeError(error_message)
+
+    return missing_keys, unexpected_keys, mismatched_keys
+
+
+def load_tf_shard(model, model_layer_map, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
+    """
+    Loads a shard from a sharded checkpoint file. Can be either H5 or Safetensors.
+    Handles missing keys and unexpected keys.
+
+    Args:
+        model (`keras.models.Model`): Model in which the weights are loaded
+        model_layer_map (`Dict`): A dictionary mapping the layer name to the index of the layer in the model.
+        resolved_archive_file (`str`): Path to the checkpoint file from which the weights will be loaded
+        ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): Whether to ignore the mismatched keys
+
+    Returns:
+        `keras.models.Model`: Three lists, one for the layers that were found and successfully restored (from the
+        shard file), one for the mismatched layers, and another one for the unexpected layers.
+    """
+    saved_weight_names_set = set()
+    saved_weights = {}
+    mismatched_keys = set()
+    unexpected_keys = set()
+    # Read the H5 file
+    try:
+        with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
+            # Retrieve the name of each layer from the H5 file
+            saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names"))
+            weight_value_tuples = []
+
+            # Compute missing and unexpected sub layers
+            # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...]
+            for layer_name in saved_h5_model_layers_name:
+                h5_layer_object = sharded_checkpoint_file[layer_name]
+                saved_weights[layer_name] = np.asarray(h5_layer_object)
+
+                saved_weight_names_set.add(layer_name)
+
+                if layer_name not in model_layer_map:
+                    unexpected_keys.add(layer_name)
+                else:
+                    symbolic_weight = model.weights[model_layer_map[layer_name]]
+
+                    saved_weight_value = saved_weights[layer_name]
+                    # If the current weight is found
+                    if saved_weight_value is not None:
+                        # Check if the shape of the current weight and the one from the H5 file are different
+                        if K.int_shape(symbolic_weight) != saved_weight_value.shape:
+                            # If yes we reshape the weight from the H5 file accordingly to the current weight
+                            # If the two shapes are not compatible we raise an issue
+                            try:
+                                array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
+                            except ValueError as e:
+                                if ignore_mismatched_sizes:
+                                    mismatched_keys.add(
+                                        (layer_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
+                                    )
+                                    continue
+                                else:
+                                    raise e
+                        else:
+                            array = saved_weight_value
+
+                    # We create the tuple that will be loaded and add it to the final list
+                    weight_value_tuples.append((symbolic_weight, array))
+
+        K.batch_set_value(weight_value_tuples)
+
+        return saved_weight_names_set, unexpected_keys, mismatched_keys
+
+    except Exception as e:
+        try:
+            with open(resolved_archive_file) as f:
+                if f.read().startswith("version"):
+                    raise OSError(
+                        "You seem to have cloned a repository without having git-lfs installed. Please install "
+                        "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
+                        "you cloned."
+                    )
+                else:
+                    raise ValueError(
+                        f"Unable to locate the file {resolved_archive_file} which is necessary to load this pretrained"
+                        " model. Make sure you have saved the model properly."
+                    ) from e
+        except (UnicodeDecodeError, ValueError):
+            raise OSError(
+                f"Unable to load weights from TF checkpoint file for '{resolved_archive_file}' "
+                f"at '{resolved_archive_file}'. "
+                "If you tried to load a TF model from a sharded checkpoint, you should try converting the model "
+                "by loading it in pytorch and saving it locally. A conversion script should be released soon."
+            )
+
+
+def load_tf_sharded_weights_from_safetensors(
+    model, shard_files, ignore_mismatched_sizes=False, strict=False, _prefix=None
+):
+    """
+    This is the same as `load_tf_weights_from_safetensors` but for a sharded TF-format safetensors checkpoint.
+    Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
+    shapes.
+
+    This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
+    loaded in the model.
+
+    Args:
+        model (`keras.models.Model`): The model in which to load the checkpoint.
+        shard_files (`str` or `os.PathLike`): A list containing the sharded checkpoint names.
+        ignore_mismatched_sizes`bool`, *optional`, defaults to `True`):
+            Whether or not to ignore the mismatch between the sizes
+        strict (`bool`, *optional*, defaults to `True`):
+            Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
+
+    Returns:
+        Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
+        mismatched layers.
+    """
+
+    # Load the index
+    unexpected_keys = set()
+    all_missing_keys = []
+    mismatched_keys = set()
+
+    for shard_file in shard_files:
+        missing_layers, unexpected_layers, mismatched_layers = load_tf_weights_from_safetensors(
+            model,
+            shard_file,
+            ignore_mismatched_sizes=ignore_mismatched_sizes,
+            _prefix=_prefix,
+        )
+        all_missing_keys.append(set(missing_layers))
+        unexpected_keys.update(unexpected_layers)
+        mismatched_keys.update(mismatched_layers)
+        gc.collect()
+    missing_keys = set.intersection(*all_missing_keys)
+
+    if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
+        error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
+        if len(missing_keys) > 0:
+            str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
+            error_message += f"\nMissing key(s): {str_missing_keys}."
+        if len(unexpected_keys) > 0:
+            str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
+            error_message += f"\nMissing key(s): {str_unexpected_keys}."
+        raise RuntimeError(error_message)
+
+    return missing_keys, unexpected_keys, mismatched_keys
+
+
+def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
+    """
+    Detect missing and unexpected layers and load the TF weights from the shard file accordingly to their names and
+    shapes.
+
+    Args:
+        model (`keras.models.Model`):
+            The model to load the weights into.
+        resolved_archive_file (`str`):
+            The location of the H5 file.
+        ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
+            Whether or not to ignore weights with shapes that don't match between the checkpoint of the model.
+
+    Returns:
+        Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
+        mismatched layers.
+    """
+    if resolved_archive_file.endswith(".safetensors"):
+        load_function = load_tf_weights_from_safetensors
+    else:
+        load_function = load_tf_weights_from_h5
+
+    return load_function(
+        model, resolved_archive_file, ignore_mismatched_sizes=ignore_mismatched_sizes, _prefix=_prefix
+    )
+
+
+def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
+    mismatched_layers = []
+
+    # Read the H5 file
+    with h5py.File(resolved_archive_file, "r") as sharded_checkpoint_file:
+        # Retrieve the name of each layer from the H5 file
+        saved_h5_model_layers_name = set(load_attributes_from_hdf5_group(sharded_checkpoint_file, "layer_names"))
+
+        # Find the missing layers from the high level list of layers
+        missing_layers = list({layer.name for layer in model.layers} - saved_h5_model_layers_name)
+
+        # Find the unexpected layers from the high level list of layers
+        unexpected_layers = list(saved_h5_model_layers_name - {layer.name for layer in model.layers})
+        saved_weight_names_set = set()
+        symbolic_weights_names = set()
+        weight_value_tuples = []
+
+        # Compute missing and unexpected sub layers
+        # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...]
+        for layer in model.layers:
+            # if layer_name from the H5 file belongs to the layers from the instantiated model
+            if layer.name in saved_h5_model_layers_name:
+                # Get the H5 layer object from its name
+                h5_layer_object = sharded_checkpoint_file[layer.name]
+                # Get all the weights as a list from the layer object
+                symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
+                saved_weights = {}
+
+                # Create a dict from the H5 saved model that looks like {"weight_name": weight_value}
+                # And a set with only the names
+                for weight_name in load_attributes_from_hdf5_group(h5_layer_object, "weight_names"):
+                    # TF names always start with the model name so we ignore it
+                    name = "/".join(weight_name.split("/")[1:])
+
+                    if _prefix is not None:
+                        name = _prefix + "/" + name
+
+                    saved_weights[name] = np.asarray(h5_layer_object[weight_name])
+
+                    # Add the updated name to the final list for computing missing/unexpected values
+                    saved_weight_names_set.add(name)
+
+                # Loop over each weights from the instantiated model and compare with the weights from the H5 file
+                for symbolic_weight in symbolic_weights:
+                    # TF names always start with the model name so we ignore it
+                    if _prefix is not None:
+                        delimiter = len(_prefix.split("/"))
+                        symbolic_weight_name = "/".join(
+                            symbolic_weight.name.split("/")[:delimiter]
+                            + symbolic_weight.name.split("/")[delimiter + 1 :]
+                        )
+                    else:
+                        symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:])
+
+                    # here we check if the current weight is among the weights from the H5 file
+                    # If yes, get the weight_value of the corresponding weight from the H5 file
+                    # If not, make the value to None
+                    saved_weight_value = saved_weights.get(symbolic_weight_name)
+
+                    # Retrocompatibility patch: some embeddings are stored with the weights name (e.g. Bart's
+                    # `model.shared/embeddings:0` are stored as `model.shared/weights:0`)
+                    if saved_weight_value is None and symbolic_weight_name.endswith("embeddings:0"):
+                        symbolic_weight_name = symbolic_weight_name[:-12] + "weight:0"
+                        saved_weight_value = saved_weights.get(symbolic_weight_name)
+
+                    # Add the updated name to the final list for computing missing/unexpected values
+                    symbolic_weights_names.add(symbolic_weight_name)
+
+                    # If the current weight is found
+                    if saved_weight_value is not None:
+                        # Check if the shape of the current weight and the one from the H5 file are different
+                        if K.int_shape(symbolic_weight) != saved_weight_value.shape:
+                            # If yes we reshape the weight from the H5 file accordingly to the current weight
+                            # If the two shapes are not compatible we raise an issue
+                            try:
+                                array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
+                            except ValueError as e:
+                                if ignore_mismatched_sizes:
+                                    mismatched_layers.append(
+                                        (symbolic_weight_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
+                                    )
+                                    continue
+                                else:
+                                    raise e
+                        else:
+                            array = saved_weight_value
+
+                        # We create the tuple that will be loaded and add it to the final list
+                        weight_value_tuples.append((symbolic_weight, array))
+
+    # Load all the weights
+    K.batch_set_value(weight_value_tuples)
+
+    # Compute the missing and unexpected layers
+    missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
+    unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))
+
+    return missing_layers, unexpected_layers, mismatched_layers
+
+
+def load_tf_weights_from_safetensors(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
+    # Read the safetensors file
+    with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
+        mismatched_layers = []
+        weight_names = [strip_model_name_and_prefix(w.name, _prefix=_prefix) for w in model.weights]
+        loaded_weight_names = list(safetensors_archive.keys())
+        # Find the missing layers from the high level list of layers
+        missing_layers = list(set(weight_names) - set(loaded_weight_names))
+        # Find the unexpected layers from the high level list of layers
+        unexpected_layers = list(set(loaded_weight_names) - set(weight_names))
+
+        for weight in model.weights:
+            weight_name = strip_model_name_and_prefix(weight.name, _prefix=_prefix)
+            if weight_name in loaded_weight_names:
+                weight_value = safetensors_archive.get_tensor(weight_name)
+                # Check if the shape of the current weight and the one from the H5 file are different
+                if K.int_shape(weight) != weight_value.shape:
+                    # If yes we reshape the weight from the H5 file accordingly to the current weight
+                    # If the two shapes are not compatible we raise an issue
+                    try:
+                        weight_value = tf.reshape(weight_value, K.int_shape(weight))
+                    except (ValueError, tf.errors.InvalidArgumentError) as e:
+                        if ignore_mismatched_sizes:
+                            mismatched_layers.append((weight_name, weight_value.shape, K.int_shape(weight)))
+                            continue
+                        else:
+                            raise e
+
+                K.set_value(weight, weight_value)  # weight.assign() might break if weight is a DTensor
+    return missing_layers, unexpected_layers, mismatched_layers
+
+
+def init_copy_embeddings(old_embeddings, new_num_tokens):
+    r"""
+    This function aims to reduce the embeddings in case new_num_tokens < old_num_tokens or to pad with -1 in case
+    new_num_tokens > old_num_tokens. A mask is also computed in order to know which weight in the embeddings should be
+    kept or not. Example:
+
+        - if new_num_tokens=5 and old_num_tokens=4 and old_embeddings=[w1,w2,w3,w4]
+
+            -  mask=[True,True,True,True,False] and current_weights=[w1,w2,w3,w4,-1]
+        - if new_num_tokens=4 and old_num_tokens=5 and old_embeddings=[w1,w2,w3,w4,w5]
+
+            - mask=[True,True,True,True] and current_weights=[w1,w2,w3,w4]
+    """
+    old_num_tokens, old_embedding_dim = shape_list(old_embeddings)
+    size_diff = new_num_tokens - old_num_tokens
+
+    # initialize new embeddings
+    # Copy token embeddings from the previous ones
+    if tf.math.greater(size_diff, 0):
+        # if the new size is greater than the old one, we extend the current embeddings with a padding until getting new size
+        # and we create a mask to properly identify the padded values and be replaced by the values of the newly created
+        # embeddings
+        current_weights = tf.pad(
+            old_embeddings.value(), tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=-1
+        )
+        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
+        mask = tf.fill(tf.convert_to_tensor([num_tokens_to_copy, 1]), True)
+        mask = tf.pad(mask, tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=False)
+    else:
+        # if the new size if lower than the old one, we take the current embeddings until the new size
+        current_weights = tf.slice(
+            old_embeddings.value(),
+            tf.convert_to_tensor([0, 0]),
+            tf.convert_to_tensor([new_num_tokens, old_embedding_dim]),
+        )
+        mask = tf.fill(tf.convert_to_tensor([new_num_tokens, 1]), True)
+
+    return mask, current_weights
+
+
+class TFPreTrainedModel(keras.Model, TFModelUtilsMixin, TFGenerationMixin, PushToHubMixin):
+    r"""
+    Base class for all TF models.
+
+    [`TFPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
+    downloading and saving models as well as a few methods common to all models to:
+
+        - resize the input embeddings,
+        - prune heads in the self-attention heads.
+
+    Class attributes (overridden by derived classes):
+
+        - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
+          for this model architecture.
+        - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
+          classes of the same architecture adding modules on top of the base model.
+        - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
+          models, `pixel_values` for vision models and `input_values` for speech models).
+    """
+
+    config_class = None
+    base_model_prefix = ""
+    main_input_name = "input_ids"
+    _auto_class = None
+    _using_dummy_loss = None
+    _label_to_output_map = None
+
+    # a list of re pattern of tensor names to ignore from the model when loading the model weights
+    # (and avoid unnecessary warnings).
+    _keys_to_ignore_on_load_missing = None
+    # a list of re pattern of tensor names to ignore from the weights when loading the model weights
+    # (and avoid unnecessary warnings).
+    _keys_to_ignore_on_load_unexpected = None
+    _requires_load_weight_prefix = False
+
+    @property
+    def dummy_inputs(self) -> dict[str, tf.Tensor]:
+        """
+        Dummy inputs to build the network.
+
+        Returns:
+            `dict[str, tf.Tensor]`: The dummy inputs.
+        """
+        dummies = {}
+        for key, spec in self.input_signature.items():
+            # 2 is the most correct arbitrary size. I will not be taking questions
+            dummy_shape = [dim if dim is not None else 2 for dim in spec.shape]
+            if spec.shape[0] is None:
+                # But let's make the batch size 1 to save memory anyway
+                dummy_shape[0] = 1
+            dummies[key] = tf.ones(shape=dummy_shape, dtype=spec.dtype)
+            if key == "token_type_ids":
+                # Some models have token_type_ids but with a vocab_size of 1
+                dummies[key] = tf.zeros_like(dummies[key])
+        if self.config.add_cross_attention and "encoder_hidden_states" in inspect.signature(self.call).parameters:
+            if "encoder_hidden_states" not in dummies:
+                if self.main_input_name == "input_ids":
+                    dummies["encoder_hidden_states"] = tf.ones(
+                        shape=(1, 2, self.config.hidden_size), dtype=tf.float32, name="encoder_hidden_states"
+                    )
+                else:
+                    raise NotImplementedError(
+                        "Model has cross-attention but we couldn't infer the shape for the encoder hidden states. Please manually override dummy_inputs!"
+                    )
+        return dummies
+
+    def build_in_name_scope(self):
+        with tf.name_scope(self.name):
+            self.build(input_shape=None)
+
+    @property
+    def framework(self) -> str:
+        """
+        :str: Identifies that this is a TensorFlow model.
+        """
+        return "tf"
+
+    def build(self, input_shape=None):
+        pass  # This is just here to make sure we don't call the superclass build()
+
+    def __init__(self, config, *inputs, **kwargs):
+        super().__init__(*inputs, **kwargs)
+        if not isinstance(config, PretrainedConfig):
+            raise TypeError(
+                f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
+                "`PretrainedConfig`. To create a model from a pretrained model use "
+                f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
+            )
+        # Save config and origin of the pretrained weights if given in model
+        self.config = config
+        self.name_or_path = config.name_or_path
+        self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
+        self._set_save_spec(self.input_signature)
+        logger.warning_once(
+            "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We "
+            "recommend migrating to PyTorch classes or pinning your version of Transformers."
+        )
+
+    def get_config(self):
+        return self.config.to_dict()
+
+    @functools.wraps(keras.Model.fit)
+    def fit(self, *args, **kwargs):
+        args, kwargs = convert_batch_encoding(*args, **kwargs)
+        return super().fit(*args, **kwargs)
+
+    @functools.wraps(keras.Model.train_on_batch)
+    def train_on_batch(self, *args, **kwargs):
+        args, kwargs = convert_batch_encoding(*args, **kwargs)
+        return super().train_on_batch(*args, **kwargs)
+
+    @functools.wraps(keras.Model.test_on_batch)
+    def test_on_batch(self, *args, **kwargs):
+        args, kwargs = convert_batch_encoding(*args, **kwargs)
+        return super().test_on_batch(*args, **kwargs)
+
+    @functools.wraps(keras.Model.predict_on_batch)
+    def predict_on_batch(self, *args, **kwargs):
+        args, kwargs = convert_batch_encoding(*args, **kwargs)
+        return super().predict_on_batch(*args, **kwargs)
+
+    @functools.wraps(keras.Model.predict)
+    def predict(self, *args, **kwargs):
+        args, kwargs = convert_batch_encoding(*args, **kwargs)
+        return super().predict(*args, **kwargs)
+
+    @functools.wraps(keras.Model.evaluate)
+    def evaluate(self, *args, **kwargs):
+        args, kwargs = convert_batch_encoding(*args, **kwargs)
+        return super().evaluate(*args, **kwargs)
+
+    @classmethod
+    def from_config(cls, config, **kwargs):
+        if isinstance(config, PretrainedConfig):
+            return cls._from_config(config, **kwargs)
+        return cls._from_config(cls.config_class.from_dict(config, **kwargs))
+
+    @classmethod
+    def _from_config(cls, config, **kwargs):
+        """
+        All context managers that the model should be initialized under go here.
+        """
+        return cls(config, **kwargs)
+
+    def get_head_mask(self, head_mask: tf.Tensor | None, num_hidden_layers: int) -> tf.Tensor:
+        """
+        Prepare the head mask if needed.
+
+        Args:
+            head_mask (`tf.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
+                The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
+            num_hidden_layers (`int`):
+                The number of hidden layers in the model.
+
+        Returns:
+            `tf.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
+            `[None]` for each layer.
+        """
+        if head_mask is not None:
+            head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
+        else:
+            head_mask = [None] * num_hidden_layers
+
+        return head_mask
+
+    def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
+        """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
+        if head_mask.shape.rank == 1:
+            head_mask = head_mask[None, None, :, None, None]
+            head_mask = tf.repeat(head_mask, repeats=num_hidden_layers, axis=0)
+        elif head_mask.shape.rank == 2:
+            head_mask = head_mask[:, None, :, None, None]
+        assert head_mask.shape.rank == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
+        head_mask = tf.cast(head_mask, tf.float32)  # switch to float if need + fp16 compatibility
+        return head_mask
+
+    @tf.function
+    def serving(self, inputs):
+        """
+        Args:
+        Method used for serving the model. Does not have a specific signature, but will be specialized as concrete
+        functions when saving with `save_pretrained`.
+            inputs (`dict[str, tf.Tensor]`):
+                The input of the saved model as a dictionary of tensors.
+        """
+        output = self.call(inputs)
+
+        return self.serving_output(output)
+
+    @property
+    def input_signature(self) -> dict[str, tf.TensorSpec]:
+        """
+        This property should return a dict mapping input names to tf.TensorSpec objects, representing the expected
+        shape and dtype for model inputs. It is used for both serving and for generating dummy inputs.
+        """
+        model_inputs = list(inspect.signature(self.call).parameters)
+        sig = {}
+        if "input_ids" in model_inputs:
+            if self.__class__.__name__.endswith("ForMultipleChoice"):
+                text_dims = 3
+            else:
+                text_dims = 2
+            for input_name in (
+                "input_ids",
+                "attention_mask",
+                "token_type_ids",
+                "decoder_input_ids",
+                "decoder_attention_mask",
+            ):
+                if input_name in model_inputs:
+                    sig[input_name] = tf.TensorSpec([None] * text_dims, tf.int32, name=input_name)
+        if "pixel_values" in model_inputs:
+            pixel_values_shape = [None, None, None, None]
+            if hasattr(self.config, "vision_config"):
+                vision_config = self.config.vision_config
+            else:
+                vision_config = self.config
+            if hasattr(vision_config, "num_channels"):
+                pixel_values_shape[1] = vision_config.num_channels
+            else:
+                raise NotImplementedError(
+                    "Could not infer number of channels from config, please override input_signature to specify input shapes."
+                )
+            if hasattr(vision_config, "image_size"):
+                pixel_values_shape[2] = pixel_values_shape[3] = vision_config.image_size
+            elif hasattr(vision_config, "input_size"):
+                pixel_values_shape[2] = pixel_values_shape[3] = vision_config.input_size
+            else:
+                raise NotImplementedError(
+                    "Could not infer input image shape from config, please override input_signature to specify input shapes."
+                )
+            sig["pixel_values"] = tf.TensorSpec(pixel_values_shape, tf.float32, name="pixel_values")
+        if "input_features" in model_inputs:
+            raise NotImplementedError("Audio models need a manually defined input_signature")
+        return sig
+
+    def serving_output(self, output):
+        """
+        Prepare the output of the saved model. Can be overridden if specific serving modifications are required.
+        """
+        if not isinstance(output, ModelOutput):
+            return output
+        for key in output:
+            if key.endswith("hidden_states") and not getattr(self.config, "output_hidden_states", False):
+                output[key] = None
+            elif key.endswith("attentions") and not getattr(self.config, "output_attentions", False):
+                output[key] = None
+            elif key == "past_key_values" and not getattr(self.config, "use_cache", False):
+                output[key] = None
+            elif key == "cross_attentions" and not (
+                getattr(self.config, "output_attentions", False) and getattr(self.config, "add_cross_attention", False)
+            ):
+                output[key] = None
+            if isinstance(output[key], (tuple, list)):
+                try:
+                    output[key] = tf.convert_to_tensor(output[key])
+                except (ValueError, tf.errors.InvalidArgumentError):
+                    pass  # Layers may not have the same dimensions
+        return output
+
+    @classmethod
+    def can_generate(cls) -> bool:
+        """
+        Returns whether this model can generate sequences with `.generate()`.
+
+        Returns:
+            `bool`: Whether this model can generate sequences with `.generate()`.
+        """
+        # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
+        # Alternatively, the model can also have a custom `generate` function.
+        if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
+            return False
+        return True
+
+    def get_input_embeddings(self) -> keras.layers.Layer:
+        """
+        Returns the model's input embeddings layer.
+
+        Returns:
+            `tf.Variable`: The embeddings layer mapping vocabulary to hidden states.
+        """
+        main_layer = getattr(self, self.base_model_prefix, self)
+
+        if main_layer is not self:
+            return main_layer.get_input_embeddings()
+        else:
+            raise NotImplementedError
+
+    def _save_checkpoint(self, checkpoint_dir, epoch):
+        if not os.path.isdir(checkpoint_dir):
+            os.mkdir(checkpoint_dir)
+        # We avoid tf.train.checkpoint or saving weights in TF format, even though that includes optimizer
+        # state for us, because it requires special handling for objects like custom losses, which we use
+        # internally and which users are likely to use too
+        weights_path = os.path.join(checkpoint_dir, "weights.h5")
+        self.save_weights(weights_path)
+        extra_data = {"epoch": epoch, "optimizer_state": self.optimizer.get_weights()}
+        extra_data_path = os.path.join(checkpoint_dir, "extra_data.pickle")
+        with open(extra_data_path, "wb") as f:
+            pickle.dump(extra_data, f)
+
+    def prepare_tf_dataset(
+        self,
+        dataset: datasets.Dataset,  # noqa:F821
+        batch_size: int = 8,
+        shuffle: bool = True,
+        tokenizer: PreTrainedTokenizerBase | None = None,
+        collate_fn: Callable | None = None,
+        collate_fn_args: dict[str, Any] | None = None,
+        drop_remainder: bool | None = None,
+        prefetch: bool = True,
+    ):
+        """
+        Wraps a HuggingFace [`~datasets.Dataset`] as a `tf.data.Dataset` with collation and batching. This method is
+        designed to create a "ready-to-use" dataset that can be passed directly to Keras methods like `fit()` without
+        further modification. The method will drop columns from the dataset if they don't match input names for the
+        model. If you want to specify the column names to return rather than using the names that match this model, we
+        recommend using `Dataset.to_tf_dataset()` instead.
+
+        Args:
+            dataset (`Any`):
+                A [~`datasets.Dataset`] to be wrapped as a `tf.data.Dataset`.
+            batch_size (`int`, *optional*, defaults to 8):
+                The size of batches to return.
+            shuffle (`bool`, defaults to `True`):
+                Whether to return samples from the dataset in random order. Usually `True` for training datasets and
+                `False` for validation/test datasets.
+            tokenizer ([`PreTrainedTokenizerBase`], *optional*):
+                A `PreTrainedTokenizer` that will be used to pad samples to create batches. Has no effect if a specific
+                `collate_fn` is passed instead.
+            collate_fn (`Callable`, *optional*):
+                A function that collates samples from the dataset into a single batch. Defaults to
+                `DefaultDataCollator` if no `tokenizer` is supplied or `DataCollatorWithPadding` if a `tokenizer` is
+                passed.
+            collate_fn_args (`dict[str, Any]`, *optional*):
+                A dict of arguments to pass to the `collate_fn` alongside the list of samples.
+            drop_remainder (`bool`, *optional*):
+                Whether to drop the final batch, if the batch_size does not evenly divide the dataset length. Defaults
+                to the same setting as `shuffle`.
+            prefetch (`bool`, defaults to `True`):
+                Whether to add prefetching to the end of the `tf.data` pipeline. This is almost always beneficial for
+                performance, but can be disabled in edge cases.
+
+
+        Returns:
+            `Dataset`: A `tf.data.Dataset` which is ready to pass to the Keras API.
+        """
+        requires_backends(self, ["datasets"])
+        import datasets
+
+        if collate_fn is None:
+            if tokenizer is None:
+                collate_fn = DefaultDataCollator(return_tensors="np")
+            else:
+                collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="np")
+        if collate_fn_args is None:
+            collate_fn_args = {}
+
+        if not isinstance(dataset, datasets.Dataset):
+            raise TypeError("Dataset argument should be a datasets.Dataset!")
+        model_inputs = list(inspect.signature(self.call).parameters)
+        model_labels = find_labels(self.__class__)
+        if "cols_to_retain" in list(inspect.signature(dataset._get_output_signature).parameters.keys()):
+            output_signature, _ = dataset._get_output_signature(
+                dataset,
+                batch_size=None,
+                collate_fn=collate_fn,
+                collate_fn_args=collate_fn_args,
+                cols_to_retain=model_inputs,
+            )
+        else:
+            # TODO Matt: This is a workaround for older versions of datasets that are missing the `cols_to_retain`
+            #            argument. We should remove this once the minimum supported version of datasets is > 2.3.2
+            unwanted_columns = [
+                feature
+                for feature in dataset.features
+                if feature not in model_inputs and feature not in ("label_ids", "label")
+            ]
+            dataset = dataset.remove_columns(unwanted_columns)
+            output_signature, _ = dataset._get_output_signature(
+                dataset, batch_size=None, collate_fn=collate_fn, collate_fn_args=collate_fn_args
+            )
+        output_columns = list(output_signature.keys())
+        feature_cols = [col for col in output_columns if col in model_inputs and col not in model_labels]
+        label_cols = [col for col in output_columns if col in model_labels]
+
+        # Backwards compatibility for older versions of datasets. Previously, if `columns` or `label_cols`
+        # were a single element list, the returned element spec would be a single element. Now, passing [feature]
+        # will return a dict structure {"feature": feature}, and passing a single string will return a single element.
+        feature_cols = feature_cols[0] if len(feature_cols) == 1 else feature_cols
+        label_cols = label_cols[0] if len(label_cols) == 1 else label_cols
+
+        if drop_remainder is None:
+            drop_remainder = shuffle
+        tf_dataset = dataset.to_tf_dataset(
+            columns=feature_cols,
+            label_cols=label_cols,
+            batch_size=batch_size,
+            shuffle=shuffle,
+            drop_remainder=drop_remainder,
+            collate_fn=collate_fn,
+            collate_fn_args=collate_fn_args,
+            prefetch=prefetch,
+        )
+        return tf_dataset
+
+    def compile(
+        self,
+        optimizer="rmsprop",
+        loss="auto_with_warning",
+        metrics=None,
+        loss_weights=None,
+        weighted_metrics=None,
+        run_eagerly=None,
+        steps_per_execution=None,
+        **kwargs,
+    ):
+        """
+        This is a thin wrapper that sets the model's loss output head as the loss if the user does not specify a loss
+        function themselves.
+        """
+        if loss in ("auto_with_warning", "passthrough"):  # "passthrough" for workflow backward compatibility
+            logger.info(
+                "No loss specified in compile() - the model's internal loss computation will be used as the "
+                "loss. Don't panic - this is a common way to train TensorFlow models in Transformers! "
+                "To disable this behaviour please pass a loss argument, or explicitly pass "
+                "`loss=None` if you do not want your model to compute a loss. You can also specify `loss='auto'` to "
+                "get the internal loss without printing this info string."
+            )
+            loss = "auto"
+        if loss == "auto":
+            loss = dummy_loss
+            self._using_dummy_loss = True
+        else:
+            self._using_dummy_loss = False
+        parent_args = list(inspect.signature(keras.Model.compile).parameters.keys())
+        # This argument got renamed, we need to support both versions
+        if "steps_per_execution" in parent_args:
+            super().compile(
+                optimizer=optimizer,
+                loss=loss,
+                metrics=metrics,
+                loss_weights=loss_weights,
+                weighted_metrics=weighted_metrics,
+                run_eagerly=run_eagerly,
+                steps_per_execution=steps_per_execution,
+                **kwargs,
+            )
+        else:
+            super().compile(
+                optimizer=optimizer,
+                loss=loss,
+                metrics=metrics,
+                loss_weights=loss_weights,
+                weighted_metrics=weighted_metrics,
+                run_eagerly=run_eagerly,
+                experimental_steps_per_execution=steps_per_execution,
+                **kwargs,
+            )
+
+    def compute_loss(self, *args, **kwargs):
+        if hasattr(keras.Model, "compute_loss"):
+            # This will be true in TF 2.8 or greater
+            return super().compute_loss(*args, **kwargs)
+        else:
+            warnings.warn(
+                "The old compute_loss method is deprecated as it conflicts with the Keras compute_loss "
+                "method added in TF 2.8. If you want the original HF compute_loss, please call "
+                "hf_compute_loss() instead. From TF versions >= 2.8, or Transformers versions >= 5, "
+                "calling compute_loss() will get the Keras method instead.",
+                FutureWarning,
+            )
+            return self.hf_compute_loss(*args, **kwargs)
+
+    def get_label_to_output_name_mapping(self):
+        arg_names = list(inspect.signature(self.call).parameters)
+        if self._label_to_output_map is not None:
+            return self._label_to_output_map
+        elif "start_positions" in arg_names:
+            return {"start_positions": "start_logits", "end_positions": "end_logits"}
+        elif "sentence_order_label" in arg_names:
+            return {"labels": "prediction_logits", "sentence_order_label": "sop_logits"}
+        elif "next_sentence_label" in arg_names:
+            return {"labels": "prediction_logits", "next_sentence_label": "seq_relationship_logits"}
+        elif "mc_labels" in arg_names:
+            return {"labels": "logits", "mc_labels": "mc_logits"}
+        else:
+            return {}
+
+    def train_step(self, data):
+        """
+        A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
+        and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
+        labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
+        that they are available to the model during the forward pass.
+        """
+
+        # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
+        arg_names = list(inspect.signature(self.call).parameters)
+        label_kwargs = find_labels(self.__class__)
+        label_to_output = self.get_label_to_output_name_mapping()
+        output_to_label = {val: key for key, val in label_to_output.items()}
+        if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"):
+            # Newer TF train steps leave this out
+            data = expand_1d(data)
+        x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
+        # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
+        # them during input/label pre-processing. This avoids surprising the user by wrecking their data.
+        # In addition, modifying mutable Python inputs makes XLA compilation impossible.
+        if isinstance(x, dict):
+            x = x.copy()
+        if isinstance(y, dict):
+            y = y.copy()
+
+        # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
+        # if those keys are not already present in the input dict
+        if self._using_dummy_loss and y is not None:
+            # If y is a tensor and the model only has one label-like input, map y to that input
+            if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
+                if isinstance(x, tf.Tensor):
+                    x = {arg_names[0]: x}
+                label_kwarg = next(iter(label_kwargs))
+                if label_kwarg not in x:
+                    x[label_kwarg] = y
+            # Otherwise, copy keys from y to x as long as they weren't already present in x
+            elif isinstance(y, dict):
+                if isinstance(x, tf.Tensor):
+                    x = {arg_names[0]: x}
+                for key, val in y.items():
+                    if key in arg_names and key not in x:
+                        x[key] = val
+                    elif output_to_label.get(key) in arg_names and key not in x:
+                        x[output_to_label[key]] = val
+        if y is None:
+            y = {key: val for key, val in x.items() if key in label_kwargs}
+            if not y and not self._using_dummy_loss:
+                raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")
+
+        if isinstance(y, dict):
+            # Rename labels at this point to match output heads
+            y = {label_to_output.get(key, key): val for key, val in y.items()}
+
+        # Run forward pass.
+        with tf.GradientTape() as tape:
+            if self._using_dummy_loss and "return_loss" in arg_names:
+                y_pred = self(x, training=True, return_loss=True)
+            else:
+                y_pred = self(x, training=True)
+            if self._using_dummy_loss:
+                loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
+            else:
+                loss = None
+
+            # This next block matches outputs to label keys. Tensorflow's standard method for doing this
+            # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
+            if isinstance(y, dict) and len(y) == 1:
+                if list(y.keys())[0] in y_pred:
+                    y_pred = y_pred[list(y.keys())[0]]
+                elif list(y_pred.keys())[0] == "loss":
+                    y_pred = y_pred[1]
+                else:
+                    y_pred = y_pred[0]
+                _, y = y.popitem()
+            elif isinstance(y, dict):
+                # If the labels are a dict, match keys from the output by name
+                y_pred = {key: val for key, val in y_pred.items() if key in y}
+            elif isinstance(y, (tuple, list)):
+                # If the labels are a tuple/list, match keys to the output by order, skipping the loss.
+                if list(y_pred.keys())[0] == "loss":
+                    y_pred = y_pred.to_tuple()[1:]
+                else:
+                    y_pred = y_pred.to_tuple()
+                y_pred = y_pred[: len(y)]  # Remove unused fields in case those cause problems
+            else:
+                # If the labels are a single tensor, match them to the first non-loss tensor in the output
+                if list(y_pred.keys())[0] == "loss":
+                    y_pred = y_pred[1]
+                else:
+                    y_pred = y_pred[0]
+
+            if loss is None:
+                loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
+
+        # Run backwards pass.
+        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
+
+        self.compiled_metrics.update_state(y, y_pred, sample_weight)
+        # Collect metrics to return
+        return_metrics = {}
+        for metric in self.metrics:
+            result = metric.result()
+            if isinstance(result, dict):
+                return_metrics.update(result)
+            else:
+                return_metrics[metric.name] = result
+        return return_metrics
+
+    def test_step(self, data):
+        """
+        A modification of Keras's default `train_step` that correctly handles matching outputs to labels for our models
+        and supports directly training on the loss output head. In addition, it ensures input keys are copied to the
+        labels where appropriate. It will also copy label keys into the input dict when using the dummy loss, to ensure
+        that they are available to the model during the forward pass.
+        """
+        # We hardcode the most common renamings; models with weirder names can set `self._label_to_output_map`
+        arg_names = list(inspect.signature(self.call).parameters)
+        label_kwargs = find_labels(self.__class__)
+        label_to_output = self.get_label_to_output_name_mapping()
+        output_to_label = {val: key for key, val in label_to_output.items()}
+        if not self._using_dummy_loss and parse(tf.__version__) < parse("2.11.0"):
+            # Newer versions leave this out
+            data = expand_1d(data)
+        x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
+        # If the inputs are mutable dictionaries, make a shallow copy of them because we will modify
+        # them during input/label pre-processing. This avoids surprising the user by wrecking their data.
+        # In addition, modifying mutable Python inputs makes XLA compilation impossible.
+        if isinstance(x, dict):
+            x = x.copy()
+        if isinstance(y, dict):
+            y = y.copy()
+
+        # When using a dummy loss, we ensure that separate labels are copied to the correct model arguments,
+        # if those keys are not already present in the input dict
+        if self._using_dummy_loss and y is not None:
+            arg_names = list(inspect.signature(self.call).parameters)
+            # If y is a tensor and the model only has one label-like input, map y to that input
+            if len(label_kwargs) == 1 and isinstance(y, tf.Tensor):
+                if isinstance(x, tf.Tensor):
+                    x = {arg_names[0]: x}
+                label_kwarg = next(iter(label_kwargs))
+                if label_kwarg not in x:
+                    x[label_kwarg] = y
+            # Otherwise, copy keys from y to x as long as they weren't already present in x
+            elif isinstance(y, dict):
+                if isinstance(x, tf.Tensor):
+                    x = {arg_names[0]: x}
+                for key, val in y.items():
+                    if key in arg_names and key not in x:
+                        x[key] = val
+                    elif output_to_label.get(key) in arg_names and key not in x:
+                        x[output_to_label[key]] = val
+        if y is None:
+            y = {key: val for key, val in x.items() if key in label_kwargs}
+            if not y and not self._using_dummy_loss:
+                raise ValueError("Could not find label column(s) in input dict and no separate labels were provided!")
+
+        if isinstance(y, dict):
+            # Rename labels at this point to match output heads
+            y = {label_to_output.get(key, key): val for key, val in y.items()}
+
+        # Run forward pass.
+        if self._using_dummy_loss and "return_loss" in arg_names:
+            y_pred = self(x, return_loss=True, training=False)
+        else:
+            y_pred = self(x, training=False)
+        if self._using_dummy_loss:
+            loss = self.compiled_loss(y_pred.loss, y_pred.loss, sample_weight, regularization_losses=self.losses)
+        else:
+            loss = None
+
+        # This next block matches outputs to label keys. Tensorflow's standard method for doing this
+        # can get very confused if any of the keys contain nested values (e.g. lists/tuples of Tensors)
+        if isinstance(y, dict) and len(y) == 1:
+            if list(y.keys())[0] in y_pred:
+                y_pred = y_pred[list(y.keys())[0]]
+            elif list(y_pred.keys())[0] == "loss":
+                y_pred = y_pred[1]
+            else:
+                y_pred = y_pred[0]
+            _, y = y.popitem()
+        elif isinstance(y, dict):
+            # If the labels are a dict, match keys from the output by name
+            y_pred = {key: val for key, val in y_pred.items() if key in y}
+        elif isinstance(y, (tuple, list)):
+            # If the labels are a tuple/list, match keys to the output by order, skipping the loss.
+            if list(y_pred.keys())[0] == "loss":
+                y_pred = y_pred.to_tuple()[1:]
+            else:
+                y_pred = y_pred.to_tuple()
+            y_pred = y_pred[: len(y)]  # Remove unused fields in case those cause problems
+        else:
+            # If the labels are a single tensor, match them to the first non-loss tensor in the output
+            if list(y_pred.keys())[0] == "loss":
+                y_pred = y_pred[1]
+            else:
+                y_pred = y_pred[0]
+
+        if loss is None:
+            loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
+
+        self.compiled_metrics.update_state(y, y_pred, sample_weight)
+        # Collect metrics to return
+        return_metrics = {}
+        for metric in self.metrics:
+            result = metric.result()
+            if isinstance(result, dict):
+                return_metrics.update(result)
+            else:
+                return_metrics[metric.name] = result
+        return return_metrics
+
+    def create_model_card(
+        self,
+        output_dir,
+        model_name: str,
+        language: str | None = None,
+        license: str | None = None,
+        tags: str | None = None,
+        finetuned_from: str | None = None,
+        tasks: str | None = None,
+        dataset_tags: str | list[str] | None = None,
+        dataset: str | list[str] | None = None,
+        dataset_args: str | list[str] | None = None,
+    ):
+        """
+        Creates a draft of a model card using the information available to the `Trainer`.
+
+        Args:
+            output_dir (`str` or `os.PathLike`):
+                The folder in which to create the model card.
+            model_name (`str`, *optional*):
+                The name of the model.
+            language (`str`, *optional*):
+                The language of the model (if applicable)
+            license (`str`, *optional*):
+                The license of the model. Will default to the license of the pretrained model used, if the original
+                model given to the `Trainer` comes from a repo on the Hub.
+            tags (`str` or `list[str]`, *optional*):
+                Some tags to be included in the metadata of the model card.
+            finetuned_from (`str`, *optional*):
+                The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
+                of the original model given to the `Trainer` (if it comes from the Hub).
+            tasks (`str` or `list[str]`, *optional*):
+                One or several task identifiers, to be included in the metadata of the model card.
+            dataset_tags (`str` or `list[str]`, *optional*):
+                One or several dataset tags, to be included in the metadata of the model card.
+            dataset (`str` or `list[str]`, *optional*):
+                One or several dataset identifiers, to be included in the metadata of the model card.
+            dataset_args (`str` or `list[str]`, *optional*):
+               One or several dataset arguments, to be included in the metadata of the model card.
+        """
+        # Avoids a circular import by doing this when necessary.
+        from .modelcard import TrainingSummary  # tests_ignore
+
+        training_summary = TrainingSummary.from_keras(
+            self,
+            keras_history=self.history,
+            language=language,
+            license=license,
+            tags=tags,
+            model_name=model_name,
+            finetuned_from=finetuned_from,
+            tasks=tasks,
+            dataset_tags=dataset_tags,
+            dataset=dataset,
+            dataset_args=dataset_args,
+        )
+        model_card = training_summary.to_model_card()
+        with open(os.path.join(output_dir, "README.md"), "w") as f:
+            f.write(model_card)
+
+    def set_input_embeddings(self, value):
+        """
+        Set model's input embeddings
+
+        Args:
+            value (`tf.Variable`):
+                The new weights mapping hidden states to vocabulary.
+        """
+        main_layer = getattr(self, self.base_model_prefix)
+
+        if main_layer is None:
+            raise NotImplementedError("The model does not implements the base_model_prefix attribute.")
+
+        try:
+            main_layer.set_input_embeddings(value)
+        except AttributeError:
+            logger.info("Building the model")
+            self.build_in_name_scope()
+            main_layer.set_input_embeddings(value)
+
+    def get_output_embeddings(self) -> None | keras.layers.Layer:
+        """
+        Returns the model's output embeddings
+
+        Returns:
+            `tf.Variable`: The new weights mapping vocabulary to hidden states.
+        """
+        if self.get_lm_head() is not None:
+            lm_head = self.get_lm_head()
+
+            try:
+                return lm_head.get_output_embeddings()
+            except AttributeError:
+                logger.info("Building the model")
+                self.build_in_name_scope()
+
+                return lm_head().get_output_embeddings()
+
+        return None  # Overwrite for models with output embeddings
+
+    def set_output_embeddings(self, value):
+        """
+        Set model's output embeddings
+
+        Args:
+            value (`tf.Variable`):
+                The new weights mapping hidden states to vocabulary.
+        """
+        if self.get_lm_head() is not None:
+            lm_head = self.get_lm_head()
+            try:
+                lm_head.set_output_embeddings(value)
+            except AttributeError:
+                logger.info("Building the model")
+                self.build_in_name_scope()
+                lm_head.set_output_embeddings(value)
+
+    def get_output_layer_with_bias(self) -> None | keras.layers.Layer:
+        """
+        Get the layer that handles a bias attribute in case the model has an LM head with weights tied to the
+        embeddings
+
+        Return:
+            `keras.layers.Layer`: The layer that handles the bias, None if not an LM model.
+        """
+        warnings.warn(
+            "The method get_output_layer_with_bias is deprecated. Please use `get_lm_head` instead.", FutureWarning
+        )
+        return self.get_lm_head()
+
+    def get_prefix_bias_name(self) -> None | str:
+        """
+        Get the concatenated _prefix name of the bias from the model name to the parent layer
+
+        Return:
+            `str`: The _prefix name of the bias.
+        """
+        warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
+        return None
+
+    def get_bias(self) -> None | dict[str, tf.Variable]:
+        """
+        Dict of bias attached to an LM head. The key represents the name of the bias attribute.
+
+        Return:
+            `tf.Variable`: The weights representing the bias, None if not an LM model.
+        """
+        if self.get_lm_head() is not None:
+            lm_head = self.get_lm_head()
+            try:
+                return lm_head.get_bias()
+            except AttributeError:
+                self.build_in_name_scope()
+
+                return lm_head.get_bias()
+        return None
+
+    def set_bias(self, value):
+        """
+        Set all the bias in the LM head.
+
+        Args:
+            value (`dict[tf.Variable]`):
+                All the new bias attached to an LM head.
+        """
+        if self.get_lm_head() is not None:
+            lm_head = self.get_lm_head()
+            try:
+                lm_head.set_bias(value)
+            except AttributeError:
+                self.build_in_name_scope()
+                lm_head.set_bias(value)
+
+    def get_lm_head(self) -> keras.layers.Layer:
+        """
+        The LM Head layer. This method must be overwritten by all the models that have a lm head.
+
+        Return:
+            `keras.layers.Layer`: The LM head layer if the model has one, None if not.
+        """
+        return None
+
+    def resize_token_embeddings(self, new_num_tokens: int | None = None) -> keras.layers.Embedding | tf.Variable:
+        """
+        Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
+
+        Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
+
+        Arguments:
+            new_num_tokens (`int`, *optional*):
+                The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
+                vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
+                returns a pointer to the input tokens without doing anything.
+
+        Return:
+            `tf.Variable` or `keras.layers.Embedding`: Pointer to the input tokens of the model.
+        """
+        # TODO (joao): flagged for replacement (by `_v2_resized_token_embeddings`) due to embeddings refactor
+
+        # Run the new code path if the model has a keras embeddings layer
+        if isinstance(self.get_input_embeddings(), keras.layers.Embedding):
+            return self._v2_resized_token_embeddings(new_num_tokens)
+
+        if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
+            return self._get_word_embedding_weight(self.get_input_embeddings())
+
+        model_embeds = self._resize_token_embeddings(new_num_tokens)
+
+        # Update base model and current model config
+        self.config.vocab_size = new_num_tokens
+
+        return model_embeds
+
+    def _v2_resized_token_embeddings(self, new_num_tokens: int | None = None) -> keras.layers.Embedding:
+        """
+        Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
+
+        Arguments:
+            new_num_tokens (`int`, *optional*):
+                The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
+                vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
+                returns a pointer to the input tokens without doing anything.
+
+        Return:
+            `keras.layers.Embedding`: Pointer to the input tokens of the model.
+        """
+        if new_num_tokens is None or new_num_tokens == self.config.vocab_size:
+            return self.get_input_embeddings()
+
+        model_embeds = self._v2_resize_token_embeddings(new_num_tokens)
+
+        # Update base model and current model config
+        self.config.vocab_size = new_num_tokens
+
+        return model_embeds
+
+    def _get_word_embedding_weight(model, embedding_layer):
+        # TODO (joao): flagged for detection due to embeddings refactor
+
+        # If the variable holds the weights themselves, return them
+        if isinstance(embedding_layer, tf.Tensor):
+            return embedding_layer
+        # Otherwise, try to get them from the layer's attributes
+
+        embeds = getattr(embedding_layer, "weight", None)
+        if embeds is not None:
+            return embeds
+
+        embeds = getattr(embedding_layer, "decoder", None)
+        if embeds is not None:
+            return embeds
+
+        # The reason why the attributes don't exist might be
+        # because the model is not built, so retry getting
+        # the argument after building the model
+        model.build_in_name_scope()
+
+        embeds = getattr(embedding_layer, "weight", None)
+        if embeds is not None:
+            return embeds
+
+        embeds = getattr(embedding_layer, "decoder", None)
+        if embeds is not None:
+            return embeds
+
+        return None
+
+    def _resize_token_embeddings(self, new_num_tokens):
+        # TODO (joao): flagged for replacement (by `_v2_resize_token_embeddings`) due to embeddings refactor
+        old_embeddings = self._get_word_embedding_weight(self.get_input_embeddings())
+        new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
+
+        # if word embeddings are not tied, make sure that lm head bias is resized as well
+        if self.get_bias() is not None:
+            old_lm_head_bias = self.get_bias()
+            new_lm_head_bias = self._get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens)
+
+            self.set_bias(new_lm_head_bias)
+
+        # if word embeddings are not tied, make sure that lm head decoder is resized as well
+        if self.get_output_embeddings() is not None:
+            old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings())
+            new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens)
+
+            self.set_output_embeddings(new_lm_head_decoder)
+
+        self.set_input_embeddings(new_embeddings)
+
+        return self.get_input_embeddings()
+
+    def _v2_resize_token_embeddings(self, new_num_tokens):
+        old_embeddings = self.get_input_embeddings()
+        new_embeddings = self._v2_get_resized_embeddings(old_embeddings, new_num_tokens)
+        self.set_input_embeddings(new_embeddings)
+
+        # If word embeddings are not tied, make sure that lm head bias is resized as well
+        if self.get_bias() is not None:
+            old_lm_head_bias = self.get_bias()
+            new_lm_head_bias = self._v2_get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens)
+            self.set_bias(new_lm_head_bias)
+
+        # If word embeddings are not tied, make sure that lm head decoder is resized as well.
+        tied_weights = self.get_input_embeddings() == self.get_output_embeddings()
+        if self.get_output_embeddings() is not None and not tied_weights:
+            old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings())
+            # TODO (joao): this one probably needs a v2 version with other models
+            new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens)
+            self.set_output_embeddings(new_lm_head_decoder)
+
+        return self.get_input_embeddings()
+
+    def _get_resized_lm_head_bias(self, old_lm_head_bias, new_num_tokens):
+        """
+        Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end.
+        Reducing the size will remove vectors from the end
+
+        Args:
+            old_lm_head_bias (`tf.Variable`):
+                Old lm head bias to be resized.
+            new_num_tokens (`int`, *optional*):
+                New number of tokens in the linear matrix.
+
+                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
+                vectors from the end. If not provided or `None`, just returns None
+
+        Return:
+            `tf.Variable`: Pointer to the resized bias.
+        """
+        # TODO (joao): flagged for replacement (by `_v2_get_resized_lm_head_bias`) due to embeddings refactor
+        new_lm_head_bias = {}
+
+        for attr, weight in old_lm_head_bias.items():
+            first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight)
+            size_diff = new_num_tokens - old_num_tokens
+            final_shape = [new_num_tokens] if first_dim is None else [first_dim, new_num_tokens]
+
+            # initialize new bias
+            if tf.math.greater(size_diff, 0):
+                padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]]
+                current_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape), constant_values=-1)
+                num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
+                mask_shape = [num_tokens_to_copy] if first_dim is None else [1, num_tokens_to_copy]
+                bias_mask = tf.fill(tf.convert_to_tensor(mask_shape), True)
+                bias_mask = tf.pad(bias_mask, tf.convert_to_tensor(padding_shape), constant_values=False)
+            else:
+                slice_from = [0] if first_dim is None else [0, 0]
+                current_bias = tf.slice(
+                    weight.value(), tf.convert_to_tensor(slice_from), tf.convert_to_tensor(final_shape)
+                )
+                bias_mask = tf.fill(tf.convert_to_tensor(final_shape), True)
+
+            new_bias = self.add_weight(
+                shape=final_shape,
+                initializer="zeros",
+                trainable=True,
+                name=weight.name.split(":")[0],
+            )
+            init_bias = tf.where(bias_mask, current_bias, new_bias.value())
+
+            new_bias.assign(init_bias)
+            new_lm_head_bias[attr] = new_bias
+
+        return new_lm_head_bias
+
+    def _v2_get_resized_lm_head_bias(
+        self, old_lm_head_bias: dict[str, tf.Variable], new_num_tokens: int
+    ) -> dict[str, tf.Tensor]:
+        """
+        Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end.
+        Reducing the size will remove vectors from the end
+
+        Args:
+            old_lm_head_bias (`dict[str, tf.Variable]`):
+                Old lm head bias to be resized.
+            new_num_tokens (`int`):
+                New number of tokens in the linear matrix. Increasing the size will add newly initialized vectors at
+                the end. Reducing the size will remove vectors from the end.
+
+        Return:
+            `tf.Tensor`: Values for the resized bias.
+        """
+        new_lm_head_bias = {}
+
+        for attr, weight in old_lm_head_bias.items():
+            # Determine the size difference (depending on the shape)
+            first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight)
+            size_diff = new_num_tokens - old_num_tokens
+
+            # Copy the old bias values to the new bias
+            if old_num_tokens > new_num_tokens:
+                new_bias = weight.value()[..., :new_num_tokens]
+            else:
+                padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]]
+                new_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape))
+
+            new_lm_head_bias[attr] = new_bias
+        return new_lm_head_bias
+
+    def _get_resized_lm_head_decoder(self, old_lm_head_decoder, new_num_tokens):
+        """
+        Build a resized decoder from the old ones. Increasing the size will add newly initialized vectors at the end.
+        Reducing the size will remove vectors from the end
+
+        Args:
+            old_lm_head_decoder (`tf.Variable`):
+                Old lm head decoder to be resized.
+            new_num_tokens (`int`, *optional*):
+                New number of tokens in the linear matrix.
+
+                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
+                vectors from the end. If not provided or `None`, just returns None
+
+        Return:
+            `tf.Variable`: Pointer to the resized decoder or None if the output embeddings are different from the input
+            ones.
+        """
+        new_lm_head_decoder = old_lm_head_decoder
+        is_input_output_equals = tf.reduce_any(
+            self._get_word_embedding_weight(self.get_input_embeddings()) == old_lm_head_decoder
+        )
+
+        if old_lm_head_decoder is not None and not is_input_output_equals:
+            old_embedding_dim = shape_list(old_lm_head_decoder)[1]
+            decoder_mask, current_decoder = init_copy_embeddings(old_lm_head_decoder, new_num_tokens)
+            new_lm_head_decoder = self.add_weight(
+                shape=(new_num_tokens, old_embedding_dim),
+                initializer="zeros",
+                trainable=True,
+                name=old_lm_head_decoder.name.split(":")[0],
+            )
+            init_decoder = tf.where(decoder_mask, current_decoder, new_lm_head_decoder.value())
+
+            new_lm_head_decoder.assign(init_decoder)
+
+        return new_lm_head_decoder
+
+    def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable:
+        """
+        Build a resized Embedding weights from a provided token Embedding weights. Increasing the size will add newly
+        initialized vectors at the end. Reducing the size will remove vectors from the end
+
+        Args:
+            old_embeddings (`tf.Variable`):
+                Old embeddings to be resized.
+            new_num_tokens (`int`, *optional*):
+                New number of tokens in the embedding matrix.
+
+                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
+                vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
+                `tf.Variable` module of the model without doing anything.
+
+        Return:
+            `tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if `new_num_tokens` is
+            `None`
+        """
+        # TODO (joao): flagged for replacement (by `_v2_get_resized_embeddings`) due to embeddings refactor
+        old_embedding_dim = shape_list(old_embeddings)[1]
+        init_range = getattr(self.config, "initializer_range", 0.02)
+        embeddings_mask, current_embeddings = init_copy_embeddings(old_embeddings, new_num_tokens)
+        new_embeddings = self.add_weight(
+            name=old_embeddings.name.split(":")[0],
+            shape=[new_num_tokens, old_embedding_dim],
+            initializer=get_initializer(init_range),
+            dtype=tf.float32,
+        )
+        init_embeddings = tf.where(embeddings_mask, current_embeddings, new_embeddings.value())
+
+        new_embeddings.assign(init_embeddings)
+
+        return new_embeddings
+
+    def _v2_get_resized_embeddings(
+        self, old_embeddings: keras.layers.Embedding, new_num_tokens: int
+    ) -> keras.layers.Embedding:
+        """
+        Build a resized Embedding layer from a provided Embedding layer. Increasing the size will add newly initialized
+        vectors at the end. Reducing the size will remove vectors from the end.
+
+        Args:
+            old_embeddings (`keras.layers.Embedding`):
+                Old embeddings to be resized.
+            new_num_tokens (`int`, *optional*):
+                New number of tokens in the embedding matrix.
+
+        Return:
+            `keras.layers.Embedding`: Resized Embedding layer.
+        """
+
+        # Get the initialization range for the embeddings
+        init_range = 0.02  # default value
+        potential_initialization_variable_names = [
+            "initializer_range",  # most common
+            "initializer_factor",  # e.g. T5
+            "init_std",  # e.g BART
+        ]
+        for var_name in potential_initialization_variable_names:
+            if hasattr(self.config, var_name):
+                init_range = getattr(self.config, var_name)
+
+        # Get a new (initialized) embeddings layer
+        new_embeddings = keras.layers.Embedding(
+            input_dim=new_num_tokens,
+            output_dim=old_embeddings.output_dim,
+            embeddings_initializer=keras.initializers.TruncatedNormal(stddev=init_range),
+            name=old_embeddings.embeddings.name[:-13],  # exact same scoped name except "/embeddings:0"
+        )
+        new_embeddings(tf.constant([[0]]))
+
+        # Copy the old embeddings to the new embeddings
+        if old_embeddings.input_dim >= new_num_tokens:
+            init_embeddings = old_embeddings.embeddings[:new_num_tokens]
+        else:
+            init_embeddings = tf.concat(
+                [old_embeddings.embeddings, new_embeddings.embeddings[old_embeddings.input_dim :]], axis=0
+            )
+        new_embeddings.embeddings.assign(init_embeddings)
+        return new_embeddings
+
+    def prune_heads(self, heads_to_prune):
+        """
+        Prunes heads of the base model.
+
+        Arguments:
+            heads_to_prune (`dict[int, list[int]]`):
+                Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads
+                to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on
+                layer 1 and heads 2 and 3 on layer 2.
+        """
+        raise NotImplementedError
+
+    def save_pretrained(
+        self,
+        save_directory,
+        saved_model=False,
+        version=1,
+        push_to_hub=False,
+        signatures=None,
+        max_shard_size: int | str = "5GB",
+        create_pr: bool = False,
+        safe_serialization: bool = False,
+        token: str | bool | None = None,
+        **kwargs,
+    ):
+        """
+        Save a model and its configuration file to a directory, so that it can be re-loaded using the
+        [`~TFPreTrainedModel.from_pretrained`] class method.
+
+        Arguments:
+            save_directory (`str`):
+                Directory to which to save. Will be created if it doesn't exist.
+            saved_model (`bool`, *optional*, defaults to `False`):
+                If the model has to be saved in saved model format as well or not.
+            version (`int`, *optional*, defaults to 1):
+                The version of the saved model. A saved model needs to be versioned in order to be properly loaded by
+                TensorFlow Serving as detailed in the official documentation
+                https://www.tensorflow.org/tfx/serving/serving_basic
+            push_to_hub (`bool`, *optional*, defaults to `False`):
+                Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+                namespace).
+            signatures (`dict` or `tf.function`, *optional*):
+                Model's signature used for serving. This will be passed to the `signatures` argument of model.save().
+            max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
+                The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
+                lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
+
+                
+
+                If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
+                which will be bigger than `max_shard_size`.
+
+                
+
+            create_pr (`bool`, *optional*, defaults to `False`):
+                Whether or not to create a PR with the uploaded files or directly commit.
+            safe_serialization (`bool`, *optional*, defaults to `False`):
+                Whether to save the model using `safetensors` or the traditional TensorFlow way (that uses `h5`).
+            token (`str` or `bool`, *optional*):
+                The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+                the token generated when running `hf auth login` (stored in `~/.huggingface`).
+            kwargs (`dict[str, Any]`, *optional*):
+                Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+        """
+        use_auth_token = kwargs.pop("use_auth_token", None)
+
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if token is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            token = use_auth_token
+
+        if token is not None:
+            kwargs["token"] = token
+
+        if os.path.isfile(save_directory):
+            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+            return
+
+        os.makedirs(save_directory, exist_ok=True)
+
+        if push_to_hub:
+            commit_message = kwargs.pop("commit_message", None)
+            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+            repo_id = self._create_repo(repo_id, **kwargs)
+            files_timestamps = self._get_files_timestamps(save_directory)
+
+        if saved_model:
+            # If `torch_dtype` is in the config with a torch dtype class as the value, we need to change it to string.
+            # (Although TF doesn't care about this attribute, we can't just remove it or set it to `None`.)
+            if getattr(self.config, "torch_dtype", None) is not None and not isinstance(self.config.torch_dtype, str):
+                self.config.torch_dtype = str(self.config.torch_dtype).split(".")[1]
+            if signatures is None:
+                serving_default = self.serving.get_concrete_function(self.input_signature)
+                if any(spec.dtype == tf.int32 for spec in self.input_signature.values()):
+                    int64_spec = {
+                        key: tf.TensorSpec(
+                            shape=spec.shape, dtype=tf.int64 if spec.dtype == tf.int32 else spec.dtype, name=spec.name
+                        )
+                        for key, spec in self.input_signature.items()
+                    }
+                    int64_serving = self.serving.get_concrete_function(int64_spec)
+                    signatures = {"serving_default": serving_default, "int64_serving": int64_serving}
+                else:
+                    signatures = serving_default
+            saved_model_dir = os.path.join(save_directory, "saved_model", str(version))
+            self.save(saved_model_dir, include_optimizer=False, signatures=signatures)
+            logger.info(f"Saved model created in {saved_model_dir}")
+
+        # Save configuration file
+        self.config.architectures = [self.__class__.__name__[2:]]
+
+        # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
+        # loaded from the Hub.
+        if self._auto_class is not None:
+            custom_object_save(self, save_directory, config=self.config)
+
+        self.config.save_pretrained(save_directory)
+        if self.can_generate():
+            self.generation_config.save_pretrained(save_directory)
+
+        # If we save using the predefined names, we can load using `from_pretrained`
+        weights_name = SAFE_WEIGHTS_NAME if safe_serialization else TF2_WEIGHTS_NAME
+        output_model_file = os.path.join(save_directory, weights_name)
+
+        shards, index = tf_shard_checkpoint(self.weights, max_shard_size, weights_name=weights_name)
+
+        # Clean the folder from a previous save
+        for filename in os.listdir(save_directory):
+            full_filename = os.path.join(save_directory, filename)
+            # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
+            # in distributed settings to avoid race conditions.
+            weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
+            if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and filename not in shards:
+                os.remove(full_filename)
+
+        if index is None:
+            if safe_serialization:
+                state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in self.weights}
+                safe_save_file(state_dict, output_model_file, metadata={"format": "tf"})
+            else:
+                self.save_weights(output_model_file)
+            logger.info(f"Model weights saved in {output_model_file}")
+        else:
+            save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else TF2_WEIGHTS_INDEX_NAME
+            save_index_file = os.path.join(save_directory, save_index_file)
+            # Save the index as well
+            with open(save_index_file, "w", encoding="utf-8") as index_file:
+                content = json.dumps(index, indent=2, sort_keys=True) + "\n"
+                index_file.write(content)
+            logger.info(
+                f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
+                f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
+                f"index located at {save_index_file}."
+            )
+            for shard_file, shard in shards.items():
+                if safe_serialization:
+                    shard_state_dict = {strip_model_name_and_prefix(w.name): w.value() for w in shard}
+                    safe_save_file(
+                        shard_state_dict, os.path.join(save_directory, shard_file), metadata={"format": "tf"}
+                    )
+                else:
+                    with h5py.File(os.path.join(save_directory, shard_file), mode="w") as shard_file:
+                        layers = []
+                        for layer in sorted(shard, key=lambda x: x.name):
+                            if "model." in layer.name or len(layer.name.split("/")) == 1:
+                                layer_name = layer.name
+                            else:
+                                layer_name = "/".join(layer.name.split("/")[1:])
+                            param_dset = shard_file.create_dataset(
+                                layer_name, layer.numpy().shape, dtype=layer.numpy().dtype
+                            )
+                            param_dset[:] = layer.numpy()
+                            layers.append(layer_name.encode("utf8"))
+                        save_attributes_to_hdf5_group(shard_file, "layer_names", layers)
+
+        if push_to_hub:
+            self._upload_modified_files(
+                save_directory,
+                repo_id,
+                files_timestamps,
+                commit_message=commit_message,
+                token=token,
+            )
+
+    @classmethod
+    def from_pretrained(
+        cls,
+        pretrained_model_name_or_path: str | os.PathLike | None,
+        *model_args,
+        config: PretrainedConfig | str | os.PathLike | None = None,
+        cache_dir: str | os.PathLike | None = None,
+        ignore_mismatched_sizes: bool = False,
+        force_download: bool = False,
+        local_files_only: bool = False,
+        token: str | bool | None = None,
+        revision: str = "main",
+        use_safetensors: bool | None = None,
+        **kwargs,
+    ):
+        r"""
+        Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
+
+        The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+        task.
+
+        The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+        weights are discarded.
+
+        Parameters:
+            pretrained_model_name_or_path (`str`, *optional*):
+                Can be either:
+
+                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+                    - A path to a *directory* containing model weights saved using
+                      [`~TFPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+                    - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this
+                      case, `from_pt` should be set to `True` and a configuration object should be provided as `config`
+                      argument. This loading path is slower than converting the PyTorch model in a TensorFlow model
+                      using the provided conversion scripts and loading the TensorFlow model afterwards.
+                    - `None` if you are both providing the configuration and state dictionary (resp. with keyword
+                      arguments `config` and `state_dict`).
+            model_args (sequence of positional arguments, *optional*):
+                All remaining positional arguments will be passed to the underlying model's `__init__` method.
+            config (`Union[PretrainedConfig, str]`, *optional*):
+                Can be either:
+
+                    - an instance of a class derived from [`PretrainedConfig`],
+                    - a string valid as input to [`~PretrainedConfig.from_pretrained`].
+
+                Configuration for the model to use instead of an automatically loaded configuration. Configuration can
+                be automatically loaded when:
+
+                    - The model is a model provided by the library (loaded with the *model id* string of a pretrained
+                      model).
+                    - The model was saved using [`~TFPreTrainedModel.save_pretrained`] and is reloaded by supplying the
+                      save directory.
+                    - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
+                      configuration JSON file named *config.json* is found in the directory.
+            from_pt (`bool`, *optional*, defaults to `False`):
+                Load the model weights from a PyTorch state_dict save file (see docstring of
+                `pretrained_model_name_or_path` argument).
+            ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
+                Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
+                as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
+                checkpoint with 3 labels).
+            cache_dir (`str`, *optional*):
+                Path to a directory in which a downloaded pretrained model configuration should be cached if the
+                standard cache should not be used.
+            force_download (`bool`, *optional*, defaults to `False`):
+                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+                cached versions if they exist.
+            resume_download:
+                Deprecated and ignored. All downloads are now resumed by default when possible.
+                Will be removed in v5 of Transformers.
+            proxies:
+                (`dict[str, str], `optional`): A dictionary of proxy servers to use by protocol or endpoint, e.g.,
+                `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+                output_loading_info(`bool`, *optional*, defaults to `False`): Whether ot not to also return a
+                dictionary containing missing keys, unexpected keys and error messages.
+            local_files_only(`bool`, *optional*, defaults to `False`):
+                Whether or not to only look at local files (e.g., not try downloading the model).
+            token (`str` or `bool`, *optional*):
+                The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+                the token generated when running `hf auth login` (stored in `~/.huggingface`).
+            revision (`str`, *optional*, defaults to `"main"`):
+                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+                identifier allowed by git.
+
+
+                
+
+                To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`.
+
+                
+
+            mirror (`str`, *optional*):
+                Mirror source to accelerate downloads in China. If you are from China and have an accessibility
+                problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
+                Please refer to the mirror site for more information.
+            subfolder (`str`, *optional*, defaults to `""`):
+                In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+                specify the folder name here.
+            tf_to_pt_weight_rename (`Callable`, *optional*):
+                A function that is called to transform the names of weights during the PyTorch to TensorFlow
+                crossloading process. This is not necessary for most models, but is useful to allow composite models to
+                be crossloaded correctly.
+            use_safetensors (`bool`, *optional*, defaults to `None`):
+                Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
+                is not installed, it will be set to `False`.
+            kwargs (remaining dictionary of keyword arguments, *optional*):
+                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+                `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
+                automatically loaded:
+
+                    - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
+                      underlying model's `__init__` method (we assume all relevant updates to the configuration have
+                      already been done)
+                    - If a configuration is not provided, `kwargs` will be first passed to the configuration class
+                      initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
+                      corresponds to a configuration attribute will be used to override said attribute with the
+                      supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
+                      will be passed to the underlying model's `__init__` function.
+
+        Examples:
+
+        ```python
+        >>> from transformers import BertConfig, TFBertModel
+
+        >>> # Download model and configuration from huggingface.co and cache.
+        >>> model = TFBertModel.from_pretrained("google-bert/bert-base-uncased")
+        >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
+        >>> model = TFBertModel.from_pretrained("./test/saved_model/")
+        >>> # Update configuration during loading.
+        >>> model = TFBertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True)
+        >>> assert model.config.output_attentions == True
+        >>> # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable).
+        >>> config = BertConfig.from_json_file("./pt_model/my_pt_model_config.json")
+        >>> model = TFBertModel.from_pretrained("./pt_model/my_pytorch_model.bin", from_pt=True, config=config)
+        ```"""
+        from_pt = kwargs.pop("from_pt", False)
+        resume_download = kwargs.pop("resume_download", None)
+        proxies = kwargs.pop("proxies", None)
+        output_loading_info = kwargs.pop("output_loading_info", False)
+        use_auth_token = kwargs.pop("use_auth_token", None)
+        trust_remote_code = kwargs.pop("trust_remote_code", None)
+        _ = kwargs.pop("mirror", None)
+        load_weight_prefix = kwargs.pop("load_weight_prefix", None)
+        from_pipeline = kwargs.pop("_from_pipeline", None)
+        from_auto_class = kwargs.pop("_from_auto", False)
+        subfolder = kwargs.pop("subfolder", "")
+        commit_hash = kwargs.pop("_commit_hash", None)
+        tf_to_pt_weight_rename = kwargs.pop("tf_to_pt_weight_rename", None)
+
+        # Not relevant for TF models
+        _ = kwargs.pop("adapter_kwargs", None)
+
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if token is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            token = use_auth_token
+
+        if trust_remote_code is True:
+            logger.warning(
+                "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
+                " ignored."
+            )
+
+        user_agent = {"file_type": "model", "framework": "tensorflow", "from_auto_class": from_auto_class}
+        if from_pipeline is not None:
+            user_agent["using_pipeline"] = from_pipeline
+
+        if is_offline_mode() and not local_files_only:
+            logger.info("Offline mode: forcing local_files_only=True")
+            local_files_only = True
+
+        if use_safetensors is None and not is_safetensors_available():
+            use_safetensors = False
+
+        # Load config if we don't provide a configuration
+        if not isinstance(config, PretrainedConfig):
+            config_path = config if config is not None else pretrained_model_name_or_path
+            config, model_kwargs = cls.config_class.from_pretrained(
+                config_path,
+                cache_dir=cache_dir,
+                return_unused_kwargs=True,
+                force_download=force_download,
+                resume_download=resume_download,
+                proxies=proxies,
+                local_files_only=local_files_only,
+                token=token,
+                revision=revision,
+                _from_auto=from_auto_class,
+                _from_pipeline=from_pipeline,
+                _commit_hash=commit_hash,
+                **kwargs,
+            )
+        else:
+            model_kwargs = kwargs
+
+        if commit_hash is None:
+            commit_hash = getattr(config, "_commit_hash", None)
+
+        # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
+        # index of the files.
+        is_sharded = False
+        # Load model
+        if pretrained_model_name_or_path is not None:
+            pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+            is_local = os.path.isdir(pretrained_model_name_or_path)
+            if is_local:
+                if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
+                    # Load from a PyTorch checkpoint in priority if from_pt
+                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
+                elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
+                    # Load from a sharded PyTorch checkpoint
+                    archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
+                    is_sharded = True
+                elif use_safetensors is not False and os.path.isfile(
+                    os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
+                ):
+                    # Load from a safetensors checkpoint
+                    archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
+                elif use_safetensors is not False and os.path.isfile(
+                    os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
+                ):
+                    # Load from a sharded safetensors checkpoint
+                    archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
+                    is_sharded = True
+                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
+                    # Load from a TF 2.0 checkpoint
+                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
+                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)):
+                    # Load from a sharded TF 2.0 checkpoint
+                    archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
+                    is_sharded = True
+
+                # At this stage we don't have a weight file so we will raise an error.
+                elif use_safetensors:
+                    raise OSError(
+                        f"Error no file named {SAFE_WEIGHTS_NAME} or {SAFE_WEIGHTS_INDEX_NAME} found in directory {pretrained_model_name_or_path}. "
+                        f"Please make sure that the model has been saved with `safe_serialization=True` or do not "
+                        f"set `use_safetensors=True`."
+                    )
+                elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile(
+                    os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
+                ):
+                    raise OSError(
+                        f"Error no file named {TF2_WEIGHTS_NAME} or {SAFE_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
+                        "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
+                        "weights."
+                    )
+                else:
+                    raise OSError(
+                        f"Error no file named {TF2_WEIGHTS_NAME}, {SAFE_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
+                        f"{pretrained_model_name_or_path}."
+                    )
+            elif os.path.isfile(pretrained_model_name_or_path):
+                archive_file = pretrained_model_name_or_path
+                is_local = True
+            elif os.path.isfile(pretrained_model_name_or_path + ".index"):
+                archive_file = pretrained_model_name_or_path + ".index"
+                is_local = True
+            elif is_remote_url(pretrained_model_name_or_path):
+                filename = pretrained_model_name_or_path
+                resolved_archive_file = download_url(pretrained_model_name_or_path)
+            else:
+                # set correct filename
+                if from_pt:
+                    filename = WEIGHTS_NAME
+                elif use_safetensors is not False:
+                    filename = SAFE_WEIGHTS_NAME
+                else:
+                    filename = TF2_WEIGHTS_NAME
+
+                try:
+                    # Load from URL or cache if already cached
+                    cached_file_kwargs = {
+                        "cache_dir": cache_dir,
+                        "force_download": force_download,
+                        "proxies": proxies,
+                        "resume_download": resume_download,
+                        "local_files_only": local_files_only,
+                        "token": token,
+                        "user_agent": user_agent,
+                        "revision": revision,
+                        "subfolder": subfolder,
+                        "_raise_exceptions_for_gated_repo": False,
+                        "_raise_exceptions_for_missing_entries": False,
+                        "_commit_hash": commit_hash,
+                    }
+                    resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
+
+                    # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
+                    # result when internet is up, the repo and revision exist, but the file does not.
+                    if resolved_archive_file is None and filename == SAFE_WEIGHTS_NAME:
+                        # Did not find the safetensors file, let's fallback to TF.
+                        # No support for sharded safetensors yet, so we'll raise an error if that's all we find.
+                        filename = TF2_WEIGHTS_NAME
+                        resolved_archive_file = cached_file(
+                            pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **cached_file_kwargs
+                        )
+                    if resolved_archive_file is None and filename == TF2_WEIGHTS_NAME:
+                        # Maybe the checkpoint is sharded, we try to grab the index name in this case.
+                        resolved_archive_file = cached_file(
+                            pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME, **cached_file_kwargs
+                        )
+                        if resolved_archive_file is not None:
+                            is_sharded = True
+                    if resolved_archive_file is None and filename == WEIGHTS_NAME:
+                        # Maybe the checkpoint is sharded, we try to grab the index name in this case.
+                        resolved_archive_file = cached_file(
+                            pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
+                        )
+                        if resolved_archive_file is not None:
+                            is_sharded = True
+                    if resolved_archive_file is None:
+                        # Otherwise, maybe there is a PyTorch or Flax model file.  We try those to give a helpful error
+                        # message.
+                        has_file_kwargs = {
+                            "revision": revision,
+                            "proxies": proxies,
+                            "token": token,
+                            "cache_dir": cache_dir,
+                            "local_files_only": local_files_only,
+                        }
+                        if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
+                            is_sharded = True
+                        elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
+                            raise OSError(
+                                f"{pretrained_model_name_or_path} does not appear to have a file named"
+                                f" {TF2_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
+                                " load this model from those weights."
+                            )
+                        else:
+                            raise OSError(
+                                f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
+                                f" {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}"
+                            )
+
+                except OSError:
+                    # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
+                    # to the original exception.
+                    raise
+                except Exception:
+                    # For any other exception, we throw a generic error.
+
+                    raise OSError(
+                        f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
+                        " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+                        f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+                        f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}"
+                    )
+            if is_local:
+                logger.info(f"loading weights file {archive_file}")
+                resolved_archive_file = archive_file
+                filename = resolved_archive_file.split(os.path.sep)[-1]
+            else:
+                logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
+        else:
+            resolved_archive_file = None
+
+        # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
+        if is_sharded:
+            # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
+            resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
+                pretrained_model_name_or_path,
+                resolved_archive_file,
+                cache_dir=cache_dir,
+                force_download=force_download,
+                proxies=proxies,
+                resume_download=resume_download,
+                local_files_only=local_files_only,
+                token=token,
+                user_agent=user_agent,
+                revision=revision,
+                _commit_hash=commit_hash,
+            )
+
+        safetensors_from_pt = False
+        if filename == SAFE_WEIGHTS_NAME:
+            with safe_open(resolved_archive_file, framework="tf") as f:
+                safetensors_metadata = f.metadata()
+            if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
+                raise OSError(
+                    f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
+                    " Make sure you save your model with the `save_pretrained` method."
+                )
+            safetensors_from_pt = safetensors_metadata.get("format") == "pt"
+        elif filename == SAFE_WEIGHTS_INDEX_NAME:
+            with safe_open(resolved_archive_file[0], framework="tf") as f:
+                safetensors_metadata = f.metadata()
+            if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
+                raise OSError(
+                    f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
+                    " Make sure you save your model with the `save_pretrained` method."
+                )
+            safetensors_from_pt = safetensors_metadata.get("format") == "pt"
+
+        config.name_or_path = pretrained_model_name_or_path
+
+        # composed models, *e.g.* TFRag, require special treatment when it comes to loading
+        # pre-trained weights.
+        if cls._requires_load_weight_prefix and model_kwargs.get("name") is not None:
+            model_kwargs["load_weight_prefix"] = load_weight_prefix + "/" + model_kwargs.get("name")
+
+        # Instantiate model.
+        model = cls(config, *model_args, **model_kwargs)
+
+        if tf_to_pt_weight_rename is None and hasattr(model, "tf_to_pt_weight_rename"):
+            # TODO Matt: This is a temporary workaround to allow weight renaming, but requires a method
+            #            to be defined for each class that requires a rename. We can probably just have a class-level
+            #            dict and a single top-level method or something and cut down a lot of boilerplate code
+            tf_to_pt_weight_rename = model.tf_to_pt_weight_rename
+
+        if from_pt:
+            from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model
+
+            # Load from a PyTorch checkpoint
+            return load_pytorch_checkpoint_in_tf2_model(
+                model,
+                resolved_archive_file,
+                allow_missing_keys=True,
+                output_loading_info=output_loading_info,
+                _prefix=load_weight_prefix,
+                tf_to_pt_weight_rename=tf_to_pt_weight_rename,
+            )
+
+        # we might need to extend the variable scope for composite models
+        if load_weight_prefix is not None:
+            with tf.compat.v1.variable_scope(load_weight_prefix):
+                model.build_in_name_scope()  # build the network with dummy inputs
+        else:
+            model.build_in_name_scope()  # build the network with dummy inputs
+
+        if safetensors_from_pt and not is_sharded:
+            from .modeling_tf_pytorch_utils import load_pytorch_state_dict_in_tf2_model
+
+            with safe_open(resolved_archive_file, framework="tf") as safetensors_archive:
+                # Load from a PyTorch safetensors checkpoint
+                # We load in TF format here because PT weights often need to be transposed, and this is much
+                # faster on GPU. Loading as numpy and transposing on CPU adds several seconds to load times.
+                return load_pytorch_state_dict_in_tf2_model(
+                    model,
+                    safetensors_archive,
+                    tf_inputs=False,  # No need to build the model again
+                    allow_missing_keys=True,
+                    output_loading_info=output_loading_info,
+                    _prefix=load_weight_prefix,
+                    ignore_mismatched_sizes=ignore_mismatched_sizes,
+                    tf_to_pt_weight_rename=tf_to_pt_weight_rename,
+                )
+        elif safetensors_from_pt:
+            from .modeling_tf_pytorch_utils import load_sharded_pytorch_safetensors_in_tf2_model
+
+            return load_sharded_pytorch_safetensors_in_tf2_model(
+                model,
+                resolved_archive_file,
+                tf_inputs=False,
+                allow_missing_keys=True,
+                output_loading_info=output_loading_info,
+                _prefix=load_weight_prefix,
+                ignore_mismatched_sizes=ignore_mismatched_sizes,
+                tf_to_pt_weight_rename=tf_to_pt_weight_rename,
+            )
+
+        # 'by_name' allow us to do transfer learning by skipping/adding layers
+        # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
+        try:
+            if is_sharded:
+                for file in resolved_archive_file:
+                    os.path.isfile(file), f"Error retrieving files {file}"
+                if filename == SAFE_WEIGHTS_INDEX_NAME:
+                    missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights_from_safetensors(
+                        model,
+                        resolved_archive_file,
+                        ignore_mismatched_sizes=ignore_mismatched_sizes,
+                        _prefix=load_weight_prefix,
+                    )
+                else:
+                    missing_keys, unexpected_keys, mismatched_keys = load_tf_sharded_weights(
+                        model,
+                        resolved_archive_file,
+                        ignore_mismatched_sizes=ignore_mismatched_sizes,
+                        _prefix=load_weight_prefix,
+                    )
+            else:
+                # Handles both H5 and safetensors
+                missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
+                    model,
+                    resolved_archive_file,
+                    ignore_mismatched_sizes=ignore_mismatched_sizes,
+                    _prefix=load_weight_prefix,
+                )
+        except OSError as e:
+            try:
+                with open(resolved_archive_file) as f:
+                    if f.read().startswith("version"):
+                        raise OSError(
+                            "You seem to have cloned a repository without having git-lfs installed. Please install "
+                            "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
+                            "you cloned."
+                        )
+                    else:
+                        raise ValueError from e
+            except (UnicodeDecodeError, ValueError):
+                raise OSError(
+                    "Unable to load weights from h5 file. "
+                    "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
+                )
+
+        if cls._keys_to_ignore_on_load_missing is not None:
+            for pat in cls._keys_to_ignore_on_load_missing:
+                missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
+
+        if cls._keys_to_ignore_on_load_unexpected is not None:
+            for pat in cls._keys_to_ignore_on_load_unexpected:
+                unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
+
+        if len(unexpected_keys) > 0:
+            logger.warning(
+                f"Some layers from the model checkpoint at {pretrained_model_name_or_path} were not used when"
+                f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+                f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
+                " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+                " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+                f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
+                " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
+            )
+        else:
+            logger.warning(f"All model checkpoint layers were used when initializing {model.__class__.__name__}.\n")
+
+        if len(missing_keys) > 0:
+            logger.warning(
+                f"Some layers of {model.__class__.__name__} were not initialized from the model checkpoint at"
+                f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+                " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+            )
+        elif len(mismatched_keys) == 0:
+            logger.warning(
+                f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at"
+                f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
+                f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
+                " training."
+            )
+        if len(mismatched_keys) > 0:
+            mismatched_warning = "\n".join(
+                [
+                    f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+                    for key, shape1, shape2 in mismatched_keys
+                ]
+            )
+            logger.warning(
+                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+                f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
+                f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
+                " to use it for predictions and inference."
+            )
+
+        # If it is a model with generation capabilities, attempt to load the generation config
+        if model.can_generate():
+            try:
+                model.generation_config = GenerationConfig.from_pretrained(
+                    pretrained_model_name_or_path,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    resume_download=resume_download,
+                    proxies=proxies,
+                    local_files_only=local_files_only,
+                    token=token,
+                    revision=revision,
+                    subfolder=subfolder,
+                    _from_auto=from_auto_class,
+                    _from_pipeline=from_pipeline,
+                    **kwargs,
+                )
+            except OSError:
+                logger.info(
+                    "Generation config file not found, using a generation config created from the model config."
+                )
+                pass
+
+        if output_loading_info:
+            loading_info = {
+                "missing_keys": missing_keys,
+                "unexpected_keys": unexpected_keys,
+                "mismatched_keys": mismatched_keys,
+            }
+
+            return model, loading_info
+
+        return model
+
+    def push_to_hub(
+        self,
+        repo_id: str,
+        use_temp_dir: bool | None = None,
+        commit_message: str | None = None,
+        private: bool | None = None,
+        max_shard_size: int | str | None = "10GB",
+        token: bool | str | None = None,
+        # (`use_auth_token` is deprecated: we have to keep it here as we don't have **kwargs)
+        use_auth_token: bool | str | None = None,
+        create_pr: bool = False,
+        **base_model_card_args,
+    ) -> str:
+        """
+        Upload the model files to the 🤗 Model Hub while synchronizing a local clone of the repo in `repo_path_or_name`.
+
+        Parameters:
+            repo_id (`str`):
+                The name of the repository you want to push your model to. It should contain your organization name
+                when pushing to a given organization.
+            use_temp_dir (`bool`, *optional*):
+                Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub.
+                Will default to `True` if there is no directory named like `repo_id`, `False` otherwise.
+            commit_message (`str`, *optional*):
+                Message to commit while pushing. Will default to `"Upload model"`.
+            private (`bool`, *optional*):
+                Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
+            token (`bool` or `str`, *optional*):
+                The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+                when running `hf auth login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`
+                is not specified.
+            max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
+                Only applicable for models. The maximum size for a checkpoint before being sharded. Checkpoints shard
+                will then be each of size lower than this size. If expressed as a string, needs to be digits followed
+                by a unit (like `"5MB"`).
+            create_pr (`bool`, *optional*, defaults to `False`):
+                Whether or not to create a PR with the uploaded files or directly commit.
+
+        Examples:
+
+        ```python
+        from transformers import TFAutoModel
+
+        model = TFAutoModel.from_pretrained("google-bert/bert-base-cased")
+
+        # Push the model to your namespace with the name "my-finetuned-bert".
+        model.push_to_hub("my-finetuned-bert")
+
+        # Push the model to an organization with the name "my-finetuned-bert".
+        model.push_to_hub("huggingface/my-finetuned-bert")
+        ```
+        """
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if token is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            token = use_auth_token
+
+        if "repo_path_or_name" in base_model_card_args:
+            warnings.warn(
+                "The `repo_path_or_name` argument is deprecated and will be removed in v5 of Transformers. Use "
+                "`repo_id` instead."
+            )
+            repo_id = base_model_card_args.pop("repo_path_or_name")
+        # Deprecation warning will be sent after for repo_url and organization
+        repo_url = base_model_card_args.pop("repo_url", None)
+        organization = base_model_card_args.pop("organization", None)
+
+        if os.path.isdir(repo_id):
+            working_dir = repo_id
+            repo_id = repo_id.split(os.path.sep)[-1]
+        else:
+            working_dir = repo_id.split("/")[-1]
+
+        repo_id = self._create_repo(
+            repo_id, private=private, token=token, repo_url=repo_url, organization=organization
+        )
+
+        if use_temp_dir is None:
+            use_temp_dir = not os.path.isdir(working_dir)
+
+        with working_or_temp_dir(working_dir=working_dir, use_temp_dir=use_temp_dir) as work_dir:
+            files_timestamps = self._get_files_timestamps(work_dir)
+
+            # Save all files.
+            self.save_pretrained(work_dir, max_shard_size=max_shard_size)
+            if hasattr(self, "history") and hasattr(self, "create_model_card"):
+                # This is a Keras model and we might be able to fish out its History and make a model card out of it
+                base_model_card_args = {
+                    "output_dir": work_dir,
+                    "model_name": Path(repo_id).name,
+                }
+                base_model_card_args.update(base_model_card_args)
+                self.create_model_card(**base_model_card_args)
+
+            self._upload_modified_files(
+                work_dir,
+                repo_id,
+                files_timestamps,
+                commit_message=commit_message,
+                token=token,
+                create_pr=create_pr,
+            )
+
+    @classmethod
+    def register_for_auto_class(cls, auto_class="TFAutoModel"):
+        """
+        Register this class with a given auto class. This should only be used for custom models as the ones in the
+        library are already mapped with an auto class.
+
+
+
+        Args:
+            auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`):
+                The auto class to register this new model with.
+        """
+        if not isinstance(auto_class, str):
+            auto_class = auto_class.__name__
+
+        import transformers.models.auto as auto_module
+
+        if not hasattr(auto_module, auto_class):
+            raise ValueError(f"{auto_class} is not a valid auto class.")
+
+        cls._auto_class = auto_class
+
+
+class TFConv1D(keras.layers.Layer):
+    """
+    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
+
+    Basically works like a linear layer but the weights are transposed.
+
+    Args:
+        nf (`int`):
+            The number of output features.
+        nx (`int`):
+            The number of input features.
+        initializer_range (`float`, *optional*, defaults to 0.02):
+            The standard deviation to use to initialize the weights.
+        kwargs (`dict[str, Any]`, *optional*):
+            Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`.
+    """
+
+    def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
+        super().__init__(**kwargs)
+        self.nf = nf
+        self.nx = nx
+        self.initializer_range = initializer_range
+
+    def build(self, input_shape):
+        if self.built:
+            return
+        self.built = True
+        self.weight = self.add_weight(
+            "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
+        )
+        self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())
+
+    def call(self, x):
+        bz, sl = shape_list(x)[:2]
+
+        x = tf.reshape(x, [-1, self.nx])
+        x = tf.matmul(x, self.weight) + self.bias
+
+        x = tf.reshape(x, [bz, sl, self.nf])
+
+        return x
+
+
+class TFSharedEmbeddings(keras.layers.Layer):
+    r"""
+    Construct shared token embeddings.
+
+    The weights of the embedding layer is usually shared with the weights of the linear decoder when doing language
+    modeling.
+
+    Args:
+        vocab_size (`int`):
+            The size of the vocabulary, e.g., the number of unique tokens.
+        hidden_size (`int`):
+            The size of the embedding vectors.
+        initializer_range (`float`, *optional*):
+            The standard deviation to use when initializing the weights. If no value is provided, it will default to
+            \\(1/\sqrt{hidden\_size}\\).
+        kwargs (`dict[str, Any]`, *optional*):
+            Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`.
+    """
+
+    # TODO (joao): flagged for detection due to embeddings refactor
+
+    def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float | None = None, **kwargs):
+        super().__init__(**kwargs)
+        self.vocab_size = vocab_size
+        self.hidden_size = hidden_size
+        self.initializer_range = hidden_size**-0.5 if initializer_range is None else initializer_range
+        warnings.warn(
+            "`TFSharedEmbeddings` is scheduled for deletion in v4.32, use `keras.layers.Embedding` instead.",
+            DeprecationWarning,
+        )
+
+    def build(self, input_shape):
+        """
+        Build shared token embedding layer Shared weights logic adapted from
+        https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
+        """
+        self.weight = self.add_weight(
+            "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
+        )
+        super().build(input_shape)
+
+    def get_config(self):
+        config = {
+            "vocab_size": self.vocab_size,
+            "hidden_size": self.hidden_size,
+            "initializer_range": self.initializer_range,
+        }
+        base_config = super().get_config()
+
+        return dict(list(base_config.items()) + list(config.items()))
+
+    def call(self, inputs: tf.Tensor, mode: str = "embedding") -> tf.Tensor:
+        """
+        Get token embeddings of inputs or decode final hidden state.
+
+        Args:
+            inputs (`tf.Tensor`):
+                In embedding mode, should be an int64 tensor with shape `[batch_size, length]`.
+
+                In linear mode, should be a float tensor with shape `[batch_size, length, hidden_size]`.
+            mode (`str`, defaults to `"embedding"`):
+               A valid value is either `"embedding"` or `"linear"`, the first one indicates that the layer should be
+               used as an embedding layer, the second one that the layer should be used as a linear decoder.
+
+        Returns:
+            `tf.Tensor`: In embedding mode, the output is a float32 embedding tensor, with shape `[batch_size, length,
+            embedding_size]`.
+
+            In linear mode, the output is a float32 with shape `[batch_size, length, vocab_size]`.
+
+        Raises:
+            ValueError: if `mode` is not valid.
+
+        Shared weights logic is adapted from
+        [here](https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24).
+        """
+        if mode == "embedding":
+            return self._embedding(inputs)
+        elif mode == "linear":
+            return self._linear(inputs)
+        else:
+            raise ValueError(f"mode {mode} is not valid.")
+
+    def _embedding(self, input_ids):
+        """Applies embedding based on inputs tensor."""
+        return tf.gather(self.weight, input_ids)
+
+    def _linear(self, inputs):
+        """
+        Computes logits by running inputs through a linear layer.
+
+        Args:
+            inputs: A float32 tensor with shape [..., hidden_size]
+
+        Returns:
+            float32 tensor with shape [..., vocab_size].
+        """
+        first_dims = shape_list(inputs)[:-1]
+        x = tf.reshape(inputs, [-1, self.hidden_size])
+        logits = tf.matmul(x, self.weight, transpose_b=True)
+
+        return tf.reshape(logits, first_dims + [self.vocab_size])
+
+
+class TFSequenceSummary(keras.layers.Layer):
+    """
+    Compute a single vector summary of a sequence hidden states.
+
+    Args:
+        config ([`PretrainedConfig`]):
+            The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
+            config class of your model for the default values it uses):
+
+            - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
+
+                - `"last"` -- Take the last token hidden state (like XLNet)
+                - `"first"` -- Take the first token hidden state (like Bert)
+                - `"mean"` -- Take the mean of all tokens hidden states
+                - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
+                - `"attn"` -- Not implemented now, use multi-head attention
+
+            - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
+            - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
+              (otherwise to `config.hidden_size`).
+            - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
+              another string or `None` will add no activation.
+            - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
+            - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
+
+        initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation to use to initialize the weights.
+        kwargs (`dict[str, Any]`, *optional*):
+            Additional keyword arguments passed along to the `__init__` of `keras.layers.Layer`.
+    """
+
+    def __init__(self, config: PretrainedConfig, initializer_range: float = 0.02, **kwargs):
+        super().__init__(**kwargs)
+
+        self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
+        if self.summary_type == "attn":
+            # We should use a standard multi-head attention module with absolute positional embedding for that.
+            # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
+            # We can probably just use the multi-head attention module of PyTorch >=1.1.0
+            raise NotImplementedError
+
+        self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
+        if self.has_summary:
+            if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
+                num_classes = config.num_labels
+            else:
+                num_classes = config.hidden_size
+            self.summary = keras.layers.Dense(
+                num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
+            )
+
+        self.has_activation = False
+        activation_string = getattr(config, "summary_activation", None)
+        if activation_string is not None:
+            self.has_activation = True
+            self.activation = get_tf_activation(activation_string)
+
+        self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
+        if self.has_first_dropout:
+            self.first_dropout = keras.layers.Dropout(config.summary_first_dropout)
+
+        self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
+        if self.has_last_dropout:
+            self.last_dropout = keras.layers.Dropout(config.summary_last_dropout)
+        self.hidden_size = config.hidden_size
+
+    def call(self, inputs, cls_index=None, training=False):
+        if not isinstance(inputs, (dict, tuple, list)):
+            hidden_states = inputs
+        elif isinstance(inputs, (tuple, list)):
+            hidden_states = inputs[0]
+            cls_index = inputs[1] if len(inputs) > 1 else None
+            assert len(inputs) <= 2, "Too many inputs."
+        else:
+            hidden_states = inputs.get("hidden_states")
+            cls_index = inputs.get("cls_index", None)
+
+        if self.summary_type == "last":
+            output = hidden_states[:, -1]
+        elif self.summary_type == "first":
+            output = hidden_states[:, 0]
+        elif self.summary_type == "mean":
+            output = tf.reduce_mean(hidden_states, axis=1)
+        elif self.summary_type == "cls_index":
+            hidden_shape = shape_list(hidden_states)  # e.g. [batch, num choices, seq length, hidden dims]
+            if cls_index is None:
+                cls_index = tf.fill(
+                    hidden_shape[:-2], hidden_shape[-2] - 1
+                )  # A tensor full of shape [batch] or [batch, num choices] full of sequence length
+            cls_shape = shape_list(cls_index)
+            if len(cls_shape) <= len(hidden_shape) - 2:
+                cls_index = tf.expand_dims(cls_index, axis=-1)
+            # else:
+            # cls_index = cls_index[..., tf.newaxis]
+            # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
+            # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
+            output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
+            output = tf.squeeze(
+                output, axis=len(hidden_shape) - 2
+            )  # shape of output: (batch, num choices, hidden_size)
+        elif self.summary_type == "attn":
+            raise NotImplementedError
+
+        if self.has_first_dropout:
+            output = self.first_dropout(output, training=training)
+
+        if self.has_summary:
+            output = self.summary(output)
+
+        if self.has_activation:
+            output = self.activation(output)
+
+        if self.has_last_dropout:
+            output = self.last_dropout(output, training=training)
+
+        return output
+
+    def build(self, input_shape):
+        if self.built:
+            return
+        self.built = True
+        if getattr(self, "summary", None) is not None:
+            with tf.name_scope("summary"):
+                self.summary.build(self.hidden_size)
+
+
+def get_initializer(initializer_range: float = 0.02) -> keras.initializers.TruncatedNormal:
+    """
+    Creates a `keras.initializers.TruncatedNormal` with the given range.
+
+    Args:
+        initializer_range (*float*, defaults to 0.02): Standard deviation of the initializer range.
+
+    Returns:
+        `keras.initializers.TruncatedNormal`: The truncated normal initializer.
+    """
+    return keras.initializers.TruncatedNormal(stddev=initializer_range)
diff --git a/phivenv/Lib/site-packages/transformers/modeling_utils.py b/phivenv/Lib/site-packages/transformers/modeling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..973ee405cb3a0acb37558385da7fbdb158065556
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/modeling_utils.py
@@ -0,0 +1,6291 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import collections
+import copy
+import functools
+import gc
+import importlib.metadata
+import inspect
+import itertools
+import json
+import os
+import re
+import shutil
+import sys
+import tempfile
+import warnings
+from abc import abstractmethod
+from collections import defaultdict
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from contextlib import contextmanager
+from enum import Enum
+from functools import partial, wraps
+from threading import Thread
+from typing import Any, Callable, Optional, TypeVar, Union, get_type_hints
+from zipfile import is_zipfile
+
+import torch
+from huggingface_hub import split_torch_state_dict_into_shards
+from packaging import version
+from torch import Tensor, nn
+from torch.distributions import constraints
+from torch.utils.checkpoint import checkpoint
+
+from .configuration_utils import PretrainedConfig
+from .distributed import DistributedConfig
+from .dynamic_module_utils import custom_object_save
+from .generation import CompileConfig, GenerationConfig
+from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled, is_fsdp_enabled
+from .integrations.accelerate import find_tied_parameters, init_empty_weights
+from .integrations.deepspeed import _load_state_dict_into_zero3_model
+from .integrations.eager_paged import eager_paged_attention_forward
+from .integrations.flash_attention import flash_attention_forward
+from .integrations.flash_paged import paged_attention_forward
+from .integrations.flex_attention import flex_attention_forward
+from .integrations.hub_kernels import is_kernel, load_and_register_kernel
+from .integrations.sdpa_attention import sdpa_attention_forward
+from .integrations.sdpa_paged import sdpa_attention_paged_forward
+from .integrations.tensor_parallel import (
+    _get_parameter_tp_plan,
+    distribute_model,
+    initialize_tensor_parallelism,
+    repack_weights,
+    replace_state_dict_local_with_dtensor,
+    shard_and_distribute_module,
+    verify_tp_plan,
+)
+from .loss.loss_utils import LOSS_MAPPING
+from .modeling_flash_attention_utils import lazy_import_flash_attention
+from .pytorch_utils import id_tensor_storage
+from .quantizers import HfQuantizer
+from .quantizers.auto import get_hf_quantizer
+from .quantizers.quantizers_utils import get_module_from_name
+from .safetensors_conversion import auto_conversion
+from .utils import (
+    ADAPTER_SAFE_WEIGHTS_NAME,
+    ADAPTER_WEIGHTS_NAME,
+    CONFIG_NAME,
+    DUMMY_INPUTS,
+    FLAX_WEIGHTS_NAME,
+    SAFE_WEIGHTS_INDEX_NAME,
+    SAFE_WEIGHTS_NAME,
+    TF2_WEIGHTS_NAME,
+    TF_WEIGHTS_NAME,
+    WEIGHTS_INDEX_NAME,
+    WEIGHTS_NAME,
+    ContextManagers,
+    PushToHubMixin,
+    cached_file,
+    check_torch_load_is_safe,
+    copy_func,
+    download_url,
+    extract_commit_hash,
+    has_file,
+    is_accelerate_available,
+    is_bitsandbytes_available,
+    is_flash_attn_2_available,
+    is_flash_attn_3_available,
+    is_kernels_available,
+    is_offline_mode,
+    is_optimum_available,
+    is_peft_available,
+    is_remote_url,
+    is_safetensors_available,
+    is_torch_flex_attn_available,
+    is_torch_greater_or_equal,
+    is_torch_mlu_available,
+    is_torch_npu_available,
+    is_torch_xla_available,
+    is_torch_xpu_available,
+    is_torchao_available,
+    logging,
+)
+from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
+from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
+from .utils.import_utils import (
+    ENV_VARS_TRUE_VALUES,
+    is_huggingface_hub_greater_or_equal,
+    is_sagemaker_mp_enabled,
+    is_torch_fx_proxy,
+    is_torchdynamo_compiling,
+)
+from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
+
+
+if is_torchao_available():
+    from torchao.quantization import Int4WeightOnlyConfig
+
+if is_accelerate_available():
+    from accelerate import dispatch_model, infer_auto_device_map
+    from accelerate.hooks import add_hook_to_module
+    from accelerate.utils import (
+        check_tied_parameters_on_same_device,
+        extract_model_from_parallel,
+        get_balanced_memory,
+        get_max_memory,
+        load_offloaded_weights,
+        offload_weight,
+        save_offload_index,
+    )
+
+    accelerate_version = version.parse(importlib.metadata.version("accelerate"))
+    if accelerate_version >= version.parse("0.31"):
+        from accelerate.utils.modeling import get_state_dict_from_offload
+
+if is_safetensors_available():
+    from safetensors import safe_open
+    from safetensors.torch import load_file as safe_load_file
+    from safetensors.torch import save_file as safe_save_file
+
+if is_peft_available():
+    from .utils import find_adapter_config_file
+
+_torch_distributed_available = torch.distributed.is_available()
+_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
+if _is_dtensor_available:
+    from torch.distributed.tensor import DTensor
+
+if is_sagemaker_mp_enabled():
+    import smdistributed.modelparallel.torch as smp
+    from smdistributed.modelparallel import __version__ as SMP_VERSION
+
+    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
+else:
+    IS_SAGEMAKER_MP_POST_1_10 = False
+
+
+logger = logging.get_logger(__name__)
+
+XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
+XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
+SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
+_init_weights = True
+_is_quantized = False
+_is_ds_init_called = False
+
+
+def is_local_dist_rank_0():
+    return (
+        torch.distributed.is_available()
+        and torch.distributed.is_initialized()
+        and int(os.environ.get("LOCAL_RANK", "-1")) == 0
+    )
+
+
+TORCH_INIT_FUNCTIONS = {
+    "uniform_": nn.init.uniform_,
+    "normal_": nn.init.normal_,
+    "trunc_normal_": nn.init.trunc_normal_,
+    "constant_": nn.init.constant_,
+    "xavier_uniform_": nn.init.xavier_uniform_,
+    "xavier_normal_": nn.init.xavier_normal_,
+    "kaiming_uniform_": nn.init.kaiming_uniform_,
+    "kaiming_normal_": nn.init.kaiming_normal_,
+    "uniform": nn.init.uniform,
+    "normal": nn.init.normal,
+    "xavier_uniform": nn.init.xavier_uniform,
+    "xavier_normal": nn.init.xavier_normal,
+    "kaiming_uniform": nn.init.kaiming_uniform,
+    "kaiming_normal": nn.init.kaiming_normal,
+}
+
+# DO NOT MODIFY, KEPT FOR BC ONLY
+VLMS = [
+    "aria",
+    "ayavision",
+    "colpali",
+    "emu3",
+    "fuyu",
+    "gotocr2",
+    "gemma3",
+    "internvl",
+    "llava",  # all llava prefixed models fall under this check
+    "mistral3",
+    "mllama",
+    "paligemma",
+    "shieldgemma2",
+    "qwen2vl",
+    "qwen2_5_vl",
+    "videollava",
+    "vipllava",
+]
+
+
+@contextmanager
+def no_init_weights():
+    """
+    Context manager to globally disable weight initialization to speed up loading large models.
+    """
+    global _init_weights
+    old_init_weights = _init_weights
+
+    _init_weights = False
+
+    def _skip_init(*args, **kwargs):
+        pass
+
+    # Save the original initialization functions
+    for name, init_func in TORCH_INIT_FUNCTIONS.items():
+        setattr(torch.nn.init, name, _skip_init)
+
+    try:
+        yield
+    finally:
+        _init_weights = old_init_weights
+        # Restore the original initialization functions
+        for name, init_func in TORCH_INIT_FUNCTIONS.items():
+            setattr(torch.nn.init, name, init_func)
+
+
+@contextmanager
+def set_quantized_state():
+    global _is_quantized
+    _is_quantized = True
+    try:
+        yield
+    finally:
+        _is_quantized = False
+
+
+# Skip recursive calls to deepspeed.zero.Init to avoid pinning errors.
+# This issue occurs with ZeRO stage 3 when using NVMe offloading.
+# For more details, refer to issue #34429.
+@contextmanager
+def set_zero3_state():
+    global _is_ds_init_called
+    _is_ds_init_called = True
+    try:
+        yield
+    finally:
+        _is_ds_init_called = False
+
+
+def restore_default_dtype(func):
+    """
+    Decorator to restore the default torch dtype
+    at the end of the function. Serves
+    as a backup in case calling the function raises
+    an error after the function has changed the default dtype but before it could restore it.
+    """
+
+    @wraps(func)
+    def _wrapper(*args, **kwargs):
+        old_dtype = torch.get_default_dtype()
+        try:
+            return func(*args, **kwargs)
+        finally:
+            torch.set_default_dtype(old_dtype)
+
+    return _wrapper
+
+
+def get_torch_context_manager_or_global_device():
+    """
+    Test if a device context manager is currently in use, or if it is not the case, check if the default device
+    is not "cpu". This is used to infer the correct device to load the model on, in case `device_map` is not provided.
+    """
+    device_in_context = torch.tensor([]).device
+    # `get_default_device` was only introduced in torch>=2.3 - use cpu otherwise to align the behavior
+    default_device = torch.get_default_device() if is_torch_greater_or_equal("2.3") else torch.device("cpu")
+    # This case means no context manager was used -> we still check if the default that was potentially set is not cpu
+    if device_in_context == default_device:
+        if default_device != torch.device("cpu"):
+            return default_device
+        return None
+    return device_in_context
+
+
+def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
+    try:
+        return next(parameter.parameters()).device
+    except StopIteration:
+        # For nn.DataParallel compatibility in PyTorch 1.5
+
+        def find_tensor_attributes(module: nn.Module) -> list[tuple[str, Tensor]]:
+            tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+            return tuples
+
+        gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+        first_tuple = next(gen)
+        return first_tuple[1].device
+
+
+def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
+    """
+    Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
+    """
+    last_dtype = None
+    for t in parameter.parameters():
+        last_dtype = t.dtype
+        if t.is_floating_point():
+            # Adding fix for https://github.com/pytorch/xla/issues/4152
+            # Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1
+            # and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf
+            # NOTE: `is_torch_xla_available()` is checked last as it induces a graph break in torch dynamo
+            if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
+                return torch.bfloat16
+            if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
+                if t.dtype == torch.float:
+                    return torch.bfloat16
+                if t.dtype == torch.double:
+                    return torch.float32
+            return t.dtype
+
+    if last_dtype is not None:
+        # if no floating dtype was found return whatever the first dtype is
+        return last_dtype
+
+    # For nn.DataParallel compatibility in PyTorch > 1.5
+    def find_tensor_attributes(module: nn.Module) -> list[tuple[str, Tensor]]:
+        tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+        return tuples
+
+    gen = parameter._named_members(get_members_fn=find_tensor_attributes)
+    last_tuple = None
+    for gen_tuple in gen:
+        last_tuple = gen_tuple
+        if gen_tuple[1].is_floating_point():
+            return gen_tuple[1].dtype
+
+    if last_tuple is not None:
+        # fallback to the last dtype
+        return last_tuple[1].dtype
+
+    # fallback to buffer dtype
+    for t in parameter.buffers():
+        last_dtype = t.dtype
+        if t.is_floating_point():
+            return t.dtype
+    return last_dtype
+
+
+def get_state_dict_dtype(state_dict):
+    """
+    Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the first dtype.
+    """
+    for t in state_dict.values():
+        if t.is_floating_point():
+            return t.dtype
+
+    # if no floating dtype was found return whatever the first dtype is
+    return next(state_dict.values()).dtype
+
+
+def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
+    """
+    This is the same as
+    [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
+    but for a sharded checkpoint.
+
+    This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
+    loaded in the model.
+
+    Args:
+        model (`torch.nn.Module`): The model in which to load the checkpoint.
+        folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
+        strict (`bool`, *optional*, defaults to `True`):
+            Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
+        prefer_safe (`bool`, *optional*, defaults to `False`):
+            If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the
+            safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.
+
+    Returns:
+        `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields
+            - `missing_keys` is a list of str containing the missing keys
+            - `unexpected_keys` is a list of str containing the unexpected keys
+    """
+    # Load the index
+    index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
+    safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
+
+    index_present = os.path.isfile(index_file)
+    safe_index_present = os.path.isfile(safe_index_file)
+
+    if not index_present and not (safe_index_present and is_safetensors_available()):
+        filenames = (
+            (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME) if is_safetensors_available() else (WEIGHTS_INDEX_NAME,)
+        )
+        raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.")
+
+    load_safe = False
+    if safe_index_present:
+        if prefer_safe:
+            if is_safetensors_available():
+                load_safe = True  # load safe due to preference
+            else:
+                logger.warning(
+                    f"Cannot load sharded checkpoint at {folder} safely since safetensors is not installed!"
+                )
+        elif not index_present:
+            load_safe = True  # load safe since we have no other choice
+
+    load_index = safe_index_file if load_safe else index_file
+
+    with open(load_index, "r", encoding="utf-8") as f:
+        index = json.load(f)
+
+    shard_files = list(set(index["weight_map"].values()))
+
+    # If strict=True, error before loading any of the state dicts.
+    loaded_keys = index["weight_map"].keys()
+    model_keys = model.state_dict().keys()
+    missing_keys = [key for key in model_keys if key not in loaded_keys]
+    unexpected_keys = [key for key in loaded_keys if key not in model_keys]
+    if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
+        error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
+        if len(missing_keys) > 0:
+            str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
+            error_message += f"\nMissing key(s): {str_missing_keys}."
+        if len(unexpected_keys) > 0:
+            str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
+            error_message += f"\nMissing key(s): {str_unexpected_keys}."
+        raise RuntimeError(error_message)
+
+    if load_safe:
+        loader = safe_load_file
+    else:
+        check_torch_load_is_safe()
+        loader = partial(torch.load, map_location="cpu", weights_only=True)
+
+    for shard_file in shard_files:
+        state_dict = loader(os.path.join(folder, shard_file))
+        model.load_state_dict(state_dict, strict=False)
+
+        # Make sure memory is freed before we load the next state dict.
+        del state_dict
+        gc.collect()
+
+    # Return the same thing as PyTorch load_state_dict function.
+    return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
+
+
+str_to_torch_dtype = {
+    "BOOL": torch.bool,
+    "U8": torch.uint8,
+    "I8": torch.int8,
+    "I16": torch.int16,
+    "F16": torch.float16,
+    "BF16": torch.bfloat16,
+    "I32": torch.int32,
+    "F32": torch.float32,
+    "F64": torch.float64,
+    "I64": torch.int64,
+    "F8_E4M3": torch.float8_e4m3fn,
+    "F8_E5M2": torch.float8_e5m2,
+}
+
+
+if is_torch_greater_or_equal("2.3.0"):
+    str_to_torch_dtype["U16"] = torch.uint16
+    str_to_torch_dtype["U32"] = torch.uint32
+    str_to_torch_dtype["U64"] = torch.uint64
+
+
+def load_state_dict(
+    checkpoint_file: Union[str, os.PathLike],
+    is_quantized: bool = False,
+    map_location: Optional[Union[str, torch.device]] = "cpu",
+    weights_only: bool = True,
+):
+    """
+    Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
+    """
+    # Use safetensors if possible
+    if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
+        with safe_open(checkpoint_file, framework="pt") as f:
+            metadata = f.metadata()
+
+            if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
+                raise OSError(
+                    f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
+                    "you save your model with the `save_pretrained` method."
+                )
+            state_dict = {}
+            for k in f.keys():
+                if map_location == "meta":
+                    _slice = f.get_slice(k)
+                    k_dtype = _slice.get_dtype()
+                    if k_dtype in str_to_torch_dtype:
+                        dtype = str_to_torch_dtype[k_dtype]
+                    else:
+                        raise ValueError(f"Cannot load safetensors of unknown dtype {k_dtype}")
+                    state_dict[k] = torch.empty(size=_slice.get_shape(), dtype=dtype, device="meta")
+                else:
+                    state_dict[k] = f.get_tensor(k)
+            return state_dict
+
+    # Fallback to torch.load (if weights_only was explicitly False, do not check safety as this is known to be unsafe)
+    if weights_only:
+        check_torch_load_is_safe()
+    try:
+        if map_location is None:
+            if (
+                (
+                    is_deepspeed_zero3_enabled()
+                    and torch.distributed.is_initialized()
+                    and torch.distributed.get_rank() > 0
+                )
+                or (is_fsdp_enabled() and not is_local_dist_rank_0())
+            ) and not is_quantized:
+                map_location = "meta"
+            else:
+                map_location = "cpu"
+        extra_args = {}
+        # mmap can only be used with files serialized with zipfile-based format.
+        if isinstance(checkpoint_file, str) and map_location != "meta" and is_zipfile(checkpoint_file):
+            extra_args = {"mmap": True}
+        return torch.load(
+            checkpoint_file,
+            map_location=map_location,
+            weights_only=weights_only,
+            **extra_args,
+        )
+    except Exception as e:
+        try:
+            with open(checkpoint_file) as f:
+                if f.read(7) == "version":
+                    raise OSError(
+                        "You seem to have cloned a repository without having git-lfs installed. Please install "
+                        "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
+                        "you cloned."
+                    )
+                else:
+                    raise ValueError(
+                        f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
+                        "model. Make sure you have saved the model properly."
+                    ) from e
+        except (UnicodeDecodeError, ValueError):
+            raise OSError(
+                f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
+                f"at '{checkpoint_file}'. "
+                "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
+            )
+
+
+def set_initialized_submodules(model, state_dict_keys):
+    """
+    Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state
+    dict.
+    """
+    state_dict_keys = set(state_dict_keys)
+    not_initialized_submodules = {}
+    for module_name, module in model.named_modules():
+        if module_name == "":
+            # When checking if the root module is loaded there's no need to prepend module_name.
+            module_keys = set(module.state_dict())
+        else:
+            module_keys = {f"{module_name}.{k}" for k in module.state_dict()}
+        if module_keys.issubset(state_dict_keys):
+            module._is_hf_initialized = True
+        else:
+            not_initialized_submodules[module_name] = module
+    return not_initialized_submodules
+
+
+def _end_ptr(tensor: torch.Tensor) -> int:
+    # extract the end of the pointer if the tensor is a slice of a bigger tensor
+    if tensor.nelement():
+        stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size()
+    else:
+        stop = tensor.data_ptr()
+    return stop
+
+
+def _get_tied_weight_keys(module: nn.Module, prefix=""):
+    tied_weight_keys = []
+    if getattr(module, "_tied_weights_keys", None) is not None:
+        names = [f"{prefix}.{k}" if prefix else k for k in module._tied_weights_keys]
+        tied_weight_keys.extend(names)
+    if getattr(module, "_dynamic_tied_weights_keys", None) is not None:
+        names = [f"{prefix}.{k}" if prefix else k for k in module._dynamic_tied_weights_keys]
+        tied_weight_keys.extend(names)
+    for name, submodule in module.named_children():
+        local_prefix = f"{prefix}.{name}" if prefix else name
+        tied_weight_keys.extend(_get_tied_weight_keys(submodule, prefix=local_prefix))
+    return tied_weight_keys
+
+
+def _find_disjoint(tensors: list[set[str]], state_dict: dict[str, torch.Tensor]) -> tuple[list[set[str]], list[str]]:
+    filtered_tensors = []
+    for shared in tensors:
+        if len(shared) < 2:
+            filtered_tensors.append(shared)
+            continue
+
+        areas = []
+        for name in shared:
+            tensor = state_dict[name]
+            areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
+        areas.sort()
+
+        _, last_stop, last_name = areas[0]
+        filtered_tensors.append({last_name})
+        for start, stop, name in areas[1:]:
+            if start >= last_stop:
+                filtered_tensors.append({name})
+            else:
+                filtered_tensors[-1].add(name)
+            last_stop = stop
+    disjoint_tensors = []
+    shared_tensors = []
+    for tensors in filtered_tensors:
+        if len(tensors) == 1:
+            disjoint_tensors.append(tensors.pop())
+        else:
+            shared_tensors.append(tensors)
+    return shared_tensors, disjoint_tensors
+
+
+def _find_identical(tensors: list[set[str]], state_dict: dict[str, torch.Tensor]) -> tuple[list[set[str]], set[str]]:
+    shared_tensors = []
+    identical = []
+    for shared in tensors:
+        if len(shared) < 2:
+            continue
+
+        areas = collections.defaultdict(set)
+        for name in shared:
+            tensor = state_dict[name]
+            area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor))
+            areas[area].add(name)
+        if len(areas) == 1:
+            identical.append(shared)
+        else:
+            shared_tensors.append(shared)
+    return shared_tensors, identical
+
+
+def _infer_parameter_dtype(
+    model: "PreTrainedModel",
+    param_name: str,
+    empty_param: torch.Tensor,
+    keep_in_fp32_regex: Optional[re.Pattern] = None,
+    hf_quantizer: Optional[HfQuantizer] = None,
+) -> Union[bool, Optional[torch.dtype]]:
+    try:
+        old_param = model.get_parameter_or_buffer(param_name)
+    except Exception as e:
+        if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in {
+            QuantizationMethod.HQQ,
+            QuantizationMethod.QUARK,
+            QuantizationMethod.MXFP4,
+        }:
+            return True, None
+        else:
+            raise e
+    is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
+    # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
+    # in int/uint/bool and not cast them.
+    casting_dtype = None
+    is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
+    if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
+        # First fp32 if part of the exception list
+        if keep_in_fp32_regex is not None and keep_in_fp32_regex.search(param_name):
+            casting_dtype = torch.float32
+        # Then dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes
+        elif hf_quantizer is not None:
+            casting_dtype = model.config._pre_quantization_dtype
+        else:
+            casting_dtype = old_param.dtype
+    return old_param is not None and old_param.is_contiguous(), casting_dtype
+
+
+def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor: torch.Tensor):
+    """Cast a single parameter `param_name` into the `model`, with value `tensor`."""
+    module, param_type = get_module_from_name(model, param_name)
+    # This will check potential shape mismatch if skipped before
+    module.load_state_dict({param_type: tensor}, strict=False, assign=True)
+
+
+@torch.no_grad()
+def _load_state_dict_into_meta_model(
+    model: "PreTrainedModel",
+    state_dict: dict,
+    shard_file: str,
+    expected_keys: list[str],
+    reverse_renaming_mapping: dict[str, str],
+    device_map: Optional[dict] = None,
+    disk_offload_folder: Optional[str] = None,
+    disk_offload_index: Optional[dict] = None,
+    cpu_offload_folder: Optional[str] = None,
+    cpu_offload_index: Optional[dict] = None,
+    hf_quantizer: Optional[HfQuantizer] = None,
+    is_safetensors: bool = False,
+    keep_in_fp32_regex: Optional[re.Pattern] = None,
+    unexpected_keys: Optional[list[str]] = None,  # passing `unexpected` for cleanup from quantization items
+    device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
+) -> tuple[Optional[dict], Optional[dict]]:
+    """Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
+    device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
+    from `shard_file`, which is the actual state dict file on disk.
+    This function takes care of correctly casting dtypes, devices, and sharding tensors in case of tensor parallelism.
+    """
+    tensor_device = "cpu"
+    if device_map is not None and device_map.get("", None) is not None:
+        if device_map[""] not in ("cpu", torch.device("cpu")):
+            tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
+    if device_map is not None:
+        device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
+
+    is_quantized = hf_quantizer is not None
+    is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in {
+        QuantizationMethod.HQQ,
+        QuantizationMethod.BITS_AND_BYTES,
+    }
+    is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb
+    file_pointer = None
+    if is_meta_state_dict:
+        file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
+
+    for param_name, empty_param in state_dict.items():
+        if param_name not in expected_keys:  # when loading from ckpt, we skip param if doesnt exist in modeling
+            continue
+        # we need to use serialized_param_name as file pointer is untouched
+        if is_meta_state_dict:
+            # This is the name of the parameter as it appears on disk file
+            serialized_param_name = reverse_renaming_mapping[param_name]
+            param = file_pointer.get_slice(serialized_param_name)
+        else:
+            param = empty_param.to(tensor_device)  # It is actually not empty!
+        to_contiguous, casting_dtype = _infer_parameter_dtype(
+            model,
+            param_name,
+            empty_param,
+            keep_in_fp32_regex,
+            hf_quantizer,
+        )
+
+        if device_mesh is not None:
+            if (
+                not is_quantized
+                or (not hf_quantizer.requires_parameters_quantization)
+                or (
+                    not hf_quantizer.check_quantized_param(
+                        model,
+                        param,
+                        param_name,
+                        state_dict,
+                        device_map=device_map,
+                    )
+                )
+            ):  # In this case, the param is already on the correct device!
+                shard_and_distribute_module(
+                    model,
+                    param,
+                    empty_param,
+                    param_name,
+                    casting_dtype,
+                    to_contiguous,
+                    device_mesh.get_local_rank(),
+                    device_mesh,
+                )
+            else:  # we have a device mesh but the param needs to be quantized, so we shard inside create_quantized_param:
+                sharding_kwargs = {
+                    "empty_param": empty_param,
+                    "casting_dtype": casting_dtype,
+                    "to_contiguous": to_contiguous,
+                    "rank": device_mesh.get_local_rank(),
+                    "device_mesh": device_mesh,
+                }
+                hf_quantizer.create_quantized_param(
+                    model,
+                    param,
+                    param_name,
+                    device_mesh.get_local_rank(),
+                    state_dict,
+                    unexpected_keys,
+                    **sharding_kwargs,
+                )
+        else:
+            param = param[...]
+            if casting_dtype is not None:
+                param = param.to(casting_dtype)
+            if to_contiguous:
+                param = param.contiguous()
+
+            if device_map is None:
+                param_device = "cpu"
+            else:
+                module_layer = re.search(device_map_regex, param_name)
+                if not module_layer:
+                    raise ValueError(f"{param_name} doesn't have any device set.")
+                else:
+                    param_device = device_map[module_layer.group()]
+
+            if param_device == "disk":
+                if not is_safetensors:
+                    disk_offload_index = offload_weight(param, param_name, disk_offload_folder, disk_offload_index)
+            elif param_device == "cpu" and cpu_offload_index is not None:
+                cpu_offload_index = offload_weight(param, param_name, cpu_offload_folder, cpu_offload_index)
+            elif (
+                not is_quantized
+                or (not hf_quantizer.requires_parameters_quantization)
+                or (
+                    not hf_quantizer.check_quantized_param(
+                        model,
+                        param,
+                        param_name,
+                        state_dict,
+                        param_device=param_device,
+                        device_map=device_map,
+                    )
+                )
+            ):
+                if is_fsdp_enabled():
+                    param_device = "cpu" if is_local_dist_rank_0() else "meta"
+
+                _load_parameter_into_model(model, param_name, param.to(param_device))
+
+            else:
+                # TODO naming is stupid it loads it as well
+                hf_quantizer.create_quantized_param(
+                    model, param, param_name, param_device, state_dict, unexpected_keys
+                )
+
+                # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
+                # and then cast it to CPU to avoid excessive memory usage on each GPU
+                # in comparison to the sharded model across GPUs.
+                if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
+                    param_name = hf_quantizer.update_param_name(param_name)
+                    module, param_type = get_module_from_name(model, param_name)
+                    value = getattr(module, param_type)
+                    # special case for gpt_oss model, we wait for the param to be leave the meta device before casting it to cpu
+                    if model.config.model_type == "gpt_oss" and value.device.type == "meta":
+                        continue
+                    param_to = "cpu"
+                    if is_fsdp_enabled() and not is_local_dist_rank_0():
+                        param_to = "meta"
+                    val_kwargs = {}
+                    if (hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params") or (
+                        value.dtype == torch.uint8 or value.dtype == torch.int8
+                    ):
+                        val_kwargs["requires_grad"] = False
+                    value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
+                    setattr(module, param_type, value)
+
+    if file_pointer is not None:
+        file_pointer.__exit__(None, None, None)
+
+    return disk_offload_index, cpu_offload_index
+
+
+def load_shard_file(args):
+    (
+        shard_file,
+        state_dict,
+        disk_only_shard_files,
+        is_hqq_or_bnb,
+        is_quantized,
+        device_map,
+        hf_quantizer,
+        key_renaming_mapping,
+        weights_only,
+        model_to_load,
+        expected_keys,
+        reverse_key_renaming_mapping,
+        disk_offload_folder,
+        disk_offload_index,
+        cpu_offload_folder,
+        cpu_offload_index,
+        is_offloaded_safetensors,
+        keep_in_fp32_regex,
+        unexpected_keys,
+        device_mesh,
+    ) = args
+
+    # Skip the load for shards that only contain disk-offloaded weights
+    if shard_file in disk_only_shard_files:
+        return [], disk_offload_index, cpu_offload_index
+
+    map_location = "cpu"
+    if (
+        shard_file.endswith(".safetensors")
+        and not is_hqq_or_bnb
+        and not (is_deepspeed_zero3_enabled() and not is_quantized)
+    ):
+        map_location = "meta"
+    elif (
+        device_map is not None
+        and hf_quantizer is not None
+        and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
+        and (
+            hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
+            or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
+        )
+    ):
+        map_location = torch.device([d for d in device_map.values() if d not in ["disk"]][0])
+
+    # If shard_file is "", we use the existing state_dict instead of loading it
+    if shard_file != "":
+        state_dict = load_state_dict(
+            shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
+        )
+
+    # Fix the key names
+    state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
+
+    error_msgs = []
+
+    if is_deepspeed_zero3_enabled() and not is_quantized:
+        error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
+    # Skip it with fsdp on ranks other than 0
+    elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
+        disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
+            model_to_load,
+            state_dict,
+            shard_file,
+            expected_keys,
+            reverse_key_renaming_mapping,
+            device_map=device_map,
+            disk_offload_folder=disk_offload_folder,
+            disk_offload_index=disk_offload_index,
+            cpu_offload_folder=cpu_offload_folder,
+            cpu_offload_index=cpu_offload_index,
+            hf_quantizer=hf_quantizer,
+            is_safetensors=is_offloaded_safetensors,
+            keep_in_fp32_regex=keep_in_fp32_regex,
+            unexpected_keys=unexpected_keys,
+            device_mesh=device_mesh,
+        )
+
+    return error_msgs, disk_offload_index, cpu_offload_index
+
+
+def load_shard_files_with_threadpool(args_list):
+    num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
+
+    # Do not spawn anymore workers than you need
+    num_workers = min(len(args_list), num_workers)
+
+    logger.info(f"Loading model weights in parallel with {num_workers} workers...")
+
+    error_msgs = []
+
+    with ThreadPoolExecutor(max_workers=num_workers) as executor:
+        with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar:
+            futures = [executor.submit(load_shard_file, arg) for arg in args_list]
+            for future in as_completed(futures):
+                result = future.result()
+                (
+                    _error_msgs,
+                    disk_offload_index,
+                    cpu_offload_index,
+                ) = result
+
+                error_msgs += _error_msgs
+
+                pbar.update(1)
+
+    return error_msgs, disk_offload_index, cpu_offload_index
+
+
+def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
+    if variant is not None:
+        path, name = weights_name.rsplit(".", 1)
+        weights_name = f"{path}.{variant}.{name}"
+    return weights_name
+
+
+def _get_resolved_checkpoint_files(
+    pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+    subfolder: str,
+    variant: Optional[str],
+    gguf_file: Optional[str],
+    from_tf: bool,
+    from_flax: bool,
+    use_safetensors: bool,
+    cache_dir: str,
+    force_download: bool,
+    proxies: Optional[dict[str, str]],
+    local_files_only: bool,
+    token: Optional[Union[str, bool]],
+    user_agent: dict,
+    revision: str,
+    commit_hash: Optional[str],
+    is_remote_code: bool,  # Because we can't determine this inside this function, we need it to be passed in
+    transformers_explicit_filename: Optional[str] = None,
+) -> tuple[Optional[list[str]], Optional[dict]]:
+    """Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
+    checkpoints are sharded.
+    This function will download the data if necessary.
+    """
+    is_sharded = False
+
+    if pretrained_model_name_or_path is not None and gguf_file is None:
+        pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+        is_local = os.path.isdir(pretrained_model_name_or_path)
+        if is_local:
+            if transformers_explicit_filename is not None:
+                # If the filename is explicitly defined, load this by default.
+                archive_file = os.path.join(pretrained_model_name_or_path, subfolder, transformers_explicit_filename)
+                is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
+            elif from_tf and os.path.isfile(
+                os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
+            ):
+                # Load from a TF 1.0 checkpoint in priority if from_tf
+                archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
+            elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)):
+                # Load from a TF 2.0 checkpoint in priority if from_tf
+                archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
+            elif from_flax and os.path.isfile(
+                os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
+            ):
+                # Load from a Flax checkpoint in priority if from_flax
+                archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
+            elif use_safetensors is not False and os.path.isfile(
+                os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
+            ):
+                # Load from a safetensors checkpoint
+                archive_file = os.path.join(
+                    pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
+                )
+            elif use_safetensors is not False and os.path.isfile(
+                os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
+            ):
+                # Load from a sharded safetensors checkpoint
+                archive_file = os.path.join(
+                    pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
+                )
+                is_sharded = True
+            elif not use_safetensors and os.path.isfile(
+                os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
+            ):
+                # Load from a PyTorch checkpoint
+                archive_file = os.path.join(
+                    pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
+                )
+            elif not use_safetensors and os.path.isfile(
+                os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
+            ):
+                # Load from a sharded PyTorch checkpoint
+                archive_file = os.path.join(
+                    pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
+                )
+                is_sharded = True
+            # At this stage we don't have a weight file so we will raise an error.
+            elif not use_safetensors and (
+                os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index"))
+                or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME))
+            ):
+                raise OSError(
+                    f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
+                    f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use"
+                    " `from_tf=True` to load this model from those weights."
+                )
+            elif not use_safetensors and os.path.isfile(
+                os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
+            ):
+                raise OSError(
+                    f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
+                    f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`"
+                    " to load this model from those weights."
+                )
+            elif use_safetensors:
+                raise OSError(
+                    f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)} found in directory"
+                    f" {pretrained_model_name_or_path}."
+                )
+            else:
+                raise OSError(
+                    f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)},"
+                    f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory"
+                    f" {pretrained_model_name_or_path}."
+                )
+        elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
+            archive_file = pretrained_model_name_or_path
+            is_local = True
+        elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")):
+            if not from_tf:
+                raise ValueError(
+                    f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set "
+                    "from_tf to True to load from this checkpoint."
+                )
+            archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
+            is_local = True
+        elif is_remote_url(pretrained_model_name_or_path):
+            filename = pretrained_model_name_or_path
+            resolved_archive_file = download_url(pretrained_model_name_or_path)
+        else:
+            # set correct filename
+            if transformers_explicit_filename is not None:
+                filename = transformers_explicit_filename
+                is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
+            elif from_tf:
+                filename = TF2_WEIGHTS_NAME
+            elif from_flax:
+                filename = FLAX_WEIGHTS_NAME
+            elif use_safetensors is not False:
+                filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
+            else:
+                filename = _add_variant(WEIGHTS_NAME, variant)
+
+            try:
+                # Load from URL or cache if already cached
+                cached_file_kwargs = {
+                    "cache_dir": cache_dir,
+                    "force_download": force_download,
+                    "proxies": proxies,
+                    "local_files_only": local_files_only,
+                    "token": token,
+                    "user_agent": user_agent,
+                    "revision": revision,
+                    "subfolder": subfolder,
+                    "_raise_exceptions_for_gated_repo": False,
+                    "_raise_exceptions_for_missing_entries": False,
+                    "_commit_hash": commit_hash,
+                }
+                resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
+
+                # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
+                # result when internet is up, the repo and revision exist, but the file does not.
+                if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
+                    # Maybe the checkpoint is sharded, we try to grab the index name in this case.
+                    resolved_archive_file = cached_file(
+                        pretrained_model_name_or_path,
+                        _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
+                        **cached_file_kwargs,
+                    )
+                    if resolved_archive_file is not None:
+                        is_sharded = True
+                    elif use_safetensors:
+                        if revision == "main":
+                            resolved_archive_file, revision, is_sharded = auto_conversion(
+                                pretrained_model_name_or_path, **cached_file_kwargs
+                            )
+                        cached_file_kwargs["revision"] = revision
+                        if resolved_archive_file is None:
+                            raise OSError(
+                                f"{pretrained_model_name_or_path} does not appear to have a file named"
+                                f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} "
+                                "and thus cannot be loaded with `safetensors`. Please make sure that the model has "
+                                "been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
+                            )
+                    else:
+                        # This repo has no safetensors file of any kind, we switch to PyTorch.
+                        filename = _add_variant(WEIGHTS_NAME, variant)
+                        resolved_archive_file = cached_file(
+                            pretrained_model_name_or_path, filename, **cached_file_kwargs
+                        )
+                if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
+                    # Maybe the checkpoint is sharded, we try to grab the index name in this case.
+                    resolved_archive_file = cached_file(
+                        pretrained_model_name_or_path,
+                        _add_variant(WEIGHTS_INDEX_NAME, variant),
+                        **cached_file_kwargs,
+                    )
+                    if resolved_archive_file is not None:
+                        is_sharded = True
+                if not local_files_only and not is_offline_mode():
+                    if resolved_archive_file is not None:
+                        if filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]:
+                            # If the PyTorch file was found, check if there is a safetensors file on the repository
+                            # If there is no safetensors file on the repositories, start an auto conversion
+                            safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME
+                            has_file_kwargs = {
+                                "revision": revision,
+                                "proxies": proxies,
+                                "token": token,
+                                "cache_dir": cache_dir,
+                                "local_files_only": local_files_only,
+                            }
+                            cached_file_kwargs = {
+                                "cache_dir": cache_dir,
+                                "force_download": force_download,
+                                "local_files_only": local_files_only,
+                                "user_agent": user_agent,
+                                "subfolder": subfolder,
+                                "_raise_exceptions_for_gated_repo": False,
+                                "_raise_exceptions_for_missing_entries": False,
+                                "_commit_hash": commit_hash,
+                                **has_file_kwargs,
+                            }
+                            if (
+                                not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs)
+                                and not is_remote_code
+                            ):
+                                Thread(
+                                    target=auto_conversion,
+                                    args=(pretrained_model_name_or_path,),
+                                    kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs},
+                                    name="Thread-auto_conversion",
+                                ).start()
+                    else:
+                        # Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file.
+                        # We try those to give a helpful error message.
+                        has_file_kwargs = {
+                            "revision": revision,
+                            "proxies": proxies,
+                            "token": token,
+                            "cache_dir": cache_dir,
+                            "local_files_only": local_files_only,
+                        }
+                        if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
+                            raise OSError(
+                                f"{pretrained_model_name_or_path} does not appear to have a file named"
+                                f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for TensorFlow weights."
+                                " Use `from_tf=True` to load this model from those weights."
+                            )
+                        elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
+                            raise OSError(
+                                f"{pretrained_model_name_or_path} does not appear to have a file named"
+                                f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for Flax weights. Use"
+                                " `from_flax=True` to load this model from those weights."
+                            )
+                        elif variant is not None and has_file(
+                            pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
+                        ):
+                            raise OSError(
+                                f"{pretrained_model_name_or_path} does not appear to have a file named"
+                                f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
+                                f" {variant}. Use `variant=None` to load this model from those weights."
+                            )
+                        else:
+                            raise OSError(
+                                f"{pretrained_model_name_or_path} does not appear to have a file named"
+                                f" {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)},"
+                                f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
+                            )
+
+            except OSError:
+                # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
+                # to the original exception.
+                raise
+            except Exception as e:
+                # For any other exception, we throw a generic error.
+                raise OSError(
+                    f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
+                    " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+                    f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+                    f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)},"
+                    f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
+                ) from e
+
+        if is_local:
+            logger.info(f"loading weights file {archive_file}")
+            resolved_archive_file = archive_file
+        else:
+            logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
+
+    elif gguf_file:
+        # Case 1: the GGUF file is present locally
+        if os.path.isfile(gguf_file):
+            resolved_archive_file = gguf_file
+        # Case 2: The GGUF path is a location on the Hub
+        # Load from URL or cache if already cached
+        else:
+            cached_file_kwargs = {
+                "cache_dir": cache_dir,
+                "force_download": force_download,
+                "proxies": proxies,
+                "local_files_only": local_files_only,
+                "token": token,
+                "user_agent": user_agent,
+                "revision": revision,
+                "subfolder": subfolder,
+                "_raise_exceptions_for_gated_repo": False,
+                "_raise_exceptions_for_missing_entries": False,
+                "_commit_hash": commit_hash,
+            }
+
+            resolved_archive_file = cached_file(pretrained_model_name_or_path, gguf_file, **cached_file_kwargs)
+
+    # We now download and resolve all checkpoint files if the checkpoint is sharded
+    sharded_metadata = None
+    if is_sharded:
+        checkpoint_files, sharded_metadata = get_checkpoint_shard_files(
+            pretrained_model_name_or_path,
+            resolved_archive_file,
+            cache_dir=cache_dir,
+            force_download=force_download,
+            proxies=proxies,
+            local_files_only=local_files_only,
+            token=token,
+            user_agent=user_agent,
+            revision=revision,
+            subfolder=subfolder,
+            _commit_hash=commit_hash,
+        )
+    else:
+        checkpoint_files = [resolved_archive_file] if pretrained_model_name_or_path is not None else None
+
+    return checkpoint_files, sharded_metadata
+
+
+def _get_dtype(
+    cls,
+    dtype: Optional[Union[str, torch.dtype, dict]],
+    checkpoint_files: Optional[list[str]],
+    config: PretrainedConfig,
+    sharded_metadata: Optional[dict],
+    state_dict: Optional[dict],
+    weights_only: bool,
+) -> tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
+    """Find the correct `dtype` to use based on provided arguments. Also update the `config` based on the
+    inferred dtype. We do the following:
+    1. If dtype is not None, we use that dtype
+    2. If dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
+        weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
+    we also may have config.dtype available, but we won't rely on it till v5
+    """
+    dtype_orig = None
+    is_sharded = sharded_metadata is not None
+
+    if dtype is not None:
+        if isinstance(dtype, str):
+            if dtype == "auto":
+                if hasattr(config, "dtype") and config.dtype is not None:
+                    dtype = config.dtype
+                    logger.info(f"Will use dtype={dtype} as defined in model's config object")
+                else:
+                    if is_sharded and "dtype" in sharded_metadata:
+                        dtype = sharded_metadata["dtype"]
+                    elif state_dict is not None:
+                        dtype = get_state_dict_dtype(state_dict)
+                    else:
+                        state_dict = load_state_dict(
+                            checkpoint_files[0], map_location="meta", weights_only=weights_only
+                        )
+                        dtype = get_state_dict_dtype(state_dict)
+                    logger.info(
+                        "Since the `dtype` attribute can't be found in model's config object, "
+                        "will use dtype={dtype} as derived from model's weights"
+                    )
+            elif hasattr(torch, dtype):
+                dtype = getattr(torch, dtype)
+                config.dtype = dtype
+                for sub_config_key in config.sub_configs:
+                    sub_config = getattr(config, sub_config_key)
+                    sub_config.dtype = dtype
+        elif isinstance(dtype, torch.dtype):
+            config.dtype = dtype
+            for sub_config_key in config.sub_configs:
+                sub_config = getattr(config, sub_config_key)
+                sub_config.dtype = dtype
+        elif isinstance(dtype, dict):
+            for key, curr_dtype in dtype.items():
+                if hasattr(config, key):
+                    value = getattr(config, key)
+                    curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
+                    value.dtype = curr_dtype
+            # main torch dtype for modules that aren't part of any sub-config
+            dtype = dtype.get("")
+            dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
+            config.dtype = dtype
+            if dtype is None:
+                dtype = torch.float32
+        else:
+            raise ValueError(
+                f"`dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `dtype` "
+                f"for each sub-config in composite configs, but received {dtype}"
+            )
+
+        dtype_orig = cls._set_default_dtype(dtype)
+    else:
+        # set fp32 as the default dtype for BC
+        default_dtype = torch.get_default_dtype()
+        config.dtype = default_dtype
+        for key in config.sub_configs:
+            value = getattr(config, key)
+            value.dtype = default_dtype
+
+    return config, dtype, dtype_orig
+
+
+def _get_device_map(
+    model: "PreTrainedModel",
+    device_map: Optional[Union[dict, str]],
+    max_memory: Optional[dict],
+    hf_quantizer: Optional[HfQuantizer],
+    dtype: Optional[torch.dtype],
+    keep_in_fp32_regex: Optional[re.Pattern],
+) -> dict:
+    """Compute the final `device_map` to use if we passed a value in ['auto', 'balanced', 'balanced_low_0', 'sequential'].
+    Otherwise, we check for any device inconsistencies in the device_map.
+    """
+    if isinstance(device_map, str):
+        special_dtypes = {}
+        if hf_quantizer is not None:
+            special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, dtype))
+        if keep_in_fp32_regex is not None:
+            special_dtypes.update(
+                {name: torch.float32 for name, _ in model.named_parameters() if keep_in_fp32_regex.search(name)}
+            )
+
+        target_dtype = dtype
+
+        if hf_quantizer is not None:
+            target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)
+
+        no_split_modules = model._get_no_split_modules(device_map)
+        device_map_kwargs = {"no_split_module_classes": no_split_modules}
+
+        if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
+            device_map_kwargs["special_dtypes"] = special_dtypes
+        elif len(special_dtypes) > 0:
+            logger.warning(
+                "This model has some weights that should be kept in higher precision, you need to upgrade "
+                "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
+            )
+
+        if device_map != "sequential":
+            inferred_max_memory = get_balanced_memory(
+                model,
+                dtype=target_dtype,
+                low_zero=(device_map == "balanced_low_0"),
+                max_memory=max_memory,
+                **device_map_kwargs,
+            )
+        else:
+            inferred_max_memory = get_max_memory(max_memory)
+        if hf_quantizer is not None:
+            inferred_max_memory = hf_quantizer.adjust_max_memory(inferred_max_memory)
+
+        # `inferred_max_memory` contains non-reserved memory. There may be *unused* reserved memory in the GPU,
+        # which we can use to allocate parameters.
+        for device_name in inferred_max_memory:
+            if isinstance(device_name, int):  # it's a GPU device
+                if is_torch_xpu_available():
+                    unused_memory = torch.xpu.memory_reserved(device_name) - torch.xpu.memory_allocated(device_name)
+                else:
+                    unused_memory = torch.cuda.memory_reserved(device_name) - torch.cuda.memory_allocated(device_name)
+                inferred_max_memory[device_name] += unused_memory
+            # respect the `max_memory` passed by the user
+            if max_memory is not None and device_name in max_memory:
+                inferred_max_memory[device_name] = min(inferred_max_memory[device_name], max_memory[device_name])
+        device_map_kwargs["max_memory"] = inferred_max_memory
+
+        device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
+
+        if hf_quantizer is not None:
+            hf_quantizer.validate_environment(device_map=device_map)
+
+    elif device_map is not None:
+        tied_params = find_tied_parameters(model)
+        # check if we don't have tied param in different devices
+        check_tied_parameters_on_same_device(tied_params, device_map)
+
+    return device_map
+
+
+def _find_missing_and_unexpected_keys(
+    cls,
+    model: "PreTrainedModel",
+    original_checkpoint_keys: list[str],
+    checkpoint_keys: list[str],
+    loading_base_model_from_task_state_dict: bool,
+    hf_quantizer: Optional[HfQuantizer],
+    device_map: dict,
+) -> tuple[list[str], list[str]]:
+    """Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
+    (keys found in the loaded state dict keys, but that are NOT part of the model parameters)
+    """
+    prefix = model.base_model_prefix
+
+    # Compute expected keys, i.e. keys that the FULL model (not model_to_load) expects
+    expected_keys = list(model.state_dict().keys())
+    if hf_quantizer is not None:
+        expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, checkpoint_keys)
+
+    # Adjust prefix of the keys to make them match loaded keys before removing them
+    missing_keys = sorted(set(expected_keys) - set(checkpoint_keys))
+    unexpected_keys = set(checkpoint_keys) - set(expected_keys)
+    # If a module has the same name under the base and task specific model, we have to re-add it to unexpected keys
+    if loading_base_model_from_task_state_dict:
+        task_specific_keys = [k for k in original_checkpoint_keys if not k.startswith(f"{prefix}.")]
+        unexpected_keys.update(task_specific_keys)
+
+    # Remove nonpersistent buffers from unexpected keys: they are not in the expected keys (model state dict), but
+    # may be in the loaded keys. Note that removing all buffers does the job, as they were part of the expected keys anyway
+    model_buffers = {n for n, _ in model.named_buffers()}
+    unexpected_keys = sorted(unexpected_keys - model_buffers)
+
+    # Old checkpoints may have keys for rotary_emb.inv_freq for each layer, however we moved this buffer to the main model
+    # (so the buffer name has changed). Remove them in such a case
+    has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer in model_buffers)
+    if has_inv_freq_buffers:
+        unexpected_keys = [k for k in unexpected_keys if "rotary_emb.inv_freq" not in k]
+
+    tied_params = find_tied_parameters(model)
+    for group in tied_params:
+        missing_in_group = [k for k in missing_keys if k in group]
+        if len(missing_in_group) > 0 and len(missing_in_group) < len(group):
+            missing_keys = [k for k in missing_keys if k not in missing_in_group]
+
+    if hf_quantizer is not None:
+        missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
+        unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys, prefix)
+
+    # Model-specific exceptions for missing and unexpected keys (e.g. if the modeling change over time, or any other reason...)
+    if cls._keys_to_ignore_on_load_missing is not None:
+        for pattern in cls._keys_to_ignore_on_load_missing:
+            missing_keys = [k for k in missing_keys if re.search(pattern, k) is None]
+
+    if cls._keys_to_ignore_on_load_unexpected is not None:
+        for pattern in cls._keys_to_ignore_on_load_unexpected:
+            unexpected_keys = [k for k in unexpected_keys if re.search(pattern, k) is None]
+
+    return missing_keys, unexpected_keys
+
+
+def _find_mismatched_keys(
+    model: "PreTrainedModel",
+    state_dict: Optional[dict],
+    checkpoint_files: Optional[list[str]],
+    ignore_mismatched_sizes: bool,
+    keys_to_rename_mapping: dict[str, str],
+    is_quantized: bool,
+    weights_only: bool,
+) -> tuple[list[str], list[tuple[int, int]]]:
+    """
+    Find potential shape mismatch between the different state dicts and the model parameters, but only if `ignore_mismatched_sizes`
+    is True. Otherwise, return immediately and any shape mismatch that may exist will be raised later on. This avoids checking
+    every parameter in advance, as shape mismatch are extremely rare in practice. If we want to ignore them however, we do
+    need to check in advance as we need to know which parameters we need to move back from meta to cpu, and initialize
+    correctly. Indeed, as our model initialization takes place at the module level, and not the weight level, in the
+    case of a sharded checkpoint we cannot correctly initialize the weights according to `model._init_weights()` if we perform
+    this check on each state dict at loading time (after the first loaded checkpoint, there are no way to initialize only the
+    mismatched weights if any, without overwriting the previously loaded weights as well because all the module will be
+    initialized, not only the weights that are mismatched).
+    """
+
+    # An error will be raised later on anyway if there is a mismatch - this avoids running the rest of this function
+    # if there are no mismatch (which is almost always the case)
+    if not ignore_mismatched_sizes:
+        return [], []
+
+    if state_dict is not None:
+        checkpoint_files = [""]
+
+    model_state_dict = model.state_dict()
+    mismatched_keys = []
+    mismatched_shapes = []
+    for shard_file in checkpoint_files:
+        # If shard_file is "", we use the existing state_dict instead of loading it
+        if shard_file != "":
+            state_dict = load_state_dict(
+                shard_file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only
+            )
+
+        # Fix the key names
+        new_state_dict = {keys_to_rename_mapping[k]: v for k, v in state_dict.items() if k in keys_to_rename_mapping}
+
+        for key, tensor in new_state_dict.items():
+            if key in model_state_dict and tensor.shape != model_state_dict[key].shape:
+                # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
+                # Without matching with module type or parameter type it seems like a practical way to detect valid 4bit weights.
+                if not (
+                    is_quantized and tensor.shape[-1] == 1 and tensor.numel() * 2 == model_state_dict[key].numel()
+                ):
+                    mismatched_keys.append(key)
+                    mismatched_shapes.append((tensor.shape, model_state_dict[key].shape))
+
+    return mismatched_keys, mismatched_shapes
+
+
+class PipelineParallel(Enum):
+    inputs: 0
+    outputs: 1
+
+
+class ModuleUtilsMixin:
+    """
+    A few utilities for `torch.nn.Modules`, to be used as a mixin.
+    """
+
+    @staticmethod
+    def _hook_rss_memory_pre_forward(module, *args, **kwargs):
+        try:
+            import psutil
+        except ImportError:
+            raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
+
+        process = psutil.Process(os.getpid())
+        mem = process.memory_info()
+        module.mem_rss_pre_forward = mem.rss
+        return None
+
+    @staticmethod
+    def _hook_rss_memory_post_forward(module, *args, **kwargs):
+        try:
+            import psutil
+        except ImportError:
+            raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
+
+        process = psutil.Process(os.getpid())
+        mem = process.memory_info()
+        module.mem_rss_post_forward = mem.rss
+        mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
+        module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
+        return None
+
+    def add_memory_hooks(self):
+        """
+        Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
+
+        Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero
+        with `model.reset_memory_hooks_state()`.
+        """
+        for module in self.modules():
+            module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
+            module.register_forward_hook(self._hook_rss_memory_post_forward)
+        self.reset_memory_hooks_state()
+
+    def reset_memory_hooks_state(self):
+        """
+        Reset the `mem_rss_diff` attribute of each module (see [`~modeling_utils.ModuleUtilsMixin.add_memory_hooks`]).
+        """
+        for module in self.modules():
+            module.mem_rss_diff = 0
+            module.mem_rss_post_forward = 0
+            module.mem_rss_pre_forward = 0
+
+    @property
+    def device(self) -> torch.device:
+        """
+        `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
+        device).
+        """
+        return get_parameter_device(self)
+
+    @property
+    def dtype(self) -> torch.dtype:
+        """
+        `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
+        """
+        return get_parameter_dtype(self)
+
+    def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
+        """
+        Invert an attention mask (e.g., switches 0. and 1.).
+
+        Args:
+            encoder_attention_mask (`torch.Tensor`): An attention mask.
+
+        Returns:
+            `torch.Tensor`: The inverted attention mask.
+        """
+        if encoder_attention_mask.dim() == 3:
+            encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
+        if encoder_attention_mask.dim() == 2:
+            encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
+        # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
+        # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
+        # /transformer/transformer_layers.py#L270
+        # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
+        # encoder_extended_attention_mask.transpose(-1, -2))
+        encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
+        encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(self.dtype).min
+
+        return encoder_extended_attention_mask
+
+    @staticmethod
+    def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None):
+        if device is not None:
+            warnings.warn(
+                "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
+            )
+        else:
+            device = attention_mask.device
+        batch_size, seq_length = input_shape
+        seq_ids = torch.arange(seq_length, device=device)
+        causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
+        # in case past_key_values are used we need to add a prefix ones mask to the causal mask
+        causal_mask = causal_mask.to(attention_mask.dtype)
+
+        if causal_mask.shape[1] < attention_mask.shape[1]:
+            prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+            causal_mask = torch.cat(
+                [
+                    torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
+                    causal_mask,
+                ],
+                axis=-1,
+            )
+
+        extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+        return extended_attention_mask
+
+    def get_extended_attention_mask(
+        self, attention_mask: Tensor, input_shape: tuple[int], device: torch.device = None, dtype: torch.float = None
+    ) -> Tensor:
+        """
+        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+        Arguments:
+            attention_mask (`torch.Tensor`):
+                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+            input_shape (`tuple[int]`):
+                The shape of the input to the model.
+
+        Returns:
+            `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
+        """
+        if dtype is None:
+            dtype = self.dtype
+
+        if not (attention_mask.dim() == 2 and self.config.is_decoder):
+            # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
+            if device is not None:
+                warnings.warn(
+                    "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
+                )
+        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+        # ourselves in which case we just need to make it broadcastable to all heads.
+        if attention_mask.dim() == 3:
+            extended_attention_mask = attention_mask[:, None, :, :]
+        elif attention_mask.dim() == 2:
+            # Provided a padding mask of dimensions [batch_size, seq_length]
+            # - if the model is a decoder, apply a causal mask in addition to the padding mask
+            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+            if self.config.is_decoder:
+                extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
+                    input_shape, attention_mask, device
+                )
+            else:
+                extended_attention_mask = attention_mask[:, None, None, :]
+        else:
+            raise ValueError(
+                f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
+            )
+
+        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+        # masked positions, this operation will create a tensor which is 0.0 for
+        # positions we want to attend and the dtype's smallest value for masked positions.
+        # Since we are adding it to the raw scores before the softmax, this is
+        # effectively the same as removing these entirely.
+        extended_attention_mask = extended_attention_mask.to(dtype=dtype)  # fp16 compatibility
+        extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
+        return extended_attention_mask
+
+    def get_head_mask(
+        self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
+    ) -> Tensor:
+        """
+        Prepare the head mask if needed.
+
+        Args:
+            head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
+                The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
+            num_hidden_layers (`int`):
+                The number of hidden layers in the model.
+            is_attention_chunked (`bool`, *optional*, defaults to `False`):
+                Whether or not the attentions scores are computed by chunks or not.
+
+        Returns:
+            `torch.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
+            `[None]` for each layer.
+        """
+        if head_mask is not None:
+            head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
+            if is_attention_chunked is True:
+                head_mask = head_mask.unsqueeze(-1)
+        else:
+            head_mask = [None] * num_hidden_layers
+
+        return head_mask
+
+    def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
+        """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
+        if head_mask.dim() == 1:
+            head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+            head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
+        elif head_mask.dim() == 2:
+            head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
+        assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
+        head_mask = head_mask.to(dtype=self.dtype)  # switch to float if need + fp16 compatibility
+        return head_mask
+
+    def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
+        """
+        Get number of (optionally, trainable or non-embeddings) parameters in the module.
+
+        Args:
+            only_trainable (`bool`, *optional*, defaults to `False`):
+                Whether or not to return only the number of trainable parameters
+
+            exclude_embeddings (`bool`, *optional*, defaults to `False`):
+                Whether or not to return only the number of non-embeddings parameters
+
+        Returns:
+            `int`: The number of parameters.
+        """
+
+        if exclude_embeddings:
+            embedding_param_names = [
+                f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
+            ]
+            total_parameters = [
+                parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
+            ]
+        else:
+            total_parameters = list(self.parameters())
+
+        total_numel = []
+        is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False)
+
+        if is_loaded_in_4bit:
+            if is_bitsandbytes_available():
+                import bitsandbytes as bnb
+            else:
+                raise ValueError(
+                    "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong"
+                    " make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. "
+                )
+
+        for param in total_parameters:
+            if param.requires_grad or not only_trainable:
+                # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are
+                # used for the 4bit quantization (uint8 tensors are stored)
+                if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
+                    if hasattr(param, "element_size"):
+                        num_bytes = param.element_size()
+                    elif hasattr(param, "quant_storage"):
+                        num_bytes = param.quant_storage.itemsize
+                    else:
+                        num_bytes = 1
+                    total_numel.append(param.numel() * 2 * num_bytes)
+                else:
+                    total_numel.append(param.numel())
+
+        return sum(total_numel)
+
+    def estimate_tokens(self, input_dict: dict[str, Union[torch.Tensor, Any]]) -> int:
+        """
+        Helper function to estimate the total number of tokens from the model inputs.
+
+        Args:
+            inputs (`dict`): The model inputs.
+
+        Returns:
+            `int`: The total number of tokens.
+        """
+        if not hasattr(self, "warnings_issued"):
+            self.warnings_issued = {}
+        if self.main_input_name in input_dict:
+            return input_dict[self.main_input_name].numel()
+        elif "estimate_tokens" not in self.warnings_issued:
+            logger.warning(
+                "Could not estimate the number of tokens of the input, floating-point operations will not be computed"
+            )
+            self.warnings_issued["estimate_tokens"] = True
+        return 0
+
+    def floating_point_ops(
+        self, input_dict: dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True
+    ) -> int:
+        """
+        Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a
+        batch with this transformer model. Default approximation neglects the quadratic dependency on the number of
+        tokens (valid if `12 * d_model << sequence_length`) as laid out in [this
+        paper](https://huggingface.co/papers/2001.08361) section 2.1. Should be overridden for transformers with parameter
+        re-use e.g. Albert or Universal Transformers, or if doing long-range modeling with very high sequence lengths.
+
+        Args:
+            batch_size (`int`):
+                The batch size for the forward pass.
+
+            sequence_length (`int`):
+                The number of tokens in each line of the batch.
+
+            exclude_embeddings (`bool`, *optional*, defaults to `True`):
+                Whether or not to count embedding and softmax operations.
+
+        Returns:
+            `int`: The number of floating-point operations.
+        """
+
+        return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
+
+
+class EmbeddingAccessMixin:
+    """
+    Base utilities to regroup getters and setters for embeddings.
+    Introduces the `input_layer_embed` attribute, which indicates
+    where the input embeddings come from and where they
+    should be set.
+    """
+
+    _input_embed_layer = "embed_tokens"  # default layer that holds input embeddings.
+
+    def get_input_embeddings(self) -> nn.Module:
+        """
+        Returns the model's input embeddings.
+
+        Returns:
+            `nn.Module`: A torch module mapping vocabulary to hidden states.
+        """
+
+        # 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer
+        #  for most NLP models), and if so, return it.
+
+        name = getattr(self, "_input_embed_layer", "embed_tokens")
+
+        if (default_embedding := getattr(self, name, None)) is not None:
+            return default_embedding
+        # 2) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
+
+        if hasattr(self, "model") and hasattr(self.model, "embed_tokens"):
+            return self.model.embed_tokens
+
+        # 3) vanilla decoder‑only architectures
+        elif hasattr(self, "embed_tokens"):
+            return self.embed_tokens
+        else:
+            base_model = getattr(self, "base_model_prefix", None)
+            if base_model is not None:
+                base_model = getattr(self, base_model, None)
+                if base_model is not None and base_model is not self:
+                    return base_model.get_input_embeddings()
+            raise NotImplementedError(
+                f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; "
+                "please override in the subclass."
+            )
+
+    def set_input_embeddings(self, value: nn.Module):
+        """Fallback setter that handles **~70 %** of models in the code‑base.
+
+        Order of attempts:
+        1. `self.model.embed_tokens`
+        2. `self.embed_tokens`
+        3. delegate to the *base model* if one exists
+        4. otherwise raise `NotImplementedError` so subclasses still can (and
+            should) override for exotic layouts.
+        """
+
+        # 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
+        name = getattr(self, "_input_embed_layer", "embed_tokens")
+        if hasattr(self, "model") and hasattr(self.model, name):
+            setattr(self.model, name, value)
+        # 2) as well as vanilla decoder‑only architectures
+        elif hasattr(self, name):
+            setattr(self, name, value)
+        # 3) recurse once into the registered *base* model (e.g. for encoder/decoder)
+        elif getattr(self, self.base_model_prefix, self) is not self:
+            base_model = getattr(self, self.base_model_prefix, self)
+            base_model.set_input_embeddings(value)
+        else:
+            raise NotImplementedError(
+                f"`set_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
+            )
+
+    def get_output_embeddings(self):
+        if not hasattr(self, "lm_head"):
+            return None
+        try:
+            # Speech / vision backbones raise here, so we return None.
+            # Legit use of get_input_embs?
+            self.get_input_embeddings()
+        except NotImplementedError:
+            return None
+        return self.lm_head
+
+    def set_output_embeddings(self, new_embeddings):
+        """
+        Sets the model's output embedding, defaulting to setting new_embeddings to lm_head.
+        """
+        if getattr(self, "lm_head"):
+            self.lm_head = new_embeddings
+
+
+class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
+    r"""
+    Base class for all models.
+
+    [`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
+    downloading and saving models as well as a few methods common to all models to:
+
+        - resize the input embeddings,
+        - prune heads in the self-attention heads.
+
+    Class attributes (overridden by derived classes):
+
+        - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
+          for this model architecture.
+        - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
+          taking as arguments:
+
+            - **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint.
+            - **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model.
+            - **path** (`str`) -- A path to the TensorFlow checkpoint.
+
+        - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
+          classes of the same architecture adding modules on top of the base model.
+        - **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
+        - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
+          models, `pixel_values` for vision models and `input_values` for speech models).
+        - **can_record_outputs** (dict):"""
+
+    config_class = None
+    base_model_prefix = ""
+    main_input_name = "input_ids"
+    model_tags = None
+
+    _checkpoint_conversion_mapping = {}  # used for BC support in VLMs, not meant to be used by new models
+
+    _auto_class = None
+    _no_split_modules = None
+    _skip_keys_device_placement = None
+
+    _keep_in_fp32_modules = None
+    # the _keep_in_fp32_modules will avoid casting to anything other than float32, except bfloat16
+    # to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
+    _keep_in_fp32_modules_strict = None
+
+    # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
+    # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
+    _keys_to_ignore_on_load_missing = None
+    # a list of `re` patterns of `state_dict` keys that should be removed from the list of
+    # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
+    # warnings.
+    _keys_to_ignore_on_load_unexpected = None
+    # a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
+    # trained, but which are either deterministic or tied variables)
+    _keys_to_ignore_on_save = None
+    # a list of `state_dict` keys that are potentially tied to another key in the state_dict.
+    _tied_weights_keys = None
+
+    is_parallelizable = False
+    supports_gradient_checkpointing = False
+    _is_stateful = False
+
+    # Flash Attention support
+    _supports_flash_attn = False
+
+    # SDPA support
+    _supports_sdpa = False
+
+    # Flex Attention support
+    _supports_flex_attn = False
+
+    _can_compile_fullgraph = False
+
+    # A tensor parallel plan to be applied to the model when TP is enabled. For
+    # top-level models, this attribute is currently defined in respective model
+    # code. For base models, this attribute comes from
+    # `config.base_model_tp_plan` during `__init__`.
+    # It should identify the layers exactly: if you want to TP model.language_model.layers.fc1
+    # by passing `tp_plan` to the init, it should be {"model.language_model.layers.fc1":"colwise"}
+    # for example.
+    _tp_plan = None
+
+    # tensor parallel degree to which model is sharded to.
+    _tp_size = None
+
+    # A pipeline parallel plan specifying the layers which may not be present
+    # on all ranks when PP is enabled. For top-level models, this attribute is
+    # currently defined in respective model code. For base models, this
+    # attribute comes from `config.base_model_pp_plan` during `post_init`.
+    #
+    # The variable names for the inputs and outputs of the specified layers can
+    # be indexed using the `PipelineParallel` enum as follows:
+    # - `_pp_plan["layers"][PipelineParallel.inputs]`
+    # - `_pp_plan["layers"][PipelineParallel.outputs]`
+    _pp_plan = None
+
+    # This flag signal that the model can be used as an efficient backend in TGI and vLLM
+    # In practice, it means that they support attention (mask) interface functions, fully pass the kwargs
+    # through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
+    _supports_attention_backend = False
+    _can_record_outputs = None
+
+    @property
+    @torch._dynamo.allow_in_graph
+    def can_record_outputs(self) -> dict[str, OutputRecorder]:
+        """
+         Maps output names (e.g., "attentions", "hidden_states")
+         to either:
+             - A module class (e.g., `LlamaDecoderLayer`), using default index conventions:
+                 * index=0 for "hidden_states"
+                 * index=1 for "attentions"
+             - Or an `OutputRecorder(...)` with `target_class`, optional `index`, and `layer_name`.
+
+         Examples:
+             These two are equivalent:
+
+         ```python
+             _can_record_outputs = {
+                 "attentions": LlamaAttention,
+                 "hidden_states": LlamaDecoderLayer
+             }
+
+             _can_record_outputs = {
+                 "attentions": OutputRecorder(LlamaAttention, index=1),
+                 "hidden_states": OutputRecorder(LlamaDecoderLayer, index=0)
+             }
+        ```
+
+         This means you can record outputs from the same class, by specifying a layer name. Before
+         collecting outputs, we check that they come from this layer.
+
+         If you have cross attention that come from `LlamaAttention` and self attention that also
+         come from `LlamaAttention` but from `self_attn` you can do this:
+
+         ```python
+         class LlamaModel(PreTrainedModel):
+             _can_record_outputs = {
+                 "attentions": OutputRecorder(LlamaAttention, index=1, layer-name="self_attn"),
+                 "cross_attentions": OutputRecorder(LlamaAttention, index=1, layer_name="cross_attn")
+             }
+
+        ```
+        """
+        return self._can_record_outputs or {}
+
+    @property
+    def dummy_inputs(self) -> dict[str, torch.Tensor]:
+        """
+        `dict[str, torch.Tensor]`: Dummy inputs to do a forward pass in the network.
+        """
+        return {"input_ids": torch.tensor(DUMMY_INPUTS)}
+
+    @property
+    def framework(self) -> str:
+        """
+        :str: Identifies that this is a PyTorch model.
+        """
+        return "pt"
+
+    def __init_subclass__(cls, **kwargs):
+        super().__init_subclass__(**kwargs)
+        # For BC we keep the original `config_class` definition in case
+        # there is a `config_class` attribute (e.g. remote code models),
+        # otherwise we derive it from the annotated `config` attribute.
+
+        # defined in this particular subclass
+        child_annotation = cls.__dict__.get("__annotations__", {}).get("config", None)
+        child_attribute = cls.__dict__.get("config_class", None)
+
+        # defined in the class (this subclass or any parent class)
+        full_annotation = get_type_hints(cls).get("config", None)
+        full_attribute = cls.config_class
+
+        # priority (child class_config -> child annotation -> global class_config -> global annotation)
+        if child_attribute is not None:
+            cls.config_class = child_attribute
+        elif child_annotation is not None:
+            cls.config_class = child_annotation
+        elif full_attribute is not None:
+            cls.config_class = full_attribute
+        elif full_annotation is not None:
+            cls.config_class = full_annotation
+
+    def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
+        super().__init__()
+        if not isinstance(config, PretrainedConfig):
+            raise TypeError(
+                f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
+                "`PretrainedConfig`. To create a model from a pretrained model use "
+                f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
+            )
+        self.config = config
+
+        # Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid
+        # setting it recursively)
+        self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
+            self.config._attn_implementation, is_init_check=True
+        )
+
+        # for initialization of the loss
+        loss_type = self.__class__.__name__
+        if loss_type not in LOSS_MAPPING:
+            loss_groups = f"({'|'.join(LOSS_MAPPING)})"
+            loss_type = re.findall(loss_groups, self.__class__.__name__)
+            if len(loss_type) > 0:
+                loss_type = loss_type[0]
+            else:
+                loss_type = None
+        self.loss_type = loss_type
+
+        self.name_or_path = config.name_or_path
+        self.warnings_issued = {}
+        self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
+        # Overwrite the class attribute to make it an instance attribute, so models like
+        # `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
+        # when a different component (e.g. language_model) is used.
+        self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
+        self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
+
+        self._no_split_modules = self._no_split_modules or []
+        _CAN_RECORD_REGISTRY[str(self.__class__)] = self._can_record_outputs  # added for executorch support only
+
+    def post_init(self):
+        """
+        A method executed at the end of each Transformer model initialization, to execute code that needs the model's
+        modules properly initialized (such as weight initialization).
+
+        This is also used when the user is running distributed code. We add hooks to the modules here, according to
+        the model's tp_plan!
+        """
+        self.init_weights()
+        self._backward_compatibility_gradient_checkpointing()
+
+        # Make sure the modules correctly exist if the flag is active
+        if self._keep_in_fp32_modules is not None or self._keep_in_fp32_modules_strict is not None:
+            all_parameters = {name for name, _ in self.named_parameters() if len(name) > 0}
+            unique_module_names = set()
+            # Get all unique module names in the module graph, without the prefixes
+            for param in all_parameters:
+                unique_module_names.update(
+                    [name for name in param.split(".") if not name.isnumeric() and name not in ["weight", "bias"]]
+                )
+            # Check that every module in the keep_in_fp32 list is part of the module graph
+            if self._keep_in_fp32_modules is not None:
+                for module in self._keep_in_fp32_modules:
+                    if module not in unique_module_names:
+                        raise ValueError(
+                            f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in"
+                            f" {self.__class__.__name__}"
+                        )
+
+            if self._keep_in_fp32_modules_strict is not None:
+                for module in self._keep_in_fp32_modules_strict:
+                    if module not in unique_module_names:
+                        raise ValueError(
+                            f"{module} was specified in the `_keep_in_fp32_modules_strict` list, but is not part of the modules in"
+                            f" {self.__class__.__name__}"
+                        )
+
+        # If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
+        self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else {}
+        self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
+        self._ep_plan = self.config.base_model_ep_plan.copy() if self.config.base_model_ep_plan is not None else {}
+        for name, module in self.named_children():
+            if plan := getattr(module, "_ep_plan", None):
+                self._ep_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
+            if plan := getattr(module, "_tp_plan", None):
+                self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
+            if plan := getattr(module, "_pp_plan", None):
+                self._pp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
+
+    @property
+    def tp_plan(self) -> dict[str, str]:
+        """
+        The full tp plan for the model's modules
+        """
+        if hasattr(self.config, "distributed_config") and self.config.distributed_config.enable_expert_parallel:
+            return self._ep_plan
+        return self._tp_plan
+
+    @property
+    def pp_plan(self) -> dict[str, tuple[str, str]]:
+        return self._pp_plan
+
+    @tp_plan.setter
+    def tp_plan(self, plan: dict[str, str]):
+        if plan is not None:
+            # Validate that all parallel styles in the plan are supported
+            from .integrations.tensor_parallel import ALL_PARALLEL_STYLES
+
+            for layer_pattern, parallel_style in plan.items():
+                if parallel_style not in ALL_PARALLEL_STYLES:
+                    raise ValueError(
+                        f"Unsupported tensor parallel style '{parallel_style}' for layer '{layer_pattern}'. "
+                        f"Supported styles are {list(ALL_PARALLEL_STYLES.keys())}"
+                    )
+
+            # Validate that the layer patterns match existing model structure
+            # We check this by getting all parameter names and seeing if any match the patterns
+            if hasattr(self, "named_parameters"):
+                model_param_names = [name for name, _ in self.named_parameters()]
+                if model_param_names:  # Only validate if model has parameters
+                    import re
+
+                    for layer_pattern in plan.keys():
+                        # Convert pattern to regex (replace * with .*)
+                        regex_pattern = layer_pattern.replace("*", r"\d+")
+                        pattern_matched = False
+                        for param_name in model_param_names:
+                            if re.match(regex_pattern, param_name):
+                                pattern_matched = True
+                                break
+                        if not pattern_matched:
+                            # Try more flexible matching - check if pattern components exist
+                            pattern_parts = layer_pattern.split(".")
+                            flexible_matched = False
+                            for param_name in model_param_names:
+                                param_parts = param_name.split(".")
+                                if len(pattern_parts) <= len(param_parts):
+                                    match_count = 0
+                                    for i, pattern_part in enumerate(pattern_parts):
+                                        if pattern_part == "*":
+                                            match_count += 1
+                                        elif i < len(param_parts) and pattern_part == param_parts[i]:
+                                            match_count += 1
+                                    if match_count == len(pattern_parts):
+                                        flexible_matched = True
+                                        break
+                            if not flexible_matched:
+                                import warnings
+
+                                warnings.warn(
+                                    f"Layer pattern '{layer_pattern}' does not match any parameters in the model. "
+                                    f"This rule may not be applied during tensor parallelization."
+                                )
+
+        self._tp_plan = plan if plan is not None else {}
+
+    @pp_plan.setter
+    def pp_plan(self, plan: dict[str, tuple[str, str]]):
+        self._pp_plan = plan
+
+    def dequantize(self):
+        """
+        Potentially dequantize the model in case it has been quantized by a quantization method that support
+        dequantization.
+        """
+        hf_quantizer = getattr(self, "hf_quantizer", None)
+
+        if hf_quantizer is None:
+            raise ValueError("You need to first quantize your model in order to dequantize it")
+
+        return hf_quantizer.dequantize(self)
+
+    def _backward_compatibility_gradient_checkpointing(self):
+        if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
+            self.gradient_checkpointing_enable()
+            # Remove the attribute now that is has been consumed, so it's no saved in the config.
+            delattr(self.config, "gradient_checkpointing")
+
+    def add_model_tags(self, tags: Union[list[str], str]) -> None:
+        r"""
+        Add custom tags into the model that gets pushed to the Hugging Face Hub. Will
+        not overwrite existing tags in the model.
+
+        Args:
+            tags (`Union[list[str], str]`):
+                The desired tags to inject in the model
+
+        Examples:
+
+        ```python
+        from transformers import AutoModel
+
+        model = AutoModel.from_pretrained("google-bert/bert-base-cased")
+
+        model.add_model_tags(["custom", "custom-bert"])
+
+        # Push the model to your namespace with the name "my-custom-bert".
+        model.push_to_hub("my-custom-bert")
+        ```
+        """
+        if isinstance(tags, str):
+            tags = [tags]
+
+        if self.model_tags is None:
+            self.model_tags = []
+
+        for tag in tags:
+            if tag not in self.model_tags:
+                self.model_tags.append(tag)
+
+    @classmethod
+    @restore_default_dtype
+    def _from_config(cls, config, **kwargs):
+        """
+        All context managers that the model should be initialized under go here.
+
+        Args:
+            dtype (`torch.dtype`, *optional*):
+                Override the default `dtype` and load the model under this dtype.
+        """
+        # when we init a model from within another model (e.g. VLMs) and dispatch on FA2
+        # a warning is raised that dtype should be fp16. Since we never pass dtype from within
+        # modeling code, we can try to infer it here same way as done in `from_pretrained`
+        # For BC on the old `torch_dtype`
+        dtype = kwargs.pop("dtype", config.dtype)
+        if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
+            logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
+            # if both kwargs are provided, use `dtype`
+            dtype = dtype if dtype != config.dtype else torch_dtype
+        if isinstance(dtype, str):
+            dtype = getattr(torch, dtype)
+
+        # override default dtype if needed
+        dtype_orig = None
+        if dtype is not None:
+            dtype_orig = cls._set_default_dtype(dtype)
+
+        # If passing `attn_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
+        if "attn_implementation" in kwargs:
+            config._attn_implementation = kwargs.pop("attn_implementation")
+
+        if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
+            logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
+            # this immediately partitions the model across all gpus, to avoid the overhead in time
+            # and memory copying it on CPU or each GPU first
+            import deepspeed
+
+            init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()]
+            with ContextManagers(init_contexts):
+                model = cls(config, **kwargs)
+
+        else:
+            model = cls(config, **kwargs)
+
+        # restore default dtype if it was modified
+        if dtype_orig is not None:
+            torch.set_default_dtype(dtype_orig)
+
+        return model
+
+    @classmethod
+    def _set_default_dtype(cls, dtype: torch.dtype) -> torch.dtype:
+        """
+        Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
+        under specific dtype.
+
+        Args:
+            dtype (`torch.dtype`):
+                a floating dtype to set to.
+
+        Returns:
+            `torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
+            modified. If it wasn't, returns `None`.
+
+        Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
+        `torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
+        """
+        if not dtype.is_floating_point:
+            raise ValueError(
+                f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
+            )
+
+        logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
+        dtype_orig = torch.get_default_dtype()
+        torch.set_default_dtype(dtype)
+        return dtype_orig
+
+    @property
+    def base_model(self) -> nn.Module:
+        """
+        `torch.nn.Module`: The main body of the model.
+        """
+        return getattr(self, self.base_model_prefix, self)
+
+    @classmethod
+    def can_generate(cls) -> bool:
+        """
+        Returns whether this model can generate sequences with `.generate()` from the `GenerationMixin`.
+
+        Under the hood, on classes where this function returns True, some generation-specific changes are triggered:
+        for instance, the model instance will have a populated `generation_config` attribute.
+
+        Returns:
+            `bool`: Whether this model can generate sequences with `.generate()`.
+        """
+        # Directly inherits `GenerationMixin` -> can generate
+        if "GenerationMixin" in str(cls.__bases__):
+            return True
+        # The class inherits from a class that can generate (recursive check) -> can generate
+        for base in cls.__bases__:
+            if not hasattr(base, "can_generate"):
+                continue
+            if "PreTrainedModel" not in str(base) and base.can_generate():
+                return True
+        # Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
+        # was how we detected whether a model could generate.
+        if hasattr(cls, "prepare_inputs_for_generation"):  # implicit: doesn't inherit `GenerationMixin`
+            logger.warning(
+                f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly "
+                "defined. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, "
+                "`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability "
+                "to call `generate` and other related functions."
+                "\n  - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the "
+                "model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes"
+                "\n  - If you are the owner of the model architecture code, please modify your model class such that "
+                "it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception)."
+                "\n  - If you are not the owner of the model architecture class, please contact the model code owner "
+                "to update it."
+            )
+        # Otherwise, can't generate
+        return False
+
+    def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool:
+        """
+        Check the availability of Flash Attention 2 for a given model.
+
+        Args:
+            is_init_check (`bool`, *optional*):
+                Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
+                fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
+                BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
+                before instantiating the full models if we know that the model does not support the requested attention.
+        """
+        dtype = self.config.dtype
+
+        # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
+        if not (self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False)):
+            raise ValueError(
+                f"{self.__class__.__name__} does not support Flash Attention 2.0 yet. Please request to add support where"
+                f" the model is hosted, on its model hub page: https://huggingface.co/{self.config._name_or_path}/discussions/new"
+                " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
+            )
+
+        if not is_flash_attn_2_available():
+            preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
+            install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
+
+            # package `flash-attn` can not be installed on Ascend NPU, following validation logics can be ignored.
+            if is_torch_npu_available():
+                logger.info("Detect using FlashAttention2 on Ascend NPU.")
+                return True
+
+            if importlib.util.find_spec("flash_attn") is None:
+                raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
+            else:
+                # Check FA2 installed version compatibility
+                flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
+                if torch.version.cuda:
+                    if flash_attention_version < version.parse("2.1.0"):
+                        raise ImportError(
+                            f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}"
+                        )
+                    elif not torch.cuda.is_available():
+                        raise ValueError(
+                            f"{preface} Flash Attention 2 is not available on CPU. Please make sure torch can access a CUDA device."
+                        )
+                    else:
+                        raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
+                elif torch.version.hip:
+                    if flash_attention_version < version.parse("2.0.4"):
+                        raise ImportError(
+                            f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Detected version {flash_attention_version}. {install_message}"
+                        )
+                    else:
+                        raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
+
+        if dtype is None:
+            logger.warning_once(
+                "You are attempting to use Flash Attention 2 without specifying a torch dtype. This might lead to unexpected behaviour"
+            )
+        elif dtype is not None and dtype not in [torch.float16, torch.bfloat16]:
+            logger.warning_once(
+                "Flash Attention 2 only supports torch.float16 and torch.bfloat16 dtypes, but"
+                f" the current dype in {self.__class__.__name__} is {dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
+                ' or load the model with the `dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", dtype=torch.float16)`'
+            )
+
+        # With the early check, the parameters are not yet initialized correctly
+        if not is_init_check:
+            if getattr(self, "use_bettertransformer", False):
+                raise ValueError(
+                    "Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
+                )
+
+            param_devices = list({param.device for param in self.parameters()})
+            if len(param_devices) == 1 and param_devices[0].type == "cpu":
+                if torch.cuda.is_available():
+                    logger.warning_once(
+                        "You are attempting to use Flash Attention 2 with a model not initialized on GPU. Make sure to move the model to GPU"
+                        " after initializing it on CPU with `model.to('cuda')`."
+                    )
+                elif is_torch_mlu_available():
+                    logger.warning_once(
+                        "You are attempting to use Flash Attention 2 with a model not initialized on MLU. Make sure to move the model to MLU"
+                        " after initializing it on CPU with `model.to('mlu')`."
+                    )
+                else:
+                    raise ValueError(
+                        "You are attempting to use Flash Attention 2 with a model not initialized on GPU and with no GPU available. "
+                        "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
+                        "or initialising the model on CPU and then moving it to GPU."
+                    )
+
+        # If no error raise by this point, we can return `True`
+        return True
+
+    def _flash_attn_3_can_dispatch(self, is_init_check: bool = False) -> bool:
+        """
+        Check the availability of Flash Attention 3 for a given model.
+
+        Args:
+            is_init_check (`bool`, *optional*):
+                Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
+                fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
+                BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
+                before instantiating the full models if we know that the model does not support the requested attention.
+        """
+        dtype = self.config.dtype
+
+        if not self._supports_flash_attn:
+            raise ValueError(
+                f"{self.__class__.__name__} does not support Flash Attention 3 yet. Please request to add support where"
+                f" the model is hosted, on its model hub page: https://huggingface.co/{self.config._name_or_path}/discussions/new"
+                " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
+            )
+
+        if not is_flash_attn_3_available():
+            preface = "FlashAttention3 has been toggled on, but it cannot be used due to the following error:"
+
+            if importlib.util.find_spec("flash_attn_3") is None:
+                raise ImportError(f"{preface} the package flash_attn_3 seems to be not installed.")
+
+            if torch.cuda.is_available():
+                major, _ = torch.cuda.get_device_capability()
+                if major < 9:
+                    raise ValueError(
+                        f"{preface} Flash Attention 3 requires compute capability >= 9.0, but found {torch.cuda.get_device_capability()} with compute capability {major}.0."
+                    )
+                else:
+                    raise ImportError(f"{preface} Flash Attention 3 is not available.")
+            else:
+                raise ValueError(
+                    f"{preface} Flash Attention 3 is not available on CPU. Please make sure torch can access a CUDA device."
+                )
+
+        if dtype is None:
+            logger.warning_once(
+                "You are attempting to use Flash Attention 3 without specifying a torch dtype. This might lead to unexpected behaviour"
+            )
+        elif dtype is not None and dtype not in [torch.float16, torch.bfloat16]:
+            logger.warning_once(
+                "Flash Attention 3 only supports torch.float16 and torch.bfloat16 dtypes, but"
+                f" the current dype in {self.__class__.__name__} is {dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
+                ' or load the model with the `dtype` argument. Example: `model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_3", dtype=torch.float16)`'
+            )
+
+        if getattr(self.config, "alibi", False) or getattr(self.config, "use_alibi", False):
+            raise ValueError("Model is configured to use ALiBi, which is not supported by Flash Attention 3.")
+
+        # Check for attention dropout, which is incompatible with FA3
+        if hasattr(self.config, "attention_dropout") and self.config.attention_dropout > 0:
+            raise ValueError(
+                f"Model has attention_dropout={self.config.attention_dropout}, which is not supported by Flash Attention 3."
+            )
+
+        # With the early check, the parameters are not yet initialized correctly
+        if not is_init_check:
+            param_devices = list({param.device for param in self.parameters()})
+            if len(param_devices) == 1 and param_devices[0].type == "cpu":
+                if torch.cuda.is_available():
+                    logger.warning_once(
+                        "You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU"
+                        " after initializing it on CPU with `model.to('cuda')`."
+                    )
+                else:
+                    raise ValueError(
+                        "You are attempting to use Flash Attention 3 with a model not initialized on GPU and with no GPU available. "
+                        "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
+                        "or initialising the model on CPU and then moving it to GPU."
+                    )
+
+        return True
+
+    def _sdpa_can_dispatch(self, is_init_check: bool = False) -> bool:
+        """
+        Check the availability of SDPA for a given model.
+
+        Args:
+            is_init_check (`bool`, *optional*):
+                Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
+                fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
+                BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
+                before instantiating the full models if we know that the model does not support the requested attention.
+        """
+        if not self._supports_sdpa:
+            raise ValueError(
+                f"{self.__class__.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
+                " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe"
+                ' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
+            )
+
+        if (
+            torch.version.hip is not None
+            and torch.cuda.device_count() > 1
+            and version.parse(torch.__version__) < version.parse("2.4.1")
+        ):
+            logger.warning_once(
+                "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
+            )
+            torch.backends.cuda.enable_flash_sdp(False)
+
+        if not is_init_check:
+            if getattr(self, "use_bettertransformer", False):
+                raise ValueError(
+                    "SDPA and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
+                )
+
+        return True
+
+    def _flex_attn_can_dispatch(self, is_init_check: bool = False) -> bool:
+        """
+        Check the availability of Flex Attention for a given model.
+
+        Args:
+            is_init_check (`bool`, *optional*):
+                Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
+                fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
+                BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
+                before instantiating the full models if we know that the model does not support the requested attention.
+        """
+        if not self._supports_flex_attn:
+            raise ValueError(
+                f"{self.__class__.__name__} does not support an attention implementation through torch's flex_attention."
+                " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/34809."
+                " If you believe this error is a bug, please open an issue in Transformers GitHub repository"
+                ' and load your model with the argument `attn_implementation="eager"` meanwhile.'
+                ' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
+            )
+        if not is_torch_flex_attn_available():
+            raise ImportError(
+                "PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0."
+            )
+
+        if not is_init_check:
+            if getattr(self, "use_bettertransformer", False):
+                raise ValueError(
+                    "FlexAttention and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
+                )
+
+        # If no error raise by this point, we can return `True`
+        return True
+
+    def _check_and_adjust_attn_implementation(
+        self, attn_implementation: Optional[str], is_init_check: bool = False
+    ) -> str:
+        """
+        Check that the `attn_implementation` exists and is supported by the models, and try to get the kernel from hub if
+        it matches hf kernels pattern.
+
+        Args:
+            attn_implementation (`str` or `None`):
+                The attention implementation to check for existence/validity.
+            is_init_check (`bool`, *optional*):
+                Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
+                fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
+                BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
+                before instantiating the full models if we know that the model does not support the requested attention.
+
+        Returns:
+            `str`: The final attention implementation to use, including potential fallbacks from sdpa to eager, or from
+            None to sdpa (to potentially eager).
+        """
+        applicable_attn_implementation = attn_implementation
+        # If FA not installed, do not fail but use kernels instead
+        if (
+            applicable_attn_implementation == "flash_attention_2"
+            and self._supports_flash_attn
+            and not is_flash_attn_2_available()
+            and is_kernels_available()
+        ):
+            applicable_attn_implementation = "kernels-community/flash-attn"
+        if is_kernel(applicable_attn_implementation):
+            try:
+                load_and_register_kernel(applicable_attn_implementation)
+                # log that we used kernel fallback if successful
+                if attn_implementation == "flash_attention_2":
+                    logger.warning_once(
+                        "You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` "
+                        "library instead!"
+                    )
+            except Exception as e:
+                if attn_implementation == "flash_attention_2":
+                    self._flash_attn_2_can_dispatch()  # will fail as fa2 is not available but raise the proper exception
+                logger.warning_once(
+                    f"Could not find a kernel matching `{applicable_attn_implementation}` compatible with your device in the "
+                    f"hub:\n{e}.\nUsing default attention implementation instead (sdpa if available, eager otherwise)."
+                )
+                try:
+                    self._sdpa_can_dispatch(is_init_check)
+                    applicable_attn_implementation = "sdpa"
+                except (ValueError, ImportError) as e:
+                    applicable_attn_implementation = "eager"
+        else:
+            applicable_attn_implementation = self.get_correct_attn_implementation(
+                applicable_attn_implementation, is_init_check
+            )
+            # preload flash attention here to allow compile with fullgraph
+            if applicable_attn_implementation.startswith("flash_attention"):
+                lazy_import_flash_attention(applicable_attn_implementation)
+
+        return applicable_attn_implementation
+
+    def get_correct_attn_implementation(self, requested_attention: Optional[str], is_init_check: bool = False) -> str:
+        applicable_attention = "sdpa" if requested_attention is None else requested_attention
+
+        if applicable_attention not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
+            message = (
+                f'Specified `attn_implementation="{applicable_attention}"` is not supported. The only possible arguments are '
+                '`attn_implementation="eager"`'
+            )
+            # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
+            if self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False):
+                message += ', `"attn_implementation=flash_attention_3"`, `"attn_implementation=flash_attention_2"`'
+            if self._supports_sdpa:
+                message += ', `"attn_implementation=sdpa"'
+            if self._supports_flex_attn:
+                message += ', `"attn_implementation=flex_attention"`'
+            raise ValueError(message + ".")
+
+        # Perform relevant checks
+        if applicable_attention == "flash_attention_2":
+            self._flash_attn_2_can_dispatch(is_init_check)
+        elif applicable_attention == "flash_attention_3":
+            self._flash_attn_3_can_dispatch(is_init_check)
+        elif applicable_attention == "flex_attention":
+            self._flex_attn_can_dispatch(is_init_check)
+        elif applicable_attention == "sdpa":
+            # Sdpa is the default, so we try it and fallback to eager otherwise when not possible
+            try:
+                self._sdpa_can_dispatch(is_init_check)
+            except (ValueError, ImportError) as e:
+                if requested_attention == "sdpa":
+                    raise e
+                applicable_attention = "eager"
+
+        return applicable_attention
+
+    @classmethod
+    def _can_set_attn_implementation(cls) -> bool:
+        """Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on
+        opening the file, but avoids maintaining yet another property flag.
+        """
+        class_file = sys.modules[cls.__module__].__file__
+        with open(class_file, "r") as f:
+            code = f.read()
+        # heuristic -> if we find those patterns, the model uses the correct interface
+        if re.search(r"class \w+Attention\(nn.Module\)", code):
+            return (
+                "eager_attention_forward" in code
+                and "ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]" in code
+            )
+        else:
+            # If no attention layer, assume `True`. Most probably a multimodal model or inherits from existing models
+            return True
+
+    def set_attn_implementation(self, attn_implementation: Union[str, dict]):
+        """
+        Set the requested `attn_implementation` for this model.
+
+        Args:
+            attn_implementation (`str` or `dict`):
+                The attention implementation to set for this model. It can be either a `str`, in which case it will be
+                dispatched to all submodels if relevant, or a `dict` where keys are the sub_configs name, in which case each
+                submodel will dispatch the corresponding value.
+        """
+        requested_implementation = (
+            attn_implementation
+            if not isinstance(attn_implementation, dict)
+            else attn_implementation.get("", self.config._attn_implementation)
+        )
+
+        # At this point, the model was already instantiated, so instead of crashing on bad value, let's simply
+        # warn the user that the requested value is not working
+        if requested_implementation != self.config._attn_implementation:
+            # In this case, raise
+            if not self._can_set_attn_implementation():
+                logger.warning(
+                    f"{self.__class__.__name__} does not support setting its attention implementation dynamically, because it "
+                    "does not follow the functional approach based on AttentionInterface "
+                    "(see https://huggingface.co/docs/transformers/en/attention_interface)"
+                )
+            else:
+                requested_implementation = self._check_and_adjust_attn_implementation(
+                    requested_implementation, is_init_check=False
+                )
+                # Apply the change (on the internal attr, to avoid setting it recursively)
+                self.config._attn_implementation_internal = requested_implementation
+
+        # Apply it to all submodels as well
+        for submodule in self.modules():
+            # We found a submodel (which is not self) with a different config (otherwise, it may be the same "actual model",
+            # e.g. ForCausalLM has a Model inside, but no need to check it again)
+            if (
+                submodule is not self
+                and isinstance(submodule, PreTrainedModel)
+                and submodule.config.__class__ != self.config.__class__
+                # If it was already changed, no need to do it again
+                and not hasattr(submodule.config, "_attn_was_changed")
+            ):
+                # In this case, warn and skip
+                if not submodule._can_set_attn_implementation():
+                    logger.warning(
+                        f"{submodule.__class__.__name__} does not support setting its attention implementation dynamically, because it "
+                        "does not follow the functional approach based on AttentionInterface "
+                        "(see https://huggingface.co/docs/transformers/en/attention_interface)"
+                    )
+                # Set the attn on the submodule
+                else:
+                    sub_implementation = requested_implementation
+                    if isinstance(attn_implementation, dict):
+                        for subconfig_key in self.config.sub_configs:
+                            # We need to check for exact object match here, with `is`
+                            if getattr(self.config, subconfig_key) is submodule.config:
+                                sub_implementation = attn_implementation.get(
+                                    subconfig_key, submodule.config._attn_implementation
+                                )
+                                break
+                    # Check the module can use correctly, otherwise we raise an error if requested attention can't be set for submodule
+                    sub_implementation = submodule.get_correct_attn_implementation(sub_implementation)
+                    submodule.config._attn_implementation_internal = sub_implementation
+
+                # Still add it as "changed" even if it was skipped, as we would otherwise try to set it in the dark afterwards
+                # We need to set it on the config itself, to differentiate 2 subconfigs of the same __class__ potentially
+                submodule.config._attn_was_changed = True
+
+        # We need this as some old and badly designed models use subconfigs without declaring the corresponding modules as PreTrainedModel
+        for subconfig_key in self.config.sub_configs:
+            subconfig = getattr(self.config, subconfig_key)
+            sub_implementation = (
+                requested_implementation
+                if not isinstance(attn_implementation, dict)
+                else attn_implementation.get(subconfig_key, subconfig._attn_implementation)
+            )
+            # This means we did not perform any check above for this particular subconfig -> set it in the dark if it is registered
+            if (
+                not hasattr(subconfig, "_attn_was_changed")
+                # If it's already the same, then no need to enter here and raise warnings
+                and sub_implementation != subconfig._attn_implementation
+            ):
+                if sub_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
+                    raise ValueError(
+                        f'Specified `attn_implementation="{sub_implementation}"` is not supported for {subconfig_key}. '
+                        'The only possible arguments are "eager" (manual attention implementation)'
+                        f"or one of the following: {list(ALL_ATTENTION_FUNCTIONS.valid_keys())}"
+                    )
+                subconfig._attn_implementation_internal = sub_implementation
+                logger.warning(
+                    f"We set the attention implementation for the sub-config `{subconfig_key}` to `{sub_implementation}` "
+                    "without finding the associated sub-model. For this reason we could not check if the model supports it. "
+                    "You may encounter undefined behavior."
+                )
+            # Unset the attribute in this case, to avoid issues in the future
+            else:
+                if hasattr(subconfig, "_attn_was_changed"):
+                    del subconfig._attn_was_changed
+
+    def enable_input_require_grads(self):
+        """
+        Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
+        the model weights fixed.
+        """
+
+        def make_inputs_require_grads(module, input, output):
+            output.requires_grad_(True)
+
+        self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
+
+    def disable_input_require_grads(self):
+        """
+        Removes the `_require_grads_hook`.
+        """
+        self._require_grads_hook.remove()
+
+    def get_decoder(self):
+        """
+        Best-effort lookup of the *decoder* module.
+
+        Order of attempts (covers ~85 % of current usages):
+
+        1. `self.decoder`
+        2. `self.model`                       (many wrappers store the decoder here)
+        3. `self.model.get_decoder()`         (nested wrappers)
+        4. fallback: raise for the few exotic models that need a bespoke rule
+        """
+        if hasattr(self, "decoder"):
+            return self.decoder
+
+        if hasattr(self, "model"):
+            inner = self.model
+            if hasattr(inner, "get_decoder"):
+                return inner.get_decoder()
+            return inner
+
+        return None  # raise AttributeError(f"{self.__class__.__name__} has no decoder; override `get_decoder()` if needed.")
+
+    def set_decoder(self, decoder):
+        """
+        Symmetric setter. Mirrors the lookup logic used in `get_decoder`.
+        """
+
+        if hasattr(self, "decoder"):
+            self.decoder = decoder
+            return
+
+        if hasattr(self, "model"):
+            inner = self.model
+            if hasattr(inner, "set_decoder"):
+                inner.set_decoder(decoder)
+            else:
+                self.model = decoder
+            return
+
+        return  # raise AttributeError(f"{self.__class__.__name__} cannot accept a decoder; override `set_decoder()`.")
+
+    def _init_weights(self, module):
+        """
+        Initialize the weights. This is quite general on purpose, in the spirit of what we usually do. For more complex
+        initialization scheme, it should be overridden by the derived `PreTrainedModel` class. In case a model adds an explicit
+        `nn.Parameter`, this method should also be overridden in order to initialize it correctly.
+        """
+        if hasattr(self.config, "initializer_range"):
+            std = self.config.initializer_range
+        else:
+            # 0.02 is the standard default value across the library
+            std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
+
+        if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+        elif isinstance(module, nn.MultiheadAttention):
+            # This uses torch's original init
+            module._reset_parameters()
+        # We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names
+        # between modelings (because they are prefixed with the model name)
+        elif (
+            isinstance(module, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d))
+            or "LayerNorm" in module.__class__.__name__
+            or "RMSNorm" in module.__class__.__name__
+        ):
+            # Norms can exist without weights (in which case they are None from torch primitives)
+            if hasattr(module, "weight") and module.weight is not None:
+                module.weight.data.fill_(1.0)
+            if hasattr(module, "bias") and module.bias is not None:
+                module.bias.data.zero_()
+
+    def _initialize_weights(self, module):
+        """
+        Initialize the weights if they are not already initialized.
+        """
+        if getattr(module, "_is_hf_initialized", False):
+            return
+        self._init_weights(module)
+        module._is_hf_initialized = True
+
+    @torch.no_grad()
+    def initialize_weights(self):
+        """
+        This is equivalent to calling `self.apply(self._initialize_weights)`, but correctly handles composite models.
+        This function dynamically dispatches the correct `init_weights` function to the modules as we advance in the
+        module graph along the recursion. It can handle an arbitrary number of sub-models. Without it, every composite
+        model would have to recurse a second time on all sub-models explicitly in the outer-most `_init_weights`, which
+        is extremely error prone and inefficient.
+
+        Note that the `torch.no_grad()` decorator is very important as well, as most of our `_init_weights` do not use
+        `torch.nn.init` functions (which are all no_grad by default), but simply do in-place ops such as
+        `module.weight.data.zero_()`.
+        """
+        if not hasattr(torch.nn.Module, "smart_apply"):
+            # This function is equivalent to `torch.nn.Module.apply`, except that it dynamically adjust the function
+            # to apply as we go down the graph
+            def smart_apply(self, fn):
+                for module in self.children():
+                    # We found a sub-model: recursively dispatch its own init function now!
+                    if isinstance(module, PreTrainedModel):
+                        module.smart_apply(module._initialize_weights)
+                    else:
+                        module.smart_apply(fn)
+                fn(self)
+                return self
+
+            torch.nn.Module.smart_apply = smart_apply
+
+        # Let the magic happen with this simple call
+        self.smart_apply(self._initialize_weights)
+
+    def tie_embeddings_and_encoder_decoder(self):
+        """
+        If set in the config, tie the weights between the input embeddings and the output embeddings,
+        and the encoder and decoder.
+
+        If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
+        weights instead.
+        """
+        if getattr(self.config.get_text_config(decoder=True), "tie_word_embeddings", True):
+            output_embeddings = self.get_output_embeddings()
+            if output_embeddings is not None:
+                self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
+
+        if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
+            if hasattr(self, self.base_model_prefix):
+                self = getattr(self, self.base_model_prefix)
+            tied_weights = self._tie_encoder_decoder_weights(
+                self.encoder, self.decoder, self.base_model_prefix, "encoder"
+            )
+            # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
+            # attributed not an instance member, therefore modifying it will modify the entire class
+            # Leading to issues on subsequent calls by different tests or subsequent calls.
+            self._dynamic_tied_weights_keys = tied_weights
+
+    def tie_weights(self):
+        """
+        Recursively (for all submodels) tie all the weights of the model.
+        """
+        # Note that `self` is included in `self.modules` so we also apply to current PreTrainedModel with this call
+        for module in self.modules():
+            # If it's a PreTrainedModel, may need to tie the embeddings and/or encoder/decoder weights
+            if isinstance(module, PreTrainedModel):
+                module.tie_embeddings_and_encoder_decoder()
+            # Additionally, if it has a custom `_tie_weights`, honor it
+            if hasattr(module, "_tie_weights"):
+                module._tie_weights()
+
+    @staticmethod
+    def _tie_encoder_decoder_weights(
+        encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, base_encoder_name: str
+    ):
+        uninitialized_encoder_weights: list[str] = []
+        tied_weights: list[str] = []
+        if decoder.__class__ != encoder.__class__:
+            logger.info(
+                f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder"
+                " weights are correctly initialized."
+            )
+
+        def tie_encoder_to_decoder_recursively(
+            decoder_pointer: nn.Module,
+            encoder_pointer: nn.Module,
+            module_name: str,
+            base_encoder_name: str,
+            uninitialized_encoder_weights: list[str],
+            depth=0,
+            total_decoder_name="",
+            total_encoder_name="",
+        ):
+            assert isinstance(decoder_pointer, nn.Module) and isinstance(encoder_pointer, nn.Module), (
+                f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module"
+            )
+            if hasattr(decoder_pointer, "weight"):
+                assert hasattr(encoder_pointer, "weight")
+                encoder_pointer.weight = decoder_pointer.weight
+                tied_weights.append(f"{base_encoder_name}{total_encoder_name}.weight")
+                if hasattr(decoder_pointer, "bias"):
+                    assert hasattr(encoder_pointer, "bias")
+                    tied_weights.append(f"{base_encoder_name}{total_encoder_name}.bias")
+                    encoder_pointer.bias = decoder_pointer.bias
+                return
+
+            encoder_modules = encoder_pointer._modules
+            decoder_modules = decoder_pointer._modules
+            if len(decoder_modules) > 0:
+                assert len(encoder_modules) > 0, (
+                    f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
+                )
+
+                all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules}
+                encoder_layer_pos = 0
+                for name in decoder_modules:
+                    if name.isdigit():
+                        encoder_name = str(int(name) + encoder_layer_pos)
+                        decoder_name = name
+                        if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
+                            encoder_modules
+                        ) != len(decoder_modules):
+                            # this can happen if the name corresponds to the position in a list module list of layers
+                            # in this case the decoder has added a cross-attention that the encoder does not have
+                            # thus skip this step and subtract one layer pos from encoder
+                            encoder_layer_pos -= 1
+                            continue
+                    elif name not in encoder_modules:
+                        continue
+                    elif depth > 500:
+                        raise ValueError(
+                            "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is"
+                            " a circular dependency between two or more `nn.Modules` of your model."
+                        )
+                    else:
+                        decoder_name = encoder_name = name
+                    tie_encoder_to_decoder_recursively(
+                        decoder_modules[decoder_name],
+                        encoder_modules[encoder_name],
+                        module_name + "/" + name,
+                        base_encoder_name,
+                        uninitialized_encoder_weights,
+                        depth=depth + 1,
+                        total_encoder_name=f"{total_encoder_name}.{encoder_name}",
+                        total_decoder_name=f"{total_decoder_name}.{decoder_name}",
+                    )
+                    all_encoder_weights.remove(module_name + "/" + encoder_name)
+
+                uninitialized_encoder_weights += list(all_encoder_weights)
+
+        # tie weights recursively
+        tie_encoder_to_decoder_recursively(
+            decoder, encoder, base_model_prefix, base_encoder_name, uninitialized_encoder_weights
+        )
+
+        if len(uninitialized_encoder_weights) > 0:
+            logger.warning(
+                f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
+            )
+        return tied_weights
+
+    def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
+        """Tie or clone module weights depending of whether we are using TorchScript or not"""
+        if self.config.torchscript:
+            output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
+        else:
+            output_embeddings.weight = input_embeddings.weight
+
+        # Passing hooks over to the embeddings if needed
+        # (currently limited to tensor parallel hooks and flags only)
+        if hasattr(input_embeddings, "_is_hooked") and getattr(input_embeddings, "_hf_tp_plan", None):
+            output_embeddings._is_hooked = input_embeddings._is_hooked
+            output_embeddings._hf_tp_plan = input_embeddings._hf_tp_plan
+            output_embeddings._forward_hooks = input_embeddings._forward_hooks
+            output_embeddings._forward_pre_hooks = input_embeddings._forward_pre_hooks
+            output_embeddings.__repr__ = (
+                lambda: f"{output_embeddings.__repr__()}\nTP Plan: {output_embeddings._hf_tp_plan}"
+            )
+
+        if getattr(output_embeddings, "bias", None) is not None:
+            output_embeddings.bias.data = nn.functional.pad(
+                output_embeddings.bias.data,
+                (
+                    0,
+                    output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
+                ),
+                "constant",
+                0,
+            )
+        if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
+            output_embeddings.out_features = input_embeddings.num_embeddings
+
+    def _get_no_split_modules(self, device_map: str):
+        """
+        Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
+        get the underlying `_no_split_modules`.
+
+        Args:
+            device_map (`str`):
+                The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
+
+        Returns:
+            `list[str]`: List of modules that should not be split
+        """
+        _no_split_modules = set()
+        modules_to_check = [self]
+        while len(modules_to_check) > 0:
+            module = modules_to_check.pop(-1)
+            # if the module does not appear in _no_split_modules, we also check the children
+            if module.__class__.__name__ not in _no_split_modules:
+                if isinstance(module, PreTrainedModel):
+                    if module._no_split_modules is None:
+                        raise ValueError(
+                            f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
+                            "class needs to implement the `_no_split_modules` attribute."
+                        )
+                    else:
+                        _no_split_modules = _no_split_modules | set(module._no_split_modules)
+                modules_to_check += list(module.children())
+        return list(_no_split_modules)
+
+    def resize_token_embeddings(
+        self,
+        new_num_tokens: Optional[int] = None,
+        pad_to_multiple_of: Optional[int] = None,
+        mean_resizing: bool = True,
+    ) -> nn.Embedding:
+        """
+        Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
+
+        Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
+
+        Arguments:
+            new_num_tokens (`int`, *optional*):
+                The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
+                vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
+                returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
+            pad_to_multiple_of (`int`, *optional*):
+                If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to
+                `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
+
+                This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
+                `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
+                details about this, or help on choosing the correct value for resizing, refer to this guide:
+                https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
+            mean_resizing (`bool`):
+                Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
+                covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
+
+                Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
+                where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the
+                old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
+                Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
+
+        Return:
+            `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
+        """
+        model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
+        if new_num_tokens is None and pad_to_multiple_of is None:
+            return model_embeds
+
+        # Since we are basically reusing the same old embeddings with new weight values, gathering is required
+        is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
+        if is_deepspeed_zero3_enabled() and not is_quantized:
+            import deepspeed
+
+            with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None):
+                vocab_size = model_embeds.weight.shape[0]
+        else:
+            vocab_size = model_embeds.weight.shape[0]
+
+        # Update base model and current model config.
+        self.config.get_text_config().vocab_size = vocab_size
+        self.vocab_size = vocab_size
+
+        # Tie weights again if needed
+        self.tie_weights()
+
+        return model_embeds
+
+    def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True):
+        old_embeddings = self.get_input_embeddings()
+        new_embeddings = self._get_resized_embeddings(
+            old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing
+        )
+        if hasattr(old_embeddings, "_hf_hook"):
+            hook = old_embeddings._hf_hook
+            add_hook_to_module(new_embeddings, hook)
+        old_embeddings_requires_grad = old_embeddings.weight.requires_grad
+        new_embeddings.requires_grad_(old_embeddings_requires_grad)
+        self.set_input_embeddings(new_embeddings)
+        is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
+
+        # Update new_num_tokens with the actual size of new_embeddings
+        if pad_to_multiple_of is not None:
+            if is_deepspeed_zero3_enabled() and not is_quantized:
+                import deepspeed
+
+                with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
+                    new_num_tokens = new_embeddings.weight.shape[0]
+            else:
+                new_num_tokens = new_embeddings.weight.shape[0]
+
+        # if word embeddings are not tied, make sure that lm head is resized as well
+        if (
+            self.get_output_embeddings() is not None
+            and not self.config.get_text_config(decoder=True).tie_word_embeddings
+        ):
+            old_lm_head = self.get_output_embeddings()
+            if isinstance(old_lm_head, torch.nn.Embedding):
+                new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
+            else:
+                new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
+            if hasattr(old_lm_head, "_hf_hook"):
+                hook = old_lm_head._hf_hook
+                add_hook_to_module(new_lm_head, hook)
+            old_lm_head_requires_grad = old_lm_head.weight.requires_grad
+            new_lm_head.requires_grad_(old_lm_head_requires_grad)
+            self.set_output_embeddings(new_lm_head)
+
+        return self.get_input_embeddings()
+
+    def _get_resized_embeddings(
+        self,
+        old_embeddings: nn.Embedding,
+        new_num_tokens: Optional[int] = None,
+        pad_to_multiple_of: Optional[int] = None,
+        mean_resizing: bool = True,
+    ) -> nn.Embedding:
+        """
+        Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
+        initialized vectors at the end. Reducing the size will remove vectors from the end
+
+        Args:
+            old_embeddings (`torch.nn.Embedding`):
+                Old embeddings to be resized.
+            new_num_tokens (`int`, *optional*):
+                New number of tokens in the embedding matrix.
+
+                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
+                vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
+                `torch.nn.Embedding` module of the model without doing anything.
+            pad_to_multiple_of (`int`, *optional*):
+                If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
+                `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
+
+                This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
+                `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
+                details about this, or help on choosing the correct value for resizing, refer to this guide:
+                https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
+            mean_resizing (`bool`):
+                Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
+                covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
+
+                Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
+                where the generated tokens' probabilities will not be affected by the added embeddings because initializing the new embeddings with the
+                old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
+                Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
+
+
+        Return:
+            `torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if
+            `new_num_tokens` is `None`
+        """
+
+        if pad_to_multiple_of is not None:
+            if not isinstance(pad_to_multiple_of, int):
+                raise ValueError(
+                    f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer"
+                )
+            if new_num_tokens is None:
+                new_num_tokens = old_embeddings.weight.shape[0]
+            new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
+        else:
+            logger.info(
+                "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding"
+                f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available."
+                " For more details about this, or help on choosing the correct value for resizing, refer to this guide:"
+                " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc"
+            )
+
+        if new_num_tokens is None:
+            return old_embeddings
+
+        is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
+        if is_deepspeed_zero3_enabled() and not is_quantized:
+            import deepspeed
+
+            with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
+                old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
+        else:
+            old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
+
+        if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
+            return old_embeddings
+
+        if not isinstance(old_embeddings, nn.Embedding):
+            raise TypeError(
+                f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You"
+                " should either use a different resize function or make sure that `old_embeddings` are an instance of"
+                f" {nn.Embedding}."
+            )
+
+        # Build new embeddings
+
+        # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
+        # because the shape of the new embedding layer is used across various modeling files
+        # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
+        # to errors when training.
+        new_embeddings = nn.Embedding(
+            new_num_tokens,
+            old_embedding_dim,
+            device=old_embeddings.weight.device,
+            dtype=old_embeddings.weight.dtype,
+        )
+
+        if new_num_tokens > old_num_tokens and not mean_resizing:
+            # initialize new embeddings (in particular added tokens) with a mean of 0 and std equals `config.initializer_range`.
+            self._init_weights(new_embeddings)
+
+        elif new_num_tokens > old_num_tokens and mean_resizing:
+            # initialize new embeddings  (in particular added tokens). The new embeddings will be initialized
+            # from a multivariate normal distribution that has old embeddings' mean and covariance.
+            # as described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
+            logger.warning_once(
+                "The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. "
+                "As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. "
+                "To disable this, use `mean_resizing=False`"
+            )
+
+            added_num_tokens = new_num_tokens - old_num_tokens
+            if is_deepspeed_zero3_enabled() and not is_quantized:
+                import deepspeed
+
+                with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
+                    self._init_added_embeddings_weights_with_mean(
+                        old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
+                    )
+            else:
+                self._init_added_embeddings_weights_with_mean(
+                    old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
+                )
+
+        # Copy token embeddings from the previous weights
+
+        # numbers of tokens to copy
+        n = min(old_num_tokens, new_num_tokens)
+
+        if is_deepspeed_zero3_enabled() and not is_quantized:
+            import deepspeed
+
+            params = [old_embeddings.weight, new_embeddings.weight]
+            with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
+                new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
+        else:
+            new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
+
+        # Replace weights in old_embeddings and return to maintain the same embedding type.
+        # This ensures correct functionality when a Custom Embedding class is passed as input.
+        # The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979)
+        if is_deepspeed_zero3_enabled() and not is_quantized:
+            import deepspeed
+
+            params = [old_embeddings.weight, new_embeddings.weight]
+            with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
+                old_embeddings.weight = new_embeddings.weight
+                old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
+
+                # If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx`
+                # will be set to `None` in the resized embeddings.
+                if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
+                    old_embeddings.padding_idx = None
+        else:
+            old_embeddings.weight.data = new_embeddings.weight.data
+            old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
+            if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
+                old_embeddings.padding_idx = None
+
+        return old_embeddings
+
+    def _get_resized_lm_head(
+        self,
+        old_lm_head: nn.Linear,
+        new_num_tokens: Optional[int] = None,
+        transposed: Optional[bool] = False,
+        mean_resizing: bool = True,
+    ) -> nn.Linear:
+        """
+        Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
+        vectors at the end. Reducing the size will remove vectors from the end
+
+        Args:
+            old_lm_head (`torch.nn.Linear`):
+                Old lm head liner layer to be resized.
+            new_num_tokens (`int`, *optional*):
+                New number of tokens in the linear matrix.
+
+                Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
+                vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
+                `torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults
+                to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim,
+                vocab_size` else `vocab_size, lm_head_dim`.
+            mean_resizing (`bool`):
+                Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
+                covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
+
+                Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
+                where the generated tokens' probabilities will not be affected by the added embeddings because initializing the new embeddings with the
+                old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
+                Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
+
+        Return:
+            `torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is
+            `None`
+        """
+
+        if new_num_tokens is None:
+            return old_lm_head
+
+        is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
+        if is_deepspeed_zero3_enabled() and not is_quantized:
+            import deepspeed
+
+            with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
+                old_num_tokens, old_lm_head_dim = (
+                    old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
+                )
+        else:
+            old_num_tokens, old_lm_head_dim = (
+                old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
+            )
+
+        if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
+            return old_lm_head
+
+        if not isinstance(old_lm_head, nn.Linear):
+            raise TypeError(
+                f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}. You"
+                " should either use a different resize function or make sure that `old_lm_head` are an instance of"
+                f" {nn.Linear}."
+            )
+
+        # Build new lm head
+        new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
+        has_new_lm_head_bias = old_lm_head.bias is not None
+
+        # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
+        # because the shape of the new embedding layer is used across various modeling files
+        # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
+        # to errors when training.
+        new_lm_head = nn.Linear(
+            *new_lm_head_shape,
+            bias=has_new_lm_head_bias,
+            device=old_lm_head.weight.device,
+            dtype=old_lm_head.weight.dtype,
+        )
+
+        if new_num_tokens > old_num_tokens and not mean_resizing:
+            # initialize new embeddings (in particular added tokens) with a mean of 0 and std equals `config.initializer_range`.
+            self._init_weights(new_lm_head)
+
+        elif new_num_tokens > old_num_tokens and mean_resizing:
+            # initialize new lm_head weights (in particular added tokens). The new lm_head weights
+            # will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance.
+            # as described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
+            logger.warning_once(
+                "The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. "
+                "As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. "
+                "To disable this, use `mean_resizing=False`"
+            )
+
+            added_num_tokens = new_num_tokens - old_num_tokens
+            if is_deepspeed_zero3_enabled() and not is_quantized:
+                import deepspeed
+
+                params = [old_lm_head.weight]
+                if has_new_lm_head_bias:
+                    params += [old_lm_head.bias]
+                with deepspeed.zero.GatheredParameters(params, modifier_rank=None):
+                    self._init_added_lm_head_weights_with_mean(
+                        old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens, transposed
+                    )
+                    if has_new_lm_head_bias:
+                        self._init_added_lm_head_bias_with_mean(old_lm_head, new_lm_head, added_num_tokens)
+
+            else:
+                self._init_added_lm_head_weights_with_mean(
+                    old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens, transposed
+                )
+                if has_new_lm_head_bias:
+                    self._init_added_lm_head_bias_with_mean(old_lm_head, new_lm_head, added_num_tokens)
+
+        num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
+
+        if is_deepspeed_zero3_enabled() and not is_quantized:
+            import deepspeed
+
+            params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
+            with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
+                self._copy_lm_head_original_to_resized(
+                    new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
+                )
+        else:
+            self._copy_lm_head_original_to_resized(
+                new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
+            )
+
+        return new_lm_head
+
+    def _init_added_embeddings_weights_with_mean(
+        self, old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
+    ):
+        old_embeddings_weight = old_embeddings.weight.data.to(torch.float32)
+        mean_embeddings = torch.mean(old_embeddings_weight, axis=0)
+        old_centered_embeddings = old_embeddings_weight - mean_embeddings
+        covariance = old_centered_embeddings.T @ old_centered_embeddings / old_num_tokens
+
+        # Check if the covariance is positive definite.
+        epsilon = 1e-9
+        is_covariance_psd = constraints.positive_definite.check(epsilon * covariance).all()
+        if is_covariance_psd:
+            # If covariances is positive definite, a distribution can be created. and we can sample new weights from it.
+            distribution = torch.distributions.multivariate_normal.MultivariateNormal(
+                mean_embeddings, covariance_matrix=epsilon * covariance
+            )
+            new_embeddings.weight.data[-1 * added_num_tokens :, :] = distribution.sample(
+                sample_shape=(added_num_tokens,)
+            ).to(old_embeddings.weight.dtype)
+        else:
+            # Otherwise, just initialize with the mean. because distribution will not be created.
+            new_embeddings.weight.data[-1 * added_num_tokens :, :] = (
+                mean_embeddings[None, :].repeat(added_num_tokens, 1).to(old_embeddings.weight.dtype)
+            )
+
+    def _init_added_lm_head_weights_with_mean(
+        self,
+        old_lm_head,
+        new_lm_head,
+        old_lm_head_dim,
+        old_num_tokens,
+        added_num_tokens,
+        transposed=False,
+    ):
+        if transposed:
+            # Transpose to the desired shape for the function.
+            new_lm_head.weight.data = new_lm_head.weight.data.T
+            old_lm_head.weight.data = old_lm_head.weight.data.T
+
+        # The same initialization logic as Embeddings.
+        self._init_added_embeddings_weights_with_mean(
+            old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens
+        )
+
+        if transposed:
+            # Transpose again to the correct shape.
+            new_lm_head.weight.data = new_lm_head.weight.data.T
+            old_lm_head.weight.data = old_lm_head.weight.data.T
+
+    def _init_added_lm_head_bias_with_mean(self, old_lm_head, new_lm_head, added_num_tokens):
+        bias_mean = torch.mean(old_lm_head.bias.data, axis=0, dtype=torch.float32)
+        bias_std = torch.std(old_lm_head.bias.data, axis=0).to(torch.float32)
+        new_lm_head.bias.data[-1 * added_num_tokens :].normal_(mean=bias_mean, std=1e-9 * bias_std)
+
+    def _copy_lm_head_original_to_resized(
+        self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
+    ):
+        # Copy old lm head weights to new lm head
+        if not transposed:
+            new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
+        else:
+            new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
+
+        # Copy bias weights to new lm head
+        if has_new_lm_head_bias:
+            new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
+
+    def resize_position_embeddings(self, new_num_position_embeddings: int):
+        raise NotImplementedError(
+            f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
+            f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
+        )
+
+    def get_position_embeddings(self) -> Union[nn.Embedding, tuple[nn.Embedding]]:
+        raise NotImplementedError(
+            f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
+            f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
+        )
+
+    def init_weights(self):
+        """
+        If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
+        initialization logic in `_init_weights`.
+        """
+        # Prune heads if needed
+        if self.config.pruned_heads:
+            self.prune_heads(self.config.pruned_heads)
+
+        if _init_weights:
+            # Initialize weights
+            self.initialize_weights()
+
+            # Tie weights should be skipped when not initializing all weights
+            # since from_pretrained(...) calls tie weights anyways
+            self.tie_weights()
+
+    def prune_heads(self, heads_to_prune: dict[int, list[int]]):
+        """
+        Prunes heads of the base model.
+
+        Arguments:
+            heads_to_prune (`dict[int, list[int]]`):
+                Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads
+                to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on
+                layer 1 and heads 2 and 3 on layer 2.
+        """
+        # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
+        for layer, heads in heads_to_prune.items():
+            union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
+            self.config.pruned_heads[layer] = list(union_heads)  # Unfortunately we have to store it as list for JSON
+
+        self.base_model._prune_heads(heads_to_prune)
+
+    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
+        """
+        Activates gradient checkpointing for the current model.
+
+        Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+        activations".
+
+        We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
+        the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
+
+        Args:
+            gradient_checkpointing_kwargs (dict, *optional*):
+                Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
+        """
+        if not self.supports_gradient_checkpointing:
+            raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
+
+        if gradient_checkpointing_kwargs is None:
+            gradient_checkpointing_kwargs = {"use_reentrant": True}
+
+        gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
+
+        # For old GC format (transformers < 4.35.0) for models that live on the Hub
+        # we will fall back to the overwritten `_set_gradient_checkpointing` method
+        _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
+
+        if not _is_using_old_format:
+            self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
+        else:
+            self.apply(partial(self._set_gradient_checkpointing, value=True))
+            logger.warning(
+                "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
+                "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
+            )
+
+        if getattr(self, "_hf_peft_config_loaded", False):
+            # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
+            # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
+            # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
+            # the gradients to make sure the gradient flows.
+            self.enable_input_require_grads()
+
+    def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
+        is_gradient_checkpointing_set = False
+
+        # Apply it on the top-level module in case the top-level modules supports it
+        # for example, LongT5Stack inherits from `PreTrainedModel`.
+        if hasattr(self, "gradient_checkpointing"):
+            self._gradient_checkpointing_func = gradient_checkpointing_func
+            self.gradient_checkpointing = enable
+            is_gradient_checkpointing_set = True
+
+        for module in self.modules():
+            if hasattr(module, "gradient_checkpointing"):
+                module._gradient_checkpointing_func = gradient_checkpointing_func
+                module.gradient_checkpointing = enable
+                is_gradient_checkpointing_set = True
+
+        if not is_gradient_checkpointing_set:
+            raise ValueError(
+                f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
+                " `gradient_checkpointing` to modules of the model that uses checkpointing."
+            )
+
+    def gradient_checkpointing_disable(self):
+        """
+        Deactivates gradient checkpointing for the current model.
+
+        Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+        activations".
+        """
+        if self.supports_gradient_checkpointing:
+            # For old GC format (transformers < 4.35.0) for models that live on the Hub
+            # we will fall back to the overwritten `_set_gradient_checkpointing` method
+            _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
+            if not _is_using_old_format:
+                self._set_gradient_checkpointing(enable=False)
+            else:
+                logger.warning(
+                    "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
+                    "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
+                )
+                self.apply(partial(self._set_gradient_checkpointing, value=False))
+
+        if getattr(self, "_hf_peft_config_loaded", False):
+            self.disable_input_require_grads()
+
+    @property
+    def is_gradient_checkpointing(self) -> bool:
+        """
+        Whether gradient checkpointing is activated for this model or not.
+
+        Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
+        activations".
+        """
+        return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
+
+    def save_pretrained(
+        self,
+        save_directory: Union[str, os.PathLike],
+        is_main_process: bool = True,
+        state_dict: Optional[dict] = None,
+        save_function: Callable = torch.save,
+        push_to_hub: bool = False,
+        max_shard_size: Union[int, str] = "5GB",
+        safe_serialization: bool = True,
+        variant: Optional[str] = None,
+        token: Optional[Union[str, bool]] = None,
+        save_peft_format: bool = True,
+        **kwargs,
+    ):
+        """
+        Save a model and its configuration file to a directory, so that it can be re-loaded using the
+        [`~PreTrainedModel.from_pretrained`] class method.
+
+        Arguments:
+            save_directory (`str` or `os.PathLike`):
+                Directory to which to save. Will be created if it doesn't exist.
+            is_main_process (`bool`, *optional*, defaults to `True`):
+                Whether the process calling this is the main process or not. Useful when in distributed training like
+                TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
+                the main process to avoid race conditions.
+            state_dict (nested dictionary of `torch.Tensor`):
+                The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only
+                save parts of the model or if special precautions need to be taken when recovering the state dictionary
+                of a model (like when using model parallelism).
+            save_function (`Callable`):
+                The function to use to save the state dictionary. Useful on distributed training like TPUs when one
+                need to replace `torch.save` by another method.
+            push_to_hub (`bool`, *optional*, defaults to `False`):
+                Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+                namespace).
+            max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`):
+                The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
+                lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
+                We default it to 5GB in order for models to be able to run easily on free-tier google colab instances
+                without CPU OOM issues.
+
+                
+
+                If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
+                which will be bigger than `max_shard_size`.
+
+                
+
+            safe_serialization (`bool`, *optional*, defaults to `True`):
+                Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
+            variant (`str`, *optional*):
+                If specified, weights are saved in the format pytorch_model..bin.
+            token (`str` or `bool`, *optional*):
+                The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+                the token generated when running `hf auth login` (stored in `~/.huggingface`).
+            save_peft_format (`bool`, *optional*, defaults to `True`):
+                For backward compatibility with PEFT library, in case adapter weights are attached to the model, all
+                keys of the state dict of adapters needs to be prepended with `base_model.model`. Advanced users can
+                disable this behaviours by setting `save_peft_format` to `False`.
+            kwargs (`dict[str, Any]`, *optional*):
+                Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+        """
+        use_auth_token = kwargs.pop("use_auth_token", None)
+        ignore_metadata_errors = kwargs.pop("ignore_metadata_errors", False)
+
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if token is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            token = use_auth_token
+
+        if token is not None:
+            kwargs["token"] = token
+
+        _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False)
+
+        hf_quantizer = getattr(self, "hf_quantizer", None)
+        quantization_serializable = (
+            hf_quantizer is not None
+            and isinstance(hf_quantizer, HfQuantizer)
+            and hf_quantizer.is_serializable(safe_serialization=safe_serialization)
+        )
+
+        if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
+            raise ValueError(
+                f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
+                " the logger on the traceback to understand the reason why the quantized model is not serializable."
+            )
+
+        if "save_config" in kwargs:
+            warnings.warn(
+                "`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
+            )
+            is_main_process = kwargs.pop("save_config")
+        if safe_serialization and not is_safetensors_available():
+            raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
+
+        # we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one
+        if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"):
+            raise ImportError(
+                "Saving a model with tensor parallelism requires `huggingface_hub` version 0.31.4 or higher."
+            )
+
+        if os.path.isfile(save_directory):
+            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+            return
+
+        os.makedirs(save_directory, exist_ok=True)
+
+        if push_to_hub:
+            commit_message = kwargs.pop("commit_message", None)
+            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+            create_pr = kwargs.pop("create_pr", False)
+            repo_id = self._create_repo(repo_id, **kwargs)
+            files_timestamps = self._get_files_timestamps(save_directory)
+
+        if hf_quantizer is not None:
+            state_dict = hf_quantizer.get_state_dict(self)
+        # Only save the model itself if we are using distributed training
+        model_to_save = unwrap_model(self)
+        # save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
+        # we currently don't use this setting automatically, but may start to use with v5
+        dtype = get_parameter_dtype(model_to_save)
+        model_to_save.config.dtype = str(dtype).split(".")[1]
+
+        # Attach architecture to the config
+        model_to_save.config.architectures = [model_to_save.__class__.__name__]
+
+        # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
+        # loaded from the Hub.
+        if self._auto_class is not None:
+            custom_object_save(self, save_directory, config=self.config)
+
+        # Save the config
+        if is_main_process:
+            if not _hf_peft_config_loaded:
+                # If the model config has set attributes that should be in the generation config, move them there.
+                misplaced_generation_parameters = model_to_save.config._get_non_default_generation_parameters()
+                if self.can_generate() and len(misplaced_generation_parameters) > 0:
+                    warnings.warn(
+                        "Moving the following attributes in the config to the generation config: "
+                        f"{misplaced_generation_parameters}. You are seeing this warning because you've set "
+                        "generation parameters in the model config, as opposed to in the generation config.",
+                        UserWarning,
+                    )
+                    for param_name, param_value in misplaced_generation_parameters.items():
+                        setattr(model_to_save.generation_config, param_name, param_value)
+                        setattr(model_to_save.config, param_name, None)
+
+                model_to_save.config.save_pretrained(save_directory)
+            if self.can_generate():
+                model_to_save.generation_config.save_pretrained(save_directory)
+
+            if _hf_peft_config_loaded:
+                logger.info(
+                    "Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved."
+                )
+                state_dict = model_to_save.get_adapter_state_dict(state_dict=state_dict)
+
+                if save_peft_format:
+                    logger.info(
+                        "To match the expected format of the PEFT library, all keys of the state dict of adapters will be prepended with `base_model.model`."
+                    )
+                    peft_state_dict = {}
+                    for key, value in state_dict.items():
+                        peft_state_dict[f"base_model.model.{key}"] = value
+                    state_dict = peft_state_dict
+
+                active_adapter = self.active_adapters()
+
+                if len(active_adapter) > 1:
+                    raise ValueError(
+                        "Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one "
+                        "by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`"
+                    )
+                active_adapter = active_adapter[0]
+
+                current_peft_config = self.peft_config[active_adapter]
+                current_peft_config.save_pretrained(save_directory)
+
+        # for offloaded modules
+        module_map = {}
+
+        # Save the model
+        if state_dict is None:
+            # if any model parameters are offloaded, make module map
+            if (
+                hasattr(self, "hf_device_map")
+                and len(set(self.hf_device_map.values())) > 1
+                and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
+            ):
+                warnings.warn(
+                    "Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
+                )
+                for name, module in model_to_save.named_modules():
+                    if name == "":
+                        continue
+                    module_state_dict = module.state_dict()
+
+                    for key in module_state_dict:
+                        module_map[name + f".{key}"] = module
+            state_dict = model_to_save.state_dict()
+
+        if any(
+            allowed_name in class_name.__name__.lower()
+            for class_name in self.__class__.__mro__[:-1]
+            for allowed_name in VLMS
+        ):
+            reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()}
+
+            original_state_dict = {}
+            for key, value in state_dict.items():
+                for pattern, replacement in reverse_key_mapping.items():
+                    replacement = replacement.lstrip("^")  # strip off un-needed chars and patterns
+                    replacement = re.sub(r"\(.*\)", "", replacement)
+                    key, n_replace = re.subn(pattern, replacement, key)
+                    # Early exit of the loop
+                    if n_replace > 0:
+                        break
+                original_state_dict[key] = value
+            state_dict = original_state_dict
+
+        # Translate state_dict from smp to hf if saving with smp >= 1.10
+        if IS_SAGEMAKER_MP_POST_1_10:
+            for smp_to_hf, _ in smp.state.module_manager.translate_functions:
+                state_dict = smp_to_hf(state_dict)
+
+        # Handle the case where some state_dict keys shouldn't be saved
+        if self._keys_to_ignore_on_save is not None:
+            for ignore_key in self._keys_to_ignore_on_save:
+                if ignore_key in state_dict:
+                    del state_dict[ignore_key]
+
+        # Rename state_dict keys before saving to file. Do nothing unless overridden in a particular model.
+        # (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm)
+        state_dict = self._fix_state_dict_keys_on_save(state_dict)
+        # If model was sharded, we cannot properly determine sizes of tensors that `local_*` strategy was used,
+        # therefore we replace them with DTensors that are equivalently sharded
+        if self._tp_size is not None:
+            state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
+
+        if safe_serialization:
+            # TODO: fix safe_serialization for tied weights
+            # Safetensors does not allow tensor aliasing.
+            # We're going to remove aliases before saving
+            ptrs = collections.defaultdict(list)
+            for name, tensor in state_dict.items():
+                if not isinstance(tensor, torch.Tensor):
+                    # Sometimes in the state_dict we have non-tensor objects.
+                    # e.g. in bitsandbytes we have some `str` objects in the state_dict
+                    # In the non-tensor case, fall back to the pointer of the object itself
+                    ptrs[id(tensor)].append(name)
+
+                elif tensor.device.type == "meta":
+                    # In offloaded cases, there may be meta tensors in the state_dict.
+                    # For these cases, key by the pointer of the original tensor object
+                    # (state_dict tensors are detached and therefore no longer shared)
+                    tensor = self.get_parameter(name)
+                    ptrs[id(tensor)].append(name)
+
+                else:
+                    ptrs[id_tensor_storage(tensor)].append(name)
+
+            shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
+
+            # Recursively descend to find tied weight keys
+            _tied_weights_keys = _get_tied_weight_keys(self)
+            error_names = []
+            to_delete_names = set()
+            for names in shared_ptrs.values():
+                # Removing the keys which are declared as known duplicates on
+                # load. This allows to make sure the name which is kept is consistent.
+                if _tied_weights_keys is not None:
+                    found = 0
+                    for name in sorted(names):
+                        matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
+                        if matches_pattern and name in state_dict:
+                            found += 1
+                            if found < len(names):
+                                to_delete_names.add(name)
+            # We are entering a place where the weights and the transformers configuration do NOT match.
+            shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
+            # Those are actually tensor sharing but disjoint from each other, we can safely clone them
+            # Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
+            for name in disjoint_names:
+                state_dict[name] = state_dict[name].clone()
+
+            # When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
+            # If the link between tensors was done at runtime then `from_pretrained` will not get
+            # the key back leading to random tensor. A proper warning will be shown
+            # during reload (if applicable), but since the file is not necessarily compatible with
+            # the config, better show a proper warning.
+            shared_names, identical_names = _find_identical(shared_names, state_dict)
+            # delete tensors that have identical storage
+            for inames in identical_names:
+                known = inames.intersection(to_delete_names)
+                for name in known:
+                    del state_dict[name]
+                unknown = inames.difference(to_delete_names)
+                if len(unknown) > 1:
+                    error_names.append(unknown)
+
+            if shared_names:
+                error_names.extend(shared_names)
+
+            if len(error_names) > 0:
+                raise RuntimeError(
+                    f"The weights trying to be saved contained shared tensors {error_names} that are mismatching "
+                    "the transformers base configuration. Try saving using `safe_serialization=False`, setting the "
+                    "`_dynamic_tied_weights_keys` attribute for affected modules, or remove this tensor sharing.",
+                )
+
+        # Shard the model if it is too big.
+        if not _hf_peft_config_loaded:
+            weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
+            weights_name = _add_variant(weights_name, variant)
+        else:
+            weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME
+
+        filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
+        state_dict_split = split_torch_state_dict_into_shards(
+            state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
+        )
+        # Save index if sharded
+        index = None
+        if state_dict_split.is_sharded:
+            index = {
+                "metadata": {"total_parameters": self.num_parameters(), **state_dict_split.metadata},
+                "weight_map": state_dict_split.tensor_to_filename,
+            }
+
+        # Clean the folder from a previous save
+        for filename in os.listdir(save_directory):
+            full_filename = os.path.join(save_directory, filename)
+            # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
+            # in distributed settings to avoid race conditions.
+            weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
+
+            # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
+            filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
+            reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
+
+            if (
+                filename.startswith(weights_no_suffix)
+                and os.path.isfile(full_filename)
+                and filename not in state_dict_split.filename_to_tensors
+                and is_main_process
+                and reg.fullmatch(filename_no_suffix) is not None
+            ):
+                os.remove(full_filename)
+        # Save the model
+        filename_to_tensors = state_dict_split.filename_to_tensors.items()
+        if module_map:
+            filename_to_tensors = logging.tqdm(filename_to_tensors, desc="Saving checkpoint shards")
+        for shard_file, tensors in filename_to_tensors:
+            shard = {}
+            for tensor in tensors:
+                if _is_dtensor_available and isinstance(state_dict[tensor], DTensor):
+                    full_tensor = state_dict[tensor].full_tensor()
+                    # to get the correctly ordered tensor we need to repack if packed
+                    if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",):
+                        full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
+                    shard[tensor] = full_tensor.contiguous()  # only do contiguous after it's permuted correctly
+                else:
+                    shard[tensor] = state_dict[tensor].contiguous()
+                # delete reference, see https://github.com/huggingface/transformers/pull/34890
+                del state_dict[tensor]
+
+            # remake shard with onloaded parameters if necessary
+            if module_map:
+                if accelerate_version < version.parse("0.31"):
+                    raise ImportError(
+                        f"You need accelerate version to be greater or equal than 0.31 to save models with offloaded parameters. Detected version {accelerate_version}. "
+                        f"Please upgrade accelerate with `pip install -U accelerate`"
+                    )
+                # init state_dict for this shard
+                shard_state_dict = dict.fromkeys(shard, "")
+                for module_name in shard:
+                    # note that get_state_dict_from_offload can update with meta tensors
+                    # if both a parent module and its descendant are offloaded
+                    tensor = shard_state_dict[module_name]
+                    if tensor == "" or (isinstance(tensor, torch.Tensor) and tensor.device.type == "meta"):
+                        # update state dict with onloaded parameters
+                        module = module_map[module_name]
+                        shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)
+
+                # assign shard to be the completed state dict
+                shard = shard_state_dict
+                del shard_state_dict
+                gc.collect()
+
+            if safe_serialization:
+                # At some point we will need to deal better with save_function (used for TPU and other distributed
+                # joyfulness), but for now this enough.
+                safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
+            else:
+                save_function(shard, os.path.join(save_directory, shard_file))
+
+        del state_dict
+
+        if index is None:
+            path_to_weights = os.path.join(save_directory, weights_name)
+            logger.info(f"Model weights saved in {path_to_weights}")
+        else:
+            save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
+            save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
+            # Save the index as well
+            with open(save_index_file, "w", encoding="utf-8") as f:
+                content = json.dumps(index, indent=2, sort_keys=True) + "\n"
+                f.write(content)
+            logger.info(
+                f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
+                f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
+                f"index located at {save_index_file}."
+            )
+
+        if push_to_hub:
+            # Eventually create an empty model card
+            model_card = create_and_tag_model_card(
+                repo_id, self.model_tags, token=token, ignore_metadata_errors=ignore_metadata_errors
+            )
+
+            # Update model card if needed:
+            model_card.save(os.path.join(save_directory, "README.md"))
+
+            self._upload_modified_files(
+                save_directory,
+                repo_id,
+                files_timestamps,
+                commit_message=commit_message,
+                token=token,
+                create_pr=create_pr,
+            )
+
+    @wraps(PushToHubMixin.push_to_hub)
+    def push_to_hub(self, *args, **kwargs):
+        tags = self.model_tags if self.model_tags is not None else []
+
+        tags_kwargs = kwargs.get("tags", [])
+        if isinstance(tags_kwargs, str):
+            tags_kwargs = [tags_kwargs]
+
+        for tag in tags_kwargs:
+            if tag not in tags:
+                tags.append(tag)
+
+        if tags:
+            kwargs["tags"] = tags
+        return super().push_to_hub(*args, **kwargs)
+
+    def get_memory_footprint(self, return_buffers=True):
+        r"""
+        Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
+        Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
+        PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
+
+        Arguments:
+            return_buffers (`bool`, *optional*, defaults to `True`):
+                Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
+                are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
+                norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
+        """
+        mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
+        if return_buffers:
+            mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
+            mem = mem + mem_bufs
+        return mem
+
+    @wraps(torch.nn.Module.cuda)
+    def cuda(self, *args, **kwargs):
+        if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
+            from hqq.core.quantize import HQQLinear
+
+            # Since HQQLinear stores some tensors in the 'meta' attribute,
+            # it's necessary to manually call the `cuda` method on HQQLinear layers.
+            super().cuda(*args, **kwargs)
+            for module in self.modules():
+                if isinstance(module, HQQLinear):
+                    if len(args) > 0:
+                        device = args[0]
+                    else:
+                        device = kwargs.get("device", "cuda")
+                    module.cuda(device)
+            return self
+
+        # Checks if the model has been loaded in 4-bit or 8-bit with BNB
+        if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
+            if getattr(self, "is_loaded_in_8bit", False):
+                raise ValueError(
+                    "Calling `cuda()` is not supported for `8-bit` quantized models. "
+                    " Please use the model as it is, since the model has already been set to the correct devices."
+                )
+            elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
+                raise ValueError(
+                    "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
+                    f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
+                )
+        return super().cuda(*args, **kwargs)
+
+    @wraps(torch.nn.Module.to)
+    def to(self, *args, **kwargs):
+        # For BNB/GPTQ models, we prevent users from casting the model to another dtype to restrict unwanted behaviours.
+        # the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
+        dtype_present_in_args = "dtype" in kwargs
+
+        if not dtype_present_in_args:
+            for arg in args:
+                if isinstance(arg, torch.dtype):
+                    dtype_present_in_args = True
+                    break
+
+        if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
+            from hqq.core.quantize import HQQLinear
+
+            # Since HQQLinear stores some tensors in the 'meta' attribute, we must
+            # explicitly move the parameters to the target device for each HQQLinear layer after `to`.
+            super().to(*args, **kwargs)
+            for module in self.modules():
+                if isinstance(module, HQQLinear):
+                    if "device" in kwargs:
+                        device = kwargs["device"]
+                    else:
+                        device = args[0]
+                    if "dtype" in kwargs:
+                        dtype = kwargs["dtype"]
+                    elif dtype_present_in_args:
+                        dtype = arg
+                    else:
+                        dtype = None
+                    # Due to the current messy implementation of HQQLinear, updating `compute_dtype`
+                    # followed by calling the `cuda` method achieves the intended behavior of `to`,
+                    # even when the target device is CPU.
+                    if dtype is not None:
+                        module.compute_dtype = dtype
+                    module.cuda(device)
+            return self
+
+        if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK:
+            raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.")
+
+        # Checks if the model has been loaded in 4-bit or 8-bit with BNB
+        if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
+            if dtype_present_in_args:
+                raise ValueError(
+                    "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
+                    " desired `dtype` by passing the correct `dtype` argument."
+                )
+
+            if getattr(self, "is_loaded_in_8bit", False):
+                raise ValueError(
+                    "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
+                    " model has already been set to the correct devices and casted to the correct `dtype`."
+                )
+            elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
+                raise ValueError(
+                    "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
+                    f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
+                )
+        elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
+            if dtype_present_in_args:
+                raise ValueError(
+                    "You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired"
+                    " `dtype` by passing the correct `dtype` argument."
+                )
+        return super().to(*args, **kwargs)
+
+    def half(self, *args):
+        # Checks if the model is quantized
+        if getattr(self, "is_quantized", False):
+            raise ValueError(
+                "`.half()` is not supported for quantized model. Please use the model as it is, since the"
+                " model has already been casted to the correct `dtype`."
+            )
+        else:
+            return super().half(*args)
+
+    def float(self, *args):
+        # Checks if the model is quantized
+        if getattr(self, "is_quantized", False):
+            raise ValueError(
+                "`.float()` is not supported for quantized model. Please use the model as it is, since the"
+                " model has already been casted to the correct `dtype`."
+            )
+        else:
+            return super().float(*args)
+
+    @classmethod
+    def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
+        if is_deepspeed_zero3_enabled():
+            import deepspeed
+
+            init_contexts = [no_init_weights()]
+            # We cannot initialize the model on meta device with deepspeed when not quantized
+            if not is_quantized and not _is_ds_init_called:
+                logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
+                init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])
+            elif is_quantized:
+                init_contexts.extend([init_empty_weights(), set_quantized_state()])
+        else:
+            init_contexts = [no_init_weights(), init_empty_weights()]
+
+        return init_contexts
+
+    @classmethod
+    @restore_default_dtype
+    def from_pretrained(
+        cls: type[SpecificPreTrainedModelType],
+        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+        *model_args,
+        config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
+        cache_dir: Optional[Union[str, os.PathLike]] = None,
+        ignore_mismatched_sizes: bool = False,
+        force_download: bool = False,
+        local_files_only: bool = False,
+        token: Optional[Union[str, bool]] = None,
+        revision: str = "main",
+        use_safetensors: Optional[bool] = None,
+        weights_only: bool = True,
+        **kwargs,
+    ) -> SpecificPreTrainedModelType:
+        r"""
+        Instantiate a pretrained pytorch model from a pre-trained model configuration.
+
+        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
+        the model, you should first set it back in training mode with `model.train()`.
+
+        The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
+        pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
+        task.
+
+        The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
+        weights are discarded.
+
+        Parameters:
+            pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+                Can be either:
+
+                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
+                    - A path to a *directory* containing model weights saved using
+                      [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
+                    - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
+                      this case, `from_tf` should be set to `True` and a configuration object should be provided as
+                      `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
+                      PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
+                    - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
+                      `./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
+                      `True`.
+                    - `None` if you are both providing the configuration and state dictionary (resp. with keyword
+                      arguments `config` and `state_dict`).
+            model_args (sequence of positional arguments, *optional*):
+                All remaining positional arguments will be passed to the underlying model's `__init__` method.
+            config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
+                Can be either:
+
+                    - an instance of a class derived from [`PretrainedConfig`],
+                    - a string or path valid as input to [`~PretrainedConfig.from_pretrained`].
+
+                Configuration for the model to use instead of an automatically loaded configuration. Configuration can
+                be automatically loaded when:
+
+                    - The model is a model provided by the library (loaded with the *model id* string of a pretrained
+                      model).
+                    - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
+                      save directory.
+                    - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
+                      configuration JSON file named *config.json* is found in the directory.
+            state_dict (`dict[str, torch.Tensor]`, *optional*):
+                A state dictionary to use instead of a state dictionary loaded from saved weights file.
+
+                This option can be used if you want to create a model from a pretrained configuration but load your own
+                weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and
+                [`~PreTrainedModel.from_pretrained`] is not a simpler option.
+            cache_dir (`Union[str, os.PathLike]`, *optional*):
+                Path to a directory in which a downloaded pretrained model configuration should be cached if the
+                standard cache should not be used.
+            from_tf (`bool`, *optional*, defaults to `False`):
+                Load the model weights from a TensorFlow checkpoint save file (see docstring of
+                `pretrained_model_name_or_path` argument).
+            from_flax (`bool`, *optional*, defaults to `False`):
+                Load the model weights from a Flax checkpoint save file (see docstring of
+                `pretrained_model_name_or_path` argument).
+            ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
+                Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
+                as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
+                checkpoint with 3 labels).
+            force_download (`bool`, *optional*, defaults to `False`):
+                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+                cached versions if they exist.
+            resume_download:
+                Deprecated and ignored. All downloads are now resumed by default when possible.
+                Will be removed in v5 of Transformers.
+            proxies (`dict[str, str]`, *optional*):
+                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+            output_loading_info(`bool`, *optional*, defaults to `False`):
+                Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
+            local_files_only(`bool`, *optional*, defaults to `False`):
+                Whether or not to only look at local files (i.e., do not try to download the model).
+            token (`str` or `bool`, *optional*):
+                The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+                the token generated when running `hf auth login` (stored in `~/.huggingface`).
+            revision (`str`, *optional*, defaults to `"main"`):
+                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+                identifier allowed by git.
+
+                
+
+                To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`.
+
+                
+            attn_implementation (`str`, *optional*):
+                The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)), or `"flash_attention_3"` (using [Dao-AILab/flash-attention/hopper](https://github.com/Dao-AILab/flash-attention/tree/main/hopper)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
+
+                Accept HF kernel references in the form:
+                  /[@][:]
+
+                -  and  are any non-"/" and non-":" sequences.
+                - "@" is optional (branch, tag, or commit-ish), e.g. "@main", "@v1.2.0", "@abc123".
+                - ":" is optional and selects a function inside the kernel repo.
+                - Both options can appear together and in this order only: @revision first, then :kernel_name.
+                - We intentionally allow a leading "|" prefix (e.g., "flash|...") because the code
+                  strips it before loading; '|' is not excluded in the character classes here.
+
+                Examples that match:
+                  "org/model"
+                  "org/model@main"
+                  "org/model:custom_kernel"
+                  "org/model@v1.2.3:custom_kernel"
+
+            > Parameters for big model inference
+
+            dtype (`str` or `torch.dtype`, *optional*):
+                Override the default `torch_dtype` and load the model under a specific `dtype`. The different options
+                are:
+
+                1. `torch.float16` or `torch.bfloat16` or `torch.float`: load in a specified
+                  `dtype`, ignoring the model's `config.dtype` if one exists. If not specified
+                  - the model will get loaded in `torch.float` (fp32).
+
+                2. `"auto"` - A `dtype` or `torch_dtype` entry in the `config.json` file of the model will be
+                  attempted to be used. If this entry isn't found then next check the `dtype` of the first weight in
+                  the checkpoint that's of a floating point type and use that as `dtype`. This will load the model
+                  using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how
+                  the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
+
+                3. A string that is a valid `torch.dtype`. E.g. "float32" loads the model in `torch.float32`, "float16" loads in `torch.float16` etc.
+
+                
+
+                For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or
+                reach out to the authors and ask them to add this information to the model's card and to insert the
+                `dtype` or `torch_dtype` entry in `config.json` on the hub.
+
+                
+
+            device_map (`str` or `dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*):
+                A map that specifies where each submodule should go. It doesn't need to be refined to each
+                parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
+                same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank
+                like `1`) on which the model will be allocated, the device map will map the entire model to this
+                device. Passing `device_map = 0` means put the whole model on GPU 0.
+
+                To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
+                more information about each option see [designing a device
+                map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
+            max_memory (`Dict`, *optional*):
+                A dictionary device identifier to maximum memory if using `device_map`. Will default to the maximum memory available for each
+                GPU and the available CPU RAM if unset.
+            tp_plan (`str`, *optional*):
+                A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts
+                `tp_plan="auto"` to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with
+                `torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
+            tp_size (`str`, *optional*):
+                A torch tensor parallel degree. If not provided would default to world size.
+            device_mesh (`torch.distributed.DeviceMesh`, *optional*):
+                A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now.
+                If provided, it has to contain dimension named `"tp"` in case it's > 1 dimensional, this dimension will be used for tensor parallelism
+            offload_folder (`str` or `os.PathLike`, *optional*):
+                If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
+            offload_state_dict (`bool`, *optional*):
+                If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
+                RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
+                `True` when there is some disk offload.
+            offload_buffers (`bool`, *optional*):
+                Whether or not to offload the buffers with the model parameters.
+            quantization_config (`Union[QuantizationConfigMixin,Dict]`, *optional*):
+                A dictionary of configuration parameters or a QuantizationConfigMixin object for quantization (e.g
+                bitsandbytes, gptq). There may be other quantization-related kwargs, including `load_in_4bit` and
+                `load_in_8bit`, which are parsed by QuantizationConfigParser. Supported only for bitsandbytes
+                quantizations and not preferred. consider inserting all such arguments into quantization_config
+                instead.
+            subfolder (`str`, *optional*, defaults to `""`):
+                In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+                specify the folder name here.
+            variant (`str`, *optional*):
+                If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is
+                ignored when using `from_tf` or `from_flax`.
+            use_safetensors (`bool`, *optional*, defaults to `None`):
+                Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
+                is not installed, it will be set to `False`.
+            weights_only (`bool`, *optional*, defaults to `True`):
+                Indicates whether unpickler should be restricted to loading only tensors, primitive types,
+                dictionaries and any types added via torch.serialization.add_safe_globals().
+                When set to False, we can load wrapper tensor subclass weights.
+            key_mapping (`dict[str, str], *optional*):
+                A potential mapping of the weight names if using a model on the Hub which is compatible to a Transformers
+                architecture, but was not converted accordingly.
+            kwargs (remaining dictionary of keyword arguments, *optional*):
+                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
+                `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
+                automatically loaded:
+
+                    - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
+                      underlying model's `__init__` method (we assume all relevant updates to the configuration have
+                      already been done)
+                    - If a configuration is not provided, `kwargs` will be first passed to the configuration class
+                      initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
+                      corresponds to a configuration attribute will be used to override said attribute with the
+                      supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
+                      will be passed to the underlying model's `__init__` function.
+
+        
+
+        Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
+        use this method in a firewalled environment.
+
+        
+
+        Examples:
+
+        ```python
+        >>> from transformers import BertConfig, BertModel
+
+        >>> # Download model and configuration from huggingface.co and cache.
+        >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased")
+        >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
+        >>> model = BertModel.from_pretrained("./test/saved_model/")
+        >>> # Update configuration during loading.
+        >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True)
+        >>> assert model.config.output_attentions == True
+        >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
+        >>> config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json")
+        >>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)
+        >>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
+        >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True)
+        ```
+        """
+        state_dict = kwargs.pop("state_dict", None)
+        from_tf = kwargs.pop("from_tf", False)
+        from_flax = kwargs.pop("from_flax", False)
+        proxies = kwargs.pop("proxies", None)
+        output_loading_info = kwargs.pop("output_loading_info", False)
+        use_auth_token = kwargs.pop("use_auth_token", None)
+        from_pipeline = kwargs.pop("_from_pipeline", None)
+        from_auto_class = kwargs.pop("_from_auto", False)
+        dtype = kwargs.pop("dtype", None)
+        torch_dtype = kwargs.pop("torch_dtype", None)  # kept for BC
+        device_map = kwargs.pop("device_map", None)
+        max_memory = kwargs.pop("max_memory", None)
+        offload_folder = kwargs.pop("offload_folder", None)
+        offload_state_dict = kwargs.pop("offload_state_dict", False)
+        offload_buffers = kwargs.pop("offload_buffers", False)
+        load_in_8bit = kwargs.pop("load_in_8bit", False)
+        load_in_4bit = kwargs.pop("load_in_4bit", False)
+        quantization_config = kwargs.pop("quantization_config", None)
+        subfolder = kwargs.pop("subfolder", "")
+        commit_hash = kwargs.pop("_commit_hash", None)
+        variant = kwargs.pop("variant", None)
+        adapter_kwargs = kwargs.pop("adapter_kwargs", {})
+        adapter_name = kwargs.pop("adapter_name", "default")
+        generation_config = kwargs.pop("generation_config", None)
+        gguf_file = kwargs.pop("gguf_file", None)
+        tp_plan = kwargs.pop("tp_plan", None)
+        tp_size = kwargs.pop("tp_size", None)
+        distributed_config: DistributedConfig = kwargs.pop("distributed_config", None)
+        device_mesh = kwargs.pop("device_mesh", None)
+        trust_remote_code = kwargs.pop("trust_remote_code", None)
+        use_kernels = kwargs.pop("use_kernels", False)
+
+        key_mapping = kwargs.pop("key_mapping", None)
+        # Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
+        if key_mapping is None and any(
+            allowed_name in class_name.__name__.lower() for class_name in cls.__mro__[:-1] for allowed_name in VLMS
+        ):
+            key_mapping = cls._checkpoint_conversion_mapping
+
+        if distributed_config is not None:
+            tp_plan = "auto"
+
+        # Not used anymore -- remove them from the kwargs
+        _ = kwargs.pop("resume_download", None)
+        _ = kwargs.pop("mirror", None)
+        _ = kwargs.pop("_fast_init", True)
+        _ = kwargs.pop("low_cpu_mem_usage", None)
+
+        # For BC on torch_dtype argument
+        if torch_dtype is not None:
+            logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
+            # If both kwargs are provided, use `dtype`
+            dtype = dtype if dtype is not None else torch_dtype
+
+        if state_dict is not None and (pretrained_model_name_or_path is not None or gguf_file is not None):
+            raise ValueError(
+                "`state_dict` cannot be passed together with a model name or a `gguf_file`. Use one of the two loading strategies."
+            )
+        if tp_size is not None and tp_plan is None:
+            raise ValueError("tp_plan has to be set when tp_size is passed.")
+        if tp_plan is not None and tp_plan != "auto":
+            # TODO: we can relax this check when we support taking tp_plan from a json file, for example.
+            raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
+        if tp_plan is not None and device_map is not None:
+            raise ValueError(
+                "`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization."
+            )
+
+        if device_map == "auto" and int(os.environ.get("WORLD_SIZE", "0")):
+            logger.info(
+                "You've set device_map=`auto` while triggering a distributed run with torchrun. This might lead to unexpected behavior. "
+                "If your plan is to load the model on each device, you should set device_map={"
+                ": PartialState().process_index} where PartialState comes from accelerate library"
+            )
+
+        # We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
+        # `device_map` pointing to the correct device
+        if tp_plan is not None:
+            if device_mesh is None:
+                tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(tp_plan, tp_size=tp_size)
+            else:
+                if device_mesh.ndim > 1:
+                    if "tp" not in device_mesh.mesh_dim_names:
+                        raise ValueError(
+                            "When using `tp_plan` and n-d `device_mesh`, it must contain a 'tp' dimension. "
+                            "Please provide a valid `device_mesh`."
+                        )
+                    device_mesh = device_mesh["tp"]
+                tp_size = device_mesh.size()
+                device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
+
+            if tp_size is None:
+                tp_size = torch.distributed.get_world_size()
+
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if token is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            token = use_auth_token
+
+        if token is not None and adapter_kwargs is not None and "token" not in adapter_kwargs:
+            adapter_kwargs["token"] = token
+
+        if use_safetensors is None and not is_safetensors_available():
+            use_safetensors = False
+
+        if gguf_file is not None and not is_accelerate_available():
+            raise ValueError("accelerate is required when loading a GGUF file `pip install accelerate`.")
+
+        if commit_hash is None:
+            if not isinstance(config, PretrainedConfig):
+                # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
+                resolved_config_file = cached_file(
+                    pretrained_model_name_or_path,
+                    CONFIG_NAME,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    proxies=proxies,
+                    local_files_only=local_files_only,
+                    token=token,
+                    revision=revision,
+                    subfolder=subfolder,
+                    _raise_exceptions_for_gated_repo=False,
+                    _raise_exceptions_for_missing_entries=False,
+                    _raise_exceptions_for_connection_errors=False,
+                )
+                commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
+            else:
+                commit_hash = getattr(config, "_commit_hash", None)
+
+        if is_peft_available():
+            _adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)
+
+            if _adapter_model_path is None:
+                _adapter_model_path = find_adapter_config_file(
+                    pretrained_model_name_or_path,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    proxies=proxies,
+                    local_files_only=local_files_only,
+                    _commit_hash=commit_hash,
+                    **adapter_kwargs,
+                )
+            if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
+                with open(_adapter_model_path, "r", encoding="utf-8") as f:
+                    _adapter_model_path = pretrained_model_name_or_path
+                    pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
+        else:
+            _adapter_model_path = None
+
+        # Potentially detect context manager or global device, and use it (only if no device_map was provided)
+        if device_map is None and not is_deepspeed_zero3_enabled():
+            device_in_context = get_torch_context_manager_or_global_device()
+            if device_in_context == torch.device("meta"):
+                # TODO Cyril: raise an error instead of the warning in v4.53 (and change the test to check for raise instead of success)
+                logger.warning(
+                    "We detected that you are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`\n"
+                    "This is an anti-pattern and will raise an Error in version v4.53\nIf you want to initialize a model on the meta device, use "
+                    "the context manager or global device with `from_config`, or `ModelClass(config)`"
+                )
+            device_map = device_in_context
+
+        # change device_map into a map if we passed an int, a str or a torch.device
+        if isinstance(device_map, torch.device):
+            device_map = {"": device_map}
+        elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
+            try:
+                device_map = {"": torch.device(device_map)}
+            except RuntimeError:
+                raise ValueError(
+                    "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
+                    f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
+                )
+        elif isinstance(device_map, int):
+            if device_map < 0:
+                raise ValueError(
+                    "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
+                )
+            else:
+                device_map = {"": device_map}
+
+        if device_map is not None:
+            if is_deepspeed_zero3_enabled():
+                raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
+            if not is_accelerate_available():
+                raise ValueError(
+                    "Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` "
+                    "requires `accelerate`. You can install it with `pip install accelerate`"
+                )
+
+        # handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
+        if load_in_4bit or load_in_8bit:
+            if quantization_config is not None:
+                raise ValueError(
+                    "You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing "
+                    "`quantization_config` argument at the same time."
+                )
+
+            # preparing BitsAndBytesConfig from kwargs
+            config_dict = {k: v for k, v in kwargs.items() if k in inspect.signature(BitsAndBytesConfig).parameters}
+            config_dict = {**config_dict, "load_in_4bit": load_in_4bit, "load_in_8bit": load_in_8bit}
+            quantization_config, kwargs = BitsAndBytesConfig.from_dict(
+                config_dict=config_dict, return_unused_kwargs=True, **kwargs
+            )
+            logger.warning(
+                "The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. "
+                "Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead."
+            )
+
+        from_pt = not (from_tf | from_flax)
+
+        user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
+        if from_pipeline is not None:
+            user_agent["using_pipeline"] = from_pipeline
+
+        if is_offline_mode() and not local_files_only:
+            logger.info("Offline mode: forcing local_files_only=True")
+            local_files_only = True
+
+        # Load config if we don't provide a configuration
+        if not isinstance(config, PretrainedConfig):
+            config_path = config if config is not None else pretrained_model_name_or_path
+            config, model_kwargs = cls.config_class.from_pretrained(
+                config_path,
+                cache_dir=cache_dir,
+                return_unused_kwargs=True,
+                force_download=force_download,
+                proxies=proxies,
+                local_files_only=local_files_only,
+                token=token,
+                revision=revision,
+                subfolder=subfolder,
+                gguf_file=gguf_file,
+                _from_auto=from_auto_class,
+                _from_pipeline=from_pipeline,
+                **kwargs,
+            )
+            if "gguf_file" in model_kwargs:
+                model_kwargs.pop("gguf_file")
+        else:
+            config = copy.deepcopy(config)
+            model_kwargs = kwargs
+
+        # Because some composite configs call super().__init__ before instantiating the sub-configs, we need this call
+        # to correctly redispatch recursively if the kwarg is provided
+        if "attn_implementation" in kwargs:
+            config._attn_implementation = kwargs.pop("attn_implementation")
+
+        transformers_explicit_filename = getattr(config, "transformers_weights", None)
+
+        if transformers_explicit_filename is not None:
+            if not transformers_explicit_filename.endswith(
+                ".safetensors"
+            ) and not transformers_explicit_filename.endswith(".safetensors.index.json"):
+                raise ValueError(
+                    "The transformers file in the config seems to be incorrect: it is neither a safetensors file "
+                    "(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
+                    f"{transformers_explicit_filename}"
+                )
+
+        hf_quantizer, config, dtype, device_map = get_hf_quantizer(
+            config, quantization_config, dtype, from_tf, from_flax, device_map, weights_only, user_agent
+        )
+
+        if gguf_file is not None and hf_quantizer is not None:
+            raise ValueError(
+                "You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub."
+            )
+
+        if (
+            gguf_file
+            and device_map is not None
+            and ((isinstance(device_map, dict) and "disk" in device_map.values()) or "disk" in device_map)
+        ):
+            raise RuntimeError(
+                "One or more modules is configured to be mapped to disk. Disk offload is not supported for models "
+                "loaded from GGUF files."
+            )
+
+        checkpoint_files, sharded_metadata = _get_resolved_checkpoint_files(
+            pretrained_model_name_or_path=pretrained_model_name_or_path,
+            subfolder=subfolder,
+            variant=variant,
+            gguf_file=gguf_file,
+            from_tf=from_tf,
+            from_flax=from_flax,
+            use_safetensors=use_safetensors,
+            cache_dir=cache_dir,
+            force_download=force_download,
+            proxies=proxies,
+            local_files_only=local_files_only,
+            token=token,
+            user_agent=user_agent,
+            revision=revision,
+            commit_hash=commit_hash,
+            is_remote_code=cls._auto_class is not None,
+            transformers_explicit_filename=transformers_explicit_filename,
+        )
+
+        is_sharded = sharded_metadata is not None
+        is_quantized = hf_quantizer is not None
+        is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None
+
+        if (
+            is_safetensors_available()
+            and is_from_file
+            and not is_sharded
+            and checkpoint_files[0].endswith(".safetensors")
+        ):
+            with safe_open(checkpoint_files[0], framework="pt") as f:
+                metadata = f.metadata()
+
+            if metadata is None:
+                # Assume it's a pytorch checkpoint (introduced for timm checkpoints)
+                pass
+            elif metadata.get("format") == "pt":
+                pass
+            elif metadata.get("format") == "tf":
+                from_tf = True
+                logger.info("A TensorFlow safetensors file is being loaded in a PyTorch model.")
+            elif metadata.get("format") == "flax":
+                from_flax = True
+                logger.info("A Flax safetensors file is being loaded in a PyTorch model.")
+            elif metadata.get("format") == "mlx":
+                # This is a mlx file, we assume weights are compatible with pt
+                pass
+            else:
+                raise ValueError(
+                    f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax', 'mlx'] but {metadata.get('format')}"
+                )
+
+        from_pt = not (from_tf | from_flax)
+
+        if from_pt:
+            if gguf_file:
+                from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
+
+                # we need a dummy model to get the state_dict - for this reason, we keep the state_dict as if it was
+                # passed directly as a kwarg from now on
+                with torch.device("meta"):
+                    dummy_model = cls(config)
+                state_dict = load_gguf_checkpoint(checkpoint_files[0], return_tensors=True, model_to_load=dummy_model)[
+                    "tensors"
+                ]
+
+            # Find the correct dtype based on current state
+            config, dtype, dtype_orig = _get_dtype(
+                cls, dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only
+            )
+
+        config.name_or_path = pretrained_model_name_or_path
+        model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
+        config = copy.deepcopy(config)  # We do not want to modify the config inplace in from_pretrained.
+        with ContextManagers(model_init_context):
+            # Let's make sure we don't run the init function of buffer modules
+            model = cls(config, *model_args, **model_kwargs)
+
+        # Make sure to tie the weights correctly
+        model.tie_weights()
+
+        # make sure we use the model's config since the __init__ call might have copied it
+        config = model.config
+
+        # Find fp32 modules if needed
+        keep_in_fp32_modules = []
+        # The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
+        # in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
+        # step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
+        if model._keep_in_fp32_modules is not None and (
+            dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
+        ):
+            keep_in_fp32_modules.extend(model._keep_in_fp32_modules)
+
+        if model._keep_in_fp32_modules_strict is not None and (dtype == torch.float16 or dtype == torch.bfloat16):
+            keep_in_fp32_modules.extend(model._keep_in_fp32_modules_strict)
+
+        keep_in_fp32_regex = None
+        if keep_in_fp32_modules:
+            # We need to match exact layers, so we add either `.` on each side, or start/end of string
+            keep_in_fp32_regex = re.compile("|".join([rf"((^|\.){module}($|\.))" for module in keep_in_fp32_modules]))
+
+        if hf_quantizer is not None:
+            hf_quantizer.preprocess_model(
+                model=model,
+                device_map=device_map,
+                keep_in_fp32_modules=model._keep_in_fp32_modules,
+                config=config,
+                use_kernels=use_kernels,
+            )
+            # We store the original dtype for quantized models as we cannot easily retrieve it
+            # once the weights have been quantized
+            # Note that once you have loaded a quantized model, you can't change its dtype so this will
+            # remain a single source of truth
+            original_dtype = dtype if dtype is not None else torch.get_default_dtype()
+
+            def _assign_original_dtype(module):
+                for child in module.children():
+                    if isinstance(child, PreTrainedModel):
+                        child.config._pre_quantization_dtype = original_dtype
+                    _assign_original_dtype(child)
+
+            config._pre_quantization_dtype = original_dtype
+            _assign_original_dtype(model)
+
+        if _torch_distributed_available and device_mesh is not None:
+            model = distribute_model(model, distributed_config, device_mesh, tp_size)
+
+        # Prepare the full device map
+        if device_map is not None:
+            device_map = _get_device_map(model, device_map, max_memory, hf_quantizer, dtype, keep_in_fp32_regex)
+
+        # Finalize model weight initialization
+        if from_tf:
+            model, loading_info = cls._load_from_tf(model, config, checkpoint_files)
+        elif from_flax:
+            model = cls._load_from_flax(model, checkpoint_files)
+        elif from_pt:
+            # restore default dtype
+            if dtype_orig is not None:
+                torch.set_default_dtype(dtype_orig)
+
+            (
+                model,
+                missing_keys,
+                unexpected_keys,
+                mismatched_keys,
+                offload_index,
+                error_msgs,
+            ) = cls._load_pretrained_model(
+                model,
+                state_dict,
+                checkpoint_files,
+                pretrained_model_name_or_path,
+                ignore_mismatched_sizes=ignore_mismatched_sizes,
+                sharded_metadata=sharded_metadata,
+                device_map=device_map,
+                disk_offload_folder=offload_folder,
+                offload_state_dict=offload_state_dict,
+                dtype=dtype,
+                hf_quantizer=hf_quantizer,
+                keep_in_fp32_regex=keep_in_fp32_regex,
+                device_mesh=device_mesh,
+                key_mapping=key_mapping,
+                weights_only=weights_only,
+            )
+        # make sure token embedding weights are still tied if needed
+        model.tie_weights()
+
+        # Set model in evaluation mode to deactivate DropOut modules by default
+        model.eval()
+
+        # check if using kernels
+        if use_kernels:
+            if not is_kernels_available():
+                raise ValueError(
+                    "Kernels are not available. To use kernels, please install kernels using `pip install kernels`"
+                )
+
+            from kernels import Device, kernelize
+
+            kernelize(model, device=Device(type=model.device.type))
+
+        # If it is a model with generation capabilities, attempt to load generation files (generation config,
+        # custom generate function)
+        if model.can_generate() and generation_config is not None:
+            logger.info("The user-defined `generation_config` will be used to override the default generation config.")
+            model.generation_config = model.generation_config.from_dict(generation_config.to_dict())
+        elif model.can_generate() and pretrained_model_name_or_path is not None:
+            repo_loading_kwargs = {
+                "cache_dir": cache_dir,
+                "force_download": force_download,
+                "proxies": proxies,
+                "local_files_only": local_files_only,
+                "token": token,
+                "revision": revision,
+                "subfolder": subfolder,
+                **kwargs,
+            }
+            # Load generation config
+            try:
+                model.generation_config = GenerationConfig.from_pretrained(
+                    pretrained_model_name_or_path,
+                    _from_auto=from_auto_class,
+                    _from_pipeline=from_pipeline,
+                    **repo_loading_kwargs,
+                )
+            except OSError:
+                logger.info(
+                    "Generation config file not found, using a generation config created from the model config."
+                )
+                pass
+            # Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
+            if hasattr(model, "load_custom_generate"):
+                try:
+                    custom_generate = model.load_custom_generate(
+                        pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **repo_loading_kwargs
+                    )
+                    model.generate = functools.partial(custom_generate, model=model)
+                except OSError:  # there is no custom generate function
+                    pass
+
+        # Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
+        # harm performances)
+        if device_map is not None and device_mesh is None:
+            device_map_kwargs = {
+                "device_map": device_map,
+                "offload_dir": offload_folder,
+                "offload_index": offload_index,
+                "offload_buffers": offload_buffers,
+            }
+            if "skip_keys" in inspect.signature(dispatch_model).parameters:
+                device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
+            # For HQQ method we force-set the hooks for single GPU envs
+            if (
+                "force_hooks" in inspect.signature(dispatch_model).parameters
+                and hf_quantizer is not None
+                and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
+            ):
+                device_map_kwargs["force_hooks"] = True
+            if (
+                hf_quantizer is not None
+                and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8
+                and isinstance(device_map, dict)
+                and ("cpu" in device_map.values() or "disk" in device_map.values())
+            ):
+                device_map_kwargs["offload_buffers"] = True
+
+            if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
+                dispatch_model(model, **device_map_kwargs)
+
+        if hf_quantizer is not None:
+            model.hf_quantizer = hf_quantizer
+            hf_quantizer.postprocess_model(model, config=config)
+
+        if _adapter_model_path is not None:
+            adapter_kwargs["key_mapping"] = key_mapping
+            model.load_adapter(
+                _adapter_model_path,
+                adapter_name=adapter_name,
+                token=token,
+                adapter_kwargs=adapter_kwargs,
+            )
+
+        if output_loading_info:
+            if from_pt:
+                loading_info = {
+                    "missing_keys": missing_keys,
+                    "unexpected_keys": unexpected_keys,
+                    "mismatched_keys": mismatched_keys,
+                    "error_msgs": error_msgs,
+                }
+            elif from_flax:
+                loading_info = None
+            return model, loading_info
+        return model
+
+    @staticmethod
+    def _fix_state_dict_key_on_load(key: str) -> tuple[str, bool]:
+        """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
+        # Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
+        # This rename is logged.
+        if key.endswith("LayerNorm.beta"):
+            return key.replace("LayerNorm.beta", "LayerNorm.bias"), True
+        if key.endswith("LayerNorm.gamma"):
+            return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True
+
+        # Rename weight norm parametrizations to match changes across torch versions.
+        # Impacts a number of speech/wav2vec models. e.g. Hubert, Wav2Vec2, and others.
+        # This rename is not logged.
+        if hasattr(nn.utils.parametrizations, "weight_norm"):
+            if key.endswith("weight_g"):
+                return key.replace("weight_g", "parametrizations.weight.original0"), True
+            if key.endswith("weight_v"):
+                return key.replace("weight_v", "parametrizations.weight.original1"), True
+        else:
+            if key.endswith("parametrizations.weight.original0"):
+                return key.replace("parametrizations.weight.original0", "weight_g"), True
+            if key.endswith("parametrizations.weight.original1"):
+                return key.replace("parametrizations.weight.original1", "weight_v"), True
+
+        return key, False
+
+    def _get_key_renaming_mapping(
+        self,
+        checkpoint_keys: list[str],
+        key_mapping: Optional[dict[str, str]] = None,
+        loading_base_model_from_task_state_dict: bool = False,
+        loading_task_model_from_base_state_dict: bool = False,
+    ):
+        """
+        Compute a mapping between the serialized keys on disk `checkpoint_keys`, and the keys that the model
+        that we are loading expects. This is the single entry point for key renaming that will be used during
+        loading.
+        Log if any parameters have been renamed.
+        """
+        prefix = self.base_model_prefix
+        _prefix = f"{prefix}."
+
+        renamed_keys = {}
+        key_renaming_mapping = {}
+        for key in checkpoint_keys:
+            # Class specific rename
+            new_key, has_changed = self._fix_state_dict_key_on_load(key)
+
+            # Optionally map the key according to `key_mapping`
+            if key_mapping is not None:
+                for pattern, replacement in key_mapping.items():
+                    new_key, n_replace = re.subn(pattern, replacement, new_key)
+                    # Early exit of the loop
+                    if n_replace > 0:
+                        has_changed = True
+                        break
+
+            # In this case, we need to add the prefix to the keys, to match them to the expected keys
+            if loading_task_model_from_base_state_dict:
+                new_key = ".".join([prefix, new_key])
+            # In this case we need to remove the prefix from the key to match them to the expected keys, and use
+            # only the keys starting with the prefix
+            elif loading_base_model_from_task_state_dict:
+                if not new_key.startswith(_prefix):
+                    continue
+                new_key = new_key[len(_prefix) :]
+
+            key_renaming_mapping[key] = new_key
+
+            # track gamma/beta rename for logging
+            if has_changed:
+                if key.endswith("LayerNorm.gamma"):
+                    renamed_keys["LayerNorm.gamma"] = (key, new_key)
+                elif key.endswith("LayerNorm.beta"):
+                    renamed_keys["LayerNorm.beta"] = (key, new_key)
+
+        if renamed_keys:
+            warning_msg = f"A pretrained model of type `{self.__class__.__name__}` "
+            warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
+            for old_key, new_key in renamed_keys.values():
+                warning_msg += f"* `{old_key}` -> `{new_key}`\n"
+            warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
+            logger.info_once(warning_msg)
+
+        return key_renaming_mapping
+
+    @staticmethod
+    def _fix_state_dict_key_on_save(key) -> tuple[str, bool]:
+        """
+        Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save.
+        Do nothing by default, but can be overridden in particular models.
+        """
+        return key, False
+
+    def _fix_state_dict_keys_on_save(self, state_dict):
+        """
+        Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save.
+        Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`.
+        """
+        return {self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items()}
+
+    @classmethod
+    def _load_pretrained_model(
+        cls,
+        model: "PreTrainedModel",
+        state_dict: Optional[dict],
+        checkpoint_files: Optional[list[str]],
+        pretrained_model_name_or_path: Optional[str],
+        ignore_mismatched_sizes: bool = False,
+        sharded_metadata: Optional[dict] = None,
+        device_map: Optional[dict] = None,
+        disk_offload_folder: Optional[str] = None,
+        offload_state_dict: Optional[bool] = None,
+        dtype: Optional[torch.dtype] = None,
+        hf_quantizer: Optional[HfQuantizer] = None,
+        keep_in_fp32_regex: Optional[re.Pattern] = None,
+        device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
+        key_mapping: Optional[dict[str, str]] = None,
+        weights_only: bool = True,
+    ):
+        # TODO: we should only be calling hf_quantizer.skip_placement or something like that
+        is_quantized = hf_quantizer is not None
+        is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
+            QuantizationMethod.HQQ,
+            QuantizationMethod.QUARK,
+        }
+        is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in {
+            QuantizationMethod.HQQ,
+            QuantizationMethod.BITS_AND_BYTES,
+        }
+
+        # Get all the keys of the state dicts that we have to initialize the model
+        if sharded_metadata is not None:
+            original_checkpoint_keys = sharded_metadata["all_checkpoint_keys"]
+        elif state_dict is not None:
+            original_checkpoint_keys = list(state_dict.keys())
+        else:
+            original_checkpoint_keys = list(
+                load_state_dict(checkpoint_files[0], map_location="meta", weights_only=weights_only).keys()
+            )
+
+        # Check if we are in a special state, i.e. loading from a state dict coming from a different architecture
+        prefix = model.base_model_prefix
+        _prefix = f"{prefix}."
+        has_prefix_module = any(s.startswith(prefix) for s in original_checkpoint_keys) if len(prefix) > 0 else False
+        expects_prefix_module = hasattr(model, prefix) if len(prefix) > 0 else False
+        loading_task_model_from_base_state_dict = not has_prefix_module and expects_prefix_module
+        loading_base_model_from_task_state_dict = has_prefix_module and not expects_prefix_module
+
+        # Find the key names that the model expects from the serialized keys
+        key_renaming_mapping = model._get_key_renaming_mapping(
+            original_checkpoint_keys,
+            key_mapping,
+            loading_base_model_from_task_state_dict,
+            loading_task_model_from_base_state_dict,
+        )
+        checkpoint_keys = list(key_renaming_mapping.values())
+
+        # Find missing and unexpected keys from the state dict
+        missing_keys, unexpected_keys = _find_missing_and_unexpected_keys(
+            cls,
+            model,
+            original_checkpoint_keys,
+            checkpoint_keys,
+            loading_base_model_from_task_state_dict,
+            hf_quantizer,
+            device_map,
+        )
+        # Find all the keys with shape mismatch (if we ignore the mismatch, the weights need to be newly initialized the
+        # same way as missing keys)
+        mismatched_keys, mismatched_shapes = _find_mismatched_keys(
+            model,
+            state_dict,
+            checkpoint_files,
+            ignore_mismatched_sizes,
+            key_renaming_mapping,
+            is_quantized,
+            weights_only,
+        )
+
+        # We need to update both the mapping and the list of checkpoint keys to remove the mismatched ones
+        key_renaming_mapping = {k: v for k, v in key_renaming_mapping.items() if v not in mismatched_keys}
+        checkpoint_keys = list(key_renaming_mapping.values())
+
+        # Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when
+        # loading the weights as they are not in the loaded state dict)
+        model._move_missing_keys_from_meta_to_cpu(missing_keys + mismatched_keys, unexpected_keys, dtype, hf_quantizer)
+
+        # correctly initialize the missing (and potentially mismatched) keys
+        model._initialize_missing_keys(checkpoint_keys, ignore_mismatched_sizes, is_quantized)
+
+        # Set some modules to fp32 if needed
+        if keep_in_fp32_regex is not None:
+            for name, param in model.named_parameters():
+                if keep_in_fp32_regex.search(name):
+                    # param = param.to(torch.float32) does not work here as only in the local scope.
+                    param.data = param.data.to(torch.float32)
+
+        # Make sure we are able to load base models as well as derived models (specific task models, with heads)
+        model_to_load = model
+        # In this case, we load a ForTaskModel with keys from a BaseModel -> only load keys to the BaseModel
+        if loading_task_model_from_base_state_dict:
+            model_to_load = getattr(model, prefix)
+            # Here we need to remove the prefix we added to correctly find missing/unexpected keys, as we will load
+            # in the submodule
+            key_renaming_mapping = {k: v[len(_prefix) :] for k, v in key_renaming_mapping.items()}
+            checkpoint_keys = list(key_renaming_mapping.values())
+            # We need to update the device map as well
+            if device_map is not None:
+                device_map = {k[len(_prefix) :] if k.startswith(_prefix) else k: v for k, v in device_map.items()}
+            # small sanity check: the base model should not contain task-specific head keys
+            task_specific_expected_keys = [s for s in model.state_dict() if not s.startswith(_prefix)]
+            base_model_expected_keys = list(model_to_load.state_dict().keys())
+            if any(
+                key in task_specific_expected_keys and key not in base_model_expected_keys for key in checkpoint_keys
+            ):
+                raise ValueError(
+                    "The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
+                    "properly saved?"
+                )
+
+        # Get reverse key mapping
+        reverse_key_renaming_mapping = {v: k for k, v in key_renaming_mapping.items()}
+
+        is_offloaded_safetensors = False
+        # This offload index if for params explicitly on the "disk" in the device_map
+        disk_offload_index = None
+        disk_only_shard_files = []
+        # Prepare parameters offloading if needed
+        if device_map is not None and "disk" in device_map.values():
+            if offload_state_dict is None:
+                offload_state_dict = True
+            if disk_offload_folder is not None:
+                os.makedirs(disk_offload_folder, exist_ok=True)
+            is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors")
+            if disk_offload_folder is None and not is_offloaded_safetensors:
+                raise ValueError(
+                    "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
+                    " for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
+                    " offers the weights in this format."
+                )
+            if is_offloaded_safetensors:
+                param_device_map = expand_device_map(device_map, checkpoint_keys)
+                str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
+                if sharded_metadata is None:
+                    weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
+                else:
+                    folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
+                    # Fix the weight map keys according to the key mapping
+                    weight_map = {
+                        key_renaming_mapping[k]: v
+                        for k, v in sharded_metadata["weight_map"].items()
+                        if k in key_renaming_mapping
+                    }
+                    weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
+                    # Find potential checkpoints containing only offloaded weights
+                    disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
+                disk_offload_index = {
+                    name: {
+                        "safetensors_file": file,
+                        "weight_name": reverse_key_renaming_mapping[name],
+                        "dtype": str_dtype,
+                    }
+                    for name, file in weight_map.items()
+                    if param_device_map[name] == "disk"
+                }
+            else:
+                disk_offload_index = {}
+
+        # This offload index if for params that are supposed to be on the "cpu", either with or without a device_map
+        # It allows to load parameters one-by-one from the state dict, avoiding a memory peak of 2 x state_dict_size,
+        # i.e. 1x to load it, and 1x to copy it to model
+        cpu_offload_folder = None
+        cpu_offload_index = None
+        if offload_state_dict:
+            cpu_offload_folder = tempfile.mkdtemp()
+            cpu_offload_index = {}
+
+        # To be able to iterate, even if we don't use it if the state_dict is already provided
+        elif state_dict is not None:
+            checkpoint_files = [""]
+
+        # Compute expected model keys
+        expected_keys = list(model_to_load.state_dict().keys())
+        if hf_quantizer is not None:
+            expected_keys = hf_quantizer.update_expected_keys(model_to_load, expected_keys, checkpoint_keys)
+
+        if logger.level >= logging.WARNING:
+            verify_tp_plan(expected_keys, getattr(model_to_load, "_tp_plan", None))
+
+        # Warmup cuda to load the weights much faster on devices
+        if device_map is not None and not is_hqq_or_quark:
+            expanded_device_map = expand_device_map(device_map, expected_keys)
+            caching_allocator_warmup(model_to_load, expanded_device_map, hf_quantizer)
+
+        # Prepare and compatabilize arguments for serial and parallel shard loading
+        args_list = [
+            (
+                shard_file,
+                state_dict,
+                disk_only_shard_files,
+                is_hqq_or_bnb,
+                is_quantized,
+                device_map,
+                hf_quantizer,
+                key_renaming_mapping,
+                weights_only,
+                model_to_load,
+                expected_keys,
+                reverse_key_renaming_mapping,
+                disk_offload_folder,
+                disk_offload_index,
+                cpu_offload_folder,
+                cpu_offload_index,
+                is_offloaded_safetensors,
+                keep_in_fp32_regex,
+                unexpected_keys,
+                device_mesh,
+            )
+            for shard_file in checkpoint_files
+        ]
+
+        error_msgs = []
+
+        if (
+            os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
+            and not is_deepspeed_zero3_enabled()
+        ):
+            _error_msgs, disk_offload_index, cpu_offload_index = load_shard_files_with_threadpool(args_list)
+            error_msgs += _error_msgs
+        else:
+            if len(args_list) > 1:
+                args_list = logging.tqdm(args_list, desc="Loading checkpoint shards")
+
+            for args in args_list:
+                _error_msgs, disk_offload_index, cpu_offload_index = load_shard_file(args)
+                error_msgs += _error_msgs
+
+        # Adjust offloaded weights name and save if needed
+        if disk_offload_index is not None and len(disk_offload_index) > 0:
+            if loading_task_model_from_base_state_dict:
+                # We need to add the prefix of the base model
+                prefix = cls.base_model_prefix
+                if not is_offloaded_safetensors:
+                    for weight_name in disk_offload_index:
+                        shutil.move(
+                            os.path.join(disk_offload_folder, f"{weight_name}.dat"),
+                            os.path.join(disk_offload_folder, f"{prefix}.{weight_name}.dat"),
+                        )
+                disk_offload_index = {f"{prefix}.{key}": value for key, value in disk_offload_index.items()}
+            if not is_offloaded_safetensors:
+                save_offload_index(disk_offload_index, disk_offload_folder)
+                disk_offload_index = None
+
+        # one-at-a-time param loading for the cpu offloaded params
+        if offload_state_dict:
+            # Load back temporarily offloaded state dict
+            load_offloaded_weights(model_to_load, cpu_offload_index, cpu_offload_folder)
+            shutil.rmtree(cpu_offload_folder)
+
+        if hf_quantizer is not None:
+            missing_keys = hf_quantizer.update_missing_keys_after_loading(model_to_load, missing_keys, prefix)
+
+        # Post-processing for tensor parallelism
+        if device_mesh is not None:
+            # When using TP, the device map is a single device for all parameters
+            tp_device = list(device_map.values())[0]
+            # This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
+            # not part of the state_dict (persistent=False)
+            for buffer in model.buffers():
+                if buffer.device != tp_device:
+                    buffer.data = buffer.to(tp_device)
+
+            # In this case, the top-most task module weights were not moved to device and parallelized as they
+            # were not part of the loaded weights: do it now
+            if loading_task_model_from_base_state_dict:
+                parameters_to_initialize = {
+                    name: param for name, param in model.named_parameters() if not name.startswith(prefix)
+                }
+                for name, param in parameters_to_initialize.items():
+                    # If it is still on meta here, it means that it's a tied weight that will be tied later anyway -> skip it
+                    if param.device.type == "meta":
+                        continue
+                    # Shard the param
+                    to_contiguous, casting_dtype = _infer_parameter_dtype(model, name, param, keep_in_fp32_regex)
+                    shard_and_distribute_module(
+                        model,
+                        param.to(tp_device),
+                        param,
+                        name,
+                        casting_dtype,
+                        to_contiguous,
+                        device_mesh.get_local_rank(),
+                        device_mesh,
+                    )
+
+        # All potential warnings/infos
+        if len(error_msgs) > 0:
+            error_msg = "\n\t".join(error_msgs)
+            if "size mismatch" in error_msg:
+                error_msg += (
+                    "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
+                )
+            raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
+        if len(unexpected_keys) > 0:
+            archs = [] if model.config.architectures is None else model.config.architectures
+            warner = logger.warning if model.__class__.__name__ in archs else logger.info
+            warner(
+                f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
+                f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
+                f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
+                " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
+                " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
+                f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
+                " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
+            )
+        else:
+            logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
+        if len(missing_keys) > 0:
+            logger.warning(
+                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+                f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
+                " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
+            )
+        elif len(mismatched_keys) == 0:
+            logger.info(
+                f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
+                f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
+                f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
+                " training."
+            )
+        if len(mismatched_keys) > 0:
+            mismatched_warning = "\n".join(
+                [
+                    f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
+                    for key, (shape1, shape2) in zip(mismatched_keys, mismatched_shapes)
+                ]
+            )
+            logger.warning(
+                f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
+                f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
+                f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
+                " to use it for predictions and inference."
+            )
+
+        return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs
+
+    @classmethod
+    def _load_from_tf(cls, model, config, checkpoint_files):
+        if checkpoint_files[0].endswith(".index"):
+            # Load from a TensorFlow 1.X checkpoint - provided by original authors
+            model = cls.load_tf_weights(model, config, checkpoint_files[0][:-6])  # Remove the '.index'
+            loading_info = None
+        else:
+            # Load from our TensorFlow 2.0 checkpoints
+            try:
+                from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model
+
+                model, loading_info = load_tf2_checkpoint_in_pytorch_model(
+                    model, checkpoint_files[0], allow_missing_keys=True, output_loading_info=True
+                )
+            except ImportError:
+                logger.error(
+                    "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed."
+                    " Please see https://pytorch.org/ and https://www.tensorflow.org/install/ for installation"
+                    " instructions."
+                )
+                raise
+        return model, loading_info
+
+    @classmethod
+    def _load_from_flax(cls, model, checkpoint_files):
+        try:
+            from .modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model
+
+            model = load_flax_checkpoint_in_pytorch_model(model, checkpoint_files[0])
+        except ImportError:
+            logger.error(
+                "Loading a Flax model in PyTorch, requires both PyTorch and Flax to be installed. Please see"
+                " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for"
+                " installation instructions."
+            )
+            raise
+        return model
+
+    def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
+        module_keys = {".".join(key.split(".")[:-1]) for key in names}
+
+        # torch.nn.ParameterList is a special case where two parameter keywords
+        # are appended to the module name, *e.g.* bert.special_embeddings.0
+        module_keys = module_keys.union(
+            {".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()}
+        )
+
+        retrieved_modules = []
+        # retrieve all modules that has at least one missing weight name
+        for name, module in self.named_modules():
+            if remove_prefix:
+                _prefix = f"{self.base_model_prefix}."
+                name = name[len(_prefix) :] if name.startswith(_prefix) else name
+            elif add_prefix:
+                name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix
+
+            if name in module_keys:
+                retrieved_modules.append(module)
+
+        return retrieved_modules
+
+    @classmethod
+    def register_for_auto_class(cls, auto_class="AutoModel"):
+        """
+        Register this class with a given auto class. This should only be used for custom models as the ones in the
+        library are already mapped with an auto class.
+
+
+
+        Args:
+            auto_class (`str` or `type`, *optional*, defaults to `"AutoModel"`):
+                The auto class to register this new model with.
+        """
+        if not isinstance(auto_class, str):
+            auto_class = auto_class.__name__
+
+        import transformers.models.auto as auto_module
+
+        if not hasattr(auto_module, auto_class):
+            raise ValueError(f"{auto_class} is not a valid auto class.")
+
+        cls._auto_class = auto_class
+
+    def to_bettertransformer(self) -> "PreTrainedModel":
+        """
+        Converts the model to use [PyTorch's native attention
+        implementation](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html), integrated to
+        Transformers through [Optimum library](https://huggingface.co/docs/optimum/bettertransformer/overview). Only a
+        subset of all Transformers models are supported.
+
+        PyTorch's attention fastpath allows to speed up inference through kernel fusions and the use of [nested
+        tensors](https://pytorch.org/docs/stable/nested.html). Detailed benchmarks can be found in [this blog
+        post](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2).
+
+        Returns:
+            [`PreTrainedModel`]: The model converted to BetterTransformer.
+        """
+        if not is_optimum_available():
+            raise ImportError("The package `optimum` is required to use Better Transformer.")
+
+        from optimum.version import __version__ as optimum_version
+
+        if version.parse(optimum_version) < version.parse("1.7.0"):
+            raise ImportError(
+                f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found."
+            )
+
+        from optimum.bettertransformer import BetterTransformer
+
+        return BetterTransformer.transform(self)
+
+    def reverse_bettertransformer(self):
+        """
+        Reverts the transformation from [`~PreTrainedModel.to_bettertransformer`] so that the original modeling is
+        used, for example in order to save the model.
+
+        Returns:
+            [`PreTrainedModel`]: The model converted back to the original modeling.
+        """
+        if not is_optimum_available():
+            raise ImportError("The package `optimum` is required to use Better Transformer.")
+
+        from optimum.version import __version__ as optimum_version
+
+        if version.parse(optimum_version) < version.parse("1.7.0"):
+            raise ImportError(
+                f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found."
+            )
+
+        from optimum.bettertransformer import BetterTransformer
+
+        return BetterTransformer.reverse(self)
+
+    def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask):
+        """
+        Shows a one-time warning if the input_ids appear to contain padding and no attention mask was given.
+        """
+
+        # Skip the check during tracing.
+        if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing() or is_torchdynamo_compiling():
+            return
+
+        if (attention_mask is not None) or (self.config.pad_token_id is None):
+            return
+
+        # Check only the first and last input IDs to reduce overhead.
+        if self.config.pad_token_id in input_ids[:, [-1, 0]]:
+            warn_string = (
+                "We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See "
+                "https://huggingface.co/docs/transformers/troubleshooting"
+                "#incorrect-output-when-padding-tokens-arent-masked."
+            )
+
+            # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an
+            # attention_mask or not. In this case, we should still show a warning because this is a rare case.
+            if (
+                (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id)
+                or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id)
+                or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id)
+            ):
+                warn_string += (
+                    f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical "
+                    f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), "
+                    f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded."
+                )
+
+            logger.warning_once(warn_string)
+
+    @property
+    def supports_tp_plan(self):
+        """
+        Returns whether the model has a tensor parallelism plan.
+        """
+        if self._tp_plan is not None:
+            return True
+        # Check if base model has a TP plan
+        if getattr(self.base_model, "_tp_plan", None) is not None:
+            return True
+        if self.config.base_model_tp_plan is not None:
+            return True
+        return False
+
+    @property
+    def tp_size(self):
+        """
+        Returns the model's tensor parallelism degree.
+        """
+        # if None, the model didn't undergo tensor parallel sharding
+        return self._tp_size
+
+    @property
+    def supports_pp_plan(self):
+        if self._pp_plan is not None:
+            return True
+        # Check if base model has PP plan
+        if getattr(self.base_model, "_pp_plan", None) is not None:
+            return True
+        return False
+
+    @property
+    def loss_function(self):
+        if hasattr(self, "_loss_function"):
+            return self._loss_function
+
+        loss_type = getattr(self, "loss_type", None)
+
+        if loss_type is None or loss_type not in LOSS_MAPPING:
+            logger.warning_once(
+                f"`loss_type={loss_type}` was set in the config but it is unrecognized. "
+                f"Using the default loss: `ForCausalLMLoss`."
+            )
+            loss_type = "ForCausalLM"
+        return LOSS_MAPPING[loss_type]
+
+    @loss_function.setter
+    def loss_function(self, value):
+        self._loss_function = value
+
+    def get_compiled_call(self, compile_config: Optional[CompileConfig]) -> Callable:
+        """Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
+        non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
+        want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
+        (where we want the speed-ups of compiled version with static shapes)."""
+        # Only reset it if not present or different from previous config
+        if "llama4" in self.config.model_type:  # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT
+            return self.__call__
+        compile_config = compile_config or CompileConfig()
+        default_config = getattr(self.generation_config, "compile_config", None) or CompileConfig()
+        if (
+            not hasattr(self, "_compiled_call")
+            or getattr(self, "_last_compile_config", default_config) != compile_config
+        ):
+            self._last_compile_config = compile_config
+            self._compiled_call = torch.compile(self.__call__, **compile_config.to_dict())
+        return self._compiled_call
+
+    @classmethod
+    def is_backend_compatible(cls):
+        return cls._supports_attention_backend
+
+    def _move_missing_keys_from_meta_to_cpu(
+        self,
+        missing_keys: list[str],
+        unexpected_keys: list[str],
+        dtype: Optional[torch.dtype],
+        hf_quantizer: Optional[HfQuantizer],
+    ) -> "PreTrainedModel":
+        """Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts) back
+        from meta device to cpu.
+        """
+        is_quantized = hf_quantizer is not None
+
+        # In this case we need to move everything back
+        if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
+            # We only do it for the parameters, as the buffers are not initialized on the meta device by default
+            for key, param in self.named_parameters():
+                value = torch.empty_like(param, dtype=dtype, device="cpu")
+                _load_parameter_into_model(self, key, value)
+            return
+
+        model_state_dict = self.state_dict()
+        for key in missing_keys:
+            param = model_state_dict[key]
+            # Buffers are not initialized on the meta device, so we still need this check to avoid overwriting them
+            if param.device == torch.device("meta"):
+                value = torch.empty_like(param, dtype=dtype, device="cpu")
+                if (
+                    not is_quantized
+                    or (getattr(hf_quantizer, "requires_parameters_quantization", False))
+                    or not hf_quantizer.check_quantized_param(self, param_value=value, param_name=key, state_dict={})
+                ):
+                    _load_parameter_into_model(self, key, value)
+                else:
+                    hf_quantizer.create_quantized_param(self, value, key, "cpu", model_state_dict, unexpected_keys)
+
+    def _initialize_missing_keys(
+        self,
+        loaded_keys: list[str],
+        ignore_mismatched_sizes: bool,
+        is_quantized: bool,
+    ) -> "PreTrainedModel":
+        """Initialize the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts), according to
+        `_initialize_weights`. Indeed, since the corresponding weights are missing from the state dict, they will not be replaced and need to
+        be initialized correctly (i.e. weight initialization distribution).
+        Also take care of setting the `_is_hf_initialized` flag for keys that are not missing.
+        """
+        if not ignore_mismatched_sizes:
+            not_initialized_submodules = set_initialized_submodules(self, loaded_keys)
+            # If we're about to tie the output embeds to the input embeds we don't need to init them
+            if (
+                hasattr(self.config.get_text_config(decoder=True), "tie_word_embeddings")
+                and self.config.get_text_config(decoder=True).tie_word_embeddings
+            ):
+                output_embeddings = self.get_output_embeddings()
+                if output_embeddings is not None:
+                    # Still need to initialize if there is a bias term since biases are not tied.
+                    if not hasattr(output_embeddings, "bias") or output_embeddings.bias is None:
+                        output_embeddings._is_hf_initialized = True
+        else:
+            not_initialized_submodules = dict(self.named_modules())
+        # This will only initialize submodules that are not marked as initialized by the line above.
+        if is_deepspeed_zero3_enabled() and not is_quantized:
+            import deepspeed
+
+            not_initialized_parameters = list(
+                set(
+                    itertools.chain.from_iterable(
+                        submodule.parameters(recurse=False) for submodule in not_initialized_submodules.values()
+                    )
+                )
+            )
+            with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
+                self.initialize_weights()
+        else:
+            self.initialize_weights()
+
+    def get_parameter_or_buffer(self, target: str):
+        """
+        Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines
+        `get_parameter()` and `get_buffer()` in a single handy function. If the target is an `_extra_state` attribute,
+        it will return the extra state provided by the module. Note that it only work if `target` is a leaf of the model.
+        """
+        try:
+            return self.get_parameter(target)
+        except AttributeError:
+            pass
+        try:
+            return self.get_buffer(target)
+        except AttributeError:
+            pass
+        module, param_name = get_module_from_name(self, target)
+        if (
+            param_name == "_extra_state"
+            and getattr(module.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
+            is not torch.nn.Module.get_extra_state
+        ):
+            return module.get_extra_state()
+
+        raise AttributeError(f"`{target}` is neither a parameter, buffer, nor extra state.")
+
+
+PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
+if PreTrainedModel.push_to_hub.__doc__ is not None:
+    PreTrainedModel.push_to_hub.__doc__ = PreTrainedModel.push_to_hub.__doc__.format(
+        object="model", object_class="AutoModel", object_files="model file"
+    )
+
+
+def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
+    """
+    Recursively unwraps a model from potential containers (as used in distributed training).
+
+    Args:
+        model (`torch.nn.Module`): The model to unwrap.
+        recursive (`bool`, *optional*, defaults to `False`):
+            Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers
+            recursively, not just the top-level distributed containers.
+    """
+    # Use accelerate implementation if available (should always be the case when using torch)
+    # This is for pytorch, as we also have to handle things like dynamo
+    if is_accelerate_available():
+        kwargs = {}
+        if recursive:
+            if not is_accelerate_available("0.29.0"):
+                raise RuntimeError(
+                    "Setting `recursive=True` to `unwrap_model` requires `accelerate` v0.29.0. Please upgrade your version of accelerate"
+                )
+            else:
+                kwargs["recursive"] = recursive
+        return extract_model_from_parallel(model, **kwargs)
+    else:
+        # since there could be multiple levels of wrapping, unwrap recursively
+        if hasattr(model, "module"):
+            return unwrap_model(model.module)
+        else:
+            return model
+
+
+def expand_device_map(device_map, param_names):
+    """
+    Expand a device map to return the correspondence parameter name to device.
+    """
+    new_device_map = {}
+    for module, device in device_map.items():
+        new_device_map.update(
+            {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
+        )
+    return new_device_map
+
+
+def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
+    """Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
+    a proper `torch.device`.
+    """
+    if device == "disk":
+        return False
+    else:
+        return torch.device(device).type not in ["meta", "cpu"]
+
+
+def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, hf_quantizer: Optional[HfQuantizer]):
+    """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
+    device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
+    the model, which is actually the loading speed bottleneck.
+    Calling this function allows to cut the model loading time by a very large margin.
+
+    A few facts related to loading speed (taking into account the use of this function):
+    - When loading a model the first time, it is usually slower than the subsequent times, because the OS is very likely
+    to cache the different state dicts (if enough resources/RAM are available)
+    - Trying to force the OS to cache the files in advance (by e.g. accessing a small portion of them) is really hard,
+    and not a good idea in general as this is low level OS optimizations that depend on resource usage anyway
+    - As of 18/03/2025, loading a Llama 70B model with TP takes ~1 min without file cache, and ~13s with full file cache.
+    The baseline, i.e. only loading the tensor shards on device and adjusting dtype (i.e. copying them) is ~5s with full cache.
+    These numbers are reported for TP on 4 H100 GPUs.
+    - It is useless to pre-allocate more than the model size in this function (i.e. using an `allocation_factor` > 1) as
+    cudaMalloc is not a bottleneck at all anymore
+    - Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
+    However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
+    """
+    factor = 2 if hf_quantizer is None else hf_quantizer.get_accelerator_warm_up_factor()
+
+    # Remove disk, cpu and meta devices, and cast to proper torch.device
+    accelerator_device_map = {
+        param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)
+    }
+    if not accelerator_device_map:
+        return
+
+    tp_plan = getattr(model, "_tp_plan", []) or []
+    tp_plan_regex = (
+        re.compile("|".join([re.escape(plan) for plan in tp_plan]))
+        if _torch_distributed_available and torch.distributed.is_initialized()
+        else None
+    )
+    total_byte_count = defaultdict(lambda: 0)
+    tied_param_names = _get_tied_weight_keys(model)
+    for param_name, device in accelerator_device_map.items():
+        # Skip if the parameter has already been accounted for (tied weights)
+        if param_name in tied_param_names:
+            continue
+
+        # For example in the case of MXFP4 quantization, we need to update the param name to the original param name
+        # because the checkpoint contains blocks, and scales, but since we are dequantizing, we need to use the original param name
+        if hf_quantizer is not None:
+            param_name = hf_quantizer.update_param_name(param_name)
+
+        try:
+            param = model.get_parameter_or_buffer(param_name)
+        except AttributeError:
+            raise AttributeError(f"Parameter {param_name} not found in model")
+
+        # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
+        param_byte_count = param.numel() * param.element_size()
+
+        if tp_plan_regex is not None:
+            generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
+            param_byte_count //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1
+
+        total_byte_count[device] += param_byte_count
+
+    # This will kick off the caching allocator to avoid having to Malloc afterwards
+    for device, byte_count in total_byte_count.items():
+        if device.type in ["cuda", "xpu"]:
+            torch_accelerator_module = getattr(torch, device.type)
+            index = device.index if device.index is not None else torch_accelerator_module.current_device()
+            device_memory = torch_accelerator_module.mem_get_info(index)[0]
+            # Allow up to (max device memory - 1.2 GiB) in resource-constrained hardware configurations. Trying to reserve more
+            # than that amount might sometimes lead to unnecessary cuda/xpu OOM, if the last parameter to be loaded on the device is large,
+            # and the remaining reserved memory portion is smaller than the param size -> torch will then try to fully re-allocate all
+            # the param size, instead of using the remaining reserved part, and allocating only the difference, which can lead
+            # to OOM. See https://github.com/huggingface/transformers/issues/37436#issuecomment-2808982161 for more details.
+            # Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much
+            # if using e.g. 90% of device size, while a 140GiB device would allocate too little
+            byte_count = min(byte_count, max(0, int(device_memory - 1.2 * 1024**3)))
+            # If there is *unused* reserved cuda/xpu memory, we can skip/reduce the allocation.
+            unused_memory = torch_accelerator_module.memory_reserved(
+                index
+            ) - torch_accelerator_module.memory_allocated(index)
+            byte_count = max(0, byte_count - unused_memory)
+        # Allocate memory
+        _ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
+
+
+def get_disk_only_shard_files(device_map, weight_map):
+    """
+    Returns the list of shard files containing only weights offloaded to disk.
+    """
+    files_content = collections.defaultdict(list)
+    for weight_name, filename in weight_map.items():
+        while len(weight_name) > 0 and weight_name not in device_map:
+            weight_name = ".".join(weight_name.split(".")[:-1])
+        files_content[filename].append(device_map[weight_name])
+
+    return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
+
+
+class AttentionInterface(GeneralInterface):
+    """
+    Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
+    with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
+    it needs to declare a new instance of this class inside the `modeling_.py`, and declare it on that instance.
+    """
+
+    # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
+    # a new instance is created (in order to locally override a given function)
+    _global_mapping = {
+        "flash_attention_3": flash_attention_forward,
+        "flash_attention_2": flash_attention_forward,
+        "flex_attention": flex_attention_forward,
+        "paged_attention": paged_attention_forward,
+        "sdpa": sdpa_attention_forward,
+        "sdpa_paged": sdpa_attention_paged_forward,
+        "eager_paged": eager_paged_attention_forward,
+    }
+
+
+# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
+ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface()
+
+
+class PreTrainedAudioTokenizerBase(PreTrainedModel):
+    """
+    Class that additionally defines the behavior of any `audio_tokenizer` to be added.
+    Characteristic for any of them:
+        1. Encode raw audio into discrete audio codebooks (with x channels)
+        2. Decode from discrete audio codebooks back to raw audio
+    It is possible that they can decode in different ways given a different representation
+    but they are forced to support 2. nonetheless, e.g. see `DAC`.
+    """
+
+    @abstractmethod
+    def encode(self, input_values: torch.Tensor, *args, **kwargs):
+        """
+        Encode raw audio retrieved from a respective `FeatureExtractor` into discrete audio codebooks (with x channels)
+        """
+        pass
+
+    @abstractmethod
+    def decode(self, audio_codes: torch.Tensor, *args, **kwargs):
+        """Decode from discrete audio codebooks back to raw audio"""
+        pass
diff --git a/phivenv/Lib/site-packages/transformers/optimization.py b/phivenv/Lib/site-packages/transformers/optimization.py
new file mode 100644
index 0000000000000000000000000000000000000000..688d0f8db56f393910dc05bb35b53213b46ced2b
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/optimization.py
@@ -0,0 +1,973 @@
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch optimization for BERT model."""
+
+import math
+import warnings
+from functools import partial
+from typing import Optional, Union
+
+import torch
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
+
+from .trainer_pt_utils import LayerWiseDummyOptimizer, LayerWiseDummyScheduler
+from .trainer_utils import SchedulerType
+from .utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def _get_constant_lambda(_=None):
+    return 1
+
+
+def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
+    """
+    Create a schedule with a constant learning rate, using the learning rate set in optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch)
+
+
+def get_reduce_on_plateau_schedule(optimizer: Optimizer, **kwargs):
+    """
+    Create a schedule with a constant learning rate that decreases when a metric has stopped improving.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        kwargs (`dict`, *optional*):
+            Extra parameters to be passed to the scheduler. See `torch.optim.lr_scheduler.ReduceLROnPlateau`
+            for possible parameters.
+
+    Return:
+        `torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule.
+    """
+
+    return ReduceLROnPlateau(optimizer, **kwargs)
+
+
+def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
+    if current_step < num_warmup_steps:
+        return float(current_step) / float(max(1.0, num_warmup_steps))
+    return 1.0
+
+
+def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
+    """
+    Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
+    increases linearly between 0 and the initial lr set in the optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    lr_lambda = partial(_get_constant_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps)
+    return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
+
+
+def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int):
+    if current_step < num_warmup_steps:
+        return float(current_step) / float(max(1, num_warmup_steps))
+    return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
+
+
+def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
+    """
+    Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
+    a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (`int`):
+            The total number of training steps.
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    lr_lambda = partial(
+        _get_linear_schedule_with_warmup_lr_lambda,
+        num_warmup_steps=num_warmup_steps,
+        num_training_steps=num_training_steps,
+    )
+    return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def _get_cosine_schedule_with_warmup_lr_lambda(
+    current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float
+):
+    if current_step < num_warmup_steps:
+        return float(current_step) / float(max(1, num_warmup_steps))
+    progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+    return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
+
+
+def get_cosine_schedule_with_warmup(
+    optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
+):
+    """
+    Create a schedule with a learning rate that decreases following the values of the cosine function between the
+    initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
+    initial lr set in the optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (`int`):
+            The total number of training steps.
+        num_cycles (`float`, *optional*, defaults to 0.5):
+            The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
+            following a half-cosine).
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    lr_lambda = partial(
+        _get_cosine_schedule_with_warmup_lr_lambda,
+        num_warmup_steps=num_warmup_steps,
+        num_training_steps=num_training_steps,
+        num_cycles=num_cycles,
+    )
+    return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda(
+    current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: int
+):
+    if current_step < num_warmup_steps:
+        return float(current_step) / float(max(1, num_warmup_steps))
+    progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+    if progress >= 1.0:
+        return 0.0
+    return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
+
+
+def get_cosine_with_hard_restarts_schedule_with_warmup(
+    optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
+):
+    """
+    Create a schedule with a learning rate that decreases following the values of the cosine function between the
+    initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
+    linearly between 0 and the initial lr set in the optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (`int`):
+            The total number of training steps.
+        num_cycles (`int`, *optional*, defaults to 1):
+            The number of hard restarts to use.
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    lr_lambda = partial(
+        _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda,
+        num_warmup_steps=num_warmup_steps,
+        num_training_steps=num_training_steps,
+        num_cycles=num_cycles,
+    )
+    return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def _get_polynomial_decay_schedule_with_warmup_lr_lambda(
+    current_step: int,
+    *,
+    num_warmup_steps: int,
+    num_training_steps: int,
+    lr_end: float,
+    power: float,
+    lr_init: int,
+):
+    if current_step < num_warmup_steps:
+        return float(current_step) / float(max(1, num_warmup_steps))
+    elif current_step > num_training_steps:
+        return lr_end / lr_init  # as LambdaLR multiplies by lr_init
+    else:
+        lr_range = lr_init - lr_end
+        decay_steps = num_training_steps - num_warmup_steps
+        pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
+        decay = lr_range * pct_remaining**power + lr_end
+        return decay / lr_init  # as LambdaLR multiplies by lr_init
+
+
+def get_polynomial_decay_schedule_with_warmup(
+    optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
+):
+    """
+    Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
+    optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
+    initial lr set in the optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (`int`):
+            The total number of training steps.
+        lr_end (`float`, *optional*, defaults to 1e-7):
+            The end LR.
+        power (`float`, *optional*, defaults to 1.0):
+            Power factor.
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
+    implementation at
+    https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+
+    """
+
+    lr_init = optimizer.defaults["lr"]
+    if not (lr_init > lr_end):
+        raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")
+
+    lr_lambda = partial(
+        _get_polynomial_decay_schedule_with_warmup_lr_lambda,
+        num_warmup_steps=num_warmup_steps,
+        num_training_steps=num_training_steps,
+        lr_end=lr_end,
+        power=power,
+        lr_init=lr_init,
+    )
+    return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: Optional[int] = None):
+    if current_step < num_warmup_steps:
+        return float(current_step) / float(max(1, num_warmup_steps))
+    shift = timescale - num_warmup_steps
+    decay = 1.0 / math.sqrt((current_step + shift) / timescale)
+    return decay
+
+
+def get_inverse_sqrt_schedule(
+    optimizer: Optimizer, num_warmup_steps: int, timescale: Optional[int] = None, last_epoch: int = -1
+):
+    """
+    Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a
+    warmup period which increases lr linearly from 0 to the initial lr set in the optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        timescale (`int`, *optional*, defaults to `num_warmup_steps`):
+            Time scale.
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+    # Note: this implementation is adapted from
+    # https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930
+
+    if timescale is None:
+        timescale = num_warmup_steps or 10_000
+
+    lr_lambda = partial(_get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale)
+    return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
+
+
+def _get_cosine_schedule_with_warmup_lr_lambda(
+    current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0
+):
+    if current_step < num_warmup_steps:
+        return float(current_step) / float(max(1, num_warmup_steps))
+    progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
+    factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
+    factor = factor * (1 - min_lr_rate) + min_lr_rate
+    return max(0, factor)
+
+
+def get_cosine_with_min_lr_schedule_with_warmup(
+    optimizer: Optimizer,
+    num_warmup_steps: int,
+    num_training_steps: int,
+    num_cycles: float = 0.5,
+    last_epoch: int = -1,
+    min_lr: Optional[float] = None,
+    min_lr_rate: Optional[float] = None,
+):
+    """
+    Create a schedule with a learning rate that decreases following the values of the cosine function between the
+    initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the
+    initial lr set in the optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (`int`):
+            The total number of training steps.
+        num_cycles (`float`, *optional*, defaults to 0.5):
+            The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
+            following a half-cosine).
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+        min_lr (`float`, *optional*):
+            The minimum learning rate to reach after the cosine schedule.
+        min_lr_rate (`float`, *optional*):
+            The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    if min_lr is not None and min_lr_rate is not None:
+        raise ValueError("Only one of min_lr or min_lr_rate should be set")
+    elif min_lr is not None:
+        min_lr_rate = min_lr / optimizer.defaults["lr"]
+    elif min_lr_rate is None:
+        raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`")
+
+    lr_lambda = partial(
+        _get_cosine_schedule_with_warmup_lr_lambda,
+        num_warmup_steps=num_warmup_steps,
+        num_training_steps=num_training_steps,
+        num_cycles=num_cycles,
+        min_lr_rate=min_lr_rate,
+    )
+    return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def _get_cosine_with_min_lr_schedule_with_warmup_lr_rate_lambda(
+    current_step: int,
+    *,
+    num_warmup_steps: int,
+    num_training_steps: int,
+    num_cycles: float,
+    min_lr_rate: float = 0.0,
+    warmup_lr_rate: Optional[float] = None,
+):
+    current_step = float(current_step)
+    num_warmup_steps = float(num_warmup_steps)
+    num_training_steps = float(num_training_steps)
+
+    if current_step < num_warmup_steps:
+        if warmup_lr_rate is None:
+            return (current_step + 1.0) / max(1.0, num_warmup_steps)
+        else:
+            warmup_lr_rate = float(warmup_lr_rate)
+            return warmup_lr_rate + (1.0 - warmup_lr_rate) * (current_step) / (max(1, num_warmup_steps - 1))
+    progress = (current_step - num_warmup_steps + 1.0) / (max(1.0, num_training_steps - num_warmup_steps))
+    factor = 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress))
+    factor = factor * (1 - min_lr_rate) + min_lr_rate
+    return max(0, factor)
+
+
+def get_cosine_with_min_lr_schedule_with_warmup_lr_rate(
+    optimizer: Optimizer,
+    num_warmup_steps: int,
+    num_training_steps: int,
+    num_cycles: float = 0.5,
+    last_epoch: int = -1,
+    min_lr: Optional[float] = None,
+    min_lr_rate: Optional[float] = None,
+    warmup_lr_rate: Optional[float] = None,
+):
+    """
+    Create a schedule with a learning rate that decreases following the values of the cosine function between the
+    initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the
+    initial lr set in the optimizer.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_training_steps (`int`):
+            The total number of training steps.
+        num_cycles (`float`, *optional*, defaults to 0.5):
+            The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
+            following a half-cosine).
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+        min_lr (`float`, *optional*):
+            The minimum learning rate to reach after the cosine schedule.
+        min_lr_rate (`float`, *optional*):
+            The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set.
+        warmup_lr_rate (`float`, *optional*):
+            The minimum learning rate as a ratio of the start learning rate. If not set, `warmup_lr_rate` will be treated as float(1/num_warmup_steps).
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    if min_lr is not None and min_lr_rate is not None:
+        raise ValueError("Only one of min_lr or min_lr_rate should be set")
+    elif min_lr is not None:
+        min_lr_rate = min_lr / optimizer.defaults["lr"]
+    elif min_lr_rate is None:
+        raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`")
+
+    lr_lambda = partial(
+        _get_cosine_with_min_lr_schedule_with_warmup_lr_rate_lambda,
+        num_warmup_steps=num_warmup_steps,
+        num_training_steps=num_training_steps,
+        num_cycles=num_cycles,
+        min_lr_rate=min_lr_rate,
+        warmup_lr_rate=warmup_lr_rate,
+    )
+    return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+def _get_wsd_scheduler_lambda(
+    current_step: int,
+    *,
+    num_warmup_steps: int,
+    num_stable_steps: int,
+    num_decay_steps: int,
+    warmup_type: str,
+    decay_type: str,
+    min_lr_ratio: float,
+    num_cycles: float,
+):
+    if current_step < num_warmup_steps:
+        progress = float(current_step) / float(max(1, num_warmup_steps))
+        if warmup_type == "linear":
+            factor = progress
+        elif warmup_type == "cosine":
+            factor = 0.5 * (1.0 - math.cos(math.pi * progress))
+        elif warmup_type == "1-sqrt":
+            factor = 1.0 - math.sqrt(1.0 - progress)
+        factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio
+        return max(0.0, factor)
+
+    if current_step < num_warmup_steps + num_stable_steps:
+        return 1.0
+
+    if current_step < num_warmup_steps + num_stable_steps + num_decay_steps:
+        progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps))
+        if decay_type == "linear":
+            factor = 1.0 - progress
+        elif decay_type == "cosine":
+            factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
+        elif decay_type == "1-sqrt":
+            factor = 1.0 - math.sqrt(progress)
+        factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio
+        return max(0.0, factor)
+    return min_lr_ratio
+
+
+def get_wsd_schedule(
+    optimizer: Optimizer,
+    num_warmup_steps: int,
+    num_decay_steps: int,
+    num_training_steps: Optional[int] = None,
+    num_stable_steps: Optional[int] = None,
+    warmup_type: str = "linear",
+    decay_type: str = "cosine",
+    min_lr_ratio: float = 0,
+    num_cycles: float = 0.5,
+    last_epoch: int = -1,
+):
+    """
+    Create a schedule with a learning rate that has three stages:
+    1. warmup: increase from min_lr_ratio times the initial learning rate to the initial learning rate following a warmup_type.
+    2. stable: constant learning rate.
+    3. decay: decrease from the initial learning rate to min_lr_ratio times the initial learning rate following a decay_type.
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        num_warmup_steps (`int`):
+            The number of steps for the warmup phase.
+        num_decay_steps (`int`):
+            The number of steps for the decay phase.
+        num_training_steps (`int`, *optional*):
+            The total number of training steps. This is the sum of the warmup, stable and decay steps. If `num_stable_steps` is not provided, the stable phase will be `num_training_steps - num_warmup_steps - num_decay_steps`.
+        num_stable_steps (`int`, *optional*):
+            The number of steps for the stable phase. Please ensure that `num_warmup_steps + num_stable_steps + num_decay_steps` equals `num_training_steps`, otherwise the other steps will default to the minimum learning rate.
+        warmup_type (`str`, *optional*, defaults to "linear"):
+            The type of warmup to use. Can be 'linear', 'cosine' or '1-sqrt'.
+        decay_type (`str`, *optional*, defaults to "cosine"):
+            The type of decay to use. Can be 'linear', 'cosine' or '1-sqrt'.
+        min_lr_ratio (`float`, *optional*, defaults to 0):
+            The minimum learning rate as a ratio of the initial learning rate.
+        num_cycles (`float`, *optional*, defaults to 0.5):
+            The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
+            following a half-cosine).
+        last_epoch (`int`, *optional*, defaults to -1):
+            The index of the last epoch when resuming training.
+
+    Return:
+        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
+    """
+
+    if num_training_steps is None and num_stable_steps is None:
+        raise ValueError("Either num_training_steps or num_stable_steps must be specified.")
+
+    if num_training_steps is not None and num_stable_steps is not None:
+        warnings.warn("Both num_training_steps and num_stable_steps are specified. num_stable_steps will be used.")
+
+    if warmup_type not in ["linear", "cosine", "1-sqrt"]:
+        raise ValueError(f"Unknown warmup type: {warmup_type}, expected 'linear', 'cosine' or '1-sqrt'")
+
+    if decay_type not in ["linear", "cosine", "1-sqrt"]:
+        raise ValueError(f"Unknown decay type: {decay_type}, expected 'linear', 'cosine' or '1-sqrt'")
+
+    if num_stable_steps is None:
+        num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
+
+    lr_lambda = partial(
+        _get_wsd_scheduler_lambda,
+        num_warmup_steps=num_warmup_steps,
+        num_stable_steps=num_stable_steps,
+        num_decay_steps=num_decay_steps,
+        warmup_type=warmup_type,
+        decay_type=decay_type,
+        min_lr_ratio=min_lr_ratio,
+        num_cycles=num_cycles,
+    )
+    return LambdaLR(optimizer, lr_lambda, last_epoch)
+
+
+TYPE_TO_SCHEDULER_FUNCTION = {
+    SchedulerType.LINEAR: get_linear_schedule_with_warmup,
+    SchedulerType.COSINE: get_cosine_schedule_with_warmup,
+    SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
+    SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
+    SchedulerType.CONSTANT: get_constant_schedule,
+    SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
+    SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
+    SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule,
+    SchedulerType.COSINE_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup,
+    SchedulerType.COSINE_WARMUP_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup_lr_rate,
+    SchedulerType.WARMUP_STABLE_DECAY: get_wsd_schedule,
+}
+
+
+def get_scheduler(
+    name: Union[str, SchedulerType],
+    optimizer: Optimizer,
+    num_warmup_steps: Optional[int] = None,
+    num_training_steps: Optional[int] = None,
+    scheduler_specific_kwargs: Optional[dict] = None,
+):
+    """
+    Unified API to get any scheduler from its name.
+
+    Args:
+        name (`str` or `SchedulerType`):
+            The name of the scheduler to use.
+        optimizer (`torch.optim.Optimizer`):
+            The optimizer that will be used during training.
+        num_warmup_steps (`int`, *optional*):
+            The number of warmup steps to do. This is not required by all schedulers (hence the argument being
+            optional), the function will raise an error if it's unset and the scheduler type requires it.
+        num_training_steps (`int``, *optional*):
+            The number of training steps to do. This is not required by all schedulers (hence the argument being
+            optional), the function will raise an error if it's unset and the scheduler type requires it.
+        scheduler_specific_kwargs (`dict`, *optional*):
+            Extra parameters for schedulers such as cosine with restarts. Mismatched scheduler types and scheduler
+            parameters will cause the scheduler function to raise a TypeError.
+    """
+    name = SchedulerType(name)
+    schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
+
+    # If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and
+    # recursively call `get_scheduler` to get the proper schedulers on each parameter
+    if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer):
+        optimizer_dict = optimizer.optimizer_dict
+        scheduler_dict = {}
+
+        for param in optimizer_dict:
+            scheduler_dict[param] = get_scheduler(
+                name,
+                optimizer=optimizer_dict[param],
+                num_warmup_steps=num_warmup_steps,
+                num_training_steps=num_training_steps,
+                scheduler_specific_kwargs=scheduler_specific_kwargs,
+            )
+
+        def scheduler_hook(param):
+            # Since the optimizer hook has been already attached we only need to
+            # attach the scheduler hook, the gradients have been zeroed here
+            scheduler_dict[param].step()
+
+        for param in optimizer_dict:
+            if param.requires_grad:
+                param.register_post_accumulate_grad_hook(scheduler_hook)
+
+        return LayerWiseDummyScheduler(optimizer_dict=optimizer_dict, lr=optimizer.defaults["lr"])
+
+    if name == SchedulerType.CONSTANT:
+        return schedule_func(optimizer)
+
+    if scheduler_specific_kwargs is None:
+        scheduler_specific_kwargs = {}
+
+    if name == SchedulerType.REDUCE_ON_PLATEAU:
+        return schedule_func(optimizer, **scheduler_specific_kwargs)
+
+    # All other schedulers require `num_warmup_steps`
+    if num_warmup_steps is None:
+        raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
+
+    if name == SchedulerType.CONSTANT_WITH_WARMUP:
+        return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
+
+    if name == SchedulerType.INVERSE_SQRT:
+        return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
+
+    # wsd scheduler requires either num_training_steps or num_stable_steps
+    if name == SchedulerType.WARMUP_STABLE_DECAY:
+        return schedule_func(
+            optimizer,
+            num_warmup_steps=num_warmup_steps,
+            num_training_steps=num_training_steps,
+            **scheduler_specific_kwargs,
+        )
+
+    # All other schedulers require `num_training_steps`
+    if num_training_steps is None:
+        raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
+
+    return schedule_func(
+        optimizer,
+        num_warmup_steps=num_warmup_steps,
+        num_training_steps=num_training_steps,
+        **scheduler_specific_kwargs,
+    )
+
+
+class Adafactor(Optimizer):
+    """
+    AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
+    https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
+
+    Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://huggingface.co/papers/1804.04235 Note that
+    this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
+    `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
+    `relative_step=False`.
+
+    Arguments:
+        params (`Iterable[nn.parameter.Parameter]`):
+            Iterable of parameters to optimize or dictionaries defining parameter groups.
+        lr (`float`, *optional*):
+            The external learning rate.
+        eps (`tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`):
+            Regularization constants for square gradient and parameter scale respectively
+        clip_threshold (`float`, *optional*, defaults to 1.0):
+            Threshold of root mean square of final gradient update
+        decay_rate (`float`, *optional*, defaults to -0.8):
+            Coefficient used to compute running averages of square
+        beta1 (`float`, *optional*):
+            Coefficient used for computing running averages of gradient
+        weight_decay (`float`, *optional*, defaults to 0.0):
+            Weight decay (L2 penalty)
+        scale_parameter (`bool`, *optional*, defaults to `True`):
+            If True, learning rate is scaled by root mean square
+        relative_step (`bool`, *optional*, defaults to `True`):
+            If True, time-dependent learning rate is computed instead of external learning rate
+        warmup_init (`bool`, *optional*, defaults to `False`):
+            Time-dependent learning rate computation depends on whether warm-up initialization is being used
+
+    This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
+
+    Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
+
+        - Training without LR warmup or clip_threshold is not recommended.
+
+           - use scheduled LR warm-up to fixed LR
+           - use clip_threshold=1.0 (https://huggingface.co/papers/1804.04235)
+        - Disable relative updates
+        - Use scale_parameter=False
+        - Additional optimizer operations like gradient clipping should not be used alongside Adafactor
+
+    Example:
+
+    ```python
+    Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
+    ```
+
+    Others reported the following combination to work well:
+
+    ```python
+    Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
+    ```
+
+    When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
+    scheduler as following:
+
+    ```python
+    from transformers.optimization import Adafactor, AdafactorSchedule
+
+    optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
+    lr_scheduler = AdafactorSchedule(optimizer)
+    trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
+    ```
+
+    Usage:
+
+    ```python
+    # replace AdamW with Adafactor
+    optimizer = Adafactor(
+        model.parameters(),
+        lr=1e-3,
+        eps=(1e-30, 1e-3),
+        clip_threshold=1.0,
+        decay_rate=-0.8,
+        beta1=None,
+        weight_decay=0.0,
+        relative_step=False,
+        scale_parameter=False,
+        warmup_init=False,
+    )
+    ```"""
+
+    def __init__(
+        self,
+        params,
+        lr=None,
+        eps=(1e-30, 1e-3),
+        clip_threshold=1.0,
+        decay_rate=-0.8,
+        beta1=None,
+        weight_decay=0.0,
+        scale_parameter=True,
+        relative_step=True,
+        warmup_init=False,
+    ):
+        if lr is not None and relative_step:
+            raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
+        if warmup_init and not relative_step:
+            raise ValueError("`warmup_init=True` requires `relative_step=True`")
+
+        defaults = {
+            "lr": lr,
+            "eps": eps,
+            "clip_threshold": clip_threshold,
+            "decay_rate": decay_rate,
+            "beta1": beta1,
+            "weight_decay": weight_decay,
+            "scale_parameter": scale_parameter,
+            "relative_step": relative_step,
+            "warmup_init": warmup_init,
+        }
+        super().__init__(params, defaults)
+
+    @staticmethod
+    def _get_lr(param_group, param_state):
+        rel_step_sz = param_group["lr"]
+        if param_group["relative_step"]:
+            min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
+            rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
+        param_scale = 1.0
+        if param_group["scale_parameter"]:
+            param_scale = max(param_group["eps"][1], param_state["RMS"])
+        return param_scale * rel_step_sz
+
+    @staticmethod
+    def _get_options(param_group, param_shape):
+        factored = len(param_shape) >= 2
+        use_first_moment = param_group["beta1"] is not None
+        return factored, use_first_moment
+
+    @staticmethod
+    def _rms(tensor):
+        return tensor.norm(2) / (tensor.numel() ** 0.5)
+
+    @staticmethod
+    def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
+        # copy from fairseq's adafactor implementation:
+        # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
+        r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
+        c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
+        return torch.mul(r_factor, c_factor)
+
+    @torch.no_grad()
+    def step(self, closure=None):
+        """
+        Performs a single optimization step
+
+        Arguments:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group in self.param_groups:
+            for p in group["params"]:
+                if p.grad is None:
+                    continue
+                grad = p.grad
+                if grad.dtype in {torch.float16, torch.bfloat16}:
+                    grad = grad.float()
+                if grad.is_sparse:
+                    raise RuntimeError("Adafactor does not support sparse gradients.")
+
+                state = self.state[p]
+                grad_shape = grad.shape
+
+                factored, use_first_moment = self._get_options(group, grad_shape)
+                # State Initialization
+                if len(state) == 0:
+                    state["step"] = 0
+
+                    if use_first_moment:
+                        # Exponential moving average of gradient values
+                        state["exp_avg"] = torch.zeros_like(grad)
+                    if factored:
+                        state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
+                        state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
+                    else:
+                        state["exp_avg_sq"] = torch.zeros_like(grad)
+
+                    state["RMS"] = 0
+                else:
+                    if use_first_moment:
+                        state["exp_avg"] = state["exp_avg"].to(grad)
+                    if factored:
+                        state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
+                        state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
+                    else:
+                        state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
+
+                p_data_fp32 = p
+                if p.dtype in {torch.float16, torch.bfloat16}:
+                    p_data_fp32 = p_data_fp32.float()
+
+                state["step"] += 1
+                state["RMS"] = self._rms(p_data_fp32)
+                lr = self._get_lr(group, state)
+
+                beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
+                update = (grad**2) + group["eps"][0]
+                if factored:
+                    exp_avg_sq_row = state["exp_avg_sq_row"]
+                    exp_avg_sq_col = state["exp_avg_sq_col"]
+
+                    exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
+                    exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
+
+                    # Approximation of exponential moving average of square of gradient
+                    update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
+                    update.mul_(grad)
+                else:
+                    exp_avg_sq = state["exp_avg_sq"]
+
+                    exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
+                    update = exp_avg_sq.rsqrt().mul_(grad)
+
+                update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
+                update.mul_(lr)
+
+                if use_first_moment:
+                    exp_avg = state["exp_avg"]
+                    exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
+                    update = exp_avg
+
+                if group["weight_decay"] != 0:
+                    p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
+
+                p_data_fp32.add_(-update)
+
+                if p.dtype in {torch.float16, torch.bfloat16}:
+                    p.copy_(p_data_fp32)
+
+        return loss
+
+
+class AdafactorSchedule(LambdaLR):
+    """
+    Since [`~optimization.Adafactor`] performs its own scheduling, if the training loop relies on a scheduler (e.g.,
+    for logging), this class creates a proxy object that retrieves the current lr values from the optimizer.
+
+    It returns `initial_lr` during startup and the actual `lr` during stepping.
+    """
+
+    def __init__(self, optimizer, initial_lr=0.0):
+        def lr_lambda(_):
+            return initial_lr
+
+        for group in optimizer.param_groups:
+            group["initial_lr"] = initial_lr
+        super().__init__(optimizer, lr_lambda)
+        for group in optimizer.param_groups:
+            del group["initial_lr"]
+
+    def get_lr(self):
+        opt = self.optimizer
+        lrs = [
+            opt._get_lr(group, opt.state[group["params"][0]])
+            for group in opt.param_groups
+            if group["params"][0].grad is not None
+        ]
+        if len(lrs) == 0:
+            lrs = self.base_lrs  # if called before stepping
+        return lrs
+
+
+def get_adafactor_schedule(optimizer, initial_lr=0.0):
+    """
+    Get a proxy schedule for [`~optimization.Adafactor`]
+
+    Args:
+        optimizer ([`~torch.optim.Optimizer`]):
+            The optimizer for which to schedule the learning rate.
+        initial_lr (`float`, *optional*, defaults to 0.0):
+            Initial lr
+
+    Return:
+        [`~optimization.Adafactor`] proxy schedule object.
+
+
+    """
+    return AdafactorSchedule(optimizer, initial_lr)
diff --git a/phivenv/Lib/site-packages/transformers/optimization_tf.py b/phivenv/Lib/site-packages/transformers/optimization_tf.py
new file mode 100644
index 0000000000000000000000000000000000000000..71a77251f2bf9431a08295b8daabbcbe576de71b
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/optimization_tf.py
@@ -0,0 +1,378 @@
+# Copyright 2019 The TensorFlow Authors, The Hugging Face Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions and classes related to optimization (weight updates)."""
+
+from typing import Callable, Optional, Union
+
+import tensorflow as tf
+
+
+try:
+    from tf_keras.optimizers.legacy import Adam
+except (ImportError, ModuleNotFoundError):
+    from tensorflow.keras.optimizers.legacy import Adam
+
+from .modeling_tf_utils import keras
+
+
+# This block because Keras loves randomly moving things to different places - this changed somewhere between 2.10 - 2.15
+if hasattr(keras.optimizers.schedules, "learning_rate_schedule"):
+    schedules = keras.optimizers.schedules.learning_rate_schedule
+else:
+    schedules = keras.optimizers.schedules
+
+
+class WarmUp(schedules.LearningRateSchedule):
+    """
+    Applies a warmup schedule on a given learning rate decay schedule.
+
+    Args:
+        initial_learning_rate (`float`):
+            The initial learning rate for the schedule after the warmup (so this will be the learning rate at the end
+            of the warmup).
+        decay_schedule_fn (`Callable`):
+            The schedule function to apply after the warmup for the rest of training.
+        warmup_steps (`int`):
+            The number of steps for the warmup part of training.
+        power (`float`, *optional*, defaults to 1.0):
+            The power to use for the polynomial warmup (defaults is a linear warmup).
+        name (`str`, *optional*):
+            Optional name prefix for the returned tensors during the schedule.
+    """
+
+    def __init__(
+        self,
+        initial_learning_rate: float,
+        decay_schedule_fn: Callable,
+        warmup_steps: int,
+        power: float = 1.0,
+        name: Optional[str] = None,
+    ):
+        super().__init__()
+        self.initial_learning_rate = initial_learning_rate
+        self.warmup_steps = warmup_steps
+        self.power = power
+        self.decay_schedule_fn = decay_schedule_fn
+        self.name = name
+
+    def __call__(self, step):
+        with tf.name_scope(self.name or "WarmUp") as name:
+            # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
+            # learning rate will be `global_step/num_warmup_steps * init_lr`.
+            global_step_float = tf.cast(step, tf.float32)
+            warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
+            warmup_percent_done = global_step_float / warmup_steps_float
+            warmup_learning_rate = self.initial_learning_rate * tf.math.pow(warmup_percent_done, self.power)
+            return tf.cond(
+                global_step_float < warmup_steps_float,
+                lambda: warmup_learning_rate,
+                lambda: self.decay_schedule_fn(step - self.warmup_steps),
+                name=name,
+            )
+
+    def get_config(self):
+        return {
+            "initial_learning_rate": self.initial_learning_rate,
+            "decay_schedule_fn": self.decay_schedule_fn,
+            "warmup_steps": self.warmup_steps,
+            "power": self.power,
+            "name": self.name,
+        }
+
+
+def create_optimizer(
+    init_lr: float,
+    num_train_steps: int,
+    num_warmup_steps: int,
+    min_lr_ratio: float = 0.0,
+    adam_beta1: float = 0.9,
+    adam_beta2: float = 0.999,
+    adam_epsilon: float = 1e-8,
+    adam_clipnorm: Optional[float] = None,
+    adam_global_clipnorm: Optional[float] = None,
+    weight_decay_rate: float = 0.0,
+    power: float = 1.0,
+    include_in_weight_decay: Optional[list[str]] = None,
+):
+    """
+    Creates an optimizer with a learning rate schedule using a warmup phase followed by a linear decay.
+
+    Args:
+        init_lr (`float`):
+            The desired learning rate at the end of the warmup phase.
+        num_train_steps (`int`):
+            The total number of training steps.
+        num_warmup_steps (`int`):
+            The number of warmup steps.
+        min_lr_ratio (`float`, *optional*, defaults to 0):
+            The final learning rate at the end of the linear decay will be `init_lr * min_lr_ratio`.
+        adam_beta1 (`float`, *optional*, defaults to 0.9):
+            The beta1 to use in Adam.
+        adam_beta2 (`float`, *optional*, defaults to 0.999):
+            The beta2 to use in Adam.
+        adam_epsilon (`float`, *optional*, defaults to 1e-8):
+            The epsilon to use in Adam.
+        adam_clipnorm (`float`, *optional*, defaults to `None`):
+            If not `None`, clip the gradient norm for each weight tensor to this value.
+        adam_global_clipnorm (`float`, *optional*, defaults to `None`)
+            If not `None`, clip gradient norm to this value. When using this argument, the norm is computed over all
+            weight tensors, as if they were concatenated into a single vector.
+        weight_decay_rate (`float`, *optional*, defaults to 0):
+            The weight decay to use.
+        power (`float`, *optional*, defaults to 1.0):
+            The power to use for PolynomialDecay.
+        include_in_weight_decay (`list[str]`, *optional*):
+            List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is
+            applied to all parameters except bias and layer norm parameters.
+    """
+    # Implements linear decay of the learning rate.
+    lr_schedule = schedules.PolynomialDecay(
+        initial_learning_rate=init_lr,
+        decay_steps=num_train_steps - num_warmup_steps,
+        end_learning_rate=init_lr * min_lr_ratio,
+        power=power,
+    )
+    if num_warmup_steps:
+        lr_schedule = WarmUp(
+            initial_learning_rate=init_lr,
+            decay_schedule_fn=lr_schedule,
+            warmup_steps=num_warmup_steps,
+        )
+    if weight_decay_rate > 0.0:
+        optimizer = AdamWeightDecay(
+            learning_rate=lr_schedule,
+            weight_decay_rate=weight_decay_rate,
+            beta_1=adam_beta1,
+            beta_2=adam_beta2,
+            epsilon=adam_epsilon,
+            clipnorm=adam_clipnorm,
+            global_clipnorm=adam_global_clipnorm,
+            exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
+            include_in_weight_decay=include_in_weight_decay,
+        )
+    else:
+        optimizer = keras.optimizers.Adam(
+            learning_rate=lr_schedule,
+            beta_1=adam_beta1,
+            beta_2=adam_beta2,
+            epsilon=adam_epsilon,
+            clipnorm=adam_clipnorm,
+            global_clipnorm=adam_global_clipnorm,
+        )
+    # We return the optimizer and the LR scheduler in order to better track the
+    # evolution of the LR independently of the optimizer.
+    return optimizer, lr_schedule
+
+
+class AdamWeightDecay(Adam):
+    """
+    Adam enables L2 weight decay and clip_by_global_norm on gradients. Just adding the square of the weights to the
+    loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact
+    with the m and v parameters in strange ways as shown in [Decoupled Weight Decay
+    Regularization](https://huggingface.co/papers/1711.05101).
+
+    Instead we want to decay the weights in a manner that doesn't interact with the m/v parameters. This is equivalent
+    to adding the square of the weights to the loss with plain (non-momentum) SGD.
+
+    Args:
+        learning_rate (`Union[float, LearningRateSchedule]`, *optional*, defaults to 0.001):
+            The learning rate to use or a schedule.
+        beta_1 (`float`, *optional*, defaults to 0.9):
+            The beta1 parameter in Adam, which is the exponential decay rate for the 1st momentum estimates.
+        beta_2 (`float`, *optional*, defaults to 0.999):
+            The beta2 parameter in Adam, which is the exponential decay rate for the 2nd momentum estimates.
+        epsilon (`float`, *optional*, defaults to 1e-07):
+            The epsilon parameter in Adam, which is a small constant for numerical stability.
+        amsgrad (`bool`, *optional*, defaults to `False`):
+            Whether to apply AMSGrad variant of this algorithm or not, see [On the Convergence of Adam and
+            Beyond](https://huggingface.co/papers/1904.09237).
+        weight_decay_rate (`float`, *optional*, defaults to 0.0):
+            The weight decay to apply.
+        include_in_weight_decay (`list[str]`, *optional*):
+            List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is
+            applied to all parameters by default (unless they are in `exclude_from_weight_decay`).
+        exclude_from_weight_decay (`list[str]`, *optional*):
+            List of the parameter names (or re patterns) to exclude from applying weight decay to. If a
+            `include_in_weight_decay` is passed, the names in it will supersede this list.
+        name (`str`, *optional*, defaults to `"AdamWeightDecay"`):
+            Optional name for the operations created when applying gradients.
+        kwargs (`dict[str, Any]`, *optional*):
+            Keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
+            norm; `clipvalue` is clip gradients by value, `decay` is included for backward compatibility to allow time
+            inverse decay of learning rate. `lr` is included for backward compatibility, recommended to use
+            `learning_rate` instead.
+    """
+
+    def __init__(
+        self,
+        learning_rate: Union[float, schedules.LearningRateSchedule] = 0.001,
+        beta_1: float = 0.9,
+        beta_2: float = 0.999,
+        epsilon: float = 1e-7,
+        amsgrad: bool = False,
+        weight_decay_rate: float = 0.0,
+        include_in_weight_decay: Optional[list[str]] = None,
+        exclude_from_weight_decay: Optional[list[str]] = None,
+        name: str = "AdamWeightDecay",
+        **kwargs,
+    ):
+        super().__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
+        self.weight_decay_rate = weight_decay_rate
+        self._include_in_weight_decay = include_in_weight_decay
+        self._exclude_from_weight_decay = exclude_from_weight_decay
+
+    @classmethod
+    def from_config(cls, config):
+        """Creates an optimizer from its config with WarmUp custom object."""
+        custom_objects = {"WarmUp": WarmUp}
+        return super().from_config(config, custom_objects=custom_objects)
+
+    def _prepare_local(self, var_device, var_dtype, apply_state):
+        super()._prepare_local(var_device, var_dtype, apply_state)
+        apply_state[(var_device, var_dtype)]["weight_decay_rate"] = tf.constant(
+            self.weight_decay_rate, name="adam_weight_decay_rate"
+        )
+
+    def _decay_weights_op(self, var, learning_rate, apply_state):
+        do_decay = self._do_use_weight_decay(var.name)
+        if do_decay:
+            return var.assign_sub(
+                learning_rate * var * apply_state[(var.device, var.dtype.base_dtype)]["weight_decay_rate"],
+                use_locking=self._use_locking,
+            )
+        return tf.no_op()
+
+    def apply_gradients(self, grads_and_vars, name=None, **kwargs):
+        grads, tvars = list(zip(*grads_and_vars))
+        return super().apply_gradients(zip(grads, tvars), name=name, **kwargs)
+
+    def _get_lr(self, var_device, var_dtype, apply_state):
+        """Retrieves the learning rate with the given state."""
+        if apply_state is None:
+            return self._decayed_lr_t[var_dtype], {}
+
+        apply_state = apply_state or {}
+        coefficients = apply_state.get((var_device, var_dtype))
+        if coefficients is None:
+            coefficients = self._fallback_apply_state(var_device, var_dtype)
+            apply_state[(var_device, var_dtype)] = coefficients
+
+        return coefficients["lr_t"], {"apply_state": apply_state}
+
+    def _resource_apply_dense(self, grad, var, apply_state=None):
+        lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
+        decay = self._decay_weights_op(var, lr_t, apply_state)
+        with tf.control_dependencies([decay]):
+            return super()._resource_apply_dense(grad, var, **kwargs)
+
+    def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
+        lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
+        decay = self._decay_weights_op(var, lr_t, apply_state)
+        with tf.control_dependencies([decay]):
+            return super()._resource_apply_sparse(grad, var, indices, **kwargs)
+
+    def get_config(self):
+        config = super().get_config()
+        config.update({"weight_decay_rate": self.weight_decay_rate})
+        return config
+
+    def _do_use_weight_decay(self, param_name):
+        """Whether to use L2 weight decay for `param_name`."""
+        if self.weight_decay_rate == 0:
+            return False
+
+        if self._include_in_weight_decay:
+            for r in self._include_in_weight_decay:
+                if r in param_name:
+                    return True
+
+        if self._exclude_from_weight_decay:
+            for r in self._exclude_from_weight_decay:
+                if r in param_name:
+                    return False
+        return True
+
+
+# Extracted from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
+class GradientAccumulator:
+    """
+    Gradient accumulation utility. When used with a distribution strategy, the accumulator should be called in a
+    replica context. Gradients will be accumulated locally on each replica and without synchronization. Users should
+    then call `.gradients`, scale the gradients if required, and pass the result to `apply_gradients`.
+    """
+
+    # We use the ON_READ synchronization policy so that no synchronization is
+    # performed on assignment. To get the value, we call .value() which returns the
+    # value on the current replica without synchronization.
+
+    def __init__(self):
+        """Initializes the accumulator."""
+        self._gradients = []
+        self._accum_steps = None
+
+    @property
+    def step(self):
+        """Number of accumulated steps."""
+        if self._accum_steps is None:
+            self._accum_steps = tf.Variable(
+                tf.constant(0, dtype=tf.int64),
+                trainable=False,
+                synchronization=tf.VariableSynchronization.ON_READ,
+                aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
+            )
+
+        return self._accum_steps.value()
+
+    @property
+    def gradients(self):
+        """The accumulated gradients on the current replica."""
+        if not self._gradients:
+            raise ValueError("The accumulator should be called first to initialize the gradients")
+        return [gradient.value() if gradient is not None else gradient for gradient in self._gradients]
+
+    def __call__(self, gradients):
+        """Accumulates `gradients` on the current replica."""
+        if not self._gradients:
+            _ = self.step  # Create the step variable.
+            self._gradients.extend(
+                [
+                    tf.Variable(
+                        tf.zeros_like(gradient),
+                        trainable=False,
+                        synchronization=tf.VariableSynchronization.ON_READ,
+                        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
+                    )
+                    if gradient is not None
+                    else gradient
+                    for gradient in gradients
+                ]
+            )
+        if len(gradients) != len(self._gradients):
+            raise ValueError(f"Expected {len(self._gradients)} gradients, but got {len(gradients)}")
+
+        for accum_gradient, gradient in zip(self._gradients, gradients):
+            if accum_gradient is not None and gradient is not None:
+                accum_gradient.assign_add(gradient)
+
+        self._accum_steps.assign_add(1)
+
+    def reset(self):
+        """Resets the accumulated gradients on the current replica."""
+        if not self._gradients:
+            return
+        self._accum_steps.assign(0)
+        for gradient in self._gradients:
+            if gradient is not None:
+                gradient.assign(tf.zeros_like(gradient))
diff --git a/phivenv/Lib/site-packages/transformers/processing_utils.py b/phivenv/Lib/site-packages/transformers/processing_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..baa98881ff3380fe546a4e7d6b2f421943a66e78
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/processing_utils.py
@@ -0,0 +1,1714 @@
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Processing saving/loading class for common processors.
+"""
+
+import bisect
+import copy
+import inspect
+import json
+import os
+import sys
+import typing
+import warnings
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Optional, TypedDict, TypeVar, Union
+
+import numpy as np
+import typing_extensions
+from huggingface_hub.errors import EntryNotFoundError
+
+from .audio_utils import load_audio
+from .dynamic_module_utils import custom_object_save
+from .feature_extraction_utils import BatchFeature
+from .image_utils import ChannelDimension, is_vision_available
+from .utils.chat_template_utils import render_jinja_template
+from .video_utils import VideoMetadata
+
+
+if is_vision_available():
+    from .image_utils import PILImageResampling
+
+
+from .tokenization_utils_base import (
+    PaddingStrategy,
+    PreTokenizedInput,
+    PreTrainedTokenizerBase,
+    TextInput,
+    TruncationStrategy,
+)
+from .utils import (
+    AUDIO_TOKENIZER_NAME,
+    CHAT_TEMPLATE_DIR,
+    CHAT_TEMPLATE_FILE,
+    LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE,
+    PROCESSOR_NAME,
+    PushToHubMixin,
+    TensorType,
+    cached_file,
+    copy_func,
+    direct_transformers_import,
+    download_url,
+    is_offline_mode,
+    is_remote_url,
+    is_torch_available,
+    list_repo_templates,
+    logging,
+)
+from .utils.deprecation import deprecate_kwarg
+
+
+if is_torch_available():
+    from .modeling_utils import PreTrainedAudioTokenizerBase
+
+
+logger = logging.get_logger(__name__)
+
+# type hinting: specifying the type of processor class that inherits from ProcessorMixin
+SpecificProcessorType = TypeVar("SpecificProcessorType", bound="ProcessorMixin")
+
+# Dynamically import the Transformers module to grab the attribute classes of the processor from their names.
+transformers_module = direct_transformers_import(Path(__file__).parent)
+
+
+AUTO_TO_BASE_CLASS_MAPPING = {
+    "AutoTokenizer": "PreTrainedTokenizerBase",
+    "AutoFeatureExtractor": "FeatureExtractionMixin",
+    "AutoImageProcessor": "ImageProcessingMixin",
+    "AutoVideoProcessor": "BaseVideoProcessor",
+}
+
+if sys.version_info >= (3, 11):
+    Unpack = typing.Unpack
+else:
+    Unpack = typing_extensions.Unpack
+
+
+class TextKwargs(TypedDict, total=False):
+    """
+    Keyword arguments for text processing. For extended documentation, check out tokenization_utils_base methods and
+    docstrings associated.
+
+    Attributes:
+        add_special_tokens (`bool`, *optional*)
+            Whether or not to add special tokens when encoding the sequences.
+        padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*)
+            Activates and controls padding.
+        truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*):
+            Activates and controls truncation.
+        max_length (`int`, *optional*):
+            Controls the maximum length to use by one of the truncation/padding parameters.
+        stride (`int`, *optional*):
+            If set, the overflowing tokens will contain some tokens from the end of the truncated sequence.
+        is_split_into_words (`bool`, *optional*):
+            Whether or not the input is already pre-tokenized.
+        pad_to_multiple_of (`int`, *optional*):
+            If set, will pad the sequence to a multiple of the provided value.
+        return_token_type_ids (`bool`, *optional*):
+            Whether to return token type IDs.
+        return_attention_mask (`bool`, *optional*):
+            Whether to return the attention mask.
+        return_overflowing_tokens (`bool`, *optional*):
+            Whether or not to return overflowing token sequences.
+        return_special_tokens_mask (`bool`, *optional*):
+            Whether or not to return special tokens mask information.
+        return_offsets_mapping (`bool`, *optional*):
+            Whether or not to return `(char_start, char_end)` for each token.
+        return_length (`bool`, *optional*):
+            Whether or not to return the lengths of the encoded inputs.
+        verbose (`bool`, *optional*):
+            Whether or not to print more information and warnings.
+        padding_side (`str`, *optional*):
+            The side on which padding will be applied.
+        return_mm_token_type_ids (`bool`, *optional*):
+            Whether to return multimodal token type ids indicating mm placeholder token positions.
+    """
+
+    text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]]
+    text_target: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]
+    text_pair_target: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]]
+    add_special_tokens: Optional[bool]
+    padding: Union[bool, str, PaddingStrategy]
+    truncation: Union[bool, str, TruncationStrategy]
+    max_length: Optional[int]
+    stride: Optional[int]
+    is_split_into_words: Optional[bool]
+    pad_to_multiple_of: Optional[int]
+    return_token_type_ids: Optional[bool]
+    return_attention_mask: Optional[bool]
+    return_overflowing_tokens: Optional[bool]
+    return_special_tokens_mask: Optional[bool]
+    return_offsets_mapping: Optional[bool]
+    return_length: Optional[bool]
+    verbose: Optional[bool]
+    padding_side: Optional[str]
+    return_mm_token_type_ids: Optional[bool]
+
+
+class ImagesKwargs(TypedDict, total=False):
+    """
+    Keyword arguments for image processing. For extended documentation, check the appropriate ImageProcessor
+    class methods and docstrings.
+
+    Attributes:
+        do_resize (`bool`, *optional*):
+            Whether to resize the image.
+        size (`dict[str, int]`, *optional*):
+            Resize the shorter side of the input to `size["shortest_edge"]`.
+        size_divisor (`int`, *optional*):
+            The size by which to make sure both the height and width can be divided.
+        crop_size (`dict[str, int]`, *optional*):
+            Desired output size when applying center-cropping.
+        resample (`PILImageResampling`, *optional*):
+            Resampling filter to use if resizing the image.
+        do_rescale (`bool`, *optional*):
+            Whether to rescale the image by the specified scale `rescale_factor`.
+        rescale_factor (`int` or `float`, *optional*):
+            Scale factor to use if rescaling the image.
+        do_normalize (`bool`, *optional*):
+            Whether to normalize the image.
+        image_mean (`float` or `list[float]`, *optional*):
+            Mean to use if normalizing the image.
+        image_std (`float` or `list[float]`, *optional*):
+            Standard deviation to use if normalizing the image.
+        do_pad (`bool`, *optional*):
+            Whether to pad the image to the `(max_height, max_width)` of the images in the batch.
+        pad_size (`dict[str, int]`, *optional*):
+            The size `{"height": int, "width" int}` to pad the images to.
+        do_center_crop (`bool`, *optional*):
+            Whether to center crop the image.
+        data_format (`ChannelDimension` or `str`, *optional*):
+            The channel dimension format for the output image.
+        input_data_format (`ChannelDimension` or `str`, *optional*):
+            The channel dimension format for the input image.
+        device (`str`, *optional*):
+            The device to use for processing (e.g. "cpu", "cuda"), only relevant for fast image processing.
+    """
+
+    do_resize: Optional[bool]
+    size: Optional[dict[str, int]]
+    size_divisor: Optional[int]
+    crop_size: Optional[dict[str, int]]
+    resample: Optional[Union["PILImageResampling", int]]
+    do_rescale: Optional[bool]
+    rescale_factor: Optional[float]
+    do_normalize: Optional[bool]
+    image_mean: Optional[Union[float, list[float]]]
+    image_std: Optional[Union[float, list[float]]]
+    do_pad: Optional[bool]
+    pad_size: Optional[dict[str, int]]
+    do_center_crop: Optional[bool]
+    data_format: Optional[ChannelDimension]
+    input_data_format: Optional[Union[str, ChannelDimension]]
+    device: Optional[str]
+
+
+class VideosKwargs(TypedDict, total=False):
+    """
+    Keyword arguments for video processing.
+
+    Attributes:
+        do_convert_rgb (`bool`):
+            Whether to convert the video to RGB format.
+        do_resize (`bool`):
+            Whether to resize the video.
+        size (`dict[str, int]`, *optional*):
+            Resize the shorter side of the input to `size["shortest_edge"]`.
+        default_to_square (`bool`, *optional*, defaults to `self.default_to_square`):
+            Whether to default to a square when resizing, if size is an int.
+        size_divisor (`int`, *optional*):
+            The size by which to make sure both the height and width can be divided.
+        resample (`PILImageResampling`, *optional*):
+            Resampling filter to use if resizing the video.
+        do_rescale (`bool`, *optional*):
+            Whether to rescale the video by the specified scale `rescale_factor`.
+        rescale_factor (`int` or `float`, *optional*):
+            Scale factor to use if rescaling the video.
+        do_normalize (`bool`, *optional*):
+            Whether to normalize the video.
+        image_mean (`float` or `list[float]`, *optional*):
+            Mean to use if normalizing the video.
+        image_std (`float` or `list[float]`, *optional*):
+            Standard deviation to use if normalizing the video.
+        do_pad (`bool`, *optional*):
+            Whether to pad the video to the `(max_height, max_width)` of the videos in the batch.
+        do_center_crop (`bool`, *optional*):
+            Whether to center crop the video.
+        do_sample_frames (`bool`, *optional*):
+            Whether to sample frames from the video before processing or to process the whole video.
+        video_metadata (`Union[VideoMetadata, dict]`, *optional*):
+            Metadata of the video containing information about total duration, fps and total number of frames.
+        num_frames (`int`, *optional*):
+            Maximum number of frames to sample when `do_sample_frames=True`.
+        fps (`int` or `float`, *optional*):
+            Target frames to sample per second when `do_sample_frames=True`.
+        crop_size (`dict[str, int]`, *optional*):
+            Desired output size when applying center-cropping.
+        data_format (`ChannelDimension` or `str`, *optional*):
+            The channel dimension format for the output video.
+        input_data_format (`ChannelDimension` or `str`, *optional*):
+            The channel dimension format for the input video.
+        return_metadata (`ChannelDimension` or `str`, *optional*):
+            Whether to return video metadata or not.
+    """
+
+    do_convert_rgb: Optional[bool]
+    do_resize: Optional[bool]
+    size: Optional[dict[str, int]]
+    size_divisor: Optional[int]
+    default_to_square: Optional[bool]
+    resample: Optional["PILImageResampling"]
+    do_rescale: Optional[bool]
+    rescale_factor: Optional[float]
+    do_normalize: Optional[bool]
+    image_mean: Optional[Union[float, list[float]]]
+    image_std: Optional[Union[float, list[float]]]
+    do_pad: Optional[bool]
+    do_center_crop: Optional[bool]
+    crop_size: Optional[dict[str, int]]
+    data_format: Optional[ChannelDimension]
+    input_data_format: Optional[Union[str, ChannelDimension]]
+    device: Optional[str]
+    do_sample_frames: Optional[bool]
+    video_metadata: Optional[Union[VideoMetadata, dict]]
+    fps: Optional[Union[int, float]]
+    num_frames: Optional[int]
+    return_metadata: Optional[bool]
+
+
+class AudioKwargs(TypedDict, total=False):
+    """
+    Keyword arguments for audio processing.
+
+    Attributes:
+        sampling_rate (`int`, *optional*):
+            The sampling rate at which the `raw_speech` input was sampled.
+        raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
+            The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
+            values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
+            stereo, i.e. single float per timestep.
+        padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*):
+            Select a strategy to pad the returned sequences (according to the model's padding side and padding
+            index) among:
+
+            - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+                sequence if provided).
+            - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+                acceptable input length for the model if that argument is not provided.
+            - `False` or `'do_not_pad'`
+        max_length (`int`, *optional*):
+            Maximum length of the returned list and optionally padding length (see above).
+        truncation (`bool`, *optional*):
+            Activates truncation to cut input sequences longer than *max_length* to *max_length*.
+        pad_to_multiple_of (`int`, *optional*):
+            If set, will pad the sequence to a multiple of the provided value.
+        return_attention_mask (`bool`, *optional*):
+            Whether or not [`~ASTFeatureExtractor.__call__`] should return `attention_mask`.
+    """
+
+    sampling_rate: Optional[int]
+    raw_speech: Optional[Union["np.ndarray", list[float], list["np.ndarray"], list[list[float]]]]
+    padding: Optional[Union[bool, str, PaddingStrategy]]
+    max_length: Optional[int]
+    truncation: Optional[bool]
+    pad_to_multiple_of: Optional[int]
+    return_attention_mask: Optional[bool]
+
+
+class CommonKwargs(TypedDict, total=False):
+    return_tensors: Optional[Union[str, TensorType]]
+
+
+class ProcessingKwargs(TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, total=False):
+    """
+    Base class for kwargs passing to processors.
+    A model should have its own `ModelProcessorKwargs` class that inherits from `ProcessingKwargs` to provide:
+        1) Additional typed keys and that this model requires to process inputs.
+        2) Default values for existing keys under a `_defaults` attribute.
+    New keys have to be defined as follows to ensure type hinting is done correctly.
+
+    ```python
+    # adding a new image kwarg for this model
+    class ModelImagesKwargs(ImagesKwargs, total=False):
+        new_image_kwarg: Optional[bool]
+
+    class ModelProcessorKwargs(ProcessingKwargs, total=False):
+        images_kwargs: ModelImagesKwargs
+        _defaults = {
+            "images_kwargs: {
+                "new_image_kwarg": False,
+            }
+            "text_kwargs": {
+                "padding": "max_length",
+            },
+        }
+
+    ```
+
+    For Python 3.8 compatibility, when inheriting from this class and overriding one of the kwargs,
+    you need to manually update the __annotations__ dictionary. This can be done as follows:
+
+    ```python
+    class CustomProcessorKwargs(ProcessingKwargs, total=False):
+        images_kwargs: CustomImagesKwargs
+
+    CustomProcessorKwargs.__annotations__["images_kwargs"] = CustomImagesKwargs  # python 3.8 compatibility
+    ```python
+
+    """
+
+    common_kwargs: CommonKwargs = {
+        **CommonKwargs.__annotations__,
+    }
+    text_kwargs: TextKwargs = {
+        **TextKwargs.__annotations__,
+    }
+    images_kwargs: ImagesKwargs = {
+        **ImagesKwargs.__annotations__,
+    }
+    videos_kwargs: VideosKwargs = {
+        **VideosKwargs.__annotations__,
+    }
+    audio_kwargs: AudioKwargs = {
+        **AudioKwargs.__annotations__,
+    }
+
+
+class TokenizerChatTemplateKwargs(TypedDict, total=False):
+    """
+    Keyword arguments for tokenizer's `apply_chat_template`, when it is called from within a processor.
+
+    tools (`list[Dict]`, *optional*):
+        A list of tools (callable functions) that will be accessible to the model. If the template does not
+        support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
+        giving the name, description and argument types for the tool. See our
+        [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
+        for more information.
+    documents (`list[dict[str, str]]`, *optional*):
+        A list of dicts representing documents that will be accessible to the model if it is performing RAG
+        (retrieval-augmented generation). If the template does not support RAG, this argument will have no
+        effect. We recommend that each document should be a dict containing "title" and "text" keys. Please
+        see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG)
+        for examples of passing documents with chat templates.
+    add_generation_prompt (bool, *optional*):
+        If this is set, a prompt with the token(s) that indicate
+        the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model.
+        Note that this argument will be passed to the chat template, and so it must be supported in the
+        template for this argument to have any effect.
+    continue_final_message (bool, *optional*):
+        If this is set, the chat will be formatted so that the final
+        message in the chat is open-ended, without any EOS tokens. The model will continue this message
+        rather than starting a new one. This allows you to "prefill" part of
+        the model's response for it. Cannot be used at the same time as `add_generation_prompt`.
+    return_assistant_tokens_mask (`bool`, defaults to `False`):
+        Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
+        the mask will contain 1. For user and system tokens, the mask will contain 0.
+        This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
+    """
+
+    tools: Optional[list[dict]] = None
+    documents: Optional[list[dict[str, str]]] = None
+    add_generation_prompt: Optional[bool] = False
+    continue_final_message: Optional[bool] = False
+    return_assistant_tokens_mask: Optional[bool] = False
+
+
+class ChatTemplateLoadKwargs(TypedDict, total=False):
+    """
+    Keyword arguments used to load multimodal data in processor chat templates.
+
+    num_frames (`int`, *optional*):
+        Number of frames to sample uniformly. If not passed, the whole video is loaded.
+    load_audio_from_video (`bool`, *optional*):
+            Whether to use the audio track of input video. If `True` the audio track will be loaded and passed to the
+            processor. This flag has no effect if the model doesn't support audio modality.
+    """
+
+    sampling_rate: Optional[int] = 16_000
+    load_audio_from_video: Optional[bool] = False
+
+
+class ProcessorChatTemplateKwargs(ChatTemplateLoadKwargs, TokenizerChatTemplateKwargs, total=False):
+    """
+    Keyword arguments for processor's `apply_chat_template`.
+
+    tokenize (`bool`, *optional*, defaults to `False`):
+        Whether to tokenize the output or not.
+    return_dict (`bool`, defaults to `False`):
+        Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
+    """
+
+    tokenize: Optional[bool] = False
+    return_dict: Optional[bool] = False
+
+
+class AllKwargsForChatTemplate(
+    TextKwargs, ImagesKwargs, VideosKwargs, AudioKwargs, CommonKwargs, ProcessorChatTemplateKwargs
+):
+    processor_kwargs: ProcessingKwargs = {
+        **ProcessingKwargs.__annotations__,
+    }
+    mm_load_kwargs: ChatTemplateLoadKwargs = {
+        **TextKwargs.__annotations__,
+    }
+    template_kwargs: ProcessorChatTemplateKwargs = {
+        **ProcessorChatTemplateKwargs.__annotations__,
+    }
+
+
+@dataclass
+class MultiModalData:
+    """
+    Dataclass that holds extra useful data for processing
+    multimodal data. Processors currently cannot return keys,
+    unless it is used in model's forward. Thus we have helper
+    methods that calculate and return useful data from processing
+    input multimodals (images/videos).
+    Note that this dataclass is aimed to be used only in vLLM
+    and we might change its API in the future.
+    """
+
+    num_image_tokens: list[int] = None
+    num_video_tokens: list[int] = None
+    num_audio_tokens: list[int] = None
+    num_image_patches: list[int] = None
+
+    def __contains__(self, key):
+        return hasattr(self, key) and getattr(self, key) is not None
+
+    def __getitem__(self, key):
+        if hasattr(self, key):
+            return getattr(self, key)
+        raise AttributeError(f"{self.__class__.__name__} has no attribute {key}")
+
+
+class ProcessorMixin(PushToHubMixin):
+    """
+    This is a mixin used to provide saving/loading functionality for all processor classes.
+    """
+
+    attributes = ["feature_extractor", "tokenizer"]
+    optional_attributes = ["chat_template", "audio_tokenizer"]
+    optional_call_args: list[str] = []
+    # Names need to be attr_class for attr in attributes
+    feature_extractor_class = None
+    tokenizer_class = None
+    _auto_class = None
+
+    # args have to match the attributes class attribute
+    def __init__(self, *args, **kwargs):
+        # First, extract optional attributes from kwargs if present
+        # Optional attributes can never be positional arguments
+        for optional_attribute in self.optional_attributes:
+            optional_attribute_value = kwargs.pop(optional_attribute, None)
+            setattr(self, optional_attribute, optional_attribute_value)
+
+            # Check audio tokenizer for its class but do not treat it as attr to avoid saving weights
+            if optional_attribute == "audio_tokenizer" and optional_attribute_value is not None:
+                proper_class = self.check_argument_for_proper_class(optional_attribute, optional_attribute_value)
+
+                if not (is_torch_available() and isinstance(optional_attribute_value, PreTrainedAudioTokenizerBase)):
+                    raise ValueError(
+                        f"Tried to use `{proper_class}` for audio tokenization. However, this class is not"
+                        " registered for audio tokenization."
+                    )
+
+        # Sanitize args and kwargs
+        for key in kwargs:
+            if key not in self.attributes:
+                raise TypeError(f"Unexpected keyword argument {key}.")
+        for arg, attribute_name in zip(args, self.attributes):
+            if attribute_name in kwargs:
+                raise TypeError(f"Got multiple values for argument {attribute_name}.")
+            else:
+                kwargs[attribute_name] = arg
+
+        if len(kwargs) != len(self.attributes):
+            raise ValueError(
+                f"This processor requires {len(self.attributes)} arguments: {', '.join(self.attributes)}. Got "
+                f"{len(args)} arguments instead."
+            )
+
+        # Check each arg is of the proper class (this will also catch a user initializing in the wrong order)
+        for attribute_name, arg in kwargs.items():
+            self.check_argument_for_proper_class(attribute_name, arg)
+            setattr(self, attribute_name, arg)
+
+    def check_argument_for_proper_class(self, argument_name, argument):
+        """
+        Checks the passed argument's class against the expected transformers class. In case of an unexpected
+        mismatch between expected and actual class, an error is raise. Otherwise, the proper retrieved class
+        is returned.
+        """
+        class_name = getattr(self, f"{argument_name}_class")
+        # Nothing is ever going to be an instance of "AutoXxx", in that case we check the base class.
+        class_name = AUTO_TO_BASE_CLASS_MAPPING.get(class_name, class_name)
+        if isinstance(class_name, tuple):
+            proper_class = tuple(self.get_possibly_dynamic_module(n) for n in class_name if n is not None)
+        else:
+            proper_class = self.get_possibly_dynamic_module(class_name)
+
+        if not isinstance(argument, proper_class):
+            raise TypeError(
+                f"Received a {type(argument).__name__} for argument {argument_name}, but a {class_name} was expected."
+            )
+
+        return proper_class
+
+    def to_dict(self, legacy_serialization=True) -> dict[str, Any]:
+        """
+        Serializes this instance to a Python dictionary.
+
+        Returns:
+            `dict[str, Any]`: Dictionary of all the attributes that make up this processor instance.
+        """
+        output = copy.deepcopy(self.__dict__)
+
+        # Get the kwargs in `__init__`.
+        sig = inspect.signature(self.__init__)
+        # Only save the attributes that are presented in the kwargs of `__init__`.
+        attrs_to_save = list(sig.parameters)
+        # extra attributes to be kept
+        attrs_to_save += ["auto_map"]
+
+        if legacy_serialization:
+            # Don't save attributes like `tokenizer`, `image processor` etc. in processor config if `legacy=True`
+            attrs_to_save = [x for x in attrs_to_save if x not in self.__class__.attributes]
+
+        if "tokenizer" in output:
+            del output["tokenizer"]
+        if "qformer_tokenizer" in output:
+            del output["qformer_tokenizer"]
+        if "protein_tokenizer" in output:
+            del output["protein_tokenizer"]
+        if "chat_template" in output:
+            del output["chat_template"]
+
+        # Serialize attributes as a dict
+        output = {
+            k: v.to_dict() if isinstance(v, PushToHubMixin) else v
+            for k, v in output.items()
+            if (
+                k in attrs_to_save  # keep all attributes that have to be serialized
+                and v.__class__.__name__ != "BeamSearchDecoderCTC"  # remove attributes with that are objects
+                and (
+                    (legacy_serialization and not isinstance(v, PushToHubMixin)) or not legacy_serialization
+                )  # remove `PushToHubMixin` objects
+            )
+        }
+
+        # Special case, add `audio_tokenizer` dict which points to model weights and path
+        if not legacy_serialization and "audio_tokenizer" in output:
+            audio_tokenizer_dict = {
+                "audio_tokenizer_class": self.audio_tokenizer.__class__.__name__,
+                "audio_tokenizer_name_or_path": self.audio_tokenizer.name_or_path,
+            }
+            # Update or overwrite, what do audio tokenizers expect when loading?
+            output["audio_tokenizer"] = audio_tokenizer_dict
+
+        output["processor_class"] = self.__class__.__name__
+
+        return output
+
+    def to_json_string(self, legacy_serialization=True) -> str:
+        """
+        Serializes this instance to a JSON string.
+
+        Returns:
+            `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
+        """
+        dictionary = self.to_dict(legacy_serialization=legacy_serialization)
+
+        return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
+
+    def to_json_file(self, json_file_path: Union[str, os.PathLike], legacy_serialization=True):
+        """
+        Save this instance to a JSON file.
+
+        Args:
+            json_file_path (`str` or `os.PathLike`):
+                Path to the JSON file in which this processor instance's parameters will be saved.
+        """
+        with open(json_file_path, "w", encoding="utf-8") as writer:
+            writer.write(self.to_json_string(legacy_serialization=legacy_serialization))
+
+    def __repr__(self):
+        attributes_repr = [f"- {name}: {repr(getattr(self, name))}" for name in self.attributes]
+        attributes_repr = "\n".join(attributes_repr)
+        return f"{self.__class__.__name__}:\n{attributes_repr}\n\n{self.to_json_string()}"
+
+    def save_pretrained(self, save_directory, push_to_hub: bool = False, legacy_serialization: bool = True, **kwargs):
+        """
+        Saves the attributes of this processor (feature extractor, tokenizer...) in the specified directory so that it
+        can be reloaded using the [`~ProcessorMixin.from_pretrained`] method.
+
+        
+
+        This class method is simply calling [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] and
+        [`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`]. Please refer to the docstrings of the
+        methods above for more information.
+
+        
+
+        Args:
+            save_directory (`str` or `os.PathLike`):
+                Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
+                be created if it does not exist).
+            push_to_hub (`bool`, *optional*, defaults to `False`):
+                Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+                namespace).
+            legacy_serialization (`bool`, *optional*, defaults to `True`):
+                Whether or not to save processor attributes in separate config files (legacy) or in processor's config
+                file as a nested dict. Saving all attributes in a single dict will become the default in future versions.
+                Set to `legacy_serialization=True` until then.
+            kwargs (`dict[str, Any]`, *optional*):
+                Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+        """
+        use_auth_token = kwargs.pop("use_auth_token", None)
+
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if kwargs.get("token") is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            kwargs["token"] = use_auth_token
+
+        os.makedirs(save_directory, exist_ok=True)
+
+        if push_to_hub:
+            commit_message = kwargs.pop("commit_message", None)
+            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+            repo_id = self._create_repo(repo_id, **kwargs)
+            files_timestamps = self._get_files_timestamps(save_directory)
+        # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
+        # loaded from the Hub.
+        if self._auto_class is not None:
+            attrs = [getattr(self, attribute_name) for attribute_name in self.attributes]
+            configs = [(a.init_kwargs if isinstance(a, PreTrainedTokenizerBase) else a) for a in attrs]
+            configs.append(self)
+            custom_object_save(self, save_directory, config=configs)
+
+        save_jinja_files = kwargs.get("save_jinja_files", True)
+
+        for attribute_name in self.attributes:
+            # Save the tokenizer in its own vocab file. The other attributes are saved as part of `processor_config.json`
+            if attribute_name == "tokenizer":
+                attribute = getattr(self, attribute_name)
+                if hasattr(attribute, "_set_processor_class"):
+                    attribute._set_processor_class(self.__class__.__name__)
+
+                # Propagate save_jinja_files to tokenizer to ensure we don't get conflicts
+                attribute.save_pretrained(save_directory, save_jinja_files=save_jinja_files)
+            elif legacy_serialization:
+                attribute = getattr(self, attribute_name)
+                # Include the processor class in attribute config so this processor can then be reloaded with `AutoProcessor` API.
+                if hasattr(attribute, "_set_processor_class"):
+                    attribute._set_processor_class(self.__class__.__name__)
+                attribute.save_pretrained(save_directory)
+
+        if self._auto_class is not None:
+            # We added an attribute to the init_kwargs of the tokenizers, which needs to be cleaned up.
+            for attribute_name in self.attributes:
+                attribute = getattr(self, attribute_name)
+                if isinstance(attribute, PreTrainedTokenizerBase):
+                    del attribute.init_kwargs["auto_map"]
+
+        # If we save using the predefined names, we can load using `from_pretrained`
+        # plus we save chat_template in its own file
+        output_processor_file = os.path.join(save_directory, PROCESSOR_NAME)
+        output_chat_template_file_jinja = os.path.join(save_directory, CHAT_TEMPLATE_FILE)
+        output_chat_template_file_legacy = os.path.join(
+            save_directory, LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE
+        )  # Legacy filename
+        chat_template_dir = os.path.join(save_directory, CHAT_TEMPLATE_DIR)
+
+        # Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict`
+        # to avoid serializing chat template in json config file. So let's get it from `self` directly
+        if self.chat_template is not None:
+            save_jinja_files = kwargs.get("save_jinja_files", True)
+            is_single_template = isinstance(self.chat_template, str)
+            if save_jinja_files and is_single_template:
+                # New format for single templates is to save them as chat_template.jinja
+                with open(output_chat_template_file_jinja, "w", encoding="utf-8") as f:
+                    f.write(self.chat_template)
+                logger.info(f"chat template saved in {output_chat_template_file_jinja}")
+            elif save_jinja_files and not is_single_template:
+                # New format for multiple templates is to save the default as chat_template.jinja
+                # and the other templates in the chat_templates/ directory
+                for template_name, template in self.chat_template.items():
+                    if template_name == "default":
+                        with open(output_chat_template_file_jinja, "w", encoding="utf-8") as f:
+                            f.write(self.chat_template["default"])
+                        logger.info(f"chat template saved in {output_chat_template_file_jinja}")
+                    else:
+                        os.makedirs(chat_template_dir, exist_ok=True)
+                        template_filepath = os.path.join(chat_template_dir, f"{template_name}.jinja")
+                        with open(template_filepath, "w", encoding="utf-8") as f:
+                            f.write(template)
+                        logger.info(f"chat template saved in {template_filepath}")
+            elif is_single_template:
+                # Legacy format for single templates: Put them in chat_template.json
+                chat_template_json_string = (
+                    json.dumps({"chat_template": self.chat_template}, indent=2, sort_keys=True) + "\n"
+                )
+                with open(output_chat_template_file_legacy, "w", encoding="utf-8") as writer:
+                    writer.write(chat_template_json_string)
+                logger.info(f"chat template saved in {output_chat_template_file_legacy}")
+            elif self.chat_template is not None:
+                # At this point we have multiple templates in the legacy format, which is not supported
+                # chat template dicts are saved to chat_template.json as lists of dicts with fixed key names.
+                raise ValueError(
+                    "Multiple chat templates are not supported in the legacy format. Please save them as "
+                    "separate files using the `save_jinja_files` argument."
+                )
+
+        if legacy_serialization:
+            output_audio_tokenizer_file = os.path.join(save_directory, AUDIO_TOKENIZER_NAME)
+            processor_dict = self.to_dict()
+
+            # For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and
+            # `auto_map` is not specified.
+            if set(processor_dict.keys()) != {"processor_class"}:
+                self.to_json_file(output_processor_file)
+                logger.info(f"processor saved in {output_processor_file}")
+
+            if set(processor_dict.keys()) == {"processor_class"}:
+                return_files = []
+            else:
+                return_files = [output_processor_file]
+
+            if self.audio_tokenizer is not None:
+                audio_tokenizer_class = self.audio_tokenizer.__class__.__name__
+                audio_tokenizer_name_or_path = self.audio_tokenizer.name_or_path
+                audio_tokenizer_dict = {
+                    "audio_tokenizer_class": audio_tokenizer_class,
+                    "audio_tokenizer_name_or_path": audio_tokenizer_name_or_path,
+                }
+                audio_tokenizer_json = json.dumps(audio_tokenizer_dict, indent=2, sort_keys=True) + "\n"
+                with open(output_audio_tokenizer_file, "w", encoding="utf-8") as writer:
+                    writer.write(audio_tokenizer_json)
+
+        # Create a unified `preprocessor_config.json` and save all attributes as a composite config, except for tokenizers
+        # NOTE: this will become the default way to save all processor attrbiutes in future versions. Toggled off for now to give
+        # us time for smoother transition
+        else:
+            self.to_json_file(output_processor_file, legacy_serialization=False)
+            logger.info(f"processor saved in {output_processor_file}")
+            return_files = [output_processor_file]
+
+        if push_to_hub:
+            self._upload_modified_files(
+                save_directory,
+                repo_id,
+                files_timestamps,
+                commit_message=commit_message,
+                token=kwargs.get("token"),
+            )
+
+        return return_files
+
+    @classmethod
+    def get_processor_dict(
+        cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+    ) -> tuple[dict[str, Any], dict[str, Any]]:
+        """
+        From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
+        processor of type [`~processing_utils.ProcessingMixin`] using `from_args_and_dict`.
+
+        Parameters:
+            pretrained_model_name_or_path (`str` or `os.PathLike`):
+                The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
+            subfolder (`str`, *optional*, defaults to `""`):
+                In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+                specify the folder name here.
+
+        Returns:
+            `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the processor object.
+        """
+        # holding a copy for optionally loading the audio tokenizer (if available)
+        audio_tokenizer_kwargs = copy.deepcopy(kwargs)
+
+        cache_dir = kwargs.pop("cache_dir", None)
+        force_download = kwargs.pop("force_download", False)
+        resume_download = kwargs.pop("resume_download", None)
+        proxies = kwargs.pop("proxies", None)
+        token = kwargs.pop("token", None)
+        local_files_only = kwargs.pop("local_files_only", False)
+        revision = kwargs.pop("revision", None)
+        subfolder = kwargs.pop("subfolder", "")
+
+        from_pipeline = kwargs.pop("_from_pipeline", None)
+        from_auto_class = kwargs.pop("_from_auto", False)
+
+        user_agent = {"file_type": "processor", "from_auto_class": from_auto_class}
+        if from_pipeline is not None:
+            user_agent["using_pipeline"] = from_pipeline
+
+        if is_offline_mode() and not local_files_only:
+            logger.info("Offline mode: forcing local_files_only=True")
+            local_files_only = True
+
+        pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+        is_local = os.path.isdir(pretrained_model_name_or_path)
+        if os.path.isdir(pretrained_model_name_or_path):
+            processor_file = os.path.join(pretrained_model_name_or_path, PROCESSOR_NAME)
+
+        additional_chat_template_files = {}
+        resolved_additional_chat_template_files = {}
+        if os.path.isfile(pretrained_model_name_or_path):
+            resolved_processor_file = pretrained_model_name_or_path
+            # can't load chat-template and audio tokenizer when given a file as pretrained_model_name_or_path
+            resolved_chat_template_file = None
+            resolved_raw_chat_template_file = None
+            resolved_audio_tokenizer_file = None
+            is_local = True
+        elif is_remote_url(pretrained_model_name_or_path):
+            processor_file = pretrained_model_name_or_path
+            resolved_processor_file = download_url(pretrained_model_name_or_path)
+            # can't load chat-template and audio tokenizer when given a file url as pretrained_model_name_or_path
+            resolved_chat_template_file = None
+            resolved_raw_chat_template_file = None
+            resolved_audio_tokenizer_file = None
+        else:
+            if is_local:
+                template_dir = Path(pretrained_model_name_or_path, CHAT_TEMPLATE_DIR)
+                if template_dir.is_dir():
+                    for template_file in template_dir.glob("*.jinja"):
+                        template_name = template_file.stem
+                        additional_chat_template_files[template_name] = f"{CHAT_TEMPLATE_DIR}/{template_file.name}"
+            else:
+                try:
+                    for template in list_repo_templates(
+                        pretrained_model_name_or_path,
+                        local_files_only=local_files_only,
+                        revision=revision,
+                        cache_dir=cache_dir,
+                    ):
+                        additional_chat_template_files[template] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja"
+                except EntryNotFoundError:
+                    pass  # No template dir means no template files
+            processor_file = PROCESSOR_NAME
+
+            try:
+                # Load from local folder or from cache or download from model Hub and cache
+                resolved_processor_file = cached_file(
+                    pretrained_model_name_or_path,
+                    processor_file,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    proxies=proxies,
+                    resume_download=resume_download,
+                    local_files_only=local_files_only,
+                    token=token,
+                    user_agent=user_agent,
+                    revision=revision,
+                    subfolder=subfolder,
+                    _raise_exceptions_for_missing_entries=False,
+                )
+
+                # chat_template.json is a legacy file used by the processor class
+                # a raw chat_template.jinja is preferred in future
+                resolved_chat_template_file = cached_file(
+                    pretrained_model_name_or_path,
+                    LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    proxies=proxies,
+                    resume_download=resume_download,
+                    local_files_only=local_files_only,
+                    token=token,
+                    user_agent=user_agent,
+                    revision=revision,
+                    subfolder=subfolder,
+                    _raise_exceptions_for_missing_entries=False,
+                )
+
+                resolved_raw_chat_template_file = cached_file(
+                    pretrained_model_name_or_path,
+                    CHAT_TEMPLATE_FILE,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    proxies=proxies,
+                    resume_download=resume_download,
+                    local_files_only=local_files_only,
+                    token=token,
+                    user_agent=user_agent,
+                    revision=revision,
+                    subfolder=subfolder,
+                    _raise_exceptions_for_missing_entries=False,
+                )
+
+                resolved_additional_chat_template_files = {
+                    template_name: cached_file(
+                        pretrained_model_name_or_path,
+                        template_file,
+                        cache_dir=cache_dir,
+                        force_download=force_download,
+                        proxies=proxies,
+                        resume_download=resume_download,
+                        local_files_only=local_files_only,
+                        token=token,
+                        user_agent=user_agent,
+                        revision=revision,
+                        subfolder=subfolder,
+                        _raise_exceptions_for_missing_entries=False,
+                    )
+                    for template_name, template_file in additional_chat_template_files.items()
+                }
+
+                resolved_audio_tokenizer_file = cached_file(
+                    pretrained_model_name_or_path,
+                    AUDIO_TOKENIZER_NAME,
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    proxies=proxies,
+                    resume_download=resume_download,
+                    local_files_only=local_files_only,
+                    token=token,
+                    user_agent=user_agent,
+                    revision=revision,
+                    subfolder=subfolder,
+                    _raise_exceptions_for_missing_entries=False,
+                )
+            except OSError:
+                # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
+                # the original exception.
+                raise
+            except Exception:
+                # For any other exception, we throw a generic error.
+                raise OSError(
+                    f"Can't load processor for '{pretrained_model_name_or_path}'. If you were trying to load"
+                    " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+                    f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+                    f" directory containing a {PROCESSOR_NAME} file"
+                )
+
+        # Add chat template as kwarg before returning because most models don't have processor config
+        if resolved_chat_template_file is not None:
+            # This is the legacy path
+            with open(resolved_chat_template_file, encoding="utf-8") as reader:
+                chat_template_json = json.loads(reader.read())
+                chat_templates = {"default": chat_template_json["chat_template"]}
+                if resolved_additional_chat_template_files:
+                    raise ValueError(
+                        "Cannot load chat template due to conflicting files - this checkpoint combines "
+                        "a legacy chat_template.json file with separate template files, which is not "
+                        "supported. To resolve this error, replace the legacy chat_template.json file "
+                        "with a modern chat_template.jinja file."
+                    )
+        else:
+            chat_templates = {
+                template_name: open(template_file, "r", encoding="utf-8").read()
+                for template_name, template_file in resolved_additional_chat_template_files.items()
+            }
+            if resolved_raw_chat_template_file is not None:
+                with open(resolved_raw_chat_template_file, "r", encoding="utf-8") as reader:
+                    chat_templates["default"] = reader.read()
+        if isinstance(chat_templates, dict) and "default" in chat_templates and len(chat_templates) == 1:
+            chat_templates = chat_templates["default"]  # Flatten when we just have a single template/file
+
+        if chat_templates:
+            kwargs["chat_template"] = chat_templates
+
+        # Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not
+        # updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict.
+        # (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception)
+        # However, for models added in the future, we won't get the expected error if this file is missing.
+        if resolved_processor_file is None:
+            # In any case we need to pass `chat_template` if it is available
+            processor_dict = {}
+        else:
+            try:
+                # Load processor dict
+                with open(resolved_processor_file, encoding="utf-8") as reader:
+                    text = reader.read()
+                processor_dict = json.loads(text)
+
+            except json.JSONDecodeError:
+                raise OSError(
+                    f"It looks like the config file at '{resolved_processor_file}' is not a valid JSON file."
+                )
+
+        if is_local:
+            logger.info(f"loading configuration file {resolved_processor_file}")
+        else:
+            logger.info(f"loading configuration file {processor_file} from cache at {resolved_processor_file}")
+
+        if "chat_template" in processor_dict and processor_dict["chat_template"] is not None:
+            logger.warning_once(
+                "Chat templates should be in a 'chat_template.jinja' file but found key='chat_template' "
+                "in the processor's config. Make sure to move your template to its own file."
+            )
+
+        if "chat_template" in kwargs:
+            processor_dict["chat_template"] = kwargs.pop("chat_template")
+
+        # Audio tokenizer needs to load the model checkpoint first, because the saved
+        # json file contains only references to the model path and repo id
+        if resolved_audio_tokenizer_file is not None or "audio_tokenizer" in processor_dict:
+            if resolved_audio_tokenizer_file is not None:
+                reader = open(resolved_audio_tokenizer_file, "r", encoding="utf-8")
+                audio_tokenizer_dict = reader.read()
+                audio_tokenizer_dict = json.loads(audio_tokenizer_dict)
+            else:
+                audio_tokenizer_dict = processor_dict["audio_tokenizer"]
+
+            audio_tokenizer_class = cls.get_possibly_dynamic_module(audio_tokenizer_dict["audio_tokenizer_class"])
+            audio_tokenizer_path = audio_tokenizer_dict["audio_tokenizer_name_or_path"]
+            processor_dict["audio_tokenizer"] = audio_tokenizer_class.from_pretrained(
+                audio_tokenizer_path, **audio_tokenizer_kwargs
+            )
+
+        # Pop attributes if saved in a single processor dict, they are loaded in `_get_arguments_from_pretrained`
+        for attribute in cls.attributes:
+            processor_dict.pop(attribute, None)
+
+        return processor_dict, kwargs
+
+    @classmethod
+    def from_args_and_dict(cls, args, processor_dict: dict[str, Any], **kwargs):
+        """
+        Instantiates a type of [`~processing_utils.ProcessingMixin`] from a Python dictionary of parameters.
+
+        Args:
+            processor_dict (`dict[str, Any]`):
+                Dictionary that will be used to instantiate the processor object. Such a dictionary can be
+                retrieved from a pretrained checkpoint by leveraging the
+                [`~processing_utils.ProcessingMixin.to_dict`] method.
+            kwargs (`dict[str, Any]`):
+                Additional parameters from which to initialize the processor object.
+
+        Returns:
+            [`~processing_utils.ProcessingMixin`]: The processor object instantiated from those
+            parameters.
+        """
+        processor_dict = processor_dict.copy()
+        return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
+
+        # We have to pop up some unused (but specific) kwargs and then validate that it doesn't contain unused kwargs
+        # If we don't pop, some specific kwargs will raise a warning
+        if "processor_class" in processor_dict:
+            del processor_dict["processor_class"]
+
+        if "auto_map" in processor_dict:
+            del processor_dict["auto_map"]
+
+        # override processor_dict with given kwargs
+        processor_dict.update(kwargs)
+
+        # check if there is an overlap between args and processor_dict
+        accepted_args_and_kwargs = cls.__init__.__code__.co_varnames[: cls.__init__.__code__.co_argcount][1:]
+
+        # validate both processor_dict and given kwargs
+        unused_kwargs, valid_kwargs = cls.validate_init_kwargs(
+            processor_config=processor_dict, valid_kwargs=accepted_args_and_kwargs
+        )
+
+        # update args that are already in processor_dict to avoid duplicate arguments
+        args_to_update = {
+            i: valid_kwargs.pop(arg)
+            for i, arg in enumerate(accepted_args_and_kwargs)
+            if (arg in valid_kwargs and i < len(args))
+        }
+        args = [args_to_update.get(i, arg) for i, arg in enumerate(args)]
+
+        # instantiate processor with used (and valid) kwargs only
+        processor = cls(*args, **valid_kwargs)
+
+        logger.info(f"Processor {processor}")
+        if return_unused_kwargs:
+            return processor, unused_kwargs
+        else:
+            return processor
+
+    def _merge_kwargs(
+        self,
+        ModelProcessorKwargs: ProcessingKwargs,
+        tokenizer_init_kwargs: Optional[dict] = None,
+        **kwargs,
+    ) -> dict[str, dict]:
+        """
+        Method to merge dictionaries of kwargs cleanly separated by modality within a Processor instance.
+        The order of operations is as follows:
+            1) kwargs passed as before have highest priority to preserve BC.
+                ```python
+                high_priority_kwargs = {"crop_size" = {"height": 222, "width": 222}, "padding" = "max_length"}
+                processor(..., **high_priority_kwargs)
+                ```
+            2) kwargs passed as modality-specific kwargs have second priority. This is the recommended API.
+                ```python
+                processor(..., text_kwargs={"padding": "max_length"}, images_kwargs={"crop_size": {"height": 222, "width": 222}}})
+                ```
+            3) kwargs passed during instantiation of a modality processor have fourth priority.
+                ```python
+                tokenizer = tokenizer_class(..., {"padding": "max_length"})
+                image_processor = image_processor_class(...)
+                processor(tokenizer, image_processor) # will pass max_length unless overridden by kwargs at call
+                ```
+            4) defaults kwargs specified at processor level have lowest priority.
+                ```python
+                class MyProcessingKwargs(ProcessingKwargs, CommonKwargs, TextKwargs, ImagesKwargs, total=False):
+                    _defaults = {
+                        "text_kwargs": {
+                            "padding": "max_length",
+                            "max_length": 64,
+                        },
+                    }
+                ```
+        Args:
+            ModelProcessorKwargs (`ProcessingKwargs`):
+                Typed dictionary of kwargs specifically required by the model passed.
+            tokenizer_init_kwargs (`Dict`, *optional*):
+                Dictionary of kwargs the tokenizer was instantiated with and need to take precedence over defaults.
+
+        Returns:
+            output_kwargs (`Dict`):
+                Dictionary of per-modality kwargs to be passed to each modality-specific processor.
+
+        """
+        # Initialize dictionaries
+        output_kwargs = {
+            "text_kwargs": {},
+            "images_kwargs": {},
+            "audio_kwargs": {},
+            "videos_kwargs": {},
+            "common_kwargs": {},
+        }
+
+        default_kwargs = {
+            "text_kwargs": {},
+            "images_kwargs": {},
+            "audio_kwargs": {},
+            "videos_kwargs": {},
+            "common_kwargs": {},
+        }
+
+        possible_modality_keywords = {"text", "audio", "videos", "images"}
+        used_keys = set()
+
+        # get defaults from set model processor kwargs if they exist
+        for modality in default_kwargs:  # noqa: PLC0206
+            default_kwargs[modality] = ModelProcessorKwargs._defaults.get(modality, {}).copy()
+            # update defaults with arguments from tokenizer init
+            for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__:
+                # init with tokenizer init kwargs if necessary
+                if tokenizer_init_kwargs is not None and modality_key in tokenizer_init_kwargs:
+                    value = (
+                        getattr(self.tokenizer, modality_key)
+                        if hasattr(self.tokenizer, modality_key)
+                        else tokenizer_init_kwargs[modality_key]
+                    )
+                    default_kwargs[modality][modality_key] = value
+        # now defaults kwargs are updated with the tokenizers defaults.
+        # pass defaults to output dictionary
+        output_kwargs.update(default_kwargs)
+
+        # update modality kwargs with passed kwargs
+        non_modality_kwargs = set(kwargs) - set(output_kwargs)
+        for modality, output_kwarg in output_kwargs.items():
+            for modality_key in ModelProcessorKwargs.__annotations__[modality].__annotations__:
+                # check if we received a structured kwarg dict or not to handle it correctly
+                if modality in kwargs:
+                    kwarg_value = kwargs[modality].pop(modality_key, "__empty__")
+                    # check if this key was passed as a flat kwarg.
+                    if kwarg_value != "__empty__" and modality_key in non_modality_kwargs:
+                        raise ValueError(
+                            f"Keyword argument {modality_key} was passed two times:\n"
+                            f"in a dictionary for {modality} and as a **kwarg."
+                        )
+                elif modality_key in kwargs:
+                    # we get a modality_key instead of popping it because modality-specific processors
+                    # can have overlapping kwargs
+                    kwarg_value = kwargs.get(modality_key, "__empty__")
+                else:
+                    kwarg_value = "__empty__"
+                if not isinstance(kwarg_value, str) or kwarg_value != "__empty__":
+                    output_kwarg[modality_key] = kwarg_value
+                    used_keys.add(modality_key)
+
+        # Determine if kwargs is a flat dictionary or contains nested dictionaries
+        if any(key in default_kwargs for key in kwargs):
+            # kwargs is dictionary-based, and some keys match modality names
+            for modality, subdict in kwargs.items():
+                if modality in default_kwargs:
+                    for subkey, subvalue in subdict.items():
+                        if subkey not in used_keys:
+                            output_kwargs[modality][subkey] = subvalue
+                            used_keys.add(subkey)
+        else:
+            # kwargs is a flat dictionary
+            for key, kwarg in kwargs.items():
+                if key not in used_keys:
+                    if key in ModelProcessorKwargs.__annotations__["common_kwargs"].__annotations__:
+                        output_kwargs["common_kwargs"][key] = kwarg
+                    elif key not in possible_modality_keywords:
+                        logger.warning_once(
+                            f"Keyword argument `{key}` is not a valid argument for this processor and will be ignored."
+                        )
+
+        # all modality-specific kwargs are updated with common kwargs
+        for kwarg in output_kwargs.values():
+            kwarg.update(output_kwargs["common_kwargs"])
+        return output_kwargs
+
+    @classmethod
+    def from_pretrained(
+        cls: type[SpecificProcessorType],
+        pretrained_model_name_or_path: Union[str, os.PathLike],
+        cache_dir: Optional[Union[str, os.PathLike]] = None,
+        force_download: bool = False,
+        local_files_only: bool = False,
+        token: Optional[Union[str, bool]] = None,
+        revision: str = "main",
+        **kwargs,
+    ) -> SpecificProcessorType:
+        r"""
+        Instantiate a processor associated with a pretrained model.
+
+        
+
+        This class method is simply calling the feature extractor
+        [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`], image processor
+        [`~image_processing_utils.ImageProcessingMixin`] and the tokenizer
+        [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`] methods. Please refer to the docstrings of the
+        methods above for more information.
+
+        
+
+        Args:
+            pretrained_model_name_or_path (`str` or `os.PathLike`):
+                This can be either:
+
+                - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
+                  huggingface.co.
+                - a path to a *directory* containing a feature extractor file saved using the
+                  [`~SequenceFeatureExtractor.save_pretrained`] method, e.g., `./my_model_directory/`.
+                - a path or url to a saved feature extractor JSON *file*, e.g.,
+                  `./my_model_directory/preprocessor_config.json`.
+            **kwargs
+                Additional keyword arguments passed along to both
+                [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] and
+                [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`].
+        """
+        kwargs["cache_dir"] = cache_dir
+        kwargs["force_download"] = force_download
+        kwargs["local_files_only"] = local_files_only
+        kwargs["revision"] = revision
+
+        use_auth_token = kwargs.pop("use_auth_token", None)
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if token is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            token = use_auth_token
+
+        if token is not None:
+            kwargs["token"] = token
+
+        args = cls._get_arguments_from_pretrained(pretrained_model_name_or_path, **kwargs)
+        processor_dict, kwargs = cls.get_processor_dict(pretrained_model_name_or_path, **kwargs)
+        return cls.from_args_and_dict(args, processor_dict, **kwargs)
+
+    @classmethod
+    def register_for_auto_class(cls, auto_class="AutoProcessor"):
+        """
+        Register this class with a given auto class. This should only be used for custom feature extractors as the ones
+        in the library are already mapped with `AutoProcessor`.
+
+
+
+        Args:
+            auto_class (`str` or `type`, *optional*, defaults to `"AutoProcessor"`):
+                The auto class to register this new feature extractor with.
+        """
+        if not isinstance(auto_class, str):
+            auto_class = auto_class.__name__
+
+        import transformers.models.auto as auto_module
+
+        if not hasattr(auto_module, auto_class):
+            raise ValueError(f"{auto_class} is not a valid auto class.")
+
+        cls._auto_class = auto_class
+
+    @classmethod
+    def _get_arguments_from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+        """
+        Identify and instantiate the subcomponents of Processor classes, like image processors and
+        tokenizers. This method uses the Processor attributes like `tokenizer_class` to figure out what class those
+        subcomponents should be. Note that any subcomponents must either be library classes that are accessible in
+        the `transformers` root, or they must be custom code that has been registered with the relevant autoclass,
+        via methods like `AutoTokenizer.register()`. If neither of these conditions are fulfilled, this method
+        will be unable to find the relevant subcomponent class and will raise an error.
+        """
+        args = []
+        for attribute_name in cls.attributes:
+            class_name = getattr(cls, f"{attribute_name}_class")
+            if isinstance(class_name, tuple):
+                classes = tuple(cls.get_possibly_dynamic_module(n) if n is not None else None for n in class_name)
+                if attribute_name == "image_processor":
+                    # TODO: @yoni, change logic in v4.52 (when use_fast set to True by default)
+                    use_fast = kwargs.get("use_fast")
+                    if use_fast is None:
+                        logger.warning_once(
+                            "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. "
+                            "`use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. "
+                            "This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`."
+                        )
+                else:
+                    use_fast = kwargs.get("use_fast", True)
+                if use_fast and classes[1] is not None:
+                    attribute_class = classes[1]
+                else:
+                    attribute_class = classes[0]
+            else:
+                attribute_class = cls.get_possibly_dynamic_module(class_name)
+
+            args.append(attribute_class.from_pretrained(pretrained_model_name_or_path, **kwargs))
+
+        return args
+
+    @staticmethod
+    def get_possibly_dynamic_module(module_name):
+        if hasattr(transformers_module, module_name):
+            return getattr(transformers_module, module_name)
+        lookup_locations = [
+            transformers_module.IMAGE_PROCESSOR_MAPPING,
+            transformers_module.VIDEO_PROCESSOR_MAPPING,
+            transformers_module.TOKENIZER_MAPPING,
+            transformers_module.FEATURE_EXTRACTOR_MAPPING,
+            transformers_module.MODEL_FOR_AUDIO_TOKENIZATION_MAPPING,
+        ]
+        for lookup_location in lookup_locations:
+            for custom_class in lookup_location._extra_content.values():
+                if isinstance(custom_class, tuple):
+                    for custom_subclass in custom_class:
+                        if custom_subclass is not None and custom_subclass.__name__ == module_name:
+                            return custom_subclass
+                elif custom_class is not None and custom_class.__name__ == module_name:
+                    return custom_class
+        raise ValueError(
+            f"Could not find module {module_name} in `transformers`. If this is a custom class, "
+            f"it should be registered using the relevant `AutoClass.register()` function so that "
+            f"other functions can find it!"
+        )
+
+    def batch_decode(self, *args, **kwargs):
+        """
+        This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
+        refer to the docstring of this method for more information.
+        """
+        if not hasattr(self, "tokenizer"):
+            raise ValueError(f"Cannot batch decode text: {self.__class__.__name__} has no tokenizer.")
+        return self.tokenizer.batch_decode(*args, **kwargs)
+
+    def decode(self, *args, **kwargs):
+        """
+        This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
+        the docstring of this method for more information.
+        """
+        if not hasattr(self, "tokenizer"):
+            raise ValueError(f"Cannot decode text: {self.__class__.__name__} has no tokenizer.")
+        return self.tokenizer.decode(*args, **kwargs)
+
+    @property
+    def model_input_names(self):
+        model_input_names = []
+        for attribute_name in self.attributes:
+            attribute = getattr(self, attribute_name, None)
+            attr_input_names = getattr(attribute, "model_input_names")
+            model_input_names.extend(attr_input_names)
+        return model_input_names
+
+    @staticmethod
+    def validate_init_kwargs(processor_config, valid_kwargs):
+        kwargs_from_config = set(processor_config.keys())
+        valid_kwargs_set = set(valid_kwargs)
+
+        unused_keys = kwargs_from_config - valid_kwargs_set
+        valid_keys = kwargs_from_config & valid_kwargs_set
+
+        unused_kwargs = {k: processor_config[k] for k in unused_keys} if unused_keys else {}
+        valid_kwargs = {k: processor_config[k] for k in valid_keys} if valid_keys else {}
+
+        return unused_kwargs, valid_kwargs
+
+    @deprecate_kwarg("video_fps", version="4.58", new_name="fps")
+    @deprecate_kwarg(
+        "video_load_backend",
+        version="4.59",
+        additional_message=". This function will use `torchcodec` by default, or `torchvision` if `torchcodec` is not installed.",
+    )
+    def apply_chat_template(
+        self,
+        conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
+        chat_template: Optional[str] = None,
+        **kwargs: Unpack[AllKwargsForChatTemplate],
+    ) -> str:
+        """
+        Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
+        conversations to turn them into a single tokenizable string.
+
+        The input is expected to be in the following format, where each message content is a list consisting of text and
+        optionally image or video inputs. One can also provide an image, video, URL or local path which will be used to form
+        `pixel_values` when `return_dict=True`. If not provided, one will get only the formatted text, optionally tokenized text.
+
+        conversation = [
+            {
+                "role": "user",
+                "content": [
+                    {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
+                    {"type": "text", "text": "Please describe this image in detail."},
+                ],
+            },
+        ]
+
+        Args:
+            conversation (`Union[list[Dict, [str, str]], list[list[dict[str, str]]]]`):
+                The conversation to format.
+            chat_template (`Optional[str]`, *optional*):
+                The Jinja template to use for formatting the conversation. If not provided, the tokenizer's
+                chat template is used.
+        """
+        if chat_template is None:
+            if isinstance(self.chat_template, dict) and "default" in self.chat_template:
+                chat_template = self.chat_template["default"]
+            elif isinstance(self.chat_template, dict):
+                raise ValueError(
+                    'The processor has multiple chat templates but none of them are named "default". You need to specify'
+                    " which one to use by passing the `chat_template` argument. Available templates are: "
+                    f"{', '.join(self.chat_template.keys())}"
+                )
+            elif self.chat_template is not None:
+                chat_template = self.chat_template
+            else:
+                raise ValueError(
+                    "Cannot use apply_chat_template because this processor does not have a chat template."
+                )
+        else:
+            if isinstance(self.chat_template, dict) and chat_template in self.chat_template:
+                # It's the name of a template, not a full template string
+                chat_template = self.chat_template[chat_template]
+            else:
+                # It's a template string, render it directly
+                pass
+
+        is_tokenizers_fast = hasattr(self, "tokenizer") and self.tokenizer.__class__.__name__.endswith("Fast")
+
+        if kwargs.get("continue_final_message", False):
+            if kwargs.get("add_generation_prompt", False):
+                raise ValueError(
+                    "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."
+                )
+            if kwargs.get("return_assistant_tokens_mask", False):
+                raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")
+
+        if kwargs.get("return_assistant_tokens_mask", False):
+            if not is_tokenizers_fast:
+                raise ValueError(
+                    "`return_assistant_tokens_mask` is not possible with slow tokenizers. Make sure you have `tokenizers` installed. "
+                    "If the error persists, open an issue to support a Fast tokenizer for your model."
+                )
+            else:
+                kwargs["return_offsets_mapping"] = True  # force offset mapping so we can infer token boundaries
+
+        # Fill sets of kwargs that should be used by different parts of template
+        processed_kwargs = {
+            "mm_load_kwargs": {},
+            "template_kwargs": {},
+        }
+
+        for kwarg_type in processed_kwargs:
+            for key in AllKwargsForChatTemplate.__annotations__[kwarg_type].__annotations__:
+                kwarg_type_defaults = AllKwargsForChatTemplate.__annotations__[kwarg_type]
+                default_value = getattr(kwarg_type_defaults, key, None)
+                value = kwargs.pop(key, default_value)
+                if value is not None and not isinstance(value, dict):
+                    processed_kwargs[kwarg_type][key] = value
+
+        # pop unused and deprecated kwarg
+        kwargs.pop("video_load_backend", None)
+
+        # Pass unprocessed custom kwargs
+        processed_kwargs["template_kwargs"].update(kwargs)
+
+        if isinstance(conversation, (list, tuple)) and (
+            isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "content")
+        ):
+            is_batched = True
+            conversations = conversation
+        else:
+            is_batched = False
+            conversations = [conversation]
+
+        tokenize = processed_kwargs["template_kwargs"].pop("tokenize", False)
+        return_dict = processed_kwargs["template_kwargs"].pop("return_dict", False)
+        mm_load_kwargs = processed_kwargs["mm_load_kwargs"]
+
+        if tokenize:
+            batch_images, batch_videos = [], []
+            batch_audios = []
+            for conversation in conversations:
+                for message in conversation:
+                    visuals = [content for content in message["content"] if content["type"] in ["image", "video"]]
+                    audio_fnames = [
+                        content[key]
+                        for content in message["content"]
+                        for key in ["audio", "url", "path"]
+                        if key in content and content["type"] == "audio"
+                    ]
+                    image_fnames = [
+                        vision_info[key]
+                        for vision_info in visuals
+                        for key in ["image", "url", "path", "base64"]
+                        if key in vision_info and vision_info["type"] == "image"
+                    ]
+                    video_fnames = [
+                        vision_info[key]
+                        for vision_info in visuals
+                        for key in ["video", "url", "path"]
+                        if key in vision_info and vision_info["type"] == "video"
+                    ]
+
+                    # Audio models do not accept nested list of audios (yet!) so we construct a flat input audio list
+                    if not mm_load_kwargs["load_audio_from_video"]:
+                        for fname in audio_fnames:
+                            batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"]))
+                    else:
+                        for fname in video_fnames:
+                            batch_audios.append(load_audio(fname, sampling_rate=mm_load_kwargs["sampling_rate"]))
+
+                    # Currently all processors can accept nested list of batches, but not flat list of visuals
+                    # So we'll make a batched list of images and let the processor handle it
+                    if image_fnames:
+                        batch_images.append(image_fnames)
+                    if video_fnames:
+                        batch_videos.append(video_fnames)
+
+        prompt, generation_indices = render_jinja_template(
+            conversations=conversations,
+            chat_template=chat_template,
+            **processed_kwargs["template_kwargs"],  # different flags such as `return_assistant_mask`
+            **self.tokenizer.special_tokens_map,  # tokenizer special tokens are used by some templates
+        )
+
+        if not is_batched:
+            prompt = prompt[0]
+
+        if tokenize:
+            # Tokenizer's `apply_chat_template` never adds special tokens when tokenizing
+            # But processor's `apply_chat_template` didn't have an option to tokenize, so users had to format the prompt
+            # and pass it to the processor. Users thus never worried about special tokens relying on processor handling
+            # everything internally. The below line is to keep BC for that and be able to work with model that have
+            # special tokens in the template (consistent with tokenizers). We dont want to raise warning, it will flood command line
+            # without actionable solution for users
+            single_prompt = prompt[0] if is_batched else prompt
+            if self.tokenizer.bos_token is not None and single_prompt.startswith(self.tokenizer.bos_token):
+                kwargs["add_special_tokens"] = False
+
+            # Always sample frames by default unless explicitly set to `False` by users. If users do not pass `num_frames`/`video_fps`
+            # sampling should not done for BC.
+            if "do_sample_frames" not in kwargs and ("fps" in kwargs or "num_frames" in kwargs):
+                kwargs["do_sample_frames"] = True
+
+            out = self(
+                text=prompt,
+                images=batch_images if batch_images else None,
+                videos=batch_videos if batch_videos else None,
+                audio=batch_audios if batch_audios else None,
+                **kwargs,
+            )
+
+            if return_dict:
+                if processed_kwargs["template_kwargs"].get("return_assistant_tokens_mask", False):
+                    assistant_masks = []
+                    offset_mapping = out.pop("offset_mapping")
+                    input_ids = out["input_ids"]
+                    for i in range(len(input_ids)):
+                        current_mask = [0] * len(input_ids[i])
+                        offsets = offset_mapping[i]
+                        offset_starts = [start for start, end in offsets]
+                        for assistant_start_char, assistant_end_char in generation_indices[i]:
+                            start_pos = bisect.bisect_left(offset_starts, assistant_start_char)
+                            end_pos = bisect.bisect_left(offset_starts, assistant_end_char)
+
+                            if not (
+                                start_pos >= 0
+                                and offsets[start_pos][0] <= assistant_start_char < offsets[start_pos][1]
+                            ):
+                                # start_token is out of bounds maybe due to truncation.
+                                continue
+                            for token_id in range(start_pos, end_pos if end_pos else len(input_ids[i])):
+                                current_mask[token_id] = 1
+                        assistant_masks.append(current_mask)
+                    out["assistant_masks"] = assistant_masks
+                    out.convert_to_tensors(tensor_type=kwargs.get("return_tensors"))
+                return out
+            else:
+                return out["input_ids"]
+        return prompt
+
+    def post_process_image_text_to_text(self, generated_outputs, skip_special_tokens=True, **kwargs):
+        """
+        Post-process the output of a vlm to decode the text.
+
+        Args:
+            generated_outputs (`torch.Tensor` or `np.ndarray`):
+                The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
+                or `(sequence_length,)`.
+            skip_special_tokens (`bool`, *optional*, defaults to `True`):
+                Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
+            **kwargs:
+                Additional arguments to be passed to the tokenizer's `batch_decode method`.
+
+        Returns:
+            `list[str]`: The decoded text.
+        """
+        return self.tokenizer.batch_decode(generated_outputs, skip_special_tokens=skip_special_tokens, **kwargs)
+
+    def _check_special_mm_tokens(self, text: list[str], text_inputs: "BatchFeature", modalities: list[str]):
+        """
+        Checks that number of special tokens in text and processed text is same. The count can be different
+        if tokenized text was truncated, leading to issues in model code.
+        """
+        for modality in modalities:
+            token_str = getattr(self, f"{modality}_token")
+            token_id = getattr(self, f"{modality}_token_id")
+            ids_count = [list(ids).count(token_id) for ids in text_inputs["input_ids"]]
+            text_count = [sample.count(token_str) for sample in text]
+
+            if ids_count != text_count:
+                raise ValueError(
+                    f"Mismatch in `{modality}` token count between text and `input_ids`. Got ids={ids_count} and text={text_count}. "
+                    "Likely due to `truncation='max_length'`. Please disable truncation or increase `max_length`."
+                )
+
+
+ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
+if ProcessorMixin.push_to_hub.__doc__ is not None:
+    ProcessorMixin.push_to_hub.__doc__ = ProcessorMixin.push_to_hub.__doc__.format(
+        object="processor", object_class="AutoProcessor", object_files="processor files"
+    )
diff --git a/phivenv/Lib/site-packages/transformers/py.typed b/phivenv/Lib/site-packages/transformers/py.typed
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/phivenv/Lib/site-packages/transformers/pytorch_utils.py b/phivenv/Lib/site-packages/transformers/pytorch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..87136d079f104f28480a6fa8183dcff3640e63a3
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/pytorch_utils.py
@@ -0,0 +1,380 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+import inspect
+from functools import lru_cache, wraps
+from typing import Callable
+
+import torch
+from safetensors.torch import storage_ptr, storage_size
+from torch import nn
+
+from .utils import (
+    is_torch_greater_or_equal,
+    is_torch_xla_available,
+    is_torch_xpu_available,
+    is_torchdynamo_compiling,
+    logging,
+)
+
+
+ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
+
+logger = logging.get_logger(__name__)
+
+is_torch_greater_or_equal_than_2_8 = is_torch_greater_or_equal("2.8", accept_dev=True)
+is_torch_greater_or_equal_than_2_6 = is_torch_greater_or_equal("2.6", accept_dev=True)
+is_torch_greater_or_equal_than_2_4 = is_torch_greater_or_equal("2.4", accept_dev=True)
+is_torch_greater_or_equal_than_2_3 = is_torch_greater_or_equal("2.3", accept_dev=True)
+
+# For backwards compatibility (e.g. some remote codes on Hub using those variables).
+is_torch_greater_or_equal_than_2_2 = is_torch_greater_or_equal("2.2", accept_dev=True)
+is_torch_greater_or_equal_than_2_1 = is_torch_greater_or_equal("2.1", accept_dev=True)
+is_torch_greater_or_equal_than_2_0 = is_torch_greater_or_equal("2.0", accept_dev=True)
+is_torch_greater_or_equal_than_1_13 = is_torch_greater_or_equal("1.13", accept_dev=True)
+is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_dev=True)
+
+# Cache this result has it's a C FFI call which can be pretty time-consuming
+_torch_distributed_available = torch.distributed.is_available()
+
+
+def softmax_backward_data(parent, grad_output, output, dim, self):
+    """
+    A function that calls the internal `_softmax_backward_data` PyTorch method and that adjusts the arguments according
+    to the torch version detected.
+    """
+
+    from torch import _softmax_backward_data
+
+    return _softmax_backward_data(grad_output, output, parent.dim, self.dtype)
+
+
+def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
+    """
+    Prune a linear layer to keep only entries in index.
+
+    Used to remove heads.
+
+    Args:
+        layer (`torch.nn.Linear`): The layer to prune.
+        index (`torch.LongTensor`): The indices to keep in the layer.
+        dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices.
+
+    Returns:
+        `torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
+    """
+    index = index.to(layer.weight.device)
+    W = layer.weight.index_select(dim, index).detach().clone()
+    if layer.bias is not None:
+        if dim == 1:
+            b = layer.bias.detach().clone()
+        else:
+            b = layer.bias[index].detach().clone()
+    new_size = list(layer.weight.size())
+    new_size[dim] = len(index)
+    new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
+    new_layer.weight.requires_grad = False
+    new_layer.weight.copy_(W.contiguous())
+    new_layer.weight.requires_grad = True
+    if layer.bias is not None:
+        new_layer.bias.requires_grad = False
+        new_layer.bias.copy_(b.contiguous())
+        new_layer.bias.requires_grad = True
+    return new_layer
+
+
+class Conv1D(nn.Module):
+    """
+    1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2).
+
+    Basically works like a linear layer but the weights are transposed.
+
+    Args:
+        nf (`int`): The number of output features.
+        nx (`int`): The number of input features.
+    """
+
+    def __init__(self, nf, nx):
+        super().__init__()
+        self.nf = nf
+        self.nx = nx
+        self.weight = nn.Parameter(torch.empty(nx, nf))
+        self.bias = nn.Parameter(torch.zeros(nf))
+        nn.init.normal_(self.weight, std=0.02)
+
+    def __repr__(self) -> str:
+        return "Conv1D(nf={nf}, nx={nx})".format(**self.__dict__)
+
+    def forward(self, x):
+        size_out = x.size()[:-1] + (self.nf,)
+        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
+        x = x.view(size_out)
+        return x
+
+
+def prune_conv1d_layer(layer: Conv1D, index: torch.LongTensor, dim: int = 1) -> Conv1D:
+    """
+    Prune a Conv1D layer to keep only entries in index. A Conv1D work as a Linear layer (see e.g. BERT) but the weights
+    are transposed.
+
+    Used to remove heads.
+
+    Args:
+        layer ([`~pytorch_utils.Conv1D`]): The layer to prune.
+        index (`torch.LongTensor`): The indices to keep in the layer.
+        dim (`int`, *optional*, defaults to 1): The dimension on which to keep the indices.
+
+    Returns:
+        [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
+    """
+    index = index.to(layer.weight.device)
+    W = layer.weight.index_select(dim, index).detach().clone()
+    if dim == 0:
+        b = layer.bias.detach().clone()
+    else:
+        b = layer.bias[index].detach().clone()
+    new_size = list(layer.weight.size())
+    new_size[dim] = len(index)
+    new_layer = Conv1D(new_size[1], new_size[0]).to(layer.weight.device)
+    new_layer.weight.requires_grad = False
+    new_layer.weight.copy_(W.contiguous())
+    new_layer.weight.requires_grad = True
+    new_layer.bias.requires_grad = False
+    new_layer.bias.copy_(b.contiguous())
+    new_layer.bias.requires_grad = True
+    return new_layer
+
+
+def prune_layer(layer: nn.Linear | Conv1D, index: torch.LongTensor, dim: int | None = None) -> nn.Linear | Conv1D:
+    """
+    Prune a Conv1D or linear layer to keep only entries in index.
+
+    Used to remove heads.
+
+    Args:
+        layer (`Union[torch.nn.Linear, Conv1D]`): The layer to prune.
+        index (`torch.LongTensor`): The indices to keep in the layer.
+        dim (`int`, *optional*): The dimension on which to keep the indices.
+
+    Returns:
+        `torch.nn.Linear` or [`~pytorch_utils.Conv1D`]: The pruned layer as a new layer with `requires_grad=True`.
+    """
+    if isinstance(layer, nn.Linear):
+        return prune_linear_layer(layer, index, dim=0 if dim is None else dim)
+    elif isinstance(layer, Conv1D):
+        return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
+    else:
+        raise ValueError(f"Can't prune layer of class {layer.__class__}")
+
+
+def apply_chunking_to_forward(
+    forward_fn: Callable[..., torch.Tensor],
+    chunk_size: int,
+    chunk_dim: int,
+    *input_tensors,
+) -> torch.Tensor:
+    """
+    This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
+    `chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.
+
+    If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
+    applying `forward_fn` to `input_tensors`.
+
+    Args:
+        forward_fn (`Callable[..., torch.Tensor]`):
+            The forward function of the model.
+        chunk_size (`int`):
+            The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
+        chunk_dim (`int`):
+            The dimension over which the `input_tensors` should be chunked.
+        input_tensors (`tuple[torch.Tensor]`):
+            The input tensors of `forward_fn` which will be chunked
+
+    Returns:
+        `torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.
+
+
+    Examples:
+
+    ```python
+    # rename the usual forward() fn to forward_chunk()
+    def forward_chunk(self, hidden_states):
+        hidden_states = self.decoder(hidden_states)
+        return hidden_states
+
+
+    # implement a chunked forward function
+    def forward(self, hidden_states):
+        return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
+    ```"""
+
+    assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
+
+    # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
+    num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
+    if num_args_in_forward_chunk_fn != len(input_tensors):
+        raise ValueError(
+            f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
+            "tensors are given"
+        )
+
+    if chunk_size > 0:
+        tensor_shape = input_tensors[0].shape[chunk_dim]
+        for input_tensor in input_tensors:
+            if input_tensor.shape[chunk_dim] != tensor_shape:
+                raise ValueError(
+                    f"All input tenors have to be of the same shape: {tensor_shape}, "
+                    f"found shape {input_tensor.shape[chunk_dim]}"
+                )
+
+        if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
+            raise ValueError(
+                f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
+                f"size {chunk_size}"
+            )
+
+        num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
+
+        # chunk input tensor into tuples
+        input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
+        # apply forward fn to every tuple
+        output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
+        # concatenate output at same dimension
+        return torch.cat(output_chunks, dim=chunk_dim)
+
+    return forward_fn(*input_tensors)
+
+
+def find_pruneable_heads_and_indices(
+    heads: list[int], n_heads: int, head_size: int, already_pruned_heads: set[int]
+) -> tuple[set[int], torch.LongTensor]:
+    """
+    Finds the heads and their indices taking `already_pruned_heads` into account.
+
+    Args:
+        heads (`list[int]`): List of the indices of heads to prune.
+        n_heads (`int`): The number of heads in the model.
+        head_size (`int`): The size of each head.
+        already_pruned_heads (`Set[int]`): A set of already pruned heads.
+
+    Returns:
+        `tuple[Set[int], torch.LongTensor]`: A tuple with the indices of heads to prune taking `already_pruned_heads`
+        into account and the indices of rows/columns to keep in the layer weight.
+    """
+    mask = torch.ones(n_heads, head_size)
+    heads = set(heads) - already_pruned_heads  # Convert to set and remove already pruned heads
+    for head in heads:
+        # Compute how many pruned heads are before the head and move the index accordingly
+        head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
+        mask[head] = 0
+    mask = mask.view(-1).contiguous().eq(1)
+    index: torch.LongTensor = torch.arange(len(mask))[mask].long()
+    return heads, index
+
+
+def meshgrid(*tensors: torch.Tensor | list[torch.Tensor], indexing: str | None = None) -> tuple[torch.Tensor, ...]:
+    """
+    Wrapper around torch.meshgrid to avoid warning messages about the introduced `indexing` argument.
+
+    Reference: https://pytorch.org/docs/1.13/generated/torch.meshgrid.html
+    """
+    return torch.meshgrid(*tensors, indexing=indexing)
+
+
+def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
+    """
+    Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For
+    example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
+    guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
+    non-overlapping lifetimes may have the same id.
+    """
+    if _torch_distributed_available and is_torch_greater_or_equal("2.5"):
+        from torch.distributed.tensor import DTensor
+
+        if isinstance(tensor, DTensor):
+            local_tensor = tensor.to_local()
+            return tensor.device, local_tensor.storage().data_ptr(), tensor.nbytes
+
+    if tensor.device.type == "xla" and is_torch_xla_available():
+        # NOTE: xla tensors dont have storage
+        # use some other unique id to distinguish.
+        # this is a XLA tensor, it must be created using torch_xla's
+        # device. So the following import is safe:
+        import torch_xla
+
+        unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
+    else:
+        unique_id = storage_ptr(tensor)
+
+    return tensor.device, unique_id, storage_size(tensor)
+
+
+def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) -> torch.Tensor:
+    """
+    Same as `torch.isin` without flags, but MPS-friendly. We can remove this function when we stop supporting
+    torch <= 2.3. See https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
+
+    Args:
+        elements (`torch.Tensor`): Input elements
+        test_elements (`torch.Tensor` or `int`): The elements to check against.
+
+    Returns:
+        `torch.Tensor`: A boolean tensor of the same shape as `elements` that is True for `elements` in `test_elements`
+        and False otherwise
+    """
+
+    if elements.device.type == "mps" and not is_torch_greater_or_equal_than_2_4:
+        test_elements = torch.tensor(test_elements)
+        if test_elements.ndim == 0:
+            test_elements = test_elements.unsqueeze(0)
+        return elements.tile(test_elements.shape[0], 1).eq(test_elements.unsqueeze(1)).sum(dim=0).bool().squeeze()
+    else:
+        # Note: don't use named arguments in `torch.isin`, see https://github.com/pytorch/pytorch/issues/126045
+        return torch.isin(elements, test_elements)
+
+
+@wraps(lru_cache)
+def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs):
+    """
+    LRU cache decorator from standard functools library, but with a workaround to disable
+    caching when torchdynamo is compiling. Expected to work with class methods.
+    """
+
+    def decorator(func):
+        func_with_cache = lru_cache(*lru_args, **lru_kwargs)(func)
+
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            if is_torchdynamo_compiling():
+                return func(*args, **kwargs)
+            else:
+                return func_with_cache(*args, **kwargs)
+
+        return wrapper
+
+    return decorator
+
+
+def infer_device():
+    """
+    Infers available device.
+    """
+    torch_device = "cpu"
+    if torch.cuda.is_available():
+        torch_device = "cuda"
+    elif is_torch_xpu_available():
+        torch_device = "xpu"
+
+    return torch_device
diff --git a/phivenv/Lib/site-packages/transformers/safetensors_conversion.py b/phivenv/Lib/site-packages/transformers/safetensors_conversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1612d3ea57c98fd1d383887cfbeb4e2882d3963
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/safetensors_conversion.py
@@ -0,0 +1,105 @@
+from typing import Optional
+
+import requests
+from huggingface_hub import Discussion, HfApi, get_repo_discussions
+
+from .utils import cached_file, http_user_agent, logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def previous_pr(api: HfApi, model_id: str, pr_title: str, token: str) -> Optional["Discussion"]:
+    main_commit = api.list_repo_commits(model_id, token=token)[0].commit_id
+    for discussion in get_repo_discussions(repo_id=model_id, token=token):
+        if discussion.title == pr_title and discussion.status == "open" and discussion.is_pull_request:
+            commits = api.list_repo_commits(model_id, revision=discussion.git_reference, token=token)
+
+            if main_commit == commits[1].commit_id:
+                return discussion
+    return None
+
+
+def spawn_conversion(token: str, private: bool, model_id: str):
+    logger.info("Attempting to convert .bin model on the fly to safetensors.")
+
+    safetensors_convert_space_url = "https://safetensors-convert.hf.space"
+    sse_url = f"{safetensors_convert_space_url}/call/run"
+
+    def start(_sse_connection):
+        for line in _sse_connection.iter_lines():
+            line = line.decode()
+            if line.startswith("event:"):
+                status = line[7:]
+                logger.debug(f"Safetensors conversion status: {status}")
+
+                if status == "complete":
+                    return
+                elif status == "heartbeat":
+                    logger.debug("Heartbeat")
+                else:
+                    logger.debug(f"Unknown status {status}")
+            else:
+                logger.debug(line)
+
+    data = {"data": [model_id, private, token]}
+
+    result = requests.post(sse_url, stream=True, json=data).json()
+    event_id = result["event_id"]
+
+    with requests.get(f"{sse_url}/{event_id}", stream=True) as sse_connection:
+        try:
+            logger.debug("Spawning safetensors automatic conversion.")
+            start(sse_connection)
+        except Exception as e:
+            logger.warning(f"Error during conversion: {repr(e)}")
+
+
+def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs):
+    private = api.model_info(model_id).private
+
+    logger.info("Attempting to create safetensors variant")
+    pr_title = "Adding `safetensors` variant of this model"
+    token = kwargs.get("token")
+
+    # This looks into the current repo's open PRs to see if a PR for safetensors was already open. If so, it
+    # returns it. It checks that the PR was opened by the bot and not by another user so as to prevent
+    # security breaches.
+    pr = previous_pr(api, model_id, pr_title, token=token)
+
+    if pr is None or (not private and pr.author != "SFconvertbot"):
+        spawn_conversion(token, private, model_id)
+        pr = previous_pr(api, model_id, pr_title, token=token)
+    else:
+        logger.info("Safetensors PR exists")
+
+    sha = f"refs/pr/{pr.num}"
+
+    return sha
+
+
+def auto_conversion(pretrained_model_name_or_path: str, ignore_errors_during_conversion=False, **cached_file_kwargs):
+    try:
+        api = HfApi(token=cached_file_kwargs.get("token"), headers={"user-agent": http_user_agent()})
+        sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)
+
+        if sha is None:
+            return None, None
+        cached_file_kwargs["revision"] = sha
+        del cached_file_kwargs["_commit_hash"]
+
+        # This is an additional HEAD call that could be removed if we could infer sharded/non-sharded from the PR
+        # description.
+        sharded = api.file_exists(
+            pretrained_model_name_or_path,
+            "model.safetensors.index.json",
+            revision=sha,
+            token=cached_file_kwargs.get("token"),
+        )
+        filename = "model.safetensors.index.json" if sharded else "model.safetensors"
+
+        resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
+        return resolved_archive_file, sha, sharded
+    except Exception as e:
+        if not ignore_errors_during_conversion:
+            raise e
diff --git a/phivenv/Lib/site-packages/transformers/testing_utils.py b/phivenv/Lib/site-packages/transformers/testing_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a35f13829c37929b8d03733ebe055ae3946138f
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/testing_utils.py
@@ -0,0 +1,3518 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+import contextlib
+import copy
+import doctest
+import functools
+import gc
+import importlib
+import inspect
+import logging
+import multiprocessing
+import os
+import re
+import shlex
+import shutil
+import subprocess
+import sys
+import tempfile
+import threading
+import time
+import types
+import unittest
+from collections import UserDict, defaultdict
+from collections.abc import Generator, Iterable, Iterator, Mapping
+from dataclasses import MISSING, fields
+from functools import cache, wraps
+from io import StringIO
+from pathlib import Path
+from typing import Any, Callable, Optional, Union
+from unittest import mock
+from unittest.mock import patch
+
+import huggingface_hub.utils
+import requests
+import urllib3
+from huggingface_hub import delete_repo
+from packaging import version
+
+from transformers import Trainer
+from transformers import logging as transformers_logging
+
+from .integrations import (
+    is_clearml_available,
+    is_optuna_available,
+    is_ray_available,
+    is_sigopt_available,
+    is_swanlab_available,
+    is_tensorboard_available,
+    is_trackio_available,
+    is_wandb_available,
+)
+from .integrations.deepspeed import is_deepspeed_available
+from .utils import (
+    ACCELERATE_MIN_VERSION,
+    GGUF_MIN_VERSION,
+    TRITON_MIN_VERSION,
+    is_accelerate_available,
+    is_apex_available,
+    is_apollo_torch_available,
+    is_aqlm_available,
+    is_auto_awq_available,
+    is_auto_gptq_available,
+    is_auto_round_available,
+    is_av_available,
+    is_bitsandbytes_available,
+    is_bitsandbytes_multi_backend_available,
+    is_bs4_available,
+    is_compressed_tensors_available,
+    is_cv2_available,
+    is_cython_available,
+    is_decord_available,
+    is_detectron2_available,
+    is_eetq_available,
+    is_essentia_available,
+    is_faiss_available,
+    is_fbgemm_gpu_available,
+    is_flash_attn_2_available,
+    is_flash_attn_3_available,
+    is_flax_available,
+    is_flute_available,
+    is_fp_quant_available,
+    is_fsdp_available,
+    is_ftfy_available,
+    is_g2p_en_available,
+    is_galore_torch_available,
+    is_gguf_available,
+    is_gptqmodel_available,
+    is_grokadamw_available,
+    is_hadamard_available,
+    is_hqq_available,
+    is_huggingface_hub_greater_or_equal,
+    is_ipex_available,
+    is_jieba_available,
+    is_jinja_available,
+    is_jumanpp_available,
+    is_keras_nlp_available,
+    is_kernels_available,
+    is_levenshtein_available,
+    is_librosa_available,
+    is_liger_kernel_available,
+    is_lomo_available,
+    is_mistral_common_available,
+    is_natten_available,
+    is_nltk_available,
+    is_onnx_available,
+    is_openai_available,
+    is_optimum_available,
+    is_optimum_quanto_available,
+    is_pandas_available,
+    is_peft_available,
+    is_phonemizer_available,
+    is_pretty_midi_available,
+    is_psutil_available,
+    is_pyctcdecode_available,
+    is_pytesseract_available,
+    is_pytest_available,
+    is_pytorch_quantization_available,
+    is_quark_available,
+    is_qutlass_available,
+    is_rjieba_available,
+    is_sacremoses_available,
+    is_safetensors_available,
+    is_schedulefree_available,
+    is_scipy_available,
+    is_sentencepiece_available,
+    is_seqio_available,
+    is_spacy_available,
+    is_speech_available,
+    is_spqr_available,
+    is_sudachi_available,
+    is_sudachi_projection_available,
+    is_tensorflow_probability_available,
+    is_tensorflow_text_available,
+    is_tf2onnx_available,
+    is_tf_available,
+    is_tiktoken_available,
+    is_timm_available,
+    is_tokenizers_available,
+    is_torch_available,
+    is_torch_bf16_available_on_device,
+    is_torch_bf16_gpu_available,
+    is_torch_fp16_available_on_device,
+    is_torch_greater_or_equal,
+    is_torch_hpu_available,
+    is_torch_mlu_available,
+    is_torch_neuroncore_available,
+    is_torch_npu_available,
+    is_torch_optimi_available,
+    is_torch_tensorrt_fx_available,
+    is_torch_tf32_available,
+    is_torch_xla_available,
+    is_torch_xpu_available,
+    is_torchao_available,
+    is_torchaudio_available,
+    is_torchcodec_available,
+    is_torchdynamo_available,
+    is_torchvision_available,
+    is_triton_available,
+    is_vision_available,
+    is_vptq_available,
+    strtobool,
+)
+
+
+if is_accelerate_available():
+    from accelerate.state import AcceleratorState, PartialState
+    from accelerate.utils.imports import is_fp8_available
+
+
+if is_pytest_available():
+    from _pytest.doctest import (
+        Module,
+        _get_checker,
+        _get_continue_on_failure,
+        _get_runner,
+        _is_mocked,
+        _patch_unwrap_mock_aware,
+        get_optionflags,
+    )
+    from _pytest.outcomes import skip
+    from _pytest.pathlib import import_path
+    from pytest import DoctestItem
+else:
+    Module = object
+    DoctestItem = object
+
+
+SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
+DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown"
+DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
+# Used to test Auto{Config, Model, Tokenizer} model_type detection.
+
+# Used to test the hub
+USER = "__DUMMY_TRANSFORMERS_USER__"
+ENDPOINT_STAGING = "https://hub-ci.huggingface.co"
+
+# Not critical, only usable on the sandboxed CI instance.
+TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL"
+
+if is_torch_available():
+    import torch
+
+    IS_ROCM_SYSTEM = torch.version.hip is not None
+    IS_CUDA_SYSTEM = torch.version.cuda is not None
+    IS_XPU_SYSTEM = getattr(torch.version, "xpu", None) is not None
+else:
+    IS_ROCM_SYSTEM = False
+    IS_CUDA_SYSTEM = False
+    IS_XPU_SYSTEM = False
+
+logger = transformers_logging.get_logger(__name__)
+
+
+def parse_flag_from_env(key, default=False):
+    try:
+        value = os.environ[key]
+    except KeyError:
+        # KEY isn't set, default to `default`.
+        _value = default
+    else:
+        # KEY is set, convert it to True or False.
+        try:
+            _value = strtobool(value)
+        except ValueError:
+            # More values are supported, but let's keep the message simple.
+            raise ValueError(f"If set, {key} must be yes or no.")
+    return _value
+
+
+def parse_int_from_env(key, default=None):
+    try:
+        value = os.environ[key]
+    except KeyError:
+        _value = default
+    else:
+        try:
+            _value = int(value)
+        except ValueError:
+            raise ValueError(f"If set, {key} must be a int.")
+    return _value
+
+
+_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
+_run_flaky_tests = parse_flag_from_env("RUN_FLAKY", default=True)
+_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
+_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
+_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
+_run_agent_tests = parse_flag_from_env("RUN_AGENT_TESTS", default=False)
+
+
+def is_staging_test(test_case):
+    """
+    Decorator marking a test as a staging test.
+
+    Those tests will run using the staging environment of huggingface.co instead of the real model hub.
+    """
+    if not _run_staging:
+        return unittest.skip(reason="test is staging test")(test_case)
+    else:
+        try:
+            import pytest  # We don't need a hard dependency on pytest in the main library
+        except ImportError:
+            return test_case
+        else:
+            return pytest.mark.is_staging_test()(test_case)
+
+
+def is_pipeline_test(test_case):
+    """
+    Decorator marking a test as a pipeline test. If RUN_PIPELINE_TESTS is set to a falsy value, those tests will be
+    skipped.
+    """
+    if not _run_pipeline_tests:
+        return unittest.skip(reason="test is pipeline test")(test_case)
+    else:
+        try:
+            import pytest  # We don't need a hard dependency on pytest in the main library
+        except ImportError:
+            return test_case
+        else:
+            return pytest.mark.is_pipeline_test()(test_case)
+
+
+def is_agent_test(test_case):
+    """
+    Decorator marking a test as an agent test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped.
+    """
+    if not _run_agent_tests:
+        return unittest.skip(reason="test is an agent test")(test_case)
+    else:
+        try:
+            import pytest  # We don't need a hard dependency on pytest in the main library
+        except ImportError:
+            return test_case
+        else:
+            return pytest.mark.is_agent_test()(test_case)
+
+
+def slow(test_case):
+    """
+    Decorator marking a test as slow.
+
+    Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
+
+    """
+    return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
+
+
+def tooslow(test_case):
+    """
+    Decorator marking a test as too slow.
+
+    Slow tests are skipped while they're in the process of being fixed. No test should stay tagged as "tooslow" as
+    these will not be tested by the CI.
+
+    """
+    return unittest.skip(reason="test is too slow")(test_case)
+
+
+def skip_if_not_implemented(test_func):
+    @functools.wraps(test_func)
+    def wrapper(*args, **kwargs):
+        try:
+            return test_func(*args, **kwargs)
+        except NotImplementedError as e:
+            raise unittest.SkipTest(f"Test skipped due to NotImplementedError: {e}")
+
+    return wrapper
+
+
+def apply_skip_if_not_implemented(cls):
+    """
+    Class decorator to apply @skip_if_not_implemented to all test methods.
+    """
+    for attr_name in dir(cls):
+        if attr_name.startswith("test_"):
+            attr = getattr(cls, attr_name)
+            if callable(attr):
+                setattr(cls, attr_name, skip_if_not_implemented(attr))
+    return cls
+
+
+def custom_tokenizers(test_case):
+    """
+    Decorator marking a test for a custom tokenizer.
+
+    Custom tokenizers require additional dependencies, and are skipped by default. Set the RUN_CUSTOM_TOKENIZERS
+    environment variable to a truthy value to run them.
+    """
+    return unittest.skipUnless(_run_custom_tokenizers, "test of custom tokenizers")(test_case)
+
+
+def require_bs4(test_case):
+    """
+    Decorator marking a test that requires BeautifulSoup4. These tests are skipped when BeautifulSoup4 isn't installed.
+    """
+    return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case)
+
+
+def require_galore_torch(test_case):
+    """
+    Decorator marking a test that requires GaLore. These tests are skipped when GaLore isn't installed.
+    https://github.com/jiaweizzhao/GaLore
+    """
+    return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)
+
+
+def require_apollo_torch(test_case):
+    """
+    Decorator marking a test that requires GaLore. These tests are skipped when APOLLO isn't installed.
+    https://github.com/zhuhanqing/APOLLO
+    """
+    return unittest.skipUnless(is_apollo_torch_available(), "test requires APOLLO")(test_case)
+
+
+def require_torch_optimi(test_case):
+    """
+    Decorator marking a test that requires torch-optimi. These tests are skipped when torch-optimi isn't installed.
+    https://github.com/jxnl/torch-optimi
+    """
+    return unittest.skipUnless(is_torch_optimi_available(), "test requires torch-optimi")(test_case)
+
+
+def require_lomo(test_case):
+    """
+    Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.
+    https://github.com/OpenLMLab/LOMO
+    """
+    return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case)
+
+
+def require_grokadamw(test_case):
+    """
+    Decorator marking a test that requires GrokAdamW. These tests are skipped when GrokAdamW isn't installed.
+    """
+    return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case)
+
+
+def require_schedulefree(test_case):
+    """
+    Decorator marking a test that requires schedulefree. These tests are skipped when schedulefree isn't installed.
+    https://github.com/facebookresearch/schedule_free
+    """
+    return unittest.skipUnless(is_schedulefree_available(), "test requires schedulefree")(test_case)
+
+
+def require_cv2(test_case):
+    """
+    Decorator marking a test that requires OpenCV.
+
+    These tests are skipped when OpenCV isn't installed.
+
+    """
+    return unittest.skipUnless(is_cv2_available(), "test requires OpenCV")(test_case)
+
+
+def require_levenshtein(test_case):
+    """
+    Decorator marking a test that requires Levenshtein.
+
+    These tests are skipped when Levenshtein isn't installed.
+
+    """
+    return unittest.skipUnless(is_levenshtein_available(), "test requires Levenshtein")(test_case)
+
+
+def require_nltk(test_case):
+    """
+    Decorator marking a test that requires NLTK.
+
+    These tests are skipped when NLTK isn't installed.
+
+    """
+    return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case)
+
+
+def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION):
+    """
+    Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
+    """
+    return unittest.skipUnless(
+        is_accelerate_available(min_version), f"test requires accelerate version >= {min_version}"
+    )(test_case)
+
+
+def require_triton(min_version: str = TRITON_MIN_VERSION):
+    """
+    Decorator marking a test that requires triton. These tests are skipped when triton isn't installed.
+    """
+
+    def decorator(test_case):
+        return unittest.skipUnless(is_triton_available(min_version), f"test requires triton version >= {min_version}")(
+            test_case
+        )
+
+    return decorator
+
+
+def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION):
+    """
+    Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed.
+    """
+    return unittest.skipUnless(is_gguf_available(min_version), f"test requires gguf version >= {min_version}")(
+        test_case
+    )
+
+
+def require_fsdp(test_case, min_version: str = "1.12.0"):
+    """
+    Decorator marking a test that requires fsdp. These tests are skipped when fsdp isn't installed.
+    """
+    return unittest.skipUnless(is_fsdp_available(min_version), f"test requires torch version >= {min_version}")(
+        test_case
+    )
+
+
+def require_g2p_en(test_case):
+    """
+    Decorator marking a test that requires g2p_en. These tests are skipped when SentencePiece isn't installed.
+    """
+    return unittest.skipUnless(is_g2p_en_available(), "test requires g2p_en")(test_case)
+
+
+def require_safetensors(test_case):
+    """
+    Decorator marking a test that requires safetensors. These tests are skipped when safetensors isn't installed.
+    """
+    return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case)
+
+
+def require_rjieba(test_case):
+    """
+    Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
+    """
+    return unittest.skipUnless(is_rjieba_available(), "test requires rjieba")(test_case)
+
+
+def require_jieba(test_case):
+    """
+    Decorator marking a test that requires jieba. These tests are skipped when jieba isn't installed.
+    """
+    return unittest.skipUnless(is_jieba_available(), "test requires jieba")(test_case)
+
+
+def require_jinja(test_case):
+    """
+    Decorator marking a test that requires jinja. These tests are skipped when jinja isn't installed.
+    """
+    return unittest.skipUnless(is_jinja_available(), "test requires jinja")(test_case)
+
+
+def require_tf2onnx(test_case):
+    logger.warning_once(
+        "TensorFlow test-related code, including `require_tf2onnx`, is deprecated and will be removed in "
+        "Transformers v4.55"
+    )
+    return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case)
+
+
+def require_onnx(test_case):
+    return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case)
+
+
+def require_timm(test_case):
+    """
+    Decorator marking a test that requires Timm.
+
+    These tests are skipped when Timm isn't installed.
+
+    """
+    return unittest.skipUnless(is_timm_available(), "test requires Timm")(test_case)
+
+
+def require_natten(test_case):
+    """
+    Decorator marking a test that requires NATTEN.
+
+    These tests are skipped when NATTEN isn't installed.
+
+    """
+    return unittest.skipUnless(is_natten_available(), "test requires natten")(test_case)
+
+
+def require_torch(test_case):
+    """
+    Decorator marking a test that requires PyTorch.
+
+    These tests are skipped when PyTorch isn't installed.
+
+    """
+    return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
+
+
+def require_torch_greater_or_equal(version: str):
+    """
+    Decorator marking a test that requires PyTorch version >= `version`.
+
+    These tests are skipped when PyTorch version is less than `version`.
+    """
+
+    def decorator(test_case):
+        return unittest.skipUnless(is_torch_greater_or_equal(version), f"test requires PyTorch version >= {version}")(
+            test_case
+        )
+
+    return decorator
+
+
+def require_huggingface_hub_greater_or_equal(version: str):
+    """
+    Decorator marking a test that requires huggingface_hub version >= `version`.
+
+    These tests are skipped when huggingface_hub version is less than `version`.
+    """
+
+    def decorator(test_case):
+        return unittest.skipUnless(
+            is_huggingface_hub_greater_or_equal(version), f"test requires huggingface_hub version >= {version}"
+        )(test_case)
+
+    return decorator
+
+
+def require_flash_attn(test_case):
+    """
+    Decorator marking a test that requires Flash Attention.
+
+    These tests are skipped when Flash Attention isn't installed.
+
+    """
+    flash_attn_available = is_flash_attn_2_available()
+    kernels_available = is_kernels_available()
+    try:
+        from kernels import get_kernel
+
+        get_kernel("kernels-community/flash-attn")
+    except Exception as _:
+        kernels_available = False
+
+    return unittest.skipUnless(kernels_available | flash_attn_available, "test requires Flash Attention")(test_case)
+
+
+def require_kernels(test_case):
+    """
+    Decorator marking a test that requires Flash Attention.
+
+    These tests are skipped when Flash Attention isn't installed.
+
+    """
+    return unittest.skipUnless(is_kernels_available(), "test requires Flash Attention")(test_case)
+
+
+def require_flash_attn_3(test_case):
+    """
+    Decorator marking a test that requires Flash Attention 3.
+
+    These tests are skipped when Flash Attention 3 isn't installed.
+    """
+    return unittest.skipUnless(is_flash_attn_3_available(), "test requires Flash Attention 3")(test_case)
+
+
+def require_read_token(test_case):
+    """
+    A decorator that loads the HF token for tests that require to load gated models.
+    """
+    token = os.getenv("HF_HUB_READ_TOKEN")
+
+    if isinstance(test_case, type):
+        for attr_name in dir(test_case):
+            attr = getattr(test_case, attr_name)
+            if isinstance(attr, types.FunctionType):
+                if getattr(attr, "__require_read_token__", False):
+                    continue
+                wrapped = require_read_token(attr)
+                setattr(test_case, attr_name, wrapped)
+        return test_case
+    else:
+        if getattr(test_case, "__require_read_token__", False):
+            return test_case
+
+        @functools.wraps(test_case)
+        def wrapper(*args, **kwargs):
+            if token is not None:
+                with patch("huggingface_hub.utils._headers.get_token", return_value=token):
+                    return test_case(*args, **kwargs)
+            else:  # Allow running locally with the default token env variable
+                # dealing with static/class methods and called by `self.xxx`
+                if "staticmethod" in inspect.getsource(test_case).strip():
+                    if len(args) > 0 and isinstance(args[0], unittest.TestCase):
+                        return test_case(*args[1:], **kwargs)
+                return test_case(*args, **kwargs)
+
+        wrapper.__require_read_token__ = True
+        return wrapper
+
+
+def require_peft(test_case):
+    """
+    Decorator marking a test that requires PEFT.
+
+    These tests are skipped when PEFT isn't installed.
+
+    """
+    return unittest.skipUnless(is_peft_available(), "test requires PEFT")(test_case)
+
+
+def require_torchvision(test_case):
+    """
+    Decorator marking a test that requires Torchvision.
+
+    These tests are skipped when Torchvision isn't installed.
+
+    """
+    return unittest.skipUnless(is_torchvision_available(), "test requires Torchvision")(test_case)
+
+
+def require_torchcodec(test_case):
+    """
+    Decorator marking a test that requires Torchcodec.
+
+    These tests are skipped when Torchcodec isn't installed.
+
+    """
+    return unittest.skipUnless(is_torchcodec_available(), "test requires Torchcodec")(test_case)
+
+
+def require_torch_or_tf(test_case):
+    """
+    Decorator marking a test that requires PyTorch or TensorFlow.
+
+    These tests are skipped when neither PyTorch not TensorFlow is installed.
+
+    """
+    return unittest.skipUnless(is_torch_available() or is_tf_available(), "test requires PyTorch or TensorFlow")(
+        test_case
+    )
+
+
+def require_intel_extension_for_pytorch(test_case):
+    """
+    Decorator marking a test that requires Intel Extension for PyTorch.
+
+    These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
+    version.
+
+    """
+    return unittest.skipUnless(
+        is_ipex_available(),
+        "test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see"
+        " https://github.com/intel/intel-extension-for-pytorch",
+    )(test_case)
+
+
+def require_tensorflow_probability(test_case):
+    """
+    Decorator marking a test that requires TensorFlow probability.
+
+    These tests are skipped when TensorFlow probability isn't installed.
+
+    """
+    logger.warning_once(
+        "TensorFlow test-related code, including `require_tensorflow_probability`, is deprecated and will be "
+        "removed in Transformers v4.55"
+    )
+    return unittest.skipUnless(is_tensorflow_probability_available(), "test requires TensorFlow probability")(
+        test_case
+    )
+
+
+def require_torchaudio(test_case):
+    """
+    Decorator marking a test that requires torchaudio. These tests are skipped when torchaudio isn't installed.
+    """
+    return unittest.skipUnless(is_torchaudio_available(), "test requires torchaudio")(test_case)
+
+
+def require_tf(test_case):
+    """
+    Decorator marking a test that requires TensorFlow. These tests are skipped when TensorFlow isn't installed.
+    """
+    logger.warning_once(
+        "TensorFlow test-related code, including `require_tf`, is deprecated and will be removed in Transformers v4.55"
+    )
+    return unittest.skipUnless(is_tf_available(), "test requires TensorFlow")(test_case)
+
+
+def require_flax(test_case):
+    """
+    Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
+    """
+    logger.warning_once(
+        "JAX test-related code, including `require_flax`, is deprecated and will be removed in Transformers v4.55"
+    )
+    return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
+
+
+def require_sentencepiece(test_case):
+    """
+    Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed.
+    """
+    return unittest.skipUnless(is_sentencepiece_available(), "test requires SentencePiece")(test_case)
+
+
+def require_sacremoses(test_case):
+    """
+    Decorator marking a test that requires Sacremoses. These tests are skipped when Sacremoses isn't installed.
+    """
+    return unittest.skipUnless(is_sacremoses_available(), "test requires Sacremoses")(test_case)
+
+
+def require_seqio(test_case):
+    """
+    Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed.
+    """
+    return unittest.skipUnless(is_seqio_available(), "test requires Seqio")(test_case)
+
+
+def require_scipy(test_case):
+    """
+    Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed.
+    """
+    return unittest.skipUnless(is_scipy_available(), "test requires Scipy")(test_case)
+
+
+def require_tokenizers(test_case):
+    """
+    Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed.
+    """
+    return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case)
+
+
+def require_tensorflow_text(test_case):
+    """
+    Decorator marking a test that requires tensorflow_text. These tests are skipped when tensroflow_text isn't
+    installed.
+    """
+    logger.warning_once(
+        "TensorFlow test-related code, including `require_tensorflow_text`, is deprecated and will be "
+        "removed in Transformers v4.55"
+    )
+    return unittest.skipUnless(is_tensorflow_text_available(), "test requires tensorflow_text")(test_case)
+
+
+def require_keras_nlp(test_case):
+    """
+    Decorator marking a test that requires keras_nlp. These tests are skipped when keras_nlp isn't installed.
+    """
+    return unittest.skipUnless(is_keras_nlp_available(), "test requires keras_nlp")(test_case)
+
+
+def require_pandas(test_case):
+    """
+    Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
+    """
+    return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case)
+
+
+def require_pytesseract(test_case):
+    """
+    Decorator marking a test that requires PyTesseract. These tests are skipped when PyTesseract isn't installed.
+    """
+    return unittest.skipUnless(is_pytesseract_available(), "test requires PyTesseract")(test_case)
+
+
+def require_pytorch_quantization(test_case):
+    """
+    Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch
+    Quantization Toolkit isn't installed.
+    """
+    return unittest.skipUnless(is_pytorch_quantization_available(), "test requires PyTorch Quantization Toolkit")(
+        test_case
+    )
+
+
+def require_vision(test_case):
+    """
+    Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't
+    installed.
+    """
+    return unittest.skipUnless(is_vision_available(), "test requires vision")(test_case)
+
+
+def require_ftfy(test_case):
+    """
+    Decorator marking a test that requires ftfy. These tests are skipped when ftfy isn't installed.
+    """
+    return unittest.skipUnless(is_ftfy_available(), "test requires ftfy")(test_case)
+
+
+def require_spacy(test_case):
+    """
+    Decorator marking a test that requires SpaCy. These tests are skipped when SpaCy isn't installed.
+    """
+    return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case)
+
+
+def require_torch_multi_gpu(test_case):
+    """
+    Decorator marking a test that requires a multi-GPU CUDA setup (in PyTorch). These tests are skipped on a machine without
+    multiple CUDA GPUs.
+
+    To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
+    """
+    if not is_torch_available():
+        return unittest.skip(reason="test requires PyTorch")(test_case)
+
+    import torch
+
+    return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple CUDA GPUs")(test_case)
+
+
+def require_torch_multi_accelerator(test_case):
+    """
+    Decorator marking a test that requires a multi-accelerator (in PyTorch). These tests are skipped on a machine
+    without multiple accelerators. To run *only* the multi_accelerator tests, assuming all test names contain
+    multi_accelerator: $ pytest -sv ./tests -k "multi_accelerator"
+    """
+    if not is_torch_available():
+        return unittest.skip(reason="test requires PyTorch")(test_case)
+
+    return unittest.skipUnless(backend_device_count(torch_device) > 1, "test requires multiple accelerators")(
+        test_case
+    )
+
+
+def require_torch_non_multi_gpu(test_case):
+    """
+    Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
+    """
+    if not is_torch_available():
+        return unittest.skip(reason="test requires PyTorch")(test_case)
+
+    import torch
+
+    return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case)
+
+
+def require_torch_non_multi_accelerator(test_case):
+    """
+    Decorator marking a test that requires 0 or 1 accelerator setup (in PyTorch).
+    """
+    if not is_torch_available():
+        return unittest.skip(reason="test requires PyTorch")(test_case)
+
+    return unittest.skipUnless(backend_device_count(torch_device) < 2, "test requires 0 or 1 accelerator")(test_case)
+
+
+def require_torch_up_to_2_gpus(test_case):
+    """
+    Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch).
+    """
+    if not is_torch_available():
+        return unittest.skip(reason="test requires PyTorch")(test_case)
+
+    import torch
+
+    return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case)
+
+
+def require_torch_up_to_2_accelerators(test_case):
+    """
+    Decorator marking a test that requires 0 or 1 or 2 accelerator setup (in PyTorch).
+    """
+    if not is_torch_available():
+        return unittest.skip(reason="test requires PyTorch")(test_case)
+
+    return unittest.skipUnless(backend_device_count(torch_device) < 3, "test requires 0 or 1 or 2 accelerators")(
+        test_case
+    )
+
+
+def require_torch_xla(test_case):
+    """
+    Decorator marking a test that requires TorchXLA (in PyTorch).
+    """
+    return unittest.skipUnless(is_torch_xla_available(), "test requires TorchXLA")(test_case)
+
+
+def require_torch_neuroncore(test_case):
+    """
+    Decorator marking a test that requires NeuronCore (in PyTorch).
+    """
+    return unittest.skipUnless(is_torch_neuroncore_available(check_device=False), "test requires PyTorch NeuronCore")(
+        test_case
+    )
+
+
+def require_torch_npu(test_case):
+    """
+    Decorator marking a test that requires NPU (in PyTorch).
+    """
+    return unittest.skipUnless(is_torch_npu_available(), "test requires PyTorch NPU")(test_case)
+
+
+def require_torch_multi_npu(test_case):
+    """
+    Decorator marking a test that requires a multi-NPU setup (in PyTorch). These tests are skipped on a machine without
+    multiple NPUs.
+
+    To run *only* the multi_npu tests, assuming all test names contain multi_npu: $ pytest -sv ./tests -k "multi_npu"
+    """
+    if not is_torch_npu_available():
+        return unittest.skip(reason="test requires PyTorch NPU")(test_case)
+
+    return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case)
+
+
+def require_non_hpu(test_case):
+    """
+    Decorator marking a test that should be skipped for HPU.
+    """
+    return unittest.skipUnless(torch_device != "hpu", "test requires a non-HPU")(test_case)
+
+
+def require_torch_xpu(test_case):
+    """
+    Decorator marking a test that requires XPU (in PyTorch).
+
+    These tests are skipped when XPU backend is not available. XPU backend might be available either via stock
+    PyTorch (>=2.4) or via Intel Extension for PyTorch. In the latter case, if IPEX is installed, its version
+    must match match current PyTorch version.
+    """
+    return unittest.skipUnless(is_torch_xpu_available(), "test requires XPU device")(test_case)
+
+
+def require_non_xpu(test_case):
+    """
+    Decorator marking a test that should be skipped for XPU.
+    """
+    return unittest.skipUnless(torch_device != "xpu", "test requires a non-XPU")(test_case)
+
+
+def require_torch_multi_xpu(test_case):
+    """
+    Decorator marking a test that requires a multi-XPU setup (in PyTorch). These tests are skipped on a machine without
+    multiple XPUs.
+
+    To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu"
+    """
+    if not is_torch_xpu_available():
+        return unittest.skip(reason="test requires PyTorch XPU")(test_case)
+
+    return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)
+
+
+def require_torch_multi_hpu(test_case):
+    """
+    Decorator marking a test that requires a multi-HPU setup (in PyTorch). These tests are skipped on a machine without
+    multiple HPUs.
+
+    To run *only* the multi_hpu tests, assuming all test names contain multi_hpu: $ pytest -sv ./tests -k "multi_hpu"
+    """
+    if not is_torch_hpu_available():
+        return unittest.skip(reason="test requires PyTorch HPU")(test_case)
+
+    return unittest.skipUnless(torch.hpu.device_count() > 1, "test requires multiple HPUs")(test_case)
+
+
+if is_torch_available():
+    # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
+    import torch
+
+    if "TRANSFORMERS_TEST_BACKEND" in os.environ:
+        backend = os.environ["TRANSFORMERS_TEST_BACKEND"]
+        try:
+            _ = importlib.import_module(backend)
+        except ModuleNotFoundError as e:
+            raise ModuleNotFoundError(
+                f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its"
+                f" traceback):\n{e}"
+            ) from e
+
+    if "TRANSFORMERS_TEST_DEVICE" in os.environ:
+        torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"]
+        if torch_device == "cuda" and not torch.cuda.is_available():
+            raise ValueError(
+                f"TRANSFORMERS_TEST_DEVICE={torch_device}, but CUDA is unavailable. Please double-check your testing environment."
+            )
+        if torch_device == "xpu" and not is_torch_xpu_available():
+            raise ValueError(
+                f"TRANSFORMERS_TEST_DEVICE={torch_device}, but XPU is unavailable. Please double-check your testing environment."
+            )
+        if torch_device == "npu" and not is_torch_npu_available():
+            raise ValueError(
+                f"TRANSFORMERS_TEST_DEVICE={torch_device}, but NPU is unavailable. Please double-check your testing environment."
+            )
+        if torch_device == "mlu" and not is_torch_mlu_available():
+            raise ValueError(
+                f"TRANSFORMERS_TEST_DEVICE={torch_device}, but MLU is unavailable. Please double-check your testing environment."
+            )
+        if torch_device == "hpu" and not is_torch_hpu_available():
+            raise ValueError(
+                f"TRANSFORMERS_TEST_DEVICE={torch_device}, but HPU is unavailable. Please double-check your testing environment."
+            )
+
+        try:
+            # try creating device to see if provided device is valid
+            _ = torch.device(torch_device)
+        except RuntimeError as e:
+            raise RuntimeError(
+                f"Unknown testing device specified by environment variable `TRANSFORMERS_TEST_DEVICE`: {torch_device}"
+            ) from e
+    elif torch.cuda.is_available():
+        torch_device = "cuda"
+    elif is_torch_npu_available():
+        torch_device = "npu"
+    elif is_torch_mlu_available():
+        torch_device = "mlu"
+    elif is_torch_hpu_available():
+        torch_device = "hpu"
+    elif is_torch_xpu_available():
+        torch_device = "xpu"
+    else:
+        torch_device = "cpu"
+else:
+    torch_device = None
+
+if is_tf_available():
+    import tensorflow as tf
+
+if is_flax_available():
+    import jax
+
+    jax_device = jax.default_backend()
+else:
+    jax_device = None
+
+
+def require_torchdynamo(test_case):
+    """Decorator marking a test that requires TorchDynamo"""
+    return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
+
+
+def require_torchao(test_case):
+    """Decorator marking a test that requires torchao"""
+    return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case)
+
+
+def require_torchao_version_greater_or_equal(torchao_version):
+    def decorator(test_case):
+        correct_torchao_version = is_torchao_available() and version.parse(
+            version.parse(importlib.metadata.version("torchao")).base_version
+        ) >= version.parse(torchao_version)
+        return unittest.skipUnless(
+            correct_torchao_version, f"Test requires torchao with the version greater than {torchao_version}."
+        )(test_case)
+
+    return decorator
+
+
+def require_torch_tensorrt_fx(test_case):
+    """Decorator marking a test that requires Torch-TensorRT FX"""
+    return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)
+
+
+def require_torch_gpu(test_case):
+    """Decorator marking a test that requires CUDA and PyTorch."""
+    return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
+
+
+def require_torch_mps(test_case):
+    """Decorator marking a test that requires CUDA and PyTorch."""
+    return unittest.skipUnless(torch_device == "mps", "test requires MPS")(test_case)
+
+
+def require_large_cpu_ram(test_case, memory: float = 80):
+    """Decorator marking a test that requires a CPU RAM with more than `memory` GiB of memory."""
+    if not is_psutil_available():
+        return test_case
+
+    import psutil
+
+    return unittest.skipUnless(
+        psutil.virtual_memory().total / 1024**3 > memory,
+        f"test requires a machine with more than {memory} GiB of CPU RAM memory",
+    )(test_case)
+
+
+def require_torch_large_gpu(test_case, memory: float = 20):
+    """Decorator marking a test that requires a CUDA GPU with more than `memory` GiB of memory."""
+    if torch_device != "cuda":
+        return unittest.skip(reason=f"test requires a CUDA GPU with more than {memory} GiB of memory")(test_case)
+
+    return unittest.skipUnless(
+        torch.cuda.get_device_properties(0).total_memory / 1024**3 > memory,
+        f"test requires a GPU with more than {memory} GiB of memory",
+    )(test_case)
+
+
+def require_torch_large_accelerator(test_case, memory: float = 20):
+    """Decorator marking a test that requires an accelerator with more than `memory` GiB of memory."""
+    if torch_device != "cuda" and torch_device != "xpu":
+        return unittest.skip(reason=f"test requires a GPU or XPU with more than {memory} GiB of memory")(test_case)
+
+    torch_accelerator_module = getattr(torch, torch_device)
+
+    return unittest.skipUnless(
+        torch_accelerator_module.get_device_properties(0).total_memory / 1024**3 > memory,
+        f"test requires a GPU or XPU with more than {memory} GiB of memory",
+    )(test_case)
+
+
+def require_torch_gpu_if_bnb_not_multi_backend_enabled(test_case):
+    """
+    Decorator marking a test that requires a GPU if bitsandbytes multi-backend feature is not enabled.
+    """
+    if is_bitsandbytes_available() and is_bitsandbytes_multi_backend_available():
+        return test_case
+    return require_torch_gpu(test_case)
+
+
+def require_torch_accelerator(test_case):
+    """Decorator marking a test that requires an accessible accelerator and PyTorch."""
+    return unittest.skipUnless(torch_device is not None and torch_device != "cpu", "test requires accelerator")(
+        test_case
+    )
+
+
+def require_torch_fp16(test_case):
+    """Decorator marking a test that requires a device that supports fp16"""
+    return unittest.skipUnless(
+        is_torch_fp16_available_on_device(torch_device), "test requires device with fp16 support"
+    )(test_case)
+
+
+def require_fp8(test_case):
+    """Decorator marking a test that requires supports for fp8"""
+    return unittest.skipUnless(is_accelerate_available() and is_fp8_available(), "test requires fp8 support")(
+        test_case
+    )
+
+
+def require_torch_bf16(test_case):
+    """Decorator marking a test that requires a device that supports bf16"""
+    return unittest.skipUnless(
+        is_torch_bf16_available_on_device(torch_device), "test requires device with bf16 support"
+    )(test_case)
+
+
+def require_torch_bf16_gpu(test_case):
+    """Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0"""
+    return unittest.skipUnless(
+        is_torch_bf16_gpu_available(),
+        "test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0",
+    )(test_case)
+
+
+def require_deterministic_for_xpu(test_case):
+    @wraps(test_case)
+    def wrapper(*args, **kwargs):
+        if is_torch_xpu_available():
+            original_state = torch.are_deterministic_algorithms_enabled()
+            try:
+                torch.use_deterministic_algorithms(True)
+                return test_case(*args, **kwargs)
+            finally:
+                torch.use_deterministic_algorithms(original_state)
+        else:
+            return test_case(*args, **kwargs)
+
+    return wrapper
+
+
+def require_torch_tf32(test_case):
+    """Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
+    return unittest.skipUnless(
+        is_torch_tf32_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7"
+    )(test_case)
+
+
+def require_detectron2(test_case):
+    """Decorator marking a test that requires detectron2."""
+    return unittest.skipUnless(is_detectron2_available(), "test requires `detectron2`")(test_case)
+
+
+def require_faiss(test_case):
+    """Decorator marking a test that requires faiss."""
+    return unittest.skipUnless(is_faiss_available(), "test requires `faiss`")(test_case)
+
+
+def require_optuna(test_case):
+    """
+    Decorator marking a test that requires optuna.
+
+    These tests are skipped when optuna isn't installed.
+
+    """
+    return unittest.skipUnless(is_optuna_available(), "test requires optuna")(test_case)
+
+
+def require_ray(test_case):
+    """
+    Decorator marking a test that requires Ray/tune.
+
+    These tests are skipped when Ray/tune isn't installed.
+
+    """
+    return unittest.skipUnless(is_ray_available(), "test requires Ray/tune")(test_case)
+
+
+def require_sigopt(test_case):
+    """
+    Decorator marking a test that requires SigOpt.
+
+    These tests are skipped when SigOpt isn't installed.
+
+    """
+    return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case)
+
+
+def require_swanlab(test_case):
+    """
+    Decorator marking a test that requires swanlab.
+
+    These tests are skipped when swanlab isn't installed.
+
+    """
+    return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case)
+
+
+def require_trackio(test_case):
+    """
+    Decorator marking a test that requires trackio.
+
+    These tests are skipped when trackio isn't installed.
+
+    """
+    return unittest.skipUnless(is_trackio_available(), "test requires trackio")(test_case)
+
+
+def require_wandb(test_case):
+    """
+    Decorator marking a test that requires wandb.
+
+    These tests are skipped when wandb isn't installed.
+
+    """
+    return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)
+
+
+def require_clearml(test_case):
+    """
+    Decorator marking a test requires clearml.
+
+    These tests are skipped when clearml isn't installed.
+
+    """
+    return unittest.skipUnless(is_clearml_available(), "test requires clearml")(test_case)
+
+
+def require_deepspeed(test_case):
+    """
+    Decorator marking a test that requires deepspeed
+    """
+    return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case)
+
+
+def require_apex(test_case):
+    """
+    Decorator marking a test that requires apex
+    """
+    return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case)
+
+
+def require_aqlm(test_case):
+    """
+    Decorator marking a test that requires aqlm
+    """
+    return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)
+
+
+def require_vptq(test_case):
+    """
+    Decorator marking a test that requires vptq
+    """
+    return unittest.skipUnless(is_vptq_available(), "test requires vptq")(test_case)
+
+
+def require_spqr(test_case):
+    """
+    Decorator marking a test that requires spqr
+    """
+    return unittest.skipUnless(is_spqr_available(), "test requires spqr")(test_case)
+
+
+def require_eetq(test_case):
+    """
+    Decorator marking a test that requires eetq
+    """
+    eetq_available = is_eetq_available()
+    if eetq_available:
+        try:
+            import eetq  # noqa: F401
+        except ImportError as exc:
+            if "shard_checkpoint" in str(exc):
+                # EETQ 1.0.0 is currently broken with the latest transformers because it tries to import the removed
+                # shard_checkpoint function, see https://github.com/NetEase-FuXi/EETQ/issues/34.
+                # TODO: Remove once eetq releases a fix and this release is used in CI
+                eetq_available = False
+    return unittest.skipUnless(eetq_available, "test requires eetq")(test_case)
+
+
+def require_av(test_case):
+    """
+    Decorator marking a test that requires av
+    """
+    return unittest.skipUnless(is_av_available(), "test requires av")(test_case)
+
+
+def require_decord(test_case):
+    """
+    Decorator marking a test that requires decord
+    """
+    return unittest.skipUnless(is_decord_available(), "test requires decord")(test_case)
+
+
+def require_bitsandbytes(test_case):
+    """
+    Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed.
+    """
+    if is_bitsandbytes_available() and is_torch_available():
+        try:
+            import pytest
+
+            return pytest.mark.bitsandbytes(test_case)
+        except ImportError:
+            return test_case
+    else:
+        return unittest.skip(reason="test requires bitsandbytes and torch")(test_case)
+
+
+def require_optimum(test_case):
+    """
+    Decorator for optimum dependency
+    """
+    return unittest.skipUnless(is_optimum_available(), "test requires optimum")(test_case)
+
+
+def require_tensorboard(test_case):
+    """
+    Decorator for `tensorboard` dependency
+    """
+    return unittest.skipUnless(is_tensorboard_available(), "test requires tensorboard")
+
+
+def require_gptq(test_case):
+    """
+    Decorator for auto_gptq dependency
+    """
+    return unittest.skipUnless(
+        is_gptqmodel_available() or is_auto_gptq_available(), "test requires gptqmodel or auto-gptq"
+    )(test_case)
+
+
+def require_hqq(test_case):
+    """
+    Decorator for hqq dependency
+    """
+    return unittest.skipUnless(is_hqq_available(), "test requires hqq")(test_case)
+
+
+def require_auto_awq(test_case):
+    """
+    Decorator for auto_awq dependency
+    """
+    return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case)
+
+
+def require_auto_round(test_case):
+    """
+    Decorator for auto_round dependency
+    """
+    return unittest.skipUnless(is_auto_round_available(), "test requires autoround")(test_case)
+
+
+def require_optimum_quanto(test_case):
+    """
+    Decorator for quanto dependency
+    """
+    return unittest.skipUnless(is_optimum_quanto_available(), "test requires optimum-quanto")(test_case)
+
+
+def require_compressed_tensors(test_case):
+    """
+    Decorator for compressed_tensors dependency
+    """
+    return unittest.skipUnless(is_compressed_tensors_available(), "test requires compressed_tensors")(test_case)
+
+
+def require_fbgemm_gpu(test_case):
+    """
+    Decorator for fbgemm_gpu dependency
+    """
+    return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case)
+
+
+def require_quark(test_case):
+    """
+    Decorator for quark dependency
+    """
+    return unittest.skipUnless(is_quark_available(), "test requires quark")(test_case)
+
+
+def require_flute_hadamard(test_case):
+    """
+    Decorator marking a test that requires higgs and hadamard
+    """
+    return unittest.skipUnless(
+        is_flute_available() and is_hadamard_available(), "test requires flute and fast_hadamard_transform"
+    )(test_case)
+
+
+def require_fp_quant(test_case):
+    """
+    Decorator marking a test that requires fp_quant and qutlass
+    """
+    return unittest.skipUnless(is_fp_quant_available(), "test requires fp_quant")(test_case)
+
+
+def require_qutlass(test_case):
+    """
+    Decorator marking a test that requires qutlass
+    """
+    return unittest.skipUnless(is_qutlass_available(), "test requires qutlass")(test_case)
+
+
+def require_phonemizer(test_case):
+    """
+    Decorator marking a test that requires phonemizer
+    """
+    return unittest.skipUnless(is_phonemizer_available(), "test requires phonemizer")(test_case)
+
+
+def require_pyctcdecode(test_case):
+    """
+    Decorator marking a test that requires pyctcdecode
+    """
+    return unittest.skipUnless(is_pyctcdecode_available(), "test requires pyctcdecode")(test_case)
+
+
+def require_librosa(test_case):
+    """
+    Decorator marking a test that requires librosa
+    """
+    return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case)
+
+
+def require_liger_kernel(test_case):
+    """
+    Decorator marking a test that requires liger_kernel
+    """
+    return unittest.skipUnless(is_liger_kernel_available(), "test requires liger_kernel")(test_case)
+
+
+def require_essentia(test_case):
+    """
+    Decorator marking a test that requires essentia
+    """
+    return unittest.skipUnless(is_essentia_available(), "test requires essentia")(test_case)
+
+
+def require_pretty_midi(test_case):
+    """
+    Decorator marking a test that requires pretty_midi
+    """
+    return unittest.skipUnless(is_pretty_midi_available(), "test requires pretty_midi")(test_case)
+
+
+def cmd_exists(cmd):
+    return shutil.which(cmd) is not None
+
+
+def require_usr_bin_time(test_case):
+    """
+    Decorator marking a test that requires `/usr/bin/time`
+    """
+    return unittest.skipUnless(cmd_exists("/usr/bin/time"), "test requires /usr/bin/time")(test_case)
+
+
+def require_sudachi(test_case):
+    """
+    Decorator marking a test that requires sudachi
+    """
+    return unittest.skipUnless(is_sudachi_available(), "test requires sudachi")(test_case)
+
+
+def require_sudachi_projection(test_case):
+    """
+    Decorator marking a test that requires sudachi_projection
+    """
+    return unittest.skipUnless(is_sudachi_projection_available(), "test requires sudachi which supports projection")(
+        test_case
+    )
+
+
+def require_jumanpp(test_case):
+    """
+    Decorator marking a test that requires jumanpp
+    """
+    return unittest.skipUnless(is_jumanpp_available(), "test requires jumanpp")(test_case)
+
+
+def require_cython(test_case):
+    """
+    Decorator marking a test that requires jumanpp
+    """
+    return unittest.skipUnless(is_cython_available(), "test requires cython")(test_case)
+
+
+def require_tiktoken(test_case):
+    """
+    Decorator marking a test that requires TikToken. These tests are skipped when TikToken isn't installed.
+    """
+    return unittest.skipUnless(is_tiktoken_available(), "test requires TikToken")(test_case)
+
+
+def require_speech(test_case):
+    """
+    Decorator marking a test that requires speech. These tests are skipped when speech isn't available.
+    """
+    return unittest.skipUnless(is_speech_available(), "test requires torchaudio")(test_case)
+
+
+def require_openai(test_case):
+    """
+    Decorator marking a test that requires openai
+    """
+    return unittest.skipUnless(is_openai_available(), "test requires openai")(test_case)
+
+
+def require_mistral_common(test_case):
+    """
+    Decorator marking a test that requires mistral-common. These tests are skipped when mistral-common isn't available.
+    """
+    return unittest.skipUnless(is_mistral_common_available(), "test requires mistral-common")(test_case)
+
+
+def get_gpu_count():
+    """
+    Return the number of available gpus (regardless of whether torch, tf or jax is used)
+    """
+    if is_torch_available():
+        import torch
+
+        return torch.cuda.device_count()
+    elif is_tf_available():
+        import tensorflow as tf
+
+        return len(tf.config.list_physical_devices("GPU"))
+    elif is_flax_available():
+        import jax
+
+        return jax.device_count()
+    else:
+        return 0
+
+
+def get_tests_dir(append_path=None):
+    """
+    Args:
+        append_path: optional path to append to the tests dir path
+
+    Return:
+        The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is
+        joined after the `tests` dir the former is provided.
+
+    """
+    # this function caller's __file__
+    caller__file__ = inspect.stack()[1][1]
+    tests_dir = os.path.abspath(os.path.dirname(caller__file__))
+
+    while not tests_dir.endswith("tests"):
+        tests_dir = os.path.dirname(tests_dir)
+
+    if append_path:
+        return os.path.join(tests_dir, append_path)
+    else:
+        return tests_dir
+
+
+def get_steps_per_epoch(trainer: Trainer) -> int:
+    training_args = trainer.args
+    train_dataloader = trainer.get_train_dataloader()
+
+    initial_training_values = trainer.set_initial_training_values(
+        args=training_args,
+        dataloader=train_dataloader,
+        total_train_batch_size=training_args.per_device_train_batch_size,
+    )
+    steps_per_epoch = initial_training_values[1]
+
+    return steps_per_epoch
+
+
+def evaluate_side_effect_factory(
+    side_effect_values: list[dict[str, float]],
+) -> Generator[dict[str, float], None, None]:
+    """
+    Function that returns side effects for the _evaluate method.
+    Used when we're unsure of exactly how many times _evaluate will be called.
+    """
+    yield from side_effect_values
+
+    while True:
+        yield side_effect_values[-1]
+
+
+#
+# Helper functions for dealing with testing text outputs
+# The original code came from:
+# https://github.com/fastai/fastai/blob/master/tests/utils/text.py
+
+
+# When any function contains print() calls that get overwritten, like progress bars,
+# a special care needs to be applied, since under pytest -s captured output (capsys
+# or contextlib.redirect_stdout) contains any temporary printed strings, followed by
+# \r's. This helper function ensures that the buffer will contain the same output
+# with and without -s in pytest, by turning:
+# foo bar\r tar mar\r final message
+# into:
+# final message
+# it can handle a single string or a multiline buffer
+def apply_print_resets(buf):
+    return re.sub(r"^.*\r", "", buf, 0, re.M)
+
+
+def assert_screenout(out, what):
+    out_pr = apply_print_resets(out).lower()
+    match_str = out_pr.find(what.lower())
+    assert match_str != -1, f"expecting to find {what} in output: f{out_pr}"
+
+
+def set_model_tester_for_less_flaky_test(test_case):
+    target_num_hidden_layers = 1
+    # TODO (if possible): Avoid exceptional cases
+    exceptional_classes = [
+        "ZambaModelTester",
+        "Zamba2ModelTester",
+        "RwkvModelTester",
+        "AriaVisionText2TextModelTester",
+        "GPTNeoModelTester",
+        "DPTModelTester",
+    ]
+    if test_case.model_tester.__class__.__name__ in exceptional_classes:
+        target_num_hidden_layers = None
+    if hasattr(test_case.model_tester, "out_features") or hasattr(test_case.model_tester, "out_indices"):
+        target_num_hidden_layers = None
+
+    if hasattr(test_case.model_tester, "num_hidden_layers") and target_num_hidden_layers is not None:
+        test_case.model_tester.num_hidden_layers = target_num_hidden_layers
+    if (
+        hasattr(test_case.model_tester, "vision_config")
+        and "num_hidden_layers" in test_case.model_tester.vision_config
+        and target_num_hidden_layers is not None
+    ):
+        test_case.model_tester.vision_config = copy.deepcopy(test_case.model_tester.vision_config)
+        if isinstance(test_case.model_tester.vision_config, dict):
+            test_case.model_tester.vision_config["num_hidden_layers"] = 1
+        else:
+            test_case.model_tester.vision_config.num_hidden_layers = 1
+    if (
+        hasattr(test_case.model_tester, "text_config")
+        and "num_hidden_layers" in test_case.model_tester.text_config
+        and target_num_hidden_layers is not None
+    ):
+        test_case.model_tester.text_config = copy.deepcopy(test_case.model_tester.text_config)
+        if isinstance(test_case.model_tester.text_config, dict):
+            test_case.model_tester.text_config["num_hidden_layers"] = 1
+        else:
+            test_case.model_tester.text_config.num_hidden_layers = 1
+
+    # A few model class specific handling
+
+    # For Albert
+    if hasattr(test_case.model_tester, "num_hidden_groups"):
+        test_case.model_tester.num_hidden_groups = test_case.model_tester.num_hidden_layers
+
+
+def set_config_for_less_flaky_test(config):
+    target_attrs = [
+        "rms_norm_eps",
+        "layer_norm_eps",
+        "norm_eps",
+        "norm_epsilon",
+        "layer_norm_epsilon",
+        "batch_norm_eps",
+    ]
+    for target_attr in target_attrs:
+        setattr(config, target_attr, 1.0)
+
+    # norm layers (layer/group norm, etc.) could cause flaky tests when the tensors have very small variance.
+    # (We don't need the original epsilon values to check eager/sdpa matches)
+    attrs = ["text_config", "vision_config", "text_encoder", "audio_encoder", "decoder"]
+    for attr in attrs:
+        if hasattr(config, attr):
+            for target_attr in target_attrs:
+                setattr(getattr(config, attr), target_attr, 1.0)
+
+
+def set_model_for_less_flaky_test(model):
+    # Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.)
+    target_names = (
+        "LayerNorm",
+        "GroupNorm",
+        "BatchNorm",
+        "RMSNorm",
+        "BatchNorm2d",
+        "BatchNorm1d",
+        "BitGroupNormActivation",
+        "WeightStandardizedConv2d",
+    )
+    target_attrs = ["eps", "epsilon", "variance_epsilon"]
+    if is_torch_available() and isinstance(model, torch.nn.Module):
+        for module in model.modules():
+            if type(module).__name__.endswith(target_names):
+                for attr in target_attrs:
+                    if hasattr(module, attr):
+                        setattr(module, attr, 1.0)
+
+
+class CaptureStd:
+    """
+    Context manager to capture:
+
+        - stdout: replay it, clean it up and make it available via `obj.out`
+        - stderr: replay it and make it available via `obj.err`
+
+    Args:
+        out (`bool`, *optional*, defaults to `True`): Whether to capture stdout or not.
+        err (`bool`, *optional*, defaults to `True`): Whether to capture stderr or not.
+        replay (`bool`, *optional*, defaults to `True`): Whether to replay or not.
+            By default each captured stream gets replayed back on context's exit, so that one can see what the test was
+            doing. If this is a not wanted behavior and the captured data shouldn't be replayed, pass `replay=False` to
+            disable this feature.
+
+    Examples:
+
+    ```python
+    # to capture stdout only with auto-replay
+    with CaptureStdout() as cs:
+        print("Secret message")
+    assert "message" in cs.out
+
+    # to capture stderr only with auto-replay
+    import sys
+
+    with CaptureStderr() as cs:
+        print("Warning: ", file=sys.stderr)
+    assert "Warning" in cs.err
+
+    # to capture both streams with auto-replay
+    with CaptureStd() as cs:
+        print("Secret message")
+        print("Warning: ", file=sys.stderr)
+    assert "message" in cs.out
+    assert "Warning" in cs.err
+
+    # to capture just one of the streams, and not the other, with auto-replay
+    with CaptureStd(err=False) as cs:
+        print("Secret message")
+    assert "message" in cs.out
+    # but best use the stream-specific subclasses
+
+    # to capture without auto-replay
+    with CaptureStd(replay=False) as cs:
+        print("Secret message")
+    assert "message" in cs.out
+    ```"""
+
+    def __init__(self, out=True, err=True, replay=True):
+        self.replay = replay
+
+        if out:
+            self.out_buf = StringIO()
+            self.out = "error: CaptureStd context is unfinished yet, called too early"
+        else:
+            self.out_buf = None
+            self.out = "not capturing stdout"
+
+        if err:
+            self.err_buf = StringIO()
+            self.err = "error: CaptureStd context is unfinished yet, called too early"
+        else:
+            self.err_buf = None
+            self.err = "not capturing stderr"
+
+    def __enter__(self):
+        if self.out_buf:
+            self.out_old = sys.stdout
+            sys.stdout = self.out_buf
+
+        if self.err_buf:
+            self.err_old = sys.stderr
+            sys.stderr = self.err_buf
+
+        return self
+
+    def __exit__(self, *exc):
+        if self.out_buf:
+            sys.stdout = self.out_old
+            captured = self.out_buf.getvalue()
+            if self.replay:
+                sys.stdout.write(captured)
+            self.out = apply_print_resets(captured)
+
+        if self.err_buf:
+            sys.stderr = self.err_old
+            captured = self.err_buf.getvalue()
+            if self.replay:
+                sys.stderr.write(captured)
+            self.err = captured
+
+    def __repr__(self):
+        msg = ""
+        if self.out_buf:
+            msg += f"stdout: {self.out}\n"
+        if self.err_buf:
+            msg += f"stderr: {self.err}\n"
+        return msg
+
+
+# in tests it's the best to capture only the stream that's wanted, otherwise
+# it's easy to miss things, so unless you need to capture both streams, use the
+# subclasses below (less typing). Or alternatively, configure `CaptureStd` to
+# disable the stream you don't need to test.
+
+
+class CaptureStdout(CaptureStd):
+    """Same as CaptureStd but captures only stdout"""
+
+    def __init__(self, replay=True):
+        super().__init__(err=False, replay=replay)
+
+
+class CaptureStderr(CaptureStd):
+    """Same as CaptureStd but captures only stderr"""
+
+    def __init__(self, replay=True):
+        super().__init__(out=False, replay=replay)
+
+
+class CaptureLogger:
+    """
+    Context manager to capture `logging` streams
+
+    Args:
+        logger: 'logging` logger object
+
+    Returns:
+        The captured output is available via `self.out`
+
+    Example:
+
+    ```python
+    >>> from transformers import logging
+    >>> from transformers.testing_utils import CaptureLogger
+
+    >>> msg = "Testing 1, 2, 3"
+    >>> logging.set_verbosity_info()
+    >>> logger = logging.get_logger("transformers.models.bart.tokenization_bart")
+    >>> with CaptureLogger(logger) as cl:
+    ...     logger.info(msg)
+    >>> assert cl.out, msg + "\n"
+    ```
+    """
+
+    def __init__(self, logger):
+        self.logger = logger
+        self.io = StringIO()
+        self.sh = logging.StreamHandler(self.io)
+        self.out = ""
+
+    def __enter__(self):
+        self.logger.addHandler(self.sh)
+        return self
+
+    def __exit__(self, *exc):
+        self.logger.removeHandler(self.sh)
+        self.out = self.io.getvalue()
+
+    def __repr__(self):
+        return f"captured: {self.out}\n"
+
+
+@contextlib.contextmanager
+def LoggingLevel(level):
+    """
+    This is a context manager to temporarily change transformers modules logging level to the desired value and have it
+    restored to the original setting at the end of the scope.
+
+    Example:
+
+    ```python
+    with LoggingLevel(logging.INFO):
+        AutoModel.from_pretrained("openai-community/gpt2")  # calls logger.info() several times
+    ```
+    """
+    orig_level = transformers_logging.get_verbosity()
+    try:
+        transformers_logging.set_verbosity(level)
+        yield
+    finally:
+        transformers_logging.set_verbosity(orig_level)
+
+
+class TemporaryHubRepo:
+    """Create a temporary Hub repository and return its `RepoUrl` object. This is similar to
+    `tempfile.TemporaryDirectory` and can be used as a context manager. For example:
+
+        with TemporaryHubRepo(token=self._token) as temp_repo:
+            ...
+
+    Upon exiting the context, the repository and everything contained in it are removed.
+
+    Example:
+
+    ```python
+    with TemporaryHubRepo(token=self._token) as temp_repo:
+        model.push_to_hub(tmp_repo.repo_id, token=self._token)
+    ```
+    """
+
+    def __init__(self, namespace: Optional[str] = None, token: Optional[str] = None) -> None:
+        self.token = token
+        with tempfile.TemporaryDirectory() as tmp_dir:
+            repo_id = Path(tmp_dir).name
+            if namespace is not None:
+                repo_id = f"{namespace}/{repo_id}"
+            self.repo_url = huggingface_hub.create_repo(repo_id, token=self.token)
+
+    def __enter__(self):
+        return self.repo_url
+
+    def __exit__(self, exc, value, tb):
+        delete_repo(repo_id=self.repo_url.repo_id, token=self.token, missing_ok=True)
+
+
+@contextlib.contextmanager
+# adapted from https://stackoverflow.com/a/64789046/9201239
+def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]:
+    """
+    Temporary add given path to `sys.path`.
+
+    Usage :
+
+    ```python
+    with ExtendSysPath("/path/to/dir"):
+        mymodule = importlib.import_module("mymodule")
+    ```
+    """
+
+    path = os.fspath(path)
+    try:
+        sys.path.insert(0, path)
+        yield
+    finally:
+        sys.path.remove(path)
+
+
+class TestCasePlus(unittest.TestCase):
+    """
+    This class extends *unittest.TestCase* with additional features.
+
+    Feature 1: A set of fully resolved important file and dir path accessors.
+
+    In tests often we need to know where things are relative to the current test file, and it's not trivial since the
+    test could be invoked from more than one directory or could reside in sub-directories with different depths. This
+    class solves this problem by sorting out all the basic paths and provides easy accessors to them:
+
+    - `pathlib` objects (all fully resolved):
+
+       - `test_file_path` - the current test file path (=`__file__`)
+       - `test_file_dir` - the directory containing the current test file
+       - `tests_dir` - the directory of the `tests` test suite
+       - `examples_dir` - the directory of the `examples` test suite
+       - `repo_root_dir` - the directory of the repository
+       - `src_dir` - the directory of `src` (i.e. where the `transformers` sub-dir resides)
+
+    - stringified paths---same as above but these return paths as strings, rather than `pathlib` objects:
+
+       - `test_file_path_str`
+       - `test_file_dir_str`
+       - `tests_dir_str`
+       - `examples_dir_str`
+       - `repo_root_dir_str`
+       - `src_dir_str`
+
+    Feature 2: Flexible auto-removable temporary dirs which are guaranteed to get removed at the end of test.
+
+    1. Create a unique temporary dir:
+
+    ```python
+    def test_whatever(self):
+        tmp_dir = self.get_auto_remove_tmp_dir()
+    ```
+
+    `tmp_dir` will contain the path to the created temporary dir. It will be automatically removed at the end of the
+    test.
+
+
+    2. Create a temporary dir of my choice, ensure it's empty before the test starts and don't
+    empty it after the test.
+
+    ```python
+    def test_whatever(self):
+        tmp_dir = self.get_auto_remove_tmp_dir("./xxx")
+    ```
+
+    This is useful for debug when you want to monitor a specific directory and want to make sure the previous tests
+    didn't leave any data in there.
+
+    3. You can override the first two options by directly overriding the `before` and `after` args, leading to the
+        following behavior:
+
+    `before=True`: the temporary dir will always be cleared at the beginning of the test.
+
+    `before=False`: if the temporary dir already existed, any existing files will remain there.
+
+    `after=True`: the temporary dir will always be deleted at the end of the test.
+
+    `after=False`: the temporary dir will always be left intact at the end of the test.
+
+    Note 1: In order to run the equivalent of `rm -r` safely, only subdirs of the project repository checkout are
+    allowed if an explicit `tmp_dir` is used, so that by mistake no `/tmp` or similar important part of the filesystem
+    will get nuked. i.e. please always pass paths that start with `./`
+
+    Note 2: Each test can register multiple temporary dirs and they all will get auto-removed, unless requested
+    otherwise.
+
+    Feature 3: Get a copy of the `os.environ` object that sets up `PYTHONPATH` specific to the current test suite. This
+    is useful for invoking external programs from the test suite - e.g. distributed training.
+
+
+    ```python
+    def test_whatever(self):
+        env = self.get_env()
+    ```"""
+
+    def setUp(self):
+        # get_auto_remove_tmp_dir feature:
+        self.teardown_tmp_dirs = []
+
+        # figure out the resolved paths for repo_root, tests, examples, etc.
+        self._test_file_path = inspect.getfile(self.__class__)
+        path = Path(self._test_file_path).resolve()
+        self._test_file_dir = path.parents[0]
+        for up in [1, 2, 3]:
+            tmp_dir = path.parents[up]
+            if (tmp_dir / "src").is_dir() and (tmp_dir / "tests").is_dir():
+                break
+        if tmp_dir:
+            self._repo_root_dir = tmp_dir
+        else:
+            raise ValueError(f"can't figure out the root of the repo from {self._test_file_path}")
+        self._tests_dir = self._repo_root_dir / "tests"
+        self._examples_dir = self._repo_root_dir / "examples"
+        self._src_dir = self._repo_root_dir / "src"
+
+    @property
+    def test_file_path(self):
+        return self._test_file_path
+
+    @property
+    def test_file_path_str(self):
+        return str(self._test_file_path)
+
+    @property
+    def test_file_dir(self):
+        return self._test_file_dir
+
+    @property
+    def test_file_dir_str(self):
+        return str(self._test_file_dir)
+
+    @property
+    def tests_dir(self):
+        return self._tests_dir
+
+    @property
+    def tests_dir_str(self):
+        return str(self._tests_dir)
+
+    @property
+    def examples_dir(self):
+        return self._examples_dir
+
+    @property
+    def examples_dir_str(self):
+        return str(self._examples_dir)
+
+    @property
+    def repo_root_dir(self):
+        return self._repo_root_dir
+
+    @property
+    def repo_root_dir_str(self):
+        return str(self._repo_root_dir)
+
+    @property
+    def src_dir(self):
+        return self._src_dir
+
+    @property
+    def src_dir_str(self):
+        return str(self._src_dir)
+
+    def get_env(self):
+        """
+        Return a copy of the `os.environ` object that sets up `PYTHONPATH` correctly, depending on the test suite it's
+        invoked from. This is useful for invoking external programs from the test suite - e.g. distributed training.
+
+        It always inserts `./src` first, then `./tests` or `./examples` depending on the test suite type and finally
+        the preset `PYTHONPATH` if any (all full resolved paths).
+
+        """
+        env = os.environ.copy()
+        paths = [self.repo_root_dir_str, self.src_dir_str]
+        if "/examples" in self.test_file_dir_str:
+            paths.append(self.examples_dir_str)
+        else:
+            paths.append(self.tests_dir_str)
+        paths.append(env.get("PYTHONPATH", ""))
+
+        env["PYTHONPATH"] = ":".join(paths)
+        return env
+
+    def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
+        """
+        Args:
+            tmp_dir (`string`, *optional*):
+                if `None`:
+
+                   - a unique temporary path will be created
+                   - sets `before=True` if `before` is `None`
+                   - sets `after=True` if `after` is `None`
+                else:
+
+                   - `tmp_dir` will be created
+                   - sets `before=True` if `before` is `None`
+                   - sets `after=False` if `after` is `None`
+            before (`bool`, *optional*):
+                If `True` and the `tmp_dir` already exists, make sure to empty it right away if `False` and the
+                `tmp_dir` already exists, any existing files will remain there.
+            after (`bool`, *optional*):
+                If `True`, delete the `tmp_dir` at the end of the test if `False`, leave the `tmp_dir` and its contents
+                intact at the end of the test.
+
+        Returns:
+            tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir
+        """
+        if tmp_dir is not None:
+            # defining the most likely desired behavior for when a custom path is provided.
+            # this most likely indicates the debug mode where we want an easily locatable dir that:
+            # 1. gets cleared out before the test (if it already exists)
+            # 2. is left intact after the test
+            if before is None:
+                before = True
+            if after is None:
+                after = False
+
+            # using provided path
+            path = Path(tmp_dir).resolve()
+
+            # to avoid nuking parts of the filesystem, only relative paths are allowed
+            if not tmp_dir.startswith("./"):
+                raise ValueError(
+                    f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`"
+                )
+
+            # ensure the dir is empty to start with
+            if before is True and path.exists():
+                shutil.rmtree(tmp_dir, ignore_errors=True)
+
+            path.mkdir(parents=True, exist_ok=True)
+
+        else:
+            # defining the most likely desired behavior for when a unique tmp path is auto generated
+            # (not a debug mode), here we require a unique tmp dir that:
+            # 1. is empty before the test (it will be empty in this situation anyway)
+            # 2. gets fully removed after the test
+            if before is None:
+                before = True
+            if after is None:
+                after = True
+
+            # using unique tmp dir (always empty, regardless of `before`)
+            tmp_dir = tempfile.mkdtemp()
+
+        if after is True:
+            # register for deletion
+            self.teardown_tmp_dirs.append(tmp_dir)
+
+        return tmp_dir
+
+    def python_one_liner_max_rss(self, one_liner_str):
+        """
+        Runs the passed python one liner (just the code) and returns how much max cpu memory was used to run the
+        program.
+
+        Args:
+            one_liner_str (`string`):
+                a python one liner code that gets passed to `python -c`
+
+        Returns:
+            max cpu memory bytes used to run the program. This value is likely to vary slightly from run to run.
+
+        Requirements:
+            this helper needs `/usr/bin/time` to be installed (`apt install time`)
+
+        Example:
+
+        ```
+        one_liner_str = 'from transformers import AutoModel; AutoModel.from_pretrained("google-t5/t5-large")'
+        max_rss = self.python_one_liner_max_rss(one_liner_str)
+        ```
+        """
+
+        if not cmd_exists("/usr/bin/time"):
+            raise ValueError("/usr/bin/time is required, install with `apt install time`")
+
+        cmd = shlex.split(f"/usr/bin/time -f %M python -c '{one_liner_str}'")
+        with CaptureStd() as cs:
+            execute_subprocess_async(cmd, env=self.get_env())
+        # returned data is in KB so convert to bytes
+        max_rss = int(cs.err.split("\n")[-2].replace("stderr: ", "")) * 1024
+        return max_rss
+
+    def tearDown(self):
+        # get_auto_remove_tmp_dir feature: remove registered temp dirs
+        for path in self.teardown_tmp_dirs:
+            shutil.rmtree(path, ignore_errors=True)
+        self.teardown_tmp_dirs = []
+        if is_accelerate_available():
+            AcceleratorState._reset_state()
+            PartialState._reset_state()
+
+            # delete all the env variables having `ACCELERATE` in them
+            for k in list(os.environ.keys()):
+                if "ACCELERATE" in k:
+                    del os.environ[k]
+
+
+def mockenv(**kwargs):
+    """
+    this is a convenience wrapper, that allows this ::
+
+    @mockenv(RUN_SLOW=True, USE_TF=False) def test_something():
+        run_slow = os.getenv("RUN_SLOW", False) use_tf = os.getenv("USE_TF", False)
+
+    """
+    return mock.patch.dict(os.environ, kwargs)
+
+
+# from https://stackoverflow.com/a/34333710/9201239
+@contextlib.contextmanager
+def mockenv_context(*remove, **update):
+    """
+    Temporarily updates the `os.environ` dictionary in-place. Similar to mockenv
+
+    The `os.environ` dictionary is updated in-place so that the modification is sure to work in all situations.
+
+    Args:
+      remove: Environment variables to remove.
+      update: Dictionary of environment variables and values to add/update.
+    """
+    env = os.environ
+    update = update or {}
+    remove = remove or []
+
+    # List of environment variables being updated or removed.
+    stomped = (set(update.keys()) | set(remove)) & set(env.keys())
+    # Environment variables and values to restore on exit.
+    update_after = {k: env[k] for k in stomped}
+    # Environment variables and values to remove on exit.
+    remove_after = frozenset(k for k in update if k not in env)
+
+    try:
+        env.update(update)
+        [env.pop(k, None) for k in remove]
+        yield
+    finally:
+        env.update(update_after)
+        [env.pop(k) for k in remove_after]
+
+
+# --- pytest conf functions --- #
+
+# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once
+pytest_opt_registered = {}
+
+
+def pytest_addoption_shared(parser):
+    """
+    This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there.
+
+    It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest`
+    option.
+
+    """
+    option = "--make-reports"
+    if option not in pytest_opt_registered:
+        parser.addoption(
+            option,
+            action="store",
+            default=False,
+            help="generate report files. The value of this option is used as a prefix to report names",
+        )
+        pytest_opt_registered[option] = 1
+
+
+def pytest_terminal_summary_main(tr, id):
+    """
+    Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current
+    directory. The report files are prefixed with the test suite name.
+
+    This function emulates --duration and -rA pytest arguments.
+
+    This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined
+    there.
+
+    Args:
+    - tr: `terminalreporter` passed from `conftest.py`
+    - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is
+      needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other.
+
+    NB: this functions taps into a private _pytest API and while unlikely, it could break should pytest do internal
+    changes - also it calls default internal methods of terminalreporter which can be hijacked by various `pytest-`
+    plugins and interfere.
+
+    """
+    from _pytest.config import create_terminal_writer
+
+    if not len(id):
+        id = "tests"
+
+    config = tr.config
+    orig_writer = config.get_terminal_writer()
+    orig_tbstyle = config.option.tbstyle
+    orig_reportchars = tr.reportchars
+
+    dir = f"reports/{id}"
+    Path(dir).mkdir(parents=True, exist_ok=True)
+    report_files = {
+        k: f"{dir}/{k}.txt"
+        for k in [
+            "durations",
+            "errors",
+            "failures_long",
+            "failures_short",
+            "failures_line",
+            "passes",
+            "stats",
+            "summary_short",
+            "warnings",
+        ]
+    }
+
+    # custom durations report
+    # note: there is no need to call pytest --durations=XX to get this separate report
+    # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66
+    dlist = []
+    for replist in tr.stats.values():
+        for rep in replist:
+            if hasattr(rep, "duration"):
+                dlist.append(rep)
+    if dlist:
+        dlist.sort(key=lambda x: x.duration, reverse=True)
+        with open(report_files["durations"], "w") as f:
+            durations_min = 0.05  # sec
+            f.write("slowest durations\n")
+            for i, rep in enumerate(dlist):
+                if rep.duration < durations_min:
+                    f.write(f"{len(dlist) - i} durations < {durations_min} secs were omitted")
+                    break
+                f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
+
+    def summary_failures_short(tr):
+        # expecting that the reports were --tb=long (default) so we chop them off here to the last frame
+        reports = tr.getreports("failed")
+        if not reports:
+            return
+        tr.write_sep("=", "FAILURES SHORT STACK")
+        for rep in reports:
+            msg = tr._getfailureheadline(rep)
+            tr.write_sep("_", msg, red=True, bold=True)
+            # chop off the optional leading extra frames, leaving only the last one
+            longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S)
+            tr._tw.line(longrepr)
+            # note: not printing out any rep.sections to keep the report short
+
+    # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each
+    # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814
+    # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g.
+    # pytest-instafail does that)
+
+    # report failures with line/short/long styles
+    config.option.tbstyle = "auto"  # full tb
+    with open(report_files["failures_long"], "w") as f:
+        tr._tw = create_terminal_writer(config, f)
+        tr.summary_failures()
+
+    # config.option.tbstyle = "short" # short tb
+    with open(report_files["failures_short"], "w") as f:
+        tr._tw = create_terminal_writer(config, f)
+        summary_failures_short(tr)
+
+    config.option.tbstyle = "line"  # one line per error
+    with open(report_files["failures_line"], "w") as f:
+        tr._tw = create_terminal_writer(config, f)
+        tr.summary_failures()
+
+    with open(report_files["errors"], "w") as f:
+        tr._tw = create_terminal_writer(config, f)
+        tr.summary_errors()
+
+    with open(report_files["warnings"], "w") as f:
+        tr._tw = create_terminal_writer(config, f)
+        tr.summary_warnings()  # normal warnings
+        tr.summary_warnings()  # final warnings
+
+    tr.reportchars = "wPpsxXEf"  # emulate -rA (used in summary_passes() and short_test_summary())
+
+    # Skip the `passes` report, as it starts to take more than 5 minutes, and sometimes it timeouts on CircleCI if it
+    # takes > 10 minutes (as this part doesn't generate any output on the terminal).
+    # (also, it seems there is no useful information in this report, and we rarely need to read it)
+    # with open(report_files["passes"], "w") as f:
+    #     tr._tw = create_terminal_writer(config, f)
+    #     tr.summary_passes()
+
+    with open(report_files["summary_short"], "w") as f:
+        tr._tw = create_terminal_writer(config, f)
+        tr.short_test_summary()
+
+    with open(report_files["stats"], "w") as f:
+        tr._tw = create_terminal_writer(config, f)
+        tr.summary_stats()
+
+    # restore:
+    tr._tw = orig_writer
+    tr.reportchars = orig_reportchars
+    config.option.tbstyle = orig_tbstyle
+
+
+# --- distributed testing functions --- #
+
+# adapted from https://stackoverflow.com/a/59041913/9201239
+import asyncio  # noqa
+
+
+class _RunOutput:
+    def __init__(self, returncode, stdout, stderr):
+        self.returncode = returncode
+        self.stdout = stdout
+        self.stderr = stderr
+
+
+async def _read_stream(stream, callback):
+    while True:
+        line = await stream.readline()
+        if line:
+            callback(line)
+        else:
+            break
+
+
+async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
+    if echo:
+        print("\nRunning: ", " ".join(cmd))
+
+    p = await asyncio.create_subprocess_exec(
+        cmd[0],
+        *cmd[1:],
+        stdin=stdin,
+        stdout=asyncio.subprocess.PIPE,
+        stderr=asyncio.subprocess.PIPE,
+        env=env,
+    )
+
+    # note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
+    # https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
+    #
+    # If it starts hanging, will need to switch to the following code. The problem is that no data
+    # will be seen until it's done and if it hangs for example there will be no debug info.
+    # out, err = await p.communicate()
+    # return _RunOutput(p.returncode, out, err)
+
+    out = []
+    err = []
+
+    def tee(line, sink, pipe, label=""):
+        line = line.decode("utf-8").rstrip()
+        sink.append(line)
+        if not quiet:
+            print(label, line, file=pipe)
+
+    # XXX: the timeout doesn't seem to make any difference here
+    await asyncio.wait(
+        [
+            asyncio.create_task(_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:"))),
+            asyncio.create_task(_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:"))),
+        ],
+        timeout=timeout,
+    )
+    return _RunOutput(await p.wait(), out, err)
+
+
+def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:
+    loop = asyncio.get_event_loop()
+    result = loop.run_until_complete(
+        _stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
+    )
+
+    cmd_str = " ".join(cmd)
+    if result.returncode > 0:
+        stderr = "\n".join(result.stderr)
+        raise RuntimeError(
+            f"'{cmd_str}' failed with returncode {result.returncode}\n\n"
+            f"The combined stderr from workers follows:\n{stderr}"
+        )
+
+    # check that the subprocess actually did run and produced some output, should the test rely on
+    # the remote side to do the testing
+    if not result.stdout and not result.stderr:
+        raise RuntimeError(f"'{cmd_str}' produced no output.")
+
+    return result
+
+
+def pytest_xdist_worker_id():
+    """
+    Returns an int value of worker's numerical id under `pytest-xdist`'s concurrent workers `pytest -n N` regime, or 0
+    if `-n 1` or `pytest-xdist` isn't being used.
+    """
+    worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0")
+    worker = re.sub(r"^gw", "", worker, 0, re.M)
+    return int(worker)
+
+
+def get_torch_dist_unique_port():
+    """
+    Returns a port number that can be fed to `torch.distributed.launch`'s `--master_port` argument.
+
+    Under `pytest-xdist` it adds a delta number based on a worker id so that concurrent tests don't try to use the same
+    port at once.
+    """
+    port = 29500
+    uniq_delta = pytest_xdist_worker_id()
+    return port + uniq_delta
+
+
+def nested_simplify(obj, decimals=3):
+    """
+    Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test
+    within tests.
+    """
+    import numpy as np
+
+    if isinstance(obj, list):
+        return [nested_simplify(item, decimals) for item in obj]
+    if isinstance(obj, tuple):
+        return tuple(nested_simplify(item, decimals) for item in obj)
+    elif isinstance(obj, np.ndarray):
+        return nested_simplify(obj.tolist())
+    elif isinstance(obj, Mapping):
+        return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
+    elif isinstance(obj, (str, int, np.int64)) or obj is None:
+        return obj
+    elif is_torch_available() and isinstance(obj, torch.Tensor):
+        return nested_simplify(obj.tolist(), decimals)
+    elif is_tf_available() and tf.is_tensor(obj):
+        return nested_simplify(obj.numpy().tolist())
+    elif isinstance(obj, float):
+        return round(obj, decimals)
+    elif isinstance(obj, (np.int32, np.float32, np.float16)):
+        return nested_simplify(obj.item(), decimals)
+    else:
+        raise Exception(f"Not supported: {type(obj)}")
+
+
+def check_json_file_has_correct_format(file_path):
+    with open(file_path) as f:
+        lines = f.readlines()
+        if len(lines) == 1:
+            # length can only be 1 if dict is empty
+            assert lines[0] == "{}"
+        else:
+            # otherwise make sure json has correct format (at least 3 lines)
+            assert len(lines) >= 3
+            # each key one line, ident should be 2, min length is 3
+            assert lines[0].strip() == "{"
+            for line in lines[1:-1]:
+                left_indent = len(lines[1]) - len(lines[1].lstrip())
+                assert left_indent == 2
+            assert lines[-1].strip() == "}"
+
+
+def to_2tuple(x):
+    if isinstance(x, collections.abc.Iterable):
+        return x
+    return (x, x)
+
+
+# These utils relate to ensuring the right error message is received when running scripts
+class SubprocessCallException(Exception):
+    pass
+
+
+def run_command(command: list[str], return_stdout=False):
+    """
+    Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
+    if an error occurred while running `command`
+    """
+    try:
+        output = subprocess.check_output(command, stderr=subprocess.STDOUT)
+        if return_stdout:
+            if hasattr(output, "decode"):
+                output = output.decode("utf-8")
+            return output
+    except subprocess.CalledProcessError as e:
+        raise SubprocessCallException(
+            f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
+        ) from e
+
+
+class RequestCounter:
+    """
+    Helper class that will count all requests made online.
+
+    Might not be robust if urllib3 changes its logging format but should be good enough for us.
+
+    Usage:
+    ```py
+    with RequestCounter() as counter:
+        _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
+    assert counter["GET"] == 0
+    assert counter["HEAD"] == 1
+    assert counter.total_calls == 1
+    ```
+    """
+
+    def __enter__(self):
+        self._counter = defaultdict(int)
+        self._thread_id = threading.get_ident()
+        self._extra_info = []
+
+        def patched_with_thread_info(func):
+            def wrap(*args, **kwargs):
+                self._extra_info.append(threading.get_ident())
+                return func(*args, **kwargs)
+
+            return wrap
+
+        self.patcher = patch.object(
+            urllib3.connectionpool.log, "debug", side_effect=patched_with_thread_info(urllib3.connectionpool.log.debug)
+        )
+        self.mock = self.patcher.start()
+        return self
+
+    def __exit__(self, *args, **kwargs) -> None:
+        assert len(self.mock.call_args_list) == len(self._extra_info)
+        for thread_id, call in zip(self._extra_info, self.mock.call_args_list):
+            if thread_id != self._thread_id:
+                continue
+            # code 307: the URL being requested by the user has moved to a temporary location
+            if call.args[-2] == 307:
+                continue
+            log = call.args[0] % call.args[1:]
+            for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"):
+                if method in log:
+                    self._counter[method] += 1
+                    break
+        self.patcher.stop()
+
+    def __getitem__(self, key: str) -> int:
+        return self._counter[key]
+
+    @property
+    def total_calls(self) -> int:
+        return sum(self._counter.values())
+
+
+def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
+    """
+    To decorate flaky tests. They will be retried on failures.
+
+    Please note that our push tests use `pytest-rerunfailures`, which prompts the CI to rerun certain types of
+    failed tests. More specifically, if the test exception contains any substring in `FLAKY_TEST_FAILURE_PATTERNS`
+    (in `.circleci/create_circleci_config.py`), it will be rerun. If you find a recurrent pattern of failures,
+    expand `FLAKY_TEST_FAILURE_PATTERNS` in our CI configuration instead of using `is_flaky`.
+
+    Args:
+        max_attempts (`int`, *optional*, defaults to 5):
+            The maximum number of attempts to retry the flaky test.
+        wait_before_retry (`float`, *optional*):
+            If provided, will wait that number of seconds before retrying the test.
+        description (`str`, *optional*):
+            A string to describe the situation (what / where / why is flaky, link to GH issue/PR comments, errors,
+            etc.)
+    """
+
+    def decorator(test_func_ref):
+        @functools.wraps(test_func_ref)
+        def wrapper(*args, **kwargs):
+            retry_count = 1
+
+            while retry_count < max_attempts:
+                try:
+                    return test_func_ref(*args, **kwargs)
+
+                except Exception as err:
+                    logger.error(f"Test failed with {err} at try {retry_count}/{max_attempts}.")
+                    if wait_before_retry is not None:
+                        time.sleep(wait_before_retry)
+                    retry_count += 1
+
+            return test_func_ref(*args, **kwargs)
+
+        return unittest.skipUnless(_run_flaky_tests, "test is flaky")(wrapper)
+
+    return decorator
+
+
+def hub_retry(max_attempts: int = 5, wait_before_retry: Optional[float] = 2):
+    """
+    To decorate tests that download from the Hub. They can fail due to a
+    variety of network issues such as timeouts, connection resets, etc.
+
+    Args:
+        max_attempts (`int`, *optional*, defaults to 5):
+            The maximum number of attempts to retry the flaky test.
+        wait_before_retry (`float`, *optional*, defaults to 2):
+            If provided, will wait that number of seconds before retrying the test.
+    """
+
+    def decorator(test_func_ref):
+        @functools.wraps(test_func_ref)
+        def wrapper(*args, **kwargs):
+            retry_count = 1
+
+            while retry_count < max_attempts:
+                try:
+                    return test_func_ref(*args, **kwargs)
+                # We catch all exceptions related to network issues from requests
+                except (
+                    requests.exceptions.ConnectionError,
+                    requests.exceptions.Timeout,
+                    requests.exceptions.ReadTimeout,
+                    requests.exceptions.HTTPError,
+                    requests.exceptions.RequestException,
+                ) as err:
+                    logger.error(
+                        f"Test failed with {err} at try {retry_count}/{max_attempts} as it couldn't connect to the specified Hub repository."
+                    )
+                    if wait_before_retry is not None:
+                        time.sleep(wait_before_retry)
+                    retry_count += 1
+
+            return test_func_ref(*args, **kwargs)
+
+        return wrapper
+
+    return decorator
+
+
+def run_first(test_case):
+    """
+    Decorator marking a test with order(1). When pytest-order plugin is installed, tests marked with this decorator
+    are guaranteed to run first.
+
+    This is especially useful in some test settings like on a Gaudi instance where a Gaudi device can only be used by a
+    single process at a time. So we make sure all tests that run in a subprocess are launched first, to avoid device
+    allocation conflicts.
+    """
+    import pytest
+
+    return pytest.mark.order(1)(test_case)
+
+
+def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
+    """
+    To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
+
+    Args:
+        test_case (`unittest.TestCase`):
+            The test that will run `target_func`.
+        target_func (`Callable`):
+            The function implementing the actual testing logic.
+        inputs (`dict`, *optional*, defaults to `None`):
+            The inputs that will be passed to `target_func` through an (input) queue.
+        timeout (`int`, *optional*, defaults to `None`):
+            The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env.
+            variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`.
+    """
+    if timeout is None:
+        timeout = int(os.environ.get("PYTEST_TIMEOUT", "600"))
+
+    start_methohd = "spawn"
+    ctx = multiprocessing.get_context(start_methohd)
+
+    input_queue = ctx.Queue(1)
+    output_queue = ctx.JoinableQueue(1)
+
+    # We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle.
+    input_queue.put(inputs, timeout=timeout)
+
+    process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
+    process.start()
+    # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents
+    # the test to exit properly.
+    try:
+        results = output_queue.get(timeout=timeout)
+        output_queue.task_done()
+    except Exception as e:
+        process.terminate()
+        test_case.fail(e)
+    process.join(timeout=timeout)
+
+    if results["error"] is not None:
+        test_case.fail(f"{results['error']}")
+
+
+def run_test_using_subprocess(func):
+    """
+    To decorate a test to run in a subprocess using the `subprocess` module. This could avoid potential GPU memory
+    issues (GPU OOM or a test that causes many subsequential failing with `CUDA error: device-side assert triggered`).
+    """
+    import pytest
+
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        if os.getenv("_INSIDE_SUB_PROCESS", None) == "1":
+            func(*args, **kwargs)
+        else:
+            test = " ".join(os.environ.get("PYTEST_CURRENT_TEST").split(" ")[:-1])
+            try:
+                import copy
+
+                env = copy.deepcopy(os.environ)
+                env["_INSIDE_SUB_PROCESS"] = "1"
+                # This prevents the entries in `short test summary info` given by the subprocess being truncated. so the
+                # full information can be passed to the parent pytest process.
+                # See: https://docs.pytest.org/en/stable/explanation/ci.html
+                env["CI"] = "true"
+
+                # If not subclass of `unitTest.TestCase` and `pytestconfig` is used: try to grab and use the arguments
+                if "pytestconfig" in kwargs:
+                    command = list(kwargs["pytestconfig"].invocation_params.args)
+                    for idx, x in enumerate(command):
+                        if x in kwargs["pytestconfig"].args:
+                            test = test.split("::")[1:]
+                            command[idx] = "::".join([f"{func.__globals__['__file__']}"] + test)
+                    command = [f"{sys.executable}", "-m", "pytest"] + command
+                    command = [x for x in command if x not in ["--no-summary"]]
+                # Otherwise, simply run the test with no option at all
+                else:
+                    command = [f"{sys.executable}", "-m", "pytest", f"{test}"]
+
+                subprocess.run(command, env=env, check=True, capture_output=True)
+            except subprocess.CalledProcessError as e:
+                exception_message = e.stdout.decode()
+                lines = exception_message.split("\n")
+                # Add a first line with more informative information instead of just `= test session starts =`.
+                # This makes the `short test summary info` section more useful.
+                if "= test session starts =" in lines[0]:
+                    text = ""
+                    for line in lines[1:]:
+                        if line.startswith("FAILED "):
+                            text = line[len("FAILED ") :]
+                            text = "".join(text.split(" - ")[1:])
+                        elif line.startswith("=") and line.endswith("=") and " failed in " in line:
+                            break
+                        elif len(text) > 0:
+                            text += f"\n{line}"
+                    text = "(subprocess) " + text
+                    lines = [text] + lines
+                exception_message = "\n".join(lines)
+                raise pytest.fail(exception_message, pytrace=False)
+
+    return wrapper
+
+
+"""
+The following contains utils to run the documentation tests without having to overwrite any files.
+
+The `preprocess_string` function adds `# doctest: +IGNORE_RESULT` markers on the fly anywhere a `load_dataset` call is
+made as a print would otherwise fail the corresponding line.
+
+To skip cuda tests, make sure to call `SKIP_CUDA_DOCTEST=1 pytest --doctest-modules 
+"""
+
+
+def preprocess_string(string, skip_cuda_tests):
+    """Prepare a docstring or a `.md` file to be run by doctest.
+
+    The argument `string` would be the whole file content if it is a `.md` file. For a python file, it would be one of
+    its docstring. In each case, it may contain multiple python code examples. If `skip_cuda_tests` is `True` and a
+    cuda stuff is detective (with a heuristic), this method will return an empty string so no doctest will be run for
+    `string`.
+    """
+    codeblock_pattern = r"(```(?:python|py)\s*\n\s*>>> )(.*?```)"
+    codeblocks = re.split(codeblock_pattern, string, flags=re.DOTALL)
+    is_cuda_found = False
+    for i, codeblock in enumerate(codeblocks):
+        if "load_dataset(" in codeblock and "# doctest: +IGNORE_RESULT" not in codeblock:
+            codeblocks[i] = re.sub(r"(>>> .*load_dataset\(.*)", r"\1 # doctest: +IGNORE_RESULT", codeblock)
+        if (
+            (">>>" in codeblock or "..." in codeblock)
+            and re.search(r"cuda|to\(0\)|device=0", codeblock)
+            and skip_cuda_tests
+        ):
+            is_cuda_found = True
+            break
+
+    modified_string = ""
+    if not is_cuda_found:
+        modified_string = "".join(codeblocks)
+
+    return modified_string
+
+
+class HfDocTestParser(doctest.DocTestParser):
+    """
+    Overwrites the DocTestParser from doctest to properly parse the codeblocks that are formatted with black. This
+    means that there are no extra lines at the end of our snippets. The `# doctest: +IGNORE_RESULT` marker is also
+    added anywhere a `load_dataset` call is made as a print would otherwise fail the corresponding line.
+
+    Tests involving cuda are skipped base on a naive pattern that should be updated if it is not enough.
+    """
+
+    # This regular expression is used to find doctest examples in a
+    # string.  It defines three groups: `source` is the source code
+    # (including leading indentation and prompts); `indent` is the
+    # indentation of the first (PS1) line of the source code; and
+    # `want` is the expected output (including leading indentation).
+    # fmt: off
+    _EXAMPLE_RE = re.compile(r'''
+        # Source consists of a PS1 line followed by zero or more PS2 lines.
+        (?P
+            (?:^(?P [ ]*) >>>    .*)    # PS1 line
+            (?:\n           [ ]*  \.\.\. .*)*)  # PS2 lines
+        \n?
+        # Want consists of any non-blank lines that do not start with PS1.
+        (?P (?:(?![ ]*$)    # Not a blank line
+             (?![ ]*>>>)          # Not a line starting with PS1
+             # !!!!!!!!!!! HF Specific !!!!!!!!!!!
+             (?:(?!```).)*        # Match any character except '`' until a '```' is found (this is specific to HF because black removes the last line)
+             # !!!!!!!!!!! HF Specific !!!!!!!!!!!
+             (?:\n|$)  # Match a new line or end of string
+          )*)
+        ''', re.MULTILINE | re.VERBOSE
+    )
+    # fmt: on
+
+    # !!!!!!!!!!! HF Specific !!!!!!!!!!!
+    skip_cuda_tests: bool = bool(os.environ.get("SKIP_CUDA_DOCTEST", "0"))
+    # !!!!!!!!!!! HF Specific !!!!!!!!!!!
+
+    def parse(self, string, name=""):
+        """
+        Overwrites the `parse` method to incorporate a skip for CUDA tests, and remove logs and dataset prints before
+        calling `super().parse`
+        """
+        string = preprocess_string(string, self.skip_cuda_tests)
+        return super().parse(string, name)
+
+
+class HfDoctestModule(Module):
+    """
+    Overwrites the `DoctestModule` of the pytest package to make sure the HFDocTestParser is used when discovering
+    tests.
+    """
+
+    def collect(self) -> Iterable[DoctestItem]:
+        class MockAwareDocTestFinder(doctest.DocTestFinder):
+            """A hackish doctest finder that overrides stdlib internals to fix a stdlib bug.
+
+            https://github.com/pytest-dev/pytest/issues/3456 https://bugs.python.org/issue25532
+            """
+
+            def _find_lineno(self, obj, source_lines):
+                """Doctest code does not take into account `@property`, this
+                is a hackish way to fix it. https://bugs.python.org/issue17446
+
+                Wrapped Doctests will need to be unwrapped so the correct line number is returned. This will be
+                reported upstream. #8796
+                """
+                if isinstance(obj, property):
+                    obj = getattr(obj, "fget", obj)
+
+                if hasattr(obj, "__wrapped__"):
+                    # Get the main obj in case of it being wrapped
+                    obj = inspect.unwrap(obj)
+
+                # Type ignored because this is a private function.
+                return super()._find_lineno(  # type:ignore[misc]
+                    obj,
+                    source_lines,
+                )
+
+            def _find(self, tests, obj, name, module, source_lines, globs, seen) -> None:
+                if _is_mocked(obj):
+                    return
+                with _patch_unwrap_mock_aware():
+                    # Type ignored because this is a private function.
+                    super()._find(  # type:ignore[misc]
+                        tests, obj, name, module, source_lines, globs, seen
+                    )
+
+        if self.path.name == "conftest.py":
+            module = self.config.pluginmanager._importconftest(
+                self.path,
+                self.config.getoption("importmode"),
+                rootpath=self.config.rootpath,
+            )
+        else:
+            try:
+                module = import_path(
+                    self.path,
+                    root=self.config.rootpath,
+                    mode=self.config.getoption("importmode"),
+                )
+            except ImportError:
+                if self.config.getvalue("doctest_ignore_import_errors"):
+                    skip("unable to import module %r" % self.path)
+                else:
+                    raise
+
+        # !!!!!!!!!!! HF Specific !!!!!!!!!!!
+        finder = MockAwareDocTestFinder(parser=HfDocTestParser())
+        # !!!!!!!!!!! HF Specific !!!!!!!!!!!
+        optionflags = get_optionflags(self)
+        runner = _get_runner(
+            verbose=False,
+            optionflags=optionflags,
+            checker=_get_checker(),
+            continue_on_failure=_get_continue_on_failure(self.config),
+        )
+        for test in finder.find(module, module.__name__):
+            if test.examples:  # skip empty doctests and cuda
+                yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test)
+
+
+def _device_agnostic_dispatch(device: str, dispatch_table: dict[str, Callable], *args, **kwargs):
+    if device not in dispatch_table:
+        if not callable(dispatch_table["default"]):
+            return dispatch_table["default"]
+
+        return dispatch_table["default"](*args, **kwargs)
+
+    fn = dispatch_table[device]
+
+    # Some device agnostic functions return values or None, will return then directly.
+    if not callable(fn):
+        return fn
+
+    return fn(*args, **kwargs)
+
+
+if is_torch_available():
+    # Mappings from device names to callable functions to support device agnostic
+    # testing.
+    BACKEND_MANUAL_SEED = {
+        "cuda": torch.cuda.manual_seed,
+        "cpu": torch.manual_seed,
+        "default": torch.manual_seed,
+    }
+    BACKEND_EMPTY_CACHE = {
+        "cuda": torch.cuda.empty_cache,
+        "cpu": None,
+        "default": None,
+    }
+    BACKEND_DEVICE_COUNT = {
+        "cuda": torch.cuda.device_count,
+        "cpu": lambda: 0,
+        "default": lambda: 1,
+    }
+    BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
+        "cuda": torch.cuda.reset_max_memory_allocated,
+        "cpu": None,
+        "default": None,
+    }
+    BACKEND_MAX_MEMORY_ALLOCATED = {
+        "cuda": torch.cuda.max_memory_allocated,
+        "cpu": 0,
+        "default": 0,
+    }
+    BACKEND_RESET_PEAK_MEMORY_STATS = {
+        "cuda": torch.cuda.reset_peak_memory_stats,
+        "cpu": None,
+        "default": None,
+    }
+    BACKEND_MEMORY_ALLOCATED = {
+        "cuda": torch.cuda.memory_allocated,
+        "cpu": 0,
+        "default": 0,
+    }
+    BACKEND_SYNCHRONIZE = {
+        "cuda": torch.cuda.synchronize,
+        "cpu": None,
+        "default": None,
+    }
+    BACKEND_TORCH_ACCELERATOR_MODULE = {
+        "cuda": torch.cuda,
+        "cpu": None,
+        "default": None,
+    }
+else:
+    BACKEND_MANUAL_SEED = {"default": None}
+    BACKEND_EMPTY_CACHE = {"default": None}
+    BACKEND_DEVICE_COUNT = {"default": lambda: 0}
+    BACKEND_RESET_MAX_MEMORY_ALLOCATED = {"default": None}
+    BACKEND_RESET_PEAK_MEMORY_STATS = {"default": None}
+    BACKEND_MAX_MEMORY_ALLOCATED = {"default": 0}
+    BACKEND_MEMORY_ALLOCATED = {"default": 0}
+    BACKEND_SYNCHRONIZE = {"default": None}
+    BACKEND_TORCH_ACCELERATOR_MODULE = {"default": None}
+
+
+if is_torch_hpu_available():
+    BACKEND_MANUAL_SEED["hpu"] = torch.hpu.manual_seed
+    BACKEND_DEVICE_COUNT["hpu"] = torch.hpu.device_count
+    BACKEND_TORCH_ACCELERATOR_MODULE["hpu"] = torch.hpu
+
+if is_torch_mlu_available():
+    BACKEND_EMPTY_CACHE["mlu"] = torch.mlu.empty_cache
+    BACKEND_MANUAL_SEED["mlu"] = torch.mlu.manual_seed
+    BACKEND_DEVICE_COUNT["mlu"] = torch.mlu.device_count
+    BACKEND_TORCH_ACCELERATOR_MODULE["mlu"] = torch.mlu
+
+if is_torch_npu_available():
+    BACKEND_EMPTY_CACHE["npu"] = torch.npu.empty_cache
+    BACKEND_MANUAL_SEED["npu"] = torch.npu.manual_seed
+    BACKEND_DEVICE_COUNT["npu"] = torch.npu.device_count
+    BACKEND_TORCH_ACCELERATOR_MODULE["npu"] = torch.npu
+
+if is_torch_xpu_available():
+    BACKEND_EMPTY_CACHE["xpu"] = torch.xpu.empty_cache
+    BACKEND_MANUAL_SEED["xpu"] = torch.xpu.manual_seed
+    BACKEND_DEVICE_COUNT["xpu"] = torch.xpu.device_count
+    BACKEND_RESET_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.reset_peak_memory_stats
+    BACKEND_RESET_PEAK_MEMORY_STATS["xpu"] = torch.xpu.reset_peak_memory_stats
+    BACKEND_MAX_MEMORY_ALLOCATED["xpu"] = torch.xpu.max_memory_allocated
+    BACKEND_MEMORY_ALLOCATED["xpu"] = torch.xpu.memory_allocated
+    BACKEND_SYNCHRONIZE["xpu"] = torch.xpu.synchronize
+    BACKEND_TORCH_ACCELERATOR_MODULE["xpu"] = torch.xpu
+
+
+if is_torch_xla_available():
+    BACKEND_EMPTY_CACHE["xla"] = torch.cuda.empty_cache
+    BACKEND_MANUAL_SEED["xla"] = torch.cuda.manual_seed
+    BACKEND_DEVICE_COUNT["xla"] = torch.cuda.device_count
+
+
+def backend_manual_seed(device: str, seed: int):
+    return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed)
+
+
+def backend_empty_cache(device: str):
+    return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE)
+
+
+def backend_device_count(device: str):
+    return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
+
+
+def backend_reset_max_memory_allocated(device: str):
+    return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
+
+
+def backend_reset_peak_memory_stats(device: str):
+    return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
+
+
+def backend_max_memory_allocated(device: str):
+    return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
+
+
+def backend_memory_allocated(device: str):
+    return _device_agnostic_dispatch(device, BACKEND_MEMORY_ALLOCATED)
+
+
+def backend_synchronize(device: str):
+    return _device_agnostic_dispatch(device, BACKEND_SYNCHRONIZE)
+
+
+def backend_torch_accelerator_module(device: str):
+    return _device_agnostic_dispatch(device, BACKEND_TORCH_ACCELERATOR_MODULE)
+
+
+if is_torch_available():
+    # If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries
+    # into device to function mappings.
+    if "TRANSFORMERS_TEST_DEVICE_SPEC" in os.environ:
+        device_spec_path = os.environ["TRANSFORMERS_TEST_DEVICE_SPEC"]
+        if not Path(device_spec_path).is_file():
+            raise ValueError(
+                f"Specified path to device spec file is not a file or not found. Received '{device_spec_path}"
+            )
+
+        # Try to strip extension for later import – also verifies we are importing a
+        # python file.
+        device_spec_dir, _ = os.path.split(os.path.realpath(device_spec_path))
+        sys.path.append(device_spec_dir)
+        try:
+            import_name = device_spec_path[: device_spec_path.index(".py")]
+        except ValueError as e:
+            raise ValueError(f"Provided device spec file was not a Python file! Received '{device_spec_path}") from e
+
+        device_spec_module = importlib.import_module(import_name)
+
+        # Imported file must contain `DEVICE_NAME`. If it doesn't, terminate early.
+        try:
+            device_name = device_spec_module.DEVICE_NAME
+        except AttributeError as e:
+            raise AttributeError("Device spec file did not contain `DEVICE_NAME`") from e
+
+        if "TRANSFORMERS_TEST_DEVICE" in os.environ and torch_device != device_name:
+            msg = f"Mismatch between environment variable `TRANSFORMERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n"
+            msg += "Either unset `TRANSFORMERS_TEST_DEVICE` or ensure it matches device spec name."
+            raise ValueError(msg)
+
+        torch_device = device_name
+
+        def update_mapping_from_spec(device_fn_dict: dict[str, Callable], attribute_name: str):
+            try:
+                # Try to import the function directly
+                spec_fn = getattr(device_spec_module, attribute_name)
+                device_fn_dict[torch_device] = spec_fn
+            except AttributeError as e:
+                # If the function doesn't exist, and there is no default, throw an error
+                if "default" not in device_fn_dict:
+                    raise AttributeError(
+                        f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found."
+                    ) from e
+
+        # Add one entry here for each `BACKEND_*` dictionary.
+        update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN")
+        update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
+        update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
+
+
+def compare_pipeline_output_to_hub_spec(output, hub_spec):
+    missing_keys = []
+    unexpected_keys = []
+    all_field_names = {field.name for field in fields(hub_spec)}
+    matching_keys = sorted([key for key in output if key in all_field_names])
+
+    # Fields with a MISSING default are required and must be in the output
+    for field in fields(hub_spec):
+        if field.default is MISSING and field.name not in output:
+            missing_keys.append(field.name)
+
+    # All output keys must match either a required or optional field in the Hub spec
+    for output_key in output:
+        if output_key not in all_field_names:
+            unexpected_keys.append(output_key)
+
+    if missing_keys or unexpected_keys:
+        error = ["Pipeline output does not match Hub spec!"]
+        if matching_keys:
+            error.append(f"Matching keys: {matching_keys}")
+        if missing_keys:
+            error.append(f"Missing required keys in pipeline output: {missing_keys}")
+        if unexpected_keys:
+            error.append(f"Keys in pipeline output that are not in Hub spec: {unexpected_keys}")
+        raise KeyError("\n".join(error))
+
+
+@require_torch
+def cleanup(device: str, gc_collect=False):
+    if gc_collect:
+        gc.collect()
+    backend_empty_cache(device)
+    torch._dynamo.reset()
+
+
+# Type definition of key used in `Expectations` class.
+DeviceProperties = tuple[Optional[str], Optional[int], Optional[int]]
+# Helper type. Makes creating instances of `Expectations` smoother.
+PackedDeviceProperties = tuple[Optional[str], Union[None, int, tuple[int, int]]]
+
+
+@cache
+def get_device_properties() -> DeviceProperties:
+    """
+    Get environment device properties.
+    """
+    if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
+        import torch
+
+        major, minor = torch.cuda.get_device_capability()
+        if IS_ROCM_SYSTEM:
+            return ("rocm", major, minor)
+        else:
+            return ("cuda", major, minor)
+    elif IS_XPU_SYSTEM:
+        import torch
+
+        # To get more info of the architecture meaning and bit allocation, refer to https://github.com/intel/llvm/blob/sycl/sycl/include/sycl/ext/oneapi/experimental/device_architecture.def
+        arch = torch.xpu.get_device_capability()["architecture"]
+        gen_mask = 0x000000FF00000000
+        gen = (arch & gen_mask) >> 32
+        return ("xpu", gen, None)
+    else:
+        return (torch_device, None, None)
+
+
+def unpack_device_properties(
+    properties: Optional[PackedDeviceProperties] = None,
+) -> DeviceProperties:
+    """
+    Unpack a `PackedDeviceProperties` tuple into consistently formatted `DeviceProperties` tuple. If properties is None, it is fetched.
+    """
+    if properties is None:
+        return get_device_properties()
+    device_type, major_minor = properties
+    if major_minor is None:
+        major, minor = None, None
+    elif isinstance(major_minor, int):
+        major, minor = major_minor, None
+    else:
+        major, minor = major_minor
+    return device_type, major, minor
+
+
+class Expectations(UserDict[PackedDeviceProperties, Any]):
+    def get_expectation(self) -> Any:
+        """
+        Find best matching expectation based on environment device properties.
+        """
+        return self.find_expectation(get_device_properties())
+
+    def unpacked(self) -> list[tuple[DeviceProperties, Any]]:
+        return [(unpack_device_properties(k), v) for k, v in self.data.items()]
+
+    @staticmethod
+    def is_default(expectation_key: PackedDeviceProperties) -> bool:
+        """
+        This function returns True if the expectation_key is the Default expectation (None, None).
+        When an Expectation dict contains a Default value, it is generally because the test existed before Expectations.
+        When we modify a test to use Expectations for a specific hardware, we don't want to affect the tests on other
+        hardwares. Thus we set the previous value as the Default expectation with key (None, None) and add a value for
+        the specific hardware with key (hardware_type, (major, minor)).
+        """
+        return all(p is None for p in expectation_key)
+
+    @staticmethod
+    def score(properties: DeviceProperties, other: DeviceProperties) -> float:
+        """
+        Returns score indicating how similar two instances of the `Properties` tuple are.
+        Rules are as follows:
+            * Matching `type` adds one point, semi-matching `type` adds 0.1 point (e.g. cuda and rocm).
+            * If types match, matching `major` adds another point, and then matching `minor` adds another.
+            * The Default expectation (None, None) is worth 0.5 point, which is better than semi-matching. More on this
+            in the `is_default` function.
+        """
+        device_type, major, minor = properties
+        other_device_type, other_major, other_minor = other
+
+        score = 0
+        # Matching device type, maybe major and minor
+        if device_type is not None and device_type == other_device_type:
+            score += 1
+            if major is not None and major == other_major:
+                score += 1
+                if minor is not None and minor == other_minor:
+                    score += 1
+        # Semi-matching device type, which carries less importance than the default expectation
+        elif device_type in ["cuda", "rocm"] and other_device_type in ["cuda", "rocm"]:
+            score = 0.1
+
+        # Default expectation
+        if Expectations.is_default(other):
+            score = 0.5
+
+        return score
+
+    def find_expectation(self, properties: DeviceProperties = (None, None, None)) -> Any:
+        """
+        Find best matching expectation based on provided device properties. We score each expectation, and to
+        distinguish between expectations with the same score, we use the major and minor version numbers, prioritizing
+        most recent versions.
+        """
+        (result_key, result) = max(
+            self.unpacked(),
+            key=lambda x: (
+                Expectations.score(properties, x[0]),  # x[0] is a device properties tuple (device_type, major, minor)
+                x[0][1] if x[0][1] is not None else -1,  # This key is the major version, -1 if major is None
+                x[0][2] if x[0][2] is not None else -1,  # This key is the minor version, -1 if minor is None
+            ),
+        )
+
+        if Expectations.score(properties, result_key) == 0:
+            raise ValueError(f"No matching expectation found for {properties}")
+
+        return result
+
+    def __repr__(self):
+        return f"{self.data}"
+
+
+def patch_torch_compile_force_graph():
+    """
+    Patch `torch.compile` to always use `fullgraph=True`.
+
+    This is useful when some `torch.compile` tests are running with `fullgraph=False` and we want to be able to run
+    them with `fullgraph=True` in some occasion (without introducing new tests) to make sure there is no graph break.
+
+    After PR #40137, `CompileConfig.fullgraph` is `False` by default, this patch is necessary.
+    """
+
+    force_fullgraph = os.environ.get("TORCH_COMPILE_FORCE_FULLGRAPH", "")
+    force_fullgraph = force_fullgraph.lower() in ("yes", "true", "on", "t", "y", "1")
+
+    if force_fullgraph:
+        import torch
+
+        orig_method = torch.compile
+
+        def patched(*args, **kwargs):
+            # In `torch_compile`, all arguments except `model` is keyword only argument.
+            kwargs["fullgraph"] = True
+            return orig_method(*args, **kwargs)
+
+        torch.compile = patched
+
+
+def torchrun(script: str, nproc_per_node: int, is_torchrun: bool = True, env: Optional[dict] = None):
+    """Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary."""
+    with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
+        tmp.write(script)
+        tmp.flush()
+        tmp.seek(0)
+        if is_torchrun:
+            cmd = (
+                f"torchrun --nproc_per_node {nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
+            ).split()
+        else:
+            cmd = ["python3", tmp.name]
+
+        # Note that the subprocess will be waited for here, and raise an error if not successful
+        try:
+            _ = subprocess.run(cmd, capture_output=True, env=env, text=True, check=True)
+        except subprocess.CalledProcessError as e:
+            raise Exception(f"The following error was captured: {e.stderr}")
diff --git a/phivenv/Lib/site-packages/transformers/tf_utils.py b/phivenv/Lib/site-packages/transformers/tf_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..11d07f8d7edab7401bd4582b57a095db6552475a
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/tf_utils.py
@@ -0,0 +1,294 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import numpy as np
+import tensorflow as tf
+
+from .feature_extraction_utils import BatchFeature
+from .tokenization_utils_base import BatchEncoding
+from .utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> list[int]:
+    """
+    Deal with dynamic shape in tensorflow cleanly.
+
+    Args:
+        tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of.
+
+    Returns:
+        `list[int]`: The shape of the tensor as a list.
+    """
+    if isinstance(tensor, np.ndarray):
+        return list(tensor.shape)
+
+    dynamic = tf.shape(tensor)
+
+    if tensor.shape == tf.TensorShape(None):
+        return dynamic
+
+    static = tensor.shape.as_list()
+
+    return [dynamic[i] if s is None else s for i, s in enumerate(static)]
+
+
+def stable_softmax(logits: tf.Tensor, axis: Optional[int] = None, name: Optional[str] = None) -> tf.Tensor:
+    """
+    Stable wrapper that returns the same output as `tf.nn.softmax`, but that works reliably with XLA on CPU. It is
+    meant as a workaround for the [following issue](https://github.com/tensorflow/tensorflow/issues/55682), and will be
+    removed after it gets fixed. The arguments and outputs are the same as `tf.nn.softmax`, and relies on the fact that
+    `softmax(x) = softmax(x + c)` (see https://ogunlao.github.io/2020/04/26/you_dont_really_know_softmax.html).
+
+    Args:
+        logits (`tf.Tensor`):
+            Must be one of the following types: half, float32, float64.
+        axis (`int`, *optional*):
+            The dimension softmax would be performed on. The default is -1 which indicates the last dimension.
+        name (`str`, *optional*):
+            A name for the operation.
+
+    Returns:
+        `tf.Tensor`:
+            A Tensor. Has the same type and shape as logits.
+    """
+    # TODO: When the issue linked above gets sorted, add a check on TF version here and use the original function if
+    # it has the fix. After we drop the support for unfixed versions, remove this function.
+    return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
+
+
+def functional_layernorm(inputs, weight, bias, epsilon=1e-5, axis=-1):
+    # This is a very simplified functional layernorm, designed to duplicate
+    # the functionality of PyTorch nn.functional.layer_norm when this is needed to port
+    # models in Transformers.
+
+    if weight.shape.rank != 1 or bias.shape.rank != 1 or not isinstance(axis, int):
+        raise NotImplementedError("Only 1D weight and bias tensors are supported for now, with only a single axis.")
+
+    # Get mean and variance on the axis to be normalized
+    mean, variance = tf.nn.moments(inputs, axes=[axis], keepdims=True)
+
+    if axis != -1:
+        # Reshape scale and weight to have the same rank as inputs, but with 1 dimensions
+        # on every dimension except axis
+        shape = [1] * inputs.shape.rank
+        shape[axis] = shape_list(inputs)[axis]
+        weight = tf.reshape(weight, shape)
+        bias = tf.reshape(bias, shape)
+
+    # Compute layer normalization using the batch_normalization
+    # function.
+    outputs = tf.nn.batch_normalization(
+        inputs,
+        mean,
+        variance,
+        offset=bias,
+        scale=weight,
+        variance_epsilon=epsilon,
+    )
+    return outputs
+
+
+def scaled_dot_product_attention(
+    query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale: Optional[float] = None
+):
+    """TF equivalent for torch's nn.functional.scaled_dot_product_attention"""
+    if dropout_p != 0.0:
+        raise ValueError(
+            "Dropout is not supported in this implementation - file an issue "
+            "with Transformers and ping @Rocketknight1 if you need it for a port!"
+        )
+    if is_causal and attn_mask is not None:
+        raise ValueError("You cannot specify an attn_mask and is_causal at the same time!")
+    if is_causal:
+        attn_mask = tf.ones((tf.shape(query)[-2], tf.shape(key)[-2]), dtype=tf.int32)
+        attn_mask = tf.experimental.numpy.tril(attn_mask, k=0)
+    if attn_mask is not None and (attn_mask.dtype.is_integer or attn_mask.dtype.is_bool):
+        # Convert boolean mask to a negative logit bias
+        attn_mask = tf.where(attn_mask > 0, tf.cast(0.0, query.dtype), tf.cast(-1000.0, query.dtype))
+    logits = tf.einsum("...qd, ...kd -> ...qk", query, key)
+    if scale is None:
+        scale = tf.cast(tf.shape(key)[-1], logits.dtype) ** -0.5
+    logits *= scale  # scale by 1/sqrt(key_dim)
+    if attn_mask is not None:
+        logits += attn_mask
+    probs = tf.nn.softmax(logits)
+    return probs @ value
+
+
+def flatten(input, start_dim=0, end_dim=-1):
+    # Replicates the behavior of torch.flatten in TF
+
+    # If end_dim or start_dim is negative, count them from the end
+    if end_dim < 0:
+        end_dim += input.shape.rank
+    if start_dim < 0:
+        start_dim += input.shape.rank
+
+    if start_dim == end_dim:
+        return input
+
+    in_shape = tf.shape(input)
+    flattened_dim = tf.math.reduce_prod(in_shape[start_dim : end_dim + 1])
+    out_shape = tf.concat([in_shape[:start_dim], [flattened_dim], in_shape[end_dim + 1 :]], axis=0)
+    return tf.reshape(input, out_shape)
+
+
+def invert_attention_mask(encoder_attention_mask: tf.Tensor) -> tf.Tensor:
+    """
+    Invert an attention mask (e.g., switches 0. and 1.).
+
+    Args:
+        encoder_attention_mask (`torch.Tensor`): An attention mask.
+
+    Returns:
+        `tf.Tensor`: The inverted attention mask.
+    """
+    if not isinstance(encoder_attention_mask, tf.Tensor):
+        encoder_attention_mask = tf.convert_to_tensor(encoder_attention_mask)  # Catches stray NumPy inputs
+    if encoder_attention_mask.shape.rank == 3:
+        encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
+    if encoder_attention_mask.shape.rank == 2:
+        encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
+    # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
+    # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
+    # /transformer/transformer_layers.py#L270
+    # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
+    # encoder_extended_attention_mask.transpose(-1, -2))
+    encoder_extended_attention_mask = (
+        tf.cast(1, encoder_attention_mask.dtype) - encoder_extended_attention_mask
+    ) * encoder_extended_attention_mask.dtype.min
+
+    return encoder_extended_attention_mask
+
+
+def check_embeddings_within_bounds(tensor: tf.Tensor, embed_dim: int, tensor_name: str = "input_ids") -> None:
+    """
+    `tf.gather`, on which TF embedding layers are based, won't check positive out of bound indices on GPU, returning
+    zeros instead. This function adds a check against that dangerous silent behavior.
+
+    Args:
+        tensor (`tf.Tensor`): The tensor of indices to check.
+        embed_dim (`int`): The embedding dimension.
+        tensor_name (`str`, *optional*): The name of the tensor to use in the error message.
+    """
+    tf.debugging.assert_less(
+        tensor,
+        tf.cast(embed_dim, dtype=tensor.dtype),
+        message=(
+            f"The maximum value of {tensor_name} ({tf.math.reduce_max(tensor)}) must be smaller than the embedding "
+            f"layer's input dimension ({embed_dim}). The likely cause is some problem at tokenization time."
+        ),
+    )
+
+
+def save_attributes_to_hdf5_group(group, name, data):
+    """Saves attributes (data) of the specified name into the HDF5 group.
+
+    This method deals with an inherent problem of HDF5 file which is not able to store data larger than
+    HDF5_OBJECT_HEADER_LIMIT bytes.
+
+    Args:
+        group: A pointer to a HDF5 group.
+        name: A name of the attributes to save.
+        data: Attributes data to store.
+
+    Raises:
+      RuntimeError: If any single attribute is too large to be saved.
+
+    Copied from Keras to Transformers to avoid versioning issues.
+    """
+    HDF5_OBJECT_HEADER_LIMIT = 64512
+    # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT`
+    # because in that case even chunking the array would not make the saving
+    # possible.
+    bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT]
+
+    # Expecting this to never be true.
+    if bad_attributes:
+        raise RuntimeError(
+            "The following attributes cannot be saved to HDF5 file because "
+            f"they are larger than {HDF5_OBJECT_HEADER_LIMIT} "
+            f"bytes: {bad_attributes}"
+        )
+
+    data_npy = np.asarray(data)
+
+    num_chunks = 1
+    chunked_data = np.array_split(data_npy, num_chunks)
+
+    # This will never loop forever thanks to the test above.
+    while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data):
+        num_chunks += 1
+        chunked_data = np.array_split(data_npy, num_chunks)
+
+    if num_chunks > 1:
+        for chunk_id, chunk_data in enumerate(chunked_data):
+            group.attrs["%s%d" % (name, chunk_id)] = chunk_data
+    else:
+        group.attrs[name] = data
+
+
+def load_attributes_from_hdf5_group(group, name):
+    """Loads attributes of the specified name from the HDF5 group.
+
+    This method deals with an inherent problem of HDF5 file which is not able to store data larger than
+    HDF5_OBJECT_HEADER_LIMIT bytes.
+
+    Args:
+        group: A pointer to a HDF5 group.
+        name: A name of the attributes to load.
+
+    Returns:
+        data: Attributes data.
+
+    Copied from Keras to Transformers to avoid versioning issues.
+    """
+    if name in group.attrs:
+        data = [n.decode("utf8") if hasattr(n, "decode") else n for n in group.attrs[name]]
+    else:
+        data = []
+        chunk_id = 0
+        while "%s%d" % (name, chunk_id) in group.attrs:
+            data.extend(
+                [n.decode("utf8") if hasattr(n, "decode") else n for n in group.attrs["%s%d" % (name, chunk_id)]]
+            )
+            chunk_id += 1
+    return data
+
+
+def expand_1d(data):
+    """Expands 1-dimensional `Tensor`s into 2-dimensional `Tensor`s.
+    Copied from Keras to here to avoid versioning issues."""
+
+    def _expand_single_1d_tensor(t):
+        if isinstance(t, tf.Tensor) and t.shape.rank == 1:
+            return tf.expand_dims(t, axis=-1)
+        return t
+
+    return tf.nest.map_structure(_expand_single_1d_tensor, data)
+
+
+def convert_batch_encoding(*args, **kwargs):
+    # Convert HF BatchEncoding/BatchFeature objects in the inputs to dicts that Keras understands
+    if args and isinstance(args[0], (BatchEncoding, BatchFeature)):
+        args = list(args)
+        args[0] = dict(args[0])
+    elif "x" in kwargs and isinstance(kwargs["x"], (BatchEncoding, BatchFeature)):
+        kwargs["x"] = dict(kwargs["x"])
+    return args, kwargs
diff --git a/phivenv/Lib/site-packages/transformers/time_series_utils.py b/phivenv/Lib/site-packages/transformers/time_series_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a5cf4f2f4d8f636e623c07ae3400ca5e17b5891
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/time_series_utils.py
@@ -0,0 +1,225 @@
+# Copyright 2023 The HuggingFace Inc. team.
+# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Time series distributional output classes and utilities.
+"""
+
+from typing import Callable, Optional
+
+import torch
+from torch import nn
+from torch.distributions import (
+    AffineTransform,
+    Distribution,
+    Independent,
+    NegativeBinomial,
+    Normal,
+    StudentT,
+    TransformedDistribution,
+)
+
+
+class AffineTransformed(TransformedDistribution):
+    def __init__(self, base_distribution: Distribution, loc=None, scale=None, event_dim=0):
+        self.scale = 1.0 if scale is None else scale
+        self.loc = 0.0 if loc is None else loc
+
+        super().__init__(base_distribution, [AffineTransform(loc=self.loc, scale=self.scale, event_dim=event_dim)])
+
+    @property
+    def mean(self):
+        """
+        Returns the mean of the distribution.
+        """
+        return self.base_dist.mean * self.scale + self.loc
+
+    @property
+    def variance(self):
+        """
+        Returns the variance of the distribution.
+        """
+        return self.base_dist.variance * self.scale**2
+
+    @property
+    def stddev(self):
+        """
+        Returns the standard deviation of the distribution.
+        """
+        return self.variance.sqrt()
+
+
+class ParameterProjection(nn.Module):
+    def __init__(
+        self, in_features: int, args_dim: dict[str, int], domain_map: Callable[..., tuple[torch.Tensor]], **kwargs
+    ) -> None:
+        super().__init__(**kwargs)
+        self.args_dim = args_dim
+        self.proj = nn.ModuleList([nn.Linear(in_features, dim) for dim in args_dim.values()])
+        self.domain_map = domain_map
+
+    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]:
+        params_unbounded = [proj(x) for proj in self.proj]
+
+        return self.domain_map(*params_unbounded)
+
+
+class LambdaLayer(nn.Module):
+    def __init__(self, function):
+        super().__init__()
+        self.function = function
+
+    def forward(self, x, *args):
+        return self.function(x, *args)
+
+
+class DistributionOutput:
+    distribution_class: type
+    in_features: int
+    args_dim: dict[str, int]
+
+    def __init__(self, dim: int = 1) -> None:
+        self.dim = dim
+        self.args_dim = {k: dim * self.args_dim[k] for k in self.args_dim}
+
+    def _base_distribution(self, distr_args):
+        if self.dim == 1:
+            return self.distribution_class(*distr_args)
+        else:
+            return Independent(self.distribution_class(*distr_args), 1)
+
+    def distribution(
+        self,
+        distr_args,
+        loc: Optional[torch.Tensor] = None,
+        scale: Optional[torch.Tensor] = None,
+    ) -> Distribution:
+        distr = self._base_distribution(distr_args)
+        if loc is None and scale is None:
+            return distr
+        else:
+            return AffineTransformed(distr, loc=loc, scale=scale, event_dim=self.event_dim)
+
+    @property
+    def event_shape(self) -> tuple:
+        r"""
+        Shape of each individual event contemplated by the distributions that this object constructs.
+        """
+        return () if self.dim == 1 else (self.dim,)
+
+    @property
+    def event_dim(self) -> int:
+        r"""
+        Number of event dimensions, i.e., length of the `event_shape` tuple, of the distributions that this object
+        constructs.
+        """
+        return len(self.event_shape)
+
+    @property
+    def value_in_support(self) -> float:
+        r"""
+        A float that will have a valid numeric value when computing the log-loss of the corresponding distribution. By
+        default 0.0. This value will be used when padding data series.
+        """
+        return 0.0
+
+    def get_parameter_projection(self, in_features: int) -> nn.Module:
+        r"""
+        Return the parameter projection layer that maps the input to the appropriate parameters of the distribution.
+        """
+        return ParameterProjection(
+            in_features=in_features,
+            args_dim=self.args_dim,
+            domain_map=LambdaLayer(self.domain_map),
+        )
+
+    def domain_map(self, *args: torch.Tensor):
+        r"""
+        Converts arguments to the right shape and domain. The domain depends on the type of distribution, while the
+        correct shape is obtained by reshaping the trailing axis in such a way that the returned tensors define a
+        distribution of the right event_shape.
+        """
+        raise NotImplementedError()
+
+    @staticmethod
+    def squareplus(x: torch.Tensor) -> torch.Tensor:
+        r"""
+        Helper to map inputs to the positive orthant by applying the square-plus operation. Reference:
+        https://twitter.com/jon_barron/status/1387167648669048833
+        """
+        return (x + torch.sqrt(torch.square(x) + 4.0)) / 2.0
+
+
+class StudentTOutput(DistributionOutput):
+    """
+    Student-T distribution output class.
+    """
+
+    args_dim: dict[str, int] = {"df": 1, "loc": 1, "scale": 1}
+    distribution_class: type = StudentT
+
+    @classmethod
+    def domain_map(cls, df: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor):
+        scale = cls.squareplus(scale).clamp_min(torch.finfo(scale.dtype).eps)
+        df = 2.0 + cls.squareplus(df)
+        return df.squeeze(-1), loc.squeeze(-1), scale.squeeze(-1)
+
+
+class NormalOutput(DistributionOutput):
+    """
+    Normal distribution output class.
+    """
+
+    args_dim: dict[str, int] = {"loc": 1, "scale": 1}
+    distribution_class: type = Normal
+
+    @classmethod
+    def domain_map(cls, loc: torch.Tensor, scale: torch.Tensor):
+        scale = cls.squareplus(scale).clamp_min(torch.finfo(scale.dtype).eps)
+        return loc.squeeze(-1), scale.squeeze(-1)
+
+
+class NegativeBinomialOutput(DistributionOutput):
+    """
+    Negative Binomial distribution output class.
+    """
+
+    args_dim: dict[str, int] = {"total_count": 1, "logits": 1}
+    distribution_class: type = NegativeBinomial
+
+    @classmethod
+    def domain_map(cls, total_count: torch.Tensor, logits: torch.Tensor):
+        total_count = cls.squareplus(total_count)
+        return total_count.squeeze(-1), logits.squeeze(-1)
+
+    def _base_distribution(self, distr_args) -> Distribution:
+        total_count, logits = distr_args
+        if self.dim == 1:
+            return self.distribution_class(total_count=total_count, logits=logits)
+        else:
+            return Independent(self.distribution_class(total_count=total_count, logits=logits), 1)
+
+    # Overwrites the parent class method. We cannot scale using the affine
+    # transformation since negative binomial should return integers. Instead
+    # we scale the parameters.
+    def distribution(
+        self, distr_args, loc: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None
+    ) -> Distribution:
+        total_count, logits = distr_args
+
+        if scale is not None:
+            # See scaling property of Gamma.
+            logits += scale.log()
+
+        return self._base_distribution((total_count, logits))
diff --git a/phivenv/Lib/site-packages/transformers/tokenization_mistral_common.py b/phivenv/Lib/site-packages/transformers/tokenization_mistral_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..a362a7c8b066178bdf59f7aaf8395616b0fb0322
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/tokenization_mistral_common.py
@@ -0,0 +1,1883 @@
+# Copyright 2025 Mistral AI and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import shutil
+import warnings
+from collections.abc import Mapping, Sized
+from enum import Enum
+from pathlib import Path
+from typing import Any, Callable, Optional, Union, overload
+
+import numpy as np
+
+from transformers.audio_utils import load_audio_as
+from transformers.tokenization_utils_base import (
+    LARGE_INTEGER,
+    VERY_LARGE_INTEGER,
+    BatchEncoding,
+    EncodedInput,
+    PreTokenizedInput,
+    PreTrainedTokenizerBase,
+    TextInput,
+    TruncationStrategy,
+)
+from transformers.utils import PaddingStrategy, TensorType, add_end_docstrings, logging, to_py_obj
+from transformers.utils.generic import is_torch_tensor
+from transformers.utils.hub import PushToHubMixin
+from transformers.utils.import_utils import is_mistral_common_available, is_torch_available, requires
+
+
+if is_mistral_common_available():
+    from mistral_common.protocol.instruct.request import ChatCompletionRequest
+    from mistral_common.protocol.instruct.validator import ValidationMode
+    from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, TokenizerVersion
+    from mistral_common.tokens.tokenizers.image import MultiModalVersion
+    from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
+    from mistral_common.tokens.tokenizers.tekken import Tekkenizer
+    from mistral_common.tokens.tokenizers.utils import download_tokenizer_from_hf_hub
+
+
+if is_torch_available():
+    import torch
+
+
+logger = logging.get_logger(__name__)
+
+
+ENCODE_KWARGS_DOCSTRING = r"""
+            add_special_tokens (`bool`, *optional*, defaults to `True`):
+                Whether or not to add special tokens when encoding the sequences. This will use the underlying
+                `PretrainedTokenizerBase.build_inputs_with_special_tokens` function, which defines which tokens are
+                automatically added to the input ids. This is useful if you want to add `bos` or `eos` tokens
+                automatically.
+            padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+                Activates and controls padding. Accepts the following values:
+
+                - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+                  sequence is provided).
+                - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+                  acceptable input length for the model if that argument is not provided.
+                - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+                  lengths).
+            truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+                Activates and controls truncation. Accepts the following values:
+
+                - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
+                  to the maximum acceptable input length for the model if that argument is not provided.
+                - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
+                  greater than the model maximum admissible input size).
+            max_length (`int`, *optional*):
+                Controls the maximum length to use by one of the truncation/padding parameters.
+
+                If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
+                is required by one of the truncation/padding parameters. If the model has no specific maximum input
+                length (like XLNet) truncation/padding to a maximum length will be deactivated.
+            stride (`int`, *optional*, defaults to 0):
+                If set to a number along with `max_length`, the overflowing tokens returned when
+                `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
+                returned to provide some overlap between truncated and overflowing sequences. The value of this
+                argument defines the number of overlapping tokens.
+            pad_to_multiple_of (`int`, *optional*):
+                If set will pad the sequence to a multiple of the provided value. Requires `padding` to be activated.
+                This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
+                `>= 7.5` (Volta).
+            padding_side (`str`, *optional*):
+                The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+                Default value is picked from the class attribute of the same name.
+            return_tensors (`str` or [`~utils.TensorType`], *optional*):
+                If set, will return tensors instead of list of python integers. Acceptable values are:
+
+                - `'pt'`: Return PyTorch `torch.Tensor` objects.
+"""
+
+ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
+            return_attention_mask (`bool`, *optional*):
+                Whether to return the attention mask. If left to the default, will return the attention mask according
+                to the specific tokenizer's default, defined by the `return_outputs` attribute.
+
+                [What are attention masks?](../glossary#attention-mask)
+            return_overflowing_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch
+                of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead
+                of returning overflowing tokens.
+            return_special_tokens_mask (`bool`, *optional*, defaults to `False`):
+                Whether or not to return special tokens mask information.
+            return_offsets_mapping (`bool`, *optional*, defaults to `False`):
+                Whether or not to return `(char_start, char_end)` for each token.
+
+                This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using
+                Python's tokenizer, this method will raise `NotImplementedError`.
+            return_length  (`bool`, *optional*, defaults to `False`):
+                Whether or not to return the lengths of the encoded inputs.
+            verbose (`bool`, *optional*, defaults to `True`):
+                Whether or not to print more information and warnings.
+            **kwargs: passed to the `self.tokenize()` method
+
+        Return:
+            [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
+
+            - **input_ids** -- List of token ids to be fed to a model.
+
+              [What are input IDs?](../glossary#input-ids)
+
+            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+              `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`).
+
+              [What are attention masks?](../glossary#attention-mask)
+
+            - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and
+              `return_overflowing_tokens=True`).
+            - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and
+              `return_overflowing_tokens=True`).
+            - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying
+              regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`).
+            - **length** -- The length of the inputs (when `return_length=True`)
+"""
+
+
+class MistralTokenizerType(str, Enum):
+    """Enum for the different type of tokenizer."""
+
+    spm = "spm"
+    tekken = "tekken"
+
+
+@requires(backends=("mistral-common",))
+class MistralCommonTokenizer(PushToHubMixin):
+    """
+    Class to wrap `mistral-common` tokenizers.
+
+    `mistral-common` is the official tokenizer library for Mistral AI models. To use it, you need to install it with:
+
+    ```bash
+    pip install transformers[mistral-common]
+    ```
+
+    Otherwise the tokenizer falls back to the Transformers implementation of the tokenizer.
+
+    For more info on `mistral-common`, see [mistral-common](https://github.com/mistralai/mistral-common).
+
+    This class is a wrapper around a `mistral_common.tokens.tokenizers.mistral.MistralTokenizer`.
+    It provides a Hugging Face compatible interface to tokenize using the official mistral-common tokenizer.
+
+    Supports the following methods from the `PreTrainedTokenizerBase` class:
+
+    - [`~MistralCommonTokenizer.get_vocab`]: Returns the vocabulary as a dictionary of token to index.
+    - [`~MistralCommonTokenizer.encode`]: Encode a string to a list of integers.
+    - [`~MistralCommonTokenizer.decode`]: Decode a list of integers to a string.
+    - [`~MistralCommonTokenizer.batch_decode`]: Decode a batch of list of integers to a list of strings.
+    - [`~MistralCommonTokenizer.convert_tokens_to_ids`]: Convert a list of tokens to a list of integers.
+    - [`~MistralCommonTokenizer.convert_ids_to_tokens`]: Convert a list of integers to a list of tokens.
+    - [`~MistralCommonTokenizer.tokenize`]: Tokenize a string.
+    - [`~MistralCommonTokenizer.get_special_tokens_mask`]: Get the special tokens mask for a list of tokens.
+    - [`~MistralCommonTokenizer.prepare_for_model`]: Prepare a list of inputs for the model.
+    - [`~MistralCommonTokenizer.pad`]: Pad a list of inputs to the same length.
+    - [`~MistralCommonTokenizer.truncate_sequences`]: Truncate a list of sequences to the same length.
+    - [`~MistralCommonTokenizer.apply_chat_template`]: Apply a chat template to a list of messages.
+    - [`~MistralCommonTokenizer.__call__`]: Tokenize a string or a list of strings.
+    - [`~MistralCommonTokenizer.from_pretrained`]: Download and cache a pretrained tokenizer from the Hugging Face model hub or local directory.
+    - [`~MistralCommonTokenizer.save_pretrained`]: Save a tokenizer to a directory, so it can be reloaded using the `from_pretrained` class method.
+    - [`~MistralCommonTokenizer.push_to_hub`]: Upload tokenizer to the Hugging Face model hub.
+
+    Here are the key differences with the `PreTrainedTokenizerBase` class:
+
+    - Pair of sequences are not supported. The signature have been kept for compatibility but all arguments related to pair of sequences are ignored. The return values of pairs are returned as `None`.
+    - The `is_split_into_words` argument is not supported.
+    - The `return_token_type_ids` argument is not supported.
+    - It is not possible to add new tokens to the tokenizer. Also the special tokens are handled differently from Transformers. In `mistral-common`, special tokens are never encoded directly. This means that: `tokenizer.encode("")` will not return the ID of the `` token. Instead, it will return a list of IDs corresponding to the tokenization of the string `""`. For more information, see the [mistral-common documentation](https://mistralai.github.io/mistral-common/usage/tokenizers/#special-tokens).
+
+    If you have suggestions to improve this class, please open an issue on the [mistral-common GitHub repository](https://github.com/mistralai/mistral-common/issues) if it is related to the tokenizer or on the [Transformers GitHub repository](https://github.com/huggingface/transformers/issues) if it is related to the Hugging Face interface.
+    """
+
+    model_input_names: list[str] = ["input_ids", "attention_mask"]
+    padding_side: str = "left"
+    truncation_side: str = "right"
+
+    def __init__(
+        self,
+        tokenizer_path: Union[str, os.PathLike, Path],
+        mode: ValidationMode = ValidationMode.test,
+        model_max_length: int = VERY_LARGE_INTEGER,
+        padding_side: str = "left",
+        truncation_side: str = "right",
+        model_input_names: Optional[list[str]] = None,
+        clean_up_tokenization_spaces: bool = False,
+        **kwargs,
+    ):
+        """
+        Constructs a `MistralCommonTokenizer`.
+
+        - **model_input_names** (`List[str]`) -- A list of inputs expected in the forward pass of the model.
+        - **padding_side** (`str`) -- The default value for the side on which the model should have padding applied.
+            Should be `'right'` or `'left'`.
+        - **truncation_side** (`str`) -- The default value for the side on which the model should have truncation
+            applied. Should be `'right'` or `'left'`.
+
+        Args:
+            tokenizer_path (`str` or `os.PathLike` or `Path`):
+                Path to the tokenizer file to load the `MistralTokenizer`.
+            mode (`ValidationMode`, *optional*, defaults to `ValidationMode.test`):
+                The mode to use for the tokenizer. This will be passed to the `MistralTokenizer` constructor.
+            model_max_length (`int`, *optional*):
+                The maximum length (in number of tokens) for the inputs to the transformer model. When the tokenizer is
+                loaded with [`~tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`], this will be set to the
+                value stored for the associated model in `max_model_input_sizes` (see above). If no value is provided, will
+                default to VERY_LARGE_INTEGER (`int(1e30)`).
+            padding_side (`str`, *optional*):
+                The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+                Default value is picked from the class attribute of the same name.
+            truncation_side (`str`, *optional*):
+                The side on which the model should have truncation applied. Should be selected between ['right', 'left'].
+                Default value is picked from the class attribute of the same name.
+            model_input_names (`List[string]`, *optional*):
+                The list of inputs accepted by the forward pass of the model (like `"token_type_ids"` or
+                `"attention_mask"`). Default value is picked from the class attribute of the same name.
+            clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
+                Whether or not the model should cleanup the spaces that were added when splitting the input text during the
+                tokenization process.
+        """
+        if kwargs:
+            raise ValueError(f"Kwargs {list(kwargs.keys())} are not supported to init `MistralCommonTokenizer`.")
+
+        self._tokenizer_path = Path(tokenizer_path)
+        self.tokenizer: MistralTokenizer = MistralTokenizer.from_file(str(self._tokenizer_path), mode=mode)
+        self._tokenizer_type = (
+            MistralTokenizerType.tekken
+            if isinstance(self.tokenizer.instruct_tokenizer.tokenizer, Tekkenizer)
+            else MistralTokenizerType.spm
+        )
+        self.truncation_side = truncation_side
+        self.padding_side = padding_side
+        self.model_max_length = model_max_length
+        self.cleanup_tokenization_spaces = clean_up_tokenization_spaces
+        self.deprecation_warnings = {}  # Use to store when we have already noticed a deprecation warning (avoid overlogging).
+
+        if model_input_names is not None:
+            if (
+                not isinstance(model_input_names, (list, tuple))
+                and len(model_input_names) == 0
+                and not all(isinstance(i, str) for i in model_input_names)
+            ):
+                raise ValueError(
+                    "`model_input_names` should be a non-empty list or tuple of str but got an empty value."
+                )
+            self.model_input_names = model_input_names
+
+        self._cache_get_vocab: Optional[dict[str, int]] = None
+
+    @property
+    def bos_token_id(self) -> int:
+        """
+        Id of the beginning of sentence token in the vocabulary.
+        """
+        return self.tokenizer.instruct_tokenizer.tokenizer.bos_id
+
+    @property
+    def eos_token_id(self) -> int:
+        """
+        Id of the end of sentence token in the vocabulary.
+        """
+        return self.tokenizer.instruct_tokenizer.tokenizer.eos_id
+
+    @property
+    def unk_token_id(self) -> int:
+        """
+        Id of the unknown token in the vocabulary.
+        """
+        return self.tokenizer.instruct_tokenizer.tokenizer.unk_id
+
+    @property
+    def pad_token_id(self) -> int:
+        """
+        Id of the padding token in the vocabulary.
+        """
+        return self.tokenizer.instruct_tokenizer.tokenizer.pad_id
+
+    @property
+    def bos_token(self) -> str:
+        """
+        String associated to the beginning of sentence token in the vocabulary.
+        """
+        return self.convert_ids_to_tokens(self.bos_token_id)
+
+    @property
+    def eos_token(self) -> str:
+        """
+        String associated to the end of sentence token in the vocabulary.
+        """
+        return self.convert_ids_to_tokens(self.eos_token_id)
+
+    @property
+    def unk_token(self) -> str:
+        """
+        String associated to the unknown token in the vocabulary.
+        """
+        return self.convert_ids_to_tokens(self.unk_token_id)
+
+    @property
+    def pad_token(self) -> str:
+        """
+        String associated to the padding token in the vocabulary.
+        """
+        return self.convert_ids_to_tokens(self.pad_token_id)
+
+    @property
+    def vocab_size(self) -> int:
+        """
+        Returns the size of the vocabulary.
+
+        `int`: Size of the vocabulary.
+        """
+        return self.tokenizer.instruct_tokenizer.tokenizer.n_words
+
+    def get_vocab(self) -> dict[str, int]:
+        """
+        Returns the vocabulary as a dictionary of token to index.
+
+        This is a lossy conversion. There may be multiple token ids that decode to the same
+        string due to partial UTF-8 byte sequences being converted to �.
+
+        Returns:
+            `Dict[str, int]`: The vocabulary.
+        """
+        if self._cache_get_vocab is None:
+            self._cache_get_vocab = {
+                token: idx for idx, token in enumerate(self.tokenizer.instruct_tokenizer.tokenizer.vocab())
+            }
+        return self._cache_get_vocab
+
+    def __len__(self):
+        """
+        Size of the full vocabulary with the added tokens.
+        """
+        return self.vocab_size
+
+    @add_end_docstrings(
+        ENCODE_KWARGS_DOCSTRING,
+        """
+            **kwargs: Not supported by `MistralCommonTokenizer.encode`.
+                Will raise an error if used.
+        """,
+        """
+        Returns:
+            `List[int]`, `torch.Tensor`: The tokenized ids of the text.
+        """,
+    )
+    def encode(
+        self,
+        text: Union[TextInput, EncodedInput],
+        text_pair: None = None,
+        add_special_tokens: bool = True,
+        padding: Union[bool, str, PaddingStrategy] = False,
+        truncation: Union[bool, str, TruncationStrategy, None] = None,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        verbose: bool = True,
+        **kwargs,
+    ) -> list[int]:
+        """
+        Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary.
+
+        Args:
+            text (`str` or `List[int]`):
+                The first sequence to be encoded. This can be a string or a list of integers (tokenized string ids).
+            text_pair (`None`, *optional*):
+                Not supported by `MistralCommonTokenizer.encode`. Kept to match `PreTrainedTokenizerBase.encode` signature.
+        """
+        if kwargs:
+            raise ValueError(f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer.encode`.")
+        if text_pair:
+            raise ValueError("`MistralCommonTokenizer.encode` does not support `text_pair`.")
+
+        padding_strategy, truncation_strategy, max_length, _ = self._get_padding_truncation_strategies(
+            padding=padding,
+            truncation=truncation,
+            max_length=max_length,
+            pad_to_multiple_of=pad_to_multiple_of,
+            verbose=verbose,
+        )
+
+        encoded_inputs = self._encode_plus(
+            text,
+            add_special_tokens=add_special_tokens,
+            padding_strategy=padding_strategy,
+            truncation_strategy=truncation_strategy,
+            max_length=max_length,
+            stride=stride,
+            pad_to_multiple_of=pad_to_multiple_of,
+            padding_side=padding_side,
+            return_tensors=return_tensors,
+            return_attention_mask=False,
+            return_overflowing_tokens=False,
+            return_special_tokens_mask=False,
+            return_length=False,
+            verbose=verbose,
+        )
+
+        return encoded_inputs["input_ids"]
+
+    def decode(
+        self,
+        token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor"],
+        skip_special_tokens: bool = False,
+        clean_up_tokenization_spaces: Optional[bool] = None,
+        **kwargs,
+    ) -> str:
+        """
+        Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
+        tokens and clean up tokenization spaces.
+
+        Args:
+            token_ids (`Union[int, List[int], np.ndarray, torch.Tensor]`):
+                List of tokenized input ids. Can be obtained using the `__call__` method.
+            skip_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not to remove special tokens in the decoding.
+            clean_up_tokenization_spaces (`bool`, *optional*):
+                Whether or not to clean up the tokenization spaces. If `None`, will default to
+                `self.clean_up_tokenization_spaces`.
+            kwargs (additional keyword arguments, *optional*):
+                Not supported by `MistralCommonTokenizer.decode`.
+                Will raise an error if used.
+
+        Returns:
+            `str`: The decoded sentence.
+        """
+        if kwargs:
+            raise ValueError(f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer.decode`.")
+
+        clean_up_tokenization_spaces = clean_up_tokenization_spaces or self.cleanup_tokenization_spaces
+
+        # Convert inputs to python lists
+        token_ids = to_py_obj(token_ids)
+
+        special_token_policy = SpecialTokenPolicy.IGNORE if skip_special_tokens else SpecialTokenPolicy.KEEP
+
+        decoded_string = self.tokenizer.decode(token_ids, special_token_policy=special_token_policy)
+        if clean_up_tokenization_spaces:
+            decoded_string = PreTrainedTokenizerBase.clean_up_tokenization(decoded_string)
+
+        return decoded_string
+
+    def batch_decode(
+        self,
+        sequences: Union[list[int], list[list[int]], "np.ndarray", "torch.Tensor"],
+        skip_special_tokens: bool = False,
+        clean_up_tokenization_spaces: Optional[bool] = None,
+        **kwargs,
+    ) -> list[str]:
+        """
+        Convert a list of lists of token ids into a list of strings by calling decode.
+
+        Args:
+            sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor]`):
+                List of tokenized input ids. Can be obtained using the `__call__` method.
+            skip_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not to remove special tokens in the decoding.
+            clean_up_tokenization_spaces (`bool`, *optional*):
+                Whether or not to clean up the tokenization spaces. If `None`, will default to
+                `self.clean_up_tokenization_spaces`.
+            kwargs (additional keyword arguments, *optional*):
+                Not supported by `MistralCommonTokenizer.batch_decode`.
+                Will raise an error if used.
+
+        Returns:
+            `List[str]`: The list of decoded sentences.
+        """
+        return [
+            self.decode(
+                seq,
+                skip_special_tokens=skip_special_tokens,
+                clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+                **kwargs,
+            )
+            for seq in sequences
+        ]
+
+    def _is_control_token(self, token_id: int) -> bool:
+        if self._tokenizer_type == MistralTokenizerType.spm:
+            return token_id in self.tokenizer.instruct_tokenizer.tokenizer._control_tokens()
+        elif self._tokenizer_type == MistralTokenizerType.tekken:
+            return token_id < self.tokenizer.instruct_tokenizer.tokenizer.num_special_tokens
+        else:
+            raise ValueError(f"Unknown tokenizer type: {self._tokenizer_type}")
+
+    @overload
+    def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ...
+    @overload
+    def convert_ids_to_tokens(self, ids: list[int], skip_special_tokens: bool = False) -> list[str]: ...
+    def convert_ids_to_tokens(
+        self, ids: Union[int, list[int]], skip_special_tokens: bool = False
+    ) -> Union[str, list[str]]:
+        """
+        Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
+        added tokens.
+
+        Args:
+            ids (`int` or `List[int]`):
+                The token id (or token ids) to convert to tokens.
+            skip_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not to remove special tokens in the decoding.
+
+        Returns:
+            `str` or `List[str]`: The decoded token(s).
+        """
+
+        if isinstance(ids, int):
+            one_token = True
+            ids = [ids]
+        else:
+            one_token = False
+
+        tokens: list[str] = []
+        for token_id in ids:
+            if self._is_control_token(token_id) and skip_special_tokens:
+                continue
+            tokens.append(self.tokenizer.instruct_tokenizer.tokenizer.id_to_piece(token_id))
+
+        if one_token:
+            if tokens == []:
+                raise ValueError(f"Invalid token id {ids}.")
+
+            return tokens[0]
+        return tokens
+
+    def _piece_to_id(self, piece: str) -> int:
+        if self._tokenizer_type == MistralTokenizerType.spm:
+            return self.tokenizer.instruct_tokenizer.tokenizer._model.piece_to_id(piece)
+        elif self._tokenizer_type == MistralTokenizerType.tekken:
+            pieces = self.tokenizer.instruct_tokenizer.tokenizer._model.encode(
+                piece, allowed_special="all", disallowed_special=set()
+            )
+            assert len(pieces) == 1, f"Expected to decode 1 token, got {len(pieces)}"
+            return pieces[0]
+        else:
+            raise ValueError(f"Unknown tokenizer type: {self._tokenizer_type}")
+
+    def convert_tokens_to_ids(self, tokens: Union[str, list[str]]) -> Union[int, list[int]]:
+        """
+        Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
+        vocabulary.
+
+        Args:
+            tokens (`str` or `List[str]`): One or several token(s) to convert to token id(s).
+
+        Returns:
+            `int` or `List[int]`: The token id or list of token ids.
+        """
+
+        if isinstance(tokens, str):
+            one_token = True
+            tokens = [tokens]
+        else:
+            one_token = False
+
+        ids: list[int] = []
+        for token in tokens:
+            ids.append(self._piece_to_id(token))
+
+        if one_token:
+            return ids[0]
+        return ids
+
+    def _text_to_ids(self, text: TextInput, add_special_tokens: bool) -> list[int]:
+        """
+        Converts a string into a sequence of tokens ids, using the tokenizer.
+        """
+        tokens_ids = self.tokenizer.instruct_tokenizer.tokenizer.encode(
+            text, bos=add_special_tokens, eos=add_special_tokens
+        )
+        return tokens_ids
+
+    def tokenize(self, text: TextInput, **kwargs) -> list[str]:
+        """
+        Converts a string into a sequence of tokens, using the tokenizer.
+
+        Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies.
+
+        Args:
+            text (`str`):
+                The sequence to be encoded.
+            **kwargs (additional keyword arguments):
+                Not supported by `MistralCommonTokenizer.tokenize`.
+                Will raise an error if used.
+
+        Returns:
+            `List[str]`: The list of tokens.
+        """
+        if kwargs:
+            raise ValueError(f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer.tokenize`.")
+
+        return self.convert_ids_to_tokens(self._text_to_ids(text, add_special_tokens=False), skip_special_tokens=False)
+
+    def _encode_plus(
+        self,
+        text: Union[TextInput, EncodedInput],
+        add_special_tokens: bool = True,
+        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_overflowing_tokens: bool = False,
+        return_special_tokens_mask: bool = False,
+        return_length: bool = False,
+        verbose: bool = True,
+        **kwargs,
+    ) -> BatchEncoding:
+        if kwargs:
+            raise ValueError(
+                f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer._encode_plus`."
+            )
+
+        def get_input_ids(text):
+            if isinstance(text, str):
+                return self._text_to_ids(text, add_special_tokens)
+            elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
+                return text
+            else:
+                raise ValueError(f"Input {text} is not valid. Should be a string, or a list/tuple of integers.")
+
+        ids = get_input_ids(text)
+
+        return self.prepare_for_model(
+            ids,
+            add_special_tokens=add_special_tokens,
+            padding=padding_strategy.value,
+            truncation=truncation_strategy.value,
+            max_length=max_length,
+            stride=stride,
+            pad_to_multiple_of=pad_to_multiple_of,
+            padding_side=padding_side,
+            return_tensors=return_tensors,
+            prepend_batch_axis=True,
+            return_attention_mask=return_attention_mask,
+            return_overflowing_tokens=return_overflowing_tokens,
+            return_special_tokens_mask=return_special_tokens_mask,
+            return_length=return_length,
+            verbose=verbose,
+        )
+
+    def _batch_encode_plus(
+        self,
+        batch_text: Union[
+            list[TextInput],
+            list[EncodedInput],
+        ],
+        add_special_tokens: bool = True,
+        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_overflowing_tokens: bool = False,
+        return_special_tokens_mask: bool = False,
+        return_offsets_mapping: bool = False,
+        return_length: bool = False,
+        verbose: bool = True,
+        **kwargs,
+    ) -> BatchEncoding:
+        def get_input_ids(text):
+            if isinstance(text, str):
+                return self._text_to_ids(text, add_special_tokens)
+            elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
+                return text
+            else:
+                raise ValueError("Input is not valid. Should be a string or a list/tuple of integers.")
+
+        if return_offsets_mapping:
+            raise NotImplementedError(
+                "return_offset_mapping is not available when using Python tokenizers. "
+                "To use this feature, change your tokenizer to one deriving from "
+                "transformers.PreTrainedTokenizerFast."
+            )
+
+        input_ids = []
+        for ids in batch_text:
+            input_ids.append(get_input_ids(ids))
+
+        batch_outputs = self._batch_prepare_for_model(
+            input_ids,
+            add_special_tokens=add_special_tokens,
+            padding_strategy=padding_strategy,
+            truncation_strategy=truncation_strategy,
+            max_length=max_length,
+            stride=stride,
+            pad_to_multiple_of=pad_to_multiple_of,
+            padding_side=padding_side,
+            return_attention_mask=return_attention_mask,
+            return_overflowing_tokens=return_overflowing_tokens,
+            return_special_tokens_mask=return_special_tokens_mask,
+            return_length=return_length,
+            return_tensors=return_tensors,
+            verbose=verbose,
+        )
+
+        return BatchEncoding(batch_outputs)
+
+    def _all_special_ids(self) -> set[int]:
+        if self._tokenizer_type == MistralTokenizerType.tekken:
+            return {t["rank"] for t in self.tokenizer.instruct_tokenizer.tokenizer._all_special_tokens}
+        elif self._tokenizer_type == MistralTokenizerType.spm:
+            return self.tokenizer.instruct_tokenizer.tokenizer._control_tokens()
+        else:
+            raise ValueError(f"Unknown tokenizer type: {self._tokenizer_type}")
+
+    def get_special_tokens_mask(
+        self, token_ids_0: list, token_ids_1: None = None, already_has_special_tokens: bool = False
+    ) -> list[int]:
+        """
+        Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
+        special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
+
+        Args:
+            token_ids_0 (`List[int]`):
+                List of ids of the sequence.
+            token_ids_1 (`List[int]`, *optional*):
+                Not supported by `MistralCommonTokenizer`. Kept to match the interface of `PreTrainedTokenizerBase`.
+            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not the token list is already formatted with special tokens for the model.
+
+        Returns:
+            A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+        """
+        if token_ids_1 is not None:
+            raise ValueError(
+                "`token_ids_1` is not supported by `MistralCommonTokenizer` and should be `None`, kept for compatibility."
+            )
+        if already_has_special_tokens:
+            raise ValueError(
+                "`already_has_special_tokens` is not supported by `MistralCommonTokenizer` and should be `False`."
+            )
+
+        all_special_ids = self._all_special_ids()  # cache the ids
+
+        special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
+        return special_tokens_mask
+
+    def _batch_prepare_for_model(
+        self,
+        batch_ids: list[Union[PreTokenizedInput, list[int]]],
+        add_special_tokens: bool = True,
+        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[str] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_overflowing_tokens: bool = False,
+        return_special_tokens_mask: bool = False,
+        return_length: bool = False,
+        verbose: bool = True,
+    ) -> BatchEncoding:
+        """
+        Prepares a sequence of input id so that it can be used by the model. It
+        adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
+        manages a moving window (with user defined stride) for overflowing tokens.
+
+        Args:
+            batch_ids: list of tokenized input ids
+        """
+
+        batch_outputs = {}
+        for ids in batch_ids:
+            outputs = self.prepare_for_model(
+                ids,
+                add_special_tokens=add_special_tokens,
+                padding=PaddingStrategy.DO_NOT_PAD.value,  # we pad in batch afterward
+                truncation=truncation_strategy.value,
+                max_length=max_length,
+                stride=stride,
+                pad_to_multiple_of=None,  # we pad in batch afterward
+                padding_side=None,  # we pad in batch afterward
+                return_attention_mask=False,  # we pad in batch afterward
+                return_overflowing_tokens=return_overflowing_tokens,
+                return_special_tokens_mask=return_special_tokens_mask,
+                return_length=return_length,
+                return_tensors=None,  # We convert the whole batch to tensors at the end
+                prepend_batch_axis=False,
+                verbose=verbose,
+            )
+
+            for key, value in outputs.items():
+                if key not in batch_outputs:
+                    batch_outputs[key] = []
+                batch_outputs[key].append(value)
+
+        batch_outputs = self.pad(
+            batch_outputs,
+            padding=padding_strategy.value,
+            max_length=max_length,
+            pad_to_multiple_of=pad_to_multiple_of,
+            padding_side=padding_side,
+            return_attention_mask=return_attention_mask,
+        )
+
+        batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
+
+        return batch_outputs
+
+    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+    def prepare_for_model(
+        self,
+        ids: list[int],
+        pair_ids: None = None,
+        add_special_tokens: bool = True,
+        padding: Union[bool, str, PaddingStrategy] = False,
+        truncation: Union[bool, str, TruncationStrategy, None] = None,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_overflowing_tokens: bool = False,
+        return_special_tokens_mask: bool = False,
+        return_length: bool = False,
+        verbose: bool = True,
+        prepend_batch_axis: bool = False,
+        **kwargs,
+    ) -> BatchEncoding:
+        """
+        Prepares a sequence of input id so that it can be used by the model. It
+        adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
+        manages a moving window (with user defined stride) for overflowing tokens.
+
+        Args:
+            ids (`List[int]`):
+                Tokenized input ids of the first sequence.
+            pair_ids (`None`, *optional*):
+                Not supported by `MistralCommonTokenizer`. Kept to match the interface of `PreTrainedTokenizerBase`.
+        """
+        if pair_ids is not None:
+            raise ValueError(
+                "`pair_ids` is not supported by `MistralCommonTokenizer` and should be `None`, kept for compatibility."
+            )
+        if kwargs:
+            raise ValueError(
+                f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer.prepare_for_model`."
+            )
+
+        padding_strategy, truncation_strategy, max_length, _ = self._get_padding_truncation_strategies(
+            padding=padding,
+            truncation=truncation,
+            max_length=max_length,
+            pad_to_multiple_of=pad_to_multiple_of,
+            verbose=verbose,
+        )
+
+        len_ids = len(ids)
+
+        # Load from model defaults
+        if return_attention_mask is None:
+            return_attention_mask = "attention_mask" in self.model_input_names
+
+        encoded_inputs = {}
+
+        # Truncation: Handle max sequence length
+        overflowing_tokens = []
+        if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and len_ids > max_length:
+            ids, _, overflowing_tokens = self.truncate_sequences(
+                ids,
+                num_tokens_to_remove=len_ids - max_length,
+                truncation_strategy=truncation_strategy,
+                stride=stride,
+            )
+
+        if return_overflowing_tokens:
+            encoded_inputs["overflowing_tokens"] = overflowing_tokens
+            encoded_inputs["num_truncated_tokens"] = len_ids - max_length
+
+        # Build output dictionary
+        encoded_inputs[self.model_input_names[0]] = ids
+        if return_special_tokens_mask:
+            if add_special_tokens:
+                encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, None)
+            else:
+                encoded_inputs["special_tokens_mask"] = [0] * len(ids)
+
+        # Padding
+        if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
+            encoded_inputs = self.pad(
+                encoded_inputs,
+                max_length=max_length,
+                padding=padding_strategy.value,
+                pad_to_multiple_of=pad_to_multiple_of,
+                padding_side=padding_side,
+                return_attention_mask=return_attention_mask,
+            )
+
+        if return_length:
+            encoded_inputs["length"] = len(encoded_inputs["input_ids"])
+
+        batch_outputs = BatchEncoding(
+            encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
+        )
+
+        return batch_outputs
+
+    def _get_padding_truncation_strategies(
+        self,
+        padding: Union[str, PaddingStrategy, bool] = False,
+        truncation: Optional[Union[str, TruncationStrategy, bool]] = None,
+        max_length: Optional[int] = None,
+        pad_to_multiple_of: Optional[int] = None,
+        verbose: bool = True,
+        **kwargs,
+    ):
+        """
+        Find the correct padding/truncation strategy.
+        """
+
+        # Backward compatibility for previous behavior, maybe we should deprecate it:
+        # If you only set max_length, it activates truncation for max_length
+        if max_length is not None and padding is False and truncation is None:
+            if verbose:
+                if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
+                    logger.warning(
+                        "Truncation was not explicitly activated but `max_length` is provided a specific value, please"
+                        " use `truncation=True` to explicitly truncate examples to max length. Defaulting to"
+                        " 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the"
+                        " tokenizer you can select this strategy more precisely by providing a specific strategy to"
+                        " `truncation`."
+                    )
+                self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
+            truncation = "longest_first"
+
+        # Get padding strategy
+        if padding is not False:
+            if padding is True:
+                if verbose:
+                    if max_length is not None and (
+                        truncation is None or truncation is False or truncation == "do_not_truncate"
+                    ):
+                        warnings.warn(
+                            "`max_length` is ignored when `padding`=`True` and there is no truncation strategy. "
+                            "To pad to max length, use `padding='max_length'`."
+                        )
+                padding_strategy = PaddingStrategy.LONGEST  # Default to pad to the longest sequence in the batch
+            elif not isinstance(padding, PaddingStrategy):
+                padding_strategy = PaddingStrategy(padding)
+            elif isinstance(padding, PaddingStrategy):
+                padding_strategy = padding
+        else:
+            padding_strategy = PaddingStrategy.DO_NOT_PAD
+
+        # Get truncation strategy
+        if truncation is not False and truncation is not None:
+            if truncation is True:
+                truncation_strategy = (
+                    TruncationStrategy.LONGEST_FIRST
+                )  # Default to truncate the longest sequences in pairs of inputs
+            elif not isinstance(truncation, TruncationStrategy):
+                truncation_strategy = TruncationStrategy(truncation)
+            elif isinstance(truncation, TruncationStrategy):
+                truncation_strategy = truncation
+            if truncation in [TruncationStrategy.ONLY_FIRST, TruncationStrategy.ONLY_SECOND]:
+                raise ValueError(
+                    "Truncation strategy `only_first` and `only_second` are not supported by `MistralCommonTokenizer`."
+                )
+        else:
+            truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
+
+        # Set max length if needed
+        if max_length is None:
+            if padding_strategy == PaddingStrategy.MAX_LENGTH:
+                if self.model_max_length > LARGE_INTEGER:
+                    if verbose:
+                        if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
+                            logger.warning(
+                                "Asking to pad to max_length but no maximum length is provided and the model has no"
+                                " predefined maximum length. Default to no padding."
+                            )
+                        self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
+                    padding_strategy = PaddingStrategy.DO_NOT_PAD
+                else:
+                    max_length = self.model_max_length
+
+            if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
+                if self.model_max_length > LARGE_INTEGER:
+                    if verbose:
+                        if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
+                            logger.warning(
+                                "Asking to truncate to max_length but no maximum length is provided and the model has"
+                                " no predefined maximum length. Default to no truncation."
+                            )
+                        self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
+                    truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
+                else:
+                    max_length = self.model_max_length
+
+        # Test if we have a padding token
+        if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.pad_token is None or self.pad_token_id < 0):
+            raise ValueError(
+                "Asking to pad but the tokenizer does not have a padding token. "
+                "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
+                "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
+            )
+
+        # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
+        if (
+            truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
+            and padding_strategy != PaddingStrategy.DO_NOT_PAD
+            and pad_to_multiple_of is not None
+            and max_length is not None
+            and (max_length % pad_to_multiple_of != 0)
+        ):
+            raise ValueError(
+                "Truncation and padding are both activated but "
+                f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
+            )
+
+        return padding_strategy, truncation_strategy, max_length, kwargs
+
+    def _pad(
+        self,
+        encoded_inputs: Union[dict[str, EncodedInput], BatchEncoding],
+        max_length: Optional[int] = None,
+        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_attention_mask: Optional[bool] = None,
+    ) -> dict:
+        """
+        Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
+
+        Args:
+            encoded_inputs:
+                Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
+            max_length: maximum length of the returned list and optionally padding length (see below).
+                Will truncate by taking into account the special tokens.
+            padding_strategy: PaddingStrategy to use for padding.
+
+                - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+                - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+                - PaddingStrategy.DO_NOT_PAD: Do not pad
+                The tokenizer padding sides are defined in `padding_side` argument:
+
+                    - 'left': pads on the left of the sequences
+                    - 'right': pads on the right of the sequences
+            pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+                This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+                `>= 7.5` (Volta).
+            padding_side:
+                The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+                Default value is picked from the class attribute of the same name.
+            return_attention_mask:
+                (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+        """
+        # Load from model defaults
+        if return_attention_mask is None:
+            return_attention_mask = "attention_mask" in self.model_input_names
+
+        required_input = encoded_inputs[self.model_input_names[0]]
+
+        if padding_strategy == PaddingStrategy.LONGEST:
+            max_length = len(required_input)
+
+        if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+            max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+        needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
+
+        # Initialize attention mask if not present.
+        if return_attention_mask and "attention_mask" not in encoded_inputs:
+            encoded_inputs["attention_mask"] = [1] * len(required_input)
+
+        if needs_to_be_padded:
+            difference = max_length - len(required_input)
+            padding_side = padding_side if padding_side is not None else self.padding_side
+
+            if padding_side == "right":
+                if return_attention_mask:
+                    encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
+                if "special_tokens_mask" in encoded_inputs:
+                    encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
+                encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
+            elif padding_side == "left":
+                if return_attention_mask:
+                    encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
+                if "special_tokens_mask" in encoded_inputs:
+                    encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
+                encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
+            else:
+                raise ValueError(f"Invalid padding strategy:{padding_side}")
+
+        return encoded_inputs
+
+    def pad(
+        self,
+        encoded_inputs: Union[
+            BatchEncoding,
+            list[BatchEncoding],
+            dict[str, EncodedInput],
+            dict[str, list[EncodedInput]],
+            list[dict[str, EncodedInput]],
+        ],
+        padding: Union[bool, str, PaddingStrategy] = True,
+        max_length: Optional[int] = None,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        verbose: bool = True,
+    ) -> BatchEncoding:
+        """
+        Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
+        in the batch.
+
+        Padding side (left/right) padding token ids are defined at the tokenizer level (with `self.padding_side`,
+        `self.pad_token_id`).
+        
+
+        If the `encoded_inputs` passed are dictionary of numpy arrays, PyTorch tensors, the
+        result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of
+        PyTorch tensors, you will lose the specific device of your tensors however.
+
+        
+
+        Args:
+            encoded_inputs ([`BatchEncoding`], list of [`BatchEncoding`], `Dict[str, List[int]]`, `Dict[str, List[List[int]]` or `List[Dict[str, List[int]]]`):
+                Tokenized inputs. Can represent one input ([`BatchEncoding`] or `Dict[str, List[int]]`) or a batch of
+                tokenized inputs (list of [`BatchEncoding`], *Dict[str, List[List[int]]]* or *List[Dict[str,
+                List[int]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader
+                collate function.
+
+                Instead of `List[int]` you can have tensors (numpy arrays, PyTorch tensors), see
+                the note above for the return type.
+            padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
+                 Select a strategy to pad the returned sequences (according to the model's padding side and padding
+                 index) among:
+
+                - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
+                  sequence if provided).
+                - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+                  acceptable input length for the model if that argument is not provided.
+                - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different
+                  lengths).
+            max_length (`int`, *optional*):
+                Maximum length of the returned list and optionally padding length (see above).
+            pad_to_multiple_of (`int`, *optional*):
+                If set will pad the sequence to a multiple of the provided value.
+
+                This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
+                `>= 7.5` (Volta).
+            padding_side (`str`, *optional*):
+                The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+                Default value is picked from the class attribute of the same name.
+            return_attention_mask (`bool`, *optional*):
+                Whether to return the attention mask. If left to the default, will return the attention mask according
+                to the specific tokenizer's default, defined by the `return_outputs` attribute.
+
+                [What are attention masks?](../glossary#attention-mask)
+            return_tensors (`str` or [`~utils.TensorType`], *optional*):
+                If set, will return tensors instead of list of python integers. Acceptable values are:
+
+                - `'pt'`: Return PyTorch `torch.Tensor` objects.
+                - `'np'`: Return Numpy `np.ndarray` objects.
+            verbose (`bool`, *optional*, defaults to `True`):
+                Whether or not to print more information and warnings.
+        """
+        # If we have a list of dicts, let's convert it in a dict of lists
+        # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
+        if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
+            encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0]}
+
+        # The model's main input name, usually `input_ids`, has been passed for padding
+        if self.model_input_names[0] not in encoded_inputs:
+            raise ValueError(
+                "You should supply an encoding or a list of encodings to this method "
+                f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
+            )
+
+        required_input = encoded_inputs[self.model_input_names[0]]
+
+        if required_input is None or (isinstance(required_input, Sized) and len(required_input) == 0):
+            if return_attention_mask:
+                encoded_inputs["attention_mask"] = []
+            return encoded_inputs
+
+        # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
+        # and rebuild them afterwards if no return_tensors is specified
+        # Note that we lose the specific device the tensor may be on for PyTorch
+
+        first_element = required_input[0]
+        if isinstance(first_element, (list, tuple)):
+            # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
+            for item in required_input:
+                if len(item) != 0:
+                    first_element = item[0]
+                    break
+        # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
+        if not isinstance(first_element, (int, list, tuple)):
+            if is_torch_tensor(first_element):
+                return_tensors = "pt" if return_tensors is None else return_tensors
+            elif isinstance(first_element, np.ndarray):
+                return_tensors = "np" if return_tensors is None else return_tensors
+            else:
+                raise ValueError(
+                    f"type of {first_element} unknown: {type(first_element)}. "
+                    "Should be one of a python, numpy, pytorch or tensorflow object."
+                )
+
+            for key, value in encoded_inputs.items():
+                encoded_inputs[key] = to_py_obj(value)
+
+        # Convert padding_strategy in PaddingStrategy
+        padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
+            padding=padding, max_length=max_length, verbose=verbose
+        )
+
+        required_input = encoded_inputs[self.model_input_names[0]]
+        if required_input and not isinstance(required_input[0], (list, tuple)):
+            encoded_inputs = self._pad(
+                encoded_inputs,
+                max_length=max_length,
+                padding_strategy=padding_strategy,
+                pad_to_multiple_of=pad_to_multiple_of,
+                padding_side=padding_side,
+                return_attention_mask=return_attention_mask,
+            )
+            return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
+
+        batch_size = len(required_input)
+        assert all(len(v) == batch_size for v in encoded_inputs.values()), (
+            "Some items in the output dictionary have a different batch size than others."
+        )
+
+        if padding_strategy == PaddingStrategy.LONGEST:
+            max_length = max(len(inputs) for inputs in required_input)
+            padding_strategy = PaddingStrategy.MAX_LENGTH
+
+        batch_outputs = {}
+        for i in range(batch_size):
+            inputs = {k: v[i] for k, v in encoded_inputs.items()}
+            outputs = self._pad(
+                inputs,
+                max_length=max_length,
+                padding_strategy=padding_strategy,
+                pad_to_multiple_of=pad_to_multiple_of,
+                padding_side=padding_side,
+                return_attention_mask=return_attention_mask,
+            )
+
+            for key, value in outputs.items():
+                if key not in batch_outputs:
+                    batch_outputs[key] = []
+                batch_outputs[key].append(value)
+
+        return BatchEncoding(batch_outputs, tensor_type=return_tensors)
+
+    def truncate_sequences(
+        self,
+        ids: list[int],
+        pair_ids: None = None,
+        num_tokens_to_remove: int = 0,
+        truncation_strategy: Union[str, TruncationStrategy] = "longest_first",
+        stride: int = 0,
+        **kwargs,
+    ) -> tuple[list[int], None, list[int]]:
+        """
+        Truncates a sequence pair in-place following the strategy.
+
+        Args:
+            ids (`List[int]`):
+                Tokenized input ids. Can be obtained from a string by chaining the `tokenize` and
+                `convert_tokens_to_ids` methods.
+            pair_ids (`None`, *optional*):
+                Not supported by `MistralCommonTokenizer`. Kept to match the signature of `PreTrainedTokenizerBase.truncate_sequences`.
+            num_tokens_to_remove (`int`, *optional*, defaults to 0):
+                Number of tokens to remove using the truncation strategy.
+            truncation_strategy (`str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `'longest_first'`):
+                The strategy to follow for truncation. Can be:
+
+                - `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+                  maximum acceptable input length for the model if that argument is not provided.
+                - `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater
+                  than the model maximum admissible input size).
+            stride (`int`, *optional*, defaults to 0):
+                If set to a positive number, the overflowing tokens returned will contain some tokens from the main
+                sequence returned. The value of this argument defines the number of additional tokens.
+
+        Returns:
+            `Tuple[List[int], None, List[int]]`: The truncated `ids` and the list of
+            overflowing tokens. `None` is returned to match Transformers signature.
+        """
+        if kwargs:
+            raise ValueError(
+                f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer.truncate_sequences`."
+            )
+        if pair_ids:
+            raise ValueError("`pair_ids` is not supported by `MistralCommonTokenizer.truncate_sequences`.")
+
+        if num_tokens_to_remove <= 0:
+            return (ids, None, [])
+
+        if not isinstance(truncation_strategy, TruncationStrategy):
+            truncation_strategy = TruncationStrategy(truncation_strategy)
+
+        if truncation_strategy in [TruncationStrategy.ONLY_FIRST, TruncationStrategy.ONLY_SECOND]:
+            raise ValueError(
+                f"Only {TruncationStrategy.LONGEST_FIRST} and {TruncationStrategy.DO_NOT_TRUNCATE} are supported."
+            )
+
+        overflowing_tokens = []
+        if truncation_strategy == TruncationStrategy.LONGEST_FIRST:
+            if len(ids) > num_tokens_to_remove:
+                window_len = min(len(ids), stride + num_tokens_to_remove)
+                if self.truncation_side == "left":
+                    overflowing_tokens = ids[:window_len]
+                    ids = ids[num_tokens_to_remove:]
+                elif self.truncation_side == "right":
+                    overflowing_tokens = ids[-window_len:]
+                    ids = ids[:-num_tokens_to_remove]
+                else:
+                    raise ValueError(f"invalid truncation strategy: {self.truncation_side}, use 'left' or 'right'.")
+
+            else:
+                error_msg = (
+                    f"We need to remove {num_tokens_to_remove} to truncate the input "
+                    f"but the first sequence has a length {len(ids)}. "
+                )
+                logger.error(error_msg)
+
+        return (ids, None, overflowing_tokens)
+
+    def apply_chat_template(
+        self,
+        conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
+        tools: Optional[list[Union[dict, Callable]]] = None,
+        continue_final_message: bool = False,
+        tokenize: bool = True,
+        padding: Union[bool, str, PaddingStrategy] = False,
+        truncation: bool = False,
+        max_length: Optional[int] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        return_dict: bool = False,
+        **kwargs,
+    ) -> Union[str, list[int], list[str], list[list[int]], BatchEncoding]:
+        """
+        Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token
+        ids.
+
+        Args:
+            conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts
+                with "role" and "content" keys, representing the chat history so far.
+            tools (`List[Union[Dict, Callable]]`, *optional*):
+                A list of tools (callable functions) that will be accessible to the model. If the template does not
+                support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
+                giving the name, description and argument types for the tool. See our
+                [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
+                for more information.
+            continue_final_message (bool, *optional*):
+                If this is set, the chat will be formatted so that the final
+                message in the chat is open-ended, without any EOS tokens. The model will continue this message
+                rather than starting a new one. This allows you to "prefill" part of
+                the model's response for it. Cannot be used at the same time as `add_generation_prompt`.
+            tokenize (`bool`, defaults to `True`):
+                Whether to tokenize the output. If `False`, the output will be a string.
+            padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+                 Select a strategy to pad the returned sequences (according to the model's padding side and padding
+                 index) among:
+
+                - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+                  sequence if provided).
+                - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+                  acceptable input length for the model if that argument is not provided.
+                - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+                  lengths).
+            truncation (`bool`, defaults to `False`):
+                Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`.
+            max_length (`int`, *optional*):
+                Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If
+                not specified, the tokenizer's `max_length` attribute will be used as a default.
+            return_tensors (`str` or [`~utils.TensorType`], *optional*):
+                If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable
+                values are:
+                - `'pt'`: Return PyTorch `torch.Tensor` objects.
+            return_dict (`bool`, defaults to `False`):
+                Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
+                If at least one conversation contains an image, its pixel values will be returned in the `pixel_values` key.
+            kwargs (additional keyword arguments, *optional*):
+                Not supported by `MistralCommonTokenizer.apply_chat_template`.
+                Will raise an error if used.
+
+        Returns:
+            `Union[str, List[int], List[str], List[List[int]], BatchEncoding]`: A list of token ids representing the tokenized chat so far, including control
+            tokens. This output is ready to pass to the model, either directly or via methods like `generate()`.
+        """
+        if kwargs:
+            raise ValueError(
+                f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer.apply_chat_template`."
+            )
+        if not isinstance(truncation, bool):
+            raise ValueError("`truncation` must be a boolean for `apply_chat_template` method.")
+
+        if isinstance(conversation, (list, tuple)) and (
+            isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
+        ):
+            conversations = conversation
+            is_batched = True
+        else:
+            conversations = [conversation]
+            is_batched = False
+
+        def _maybe_adapt_message(message: dict[str, Any]) -> None:
+            """Adapt message to `mistral-common` format and leave validation to `mistral-common`."""
+            if not isinstance(message, dict):
+                return
+            maybe_list_content: Optional[Union[str, list[dict[str, Union[str, dict[str, Any]]]]]] = message.get(
+                "content"
+            )
+            if not maybe_list_content or isinstance(maybe_list_content, str):
+                return
+
+            normalized_content: list[dict[str, Union[str, dict[str, Any]]]] = []
+            for content in maybe_list_content:
+                content_type = content.get("type", None)
+                if not content_type:
+                    continue
+                elif content_type == "image":
+                    maybe_url: Optional[str] = content.get("url")
+                    maybe_path: Optional[str] = content.get("path")
+                    maybe_base64: Optional[str] = content.get("base64")
+                    if maybe_url:
+                        image_content = maybe_url
+                    elif maybe_path:
+                        if not maybe_path.startswith("file://"):
+                            maybe_path = Path(maybe_path).resolve().as_uri()
+                        image_content = maybe_path
+                    elif maybe_base64:
+                        if not maybe_base64.startswith("data:image"):
+                            maybe_base64 = "data:image/unk;base64," + maybe_base64
+                        image_content = maybe_base64
+                    else:
+                        raise ValueError("Image content must be specified.")
+                    normalized_content.append({"type": "image_url", "image_url": {"url": image_content}})
+                elif content_type == "audio":
+                    maybe_url: Optional[str] = content.get("url")
+                    maybe_path: Optional[str] = content.get("path")
+                    maybe_base64: Optional[str] = content.get("base64")
+                    if maybe_url or maybe_path:
+                        audio_data = load_audio_as(maybe_url or maybe_path, return_format="dict", force_mono=True)
+                        normalized_content.append({"type": "input_audio", "input_audio": audio_data})
+                        continue
+                    if not maybe_base64:
+                        raise ValueError("Audio content must be specified.")
+                    normalized_content.append({"type": "audio_url", "audio_url": {"url": maybe_base64}})
+                else:
+                    normalized_content.append(content)
+            message["content"] = normalized_content
+
+        outputs = []
+        images: list[np.ndarray] = []
+        audios: list[np.ndarray] = []
+
+        for conversation in conversations:
+            messages: list[dict[str, Union[str, list[dict[str, Union[str, dict[str, Any]]]]]]] = []
+            for message in conversation:
+                _maybe_adapt_message(message)
+                messages.append(message)
+
+            chat_request = ChatCompletionRequest.from_openai(
+                messages=messages,
+                tools=tools,
+                continue_final_message=continue_final_message,
+            )
+
+            tokenized_request = self.tokenizer.encode_chat_completion(chat_request)
+            if tokenize:
+                outputs.append(tokenized_request.tokens)
+            else:
+                outputs.append(tokenized_request.text)
+            images.extend(tokenized_request.images)
+            audios.extend([el.audio_array for el in tokenized_request.audios])
+
+        if not is_batched:
+            outputs = outputs[0]
+
+        if tokenize:
+            out = self(
+                outputs,
+                padding=padding,
+                truncation=truncation,
+                max_length=max_length,
+                add_special_tokens=False,
+                return_tensors=return_tensors,
+            )
+            if return_dict:
+                if images:
+                    pixel_values: Union[list[np.ndarray], np.ndarray, torch.Tensor]
+                    if return_tensors == "pt":
+                        if not is_torch_available():
+                            raise ImportError(
+                                "Unable to convert output to PyTorch tensors format, PyTorch is not installed."
+                            )
+
+                        pixel_values = torch.tensor(images)
+                    elif return_tensors == "np":
+                        pixel_values = np.array(images)
+                    elif return_tensors is None:
+                        pixel_values = images
+                    else:
+                        raise ValueError(f"Unsupported return_tensors type: {return_tensors}")
+                    out.data["pixel_values"] = pixel_values
+                if audios:
+                    if return_tensors is not None:
+                        raise NotImplementedError(
+                            "When passing audio content in apply_chat_template, `return_tensors` must be None since we cannot batch the audio inputs. The returned audio will be a list of numpy arrays."
+                        )
+                    # Transformers convention is audio for plural audio (audio does not take a "s")
+                    out.data["audio"] = audios
+                return out
+            else:
+                return out["input_ids"]
+
+        else:
+            logger.warning(
+                "`MistralCommonTokenizer.apply_chat_template(..., tokenize=False)` is unsafe and may lead to unexpected behavior."
+                " Please consider using `tokenize=True` instead and don't encode the output manually."
+            )
+            return outputs
+
+    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+    def __call__(
+        self,
+        text: Union[TextInput, EncodedInput, list[TextInput], list[EncodedInput], None] = None,
+        text_pair: None = None,
+        text_target: None = None,
+        text_pair_target: None = None,
+        add_special_tokens: bool = True,
+        padding: Union[bool, str, PaddingStrategy] = False,
+        truncation: Union[bool, str, TruncationStrategy, None] = None,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_overflowing_tokens: bool = False,
+        return_special_tokens_mask: bool = False,
+        return_length: bool = False,
+        verbose: bool = True,
+        **kwargs,
+    ) -> BatchEncoding:
+        """
+        Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
+        sequences.
+
+        Args:
+            text (`str`, `List[str]`, `List[List[str]]`, *optional*):
+                The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of int
+                (encoded strings).
+            text_pair (`None`, *optional*):
+                Not supported by `MistralCommonTokenizer`. Kept to match the signature of `PreTrainedTokenizerBase.__call__`.
+            text_target (`None`, *optional*):
+                Not supported by `MistralCommonTokenizer`. Kept to match the signature of `PreTrainedTokenizerBase.__call__`.
+            text_pair_target (`None`, *optional*):
+                Not supported by `MistralCommonTokenizer`. Kept to match the signature of `PreTrainedTokenizerBase.__call__`.
+        """
+        if kwargs:
+            raise ValueError(f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer.__call__`.")
+
+        if text_pair or text_target or text_pair_target:
+            raise ValueError(
+                "`text_pair`, `text_target` and `text_pair_target` are not supported by `MistralCommonTokenizer`."
+            )
+
+        if return_tensors in ("tf", "jax"):
+            raise ValueError(
+                "`MistralCommonTokenizer` does not support `return_tensors='tf'` or `return_tensors='jax'`."
+            )
+
+        def _is_valid_text_input(t):
+            if isinstance(t, str):
+                # Strings are fine
+                return True
+            elif isinstance(t, (list, tuple)):
+                # List are fine as long as they are...
+                if len(t) == 0:
+                    # ... empty
+                    return True
+                elif isinstance(t[0], (str, int)):
+                    # ... list of strings or int
+                    return True
+                elif isinstance(t[0], (list, tuple)):
+                    # ... list with an empty list or with a list of strings or with a list of ints
+                    return len(t[0]) == 0 or isinstance(t[0][0], (str, int))
+                else:
+                    return False
+            else:
+                return False
+
+        if not _is_valid_text_input(text):
+            raise ValueError(
+                "text input must be of type `str` (single example), `List[str]` (batch or single encoded example) "
+                "or `List[List[int]]` (batch of encoded examples)."
+            )
+
+        is_batched = isinstance(text, (list, tuple)) and isinstance(text[0], (str, list, tuple))
+
+        padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+            padding=padding,
+            truncation=truncation,
+            max_length=max_length,
+            pad_to_multiple_of=pad_to_multiple_of,
+            verbose=verbose,
+            **kwargs,
+        )
+
+        if is_batched:
+            return self._batch_encode_plus(
+                batch_text=text,
+                add_special_tokens=add_special_tokens,
+                padding_strategy=padding_strategy,
+                truncation_strategy=truncation_strategy,
+                max_length=max_length,
+                stride=stride,
+                pad_to_multiple_of=pad_to_multiple_of,
+                padding_side=padding_side,
+                return_tensors=return_tensors,
+                return_attention_mask=return_attention_mask,
+                return_overflowing_tokens=return_overflowing_tokens,
+                return_special_tokens_mask=return_special_tokens_mask,
+                return_length=return_length,
+                verbose=verbose,
+                **kwargs,
+            )
+        else:
+            return self._encode_plus(
+                text=text,
+                add_special_tokens=add_special_tokens,
+                padding_strategy=padding_strategy,
+                truncation_strategy=truncation_strategy,
+                max_length=max_length,
+                stride=stride,
+                pad_to_multiple_of=pad_to_multiple_of,
+                padding_side=padding_side,
+                return_tensors=return_tensors,
+                return_attention_mask=return_attention_mask,
+                return_overflowing_tokens=return_overflowing_tokens,
+                return_special_tokens_mask=return_special_tokens_mask,
+                return_length=return_length,
+                verbose=verbose,
+                **kwargs,
+            )
+
+    @classmethod
+    def from_pretrained(
+        cls,
+        pretrained_model_name_or_path: Union[str, os.PathLike],
+        *init_inputs,
+        mode: ValidationMode = ValidationMode.test,
+        cache_dir: Optional[Union[str, os.PathLike]] = None,
+        force_download: bool = False,
+        local_files_only: bool = False,
+        token: Optional[Union[str, bool]] = None,
+        revision: str = "main",
+        model_max_length: int = VERY_LARGE_INTEGER,
+        padding_side: str = "left",
+        truncation_side: str = "right",
+        model_input_names: Optional[list[str]] = None,
+        clean_up_tokenization_spaces: bool = False,
+        **kwargs,
+    ):
+        r"""
+        Instantiate a `MistralCommonTokenizer` from a predefined
+        tokenizer.
+
+        Args:
+            pretrained_model_name_or_path (`str` or `os.PathLike`):
+                Can be either:
+
+                - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
+                - A path to a *directory* containing the tokenizer config, for instance saved
+                  using the [`MistralCommonTokenizer.tokenization_mistral_common.save_pretrained`] method, e.g.,
+                  `./my_model_directory/`.
+            mode (`ValidationMode`, *optional*, defaults to `ValidationMode.test`):
+                Validation mode for the `MistralTokenizer` tokenizer.
+            cache_dir (`str` or `os.PathLike`, *optional*):
+                Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the
+                standard cache should not be used.
+            force_download (`bool`, *optional*, defaults to `False`):
+                Whether or not to force the (re-)download the vocabulary files and override the cached versions if they
+                exist.
+            token (`str` or *bool*, *optional*):
+                The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+                when running `hf auth login` (stored in `~/.huggingface`).
+            local_files_only (`bool`, *optional*, defaults to `False`):
+                Whether or not to only rely on local files and not to attempt to download any files.
+            revision (`str`, *optional*, defaults to `"main"`):
+                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+                identifier allowed by git.
+            max_length (`int`, *optional*):
+                Controls the maximum length to use by one of the truncation/padding parameters.
+
+                If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
+                is required by one of the truncation/padding parameters. If the model has no specific maximum input
+                length (like XLNet) truncation/padding to a maximum length will be deactivated.
+            padding_side (`str`, *optional*, defaults to `"left"`):
+                The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+                Default value is picked from the class attribute of the same name.
+            truncation_side (`str`, *optional*, defaults to `"right"`):
+                The side on which the model should have truncation applied. Should be selected between ['right', 'left'].
+            model_input_names (`List[string]`, *optional*):
+                The list of inputs accepted by the forward pass of the model (like `"token_type_ids"` or
+                `"attention_mask"`). Default value is picked from the class attribute of the same name.
+            clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
+                Whether or not the model should cleanup the spaces that were added when splitting the input text during the
+                tokenization process.
+            kwargs (additional keyword arguments, *optional*):
+                Not supported by `MistralCommonTokenizer.from_pretrained`.
+                Will raise an error if used.
+        """
+        if init_inputs:
+            raise ValueError("`init_inputs` are not supported by `MistralCommonTokenizer.from_pretrained`.")
+
+        # Handle kwargs and AutoTokenizer case
+        if kwargs and not set(kwargs.keys()).issubset({"_from_auto", "trust_remote_code"}):
+            raise ValueError(
+                f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer.from_pretrained`."
+            )
+
+        if not os.path.isdir(pretrained_model_name_or_path):
+            tokenizer_path = download_tokenizer_from_hf_hub(
+                repo_id=pretrained_model_name_or_path,
+                cache_dir=cache_dir,
+                token=token,
+                revision=revision,
+                force_download=force_download,
+                local_files_only=local_files_only,
+            )
+        else:
+            valid_tokenizer_files = []
+            tokenizer_file: str
+
+            instruct_versions = list(TokenizerVersion.__members__)
+            mm_versions = list(MultiModalVersion.__members__) + [""]  # allow no mm version
+            sentencepiece_suffixes = [f".model.{v}{m}" for v in instruct_versions for m in mm_versions] + [".model"]
+
+            for path in os.listdir(pretrained_model_name_or_path):
+                pathlib_repo_file = Path(path)
+                file_name = pathlib_repo_file.name
+                suffix = "".join(pathlib_repo_file.suffixes)
+                if file_name == "tekken.json" or suffix in sentencepiece_suffixes:
+                    valid_tokenizer_files.append(file_name)
+
+            if len(valid_tokenizer_files) == 0:
+                raise ValueError(f"No tokenizer file found in directory: {pretrained_model_name_or_path}")
+            # If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one.
+            if len(valid_tokenizer_files) > 1:
+                if "tekken.json" in valid_tokenizer_files:
+                    tokenizer_file = "tekken.json"
+                else:
+                    tokenizer_file = sorted(valid_tokenizer_files)[-1]
+                logger.warning(
+                    f"Multiple tokenizer files found in directory: {pretrained_model_name_or_path}. Using {tokenizer_file}."
+                )
+            else:
+                tokenizer_file = valid_tokenizer_files[0]
+
+            tokenizer_path = os.path.join(pretrained_model_name_or_path, tokenizer_file)
+
+        return cls(
+            tokenizer_path=tokenizer_path,
+            mode=mode,
+            model_max_length=model_max_length,
+            padding_side=padding_side,
+            truncation_side=truncation_side,
+            model_input_names=model_input_names,
+            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+        )
+
+    def save_pretrained(
+        self,
+        save_directory: Union[str, os.PathLike, Path],
+        push_to_hub: bool = False,
+        token: Optional[Union[str, bool]] = None,
+        commit_message: Optional[str] = None,
+        repo_id: Optional[str] = None,
+        private: Optional[bool] = None,
+        repo_url: Optional[str] = None,
+        organization: Optional[str] = None,
+        **kwargs,
+    ) -> tuple[str]:
+        """
+        Save the full tokenizer state.
+
+
+        This method make sure the full tokenizer can then be re-loaded using the
+        [`~MistralCommonTokenizer.tokenization_mistral_common.from_pretrained`] class method.
+
+        Args:
+            save_directory (`str` or `os.PathLike`): The path to a directory where the tokenizer will be saved.
+            push_to_hub (`bool`, *optional*, defaults to `False`):
+                Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+                namespace).
+            token (`str` or *bool*, *optional*, defaults to `None`):
+                The token to use to push to the model hub. If `True`, will use the token in the `HF_TOKEN` environment
+                variable.
+            commit_message (`str`, *optional*): The commit message to use when pushing to the hub.
+            repo_id (`str`, *optional*): The name of the repository to which push to the Hub.
+            private (`bool`, *optional*): Whether the model repository is private or not.
+            repo_url (`str`, *optional*): The URL to the Git repository to which push to the Hub.
+            organization (`str`, *optional*): The name of the organization in which you would like to push your model.
+            kwargs (`Dict[str, Any]`, *optional*):
+                Not supported by `MistralCommonTokenizer.save_pretrained`.
+                Will raise an error if used.
+
+        Returns:
+            A tuple of `str`: The files saved.
+        """
+        # `save_jinja_files`` must be skipped to be able to save from a processor
+        kwargs.pop("save_jinja_files", None)
+        if kwargs:
+            raise ValueError(
+                f"Kwargs {list(kwargs.keys())} are not supported by `MistralCommonTokenizer.save_pretrained`."
+            )
+
+        save_directory = Path(save_directory)
+        save_directory.mkdir(parents=True, exist_ok=True)
+
+        shutil.copy(self._tokenizer_path, save_directory)
+
+        if push_to_hub:
+            repo_id = repo_id or str(save_directory).split(os.path.sep)[-1]
+            repo_id = self._create_repo(
+                repo_id, token=token, private=private, repo_url=repo_url, organization=organization
+            )
+            files_timestamps = self._get_files_timestamps(save_directory)
+
+            self._upload_modified_files(
+                save_directory,
+                repo_id,
+                files_timestamps,
+                commit_message=commit_message,
+                token=token,
+            )
+
+        return (str(save_directory / self._tokenizer_path.name),)
diff --git a/phivenv/Lib/site-packages/transformers/tokenization_utils.py b/phivenv/Lib/site-packages/transformers/tokenization_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..08627d62c123859aab53accfbb3070377ef0fd12
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/tokenization_utils.py
@@ -0,0 +1,1135 @@
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Tokenization classes for python tokenizers. For fast tokenizers (provided by HuggingFace's tokenizers library) see
+tokenization_utils_fast.py
+"""
+
+import bisect
+import itertools
+import re
+import unicodedata
+from collections import OrderedDict
+from typing import Any, Optional, Union, overload
+
+from .tokenization_utils_base import (
+    ENCODE_KWARGS_DOCSTRING,
+    ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING,
+    INIT_TOKENIZER_DOCSTRING,
+    AddedToken,
+    BatchEncoding,
+    EncodedInput,
+    EncodedInputPair,
+    PreTokenizedInput,
+    PreTokenizedInputPair,
+    PreTrainedTokenizerBase,
+    TextInput,
+    TextInputPair,
+    TruncationStrategy,
+)
+from .utils import PaddingStrategy, TensorType, add_end_docstrings, logging
+
+
+logger = logging.get_logger(__name__)
+
+# Slow tokenizers are saved in a vocabulary plus three separated files
+SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
+ADDED_TOKENS_FILE = "added_tokens.json"
+TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
+
+
+class Trie:
+    """
+    Trie in Python. Creates a Trie out of a list of words. The trie is used to split on `added_tokens` in one pass
+    Loose reference https://en.wikipedia.org/wiki/Trie
+    """
+
+    def __init__(self, *args):
+        self.data = {}
+        self._tokens = set()
+        self._termination_char = ""
+        self.update(*args)
+
+    def update(self, *args):
+        """
+        Updates the Trie with new tokens provided as arguments.
+
+        Args:
+            *args: Variable number of words to be added to the Trie.
+        """
+        for token in tuple(*args):
+            self.add(token)
+
+    def add(self, word: str):
+        """
+        Passes over every char (utf-8 char) on word and recursively adds it to the internal `data` trie representation.
+        The special key `""` in `self._termination_char` is used to represent termination.
+
+        This function is idempotent, adding twice the same word will leave the trie unchanged
+
+        Example:
+
+        ```python
+        >>> trie = Trie()
+        >>> trie.add("Hello 友達")
+        >>> trie.data
+        {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}}
+
+        >>> trie.add("Hello")
+        >>> trie.data
+        {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}}
+        ```
+        """
+        if not word:
+            # Prevent empty string
+            return
+
+        self._tokens.add(word)
+        ref = self.data
+        for char in word:
+            ref[char] = ref.setdefault(char, {})
+            ref = ref[char]
+        ref[self._termination_char] = 1
+
+    def split(self, text: str) -> list[str]:
+        """
+        Will look for the words added to the trie within `text`. Output is the original string split along the
+        boundaries of the words found.
+
+        This trie will match the longest possible word first !
+
+        Example:
+
+        ```python
+        >>> trie = Trie()
+        >>> trie.split("[CLS] This is a extra_id_100")
+        ["[CLS] This is a extra_id_100"]
+
+        >>> trie.add("[CLS]")
+        >>> trie.add("extra_id_1")
+        >>> trie.add("extra_id_100")
+        >>> trie.split("[CLS] This is a extra_id_100")
+        ["[CLS]", " This is a ", "extra_id_100"]
+        ```
+        """
+        # indexes are counted left of the chars index.
+        # "hello", index 0, is left of h, index 1 is between h and e.
+        # index 5 is right of the "o".
+
+        # States are going to capture every possible start (indexes as above)
+        # as keys, and have as values, a pointer to the position in the trie
+        # where we're at. This is a partial match for now.
+        # This enables to keep track of multiple matches while we're iterating
+        # the string
+        # If the trie contains, "blowing", and "lower" and we encounter the
+        # string "blower", we need to split into ["b", "lower"].
+        # This is where we need to keep track of multiple possible starts.
+        states = OrderedDict()
+
+        # This will contain every indices where we need
+        # to cut.
+        # We force to cut at offset 0 and len(text) (added later)
+        offsets = [0]
+
+        # This is used by the lookahead which needs to skip over
+        # some text where the full match exceeded the place in the initial
+        # for loop
+        skip = 0
+        # Main loop, Giving this algorithm O(n) complexity
+        for current, current_char in enumerate(text):
+            if skip and current < skip:
+                # Prevents the lookahead for matching twice
+                # like extra_id_100 and id_100
+                continue
+
+            # This will track every state
+            # that stop matching, we need to stop tracking them.
+            # If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then
+            # fail on "b", we need to remove 0 from the valid states.
+            to_remove = set()
+            # Whenever we found a match, we need to drop everything
+            # this is a greedy algorithm, it will match on the first found token
+            reset = False
+
+            # In this case, we already have partial matches (But unfinished)
+            for start, trie_pointer in states.items():
+                if "" in trie_pointer:
+                    # This is a final match, we need to reset and
+                    # store the results in `offsets`.
+
+                    # Lookahead to match longest first
+                    # Important in case of extra_id_1 vs extra_id_100
+                    # Here we are also actively looking for other earlier partial
+                    # matches
+                    # "[CLS]", "L", we need to match CLS even if L is special
+                    for lookstart, looktrie_pointer in states.items():
+                        if lookstart > start:
+                            # This partial match is later, we can stop looking
+                            break
+                        elif lookstart < start:
+                            # This partial match is earlier, the trie pointer
+                            # was already updated, so index is + 1
+                            lookahead_index = current + 1
+                            end = current + 1
+                        else:
+                            # Here lookstart == start and
+                            #      looktrie_pointer == trie_pointer
+                            # It wasn't updated yet so indices are current ones
+                            lookahead_index = current
+                            end = current
+                        next_char = text[lookahead_index] if lookahead_index < len(text) else None
+                        if "" in looktrie_pointer:
+                            start = lookstart
+                            end = lookahead_index
+                            skip = lookahead_index
+
+                        while next_char in looktrie_pointer:
+                            looktrie_pointer = looktrie_pointer[next_char]
+                            lookahead_index += 1
+                            if "" in looktrie_pointer:
+                                start = lookstart
+                                end = lookahead_index
+                                skip = lookahead_index
+
+                            if lookahead_index == len(text):
+                                # End of string
+                                break
+                            next_char = text[lookahead_index]
+                        # End lookahead
+
+                    # Storing and resetting
+                    offsets.append(start)
+                    offsets.append(end)
+                    reset = True
+                    break
+                elif current_char in trie_pointer:
+                    # The current character being looked at has a match within the trie
+                    # update the pointer (it will be stored back into states later).
+                    trie_pointer = trie_pointer[current_char]
+
+                    # Storing back the new pointer into the states.
+                    # Partial matches got longer by one.
+                    states[start] = trie_pointer
+                else:
+                    # The new character has not match in the trie, we need
+                    # to stop keeping track of this partial match.
+                    # We can't do it directly within the loop because of how
+                    # python iteration works
+                    to_remove.add(start)
+
+            # Either clearing the full start (we found a real match)
+            # Or clearing only the partial matches that didn't work.
+            if reset:
+                states = {}
+            else:
+                for start in to_remove:
+                    del states[start]
+
+            # If this character is a starting character within the trie
+            # start keeping track of this partial match.
+            if current >= skip and current_char in self.data:
+                states[current] = self.data[current_char]
+
+        # We have a cut at the end with states.
+        for start, trie_pointer in states.items():
+            if "" in trie_pointer:
+                # This is a final match, we need to reset and
+                # store the results in `offsets`.
+                end = len(text)
+                offsets.append(start)
+                offsets.append(end)
+                # Longest cut is always the one with lower start so the first
+                # item so we need to break.
+                break
+
+        return self.cut_text(text, offsets)
+
+    def cut_text(self, text, offsets):
+        # We have all the offsets now, we just need to do the actual splitting.
+        # We need to eventually add the first part of the string and the eventual
+        # last part.
+        offsets.append(len(text))
+        tokens = []
+        start = 0
+        for end in offsets:
+            if start > end:
+                logger.error(
+                    "There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it"
+                    " anyway."
+                )
+                continue
+            elif start == end:
+                # This might happen if there's a match at index 0
+                # we're also preventing zero-width cuts in case of two
+                # consecutive matches
+                continue
+            tokens.append(text[start:end])
+            start = end
+
+        return tokens
+
+
+class ExtensionsTrie(Trie):
+    def __init__(self, *args):
+        super().__init__(*args)
+
+    def extensions(self, prefix: str):
+        """
+        Generates all extensions of a given prefix token in the Trie.
+
+        Example:
+
+        ```python
+        >>> trie = Trie()
+        >>> trie.add("apple")
+        >>> trie.add("app")
+        >>> trie.add("application")
+        >>> trie.extensions("app")
+        ['app', 'apple', 'application']
+        ```
+        """
+        prefix_node = self._get_node(prefix)
+        ret = self._collect_tokens(prefix_node)
+        return [prefix + token for token in ret]
+
+    def _get_node(self, token: str) -> dict:
+        """
+        Retrieves the node corresponding to the given token in the Trie.
+
+        Args:
+            token (str): The token for which the corresponding node needs to be retrieved.
+
+        Returns:
+            dict: The node in the Trie corresponding to the given token.
+        """
+        node = self.data
+        for char in token:
+            if char not in node:
+                break
+
+            node = node[char]
+        return node
+
+    def _collect_tokens(self, node: dict) -> list:
+        """
+        Generates all tokens in the Trie starting from a given node.
+
+        Args:
+            node (dict): The node in the Trie from which tokens need to be generated.
+
+        Returns:
+            list: List of tokens generated from the given node.
+        """
+        tokens = [self._termination_char] if self._termination_char in node else []
+        for token, subtrie_head in node.items():
+            if token != self._termination_char:
+                subtokens = self._collect_tokens(subtrie_head)
+                tokens.extend([token + subtoken for subtoken in subtokens])
+        return tokens
+
+
+def _is_whitespace(char):
+    """Checks whether `char` is a whitespace character."""
+    # \t, \n, and \r are technically control characters but we treat them
+    # as whitespace since they are generally considered as such.
+    if char == " " or char == "\t" or char == "\n" or char == "\r":
+        return True
+    cat = unicodedata.category(char)
+    if cat == "Zs":
+        return True
+    return False
+
+
+def _is_control(char):
+    """Checks whether `char` is a control character."""
+    # These are technically control characters but we count them as whitespace
+    # characters.
+    if char == "\t" or char == "\n" or char == "\r":
+        return False
+    cat = unicodedata.category(char)
+    if cat.startswith("C"):
+        return True
+    return False
+
+
+def _is_punctuation(char):
+    """Checks whether `char` is a punctuation character."""
+    cp = ord(char)
+    # We treat all non-letter/number ASCII as punctuation.
+    # Characters such as "^", "$", and "`" are not in the Unicode
+    # Punctuation class but we treat them as punctuation anyways, for
+    # consistency.
+    if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
+        return True
+    cat = unicodedata.category(char)
+    if cat.startswith("P"):
+        return True
+    return False
+
+
+def _is_end_of_word(text):
+    """Checks whether the last character in text is one of a punctuation, control or whitespace character."""
+    last_char = text[-1]
+    return bool(_is_control(last_char) | _is_punctuation(last_char) | _is_whitespace(last_char))
+
+
+def _is_start_of_word(text):
+    """Checks whether the first character in text is one of a punctuation, control or whitespace character."""
+    first_char = text[0]
+    return bool(_is_control(first_char) | _is_punctuation(first_char) | _is_whitespace(first_char))
+
+
+def _insert_one_token_to_ordered_list(token_list: list[str], new_token: str):
+    """
+    Inserts one token to an ordered list if it does not already exist. Note: token_list must be sorted.
+    """
+    insertion_idx = bisect.bisect_left(token_list, new_token)
+    # Checks if new_token is already in the ordered token_list
+    if insertion_idx < len(token_list) and token_list[insertion_idx] == new_token:
+        # new_token is in token_list, don't add
+        return
+    else:
+        token_list.insert(insertion_idx, new_token)
+
+
+@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
+class PreTrainedTokenizer(PreTrainedTokenizerBase):
+    """
+    Base class for all slow tokenizers.
+
+    Inherits from [`~tokenization_utils_base.PreTrainedTokenizerBase`].
+
+    Handle all the shared methods for tokenization and special tokens as well as methods downloading/caching/loading
+    pretrained tokenizers as well as adding tokens to the vocabulary.
+
+    This class also contain the added tokens in a unified way on top of all tokenizers so we don't have to handle the
+    specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
+    """
+
+    def __init__(self, **kwargs):
+        # 1. Init the parent class
+
+        self.tokens_trie = Trie()
+
+        # 2. init `_added_tokens_decoder` if child class did not
+        if not hasattr(self, "_added_tokens_decoder"):
+            self._added_tokens_decoder: dict[int, AddedToken] = {}
+
+        # 3. if a `added_tokens_decoder` is passed, we are loading from a saved tokenizer, we overwrite
+        self._added_tokens_decoder.update(kwargs.pop("added_tokens_decoder", {}))
+        self._added_tokens_encoder: dict[str, int] = {k.content: v for v, k in self._added_tokens_decoder.items()}
+
+        # 4 init the parent class
+        super().__init__(**kwargs)
+
+        # 4. If some of the special tokens are not part of the vocab, we add them, at the end.
+        # the order of addition is the same as self.SPECIAL_TOKENS_ATTRIBUTES following `tokenizers`
+        self._add_tokens(
+            [token for token in self.all_special_tokens_extended if token not in self._added_tokens_encoder],
+            special_tokens=True,
+        )
+
+        self._decode_use_source_tokenizer = False
+
+    @property
+    def is_fast(self) -> bool:
+        return False
+
+    @property
+    def vocab_size(self) -> int:
+        """
+        `int`: Size of the base vocabulary (without the added tokens).
+        """
+        raise NotImplementedError
+
+    @property
+    def added_tokens_encoder(self) -> dict[str, int]:
+        """
+        Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
+        optimisation in `self._added_tokens_encoder` for the slow tokenizers.
+        """
+        return {k.content: v for v, k in sorted(self._added_tokens_decoder.items(), key=lambda item: item[0])}
+
+    @property
+    def added_tokens_decoder(self) -> dict[int, AddedToken]:
+        """
+        Returns the added tokens in the vocabulary as a dictionary of index to AddedToken.
+
+        Returns:
+            `dict[str, int]`: The added tokens.
+        """
+        return dict(sorted(self._added_tokens_decoder.items(), key=lambda item: item[0]))
+
+    @added_tokens_decoder.setter
+    def added_tokens_decoder(self, value: dict[int, Union[AddedToken, str]]) -> dict[int, AddedToken]:
+        # Always raise an error if string because users should define the behavior
+        for index, token in value.items():
+            if not isinstance(token, (str, AddedToken)) or not isinstance(index, int):
+                raise TypeError(
+                    f"The provided `added_tokens_decoder` has an element of type {index.__class__, token.__class__}, should be a dict of {int, Union[AddedToken, str]}"
+                )
+
+            self._added_tokens_decoder[index] = AddedToken(token) if isinstance(token, str) else token
+            self._added_tokens_encoder[str(token)] = index
+        self._update_total_vocab_size()
+
+    def get_added_vocab(self) -> dict[str, int]:
+        """
+        Returns the added tokens in the vocabulary as a dictionary of token to index. Results might be different from
+        the fast call because for now we always add the tokens even if they are already in the vocabulary. This is
+        something we should change.
+
+        Returns:
+            `dict[str, int]`: The added tokens.
+        """
+        return self._added_tokens_encoder
+
+    def __len__(self):
+        """
+        Size of the full vocabulary with the added tokens.
+        """
+        return self.total_vocab_size
+
+    def _update_total_vocab_size(self):
+        """
+        Update the size of the full vocabulary with the added tokens. Counts the `keys` and not the `values` because
+        otherwise if there is a hole in the vocab, we will add tokenizers at a wrong index. This operation is slow and
+        is only updated when adding tokens.
+        """
+        self.total_vocab_size = len(self.get_vocab())
+
+    def _add_tokens(self, new_tokens: Union[list[str], list[AddedToken]], special_tokens: bool = False) -> int:
+        """
+        Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
+        it with indices starting from length of the current vocabulary. Special tokens are sometimes already in the
+        vocab which is why they have to be handled specifically.
+
+        Args:
+            new_tokens (`list[str]`or `list[tokenizers.AddedToken]`):
+                Token(s) to add in vocabulary. A token is counted as added if it's not already in the vocabulary
+                (tested by checking if the tokenizer assign the index of the `unk_token` to them). If a token is part
+                of the vocabulary then we simply mark this token as an `AddedToken` which allows to control the
+                stripping and normalization of this token. This is NOT possible in `tokenizers`.
+            special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not the tokens should be added as special tokens.
+
+        Returns:
+            `int`: The number of tokens actually added to the vocabulary.
+
+        Examples:
+
+        ```python
+        # Let's see how to increase the vocabulary of Bert model and tokenizer
+        tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
+        model = BertModel.from_pretrained("google-bert/bert-base-uncased")
+
+        num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"])
+        print("We have added", num_added_toks, "tokens")
+        # Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
+        model.resize_token_embeddings(len(tokenizer))
+        ```"""
+        added_tokens = 0
+        if new_tokens is None:
+            return added_tokens
+        # TODO this is fairly slow to improve!
+        current_vocab = self.get_vocab().copy()
+        new_idx = len(current_vocab)  # only call this once, len gives the last index + 1
+        for token in new_tokens:
+            if not isinstance(token, (str, AddedToken)):
+                raise TypeError(f"Token {token} is not a string but a {type(token)}.")
+            if str(token) == "":
+                continue
+            if isinstance(token, str):
+                if token in self._added_tokens_encoder:
+                    continue
+                else:
+                    # very important for fast and slow equivalence!
+                    is_special = token in self.all_special_tokens or special_tokens
+                    token = AddedToken(
+                        token, rstrip=False, lstrip=False, normalized=not is_special, special=is_special
+                    )
+            elif special_tokens:
+                # doing token.special=True changes the normalization! will fix in rust
+                # this is important and the only reason why the AddedTokens in each class are normalized by default
+                token.__setstate__({"special": True, "normalized": token.normalized})
+            if token in self._added_tokens_decoder:
+                continue
+            if not token.special and token.normalized and getattr(self, "do_lower_case", False):
+                # Normalize if requested
+                token.content = token.content.lower()
+            if token.content not in current_vocab:
+                token_index = new_idx + added_tokens
+                current_vocab[token.content] = token_index
+                added_tokens += 1
+            else:
+                token_index = current_vocab[token.content]
+
+            if token.special and str(token) not in self.all_special_tokens:
+                self._special_tokens_map["additional_special_tokens"].append(token)
+            # the setter automatically updates the reverse map
+            self._added_tokens_decoder[token_index] = token
+            self._added_tokens_encoder[token.content] = token_index
+            if self.verbose:
+                logger.info(f"Adding {token} to the vocabulary")
+
+        self._update_trie()
+        self._update_total_vocab_size()
+        return added_tokens
+
+    def _update_trie(self, unique_no_split_tokens: Optional[str] = []):
+        for token in self._added_tokens_decoder.values():
+            if token.content not in self.tokens_trie._tokens:
+                self.tokens_trie.add(token.content)
+        for token in unique_no_split_tokens:
+            if token not in self.tokens_trie._tokens:
+                self.tokens_trie.add(token)
+
+    def num_special_tokens_to_add(self, pair: bool = False) -> int:
+        """
+        Returns the number of added tokens when encoding a sequence with special tokens.
+
+        
+
+        This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not put
+        this inside your training loop.
+
+        
+
+        Args:
+            pair (`bool`, *optional*, defaults to `False`):
+                Whether the number of added tokens should be computed in the case of a sequence pair or a single
+                sequence.
+
+        Returns:
+            `int`: Number of special tokens added to sequences.
+        """
+        token_ids_0 = []
+        token_ids_1 = []
+        return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None))
+
+    def tokenize(self, text: TextInput, **kwargs) -> list[str]:
+        """
+        Converts a string into a sequence of tokens, using the tokenizer.
+
+        Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies
+        (BPE/SentencePieces/WordPieces). Takes care of added tokens.
+
+        Args:
+            text (`str`):
+                The sequence to be encoded.
+            **kwargs (additional keyword arguments):
+                Passed along to the model-specific `prepare_for_tokenization` preprocessing method.
+
+        Returns:
+            `list[str]`: The list of tokens.
+        """
+        split_special_tokens = kwargs.pop("split_special_tokens", self.split_special_tokens)
+
+        text, kwargs = self.prepare_for_tokenization(text, **kwargs)
+
+        if kwargs:
+            logger.warning(f"Keyword arguments {kwargs} not recognized.")
+
+        if hasattr(self, "do_lower_case") and self.do_lower_case:
+            # convert non-special tokens to lowercase. Might be super slow as well?
+            escaped_special_toks = [re.escape(s_tok) for s_tok in (self.all_special_tokens)]
+            escaped_special_toks += [
+                re.escape(s_tok.content)
+                for s_tok in (self._added_tokens_decoder.values())
+                if not s_tok.special and s_tok.normalized
+            ]
+            pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
+            text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
+
+        if split_special_tokens:
+            no_split_token = []
+            tokens = [text]
+        else:
+            no_split_token = self._added_tokens_encoder.keys()  # don't split on any of the added tokens
+            # "This is something  else"
+            tokens = self.tokens_trie.split(text)
+
+        # ["This is something", "", "  else"]
+        for i, token in enumerate(tokens):
+            if token in no_split_token:
+                tok_extended = self._added_tokens_decoder.get(self._added_tokens_encoder[token], None)
+                left = tokens[i - 1] if i > 0 else None
+                right = tokens[i + 1] if i < len(tokens) - 1 else None
+                if isinstance(tok_extended, AddedToken):
+                    if tok_extended.rstrip and right:
+                        # A bit counter-intuitive but we strip the left of the string
+                        # since tok_extended.rstrip means the special token is eating all white spaces on its right
+                        tokens[i + 1] = right.lstrip()
+                    # Strip white spaces on the left
+                    if tok_extended.lstrip and left:
+                        tokens[i - 1] = left.rstrip()  # Opposite here
+                    if tok_extended.single_word and left and left[-1] != " ":
+                        tokens[i - 1] += token
+                        tokens[i] = ""
+                    elif tok_extended.single_word and right and right[0] != " ":
+                        tokens[i + 1] = token + tokens[i + 1]
+                        tokens[i] = ""
+                else:
+                    raise ValueError(
+                        f"{tok_extended} cannot be tokenized because it was not properly added"
+                        f" to the tokenizer. This means that it is not an `AddedToken` but a {type(tok_extended)}"
+                    )
+        # ["This is something", "", "else"]
+        tokenized_text = []
+        for token in tokens:
+            # Need to skip eventual empty (fully stripped) tokens
+            if not token:
+                continue
+            if token in no_split_token:
+                tokenized_text.append(token)
+            else:
+                tokenized_text.extend(self._tokenize(token))
+        # ["This", " is", " something", "", "else"]
+        return tokenized_text
+
+    def _tokenize(self, text, **kwargs):
+        """
+        Converts a string into a sequence of tokens (string), using the tokenizer. Split in words for word-based
+        vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
+
+        Do NOT take care of added tokens.
+        """
+        raise NotImplementedError
+
+    def convert_tokens_to_ids(self, tokens: Union[str, list[str]]) -> Union[int, list[int]]:
+        """
+        Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
+        vocabulary.
+
+        Args:
+            tokens (`str` or `list[str]`): One or several token(s) to convert to token id(s).
+
+        Returns:
+            `int` or `list[int]`: The token id or list of token ids.
+        """
+        if tokens is None:
+            return None
+
+        if isinstance(tokens, str):
+            return self._convert_token_to_id_with_added_voc(tokens)
+
+        ids = []
+        for token in tokens:
+            ids.append(self._convert_token_to_id_with_added_voc(token))
+        return ids
+
+    def _convert_token_to_id_with_added_voc(self, token):
+        if token is None:
+            return None
+
+        if token in self._added_tokens_encoder:
+            return self._added_tokens_encoder[token]
+        return self._convert_token_to_id(token)
+
+    def _convert_token_to_id(self, token):
+        raise NotImplementedError
+
+    def _encode_plus(
+        self,
+        text: Union[TextInput, PreTokenizedInput, EncodedInput],
+        text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
+        add_special_tokens: bool = True,
+        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        is_split_into_words: bool = False,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        return_token_type_ids: Optional[bool] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_overflowing_tokens: bool = False,
+        return_special_tokens_mask: bool = False,
+        return_offsets_mapping: bool = False,
+        return_length: bool = False,
+        verbose: bool = True,
+        **kwargs,
+    ) -> BatchEncoding:
+        def get_input_ids(text):
+            if isinstance(text, str):
+                tokens = self.tokenize(text, **kwargs)
+                return self.convert_tokens_to_ids(tokens)
+            elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
+                if is_split_into_words:
+                    tokens = list(
+                        itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text))
+                    )
+                    return self.convert_tokens_to_ids(tokens)
+                else:
+                    return self.convert_tokens_to_ids(text)
+            elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
+                return text
+            else:
+                if is_split_into_words:
+                    raise ValueError(
+                        f"Input {text} is not valid. Should be a string or a list/tuple of strings when"
+                        " `is_split_into_words=True`."
+                    )
+                else:
+                    raise ValueError(
+                        f"Input {text} is not valid. Should be a string, a list/tuple of strings or a list/tuple of"
+                        " integers."
+                    )
+
+        if return_offsets_mapping:
+            raise NotImplementedError(
+                "return_offset_mapping is not available when using Python tokenizers. "
+                "To use this feature, change your tokenizer to one deriving from "
+                "transformers.PreTrainedTokenizerFast. "
+                "More information on available tokenizers at "
+                "https://github.com/huggingface/transformers/pull/2674"
+            )
+
+        first_ids = get_input_ids(text)
+        second_ids = get_input_ids(text_pair) if text_pair is not None else None
+
+        return self.prepare_for_model(
+            first_ids,
+            pair_ids=second_ids,
+            add_special_tokens=add_special_tokens,
+            padding=padding_strategy.value,
+            truncation=truncation_strategy.value,
+            max_length=max_length,
+            stride=stride,
+            pad_to_multiple_of=pad_to_multiple_of,
+            padding_side=padding_side,
+            return_tensors=return_tensors,
+            prepend_batch_axis=True,
+            return_attention_mask=return_attention_mask,
+            return_token_type_ids=return_token_type_ids,
+            return_overflowing_tokens=return_overflowing_tokens,
+            return_special_tokens_mask=return_special_tokens_mask,
+            return_length=return_length,
+            verbose=verbose,
+        )
+
+    def _batch_encode_plus(
+        self,
+        batch_text_or_text_pairs: Union[
+            list[TextInput],
+            list[TextInputPair],
+            list[PreTokenizedInput],
+            list[PreTokenizedInputPair],
+            list[EncodedInput],
+            list[EncodedInputPair],
+        ],
+        add_special_tokens: bool = True,
+        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        is_split_into_words: bool = False,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        return_token_type_ids: Optional[bool] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_overflowing_tokens: bool = False,
+        return_special_tokens_mask: bool = False,
+        return_offsets_mapping: bool = False,
+        return_length: bool = False,
+        verbose: bool = True,
+        split_special_tokens: bool = False,
+        **kwargs,
+    ) -> BatchEncoding:
+        def get_input_ids(text):
+            if isinstance(text, str):
+                tokens = self.tokenize(text, **kwargs)
+                return self.convert_tokens_to_ids(tokens)
+            elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
+                if is_split_into_words:
+                    tokens = list(
+                        itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text))
+                    )
+                    return self.convert_tokens_to_ids(tokens)
+                else:
+                    return self.convert_tokens_to_ids(text)
+            elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
+                return text
+            else:
+                raise ValueError(
+                    "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
+                )
+
+        if return_offsets_mapping:
+            raise NotImplementedError(
+                "return_offset_mapping is not available when using Python tokenizers. "
+                "To use this feature, change your tokenizer to one deriving from "
+                "transformers.PreTrainedTokenizerFast."
+            )
+
+        input_ids = []
+        for ids_or_pair_ids in batch_text_or_text_pairs:
+            if (
+                not isinstance(ids_or_pair_ids, (list, tuple))
+                or is_split_into_words
+                and not isinstance(ids_or_pair_ids[0], (list, tuple))
+            ):
+                ids, pair_ids = ids_or_pair_ids, None
+            else:
+                ids, pair_ids = ids_or_pair_ids
+
+            first_ids = get_input_ids(ids)
+            second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
+            input_ids.append((first_ids, second_ids))
+
+        batch_outputs = self._batch_prepare_for_model(
+            input_ids,
+            add_special_tokens=add_special_tokens,
+            padding_strategy=padding_strategy,
+            truncation_strategy=truncation_strategy,
+            max_length=max_length,
+            stride=stride,
+            pad_to_multiple_of=pad_to_multiple_of,
+            padding_side=padding_side,
+            return_attention_mask=return_attention_mask,
+            return_token_type_ids=return_token_type_ids,
+            return_overflowing_tokens=return_overflowing_tokens,
+            return_special_tokens_mask=return_special_tokens_mask,
+            return_length=return_length,
+            return_tensors=return_tensors,
+            verbose=verbose,
+            split_special_tokens=split_special_tokens,
+        )
+
+        return BatchEncoding(batch_outputs)
+
+    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+    def _batch_prepare_for_model(
+        self,
+        batch_ids_pairs: list[Union[PreTokenizedInputPair, tuple[list[int], None]]],
+        add_special_tokens: bool = True,
+        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[str] = None,
+        return_token_type_ids: Optional[bool] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_overflowing_tokens: bool = False,
+        return_special_tokens_mask: bool = False,
+        return_length: bool = False,
+        verbose: bool = True,
+        split_special_tokens: bool = False,
+    ) -> BatchEncoding:
+        """
+        Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
+        adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
+        manages a moving window (with user defined stride) for overflowing tokens
+
+        Args:
+            batch_ids_pairs: list of tokenized input ids or input ids pairs
+        """
+
+        batch_outputs = {}
+        for first_ids, second_ids in batch_ids_pairs:
+            outputs = self.prepare_for_model(
+                first_ids,
+                second_ids,
+                add_special_tokens=add_special_tokens,
+                padding=PaddingStrategy.DO_NOT_PAD.value,  # we pad in batch afterward
+                truncation=truncation_strategy.value,
+                max_length=max_length,
+                stride=stride,
+                pad_to_multiple_of=None,  # we pad in batch afterward
+                padding_side=None,  # we pad in batch afterward
+                return_attention_mask=False,  # we pad in batch afterward
+                return_token_type_ids=return_token_type_ids,
+                return_overflowing_tokens=return_overflowing_tokens,
+                return_special_tokens_mask=return_special_tokens_mask,
+                return_length=return_length,
+                return_tensors=None,  # We convert the whole batch to tensors at the end
+                prepend_batch_axis=False,
+                verbose=verbose,
+                split_special_tokens=split_special_tokens,
+            )
+
+            for key, value in outputs.items():
+                if key not in batch_outputs:
+                    batch_outputs[key] = []
+                batch_outputs[key].append(value)
+
+        batch_outputs = self.pad(
+            batch_outputs,
+            padding=padding_strategy.value,
+            max_length=max_length,
+            pad_to_multiple_of=pad_to_multiple_of,
+            padding_side=padding_side,
+            return_attention_mask=return_attention_mask,
+        )
+
+        batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors)
+
+        return batch_outputs
+
+    def prepare_for_tokenization(
+        self, text: str, is_split_into_words: bool = False, **kwargs
+    ) -> tuple[str, dict[str, Any]]:
+        """
+        Performs any necessary transformations before tokenization.
+
+        This method should pop the arguments from kwargs and return the remaining `kwargs` as well. We test the
+        `kwargs` at the end of the encoding process to be sure all the arguments have been used.
+
+        Args:
+            text (`str`):
+                The text to prepare.
+            is_split_into_words (`bool`, *optional*, defaults to `False`):
+                Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
+                tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
+                which it will tokenize. This is useful for NER or token classification.
+            kwargs (`dict[str, Any]`, *optional*):
+                Keyword arguments to use for the tokenization.
+
+        Returns:
+            `tuple[str, dict[str, Any]]`: The prepared text and the unused kwargs.
+        """
+        return (text, kwargs)
+
+    def get_special_tokens_mask(
+        self, token_ids_0: list, token_ids_1: Optional[list] = None, already_has_special_tokens: bool = False
+    ) -> list[int]:
+        """
+        Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
+        special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
+
+        Args:
+            token_ids_0 (`list[int]`):
+                List of ids of the first sequence.
+            token_ids_1 (`list[int]`, *optional*):
+                List of ids of the second sequence.
+            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not the token list is already formatted with special tokens for the model.
+
+        Returns:
+            A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+        """
+        if already_has_special_tokens:
+            if token_ids_1 is not None:
+                raise ValueError(
+                    "You should not supply a second sequence if the provided sequence of "
+                    "ids is already formatted with special tokens for the model."
+                )
+
+            return super().get_special_tokens_mask(
+                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+            )
+        return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
+
+    @overload
+    def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: ...
+
+    @overload
+    def convert_ids_to_tokens(self, ids: list[int], skip_special_tokens: bool = False) -> list[str]: ...
+
+    def convert_ids_to_tokens(
+        self, ids: Union[int, list[int]], skip_special_tokens: bool = False
+    ) -> Union[str, list[str]]:
+        """
+        Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
+        added tokens.
+
+        Args:
+            ids (`int` or `list[int]`):
+                The token id (or token ids) to convert to tokens.
+            skip_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not to remove special tokens in the decoding.
+
+        Returns:
+            `str` or `list[str]`: The decoded token(s).
+        """
+        if isinstance(ids, int):
+            if ids in self._added_tokens_decoder:
+                return self._added_tokens_decoder[ids].content
+            else:
+                return self._convert_id_to_token(ids)
+        tokens = []
+        for index in ids:
+            index = int(index)
+            if skip_special_tokens and index in self.all_special_ids:
+                continue
+            if index in self._added_tokens_decoder:
+                tokens.append(self._added_tokens_decoder[index].content)
+            else:
+                tokens.append(self._convert_id_to_token(index))
+        return tokens
+
+    def _convert_id_to_token(self, index: int) -> str:
+        raise NotImplementedError
+
+    def convert_tokens_to_string(self, tokens: list[str]) -> str:
+        return " ".join(tokens)
+
+    def _decode(
+        self,
+        token_ids: Union[int, list[int]],
+        skip_special_tokens: bool = False,
+        clean_up_tokenization_spaces: Optional[bool] = None,
+        spaces_between_special_tokens: bool = True,
+        **kwargs,
+    ) -> str:
+        self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
+
+        filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
+        # If given is a single id, prevents splitting the string in upcoming loop
+        if isinstance(filtered_tokens, str):
+            filtered_tokens = [filtered_tokens]
+
+        legacy_added_tokens = set(self._added_tokens_encoder.keys()) - set(self.all_special_tokens) | {
+            token for token in self.additional_special_tokens if self.convert_tokens_to_ids(token) >= self.vocab_size
+        }
+        # To avoid mixing byte-level and unicode for byte-level BPT
+        # we need to build string separately for added tokens and byte-level tokens
+        # cf. https://github.com/huggingface/transformers/issues/1133
+        sub_texts = []
+        current_sub_text = []
+        # TODO @ArthurZ in version 5, special tokens should be handled in convert_tokens_to_string, while _convert_tokens_to_string
+        for token in filtered_tokens:
+            if skip_special_tokens and token in self.all_special_tokens:
+                continue
+            if token in legacy_added_tokens:
+                if current_sub_text:
+                    string = self.convert_tokens_to_string(current_sub_text)
+                    if len(string) > 0:
+                        sub_texts.append(string)
+                    current_sub_text = []
+                sub_texts.append(token)
+            else:
+                current_sub_text.append(token)
+        if current_sub_text:
+            sub_texts.append(self.convert_tokens_to_string(current_sub_text))
+
+        if spaces_between_special_tokens:
+            text = " ".join(sub_texts)
+        else:
+            text = "".join(sub_texts)
+
+        clean_up_tokenization_spaces = (
+            clean_up_tokenization_spaces
+            if clean_up_tokenization_spaces is not None
+            else self.clean_up_tokenization_spaces
+        )
+        if clean_up_tokenization_spaces:
+            clean_text = self.clean_up_tokenization(text)
+            return clean_text
+        else:
+            return text
diff --git a/phivenv/Lib/site-packages/transformers/tokenization_utils_base.py b/phivenv/Lib/site-packages/transformers/tokenization_utils_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a0d2e0567c8a91f73ce40cbba856be2fa1b5c2b
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/tokenization_utils_base.py
@@ -0,0 +1,4199 @@
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Base classes common to both the slow and the fast tokenization classes: PreTrainedTokenizerBase (host all the user
+fronting encoding methods) Special token mixing (host the special tokens logic) and BatchEncoding (wrap the dictionary
+of output with special method for the Fast tokenizers)
+"""
+
+import copy
+import json
+import os
+import re
+import warnings
+from collections import UserDict
+from collections.abc import Mapping, Sequence, Sized
+from contextlib import contextmanager
+from dataclasses import dataclass
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, Union
+
+import numpy as np
+from packaging import version
+
+from . import __version__
+from .dynamic_module_utils import custom_object_save
+from .utils import (
+    CHAT_TEMPLATE_DIR,
+    CHAT_TEMPLATE_FILE,
+    ExplicitEnum,
+    PaddingStrategy,
+    PushToHubMixin,
+    TensorType,
+    add_end_docstrings,
+    cached_file,
+    copy_func,
+    download_url,
+    extract_commit_hash,
+    is_flax_available,
+    is_jax_tensor,
+    is_mlx_available,
+    is_numpy_array,
+    is_offline_mode,
+    is_protobuf_available,
+    is_remote_url,
+    is_tf_available,
+    is_tf_tensor,
+    is_tokenizers_available,
+    is_torch_available,
+    is_torch_device,
+    is_torch_tensor,
+    list_repo_templates,
+    logging,
+    requires_backends,
+    to_py_obj,
+)
+from .utils.chat_template_utils import render_jinja_template
+from .utils.import_utils import PROTOBUF_IMPORT_ERROR
+
+
+if TYPE_CHECKING:
+    if is_torch_available():
+        import torch
+    if is_tf_available():
+        import tensorflow as tf
+    if is_flax_available():
+        import jax.numpy as jnp  # noqa: F401
+
+
+def import_protobuf_decode_error(error_message=""):
+    if is_protobuf_available():
+        from google.protobuf.message import DecodeError
+
+        return DecodeError
+    else:
+        raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
+
+
+if is_tokenizers_available():
+    from tokenizers import AddedToken
+    from tokenizers import Encoding as EncodingFast
+else:
+
+    @dataclass(frozen=False, eq=True)
+    class AddedToken:
+        """
+        AddedToken represents a token to be added to a Tokenizer An AddedToken can have special options defining the
+        way it should behave.
+
+        The `normalized` will default to `not special` if it is not specified, similarly to the definition in
+        `tokenizers`.
+        """
+
+        def __init__(
+            self, content: str, single_word=False, lstrip=False, rstrip=False, special=False, normalized=None
+        ):
+            self.content = content
+            self.single_word = single_word
+            self.lstrip = lstrip
+            self.rstrip = rstrip
+            self.special = special
+            self.normalized = normalized if normalized is not None else not special
+
+        def __getstate__(self):
+            return self.__dict__
+
+        def __str__(self):
+            return self.content
+
+    @dataclass
+    class EncodingFast:
+        """This is dummy class because without the `tokenizers` library we don't have these objects anyway"""
+
+        pass
+
+
+logger = logging.get_logger(__name__)
+
+VERY_LARGE_INTEGER = int(1e30)  # This is used to set the max input length for a model with infinite size input
+LARGE_INTEGER = int(1e20)  # This is used when we need something big but slightly smaller than VERY_LARGE_INTEGER
+
+# Define type aliases and NamedTuples
+TextInput = str
+PreTokenizedInput = list[str]
+EncodedInput = list[int]
+TextInputPair = tuple[str, str]
+PreTokenizedInputPair = tuple[list[str], list[str]]
+EncodedInputPair = tuple[list[int], list[int]]
+
+# Define type aliases for text-related non-text modalities
+AudioInput = Union["np.ndarray", "torch.Tensor", list["np.ndarray"], list["torch.Tensor"]]
+
+# Slow tokenizers used to be saved in three separated files
+SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
+ADDED_TOKENS_FILE = "added_tokens.json"
+TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
+
+# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
+FULL_TOKENIZER_FILE = "tokenizer.json"
+_re_tokenizer_file = re.compile(r"tokenizer\.(.*)\.json")
+
+
+class TruncationStrategy(ExplicitEnum):
+    """
+    Possible values for the `truncation` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in
+    an IDE.
+    """
+
+    ONLY_FIRST = "only_first"
+    ONLY_SECOND = "only_second"
+    LONGEST_FIRST = "longest_first"
+    DO_NOT_TRUNCATE = "do_not_truncate"
+
+
+class CharSpan(NamedTuple):
+    """
+    Character span in the original string.
+
+    Args:
+        start (`int`): Index of the first character in the original string.
+        end (`int`): Index of the character following the last character in the original string.
+    """
+
+    start: int
+    end: int
+
+
+class TokenSpan(NamedTuple):
+    """
+    Token span in an encoded string (list of tokens).
+
+    Args:
+        start (`int`): Index of the first token in the span.
+        end (`int`): Index of the token following the last token in the span.
+    """
+
+    start: int
+    end: int
+
+
+class BatchEncoding(UserDict):
+    """
+    Holds the output of the [`~tokenization_utils_base.PreTrainedTokenizerBase.__call__`],
+    [`~tokenization_utils_base.PreTrainedTokenizerBase.encode_plus`] and
+    [`~tokenization_utils_base.PreTrainedTokenizerBase.batch_encode_plus`] methods (tokens, attention_masks, etc).
+
+    This class is derived from a python dictionary and can be used as a dictionary. In addition, this class exposes
+    utility methods to map from word/character space to token space.
+
+    Args:
+        data (`dict`, *optional*):
+            Dictionary of lists/arrays/tensors returned by the `__call__`/`encode_plus`/`batch_encode_plus` methods
+            ('input_ids', 'attention_mask', etc.).
+        encoding (`tokenizers.Encoding` or `Sequence[tokenizers.Encoding]`, *optional*):
+            If the tokenizer is a fast tokenizer which outputs additional information like mapping from word/character
+            space to token space the `tokenizers.Encoding` instance or list of instance (for batches) hold this
+            information.
+        tensor_type (`Union[None, str, TensorType]`, *optional*):
+            You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
+            initialization.
+        prepend_batch_axis (`bool`, *optional*, defaults to `False`):
+            Whether or not to add a batch axis when converting to tensors (see `tensor_type` above). Note that this
+            parameter has an effect if the parameter `tensor_type` is set, *otherwise has no effect*.
+        n_sequences (`Optional[int]`, *optional*):
+            You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
+            initialization.
+    """
+
+    def __init__(
+        self,
+        data: Optional[dict[str, Any]] = None,
+        encoding: Optional[Union[EncodingFast, Sequence[EncodingFast]]] = None,
+        tensor_type: Union[None, str, TensorType] = None,
+        prepend_batch_axis: bool = False,
+        n_sequences: Optional[int] = None,
+    ):
+        super().__init__(data)
+
+        if isinstance(encoding, EncodingFast):
+            encoding = [encoding]
+
+        self._encodings = encoding
+
+        if n_sequences is None and encoding is not None and encoding:
+            n_sequences = encoding[0].n_sequences
+
+        self._n_sequences = n_sequences
+
+        self.convert_to_tensors(tensor_type=tensor_type, prepend_batch_axis=prepend_batch_axis)
+
+    @property
+    def n_sequences(self) -> Optional[int]:
+        """
+        `Optional[int]`: The number of sequences used to generate each sample from the batch encoded in this
+        [`BatchEncoding`]. Currently can be one of `None` (unknown), `1` (a single sentence) or `2` (a pair of
+        sentences)
+        """
+        return self._n_sequences
+
+    @property
+    def is_fast(self) -> bool:
+        """
+        `bool`: Indicate whether this [`BatchEncoding`] was generated from the result of a [`PreTrainedTokenizerFast`]
+        or not.
+        """
+        return self._encodings is not None
+
+    def __getitem__(self, item: Union[int, str]) -> Union[Any, EncodingFast]:
+        """
+        If the key is a string, returns the value of the dict associated to `key` ('input_ids', 'attention_mask',
+        etc.).
+
+        If the key is an integer, get the `tokenizers.Encoding` for batch item with index `key`.
+
+        If the key is a slice, returns the value of the dict associated to `key` ('input_ids', 'attention_mask', etc.)
+        with the constraint of slice.
+        """
+        if isinstance(item, str):
+            return self.data[item]
+        elif self._encodings is not None:
+            return self._encodings[item]
+        elif isinstance(item, slice):
+            return {key: self.data[key][item] for key in self.data}
+        else:
+            raise KeyError(
+                "Invalid key. Only three types of key are available: "
+                "(1) string, (2) integers for backend Encoding, and (3) slices for data subsetting."
+            )
+
+    def __getattr__(self, item: str):
+        try:
+            return self.data[item]
+        except KeyError:
+            raise AttributeError
+
+    def __getstate__(self):
+        return {"data": self.data, "encodings": self._encodings}
+
+    def __setstate__(self, state):
+        if "data" in state:
+            self.data = state["data"]
+
+        if "encodings" in state:
+            self._encodings = state["encodings"]
+
+    # After this point:
+    # Extended properties and methods only available for fast (Rust-based) tokenizers
+    # provided by HuggingFace tokenizers library.
+
+    @property
+    def encodings(self) -> Optional[list[EncodingFast]]:
+        """
+        `Optional[list[tokenizers.Encoding]]`: The list all encodings from the tokenization process. Returns `None` if
+        the input was tokenized through Python (i.e., not a fast) tokenizer.
+        """
+        return self._encodings
+
+    def tokens(self, batch_index: int = 0) -> list[str]:
+        """
+        Return the list of tokens (sub-parts of the input strings after word/subword splitting and before conversion to
+        integer indices) at a given batch index (only works for the output of a fast tokenizer).
+
+        Args:
+            batch_index (`int`, *optional*, defaults to 0): The index to access in the batch.
+
+        Returns:
+            `list[str]`: The list of tokens at that index.
+        """
+        if not self._encodings:
+            raise ValueError(
+                "tokens() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`"
+                " class)."
+            )
+        return self._encodings[batch_index].tokens
+
+    def sequence_ids(self, batch_index: int = 0) -> list[Optional[int]]:
+        """
+        Return a list mapping the tokens to the id of their original sentences:
+
+            - `None` for special tokens added around or between sequences,
+            - `0` for tokens corresponding to words in the first sequence,
+            - `1` for tokens corresponding to words in the second sequence when a pair of sequences was jointly
+              encoded.
+
+        Args:
+            batch_index (`int`, *optional*, defaults to 0): The index to access in the batch.
+
+        Returns:
+            `list[Optional[int]]`: A list indicating the sequence id corresponding to each token. Special tokens added
+            by the tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding
+            sequence.
+        """
+        if not self._encodings:
+            raise ValueError(
+                "sequence_ids() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`"
+                " class)."
+            )
+        return self._encodings[batch_index].sequence_ids
+
+    def words(self, batch_index: int = 0) -> list[Optional[int]]:
+        """
+        Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer.
+
+        Args:
+            batch_index (`int`, *optional*, defaults to 0): The index to access in the batch.
+
+        Returns:
+            `list[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by the
+            tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding word
+            (several tokens will be mapped to the same word index if they are parts of that word).
+        """
+        if not self._encodings:
+            raise ValueError(
+                "words() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`"
+                " class)."
+            )
+        warnings.warn(
+            "`BatchEncoding.words()` property is deprecated and should be replaced with the identical, "
+            "but more self-explanatory `BatchEncoding.word_ids()` property.",
+            FutureWarning,
+        )
+        return self.word_ids(batch_index)
+
+    def word_ids(self, batch_index: int = 0) -> list[Optional[int]]:
+        """
+        Return a list mapping the tokens to their actual word in the initial sentence for a fast tokenizer.
+
+        Args:
+            batch_index (`int`, *optional*, defaults to 0): The index to access in the batch.
+
+        Returns:
+            `list[Optional[int]]`: A list indicating the word corresponding to each token. Special tokens added by the
+            tokenizer are mapped to `None` and other tokens are mapped to the index of their corresponding word
+            (several tokens will be mapped to the same word index if they are parts of that word).
+        """
+        if not self._encodings:
+            raise ValueError(
+                "word_ids() is not available when using non-fast tokenizers (e.g. instance of a `XxxTokenizerFast`"
+                " class)."
+            )
+        return self._encodings[batch_index].word_ids
+
+    def token_to_sequence(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int:
+        """
+        Get the index of the sequence represented by the given token. In the general use case, this method returns `0`
+        for a single sequence or the first sequence of a pair, and `1` for the second sequence of a pair
+
+        Can be called as:
+
+        - `self.token_to_sequence(token_index)` if batch size is 1
+        - `self.token_to_sequence(batch_index, token_index)` if batch size is greater than 1
+
+        This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e.,
+        words are defined by the user). In this case it allows to easily associate encoded tokens with provided
+        tokenized words.
+
+        Args:
+            batch_or_token_index (`int`):
+                Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of
+                the token in the sequence.
+            token_index (`int`, *optional*):
+                If a batch index is provided in *batch_or_token_index*, this can be the index of the token in the
+                sequence.
+
+        Returns:
+            `int`: Index of the word in the input sequence.
+        """
+
+        if not self._encodings:
+            raise ValueError("token_to_sequence() is not available when using Python based tokenizers")
+        if token_index is not None:
+            batch_index = batch_or_token_index
+        else:
+            batch_index = 0
+            token_index = batch_or_token_index
+        if batch_index < 0:
+            batch_index = self._batch_size + batch_index
+        if token_index < 0:
+            token_index = self._seq_len + token_index
+        return self._encodings[batch_index].token_to_sequence(token_index)
+
+    def token_to_word(self, batch_or_token_index: int, token_index: Optional[int] = None) -> int:
+        """
+        Get the index of the word corresponding (i.e. comprising) to an encoded token in a sequence of the batch.
+
+        Can be called as:
+
+        - `self.token_to_word(token_index)` if batch size is 1
+        - `self.token_to_word(batch_index, token_index)` if batch size is greater than 1
+
+        This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e.,
+        words are defined by the user). In this case it allows to easily associate encoded tokens with provided
+        tokenized words.
+
+        Args:
+            batch_or_token_index (`int`):
+                Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
+                the token in the sequence.
+            token_index (`int`, *optional*):
+                If a batch index is provided in *batch_or_token_index*, this can be the index of the token in the
+                sequence.
+
+        Returns:
+            `int`: Index of the word in the input sequence.
+        """
+
+        if not self._encodings:
+            raise ValueError("token_to_word() is not available when using Python based tokenizers")
+        if token_index is not None:
+            batch_index = batch_or_token_index
+        else:
+            batch_index = 0
+            token_index = batch_or_token_index
+        if batch_index < 0:
+            batch_index = self._batch_size + batch_index
+        if token_index < 0:
+            token_index = self._seq_len + token_index
+        return self._encodings[batch_index].token_to_word(token_index)
+
+    def word_to_tokens(
+        self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0
+    ) -> Optional[TokenSpan]:
+        """
+        Get the encoded token span corresponding to a word in a sequence of the batch.
+
+        Token spans are returned as a [`~tokenization_utils_base.TokenSpan`] with:
+
+        - **start** -- Index of the first token.
+        - **end** -- Index of the token following the last token.
+
+        Can be called as:
+
+        - `self.word_to_tokens(word_index, sequence_index: int = 0)` if batch size is 1
+        - `self.word_to_tokens(batch_index, word_index, sequence_index: int = 0)` if batch size is greater or equal to
+          1
+
+        This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words
+        are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized
+        words.
+
+        Args:
+            batch_or_word_index (`int`):
+                Index of the sequence in the batch. If the batch only comprises one sequence, this can be the index of
+                the word in the sequence.
+            word_index (`int`, *optional*):
+                If a batch index is provided in *batch_or_token_index*, this can be the index of the word in the
+                sequence.
+            sequence_index (`int`, *optional*, defaults to 0):
+                If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
+                or 1) the provided word index belongs to.
+
+        Returns:
+            ([`~tokenization_utils_base.TokenSpan`], *optional*): Span of tokens in the encoded sequence. Returns
+            `None` if no tokens correspond to the word. This can happen especially when the token is a special token
+            that has been used to format the tokenization. For example when we add a class token at the very beginning
+            of the tokenization.
+        """
+
+        if not self._encodings:
+            raise ValueError("word_to_tokens() is not available when using Python based tokenizers")
+        if word_index is not None:
+            batch_index = batch_or_word_index
+        else:
+            batch_index = 0
+            word_index = batch_or_word_index
+        if batch_index < 0:
+            batch_index = self._batch_size + batch_index
+        if word_index < 0:
+            word_index = self._seq_len + word_index
+        span = self._encodings[batch_index].word_to_tokens(word_index, sequence_index)
+        return TokenSpan(*span) if span is not None else None
+
+    def token_to_chars(self, batch_or_token_index: int, token_index: Optional[int] = None) -> Optional[CharSpan]:
+        """
+        Get the character span corresponding to an encoded token in a sequence of the batch.
+
+        Character spans are returned as a [`~tokenization_utils_base.CharSpan`] with:
+
+        - **start** -- Index of the first character in the original string associated to the token.
+        - **end** -- Index of the character following the last character in the original string associated to the
+          token.
+
+        Can be called as:
+
+        - `self.token_to_chars(token_index)` if batch size is 1
+        - `self.token_to_chars(batch_index, token_index)` if batch size is greater or equal to 1
+
+        Args:
+            batch_or_token_index (`int`):
+                Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
+                the token in the sequence.
+            token_index (`int`, *optional*):
+                If a batch index is provided in *batch_or_token_index*, this can be the index of the token or tokens in
+                the sequence.
+
+        Returns:
+            [`~tokenization_utils_base.CharSpan`]: Span of characters in the original string, or None, if the token
+            (e.g. , ) doesn't correspond to any chars in the origin string.
+        """
+
+        if not self._encodings:
+            raise ValueError("token_to_chars() is not available when using Python based tokenizers")
+        if token_index is not None:
+            batch_index = batch_or_token_index
+        else:
+            batch_index = 0
+            token_index = batch_or_token_index
+        span_indices = self._encodings[batch_index].token_to_chars(token_index)
+
+        return CharSpan(*span_indices) if span_indices is not None else None
+
+    def char_to_token(
+        self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0
+    ) -> int:
+        """
+        Get the index of the token in the encoded output comprising a character in the original string for a sequence
+        of the batch.
+
+        Can be called as:
+
+        - `self.char_to_token(char_index)` if batch size is 1
+        - `self.char_to_token(batch_index, char_index)` if batch size is greater or equal to 1
+
+        This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words
+        are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized
+        words.
+
+        Args:
+            batch_or_char_index (`int`):
+                Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
+                the word in the sequence
+            char_index (`int`, *optional*):
+                If a batch index is provided in *batch_or_token_index*, this can be the index of the word in the
+                sequence.
+            sequence_index (`int`, *optional*, defaults to 0):
+                If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
+                or 1) the provided character index belongs to.
+
+
+        Returns:
+            `int`: Index of the token, or None if the char index refers to a whitespace only token and whitespace is
+                   trimmed with `trim_offsets=True`.
+        """
+
+        if not self._encodings:
+            raise ValueError("char_to_token() is not available when using Python based tokenizers")
+        if char_index is not None:
+            batch_index = batch_or_char_index
+        else:
+            batch_index = 0
+            char_index = batch_or_char_index
+        return self._encodings[batch_index].char_to_token(char_index, sequence_index)
+
+    def word_to_chars(
+        self, batch_or_word_index: int, word_index: Optional[int] = None, sequence_index: int = 0
+    ) -> CharSpan:
+        """
+        Get the character span in the original string corresponding to given word in a sequence of the batch.
+
+        Character spans are returned as a CharSpan NamedTuple with:
+
+        - start: index of the first character in the original string
+        - end: index of the character following the last character in the original string
+
+        Can be called as:
+
+        - `self.word_to_chars(word_index)` if batch size is 1
+        - `self.word_to_chars(batch_index, word_index)` if batch size is greater or equal to 1
+
+        Args:
+            batch_or_word_index (`int`):
+                Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
+                the word in the sequence
+            word_index (`int`, *optional*):
+                If a batch index is provided in *batch_or_token_index*, this can be the index of the word in the
+                sequence.
+            sequence_index (`int`, *optional*, defaults to 0):
+                If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
+                or 1) the provided word index belongs to.
+
+        Returns:
+            `CharSpan` or `list[CharSpan]`: Span(s) of the associated character or characters in the string. CharSpan
+            are NamedTuple with:
+
+                - start: index of the first character associated to the token in the original string
+                - end: index of the character following the last character associated to the token in the original
+                  string
+        """
+
+        if not self._encodings:
+            raise ValueError("word_to_chars() is not available when using Python based tokenizers")
+        if word_index is not None:
+            batch_index = batch_or_word_index
+        else:
+            batch_index = 0
+            word_index = batch_or_word_index
+        return CharSpan(*(self._encodings[batch_index].word_to_chars(word_index, sequence_index)))
+
+    def char_to_word(self, batch_or_char_index: int, char_index: Optional[int] = None, sequence_index: int = 0) -> int:
+        """
+        Get the word in the original string corresponding to a character in the original string of a sequence of the
+        batch.
+
+        Can be called as:
+
+        - `self.char_to_word(char_index)` if batch size is 1
+        - `self.char_to_word(batch_index, char_index)` if batch size is greater than 1
+
+        This method is particularly suited when the input sequences are provided as pre-tokenized sequences (i.e. words
+        are defined by the user). In this case it allows to easily associate encoded tokens with provided tokenized
+        words.
+
+        Args:
+            batch_or_char_index (`int`):
+                Index of the sequence in the batch. If the batch only comprise one sequence, this can be the index of
+                the character in the original string.
+            char_index (`int`, *optional*):
+                If a batch index is provided in *batch_or_token_index*, this can be the index of the character in the
+                original string.
+            sequence_index (`int`, *optional*, defaults to 0):
+                If pair of sequences are encoded in the batch this can be used to specify which sequence in the pair (0
+                or 1) the provided character index belongs to.
+
+
+        Returns:
+            `int` or `list[int]`: Index or indices of the associated encoded token(s).
+        """
+
+        if not self._encodings:
+            raise ValueError("char_to_word() is not available when using Python based tokenizers")
+        if char_index is not None:
+            batch_index = batch_or_char_index
+        else:
+            batch_index = 0
+            char_index = batch_or_char_index
+        return self._encodings[batch_index].char_to_word(char_index, sequence_index)
+
+    def convert_to_tensors(
+        self, tensor_type: Optional[Union[str, TensorType]] = None, prepend_batch_axis: bool = False
+    ):
+        """
+        Convert the inner content to tensors.
+
+        Args:
+            tensor_type (`str` or [`~utils.TensorType`], *optional*):
+                The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
+                `None`, no modification is done.
+            prepend_batch_axis (`int`, *optional*, defaults to `False`):
+                Whether or not to add the batch dimension during the conversion.
+        """
+        if tensor_type is None:
+            return self
+
+        # Convert to TensorType
+        if not isinstance(tensor_type, TensorType):
+            tensor_type = TensorType(tensor_type)
+
+        # Get a function reference for the correct framework
+        if tensor_type == TensorType.TENSORFLOW:
+            if not is_tf_available():
+                raise ImportError(
+                    "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
+                )
+            import tensorflow as tf
+
+            as_tensor = tf.constant
+            is_tensor = tf.is_tensor
+        elif tensor_type == TensorType.PYTORCH:
+            if not is_torch_available():
+                raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
+            import torch
+
+            is_tensor = torch.is_tensor
+
+            def as_tensor(value, dtype=None):
+                if isinstance(value, list) and isinstance(value[0], np.ndarray):
+                    return torch.from_numpy(np.array(value))
+                return torch.tensor(value)
+
+        elif tensor_type == TensorType.JAX:
+            if not is_flax_available():
+                raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
+            import jax.numpy as jnp  # noqa: F811
+
+            as_tensor = jnp.array
+            is_tensor = is_jax_tensor
+
+        elif tensor_type == TensorType.MLX:
+            if not is_mlx_available():
+                raise ImportError("Unable to convert output to MLX tensors format, MLX is not installed.")
+            import mlx.core as mx
+
+            as_tensor = mx.array
+
+            def is_tensor(obj):
+                return isinstance(obj, mx.array)
+        else:
+
+            def as_tensor(value, dtype=None):
+                if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
+                    value_lens = [len(val) for val in value]
+                    if len(set(value_lens)) > 1 and dtype is None:
+                        # we have a ragged list so handle explicitly
+                        value = as_tensor([np.asarray(val) for val in value], dtype=object)
+                return np.asarray(value, dtype=dtype)
+
+            is_tensor = is_numpy_array
+
+        # Do the tensor conversion in batch
+        for key, value in self.items():
+            try:
+                if prepend_batch_axis:
+                    value = [value]
+
+                if not is_tensor(value):
+                    tensor = as_tensor(value)
+
+                    # Removing this for now in favor of controlling the shape with `prepend_batch_axis`
+                    # # at-least2d
+                    # if tensor.ndim > 2:
+                    #     tensor = tensor.squeeze(0)
+                    # elif tensor.ndim < 2:
+                    #     tensor = tensor[None, :]
+
+                    self[key] = tensor
+            except Exception as e:
+                if key == "overflowing_tokens":
+                    raise ValueError(
+                        "Unable to create tensor returning overflowing tokens of different lengths. "
+                        "Please see if a fast version of this tokenizer is available to have this feature available."
+                    ) from e
+                raise ValueError(
+                    "Unable to create tensor, you should probably activate truncation and/or padding with"
+                    " 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your"
+                    f" features (`{key}` in this case) have excessive nesting (inputs type `list` where type `int` is"
+                    " expected)."
+                ) from e
+
+        return self
+
+    def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False) -> "BatchEncoding":
+        """
+        Send all values to device by calling `v.to(device, non_blocking=non_blocking)` (PyTorch only).
+
+        Args:
+            device (`str` or `torch.device`): The device to put the tensors on.
+            non_blocking (`bool`): Whether to perform the copy asynchronously.
+
+        Returns:
+            [`BatchEncoding`]: The same instance after modification.
+        """
+        requires_backends(self, ["torch"])
+
+        # This check catches things like APEX blindly calling "to" on all inputs to a module
+        # Otherwise it passes the casts down and casts the LongTensor containing the token idxs
+        # into a HalfTensor
+        if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
+            self.data = {
+                k: v.to(device=device, non_blocking=non_blocking) if hasattr(v, "to") and callable(v.to) else v
+                for k, v in self.data.items()
+            }
+        else:
+            logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
+        return self
+
+
+class SpecialTokensMixin:
+    """
+    A mixin derived by [`PreTrainedTokenizer`] and [`PreTrainedTokenizerFast`] to handle specific behaviors related to
+    special tokens. In particular, this class hold the attributes which can be used to directly access these special
+    tokens in a model-independent manner and allow to set and update the special tokens.
+
+    Args:
+        bos_token (`str` or `tokenizers.AddedToken`, *optional*):
+            A special token representing the beginning of a sentence.
+        eos_token (`str` or `tokenizers.AddedToken`, *optional*):
+            A special token representing the end of a sentence.
+        unk_token (`str` or `tokenizers.AddedToken`, *optional*):
+            A special token representing an out-of-vocabulary token.
+        sep_token (`str` or `tokenizers.AddedToken`, *optional*):
+            A special token separating two different sentences in the same input (used by BERT for instance).
+        pad_token (`str` or `tokenizers.AddedToken`, *optional*):
+            A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
+            attention mechanisms or loss computation.
+        cls_token (`str` or `tokenizers.AddedToken`, *optional*):
+            A special token representing the class of the input (used by BERT for instance).
+        mask_token (`str` or `tokenizers.AddedToken`, *optional*):
+            A special token representing a masked token (used by masked-language modeling pretraining objectives, like
+            BERT).
+        additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*):
+            A tuple or a list of additional tokens, which will be marked as `special`, meaning that they will be
+            skipped when decoding if `skip_special_tokens` is set to `True`.
+    """
+
+    SPECIAL_TOKENS_ATTRIBUTES = [
+        "bos_token",
+        "eos_token",
+        "unk_token",
+        "sep_token",
+        "pad_token",
+        "cls_token",
+        "mask_token",
+        "additional_special_tokens",
+    ]
+
+    def __init__(self, verbose=False, **kwargs):
+        self._pad_token_type_id = 0
+        self.verbose = verbose
+        self._special_tokens_map = dict.fromkeys(self.SPECIAL_TOKENS_ATTRIBUTES)
+        self._special_tokens_map["additional_special_tokens"] = []  # for BC where it defaults to empty list
+
+        # We directly set the hidden value to allow initialization with special tokens
+        # which are not yet in the vocabulary. Necessary for serialization/de-serialization
+        # TODO clean this up at some point (probably by switching to fast tokenizers)
+
+        for key, value in kwargs.items():
+            if value is None:
+                continue
+            if key in self.SPECIAL_TOKENS_ATTRIBUTES:
+                if key == "additional_special_tokens":
+                    assert isinstance(value, (list, tuple)), f"Value {value} is not a list or tuple"
+                    assert all(isinstance(t, (str, AddedToken)) for t in value), (
+                        "One of the tokens is not a string or an AddedToken"
+                    )
+                    setattr(self, key, value)
+                elif isinstance(value, (str, AddedToken)):
+                    setattr(self, key, value)
+                else:
+                    raise TypeError(f"Special token {key} has to be either str or AddedToken but got: {type(value)}")
+
+    def sanitize_special_tokens(self) -> int:
+        """
+        The `sanitize_special_tokens` is now deprecated kept for backward compatibility and will be removed in
+        transformers v5.
+        """
+        logger.warning_once("The `sanitize_special_tokens` will be removed in transformers v5.")
+        return self.add_tokens(self.all_special_tokens_extended, special_tokens=True)
+
+    def add_special_tokens(
+        self,
+        special_tokens_dict: dict[str, Union[str, AddedToken, Sequence[Union[str, AddedToken]]]],
+        replace_additional_special_tokens=True,
+    ) -> int:
+        """
+        Add a dictionary of special tokens (eos, pad, cls, etc.) to the encoder and link them to class attributes. If
+        special tokens are NOT in the vocabulary, they are added to it (indexed starting from the last index of the
+        current vocabulary).
+
+        When adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix of the
+        model so that its embedding matrix matches the tokenizer.
+
+        In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method.
+
+        Using `add_special_tokens` will ensure your special tokens can be used in several ways:
+
+        - Special tokens can be skipped when decoding using `skip_special_tokens = True`.
+        - Special tokens are carefully handled by the tokenizer (they are never split), similar to `AddedTokens`.
+        - You can easily refer to special tokens using tokenizer class attributes like `tokenizer.cls_token`. This
+          makes it easy to develop model-agnostic training and fine-tuning scripts.
+
+        When possible, special tokens are already registered for provided pretrained models (for instance
+        [`BertTokenizer`] `cls_token` is already registered to be `'[CLS]'` and XLM's one is also registered to be
+        `''`).
+
+        Args:
+            special_tokens_dict (dictionary *str* to *str*, `tokenizers.AddedToken`, or `Sequence[Union[str, AddedToken]]`):
+                Keys should be in the list of predefined special attributes: [`bos_token`, `eos_token`, `unk_token`,
+                `sep_token`, `pad_token`, `cls_token`, `mask_token`, `additional_special_tokens`].
+
+                Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer
+                assign the index of the `unk_token` to them).
+            replace_additional_special_tokens (`bool`, *optional*,, defaults to `True`):
+                If `True`, the existing list of additional special tokens will be replaced by the list provided in
+                `special_tokens_dict`. Otherwise, `self._special_tokens_map["additional_special_tokens"]` is just extended. In the former
+                case, the tokens will NOT be removed from the tokenizer's full vocabulary - they are only being flagged
+                as non-special tokens. Remember, this only affects which tokens are skipped during decoding, not the
+                `added_tokens_encoder` and `added_tokens_decoder`. This means that the previous
+                `additional_special_tokens` are still added tokens, and will not be split by the model.
+
+        Returns:
+            `int`: Number of tokens added to the vocabulary.
+
+        Examples:
+
+        ```python
+        # Let's see how to add a new classification token to GPT-2
+        tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
+        model = GPT2Model.from_pretrained("openai-community/gpt2")
+
+        special_tokens_dict = {"cls_token": ""}
+
+        num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
+        print("We have added", num_added_toks, "tokens")
+        # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
+        model.resize_token_embeddings(len(tokenizer))
+
+        assert tokenizer.cls_token == ""
+        ```"""
+        if not special_tokens_dict:
+            return 0
+
+        added_tokens = []
+        for key, value in special_tokens_dict.items():
+            assert key in self.SPECIAL_TOKENS_ATTRIBUTES, f"Key {key} is not a special token"
+
+            if self.verbose:
+                logger.info(f"Assigning {value} to the {key} key of the tokenizer")
+
+            if key == "additional_special_tokens":
+                assert isinstance(value, (list, tuple)) and all(isinstance(t, (str, AddedToken)) for t in value), (
+                    f"Tokens {value} for key {key} should all be str or AddedToken instances"
+                )
+
+                to_add = []
+                for token in value:
+                    if isinstance(token, str):
+                        # for legacy purpose we default to stripping. `test_add_tokens_tokenizer` depends on this
+                        token = AddedToken(token, rstrip=False, lstrip=False, normalized=False, special=True)
+                    if not replace_additional_special_tokens and str(token) in self.additional_special_tokens:
+                        continue
+                    to_add.append(token)
+                if replace_additional_special_tokens and len(to_add) > 0:
+                    setattr(self, key, list(to_add))
+                else:
+                    self._special_tokens_map["additional_special_tokens"].extend(to_add)
+                added_tokens += to_add
+
+            else:
+                if not isinstance(value, (str, AddedToken)):
+                    raise ValueError(f"Token {value} for key {key} should be a str or an AddedToken instance")
+                if isinstance(value, (str)):
+                    # for legacy purpose we default to stripping. `False` depends on this
+                    value = AddedToken(value, rstrip=False, lstrip=False, normalized=False, special=True)
+                if isinstance(value, AddedToken):
+                    setattr(self, key, value)
+                if value not in added_tokens:
+                    added_tokens.append(value)
+
+        # if we are adding tokens that were not part of the vocab, we ought to add them
+        added_tokens = self.add_tokens(added_tokens, special_tokens=True)
+        return added_tokens
+
+    def add_tokens(
+        self, new_tokens: Union[str, AddedToken, Sequence[Union[str, AddedToken]]], special_tokens: bool = False
+    ) -> int:
+        """
+        Add a list of new tokens to the tokenizer class. If the new tokens are not in the vocabulary, they are added to
+        it with indices starting from length of the current vocabulary and will be isolated before the tokenization
+        algorithm is applied. Added tokens and tokens from the vocabulary of the tokenization algorithm are therefore
+        not treated in the same way.
+
+        Note, when adding new tokens to the vocabulary, you should make sure to also resize the token embedding matrix
+        of the model so that its embedding matrix matches the tokenizer.
+
+        In order to do that, please use the [`~PreTrainedModel.resize_token_embeddings`] method.
+
+        Args:
+            new_tokens (`str`, `tokenizers.AddedToken` or a sequence of *str* or `tokenizers.AddedToken`):
+                Tokens are only added if they are not already in the vocabulary. `tokenizers.AddedToken` wraps a string
+                token to let you personalize its behavior: whether this token should only match against a single word,
+                whether this token should strip all potential whitespaces on the left side, whether this token should
+                strip all potential whitespaces on the right side, etc.
+            special_tokens (`bool`, *optional*, defaults to `False`):
+                Can be used to specify if the token is a special token. This mostly change the normalization behavior
+                (special tokens like CLS or [MASK] are usually not lower-cased for instance).
+
+                See details for `tokenizers.AddedToken` in HuggingFace tokenizers library.
+
+        Returns:
+            `int`: Number of tokens added to the vocabulary.
+
+        Examples:
+
+        ```python
+        # Let's see how to increase the vocabulary of Bert model and tokenizer
+        tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-uncased")
+        model = BertModel.from_pretrained("google-bert/bert-base-uncased")
+
+        num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"])
+        print("We have added", num_added_toks, "tokens")
+        # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
+        model.resize_token_embeddings(len(tokenizer))
+        ```"""
+        if not new_tokens:
+            return 0
+
+        if not isinstance(new_tokens, (list, tuple)):
+            new_tokens = [new_tokens]
+
+        return self._add_tokens(new_tokens, special_tokens=special_tokens)
+
+    def _add_tokens(self, new_tokens: Union[list[str], list[AddedToken]], special_tokens: bool = False) -> int:
+        raise NotImplementedError
+
+    @property
+    def pad_token_type_id(self) -> int:
+        """
+        `int`: Id of the padding token type in the vocabulary.
+        """
+        return self._pad_token_type_id
+
+    def __setattr__(self, key, value):
+        key_without_id = key
+        key_is_special_id = key.endswith("_id") or key.endswith("_ids")
+        if key_is_special_id:
+            key_without_id = key[:-3] if not key.endswith("_ids") else key[:-4]
+
+        if self.__dict__.get("_special_tokens_map", None) is not None and any(
+            name in self.__dict__["_special_tokens_map"] for name in [key, key_without_id]
+        ):
+            if key_is_special_id:
+                if value is not None:
+                    value = (
+                        self.convert_ids_to_tokens(value)
+                        if key != "additional_special_tokens"
+                        else [self.convert_ids_to_tokens(val) for val in value]
+                    )
+                key = key_without_id
+
+            if key != "additional_special_tokens" and not isinstance(value, (str, AddedToken)) and value is not None:
+                raise ValueError(f"Cannot set a non-string value as the {key}")
+            self._special_tokens_map[key] = value
+        else:
+            super().__setattr__(key, value)
+
+    def __getattr__(self, key):
+        key_without_id = key
+        key_is_special_id = key.endswith("_id") or key.endswith("_ids")
+        if key_is_special_id:
+            key_without_id = key[:-3] if not key.endswith("_ids") else key[:-4]
+
+        if self.__dict__.get("_special_tokens_map", None) is not None and any(
+            name in self.__dict__["_special_tokens_map"] for name in [key, key_without_id]
+        ):
+            _special_tokens_map = self.__dict__["_special_tokens_map"]
+            if not key_is_special_id:
+                if _special_tokens_map[key] is None:
+                    if self.verbose:
+                        logger.error(f"Using {key}, but it is not set yet.")
+                    return None
+                value = _special_tokens_map[key]
+                return str(value) if key != "additional_special_tokens" else [str(tok) for tok in value]
+            else:
+                attr_as_tokens = getattr(self, key_without_id)
+                return self.convert_tokens_to_ids(attr_as_tokens) if attr_as_tokens is not None else None
+
+        if key not in self.__dict__:
+            raise AttributeError(f"{self.__class__.__name__} has no attribute {key}")
+        else:
+            return super().__getattr__(key)
+
+    @property
+    def special_tokens_map(self) -> dict[str, Union[str, list[str]]]:
+        """
+        `dict[str, Union[str, list[str]]]`: A dictionary mapping special token class attributes (`cls_token`,
+        `unk_token`, etc.) to their values (`''`, `''`, etc.).
+
+        Convert potential tokens of `tokenizers.AddedToken` type to string.
+        """
+        set_attr = {}
+        for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
+            attr_value = getattr(self, attr)
+            if attr_value:
+                set_attr[attr] = attr_value
+        return set_attr
+
+    @property
+    def special_tokens_map_extended(self) -> dict[str, Union[str, AddedToken, list[Union[str, AddedToken]]]]:
+        """
+        `dict[str, Union[str, tokenizers.AddedToken, list[Union[str, tokenizers.AddedToken]]]]`: A dictionary mapping
+        special token class attributes (`cls_token`, `unk_token`, etc.) to their values (`''`, `''`, etc.).
+
+        Don't convert tokens of `tokenizers.AddedToken` type to string so they can be used to control more finely how
+        special tokens are tokenized.
+        """
+        set_attr = {}
+        for attr in self.SPECIAL_TOKENS_ATTRIBUTES:
+            attr_value = self._special_tokens_map[attr]
+            if attr_value:
+                set_attr[attr] = attr_value
+        return set_attr
+
+    @property
+    def all_special_tokens_extended(self) -> list[Union[str, AddedToken]]:
+        """
+        `list[Union[str, tokenizers.AddedToken]]`: All the special tokens (`''`, `''`, etc.), the order has
+        nothing to do with the index of each tokens. If you want to know the correct indices, check
+        `self.added_tokens_encoder`. We can't create an order anymore as the keys are `AddedTokens` and not `Strings`.
+
+        Don't convert tokens of `tokenizers.AddedToken` type to string so they can be used to control more finely how
+        special tokens are tokenized.
+        """
+        all_tokens = []
+        seen = set()
+        for value in self.special_tokens_map_extended.values():
+            if isinstance(value, (list, tuple)):
+                tokens_to_add = [token for token in value if str(token) not in seen]
+            else:
+                tokens_to_add = [value] if str(value) not in seen else []
+            seen.update(map(str, tokens_to_add))
+            all_tokens.extend(tokens_to_add)
+        return all_tokens
+
+    @property
+    def all_special_tokens(self) -> list[str]:
+        """
+        `list[str]`: A list of the unique special tokens (`''`, `''`, ..., etc.).
+
+        Convert tokens of `tokenizers.AddedToken` type to string.
+        """
+        all_toks = [str(s) for s in self.all_special_tokens_extended]
+        return all_toks
+
+    @property
+    def all_special_ids(self) -> list[int]:
+        """
+        `list[int]`: List the ids of the special tokens(`''`, `''`, etc.) mapped to class attributes.
+        """
+        all_toks = self.all_special_tokens
+        all_ids = self.convert_tokens_to_ids(all_toks)
+        return all_ids
+
+    def _set_model_specific_special_tokens(self, special_tokens: list[str]):
+        """
+        Adds new special tokens to the "SPECIAL_TOKENS_ATTRIBUTES" list which will be part
+        of "self.special_tokens" and saved as a special token in tokenizer's config.
+        This allows us to dynamically add new model-type specific tokens after initializing the tokenizer.
+        For example: if the model tokenizers is multimodal, we can support special image or audio tokens.
+        """
+        self.SPECIAL_TOKENS_ATTRIBUTES = self.SPECIAL_TOKENS_ATTRIBUTES + list(special_tokens.keys())
+        for key, value in special_tokens.items():
+            if isinstance(value, (str, AddedToken)):
+                self._special_tokens_map[key] = value
+            else:
+                raise TypeError(f"Special token {key} has to be either str or AddedToken but got: {type(value)}")
+
+
+ENCODE_KWARGS_DOCSTRING = r"""
+            add_special_tokens (`bool`, *optional*, defaults to `True`):
+                Whether or not to add special tokens when encoding the sequences. This will use the underlying
+                `PretrainedTokenizerBase.build_inputs_with_special_tokens` function, which defines which tokens are
+                automatically added to the input ids. This is useful if you want to add `bos` or `eos` tokens
+                automatically.
+            padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+                Activates and controls padding. Accepts the following values:
+
+                - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+                  sequence is provided).
+                - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+                  acceptable input length for the model if that argument is not provided.
+                - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+                  lengths).
+            truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`):
+                Activates and controls truncation. Accepts the following values:
+
+                - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
+                  to the maximum acceptable input length for the model if that argument is not provided. This will
+                  truncate token by token, removing a token from the longest sequence in the pair if a pair of
+                  sequences (or a batch of pairs) is provided.
+                - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+                  maximum acceptable input length for the model if that argument is not provided. This will only
+                  truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+                - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+                  maximum acceptable input length for the model if that argument is not provided. This will only
+                  truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+                - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
+                  greater than the model maximum admissible input size).
+            max_length (`int`, *optional*):
+                Controls the maximum length to use by one of the truncation/padding parameters.
+
+                If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
+                is required by one of the truncation/padding parameters. If the model has no specific maximum input
+                length (like XLNet) truncation/padding to a maximum length will be deactivated.
+            stride (`int`, *optional*, defaults to 0):
+                If set to a number along with `max_length`, the overflowing tokens returned when
+                `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
+                returned to provide some overlap between truncated and overflowing sequences. The value of this
+                argument defines the number of overlapping tokens.
+            is_split_into_words (`bool`, *optional*, defaults to `False`):
+                Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
+                tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
+                which it will tokenize. This is useful for NER or token classification.
+            pad_to_multiple_of (`int`, *optional*):
+                If set will pad the sequence to a multiple of the provided value. Requires `padding` to be activated.
+                This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
+                `>= 7.5` (Volta).
+            padding_side (`str`, *optional*):
+                The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+                Default value is picked from the class attribute of the same name.
+            return_tensors (`str` or [`~utils.TensorType`], *optional*):
+                If set, will return tensors instead of list of python integers. Acceptable values are:
+
+                - `'tf'`: Return TensorFlow `tf.constant` objects.
+                - `'pt'`: Return PyTorch `torch.Tensor` objects.
+                - `'np'`: Return Numpy `np.ndarray` objects.
+"""
+
+ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
+            return_token_type_ids (`bool`, *optional*):
+                Whether to return token type IDs. If left to the default, will return the token type IDs according to
+                the specific tokenizer's default, defined by the `return_outputs` attribute.
+
+                [What are token type IDs?](../glossary#token-type-ids)
+            return_attention_mask (`bool`, *optional*):
+                Whether to return the attention mask. If left to the default, will return the attention mask according
+                to the specific tokenizer's default, defined by the `return_outputs` attribute.
+
+                [What are attention masks?](../glossary#attention-mask)
+            return_overflowing_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch
+                of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead
+                of returning overflowing tokens.
+            return_special_tokens_mask (`bool`, *optional*, defaults to `False`):
+                Whether or not to return special tokens mask information.
+            return_offsets_mapping (`bool`, *optional*, defaults to `False`):
+                Whether or not to return `(char_start, char_end)` for each token.
+
+                This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using
+                Python's tokenizer, this method will raise `NotImplementedError`.
+            return_length  (`bool`, *optional*, defaults to `False`):
+                Whether or not to return the lengths of the encoded inputs.
+            verbose (`bool`, *optional*, defaults to `True`):
+                Whether or not to print more information and warnings.
+            **kwargs: passed to the `self.tokenize()` method
+
+        Return:
+            [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
+
+            - **input_ids** -- List of token ids to be fed to a model.
+
+              [What are input IDs?](../glossary#input-ids)
+
+            - **token_type_ids** -- List of token type ids to be fed to a model (when `return_token_type_ids=True` or
+              if *"token_type_ids"* is in `self.model_input_names`).
+
+              [What are token type IDs?](../glossary#token-type-ids)
+
+            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+              `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`).
+
+              [What are attention masks?](../glossary#attention-mask)
+
+            - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and
+              `return_overflowing_tokens=True`).
+            - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and
+              `return_overflowing_tokens=True`).
+            - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying
+              regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`).
+            - **length** -- The length of the inputs (when `return_length=True`)
+"""
+
+
+INIT_TOKENIZER_DOCSTRING = r"""
+    Class attributes (overridden by derived classes)
+
+        - **vocab_files_names** (`dict[str, str]`) -- A dictionary with, as keys, the `__init__` keyword name of each
+          vocabulary file required by the model, and as associated values, the filename for saving the associated file
+          (string).
+        - **pretrained_vocab_files_map** (`dict[str, dict[str, str]]`) -- A dictionary of dictionaries, with the
+          high-level keys being the `__init__` keyword name of each vocabulary file required by the model, the
+          low-level being the `short-cut-names` of the pretrained models with, as associated values, the `url` to the
+          associated pretrained vocabulary file.
+        - **model_input_names** (`list[str]`) -- A list of inputs expected in the forward pass of the model.
+        - **padding_side** (`str`) -- The default value for the side on which the model should have padding applied.
+          Should be `'right'` or `'left'`.
+        - **truncation_side** (`str`) -- The default value for the side on which the model should have truncation
+          applied. Should be `'right'` or `'left'`.
+
+    Args:
+        model_max_length (`int`, *optional*):
+            The maximum length (in number of tokens) for the inputs to the transformer model. When the tokenizer is
+            loaded with [`~tokenization_utils_base.PreTrainedTokenizerBase.from_pretrained`], this will be set to the
+            value stored for the associated model in `max_model_input_sizes` (see above). If no value is provided, will
+            default to VERY_LARGE_INTEGER (`int(1e30)`).
+        padding_side (`str`, *optional*):
+            The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+            Default value is picked from the class attribute of the same name.
+        truncation_side (`str`, *optional*):
+            The side on which the model should have truncation applied. Should be selected between ['right', 'left'].
+            Default value is picked from the class attribute of the same name.
+        chat_template (`str`, *optional*):
+            A Jinja template string that will be used to format lists of chat messages. See
+            https://huggingface.co/docs/transformers/chat_templating for a full description.
+        model_input_names (`list[string]`, *optional*):
+            The list of inputs accepted by the forward pass of the model (like `"token_type_ids"` or
+            `"attention_mask"`). Default value is picked from the class attribute of the same name.
+        bos_token (`str` or `tokenizers.AddedToken`, *optional*):
+            A special token representing the beginning of a sentence. Will be associated to `self.bos_token` and
+            `self.bos_token_id`.
+        eos_token (`str` or `tokenizers.AddedToken`, *optional*):
+            A special token representing the end of a sentence. Will be associated to `self.eos_token` and
+            `self.eos_token_id`.
+        unk_token (`str` or `tokenizers.AddedToken`, *optional*):
+            A special token representing an out-of-vocabulary token. Will be associated to `self.unk_token` and
+            `self.unk_token_id`.
+        sep_token (`str` or `tokenizers.AddedToken`, *optional*):
+            A special token separating two different sentences in the same input (used by BERT for instance). Will be
+            associated to `self.sep_token` and `self.sep_token_id`.
+        pad_token (`str` or `tokenizers.AddedToken`, *optional*):
+            A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
+            attention mechanisms or loss computation. Will be associated to `self.pad_token` and `self.pad_token_id`.
+        cls_token (`str` or `tokenizers.AddedToken`, *optional*):
+            A special token representing the class of the input (used by BERT for instance). Will be associated to
+            `self.cls_token` and `self.cls_token_id`.
+        mask_token (`str` or `tokenizers.AddedToken`, *optional*):
+            A special token representing a masked token (used by masked-language modeling pretraining objectives, like
+            BERT). Will be associated to `self.mask_token` and `self.mask_token_id`.
+        additional_special_tokens (tuple or list of `str` or `tokenizers.AddedToken`, *optional*):
+            A tuple or a list of additional special tokens. Add them here to ensure they are skipped when decoding with
+            `skip_special_tokens` is set to True. If they are not part of the vocabulary, they will be added at the end
+            of the vocabulary.
+        clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
+            Whether or not the model should cleanup the spaces that were added when splitting the input text during the
+            tokenization process.
+        split_special_tokens (`bool`, *optional*, defaults to `False`):
+            Whether or not the special tokens should be split during the tokenization process. Passing will affect the
+            internal state of the tokenizer. The default behavior is to not split special tokens. This means that if
+            `` is the `bos_token`, then `tokenizer.tokenize("") = ['`]. Otherwise, if
+            `split_special_tokens=True`, then `tokenizer.tokenize("")` will be give `['<','s', '>']`.
+"""
+
+
+@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
+class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
+    """
+    Base class for [`PreTrainedTokenizer`] and [`PreTrainedTokenizerFast`].
+
+    Handles shared (mostly boiler plate) methods for those two classes.
+    """
+
+    vocab_files_names: dict[str, str] = {}
+    pretrained_vocab_files_map: dict[str, dict[str, str]] = {}
+    _auto_class: Optional[str] = None
+
+    # first name has to correspond to main model input name
+    # to make sure `tokenizer.pad(...)` works correctly
+    model_input_names: list[str] = ["input_ids", "token_type_ids", "attention_mask"]
+    padding_side: str = "right"
+    truncation_side: str = "right"
+    slow_tokenizer_class = None
+
+    def __init__(self, **kwargs):
+        # inputs and kwargs for saving and re-loading (see ``from_pretrained`` and ``save_pretrained``)
+        self.init_inputs = ()
+        for key in kwargs:
+            if hasattr(self, key) and callable(getattr(self, key)):
+                raise AttributeError(f"{key} conflicts with the method {key} in {self.__class__.__name__}")
+
+        self.init_kwargs = copy.deepcopy(kwargs)
+        self.name_or_path = kwargs.pop("name_or_path", "")
+        self._processor_class = kwargs.pop("processor_class", None)
+
+        # For backward compatibility we fallback to set model_max_length from max_len if provided
+        model_max_length = kwargs.pop("model_max_length", kwargs.pop("max_len", None))
+        self.model_max_length = model_max_length if model_max_length is not None else VERY_LARGE_INTEGER
+
+        # Padding and truncation side are right by default and overridden in subclasses. If specified in the kwargs, it
+        # is changed.
+        self.padding_side = kwargs.pop("padding_side", self.padding_side)
+        if self.padding_side not in ["right", "left"]:
+            raise ValueError(
+                f"Padding side should be selected between 'right' and 'left', current value: {self.padding_side}"
+            )
+
+        self.truncation_side = kwargs.pop("truncation_side", self.truncation_side)
+        if self.truncation_side not in ["right", "left"]:
+            raise ValueError(
+                f"Truncation side should be selected between 'right' and 'left', current value: {self.truncation_side}"
+            )
+
+        self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
+
+        # By default, cleaning tokenization spaces for both fast and slow tokenizers
+        self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", False)
+
+        # By default, do not split special tokens for both fast and slow tokenizers
+        self.split_special_tokens = kwargs.pop("split_special_tokens", False)
+
+        self.deprecation_warnings = {}  # Use to store when we have already noticed a deprecation warning (avoid overlogging).
+        self._in_target_context_manager = False
+
+        # Stores a Jinja template that formats chat histories into tokenizable strings
+        self.chat_template = kwargs.pop("chat_template", None)
+        if isinstance(self.chat_template, (list, tuple)):
+            # Chat templates are stored as lists of dicts with fixed key names,
+            # we reconstruct that into a single dict while loading them.
+            self.chat_template = {template["name"]: template["template"] for template in self.chat_template}
+
+        super().__init__(**kwargs)
+
+        self.extra_special_tokens = kwargs.pop("extra_special_tokens", {})
+        self._set_model_specific_special_tokens(special_tokens=self.extra_special_tokens)
+
+    @property
+    def max_len_single_sentence(self) -> int:
+        """
+        `int`: The maximum length of a sentence that can be fed to the model.
+        """
+        return self.model_max_length - self.num_special_tokens_to_add(pair=False)
+
+    @property
+    def max_len_sentences_pair(self) -> int:
+        """
+        `int`: The maximum combined length of a pair of sentences that can be fed to the model.
+        """
+        return self.model_max_length - self.num_special_tokens_to_add(pair=True)
+
+    @max_len_single_sentence.setter
+    def max_len_single_sentence(self, value) -> int:
+        # For backward compatibility, allow to try to setup 'max_len_single_sentence'.
+        if value == self.model_max_length - self.num_special_tokens_to_add(pair=False) and self.verbose:
+            if not self.deprecation_warnings.get("max_len_single_sentence", False):
+                logger.warning(
+                    "Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up."
+                )
+            self.deprecation_warnings["max_len_single_sentence"] = True
+        else:
+            raise ValueError(
+                "Setting 'max_len_single_sentence' is now deprecated. This value is automatically set up."
+            )
+
+    @max_len_sentences_pair.setter
+    def max_len_sentences_pair(self, value) -> int:
+        # For backward compatibility, allow to try to setup 'max_len_sentences_pair'.
+        if value == self.model_max_length - self.num_special_tokens_to_add(pair=True) and self.verbose:
+            if not self.deprecation_warnings.get("max_len_sentences_pair", False):
+                logger.warning(
+                    "Setting 'max_len_sentences_pair' is now deprecated. This value is automatically set up."
+                )
+            self.deprecation_warnings["max_len_sentences_pair"] = True
+        else:
+            raise ValueError("Setting 'max_len_sentences_pair' is now deprecated. This value is automatically set up.")
+
+    def _set_processor_class(self, processor_class: str):
+        """Sets processor class as an attribute."""
+        self._processor_class = processor_class
+
+    @property
+    def added_tokens_decoder(self) -> dict[int, AddedToken]:
+        raise NotImplementedError()
+
+    def __repr__(self) -> str:
+        added_tokens_decoder_rep = "\n\t".join([f"{k}: {v.__repr__()}," for k, v in self.added_tokens_decoder.items()])
+        return (
+            f"{self.__class__.__name__}(name_or_path='{self.name_or_path}',"
+            f" vocab_size={self.vocab_size}, model_max_length={self.model_max_length}, is_fast={self.is_fast},"
+            f" padding_side='{self.padding_side}', truncation_side='{self.truncation_side}',"
+            f" special_tokens={self.special_tokens_map}, clean_up_tokenization_spaces={self.clean_up_tokenization_spaces},"
+            " added_tokens_decoder={\n\t" + added_tokens_decoder_rep + "\n}\n)"
+        )
+
+    def __len__(self) -> int:
+        raise NotImplementedError()
+
+    def get_vocab(self) -> dict[str, int]:
+        """
+        Returns the vocabulary as a dictionary of token to index.
+
+        `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the
+        vocab.
+
+        Returns:
+            `dict[str, int]`: The vocabulary.
+        """
+        raise NotImplementedError()
+
+    def apply_chat_template(
+        self,
+        conversation: Union[list[dict[str, str]], list[list[dict[str, str]]]],
+        tools: Optional[list[Union[dict, Callable]]] = None,
+        documents: Optional[list[dict[str, str]]] = None,
+        chat_template: Optional[str] = None,
+        add_generation_prompt: bool = False,
+        continue_final_message: bool = False,
+        tokenize: bool = True,
+        padding: Union[bool, str, PaddingStrategy] = False,
+        truncation: bool = False,
+        max_length: Optional[int] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        return_dict: bool = False,
+        return_assistant_tokens_mask: bool = False,
+        tokenizer_kwargs: Optional[dict[str, Any]] = None,
+        **kwargs,
+    ) -> Union[str, list[int], list[str], list[list[int]], BatchEncoding]:
+        """
+        Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token
+        ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to
+        determine the format and control tokens to use when converting.
+
+        Args:
+            conversation (Union[list[dict[str, str]], list[list[dict[str, str]]]]): A list of dicts
+                with "role" and "content" keys, representing the chat history so far.
+            tools (`list[Union[Dict, Callable]]`, *optional*):
+                A list of tools (callable functions) that will be accessible to the model. If the template does not
+                support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
+                giving the name, description and argument types for the tool. See our
+                [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
+                for more information.
+            documents (`list[dict[str, str]]`, *optional*):
+                A list of dicts representing documents that will be accessible to the model if it is performing RAG
+                (retrieval-augmented generation). If the template does not support RAG, this argument will have no
+                effect. We recommend that each document should be a dict containing "title" and "text" keys. Please
+                see the RAG section of the [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#arguments-for-RAG)
+                for examples of passing documents with chat templates.
+            chat_template (`str`, *optional*):
+                A Jinja template to use for this conversion. It is usually not necessary to pass anything to this
+                argument, as the model's template will be used by default.
+            add_generation_prompt (bool, *optional*):
+                If this is set, a prompt with the token(s) that indicate
+                the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model.
+                Note that this argument will be passed to the chat template, and so it must be supported in the
+                template for this argument to have any effect.
+            continue_final_message (bool, *optional*):
+                If this is set, the chat will be formatted so that the final
+                message in the chat is open-ended, without any EOS tokens. The model will continue this message
+                rather than starting a new one. This allows you to "prefill" part of
+                the model's response for it. Cannot be used at the same time as `add_generation_prompt`.
+            tokenize (`bool`, defaults to `True`):
+                Whether to tokenize the output. If `False`, the output will be a string.
+            padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+                 Select a strategy to pad the returned sequences (according to the model's padding side and padding
+                 index) among:
+
+                - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+                  sequence if provided).
+                - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+                  acceptable input length for the model if that argument is not provided.
+                - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+                  lengths).
+            truncation (`bool`, defaults to `False`):
+                Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`.
+            max_length (`int`, *optional*):
+                Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If
+                not specified, the tokenizer's `max_length` attribute will be used as a default.
+            return_tensors (`str` or [`~utils.TensorType`], *optional*):
+                If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable
+                values are:
+                - `'tf'`: Return TensorFlow `tf.Tensor` objects.
+                - `'pt'`: Return PyTorch `torch.Tensor` objects.
+                - `'np'`: Return NumPy `np.ndarray` objects.
+                - `'jax'`: Return JAX `jnp.ndarray` objects.
+            return_dict (`bool`, defaults to `False`):
+                Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
+            tokenizer_kwargs (`dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
+            return_assistant_tokens_mask (`bool`, defaults to `False`):
+                Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
+                the mask will contain 1. For user and system tokens, the mask will contain 0.
+                This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
+            **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
+
+        Returns:
+            `Union[list[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This
+            output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is
+            set, will return a dict of tokenizer outputs instead.
+        """
+
+        if return_dict and not tokenize:
+            raise ValueError(
+                "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict "
+                "of tokenizer outputs to return."
+            )
+
+        if return_assistant_tokens_mask and not return_dict:
+            raise ValueError("`return_assistant_tokens_mask=True` is incompatible with `return_dict=False`")
+
+        if tokenizer_kwargs is None:
+            tokenizer_kwargs = {}
+
+        chat_template = self.get_chat_template(chat_template, tools)
+
+        if isinstance(conversation, (list, tuple)) and (
+            isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
+        ):
+            conversations = conversation
+            is_batched = True
+        else:
+            conversations = [conversation]
+            is_batched = False
+
+        if continue_final_message:
+            if add_generation_prompt:
+                raise ValueError(
+                    "continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."
+                )
+            if return_assistant_tokens_mask:
+                raise ValueError("continue_final_message is not compatible with return_assistant_tokens_mask.")
+
+        template_kwargs = {**self.special_tokens_map, **kwargs}  # kwargs overwrite special tokens if both are present
+        rendered_chat, generation_indices = render_jinja_template(
+            conversations=conversations,
+            tools=tools,
+            documents=documents,
+            chat_template=chat_template,
+            return_assistant_tokens_mask=return_assistant_tokens_mask,
+            continue_final_message=continue_final_message,
+            add_generation_prompt=add_generation_prompt,
+            **template_kwargs,
+        )
+
+        if not is_batched:
+            rendered_chat = rendered_chat[0]
+
+        if tokenize:
+            out = self(
+                rendered_chat,
+                padding=padding,
+                truncation=truncation,
+                max_length=max_length,
+                add_special_tokens=False,
+                return_tensors=return_tensors,
+                **tokenizer_kwargs,
+            )
+            if return_dict:
+                if return_assistant_tokens_mask:
+                    assistant_masks = []
+                    if is_batched or return_tensors:
+                        input_ids = out["input_ids"]
+                    else:
+                        input_ids = [out["input_ids"]]
+                    for i in range(len(input_ids)):
+                        current_mask = [0] * len(input_ids[i])
+                        for assistant_start_char, assistant_end_char in generation_indices[i]:
+                            start_token = out.char_to_token(i, assistant_start_char)
+                            end_token = out.char_to_token(i, assistant_end_char - 1)
+                            if start_token is None:
+                                # start_token is out of bounds maybe due to truncation.
+                                break
+                            for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])):
+                                current_mask[token_id] = 1
+                        assistant_masks.append(current_mask)
+
+                    if not is_batched and not return_tensors:
+                        assistant_masks = assistant_masks[0]
+
+                    out["assistant_masks"] = assistant_masks
+
+                    if return_tensors:
+                        out.convert_to_tensors(tensor_type=return_tensors)
+
+                return out
+            else:
+                return out["input_ids"]
+        else:
+            return rendered_chat
+
+    def encode_message_with_chat_template(
+        self,
+        message: dict[str, str],
+        conversation_history: Optional[list[dict[str, str]]] = None,
+        **kwargs,
+    ) -> list[int]:
+        """
+        Tokenize a single message. This method is a convenience wrapper around `apply_chat_template` that allows you
+        to tokenize messages one by one. This is useful for things like token-by-token streaming.
+        This method is not guaranteed to be perfect. For some models, it may be impossible to robustly tokenize
+        single messages. For example, if the chat template adds tokens after each message, but also has a prefix that
+        is added to the entire chat, it will be impossible to distinguish a chat-start-token from a message-start-token.
+        In these cases, this method will do its best to find the correct tokenization, but it may not be perfect.
+        **Note:** This method does not support `add_generation_prompt`. If you want to add a generation prompt,
+        you should do it separately after tokenizing the conversation.
+        Args:
+            message (`dict`):
+                A dictionary with "role" and "content" keys, representing the message to tokenize.
+            conversation_history (`list[dict]`, *optional*):
+                A list of dicts with "role" and "content" keys, representing the chat history so far. If you are
+                tokenizing messages one by one, you should pass the previous messages in the conversation here.
+            **kwargs:
+                Additional kwargs to pass to the `apply_chat_template` method.
+        Returns:
+            `list[int]`: A list of token ids representing the tokenized message.
+        """
+        if "add_generation_prompt" in kwargs:
+            raise ValueError(
+                "`encode_message_with_chat_template` does not support `add_generation_prompt`. Please add the generation prompt "
+                "separately."
+            )
+
+        if conversation_history is None or len(conversation_history) == 0:
+            return self.apply_chat_template([message], add_generation_prompt=False, tokenize=True, **kwargs)
+
+        conversation = conversation_history + [message]
+        tokens = self.apply_chat_template(conversation, add_generation_prompt=False, tokenize=True, **kwargs)
+
+        prefix_tokens = self.apply_chat_template(
+            conversation_history, add_generation_prompt=False, tokenize=True, **kwargs
+        )
+        # It's possible that the prefix tokens are not a prefix of the full list of tokens.
+        # For example, if the prefix is `User: Hi` and the full conversation is `User: HiAssistant: Hello`.
+        # In this case, we can't simply find the prefix, so we have to do something a bit more subtle.
+        # We look for the first place where the tokens differ, and that's our split point.
+        # This is not perfect, but it's the best we can do without a token-level API.
+        # To make this more robust, we could do a diff and find the longest common subsequence, but this is
+        # a good first approximation.
+        # This is particularly important for models like Llama3 that have changed their chat template to include
+        # EOS tokens after user messages.
+        min_len = min(len(prefix_tokens), len(tokens))
+        for i in range(min_len):
+            if prefix_tokens[i] != tokens[i]:
+                return tokens[i:]
+        return tokens[min_len:]
+
+    def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[list[dict]] = None) -> str:
+        """
+        Retrieve the chat template string used for tokenizing chat messages. This template is used
+        internally by the `apply_chat_template` method and can also be used externally to retrieve the model's chat
+        template for better generation tracking.
+
+        Args:
+            chat_template (`str`, *optional*):
+                A Jinja template or the name of a template to use for this conversion.
+                It is usually not necessary to pass anything to this argument,
+                as the model's template will be used by default.
+            tools (`list[Dict]`, *optional*):
+                A list of tools (callable functions) that will be accessible to the model. If the template does not
+                support function calling, this argument will have no effect. Each tool should be passed as a JSON Schema,
+                giving the name, description and argument types for the tool. See our
+                [chat templating guide](https://huggingface.co/docs/transformers/main/en/chat_templating#automated-function-conversion-for-tool-use)
+                for more information.
+
+        Returns:
+            `str`: The chat template string.
+        """
+        # First, handle the cases when the model has a dict of multiple templates
+        if isinstance(self.chat_template, dict):
+            template_dict = self.chat_template
+            if chat_template is not None and chat_template in template_dict:
+                # The user can pass the name of a template to the chat template argument instead of an entire template
+                chat_template = template_dict[chat_template]
+            elif chat_template is None:
+                if tools is not None and "tool_use" in template_dict:
+                    chat_template = template_dict["tool_use"]
+                elif "default" in template_dict:
+                    chat_template = template_dict["default"]
+                else:
+                    raise ValueError(
+                        "This model has multiple chat templates with no default specified! Please either pass a chat "
+                        "template or the name of the template you wish to use to the `chat_template` argument. Available "
+                        f"template names are {sorted(template_dict.keys())}."
+                    )
+
+        elif chat_template is None:
+            # These are the cases when the model has a single template
+            # priority: `chat_template` argument > `tokenizer.chat_template`
+            if self.chat_template is not None:
+                chat_template = self.chat_template
+            else:
+                raise ValueError(
+                    "Cannot use chat template functions because tokenizer.chat_template is not set and no template "
+                    "argument was passed! For information about writing templates and setting the "
+                    "tokenizer.chat_template attribute, please see the documentation at "
+                    "https://huggingface.co/docs/transformers/main/en/chat_templating"
+                )
+
+        return chat_template
+
+    @classmethod
+    def from_pretrained(
+        cls,
+        pretrained_model_name_or_path: Union[str, os.PathLike],
+        *init_inputs,
+        cache_dir: Optional[Union[str, os.PathLike]] = None,
+        force_download: bool = False,
+        local_files_only: bool = False,
+        token: Optional[Union[str, bool]] = None,
+        revision: str = "main",
+        trust_remote_code=False,
+        **kwargs,
+    ):
+        r"""
+        Instantiate a [`~tokenization_utils_base.PreTrainedTokenizerBase`] (or a derived class) from a predefined
+        tokenizer.
+
+        Args:
+            pretrained_model_name_or_path (`str` or `os.PathLike`):
+                Can be either:
+
+                - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
+                - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
+                  using the [`~tokenization_utils_base.PreTrainedTokenizerBase.save_pretrained`] method, e.g.,
+                  `./my_model_directory/`.
+                - (**Deprecated**, not applicable to all derived classes) A path or url to a single saved vocabulary
+                  file (if and only if the tokenizer only requires a single vocabulary file like Bert or XLNet), e.g.,
+                  `./my_model_directory/vocab.txt`.
+            cache_dir (`str` or `os.PathLike`, *optional*):
+                Path to a directory in which a downloaded predefined tokenizer vocabulary files should be cached if the
+                standard cache should not be used.
+            force_download (`bool`, *optional*, defaults to `False`):
+                Whether or not to force the (re-)download the vocabulary files and override the cached versions if they
+                exist.
+            resume_download:
+                Deprecated and ignored. All downloads are now resumed by default when possible.
+                Will be removed in v5 of Transformers.
+            proxies (`dict[str, str]`, *optional*):
+                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+            token (`str` or *bool*, *optional*):
+                The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
+                when running `hf auth login` (stored in `~/.huggingface`).
+            local_files_only (`bool`, *optional*, defaults to `False`):
+                Whether or not to only rely on local files and not to attempt to download any files.
+            revision (`str`, *optional*, defaults to `"main"`):
+                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+                identifier allowed by git.
+            subfolder (`str`, *optional*):
+                In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for
+                facebook/rag-token-base), specify it here.
+            inputs (additional positional arguments, *optional*):
+                Will be passed along to the Tokenizer `__init__` method.
+            trust_remote_code (`bool`, *optional*, defaults to `False`):
+                Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
+                should only be set to `True` for repositories you trust and in which you have read the code, as it will
+                execute code present on the Hub on your local machine.
+            kwargs (additional keyword arguments, *optional*):
+                Will be passed to the Tokenizer `__init__` method. Can be used to set special tokens like `bos_token`,
+                `eos_token`, `unk_token`, `sep_token`, `pad_token`, `cls_token`, `mask_token`,
+                `additional_special_tokens`. See parameters in the `__init__` for more details.
+
+        
+
+        Passing `token=True` is required when you want to use a private model.
+
+        
+
+        Examples:
+
+        ```python
+        # We can't instantiate directly the base class *PreTrainedTokenizerBase* so let's show our examples on a derived class: BertTokenizer
+        # Download vocabulary from huggingface.co and cache.
+        tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
+
+        # Download vocabulary from huggingface.co (user-uploaded) and cache.
+        tokenizer = BertTokenizer.from_pretrained("dbmdz/bert-base-german-cased")
+
+        # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
+        tokenizer = BertTokenizer.from_pretrained("./test/saved_model/")
+
+        # If the tokenizer uses a single vocabulary file, you can point directly to this file
+        tokenizer = BertTokenizer.from_pretrained("./test/saved_model/my_vocab.txt")
+
+        # You can link tokens to special vocabulary when instantiating
+        tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased", unk_token="")
+        # You should be sure '' is in the vocabulary when doing that.
+        # Otherwise use tokenizer.add_special_tokens({'unk_token': ''}) instead)
+        assert tokenizer.unk_token == ""
+        ```"""
+        resume_download = kwargs.pop("resume_download", None)
+        proxies = kwargs.pop("proxies", None)
+        use_auth_token = kwargs.pop("use_auth_token", None)
+        subfolder = kwargs.pop("subfolder", None)
+        from_pipeline = kwargs.pop("_from_pipeline", None)
+        from_auto_class = kwargs.pop("_from_auto", False)
+        commit_hash = kwargs.pop("_commit_hash", None)
+        gguf_file = kwargs.get("gguf_file")
+
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if token is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            token = use_auth_token
+
+        user_agent = {"file_type": "tokenizer", "from_auto_class": from_auto_class, "is_fast": "Fast" in cls.__name__}
+        if from_pipeline is not None:
+            user_agent["using_pipeline"] = from_pipeline
+
+        if is_offline_mode() and not local_files_only:
+            logger.info("Offline mode: forcing local_files_only=True")
+            local_files_only = True
+
+        pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+        vocab_files = {}
+        init_configuration = {}
+
+        is_local = os.path.isdir(pretrained_model_name_or_path)
+        single_file_id = None
+        if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
+            if len(cls.vocab_files_names) > 1 and not gguf_file:
+                raise ValueError(
+                    f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not "
+                    "supported for this tokenizer. Use a model identifier or the path to a directory instead."
+                )
+            warnings.warn(
+                f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is deprecated and "
+                "won't be possible anymore in v5. Use a model identifier or the path to a directory instead.",
+                FutureWarning,
+            )
+            file_id = list(cls.vocab_files_names.keys())[0]
+
+            vocab_files[file_id] = pretrained_model_name_or_path
+            single_file_id = file_id
+        else:
+            if gguf_file:
+                vocab_files["vocab_file"] = gguf_file
+            else:
+                # At this point pretrained_model_name_or_path is either a directory or a model identifier name
+                additional_files_names = {
+                    "added_tokens_file": ADDED_TOKENS_FILE,  # kept only for legacy
+                    "special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE,  # kept only for legacy
+                    "tokenizer_config_file": TOKENIZER_CONFIG_FILE,
+                    # tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders
+                    "tokenizer_file": FULL_TOKENIZER_FILE,
+                    "chat_template_file": CHAT_TEMPLATE_FILE,
+                }
+
+                vocab_files = {**cls.vocab_files_names, **additional_files_names}
+                if "tokenizer_file" in vocab_files:
+                    # Try to get the tokenizer config to see if there are versioned tokenizer files.
+                    fast_tokenizer_file = FULL_TOKENIZER_FILE
+
+                    try:
+                        resolved_config_file = cached_file(
+                            pretrained_model_name_or_path,
+                            TOKENIZER_CONFIG_FILE,
+                            cache_dir=cache_dir,
+                            force_download=force_download,
+                            resume_download=resume_download,
+                            proxies=proxies,
+                            token=token,
+                            revision=revision,
+                            local_files_only=local_files_only,
+                            subfolder=subfolder,
+                            user_agent=user_agent,
+                            _raise_exceptions_for_missing_entries=False,
+                            _commit_hash=commit_hash,
+                        )
+                    except OSError:
+                        # Re-raise any error raised by cached_file in order to get a helpful error message
+                        raise
+                    except Exception:
+                        # For any other exception, we throw a generic error.
+                        raise OSError(
+                            f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+                            "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+                            f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+                            f"containing all relevant files for a {cls.__name__} tokenizer."
+                        )
+
+                    commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
+                    if resolved_config_file is not None:
+                        with open(resolved_config_file, encoding="utf-8") as reader:
+                            tokenizer_config = json.load(reader)
+                            if "fast_tokenizer_files" in tokenizer_config:
+                                fast_tokenizer_file = get_fast_tokenizer_file(tokenizer_config["fast_tokenizer_files"])
+                    vocab_files["tokenizer_file"] = fast_tokenizer_file
+
+                    # This block looks for any extra chat template files
+                    if is_local:
+                        template_dir = Path(pretrained_model_name_or_path, CHAT_TEMPLATE_DIR)
+                        if template_dir.is_dir():
+                            for template_file in template_dir.glob("*.jinja"):
+                                template_name = template_file.name.removesuffix(".jinja")
+                                vocab_files[f"chat_template_{template_name}"] = (
+                                    f"{CHAT_TEMPLATE_DIR}/{template_file.name}"
+                                )
+                    else:
+                        for template in list_repo_templates(
+                            pretrained_model_name_or_path,
+                            local_files_only=local_files_only,
+                            revision=revision,
+                            cache_dir=cache_dir,
+                        ):
+                            template = template.removesuffix(".jinja")
+                            vocab_files[f"chat_template_{template}"] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja"
+
+        # Get files from url, cache, or disk depending on the case
+        resolved_vocab_files = {}
+        for file_id, file_path in vocab_files.items():
+            if file_path is None:
+                resolved_vocab_files[file_id] = None
+            elif single_file_id == file_id:
+                if os.path.isfile(file_path):
+                    resolved_vocab_files[file_id] = file_path
+                elif is_remote_url(file_path):
+                    resolved_vocab_files[file_id] = download_url(file_path, proxies=proxies)
+            else:
+                try:
+                    resolved_vocab_files[file_id] = cached_file(
+                        pretrained_model_name_or_path,
+                        file_path,
+                        cache_dir=cache_dir,
+                        force_download=force_download,
+                        proxies=proxies,
+                        resume_download=resume_download,
+                        local_files_only=local_files_only,
+                        token=token,
+                        user_agent=user_agent,
+                        revision=revision,
+                        subfolder=subfolder,
+                        _raise_exceptions_for_missing_entries=False,
+                        _commit_hash=commit_hash,
+                    )
+                except OSError:
+                    # Re-raise any error raised by cached_file in order to get a helpful error message
+                    raise
+                except Exception:
+                    # For any other exception, we throw a generic error.
+                    raise OSError(
+                        f"Can't load tokenizer for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+                        "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+                        f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+                        f"containing all relevant files for a {cls.__name__} tokenizer."
+                    )
+                commit_hash = extract_commit_hash(resolved_vocab_files[file_id], commit_hash)
+
+        for file_id, file_path in vocab_files.items():
+            if file_id not in resolved_vocab_files:
+                continue
+
+            if is_local:
+                logger.info(f"loading file {file_path}")
+            else:
+                logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}")
+
+        return cls._from_pretrained(
+            resolved_vocab_files,
+            pretrained_model_name_or_path,
+            init_configuration,
+            *init_inputs,
+            token=token,
+            cache_dir=cache_dir,
+            local_files_only=local_files_only,
+            _commit_hash=commit_hash,
+            _is_local=is_local,
+            trust_remote_code=trust_remote_code,
+            **kwargs,
+        )
+
+    @classmethod
+    def _from_pretrained(
+        cls,
+        resolved_vocab_files,
+        pretrained_model_name_or_path,
+        init_configuration,
+        *init_inputs,
+        token=None,
+        cache_dir=None,
+        local_files_only=False,
+        _commit_hash=None,
+        _is_local=False,
+        trust_remote_code=False,
+        **kwargs,
+    ):
+        # We instantiate fast tokenizers based on a slow tokenizer if we don't have access to the tokenizer.json
+        # file or if `from_slow` is set to True.
+        from_slow = kwargs.get("from_slow", False)
+        gguf_file = kwargs.get("gguf_file")
+        has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None
+
+        # If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be
+        # loaded directly from the GGUF file.
+        if (from_slow or not has_tokenizer_file) and cls.slow_tokenizer_class is not None and not gguf_file:
+            slow_tokenizer = (cls.slow_tokenizer_class)._from_pretrained(
+                copy.deepcopy(resolved_vocab_files),
+                pretrained_model_name_or_path,
+                copy.deepcopy(init_configuration),
+                *init_inputs,
+                token=token,
+                cache_dir=cache_dir,
+                local_files_only=local_files_only,
+                _commit_hash=_commit_hash,
+                **(copy.deepcopy(kwargs)),
+            )
+        else:
+            slow_tokenizer = None
+
+        # Prepare tokenizer initialization kwargs
+        # Did we saved some inputs and kwargs to reload ?
+        tokenizer_config_file = resolved_vocab_files.pop("tokenizer_config_file", None)
+        if tokenizer_config_file is not None:
+            with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
+                init_kwargs = json.load(tokenizer_config_handle)
+            # First attempt. We get tokenizer_class from tokenizer_config to check mismatch between tokenizers.
+            config_tokenizer_class = init_kwargs.get("tokenizer_class")
+            init_kwargs.pop("tokenizer_class", None)
+            if not has_tokenizer_file:
+                init_kwargs.pop("tokenizer_file", None)
+            saved_init_inputs = init_kwargs.pop("init_inputs", ())
+            if not init_inputs:
+                init_inputs = saved_init_inputs
+        else:
+            config_tokenizer_class = None
+            init_kwargs = init_configuration
+
+        # If independent chat template file(s) exist, they take priority over template entries in the tokenizer config
+        chat_templates = {}
+        chat_template_file = resolved_vocab_files.pop("chat_template_file", None)
+        extra_chat_templates = [key for key in resolved_vocab_files if key.startswith("chat_template_")]
+        if chat_template_file is not None:
+            with open(chat_template_file, encoding="utf-8") as chat_template_handle:
+                chat_templates["default"] = chat_template_handle.read()
+        for extra_chat_template in extra_chat_templates:
+            template_file = resolved_vocab_files.pop(extra_chat_template, None)
+            if template_file is None:
+                continue  # I think this should never happen, but just in case
+            template_name = extra_chat_template.removeprefix("chat_template_")
+            with open(template_file) as chat_template_handle:
+                chat_templates[template_name] = chat_template_handle.read()
+        if len(chat_templates) == 1 and "default" in chat_templates:
+            init_kwargs["chat_template"] = chat_templates["default"]
+        elif chat_templates:
+            init_kwargs["chat_template"] = chat_templates
+
+        if not _is_local:
+            if "auto_map" in init_kwargs:
+                # For backward compatibility with odl format.
+                if isinstance(init_kwargs["auto_map"], (tuple, list)):
+                    init_kwargs["auto_map"] = {"AutoTokenizer": init_kwargs["auto_map"]}
+
+        if config_tokenizer_class is None:
+            # Matt: This entire block is only used to decide if the tokenizer class matches the class in the repo.
+            #       If not, it raises a warning, but otherwise continues. Since we mostly load tokenizers with
+            #       AutoTokenizer these days, it seems like a lot of work (and a source of bugs) for little gain.
+            #       Maybe we can just remove this entirely?
+            from .models.auto.configuration_auto import AutoConfig  # tests_ignore
+
+            # Second attempt. If we have not yet found tokenizer_class, let's try to use the config.
+            try:
+                config = AutoConfig.from_pretrained(
+                    pretrained_model_name_or_path,
+                    token=token,
+                    cache_dir=cache_dir,
+                    local_files_only=local_files_only,
+                    trust_remote_code=trust_remote_code,
+                    _commit_hash=_commit_hash,
+                )
+                config_tokenizer_class = config.tokenizer_class
+            except (OSError, ValueError, KeyError):
+                # skip if an error occurred.
+                config = None
+            if config_tokenizer_class is None:
+                # Third attempt. If we have not yet found the original type of the tokenizer,
+                # we are loading we see if we can infer it from the type of the configuration file
+                from .models.auto.tokenization_auto import TOKENIZER_MAPPING_NAMES  # tests_ignore
+
+                if hasattr(config, "model_type"):
+                    model_type = config.model_type
+                else:
+                    # Fallback: use pattern matching on the string.
+                    model_type = None
+                    for pattern in TOKENIZER_MAPPING_NAMES:
+                        if pattern in str(pretrained_model_name_or_path):
+                            model_type = pattern
+                            break
+
+                if model_type is not None:
+                    config_tokenizer_class, config_tokenizer_class_fast = TOKENIZER_MAPPING_NAMES.get(
+                        model_type, (None, None)
+                    )
+                    if config_tokenizer_class is None:
+                        config_tokenizer_class = config_tokenizer_class_fast
+
+        if config_tokenizer_class is not None:
+            if cls.__name__.replace("Fast", "") != config_tokenizer_class.replace("Fast", ""):
+                logger.warning(
+                    "The tokenizer class you load from this checkpoint is not the same type as the class this"
+                    " function is called from. It may result in unexpected tokenization. \nThe tokenizer class you"
+                    f" load from this checkpoint is '{config_tokenizer_class}'. \nThe class this function is called"
+                    f" from is '{cls.__name__}'."
+                )
+
+        # Update with newly provided kwargs
+        init_kwargs.update(kwargs)
+
+        # Merge resolved_vocab_files arguments in init_kwargs.
+        added_tokens_file = resolved_vocab_files.pop("added_tokens_file", None)
+        special_tokens_map_file = resolved_vocab_files.pop("special_tokens_map_file", None)
+        for args_name, file_path in resolved_vocab_files.items():
+            if args_name not in init_kwargs:
+                init_kwargs[args_name] = file_path
+        tokenizer_file = resolved_vocab_files.pop("tokenizer_file", None)
+
+        if slow_tokenizer is not None:
+            init_kwargs["__slow_tokenizer"] = slow_tokenizer
+        init_kwargs["name_or_path"] = pretrained_model_name_or_path
+
+        #### Handle tokenizer serialization of added and special tokens
+        added_tokens_decoder: dict[int, AddedToken] = {}
+        added_tokens_map: dict[str, AddedToken] = {}
+        # if we have info on the slow added tokens
+        if "added_tokens_decoder" in init_kwargs:
+            for idx, token in init_kwargs["added_tokens_decoder"].items():
+                if isinstance(token, dict):
+                    token = AddedToken(**token)
+                if isinstance(token, AddedToken):
+                    added_tokens_decoder[int(idx)] = token
+                    added_tokens_map[str(token)] = token
+                else:
+                    raise TypeError(
+                        f"Found a {token.__class__} in the saved `added_tokens_decoder`, should be a dictionary or an AddedToken instance"
+                    )
+        else:
+            # begin legacy: read the added_tokens_file and update kwargs with special_tokens_map if modified
+            if special_tokens_map_file is not None:
+                with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
+                    special_tokens_map = json.load(special_tokens_map_handle)
+                    for key, value in special_tokens_map.items():
+                        if key in kwargs and kwargs[key]:
+                            # This value has already been redefined by the kwargs
+                            # We keep this new value and ignore the one stored in the special_tokens_map_file
+                            continue
+                        if isinstance(value, dict):
+                            value["special"] = True
+                            value = AddedToken(**value)
+                        elif key == "additional_special_tokens" and isinstance(value, list):
+                            additional_special_tokens = init_kwargs.pop("additional_special_tokens", []) or []
+                            for token in value:
+                                if isinstance(token, dict):
+                                    token["special"] = True
+                                    token = AddedToken(**token)
+                                if token not in additional_special_tokens:
+                                    additional_special_tokens.append(token)
+                            value = additional_special_tokens
+                        init_kwargs[key] = value
+
+            # slow -> slow|fast, legacy: convert the `"added_tokens.json"` file to `added_tokens_decoder`.
+            # this is for legacy purpose. We don't add the tokens after init for efficiency.
+            if added_tokens_file is not None:
+                special_tokens = []
+                for key in cls.SPECIAL_TOKENS_ATTRIBUTES & init_kwargs.keys():
+                    if init_kwargs[key] is not None:
+                        if key == "additional_special_tokens":
+                            special_tokens += [str(token) for token in init_kwargs[key]]
+                        else:
+                            special_tokens.append(str(init_kwargs[key]))
+
+                with open(added_tokens_file, encoding="utf-8") as added_tokens_handle:
+                    added_tok_encoder = json.load(added_tokens_handle)
+                for str_token, index in added_tok_encoder.items():
+                    # if index not in added_tokens_decoder and str_token not in added_tokens_map:
+                    special = str_token in special_tokens
+                    added_tokens_decoder[index] = AddedToken(
+                        str_token, rstrip=False, lstrip=False, normalized=not special, special=special
+                    )
+                    added_tokens_map[str(token)] = added_tokens_decoder[index]
+
+            # allows converting a fast -> slow: add the `tokenizer.json`'s `"added_tokens"` to the slow tokenizer
+            # if `tokenizer_config.json` is `None`
+            if tokenizer_file is not None:
+                # This is for slow so can be done before
+                with open(tokenizer_file, encoding="utf-8") as tokenizer_file_handle:
+                    tokenizer_file_handle = json.load(tokenizer_file_handle)
+                    added_tokens = tokenizer_file_handle.pop("added_tokens")
+                for serialized_tokens in added_tokens:
+                    idx = serialized_tokens.pop("id")
+                    added_tokens_decoder[idx] = AddedToken(**serialized_tokens)
+                    added_tokens_map[str(added_tokens_decoder[idx])] = added_tokens_decoder[idx]
+            # end legacy
+
+        # Passing AddedTokens and not strings to the class to prevent it from casting the string to a different AddedToken
+        # convert {'__type': 'AddedToken', 'content': '', 'lstrip': False, 'normalized': True, ...} to AddedTokens
+        init_kwargs["added_tokens_decoder"] = added_tokens_decoder
+        init_kwargs = cls.convert_added_tokens(init_kwargs, save=False)
+        for key in cls.SPECIAL_TOKENS_ATTRIBUTES & init_kwargs.keys():
+            if added_tokens_map != {} and init_kwargs[key] is not None:
+                if key != "additional_special_tokens":
+                    init_kwargs[key] = added_tokens_map.get(str(init_kwargs[key]), init_kwargs[key])
+
+        # Instantiate the tokenizer.
+        try:
+            tokenizer = cls(*init_inputs, **init_kwargs)
+        except import_protobuf_decode_error():
+            logger.info(
+                "Unable to load tokenizer model from SPM, loading from TikToken will be attempted instead."
+                "(Google protobuf error: Tried to load SPM model with non-SPM vocab file).",
+            )
+            return False
+        except RuntimeError as e:
+            if "sentencepiece_processor.cc" in str(e):
+                logger.info(
+                    "Unable to load tokenizer model from SPM, loading from TikToken will be attempted instead."
+                    "(SentencePiece RuntimeError: Tried to load SPM model with non-SPM vocab file).",
+                )
+            return False
+        except OSError:
+            raise OSError(
+                "Unable to load vocabulary from file. "
+                "Please check that the provided vocabulary is accessible and not corrupted."
+            )
+
+        if added_tokens_decoder != {} and max(list(added_tokens_decoder.keys())[-1], 0) > tokenizer.vocab_size:
+            logger.info(
+                "Special tokens have been added in the vocabulary, make sure the associated word embeddings are"
+                " fine-tuned or trained."
+            )
+        return tokenizer
+
+    @staticmethod
+    def _eventually_correct_t5_max_length(pretrained_model_name_or_path, max_model_length, init_max_model_length):
+        # This method should be deleted in Transformers v5
+        # Its only purpose is to potentially throw a warning
+        # that incorrectly defined max lengths of T5's tokenizer are used
+        # which we will correct in Transformers v5.
+        return max_model_length
+
+    @classmethod
+    def convert_added_tokens(cls, obj: Union[AddedToken, Any], save=False, add_type_field=True):
+        if isinstance(obj, dict) and "__type" in obj and obj["__type"] == "AddedToken":
+            obj.pop("__type")
+            return AddedToken(**obj)
+        if isinstance(obj, AddedToken) and save:
+            obj = obj.__getstate__()
+            if add_type_field:
+                obj["__type"] = "AddedToken"
+            else:
+                # Don't save "special" for previous tokenizers
+                obj.pop("special")
+            return obj
+        elif isinstance(obj, (list, tuple)):
+            return [cls.convert_added_tokens(o, save=save, add_type_field=add_type_field) for o in obj]
+        elif isinstance(obj, dict):
+            return {k: cls.convert_added_tokens(v, save=save, add_type_field=add_type_field) for k, v in obj.items()}
+        return obj
+
+    def save_chat_templates(
+        self,
+        save_directory: Union[str, os.PathLike],
+        tokenizer_config: dict,
+        filename_prefix: Optional[str],
+        save_jinja_files: bool,
+    ):
+        """
+        Writes chat templates out to the save directory if we're using the new format, and removes them from
+        the tokenizer config if present. If we're using the legacy format, it doesn't write any files, and instead
+        writes the templates to the tokenizer config in the correct format.
+        """
+        chat_template_file = os.path.join(
+            save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_FILE
+        )
+        chat_template_dir = os.path.join(
+            save_directory, (filename_prefix + "-" if filename_prefix else "") + CHAT_TEMPLATE_DIR
+        )
+
+        saved_raw_chat_template_files = []
+        if save_jinja_files and isinstance(self.chat_template, str):
+            # New format for single templates is to save them as chat_template.jinja
+            with open(chat_template_file, "w", encoding="utf-8") as f:
+                f.write(self.chat_template)
+            logger.info(f"chat template saved in {chat_template_file}")
+            saved_raw_chat_template_files.append(chat_template_file)
+            if "chat_template" in tokenizer_config:
+                tokenizer_config.pop("chat_template")  # To ensure it doesn't somehow end up in the config too
+        elif save_jinja_files and isinstance(self.chat_template, dict):
+            # New format for multiple templates is to save the default as chat_template.jinja
+            # and the other templates in the chat_templates/ directory
+            for template_name, template in self.chat_template.items():
+                if template_name == "default":
+                    with open(chat_template_file, "w", encoding="utf-8") as f:
+                        f.write(self.chat_template["default"])
+                    logger.info(f"chat template saved in {chat_template_file}")
+                    saved_raw_chat_template_files.append(chat_template_file)
+                else:
+                    Path(chat_template_dir).mkdir(exist_ok=True)
+                    template_filepath = os.path.join(chat_template_dir, f"{template_name}.jinja")
+                    with open(template_filepath, "w", encoding="utf-8") as f:
+                        f.write(template)
+                    logger.info(f"chat template saved in {template_filepath}")
+                    saved_raw_chat_template_files.append(template_filepath)
+            if "chat_template" in tokenizer_config:
+                tokenizer_config.pop("chat_template")  # To ensure it doesn't somehow end up in the config too
+        elif isinstance(self.chat_template, dict):
+            # Legacy format for multiple templates:
+            # chat template dicts are saved to the config as lists of dicts with fixed key names.
+            tokenizer_config["chat_template"] = [{"name": k, "template": v} for k, v in self.chat_template.items()]
+        elif self.chat_template is not None:
+            # Legacy format for single templates: Just make them a key in tokenizer_config.json
+            tokenizer_config["chat_template"] = self.chat_template
+        return tokenizer_config, saved_raw_chat_template_files
+
+    def save_pretrained(
+        self,
+        save_directory: Union[str, os.PathLike],
+        legacy_format: Optional[bool] = None,
+        filename_prefix: Optional[str] = None,
+        push_to_hub: bool = False,
+        **kwargs,
+    ) -> tuple[str]:
+        """
+        Save the full tokenizer state.
+
+
+        This method make sure the full tokenizer can then be re-loaded using the
+        [`~tokenization_utils_base.PreTrainedTokenizer.from_pretrained`] class method..
+
+        Warning,None This won't save modifications you may have applied to the tokenizer after the instantiation (for
+        instance, modifying `tokenizer.do_lower_case` after creation).
+
+        Args:
+            save_directory (`str` or `os.PathLike`): The path to a directory where the tokenizer will be saved.
+            legacy_format (`bool`, *optional*):
+                Only applicable for a fast tokenizer. If unset (default), will save the tokenizer in the unified JSON
+                format as well as in legacy format if it exists, i.e. with tokenizer specific vocabulary and a separate
+                added_tokens files.
+
+                If `False`, will only save the tokenizer in the unified JSON format. This format is incompatible with
+                "slow" tokenizers (not powered by the *tokenizers* library), so the tokenizer will not be able to be
+                loaded in the corresponding "slow" tokenizer.
+
+                If `True`, will save the tokenizer in legacy format. If the "slow" tokenizer doesn't exits, a value
+                error is raised.
+            filename_prefix (`str`, *optional*):
+                A prefix to add to the names of the files saved by the tokenizer.
+            push_to_hub (`bool`, *optional*, defaults to `False`):
+                Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+                namespace).
+            kwargs (`dict[str, Any]`, *optional*):
+                Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+
+        Returns:
+            A tuple of `str`: The files saved.
+        """
+        use_auth_token = kwargs.pop("use_auth_token", None)
+
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if kwargs.get("token") is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            kwargs["token"] = use_auth_token
+
+        if os.path.isfile(save_directory):
+            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+            return
+
+        os.makedirs(save_directory, exist_ok=True)
+
+        if push_to_hub:
+            commit_message = kwargs.pop("commit_message", None)
+            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+            repo_id = self._create_repo(repo_id, **kwargs)
+            files_timestamps = self._get_files_timestamps(save_directory)
+
+        special_tokens_map_file = os.path.join(
+            save_directory, (filename_prefix + "-" if filename_prefix else "") + SPECIAL_TOKENS_MAP_FILE
+        )
+        tokenizer_config_file = os.path.join(
+            save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE
+        )
+
+        tokenizer_config = copy.deepcopy(self.init_kwargs)
+
+        # Let's save the init kwargs
+        target_keys = set(self.init_kwargs.keys())
+        # Let's save the special tokens map (only the strings)
+        target_keys.update(["model_max_length", "clean_up_tokenization_spaces"])
+
+        for k in target_keys:
+            if hasattr(self, k):
+                tokenizer_config[k] = getattr(self, k)
+
+        # Let's make sure we properly save the special tokens
+        tokenizer_config.update(self.special_tokens_map)
+        if "extra_special_tokens" not in tokenizer_config:
+            tokenizer_config["extra_special_tokens"] = self.extra_special_tokens
+            tokenizer_config.update(self.extra_special_tokens)
+
+        save_jinja_files = kwargs.get("save_jinja_files", True)
+        tokenizer_config, saved_raw_chat_template_files = self.save_chat_templates(
+            save_directory, tokenizer_config, filename_prefix, save_jinja_files
+        )
+
+        if len(self.init_inputs) > 0:
+            tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
+        for file_id in self.vocab_files_names:
+            tokenizer_config.pop(file_id, None)
+
+        # no typefields, this way old fast and slow can load it
+        tokenizer_config = self.convert_added_tokens(tokenizer_config, add_type_field=True, save=True)
+
+        # Process added tokens separately: allows previous versions to ignore it!
+        added_tokens = {}
+        for key, value in self.added_tokens_decoder.items():
+            added_tokens[key] = value.__getstate__()
+        tokenizer_config["added_tokens_decoder"] = added_tokens
+
+        # Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained
+        tokenizer_class = self.__class__.__name__
+        # Remove the Fast at the end if we can save the slow tokenizer
+        if tokenizer_class.endswith("Fast") and getattr(self, "can_save_slow_tokenizer", False):
+            tokenizer_class = tokenizer_class[:-4]
+        tokenizer_config["tokenizer_class"] = tokenizer_class
+        if getattr(self, "_auto_map", None) is not None:
+            tokenizer_config["auto_map"] = self._auto_map
+        if getattr(self, "_processor_class", None) is not None:
+            tokenizer_config["processor_class"] = self._processor_class
+
+        # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
+        # loaded from the Hub.
+        if self._auto_class is not None:
+            custom_object_save(self, save_directory, config=tokenizer_config)
+
+        # remove private information
+        if "name_or_path" in tokenizer_config:
+            tokenizer_config.pop("name_or_path")
+            tokenizer_config.pop("special_tokens_map_file", None)
+            tokenizer_config.pop("tokenizer_file", None)
+        if "device_map" in tokenizer_config:
+            tokenizer_config.pop("device_map")
+
+        with open(tokenizer_config_file, "w", encoding="utf-8") as f:
+            out_str = json.dumps(tokenizer_config, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
+            f.write(out_str)
+        logger.info(f"tokenizer config file saved in {tokenizer_config_file}")
+
+        # Sanitize AddedTokens in special_tokens_map
+
+        # kept for forward compatibility, will be removed in transoformers 5. Typefields are not saved for FC, special should not be save either
+        write_dict = self.convert_added_tokens(self.special_tokens_map_extended, save=True, add_type_field=False)
+        with open(special_tokens_map_file, "w", encoding="utf-8") as f:
+            out_str = json.dumps(write_dict, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
+            f.write(out_str)
+        logger.info(f"Special tokens file saved in {special_tokens_map_file}")
+
+        file_names = (tokenizer_config_file, special_tokens_map_file, *saved_raw_chat_template_files)
+
+        save_files = self._save_pretrained(
+            save_directory=save_directory,
+            file_names=file_names,
+            legacy_format=legacy_format,
+            filename_prefix=filename_prefix,
+        )
+
+        if push_to_hub:
+            self._upload_modified_files(
+                save_directory,
+                repo_id,
+                files_timestamps,
+                commit_message=commit_message,
+                token=kwargs.get("token"),
+            )
+
+        return save_files
+
+    def _save_pretrained(
+        self,
+        save_directory: Union[str, os.PathLike],
+        file_names: tuple[str],
+        legacy_format: Optional[bool] = None,
+        filename_prefix: Optional[str] = None,
+    ) -> tuple[str]:
+        """
+        Save a tokenizer using the slow-tokenizer/legacy format: vocabulary + added tokens.
+
+        Fast tokenizers can also be saved in a unique JSON file containing {config + vocab + added-tokens} using the
+        specific [`~tokenization_utils_fast.PreTrainedTokenizerFast._save_pretrained`]
+        """
+        if legacy_format is False:
+            raise ValueError(
+                "Only fast tokenizers (instances of PreTrainedTokenizerFast) can be saved in non legacy format."
+            )
+
+        save_directory = str(save_directory)
+
+        added_tokens_file = os.path.join(
+            save_directory, (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE
+        )
+        # the new get_added_vocab() also returns special tokens and tokens that have an index < vocab_size
+        added_vocab = {tok: index for tok, index in self.added_tokens_encoder.items() if index >= self.vocab_size}
+        if added_vocab:
+            with open(added_tokens_file, "w", encoding="utf-8") as f:
+                out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
+                f.write(out_str)
+                logger.info(f"added tokens file saved in {added_tokens_file}")
+
+        vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix)
+
+        return file_names + vocab_files + (added_tokens_file,)
+
+    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
+        """
+        Save only the vocabulary of the tokenizer (vocabulary + added tokens).
+
+        This method won't save the configuration and special token mappings of the tokenizer. Use
+        [`~PreTrainedTokenizerFast._save_pretrained`] to save the whole state of the tokenizer.
+
+        Args:
+            save_directory (`str`):
+                The directory in which to save the vocabulary.
+            filename_prefix (`str`, *optional*):
+                An optional prefix to add to the named of the saved files.
+
+        Returns:
+            `Tuple(str)`: Paths to the files saved.
+        """
+        raise NotImplementedError
+
+    def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> list[str]:
+        """
+        Converts a string into a sequence of tokens, replacing unknown tokens with the `unk_token`.
+
+        Args:
+            text (`str`):
+                The sequence to be encoded.
+            pair (`str`, *optional*):
+                A second sequence to be encoded with the first.
+            add_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not to add the special tokens associated with the corresponding model.
+            kwargs (additional keyword arguments, *optional*):
+                Will be passed to the underlying model specific encode method. See details in
+                [`~PreTrainedTokenizerBase.__call__`]
+
+        Returns:
+            `list[str]`: The list of tokens.
+        """
+        raise NotImplementedError
+
+    @add_end_docstrings(
+        ENCODE_KWARGS_DOCSTRING,
+        """
+            **kwargs: Passed along to the `.tokenize()` method.
+        """,
+        """
+        Returns:
+            `list[int]`, `torch.Tensor`, `tf.Tensor` or `np.ndarray`: The tokenized ids of the text.
+        """,
+    )
+    def encode(
+        self,
+        text: Union[TextInput, PreTokenizedInput, EncodedInput],
+        text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
+        add_special_tokens: bool = True,
+        padding: Union[bool, str, PaddingStrategy] = False,
+        truncation: Union[bool, str, TruncationStrategy, None] = None,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        **kwargs,
+    ) -> list[int]:
+        """
+        Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary.
+
+        Same as doing `self.convert_tokens_to_ids(self.tokenize(text))`.
+
+        Args:
+            text (`str`, `list[str]` or `list[int]`):
+                The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the
+                `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
+                method).
+            text_pair (`str`, `list[str]` or `list[int]`, *optional*):
+                Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using
+                the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
+                method).
+        """
+        encoded_inputs = self.encode_plus(
+            text,
+            text_pair=text_pair,
+            add_special_tokens=add_special_tokens,
+            padding=padding,
+            truncation=truncation,
+            max_length=max_length,
+            stride=stride,
+            padding_side=padding_side,
+            return_tensors=return_tensors,
+            **kwargs,
+        )
+
+        return encoded_inputs["input_ids"]
+
+    def num_special_tokens_to_add(self, pair: bool = False) -> int:
+        raise NotImplementedError
+
+    def _get_padding_truncation_strategies(
+        self, padding=False, truncation=None, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
+    ):
+        """
+        Find the correct padding/truncation strategy
+        """
+
+        # Backward compatibility for previous behavior, maybe we should deprecate it:
+        # If you only set max_length, it activates truncation for max_length
+        if max_length is not None and padding is False and truncation is None:
+            if verbose:
+                if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
+                    logger.warning(
+                        "Truncation was not explicitly activated but `max_length` is provided a specific value, please"
+                        " use `truncation=True` to explicitly truncate examples to max length. Defaulting to"
+                        " 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the"
+                        " tokenizer you can select this strategy more precisely by providing a specific strategy to"
+                        " `truncation`."
+                    )
+                self.deprecation_warnings["Truncation-not-explicitly-activated"] = True
+            truncation = "longest_first"
+
+        # Get padding strategy
+        if padding is not False:
+            if padding is True:
+                if verbose:
+                    if max_length is not None and (
+                        truncation is None or truncation is False or truncation == "do_not_truncate"
+                    ):
+                        warnings.warn(
+                            "`max_length` is ignored when `padding`=`True` and there is no truncation strategy. "
+                            "To pad to max length, use `padding='max_length'`."
+                        )
+                padding_strategy = PaddingStrategy.LONGEST  # Default to pad to the longest sequence in the batch
+            elif not isinstance(padding, PaddingStrategy):
+                padding_strategy = PaddingStrategy(padding)
+            elif isinstance(padding, PaddingStrategy):
+                padding_strategy = padding
+        else:
+            padding_strategy = PaddingStrategy.DO_NOT_PAD
+
+        # Get truncation strategy
+        if truncation is not False and truncation is not None:
+            if truncation is True:
+                truncation_strategy = (
+                    TruncationStrategy.LONGEST_FIRST
+                )  # Default to truncate the longest sequences in pairs of inputs
+            elif not isinstance(truncation, TruncationStrategy):
+                truncation_strategy = TruncationStrategy(truncation)
+            elif isinstance(truncation, TruncationStrategy):
+                truncation_strategy = truncation
+        else:
+            truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
+
+        # Set max length if needed
+        if max_length is None:
+            if padding_strategy == PaddingStrategy.MAX_LENGTH:
+                if self.model_max_length > LARGE_INTEGER:
+                    if verbose:
+                        if not self.deprecation_warnings.get("Asking-to-pad-to-max_length", False):
+                            logger.warning(
+                                "Asking to pad to max_length but no maximum length is provided and the model has no"
+                                " predefined maximum length. Default to no padding."
+                            )
+                        self.deprecation_warnings["Asking-to-pad-to-max_length"] = True
+                    padding_strategy = PaddingStrategy.DO_NOT_PAD
+                else:
+                    max_length = self.model_max_length
+
+            if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE:
+                if self.model_max_length > LARGE_INTEGER:
+                    if verbose:
+                        if not self.deprecation_warnings.get("Asking-to-truncate-to-max_length", False):
+                            logger.warning(
+                                "Asking to truncate to max_length but no maximum length is provided and the model has"
+                                " no predefined maximum length. Default to no truncation."
+                            )
+                        self.deprecation_warnings["Asking-to-truncate-to-max_length"] = True
+                    truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
+                else:
+                    max_length = self.model_max_length
+
+        # Test if we have a padding token
+        if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.pad_token is None or self.pad_token_id < 0):
+            raise ValueError(
+                "Asking to pad but the tokenizer does not have a padding token. "
+                "Please select a token to use as `pad_token` `(tokenizer.pad_token = tokenizer.eos_token e.g.)` "
+                "or add a new pad token via `tokenizer.add_special_tokens({'pad_token': '[PAD]'})`."
+            )
+
+        # Check that we will truncate to a multiple of pad_to_multiple_of if both are provided
+        if (
+            truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE
+            and padding_strategy != PaddingStrategy.DO_NOT_PAD
+            and pad_to_multiple_of is not None
+            and max_length is not None
+            and (max_length % pad_to_multiple_of != 0)
+        ):
+            raise ValueError(
+                "Truncation and padding are both activated but "
+                f"truncation length ({max_length}) is not a multiple of pad_to_multiple_of ({pad_to_multiple_of})."
+            )
+
+        return padding_strategy, truncation_strategy, max_length, kwargs
+
+    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+    def __call__(
+        self,
+        text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput], None] = None,
+        text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
+        text_target: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput], None] = None,
+        text_pair_target: Optional[
+            Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]
+        ] = None,
+        add_special_tokens: bool = True,
+        padding: Union[bool, str, PaddingStrategy] = False,
+        truncation: Union[bool, str, TruncationStrategy, None] = None,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        is_split_into_words: bool = False,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        return_token_type_ids: Optional[bool] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_overflowing_tokens: bool = False,
+        return_special_tokens_mask: bool = False,
+        return_offsets_mapping: bool = False,
+        return_length: bool = False,
+        verbose: bool = True,
+        **kwargs,
+    ) -> BatchEncoding:
+        """
+        Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
+        sequences.
+
+        Args:
+            text (`str`, `list[str]`, `list[list[str]]`, *optional*):
+                The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+                (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+                `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+            text_pair (`str`, `list[str]`, `list[list[str]]`, *optional*):
+                The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+                (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+                `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+            text_target (`str`, `list[str]`, `list[list[str]]`, *optional*):
+                The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
+                list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),
+                you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+            text_pair_target (`str`, `list[str]`, `list[list[str]]`, *optional*):
+                The sequence or batch of sequences to be encoded as target texts. Each sequence can be a string or a
+                list of strings (pretokenized string). If the sequences are provided as list of strings (pretokenized),
+                you must set `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+        """
+        # To avoid duplicating
+        all_kwargs = {
+            "add_special_tokens": add_special_tokens,
+            "padding": padding,
+            "truncation": truncation,
+            "max_length": max_length,
+            "stride": stride,
+            "is_split_into_words": is_split_into_words,
+            "pad_to_multiple_of": pad_to_multiple_of,
+            "padding_side": padding_side,
+            "return_tensors": return_tensors,
+            "return_token_type_ids": return_token_type_ids,
+            "return_attention_mask": return_attention_mask,
+            "return_overflowing_tokens": return_overflowing_tokens,
+            "return_special_tokens_mask": return_special_tokens_mask,
+            "return_offsets_mapping": return_offsets_mapping,
+            "return_length": return_length,
+            "split_special_tokens": kwargs.pop("split_special_tokens", self.split_special_tokens),
+            "verbose": verbose,
+        }
+
+        if return_tensors in ("tf", "jax"):
+            logger.warning_once(
+                "TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We "
+                "recommend migrating to PyTorch classes or pinning your version of Transformers."
+            )
+        all_kwargs.update(kwargs)
+        if text is None and text_target is None:
+            raise ValueError("You need to specify either `text` or `text_target`.")
+        if text is not None:
+            # The context manager will send the inputs as normal texts and not text_target, but we shouldn't change the
+            # input mode in this case.
+            if not self._in_target_context_manager:
+                self._switch_to_input_mode()
+            encodings = self._call_one(text=text, text_pair=text_pair, **all_kwargs)
+        if text_target is not None:
+            self._switch_to_target_mode()
+            target_encodings = self._call_one(text=text_target, text_pair=text_pair_target, **all_kwargs)
+        # Leave back tokenizer in input mode
+        self._switch_to_input_mode()
+
+        if text_target is None:
+            return encodings
+        elif text is None:
+            return target_encodings
+        else:
+            encodings["labels"] = target_encodings["input_ids"]
+            return encodings
+
+    def _call_one(
+        self,
+        text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]],
+        text_pair: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
+        add_special_tokens: bool = True,
+        padding: Union[bool, str, PaddingStrategy] = False,
+        truncation: Union[bool, str, TruncationStrategy, None] = None,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        is_split_into_words: bool = False,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        return_token_type_ids: Optional[bool] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_overflowing_tokens: bool = False,
+        return_special_tokens_mask: bool = False,
+        return_offsets_mapping: bool = False,
+        return_length: bool = False,
+        verbose: bool = True,
+        split_special_tokens: bool = False,
+        **kwargs,
+    ) -> BatchEncoding:
+        # Input type checking for clearer error
+        def _is_valid_text_input(t):
+            if isinstance(t, str):
+                # Strings are fine
+                return True
+            elif isinstance(t, (list, tuple)):
+                # List are fine as long as they are...
+                if len(t) == 0:
+                    # ... empty
+                    return True
+                elif isinstance(t[0], str):
+                    # ... list of strings
+                    return True
+                elif isinstance(t[0], (list, tuple)):
+                    # ... list with an empty list or with a list of strings
+                    return len(t[0]) == 0 or isinstance(t[0][0], str)
+                else:
+                    return False
+            else:
+                return False
+
+        if not _is_valid_text_input(text):
+            raise ValueError(
+                "text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) "
+                "or `list[list[str]]` (batch of pretokenized examples)."
+            )
+
+        if text_pair is not None and not _is_valid_text_input(text_pair):
+            raise ValueError(
+                "text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) "
+                "or `list[list[str]]` (batch of pretokenized examples)."
+            )
+
+        if is_split_into_words:
+            is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple))
+        else:
+            is_batched = isinstance(text, (list, tuple))
+
+        if is_batched:
+            if isinstance(text_pair, str):
+                raise TypeError(
+                    "when tokenizing batches of text, `text_pair` must be a list or tuple with the same length as"
+                    " `text`."
+                )
+            if text_pair is not None and len(text) != len(text_pair):
+                raise ValueError(
+                    f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:"
+                    f" {len(text_pair)}."
+                )
+            batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text
+            return self.batch_encode_plus(
+                batch_text_or_text_pairs=batch_text_or_text_pairs,
+                add_special_tokens=add_special_tokens,
+                padding=padding,
+                truncation=truncation,
+                max_length=max_length,
+                stride=stride,
+                is_split_into_words=is_split_into_words,
+                pad_to_multiple_of=pad_to_multiple_of,
+                padding_side=padding_side,
+                return_tensors=return_tensors,
+                return_token_type_ids=return_token_type_ids,
+                return_attention_mask=return_attention_mask,
+                return_overflowing_tokens=return_overflowing_tokens,
+                return_special_tokens_mask=return_special_tokens_mask,
+                return_offsets_mapping=return_offsets_mapping,
+                return_length=return_length,
+                verbose=verbose,
+                split_special_tokens=split_special_tokens,
+                **kwargs,
+            )
+        else:
+            return self.encode_plus(
+                text=text,
+                text_pair=text_pair,
+                add_special_tokens=add_special_tokens,
+                padding=padding,
+                truncation=truncation,
+                max_length=max_length,
+                stride=stride,
+                is_split_into_words=is_split_into_words,
+                pad_to_multiple_of=pad_to_multiple_of,
+                padding_side=padding_side,
+                return_tensors=return_tensors,
+                return_token_type_ids=return_token_type_ids,
+                return_attention_mask=return_attention_mask,
+                return_overflowing_tokens=return_overflowing_tokens,
+                return_special_tokens_mask=return_special_tokens_mask,
+                return_offsets_mapping=return_offsets_mapping,
+                return_length=return_length,
+                verbose=verbose,
+                split_special_tokens=split_special_tokens,
+                **kwargs,
+            )
+
+    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+    def encode_plus(
+        self,
+        text: Union[TextInput, PreTokenizedInput, EncodedInput],
+        text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
+        add_special_tokens: bool = True,
+        padding: Union[bool, str, PaddingStrategy] = False,
+        truncation: Union[bool, str, TruncationStrategy, None] = None,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        is_split_into_words: bool = False,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        return_token_type_ids: Optional[bool] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_overflowing_tokens: bool = False,
+        return_special_tokens_mask: bool = False,
+        return_offsets_mapping: bool = False,
+        return_length: bool = False,
+        verbose: bool = True,
+        **kwargs,
+    ) -> BatchEncoding:
+        """
+        Tokenize and prepare for the model a sequence or a pair of sequences.
+
+        
+
+        This method is deprecated, `__call__` should be used instead.
+
+        
+
+        Args:
+            text (`str`, `list[str]` or (for non-fast tokenizers) `list[int]`):
+                The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the
+                `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
+                method).
+            text_pair (`str`, `list[str]` or `list[int]`, *optional*):
+                Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using
+                the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
+                method).
+        """
+
+        padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+            padding=padding,
+            truncation=truncation,
+            max_length=max_length,
+            pad_to_multiple_of=pad_to_multiple_of,
+            verbose=verbose,
+            **kwargs,
+        )
+
+        return self._encode_plus(
+            text=text,
+            text_pair=text_pair,
+            add_special_tokens=add_special_tokens,
+            padding_strategy=padding_strategy,
+            truncation_strategy=truncation_strategy,
+            max_length=max_length,
+            stride=stride,
+            is_split_into_words=is_split_into_words,
+            pad_to_multiple_of=pad_to_multiple_of,
+            padding_side=padding_side,
+            return_tensors=return_tensors,
+            return_token_type_ids=return_token_type_ids,
+            return_attention_mask=return_attention_mask,
+            return_overflowing_tokens=return_overflowing_tokens,
+            return_special_tokens_mask=return_special_tokens_mask,
+            return_offsets_mapping=return_offsets_mapping,
+            return_length=return_length,
+            verbose=verbose,
+            split_special_tokens=kwargs.pop("split_special_tokens", self.split_special_tokens),
+            **kwargs,
+        )
+
+    def _encode_plus(
+        self,
+        text: Union[TextInput, PreTokenizedInput, EncodedInput],
+        text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
+        add_special_tokens: bool = True,
+        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        is_split_into_words: bool = False,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        return_token_type_ids: Optional[bool] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_overflowing_tokens: bool = False,
+        return_special_tokens_mask: bool = False,
+        return_offsets_mapping: bool = False,
+        return_length: bool = False,
+        verbose: bool = True,
+        split_special_tokens: bool = False,
+        **kwargs,
+    ) -> BatchEncoding:
+        raise NotImplementedError
+
+    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+    def batch_encode_plus(
+        self,
+        batch_text_or_text_pairs: Union[
+            list[TextInput],
+            list[TextInputPair],
+            list[PreTokenizedInput],
+            list[PreTokenizedInputPair],
+            list[EncodedInput],
+            list[EncodedInputPair],
+        ],
+        add_special_tokens: bool = True,
+        padding: Union[bool, str, PaddingStrategy] = False,
+        truncation: Union[bool, str, TruncationStrategy, None] = None,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        is_split_into_words: bool = False,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        return_token_type_ids: Optional[bool] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_overflowing_tokens: bool = False,
+        return_special_tokens_mask: bool = False,
+        return_offsets_mapping: bool = False,
+        return_length: bool = False,
+        verbose: bool = True,
+        split_special_tokens: bool = False,
+        **kwargs,
+    ) -> BatchEncoding:
+        """
+        Tokenize and prepare for the model a list of sequences or a list of pairs of sequences.
+
+        
+
+        This method is deprecated, `__call__` should be used instead.
+
+        
+
+        Args:
+            batch_text_or_text_pairs (`list[str]`, `list[tuple[str, str]]`, `list[list[str]]`, `list[tuple[list[str], list[str]]]`, and for not-fast tokenizers, also `list[list[int]]`, `list[tuple[list[int], list[int]]]`):
+                Batch of sequences or pair of sequences to be encoded. This can be a list of
+                string/string-sequences/int-sequences or a list of pair of string/string-sequences/int-sequence (see
+                details in `encode_plus`).
+        """
+
+        # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+        padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+            padding=padding,
+            truncation=truncation,
+            max_length=max_length,
+            pad_to_multiple_of=pad_to_multiple_of,
+            verbose=verbose,
+            **kwargs,
+        )
+
+        return self._batch_encode_plus(
+            batch_text_or_text_pairs=batch_text_or_text_pairs,
+            add_special_tokens=add_special_tokens,
+            padding_strategy=padding_strategy,
+            truncation_strategy=truncation_strategy,
+            max_length=max_length,
+            stride=stride,
+            is_split_into_words=is_split_into_words,
+            pad_to_multiple_of=pad_to_multiple_of,
+            padding_side=padding_side,
+            return_tensors=return_tensors,
+            return_token_type_ids=return_token_type_ids,
+            return_attention_mask=return_attention_mask,
+            return_overflowing_tokens=return_overflowing_tokens,
+            return_special_tokens_mask=return_special_tokens_mask,
+            return_offsets_mapping=return_offsets_mapping,
+            return_length=return_length,
+            verbose=verbose,
+            split_special_tokens=split_special_tokens,
+            **kwargs,
+        )
+
+    def _batch_encode_plus(
+        self,
+        batch_text_or_text_pairs: Union[
+            list[TextInput],
+            list[TextInputPair],
+            list[PreTokenizedInput],
+            list[PreTokenizedInputPair],
+            list[EncodedInput],
+            list[EncodedInputPair],
+        ],
+        add_special_tokens: bool = True,
+        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        is_split_into_words: bool = False,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        return_token_type_ids: Optional[bool] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_overflowing_tokens: bool = False,
+        return_special_tokens_mask: bool = False,
+        return_offsets_mapping: bool = False,
+        return_length: bool = False,
+        verbose: bool = True,
+        split_special_tokens: bool = False,
+        **kwargs,
+    ) -> BatchEncoding:
+        raise NotImplementedError
+
+    def pad(
+        self,
+        encoded_inputs: Union[
+            BatchEncoding,
+            list[BatchEncoding],
+            dict[str, EncodedInput],
+            dict[str, list[EncodedInput]],
+            list[dict[str, EncodedInput]],
+        ],
+        padding: Union[bool, str, PaddingStrategy] = True,
+        max_length: Optional[int] = None,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        verbose: bool = True,
+    ) -> BatchEncoding:
+        """
+        Pad a single encoded input or a batch of encoded inputs up to predefined length or to the max sequence length
+        in the batch.
+
+        Padding side (left/right) padding token ids are defined at the tokenizer level (with `self.padding_side`,
+        `self.pad_token_id` and `self.pad_token_type_id`).
+
+        Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the
+        text followed by a call to the `pad` method to get a padded encoding.
+
+        
+
+        If the `encoded_inputs` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
+        result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of
+        PyTorch tensors, you will lose the specific device of your tensors however.
+
+        
+
+        Args:
+            encoded_inputs ([`BatchEncoding`], list of [`BatchEncoding`], `dict[str, list[int]]`, `dict[str, list[list[int]]` or `list[dict[str, list[int]]]`):
+                Tokenized inputs. Can represent one input ([`BatchEncoding`] or `dict[str, list[int]]`) or a batch of
+                tokenized inputs (list of [`BatchEncoding`], *dict[str, list[list[int]]]* or *list[dict[str,
+                list[int]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader
+                collate function.
+
+                Instead of `list[int]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors), see
+                the note above for the return type.
+            padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
+                 Select a strategy to pad the returned sequences (according to the model's padding side and padding
+                 index) among:
+
+                - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
+                  sequence if provided).
+                - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+                  acceptable input length for the model if that argument is not provided.
+                - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different
+                  lengths).
+            max_length (`int`, *optional*):
+                Maximum length of the returned list and optionally padding length (see above).
+            pad_to_multiple_of (`int`, *optional*):
+                If set will pad the sequence to a multiple of the provided value.
+
+                This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
+                `>= 7.5` (Volta).
+            padding_side (`str`, *optional*):
+                The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+                Default value is picked from the class attribute of the same name.
+            return_attention_mask (`bool`, *optional*):
+                Whether to return the attention mask. If left to the default, will return the attention mask according
+                to the specific tokenizer's default, defined by the `return_outputs` attribute.
+
+                [What are attention masks?](../glossary#attention-mask)
+            return_tensors (`str` or [`~utils.TensorType`], *optional*):
+                If set, will return tensors instead of list of python integers. Acceptable values are:
+
+                - `'tf'`: Return TensorFlow `tf.constant` objects.
+                - `'pt'`: Return PyTorch `torch.Tensor` objects.
+                - `'np'`: Return Numpy `np.ndarray` objects.
+            verbose (`bool`, *optional*, defaults to `True`):
+                Whether or not to print more information and warnings.
+        """
+        if self.__class__.__name__.endswith("Fast"):
+            if not self.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False):
+                logger.warning_advice(
+                    f"You're using a {self.__class__.__name__} tokenizer. Please note that with a fast tokenizer,"
+                    " using the `__call__` method is faster than using a method to encode the text followed by a call"
+                    " to the `pad` method to get a padded encoding."
+                )
+                self.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
+
+        # If we have a list of dicts, let's convert it in a dict of lists
+        # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
+        if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
+            encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0]}
+
+        # The model's main input name, usually `input_ids`, has been passed for padding
+        if self.model_input_names[0] not in encoded_inputs:
+            raise ValueError(
+                "You should supply an encoding or a list of encodings to this method "
+                f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
+            )
+
+        required_input = encoded_inputs[self.model_input_names[0]]
+
+        if required_input is None or (isinstance(required_input, Sized) and len(required_input) == 0):
+            if return_attention_mask:
+                encoded_inputs["attention_mask"] = []
+            return encoded_inputs
+
+        # If we have PyTorch/TF/NumPy tensors/arrays as inputs, we cast them as python objects
+        # and rebuild them afterwards if no return_tensors is specified
+        # Note that we lose the specific device the tensor may be on for PyTorch
+
+        first_element = required_input[0]
+        if isinstance(first_element, (list, tuple)):
+            # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
+            for item in required_input:
+                if len(item) != 0:
+                    first_element = item[0]
+                    break
+        # At this state, if `first_element` is still a list/tuple, it's an empty one so there is nothing to do.
+        if not isinstance(first_element, (int, list, tuple)):
+            if is_tf_tensor(first_element):
+                return_tensors = "tf" if return_tensors is None else return_tensors
+            elif is_torch_tensor(first_element):
+                return_tensors = "pt" if return_tensors is None else return_tensors
+            elif isinstance(first_element, np.ndarray):
+                return_tensors = "np" if return_tensors is None else return_tensors
+            else:
+                raise ValueError(
+                    f"type of {first_element} unknown: {type(first_element)}. "
+                    "Should be one of a python, numpy, pytorch or tensorflow object."
+                )
+
+            for key, value in encoded_inputs.items():
+                encoded_inputs[key] = to_py_obj(value)
+
+        # Convert padding_strategy in PaddingStrategy
+        padding_strategy, _, max_length, _ = self._get_padding_truncation_strategies(
+            padding=padding, max_length=max_length, verbose=verbose
+        )
+
+        required_input = encoded_inputs[self.model_input_names[0]]
+        if required_input and not isinstance(required_input[0], (list, tuple)):
+            encoded_inputs = self._pad(
+                encoded_inputs,
+                max_length=max_length,
+                padding_strategy=padding_strategy,
+                pad_to_multiple_of=pad_to_multiple_of,
+                padding_side=padding_side,
+                return_attention_mask=return_attention_mask,
+            )
+            return BatchEncoding(encoded_inputs, tensor_type=return_tensors)
+
+        batch_size = len(required_input)
+        assert all(len(v) == batch_size for v in encoded_inputs.values()), (
+            "Some items in the output dictionary have a different batch size than others."
+        )
+
+        if padding_strategy == PaddingStrategy.LONGEST:
+            max_length = max(len(inputs) for inputs in required_input)
+            padding_strategy = PaddingStrategy.MAX_LENGTH
+
+        batch_outputs = {}
+        for i in range(batch_size):
+            inputs = {k: v[i] for k, v in encoded_inputs.items()}
+            outputs = self._pad(
+                inputs,
+                max_length=max_length,
+                padding_strategy=padding_strategy,
+                pad_to_multiple_of=pad_to_multiple_of,
+                padding_side=padding_side,
+                return_attention_mask=return_attention_mask,
+            )
+
+            for key, value in outputs.items():
+                if key not in batch_outputs:
+                    batch_outputs[key] = []
+                batch_outputs[key].append(value)
+
+        return BatchEncoding(batch_outputs, tensor_type=return_tensors)
+
+    def create_token_type_ids_from_sequences(
+        self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+    ) -> list[int]:
+        """
+        Create the token type IDs corresponding to the sequences passed. [What are token type
+        IDs?](../glossary#token-type-ids)
+
+        Should be overridden in a subclass if the model has a special way of building those.
+
+        Args:
+            token_ids_0 (`list[int]`): The first tokenized sequence.
+            token_ids_1 (`list[int]`, *optional*): The second tokenized sequence.
+
+        Returns:
+            `list[int]`: The token type ids.
+        """
+        cls_len = int(getattr(self, "cls_token_id", None) is not None)
+        sep_len = int(getattr(self, "sep_token_id", None) is not None)
+
+        if token_ids_1 is None:
+            return [0] * (cls_len + len(token_ids_0) + sep_len)
+
+        return [0] * (cls_len + len(token_ids_0) + sep_len) + [1] * (len(token_ids_1) + sep_len)
+
+    def build_inputs_with_special_tokens(
+        self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
+    ) -> list[int]:
+        """
+        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+        adding special tokens.
+
+        This implementation does not add special tokens and this method should be overridden in a subclass.
+
+        Args:
+            token_ids_0 (`list[int]`): The first tokenized sequence.
+            token_ids_1 (`list[int]`, *optional*): The second tokenized sequence.
+
+        Returns:
+            `list[int]`: The model input with special tokens.
+        """
+        if token_ids_1 is None:
+            return token_ids_0
+        return token_ids_0 + token_ids_1
+
+    @add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
+    def prepare_for_model(
+        self,
+        ids: list[int],
+        pair_ids: Optional[list[int]] = None,
+        add_special_tokens: bool = True,
+        padding: Union[bool, str, PaddingStrategy] = False,
+        truncation: Union[bool, str, TruncationStrategy, None] = None,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        return_token_type_ids: Optional[bool] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_overflowing_tokens: bool = False,
+        return_special_tokens_mask: bool = False,
+        return_offsets_mapping: bool = False,
+        return_length: bool = False,
+        verbose: bool = True,
+        prepend_batch_axis: bool = False,
+        **kwargs,
+    ) -> BatchEncoding:
+        """
+        Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
+        adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
+        manages a moving window (with user defined stride) for overflowing tokens. Please Note, for *pair_ids*
+        different than `None` and *truncation_strategy = longest_first* or `True`, it is not possible to return
+        overflowing tokens. Such a combination of arguments will raise an error.
+
+        Args:
+            ids (`list[int]`):
+                Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
+                `convert_tokens_to_ids` methods.
+            pair_ids (`list[int]`, *optional*):
+                Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
+                and `convert_tokens_to_ids` methods.
+        """
+
+        # Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
+        padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
+            padding=padding,
+            truncation=truncation,
+            max_length=max_length,
+            pad_to_multiple_of=pad_to_multiple_of,
+            verbose=verbose,
+            **kwargs,
+        )
+
+        pair = pair_ids is not None
+        len_ids = len(ids)
+        len_pair_ids = len(pair_ids) if pair else 0
+
+        if return_token_type_ids and not add_special_tokens:
+            raise ValueError(
+                "Asking to return token_type_ids while setting add_special_tokens to False "
+                "results in an undefined behavior. Please set add_special_tokens to True or "
+                "set return_token_type_ids to None."
+            )
+
+        if (
+            return_overflowing_tokens
+            and truncation_strategy == TruncationStrategy.LONGEST_FIRST
+            and pair_ids is not None
+        ):
+            raise ValueError(
+                "Not possible to return overflowing tokens for pair of sequences with the "
+                "`longest_first`. Please select another truncation strategy than `longest_first`, "
+                "for instance `only_second` or `only_first`."
+            )
+
+        # Load from model defaults
+        if return_token_type_ids is None:
+            return_token_type_ids = "token_type_ids" in self.model_input_names
+        if return_attention_mask is None:
+            return_attention_mask = "attention_mask" in self.model_input_names
+
+        encoded_inputs = {}
+
+        # Compute the total size of the returned encodings
+        total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
+
+        # Truncation: Handle max sequence length
+        overflowing_tokens = []
+        if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
+            ids, pair_ids, overflowing_tokens = self.truncate_sequences(
+                ids,
+                pair_ids=pair_ids,
+                num_tokens_to_remove=total_len - max_length,
+                truncation_strategy=truncation_strategy,
+                stride=stride,
+            )
+
+        if return_overflowing_tokens:
+            encoded_inputs["overflowing_tokens"] = overflowing_tokens
+            encoded_inputs["num_truncated_tokens"] = total_len - max_length
+
+        # Add special tokens
+        if add_special_tokens:
+            sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
+            token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
+        else:
+            sequence = ids + pair_ids if pair else ids
+            token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else [])
+
+        # Build output dictionary
+        encoded_inputs["input_ids"] = sequence
+        if return_token_type_ids:
+            encoded_inputs["token_type_ids"] = token_type_ids
+        if return_special_tokens_mask:
+            if add_special_tokens:
+                encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
+            else:
+                encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
+
+        # Check lengths
+        self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose)
+
+        # Padding
+        if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
+            encoded_inputs = self.pad(
+                encoded_inputs,
+                max_length=max_length,
+                padding=padding_strategy.value,
+                pad_to_multiple_of=pad_to_multiple_of,
+                padding_side=padding_side,
+                return_attention_mask=return_attention_mask,
+            )
+
+        if return_length:
+            encoded_inputs["length"] = len(encoded_inputs["input_ids"])
+
+        batch_outputs = BatchEncoding(
+            encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
+        )
+
+        return batch_outputs
+
+    def truncate_sequences(
+        self,
+        ids: list[int],
+        pair_ids: Optional[list[int]] = None,
+        num_tokens_to_remove: int = 0,
+        truncation_strategy: Union[str, TruncationStrategy] = "longest_first",
+        stride: int = 0,
+    ) -> tuple[list[int], list[int], list[int]]:
+        """
+        Truncates a sequence pair in-place following the strategy.
+
+        Args:
+            ids (`list[int]`):
+                Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and
+                `convert_tokens_to_ids` methods.
+            pair_ids (`list[int]`, *optional*):
+                Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize`
+                and `convert_tokens_to_ids` methods.
+            num_tokens_to_remove (`int`, *optional*, defaults to 0):
+                Number of tokens to remove using the truncation strategy.
+            truncation_strategy (`str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `'longest_first'`):
+                The strategy to follow for truncation. Can be:
+
+                - `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+                  maximum acceptable input length for the model if that argument is not provided. This will truncate
+                  token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a
+                  batch of pairs) is provided.
+                - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+                  maximum acceptable input length for the model if that argument is not provided. This will only
+                  truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+                - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+                  maximum acceptable input length for the model if that argument is not provided. This will only
+                  truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+                - `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater
+                  than the model maximum admissible input size).
+            stride (`int`, *optional*, defaults to 0):
+                If set to a positive number, the overflowing tokens returned will contain some tokens from the main
+                sequence returned. The value of this argument defines the number of additional tokens.
+
+        Returns:
+            `tuple[list[int], list[int], list[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of
+            overflowing tokens. Note: The *longest_first* strategy returns empty list of overflowing tokens if a pair
+            of sequences (or a batch of pairs) is provided.
+        """
+        if num_tokens_to_remove <= 0:
+            return ids, pair_ids, []
+
+        if not isinstance(truncation_strategy, TruncationStrategy):
+            truncation_strategy = TruncationStrategy(truncation_strategy)
+
+        overflowing_tokens = []
+        if truncation_strategy == TruncationStrategy.ONLY_FIRST or (
+            truncation_strategy == TruncationStrategy.LONGEST_FIRST and pair_ids is None
+        ):
+            if len(ids) > num_tokens_to_remove:
+                window_len = min(len(ids), stride + num_tokens_to_remove)
+                if self.truncation_side == "left":
+                    overflowing_tokens = ids[:window_len]
+                    ids = ids[num_tokens_to_remove:]
+                elif self.truncation_side == "right":
+                    overflowing_tokens = ids[-window_len:]
+                    ids = ids[:-num_tokens_to_remove]
+                else:
+                    raise ValueError(f"invalid truncation strategy: {self.truncation_side}, use 'left' or 'right'.")
+
+            else:
+                error_msg = (
+                    f"We need to remove {num_tokens_to_remove} to truncate the input "
+                    f"but the first sequence has a length {len(ids)}. "
+                )
+                if truncation_strategy == TruncationStrategy.ONLY_FIRST:
+                    error_msg = (
+                        error_msg + "Please select another truncation strategy than "
+                        f"{truncation_strategy}, for instance 'longest_first' or 'only_second'."
+                    )
+                logger.error(error_msg)
+        elif truncation_strategy == TruncationStrategy.LONGEST_FIRST:
+            logger.warning(
+                "Be aware, overflowing tokens are not returned for the setting you have chosen,"
+                f" i.e. sequence pairs with the '{TruncationStrategy.LONGEST_FIRST.value}' "
+                "truncation strategy. So the returned list will always be empty even if some "
+                "tokens have been removed."
+            )
+            len_pair_ids = len(pair_ids) if pair_ids is not None else 0
+            len_ids = len(ids)
+            first_remove = min(abs(len_pair_ids - len_ids), num_tokens_to_remove)
+            second_remove = num_tokens_to_remove - first_remove
+            if len_ids > len_pair_ids:
+                ids_to_move = first_remove + second_remove // 2
+                pair_ids_to_move = second_remove - second_remove // 2
+            else:
+                ids_to_move = second_remove // 2
+                pair_ids_to_move = first_remove + second_remove - (second_remove // 2)
+
+            if self.truncation_side == "right":
+                ids = ids[:-ids_to_move] if ids_to_move > 0 else ids
+                pair_ids = pair_ids[:-pair_ids_to_move] if pair_ids is not None and pair_ids_to_move > 0 else pair_ids
+            elif self.truncation_side == "left":
+                ids = ids[ids_to_move:]
+                pair_ids = pair_ids[pair_ids_to_move:] if pair_ids is not None else None
+            else:
+                raise ValueError(f"invalid truncation strategy:{self.truncation_side}")
+
+        elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
+            if len(pair_ids) > num_tokens_to_remove:
+                window_len = min(len(pair_ids), stride + num_tokens_to_remove)
+                if self.truncation_side == "right":
+                    overflowing_tokens = pair_ids[-window_len:]
+                    pair_ids = pair_ids[:-num_tokens_to_remove]
+                elif self.truncation_side == "left":
+                    overflowing_tokens = pair_ids[:window_len]
+                    pair_ids = pair_ids[num_tokens_to_remove:]
+                else:
+                    raise ValueError(f"invalid truncation strategy:{self.truncation_side}")
+            else:
+                logger.error(
+                    f"We need to remove {num_tokens_to_remove} to truncate the input "
+                    f"but the second sequence has a length {len(pair_ids)}. "
+                    f"Please select another truncation strategy than {truncation_strategy}, "
+                    "for instance 'longest_first' or 'only_first'."
+                )
+
+        return (ids, pair_ids, overflowing_tokens)
+
+    def _pad(
+        self,
+        encoded_inputs: Union[dict[str, EncodedInput], BatchEncoding],
+        max_length: Optional[int] = None,
+        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_attention_mask: Optional[bool] = None,
+    ) -> dict:
+        """
+        Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
+
+        Args:
+            encoded_inputs:
+                Dictionary of tokenized inputs (`list[int]`) or batch of tokenized inputs (`list[list[int]]`).
+            max_length: maximum length of the returned list and optionally padding length (see below).
+                Will truncate by taking into account the special tokens.
+            padding_strategy: PaddingStrategy to use for padding.
+
+                - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
+                - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
+                - PaddingStrategy.DO_NOT_PAD: Do not pad
+                The tokenizer padding sides are defined in `padding_side` argument:
+
+                    - 'left': pads on the left of the sequences
+                    - 'right': pads on the right of the sequences
+            pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
+                This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
+                `>= 7.5` (Volta).
+            padding_side:
+                The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+                Default value is picked from the class attribute of the same name.
+            return_attention_mask:
+                (optional) Set to False to avoid returning attention mask (default: set to model specifics)
+        """
+        # Load from model defaults
+        if return_attention_mask is None:
+            return_attention_mask = "attention_mask" in self.model_input_names
+
+        required_input = encoded_inputs[self.model_input_names[0]]
+
+        if padding_strategy == PaddingStrategy.LONGEST:
+            max_length = len(required_input)
+
+        if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
+            max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
+
+        needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
+
+        # Initialize attention mask if not present.
+        if return_attention_mask and "attention_mask" not in encoded_inputs:
+            encoded_inputs["attention_mask"] = [1] * len(required_input)
+
+        if needs_to_be_padded:
+            difference = max_length - len(required_input)
+            padding_side = padding_side if padding_side is not None else self.padding_side
+
+            if padding_side == "right":
+                if return_attention_mask:
+                    encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference
+                if "token_type_ids" in encoded_inputs:
+                    encoded_inputs["token_type_ids"] = (
+                        encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference
+                    )
+                if "special_tokens_mask" in encoded_inputs:
+                    encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference
+                encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference
+            elif padding_side == "left":
+                if return_attention_mask:
+                    encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
+                if "token_type_ids" in encoded_inputs:
+                    encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
+                        "token_type_ids"
+                    ]
+                if "special_tokens_mask" in encoded_inputs:
+                    encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
+                encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
+            else:
+                raise ValueError(f"Invalid padding strategy:{padding_side}")
+
+        return encoded_inputs
+
+    def convert_tokens_to_string(self, tokens: list[str]) -> str:
+        """
+        Converts a sequence of tokens in a single string. The most simple way to do it is `" ".join(tokens)` but we
+        often want to remove sub-word tokenization artifacts at the same time.
+
+        Args:
+            tokens (`list[str]`): The token to join in a string.
+
+        Returns:
+            `str`: The joined tokens.
+        """
+        raise NotImplementedError
+
+    def batch_decode(
+        self,
+        sequences: Union[list[int], list[list[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
+        skip_special_tokens: bool = False,
+        clean_up_tokenization_spaces: Optional[bool] = None,
+        **kwargs,
+    ) -> list[str]:
+        """
+        Convert a list of lists of token ids into a list of strings by calling decode.
+
+        Args:
+            sequences (`Union[list[int], list[list[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
+                List of tokenized input ids. Can be obtained using the `__call__` method.
+            skip_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not to remove special tokens in the decoding.
+            clean_up_tokenization_spaces (`bool`, *optional*):
+                Whether or not to clean up the tokenization spaces. If `None`, will default to
+                `self.clean_up_tokenization_spaces`.
+            kwargs (additional keyword arguments, *optional*):
+                Will be passed to the underlying model specific decode method.
+
+        Returns:
+            `list[str]`: The list of decoded sentences.
+        """
+        return [
+            self.decode(
+                seq,
+                skip_special_tokens=skip_special_tokens,
+                clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+                **kwargs,
+            )
+            for seq in sequences
+        ]
+
+    def decode(
+        self,
+        token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
+        skip_special_tokens: bool = False,
+        clean_up_tokenization_spaces: Optional[bool] = None,
+        **kwargs,
+    ) -> str:
+        """
+        Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
+        tokens and clean up tokenization spaces.
+
+        Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
+
+        Args:
+            token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`):
+                List of tokenized input ids. Can be obtained using the `__call__` method.
+            skip_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not to remove special tokens in the decoding.
+            clean_up_tokenization_spaces (`bool`, *optional*):
+                Whether or not to clean up the tokenization spaces. If `None`, will default to
+                `self.clean_up_tokenization_spaces`.
+            kwargs (additional keyword arguments, *optional*):
+                Will be passed to the underlying model specific decode method.
+
+        Returns:
+            `str`: The decoded sentence.
+        """
+        # Convert inputs to python lists
+        token_ids = to_py_obj(token_ids)
+
+        return self._decode(
+            token_ids=token_ids,
+            skip_special_tokens=skip_special_tokens,
+            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
+            **kwargs,
+        )
+
+    def _decode(
+        self,
+        token_ids: Union[int, list[int]],
+        skip_special_tokens: bool = False,
+        clean_up_tokenization_spaces: Optional[bool] = None,
+        **kwargs,
+    ) -> str:
+        raise NotImplementedError
+
+    def get_special_tokens_mask(
+        self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
+    ) -> list[int]:
+        """
+        Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
+        special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
+
+        Args:
+            token_ids_0 (`list[int]`):
+                List of ids of the first sequence.
+            token_ids_1 (`list[int]`, *optional*):
+                List of ids of the second sequence.
+            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not the token list is already formatted with special tokens for the model.
+
+        Returns:
+            A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+        """
+        assert already_has_special_tokens and token_ids_1 is None, (
+            "You cannot use ``already_has_special_tokens=False`` with this tokenizer. "
+            "Please use a slow (full python) tokenizer to activate this argument. "
+            "Or set `return_special_tokens_mask=True` when calling the encoding method "
+            "to get the special tokens mask in any tokenizer. "
+        )
+
+        all_special_ids = self.all_special_ids  # cache the property
+
+        special_tokens_mask = [1 if token in all_special_ids else 0 for token in token_ids_0]
+
+        return special_tokens_mask
+
+    @staticmethod
+    def clean_up_tokenization(out_string: str) -> str:
+        """
+        Clean up a list of simple English tokenization artifacts like spaces before punctuations and abbreviated forms.
+
+        Args:
+            out_string (`str`): The text to clean up.
+
+        Returns:
+            `str`: The cleaned-up string.
+        """
+        out_string = (
+            out_string.replace(" .", ".")
+            .replace(" ?", "?")
+            .replace(" !", "!")
+            .replace(" ,", ",")
+            .replace(" ' ", "'")
+            .replace(" n't", "n't")
+            .replace(" 'm", "'m")
+            .replace(" 's", "'s")
+            .replace(" 've", "'ve")
+            .replace(" 're", "'re")
+        )
+        return out_string
+
+    def _eventual_warn_about_too_long_sequence(self, ids: list[int], max_length: Optional[int], verbose: bool):
+        """
+        Depending on the input and internal state we might trigger a warning about a sequence that is too long for its
+        corresponding model
+
+        Args:
+            ids (`list[str]`): The ids produced by the tokenization
+            max_length (`int`, *optional*): The max_length desired (does not trigger a warning if it is set)
+            verbose (`bool`): Whether or not to print more information and warnings.
+
+        """
+        if max_length is None and len(ids) > self.model_max_length and verbose and self.model_max_length != 0:
+            if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
+                logger.warning(
+                    "Token indices sequence length is longer than the specified maximum sequence length "
+                    f"for this model ({len(ids)} > {self.model_max_length}). Running this sequence through the model "
+                    "will result in indexing errors"
+                )
+            self.deprecation_warnings["sequence-length-is-longer-than-the-specified-maximum"] = True
+
+    def _switch_to_input_mode(self):
+        """
+        Private method to put the tokenizer in input mode (when it has different modes for input/outputs)
+        """
+        pass
+
+    def _switch_to_target_mode(self):
+        """
+        Private method to put the tokenizer in target mode (when it has different modes for input/outputs)
+        """
+        pass
+
+    @contextmanager
+    def as_target_tokenizer(self):
+        """
+        Temporarily sets the tokenizer for encoding the targets. Useful for tokenizer associated to
+        sequence-to-sequence models that need a slightly different processing for the labels.
+        """
+        warnings.warn(
+            "`as_target_tokenizer` is deprecated and will be removed in v5 of Transformers. You can tokenize your "
+            "labels by using the argument `text_target` of the regular `__call__` method (either in the same call as "
+            "your input texts if you use the same keyword arguments, or in a separate call."
+        )
+        self._switch_to_target_mode()
+        self._in_target_context_manager = True
+        yield
+        self._in_target_context_manager = False
+        self._switch_to_input_mode()
+
+    @classmethod
+    def register_for_auto_class(cls, auto_class="AutoTokenizer"):
+        """
+        Register this class with a given auto class. This should only be used for custom tokenizers as the ones in the
+        library are already mapped with `AutoTokenizer`.
+
+
+
+        Args:
+            auto_class (`str` or `type`, *optional*, defaults to `"AutoTokenizer"`):
+                The auto class to register this new tokenizer with.
+        """
+        if not isinstance(auto_class, str):
+            auto_class = auto_class.__name__
+
+        import transformers.models.auto as auto_module
+
+        if not hasattr(auto_module, auto_class):
+            raise ValueError(f"{auto_class} is not a valid auto class.")
+
+        cls._auto_class = auto_class
+
+    def prepare_seq2seq_batch(
+        self,
+        src_texts: list[str],
+        tgt_texts: Optional[list[str]] = None,
+        max_length: Optional[int] = None,
+        max_target_length: Optional[int] = None,
+        padding: str = "longest",
+        return_tensors: Optional[str] = None,
+        truncation: bool = True,
+        **kwargs,
+    ) -> BatchEncoding:
+        """
+        Prepare model inputs for translation. For best performance, translate one sentence at a time.
+
+        Arguments:
+            src_texts (`list[str]`):
+                List of documents to summarize or source language texts.
+            tgt_texts (`list`, *optional*):
+                List of summaries or target language texts.
+            max_length (`int`, *optional*):
+                Controls the maximum length for encoder inputs (documents to summarize or source language texts) If
+                left unset or set to `None`, this will use the predefined model maximum length if a maximum length is
+                required by one of the truncation/padding parameters. If the model has no specific maximum input length
+                (like XLNet) truncation/padding to a maximum length will be deactivated.
+            max_target_length (`int`, *optional*):
+                Controls the maximum length of decoder inputs (target language texts or summaries) If left unset or set
+                to `None`, this will use the max_length value.
+            padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
+                Activates and controls padding. Accepts the following values:
+
+                - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
+                  sequence if provided).
+                - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
+                  acceptable input length for the model if that argument is not provided.
+                - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
+                  lengths).
+            return_tensors (`str` or [`~utils.TensorType`], *optional*):
+                If set, will return tensors instead of list of python integers. Acceptable values are:
+
+                - `'tf'`: Return TensorFlow `tf.constant` objects.
+                - `'pt'`: Return PyTorch `torch.Tensor` objects.
+                - `'np'`: Return Numpy `np.ndarray` objects.
+            truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `True`):
+                Activates and controls truncation. Accepts the following values:
+
+                - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
+                  to the maximum acceptable input length for the model if that argument is not provided. This will
+                  truncate token by token, removing a token from the longest sequence in the pair if a pair of
+                  sequences (or a batch of pairs) is provided.
+                - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
+                  maximum acceptable input length for the model if that argument is not provided. This will only
+                  truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+                - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
+                  maximum acceptable input length for the model if that argument is not provided. This will only
+                  truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
+                - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
+                  greater than the model maximum admissible input size).
+            **kwargs:
+                Additional keyword arguments passed along to `self.__call__`.
+
+        Return:
+            [`BatchEncoding`]: A [`BatchEncoding`] with the following fields:
+
+            - **input_ids** -- List of token ids to be fed to the encoder.
+            - **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
+            - **labels** -- List of token ids for tgt_texts.
+
+            The full set of keys `[input_ids, attention_mask, labels]`, will only be returned if tgt_texts is passed.
+            Otherwise, input_ids, attention_mask will be the only keys.
+        """
+        # docstyle-ignore
+        formatted_warning = """
+`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of HuggingFace Transformers. Use the regular
+`__call__` method to prepare your inputs and targets.
+
+Here is a short example:
+
+model_inputs = tokenizer(src_texts, text_target=tgt_texts, ...)
+
+If you either need to use different keyword arguments for the source and target texts, you should do two calls like
+this:
+
+model_inputs = tokenizer(src_texts, ...)
+labels = tokenizer(text_target=tgt_texts, ...)
+model_inputs["labels"] = labels["input_ids"]
+
+See the documentation of your specific tokenizer for more details on the specific arguments to the tokenizer of choice.
+For a more complete example, see the implementation of `prepare_seq2seq_batch`.
+"""
+        warnings.warn(formatted_warning, FutureWarning)
+        # mBART-specific kwargs that should be ignored by other models.
+        kwargs.pop("src_lang", None)
+        kwargs.pop("tgt_lang", None)
+        if max_length is None:
+            max_length = self.model_max_length
+        model_inputs = self(
+            src_texts,
+            add_special_tokens=True,
+            return_tensors=return_tensors,
+            max_length=max_length,
+            padding=padding,
+            truncation=truncation,
+            **kwargs,
+        )
+        if tgt_texts is None:
+            return model_inputs
+        # Process tgt_texts
+        if max_target_length is None:
+            max_target_length = max_length
+        with self.as_target_tokenizer():
+            labels = self(
+                tgt_texts,
+                add_special_tokens=True,
+                return_tensors=return_tensors,
+                padding=padding,
+                max_length=max_target_length,
+                truncation=truncation,
+                **kwargs,
+            )
+        model_inputs["labels"] = labels["input_ids"]
+        return model_inputs
+
+
+def get_fast_tokenizer_file(tokenization_files: list[str]) -> str:
+    """
+    Get the tokenization file to use for this version of transformers.
+
+    Args:
+        tokenization_files (`list[str]`): The list of available configuration files.
+
+    Returns:
+        `str`: The tokenization file to use.
+    """
+    tokenizer_files_map = {}
+    for file_name in tokenization_files:
+        search = _re_tokenizer_file.search(file_name)
+        if search is not None:
+            v = search.groups()[0]
+            tokenizer_files_map[v] = file_name
+    available_versions = sorted(tokenizer_files_map.keys())
+
+    # Defaults to FULL_TOKENIZER_FILE and then try to look at some newer versions.
+    tokenizer_file = FULL_TOKENIZER_FILE
+    transformers_version = version.parse(__version__)
+    for v in available_versions:
+        if version.parse(v) <= transformers_version:
+            tokenizer_file = tokenizer_files_map[v]
+        else:
+            # No point going further since the versions are sorted.
+            break
+
+    return tokenizer_file
+
+
+# To update the docstring, we need to copy the method, otherwise we change the original docstring.
+PreTrainedTokenizerBase.push_to_hub = copy_func(PreTrainedTokenizerBase.push_to_hub)
+if PreTrainedTokenizerBase.push_to_hub.__doc__ is not None:
+    PreTrainedTokenizerBase.push_to_hub.__doc__ = PreTrainedTokenizerBase.push_to_hub.__doc__.format(
+        object="tokenizer", object_class="AutoTokenizer", object_files="tokenizer files"
+    )
diff --git a/phivenv/Lib/site-packages/transformers/tokenization_utils_fast.py b/phivenv/Lib/site-packages/transformers/tokenization_utils_fast.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ca5bed60511396bf71e3dd1ad25b1070d67c87c
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/tokenization_utils_fast.py
@@ -0,0 +1,922 @@
+# Copyright 2020 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Tokenization classes for fast tokenizers (provided by HuggingFace's tokenizers library). For slow (python) tokenizers
+see tokenization_utils.py
+"""
+
+import copy
+import json
+import os
+from collections import defaultdict
+from collections.abc import Iterable
+from typing import Any, Optional, Union
+
+import tokenizers.pre_tokenizers as pre_tokenizers_fast
+from tokenizers import Encoding as EncodingFast
+from tokenizers import Tokenizer as TokenizerFast
+from tokenizers.decoders import Decoder as DecoderFast
+from tokenizers.trainers import BpeTrainer, UnigramTrainer, WordLevelTrainer, WordPieceTrainer
+
+from .convert_slow_tokenizer import convert_slow_tokenizer
+from .integrations.ggml import convert_gguf_tokenizer
+from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
+from .tokenization_utils import PreTrainedTokenizer
+from .tokenization_utils_base import (
+    INIT_TOKENIZER_DOCSTRING,
+    AddedToken,
+    BatchEncoding,
+    PreTokenizedInput,
+    PreTokenizedInputPair,
+    PreTrainedTokenizerBase,
+    SpecialTokensMixin,
+    TextInput,
+    TextInputPair,
+    TruncationStrategy,
+)
+from .utils import PaddingStrategy, add_end_docstrings, logging
+
+
+logger = logging.get_logger(__name__)
+
+# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
+TOKENIZER_FILE = "tokenizer.json"
+SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
+TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
+TIKTOKEN_VOCAB_FILE = "tokenizer.model"
+
+# Slow tokenizers have an additional added tokens files
+ADDED_TOKENS_FILE = "added_tokens.json"
+
+INIT_TOKENIZER_DOCSTRING += """
+        tokenizer_object ([`tokenizers.Tokenizer`]):
+            A [`tokenizers.Tokenizer`] object from 🤗 tokenizers to instantiate from. See [Using tokenizers from 🤗
+            tokenizers](../fast_tokenizers) for more information.
+        tokenizer_file ([`str`]):
+            A path to a local JSON file representing a previously serialized [`tokenizers.Tokenizer`] object from 🤗
+            tokenizers.
+"""
+
+MODEL_TO_TRAINER_MAPPING = {
+    "BPE": BpeTrainer,
+    "Unigram": UnigramTrainer,
+    "WordLevel": WordLevelTrainer,
+    "WordPiece": WordPieceTrainer,
+}
+
+VOCAB_FILES_NAMES = {"tokenizer_file": TOKENIZER_FILE, "vocab_file": TIKTOKEN_VOCAB_FILE}
+
+
+@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
+class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
+    """
+    Base class for all fast tokenizers (wrapping HuggingFace tokenizers library).
+
+    Inherits from [`~tokenization_utils_base.PreTrainedTokenizerBase`].
+
+    Handles all the shared methods for tokenization and special tokens, as well as methods for
+    downloading/caching/loading pretrained tokenizers, as well as adding tokens to the vocabulary.
+
+    This class also contains the added tokens in a unified way on top of all tokenizers so we don't have to handle the
+    specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
+    """
+
+    vocab_files_names = VOCAB_FILES_NAMES
+    slow_tokenizer_class: PreTrainedTokenizer = None
+
+    def __init__(self, *args, **kwargs):
+        tokenizer_object = kwargs.pop("tokenizer_object", None)
+        slow_tokenizer = kwargs.pop("__slow_tokenizer", None)
+        gguf_file = kwargs.pop("gguf_file", None)
+        fast_tokenizer_file = kwargs.pop("tokenizer_file", None)
+        from_slow = kwargs.pop("from_slow", False)
+        added_tokens_decoder = kwargs.pop("added_tokens_decoder", {})
+        self.add_prefix_space = kwargs.get("add_prefix_space", False)
+
+        if from_slow and slow_tokenizer is None and self.slow_tokenizer_class is None:
+            raise ValueError(
+                "Cannot instantiate this tokenizer from a slow version. If it's based on sentencepiece, make sure you "
+                "have sentencepiece installed."
+            )
+
+        if tokenizer_object is not None:
+            fast_tokenizer = copy.deepcopy(tokenizer_object)
+        elif fast_tokenizer_file is not None and not from_slow:
+            # We have a serialization from tokenizers which let us directly build the backend
+            fast_tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
+        elif slow_tokenizer:
+            # We need to convert a slow tokenizer to build the backend
+            fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
+        elif gguf_file is not None:
+            # We need to convert a slow tokenizer to build the backend
+            gguf_param = load_gguf_checkpoint(kwargs.get("vocab_file"))
+            architecture = gguf_param["config"]["model_type"]
+            tokenizer_dict = gguf_param["tokenizer"]
+            tokenizer_config = gguf_param["tokenizer_config"]
+            fast_tokenizer, additional_kwargs = convert_gguf_tokenizer(architecture, tokenizer_dict)
+            kwargs.update(tokenizer_config)
+            if len(additional_kwargs) > 0:
+                kwargs.update(additional_kwargs)
+        elif self.slow_tokenizer_class is not None and slow_tokenizer is not False:
+            # We need to create and convert a slow tokenizer to build the backend
+            slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs)
+            fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
+        elif not slow_tokenizer:
+            # We tried loading a slow_tokenizer with spm and failed, try to load with tiktoken
+            self.vocab_file = kwargs.get("vocab_file")
+            self.additional_special_tokens = kwargs.get("additional_special_tokens", [])
+            fast_tokenizer = convert_slow_tokenizer(self, from_tiktoken=True)
+            slow_tokenizer = None
+        else:
+            raise ValueError(
+                "Couldn't instantiate the backend tokenizer from one of: \n"
+                "(1) a `tokenizers` library serialization file, \n"
+                "(2) a slow tokenizer instance to convert or \n"
+                "(3) an equivalent slow tokenizer class to instantiate and convert. \n"
+                "You need to have sentencepiece or tiktoken installed to convert a slow tokenizer to a fast one."
+            )
+
+        self._tokenizer = fast_tokenizer
+
+        if slow_tokenizer is not None:
+            kwargs.update(slow_tokenizer.init_kwargs)
+
+        self._decode_use_source_tokenizer = False
+
+        _truncation = self._tokenizer.truncation
+
+        if _truncation is not None:
+            self._tokenizer.enable_truncation(**_truncation)
+            kwargs.setdefault("max_length", _truncation["max_length"])
+            kwargs.setdefault("truncation_side", _truncation["direction"])
+            kwargs.setdefault("stride", _truncation["stride"])
+            kwargs.setdefault("truncation_strategy", _truncation["strategy"])
+        else:
+            self._tokenizer.no_truncation()
+
+        _padding = self._tokenizer.padding
+        if _padding is not None:
+            self._tokenizer.enable_padding(**_padding)
+            kwargs.setdefault("pad_token", _padding["pad_token"])
+            kwargs.setdefault("pad_token_type_id", _padding["pad_type_id"])
+            kwargs.setdefault("padding_side", _padding["direction"])
+            kwargs.setdefault("max_length", _padding["length"])
+            kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"])
+
+        # We call this after having initialized the backend tokenizer because we update it.
+        super().__init__(**kwargs)
+        self._tokenizer.encode_special_tokens = self.split_special_tokens
+
+        added_tokens_decoder_hash = {hash(repr(token)) for token in self.added_tokens_decoder}
+        tokens_to_add = [
+            token
+            for index, token in sorted(added_tokens_decoder.items(), key=lambda x: x[0])
+            if hash(repr(token)) not in added_tokens_decoder_hash
+        ]
+        encoder = list(self.added_tokens_encoder.keys()) + [str(token) for token in tokens_to_add]
+        # if some of the special tokens are strings, we check if we don't already have a token
+        tokens_to_add += [
+            token for token in self.all_special_tokens_extended if token not in encoder and token not in tokens_to_add
+        ]
+
+        if len(tokens_to_add) > 0:
+            tokens = []
+            special_tokens = self.all_special_tokens
+            for token in tokens_to_add:
+                is_special = (
+                    (token.special or str(token) in special_tokens)
+                    if isinstance(token, AddedToken)
+                    else str(token) in special_tokens
+                )
+                if isinstance(token, str):
+                    token = AddedToken(token, special=is_special)
+                else:
+                    token.special = is_special
+                tokens.append(token)
+            if tokens:
+                self.add_tokens(tokens)
+
+        try:
+            pre_tok_state = json.loads(self.backend_tokenizer.pre_tokenizer.__getstate__())
+            if pre_tok_state.get("add_prefix_space", self.add_prefix_space) != self.add_prefix_space:
+                pre_tok_class = getattr(pre_tokenizers_fast, pre_tok_state.pop("type"))
+                pre_tok_state["add_prefix_space"] = self.add_prefix_space
+                self.backend_tokenizer.pre_tokenizer = pre_tok_class(**pre_tok_state)
+        except Exception:
+            # We'll get an error if there is no pre_tokenizer, or if it's a custom pre_tokenizer that can
+            # not be serialized. In those cases, we just ignore the error as there's no pre_tokenizer
+            # for which we need to update the `add_prefix_space` attribute.
+            pass
+
+    @property
+    def is_fast(self) -> bool:
+        return True
+
+    @property
+    def can_save_slow_tokenizer(self) -> bool:
+        """
+        `bool`: Whether or not the slow tokenizer can be saved. For a sentencepiece based slow tokenizer, this
+        can only be `True` if the original `"sentencepiece.model"` was not deleted.
+        """
+        if "vocab_file" in self.vocab_files_names and self.vocab_files_names["vocab_file"].endswith(".model"):
+            if hasattr(self, "vocab_file") and self.vocab_file:
+                # If the vocab file is a sentencepiece model, we can save it
+                return os.path.isfile(self.vocab_file)
+            return False
+        else:
+            return True
+
+    @property
+    def vocab_size(self) -> int:
+        """
+        `int`: Size of the base vocabulary (without the added tokens).
+        """
+        return self._tokenizer.get_vocab_size(with_added_tokens=False)
+
+    def get_vocab(self) -> dict[str, int]:
+        return self._tokenizer.get_vocab(with_added_tokens=True)
+
+    @property
+    def vocab(self) -> dict[str, int]:
+        return self.get_vocab()
+
+    @property
+    def added_tokens_encoder(self) -> dict[str, int]:
+        """
+        Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
+        optimisation in `self._added_tokens_encoder` for the slow tokenizers.
+        """
+        return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])}
+
+    @property
+    def added_tokens_decoder(self) -> dict[int, AddedToken]:
+        """
+        Returns the added tokens in the vocabulary as a dictionary of index to AddedToken.
+
+        Returns:
+            `dict[str, int]`: The added tokens.
+        """
+        return self._tokenizer.get_added_tokens_decoder()
+
+    def get_added_vocab(self) -> dict[str, int]:
+        """
+        Returns the added tokens in the vocabulary as a dictionary of token to index.
+
+        Returns:
+            `dict[str, int]`: The added tokens.
+        """
+        return {k.content: v for v, k in sorted(self.added_tokens_decoder.items(), key=lambda item: item[0])}
+
+    def __bool__(self) -> bool:
+        """
+        Returns True, to avoid expensive `assert tokenizer` gotchas.
+        """
+        return True
+
+    def __len__(self) -> int:
+        """
+        Size of the full vocabulary with the added tokens.
+        """
+        return self._tokenizer.get_vocab_size(with_added_tokens=True)
+
+    @property
+    def backend_tokenizer(self) -> TokenizerFast:
+        """
+        `tokenizers.implementations.BaseTokenizer`: The Rust tokenizer used as a backend.
+        """
+        return self._tokenizer
+
+    @property
+    def decoder(self) -> DecoderFast:
+        """
+        `tokenizers.decoders.Decoder`: The Rust decoder for this tokenizer.
+        """
+        return self._tokenizer.decoder
+
+    def _convert_encoding(
+        self,
+        encoding: EncodingFast,
+        return_token_type_ids: Optional[bool] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_overflowing_tokens: bool = False,
+        return_special_tokens_mask: bool = False,
+        return_offsets_mapping: bool = False,
+        return_length: bool = False,
+        verbose: bool = True,
+    ) -> tuple[dict[str, Any], list[EncodingFast]]:
+        """
+        Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict and a list
+        of encodings, take care of building a batch from overflowing tokens.
+
+        Overflowing tokens are converted to additional examples (like batches) so the output values of the dict are
+        lists (overflows) of lists (tokens).
+
+        Output shape: (overflows, sequence length)
+        """
+        if return_token_type_ids is None:
+            return_token_type_ids = "token_type_ids" in self.model_input_names
+        if return_attention_mask is None:
+            return_attention_mask = "attention_mask" in self.model_input_names
+
+        if return_overflowing_tokens and encoding.overflowing is not None:
+            encodings = [encoding] + encoding.overflowing
+        else:
+            encodings = [encoding]
+
+        encoding_dict = defaultdict(list)
+        for e in encodings:
+            encoding_dict["input_ids"].append(e.ids)
+
+            if return_token_type_ids:
+                encoding_dict["token_type_ids"].append(e.type_ids)
+            if return_attention_mask:
+                encoding_dict["attention_mask"].append(e.attention_mask)
+            if return_special_tokens_mask:
+                encoding_dict["special_tokens_mask"].append(e.special_tokens_mask)
+            if return_offsets_mapping:
+                encoding_dict["offset_mapping"].append(e.offsets)
+            if return_length:
+                encoding_dict["length"].append(len(e.ids))
+
+        return encoding_dict, encodings
+
+    def convert_tokens_to_ids(self, tokens: Union[str, Iterable[str]]) -> Union[int, list[int]]:
+        """
+        Converts a token string (or a sequence of tokens) in a single integer id (or a Iterable of ids), using the
+        vocabulary.
+
+        Args:
+            tokens (`str` or `Iterable[str]`): One or several token(s) to convert to token id(s).
+
+        Returns:
+            `int` or `list[int]`: The token id or list of token ids.
+        """
+        if isinstance(tokens, str):
+            return self._convert_token_to_id_with_added_voc(tokens)
+
+        return [self._convert_token_to_id_with_added_voc(token) for token in tokens]
+
+    def _convert_token_to_id_with_added_voc(self, token: str) -> int:
+        index = self._tokenizer.token_to_id(token)
+        if index is None:
+            return self.unk_token_id
+        return index
+
+    def _convert_id_to_token(self, index: int) -> Optional[str]:
+        return self._tokenizer.id_to_token(int(index))
+
+    def _add_tokens(self, new_tokens: list[Union[str, AddedToken]], special_tokens=False) -> int:
+        if special_tokens:
+            return self._tokenizer.add_special_tokens(new_tokens)
+
+        return self._tokenizer.add_tokens(new_tokens)
+
+    def num_special_tokens_to_add(self, pair: bool = False) -> int:
+        """
+        Returns the number of added tokens when encoding a sequence with special tokens.
+
+        
+
+        This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not put
+        this inside your training loop.
+
+        
+
+        Args:
+            pair (`bool`, *optional*, defaults to `False`):
+                Whether the number of added tokens should be computed in the case of a sequence pair or a single
+                sequence.
+
+        Returns:
+            `int`: Number of special tokens added to sequences.
+        """
+        return self._tokenizer.num_special_tokens_to_add(pair)
+
+    def convert_ids_to_tokens(
+        self, ids: Union[int, list[int]], skip_special_tokens: bool = False
+    ) -> Union[str, list[str]]:
+        """
+        Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
+        added tokens.
+
+        Args:
+            ids (`int` or `list[int]`):
+                The token id (or token ids) to convert to tokens.
+            skip_special_tokens (`bool`, *optional*, defaults to `False`):
+                Whether or not to remove special tokens in the decoding.
+
+        Returns:
+            `str` or `list[str]`: The decoded token(s).
+        """
+        if isinstance(ids, int):
+            return self._tokenizer.id_to_token(ids)
+        tokens = []
+        # self.all_special_ids is an @property which may be slow, so only compute it once before the loop
+        ids_to_skip = set(self.all_special_ids) if skip_special_tokens else set()
+        for index in ids:
+            index = int(index)
+            if index in ids_to_skip:
+                continue
+            tokens.append(self._tokenizer.id_to_token(index))
+        return tokens
+
+    def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> list[str]:
+        return self.encode_plus(text=text, text_pair=pair, add_special_tokens=add_special_tokens, **kwargs).tokens()
+
+    def set_truncation_and_padding(
+        self,
+        padding_strategy: PaddingStrategy,
+        truncation_strategy: TruncationStrategy,
+        max_length: int,
+        stride: int,
+        pad_to_multiple_of: Optional[int],
+        padding_side: Optional[str],
+    ):
+        """
+        Define the truncation and the padding strategies for fast tokenizers (provided by HuggingFace tokenizers
+        library) and restore the tokenizer settings afterwards.
+
+        The provided tokenizer has no padding / truncation strategy before the managed section. If your tokenizer set a
+        padding / truncation strategy before, then it will be reset to no padding / truncation when exiting the managed
+        section.
+
+        Args:
+            padding_strategy ([`~utils.PaddingStrategy`]):
+                The kind of padding that will be applied to the input
+            truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`]):
+                The kind of truncation that will be applied to the input
+            max_length (`int`):
+                The maximum size of a sequence.
+            stride (`int`):
+                The stride to use when handling overflow.
+            pad_to_multiple_of (`int`, *optional*):
+                If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
+                the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
+            padding_side (`str`, *optional*):
+                The side on which the model should have padding applied. Should be selected between ['right', 'left'].
+                Default value is picked from the class attribute of the same name.
+        """
+        _truncation = self._tokenizer.truncation
+        _padding = self._tokenizer.padding
+        # Set truncation and padding on the backend tokenizer
+        if truncation_strategy == TruncationStrategy.DO_NOT_TRUNCATE:
+            if _truncation is not None:
+                self._tokenizer.no_truncation()
+        else:
+            target = {
+                "max_length": max_length,
+                "stride": stride,
+                "strategy": truncation_strategy.value,
+                "direction": self.truncation_side,
+            }
+
+            # _truncation might contain more keys that the target `transformers`
+            # supports. Use only the target keys to trigger `enable_truncation`.
+            # This should enable this code to works on various `tokenizers`
+            # targets.
+            if _truncation is None:
+                current = None
+            else:
+                current = {k: _truncation.get(k, None) for k in target}
+
+            if current != target:
+                self._tokenizer.enable_truncation(**target)
+
+        if padding_strategy == PaddingStrategy.DO_NOT_PAD:
+            if _padding is not None:
+                self._tokenizer.no_padding()
+        else:
+            length = max_length if padding_strategy == PaddingStrategy.MAX_LENGTH else None
+            target = {
+                "length": length,
+                "direction": padding_side if padding_side is not None else self.padding_side,
+                "pad_id": self.pad_token_id,
+                "pad_token": self.pad_token,
+                "pad_type_id": self.pad_token_type_id,
+                "pad_to_multiple_of": pad_to_multiple_of,
+            }
+            if _padding != target:
+                self._tokenizer.enable_padding(**target)
+
+    def _batch_encode_plus(
+        self,
+        batch_text_or_text_pairs: Union[
+            list[TextInput], list[TextInputPair], list[PreTokenizedInput], list[PreTokenizedInputPair]
+        ],
+        add_special_tokens: bool = True,
+        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        is_split_into_words: bool = False,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[str] = None,
+        return_token_type_ids: Optional[bool] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_overflowing_tokens: bool = False,
+        return_special_tokens_mask: bool = False,
+        return_offsets_mapping: bool = False,
+        return_length: bool = False,
+        verbose: bool = True,
+        split_special_tokens: bool = False,
+    ) -> BatchEncoding:
+        if not isinstance(batch_text_or_text_pairs, (tuple, list)):
+            raise TypeError(
+                f"batch_text_or_text_pairs has to be a list or a tuple (got {type(batch_text_or_text_pairs)})"
+            )
+
+        # Set the truncation and padding strategy and restore the initial configuration
+        self.set_truncation_and_padding(
+            padding_strategy=padding_strategy,
+            truncation_strategy=truncation_strategy,
+            max_length=max_length,
+            stride=stride,
+            pad_to_multiple_of=pad_to_multiple_of,
+            padding_side=padding_side,
+        )
+
+        if self._tokenizer.encode_special_tokens != split_special_tokens:
+            self._tokenizer.encode_special_tokens = split_special_tokens
+
+        encodings = self._tokenizer.encode_batch(
+            batch_text_or_text_pairs,
+            add_special_tokens=add_special_tokens,
+            is_pretokenized=is_split_into_words,
+        )
+
+        # Convert encoding to dict
+        # `Tokens` has type: tuple[
+        #                       list[dict[str, list[list[int]]]] or list[dict[str, 2D-Tensor]],
+        #                       list[EncodingFast]
+        #                    ]
+        # with nested dimensions corresponding to batch, overflows, sequence length
+        tokens_and_encodings = [
+            self._convert_encoding(
+                encoding=encoding,
+                return_token_type_ids=return_token_type_ids,
+                return_attention_mask=return_attention_mask,
+                return_overflowing_tokens=return_overflowing_tokens,
+                return_special_tokens_mask=return_special_tokens_mask,
+                return_offsets_mapping=return_offsets_mapping,
+                return_length=return_length,
+                verbose=verbose,
+            )
+            for encoding in encodings
+        ]
+
+        # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
+        # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length)
+        # (we say ~ because the number of overflow varies with the example in the batch)
+        #
+        # To match each overflowing sample with the original sample in the batch
+        # we add an overflow_to_sample_mapping array (see below)
+        sanitized_tokens = {}
+        for key in tokens_and_encodings[0][0]:
+            stack = [e for item, _ in tokens_and_encodings for e in item[key]]
+            sanitized_tokens[key] = stack
+        sanitized_encodings = [e for _, item in tokens_and_encodings for e in item]
+
+        # If returning overflowing tokens, we need to return a mapping
+        # from the batch idx to the original sample
+        if return_overflowing_tokens:
+            overflow_to_sample_mapping = []
+            for i, (toks, _) in enumerate(tokens_and_encodings):
+                overflow_to_sample_mapping += [i] * len(toks["input_ids"])
+            sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
+
+        for input_ids in sanitized_tokens["input_ids"]:
+            self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
+        return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors)
+
+    def _encode_plus(
+        self,
+        text: Union[TextInput, PreTokenizedInput],
+        text_pair: Optional[Union[TextInput, PreTokenizedInput]] = None,
+        add_special_tokens: bool = True,
+        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
+        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
+        max_length: Optional[int] = None,
+        stride: int = 0,
+        is_split_into_words: bool = False,
+        pad_to_multiple_of: Optional[int] = None,
+        padding_side: Optional[str] = None,
+        return_tensors: Optional[bool] = None,
+        return_token_type_ids: Optional[bool] = None,
+        return_attention_mask: Optional[bool] = None,
+        return_overflowing_tokens: bool = False,
+        return_special_tokens_mask: bool = False,
+        return_offsets_mapping: bool = False,
+        return_length: bool = False,
+        verbose: bool = True,
+        split_special_tokens: bool = False,
+        **kwargs,
+    ) -> BatchEncoding:
+        batched_input = [(text, text_pair)] if text_pair else [text]
+        batched_output = self._batch_encode_plus(
+            batched_input,
+            is_split_into_words=is_split_into_words,
+            add_special_tokens=add_special_tokens,
+            padding_strategy=padding_strategy,
+            truncation_strategy=truncation_strategy,
+            max_length=max_length,
+            stride=stride,
+            pad_to_multiple_of=pad_to_multiple_of,
+            padding_side=padding_side,
+            return_tensors=return_tensors,
+            return_token_type_ids=return_token_type_ids,
+            return_attention_mask=return_attention_mask,
+            return_overflowing_tokens=return_overflowing_tokens,
+            return_special_tokens_mask=return_special_tokens_mask,
+            return_offsets_mapping=return_offsets_mapping,
+            return_length=return_length,
+            verbose=verbose,
+            split_special_tokens=split_special_tokens,
+            **kwargs,
+        )
+
+        # Return tensor is None, then we can remove the leading batch axis
+        # Overflowing tokens are returned as a batch of output so we keep them in this case
+        if return_tensors is None and not return_overflowing_tokens:
+            batched_output = BatchEncoding(
+                {
+                    key: (value[0] if len(value) > 0 and isinstance(value[0], list) else value)
+                    for key, value in batched_output.items()
+                },
+                batched_output.encodings,
+            )
+
+        self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose)
+
+        return batched_output
+
+    def convert_tokens_to_string(self, tokens: list[str]) -> str:
+        return (
+            self.backend_tokenizer.decoder.decode(tokens)
+            if self.backend_tokenizer.decoder is not None
+            else " ".join(tokens)
+        )
+
+    def _decode(
+        self,
+        token_ids: Union[int, list[int]],
+        skip_special_tokens: bool = False,
+        clean_up_tokenization_spaces: Optional[bool] = None,
+        **kwargs,
+    ) -> str:
+        self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
+
+        if isinstance(token_ids, int):
+            token_ids = [token_ids]
+        text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
+
+        clean_up_tokenization_spaces = (
+            clean_up_tokenization_spaces
+            if clean_up_tokenization_spaces is not None
+            else self.clean_up_tokenization_spaces
+        )
+        if clean_up_tokenization_spaces:
+            clean_text = self.clean_up_tokenization(text)
+            return clean_text
+        else:
+            return text
+
+    def _save_pretrained(
+        self,
+        save_directory: Union[str, os.PathLike],
+        file_names: tuple[str],
+        legacy_format: Optional[bool] = None,
+        filename_prefix: Optional[str] = None,
+    ) -> tuple[str]:
+        """
+        Save a tokenizer using the slow-tokenizer/legacy format: vocabulary + added tokens as well as in a unique JSON
+        file containing {config + vocab + added-tokens}.
+        """
+        save_directory = str(save_directory)
+
+        if self.slow_tokenizer_class is None and legacy_format is True:
+            raise ValueError(
+                "Your tokenizer does not have a legacy version defined and therefore cannot register this version. You"
+                " might consider leaving the legacy_format at `None` or setting it to `False`."
+            )
+
+        save_slow = (
+            (legacy_format is None or legacy_format is True)
+            and self.slow_tokenizer_class is not None
+            and self.can_save_slow_tokenizer
+        )
+        save_fast = legacy_format is None or legacy_format is False
+
+        if save_slow:
+            added_tokens_file = os.path.join(
+                save_directory, (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE
+            )
+            # make sure to be forward compatible
+            added_vocab = {tok: index for tok, index in self.added_tokens_encoder.items() if index >= self.vocab_size}
+            if added_vocab:
+                with open(added_tokens_file, "w", encoding="utf-8") as f:
+                    out_str = json.dumps(added_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
+                    f.write(out_str)
+
+            vocab_files = self.save_vocabulary(save_directory, filename_prefix=filename_prefix)
+            file_names = file_names + vocab_files + (added_tokens_file,)
+
+        if save_fast:
+            tokenizer_file = os.path.join(
+                save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_FILE
+            )
+            self.backend_tokenizer.save(tokenizer_file)
+            file_names = file_names + (tokenizer_file,)
+
+        return file_names
+
+    def train_new_from_iterator(
+        self,
+        text_iterator,
+        vocab_size,
+        length=None,
+        new_special_tokens=None,
+        special_tokens_map=None,
+        **kwargs,
+    ):
+        """
+        Trains a tokenizer on a new corpus with the same defaults (in terms of special tokens or tokenization pipeline)
+        as the current one.
+
+        Args:
+            text_iterator (generator of `list[str]`):
+                The training corpus. Should be a generator of batches of texts, for instance a list of lists of texts
+                if you have everything in memory.
+            vocab_size (`int`):
+                The size of the vocabulary you want for your tokenizer.
+            length (`int`, *optional*):
+                The total number of sequences in the iterator. This is used to provide meaningful progress tracking
+            new_special_tokens (list of `str` or `AddedToken`, *optional*):
+                A list of new special tokens to add to the tokenizer you are training.
+            special_tokens_map (`dict[str, str]`, *optional*):
+                If you want to rename some of the special tokens this tokenizer uses, pass along a mapping old special
+                token name to new special token name in this argument.
+            kwargs (`dict[str, Any]`, *optional*):
+                Additional keyword arguments passed along to the trainer from the 🤗 Tokenizers library.
+
+        Returns:
+            [`PreTrainedTokenizerFast`]: A new tokenizer of the same type as the original one, trained on
+            `text_iterator`.
+
+        """
+        tokenizer_json = json.loads(self._tokenizer.to_str())
+        # Remove added tokens for now (uses IDs of tokens)
+        added_tokens = tokenizer_json.pop("added_tokens")
+        # Remove post processor for now (uses IDs of tokens)
+        post_processor = tokenizer_json.pop("post_processor")
+
+        unk_token = None
+        # Remove vocab
+        if tokenizer_json["model"]["type"] == "BPE":
+            tokenizer_json["model"]["vocab"] = {}
+            tokenizer_json["model"]["merges"] = []
+        elif tokenizer_json["model"]["type"] == "Unigram":
+            if tokenizer_json["model"]["unk_id"] is not None:
+                unk_id = tokenizer_json["model"]["unk_id"]
+                unk_token = tokenizer_json["model"]["vocab"][unk_id][0]
+                if special_tokens_map is not None and unk_token in special_tokens_map:
+                    unk_token = special_tokens_map[unk_token]
+                tokenizer_json["model"]["unk_id"] = 0
+                tokenizer_json["model"]["vocab"] = [[unk_token, 0.0]]
+        elif tokenizer_json["model"]["type"] in ["WordLevel", "WordPiece"]:
+            tokenizer_json["model"]["vocab"] = {}
+        else:
+            raise ValueError(
+                f"This method does not support this type of tokenizer (found {tokenizer_json['model']['type']}) "
+                "only BPE, Unigram, WordLevel and WordPiece."
+            )
+
+        if (
+            special_tokens_map is not None
+            and "unk_token" in tokenizer_json["model"]
+            and tokenizer_json["model"]["unk_token"] in special_tokens_map
+        ):
+            tokenizer_json["model"]["unk_token"] = special_tokens_map[tokenizer_json["model"]["unk_token"]]
+
+        tokenizer = TokenizerFast.from_str(json.dumps(tokenizer_json))
+
+        # Get the special tokens from the current tokenizer if none are specified.
+        special_tokens = []
+        for added_token in added_tokens:
+            special = added_token.pop("special", None)
+            _ = added_token.pop("id", None)
+            if tokenizer_json["model"]["type"] != "Unigram" and not special:
+                continue
+            if special_tokens_map is not None and added_token["content"] in special_tokens_map:
+                added_token["content"] = special_tokens_map[added_token["content"]]
+            special_tokens.append(AddedToken(**added_token))
+
+        if new_special_tokens is not None:
+            special_tokens.extend(new_special_tokens)
+
+        # Trainer needs to know the end of word / continuing subword thingies in BPE
+        if (
+            tokenizer_json["model"]["type"] == "BPE"
+            and "continuing_subword_prefix" not in kwargs
+            and tokenizer_json["model"]["continuing_subword_prefix"] is not None
+        ):
+            kwargs["continuing_subword_prefix"] = tokenizer_json["model"]["continuing_subword_prefix"]
+        if (
+            tokenizer_json["model"]["type"] == "BPE"
+            and "end_of_word_suffix" not in kwargs
+            and tokenizer_json["model"]["end_of_word_suffix"] is not None
+        ):
+            kwargs["end_of_word_suffix"] = tokenizer_json["model"]["end_of_word_suffix"]
+        if tokenizer_json["model"]["type"] == "Unigram" and unk_token is not None:
+            kwargs["unk_token"] = unk_token
+        if tokenizer_json["pre_tokenizer"] is not None:
+            if (
+                tokenizer_json["pre_tokenizer"]["type"] == "ByteLevel"
+                or tokenizer_json["pre_tokenizer"]["type"] == "Sequence"
+                and "pretokenizers" in tokenizer_json["pre_tokenizer"]
+                and any(
+                    pretokenizer["type"] == "ByteLevel"
+                    for pretokenizer in tokenizer_json["pre_tokenizer"]["pretokenizers"]
+                )
+            ):
+                kwargs["initial_alphabet"] = pre_tokenizers_fast.ByteLevel.alphabet()
+
+        trainer_class = MODEL_TO_TRAINER_MAPPING[tokenizer_json["model"]["type"]]
+        trainer = trainer_class(vocab_size=vocab_size, special_tokens=special_tokens, **kwargs)
+        tokenizer.train_from_iterator(text_iterator, length=length, trainer=trainer)
+
+        if post_processor is not None:
+            trained_tokenizer_json = json.loads(tokenizer.to_str())
+            # Almost done, we just have to adjust the token IDs in the post processor
+            if "special_tokens" in post_processor:
+                for key in post_processor["special_tokens"]:
+                    tokens = post_processor["special_tokens"][key]["tokens"]
+                    if special_tokens_map is not None:
+                        tokens = [special_tokens_map.get(token, token) for token in tokens]
+                    post_processor["special_tokens"][key]["tokens"] = tokens
+                    for token in tokens:
+                        token_id = tokenizer.token_to_id(token)
+                        if token_id is None:
+                            raise ValueError(
+                                "Attempted to set a token in the post processor that does not exist in the mapping"
+                            )
+
+                    post_processor["special_tokens"][key]["ids"] = [tokenizer.token_to_id(token) for token in tokens]
+
+            for special_token in ["cls", "sep"]:
+                if special_token in post_processor:
+                    token, _ = post_processor[special_token]
+                    if special_tokens_map is not None and token in special_tokens_map:
+                        token = special_tokens_map[token]
+                    token_id = tokenizer.token_to_id(token)
+                    if token_id is None:
+                        raise ValueError(
+                            "Attempted to set a token in the post processor that does not exist in the mapping"
+                        )
+                    post_processor[special_token] = [token, token_id]
+
+            trained_tokenizer_json["post_processor"] = post_processor
+            tokenizer = TokenizerFast.from_str(json.dumps(trained_tokenizer_json))
+
+        kwargs = self.init_kwargs.copy()
+        # Map pad/cls/mask token at the Transformers level
+        special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy()
+        special_tokens_list.remove("additional_special_tokens")
+        for token in special_tokens_list:
+            if getattr(self, token) is not None:
+                special_token = getattr(self, token)
+                if special_tokens_map is not None and special_token in special_tokens_map:
+                    special_token = special_tokens_map[special_token]
+
+                special_token_full = self._special_tokens_map.get(token, None)
+                if isinstance(special_token_full, AddedToken):
+                    # Create an added token with the same parameters except the content
+                    kwargs[token] = AddedToken(
+                        special_token,
+                        single_word=special_token_full.single_word,
+                        lstrip=special_token_full.lstrip,
+                        rstrip=special_token_full.rstrip,
+                        normalized=special_token_full.normalized,
+                        special=True,
+                    )
+                else:
+                    kwargs[token] = special_token
+
+        additional_special_tokens = self.additional_special_tokens
+        if new_special_tokens is not None:
+            additional_special_tokens.extend(new_special_tokens)
+        if len(additional_special_tokens) > 0:
+            kwargs["additional_special_tokens"] = additional_special_tokens
+
+        return self.__class__(tokenizer_object=tokenizer, **kwargs)
diff --git a/phivenv/Lib/site-packages/transformers/trainer.py b/phivenv/Lib/site-packages/transformers/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b931ac0d7ffb6e1c572d24e522a6b0469236136
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/trainer.py
@@ -0,0 +1,5685 @@
+# Copyright 2020-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
+"""
+
+import contextlib
+import copy
+import functools
+import glob
+import importlib.metadata
+import inspect
+import json
+import math
+import os
+import random
+import re
+import shutil
+import sys
+import tempfile
+import time
+import warnings
+from collections.abc import Iterator, Mapping
+from functools import partial
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, Optional, Union
+
+
+# Integrations must be imported before ML frameworks:
+# ruff: isort: off
+from .integrations import (
+    get_reporting_integration_callbacks,
+)
+
+# ruff: isort: on
+
+import huggingface_hub.utils as hf_hub_utils
+import numpy as np
+import torch
+import torch.distributed as dist
+from huggingface_hub import ModelCard, create_repo, upload_folder
+from packaging import version
+from torch import nn
+from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
+
+from . import __version__
+from .configuration_utils import PretrainedConfig
+from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
+from .debug_utils import DebugOption, DebugUnderflowOverflow
+from .feature_extraction_sequence_utils import SequenceFeatureExtractor
+from .feature_extraction_utils import FeatureExtractionMixin
+from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
+from .image_processing_utils import BaseImageProcessor
+from .integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
+from .integrations.tpu import tpu_spmd_dataloader
+from .modelcard import TrainingSummary
+from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
+from .models.auto.modeling_auto import (
+    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
+    MODEL_MAPPING_NAMES,
+)
+from .optimization import Adafactor, get_scheduler
+from .processing_utils import ProcessorMixin
+from .pytorch_utils import (
+    is_torch_greater_or_equal_than_2_3,
+)
+from .tokenization_utils_base import PreTrainedTokenizerBase
+from .trainer_callback import (
+    CallbackHandler,
+    DefaultFlowCallback,
+    ExportableState,
+    PrinterCallback,
+    ProgressCallback,
+    TrainerCallback,
+    TrainerControl,
+    TrainerState,
+)
+from .trainer_pt_utils import (
+    DistributedTensorGatherer,
+    EvalLoopContainer,
+    IterableDatasetShard,
+    LabelSmoother,
+    LayerWiseDummyOptimizer,
+    LengthGroupedSampler,
+    SequentialDistributedSampler,
+    distributed_broadcast_scalars,
+    distributed_concat,
+    find_batch_size,
+    get_model_param_count,
+    get_module_class_from_name,
+    get_parameter_names,
+    nested_concat,
+    nested_detach,
+    nested_numpify,
+    nested_xla_mesh_reduce,
+    reissue_pt_warnings,
+    remove_dummy_checkpoint,
+    set_rng_state_for_device,
+)
+from .trainer_utils import (
+    PREFIX_CHECKPOINT_DIR,
+    BestRun,
+    EvalLoopOutput,
+    EvalPrediction,
+    HPSearchBackend,
+    HubStrategy,
+    PredictionOutput,
+    RemoveColumnsCollator,
+    SaveStrategy,
+    TrainerMemoryTracker,
+    TrainOutput,
+    check_target_module_exists,
+    default_compute_objective,
+    denumpify_detensorize,
+    enable_full_determinism,
+    find_executable_batch_size,
+    get_last_checkpoint,
+    has_length,
+    neftune_post_forward_hook,
+    number_of_arguments,
+    seed_worker,
+    set_seed,
+    speed_metrics,
+)
+from .training_args import OptimizerNames, ParallelMode, TrainingArguments
+from .utils import (
+    ADAPTER_CONFIG_NAME,
+    ADAPTER_SAFE_WEIGHTS_NAME,
+    ADAPTER_WEIGHTS_NAME,
+    CONFIG_NAME,
+    SAFE_WEIGHTS_INDEX_NAME,
+    SAFE_WEIGHTS_NAME,
+    WEIGHTS_INDEX_NAME,
+    WEIGHTS_NAME,
+    XLA_FSDPV2_MIN_VERSION,
+    PushInProgress,
+    PushToHubMixin,
+    can_return_loss,
+    check_torch_load_is_safe,
+    find_labels,
+    is_accelerate_available,
+    is_apollo_torch_available,
+    is_bitsandbytes_available,
+    is_datasets_available,
+    is_galore_torch_available,
+    is_grokadamw_available,
+    is_in_notebook,
+    is_liger_kernel_available,
+    is_lomo_available,
+    is_peft_available,
+    is_safetensors_available,
+    is_sagemaker_dp_enabled,
+    is_sagemaker_mp_enabled,
+    is_schedulefree_available,
+    is_torch_hpu_available,
+    is_torch_mlu_available,
+    is_torch_mps_available,
+    is_torch_musa_available,
+    is_torch_neuroncore_available,
+    is_torch_npu_available,
+    is_torch_optimi_available,
+    is_torch_xla_available,
+    is_torch_xpu_available,
+    is_torchao_available,
+    logging,
+    strtobool,
+)
+from .utils.deprecation import deprecate_kwarg
+from .utils.import_utils import requires
+from .utils.quantization_config import QuantizationMethod
+
+
+DEFAULT_CALLBACKS = [DefaultFlowCallback]
+DEFAULT_PROGRESS_CALLBACK = ProgressCallback
+
+if is_in_notebook():
+    from .utils.notebook import NotebookProgressCallback
+
+    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
+
+if is_datasets_available():
+    import datasets
+
+if is_torch_xla_available():
+    import torch_xla.core.xla_model as xm
+    import torch_xla.debug.metrics as met
+    import torch_xla.runtime as xr
+    from torch_xla import __version__ as XLA_VERSION
+
+    IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
+    if IS_XLA_FSDPV2_POST_2_2:
+        import torch_xla.distributed.spmd as xs
+else:
+    IS_XLA_FSDPV2_POST_2_2 = False
+
+
+if is_sagemaker_mp_enabled():
+    import smdistributed.modelparallel.torch as smp
+    from smdistributed.modelparallel import __version__ as SMP_VERSION
+
+    IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
+
+    from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
+else:
+    IS_SAGEMAKER_MP_POST_1_10 = False
+
+
+if is_safetensors_available():
+    import safetensors.torch
+
+if is_peft_available():
+    from peft import PeftModel
+
+
+if is_accelerate_available():
+    from accelerate import Accelerator, skip_first_batches
+    from accelerate import __version__ as accelerate_version
+    from accelerate.state import AcceleratorState
+    from accelerate.utils import (
+        AutocastKwargs,
+        DistributedDataParallelKwargs,
+        DistributedType,
+        load_fsdp_model,
+        load_fsdp_optimizer,
+        save_fsdp_model,
+        save_fsdp_optimizer,
+    )
+
+    DATA_SAMPLERS = [RandomSampler]
+    if version.parse(accelerate_version) > version.parse("1.3.0"):
+        from accelerate.utils import TorchTensorParallelPlugin
+    if version.parse(accelerate_version) > version.parse("0.23.0"):
+        from accelerate.data_loader import SeedableRandomSampler
+
+        DATA_SAMPLERS += [SeedableRandomSampler]
+
+    if is_deepspeed_available():
+        from accelerate.utils import DeepSpeedSchedulerWrapper
+
+if is_accelerate_available("0.28.0"):
+    from accelerate.utils import DataLoaderConfiguration
+
+
+def _is_peft_model(model):
+    if is_peft_available():
+        classes_to_check = (PeftModel,)
+        # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321
+        if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"):
+            from peft import PeftMixedModel
+
+            classes_to_check = (*classes_to_check, PeftMixedModel)
+        return isinstance(model, classes_to_check)
+    return False
+
+
+def _get_fsdp_ckpt_kwargs():
+    # TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release
+    if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters):
+        return {"adapter_only": True}
+    else:
+        return {}
+
+
+def safe_globals():
+    # Starting from version 2.4 PyTorch introduces a check for the objects loaded
+    # with torch.load(weights_only=True). Starting from 2.6 weights_only=True becomes
+    # a default and requires allowlisting of objects being loaded.
+    # See: https://github.com/pytorch/pytorch/pull/137602
+    # See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
+    # See: https://github.com/huggingface/accelerate/pull/3036
+    if version.parse(torch.__version__).release < version.parse("2.6").release:
+        return contextlib.nullcontext()
+
+    np_core = np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core
+    allowlist = [np_core.multiarray._reconstruct, np.ndarray, np.dtype]
+    # numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for
+    # all versions of numpy
+    allowlist += [type(np.dtype(np.uint32))]
+
+    return torch.serialization.safe_globals(allowlist)
+
+
+if TYPE_CHECKING:
+    import optuna
+
+    if is_datasets_available():
+        import datasets
+
+logger = logging.get_logger(__name__)
+
+
+# Name of the files used for checkpointing
+TRAINING_ARGS_NAME = "training_args.bin"
+TRAINER_STATE_NAME = "trainer_state.json"
+OPTIMIZER_NAME = "optimizer.pt"
+SCALER_NAME = "scaler.pt"
+OPTIMIZER_NAME_BIN = "optimizer.bin"
+SCHEDULER_NAME = "scheduler.pt"
+FSDP_MODEL_NAME = "pytorch_model_fsdp"
+
+
+@requires(
+    backends=(
+        "torch",
+        "accelerate",
+    )
+)
+class Trainer:
+    """
+    Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
+
+    Args:
+        model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
+            The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
+
+            
+
+            [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
+            your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers
+            models.
+
+            
+
+        args ([`TrainingArguments`], *optional*):
+            The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
+            `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
+        data_collator (`DataCollator`, *optional*):
+            The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
+            default to [`default_data_collator`] if no `processing_class` is provided, an instance of
+            [`DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or tokenizer.
+        train_dataset (Union[`torch.utils.data.Dataset`, `torch.utils.data.IterableDataset`, `datasets.Dataset`], *optional*):
+            The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
+            `model.forward()` method are automatically removed.
+
+            Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
+            distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
+            `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
+            manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
+            sets the seed of the RNGs used.
+        eval_dataset (Union[`torch.utils.data.Dataset`, dict[str, `torch.utils.data.Dataset`, `datasets.Dataset`]), *optional*):
+             The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
+             `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
+             dataset prepending the dictionary key to the metric name.
+        processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
+            Processing class used to process the data. If provided, will be used to automatically process the inputs
+            for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
+            reuse the fine-tuned model.
+            This supersedes the `tokenizer` argument, which is now deprecated.
+        model_init (`Callable[[], PreTrainedModel]`, *optional*):
+            A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
+            from a new instance of the model as given by this function.
+
+            The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
+            be able to choose different architectures according to hyper parameters (such as layer count, sizes of
+            inner layers, dropout probabilities etc).
+        compute_loss_func (`Callable`, *optional*):
+            A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated
+            batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618) used by [`Trainer`].
+        compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
+            The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
+            a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to
+            `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered
+            after the last eval batch to signal that the function needs to calculate and return the global summary
+            statistics rather than accumulating the batch-level statistics
+        callbacks (List of [`TrainerCallback`], *optional*):
+            A list of callbacks to customize the training loop. Will add those to the list of default callbacks
+            detailed in [here](callback).
+
+            If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
+        optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
+            A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
+            model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
+        optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], dict[str, Any]]`, *optional*):
+            A tuple containing the optimizer class and keyword arguments to use.
+            Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument.
+
+            Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer.
+        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
+            A function that preprocess the logits right before caching them at each evaluation step. Must take two
+            tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
+            by this function will be reflected in the predictions received by `compute_metrics`.
+
+            Note that the labels (second parameter) will be `None` if the dataset does not have them.
+
+    Important attributes:
+
+        - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
+          subclass.
+        - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
+          original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
+          the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
+          model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
+        - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
+          data parallelism, this means some of the model layers are split on different GPUs).
+        - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
+          to `False` if model parallel or deepspeed is used, or if the default
+          `TrainingArguments.place_model_on_device` is overridden to return `False` .
+        - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
+          in `train`)
+
+    """
+
+    # Those are used as methods of the Trainer in examples.
+    from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
+
+    @deprecate_kwarg("tokenizer", new_name="processing_class", version="5.0.0", raise_if_both_names=True)
+    def __init__(
+        self,
+        model: Union[PreTrainedModel, nn.Module, None] = None,
+        args: TrainingArguments = None,
+        data_collator: Optional[DataCollator] = None,
+        train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
+        eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None,
+        processing_class: Optional[
+            Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
+        ] = None,
+        model_init: Optional[Callable[[], PreTrainedModel]] = None,
+        compute_loss_func: Optional[Callable] = None,
+        compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
+        callbacks: Optional[list[TrainerCallback]] = None,
+        optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
+        optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
+        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
+    ):
+        if args is None:
+            output_dir = "tmp_trainer"
+            logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
+            args = TrainingArguments(output_dir=output_dir)
+        if args.batch_eval_metrics and compute_metrics is not None:
+            if "compute_result" not in inspect.signature(compute_metrics).parameters:
+                raise ValueError(
+                    "When using `batch_eval_metrics`, your `compute_metrics` function must take a `compute_result`"
+                    " boolean argument which will be triggered after the last batch of the eval set to signal that the"
+                    " summary statistics should be returned by the function."
+                )
+        if args.eval_strategy is not None and args.eval_strategy != "no" and eval_dataset is None:
+            raise ValueError(
+                f"You have set `args.eval_strategy` to {args.eval_strategy} but you didn't pass an `eval_dataset` to `Trainer`. Either set `args.eval_strategy` to `no` or pass an `eval_dataset`. "
+            )
+        if args.save_strategy == SaveStrategy.BEST or args.load_best_model_at_end:
+            if args.metric_for_best_model is None:
+                raise ValueError(
+                    "`args.metric_for_best_model` must be provided when using 'best' save_strategy or if `args.load_best_model_at_end` is set to `True`."
+                )
+
+        self.args = args
+        self.compute_loss_func = compute_loss_func
+        # Seed must be set before instantiating the model when using model
+        enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
+
+        self.hp_name = None
+        self.deepspeed = None
+        self.is_in_train = False
+        self.model = model
+        self.create_accelerator_and_postprocess()
+
+        # memory metrics - must set up as early as possible
+        self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
+        self._memory_tracker.start()
+
+        # set the correct log level depending on the node
+        log_level = args.get_process_log_level()
+        logging.set_verbosity(log_level)
+
+        # force device and distributed setup init explicitly
+        args._setup_devices
+
+        if model is None:
+            if model_init is not None:
+                self.model_init = model_init
+                model = self.call_model_init()
+            else:
+                raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
+        else:
+            if model_init is not None:
+                warnings.warn(
+                    "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
+                    " overwrite your model when calling the `train` method. This will become a fatal error in the next"
+                    " release.",
+                    FutureWarning,
+                )
+            self.model_init = model_init
+
+        if model.__class__.__name__ in MODEL_MAPPING_NAMES:
+            raise ValueError(
+                f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
+                "computes hidden states and does not accept any labels. You should choose a model with a head "
+                "suitable for your task like any of the `AutoModelForXxx` listed at "
+                "https://huggingface.co/docs/transformers/model_doc/auto"
+            )
+
+        if getattr(model, "is_parallelizable", False) and getattr(model, "model_parallel", False):
+            self.is_model_parallel = True
+        else:
+            self.is_model_parallel = False
+
+        if getattr(model, "hf_device_map", None) is not None:
+            devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]
+            if len(devices) > 1:
+                self.is_model_parallel = True
+            elif len(devices) == 1:
+                self.is_model_parallel = self.args.device != torch.device(devices[0])
+            else:
+                self.is_model_parallel = False
+
+            # warn users
+            if self.is_model_parallel:
+                logger.info(
+                    "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set"
+                    " to `True` to avoid any unexpected behavior such as device placement mismatching."
+                )
+
+        if self.args.use_liger_kernel:
+            if is_liger_kernel_available():
+                from liger_kernel.transformers import _apply_liger_kernel_to_instance
+
+                # Prepare kernel config - use provided config or default (empty dict for default behavior)
+                kernel_config = self.args.liger_kernel_config if self.args.liger_kernel_config is not None else {}
+
+                if isinstance(model, PreTrainedModel):
+                    # Patch the model with liger kernels. Use the specified or default kernel configurations.
+                    _apply_liger_kernel_to_instance(model=model, **kernel_config)
+                elif hasattr(model, "get_base_model") and isinstance(model.get_base_model(), PreTrainedModel):
+                    # Patch the base model with liger kernels where model is a PeftModel. Use the specified or default kernel configurations.
+                    _apply_liger_kernel_to_instance(model=model.get_base_model(), **kernel_config)
+                else:
+                    logger.warning(
+                        "The model is not an instance of PreTrainedModel. No liger kernels will be applied."
+                    )
+            else:
+                raise ImportError(
+                    "You have set `use_liger_kernel` to `True` but liger-kernel >= 0.3.0 is not available. "
+                    "Please install it with `pip install liger-kernel`"
+                )
+
+        _is_quantized_and_base_model = getattr(model, "is_quantized", False) and not getattr(
+            model, "_hf_peft_config_loaded", False
+        )
+        _quantization_method_supports_training = (
+            getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
+        )
+
+        _is_model_quantized_and_qat_trainable = getattr(model, "hf_quantizer", None) is not None and getattr(
+            model.hf_quantizer, "is_qat_trainable", False
+        )
+
+        # Filter out quantized + compiled models
+        if _is_quantized_and_base_model and hasattr(model, "_orig_mod"):
+            raise ValueError(
+                "You cannot fine-tune quantized model with `torch.compile()` make sure to pass a non-compiled model when fine-tuning a quantized model with PEFT"
+            )
+
+        # At this stage the model is already loaded
+        if _is_quantized_and_base_model and not _is_peft_model(model) and not _is_model_quantized_and_qat_trainable:
+            raise ValueError(
+                "You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of"
+                " the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
+                " for more details"
+            )
+        elif _is_quantized_and_base_model and not _quantization_method_supports_training:
+            raise ValueError(
+                f"The model you are trying to fine-tune is quantized with {model.hf_quantizer.quantization_config.quant_method}"
+                " but that quantization method do not support training. Please open an issue on GitHub: https://github.com/huggingface/transformers"
+                f" to request the support for training support for {model.hf_quantizer.quantization_config.quant_method}"
+            )
+
+        self.is_fsdp_xla_enabled = args.fsdp_config["xla"]
+        if len(args.fsdp) > 0:
+            if self.is_deepspeed_enabled:
+                raise ValueError(
+                    "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
+                )
+            if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED:
+                raise ValueError("Using fsdp only works in distributed training.")
+
+        # one place to sort out whether to place the model on device or not
+        # postpone switching model to cuda when:
+        # 1. MP - since we are trying to fit a much bigger than 1 gpu model
+        # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
+        #    and we only use deepspeed for training at the moment
+        # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
+        # 4. FSDP - same as MP
+        self.place_model_on_device = args.place_model_on_device
+        if (
+            self.is_model_parallel
+            or self.is_deepspeed_enabled
+            or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
+            or self.is_fsdp_xla_enabled
+            or self.is_fsdp_enabled
+        ):
+            self.place_model_on_device = False
+
+        default_collator = (
+            DataCollatorWithPadding(processing_class)
+            if processing_class is not None
+            and isinstance(processing_class, (PreTrainedTokenizerBase, SequenceFeatureExtractor))
+            else default_data_collator
+        )
+        self.data_collator = data_collator if data_collator is not None else default_collator
+        self.train_dataset = train_dataset
+        self.eval_dataset = eval_dataset
+        self.processing_class = processing_class
+
+        # Bnb Quantized models doesn't support `.to` operation.
+        if (
+            self.place_model_on_device
+            and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES
+        ):
+            self._move_model_to_device(model, args.device)
+
+        # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
+        if self.is_model_parallel:
+            self.args._n_gpu = 1
+
+        # later use `self.model is self.model_wrapped` to check if it's wrapped or not
+        self.model_wrapped = model
+        self.model = model
+
+        # Just in case the model was wrapped outside of the `Trainer`
+        unwrapped_model = self.accelerator.unwrap_model(model)
+        # We also unwrap peft model
+        if _is_peft_model(unwrapped_model):
+            if hasattr(unwrapped_model, "get_base_model"):
+                unwrapped_model = unwrapped_model.get_base_model()
+            elif hasattr(unwrapped_model, "base_model") and hasattr(unwrapped_model.base_model, "model"):
+                unwrapped_model = unwrapped_model.base_model.model
+            else:
+                raise AttributeError("Cannot extract base model safely from this PEFT wrapper.")
+
+        # Check if the model has explicit setup for loss kwargs,
+        # if not, check if `**kwargs` are in model.forward
+        if hasattr(unwrapped_model, "accepts_loss_kwargs"):
+            self.model_accepts_loss_kwargs = unwrapped_model.accepts_loss_kwargs
+        else:
+            forward_params = inspect.signature(unwrapped_model.forward).parameters
+            self.model_accepts_loss_kwargs = any(
+                k.kind == inspect.Parameter.VAR_KEYWORD for k in forward_params.values()
+            )
+
+        self.neftune_noise_alpha = args.neftune_noise_alpha
+
+        self.compute_metrics = compute_metrics
+        self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
+        self.optimizer, self.lr_scheduler = optimizers
+        self.optimizer_cls_and_kwargs = optimizer_cls_and_kwargs
+        if self.optimizer_cls_and_kwargs is not None and self.optimizer is not None:
+            raise RuntimeError("Passing both `optimizers` and `optimizer_cls_and_kwargs` arguments is incompatible.")
+        if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
+            raise RuntimeError(
+                "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
+                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
+            )
+        if is_torch_xla_available() and self.optimizer is not None:
+            for param in self.model.parameters():
+                model_device = param.device
+                break
+            for param_group in self.optimizer.param_groups:
+                if len(param_group["params"]) > 0:
+                    optimizer_device = param_group["params"][0].device
+                    break
+            if model_device != optimizer_device:
+                raise ValueError(
+                    "The model and the optimizer parameters are not on the same device, which probably means you"
+                    " created an optimizer around your model **before** putting on the device and passing it to the"
+                    " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
+                    " `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
+                )
+        if (self.is_fsdp_xla_enabled or self.is_fsdp_enabled) and (
+            self.optimizer is not None or self.lr_scheduler is not None
+        ):
+            raise RuntimeError(
+                "Passing `optimizers` is not allowed if PyTorch FSDP is enabled. "
+                "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
+            )
+        default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
+        callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
+        self.callback_handler = CallbackHandler(
+            callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
+        )
+        self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
+
+        # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
+        self._loggers_initialized = False
+
+        # Create distant repo and output directory if needed
+        self.hub_model_id = None
+        if self.args.push_to_hub:
+            self.init_hf_repo()
+        if self.args.should_save:
+            os.makedirs(self.args.output_dir, exist_ok=True)
+
+        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
+            raise TypeError("The `data_collator` should be a simple callable (function, class with `__call__`).")
+
+        if args.max_steps > 0 and args.num_train_epochs > 0:
+            logger.info("max_steps is given, it will override any value given in num_train_epochs")
+
+        if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
+            raise ValueError(
+                "The train_dataset does not implement __len__, max_steps has to be specified. "
+                "The number of steps needs to be known in advance for the learning rate scheduler."
+            )
+
+        if (
+            train_dataset is not None
+            and isinstance(train_dataset, torch.utils.data.IterableDataset)
+            and args.group_by_length
+        ):
+            raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")
+
+        self._signature_columns = None
+
+        # Mixed precision setup
+        self.use_apex = False
+        self.use_cpu_amp = False
+
+        # Mixed precision setup for SageMaker Model Parallel
+        if is_sagemaker_mp_enabled():
+            # BF16 + model parallelism in SageMaker: currently not supported, raise an error
+            if args.bf16:
+                raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
+
+            if IS_SAGEMAKER_MP_POST_1_10:
+                # When there's mismatch between SMP config and trainer argument, use SMP config as truth
+                if args.fp16 != smp.state.cfg.fp16:
+                    logger.warning(
+                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
+                        f"but FP16 provided in trainer argument is {args.fp16}, "
+                        f"setting to {smp.state.cfg.fp16}"
+                    )
+                    args.fp16 = smp.state.cfg.fp16
+            else:
+                # smp < 1.10 does not support fp16 in trainer.
+                if hasattr(smp.state.cfg, "fp16"):
+                    logger.warning(
+                        f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
+                        "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
+                    )
+        if (args.fp16 or args.bf16) and args.half_precision_backend == "auto":
+            if args.device == torch.device("cpu"):
+                if args.fp16:
+                    if not is_torch_greater_or_equal_than_2_3:
+                        raise ValueError("Tried to use `fp16` but it is not supported on cpu")
+                else:
+                    args.half_precision_backend = "cpu_amp"
+            logger.info(f"Using {args.half_precision_backend} half precision backend")
+
+        if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):
+            # deepspeed and SageMaker Model Parallel manage their own half precision
+            if args.half_precision_backend == "cpu_amp":
+                self.use_cpu_amp = True
+                self.amp_dtype = torch.bfloat16
+            elif args.half_precision_backend == "apex":
+                self.use_apex = True
+
+        # Label smoothing
+        if self.args.label_smoothing_factor != 0:
+            self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
+        else:
+            self.label_smoother = None
+
+        # Check for multi-label classification incompatibility
+        if self.args.label_smoothing_factor > 0:
+            if getattr(self.model.config, "problem_type", None) == "multi_label_classification":
+                warnings.warn(
+                    "Label smoothing is not compatible with multi-label classification. "
+                    "Disabling label smoothing for this training run.",
+                    UserWarning,
+                )
+                self.label_smoother = None
+
+        self.control = TrainerControl()
+
+        self.state = TrainerState(
+            is_local_process_zero=self.is_local_process_zero(),
+            is_world_process_zero=self.is_world_process_zero(),
+            stateful_callbacks=[
+                cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
+            ],
+        )
+        # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
+        # returned to 0 every time flos need to be logged
+        self.current_flos = 0
+        self.hp_search_backend = None
+
+        model_to_inspect = self.model
+        if _is_peft_model(self.model):
+            if hasattr(self.model, "get_base_model"):
+                model_to_inspect = self.model.get_base_model()
+            else:
+                # PeftMixedModel do not provide a `get_base_model` method
+                model_to_inspect = self.model.base_model.model
+        default_label_names = find_labels(model_to_inspect.__class__)
+        self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
+        self.can_return_loss = can_return_loss(model_to_inspect.__class__)
+        self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
+
+        # Internal variables to help with automatic batch size reduction
+        self._train_batch_size = args.train_batch_size
+        self._created_lr_scheduler = False
+
+        # very last
+        self._memory_tracker.stop_and_update_metrics()
+
+        self.is_fsdp_xla_v2_enabled = args.fsdp_config.get("xla_fsdp_v2", False)
+        if self.is_fsdp_xla_v2_enabled:
+            if not IS_XLA_FSDPV2_POST_2_2:
+                raise ValueError("FSDPv2 requires `torch_xla` 2.2 or higher.")
+            # Prepare the SPMD mesh that is going to be used by the data loader and the FSDPv2 wrapper.
+            # Tensor axis is just a placeholder where it will not be used in FSDPv2.
+            num_devices = xr.global_runtime_device_count()
+            xs.set_global_mesh(xs.Mesh(np.array(range(num_devices)), (num_devices, 1), axis_names=("fsdp", "tensor")))
+        self.is_fsdp_xla_v1_enabled = self.is_fsdp_xla_enabled and not self.is_fsdp_xla_v2_enabled
+
+    @property
+    def tokenizer(self) -> Optional[PreTrainedTokenizerBase]:
+        logger.warning("Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.")
+        return self.processing_class
+
+    @tokenizer.setter
+    def tokenizer(self, processing_class) -> None:
+        logger.warning(
+            "Trainer.tokenizer is now deprecated. You should use `Trainer.processing_class = processing_class` instead."
+        )
+        self.processing_class = processing_class
+
+    def _activate_neftune(self, model):
+        r"""
+        Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper:
+        https://huggingface.co/papers/2310.05914
+        """
+        unwrapped_model = self.accelerator.unwrap_model(model)
+
+        if _is_peft_model(unwrapped_model):
+            embeddings = unwrapped_model.base_model.model.get_input_embeddings()
+        else:
+            embeddings = unwrapped_model.get_input_embeddings()
+
+        del unwrapped_model
+
+        embeddings.neftune_noise_alpha = self.neftune_noise_alpha
+        hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
+        self.neftune_hook_handle = hook_handle
+        return model
+
+    def _deactivate_neftune(self, model):
+        """
+        Deactivates the neftune method. Make sure to call `_activate_neftune` first.
+        """
+        if not hasattr(self, "neftune_hook_handle"):
+            raise ValueError("Neftune is not activated make sure to call `trainer._activate_neftune()` first")
+
+        unwrapped_model = self.accelerator.unwrap_model(model)
+
+        if _is_peft_model(unwrapped_model):
+            embeddings = unwrapped_model.base_model.model.get_input_embeddings()
+        else:
+            embeddings = unwrapped_model.get_input_embeddings()
+
+        self.neftune_hook_handle.remove()
+        del embeddings.neftune_noise_alpha, unwrapped_model
+
+    def add_callback(self, callback):
+        """
+        Add a callback to the current list of [`~transformers.TrainerCallback`].
+
+        Args:
+           callback (`type` or [`~transformers.TrainerCallback]`):
+               A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
+               first case, will instantiate a member of that class.
+        """
+        self.callback_handler.add_callback(callback)
+
+    def pop_callback(self, callback):
+        """
+        Remove a callback from the current list of [`~transformers.TrainerCallback`] and returns it.
+
+        If the callback is not found, returns `None` (and no error is raised).
+
+        Args:
+           callback (`type` or [`~transformers.TrainerCallback]`):
+               A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
+               first case, will pop the first member of that class found in the list of callbacks.
+
+        Returns:
+            [`~transformers.TrainerCallback`]: The callback removed, if found.
+        """
+        return self.callback_handler.pop_callback(callback)
+
+    def remove_callback(self, callback):
+        """
+        Remove a callback from the current list of [`~transformers.TrainerCallback`].
+
+        Args:
+           callback (`type` or [`~transformers.TrainerCallback]`):
+               A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the
+               first case, will remove the first member of that class found in the list of callbacks.
+        """
+        self.callback_handler.remove_callback(callback)
+
+    def _move_model_to_device(self, model, device):
+        model = model.to(device)
+        # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
+        if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
+            model.tie_weights()
+
+    def _align_special_tokens(self):
+        """
+        Aligns the special tokens of the tokenizer with the model configs.
+
+        A new tokens may be defined in the tokenizer for fine-tuning purposes, e.g. an "end of turn" token may be
+        added on chat models. In that case, we want the model configs to be aligned with the tokenizer, so that all
+        downstream uses work as expected. This alignment should happen before training, to ensure the prediction step
+        uses the new tokens as well.
+        """
+        if isinstance(self.processing_class, ProcessorMixin):
+            tokenizer = self.processing_class.tokenizer
+        else:
+            tokenizer = self.processing_class
+        model_has_generation_config = (
+            hasattr(self.model, "generation_config") and self.model.generation_config is not None
+        )
+        updated_tokens = {}
+
+        # 1 - Align EOS token. EOS is more complex than the others, as `generation_config` may hold more than one EOS
+        # token.
+        tokenizer_has_new_eos = tokenizer.eos_token_id != self.model.config.eos_token_id
+        if model_has_generation_config:
+            # `generation_config.eos_token_id` is None: direct comparison
+            if self.model.generation_config.eos_token_id is None:
+                tokenizer_has_new_eos |= tokenizer.eos_token_id != self.model.generation_config.eos_token_id
+            else:
+                # `generation_config.eos_token_id` is an `int`: convert it to list (and continue below)
+                if isinstance(self.model.generation_config.eos_token_id, int):
+                    self.model.generation_config.eos_token_id = [self.model.generation_config.eos_token_id]
+                # `generation_config.eos_token_id` is a `list`: check if the tokenizer's EOS token is in the list
+                tokenizer_has_new_eos |= tokenizer.eos_token_id not in self.model.generation_config.eos_token_id
+
+        if tokenizer_has_new_eos:
+            updated_tokens["eos_token_id"] = tokenizer.eos_token_id
+            self.model.config.eos_token_id = tokenizer.eos_token_id
+            # The generation config may hold more than one EOS token. We preserve the original EOS tokens: any of the
+            # EOS tokens defined here will halt generation.
+            if model_has_generation_config:
+                all_eos_tokens = [tokenizer.eos_token_id]
+                if self.model.generation_config.eos_token_id is not None:
+                    all_eos_tokens += list(self.model.generation_config.eos_token_id)
+                self.model.generation_config.eos_token_id = [token for token in all_eos_tokens if token is not None]
+
+        # 2 - Align BOS
+        tokenizer_has_new_bos = tokenizer.bos_token_id != self.model.config.bos_token_id
+        if model_has_generation_config:
+            tokenizer_has_new_bos |= tokenizer.bos_token_id != self.model.generation_config.bos_token_id
+
+        if tokenizer_has_new_bos:
+            updated_tokens["bos_token_id"] = tokenizer.bos_token_id
+            self.model.config.bos_token_id = tokenizer.bos_token_id
+            if model_has_generation_config:
+                self.model.generation_config.bos_token_id = tokenizer.bos_token_id
+
+        # 3 - Align PAD
+        tokenizer_has_new_pad = tokenizer.pad_token_id != self.model.config.pad_token_id
+        if model_has_generation_config:
+            tokenizer_has_new_pad |= tokenizer.pad_token_id != self.model.generation_config.pad_token_id
+
+        if tokenizer_has_new_pad:
+            updated_tokens["pad_token_id"] = tokenizer.pad_token_id
+            self.model.config.pad_token_id = tokenizer.pad_token_id
+            if model_has_generation_config:
+                self.model.generation_config.pad_token_id = tokenizer.pad_token_id
+
+        # 4 - Warn users about the changes
+        if len(updated_tokens) > 0:
+            logger.warning(
+                "The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. "
+                "The model config and generation config were aligned accordingly, being updated with the tokenizer's "
+                f"values. Updated tokens: {updated_tokens}."
+            )
+
+    def _set_signature_columns_if_needed(self):
+        if self._signature_columns is None:
+            # Inspect model forward signature to keep only the arguments it accepts.
+            model_to_inspect = self.model
+            if _is_peft_model(self.model):
+                if hasattr(self.model, "get_base_model"):
+                    model_to_inspect = self.model.get_base_model()
+                else:
+                    # PeftMixedModel do not provide a `get_base_model` method
+                    model_to_inspect = self.model.base_model.model
+            signature = inspect.signature(model_to_inspect.forward)
+            self._signature_columns = list(signature.parameters.keys())
+            # Labels may be named label or label_ids, the default data collator handles that.
+            self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
+
+    def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
+        if not self.args.remove_unused_columns:
+            return dataset
+        self._set_signature_columns_if_needed()
+        signature_columns = self._signature_columns
+
+        ignored_columns = list(set(dataset.column_names) - set(signature_columns))
+        if len(ignored_columns) > 0:
+            dset_description = "" if description is None else f"in the {description} set"
+            logger.info(
+                f"The following columns {dset_description} don't have a corresponding argument in "
+                f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
+                f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
+                " you can safely ignore this message."
+            )
+
+        columns = [k for k in signature_columns if k in dataset.column_names]
+        if len(columns) == 0:
+            raise ValueError(
+                f"No columns in the dataset match the model's forward method signature: ({', '.join(signature_columns)}). "
+                f"The following columns have been ignored: [{', '.join(ignored_columns)}]. "
+                "Please check the dataset and model. You may need to set `remove_unused_columns=False` in `TrainingArguments`."
+            )
+
+        if version.parse(datasets.__version__) < version.parse("1.4.0"):
+            dataset.set_format(
+                type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
+            )
+            return dataset
+        else:
+            return dataset.remove_columns(ignored_columns)
+
+    def _get_collator_with_removed_columns(
+        self, data_collator: Callable, description: Optional[str] = None
+    ) -> Callable:
+        """Wrap the data collator in a callable removing unused columns."""
+        if not self.args.remove_unused_columns:
+            return data_collator
+        self._set_signature_columns_if_needed()
+        signature_columns = self._signature_columns
+
+        remove_columns_collator = RemoveColumnsCollator(
+            data_collator=data_collator,
+            signature_columns=signature_columns,
+            logger=logger,
+            description=description,
+            model_name=self.model.__class__.__name__,
+        )
+        return remove_columns_collator
+
+    def _get_train_sampler(self, train_dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]:
+        if train_dataset is None:
+            train_dataset = self.train_dataset
+        if train_dataset is None or not has_length(train_dataset):
+            return None
+
+        # Build the sampler.
+        if self.args.group_by_length:
+            if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
+                lengths = (
+                    train_dataset[self.args.length_column_name]
+                    if self.args.length_column_name in train_dataset.column_names
+                    else None
+                )
+            else:
+                lengths = None
+            model_input_name = (
+                self.processing_class.model_input_names[0] if self.processing_class is not None else None
+            )
+            return LengthGroupedSampler(
+                self.args.train_batch_size * self.args.gradient_accumulation_steps,
+                dataset=train_dataset,
+                lengths=lengths,
+                model_input_name=model_input_name,
+            )
+
+        else:
+            return RandomSampler(train_dataset)
+
+    def _get_dataloader(
+        self,
+        dataset: Dataset,
+        description: str,
+        batch_size: int,
+        sampler_fn: Optional[Callable[[Dataset], torch.utils.data.Sampler]] = None,
+        is_training: bool = False,
+        dataloader_key: Optional[str] = None,
+    ) -> DataLoader:
+        """Create a [`~torch.utils.data.DataLoader`] from the given dataset."""
+
+        data_collator = self.data_collator
+        if is_datasets_available() and isinstance(dataset, datasets.Dataset):
+            dataset = self._remove_unused_columns(dataset, description=description)
+        else:
+            data_collator = self._get_collator_with_removed_columns(self.data_collator, description=description)
+
+        dataloader_params = {
+            "batch_size": batch_size,
+            "collate_fn": data_collator,
+            "num_workers": self.args.dataloader_num_workers,
+            "pin_memory": self.args.dataloader_pin_memory,
+            "persistent_workers": self.args.dataloader_persistent_workers,
+        }
+
+        if not isinstance(dataset, torch.utils.data.IterableDataset):
+            if sampler_fn is not None:
+                dataloader_params["sampler"] = sampler_fn(dataset)
+            dataloader_params["drop_last"] = self.args.dataloader_drop_last
+            dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
+            if is_training:
+                dataloader_params["worker_init_fn"] = partial(
+                    seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
+                )
+
+        dataloader = self.accelerator.prepare(DataLoader(dataset, **dataloader_params))
+
+        # Store the prepared dataloader for subsequent evaluations if using persistent workers.
+        if dataloader_key is not None and self.args.dataloader_persistent_workers:
+            if hasattr(self, "_eval_dataloaders"):
+                self._eval_dataloaders[dataloader_key] = dataloader
+            else:
+                self._eval_dataloaders = {dataloader_key: dataloader}
+
+        return dataloader
+
+    def get_train_dataloader(self) -> DataLoader:
+        """
+        Returns the training [`~torch.utils.data.DataLoader`].
+
+        Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
+        training if necessary) otherwise.
+
+        Subclass and override this method if you want to inject some custom behavior.
+        """
+        if self.train_dataset is None:
+            raise ValueError("Trainer: training requires a train_dataset.")
+
+        return self._get_dataloader(
+            dataset=self.train_dataset,
+            description="Training",
+            batch_size=self._train_batch_size,
+            sampler_fn=self._get_train_sampler,
+            is_training=True,
+        )
+
+    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
+        if eval_dataset is None or not has_length(eval_dataset):
+            return None
+        # Build the sampler.
+
+        # Deprecated code
+        if self.args.use_legacy_prediction_loop:
+            if is_torch_xla_available():
+                return SequentialDistributedSampler(
+                    eval_dataset, num_replicas=xr.world_size(), rank=xr.global_ordinal()
+                )
+            elif is_sagemaker_mp_enabled():
+                return SequentialDistributedSampler(
+                    eval_dataset,
+                    num_replicas=smp.dp_size(),
+                    rank=smp.dp_rank(),
+                    batch_size=self.args.per_device_eval_batch_size,
+                )
+            else:
+                return SequentialSampler(eval_dataset)
+
+        if self.args.group_by_length:
+            if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
+                lengths = (
+                    eval_dataset[self.args.length_column_name]
+                    if self.args.length_column_name in eval_dataset.column_names
+                    else None
+                )
+            else:
+                lengths = None
+            model_input_name = (
+                self.processing_class.model_input_names[0] if self.processing_class is not None else None
+            )
+            return LengthGroupedSampler(
+                self.args.eval_batch_size,
+                dataset=eval_dataset,
+                lengths=lengths,
+                model_input_name=model_input_name,
+            )
+
+        if self.args.world_size <= 1:
+            return SequentialSampler(eval_dataset)
+        else:
+            return None
+
+    def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
+        """
+        Returns the evaluation [`~torch.utils.data.DataLoader`].
+
+        Subclass and override this method if you want to inject some custom behavior.
+
+        Args:
+            eval_dataset (`str` or `torch.utils.data.Dataset`, *optional*):
+                If a `str`, will use `self.eval_dataset[eval_dataset]` as the evaluation dataset. If a `Dataset`, will override `self.eval_dataset` and must implement `__len__`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed.
+        """
+        if eval_dataset is None and self.eval_dataset is None:
+            raise ValueError("Trainer: evaluation requires an eval_dataset.")
+
+        # If we have persistent workers, don't do a fork bomb especially as eval datasets
+        # don't change during training
+        dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval"
+        if (
+            hasattr(self, "_eval_dataloaders")
+            and dataloader_key in self._eval_dataloaders
+            and self.args.dataloader_persistent_workers
+        ):
+            return self._eval_dataloaders[dataloader_key]
+
+        eval_dataset = (
+            self.eval_dataset[eval_dataset]
+            if isinstance(eval_dataset, str)
+            else eval_dataset
+            if eval_dataset is not None
+            else self.eval_dataset
+        )
+
+        return self._get_dataloader(
+            dataset=eval_dataset,
+            description="Evaluation",
+            batch_size=self.args.eval_batch_size,
+            sampler_fn=self._get_eval_sampler,
+            dataloader_key=dataloader_key,
+        )
+
+    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
+        """
+        Returns the test [`~torch.utils.data.DataLoader`].
+
+        Subclass and override this method if you want to inject some custom behavior.
+
+        Args:
+            test_dataset (`torch.utils.data.Dataset`, *optional*):
+                The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
+                `model.forward()` method are automatically removed. It must implement `__len__`.
+        """
+        return self._get_dataloader(
+            dataset=test_dataset,
+            description="test",
+            batch_size=self.args.eval_batch_size,
+            sampler_fn=self._get_eval_sampler,
+        )
+
+    def create_optimizer_and_scheduler(self, num_training_steps: int):
+        """
+        Setup the optimizer and the learning rate scheduler.
+
+        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+        Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
+        `create_scheduler`) in a subclass.
+        """
+        self.create_optimizer()
+        if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
+            # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
+            optimizer = self.optimizer.optimizer
+        else:
+            optimizer = self.optimizer
+        self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
+
+    def get_decay_parameter_names(self, model) -> list[str]:
+        """
+        Get all parameter names that weight decay will be applied to.
+
+        This function filters out parameters in two ways:
+        1. By layer type (instances of layers specified in ALL_LAYERNORM_LAYERS)
+        2. By parameter name patterns (containing 'bias', or variation of 'norm')
+        """
+        forbidden_name_patterns = [r"bias", r"layernorm", r"rmsnorm", r"(?:^|\.)norm(?:$|\.)", r"_norm(?:$|\.)"]
+        decay_parameters = get_parameter_names(model, [nn.LayerNorm], forbidden_name_patterns)
+        return decay_parameters
+
+    def create_optimizer(self):
+        """
+        Setup the optimizer.
+
+        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
+        """
+        opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
+
+        if self.optimizer is None:
+            decay_parameters = self.get_decay_parameter_names(opt_model)
+            optimizer_grouped_parameters = [
+                {
+                    "params": [
+                        p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
+                    ],
+                    "weight_decay": self.args.weight_decay,
+                },
+                {
+                    "params": [
+                        p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
+                    ],
+                    "weight_decay": 0.0,
+                },
+            ]
+
+            if self.optimizer_cls_and_kwargs is not None:
+                optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
+            else:
+                optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model)
+
+            # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
+            # e.g. for GaLore optimizer.
+            if "params" in optimizer_kwargs:
+                optimizer_grouped_parameters = optimizer_kwargs.pop("params")
+
+            # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
+            # e.g. for LOMO optimizer.
+            if "model" in optimizer_kwargs:
+                optimizer_grouped_parameters = optimizer_kwargs.pop("model")
+
+            # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
+            # to avoid arguments conflicts.
+            if "optimizer_dict" in optimizer_kwargs:
+                optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
+
+            self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
+
+            if "bitsandbytes" in str(optimizer_cls) and optimizer_kwargs.get("optim_bits", None) == 8:
+                import bitsandbytes
+
+                manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
+
+                skipped = 0
+                for module in opt_model.modules():
+                    if isinstance(module, nn.Embedding):
+                        skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
+                        logger.info(f"skipped {module}: {skipped / 2**20}M params")
+                        manager.register_module_override(module, "weight", {"optim_bits": 32})
+                        logger.debug(f"bitsandbytes: will optimize {module} in fp32")
+                logger.info(f"skipped: {skipped / 2**20}M params")
+
+        if is_sagemaker_mp_enabled():
+            self.optimizer = smp.DistributedOptimizer(self.optimizer)
+
+        return self.optimizer
+
+    def get_num_trainable_parameters(self):
+        """
+        Get the number of trainable parameters.
+        """
+        return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
+
+    def get_learning_rates(self):
+        """
+        Returns the learning rate of each parameter from self.optimizer.
+        """
+        if self.optimizer is None:
+            raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.")
+        return [group["lr"] for group in self.optimizer.param_groups]
+
+    def get_optimizer_group(self, param: Optional[Union[str, torch.nn.parameter.Parameter]] = None):
+        """
+        Returns optimizer group for a parameter if given, else returns all optimizer groups for params.
+
+        Args:
+            param (`str` or `torch.nn.parameter.Parameter`, *optional*):
+                The parameter for which optimizer group needs to be returned.
+        """
+        if self.optimizer is None:
+            raise ValueError("Trainer optimizer is None, please make sure you have setup the optimizer before.")
+        if param is not None:
+            for group in self.optimizer.param_groups:
+                if param in group["params"]:
+                    return group
+        return [group["params"] for group in self.optimizer.param_groups]
+
+    @staticmethod
+    def get_optimizer_cls_and_kwargs(
+        args: TrainingArguments, model: Optional[PreTrainedModel] = None
+    ) -> tuple[Any, Any]:
+        """
+        Returns the optimizer class and optimizer parameters based on the training arguments.
+
+        Args:
+            args (`transformers.training_args.TrainingArguments`):
+                The training arguments for the training session.
+
+        """
+
+        # parse args.optim_args
+        optim_args = {}
+        if args.optim_args:
+            for mapping in args.optim_args.replace(" ", "").split(","):
+                key, value = mapping.split("=")
+                optim_args[key] = value
+
+        optimizer_kwargs = {"lr": args.learning_rate}
+
+        adam_kwargs = {
+            "betas": (args.adam_beta1, args.adam_beta2),
+            "eps": args.adam_epsilon,
+        }
+
+        def setup_low_rank_optimizer(
+            optimizer_name: str,
+            optimizer_mapping: dict[str, Any],
+            optim_kwargs: dict[str, Any],
+            is_layerwise_supported: bool = True,
+        ) -> tuple[Any, Any]:
+            """
+            Helper function to set up low-rank optimizers like GaLore and Apollo.
+
+            Args:
+                optimizer_name (str): Name of the optimizer.
+                optimizer_mapping (dict): Mapping of optimizer names to their classes.
+                optim_kwargs (dict): Keyword arguments for the optimizer.
+                is_layerwise_supported (bool): Whether layerwise optimization is supported.
+
+            Returns:
+                tuple[Any, Any]: Optimizer class and updated optimizer kwargs.
+            """
+            is_layerwise = optimizer_name.lower().endswith("layerwise")
+            if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED and is_layerwise_supported:
+                raise NotImplementedError(f"Layer-wise {optimizer_name} does not support DDP at this time")
+
+            optimizer_cls = optimizer_mapping[optimizer_name]
+
+            if args.optim_target_modules is None:
+                raise ValueError(f"You need to define `optim_target_modules` to use {optimizer_name} optimizers")
+
+            if not isinstance(args.optim_target_modules, (list, str)):
+                raise TypeError(
+                    f"`optim_target_modules` must be a list of strings, a regex string, or 'all-linear'. Got: {args.optim_target_modules}"
+                )
+
+            if model is None:
+                raise ValueError(f"You need to pass a model to initialize {optimizer_name} optimizer.")
+
+            all_linear = (
+                isinstance(args.optim_target_modules, str)
+                and args.optim_target_modules.replace("_", "-") == "all-linear"
+            )
+
+            target_params_names = []
+            for module_name, module in model.named_modules():
+                target_module_exists, is_regex = check_target_module_exists(
+                    args.optim_target_modules, module_name, return_is_regex=True
+                )
+
+                if not isinstance(module, nn.Linear):
+                    if target_module_exists and not is_regex:
+                        logger.warning(
+                            f"{module_name} matched but ignored. {optimizer_name} only supports linear layers."
+                        )
+                    continue
+
+                if not target_module_exists and not all_linear:
+                    continue
+
+                target_params_names.append(module_name + ".weight")
+
+            if len(target_params_names) == 0:
+                raise ValueError(f"No target modules found for {optimizer_name} ({args.optim_target_modules}).")
+
+            target_params = [p for n, p in model.named_parameters() if n in target_params_names]
+            non_target_params = [p for n, p in model.named_parameters() if n not in target_params_names]
+            optim_kwargs.update(optim_args)
+
+            param_groups = [
+                {"params": non_target_params},
+                {"params": target_params, **optim_kwargs},
+            ]
+
+            if is_layerwise:
+                if args.gradient_accumulation_steps != 1:
+                    raise ValueError(f"Layerwise {optimizer_name} does not support gradient accumulation!")
+
+                optimizer_dict = {}
+                for param in non_target_params:
+                    optimizer_dict[param] = optimizer_cls([{"params": [param]}], **optimizer_kwargs)
+                for param in target_params:
+                    optimizer_dict[param] = optimizer_cls([{"params": [param], **optim_kwargs}], **optimizer_kwargs)
+
+                def optimizer_hook(param):
+                    if param.grad is not None:
+                        optimizer_dict[param].step()
+                        optimizer_dict[param].zero_grad()
+
+                for param in model.parameters():
+                    if param.requires_grad:
+                        param.register_post_accumulate_grad_hook(optimizer_hook)
+
+                optimizer_cls = LayerWiseDummyOptimizer
+                optimizer_kwargs.update({"optimizer_dict": optimizer_dict})
+
+            optimizer_kwargs.update({"params": param_groups})
+            return optimizer_cls, optimizer_kwargs
+
+        if args.optim == OptimizerNames.ADAFACTOR:
+            optimizer_cls = Adafactor
+            optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
+        elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
+            from torch.optim import AdamW
+
+            optimizer_cls = AdamW
+            optimizer_kwargs.update(adam_kwargs)
+            if args.optim == OptimizerNames.ADAMW_TORCH_FUSED:
+                optimizer_kwargs.update({"fused": True})
+        elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
+            try:
+                from torch_xla.amp.syncfree import AdamW
+
+                optimizer_cls = AdamW
+                optimizer_kwargs.update(adam_kwargs)
+            except ImportError:
+                raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
+        elif args.optim == OptimizerNames.ADAMW_TORCH_NPU_FUSED:
+            try:
+                from torch_npu.optim import NpuFusedAdamW
+
+                optimizer_cls = NpuFusedAdamW
+                optimizer_kwargs.update(adam_kwargs)
+            except ImportError:
+                raise ValueError("Trainer failed to import FusedAdamW from torch_npu.")
+        elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
+            try:
+                from apex.optimizers import FusedAdam
+
+                optimizer_cls = FusedAdam
+                optimizer_kwargs.update(adam_kwargs)
+            except ImportError:
+                raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
+        elif args.optim in [
+            OptimizerNames.ADAMW_BNB,
+            OptimizerNames.ADAMW_8BIT,
+            OptimizerNames.PAGED_ADAMW,
+            OptimizerNames.PAGED_ADAMW_8BIT,
+            OptimizerNames.ADEMAMIX,
+            OptimizerNames.ADEMAMIX_8BIT,
+            OptimizerNames.PAGED_ADEMAMIX,
+            OptimizerNames.PAGED_ADEMAMIX_8BIT,
+            OptimizerNames.LION,
+            OptimizerNames.LION_8BIT,
+            OptimizerNames.PAGED_LION,
+            OptimizerNames.PAGED_LION_8BIT,
+            OptimizerNames.RMSPROP_BNB,
+            OptimizerNames.RMSPROP_8BIT,
+            OptimizerNames.RMSPROP_32BIT,
+        ]:
+            try:
+                from bitsandbytes.optim import AdamW, Lion, RMSprop
+
+                is_paged = False
+                optim_bits = 32
+                optimizer_cls = None
+                additional_optim_kwargs = adam_kwargs
+                if "paged" in args.optim:
+                    is_paged = True
+                if "8bit" in args.optim:
+                    optim_bits = 8
+                if "adam" in args.optim:
+                    optimizer_cls = AdamW
+                elif "lion" in args.optim:
+                    optimizer_cls = Lion
+                    additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)}
+                elif "rmsprop" in args.optim:
+                    optimizer_cls = RMSprop
+                    # Above we pass all `adam_kwargs` to the optimizer, here
+                    # we only pass `optim_args` which can be passed by the user.
+                    additional_optim_kwargs = optim_args
+                elif "ademamix" in args.optim:
+                    if is_bitsandbytes_available() and version.parse(
+                        importlib.metadata.version("bitsandbytes")
+                    ) < version.parse("0.44.0"):
+                        raise ValueError(
+                            "The AdEMAMix optimizer is not supported by your current version of `bitsandbytes`. "
+                            "Please install `bitsandbytes` >= 0.44.0."
+                        )
+
+                    from bitsandbytes.optim import AdEMAMix
+
+                    optimizer_cls = AdEMAMix
+                    additional_optim_kwargs = {
+                        "betas": (
+                            float(optim_args.get("beta1", args.adam_beta1)),
+                            float(optim_args.get("beta2", args.adam_beta2)),
+                            float(optim_args.get("beta3", 0.9999)),
+                        ),
+                        "alpha": float(optim_args.get("alpha", 5.0)),
+                        "eps": float(optim_args.get("eps", args.adam_epsilon)),
+                    }
+
+                    if "t_alpha" in optim_args:
+                        additional_optim_kwargs["t_alpha"] = int(optim_args["t_alpha"])
+
+                    if "t_beta3" in optim_args:
+                        additional_optim_kwargs["t_beta3"] = int(optim_args["t_beta3"])
+
+                bnb_kwargs = {"optim_bits": optim_bits}
+                if "rmsprop" not in args.optim:
+                    bnb_kwargs["is_paged"] = is_paged
+
+                optimizer_kwargs.update(additional_optim_kwargs)
+                optimizer_kwargs.update(bnb_kwargs)
+            except ImportError:
+                raise ValueError("Trainer tried to instantiate bnb optimizer but `bitsandbytes` is not installed!")
+            if is_bitsandbytes_available() and version.parse(
+                importlib.metadata.version("bitsandbytes")
+            ) < version.parse("0.41.1"):
+                logger.warning(
+                    "You are using 8-bit optimizers with a version of `bitsandbytes` < 0.41.1. "
+                    "It is recommended to update your version as a major bug has been fixed in 8-bit optimizers."
+                )
+        elif args.optim == OptimizerNames.ADAMW_ANYPRECISION:
+            try:
+                from torchdistx.optimizers import AnyPrecisionAdamW
+
+                optimizer_cls = AnyPrecisionAdamW
+                optimizer_kwargs.update(adam_kwargs)
+
+                # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx.
+                optimizer_kwargs.update(
+                    {
+                        "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")),
+                        "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")),
+                        "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")),
+                        "compensation_buffer_dtype": getattr(
+                            torch, optim_args.get("compensation_buffer_dtype", "bfloat16")
+                        ),
+                    }
+                )
+            except ImportError:
+                raise ValueError("Please install https://github.com/pytorch/torchdistx")
+        elif args.optim == OptimizerNames.SGD:
+            optimizer_cls = torch.optim.SGD
+        elif args.optim == OptimizerNames.ADAGRAD:
+            optimizer_cls = torch.optim.Adagrad
+        elif args.optim == OptimizerNames.RMSPROP:
+            optimizer_cls = torch.optim.RMSprop
+        elif args.optim in [
+            OptimizerNames.GALORE_ADAMW,
+            OptimizerNames.GALORE_ADAMW_8BIT,
+            OptimizerNames.GALORE_ADAFACTOR,
+            OptimizerNames.GALORE_ADAMW_LAYERWISE,
+            OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE,
+            OptimizerNames.GALORE_ADAFACTOR_LAYERWISE,
+        ]:
+            if not is_galore_torch_available():
+                raise ImportError(
+                    "You need to install `galore_torch` in order to use GaLore optimizers"
+                    " install it with `pip install git+https://github.com/jiaweizzhao/GaLore`"
+                )
+            from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
+
+            optimizer_mapping = {
+                OptimizerNames.GALORE_ADAMW: GaLoreAdamW,
+                OptimizerNames.GALORE_ADAMW_8BIT: GaLoreAdamW8bit,
+                OptimizerNames.GALORE_ADAFACTOR: GaLoreAdafactor,
+                OptimizerNames.GALORE_ADAMW_LAYERWISE: GaLoreAdamW,
+                OptimizerNames.GALORE_ADAMW_8BIT_LAYERWISE: GaLoreAdamW8bit,
+                OptimizerNames.GALORE_ADAFACTOR_LAYERWISE: GaLoreAdafactor,
+            }
+
+            galore_optim_kwargs = {
+                "rank": int(optim_args.pop("rank", 128)),
+                "update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
+                "scale": float(optim_args.pop("scale", 0.25)),
+                "proj_type": optim_args.pop("proj_type", "std"),
+            }
+
+            optimizer_cls, optimizer_kwargs = setup_low_rank_optimizer(
+                args.optim, optimizer_mapping, galore_optim_kwargs
+            )
+            if args.optim == OptimizerNames.GALORE_ADAFACTOR:
+                optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
+        elif args.optim in [
+            OptimizerNames.APOLLO_ADAMW,
+            OptimizerNames.APOLLO_ADAMW_LAYERWISE,
+        ]:
+            if not is_apollo_torch_available():
+                raise ImportError(
+                    "You need to install `apollo_torch` in order to use APOLLO optimizers"
+                    " install it with `pip install git+https://github.com/zhuhanqing/APOLLO`"
+                )
+            from apollo_torch import APOLLOAdamW
+
+            optimizer_mapping = {
+                OptimizerNames.APOLLO_ADAMW: APOLLOAdamW,
+                OptimizerNames.APOLLO_ADAMW_LAYERWISE: APOLLOAdamW,
+            }
+
+            apollo_optim_kwargs = {
+                "rank": int(optim_args.pop("rank", 128)),
+                "proj": optim_args.pop("proj", "random"),
+                "scale_type": optim_args.pop("scale_type", "channel"),
+                "update_proj_gap": int(optim_args.pop("update_proj_gap", 200)),
+                "scale": float(optim_args.pop("scale", 1.0)),
+                "proj_type": optim_args.pop("proj_type", "std"),
+            }
+            apollo_optim_kwargs.update(adam_kwargs)
+
+            optimizer_cls, optimizer_kwargs = setup_low_rank_optimizer(
+                args.optim, optimizer_mapping, apollo_optim_kwargs
+            )
+        elif args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
+            if not is_lomo_available():
+                raise ImportError(
+                    "You need to install `lomo_optim` in order to use LOMO optimizers"
+                    " install it with `pip install lomo-optim`"
+                )
+            if not is_accelerate_available("0.30.0"):
+                raise ImportError("You need to have `accelerate>=0.30.0` to be able to use LOMO optimizers")
+
+            if model is None:
+                raise ValueError("You need to pass a `model` in order to correctly initialize a LOMO optimizer.")
+
+            from lomo_optim import AdaLomo, Lomo
+
+            if "ada" in args.optim:
+                optimizer_cls = AdaLomo
+            else:
+                optimizer_cls = Lomo
+
+            optimizer_kwargs.update({"model": model})
+        elif args.optim == OptimizerNames.GROKADAMW:
+            if not is_grokadamw_available():
+                raise ValueError("Please install grokadamw with `pip install grokadamw`")
+
+            from grokadamw import GrokAdamW
+
+            optimizer_cls = GrokAdamW
+            optimizer_kwargs.update(
+                {
+                    "alpha_init": float(optim_args.get("alpha_init", 0.98)),
+                    "lamb": float(optim_args.get("lamb", 2.0)),
+                    "gamma": float(optim_args.get("gamma", 0.1)),
+                    "grokking_signal_decay_rate": float(optim_args.get("grokking_signal_decay_rate", 0.1)),
+                    "gradient_clipping": float(optim_args.get("gradient_clipping", 1.0)),
+                }
+            )
+        elif args.optim in [
+            OptimizerNames.ADAMW_TORCH_4BIT,
+            OptimizerNames.ADAMW_TORCH_8BIT,
+        ]:
+            if not is_torchao_available() or version.parse(importlib.metadata.version("torchao")) < version.parse(
+                "0.4.0"
+            ):
+                raise ImportError(
+                    "You need to have `torchao>=0.4.0` in order to use torch 4-bit optimizers."
+                    "Install it with `pip install torchao` or follow the instructions here: https://github.com/pytorch/ao"
+                )
+            if version.parse(importlib.metadata.version("torch")) <= version.parse("2.4"):
+                raise ImportError(
+                    "You need to have `torch>2.4` in order to use torch 4-bit optimizers. "
+                    "Install it with `pip install --upgrade torch` it is available on pipy. Otherwise, you need to install torch nightly."
+                )
+            if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.11.0"):
+                # https://github.com/pytorch/ao/pull/2159
+                from torchao.optim import AdamW4bit, AdamW8bit
+            else:
+                from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
+            if args.optim == OptimizerNames.ADAMW_TORCH_4BIT:
+                optimizer_cls = AdamW4bit
+            elif args.optim == OptimizerNames.ADAMW_TORCH_8BIT:
+                optimizer_cls = AdamW8bit
+            else:
+                raise ValueError("Invalid optimizer")
+            optimizer_kwargs.update(adam_kwargs)
+        elif args.optim in [
+            OptimizerNames.SCHEDULE_FREE_RADAM,
+            OptimizerNames.SCHEDULE_FREE_ADAMW,
+            OptimizerNames.SCHEDULE_FREE_SGD,
+        ]:
+            if not is_schedulefree_available():
+                raise ImportError(
+                    "You need to install `schedulefree` in order to use schedulefree optimizers. "
+                    "Install it with `pip install schedulefree.`"
+                )
+            if not is_accelerate_available("0.30.0"):
+                raise ImportError("You need to have `accelerate>=0.30.0` to be able to use schedulefree optimizers")
+            from schedulefree import AdamWScheduleFree, SGDScheduleFree
+
+            additional_optim_kwargs = {}
+            require_warmup = True
+
+            if args.optim == OptimizerNames.SCHEDULE_FREE_RADAM:
+                if not is_schedulefree_available("1.4.0"):
+                    raise ImportError(
+                        "You need to install `schedulefree>=1.4.0` in order to use RAdamScheduleFree optimizer. "
+                        "Install it with `pip install schedulefree.`"
+                    )
+                from schedulefree import RAdamScheduleFree
+
+                optimizer_cls = RAdamScheduleFree
+                additional_optim_kwargs = adam_kwargs
+                require_warmup = False
+            elif args.optim == OptimizerNames.SCHEDULE_FREE_ADAMW:
+                optimizer_cls = AdamWScheduleFree
+                additional_optim_kwargs = adam_kwargs
+            elif args.optim == OptimizerNames.SCHEDULE_FREE_SGD:
+                optimizer_cls = SGDScheduleFree
+            else:
+                raise ValueError("Invalid schedulefree optimizer")
+
+            additional_optim_kwargs["weight_decay"] = args.weight_decay
+            if require_warmup:
+                additional_optim_kwargs["warmup_steps"] = args.warmup_steps
+            additional_optim_kwargs.update(
+                {
+                    "weight_lr_power": float(optim_args.get("weight_lr_power", 2.0)),
+                    "r": float(optim_args.get("r", 0.0)),
+                }
+            )
+            optimizer_kwargs.update(additional_optim_kwargs)
+        elif args.optim == OptimizerNames.STABLE_ADAMW:
+            if not is_torch_optimi_available():
+                raise ImportError(
+                    "You need to install `torch-optimi` in order to use stable_adamw optimizers. "
+                    "Install it with `pip install torch-optimi`."
+                )
+            from optimi import StableAdamW
+
+            max_lr = optim_args.pop("max_lr", None)
+            if max_lr is not None:
+                max_lr = float(max_lr)
+
+            kahan_sum = optim_args.pop("kahan_sum", None)
+            if kahan_sum is not None:
+                kahan_sum = bool(kahan_sum)
+
+            adam_kwargs["weight_decay"] = args.weight_decay
+            stable_adamw_kwargs = {
+                "decouple_lr": bool(optim_args.pop("decouple_lr", False)),
+                "max_lr": max_lr,
+                "kahan_sum": kahan_sum,
+            }
+
+            optimizer_cls = StableAdamW
+            optimizer_kwargs.update(adam_kwargs)
+            optimizer_kwargs.update(stable_adamw_kwargs)
+        else:
+            raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
+        return optimizer_cls, optimizer_kwargs
+
+    def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
+        """
+        Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
+        passed as an argument.
+
+        Args:
+            num_training_steps (int): The number of training steps to do.
+        """
+        if self.lr_scheduler is None:
+            self.lr_scheduler = get_scheduler(
+                self.args.lr_scheduler_type,
+                optimizer=self.optimizer if optimizer is None else optimizer,
+                num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
+                num_training_steps=num_training_steps,
+                scheduler_specific_kwargs=self.args.lr_scheduler_kwargs,
+            )
+            self._created_lr_scheduler = True
+        return self.lr_scheduler
+
+    def num_examples(self, dataloader: DataLoader) -> int:
+        """
+        Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
+        dataloader.dataset does not exist or has no length, estimates as best it can
+        """
+        try:
+            dataset = dataloader.dataset
+            # Special case for IterableDatasetShard, we need to dig deeper
+            if isinstance(dataset, IterableDatasetShard):
+                return len(dataloader.dataset.dataset)
+            return len(dataloader.dataset)
+        except (NameError, AttributeError, TypeError):  # no dataset or length, estimate by length of dataloader
+            return len(dataloader) * self.args.per_device_train_batch_size
+
+    @staticmethod
+    def num_tokens(train_dl: DataLoader, max_steps: Optional[int] = None) -> int:
+        """
+        Helper to get number of tokens in a [`~torch.utils.data.DataLoader`] by enumerating dataloader.
+        """
+        train_tokens = 0
+        try:
+            for batch in train_dl:
+                tokens = batch["input_ids"].numel()
+                if max_steps is not None:
+                    return tokens * max_steps
+                train_tokens += tokens
+        except KeyError:
+            logger.warning("Cannot get num_tokens from dataloader")
+        return train_tokens
+
+    def _hp_search_setup(self, trial: Union["optuna.Trial", dict[str, Any]]):
+        """HP search setup code"""
+        self._trial = trial
+
+        if self.hp_search_backend is None or trial is None:
+            return
+        if self.hp_search_backend == HPSearchBackend.OPTUNA:
+            params = self.hp_space(trial)
+        elif self.hp_search_backend == HPSearchBackend.RAY:
+            params = trial
+            params.pop("wandb", None)
+        elif self.hp_search_backend == HPSearchBackend.SIGOPT:
+            params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
+        elif self.hp_search_backend == HPSearchBackend.WANDB:
+            params = trial
+
+        for key, value in params.items():
+            if not hasattr(self.args, key):
+                logger.warning(
+                    f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
+                    " `TrainingArguments`."
+                )
+                continue
+            old_attr = getattr(self.args, key, None)
+            # Casting value to the proper type
+            if old_attr is not None:
+                value = type(old_attr)(value)
+
+            setattr(self.args, key, value)
+        if self.hp_search_backend == HPSearchBackend.OPTUNA:
+            logger.info(f"Trial: {trial.params}")
+        if self.hp_search_backend == HPSearchBackend.SIGOPT:
+            logger.info(f"SigOpt Assignments: {trial.assignments}")
+        if self.hp_search_backend == HPSearchBackend.WANDB:
+            logger.info(f"W&B Sweep parameters: {trial}")
+        if self.is_deepspeed_enabled:
+            if self.args.deepspeed is None:
+                raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set")
+
+            self.accelerator.free_memory()
+
+            # Rebuild the deepspeed config to reflect the updated training parameters
+            from accelerate.utils import DeepSpeedPlugin
+
+            from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
+
+            self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
+            self.args.hf_deepspeed_config.trainer_config_process(self.args)
+            self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
+
+            # From 1.0 on, we need to fully wipe the DS plugin when doing sweeps.
+            # Simply calling `_reset_state` is enough and doesn't need a version pin.
+            AcceleratorState()._reset_state()
+
+        self.create_accelerator_and_postprocess()
+
+    def _report_to_hp_search(self, trial: Union["optuna.Trial", dict[str, Any]], step: int, metrics: dict[str, float]):
+        if self.hp_search_backend is None or trial is None:
+            return
+        metrics = metrics.copy()
+        self.objective = self.compute_objective(metrics)
+        if self.hp_search_backend == HPSearchBackend.OPTUNA:
+            import optuna
+
+            if hasattr(trial, "study") and not trial.study._is_multi_objective():
+                trial.report(self.objective, step)
+                if trial.should_prune():
+                    self.callback_handler.on_train_end(self.args, self.state, self.control)
+                    raise optuna.TrialPruned()
+        elif self.hp_search_backend == HPSearchBackend.RAY:
+            import ray.train
+
+            with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
+                checkpoint = None
+                if self.control.should_save:
+                    self._tune_save_checkpoint(checkpoint_dir=temp_checkpoint_dir)
+                    checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
+                metrics["objective"] = self.objective
+                ray.train.report(metrics, checkpoint=checkpoint)
+
+    def _tune_save_checkpoint(self, checkpoint_dir: str):
+        output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
+        self.save_model(output_dir, _internal_call=True)
+        if self.args.should_save:
+            # Update the `TrainerControl` state to where we are currently
+            self.state.stateful_callbacks["TrainerControl"] = self.control.state()
+            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
+            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
+            torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+
+    def call_model_init(self, trial=None):
+        model_init_argcount = number_of_arguments(self.model_init)
+        if model_init_argcount == 0:
+            model = self.model_init()
+        elif model_init_argcount == 1:
+            model = self.model_init(trial)
+        else:
+            raise RuntimeError("model_init should have 0 or 1 argument.")
+
+        if model is None:
+            raise RuntimeError("model_init should not return None.")
+
+        return model
+
+    def torch_jit_model_eval(self, model, dataloader, training=False):
+        if not training:
+            if dataloader is None:
+                logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
+                return model
+            example_batch = next(iter(dataloader))
+            example_batch = self._prepare_inputs(example_batch)
+            try:
+                jit_model = copy.copy(model)
+                jit_model.eval()
+                original_forward = jit_model.__dict__.pop("_original_forward", None)
+                # remove mixed precision hooks from the model
+                if original_forward:
+                    jit_model.forward = original_forward
+                autocast_handler = AutocastKwargs(cache_enabled=False)
+                with self.accelerator.autocast(autocast_handler=autocast_handler), torch.no_grad():
+                    if isinstance(example_batch, dict):
+                        jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
+                    else:
+                        jit_model = torch.jit.trace(
+                            jit_model,
+                            example_kwarg_inputs={key: example_batch[key] for key in example_batch},
+                            strict=False,
+                        )
+                jit_model = torch.jit.freeze(jit_model)
+                with torch.no_grad():
+                    jit_model(**example_batch)
+                    jit_model(**example_batch)
+                model = jit_model
+                self.use_cpu_amp = False
+            except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
+                logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
+
+        return model
+
+    def compare_trainer_and_checkpoint_args(self, training_args, trainer_state):
+        attributes_map = {
+            "logging_steps": "logging_steps",
+            "eval_steps": "eval_steps",
+            "save_steps": "save_steps",
+        }
+
+        has_warning = False
+        warning_str = "Warning: The following arguments do not match the ones in the `trainer_state.json` within the checkpoint directory: "
+        for arg_attr, state_attr in attributes_map.items():
+            arg_value = getattr(training_args, arg_attr, None)
+            state_value = getattr(trainer_state, state_attr, None)
+
+            if arg_value is not None and state_value is not None and arg_value != state_value:
+                warning_str += f"\n\t{arg_attr}: {arg_value} (from args) != {state_value} (from trainer_state.json)"
+                has_warning = True
+
+        # train bs is special as we need to account for multi-GPU
+        train_bs_args = training_args.per_device_train_batch_size
+        train_bs_state = trainer_state.train_batch_size // max(1, training_args.n_gpu)
+
+        if train_bs_args != train_bs_state:
+            warning_str += f"\n\tper_device_train_batch_size: {train_bs_args} (from args) != {train_bs_state} (from trainer_state.json)"
+            has_warning = True
+
+        if has_warning:
+            logger.warning_once(warning_str)
+
+    def _wrap_model(self, model, training=True, dataloader=None):
+        if is_sagemaker_mp_enabled():
+            # Wrapping the base model twice in a DistributedModel will raise an error.
+            if isinstance(self.model_wrapped, smp.model.DistributedModel):
+                return self.model_wrapped
+            return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
+
+        # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
+        if self.accelerator.unwrap_model(model, keep_torch_compile=False) is not model:
+            return model
+
+        # Mixed precision training with apex
+        if self.use_apex and training:
+            from apex import amp
+
+            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
+
+        # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP
+        if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False):
+            model = nn.DataParallel(model)
+
+        if self.args.jit_mode_eval:
+            start_time = time.time()
+            model = self.torch_jit_model_eval(model, dataloader, training)
+            self.jit_compilation_time = round(time.time() - start_time, 4)
+
+        # Note: in torch.distributed mode, there's no point in wrapping the model
+        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
+        if not training:
+            return model
+
+        # Distributed training (should be after apex fp16 initialization)
+        # Distributed training using PyTorch FSDP
+        if self.is_fsdp_xla_enabled:
+            try:
+                from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
+                from torch_xla.distributed.fsdp import checkpoint_module
+                from torch_xla.distributed.fsdp.wrap import (
+                    size_based_auto_wrap_policy,
+                    transformer_auto_wrap_policy,
+                )
+
+                if self.is_fsdp_xla_v2_enabled:
+                    from torch_xla.experimental.spmd_fully_sharded_data_parallel import (
+                        SpmdFullyShardedDataParallel as FSDPv2,
+                    )
+            except ImportError:
+                raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
+            auto_wrap_policy = None
+            auto_wrapper_callable = None
+            default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
+            fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get(
+                "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
+            )
+
+            if self.args.fsdp_config["min_num_params"] > 0:
+                auto_wrap_policy = functools.partial(
+                    size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["min_num_params"]
+                )
+            elif fsdp_transformer_layer_cls_to_wrap is not None:
+                transformer_cls_to_wrap = set()
+                for layer_class in fsdp_transformer_layer_cls_to_wrap:
+                    transformer_cls = get_module_class_from_name(model, layer_class)
+                    if transformer_cls is None:
+                        raise Exception("Could not find the transformer layer class to wrap in the model.")
+                    else:
+                        transformer_cls_to_wrap.add(transformer_cls)
+
+                auto_wrap_policy = functools.partial(
+                    transformer_auto_wrap_policy,
+                    # Transformer layer class to wrap
+                    transformer_layer_cls=transformer_cls_to_wrap,
+                )
+            fsdp_kwargs = self.args.xla_fsdp_config
+            if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
+                if model.config.use_cache:
+                    logger.warning_once(
+                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+                    )
+                    model.config.use_cache = False
+
+                # Apply gradient checkpointing to auto-wrapped sub-modules if specified
+                def auto_wrapper_callable(m, *args, **kwargs):
+                    target_cls = FSDP if not self.is_fsdp_xla_v2_enabled else FSDPv2
+                    return target_cls(checkpoint_module(m), *args, **kwargs)
+
+            # Wrap the base model with an outer FSDP wrapper
+            if self.is_fsdp_xla_v2_enabled:
+
+                def shard_output(output, mesh):
+                    from .modeling_outputs import CausalLMOutputWithPast
+
+                    real_output = None
+                    if isinstance(output, torch.Tensor):
+                        real_output = output
+                    elif isinstance(output, tuple):
+                        real_output = output[0]
+                    elif isinstance(output, CausalLMOutputWithPast):
+                        real_output = output.logits
+
+                    if real_output is None:
+                        raise ValueError("Something went wrong, the output of the model shouldn't be `None`")
+                    xs.mark_sharding(real_output, mesh, ("fsdp", None, None))
+
+                self.model = model = FSDPv2(
+                    model,
+                    shard_output=shard_output,
+                    auto_wrap_policy=auto_wrap_policy,
+                    auto_wrapper_callable=auto_wrapper_callable,
+                )
+            else:
+                self.model = model = FSDP(
+                    model,
+                    auto_wrap_policy=auto_wrap_policy,
+                    auto_wrapper_callable=auto_wrapper_callable,
+                    **fsdp_kwargs,
+                )
+
+            # Patch `xm.optimizer_step` should not reduce gradients in this case,
+            # as FSDP does not need gradient reduction over sharded parameters.
+            def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
+                loss = optimizer.step(**optimizer_args)
+                if barrier:
+                    xm.mark_step()
+                return loss
+
+            xm.optimizer_step = patched_optimizer_step
+        elif is_sagemaker_dp_enabled():
+            model = nn.parallel.DistributedDataParallel(
+                model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
+            )
+        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+            if is_torch_neuroncore_available():
+                return model
+            kwargs = {}
+            if self.args.ddp_find_unused_parameters is not None:
+                kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
+            elif isinstance(model, PreTrainedModel):
+                # find_unused_parameters breaks checkpointing as per
+                # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
+                kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
+            else:
+                kwargs["find_unused_parameters"] = True
+
+            if self.args.ddp_bucket_cap_mb is not None:
+                kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
+
+            if self.args.ddp_broadcast_buffers is not None:
+                kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers
+
+            self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
+
+        return model
+
+    def train(
+        self,
+        resume_from_checkpoint: Optional[Union[str, bool]] = None,
+        trial: Union["optuna.Trial", dict[str, Any], None] = None,
+        ignore_keys_for_eval: Optional[list[str]] = None,
+        **kwargs,
+    ):
+        """
+        Main training entry point.
+
+        Args:
+            resume_from_checkpoint (`str` or `bool`, *optional*):
+                If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
+                `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
+                of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
+            trial (`optuna.Trial` or `dict[str, Any]`, *optional*):
+                The trial run or the hyperparameter dictionary for hyperparameter search.
+            ignore_keys_for_eval (`list[str]`, *optional*)
+                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+                gathering predictions for evaluation during the training.
+            kwargs (`dict[str, Any]`, *optional*):
+                Additional keyword arguments used to hide deprecated arguments
+        """
+        if resume_from_checkpoint is False:
+            resume_from_checkpoint = None
+
+        # memory metrics - must set up as early as possible
+        self._memory_tracker.start()
+
+        args = self.args
+
+        self.is_in_train = True
+
+        # If the model uses a tokenizer, it may have a new tokens for fine-tuning purposes.
+        if isinstance(self.processing_class, (PreTrainedTokenizerBase, ProcessorMixin)) and hasattr(
+            self.model, "config"
+        ):
+            self._align_special_tokens()
+
+        # Attach NEFTune hooks if necessary
+        if self.neftune_noise_alpha is not None:
+            self.model = self._activate_neftune(self.model)
+
+        # do_train is not a reliable argument, as it might not be set and .train() still called, so
+        # the following is a workaround:
+        if (
+            (args.fp16_full_eval or args.bf16_full_eval)
+            and not args.do_train
+            and not self.is_model_parallel
+            and self.model_init is None
+        ):
+            self._move_model_to_device(self.model, args.device)
+
+        if "model_path" in kwargs:
+            resume_from_checkpoint = kwargs.pop("model_path")
+            warnings.warn(
+                "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
+                "instead.",
+                FutureWarning,
+            )
+        if len(kwargs) > 0:
+            raise TypeError(f"train() got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
+        # This might change the seed so needs to run first.
+        self._hp_search_setup(trial)
+        self._train_batch_size = self.args.train_batch_size
+
+        # Model re-init
+        model_reloaded = False
+        if self.model_init is not None:
+            # Seed must be set before instantiating the model when using model_init.
+            enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
+            self.model = self.call_model_init(trial)
+            model_reloaded = True
+            # Reinitializes optimizer and scheduler
+            self.optimizer, self.lr_scheduler = None, None
+
+        # Load potential model checkpoint
+        if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
+            resume_from_checkpoint = get_last_checkpoint(args.output_dir)
+            if resume_from_checkpoint is None:
+                raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
+
+        if resume_from_checkpoint is not None:
+            if not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled:
+                self._load_from_checkpoint(resume_from_checkpoint)
+            # In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly
+            state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
+            if state.train_batch_size is not None:
+                self._train_batch_size = state.train_batch_size
+
+        # If model was re-initialized, put it on the right device and update self.model_wrapped
+        if model_reloaded:
+            if self.place_model_on_device:
+                self._move_model_to_device(self.model, args.device)
+            self.model_wrapped = self.model
+
+        inner_training_loop = find_executable_batch_size(
+            self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
+        )
+        if args.push_to_hub:
+            try:
+                # Disable progress bars when uploading models during checkpoints to avoid polluting stdout
+                hf_hub_utils.disable_progress_bars()
+                return inner_training_loop(
+                    args=args,
+                    resume_from_checkpoint=resume_from_checkpoint,
+                    trial=trial,
+                    ignore_keys_for_eval=ignore_keys_for_eval,
+                )
+            finally:
+                hf_hub_utils.enable_progress_bars()
+        else:
+            return inner_training_loop(
+                args=args,
+                resume_from_checkpoint=resume_from_checkpoint,
+                trial=trial,
+                ignore_keys_for_eval=ignore_keys_for_eval,
+            )
+
+    def get_tp_size(self) -> int:
+        """Get the tensor parallel size from either the model or DeepSpeed config."""
+
+        # 1. Check model.tp_size first
+        if (model_tp := getattr(self.model, "_tp_size", None)) is not None:
+            return model_tp
+
+        # 2. Fall back to DeepSpeed config if enabled
+        if self.is_deepspeed_enabled and (deepspeed_config := getattr(self.args, "hf_deepspeed_config", None)):
+            return deepspeed_config.config.get("tensor_parallel", {}).get("autotp_size", 1)
+
+        # 3. Default fallback
+        return 1
+
+    def get_total_train_batch_size(self, args) -> int:
+        """Calculates total batch size (micro_batch * grad_accum * dp_world_size).
+
+        Note: Only considers DP and TP (dp_world_size = world_size // tp_size)."""
+        dp_world_size = args.world_size // self.get_tp_size()
+        return self._train_batch_size * args.gradient_accumulation_steps * dp_world_size
+
+    def _inner_training_loop(
+        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
+    ):
+        self.accelerator.free_memory()
+        self._train_batch_size = batch_size
+        if self.args.auto_find_batch_size:
+            if self.state.train_batch_size != self._train_batch_size:
+                from accelerate.utils import release_memory
+
+                (self.model_wrapped,) = release_memory(self.model_wrapped)
+                self.model_wrapped = self.model
+
+                # Check for DeepSpeed *after* the initial pass and modify the config
+                if self.is_deepspeed_enabled:
+                    # Temporarily unset `self.args.train_batch_size`
+                    original_bs = self.args.per_device_train_batch_size
+                    self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu)
+                    self.propagate_args_to_deepspeed(True)
+                    self.args.per_device_train_batch_size = original_bs
+            self.state.train_batch_size = self._train_batch_size
+        logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
+        # Data loader and number of training steps
+        train_dataloader = self.get_train_dataloader()
+        if self.is_fsdp_xla_v2_enabled:
+            train_dataloader = tpu_spmd_dataloader(train_dataloader)
+
+        # Setting up training control variables:
+        # number of training epochs: num_train_epochs
+        # number of training steps per epoch: num_update_steps_per_epoch
+        # total number of training steps to execute: max_steps
+        total_train_batch_size = self.get_total_train_batch_size(args)
+
+        (
+            num_train_epochs,
+            num_update_steps_per_epoch,
+            num_examples,
+            num_train_samples,
+            epoch_based,
+            len_dataloader,
+            max_steps,
+        ) = self.set_initial_training_values(args, train_dataloader, total_train_batch_size)
+
+        num_train_tokens = None
+        if self.args.include_tokens_per_second:
+            num_train_tokens = self.num_tokens(train_dataloader, None if epoch_based else max_steps)
+            # If going by epochs, multiply tokens linearly
+            if len_dataloader is not None and epoch_based:
+                num_train_tokens *= args.num_train_epochs
+            # Otherwise since its steps, we just multiply by grad accum
+            else:
+                num_train_tokens *= args.gradient_accumulation_steps
+
+        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
+            if self.args.n_gpu > 1:
+                # nn.DataParallel(model) replicates the model, creating new variables and module
+                # references registered here no longer work on other gpus, breaking the module
+                raise ValueError(
+                    "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
+                    " (torchrun or torch.distributed.launch (deprecated))."
+                )
+            else:
+                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa
+
+        delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
+
+        # Can't delay optimizer creation when using FSDP2: https://github.com/huggingface/accelerate/blob/3f636d626063ffcf9a337c7d3624d61b7d187d59/src/accelerate/accelerator.py#L1404
+        is_fsdp2 = self.is_fsdp_enabled and (getattr(self.accelerator.state.fsdp_plugin, "fsdp_version", 1) == 2)
+        if is_fsdp2:
+            delay_optimizer_creation = False
+
+        # We need to reset the scheduler, as its parameters may be different on subsequent calls
+        if self._created_lr_scheduler:
+            self.lr_scheduler = None
+            self._created_lr_scheduler = False
+
+        if self.is_deepspeed_enabled:
+            self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
+
+        if not delay_optimizer_creation:
+            self.create_optimizer_and_scheduler(num_training_steps=max_steps)
+
+        self.state = TrainerState(
+            stateful_callbacks=[
+                cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
+            ]
+        )
+        self.state.is_hyper_param_search = trial is not None
+        self.state.train_batch_size = self._train_batch_size
+
+        # Compute absolute values for logging, eval, and save if given as ratio
+        self.state.compute_steps(args, max_steps)
+
+        # Activate gradient checkpointing if needed
+        if args.gradient_checkpointing:
+            self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs)
+
+        model = self._wrap_model(self.model_wrapped)
+
+        # as the model is wrapped, don't use `accelerator.prepare`
+        # this is for unhandled cases such as
+        # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
+        use_accelerator_prepare = model is self.model
+
+        if use_accelerator_prepare and self.is_fsdp_enabled:
+            # In case of auto_find_batch_size=True
+            # Remove FSDP wrapping from sub-models.
+            self.model = unwrap_model(self.model, recursive=True)
+
+        if delay_optimizer_creation:
+            if use_accelerator_prepare:
+                # configure fsdp plugin for qlora if any
+                self._fsdp_qlora_plugin_updates()
+                if self.accelerator.mixed_precision != "fp8":
+                    self.model = self.accelerator.prepare(self.model)
+            self.create_optimizer_and_scheduler(num_training_steps=max_steps)
+
+        # prepare using `accelerator` prepare
+        if use_accelerator_prepare:
+            self.model.train()
+            if hasattr(self.lr_scheduler, "step"):
+                if self.use_apex:
+                    model = self.accelerator.prepare(self.model)
+                else:
+                    # We should avoid accelerate preparing the model in TP case since we dont need it as it is handled by transformers from_pretrained and also it goes into DDP based preparation.
+                    if self.is_tp_enabled:
+                        self.optimizer = self.accelerator.prepare(self.optimizer)
+                    else:
+                        model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
+            else:
+                # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
+                model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
+                    self.model, self.optimizer, self.lr_scheduler
+                )
+        elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
+            # In this case we are in DDP + LOMO, which should be supported
+            self.optimizer = self.accelerator.prepare(self.optimizer)
+
+        if self.is_fsdp_enabled:
+            self.model = self.model_wrapped = model
+
+        # for the rest of this function `model` is the outside model, whether it was wrapped or not
+        if model is not self.model:
+            self.model_wrapped = model
+
+        # backward compatibility
+        if self.is_deepspeed_enabled:
+            self.deepspeed = self.model_wrapped
+
+        # ckpt loading
+        if resume_from_checkpoint is not None:
+            if self.is_deepspeed_enabled:
+                deepspeed_load_checkpoint(
+                    self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)
+                )
+            elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
+                self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
+
+        # Check if saved optimizer or scheduler states exist
+        self._load_optimizer_and_scheduler(resume_from_checkpoint)
+        self._load_scaler(resume_from_checkpoint)
+
+        # important: at this point:
+        # self.model         is the Transformers Model
+        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
+        # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.
+
+        # Train!
+        logger.info("***** Running training *****")
+        logger.info(f"  Num examples = {num_examples:,}")
+        logger.info(f"  Num Epochs = {num_train_epochs:,}")
+        logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
+        if self.args.per_device_train_batch_size != self._train_batch_size:
+            logger.info(f"  Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
+        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
+        logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+        logger.info(f"  Total optimization steps = {max_steps:,}")
+        logger.info(f"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
+
+        self.state.epoch = 0
+        start_time = time.time()
+        epochs_trained = 0
+        steps_trained_in_current_epoch = 0
+        steps_trained_progress_bar = None
+
+        # Check if continuing training from a checkpoint
+        if resume_from_checkpoint is not None and os.path.isfile(
+            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
+        ):
+            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
+            self.compare_trainer_and_checkpoint_args(self.args, self.state)
+            self._load_callback_state()
+            epochs_trained = int(self.state.global_step // num_update_steps_per_epoch)
+            if not args.ignore_data_skip:
+                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
+                steps_trained_in_current_epoch *= args.gradient_accumulation_steps
+            else:
+                steps_trained_in_current_epoch = 0
+
+            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
+            logger.info(f"  Continuing training from epoch {epochs_trained}")
+            logger.info(f"  Continuing training from global step {self.state.global_step}")
+            if not args.ignore_data_skip:
+                logger.info(
+                    f"  Will skip the first {epochs_trained} epochs then the first"
+                    f" {steps_trained_in_current_epoch} batches in the first epoch."
+                )
+
+        # Update the references
+        for attr in ("model", "optimizer", "lr_scheduler"):
+            setattr(self.callback_handler, attr, getattr(self, attr))
+        self.callback_handler.train_dataloader = train_dataloader
+
+        self.state.init_training_references(self, max_steps, num_train_epochs, trial)
+
+        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
+        tr_loss = torch.tensor(0.0, device=args.device)
+        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
+        self._total_loss_scalar = 0.0
+        self._globalstep_last_logged = self.state.global_step
+        model.zero_grad()
+        grad_norm: Optional[float] = None
+        learning_rate = None
+        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
+
+        if args.eval_on_start:
+            self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
+
+        for epoch in range(epochs_trained, num_train_epochs):
+            epoch_dataloader = train_dataloader
+            if hasattr(epoch_dataloader, "set_epoch"):
+                epoch_dataloader.set_epoch(epoch)
+
+            # Reset the past mems state at the beginning of each epoch if necessary.
+            if args.past_index >= 0:
+                self._past = None
+
+            steps_in_epoch = (
+                len(epoch_dataloader)
+                if len_dataloader is not None
+                else args.max_steps * args.gradient_accumulation_steps
+            )
+            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
+
+            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
+                self._load_rng_state(resume_from_checkpoint)
+
+            rng_to_sync = False
+            steps_skipped = 0
+            if steps_trained_in_current_epoch > 0:
+                epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
+                steps_skipped = steps_trained_in_current_epoch
+                steps_trained_in_current_epoch = 0
+                rng_to_sync = True
+
+            step = -1
+            epoch_iterator = iter(epoch_dataloader)
+            # We chunkify the epoch iterator into gradient accumulation steps `n` batches
+            remainder = steps_in_epoch % args.gradient_accumulation_steps
+            if remainder == 0:
+                remainder = args.gradient_accumulation_steps
+            update_step = -1
+            total_updates = steps_in_epoch // args.gradient_accumulation_steps + int(
+                remainder < args.gradient_accumulation_steps
+            )
+            for _ in range(total_updates):
+                update_step += 1
+                num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
+                batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)
+                # Store the number of batches for current gradient accumulation
+                # This is used to correctly scale the loss when the last accumulation step has fewer batches
+                self.current_gradient_accumulation_steps = len(batch_samples)
+                for i, inputs in enumerate(batch_samples):
+                    step += 1
+                    do_sync_step = (step + 1) % args.gradient_accumulation_steps == 0 or (step + 1) == steps_in_epoch
+                    # Since we perform prefetching, we need to manually set sync_gradients
+                    self.accelerator.gradient_state._set_sync_gradients(do_sync_step)
+
+                    if self.args.include_num_input_tokens_seen:
+                        main_input_name = getattr(self.model, "main_input_name", "input_ids")
+                        if main_input_name not in inputs:
+                            logger.warning(
+                                "Tried to track the number of tokens seen, however the current model is "
+                                "not configured properly to know what item is the input. To fix this, add "
+                                "a `main_input_name` attribute to the model class you are using."
+                            )
+                        else:
+                            input_tokens = inputs[main_input_name].numel()
+                            input_tokens = torch.tensor(input_tokens, device=self.args.device, dtype=torch.int64)
+                            self.state.num_input_tokens_seen += self.accelerator.gather(input_tokens).sum().item()
+                    if rng_to_sync:
+                        self._load_rng_state(resume_from_checkpoint)
+                        rng_to_sync = False
+
+                    # Skip past any already trained steps if resuming training
+                    if steps_trained_in_current_epoch > 0:
+                        steps_trained_in_current_epoch -= 1
+                        if steps_trained_progress_bar is not None:
+                            steps_trained_progress_bar.update(1)
+                        if steps_trained_in_current_epoch == 0:
+                            self._load_rng_state(resume_from_checkpoint)
+                        continue
+                    elif steps_trained_progress_bar is not None:
+                        steps_trained_progress_bar.close()
+                        steps_trained_progress_bar = None
+
+                    if step % args.gradient_accumulation_steps == 0:
+                        self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
+
+                    # We explicitly want to avoid relying on `accelerator.accumulate` for generation training
+                    context = (
+                        functools.partial(self.accelerator.no_sync, model=model)
+                        if i != len(batch_samples) - 1
+                        and self.accelerator.distributed_type != DistributedType.DEEPSPEED
+                        else contextlib.nullcontext
+                    )
+                    with context():
+                        tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
+
+                    if (
+                        args.logging_nan_inf_filter
+                        and not is_torch_xla_available()
+                        and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
+                    ):
+                        # if loss is nan or inf simply add the average of previous logged losses
+                        tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
+                    else:
+                        if tr_loss.device != tr_loss_step.device:
+                            raise ValueError(
+                                f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
+                            )
+                        tr_loss = tr_loss + tr_loss_step
+
+                    self.current_flos += float(self.floating_point_ops(inputs))
+
+                    if do_sync_step:
+                        # Since we perform prefetching, we need to manually set sync_gradients to True
+                        self.accelerator.gradient_state._set_sync_gradients(True)
+
+                        # Gradient clipping
+                        if args.max_grad_norm is not None and args.max_grad_norm > 0:
+                            if is_sagemaker_mp_enabled() and args.fp16:
+                                _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
+                            elif self.use_apex:
+                                from apex import amp
+
+                                # Revert to normal clipping otherwise, handling Apex or full precision
+                                _grad_norm = nn.utils.clip_grad_norm_(
+                                    amp.master_params(self.optimizer),
+                                    args.max_grad_norm,
+                                )
+                            else:
+                                grad_norm_context = contextlib.nullcontext
+                                if self.is_tp_enabled:
+                                    from torch.distributed._tensor.experimental import implicit_replication
+
+                                    grad_norm_context = implicit_replication
+                                with grad_norm_context():
+                                    _grad_norm = self.accelerator.clip_grad_norm_(
+                                        model.parameters(),
+                                        args.max_grad_norm,
+                                    )
+
+                            if (
+                                is_accelerate_available()
+                                and self.accelerator.distributed_type == DistributedType.DEEPSPEED
+                            ):
+                                grad_norm = model.get_global_grad_norm()
+                                # In some cases the grad norm may not return a float
+                                if hasattr(grad_norm, "item"):
+                                    grad_norm = grad_norm.item()
+                            else:
+                                grad_norm = _grad_norm
+
+                        self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)
+
+                        context = contextlib.nullcontext
+                        if self.is_tp_enabled:
+                            from torch.distributed._tensor.experimental import implicit_replication
+
+                            context = implicit_replication
+
+                        with context():
+                            self.optimizer.step()
+
+                        self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)
+
+                        # get leaning rate before update
+                        learning_rate = self._get_learning_rate()
+
+                        if not self.accelerator.optimizer_step_was_skipped:
+                            # Delay optimizer scheduling until metrics are generated
+                            if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
+                                self.lr_scheduler.step()
+
+                        model.zero_grad()
+                        self.state.global_step += 1
+                        self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
+                        self.control = self.callback_handler.on_step_end(args, self.state, self.control)
+                        self._maybe_log_save_evaluate(
+                            tr_loss,
+                            grad_norm,
+                            model,
+                            trial,
+                            epoch,
+                            ignore_keys_for_eval,
+                            start_time,
+                            learning_rate=learning_rate,
+                        )
+                    else:
+                        self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
+
+                    # PyTorch/XLA relies on the data loader to insert the mark_step for
+                    # each step. Since we are breaking the loop early, we need to manually
+                    # insert the mark_step here.
+                    if self.control.should_epoch_stop or self.control.should_training_stop:
+                        if is_torch_xla_available():
+                            xm.mark_step()
+                        break
+                # We also need to break out of the nested loop
+                if self.control.should_epoch_stop or self.control.should_training_stop:
+                    if is_torch_xla_available():
+                        xm.mark_step()
+                    break
+            if step < 0:
+                logger.warning(
+                    "There seems not to be a single sample in your epoch_iterator, stopping training at step"
+                    f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
+                    f" num_steps ({max_steps}) higher than the number of available samples."
+                )
+                self.control.should_training_stop = True
+
+            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
+            self._maybe_log_save_evaluate(
+                tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=learning_rate
+            )
+
+            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
+                if is_torch_xla_available():
+                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
+                    xm.master_print(met.metrics_report())
+                else:
+                    logger.warning(
+                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
+                        "configured. Check your training configuration if this is unexpected."
+                    )
+            if self.control.should_training_stop:
+                break
+
+        if args.past_index and hasattr(self, "_past"):
+            # Clean the state at the end of training
+            delattr(self, "_past")
+
+        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
+        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
+            # Wait for everyone to get here so we are sure the model has been saved by process 0.
+            if is_torch_xla_available():
+                xm.rendezvous("load_best_model_at_end")
+            elif args.parallel_mode == ParallelMode.DISTRIBUTED:
+                dist.barrier()
+            elif is_sagemaker_mp_enabled():
+                smp.barrier()
+
+            self._load_best_model()
+
+        # add remaining tr_loss
+        self._total_loss_scalar += tr_loss.item()
+        effective_global_step = max(self.state.global_step, 0.001)  # Avoid ZeroDivisionError
+        train_loss = self._total_loss_scalar / effective_global_step
+
+        metrics = speed_metrics(
+            "train",
+            start_time,
+            num_samples=num_train_samples,
+            num_steps=self.state.max_steps,
+            num_tokens=num_train_tokens,
+        )
+        self.store_flos()
+        metrics["total_flos"] = self.state.total_flos
+        metrics["train_loss"] = train_loss
+
+        self.is_in_train = False
+
+        self._memory_tracker.stop_and_update_metrics(metrics)
+
+        self.log(metrics)
+
+        run_dir = self._get_output_dir(trial)
+        checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
+
+        # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
+        if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
+            for checkpoint in checkpoints_sorted:
+                if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
+                    logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
+                    shutil.rmtree(checkpoint, ignore_errors=True)
+
+        self.control = self.callback_handler.on_train_end(args, self.state, self.control)
+
+        # Wait for the checkpoint to be uploaded.
+        self._finish_current_push()
+
+        # After training we make sure to retrieve back the original forward pass method
+        # for the embedding layer by removing the forward post hook.
+        if self.neftune_noise_alpha is not None:
+            self._deactivate_neftune(self.model)
+
+        return TrainOutput(self.state.global_step, train_loss, metrics)
+
+    def _get_output_dir(self, trial):
+        if self.hp_search_backend is not None and trial is not None:
+            if self.hp_search_backend == HPSearchBackend.OPTUNA:
+                run_id = trial.number
+            elif self.hp_search_backend == HPSearchBackend.RAY:
+                import ray.train
+
+                run_id = ray.train.get_context().get_trial_id()
+            elif self.hp_search_backend == HPSearchBackend.SIGOPT:
+                run_id = trial.id
+            elif self.hp_search_backend == HPSearchBackend.WANDB:
+                import wandb
+
+                run_id = wandb.run.id
+            run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
+            run_dir = os.path.join(self.args.output_dir, run_name)
+        else:
+            run_dir = self.args.output_dir
+        return run_dir
+
+    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
+        if model is None:
+            model = self.model
+
+        config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
+        adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)
+        adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
+        weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
+        weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
+        safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
+        safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
+        is_fsdp_ckpt = os.path.isdir(resume_from_checkpoint) and (
+            # this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
+            any(
+                FSDP_MODEL_NAME in folder_name
+                for folder_name in os.listdir(resume_from_checkpoint)
+                if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
+            )
+            # this checks the FSDP state dict when `FULL_STATE_DICT` is used
+            or os.path.isfile(os.path.join(resume_from_checkpoint, f"{FSDP_MODEL_NAME}.bin"))
+        )
+        # if multiple adapters exist, they get saved in sub directories
+        adapter_subdirs = (
+            [
+                folder_name
+                for folder_name in os.listdir(resume_from_checkpoint)
+                if os.path.isdir(os.path.join(resume_from_checkpoint, folder_name))
+                and (
+                    os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_WEIGHTS_NAME))
+                    or os.path.isfile(os.path.join(resume_from_checkpoint, folder_name, ADAPTER_SAFE_WEIGHTS_NAME))
+                )
+            ]
+            if os.path.isdir(resume_from_checkpoint)
+            else []
+        )
+
+        if is_fsdp_ckpt and not self.is_fsdp_enabled:
+            raise ValueError(f"Checkpoint found at {resume_from_checkpoint} is only supported when using PyTorch FSDP")
+
+        if not (
+            any(
+                os.path.isfile(f)
+                for f in [
+                    weights_file,
+                    safe_weights_file,
+                    weights_index_file,
+                    safe_weights_index_file,
+                    adapter_weights_file,
+                    adapter_safe_weights_file,
+                ]
+            )
+            or is_fsdp_ckpt
+            or adapter_subdirs
+        ):
+            raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
+
+        logger.info(f"Loading model from {resume_from_checkpoint}.")
+
+        if os.path.isfile(config_file):
+            config = PretrainedConfig.from_json_file(config_file)
+            checkpoint_version = config.transformers_version
+            if checkpoint_version is not None and checkpoint_version != __version__:
+                logger.warning(
+                    f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
+                    f"Transformers but your current version is {__version__}. This is not recommended and could "
+                    "yield to errors or unwanted behaviors."
+                )
+
+        if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file) or is_fsdp_ckpt:
+            # If the model is on the GPU, it still works!
+            if is_sagemaker_mp_enabled():
+                if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
+                    # If the 'user_content.pt' file exists, load with the new smp api.
+                    # Checkpoint must have been saved with the new smp api.
+                    smp.resume_from_checkpoint(
+                        path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
+                    )
+                else:
+                    # If the 'user_content.pt' file does NOT exist, load with the old smp api.
+                    # Checkpoint must have been saved with the old smp api.
+                    if hasattr(self.args, "fp16") and self.args.fp16 is True:
+                        logger.warning(
+                            "Enabling FP16 and loading from smp < 1.10 checkpoint together is not supported."
+                        )
+                    check_torch_load_is_safe()
+                    state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
+                    # Required for smp to not auto-translate state_dict from hf to smp (is already smp).
+                    state_dict["_smp_is_partial"] = False
+                    load_result = model.load_state_dict(state_dict, strict=True)
+                    # release memory
+                    del state_dict
+            elif self.is_fsdp_enabled:
+                load_fsdp_model(
+                    self.accelerator.state.fsdp_plugin,
+                    self.accelerator,
+                    model,
+                    resume_from_checkpoint,
+                    **_get_fsdp_ckpt_kwargs(),
+                )
+            else:
+                # We load the model state dict on the CPU to avoid an OOM error.
+                if self.args.save_safetensors and os.path.isfile(safe_weights_file):
+                    state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
+                else:
+                    check_torch_load_is_safe()
+                    state_dict = torch.load(weights_file, map_location="cpu", weights_only=True)
+
+                # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
+                # which takes *args instead of **kwargs
+                load_result = model.load_state_dict(state_dict, False)
+                # release memory
+                del state_dict
+                self._issue_warnings_after_load(load_result)
+
+        # Load adapters following PR # 24096
+        elif _is_peft_model(model):
+            # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
+            # TODO: in the future support only specific min PEFT versions
+            if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr(
+                model, "load_adapter"
+            ):
+                if os.path.exists(resume_from_checkpoint):
+                    # For BC for older PEFT versions
+                    if hasattr(model, "active_adapters"):
+                        active_adapters = model.active_adapters
+                        if len(active_adapters) > 1:
+                            logger.warning("Multiple active adapters detected will only consider the first adapter")
+                        active_adapter = active_adapters[0]
+                    else:
+                        active_adapter = model.active_adapter
+
+                    if adapter_subdirs:
+                        for subdir_name in adapter_subdirs:
+                            peft_id = os.path.join(resume_from_checkpoint, subdir_name)
+                            model.load_adapter(peft_id, subdir_name, is_trainable=(subdir_name == active_adapter))
+                        model.set_adapter(active_adapter)
+                    else:
+                        model.load_adapter(resume_from_checkpoint, active_adapter, is_trainable=True)
+                else:
+                    logger.warning(
+                        "The intermediate checkpoints of PEFT may not be saved correctly, "
+                        f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
+                        "Check some examples here: https://github.com/huggingface/peft/issues/96"
+                    )
+            else:
+                logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
+        else:
+            # We load the sharded checkpoint
+            load_result = load_sharded_checkpoint(
+                model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors
+            )
+            if not is_sagemaker_mp_enabled():
+                self._issue_warnings_after_load(load_result)
+
+    def _load_best_model(self):
+        logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
+        best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
+        best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
+        best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
+        best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
+
+        model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
+        if self.is_deepspeed_enabled:
+            deepspeed_load_checkpoint(
+                self.model_wrapped,
+                self.state.best_model_checkpoint,
+                load_module_strict=not _is_peft_model(self.model),
+            )
+        elif self.is_fsdp_enabled:
+            load_result = load_fsdp_model(
+                self.accelerator.state.fsdp_plugin,
+                self.accelerator,
+                model,
+                self.state.best_model_checkpoint,
+                **_get_fsdp_ckpt_kwargs(),
+            )
+        elif (
+            os.path.exists(best_model_path)
+            or os.path.exists(best_safe_model_path)
+            or os.path.exists(best_adapter_model_path)
+            or os.path.exists(best_safe_adapter_model_path)
+        ):
+            has_been_loaded = True
+            if is_sagemaker_mp_enabled():
+                if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
+                    # If the 'user_content.pt' file exists, load with the new smp api.
+                    # Checkpoint must have been saved with the new smp api.
+                    smp.resume_from_checkpoint(
+                        path=self.state.best_model_checkpoint,
+                        tag=WEIGHTS_NAME,
+                        partial=False,
+                        load_optimizer=False,
+                    )
+                else:
+                    # If the 'user_content.pt' file does NOT exist, load with the old smp api.
+                    # Checkpoint must have been saved with the old smp api.
+                    if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
+                        state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
+                    else:
+                        check_torch_load_is_safe()
+                        state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
+
+                    state_dict["_smp_is_partial"] = False
+                    load_result = model.load_state_dict(state_dict, strict=True)
+            else:
+                if _is_peft_model(model):
+                    # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
+                    # TODO: in the future support only specific min PEFT versions
+                    if (hasattr(model, "active_adapter") or hasattr(model, "active_adapters")) and hasattr(
+                        model, "load_adapter"
+                    ):
+                        # For BC for older PEFT versions
+                        if hasattr(model, "active_adapters"):
+                            active_adapter = model.active_adapters[0]
+                            if len(model.active_adapters) > 1:
+                                logger.warning("Detected multiple active adapters, will only consider the first one")
+                        else:
+                            active_adapter = model.active_adapter
+
+                        if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
+                            try:
+                                model.load_adapter(self.state.best_model_checkpoint, active_adapter)
+                            except RuntimeError as exc:
+                                if model.peft_config[active_adapter].is_prompt_learning:
+                                    # for context: https://github.com/huggingface/peft/issues/2256
+                                    msg = (
+                                        "When using prompt learning PEFT methods such as "
+                                        f"{model.peft_config[active_adapter].peft_type.value}, setting "
+                                        "load_best_model_at_end=True can lead to errors, it is recommended "
+                                        "to set this to False and to load the model manually from the checkpoint "
+                                        "directory using PeftModel.from_pretrained(base_model, ) after training "
+                                        "has finished."
+                                    )
+                                    raise RuntimeError(msg) from exc
+                                else:
+                                    raise
+                            # Load_adapter has no return value present, modify it when appropriate.
+                            from torch.nn.modules.module import _IncompatibleKeys
+
+                            load_result = _IncompatibleKeys([], [])
+                        else:
+                            logger.warning(
+                                "The intermediate checkpoints of PEFT may not be saved correctly, "
+                                f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. "
+                                "Check some examples here: https://github.com/huggingface/peft/issues/96"
+                            )
+                            has_been_loaded = False
+                    else:
+                        logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
+                        has_been_loaded = False
+                else:
+                    # We load the model state dict on the CPU to avoid an OOM error.
+                    if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
+                        state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
+                    else:
+                        check_torch_load_is_safe()
+                        state_dict = torch.load(best_model_path, map_location="cpu", weights_only=True)
+
+                    # If the model is on the GPU, it still works!
+                    # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
+                    # which takes *args instead of **kwargs
+                    load_result = model.load_state_dict(state_dict, False)
+                if not is_sagemaker_mp_enabled() and has_been_loaded:
+                    self._issue_warnings_after_load(load_result)
+        elif os.path.exists(os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_INDEX_NAME)) or os.path.exists(
+            os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)
+        ):
+            load_result = load_sharded_checkpoint(
+                model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()
+            )
+            if not is_sagemaker_mp_enabled():
+                self._issue_warnings_after_load(load_result)
+        else:
+            logger.warning(
+                f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
+                "on multiple nodes, you should activate `--save_on_each_node`."
+            )
+
+    def _issue_warnings_after_load(self, load_result):
+        if len(load_result.missing_keys) != 0:
+            if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
+                self.model._keys_to_ignore_on_save
+            ):
+                self.model.tie_weights()
+            else:
+                logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
+        if len(load_result.unexpected_keys) != 0:
+            logger.warning(
+                f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
+            )
+
+    def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
+        metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
+        self._report_to_hp_search(trial, self.state.global_step, metrics)
+
+        # Run delayed LR scheduler now that metrics are populated
+        if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) and not skip_scheduler:
+            metric_to_check = self.args.metric_for_best_model
+            if not metric_to_check.startswith("eval_"):
+                metric_to_check = f"eval_{metric_to_check}"
+            try:
+                self.lr_scheduler.step(metrics[metric_to_check])
+            except KeyError as exc:
+                raise KeyError(
+                    f"The `metric_for_best_model` training argument is set to '{metric_to_check}', "
+                    f"which is not found in the evaluation metrics. "
+                    f"The available evaluation metrics are: {list(metrics.keys())}. "
+                    f"Please ensure that the `compute_metrics` function returns a dictionary that includes '{metric_to_check}' or "
+                    f"consider changing the `metric_for_best_model` via the TrainingArguments."
+                ) from exc
+        return metrics
+
+    def _maybe_log_save_evaluate(
+        self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None
+    ):
+        if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
+            if is_torch_xla_available():
+                xm.mark_step()
+
+            logs: dict[str, float] = {}
+
+            # all_gather + mean() to get average loss over all processes
+            tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
+
+            # reset tr_loss to zero
+            tr_loss -= tr_loss
+
+            logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
+            if grad_norm is not None:
+                logs["grad_norm"] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
+            if learning_rate is not None:
+                logs["learning_rate"] = learning_rate
+            else:
+                logs["learning_rate"] = self._get_learning_rate()
+
+            self._total_loss_scalar += tr_loss_scalar
+            self._globalstep_last_logged = self.state.global_step
+            self.store_flos()
+
+            self.log(logs, start_time)
+
+        metrics = None
+        if self.control.should_evaluate:
+            metrics = self._evaluate(trial, ignore_keys_for_eval)
+            is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
+
+            if self.args.save_strategy == SaveStrategy.BEST:
+                self.control.should_save = is_new_best_metric
+
+        if self.control.should_save:
+            self._save_checkpoint(model, trial)
+            self.control = self.callback_handler.on_save(self.args, self.state, self.control)
+
+    def _load_rng_state(self, checkpoint):
+        # Load RNG states from `checkpoint`
+        if checkpoint is None:
+            return
+
+        if self.args.world_size > 1:
+            process_index = self.args.process_index
+            rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
+            if not os.path.isfile(rng_file):
+                logger.info(
+                    f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
+                    "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
+                )
+                return
+        else:
+            rng_file = os.path.join(checkpoint, "rng_state.pth")
+            if not os.path.isfile(rng_file):
+                logger.info(
+                    "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
+                    "fashion, reproducibility is not guaranteed."
+                )
+                return
+
+        with safe_globals():
+            checkpoint_rng_state = torch.load(rng_file)
+        random.setstate(checkpoint_rng_state["python"])
+        np.random.set_state(checkpoint_rng_state["numpy"])
+        torch.random.set_rng_state(checkpoint_rng_state["cpu"])
+        if is_torch_xla_available():
+            xm.set_rng_state(checkpoint_rng_state["xla"])
+
+        is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
+        if torch.cuda.is_available():
+            set_rng_state_for_device("CUDA", torch.cuda, checkpoint_rng_state, is_distributed)
+        if is_torch_npu_available():
+            set_rng_state_for_device("NPU", torch.npu, checkpoint_rng_state, is_distributed)
+        if is_torch_hpu_available():
+            set_rng_state_for_device("HPU", torch.hpu, checkpoint_rng_state, is_distributed)
+        if is_torch_mlu_available():
+            set_rng_state_for_device("MLU", torch.mlu, checkpoint_rng_state, is_distributed)
+        if is_torch_musa_available():
+            set_rng_state_for_device("MUSA", torch.musa, checkpoint_rng_state, is_distributed)
+
+    def _determine_best_metric(self, metrics, trial):
+        """
+        Determine if the model should be saved based on the evaluation metrics.
+
+        Returns:
+            bool: True if a new best metric was found, else False
+        """
+        is_new_best_metric = False
+
+        if self.args.metric_for_best_model is not None:
+            metric_to_check = self.args.metric_for_best_model
+
+            if not metric_to_check.startswith("eval_"):
+                metric_to_check = f"eval_{metric_to_check}"
+
+            try:
+                metric_value = metrics[metric_to_check]
+            except KeyError as exc:
+                raise KeyError(
+                    f"The `metric_for_best_model` training argument is set to '{metric_to_check}', which is not found in the evaluation metrics. "
+                    f"The available evaluation metrics are: {list(metrics.keys())}. Consider changing the `metric_for_best_model` via the TrainingArguments."
+                ) from exc
+
+            operator = np.greater if self.args.greater_is_better else np.less
+
+            if self.state.best_metric is None:
+                self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")
+
+            if operator(metric_value, self.state.best_metric):
+                self.state.best_metric = metric_value
+
+                if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH]:
+                    self.state.best_global_step = self.state.global_step
+
+                is_new_best_metric = True
+
+        return is_new_best_metric
+
+    def _save_checkpoint(self, model, trial):
+        # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
+        # want to save except FullyShardedDDP.
+        # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
+
+        # Save model checkpoint
+        checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+
+        if self.hp_search_backend is None and trial is None:
+            self.store_flos()
+
+        run_dir = self._get_output_dir(trial=trial)
+        output_dir = os.path.join(run_dir, checkpoint_folder)
+        self.save_model(output_dir, _internal_call=True)
+
+        if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH] and self.state.best_global_step:
+            best_checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.best_global_step}"
+            best_checkpoint_dir = os.path.join(run_dir, best_checkpoint_folder)
+
+            if os.path.exists(best_checkpoint_dir):
+                self.state.best_model_checkpoint = best_checkpoint_dir
+
+        if not self.args.save_only_model:
+            # Save optimizer and scheduler
+            self._save_optimizer_and_scheduler(output_dir)
+            self._save_scaler(output_dir)
+            # Save RNG state
+            self._save_rng_state(output_dir)
+
+        # Save the Trainer state
+        if self.args.should_save:
+            # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
+            for cb in [
+                cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
+            ]:
+                cb_name = cb.__class__.__name__
+                cb_state = cb.state()
+                if isinstance(self.state.stateful_callbacks[cb_name], list):
+                    self.state.stateful_callbacks[cb_name].append(cb_state)
+                else:
+                    self.state.stateful_callbacks[cb_name] = cb_state
+            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
+
+        if self.args.push_to_hub:
+            self._push_from_checkpoint(output_dir)
+
+        # Maybe delete some older checkpoints.
+        if self.args.should_save:
+            # we use mtime as default, filesystems without mtime support will be detected in `_sorted_checkpoints`
+            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
+
+    def _save_rng_state(self, output_dir):
+        # Save RNG state in non-distributed training
+        rng_states = {
+            "python": random.getstate(),
+            "numpy": np.random.get_state(),
+            "cpu": torch.random.get_rng_state(),
+        }
+        if torch.cuda.is_available():
+            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+                # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
+                rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
+            else:
+                rng_states["cuda"] = torch.cuda.random.get_rng_state()
+
+        if is_torch_xla_available():
+            rng_states["xla"] = xm.get_rng_state()
+
+        if is_torch_npu_available():
+            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+                rng_states["npu"] = torch.npu.random.get_rng_state_all()
+            else:
+                rng_states["npu"] = torch.npu.random.get_rng_state()
+
+        if is_torch_hpu_available():
+            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+                rng_states["hpu"] = torch.hpu.random.get_rng_state_all()
+            else:
+                rng_states["hpu"] = torch.hpu.random.get_rng_state()
+
+        if is_torch_mlu_available():
+            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+                rng_states["mlu"] = torch.mlu.random.get_rng_state_all()
+            else:
+                rng_states["mlu"] = torch.mlu.random.get_rng_state()
+
+        if is_torch_musa_available():
+            if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+                rng_states["musa"] = torch.musa.get_rng_state_all()
+            else:
+                rng_states["musa"] = torch.musa.get_rng_state()
+
+        # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
+        # not yet exist.
+        os.makedirs(output_dir, exist_ok=True)
+
+        if self.args.world_size <= 1:
+            torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
+        else:
+            torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
+
+    def _save_optimizer_and_scheduler(self, output_dir):
+        if is_torch_xla_available():
+            xm.rendezvous("saving_optimizer_states")
+            if self.is_fsdp_xla_v1_enabled:
+                optm = {
+                    "optimizer": self.optimizer.state_dict(),
+                    "shard_metadata": self.model.get_shard_metadata(),
+                }
+                xm.save(
+                    optm,
+                    os.path.join(
+                        output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
+                    ),
+                    master_only=False,
+                )
+            else:
+                xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
+            with warnings.catch_warnings(record=True) as caught_warnings:
+                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+                reissue_pt_warnings(caught_warnings)
+        elif is_sagemaker_mp_enabled():
+            opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
+            smp.barrier()
+            if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
+                smp.save(
+                    opt_state_dict,
+                    os.path.join(output_dir, OPTIMIZER_NAME),
+                    partial=True,
+                    v3=smp.state.cfg.shard_optimizer_state,
+                )
+        elif self.is_deepspeed_enabled:
+            # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
+            # config `stage3_gather_16bit_weights_on_model_save` is True
+            accept_exclude_frozen_parameters = "exclude_frozen_parameters" in set(
+                inspect.signature(self.model_wrapped.save_checkpoint).parameters.keys()
+            )
+            if accept_exclude_frozen_parameters and _is_peft_model(self.model):
+                self.model_wrapped.save_checkpoint(output_dir, exclude_frozen_parameters=True)
+            else:
+                self.model_wrapped.save_checkpoint(output_dir)
+        elif self.is_fsdp_enabled:
+            # save fsdp specific ckpt for resuming from ckpt
+            save_fsdp_model(
+                self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir, **_get_fsdp_ckpt_kwargs()
+            )
+            save_fsdp_optimizer(
+                self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
+            )
+        elif self.args.should_save:
+            # deepspeed.save_checkpoint above saves model/optim/sched
+            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
+
+        # Save SCHEDULER & SCALER
+        is_deepspeed_custom_scheduler = self.is_deepspeed_enabled and not isinstance(
+            self.lr_scheduler, DeepSpeedSchedulerWrapper
+        )
+        if (
+            self.args.should_save
+            and (not self.is_deepspeed_enabled or is_deepspeed_custom_scheduler)
+            and not is_torch_xla_available()
+        ):
+            with warnings.catch_warnings(record=True) as caught_warnings:
+                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+            reissue_pt_warnings(caught_warnings)
+
+    def _load_optimizer_and_scheduler(self, checkpoint):
+        """If optimizer and scheduler states exist, load them."""
+        if checkpoint is None:
+            return
+
+        if self.is_deepspeed_enabled:
+            # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
+            if not isinstance(self.lr_scheduler, DeepSpeedSchedulerWrapper):
+                with warnings.catch_warnings(record=True) as caught_warnings:
+                    check_torch_load_is_safe()
+                    self.lr_scheduler.load_state_dict(
+                        torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
+                    )
+                reissue_pt_warnings(caught_warnings)
+            return
+
+        checkpoint_file_exists = (
+            glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
+            if is_sagemaker_mp_enabled()
+            else (
+                os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
+                or os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME_BIN))
+                or (
+                    os.path.isdir(checkpoint)
+                    and any(
+                        OPTIMIZER_NAME_BIN.split(".")[0] in folder_name
+                        for folder_name in os.listdir(checkpoint)
+                        if os.path.isdir(os.path.join(checkpoint, folder_name))
+                    )
+                )
+            )
+        )
+        checkpoint_file_exists = (
+            glob.glob(os.path.join(checkpoint, f"rank*-of-{self.args.world_size}-{OPTIMIZER_NAME}"))
+            if self.is_fsdp_xla_v1_enabled
+            else checkpoint_file_exists
+        )
+        if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
+            # Load in optimizer and scheduler states
+            if is_torch_xla_available():
+                # On TPU we have to take some extra precautions to properly load the states on the right device.
+                if self.is_fsdp_xla_v1_enabled:
+                    check_torch_load_is_safe()
+                    optimizer_state = torch.load(
+                        os.path.join(
+                            checkpoint, f"rank{self.args.process_index}-of-{self.args.world_size}-{OPTIMIZER_NAME}"
+                        ),
+                        map_location="cpu",
+                        weights_only=True,
+                    )
+                    # We only need `optimizer` when resuming from checkpoint
+                    optimizer_state = optimizer_state["optimizer"]
+                else:
+                    check_torch_load_is_safe()
+                    optimizer_state = torch.load(
+                        os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu", weights_only=True
+                    )
+                with warnings.catch_warnings(record=True) as caught_warnings:
+                    check_torch_load_is_safe()
+                    lr_scheduler_state = torch.load(
+                        os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu", weights_only=True
+                    )
+                reissue_pt_warnings(caught_warnings)
+
+                xm.send_cpu_data_to_device(optimizer_state, self.args.device)
+                xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)
+
+                self.optimizer.load_state_dict(optimizer_state)
+                self.lr_scheduler.load_state_dict(lr_scheduler_state)
+            else:
+                if is_sagemaker_mp_enabled():
+                    if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
+                        # Optimizer checkpoint was saved with smp >= 1.10
+                        def opt_load_hook(mod, opt):
+                            opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
+
+                    else:
+                        # Optimizer checkpoint was saved with smp < 1.10
+                        def opt_load_hook(mod, opt):
+                            if IS_SAGEMAKER_MP_POST_1_10:
+                                opt.load_state_dict(
+                                    smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
+                                )
+                            else:
+                                opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
+
+                    self.model_wrapped.register_post_step_hook(opt_load_hook)
+                else:
+                    # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
+                    # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
+                    # likely to get OOM on CPU (since we load num_gpu times the optimizer state
+                    map_location = self.args.device if self.args.world_size > 1 else "cpu"
+                    if self.is_fsdp_enabled:
+                        load_fsdp_optimizer(
+                            self.accelerator.state.fsdp_plugin,
+                            self.accelerator,
+                            self.optimizer,
+                            self.model,
+                            checkpoint,
+                            **_get_fsdp_ckpt_kwargs(),
+                        )
+                    else:
+                        check_torch_load_is_safe()
+                        self.optimizer.load_state_dict(
+                            torch.load(
+                                os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location, weights_only=True
+                            )
+                        )
+                with warnings.catch_warnings(record=True) as caught_warnings:
+                    check_torch_load_is_safe()
+                    self.lr_scheduler.load_state_dict(
+                        torch.load(os.path.join(checkpoint, SCHEDULER_NAME), weights_only=True)
+                    )
+                reissue_pt_warnings(caught_warnings)
+
+    def _save_scaler(self, output_dir):
+        # See if there is a scaler attribute
+        try:
+            scaler = self.accelerator.scaler
+        except AttributeError:
+            return
+        if scaler is None:
+            return
+        if is_torch_xla_available():
+            xm.rendezvous("saving_scaler_state")
+            with warnings.catch_warnings(record=True) as caught_warnings:
+                xm.save(self.accelerator.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
+                reissue_pt_warnings(caught_warnings)
+
+        # Save SCALER
+        if self.args.should_save and not is_torch_xla_available():
+            with warnings.catch_warnings(record=True) as caught_warnings:
+                torch.save(self.accelerator.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
+            reissue_pt_warnings(caught_warnings)
+
+    def _load_scaler(self, checkpoint):
+        """If scaler state exists, load it."""
+        if checkpoint is None:
+            return
+
+        checkpoint_file_exists = os.path.isfile(os.path.join(checkpoint, SCALER_NAME))
+
+        if checkpoint_file_exists:
+            # On TPU we have to take some extra precautions to properly load the states on the right device.
+            # Load in scaler states
+            if is_torch_xla_available():
+                with warnings.catch_warnings(record=True) as caught_warnings:
+                    check_torch_load_is_safe()
+                    scaler_state = torch.load(
+                        os.path.join(checkpoint, SCALER_NAME), map_location="cpu", weights_only=True
+                    )
+                reissue_pt_warnings(caught_warnings)
+                xm.send_cpu_data_to_device(scaler_state, self.args.device)
+                self.accelerator.scaler.load_state_dict(scaler_state)
+            else:
+                with warnings.catch_warnings(record=True) as caught_warnings:
+                    check_torch_load_is_safe()
+                    self.accelerator.scaler.load_state_dict(
+                        torch.load(os.path.join(checkpoint, SCALER_NAME), weights_only=True)
+                    )
+                reissue_pt_warnings(caught_warnings)
+
+    def _load_callback_state(self):
+        """If callback states exist and were passed in, restore their states if enabled"""
+        if not self.args.restore_callback_states_from_checkpoint:
+            return
+        # Callback states are stored in stateful_callbacks
+        not_found = []
+        new_callbacks = []
+        original_callbacks = self.callback_handler.callbacks + [self.control]
+        for stored_callback, data in self.state.stateful_callbacks.items():
+            if not isinstance(data, list):
+                data = [data]
+            if any(callback.__class__.__name__ == stored_callback for callback in original_callbacks):
+                # We can load/restore from multiple callbacks of the same type.
+                duplicates = [
+                    callback for callback in original_callbacks if callback.__class__.__name__ == stored_callback
+                ]
+                for callback, callback_data in zip(duplicates, data):
+                    args = callback_data.get("args", {})
+                    attributes = callback_data.get("attributes", {})
+                    new_callback = type(callback)(**args)
+                    for attribute, value in attributes.items():
+                        setattr(new_callback, attribute, value)
+                    if isinstance(callback, TrainerControl):
+                        # Specifically for restoring the `control` state
+                        self.control = new_callback
+                    else:
+                        new_callbacks.append(new_callback)
+                    # We remove the existing callback and add it to the list of new callbacks
+                    self.callback_handler.remove_callback(type(new_callback))
+                logger.info("Continuing training from checkpoint, restoring any callbacks that were passed in")
+            else:
+                not_found.append(stored_callback)
+        if len(not_found) > 0:
+            logger.warning(
+                f"Checkpoint included callbacks not included in current configuration. Ignoring. ({', '.join(not_found)})"
+            )
+        for callback in new_callbacks:
+            self.callback_handler.add_callback(callback)
+
+    def hyperparameter_search(
+        self,
+        hp_space: Optional[Callable[["optuna.Trial"], dict[str, float]]] = None,
+        compute_objective: Optional[Callable[[dict[str, float]], float]] = None,
+        n_trials: int = 20,
+        direction: Union[str, list[str]] = "minimize",
+        backend: Optional[Union["str", HPSearchBackend]] = None,
+        hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
+        **kwargs,
+    ) -> Union[BestRun, list[BestRun]]:
+        """
+        Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
+        by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
+        the sum of all metrics otherwise.
+
+        
+
+        To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
+        reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to
+        subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom
+        optimizer/scheduler.
+
+        
+
+        Args:
+            hp_space (`Callable[["optuna.Trial"], dict[str, float]]`, *optional*):
+                A function that defines the hyperparameter search space. Will default to
+                [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or
+                [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.
+            compute_objective (`Callable[[dict[str, float]], float]`, *optional*):
+                A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
+                method. Will default to [`~trainer_utils.default_compute_objective`].
+            n_trials (`int`, *optional*, defaults to 100):
+                The number of trial runs to test.
+            direction (`str` or `list[str]`, *optional*, defaults to `"minimize"`):
+                If it's single objective optimization, direction is `str`, can be `"minimize"` or `"maximize"`, you
+                should pick `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or
+                several metrics. If it's multi objectives optimization, direction is `list[str]`, can be List of
+                `"minimize"` and `"maximize"`, you should pick `"minimize"` when optimizing the validation loss,
+                `"maximize"` when optimizing one or several metrics.
+            backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):
+                The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
+                on which one is installed. If all are installed, will default to optuna.
+            hp_name (`Callable[["optuna.Trial"], str]]`, *optional*):
+                A function that defines the trial/run name. Will default to None.
+            kwargs (`dict[str, Any]`, *optional*):
+                Additional keyword arguments for each backend:
+
+                - `optuna`: parameters from
+                  [optuna.study.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
+                  and also the parameters `timeout`, `n_jobs` and `gc_after_trial` from
+                  [optuna.study.Study.optimize](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.Study.html#optuna.study.Study.optimize)
+                - `ray`: parameters from [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run).
+                  If `resources_per_trial` is not set in the `kwargs`, it defaults to 1 CPU core and 1 GPU (if available).
+                  If `progress_reporter` is not set in the `kwargs`,
+                  [ray.tune.CLIReporter](https://docs.ray.io/en/latest/tune/api/doc/ray.tune.CLIReporter.html) is used.
+                - `sigopt`: the parameter `proxies` from
+                  [sigopt.Connection.set_proxies](https://docs.sigopt.com/support/faq#how-do-i-use-sigopt-with-a-proxy).
+
+        Returns:
+            [`trainer_utils.BestRun` or `list[trainer_utils.BestRun]`]: All the information about the best run or best
+            runs for multi-objective optimization. Experiment summary can be found in `run_summary` attribute for Ray
+            backend.
+        """
+        if backend is None:
+            backend = default_hp_search_backend()
+        backend = HPSearchBackend(backend)
+        backend_obj = ALL_HYPERPARAMETER_SEARCH_BACKENDS[backend]()
+        backend_obj.ensure_available()
+        self.hp_search_backend = backend
+        if self.model_init is None:
+            raise RuntimeError(
+                "To use hyperparameter search, you need to pass your model through a model_init function."
+            )
+
+        self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space
+        self.hp_name = hp_name
+        self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
+
+        best_run = backend_obj.run(self, n_trials, direction, **kwargs)
+
+        self.hp_search_backend = None
+        return best_run
+
+    def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
+        """
+        Log `logs` on the various objects watching training.
+
+        Subclass and override this method to inject custom behavior.
+
+        Args:
+            logs (`dict[str, float]`):
+                The values to log.
+            start_time (`Optional[float]`):
+                The start of training.
+        """
+        if self.state.epoch is not None:
+            logs["epoch"] = self.state.epoch
+        if self.args.include_num_input_tokens_seen:
+            logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen
+            if start_time is not None:
+                logs.update(speed_metrics("train", start_time, num_tokens=self.state.num_input_tokens_seen))
+
+        output = {**logs, **{"step": self.state.global_step}}
+        self.state.log_history.append(output)
+        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
+
+    def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
+        """
+        Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
+        """
+        if isinstance(data, Mapping):
+            return type(data)({k: self._prepare_input(v) for k, v in data.items()})
+        elif isinstance(data, (tuple, list)):
+            return type(data)(self._prepare_input(v) for v in data)
+        elif isinstance(data, torch.Tensor):
+            kwargs = {"device": self.args.device}
+            if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
+                # NLP models inputs are int/uint and those get adjusted to the right dtype of the
+                # embedding. Other models such as wav2vec2's inputs are already float and thus
+                # may need special handling to match the dtypes of the model
+                kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
+            return data.to(**kwargs)
+        return data
+
+    def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
+        """
+        Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
+        handling potential state.
+        """
+        inputs = self._prepare_input(inputs)
+        if len(inputs) == 0:
+            raise ValueError(
+                "The batch received was empty, your model won't be able to train on it. Double-check that your "
+                f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
+            )
+        if self.args.past_index >= 0 and self._past is not None:
+            inputs["mems"] = self._past
+
+        return inputs
+
+    def _is_attention_mask_causal(self, attention_mask):
+        """
+        Check if an attention mask is causal (compatible with causal attention).
+        Context parallelism only supports causal attention patterns. This function
+        checks if the provided attention mask is compatible.
+
+        Args:
+            attention_mask (torch.Tensor): The attention mask to check
+
+        Returns:
+            bool: True if the mask is causal or compatible with causal attention
+        """
+        if attention_mask is None:
+            return True  # No mask is considered causal (model uses default causal masking)
+
+        # Handle different mask dimensions
+        if attention_mask.dim() == 2:
+            # (batch_size, seq_len) - standard padding mask, compatible with causal attention
+            return True
+        elif attention_mask.dim() in [3, 4]:
+            # (batch_size, seq_len, seq_len) or (batch_size, num_heads, seq_len, seq_len)
+            # Check if it's lower triangular (causal)
+            seq_len = attention_mask.shape[-1]
+            if seq_len <= 1:
+                return True  # Single token or empty is always causal
+
+            # Take first batch and head (if 4D) for checking pattern
+            if attention_mask.dim() == 4:
+                mask = attention_mask[0, 0]  # First batch, first head
+            else:
+                mask = attention_mask[0]  # First batch
+
+            # Check if upper triangular part is masked (should be 0 or very negative for causal)
+            upper_triangular = torch.triu(mask, diagonal=1)
+
+            # For causal masks, upper triangular should be 0 or very negative (like -inf)
+            # Use a reasonable threshold to handle float precision issues
+            is_causal = torch.all(upper_triangular <= 1e-6) or torch.all(upper_triangular < -1e4)
+            return is_causal.item() if isinstance(is_causal, torch.Tensor) else is_causal
+
+        # For unknown dimensions, be conservative and reject
+        return False
+
+    def _prepare_context_parallel_inputs(self, model, inputs: dict[str, Union[torch.Tensor, Any]]):
+        """
+        Prepare inputs for context parallelism by setting up buffers and validation.
+
+        Args:
+            model: The model being trained
+            inputs: Input tensors to prepare
+
+        Returns:
+            tuple: (context_manager, prepared_inputs) where context_manager is either
+                   the context parallelism wrapper or a no-op context
+        """
+        if (
+            getattr(self.accelerator, "parallelism_config", None) is not None
+            and self.accelerator.parallelism_config.cp_enabled
+        ):
+            if hasattr(model, "config"):
+                if model.config._attn_implementation != "sdpa":
+                    raise ValueError(
+                        f"Context parallelism is supported only with SDPA attention, you are using {model.config._attn_implementation}."
+                    )
+
+            if "position_ids" not in inputs:
+                logger.warning_once("Position IDs not found in the inputs, generating manually")
+                inputs["position_ids"] = torch.arange(
+                    inputs["input_ids"].size(1), device=inputs["input_ids"].device
+                ).expand(inputs["input_ids"].size(0), -1)
+            if "shift_labels" not in inputs:
+                logger.warning_once("Shift labels not found in the inputs, shifting manually")
+                if "labels" in inputs:
+                    _ignore_index = -100
+                    labels = nn.functional.pad(inputs["labels"], (0, 1), value=_ignore_index)
+                    inputs["shift_labels"] = labels[:, 1:].contiguous()
+
+            buffers = []
+            buffer_seq_dims = []
+
+            if "input_ids" in inputs:
+                buffers.append(inputs["input_ids"])
+                buffer_seq_dims.append(1)  # Sequence dimension
+            if "labels" in inputs:
+                buffers.append(inputs["labels"])
+                buffer_seq_dims.append(1)
+            if "shift_labels" in inputs:
+                buffers.append(inputs["shift_labels"])
+                buffer_seq_dims.append(1)
+            if "attention_mask" in inputs and not getattr(self, "_attn_mask_causal_checked", False):
+                # Context parallel currently doesn't support other masks than causal
+                # Accelerate applies hooks to replace mask with is_causal arg in SDPA
+                # Check if the mask is really causal and if not throw an error
+                # TODO: check this only once or always, with speed being the cost
+                attention_mask = inputs["attention_mask"]
+                if not self._is_attention_mask_causal(attention_mask):
+                    raise ValueError(
+                        "Context parallelism only supports causal attention masks. "
+                        "The provided attention_mask is not causal. "
+                        "Please ensure your data uses causal masking (lower triangular) "
+                        "or remove the attention_mask to use the model's default causal masking."
+                    )
+                self._attn_mask_causal_checked = True
+            # Include position_ids in context parallelism splitting
+            if "position_ids" in inputs and inputs["position_ids"] is not None:
+                buffers.append(inputs["position_ids"])
+                buffer_seq_dims.append(1)
+
+            return partial(
+                self.accelerator.maybe_context_parallel,
+                buffers=buffers,
+                buffer_seq_dims=buffer_seq_dims,
+                no_restore_buffers=set(buffers),
+            ), inputs
+
+        return contextlib.nullcontext, inputs
+
+    def compute_loss_context_manager(self):
+        """
+        A helper wrapper to group together context managers.
+        """
+        ctx_stack = contextlib.ExitStack()
+
+        autocast_ctx = self.autocast_smart_context_manager()
+        if not isinstance(autocast_ctx, contextlib.nullcontext):
+            ctx_stack.enter_context(autocast_ctx)
+
+        return ctx_stack
+
+    def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
+        """
+        A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
+        arguments, depending on the situation.
+        """
+        if self.use_cpu_amp:
+            # TODO Matt: This syntax is deprecated and the preferred version is
+            #      torch.amp.autocast("cpu", cache_enabled=cache_enabled, dtype=self.amp_dtype)
+            #      but this is unavailable on Torch 2.1 or earlier. We can change this when we stop supporting 2.1.
+            ctx_manager = torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
+        else:
+            ctx_manager = contextlib.nullcontext()
+
+        return ctx_manager
+
+    def training_step(
+        self,
+        model: nn.Module,
+        inputs: dict[str, Union[torch.Tensor, Any]],
+        num_items_in_batch: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """
+        Perform a training step on a batch of inputs.
+
+        Subclass and override to inject custom behavior.
+
+        Args:
+            model (`nn.Module`):
+                The model to train.
+            inputs (`dict[str, Union[torch.Tensor, Any]]`):
+                The inputs and targets of the model.
+
+                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
+                argument `labels`. Check your model's documentation for all accepted arguments.
+
+        Return:
+            `torch.Tensor`: The tensor with training loss on this batch.
+        """
+        # Prepare buffers for context parallelism
+
+        cp_context, inputs = self._prepare_context_parallel_inputs(model, inputs)
+
+        # Context manager is no-op if CP isn't enabled
+        with cp_context():
+            model.train()
+            if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
+                self.optimizer.train()
+
+            inputs = self._prepare_inputs(inputs)
+            if is_sagemaker_mp_enabled():
+                loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
+                return loss_mb.reduce_mean().detach().to(self.args.device)
+
+            with self.compute_loss_context_manager():
+                loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
+
+            del inputs
+            if (
+                self.args.torch_empty_cache_steps is not None
+                and self.state.global_step % self.args.torch_empty_cache_steps == 0
+            ):
+                if is_torch_xpu_available():
+                    torch.xpu.empty_cache()
+                elif is_torch_mlu_available():
+                    torch.mlu.empty_cache()
+                elif is_torch_musa_available():
+                    torch.musa.empty_cache()
+                elif is_torch_npu_available():
+                    torch.npu.empty_cache()
+                elif is_torch_mps_available():
+                    torch.mps.empty_cache()
+                elif is_torch_hpu_available():
+                    logger.warning(
+                        "`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()."
+                    )
+                else:
+                    torch.cuda.empty_cache()
+
+            kwargs = {}
+
+            # For LOMO optimizers you need to explicitly use the learning rate
+            if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
+                kwargs["learning_rate"] = self._get_learning_rate()
+
+            if self.args.n_gpu > 1:
+                loss = loss.mean()  # mean() to average on multi-gpu parallel training
+
+            if self.use_apex:
+                from apex import amp
+
+                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
+                    scaled_loss.backward()
+            else:
+                # Finally we need to normalize the loss for reporting if GA loss bug is not fixed during compute loss
+                if (
+                    not self.model_accepts_loss_kwargs or num_items_in_batch is None
+                ) and self.compute_loss_func is None:
+                    # If the model does not accept loss kwargs, we need to normalize the loss by the number of gradient accumulation steps
+                    loss = loss / self.current_gradient_accumulation_steps
+
+                # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled
+                # https://github.com/huggingface/transformers/pull/35808
+                if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
+                    kwargs["scale_wrt_gas"] = False
+
+                self.accelerator.backward(loss, **kwargs)
+
+            return loss.detach()
+
+    def compute_loss(
+        self,
+        model: nn.Module,
+        inputs: dict[str, Union[torch.Tensor, Any]],
+        return_outputs: bool = False,
+        num_items_in_batch: Optional[torch.Tensor] = None,
+    ):
+        """
+        How the loss is computed by Trainer. By default, all models return the loss in the first element.
+
+        Args:
+            model (`nn.Module`):
+                The model to compute the loss for.
+            inputs (`dict[str, Union[torch.Tensor, Any]]`):
+                The input data for the model.
+            return_outputs (`bool`, *optional*, defaults to `False`):
+                Whether to return the model outputs along with the loss.
+            num_items_in_batch (Optional[torch.Tensor], *optional*):
+                The number of items in the batch. If num_items_in_batch is not passed,
+
+        Returns:
+            The loss of the model along with its output if return_outputs was set to True
+
+        Subclass and override for custom behavior. If you are not using `num_items_in_batch` when computing your loss,
+        make sure to overwrite `self.model_accepts_loss_kwargs` to `False`. Otherwise, the loss calculating might be slightly inaccurate when performing gradient accumulation.
+        """
+        if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
+            labels = inputs.pop("labels")
+        else:
+            labels = None
+        if self.model_accepts_loss_kwargs:
+            kwargs = {}
+            if num_items_in_batch is not None:
+                kwargs["num_items_in_batch"] = num_items_in_batch
+            inputs = {**inputs, **kwargs}
+        outputs = model(**inputs)
+        # Save past state if it exists
+        # TODO: this needs to be fixed and made cleaner later.
+        if self.args.past_index >= 0:
+            self._past = outputs[self.args.past_index]
+
+        if labels is not None:
+            unwrapped_model = self.accelerator.unwrap_model(model)
+            if _is_peft_model(unwrapped_model):
+                model_name = unwrapped_model.base_model.model._get_name()
+            else:
+                model_name = unwrapped_model._get_name()
+            # User-defined compute_loss function
+            if self.compute_loss_func is not None:
+                loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch)
+            elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
+                loss = self.label_smoother(outputs, labels, shift_labels=True)
+            else:
+                loss = self.label_smoother(outputs, labels)
+        else:
+            if isinstance(outputs, dict) and "loss" not in outputs:
+                raise ValueError(
+                    "The model did not return a loss from the inputs, only the following keys: "
+                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
+                )
+            # We don't use .loss here since the model may return tuples instead of ModelOutput.
+            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
+
+        if (
+            self.args.average_tokens_across_devices
+            and (self.model_accepts_loss_kwargs or self.compute_loss_func)
+            and num_items_in_batch is not None
+        ):
+            loss *= self.accelerator.num_processes
+
+        return (loss, outputs) if return_outputs else loss
+
+    def is_local_process_zero(self) -> bool:
+        """
+        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
+        machines) main process.
+        """
+        return self.args.local_process_index == 0
+
+    def is_world_process_zero(self) -> bool:
+        """
+        Whether or not this process is the global main process (when training in a distributed fashion on several
+        machines, this is only going to be `True` for one process).
+        """
+        # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
+        # process index.
+        if is_sagemaker_mp_enabled():
+            return smp.rank() == 0
+        else:
+            return self.args.process_index == 0
+
+    def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
+        """
+        Will save the model, so you can reload it using `from_pretrained()`.
+
+        Will only save from the main process.
+        """
+
+        if output_dir is None:
+            output_dir = self.args.output_dir
+
+        if is_torch_xla_available():
+            self._save_tpu(output_dir)
+        elif is_sagemaker_mp_enabled():
+            # Calling the state_dict needs to be done on the wrapped model and on all processes.
+            os.makedirs(output_dir, exist_ok=True)
+            state_dict = self.model_wrapped.state_dict()
+            if self.args.should_save:
+                self._save(output_dir, state_dict=state_dict)
+            if IS_SAGEMAKER_MP_POST_1_10:
+                # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
+                Path(os.path.join(output_dir, "user_content.pt")).touch()
+        # We are in N-D parallelism if we have parallelism_config set, so we check accelerate if we're on a to_save rank
+        elif getattr(self.accelerator, "parallelism_config", None) is not None:
+            if self.accelerator.should_save_model:
+                self._save(output_dir)
+        # If we drop to here, we're in 1D parallelism, so all ranks need to go to `save_pretrained`
+        elif (tp_size := getattr(self.model, "_tp_size", 0)) is not None and tp_size > 1:
+            self._save(output_dir)
+        elif self.is_fsdp_enabled:
+            if ("FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)) and (
+                version.parse(accelerate_version) > version.parse("0.24.1")
+            ):
+                state_dict = self.accelerator.get_state_dict(self.model)
+                if self.args.should_save:
+                    self._save(output_dir, state_dict=state_dict)
+        elif self.is_deepspeed_enabled:
+            try:
+                state_dict = self.accelerator.get_state_dict(self.deepspeed)
+                if self.args.should_save:
+                    self._save(output_dir, state_dict=state_dict)
+            except ValueError:
+                logger.warning(
+                    " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
+                    " zero_to_fp32.py to recover weights"
+                )
+                if self.args.should_save:
+                    self._save(output_dir, state_dict={})
+                # remove the dummy state_dict
+                remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
+                self.model_wrapped.save_checkpoint(output_dir)
+
+        elif self.args.should_save:
+            self._save(output_dir)
+
+        # Push to the Hub when `save_model` is called by the user.
+        if self.args.push_to_hub and not _internal_call:
+            self.push_to_hub(commit_message="Model save", revision=self.args.hub_revision)
+
+    def _save_tpu(self, output_dir: Optional[str] = None):
+        output_dir = output_dir if output_dir is not None else self.args.output_dir
+
+        logger.info(f"Saving model checkpoint to {output_dir}")
+        model = self.model
+        xm.mark_step()
+
+        if xm.is_master_ordinal(local=False):
+            os.makedirs(output_dir, exist_ok=True)
+            torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
+
+        # Save a trained model and configuration using `save_pretrained()`.
+        # They can then be reloaded using `from_pretrained()`
+        supported_classes = (PushToHubMixin,)
+        xm.rendezvous("saving_checkpoint")
+        if self.is_fsdp_xla_v1_enabled:
+            ckpt = {
+                "model": model.state_dict(),
+                "shard_metadata": model.get_shard_metadata(),
+            }
+            ckpt_path = os.path.join(
+                output_dir, f"rank{self.args.process_index}-of-{self.args.world_size}-{WEIGHTS_NAME}"
+            )
+            # All ranks save sharded checkpoint
+            xm.save(ckpt, ckpt_path, master_only=False)
+            # Make sure all ranks have saved checkpoints
+            xm.rendezvous("save_full_checkpoints")
+            # Master save full checkpoint
+            if self.args.should_save:
+                from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints
+
+                full_state_dict, _ = consolidate_sharded_model_checkpoints(
+                    ckpt_prefix=os.path.join(output_dir, ""),
+                    ckpt_suffix=f"rank*-of-*-{WEIGHTS_NAME}",
+                    save_model=False,
+                )
+                model = model.module.module
+                unwrapped_model = self.accelerator.unwrap_model(model)
+                if isinstance(unwrapped_model, supported_classes):
+                    unwrapped_model.save_pretrained(
+                        output_dir,
+                        state_dict=full_state_dict,
+                        save_function=xm.save,
+                        safe_serialization=self.args.save_safetensors,
+                    )
+                else:
+                    logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
+                    xm.save(full_state_dict, os.path.join(output_dir, WEIGHTS_NAME))
+        elif not isinstance(model, supported_classes):
+            if isinstance(self.accelerator.unwrap_model(model), supported_classes):
+                self.accelerator.unwrap_model(model).save_pretrained(
+                    output_dir,
+                    is_main_process=self.args.should_save,
+                    state_dict=xm._maybe_convert_to_cpu(model.state_dict()),
+                    save_function=xm.save,
+                    safe_serialization=self.args.save_safetensors,
+                )
+            else:
+                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
+                state_dict = xm._maybe_convert_to_cpu(model.state_dict())
+                xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
+        else:
+            model.save_pretrained(
+                output_dir,
+                is_main_process=self.args.should_save,
+                save_function=xm.save,
+                safe_serialization=self.args.save_safetensors,
+                state_dict=xm._maybe_convert_to_cpu(model.state_dict()),
+            )
+        if self.processing_class is not None and self.args.should_save:
+            self.processing_class.save_pretrained(output_dir)
+
+    def _save(self, output_dir: Optional[str] = None, state_dict=None):
+        # If we are executing this function, we are the process zero, so we don't check for that.
+        output_dir = output_dir if output_dir is not None else self.args.output_dir
+        os.makedirs(output_dir, exist_ok=True)
+        logger.info(f"Saving model checkpoint to {output_dir}")
+
+        supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
+        # Save a trained model and configuration using `save_pretrained()`.
+        # They can then be reloaded using `from_pretrained()`
+        if not isinstance(self.model, supported_classes):
+            if state_dict is None:
+                state_dict = self.model.state_dict()
+
+            if isinstance(self.accelerator.unwrap_model(self.model, keep_torch_compile=False), supported_classes):
+                self.accelerator.unwrap_model(self.model, keep_torch_compile=False).save_pretrained(
+                    output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
+                )
+            else:
+                logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
+                if self.args.save_safetensors:
+                    safetensors.torch.save_file(
+                        state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
+                    )
+                else:
+                    torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
+        else:
+            self.model.save_pretrained(
+                output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
+            )
+
+        if self.processing_class is not None:
+            self.processing_class.save_pretrained(output_dir)
+        elif (
+            self.data_collator is not None
+            and hasattr(self.data_collator, "tokenizer")
+            and self.data_collator.tokenizer is not None
+        ):
+            logger.info("Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`")
+            self.data_collator.tokenizer.save_pretrained(output_dir)
+
+        # Good practice: save your training arguments together with the trained model
+        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
+
+    def store_flos(self):
+        # Storing the number of floating-point operations that went into the model
+        if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+            self.state.total_flos += (
+                distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
+            )
+            self.current_flos = 0
+        else:
+            self.state.total_flos += self.current_flos
+            self.current_flos = 0
+
+    def _sorted_checkpoints(
+        self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
+    ) -> list[str]:
+        ordering_and_checkpoint_path = []
+
+        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
+
+        for path in glob_checkpoints:
+            if use_mtime:
+                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
+            else:
+                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
+                if regex_match is not None and regex_match.groups() is not None:
+                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
+
+        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
+        # mtime is not reliable on all filesystems, especially on some fuse fs in cloud environments
+        # so we check if the mtime is fake and fallback to numerical ordering if needed
+        if use_mtime and len(ordering_and_checkpoint_path) > 1:
+            mtime_diff = checkpoints_sorted[-1][0] - checkpoints_sorted[0][0]
+            if mtime_diff < 1.0:  # less than 1 second, which is almost impossible when mtime works fine
+                warnings.warn("mtime may not be reliable on this filesystem, falling back to numerical ordering")
+                return self._sorted_checkpoints(
+                    use_mtime=False, output_dir=output_dir, checkpoint_prefix=checkpoint_prefix
+                )
+        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
+
+        # Make sure we don't delete the best model.
+        if (
+            self.state.best_model_checkpoint is not None
+            and str(Path(self.state.best_model_checkpoint)) in checkpoints_sorted
+        ):
+            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
+            for i in range(best_model_index, len(checkpoints_sorted) - 2):
+                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
+        return checkpoints_sorted
+
+    def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
+        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
+            return
+
+        # Check if we should delete older checkpoint(s)
+        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
+        if len(checkpoints_sorted) <= self.args.save_total_limit:
+            return
+
+        # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
+        # we don't do to allow resuming.
+        save_total_limit = self.args.save_total_limit
+        if (
+            self.state.best_model_checkpoint is not None
+            and self.args.save_total_limit == 1
+            and checkpoints_sorted[-1] != self.state.best_model_checkpoint
+        ):
+            save_total_limit = 2
+
+        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
+        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
+        for checkpoint in checkpoints_to_be_deleted:
+            logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
+            shutil.rmtree(checkpoint, ignore_errors=True)
+
+    def evaluate(
+        self,
+        eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
+        ignore_keys: Optional[list[str]] = None,
+        metric_key_prefix: str = "eval",
+    ) -> dict[str, float]:
+        """
+        Run evaluation and returns metrics.
+
+        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
+        (pass it to the init `compute_metrics` argument).
+
+        You can also subclass and override this method to inject custom behavior.
+
+        Args:
+            eval_dataset (Union[`Dataset`, dict[str, `Dataset`]), *optional*):
+                Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
+                not accepted by the `model.forward()` method are automatically removed. If it is a dictionary, it will
+                evaluate on each dataset, prepending the dictionary key to the metric name. Datasets must implement the
+                `__len__` method.
+
+                
+
+                If you pass a dictionary with names of datasets as keys and datasets as values, evaluate will run
+                separate evaluations on each dataset. This can be useful to monitor how training affects other
+                datasets or simply to get a more fine-grained evaluation.
+                When used with `load_best_model_at_end`, make sure `metric_for_best_model` references exactly one
+                of the datasets. If you, for example, pass in `{"data1": data1, "data2": data2}` for two datasets
+                `data1` and `data2`, you could specify `metric_for_best_model="eval_data1_loss"` for using the
+                loss on `data1` and `metric_for_best_model="eval_data2_loss"` for the loss on `data2`.
+
+                
+
+            ignore_keys (`list[str]`, *optional*):
+                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+                gathering predictions.
+            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
+                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
+                "eval_bleu" if the prefix is "eval" (default)
+
+        Returns:
+            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
+            dictionary also contains the epoch number which comes from the training state.
+        """
+        # handle multiple eval datasets
+        override = eval_dataset is not None
+        eval_dataset = eval_dataset if override else self.eval_dataset
+        if isinstance(eval_dataset, dict):
+            metrics = {}
+            for eval_dataset_name, _eval_dataset in eval_dataset.items():
+                dataset_metrics = self.evaluate(
+                    eval_dataset=_eval_dataset if override else eval_dataset_name,
+                    ignore_keys=ignore_keys,
+                    metric_key_prefix=f"{metric_key_prefix}_{eval_dataset_name}",
+                )
+                metrics.update(dataset_metrics)
+            return metrics
+
+        # memory metrics - must set up as early as possible
+        self._memory_tracker.start()
+
+        eval_dataloader = self.get_eval_dataloader(eval_dataset)
+        if self.is_fsdp_xla_v2_enabled:
+            eval_dataloader = tpu_spmd_dataloader(eval_dataloader)
+
+        start_time = time.time()
+
+        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
+        output = eval_loop(
+            eval_dataloader,
+            description="Evaluation",
+            # No point gathering the predictions if there are no metrics, otherwise we defer to
+            # self.args.prediction_loss_only
+            prediction_loss_only=True if self.compute_metrics is None else None,
+            ignore_keys=ignore_keys,
+            metric_key_prefix=metric_key_prefix,
+        )
+
+        total_batch_size = self.args.eval_batch_size * self.args.world_size
+        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
+            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
+        if f"{metric_key_prefix}_model_preparation_time" in output.metrics:
+            start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"]
+        output.metrics.update(
+            speed_metrics(
+                metric_key_prefix,
+                start_time,
+                num_samples=output.num_samples,
+                num_steps=math.ceil(output.num_samples / total_batch_size),
+            )
+        )
+
+        self.log(output.metrics)
+
+        if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
+            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
+            xm.master_print(met.metrics_report())
+
+        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
+
+        self._memory_tracker.stop_and_update_metrics(output.metrics)
+
+        return output.metrics
+
+    def predict(
+        self, test_dataset: Dataset, ignore_keys: Optional[list[str]] = None, metric_key_prefix: str = "test"
+    ) -> PredictionOutput:
+        """
+        Run prediction and returns predictions and potential metrics.
+
+        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
+        will also return metrics, like in `evaluate()`.
+
+        Args:
+            test_dataset (`Dataset`):
+                Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
+                `model.forward()` method are automatically removed. Has to implement the method `__len__`
+            ignore_keys (`list[str]`, *optional*):
+                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+                gathering predictions.
+            metric_key_prefix (`str`, *optional*, defaults to `"test"`):
+                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
+                "test_bleu" if the prefix is "test" (default)
+
+        
+
+        If your predictions or labels have different sequence length (for instance because you're doing dynamic padding
+        in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
+        one array. The padding index is -100.
+
+        
+
+        Returns: *NamedTuple* A namedtuple with the following keys:
+
+            - predictions (`np.ndarray`): The predictions on `test_dataset`.
+            - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
+            - metrics (`dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
+              labels).
+        """
+        # memory metrics - must set up as early as possible
+        self._memory_tracker.start()
+
+        test_dataloader = self.get_test_dataloader(test_dataset)
+        start_time = time.time()
+
+        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
+        output = eval_loop(
+            test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
+        )
+        total_batch_size = self.args.eval_batch_size * self.args.world_size
+        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
+            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
+        if f"{metric_key_prefix}_model_preparation_time" in output.metrics:
+            start_time += output.metrics[f"{metric_key_prefix}_model_preparation_time"]
+        output.metrics.update(
+            speed_metrics(
+                metric_key_prefix,
+                start_time,
+                num_samples=output.num_samples,
+                num_steps=math.ceil(output.num_samples / total_batch_size),
+            )
+        )
+
+        self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
+        self._memory_tracker.stop_and_update_metrics(output.metrics)
+
+        return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
+
+    def evaluation_loop(
+        self,
+        dataloader: DataLoader,
+        description: str,
+        prediction_loss_only: Optional[bool] = None,
+        ignore_keys: Optional[list[str]] = None,
+        metric_key_prefix: str = "eval",
+    ) -> EvalLoopOutput:
+        """
+        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
+
+        Works both with or without labels.
+        """
+        args = self.args
+
+        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
+
+        # if eval is called w/o train, handle model prep here
+        if self.is_deepspeed_enabled and self.deepspeed is None:
+            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
+
+        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
+
+        if len(self.accelerator._models) == 0 and model is self.model:
+            start_time = time.time()
+            model = (
+                self.accelerator.prepare(model)
+                if self.is_deepspeed_enabled
+                or (self.is_fsdp_enabled and self.accelerator.mixed_precision != "fp8" and not self.args.torch_compile)
+                else self.accelerator.prepare_model(model, evaluation_mode=True)
+            )
+            self.model_preparation_time = round(time.time() - start_time, 4)
+
+            if self.is_fsdp_enabled:
+                self.model = model
+
+            # for the rest of this function `model` is the outside model, whether it was wrapped or not
+            if model is not self.model:
+                self.model_wrapped = model
+
+            # backward compatibility
+            if self.is_deepspeed_enabled:
+                self.deepspeed = self.model_wrapped
+
+        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
+        # while ``train`` is running, cast it to the right dtype first and then put on device
+        if not self.is_in_train:
+            if args.fp16_full_eval:
+                model = model.to(dtype=torch.float16, device=args.device)
+            elif args.bf16_full_eval:
+                model = model.to(dtype=torch.bfloat16, device=args.device)
+
+        batch_size = self.args.eval_batch_size
+
+        logger.info(f"\n***** Running {description} *****")
+        if has_length(dataloader):
+            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
+        else:
+            logger.info("  Num examples: Unknown")
+        logger.info(f"  Batch size = {batch_size}")
+
+        if hasattr(model, "eval") and callable(model.eval):
+            model.eval()
+        if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
+            self.optimizer.eval()
+
+        self.callback_handler.eval_dataloader = dataloader
+        # Do this before wrapping.
+        eval_dataset = getattr(dataloader, "dataset", None)
+
+        if args.past_index >= 0:
+            self._past = None
+
+        # Initialize containers
+        all_losses = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
+        all_preds = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
+        all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
+        all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
+
+        metrics = None
+        eval_set_kwargs = {}
+
+        # Will be useful when we have an iterable dataset so don't know its length.
+        observed_num_examples = 0
+
+        # Main evaluation loop
+        for step, inputs in enumerate(dataloader):
+            # Update the observed num examples
+            observed_batch_size = find_batch_size(inputs)
+            if observed_batch_size is not None:
+                observed_num_examples += observed_batch_size
+                # For batch samplers, batch_size is not known by the dataloader in advance.
+                if batch_size is None:
+                    batch_size = observed_batch_size
+
+            # Prediction step
+            losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
+            main_input_name = getattr(self.model, "main_input_name", "input_ids")
+            inputs_decode = (
+                self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None
+            )
+
+            if is_torch_xla_available():
+                xm.mark_step()
+
+            # Update containers
+            if losses is not None:
+                losses = self.gather_function(losses.repeat(batch_size))
+                all_losses.add(losses)
+            if inputs_decode is not None:
+                inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
+                inputs_decode = self.gather_function(inputs_decode)
+                if not self.args.batch_eval_metrics or description == "Prediction":
+                    all_inputs.add(inputs_decode)
+            if labels is not None:
+                # Pad labels here, preparing for preprocess_logits_for_metrics in next logits block.
+                labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
+            if logits is not None:
+                logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
+                if self.preprocess_logits_for_metrics is not None:
+                    logits = self.preprocess_logits_for_metrics(logits, labels)
+                logits = self.gather_function(logits)
+                if not self.args.batch_eval_metrics or description == "Prediction":
+                    all_preds.add(logits)
+            if labels is not None:
+                labels = self.gather_function(labels)
+                if not self.args.batch_eval_metrics or description == "Prediction":
+                    all_labels.add(labels)
+
+            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
+
+            if self.args.batch_eval_metrics:
+                if self.compute_metrics is not None and logits is not None and labels is not None:
+                    is_last_step = self.accelerator.gradient_state.end_of_dataloader
+                    batch_kwargs = {}
+                    batch_kwargs["losses"] = losses if "loss" in args.include_for_metrics else None
+                    batch_kwargs["inputs"] = inputs if "inputs" in args.include_for_metrics else None
+                    metrics = self.compute_metrics(
+                        EvalPrediction(predictions=logits, label_ids=labels, **batch_kwargs),
+                        compute_result=is_last_step,
+                    )
+
+                del losses, logits, labels, inputs
+                torch.cuda.empty_cache()
+
+            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
+            elif args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
+                all_losses.to_cpu_and_numpy()
+                all_preds.to_cpu_and_numpy()
+                all_labels.to_cpu_and_numpy()
+                all_inputs.to_cpu_and_numpy()
+
+                del losses, logits, labels, inputs
+                torch.cuda.empty_cache()
+
+        # After all calls to `.gather_function`, reset to `gather_for_metrics`:
+        self.gather_function = self.accelerator.gather_for_metrics
+        if args.past_index and hasattr(self, "_past"):
+            # Clean the state at the end of the evaluation loop
+            delattr(self, "_past")
+
+        # Gather all remaining tensors and put them back on the CPU
+        all_losses = all_losses.get_arrays()
+        all_preds = all_preds.get_arrays()
+        all_labels = all_labels.get_arrays()
+        all_inputs = all_inputs.get_arrays()
+
+        # Number of samples
+        if has_length(eval_dataset):
+            num_samples = len(eval_dataset)
+        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
+        # methods. Therefore we need to make sure it also has the attribute.
+        elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
+            num_samples = eval_dataset.num_examples
+        else:
+            if has_length(dataloader):
+                num_samples = self.num_examples(dataloader)
+            else:  # both len(dataloader.dataset) and len(dataloader) fail
+                num_samples = observed_num_examples
+        if num_samples == 0 and observed_num_examples > 0:
+            num_samples = observed_num_examples
+
+        # Metrics!
+        if (
+            self.compute_metrics is not None
+            and all_preds is not None
+            and all_labels is not None
+            and not self.args.batch_eval_metrics
+        ):
+            eval_set_kwargs["losses"] = all_losses if "loss" in args.include_for_metrics else None
+            eval_set_kwargs["inputs"] = all_inputs if "inputs" in args.include_for_metrics else None
+            metrics = self.compute_metrics(
+                EvalPrediction(predictions=all_preds, label_ids=all_labels, **eval_set_kwargs)
+            )
+        elif metrics is None:
+            metrics = {}
+
+        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
+        metrics = denumpify_detensorize(metrics)
+
+        if isinstance(all_losses, list) and all_losses:
+            metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()
+        elif isinstance(all_losses, np.ndarray):
+            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
+        if hasattr(self, "jit_compilation_time"):
+            metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
+        if hasattr(self, "model_preparation_time"):
+            metrics[f"{metric_key_prefix}_model_preparation_time"] = self.model_preparation_time
+
+        # Prefix all keys with metric_key_prefix + '_'
+        for key in list(metrics.keys()):
+            if not key.startswith(f"{metric_key_prefix}_"):
+                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
+
+        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
+
+    def _nested_gather(self, tensors, name=None):
+        """
+        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
+        concatenating them to `gathered`
+        """
+        if tensors is None:
+            return
+        if is_torch_xla_available():
+            if name is None:
+                name = "nested_gather"
+            tensors = nested_xla_mesh_reduce(tensors, name)
+        elif is_sagemaker_mp_enabled():
+            tensors = smp_gather(tensors)
+        elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or (
+            self.args.distributed_state is None and self.args.local_rank != -1
+        ):
+            tensors = distributed_concat(tensors)
+        return tensors
+
+    def prediction_step(
+        self,
+        model: nn.Module,
+        inputs: dict[str, Union[torch.Tensor, Any]],
+        prediction_loss_only: bool,
+        ignore_keys: Optional[list[str]] = None,
+    ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
+        """
+        Perform an evaluation step on `model` using `inputs`.
+
+        Subclass and override to inject custom behavior.
+
+        Args:
+            model (`nn.Module`):
+                The model to evaluate.
+            inputs (`dict[str, Union[torch.Tensor, Any]]`):
+                The inputs and targets of the model.
+
+                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
+                argument `labels`. Check your model's documentation for all accepted arguments.
+            prediction_loss_only (`bool`):
+                Whether or not to return the loss only.
+            ignore_keys (`list[str]`, *optional*):
+                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+                gathering predictions.
+
+        Return:
+            tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
+            logits and labels (each being optional).
+        """
+        has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
+        # For CLIP-like models capable of returning loss values.
+        # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
+        # is `True` in `model.forward`.
+        return_loss = inputs.get("return_loss")
+        if return_loss is None:
+            return_loss = self.can_return_loss
+        loss_without_labels = len(self.label_names) == 0 and return_loss
+
+        inputs = self._prepare_inputs(inputs)
+        if ignore_keys is None:
+            if hasattr(self.model, "config"):
+                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", ["past_key_values"])
+            else:
+                ignore_keys = []
+
+        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
+        if has_labels or loss_without_labels:
+            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
+            if len(labels) == 1:
+                labels = labels[0]
+        else:
+            labels = None
+
+        with torch.no_grad():
+            if is_sagemaker_mp_enabled():
+                raw_outputs = smp_forward_only(model, inputs)
+                if has_labels or loss_without_labels:
+                    if isinstance(raw_outputs, dict):
+                        loss_mb = raw_outputs["loss"]
+                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
+                    else:
+                        loss_mb = raw_outputs[0]
+                        logits_mb = raw_outputs[1:]
+
+                    loss = loss_mb.reduce_mean().detach().cpu()
+                    logits = smp_nested_concat(logits_mb)
+                else:
+                    loss = None
+                    if isinstance(raw_outputs, dict):
+                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
+                    else:
+                        logits_mb = raw_outputs
+                    logits = smp_nested_concat(logits_mb)
+            else:
+                if has_labels or loss_without_labels:
+                    with self.compute_loss_context_manager():
+                        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
+                    loss = loss.detach().mean()
+
+                    if isinstance(outputs, dict):
+                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
+                    else:
+                        logits = outputs[1:]
+                else:
+                    loss = None
+                    with self.compute_loss_context_manager():
+                        outputs = model(**inputs)
+                    if isinstance(outputs, dict):
+                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
+                    else:
+                        logits = outputs
+                    # TODO: this needs to be fixed and made cleaner later.
+                    if self.args.past_index >= 0:
+                        self._past = outputs[self.args.past_index - 1]
+
+        if prediction_loss_only:
+            return (loss, None, None)
+
+        logits = nested_detach(logits)
+        if len(logits) == 1:
+            logits = logits[0]
+
+        return (loss, logits, labels)
+
+    def floating_point_ops(self, inputs: dict[str, Union[torch.Tensor, Any]]):
+        """
+        For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
+        operations for every backward + forward pass. If using another model, either implement such a method in the
+        model or subclass and override this method.
+
+        Args:
+            inputs (`dict[str, Union[torch.Tensor, Any]]`):
+                The inputs and targets of the model.
+
+        Returns:
+            `int`: The number of floating-point operations.
+        """
+        if hasattr(self.model, "floating_point_ops"):
+            return self.model.floating_point_ops(inputs)
+        else:
+            return 0
+
+    def init_hf_repo(self, token: Optional[str] = None):
+        """
+        Initializes a git repo in `self.args.hub_model_id`.
+        """
+        # Only on process zero
+        if not self.is_world_process_zero():
+            return
+
+        if self.args.hub_model_id is None:
+            repo_name = Path(self.args.output_dir).absolute().name
+        else:
+            repo_name = self.args.hub_model_id
+
+        token = token if token is not None else self.args.hub_token
+        repo_url = create_repo(repo_name, token=token, private=self.args.hub_private_repo, exist_ok=True)
+        self.hub_model_id = repo_url.repo_id
+        self.push_in_progress = None
+
+    def create_model_card(
+        self,
+        language: Optional[str] = None,
+        license: Optional[str] = None,
+        tags: Union[str, list[str], None] = None,
+        model_name: Optional[str] = None,
+        finetuned_from: Optional[str] = None,
+        tasks: Union[str, list[str], None] = None,
+        dataset_tags: Union[str, list[str], None] = None,
+        dataset: Union[str, list[str], None] = None,
+        dataset_args: Union[str, list[str], None] = None,
+    ):
+        """
+        Creates a draft of a model card using the information available to the `Trainer`.
+
+        Args:
+            language (`str`, *optional*):
+                The language of the model (if applicable)
+            license (`str`, *optional*):
+                The license of the model. Will default to the license of the pretrained model used, if the original
+                model given to the `Trainer` comes from a repo on the Hub.
+            tags (`str` or `list[str]`, *optional*):
+                Some tags to be included in the metadata of the model card.
+            model_name (`str`, *optional*):
+                The name of the model.
+            finetuned_from (`str`, *optional*):
+                The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
+                of the original model given to the `Trainer` (if it comes from the Hub).
+            tasks (`str` or `list[str]`, *optional*):
+                One or several task identifiers, to be included in the metadata of the model card.
+            dataset_tags (`str` or `list[str]`, *optional*):
+                One or several dataset tags, to be included in the metadata of the model card.
+            dataset (`str` or `list[str]`, *optional*):
+                One or several dataset identifiers, to be included in the metadata of the model card.
+            dataset_args (`str` or `list[str]`, *optional*):
+               One or several dataset arguments, to be included in the metadata of the model card.
+        """
+        if not self.is_world_process_zero():
+            return
+
+        model_card_filepath = os.path.join(self.args.output_dir, "README.md")
+        is_peft_library = False
+        if os.path.exists(model_card_filepath):
+            library_name = ModelCard.load(model_card_filepath).data.get("library_name")
+            is_peft_library = library_name == "peft"
+
+            # Append existing tags in `tags`
+            existing_tags = ModelCard.load(model_card_filepath).data.tags
+            if tags is not None and existing_tags is not None:
+                if isinstance(tags, str):
+                    tags = [tags]
+                for tag in existing_tags:
+                    if tag not in tags:
+                        tags.append(tag)
+
+        training_summary = TrainingSummary.from_trainer(
+            self,
+            language=language,
+            license=license,
+            tags=tags,
+            model_name=model_name,
+            finetuned_from=finetuned_from,
+            tasks=tasks,
+            dataset_tags=dataset_tags,
+            dataset=dataset,
+            dataset_args=dataset_args,
+        )
+        model_card = training_summary.to_model_card()
+        with open(model_card_filepath, "w") as f:
+            f.write(model_card)
+
+        if is_peft_library:
+            self.accelerator.unwrap_model(self.model).create_or_update_model_card(self.args.output_dir)
+
+    def _push_from_checkpoint(self, checkpoint_folder):
+        # Only push from one node.
+        if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
+            return
+        # If we haven't finished the last push, we don't do this one unless args.hub_always_push=True.
+        if not self.args.hub_always_push and self.push_in_progress is not None and not self.push_in_progress.is_done():
+            return
+
+        output_dir = self.args.output_dir
+        # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
+        modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
+        #  Add sharded checkpoints if we have an index
+        for index_file in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
+            index_path = os.path.join(checkpoint_folder, index_file)
+            if os.path.isfile(index_path):
+                modeling_files.append(index_file)
+                with open(index_path) as f:
+                    index = json.loads(f.read())
+                shard_files = list(set(index["weight_map"].values()))
+                modeling_files.extend(shard_files)
+        if is_peft_available():
+            modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
+        for modeling_file in modeling_files:
+            if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
+                shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
+        # Saving the processing class is fast and we don't know how many files it may have spawned, so we resave it to be sure.
+        if self.processing_class is not None:
+            self.processing_class.save_pretrained(output_dir)
+        # Same for the training arguments
+        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
+
+        if self.args.save_strategy == SaveStrategy.STEPS:
+            commit_message = f"Training in progress, step {self.state.global_step}"
+        else:
+            commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
+
+        model_push_job = upload_folder(
+            repo_id=self.hub_model_id,
+            folder_path=output_dir,
+            commit_message=commit_message,
+            token=self.args.hub_token,
+            run_as_future=True,
+            ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
+            revision=self.args.hub_revision,
+        )
+
+        push_jobs = [model_push_job]
+
+        if self.args.hub_strategy in [HubStrategy.CHECKPOINT, HubStrategy.ALL_CHECKPOINTS]:
+            path_in_repo = (
+                "last-checkpoint" if self.args.hub_strategy == HubStrategy.CHECKPOINT else Path(checkpoint_folder).name
+            )
+            checkpoint_push = upload_folder(
+                repo_id=self.hub_model_id,
+                folder_path=checkpoint_folder,
+                path_in_repo=path_in_repo,
+                commit_message=commit_message + ", checkpoint",
+                token=self.args.hub_token,
+                run_as_future=True,
+                revision=self.args.hub_revision,
+            )
+            push_jobs.append(checkpoint_push)
+
+        if self.push_in_progress is None or self.push_in_progress.is_done():
+            self.push_in_progress = PushInProgress(push_jobs)
+        else:
+            self.push_in_progress.jobs.extend(push_jobs)
+
+    def _finish_current_push(self):
+        if not hasattr(self, "push_in_progress"):
+            return
+        if self.push_in_progress is not None and not self.push_in_progress.is_done():
+            logger.info("Waiting for the current checkpoint push to be finished, this might take a couple of minutes.")
+            self.push_in_progress.wait_until_done()
+
+    def push_to_hub(
+        self,
+        commit_message: Optional[str] = "End of training",
+        blocking: bool = True,
+        token: Optional[str] = None,
+        revision: Optional[str] = None,
+        **kwargs,
+    ) -> str:
+        """
+        Upload `self.model` and `self.processing_class` to the 🤗 model hub on the repo `self.args.hub_model_id`.
+
+        Parameters:
+            commit_message (`str`, *optional*, defaults to `"End of training"`):
+                Message to commit while pushing.
+            blocking (`bool`, *optional*, defaults to `True`):
+                Whether the function should return only when the `git push` has finished.
+            token (`str`, *optional*, defaults to `None`):
+                Token with write permission to overwrite Trainer's original args.
+            revision (`str`, *optional*):
+                The git revision to commit from. Defaults to the head of the "main" branch.
+            kwargs (`dict[str, Any]`, *optional*):
+                Additional keyword arguments passed along to [`~Trainer.create_model_card`].
+
+        Returns:
+            The URL of the repository where the model was pushed if `blocking=False`, or a `Future` object tracking the
+            progress of the commit if `blocking=True`.
+        """
+        model_name = kwargs.pop("model_name", None)
+        if model_name is None and self.args.should_save:
+            if self.args.hub_model_id is None:
+                model_name = Path(self.args.output_dir).name
+            else:
+                model_name = self.args.hub_model_id.split("/")[-1]
+        token = token if token is not None else self.args.hub_token
+
+        # In case the user calls this method with args.push_to_hub = False
+        if self.hub_model_id is None:
+            self.init_hf_repo(token=token)
+
+        # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
+        # self.args.should_save.
+        self.save_model(_internal_call=True)
+
+        # Only push from one node.
+        if not self.is_world_process_zero():
+            return
+
+        # Add additional tags in the case the model has already some tags and users pass
+        # "tags" argument to `push_to_hub` so that trainer automatically handles internal tags
+        # from all models since Trainer does not call `model.push_to_hub`.
+        if getattr(self.model, "model_tags", None) is not None:
+            if "tags" not in kwargs:
+                kwargs["tags"] = []
+
+            # If it is a string, convert it to a list
+            if isinstance(kwargs["tags"], str):
+                kwargs["tags"] = [kwargs["tags"]]
+
+            for model_tag in self.model.model_tags:
+                if model_tag not in kwargs["tags"]:
+                    kwargs["tags"].append(model_tag)
+
+        self.create_model_card(model_name=model_name, **kwargs)
+
+        if revision is None:
+            revision = self.args.hub_revision
+
+        # Wait for the current upload to be finished.
+        self._finish_current_push()
+
+        return upload_folder(
+            repo_id=self.hub_model_id,
+            folder_path=self.args.output_dir,
+            commit_message=commit_message,
+            token=token,
+            run_as_future=not blocking,
+            ignore_patterns=["_*", f"{PREFIX_CHECKPOINT_DIR}-*"],
+            revision=revision,
+        )
+
+    #
+    # Deprecated code
+    #
+
+    def prediction_loop(
+        self,
+        dataloader: DataLoader,
+        description: str,
+        prediction_loss_only: Optional[bool] = None,
+        ignore_keys: Optional[list[str]] = None,
+        metric_key_prefix: str = "eval",
+    ) -> EvalLoopOutput:
+        """
+        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
+
+        Works both with or without labels.
+        """
+        args = self.args
+
+        if not has_length(dataloader):
+            raise ValueError("dataloader must implement a working __len__")
+
+        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
+
+        # if eval is called w/o train, handle model prep here
+        if self.is_deepspeed_enabled and self.deepspeed is None:
+            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
+
+        model = self._wrap_model(self.model, training=False, dataloader=dataloader)
+
+        if len(self.accelerator._models) == 0 and model is self.model:
+            model = (
+                self.accelerator.prepare(model)
+                if self.is_deepspeed_enabled or self.is_fsdp_enabled
+                else self.accelerator.prepare_model(model, evaluation_mode=True)
+            )
+
+            if self.is_fsdp_enabled:
+                self.model = model
+
+            # for the rest of this function `model` is the outside model, whether it was wrapped or not
+            if model is not self.model:
+                self.model_wrapped = model
+
+            # backward compatibility
+            if self.is_deepspeed_enabled:
+                self.deepspeed = self.model_wrapped
+
+        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
+        # while ``train`` is running, cast it to the right dtype first and then put on device
+        if not self.is_in_train:
+            if args.fp16_full_eval:
+                model = model.to(dtype=torch.float16, device=args.device)
+            elif args.bf16_full_eval:
+                model = model.to(dtype=torch.bfloat16, device=args.device)
+
+        batch_size = (
+            dataloader.total_batch_size
+            if getattr(dataloader, "_is_accelerate_prepared", False)
+            else dataloader.batch_size
+        )
+
+        if batch_size is None:
+            raise ValueError(
+                "Batch size cannot be None. Ensure the dataloader has a valid batch_size or total_batch_size."
+            )
+
+        num_examples = self.num_examples(dataloader)
+        logger.info(f"\n***** Running {description} *****")
+        logger.info(f"  Num examples = {num_examples}")
+        logger.info(f"  Batch size = {batch_size}")
+
+        losses_host: Optional[torch.Tensor] = None
+        preds_host: Union[torch.Tensor, list[torch.Tensor], None] = None
+        labels_host: Union[torch.Tensor, list[torch.Tensor], None] = None
+        inputs_host: Union[torch.Tensor, list[torch.Tensor], None] = None
+        metrics: Optional[dict] = None
+        eval_set_kwargs: dict = {}
+
+        world_size = max(1, args.world_size)
+
+        eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
+        if not prediction_loss_only:
+            # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
+            # a batch size to the sampler)
+            make_multiple_of = None
+            if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
+                make_multiple_of = dataloader.sampler.batch_size
+            preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
+            labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
+            inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
+
+        model.eval()
+        if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
+            self.optimizer.eval()
+
+        if args.past_index >= 0:
+            self._past = None
+
+        self.callback_handler.eval_dataloader = dataloader
+
+        for step, inputs in enumerate(dataloader):
+            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
+            main_input_name = getattr(self.model, "main_input_name", "input_ids")
+            inputs_decode = (
+                self._prepare_input(inputs[main_input_name]) if "inputs" in args.include_for_metrics else None
+            )
+
+            if loss is not None:
+                losses = loss.repeat(batch_size)
+                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
+            if logits is not None:
+                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
+            if labels is not None:
+                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
+            if inputs_decode is not None:
+                inputs_host = (
+                    inputs_decode
+                    if inputs_host is None
+                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
+                )
+            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
+
+            if self.args.batch_eval_metrics:
+                if self.compute_metrics is not None and preds_host is not None and labels_host is not None:
+                    is_last_step = self.accelerator.gradient_state.end_of_dataloader
+                    batch_kwargs = {}
+                    batch_kwargs["losses"] = losses_host if "loss" in args.include_for_metrics else None
+                    batch_kwargs["inputs"] = inputs_host if "inputs" in args.include_for_metrics else None
+                    metrics = self.compute_metrics(
+                        EvalPrediction(predictions=preds_host, label_ids=labels_host, **batch_kwargs),
+                        compute_result=is_last_step,
+                    )
+
+            if self.args.batch_eval_metrics or (
+                args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0
+            ):
+                # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
+                eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
+                if not prediction_loss_only:
+                    preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
+                    labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
+                    inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
+
+                # Set back to None to begin a new accumulation
+                del losses_host, preds_host, labels_host, inputs_host
+                torch.cuda.empty_cache()
+                losses_host, preds_host, labels_host, inputs_host = None, None, None, None
+
+        if args.past_index and hasattr(self, "_past"):
+            # Clean the state at the end of the evaluation loop
+            delattr(self, "_past")
+
+        # Gather all remaining tensors and put them back on the CPU
+        eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
+        if not prediction_loss_only:
+            preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
+            labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
+            inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
+
+        eval_loss = eval_losses_gatherer.finalize()
+        preds = preds_gatherer.finalize() if not prediction_loss_only else None
+        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
+        inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
+
+        if (
+            self.compute_metrics is not None
+            and preds is not None
+            and label_ids is not None
+            and not self.args.batch_eval_metrics
+        ):
+            eval_set_kwargs["losses"] = eval_loss if "loss" in args.include_for_metrics else None
+            eval_set_kwargs["inputs"] = inputs_ids if "inputs" in args.include_for_metrics else None
+            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids, **eval_set_kwargs))
+        elif metrics is None:
+            metrics = {}
+
+        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
+        metrics = denumpify_detensorize(metrics)
+
+        if eval_loss is not None:
+            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
+
+        # Prefix all keys with metric_key_prefix + '_'
+        for key in list(metrics.keys()):
+            if not key.startswith(f"{metric_key_prefix}_"):
+                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
+
+        return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)
+
+    def _gather_and_numpify(self, tensors, name):
+        """
+        Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
+        concatenating them to `gathered`
+        """
+        if tensors is None:
+            return
+        if is_torch_xla_available():
+            tensors = nested_xla_mesh_reduce(tensors, name)
+        elif is_sagemaker_mp_enabled():
+            tensors = smp_gather(tensors)
+        elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+            tensors = distributed_concat(tensors)
+
+        return nested_numpify(tensors)
+
+    def _add_sm_patterns_to_gitignore(self) -> None:
+        """Add SageMaker Checkpointing patterns to .gitignore file."""
+        # Make sure we only do this on the main process
+        if not self.is_world_process_zero():
+            return
+
+        patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"]
+
+        # Get current .gitignore content
+        if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")):
+            with open(os.path.join(self.repo.local_dir, ".gitignore")) as f:
+                current_content = f.read()
+        else:
+            current_content = ""
+
+        # Add the patterns to .gitignore
+        content = current_content
+        for pattern in patterns:
+            if pattern not in content:
+                if content.endswith("\n"):
+                    content += pattern
+                else:
+                    content += f"\n{pattern}"
+
+        # Write the .gitignore file if it has changed
+        if content != current_content:
+            with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f:
+                logger.debug(f"Writing .gitignore file. Content: {content}")
+                f.write(content)
+
+        self.repo.git_add(".gitignore")
+
+        # avoid race condition with git status
+        time.sleep(0.5)
+
+        if not self.repo.is_repo_clean():
+            self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
+            self.repo.git_push()
+
+    def create_accelerator_and_postprocess(self):
+        # We explicitly don't rely on the `Accelerator` to do gradient accumulation
+        grad_acc_kwargs = {}
+        if is_accelerate_available("0.28.0") and self.args.accelerator_config.gradient_accumulation_kwargs is not None:
+            grad_acc_kwargs = self.args.accelerator_config.gradient_accumulation_kwargs
+
+        # check if num_steps is attempted to be passed in gradient_accumulation_kwargs
+        if "num_steps" in grad_acc_kwargs:
+            if self.args.gradient_accumulation_steps > 1:
+                # raise because we do not know which setting is intended.
+                raise ValueError(
+                    "The `AcceleratorConfig`'s `num_steps` is set but `gradient_accumulation_steps` is greater than 1 in the passed `TrainingArguments`"
+                    "If using the passed `AcceleratorConfig` is desired, do not set the `TrainingArguments` `gradient_accumulation_steps`."
+                )
+            else:
+                self.args.gradient_accumulation_steps = grad_acc_kwargs["num_steps"]
+
+        accelerator_config = self.args.accelerator_config.to_dict()
+
+        if is_accelerate_available("0.28.0"):
+            # Extract dataloader config params from accelerator config
+            dataloader_params = ["split_batches", "dispatch_batches", "even_batches", "use_seedable_sampler"]
+            dataloader_config = DataLoaderConfiguration(
+                **{param: accelerator_config.pop(param) for param in dataloader_params}
+            )
+            if is_accelerate_available("1.1.0"):
+                dataloader_config.data_seed = self.args.data_seed
+
+        non_blocking = accelerator_config.pop("non_blocking")
+        if not is_accelerate_available("0.30.0"):
+            if non_blocking:
+                raise ImportError(
+                    "`non_blocking` is only supported in accelerate v0.30.0 and above. Please upgrade accelerate to use this feature."
+                )
+        else:
+            if non_blocking and not self.args.dataloader_pin_memory:
+                logger.warning(
+                    "`non_blocking` is enabled but `dataloader_pin_memory` is not. For the best performance, it's recommended to enable both."
+                )
+            dataloader_config.non_blocking = non_blocking
+        # this would have been updated above, no need for it anymore
+        accelerator_config.pop("gradient_accumulation_kwargs")
+
+        args = {
+            "deepspeed_plugin": self.args.deepspeed_plugin,
+        }
+
+        # We defer compatibility checks to accelerator
+        if self.args.parallelism_config is not None:
+            if not is_accelerate_available("1.10.1"):
+                raise ImportError(
+                    "ParallelismConfig requires accelerate v1.10.1 and above. Please upgrade accelerate to use this feature."
+                )
+
+            args["parallelism_config"] = self.args.parallelism_config
+
+        if is_accelerate_available("0.28.0"):
+            args["dataloader_config"] = dataloader_config
+        else:
+            args.update(accelerator_config)
+        # tp is initialized at Accelerator init phase so
+        # args should be prepared here
+        if hasattr(self.model, "tp_size") and self.model.tp_size is not None and self.model.tp_size > 1:
+            self.is_tp_enabled = True
+            if version.parse(accelerate_version) > version.parse("1.3.0"):
+                args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.model.tp_size)
+            else:
+                raise ValueError("Requires accelerate>1.3.0 to use Tensor Parallelism.")
+
+        # create accelerator object
+        self.accelerator = Accelerator(**args)
+        # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
+        self.gather_function = self.accelerator.gather_for_metrics
+
+        if "use_gather_object" in inspect.signature(self.gather_function).parameters:
+            self.gather_function = functools.partial(
+                self.gather_function, use_gather_object=self.args.eval_use_gather_object
+            )
+
+        # deepspeed and accelerate flags covering both trainer args and accelerate launcher
+        self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
+        self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
+        self.is_tp_enabled = getattr(self.accelerator.state, "torch_tp_plugin", None) is not None
+        # post accelerator creation setup
+        if self.is_fsdp_enabled:
+            fsdp_plugin = self.accelerator.state.fsdp_plugin
+            for param in ["limit_all_gathers", "activation_checkpointing"]:
+                setattr(fsdp_plugin, param, self.args.fsdp_config.get(param, getattr(fsdp_plugin, param)))
+            if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
+                raise ValueError(
+                    "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
+                    "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
+                    "when using FSDP."
+                )
+
+        if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
+            self.propagate_args_to_deepspeed()
+
+        # `save_only_model` can't be used with DeepSpeed/FSDP along with `load_best_model_at_end`
+        if (
+            self.args.save_only_model
+            and (self.is_deepspeed_enabled or self.is_fsdp_enabled)
+            and self.args.load_best_model_at_end
+        ):
+            wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP"
+            raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.")
+
+        # `auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3
+        if (
+            self.is_deepspeed_enabled
+            and self.accelerator.state.deepspeed_plugin.zero_stage == 3
+            and self.args.auto_find_batch_size
+        ):
+            raise ValueError(
+                "`auto_find_batch_size` isn't supported yet with DeepSpeed Zero-3. Please consider using Zero-2, Zero-1, or FSDP"
+            )
+        if (
+            self.args.save_only_model
+            and self.is_fsdp_enabled
+            and "SHARDED_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)
+        ):
+            raise ValueError("save_only_model option is not compatible with FSDP state dict type 'SHARDED_STATE_DICT'")
+
+    def propagate_args_to_deepspeed(self, auto_find_batch_size=False):
+        """
+        Sets values in the deepspeed plugin based on the Trainer args
+        """
+        from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
+
+        ds_plugin = self.accelerator.state.deepspeed_plugin
+
+        ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
+        ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
+        ds_plugin.hf_ds_config.trainer_config_process(self.args, auto_find_batch_size)
+
+    def _fsdp_qlora_plugin_updates(self):
+        if self.is_fsdp_enabled and _is_peft_model(self.model):
+            from peft import PeftConfig
+            from peft.utils.other import fsdp_auto_wrap_policy
+
+            if isinstance(self.model.active_peft_config, PeftConfig):
+                self.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(self.model)
+            if (
+                getattr(self.model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
+                and self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage.is_floating_point
+                and version.parse(accelerate_version) > version.parse("0.27.0")
+            ):
+                self.accelerator.state.fsdp_plugin.set_mixed_precision(
+                    self.model.hf_quantizer.quantization_config.bnb_4bit_quant_storage, override=True
+                )
+
+    def get_batch_samples(
+        self, epoch_iterator: Iterator, num_batches: int, device: torch.device
+    ) -> tuple[list, Optional[torch.Tensor]]:
+        """
+        Collects a specified number of batches from the epoch iterator and optionally counts the number of items in the batches to properly scale the loss.
+        """
+        batch_samples = []
+        num_items_in_batch = None
+
+        for _ in range(num_batches):
+            try:
+                batch_samples.append(next(epoch_iterator))
+            except StopIteration:
+                break
+
+        count_num_items_in_batch = (
+            len(batch_samples) > 0
+            and "labels" in batch_samples[0]
+            and (
+                # num_items_in_batch is passed to model forward
+                # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3757
+                self.model_accepts_loss_kwargs
+                # num_items_in_batch is passed to compute_loss_func
+                # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3773
+                or self.compute_loss_func is not None
+                # num_items_in_batch is also verified if (self.model_accepts_loss_kwargs or self.compute_loss_func)
+                # https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3790
+            )
+        )
+
+        if count_num_items_in_batch:
+            # For now we don't support object detection
+            try:
+                num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
+            except (TypeError, AttributeError):
+                pass
+
+        if num_items_in_batch is not None:
+            if self.args.average_tokens_across_devices:
+                num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum()
+
+            if torch.is_tensor(num_items_in_batch):
+                num_items_in_batch = num_items_in_batch.to(device)
+
+                if self.args.n_gpu > 1 and num_items_in_batch.dim() == 0:
+                    # In the DataParallel case, convert the scalar tensor into a 1-dim tensor
+                    num_items_in_batch = num_items_in_batch.unsqueeze(0)
+                # Divide by number of devices with the same batch
+                if pc := getattr(self.accelerator, "parallelism_config", None):
+                    num_items_in_batch = num_items_in_batch // pc.non_data_parallel_size
+
+        return batch_samples, num_items_in_batch
+
+    def set_initial_training_values(
+        self, args: TrainingArguments, dataloader: DataLoader, total_train_batch_size: int
+    ):
+        """
+        Calculates and returns the following values:
+        - `num_train_epochs`
+        - `num_update_steps_per_epoch`
+        - `num_examples`
+        - `num_train_samples`
+        - `epoch_based`
+        - `len_dataloader`
+        - `max_steps`
+        """
+        # Case 1: we rely on `args.max_steps` first
+        max_steps = args.max_steps
+        # If max_steps is negative, we use the number of epochs to determine the number of total steps later
+        epoch_based = max_steps < 0
+        len_dataloader = len(dataloader) if has_length(dataloader) else None
+
+        # Case 2: We have a dataloader length and can extrapolate
+        if len_dataloader is not None:
+            num_update_steps_per_epoch = max(
+                len_dataloader // args.gradient_accumulation_steps
+                + int(len_dataloader % args.gradient_accumulation_steps > 0),
+                1,
+            )
+            # Case 3: We have a length but are using epochs, we can extrapolate the number of steps
+            if epoch_based:
+                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
+
+        # Now we figure out `num_examples`, `num_train_epochs`, and `train_samples`
+        if len_dataloader:
+            num_examples = self.num_examples(dataloader)
+            if args.max_steps > 0:
+                num_train_epochs = max_steps // num_update_steps_per_epoch + int(
+                    max_steps % num_update_steps_per_epoch > 0
+                )
+                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
+                # the best we can do.
+                num_train_samples = max_steps * total_train_batch_size
+            else:
+                num_train_epochs = math.ceil(args.num_train_epochs)
+                num_train_samples = self.num_examples(dataloader) * args.num_train_epochs
+        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
+            # Setting a very large number of epochs so we go as many times as necessary over the iterator.
+            num_train_epochs = sys.maxsize
+            num_update_steps_per_epoch = max_steps
+            num_examples = total_train_batch_size * args.max_steps
+            num_train_samples = args.max_steps * total_train_batch_size
+        else:
+            raise ValueError(
+                "args.max_steps must be set to a positive value if dataloader does not have a length, was"
+                f" {args.max_steps}"
+            )
+        return (
+            num_train_epochs,
+            num_update_steps_per_epoch,
+            num_examples,
+            num_train_samples,
+            epoch_based,
+            len_dataloader,
+            max_steps,
+        )
diff --git a/phivenv/Lib/site-packages/transformers/trainer_callback.py b/phivenv/Lib/site-packages/transformers/trainer_callback.py
new file mode 100644
index 0000000000000000000000000000000000000000..7102f7a5bedc95a333421dcfcab75f7ea9303b87
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/trainer_callback.py
@@ -0,0 +1,785 @@
+# Copyright 2020-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Callbacks to use with the Trainer class and customize the training loop.
+"""
+
+import dataclasses
+import json
+import math
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import numpy as np
+from tqdm.auto import tqdm
+
+from .trainer_utils import HPSearchBackend, IntervalStrategy, SaveStrategy, has_length
+from .training_args import TrainingArguments
+from .utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class TrainerState:
+    """
+    A class containing the [`Trainer`] inner state that will be saved along the model and optimizer when checkpointing
+    and passed to the [`TrainerCallback`].
+
+    
+
+    In all this class, one step is to be understood as one update step. When using gradient accumulation, one update
+    step may require several forward and backward passes: if you use `gradient_accumulation_steps=n`, then one update
+    step requires going through *n* batches.
+
+    
+
+    Args:
+        epoch (`float`, *optional*):
+            Only set during training, will represent the epoch the training is at (the decimal part being the
+            percentage of the current epoch completed).
+        global_step (`int`, *optional*, defaults to 0):
+            During training, represents the number of update steps completed.
+        max_steps (`int`, *optional*, defaults to 0):
+            The number of update steps to do during the current training.
+        logging_steps (`int`, *optional*, defaults to 500):
+            Log every X updates steps
+        eval_steps (`int`, *optional*):
+            Run an evaluation every X steps.
+        save_steps (`int`, *optional*, defaults to 500):
+            Save checkpoint every X updates steps.
+        train_batch_size (`int`, *optional*):
+            The batch size for the training dataloader. Only needed when
+            `auto_find_batch_size` has been used.
+        num_input_tokens_seen (`int`, *optional*, defaults to 0):
+            When tracking the inputs tokens, the number of tokens seen during training (number of input tokens, not the
+            number of prediction tokens).
+        total_flos (`float`, *optional*, defaults to 0):
+            The total number of floating operations done by the model since the beginning of training (stored as floats
+            to avoid overflow).
+        log_history (`list[dict[str, float]]`, *optional*):
+            The list of logs done since the beginning of training.
+        best_metric (`float`, *optional*):
+            When tracking the best model, the value of the best metric encountered so far.
+        best_global_step (`int`, *optional*):
+            When tracking the best model, the step at which the best metric was encountered.
+            Used for setting `best_model_checkpoint`.
+        best_model_checkpoint (`str`, *optional*):
+            When tracking the best model, the value of the name of the checkpoint for the best model encountered so
+            far.
+        is_local_process_zero (`bool`, *optional*, defaults to `True`):
+            Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
+            several machines) main process.
+        is_world_process_zero (`bool`, *optional*, defaults to `True`):
+            Whether or not this process is the global main process (when training in a distributed fashion on several
+            machines, this is only going to be `True` for one process).
+        is_hyper_param_search (`bool`, *optional*, defaults to `False`):
+            Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will
+            impact the way data will be logged in TensorBoard.
+        stateful_callbacks (`list[StatefulTrainerCallback]`, *optional*):
+            Callbacks attached to the `Trainer` that should have their states be saved or restored.
+            Relevant callbacks should implement a `state` and `from_state` function.
+    """
+
+    epoch: Optional[float] = None
+    global_step: int = 0
+    max_steps: int = 0
+    logging_steps: int = 500
+    eval_steps: int = 500
+    save_steps: int = 500
+    train_batch_size: Optional[int] = None
+    num_train_epochs: int = 0
+    num_input_tokens_seen: int = 0
+    total_flos: float = 0
+    log_history: list[dict[str, float]] = None
+    best_metric: Optional[float] = None
+    best_global_step: Optional[int] = None
+    best_model_checkpoint: Optional[str] = None
+    is_local_process_zero: bool = True
+    is_world_process_zero: bool = True
+    is_hyper_param_search: bool = False
+    trial_name: Optional[str] = None
+    trial_params: dict[str, Union[str, float, int, bool]] = None
+    stateful_callbacks: list["TrainerCallback"] = None
+
+    def __post_init__(self):
+        if self.log_history is None:
+            self.log_history = []
+        if self.stateful_callbacks is None:
+            self.stateful_callbacks = {}
+        elif isinstance(self.stateful_callbacks, dict):
+            # We are loading the callbacks in from the state file, no need to process them
+            pass
+        else:
+            # Saveable callbacks get stored as dict of kwargs
+            stateful_callbacks = {}
+            for callback in self.stateful_callbacks:
+                if not isinstance(callback, (ExportableState)):
+                    raise TypeError(
+                        f"All callbacks passed to be saved must inherit `ExportableState`, but received {type(callback)}"
+                    )
+                name = callback.__class__.__name__
+                if name in stateful_callbacks:
+                    # We can have multiple versions of the same callback
+                    # if so, we store them as a list of states to restore
+                    if not isinstance(stateful_callbacks[name], list):
+                        stateful_callbacks[name] = [stateful_callbacks[name]]
+                    stateful_callbacks[name].append(callback.state())
+                else:
+                    stateful_callbacks[name] = callback.state()
+            self.stateful_callbacks = stateful_callbacks
+
+    def save_to_json(self, json_path: str):
+        """Save the content of this instance in JSON format inside `json_path`."""
+        json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
+        with open(json_path, "w", encoding="utf-8") as f:
+            f.write(json_string)
+
+    @classmethod
+    def load_from_json(cls, json_path: str):
+        """Create an instance from the content of `json_path`."""
+        with open(json_path, encoding="utf-8") as f:
+            text = f.read()
+        return cls(**json.loads(text))
+
+    def compute_steps(self, args, max_steps):
+        """
+        Calculates and stores the absolute value for logging,
+        eval, and save steps based on if it was a proportion
+        or not.
+        """
+        for step_kind in ("logging", "eval", "save"):
+            num_steps = getattr(args, f"{step_kind}_steps")
+            if num_steps is not None:
+                if num_steps < 1:
+                    num_steps = math.ceil(max_steps * num_steps)
+                setattr(self, f"{step_kind}_steps", num_steps)
+
+    def init_training_references(self, trainer, max_steps, num_train_epochs, trial):
+        """
+        Stores the initial training references needed in `self`
+        """
+        if trainer.hp_name is not None and trainer._trial is not None:
+            # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
+            # parameter to Train when using DDP.
+            self.trial_name = trainer.hp_name(trainer._trial)
+        self.trial_params = None
+        if trial is not None:
+            from transformers.integrations import hp_params
+
+            assignments = trial.assignments if trainer.hp_search_backend == HPSearchBackend.SIGOPT else trial
+            self.trial_params = hp_params(assignments)
+
+        self.max_steps = max_steps
+        self.num_train_epochs = num_train_epochs
+        self.is_local_process_zero = trainer.is_local_process_zero()
+        self.is_world_process_zero = trainer.is_world_process_zero()
+
+
+class ExportableState:
+    """
+    A class for objects that include the ability to have its state
+    be saved during `Trainer._save_checkpoint` and loaded back in during
+    `Trainer._load_from_checkpoint`.
+
+    These must implement a `state` function that gets called during the respective
+    Trainer function call. It should only include parameters and attributes needed to
+    recreate the state at a particular time, to avoid utilizing pickle/maintain standard
+    file IO writing.
+
+    Example:
+
+    ```python
+    class EarlyStoppingCallback(TrainerCallback, ExportableState):
+        def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
+            self.early_stopping_patience = early_stopping_patience
+            self.early_stopping_threshold = early_stopping_threshold
+            # early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
+            self.early_stopping_patience_counter = 0
+
+        def state(self) -> dict:
+            return {
+                "args": {
+                    "early_stopping_patience": self.early_stopping_patience,
+                    "early_stopping_threshold": self.early_stopping_threshold,
+                },
+                "attributes": {
+                    "early_stopping_patience_counter": self.early_stopping_patience_counter,
+                }
+            }
+    ```"""
+
+    def state(self) -> dict:
+        raise NotImplementedError("You must implement a `state` function to utilize this class.")
+
+    @classmethod
+    def from_state(cls, state):
+        instance = cls(**state["args"])
+        for k, v in state["attributes"].items():
+            setattr(instance, k, v)
+        return instance
+
+
+@dataclass
+class TrainerControl(ExportableState):
+    """
+    A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate some
+    switches in the training loop.
+
+    Args:
+        should_training_stop (`bool`, *optional*, defaults to `False`):
+            Whether or not the training should be interrupted.
+
+            If `True`, this variable will not be set back to `False`. The training will just stop.
+        should_epoch_stop (`bool`, *optional*, defaults to `False`):
+            Whether or not the current epoch should be interrupted.
+
+            If `True`, this variable will be set back to `False` at the beginning of the next epoch.
+        should_save (`bool`, *optional*, defaults to `False`):
+            Whether or not the model should be saved at this step.
+
+            If `True`, this variable will be set back to `False` at the beginning of the next step.
+        should_evaluate (`bool`, *optional*, defaults to `False`):
+            Whether or not the model should be evaluated at this step.
+
+            If `True`, this variable will be set back to `False` at the beginning of the next step.
+        should_log (`bool`, *optional*, defaults to `False`):
+            Whether or not the logs should be reported at this step.
+
+            If `True`, this variable will be set back to `False` at the beginning of the next step.
+    """
+
+    should_training_stop: bool = False
+    should_epoch_stop: bool = False
+    should_save: bool = False
+    should_evaluate: bool = False
+    should_log: bool = False
+
+    def _new_training(self):
+        """Internal method that resets the variable for a new training."""
+        self.should_training_stop = False
+
+    def _new_epoch(self):
+        """Internal method that resets the variable for a new epoch."""
+        self.should_epoch_stop = False
+
+    def _new_step(self):
+        """Internal method that resets the variable for a new step."""
+        self.should_save = False
+        self.should_evaluate = False
+        self.should_log = False
+
+    def state(self) -> dict:
+        return {
+            "args": {
+                "should_training_stop": self.should_training_stop,
+                "should_epoch_stop": self.should_epoch_stop,
+                "should_save": self.should_save,
+                "should_evaluate": self.should_evaluate,
+                "should_log": self.should_log,
+            },
+            "attributes": {},
+        }
+
+
+class TrainerCallback:
+    # no-format
+    """
+    A class for objects that will inspect the state of the training loop at some events and take some decisions. At
+    each of those events the following arguments are available:
+
+    Args:
+        args ([`TrainingArguments`]):
+            The training arguments used to instantiate the [`Trainer`].
+        state ([`TrainerState`]):
+            The current state of the [`Trainer`].
+        control ([`TrainerControl`]):
+            The object that is returned to the [`Trainer`] and can be used to make some decisions.
+        model ([`PreTrainedModel`] or `torch.nn.Module`):
+            The model being trained.
+        tokenizer ([`PreTrainedTokenizer`]):
+            The tokenizer used for encoding the data. This is deprecated in favour of `processing_class`.
+        processing_class ([`PreTrainedTokenizer` or `BaseImageProcessor` or `ProcessorMixin` or `FeatureExtractionMixin`]):
+            The processing class used for encoding the data. Can be a tokenizer, a processor, an image processor or a feature extractor.
+        optimizer (`torch.optim.Optimizer`):
+            The optimizer used for the training steps.
+        lr_scheduler (`torch.optim.lr_scheduler.LambdaLR`):
+            The scheduler used for setting the learning rate.
+        train_dataloader (`torch.utils.data.DataLoader`, *optional*):
+            The current dataloader used for training.
+        eval_dataloader (`torch.utils.data.DataLoader`, *optional*):
+            The current dataloader used for evaluation.
+        metrics (`dict[str, float]`):
+            The metrics computed by the last evaluation phase.
+
+            Those are only accessible in the event `on_evaluate`.
+        logs  (`dict[str, float]`):
+            The values to log.
+
+            Those are only accessible in the event `on_log`.
+
+    The `control` object is the only one that can be changed by the callback, in which case the event that changes it
+    should return the modified version.
+
+    The argument `args`, `state` and `control` are positionals for all events, all the others are grouped in `kwargs`.
+    You can unpack the ones you need in the signature of the event using them. As an example, see the code of the
+    simple [`~transformers.PrinterCallback`].
+
+    Example:
+
+    ```python
+    class PrinterCallback(TrainerCallback):
+        def on_log(self, args, state, control, logs=None, **kwargs):
+            _ = logs.pop("total_flos", None)
+            if state.is_local_process_zero:
+                print(logs)
+    ```"""
+
+    def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+        """
+        Event called at the end of the initialization of the [`Trainer`].
+        """
+        pass
+
+    def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+        """
+        Event called at the beginning of training.
+        """
+        pass
+
+    def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+        """
+        Event called at the end of training.
+        """
+        pass
+
+    def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+        """
+        Event called at the beginning of an epoch.
+        """
+        pass
+
+    def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+        """
+        Event called at the end of an epoch.
+        """
+        pass
+
+    def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+        """
+        Event called at the beginning of a training step. If using gradient accumulation, one training step might take
+        several inputs.
+        """
+        pass
+
+    def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+        """
+        Event called before the optimizer step but after gradient clipping. Useful for monitoring gradients.
+        """
+        pass
+
+    def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+        """
+        Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients.
+        """
+        pass
+
+    def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+        """
+        Event called at the end of an substep during gradient accumulation.
+        """
+        pass
+
+    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+        """
+        Event called at the end of a training step. If using gradient accumulation, one training step might take
+        several inputs.
+        """
+        pass
+
+    def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+        """
+        Event called after an evaluation phase.
+        """
+        pass
+
+    def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics, **kwargs):
+        """
+        Event called after a successful prediction.
+        """
+        pass
+
+    def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+        """
+        Event called after a checkpoint save.
+        """
+        pass
+
+    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+        """
+        Event called after logging the last logs.
+        """
+        pass
+
+    def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+        """
+        Event called after a prediction step.
+        """
+        pass
+
+
+class CallbackHandler(TrainerCallback):
+    """Internal class that just calls the list of callbacks in order."""
+
+    def __init__(self, callbacks, model, processing_class, optimizer, lr_scheduler):
+        self.callbacks = []
+        for cb in callbacks:
+            self.add_callback(cb)
+        self.model = model
+        self.processing_class = processing_class
+        self.optimizer = optimizer
+        self.lr_scheduler = lr_scheduler
+        self.train_dataloader = None
+        self.eval_dataloader = None
+
+        if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks):
+            logger.warning(
+                "The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n"
+                + "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of"
+                + "callbacks is\n:"
+                + self.callback_list
+            )
+
+    def add_callback(self, callback):
+        cb = callback() if isinstance(callback, type) else callback
+        cb_class = callback if isinstance(callback, type) else callback.__class__
+        if cb_class in [c.__class__ for c in self.callbacks]:
+            logger.warning(
+                f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current"
+                + "list of callbacks is\n:"
+                + self.callback_list
+            )
+        self.callbacks.append(cb)
+
+    def pop_callback(self, callback):
+        if isinstance(callback, type):
+            for cb in self.callbacks:
+                if isinstance(cb, callback):
+                    self.callbacks.remove(cb)
+                    return cb
+        else:
+            for cb in self.callbacks:
+                if cb == callback:
+                    self.callbacks.remove(cb)
+                    return cb
+
+    def remove_callback(self, callback):
+        if isinstance(callback, type):
+            for cb in self.callbacks:
+                if isinstance(cb, callback):
+                    self.callbacks.remove(cb)
+                    return
+        else:
+            self.callbacks.remove(callback)
+
+    @property
+    def callback_list(self):
+        return "\n".join(cb.__class__.__name__ for cb in self.callbacks)
+
+    def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+        return self.call_event("on_init_end", args, state, control)
+
+    def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+        control.should_training_stop = False
+        return self.call_event("on_train_begin", args, state, control)
+
+    def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+        return self.call_event("on_train_end", args, state, control)
+
+    def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+        control.should_epoch_stop = False
+        return self.call_event("on_epoch_begin", args, state, control)
+
+    def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+        return self.call_event("on_epoch_end", args, state, control)
+
+    def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+        control.should_log = False
+        control.should_evaluate = False
+        control.should_save = False
+        return self.call_event("on_step_begin", args, state, control)
+
+    def on_pre_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+        return self.call_event("on_pre_optimizer_step", args, state, control)
+
+    def on_optimizer_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+        return self.call_event("on_optimizer_step", args, state, control)
+
+    def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+        return self.call_event("on_substep_end", args, state, control)
+
+    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+        return self.call_event("on_step_end", args, state, control)
+
+    def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):
+        control.should_evaluate = False
+        return self.call_event("on_evaluate", args, state, control, metrics=metrics)
+
+    def on_predict(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):
+        return self.call_event("on_predict", args, state, control, metrics=metrics)
+
+    def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+        control.should_save = False
+        return self.call_event("on_save", args, state, control)
+
+    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs):
+        control.should_log = False
+        return self.call_event("on_log", args, state, control, logs=logs)
+
+    def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+        return self.call_event("on_prediction_step", args, state, control)
+
+    def call_event(self, event, args, state, control, **kwargs):
+        for callback in self.callbacks:
+            result = getattr(callback, event)(
+                args,
+                state,
+                control,
+                model=self.model,
+                processing_class=self.processing_class,
+                optimizer=self.optimizer,
+                lr_scheduler=self.lr_scheduler,
+                train_dataloader=self.train_dataloader,
+                eval_dataloader=self.eval_dataloader,
+                **kwargs,
+            )
+            # A Callback can skip the return of `control` if it doesn't change it.
+            if result is not None:
+                control = result
+        return control
+
+
+class DefaultFlowCallback(TrainerCallback):
+    """
+    A [`TrainerCallback`] that handles the default flow of the training loop for logs, evaluation and checkpoints.
+    """
+
+    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+        # Log
+        if state.global_step == 1 and args.logging_first_step:
+            control.should_log = True
+        if args.logging_strategy == IntervalStrategy.STEPS and state.global_step % state.logging_steps == 0:
+            control.should_log = True
+
+        # Evaluate
+        if (
+            args.eval_strategy == IntervalStrategy.STEPS
+            and state.global_step % state.eval_steps == 0
+            and args.eval_delay <= state.global_step
+        ):
+            control.should_evaluate = True
+
+        # Save
+        if (
+            args.save_strategy == SaveStrategy.STEPS
+            and state.save_steps > 0
+            and state.global_step % state.save_steps == 0
+        ):
+            control.should_save = True
+
+        # End training
+        if state.global_step >= state.max_steps:
+            control.should_training_stop = True
+            # Save the model at the end if we have a save strategy
+            if args.save_strategy == SaveStrategy.STEPS:
+                control.should_save = True
+
+        return control
+
+    def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+        # Log
+        if args.logging_strategy == IntervalStrategy.EPOCH:
+            control.should_log = True
+
+        # Evaluate
+        if args.eval_strategy == IntervalStrategy.EPOCH and args.eval_delay <= state.epoch:
+            control.should_evaluate = True
+
+        # Save
+        if args.save_strategy == SaveStrategy.EPOCH:
+            control.should_save = True
+
+        return control
+
+
+class ProgressCallback(TrainerCallback):
+    """
+    A [`TrainerCallback`] that displays the progress of training or evaluation.
+    You can modify `max_str_len` to control how long strings are truncated when logging.
+    """
+
+    def __init__(self, max_str_len: int = 100):
+        """
+        Initialize the callback with optional max_str_len parameter to control string truncation length.
+
+        Args:
+            max_str_len (`int`):
+                Maximum length of strings to display in logs.
+                Longer strings will be truncated with a message.
+        """
+        self.training_bar = None
+        self.prediction_bar = None
+        self.max_str_len = max_str_len
+
+    def on_train_begin(self, args, state, control, **kwargs):
+        if state.is_world_process_zero:
+            self.training_bar = tqdm(total=state.max_steps, dynamic_ncols=True)
+        self.current_step = 0
+
+    def on_step_end(self, args, state, control, **kwargs):
+        if state.is_world_process_zero:
+            self.training_bar.update(state.global_step - self.current_step)
+            self.current_step = state.global_step
+
+    def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
+        if state.is_world_process_zero and has_length(eval_dataloader):
+            if self.prediction_bar is None:
+                self.prediction_bar = tqdm(
+                    total=len(eval_dataloader), leave=self.training_bar is None, dynamic_ncols=True
+                )
+            self.prediction_bar.update(1)
+
+    def on_evaluate(self, args, state, control, **kwargs):
+        if state.is_world_process_zero:
+            if self.prediction_bar is not None:
+                self.prediction_bar.close()
+            self.prediction_bar = None
+
+    def on_predict(self, args, state, control, **kwargs):
+        if state.is_world_process_zero:
+            if self.prediction_bar is not None:
+                self.prediction_bar.close()
+            self.prediction_bar = None
+
+    def on_log(self, args, state, control, logs=None, **kwargs):
+        if state.is_world_process_zero and self.training_bar is not None:
+            # make a shallow copy of logs so we can mutate the fields copied
+            # but avoid doing any value pickling.
+            shallow_logs = {}
+            for k, v in logs.items():
+                if isinstance(v, str) and len(v) > self.max_str_len:
+                    shallow_logs[k] = (
+                        f"[String too long to display, length: {len(v)} > {self.max_str_len}. "
+                        "Consider increasing `max_str_len` if needed.]"
+                    )
+                else:
+                    shallow_logs[k] = v
+            _ = shallow_logs.pop("total_flos", None)
+            # round numbers so that it looks better in console
+            if "epoch" in shallow_logs:
+                shallow_logs["epoch"] = round(shallow_logs["epoch"], 2)
+            self.training_bar.write(str(shallow_logs))
+
+    def on_train_end(self, args, state, control, **kwargs):
+        if state.is_world_process_zero:
+            self.training_bar.close()
+            self.training_bar = None
+
+
+class PrinterCallback(TrainerCallback):
+    """
+    A bare [`TrainerCallback`] that just prints the logs.
+    """
+
+    def on_log(self, args, state, control, logs=None, **kwargs):
+        _ = logs.pop("total_flos", None)
+        if state.is_local_process_zero:
+            print(logs)
+
+
+class EarlyStoppingCallback(TrainerCallback, ExportableState):
+    """
+    A [`TrainerCallback`] that handles early stopping.
+
+    Args:
+        early_stopping_patience (`int`):
+            Use with `metric_for_best_model` to stop training when the specified metric worsens for
+            `early_stopping_patience` evaluation calls.
+        early_stopping_threshold(`float`, *optional*):
+            Use with TrainingArguments `metric_for_best_model` and `early_stopping_patience` to denote how much the
+            specified metric must improve to satisfy early stopping conditions. `
+
+    This callback depends on [`TrainingArguments`] argument *load_best_model_at_end* functionality to set best_metric
+    in [`TrainerState`]. Note that if the [`TrainingArguments`] argument *save_steps* differs from *eval_steps*, the
+    early stopping will not occur until the next save step.
+    """
+
+    def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
+        self.early_stopping_patience = early_stopping_patience
+        self.early_stopping_threshold = early_stopping_threshold
+        # early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
+        self.early_stopping_patience_counter = 0
+
+    def check_metric_value(self, args, state, control, metric_value):
+        # best_metric is set by code for load_best_model
+        operator = np.greater if args.greater_is_better else np.less
+        if state.best_metric is None or (
+            operator(metric_value, state.best_metric)
+            and abs(metric_value - state.best_metric) > self.early_stopping_threshold
+        ):
+            self.early_stopping_patience_counter = 0
+        else:
+            self.early_stopping_patience_counter += 1
+
+    def on_train_begin(self, args, state, control, **kwargs):
+        if not args.load_best_model_at_end:
+            logger.warning(
+                "Using EarlyStoppingCallback without load_best_model_at_end=True. "
+                "Once training is finished, the best model will not be loaded automatically."
+            )
+        assert args.metric_for_best_model is not None, (
+            "EarlyStoppingCallback requires metric_for_best_model to be defined"
+        )
+        assert args.eval_strategy != IntervalStrategy.NO, (
+            "EarlyStoppingCallback requires IntervalStrategy of steps or epoch"
+        )
+
+    def on_evaluate(self, args, state, control, metrics, **kwargs):
+        metric_to_check = args.metric_for_best_model
+        if not metric_to_check.startswith("eval_"):
+            metric_to_check = f"eval_{metric_to_check}"
+        metric_value = metrics.get(metric_to_check)
+
+        if metric_value is None:
+            logger.warning(
+                f"early stopping required metric_for_best_model, but did not find {metric_to_check} so early stopping"
+                " is disabled"
+            )
+            return
+
+        self.check_metric_value(args, state, control, metric_value)
+        if self.early_stopping_patience_counter >= self.early_stopping_patience:
+            control.should_training_stop = True
+
+    def state(self) -> dict:
+        return {
+            "args": {
+                "early_stopping_patience": self.early_stopping_patience,
+                "early_stopping_threshold": self.early_stopping_threshold,
+            },
+            "attributes": {
+                "early_stopping_patience_counter": self.early_stopping_patience_counter,
+            },
+        }
diff --git a/phivenv/Lib/site-packages/transformers/trainer_pt_utils.py b/phivenv/Lib/site-packages/transformers/trainer_pt_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c32516b167fe54f2d29e64d221e86e63c806f592
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/trainer_pt_utils.py
@@ -0,0 +1,1406 @@
+# Copyright 2020-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Torch utilities for the Trainer class.
+"""
+
+import copy
+import datetime
+import io
+import json
+import math
+import os
+import re
+import sys
+import warnings
+from collections.abc import Iterator, Mapping
+from contextlib import contextmanager
+from dataclasses import dataclass, field
+from itertools import chain
+from logging import StreamHandler
+from typing import Any, Optional, Union
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch import nn
+from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
+from torch.utils.data.distributed import DistributedSampler
+
+from .integrations.deepspeed import is_deepspeed_zero3_enabled
+from .tokenization_utils_base import BatchEncoding
+from .utils import (
+    is_sagemaker_mp_enabled,
+    is_torch_available,
+    is_torch_xla_available,
+    is_training_run_on_sagemaker,
+    logging,
+)
+
+
+if is_training_run_on_sagemaker():
+    logging.add_handler(StreamHandler(sys.stdout))
+
+if is_torch_xla_available():
+    import torch_xla.runtime as xr
+
+if is_torch_available():
+    from torch.optim.lr_scheduler import LRScheduler
+
+
+logger = logging.get_logger(__name__)
+
+
+def get_dataloader_sampler(dataloader):
+    if hasattr(dataloader, "batch_sampler") and dataloader.batch_sampler is not None:
+        return get_dataloader_sampler(dataloader.batch_sampler)
+    elif hasattr(dataloader, "sampler"):
+        return dataloader.sampler
+
+
+def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]):
+    if isinstance(tensor_or_array, torch.Tensor):
+        if hasattr(torch, "atleast_1d"):
+            tensor_or_array = torch.atleast_1d(tensor_or_array)
+        elif tensor_or_array.ndim < 1:
+            tensor_or_array = tensor_or_array[None]
+    else:
+        tensor_or_array = np.atleast_1d(tensor_or_array)
+    return tensor_or_array
+
+
+def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
+    """Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary."""
+    tensor1 = atleast_1d(tensor1)
+    tensor2 = atleast_1d(tensor2)
+
+    if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]:
+        return torch.cat((tensor1, tensor2), dim=0)
+
+    # Let's figure out the new shape
+    new_shape = (tensor1.shape[0] + tensor2.shape[0], max(tensor1.shape[1], tensor2.shape[1])) + tensor1.shape[2:]
+
+    # Now let's fill the result tensor
+    result = tensor1.new_full(new_shape, padding_index)
+    result[: tensor1.shape[0], : tensor1.shape[1]] = tensor1
+    result[tensor1.shape[0] :, : tensor2.shape[1]] = tensor2
+    return result
+
+
+def numpy_pad_and_concatenate(array1, array2, padding_index=-100):
+    """Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary."""
+    array1 = atleast_1d(array1)
+    array2 = atleast_1d(array2)
+
+    if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]:
+        return np.concatenate((array1, array2), axis=0)
+
+    # Let's figure out the new shape
+    new_shape = (array1.shape[0] + array2.shape[0], max(array1.shape[1], array2.shape[1])) + array1.shape[2:]
+
+    # Now let's fill the result tensor
+    result = np.full_like(array1, padding_index, shape=new_shape)
+    result[: array1.shape[0], : array1.shape[1]] = array1
+    result[array1.shape[0] :, : array2.shape[1]] = array2
+    return result
+
+
+def nested_concat(tensors, new_tensors, padding_index=-100):
+    """
+    Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or
+    nested list/tuples/dict of tensors.
+    """
+    if not (isinstance(tensors, torch.Tensor) and isinstance(new_tensors, torch.Tensor)):
+        assert type(tensors) is type(new_tensors), (
+            f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
+        )
+    if isinstance(tensors, (list, tuple)):
+        return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
+    elif isinstance(tensors, torch.Tensor):
+        return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
+    elif isinstance(tensors, Mapping):
+        return type(tensors)(
+            {k: nested_concat(t, new_tensors[k], padding_index=padding_index) for k, t in tensors.items()}
+        )
+    elif isinstance(tensors, np.ndarray):
+        return numpy_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
+    else:
+        raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")
+
+
+def find_batch_size(tensors):
+    """
+    Find the first dimension of a tensor in a nested list/tuple/dict of tensors.
+    """
+    if isinstance(tensors, (list, tuple)):
+        for t in tensors:
+            result = find_batch_size(t)
+            if result is not None:
+                return result
+    elif isinstance(tensors, Mapping):
+        for value in tensors.values():
+            result = find_batch_size(value)
+            if result is not None:
+                return result
+    elif isinstance(tensors, (torch.Tensor, np.ndarray)):
+        return tensors.shape[0] if len(tensors.shape) >= 1 else None
+
+
+def nested_numpify(tensors):
+    "Numpify `tensors` (even if it's a nested list/tuple/dict of tensors)."
+    if isinstance(tensors, (list, tuple)):
+        return type(tensors)(nested_numpify(t) for t in tensors)
+    if isinstance(tensors, Mapping):
+        return type(tensors)({k: nested_numpify(t) for k, t in tensors.items()})
+
+    t = tensors.cpu()
+    if t.dtype == torch.bfloat16:
+        # As of Numpy 1.21.4, NumPy does not support bfloat16 (see
+        # https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ).
+        # Until Numpy adds bfloat16, we must convert float32.
+        t = t.to(torch.float32)
+    return t.numpy()
+
+
+def nested_detach(tensors):
+    "Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."
+    if isinstance(tensors, (list, tuple)):
+        return type(tensors)(nested_detach(t) for t in tensors)
+    elif isinstance(tensors, Mapping):
+        return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})
+    return tensors.detach() if isinstance(tensors, torch.Tensor) else tensors
+
+
+def nested_xla_mesh_reduce(tensors, name):
+    if is_torch_xla_available():
+        import torch_xla.core.xla_model as xm
+
+        if isinstance(tensors, (list, tuple)):
+            return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
+        if isinstance(tensors, Mapping):
+            return type(tensors)(
+                {k: nested_xla_mesh_reduce(t, f"{name}_{i}") for i, (k, t) in enumerate(tensors.items())}
+            )
+
+        tensors = atleast_1d(tensors)
+        return xm.mesh_reduce(name, tensors, torch.cat)
+    else:
+        raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
+
+
+def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) -> Any:
+    try:
+        if isinstance(tensor, (tuple, list)):
+            return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
+        if isinstance(tensor, Mapping):
+            return type(tensor)({k: distributed_concat(t, num_total_examples) for k, t in tensor.items()})
+        tensor = atleast_1d(tensor).contiguous()
+        output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
+        dist.all_gather(output_tensors, tensor)
+        concat = torch.cat(output_tensors, dim=0)
+
+        # truncate the dummy elements added by SequentialDistributedSampler
+        if num_total_examples is not None:
+            concat = concat[:num_total_examples]
+        return concat
+    except AssertionError:
+        raise AssertionError("Not currently using distributed training")
+
+
+def distributed_broadcast_scalars(
+    scalars: list[Union[int, float]],
+    num_total_examples: Optional[int] = None,
+    device: Optional[torch.device] = torch.device("cuda"),
+) -> torch.Tensor:
+    try:
+        tensorized_scalar = torch.tensor(scalars, device=device)
+        output_tensors = [tensorized_scalar.clone() for _ in range(dist.get_world_size())]
+        dist.all_gather(output_tensors, tensorized_scalar)
+        concat = torch.cat(output_tensors, dim=0)
+
+        # truncate the dummy elements added by SequentialDistributedSampler
+        if num_total_examples is not None:
+            concat = concat[:num_total_examples]
+        return concat
+    except AssertionError:
+        raise AssertionError("Not currently using distributed training")
+
+
+def reissue_pt_warnings(caught_warnings):
+    # Reissue warnings
+    if len(caught_warnings) > 1:
+        for w in caught_warnings:
+            if w.category is not UserWarning:
+                warnings.warn(w.message, w.category)
+
+
+@contextmanager
+def torch_distributed_zero_first(local_rank: int):
+    """
+    Decorator to make all processes in distributed training wait for each local_master to do something.
+
+    Args:
+        local_rank (`int`): The rank of the local process.
+    """
+    if local_rank not in [-1, 0]:
+        dist.barrier()
+    yield
+    if local_rank == 0:
+        dist.barrier()
+
+
+class DistributedSamplerWithLoop(DistributedSampler):
+    """
+    Like a torch.utils.data.distributed.DistributedSampler` but loops at the end back to the beginning of the shuffled
+    samples to make each process have a round multiple of batch_size samples.
+
+    Args:
+        dataset (`torch.utils.data.Dataset`):
+            Dataset used for sampling.
+        batch_size (`int`):
+            The batch size used with this sampler
+        kwargs (`dict[str, Any]`, *optional*):
+            All other keyword arguments passed to `DistributedSampler`.
+    """
+
+    def __init__(self, dataset, batch_size, **kwargs):
+        super().__init__(dataset, **kwargs)
+        self.batch_size = batch_size
+
+    def __iter__(self):
+        indices = list(super().__iter__())
+        remainder = 0 if len(indices) % self.batch_size == 0 else self.batch_size - len(indices) % self.batch_size
+        # DistributedSampler already added samples from the beginning to make the number of samples a round multiple
+        # of the world size, so we skip those.
+        start_remainder = 1 if self.rank < len(self.dataset) % self.num_replicas else 0
+        indices += indices[start_remainder : start_remainder + remainder]
+        return iter(indices)
+
+
+class EvalLoopContainer:
+    """
+    Container to store intermediate results of evaluation loop.
+
+    Args:
+        do_nested_concat (`bool`, *optional*, defaults to `True`):
+            If set to `True`, each iteration will recursively concatenate a new object containing tensors to
+            the existing stored tensors, provided that the structure of the existing object and the new one
+            are identical. If set to `False`, all newly added tensors will be stored in a list.
+        padding_index (`int`, *optional*, defaults to -100):
+            Value used to pad tensors of different shapes when `do_nested_concat=True`.
+    """
+
+    def __init__(self, do_nested_concat: bool = True, padding_index: int = -100):
+        self.do_nested_concat = do_nested_concat
+        self.padding_index = padding_index
+        self.tensors = None
+        self.arrays = None
+
+    def add(self, tensors) -> None:
+        """Add tensors to the stored objects. If `do_nested_concat=True`, the tensors will be concatenated recursively."""
+        if self.tensors is None:
+            self.tensors = tensors if self.do_nested_concat else [tensors]
+        elif self.do_nested_concat:
+            self.tensors = nested_concat(self.tensors, tensors, padding_index=self.padding_index)
+        else:
+            self.tensors.append(tensors)
+
+    def to_cpu_and_numpy(self) -> None:
+        """Move tensors in stored objects to CPU and convert them to numpy arrays."""
+
+        # Check if we have something to add, if not just return
+        if self.tensors is None:
+            return
+
+        new_arrays = nested_numpify(self.tensors)
+        if self.arrays is None:
+            self.arrays = new_arrays
+        elif self.do_nested_concat:
+            self.arrays = nested_concat(self.arrays, new_arrays, padding_index=self.padding_index)
+        else:
+            self.arrays.extend(new_arrays)
+
+        # reset device tensors after adding to cpu
+        self.tensors = None
+
+    def get_arrays(self):
+        """Returns the numpified and moved to CPU stored objects."""
+        self.to_cpu_and_numpy()
+        return self.arrays
+
+
+class SequentialDistributedSampler(Sampler):
+    """
+    Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end.
+
+    Even though we only use this sampler for eval and predict (no training), which means that the model params won't
+    have to be synced (i.e. will not hang for synchronization even if varied number of forward passes), we still add
+    extra samples to the sampler to make it evenly divisible (like in `DistributedSampler`) to make it easy to `gather`
+    or `reduce` resulting tensors at the end of the loop.
+    """
+
+    def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None):
+        warnings.warn(
+            "SequentialDistributedSampler is deprecated and will be removed in v5 of Transformers.",
+            FutureWarning,
+        )
+        if num_replicas is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            num_replicas = dist.get_world_size()
+        if rank is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            rank = dist.get_rank()
+        self.dataset = dataset
+        self.num_replicas = num_replicas
+        self.rank = rank
+        num_samples = len(self.dataset)
+        # Add extra samples to make num_samples a multiple of batch_size if passed
+        if batch_size is not None:
+            self.num_samples = int(math.ceil(num_samples / (batch_size * num_replicas))) * batch_size
+        else:
+            self.num_samples = int(math.ceil(num_samples / num_replicas))
+        self.total_size = self.num_samples * self.num_replicas
+        self.batch_size = batch_size
+
+    def __iter__(self):
+        indices = list(range(len(self.dataset)))
+
+        # add extra samples to make it evenly divisible
+        indices += indices[: (self.total_size - len(indices))]
+        assert len(indices) == self.total_size, (
+            f"Indices length {len(indices)} and total size {self.total_size} mismatched"
+        )
+
+        # subsample
+        indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
+        assert len(indices) == self.num_samples, (
+            f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
+        )
+
+        return iter(indices)
+
+    def __len__(self):
+        return self.num_samples
+
+
+def get_tpu_sampler(dataset: torch.utils.data.Dataset, batch_size: int):
+    if xr.world_size() <= 1:
+        return RandomSampler(dataset)
+    return DistributedSampler(dataset, num_replicas=xr.world_size(), rank=xr.global_ordinal())
+
+
+def nested_new_like(arrays, num_samples, padding_index=-100):
+    """Create the same nested structure as `arrays` with a first dimension always at `num_samples`."""
+    if isinstance(arrays, (list, tuple)):
+        return type(arrays)(nested_new_like(x, num_samples) for x in arrays)
+    return np.full_like(arrays, padding_index, shape=(num_samples, *arrays.shape[1:]))
+
+
+def expand_like(arrays, new_seq_length, padding_index=-100):
+    """Expand the `arrays` so that the second dimension grows to `new_seq_length`. Uses `padding_index` for padding."""
+    result = np.full_like(arrays, padding_index, shape=(arrays.shape[0], new_seq_length) + arrays.shape[2:])
+    result[:, : arrays.shape[1]] = arrays
+    return result
+
+
+def nested_truncate(tensors, limit):
+    "Truncate `tensors` at `limit` (even if it's a nested list/tuple/dict of tensors)."
+    if isinstance(tensors, (list, tuple)):
+        return type(tensors)(nested_truncate(t, limit) for t in tensors)
+    if isinstance(tensors, Mapping):
+        return type(tensors)({k: nested_truncate(t, limit) for k, t in tensors.items()})
+
+    return tensors[:limit]
+
+
+class DistributedTensorGatherer:
+    """
+    A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
+
+    If our dataset has 16 samples with a batch size of 2 on 3 processes and we gather then transfer on CPU at every
+    step, our sampler will generate the following indices:
+
+        `[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1]`
+
+    to get something of size a multiple of 3 (so that each process gets the same dataset length). Then process 0, 1 and
+    2 will be responsible of making predictions for the following samples:
+
+        - P0: `[0, 1, 2, 3, 4, 5]`
+        - P1: `[6, 7, 8, 9, 10, 11]`
+        - P2: `[12, 13, 14, 15, 0, 1]`
+
+    The first batch treated on each process will be:
+
+        - P0: `[0, 1]`
+        - P1: `[6, 7]`
+        - P2: `[12, 13]`
+
+    So if we gather at the end of the first batch, we will get a tensor (nested list/tuple of tensor) corresponding to
+    the following indices:
+
+        `[0, 1, 6, 7, 12, 13]`
+
+    If we directly concatenate our results without taking any precautions, the user will then get the predictions for
+    the indices in this order at the end of the prediction loop:
+
+        `[0, 1, 6, 7, 12, 13, 2, 3, 8, 9, 14, 15, 4, 5, 10, 11, 0, 1]`
+
+    For some reason, that's not going to roll their boat. This class is there to solve that problem.
+
+    Args:
+        world_size (`int`):
+            The number of processes used in the distributed training.
+        num_samples (`int`):
+            The number of samples in our dataset.
+        make_multiple_of (`int`, *optional*):
+            If passed, the class assumes the datasets passed to each process are made to be a multiple of this argument
+            (by adding samples).
+        padding_index (`int`, *optional*, defaults to -100):
+            The padding index to use if the arrays don't all have the same sequence length.
+    """
+
+    def __init__(self, world_size, num_samples, make_multiple_of=None, padding_index=-100):
+        warnings.warn(
+            "DistributedTensorGatherer is deprecated and will be removed in v5 of Transformers.",
+            FutureWarning,
+        )
+        self.world_size = world_size
+        self.num_samples = num_samples
+        total_size = world_size if make_multiple_of is None else world_size * make_multiple_of
+        self.total_samples = int(np.ceil(num_samples / total_size)) * total_size
+        self.process_length = self.total_samples // world_size
+        self._storage = None
+        self._offsets = None
+        self.padding_index = padding_index
+
+    def add_arrays(self, arrays):
+        """
+        Add `arrays` to the internal storage, Will initialize the storage to the full size at the first arrays passed
+        so that if we're bound to get an OOM, it happens at the beginning.
+        """
+        if arrays is None:
+            return
+        if self._storage is None:
+            self._storage = nested_new_like(arrays, self.total_samples, padding_index=self.padding_index)
+            self._offsets = list(range(0, self.total_samples, self.process_length))
+
+        slice_len, self._storage = self._nested_set_tensors(self._storage, arrays)
+        for i in range(self.world_size):
+            self._offsets[i] += slice_len
+
+    def _nested_set_tensors(self, storage, arrays):
+        if isinstance(arrays, (list, tuple)):
+            result = [self._nested_set_tensors(x, y) for x, y in zip(storage, arrays)]
+            return result[0][0], type(arrays)(r[1] for r in result)
+        assert arrays.shape[0] % self.world_size == 0, (
+            f"Arrays passed should all have a first dimension multiple of {self.world_size}, found {arrays.shape[0]}."
+        )
+
+        slice_len = arrays.shape[0] // self.world_size
+        for i in range(self.world_size):
+            if len(arrays.shape) == 1:
+                storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len]
+            else:
+                # Expand the array on the fly if needed.
+                if len(storage.shape) > 1 and storage.shape[1] < arrays.shape[1]:
+                    storage = expand_like(storage, arrays.shape[1], padding_index=self.padding_index)
+                storage[self._offsets[i] : self._offsets[i] + slice_len, : arrays.shape[1]] = arrays[
+                    i * slice_len : (i + 1) * slice_len
+                ]
+        return slice_len, storage
+
+    def finalize(self):
+        """
+        Return the properly gathered arrays and truncate to the number of samples (since the sampler added some extras
+        to get each process a dataset of the same length).
+        """
+        if self._storage is None:
+            return
+        if self._offsets[0] != self.process_length:
+            logger.warning("Not all data has been set. Are you sure you passed all values?")
+        return nested_truncate(self._storage, self.num_samples)
+
+
+@dataclass
+class LabelSmoother:
+    """
+    Adds label-smoothing on a pre-computed output from a Transformers model.
+
+    Args:
+        epsilon (`float`, *optional*, defaults to 0.1):
+            The label smoothing factor.
+        ignore_index (`int`, *optional*, defaults to -100):
+            The index in the labels to ignore when computing the loss.
+    """
+
+    epsilon: float = 0.1
+    ignore_index: int = -100
+
+    def __call__(self, model_output, labels, shift_labels=False):
+        logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
+        if shift_labels:
+            logits = logits[..., :-1, :].contiguous()
+            labels = labels[..., 1:].contiguous()
+
+        log_probs = -nn.functional.log_softmax(logits, dim=-1)
+        if labels.dim() == log_probs.dim() - 1:
+            labels = labels.unsqueeze(-1)
+
+        padding_mask = labels.eq(self.ignore_index)
+        # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
+        # will ignore them in any case.
+        labels = torch.clamp(labels, min=0)
+        nll_loss = log_probs.gather(dim=-1, index=labels)
+        # works for fp16 input tensor too, by internally upcasting it to fp32
+        smoothed_loss = log_probs.sum(dim=-1, keepdim=True, dtype=torch.float32)
+
+        nll_loss.masked_fill_(padding_mask, 0.0)
+        smoothed_loss.masked_fill_(padding_mask, 0.0)
+
+        # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
+        num_active_elements = padding_mask.numel() - padding_mask.long().sum()
+        nll_loss = nll_loss.sum() / num_active_elements
+        smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
+        return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss
+
+
+def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None):
+    """
+    Return a list of indices so that each slice of `batch_size` consecutive indices correspond to elements of similar
+    lengths. To do this, the indices are:
+
+    - randomly permuted
+    - grouped in mega-batches of size `mega_batch_mult * batch_size`
+    - sorted by length in each mega-batch
+
+    The result is the concatenation of all mega-batches, with the batch of `batch_size` containing the element of
+    maximum length placed first, so that an OOM happens sooner rather than later.
+    """
+    # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
+    if mega_batch_mult is None:
+        mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
+        # Just in case, for tiny datasets
+        if mega_batch_mult == 0:
+            mega_batch_mult = 1
+
+    # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+    indices = torch.randperm(len(lengths), generator=generator)
+    megabatch_size = mega_batch_mult * batch_size
+    megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
+    megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
+
+    # The rest is to get the biggest batch first.
+    # Since each megabatch is sorted by descending length, the longest element is the first
+    megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
+    max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
+    # Switch to put the longest element in first position
+    megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]
+
+    return [i for megabatch in megabatches for i in megabatch]
+
+
+class LengthGroupedSampler(Sampler):
+    r"""
+    Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
+    keeping a bit of randomness.
+    """
+
+    def __init__(
+        self,
+        batch_size: int,
+        dataset: Optional[Dataset] = None,
+        lengths: Optional[list[int]] = None,
+        model_input_name: Optional[str] = None,
+        generator=None,
+    ):
+        if dataset is None and lengths is None:
+            raise ValueError("One of dataset and lengths must be provided.")
+
+        self.batch_size = batch_size
+        if lengths is None:
+            model_input_name = model_input_name if model_input_name is not None else "input_ids"
+            if not isinstance(dataset[0], (dict, BatchEncoding)) or model_input_name not in dataset[0]:
+                raise ValueError(
+                    "Can only automatically infer lengths for datasets whose items are dictionaries with an "
+                    f"'{model_input_name}' key."
+                )
+            lengths = [len(feature[model_input_name]) for feature in dataset]
+        elif isinstance(lengths, torch.Tensor):
+            logger.info(
+                "If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to list[int]..."
+            )
+            lengths = lengths.tolist()
+
+        self.lengths = lengths
+        self.generator = generator
+
+    def __len__(self):
+        return len(self.lengths)
+
+    def __iter__(self):
+        indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=self.generator)
+        return iter(indices)
+
+
+class DistributedLengthGroupedSampler(DistributedSampler):
+    r"""
+    Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
+    length while keeping a bit of randomness.
+    """
+
+    # Copied and adapted from PyTorch DistributedSampler.
+    def __init__(
+        self,
+        batch_size: int,
+        dataset: Optional[Dataset] = None,
+        num_replicas: Optional[int] = None,
+        rank: Optional[int] = None,
+        seed: int = 0,
+        drop_last: bool = False,
+        lengths: Optional[list[int]] = None,
+        model_input_name: Optional[str] = None,
+    ):
+        if dataset is None and lengths is None:
+            raise ValueError("One of dataset and lengths must be provided.")
+        if num_replicas is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            num_replicas = dist.get_world_size()
+        if rank is None:
+            if not dist.is_available():
+                raise RuntimeError("Requires distributed package to be available")
+            rank = dist.get_rank()
+
+        self.batch_size = batch_size
+        self.num_replicas = num_replicas
+        self.rank = rank
+        self.epoch = 0
+        self.drop_last = drop_last
+
+        if lengths is None:
+            model_input_name = model_input_name if model_input_name is not None else "input_ids"
+            if not isinstance(dataset[0], (dict, BatchEncoding)) or model_input_name not in dataset[0]:
+                raise ValueError(
+                    "Can only automatically infer lengths for datasets whose items are dictionaries with an "
+                    f"'{model_input_name}' key."
+                )
+            lengths = [len(feature[model_input_name]) for feature in dataset]
+        elif isinstance(lengths, torch.Tensor):
+            logger.info(
+                "If lengths is a torch.Tensor, DistributedLengthGroupedSampler will be slow. Converting lengths to"
+                " list[int]..."
+            )
+            lengths = lengths.tolist()
+
+        self.lengths = lengths
+
+        # If the dataset length is evenly divisible by # of replicas, then there
+        # is no need to drop any data, since the dataset will be split equally.
+        if self.drop_last and len(self.lengths) % self.num_replicas != 0:
+            # Split to nearest available length that is evenly divisible.
+            # This is to ensure each rank receives the same amount of data when
+            # using this Sampler.
+            self.num_samples = math.ceil((len(self.lengths) - self.num_replicas) / self.num_replicas)
+        else:
+            self.num_samples = math.ceil(len(self.lengths) / self.num_replicas)
+        self.total_size = self.num_samples * self.num_replicas
+        self.seed = seed
+
+    def __iter__(self) -> Iterator:
+        # Deterministically shuffle based on epoch and seed
+        g = torch.Generator()
+        g.manual_seed(self.seed + self.epoch)
+        indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
+
+        if not self.drop_last:
+            # add extra samples to make it evenly divisible
+            indices += indices[: (self.total_size - len(indices))]
+        else:
+            # remove tail of data to make it evenly divisible
+            indices = indices[: self.total_size]
+        assert len(indices) == self.total_size
+
+        # subsample
+        indices = indices[self.rank : self.total_size : self.num_replicas]
+        assert len(indices) == self.num_samples
+
+        return iter(indices)
+
+
+class ShardSampler(Sampler):
+    """
+    Sampler that shards batches between several processes. Dispatches indices batch by batch: on 2 processes with batch
+    size 4, the first two batches are `[0, 1, 2, 3, 4, 5, 6, 7]` and `[8, 9, 10, 11, 12, 13, 14, 15]`, which shard into
+    `[0, 1, 2, 3]` and `[8, 9, 10, 11]` for GPU-0 and `[4, 5, 6, 7]` and `[12, 13, 14, 15]` for GPU-1.
+
+    The sampler thus yields `[0, 1, 2, 3, 8, 9, 10, 11]` on GPU-0 and `[4, 5, 6, 7, 12, 13, 14, 15]` on GPU-1.
+    """
+
+    def __init__(
+        self,
+        dataset: Dataset,
+        batch_size: int = 1,
+        drop_last: bool = False,
+        num_processes: int = 1,
+        process_index: int = 0,
+    ):
+        self.dataset = dataset
+        self.batch_size = batch_size
+        self.drop_last = drop_last
+        self.num_processes = num_processes
+        self.process_index = process_index
+
+        self.total_batch_size = total_batch_size = batch_size * num_processes
+
+        num_batches = len(dataset) // total_batch_size if drop_last else math.ceil(len(dataset) / total_batch_size)
+        self.total_num_samples = num_batches * total_batch_size
+
+    def __iter__(self):
+        indices = list(range(len(self.dataset)))
+
+        # Add extra samples to make it evenly divisible. While loop is there in the edge case we have a tiny dataset
+        # and it needs to be done several times.
+        while len(indices) < self.total_num_samples:
+            indices += indices[: (self.total_num_samples - len(indices))]
+
+        result = []
+        for batch_start in range(self.batch_size * self.process_index, self.total_num_samples, self.total_batch_size):
+            result += indices[batch_start : batch_start + self.batch_size]
+
+        return iter(result)
+
+    def __len__(self):
+        # Each shard only sees a fraction of total_num_samples.
+        return self.total_num_samples // self.num_processes
+
+
+class IterableDatasetShard(IterableDataset):
+    """
+    Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will
+    always yield a number of samples that is a round multiple of the actual batch size (which is `batch_size x
+    num_processes`). Depending on the value of the `drop_last` attribute, it will either stop the iteration at the
+    first batch that would be too small or loop with indices from the beginning.
+
+    On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]` with a batch size of
+    2:
+
+    - the shard on process 0 will yield `[0, 1, 4, 5, 8, 9]` so will see batches `[0, 1]`, `[4, 5]`, `[8, 9]`
+    - the shard on process 1 will yield `[2, 3, 6, 7, 10, 11]` so will see batches `[2, 3]`, `[6, 7]`, `[10, 11]`
+
+    
+
+        If your IterableDataset implements some randomization that needs to be applied the same way on all processes
+        (for instance, a shuffling), you should use a `torch.Generator` in a `generator` attribute of the `dataset` to
+        generate your random numbers and call the [`~trainer_pt_utils.IterableDatasetShard.set_epoch`] method of this
+        object. It will set the seed of this `generator` to `seed + epoch` on all processes before starting the
+        iteration. Alternatively, you can also implement a `set_epoch()` method in your iterable dataset to deal with
+        this.
+
+    
+
+    Args:
+        dataset (`torch.utils.data.IterableDataset`):
+            The batch sampler to split in several shards.
+        batch_size (`int`, *optional*, defaults to 1):
+            The size of the batches per shard.
+        drop_last (`bool`, *optional*, defaults to `False`):
+            Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
+            beginning.
+        num_processes (`int`, *optional*, defaults to 1):
+            The number of processes running concurrently.
+        process_index (`int`, *optional*, defaults to 0):
+            The index of the current process.
+        seed (`int`, *optional*, defaults to 0):
+            A random seed that will be used for the random number generation in
+            [`~trainer_pt_utils.IterableDatasetShard.set_epoch`].
+    """
+
+    def __init__(
+        self,
+        dataset: IterableDataset,
+        batch_size: int = 1,
+        drop_last: bool = False,
+        num_processes: int = 1,
+        process_index: int = 0,
+        seed: int = 0,
+    ):
+        self.dataset = dataset
+        self.batch_size = batch_size
+        self.drop_last = drop_last
+        self.num_processes = num_processes
+        self.process_index = process_index
+        self.seed = seed
+        self.epoch = 0
+        self.num_examples = 0
+
+    def set_epoch(self, epoch):
+        self.epoch = epoch
+        if hasattr(self.dataset, "set_epoch"):
+            self.dataset.set_epoch(epoch)
+
+    def __iter__(self):
+        self.num_examples = 0
+        if (
+            not hasattr(self.dataset, "set_epoch")
+            and hasattr(self.dataset, "generator")
+            and isinstance(self.dataset.generator, torch.Generator)
+        ):
+            self.dataset.generator.manual_seed(self.seed + self.epoch)
+        real_batch_size = self.batch_size * self.num_processes
+        process_slice = range(self.process_index * self.batch_size, (self.process_index + 1) * self.batch_size)
+
+        first_batch = None
+        current_batch = []
+        for element in self.dataset:
+            self.num_examples += 1
+            current_batch.append(element)
+            # Wait to have a full batch before yielding elements.
+            if len(current_batch) == real_batch_size:
+                for i in process_slice:
+                    yield current_batch[i]
+                if first_batch is None:
+                    first_batch = current_batch.copy()
+                current_batch = []
+
+        # Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.
+        if not self.drop_last and len(current_batch) > 0:
+            if first_batch is None:
+                first_batch = current_batch.copy()
+            while len(current_batch) < real_batch_size:
+                current_batch += first_batch
+            for i in process_slice:
+                yield current_batch[i]
+
+    def __len__(self):
+        # Will raise an error if the underlying dataset is not sized.
+        if self.drop_last:
+            return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size
+        else:
+            return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size
+
+
+# In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer
+# helper methods here
+
+
+def _get_learning_rate(self):
+    if self.is_deepspeed_enabled:
+        # with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
+        # not run for the first few dozen steps while loss scale is too large, and thus during
+        # that time `get_last_lr` will fail if called during that warm up stage, so work around it:
+        try:
+            last_lr = self.lr_scheduler.get_last_lr()[0]
+        except AssertionError as e:
+            if "need to call step" in str(e):
+                logger.warning("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
+                last_lr = 0
+            else:
+                raise
+    else:
+        if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
+            last_lr = self.optimizer.param_groups[0]["lr"]
+        else:
+            last_lr = self.lr_scheduler.get_last_lr()[0]
+
+    if torch.is_tensor(last_lr):
+        last_lr = last_lr.item()
+    return last_lr
+
+
+def _secs2timedelta(secs):
+    """
+    Convert seconds to hh:mm:ss.msec, msecs rounded to 2 decimal places.
+    """
+
+    msec = int(abs(secs - int(secs)) * 100)
+    return f"{datetime.timedelta(seconds=int(secs))}.{msec:02d}"
+
+
+def metrics_format(self, metrics: dict[str, float]) -> dict[str, float]:
+    """
+    Reformat Trainer metrics values to a human-readable format.
+
+    Args:
+        metrics (`dict[str, float]`):
+            The metrics returned from train/evaluate/predict
+
+    Returns:
+        metrics (`dict[str, float]`): The reformatted metrics
+    """
+
+    metrics_copy = metrics.copy()
+    for k, v in metrics_copy.items():
+        if "_mem_" in k:
+            metrics_copy[k] = f"{v >> 20}MB"
+        elif "_runtime" in k:
+            metrics_copy[k] = _secs2timedelta(v)
+        elif k == "total_flos":
+            metrics_copy[k] = f"{int(v) >> 30}GF"
+        elif isinstance(metrics_copy[k], float):
+            metrics_copy[k] = round(v, 4)
+
+    return metrics_copy
+
+
+def log_metrics(self, split, metrics):
+    """
+    Log metrics in a specially formatted way.
+
+    Under distributed environment this is done only for a process with rank 0.
+
+    Args:
+        split (`str`):
+            Mode/split name: one of `train`, `eval`, `test`
+        metrics (`dict[str, float]`):
+            The metrics returned from train/evaluate/predictmetrics: metrics dict
+
+    Notes on memory reports:
+
+    In order to get memory usage report you need to install `psutil`. You can do that with `pip install psutil`.
+
+    Now when this method is run, you will see a report that will include:
+
+    ```
+    init_mem_cpu_alloc_delta   =     1301MB
+    init_mem_cpu_peaked_delta  =      154MB
+    init_mem_gpu_alloc_delta   =      230MB
+    init_mem_gpu_peaked_delta  =        0MB
+    train_mem_cpu_alloc_delta  =     1345MB
+    train_mem_cpu_peaked_delta =        0MB
+    train_mem_gpu_alloc_delta  =      693MB
+    train_mem_gpu_peaked_delta =        7MB
+    ```
+
+    **Understanding the reports:**
+
+    - the first segment, e.g., `train__`, tells you which stage the metrics are for. Reports starting with `init_`
+        will be added to the first stage that gets run. So that if only evaluation is run, the memory usage for the
+        `__init__` will be reported along with the `eval_` metrics.
+    - the third segment, is either `cpu` or `gpu`, tells you whether it's the general RAM or the gpu0 memory
+        metric.
+    - `*_alloc_delta` - is the difference in the used/allocated memory counter between the end and the start of the
+        stage - it can be negative if a function released more memory than it allocated.
+    - `*_peaked_delta` - is any extra memory that was consumed and then freed - relative to the current allocated
+        memory counter - it is never negative. When you look at the metrics of any stage you add up `alloc_delta` +
+        `peaked_delta` and you know how much memory was needed to complete that stage.
+
+    The reporting happens only for process of rank 0 and gpu 0 (if there is a gpu). Typically this is enough since the
+    main process does the bulk of work, but it could be not quite so if model parallel is used and then other GPUs may
+    use a different amount of gpu memory. This is also not the same under DataParallel where gpu0 may require much more
+    memory than the rest since it stores the gradient and optimizer states for all participating GPUs. Perhaps in the
+    future these reports will evolve to measure those too.
+
+    The CPU RAM metric measures RSS (Resident Set Size) includes both the memory which is unique to the process and the
+    memory shared with other processes. It is important to note that it does not include swapped out memory, so the
+    reports could be imprecise.
+
+    The CPU peak memory is measured using a sampling thread. Due to python's GIL it may miss some of the peak memory if
+    that thread didn't get a chance to run when the highest memory was used. Therefore this report can be less than
+    reality. Using `tracemalloc` would have reported the exact peak memory, but it doesn't report memory allocations
+    outside of python. So if some C++ CUDA extension allocated its own memory it won't be reported. And therefore it
+    was dropped in favor of the memory sampling approach, which reads the current process memory usage.
+
+    The GPU allocated and peak memory reporting is done with `torch.cuda.memory_allocated()` and
+    `torch.cuda.max_memory_allocated()`. This metric reports only "deltas" for pytorch-specific allocations, as
+    `torch.cuda` memory management system doesn't track any memory allocated outside of pytorch. For example, the very
+    first cuda call typically loads CUDA kernels, which may take from 0.5 to 2GB of GPU memory.
+
+    Note that this tracker doesn't account for memory allocations outside of [`Trainer`]'s `__init__`, `train`,
+    `evaluate` and `predict` calls.
+
+    Because `evaluation` calls may happen during `train`, we can't handle nested invocations because
+    `torch.cuda.max_memory_allocated` is a single counter, so if it gets reset by a nested eval call, `train`'s tracker
+    will report incorrect info. If this [pytorch issue](https://github.com/pytorch/pytorch/issues/16266) gets resolved
+    it will be possible to change this class to be re-entrant. Until then we will only track the outer level of
+    `train`, `evaluate` and `predict` methods. Which means that if `eval` is called during `train`, it's the latter
+    that will account for its memory usage and that of the former.
+
+    This also means that if any other tool that is used along the [`Trainer`] calls
+    `torch.cuda.reset_peak_memory_stats`, the gpu peak memory stats could be invalid. And the [`Trainer`] will disrupt
+    the normal behavior of any such tools that rely on calling `torch.cuda.reset_peak_memory_stats` themselves.
+
+    For best performance you may want to consider turning the memory profiling off for production runs.
+    """
+    if not self.is_world_process_zero():
+        return
+
+    print(f"***** {split} metrics *****")
+    metrics_formatted = self.metrics_format(metrics)
+    k_width = max(len(str(x)) for x in metrics_formatted)
+    v_width = max(len(str(x)) for x in metrics_formatted.values())
+    for key in sorted(metrics_formatted.keys()):
+        print(f"  {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
+
+
+def save_metrics(self, split, metrics, combined=True):
+    """
+    Save metrics into a json file for that split, e.g. `train_results.json`.
+
+    Under distributed environment this is done only for a process with rank 0.
+
+    Args:
+        split (`str`):
+            Mode/split name: one of `train`, `eval`, `test`, `all`
+        metrics (`dict[str, float]`):
+            The metrics returned from train/evaluate/predict
+        combined (`bool`, *optional*, defaults to `True`):
+            Creates combined metrics by updating `all_results.json` with metrics of this call
+
+    To understand the metrics please read the docstring of [`~Trainer.log_metrics`]. The only difference is that raw
+    unformatted numbers are saved in the current method.
+
+    """
+    if not self.is_world_process_zero():
+        return
+
+    path = os.path.join(self.args.output_dir, f"{split}_results.json")
+    with open(path, "w") as f:
+        json.dump(metrics, f, indent=4, sort_keys=True)
+
+    if combined:
+        path = os.path.join(self.args.output_dir, "all_results.json")
+        if os.path.exists(path):
+            with open(path) as f:
+                all_metrics = json.load(f)
+        else:
+            all_metrics = {}
+
+        all_metrics.update(metrics)
+        with open(path, "w") as f:
+            json.dump(all_metrics, f, indent=4, sort_keys=True)
+
+
+def save_state(self):
+    """
+    Saves the Trainer state, since Trainer.save_model saves only the tokenizer with the model.
+
+    Under distributed environment this is done only for a process with rank 0.
+    """
+    if not self.is_world_process_zero():
+        return
+
+    path = os.path.join(self.args.output_dir, "trainer_state.json")
+    self.state.save_to_json(path)
+
+
+def get_model_param_count(model, trainable_only=False):
+    """
+    Calculate model's total param count. If trainable_only is True then count only those requiring grads.
+    """
+    if is_deepspeed_zero3_enabled():
+
+        def numel(p):
+            return p.ds_numel if hasattr(p, "ds_numel") else p.numel()
+
+    else:
+
+        def numel(p):
+            return p.numel()
+
+    return sum(numel(p) for p in model.parameters() if not trainable_only or p.requires_grad)
+
+
+def get_parameter_names(model, forbidden_layer_types, forbidden_layer_names=None):
+    """
+    Returns the names of the model parameters that are not inside a forbidden layer.
+    """
+    forbidden_layer_patterns = (
+        [re.compile(pattern) for pattern in forbidden_layer_names] if forbidden_layer_names is not None else []
+    )
+    result = []
+    for name, child in model.named_children():
+        child_params = get_parameter_names(child, forbidden_layer_types, forbidden_layer_names)
+        result += [
+            f"{name}.{n}"
+            for n in child_params
+            if not isinstance(child, tuple(forbidden_layer_types))
+            and not any(pattern.search(f"{name}.{n}".lower()) for pattern in forbidden_layer_patterns)
+        ]
+    # Add model specific parameters that are not in any child
+    result += [
+        k for k in model._parameters if not any(pattern.search(k.lower()) for pattern in forbidden_layer_patterns)
+    ]
+
+    return result
+
+
+def get_module_class_from_name(module, name):
+    """
+    Gets a class from a module by its name.
+
+    Args:
+        module (`torch.nn.Module`): The module to get the class from.
+        name (`str`): The name of the class.
+    """
+    modules_children = list(module.children())
+    if module.__class__.__name__ == name:
+        return module.__class__
+    elif len(modules_children) == 0:
+        return
+    else:
+        for child_module in modules_children:
+            module_class = get_module_class_from_name(child_module, name)
+            if module_class is not None:
+                return module_class
+
+
+def remove_dummy_checkpoint(is_main_process, output_dir, filenames):
+    if is_main_process:
+        for filename in filenames:
+            file = os.path.join(output_dir, filename)
+            if os.path.isfile(file):
+                os.remove(file)
+
+
+if is_sagemaker_mp_enabled():
+    import smdistributed.modelparallel.torch as smp
+
+    @smp.step()
+    def smp_forward_backward(model, inputs, gradient_accumulation_steps=1):
+        outputs = model(**inputs)
+        loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
+        loss /= gradient_accumulation_steps
+        model.backward(loss)
+        return loss
+
+    @smp.step()
+    def smp_forward_only(model, inputs):
+        return model(**inputs)
+
+    def smp_gather(tensor):
+        if isinstance(tensor, (list, tuple)):
+            return type(tensor)(smp_gather(t) for t in tensor)
+        elif isinstance(tensor, dict):
+            return type(tensor)({k: smp_gather(v) for k, v in tensor.items()})
+        elif not isinstance(tensor, torch.Tensor):
+            raise TypeError(
+                f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
+            )
+        all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
+        all_tensors = [atleast_1d(t) for t in all_tensors]
+        return torch.cat([t.cpu() for t in all_tensors], dim=0)
+
+    def smp_nested_concat(tensor):
+        if isinstance(tensor, (list, tuple)):
+            return type(tensor)(smp_nested_concat(t) for t in tensor)
+        elif isinstance(tensor, dict):
+            return type(tensor)({k: smp_nested_concat(v) for k, v in tensor.items()})
+        # It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
+        # which is also the name of the decorator so Python is confused.
+        return tensor.detach().concat().cpu()
+
+
+@dataclass
+class AcceleratorConfig:
+    """
+    A subset of arguments relating to the underlying [`accelerate.Accelerator`]
+    implementation utilized in the `Trainer` that can be customized.
+    Mostly relating to data.
+
+    Parameters:
+        split_batches (`bool`, *optional*, defaults to `False`):
+            Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If
+            `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a
+            round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set
+            in your script multiplied by the number of processes.
+        dispatch_batches (`bool`, *optional*):
+            If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process
+            and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose
+            underlying dataset is an `IterableDataset`, `False` otherwise.
+        even_batches (`bool`, *optional*, defaults to `True`):
+            If set to `True`, in cases where the total batch size across all processes does not exactly divide the
+            dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
+            all workers.
+        use_seedable_sampler (`bool`, *optional*, defaults to `True`):
+            Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`]). Ensures
+            training results are fully reproducible using a different sampling technique. While seed-to-seed results
+            may differ, on average the differences are negligible when using multiple different seeds to compare. Should
+            also be ran with [`~utils.set_seed`] for the best results.
+        gradient_accumulation_kwargs (`dict`, *optional*):
+            Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`].
+            Any of the following (optional) keys are acceptable:
+              num_steps (`int`): Will take precedence over [`~.TrainingArguments.gradient_accumulation_steps`] if
+                the latter is set to 1, otherwise an exception will be raised.
+              adjust_scheduler (`bool`): Whether to adjust the scheduler steps to account for [`~.TrainingArguments.gradient_accumulation_steps`].
+                The [`accelerate.utils.GradientAccumulationPlugin`] default is `True`.
+              sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch.
+                The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`.
+        non_blocking (`bool`, *optional*, defaults to `False`):
+            Whether to use non-blocking CUDA calls to help minimize synchronization during
+            distributed training with prepared `DataLoader` inputs being moved to device.
+            Best if used with `pin_memory=True` in the `TrainingArguments`.
+        use_configured_state (`bool*, *optional*, defaults to `False`):
+            Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined
+            before calling `TrainingArguments`. If `True`, an `Accelerator` or `PartialState`
+            must be initialized. May lead to issues using sweeps or hyperparameter tuning.
+
+    """
+
+    # Data related arguments
+    split_batches: bool = field(
+        default=False,
+        metadata={
+            "help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If"
+            " `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a"
+            " round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set"
+            " in your script multiplied by the number of processes."
+        },
+    )
+    dispatch_batches: Optional[bool] = field(
+        default=None,
+        metadata={
+            "help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process"
+            " and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose"
+            " underlying dataset is an `IterableDataslet`, `False` otherwise."
+        },
+    )
+    even_batches: bool = field(
+        default=True,
+        metadata={
+            "help": "If set to `True`, in cases where the total batch size across all processes does not exactly divide the"
+            " dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among"
+            " all workers."
+        },
+    )
+    use_seedable_sampler: bool = field(
+        default=True,
+        metadata={
+            "help": "Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`])."
+            "Ensures training results are fully reproducible using a different sampling technique. "
+            "While seed-to-seed results may differ, on average the differences are negligible when using"
+            "multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
+        },
+    )
+
+    non_blocking: Optional[bool] = field(
+        default=False,
+        metadata={
+            "help": "Whether to use non-blocking CUDA calls to help minimize synchronization during "
+            "distributed training with prepared `DataLoader` inputs being moved to device. "
+            "Best if used with `pin_memory=True` in the `TrainingArguments`. Requires accelerate "
+            "v0.30.0."
+        },
+    )
+
+    gradient_accumulation_kwargs: Optional[dict] = field(
+        default=None,
+        metadata={
+            "help": "Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`]. "
+            "Any of the following (optional) keys are acceptable: "
+            "  num_steps (`int`): Will take precedence over [`~.TrainingArguments.gradient_accumulation_steps`] if "
+            "    the latter is set to 1, otherwise an exception will be raised. "
+            "  adjust_scheduler (`bool`): Whether to adjust the scheduler steps to account for [`~.TrainingArguments.gradient_accumulation_steps`]. "
+            "    The [`accelerate.utils.GradientAccumulationPlugin`] default is `True`. "
+            "  sync_each_batch (`bool`): Whether to synchronize the gradients at each data batch. "
+            "    The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`."
+        },
+    )
+    use_configured_state: bool = field(
+        default=False,
+        metadata={
+            "help": "Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`."
+            "If `True`, an `Accelerator` or `PartialState` must be initialized. May lead to issues using sweeps or hyperparameter tuning."
+        },
+    )
+
+    @classmethod
+    def from_json_file(cls, json_file):
+        # Check if exists
+        open_file = io.open if os.path.exists(json_file) else open
+        with open_file(json_file, "r", encoding="utf-8") as f:
+            config_dict = json.load(f)
+        # Check for keys and load sensible defaults
+        extra_keys = sorted(key for key in config_dict if key not in cls.__dataclass_fields__)
+        if len(extra_keys) > 0:
+            raise ValueError(
+                f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `transformers`"
+                " version or fix (and potentially remove these keys) from your config file."
+            )
+        return cls(**config_dict)
+
+    def to_dict(self):
+        return copy.deepcopy(self.__dict__)
+
+    def pop(self, key, default=None):
+        return self.__dict__.pop(key, default)
+
+
+class LayerWiseDummyOptimizer(torch.optim.Optimizer):
+    """
+    For Layer-wise optimizers such as GaLoRE optimizer, the optimization
+    step is already done through the post gradient hooks. Therefore
+    the trick is to create a dummy optimizer that can take arbitrary
+    args and kwargs and return a no-op during training.
+
+    Initial idea from @hiyouga in LLaMA-Factory:
+    https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba
+    """
+
+    def __init__(self, optimizer_dict=None, *args, **kwargs):
+        dummy_tensor = torch.randn(1, 1)
+        self.optimizer_dict = optimizer_dict
+        super().__init__([dummy_tensor], {"lr": kwargs.get("lr", 1e-03)})
+
+    def zero_grad(self, set_to_none: bool = True) -> None:
+        pass
+
+    def step(self, closure=None) -> Optional[float]:
+        pass
+
+
+class LayerWiseDummyScheduler(LRScheduler):
+    """
+    For Layer-wise optimizers such as GaLoRE optimizer, the optimization and scheduling step
+    are already done through the post gradient hooks. Therefore
+    the trick is to create a dummy scheduler that can take arbitrary
+    args and kwargs and return a no-op during training.
+    """
+
+    def __init__(self, *args, **kwargs):
+        self.default_lr = kwargs["lr"]
+        optimizer = LayerWiseDummyOptimizer(**kwargs)
+        last_epoch = -1
+        super().__init__(optimizer, last_epoch)
+
+    def get_lr(self):
+        # default value
+        lrs = [self.default_lr]
+
+        # we take each lr in the parameters if they exist, assumes the optimizer to be the `LayerWiseDummyOptimizer`
+        if self.optimizer is not None:
+            param_wise_lrs = [
+                [group["lr"] for group in optim.param_groups] for optim in self.optimizer.optimizer_dict.values()
+            ]
+            lrs = list(chain(*param_wise_lrs))
+
+        return lrs
+
+    def _get_closed_form_lr(self):
+        return self.base_lrs
+
+
+def set_rng_state_for_device(device_name, device_module, checkpoint_rng_state, is_distributed):
+    """Helper to set RNG state for a specific device type (CUDA, NPU, MLU, MUSA)"""
+    device_state_key = device_name.lower()
+    err_template = "Didn't manage to set back the RNG states of the {backend} because of the following error:\n {exception}\nThis won't yield the same results as if the training had not been interrupted."
+    try:
+        if is_distributed:
+            device_module.random.set_rng_state_all(checkpoint_rng_state[device_state_key])
+        else:
+            device_module.random.set_rng_state(checkpoint_rng_state[device_state_key])
+    except Exception as e:
+        # Log error if setting RNG state fails
+        logger.error(err_template.format(backend=device_name, exception=e))
diff --git a/phivenv/Lib/site-packages/transformers/trainer_seq2seq.py b/phivenv/Lib/site-packages/transformers/trainer_seq2seq.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cbcad1f9de3cd2842e36033746efebf5314ac83
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/trainer_seq2seq.py
@@ -0,0 +1,386 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+from copy import deepcopy
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, Optional, Union
+
+import torch
+from torch import nn
+from torch.distributed.fsdp import FullyShardedDataParallel
+from torch.utils.data import Dataset
+
+from .generation.configuration_utils import GenerationConfig
+from .integrations.deepspeed import is_deepspeed_zero3_enabled
+from .integrations.fsdp import is_fsdp_managed_module
+from .trainer import Trainer
+from .utils import is_datasets_available, logging
+from .utils.deprecation import deprecate_kwarg
+
+
+if is_datasets_available():
+    import datasets
+
+if TYPE_CHECKING:
+    from torch.utils.data import IterableDataset
+
+    from .data.data_collator import DataCollator
+    from .feature_extraction_utils import FeatureExtractionMixin
+    from .image_processing_utils import BaseImageProcessor
+    from .modeling_utils import PreTrainedModel
+    from .processing_utils import ProcessorMixin
+    from .tokenization_utils_base import PreTrainedTokenizerBase
+    from .trainer_callback import TrainerCallback
+    from .trainer_utils import EvalPrediction, PredictionOutput
+    from .training_args import TrainingArguments
+
+
+logger = logging.get_logger(__name__)
+
+
+class Seq2SeqTrainer(Trainer):
+    @deprecate_kwarg("tokenizer", new_name="processing_class", version="5.0.0", raise_if_both_names=True)
+    def __init__(
+        self,
+        model: Union["PreTrainedModel", nn.Module] = None,
+        args: "TrainingArguments" = None,
+        data_collator: Optional["DataCollator"] = None,
+        train_dataset: Optional[Union[Dataset, "IterableDataset", "datasets.Dataset"]] = None,
+        eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
+        processing_class: Optional[
+            Union["PreTrainedTokenizerBase", "BaseImageProcessor", "FeatureExtractionMixin", "ProcessorMixin"]
+        ] = None,
+        model_init: Optional[Callable[[], "PreTrainedModel"]] = None,
+        compute_loss_func: Optional[Callable] = None,
+        compute_metrics: Optional[Callable[["EvalPrediction"], dict]] = None,
+        callbacks: Optional[list["TrainerCallback"]] = None,
+        optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
+        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
+    ):
+        super().__init__(
+            model=model,
+            args=args,
+            data_collator=data_collator,
+            train_dataset=train_dataset,
+            eval_dataset=eval_dataset,
+            processing_class=processing_class,
+            model_init=model_init,
+            compute_loss_func=compute_loss_func,
+            compute_metrics=compute_metrics,
+            callbacks=callbacks,
+            optimizers=optimizers,
+            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
+        )
+
+        # Override self.model.generation_config if a GenerationConfig is specified in args.
+        # Priority: args.generation_config > model.generation_config > default GenerationConfig.
+        if self.args.generation_config is not None:
+            gen_config = self.load_generation_config(self.args.generation_config)
+            self.model.generation_config = gen_config
+
+    @staticmethod
+    def load_generation_config(gen_config_arg: Union[str, GenerationConfig]) -> GenerationConfig:
+        """
+        Loads a `~generation.GenerationConfig` from the `Seq2SeqTrainingArguments.generation_config` arguments.
+
+        Args:
+            gen_config_arg (`str` or [`~generation.GenerationConfig]`):
+                `Seq2SeqTrainingArguments.generation_config` argument.
+
+        Returns:
+            A `~generation.GenerationConfig`.
+        """
+
+        # GenerationConfig provided, nothing to do
+        if isinstance(gen_config_arg, GenerationConfig):
+            gen_config = deepcopy(gen_config_arg)
+        else:
+            # str or Path
+            pretrained_model_name = Path(gen_config_arg) if isinstance(gen_config_arg, str) else gen_config_arg
+            config_file_name = None
+
+            # Figuring if it is path pointing to a file, pointing to a directory or else a model id or URL
+            # This step is required in order to determine config_file_name
+            if pretrained_model_name.is_file():
+                config_file_name = pretrained_model_name.name
+                pretrained_model_name = pretrained_model_name.parent
+            # dir path
+            elif pretrained_model_name.is_dir():
+                pass
+            # model id or URL
+            else:
+                pretrained_model_name = gen_config_arg
+
+            gen_config = GenerationConfig.from_pretrained(pretrained_model_name, config_file_name)
+
+        # Strict validation to fail early. `GenerationConfig.save_pretrained()`, run at the end of training, throws
+        # an exception if there are warnings at validation time.
+        try:
+            gen_config.validate(strict=True)
+        except ValueError as exc:
+            raise ValueError(str(exc) + "\n\nFix these issues to train your model.")
+
+        return gen_config
+
+    def evaluate(
+        self,
+        eval_dataset: Optional[Dataset] = None,
+        ignore_keys: Optional[list[str]] = None,
+        metric_key_prefix: str = "eval",
+        **gen_kwargs,
+    ) -> dict[str, float]:
+        """
+        Run evaluation and returns metrics.
+
+        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
+        (pass it to the init `compute_metrics` argument).
+
+        You can also subclass and override this method to inject custom behavior.
+
+        Args:
+            eval_dataset (`Dataset`, *optional*):
+                Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
+                not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
+                method.
+            ignore_keys (`list[str]`, *optional*):
+                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+                gathering predictions.
+            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
+                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
+                "eval_bleu" if the prefix is `"eval"` (default)
+            max_length (`int`, *optional*):
+                The maximum target length to use when predicting with the generate method.
+            num_beams (`int`, *optional*):
+                Number of beams for beam search that will be used when predicting with the generate method. 1 means no
+                beam search.
+            gen_kwargs:
+                Additional `generate` specific kwargs.
+
+        Returns:
+            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
+            dictionary also contains the epoch number which comes from the training state.
+        """
+
+        gen_kwargs = gen_kwargs.copy()
+
+        # Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
+        # training args
+        if (
+            gen_kwargs.get("max_length") is None
+            and gen_kwargs.get("max_new_tokens") is None
+            and self.args.generation_max_length is not None
+        ):
+            gen_kwargs["max_length"] = self.args.generation_max_length
+        if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
+            gen_kwargs["num_beams"] = self.args.generation_num_beams
+        # We don't want to drop samples in general
+        self.gather_function = self.accelerator.gather
+        self._gen_kwargs = gen_kwargs
+        return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
+
+    def predict(
+        self,
+        test_dataset: Dataset,
+        ignore_keys: Optional[list[str]] = None,
+        metric_key_prefix: str = "test",
+        **gen_kwargs,
+    ) -> "PredictionOutput":
+        """
+        Run prediction and returns predictions and potential metrics.
+
+        Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
+        will also return metrics, like in `evaluate()`.
+
+        Args:
+            test_dataset (`Dataset`):
+                Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
+                `model.forward()` method are automatically removed. Has to implement the method `__len__`
+            ignore_keys (`list[str]`, *optional*):
+                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+                gathering predictions.
+            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
+                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
+                "eval_bleu" if the prefix is `"eval"` (default)
+            max_length (`int`, *optional*):
+                The maximum target length to use when predicting with the generate method.
+            num_beams (`int`, *optional*):
+                Number of beams for beam search that will be used when predicting with the generate method. 1 means no
+                beam search.
+            gen_kwargs:
+                Additional `generate` specific kwargs.
+
+        
+
+        If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
+        padding in a token classification task) the predictions will be padded (on the right) to allow for
+        concatenation into one array. The padding index is -100.
+
+        
+
+        Returns: *NamedTuple* A namedtuple with the following keys:
+
+            - predictions (`np.ndarray`): The predictions on `test_dataset`.
+            - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
+            - metrics (`dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
+              labels).
+        """
+
+        gen_kwargs = gen_kwargs.copy()
+
+        # Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
+        # training args
+        if (
+            gen_kwargs.get("max_length") is None
+            and gen_kwargs.get("max_new_tokens") is None
+            and self.args.generation_max_length is not None
+        ):
+            gen_kwargs["max_length"] = self.args.generation_max_length
+        if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
+            gen_kwargs["num_beams"] = self.args.generation_num_beams
+        self.gather_function = self.accelerator.gather
+        self._gen_kwargs = gen_kwargs
+
+        return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
+
+    def prediction_step(
+        self,
+        model: nn.Module,
+        inputs: dict[str, Union[torch.Tensor, Any]],
+        prediction_loss_only: bool,
+        ignore_keys: Optional[list[str]] = None,
+        **gen_kwargs,
+    ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
+        """
+        Perform an evaluation step on `model` using `inputs`.
+
+        Subclass and override to inject custom behavior.
+
+        Args:
+            model (`nn.Module`):
+                The model to evaluate.
+            inputs (`dict[str, Union[torch.Tensor, Any]]`):
+                The inputs and targets of the model.
+
+                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
+                argument `labels`. Check your model's documentation for all accepted arguments.
+            prediction_loss_only (`bool`):
+                Whether or not to return the loss only.
+            gen_kwargs:
+                Additional `generate` specific kwargs.
+
+        Return:
+            tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
+            labels (each being optional).
+        """
+
+        if not self.args.predict_with_generate or prediction_loss_only:
+            return super().prediction_step(
+                model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
+            )
+
+        has_labels = "labels" in inputs
+        inputs = self._prepare_inputs(inputs)
+
+        # Priority (handled in generate):
+        # non-`None` gen_kwargs > model.generation_config > default GenerationConfig()
+        if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"):
+            gen_kwargs = self._gen_kwargs.copy()
+        if "num_beams" in gen_kwargs and gen_kwargs["num_beams"] is None:
+            gen_kwargs.pop("num_beams")
+        if "max_length" in gen_kwargs and gen_kwargs["max_length"] is None:
+            gen_kwargs.pop("max_length")
+
+        default_synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self.model)
+        gen_kwargs["synced_gpus"] = gen_kwargs.get("synced_gpus", default_synced_gpus)
+
+        generation_inputs = inputs.copy()
+        # If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate
+        # (otherwise, it would continue generating from the padded `decoder_input_ids`)
+        if (
+            "labels" in generation_inputs
+            and "decoder_input_ids" in generation_inputs
+            and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
+        ):
+            generation_inputs = {
+                k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
+            }
+
+        summon_full_params_context = (
+            FullyShardedDataParallel.summon_full_params(self.model)
+            if isinstance(self.model, FullyShardedDataParallel)
+            else contextlib.nullcontext()
+        )
+
+        with summon_full_params_context:
+            generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
+
+        # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
+        # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
+        # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
+        if self.model.generation_config._from_model_config:
+            self.model.generation_config._from_model_config = False
+
+        # Retrieves GenerationConfig from model.generation_config
+        gen_config = self.model.generation_config
+        # in case the batch is shorter than max length, the output should be padded
+        if generated_tokens.shape[-1] < gen_config.max_length:
+            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length)
+        elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1:
+            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1)
+
+        with torch.no_grad():
+            if has_labels:
+                with self.compute_loss_context_manager():
+                    outputs = model(**inputs)
+                if self.label_smoother is not None:
+                    loss = self.label_smoother(outputs, inputs["labels"]).detach().mean()
+                else:
+                    loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).detach().mean()
+            else:
+                loss = None
+
+        if self.args.prediction_loss_only:
+            return loss, None, None
+
+        if has_labels:
+            labels = inputs["labels"]
+            if labels.shape[-1] < gen_config.max_length:
+                labels = self._pad_tensors_to_max_len(labels, gen_config.max_length)
+            elif gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1:
+                labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1)
+        else:
+            labels = None
+
+        return loss, generated_tokens, labels
+
+    def _pad_tensors_to_max_len(self, tensor, max_length):
+        if self.processing_class is not None and hasattr(self.processing_class, "pad_token_id"):
+            # If PAD token is not defined at least EOS token has to be defined
+            pad_token_id = (
+                self.processing_class.pad_token_id
+                if self.processing_class.pad_token_id is not None
+                else self.processing_class.eos_token_id
+            )
+        else:
+            if self.model.config.pad_token_id is not None:
+                pad_token_id = self.model.config.pad_token_id
+            else:
+                raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
+
+        padded_tensor = pad_token_id * torch.ones(
+            (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
+        )
+        padded_tensor[:, : tensor.shape[-1]] = tensor
+        return padded_tensor
diff --git a/phivenv/Lib/site-packages/transformers/trainer_utils.py b/phivenv/Lib/site-packages/transformers/trainer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f09617dc8bfdfd5407b0b8c056d87301813187d
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/trainer_utils.py
@@ -0,0 +1,910 @@
+# Copyright 2020-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+PyTorch-independent utilities for the Trainer class.
+"""
+
+import copy
+import functools
+import gc
+import inspect
+import os
+import random
+import re
+import threading
+import time
+from typing import Any, NamedTuple, Optional, Union
+
+import numpy as np
+
+from .utils import (
+    ExplicitEnum,
+    is_psutil_available,
+    is_tf_available,
+    is_torch_available,
+    is_torch_cuda_available,
+    is_torch_hpu_available,
+    is_torch_mlu_available,
+    is_torch_mps_available,
+    is_torch_musa_available,
+    is_torch_npu_available,
+    is_torch_xla_available,
+    is_torch_xpu_available,
+    requires_backends,
+)
+
+
+if is_torch_available():
+    import torch
+
+
+def seed_worker(worker_id: int, num_workers: int, rank: int):
+    """
+    Helper function to set worker seed during Dataloader initialization.
+    """
+    init_seed = torch.initial_seed() % 2**32
+    worker_seed = num_workers * rank + init_seed
+    set_seed(worker_seed)
+
+
+def enable_full_determinism(seed: int, warn_only: bool = False):
+    """
+    Helper function for reproducible behavior during distributed training. See
+    - https://pytorch.org/docs/stable/notes/randomness.html for pytorch
+    - https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism for tensorflow
+    """
+    # set seed first
+    set_seed(seed)
+
+    if is_torch_available():
+        # Enable PyTorch deterministic mode. This potentially requires either the environment
+        # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
+        # depending on the CUDA version, so we set them both here
+        os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
+        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
+        # The environment variable required to enable deterministic mode on Ascend NPUs.
+        os.environ["ASCEND_LAUNCH_BLOCKING"] = "1"
+        os.environ["HCCL_DETERMINISTIC"] = "1"
+
+        os.environ["FLASH_ATTENTION_DETERMINISTIC"] = "1"
+        torch.use_deterministic_algorithms(True, warn_only=warn_only)
+
+        # Enable CUDNN deterministic mode
+        torch.backends.cudnn.deterministic = True
+        torch.backends.cudnn.benchmark = False
+
+    if is_tf_available():
+        import tensorflow as tf
+
+        tf.config.experimental.enable_op_determinism()
+
+
+def set_seed(seed: int, deterministic: bool = False):
+    """
+    Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed).
+
+    Args:
+        seed (`int`):
+            The seed to set.
+        deterministic (`bool`, *optional*, defaults to `False`):
+            Whether to use deterministic algorithms where available. Can slow down training.
+    """
+    random.seed(seed)
+    np.random.seed(seed)
+    if is_torch_available():
+        torch.manual_seed(seed)
+        torch.cuda.manual_seed_all(seed)
+        # ^^ safe to call this function even if cuda is not available
+        if deterministic:
+            torch.use_deterministic_algorithms(True)
+    if is_torch_mlu_available():
+        torch.mlu.manual_seed_all(seed)
+    if is_torch_musa_available():
+        torch.musa.manual_seed_all(seed)
+    if is_torch_npu_available():
+        torch.npu.manual_seed_all(seed)
+    if is_torch_hpu_available():
+        torch.hpu.manual_seed_all(seed)
+    if is_torch_xpu_available():
+        torch.xpu.manual_seed_all(seed)
+    if is_tf_available():
+        import tensorflow as tf
+
+        tf.random.set_seed(seed)
+        if deterministic:
+            tf.config.experimental.enable_op_determinism()
+
+
+def neftune_post_forward_hook(module, input, output):
+    """
+    Implements the NEFTune forward pass for the model using forward hooks. Note this works only for torch.nn.Embedding
+    layers. This method is slightly adapted from the original source code that can be found here:
+    https://github.com/neelsjain/NEFTune Simply add it to your model as follows:
+    ```python
+    model = ...
+    model.embed_tokens.neftune_noise_alpha = 0.1
+    model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
+    ```
+    Args:
+        module (`torch.nn.Module`):
+            The embedding module where the hook is attached. Note that you need to set `module.neftune_noise_alpha` to
+            the desired noise alpha value.
+        input (`torch.Tensor`):
+            The input tensor to the model.
+        output (`torch.Tensor`):
+            The output tensor of the model (i.e. the embeddings).
+    """
+    if module.training:
+        dims = torch.tensor(output.size(1) * output.size(2))
+        mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
+        output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
+    return output
+
+
+class EvalPrediction:
+    """
+    Evaluation output (always contains labels), to be used to compute metrics.
+
+    Parameters:
+        predictions (`np.ndarray`): Predictions of the model.
+        label_ids (`np.ndarray`): Targets to be matched.
+        inputs (`np.ndarray`, *optional*): Input data passed to the model.
+        losses (`np.ndarray`, *optional*): Loss values computed during evaluation.
+    """
+
+    def __init__(
+        self,
+        predictions: Union[np.ndarray, tuple[np.ndarray]],
+        label_ids: Union[np.ndarray, tuple[np.ndarray]],
+        inputs: Optional[Union[np.ndarray, tuple[np.ndarray]]] = None,
+        losses: Optional[Union[np.ndarray, tuple[np.ndarray]]] = None,
+    ):
+        self.predictions = predictions
+        self.label_ids = label_ids
+        self.inputs = inputs
+        self.losses = losses
+        self.elements = (self.predictions, self.label_ids)
+        if self.inputs is not None:
+            self.elements += (self.inputs,)
+        if self.losses is not None:
+            self.elements += (self.losses,)
+
+    def __iter__(self):
+        return iter(self.elements)
+
+    def __getitem__(self, idx):
+        if idx < 0 or idx >= len(self.elements):
+            raise IndexError("tuple index out of range")
+        return self.elements[idx]
+
+
+class EvalLoopOutput(NamedTuple):
+    predictions: Union[np.ndarray, tuple[np.ndarray]]
+    label_ids: Optional[Union[np.ndarray, tuple[np.ndarray]]]
+    metrics: Optional[dict[str, float]]
+    num_samples: Optional[int]
+
+
+class PredictionOutput(NamedTuple):
+    predictions: Union[np.ndarray, tuple[np.ndarray]]
+    label_ids: Optional[Union[np.ndarray, tuple[np.ndarray]]]
+    metrics: Optional[dict[str, float]]
+
+
+class TrainOutput(NamedTuple):
+    global_step: int
+    training_loss: float
+    metrics: dict[str, float]
+
+
+PREFIX_CHECKPOINT_DIR = "checkpoint"
+_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
+
+
+def get_last_checkpoint(folder):
+    content = os.listdir(folder)
+    checkpoints = [
+        path
+        for path in content
+        if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path))
+    ]
+    if len(checkpoints) == 0:
+        return
+    return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
+
+
+class IntervalStrategy(ExplicitEnum):
+    NO = "no"
+    STEPS = "steps"
+    EPOCH = "epoch"
+
+
+class SaveStrategy(ExplicitEnum):
+    NO = "no"
+    STEPS = "steps"
+    EPOCH = "epoch"
+    BEST = "best"
+
+
+class EvaluationStrategy(ExplicitEnum):
+    NO = "no"
+    STEPS = "steps"
+    EPOCH = "epoch"
+
+
+class HubStrategy(ExplicitEnum):
+    END = "end"
+    EVERY_SAVE = "every_save"
+    CHECKPOINT = "checkpoint"
+    ALL_CHECKPOINTS = "all_checkpoints"
+
+
+class BestRun(NamedTuple):
+    """
+    The best run found by a hyperparameter search (see [`~Trainer.hyperparameter_search`]).
+
+    Parameters:
+        run_id (`str`):
+            The id of the best run (if models were saved, the corresponding checkpoint will be in the folder ending
+            with run-{run_id}).
+        objective (`float`):
+            The objective that was obtained for this run.
+        hyperparameters (`dict[str, Any]`):
+            The hyperparameters picked to get this run.
+        run_summary (`Optional[Any]`):
+            A summary of tuning experiments. `ray.tune.ExperimentAnalysis` object for Ray backend.
+    """
+
+    run_id: str
+    objective: Union[float, list[float]]
+    hyperparameters: dict[str, Any]
+    run_summary: Optional[Any] = None
+
+
+def default_compute_objective(metrics: dict[str, float]) -> float:
+    """
+    The default objective to maximize/minimize when doing an hyperparameter search. It is the evaluation loss if no
+    metrics are provided to the [`Trainer`], the sum of all metrics otherwise.
+
+    Args:
+        metrics (`dict[str, float]`): The metrics returned by the evaluate method.
+
+    Return:
+        `float`: The objective to minimize or maximize
+    """
+    metrics = copy.deepcopy(metrics)
+    loss = metrics.pop("eval_loss", None)
+    _ = metrics.pop("epoch", None)
+    # Remove speed metrics
+    speed_metrics = [
+        m for m in metrics if m.endswith("_runtime") or m.endswith("_per_second") or m.endswith("_compilation_time")
+    ]
+    for sm in speed_metrics:
+        _ = metrics.pop(sm, None)
+    return loss if len(metrics) == 0 else sum(metrics.values())
+
+
+def default_hp_space_optuna(trial) -> dict[str, float]:
+    from .integrations import is_optuna_available
+
+    assert is_optuna_available(), "This function needs Optuna installed: `pip install optuna`"
+    return {
+        "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
+        "num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5),
+        "seed": trial.suggest_int("seed", 1, 40),
+        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16, 32, 64]),
+    }
+
+
+def default_hp_space_ray(trial) -> dict[str, float]:
+    from .integrations import is_ray_tune_available
+
+    assert is_ray_tune_available(), "This function needs ray installed: `pip install ray[tune]`"
+    from ray import tune
+
+    return {
+        "learning_rate": tune.loguniform(1e-6, 1e-4),
+        "num_train_epochs": tune.choice(list(range(1, 6))),
+        "seed": tune.uniform(1, 40),
+        "per_device_train_batch_size": tune.choice([4, 8, 16, 32, 64]),
+    }
+
+
+def default_hp_space_sigopt(trial):
+    return [
+        {"bounds": {"min": 1e-6, "max": 1e-4}, "name": "learning_rate", "type": "double", "transformation": "log"},
+        {"bounds": {"min": 1, "max": 6}, "name": "num_train_epochs", "type": "int"},
+        {"bounds": {"min": 1, "max": 40}, "name": "seed", "type": "int"},
+        {
+            "categorical_values": ["4", "8", "16", "32", "64"],
+            "name": "per_device_train_batch_size",
+            "type": "categorical",
+        },
+    ]
+
+
+def default_hp_space_wandb(trial) -> dict[str, float]:
+    from .integrations import is_wandb_available
+
+    if not is_wandb_available():
+        raise ImportError("This function needs wandb installed: `pip install wandb`")
+
+    return {
+        "method": "random",
+        "metric": {"name": "objective", "goal": "minimize"},
+        "parameters": {
+            "learning_rate": {"distribution": "uniform", "min": 1e-6, "max": 1e-4},
+            "num_train_epochs": {"distribution": "int_uniform", "min": 1, "max": 6},
+            "seed": {"distribution": "int_uniform", "min": 1, "max": 40},
+            "per_device_train_batch_size": {"values": [4, 8, 16, 32, 64]},
+        },
+    }
+
+
+class HPSearchBackend(ExplicitEnum):
+    OPTUNA = "optuna"
+    RAY = "ray"
+    SIGOPT = "sigopt"
+    WANDB = "wandb"
+
+
+def is_main_process(local_rank):
+    """
+    Whether or not the current process is the local process, based on `xr.global_ordinal()` (for TPUs) first, then on
+    `local_rank`.
+    """
+    if is_torch_xla_available():
+        import torch_xla.runtime as xr
+
+        return xr.global_ordinal() == 0
+    return local_rank in [-1, 0]
+
+
+def total_processes_number(local_rank):
+    """
+    Return the number of processes launched in parallel. Works with `torch.distributed` and TPUs.
+    """
+    if is_torch_xla_available():
+        import torch_xla.runtime as xr
+
+        return xr.world_size()
+    elif local_rank != -1 and is_torch_available():
+        import torch
+
+        return torch.distributed.get_world_size()
+    return 1
+
+
+def speed_metrics(split, start_time, num_samples=None, num_steps=None, num_tokens=None):
+    """
+    Measure and return speed performance metrics.
+
+    This function requires a time snapshot `start_time` before the operation to be measured starts and this function
+    should be run immediately after the operation to be measured has completed.
+
+    Args:
+    - split: name to prefix metric (like train, eval, test...)
+    - start_time: operation start time
+    - num_samples: number of samples processed
+    - num_steps: number of steps processed
+    - num_tokens: number of tokens processed
+    """
+    runtime = time.time() - start_time
+    result = {f"{split}_runtime": round(runtime, 4)}
+    if runtime == 0:
+        return result
+    if num_samples is not None:
+        samples_per_second = num_samples / runtime
+        result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
+    if num_steps is not None:
+        steps_per_second = num_steps / runtime
+        result[f"{split}_steps_per_second"] = round(steps_per_second, 3)
+    if num_tokens is not None:
+        tokens_per_second = num_tokens / runtime
+        result[f"{split}_tokens_per_second"] = round(tokens_per_second, 3)
+    return result
+
+
+class SchedulerType(ExplicitEnum):
+    """
+    Scheduler names for the parameter `lr_scheduler_type` in [`TrainingArguments`].
+    By default, it uses "linear". Internally, this retrieves `get_linear_schedule_with_warmup` scheduler from [`Trainer`].
+    Scheduler types:
+       - "linear" = get_linear_schedule_with_warmup
+       - "cosine" = get_cosine_schedule_with_warmup
+       - "cosine_with_restarts" = get_cosine_with_hard_restarts_schedule_with_warmup
+       - "polynomial" = get_polynomial_decay_schedule_with_warmup
+       - "constant" =  get_constant_schedule
+       - "constant_with_warmup" = get_constant_schedule_with_warmup
+       - "inverse_sqrt" = get_inverse_sqrt_schedule
+       - "reduce_lr_on_plateau" = get_reduce_on_plateau_schedule
+       - "cosine_with_min_lr" = get_cosine_with_min_lr_schedule_with_warmup
+       - "warmup_stable_decay" = get_wsd_schedule
+    """
+
+    LINEAR = "linear"
+    COSINE = "cosine"
+    COSINE_WITH_RESTARTS = "cosine_with_restarts"
+    POLYNOMIAL = "polynomial"
+    CONSTANT = "constant"
+    CONSTANT_WITH_WARMUP = "constant_with_warmup"
+    INVERSE_SQRT = "inverse_sqrt"
+    REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"
+    COSINE_WITH_MIN_LR = "cosine_with_min_lr"
+    COSINE_WARMUP_WITH_MIN_LR = "cosine_warmup_with_min_lr"
+    WARMUP_STABLE_DECAY = "warmup_stable_decay"
+
+
+class TrainerMemoryTracker:
+    """
+    A helper class that tracks cpu and gpu memory.
+
+    This class will silently skip unless `psutil` is available. Install with `pip install psutil`.
+
+    When a stage completes, it can pass metrics dict to update with the memory metrics gathered during this stage.
+
+    Example :
+
+    ```python
+    self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
+    self._memory_tracker.start()
+    # code ...
+    metrics = {"train_runtime": 10.5}
+    self._memory_tracker.stop_and_update_metrics(metrics)
+    ```
+
+    At the moment GPU tracking is only for `pytorch`, but can be extended to support `tensorflow`.
+
+    To understand this class' intricacies please read the documentation of [`~Trainer.log_metrics`].
+    """
+
+    # map trainer methods to metrics prefix
+    stages = {
+        "__init__": "init",
+        "train": "train",
+        "_inner_training_loop": "train",
+        "evaluate": "eval",
+        "predict": "test",
+    }
+
+    def __init__(self, skip_memory_metrics=False):
+        self.skip_memory_metrics = skip_memory_metrics
+
+        if not is_psutil_available():
+            # soft dependency on psutil
+            self.skip_memory_metrics = True
+
+        if self.skip_memory_metrics:
+            return
+
+        import psutil  # noqa
+
+        if is_torch_cuda_available() or is_torch_mlu_available() or is_torch_musa_available():
+            import torch
+
+            self.torch = torch
+            self.gpu = {}
+        elif is_torch_mps_available():
+            import torch
+
+            self.torch = torch
+            self.gpu = {}
+        elif is_torch_xpu_available():
+            import torch
+
+            self.torch = torch
+            self.gpu = {}
+        elif is_torch_npu_available():
+            import torch
+
+            self.torch = torch
+            self.gpu = {}
+        elif is_torch_hpu_available():
+            import torch
+
+            self.torch = torch
+            self.gpu = {}
+        else:
+            self.torch = None
+
+        self.process = psutil.Process()
+
+        self.cur_stage = None
+        self.cpu = {}
+        self.init_reported = False
+
+    def derive_stage(self):
+        """derives the stage/caller name automatically"""
+        caller = inspect.currentframe().f_back.f_back.f_code.co_name
+        if caller in self.stages:
+            return self.stages[caller]
+        else:
+            raise ValueError(
+                f"was called from {caller}, but only expect to be called from one of {self.stages.keys()}"
+            )
+
+    def cpu_mem_used(self):
+        """get resident set size memory for the current process"""
+        return self.process.memory_info().rss
+
+    def peak_monitor_func(self):
+        self.cpu_mem_used_peak = -1
+
+        while True:
+            self.cpu_mem_used_peak = max(self.cpu_mem_used(), self.cpu_mem_used_peak)
+
+            # can't sleep or will not catch the peak right (this comment is here on purpose)
+            # time.sleep(0.001) # 1msec
+
+            if not self.peak_monitoring:
+                break
+
+    def start(self):
+        """start tracking for the caller's stage"""
+        if self.skip_memory_metrics:
+            return
+
+        stage = self.derive_stage()
+        # deal with nested calls of eval during train - simply ignore those
+        if self.cur_stage is not None and self.cur_stage != stage:
+            return
+
+        self.cur_stage = stage
+
+        gc.collect()
+
+        if self.torch is not None:
+            if torch.cuda.is_available():
+                self.torch.cuda.reset_peak_memory_stats()
+                self.torch.cuda.empty_cache()
+            elif is_torch_mlu_available():
+                self.torch.mlu.reset_peak_memory_stats()
+                self.torch.mlu.empty_cache()
+            elif is_torch_musa_available():
+                self.torch.musa.reset_peak_memory_stats()
+                self.torch.musa.empty_cache()
+            elif is_torch_xpu_available():
+                self.torch.xpu.reset_peak_memory_stats()
+                self.torch.xpu.empty_cache()
+            elif is_torch_npu_available():
+                self.torch.npu.reset_peak_memory_stats()
+                self.torch.npu.empty_cache()
+            elif is_torch_hpu_available():
+                self.torch.hpu.reset_peak_memory_stats()
+                # not available on hpu as it reserves all device memory for the current process
+                # self.torch.hpu.empty_cache()
+            elif is_torch_mps_available():
+                self.torch.mps.empty_cache()
+
+        # gpu
+        if self.torch is not None:
+            if torch.cuda.is_available():
+                self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated()
+            elif is_torch_mlu_available():
+                self.gpu_mem_used_at_start = self.torch.mlu.memory_allocated()
+            elif is_torch_musa_available():
+                self.gpu_mem_used_at_start = self.torch.musa.memory_allocated()
+            elif is_torch_xpu_available():
+                self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated()
+            elif is_torch_npu_available():
+                self.gpu_mem_used_at_start = self.torch.npu.memory_allocated()
+            elif is_torch_hpu_available():
+                self.gpu_mem_used_at_start = self.torch.hpu.memory_allocated()
+            elif is_torch_mps_available():
+                self.gpu_mem_used_at_start = self.torch.mps.current_allocated_memory()
+
+        # cpu
+        self.cpu_mem_used_at_start = self.cpu_mem_used()
+
+        self.peak_monitoring = True
+        peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
+        peak_monitor_thread.daemon = True
+        peak_monitor_thread.start()
+
+    def stop(self, stage):
+        """stop tracking for the passed stage"""
+
+        # deal with nested calls of eval during train - simply ignore those
+        if self.cur_stage is not None and self.cur_stage != stage:
+            return
+
+        # this sends a signal to peak_monitor_func to complete its loop
+        self.peak_monitoring = False
+
+        # first ensure all objects get collected and their memory is freed
+        gc.collect()
+
+        if self.torch is not None:
+            if torch.cuda.is_available():
+                self.torch.cuda.empty_cache()
+            elif is_torch_mlu_available():
+                self.torch.mlu.empty_cache()
+            elif is_torch_musa_available():
+                self.torch.musa.empty_cache()
+            elif is_torch_xpu_available():
+                self.torch.xpu.empty_cache()
+            elif is_torch_npu_available():
+                self.torch.npu.empty_cache()
+            elif is_torch_hpu_available():
+                # not available on hpu as it reserves all device memory for the current process
+                # self.torch.npu.empty_cache()
+                pass
+            elif is_torch_mps_available():
+                self.torch.mps.empty_cache()
+
+        # concepts:
+        # - alloc_delta:  the difference of allocated memory between the end and the start
+        # - peaked_delta: the difference between the peak memory and the current memory
+        # in order to know how much memory the measured code consumed one needs to sum these two
+
+        # gpu
+        if self.torch is not None:
+            if torch.cuda.is_available():
+                self.gpu_mem_used_now = self.torch.cuda.memory_allocated()
+                self.gpu_mem_used_peak = self.torch.cuda.max_memory_allocated()
+            elif is_torch_mlu_available():
+                self.gpu_mem_used_now = self.torch.mlu.memory_allocated()
+                self.gpu_mem_used_peak = self.torch.mlu.max_memory_allocated()
+            elif is_torch_musa_available():
+                self.gpu_mem_used_now = self.torch.musa.memory_allocated()
+                self.gpu_mem_used_peak = self.torch.musa.max_memory_allocated()
+            elif is_torch_xpu_available():
+                self.gpu_mem_used_now = self.torch.xpu.memory_allocated()
+                self.gpu_mem_used_peak = self.torch.xpu.max_memory_allocated()
+            elif is_torch_npu_available():
+                self.gpu_mem_used_now = self.torch.npu.memory_allocated()
+                self.gpu_mem_used_peak = self.torch.npu.max_memory_allocated()
+            elif is_torch_hpu_available():
+                self.gpu_mem_used_now = self.torch.hpu.memory_allocated()
+                self.gpu_mem_used_peak = self.torch.hpu.max_memory_allocated()
+            elif is_torch_mps_available():
+                self.gpu_mem_used_now = self.torch.mps.current_allocated_memory()
+                # self.torch.mps.max_memory_allocated() does not exist yet
+                self.gpu_mem_used_peak = None
+
+            else:
+                raise ValueError("No available GPU device found!")
+
+            self.gpu[self.cur_stage] = {
+                "begin": self.gpu_mem_used_at_start,
+                "end": self.gpu_mem_used_now,
+                "alloc": (self.gpu_mem_used_now - self.gpu_mem_used_at_start),
+            }
+            if self.gpu_mem_used_peak is not None:
+                self.gpu[self.cur_stage]["peaked"] = max(0, self.gpu_mem_used_peak - self.gpu_mem_used_now)
+            else:
+                self.gpu[self.cur_stage]["peaked"] = "Not available"
+
+        # cpu
+        self.cpu_mem_used_now = self.cpu_mem_used()
+        self.cpu[self.cur_stage] = {
+            "begin": self.cpu_mem_used_at_start,
+            "end": self.cpu_mem_used_now,
+            "alloc": (self.cpu_mem_used_now - self.cpu_mem_used_at_start),
+            "peaked": max(0, self.cpu_mem_used_peak - self.cpu_mem_used_now),
+        }
+
+        # reset - cycle finished
+        self.cur_stage = None
+
+    def update_metrics(self, stage, metrics):
+        """updates the metrics"""
+        if self.skip_memory_metrics:
+            return
+
+        # deal with nested calls of eval during train - simply ignore those
+        if self.cur_stage is not None and self.cur_stage != stage:
+            return
+
+        # since we don't have a way to return init metrics, we push them into the first of train/val/predict
+        stages = [stage]
+        if not self.init_reported:
+            stages.insert(0, "init")
+            self.init_reported = True
+
+        for stage in stages:
+            for t in ["alloc", "peaked"]:
+                if stage in self.cpu and t in self.cpu[stage]:
+                    metrics[f"{stage}_mem_cpu_{t}_delta"] = self.cpu[stage][t]
+                if self.torch is not None and stage in self.gpu and t in self.gpu[stage]:
+                    metrics[f"{stage}_mem_gpu_{t}_delta"] = self.gpu[stage][t]
+            # if we need additional debug info, enable the following
+            # for t in ["begin", "end"]:
+            #     if stage in self.cpu and t in self.cpu[stage]:
+            #         metrics[f"{stage}_mem_cpu_{t}"] = self.cpu[stage][t]
+            #     if self.torch is not None and stage in self.gpu and t in self.gpu[stage]:
+            #         metrics[f"{stage}_mem_gpu_{t}"] = self.gpu[stage][t]
+
+        # since memory can be allocated before init, and it might be difficult to track overall
+        # memory usage, in particular for GPU, let's report memory usage at the point init was called
+        if stages[0] == "init":
+            metrics["before_init_mem_cpu"] = self.cpu["init"]["begin"]
+            if self.torch is not None:
+                metrics["before_init_mem_gpu"] = self.gpu["init"]["begin"]
+            # if we also wanted to report any additional memory allocations in between init and
+            # whatever the next stage was we could also report this:
+            # if self.cpu["init"]["end"] != self.cpu[stage]["begin"]:
+            #     metrics[f"after_init_mem_cpu_delta"] = self.cpu[stage]["begin"] - self.cpu["init"]["end"]
+            # if self.torch is not None and self.gpu["init"]["end"] != self.gpu[stage]["begin"]:
+            #     metrics[f"after_init_mem_gpu_delta"] = self.gpu[stage]["begin"] - self.gpu["init"]["end"]
+
+    def stop_and_update_metrics(self, metrics=None):
+        """combine stop and metrics update in one call for simpler code"""
+        if self.skip_memory_metrics:
+            return
+
+        stage = self.derive_stage()
+        self.stop(stage)
+
+        # init doesn't have metrics to update so we just save that data for later stages to retrieve
+        if metrics is not None:
+            self.update_metrics(stage, metrics)
+
+
+def has_length(dataset):
+    """
+    Checks if the dataset implements __len__() and it doesn't raise an error
+    """
+    try:
+        return len(dataset) is not None
+    except TypeError:
+        # TypeError: len() of unsized object
+        return False
+    except AttributeError:
+        # Ray DataSets raises an AttributeError: https://github.com/ray-project/ray/blob/master/python/ray/data/dataset.py#L5616
+        return False
+
+
+def denumpify_detensorize(metrics):
+    """
+    Recursively calls `.item()` on the element of the dictionary passed
+    """
+    if isinstance(metrics, (list, tuple)):
+        return type(metrics)(denumpify_detensorize(m) for m in metrics)
+    elif isinstance(metrics, dict):
+        return type(metrics)({k: denumpify_detensorize(v) for k, v in metrics.items()})
+    elif isinstance(metrics, np.generic):
+        return metrics.item()
+    elif is_torch_available() and isinstance(metrics, torch.Tensor) and metrics.numel() == 1:
+        return metrics.item()
+    return metrics
+
+
+def number_of_arguments(func):
+    """
+    Return the number of arguments of the passed function, even if it's a partial function.
+    """
+    if isinstance(func, functools.partial):
+        total_args = len(inspect.signature(func.func).parameters)
+        return total_args - len(func.args) - len(func.keywords)
+    return len(inspect.signature(func).parameters)
+
+
+def find_executable_batch_size(
+    function: Optional[callable] = None, starting_batch_size: int = 128, auto_find_batch_size: bool = False
+):
+    """
+    Args:
+    A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
+    CUDNN, the batch size is multiplied by 0.9 and passed to `function`. `function` must take in a `batch_size` parameter as
+    its first argument.
+        function (`callable`, *optional*)
+            A function to wrap
+        starting_batch_size (`int`, *optional*)
+            The batch size to try and fit into memory
+        auto_find_batch_size (`bool`, *optional*)
+            If False, will just execute `function`
+    """
+    if function is None:
+        return functools.partial(
+            find_executable_batch_size,
+            starting_batch_size=starting_batch_size,
+            auto_find_batch_size=auto_find_batch_size,
+        )
+
+    if auto_find_batch_size:
+        requires_backends(find_executable_batch_size, "accelerate")
+        from accelerate.utils import find_executable_batch_size as accelerate_find_executable_batch_size
+
+        return accelerate_find_executable_batch_size(function=function, starting_batch_size=starting_batch_size)
+
+    return functools.partial(function, batch_size=starting_batch_size)
+
+
+class FSDPOption(ExplicitEnum):
+    FULL_SHARD = "full_shard"
+    SHARD_GRAD_OP = "shard_grad_op"
+    NO_SHARD = "no_shard"
+    HYBRID_SHARD = "hybrid_shard"
+    HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2"
+    OFFLOAD = "offload"
+    AUTO_WRAP = "auto_wrap"
+
+
+class RemoveColumnsCollator:
+    """Wrap the data collator to remove unused columns before they are passed to the collator."""
+
+    def __init__(
+        self,
+        data_collator,
+        signature_columns,
+        logger=None,
+        model_name: Optional[str] = None,
+        description: Optional[str] = None,
+    ):
+        self.data_collator = data_collator
+        self.signature_columns = signature_columns
+        self.logger = logger
+        self.description = description
+        self.model_name = model_name
+        self.message_logged = False
+
+    def _remove_columns(self, feature: dict) -> dict:
+        if not isinstance(feature, dict):
+            return feature
+        if not self.message_logged and self.logger and self.model_name:
+            ignored_columns = list(set(feature.keys()) - set(self.signature_columns))
+            if len(ignored_columns) > 0:
+                dset_description = "" if self.description is None else f"in the {self.description} set"
+                self.logger.info(
+                    f"The following columns {dset_description} don't have a corresponding argument in "
+                    f"`{self.model_name}.forward` and have been ignored: {', '.join(ignored_columns)}."
+                    f" If {', '.join(ignored_columns)} are not expected by `{self.model_name}.forward`, "
+                    " you can safely ignore this message."
+                )
+                self.message_logged = True
+        return {k: v for k, v in feature.items() if k in self.signature_columns}
+
+    def __call__(self, features: list[dict]):
+        features = [self._remove_columns(feature) for feature in features]
+        return self.data_collator(features)
+
+
+def check_target_module_exists(optim_target_modules, key: str, return_is_regex: bool = False):
+    """A helper method to check if the passed module's key name matches any of the target modules in the optim_target_modules.
+
+    Args:
+        optim_target_modules (`Union[str, list[str]]`):
+            A list of strings to try to match. Can be also a full string.
+        key (`str`):
+            A key to search any matches in optim_target_modules
+        return_is_regex (`bool`):
+            If set to `True`, the method will return whether the passed `optim_target_modules`
+            is a regex or not.
+
+    Returns:
+        `bool` : True of match object if key matches any target modules from config, False or
+        None if no match found
+        `bool` : If the matched target module is a regex to silence out the warnings in Trainer
+        for extra modules being found (only if `target_module_found=True` for an array of regex).
+    """
+    target_module_found = False
+    is_regex = False
+
+    if isinstance(optim_target_modules, str):
+        target_module_found = bool(re.fullmatch(optim_target_modules, key))
+        is_regex = optim_target_modules != key
+    elif key in optim_target_modules:  # from here, target_module_found must be a list of str
+        # this module is specified directly in target_modules
+        target_module_found = True
+    elif any(target_key in key for target_key in optim_target_modules):
+        target_module_found = True
+    elif any(bool(re.fullmatch(optim_target_module, key)) for optim_target_module in optim_target_modules):
+        target_module_found = True
+        is_regex = True
+
+    if return_is_regex:
+        return target_module_found, is_regex
+
+    return target_module_found
diff --git a/phivenv/Lib/site-packages/transformers/training_args.py b/phivenv/Lib/site-packages/transformers/training_args.py
new file mode 100644
index 0000000000000000000000000000000000000000..2337edc93b33f51932781f875837b00d40758894
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/training_args.py
@@ -0,0 +1,3144 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import json
+import math
+import os
+import warnings
+from dataclasses import asdict, dataclass, field, fields
+from datetime import timedelta
+from enum import Enum
+from pathlib import Path
+from typing import Any, Optional, Union
+
+from huggingface_hub import get_full_repo_name
+
+from .debug_utils import DebugOption
+from .trainer_utils import (
+    EvaluationStrategy,
+    FSDPOption,
+    HubStrategy,
+    IntervalStrategy,
+    SaveStrategy,
+    SchedulerType,
+)
+from .utils import (
+    ACCELERATE_MIN_VERSION,
+    ExplicitEnum,
+    cached_property,
+    is_accelerate_available,
+    is_apex_available,
+    is_ipex_available,
+    is_safetensors_available,
+    is_sagemaker_dp_enabled,
+    is_sagemaker_mp_enabled,
+    is_torch_available,
+    is_torch_bf16_gpu_available,
+    is_torch_cuda_available,
+    is_torch_hpu_available,
+    is_torch_mlu_available,
+    is_torch_mps_available,
+    is_torch_musa_available,
+    is_torch_neuroncore_available,
+    is_torch_npu_available,
+    is_torch_tf32_available,
+    is_torch_xla_available,
+    is_torch_xpu_available,
+    logging,
+    requires_backends,
+)
+from .utils.generic import strtobool
+from .utils.import_utils import is_optimum_neuron_available
+
+
+logger = logging.get_logger(__name__)
+log_levels = logging.get_log_levels_dict().copy()
+trainer_log_levels = dict(**log_levels, passive=-1)
+
+if is_torch_available():
+    import torch
+    import torch.distributed as dist
+
+if is_accelerate_available():
+    from accelerate.state import AcceleratorState, PartialState
+    from accelerate.utils import DistributedType
+
+    from .trainer_pt_utils import AcceleratorConfig
+
+    if is_accelerate_available("1.10.1"):
+        from accelerate.parallelism_config import ParallelismConfig
+
+if is_torch_xla_available():
+    import torch_xla.core.xla_model as xm
+
+if is_torch_neuroncore_available(check_device=False):
+    # torchrun support
+    # https://github.com/pytorch/xla/pull/3609
+    if os.environ.get("TORCHELASTIC_RUN_ID"):
+        if is_optimum_neuron_available():
+            logger.info(
+                "Make sure that you are performing the training with the NeuronTrainer from optimum[neuron], this "
+                "will fail otherwise."
+            )
+        else:
+            logger.warning(
+                "Please use the NeuronTrainer from optimum[neuron] instead of the Transformers library to perform "
+                "training on AWS Trainium instances. More information here: "
+                "https://github.com/huggingface/optimum-neuron"
+            )
+            import torch_xla.distributed.xla_backend as xbn
+
+            if not isinstance(dist.group.WORLD, xbn.ProcessGroupXla):
+                dist.init_process_group(backend="xla")
+                if not isinstance(dist.group.WORLD, xbn.ProcessGroupXla):
+                    raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.")
+
+
+if is_sagemaker_mp_enabled():
+    import smdistributed.modelparallel.torch as smp
+
+    smp.init()
+
+
+def default_logdir() -> str:
+    """
+    Same default as PyTorch
+    """
+    import socket
+    from datetime import datetime
+
+    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
+    return os.path.join("runs", current_time + "_" + socket.gethostname())
+
+
+def get_int_from_env(env_keys, default):
+    """Returns the first positive env value found in the `env_keys` list or the default."""
+    for e in env_keys:
+        val = int(os.environ.get(e, "-1"))
+        if val >= 0:
+            return val
+    return default
+
+
+def get_xla_device_type(device: "torch.device") -> Optional[str]:
+    """
+    Returns the xla device type (CPU|GPU|TPU) or None if the device is a non-xla device.
+    """
+    if is_torch_xla_available():
+        if device.type == "cpu":
+            return "CPU"
+        return xm.xla_real_devices([device])[0].split(":")[0]
+    return None
+
+
+class OptimizerNames(ExplicitEnum):
+    """
+    Stores the acceptable string identifiers for optimizers.
+    """
+
+    ADAMW_TORCH = "adamw_torch"
+    ADAMW_TORCH_FUSED = "adamw_torch_fused"
+    ADAMW_TORCH_XLA = "adamw_torch_xla"
+    ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused"
+    ADAMW_APEX_FUSED = "adamw_apex_fused"
+    ADAFACTOR = "adafactor"
+    ADAMW_ANYPRECISION = "adamw_anyprecision"
+    ADAMW_TORCH_4BIT = "adamw_torch_4bit"
+    ADAMW_TORCH_8BIT = "adamw_torch_8bit"
+    ADEMAMIX = "ademamix"
+    SGD = "sgd"
+    ADAGRAD = "adagrad"
+    ADAMW_BNB = "adamw_bnb_8bit"
+    ADAMW_8BIT = "adamw_8bit"  # just an alias for adamw_bnb_8bit
+    ADEMAMIX_8BIT = "ademamix_8bit"
+    LION_8BIT = "lion_8bit"
+    LION = "lion_32bit"
+    PAGED_ADAMW = "paged_adamw_32bit"
+    PAGED_ADAMW_8BIT = "paged_adamw_8bit"
+    PAGED_ADEMAMIX = "paged_ademamix_32bit"
+    PAGED_ADEMAMIX_8BIT = "paged_ademamix_8bit"
+    PAGED_LION = "paged_lion_32bit"
+    PAGED_LION_8BIT = "paged_lion_8bit"
+    RMSPROP = "rmsprop"
+    RMSPROP_BNB = "rmsprop_bnb"
+    RMSPROP_8BIT = "rmsprop_bnb_8bit"
+    RMSPROP_32BIT = "rmsprop_bnb_32bit"
+    GALORE_ADAMW = "galore_adamw"
+    GALORE_ADAMW_8BIT = "galore_adamw_8bit"
+    GALORE_ADAFACTOR = "galore_adafactor"
+    GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
+    GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
+    GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
+    LOMO = "lomo"
+    ADALOMO = "adalomo"
+    GROKADAMW = "grokadamw"
+    SCHEDULE_FREE_RADAM = "schedule_free_radam"
+    SCHEDULE_FREE_ADAMW = "schedule_free_adamw"
+    SCHEDULE_FREE_SGD = "schedule_free_sgd"
+    APOLLO_ADAMW = "apollo_adamw"
+    APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"
+    STABLE_ADAMW = "stable_adamw"
+
+
+def _convert_str_dict(passed_value: dict):
+    "Safely checks that a passed value is a dictionary and converts any string values to their appropriate types."
+    for key, value in passed_value.items():
+        if isinstance(value, dict):
+            passed_value[key] = _convert_str_dict(value)
+        elif isinstance(value, str):
+            # First check for bool and convert
+            if value.lower() in ("true", "false"):
+                passed_value[key] = value.lower() == "true"
+            # Check for digit
+            elif value.isdigit():
+                passed_value[key] = int(value)
+            elif value.replace(".", "", 1).isdigit():
+                passed_value[key] = float(value)
+
+    return passed_value
+
+
+# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
+@dataclass
+class TrainingArguments:
+    """
+    TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop
+    itself**.
+
+    Using [`HfArgumentParser`] we can turn this class into
+    [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
+    command line.
+
+    Parameters:
+        output_dir (`str`, *optional*, defaults to `"trainer_output"`):
+            The output directory where the model predictions and checkpoints will be written.
+        overwrite_output_dir (`bool`, *optional*, defaults to `False`):
+            If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir`
+            points to a checkpoint directory.
+        do_train (`bool`, *optional*, defaults to `False`):
+            Whether to run training or not. This argument is not directly used by [`Trainer`], it's intended to be used
+            by your training/evaluation scripts instead. See the [example
+            scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
+        do_eval (`bool`, *optional*):
+            Whether to run evaluation on the validation set or not. Will be set to `True` if `eval_strategy` is
+            different from `"no"`. This argument is not directly used by [`Trainer`], it's intended to be used by your
+            training/evaluation scripts instead. See the [example
+            scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
+        do_predict (`bool`, *optional*, defaults to `False`):
+            Whether to run predictions on the test set or not. This argument is not directly used by [`Trainer`], it's
+            intended to be used by your training/evaluation scripts instead. See the [example
+            scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
+        eval_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`):
+            The evaluation strategy to adopt during training. Possible values are:
+
+                - `"no"`: No evaluation is done during training.
+                - `"steps"`: Evaluation is done (and logged) every `eval_steps`.
+                - `"epoch"`: Evaluation is done at the end of each epoch.
+
+        prediction_loss_only (`bool`, *optional*, defaults to `False`):
+            When performing evaluation and generating predictions, only returns the loss.
+        per_device_train_batch_size (`int`, *optional*, defaults to 8):
+            The batch size *per device*. The **global batch size** is computed as:
+            `per_device_train_batch_size * number_of_devices` in multi-GPU or distributed setups.
+        per_device_eval_batch_size (`int`, *optional*, defaults to 8):
+            The batch size per device accelerator core/CPU for evaluation.
+        gradient_accumulation_steps (`int`, *optional*, defaults to 1):
+            Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
+
+            
+
+            When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging,
+            evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training examples.
+
+            
+
+        eval_accumulation_steps (`int`, *optional*):
+            Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If
+            left unset, the whole predictions are accumulated on the device accelerator before being moved to the CPU (faster but
+            requires more memory).
+        eval_delay (`float`, *optional*):
+            Number of epochs or steps to wait for before the first evaluation can be performed, depending on the
+            eval_strategy.
+        torch_empty_cache_steps (`int`, *optional*):
+            Number of steps to wait before calling `torch..empty_cache()`. If left unset or set to None, cache will not be emptied.
+
+            
+
+            This can help avoid CUDA out-of-memory errors by lowering peak VRAM usage at a cost of about [10% slower performance](https://github.com/huggingface/transformers/issues/31372).
+
+            
+
+        learning_rate (`float`, *optional*, defaults to 5e-5):
+            The initial learning rate for [`AdamW`] optimizer.
+        weight_decay (`float`, *optional*, defaults to 0):
+            The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in [`AdamW`]
+            optimizer.
+        adam_beta1 (`float`, *optional*, defaults to 0.9):
+            The beta1 hyperparameter for the [`AdamW`] optimizer.
+        adam_beta2 (`float`, *optional*, defaults to 0.999):
+            The beta2 hyperparameter for the [`AdamW`] optimizer.
+        adam_epsilon (`float`, *optional*, defaults to 1e-8):
+            The epsilon hyperparameter for the [`AdamW`] optimizer.
+        max_grad_norm (`float`, *optional*, defaults to 1.0):
+            Maximum gradient norm (for gradient clipping).
+        num_train_epochs(`float`, *optional*, defaults to 3.0):
+            Total number of training epochs to perform (if not an integer, will perform the decimal part percents of
+            the last epoch before stopping training).
+        max_steps (`int`, *optional*, defaults to -1):
+            If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`.
+            For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until
+            `max_steps` is reached.
+        lr_scheduler_type (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`):
+            The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values.
+        lr_scheduler_kwargs ('dict', *optional*, defaults to {}):
+            The extra arguments for the lr_scheduler. See the documentation of each scheduler for possible values.
+        warmup_ratio (`float`, *optional*, defaults to 0.0):
+            Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.
+        warmup_steps (`int`, *optional*, defaults to 0):
+            Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`.
+        log_level (`str`, *optional*, defaults to `passive`):
+            Logger log level to use on the main process. Possible choices are the log levels as strings: 'debug',
+            'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and keeps the
+            current log level for the Transformers library (which will be `"warning"` by default).
+        log_level_replica (`str`, *optional*, defaults to `"warning"`):
+            Logger log level to use on replicas. Same choices as `log_level`"
+        log_on_each_node (`bool`, *optional*, defaults to `True`):
+            In multinode distributed training, whether to log using `log_level` once per node, or only on the main
+            node.
+        logging_dir (`str`, *optional*):
+            [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to
+            *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.
+        logging_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
+            The logging strategy to adopt during training. Possible values are:
+
+                - `"no"`: No logging is done during training.
+                - `"epoch"`: Logging is done at the end of each epoch.
+                - `"steps"`: Logging is done every `logging_steps`.
+
+        logging_first_step (`bool`, *optional*, defaults to `False`):
+            Whether to log the first `global_step` or not.
+        logging_steps (`int` or `float`, *optional*, defaults to 500):
+            Number of update steps between two logs if `logging_strategy="steps"`. Should be an integer or a float in
+            range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps.
+        logging_nan_inf_filter (`bool`, *optional*, defaults to `True`):
+            Whether to filter `nan` and `inf` losses for logging. If set to `True` the loss of every step that is `nan`
+            or `inf` is filtered and the average loss of the current logging window is taken instead.
+
+            
+
+            `logging_nan_inf_filter` only influences the logging of loss values, it does not change the behavior the
+            gradient is computed or applied to the model.
+
+            
+
+        save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`):
+            The checkpoint save strategy to adopt during training. Possible values are:
+
+                - `"no"`: No save is done during training.
+                - `"epoch"`: Save is done at the end of each epoch.
+                - `"steps"`: Save is done every `save_steps`.
+                - `"best"`: Save is done whenever a new `best_metric` is achieved.
+
+                If `"epoch"` or `"steps"` is chosen, saving will also be performed at the
+                very end of training, always.
+        save_steps (`int` or `float`, *optional*, defaults to 500):
+            Number of updates steps before two checkpoint saves if `save_strategy="steps"`. Should be an integer or a
+            float in range `[0,1)`. If smaller than 1, will be interpreted as ratio of total training steps.
+        save_total_limit (`int`, *optional*):
+            If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
+            `output_dir`. When `load_best_model_at_end` is enabled, the "best" checkpoint according to
+            `metric_for_best_model` will always be retained in addition to the most recent ones. For example, for
+            `save_total_limit=5` and `load_best_model_at_end`, the four last checkpoints will always be retained
+            alongside the best model. When `save_total_limit=1` and `load_best_model_at_end`, it is possible that two
+            checkpoints are saved: the last one and the best one (if they are different).
+        save_safetensors (`bool`, *optional*, defaults to `True`):
+            Use [safetensors](https://huggingface.co/docs/safetensors) saving and loading for state dicts instead of
+            default `torch.load` and `torch.save`.
+        save_on_each_node (`bool`, *optional*, defaults to `False`):
+            When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on
+            the main one.
+
+            This should not be activated when the different nodes use the same storage as the files will be saved with
+            the same names for each node.
+        save_only_model (`bool`, *optional*, defaults to `False`):
+            When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state.
+            Note that when this is true, you won't be able to resume training from checkpoint.
+            This enables you to save storage by not storing the optimizer, scheduler & rng state.
+            You can only load the model using `from_pretrained` with this option set to `True`.
+        restore_callback_states_from_checkpoint (`bool`, *optional*, defaults to `False`):
+            Whether to restore the callback states from the checkpoint. If `True`, will override
+            callbacks passed to the `Trainer` if they exist in the checkpoint."
+        use_cpu (`bool`, *optional*, defaults to `False`):
+            Whether or not to use cpu. If set to False, we will use cuda or mps device if available.
+        seed (`int`, *optional*, defaults to 42):
+            Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the
+            [`~Trainer.model_init`] function to instantiate the model if it has some randomly initialized parameters.
+        data_seed (`int`, *optional*):
+            Random seed to be used with data samplers. If not set, random generators for data sampling will use the
+            same seed as `seed`. This can be used to ensure reproducibility of data sampling, independent of the model
+            seed.
+        jit_mode_eval (`bool`, *optional*, defaults to `False`):
+            Whether or not to use PyTorch jit trace for inference.
+        use_ipex (`bool`, *optional*, defaults to `False`):
+            Use Intel extension for PyTorch when it is available. [IPEX
+            installation](https://github.com/intel/intel-extension-for-pytorch).
+        bf16 (`bool`, *optional*, defaults to `False`):
+            Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher
+            NVIDIA architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change.
+        fp16 (`bool`, *optional*, defaults to `False`):
+            Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training.
+        fp16_opt_level (`str`, *optional*, defaults to 'O1'):
+            For `fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details on
+            the [Apex documentation](https://nvidia.github.io/apex/amp).
+        fp16_backend (`str`, *optional*, defaults to `"auto"`):
+            This argument is deprecated. Use `half_precision_backend` instead.
+        half_precision_backend (`str`, *optional*, defaults to `"auto"`):
+            The backend to use for mixed precision training. Must be one of `"auto", "apex", "cpu_amp"`. `"auto"` will
+            use CPU/CUDA AMP or APEX depending on the PyTorch version detected, while the other choices will force the
+            requested backend.
+        bf16_full_eval (`bool`, *optional*, defaults to `False`):
+            Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm
+            metric values. This is an experimental API and it may change.
+        fp16_full_eval (`bool`, *optional*, defaults to `False`):
+            Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm
+            metric values.
+        tf32 (`bool`, *optional*):
+            Whether to enable the TF32 mode, available in Ampere and newer GPU architectures. The default value depends
+            on PyTorch's version default of `torch.backends.cuda.matmul.allow_tf32`. For more details please refer to
+            the [TF32](https://huggingface.co/docs/transformers/perf_train_gpu_one#tf32) documentation. This is an
+            experimental API and it may change.
+        local_rank (`int`, *optional*, defaults to -1):
+            Rank of the process during distributed training.
+        ddp_backend (`str`, *optional*):
+            The backend to use for distributed training. Must be one of `"nccl"`, `"mpi"`, `"ccl"`, `"gloo"`, `"hccl"`.
+        tpu_num_cores (`int`, *optional*):
+            When training on TPU, the number of TPU cores (automatically passed by launcher script).
+        dataloader_drop_last (`bool`, *optional*, defaults to `False`):
+            Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
+            or not.
+        eval_steps (`int` or `float`, *optional*):
+            Number of update steps between two evaluations if `eval_strategy="steps"`. Will default to the same
+            value as `logging_steps` if not set. Should be an integer or a float in range `[0,1)`. If smaller than 1,
+            will be interpreted as ratio of total training steps.
+        dataloader_num_workers (`int`, *optional*, defaults to 0):
+            Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the
+            main process.
+        past_index (`int`, *optional*, defaults to -1):
+            Some models like [TransformerXL](../model_doc/transformerxl) or [XLNet](../model_doc/xlnet) can make use of
+            the past hidden states for their predictions. If this argument is set to a positive int, the `Trainer` will
+            use the corresponding output (usually index 2) as the past state and feed it to the model at the next
+            training step under the keyword argument `mems`.
+        run_name (`str`, *optional*, defaults to `output_dir`):
+            A descriptor for the run. Typically used for [trackio](https://github.com/gradio-app/trackio),
+            [wandb](https://www.wandb.com/), [mlflow](https://www.mlflow.org/), [comet](https://www.comet.com/site) and
+            [swanlab](https://swanlab.cn) logging. If not specified, will be the same as `output_dir`.
+        disable_tqdm (`bool`, *optional*):
+            Whether or not to disable the tqdm progress bars and table of metrics produced by
+            [`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is
+            set to warn or lower (default), `False` otherwise.
+        remove_unused_columns (`bool`, *optional*, defaults to `True`):
+            Whether or not to automatically remove the columns unused by the model forward method.
+        label_names (`list[str]`, *optional*):
+            The list of keys in your dictionary of inputs that correspond to the labels.
+
+            Will eventually default to the list of argument names accepted by the model that contain the word "label",
+            except if the model used is one of the `XxxForQuestionAnswering` in which case it will also include the
+            `["start_positions", "end_positions"]` keys.
+
+            You should only specify `label_names` if you're using custom label names or if your model's `forward` consumes multiple label tensors (e.g., extractive QA).
+        load_best_model_at_end (`bool`, *optional*, defaults to `False`):
+            Whether or not to load the best model found during training at the end of training. When this option is
+            enabled, the best checkpoint will always be saved. See
+            [`save_total_limit`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.save_total_limit)
+            for more.
+
+            
+
+            When set to `True`, the parameters `save_strategy` needs to be the same as `eval_strategy`, and in
+            the case it is "steps", `save_steps` must be a round multiple of `eval_steps`.
+
+            
+
+        metric_for_best_model (`str`, *optional*):
+            Use in conjunction with `load_best_model_at_end` to specify the metric to use to compare two different
+            models. Must be the name of a metric returned by the evaluation with or without the prefix `"eval_"`.
+
+            If not specified, this will default to `"loss"` when either `load_best_model_at_end == True`
+            or `lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU` (to use the evaluation loss).
+
+            If you set this value, `greater_is_better` will default to `True` unless the name ends with "loss".
+            Don't forget to set it to `False` if your metric is better when lower.
+        greater_is_better (`bool`, *optional*):
+            Use in conjunction with `load_best_model_at_end` and `metric_for_best_model` to specify if better models
+            should have a greater metric or not. Will default to:
+
+            - `True` if `metric_for_best_model` is set to a value that doesn't end in `"loss"`.
+            - `False` if `metric_for_best_model` is not set, or set to a value that ends in `"loss"`.
+        ignore_data_skip (`bool`, *optional*, defaults to `False`):
+            When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
+            stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step
+            can take a long time) but will not yield the same results as the interrupted training would have.
+        fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `''`):
+            Use PyTorch Distributed Parallel Training (in distributed training only).
+
+            A list of options along the following:
+
+            - `"full_shard"`: Shard parameters, gradients and optimizer states.
+            - `"shard_grad_op"`: Shard optimizer states and gradients.
+            - `"hybrid_shard"`: Apply `FULL_SHARD` within a node, and replicate parameters across nodes.
+            - `"hybrid_shard_zero2"`: Apply `SHARD_GRAD_OP` within a node, and replicate parameters across nodes.
+            - `"offload"`: Offload parameters and gradients to CPUs (only compatible with `"full_shard"` and
+              `"shard_grad_op"`).
+            - `"auto_wrap"`: Automatically recursively wrap layers with FSDP using `default_auto_wrap_policy`.
+        fsdp_config (`str` or `dict`, *optional*):
+            Config to be used with fsdp (Pytorch Distributed Parallel Training). The value is either a location of
+            fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`.
+
+            A List of config and its options:
+                - min_num_params (`int`, *optional*, defaults to `0`):
+                    FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is
+                    passed).
+                - transformer_layer_cls_to_wrap (`list[str]`, *optional*):
+                    List of transformer layer class names (case-sensitive) to wrap, e.g, `BertLayer`, `GPTJBlock`,
+                    `T5Block` .... (useful only when `fsdp` flag is passed).
+                - backward_prefetch (`str`, *optional*)
+                    FSDP's backward prefetch mode. Controls when to prefetch next set of parameters (useful only when
+                    `fsdp` field is passed).
+
+                    A list of options along the following:
+
+                    - `"backward_pre"` : Prefetches the next set of parameters before the current set of parameter's
+                      gradient
+                        computation.
+                    - `"backward_post"` : This prefetches the next set of parameters after the current set of
+                      parameter’s
+                        gradient computation.
+                - forward_prefetch (`bool`, *optional*, defaults to `False`)
+                    FSDP's forward prefetch mode (useful only when `fsdp` field is passed).
+                     If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the
+                     forward pass.
+                - limit_all_gathers (`bool`, *optional*, defaults to `False`)
+                    FSDP's limit_all_gathers (useful only when `fsdp` field is passed).
+                     If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight
+                     all-gathers.
+                - use_orig_params (`bool`, *optional*, defaults to `True`)
+                    If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed
+                    frozen and trainable parameters. Useful in cases such as parameter-efficient fine-tuning. Please
+                    refer this
+                    [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019
+                - sync_module_states (`bool`, *optional*, defaults to `True`)
+                    If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to
+                    ensure they are the same across all ranks after initialization
+                - cpu_ram_efficient_loading (`bool`, *optional*, defaults to `False`)
+                    If `"True"`, only the first process loads the pretrained model checkpoint while all other processes
+                    have empty weights.  When this setting as `"True"`, `sync_module_states` also must to be `"True"`,
+                    otherwise all the processes except the main process would have random weights leading to unexpected
+                    behaviour during training.
+                - activation_checkpointing (`bool`, *optional*, defaults to `False`):
+                    If `"True"`, activation checkpointing is a technique to reduce memory usage by clearing activations of
+                    certain layers and recomputing them during a backward pass. Effectively, this trades extra
+                    computation time for reduced memory usage.
+                - xla (`bool`, *optional*, defaults to `False`):
+                    Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature
+                    and its API may evolve in the future.
+                - xla_fsdp_settings (`dict`, *optional*)
+                    The value is a dictionary which stores the XLA FSDP wrapping parameters.
+
+                    For a complete list of options, please see [here](
+                    https://github.com/pytorch/xla/blob/master/torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py).
+                - xla_fsdp_grad_ckpt (`bool`, *optional*, defaults to `False`):
+                    Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
+                    used when the xla flag is set to true, and an auto wrapping policy is specified through
+                    fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.
+        deepspeed (`str` or `dict`, *optional*):
+            Use [Deepspeed](https://github.com/deepspeedai/DeepSpeed). This is an experimental feature and its API may
+            evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
+            `ds_config.json`) or an already loaded json file as a `dict`"
+
+            
+                If enabling any Zero-init, make sure that your model is not initialized until
+                *after* initializing the `TrainingArguments`, else it will not be applied.
+            
+
+        accelerator_config (`str`, `dict`, or `AcceleratorConfig`, *optional*):
+            Config to be used with the internal `Accelerator` implementation. The value is either a location of
+            accelerator json config file (e.g., `accelerator_config.json`), an already loaded json file as `dict`,
+            or an instance of [`~trainer_pt_utils.AcceleratorConfig`].
+
+            A list of config and its options:
+                - split_batches (`bool`, *optional*, defaults to `False`):
+                    Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If
+                    `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a
+                    round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set
+                    in your script multiplied by the number of processes.
+                - dispatch_batches (`bool`, *optional*):
+                    If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process
+                    and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose
+                    underlying dataset is an `IterableDataset`, `False` otherwise.
+                - even_batches (`bool`, *optional*, defaults to `True`):
+                    If set to `True`, in cases where the total batch size across all processes does not exactly divide the
+                    dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
+                    all workers.
+                - use_seedable_sampler (`bool`, *optional*, defaults to `True`):
+                    Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`]). Ensures
+                    training results are fully reproducible using a different sampling technique. While seed-to-seed results
+                    may differ, on average the differences are negligible when using multiple different seeds to compare. Should
+                    also be ran with [`~utils.set_seed`] for the best results.
+                - use_configured_state (`bool`, *optional*, defaults to `False`):
+                    Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`.
+                    If `True`, an `Accelerator` or `PartialState` must be initialized. Note that by doing so, this could lead to issues
+                    with hyperparameter tuning.
+        parallelism_config (`ParallelismConfig`, *optional*):
+            Parallelism configuration for the training run. Requires Accelerate `1.10.1`
+        label_smoothing_factor (`float`, *optional*, defaults to 0.0):
+            The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
+            labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor +
+            label_smoothing_factor/num_labels` respectively.
+        debug (`str` or list of [`~debug_utils.DebugOption`], *optional*, defaults to `""`):
+            Enable one or more debug features. This is an experimental feature.
+
+            Possible options are:
+
+            - `"underflow_overflow"`: detects overflow in model's input/outputs and reports the last frames that led to
+              the event
+            - `"tpu_metrics_debug"`: print debug metrics on TPU
+
+            The options should be separated by whitespaces.
+        optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"` (for torch>=2.8 `"adamw_torch_fused"`)):
+            The optimizer to use, such as "adamw_torch", "adamw_torch_fused", "adamw_apex_fused", "adamw_anyprecision",
+            "adafactor". See `OptimizerNames` in [training_args.py](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py)
+            for a full list of optimizers.
+        optim_args (`str`, *optional*):
+            Optional arguments that are supplied to optimizers such as AnyPrecisionAdamW, AdEMAMix, and GaLore.
+        group_by_length (`bool`, *optional*, defaults to `False`):
+            Whether or not to group together samples of roughly the same length in the training dataset (to minimize
+            padding applied and be more efficient). Only useful if applying dynamic padding.
+        length_column_name (`str`, *optional*, defaults to `"length"`):
+            Column name for precomputed lengths. If the column exists, grouping by length will use these values rather
+            than computing them on train startup. Ignored unless `group_by_length` is `True` and the dataset is an
+            instance of `Dataset`.
+        report_to (`str` or `list[str]`, *optional*, defaults to `"all"`):
+            The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
+            `"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, `"neptune"`,
+            `"swanlab"`, `"tensorboard"`, `"trackio"` and `"wandb"`. Use `"all"` to report to all integrations
+            installed, `"none"` for no integrations.
+        ddp_find_unused_parameters (`bool`, *optional*):
+            When using distributed training, the value of the flag `find_unused_parameters` passed to
+            `DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
+        ddp_bucket_cap_mb (`int`, *optional*):
+            When using distributed training, the value of the flag `bucket_cap_mb` passed to `DistributedDataParallel`.
+        ddp_broadcast_buffers (`bool`, *optional*):
+            When using distributed training, the value of the flag `broadcast_buffers` passed to
+            `DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
+        dataloader_pin_memory (`bool`, *optional*, defaults to `True`):
+            Whether you want to pin memory in data loaders or not. Will default to `True`.
+        dataloader_persistent_workers (`bool`, *optional*, defaults to `False`):
+            If True, the data loader will not shut down the worker processes after a dataset has been consumed once.
+            This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will
+            increase RAM usage. Will default to `False`.
+        dataloader_prefetch_factor (`int`, *optional*):
+            Number of batches loaded in advance by each worker.
+            2 means there will be a total of 2 * num_workers batches prefetched across all workers.
+        skip_memory_metrics (`bool`, *optional*, defaults to `True`):
+            Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows
+            down the training and evaluation speed.
+        push_to_hub (`bool`, *optional*, defaults to `False`):
+            Whether or not to push the model to the Hub every time the model is saved. If this is activated,
+            `output_dir` will begin a git directory synced with the repo (determined by `hub_model_id`) and the content
+            will be pushed each time a save is triggered (depending on your `save_strategy`). Calling
+            [`~Trainer.save_model`] will also trigger a push.
+
+            
+
+            If `output_dir` exists, it needs to be a local clone of the repository to which the [`Trainer`] will be
+            pushed.
+
+            
+
+        resume_from_checkpoint (`str`, *optional*):
+            The path to a folder with a valid checkpoint for your model. This argument is not directly used by
+            [`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example
+            scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
+        hub_model_id (`str`, *optional*):
+            The name of the repository to keep in sync with the local *output_dir*. It can be a simple model ID in
+            which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
+            for instance `"user_name/model"`, which allows you to push to an organization you are a member of with
+            `"organization_name/model"`. Will default to `user_name/output_dir_name` with *output_dir_name* being the
+            name of `output_dir`.
+
+            Will default to the name of `output_dir`.
+        hub_strategy (`str` or [`~trainer_utils.HubStrategy`], *optional*, defaults to `"every_save"`):
+            Defines the scope of what is pushed to the Hub and when. Possible values are:
+
+            - `"end"`: push the model, its configuration, the processing class e.g. tokenizer (if passed along to the [`Trainer`]) and a
+              draft of a model card when the [`~Trainer.save_model`] method is called.
+            - `"every_save"`: push the model, its configuration, the processing class e.g. tokenizer (if passed along to the [`Trainer`]) and
+              a draft of a model card each time there is a model save. The pushes are asynchronous to not block
+              training, and in case the save are very frequent, a new push is only attempted if the previous one is
+              finished. A last push is made with the final model at the end of training.
+            - `"checkpoint"`: like `"every_save"` but the latest checkpoint is also pushed in a subfolder named
+              last-checkpoint, allowing you to resume training easily with
+              `trainer.train(resume_from_checkpoint="last-checkpoint")`.
+            - `"all_checkpoints"`: like `"checkpoint"` but all checkpoints are pushed like they appear in the output
+              folder (so you will get one checkpoint folder per folder in your final repository)
+
+        hub_token (`str`, *optional*):
+            The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
+            `hf auth login`.
+        hub_private_repo (`bool`, *optional*):
+            Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
+        hub_always_push (`bool`, *optional*, defaults to `False`):
+            Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished.
+        hub_revision (`str`, *optional*):
+            The revision to use when pushing to the Hub. Can be a branch name, a tag, or a commit hash.
+        gradient_checkpointing (`bool`, *optional*, defaults to `False`):
+            If True, use gradient checkpointing to save memory at the expense of slower backward pass.
+        gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):
+            Key word arguments to be passed to the `gradient_checkpointing_enable` method.
+        include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
+            This argument is deprecated. Use `include_for_metrics` instead, e.g, `include_for_metrics = ["inputs"]`.
+        include_for_metrics (`list[str]`, *optional*, defaults to `[]`):
+            Include additional data in the `compute_metrics` function if needed for metrics computation.
+            Possible options to add to `include_for_metrics` list:
+            - `"inputs"`: Input data passed to the model, intended for calculating input dependent metrics.
+            - `"loss"`: Loss values computed during evaluation, intended for calculating loss dependent metrics.
+        eval_do_concat_batches (`bool`, *optional*, defaults to `True`):
+            Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`,
+            will instead store them as lists, with each batch kept separate.
+        auto_find_batch_size (`bool`, *optional*, defaults to `False`)
+            Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding
+            CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
+        full_determinism (`bool`, *optional*, defaults to `False`)
+            If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
+            distributed training. Important: this will negatively impact the performance, so only use it for debugging.
+        torchdynamo (`str`, *optional*):
+            If set, the backend compiler for TorchDynamo. Possible choices are `"eager"`, `"aot_eager"`, `"inductor"`,
+            `"nvfuser"`, `"aot_nvfuser"`, `"aot_cudagraphs"`, `"ofi"`, `"fx2trt"`, `"onnxrt"` and `"ipex"`.
+        ray_scope (`str`, *optional*, defaults to `"last"`):
+            The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray will
+            then use the last checkpoint of all trials, compare those, and select the best one. However, other options
+            are also available. See the [Ray documentation](
+            https://docs.ray.io/en/latest/tune/api_docs/analysis.html#ray.tune.ExperimentAnalysis.get_best_trial) for
+            more options.
+        ddp_timeout (`int`, *optional*, defaults to 1800):
+            The timeout for `torch.distributed.init_process_group` calls, used to avoid GPU socket timeouts when
+            performing slow operations in distributed runnings. Please refer the [PyTorch documentation]
+            (https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more
+            information.
+        use_mps_device (`bool`, *optional*, defaults to `False`):
+            This argument is deprecated.`mps` device will be used if it is available similar to `cuda` device.
+        torch_compile (`bool`, *optional*, defaults to `False`):
+            Whether or not to compile the model using PyTorch 2.0
+            [`torch.compile`](https://pytorch.org/get-started/pytorch-2.0/).
+
+            This will use the best defaults for the [`torch.compile`
+            API](https://pytorch.org/docs/stable/generated/torch.compile.html?highlight=torch+compile#torch.compile).
+            You can customize the defaults with the argument `torch_compile_backend` and `torch_compile_mode` but we
+            don't guarantee any of them will work as the support is progressively rolled in in PyTorch.
+
+            This flag and the whole compile API is experimental and subject to change in future releases.
+        torch_compile_backend (`str`, *optional*):
+            The backend to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.
+
+            Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.
+
+            This flag is experimental and subject to change in future releases.
+        torch_compile_mode (`str`, *optional*):
+            The mode to use in `torch.compile`. If set to any value, `torch_compile` will be set to `True`.
+
+            Refer to the PyTorch doc for possible values and note that they may change across PyTorch versions.
+
+            This flag is experimental and subject to change in future releases.
+        include_tokens_per_second (`bool`, *optional*):
+            Whether or not to compute the number of tokens per second per device for training speed metrics.
+
+            This will iterate over the entire training dataloader once beforehand,
+
+            and will slow down the entire process.
+
+        include_num_input_tokens_seen (`bool`, *optional*):
+            Whether or not to track the number of input tokens seen throughout training.
+
+            May be slower in distributed training as gather operations must be called.
+
+        neftune_noise_alpha (`Optional[float]`):
+            If not `None`, this will activate NEFTune noise embeddings. This can drastically improve model performance
+            for instruction fine-tuning. Check out the [original paper](https://huggingface.co/papers/2310.05914) and the
+            [original code](https://github.com/neelsjain/NEFTune). Support transformers `PreTrainedModel` and also
+            `PeftModel` from peft. The original paper used values in the range [5.0, 15.0].
+        optim_target_modules (`Union[str, list[str]]`, *optional*):
+            The target modules to optimize, i.e. the module names that you would like to train.
+            Currently used for the GaLore algorithm (https://huggingface.co/papers/2403.03507) and APOLLO algorithm (https://huggingface.co/papers/2412.05270).
+            See GaLore implementation (https://github.com/jiaweizzhao/GaLore) and APOLLO implementation (https://github.com/zhuhanqing/APOLLO) for more details.
+            You need to make sure to pass a valid GaLore or APOLLO optimizer, e.g., one of: "apollo_adamw", "galore_adamw", "galore_adamw_8bit", "galore_adafactor" and make sure that the target modules are `nn.Linear` modules only.
+
+        batch_eval_metrics (`Optional[bool]`, defaults to `False`):
+            If set to `True`, evaluation will call compute_metrics at the end of each batch to accumulate statistics
+            rather than saving all eval logits in memory. When set to `True`, you must pass a compute_metrics function
+            that takes a boolean argument `compute_result`, which when passed `True`, will trigger the final global
+            summary statistics from the batch-level summary statistics you've accumulated over the evaluation set.
+
+        eval_on_start (`bool`, *optional*, defaults to `False`):
+            Whether to perform a evaluation step (sanity check) before the training to ensure the validation steps works correctly.
+
+        eval_use_gather_object (`bool`, *optional*, defaults to `False`):
+            Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices. This should only be enabled if users are not just returning tensors, and this is actively discouraged by PyTorch.
+
+        use_liger_kernel (`bool`, *optional*, defaults to `False`):
+            Whether enable [Liger](https://github.com/linkedin/Liger-Kernel) Kernel for LLM model training.
+            It can effectively increase multi-GPU training throughput by ~20% and reduces memory usage by ~60%, works out of the box with
+            flash attention, PyTorch FSDP, and Microsoft DeepSpeed. Currently, it supports llama, mistral, mixtral and gemma models.
+
+        liger_kernel_config (`Optional[dict]`, *optional*):
+            Configuration to be used for Liger Kernel. When use_liger_kernel=True, this dict is passed as keyword arguments to the
+            `_apply_liger_kernel_to_instance` function, which specifies which kernels to apply. Available options vary by model but typically
+            include: 'rope', 'swiglu', 'cross_entropy', 'fused_linear_cross_entropy', 'rms_norm', etc. If `None`, use the default kernel configurations.
+
+        average_tokens_across_devices (`bool`, *optional*, defaults to `True`):
+            Whether or not to average tokens across devices. If enabled, will use all_reduce to synchronize
+            num_tokens_in_batch for precise loss calculation. Reference:
+            https://github.com/huggingface/transformers/issues/34242
+    """
+
+    # Sometimes users will pass in a `str` repr of a dict in the CLI
+    # We need to track what fields those can be. Each time a new arg
+    # has a dict type, it must be added to this list.
+    # Important: These should be typed with Optional[Union[dict,str,...]]
+    _VALID_DICT_FIELDS = [
+        "accelerator_config",
+        "fsdp_config",
+        "deepspeed",
+        "gradient_checkpointing_kwargs",
+        "lr_scheduler_kwargs",
+    ]
+    framework = "pt"
+
+    output_dir: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "The output directory where the model predictions and checkpoints will be written. Defaults to 'trainer_output' if not provided."
+        },
+    )
+    overwrite_output_dir: bool = field(
+        default=False,
+        metadata={
+            "help": (
+                "Overwrite the content of the output directory. "
+                "Use this to continue training if output_dir points to a checkpoint directory."
+            )
+        },
+    )
+
+    do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
+    do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
+    do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
+    eval_strategy: Union[IntervalStrategy, str] = field(
+        default="no",
+        metadata={"help": "The evaluation strategy to use."},
+    )
+    prediction_loss_only: bool = field(
+        default=False,
+        metadata={"help": "When performing evaluation and predictions, only returns the loss."},
+    )
+
+    per_device_train_batch_size: int = field(
+        default=8, metadata={"help": "Batch size per device accelerator core/CPU for training."}
+    )
+    per_device_eval_batch_size: int = field(
+        default=8, metadata={"help": "Batch size per device accelerator core/CPU for evaluation."}
+    )
+
+    per_gpu_train_batch_size: Optional[int] = field(
+        default=None,
+        metadata={
+            "help": (
+                "Deprecated, the use of `--per_device_train_batch_size` is preferred. "
+                "Batch size per GPU/TPU core/CPU for training."
+            )
+        },
+    )
+    per_gpu_eval_batch_size: Optional[int] = field(
+        default=None,
+        metadata={
+            "help": (
+                "Deprecated, the use of `--per_device_eval_batch_size` is preferred. "
+                "Batch size per GPU/TPU core/CPU for evaluation."
+            )
+        },
+    )
+
+    gradient_accumulation_steps: int = field(
+        default=1,
+        metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."},
+    )
+    eval_accumulation_steps: Optional[int] = field(
+        default=None,
+        metadata={"help": "Number of predictions steps to accumulate before moving the tensors to the CPU."},
+    )
+
+    eval_delay: Optional[float] = field(
+        default=0,
+        metadata={
+            "help": (
+                "Number of epochs or steps to wait for before the first evaluation can be performed, depending on the"
+                " eval_strategy."
+            )
+        },
+    )
+
+    torch_empty_cache_steps: Optional[int] = field(
+        default=None,
+        metadata={
+            "help": "Number of steps to wait before calling `torch..empty_cache()`."
+            "This can help avoid CUDA out-of-memory errors by lowering peak VRAM usage at a cost of about [10% slower performance](https://github.com/huggingface/transformers/issues/31372)."
+            "If left unset or set to None, cache will not be emptied."
+        },
+    )
+
+    learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
+    weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
+    adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
+    adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
+    adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
+    max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."})
+
+    num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
+    max_steps: int = field(
+        default=-1,
+        metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."},
+    )
+    lr_scheduler_type: Union[SchedulerType, str] = field(
+        default="linear",
+        metadata={"help": "The scheduler type to use."},
+    )
+    lr_scheduler_kwargs: Optional[Union[dict[str, Any], str]] = field(
+        default_factory=dict,
+        metadata={
+            "help": (
+                "Extra parameters for the lr_scheduler such as {'num_cycles': 1} for the cosine with hard restarts."
+            )
+        },
+    )
+    warmup_ratio: float = field(
+        default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."}
+    )
+    warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
+
+    log_level: str = field(
+        default="passive",
+        metadata={
+            "help": (
+                "Logger log level to use on the main node. Possible choices are the log levels as strings: 'debug',"
+                " 'info', 'warning', 'error' and 'critical', plus a 'passive' level which doesn't set anything and"
+                " lets the application set the level. Defaults to 'passive'."
+            ),
+            "choices": trainer_log_levels.keys(),
+        },
+    )
+    log_level_replica: str = field(
+        default="warning",
+        metadata={
+            "help": "Logger log level to use on replica nodes. Same choices and defaults as ``log_level``",
+            "choices": trainer_log_levels.keys(),
+        },
+    )
+    log_on_each_node: bool = field(
+        default=True,
+        metadata={
+            "help": (
+                "When doing a multinode distributed training, whether to log once per node or just once on the main"
+                " node."
+            )
+        },
+    )
+    logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."})
+    logging_strategy: Union[IntervalStrategy, str] = field(
+        default="steps",
+        metadata={"help": "The logging strategy to use."},
+    )
+    logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
+    logging_steps: float = field(
+        default=500,
+        metadata={
+            "help": (
+                "Log every X updates steps. Should be an integer or a float in range `[0,1)`. "
+                "If smaller than 1, will be interpreted as ratio of total training steps."
+            )
+        },
+    )
+    logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."})
+    save_strategy: Union[SaveStrategy, str] = field(
+        default="steps",
+        metadata={"help": "The checkpoint save strategy to use."},
+    )
+    save_steps: float = field(
+        default=500,
+        metadata={
+            "help": (
+                "Save checkpoint every X updates steps. Should be an integer or a float in range `[0,1)`. "
+                "If smaller than 1, will be interpreted as ratio of total training steps."
+            )
+        },
+    )
+    save_total_limit: Optional[int] = field(
+        default=None,
+        metadata={
+            "help": (
+                "If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in"
+                " `output_dir`. When `load_best_model_at_end` is enabled, the 'best' checkpoint according to"
+                " `metric_for_best_model` will always be retained in addition to the most recent ones. For example,"
+                " for `save_total_limit=5` and `load_best_model_at_end=True`, the four last checkpoints will always be"
+                " retained alongside the best model. When `save_total_limit=1` and `load_best_model_at_end=True`,"
+                " it is possible that two checkpoints are saved: the last one and the best one (if they are different)."
+                " Default is unlimited checkpoints"
+            )
+        },
+    )
+    save_safetensors: Optional[bool] = field(
+        default=True,
+        metadata={
+            "help": "Use safetensors saving and loading for state dicts instead of default torch.load and torch.save."
+        },
+    )
+    save_on_each_node: bool = field(
+        default=False,
+        metadata={
+            "help": (
+                "When doing multi-node distributed training, whether to save models and checkpoints on each node, or"
+                " only on the main one"
+            )
+        },
+    )
+    save_only_model: bool = field(
+        default=False,
+        metadata={
+            "help": (
+                "When checkpointing, whether to only save the model, or also the optimizer, scheduler & rng state."
+                "Note that when this is true, you won't be able to resume training from checkpoint."
+                "This enables you to save storage by not storing the optimizer, scheduler & rng state."
+                "You can only load the model using from_pretrained with this option set to True."
+            )
+        },
+    )
+    restore_callback_states_from_checkpoint: bool = field(
+        default=False,
+        metadata={
+            "help": "Whether to restore the callback states from the checkpoint. If `True`, will override callbacks passed to the `Trainer` if they exist in the checkpoint."
+        },
+    )
+    no_cuda: bool = field(
+        default=False,
+        metadata={"help": "This argument is deprecated. It will be removed in version 5.0 of 🤗 Transformers."},
+    )
+    use_cpu: bool = field(
+        default=False,
+        metadata={
+            "help": "Whether or not to use cpu. If left to False, we will use the available torch device/backend (cuda/mps/xpu/hpu etc.)"
+        },
+    )
+    use_mps_device: bool = field(
+        default=False,
+        metadata={
+            "help": "This argument is deprecated. `mps` device will be used if available similar to `cuda` device."
+            " It will be removed in version 5.0 of 🤗 Transformers"
+        },
+    )
+    seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
+    data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
+    jit_mode_eval: bool = field(
+        default=False, metadata={"help": "Whether or not to use PyTorch jit trace for inference"}
+    )
+    use_ipex: bool = field(
+        default=False,
+        metadata={
+            "help": (
+                "Use Intel extension for PyTorch when it is available, installation:"
+                " 'https://github.com/intel/intel-extension-for-pytorch'"
+            )
+        },
+    )
+    bf16: bool = field(
+        default=False,
+        metadata={
+            "help": (
+                "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA"
+                " architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change."
+            )
+        },
+    )
+    fp16: bool = field(
+        default=False,
+        metadata={"help": "Whether to use fp16 (mixed) precision instead of 32-bit"},
+    )
+    fp16_opt_level: str = field(
+        default="O1",
+        metadata={
+            "help": (
+                "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. "
+                "See details at https://nvidia.github.io/apex/amp.html"
+            )
+        },
+    )
+    half_precision_backend: str = field(
+        default="auto",
+        metadata={
+            "help": "The backend to be used for half precision.",
+            "choices": ["auto", "apex", "cpu_amp"],
+        },
+    )
+    bf16_full_eval: bool = field(
+        default=False,
+        metadata={
+            "help": (
+                "Whether to use full bfloat16 evaluation instead of 32-bit. This is an experimental API and it may"
+                " change."
+            )
+        },
+    )
+    fp16_full_eval: bool = field(
+        default=False,
+        metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
+    )
+    tf32: Optional[bool] = field(
+        default=None,
+        metadata={
+            "help": (
+                "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental"
+                " API and it may change."
+            )
+        },
+    )
+    local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
+    ddp_backend: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "The backend to be used for distributed training",
+            "choices": ["nccl", "gloo", "mpi", "ccl", "hccl", "cncl", "mccl"],
+        },
+    )
+    tpu_num_cores: Optional[int] = field(
+        default=None, metadata={"help": "TPU: Number of TPU cores (automatically passed by launcher script)"}
+    )
+    tpu_metrics_debug: bool = field(
+        default=False,
+        metadata={
+            "help": (
+                "Deprecated, the use of `--debug tpu_metrics_debug` is preferred. TPU: Whether to print debug metrics"
+            )
+        },
+    )
+    debug: Union[str, list[DebugOption]] = field(
+        default="",
+        metadata={
+            "help": (
+                "Whether or not to enable debug mode. Current options: "
+                "`underflow_overflow` (Detect underflow and overflow in activations and weights), "
+                "`tpu_metrics_debug` (print debug metrics on TPU)."
+            )
+        },
+    )
+
+    dataloader_drop_last: bool = field(
+        default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
+    )
+    eval_steps: Optional[float] = field(
+        default=None,
+        metadata={
+            "help": (
+                "Run an evaluation every X steps. Should be an integer or a float in range `[0,1)`. "
+                "If smaller than 1, will be interpreted as ratio of total training steps."
+            )
+        },
+    )
+    dataloader_num_workers: int = field(
+        default=0,
+        metadata={
+            "help": (
+                "Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded"
+                " in the main process."
+            )
+        },
+    )
+    dataloader_prefetch_factor: Optional[int] = field(
+        default=None,
+        metadata={
+            "help": (
+                "Number of batches loaded in advance by each worker. "
+                "2 means there will be a total of 2 * num_workers batches prefetched across all workers. "
+            )
+        },
+    )
+    past_index: int = field(
+        default=-1,
+        metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."},
+    )
+
+    run_name: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": (
+                "An optional descriptor for the run. Notably used for trackio, wandb, mlflow comet and swanlab "
+                "logging."
+            )
+        },
+    )
+    disable_tqdm: Optional[bool] = field(
+        default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}
+    )
+
+    remove_unused_columns: Optional[bool] = field(
+        default=True, metadata={"help": "Remove columns not required by the model when using an nlp.Dataset."}
+    )
+    label_names: Optional[list[str]] = field(
+        default=None, metadata={"help": "The list of keys in your dictionary of inputs that correspond to the labels."}
+    )
+    load_best_model_at_end: Optional[bool] = field(
+        default=False,
+        metadata={
+            "help": (
+                "Whether or not to load the best model found during training at the end of training. When this option"
+                " is enabled, the best checkpoint will always be saved. See `save_total_limit` for more."
+            )
+        },
+    )
+    metric_for_best_model: Optional[str] = field(
+        default=None, metadata={"help": "The metric to use to compare two different models."}
+    )
+    greater_is_better: Optional[bool] = field(
+        default=None, metadata={"help": "Whether the `metric_for_best_model` should be maximized or not."}
+    )
+    ignore_data_skip: bool = field(
+        default=False,
+        metadata={
+            "help": (
+                "When resuming training, whether or not to skip the first epochs and batches to get to the same"
+                " training data."
+            )
+        },
+    )
+    fsdp: Optional[Union[list[FSDPOption], str]] = field(
+        default="",
+        metadata={
+            "help": (
+                "Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training"
+                " only). The base option should be `full_shard`, `shard_grad_op` or `no_shard` and you can add"
+                " CPU-offload to `full_shard` or `shard_grad_op` like this: full_shard offload` or `shard_grad_op"
+                " offload`. You can add auto-wrap to `full_shard` or `shard_grad_op` with the same syntax: full_shard"
+                " auto_wrap` or `shard_grad_op auto_wrap`."
+            ),
+        },
+    )
+    fsdp_min_num_params: int = field(
+        default=0,
+        metadata={
+            "help": (
+                "This parameter is deprecated. FSDP's minimum number of parameters for Default Auto Wrapping. (useful"
+                " only when `fsdp` field is passed)."
+            )
+        },
+    )
+    fsdp_config: Optional[Union[dict[str, Any], str]] = field(
+        default=None,
+        metadata={
+            "help": (
+                "Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a "
+                "fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`."
+            )
+        },
+    )
+    fsdp_transformer_layer_cls_to_wrap: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": (
+                "This parameter is deprecated. Transformer layer class name (case-sensitive) to wrap, e.g,"
+                " `BertLayer`, `GPTJBlock`, `T5Block` .... (useful only when `fsdp` flag is passed)."
+            )
+        },
+    )
+    accelerator_config: Optional[Union[dict, str]] = field(
+        default=None,
+        metadata={
+            "help": (
+                "Config to be used with the internal Accelerator object initialization. The value is either a "
+                "accelerator json config file (e.g., `accelerator_config.json`) or an already loaded json file as `dict`."
+            )
+        },
+    )
+    parallelism_config: Optional["ParallelismConfig"] = field(
+        default=None,
+        metadata={"help": ("Parallelism configuration for the training run. Requires Accelerate `1.10.1`")},
+    )
+    deepspeed: Optional[Union[dict, str]] = field(
+        default=None,
+        metadata={
+            "help": (
+                "Enable deepspeed and pass the path to deepspeed json config file (e.g. `ds_config.json`) or an already"
+                " loaded json file as a dict"
+            )
+        },
+    )
+    label_smoothing_factor: float = field(
+        default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
+    )
+
+    default_optim = "adamw_torch"
+    if is_torch_available():
+        from .pytorch_utils import is_torch_greater_or_equal_than_2_8
+
+        if is_torch_greater_or_equal_than_2_8:
+            default_optim = "adamw_torch_fused"
+    optim: Union[OptimizerNames, str] = field(
+        default=default_optim,
+        metadata={"help": "The optimizer to use."},
+    )
+    optim_args: Optional[str] = field(default=None, metadata={"help": "Optional arguments to supply to optimizer."})
+    adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
+    group_by_length: bool = field(
+        default=False,
+        metadata={"help": "Whether or not to group samples of roughly the same length together when batching."},
+    )
+    length_column_name: Optional[str] = field(
+        default="length",
+        metadata={"help": "Column name with precomputed lengths to use when grouping by length."},
+    )
+    report_to: Union[None, str, list[str]] = field(
+        default=None, metadata={"help": "The list of integrations to report the results and logs to."}
+    )
+    ddp_find_unused_parameters: Optional[bool] = field(
+        default=None,
+        metadata={
+            "help": (
+                "When using distributed training, the value of the flag `find_unused_parameters` passed to "
+                "`DistributedDataParallel`."
+            )
+        },
+    )
+    ddp_bucket_cap_mb: Optional[int] = field(
+        default=None,
+        metadata={
+            "help": (
+                "When using distributed training, the value of the flag `bucket_cap_mb` passed to "
+                "`DistributedDataParallel`."
+            )
+        },
+    )
+    ddp_broadcast_buffers: Optional[bool] = field(
+        default=None,
+        metadata={
+            "help": (
+                "When using distributed training, the value of the flag `broadcast_buffers` passed to "
+                "`DistributedDataParallel`."
+            )
+        },
+    )
+    dataloader_pin_memory: bool = field(
+        default=True, metadata={"help": "Whether or not to pin memory for DataLoader."}
+    )
+    dataloader_persistent_workers: bool = field(
+        default=False,
+        metadata={
+            "help": "If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will increase RAM usage."
+        },
+    )
+    skip_memory_metrics: bool = field(
+        default=True, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
+    )
+    use_legacy_prediction_loop: bool = field(
+        default=False, metadata={"help": "Whether or not to use the legacy prediction_loop in the Trainer."}
+    )
+    push_to_hub: bool = field(
+        default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
+    )
+    resume_from_checkpoint: Optional[str] = field(
+        default=None,
+        metadata={"help": "The path to a folder with a valid checkpoint for your model."},
+    )
+    hub_model_id: Optional[str] = field(
+        default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
+    )
+    hub_strategy: Union[HubStrategy, str] = field(
+        default="every_save",
+        metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
+    )
+    hub_token: Optional[str] = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
+    hub_private_repo: Optional[bool] = field(
+        default=None,
+        metadata={
+            "help": "Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists."
+        },
+    )
+    hub_always_push: bool = field(
+        default=False,
+        metadata={"help": "Unless `True`, the Trainer will skip pushes if the previous one wasn't finished yet."},
+    )
+    hub_revision: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "The revision to use when pushing to the Hub. Can be a branch name, a tag, or a commit hash."
+        },
+    )
+    gradient_checkpointing: bool = field(
+        default=False,
+        metadata={
+            "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
+        },
+    )
+    gradient_checkpointing_kwargs: Optional[Union[dict[str, Any], str]] = field(
+        default=None,
+        metadata={
+            "help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`."
+        },
+    )
+    include_inputs_for_metrics: bool = field(
+        default=False,
+        metadata={
+            "help": "This argument is deprecated and will be removed in version 5 of 🤗 Transformers. Use `include_for_metrics` instead."
+        },
+    )
+    include_for_metrics: list[str] = field(
+        default_factory=list,
+        metadata={
+            "help": "List of strings to specify additional data to include in the `compute_metrics` function."
+            "Options: 'inputs', 'loss'."
+        },
+    )
+    eval_do_concat_batches: bool = field(
+        default=True,
+        metadata={
+            "help": "Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`, will instead store them as lists, with each batch kept separate."
+        },
+    )
+    # Deprecated arguments
+    fp16_backend: str = field(
+        default="auto",
+        metadata={
+            "help": "Deprecated. Use half_precision_backend instead",
+            "choices": ["auto", "apex", "cpu_amp"],
+        },
+    )
+    push_to_hub_model_id: Optional[str] = field(
+        default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
+    )
+    push_to_hub_organization: Optional[str] = field(
+        default=None, metadata={"help": "The name of the organization in with to which push the `Trainer`."}
+    )
+    push_to_hub_token: Optional[str] = field(
+        default=None, metadata={"help": "The token to use to push to the Model Hub."}
+    )
+    _n_gpu: int = field(init=False, repr=False, default=-1)
+    mp_parameters: str = field(
+        default="",
+        metadata={"help": "Used by the SageMaker launcher to send mp-specific args. Ignored in Trainer"},
+    )
+
+    auto_find_batch_size: bool = field(
+        default=False,
+        metadata={
+            "help": (
+                "Whether to automatically decrease the batch size in half and rerun the training loop again each time"
+                " a CUDA Out-of-Memory was reached"
+            )
+        },
+    )
+    full_determinism: bool = field(
+        default=False,
+        metadata={
+            "help": (
+                "Whether to call enable_full_determinism instead of set_seed for reproducibility in distributed"
+                " training. Important: this will negatively impact the performance, so only use it for debugging."
+            )
+        },
+    )
+    torchdynamo: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "This argument is deprecated, use `--torch_compile_backend` instead.",
+        },
+    )
+    ray_scope: Optional[str] = field(
+        default="last",
+        metadata={
+            "help": (
+                'The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray'
+                " will then use the last checkpoint of all trials, compare those, and select the best one. However,"
+                " other options are also available. See the Ray documentation"
+                " (https://docs.ray.io/en/latest/tune/api_docs/analysis.html"
+                "#ray.tune.ExperimentAnalysis.get_best_trial)"
+                " for more options."
+            )
+        },
+    )
+    ddp_timeout: int = field(
+        default=1800,
+        metadata={
+            "help": "Overrides the default timeout for distributed training (value should be given in seconds)."
+        },
+    )
+    torch_compile: bool = field(
+        default=False, metadata={"help": "If set to `True`, the model will be wrapped in `torch.compile`."}
+    )
+    torch_compile_backend: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "Which backend to use with `torch.compile`, passing one will trigger a model compilation.",
+        },
+    )
+    torch_compile_mode: Optional[str] = field(
+        default=None,
+        metadata={
+            "help": "Which mode to use with `torch.compile`, passing one will trigger a model compilation.",
+        },
+    )
+
+    include_tokens_per_second: Optional[bool] = field(
+        default=False,
+        metadata={"help": "If set to `True`, the speed metrics will include `tgs` (tokens per second per device)."},
+    )
+
+    include_num_input_tokens_seen: Optional[bool] = field(
+        default=False,
+        metadata={
+            "help": "If set to `True`, will track the number of input tokens seen throughout training. (May be slower in distributed training)"
+        },
+    )
+
+    neftune_noise_alpha: Optional[float] = field(
+        default=None,
+        metadata={
+            "help": "Activates neftune noise embeddings into the model. NEFTune has been proven to drastically improve model performances for instruction fine-tuning. Check out the original paper here: https://huggingface.co/papers/2310.05914 and the original code here: https://github.com/neelsjain/NEFTune. Only supported for `PreTrainedModel` and `PeftModel` classes."
+        },
+    )
+
+    optim_target_modules: Union[None, str, list[str]] = field(
+        default=None,
+        metadata={
+            "help": "Target modules for the optimizer defined in the `optim` argument. Only used for the GaLore optimizer at the moment."
+        },
+    )
+
+    batch_eval_metrics: bool = field(
+        default=False,
+        metadata={"help": "Break eval metrics calculation into batches to save memory."},
+    )
+
+    eval_on_start: bool = field(
+        default=False,
+        metadata={
+            "help": "Whether to run through the entire `evaluation` step at the very beginning of training as a sanity check."
+        },
+    )
+
+    use_liger_kernel: Optional[bool] = field(
+        default=False,
+        metadata={"help": "Whether or not to enable the Liger Kernel for model training."},
+    )
+
+    liger_kernel_config: Optional[dict[str, bool]] = field(
+        default=None,
+        metadata={
+            "help": (
+                "Configuration to be used for Liger Kernel. When use_liger_kernel=True, "
+                "this dict is passed as keyword arguments to the `_apply_liger_kernel_to_instance` function, "
+                "which specifies which kernels to apply. Available options vary by model "
+                "but typically include: 'rope', 'swiglu', 'cross_entropy', 'fused_linear_cross_entropy', "
+                "'rms_norm', etc. If None, use the default kernel configurations."
+            )
+        },
+    )
+
+    eval_use_gather_object: Optional[bool] = field(
+        default=False,
+        metadata={
+            "help": "Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices."
+        },
+    )
+
+    average_tokens_across_devices: Optional[bool] = field(
+        default=True,
+        metadata={
+            "help": "Whether or not to average tokens across devices. If enabled, will use all_reduce to "
+            "synchronize num_tokens_in_batch for precise loss calculation. Reference: "
+            "https://github.com/huggingface/transformers/issues/34242"
+        },
+    )
+
+    def __post_init__(self):
+        # Set default output_dir if not provided
+        if self.output_dir is None:
+            self.output_dir = "trainer_output"
+            logger.info(
+                "No output directory specified, defaulting to 'trainer_output'. "
+                "To change this behavior, specify --output_dir when creating TrainingArguments."
+            )
+
+        # Parse in args that could be `dict` sent in from the CLI as a string
+        for field in self._VALID_DICT_FIELDS:
+            passed_value = getattr(self, field)
+            # We only want to do this if the str starts with a bracket to indicate a `dict`
+            # else its likely a filename if supported
+            if isinstance(passed_value, str) and passed_value.startswith("{"):
+                loaded_dict = json.loads(passed_value)
+                # Convert str values to types if applicable
+                loaded_dict = _convert_str_dict(loaded_dict)
+                setattr(self, field, loaded_dict)
+
+        # expand paths, if not os.makedirs("~/bar") will make directory
+        # in the current directory instead of the actual home
+        # see https://github.com/huggingface/transformers/issues/10628
+        if self.output_dir is not None:
+            self.output_dir = os.path.expanduser(self.output_dir)
+        if self.logging_dir is None and self.output_dir is not None:
+            self.logging_dir = os.path.join(self.output_dir, default_logdir())
+        if self.logging_dir is not None:
+            self.logging_dir = os.path.expanduser(self.logging_dir)
+
+        if self.disable_tqdm is None:
+            self.disable_tqdm = logger.getEffectiveLevel() > logging.WARN
+
+        if isinstance(self.eval_strategy, EvaluationStrategy):
+            warnings.warn(
+                "using `EvaluationStrategy` for `eval_strategy` is deprecated and will be removed in version 5"
+                " of 🤗 Transformers. Use `IntervalStrategy` instead",
+                FutureWarning,
+            )
+            # Go back to the underlying string or we won't be able to instantiate `IntervalStrategy` on it.
+            self.eval_strategy = self.eval_strategy.value
+        if self.no_cuda:
+            warnings.warn(
+                "using `no_cuda` is deprecated and will be removed in version 5.0 of 🤗 Transformers. "
+                "Use `use_cpu` instead",
+                FutureWarning,
+            )
+            self.use_cpu = self.no_cuda
+        if self.use_ipex:
+            warnings.warn(
+                "using `use_ipex` is deprecated and will be removed in version 4.54 of 🤗 Transformers. "
+                "You only need PyTorch for the needed optimizations on Intel CPU and XPU.",
+                FutureWarning,
+            )
+
+        self.eval_strategy = IntervalStrategy(self.eval_strategy)
+        self.logging_strategy = IntervalStrategy(self.logging_strategy)
+        self.save_strategy = SaveStrategy(self.save_strategy)
+        self.hub_strategy = HubStrategy(self.hub_strategy)
+
+        self.lr_scheduler_type = SchedulerType(self.lr_scheduler_type)
+        if self.do_eval is False and self.eval_strategy != IntervalStrategy.NO:
+            self.do_eval = True
+
+        if self.torch_empty_cache_steps is not None:
+            if not (isinstance(self.torch_empty_cache_steps, int) and self.torch_empty_cache_steps > 0):
+                raise ValueError(
+                    f"`torch_empty_cache_steps` must be an integer bigger than 0, got {self.torch_empty_cache_steps}."
+                )
+
+        # eval_steps has to be defined and non-zero, fallbacks to logging_steps if the latter is non-zero
+        if self.eval_strategy == IntervalStrategy.STEPS and (self.eval_steps is None or self.eval_steps == 0):
+            if self.logging_steps > 0:
+                logger.info(f"using `logging_steps` to initialize `eval_steps` to {self.logging_steps}")
+                self.eval_steps = self.logging_steps
+            else:
+                raise ValueError(
+                    f"evaluation strategy {self.eval_strategy} requires either non-zero --eval_steps or"
+                    " --logging_steps"
+                )
+
+        # logging_steps must be non-zero for logging_strategy that is other than 'no'
+        if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps == 0:
+            raise ValueError(f"logging strategy {self.logging_strategy} requires non-zero --logging_steps")
+
+        if self.logging_strategy == IntervalStrategy.STEPS and self.logging_steps > 1:
+            if self.logging_steps != int(self.logging_steps):
+                raise ValueError(f"--logging_steps must be an integer if bigger than 1: {self.logging_steps}")
+            self.logging_steps = int(self.logging_steps)
+        if self.eval_strategy == IntervalStrategy.STEPS and self.eval_steps > 1:
+            if self.eval_steps != int(self.eval_steps):
+                raise ValueError(f"--eval_steps must be an integer if bigger than 1: {self.eval_steps}")
+            self.eval_steps = int(self.eval_steps)
+        if self.save_strategy == SaveStrategy.STEPS and self.save_steps > 1:
+            if self.save_steps != int(self.save_steps):
+                raise ValueError(f"--save_steps must be an integer if bigger than 1: {self.save_steps}")
+            self.save_steps = int(self.save_steps)
+
+        # Sanity checks for load_best_model_at_end: we require save and eval strategies to be compatible.
+        if self.load_best_model_at_end and self.save_strategy != SaveStrategy.BEST:
+            if self.eval_strategy != self.save_strategy:
+                raise ValueError(
+                    "--load_best_model_at_end requires the save and eval strategy to match, but found\n- Evaluation "
+                    f"strategy: {self.eval_strategy}\n- Save strategy: {self.save_strategy}"
+                )
+            if self.eval_strategy == IntervalStrategy.STEPS and self.save_steps % self.eval_steps != 0:
+                if self.eval_steps < 1 or self.save_steps < 1:
+                    if not (self.eval_steps < 1 and self.save_steps < 1):
+                        raise ValueError(
+                            "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation "
+                            "steps, which cannot get guaranteed when mixing ratio and absolute steps for save_steps "
+                            f"{self.save_steps} and eval_steps {self.eval_steps}."
+                        )
+                    # Work around floating point precision issues
+                    LARGE_MULTIPLIER = 1_000_000
+                    if (self.save_steps * LARGE_MULTIPLIER) % (self.eval_steps * LARGE_MULTIPLIER) != 0:
+                        raise ValueError(
+                            "--load_best_model_at_end requires the saving steps to be a multiple of the evaluation "
+                            f"steps, but found {self.save_steps}, which is not a multiple of {self.eval_steps}."
+                        )
+                else:
+                    raise ValueError(
+                        "--load_best_model_at_end requires the saving steps to be a round multiple of the evaluation "
+                        f"steps, but found {self.save_steps}, which is not a round multiple of {self.eval_steps}."
+                    )
+
+        safetensors_available = is_safetensors_available()
+        if self.save_safetensors and not safetensors_available:
+            raise ValueError(f"--save_safetensors={self.save_safetensors} requires safetensors to be installed!")
+        if not self.save_safetensors and safetensors_available:
+            logger.info(
+                f"Found safetensors installation, but --save_safetensors={self.save_safetensors}. "
+                f"Safetensors should be a preferred weights saving format due to security and performance reasons. "
+                f"If your model cannot be saved by safetensors please feel free to open an issue at "
+                f"https://github.com/huggingface/safetensors!"
+            )
+
+        if (
+            self.load_best_model_at_end or self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU
+        ) and self.metric_for_best_model is None:
+            self.metric_for_best_model = "loss"
+        if self.greater_is_better is None and self.metric_for_best_model is not None:
+            self.greater_is_better = not self.metric_for_best_model.endswith("loss")
+        if self.framework == "pt" and is_torch_available():
+            if self.fp16_backend and self.fp16_backend != "auto":
+                warnings.warn(
+                    "`fp16_backend` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
+                    " `half_precision_backend` instead",
+                    FutureWarning,
+                )
+                self.half_precision_backend = self.fp16_backend
+
+            if self.bf16 or self.bf16_full_eval:
+                if self.use_cpu and not is_torch_available() and not is_torch_xla_available():
+                    # cpu
+                    raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
+                elif not self.use_cpu:
+                    if not is_torch_bf16_gpu_available() and not is_torch_xla_available():  # added for tpu support
+                        error_message = "Your setup doesn't support bf16/gpu."
+                        if is_torch_cuda_available():
+                            error_message += " You need Ampere+ GPU with cuda>=11.0"
+                        # gpu
+                        raise ValueError(error_message)
+
+        if self.fp16 and self.bf16:
+            raise ValueError("At most one of fp16 and bf16 can be True, but not both")
+
+        if self.fp16_full_eval and self.bf16_full_eval:
+            raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both")
+
+        if self.bf16:
+            if self.half_precision_backend == "apex":
+                raise ValueError(" `--half_precision_backend apex`: GPU bf16 is not supported by apex.")
+
+        if self.half_precision_backend == "apex":
+            if not is_apex_available():
+                raise ImportError(
+                    "Using FP16 with APEX but APEX is not installed, please refer to"
+                    " https://www.github.com/nvidia/apex."
+                )
+            try:
+                from apex import amp  # noqa: F401
+            except ImportError as e:
+                raise ImportError(
+                    f"apex.amp is deprecated in the latest version of apex, causing this error {e}. Either revert to an older version or use pytorch amp by setting half_precision_backend='auto' instead. See https://github.com/NVIDIA/apex/pull/1896 "
+                )
+
+        if self.lr_scheduler_type == SchedulerType.REDUCE_ON_PLATEAU:
+            if self.eval_strategy == IntervalStrategy.NO:
+                raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires an eval strategy")
+            if not is_torch_available():
+                raise ValueError("lr_scheduler_type reduce_lr_on_plateau requires torch>=0.2.0")
+
+        self.optim = OptimizerNames(self.optim)
+        if self.adafactor:
+            warnings.warn(
+                "`--adafactor` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--optim"
+                " adafactor` instead",
+                FutureWarning,
+            )
+            self.optim = OptimizerNames.ADAFACTOR
+
+        # We need to setup the accelerator config here *before* the first call to `self.device`
+        if is_accelerate_available():
+            if not isinstance(self.accelerator_config, AcceleratorConfig):
+                if self.accelerator_config is None:
+                    self.accelerator_config = AcceleratorConfig()
+                elif isinstance(self.accelerator_config, dict):
+                    self.accelerator_config = AcceleratorConfig(**self.accelerator_config)
+                # Check that a user didn't pass in the class instantiator
+                # such as `accelerator_config = AcceleratorConfig`
+                elif isinstance(self.accelerator_config, type):
+                    raise NotImplementedError(
+                        "Tried passing in a callable to `accelerator_config`, but this is not supported. "
+                        "Please pass in a fully constructed `AcceleratorConfig` object instead."
+                    )
+                else:
+                    self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
+            if self.accelerator_config.split_batches:
+                logger.info(
+                    "Using `split_batches=True` in `accelerator_config` will override the `per_device_train_batch_size` "
+                    "Batches will be split across all processes equally when using `split_batches=True`."
+                )
+
+        # Initialize device before we proceed
+        if self.framework == "pt" and is_torch_available():
+            self.device
+
+        # Disable average tokens when using single device
+        if self.average_tokens_across_devices:
+            try:
+                if self.world_size == 1:
+                    logger.info(
+                        "average_tokens_across_devices is True but world size is 1. Setting it to False automatically."
+                    )
+                    self.average_tokens_across_devices = False
+            except ImportError as e:
+                logger.warning(f"Can not specify world size due to {e}. Turn average_tokens_across_devices to False.")
+                self.average_tokens_across_devices = False
+
+        if self.torchdynamo is not None:
+            warnings.warn(
+                "`torchdynamo` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
+                " `torch_compile_backend` instead",
+                FutureWarning,
+            )
+            self.torch_compile_backend = self.torchdynamo
+        if (self.torch_compile_mode is not None or self.torch_compile_backend is not None) and not self.torch_compile:
+            self.torch_compile = True
+        if self.torch_compile and self.torch_compile_backend is None:
+            if not self.use_cpu and is_torch_hpu_available():
+                self.torch_compile_backend = "hpu_backend"
+            else:
+                self.torch_compile_backend = "inductor"
+
+        # accelerate integration for torch compile
+        if self.torch_compile:
+            # set env vars for accelerate
+            prefix = "ACCELERATE_DYNAMO_"
+            os.environ[prefix + "BACKEND"] = self.torch_compile_backend
+            if self.torch_compile_mode is not None:
+                os.environ[prefix + "MODE"] = self.torch_compile_mode
+
+        if self.framework == "pt" and is_torch_available() and self.torch_compile:
+            if is_torch_tf32_available():
+                if self.tf32 is None and not self.fp16 or self.bf16:
+                    logger.info(
+                        "Setting TF32 in CUDA backends to speedup torch compile, you won't see any improvement"
+                        " otherwise."
+                    )
+                    torch.backends.cuda.matmul.allow_tf32 = True
+                    torch.backends.cudnn.allow_tf32 = True
+            else:
+                logger.warning(
+                    "The speedups for torchdynamo mostly come with GPU Ampere or higher and which is not detected here."
+                )
+        if self.framework == "pt" and is_torch_available() and self.tf32 is not None:
+            if self.tf32:
+                if is_torch_tf32_available():
+                    torch.backends.cuda.matmul.allow_tf32 = True
+                    torch.backends.cudnn.allow_tf32 = True
+                else:
+                    raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
+            else:
+                if is_torch_tf32_available():
+                    torch.backends.cuda.matmul.allow_tf32 = False
+                    torch.backends.cudnn.allow_tf32 = False
+                # no need to assert on else
+
+        # if training args is specified, it will override the one specified in the accelerate config
+        if self.half_precision_backend != "apex":
+            mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
+            if self.fp16:
+                mixed_precision_dtype = "fp16"
+            elif self.bf16:
+                mixed_precision_dtype = "bf16"
+            os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype
+
+        if self.report_to is None:
+            logger.info(
+                "The default value for the training argument `--report_to` will change in v5 (from all installed "
+                "integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as "
+                "now. You should start updating your code and make this info disappear :-)."
+            )
+            self.report_to = "all"
+        if self.report_to == "all" or self.report_to == ["all"]:
+            # Import at runtime to avoid a circular import.
+            from .integrations import get_available_reporting_integrations
+
+            self.report_to = get_available_reporting_integrations()
+
+            if "codecarbon" in self.report_to and torch.version.hip:
+                logger.warning(
+                    "When using the Trainer, CodeCarbonCallback requires the `codecarbon` package, which is not compatible with AMD ROCm (https://github.com/mlco2/codecarbon/pull/490). Automatically disabling the codecarbon callback. Reference: https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments.report_to."
+                )
+                self.report_to.remove("codecarbon")
+
+        elif self.report_to == "none" or self.report_to == ["none"]:
+            self.report_to = []
+        elif not isinstance(self.report_to, list):
+            self.report_to = [self.report_to]
+
+        if self.warmup_ratio < 0 or self.warmup_ratio > 1:
+            raise ValueError("warmup_ratio must lie in range [0,1]")
+        elif self.warmup_ratio > 0 and self.warmup_steps > 0:
+            logger.info(
+                "Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio"
+                " during training"
+            )
+
+        if not isinstance(self.warmup_steps, int) or self.warmup_steps < 0:
+            raise ValueError("warmup_steps must be of type int and must be 0 or a positive integer.")
+
+        if isinstance(self.fsdp, bool):
+            self.fsdp = [FSDPOption.FULL_SHARD] if self.fsdp else ""
+        if isinstance(self.fsdp, str):
+            self.fsdp = [FSDPOption(s) for s in self.fsdp.split()]
+        if self.fsdp == [FSDPOption.OFFLOAD]:
+            raise ValueError(
+                "`--fsdp offload` can't work on its own. It needs to be added to `--fsdp full_shard` or "
+                '`--fsdp shard_grad_op`. For example, `--fsdp "full_shard offload"`.'
+            )
+        elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp:
+            raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.")
+
+        if self.gradient_checkpointing and (
+            FSDPOption.FULL_SHARD in self.fsdp or FSDPOption.HYBRID_SHARD in self.fsdp
+        ):
+            logger.warning(
+                "When using FSDP full shard, instead of using `gradient_checkpointing` in TrainingArguments, please"
+                " use `activation_checkpointing` in `fsdp_config`. The former introduces a redundant AllGather"
+                " operation in backward pass. Reference: https://github.com/huggingface/transformers/issues/30404"
+            )
+
+        if self.fsdp_config is None:
+            self.fsdp_config = {}
+
+        if isinstance(self.fsdp_config, str):
+            if len(self.fsdp) == 0:
+                warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.")
+            with open(self.fsdp_config, encoding="utf-8") as f:
+                self.fsdp_config = json.load(f)
+
+        if self.fsdp_config is not None and isinstance(self.fsdp_config, dict):
+            for k in list(self.fsdp_config.keys()):
+                if k.startswith("fsdp_"):
+                    v = self.fsdp_config.pop(k)
+                    self.fsdp_config[k[5:]] = v
+
+        if self.fsdp_min_num_params > 0:
+            warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning)
+
+        self.fsdp_config["min_num_params"] = max(self.fsdp_config.get("min_num_params", 0), self.fsdp_min_num_params)
+
+        # if fsdp_config["transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
+        if isinstance(self.fsdp_config.get("transformer_layer_cls_to_wrap", None), str):
+            self.fsdp_config["transformer_layer_cls_to_wrap"] = [self.fsdp_config["transformer_layer_cls_to_wrap"]]
+
+        if self.fsdp_transformer_layer_cls_to_wrap is not None:
+            warnings.warn(
+                "using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning
+            )
+            self.fsdp_config["transformer_layer_cls_to_wrap"] = self.fsdp_config.get(
+                "transformer_layer_cls_to_wrap", []
+            ) + [self.fsdp_transformer_layer_cls_to_wrap]
+
+        if len(self.fsdp) == 0 and self.fsdp_config["min_num_params"] > 0:
+            warnings.warn("`min_num_params` is useful only when `--fsdp` is specified.")
+
+        if len(self.fsdp) == 0 and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None:
+            warnings.warn("`transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")
+
+        if (
+            len(self.fsdp) > 0
+            and self.fsdp_config["min_num_params"] > 0
+            and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None
+        ):
+            raise ValueError("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive.")
+        self.fsdp_config["xla"] = self.fsdp_config.get("xla", False)
+        self.fsdp_config["xla_fsdp_v2"] = self.fsdp_config.get("xla_fsdp_v2", False)
+        self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False)
+        if self.fsdp_config["xla"]:
+            if len(self.fsdp) > 0:
+                # store XLA fsdp configuration parameters into a dictionary
+                # Copy the config to avoid modifying the original config (which may be used for JSON serialization)
+                self.xla_fsdp_config = self.fsdp_config.get("xla_fsdp_settings", {}).copy()
+                # apply appropriate string to torch.dtype conversions for parameters
+                if "compute_dtype" in self.xla_fsdp_config:
+                    self.xla_fsdp_config["compute_dtype"] = getattr(torch, self.xla_fsdp_config["compute_dtype"])
+                if "buffer_dtype" in self.xla_fsdp_config:
+                    self.xla_fsdp_config["buffer_dtype"] = getattr(torch, self.xla_fsdp_config["buffer_dtype"])
+            else:
+                warnings.warn("XLA FSDP can be used only when `--fsdp` is specified.")
+        else:
+            if self.fsdp_config["xla_fsdp_grad_ckpt"]:
+                warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")
+
+        # accelerate integration for FSDP
+        if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
+            os.environ["ACCELERATE_USE_FSDP"] = "true"
+            from accelerate.utils.constants import (
+                FSDP_AUTO_WRAP_POLICY,
+                FSDP_SHARDING_STRATEGY,
+            )
+
+            prefix = "FSDP_"
+            for fsdp_option in self.fsdp:
+                if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
+                    # set environment variable for FSDP sharding strategy
+                    os.environ[f"{prefix}SHARDING_STRATEGY"] = str(
+                        FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1
+                    )
+                elif fsdp_option == FSDPOption.OFFLOAD:
+                    os.environ[f"{prefix}OFFLOAD_PARAMS"] = "true"
+                elif fsdp_option == FSDPOption.AUTO_WRAP:
+                    os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
+                    if self.fsdp_config["min_num_params"] > 0:
+                        os.environ[f"{prefix}MIN_NUM_PARAMS"] = str(self.fsdp_config["min_num_params"])
+                        os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
+                    elif self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None:
+                        os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"] = ",".join(
+                            self.fsdp_config["transformer_layer_cls_to_wrap"]
+                        )
+            prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH")
+            os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper()
+            os.environ[f"{prefix}FORWARD_PREFETCH"] = str(self.fsdp_config.get("forward_prefetch", "false")).lower()
+
+            sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower()
+            cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower()
+
+            if sync_module_states == "false" and cpu_ram_efficient_loading == "true":
+                # In this case, all the processes except the main process would have random weights leading
+                # to unexpected behaviour during training, thus throwing error here to prevent it.
+                raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`')
+
+            os.environ[f"{prefix}SYNC_MODULE_STATES"] = sync_module_states
+            os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading
+
+            os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "true")).lower()
+
+        if self.tpu_metrics_debug:
+            warnings.warn(
+                "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
+                " `--debug tpu_metrics_debug` instead",
+                FutureWarning,
+            )
+            if self.debug is None:
+                self.debug = " tpu_metrics_debug"
+            else:
+                self.debug += " tpu_metrics_debug"
+            self.tpu_metrics_debug = False
+
+        if isinstance(self.debug, str):
+            self.debug = [DebugOption(s) for s in self.debug.split()]
+        elif self.debug is None:
+            self.debug = []
+
+        self.deepspeed_plugin = None
+        if self.deepspeed:
+            # - must be run very last in arg parsing, since it will use a lot of these settings.
+            # - must be run before the model is created.
+            if not is_accelerate_available():
+                raise ValueError(
+                    f"--deepspeed requires Accelerate to be installed: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`."
+                )
+            from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
+
+            # will be used later by the Trainer
+            # note: leave self.deepspeed unmodified in case a user relies on it not to be modified)
+            self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed)
+            self.hf_deepspeed_config.trainer_config_process(self)
+
+            # Accelerate DeepSpeed Plugin
+            from accelerate.utils import DeepSpeedPlugin
+
+            os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
+            self.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.hf_deepspeed_config)
+        elif strtobool(os.environ.get("ACCELERATE_USE_DEEPSPEED", "false")):
+            # Accelerate DeepSpeed Plugin
+            from accelerate.utils import DeepSpeedPlugin
+
+            self.deepspeed_plugin = DeepSpeedPlugin()
+            mixed_precision = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
+            self.deepspeed_plugin.set_mixed_precision(mixed_precision)
+            self.deepspeed_plugin.set_deepspeed_weakref()
+
+        if self.use_cpu:
+            self.dataloader_pin_memory = False
+
+        if self.dataloader_num_workers == 0 and self.dataloader_prefetch_factor is not None:
+            raise ValueError(
+                "--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e."
+                " when --dataloader_num_workers > 1."
+            )
+
+        if self.push_to_hub_token is not None:
+            warnings.warn(
+                "`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
+                "`--hub_token` instead.",
+                FutureWarning,
+            )
+            self.hub_token = self.push_to_hub_token
+
+        if self.push_to_hub_model_id is not None:
+            self.hub_model_id = get_full_repo_name(
+                self.push_to_hub_model_id, organization=self.push_to_hub_organization, token=self.hub_token
+            )
+            if self.push_to_hub_organization is not None:
+                warnings.warn(
+                    "`--push_to_hub_model_id` and `--push_to_hub_organization` are deprecated and will be removed in "
+                    "version 5 of 🤗 Transformers. Use `--hub_model_id` instead and pass the full repo name to this "
+                    f"argument (in this case {self.hub_model_id}).",
+                    FutureWarning,
+                )
+            else:
+                warnings.warn(
+                    "`--push_to_hub_model_id` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
+                    "`--hub_model_id` instead and pass the full repo name to this argument (in this case "
+                    f"{self.hub_model_id}).",
+                    FutureWarning,
+                )
+        elif self.push_to_hub_organization is not None:
+            self.hub_model_id = f"{self.push_to_hub_organization}/{Path(self.output_dir).name}"
+            warnings.warn(
+                "`--push_to_hub_organization` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
+                "`--hub_model_id` instead and pass the full repo name to this argument (in this case "
+                f"{self.hub_model_id}).",
+                FutureWarning,
+            )
+
+        if self.eval_use_gather_object and not is_accelerate_available("0.30.0"):
+            raise ValueError(
+                "--eval_use_gather_object requires Accelerate to be version of `accelerate` > 0.30.0."
+                "This is not supported and we recommend you to update your version."
+            )
+
+        if self.data_seed is not None:
+            if not is_accelerate_available("1.1.0"):
+                raise NotImplementedError(
+                    "data_seed requires Accelerate version `accelerate` >= 1.1.0. "
+                    "This is not supported and we recommend you to update your version."
+                )
+
+        if self.include_inputs_for_metrics:
+            logger.warning(
+                "Using `include_inputs_for_metrics` is deprecated and will be removed in version 5 of 🤗 Transformers. Please use `include_for_metrics` list argument instead."
+            )
+            self.include_for_metrics.append("inputs")
+
+    def __str__(self):
+        self_as_dict = asdict(self)
+
+        # Remove deprecated arguments. That code should be removed once
+        # those deprecated arguments are removed from TrainingArguments. (TODO: v5)
+        del self_as_dict["per_gpu_train_batch_size"]
+        del self_as_dict["per_gpu_eval_batch_size"]
+
+        self_as_dict = {k: f"<{k.upper()}>" if k.endswith("_token") else v for k, v in self_as_dict.items()}
+
+        attrs_as_str = [f"{k}={v},\n" for k, v in sorted(self_as_dict.items())]
+        return f"{self.__class__.__name__}(\n{''.join(attrs_as_str)})"
+
+    __repr__ = __str__
+
+    @property
+    def train_batch_size(self) -> int:
+        """
+        The actual batch size for training (may differ from `per_gpu_train_batch_size` in distributed training).
+        """
+        if self.per_gpu_train_batch_size:
+            logger.warning(
+                "Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future "
+                "version. Using `--per_device_train_batch_size` is preferred."
+            )
+        per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
+        train_batch_size = per_device_batch_size * max(1, self.n_gpu)
+        return train_batch_size
+
+    @property
+    def eval_batch_size(self) -> int:
+        """
+        The actual batch size for evaluation (may differ from `per_gpu_eval_batch_size` in distributed training).
+        """
+        if self.per_gpu_eval_batch_size:
+            logger.warning(
+                "Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future "
+                "version. Using `--per_device_eval_batch_size` is preferred."
+            )
+        per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
+        eval_batch_size = per_device_batch_size * max(1, self.n_gpu)
+        return eval_batch_size
+
+    @property
+    def ddp_timeout_delta(self) -> timedelta:
+        """
+        The actual timeout for torch.distributed.init_process_group since it expects a timedelta variable.
+        """
+        return timedelta(seconds=self.ddp_timeout)
+
+    @cached_property
+    def _setup_devices(self) -> "torch.device":
+        requires_backends(self, ["torch"])
+        logger.info("PyTorch: setting up devices")
+        if not is_sagemaker_mp_enabled():
+            if not is_accelerate_available():
+                raise ImportError(
+                    f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: "
+                    f"Please run `pip install transformers[torch]` or `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
+                )
+        # We delay the init of `PartialState` to the end for clarity
+        accelerator_state_kwargs: dict[str, Any] = {"enabled": True, "use_configured_state": False}
+        if isinstance(self.accelerator_config, AcceleratorConfig):
+            accelerator_state_kwargs["use_configured_state"] = self.accelerator_config.pop(
+                "use_configured_state", False
+            )
+        if accelerator_state_kwargs["use_configured_state"]:
+            if PartialState._shared_state == {}:
+                raise ValueError(
+                    "Passing `'use_configured_state':True` to the AcceleratorConfig requires a pre-configured "
+                    "`AcceleratorState` or `PartialState` to be defined before calling `TrainingArguments`. "
+                )
+            # We rely on `PartialState` to yell if there's issues here (which it will)
+            self.distributed_state = PartialState(cpu=self.use_cpu)
+            if self.deepspeed and self.distributed_state.distributed_type != DistributedType.DEEPSPEED:
+                raise RuntimeError(
+                    "Tried to use an already configured `Accelerator` or `PartialState` that was not initialized for DeepSpeed, "
+                    "but also passed in a `deepspeed` configuration to the `TrainingArguments`. Please set "
+                    "`use_configured_state:False` instead or setup your `Accelerator` or `PartialState` properly."
+                )
+        else:
+            AcceleratorState._reset_state(reset_partial_state=True)
+            self.distributed_state = None
+        if not self.use_ipex and "ACCELERATE_USE_IPEX" not in os.environ:
+            os.environ["ACCELERATE_USE_IPEX"] = "false"
+
+        self._n_gpu = 1
+        if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")):
+            accelerator_state_kwargs["cpu"] = True
+            accelerator_state_kwargs["backend"] = self.ddp_backend
+            self._n_gpu = 0
+        elif is_sagemaker_mp_enabled():
+            accelerator_state_kwargs["enabled"] = False
+            local_rank = smp.local_rank()
+            device = torch.device("cuda", local_rank)
+            torch.cuda.set_device(device)
+        elif is_sagemaker_dp_enabled():
+            accelerator_state_kwargs["_use_sagemaker_dp"] = True
+        elif self.deepspeed:
+            accelerator_state_kwargs["use_deepspeed"] = True
+            accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout)
+        else:
+            accelerator_state_kwargs["backend"] = self.ddp_backend
+            accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout)
+
+        # Now we pop everything
+        if accelerator_state_kwargs.pop("enabled", False) and not accelerator_state_kwargs.pop(
+            "use_configured_state", False
+        ):
+            # We need to patch this env var when enabling to detect deepspeed
+            use_deepspeed = accelerator_state_kwargs.pop("use_deepspeed", False)
+            if use_deepspeed:
+                os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
+            self.distributed_state = PartialState(**accelerator_state_kwargs)
+            if use_deepspeed:
+                del os.environ["ACCELERATE_USE_DEEPSPEED"]
+        if not is_sagemaker_mp_enabled():
+            device = self.distributed_state.device
+            self.local_rank = self.distributed_state.local_process_index
+        if dist.is_available() and dist.is_initialized() and self.parallel_mode != ParallelMode.DISTRIBUTED:
+            logger.warning(
+                "torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. "
+                "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
+            )
+        if is_torch_xla_available():
+            device = self.distributed_state.device
+            self._n_gpu = 0
+        elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled():
+            # Already set _n_gpu
+            pass
+        elif self.distributed_state.distributed_type == DistributedType.NO:
+            if self.use_mps_device:
+                warnings.warn(
+                    "`use_mps_device` is deprecated and will be removed in version 5.0 of 🤗 Transformers. "
+                    "`mps` device will be used by default if available similar to the way `cuda` device is used."
+                    "Therefore, no action from user is required. "
+                )
+                if device.type != "mps":
+                    raise ValueError(
+                        "Either you do not have an MPS-enabled device on this machine or MacOS version is not 12.3+ "
+                        "or current PyTorch install was not built with MPS enabled."
+                    )
+            if self.use_cpu:
+                device = torch.device("cpu")
+            elif is_torch_mps_available():
+                device = torch.device("mps")
+            elif is_torch_xpu_available():
+                if not is_ipex_available() and not is_accelerate_available("0.32.0.dev"):
+                    raise ImportError("Using the XPU PyTorch backend requires `accelerate>=0.32.0.dev`")
+                device = torch.device("xpu:0")
+                torch.xpu.set_device(device)
+            elif is_torch_mlu_available():
+                device = torch.device("mlu:0")
+                torch.mlu.set_device(device)
+            elif is_torch_musa_available():
+                device = torch.device("musa:0")
+                torch.musa.set_device(device)
+            elif is_torch_npu_available():
+                device = torch.device("npu:0")
+                torch.npu.set_device(device)
+            elif is_torch_hpu_available():
+                device = torch.device("hpu:0")
+                torch.hpu.set_device(device)
+            else:
+                # if n_gpu is > 1 we'll use nn.DataParallel.
+                # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
+                # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
+                # trigger an error that a device index is missing. Index 0 takes into account the
+                # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
+                # will use the first GPU in that env, i.e. GPU#1
+                device = torch.device(
+                    "cuda:0" if torch.cuda.is_available() else os.environ.get("ACCELERATE_TORCH_DEVICE", "cpu")
+                )
+                # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
+                # the default value.
+                self._n_gpu = torch.cuda.device_count()
+                if device.type == "cuda":
+                    torch.cuda.set_device(device)
+        return device
+
+    @property
+    def device(self) -> "torch.device":
+        """
+        The device used by this process.
+        """
+        requires_backends(self, ["torch"])
+        return self._setup_devices
+
+    @property
+    def n_gpu(self):
+        """
+        The number of GPUs used by this process.
+
+        Note:
+            This will only be greater than one when you have multiple GPUs available but are not using distributed
+            training. For distributed training, it will always be 1.
+        """
+        requires_backends(self, ["torch"])
+        # Make sure `self._n_gpu` is properly setup.
+        if not hasattr(self, "_n_gpu"):
+            _ = self._setup_devices
+        return self._n_gpu
+
+    @property
+    def parallel_mode(self):
+        """
+        The current mode used for parallelism if multiple GPUs/TPU cores are available. One of:
+
+        - `ParallelMode.NOT_PARALLEL`: no parallelism (CPU or one GPU).
+        - `ParallelMode.NOT_DISTRIBUTED`: several GPUs in one single process (uses `torch.nn.DataParallel`).
+        - `ParallelMode.DISTRIBUTED`: several GPUs, each having its own process (uses
+          `torch.nn.DistributedDataParallel`).
+        - `ParallelMode.TPU`: several TPU cores.
+        """
+        requires_backends(self, ["torch"])
+        if is_torch_xla_available():
+            return ParallelMode.TPU
+        elif is_sagemaker_mp_enabled():
+            return ParallelMode.SAGEMAKER_MODEL_PARALLEL
+        elif is_sagemaker_dp_enabled():
+            return ParallelMode.SAGEMAKER_DATA_PARALLEL
+        elif (
+            self.distributed_state is not None and self.distributed_state.distributed_type != DistributedType.NO
+        ) or (self.distributed_state is None and self.local_rank != -1):
+            return ParallelMode.DISTRIBUTED
+        elif self.n_gpu > 1:
+            return ParallelMode.NOT_DISTRIBUTED
+        else:
+            return ParallelMode.NOT_PARALLEL
+
+    @property
+    def world_size(self):
+        """
+        The number of processes used in parallel.
+        """
+        requires_backends(self, ["torch"])
+        if self.distributed_state is not None:
+            return self.distributed_state.num_processes
+        elif is_sagemaker_mp_enabled():
+            return smp.dp_size() if not smp.state.cfg.prescaled_batch else smp.rdp_size()
+        return 1
+
+    @property
+    def process_index(self):
+        """
+        The index of the current process used.
+        """
+        requires_backends(self, ["torch"])
+        if self.distributed_state is not None:
+            return self.distributed_state.process_index
+        elif is_sagemaker_mp_enabled():
+            return smp.dp_rank() if not smp.state.cfg.prescaled_batch else smp.rdp_rank()
+        return 0
+
+    @property
+    def local_process_index(self):
+        """
+        The index of the local process used.
+        """
+        requires_backends(self, ["torch"])
+
+        if self.distributed_state is not None:
+            return self.distributed_state.local_process_index
+        elif is_sagemaker_mp_enabled():
+            return smp.local_rank()
+        return 0
+
+    @property
+    def should_log(self):
+        """
+        Whether or not the current process should produce log.
+        """
+        if self.log_on_each_node:
+            return self.local_process_index == 0
+        else:
+            if is_sagemaker_mp_enabled():
+                return smp.rank() == 0
+            else:
+                return self.process_index == 0
+
+    @property
+    def should_save(self):
+        """
+        Whether or not the current process should write to disk, e.g., to save models and checkpoints.
+        """
+        if self.save_on_each_node:
+            return self.local_process_index == 0
+        else:
+            if is_sagemaker_mp_enabled():
+                return smp.rank() == 0
+            else:
+                return self.process_index == 0
+
+    def get_process_log_level(self):
+        """
+        Returns the log level to be used depending on whether this process is the main process of node 0, main process
+        of node non-0, or a non-main process.
+
+        For the main process the log level defaults to the logging level set (`logging.WARNING` if you didn't do
+        anything) unless overridden by `log_level` argument.
+
+        For the replica processes the log level defaults to `logging.WARNING` unless overridden by `log_level_replica`
+        argument.
+
+        The choice between the main and replica process settings is made according to the return value of `should_log`.
+        """
+
+        # convert to int
+        log_level = trainer_log_levels[self.log_level]
+        log_level_replica = trainer_log_levels[self.log_level_replica]
+
+        log_level_main_node = logging.get_verbosity() if log_level == -1 else log_level
+        log_level_replica_node = logging.get_verbosity() if log_level_replica == -1 else log_level_replica
+        return log_level_main_node if self.should_log else log_level_replica_node
+
+    @property
+    def place_model_on_device(self):
+        """
+        Can be subclassed and overridden for some specific integrations.
+        """
+        return not is_sagemaker_mp_enabled()
+
+    @property
+    def _no_sync_in_gradient_accumulation(self):
+        """
+        Whether or not to use no_sync for the gradients when doing gradient accumulation.
+        """
+        return not (
+            self.deepspeed or is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled() or is_torch_neuroncore_available()
+        )
+
+    @contextlib.contextmanager
+    def main_process_first(self, local=True, desc="work"):
+        """
+        A context manager for torch distributed environment where on needs to do something on the main process, while
+        blocking replicas, and when it's finished releasing the replicas.
+
+        One such use is for `datasets`'s `map` feature which to be efficient should be run once on the main process,
+        which upon completion saves a cached version of results and which then automatically gets loaded by the
+        replicas.
+
+        Args:
+            local (`bool`, *optional*, defaults to `True`):
+                if `True` first means process of rank 0 of each node if `False` first means process of rank 0 of node
+                rank 0 In multi-node environment with a shared filesystem you most likely will want to use
+                `local=False` so that only the main process of the first node will do the processing. If however, the
+                filesystem is not shared, then the main process of each node will need to do the processing, which is
+                the default behavior.
+            desc (`str`, *optional*, defaults to `"work"`):
+                a work description to be used in debug logs
+
+        """
+        if is_torch_available() and self.world_size > 1:
+            main_process_desc = "main local process" if local else "main process"
+            if self.distributed_state is not None:
+                is_main_process = (
+                    self.distributed_state.is_local_main_process if local else self.distributed_state.is_main_process
+                )
+            elif is_sagemaker_mp_enabled():
+                is_main_process = smp.rank() == 0
+
+            try:
+                if not is_main_process:
+                    # tell all replicas to wait
+                    logger.debug(f"{self.process_index}: waiting for the {main_process_desc} to perform {desc}")
+
+                    if is_torch_xla_available():
+                        xm.rendezvous(desc)
+                    else:
+                        dist.barrier()
+                yield
+            finally:
+                if is_main_process:
+                    # the wait is over
+                    logger.debug(f"{self.process_index}: {main_process_desc} completed {desc}, releasing all replicas")
+                    if is_torch_xla_available():
+                        xm.rendezvous(desc)
+                    else:
+                        dist.barrier()
+        else:
+            yield
+
+    def get_warmup_steps(self, num_training_steps: int):
+        """
+        Get number of steps used for a linear warmup.
+        """
+        warmup_steps = (
+            self.warmup_steps if self.warmup_steps > 0 else math.ceil(num_training_steps * self.warmup_ratio)
+        )
+        return warmup_steps
+
+    def _dict_dtype_to_str(self, d: dict[str, Any]) -> None:
+        """
+        Checks whether the passed dictionary and its nested dicts have a *dtype* key and if it's not None,
+        converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
+        string, which can then be stored in the json format.
+        """
+        if d.get("dtype") is not None and not isinstance(d["dtype"], str):
+            d["dtype"] = str(d["dtype"]).split(".")[1]
+        for value in d.values():
+            if isinstance(value, dict):
+                self._dict_dtype_to_str(value)
+
+    def to_dict(self):
+        """
+        Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
+        the token values by removing their value.
+        """
+        # filter out fields that are defined as field(init=False)
+        d = {field.name: getattr(self, field.name) for field in fields(self) if field.init}
+
+        for k, v in d.items():
+            if isinstance(v, Enum):
+                d[k] = v.value
+            if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
+                d[k] = [x.value for x in v]
+            if k.endswith("_token"):
+                d[k] = f"<{k.upper()}>"
+            # Handle the accelerator_config if passed
+            if is_accelerate_available() and isinstance(v, AcceleratorConfig):
+                d[k] = v.to_dict()
+            # Handle the quantization_config if passed
+            if k == "model_init_kwargs" and isinstance(v, dict) and "quantization_config" in v:
+                quantization_config = v.get("quantization_config")
+                if quantization_config and not isinstance(quantization_config, dict):
+                    d[k]["quantization_config"] = quantization_config.to_dict()
+            if k == "parallelism_config" and v is not None:
+                d[k] = v.to_json()
+
+        self._dict_dtype_to_str(d)
+
+        return d
+
+    def to_json_string(self):
+        """
+        Serializes this instance to a JSON string.
+        """
+        return json.dumps(self.to_dict(), indent=2)
+
+    def to_sanitized_dict(self) -> dict[str, Any]:
+        """
+        Sanitized serialization to use with TensorBoard’s hparams
+        """
+        d = self.to_dict()
+        d = {**d, **{"train_batch_size": self.train_batch_size, "eval_batch_size": self.eval_batch_size}}
+
+        valid_types = [bool, int, float, str]
+        if is_torch_available():
+            valid_types.append(torch.Tensor)
+
+        return {k: v if type(v) in valid_types else str(v) for k, v in d.items()}
+
+    # The following methods are there to simplify the instantiation of `TrainingArguments`
+    def set_training(
+        self,
+        learning_rate: float = 5e-5,
+        batch_size: int = 8,
+        weight_decay: float = 0,
+        num_epochs: float = 3,
+        max_steps: int = -1,
+        gradient_accumulation_steps: int = 1,
+        seed: int = 42,
+        gradient_checkpointing: bool = False,
+    ):
+        """
+        A method that regroups all basic arguments linked to the training.
+
+        
+
+        Calling this method will automatically set `self.do_train` to `True`.
+
+        
+
+        Args:
+            learning_rate (`float`, *optional*, defaults to 5e-5):
+                The initial learning rate for the optimizer.
+            batch_size (`int` *optional*, defaults to 8):
+                The batch size per device (GPU/TPU core/CPU...) used for training.
+            weight_decay (`float`, *optional*, defaults to 0):
+                The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in the
+                optimizer.
+            num_train_epochs(`float`, *optional*, defaults to 3.0):
+                Total number of training epochs to perform (if not an integer, will perform the decimal part percents
+                of the last epoch before stopping training).
+            max_steps (`int`, *optional*, defaults to -1):
+                If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`.
+                For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until
+                `max_steps` is reached.
+            gradient_accumulation_steps (`int`, *optional*, defaults to 1):
+                Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
+
+                
+
+                When using gradient accumulation, one step is counted as one step with backward pass. Therefore,
+                logging, evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training
+                examples.
+
+                
+
+            seed (`int`, *optional*, defaults to 42):
+                Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use
+                the [`~Trainer.model_init`] function to instantiate the model if it has some randomly initialized
+                parameters.
+            gradient_checkpointing (`bool`, *optional*, defaults to `False`):
+                If True, use gradient checkpointing to save memory at the expense of slower backward pass.
+
+        Example:
+
+        ```py
+        >>> from transformers import TrainingArguments
+
+        >>> args = TrainingArguments("working_dir")
+        >>> args = args.set_training(learning_rate=1e-4, batch_size=32)
+        >>> args.learning_rate
+        1e-4
+        ```
+        """
+        self.do_train = True
+        self.learning_rate = learning_rate
+        self.per_device_train_batch_size = batch_size
+        self.weight_decay = weight_decay
+        self.num_train_epochs = num_epochs
+        self.max_steps = max_steps
+        self.gradient_accumulation_steps = gradient_accumulation_steps
+        self.seed = seed
+        self.gradient_checkpointing = gradient_checkpointing
+        return self
+
+    def set_evaluate(
+        self,
+        strategy: Union[str, IntervalStrategy] = "no",
+        steps: int = 500,
+        batch_size: int = 8,
+        accumulation_steps: Optional[int] = None,
+        delay: Optional[float] = None,
+        loss_only: bool = False,
+        jit_mode: bool = False,
+    ):
+        """
+        A method that regroups all arguments linked to evaluation.
+
+        Args:
+            strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`):
+                The evaluation strategy to adopt during training. Possible values are:
+
+                    - `"no"`: No evaluation is done during training.
+                    - `"steps"`: Evaluation is done (and logged) every `steps`.
+                    - `"epoch"`: Evaluation is done at the end of each epoch.
+
+                Setting a `strategy` different from `"no"` will set `self.do_eval` to `True`.
+            steps (`int`, *optional*, defaults to 500):
+                Number of update steps between two evaluations if `strategy="steps"`.
+            batch_size (`int` *optional*, defaults to 8):
+                The batch size per device (GPU/TPU core/CPU...) used for evaluation.
+            accumulation_steps (`int`, *optional*):
+                Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU.
+                If left unset, the whole predictions are accumulated on GPU/TPU before being moved to the CPU (faster
+                but requires more memory).
+            delay (`float`, *optional*):
+                Number of epochs or steps to wait for before the first evaluation can be performed, depending on the
+                eval_strategy.
+            loss_only (`bool`, *optional*, defaults to `False`):
+                Ignores all outputs except the loss.
+            jit_mode (`bool`, *optional*):
+                Whether or not to use PyTorch jit trace for inference.
+
+        Example:
+
+        ```py
+        >>> from transformers import TrainingArguments
+
+        >>> args = TrainingArguments("working_dir")
+        >>> args = args.set_evaluate(strategy="steps", steps=100)
+        >>> args.eval_steps
+        100
+        ```
+        """
+        self.eval_strategy = IntervalStrategy(strategy)
+        if self.eval_strategy == IntervalStrategy.STEPS and steps == 0:
+            raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.")
+        self.do_eval = self.eval_strategy != IntervalStrategy.NO
+        self.eval_steps = steps
+        self.per_device_eval_batch_size = batch_size
+        self.eval_accumulation_steps = accumulation_steps
+        self.eval_delay = delay
+        self.prediction_loss_only = loss_only
+        self.jit_mode_eval = jit_mode
+        return self
+
+    def set_testing(
+        self,
+        batch_size: int = 8,
+        loss_only: bool = False,
+        jit_mode: bool = False,
+    ):
+        """
+        A method that regroups all basic arguments linked to testing on a held-out dataset.
+
+        
+
+        Calling this method will automatically set `self.do_predict` to `True`.
+
+        
+
+        Args:
+            batch_size (`int` *optional*, defaults to 8):
+                The batch size per device (GPU/TPU core/CPU...) used for testing.
+            loss_only (`bool`, *optional*, defaults to `False`):
+                Ignores all outputs except the loss.
+            jit_mode (`bool`, *optional*):
+                Whether or not to use PyTorch jit trace for inference.
+
+        Example:
+
+        ```py
+        >>> from transformers import TrainingArguments
+
+        >>> args = TrainingArguments("working_dir")
+        >>> args = args.set_testing(batch_size=32)
+        >>> args.per_device_eval_batch_size
+        32
+        ```
+        """
+        self.do_predict = True
+        self.per_device_eval_batch_size = batch_size
+        self.prediction_loss_only = loss_only
+        self.jit_mode_eval = jit_mode
+        return self
+
+    def set_save(
+        self,
+        strategy: Union[str, IntervalStrategy] = "steps",
+        steps: int = 500,
+        total_limit: Optional[int] = None,
+        on_each_node: bool = False,
+    ):
+        """
+        A method that regroups all arguments linked to checkpoint saving.
+
+        Args:
+            strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
+                The checkpoint save strategy to adopt during training. Possible values are:
+
+                    - `"no"`: No save is done during training.
+                    - `"epoch"`: Save is done at the end of each epoch.
+                    - `"steps"`: Save is done every `save_steps`.
+
+            steps (`int`, *optional*, defaults to 500):
+                Number of updates steps before two checkpoint saves if `strategy="steps"`.
+            total_limit (`int`, *optional*):
+                If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
+                `output_dir`.
+            on_each_node (`bool`, *optional*, defaults to `False`):
+                When doing multi-node distributed training, whether to save models and checkpoints on each node, or
+                only on the main one.
+
+                This should not be activated when the different nodes use the same storage as the files will be saved
+                with the same names for each node.
+
+        Example:
+
+        ```py
+        >>> from transformers import TrainingArguments
+
+        >>> args = TrainingArguments("working_dir")
+        >>> args = args.set_save(strategy="steps", steps=100)
+        >>> args.save_steps
+        100
+        ```
+        """
+        self.save_strategy = SaveStrategy(strategy)
+        if self.save_strategy == SaveStrategy.STEPS and steps == 0:
+            raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.")
+        self.save_steps = steps
+        self.save_total_limit = total_limit
+        self.save_on_each_node = on_each_node
+        return self
+
+    def set_logging(
+        self,
+        strategy: Union[str, IntervalStrategy] = "steps",
+        steps: int = 500,
+        report_to: Union[str, list[str]] = "none",
+        level: str = "passive",
+        first_step: bool = False,
+        nan_inf_filter: bool = False,
+        on_each_node: bool = False,
+        replica_level: str = "passive",
+    ):
+        """
+        A method that regroups all arguments linked to logging.
+
+        Args:
+            strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
+                The logging strategy to adopt during training. Possible values are:
+
+                    - `"no"`: No logging is done during training.
+                    - `"epoch"`: Logging is done at the end of each epoch.
+                    - `"steps"`: Logging is done every `logging_steps`.
+
+            steps (`int`, *optional*, defaults to 500):
+                Number of update steps between two logs if `strategy="steps"`.
+            level (`str`, *optional*, defaults to `"passive"`):
+                Logger log level to use on the main process. Possible choices are the log levels as strings: `"debug"`,
+                `"info"`, `"warning"`, `"error"` and `"critical"`, plus a `"passive"` level which doesn't set anything
+                and lets the application set the level.
+            report_to (`str` or `list[str]`, *optional*, defaults to `"all"`):
+                The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
+                `"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`,
+                `"neptune"`, `"swanlab"`, `"tensorboard"`, `"trackio"` and `"wandb"`. Use `"all"` to report to all
+                integrations installed, `"none"` for no integrations.
+            first_step (`bool`, *optional*, defaults to `False`):
+                Whether to log and evaluate the first `global_step` or not.
+            nan_inf_filter (`bool`, *optional*, defaults to `True`):
+                Whether to filter `nan` and `inf` losses for logging. If set to `True` the loss of every step that is
+                `nan` or `inf` is filtered and the average loss of the current logging window is taken instead.
+
+                
+
+                `nan_inf_filter` only influences the logging of loss values, it does not change the behavior the
+                gradient is computed or applied to the model.
+
+                
+
+            on_each_node (`bool`, *optional*, defaults to `True`):
+                In multinode distributed training, whether to log using `log_level` once per node, or only on the main
+                node.
+            replica_level (`str`, *optional*, defaults to `"passive"`):
+                Logger log level to use on replicas. Same choices as `log_level`
+
+        Example:
+
+        ```py
+        >>> from transformers import TrainingArguments
+
+        >>> args = TrainingArguments("working_dir")
+        >>> args = args.set_logging(strategy="steps", steps=100)
+        >>> args.logging_steps
+        100
+        ```
+        """
+        self.logging_strategy = IntervalStrategy(strategy)
+        if self.logging_strategy == IntervalStrategy.STEPS and steps == 0:
+            raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.")
+        self.logging_steps = steps
+        self.report_to = report_to
+        self.log_level = level
+        self.logging_first_step = first_step
+        self.logging_nan_inf_filter = nan_inf_filter
+        self.log_on_each_node = on_each_node
+        self.log_level_replica = replica_level
+        return self
+
+    def set_push_to_hub(
+        self,
+        model_id: str,
+        strategy: Union[str, HubStrategy] = "every_save",
+        token: Optional[str] = None,
+        private_repo: Optional[bool] = None,
+        always_push: bool = False,
+        revision: Optional[str] = None,
+    ):
+        """
+        A method that regroups all arguments linked to synchronizing checkpoints with the Hub.
+
+        
+
+        Calling this method will set `self.push_to_hub` to `True`, which means the `output_dir` will begin a git
+        directory synced with the repo (determined by `model_id`) and the content will be pushed each time a save is
+        triggered (depending on your `self.save_strategy`). Calling [`~Trainer.save_model`] will also trigger a push.
+
+        
+
+        Args:
+            model_id (`str`):
+                The name of the repository to keep in sync with the local *output_dir*. It can be a simple model ID in
+                which case the model will be pushed in your namespace. Otherwise it should be the whole repository
+                name, for instance `"user_name/model"`, which allows you to push to an organization you are a member of
+                with `"organization_name/model"`.
+            strategy (`str` or [`~trainer_utils.HubStrategy`], *optional*, defaults to `"every_save"`):
+                Defines the scope of what is pushed to the Hub and when. Possible values are:
+
+                - `"end"`: push the model, its configuration, the processing_class e.g. tokenizer (if passed along to the [`Trainer`]) and a
+                draft of a model card when the [`~Trainer.save_model`] method is called.
+                - `"every_save"`: push the model, its configuration, the processing_class e.g. tokenizer (if passed along to the [`Trainer`])
+                  and
+                a draft of a model card each time there is a model save. The pushes are asynchronous to not block
+                training, and in case the save are very frequent, a new push is only attempted if the previous one is
+                finished. A last push is made with the final model at the end of training.
+                - `"checkpoint"`: like `"every_save"` but the latest checkpoint is also pushed in a subfolder named
+                last-checkpoint, allowing you to resume training easily with
+                `trainer.train(resume_from_checkpoint="last-checkpoint")`.
+                - `"all_checkpoints"`: like `"checkpoint"` but all checkpoints are pushed like they appear in the
+                  output
+                folder (so you will get one checkpoint folder per folder in your final repository)
+
+            token (`str`, *optional*):
+                The token to use to push the model to the Hub. Will default to the token in the cache folder obtained
+                with `hf auth login`.
+            private_repo (`bool`, *optional*, defaults to `False`):
+                Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
+            always_push (`bool`, *optional*, defaults to `False`):
+                Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not
+                finished.
+            revision (`str`, *optional*):
+                The revision to use when pushing to the Hub. Can be a branch name, a tag, or a commit hash.
+
+        Example:
+
+        ```py
+        >>> from transformers import TrainingArguments
+
+        >>> args = TrainingArguments("working_dir")
+        >>> args = args.set_push_to_hub("me/awesome-model")
+        >>> args.hub_model_id
+        'me/awesome-model'
+        ```
+        """
+        self.push_to_hub = True
+        self.hub_model_id = model_id
+        self.hub_strategy = HubStrategy(strategy)
+        self.hub_token = token
+        self.hub_private_repo = private_repo
+        self.hub_always_push = always_push
+        self.hub_revision = revision
+        return self
+
+    def set_optimizer(
+        self,
+        name: Union[str, OptimizerNames] = "adamw_torch",
+        learning_rate: float = 5e-5,
+        weight_decay: float = 0,
+        beta1: float = 0.9,
+        beta2: float = 0.999,
+        epsilon: float = 1e-8,
+        args: Optional[str] = None,
+    ):
+        """
+        A method that regroups all arguments linked to the optimizer and its hyperparameters.
+
+        Args:
+            name (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_torch"`):
+                The optimizer to use: `"adamw_torch"`, `"adamw_torch_fused"`, `"adamw_apex_fused"`,
+                `"adamw_anyprecision"` or `"adafactor"`.
+            learning_rate (`float`, *optional*, defaults to 5e-5):
+                The initial learning rate.
+            weight_decay (`float`, *optional*, defaults to 0):
+                The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights.
+            beta1 (`float`, *optional*, defaults to 0.9):
+                The beta1 hyperparameter for the adam optimizer or its variants.
+            beta2 (`float`, *optional*, defaults to 0.999):
+                The beta2 hyperparameter for the adam optimizer or its variants.
+            epsilon (`float`, *optional*, defaults to 1e-8):
+                The epsilon hyperparameter for the adam optimizer or its variants.
+            args (`str`, *optional*):
+                Optional arguments that are supplied to AnyPrecisionAdamW (only useful when
+                `optim="adamw_anyprecision"`).
+
+        Example:
+
+        ```py
+        >>> from transformers import TrainingArguments
+
+        >>> args = TrainingArguments("working_dir")
+        >>> args = args.set_optimizer(name="adamw_torch", beta1=0.8)
+        >>> args.optim
+        'adamw_torch'
+        ```
+        """
+        self.optim = OptimizerNames(name)
+        self.learning_rate = learning_rate
+        self.weight_decay = weight_decay
+        self.adam_beta1 = beta1
+        self.adam_beta2 = beta2
+        self.adam_epsilon = epsilon
+        self.optim_args = args
+        return self
+
+    def set_lr_scheduler(
+        self,
+        name: Union[str, SchedulerType] = "linear",
+        num_epochs: float = 3.0,
+        max_steps: int = -1,
+        warmup_ratio: float = 0,
+        warmup_steps: int = 0,
+    ):
+        """
+        A method that regroups all arguments linked to the learning rate scheduler and its hyperparameters.
+
+        Args:
+            name (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`):
+                The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values.
+            num_epochs(`float`, *optional*, defaults to 3.0):
+                Total number of training epochs to perform (if not an integer, will perform the decimal part percents
+                of the last epoch before stopping training).
+            max_steps (`int`, *optional*, defaults to -1):
+                If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`.
+                For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until
+                `max_steps` is reached.
+            warmup_ratio (`float`, *optional*, defaults to 0.0):
+                Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.
+            warmup_steps (`int`, *optional*, defaults to 0):
+                Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of
+                `warmup_ratio`.
+
+        Example:
+
+        ```py
+        >>> from transformers import TrainingArguments
+
+        >>> args = TrainingArguments("working_dir")
+        >>> args = args.set_lr_scheduler(name="cosine", warmup_ratio=0.05)
+        >>> args.warmup_ratio
+        0.05
+        ```
+        """
+        self.lr_scheduler_type = SchedulerType(name)
+        self.num_train_epochs = num_epochs
+        self.max_steps = max_steps
+        self.warmup_ratio = warmup_ratio
+        self.warmup_steps = warmup_steps
+        return self
+
+    def set_dataloader(
+        self,
+        train_batch_size: int = 8,
+        eval_batch_size: int = 8,
+        drop_last: bool = False,
+        num_workers: int = 0,
+        pin_memory: bool = True,
+        persistent_workers: bool = False,
+        prefetch_factor: Optional[int] = None,
+        auto_find_batch_size: bool = False,
+        ignore_data_skip: bool = False,
+        sampler_seed: Optional[int] = None,
+    ):
+        """
+        A method that regroups all arguments linked to the dataloaders creation.
+
+        Args:
+            drop_last (`bool`, *optional*, defaults to `False`):
+                Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch
+                size) or not.
+            num_workers (`int`, *optional*, defaults to 0):
+                Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in
+                the main process.
+            pin_memory (`bool`, *optional*, defaults to `True`):
+                Whether you want to pin memory in data loaders or not. Will default to `True`.
+            persistent_workers (`bool`, *optional*, defaults to `False`):
+                If True, the data loader will not shut down the worker processes after a dataset has been consumed
+                once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training,
+                but will increase RAM usage. Will default to `False`.
+            prefetch_factor (`int`, *optional*):
+                Number of batches loaded in advance by each worker.
+                2 means there will be a total of 2 * num_workers batches prefetched across all workers.
+            auto_find_batch_size (`bool`, *optional*, defaults to `False`)
+                Whether to find a batch size that will fit into memory automatically through exponential decay,
+                avoiding CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
+            ignore_data_skip (`bool`, *optional*, defaults to `False`):
+                When resuming training, whether or not to skip the epochs and batches to get the data loading at the
+                same stage as in the previous training. If set to `True`, the training will begin faster (as that
+                skipping step can take a long time) but will not yield the same results as the interrupted training
+                would have.
+            sampler_seed (`int`, *optional*):
+                Random seed to be used with data samplers. If not set, random generators for data sampling will use the
+                same seed as `self.seed`. This can be used to ensure reproducibility of data sampling, independent of
+                the model seed.
+
+        Example:
+
+        ```py
+        >>> from transformers import TrainingArguments
+
+        >>> args = TrainingArguments("working_dir")
+        >>> args = args.set_dataloader(train_batch_size=16, eval_batch_size=64)
+        >>> args.per_device_train_batch_size
+        16
+        ```
+        """
+        self.per_device_train_batch_size = train_batch_size
+        self.per_device_eval_batch_size = eval_batch_size
+        self.dataloader_drop_last = drop_last
+        self.dataloader_num_workers = num_workers
+        self.dataloader_pin_memory = pin_memory
+        self.dataloader_persistent_workers = persistent_workers
+        self.dataloader_prefetch_factor = prefetch_factor
+        self.auto_find_batch_size = auto_find_batch_size
+        self.ignore_data_skip = ignore_data_skip
+        self.data_seed = sampler_seed
+        return self
+
+
+class ParallelMode(Enum):
+    NOT_PARALLEL = "not_parallel"
+    NOT_DISTRIBUTED = "not_distributed"
+    DISTRIBUTED = "distributed"
+    SAGEMAKER_MODEL_PARALLEL = "sagemaker_model_parallel"
+    SAGEMAKER_DATA_PARALLEL = "sagemaker_data_parallel"
+    TPU = "tpu"
diff --git a/phivenv/Lib/site-packages/transformers/training_args_seq2seq.py b/phivenv/Lib/site-packages/transformers/training_args_seq2seq.py
new file mode 100644
index 0000000000000000000000000000000000000000..5342b7add3932c542e35247e52920d8fc91ed325
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/training_args_seq2seq.py
@@ -0,0 +1,90 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Optional, Union
+
+from .generation.configuration_utils import GenerationConfig
+from .training_args import TrainingArguments
+from .utils import add_start_docstrings
+
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+@add_start_docstrings(TrainingArguments.__doc__)
+class Seq2SeqTrainingArguments(TrainingArguments):
+    """
+    Args:
+        predict_with_generate (`bool`, *optional*, defaults to `False`):
+            Whether to use generate to calculate generative metrics (ROUGE, BLEU).
+        generation_max_length (`int`, *optional*):
+            The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default to the
+            `max_length` value of the model configuration.
+        generation_num_beams (`int`, *optional*):
+            The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default to the
+            `num_beams` value of the model configuration.
+        generation_config (`str` or `Path` or [`~generation.GenerationConfig`], *optional*):
+            Allows to load a [`~generation.GenerationConfig`] from the `from_pretrained` method. This can be either:
+
+            - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
+              huggingface.co.
+            - a path to a *directory* containing a configuration file saved using the
+              [`~GenerationConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
+            - a [`~generation.GenerationConfig`] object.
+    """
+
+    sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."})
+    predict_with_generate: bool = field(
+        default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
+    )
+    generation_max_length: Optional[int] = field(
+        default=None,
+        metadata={
+            "help": (
+                "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default "
+                "to the `max_length` value of the model configuration."
+            )
+        },
+    )
+    generation_num_beams: Optional[int] = field(
+        default=None,
+        metadata={
+            "help": (
+                "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default "
+                "to the `num_beams` value of the model configuration."
+            )
+        },
+    )
+    generation_config: Optional[Union[str, Path, GenerationConfig]] = field(
+        default=None,
+        metadata={
+            "help": "Model id, file path or url pointing to a GenerationConfig json file, to use during prediction."
+        },
+    )
+
+    def to_dict(self):
+        """
+        Serializes this instance while replace `Enum` by their values and `GenerationConfig` by dictionaries (for JSON
+        serialization support). It obfuscates the token values by removing their value.
+        """
+        # filter out fields that are defined as field(init=False)
+        d = super().to_dict()
+        for k, v in d.items():
+            if isinstance(v, GenerationConfig):
+                d[k] = v.to_dict()
+        return d
diff --git a/phivenv/Lib/site-packages/transformers/training_args_tf.py b/phivenv/Lib/site-packages/transformers/training_args_tf.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf20503d63cc72f7ea5e28b2eef23d1338868d54
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/training_args_tf.py
@@ -0,0 +1,299 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from dataclasses import dataclass, field
+from typing import Optional
+
+from .training_args import TrainingArguments
+from .utils import cached_property, is_tf_available, logging, requires_backends
+
+
+logger = logging.get_logger(__name__)
+
+if is_tf_available():
+    import tensorflow as tf
+
+    from .modeling_tf_utils import keras
+
+
+@dataclass
+class TFTrainingArguments(TrainingArguments):
+    """
+    TrainingArguments is the subset of the arguments we use in our example scripts **which relate to the training loop
+    itself**.
+
+    Using [`HfArgumentParser`] we can turn this class into
+    [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
+    command line.
+
+    Parameters:
+        output_dir (`str`):
+            The output directory where the model predictions and checkpoints will be written.
+        overwrite_output_dir (`bool`, *optional*, defaults to `False`):
+            If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir`
+            points to a checkpoint directory.
+        do_train (`bool`, *optional*, defaults to `False`):
+            Whether to run training or not. This argument is not directly used by [`Trainer`], it's intended to be used
+            by your training/evaluation scripts instead. See the [example
+            scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
+        do_eval (`bool`, *optional*):
+            Whether to run evaluation on the validation set or not. Will be set to `True` if `eval_strategy` is
+            different from `"no"`. This argument is not directly used by [`Trainer`], it's intended to be used by your
+            training/evaluation scripts instead. See the [example
+            scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
+        do_predict (`bool`, *optional*, defaults to `False`):
+            Whether to run predictions on the test set or not. This argument is not directly used by [`Trainer`], it's
+            intended to be used by your training/evaluation scripts instead. See the [example
+            scripts](https://github.com/huggingface/transformers/tree/main/examples) for more details.
+        eval_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`):
+            The evaluation strategy to adopt during training. Possible values are:
+
+                - `"no"`: No evaluation is done during training.
+                - `"steps"`: Evaluation is done (and logged) every `eval_steps`.
+                - `"epoch"`: Evaluation is done at the end of each epoch.
+
+        per_device_train_batch_size (`int`, *optional*, defaults to 8):
+            The batch size per GPU/TPU core/CPU for training.
+        per_device_eval_batch_size (`int`, *optional*, defaults to 8):
+            The batch size per GPU/TPU core/CPU for evaluation.
+        gradient_accumulation_steps (`int`, *optional*, defaults to 1):
+            Number of updates steps to accumulate the gradients for, before performing a backward/update pass.
+
+            
+
+            When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging,
+            evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training examples.
+
+            
+
+        learning_rate (`float`, *optional*, defaults to 5e-5):
+            The initial learning rate for Adam.
+        weight_decay (`float`, *optional*, defaults to 0):
+            The weight decay to apply (if not zero).
+        adam_beta1 (`float`, *optional*, defaults to 0.9):
+            The beta1 hyperparameter for the Adam optimizer.
+        adam_beta2 (`float`, *optional*, defaults to 0.999):
+            The beta2 hyperparameter for the Adam optimizer.
+        adam_epsilon (`float`, *optional*, defaults to 1e-8):
+            The epsilon hyperparameter for the Adam optimizer.
+        max_grad_norm (`float`, *optional*, defaults to 1.0):
+            Maximum gradient norm (for gradient clipping).
+        num_train_epochs(`float`, *optional*, defaults to 3.0):
+            Total number of training epochs to perform.
+        max_steps (`int`, *optional*, defaults to -1):
+            If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`.
+            For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until
+            `max_steps` is reached.
+        warmup_ratio (`float`, *optional*, defaults to 0.0):
+            Ratio of total training steps used for a linear warmup from 0 to `learning_rate`.
+        warmup_steps (`int`, *optional*, defaults to 0):
+            Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`.
+        logging_dir (`str`, *optional*):
+            [TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to
+            *runs/**CURRENT_DATETIME_HOSTNAME***.
+        logging_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`):
+            The logging strategy to adopt during training. Possible values are:
+
+                - `"no"`: No logging is done during training.
+                - `"epoch"`: Logging is done at the end of each epoch.
+                - `"steps"`: Logging is done every `logging_steps`.
+
+        logging_first_step (`bool`, *optional*, defaults to `False`):
+            Whether to log and evaluate the first `global_step` or not.
+        logging_steps (`int`, *optional*, defaults to 500):
+            Number of update steps between two logs if `logging_strategy="steps"`.
+        save_strategy (`str` or [`~trainer_utils.SaveStrategy`], *optional*, defaults to `"steps"`):
+            The checkpoint save strategy to adopt during training. Possible values are:
+
+                - `"no"`: No save is done during training.
+                - `"epoch"`: Save is done at the end of each epoch.
+                - `"steps"`: Save is done every `save_steps`.
+
+        save_steps (`int`, *optional*, defaults to 500):
+            Number of updates steps before two checkpoint saves if `save_strategy="steps"`.
+        save_total_limit (`int`, *optional*):
+            If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in
+            `output_dir`.
+        no_cuda (`bool`, *optional*, defaults to `False`):
+            Whether to not use CUDA even when it is available or not.
+        seed (`int`, *optional*, defaults to 42):
+            Random seed that will be set at the beginning of training.
+        fp16 (`bool`, *optional*, defaults to `False`):
+            Whether to use 16-bit (mixed) precision training (through NVIDIA Apex) instead of 32-bit training.
+        fp16_opt_level (`str`, *optional*, defaults to 'O1'):
+            For `fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details on
+            the [Apex documentation](https://nvidia.github.io/apex/amp).
+        local_rank (`int`, *optional*, defaults to -1):
+            During distributed training, the rank of the process.
+        tpu_num_cores (`int`, *optional*):
+            When training on TPU, the number of TPU cores (automatically passed by launcher script).
+        debug (`bool`, *optional*, defaults to `False`):
+            Whether to activate the trace to record computation graphs and profiling information or not.
+        dataloader_drop_last (`bool`, *optional*, defaults to `False`):
+            Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size)
+            or not.
+        eval_steps (`int`, *optional*, defaults to 1000):
+            Number of update steps before two evaluations.
+        past_index (`int`, *optional*, defaults to -1):
+            Some models like [TransformerXL](../model_doc/transformerxl) or :doc*XLNet <../model_doc/xlnet>* can make
+            use of the past hidden states for their predictions. If this argument is set to a positive int, the
+            `Trainer` will use the corresponding output (usually index 2) as the past state and feed it to the model at
+            the next training step under the keyword argument `mems`.
+        tpu_name (`str`, *optional*):
+            The name of the TPU the process is running on.
+        tpu_zone (`str`, *optional*):
+            The zone of the TPU the process is running on. If not specified, we will attempt to automatically detect
+            from metadata.
+        gcp_project (`str`, *optional*):
+            Google Cloud Project name for the Cloud TPU-enabled project. If not specified, we will attempt to
+            automatically detect from metadata.
+        run_name (`str`, *optional*):
+            A descriptor for the run. Notably used for trackio, wandb, mlflow, comet and swanlab logging.
+        xla (`bool`, *optional*):
+            Whether to activate the XLA compilation or not.
+    """
+
+    framework = "tf"
+    tpu_name: Optional[str] = field(
+        default=None,
+        metadata={"help": "Name of TPU"},
+    )
+
+    tpu_zone: Optional[str] = field(
+        default=None,
+        metadata={"help": "Zone of TPU"},
+    )
+
+    gcp_project: Optional[str] = field(
+        default=None,
+        metadata={"help": "Name of Cloud TPU-enabled project"},
+    )
+
+    poly_power: float = field(
+        default=1.0,
+        metadata={"help": "Power for the Polynomial decay LR scheduler."},
+    )
+
+    xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"})
+
+    @cached_property
+    def _setup_strategy(self) -> tuple["tf.distribute.Strategy", int]:
+        requires_backends(self, ["tf"])
+        logger.info("Tensorflow: setting up strategy")
+
+        gpus = tf.config.list_physical_devices("GPU")
+
+        # Set to float16 at first
+        if self.fp16:
+            keras.mixed_precision.set_global_policy("mixed_float16")
+
+        if self.no_cuda:
+            strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
+        else:
+            try:
+                if self.tpu_name:
+                    tpu = tf.distribute.cluster_resolver.TPUClusterResolver(
+                        self.tpu_name, zone=self.tpu_zone, project=self.gcp_project
+                    )
+                else:
+                    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
+            except ValueError:
+                if self.tpu_name:
+                    raise RuntimeError(f"Couldn't connect to TPU {self.tpu_name}!")
+                else:
+                    tpu = None
+
+            if tpu:
+                # Set to bfloat16 in case of TPU
+                if self.fp16:
+                    keras.mixed_precision.set_global_policy("mixed_bfloat16")
+
+                tf.config.experimental_connect_to_cluster(tpu)
+                tf.tpu.experimental.initialize_tpu_system(tpu)
+
+                strategy = tf.distribute.TPUStrategy(tpu)
+
+            elif len(gpus) == 0:
+                strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
+            elif len(gpus) == 1:
+                strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
+            elif len(gpus) > 1:
+                # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
+                strategy = tf.distribute.MirroredStrategy()
+            else:
+                raise ValueError("Cannot find the proper strategy, please check your environment properties.")
+
+        return strategy
+
+    @property
+    def strategy(self) -> "tf.distribute.Strategy":
+        """
+        The strategy used for distributed training.
+        """
+        requires_backends(self, ["tf"])
+        return self._setup_strategy
+
+    @property
+    def n_replicas(self) -> int:
+        """
+        The number of replicas (CPUs, GPUs or TPU cores) used in this training.
+        """
+        requires_backends(self, ["tf"])
+        return self._setup_strategy.num_replicas_in_sync
+
+    @property
+    def should_log(self):
+        """
+        Whether or not the current process should produce log.
+        """
+        return False  # TF Logging is handled by Keras not the Trainer
+
+    @property
+    def train_batch_size(self) -> int:
+        """
+        The actual batch size for training (may differ from `per_gpu_train_batch_size` in distributed training).
+        """
+        if self.per_gpu_train_batch_size:
+            logger.warning(
+                "Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future "
+                "version. Using `--per_device_train_batch_size` is preferred."
+            )
+        per_device_batch_size = self.per_gpu_train_batch_size or self.per_device_train_batch_size
+        return per_device_batch_size * self.n_replicas
+
+    @property
+    def eval_batch_size(self) -> int:
+        """
+        The actual batch size for evaluation (may differ from `per_gpu_eval_batch_size` in distributed training).
+        """
+        if self.per_gpu_eval_batch_size:
+            logger.warning(
+                "Using deprecated `--per_gpu_eval_batch_size` argument which will be removed in a future "
+                "version. Using `--per_device_eval_batch_size` is preferred."
+            )
+        per_device_batch_size = self.per_gpu_eval_batch_size or self.per_device_eval_batch_size
+        return per_device_batch_size * self.n_replicas
+
+    @property
+    def n_gpu(self) -> int:
+        """
+        The number of replicas (CPUs, GPUs or TPU cores) used in this training.
+        """
+        requires_backends(self, ["tf"])
+        warnings.warn(
+            "The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.",
+            FutureWarning,
+        )
+        return self._setup_strategy.num_replicas_in_sync
diff --git a/phivenv/Lib/site-packages/transformers/video_processing_utils.py b/phivenv/Lib/site-packages/transformers/video_processing_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0e1db0266a92af473bc7d0b629bdba72141867d
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/video_processing_utils.py
@@ -0,0 +1,898 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import warnings
+from copy import deepcopy
+from functools import partial
+from typing import Any, Callable, Optional, Union
+
+import numpy as np
+
+from .dynamic_module_utils import custom_object_save
+from .image_processing_utils import (
+    BatchFeature,
+    get_size_dict,
+)
+from .image_processing_utils_fast import BaseImageProcessorFast
+from .image_utils import (
+    ChannelDimension,
+    SizeDict,
+    validate_kwargs,
+)
+from .processing_utils import Unpack, VideosKwargs
+from .utils import (
+    IMAGE_PROCESSOR_NAME,
+    PROCESSOR_NAME,
+    VIDEO_PROCESSOR_NAME,
+    TensorType,
+    add_start_docstrings,
+    copy_func,
+    download_url,
+    is_offline_mode,
+    is_remote_url,
+    is_torch_available,
+    is_torchcodec_available,
+    is_torchvision_available,
+    is_torchvision_v2_available,
+    logging,
+)
+from .utils.hub import cached_files
+from .utils.import_utils import requires
+from .video_utils import (
+    VideoInput,
+    VideoMetadata,
+    group_videos_by_shape,
+    is_valid_video,
+    load_video,
+    make_batched_metadata,
+    make_batched_videos,
+    reorder_videos,
+    to_channel_dimension_format,
+)
+
+
+if is_torch_available():
+    import torch
+
+if is_torchvision_available():
+    if is_torchvision_v2_available():
+        from torchvision.transforms.v2 import functional as F
+    else:
+        from torchvision.transforms import functional as F
+
+logger = logging.get_logger(__name__)
+
+
+BASE_VIDEO_PROCESSOR_DOCSTRING = r"""
+    Args:
+        do_resize (`bool`, *optional*, defaults to `self.do_resize`):
+            Whether to resize the video's (height, width) dimensions to the specified `size`. Can be overridden by the
+            `do_resize` parameter in the `preprocess` method.
+        size (`dict`, *optional*, defaults to `self.size`):
+            Size of the output video after resizing. Can be overridden by the `size` parameter in the `preprocess`
+            method.
+        size_divisor (`int`, *optional*, defaults to `self.size_divisor`):
+            The size by which to make sure both the height and width can be divided.
+        default_to_square (`bool`, *optional*, defaults to `self.default_to_square`):
+            Whether to default to a square video when resizing, if size is an int.
+        resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
+            Resampling filter to use if resizing the video. Only has an effect if `do_resize` is set to `True`. Can be
+            overridden by the `resample` parameter in the `preprocess` method.
+        do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
+            Whether to center crop the video to the specified `crop_size`. Can be overridden by `do_center_crop` in the
+            `preprocess` method.
+        do_pad (`bool`, *optional*):
+            Whether to pad the video to the `(max_height, max_width)` of the videos in the batch.
+        crop_size (`dict[str, int]` *optional*, defaults to `self.crop_size`):
+            Size of the output video after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
+            method.
+        do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
+            Whether to rescale the video by the specified scale `rescale_factor`. Can be overridden by the
+            `do_rescale` parameter in the `preprocess` method.
+        rescale_factor (`int` or `float`, *optional*, defaults to `self.rescale_factor`):
+            Scale factor to use if rescaling the video. Only has an effect if `do_rescale` is set to `True`. Can be
+            overridden by the `rescale_factor` parameter in the `preprocess` method.
+        do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
+            Whether to normalize the video. Can be overridden by the `do_normalize` parameter in the `preprocess`
+            method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
+        image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
+            Mean to use if normalizing the video. This is a float or list of floats the length of the number of
+            channels in the video. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
+            overridden by the `image_mean` parameter in the `preprocess` method.
+        image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
+            Standard deviation to use if normalizing the video. This is a float or list of floats the length of the
+            number of channels in the video. Can be overridden by the `image_std` parameter in the `preprocess` method.
+            Can be overridden by the `image_std` parameter in the `preprocess` method.
+        do_convert_rgb (`bool`, *optional*, defaults to `self.image_std`):
+            Whether to convert the video to RGB.
+        video_metadata (`VideoMetadata`, *optional*):
+            Metadata of the video containing information about total duration, fps and total number of frames.
+        do_sample_frames (`int`, *optional*, defaults to `self.do_sample_frames`):
+            Whether to sample frames from the video before processing or to process the whole video.
+        num_frames (`int`, *optional*, defaults to `self.num_frames`):
+            Maximum number of frames to sample when `do_sample_frames=True`.
+        fps (`int` or `float`, *optional*, defaults to `self.fps`):
+            Target frames to sample per second when `do_sample_frames=True`.
+        return_tensors (`str` or `TensorType`, *optional*):
+            Returns stacked tensors if set to `pt, otherwise returns a list of tensors.
+        data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
+            The channel dimension format for the output video. Can be one of:
+            - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_channels, height, width) format.
+            - `"channels_last"` or `ChannelDimension.LAST`: video in (height, width, num_channels) format.
+            - Unset: Use the channel dimension format of the input video.
+        input_data_format (`ChannelDimension` or `str`, *optional*):
+            The channel dimension format for the input video. If unset, the channel dimension format is inferred
+            from the input video. Can be one of:
+            - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_channels, height, width) format.
+            - `"channels_last"` or `ChannelDimension.LAST`: video in (height, width, num_channels) format.
+            - `"none"` or `ChannelDimension.NONE`: video in (height, width) format.
+        device (`torch.device`, *optional*):
+            The device to process the videos on. If unset, the device is inferred from the input videos.
+        return_metadata (`bool`, *optional*):
+            Whether to return video metadata or not.
+        """
+
+
+@add_start_docstrings(
+    "Constructs a base VideoProcessor.",
+    BASE_VIDEO_PROCESSOR_DOCSTRING,
+)
+@requires(backends=("vision", "torchvision"))
+class BaseVideoProcessor(BaseImageProcessorFast):
+    _auto_class = None
+
+    resample = None
+    image_mean = None
+    image_std = None
+    size = None
+    size_divisor = None
+    default_to_square = True
+    crop_size = None
+    do_resize = None
+    do_center_crop = None
+    do_pad = None
+    do_rescale = None
+    rescale_factor = 1 / 255
+    do_normalize = None
+    do_convert_rgb = None
+    do_sample_frames = None
+    fps = None
+    num_frames = None
+    video_metadata = None
+    return_metadata = False
+    valid_kwargs = VideosKwargs
+    model_input_names = ["pixel_values_videos"]
+
+    def __init__(self, **kwargs: Unpack[VideosKwargs]) -> None:
+        super().__init__()
+
+        self._processor_class = kwargs.pop("processor_class", None)
+
+        # Additional attributes without default values
+        for key, value in kwargs.items():
+            try:
+                setattr(self, key, value)
+            except AttributeError as err:
+                logger.error(f"Can't set {key} with value {value} for {self}")
+                raise err
+
+        # Prepare size related keys and turn then into `SizeDict`
+        size = kwargs.pop("size", self.size)
+        self.size = (
+            get_size_dict(size=size, default_to_square=kwargs.pop("default_to_square", self.default_to_square))
+            if size is not None
+            else None
+        )
+        crop_size = kwargs.pop("crop_size", self.crop_size)
+        self.crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None
+
+        # Save valid kwargs in a list for further processing
+        self.model_valid_processing_keys = list(self.valid_kwargs.__annotations__.keys())
+        for key in self.model_valid_processing_keys:
+            if kwargs.get(key) is not None:
+                setattr(self, key, kwargs[key])
+            else:
+                setattr(self, key, deepcopy(getattr(self, key, None)))
+
+    def __call__(self, videos, **kwargs) -> BatchFeature:
+        return self.preprocess(videos, **kwargs)
+
+    def convert_to_rgb(
+        self,
+        video: "torch.Tensor",
+    ) -> VideoInput:
+        """
+        Converts a video to RGB format.
+
+        Args:
+            video (`"torch.Tensor"`):
+                The video to convert.
+
+        Returns:
+            `torch.Tensor`: The converted video.
+        """
+
+        video = F.grayscale_to_rgb(video)
+        if video.shape[-3] == 3 or not (video[..., 3, :, :] < 255).any():
+            return video
+
+        # There is a transparency layer, blend it with a white background.
+        # Calculate the alpha proportion for blending.
+        alpha = video[..., 3, :, :] / 255.0
+        video = (1 - alpha[..., None, :, :]) * 255 + alpha[..., None, :, :] * video[..., :3, :, :]
+        return video
+
+    def sample_frames(
+        self,
+        metadata: VideoMetadata,
+        num_frames: Optional[int] = None,
+        fps: Optional[Union[int, float]] = None,
+        **kwargs,
+    ):
+        """
+        Default sampling function which uniformly samples the desired number of frames between 0 and total number of frames.
+        If `fps` is passed along with metadata, `fps` frames per second are sampled uniformty. Arguments `num_frames`
+        and `fps` are mutually exclusive.
+
+        Args:
+            metadata (`VideoMetadata`):
+                Metadata of the video containing information about total duration, fps and total number of frames.
+            num_frames (`int`, *optional*):
+                Maximum number of frames to sample. Defaults to `self.num_frames`.
+            fps (`int` or `float`, *optional*):
+                Target frames to sample per second. Defaults to `self.fps`.
+
+        Returns:
+            np.ndarray:
+                Indices to sample video frames.
+        """
+        if fps is not None and num_frames is not None:
+            raise ValueError(
+                "`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
+            )
+
+        num_frames = num_frames if num_frames is not None else self.num_frames
+        fps = fps if fps is not None else self.fps
+        total_num_frames = metadata.total_num_frames
+
+        # If num_frames is not given but fps is, calculate num_frames from fps
+        if num_frames is None and fps is not None:
+            if metadata is None or metadata.fps is None:
+                raise ValueError(
+                    "Asked to sample `fps` frames per second but no video metadata was provided which is required when sampling with `fps`. "
+                    "Please pass in `VideoMetadata` object or use a fixed `num_frames` per input video"
+                )
+            num_frames = int(total_num_frames / metadata.fps * fps)
+
+        if num_frames > total_num_frames:
+            raise ValueError(
+                f"Video can't be sampled. The `num_frames={num_frames}` exceeds `total_num_frames={total_num_frames}`. "
+            )
+
+        if num_frames is not None:
+            indices = torch.arange(0, total_num_frames, total_num_frames / num_frames).int()
+        else:
+            indices = torch.arange(0, total_num_frames).int()
+        return indices
+
+    def _decode_and_sample_videos(
+        self,
+        videos: VideoInput,
+        video_metadata: Union[VideoMetadata, dict],
+        do_sample_frames: Optional[bool] = None,
+        sample_indices_fn: Optional[Callable] = None,
+    ) -> list["torch.Tensor"]:
+        """
+        Decode input videos and sample frames if needed.
+        """
+        videos = make_batched_videos(videos)
+        video_metadata = make_batched_metadata(videos, video_metadata=video_metadata)
+
+        # Only sample frames if an array video is passed, otherwise first decode -> then sample
+        if is_valid_video(videos[0]) and do_sample_frames:
+            sampled_videos = []
+            sampled_metadata = []
+            for video, metadata in zip(videos, video_metadata):
+                indices = sample_indices_fn(metadata=metadata)
+                metadata.frames_indices = indices
+                sampled_videos.append(video[indices])
+                sampled_metadata.append(metadata)
+            videos = sampled_videos
+            video_metadata = sampled_metadata
+        elif not is_valid_video(videos[0]):
+            if isinstance(videos[0], list):
+                # Videos sometimes are passed as a list of image URLs, especially through templates
+                videos = [
+                    torch.stack([F.pil_to_tensor(image) for image in images], dim=0)
+                    for images in self.fetch_images(videos)
+                ]
+                if do_sample_frames:
+                    raise ValueError(
+                        "Sampling frames from a list of images is not supported! Set `do_sample_frames=False`."
+                    )
+            else:
+                videos, video_metadata = self.fetch_videos(videos, sample_indices_fn=sample_indices_fn)
+
+        return videos, video_metadata
+
+    def _prepare_input_videos(
+        self,
+        videos: VideoInput,
+        input_data_format: Optional[Union[str, ChannelDimension]] = None,
+        device: Optional[str] = None,
+    ) -> list["torch.Tensor"]:
+        """
+        Prepare the input videos for processing.
+        """
+        processed_videos = []
+        for video in videos:
+            # `make_batched_videos` always returns a 4D array per video
+            if isinstance(video, np.ndarray):
+                video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_data_format)
+                # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays
+                video = torch.from_numpy(video).contiguous()
+
+            if device is not None:
+                video = video.to(device)
+
+            processed_videos.append(video)
+        return processed_videos
+
+    @add_start_docstrings(
+        BASE_VIDEO_PROCESSOR_DOCSTRING,
+    )
+    def preprocess(
+        self,
+        videos: VideoInput,
+        **kwargs: Unpack[VideosKwargs],
+    ) -> BatchFeature:
+        validate_kwargs(
+            captured_kwargs=kwargs.keys(),
+            valid_processor_keys=list(self.valid_kwargs.__annotations__.keys()) + ["return_tensors"],
+        )
+        # Set default kwargs from self. This ensures that if a kwarg is not provided
+        # by the user, it gets its default value from the instance, or is set to None.
+        for kwarg_name in self.valid_kwargs.__annotations__:
+            kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
+
+        input_data_format = kwargs.pop("input_data_format")
+        do_sample_frames = kwargs.pop("do_sample_frames")
+        device = kwargs.pop("device")
+        video_metadata = kwargs.pop("video_metadata")
+
+        sample_indices_fn = partial(self.sample_frames, **kwargs) if do_sample_frames else None
+        videos, video_metadata = self._decode_and_sample_videos(
+            videos,
+            video_metadata=video_metadata,
+            do_sample_frames=do_sample_frames,
+            sample_indices_fn=sample_indices_fn,
+        )
+        videos = self._prepare_input_videos(videos=videos, input_data_format=input_data_format, device=device)
+
+        kwargs = self._further_process_kwargs(**kwargs)
+        self._validate_preprocess_kwargs(**kwargs)
+
+        # Pop kwargs that are not needed in _preprocess
+        kwargs.pop("data_format")
+        return_metadata = kwargs.pop("return_metadata")
+
+        preprocessed_videos = self._preprocess(videos=videos, **kwargs)
+        if return_metadata:
+            preprocessed_videos["video_metadata"] = video_metadata
+        return preprocessed_videos
+
+    def _preprocess(
+        self,
+        videos: list["torch.Tensor"],
+        do_convert_rgb: bool,
+        do_resize: bool,
+        size: SizeDict,
+        size_divisor: Optional[int],
+        interpolation: Optional["F.InterpolationMode"],
+        do_center_crop: bool,
+        crop_size: SizeDict,
+        do_rescale: bool,
+        do_pad: bool,
+        rescale_factor: float,
+        do_normalize: bool,
+        image_mean: Optional[Union[float, list[float]]],
+        image_std: Optional[Union[float, list[float]]],
+        return_tensors: Optional[Union[str, TensorType]] = None,
+        **kwargs,
+    ) -> BatchFeature:
+        # Group videos by size for batched resizing
+        grouped_videos, grouped_videos_index = group_videos_by_shape(videos)
+        resized_videos_grouped = {}
+        for shape, stacked_videos in grouped_videos.items():
+            if do_convert_rgb:
+                stacked_videos = self.convert_to_rgb(stacked_videos)
+            if do_resize:
+                stacked_videos = self.resize(
+                    stacked_videos, size=size, size_divisor=size_divisor, interpolation=interpolation
+                )
+            resized_videos_grouped[shape] = stacked_videos
+        resized_videos = reorder_videos(resized_videos_grouped, grouped_videos_index)
+
+        # Group videos by size for further processing
+        # Needed in case do_resize is False, or resize returns videos with different sizes
+        grouped_videos, grouped_videos_index = group_videos_by_shape(resized_videos)
+        processed_videos_grouped = {}
+        for shape, stacked_videos in grouped_videos.items():
+            if do_center_crop:
+                stacked_videos = self.center_crop(stacked_videos, crop_size)
+            # Fused rescale and normalize
+            stacked_videos = self.rescale_and_normalize(
+                stacked_videos, do_rescale, rescale_factor, do_normalize, image_mean, image_std
+            )
+            processed_videos_grouped[shape] = stacked_videos
+
+        processed_videos = reorder_videos(processed_videos_grouped, grouped_videos_index)
+        processed_videos = torch.stack(processed_videos, dim=0) if return_tensors else processed_videos
+
+        return BatchFeature(data={"pixel_values_videos": processed_videos}, tensor_type=return_tensors)
+
+    @classmethod
+    def from_pretrained(
+        cls,
+        pretrained_model_name_or_path: Union[str, os.PathLike],
+        cache_dir: Optional[Union[str, os.PathLike]] = None,
+        force_download: bool = False,
+        local_files_only: bool = False,
+        token: Optional[Union[str, bool]] = None,
+        revision: str = "main",
+        **kwargs,
+    ):
+        r"""
+        Instantiate a type of [`~video_processing_utils.VideoProcessorBase`] from an video processor.
+
+        Args:
+            pretrained_model_name_or_path (`str` or `os.PathLike`):
+                This can be either:
+
+                - a string, the *model id* of a pretrained video hosted inside a model repo on
+                  huggingface.co.
+                - a path to a *directory* containing a video processor file saved using the
+                  [`~video_processing_utils.VideoProcessorBase.save_pretrained`] method, e.g.,
+                  `./my_model_directory/`.
+                - a path or url to a saved video processor JSON *file*, e.g.,
+                  `./my_model_directory/video_preprocessor_config.json`.
+            cache_dir (`str` or `os.PathLike`, *optional*):
+                Path to a directory in which a downloaded pretrained model video processor should be cached if the
+                standard cache should not be used.
+            force_download (`bool`, *optional*, defaults to `False`):
+                Whether or not to force to (re-)download the video processor files and override the cached versions if
+                they exist.
+            resume_download:
+                Deprecated and ignored. All downloads are now resumed by default when possible.
+                Will be removed in v5 of Transformers.
+            proxies (`dict[str, str]`, *optional*):
+                A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
+                'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
+            token (`str` or `bool`, *optional*):
+                The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
+                the token generated when running `hf auth login` (stored in `~/.huggingface`).
+            revision (`str`, *optional*, defaults to `"main"`):
+                The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
+                git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
+                identifier allowed by git.
+
+
+                
+
+                To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`.
+
+                
+
+            return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+                If `False`, then this function returns just the final video processor object. If `True`, then this
+                functions returns a `Tuple(video_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
+                consisting of the key/value pairs whose keys are not video processor attributes: i.e., the part of
+                `kwargs` which has not been used to update `video_processor` and is otherwise ignored.
+            subfolder (`str`, *optional*, defaults to `""`):
+                In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+                specify the folder name here.
+            kwargs (`dict[str, Any]`, *optional*):
+                The values in kwargs of any keys which are video processor attributes will be used to override the
+                loaded values. Behavior concerning key/value pairs whose keys are *not* video processor attributes is
+                controlled by the `return_unused_kwargs` keyword parameter.
+
+        Returns:
+            A video processor of type [`~video_processing_utils.ImagVideoProcessorBase`].
+
+        Examples:
+
+        ```python
+        # We can't instantiate directly the base class *VideoProcessorBase* so let's show the examples on a
+        # derived class: *LlavaOnevisionVideoProcessor*
+        video_processor = LlavaOnevisionVideoProcessor.from_pretrained(
+            "llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
+        )  # Download video_processing_config from huggingface.co and cache.
+        video_processor = LlavaOnevisionVideoProcessor.from_pretrained(
+            "./test/saved_model/"
+        )  # E.g. video processor (or model) was saved using *save_pretrained('./test/saved_model/')*
+        video_processor = LlavaOnevisionVideoProcessor.from_pretrained("./test/saved_model/video_preprocessor_config.json")
+        video_processor = LlavaOnevisionVideoProcessor.from_pretrained(
+            "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", do_normalize=False, foo=False
+        )
+        assert video_processor.do_normalize is False
+        video_processor, unused_kwargs = LlavaOnevisionVideoProcessor.from_pretrained(
+            "llava-hf/llava-onevision-qwen2-0.5b-ov-hf", do_normalize=False, foo=False, return_unused_kwargs=True
+        )
+        assert video_processor.do_normalize is False
+        assert unused_kwargs == {"foo": False}
+        ```"""
+        kwargs["cache_dir"] = cache_dir
+        kwargs["force_download"] = force_download
+        kwargs["local_files_only"] = local_files_only
+        kwargs["revision"] = revision
+
+        use_auth_token = kwargs.pop("use_auth_token", None)
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if token is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            token = use_auth_token
+
+        if token is not None:
+            kwargs["token"] = token
+
+        video_processor_dict, kwargs = cls.get_video_processor_dict(pretrained_model_name_or_path, **kwargs)
+
+        return cls.from_dict(video_processor_dict, **kwargs)
+
+    def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+        """
+        Save an video processor object to the directory `save_directory`, so that it can be re-loaded using the
+        [`~video_processing_utils.VideoProcessorBase.from_pretrained`] class method.
+
+        Args:
+            save_directory (`str` or `os.PathLike`):
+                Directory where the video processor JSON file will be saved (will be created if it does not exist).
+            push_to_hub (`bool`, *optional*, defaults to `False`):
+                Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
+                repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+                namespace).
+            kwargs (`dict[str, Any]`, *optional*):
+                Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+        """
+        use_auth_token = kwargs.pop("use_auth_token", None)
+
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if kwargs.get("token") is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            kwargs["token"] = use_auth_token
+
+        if os.path.isfile(save_directory):
+            raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+        os.makedirs(save_directory, exist_ok=True)
+
+        if push_to_hub:
+            commit_message = kwargs.pop("commit_message", None)
+            repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+            repo_id = self._create_repo(repo_id, **kwargs)
+            files_timestamps = self._get_files_timestamps(save_directory)
+
+        # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
+        # loaded from the Hub.
+        if self._auto_class is not None:
+            custom_object_save(self, save_directory, config=self)
+
+        # If we save using the predefined names, we can load using `from_pretrained`
+        output_video_processor_file = os.path.join(save_directory, VIDEO_PROCESSOR_NAME)
+
+        self.to_json_file(output_video_processor_file)
+        logger.info(f"Video processor saved in {output_video_processor_file}")
+
+        if push_to_hub:
+            self._upload_modified_files(
+                save_directory,
+                repo_id,
+                files_timestamps,
+                commit_message=commit_message,
+                token=kwargs.get("token"),
+            )
+
+        return [output_video_processor_file]
+
+    @classmethod
+    def get_video_processor_dict(
+        cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+    ) -> tuple[dict[str, Any], dict[str, Any]]:
+        """
+        From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
+        video processor of type [`~video_processing_utils.VideoProcessorBase`] using `from_dict`.
+
+        Parameters:
+            pretrained_model_name_or_path (`str` or `os.PathLike`):
+                The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
+            subfolder (`str`, *optional*, defaults to `""`):
+                In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
+                specify the folder name here.
+
+        Returns:
+            `tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the video processor object.
+        """
+        cache_dir = kwargs.pop("cache_dir", None)
+        force_download = kwargs.pop("force_download", False)
+        resume_download = kwargs.pop("resume_download", None)
+        proxies = kwargs.pop("proxies", None)
+        token = kwargs.pop("token", None)
+        use_auth_token = kwargs.pop("use_auth_token", None)
+        local_files_only = kwargs.pop("local_files_only", False)
+        revision = kwargs.pop("revision", None)
+        subfolder = kwargs.pop("subfolder", "")
+
+        from_pipeline = kwargs.pop("_from_pipeline", None)
+        from_auto_class = kwargs.pop("_from_auto", False)
+
+        if use_auth_token is not None:
+            warnings.warn(
+                "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
+                FutureWarning,
+            )
+            if token is not None:
+                raise ValueError(
+                    "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
+                )
+            token = use_auth_token
+
+        user_agent = {"file_type": "video processor", "from_auto_class": from_auto_class}
+        if from_pipeline is not None:
+            user_agent["using_pipeline"] = from_pipeline
+
+        if is_offline_mode() and not local_files_only:
+            logger.info("Offline mode: forcing local_files_only=True")
+            local_files_only = True
+
+        pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+        is_local = os.path.isdir(pretrained_model_name_or_path)
+        if os.path.isfile(pretrained_model_name_or_path):
+            resolved_video_processor_file = pretrained_model_name_or_path
+            is_local = True
+        elif is_remote_url(pretrained_model_name_or_path):
+            video_processor_file = pretrained_model_name_or_path
+            resolved_video_processor_file = download_url(pretrained_model_name_or_path)
+        else:
+            video_processor_file = VIDEO_PROCESSOR_NAME
+            try:
+                # Try to load with a new config name first and if not successfull try with the old file name
+                # NOTE: we will gradually change to saving all processor configs as nested dict in PROCESSOR_NAME
+                resolved_video_processor_files = cached_files(
+                    pretrained_model_name_or_path,
+                    filenames=[VIDEO_PROCESSOR_NAME, IMAGE_PROCESSOR_NAME, PROCESSOR_NAME],
+                    cache_dir=cache_dir,
+                    force_download=force_download,
+                    proxies=proxies,
+                    resume_download=resume_download,
+                    local_files_only=local_files_only,
+                    token=token,
+                    user_agent=user_agent,
+                    revision=revision,
+                    subfolder=subfolder,
+                    _raise_exceptions_for_missing_entries=False,
+                )
+                resolved_video_processor_file = resolved_video_processor_files[0]
+            except EnvironmentError:
+                # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
+                # the original exception.
+                raise
+            except Exception:
+                # For any other exception, we throw a generic error.
+                raise OSError(
+                    f"Can't load video processor for '{pretrained_model_name_or_path}'. If you were trying to load"
+                    " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
+                    f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
+                    f" directory containing a {VIDEO_PROCESSOR_NAME} file"
+                )
+
+        try:
+            # Load video_processor dict
+            with open(resolved_video_processor_file, "r", encoding="utf-8") as reader:
+                text = reader.read()
+            video_processor_dict = json.loads(text)
+            video_processor_dict = video_processor_dict.get("video_processor", video_processor_dict)
+
+        except json.JSONDecodeError:
+            raise OSError(
+                f"It looks like the config file at '{resolved_video_processor_file}' is not a valid JSON file."
+            )
+
+        if is_local:
+            logger.info(f"loading configuration file {resolved_video_processor_file}")
+        else:
+            logger.info(
+                f"loading configuration file {video_processor_file} from cache at {resolved_video_processor_file}"
+            )
+        return video_processor_dict, kwargs
+
+    @classmethod
+    def from_dict(cls, video_processor_dict: dict[str, Any], **kwargs):
+        """
+        Instantiates a type of [`~video_processing_utils.VideoProcessorBase`] from a Python dictionary of parameters.
+
+        Args:
+            video_processor_dict (`dict[str, Any]`):
+                Dictionary that will be used to instantiate the video processor object. Such a dictionary can be
+                retrieved from a pretrained checkpoint by leveraging the
+                [`~video_processing_utils.VideoProcessorBase.to_dict`] method.
+            kwargs (`dict[str, Any]`):
+                Additional parameters from which to initialize the video processor object.
+
+        Returns:
+            [`~video_processing_utils.VideoProcessorBase`]: The video processor object instantiated from those
+            parameters.
+        """
+        video_processor_dict = video_processor_dict.copy()
+        return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
+
+        # The `size` parameter is a dict and was previously an int or tuple in feature extractors.
+        # We set `size` here directly to the `video_processor_dict` so that it is converted to the appropriate
+        # dict within the video processor and isn't overwritten if `size` is passed in as a kwarg.
+        if "size" in kwargs and "size" in video_processor_dict:
+            video_processor_dict["size"] = kwargs.pop("size")
+        if "crop_size" in kwargs and "crop_size" in video_processor_dict:
+            video_processor_dict["crop_size"] = kwargs.pop("crop_size")
+
+        video_processor = cls(**video_processor_dict)
+
+        # Update video_processor with kwargs if needed
+        to_remove = []
+        for key, value in kwargs.items():
+            if hasattr(video_processor, key):
+                setattr(video_processor, key, value)
+                to_remove.append(key)
+        for key in to_remove:
+            kwargs.pop(key, None)
+
+        logger.info(f"Video processor {video_processor}")
+        if return_unused_kwargs:
+            return video_processor, kwargs
+        else:
+            return video_processor
+
+    def to_dict(self) -> dict[str, Any]:
+        """
+        Serializes this instance to a Python dictionary.
+
+        Returns:
+            `dict[str, Any]`: Dictionary of all the attributes that make up this video processor instance.
+        """
+        output = deepcopy(self.__dict__)
+        output.pop("model_valid_processing_keys", None)
+        output.pop("_valid_kwargs_names", None)
+        output["video_processor_type"] = self.__class__.__name__
+
+        return output
+
+    def to_json_string(self) -> str:
+        """
+        Serializes this instance to a JSON string.
+
+        Returns:
+            `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
+        """
+        dictionary = self.to_dict()
+
+        for key, value in dictionary.items():
+            if isinstance(value, np.ndarray):
+                dictionary[key] = value.tolist()
+
+        # make sure private name "_processor_class" is correctly
+        # saved as "processor_class"
+        _processor_class = dictionary.pop("_processor_class", None)
+        if _processor_class is not None:
+            dictionary["processor_class"] = _processor_class
+
+        return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
+
+    def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+        """
+        Save this instance to a JSON file.
+
+        Args:
+            json_file_path (`str` or `os.PathLike`):
+                Path to the JSON file in which this image_processor instance's parameters will be saved.
+        """
+        with open(json_file_path, "w", encoding="utf-8") as writer:
+            writer.write(self.to_json_string())
+
+    def __repr__(self):
+        return f"{self.__class__.__name__} {self.to_json_string()}"
+
+    @classmethod
+    def from_json_file(cls, json_file: Union[str, os.PathLike]):
+        """
+        Instantiates a video processor of type [`~video_processing_utils.VideoProcessorBase`] from the path to a JSON
+        file of parameters.
+
+        Args:
+            json_file (`str` or `os.PathLike`):
+                Path to the JSON file containing the parameters.
+
+        Returns:
+            A video processor of type [`~video_processing_utils.VideoProcessorBase`]: The video_processor object
+            instantiated from that JSON file.
+        """
+        with open(json_file, "r", encoding="utf-8") as reader:
+            text = reader.read()
+        video_processor_dict = json.loads(text)
+        return cls(**video_processor_dict)
+
+    @classmethod
+    def register_for_auto_class(cls, auto_class="AutoVideoProcessor"):
+        """
+        Register this class with a given auto class. This should only be used for custom video processors as the ones
+        in the library are already mapped with `AutoVideoProcessor `.
+
+        
+
+        This API is experimental and may have some slight breaking changes in the next releases.
+
+        
+
+        Args:
+            auto_class (`str` or `type`, *optional*, defaults to `"AutoVideoProcessor "`):
+                The auto class to register this new video processor with.
+        """
+        if not isinstance(auto_class, str):
+            auto_class = auto_class.__name__
+
+        import transformers.models.auto as auto_module
+
+        if not hasattr(auto_module, auto_class):
+            raise ValueError(f"{auto_class} is not a valid auto class.")
+
+        cls._auto_class = auto_class
+
+    def fetch_videos(self, video_url_or_urls: Union[str, list[str], list[list[str]]], sample_indices_fn=None):
+        """
+        Convert a single or a list of urls into the corresponding `np.array` objects.
+
+        If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
+        returned.
+        """
+        backend = "torchcodec"
+        if not is_torchcodec_available():
+            warnings.warn(
+                "`torchcodec` is not installed and cannot be used to decode the video by default. "
+                "Falling back to `torchvision`. Note that `torchvision` decoding is deprecated and will be removed in future versions. "
+            )
+            backend = "torchvision"
+
+        if isinstance(video_url_or_urls, list):
+            return list(zip(*[self.fetch_videos(x, sample_indices_fn=sample_indices_fn) for x in video_url_or_urls]))
+        else:
+            return load_video(video_url_or_urls, backend=backend, sample_indices_fn=sample_indices_fn)
+
+
+BaseVideoProcessor.push_to_hub = copy_func(BaseVideoProcessor.push_to_hub)
+if BaseVideoProcessor.push_to_hub.__doc__ is not None:
+    BaseVideoProcessor.push_to_hub.__doc__ = BaseVideoProcessor.push_to_hub.__doc__.format(
+        object="video processor", object_class="AutoVideoProcessor", object_files="video processor file"
+    )
diff --git a/phivenv/Lib/site-packages/transformers/video_utils.py b/phivenv/Lib/site-packages/transformers/video_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb69dd41b70cfcb13dea1c68d655f0cd29e1dc48
--- /dev/null
+++ b/phivenv/Lib/site-packages/transformers/video_utils.py
@@ -0,0 +1,878 @@
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import warnings
+from collections.abc import Iterable, Mapping
+from contextlib import redirect_stdout
+from dataclasses import dataclass, fields
+from io import BytesIO
+from typing import Callable, NewType, Optional, Union
+from urllib.parse import urlparse
+
+import numpy as np
+import requests
+
+from .image_transforms import PaddingMode, to_channel_dimension_format
+from .image_utils import ChannelDimension, infer_channel_dimension_format, is_valid_image
+from .utils import (
+    is_av_available,
+    is_cv2_available,
+    is_decord_available,
+    is_numpy_array,
+    is_torch_available,
+    is_torch_tensor,
+    is_torchcodec_available,
+    is_torchvision_available,
+    is_vision_available,
+    is_yt_dlp_available,
+    logging,
+    requires_backends,
+)
+
+
+if is_vision_available():
+    import PIL.Image
+    import PIL.ImageOps
+
+    if is_torchvision_available():
+        from torchvision import io as torchvision_io
+
+if is_torch_available():
+    import torch
+
+
+logger = logging.get_logger(__name__)
+
+URL = NewType("URL", str)
+Path = NewType("Path", str)
+
+VideoInput = Union[
+    list["PIL.Image.Image"],
+    "np.ndarray",
+    "torch.Tensor",
+    list["np.ndarray"],
+    list["torch.Tensor"],
+    list[list["PIL.Image.Image"]],
+    list[list["np.ndarrray"]],
+    list[list["torch.Tensor"]],
+    URL,
+    list[URL],
+    list[list[URL]],
+    Path,
+    list[Path],
+    list[list[Path]],
+]  # noqa
+
+
+@dataclass
+class VideoMetadata(Mapping):
+    total_num_frames: int
+    fps: float = None
+    width: int = None
+    height: int = None
+    duration: float = None
+    video_backend: str = None
+    frames_indices: list[int] = None
+
+    def __iter__(self):
+        return (f.name for f in fields(self))
+
+    def __len__(self):
+        return len(fields(self))
+
+    def __getitem__(self, item):
+        return getattr(self, item)
+
+    def __setitem__(self, key, value):
+        return setattr(self, key, value)
+
+    @property
+    def timestamps(self) -> float:
+        "Timestamps of the sampled frames in seconds."
+        if self.fps is None or self.frames_indices is None:
+            raise ValueError("Cannot infer video `timestamps` when `fps` or `frames_indices` is None.")
+        return [frame_idx / self.fps for frame_idx in self.frames_indices]
+
+    def update(self, dictionary):
+        for key, value in dictionary.items():
+            if hasattr(self, key):
+                setattr(self, key, value)
+
+
+def is_valid_video_frame(frame):
+    return isinstance(frame, PIL.Image.Image) or (
+        (is_numpy_array(frame) or is_torch_tensor(frame)) and frame.ndim == 3
+    )
+
+
+def is_valid_video(video):
+    if not isinstance(video, (list, tuple)):
+        return (is_numpy_array(video) or is_torch_tensor(video)) and video.ndim == 4
+    return all(is_valid_video_frame(frame) for frame in video)
+
+
+def valid_videos(videos):
+    # If we have a list of videos, it could be either one video as list of frames or a batch
+    if isinstance(videos, (list, tuple)):
+        for video_or_frame in videos:
+            if not (is_valid_video(video_or_frame) or is_valid_video_frame(video_or_frame)):
+                return False
+    # If not a list, then we have a single 4D video or 5D batched tensor
+    elif not is_valid_video(videos) or videos.ndim == 5:
+        return False
+    return True
+
+
+def is_batched_video(videos):
+    if isinstance(videos, (list, tuple)):
+        return is_valid_video(videos[0])
+    elif (is_numpy_array(videos) or is_torch_tensor(videos)) and videos.ndim == 5:
+        return True
+    return False
+
+
+def is_scaled_video(video: np.ndarray) -> bool:
+    """
+    Checks to see whether the pixel values have already been rescaled to [0, 1].
+    """
+    # It's possible the video has pixel values in [0, 255] but is of floating type
+    return np.min(video) >= 0 and np.max(video) <= 1
+
+
+def convert_pil_frames_to_video(videos: list[VideoInput]) -> list[Union["np.ndarray", "torch.Tensor"]]:
+    """
+    Given a batch of videos, converts each video to a 4D array. If video is already in array type,
+    it is simply returned. We assume that all inputs in the list are in the same format, based on the type of the first element.
+
+    Args:
+        videos (`VideoInput`):
+            Video inputs to turn into a list of videos.
+    """
+
+    if not (isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0])):
+        return videos
+
+    video_converted = []
+    for video in videos:
+        video = [np.array(frame) for frame in video]
+        video = np.stack(video)
+        video_converted.append(video)
+    return video_converted
+
+
+def make_batched_videos(videos) -> list[Union["np.ndarray", "torch.Tensor", "URL", "Path"]]:
+    """
+    Ensure that the input is a list of videos. If the input is a single video, it is converted to a list of length 1.
+    If the input is a batch of videos, it is converted to a list of 4D video arrays. Videos passed as list `PIL.Image`
+    frames are converted to 4D arrays.
+
+    We assume that all inputs in the list are in the same format, based on the type of the first element.
+
+    Args:
+        videos (`VideoInput`):
+            Video inputs to turn into a list of videos.
+    """
+    # Early exit for deeply nested list of image frame paths. We shouldn't flatten them
+    try:
+        if isinstance(videos[0][0][0], str):
+            return [image_paths for sublist in videos for image_paths in sublist]
+    except (IndexError, TypeError):
+        pass
+
+    if isinstance(videos, str) or is_valid_video(videos):
+        return convert_pil_frames_to_video([videos])
+    # only one frame passed, thus we unsqueeze time dim
+    elif is_valid_image(videos):
+        return [np.array(videos)[None, ...]]
+    elif not isinstance(videos, list):
+        raise ValueError(
+            f"Invalid video input. Expected either a list of video frames or an input of 4 or 5 dimensions, but got"
+            f" type {type(videos)}."
+        )
+
+    # Recursively flatten any nested structure
+    flat_videos_list = []
+    for item in videos:
+        if isinstance(item, str) or is_valid_video(item):
+            flat_videos_list.append(item)
+        elif isinstance(item, list):
+            flat_videos_list.extend(make_batched_videos(item))
+
+    flat_videos_list = convert_pil_frames_to_video(flat_videos_list)
+    return flat_videos_list
+
+
+def make_batched_metadata(videos: VideoInput, video_metadata: Union[VideoMetadata, dict]):
+    if video_metadata is None:
+        # Create default metadata and fill attrbiutes we can infer from given video
+        video_metadata = [
+            {
+                "total_num_frames": len(video),
+                "fps": None,
+                "duration": None,
+                "frames_indices": list(range(len(video))),
+                "height": get_video_size(video)[0] if is_valid_video(video) else None,
+                "width": get_video_size(video)[1] if is_valid_video(video) else None,
+            }
+            for video in videos
+        ]
+
+    if isinstance(video_metadata, list):
+        # Flatten if nested list
+        if isinstance(video_metadata[0], list):
+            video_metadata = [
+                VideoMetadata(**metadata) for metadata_list in video_metadata for metadata in metadata_list
+            ]
+        # Simply wrap in VideoMetadata if simple dict
+        elif isinstance(video_metadata[0], dict):
+            video_metadata = [VideoMetadata(**metadata) for metadata in video_metadata]
+    else:
+        # Create a batched list from single object
+        video_metadata = [VideoMetadata(**video_metadata)]
+    return video_metadata
+
+
+def get_video_size(video: np.ndarray, channel_dim: ChannelDimension = None) -> tuple[int, int]:
+    """
+    Returns the (height, width) dimensions of the video.
+
+    Args:
+        video (`np.ndarray`):
+            The video to get the dimensions of.
+        channel_dim (`ChannelDimension`, *optional*):
+            Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the video.
+
+    Returns:
+        A tuple of the video's height and width.
+    """
+    if channel_dim is None:
+        channel_dim = infer_channel_dimension_format(video, num_channels=(1, 3, 4))
+
+    if channel_dim == ChannelDimension.FIRST:
+        return video.shape[-2], video.shape[-1]
+    elif channel_dim == ChannelDimension.LAST:
+        return video.shape[-3], video.shape[-2]
+    else:
+        raise ValueError(f"Unsupported data format: {channel_dim}")
+
+
+def get_uniform_frame_indices(total_num_frames: int, num_frames: Optional[int] = None):
+    """
+    Creates a numpy array for uniform sampling of `num_frame` frames from `total_num_frames`
+    when loading a video.
+
+    Args:
+        total_num_frames (`int`):
+            Total number of frames that a video has.
+        num_frames (`int`, *optional*):
+            Number of frames to sample uniformly. If not specified, all frames are sampled.
+
+    Returns:
+        np.ndarray: np array of frame indices that will be sampled.
+    """
+    if num_frames is not None:
+        indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int)
+    else:
+        indices = np.arange(0, total_num_frames).astype(int)
+    return indices
+
+
+def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
+    """
+    A default sampling function that replicates the logic used in get_uniform_frame_indices,
+    while optionally handling `fps` if `num_frames` is not provided.
+
+    Args:
+        metadata (`VideoMetadata`):
+            `VideoMetadata` object containing metadata about the video, such as "total_num_frames" or "fps".
+        num_frames (`int`, *optional*):
+            Number of frames to sample uniformly.
+        fps (`int` or `float`, *optional*):
+            Desired frames per second. Takes priority over num_frames if both are provided.
+
+    Returns:
+        `np.ndarray`: Array of frame indices to sample.
+    """
+    total_num_frames = metadata.total_num_frames
+    video_fps = metadata.fps
+
+    # If num_frames is not given but fps is, calculate num_frames from fps
+    if num_frames is None and fps is not None:
+        num_frames = int(total_num_frames / video_fps * fps)
+        if num_frames > total_num_frames:
+            raise ValueError(
+                f"When loading the video with fps={fps}, we computed num_frames={num_frames} "
+                f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata."
+            )
+
+    if num_frames is not None:
+        indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int)
+    else:
+        indices = np.arange(0, total_num_frames, dtype=int)
+    return indices
+
+
+def read_video_opencv(
+    video_path: Union["URL", "Path"],
+    sample_indices_fn: Callable,
+    **kwargs,
+):
+    """
+    Decode a video using the OpenCV backend.
+
+    Args:
+        video_path (`str`):
+            Path to the video file.
+        sample_indices_fn (`Callable`):
+            A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
+            by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
+            If not provided, simple uniform sampling with fps is performed.
+            Example:
+            def sample_indices_fn(metadata, **kwargs):
+                return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
+
+    Returns:
+        tuple[`np.array`, `VideoMetadata`]: A tuple containing:
+            - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
+            - `VideoMetadata` object.
+    """
+    # Lazy import cv2
+    requires_backends(read_video_opencv, ["cv2"])
+    import cv2
+
+    video = cv2.VideoCapture(video_path)
+    total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
+    video_fps = video.get(cv2.CAP_PROP_FPS)
+    duration = total_num_frames / video_fps if video_fps else 0
+    metadata = VideoMetadata(
+        total_num_frames=int(total_num_frames),
+        fps=float(video_fps),
+        duration=float(duration),
+        video_backend="opencv",
+        height=int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)),
+        width=int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
+    )
+    indices = sample_indices_fn(metadata=metadata, **kwargs)
+
+    index = 0
+    frames = []
+    while video.isOpened():
+        success, frame = video.read()
+        if not success:
+            break
+        if index in indices:
+            height, width, channel = frame.shape
+            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+            frames.append(frame[0:height, 0:width, 0:channel])
+        if success:
+            index += 1
+        if index >= total_num_frames:
+            break
+
+    video.release()
+    metadata.frames_indices = indices
+    return np.stack(frames), metadata
+
+
+def read_video_decord(
+    video_path: Union["URL", "Path"],
+    sample_indices_fn: Callable,
+    **kwargs,
+):
+    """
+    Decode a video using the Decord backend.
+
+    Args:
+        video_path (`str`):
+            Path to the video file.
+        sample_indices_fn (`Callable`):
+            A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
+            by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
+            If not provided, simple uniform sampling with fps is performed.
+            Example:
+            def sample_indices_fn(metadata, **kwargs):
+                return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
+
+    Returns:
+        tuple[`np.array`, `VideoMetadata`]: A tuple containing:
+            - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
+            - `VideoMetadata` object.
+    """
+    # Lazy import from decord
+    requires_backends(read_video_decord, ["decord"])
+    from decord import VideoReader, cpu
+
+    vr = VideoReader(uri=video_path, ctx=cpu(0))  # decord has problems with gpu
+    video_fps = vr.get_avg_fps()
+    total_num_frames = len(vr)
+    duration = total_num_frames / video_fps if video_fps else 0
+    metadata = VideoMetadata(
+        total_num_frames=int(total_num_frames),
+        fps=float(video_fps),
+        duration=float(duration),
+        video_backend="decord",
+    )
+
+    indices = sample_indices_fn(metadata=metadata, **kwargs)
+    video = vr.get_batch(indices).asnumpy()
+
+    metadata.update(
+        {
+            "frames_indices": indices,
+            "height": video.shape[1],
+            "width": video.shape[2],
+        }
+    )
+    return video, metadata
+
+
+def read_video_pyav(
+    video_path: Union["URL", "Path"],
+    sample_indices_fn: Callable,
+    **kwargs,
+):
+    """
+    Decode the video with PyAV decoder.
+
+    Args:
+        video_path (`str`):
+            Path to the video file.
+        sample_indices_fn (`Callable`, *optional*):
+            A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
+            by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
+            If not provided, simple uniform sampling with fps is performed.
+            Example:
+            def sample_indices_fn(metadata, **kwargs):
+                return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
+
+    Returns:
+        tuple[`np.array`, `VideoMetadata`]: A tuple containing:
+            - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
+            - `VideoMetadata` object.
+    """
+    # Lazy import av
+    requires_backends(read_video_pyav, ["av"])
+    import av
+
+    container = av.open(video_path)
+    total_num_frames = container.streams.video[0].frames
+    video_fps = container.streams.video[0].average_rate  # should we better use `av_guess_frame_rate`?
+    duration = total_num_frames / video_fps if video_fps else 0
+    metadata = VideoMetadata(
+        total_num_frames=int(total_num_frames),
+        fps=float(video_fps),
+        duration=float(duration),
+        video_backend="pyav",
+        height=container.streams.video[0].height,
+        width=container.streams.video[0].width,
+    )
+    indices = sample_indices_fn(metadata=metadata, **kwargs)
+
+    frames = []
+    container.seek(0)
+    end_index = indices[-1]
+    for i, frame in enumerate(container.decode(video=0)):
+        if i > end_index:
+            break
+        if i >= 0 and i in indices:
+            frames.append(frame)
+
+    video = np.stack([x.to_ndarray(format="rgb24") for x in frames])
+    metadata.frames_indices = indices
+    return video, metadata
+
+
+def read_video_torchvision(
+    video_path: Union["URL", "Path"],
+    sample_indices_fn: Callable,
+    **kwargs,
+):
+    """
+    Decode the video with torchvision decoder.
+
+    Args:
+        video_path (`str`):
+            Path to the video file.
+        sample_indices_fn (`Callable`, *optional*):
+            A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
+            by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
+            If not provided, simple uniform sampling with fps is performed.
+            Example:
+            def sample_indices_fn(metadata, **kwargs):
+                return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
+
+    Returns:
+        tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing:
+            - Torch tensor of frames in RGB (shape: [num_frames, height, width, 3]).
+            - `VideoMetadata` object.
+    """
+    warnings.warn(
+        "Using `torchvision` for video decoding is deprecated and will be removed in future versions. "
+        "Please use `torchcodec` instead."
+    )
+    video, _, info = torchvision_io.read_video(
+        video_path,
+        start_pts=0.0,
+        end_pts=None,
+        pts_unit="sec",
+        output_format="TCHW",
+    )
+    video_fps = info["video_fps"]
+    total_num_frames = video.size(0)
+    duration = total_num_frames / video_fps if video_fps else 0
+    metadata = VideoMetadata(
+        total_num_frames=int(total_num_frames),
+        fps=float(video_fps),
+        duration=float(duration),
+        video_backend="torchvision",
+    )
+
+    indices = sample_indices_fn(metadata=metadata, **kwargs)
+
+    video = video[indices].contiguous()
+    metadata.update(
+        {
+            "frames_indices": indices,
+            "height": video.shape[1],
+            "width": video.shape[2],
+        }
+    )
+    return video, metadata
+
+
+def read_video_torchcodec(
+    video_path: Union["URL", "Path"],
+    sample_indices_fn: Callable,
+    **kwargs,
+):
+    """
+    Decode the video with torchcodec decoder.
+
+    Args:
+        video_path (`str`):
+            Path to the video file.
+        sample_indices_fn (`Callable`):
+            A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
+            by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
+            If not provided, simple uniform sampling with fps is performed.
+            Example:
+            def sample_indices_fn(metadata, **kwargs):
+                return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
+
+    Returns:
+        Tuple[`torch.Tensor`, `VideoMetadata`]: A tuple containing:
+            - Torch tensor of frames in RGB (shape: [num_frames, height, width, 3]).
+            - `VideoMetadata` object.
+    """
+    # Lazy import torchcodec
+    requires_backends(read_video_torchcodec, ["torchcodec"])
+    from torchcodec.decoders import VideoDecoder
+
+    decoder = VideoDecoder(
+        video_path,
+        # Interestingly `exact` mode takes less than approximate when we load the whole video
+        seek_mode="exact",
+        # Allow FFmpeg decide on the number of threads for efficiency
+        num_ffmpeg_threads=0,
+        device=kwargs.get("device"),
+    )
+    metadata = VideoMetadata(
+        total_num_frames=decoder.metadata.num_frames,
+        fps=decoder.metadata.average_fps,
+        duration=decoder.metadata.duration_seconds,
+        video_backend="torchcodec",
+        height=decoder.metadata.height,
+        width=decoder.metadata.width,
+    )
+    indices = sample_indices_fn(metadata=metadata, **kwargs)
+
+    video = decoder.get_frames_at(indices=indices).data.contiguous()
+    metadata.frames_indices = indices
+    return video, metadata
+
+
+VIDEO_DECODERS = {
+    "decord": read_video_decord,
+    "opencv": read_video_opencv,
+    "pyav": read_video_pyav,
+    "torchvision": read_video_torchvision,
+    "torchcodec": read_video_torchcodec,
+}
+
+
+def load_video(
+    video: VideoInput,
+    num_frames: Optional[int] = None,
+    fps: Optional[Union[int, float]] = None,
+    backend: str = "pyav",
+    sample_indices_fn: Optional[Callable] = None,
+    **kwargs,
+) -> np.array:
+    """
+    Loads `video` to a numpy array.
+
+    Args:
+        video (`VideoInput`):
+            The video to convert to the numpy array format. Can be a link to video or local path.
+        num_frames (`int`, *optional*):
+            Number of frames to sample uniformly. If not passed, the whole video is loaded.
+        fps (`int` or `float`, *optional*):
+            Number of frames to sample per second. Should be passed only when `num_frames=None`.
+            If not specified and `num_frames==None`, all frames are sampled.
+        backend (`str`, *optional*, defaults to `"pyav"`):
+            The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision", "torchcodec"]. Defaults to "pyav".
+        sample_indices_fn (`Callable`, *optional*):
+            A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
+            by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
+            If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
+            The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
+            indices at which the video should be sampled. For example:
+
+            Example:
+            def sample_indices_fn(metadata, **kwargs):
+                return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)
+
+    Returns:
+        tuple[`np.array`, Dict]: A tuple containing:
+            - Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
+            - Metadata dictionary.
+    """
+
+    # If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn`
+    if fps is not None and num_frames is not None and sample_indices_fn is None:
+        raise ValueError(
+            "`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
+        )
+
+    # If user didn't pass a sampling function, create one on the fly with default logic
+    if sample_indices_fn is None:
+
+        def sample_indices_fn_func(metadata, **fn_kwargs):
+            return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs)
+
+        sample_indices_fn = sample_indices_fn_func
+
+    # Early exit if provided an array or `PIL` frames
+    if not isinstance(video, str):
+        metadata = [None] * len(video)
+        return video, metadata
+
+    if urlparse(video).netloc in ["www.youtube.com", "youtube.com"]:
+        if not is_yt_dlp_available():
+            raise ImportError("To load a video from YouTube url you have  to install `yt_dlp` first.")
+        # Lazy import from yt_dlp
+        requires_backends(load_video, ["yt_dlp"])
+        from yt_dlp import YoutubeDL
+
+        buffer = BytesIO()
+        with redirect_stdout(buffer), YoutubeDL() as f:
+            f.download([video])
+        bytes_obj = buffer.getvalue()
+        file_obj = BytesIO(bytes_obj)
+    elif video.startswith("http://") or video.startswith("https://"):
+        file_obj = BytesIO(requests.get(video).content)
+    elif os.path.isfile(video):
+        file_obj = video
+    else:
+        raise TypeError("Incorrect format used for video. Should be an url linking to an video or a local path.")
+
+    # can also load with decord, but not cv2/torchvision
+    # both will fail in case of url links
+    video_is_url = video.startswith("http://") or video.startswith("https://")
+    if video_is_url and backend in ["opencv"]:
+        raise ValueError("If you are trying to load a video from URL, you cannot use 'opencv' as backend")
+
+    if (
+        (not is_decord_available() and backend == "decord")
+        or (not is_av_available() and backend == "pyav")
+        or (not is_cv2_available() and backend == "opencv")
+        or (not is_torchvision_available() and backend == "torchvision")
+        or (not is_torchcodec_available() and backend == "torchcodec")
+    ):
+        raise ImportError(
+            f"You chose backend={backend} for loading the video but the required library is not found in your environment "
+            f"Make sure to install {backend} before loading the video."
+        )
+
+    video_decoder = VIDEO_DECODERS[backend]
+    video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
+    return video, metadata
+
+
+def convert_to_rgb(
+    video: np.array,
+    data_format: Optional[ChannelDimension] = None,
+    input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> np.array:
+    """
+    Convert video to RGB by blending the transparency layer if it's in RGBA format, otherwise simply returns it.
+
+    Args:
+        video (`np.array`):
+            The video to convert.
+        data_format (`ChannelDimension`, *optional*):
+            The channel dimension format of the output video. If unset, will use the inferred format from the input.
+        input_data_format (`ChannelDimension`, *optional*):
+            The channel dimension format of the input video. If unset, will use the inferred format from the input.
+    """
+    if not isinstance(video, np.ndarray):
+        raise TypeError(f"Video has to be a numpy array to convert to RGB format, but found {type(video)}")
+
+    # np.array usually comes with ChannelDimension.LAST so leet's convert it
+    if input_data_format is None:
+        input_data_format = infer_channel_dimension_format(video)
+    video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_channel_dim=input_data_format)
+
+    # 3 channels for RGB already
+    if video.shape[-3] == 3:
+        return video
+
+    # Grayscale video so we repeat it 3 times for each channel
+    if video.shape[-3] == 1:
+        return video.repeat(3, -3)
+
+    if not (video[..., 3, :, :] < 255).any():
+        return video
+
+    # There is a transparency layer, blend it with a white background.
+    # Calculate the alpha proportion for blending.
+    alpha = video[..., 3, :, :] / 255.0
+    video = (1 - alpha[..., None, :, :]) * 255 + alpha[..., None, :, :] * video[..., 3, :, :]
+    return video
+
+
+def pad(
+    video: np.ndarray,
+    padding: Union[int, tuple[int, int], Iterable[tuple[int, int]]],
+    mode: PaddingMode = PaddingMode.CONSTANT,
+    constant_values: Union[float, Iterable[float]] = 0.0,
+    data_format: Optional[Union[str, ChannelDimension]] = None,
+    input_data_format: Optional[Union[str, ChannelDimension]] = None,
+) -> np.ndarray:
+    """
+    Pads the `video` with the specified (height, width) `padding` and `mode`.
+
+    Args:
+        video (`np.ndarray`):
+            The video to pad.
+        padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`):
+            Padding to apply to the edges of the height, width axes. Can be one of three formats:
+            - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
+            - `((before, after),)` yields same before and after pad for height and width.
+            - `(pad,)` or int is a shortcut for before = after = pad width for all axes.
+        mode (`PaddingMode`):
+            The padding mode to use. Can be one of:
+                - `"constant"`: pads with a constant value.
+                - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
+                  vector along each axis.
+                - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
+                - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
+        constant_values (`float` or `Iterable[float]`, *optional*):
+            The value to use for the padding if `mode` is `"constant"`.
+        data_format (`str` or `ChannelDimension`, *optional*):
+            The channel dimension format for the output video. Can be one of:
+                - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format.
+                - `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format.
+            If unset, will use same as the input video.
+        input_data_format (`str` or `ChannelDimension`, *optional*):
+            The channel dimension format for the input video. Can be one of:
+                - `"channels_first"` or `ChannelDimension.FIRST`: video in (num_frames, num_channels, height, width) format.
+                - `"channels_last"` or `ChannelDimension.LAST`: video in (num_frames, height, width, num_channels) format.
+            If unset, will use the inferred format of the input video.
+
+    Returns:
+        `np.ndarray`: The padded video.
+
+    """
+    if input_data_format is None:
+        input_data_format = infer_channel_dimension_format(video)
+
+    def _expand_for_data_format(values):
+        """
+        Convert values to be in the format expected by np.pad based on the data format.
+        """
+        if isinstance(values, (int, float)):
+            values = ((values, values), (values, values))
+        elif isinstance(values, tuple) and len(values) == 1:
+            values = ((values[0], values[0]), (values[0], values[0]))
+        elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
+            values = (values, values)
+        elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
+            pass
+        else:
+            raise ValueError(f"Unsupported format: {values}")
+
+        # add 0 for channel dimension
+        values = (
+            ((0, 0), (0, 0), *values) if input_data_format == ChannelDimension.FIRST else ((0, 0), *values, (0, 0))
+        )
+
+        # Add additional padding if there's a batch dimension
+        values = (0, *values) if video.ndim == 5 else values
+        return values
+
+    padding_map = {
+        PaddingMode.CONSTANT: "constant",
+        PaddingMode.REFLECT: "reflect",
+        PaddingMode.REPLICATE: "replicate",
+        PaddingMode.SYMMETRIC: "symmetric",
+    }
+    padding = _expand_for_data_format(padding)
+
+    pad_kwargs = {}
+    if mode not in padding_map:
+        raise ValueError(f"Invalid padding mode: {mode}")
+    elif mode == PaddingMode.CONSTANT:
+        pad_kwargs["constant_values"] = _expand_for_data_format(constant_values)
+
+    video = np.pad(video, padding, mode=padding_map[mode], **pad_kwargs)
+    video = to_channel_dimension_format(video, data_format, input_data_format) if data_format is not None else video
+    return video
+
+
+def group_videos_by_shape(
+    videos: list["torch.Tensor"],
+) -> tuple[dict[tuple[int, int], list["torch.Tensor"]], dict[int, tuple[tuple[int, int], int]]]:
+    """
+    Groups videos by shape.
+    Returns a dictionary with the shape as key and a list of videos with that shape as value,
+    and a dictionary with the index of the video in the original list as key and the shape and index in the grouped list as value.
+    """
+    grouped_videos = {}
+    grouped_videos_index = {}
+    for i, video in enumerate(videos):
+        shape = video.shape[-2::]
+        num_frames = video.shape[-4]  # video format BTCHW
+        shape = (num_frames, *shape)
+        if shape not in grouped_videos:
+            grouped_videos[shape] = []
+        grouped_videos[shape].append(video)
+        grouped_videos_index[i] = (shape, len(grouped_videos[shape]) - 1)
+    # stack videos with the same size and number of frames
+    grouped_videos = {shape: torch.stack(videos, dim=0) for shape, videos in grouped_videos.items()}
+    return grouped_videos, grouped_videos_index
+
+
+def reorder_videos(
+    processed_videos: dict[tuple[int, int], "torch.Tensor"], grouped_videos_index: dict[int, tuple[int, int]]
+) -> list["torch.Tensor"]:
+    """
+    Reconstructs a list of videos in the original order.
+    """
+    return [
+        processed_videos[grouped_videos_index[i][0]][grouped_videos_index[i][1]]
+        for i in range(len(grouped_videos_index))
+    ]